├── agentlightning ├── .gitignore ├── examples │ ├── .gitignore │ ├── calc_x │ │ ├── README.md │ │ ├── train.sh │ │ └── calc_agent.py │ └── spider │ │ ├── spider_eval │ │ ├── convert_dataset.py │ │ ├── parse.py │ │ └── exec_eval.py │ │ ├── README.md │ │ └── train.sh ├── scripts │ ├── restart_ray.sh │ ├── verl_git_diff.sh │ └── verl_git_apply.sh ├── agentlightning │ ├── cli │ │ ├── vllm.py │ │ └── agentops_server.py │ ├── __init__.py │ ├── logging.py │ ├── instrumentation │ │ ├── agentops_langchain.py │ │ ├── litellm.py │ │ ├── __init__.py │ │ ├── vllm.py │ │ ├── agentops.py │ │ └── verl_chat_scheduler.py │ ├── reward.py │ └── client.py ├── pyproject.toml └── README.md ├── eppo ├── code │ ├── __init__.py │ ├── tools │ │ ├── __init__.py │ │ ├── env_utils.py │ │ └── config.py │ ├── policy │ │ ├── __init__.py │ │ ├── base.py │ │ ├── utils.py │ │ └── ppo.py │ ├── network │ │ ├── base.py │ │ ├── mlp.py │ │ ├── __init__.py │ │ └── cnn.py │ └── env │ │ ├── __init__.py │ │ ├── game_logger.py │ │ └── logging.py ├── requirements.txt ├── exp │ └── atari_local.yml └── README.md ├── bootorl ├── model │ └── __init__.py ├── utils │ ├── __init__.py │ ├── timer.py │ ├── sample.py │ ├── renderer.py │ └── planning.py ├── environment.yml ├── LICENSE ├── Dockerfile ├── .gitignore ├── analysis │ ├── visualize_state.py │ ├── visualize_distribution.py │ └── calc_distance.py ├── README.md └── main │ ├── plan.py │ ├── analyze_plan.py │ ├── extract_distribution.py │ └── train.py ├── a2ls ├── conda_env.yml ├── video.py ├── logger.py ├── README.md └── encoder.py ├── CODE_OF_CONDUCT.md ├── LICENSE ├── SUPPORT.md ├── README.md ├── SECURITY.md ├── .github └── workflows │ └── codeql.yml └── .gitignore /agentlightning/.gitignore: -------------------------------------------------------------------------------- 1 | verl_old 2 | -------------------------------------------------------------------------------- /eppo/code/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. -------------------------------------------------------------------------------- /agentlightning/examples/.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | wandb/ 3 | outputs/ 4 | checkpoints/ 5 | calc-x-data.zip 6 | agentops.log 7 | -------------------------------------------------------------------------------- /eppo/code/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. -------------------------------------------------------------------------------- /bootorl/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | from .gpt import GPT -------------------------------------------------------------------------------- /agentlightning/scripts/restart_ray.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ray stop 4 | env RAY_DEBUG=legacy HYDRA_FULL_ERROR=1 VLLM_USE_V1=1 ray start --head --dashboard-host=0.0.0.0 5 | -------------------------------------------------------------------------------- /eppo/code/policy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from .base import POLICIES 4 | from .marl_policy import MARLPPOPolicy 5 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/cli/vllm.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from vllm.entrypoints.cli.main import main 4 | 5 | from agentlightning.instrumentation.vllm import instrument_vllm 6 | 7 | 8 | if __name__ == "__main__": 9 | instrument_vllm() 10 | main() 11 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1' 2 | 3 | from .client import VerlAgentClient, SamplingParameters, TaskData 4 | from .config import lightning_cli 5 | from .logging import configure_logger 6 | from .reward import reward 7 | from .trace import lightning_span_processor 8 | from .trainer import LitAgent, Trainer 9 | -------------------------------------------------------------------------------- /eppo/code/policy/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from typing import Optional 4 | 5 | import gym 6 | from tianshou.policy import BasePolicy 7 | from utilsd.config import Registry 8 | 9 | 10 | class POLICIES(metaclass=Registry, name='policy'): 11 | pass -------------------------------------------------------------------------------- /a2ls/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: autorrl 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7 6 | - pytorch 7 | - torchvision 8 | - cudatoolkit=10.2 9 | - absl-py 10 | - pyparsing 11 | - pillow=6.1 12 | - pip: 13 | - termcolor 14 | - git+git://github.com/deepmind/dm_control.git 15 | - git+git://github.com/1nadequacy/dmc2gym.git 16 | - tb-nightly 17 | - imageio 18 | - imageio-ffmpeg 19 | - torchvision 20 | - scikit-image -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /bootorl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | from .argparser import ArgParser 5 | from .dataset import DiscretizedDataset 6 | from .dataset import NoisyDiscretizedDataset 7 | from .dataset import AnalyzeDiscretizedDataset 8 | from .dataset import load_environment 9 | from .trainer import Trainer 10 | from .tester import Tester 11 | from .timer import Timer 12 | from .planning import plan 13 | from .renderer import Renderer -------------------------------------------------------------------------------- /eppo/code/network/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import torch 4 | import torch.nn as nn 5 | from utilsd.config import Registry 6 | 7 | 8 | class NETWORKS(metaclass=Registry, name='network'): 9 | pass 10 | 11 | class BaseNetwork(nn.Module): 12 | output_dim: int 13 | 14 | def load_weight(policy, path): 15 | assert isinstance(policy, nn.Module), 'Policy has to be an nn.Module to load weight.' 16 | policy.load_state_dict(torch.load(path, map_location='cpu')) -------------------------------------------------------------------------------- /agentlightning/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "agentlightning" 3 | version = "0.1" 4 | description = "Agent Lightning" 5 | readme = "README.md" 6 | requires-python = ">=3.9" 7 | dependencies = [ 8 | "graphviz", 9 | "psutil", 10 | "setproctitle", 11 | "flask", 12 | "agentops", 13 | ] 14 | 15 | [project.optional-dependencies] 16 | dev = [ 17 | "flake8", 18 | "pytest", 19 | "hatch", 20 | ] 21 | experiment = [ 22 | "random-word", 23 | ] 24 | 25 | [build-system] 26 | requires = ["hatchling"] 27 | build-backend = "hatchling.build" 28 | 29 | [tool.hatch.build.targets.wheel] 30 | packages = ["agentlightning"] 31 | -------------------------------------------------------------------------------- /bootorl/environment.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | name: boot 5 | channels: 6 | - defaults 7 | - conda-forge 8 | - pytorch 9 | dependencies: 10 | - python=3.8.5 11 | - pip>=20.2 12 | - conda>=4.9.2 13 | - patchelf=0.12 14 | - pytorch=1.10.0 15 | - torchvision=0.11.0 16 | - torchaudio=0.10.0 17 | - cudatoolkit=11.3 18 | - pip: 19 | - scikit-learn 20 | - pandas 21 | - git+https://github.com/rail-berkeley/d4rl.git@d842aa194b416e564e54b0730d9f934e3e32f854 22 | - git+https://github.com/openai/gym.git@66c431d4b3072a1db44d564dab812b9d23c06e14 23 | - tqdm>=4.51.0 -------------------------------------------------------------------------------- /eppo/code/policy/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from tianshou.data import to_torch 4 | from utilsd import use_cuda 5 | from code.network.base import load_weight 6 | 7 | __all__ = ['chain_dedup', 'preprocess_obs', 'load_weight'] 8 | 9 | 10 | def chain_dedup(*iterables): 11 | seen = set() 12 | for iterable in iterables: 13 | for i in iterable: 14 | if i not in seen: 15 | seen.add(i) 16 | yield i 17 | 18 | 19 | def preprocess_obs(obs): 20 | return dict(to_torch(obs, device='cuda' if use_cuda() else 'cpu')) 21 | -------------------------------------------------------------------------------- /agentlightning/examples/calc_x/README.md: -------------------------------------------------------------------------------- 1 | # Calc-X Example 2 | 3 | This example requires a single node with one GPU of at least 40GB memory. 4 | 5 | 1. Download the data in parquet format from [here](https://drive.google.com/file/d/1FQMyKLLd6hP9dw9rfZn1EZOWNvKaDsqw/view?usp=sharing) and unzip it to the `data` folder: `unzip calc-x-data.zip -d data`. 6 | 2. Start ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting ray. 7 | 3. Run the agent: `python calc_agent.py`. It automatically launches 4 agent workers by default. 8 | 4. In another terminal, launch the training server: `bash train.sh`. 9 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def configure_logger(level: int = logging.INFO, name: str = "agentlightning") -> logging.Logger: 5 | logger = logging.getLogger(name) 6 | logger.handlers.clear() # clear existing handlers 7 | 8 | # log to stdout 9 | handler = logging.StreamHandler() 10 | handler.setLevel(level) 11 | formatter = logging.Formatter("%(asctime)s [%(levelname)s] (Process-%(process)d %(name)s) %(message)s") 12 | handler.setFormatter(formatter) 13 | logger.addHandler(handler) 14 | logger.setLevel(level) 15 | logger.propagate = False # prevent double logging 16 | return logger 17 | -------------------------------------------------------------------------------- /eppo/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | addict==2.4.0 3 | ale-py==0.7.5 4 | atari-py==0.2.9 5 | cachetools==5.0.0 6 | click==8.1.3 7 | glfw==2.5.3 8 | google-auth==2.6.6 9 | google-auth-oauthlib==0.4.6 10 | grpcio==1.46.1 11 | gym==0.23.1 12 | gym-notices==0.0.6 13 | importlib-metadata==4.11.3 14 | importlib-resources==5.7.1 15 | lockfile==0.12.2 16 | lz4==4.0.0 17 | markdown==3.3.7 18 | oauthlib==3.2.1 19 | opencv-python==4.5.5.64 20 | pathlib==1.0.1 21 | protobuf==3.20.2 22 | pyasn1==0.4.8 23 | pyasn1-modules==0.2.8 24 | pygame==2.1.0 25 | requests-oauthlib==1.3.1 26 | rsa==4.8 27 | ruamel-yaml==0.17.21 28 | ruamel-yaml-clib==0.2.6 29 | tensorboard==2.9.0 30 | tensorboard-data-server==0.6.1 31 | tensorboard-plugin-wit==1.8.1 32 | tianshou==0.4.8 33 | utilsd==0.0.15 -------------------------------------------------------------------------------- /agentlightning/agentlightning/cli/agentops_server.py: -------------------------------------------------------------------------------- 1 | import time 2 | from agentlightning.instrumentation.agentops import AgentOpsServerManager 3 | 4 | 5 | if __name__ == "__main__": 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser(description="Start AgentOps server") 9 | parser.add_argument("--daemon", action="store_true", help="Run server as a daemon") 10 | parser.add_argument("--port", type=int, default=8002, help="Port to run the server on") 11 | args = parser.parse_args() 12 | 13 | manager = AgentOpsServerManager(daemon=args.daemon, port=args.port) 14 | try: 15 | manager.start() 16 | # Wait forever 17 | while True: 18 | time.sleep(1) 19 | except KeyboardInterrupt: 20 | manager.stop() 21 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/instrumentation/agentops_langchain.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from agentops.integration.callbacks.langchain import LangchainCallbackHandler 3 | 4 | 5 | original_on_chain_start = LangchainCallbackHandler.on_chain_start 6 | 7 | 8 | def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> None: 9 | if "name" in kwargs: 10 | if serialized is None: 11 | serialized = {} 12 | serialized = serialized.copy() 13 | serialized["name"] = kwargs["name"] 14 | if "run_id" in kwargs: 15 | if serialized is None: 16 | serialized = {} 17 | serialized = serialized.copy() 18 | if "id" not in serialized: 19 | serialized["id"] = kwargs["run_id"] 20 | return original_on_chain_start(self, serialized, inputs, **kwargs) 21 | 22 | 23 | def instrument_agentops_langchain(): 24 | LangchainCallbackHandler.on_chain_start = on_chain_start 25 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/instrumentation/litellm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | 3 | from litellm.integrations.opentelemetry import OpenTelemetry 4 | 5 | # It's unclear whether or not this file is useful 6 | # It seems that LiteLLM owns its own telemetry from their own entrance 7 | # https://docs.litellm.ai/docs/observability/agentops_integration 8 | 9 | original_set_attributes = OpenTelemetry.set_attributes 10 | 11 | 12 | def patched_set_attributes(self, span: Any, kwargs, response_obj: Optional[Any]): 13 | original_set_attributes(self, span, kwargs, response_obj) 14 | # Add custom attributes 15 | if response_obj.get("prompt_token_ids"): 16 | span.set_attribute( 17 | "prompt_token_ids", list(response_obj.get("prompt_token_ids")) 18 | ) 19 | if response_obj.get("response_token_ids"): 20 | span.set_attribute( 21 | "response_token_ids", list(response_obj.get("response_token_ids")[0]) 22 | ) 23 | 24 | 25 | def instrument_litellm(): 26 | OpenTelemetry.set_attributes = patched_set_attributes 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /bootorl/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /eppo/code/network/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from code.network.base import NETWORKS 4 | import torch 5 | import torch.nn as nn 6 | 7 | @NETWORKS.register_module() 8 | class MLP(nn.Module): 9 | def __init__(self, 10 | input_dims: int, 11 | hidden_dim: int=64, 12 | output_dim: int=32, 13 | num_layers: int=2): 14 | super().__init__() 15 | self.raw_fc = nn.Sequential() 16 | self.output_dim = output_dim 17 | if num_layers==1: 18 | layers = [nn.Linear(input_dims, output_dim), nn.ReLU()] 19 | else: 20 | layers = [nn.Linear(input_dims, hidden_dim), nn.ReLU()] 21 | for i in range(num_layers-1): 22 | layers.append(nn.Linear(hidden_dim,hidden_dim)) 23 | layers.append(nn.ReLU()) 24 | layers.append(nn.Linear(hidden_dim,output_dim)) 25 | layers.append(nn.ReLU()) 26 | self.layers = nn.ModuleList(layers) 27 | 28 | def forward(self,x): 29 | for i, l in enumerate(self.layers): 30 | x = self.layers[i](x) 31 | return x 32 | 33 | -------------------------------------------------------------------------------- /agentlightning/examples/spider/spider_eval/convert_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | data_dir = "data" 6 | target_data_dir = "data" 7 | columns = ["db_id", "question", "query"] 8 | 9 | dev_path = os.path.join(data_dir, "dev.json") 10 | dev_df = pd.read_json(dev_path) 11 | print(dev_df) 12 | dev_df[columns].to_parquet(os.path.join(target_data_dir, "dev.parquet"), index=False) 13 | 14 | train_path = os.path.join(data_dir, "train_spider.json") 15 | train_df = pd.read_json(train_path) 16 | print(train_df) 17 | train_df[columns].to_parquet(os.path.join(target_data_dir, "train_spider.parquet"), index=False) 18 | 19 | test_path = os.path.join(data_dir, "test.json") 20 | test_df = pd.read_json(test_path) 21 | print(test_df) 22 | test_df[columns].to_parquet(os.path.join(target_data_dir, "test.parquet"), index=False) 23 | 24 | # Select 100 of test df as test_dev 25 | test_dev_df = test_df.sample(n=100, random_state=42) 26 | test_dev_df[columns].to_parquet(os.path.join(target_data_dir, "test_dev.parquet"), index=False) 27 | 28 | # Select 500 of test df as test_dev 29 | test_dev_df = test_df.sample(n=500, random_state=0) 30 | test_dev_df[columns].to_parquet(os.path.join(target_data_dir, "test_dev_500.parquet"), index=False) 31 | -------------------------------------------------------------------------------- /a2ls/video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import imageio 4 | import os 5 | import numpy as np 6 | 7 | 8 | class VideoRecorder(object): 9 | def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): 10 | self.dir_name = dir_name 11 | self.height = height 12 | self.width = width 13 | self.camera_id = camera_id 14 | self.fps = fps 15 | self.frames = [] 16 | 17 | def init(self, enabled=True): 18 | self.frames = [] 19 | self.enabled = self.dir_name is not None and enabled 20 | 21 | def record(self, env): 22 | if self.enabled: 23 | try: 24 | frame = env.render( 25 | mode='rgb_array', 26 | height=self.height, 27 | width=self.width, 28 | camera_id=self.camera_id 29 | ) 30 | except: 31 | frame = env.render( 32 | mode='rgb_array', 33 | ) 34 | 35 | self.frames.append(frame) 36 | 37 | def save(self, file_name): 38 | if self.enabled: 39 | path = os.path.join(self.dir_name, file_name) 40 | imageio.mimsave(path, self.frames, fps=self.fps) 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /eppo/exp/atari_local.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | _custom_imports_: 4 | - code.policy 5 | - code.network 6 | env: 7 | concurrency: 16 8 | env_name: PongNoFrameskip-v4 #BreakoutNoFrameskip-v4 9 | max_episode_steps: 108000 10 | episode_life: true 11 | clip_rewards: true 12 | frame_stack: true 13 | scale: false 14 | network1: 15 | type: AtariCNN 16 | input_dims: 2 # mountain car's obervation space, will be changed in code 17 | num_layers: 1 18 | output_dim: 64 19 | network2: 20 | type: AtariCNN 21 | input_dims: 2 # mountain car's obervation space, will be changed in code 22 | num_layers: 1 23 | output_dim: 64 24 | policy: 25 | type: MARLPPOPolicy 26 | lr: 0.0001 27 | num_policy: 4 28 | weight_decay: 1.0e-05 29 | discount_factor: 0.99 30 | gae_lambda: 0.95 31 | max_grad_norm: 0.5 32 | value_clip: true 33 | eps_clip: 0.1 34 | diverse_coef: 0.01 35 | sub_policy_coef: 0.5 36 | center_policy_coef: 1.0 37 | trainer: 38 | max_epoch: 1000 39 | repeat_per_collect: 5 40 | earlystop_patience: 40000 41 | episode_per_collect: 10000 42 | batch_size: 256 43 | val_every_n_epoch: 10 44 | fast_dev_run: false 45 | buffer_size: 200000 46 | save_epoch: false 47 | runtime: 48 | use_cuda: true 49 | seed: 42 50 | 51 | 52 | -------------------------------------------------------------------------------- /eppo/code/network/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import torch 4 | import torch.nn as nn 5 | from .mlp import MLP 6 | from .cnn import MiniCNN, AtariCNN, MiniLargeCNN 7 | from .base import BaseNetwork, NETWORKS 8 | 9 | class Reshape(nn.Module): 10 | def __init__(self, *args): 11 | super(Reshape, self).__init__() 12 | self.shape = args 13 | def forward(self, x): 14 | return x.view((x.size(0),)+self.shape) 15 | 16 | class Attention(nn.Module): 17 | def __init__(self, in_dim, out_dim): 18 | super().__init__() 19 | self.q_net = nn.Linear(in_dim, out_dim) 20 | self.k_net = nn.Linear(in_dim, out_dim) 21 | self.v_net = nn.Linear(in_dim, out_dim) 22 | 23 | def forward(self, Q, K, V): 24 | q = self.q_net(Q) 25 | k = self.k_net(K) 26 | v = self.v_net(V) 27 | 28 | attn = torch.einsum("ijk,ilk->ijl", q, k) 29 | attn = attn.to(Q.device) 30 | attn_prob = torch.softmax(attn, dim=-1) 31 | 32 | attn_vec = torch.einsum("ijk,ikl->ijl", attn_prob, v) 33 | 34 | return attn_vec 35 | 36 | class SelfAttention(Attention): 37 | def __init__(self, in_dim, out_dim): 38 | super().__init__(in_dim,out_dim) 39 | 40 | def forward(self, X): 41 | return super().forward(X,X,X) -------------------------------------------------------------------------------- /agentlightning/examples/spider/README.md: -------------------------------------------------------------------------------- 1 | # Spider Example 2 | 3 | This example requires a single node with one GPU of at least 40GB memory. 4 | 5 | 1. Download Spider 1.0 dataset from [here](https://yale-lily.github.io/spider) and unzip it to the `data` folder. 6 | 2. Use `python spider_eval/convert_dataset.py` to convert the dataset to the parquet format. 7 | 3. Start ray: `bash ../../scripts/restart_ray.sh`. To use Wandb, you need to set the WANDB_API_KEY environment variable before starting ray. 8 | 4. Run the agent: `VERL_API_BASE=http://localhost:9999/ python sql_agent.py`. Use `python sql_agent.py --help` to see options like running multiple agents. 9 | 5. In another terminal, launch the training server: `bash train.sh`. 10 | 11 | ## Evaluation 12 | 13 | Setting: 14 | 15 | * 1 node with 4 80GB A100 GPUs. 16 | * Model: Qwen/Qwen2.5-Coder-3B-Instruct 17 | * Train write and rewrite agents only (`--litsqlagent.trained-agents write`). The check agent share the same model but the corresponding interaction is not trained. 18 | * 10 parallel agent workers (`--trainer.n-workers 10`). 19 | * Validation temperature = 0 (`--trainer.val-temperature 0`). 20 | * Maximum 3 turns (1 write + 2 rewrites, `--litsqlagent.max-turns 3`). 21 | * Truncate the database schema description and execution result to 512 characters (`--litsqlagent.table-info-truncate 2048 --litsqlagent.execution-truncate 2048`). 22 | * RL algorithm is GRPO with learning rate 1e-6. 23 | * Uses random 500 samples from the test set for validation. 24 | 25 | Under the base setting, the performance on validation set boosted from 62.4% to 76.2% in 400 training steps. 26 | The W&B report is available [here](https://api.wandb.ai/links/ultmaster/agnice3m). 27 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/instrumentation/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | AGENTOPS_INSTALLED = False 4 | AGENTOPS_LANGCHAIN_INSTALLED = False 5 | LITELLM_INSTALLED = False 6 | VLLM_INSTALLED = False 7 | 8 | try: 9 | from . import agentops 10 | 11 | AGENTOPS_INSTALLED = True 12 | except ImportError: 13 | pass 14 | 15 | try: 16 | from . import litellm 17 | 18 | LITELLM_INSTALLED = True 19 | except ImportError: 20 | pass 21 | 22 | try: 23 | from . import vllm 24 | 25 | VLLM_INSTALLED = True 26 | except ImportError: 27 | pass 28 | 29 | 30 | try: 31 | from . import agentops_langchain 32 | 33 | AGENTOPS_LANGCHAIN_INSTALLED = True 34 | except ImportError: 35 | pass 36 | 37 | 38 | def instrument_all(): 39 | if AGENTOPS_INSTALLED: 40 | from .agentops import instrument_agentops 41 | 42 | instrument_agentops() 43 | else: 44 | warnings.warn("agentops is not installed. It's therefore not instrumented.") 45 | 46 | if LITELLM_INSTALLED: 47 | from .litellm import instrument_litellm 48 | 49 | instrument_litellm() 50 | else: 51 | warnings.warn("litellm is not installed. It's therefore not instrumented.") 52 | 53 | if VLLM_INSTALLED: 54 | from .vllm import instrument_vllm 55 | 56 | instrument_vllm() 57 | else: 58 | warnings.warn("vllm is not installed. It's therefore not instrumented.") 59 | 60 | if AGENTOPS_LANGCHAIN_INSTALLED: 61 | from .agentops_langchain import instrument_agentops_langchain 62 | 63 | instrument_agentops_langchain() 64 | else: 65 | warnings.warn("Agentops-langchain integration is not installed. It's therefore not instrumented.") 66 | -------------------------------------------------------------------------------- /bootorl/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import time 5 | 6 | 7 | class Timer: 8 | def __init__(self, total_num=-1, formatting=True, return_total_time=True, return_eta=True): 9 | self.start_time = time.time() 10 | self.last_time = self.start_time 11 | self.total_num = total_num 12 | self.cnt = 0 13 | 14 | self.formatting = formatting 15 | self.return_total_time = return_total_time 16 | self.return_eta = return_eta and (self.total_num > 0) 17 | self.prev_time_est = None 18 | 19 | def _fmt(self, time): 20 | if not self.formatting: 21 | return f"{time:.3f}" 22 | minutes, seconds = time // 60, time % 60 23 | hours, minutes = minutes // 60, minutes % 60 24 | string = f"{minutes:0>2.0f}:{seconds:0>6.3f}" 25 | if hours > 0: 26 | string = f"{hours:0>2.0f}:{string}" 27 | return string 28 | 29 | def __call__(self): 30 | self.cnt += 1 31 | 32 | curr_time = time.time() 33 | diff_time = curr_time - self.last_time 34 | returns = [self._fmt(diff_time)] 35 | if self.return_total_time: 36 | returns.append(self._fmt(curr_time - self.start_time)) 37 | if self.return_eta: 38 | self.prev_time_est = diff_time * 0.5 + self.prev_time_est * 0.5 if self.prev_time_est is not None else diff_time 39 | if self.cnt > self.total_num or self.cnt == 0: 40 | eta = "??" 41 | else: 42 | eta = self.prev_time_est * (self.total_num - self.cnt) 43 | eta = self._fmt(eta) 44 | returns.append(eta) 45 | 46 | self.last_time = curr_time 47 | return tuple(returns) 48 | -------------------------------------------------------------------------------- /bootorl/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | FROM pytorch/pytorch:1.10.0-cuda11.3-cudnn8-devel 5 | 6 | WORKDIR /workspace 7 | 8 | # Install new cuda-keyring package 9 | # Noted at https://forums.developer.nvidia.com/t/notice-cuda-linux-repository-key-rotation/212772 10 | RUN rm /etc/apt/sources.list.d/cuda.list /etc/apt/sources.list.d/nvidia-ml.list \ 11 | && apt-key del 7fa2af80 \ 12 | && apt-get update && apt-get install -y --no-install-recommends wget \ 13 | && wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb \ 14 | && dpkg -i cuda-keyring_1.0-1_all.deb 15 | 16 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive \ 17 | && apt-get install -y zlib1g zlib1g-dev libosmesa6-dev libgl1-mesa-glx libglfw3 libglew2.0 cmake git \ 18 | && ln -s /usr/lib/x86_64-linux-gnu/libGL.so.1 /usr/lib/x86_64-linux-gnu/libGL.so 19 | 20 | # Install MuJoCo 2.1.0. 21 | ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/root/.mujoco/mujoco210/bin 22 | RUN mkdir -p /root/.mujoco \ 23 | && wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz -O mujoco210.tar.gz \ 24 | && tar -xvzf mujoco210.tar.gz -C /root/.mujoco \ 25 | && rm mujoco210.tar.gz 26 | 27 | # Install packages, mainly d4rl, which will also install corresponding dependencies automatically. 28 | RUN pip install -U scikit-learn pandas \ 29 | && pip install git+https://github.com/rail-berkeley/d4rl.git@d842aa194b416e564e54b0730d9f934e3e32f854 \ 30 | && pip install git+https://github.com/openai/gym.git@66c431d4b3072a1db44d564dab812b9d23c06e14 31 | 32 | # Pre-download dataset if necessary 33 | # RUN python -c "import gym; import d4rl; [gym.make(f'{game}-{level}-v2').unwrapped.get_dataset() for level in \ 34 | # ['medium', 'medium-replay', 'medium-expert', 'expert'] for game in ['halfcheetah', 'hopper', 'walker2d']];" 35 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/instrumentation/vllm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import List 4 | 5 | from vllm.entrypoints.openai.protocol import ChatCompletionResponse 6 | import vllm.entrypoints.openai.protocol 7 | from vllm.entrypoints.openai.serving_chat import OpenAIServingChat 8 | 9 | 10 | class ChatCompletionResponsePatched(ChatCompletionResponse): 11 | prompt_token_ids: List[int] | None = None 12 | response_token_ids: List[int] | None = None 13 | 14 | 15 | original_chat_completion_full_generator = ( 16 | OpenAIServingChat.chat_completion_full_generator 17 | ) 18 | 19 | 20 | async def chat_completion_full_generator( 21 | self, 22 | request, 23 | result_generator, 24 | request_id: str, 25 | model_name: str, 26 | conversation, 27 | tokenizer, 28 | request_metadata, 29 | ): 30 | prompt_token_ids: List[int] | None = None 31 | response_token_ids: List[List[int]] | None = None 32 | 33 | async def _generate_inceptor(): 34 | nonlocal prompt_token_ids, response_token_ids 35 | async for res in result_generator: 36 | yield res 37 | prompt_token_ids = res.prompt_token_ids 38 | response_token_ids = [output.token_ids for output in res.outputs] 39 | 40 | response = await original_chat_completion_full_generator( 41 | self, 42 | request, 43 | _generate_inceptor(), 44 | request_id, 45 | model_name, 46 | conversation, 47 | tokenizer, 48 | request_metadata, 49 | ) 50 | response = response.model_copy( 51 | update={ 52 | "prompt_token_ids": prompt_token_ids, 53 | "response_token_ids": response_token_ids, 54 | } 55 | ) 56 | 57 | return response 58 | 59 | 60 | def instrument_vllm(): 61 | vllm.entrypoints.openai.protocol.ChatCompletionResponse = ( 62 | ChatCompletionResponsePatched 63 | ) 64 | OpenAIServingChat.chat_completion_full_generator = chat_completion_full_generator 65 | -------------------------------------------------------------------------------- /eppo/code/env/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | from enum import Enum 5 | from typing import Optional, Iterable, Callable 6 | import gym 7 | 8 | from tianshou.env import BaseVectorEnv 9 | from utilsd.config import PythonConfig, configclass 10 | from .logging import Logger 11 | from .finite_env import FiniteDummyVectorEnv, FiniteShmemVectorEnv, FiniteSubprocVectorEnv 12 | 13 | 14 | class ParallelMode(str, Enum): 15 | dummy = "dummy" 16 | shmem = "shmem" 17 | subproc = "subproc" 18 | 19 | 20 | @configclass 21 | class EnvConfig(PythonConfig): 22 | 23 | concurrency: int 24 | parallel_mode: ParallelMode = ParallelMode.shmem 25 | 26 | def post_validate(self): 27 | assert self.concurrency >= 1 28 | return True 29 | 30 | @configclass 31 | class DivEnvConfig(PythonConfig): 32 | num_skill: int 33 | time_per_step: int 34 | vol_limit: Optional[float] # the limitation of current decision compared to volume 35 | 36 | concurrency: int 37 | parallel_mode: ParallelMode = ParallelMode.shmem 38 | 39 | def post_validate(self): 40 | assert self.vol_limit is None or self.vol_limit < 1 41 | assert self.concurrency >= 1 42 | return True 43 | 44 | @configclass 45 | class AtariEnvConfig(PythonConfig): 46 | concurrency: int 47 | env_name: str 48 | episode_life: bool=True 49 | clip_rewards: bool=True 50 | frame_stack: bool=False 51 | scale: bool=False 52 | dnc: bool=False 53 | max_episode_steps: Optional[int]=None 54 | parallel_mode: ParallelMode = ParallelMode.shmem 55 | 56 | 57 | 58 | class MujocoWrapper(gym.Wrapper): 59 | def __init__(self,env, max_step=1000): 60 | super().__init__(env) 61 | self.max_step = max_step 62 | self.cur_step = 0 63 | 64 | def reset(self): 65 | self.cur_step = 0 66 | return self.env.reset() 67 | 68 | def step(self,act): 69 | self.cur_step += 1 70 | obs,rwd,done,info = self.env.step(act) 71 | if self.cur_step>=self.max_step: 72 | done = True 73 | return obs,rwd,done,info -------------------------------------------------------------------------------- /eppo/code/tools/env_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import copy 4 | import dataclasses 5 | from pathlib import Path 6 | from typing import Any, Callable, Dict, Optional, Tuple, List 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch, gym 11 | from tianshou.data import Collector, VectorReplayBuffer 12 | from tianshou.env import BaseVectorEnv 13 | from tianshou.policy import BasePolicy 14 | from torch.utils.data import Dataset 15 | from utilsd import get_output_dir, get_checkpoint_dir, setup_experiment, use_cuda 16 | from utilsd.experiment import print_config 17 | from utilsd.earlystop import EarlyStop, EarlyStopStatus 18 | from utilsd.logging import print_log 19 | 20 | from code.env import EnvConfig, ParallelMode, Logger, FiniteDummyVectorEnv,FiniteSubprocVectorEnv,FiniteShmemVectorEnv, AtariEnvConfig, MujocoWrapper 21 | 22 | import sys,os 23 | cur_dir = os.getcwd() 24 | sys.path.insert(0,cur_dir) 25 | print(sys.path) 26 | from code.env.atari_env import wrap_deepmind 27 | 28 | 29 | def atari_game_env_factory(env_config: AtariEnvConfig, env_name:str, logger: Logger,seed=42, dnc=False): 30 | def single_env(env_config=env_config): 31 | env = wrap_deepmind(env_config.env_name, env_config.episode_life,env_config.clip_rewards, env_config.scale, env_config.frame_stack, dnc=dnc, env_name=env_config.env_name) 32 | return env 33 | def test_env(env_config=env_config): 34 | env = wrap_deepmind(env_config.env_name, episode_life=False, clip_rewards=False, dnc=False, env_name=env_config.env_name,) 35 | return env 36 | 37 | if env_config.parallel_mode == ParallelMode.dummy: 38 | venv_cls = FiniteDummyVectorEnv 39 | elif env_config.parallel_mode == ParallelMode.shmem: 40 | venv_cls = FiniteShmemVectorEnv 41 | elif env_config.parallel_mode == ParallelMode.subproc: 42 | venv_cls = FiniteSubprocVectorEnv 43 | 44 | envs1 = venv_cls(logger, [single_env for _ in range(env_config.concurrency)]) 45 | envs2 = venv_cls(logger, [test_env for _ in range(env_config.concurrency)]) 46 | envs1.seed(seed) 47 | envs2.seed(seed) 48 | return envs1,envs2 49 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/reward.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import inspect 3 | import warnings 4 | from typing import TypedDict, Optional 5 | 6 | from agentops.sdk.decorators import operation 7 | 8 | 9 | class RewardSpanData(TypedDict): 10 | type: "reward" 11 | value: Optional[float] 12 | 13 | 14 | def reward(fn: callable) -> callable: 15 | """ 16 | A decorator to wrap a function that computes rewards. 17 | It will automatically handle the input and output of the function. 18 | """ 19 | 20 | def wrap_result(result: Optional[float]) -> RewardSpanData: 21 | """ 22 | Wrap the result of the function in a dict. 23 | """ 24 | if result is None: 25 | return {"type": "reward", "value": None} 26 | if not isinstance(result, (float, int)): 27 | warnings.warn(f"Reward is ignored because it is not a number: {result}") 28 | return {"type": "reward", "value": None} 29 | return {"type": "reward", "value": float(result)} 30 | 31 | # Check if the function is async 32 | is_async = asyncio.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn) 33 | 34 | if is_async: 35 | 36 | async def wrapper_async(*args, **kwargs): 37 | result: Optional[float] = None 38 | 39 | @operation 40 | async def agentops_reward_operation() -> RewardSpanData: 41 | # The reward function we are interested in tracing 42 | # It takes zero inputs and return a formatted dict 43 | nonlocal result 44 | result = await fn(*args, **kwargs) 45 | return wrap_result(result) 46 | 47 | await agentops_reward_operation() 48 | return result 49 | 50 | return wrapper_async 51 | 52 | else: 53 | 54 | def wrapper(*args, **kwargs): 55 | result: Optional[float] = None 56 | 57 | @operation 58 | def agentops_reward_operation() -> RewardSpanData: 59 | nonlocal result 60 | result = fn(*args, **kwargs) 61 | return wrap_result(result) 62 | 63 | agentops_reward_operation() 64 | return result 65 | 66 | return wrapper 67 | -------------------------------------------------------------------------------- /agentlightning/examples/calc_x/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | export N_GPUS=1 6 | export BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct 7 | export DATA_DIR=data 8 | export ROLLOUT_TP_SIZE=1 9 | export EXPERIMENT_NAME=calc_x 10 | export PROJECT_NAME=AgentLightning 11 | 12 | echo "Starting training script..." 13 | 14 | python -m verl.trainer.main_ppo \ 15 | agent_mode.enable=True \ 16 | actor_rollout_ref.rollout.mode=async \ 17 | actor_rollout_ref.rollout.chat_scheduler=agentlightning.instrumentation.verl_chat_scheduler.NaiveChatCompletionScheduler \ 18 | algorithm.adv_estimator=grpo \ 19 | actor_rollout_ref.model.path=${BASE_MODEL} \ 20 | data.train_files=${DATA_DIR}/train.parquet \ 21 | data.val_files=${DATA_DIR}/test.parquet \ 22 | actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ 23 | trainer.n_gpus_per_node=${N_GPUS} \ 24 | data.train_batch_size=32 \ 25 | actor_rollout_ref.rollout.n=4 \ 26 | actor_rollout_ref.actor.ppo_mini_batch_size=32 \ 27 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ 28 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ 29 | data.max_prompt_length=4096 \ 30 | data.max_response_length=2048 \ 31 | data.filter_overlong_prompts=True \ 32 | data.truncation='error' \ 33 | trainer.val_before_train=True \ 34 | actor_rollout_ref.actor.optim.lr=1e-6 \ 35 | actor_rollout_ref.model.use_remove_padding=True \ 36 | actor_rollout_ref.actor.use_kl_loss=False \ 37 | actor_rollout_ref.actor.kl_loss_coef=0.000 \ 38 | actor_rollout_ref.actor.entropy_coeff=0 \ 39 | actor_rollout_ref.actor.clip_ratio_low=0.2 \ 40 | actor_rollout_ref.actor.clip_ratio_high=0.3 \ 41 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 42 | actor_rollout_ref.actor.fsdp_config.param_offload=True \ 43 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ 44 | actor_rollout_ref.rollout.name=vllm \ 45 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ 46 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ 47 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 48 | algorithm.use_kl_in_reward=False \ 49 | trainer.critic_warmup=0 \ 50 | trainer.logger=['console','wandb'] \ 51 | trainer.project_name=${PROJECT_NAME} \ 52 | trainer.experiment_name=${EXPERIMENT_NAME} \ 53 | trainer.nnodes=1 \ 54 | trainer.save_freq=256 \ 55 | trainer.test_freq=32 \ 56 | trainer.total_epochs=2 $@ 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoRL Research 2 | 3 | This repository contains code for a series of research projects on Automated Reinforcement Learning (AutoRL). 4 | 5 | ## News 6 | 7 | * 2025.7.22 [Agent Lightning](agentlightning) has graduated as a standalone repository. Check it out [here](https://github.com/microsoft/agent-lightning). 8 | * 2025.6.6 [Agent Lightning](agentlightning) is now available as a research preview. Blog post is [here](https://www.microsoft.com/en-us/research/project/agent-lightning/). 9 | * 2023.3.10 [Bootstrapped Transformer for Offline Reinforcement Learning](https://seqml.github.io/bootorl/) is now available in [bootorl](bootorl). 10 | * 2022.10.12 [Reinforcement Learning with Automated Auxiliary Loss Search](https://seqml.github.io/a2ls/) is now available in [a2ls](a2ls). 11 | * 2022.9.21 [Towards Applicable Reinforcement Learning: Improving the Generalization and Sample Efficiency with Policy Ensemble](https://seqml.github.io/eppo/) is now available in [eppo](eppo). 12 | 13 | ## Contributing 14 | 15 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 16 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 17 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 18 | 19 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 20 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 21 | provided by the bot. You will only need to do this once across all repos using our CLA. 22 | 23 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 24 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 25 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 26 | 27 | ## Trademarks 28 | 29 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 30 | trademarks or logos is subject to and must follow 31 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 32 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 33 | Any use of third-party trademarks or logos are subject to those third-party's policies. 34 | -------------------------------------------------------------------------------- /agentlightning/examples/spider/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | export N_GPUS=1 6 | export BASE_MODEL=Qwen/Qwen2.5-Coder-1.5B-Instruct 7 | export DATA_DIR=data 8 | export ROLLOUT_TP_SIZE=1 9 | export EXPERIMENT_NAME=spider 10 | export PROJECT_NAME=AgentLightning 11 | 12 | echo "Starting training script..." 13 | 14 | python -m verl.trainer.main_ppo \ 15 | agent_mode.enable=True \ 16 | actor_rollout_ref.rollout.mode=async \ 17 | actor_rollout_ref.rollout.chat_scheduler=agentlightning.instrumentation.verl_chat_scheduler.NaiveChatCompletionScheduler \ 18 | algorithm.adv_estimator=grpo \ 19 | actor_rollout_ref.model.path=${BASE_MODEL} \ 20 | data.train_files=${DATA_DIR}/train_spider.parquet \ 21 | data.val_files=${DATA_DIR}/test_dev_500.parquet \ 22 | actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP_SIZE \ 23 | trainer.n_gpus_per_node=${N_GPUS} \ 24 | data.train_batch_size=32 \ 25 | actor_rollout_ref.rollout.n=4 \ 26 | actor_rollout_ref.actor.ppo_mini_batch_size=32 \ 27 | actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ 28 | actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ 29 | data.max_prompt_length=4096 \ 30 | data.max_response_length=2048 \ 31 | data.filter_overlong_prompts=True \ 32 | data.truncation='error' \ 33 | trainer.val_before_train=True \ 34 | actor_rollout_ref.actor.optim.lr=1e-6 \ 35 | actor_rollout_ref.model.use_remove_padding=True \ 36 | actor_rollout_ref.actor.use_kl_loss=False \ 37 | actor_rollout_ref.actor.kl_loss_coef=0.000 \ 38 | actor_rollout_ref.actor.entropy_coeff=0 \ 39 | actor_rollout_ref.actor.clip_ratio_low=0.2 \ 40 | actor_rollout_ref.actor.clip_ratio_high=0.3 \ 41 | actor_rollout_ref.model.enable_gradient_checkpointing=True \ 42 | actor_rollout_ref.actor.fsdp_config.param_offload=True \ 43 | actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ 44 | actor_rollout_ref.rollout.name=vllm \ 45 | actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ 46 | actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ 47 | actor_rollout_ref.ref.fsdp_config.param_offload=True \ 48 | algorithm.use_kl_in_reward=False \ 49 | trainer.critic_warmup=0 \ 50 | trainer.logger=['console','wandb'] \ 51 | trainer.project_name=${PROJECT_NAME} \ 52 | trainer.experiment_name=${EXPERIMENT_NAME} \ 53 | trainer.nnodes=1 \ 54 | trainer.save_freq=256 \ 55 | trainer.test_freq=32 \ 56 | trainer.total_epochs=2 $@ 57 | -------------------------------------------------------------------------------- /bootorl/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /eppo/code/tools/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from typing import Optional, List 4 | 5 | from utilsd.config import PythonConfig, RegistryConfig, RuntimeConfig, configclass 6 | 7 | from code.env import EnvConfig 8 | from code.network import NETWORKS 9 | from code.policy import POLICIES 10 | from code.env import AtariEnvConfig 11 | 12 | 13 | @configclass 14 | class TrainerConfig(PythonConfig): 15 | max_epoch: int 16 | episode_per_collect: int 17 | batch_size: int 18 | repeat_per_collect: int 19 | earlystop_patience: int 20 | val_every_n_epoch: int 21 | fast_dev_run: bool = False 22 | buffer_size: int=200000 23 | save_epoch: bool=False 24 | 25 | @configclass 26 | class OffTrainerConfig(PythonConfig): 27 | max_epoch: int 28 | episode_per_collect: int 29 | steps_per_epoch: int 30 | update_per_step: float 31 | batch_size: int 32 | earlystop_patience: int 33 | val_every_n_epoch: int 34 | fast_dev_run: bool = False 35 | buffer_size: int=200000 36 | save_epoch: bool=False 37 | 38 | 39 | @configclass 40 | class BacktestConfig(PythonConfig): 41 | env_type: str="atari" 42 | eps_num: int = 20 43 | 44 | @configclass 45 | class GameRunConfig(PythonConfig): 46 | env: AtariEnvConfig 47 | policy: RegistryConfig[POLICIES] 48 | network1: Optional[RegistryConfig[NETWORKS]] = None 49 | network2: Optional[RegistryConfig[NETWORKS]] = None 50 | trainer: Optional[TrainerConfig] = None 51 | runtime: RuntimeConfig = RuntimeConfig() 52 | backtest: bool = False 53 | use_step: bool = False 54 | bk: Optional[BacktestConfig]=None 55 | 56 | @configclass 57 | class GameOffRunConfig(PythonConfig): 58 | env: AtariEnvConfig 59 | policy: RegistryConfig[POLICIES] 60 | network: Optional[List[RegistryConfig[NETWORKS]]] = None 61 | trainer: Optional[OffTrainerConfig] = None 62 | runtime: RuntimeConfig = RuntimeConfig() 63 | 64 | @configclass 65 | class AtariRunConfig(PythonConfig): 66 | env: EnvConfig 67 | policy: RegistryConfig[POLICIES] 68 | network1: Optional[RegistryConfig[NETWORKS]] = None 69 | network2: Optional[RegistryConfig[NETWORKS]] = None 70 | trainer: Optional[TrainerConfig] = None 71 | runtime: RuntimeConfig = RuntimeConfig() 72 | 73 | 74 | @configclass 75 | class GameUtilsConfig(PythonConfig): 76 | env: AtariEnvConfig 77 | network1: Optional[RegistryConfig[NETWORKS]] = None 78 | runtime: RuntimeConfig = RuntimeConfig() 79 | target: int 80 | pattern: int -------------------------------------------------------------------------------- /eppo/README.md: -------------------------------------------------------------------------------- 1 | # Towards Applicable Reinforcement Learning: Improving the Generalization and Sample Efficiency with Policy Ensemble 2 | This is the experiment code for our IJCAI 2022 paper "[Towards Applicable Reinforcement Learning: Improving the Generalization and Sample Efficiency with Policy Ensemble](https://seqml.github.io/eppo/)". 3 | 4 | ## Abstract 5 | > It is challenging for reinforcement learning (RL) algorithms to succeed in real-world applications like financial trading and logistic system due to the noisy observation and environment shifting between training and evaluation. Thus, it requires both high sample efficiency and generalization for resolving real-world tasks. However, directly applying typical RL algorithms can lead to poor performance in such scenarios. Considering the great performance of ensemble methods on both accuracy and generalization in supervised learning (SL), we design a robust and applicable method named Ensemble Proximal Policy Optimization (EPPO), which learns ensemble policies in an end-to-end manner. Notably, EPPO combines each policy and the policy ensemble organically and optimizes both simultaneously. In addition, EPPO adopts a diversity enhancement regularization over the policy space which helps to generalize to unseen states and promotes exploration. We theoretically prove EPPO increases exploration efficacy, and through comprehensive experimental evaluations on various tasks, we demonstrate that EPPO achieves higher efficiency and is robust for real-world applications compared with vanilla policy optimization algorithms and other ensemble methods. Code and supplemental materials are available at https://seqml.github.io/eppo. 6 | 7 | ## Environment Dependencies 8 | ### Dependencies 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### Running 14 | Take `Pong` environment in Atari benchmarks as an example, to run EPPO, you can do the following. 15 | ``` 16 | python code/tools/train_on_atari.py exp/atari_local.yml 17 | ``` 18 | 19 | To run EPPO-Ens, please set the `center_policy_coef` in `exp/atari_local.yml` to 0. 20 | 21 | To run EPPO-Div, please set the `diverse_coef` in `exp/atari_local.yml` to 0. 22 | 23 | ## Reference 24 | You are more than welcome to cite our paper: 25 | ``` 26 | @article{yang2022towards, 27 | title={Towards Applicable Reinforcement Learning: Improving the Generalization and Sample Efficiency with Policy Ensemble}, 28 | author={Yang, Zhengyu and Ren, Kan and Luo, Xufang and Liu, Minghuan and Liu, Weiqing and Bian, Jiang and Zhang, Weinan and Li, Dongsheng}, 29 | journal={arXiv preprint arXiv:2205.09284}, 30 | year={2022} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '15 17 * * 2' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | with: 74 | category: "/language:${{matrix.language}}" 75 | -------------------------------------------------------------------------------- /agentlightning/scripts/verl_git_diff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # --- Configuration --- 4 | # The commit hash to compare against. 5 | COMMIT_HASH="2dc3e0ebadb479bb3f2b48cfc7f28a3b70d5ce60" 6 | # The absolute or relative path to the git repository. 7 | VERL_REPO_PATH=${VERL_REPO_PATH:-"~/verl"} 8 | # The subfolder within the repository to limit the diff to. 9 | SUBFOLDER="verl" 10 | # The output directory for files from the old commit. 11 | OLD_VERSION_DIR="agentlightning/instrumentation/verl_old" 12 | # The output directory for files from the current version (HEAD). 13 | CURRENT_VERSION_DIR="agentlightning/instrumentation/verl_patch" 14 | # --------------------- 15 | 16 | # --- Script Start --- 17 | 18 | # Store the original directory where the script was run 19 | ORIGINAL_DIR=$(pwd) 20 | 21 | # Expand the tilde (~) in the repo path to the user's home directory 22 | VERL_REPO_PATH=$(eval echo "$VERL_REPO_PATH") 23 | 24 | # --- Validations --- 25 | if [ ! -d "$VERL_REPO_PATH" ]; then 26 | echo "Error: Repository path '$VERL_REPO_PATH' does not exist." 27 | exit 1 28 | fi 29 | 30 | if [ ! -d "$VERL_REPO_PATH/.git" ]; then 31 | echo "Error: Directory '$VERL_REPO_PATH' is not a git repository." 32 | exit 1 33 | fi 34 | 35 | # Change to the repository directory 36 | cd "$VERL_REPO_PATH" || exit 37 | 38 | # --- Main Logic --- 39 | echo "Creating output directories in '$ORIGINAL_DIR'..." 40 | # Use absolute paths for output directories to ensure they are created relative to the original location 41 | mkdir -p "$ORIGINAL_DIR/$OLD_VERSION_DIR" 42 | mkdir -p "$ORIGINAL_DIR/$CURRENT_VERSION_DIR" 43 | 44 | echo "Finding differences between $COMMIT_HASH and HEAD in '$SUBFOLDER'..." 45 | 46 | # Get the list of changed files relative to the repo root 47 | changed_files=$(git diff --name-only "$COMMIT_HASH" HEAD -- "$SUBFOLDER") 48 | 49 | if [ -z "$changed_files" ]; then 50 | echo "No differences found in the specified subfolder." 51 | cd "$ORIGINAL_DIR" # Go back to the original directory before exiting 52 | exit 0 53 | fi 54 | 55 | echo "Exporting files..." 56 | 57 | while read -r file; do 58 | echo " -> $file" 59 | 60 | # Define absolute paths for destination files 61 | dest_old="$ORIGINAL_DIR/$OLD_VERSION_DIR/$file" 62 | dest_current="$ORIGINAL_DIR/$CURRENT_VERSION_DIR/$file" 63 | 64 | # Create the directory structure within the output directories 65 | mkdir -p "$(dirname "$dest_old")" 66 | mkdir -p "$(dirname "$dest_current")" 67 | 68 | # Get the file from the specific commit and redirect output to the absolute path 69 | git show "$COMMIT_HASH:$file" > "$dest_old" 2>/dev/null || echo " - Note: File did not exist in the old commit." 70 | 71 | # Copy the file from the current working directory (which is VERL_REPO_PATH) to the absolute path 72 | if [ -f "$file" ]; then 73 | cp "$file" "$dest_current" 74 | else 75 | echo " - Note: File is deleted in the current version." 76 | fi 77 | 78 | done <<< "$changed_files" 79 | 80 | # Change back to the original directory 81 | cd "$ORIGINAL_DIR" 82 | 83 | echo 84 | echo "Done. You can now compare the contents of '$OLD_VERSION_DIR' and '$CURRENT_VERSION_DIR'." 85 | -------------------------------------------------------------------------------- /bootorl/analysis/visualize_state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | sys.path.append(os.getcwd()) 9 | 10 | from PIL import Image 11 | from utils import DiscretizedDataset 12 | from utils import Renderer 13 | 14 | 15 | def load_states(target="boot_genr_repeat_ar", env="halfcheetah-medium"): 16 | transitions = np.load(f"./logs/{env}/distribution/{target}/transitions.npz") 17 | trans_origin = transitions["trans_origin"] 18 | trans_generated_reconstruct = transitions[f"trans_recon_{target.split('_')[3]}_{target.split('_')[1]}"] 19 | dataset = DiscretizedDataset(logger=None, env=f"{env}-v2", n_bins=100, sequence_length=10, penalty=-100, discount=0.99) 20 | discretizer = dataset.discretizer 21 | mean, std = dataset.raw_data_mean, dataset.raw_data_std 22 | 23 | noise = np.random.normal(scale=3e-4, size=trans_origin.shape) 24 | noise = noise * std + mean 25 | trans_noisy_state = trans_origin.copy() 26 | trans_noisy_state[:, :dataset.observation_dim] += noise[:, :dataset.observation_dim] 27 | trans_discretized_noisy_state = discretizer.discretize(trans_noisy_state) 28 | trans_recon_noisy_state = discretizer.reconstruct(trans_discretized_noisy_state) 29 | 30 | state_dim = dataset.observation_dim 31 | state_origin = trans_origin[:, :state_dim] 32 | state_generated_reconstruct = trans_generated_reconstruct[:, :state_dim] 33 | state_noisy_reconstruct = trans_recon_noisy_state[:, :state_dim] 34 | return state_origin, state_generated_reconstruct, state_noisy_reconstruct 35 | 36 | 37 | def render(target="boot_genr_repeat_ar", game="halfcheetah", level="medium", num=5): 38 | save_dir = f"./analysis/images/{target}/" 39 | if not os.path.exists(save_dir): 40 | os.makedirs(save_dir) 41 | renderer = Renderer(env=f"{game}-{level}-v2") 42 | state_origin, state_generated_reconstruct, state_noisy_reconstruct = load_states(target=target, env=f"{game}-{level}") 43 | 44 | for i in range(num): 45 | img_file_name = os.path.join(save_dir, f"{game}-{level}_original_{i}.jpg") 46 | img = renderer.render(state_origin[i], dim=2048) 47 | img = Image.fromarray(img) 48 | img.save(img_file_name) 49 | for i in range(num): 50 | img_file_name = os.path.join(save_dir, f"{game}-{level}_generated_{i}.jpg") 51 | img = renderer.render(state_generated_reconstruct[i], dim=2048) 52 | img = Image.fromarray(img) 53 | img.save(img_file_name) 54 | for i in range(num): 55 | img_file_name = os.path.join(save_dir, f"{game}-{level}_noisy_{i}.jpg") 56 | img = renderer.render(state_noisy_reconstruct[i], dim=2048) 57 | img = Image.fromarray(img) 58 | img.save(img_file_name) 59 | 60 | 61 | def render_all(): 62 | dist = {} 63 | for target in ["boot_genr_once_ar", "boot_genr_once_tf"]: 64 | for game in ["halfcheetah", "hopper", "walker2d"]: 65 | for level in ["medium", "medium-replay", "medium-expert"]: 66 | render(target=target, game=game, level=level) 67 | 68 | 69 | render_all() 70 | -------------------------------------------------------------------------------- /agentlightning/scripts/verl_git_apply.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # --- Configuration --- 4 | # The script will only run if the repo's HEAD is at this exact commit. 5 | REQUIRED_COMMIT_ID="2dc3e0ebadb479bb3f2b48cfc7f28a3b70d5ce60" 6 | 7 | # The directory containing the 'new' or 'patched' files to be applied. 8 | # This is the source of truth for the final state. 9 | PATCH_SOURCE_DIR="agentlightning/instrumentation/verl_patch" 10 | # --------------------- 11 | 12 | # --- Functions --- 13 | usage() { 14 | echo "Usage: $0 " 15 | echo 16 | echo "Applies file additions and modifications from '$PATCH_SOURCE_DIR'" 17 | echo "to the specified repository path non-interactively." 18 | echo "The repository MUST be at commit $REQUIRED_COMMIT_ID." 19 | echo 20 | echo "Arguments:" 21 | echo " The path to the git repository to apply the patch to." 22 | exit 1 23 | } 24 | 25 | # --- Script Start --- 26 | 27 | # The repository path is the first argument 28 | VERL_REPO_PATH="$1" 29 | 30 | # Check if repository path was provided 31 | if [ -z "$VERL_REPO_PATH" ]; then 32 | echo "Error: Repository path is required." 33 | usage 34 | fi 35 | 36 | # --- Main Logic --- 37 | echo "***************************************************" 38 | echo "* Non-Interactive Repository Patch Apply Script *" 39 | echo "* (Add/Modify Operations Only) *" 40 | echo "***************************************************" 41 | echo 42 | echo "Target Repository: '$VERL_REPO_PATH'" 43 | echo "Patch Source: '$PATCH_SOURCE_DIR'" 44 | echo "Required Commit: '$REQUIRED_COMMIT_ID'" 45 | echo 46 | 47 | # --- Validations --- 48 | 49 | # Expand the tilde (~) in the repo path to the user's home directory 50 | VERL_REPO_PATH=$(eval echo "$VERL_REPO_PATH") 51 | 52 | if [ ! -d "$VERL_REPO_PATH" ] || [ ! -d "$VERL_REPO_PATH/.git" ]; then 53 | echo "Error: Target repository path '$VERL_REPO_PATH' does not exist or is not a git repository." 54 | exit 1 55 | fi 56 | 57 | if [ ! -d "$PATCH_SOURCE_DIR" ]; then 58 | echo "Error: Patch source directory '$PATCH_SOURCE_DIR' does not exist." 59 | exit 1 60 | fi 61 | 62 | # --- PRE-FLIGHT CHECK: Verify Commit ID --- 63 | echo "Verifying repository commit ID..." 64 | current_commit=$(git -C "$VERL_REPO_PATH" rev-parse HEAD) 65 | 66 | if [ "$current_commit" != "$REQUIRED_COMMIT_ID" ]; then 67 | echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" 68 | echo "!! Critical Error: Incorrect repository state." 69 | echo "!! The script can only be applied when the repository's" 70 | echo "!! HEAD is at the required commit." 71 | echo "!! Expected: $REQUIRED_COMMIT_ID" 72 | echo "!! Found: $current_commit" 73 | echo "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" 74 | exit 1 75 | else 76 | echo "Success. Repository is at the correct commit." 77 | fi 78 | 79 | # --- Apply Additions and Modifications --- 80 | echo 81 | echo "Applying file additions and modifications..." 82 | find "$PATCH_SOURCE_DIR" -type f -print0 | while IFS= read -r -d $'\0' source_file; do 83 | relative_path="${source_file#$PATCH_SOURCE_DIR/}" 84 | destination_file="$VERL_REPO_PATH/$relative_path" 85 | 86 | # Ensure the destination directory exists before copying 87 | mkdir -p "$(dirname "$destination_file")" 88 | 89 | # Copy the file, overwriting the destination 90 | cp -v "$source_file" "$destination_file" 91 | done 92 | 93 | echo 94 | echo "Patch application complete." 95 | echo "Review the changes in '$VERL_REPO_PATH' with 'git status' or 'git diff'." 96 | -------------------------------------------------------------------------------- /bootorl/README.md: -------------------------------------------------------------------------------- 1 | # Bootstrapped Transformer 2 | Source code for NeurIPS 2022 paper *[Bootstrapped Transformer for Offline Reinforcement Learning](https://seqml.github.io/bootorl/)*. 3 | 4 | ## Abstract 5 | > Offline reinforcement learning (RL) aims at learning policies from previously collected static trajectory data without interacting with the real environment. Recent works provide a novel perspective by viewing offline RL as a generic sequence generation problem, adopting sequence models such as Transformer architecture to model distributions over trajectories, and repurposing beam search as a planning algorithm. However, the training datasets utilized in general offline RL tasks are quite limited and often suffer from insufficient distribution coverage, which could be harmful to training sequence generation. In this paper, we propose a novel algorithm named Bootstrapped Transformer, which incorporates the idea of bootstrapping and leverages the learned model to self-generate more offline data to further boost the sequence model training. We conduct extensive experiments on two offline RL benchmarks and demonstrate that our model can largely remedy the existing offline RL training limitations and beat other strong baseline methods. We also analyze the generated pseudo data and the revealed characteristics may shed some light on offline RL training. The codes and supplementary materials are available at https://seqml.github.io/bootorl. 6 | 7 | ## Dependencies 8 | 9 | Python dependencies are listed in [`./environment.yml`](./environment.yml). 10 | 11 | We also provides an extra dockerfile as [`./Dockerfile`](./Dockerfile) for reproducibility. 12 | 13 | ## Usage 14 | 15 | To train the model, run with 16 | ``` 17 | python main/train.py --dataset hopper-medium-replay-v2 \ 18 | --bootstrap True \ 19 | --bootstrap_type once \ 20 | --generation_type autoregressive 21 | ``` 22 | or 23 | ``` 24 | python main/train.py --dataset hopper-medium-replay-v2 \ 25 | --bootstrap True \ 26 | --bootstrap_type repeat \ 27 | --generation_type teacherforcing 28 | ``` 29 | depending on your choice of hyperparameters and bootstrap schemes. All default hyperparameters used in our experiments are placed at [`./utils/argparser.py`](`./utils/argparser.py`). You can find it in `DEFAULT_ARGS` at the beginning of this file. By default, training logs and saved models are output to `./logs/-/` directory. 30 | 31 | To evaluate the performance of trained model, run with 32 | ``` 33 | python main/plan.py --dataset hopper-medium-replay-v2 \ 34 | --checkpoint \ 35 | --suffix 36 | ``` 37 | where `checkpoint_directory` should be the directory containing your model `state_*.pt`. By default, evaluation results are output to `./logs/-/` directory. 38 | 39 | 40 | ## Acknowledgements 41 | Some source codes of this work have been implemented on top of *Trajectory Transformer* (https://arxiv.org/abs/2106.02039). 42 | *Trajectory Transformer* uses GPT implementation from Andrej Karpathy's *minGPT* repo. 43 | 44 | ## Citation 45 | You are more than welcome to cite our paper: 46 | ``` 47 | @article{wang2022bootstrapped, 48 | title={Bootstrapped Transformer for Offline Reinforcement Learning}, 49 | author={Wang, Kerong and Zhao, Hanye and Luo, Xufang and Ren, Kan and Zhang, Weinan and Li, Dongsheng}, 50 | journal={arXiv preprint arXiv:2206.08569}, 51 | year={2022} 52 | } 53 | ``` -------------------------------------------------------------------------------- /bootorl/main/plan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import glob 7 | import random 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from argparse import Namespace 13 | 14 | sys.path.append(os.getcwd()) 15 | 16 | import utils 17 | from utils import DiscretizedDataset 18 | from utils import plan 19 | from utils import Renderer 20 | from model import GPT 21 | 22 | 23 | def set_seed(seed): 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | 30 | def find_ckpt_epoch(loadpath, target="latest"): 31 | assert target in ["all", "latest", "earliest"] or isinstance(target, int) 32 | states = glob.glob1(loadpath, 'state_*') 33 | epochs = [int(state.replace('state_', '').replace('.pt', '')) for state in states] 34 | if target == "all": 35 | return epochs 36 | if target == "latest": 37 | return [max(epochs)] 38 | elif target == "earliest": 39 | return [min(epochs)] 40 | elif isinstance(target, int): 41 | assert target in epochs 42 | return [target] 43 | 44 | 45 | # Setup 46 | parser = utils.ArgParser() 47 | args = parser.parse_args() 48 | logger = parser.get_logger() 49 | set_seed(args.seed) 50 | 51 | # Environment 52 | env = utils.load_environment(args.dataset) 53 | 54 | # Dataset 55 | dataset = DiscretizedDataset( 56 | logger=logger, 57 | env=args.dataset, 58 | n_bins=args.n_bins, 59 | sequence_length=args.sequence_length, 60 | penalty=args.termination_penalty, 61 | discount=args.discount, 62 | ) 63 | 64 | obs_dim = dataset.observation_dim 65 | act_dim = dataset.action_dim 66 | trans_dim = dataset.joined_dim 67 | block_size = args.sequence_length * trans_dim - 1 68 | 69 | # Model 70 | model_config = Namespace( 71 | vocab_size=args.n_bins, 72 | block_size=block_size, 73 | n_layer=args.n_layer, 74 | n_head=args.n_head, 75 | n_embd=args.n_embd * args.n_head, 76 | observation_dim=obs_dim, 77 | action_dim=act_dim, 78 | transition_dim=trans_dim, 79 | action_weight=args.action_weight, 80 | reward_weight=args.reward_weight, 81 | value_weight=args.value_weight, 82 | embd_pdrop=args.embd_pdrop, 83 | resid_pdrop=args.resid_pdrop, 84 | attn_pdrop=args.attn_pdrop 85 | ) 86 | 87 | renderer = Renderer(env, dataset.observation_dim, dataset.action_dim) 88 | 89 | epoch_ranges = sorted(find_ckpt_epoch(args.checkpoint_path)) 90 | logger.debug(f"Find checkpoint epochs: {epoch_ranges}") 91 | all_info = {} 92 | for epoch in epoch_ranges: 93 | logger.info(f'Loading model epoch: {epoch}') 94 | state_path = os.path.join(args.checkpoint_path, f'state_{epoch}.pt') 95 | state = torch.load(state_path) 96 | 97 | model = GPT(model_config).to(args.device) 98 | model.load_state_dict(state, strict=True) 99 | 100 | info = plan(args, env, dataset, model, logger) 101 | rollout_states = info.pop("rollout_states") 102 | info = pd.DataFrame(info) 103 | info.index.name = "Timestep" 104 | all_info[(args.seed, epoch)] = info 105 | 106 | # rollout_states = np.stack(rollout_states, axis=0) 107 | # images = renderer.render_observations(rollout_states) 108 | # images_savepath = os.path.join(args.output_dir, f"epoch{epoch}_result.gif") 109 | # renderer.save_gif(images, images_savepath) 110 | 111 | all_info = pd.concat(all_info, names=["Seed", "Epoch"]) 112 | all_info.to_csv(os.path.join(args.output_dir, "reward_analysis.csv"), sep="\t") 113 | -------------------------------------------------------------------------------- /bootorl/utils/sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import numpy as np 5 | import torch 6 | import time 7 | 8 | 9 | def top_k_logits(logits, k): 10 | v, ix = torch.topk(logits, k) 11 | out = logits.clone() 12 | out[out < v[:, [-1]]] = -float('inf') 13 | return out 14 | 15 | 16 | def filter_cdf(logits, threshold): 17 | batch_inds = torch.arange(logits.shape[0], device=logits.device, dtype=torch.long) 18 | bins_inds = torch.arange(logits.shape[-1], device=logits.device) 19 | probs = logits.softmax(dim=-1) 20 | probs_sorted, _ = torch.sort(probs, dim=-1) 21 | probs_cum = torch.cumsum(probs_sorted, dim=-1) 22 | ## get minimum probability p such that the cdf up to p is at least `threshold` 23 | mask = probs_cum < threshold 24 | masked_inds = torch.argmax(mask * bins_inds, dim=-1) 25 | probs_threshold = probs_sorted[batch_inds, masked_inds] 26 | ## filter 27 | out = logits.clone() 28 | logits_mask = probs <= probs_threshold.unsqueeze(dim=-1) 29 | # Not -inf to prevent error: Assertion `cumdist[size - 1] > static_cast(0)` failed 30 | out[logits_mask] = -1000 31 | return out 32 | 33 | 34 | def round_to_multiple(x, N): 35 | pad = (N - x % N) % N 36 | return x + pad 37 | 38 | 39 | def forward(model, x, max_block=None, allow_crop=True, crop_increment=None, **kwargs): 40 | model.train(False) 41 | bs = model.module.get_block_size() if hasattr(model, "module") else model.get_block_size() 42 | block_size = min(bs, max_block or np.inf) 43 | if x.shape[1] > block_size: 44 | assert allow_crop 45 | n_crop = round_to_multiple(x.shape[1] - block_size, crop_increment) 46 | assert n_crop % crop_increment == 0 47 | x = x[:, n_crop:] 48 | logits, _, _ = model(x, return_info=False, **kwargs) 49 | return logits 50 | 51 | 52 | def sample(model, x, temperature=1.0, topk=None, cdf=None, **forward_kwargs): 53 | logits = forward(model, x, **forward_kwargs) 54 | logits = logits[:, -1] / temperature 55 | raw_probs = logits.softmax(dim=-1) 56 | if cdf is not None: 57 | logits = filter_cdf(logits, cdf) 58 | if topk is not None: 59 | logits = top_k_logits(logits, topk) 60 | probs = logits.softmax(dim=-1) 61 | indices = torch.multinomial(probs, num_samples=1) 62 | return indices, raw_probs 63 | 64 | 65 | @torch.no_grad() 66 | def sample_n(model, x, N, **sample_kwargs): 67 | batch_size = len(x) 68 | vs = model.module.vocab_size if hasattr(model, "module") else model.vocab_size 69 | probs = torch.zeros(batch_size, N, vs + 1, device=x.device) 70 | for n in range(N): 71 | indices, p = sample(model, x, **sample_kwargs) 72 | x = torch.cat((x, indices), dim=1) 73 | probs[:, n] = p 74 | return x, probs 75 | 76 | @torch.no_grad() 77 | def sample_rollout(model, x, N, temperature=1.0, cdf=None, topk=None): 78 | bs = len(x) 79 | vs = model.module.vocab_size if hasattr(model, "module") else model.vocab_size 80 | probs = torch.zeros(bs, N, vs + 1, device=x.device) 81 | for n in range(N): 82 | logits, _, _ = model(x, return_info=False) 83 | logits = logits[:, -1] / temperature 84 | raw_probs = logits.softmax(dim=-1) 85 | if cdf is not None: 86 | logits = filter_cdf(logits, cdf) 87 | if topk is not None: 88 | logits = top_k_logits(logits, topk) 89 | real_probs = logits.softmax(dim=-1) 90 | indices = torch.multinomial(real_probs, num_samples=1) 91 | x = torch.cat((x, indices), dim=1) 92 | probs[:, n] = raw_probs 93 | return x, probs 94 | -------------------------------------------------------------------------------- /bootorl/main/analyze_plan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import glob 7 | import random 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from argparse import Namespace 13 | 14 | sys.path.append(os.getcwd()) 15 | 16 | import utils 17 | from utils import DiscretizedDataset 18 | from utils import plan 19 | from utils import Renderer 20 | from model import GPT 21 | 22 | 23 | def set_seed(seed): 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | 30 | def find_ckpt_epoch(loadpath, target="latest"): 31 | assert target in ["all", "latest", "earliest"] or isinstance(target, int) 32 | states = glob.glob1(loadpath, 'state_*') 33 | epochs = [int(state.replace('state_', '').replace('.pt', '')) for state in states] 34 | if target == "all": 35 | return epochs 36 | if target == "latest": 37 | return [max(epochs)] 38 | elif target == "earliest": 39 | return [min(epochs)] 40 | elif isinstance(target, int): 41 | assert target in epochs 42 | return [target] 43 | 44 | 45 | # Setup 46 | parser = utils.ArgParser() 47 | args = parser.parse_args() 48 | logger = parser.get_logger() 49 | set_seed(args.seed) 50 | 51 | # Environment 52 | env = utils.load_environment(args.dataset) 53 | 54 | # Dataset 55 | dataset = DiscretizedDataset( 56 | logger=logger, 57 | env=args.dataset, 58 | n_bins=args.n_bins, 59 | sequence_length=args.sequence_length, 60 | penalty=args.termination_penalty, 61 | discount=args.discount, 62 | ) 63 | 64 | obs_dim = dataset.observation_dim 65 | act_dim = dataset.action_dim 66 | trans_dim = dataset.joined_dim 67 | block_size = args.sequence_length * trans_dim - 1 68 | 69 | # Model 70 | model_config = Namespace( 71 | vocab_size=args.n_bins, 72 | block_size=block_size, 73 | n_layer=args.n_layer, 74 | n_head=args.n_head, 75 | n_embd=args.n_embd * args.n_head, 76 | observation_dim=obs_dim, 77 | action_dim=act_dim, 78 | transition_dim=trans_dim, 79 | action_weight=args.action_weight, 80 | reward_weight=args.reward_weight, 81 | value_weight=args.value_weight, 82 | embd_pdrop=args.embd_pdrop, 83 | resid_pdrop=args.resid_pdrop, 84 | attn_pdrop=args.attn_pdrop 85 | ) 86 | 87 | renderer = Renderer(env, dataset.observation_dim, dataset.action_dim) 88 | 89 | epoch_ranges = sorted(find_ckpt_epoch(args.checkpoint_path)) 90 | logger.debug(f"Find checkpoint epochs: {epoch_ranges}") 91 | all_info = {} 92 | for epoch in epoch_ranges: 93 | logger.info(f'Loading model epoch: {epoch}') 94 | state_path = os.path.join(args.checkpoint_path, f'state_{epoch}.pt') 95 | state = torch.load(state_path) 96 | 97 | model = GPT(model_config).to(args.device) 98 | model.load_state_dict(state, strict=True) 99 | 100 | info = plan(args, env, dataset, model, logger) 101 | for k, v in info.items(): 102 | print(k, len(v)) 103 | info = pd.DataFrame(info) 104 | info.index.name = "Timestep" 105 | all_info[(args.seed, epoch)] = info 106 | 107 | rollout_states = np.stack(info["rollout_states"], axis=0) 108 | predict_states = np.stack(info["predict_states"], axis=0) 109 | print(len(rollout_states)) 110 | print(len(predict_states)) 111 | print(rollout_states[:3]) 112 | print(predict_states[:3]) 113 | 114 | 115 | all_info = pd.concat(all_info, names=["Seed", "Epoch"]) 116 | all_info.to_csv(os.path.join(args.output_dir, "reward_analysis.csv"), sep="\t") 117 | -------------------------------------------------------------------------------- /eppo/code/policy/ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import gym 7 | import torch 8 | import torch.nn as nn 9 | from gym.spaces import Discrete 10 | from tianshou.data import Batch 11 | from tianshou.policy import PPOPolicy 12 | 13 | from code.network import BaseNetwork 14 | from .base import POLICIES 15 | from .utils import chain_dedup, load_weight, preprocess_obs 16 | 17 | 18 | class PPOActor(nn.Module): 19 | def __init__(self, extractor: BaseNetwork, action_dim: int): 20 | super().__init__() 21 | self.extractor = extractor 22 | self.layer_out = nn.Sequential( 23 | nn.Linear(extractor.output_dim, action_dim), 24 | nn.Softmax(dim=-1) 25 | ) 26 | 27 | def forward(self, obs, state=None, info={}): 28 | feature = self.extractor(preprocess_obs(obs)) 29 | out = self.layer_out(feature) 30 | return out, state 31 | 32 | 33 | class PPOCritic(nn.Module): 34 | def __init__(self, extractor: BaseNetwork): 35 | super().__init__() 36 | self.extractor = extractor 37 | self.value_out = nn.Linear(extractor.output_dim, 1) 38 | 39 | def forward(self, obs, state=None, info={}): 40 | feature = self.extractor(preprocess_obs(obs)) 41 | return self.value_out(feature).squeeze(dim=-1) 42 | 43 | 44 | @POLICIES.register_module() 45 | class PPO(PPOPolicy): 46 | def __init__(self, 47 | lr: float, 48 | weight_decay: float = 0., 49 | discount_factor: float = 1., 50 | max_grad_norm: float = 100., 51 | reward_normalization: bool = True, 52 | eps_clip: float = 0.3, 53 | value_clip: float = True, 54 | vf_coef: float = 1., 55 | gae_lambda: float = 1., 56 | network: Optional[BaseNetwork] = None, 57 | obs_space: Optional[gym.Space] = None, 58 | action_space: Optional[gym.Space] = None, 59 | weight_file: Optional[Path] = None): 60 | assert network is not None and obs_space is not None 61 | assert isinstance(action_space, Discrete) 62 | actor = PPOActor(network, action_space.n) 63 | critic = PPOCritic(network) 64 | optimizer = torch.optim.Adam( 65 | chain_dedup(actor.parameters(), critic.parameters()), 66 | lr=lr, weight_decay=weight_decay) 67 | super().__init__(actor, critic, optimizer, torch.distributions.Categorical, 68 | discount_factor=discount_factor, 69 | max_grad_norm=max_grad_norm, 70 | reward_normalization=reward_normalization, 71 | eps_clip=eps_clip, 72 | value_clip=value_clip, 73 | vf_coef=vf_coef, 74 | gae_lambda=gae_lambda) 75 | if weight_file is not None: 76 | load_weight(self, weight_file) 77 | 78 | def forward(self, batch, state=None, **kwargs): 79 | """ 80 | This is already done in https://github.com/thu-ml/tianshou/pull/354 81 | Should be no longer needed for future release of tianshou. 82 | """ 83 | logits, h = self.actor(batch.obs, state=state) 84 | if isinstance(logits, tuple): 85 | dist = self.dist_fn(*logits) 86 | logits_, _ = logits 87 | else: 88 | dist = self.dist_fn(logits) 89 | logits_ = logits 90 | if self.training: 91 | act = dist.sample() 92 | else: 93 | act = torch.argmax(logits_, dim=-1) 94 | return Batch(logits=logits, act=act, state=h, dist=dist) 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | # env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | .DS_Store 163 | .vscode/ 164 | -------------------------------------------------------------------------------- /eppo/code/env/game_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import inspect 4 | import json 5 | from collections import defaultdict,deque 6 | from typing import TextIO 7 | 8 | from torch.utils.tensorboard import SummaryWriter 9 | from utilsd import get_tb_log_dir, get_output_dir 10 | from utilsd.avgmeter import MetricMeter 11 | from utilsd.logging import print_log 12 | 13 | from .finite_env import BaseLogger 14 | 15 | 16 | __all__ = ['Logger'] 17 | 18 | _tb_logger = _json_writer = None 19 | 20 | 21 | def _get_tb_logger() -> SummaryWriter: 22 | global _tb_logger 23 | if _tb_logger is None: 24 | _tb_logger = SummaryWriter(log_dir=get_tb_log_dir()) 25 | return _tb_logger 26 | 27 | 28 | def _get_json_writer() -> TextIO: 29 | global _json_writer 30 | if _json_writer is None: 31 | _json_writer = (get_output_dir() / 'summary.json').open('a') 32 | return _json_writer 33 | 34 | 35 | class GameLogger(BaseLogger): 36 | 37 | def __init__(self, ep_total, *, log_interval=100, prefix='Episode', tb_prefix='', count_global='episode', max_len=20): 38 | self.meter = MetricMeter() 39 | self.ep_count = 0 40 | self.global_step = 0 41 | self.ep_total = ep_total 42 | self.log_interval = log_interval 43 | self.prefix = prefix 44 | self.active_env_ids = set() 45 | assert count_global in ['step', 'episode'] 46 | self.count_global = count_global 47 | 48 | self.tb_writer = _get_tb_logger() 49 | self.tb_prefix = tb_prefix 50 | 51 | self.json_writer = _get_json_writer() 52 | 53 | self.episode_lengths = dict() 54 | self.episode_rewards = dict() 55 | self.eps_rwd = deque(maxlen=max_len) 56 | 57 | def log_step(self, env_id, obs, rew, done, info): 58 | self.active_env_ids.add(env_id) 59 | self.episode_lengths[env_id] += 1 60 | self.episode_rewards[env_id] += rew 61 | 62 | if self.count_global == 'step': 63 | self.global_step += 1 64 | 65 | if not done: 66 | return 67 | 68 | if self.count_global == 'episode': 69 | self.global_step += 1 70 | 71 | self.ep_count += 1 72 | self.eps_rwd.append(self.episode_rewards[env_id]) 73 | logs = dict() # deal with batch 74 | logs.update({ 75 | 'step_per_episode': self.episode_lengths[env_id], 76 | 'reward': self.episode_rewards[env_id], 77 | 'num_active_envs': len(self.active_env_ids) 78 | }) 79 | # print(logs) 80 | # exit(1) 81 | 82 | self.meter.update({k: v for k, v in logs.items()}) 83 | if self.ep_count % self.log_interval == 0 or self.ep_count >= self.ep_total: 84 | frm = inspect.stack()[1] 85 | mod = inspect.getmodule(frm[0]) 86 | print_log(f'{self.prefix} [{self.ep_count}/{self.ep_total}] {self.meter}', mod.__name__) 87 | print_log(f'{self.prefix} [{self.ep_count}/{self.ep_total}] {sum(self.eps_rwd)/len(self.eps_rwd)}', mod.__name__) 88 | 89 | def log_reset(self, env_id, obs): 90 | self.episode_lengths[env_id] = 0 91 | self.episode_rewards[env_id] = 0. 92 | 93 | def set_prefix(self, prefix): 94 | self.prefix = prefix 95 | 96 | def write_summary(self, extra_metrics=None): 97 | if extra_metrics: 98 | self.meter.update(extra_metrics) 99 | summary = self.summary() 100 | if len(self.eps_rwd)>0: 101 | summary["eps_rwd_avg"] = sum(self.eps_rwd)/len(self.eps_rwd) 102 | else: 103 | summary["eps_rwd_avg"] = 0 104 | print_log(f'{self.prefix} Summary:\n' + '\n'.join([f' {k}\t{v:.4f}' for k, v in summary.items()]), __name__) 105 | for key, value in summary.items(): 106 | if self.tb_prefix: 107 | key = self.tb_prefix + '/' + key 108 | self.tb_writer.add_scalar(key, value, global_step=self.global_step) 109 | summary = {'prefix': self.tb_prefix, 'step': self.global_step, **summary} 110 | self.json_writer.write(json.dumps(summary) + '\n') 111 | self.json_writer.flush() 112 | 113 | def summary(self): 114 | return {key: self.meter[key].avg for key in self.meter} 115 | 116 | def reset(self, prefix=None): 117 | self.ep_count = self.step_count = 0 118 | self.meter.reset() 119 | self.active_env_ids = set() 120 | if prefix is not None: 121 | self.set_prefix(prefix) 122 | 123 | def state_dict(self): 124 | # logging status within epoch is not saved 125 | return { 126 | 'global_step': self.global_step, 127 | } 128 | 129 | def load_state_dict(self, state_dict): 130 | self.global_step = state_dict['global_step'] 131 | -------------------------------------------------------------------------------- /agentlightning/README.md: -------------------------------------------------------------------------------- 1 | # Agent Lightning 2 | 3 | **Update 7/22/2025: We are maintaining the latest version of this project at https://github.com/microsoft/agent-lightning** 4 | 5 | **Warning: This project is currently in a research preview stage. The APIs are not stable and the functionalities are not well tested.** 6 | 7 | Welcome to Agent Lightning! This guide will walk you through setting up and running the project. 8 | 9 | ## Installation 10 | 11 | First, let's get your environment set up. We'll be using `/path/to/agentlightning` to refer to the directory containing this README file. 12 | 13 | ### 1. Set Up Your Environment 14 | 15 | We strongly recommend creating a new virtual environment to avoid conflicts with other packages. You can use either `conda` or `venv`. **Python 3.10 or later** is recommended. 16 | 17 | ### 2. Install Core Dependencies 18 | 19 | Next, let's install the essential packages: `uv`, `PyTorch`, `FlashAttention`, and `vLLM`. 20 | 21 | * **Install `uv`** (This is required for some MCP agents): 22 | 23 | ```bash 24 | curl -LsSf https://astral.sh/uv/install.sh | sh 25 | ``` 26 | 27 | * **Install `PyTorch`, `FlashAttention`, and `vLLM`**: 28 | The following versions and installation order have been tested and are confirmed to work. 29 | 30 | ```bash 31 | pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 32 | pip install flash-attn --no-build-isolation 33 | pip install vllm==v0.8.5.post1 34 | ``` 35 | 36 | ### 3. Install Patched VERL 37 | 38 | Agent Lightning requires a patched version of VERL for full compatibility. If you have a different version of VERL installed, please uninstall it first. 39 | 40 | ```bash 41 | # Clone the specific commit of VERL 42 | git clone https://github.com/volcengine/verl /path/to/your/verl 43 | cd /path/to/your/verl 44 | git checkout 2dc3e0ebadb479bb3f2b48cfc7f28a3b70d5ce60 45 | 46 | # Install the patched version 47 | pip install -e . 48 | 49 | # Apply the patch from Agent Lightning 50 | cd /path/to/agentlightning 51 | bash scripts/verl_git_apply.sh /path/to/your/verl 52 | ``` 53 | 54 | ### 4. Install Agent Lightning 55 | 56 | Now, you're ready to install Agent Lightning itself. 57 | 58 | ```bash 59 | cd /path/to/agentlightning 60 | pip install -e . 61 | ``` 62 | 63 | ### 5. Install Optional Frameworks 64 | 65 | If you plan to use other agent frameworks, you can install them with the following commands. If you don't need these, feel free to skip this step. 66 | 67 | ```bash 68 | # AutoGen (Recommended to install first) 69 | pip install "autogen-agentchat" "autogen-ext[openai]" 70 | 71 | # LiteLLM 72 | pip install "litellm[proxy]" 73 | 74 | # MCP 75 | pip install mcp 76 | 77 | # OpenAI Agents 78 | pip install openai-agents 79 | 80 | # LangChain 81 | pip install langgraph "langchain[openai]" langchain-community langchain-text-splitters 82 | 83 | # SQL-related dependencies 84 | pip install sqlparse nltk 85 | ``` 86 | 87 | Don't worry if dependency conflicts arise during this step. Follow the installation order above and the conflicts generally do not matter. 88 | 89 | ## Architecture 90 | 91 | Currently, Agent Lightning is built around a **training server** and one or multiple **agents**. 92 | 93 | * The **server** manages the training data, prepares samples for the agents, and provides the LLM endpoint. 94 | * **Agents** retrieve samples from the server, process them (which may involve interacting with the LLM), and send the results back. These results, or "trajectories," are lists of prompts and responses from the LLM. 95 | * The **server** then collects these trajectories and computes the loss to optimize the language models. 96 | 97 | ## Examples 98 | 99 | For more detailed examples, please see the `examples` folder. 100 | 101 | ## Important Caveats 102 | 103 | 1. **AgentOps Integration**: Agent Lightning uses [AgentOps](https://github.com/AgentOps-AI/agentops) for agent tracking by default. If you're already using AgentOps in your own code, you'll need to disable our managed AgentOps client by setting `agentops_managed` to `False` in the `Trainer` and handle your integration by yourself. 104 | 105 | 2. **Debugging Traces**: If you encounter issues with tracing, you can visualize the trace tree using `processor.last_trace()._tree_visualize("tree_graph")`. Please note that this API is experimental and may change in future releases. 106 | 107 | 3. **Launching the Server and Agents**: Currently, the training server and agent clients must be launched in separate processes. You can open two terminal windows or run one of them in the background. The order in which you launch them generally doesn't matter. 108 | 109 | 4. **Environment Variables**: The environment variables and working directory at the time of `ray init` are important. If you run into "file not found" errors, try restarting Ray from your current working directory. 110 | 111 | 5. **Handling Timeouts**: The training server may hang if samples fail or time out on the agent side. To prevent this, we recommend setting limits on the prompt and response lengths, as this is the most common cause of failures. 112 | -------------------------------------------------------------------------------- /agentlightning/examples/calc_x/calc_agent.py: -------------------------------------------------------------------------------- 1 | import math 2 | import string 3 | import re 4 | from typing import Any 5 | 6 | from agentlightning.client import SamplingParameters 7 | import sympy 8 | from autogen_agentchat.agents import AssistantAgent 9 | from autogen_core.models import ModelFamily 10 | from autogen_ext.models.openai import OpenAIChatCompletionClient 11 | from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams 12 | 13 | from agentlightning import Trainer, LitAgent, SamplingParameters, reward, configure_logger 14 | 15 | configure_logger() 16 | 17 | calculator_mcp_server = StdioServerParams(command="uvx", args=["mcp-server-calculator"]) 18 | 19 | 20 | # Copied and adapted from https://github.com/prompteus/calc-x/blob/master/gadgets/metrics.py 21 | 22 | 23 | def normalize_option(option: str) -> str: 24 | """ 25 | >>> normalize_option(" (A) \n") 26 | 'A' 27 | """ 28 | return re.sub(r"(\s+|\(|\))", "", option) 29 | 30 | 31 | def is_option_result(result: str) -> bool: 32 | """ 33 | >>> is_option_result(" A) \n") 34 | True 35 | >>> is_option_result(" 23/7 ") 36 | False 37 | """ 38 | return normalize_option(result) in list(string.ascii_letters) 39 | 40 | 41 | def float_eval(input_str: str) -> float: 42 | if " = around " in input_str: 43 | input_str = input_str.split(" = around ")[0] 44 | expr = sympy.parse_expr(input_str, evaluate=True) 45 | return float(expr.evalf()) 46 | 47 | 48 | def scalar_are_results_same(pred_result: str, true_result: str, rel_tol: float) -> bool: 49 | pred_result = str(pred_result) if pred_result is not None else "" 50 | true_result = str(true_result) if true_result is not None else "" 51 | 52 | if pred_result.strip() == true_result.strip(): 53 | return True 54 | 55 | if is_option_result(true_result): 56 | # The task is to select correct option 57 | true_result = normalize_option(true_result) 58 | pred_result = normalize_option(pred_result) 59 | return pred_result == true_result 60 | 61 | # The task is to calculate the result as a number 62 | try: 63 | pred_float = float_eval(pred_result) 64 | true_float = float_eval(true_result) 65 | return math.isclose(pred_float, true_float, rel_tol=rel_tol) 66 | except Exception: 67 | pass 68 | 69 | return False 70 | 71 | 72 | @reward 73 | async def eval(prediction: str, ground_truth: str) -> float: 74 | return float(scalar_are_results_same(prediction, ground_truth, 1e-2)) 75 | 76 | 77 | def get_agent(model, openai_base_url, temperature, workbench): 78 | model_client = OpenAIChatCompletionClient( 79 | model=model, 80 | base_url=openai_base_url, 81 | api_key="token-abc123", 82 | model_info={ 83 | "vision": False, 84 | "function_calling": True, 85 | "json_output": False, 86 | "family": ModelFamily.UNKNOWN, 87 | "structured_output": False, 88 | }, 89 | temperature=temperature, 90 | ) 91 | 92 | calc_agent = AssistantAgent( 93 | name="calc", 94 | model_client=model_client, 95 | workbench=workbench, 96 | reflect_on_tool_use=True, 97 | ) 98 | return calc_agent 99 | 100 | 101 | class CalcAgent(LitAgent): 102 | 103 | async def training_rollout_async( 104 | self, sample: Any, *, sampling_parameters: SamplingParameters | None = None, rollout_id: str | None = None 105 | ) -> Any: 106 | assert sampling_parameters is not None 107 | async with McpWorkbench(calculator_mcp_server) as workbench: 108 | calc_agent = get_agent( 109 | sampling_parameters["model"], 110 | self.trainer.get_openai_endpoint(), 111 | sampling_parameters["temperature"], 112 | workbench, 113 | ) 114 | try: 115 | output_format = "Output the answer when you are ready. The answer should be surrounded by three sharps (`###`), in the form of ### ANSWER: ###." 116 | task = sample["question"] + " " + output_format 117 | result = await calc_agent.run(task=task) 118 | # evaluate 119 | answer = re.search(r"###\s*ANSWER:\s*(.+?)(\s*###|$)", result.messages[-1].content) 120 | if answer: 121 | answer = answer.group(1) 122 | else: 123 | answer = result.messages[-1].content 124 | except Exception as e: 125 | print("Failure:", str(e)) 126 | answer = "None" 127 | reward = await eval(answer, str(sample["result"])) # reward is tracked with the decorator 128 | print("answer: {} ground_truth: {} reward: {}".format(answer, sample["result"], reward)) 129 | 130 | async def validation_rollout_async( 131 | self, sample: Any, *, sampling_parameters: SamplingParameters | None = None, rollout_id: str | None = None 132 | ) -> Any: 133 | return await self.training_rollout_async( 134 | sample, sampling_parameters={"temperature": 0, "model": sampling_parameters["model"]}, rollout_id=rollout_id 135 | ) 136 | 137 | 138 | if __name__ == "__main__": 139 | Trainer(n_workers=4).fit(CalcAgent(), "http://localhost:9999/") 140 | -------------------------------------------------------------------------------- /bootorl/main/extract_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import glob 7 | import random 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from argparse import Namespace 13 | 14 | sys.path.append(os.getcwd()) 15 | 16 | import utils 17 | from utils import AnalyzeDiscretizedDataset 18 | from utils import Tester 19 | from model import GPT 20 | 21 | 22 | # python main/extract_distribution.py --dataset halfcheetah-medium-v2 --checkpoint ... 23 | 24 | 25 | def set_seed(seed): 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | 32 | def find_ckpt_epoch(loadpath, target="latest"): 33 | assert target in ["all", "latest", "earliest"] or isinstance(target, int) 34 | states = glob.glob1(loadpath, 'state_*') 35 | epochs = [int(state.replace('state_', '').replace('.pt', '')) for state in states] 36 | if target == "all": 37 | return epochs 38 | if target == "latest": 39 | return [max(epochs)] 40 | elif target == "earliest": 41 | return [min(epochs)] 42 | elif isinstance(target, int): 43 | assert target in epochs 44 | return [target] 45 | 46 | 47 | # Setup 48 | parser = utils.ArgParser() 49 | args = parser.parse_args() 50 | logger = parser.get_logger() 51 | set_seed(args.seed) 52 | 53 | # Environment 54 | env = utils.load_environment(args.dataset) 55 | 56 | # Dataset 57 | dataset = AnalyzeDiscretizedDataset( 58 | logger=logger, 59 | env=args.dataset, 60 | n_bins=args.n_bins, 61 | sequence_length=args.sequence_length, 62 | penalty=args.termination_penalty, 63 | discount=args.discount, 64 | ) 65 | 66 | obs_dim = dataset.observation_dim 67 | act_dim = dataset.action_dim 68 | trans_dim = dataset.joined_dim 69 | block_size = args.sequence_length * trans_dim - 1 70 | 71 | # Model 72 | model_config = Namespace( 73 | vocab_size=args.n_bins, 74 | block_size=block_size, 75 | n_layer=args.n_layer, 76 | n_head=args.n_head, 77 | n_embd=args.n_embd * args.n_head, 78 | observation_dim=obs_dim, 79 | action_dim=act_dim, 80 | transition_dim=trans_dim, 81 | action_weight=args.action_weight, 82 | reward_weight=args.reward_weight, 83 | value_weight=args.value_weight, 84 | embd_pdrop=args.embd_pdrop, 85 | resid_pdrop=args.resid_pdrop, 86 | attn_pdrop=args.attn_pdrop 87 | ) 88 | 89 | tester_config = Namespace( 90 | logger=logger, 91 | batch_size=args.batch_size, 92 | num_workers=0, 93 | ) 94 | tester = Tester(tester_config) 95 | 96 | 97 | epoch_ranges = sorted(find_ckpt_epoch(args.checkpoint_path)) 98 | logger.debug(f"Find checkpoint epochs: {epoch_ranges}") 99 | all_info = {} 100 | for epoch in epoch_ranges: 101 | logger.info(f'Loading model epoch: {epoch}') 102 | state_path = os.path.join(args.checkpoint_path, f'state_{epoch}.pt') 103 | state = torch.load(state_path) 104 | 105 | model = GPT(model_config) 106 | model.load_state_dict(state, strict=True) 107 | 108 | device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" 109 | if args.parallel: 110 | model = torch.nn.DataParallel(model).to(device) 111 | else: 112 | model = model.to(device) 113 | 114 | generated_dataset = tester.generate_data(model, dataset, with_tqdm=args.with_tqdm) 115 | 116 | masks = torch.stack([generated_dataset[i]["mask"][-1] for i in range(len(generated_dataset))], dim=0) 117 | trans_origin = torch.stack([generated_dataset[i]["origin"][-trans_dim:] for i in range(len(generated_dataset))], dim=0)[masks] 118 | trans_discretized = torch.stack([generated_dataset[i]["discretized"][1][-trans_dim:] for i in range(len(generated_dataset))], dim=0)[masks] 119 | trans_generated_tf_realr = torch.stack([generated_dataset[i]["generated_tf_realr"][-trans_dim:] for i in range(len(generated_dataset))], dim=0)[masks] 120 | trans_recon_tf_realr = dataset.discretizer.reconstruct(trans_generated_tf_realr) 121 | trans_generated_ar_realr = torch.stack([generated_dataset[i]["generated_ar_realr"][-trans_dim:] for i in range(len(generated_dataset))], dim=0)[masks] 122 | trans_recon_ar_realr = dataset.discretizer.reconstruct(trans_generated_ar_realr) 123 | trans_generated_tf_genr = torch.stack([generated_dataset[i]["generated_tf_genr"][-trans_dim:] for i in range(len(generated_dataset))], dim=0)[masks] 124 | trans_recon_tf_genr = dataset.discretizer.reconstruct(trans_generated_tf_genr) 125 | trans_generated_ar_genr = torch.stack([generated_dataset[i]["generated_ar_genr"][-trans_dim:] for i in range(len(generated_dataset))], dim=0)[masks] 126 | trans_recon_ar_genr = dataset.discretizer.reconstruct(trans_generated_ar_genr) 127 | 128 | np.savez( 129 | os.path.join(args.output_dir, "transitions.npz"), 130 | trans_origin=trans_origin, 131 | trans_discretized=trans_discretized, 132 | trans_generated_tf_realr=trans_generated_tf_realr, 133 | trans_recon_tf_realr=trans_recon_tf_realr, 134 | trans_generated_ar_realr=trans_generated_ar_realr, 135 | trans_recon_ar_realr=trans_recon_ar_realr, 136 | trans_generated_tf_genr=trans_generated_tf_genr, 137 | trans_recon_tf_genr=trans_recon_tf_genr, 138 | trans_generated_ar_genr=trans_generated_ar_genr, 139 | trans_recon_ar_genr=trans_recon_ar_genr, 140 | ) 141 | -------------------------------------------------------------------------------- /bootorl/utils/renderer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import numpy as np 5 | import torch 6 | import gym 7 | import os 8 | import contextlib 9 | import mujoco_py as mjc 10 | from PIL import Image 11 | 12 | @contextlib.contextmanager 13 | def suppress_output(): 14 | """ 15 | A context manager that redirects stdout and stderr to devnull 16 | https://stackoverflow.com/a/52442331 17 | """ 18 | with open(os.devnull, 'w') as fnull: 19 | with contextlib.redirect_stderr(fnull) as err, contextlib.redirect_stdout(fnull) as out: 20 | yield (err, out) 21 | 22 | with suppress_output(): 23 | import d4rl 24 | 25 | ANTMAZE_BOUNDS = { 26 | 'antmaze-umaze-v0': (-3, 11), 27 | 'antmaze-medium-play-v0': (-3, 23), 28 | 'antmaze-medium-diverse-v0': (-3, 23), 29 | 'antmaze-large-play-v0': (-3, 39), 30 | 'antmaze-large-diverse-v0': (-3, 39), 31 | } 32 | 33 | 34 | def load_environment(name): 35 | with suppress_output(): 36 | wrapped_env = gym.make(name) 37 | env = wrapped_env.unwrapped 38 | env.max_episode_steps = wrapped_env._max_episode_steps 39 | env.name = name 40 | return env 41 | 42 | 43 | def to_np(x): 44 | if torch.is_tensor(x): 45 | x = x.detach().cpu().numpy() 46 | return x 47 | 48 | 49 | def save_gif(image_list, images_savepath): 50 | w, h = image_list[0].shape[0], image_list[0].shape[1] 51 | images = [] 52 | for i, img in enumerate(image_list): 53 | img_ = Image.fromarray(img) 54 | w_ = int(w * i / len(image_list)) 55 | pbar = Image.new('RGBA', (w_, h // 50), (0, 0, 255, 0)) # blue progress bar 56 | img_.paste(pbar, (0, h - h // 50)) 57 | images.append(img_) 58 | images[0].save(images_savepath, format="GIF", append_images=images[1:], save_all=True, duration=len(images)/24, loop=0) 59 | 60 | 61 | def set_state(env, state): 62 | qpos_dim = env.sim.data.qpos.size 63 | qvel_dim = env.sim.data.qvel.size 64 | qstate_dim = qpos_dim + qvel_dim 65 | 66 | if 'ant-' in env.name: 67 | ypos = np.zeros(1) 68 | state = np.concatenate([ypos, state]) 69 | 70 | if state.size == qpos_dim - 1 or state.size == qstate_dim - 1: 71 | xpos = np.zeros(1) 72 | state = np.concatenate([xpos, state]) 73 | 74 | if state.size == qpos_dim: 75 | qvel = np.zeros(qvel_dim) 76 | state = np.concatenate([state, qvel]) 77 | 78 | if 'ant-' in env.name: 79 | xpos = np.zeros(1) 80 | state = np.concatenate([xpos, state])[:qstate_dim] 81 | 82 | if state.size > qpos_dim + qvel_dim: 83 | state = state[:qstate_dim] 84 | 85 | assert state.size == qpos_dim + qvel_dim 86 | 87 | env.set_state(state[:qpos_dim], state[qpos_dim:]) 88 | 89 | 90 | class Renderer: 91 | def __init__(self, env, observation_dim=None, action_dim=None): 92 | self.env = load_environment(env) if type(env) is str else env 93 | self.observation_dim = observation_dim or np.prod(self.env.observation_space.shape) 94 | self.action_dim = action_dim or np.prod(self.env.action_space.shape) 95 | self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) 96 | self.set_viewer() 97 | 98 | def set_viewer(self, render_kwargs=None): 99 | if render_kwargs is None: 100 | if self.env.name.startswith("antmaze"): 101 | pos = sum(ANTMAZE_BOUNDS.get(self.env.name)) / 2 102 | render_kwargs = { 103 | 'trackbodyid': 2, 104 | 'distance': 3, 105 | 'lookat': [pos, pos, pos * 4], 106 | 'elevation': -90 107 | } 108 | else: 109 | render_kwargs = { 110 | 'trackbodyid': 2, 111 | 'distance': 2, 112 | 'lookat': [0, -0.5, 1], 113 | 'elevation': -20 114 | } 115 | for key, val in render_kwargs.items(): 116 | if key == 'lookat': 117 | self.viewer.cam.lookat[:] = val[:] 118 | else: 119 | setattr(self.viewer.cam, key, val) 120 | 121 | def render(self, observation, dim=512): 122 | observation = to_np(observation) 123 | set_state(self.env, observation) 124 | dim = (dim, dim) if type(dim) == int else dim 125 | self.viewer.render(*dim) 126 | data = self.viewer.read_pixels(*dim, depth=False) 127 | data = data[::-1, :, :] 128 | return data 129 | 130 | def render_observations(self, observations, **kwargs): 131 | images = [] 132 | for observation in observations: 133 | img = self.render(observation, **kwargs) 134 | images.append(img) 135 | return images 136 | 137 | def save_gif(self, image_list, images_savepath): 138 | save_gif(image_list, images_savepath) 139 | 140 | 141 | if __name__ == "__main__": 142 | import sys 143 | sys.path.append(os.getcwd()) 144 | from dataset import DiscretizedDataset 145 | dataset = DiscretizedDataset( 146 | logger=None, 147 | env="antmaze-large-play-v0", 148 | n_bins=100, 149 | sequence_length=10, 150 | penalty=0, 151 | discount=0.99, 152 | ) 153 | traj_len = dataset.path_lengths[0] 154 | traj = dataset.joined_segmented[0, :traj_len] 155 | print(traj.shape) 156 | 157 | env = load_environment("antmaze-large-play-v0") 158 | renderer = Renderer(env, dataset.observation_dim, dataset.action_dim) 159 | imgs = renderer.render_observations(traj) 160 | 161 | images_savepath = "test.gif" 162 | save_gif(imgs, images_savepath) 163 | -------------------------------------------------------------------------------- /eppo/code/network/cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from code.network.base import NETWORKS 4 | import torch 5 | import torch.nn as nn 6 | from typing import List 7 | import numpy as np 8 | from utilsd import print_log,use_cuda 9 | 10 | #specially for atari 11 | @NETWORKS.register_module() 12 | class AtariCNN(nn.Module): 13 | def __init__(self, 14 | input_dims: int, 15 | hidden_dim: int=64, 16 | output_dim: int=64, 17 | num_layers: int=2, 18 | features_only: bool=False): 19 | super().__init__() 20 | self.state_embed = nn.Sequential( 21 | nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), 22 | nn.Conv2d(32, 64, 4, stride=2, padding=0), nn.ReLU(), 23 | nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), 24 | nn.Flatten()) 25 | if not features_only: 26 | self.state_embed = nn.Sequential( 27 | self.state_embed, nn.Linear(3136, output_dim), nn.ReLU() 28 | ) 29 | self.output_dim = output_dim 30 | else: 31 | self.output_dim = 3136 32 | 33 | 34 | def forward(self,x): 35 | #x = torch.transpose(x,1,3).float() 36 | # print(x) 37 | # exit(1) 38 | return self.state_embed(x.float()) 39 | 40 | 41 | @NETWORKS.register_module() 42 | class MiniCNN(nn.Module): 43 | def __init__(self, 44 | input_dims: List[int], #c,h,w 45 | hidden_dim: int=64, 46 | output_dim: int=64, 47 | num_layers: int=2): 48 | super().__init__() 49 | # self.state_embed = nn.Sequential( 50 | # nn.Conv2d(input_dims[0], 16, 4, stride=2, padding=0), nn.ReLU(), 51 | # nn.Conv2d(16, 32, 4, stride=2, padding=0), nn.ReLU(), 52 | # nn.Conv2d(32, 32, 3, stride=2, padding=0), nn.ReLU(), 53 | # nn.Flatten()) 54 | #specified for POMDP scene 55 | #print(f"==========={torch.__version__}==============") 56 | self.state_embed = nn.Sequential( 57 | nn.Conv2d(input_dims[0], 16, kernel_size=4, stride=2, padding=0), nn.ReLU(inplace=True), 58 | nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0), nn.ReLU(inplace=True), 59 | # nn.Conv2d(32, 32, 3, stride=2, padding=0), nn.ReLU(), 60 | nn.Flatten()) 61 | with torch.no_grad(): 62 | tmp_data = torch.zeros(1, input_dims[0], input_dims[1], input_dims[2]) 63 | # if use_cuda(): 64 | # tmp_data = tmp_data.cuda() 65 | self.outter_dim = np.prod(self.state_embed(tmp_data).shape[1:]) 66 | self.net = nn.Sequential( 67 | self.state_embed, 68 | nn.Linear(self.outter_dim, output_dim), nn.ReLU(inplace=True) 69 | ) 70 | 71 | self.output_dim = output_dim 72 | 73 | def forward(self,x): 74 | #x = torch.transpose(x,1,3).float() 75 | # print(x) 76 | # exit(1) 77 | return self.net(x.float()) 78 | 79 | 80 | @NETWORKS.register_module() 81 | class MiniLargeCNN(nn.Module): 82 | def __init__(self, 83 | input_dims: List[int], #c,h,w 84 | hidden_dim: int=64, 85 | output_dim: int=64, 86 | num_layers: int=2): 87 | super().__init__() 88 | self.state_embed = nn.Sequential( 89 | nn.Conv2d(input_dims[0], 16, kernel_size=4, stride=2, padding=0), nn.ReLU(inplace=True), 90 | nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=0), nn.ReLU(inplace=True), 91 | nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=0), nn.ReLU(inplace=True), 92 | # nn.Conv2d(32, 32, 3, stride=2, padding=0), nn.ReLU(), 93 | nn.Flatten()) 94 | with torch.no_grad(): 95 | tmp_data = torch.zeros(1, input_dims[0], input_dims[1], input_dims[2]) 96 | # if use_cuda(): 97 | # tmp_data = tmp_data.cuda() 98 | self.outter_dim = np.prod(self.state_embed(tmp_data).shape[1:]) 99 | self.net = nn.Sequential( 100 | self.state_embed, 101 | nn.Linear(self.outter_dim, output_dim), nn.ReLU(inplace=True) 102 | ) 103 | 104 | self.output_dim = output_dim 105 | 106 | def forward(self,x): 107 | #x = torch.transpose(x,1,3).float() 108 | # print(x) 109 | # exit(1) 110 | return self.net(x.float()) 111 | 112 | 113 | @NETWORKS.register_module() 114 | class SokoCNN(nn.Module): 115 | def __init__(self, 116 | input_dims: List[int], #c,h,w 117 | hidden_dim: int=64, 118 | output_dim: int=64, 119 | num_layers: int=2): 120 | super().__init__() 121 | self.state_embed = nn.Sequential( 122 | nn.Conv2d(input_dims[0], 16, 8, stride=4, padding=0), nn.ReLU(), 123 | nn.Conv2d(16, 32, 4, stride=2, padding=0), nn.ReLU(), 124 | nn.Conv2d(32, 32, 3, stride=1, padding=0), nn.ReLU(), 125 | nn.Flatten()) 126 | with torch.no_grad(): 127 | self.outter_dim = np.prod( 128 | self.state_embed(torch.zeros(1, input_dims[0], input_dims[1], input_dims[2])).shape[1:]) 129 | self.net = nn.Sequential( 130 | self.state_embed, 131 | nn.Linear(self.outter_dim, output_dim), nn.ReLU(inplace=True) 132 | ) 133 | 134 | self.output_dim = output_dim 135 | 136 | def forward(self,x): 137 | #x = torch.transpose(x,1,3).float() 138 | # print(x) 139 | # exit(1) 140 | return self.net(x.float()) -------------------------------------------------------------------------------- /agentlightning/agentlightning/instrumentation/agentops.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | import signal 4 | import socket 5 | import time 6 | 7 | import flask 8 | import setproctitle 9 | import opentelemetry.instrumentation.openai.shared.chat_wrappers 10 | from opentelemetry.instrumentation.openai.shared.chat_wrappers import ( 11 | _handle_response, 12 | dont_throw, 13 | ) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | _original_handle_response = _handle_response 18 | 19 | 20 | @dont_throw 21 | def _handle_response_with_tokens(response, span, *args, **kwargs): 22 | _original_handle_response(response, span, *args, **kwargs) 23 | if hasattr(response, "prompt_token_ids"): 24 | span.set_attribute("prompt_token_ids", list(response.prompt_token_ids)) 25 | if hasattr(response, "response_token_ids"): 26 | span.set_attribute("response_token_ids", list(response.response_token_ids[0])) 27 | 28 | # For LiteLLM, response is a openai._legacy_response.LegacyAPIResponse 29 | if hasattr(response, "http_response") and hasattr(response.http_response, "json"): 30 | json_data = response.http_response.json() 31 | if isinstance(json_data, dict): 32 | if "prompt_token_ids" in json_data: 33 | span.set_attribute("prompt_token_ids", list(json_data["prompt_token_ids"])) 34 | if "response_token_ids" in json_data: 35 | span.set_attribute("response_token_ids", list(json_data["response_token_ids"][0])) 36 | 37 | 38 | def instrument_agentops(): 39 | opentelemetry.instrumentation.openai.shared.chat_wrappers._handle_response = _handle_response_with_tokens 40 | 41 | 42 | def agentops_local_server(): 43 | """ 44 | Returns a Flask app that can be used to test agentops integration. 45 | This server provides endpoints for token fetching and a catch-all endpoint. 46 | """ 47 | app = flask.Flask(__name__) 48 | 49 | @app.route("/v3/auth/token", methods=["POST"]) 50 | def fetch_token(): 51 | return {"token": "dummy", "project_id": "dummy"} 52 | 53 | @app.route("/", defaults={"path": ""}, methods=["GET", "POST"]) 54 | @app.route("/", methods=["GET", "POST"]) 55 | def catch_all(path): 56 | return {"path": path} 57 | 58 | return app 59 | 60 | 61 | def _run_server(**kwargs): 62 | """ 63 | Internal function to run the Flask server. 64 | This is used to avoid issues with multiprocessing and Flask's reloader. 65 | """ 66 | signal.signal(signal.SIGINT, signal.SIG_IGN) # Ignore SIGINT in worker processes 67 | setproctitle.setproctitle(multiprocessing.current_process().name) 68 | app = agentops_local_server() 69 | app.run(**kwargs) 70 | 71 | 72 | class AgentOpsServerManager: 73 | def __init__(self, daemon: bool = True, port: int | None = None): 74 | self.server_process: multiprocessing.Process | None = None 75 | self.server_port = port 76 | self.daemon = daemon 77 | logger.info("AgentOpsServerManager initialized.") 78 | 79 | def _find_available_port(self) -> int: 80 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 81 | s.bind(("", 0)) 82 | return s.getsockname()[1] 83 | 84 | def start(self): 85 | if self.server_process and self.server_process.is_alive(): 86 | logger.warning("AgentOps server process appears to be already running.") 87 | return 88 | 89 | if self.server_port is None: 90 | self.server_port = self._find_available_port() 91 | 92 | logger.info(f"Starting AgentOps local server on port {self.server_port}...") 93 | 94 | self.server_process = multiprocessing.Process( 95 | target=_run_server, 96 | kwargs={"host": "127.0.0.1", "port": self.server_port, "use_reloader": False, "debug": False}, 97 | daemon=self.daemon, 98 | name="AgentLightning-AgentOpsServer", 99 | ) 100 | self.server_process.start() 101 | logger.info( 102 | f"AgentOps local server process (PID: {self.server_process.pid}) started, targeting port {self.server_port}." 103 | ) 104 | time.sleep(0.5) # Brief wait for server to start up 105 | if not self.server_process.is_alive(): 106 | logger.error(f"AgentOps local server failed to start or exited prematurely.") 107 | 108 | def is_alive(self) -> bool: 109 | if self.server_process and self.server_process.is_alive(): 110 | return True 111 | return False 112 | 113 | def stop(self): 114 | if self.is_alive(): 115 | logger.info(f"Stopping AgentOps local server (PID: {self.server_process.pid})...") 116 | self.server_process.terminate() # Send SIGTERM 117 | self.server_process.join(timeout=5) # Wait for clean exit 118 | if self.server_process.is_alive(): 119 | logger.warning( 120 | f"AgentOps server (PID: {self.server_process.pid}) did not terminate gracefully, killing..." 121 | ) 122 | self.server_process.kill() # Force kill 123 | self.server_process.join(timeout=10) # Wait for kill 124 | self.server_process = None 125 | logger.info(f"AgentOps local server stopped.") 126 | else: 127 | logger.info("AgentOps local server was not running or already stopped.") 128 | 129 | def get_port(self) -> int | None: 130 | # Check liveness again in case it died since start() 131 | if self.is_alive() and self.server_port is not None: 132 | return self.server_port 133 | # If called after server stopped or failed, port might be stale or None 134 | if self.server_port is not None and (self.server_process is None or not self.server_process.is_alive()): 135 | logger.warning( 136 | f"AgentOps server port {self.server_port} is stored, but server process is not alive. Returning stored port." 137 | ) 138 | return self.server_port 139 | -------------------------------------------------------------------------------- /eppo/code/env/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import inspect 4 | import json 5 | from collections import defaultdict 6 | from typing import TextIO 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from torch.utils.tensorboard.writer import SummaryWriter 11 | from utilsd import get_tb_log_dir, get_output_dir 12 | from utilsd.avgmeter import MetricMeter 13 | from utilsd.logging import print_log 14 | 15 | from .finite_env import BaseLogger 16 | 17 | 18 | __all__ = ["Logger"] 19 | 20 | _tb_logger = _json_writer = None 21 | 22 | 23 | def _get_tb_logger() -> SummaryWriter: 24 | global _tb_logger 25 | if _tb_logger is None: 26 | _tb_logger = SummaryWriter(log_dir=get_tb_log_dir()) 27 | return _tb_logger 28 | 29 | 30 | def _get_json_writer() -> TextIO: 31 | global _json_writer 32 | if _json_writer is None: 33 | _json_writer = (get_output_dir() / "summary.json").open("a") 34 | return _json_writer 35 | 36 | 37 | def _groupby_category(category, value, key) -> defaultdict: 38 | """ 39 | Group the values by category. 40 | """ 41 | if not isinstance(value, (list, tuple, np.ndarray)): 42 | value = [value] * len(category) 43 | assert len(category) == len(value) 44 | grouped = defaultdict(list) 45 | for c, v in zip(category, value): 46 | grouped[f"{key}/{c}"].append(v) 47 | return grouped 48 | 49 | 50 | class Logger(BaseLogger): 51 | def __init__( 52 | self, ep_total, *, log_interval=100, prefix="Episode", tb_prefix="", count_global="episode", reward_func=np.mean 53 | ): 54 | self.meter = MetricMeter() 55 | self.ep_count = 0 56 | self.global_step = 0 57 | self.ep_total = ep_total 58 | self.log_interval = log_interval 59 | self.prefix = prefix 60 | self.logs = [] 61 | self.history = [] 62 | self.active_env_ids = set() 63 | assert count_global in ["step", "episode"] 64 | self.count_global = count_global 65 | 66 | self.tb_writer = _get_tb_logger() 67 | self.tb_prefix = tb_prefix 68 | 69 | self.json_writer = _get_json_writer() 70 | 71 | self.episode_lengths = dict() 72 | self.episode_rewards = dict() 73 | self.episode_rewards_info = dict() 74 | self.reward_func = reward_func 75 | 76 | def log_step(self, env_id, obs, rew, done, info): 77 | self.active_env_ids.add(env_id) 78 | self.episode_lengths[env_id] += 1 79 | self.episode_rewards[env_id] += self.reward_func(rew) 80 | 81 | for k, v in info.get("reward", {}).items(): 82 | self.episode_rewards_info[env_id][k] += v 83 | 84 | if self.count_global == "step": 85 | self.global_step += 1 86 | 87 | if not done: 88 | return 89 | 90 | if self.count_global == "episode": 91 | self.global_step += 1 92 | 93 | self.ep_count += 1 94 | index = dict(info["index"]) 95 | logs = dict(info["logs"]) # deal with batch 96 | logs.update( 97 | { 98 | "step_per_episode": self.episode_lengths[env_id], 99 | "reward": self.episode_rewards[env_id], 100 | "num_active_envs": len(self.active_env_ids), 101 | } 102 | ) 103 | logs.update({f"reward/{k}": v for k, v in self.episode_rewards_info[env_id].items()}) 104 | 105 | # TODO: meter.update support array input 106 | 107 | category = info.get("category", 'default') 108 | if not isinstance(category, (list, tuple, np.ndarray)): 109 | category = [category] 110 | cate_logs = {} 111 | for k, v in logs.items(): 112 | cate_logs.update(_groupby_category(category, v, k)) 113 | self.meter.update({k: np.nanmean(v) for k, v in cate_logs.items()}) 114 | 115 | self.meter.update({k: np.nanmean(v) for k, v in logs.items()}) 116 | self.logs += pd.DataFrame({**index, **logs, "category": category}).to_dict(orient="records") 117 | self.history.append({**index, **info["history"]}) 118 | if self.ep_count % self.log_interval == 0 or self.ep_count >= self.ep_total: 119 | frm = inspect.stack()[1] 120 | mod = inspect.getmodule(frm[0]) 121 | print_log(f"{self.prefix} [{self.ep_count}/{self.ep_total}] {self.meter}", mod.__name__) 122 | 123 | def log_reset(self, env_id, obs): 124 | self.episode_lengths[env_id] = 0 125 | self.episode_rewards[env_id] = 0.0 126 | self.episode_rewards_info[env_id] = defaultdict(float) 127 | 128 | def set_prefix(self, prefix): 129 | self.prefix = prefix 130 | 131 | def write_summary(self, extra_metrics=None): 132 | if extra_metrics: 133 | self.meter.update(extra_metrics) 134 | summary = self.summary() 135 | print_log(f"{self.prefix} Summary:\n" + "\n".join([f" {k}\t{v:.4f}" for k, v in summary.items()]), __name__) 136 | for key, value in summary.items(): 137 | if self.tb_prefix: 138 | key = self.tb_prefix + "/" + key 139 | self.tb_writer.add_scalar(key, value, global_step=self.global_step) 140 | summary = {"prefix": self.tb_prefix, "step": self.global_step, **summary} 141 | self.json_writer.write(json.dumps(summary) + "\n") 142 | self.json_writer.flush() 143 | 144 | def summary(self): 145 | return {key: self.meter[key].avg for key in self.meter} 146 | 147 | def reset(self, prefix=None): 148 | self.ep_count = self.step_count = 0 149 | self.meter.reset() 150 | self.active_env_ids = set() 151 | self.logs = [] 152 | if prefix is not None: 153 | self.set_prefix(prefix) 154 | 155 | def state_dict(self): 156 | # logging status within epoch is not saved 157 | return { 158 | "global_step": self.global_step, 159 | } 160 | 161 | def load_state_dict(self, state_dict): 162 | self.global_step = state_dict["global_step"] 163 | -------------------------------------------------------------------------------- /a2ls/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | from torch.utils.tensorboard import SummaryWriter 4 | from collections import defaultdict 5 | import json 6 | import os 7 | import shutil 8 | import torch 9 | import torchvision 10 | import numpy as np 11 | from termcolor import colored 12 | 13 | FORMAT_CONFIG = { 14 | 'rl': { 15 | 'train': [ 16 | ('episode', 'E', 'int'), ('step', 'S', 'int'), 17 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), 18 | ('batch_reward', 'BR', 'float'), ('actor_loss', 'A_LOSS', 'float'), 19 | ('critic_loss', 'CR_LOSS', 'float'), 20 | ('curl_loss', 'CU_LOSS', 'float'), 21 | ('auxi_loss', 'AU_LOSS', 'float') 22 | ], 23 | 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')] 24 | } 25 | } 26 | 27 | 28 | class AverageMeter(object): 29 | def __init__(self): 30 | self._sum = 0 31 | self._count = 0 32 | 33 | def update(self, value, n=1): 34 | self._sum += value 35 | self._count += n 36 | 37 | def value(self): 38 | return self._sum / max(1, self._count) 39 | 40 | 41 | class MetersGroup(object): 42 | def __init__(self, file_name, formating): 43 | self._file_name = file_name 44 | if os.path.exists(file_name): 45 | os.remove(file_name) 46 | self._formating = formating 47 | self._meters = defaultdict(AverageMeter) 48 | 49 | def log(self, key, value, n=1): 50 | self._meters[key].update(value, n) 51 | 52 | def _prime_meters(self): 53 | data = dict() 54 | for key, meter in self._meters.items(): 55 | if key.startswith('train'): 56 | key = key[len('train') + 1:] 57 | else: 58 | key = key[len('eval') + 1:] 59 | key = key.replace('/', '_') 60 | data[key] = meter.value() 61 | return data 62 | 63 | def _dump_to_file(self, data): 64 | with open(self._file_name, 'a') as f: 65 | f.write(json.dumps(data) + '\n') 66 | 67 | def _format(self, key, value, ty): 68 | template = '%s: ' 69 | if ty == 'int': 70 | template += '%d' 71 | elif ty == 'float': 72 | template += '%.04f' 73 | elif ty == 'time': 74 | template += '%.01f s' 75 | else: 76 | raise 'invalid format type: %s' % ty 77 | return template % (key, value) 78 | 79 | def _dump_to_console(self, data, prefix): 80 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 81 | pieces = ['{:5}'.format(prefix)] 82 | for key, disp_key, ty in self._formating: 83 | value = data.get(key, 0) 84 | pieces.append(self._format(disp_key, value, ty)) 85 | print('| %s' % (' | '.join(pieces))) 86 | 87 | def dump(self, step, prefix): 88 | if len(self._meters) == 0: 89 | return 90 | data = self._prime_meters() 91 | data['step'] = step 92 | self._dump_to_file(data) 93 | self._dump_to_console(data, prefix) 94 | self._meters.clear() 95 | 96 | 97 | class Logger(object): 98 | def __init__(self, log_dir, use_tb=True, config='rl'): 99 | self._log_dir = log_dir 100 | if use_tb: 101 | tb_dir = os.path.join(log_dir, 'tb') 102 | if os.path.exists(tb_dir): 103 | shutil.rmtree(tb_dir) 104 | self._sw = SummaryWriter(tb_dir) 105 | else: 106 | self._sw = None 107 | self._train_mg = MetersGroup( 108 | os.path.join(log_dir, 'train.log'), 109 | formating=FORMAT_CONFIG[config]['train'] 110 | ) 111 | self._eval_mg = MetersGroup( 112 | os.path.join(log_dir, 'eval.log'), 113 | formating=FORMAT_CONFIG[config]['eval'] 114 | ) 115 | 116 | def _try_sw_log(self, key, value, step): 117 | if self._sw is not None: 118 | self._sw.add_scalar(key, value, step) 119 | 120 | def _try_sw_log_image(self, key, image, step): 121 | if self._sw is not None: 122 | assert image.dim() == 3 123 | grid = torchvision.utils.make_grid(image.unsqueeze(1)) 124 | self._sw.add_image(key, grid, step) 125 | 126 | def _try_sw_log_video(self, key, frames, step): 127 | if self._sw is not None: 128 | frames = torch.from_numpy(np.array(frames)) 129 | frames = frames.unsqueeze(0) 130 | self._sw.add_video(key, frames, step, fps=30) 131 | 132 | def _try_sw_log_histogram(self, key, histogram, step): 133 | if self._sw is not None: 134 | self._sw.add_histogram(key, histogram, step) 135 | 136 | def log(self, key, value, step, n=1): 137 | assert key.startswith('train') or key.startswith('eval') 138 | if type(value) == torch.Tensor: 139 | value = value.item() 140 | self._try_sw_log(key, value / n, step) 141 | mg = self._train_mg if key.startswith('train') else self._eval_mg 142 | mg.log(key, value, n) 143 | 144 | def log_param(self, key, param, step): 145 | self.log_histogram(key + '_w', param.weight.data, step) 146 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 147 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 148 | if hasattr(param, 'bias'): 149 | self.log_histogram(key + '_b', param.bias.data, step) 150 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 151 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 152 | 153 | def log_image(self, key, image, step): 154 | assert key.startswith('train') or key.startswith('eval') 155 | self._try_sw_log_image(key, image, step) 156 | 157 | def log_video(self, key, frames, step): 158 | assert key.startswith('train') or key.startswith('eval') 159 | self._try_sw_log_video(key, frames, step) 160 | 161 | def log_histogram(self, key, histogram, step): 162 | assert key.startswith('train') or key.startswith('eval') 163 | self._try_sw_log_histogram(key, histogram, step) 164 | 165 | def dump(self, step): 166 | self._train_mg.dump(step, 'train') 167 | self._eval_mg.dump(step, 'eval') 168 | -------------------------------------------------------------------------------- /bootorl/main/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import json 7 | import pandas as pd 8 | import torch 9 | from argparse import Namespace 10 | 11 | sys.path.append(os.getcwd()) 12 | 13 | import utils 14 | from utils import DiscretizedDataset 15 | from model import GPT 16 | 17 | 18 | def to(xs, device): 19 | return [x.to(device) for x in xs] 20 | 21 | 22 | # Setup 23 | parser = utils.ArgParser() 24 | args = parser.parse_args() 25 | logger = parser.get_logger() 26 | 27 | 28 | # Resume check 29 | resume_file = os.path.join(args.output_dir, "resume_status.json") 30 | resume_model_state = os.path.join(args.output_dir, "resume_model_state.pt") 31 | if args.resume == "y" and os.path.exists(resume_file): 32 | with open(resume_file, "r") as f: 33 | resume_status = json.load(f) 34 | resume_flag = True 35 | logger.info(f"Resume mode enabled, reading status from file {resume_file}") 36 | if resume_status["train_finished"]: 37 | logger.info(f"Training has finished, exiting task.") 38 | exit(0) 39 | else: 40 | resume_flag = False 41 | logger.info(f"Resume mode disabled, starting a new task.") 42 | 43 | 44 | # Dataset 45 | dataset = DiscretizedDataset( 46 | logger=logger, 47 | env=args.dataset, 48 | n_bins=args.n_bins, 49 | sequence_length=args.sequence_length, 50 | penalty=args.termination_penalty, 51 | discount=args.discount, 52 | ) 53 | 54 | # Model 55 | obs_dim = dataset.observation_dim 56 | act_dim = dataset.action_dim 57 | trans_dim = dataset.joined_dim 58 | block_size = args.sequence_length * trans_dim - 1 59 | 60 | model_config = Namespace( 61 | vocab_size=args.n_bins, 62 | block_size=block_size, 63 | n_layer=args.n_layer, 64 | n_head=args.n_head, 65 | n_embd=args.n_embd * args.n_head, 66 | observation_dim=obs_dim, 67 | action_dim=act_dim, 68 | transition_dim=trans_dim, 69 | action_weight=args.action_weight, 70 | reward_weight=args.reward_weight, 71 | value_weight=args.value_weight, 72 | embd_pdrop=args.embd_pdrop, 73 | resid_pdrop=args.resid_pdrop, 74 | attn_pdrop=args.attn_pdrop 75 | ) 76 | 77 | device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" 78 | if args.parallel: 79 | logger.info(f"Using device: {device}; Enabling data parallel in torch") 80 | model = torch.nn.DataParallel(GPT(model_config)).to(device) 81 | else: 82 | logger.info(f"Using device: {device}; Disabling data parallel in torch") 83 | model = GPT(model_config).to(device) 84 | param_num = sum(p.numel() for p in model.parameters()) 85 | logger.info(f"Total number of parameters: {param_num}.") 86 | 87 | 88 | # Trainer 89 | trainer_config = Namespace( 90 | logger=logger, 91 | batch_size=args.batch_size, 92 | learning_rate=args.learning_rate, 93 | betas=(0.9, 0.95), 94 | grad_norm_clip=1.0, 95 | weight_decay=0.1, 96 | lr_decay=args.lr_decay, 97 | warmup_tokens=len(dataset)*block_size, 98 | final_tokens=20*len(dataset)*block_size, 99 | num_workers=0, 100 | ) 101 | trainer = utils.Trainer(trainer_config) 102 | 103 | 104 | # scale number of epochs to keep number of updates constant 105 | n_epochs = int(1e6 / len(dataset) * args.n_epochs_ref) 106 | # calculate epoch index of when to save model 107 | save_interval = max(1, int(n_epochs // args.n_saves)) 108 | save_epochs = [e for e in range(n_epochs) if e % save_interval == 0 or e == n_epochs - 1] 109 | 110 | bootstrap_kwargs = { 111 | "bootstrap": args.bootstrap, 112 | "bootstrap_type": args.bootstrap_type, 113 | "generation_type": args.generation_type, 114 | "generation_epoch_thresh": int(args.generation_epoch_thresh * n_epochs), 115 | "generation_len": args.generation_len, 116 | "generation_num": args.generation_num, 117 | "generation_confidence_type": args.generation_confidence_type, 118 | "generation_confidence_factor": args.generation_confidence_factor, 119 | "generation_real_r": args.generation_real_r, 120 | "generation_real_R": args.generation_real_R, 121 | } 122 | 123 | logger.info(f"Experiment: {args.exp_name} | Total epochs: {n_epochs} | Total saves: {len(save_epochs)}") 124 | logger.info(f"Saving model at epochs: {save_epochs}") 125 | if args.bootstrap: 126 | logger.info(f"Bootstrapping is enabled.") 127 | logger.info(f"Performing bootstrapping after epoch {bootstrap_kwargs['generation_epoch_thresh']}.") 128 | 129 | 130 | # Resume processing 131 | if resume_flag: 132 | init_epoch = resume_status["current_epoch"] + 1 133 | info = resume_status["info"] 134 | trainer.n_epochs = init_epoch 135 | trainer.n_tokens = resume_status["current_n_tokens"] 136 | timer = utils.Timer(total_num=(n_epochs - init_epoch)) 137 | model.load_state_dict(torch.load(resume_model_state), strict=True) 138 | else: 139 | init_epoch = 0 140 | info = {} 141 | timer = utils.Timer(total_num=n_epochs) 142 | 143 | 144 | for epoch in range(init_epoch, n_epochs): 145 | logger.info(f"Epoch: {epoch:>3d} / {n_epochs:>3d} | {args.exp_name}") 146 | info_ = trainer.train(model, dataset, with_tqdm=args.with_tqdm, log_freq=100, bootstrap_kwargs=bootstrap_kwargs) 147 | 148 | if epoch % save_interval == 0 or epoch == n_epochs - 1: 149 | model_path = os.path.join(args.output_dir, f'state_{epoch}.pt') 150 | logger.info(f"Epoch: {epoch:>3d} / {n_epochs:>3d} | Saving model to {model_path}") 151 | state = model.module.state_dict() if hasattr(model, "module") else model.state_dict() 152 | torch.save(state, model_path) 153 | info[epoch] = info_ 154 | 155 | if args.resume == "y": 156 | resume_status = { 157 | "current_epoch": epoch, 158 | "current_n_tokens": trainer.n_tokens.item(), 159 | "train_finished": (epoch == n_epochs - 1), 160 | "info": info, 161 | } 162 | with open(resume_file, "w") as f: 163 | json.dump(resume_status, f) 164 | logger.info(f"Dumping resume status to file {resume_file} at epoch {epoch}.") 165 | state = model.module.state_dict() if hasattr(model, "module") else model.state_dict() 166 | torch.save(state, resume_model_state) 167 | logger.info(f"Dumping resume model state to file {resume_model_state} at epoch {epoch}.") 168 | 169 | diff_time, total_time, eta = timer() 170 | logger.info(f"Epoch: {epoch:>3d} / {n_epochs:>3d} | Time: {diff_time} | Total time: {total_time} | ETA: {eta}") 171 | 172 | df = pd.DataFrame.from_dict(info, orient="index").sort_index() 173 | df.to_csv(os.path.join(args.output_dir, "train_epoch_info.csv")) 174 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/instrumentation/verl_chat_scheduler.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | 3 | # https://github.com/volcengine/verl/blob/bd94bd61fe4193e56f2845dc794004afbef7f818/examples/ppo_trainer/naive_chat_scheduler.py 4 | # This file is part of VERL example. It should be included in the VERL package but it's not currently. 5 | 6 | # Copyright 2024 Bytedance Ltd. and/or its affiliates 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | import asyncio 20 | from typing import Any, Dict, List 21 | 22 | import torch 23 | from openai.types.chat.chat_completion import ChatCompletion 24 | from tensordict import TensorDict 25 | 26 | from verl.protocol import DataProto 27 | from verl.workers.rollout.async_server import ChatCompletionScheduler 28 | 29 | 30 | class NaiveChatCompletionScheduler(ChatCompletionScheduler): 31 | """ 32 | A very naive implementation of ChatCompletionScheduler for demo purpose, 33 | only do single-turn chat completion. 34 | """ 35 | 36 | async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: 37 | kwargs = dict( 38 | n=self.config.n, 39 | max_completion_tokens=self.config.response_length, 40 | temperature=self.config.temperature, 41 | top_p=self.config.top_p, 42 | ) 43 | 44 | do_sample = batch.meta_info.get("do_sample", True) 45 | is_validate = batch.meta_info.get("validate", False) 46 | if not do_sample or is_validate: 47 | kwargs["n"] = 1 48 | kwargs["temperature"] = 0 49 | 50 | kwargs.update(sampling_params) 51 | print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}") 52 | 53 | async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): 54 | assert exception is None, f"exception: {exception}" 55 | conversation, batch_conversations, batch_index = ( 56 | info["conversation"], 57 | info["batch_conversations"], 58 | info["batch_index"], 59 | ) 60 | 61 | conversations = [] 62 | for choice in completions.choices: 63 | chat = conversation.copy() 64 | chat.append({"role": choice.message.role, "content": choice.message.content}) 65 | conversations.append(chat) 66 | batch_conversations[batch_index] = conversations 67 | 68 | # NOTE: we can call tools and resubmit chat completions here. 69 | # call_tools(completions, info) 70 | # await self.submit_chat_completions(callback2, ...) 71 | 72 | # TODO: we may need to control max concurrent requests here, or it will harm prefix cache hit rate. 73 | tasks, batch_conversations = [], [None] * len(batch) 74 | for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]): 75 | # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] 76 | tasks.append( 77 | asyncio.create_task( 78 | self.submit_chat_completions( 79 | callback=callback, 80 | callback_additional_info={ 81 | "batch_conversations": batch_conversations, 82 | "batch_index": batch_index, 83 | "conversation": list(conversation), 84 | }, 85 | model=self.model_name, 86 | messages=conversation.tolist(), 87 | **kwargs, 88 | ) 89 | ) 90 | ) 91 | await asyncio.gather(*tasks) 92 | print("[NaiveChatCompletionScheduler] generate_sequences done") 93 | 94 | return self._postprocess(batch, batch_conversations, kwargs["n"]) 95 | 96 | def _postprocess(self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int) -> DataProto: 97 | # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py 98 | # prompts: left pad 99 | # responses: right pad 100 | # input_ids: prompt + response 101 | # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] 102 | # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] 103 | 104 | # prompts: [prompt] from input dataset 105 | prompts = [self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False) for prompt in batch.non_tensor_batch["raw_prompt"]] 106 | 107 | # flatten batch_conversations if n > 1 108 | assert len(batch_conversations) == len(prompts) 109 | batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations] 110 | assert len(batch_conversations) == len(prompts) * n 111 | 112 | # sequences: [prompt + response] 113 | sequences = [self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) for conversation in batch_conversations] 114 | 115 | # responses: [response] 116 | # TODO: mask out tools calling tokens? 117 | responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] 118 | 119 | prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") 120 | responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") 121 | if n > 1: 122 | prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) 123 | prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) 124 | 125 | input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) 126 | attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) 127 | position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask 128 | 129 | batch = TensorDict( 130 | { 131 | "prompts": prompts["input_ids"], 132 | "responses": responses["input_ids"], 133 | "input_ids": input_ids, 134 | "attention_mask": attention_mask, 135 | "position_ids": position_ids, 136 | }, 137 | batch_size=len(input_ids), 138 | ) 139 | 140 | return DataProto(batch=batch) 141 | -------------------------------------------------------------------------------- /a2ls/README.md: -------------------------------------------------------------------------------- 1 | # A2LS: Reinforcement Learning with Automated Auxiliary Loss Search 2 | 3 | Code for NeurIPS 2022 paper [Reinforcement Learning with Automated Auxiliary Loss Search](https://seqml.github.io/a2ls/). 4 | 5 | This repository is the implementation of A2LS based on the official implementation of [CURL](https://mishalaskin.github.io/curl/) for the DeepMind control experiments. 6 | 7 | ## Installation 8 | 9 | All of the dependencies are in the `conda_env.yml` file. They can be installed manually or with the following command: 10 | 11 | ``` 12 | conda env create -f conda_env.yml 13 | ``` 14 | 15 | ## Instructions 16 | We present some running examples of training RL with auxiliary losses with our code base. 17 | 18 | ### A2-winner 19 | ![\mathcal{L}_{\text{A2-winner}} = \| h(g_{\theta}(s_{t+1}, a_{t+1}, a_{t+2}, a_{t+3})) - g_{\hat{\theta}}(r_t, r_{t+1}, s_{t+2}, s_{t+3}) \|_2](https://latex.codecogs.com/svg.image?%5Cmathcal%7BL%7D_%7B%5Ctext%7BA2-winner%7D%7D%20=%20%5C%7C%20h(g_%7B%5Ctheta%7D(s_%7Bt+1%7D,%20a_%7Bt+1%7D,%20a_%7Bt+2%7D,%20a_%7Bt+3%7D))%20-%20g_%7B%5Chat%7B%5Ctheta%7D%7D(r_t,%20r_%7Bt+1%7D,%20s_%7Bt+2%7D,%20s_%7Bt+3%7D)%20%5C%7C_2) 20 | 21 | 22 | To train a SAC agent with `A2-winner` on image-based Cheetah-Run with default hyper-parameters (please refer to appendix for detailed hyper-parameters for each experiment setting): 23 | ``` 24 | python train.py \ 25 | --domain_name cheetah_run \ 26 | --encoder_type pixel \ 27 | --agent auxi_sac \ 28 | --auxi_pred_horizon 4 \ 29 | --auxi_pred_input_s 1000 --auxi_pred_input_a 1111 --auxi_pred_input_r 1101 --auxi_pred_input_s_ 0\ 30 | --auxi_pred_output_s 0111 --auxi_pred_output_a 0000 --auxi_pred_output_r 0000 --auxi_pred_output_s_ 1\ 31 | --similarity_metric mse 32 | ``` 33 | 34 | To train a SAC agent with `A2-winner` on vector-based Cheetah-Run with default hyper-parameters (please refer to appendix for detailed hyper-parameters for each experiment setting): 35 | ``` 36 | python train.py \ 37 | --domain_name cheetah_run \ 38 | --encoder_type ofe --encoder_hidden_size 40 --num_layers 1 \ 39 | --agent auxi_sac \ 40 | --auxi_pred_horizon 4 \ 41 | --auxi_pred_input_s 1000 --auxi_pred_input_a 1111 --auxi_pred_input_r 1101 --auxi_pred_input_s_ 0\ 42 | --auxi_pred_output_s 0111 --auxi_pred_output_a 0000 --auxi_pred_output_r 0000 --auxi_pred_output_s_ 1\ 43 | --similarity_metric mse 44 | ``` 45 | 46 | ### A2-winner-v 47 | 48 | 49 | ![\mathcal{L}_{\text{A2-winner-v}} = \| h(g_{\theta}(s_{t}, a_{t}, a_{t+1}, s_{t+2} a_{t+2}, a_{t+3}, r_{t+3}, a_{t+4}, r_{t+4}, a_{t+5}, a_{t+7}, s_{t+8}, a_{t+8}, r_{t+8})) - g_{\hat{\theta}}(s_{t+1}, s_{t+3}, a_{t+4}, s_{t+6}, s_{t+9}) \|_2](https://latex.codecogs.com/svg.image?%5Cmathcal%7BL%7D_%7B%5Ctext%7BA2-winner-v%7D%7D%20=%20%5C%7C%20h(g_%7B%5Ctheta%7D(s_%7Bt%7D,%20a_%7Bt%7D,%20a_%7Bt+1%7D,%20s_%7Bt+2%7D%20a_%7Bt+2%7D,%20a_%7Bt+3%7D,%20r_%7Bt+3%7D,%20a_%7Bt+4%7D,%20r_%7Bt+4%7D,%20a_%7Bt+5%7D,%20a_%7Bt+7%7D,%20s_%7Bt+8%7D,%20a_%7Bt+8%7D,%20r_%7Bt+8%7D))%20-%20g_%7B%5Chat%7B%5Ctheta%7D%7D(s_%7Bt+1%7D,%20s_%7Bt+3%7D,%20a_%7Bt+4%7D,%20s_%7Bt+6%7D,%20s_%7Bt+9%7D)%20%5C%7C_2) 50 | 51 | 52 | To train a SAC agent with `A2-winner` on image-based Cheetah-Run with default hyper-parameters (please refer to appendix for detailed hyper-parameters for each experiment setting): 53 | ``` 54 | python train.py \ 55 | --domain_name cheetah_run \ 56 | --encoder_type pixel \ 57 | --agent auxi_sac \ 58 | --auxi_pred_horizon 9 \ 59 | --auxi_pred_input_s 101000001 --auxi_pred_input_a 111111011 --auxi_pred_input_r 000110001 --auxi_pred_input_s_ 0\ 60 | --auxi_pred_output_s 010100100 --auxi_pred_output_a 000010000 --auxi_pred_output_r 000000000 --auxi_pred_output_s_ 1\ 61 | --similarity_metric mse 62 | ``` 63 | 64 | To train a SAC agent with `A2-winner` on vector-based Cheetah-Run with default hyper-parameters (please refer to appendix for detailed hyper-parameters for each experiment setting): 65 | ``` 66 | python train.py \ 67 | --domain_name cheetah_run \ 68 | --encoder_type ofe --encoder_hidden_size 40 --num_layers 1 \ 69 | --agent auxi_sac \ 70 | --auxi_pred_horizon 9 \ 71 | --auxi_pred_input_s 101000001 --auxi_pred_input_a 111111011 --auxi_pred_input_r 000110001 --auxi_pred_input_s_ 0\ 72 | --auxi_pred_output_s 010100100 --auxi_pred_output_a 000010000 --auxi_pred_output_r 000000000 --auxi_pred_output_s_ 1\ 73 | --similarity_metric mse 74 | ``` 75 | 76 | 77 | ## Baselines running examples 78 | ### SAC 79 | To train a baseline SAC agent on image-based Cheetah-Run with default hyper-parameters: 80 | ``` 81 | python train.py \ 82 | --domain_name cheetah_run \ 83 | --encoder_type pixel \ 84 | --agent pixel_sac 85 | ``` 86 | 87 | To train a baseline SAC agent on image-based Cheetah-Run with default hyper-parameters and default architures (MLP): 88 | ``` 89 | python train.py \ 90 | --domain_name cheetah_run \ 91 | --encoder_type mlp --encoder_hidden_size 40 --num_layers 1\ 92 | --agent pixel_sac 93 | ``` 94 | 95 | To train a baseline SAC agent on image-based Cheetah-Run with default hyper-parameters and dense-connected architures (MLP): 96 | ``` 97 | python train.py \ 98 | --domain_name cheetah_run \ 99 | --encoder_type ofe --encoder_hidden_size 40 --num_layers 1\ 100 | --agent pixel_sac 101 | ``` 102 | 103 | ### CURL 104 | To train a baseline SAC agent with `CURL` loss on image-based Cheetah-Run with default hyper-parameters: 105 | ``` 106 | python train.py \ 107 | --domain_name cheetah_run \ 108 | --encoder_type pixel \ 109 | --agent curl_sac 110 | ``` 111 | 112 | To train a baseline SAC agent with `CURL` loss on vector-based Cheetah-Run with default hyper-parameters and default architures (MLP): 113 | ``` 114 | python train.py \ 115 | --domain_name cheetah_run \ 116 | --encoder_type mlp --encoder_hidden_size 40 --num_layers 1\ 117 | --agent curl_sac 118 | ``` 119 | 120 | To train a baseline SAC agent with `CURL` loss on vector-based Cheetah-Run with default hyper-parameters and dense-connected architures (MLP): 121 | ``` 122 | python train.py \ 123 | --domain_name cheetah_run \ 124 | --encoder_type ofe --encoder_hidden_size 40 --num_layers 1\ 125 | --agent curl_sac 126 | ``` 127 | 128 | ## Issues 129 | 130 | For GPU accelerated rendering, make sure EGL is installed on your machine and set `export MUJOCO_GL=egl`. 131 | 132 | For environment troubleshooting issues, see the DeepMind control documentation. 133 | 134 | 135 | ## Citation 136 | You are more than welcome to cite our paper: 137 | ``` 138 | @article{he2022reinforcement, 139 | title={Reinforcement Learning with Automated Auxiliary Loss Search}, 140 | author={He, Tairan and Zhang, Yuge and Ren, Kan and Liu, Minghuan and Wang, Che and Zhang, Weinan and Yang, Yuqing and Li, Dongsheng}, 141 | journal={Advances in Neural Information Processing Systems}, 142 | year={2022} 143 | } 144 | ``` -------------------------------------------------------------------------------- /bootorl/utils/planning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from .sample import sample_n 9 | from .timer import Timer 10 | 11 | VALUE_PLACEHOLDER = 0 12 | 13 | 14 | def make_prefix(discretizer, context, obs, device, prefix_context=True): 15 | obs_discrete = discretizer.discretize(obs, subslice=[0, obs.size]) 16 | obs_discrete = torch.tensor(obs_discrete, dtype=torch.long, device=device) 17 | prefix = torch.cat(context + [obs_discrete], dim=-1) if prefix_context else obs_discrete 18 | return prefix 19 | 20 | 21 | def extract_actions(x, observation_dim, action_dim, t=None): 22 | assert x.shape[1] == observation_dim + action_dim + 2 23 | actions = x[:, observation_dim:observation_dim+action_dim] 24 | return actions[t] if t is not None else actions 25 | 26 | 27 | def update_context(context, discretizer, obs, act, rew, device, max_context_transitions): 28 | # use a placeholder for value because input values are masked out by model 29 | rew_val = np.array([rew, VALUE_PLACEHOLDER]) 30 | transition = np.concatenate([obs, act, rew_val]) 31 | # discretize transition and convert to torch tensor 32 | transition_discrete = discretizer.discretize(transition) 33 | transition_discrete = torch.tensor(transition_discrete, dtype=torch.long, device=device) 34 | # add new transition to context 35 | context.append(transition_discrete) 36 | # crop context if necessary 37 | context = context[-max_context_transitions:] 38 | return context 39 | 40 | 41 | @torch.no_grad() 42 | def beam_plan( 43 | model, value_fn, x, rollout_steps, beam_width, n_expand, obs_dim, act_dim, 44 | discount=0.99, max_trans=None, k_obs=None, k_act=None, k_rew=1, cdf_obs=None, cdf_act=None, cdf_rew=None, 45 | with_tqdm=True, return_values=False 46 | ): 47 | # convert max number of transitions to max number of tokens 48 | transition_dim = obs_dim + act_dim + 2 49 | max_block = max_trans * transition_dim - 1 if max_trans else None 50 | 51 | # pass in max numer of tokens to sample function 52 | sample_kwargs = {'max_block': max_block, 'crop_increment': transition_dim} 53 | 54 | # repeat input for search 55 | x = x.repeat(beam_width, 1) 56 | 57 | # construct reward and discount tensors for estimating values 58 | rewards = torch.zeros(beam_width, rollout_steps + 1, device=x.device) 59 | discounts = discount ** torch.arange(rollout_steps + 1, device=x.device) 60 | 61 | pbar = tqdm(range(rollout_steps), leave=False) if with_tqdm else range(rollout_steps) 62 | for t in pbar: 63 | # repeat everything by `n_expand` before we sample actions 64 | x = x.repeat(n_expand, 1) 65 | rewards = rewards.repeat(n_expand, 1) 66 | 67 | # sample actions 68 | x, _ = sample_n(model, x, act_dim, topk=k_act, cdf=cdf_act, **sample_kwargs) 69 | 70 | # sample reward and value estimate 71 | x, r_probs = sample_n(model, x, 2, topk=k_rew, cdf=cdf_rew, **sample_kwargs) 72 | r_t, V_t = value_fn(r_probs) 73 | rewards[:, t] = r_t 74 | rewards[:, t+1] = V_t 75 | 76 | # estimate values using rewards up to `t` and terminal value at `t` 77 | values = (rewards * discounts).sum(dim=-1) 78 | 79 | # get `beam_width` best actions 80 | values, inds = torch.topk(values, beam_width) 81 | 82 | # index into search candidates to retain `beam_width` highest-reward sequences 83 | x = x[inds] 84 | rewards = rewards[inds] 85 | 86 | # sample next observation (unless we have reached the end of the planning horizon) 87 | if t < rollout_steps - 1: 88 | x, _ = sample_n(model, x, obs_dim, topk=k_obs, cdf=cdf_obs, **sample_kwargs) 89 | 90 | if with_tqdm: 91 | pbar.set_description(f"Context shape: {list(x.shape)} | " 92 | f"V_(t) estimate: [{V_t.min():.2f}, {V_t.max():.2f}] | " 93 | f"V_(t+{t}) estimate [{values.min():.2f}, {values.max():.2f}]") 94 | 95 | x = x.view(beam_width, -1, transition_dim) 96 | x = x[:, -rollout_steps:] 97 | 98 | # return best sequence 99 | argmax = values.argmax() 100 | best_sequence = x[argmax] 101 | best_value = (rewards * discounts)[argmax] 102 | 103 | if not return_values: 104 | return best_sequence 105 | else: 106 | return best_sequence, best_value 107 | 108 | 109 | def plan(args, env, dataset, model, logger, device='cuda:0'): 110 | timer = Timer(total_num=env.max_episode_steps) 111 | 112 | discretizer = dataset.discretizer 113 | discount = dataset.discount 114 | observation_dim = dataset.observation_dim 115 | action_dim = dataset.action_dim 116 | 117 | value_fn = lambda x: discretizer.value_fn(x, args.percentile) 118 | 119 | # main loop 120 | observation = env.reset() 121 | total_reward = 0 122 | 123 | ## observations for rendering 124 | rollout_states = [observation.copy().tolist()] 125 | predict_states = [observation.copy().tolist()] 126 | rollout_values = [] 127 | real_rewards = [] 128 | 129 | ## previous (tokenized) transitions for conditioning transformer 130 | context = [] 131 | 132 | T = env.max_episode_steps 133 | for t in range(T): 134 | prefix = make_prefix(discretizer, context, observation, device, args.prefix_context) 135 | sequence, best_value = beam_plan( 136 | model, value_fn, prefix, args.horizon, args.beam_width, args.n_expand, observation_dim, action_dim, 137 | discount, args.max_context_transitions, with_tqdm=args.with_tqdm, return_values=True, 138 | k_obs=args.k_obs, k_act=args.k_act, cdf_obs=args.cdf_obs, cdf_act=args.cdf_act, 139 | ) 140 | 141 | sequence_recon = discretizer.reconstruct(sequence) 142 | action = extract_actions(sequence_recon, observation_dim, action_dim, t=0) 143 | pred_observation = sequence_recon[0, :observation_dim] 144 | next_observation, reward, terminal, _ = env.step(action) 145 | 146 | rollout_states.append(next_observation.copy().tolist()) 147 | predict_states.append(pred_observation.copy().tolist()) 148 | rollout_values.append(best_value.cpu().numpy().tolist()) 149 | real_rewards.append(reward) 150 | 151 | total_reward += reward 152 | score = env.get_normalized_score(total_reward) 153 | context = update_context(context, discretizer, observation, action, reward, device, args.max_context_transitions) 154 | 155 | diff_time, total_time, eta = timer() 156 | logger.info(f"Step: {t:>4} / {T:>4} | r': {best_value[0]:.2f} | r: {reward:.2f} | R: {total_reward:.2f} | score: {score:.4f}") 157 | logger.debug(f"Previous time: {diff_time} | Total time: {total_time} | ETA: {eta}") 158 | 159 | if terminal: 160 | break 161 | observation = next_observation 162 | 163 | info = { 164 | "rollout_states": rollout_states[1:], 165 | "predict_states": predict_states[1:], 166 | "rollout_values": rollout_values, 167 | "real_rewards": real_rewards, 168 | } 169 | 170 | return info 171 | -------------------------------------------------------------------------------- /bootorl/analysis/visualize_distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | sys.path.append(os.getcwd()) 7 | from utils import DiscretizedDataset 8 | import numpy as np 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | import matplotlib.lines as mlines 12 | 13 | plt.rcParams["font.family"] = "Times New Roman" 14 | plt.rcParams["font.size"] = 22 15 | 16 | 17 | def plot_tsne(game="halfcheetah", level="medium", sample=1, update=True): 18 | if not update and os.path.exists(f"./logs/{game}-{level}/distribution/reduced_tsne.npz"): 19 | data = np.load(f"./logs/{game}-{level}/distribution/reduced_tsne.npz") 20 | reduced_origin = data["reduced_origin"] 21 | reduced_recon_ar = data["reduced_recon_ar"] 22 | reduced_recon_tf = data["reduced_recon_tf"] 23 | reduced_noise = data["reduced_noise"] 24 | reduced_merged = np.concatenate([reduced_origin, reduced_recon_ar, reduced_recon_tf, reduced_noise]) 25 | else: 26 | trans = [] 27 | for i, scheme in enumerate(["boot_genr_once_ar", "boot_genr_once_tf"]): 28 | transitions = np.load(f"./logs/{game}-{level}/distribution/{scheme}/transitions.npz") 29 | trans_origin = transitions["trans_origin"][:10000] 30 | trans_generated_reconstruct = transitions[f"trans_recon_{scheme.split('_')[3]}_{scheme.split('_')[1]}"][:10000] 31 | if i == 0: 32 | trans.append(trans_origin) 33 | 34 | dataset = DiscretizedDataset(logger=None, env=f"{game}-{level}-v2", n_bins=100, sequence_length=10, penalty=-100, discount=0.99) 35 | discretizer = dataset.discretizer 36 | mean, std = dataset.raw_data_mean, dataset.raw_data_std 37 | trans_noise = trans_origin.copy() 38 | noise = np.random.normal(scale=3e-4, size=trans_noise.shape) 39 | noise = noise * std + mean 40 | trans_noise[:, :dataset.observation_dim] += noise[:, :dataset.observation_dim] 41 | trans_noise = discretizer.discretize(trans_noise) 42 | trans_noise = discretizer.reconstruct(trans_noise) 43 | 44 | trans.append(trans_generated_reconstruct) 45 | 46 | trans.append(trans_noise) 47 | trans_merged = np.concatenate(trans) 48 | print(trans_merged.shape) 49 | 50 | from sklearn.manifold import TSNE 51 | n_components = 2 52 | tsne = TSNE(n_components) 53 | reduced_merged = tsne.fit_transform(trans_merged) 54 | reduced_origin = reduced_merged[ 0: 10000] 55 | reduced_recon_ar = reduced_merged[10000: 20000] 56 | reduced_recon_tf = reduced_merged[20000: 30000] 57 | reduced_noise = reduced_merged[30000: 40000] 58 | np.savez( 59 | f"./logs/{game}-{level}/distribution/reduced_tsne.npz", 60 | reduced_origin=reduced_origin, 61 | reduced_recon_ar=reduced_recon_ar, 62 | reduced_recon_tf=reduced_recon_tf, 63 | reduced_noise=reduced_noise, 64 | ) 65 | 66 | x_min, y_min = reduced_merged.min(axis=0) 67 | x_max, y_max = reduced_merged.max(axis=0) 68 | x_left, x_right = 0.5 * (x_max + x_min) - 0.55 * (x_max - x_min), 0.5 * (x_max + x_min) + 0.55 * (x_max - x_min) 69 | y_bottom, y_top = 0.5 * (y_max + y_min) - 0.55 * (y_max - y_min), 0.5 * (y_max + y_min) + 0.80 * (y_max - y_min) 70 | 71 | fig, axes = plt.subplots(1, 4, figsize=(4 * 6, 3.8), dpi=320) 72 | for i in range(4): 73 | axes[i].set_xlim(x_left, x_right) 74 | axes[i].set_ylim(y_bottom, y_top) 75 | axes[i].grid(ls="--", alpha=0.5) 76 | 77 | axes[0].scatter(reduced_origin[::sample, 0], reduced_origin[::sample, 1], marker='o', s=18, c=f"#2980b9", alpha=0.15*sample) 78 | axes[0].set_title('Original Dataset', y=-0.33) 79 | h = mlines.Line2D([], [], color='#2980b9', marker='o', linestyle='None', markersize=12, label='Original Data') 80 | axes[0].legend(handles=[h], bbox_to_anchor=(0.01, 0.99), loc='upper left', borderpad=0.2, handlelength=1, handletextpad=0.4, borderaxespad=0.1) 81 | 82 | h = axes[1].scatter(reduced_recon_tf[::sample, 0], reduced_recon_tf[::sample, 1], marker='^', s=18, c=f"#c0392b", alpha=0.15*sample) 83 | axes[1].set_title('Teacher-forcing Generation', y=-0.33) 84 | h = mlines.Line2D([], [], color='#c0392b', marker='^', linestyle='None', markersize=12, label='Generated Data') 85 | axes[1].legend(handles=[h], bbox_to_anchor=(0.01, 0.99), loc='upper left', borderpad=0.2, handlelength=1, handletextpad=0.4, borderaxespad=0.1) 86 | 87 | h = axes[2].scatter(reduced_recon_ar[::sample, 0], reduced_recon_ar[::sample, 1], marker='^', s=18, c=f"#f39c12", alpha=0.15*sample) 88 | axes[2].set_title('Autoregressive Generation', y=-0.33) 89 | h = mlines.Line2D([], [], color='#f39c12', marker='^', linestyle='None', markersize=12, label='Generated Data') 90 | axes[2].legend(handles=[h], bbox_to_anchor=(0.01, 0.99), loc='upper left', borderpad=0.2, handlelength=1, handletextpad=0.4, borderaxespad=0.1) 91 | 92 | axes[3].scatter(reduced_origin[::sample, 0], reduced_origin[::sample, 1], marker='o', s=18, c=f"#2980b9", alpha=0.15*sample) 93 | axes[3].scatter(reduced_recon_tf[::sample, 0], reduced_recon_tf[::sample, 1], marker='^', s=18, c=f"#c0392b", alpha=0.15*sample) 94 | axes[3].scatter(reduced_recon_ar[::sample, 0], reduced_recon_ar[::sample, 1], marker='^', s=18, c=f"#f39c12", alpha=0.12*sample) 95 | axes[3].set_title(f'Augmented Dataset', y=-0.33) 96 | 97 | plt.tight_layout() 98 | print(f"Saving figure to `./analysis/images/tsne_distribution/tsne_distribution_{game}-{level}`") 99 | plt.savefig(f"./analysis/images/tsne_distribution/tsne_distribution_{game}-{level}.pdf") 100 | plt.savefig(f"./analysis/images/tsne_distribution/tsne_distribution_{game}-{level}.png") 101 | plt.clf() 102 | 103 | 104 | # ======================= Noise ======================= 105 | fig, axes = plt.subplots(1, 3, figsize=(3 * 6, 3.8), dpi=320) 106 | for i in range(3): 107 | axes[i].set_xlim(x_left, x_right) 108 | axes[i].set_ylim(y_bottom, y_top) 109 | axes[i].grid(ls="--", alpha=0.5) 110 | 111 | axes[0].scatter(reduced_origin[::sample, 0], reduced_origin[::sample, 1], marker='o', s=18, c=f"#2980b9", alpha=0.15*sample) 112 | axes[0].set_title('Original Dataset', y=-0.33) 113 | h = mlines.Line2D([], [], color='#2980b9', marker='o', linestyle='None', markersize=12, label='Original Data') 114 | axes[0].legend(handles=[h], bbox_to_anchor=(0.01, 0.99), loc='upper left', borderpad=0.2, handlelength=1, handletextpad=0.4, borderaxespad=0.1) 115 | 116 | h = axes[1].scatter(reduced_noise[::sample, 0], reduced_noise[::sample, 1], marker='^', s=18, c=f"#27ae60", alpha=0.15*sample) 117 | axes[1].set_title('Noisy Data', y=-0.33) 118 | h = mlines.Line2D([], [], color='#27ae60', marker='^', linestyle='None', markersize=12, label='Generated Data') 119 | axes[1].legend(handles=[h], bbox_to_anchor=(0.01, 0.99), loc='upper left', borderpad=0.2, handlelength=1, handletextpad=0.4, borderaxespad=0.1) 120 | 121 | axes[2].scatter(reduced_origin[::sample, 0], reduced_origin[::sample, 1], marker='o', s=18, c=f"#2980b9", alpha=0.15*sample) 122 | axes[2].scatter(reduced_noise[::sample, 0], reduced_noise[::sample, 1], marker='^', s=18, c=f"#27ae60", alpha=0.12*sample) 123 | axes[2].set_title(f'Augmented Dataset', y=-0.33) 124 | 125 | plt.tight_layout() 126 | print(f"Saving figure to `./analysis/images/tsne_distribution/tsne_distribution_{game}-{level}_noise`") 127 | plt.savefig(f"./analysis/images/tsne_distribution/tsne_distribution_{game}-{level}_noise.pdf") 128 | plt.savefig(f"./analysis/images/tsne_distribution/tsne_distribution_{game}-{level}_noise.png") 129 | plt.clf() 130 | 131 | 132 | for game in ["halfcheetah", "hopper", "walker2d"]: 133 | for level in ["medium", "medium-replay", "medium-expert"]: 134 | plot_tsne(game, level, sample=4, update=False) 135 | 136 | -------------------------------------------------------------------------------- /bootorl/analysis/calc_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import torch 7 | import os 8 | import sys 9 | sys.path.append(os.getcwd()) 10 | from utils import DiscretizedDataset 11 | from itertools import accumulate 12 | 13 | pd.set_option("display.precision", 2) 14 | 15 | 16 | def calc_mmd(x, y): 17 | """Emprical maximum mean discrepancy. The lower the result 18 | the more evidence that distributions are the same. 19 | 20 | Args: 21 | x: first sample, distribution P 22 | y: second sample, distribution Q 23 | 24 | borrowed from https://www.kaggle.com/code/onurtunali/maximum-mean-discrepancy/notebook 25 | """ 26 | x, y = torch.from_numpy(x), torch.from_numpy(y) 27 | xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) 28 | rx, ry = (xx.diag().unsqueeze(0).expand_as(xx)), (yy.diag().unsqueeze(0).expand_as(yy)) 29 | dxx, dyy, dxy = rx.t() + rx - 2. * xx, ry.t() + ry - 2. * yy, rx.t() + ry - 2. * zz 30 | XX, YY, XY = (torch.zeros(xx.shape), torch.zeros(xx.shape), torch.zeros(xx.shape)) 31 | bandwidth_range = [10, 15, 20, 50] 32 | for a in bandwidth_range: 33 | XX += torch.exp(-0.5*dxx/a) 34 | YY += torch.exp(-0.5*dyy/a) 35 | XY += torch.exp(-0.5*dxy/a) 36 | return torch.mean(XX + YY - 2. * XY) 37 | 38 | 39 | def calc_r(real_r, pred_r_list): 40 | reversed_rewards = real_r[::-1] 41 | cumulative_rewards = list(accumulate(reversed_rewards, lambda x, y: x * 0.99 + y)) 42 | real_R = np.array(cumulative_rewards) 43 | pred_r = [t[0] for t in pred_r_list] 44 | reversed_rewards = pred_r[::-1] 45 | cumulative_rewards = list(accumulate(reversed_rewards, lambda x, y: x * 0.99 + y)) 46 | pred_R = np.array(cumulative_rewards) 47 | return real_r, real_R, pred_r, pred_R 48 | 49 | 50 | def calc_dist(target="boot_genr_repeat_ar", env="halfcheetah-medium"): 51 | transitions = np.load(f"./logs/{env}/distribution/{target}/transitions.npz") 52 | trans_origin = transitions["trans_origin"][:10000] 53 | trans_discretized = transitions["trans_discretized"][:10000] 54 | trans_generated_discretized = transitions[f"trans_generated_{target.split('_')[3]}_{target.split('_')[1]}"][:10000] 55 | trans_generated_reconstruct = transitions[f"trans_recon_{target.split('_')[3]}_{target.split('_')[1]}"][:10000] 56 | rmse = np.sqrt(((trans_generated_discretized - trans_discretized) ** 2).mean()) 57 | print(f"\tRMSE (state): {rmse:.4f}") 58 | mmd = calc_mmd(trans_generated_reconstruct, trans_origin) 59 | print(f"\tMMD (10^-3): {mmd * 1000:.4f}") 60 | return rmse, mmd 61 | 62 | 63 | def calc_boot_dist(): 64 | dist = {} 65 | for target in ["boot_realr_once_ar", "boot_realr_once_tf", "boot_genr_once_ar", "boot_genr_once_tf"]: 66 | for game in ["halfcheetah", "hopper", "walker2d"]: 67 | for level in ["medium", "medium-replay", "medium-expert"]: 68 | rmse, mmd = calc_dist(target=target, env=f"{game}-{level}") 69 | dist[(target, "rmse", game, level)] = rmse 70 | dist[(target, "mmd", game, level)] = mmd 71 | multi_idx = pd.MultiIndex.from_tuples(dist, names=["target", "metric", "game", "level"]) 72 | dist = pd.DataFrame(dist.values(), index=multi_idx).sort_index() 73 | dist = dist.groupby(["target", "metric"]).mean() 74 | dist.to_csv("./analysis/dist.csv", index=False) 75 | print(dist) 76 | return dist 77 | 78 | 79 | def calc_eval_dist(game="halfcheetah", level="medium"): 80 | dist = {} 81 | for target in ["boot_genr_once_ar", "boot_genr_once_tf"]: 82 | for game in ["halfcheetah", "hopper", "walker2d"]: 83 | for level in ["medium", "medium-replay", "medium-expert"]: 84 | df = pd.read_csv(f"./logs/{game}-{level}/plan_analysis/{target}/reward_analysis.csv", sep="\t") 85 | real_s = np.array([eval(r) for r in df["rollout_states"]]) 86 | pred_s = np.array([eval(r) for r in df["predict_states"]]) 87 | real_r = np.array([r for r in df["real_rewards"]]) 88 | pred_r_list = np.array([eval(r) for r in df["rollout_values"]]) 89 | real_r, real_R, pred_r, pred_R = calc_r(real_r, pred_r_list) 90 | 91 | dataset = DiscretizedDataset(logger=None, env=f"{game}-{level}-v2", n_bins=100, sequence_length=10, penalty=-100, discount=0.99) 92 | discretizer = dataset.discretizer 93 | mean, std = dataset.raw_data_mean, dataset.raw_data_std 94 | 95 | real_traj = np.zeros((len(real_s), dataset.joined_dim), dtype=real_s.dtype) 96 | real_traj[:, :dataset.observation_dim] = real_s 97 | real_traj[:, -2] = real_r 98 | real_traj[:, -1] = real_R 99 | pred_traj = np.zeros((len(pred_s), dataset.joined_dim), dtype=pred_s.dtype) 100 | pred_traj[:, :dataset.observation_dim] = pred_s 101 | pred_traj[:, -2] = pred_r 102 | pred_traj[:, -1] = pred_R 103 | 104 | mmd = calc_mmd(real_traj, pred_traj) 105 | print(f"\tMMD (10^-3): {mmd * 1000:.4f}") 106 | 107 | real_traj = discretizer.discretize(real_traj) 108 | pred_traj = discretizer.discretize(pred_traj) 109 | rmse = np.sqrt(((real_traj - pred_traj) ** 2).mean()) 110 | print(f"\tRMSE (state): {rmse:.4f}") 111 | 112 | dist[(target, "rmse", game, level)] = rmse 113 | dist[(target, "mmd", game, level)] = mmd * 1000 114 | 115 | multi_idx = pd.MultiIndex.from_tuples(dist, names=["target", "metric", "game", "level"]) 116 | dist = pd.DataFrame(dist.values(), index=multi_idx).sort_index() 117 | dist = dist.groupby(["target", "metric"]).mean() 118 | dist.to_csv("./analysis/eval_dist.csv") 119 | print(dist) 120 | 121 | 122 | def calc_noise_eval_dist(game="halfcheetah", level="medium"): 123 | dist = {} 124 | for target in ["boot_s4rl_noise_last"]: 125 | for game in ["halfcheetah", "hopper", "walker2d"]: 126 | for level in ["medium", "medium-replay", "medium-expert"]: 127 | df = pd.read_csv(f"./logs/{game}-{level}/plan_analysis/{target}/reward_analysis.csv", sep="\t") 128 | real_s = np.array([eval(r) for r in df["rollout_states"]]) 129 | pred_s = np.array([eval(r) for r in df["predict_states"]]) 130 | real_r = np.array([r for r in df["real_rewards"]]) 131 | pred_r_list = np.array([eval(r) for r in df["rollout_values"]]) 132 | real_r, real_R, pred_r, pred_R = calc_r(real_r, pred_r_list) 133 | 134 | dataset = DiscretizedDataset(logger=None, env=f"{game}-{level}-v2", n_bins=100, sequence_length=10, penalty=-100, discount=0.99) 135 | discretizer = dataset.discretizer 136 | mean, std = dataset.raw_data_mean, dataset.raw_data_std 137 | 138 | real_traj = np.zeros((len(real_s), dataset.joined_dim), dtype=real_s.dtype) 139 | real_traj[:, :dataset.observation_dim] = real_s 140 | real_traj[:, -2] = real_r 141 | real_traj[:, -1] = real_R 142 | pred_traj = np.zeros((len(pred_s), dataset.joined_dim), dtype=pred_s.dtype) 143 | pred_traj[:, :dataset.observation_dim] = pred_s 144 | pred_traj[:, -2] = pred_r 145 | pred_traj[:, -1] = pred_R 146 | 147 | noisy_traj = real_traj.copy() 148 | 149 | noise = np.random.normal(scale=3e-4, size=noisy_traj.shape) 150 | noise = noise * std + mean 151 | noisy_traj[:, :dataset.observation_dim] += noise[:, :dataset.observation_dim] 152 | 153 | mmd = calc_mmd(real_traj, noisy_traj) 154 | print(f"\tMMD (10^-3): {mmd * 1000:.4f}") 155 | 156 | real_traj = discretizer.discretize(real_traj) 157 | noisy_traj = discretizer.discretize(noisy_traj) 158 | rmse = np.sqrt(((real_traj - noisy_traj) ** 2).mean()) 159 | print(f"\tRMSE (state): {rmse:.4f}") 160 | 161 | dist[(target, "rmse", game, level)] = rmse 162 | dist[(target, "mmd", game, level)] = mmd * 1000 163 | 164 | multi_idx = pd.MultiIndex.from_tuples(dist, names=["target", "metric", "game", "level"]) 165 | dist = pd.DataFrame(dist.values(), index=multi_idx).sort_index() 166 | dist = dist.groupby(["target", "metric"]).mean() 167 | dist.to_csv("./analysis/noise_eval_dist.csv") 168 | print(dist) 169 | 170 | 171 | calc_boot_dist() 172 | calc_eval_dist() 173 | calc_noise_eval_dist() 174 | -------------------------------------------------------------------------------- /agentlightning/examples/spider/spider_eval/parse.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | # The evaluation code is from https://github.com/taoyds/test-suite-sql-eval 3 | 4 | import re 5 | import sqlparse 6 | from typing import List, Tuple, Set, Iterator, Dict, Any, Union 7 | from sqlparse.sql import Comparison, Identifier 8 | from sqlparse.tokens import Whitespace 9 | import itertools 10 | from collections import namedtuple 11 | 12 | Token = namedtuple('Token', ['ttype', 'value']) 13 | VALUE_NUM_SYMBOL = 'VALUERARE' 14 | QUOTE_CHARS = {'`', '\'', '"'} 15 | 16 | 17 | def tokenize(query: str) -> List[Token]: 18 | tokens = list([Token(t.ttype, t.value) for t in sqlparse.parse(query)[0].flatten()]) 19 | return tokens 20 | 21 | 22 | def join_tokens(tokens: List[Token]) -> str: 23 | return ''.join([x.value for x in tokens]).strip().replace(' ', ' ') 24 | 25 | 26 | def round_trip_test(query: str) -> None: 27 | tokens = tokenize(query) 28 | reconstructed = ''.join([token.value for token in tokens]) 29 | assert query == reconstructed, "Round trip test fails for string %s" % query 30 | 31 | 32 | def postprocess(query: str) -> str: 33 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 34 | return query 35 | 36 | 37 | # strip_query, reformat_query and replace values 38 | # were implemented by Yu Tao for processing CoSQL 39 | def strip_query(query: str) -> Tuple[List[str], List[str]]: 40 | query_keywords, all_values = [], [] 41 | 42 | # then replace all stuff enclosed by "" with a numerical value to get it marked as {VALUE} 43 | 44 | # Tao's implementation is commented out here. 45 | """ 46 | str_1 = re.findall("\"[^\"]*\"", query) 47 | str_2 = re.findall("\'[^\']*\'", query) 48 | values = str_1 + str_2 49 | """ 50 | 51 | toks = sqlparse.parse(query)[0].flatten() 52 | values = [t.value for t in toks if t.ttype == sqlparse.tokens.Literal.String.Single or t.ttype == sqlparse.tokens.Literal.String.Symbol] 53 | 54 | 55 | for val in values: 56 | all_values.append(val) 57 | query = query.replace(val.strip(), VALUE_NUM_SYMBOL) 58 | 59 | query_tokenized = query.split() 60 | float_nums = re.findall("[-+]?\d*\.\d+", query) 61 | all_values += [qt for qt in query_tokenized if qt in float_nums] 62 | query_tokenized = [VALUE_NUM_SYMBOL if qt in float_nums else qt for qt in query_tokenized] 63 | 64 | query = " ".join(query_tokenized) 65 | int_nums = [i.strip() for i in re.findall("[^tT]\d+", query)] 66 | 67 | all_values += [qt for qt in query_tokenized if qt in int_nums] 68 | query_tokenized = [VALUE_NUM_SYMBOL if qt in int_nums else qt for qt in query_tokenized] 69 | # print int_nums, query, query_tokenized 70 | 71 | for tok in query_tokenized: 72 | if "." in tok: 73 | table = re.findall("[Tt]\d+\.", tok) 74 | if len(table) > 0: 75 | to = tok.replace(".", " . ").split() 76 | to = [t.lower() for t in to if len(t) > 0] 77 | query_keywords.extend(to) 78 | else: 79 | query_keywords.append(tok.lower()) 80 | 81 | elif len(tok) > 0: 82 | query_keywords.append(tok.lower()) 83 | return query_keywords, all_values 84 | 85 | 86 | def reformat_query(query: str) -> str: 87 | query = query.strip().replace(";", "").replace("\t", "") 88 | query = ' '.join([t.value for t in tokenize(query) if t.ttype != sqlparse.tokens.Whitespace]) 89 | t_stars = ["t1.*", "t2.*", "t3.*", "T1.*", "T2.*", "T3.*"] 90 | for ts in t_stars: 91 | query = query.replace(ts, "*") 92 | return query 93 | 94 | 95 | def replace_values(sql: str) -> Tuple[List[str], Set[str]]: 96 | sql = sqlparse.format(sql, reindent=False, keyword_case='upper') 97 | # sql = re.sub(r"(<=|>=|!=|=|<|>|,)", r" \1 ", sql) 98 | sql = re.sub(r"(T\d+\.)\s", r"\1", sql) 99 | query_toks_no_value, values = strip_query(sql) 100 | return query_toks_no_value, set(values) 101 | 102 | 103 | # extract the non-value tokens and the set of values 104 | # from a sql query 105 | def extract_query_values(sql: str) -> Tuple[List[str], Set[str]]: 106 | reformated = reformat_query(query=sql) 107 | query_value_replaced, values = replace_values(reformated) 108 | return query_value_replaced, values 109 | 110 | 111 | # plug in the values into query with value slots 112 | def plugin(query_value_replaced: List[str], values_in_order: List[str]) -> str: 113 | q_length = len(query_value_replaced) 114 | query_w_values = query_value_replaced[:] 115 | value_idx = [idx for idx in range(q_length) if query_value_replaced[idx] == VALUE_NUM_SYMBOL.lower()] 116 | assert len(value_idx) == len(values_in_order) 117 | 118 | for idx, value in zip(value_idx, values_in_order): 119 | query_w_values[idx] = value 120 | return ' '.join(query_w_values) 121 | 122 | 123 | # a generator generating all possible ways of 124 | # filling values into predicted query 125 | def plugin_all_permutations(query_value_replaced: List[str], values: Set[str]) -> Iterator[str]: 126 | num_slots = len([v for v in query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 127 | for values in itertools.product(*[list(values) for _ in range(num_slots)]): 128 | yield plugin(query_value_replaced, list(values)) 129 | 130 | 131 | # given the gold query and the model prediction 132 | # extract values from the gold, extract predicted sql with value slots 133 | # return 1) number of possible ways to plug in gold values and 2) an iterator of predictions with value plugged in 134 | def get_all_preds_for_execution(gold: str, pred: str) -> Tuple[int, Iterator[str]]: 135 | _, gold_values = extract_query_values(gold) 136 | pred_query_value_replaced, _ = extract_query_values(pred) 137 | num_slots = len([v for v in pred_query_value_replaced if v == VALUE_NUM_SYMBOL.lower()]) 138 | num_alternatives = len(gold_values) ** num_slots 139 | return num_alternatives, plugin_all_permutations(pred_query_value_replaced, gold_values) 140 | 141 | 142 | def remove_distinct(s): 143 | toks = [t.value for t in list(sqlparse.parse(s)[0].flatten())] 144 | return ''.join([t for t in toks if t.lower() != 'distinct']) 145 | 146 | 147 | def extract_all_comparison_from_node(node: Token) -> List[Comparison]: 148 | comparison_list = [] 149 | if hasattr(node, 'tokens'): 150 | for t in node.tokens: 151 | comparison_list.extend(extract_all_comparison_from_node(t)) 152 | if type(node) == Comparison: 153 | comparison_list.append(node) 154 | return comparison_list 155 | 156 | 157 | def extract_all_comparison(query: str) -> List[Comparison]: 158 | tree = sqlparse.parse(query)[0] 159 | comparison_list = extract_all_comparison_from_node(tree) 160 | return comparison_list 161 | 162 | 163 | def extract_toks_from_comparison(comparison_node: Comparison) -> List[Token]: 164 | tokens = [t for t in comparison_node.tokens if t.ttype != Whitespace] 165 | return tokens 166 | 167 | 168 | def extract_info_from_comparison(comparison_node: Comparison) -> Dict[str, Any]: 169 | tokens = extract_toks_from_comparison(comparison_node) 170 | left, op, right = tokens 171 | 172 | returned_dict = { 173 | 'left': left, 174 | 'op': op.value, 175 | 'right': right 176 | } 177 | 178 | if type(left) != Identifier: 179 | return returned_dict 180 | 181 | table = None 182 | if len(left.tokens) == 3 and re.match('^[tT][0-9]$', left.tokens[0].value) is None: 183 | table = left.tokens[0].value.lower() 184 | col = left.tokens[-1].value 185 | 186 | if type(right) == Identifier: 187 | if len(right.tokens) == 1 and type(right.tokens[0]) == sqlparse.sql.Token: 188 | right_val = right.tokens[0].value 189 | else: 190 | return returned_dict 191 | elif type(right) == sqlparse.sql.Token: 192 | right_val = right.value 193 | else: 194 | return returned_dict 195 | 196 | returned_dict['table_col'], returned_dict['val'] = (table, col.upper()), process_str_value(right_val) 197 | 198 | return returned_dict 199 | 200 | 201 | def extract_all_comparison_from_query(query: str) -> List[Dict[str, Any]]: 202 | comparison_list = extract_all_comparison(query) 203 | return [extract_info_from_comparison(c) for c in comparison_list] 204 | 205 | 206 | def extract_typed_value_in_comparison_from_query(query: str) -> List[Tuple[Tuple[Union[str, None], str], str]]: 207 | cmps = extract_all_comparison_from_query(query) 208 | typed_values = [(cmp['table_col'], cmp['val']) for cmp in cmps if 'table_col' in cmp] 209 | for table, col, val1, val2 in re.findall('(?:([^\.\s]*)\.)?([^\.\s]+) between ([^\s;]+) and ([^\s;]+)', query, re.IGNORECASE): 210 | if table == '': 211 | table = None 212 | else: 213 | table = table.lower() 214 | col = col.upper() 215 | for v in [val1, val2]: 216 | typed_values.append(((table, col), v)) 217 | return typed_values 218 | 219 | 220 | def process_str_value(v: str) -> str: 221 | if len(v) > 0 and v[0] in QUOTE_CHARS: 222 | v = v[1:] 223 | if len(v) > 0 and v[-1] in QUOTE_CHARS: 224 | v = v[:-1] 225 | for c in QUOTE_CHARS: 226 | v = v.replace(c + c, c) 227 | return v 228 | -------------------------------------------------------------------------------- /agentlightning/agentlightning/client.py: -------------------------------------------------------------------------------- 1 | import aiohttp 2 | import asyncio 3 | import logging 4 | import requests 5 | import time 6 | import urllib.parse 7 | from typing import List, TypedDict 8 | 9 | from .trace import Transition 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TaskData(TypedDict): 16 | rollout_id: str 17 | is_train: bool 18 | 19 | 20 | class SamplingParameters(TypedDict): 21 | model: str 22 | temperature: float 23 | 24 | 25 | class VerlAgentClient: 26 | 27 | def __init__(self, endpoint: str, poll_interval: float = 5.0, timeout: float = 10.0) -> None: 28 | """ 29 | Initialize the VerlAgentClient with the given endpoint. 30 | 31 | :param endpoint: The root URL of the VeRL agent server. 32 | :param poll_interval: The interval in seconds to wait between polling for new tasks. 33 | :param timeout: The timeout in seconds for HTTP requests. 34 | """ 35 | self.endpoint = endpoint 36 | self.task_count = 0 37 | self.poll_interval = poll_interval 38 | self.timeout = timeout 39 | 40 | @property 41 | def openai_endpoint(self) -> str: 42 | """The OpenAI endpoint for the VeRL agent server.""" 43 | return urllib.parse.urljoin(self.endpoint, "v1") 44 | 45 | async def request_json_async(self, url: str) -> dict | None: 46 | """ 47 | Make a GET request to the specified URL and return the JSON response. 48 | 49 | :param url: The URL to request. 50 | :return: The JSON response as a dictionary. 51 | """ 52 | timeout = aiohttp.ClientTimeout(total=self.timeout) 53 | async with aiohttp.ClientSession(timeout=timeout) as session: 54 | try: 55 | async with session.get(url) as resp: 56 | resp.raise_for_status() 57 | return await resp.json() 58 | except Exception as e: 59 | logger.debug("GET request failed: %s", e) 60 | return None 61 | 62 | async def post_json_async(self, url: str, payload: dict) -> dict | None: 63 | """ 64 | Make a POST request to the specified URL with the given payload and return the JSON response. 65 | 66 | :param url: The URL to post to. 67 | :param payload: The data to send in the POST request. 68 | :return: The JSON response as a dictionary. 69 | """ 70 | timeout = aiohttp.ClientTimeout(total=self.timeout) 71 | async with aiohttp.ClientSession(timeout=timeout) as session: 72 | try: 73 | async with session.post(url, json=payload) as resp: 74 | resp.raise_for_status() 75 | return await resp.json() 76 | except Exception as e: 77 | logger.debug("POST request failed: %s", e) 78 | return None 79 | 80 | async def poll_next_task_async(self) -> TaskData: 81 | """Poll the server for the next task data sample until it is available. 82 | 83 | Returns a task data dict which has the same format as the dataset sample. 84 | It has an extra `rollout_id` field, which is a unique identifier for the task, 85 | and an `is_train` field indicating whether the task is for training or evaluation. 86 | """ 87 | url = urllib.parse.urljoin(self.endpoint, "next_data_sample") 88 | while True: 89 | data = await self.request_json_async(url) 90 | if data and data.get("is_available"): 91 | task_data = data["data"] 92 | self.task_count += 1 93 | logger.info("[Task %d Received] %s", self.task_count, task_data) 94 | return task_data 95 | else: 96 | logger.debug("No task available yet. Retrying in 5 seconds...") 97 | await asyncio.sleep(self.poll_interval) 98 | 99 | async def poll_sampling_parameters_async(self) -> SamplingParameters: 100 | """Poll the server for sampling parameters until they are available. 101 | 102 | The client agent is expected to respect the designated sampling parameters 103 | when calling the LLMs, to maximize the power of the algorithms. 104 | """ 105 | url = urllib.parse.urljoin(self.endpoint, "train_information") 106 | while True: 107 | data = await self.request_json_async(url) 108 | if data: 109 | logger.info("Sampling parameters received: %s", data) 110 | return data 111 | else: 112 | logger.debug( 113 | "No sampling parameters available yet. Retrying in 5 seconds..." 114 | ) 115 | await asyncio.sleep(self.poll_interval) 116 | 117 | async def post_trajectory_async( 118 | self, rollout_id: str, transitions: List[Transition] 119 | ) -> dict: 120 | url = urllib.parse.urljoin(self.endpoint, "report") 121 | payload = self._to_acceptable_trajectory_payload(rollout_id, transitions) 122 | 123 | return await self.post_json_async(url, payload) 124 | 125 | def _to_acceptable_trajectory_payload( 126 | self, rollout_id: str, transitions: List[Transition] 127 | ) -> dict: 128 | """Convert a list of Transition objects to a payload dictionary.""" 129 | return { 130 | "rollout_id": rollout_id, 131 | "reward": sum(t.reward for t in transitions if t.reward is not None), 132 | "trace_list": [ 133 | {"prompt_ids": list(t.state), "response_ids": list(t.action)} 134 | for t in transitions 135 | ], 136 | } 137 | 138 | # Synchronous methods 139 | def request_json(self, url: str) -> dict | None: 140 | """ 141 | Make a GET request to the specified URL and return the JSON response. 142 | 143 | :param url: The URL to request. 144 | :return: The JSON response as a dictionary. 145 | """ 146 | try: 147 | response = requests.get(url, timeout=self.timeout) 148 | response.raise_for_status() 149 | return response.json() 150 | except Exception as e: 151 | logger.debug("GET request failed: %s", e) 152 | return None 153 | 154 | def post_json(self, url: str, payload: dict) -> dict | None: 155 | """ 156 | Make a POST request to the specified URL with the given payload and return the JSON response. 157 | 158 | :param url: The URL to post to. 159 | :param payload: The data to send in the POST request. 160 | :return: The JSON response as a dictionary. 161 | """ 162 | try: 163 | response = requests.post(url, json=payload, timeout=self.timeout) 164 | response.raise_for_status() 165 | return response.json() 166 | except Exception as e: 167 | logger.debug("POST request failed: %s", e) 168 | return None 169 | 170 | def poll_next_task(self) -> TaskData: 171 | """Poll the server for the next task data sample until it is available. 172 | 173 | Returns a task data dict which has the same format as the dataset sample. 174 | It has an extra `rollout_id` field, which is a unique identifier for the task, 175 | and an `is_train` field indicating whether the task is for training or evaluation. 176 | """ 177 | url = urllib.parse.urljoin(self.endpoint, "next_data_sample") 178 | while True: 179 | data = self.request_json(url) 180 | if data and data.get("is_available"): 181 | task_data = data["data"] 182 | self.task_count += 1 183 | logger.info("[Task %d Received] %s", self.task_count, task_data) 184 | return task_data 185 | else: 186 | logger.debug("No task available yet. Retrying in %s seconds...", self.poll_interval) 187 | time.sleep(self.poll_interval) 188 | 189 | def poll_sampling_parameters(self) -> SamplingParameters: 190 | """Poll the server for sampling parameters until they are available. 191 | 192 | The client agent is expected to respect the designated sampling parameters 193 | when calling the LLMs, to maximize the power of the algorithms. 194 | """ 195 | url = urllib.parse.urljoin(self.endpoint, "train_information") 196 | while True: 197 | data = self.request_json(url) 198 | if data: 199 | logger.info("Sampling parameters received: %s", data) 200 | return data 201 | else: 202 | logger.debug( 203 | "No sampling parameters available yet. Retrying in %s seconds...", 204 | self.poll_interval 205 | ) 206 | time.sleep(self.poll_interval) 207 | 208 | def post_trajectory( 209 | self, rollout_id: str, transitions: List[Transition] 210 | ) -> dict: 211 | """Post a trajectory to the server synchronously. 212 | 213 | :param rollout_id: The unique identifier for the rollout. 214 | :param transitions: List of transitions in the trajectory. 215 | :return: The server response as a dictionary. 216 | """ 217 | url = urllib.parse.urljoin(self.endpoint, "report") 218 | payload = self._to_acceptable_trajectory_payload(rollout_id, transitions) 219 | return self.post_json(url, payload) 220 | -------------------------------------------------------------------------------- /a2ls/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.functional import linear 6 | 7 | 8 | def tie_weights(src, trg): 9 | assert type(src) == type(trg) 10 | trg.weight = src.weight 11 | trg.bias = src.bias 12 | 13 | 14 | # for 84 x 84 inputs 15 | OUT_DIM = {2: 39, 4: 35, 6: 31} 16 | # for 64 x 64 inputs 17 | OUT_DIM_64 = {2: 29, 4: 25, 6: 21} 18 | 19 | 20 | class PixelEncoder(nn.Module): 21 | """Convolutional encoder of pixels observations.""" 22 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,encoder_hidden_size=256 ,output_logits=False,*args): 23 | super().__init__() 24 | 25 | assert len(obs_shape) == 3 26 | self.obs_shape = obs_shape 27 | self.feature_dim = feature_dim 28 | self.num_layers = num_layers 29 | 30 | self.convs = nn.ModuleList( 31 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] 32 | ) 33 | for i in range(num_layers - 1): 34 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) 35 | 36 | out_dim = OUT_DIM_64[num_layers] if obs_shape[-1] == 64 else OUT_DIM[num_layers] 37 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) 38 | self.ln = nn.LayerNorm(self.feature_dim) 39 | 40 | self.outputs = dict() 41 | self.output_logits = output_logits 42 | 43 | def reparameterize(self, mu, logstd): 44 | std = torch.exp(logstd) 45 | eps = torch.randn_like(std) 46 | return mu + eps * std 47 | 48 | def forward_conv(self, obs): 49 | obs = obs / 255. 50 | self.outputs['obs'] = obs 51 | 52 | conv = torch.relu(self.convs[0](obs)) 53 | self.outputs['conv1'] = conv 54 | 55 | for i in range(1, self.num_layers): 56 | conv = torch.relu(self.convs[i](conv)) 57 | self.outputs['conv%s' % (i + 1)] = conv 58 | 59 | h = conv.view(conv.size(0), -1) 60 | return h 61 | 62 | def forward(self, obs, detach=False): 63 | h = self.forward_conv(obs) 64 | 65 | if detach: 66 | h = h.detach() 67 | 68 | h_fc = self.fc(h) 69 | self.outputs['fc'] = h_fc 70 | 71 | h_norm = self.ln(h_fc) 72 | self.outputs['ln'] = h_norm 73 | 74 | if self.output_logits: 75 | out = h_norm 76 | else: 77 | out = torch.tanh(h_norm) 78 | self.outputs['tanh'] = out 79 | 80 | return out 81 | 82 | def copy_conv_weights_from(self, source): 83 | """Tie convolutional layers""" 84 | # only tie conv layers 85 | for i in range(self.num_layers): 86 | tie_weights(src=source.convs[i], trg=self.convs[i]) 87 | 88 | def log(self, L, step, log_freq): 89 | if step % log_freq != 0: 90 | return 91 | 92 | for k, v in self.outputs.items(): 93 | L.log_histogram('train_encoder/%s_hist' % k, v, step) 94 | if len(v.shape) > 2: 95 | L.log_image('train_encoder/%s_img' % k, v[0], step) 96 | 97 | for i in range(self.num_layers): 98 | L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) 99 | L.log_param('train_encoder/fc', self.fc, step) 100 | L.log_param('train_encoder/ln', self.ln, step) 101 | 102 | 103 | class IdentityEncoder(nn.Module): 104 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args): 105 | super().__init__() 106 | 107 | assert len(obs_shape) == 1 108 | self.feature_dim = obs_shape[0] 109 | 110 | 111 | def forward(self, obs, detach=False): 112 | return obs 113 | 114 | def copy_conv_weights_from(self, source): 115 | pass 116 | 117 | def log(self, L, step, log_freq): 118 | pass 119 | 120 | 121 | class MlpEncoder(nn.Module): 122 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters, encoder_hidden_size, output_logits=False,*args): 123 | super().__init__() 124 | 125 | assert len(obs_shape) == 1 126 | 127 | self.obs_shape = obs_shape 128 | self.feature_dim = feature_dim 129 | self.num_layers = num_layers 130 | self.linears = nn.ModuleList( 131 | [nn.Linear(obs_shape[0], encoder_hidden_size)] 132 | ) 133 | for i in range(num_layers - 1): 134 | self.linears.append(nn.Linear(encoder_hidden_size, encoder_hidden_size)) 135 | 136 | 137 | self.fc = nn.Linear(encoder_hidden_size, self.feature_dim) 138 | self.ln = nn.LayerNorm(self.feature_dim) 139 | 140 | self.outputs = dict() 141 | self.output_logits = output_logits 142 | 143 | def forward_linear(self, obs): 144 | 145 | self.outputs['obs'] = obs 146 | 147 | linear = torch.relu(self.linears[0](obs)) 148 | self.outputs['linear1'] = linear 149 | 150 | for i in range(1, self.num_layers): 151 | linear = torch.relu(self.linears[i](linear)) 152 | self.outputs['linear%s' % (i + 1)] = linear 153 | 154 | h = linear 155 | return h 156 | 157 | def forward(self, obs, detach=False): 158 | h = self.forward_linear(obs) 159 | 160 | if detach: 161 | h = h.detach() 162 | 163 | h_fc = self.fc(h) 164 | self.outputs['fc'] = h_fc 165 | 166 | h_norm = self.ln(h_fc) 167 | self.outputs['ln'] = h_norm 168 | 169 | if self.output_logits: 170 | out = h_norm 171 | else: 172 | out = torch.tanh(h_norm) 173 | self.outputs['tanh'] = out 174 | 175 | return out 176 | 177 | def copy_conv_weights_from(self, source): 178 | pass 179 | 180 | def log(self, L, step, log_freq): 181 | pass 182 | 183 | 184 | class OfeEncoder(nn.Module): 185 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters, encoder_hidden_size, output_logits=False,*args): 186 | super().__init__() 187 | 188 | assert len(obs_shape) == 1 189 | 190 | self.obs_shape = obs_shape 191 | self.feature_dim = num_layers * encoder_hidden_size + obs_shape[0] 192 | self.num_layers = num_layers 193 | self.linears = nn.ModuleList( 194 | [nn.Linear(obs_shape[0], encoder_hidden_size)] 195 | ) 196 | for i in range(num_layers - 1): 197 | self.linears.append(nn.Linear( ( obs_shape[0] + (i+1) * encoder_hidden_size ), encoder_hidden_size)) # ofenet structure 198 | 199 | 200 | 201 | self.fc = nn.Linear(self.feature_dim, self.feature_dim) 202 | self.ln = nn.LayerNorm(self.feature_dim) 203 | 204 | self.outputs = dict() 205 | self.output_logits = output_logits 206 | 207 | def forward_linear(self, obs): 208 | 209 | self.outputs['obs'] = obs 210 | 211 | linear = torch.cat( [obs, torch.relu(self.linears[0](obs))], axis =1) 212 | self.outputs['linear1'] = linear 213 | #import ipdb; ipdb.set_trace() 214 | for i in range(1, self.num_layers): 215 | linear = torch.cat( [linear, torch.relu(self.linears[i](linear))], axis=1) 216 | self.outputs['linear%s' % (i + 1)] = linear 217 | 218 | h = linear 219 | return h 220 | 221 | def forward(self, obs, detach=False): 222 | h = self.forward_linear(obs) 223 | 224 | if detach: 225 | h = h.detach() 226 | 227 | h_fc = self.fc(h) 228 | self.outputs['fc'] = h_fc 229 | 230 | h_norm = self.ln(h_fc) 231 | self.outputs['ln'] = h_norm 232 | 233 | if self.output_logits: 234 | out = h_norm 235 | else: 236 | out = torch.tanh(h_norm) 237 | self.outputs['tanh'] = out 238 | 239 | return out 240 | 241 | def copy_conv_weights_from(self, source): 242 | pass 243 | 244 | def log(self, L, step, log_freq): 245 | pass 246 | 247 | 248 | class ActionOrRewardEncoder(nn.Module): 249 | def __init__(self, input_dim, feature_dim, num_layers, *args): 250 | super().__init__() 251 | assert type(input_dim) == int 252 | assert num_layers == 1 253 | self.feature_dim = feature_dim 254 | 255 | self.forward_linear_layers = nn.Sequential(nn.Linear(input_dim,feature_dim)) 256 | 257 | self.fc = nn.Linear(self.feature_dim, self.feature_dim) 258 | self.ln = nn.LayerNorm(self.feature_dim) 259 | self.outputs = dict() 260 | 261 | def forward(self, obs, detach=False): 262 | h = self.forward_linear_layers(obs) 263 | self.outputs['obs'] = obs 264 | self.outputs['linear'] = h 265 | 266 | if detach: 267 | h = h.detach() 268 | 269 | h_fc = self.fc(h) 270 | self.outputs['fc'] = h_fc 271 | 272 | h_norm = self.ln(h_fc) 273 | self.outputs['ln'] = h_norm 274 | 275 | out = torch.tanh(h_norm) 276 | return out 277 | 278 | 279 | def copy_linear_weights_from(self, source): 280 | pass 281 | 282 | def log(self, L, step, log_freq): 283 | pass 284 | 285 | 286 | _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder, 'mlp': MlpEncoder, 'ofe': OfeEncoder} 287 | 288 | 289 | def make_encoder( 290 | encoder_type, obs_shape, feature_dim, num_layers, num_filters, encoder_hidden_size, output_logits=False 291 | ): 292 | assert encoder_type in _AVAILABLE_ENCODERS 293 | return _AVAILABLE_ENCODERS[encoder_type]( 294 | obs_shape, feature_dim, num_layers, num_filters, encoder_hidden_size, output_logits 295 | ) 296 | -------------------------------------------------------------------------------- /agentlightning/examples/spider/spider_eval/exec_eval.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | # The evaluation code is from https://github.com/taoyds/test-suite-sql-eval 3 | 4 | import os 5 | import re 6 | import asyncio 7 | import sqlite3 8 | import threading 9 | from typing import Tuple, Any, List, Set 10 | from itertools import product 11 | from collections import defaultdict 12 | import tqdm 13 | import random 14 | from .parse import get_all_preds_for_execution, remove_distinct 15 | import time 16 | import pickle as pkl 17 | import subprocess 18 | from itertools import chain 19 | 20 | 21 | 22 | threadLock = threading.Lock() 23 | TIMEOUT = 60 24 | EXEC_TMP_DIR = '/tmp/' 25 | 26 | def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: 27 | assert len(element) == len(perm) 28 | return tuple([element[i] for i in perm]) 29 | 30 | 31 | def unorder_row(row: Tuple) -> Tuple: 32 | return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) 33 | 34 | 35 | # unorder each row in the table 36 | # [result_1 and result_2 has the same bag of unordered row] 37 | # is a necessary condition of 38 | # [result_1 and result_2 are equivalent in denotation] 39 | def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 40 | s1 = [unorder_row(row) for row in result1] 41 | s2 = [unorder_row(row) for row in result2] 42 | if order_matters: 43 | return s1 == s2 44 | else: 45 | return set(s1) == set(s2) 46 | 47 | 48 | # return whether two bag of relations are equivalent 49 | def multiset_eq(l1: List, l2: List) -> bool: 50 | if len(l1) != len(l2): 51 | return False 52 | d = defaultdict(int) 53 | for e in l1: 54 | d[e] = d[e] + 1 55 | for e in l2: 56 | d[e] = d[e] - 1 57 | if d[e] < 0: 58 | return False 59 | return True 60 | 61 | 62 | def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): 63 | num_cols = len(result2[0]) 64 | perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] 65 | if num_cols <= 3: 66 | return product(*perm_constraints) 67 | 68 | # we sample 20 rows and constrain the space of permutations 69 | for _ in range(20): 70 | random_tab2_row = random.choice(result2) 71 | 72 | for tab1_col in range(num_cols): 73 | for tab2_col in set(perm_constraints[tab1_col]): 74 | if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: 75 | perm_constraints[tab1_col].remove(tab2_col) 76 | return product(*perm_constraints) 77 | 78 | 79 | # check whether two denotations are correct 80 | def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: 81 | if len(result1) == 0 and len(result2) == 0: 82 | return True 83 | 84 | # if length is not the same, then they are definitely different bag of rows 85 | if len(result1) != len(result2): 86 | return False 87 | 88 | num_cols = len(result1[0]) 89 | 90 | # if the results do not have the same number of columns, they are different 91 | if len(result2[0]) != num_cols: 92 | return False 93 | 94 | # unorder each row and compare whether the denotation is the same 95 | # this can already find most pair of denotations that are different 96 | if not quick_rej(result1, result2, order_matters): 97 | return False 98 | 99 | # the rest of the problem is in fact more complicated than one might think 100 | # we want to find a permutation of column order and a permutation of row order, 101 | # s.t. result_1 is the same as result_2 102 | # we return true if we can find such column & row permutations 103 | # and false if we cannot 104 | tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] 105 | 106 | # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 107 | # we decrease the size of the column permutation space by the function get_constraint_permutation 108 | # if one of the permutation make result_1, result_2 equivalent, then they are equivalent 109 | for perm in get_constraint_permutation(tab1_sets_by_columns, result2): 110 | if len(perm) != len(set(perm)): 111 | continue 112 | if num_cols == 1: 113 | result2_perm = result2 114 | else: 115 | result2_perm = [permute_tuple(element, perm) for element in result2] 116 | if order_matters: 117 | if result1 == result2_perm: 118 | return True 119 | else: 120 | # in fact the first condition must hold if the second condition holds 121 | # but the first is way more efficient implementation-wise 122 | # and we use it to quickly reject impossible candidates 123 | if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): 124 | return True 125 | return False 126 | 127 | 128 | def replace_cur_year(query: str) -> str: 129 | return re.sub( 130 | "YEAR\s*\(\s*CURDATE\s*\(\s*\)\s*\)\s*", "2020", query, flags=re.IGNORECASE 131 | ) 132 | 133 | 134 | # get the database cursor for a sqlite database path 135 | def get_cursor_from_path(sqlite_path: str): 136 | try: 137 | if not os.path.exists(sqlite_path): 138 | print("Openning a new connection %s" % sqlite_path) 139 | connection = sqlite3.connect(sqlite_path) 140 | except Exception as e: 141 | print(sqlite_path) 142 | raise e 143 | connection.text_factory = lambda b: b.decode(errors="ignore") 144 | cursor = connection.cursor() 145 | return cursor 146 | 147 | 148 | async def exec_on_db_(sqlite_path: str, query: str) -> Tuple[str, Any]: 149 | query = replace_cur_year(query) 150 | cursor = get_cursor_from_path(sqlite_path) 151 | try: 152 | cursor.execute(query) 153 | result = cursor.fetchall() 154 | cursor.close() 155 | cursor.connection.close() 156 | return "result", result 157 | except Exception as e: 158 | cursor.close() 159 | cursor.connection.close() 160 | return "exception", e 161 | 162 | async def exec_on_db( 163 | sqlite_path: str, query: str, process_id: str = "", timeout: int = TIMEOUT 164 | ) -> Tuple[str, Any]: 165 | try: 166 | return await asyncio.wait_for(exec_on_db_(sqlite_path, query), timeout) 167 | except asyncio.TimeoutError: 168 | return ('exception', TimeoutError) 169 | except Exception as e: 170 | return ("exception", e) 171 | 172 | 173 | # postprocess the model predictions to avoid execution errors 174 | # e.g. removing spaces between ">" and "=" 175 | def postprocess(query: str) -> str: 176 | query = query.replace('> =', '>=').replace('< =', '<=').replace('! =', '!=') 177 | return query 178 | 179 | 180 | # approximate whether p_str and g_str are semantically equivalent 181 | # db is the database path 182 | # we are going to evaluate whether they are equivalent in all the databases 183 | # that are in the same directory as db 184 | # 0 if denotationally equivalent 185 | # 1 otherwise 186 | # the meaning of each auxillary argument can be seen in the parser definition in evaluation.py 187 | def eval_exec_match(db: str, p_str: str, g_str: str, plug_value: bool, keep_distinct: bool, progress_bar_for_each_datapoint: bool) -> int: 188 | # post-process the prediction. 189 | # e.g. removing spaces between ">" and "=" 190 | p_str, g_str = postprocess(p_str), postprocess(g_str) 191 | if not keep_distinct: 192 | p_str = remove_distinct(p_str) 193 | g_str = remove_distinct(g_str) 194 | 195 | # we decide whether two denotations are equivalent based on "bag semantics" 196 | # https://courses.cs.washington.edu/courses/cse444/10sp/lectures/lecture16.pdf 197 | # if there is order by in query, then we assume order of the rows matter 198 | # order by might also be used to find the max/min instead of sorting, 199 | # but in that case the result mostly only contains one row and hence order_matters does not make a difference 200 | order_matters = 'order by' in g_str.lower() 201 | 202 | # find all databases in the same directory 203 | db_dir = os.path.dirname(db) 204 | db_paths = [os.path.join(db_dir, basename) for basename in os.listdir(db_dir) if '.sqlite' in basename] 205 | 206 | preds = [p_str] 207 | # if plug in value (i.e. we do not consider value prediction correctness) 208 | # enumerate all ways to plug in values in the gold query to the model predictions 209 | # otherwise, we only evaluate the predicted query with its own value prediction 210 | if plug_value: 211 | _, preds = get_all_preds_for_execution(g_str, p_str) 212 | # we did not add this line in our EMNLP work 213 | # this reduces "false negatives" when value is substituted 214 | preds = chain([p_str], preds) 215 | 216 | for pred in preds: 217 | 218 | pred_passes = 1 219 | # compare the gold and predicted denotations on each database in the directory 220 | # wrap with progress bar if required 221 | if progress_bar_for_each_datapoint: 222 | ranger = tqdm.tqdm(db_paths) 223 | else: 224 | ranger = db_paths 225 | 226 | for db_path in ranger: 227 | g_flag, g_denotation = asyncio.run(exec_on_db(db_path, g_str)) 228 | p_flag, p_denotation = asyncio.run(exec_on_db(db_path, pred)) 229 | 230 | # we should expect the gold to be succesfully executed on the database 231 | assert g_flag != 'exception', 'gold query %s has error on database file %s' % (g_str, db_path) 232 | 233 | # wrong if execution fails 234 | if p_flag == 'exception': 235 | pred_passes = 0 236 | 237 | # if denotations are not equivalent, the prediction must be wrong 238 | elif not result_eq(g_denotation, p_denotation, order_matters=order_matters): 239 | pred_passes = 0 240 | if pred_passes == 0: 241 | break 242 | 243 | # the model prediction has the same denotation as the gold for all databases 244 | if pred_passes == 1: 245 | return 1 246 | 247 | # none of the predictions passed 248 | return 0 249 | --------------------------------------------------------------------------------