├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── __init__.py ├── algo ├── APG.py ├── MASHAC.py ├── PPO.py ├── SHAC.py ├── __init__.py ├── buffer.py └── dreamerv3 │ ├── __init__.py │ ├── models │ ├── __init__.py │ ├── agent.py │ ├── blocks.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── denoiser.py │ │ ├── diffusion_sampler.py │ │ └── inner_model.py │ └── state_predictor.py │ ├── wmenv │ ├── __init__.py │ ├── replaybuffer.py │ ├── utils.py │ └── world_state_env.py │ └── world.py ├── cfg ├── algo │ ├── apg.yaml │ ├── apg_sto.yaml │ ├── appo.yaml │ ├── mashac.yaml │ ├── ppo.yaml │ ├── sha2c.yaml │ ├── shac.yaml │ └── world.yaml ├── config_test.yaml ├── config_train.yaml ├── dynamics │ ├── pmc.yaml │ ├── pmd.yaml │ └── quad.yaml ├── env │ ├── imu │ │ └── default_imu.yaml │ ├── mapc.yaml │ ├── oa.yaml │ ├── oa_small.yaml │ ├── obstacles │ │ ├── outdoor.yaml │ │ └── small_room.yaml │ ├── pc.yaml │ ├── racing.yaml │ ├── randomizer │ │ └── default_randomizer.yaml │ └── render │ │ ├── oa_render.yaml │ │ └── pc_render.yaml ├── hydra │ ├── help │ │ ├── test_help.yaml │ │ └── train_help.yaml │ └── sweeper │ │ └── optuna_sweep.yaml ├── logger │ ├── tensorboard.yaml │ └── wandb.yaml ├── network │ ├── cnn.yaml │ ├── mlp.yaml │ ├── rcnn.yaml │ └── rnn.yaml └── sensor │ ├── camera.yaml │ ├── lidar.yaml │ └── relpos.yaml ├── dynamics ├── __init__.py ├── base_dynamics.py ├── controller.py ├── pointmass.py └── quadrotor.py ├── env ├── __init__.py ├── base_env.py ├── obstacle_avoidance.py ├── position_control.py ├── position_control_multi_agent.py └── racing.py ├── network ├── __init__.py ├── agents.py ├── multiagents.py └── networks.py ├── pyproject.toml ├── requirements.txt ├── script ├── __init__.py ├── export.py ├── fps_test.py ├── test.py └── train.py ├── setup.py └── utils ├── __init__.py ├── assets.py ├── exporter.py ├── logger.py ├── math.py ├── nn.py ├── randomizer.py ├── render.py ├── runner.py └── sensor.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | run.sh 3 | outputs/ 4 | imgui.ini 5 | diffaero.egg-info/ 6 | *.log -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "deploy"] 2 | path = deploy 3 | url = git@github.com:zxh0916/diffaero-deploy.git 4 | branch = master 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2025, State Key Lab of Autonomous Intelligent Unmanned Systems, Beijing Institute of Technology 4 | Copyright (c) 2025, Zhongguancun Academy 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffAero: A GPU-Accelerated Differentiable Simulation Framework for Efficient Quadrotor Policy Learning 2 | 3 | This repository contains the code of the paper: [DiffAero: A GPU-Accelerated Differentiable Simulation Framework for Efficient Quadrotor Policy Learning](https://arxiv.org/abs/2509.10247) 4 | 5 | - [DiffAero: A GPU-Accelerated Differentiable Simulation Framework for Efficient Quadrotor Policy Learning](#diffaero-a-gpu-accelerated-differentiable-simulation-framework-for-efficient-quadrotor-policy-learning) 6 | - [Introduction](#introduction) 7 | - [Features](#features) 8 | - [Environments](#environments) 9 | - [Learning algorithms](#learning-algorithms) 10 | - [Dynamical models](#dynamical-models) 11 | - [Sensors](#sensors) 12 | - [Installation](#installation) 13 | - [System requirements](#system-requirements) 14 | - [Installing the DiffAero](#installing-the-diffaero) 15 | - [Usage](#usage) 16 | - [Basic usage](#basic-usage) 17 | - [Visualization](#visualization) 18 | - [Visualization with taichi GUI](#visualization-with-taichi-gui) 19 | - [Visualize the depth camera and LiDAR data](#visualize-the-depth-camera-and-lidar-data) 20 | - [Record First-Person View Videos](#record-first-person-view-videos) 21 | - [Sweep across multiple configurations](#sweep-across-multiple-configurations) 22 | - [Sweep across multiple GPUs in parallel](#sweep-across-multiple-gpus-in-parallel) 23 | - [Automatic Hyperparameter Tuning](#automatic-hyperparameter-tuning) 24 | - [Deploy](#deploy) 25 | - [TODO-List](#todo-list) 26 | - [Citation](#citation) 27 | 28 | ## Introduction 29 | 30 | DiffAero is a GPU-accelerated differentiable quadrotor simulator that parallelizes both physics and rendering. It achieves orders-of-magnitude performance improvements over existing platforms with little VRAM consumption. It provides a modular and extensible framework supporting four differentiable dynamics models, three sensor modalities, and three flight tasks. Its PyTorch-based interface unifies four learning formulations and three learning paradigms. This flexibility enables DiffAero to serve as a benchmark for learning algorithms and allows researchers to investigate a wide range of problems, from differentiable policy learning to multi-agent coordination. Users can combine different components almost arbitrarily to initiate a custom-configured training process with minimal effort. 31 | 32 | ## Features 33 | 34 | | Module | Currently Supported | 35 | |----------------|-------------------------------------------------------------------------| 36 | | Tasks | Position Control, Obstacle Avoidance, Racing | 37 | | Differential Learning Algorithms | BPTT, SHAC, SHA2C | 38 | | Reinforcement Learning Algorithms | PPO, Dreamer V3 | 39 | | Sensors | Depth Camera, LiDAR | 40 | | Dynamic Models | Full Quadrotor, Continuous Point-Mass, Discrete Point-Mass | 41 | 42 | ### Environments 43 | 44 | DiffAero now supports three flight tasks: 45 | - **Position Control** (`env=pc`): The goal is to navigate to and hover on the specified target positions from random initial positions, without colliding with other agents. 46 | - **Obstacle Avoidance** (`env=oa`): The goal is to navigate to and hover on target positions while avoiding collision with environmental obstacles and other quadrotors, given exteroceptive informations: 47 | - Relative positions of obstacles w.r.t. the quadrotor, or 48 | - Image from the depth camera attached to the quadrotor, or 49 | - Ray distance from the LiDAR attached to the quadrotor. 50 | - **Racing** (`env=racing`): The goal is to navigate through a series of gates in the shortest time, without colliding with the gates. 51 | 52 | ### Learning algorithms 53 | 54 | We have implemented several learning algorithms, including RL algorithms and algorithms that exploit the differentiability of the simulator: 55 | 56 | - **Reinforcement Learning algorithms**: 57 | - **PPO** (`algo=ppo`): [Proximal Policy Optimization](https://arxiv.org/abs/1707.06347) 58 | - **Dreamer V3** (`algo=world`): [Mastering Diverse Domains through World Models](http://arxiv.org/abs/2301.04104) 59 | 60 | - **Differential algorithms**: 61 | - **BPTT** (`algo=apg(_sto)`): Direct back-propagation through time, supports deterministic policy (`algo=apg`) and stochastic policy (`algo=apg_sto`) 62 | - **SHAC** (`algo=shac`): [Accelerated Policy Learning with Parallel Differentiable Simulation](http://arxiv.org/abs/2204.07137) 63 | - **SHA2C** (`algo=sha2c`): Short-Horizon Asymmetric Actor-Critic 64 | 65 | ### Dynamical models 66 | 67 | We have implemented four types of dynamic models for the quadrotor: 68 | - **Full Quadrotor Dynamics** (`dynamics=quad`): Simulates the full dynamics of the quadrotor, including the aerodynamic effect, as described in [Efficient and Robust Time-Optimal Trajectory Planning and Control for Agile Quadrotor Flight](http://arxiv.org/abs/2305.02772). 69 | - **(TODO) Simplified Quadrotor Dynamics** (`dynamics=simple`): Simulates the attitude dynamics of the quadrotor, but without considering body rate dynamics, as described in [Learning Quadrotor Control From Visual Features Using Differentiable Simulation](http://arxiv.org/abs/2410.15979). 70 | - **Discrete Point Mass Dynamics** (`dynamics=pmd`): Simulates the quadrotor as a point mass, ignoring its pose for faster simulation and smoother gradient flow, as described in [Back to Newton's Laws: Learning Vision-based Agile Flight via Differentiable Physics](http://arxiv.org/abs/2407.10648). 71 | - **Continuous Point Mass Dynamics** (`dynamics=pmc`): Simulates the quadrotor as a point mass, ignoring its pose, but with continuous time integration. 72 | 73 | ### Sensors 74 | DiffAero supports two types of exteroceptive sensors: 75 | - **Depth Camera** (`sensor=camera`): Provides depth information about the environment. 76 | - **LiDAR** (`sensor=lidar`): Provides distance measurements to nearby obstacles. 77 | 78 | ## Installation 79 | 80 | ### System requirements 81 | 82 | - System: Ubuntu. 83 | - Pytorch 2.x. 84 | 85 | ### Installing the DiffAero 86 | 87 | Clone this repo and install the python package: 88 | 89 | ```bash 90 | git clone https://github.com/zxh0916/diffaero.git 91 | cd diffaero && pip install -e . 92 | ``` 93 | 94 | ## Usage 95 | 96 | ### Basic usage 97 | Under the repo's root directory, run the following command to train a policy (`[a,b,c]` means `a` or `b` or `c`, etc.): 98 | 99 | ```bash 100 | python script/train.py env=[pc,oa,racing] algo=[apg,apg_sto,shac,sha2c,ppo,world] 101 | ``` 102 | 103 | Note that `env=[pc,oa]` means use `env=pc` or `env=oa`, etc. 104 | 105 | Once the training is done, run the following command to test the trained policy: 106 | 107 | ```bash 108 | python script/test.py env=[pc,oa,racing] checkpoint=/absolute/path/to/checkpoints/directory use_training_cfg=True n_envs=64 109 | ``` 110 | 111 | To list all configuration choices, run: 112 | 113 | ```bash 114 | python script/train.py -h 115 | ``` 116 | 117 | To enable tab-completion in command line, run: 118 | ```bash 119 | eval "$(python script/train.py -sc install=bash)" 120 | ``` 121 | 122 | ### Visualization 123 | 124 | #### Visualization with taichi GUI 125 | 126 | DiffAero supports real-time visualization using [taichi GGUI system](https://docs.taichi-lang.org/docs/ggui). To enable the GUI, set `headless=False` in the training or testing command. Note that the taichi GUI can only be used with GPU0 (`device=0`) on workstation with multiple GPUs. For example, to visualize the training process of the Position Control task, run: 127 | ```bash 128 | python script/train.py env=pc headless=False device=0 129 | ``` 130 | 131 | #### Visualize the depth camera and LiDAR data 132 | 133 | To visualize the depth camera and LiDAR data in the Obstacle Avoidance task, set `display_image=True` in the training or testing command. For example, to visualize the depth camera data during testing, run: 134 | ```bash 135 | python script/train.py env=oa display_image=True 136 | ``` 137 | 138 | #### Record First-Person View Videos 139 | 140 | The Obstacle Avoidance task supports recording first-person view videos from the quadrotor's first-person perspective. To record videos, set `record_video=True` in the testing command: 141 | ```bash 142 | python script/train.py env=oa checkpoint=/absolute/path/to/checkpoints/directory use_training_cfg=True n_envs=16 record_video=True 143 | ``` 144 | The recorded videos will be saved in the `outputs/test/YYYY-MM-DD/HH-MM/video` directory under the repo's root directory.a 145 | 146 | ### Sweep across multiple configurations 147 | 148 | DiffAero supports sweeping across multiple configurations using [hydra](https://hydra.cc). For example, you can specify multiple values to one argument by separating them with commas, and hydra will automatically generate all combinations of the specified values. For example, to sweep across different environments and algorithms, you can run: 149 | ```bash 150 | python script/train.py -m env=pc,oa,racing algo=apg,apg_sto,shac,sha2c,ppo,world # generate 3x6=18 combinations, executed sequentially 151 | ``` 152 | 153 | #### Sweep across multiple GPUs in parallel 154 | 155 | For workstations with multiple GPUs, you can specify multiple devices by setting `device` to string containing multiple GPU indices and setting `n_jobs` greater than 1 to sweep through configuation combinations in parallel using [hydra-joblib-launcher](https://hydra.cc/docs/plugins/joblib_launcher/) and [joblib](https://joblib.readthedocs.io/en/stable/). For example, to use the first 4 GPUs (GPU0, GPU1, GPU2, GPU3), run: 156 | ```bash 157 | # generate 2x2x3=12 combinations, executed in parallel on 4 GPUs, with 3 jobs each 158 | python script/train.py -m env=pc,oa algo=apg_sto,shac algo.l_rollout=16,32,64 n_jobs=4 device="0123" 159 | ``` 160 | 161 | #### Automatic Hyperparameter Tuning 162 | 163 | DiffAero supports automatic hyperparameter tuning using [hydra-optuna-sweeper](https://hydra.cc/docs/plugins/optuna_sweeper/) and [Optuna](https://optuna.org/). To search for the hyperparameter configuration that maximizes the success rate, uncomment the `override hydra/sweeper: optuna_sweep` line in `cfg/config_train.yaml`, specify the hyperparameters to be optimized in the `cfg/hydra/sweeper/optuna_sweep.yaml` file, and run 164 | ```python 165 | python script/train.py -m 166 | ``` 167 | This feature can be combined with multi-device parallel sweep to further speed up the hyperparameter search. 168 | 169 | ## Deploy 170 | 171 | If you want to evaluate and deploy your trained policy in Gazebo or in real world, please refer to this repository (Coming soon). 172 | 173 | ## TODO-List 174 | - [ ] Add simplified quadrotor dynamics model. 175 | - [ ] Add support to train policies with [rsl_rl](https://github.com/leggedrobotics/rsl_rl) (maybe). 176 | - [ ] Update the LiDAR sensor to be more realistic. 177 | 178 | ## Citation 179 | 180 | If you find DiffAero useful in your research, please consider citing: 181 | 182 | ```bibtex 183 | @misc{zhang2025diffaero, 184 | title={DiffAero: A GPU-Accelerated Differentiable Simulation Framework for Efficient Quadrotor Policy Learning}, 185 | author={Xinhong Zhang and Runqing Wang and Yunfan Ren and Jian Sun and Hao Fang and Jie Chen and Gang Wang}, 186 | year={2025}, 187 | eprint={2509.10247}, 188 | archivePrefix={arXiv}, 189 | primaryClass={cs.RO}, 190 | url={https://arxiv.org/abs/2509.10247}, 191 | } 192 | ``` 193 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DIFFAERO_ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) 4 | DIFFAERO_ENVS_DIR = os.path.join(DIFFAERO_ROOT_DIR, 'envs') 5 | 6 | print("The root dir of DiffAero:", DIFFAERO_ROOT_DIR) 7 | 8 | from . import env 9 | from . import algo 10 | from . import network 11 | from . import script 12 | from . import utils 13 | 14 | __all__ = [ 15 | "env", 16 | "algo", 17 | "network", 18 | "script", 19 | "utils", 20 | ] -------------------------------------------------------------------------------- /algo/APG.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Tuple, Dict, Union, Optional 2 | import os 3 | 4 | from omegaconf import DictConfig 5 | import torch 6 | from torch import Tensor 7 | from tensordict import TensorDict 8 | 9 | from diffaero.network.agents import ( 10 | tensordict2tuple, 11 | DeterministicActor, 12 | StochasticActor) 13 | from diffaero.utils.runner import timeit 14 | from diffaero.utils.exporter import PolicyExporter 15 | 16 | class APG: 17 | def __init__( 18 | self, 19 | cfg: DictConfig, 20 | obs_dim: int, 21 | action_dim: int, 22 | l_rollout: int, 23 | device: torch.device 24 | ): 25 | self.actor = DeterministicActor(cfg.network, obs_dim, action_dim).to(device) 26 | self.optimizer = torch.optim.Adam(self.actor.parameters(), lr=cfg.lr) 27 | self.max_grad_norm: float = cfg.max_grad_norm 28 | self.l_rollout: int = l_rollout 29 | self.actor_loss = torch.zeros(1, device=device) 30 | self.device = device 31 | 32 | def act(self, obs, test=False): 33 | # type: (Union[Tensor, TensorDict], bool) -> Tuple[Tensor, Dict[str, Tensor]] 34 | return self.actor(tensordict2tuple(obs)), {} 35 | 36 | def record_loss(self, loss, policy_info, env_info): 37 | # type: (Tensor, Dict[str, Tensor], Dict[str, Tensor]) -> None 38 | self.actor_loss += loss.mean() 39 | 40 | def update_actor(self): 41 | # type: () -> Tuple[Dict[str, float], Dict[str, float]] 42 | self.actor_loss = self.actor_loss / self.l_rollout 43 | self.optimizer.zero_grad() 44 | self.actor_loss.backward() 45 | grad_norm = sum([p.grad.data.norm().item() ** 2 for p in self.actor.parameters()]) ** 0.5 46 | if self.max_grad_norm is not None: 47 | torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.max_grad_norm) 48 | self.optimizer.step() 49 | actor_loss = self.actor_loss.item() 50 | self.actor_loss = torch.zeros(1, device=self.device) 51 | return {"actor_loss": actor_loss}, {"actor_grad_norm": grad_norm} 52 | 53 | @timeit 54 | def step(self, cfg, env, logger, obs, on_step_cb=None): 55 | for _ in range(cfg.l_rollout): 56 | action, policy_info = self.act(obs) 57 | obs, (loss, reward), terminated, env_info = env.step(env.rescale_action(action)) 58 | self.reset(env_info["reset"]) 59 | self.record_loss(loss, policy_info, env_info) 60 | if on_step_cb is not None: 61 | on_step_cb( 62 | obs=obs, 63 | action=action, 64 | policy_info=policy_info, 65 | env_info=env_info) 66 | 67 | losses, grad_norms = self.update_actor() 68 | self.detach() 69 | return obs, policy_info, env_info, losses, grad_norms 70 | 71 | def save(self, path): 72 | if not os.path.exists(path): 73 | os.makedirs(path) 74 | self.actor.save(path) 75 | 76 | def load(self, path): 77 | self.actor.load(path) 78 | 79 | def reset(self, env_idx: Tensor): 80 | if self.actor.is_rnn_based: 81 | self.actor.reset(env_idx) 82 | 83 | def detach(self): 84 | if self.actor.is_rnn_based: 85 | self.actor.detach() 86 | 87 | @staticmethod 88 | def build(cfg, env, device): 89 | return APG( 90 | cfg=cfg, 91 | obs_dim=env.obs_dim, 92 | action_dim=env.action_dim, 93 | l_rollout=cfg.l_rollout, 94 | device=device) 95 | 96 | def export( 97 | self, 98 | path: str, 99 | export_cfg: DictConfig, 100 | verbose: bool = False, 101 | ): 102 | PolicyExporter(self.actor).export(path, export_cfg, verbose) 103 | 104 | 105 | class APG_stochastic(APG): 106 | def __init__( 107 | self, 108 | cfg: DictConfig, 109 | obs_dim: int, 110 | action_dim: int, 111 | l_rollout: int, 112 | device: torch.device 113 | ): 114 | super().__init__(cfg, obs_dim, action_dim, l_rollout, device) 115 | del self.optimizer; del self.actor 116 | self.actor = StochasticActor(cfg.network, obs_dim, action_dim).to(device) 117 | self.optimizer = torch.optim.Adam(self.actor.parameters(), lr=cfg.lr) 118 | self.entropy_loss = torch.zeros(1, device=device) 119 | self.entropy_weight: float = cfg.entropy_weight 120 | 121 | def act(self, obs, test=False): 122 | # type: (Union[Tensor, TensorDict], bool) -> Tuple[Tensor, Dict[str, Tensor]] 123 | action, sample, logprob, entropy = self.actor(tensordict2tuple(obs), test=test) 124 | return action, {"sample": sample, "logprob": logprob, "entropy": entropy} 125 | 126 | def record_loss(self, loss, policy_info, env_info): 127 | # type: (Tensor, Dict[str, Tensor], Dict[str, Tensor]) -> None 128 | self.actor_loss += loss.mean() 129 | self.entropy_loss -= policy_info["entropy"].mean() 130 | 131 | def update_actor(self): 132 | # type: () -> Tuple[Dict[str, float], Dict[str, float]] 133 | actor_loss = self.actor_loss / self.l_rollout 134 | entropy_loss = self.entropy_loss / self.l_rollout 135 | total_loss = actor_loss + self.entropy_weight * entropy_loss 136 | self.optimizer.zero_grad() 137 | total_loss.backward() 138 | if self.max_grad_norm is not None: 139 | grad_norm = torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.max_grad_norm) 140 | else: 141 | grad_norm = torch.nn.utils.get_total_norm(self.actor.parameters()) 142 | self.optimizer.step() 143 | self.actor_loss = torch.zeros(1, device=self.device) 144 | self.entropy_loss = torch.zeros(1, device=self.device) 145 | return {"actor_loss": actor_loss.mean().item(), "entropy_loss": entropy_loss.mean().item()}, {"actor_grad_norm": grad_norm} 146 | 147 | @staticmethod 148 | def build(cfg, env, device): 149 | return APG_stochastic( 150 | cfg=cfg, 151 | obs_dim=env.obs_dim, 152 | action_dim=env.action_dim, 153 | l_rollout=cfg.l_rollout, 154 | device=device) -------------------------------------------------------------------------------- /algo/MASHAC.py: -------------------------------------------------------------------------------- 1 | from termios import N_SLIP 2 | from typing import Union, Sequence, Tuple, Dict, Optional 3 | from copy import deepcopy 4 | import os 5 | 6 | from omegaconf import DictConfig 7 | import torch 8 | from torch import Tensor 9 | import torch.nn.functional as F 10 | from tensordict import TensorDict 11 | 12 | from diffaero.algo.buffer import RolloutBufferMASHAC, RNNStateBuffer 13 | from diffaero.network.agents import tensordict2tuple 14 | from diffaero.network.multiagents import MAStochasticActorCriticV 15 | from diffaero.utils.runner import timeit 16 | from diffaero.utils.exporter import PolicyExporter 17 | 18 | 19 | class MASHAC: 20 | def __init__( 21 | self, 22 | cfg: DictConfig, 23 | obs_dim: int, 24 | global_state_dim: int, 25 | n_agents: int, 26 | action_dim: int, 27 | n_envs: int, 28 | l_rollout: int, 29 | device: torch.device 30 | ): 31 | self.obs_dim = obs_dim 32 | self.action_dim = action_dim 33 | self.agent = MAStochasticActorCriticV(cfg.network, obs_dim, global_state_dim, action_dim).to(device) 34 | if self.agent.is_rnn_based: 35 | self.rnn_state_buffer = RNNStateBuffer(l_rollout, n_envs, cfg.network.rnn_hidden_dim, cfg.network.rnn_n_layers, device) 36 | self.actor_optim = torch.optim.Adam(self.agent.actor.parameters(), lr=cfg.actor_lr) 37 | self.critic_optim = torch.optim.Adam(self.agent.critic.parameters(), lr=cfg.critic_lr) 38 | self.buffer = RolloutBufferMASHAC(l_rollout, n_envs, obs_dim, global_state_dim, n_agents, device) 39 | self._critic_target = deepcopy(self.agent.critic) 40 | for p in self._critic_target.parameters(): 41 | p.requires_grad_(False) 42 | 43 | self.discount: float = cfg.gamma 44 | self.lmbda: float = cfg.lmbda 45 | self.entropy_weight: float = cfg.entropy_weight 46 | self.actor_grad_norm: float = cfg.actor_grad_norm 47 | self.critic_grad_norm: float = cfg.critic_grad_norm 48 | self.target_update_rate: float = cfg.target_update_rate 49 | self.n_minibatch: int = cfg.n_minibatch 50 | self.n_envs: int = n_envs 51 | self.l_rollout: int = l_rollout 52 | self.device = device 53 | 54 | self.actor_loss = torch.tensor(0., device=self.device) 55 | self.rollout_gamma = torch.ones(self.n_envs, device=self.device) 56 | self.cumulated_loss = torch.zeros(self.n_envs, device=self.device) 57 | self.entropy_loss = torch.tensor(0., device=self.device) 58 | 59 | def act(self, obs, global_state, test=False): 60 | # type: (Union[Tensor, TensorDict], Tensor, bool) -> Tuple[Tensor, Dict[str, Tensor]] 61 | if self.agent.is_rnn_based: 62 | self.rnn_state_buffer.add(self.agent.actor.actor_mean.hidden_state, self.agent.critic.critic.hidden_state) 63 | action, sample, logprob, entropy, value = self.agent.get_action_and_value(tensordict2tuple(obs), global_state, test=test) 64 | return action, {"sample": sample, "logprob": logprob, "entropy": entropy, "value": value} 65 | 66 | def value_target(self, global_state): 67 | # type: (Tensor) -> Tensor 68 | return self._critic_target(global_state).squeeze(-1) 69 | 70 | @torch.no_grad() 71 | def bootstrap_tdlambda(self): 72 | # value of the next obs should be zero if the next obs is a terminal obs 73 | next_values = self.buffer.next_values * (1 - self.buffer.next_terminated) 74 | if self.lmbda == 0.: 75 | target_values = self.buffer.rewards + self.discount * next_values 76 | else: 77 | target_values = torch.zeros_like(next_values).to(self.device) 78 | Ai = torch.zeros(self.n_envs, dtype=torch.float32, device=self.device) 79 | Bi = torch.zeros(self.n_envs, dtype=torch.float32, device=self.device) 80 | lam = torch.ones(self.n_envs, dtype=torch.float32, device=self.device) 81 | self.buffer.next_dones[-1] = 1. 82 | for i in reversed(range(self.l_rollout)): 83 | lam = lam * self.lmbda * (1. - self.buffer.next_dones[i]) + self.buffer.next_dones[i] 84 | Ai = (1. - self.buffer.next_dones[i]) * ( 85 | self.discount * (self.lmbda * Ai + next_values[i]) + \ 86 | (1. - lam) / (1. - self.lmbda) * self.buffer.rewards[i]) 87 | Bi = self.discount * (next_values[i] * self.buffer.next_dones[i] + Bi * (1. - self.buffer.next_dones[i])) + \ 88 | self.buffer.rewards[i] 89 | # Bi = self.discount * torch.where(self.buffer.next_dones[i], next_values[i], Bi) + self.buffer.rewards[i] 90 | target_values[i] = (1.0 - self.lmbda) * Ai + lam * Bi 91 | return target_values.view(-1) 92 | 93 | @torch.no_grad() 94 | def bootstrap_gae(self): 95 | advantages = torch.zeros_like(self.buffer.rewards) 96 | lastgaelam = 0 97 | for t in reversed(range(self.l_rollout)): 98 | nextnonterminal = 1.0 - self.buffer.next_terminated[t] 99 | nextnonreset = 1.0 - self.buffer.next_dones[t] 100 | # nextnonterminal = 1.0 - self.buffer.next_dones[t] 101 | nextvalues = self.buffer.next_values[t] 102 | # TD-error / vanilla advantage function. 103 | delta = self.buffer.rewards[t] + self.discount * nextvalues * nextnonterminal - self.buffer.values[t] 104 | # Generalized Advantage Estimation bootstraping formula. 105 | advantages[t] = lastgaelam = delta + self.discount * self.lmbda * nextnonreset * lastgaelam 106 | target_values = advantages + self.buffer.values 107 | return target_values.view(-1) 108 | 109 | def record_loss(self, loss, policy_info, env_info, last_step=False): 110 | # type: (Tensor, Dict[str, Tensor], Dict[str, Tensor], Optional[bool]) -> Tensor 111 | reset = torch.ones_like(env_info["reset"]) if last_step else env_info["reset"] 112 | truncated = torch.ones_like(env_info["reset"]) if last_step else env_info["truncated"] 113 | # add cumulated loss if rollout ends or trajectory ends (terminated or truncated) 114 | self.cumulated_loss = self.cumulated_loss + self.rollout_gamma * loss 115 | cumulated_loss = self.cumulated_loss[reset].sum() 116 | # add terminal value if rollout ends or truncated 117 | next_value = self.value_target(tensordict2tuple(env_info["next_state_before_reset"])) 118 | terminal_value = (self.rollout_gamma * self.discount * next_value)[truncated].sum() 119 | assert terminal_value.requires_grad == True 120 | # add up the discounted cumulated loss, the terminal value and the entropy loss 121 | self.actor_loss = self.actor_loss + cumulated_loss + terminal_value 122 | self.entropy_loss = self.entropy_loss - policy_info["entropy"].sum() 123 | # reset the discount factor, clear the cumulated loss if trajectory ends 124 | self.rollout_gamma = torch.where(reset, 1, self.rollout_gamma * self.discount) 125 | self.cumulated_loss = torch.where(reset, 0, self.cumulated_loss) 126 | return next_value.detach() 127 | 128 | def clear_loss(self): 129 | self.rollout_gamma.fill_(1.) 130 | self.actor_loss.detach_().fill_(0.) 131 | self.cumulated_loss.detach_().fill_(0.) 132 | self.entropy_loss.detach_().fill_(0.) 133 | 134 | def update_actor(self) -> Dict[str, float]: 135 | actor_loss = self.actor_loss / (self.n_envs * self.l_rollout) 136 | entropy_loss = self.entropy_loss / (self.n_envs * self.l_rollout) 137 | total_loss = actor_loss + self.entropy_weight * entropy_loss 138 | self.actor_optim.zero_grad() 139 | total_loss.backward() 140 | grad_norm = sum([p.grad.data.norm().item() ** 2 for p in self.agent.actor.parameters()]) ** 0.5 141 | if self.actor_grad_norm is not None: 142 | torch.nn.utils.clip_grad_norm_(self.agent.actor.parameters(), max_norm=self.actor_grad_norm) 143 | self.actor_optim.step() 144 | return {"actor_loss": actor_loss.item(), "entropy_loss": entropy_loss.item()}, {"actor_grad_norm": grad_norm} 145 | 146 | def update_critic(self, target_values: Tensor) -> Dict[str, float]: 147 | T, N = self.l_rollout, self.n_envs 148 | batch_indices = torch.randperm(T*N, device=self.device) 149 | mb_size = T*N // self.n_minibatch 150 | global_states = self.buffer.global_states.flatten(0, 1) 151 | if self.agent.is_rnn_based: 152 | critic_hidden_state = self.rnn_state_buffer.critic_rnn_state.flatten(0, 1) 153 | for start in range(0, T*N, mb_size): 154 | end = start + mb_size 155 | mb_indices = batch_indices[start:end] 156 | if self.agent.is_rnn_based: 157 | values = self.agent.get_value(global_states[mb_indices], 158 | critic_hidden_state[mb_indices].permute(1, 0, 2)) 159 | else: 160 | values = self.agent.get_value(global_states[mb_indices]) 161 | critic_loss = F.mse_loss(values, target_values[mb_indices]) 162 | self.critic_optim.zero_grad() 163 | critic_loss.backward() 164 | grad_norm = sum([p.grad.data.norm().item() ** 2 for p in self.agent.critic.parameters()]) ** 0.5 165 | if self.critic_grad_norm is not None: 166 | torch.nn.utils.clip_grad_norm_(self.agent.critic.parameters(), max_norm=self.critic_grad_norm) 167 | self.critic_optim.step() 168 | for p, p_t in zip(self.agent.critic.parameters(), self._critic_target.parameters()): 169 | p_t.data.lerp_(p.data, self.target_update_rate) 170 | return {"critic_loss": critic_loss.item()}, {"critic_grad_norm": grad_norm} 171 | 172 | @timeit 173 | def step(self, cfg, env, logger, obs, on_step_cb=None): 174 | obs, global_state = obs 175 | self.buffer.clear() 176 | if self.agent.is_rnn_based: 177 | self.rnn_state_buffer.clear() 178 | self.clear_loss() 179 | for t in range(cfg.l_rollout): 180 | action, policy_info = self.act(obs, global_state) 181 | (next_obs, next_global_state), (loss, reward), terminated, env_info = env.step(env.rescale_action(action), next_state_before_reset=True) 182 | next_value = self.record_loss(loss, policy_info, env_info, last_step=(t==cfg.l_rollout-1)) 183 | # divide by 10 to avoid disstability 184 | self.buffer.add( 185 | obs=obs, 186 | global_state=global_state, 187 | reward=loss/10, 188 | value=policy_info["value"], 189 | next_done=env_info["reset"], 190 | next_terminated=terminated, 191 | next_value=next_value 192 | ) 193 | self.reset(env_info["reset"]) 194 | obs = next_obs 195 | global_state = next_global_state 196 | if on_step_cb is not None: 197 | on_step_cb( 198 | obs=obs, 199 | action=action, 200 | policy_info=policy_info, 201 | env_info=env_info) 202 | target_values = self.bootstrap_gae() 203 | actor_losses, actor_grad_norms = self.update_actor() 204 | critic_losses, critic_grad_norms = self.update_critic(target_values) 205 | self.detach() 206 | losses = {**actor_losses, **critic_losses} 207 | grad_norms = {**actor_grad_norms, **critic_grad_norms} 208 | return (obs, global_state), policy_info, env_info, losses, grad_norms 209 | 210 | def save(self, path): 211 | if not os.path.exists(path): 212 | os.makedirs(path) 213 | self.agent.save(path) 214 | 215 | def load(self, path): 216 | self.agent.load(path) 217 | 218 | def reset(self, env_idx: Tensor): 219 | if self.agent.is_rnn_based: 220 | self.agent.reset(env_idx) 221 | 222 | def detach(self): 223 | if self.agent.is_rnn_based: 224 | self.agent.detach() 225 | self._critic_target.detach() 226 | 227 | @staticmethod 228 | def build(cfg, env, device): 229 | return MASHAC( 230 | cfg=cfg, 231 | obs_dim=env.obs_dim, 232 | global_state_dim=env.global_state_dim, 233 | n_agents=env.n_agents, 234 | action_dim=env.action_dim, 235 | n_envs=env.n_envs, 236 | l_rollout=cfg.l_rollout, 237 | device=device) 238 | 239 | def export( 240 | self, 241 | path: str, 242 | export_cfg: DictConfig, 243 | verbose: bool = False, 244 | ): 245 | PolicyExporter(self.agent.actor).export(path, export_cfg, verbose) -------------------------------------------------------------------------------- /algo/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | 6 | from diffaero.algo.PPO import PPO, AsymmetricPPO 7 | from diffaero.algo.APG import APG, APG_stochastic 8 | from diffaero.algo.SHAC import SHAC, SHA2C 9 | from diffaero.algo.MASHAC import MASHAC 10 | from diffaero.algo.dreamerv3 import World_Agent 11 | AGENT_ALIAS = { 12 | "ppo": PPO, 13 | "appo": AsymmetricPPO, 14 | "shac": SHAC, 15 | "sha2c": SHA2C, 16 | "mashac": MASHAC, 17 | "apg": APG, 18 | "apg_sto": APG_stochastic, 19 | "world": World_Agent, 20 | } 21 | 22 | def build_agent(cfg: DictConfig, env, device: torch.device): 23 | return AGENT_ALIAS[cfg.name].build(cfg, env, device) -------------------------------------------------------------------------------- /algo/buffer.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from tensordict import TensorDict 6 | 7 | from diffaero.utils.logger import Logger 8 | from diffaero.utils.runner import timeit 9 | 10 | class RNNStateBuffer: 11 | def __init__(self, l_rollout, n_envs, rnn_hidden_dim, rnn_n_layers, device): 12 | # type: (int, int, int, int, torch.device) -> None 13 | factory_kwargs = {"dtype": torch.float32, "device": device} 14 | self.actor_rnn_state = torch.zeros((l_rollout, n_envs, rnn_n_layers, rnn_hidden_dim), **factory_kwargs) 15 | self.critic_rnn_state = torch.zeros((l_rollout, n_envs, rnn_n_layers, rnn_hidden_dim), **factory_kwargs) 16 | 17 | def clear(self): 18 | self.step = 0 19 | 20 | @torch.no_grad() 21 | def add(self, actor_hidden_state: Optional[Tensor], critic_hidden_state: Optional[Tensor] = None): 22 | if actor_hidden_state is not None: 23 | self.actor_rnn_state[self.step] = actor_hidden_state.permute(1, 0, 2) 24 | if critic_hidden_state is not None: 25 | self.critic_rnn_state[self.step] = critic_hidden_state.permute(1, 0, 2) 26 | self.step += 1 27 | 28 | class RolloutBufferSHAC: 29 | def __init__(self, l_rollout, n_envs, obs_dim, action_dim, device): 30 | # type: (int, int, Union[int, Tuple[int, Tuple[int, int]]], int, torch.device) -> None 31 | factory_kwargs = {"dtype": torch.float32, "device": device} 32 | 33 | assert isinstance(obs_dim, tuple) or isinstance(obs_dim, int) 34 | if isinstance(obs_dim, tuple): 35 | self.obs = TensorDict({ 36 | "state": torch.zeros((l_rollout, n_envs, obs_dim[0]), **factory_kwargs), 37 | "perception": torch.zeros((l_rollout, n_envs, obs_dim[1][0], obs_dim[1][1]), **factory_kwargs) 38 | }, batch_size=(l_rollout, n_envs)) 39 | else: 40 | self.obs = torch.zeros((l_rollout, n_envs, obs_dim), **factory_kwargs) 41 | self.samples = torch.zeros((l_rollout, n_envs, action_dim), **factory_kwargs) 42 | self.logprobs = torch.zeros((l_rollout, n_envs), **factory_kwargs) 43 | self.losses = torch.zeros((l_rollout, n_envs), **factory_kwargs) 44 | self.values = torch.zeros((l_rollout, n_envs), **factory_kwargs) 45 | self.next_dones = torch.zeros((l_rollout, n_envs), **factory_kwargs) 46 | self.next_terminated = torch.zeros((l_rollout, n_envs), **factory_kwargs) 47 | self.next_values = torch.zeros((l_rollout, n_envs), **factory_kwargs) 48 | 49 | def clear(self): 50 | self.step = 0 51 | 52 | @torch.no_grad() 53 | def add(self, obs, sample, logprob, loss, value, next_done, next_terminated, next_value): 54 | # type: (Union[Tensor, TensorDict], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> None 55 | self.obs[self.step] = obs 56 | self.samples[self.step] = sample 57 | self.logprobs[self.step] = logprob 58 | self.losses[self.step] = loss 59 | self.values[self.step] = value 60 | self.next_dones[self.step] = next_done.float() 61 | self.next_terminated[self.step] = next_terminated.float() 62 | self.next_values[self.step] = next_value 63 | self.step += 1 64 | 65 | class RolloutBufferMASHAC: 66 | def __init__(self, l_rollout, n_envs, obs_dim, global_state_dim, n_agents, device): 67 | # type: (int, int, Union[int, Tuple[int, Tuple[int, int]]], int, int, torch.device) -> None 68 | factory_kwargs = {"dtype": torch.float32, "device": device} 69 | 70 | assert isinstance(obs_dim, tuple) or isinstance(obs_dim, int) 71 | assert isinstance(global_state_dim, int) 72 | if isinstance(obs_dim, tuple): 73 | self.obs = TensorDict({ 74 | "state": torch.zeros((l_rollout, n_envs, n_agents, obs_dim[0]), **factory_kwargs), 75 | "perception": torch.zeros((l_rollout, n_envs, obs_dim[1][0], obs_dim[1][1]), **factory_kwargs) 76 | }, batch_size=(l_rollout, n_envs)) 77 | else: 78 | self.obs = torch.zeros((l_rollout, n_envs, n_agents, obs_dim), **factory_kwargs) 79 | self.global_states = torch.zeros((l_rollout, n_envs, global_state_dim), **factory_kwargs) 80 | self.rewards = torch.zeros((l_rollout, n_envs), **factory_kwargs) 81 | self.values = torch.zeros((l_rollout, n_envs), **factory_kwargs) 82 | self.next_dones = torch.zeros((l_rollout, n_envs), **factory_kwargs) 83 | self.next_terminated = torch.zeros((l_rollout, n_envs), **factory_kwargs) 84 | self.next_values = torch.zeros((l_rollout, n_envs), **factory_kwargs) 85 | 86 | def clear(self): 87 | self.step = 0 88 | 89 | @torch.no_grad() 90 | def add(self, obs, global_state, reward, value, next_done, next_terminated, next_value): 91 | # type: (Union[Tensor, TensorDict], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> None 92 | self.obs[self.step] = obs 93 | self.global_states[self.step] = global_state 94 | self.rewards[self.step] = reward 95 | self.values[self.step] = value 96 | self.next_dones[self.step] = next_done.float() 97 | self.next_terminated[self.step] = next_terminated.float() 98 | self.next_values[self.step] = next_value 99 | self.step += 1 100 | 101 | 102 | class RolloutBufferPPO: 103 | def __init__(self, l_rollout, n_envs, obs_dim, action_dim, device): 104 | # type: (int, int, Union[int, Tuple[int, Tuple[int, int]]], int, torch.device) -> None 105 | factory_kwargs = {"dtype": torch.float32, "device": device} 106 | 107 | assert isinstance(obs_dim, tuple) or isinstance(obs_dim, int) 108 | if isinstance(obs_dim, tuple): 109 | self.obs = TensorDict({ 110 | "state": torch.zeros((l_rollout, n_envs, obs_dim[0]), **factory_kwargs), 111 | "perception": torch.zeros((l_rollout, n_envs, obs_dim[1][0], obs_dim[1][1]), **factory_kwargs) 112 | }, batch_size=(l_rollout, n_envs)) 113 | else: 114 | self.obs = torch.zeros((l_rollout, n_envs, obs_dim), **factory_kwargs) 115 | self.samples = torch.zeros((l_rollout, n_envs, action_dim), **factory_kwargs) 116 | self.logprobs = torch.zeros((l_rollout, n_envs), **factory_kwargs) 117 | self.rewards = torch.zeros((l_rollout, n_envs), **factory_kwargs) 118 | self.next_dones = torch.zeros((l_rollout, n_envs), **factory_kwargs) 119 | self.values = torch.zeros((l_rollout, n_envs), **factory_kwargs) 120 | self.next_values = torch.zeros((l_rollout, n_envs), **factory_kwargs) 121 | 122 | def clear(self): 123 | self.step = 0 124 | 125 | @torch.no_grad() 126 | def add(self, obs, sample, logprob, reward, next_done, value, next_value): 127 | # type: (Union[Tensor, TensorDict], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> None 128 | self.obs[self.step] = obs 129 | self.samples[self.step] = sample 130 | self.logprobs[self.step] = logprob 131 | self.rewards[self.step] = reward 132 | self.next_dones[self.step] = next_done.float() 133 | self.values[self.step] = value 134 | self.next_values[self.step] = next_value 135 | self.step += 1 136 | 137 | 138 | class RolloutBufferAPPO(RolloutBufferPPO): 139 | def __init__(self, l_rollout, n_envs, obs_dim, state_dim, action_dim, device): 140 | # type: (int, int, Union[int, Tuple[int, Tuple[int, int]]], int, int, torch.device) -> None 141 | super().__init__(l_rollout, n_envs, obs_dim, action_dim, device) 142 | factory_kwargs = {"dtype": torch.float32, "device": device} 143 | self.states = torch.zeros((l_rollout, n_envs, state_dim), **factory_kwargs) 144 | 145 | @torch.no_grad() 146 | def add(self, obs, state, sample, logprob, reward, next_done, value, next_value): 147 | # type: (Union[Tensor, TensorDict], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> None 148 | self.states[self.step] = state 149 | super().add(obs, sample, logprob, reward, next_done, value, next_value) -------------------------------------------------------------------------------- /algo/dreamerv3/__init__.py: -------------------------------------------------------------------------------- 1 | from .world import World_Agent -------------------------------------------------------------------------------- /algo/dreamerv3/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .state_predictor import DepthStateModelCfg 2 | from .agent import ActorCriticConfig -------------------------------------------------------------------------------- /algo/dreamerv3/models/agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | import copy 8 | from dataclasses import dataclass 9 | from torch.cuda.amp import autocast 10 | 11 | from .blocks import SymLogTwoHotLoss, MLP 12 | 13 | class EMAScalar(): 14 | def __init__(self, decay) -> None: 15 | self.scalar = 0.0 16 | self.decay = decay 17 | 18 | def __call__(self, value): 19 | self.update(value) 20 | return self.get() 21 | 22 | def update(self, value): 23 | self.scalar = self.scalar * self.decay + value * (1 - self.decay) 24 | 25 | def get(self): 26 | return self.scalar 27 | 28 | 29 | class ContDist: 30 | def __init__(self, dist=None, absmax=None): 31 | super().__init__() 32 | self._dist = dist 33 | self.mean = dist.mean 34 | self.absmax = absmax 35 | 36 | def __getattr__(self, name): 37 | return getattr(self._dist, name) 38 | 39 | def entropy(self): 40 | return self._dist.entropy() 41 | 42 | def mode(self): 43 | out = self._dist.mean 44 | if self.absmax is not None: 45 | out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach() 46 | return out 47 | 48 | def sample(self, sample_shape=()): 49 | out = self._dist.rsample(sample_shape) 50 | if self.absmax is not None: 51 | out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach() 52 | return out 53 | 54 | def log_prob(self, x): 55 | return self._dist.log_prob(x) 56 | 57 | 58 | def percentile(x, percentage): 59 | flat_x = torch.flatten(x) 60 | kth = int(percentage*len(flat_x)) 61 | per = torch.kthvalue(flat_x, kth).values 62 | return per 63 | 64 | 65 | def calc_lambda_return(rewards, values, termination, gamma, lam, device, dtype=torch.float32): 66 | # Invert termination to have 0 if the episode ended and 1 otherwise 67 | inv_termination = (termination * -1) + 1 68 | 69 | batch_size, batch_length = rewards.shape[:2] 70 | gamma_return = torch.zeros((batch_size, batch_length+1), dtype=dtype, device=device) 71 | gamma_return[:, -1] = values[:, -1] 72 | for t in reversed(range(batch_length)): # with last bootstrap 73 | gamma_return[:, t] = \ 74 | rewards[:, t] + \ 75 | gamma * inv_termination[:, t] * (1-lam) * values[:, t] + \ 76 | gamma * inv_termination[:, t] * lam * gamma_return[:, t+1] 77 | return gamma_return[:, :-1] 78 | 79 | @dataclass 80 | class ActorCriticConfig: 81 | feat_dim: int 82 | num_layers: int 83 | hidden_dim: int 84 | action_dim: int 85 | gamma: float 86 | lambd: float 87 | entropy_coef: float 88 | device: torch.device 89 | max_std: float=1.0 90 | min_std: float=0.1 91 | 92 | 93 | class ActorCriticAgent(nn.Module): 94 | def __init__(self, cfg: ActorCriticConfig, envs) -> None: 95 | super().__init__() 96 | self.gamma = cfg.gamma 97 | self.lambd = cfg.lambd 98 | self.entropy_coef = cfg.entropy_coef 99 | self.use_amp = False 100 | self.tensor_dtype = torch.bfloat16 if self.use_amp else torch.float32 101 | self._min_std = cfg.min_std 102 | self._max_std = cfg.max_std 103 | self.register_buffer('min_action',torch.tensor(-1.)) 104 | self.register_buffer('max_action', torch.tensor(1.)) 105 | self.min_action: torch.Tensor; self.max_action: torch.Tensor 106 | 107 | self.device = cfg.device 108 | feat_dim = cfg.feat_dim 109 | hidden_dim = cfg.hidden_dim 110 | num_layers = cfg.num_layers 111 | action_dim = cfg.action_dim 112 | 113 | self.symlog_twohot_loss = SymLogTwoHotLoss(255, -20, 20) 114 | 115 | self.actor_mean_std = nn.Sequential( 116 | MLP(feat_dim, hidden_dim, hidden_dim, num_layers, 'ReLU', 'LayerNorm',bias=False), 117 | nn.Linear(hidden_dim, action_dim*2), 118 | ) 119 | self.critic = nn.Sequential( 120 | MLP(feat_dim, hidden_dim, hidden_dim, num_layers, 'ReLU', 'LayerNorm', bias=False), 121 | nn.Linear(hidden_dim, 255), 122 | ) 123 | self.slow_critic = copy.deepcopy(self.critic) 124 | 125 | self.lowerbound_ema = EMAScalar(decay=0.99) 126 | self.upperbound_ema = EMAScalar(decay=0.99) 127 | 128 | self.optimizer = torch.optim.Adam(self.parameters(), lr=3e-4, eps=1e-5) 129 | self.scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp) 130 | 131 | @torch.no_grad() 132 | def update_slow_critic(self, decay=0.98): 133 | for slow_param, param in zip(self.slow_critic.parameters(), self.critic.parameters()): 134 | slow_param.data.copy_(slow_param.data * decay + param.data * (1 - decay)) 135 | 136 | def dist(self,mean,std): 137 | return torch.distributions.Normal(mean,std) 138 | 139 | def policy(self, x): 140 | LOG_STD_MAX = 3 141 | LOG_STD_MIN = -5 142 | mean_std = self.actor_mean_std(x) 143 | mean, std = torch.chunk(mean_std, 2, dim=-1) 144 | log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * ( 145 | torch.tanh(std) + 1) 146 | # std = torch.exp(log_std).expand_as(mean) 147 | std = torch.exp(log_std) 148 | return mean,std 149 | 150 | def value(self, x): 151 | value = self.critic(x) 152 | value = self.symlog_twohot_loss.decode(value) 153 | return value 154 | 155 | @torch.no_grad() 156 | def slow_value(self, x): 157 | value = self.slow_critic(x) 158 | value = self.symlog_twohot_loss.decode(value) 159 | return value 160 | 161 | def get_dist_raw_value(self, x): 162 | mean,std = self.policy(x[:,:-1]) 163 | dist = self.dist(mean,std) 164 | raw_value = self.critic(x) 165 | return dist, raw_value 166 | 167 | @torch.no_grad() 168 | def sample(self, latent, greedy=False): 169 | self.eval() 170 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 171 | mean,std = self.policy(latent) 172 | dist = self.dist(mean,std) 173 | if greedy: 174 | sample = mean 175 | else: 176 | sample = dist.sample() 177 | action = (self.max_action - self.min_action) * (torch.tanh(sample)*0.5 + 0.5) + self.min_action 178 | return action,sample 179 | 180 | def sample_as_env_action(self, latent, greedy=False): 181 | action = self.sample(latent, greedy) 182 | return action.to(torch.float32).detach().cpu().squeeze(0).numpy() 183 | 184 | def update(self, latent, action, reward, termination, logger=None): 185 | ''' 186 | Update policy and value model 187 | ''' 188 | self.train() 189 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=self.use_amp): 190 | dist, raw_value = self.get_dist_raw_value(latent) 191 | log_prob = dist.log_prob(action) 192 | log_prob = log_prob.sum(-1) 193 | entropy = dist.entropy().sum(-1) 194 | 195 | # decode value, calc lambda return 196 | slow_value = self.slow_value(latent) 197 | slow_lambda_return = calc_lambda_return(reward, slow_value, termination, self.gamma, self.lambd,self.device) 198 | value = self.symlog_twohot_loss.decode(raw_value) 199 | lambda_return = calc_lambda_return(reward, value, termination, self.gamma, self.lambd,self.device) 200 | 201 | # update value function with slow critic regularization 202 | value_loss = self.symlog_twohot_loss(raw_value[:, :-1], lambda_return.detach()) 203 | slow_value_regularization_loss = self.symlog_twohot_loss(raw_value[:, :-1], slow_lambda_return.detach()) 204 | 205 | lower_bound = self.lowerbound_ema(percentile(lambda_return, 0.05)) 206 | upper_bound = self.upperbound_ema(percentile(lambda_return, 0.95)) 207 | S = upper_bound-lower_bound 208 | norm_ratio = torch.max(torch.ones(1).to(self.device), S) # max(1, S) in the paper 209 | norm_advantage = (lambda_return-value[:, :-1]) / norm_ratio 210 | policy_loss = -(log_prob * norm_advantage.detach()).mean() 211 | 212 | entropy_loss = entropy.mean() 213 | 214 | loss = policy_loss + value_loss + slow_value_regularization_loss - self.entropy_coef * entropy_loss 215 | 216 | # gradient descent 217 | self.scaler.scale(loss).backward() 218 | self.scaler.unscale_(self.optimizer) # for clip grad 219 | gradnorm = torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=100.0) 220 | self.scaler.step(self.optimizer) 221 | self.scaler.update() 222 | self.optimizer.zero_grad(set_to_none=True) 223 | 224 | self.update_slow_critic() 225 | 226 | if logger is not None: 227 | logger.log('ActorCritic/policy_loss', policy_loss.item()) 228 | logger.log('ActorCritic/value_loss', value_loss.item()) 229 | logger.log('ActorCritic/entropy_loss', entropy_loss.item()) 230 | logger.log('ActorCritic/S', S.item()) 231 | logger.log('ActorCritic/gradnorm', gradnorm.item()) 232 | logger.log('ActorCritic/total_loss', loss.item()) 233 | 234 | agent_info = { 235 | 'ActorCritic/policy_loss': policy_loss.item(), 236 | 'ActorCritic/value_loss': value_loss.item(), 237 | 'ActorCritic/entropy_loss': entropy_loss.item(), 238 | 'ActorCritic/S': S.item(), 239 | 'ActorCritic/gradnorm': gradnorm.item(), 240 | 'ActorCritic/total_loss': loss.item() 241 | } 242 | 243 | return agent_info -------------------------------------------------------------------------------- /algo/dreamerv3/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .denoiser import Denoiser, DenoiserConfig, SigmaDistributionConfig 2 | from .inner_model import InnerModelConfig 3 | from .diffusion_sampler import DiffusionSampler, DiffusionSamplerConfig 4 | -------------------------------------------------------------------------------- /algo/dreamerv3/models/diffusion/denoiser.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from data import Batch 10 | from .inner_model import InnerModel, InnerModelConfig 11 | 12 | 13 | def add_dims(input: Tensor, n: int) -> Tensor: 14 | return input.reshape(input.shape + (1,) * (n - input.ndim)) 15 | 16 | 17 | @dataclass 18 | class SigmaDistributionConfig: 19 | loc: float 20 | scale: float 21 | sigma_min: float 22 | sigma_max: float 23 | 24 | 25 | @dataclass 26 | class DenoiserConfig: 27 | inner_model: InnerModelConfig 28 | sigma_data: float 29 | sigma_offset_noise: float 30 | 31 | 32 | class Denoiser(nn.Module): 33 | def __init__(self, cfg: DenoiserConfig) -> None: 34 | super().__init__() 35 | self.cfg = cfg 36 | self.inner_model = InnerModel(cfg.inner_model) 37 | self.sample_sigma_training = None 38 | 39 | @property 40 | def device(self) -> torch.device: 41 | return self.inner_model.noise_emb.weight.device 42 | 43 | def setup_training(self, cfg: SigmaDistributionConfig) -> None: 44 | assert self.sample_sigma_training is None 45 | 46 | def sample_sigma(n: int, device: torch.device): 47 | s = torch.randn(n, device=device) * cfg.scale + cfg.loc 48 | return s.exp().clip(cfg.sigma_min, cfg.sigma_max) 49 | 50 | self.sample_sigma_training = sample_sigma 51 | 52 | def forward(self, noisy_next_obs: Tensor, sigma: Tensor, obs: Tensor, act: Tensor,drone_states:Tensor,obstacles:Tensor) -> Tuple[Tensor, Tensor]: 53 | c_in, c_out, c_skip, c_noise = self._compute_conditioners(sigma) 54 | rescaled_obs = obs / self.cfg.sigma_data ##这里做rescale作用未知 55 | rescaled_noise = noisy_next_obs * c_in 56 | model_output = self.inner_model(rescaled_noise, c_noise, rescaled_obs, act,drone_states,obstacles) #F_theta(c_in*noised_x_t+1,x_[0:t],act_[0:t]) 57 | denoised = model_output * c_out + noisy_next_obs * c_skip 58 | return model_output, denoised 59 | 60 | @torch.no_grad() 61 | def denoise(self, noisy_next_obs: Tensor, sigma: Tensor, obs: Tensor, act: Tensor,drone_states:Tensor,obstacles:Tensor) -> Tensor: 62 | _, d = self(noisy_next_obs, sigma, obs, act,drone_states,obstacles) 63 | # Quantize to {0, ..., 255}, then back to [-1, 1] 64 | d = d.clamp(-1, 1).add(1).div(2).mul(255).byte().div(255).mul(2).sub(1) 65 | return d 66 | 67 | def _compute_conditioners(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 68 | sigma = (sigma**2 + self.cfg.sigma_offset_noise**2).sqrt() 69 | c_in = 1 / (sigma**2 + self.cfg.sigma_data**2).sqrt() 70 | c_skip = self.cfg.sigma_data**2 / (sigma**2 + self.cfg.sigma_data**2) 71 | c_out = sigma * c_skip.sqrt() 72 | c_noise = sigma.log() / 4 73 | return *(add_dims(c, 4) for c in (c_in, c_out, c_skip)), add_dims(c_noise, 1) 74 | 75 | def compute_loss(self, batch: Batch): 76 | n = self.cfg.inner_model.num_steps_conditioning 77 | seq_length = batch.obs.size(1) - n 78 | 79 | all_obs = batch.obs.clone() 80 | loss = 0 81 | 82 | for i in range(seq_length): 83 | obs = all_obs[:, i : n + i] 84 | next_obs = all_obs[:, n + i] 85 | act = batch.act[:, i : n + i] 86 | drone_states = batch.drone_state[:, i : n + i] 87 | obstacles = batch.obstacle_relpos[:, i : n + i] 88 | mask = batch.mask_padding[:, n + i] 89 | 90 | b, t, c, h, w = obs.shape 91 | obs = obs.reshape(b, t * c, h, w) 92 | 93 | sigma = self.sample_sigma_training(b, self.device) 94 | _, c_out, c_skip, _ = self._compute_conditioners(sigma) 95 | 96 | offset_noise = self.cfg.sigma_offset_noise * torch.randn(b, c, 1, 1, device=next_obs.device) 97 | noisy_next_obs = next_obs + offset_noise + torch.randn_like(next_obs) * add_dims(sigma, next_obs.ndim) 98 | 99 | model_output, denoised = self(noisy_next_obs, sigma, obs, act, drone_states,obstacles) 100 | 101 | target = (next_obs - c_skip * noisy_next_obs) / c_out 102 | loss += F.mse_loss(model_output[mask], target[mask]) 103 | 104 | all_obs[:, n + i] = denoised.detach().clamp(-1, 1) 105 | 106 | loss /= seq_length 107 | return loss, {"loss_denoising": loss.detach()} 108 | -------------------------------------------------------------------------------- /algo/dreamerv3/models/diffusion/diffusion_sampler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Tuple 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from .denoiser import Denoiser 8 | 9 | 10 | @dataclass 11 | class DiffusionSamplerConfig: 12 | num_steps_denoising: int 13 | sigma_min: float = 2e-3 14 | sigma_max: float = 5 15 | rho: int = 7 16 | order: int = 1 17 | s_churn: float = 0 18 | s_tmin: float = 0 19 | s_tmax: float = float("inf") 20 | s_noise: float = 1 21 | 22 | 23 | class DiffusionSampler: 24 | def __init__(self, denoiser: Denoiser, cfg: DiffusionSamplerConfig): 25 | self.denoiser = denoiser 26 | self.cfg = cfg 27 | self.sigmas = build_sigmas(cfg.num_steps_denoising, cfg.sigma_min, cfg.sigma_max, cfg.rho, denoiser.device) 28 | 29 | @torch.no_grad() 30 | def sample_next_obs(self, obs: Tensor, act: Tensor,drone_states:Tensor,obstacles:Tensor) -> Tuple[Tensor, List[Tensor]]: 31 | device = obs.device 32 | b, t, c, h, w = obs.size() 33 | obs = obs.reshape(b, t * c, h, w) 34 | s_in = torch.ones(b, device=device) 35 | gamma_ = min(self.cfg.s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) #这里gamma_就是0? 36 | x = torch.randn(b, c, h, w, device=device) 37 | trajectory = [x] 38 | for sigma, next_sigma in zip(self.sigmas[:-1], self.sigmas[1:]): 39 | gamma = gamma_ if self.cfg.s_tmin <= sigma <= self.cfg.s_tmax else 0 40 | sigma_hat = sigma * (gamma + 1) 41 | if gamma > 0: 42 | eps = torch.randn_like(x) * self.cfg.s_noise 43 | x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 44 | denoised = self.denoiser.denoise(x, sigma, obs, act,drone_states,obstacles) 45 | d = (x - denoised) / sigma_hat 46 | dt = next_sigma - sigma_hat 47 | if self.cfg.order == 1 or next_sigma == 0: 48 | # Euler method 49 | x = x + d * dt 50 | else: 51 | # Heun's method 52 | x_2 = x + d * dt 53 | denoised_2 = self.denoiser.denoise(x_2, next_sigma * s_in, obs, act) 54 | d_2 = (x_2 - denoised_2) / next_sigma 55 | d_prime = (d + d_2) / 2 56 | x = x + d_prime * dt 57 | trajectory.append(x) 58 | return x, trajectory 59 | 60 | 61 | def build_sigmas(num_steps: int, sigma_min: float, sigma_max: float, rho: int, device: torch.device) -> Tensor: 62 | min_inv_rho = sigma_min ** (1 / rho) 63 | max_inv_rho = sigma_max ** (1 / rho) 64 | l = torch.linspace(0, 1, num_steps, device=device) 65 | sigmas = (max_inv_rho + l * (min_inv_rho - max_inv_rho)) ** rho 66 | return torch.cat((sigmas, sigmas.new_zeros(1))) 67 | -------------------------------------------------------------------------------- /algo/dreamerv3/models/diffusion/inner_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from ..blocks import Conv3x3, FourierFeatures, GroupNorm, UNet,CrossAttn1d 10 | 11 | 12 | @dataclass 13 | class InnerModelConfig: 14 | img_channels: int 15 | num_steps_conditioning: int 16 | cond_channels: int 17 | depths: List[int] 18 | channels: List[int] 19 | attn_depths: List[bool] 20 | drone_states_dim: int 21 | obstacles_num: int 22 | d_model: int 23 | num_actions: Optional[int] = None 24 | 25 | 26 | class InnerModel(nn.Module): 27 | def __init__(self, cfg: InnerModelConfig) -> None: 28 | super().__init__() 29 | self.noise_emb = FourierFeatures(cfg.cond_channels) #cond_channels=256 30 | self.act_state_emb = nn.Sequential( 31 | nn.Linear(cfg.num_actions + cfg.d_model,cfg.cond_channels // cfg.num_steps_conditioning), 32 | nn.SiLU(), 33 | nn.Linear(cfg.cond_channels // cfg.num_steps_conditioning,cfg.cond_channels // cfg.num_steps_conditioning), #num_steps_conditioning=4 34 | nn.Flatten(), # b t e -> b (t e) 35 | ) 36 | self.drone_states_emb = nn.Sequential( 37 | nn.Linear(cfg.drone_states_dim,cfg.d_model), 38 | nn.LayerNorm(cfg.d_model), 39 | nn.SiLU(), 40 | ) 41 | self.obstacles_emb = nn.Sequential( 42 | nn.Linear(3,cfg.d_model), 43 | nn.LayerNorm(cfg.d_model), 44 | nn.SiLU(), 45 | ) 46 | self.states_obstacles_attn = CrossAttn1d(d_model=cfg.d_model,num_heads=4) 47 | # self.attn_proj = nn.Sequential( 48 | # nn.Linear(cfg.d_model,cfg.cond_channels // cfg.num_steps_conditioning), 49 | # nn.LayerNorm(cfg.cond_channels // cfg.num_steps_conditioning), 50 | # nn.SiLU(), 51 | # nn.Linear(cfg.cond_channels // cfg.num_steps_conditioning,cfg.cond_channels // cfg.num_steps_conditioning), 52 | # nn.Flatten(), # b t e -> b (t e) 53 | # ) 54 | 55 | self.cond_proj = nn.Sequential( 56 | nn.Linear(cfg.cond_channels, cfg.cond_channels), 57 | nn.SiLU(), 58 | nn.Linear(cfg.cond_channels, cfg.cond_channels), 59 | ) 60 | self.conv_in = Conv3x3((cfg.num_steps_conditioning + 1) * cfg.img_channels, cfg.channels[0]) #img_channels=3 61 | 62 | self.unet = UNet(cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths) #256,[2,2,2,2],[64,64,64,64],[0,0,0,0] 63 | 64 | self.norm_out = GroupNorm(cfg.channels[0]) 65 | self.conv_out = Conv3x3(cfg.channels[0], cfg.img_channels) 66 | nn.init.zeros_(self.conv_out.weight) 67 | 68 | def forward(self, noisy_next_obs: Tensor, c_noise: Tensor, obs: Tensor, 69 | act: Tensor,drone_states:Tensor,obstacles:Tensor) -> Tensor: 70 | assert len(drone_states.shape)==3 and len(obstacles.shape)==4 71 | b,l,n,d = obstacles.shape 72 | drone_states = drone_states.unsqueeze(-2) 73 | obstacles_z = obstacles[:,:,:,-1:0] 74 | mask = torch.ones(b,l,n,1).bool().to(obs.device) 75 | mask.masked_fill(obstacles_z<=-1000,0) 76 | mask = mask.transpose(-1,-2) 77 | 78 | drone_obstacles_attn,_ = self.states_obstacles_attn(self.drone_states_emb(drone_states),self.obstacles_emb(obstacles),mask) #b,l,1,d_model 79 | drone_obstacles_attn = drone_obstacles_attn.squeeze(-2) #b,l,d_model 80 | 81 | # cond = self.cond_proj(self.noise_emb(c_noise) + self.act_emb(act) + self.attn_proj(drone_obstacles_attn)) 82 | cond = self.cond_proj(self.noise_emb(c_noise) + self.act_state_emb(torch.cat([act,drone_obstacles_attn],dim=-1))) 83 | x = self.conv_in(torch.cat((obs, noisy_next_obs), dim=1)) #b (t+1)*c h w 84 | x, _, _ = self.unet(x, cond) 85 | x = self.conv_out(F.silu(self.norm_out(x))) 86 | return x 87 | 88 | def main(): 89 | inner_cfg = InnerModelConfig( 90 | img_channels=3, 91 | num_steps_conditioning=4, 92 | cond_channels=256, 93 | depths=[2,2,2,2], 94 | channels=[64,64,64,64], 95 | attn_depths=[False,False,False,False], 96 | num_actions=5 97 | ) 98 | innermodel = InnerModel(inner_cfg) 99 | noisy_next_obs = torch.randn(16,3,64,64) 100 | obs = torch.randn(16,3*4,64,64) 101 | c_noise = torch.randn(16,) 102 | act = torch.randint(0,4,(16,4)) 103 | x = innermodel(noisy_next_obs,c_noise,obs,act) 104 | print(x.shape) 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /algo/dreamerv3/wmenv/__init__.py: -------------------------------------------------------------------------------- 1 | from .world_state_env import DepthStateEnvConfig 2 | from .replaybuffer import buffercfg -------------------------------------------------------------------------------- /algo/dreamerv3/wmenv/replaybuffer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | import torch 5 | 6 | # from diffaero.utils.runner import timeit 7 | 8 | @dataclass 9 | class buffercfg: 10 | perception_width: int 11 | perception_height: int 12 | state_dim: int 13 | action_dim: int 14 | num_envs: int 15 | max_length: int 16 | warmup_length: int 17 | store_on_gpu: bool 18 | device: str 19 | use_perception: bool 20 | 21 | class ReplayBuffer(): 22 | def __init__(self, cfg:buffercfg) -> None: 23 | self.store_on_gpu = cfg.store_on_gpu 24 | device = torch.device(cfg.device) 25 | if cfg.store_on_gpu: 26 | self.state_buffer = torch.empty((cfg.max_length//cfg.num_envs, cfg.num_envs, cfg.state_dim), dtype=torch.float32, device=device, requires_grad=False) 27 | self.action_buffer = torch.empty((cfg.max_length//cfg.num_envs, cfg.num_envs,cfg.action_dim), dtype=torch.float32, device=device, requires_grad=False) 28 | self.reward_buffer = torch.empty((cfg.max_length//cfg.num_envs, cfg.num_envs), dtype=torch.float32, device=device, requires_grad=False) 29 | self.termination_buffer = torch.empty((cfg.max_length//cfg.num_envs, cfg.num_envs), dtype=torch.float32, device=device, requires_grad=False) 30 | if cfg.use_perception: 31 | self.perception_buffer = torch.empty((cfg.max_length//cfg.num_envs, cfg.num_envs, 1, cfg.perception_height, cfg.perception_width), dtype=torch.float32, device=device, requires_grad=False) 32 | else: 33 | raise ValueError("Only support gpu!!!") 34 | 35 | self.length = 0 36 | self.num_envs = cfg.num_envs 37 | self.last_pointer = -1 38 | self.max_length = cfg.max_length 39 | self.warmup_length = cfg.warmup_length 40 | self.use_perception = cfg.use_perception 41 | 42 | def ready(self): 43 | return self.length * self.num_envs > self.warmup_length and self.length > 64 44 | 45 | @torch.no_grad() 46 | # @timeit 47 | def sample(self, batch_size, batch_length): 48 | perception = None 49 | if batch_size < self.num_envs: 50 | batch_size = self.num_envs 51 | if self.store_on_gpu: 52 | indexes = torch.randint(0, self.length - batch_length, (batch_size,), device=self.state_buffer.device) 53 | arange = torch.arange(batch_length, device=self.state_buffer.device) 54 | idxs = torch.flatten(indexes.unsqueeze(1) + arange.unsqueeze(0)) # shape: (batch_size * batch_length) 55 | env_idx = torch.randint(0, self.num_envs, (batch_size, 1), device=self.state_buffer.device).expand(-1, batch_length).reshape(-1) 56 | state = self.state_buffer[idxs, env_idx].reshape(batch_size, batch_length, -1) 57 | action = self.action_buffer[idxs, env_idx].reshape(batch_size, batch_length, -1) 58 | reward = self.reward_buffer[idxs, env_idx].reshape(batch_size, batch_length) 59 | termination = self.termination_buffer[idxs, env_idx].reshape(batch_size, batch_length) 60 | if self.use_perception: 61 | perception = self.perception_buffer[idxs, env_idx].reshape(batch_size, batch_length, *self.perception_buffer.shape[2:]) 62 | else: 63 | raise ValueError("Only support gpu!!!") 64 | 65 | return state, action, reward, termination, perception 66 | 67 | def append(self, state, action, reward, termination, perception=None): 68 | self.last_pointer = (self.last_pointer + 1) % (self.max_length//self.num_envs) 69 | if self.store_on_gpu: 70 | self.state_buffer[self.last_pointer] = state 71 | self.action_buffer[self.last_pointer] = action 72 | self.reward_buffer[self.last_pointer] = reward 73 | self.termination_buffer[self.last_pointer] = termination 74 | if self.use_perception and perception is not None: 75 | self.perception_buffer[self.last_pointer] = perception 76 | else: 77 | raise ValueError("Only support gpu!!!") 78 | 79 | if len(self) < self.max_length: 80 | self.length += 1 81 | 82 | def load_external(self, path:str, max_action:torch.Tensor=None, min_action:torch.Tensor=None): 83 | if min_action == None: 84 | min_action = torch.tensor([[-20, -20, 0]]).to(self.state_buffer.device) 85 | max_action = torch.tensor([[20, 20, 40]]).to(self.state_buffer.device) 86 | with np.load(path) as data: 87 | state = np.squeeze(data["state"], axis=1) # [length, 9] 88 | perception = data["perception"] # [length, 1, 9, 16] 89 | action = data["action"] # [length, 3] 90 | self.extern_action_buff = (torch.from_numpy(action).float().to(self.state_buffer.device) - min_action) / (max_action - min_action) * 2.0 - 1.0 91 | self.extern_state_buff = torch.from_numpy(state).float().to(self.state_buffer.device) 92 | self.extern_perception_buff = torch.from_numpy(perception).float().to(self.state_buffer.device) 93 | 94 | def sample_extern(self, batch_size:int, batch_length:int): 95 | assert hasattr(self, "extern_action_buff"), "Please load external data first!!!" 96 | index = torch.randint(0, self.extern_action_buff.shape[0] - batch_length, (batch_size,), device=self.state_buffer.device) 97 | state = torch.stack([self.extern_state_buff[i:i + batch_length] for i in index], dim=0) 98 | action = torch.stack([self.extern_action_buff[i:i + batch_length] for i in index], dim=0) 99 | perception = torch.stack([self.extern_perception_buff[i:i + batch_length] for i in index], dim=0) 100 | return state, action, perception 101 | 102 | def __len__(self): 103 | return self.length * self.num_envs 104 | 105 | if __name__ == "__main__": 106 | cfg = buffercfg( 107 | perception_width=16, 108 | perception_height=9, 109 | state_dim=9, 110 | action_dim=3, 111 | num_envs=16, 112 | max_length=10000, 113 | warmup_length=1000, 114 | store_on_gpu=True, 115 | device="cuda:0", 116 | use_perception=True, 117 | ) 118 | rplb = ReplayBuffer(cfg) 119 | rplb.load_external("/home/zxh/ws/wrqws/diff/traj/all_trajs.npz") 120 | s, a, p = rplb.sample_extern(4, 32) 121 | print(s.shape, a.shape, p.shape) -------------------------------------------------------------------------------- /algo/dreamerv3/wmenv/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Tuple 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.optim.lr_scheduler import LambdaLR 8 | import torch.nn as nn 9 | from torch.optim import AdamW 10 | 11 | Logs = List[Dict[str, float]] 12 | ComputeLossOutput = Tuple[Tensor, Dict[str, Any]] 13 | 14 | class StateDictMixin: 15 | def _init_fields(self) -> None: 16 | def has_sd(x: str) -> bool: 17 | return callable(getattr(x, "state_dict", None)) and callable(getattr(x, "load_state_dict", None)) 18 | 19 | self._all_fields = {k for k in vars(self) if not k.startswith("_")} 20 | self._fields_sd = {k for k in self._all_fields if has_sd(getattr(self, k))} 21 | 22 | def _get_field(self, k: str) -> Any: 23 | return getattr(self, k).state_dict() if k in self._fields_sd else getattr(self, k) 24 | 25 | def _set_field(self, k: str, v: Any) -> None: 26 | getattr(self, k).load_state_dict(v) if k in self._fields_sd else setattr(self, k, v) 27 | 28 | def state_dict(self) -> Dict[str, Any]: 29 | if not hasattr(self, "_all_fields"): 30 | self._init_fields() 31 | return {k: self._get_field(k) for k in self._all_fields} 32 | 33 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 34 | if not hasattr(self, "_all_fields"): 35 | self._init_fields() 36 | assert set(list(state_dict.keys())) == self._all_fields 37 | for k, v in state_dict.items(): 38 | self._set_field(k, v) 39 | 40 | 41 | @dataclass 42 | class CommonTools(StateDictMixin): 43 | denoiser: Any 44 | rew_end_model: Any 45 | actor_critic: Any 46 | 47 | def get(self, name: str) -> Any: 48 | return getattr(self, name) 49 | 50 | def set(self, name: str, value: Any): 51 | return setattr(self, name, value) 52 | 53 | def configure_opt(model: nn.Module, lr: float, weight_decay: float, eps: float, *blacklist_module_names: str) -> AdamW: 54 | """Credits to https://github.com/karpathy/minGPT""" 55 | # separate out all parameters to those that will and won't experience regularizing weight decay 56 | decay = set() 57 | no_decay = set() 58 | whitelist_weight_modules = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.LSTMCell, nn.LSTM, nn.GRUCell, nn.ConvTranspose2d) 59 | blacklist_weight_modules = (nn.LayerNorm, nn.Embedding, nn.GroupNorm, nn.BatchNorm2d) 60 | for mn, m in model.named_modules(): 61 | for pn, p in m.named_parameters(): 62 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 63 | if any([fpn.startswith(module_name) for module_name in blacklist_module_names]): 64 | no_decay.add(fpn) 65 | elif "bias" in pn: 66 | # all biases will not be decayed 67 | no_decay.add(fpn) 68 | elif (pn.endswith("weight") or pn.startswith("weight_")) and isinstance(m, whitelist_weight_modules): 69 | # weights of whitelist modules will be weight decayed 70 | decay.add(fpn) 71 | elif (pn.endswith("weight") or pn.startswith("weight_")) and isinstance(m, blacklist_weight_modules): 72 | # weights of blacklist modules will NOT be weight decayed 73 | no_decay.add(fpn) 74 | 75 | # validate that we considered every parameter 76 | param_dict = {pn: p for pn, p in model.named_parameters()} 77 | inter_params = decay & no_decay 78 | union_params = decay | no_decay 79 | assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!" 80 | assert ( 81 | len(param_dict.keys() - union_params) == 0 82 | ), f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" 83 | 84 | # create the pytorch optimizer object 85 | optim_groups = [ 86 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 87 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 88 | ] 89 | optimizer = AdamW(optim_groups, lr=lr, eps=eps) 90 | return optimizer 91 | 92 | 93 | def count_parameters(model: nn.Module) -> int: 94 | return sum(p.numel() for p in model.parameters()) 95 | 96 | 97 | def extract_state_dict(state_dict: OrderedDict, module_name: str) -> OrderedDict: 98 | return OrderedDict({k.split(".", 1)[1]: v for k, v in state_dict.items() if k.startswith(module_name)}) 99 | 100 | 101 | def get_lr_sched(opt: torch.optim.Optimizer, num_warmup_steps: int) -> LambdaLR: 102 | def lr_lambda(current_step: int): 103 | return 1 if current_step >= num_warmup_steps else current_step / max(1, num_warmup_steps) 104 | 105 | return LambdaLR(opt, lr_lambda, last_epoch=-1) 106 | 107 | 108 | def init_lstm(model: nn.Module) -> None: 109 | for name, p in model.named_parameters(): 110 | if "weight_ih" in name: 111 | nn.init.xavier_uniform_(p.data) 112 | elif "weight_hh" in name: 113 | nn.init.orthogonal_(p.data) 114 | elif "bias_ih" in name: 115 | p.data.fill_(0) 116 | # Set forget-gate bias to 1 117 | n = p.size(0) 118 | p.data[(n // 4) : (n // 2)].fill_(1) 119 | elif "bias_hh" in name: 120 | p.data.fill_(0) 121 | 122 | class EMAScalar(): 123 | def __init__(self, decay) -> None: 124 | self.scalar = 0.0 125 | self.decay = decay 126 | 127 | def __call__(self, value): 128 | self.update(value) 129 | return self.get() 130 | 131 | def update(self, value): 132 | self.scalar = self.scalar * self.decay + value * (1 - self.decay) 133 | 134 | def get(self): 135 | return self.scalar 136 | -------------------------------------------------------------------------------- /algo/dreamerv3/wmenv/world_state_env.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Generator, List, Tuple,Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from diffaero.algo.dreamerv3.models.state_predictor import DepthStateModel 8 | from diffaero.algo.dreamerv3.models.blocks import symexp,symlog 9 | from .replaybuffer import ReplayBuffer 10 | from diffaero.utils.logger import Logger 11 | from diffaero.utils.runner import timeit 12 | # from models.rew_end_model import RewEndModel 13 | 14 | ResetOutput = Tuple[torch.FloatTensor, Dict[str, Any]] 15 | StepOutput = Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]] 16 | InitialCondition = Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]] 17 | 18 | 19 | @dataclass 20 | class DepthStateEnvConfig: 21 | horizon: int 22 | batch_size: int 23 | batch_length: int 24 | use_perception: bool = False 25 | use_extern: bool = False 26 | 27 | class DepthStateEnv: 28 | def __init__( 29 | self, 30 | state_model: DepthStateModel, 31 | replaybuffer: ReplayBuffer, 32 | cfg: DepthStateEnvConfig, 33 | ) -> None: 34 | self.state_model = state_model 35 | self.replaybuffer = replaybuffer 36 | self.cfg = cfg 37 | self.hidden = None 38 | self.use_extern = cfg.use_extern 39 | 40 | @torch.no_grad() 41 | @timeit 42 | def make_generator_init(self,): 43 | batch_size = self.cfg.batch_size 44 | batch_length = self.cfg.batch_length 45 | if self.use_extern: 46 | states, actions, perceptions = self.replaybuffer.sample_extern(batch_size, batch_length) 47 | else: 48 | states, actions, _ , _, perceptions = self.replaybuffer.sample(batch_size, batch_length) 49 | hidden = None 50 | 51 | for i in range(batch_length): 52 | if perceptions != None: 53 | latent,_ = self.state_model.sample_with_post(states[:,i],perceptions[:,i],hidden) 54 | else: 55 | latent,_ = self.state_model.sample_with_post(states[:,i],None,hidden) 56 | latent = self.state_model.flatten(latent) 57 | latent,_,hidden=self.state_model.sample_with_prior(latent,actions[:,i],hidden) 58 | 59 | latent = self.state_model.flatten(latent) 60 | self.latent = latent 61 | self.hidden = hidden 62 | return latent, hidden 63 | 64 | @torch.no_grad() 65 | @timeit 66 | def step(self,action:Tensor): 67 | assert action.ndim==2 68 | prior_sample,pred_reward,pred_end,hidden = self.state_model.predict_next(latent=self.latent, act=action, hidden=self.hidden) 69 | flattened_sample = prior_sample.view(*prior_sample.shape[:-2],-1) 70 | self.latent = flattened_sample 71 | self.hidden = hidden 72 | return flattened_sample,pred_reward,pred_end,hidden 73 | 74 | def decode(self, latents:Tensor, hiddens:Tensor): 75 | _, videos = self.state_model.decode(latents, hiddens) 76 | assert videos.ndim == 4, f"Expected videos to have 4 dimensions, got {videos.ndim}" 77 | return videos 78 | -------------------------------------------------------------------------------- /cfg/algo/apg.yaml: -------------------------------------------------------------------------------- 1 | name: apg 2 | network: ${network} 3 | 4 | l_rollout: 32 5 | 6 | lr: 0.001 7 | max_grad_norm: 1. -------------------------------------------------------------------------------- /cfg/algo/apg_sto.yaml: -------------------------------------------------------------------------------- 1 | name: apg_sto 2 | network: ${network} 3 | 4 | l_rollout: 32 5 | 6 | lr: 0.001 7 | max_grad_norm: 1. 8 | 9 | entropy_weight: 0.03 -------------------------------------------------------------------------------- /cfg/algo/appo.yaml: -------------------------------------------------------------------------------- 1 | name: appo 2 | network: ${network} 3 | 4 | l_rollout: 16 5 | 6 | lr: 0.0026 7 | eps: 1e-8 8 | gamma: 0.99 9 | lmbda: 0.95 10 | entropy_weight: 0.01 11 | value_weight: 2 12 | clip_coef: 0.2 13 | clip_value_loss: True 14 | norm_adv: True 15 | actor_grad_norm: 1. 16 | critic_grad_norm: 17 | n_minibatch: 8 18 | n_epoch: 4 19 | 20 | critic_network: 21 | name: mlp 22 | hidden_dim: [256, 128] -------------------------------------------------------------------------------- /cfg/algo/mashac.yaml: -------------------------------------------------------------------------------- 1 | name: mashac 2 | network: ${network} 3 | 4 | l_rollout: 32 5 | 6 | actor_lr: 0.001 7 | critic_lr: 0.003 8 | gamma: 0.99 9 | lmbda: 0.95 10 | entropy_weight: 0.01 11 | actor_grad_norm: 1. 12 | critic_grad_norm: 13 | n_minibatch: 8 14 | target_update_rate: 0.005 15 | 16 | share_parameter: True -------------------------------------------------------------------------------- /cfg/algo/ppo.yaml: -------------------------------------------------------------------------------- 1 | name: ppo 2 | network: ${network} 3 | 4 | l_rollout: 16 5 | 6 | lr: 0.0026 7 | eps: 1e-8 8 | gamma: 0.99 9 | lmbda: 0.95 10 | entropy_weight: 0.01 11 | value_weight: 2 12 | clip_coef: 0.2 13 | clip_value_loss: True 14 | norm_adv: True 15 | actor_grad_norm: 1. 16 | critic_grad_norm: 17 | n_minibatch: 8 18 | n_epoch: 4 -------------------------------------------------------------------------------- /cfg/algo/sha2c.yaml: -------------------------------------------------------------------------------- 1 | name: sha2c 2 | network: ${network} 3 | 4 | l_rollout: 32 5 | 6 | actor_lr: 0.001 7 | critic_lr: 0.003 8 | gamma: 0.99 9 | lmbda: 0.95 10 | entropy_weight: 0.01 11 | actor_grad_norm: 1. 12 | critic_grad_norm: 10. 13 | n_minibatch: 8 14 | target_update_rate: 1. 15 | 16 | critic_network: 17 | name: mlp 18 | hidden_dim: [256, 128] -------------------------------------------------------------------------------- /cfg/algo/shac.yaml: -------------------------------------------------------------------------------- 1 | name: shac 2 | network: ${network} 3 | 4 | l_rollout: 32 5 | 6 | actor_lr: 0.001 7 | critic_lr: 0.003 8 | gamma: 0.99 9 | lmbda: 0.95 10 | entropy_weight: 0.01 11 | actor_grad_norm: 1. 12 | critic_grad_norm: 13 | n_minibatch: 8 14 | target_update_rate: 0.005 -------------------------------------------------------------------------------- /cfg/algo/world.yaml: -------------------------------------------------------------------------------- 1 | name: world 2 | n_envs: ${n_envs} 3 | l_rollout: 1 4 | 5 | common: 6 | device: cuda:1 7 | run_name: v3oaa2cbigcontinu 8 | use_checkpoint: False 9 | use_symlog: True 10 | total_timesteps: 200_000_000 11 | # checkpoint_path: /home/zxh/ws/wrqws/diffaero/outputs/2024-12-27/22-18-41/checkpoints 12 | checkpoint_path: 13 | is_test: False 14 | use_amp: False 15 | 16 | world_state_env: 17 | _target_: algo.dreamerv3.wmenv.DepthStateEnvConfig 18 | horizon: 16 19 | batch_size: 1024 20 | batch_length: 4 21 | use_perception: False 22 | use_extern: False 23 | 24 | replaybuffer: 25 | _target_: algo.dreamerv3.wmenv.buffercfg 26 | perception_width: 16 27 | perception_height: 9 28 | state_dim: 13 29 | action_dim: 3 30 | num_envs: 64 31 | max_length: 1048576 32 | warmup_length: 5000 33 | store_on_gpu: True 34 | device: "cuda:1" 35 | use_perception: True 36 | 37 | state_predictor: 38 | state_model: 39 | _target_: algo.dreamerv3.models.state_predictor.DepthStateModelCfg 40 | state_dim: 13 41 | image_width: 16 42 | image_height: 9 43 | hidden_dim: 512 44 | action_dim: 3 45 | latent_dim: 1024 46 | categoricals: 32 47 | num_classes: 255 48 | use_simnorm: False 49 | only_state: False 50 | img_recon_loss_weight: 1. 51 | end_loss_pos_weight: 30 52 | enable_rec: True 53 | rec_coef: 1.0 54 | rew_coef: 1.0 55 | end_coef: 1.0 56 | rep_coef: 0.1 57 | dyn_coef: 0.5 58 | 59 | training: 60 | max_grad_norm: 100.0 61 | total_steps: 500000 62 | batch_size: 64 63 | batch_length: 64 64 | worldmodel_update_freq: 1 65 | use_amp: ${algo.common.use_amp} 66 | 67 | optimizer: 68 | lr: 1e-4 69 | weight_decay: 1e-2 70 | eps: 1e-8 71 | 72 | actor_critic: 73 | model: 74 | __target__: algo.dreamerv3.models.agent.ActorCriticConfig 75 | feat_dim: 1536 # hidden + latent 76 | num_layers: 2 77 | hidden_dim: 256 78 | action_dim: 3 79 | gamma: 0.985 80 | lambd: 0.95 81 | entropy_coef: 3e-4 82 | device: cuda:1 83 | max_std: 1 84 | min_std: 0.1 85 | training: 86 | max_grad_norm: 100.0 87 | total_steps: 500000 88 | imagine_length: 16 89 | use_states: False 90 | scale_rewards: False 91 | 92 | training: 93 | max_grad_norm: 1.0 94 | total_steps: 500000 95 | batch_size: 64 96 | batch_length: 6 97 | total_episodes: 1000 98 | imagine_period: 10000 99 | use_amp: ${algo.common.use_amp} 100 | 101 | imagine: 102 | imagine_length: 64 -------------------------------------------------------------------------------- /cfg/config_test.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - algo: apg_sto 3 | - env: pc 4 | - dynamics: pmc # [pmc, pmd, quad] 5 | - network: mlp # [mlp, cnn] 6 | - sensor: camera # [camera, lidar, relpos] 7 | - logger: tensorboard # [tensorboard, wandb] 8 | - _self_ 9 | - override hydra/help: test_help 10 | - override hydra/launcher: joblib 11 | - override hydra/job_logging: colorlog 12 | - override hydra/hydra_logging: colorlog 13 | # - override hydra/sweeper: optuna_sweep 14 | 15 | # use rollout length specified by specific algorithm 16 | # since different algorithm need different rollout length to perform well 17 | l_rollout: ${algo.l_rollout} 18 | n_agents: 1 19 | n_envs: 4096 20 | n_steps: 10000 21 | log_level: info 22 | device: 0 23 | n_jobs: 1 24 | max_vel: 5.0 25 | ref_path: 26 | 27 | headless: True 28 | display_image: False 29 | record_video: False 30 | video_saveas: mp4 # [mp4, tensorboard] 31 | use_training_cfg: False 32 | 33 | runname: "" 34 | checkpoint: 35 | export: 36 | obs_frame: ${env.obs_frame} 37 | action_frame: ${dynamics.action_frame} 38 | jit: True 39 | onnx: True 40 | 41 | seed: 0 42 | torch_deterministic: False 43 | 44 | env: 45 | dt: 0.0333 46 | min_target_vel: ${max_vel} 47 | max_target_vel: ${max_vel} 48 | randomizer: 49 | enabled: False 50 | 51 | hydra: 52 | run: 53 | dir: ./outputs/test/${now:%Y-%m-%d}/${now:%H-%M-%S} 54 | sweep: 55 | dir: ./outputs/test/${now:%Y-%m-%d}/multirun/${now:%H-%M-%S} 56 | launcher: 57 | n_jobs: ${n_jobs} -------------------------------------------------------------------------------- /cfg/config_train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - algo: apg_sto 3 | - env: pc 4 | - dynamics: pmc # [pmc, pmd, quad] 5 | - network: mlp # [mlp, cnn] 6 | - sensor: camera # [camera, lidar, relpos] 7 | - logger: tensorboard # [tensorboard, wandb] 8 | - _self_ 9 | - override hydra/help: train_help 10 | - override hydra/launcher: joblib 11 | - override hydra/job_logging: colorlog 12 | - override hydra/hydra_logging: colorlog 13 | # - override hydra/sweeper: optuna_sweep 14 | 15 | # use rollout length specified by specific algorithm 16 | # since different algorithm need different rollout length to perform well 17 | l_rollout: ${algo.l_rollout} 18 | n_agents: 1 19 | n_envs: 1024 20 | n_updates: 1000 21 | log_freq: 10 22 | log_level: info 23 | save_freq: 100 24 | device: 0 25 | n_jobs: 1 26 | 27 | headless: True 28 | display_image: False 29 | record_video: False 30 | 31 | runname: "" 32 | checkpoint: 33 | export: 34 | obs_frame: ${env.obs_frame} 35 | action_frame: ${dynamics.action_frame} 36 | jit: True 37 | onnx: True 38 | 39 | torch_profile: False 40 | 41 | seed: 0 42 | torch_deterministic: False 43 | 44 | env: 45 | dt: 0.0333 46 | 47 | ref_path: 48 | 49 | hydra: 50 | run: 51 | dir: ./outputs/train/${now:%Y-%m-%d}/${now:%H-%M-%S} 52 | sweep: 53 | dir: ./outputs/train/${now:%Y-%m-%d}/multirun/${now:%H-%M-%S} 54 | launcher: 55 | n_jobs: ${n_jobs} -------------------------------------------------------------------------------- /cfg/dynamics/pmc.yaml: -------------------------------------------------------------------------------- 1 | name: countinuous_pointmass 2 | abbr: pmc 3 | 4 | n_envs: ${n_envs} 5 | n_agents: ${n_agents} 6 | dt: ${env.dt} 7 | action_frame: "local" # "world" or "local" 8 | m: 1.0 # total mass 9 | g: 9.81 # gravity 10 | D: # drag coef normed by mass 11 | default: 0.375 12 | enabled: True 13 | min: 0.3 14 | max: 0.45 15 | lmbda: # acc exp smoothing 16 | default: 2.6 17 | enabled: True 18 | min: 2.5 19 | max: 2.7 20 | alpha: 1.0 # gradient decay factor 21 | align_yaw_with_target_direction: False 22 | align_yaw_with_vel_ema: True 23 | vel_ema_factor: 24 | default: 0.1 25 | enabled: True 26 | min: 0.08 27 | max: 0.12 28 | 29 | max_acc: 30 | xy: 31 | default: 20 32 | enabled: True 33 | min: 10 34 | max: 20 35 | z: 36 | default: 40 37 | enabled: True 38 | min: 25 39 | max: 40 40 | 41 | solver_type: "rk4" # "rk4" or "euler" 42 | n_substeps: 1 43 | -------------------------------------------------------------------------------- /cfg/dynamics/pmd.yaml: -------------------------------------------------------------------------------- 1 | name: discrete_pointmass 2 | abbr: pmd 3 | 4 | n_envs: ${n_envs} 5 | n_agents: ${n_agents} 6 | dt: ${env.dt} 7 | action_frame: "local" # "world" or "local" 8 | m: 1.0 # total mass 9 | g: 9.81 # gravity 10 | D: # drag coef normed by mass 11 | default: 0.375 12 | enabled: True 13 | min: 0.3 14 | max: 0.45 15 | lmbda: # acc exp smoothing 16 | default: 2.6 17 | enabled: False 18 | min: 2.5 19 | max: 2.7 20 | alpha: 2 # gradient decay factor 21 | align_yaw_with_target_direction: False 22 | align_yaw_with_vel_ema: True 23 | vel_ema_factor: 24 | default: 0.1 25 | enabled: True 26 | min: 0.08 27 | max: 0.12 28 | 29 | max_acc: 30 | xy: 31 | default: 20 32 | enabled: True 33 | min: 10 34 | max: 20 35 | z: 36 | default: 40 37 | enabled: True 38 | min: 25 39 | max: 40 -------------------------------------------------------------------------------- /cfg/dynamics/quad.yaml: -------------------------------------------------------------------------------- 1 | name: quadrotor 2 | abbr: quad 3 | 4 | n_envs: ${n_envs} 5 | n_agents: ${n_agents} 6 | dt: ${env.dt} 7 | action_frame: "body" 8 | alpha: 1.0 # gradient decay factor 9 | m: # total mass 10 | default: 1.0 11 | enabled: True 12 | min: 0.8 13 | max: 1.2 14 | arm_l: # arm length 15 | default: 0.15 16 | enabled: True 17 | min: 0.1 18 | max: 0.2 19 | c_tau: # torque constant 20 | default: 0.0133 21 | enabled: True 22 | min: 0.01 23 | max: 0.0166 24 | g: 9.81 # gravity 25 | J: 26 | xy: 27 | default: 0.01 28 | enabled: True 29 | min: 0.008 30 | max: 0.012 31 | z: 32 | default: 0.02 33 | enabled: True 34 | min: 0.015 35 | max: 0.025 36 | D: 37 | xy: 38 | default: 0.6 39 | enabled: True 40 | min: 0.5 41 | max: 0.7 42 | z: 43 | default: 0.6 44 | enabled: True 45 | min: 0.5 46 | max: 0.7 47 | max_w_xy: 5.0 # max angular velocity (xy) 48 | max_w_z: 1.0 # max angular velocity (z) 49 | max_T: 4.179 # max thrust 50 | min_T: 0.0 # min thrust 51 | lmbda: 0.1 # acc exp smoothing 52 | 53 | solver_type: "rk4" # "rk4" or "euler" 54 | n_substeps: 1 55 | 56 | controller: 57 | compensate_gravity: False 58 | 59 | min_normed_thrust: 0. 60 | max_normed_thrust: 5. 61 | 62 | min_pitch_rate: -3.14 63 | max_pitch_rate: 3.14 64 | 65 | min_roll_rate: -3.14 66 | max_roll_rate: 3.14 67 | 68 | min_yaw_rate: -3.14 69 | max_yaw_rate: 3.14 70 | 71 | thrust_ratio: 1. 72 | torque_ratio: 1. 73 | 74 | min_normed_torque: [-15., -15., -15.] 75 | max_normed_torque: [ 15., 15., 15.] 76 | 77 | K_angvel: [1., 1., 1.] -------------------------------------------------------------------------------- /cfg/env/imu/default_imu.yaml: -------------------------------------------------------------------------------- 1 | name: default_imu 2 | 3 | enable_drift: True 4 | enable_noise: True 5 | 6 | imu_mounting_error_range_deg: 2. 7 | 8 | # acc_drift_std: 0.0013564659966250536 9 | # acc_noise_std: 0.0015690639999999998 10 | 11 | acc_drift_std: 0.13564659966250536 12 | acc_noise_std: 0.15690639999999998 13 | 14 | # gyro_drift_std: 1.4352700094407325e-05 15 | # gyro_noise_std: 0.0002443460952792061 16 | 17 | gyro_drift_std: 1.4352700094407325e-03 18 | gyro_noise_std: 0.02443460952792061 -------------------------------------------------------------------------------- /cfg/env/mapc.yaml: -------------------------------------------------------------------------------- 1 | name: multi_agent_position_control 2 | abbr: mapc 3 | 4 | dynamics: ${dynamics} 5 | dt: 0.0333 6 | length: 7 | default: 10 8 | enabled: False 9 | min: 10 10 | max: 10 11 | n_envs: ${n_envs} 12 | n_agents: ${n_agents} 13 | obs_frame: "world" # "world" or "local" 14 | max_time: 20 15 | wait_before_truncate: 5. 16 | max_target_vel: 5. 17 | min_target_vel: 5. 18 | last_action_in_obs: False 19 | 20 | solver_type: 21 | n_substeps: 1 22 | 23 | loss_weights: 24 | pointmass: 25 | vel: 1. 26 | jerk: 0.0001 27 | pos: 1. 28 | collision: 5. 29 | 30 | reward_weights: 31 | constant: 4. 32 | pointmass: 33 | vel: 1. 34 | jerk: 0.005 35 | pos: 1. 36 | collision: 16. 37 | 38 | defaults: 39 | - render: pc_render 40 | - randomizer: default_randomizer 41 | - imu: default_imu 42 | - _self_ -------------------------------------------------------------------------------- /cfg/env/oa.yaml: -------------------------------------------------------------------------------- 1 | name: obstacle_avoidance 2 | abbr: oa 3 | 4 | sensor: ${sensor} 5 | dynamics: ${dynamics} 6 | dt: 0.0333 7 | length: 8 | default: 25 9 | enabled: False 10 | min: 25 11 | max: 25 12 | height_scale: 0.25 13 | n_envs: ${n_envs} 14 | n_agents: ${n_agents} 15 | obs_frame: "local" # "world" or "local" 16 | max_time: 30 17 | wait_before_truncate: 5. 18 | max_target_vel: 6. 19 | min_target_vel: 3. 20 | last_action_in_obs: False 21 | r_drone: 0.2 22 | n_obstacles: 30 23 | ground_plane: True 24 | 25 | loss_weights: 26 | pointmass: 27 | vel: 2 28 | z: 1. 29 | oa: 5. 30 | jerk: 0.01 31 | pos: 5. 32 | collision: 0. 33 | quadrotor: 34 | vel: 1. 35 | oa: 1. 36 | jerk: 1. 37 | pos: 5. 38 | collision: 0. 39 | 40 | reward_weights: 41 | # for Reinforcement Learning 42 | # constant: 1. 43 | # pointmass: 44 | # vel: 0.06 45 | # z: 0.2 46 | # oa: 0.6 47 | # jerk: 0.0005 48 | # pos: 0.5 49 | # arrive: 0. 50 | # collision: 5. 51 | 52 | # for SHA2C 53 | constant: 2 54 | pointmass: 55 | vel: 0. 56 | z: 0. 57 | oa: 0. 58 | jerk: 0.001 59 | pos: 0. 60 | arrive: 1. 61 | collision: 50. 62 | quadrotor: 63 | vel: 1. 64 | oa: 1. 65 | jerk: 1. 66 | pos: 5. 67 | collision: 50. 68 | 69 | enable_grid: False 70 | defaults: 71 | - render: oa_render 72 | - randomizer: default_randomizer 73 | - imu: default_imu 74 | - obstacles: outdoor 75 | - _self_ 76 | -------------------------------------------------------------------------------- /cfg/env/oa_small.yaml: -------------------------------------------------------------------------------- 1 | name: obstacle_avoidance 2 | abbr: oa_small 3 | 4 | sensor: ${sensor} 5 | dynamics: ${dynamics} 6 | dt: 0.0333 7 | length: 8 | default: 5 9 | enabled: True 10 | min: 3 11 | max: 7 12 | height_scale: 0.25 13 | n_envs: ${n_envs} 14 | n_agents: ${n_agents} 15 | obs_frame: "local" # "world" or "local" 16 | max_time: 30 17 | wait_before_truncate: 5. 18 | max_target_vel: 4. 19 | min_target_vel: 2. 20 | last_action_in_obs: False 21 | r_drone: 0.14 22 | n_obstacles: 12 23 | ground_plane: True 24 | 25 | loss_weights: 26 | pointmass: 27 | vel: 1. 28 | z: 2. 29 | oa: 4. 30 | jerk: 0.02 31 | pos: 5. 32 | collision: 0. 33 | quadrotor: 34 | vel: 1. 35 | oa: 1. 36 | jerk: 1. 37 | pos: 5. 38 | collision: 0. 39 | 40 | reward_weights: 41 | # for Reinforcement Learning 42 | # constant: 1. 43 | # pointmass: 44 | # vel: 0.06 45 | # z: 0.2 46 | # oa: 0.6 47 | # jerk: 0.0005 48 | # pos: 0.5 49 | # arrive: 0. 50 | # collision: 5. 51 | 52 | # for SHA2C 53 | constant: 1. 54 | pointmass: 55 | vel: 0. 56 | z: 0. 57 | oa: 0. 58 | jerk: 0.001 59 | pos: 0. 60 | arrive: 1. 61 | collision: 50. 62 | quadrotor: 63 | vel: 1. 64 | oa: 1. 65 | jerk: 1. 66 | pos: 5. 67 | collision: 50. 68 | 69 | enable_grid: False 70 | defaults: 71 | - render: oa_render 72 | - randomizer: default_randomizer 73 | - imu: default_imu 74 | - obstacles: small_room 75 | - _self_ 76 | -------------------------------------------------------------------------------- /cfg/env/obstacles/outdoor.yaml: -------------------------------------------------------------------------------- 1 | name: outdoor 2 | 3 | walls: False 4 | ceiling: False 5 | n_obstacles: ${env.n_obstacles} 6 | height_scale: ${env.height_scale} 7 | sphere_percentage: 0.33 8 | sphere_radius_range: [0.6, 2.0, 0.4] # [min, max, step] 9 | cube_lw_range: [0.6, 1.4, 0.2] # [min, max, step] 10 | cube_h_range: [10.0, 15.0, 1.0] # [min, max, step] 11 | randomize_cube_pose: True 12 | cube_roll_pitch_range: 15. # degrees 13 | randpos_std_min: 6 14 | randpos_std_max: 7 15 | safety_range: 1.7 -------------------------------------------------------------------------------- /cfg/env/obstacles/small_room.yaml: -------------------------------------------------------------------------------- 1 | name: small_room 2 | 3 | walls: True 4 | ceiling: True 5 | n_obstacles: ${env.n_obstacles} 6 | height_scale: ${env.height_scale} 7 | sphere_percentage: 0.4 8 | sphere_radius_range: [0.1, 0.5, 0.1] # [min, max, step] 9 | cube_lw_range: [0.1, 0.6, 0.1] # [min, max, step] 10 | cube_h_range: [3.0, 5.0, 1.0] # [min, max, step] 11 | randomize_cube_pose: True 12 | cube_roll_pitch_range: 15. # degrees 13 | randpos_std_min: 1. 14 | randpos_std_max: 1.5 15 | safety_range: 1.0 -------------------------------------------------------------------------------- /cfg/env/pc.yaml: -------------------------------------------------------------------------------- 1 | name: position_control 2 | abbr: pc 3 | 4 | dynamics: ${dynamics} 5 | dt: 0.0333 6 | length: 7 | default: 10 8 | enabled: False 9 | min: 10 10 | max: 10 11 | n_envs: ${n_envs} 12 | n_agents: ${n_agents} 13 | obs_frame: "local" # "world" or "local" 14 | max_time: 20 15 | wait_before_truncate: 5. 16 | max_target_vel: 10. 17 | min_target_vel: 5. 18 | last_action_in_obs: False 19 | 20 | loss_weights: 21 | pointmass: 22 | vel: 1. 23 | jerk: 0.01 24 | pos: 0.1 25 | quadrotor: 26 | vel: 1. 27 | jerk: 0.001 28 | pos: 3. 29 | attitude: 0.1 30 | 31 | reward_weights: 32 | constant: 1. 33 | pointmass: 34 | vel: 0.5 35 | jerk: 0.01 36 | pos: 0.1 37 | quadrotor: 38 | vel: 1. 39 | jerk: 0.001 40 | pos: 3. 41 | attitude: 0.1 42 | 43 | render: 44 | env_spacing: 0.5 45 | 46 | defaults: 47 | - render: pc_render 48 | - randomizer: default_randomizer 49 | - imu: default_imu 50 | - _self_ -------------------------------------------------------------------------------- /cfg/env/racing.yaml: -------------------------------------------------------------------------------- 1 | name: racing 2 | abbr: racing 3 | 4 | dynamics: ${dynamics} 5 | dt: 0.0333 6 | length: 7 | default: 10 8 | enabled: False 9 | min: 10 10 | max: 10 11 | n_envs: ${n_envs} 12 | n_agents: ${n_agents} 13 | obs_frame: "world" # "world" or "local" 14 | max_time: 40 15 | wait_before_truncate: 40 16 | max_target_vel: 10. 17 | min_target_vel: 5. 18 | last_action_in_obs: False 19 | use_vel_track: False 20 | ref_path: ${ref_path} 21 | 22 | gates: 23 | radius: 1.5 24 | height: 1.5 25 | 26 | loss_weights: 27 | pointmass: 28 | vel: 0.5 29 | jerk: 2e-6 30 | pos: 3. 31 | progress: 0. 32 | track: 1. 33 | quadrotor: 34 | vel: 1. 35 | jerk: 0.001 36 | pos: 3. 37 | attitude: 0.1 38 | 39 | reward_weights: 40 | constant: 0. 41 | pointmass: 42 | vel: 0. 43 | jerk: 0. 44 | passed: 0. 45 | oob: 0. 46 | pos: 0. 47 | collision: 10 48 | progress: 1. 49 | track: 0. 50 | quadrotor: 51 | vel: 0. 52 | jerk: 0.1 53 | pos: 0. 54 | attitude: 0. 55 | progress: 10. 56 | collision: 10 57 | 58 | render: 59 | env_spacing: 0. 60 | 61 | defaults: 62 | - render: pc_render 63 | - randomizer: default_randomizer 64 | - imu: default_imu 65 | - _self_ -------------------------------------------------------------------------------- /cfg/env/randomizer/default_randomizer.yaml: -------------------------------------------------------------------------------- 1 | name: default_randomizer 2 | 3 | enabled: True -------------------------------------------------------------------------------- /cfg/env/render/oa_render.yaml: -------------------------------------------------------------------------------- 1 | name: obstacle_avoidance_renderer 2 | 3 | headless: ${headless} 4 | n_envs: ${env.n_envs} 5 | n_agents: ${n_agents} 6 | dt: ${env.dt} 7 | render_n_envs: 16 8 | env_spacing: ${env.length.max} 9 | physics_engine: "physx" 10 | ground_plane: ${env.ground_plane} 11 | record_video: ${record_video} 12 | 13 | sphere_n_segments: 12 14 | 15 | video_camera: 16 | width: 512 17 | height: 288 18 | fov: 60.0 19 | far_plane: 20 20 | 21 | viewer: 22 | ref_env: 0 23 | pos: [-5, -5, 4] 24 | lookat: [0, 0, 0] 25 | 26 | camera: ${env.sensor} 27 | -------------------------------------------------------------------------------- /cfg/env/render/pc_render.yaml: -------------------------------------------------------------------------------- 1 | name: position_control_renderer 2 | 3 | headless: ${headless} 4 | n_envs: ${env.n_envs} 5 | n_agents: ${n_agents} 6 | dt: ${env.dt} 7 | render_n_envs: 256 8 | env_spacing: 2 9 | physics_engine: "physx" 10 | ground_plane: True 11 | record_video: ${record_video} 12 | 13 | video_camera: 14 | width: 512 15 | height: 288 16 | fov: 60.0 17 | far_plane: 20 -------------------------------------------------------------------------------- /cfg/hydra/help/test_help.yaml: -------------------------------------------------------------------------------- 1 | app_name: DiffAero Testing Script 2 | 3 | header: ===== ${hydra.help.app_name} ===== 4 | 5 | footer: |- 6 | View https://github.com/zxh0916/diffaero for more information. 7 | Powered by Hydra (https://hydra.cc) 8 | Use --hydra-help to view Hydra specific help and for Tab completion. 9 | 10 | # Basic Hydra flags: 11 | # $FLAGS_HELP 12 | # 13 | # Config groups, choose one of: 14 | # $APP_CONFIG_GROUPS: All config groups that does not start with hydra/. 15 | # $HYDRA_CONFIG_GROUPS: All the Hydra config groups (starts with hydra/) 16 | # 17 | # Configuration generated with overrides: 18 | # $CONFIG : Generated config 19 | # 20 | template: |- 21 | ${hydra.help.header} 22 | This is ${hydra.help.app_name}! 23 | You can start the simulation and test your trained policy fully on GPU with this script! 24 | If `checkpoint` not specified, the script will use the latest checkpoint in the `outputs` directory. 25 | 26 | Compose your configuration from those groups: 27 | 28 | ↓↓↓↓↓ Configuration groups ↓↓↓↓↓ 29 | $APP_CONFIG_GROUPS↑↑↑↑↑ Configuration groups ↑↑↑↑↑ 30 | 31 | Visualize the environment and your trained policy via GUI: 32 | `python script/test.py checkpoint=/absolute/path/to/checkpoint headless=False` 33 | 34 | Run many experiments in parallel by: 35 | `python script/test.py -m checkpoint=/absolute/path/to/checkpoint seed=0,1,2,3 n_jobs=4` 36 | 37 | ${hydra.help.footer} -------------------------------------------------------------------------------- /cfg/hydra/help/train_help.yaml: -------------------------------------------------------------------------------- 1 | app_name: DiffAero Training Script 2 | 3 | header: ===== ${hydra.help.app_name} ===== 4 | 5 | footer: |- 6 | View https://github.com/zxh0916/diffaero for more information. 7 | Powered by Hydra (https://hydra.cc) 8 | Use --hydra-help to view Hydra specific help and for Tab completion. 9 | 10 | # Basic Hydra flags: 11 | # $FLAGS_HELP 12 | # 13 | # Config groups, choose one of: 14 | # $APP_CONFIG_GROUPS: All config groups that does not start with hydra/. 15 | # $HYDRA_CONFIG_GROUPS: All the Hydra config groups (starts with hydra/) 16 | # 17 | # Configuration generated with overrides: 18 | # $CONFIG : Generated config 19 | # 20 | template: |- 21 | ${hydra.help.header} 22 | This is ${hydra.help.app_name}! 23 | You can start the simulation and train a policy fully on GPU with this script! 24 | 25 | Compose your configuration from those groups: 26 | 27 | ↓↓↓↓↓ Configuration groups ↓↓↓↓↓ 28 | $APP_CONFIG_GROUPS↑↑↑↑↑ Configuration groups ↑↑↑↑↑ 29 | 30 | Feel free to override any of the configuration groups above by: 31 | `python script/train.py env=pc algo=shac algo.actor_lr=0.003` 32 | 33 | Run many experiments in parallel by: 34 | `python script/train.py -m seed=0,1,2,3 n_jobs=4` 35 | 36 | ${hydra.help.footer} -------------------------------------------------------------------------------- /cfg/hydra/sweeper/optuna_sweep.yaml: -------------------------------------------------------------------------------- 1 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 2 | sampler: 3 | _target_: optuna.samplers.TPESampler 4 | seed: ${seed} 5 | consider_prior: true 6 | prior_weight: 1.0 7 | consider_magic_clip: true 8 | consider_endpoints: false 9 | n_startup_trials: 10 10 | n_ei_candidates: 24 11 | multivariate: false 12 | warn_independent_sampling: true 13 | direction: maximize 14 | storage: null 15 | study_name: maximize-success-rate 16 | n_jobs: ${n_jobs} 17 | search_space: null 18 | custom_search_space: null 19 | 20 | n_trials: 10 21 | params: 22 | # https://hydra.cc/docs/advanced/override_grammar/extended/#sweeps 23 | algo.lr: tag(log,interval(0.0001,0.1)) -------------------------------------------------------------------------------- /cfg/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | name: tensorboard -------------------------------------------------------------------------------- /cfg/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | name: wandb 2 | 3 | project: diffaero 4 | entity: 5 | quiet: True -------------------------------------------------------------------------------- /cfg/network/cnn.yaml: -------------------------------------------------------------------------------- 1 | name: cnn 2 | 3 | cnn_layers: [ 4 | # [in_channels, out_channels, stride] 5 | [1, 8, 1], 6 | [8, 16, 2], 7 | [16, 8, 1], 8 | [8, 8, 2] 9 | ] 10 | 11 | hidden_dim: [256, 128] -------------------------------------------------------------------------------- /cfg/network/mlp.yaml: -------------------------------------------------------------------------------- 1 | name: mlp 2 | 3 | hidden_dim: [256, 128] -------------------------------------------------------------------------------- /cfg/network/rcnn.yaml: -------------------------------------------------------------------------------- 1 | name: rcnn 2 | 3 | cnn_layers: [ 4 | # [in_channels, out_channels, stride] 5 | [1, 8, 1], 6 | [8, 16, 2], 7 | [16, 8, 1], 8 | [8, 8, 2] 9 | ] 10 | 11 | hidden_dim: [256, 128] 12 | 13 | rnn_hidden_dim: 512 14 | rnn_n_layers: 1 -------------------------------------------------------------------------------- /cfg/network/rnn.yaml: -------------------------------------------------------------------------------- 1 | name: rnn 2 | 3 | hidden_dim: [256, 128] 4 | 5 | rnn_hidden_dim: 512 6 | rnn_n_layers: 1 -------------------------------------------------------------------------------- /cfg/sensor/camera.yaml: -------------------------------------------------------------------------------- 1 | name: camera 2 | 3 | # Intel Realsense D435 4 | width: 16 5 | height: 9 6 | horizontal_fov: 86.0 7 | max_dist: 5.0 8 | 9 | onboard_position: [0.2, 0., 0.05] 10 | onboard_attitude: [0., 0., 0., 1.] # xyzw quaternion 11 | 12 | # pose randomization 13 | n_envs: ${n_envs} 14 | n_agents: ${n_agents} 15 | pitch_angle_deg: 16 | default: 0 17 | enabled: True 18 | min: -5 19 | max: 5 20 | 21 | yaw_angle_deg: 22 | default: 0 23 | enabled: True 24 | min: -3 25 | max: 3 26 | 27 | roll_angle_deg: 28 | default: 0 29 | enabled: True 30 | min: -3 31 | max: 3 -------------------------------------------------------------------------------- /cfg/sensor/lidar.yaml: -------------------------------------------------------------------------------- 1 | name: lidar 2 | 3 | depression_angle: -20. 4 | elevation_angle: 20. 5 | 6 | n_rays_vertical: 5 7 | n_rays_horizontal: 36 8 | 9 | max_dist: 20 10 | 11 | # pose randomization 12 | n_envs: ${n_envs} 13 | n_agents: ${n_agents} 14 | pitch_angle_deg: 15 | default: 0 16 | enabled: True 17 | min: -3 18 | max: 3 19 | 20 | yaw_angle_deg: 21 | default: 0 22 | enabled: True 23 | min: -10 24 | max: 10 25 | 26 | roll_angle_deg: 27 | default: 0 28 | enabled: True 29 | min: -3 30 | max: 3 -------------------------------------------------------------------------------- /cfg/sensor/relpos.yaml: -------------------------------------------------------------------------------- 1 | name: relpos 2 | 3 | n_obstacles: ${env.n_obstacles} 4 | walls: ${env.obstacles.walls} 5 | ceiling: ${env.obstacles.ceiling} -------------------------------------------------------------------------------- /dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | 6 | from .pointmass import ContinuousPointMassModel, DiscretePointMassModel, PointMassModelBase 7 | from .quadrotor import QuadrotorModel 8 | 9 | DYNAMICS_ALIAS = { 10 | "countinuous_pointmass": ContinuousPointMassModel, 11 | "discrete_pointmass": DiscretePointMassModel, 12 | "quadrotor": QuadrotorModel 13 | } 14 | 15 | def build_dynamics(cfg, device): 16 | # type: (DictConfig, torch.device) -> Union[ContinuousPointMassModel, DiscretePointMassModel, QuadrotorModel] 17 | return DYNAMICS_ALIAS[cfg.name](cfg, device) -------------------------------------------------------------------------------- /dynamics/base_dynamics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from torch import Tensor 5 | import torch.autograd as autograd 6 | import torch.nn.functional as F 7 | import pytorch3d.transforms as T 8 | from omegaconf import DictConfig 9 | 10 | from diffaero.utils.math import quat_rotate, quat_rotate_inverse, mvp, axis_rotmat, quaternion_to_euler 11 | from diffaero.utils.logger import Logger 12 | 13 | class BaseDynamics(ABC): 14 | def __init__(self, cfg: DictConfig, device: torch.device): 15 | self.type: str 16 | self.state_dim: int 17 | self.action_dim: int 18 | self.device = device 19 | self.n_agents: int = cfg.n_agents 20 | self.n_envs: int = cfg.n_envs 21 | self.dt: float = cfg.dt 22 | self.alpha: float = cfg.alpha 23 | 24 | self._G = torch.tensor(cfg.g, device=device, dtype=torch.float32) 25 | self._G_vec = torch.tensor([0.0, 0.0, -self._G], device=device, dtype=torch.float32) 26 | if self.n_agents > 1: 27 | self._G_vec.unsqueeze_(0) 28 | 29 | def detach(self): 30 | """Detach the state to prevent backpropagation through released computation graphs.""" 31 | self._state = self._state.detach() 32 | 33 | def grad_decay(self, state: Tensor) -> Tensor: 34 | if self.alpha > 0: 35 | state = GradientDecay.apply(state, self.alpha, self.dt) 36 | return state 37 | 38 | @abstractmethod 39 | def step(self, U: Tensor) -> None: 40 | """Step the model with the given action U. 41 | 42 | Args: 43 | U (Tensor): The action tensor of shape (n_envs, n_agents, 3). 44 | """ 45 | raise NotImplementedError("This method should be implemented in subclasses.") 46 | 47 | # Action ranges 48 | @property 49 | @abstractmethod 50 | def min_action(self) -> Tensor: 51 | raise NotImplementedError 52 | 53 | @property 54 | @abstractmethod 55 | def max_action(self) -> Tensor: 56 | raise NotImplementedError 57 | 58 | # Properties of agents, requires_grad=True if stepped with undetached action inputs 59 | @property 60 | @abstractmethod 61 | def _p(self) -> Tensor: 62 | raise NotImplementedError 63 | 64 | @property 65 | @abstractmethod 66 | def _v(self) -> Tensor: 67 | raise NotImplementedError 68 | 69 | @property 70 | @abstractmethod 71 | def _a(self) -> Tensor: 72 | raise NotImplementedError 73 | 74 | @property 75 | @abstractmethod 76 | def _w(self) -> Tensor: 77 | raise NotImplementedError 78 | 79 | @property 80 | @abstractmethod 81 | def _q(self) -> Tensor: 82 | """Quaternion representing the orientation of the body frame in world frame, with real part last.""" 83 | raise NotImplementedError 84 | 85 | # Detached versions of properties 86 | @property 87 | def p(self) -> Tensor: 88 | return self._p.detach() 89 | 90 | @property 91 | def v(self) -> Tensor: 92 | return self._v.detach() 93 | 94 | @property 95 | def a(self) -> Tensor: 96 | return self._a.detach() 97 | 98 | @property 99 | def w(self) -> Tensor: 100 | return self._w.detach() 101 | 102 | @property 103 | def q(self) -> Tensor: 104 | return self._q.detach() 105 | 106 | # Rotation utilities 107 | @property 108 | def R(self) -> Tensor: 109 | "Rotation matrix with columns being coordinate of axis unit vectors of body frame in world frame." 110 | return T.quaternion_to_matrix(self.q.roll(1, -1)) 111 | 112 | @property 113 | def Rz(self) -> Tensor: 114 | "Rotation matrix with columns being coordinate of axis unit vectors of local frame in world frame." 115 | # Rz = self.R.clone() 116 | # fwd = Rz[..., 0] 117 | # fwd[..., 2] = 0. 118 | # fwd = F.normalize(fwd, dim=-1) 119 | # up = torch.zeros_like(fwd) 120 | # up[..., 2] = 1. 121 | # left = torch.cross(up, fwd, dim=-1) 122 | # return torch.stack([fwd, left, up], dim=-1) 123 | return axis_rotmat("Z", quaternion_to_euler(self.q)[..., 2]) 124 | 125 | @property 126 | def ux(self) -> Tensor: 127 | "Unit vector along the x-axis of the body frame in world frame." 128 | return self.R[..., 0] 129 | 130 | @property 131 | def uy(self) -> Tensor: 132 | "Unit vector along the y-axis of the body frame in world frame." 133 | return self.R[..., 1] 134 | 135 | @property 136 | def uz(self) -> Tensor: 137 | "Unit vector along the z-axis of the body frame in world frame." 138 | return self.R[..., 2] 139 | 140 | def world2body(self, vec_w: Tensor) -> Tensor: 141 | """ 142 | Convert vector from world frame to body frame. 143 | Args: 144 | vec_w (Tensor): vector in world frame 145 | Returns: 146 | Tensor: vector in body frame 147 | """ 148 | return quat_rotate_inverse(self.q, vec_w) 149 | 150 | def body2world(self, vec_b: Tensor) -> Tensor: 151 | """ 152 | Convert vector from body frame to world frame. 153 | Args: 154 | vec_b (Tensor): vector in body frame 155 | Returns: 156 | Tensor: vector in world frame 157 | """ 158 | return quat_rotate(self.q, vec_b) 159 | 160 | def world2local(self, vec_w: Tensor) -> Tensor: 161 | """ 162 | Convert vector from world frame to local frame. 163 | Args: 164 | vec_w (Tensor): vector in world frame 165 | Returns: 166 | Tensor: vector in local frame 167 | """ 168 | # Logger.debug(mvp(self.Rz.transpose(-1, -2), self.ux)[0][..., 1].cpu(), "should be around 0") 169 | return mvp(self.Rz.transpose(-1, -2), vec_w) 170 | 171 | def local2world(self, vec_l: Tensor) -> Tensor: 172 | """ 173 | Convert vector from local frame to world frame. 174 | Args: 175 | vec_l (Tensor): vector in local frame 176 | Returns: 177 | Tensor: vector in world frame 178 | """ 179 | return mvp(self.Rz, vec_l) 180 | 181 | class GradientDecay(autograd.Function): 182 | @staticmethod 183 | def forward(ctx, state: Tensor, alpha: float, dt: float): 184 | ctx.save_for_backward(torch.tensor(-alpha * dt, device=state.device).exp()) 185 | return state 186 | 187 | @staticmethod 188 | def backward(ctx, grad_state: Tensor): 189 | decay_factor = ctx.saved_tensors[0] 190 | if ctx.needs_input_grad[0]: 191 | grad_state = grad_state * decay_factor 192 | return grad_state, None, None -------------------------------------------------------------------------------- /dynamics/controller.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import pytorch3d.transforms as p3d_transforms 5 | from omegaconf import DictConfig 6 | 7 | class BaseController: 8 | """Convert the action from RL agent to force and torques to be applied on the drone.""" 9 | def __init__( 10 | self, 11 | mass: torch.Tensor, 12 | inertia: torch.Tensor, 13 | gravity: torch.Tensor, 14 | cfg: DictConfig, 15 | device: torch.device 16 | ): 17 | self.cfg = cfg 18 | self.device = device 19 | self.mass = mass 20 | self.inertia = inertia 21 | self.gravity = gravity 22 | self.thrust_ratio: float = cfg.thrust_ratio 23 | self.torque_ratio: float = cfg.torque_ratio 24 | 25 | # lower bound of controller output (actual normed force & torque) 26 | self.min_thrust = torch.tensor(cfg.min_normed_thrust, device=device) 27 | self.min_torque = torch.tensor(list(cfg.min_normed_torque), device=device) 28 | 29 | # upper bound of controller output (actual normed force & torque) 30 | self.max_thrust = torch.tensor(cfg.max_normed_thrust, device=device) 31 | self.max_torque = torch.tensor(list(cfg.max_normed_torque), device=device) 32 | 33 | def __call__(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 34 | raise NotImplementedError 35 | 36 | def postprocess(self, normed_thrust, normed_torque): 37 | # type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 38 | normed_torque = normed_torque * self.torque_ratio 39 | normed_thrust = normed_thrust * self.thrust_ratio 40 | # compensate gravity 41 | if self.cfg.compensate_gravity: 42 | normed_thrust += 1. 43 | thrust = normed_thrust * self.gravity * self.mass 44 | torque = normed_torque * self.inertia 45 | return thrust, torque 46 | 47 | 48 | class RateController(BaseController): 49 | """ 50 | Body Rate Controller. 51 | 52 | Take desired thrust, roll rate, picth rate, and yaw rate as input 53 | and output actual force and torque to be applied on the robot. 54 | """ 55 | def __init__( 56 | self, 57 | mass: torch.Tensor, 58 | inertia: torch.Tensor, 59 | gravity: torch.Tensor, 60 | cfg: DictConfig, 61 | device: torch.device 62 | ): 63 | super().__init__(mass, inertia, gravity, cfg, device) 64 | self.K_angvel = torch.tensor(cfg.K_angvel, device=device) 65 | 66 | # lower bound of controller input (action) 67 | self.min_action = torch.tensor([ 68 | cfg.min_normed_thrust, 69 | cfg.min_roll_rate, 70 | cfg.min_pitch_rate, 71 | cfg.min_yaw_rate 72 | ], device=device) 73 | 74 | # upper bound of controller input (action) 75 | self.max_action = torch.tensor([ 76 | cfg.max_normed_thrust, 77 | cfg.max_roll_rate, 78 | cfg.max_pitch_rate, 79 | cfg.max_yaw_rate 80 | ], device=device) 81 | 82 | def __call__(self, q_xyzw, w, action): 83 | # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 84 | 85 | # quaternion with real component first 86 | R_b2i = p3d_transforms.quaternion_to_matrix(q_xyzw.roll(1, dims=-1)) 87 | # for numeric stability, very important 88 | R_b2i.clamp_(min=-1.0+1e-6, max=1.0-1e-6) 89 | # Convert current rotation matrix to euler angles 90 | R_i2b = torch.transpose(R_b2i, -1, -2) 91 | 92 | desired_angvel_b = action[:, 1:] 93 | actual_angvel_b = torch.bmm(R_i2b, w.unsqueeze(-1)).squeeze(-1) 94 | angvel_err = desired_angvel_b - actual_angvel_b 95 | 96 | # Ω × JΩ 97 | cross = torch.cross(actual_angvel_b, (self.inertia @ actual_angvel_b.unsqueeze(-1)).squeeze(-1), dim=1) 98 | cross.div_(torch.max(cross.norm(dim=-1, keepdim=True) / 100, 99 | torch.tensor(1., device=cross.device)).detach()) 100 | angacc = self.torque_ratio * self.K_angvel * angvel_err 101 | torque = (self.inertia @ angacc.unsqueeze(-1)).squeeze(-1) + cross 102 | thrust = action[:, 0] * self.thrust_ratio * self.gravity * self.mass 103 | return thrust, torque -------------------------------------------------------------------------------- /dynamics/pointmass.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | from pytorch3d import transforms as T 7 | from omegaconf import DictConfig 8 | 9 | from diffaero.dynamics.base_dynamics import BaseDynamics 10 | from diffaero.utils.math import EulerIntegral, rk4, axis_rotmat, mvp, quat_standardize, quat_mul 11 | from diffaero.utils.randomizer import build_randomizer 12 | 13 | class PointMassModelBase(BaseDynamics): 14 | def __init__(self, cfg: DictConfig, device: torch.device): 15 | super().__init__(cfg, device) 16 | self.type = "pointmass" 17 | self.action_frame: str = cfg.action_frame 18 | assert self.action_frame in ["world", "local"], f"Invalid action frame: {self.action_frame}. Must be 'world' or 'local'." 19 | self.state_dim = 9 20 | self.action_dim = 3 21 | self._state = torch.zeros(self.n_envs, self.n_agents, self.state_dim, device=device) 22 | self._vel_ema = torch.zeros(self.n_envs, self.n_agents, 3, device=device) 23 | self._acc = torch.zeros(self.n_envs, self.n_agents, 3, device=device) 24 | xyz = torch.zeros(self.n_envs, self.n_agents, 3, device=device) 25 | w = torch.ones(self.n_envs, self.n_agents, 1, device=device) 26 | self.quat_xyzw = torch.cat([xyz, w], dim=-1) 27 | self.quat_xyzw_init = self.quat_xyzw.clone() 28 | if self.n_agents == 1: 29 | self._state.squeeze_(1) 30 | self._vel_ema.squeeze_(1) 31 | self._acc.squeeze_(1) 32 | self.quat_xyzw.squeeze_(1) 33 | self.quat_xyzw_init.squeeze_(1) 34 | self.align_yaw_with_target_direction: bool = cfg.align_yaw_with_target_direction 35 | self.align_yaw_with_vel_ema: bool = cfg.align_yaw_with_vel_ema 36 | 37 | self.vel_ema_factor = build_randomizer(cfg.vel_ema_factor, [self.n_envs, self.n_agents, 1], device=device) 38 | self._D = build_randomizer(cfg.D, [self.n_envs, self.n_agents, 1], device=device) 39 | self.lmbda = build_randomizer(cfg.lmbda, [self.n_envs, self.n_agents, 1], device=device) 40 | if self.n_agents == 1: 41 | self.vel_ema_factor.value.squeeze_(1) 42 | self._D.value.squeeze_(1) 43 | self.lmbda.value.squeeze_(1) 44 | self.max_acc_xy = build_randomizer(cfg.max_acc.xy, [self.n_envs, self.n_agents], device=device) 45 | self.max_acc_z = build_randomizer(cfg.max_acc.z, [self.n_envs, self.n_agents], device=device) 46 | 47 | @property 48 | def min_action(self) -> Tensor: 49 | zero = torch.zeros_like(self.max_acc_xy.value) 50 | min_action = torch.stack([-self.max_acc_xy.value, -self.max_acc_xy.value, zero], dim=-1) 51 | if self.n_agents == 1: 52 | min_action.squeeze_(1) 53 | return min_action 54 | 55 | @property 56 | def max_action(self) -> Tensor: 57 | max_action = torch.stack([self.max_acc_xy.value, self.max_acc_xy.value, self.max_acc_z.value], dim=-1) 58 | if self.n_agents == 1: 59 | max_action.squeeze_(1) 60 | return max_action 61 | 62 | def detach(self): 63 | super().detach() 64 | self._vel_ema.detach_() 65 | self._acc.detach_() 66 | 67 | def reset_idx(self, env_idx: Tensor) -> None: 68 | mask = torch.zeros(*self._vel_ema.shape[:-1], dtype=torch.bool, device=self.device) 69 | mask[env_idx] = True 70 | mask3 = mask.unsqueeze(-1).expand_as(self._vel_ema) 71 | self._vel_ema = torch.where(mask3, 0., self._vel_ema) 72 | self._acc = torch.where(mask3, 0., self._acc) 73 | mask4 = mask.unsqueeze(-1).expand_as(self.quat_xyzw) 74 | self.quat_xyzw = torch.where(mask4, self.quat_xyzw_init, self.quat_xyzw) 75 | 76 | @property 77 | def q(self) -> Tensor: return self.quat_xyzw 78 | @property 79 | def w(self) -> Tensor: 80 | warnings.warn("Access of angular velocity in point mass model is not supported. Returning zero tensor instead.") 81 | return torch.zeros_like(self.p) 82 | @property 83 | def _p(self) -> Tensor: return self._state[..., 0:3] 84 | @property 85 | def _v(self) -> Tensor: return self._state[..., 3:6] 86 | @property 87 | def _a(self) -> Tensor: return self._acc 88 | @property 89 | def _a_thrust(self) -> Tensor: return self._state[..., 6:9] 90 | @property 91 | def a_thrust(self) -> Tensor: return self._a_thrust.detach() 92 | @property 93 | def _q(self) -> Tensor: 94 | warnings.warn("Direct access of quaternion with gradient in point mass model is not supported. Returning detached version instead.") 95 | return self.q 96 | @property 97 | def _w(self) -> Tensor: 98 | warnings.warn("Access of angular velocity with gradient in point mass model is not supported. Returning zero tensor instead.") 99 | return torch.zeros_like(self.p) 100 | 101 | def update_state(self, next_state: Tensor) -> None: 102 | self._state = self.grad_decay(next_state) 103 | self._vel_ema = torch.lerp(self._vel_ema, self._v, self.vel_ema_factor.value) 104 | self._acc = self._a_thrust + self._G_vec - self._D.value * self._v 105 | with torch.no_grad(): 106 | orientation = self._vel_ema if self.align_yaw_with_vel_ema else self.v 107 | self.quat_xyzw = point_mass_quat(self.a_thrust, orientation=orientation) 108 | 109 | @torch.jit.script 110 | def continuous_point_mass_dynamics_local( 111 | X: Tensor, 112 | U: Tensor, 113 | dt: float, 114 | Rz: Tensor, 115 | G_vec: Tensor, 116 | D: Tensor, 117 | lmbda: Tensor, 118 | ): 119 | """Dynamics function for continuous point mass model in local frame.""" 120 | p, v, a_thrust = X[..., :3], X[..., 3:6], X[..., 6:9] 121 | p_dot = v 122 | fdrag = -D * v 123 | v_dot = a_thrust + G_vec + fdrag 124 | control_delay_factor = (1 - torch.exp(-lmbda * dt)) / dt 125 | a_thrust_cmd_local = U 126 | a_thrust_cmd = torch.matmul(Rz, a_thrust_cmd_local.unsqueeze(-1)).squeeze(-1) 127 | a_dot = control_delay_factor * (a_thrust_cmd - a_thrust) 128 | 129 | X_dot = torch.concat([p_dot, v_dot, a_dot], dim=-1) 130 | return X_dot 131 | 132 | @torch.jit.script 133 | def continuous_point_mass_dynamics_world( 134 | X: Tensor, 135 | U: Tensor, 136 | dt: float, 137 | G_vec: Tensor, 138 | D: Tensor, 139 | lmbda: Tensor, 140 | ): 141 | """Dynamics function for continuous point mass model in local frame.""" 142 | p, v, a_thrust = X[..., :3], X[..., 3:6], X[..., 6:9] 143 | p_dot = v 144 | fdrag = -D * v 145 | v_dot = a_thrust + G_vec + fdrag 146 | control_delay_factor = (1 - torch.exp(-lmbda * dt)) / dt 147 | a_thrust_cmd = U 148 | a_dot = control_delay_factor * (a_thrust_cmd - a_thrust) 149 | 150 | X_dot = torch.concat([p_dot, v_dot, a_dot], dim=-1) 151 | return X_dot 152 | 153 | class ContinuousPointMassModel(PointMassModelBase): 154 | def __init__(self, cfg: DictConfig, device: torch.device): 155 | super().__init__(cfg, device) 156 | self.n_substeps: int = cfg.n_substeps 157 | assert cfg.solver_type in ["euler", "rk4"] 158 | if cfg.solver_type == "euler": 159 | self.solver = EulerIntegral 160 | elif cfg.solver_type == "rk4": 161 | self.solver = rk4 162 | self.Rz_temp: Tensor 163 | 164 | def dynamics(self, X: Tensor, U: Tensor) -> Tensor: 165 | if self.action_frame == "local": 166 | X_dot = continuous_point_mass_dynamics_local( 167 | X, U, self.dt, self.Rz_temp, self._G_vec, self._D.value, self.lmbda.value 168 | ) 169 | elif self.action_frame == "world": 170 | X_dot = continuous_point_mass_dynamics_world( 171 | X, U, self.dt, self._G_vec, self._D.value, self.lmbda.value 172 | ) 173 | return X_dot 174 | 175 | def step(self, U: Tensor) -> None: 176 | if self.action_frame == "local": 177 | self.Rz_temp = self.Rz.clone() 178 | next_state = self.solver(self.dynamics, self._state, U, dt=self.dt, M=self.n_substeps) 179 | self.update_state(next_state) 180 | 181 | 182 | @torch.jit.script 183 | def discrete_point_mass_dynamics_local( 184 | X: Tensor, 185 | U: Tensor, 186 | dt: float, 187 | Rz: Tensor, 188 | G_vec: Tensor, 189 | D: Tensor, 190 | lmbda: Tensor, 191 | ): 192 | """Dynamics function for discrete point mass model in local frame.""" 193 | p, v, a_thrust = X[..., :3], X[..., 3:6], X[..., 6:9] 194 | next_p = p + dt * (v + 0.5 * (a_thrust + G_vec) * dt) 195 | control_delay_factor = 1 - torch.exp(-lmbda*dt) 196 | a_thrust_cmd_local = U 197 | a_thrust_cmd = mvp(Rz, a_thrust_cmd_local) 198 | next_a = torch.lerp(a_thrust, a_thrust_cmd, control_delay_factor) - D * v 199 | next_v = v + dt * (0.5 * (a_thrust + next_a) + G_vec) 200 | 201 | next_state = torch.cat([next_p, next_v, next_a], dim=-1) 202 | return next_state 203 | 204 | @torch.jit.script 205 | def discrete_point_mass_dynamics_world( 206 | X: Tensor, 207 | U: Tensor, 208 | dt: float, 209 | G_vec: Tensor, 210 | D: Tensor, 211 | lmbda: Tensor, 212 | ): 213 | """Dynamics function for discrete point mass model in world frame.""" 214 | p, v, a_thrust = X[..., :3], X[..., 3:6], X[..., 6:9] 215 | next_p = p + dt * (v + 0.5 * (a_thrust + G_vec) * dt) 216 | control_delay_factor = 1 - torch.exp(-lmbda*dt) 217 | a_thrust_cmd = U 218 | next_a = torch.lerp(a_thrust, a_thrust_cmd, control_delay_factor) - D * v 219 | next_v = v + dt * (0.5 * (a_thrust + next_a) + G_vec) 220 | 221 | next_state = torch.cat([next_p, next_v, next_a], dim=-1) 222 | return next_state 223 | 224 | class DiscretePointMassModel(PointMassModelBase): 225 | 226 | def step(self, U: Tensor) -> None: 227 | if self.action_frame == "local": 228 | next_state = discrete_point_mass_dynamics_local( 229 | self._state, U, self.dt, self.Rz, self._G_vec, self._D.value, self.lmbda.value 230 | ) 231 | elif self.action_frame == "world": 232 | next_state = discrete_point_mass_dynamics_world( 233 | self._state, U, self.dt, self._G_vec, self._D.value, self.lmbda.value 234 | ) 235 | self.update_state(next_state) 236 | 237 | 238 | @torch.jit.script 239 | def point_mass_quat(a: Tensor, orientation: Tensor) -> Tensor: 240 | """Compute the drone pose using target direction and thrust acceleration direction. 241 | 242 | Args: 243 | a (Tensor): the acceleration of the drone in world frame. 244 | orientation (Tensor): at which direction(yaw) the drone should be facing. 245 | 246 | Returns: 247 | Tensor: attitude quaternion of the drone with real part last. 248 | """ 249 | up: Tensor = F.normalize(a, dim=-1) 250 | yaw = torch.atan2(orientation[..., 1], orientation[..., 0]) 251 | mat_yaw = axis_rotmat("Z", yaw) 252 | new_up = (mat_yaw.transpose(-2, -1) @ up.unsqueeze(-1)).squeeze(-1) 253 | z = torch.zeros_like(new_up) 254 | z[..., -1] = 1. 255 | quat_axis = F.normalize(torch.cross(z, new_up, dim=-1), dim=-1) 256 | cos = torch.cosine_similarity(new_up, z, dim=-1) 257 | sin = torch.norm(new_up[..., :2], dim=-1) / (torch.norm(new_up, dim=-1) + 1e-7) 258 | quat_angle = torch.atan2(sin, cos) 259 | quat_pitch_roll_xyz = quat_axis * torch.sin(0.5 * quat_angle).unsqueeze(-1) 260 | quat_pitch_roll_w = torch.cos(0.5 * quat_angle).unsqueeze(-1) 261 | quat_pitch_roll = quat_standardize(torch.cat([quat_pitch_roll_xyz, quat_pitch_roll_w], dim=-1)) 262 | yaw_half = yaw.unsqueeze(-1) / 2 263 | quat_yaw = torch.concat([torch.sin(yaw_half) * z, torch.cos(yaw_half)], dim=-1) # T.matrix_to_quaternion(mat_yaw) 264 | quat_xyzw = quat_mul(quat_yaw, quat_pitch_roll) 265 | 266 | # ori = torch.stack([orientation[..., 0], orientation[..., 1], torch.zeros_like(orientation[..., 2])], dim=-1) 267 | # print(F.normalize(quaternion_apply(quaternion_invert(quat_yaw), ori), dim=-1)[..., 0]) # 1 268 | # assert torch.max(torch.abs(quaternion_apply(quat_wxyz, z) - up)) < 1e-6 269 | # assert torch.max(torch.abs(quaternion_apply(quaternion_invert(quat_wxyz), up) - z)) < 1e-6 270 | # assert torch.max(torch.abs(quaternion_apply(quat_pitch_roll, z) - new_up)) < 1e-6 271 | 272 | # mat = T.quaternion_to_matrix(quat_wxyz) 273 | # print(((mat @ z.unsqueeze(-1)).squeeze(-1) - up).norm(dim=-1).max()) 274 | 275 | # euler = quaternion_to_euler(quat_xyzw) 276 | # mat_roll, mat_pitch, mat_yaw = axis_rotmat("X", euler[..., 0]), axis_rotmat("Y", euler[..., 1]), axis_rotmat("Z", euler[..., 2]) 277 | # mat_rot = mat_roll @ mat_pitch @ mat_yaw 278 | # print((mat_rot @ z.unsqueeze(-1)).squeeze(-1) - up) 279 | 280 | return quat_xyzw 281 | -------------------------------------------------------------------------------- /dynamics/quadrotor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from omegaconf import DictConfig 4 | 5 | from diffaero.dynamics.base_dynamics import BaseDynamics 6 | from diffaero.dynamics.controller import RateController 7 | from diffaero.utils.math import * 8 | from diffaero.utils.randomizer import build_randomizer 9 | 10 | class QuadrotorModel(BaseDynamics): 11 | def __init__(self, cfg: DictConfig, device: torch.device): 12 | super().__init__(cfg, device) 13 | self.type = "quadrotor" 14 | self.state_dim = 13 15 | self.action_dim = 4 16 | self._state = torch.zeros(self.n_envs, self.n_agents, self.state_dim, device=device) 17 | self._acc = torch.zeros(self.n_envs, self.n_agents, 3, device=device) 18 | if self.n_agents == 1: 19 | self._state.squeeze_(1) 20 | self._acc.squeeze_(1) 21 | 22 | self.n_substeps: int = cfg.n_substeps 23 | assert cfg.solver_type in ["euler", "rk4"] 24 | if cfg.solver_type == "euler": 25 | self.solver = EulerIntegral 26 | elif cfg.solver_type == "rk4": 27 | self.solver = rk4 28 | 29 | wrap = lambda x: torch.tensor(x, device=device, dtype=torch.float32) 30 | 31 | self._m = build_randomizer(cfg.m, self.n_envs, device) # total mass 32 | self._arm_l = build_randomizer(cfg.arm_l, self.n_envs, device) # arm length 33 | self._c_tau = build_randomizer(cfg.c_tau, self.n_envs, device) # torque constant 34 | 35 | # inertia 36 | self.J_xy = build_randomizer(cfg.J.xy, self.n_envs, device) 37 | self.J_z = build_randomizer(cfg.J.z, self.n_envs, device) 38 | # drag coefficients 39 | self.D_xy = build_randomizer(cfg.D.xy, self.n_envs, device) 40 | self.D_z = build_randomizer(cfg.D.z, self.n_envs, device) 41 | 42 | self._v_xy_max = wrap(float('inf')) 43 | self._v_z_max = wrap(float('inf')) 44 | self._omega_xy_max = wrap(cfg.max_w_xy) 45 | self._omega_z_max = wrap(cfg.max_w_z) 46 | self._T_max = wrap(cfg.max_T) 47 | self._T_min = wrap(cfg.min_T) 48 | 49 | self._X_lb = wrap([-float('inf'), -float('inf'), -float('inf'), 50 | -self._v_xy_max, -self._v_xy_max, -self._v_z_max, 51 | -1, -1, -1, -1, 52 | -self._omega_xy_max, -self._omega_xy_max, -self._omega_z_max]) 53 | 54 | self._X_ub = wrap([float('inf'), float('inf'), float('inf'), 55 | self._v_xy_max, self._v_xy_max, self._v_z_max, 56 | 1, 1, 1, 1, 57 | self._omega_xy_max, self._omega_xy_max, self._omega_z_max]) 58 | 59 | self._U_lb = wrap([self._T_min, self._T_min, self._T_min, self._T_min]) 60 | self._U_ub = wrap([self._T_max, self._T_max, self._T_max, self._T_max]) 61 | 62 | self.controller = RateController(self._m.value, self._J, self._G, cfg.controller, self.device) 63 | 64 | @property 65 | def min_action(self) -> Tensor: 66 | return self.controller.min_action 67 | @property 68 | def max_action(self) -> Tensor: 69 | return self.controller.max_action 70 | 71 | @property 72 | def _tau_thrust_matrix(self) -> Tensor: 73 | c, d = self._c_tau.value, self._arm_l.value / (2**0.5) 74 | ones = torch.ones(self.n_envs, 4, device=c.device, dtype=c.dtype) 75 | _tau_thrust_matrix = torch.stack([ 76 | torch.stack([ d, -d, -d, d], dim=-1), 77 | torch.stack([-d, d, -d, d], dim=-1), 78 | torch.stack([ c, c, -c, -c], dim=-1), 79 | ones], dim=-2) 80 | print(_tau_thrust_matrix.shape, _tau_thrust_matrix[0]) 81 | return _tau_thrust_matrix 82 | 83 | @property 84 | def _J(self) -> Tensor: 85 | J = torch.zeros(self.n_envs, 3, 3, device=self.device) 86 | J[:, 0, 0] = self.J_xy.value 87 | J[:, 1, 1] = self.J_xy.value 88 | J[:, 2, 2] = self.J_z.value 89 | return J 90 | 91 | @property 92 | def _J_inv(self) -> Tensor: 93 | J_inv = torch.zeros(self.n_envs, 3, 3, device=self.device) 94 | J_inv[:, 0, 0] = 1. / self.J_xy.value 95 | J_inv[:, 1, 1] = 1. / self.J_xy.value 96 | J_inv[:, 2, 2] = 1. / self.J_z.value 97 | return J_inv 98 | 99 | @property 100 | def _D(self) -> Tensor: 101 | D = torch.zeros(self.n_envs, 3, 3, device=self.device) 102 | D[:, 0, 0] = self.D_xy.value 103 | D[:, 1, 1] = self.D_xy.value 104 | D[:, 2, 2] = self.D_z.value 105 | return D 106 | 107 | def detach(self): 108 | super().detach() 109 | self._acc.detach_() 110 | 111 | def dynamics(self, X: torch.Tensor, U: torch.Tensor) -> torch.Tensor: 112 | # Unpacking state and input variables 113 | p, q, v, w = X[..., :3], X[..., 3:7], X[..., 7:10], X[..., 10:13] 114 | # Calculate torques and thrust 115 | # T1, T2, T3, T4 = U[:, 0], U[:, 1], U[:, 2], U[:, 3] 116 | # taux = (T1 + T4 - T2 - T3) * self._arm_l / torch.sqrt(torch.tensor(2.0)) 117 | # tauy = (T1 + T3 - T2 - T4) * self._arm_l / torch.sqrt(torch.tensor(2.0)) 118 | # tauz = (T3 + T4 - T1 - T2) * self._c_tau 119 | # thrust = (T1 + T2 + T3 + T4) 120 | # torque = torch.stack((taux, tauy, tauz), dim=1) 121 | thrust, torque = self.controller(q, w, U) 122 | 123 | M = torque - torch.cross(w, torch.matmul(self._J, w.unsqueeze(-1)).squeeze(-1), dim=-1) 124 | w_dot = torch.matmul(self._J_inv, M.unsqueeze(-1)).squeeze(-1) 125 | 126 | # Drag force 127 | fdrag = quat_rotate(q, (self._D @ quat_rotate(quat_inv(q), v).unsqueeze(-1)).squeeze(-1)) 128 | 129 | # thrust acceleration 130 | thrust_acc = quat_axis(q, 2) * (thrust / self._m.value).unsqueeze(-1) 131 | 132 | # overall acceleration 133 | acc = thrust_acc + self._G_vec - fdrag / self._m.value.unsqueeze(-1) 134 | self._acc = acc 135 | 136 | # quaternion derivative 137 | q_dot = 0.5 * quat_mul(q, torch.cat((w, torch.zeros((q.size(0), 1), device=self.device)), dim=-1)) 138 | 139 | # State derivatives 140 | X_dot = torch.concat([v, q_dot, acc, w_dot], dim=-1) 141 | 142 | return X_dot 143 | 144 | def step(self, U: Tensor) -> None: 145 | new_state = self.solver(self.dynamics, self._state, U, dt=self.dt, M=self.n_substeps) 146 | q_l = torch.norm(new_state[..., 3:7], dim=1, keepdim=True).detach() 147 | new_state[..., 3:7] = new_state[..., 3:7] / q_l 148 | self._state = self.grad_decay(new_state) 149 | 150 | def reset_idx(self, env_idx: Tensor) -> None: 151 | mask = torch.zeros_like(self._acc, dtype=torch.bool) 152 | mask[env_idx] = True 153 | self._acc = torch.where(mask, 0., self._acc) 154 | 155 | @property 156 | def _p(self) -> Tensor: return self._state[:, 0:3] 157 | @property 158 | def _q(self) -> Tensor: return self._state[:, 3:7] 159 | @property 160 | def _v(self) -> Tensor: return self._state[:, 7:10] 161 | @property 162 | def _w(self) -> Tensor: return self._state[:, 10:13] 163 | @property 164 | def _a(self) -> Tensor: return self._acc 165 | -------------------------------------------------------------------------------- /env/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | 6 | from diffaero.env.position_control import PositionControl, Sim2RealPositionControl 7 | from diffaero.env.position_control_multi_agent import MultiAgentPositionControl 8 | from diffaero.env.obstacle_avoidance import ObstacleAvoidance 9 | from diffaero.env.racing import Racing 10 | 11 | ENV_ALIAS = { 12 | "position_control": PositionControl, 13 | "sim2real_position_control": Sim2RealPositionControl, 14 | "multi_agent_position_control": MultiAgentPositionControl, 15 | "obstacle_avoidance": ObstacleAvoidance, 16 | "racing": Racing 17 | } 18 | 19 | def build_env(cfg, device): 20 | # type: (DictConfig, torch.device) -> Union[PositionControl, MultiAgentPositionControl, ObstacleAvoidance, Racing] 21 | env_class = ENV_ALIAS[cfg.name] 22 | return env_class(cfg, device) -------------------------------------------------------------------------------- /env/base_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple, Dict, Union, Optional 3 | 4 | from omegaconf import DictConfig 5 | import torch 6 | from torch import Tensor 7 | from tensordict import TensorDict 8 | 9 | from diffaero.dynamics import build_dynamics 10 | from diffaero.dynamics.pointmass import point_mass_quat, PointMassModelBase 11 | from diffaero.utils.sensor import IMU 12 | from diffaero.utils.randomizer import RandomizerManager, build_randomizer 13 | from diffaero.utils.render import PositionControlRenderer, ObstacleAvoidanceRenderer 14 | 15 | class BaseEnv(ABC): 16 | def __init__(self, cfg: DictConfig, device: torch.device): 17 | self.randomizer = RandomizerManager(cfg.randomizer) 18 | self.dynamics = build_dynamics(cfg.dynamics, device) 19 | self.dynamic_type: str = self.dynamics.type 20 | self.imu = IMU(cfg.imu, dynamics=self.dynamics) 21 | self.action_dim = self.dynamics.action_dim 22 | self.obs_dim: Union[int, Tuple[int, Tuple[int, int]]] 23 | self.obs_frame: str = cfg.obs_frame 24 | self.state_dim: int 25 | self.n_agents: int = cfg.n_agents 26 | self.dt: float = cfg.dt 27 | self.n_envs: int = cfg.n_envs 28 | self.L = build_randomizer(cfg.length, [self.n_envs], device=device) 29 | if not isinstance(self, BaseEnvMultiAgent): 30 | assert self.n_agents == 1 31 | self.target_pos = torch.zeros(self.n_envs, 3, device=device) 32 | self.init_pos = torch.zeros(self.n_envs, 3, device=device) 33 | self.last_action = torch.zeros(self.n_envs, self.action_dim, device=device) 34 | if self.n_agents > 1: 35 | self.init_pos = torch.zeros(self.n_envs, self.n_agents, 3, device=device) 36 | self.last_action = torch.zeros(self.n_envs, self.n_agents, self.action_dim, device=device) 37 | assert isinstance(self, BaseEnvMultiAgent) 38 | self.progress = torch.zeros(self.n_envs, device=device, dtype=torch.long) 39 | self.arrive_time = torch.zeros(self.n_envs, device=device, dtype=torch.float) 40 | self.max_steps: int = int(cfg.max_time / cfg.dt) 41 | self.wait_before_truncate: float = cfg.wait_before_truncate 42 | self.cfg = cfg 43 | self.loss_weights: DictConfig = cfg.loss_weights 44 | self.reward_weights: DictConfig = cfg.reward_weights 45 | self.device = device 46 | self.max_vel = torch.zeros(self.n_envs, device=device) 47 | self.min_target_vel: float = cfg.min_target_vel 48 | self.max_target_vel: float = cfg.max_target_vel 49 | self.renderer: Optional[Union[PositionControlRenderer, ObstacleAvoidanceRenderer]] 50 | 51 | def check_dims(self): 52 | assert self.obs_frame in ["body", "local", "world"], f"Invalid observation frame: {self.obs_frame}. Must be one of ['body', 'local', 'world']." 53 | assert self.get_observations().size(-1) == self.obs_dim, f"Observation dimension mismatch: {self.get_observations().size(-1)} != {self.obs_dim}" 54 | assert self.get_state().size(-1) == self.state_dim, f"State dimension mismatch: {self.get_state().size(-1)} != {self.state_dim}" 55 | 56 | @abstractmethod 57 | def get_observations(self, with_grad=False): 58 | raise NotImplementedError 59 | 60 | @abstractmethod 61 | def get_state(self, with_grad=False): 62 | raise NotImplementedError 63 | 64 | def detach(self): 65 | self.dynamics.detach() 66 | 67 | @property 68 | def p(self): return self.dynamics.p 69 | @property 70 | def v(self): return self.dynamics.v 71 | @property 72 | def a(self): return self.dynamics.a 73 | @property 74 | def w(self): return self.dynamics.w 75 | @property 76 | def q(self) -> Tensor: 77 | if isinstance(self.dynamics, PointMassModelBase) and self.dynamics.align_yaw_with_target_direction: 78 | return point_mass_quat(self.a, orientation=self.target_vel) 79 | else: 80 | return self.dynamics.q 81 | @property 82 | def _p(self): return self.dynamics._p 83 | @property 84 | def _v(self): return self.dynamics._v 85 | @property 86 | def _a(self): return self.dynamics._a 87 | @property 88 | def _w(self): return self.dynamics._w 89 | @property 90 | def _q(self) -> Tensor: 91 | if isinstance(self.dynamics, PointMassModelBase) and self.dynamics.align_yaw_with_target_direction: 92 | return point_mass_quat(self._a, orientation=self.target_vel) 93 | else: 94 | return self.dynamics._q 95 | 96 | @property 97 | def target_vel(self): 98 | target_relpos = self.target_pos - self.p 99 | target_dist = target_relpos.norm(dim=-1) # [n_envs] 100 | return target_relpos / torch.max(target_dist / self.max_vel, torch.ones_like(target_dist)).unsqueeze(-1) 101 | 102 | def _step(self, action: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: 103 | """Common step logic for single agent environments.""" 104 | # simulation step 105 | self.dynamics.step(action) 106 | self.imu.step() 107 | # termination and truncation logic 108 | terminated, truncated = self.terminated(), self.truncated() 109 | self.progress += 1 110 | if self.renderer is not None: 111 | self.renderer.render(self.states_for_render()) 112 | # truncate if `reset_all` is commanded by the user from GUI 113 | truncated = torch.full_like(truncated, self.renderer.gui_states["reset_all"]) | truncated 114 | # arrival flag denoting if the agent has reached the target position 115 | arrived = (self.p - self.target_pos).norm(dim=-1) < 0.5 116 | curr_time = self.progress.float() * self.dt 117 | # time that the agents approached the target positions for the first time 118 | self.arrive_time.copy_(torch.where(arrived & (self.arrive_time == 0), curr_time, self.arrive_time)) 119 | # truncate if the agents have been at the target positions for a while 120 | truncated |= arrived & (curr_time > (self.arrive_time + self.wait_before_truncate)) 121 | # average velocity of the agents 122 | avg_vel = (self.init_pos - self.target_pos).norm(dim=-1) / self.arrive_time 123 | # success flag denoting whether the agent has reached the target position at the end of the episode 124 | success = arrived & truncated 125 | # update last action 126 | self.last_action.copy_(action.detach()) 127 | return terminated, truncated, success, avg_vel 128 | 129 | @abstractmethod 130 | def states_for_render(self): 131 | # type: () -> Dict[str, Tensor] 132 | raise NotImplementedError 133 | 134 | @abstractmethod 135 | def loss_and_reward(self, action): 136 | # type: (Tensor) -> Tuple[Tensor, Tensor, Dict[str, float]] 137 | raise NotImplementedError 138 | 139 | @abstractmethod 140 | def reset_idx(self, env_idx: Tensor): 141 | raise NotImplementedError 142 | 143 | def reset(self): 144 | self.reset_idx(torch.arange(self.n_envs, device=self.device)) 145 | return self.get_observations() 146 | 147 | @abstractmethod 148 | def terminated(self) -> Tensor: 149 | raise NotImplementedError 150 | 151 | def truncated(self) -> Tensor: 152 | return self.progress >= self.max_steps 153 | 154 | def rescale_action(self, action: Tensor) -> Tensor: 155 | return self.dynamics.min_action + (self.dynamics.max_action - self.dynamics.min_action) * (action + 1) / 2 156 | 157 | def world2body(self, vec_w: Tensor) -> Tensor: 158 | return self.dynamics.world2body(vec_w) 159 | 160 | def body2world(self, vec_b: Tensor) -> Tensor: 161 | return self.dynamics.body2world(vec_b) 162 | 163 | class BaseEnvMultiAgent(BaseEnv, ABC): 164 | def __init__(self, cfg: DictConfig, device: torch.device): 165 | super().__init__(cfg, device) 166 | assert self.n_agents > 1 167 | self.target_pos_base = torch.zeros(self.n_envs, self.n_agents, 3, device=device) 168 | self.target_pos_rel = torch.zeros(self.n_envs, self.n_agents, 3, device=device) 169 | 170 | @property 171 | def target_pos(self): 172 | return self.target_pos_base + self.target_pos_rel 173 | 174 | def step(self, action, need_global_state_before_reset=True) -> Tuple[ 175 | Union[Tuple[Tensor, Tensor], Tuple[TensorDict, Tensor]], 176 | Tensor, 177 | Tensor, 178 | Dict[str, Union[Dict[str, Tensor], Dict[str, float], Tensor]] 179 | ]: 180 | raise NotImplementedError 181 | 182 | @property 183 | def target_vel(self): # TODO 184 | # 这里要改成每个环境中的num_agents个飞机分别以距离自身最近的target_pos为目标计算相对的target_relpos: 185 | target_relpos = self.target_pos - self.p 186 | # target_relpos = self.multidrone_targetpos 187 | target_dist = target_relpos.norm(dim=-1) # [n_envs, n_agents] 188 | return target_relpos / torch.max(target_dist / self.max_vel.unsqueeze(-1), torch.ones_like(target_dist)).unsqueeze(-1) 189 | 190 | @property 191 | def allocated_target_pos(self): 192 | """Allocate a target for each agent.""" 193 | # 计算每架飞机相对于每个目标点的距离并找出最近的目标点的索引 194 | distance = torch.norm(self.p[:, :, None] - self.target_pos[:, None, :], dim=-1) # [n_envs, n_agents, n_agents] 195 | closest_target_indices = torch.min(distance, dim=-1).indices # [n_envs, n_agents] 196 | # 使用gather方法获取最接近的目标位置 197 | closest_targets = self.target_pos.gather(dim=1, index=closest_target_indices.unsqueeze(-1).expand_as(self.target_pos)) # [n_envs, n_agents, 3] 198 | # 计算每个无人机到其最接近目标点的相对位置向量 199 | return closest_targets # [n_envs, n_agents, 3] 200 | 201 | @property 202 | def internal_min_distance(self) -> Tensor: 203 | # 计算每个环境中无人机之间的距离 204 | distances = torch.norm(self._p[:, :, None, :] - self._p[:, None, :, :], dim=-1) # [n_envs, n_agents, n_agents] 205 | # 去除对角线元素 206 | diag = torch.diag(torch.ones(self.n_agents, device=self.device)).unsqueeze(0).expand(self.n_envs, -1, -1) 207 | distances = torch.where(diag.bool(), float('inf'), distances) 208 | # 找到每个agent的最小距离 209 | min_distances = distances.min(dim=-1).values 210 | return min_distances 211 | 212 | @abstractmethod 213 | def get_observations(self, with_grad=False): 214 | raise NotImplementedError 215 | 216 | @abstractmethod 217 | def get_state(self, with_grad=False): 218 | raise NotImplementedError 219 | 220 | def get_obs_and_state(self, with_grad=False): 221 | return self.get_observations(with_grad), self.get_state(with_grad) 222 | 223 | def reset(self): 224 | self.reset_idx(torch.arange(self.n_envs, device=self.device)) 225 | return self.get_obs_and_state() -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Dict, Tuple, List, Optional 2 | 3 | from omegaconf import DictConfig 4 | import torch.nn as nn 5 | 6 | from .networks import MLP, CNN, RNN, RCNN, build_network 7 | from .agents import ( 8 | AgentBase, 9 | DeterministicActor, 10 | StochasticActor, 11 | CriticQ, 12 | CriticV, 13 | ActorCriticBase, 14 | StochasticActorCriticQ, 15 | StochasticActorCriticV 16 | ) -------------------------------------------------------------------------------- /network/multiagents.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union, Optional, List 2 | import os 3 | 4 | from omegaconf import DictConfig, OmegaConf 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | from tensordict import TensorDict 9 | 10 | from .networks import build_network 11 | from .agents import StochasticActor 12 | 13 | class MAAgentBase(nn.Module): 14 | def __init__( 15 | self, 16 | cfg: DictConfig, 17 | obs_dim: Union[int, Tuple[int, Tuple[int, int]]], 18 | global_state_dim: int 19 | ): 20 | super().__init__() 21 | self.obs_dim = obs_dim 22 | self.global_state_dim = global_state_dim 23 | self.is_rnn_based = cfg.name.lower() == "rnn" or cfg.name.lower() == "rcnn" 24 | 25 | 26 | class MACriticV(MAAgentBase): 27 | def __init__( 28 | self, 29 | cfg: DictConfig, 30 | obs_dim: Union[int, Tuple[int, Tuple[int, int]]], 31 | global_state_dim: int 32 | ): 33 | super().__init__(cfg, obs_dim, global_state_dim) 34 | self.critic = build_network(cfg, global_state_dim, 1) 35 | 36 | def forward(self, global_state: Tensor, hidden: Optional[Tensor] = None) -> Tensor: 37 | return self.critic(global_state, hidden=hidden).squeeze(-1) 38 | 39 | def save(self, path: str): 40 | torch.save(self.critic.state_dict(), os.path.join(path, "critic.pth")) 41 | 42 | def load(self, path: str): 43 | self.critic.load_state_dict(torch.load(os.path.join(path, "critic.pth"), weights_only=True)) 44 | 45 | def reset(self, indices: Tensor): 46 | self.critic.reset(indices) 47 | 48 | def detach(self): 49 | self.critic.detach() 50 | 51 | 52 | class MAStochasticActorCriticV(MAAgentBase): 53 | def __init__( 54 | self, 55 | cfg: DictConfig, 56 | obs_dim: Union[int, Tuple[int, Tuple[int, int]]], 57 | global_state_dim: int, 58 | action_dim: int 59 | ): 60 | super().__init__(cfg, obs_dim, global_state_dim) 61 | self.critic = MACriticV(cfg, obs_dim, global_state_dim) 62 | self.actor = StochasticActor(cfg, obs_dim, action_dim) 63 | 64 | def get_value(self, global_state: Union[Tensor, Tuple[Tensor, Tensor]], hidden: Optional[Tensor] = None) -> Tensor: 65 | return self.critic(global_state, hidden=hidden) 66 | 67 | def get_action(self, obs, sample=None, test=False, hidden=None): 68 | # type: (Union[Tensor, Tuple[Tensor, Tensor]], Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor] 69 | return self.actor(obs, sample, test, hidden=hidden) 70 | 71 | def get_action_and_value(self, obs, global_state, sample=None, test=False): 72 | # type: (Union[Tensor, Tuple[Tensor, Tensor]], Tensor, Optional[Tensor], bool) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor] 73 | return *self.get_action(obs, sample=sample, test=test), self.get_value(global_state) 74 | 75 | def save(self, path: str): 76 | self.actor.save(path) 77 | self.critic.save(path) 78 | 79 | def load(self, path: str): 80 | self.actor.load(path) 81 | self.critic.load(path) 82 | 83 | def reset(self, indices: Tensor): 84 | self.actor.reset(indices) 85 | self.critic.reset(indices) 86 | 87 | def detach(self): 88 | self.actor.detach() 89 | self.critic.detach() -------------------------------------------------------------------------------- /network/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union, Optional, List 2 | from math import ceil 3 | 4 | from omegaconf import DictConfig 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | from torchvision.models.resnet import conv3x3, conv1x1 9 | 10 | from diffaero.utils.nn import mlp 11 | 12 | def obs_action_concat(state: Union[Tensor, Tuple[Tensor, Tensor]], action: Optional[Tensor] = None) -> Tensor: 13 | if isinstance(state, Tensor): 14 | return torch.cat([state, action], dim=-1) if action is not None else state 15 | else: 16 | return torch.cat([state[0], state[1].flatten(-2)] + ([] if action is None else [action]), dim=-1) 17 | 18 | class BaseNetwork(nn.Module): 19 | def __init__( 20 | self, 21 | input_dim: Union[int, Tuple[int, Tuple[int, int]]], 22 | rnn_n_layers: int = 0, 23 | rnn_hidden_dim: int = 0 24 | ): 25 | super().__init__() 26 | self.input_dim = input_dim 27 | self.rnn_n_layers = rnn_n_layers 28 | self.rnn_hidden_dim = rnn_hidden_dim 29 | self.hidden_state: Optional[Tensor] = None 30 | 31 | def reset(self, indices: Tensor) -> None: 32 | pass 33 | 34 | def detach(self) -> None: 35 | pass 36 | 37 | class MLP(BaseNetwork): 38 | def __init__( 39 | self, 40 | cfg: DictConfig, 41 | input_dim: Union[int, Tuple[int, Tuple[int, int]]], 42 | output_dim: int, 43 | output_act: Optional[nn.Module] = None 44 | ): 45 | super().__init__(input_dim) 46 | if not isinstance(input_dim, int): 47 | D, (H, W) = input_dim 48 | input_dim = D + H * W 49 | self.head = mlp(input_dim, cfg.hidden_dim, output_dim, output_act=output_act) 50 | 51 | def forward( 52 | self, 53 | obs: Union[Tensor, Tuple[Tensor, Tensor]], # [N, D_state] or ([N, D_state], [N, H, W]) 54 | action: Optional[Tensor] = None, # [N, D_action] 55 | hidden: Optional[Tensor] = None 56 | ) -> Tensor: 57 | return self.head(obs_action_concat(obs, action)) 58 | 59 | def forward_export( 60 | self, 61 | obs: Union[Tensor, Tuple[Tensor, Tensor]], # [N, D_obs] 62 | action: Optional[Tensor] = None, # [N, D_action] 63 | ) -> Tensor: 64 | return self.forward(obs=obs, action=action) 65 | 66 | class BasicBlock(nn.Module): 67 | def __init__( 68 | self, 69 | in_channels: int, 70 | out_channels: int, 71 | stride: int = 1, 72 | ) -> None: 73 | super().__init__() 74 | self.conv1 = conv3x3(in_channels, out_channels, stride) 75 | self.act = nn.ELU() 76 | self.conv2 = conv3x3(out_channels, out_channels) 77 | if stride > 1 or in_channels != out_channels: 78 | self.skip_conn = conv1x1(in_channels, out_channels, stride) 79 | else: 80 | self.skip_conn = nn.Identity() 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | identity = x 84 | out = self.conv1(x) 85 | out = self.act(out) 86 | out = self.conv2(out) 87 | identity = self.skip_conn(x) 88 | out += identity 89 | out = self.act(out) 90 | return out 91 | 92 | class CNNBackbone(nn.Sequential): 93 | def __init__(self, cnn_layers: List[Tuple[int, int, int]], input_dim: Tuple[int, Tuple[int, int]]): 94 | D, (H, W) = input_dim 95 | layers: List[nn.Module] = [] 96 | ds_rate = 1 97 | assert len(cnn_layers) > 0, "CNNBackbone must have at least one layer." 98 | for layer in cnn_layers: 99 | in_channels, out_channels, stride = layer 100 | ds_rate *= stride 101 | layers.append(BasicBlock(in_channels, out_channels, stride=stride)) 102 | layers.append(nn.Flatten(start_dim=-3)) 103 | 104 | if any([H % ds_rate != 0, W % ds_rate != 0]): 105 | hpad = ceil(H / ds_rate) * ds_rate - H 106 | wpad = ceil(W / ds_rate) * ds_rate - W 107 | top, left = hpad // 2, wpad // 2 108 | bottom, right = hpad - top, wpad - left 109 | layers.insert(0, nn.ZeroPad2d((left, right, top, bottom))) 110 | 111 | h_out, w_out = ceil(H / ds_rate), ceil(W / ds_rate) 112 | super().__init__(*layers) 113 | self.out_dim = D + out_channels * h_out * w_out 114 | self.h_out = h_out 115 | self.w_out = w_out 116 | 117 | class CNN(BaseNetwork): 118 | def __init__( 119 | self, 120 | cfg: DictConfig, 121 | input_dim: Tuple[int, Tuple[int, int]], 122 | output_dim: int, 123 | output_act: Optional[nn.Module] = None 124 | ): 125 | super().__init__(input_dim) 126 | self.cnn = CNNBackbone(cfg.cnn_layers, input_dim) 127 | self.head = mlp(self.cnn.out_dim, cfg.hidden_dim, output_dim, output_act=output_act) 128 | 129 | def forward( 130 | self, 131 | obs: Tuple[Tensor, Tensor], # ([N, D_state], [N, H, W]) 132 | action: Optional[Tensor] = None, # [N, D_action] 133 | hidden: Optional[Tensor] = None 134 | ) -> Tensor: 135 | perception = obs[1] 136 | if perception.ndim == 3: 137 | perception = perception.unsqueeze(1) 138 | input = [obs[0], self.cnn(perception)] + ([] if action is None else [action]) 139 | return self.head(torch.cat(input, dim=-1)) 140 | 141 | def forward_export( 142 | self, 143 | state: Tensor, # [N, D_state] 144 | perception: Tensor, # [N, H, W] 145 | action: Optional[Tensor] = None, # [N, D_action] 146 | ) -> Tensor: 147 | if perception.ndim == 3: 148 | perception = perception.unsqueeze(1) 149 | input = [state, self.cnn(perception)] + ([] if action is None else [action]) 150 | return self.head(torch.cat(input, dim=-1)) 151 | 152 | 153 | class RNN(BaseNetwork): 154 | def __init__( 155 | self, 156 | cfg: DictConfig, 157 | input_dim: Union[int, Tuple[int, Tuple[int, int]]], 158 | output_dim: int, 159 | output_act: Optional[nn.Module] = None 160 | ): 161 | super().__init__(input_dim, cfg.rnn_n_layers, cfg.rnn_hidden_dim) 162 | if not isinstance(input_dim, int): 163 | D, (H, W) = input_dim 164 | input_dim = D + H * W 165 | self.gru = torch.nn.GRU( 166 | input_size=input_dim, 167 | hidden_size=self.rnn_hidden_dim, 168 | num_layers=self.rnn_n_layers, 169 | bias=True, 170 | batch_first=True, 171 | dropout=0.0, 172 | bidirectional=False, 173 | dtype=torch.float 174 | ) 175 | self.head = mlp(self.rnn_hidden_dim, cfg.hidden_dim, output_dim, output_act=output_act) 176 | 177 | def forward( 178 | self, 179 | obs: Union[Tensor, Tuple[Tensor, Tensor]], # [N, D_state] or ([N, D_state], [N, H, W]) 180 | action: Optional[Tensor] = None, # [N, D_action] 181 | hidden: Optional[Tensor] = None, # [n_layers, N, D_hidden] 182 | ) -> Tensor: 183 | # self.gru.flatten_parameters() 184 | rnn_input = obs_action_concat(obs, action) 185 | 186 | use_own_hidden = hidden is None 187 | if use_own_hidden: 188 | if self.hidden_state is None: 189 | hidden = torch.zeros(self.rnn_n_layers, rnn_input.size(0), self.rnn_hidden_dim, dtype=rnn_input.dtype, device=rnn_input.device) 190 | else: 191 | hidden = self.hidden_state 192 | 193 | rnn_out, hidden_out = self.gru(rnn_input.unsqueeze(1), hidden) 194 | if use_own_hidden: 195 | self.hidden_state = hidden_out 196 | return self.head(rnn_out.squeeze(1)) 197 | 198 | def forward_export( 199 | self, 200 | obs: Union[Tensor, Tuple[Tensor, Tensor]], # [N, D_state] or ([N, D_state], [N, H, W]) 201 | hidden: Tensor, # [n_layers, N, D_hidden] 202 | action: Optional[Tensor] = None, # [N, D_action] 203 | ) -> Tuple[Tensor, Tensor]: 204 | rnn_input = obs_action_concat(obs, action) 205 | rnn_out, hidden = self.gru(rnn_input.unsqueeze(1), hidden) 206 | return self.head(rnn_out.squeeze(1)), hidden 207 | 208 | def reset(self, indices: Tensor): 209 | if self.hidden_state is not None: 210 | self.hidden_state[:, indices, :] = 0 211 | 212 | def detach(self): 213 | if self.hidden_state is not None: 214 | self.hidden_state.detach_() 215 | 216 | 217 | class RCNN(BaseNetwork): 218 | def __init__( 219 | self, 220 | cfg: DictConfig, 221 | input_dim: Tuple[int, Tuple[int, int]], 222 | output_dim: int, 223 | output_act: Optional[nn.Module] = None 224 | ): 225 | super().__init__(input_dim, cfg.rnn_n_layers, cfg.rnn_hidden_dim) 226 | self.cnn = CNNBackbone(cfg.cnn_layers, input_dim) 227 | self.gru = torch.nn.GRU( 228 | input_size=self.cnn.out_dim, 229 | hidden_size=self.rnn_hidden_dim, 230 | num_layers=self.rnn_n_layers, 231 | bias=True, 232 | batch_first=True, 233 | dropout=0.0, 234 | bidirectional=False, 235 | dtype=torch.float 236 | ) 237 | self.head = mlp(self.rnn_hidden_dim, cfg.hidden_dim, output_dim, output_act=output_act) 238 | 239 | def forward( 240 | self, 241 | obs: Tuple[Tensor, Tensor], # ([N, D_state], [N, H, W]) 242 | action: Optional[Tensor] = None, # [N, D_action] 243 | hidden: Optional[Tensor] = None, # [n_layers, N, D_hidden] 244 | ) -> Tensor: 245 | # self.gru.flatten_parameters() 246 | 247 | perception = obs[1] 248 | if perception.ndim == 3: 249 | perception = perception.unsqueeze(1) 250 | rnn_input = torch.cat([obs[0], self.cnn(perception)] + ([] if action is None else [action]), dim=-1) 251 | 252 | use_own_hidden = hidden is None 253 | if use_own_hidden: 254 | if self.hidden_state is None: 255 | hidden = torch.zeros(self.rnn_n_layers, rnn_input.size(0), self.rnn_hidden_dim, dtype=rnn_input.dtype, device=rnn_input.device) 256 | else: 257 | hidden = self.hidden_state 258 | 259 | rnn_out, hidden_out = self.gru(rnn_input.unsqueeze(1), hidden) 260 | if use_own_hidden: 261 | self.hidden_state = hidden_out 262 | return self.head(rnn_out.squeeze(1)) 263 | 264 | def forward_export( 265 | self, 266 | state: Tensor, # [N, D_state] 267 | perception: Tensor, # [N, H, W] 268 | hidden: Tensor, # [n_layers, N, D_hidden] 269 | action: Optional[Tensor] = None, # [N, D_action] 270 | ) -> Tuple[Tensor, Tensor]: 271 | if perception.ndim == 3: 272 | perception = perception.unsqueeze(1) 273 | rnn_input = torch.cat([state, self.cnn(perception)] + ([] if action is None else [action]), dim=-1) 274 | rnn_out, hidden = self.gru(rnn_input.unsqueeze(1), hidden) 275 | return self.head(rnn_out.squeeze(1)), hidden 276 | 277 | def reset(self, indices: Tensor): 278 | if self.hidden_state is not None: 279 | self.hidden_state[:, indices, :] = 0 280 | 281 | def detach(self): 282 | if self.hidden_state is not None: 283 | self.hidden_state.detach_() 284 | 285 | 286 | BACKBONE_ALIAS: Dict[str, Union[type[MLP], type[CNN], type[RNN], type[RCNN]]] = { 287 | "mlp": MLP, 288 | "cnn": CNN, 289 | "rnn": RNN, 290 | "rcnn": RCNN 291 | } 292 | 293 | def build_network( 294 | cfg: DictConfig, 295 | input_dim: Union[int, Tuple[int, Tuple[int, int]]], 296 | output_dim: int, 297 | output_act: Optional[nn.Module] = None 298 | ) -> Union[MLP, CNN, RNN, RCNN]: 299 | return BACKBONE_ALIAS[cfg.name](cfg, input_dim, output_dim, output_act) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 64.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | tensordict 3 | taichi 4 | tqdm 5 | hydra-core 6 | hydra-joblib-launcher 7 | hydra_colorlog 8 | hydra-optuna-sweeper 9 | welford_torch 10 | line_profiler 11 | tensorboard 12 | tensorboardX 13 | torch-tb-profiler 14 | wandb 15 | gpustat 16 | opencv-python 17 | pytorch3d 18 | open3d 19 | numpy 20 | moviepy==1.0.3 21 | imageio 22 | imageio-ffmpeg 23 | matplotlib 24 | onnx 25 | onnxruntime -------------------------------------------------------------------------------- /script/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import main as train 2 | from .test import main as test 3 | from .export import main as export -------------------------------------------------------------------------------- /script/export.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | from pathlib import Path 4 | 5 | import torch 6 | import hydra 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | from diffaero.env import build_env 10 | from diffaero.algo import build_agent 11 | 12 | @hydra.main(config_path=str(Path(__file__).parent.parent.joinpath("cfg")), config_name="config_test", version_base="1.3") 13 | def main(cfg: DictConfig): 14 | print(f"Using device cpu.") 15 | device = torch.device("cpu") 16 | 17 | assert cfg.checkpoint is not None 18 | ckpt_path = Path(cfg.checkpoint).resolve() 19 | cfg_path = ckpt_path.parent.joinpath(".hydra", "config.yaml") 20 | ckpt_cfg = OmegaConf.load(cfg_path) 21 | cfg.algo = ckpt_cfg.algo 22 | # cfg.dynamics = ckpt_cfg.dynamics 23 | if cfg.algo.name != 'world': 24 | cfg.network = ckpt_cfg.network 25 | ckpt_cfg.env.render.headless = True 26 | cfg.dynamics = ckpt_cfg.dynamics 27 | cfg.sensor = ckpt_cfg.sensor 28 | cfg.env.n_envs = cfg.n_envs = 1 29 | ckpt_cfg.env.max_target_vel = cfg.env.max_target_vel 30 | ckpt_cfg.env.min_target_vel = cfg.env.min_target_vel 31 | ckpt_cfg.env.n_envs = cfg.env.n_envs 32 | cfg.env = ckpt_cfg.env 33 | 34 | env = build_env(cfg.env, device=device) 35 | agent = build_agent(cfg.algo, env, device) 36 | agent.load(ckpt_path) 37 | assert any(dict(cfg.export).values()) 38 | agent.export( 39 | path=ckpt_path, 40 | export_cfg=cfg.export, 41 | verbose=True 42 | ) 43 | 44 | if __name__ == "__main__": 45 | main() -------------------------------------------------------------------------------- /script/fps_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | sys.path.append('..') 4 | from pathlib import Path 5 | 6 | import hydra 7 | from tqdm import tqdm 8 | from omegaconf import DictConfig, OmegaConf 9 | 10 | @hydra.main(config_path=str(Path(__file__).parent.parent.joinpath("cfg")), config_name="config_test", version_base="1.3") 11 | def main(cfg: DictConfig): 12 | 13 | import torch 14 | import numpy as np 15 | 16 | from diffaero.env import build_env 17 | from gpustat import new_query as gpu_query 18 | 19 | device_idx = cfg.device 20 | device = f"cuda:{device_idx}" if torch.cuda.is_available() and device_idx != -1 else "cpu" 21 | print(f"Using device {device}.") 22 | device = torch.device(device) 23 | 24 | if cfg.seed != -1: 25 | random.seed(cfg.seed) 26 | np.random.seed(cfg.seed) 27 | torch.manual_seed(cfg.seed) 28 | torch.backends.cudnn.deterministic = cfg.torch_deterministic 29 | 30 | env = build_env(cfg.env, device=device) 31 | 32 | pbar = tqdm(range(cfg.n_steps)) 33 | try: 34 | with torch.no_grad(): 35 | obs = env.reset() 36 | start = pbar._time() 37 | for i in pbar: 38 | action = torch.zeros(env.n_envs, env.action_dim, device=device) 39 | env.step(action) 40 | pbar.set_postfix({"FPS": f"{int(cfg.n_envs * pbar.n / (pbar._time() - start)):,d}"}) 41 | except KeyboardInterrupt: 42 | print("Interrupted.") 43 | finally: 44 | end = pbar._time() 45 | fps = int(cfg.n_envs * pbar.n / (end - start)) 46 | processes = [] 47 | for gpu in gpu_query(): 48 | processes.extend(gpu.processes) 49 | for process in processes: 50 | if process["command"].startswith("python") and "script/fps_test.py" in process["full_command"]: 51 | vram = process["gpu_memory_usage"] 52 | break 53 | print("Overrides: ", " ".join(hydra.core.hydra_config.HydraConfig.get().overrides.task)) 54 | print(f"GPU Memory Usage: {vram} MiB, FPS: {fps:,d}") 55 | 56 | if __name__ == "__main__": 57 | main() -------------------------------------------------------------------------------- /script/test.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | sys.path.append('..') 4 | from pathlib import Path 5 | 6 | import hydra 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | @hydra.main(config_path=str(Path(__file__).parent.parent.joinpath("cfg")), config_name="config_test", version_base="1.3") 10 | def main(cfg: DictConfig): 11 | 12 | import torch 13 | import numpy as np 14 | 15 | from diffaero.env import build_env 16 | from diffaero.algo import build_agent 17 | from diffaero.utils.logger import Logger 18 | from diffaero.utils.runner import TestRunner 19 | 20 | logger = Logger(cfg, run_name=cfg.runname) 21 | 22 | device_idx = cfg.device 23 | device = f"cuda:{device_idx}" if torch.cuda.is_available() and device_idx != -1 else "cpu" 24 | Logger.info(f"Using device {device}.") 25 | device = torch.device(device) 26 | 27 | if cfg.seed != -1: 28 | random.seed(cfg.seed) 29 | np.random.seed(cfg.seed) 30 | torch.manual_seed(cfg.seed) 31 | torch.backends.cudnn.deterministic = cfg.torch_deterministic 32 | 33 | ckpt_path = Path(cfg.checkpoint).resolve() 34 | cfg_path = ckpt_path.parent.joinpath(".hydra", "config.yaml") 35 | ckpt_cfg = OmegaConf.load(cfg_path) 36 | cfg.algo = ckpt_cfg.algo 37 | if cfg.algo.name != 'world': 38 | cfg.network = ckpt_cfg.network 39 | else: 40 | cfg.algo.common.is_test = True 41 | if cfg.use_training_cfg: 42 | cfg.dynamics = ckpt_cfg.dynamics 43 | cfg.sensor = ckpt_cfg.sensor 44 | ckpt_cfg.env.max_target_vel = cfg.env.max_target_vel 45 | ckpt_cfg.env.min_target_vel = cfg.env.min_target_vel 46 | ckpt_cfg.env.n_envs = cfg.env.n_envs 47 | cfg.env = ckpt_cfg.env 48 | 49 | env = build_env(cfg.env, device=device) 50 | 51 | agent = build_agent(cfg.algo, env, device) 52 | agent.load(ckpt_path) 53 | 54 | runner = TestRunner(cfg, logger, env, agent) 55 | 56 | try: 57 | runner.run() 58 | except KeyboardInterrupt: 59 | Logger.warning("Interrupted.") 60 | 61 | success_rate = runner.close() 62 | 63 | return success_rate 64 | 65 | if __name__ == "__main__": 66 | main() -------------------------------------------------------------------------------- /script/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | sys.path.append('..') 4 | from pathlib import Path 5 | 6 | import hydra 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | def allocate_device(cfg: DictConfig): 10 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 11 | multirun = hydra_cfg.mode == hydra.types.RunMode.MULTIRUN 12 | use_multiple_devices = isinstance(cfg.device, str) and len(cfg.device) > 0 13 | multirun_across_devices = multirun and use_multiple_devices 14 | if multirun_across_devices: 15 | available_devices = list(map(int, list(cfg.device))) 16 | n_devices = len(available_devices) 17 | job_id = hydra_cfg.job.num 18 | job_device = available_devices[job_id % n_devices] 19 | else: 20 | job_device = int(cfg.device) if isinstance(cfg.device, int) else 0 21 | return job_device, multirun_across_devices 22 | 23 | @hydra.main(config_path=str(Path(__file__).parent.parent.joinpath("cfg")), config_name="config_train", version_base="1.3") 24 | def main(cfg: DictConfig): 25 | 26 | job_device, multirun_across_devices = allocate_device(cfg) 27 | if multirun_across_devices: 28 | import os 29 | os.environ["CUDA_VISIBLE_DEVICES"] = str(job_device) 30 | cfg.device = 0 31 | 32 | import torch 33 | torch.set_float32_matmul_precision('high') # for better performance 34 | import numpy as np 35 | 36 | from diffaero.env import build_env 37 | from diffaero.algo import build_agent 38 | from diffaero.utils.logger import Logger 39 | from diffaero.utils.runner import TrainRunner 40 | 41 | logger = Logger(cfg, run_name=cfg.runname) 42 | 43 | device = f"cuda:{cfg.device}" if torch.cuda.is_available() and cfg.device != -1 else "cpu" 44 | device_repr = f"cuda:{job_device}" if multirun_across_devices and device != "cpu" else device 45 | Logger.info(f"Using device {device_repr}.") 46 | device = torch.device(device) 47 | 48 | if cfg.seed != -1: 49 | random.seed(cfg.seed) 50 | np.random.seed(cfg.seed) 51 | torch.manual_seed(cfg.seed) 52 | torch.backends.cudnn.deterministic = cfg.torch_deterministic 53 | 54 | if cfg.checkpoint is not None and len(cfg.checkpoint) > 0: 55 | ckpt_path = Path(cfg.checkpoint).resolve() 56 | cfg_path = ckpt_path.parent.joinpath(".hydra", "config.yaml") 57 | ckpt_cfg = OmegaConf.load(cfg_path) 58 | cfg.sensor = ckpt_cfg.sensor 59 | train_from_checkpoint = True 60 | else: 61 | ckpt_path = '' 62 | train_from_checkpoint = False 63 | 64 | env = build_env(cfg.env, device=device) 65 | 66 | agent = build_agent(cfg.algo, env, device) 67 | if train_from_checkpoint: 68 | agent.load(ckpt_path) 69 | 70 | runner = TrainRunner(cfg, logger, env, agent) 71 | 72 | try: 73 | runner.run() 74 | except KeyboardInterrupt: 75 | Logger.warning("Interrupted.") 76 | 77 | max_success_rate = runner.close() 78 | 79 | return max_success_rate 80 | 81 | if __name__ == "__main__": 82 | main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='diffaero', 5 | version='0.1', 6 | # packages=find_packages(), 7 | packages=[".", "env", "algo", "network", "utils", "script"], 8 | install_requires=[ 9 | 'torch>=2.0.0', 10 | 'tensordict', 11 | 'taichi', 12 | 'tqdm', 13 | 'hydra-core', 14 | 'hydra-joblib-launcher', 15 | 'hydra_colorlog', 16 | 'hydra-optuna-sweeper', 17 | 'welford_torch', 18 | 'line_profiler', 19 | 'tensorboard', 20 | 'tensorboardX', 21 | 'torch-tb-profiler', 22 | 'wandb', 23 | 'gpustat', 24 | 'opencv-python', 25 | 'pytorch3d@git+https://github.com/facebookresearch/pytorch3d.git@stable#egg=pytorch3d', 26 | 'open3d', 27 | 'numpy', 28 | 'moviepy==1.0.3', 29 | 'imageio', 30 | 'imageio-ffmpeg', 31 | 'matplotlib', 32 | 'onnx', 33 | 'onnxruntime' 34 | ], 35 | author='Xinhong Zhang', 36 | author_email='xhzhang@bit.edu.cn', 37 | description='', 38 | url='https://github.com/flyingbitac/diffaero' 39 | ) 40 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flyingbitac/diffaero/4034ed7e23d38c2eb0084120e0a0fdae85452001/utils/__init__.py -------------------------------------------------------------------------------- /utils/exporter.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union, Optional, List 2 | from copy import deepcopy 3 | import os 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | from omegaconf import DictConfig 9 | import onnxruntime as ort 10 | import numpy as np 11 | 12 | from diffaero.network.agents import StochasticActor, DeterministicActor 13 | from diffaero.network.networks import MLP, CNN, RNN, RCNN 14 | from diffaero.dynamics.pointmass import point_mass_quat 15 | from diffaero.utils.logger import Logger 16 | 17 | class PolicyExporter(nn.Module): 18 | def __init__(self, actor: Union[StochasticActor, DeterministicActor]): 19 | super().__init__() 20 | self.is_stochastic = isinstance(actor, StochasticActor) 21 | self.is_recurrent = actor.is_rnn_based 22 | actor_net = actor.actor_mean if self.is_stochastic else actor.actor 23 | if self.is_recurrent: 24 | actor_net.hidden_state = torch.empty(0) 25 | self.hidden_shape = (actor_net.rnn_n_layers, 1, actor_net.rnn_hidden_dim) 26 | self.actor = deepcopy(actor_net).cpu() 27 | if isinstance(self.actor, MLP): 28 | self.forward = self.forward_MLP 29 | elif isinstance(self.actor, CNN): 30 | self.forward = self.forward_CNN 31 | elif isinstance(self.actor, RNN): 32 | self.forward = self.forward_RNN 33 | elif isinstance(self.actor, RCNN): 34 | self.forward = self.forward_RCNN 35 | 36 | self.input_dim = self.actor.input_dim 37 | state_dim = self.input_dim[0] if isinstance(self.input_dim, tuple) else self.input_dim 38 | perception_dim = self.input_dim[1] if isinstance(self.input_dim, tuple) else None 39 | self.named_inputs = [ 40 | ("state", torch.zeros(1, state_dim)), 41 | ("orientation", torch.zeros(1, 3)), 42 | ("Rz", torch.zeros(1, 3, 3)), 43 | ("min_action", torch.zeros(1, 3)), 44 | ("max_action", torch.zeros(1, 3)), 45 | ] 46 | if perception_dim is not None: 47 | if isinstance(self.actor, (MLP, RNN)): 48 | self.named_inputs[0] = ("state", (torch.zeros(1, state_dim), torch.zeros(1, perception_dim[0], perception_dim[1]))) 49 | elif isinstance(self.actor, (CNN, RCNN)): 50 | self.named_inputs.insert(1, ("perception", torch.zeros(1, perception_dim[0], perception_dim[1]))) 51 | self.output_names = [ 52 | "action", 53 | "quat_xyzw_cmd", 54 | "acc_norm" 55 | ] 56 | if self.is_recurrent: 57 | self.named_inputs.append(("hidden_in", torch.zeros(self.hidden_shape))) 58 | self.output_names.append("hidden_out") 59 | 60 | self.obs_frame: str 61 | self.action_frame: str 62 | 63 | def post_process_local(self, raw_action, min_action, max_action, orientation, Rz, is_stochastic): 64 | # type: (Tensor, Tensor, Tensor, Tensor, Tensor, bool) -> Tuple[Tensor, Tensor, Tensor] 65 | raw_action = raw_action.tanh() if is_stochastic else raw_action 66 | action = (raw_action * 0.5 + 0.5) * (max_action - min_action) + min_action 67 | acc_cmd = torch.matmul(Rz, action.unsqueeze(-1)).squeeze(-1) 68 | quat_xyzw = point_mass_quat(acc_cmd, orientation) 69 | acc_norm = acc_cmd.norm(p=2, dim=-1) 70 | return acc_cmd, quat_xyzw, acc_norm 71 | 72 | def post_process_world(self, raw_action, min_action, max_action, orientation, Rz, is_stochastic): 73 | # type: (Tensor, Tensor, Tensor, Tensor, Tensor, bool) -> Tuple[Tensor, Tensor, Tensor] 74 | raw_action = raw_action.tanh() if is_stochastic else raw_action 75 | action = (raw_action * 0.5 + 0.5) * (max_action - min_action) + min_action 76 | quat_xyzw = point_mass_quat(action, orientation) 77 | acc_norm = action.norm(p=2, dim=-1) 78 | return action, quat_xyzw, acc_norm 79 | 80 | def post_process(self, raw_action, min_action, max_action, orientation, Rz): 81 | # type: (Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor] 82 | if self.action_frame == "local": 83 | return self.post_process_local(raw_action, min_action, max_action, orientation, Rz, self.is_stochastic) 84 | elif self.action_frame == "world": 85 | return self.post_process_world(raw_action, min_action, max_action, orientation, Rz, self.is_stochastic) 86 | else: 87 | raise ValueError(f"Unknown action frame: {self.action_frame}") 88 | 89 | def forward_MLP(self, state, orientation, Rz, min_action, max_action): 90 | # type: (Union[Tensor, Tuple[Tensor, Tensor]], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor] 91 | raw_action = self.actor.forward_export(state) 92 | action, quat_xyzw, acc_norm = self.post_process(raw_action, min_action, max_action, orientation, Rz) 93 | return action, quat_xyzw, acc_norm 94 | 95 | def forward_CNN(self, state, perception, orientation, Rz, min_action, max_action): 96 | # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor] 97 | raw_action = self.actor.forward_export(state=state, perception=perception) 98 | action, quat_xyzw, acc_norm = self.post_process(raw_action, min_action, max_action, orientation, Rz) 99 | return action, quat_xyzw, acc_norm 100 | 101 | def forward_RNN(self, state, orientation, Rz, min_action, max_action, hidden_in): 102 | # type: (Union[Tensor, Tuple[Tensor, Tensor]], Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] 103 | raw_action, hidden_out = self.actor.forward_export(state, hidden=hidden_in) 104 | action, quat_xyzw, acc_norm = self.post_process(raw_action, min_action, max_action, orientation, Rz) 105 | return action, quat_xyzw, acc_norm, hidden_out 106 | 107 | def forward_RCNN(self, state, perception, orientation, Rz, min_action, max_action, hidden_in): 108 | # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] 109 | raw_action, hidden_out = self.actor.forward_export(state=state, perception=perception, hidden=hidden_in) 110 | action, quat_xyzw, acc_norm = self.post_process(raw_action, min_action, max_action, orientation, Rz) 111 | return action, quat_xyzw, acc_norm, hidden_out 112 | 113 | def export( 114 | self, 115 | path: str, 116 | export_cfg: DictConfig, 117 | verbose=False, 118 | ): 119 | self.obs_frame = export_cfg.obs_frame 120 | self.action_frame = export_cfg.action_frame 121 | if export_cfg.jit: 122 | self.export_jit(path, verbose) 123 | if export_cfg.onnx: 124 | self.export_onnx(path) 125 | 126 | @torch.no_grad() 127 | def export_jit(self, path: str, verbose=False): 128 | traced_script_module = torch.jit.script(self) 129 | if verbose: 130 | Logger.info("Code of scripted module: \n" + traced_script_module.code) 131 | export_path = os.path.join(path, "exported_actor.pt2") 132 | traced_script_module.save(export_path) 133 | Logger.info(f"The checkpoint is compiled and exported to {export_path}.") 134 | 135 | @torch.no_grad() 136 | def export_onnx(self, path: str): 137 | export_path = os.path.join(path, "exported_actor.onnx") 138 | names, test_inputs = zip(*self.named_inputs) 139 | torch.onnx.export( 140 | model=self, 141 | args=test_inputs, 142 | f=export_path, 143 | input_names=names, 144 | output_names=self.output_names 145 | ) 146 | Logger.info(f"The checkpoint is compiled and exported to {export_path}.") 147 | 148 | # self.eval() 149 | # ort_session = ort.InferenceSession(export_path) 150 | # verify_inputs = [] 151 | # for input in test_inputs: 152 | # if isinstance(input, Tensor): 153 | # verify_inputs.append(torch.randn_like(input)) 154 | # elif isinstance(input, tuple): 155 | # verify_inputs.append(tuple(torch.randn_like(t) for t in input)) 156 | # ort_inputs = {} 157 | # for name, input in zip(names, verify_inputs): 158 | # if isinstance(input, Tensor): 159 | # ort_inputs[name] = input.numpy() 160 | # elif isinstance(input, tuple): 161 | # ort_inputs[name] = tuple([t.numpy() for t in input]) 162 | # ort_outs: Tuple[np.ndarray, ...] = ort_session.run(None, ort_inputs) # type: ignore 163 | # torch_outs: Tuple[Tensor, ...] = self(*verify_inputs) 164 | # # compare ONNX Runtime and PyTorch results 165 | # for i in range(len(ort_outs)): 166 | # np.testing.assert_allclose(ort_outs[i], torch_outs[i].cpu().numpy(), rtol=1e-03, atol=1e-05, verbose=True) 167 | # Logger.info(f"The onnx model at {export_path} is verified!") -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | import os 3 | import copy 4 | import inspect 5 | import logging 6 | from pathlib import Path 7 | 8 | import hydra 9 | from omegaconf import OmegaConf, DictConfig 10 | from torch.utils.tensorboard.writer import SummaryWriter 11 | from tqdm import tqdm 12 | 13 | from diffaero import DIFFAERO_ROOT_DIR 14 | 15 | class TensorBoardLogger: 16 | def __init__( 17 | self, 18 | cfg: DictConfig, 19 | logdir: str, 20 | run_name: str = "" 21 | ): 22 | self.cfg = cfg 23 | self.logdir = logdir 24 | Logger.info("Using Tensorboard Logger.") 25 | self.writer = SummaryWriter(log_dir=os.path.join(self.logdir, run_name)) 26 | self.log_hparams() 27 | 28 | def log_scalar(self, tag, value, step): 29 | self.writer.add_scalar(tag, value, step) 30 | 31 | def log_scalars(self, value_dict, step): 32 | for k, v in value_dict.items(): 33 | if isinstance(v, dict): 34 | self.log_scalars({k+"/"+k_: v_ for k_, v_ in v.items()}, step) 35 | else: 36 | self.log_scalar(k, v, step) 37 | 38 | def log_histogram(self, tag, values, step): 39 | self.writer.add_histogram(tag, values, step) 40 | 41 | def log_image(self, tag, img, step): 42 | self.writer.add_image(tag, img, step, dataformats='CHW') 43 | 44 | def log_images(self, tag, img, step): 45 | self.writer.add_images(tag, img, step) 46 | 47 | def log_video(self, tag, video, step, fps): 48 | self.writer.add_video(tag, video, step, fps=fps) 49 | 50 | def close(self): 51 | self.writer.close() 52 | 53 | def log_hparams(self): 54 | to_yaml = lambda x: OmegaConf.to_yaml(x, resolve=True).replace(" ", "- ").replace("\n", " \n") 55 | if hasattr(self.cfg.env, "render"): 56 | delattr(self.cfg.env, "render") 57 | self.writer.add_text("Env HParams", to_yaml(self.cfg.env), 0) 58 | self.writer.add_text("Train HParams", to_yaml(self.cfg.algo), 0) 59 | overrides_path = os.path.join(self.logdir, ".hydra", "overrides.yaml") 60 | if os.path.exists(overrides_path): 61 | with open(overrides_path, "r") as f: 62 | overrides = [line.strip('- ') for line in f.readlines()] 63 | self.writer.add_text("Overrides", ' '.join(overrides), 0) 64 | 65 | 66 | class WandBLogger: 67 | def __init__( 68 | self, 69 | cfg: DictConfig, 70 | logdir: str, 71 | run_name: str = "" 72 | ): 73 | self.cfg = cfg 74 | self.logdir = logdir 75 | Logger.info("Using WandB Logger.") 76 | 77 | overrides_path = os.path.join(self.logdir, ".hydra", "overrides.yaml") 78 | if os.path.exists(overrides_path): 79 | with open(overrides_path, "r") as f: 80 | overrides = " ".join([line.strip('- ') for line in f.readlines()]) 81 | import wandb 82 | wandb.init( 83 | project=cfg.logger.project, 84 | entity=cfg.logger.entity, 85 | dir=self.logdir, 86 | sync_tensorboard=False, 87 | config={**dict(cfg), "overrides": overrides}, # type: ignore 88 | name=run_name, 89 | settings=wandb.Settings( 90 | quiet=cfg.logger.quiet 91 | ) 92 | ) 93 | self.writer = wandb 94 | 95 | def log_scalar(self, tag, value, step): 96 | self.writer.log({tag: value}, step=step) 97 | 98 | def log_scalars(self, value_dict, step): 99 | for k, v in value_dict.items(): 100 | if isinstance(v, dict): 101 | self.log_scalars({k+"/"+k_: v_ for k_, v_ in v.items()}, step) 102 | else: 103 | self.log_scalar(k, v, step) 104 | 105 | def log_histogram(self, tag, values, step): 106 | self.writer.log({tag: values}, step=step) 107 | 108 | def log_image(self, tag, img, step): 109 | self.writer.log({tag: img}, step=step) 110 | 111 | def log_images(self, tag, img, step): 112 | self.writer.log({tag: img}, step=step) 113 | 114 | def log_video(self, tag, video, step, fps): 115 | self.writer.log({tag: video}, step=step) 116 | 117 | def close(self): 118 | self.writer.finish() 119 | 120 | 121 | def msg2str(*msgs): 122 | return " ".join([str(msg) for msg in msgs]) 123 | 124 | class Logger: 125 | logging = logging.getLogger() 126 | def __init__( 127 | self, 128 | cfg: DictConfig, 129 | run_name: str = "" 130 | ): 131 | logger_alias = { 132 | "tensorboard": TensorBoardLogger, 133 | "wandb": WandBLogger 134 | } 135 | self.cfg = copy.deepcopy(cfg) 136 | assert str(cfg.log_level).upper() in logging._nameToLevel.keys() 137 | Logger.logging.setLevel(logging._nameToLevel[str(cfg.log_level).upper()]) 138 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() # type: ignore 139 | 140 | self.logdir = hydra_cfg.runtime.output_dir 141 | run_names = ( 142 | [ 143 | cfg.dynamics.abbr, 144 | cfg.env.abbr, 145 | cfg.algo.name 146 | ] + 147 | ([cfg.algo.network.name] if hasattr(cfg.algo, "network") and hasattr(cfg.algo.network, "name") else []) + 148 | ([run_name] if len(run_name) > 0 else []) + 149 | [ 150 | str(cfg.seed) 151 | ] 152 | ) 153 | type = cfg.logger.name.lower() 154 | self._logger: Union[TensorBoardLogger, WandBLogger] = logger_alias[type](self.cfg, self.logdir, run_name="__".join(run_names)) 155 | Logger.info("Output directory:", self.logdir) 156 | 157 | is_multirun = hydra_cfg.mode == hydra.types.RunMode.MULTIRUN # type: ignore 158 | job_id = hydra_cfg.job.num if is_multirun else 0 159 | desc = f"Job {job_id:2d}" if is_multirun else "" 160 | n = cfg.n_updates if hasattr(cfg, "n_updates") else cfg.n_steps 161 | self.pbar = tqdm(range(n), position=job_id%self.cfg.n_jobs, desc=desc) 162 | 163 | @staticmethod 164 | def _get_logger(inspect_stack: List[inspect.FrameInfo]): 165 | rel_path = Path(inspect_stack[1].filename).resolve().relative_to(DIFFAERO_ROOT_DIR) 166 | Logger.logging.name = f"{str(rel_path)}:{inspect_stack[1].lineno}" 167 | return Logger.logging 168 | 169 | @staticmethod 170 | def debug(*msgs): 171 | with tqdm.external_write_mode(): 172 | Logger._get_logger(inspect.stack()).debug(msg2str(*msgs)) 173 | 174 | @staticmethod 175 | def info(*msgs): 176 | with tqdm.external_write_mode(): 177 | Logger._get_logger(inspect.stack()).info(msg2str(*msgs)) 178 | 179 | @staticmethod 180 | def warning(*msgs): 181 | with tqdm.external_write_mode(): 182 | Logger._get_logger(inspect.stack()).warning(msg2str(*msgs)) 183 | 184 | @staticmethod 185 | def error(*msgs): 186 | with tqdm.external_write_mode(): 187 | Logger._get_logger(inspect.stack()).error(msg2str(*msgs)) 188 | 189 | @staticmethod 190 | def critical(*msgs): 191 | with tqdm.external_write_mode(): 192 | Logger._get_logger(inspect.stack()).critical(msg2str(*msgs)) 193 | 194 | @property 195 | def n(self): 196 | return self.pbar.n 197 | 198 | def log_scalar(self, tag, value): 199 | return self._logger.log_scalar(tag, value, self.n) 200 | 201 | def log_scalars(self, value_dict): 202 | return self._logger.log_scalars(value_dict, self.n) 203 | 204 | def log_histogram(self, tag, values): 205 | return self._logger.log_histogram(tag, values, self.n) 206 | 207 | def log_image(self, tag, img): 208 | return self._logger.log_image(tag, img, self.n) 209 | 210 | def log_images(self, tag, img): 211 | return self._logger.log_images(tag, img, self.n) 212 | 213 | def log_video(self, tag, video, fps): 214 | return self._logger.log_video(tag, video, self.n, fps) 215 | 216 | def close(self): 217 | return self._logger.close() -------------------------------------------------------------------------------- /utils/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2023, NVIDIA Corporation 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # 1. Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # 2. Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # 3. Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | from typing import Callable, Union, Optional, Tuple, List 30 | 31 | import torch 32 | from torch import Tensor 33 | from pytorch3d import transforms as T 34 | 35 | # Runge-Kutta 4th Order Method 36 | def rk4(f, X0, U, dt, M=1): 37 | # type: (Callable[[Tensor, Tensor], Tensor], Tensor, Tensor, float, int) -> Tensor 38 | DT = dt / M 39 | X1 = X0 40 | for _ in range(M): 41 | k1 = DT * f(X1, U) 42 | k2 = DT * f(X1 + 0.5 * k1, U) 43 | k3 = DT * f(X1 + 0.5 * k2, U) 44 | k4 = DT * f(X1 + k3, U) 45 | X1 = X1 + (k1 + 2 * k2 + 2 * k3 + k4) / 6 46 | return X1 47 | 48 | # Euler Integration 49 | def EulerIntegral(f, X0, U, dt, M=1): 50 | # type: (Callable[[Tensor, Tensor], Tensor], Tensor, Tensor, float, int) -> Tensor 51 | DT = dt / M 52 | X1 = X0 53 | for _ in range(M): 54 | X1 = X1 + DT * f(X1, U) 55 | return X1 56 | 57 | @torch.jit.script 58 | def euler_to_quaternion(roll, pitch, yaw): 59 | # type: (Tensor, Tensor, Tensor) -> Tensor 60 | cy = torch.cos(yaw * 0.5) 61 | sy = torch.sin(yaw * 0.5) 62 | cr = torch.cos(roll * 0.5) 63 | sr = torch.sin(roll * 0.5) 64 | cp = torch.cos(pitch * 0.5) 65 | sp = torch.sin(pitch * 0.5) 66 | 67 | qw = cy * cr * cp + sy * sr * sp 68 | qx = cy * sr * cp - sy * cr * sp 69 | qy = cy * cr * sp + sy * sr * cp 70 | qz = sy * cr * cp - cy * sr * sp 71 | 72 | return torch.stack([qx, qy, qz, qw], dim=-1) 73 | 74 | def random_quat_from_eular_zyx( 75 | yaw_range: Tuple[float, float] = (-torch.pi, torch.pi), 76 | pitch_range: Tuple[float, float] = (-torch.pi, torch.pi), 77 | roll_range: Tuple[float, float] = (-torch.pi, torch.pi), 78 | size: Union[int, Tuple[int, int]] = 1, 79 | device = None 80 | ) -> Tuple[float, float, float, float]: 81 | """ 82 | Return a quaternion with eular angles uniformly sampled from given range. 83 | 84 | Args: 85 | yaw_range: range of yaw angle in radians. 86 | pitch_range: range of pitch angle in radians. 87 | roll_range: range of roll angle in radians. 88 | 89 | Returns: 90 | Real and imagine part of the quaternion. 91 | """ 92 | yaw = torch.rand(size, device=device) * (yaw_range[1] - yaw_range[0]) + yaw_range[0] 93 | pitch = torch.rand(size, device=device) * (pitch_range[1] - pitch_range[0]) + pitch_range[0] 94 | roll = torch.rand(size, device=device) * (roll_range[1] - roll_range[0]) + roll_range[0] 95 | quat_xyzw = euler_to_quaternion(roll, pitch, yaw) 96 | return quat_xyzw 97 | 98 | @torch.jit.script 99 | def quat_rotate(quat_xyzw: Tensor, v: Tensor) -> Tensor: 100 | q_w = quat_xyzw[..., -1] 101 | q_vec = quat_xyzw[..., :3] 102 | a = v * (q_w ** 2 - q_vec.pow(2).sum(dim=-1)).unsqueeze(-1) 103 | b = 2. * q_w.unsqueeze(-1) * torch.cross(q_vec, v, dim=-1) 104 | c = 2. * q_vec * (q_vec * v).sum(dim=-1, keepdim=True) 105 | return a + b + c 106 | 107 | @torch.jit.script 108 | def quat_rotate_inverse(quat_xyzw: Tensor, v: Tensor) -> Tensor: 109 | q_w = quat_xyzw[..., -1] 110 | q_vec = quat_xyzw[..., :3] 111 | a = v * (q_w ** 2 - q_vec.pow(2).sum(dim=-1)).unsqueeze(-1) 112 | b = 2. * q_w.unsqueeze(-1) * torch.cross(q_vec, v, dim=-1) 113 | c = 2. * q_vec * (q_vec * v).sum(dim=-1, keepdim=True) 114 | return a - b + c 115 | 116 | @torch.jit.script 117 | def quat_axis(quat_xyzw: Tensor, axis: int = 0) -> Tensor: 118 | basis_vec = torch.zeros(quat_xyzw.shape[0], 3, device=quat_xyzw.device) 119 | basis_vec[..., axis] = 1 120 | return quat_rotate(quat_xyzw, basis_vec) 121 | 122 | @torch.jit.script 123 | def quat_mul(a: Tensor, b: Tensor) -> Tensor: 124 | shape = a.shape 125 | a = a.reshape(-1, 4) 126 | b = b.reshape(-1, 4) 127 | 128 | x1, y1, z1, w1 = a[..., 0], a[..., 1], a[..., 2], a[..., 3] 129 | x2, y2, z2, w2 = b[..., 0], b[..., 1], b[..., 2], b[..., 3] 130 | ww = (z1 + x1) * (x2 + y2) 131 | yy = (w1 - y1) * (w2 + z2) 132 | zz = (w1 + y1) * (w2 - z2) 133 | xx = ww + yy + zz 134 | qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) 135 | w = qq - ww + (z1 - y1) * (y2 - z2) 136 | x = qq - xx + (x1 + w1) * (x2 + w2) 137 | y = qq - yy + (w1 - x1) * (y2 + z2) 138 | z = qq - zz + (z1 + y1) * (w2 - x2) 139 | 140 | quat = torch.stack([x, y, z, w], dim=-1).reshape(shape) 141 | 142 | return quat 143 | 144 | def quat_standardize(quat_xyzw: torch.Tensor) -> torch.Tensor: 145 | return torch.where(quat_xyzw[..., -1:] < 0, -quat_xyzw, quat_xyzw) 146 | 147 | @torch.jit.script 148 | def quat_inv(quat_xyzw: Tensor) -> Tensor: 149 | return torch.cat([-quat_xyzw[..., :3], quat_xyzw[..., 3:4]], dim=-1) 150 | 151 | @torch.jit.script 152 | def axis_rotmat(axis: str, angle: Tensor) -> Tensor: 153 | """ 154 | Return the rotation matrices for one of the rotations about an axis 155 | of which Euler angles describe, for each value of the angle given. 156 | 157 | Args: 158 | axis: Axis label "X" or "Y or "Z". 159 | angle: any shape tensor of Euler angles in radians 160 | 161 | Returns: 162 | Rotation matrices as tensor of shape (..., 3, 3). 163 | """ 164 | 165 | cos = torch.cos(angle) 166 | sin = torch.sin(angle) 167 | one = torch.ones_like(angle) 168 | zero = torch.zeros_like(angle) 169 | 170 | if axis == "X": 171 | R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) 172 | elif axis == "Y": 173 | R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) 174 | else: # axis == "Z" 175 | R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) 176 | 177 | return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) 178 | 179 | def rand_range( 180 | min: float, 181 | max: float, 182 | size: Union[int, Tuple[int, ...]], 183 | device: torch.device 184 | ): 185 | if isinstance(size, int): 186 | size = (size,) 187 | return torch.rand(*size, device=device) * (max - min) + min 188 | 189 | @torch.jit.script 190 | def quaternion_to_euler(quat_xyzw: Tensor) -> Tensor: 191 | # return T.matrix_to_euler_angles(T.quaternion_to_matrix(quat_wxyz), "ZYX")[..., [2, 1, 0]] 192 | x, y, z, w = quat_xyzw.unbind(dim=-1) 193 | roll = torch.atan2(2.0 * (w * x + y * z), 1.0 - 2.0 * (x**2 + y**2)) 194 | pitch = torch.asin(2.0 * (w * y - x * z)) 195 | yaw = torch.atan2(2.0 * (w * z + x * y), 1.0 - 2.0 * (y**2 + z**2)) 196 | return torch.stack([roll, pitch, yaw], dim=-1) 197 | 198 | @torch.jit.script 199 | def quaternion_invert(quat_wxyz: Tensor) -> Tensor: 200 | neg = torch.ones_like(quat_wxyz) 201 | neg[..., 1:] = -1 202 | return quat_wxyz * neg 203 | 204 | @torch.jit.script 205 | def quaternion_apply(quat_wxyz: Tensor, point: Tensor) -> Tensor: 206 | if point.size(-1) != 3: 207 | raise ValueError(f"Points are not in 3D, f{point.shape}.") 208 | real_parts = torch.zeros_like(point[..., 0:1]) 209 | point_as_quaternion = torch.cat((real_parts, point), -1) 210 | out = T.quaternion_raw_multiply( 211 | T.quaternion_raw_multiply(quat_wxyz, point_as_quaternion), 212 | quaternion_invert(quat_wxyz), 213 | ) 214 | return out[..., 1:] 215 | 216 | @torch.jit.script 217 | def mvp(mat: Tensor, vec: Tensor) -> Tensor: 218 | """ 219 | Matrix-vector product. 220 | 221 | Args: 222 | mat: A tensor of shape (..., n, m). 223 | vec: A tensor of shape (..., m). 224 | 225 | Returns: 226 | A tensor of shape (..., n). 227 | """ 228 | if mat.dim() < 2 or vec.dim() < 1: 229 | raise ValueError("Input dimensions are not compatible for matrix-vector multiplication.") 230 | return torch.matmul(mat, vec.unsqueeze(-1)).squeeze(-1) 231 | 232 | @torch.jit.script 233 | def tanh_unsquash(x: Tensor, min: Tensor, max: Tensor): 234 | min, max = min.expand_as(x), max.expand_as(x) 235 | return min + (x.tanh().add(1.).mul(0.5) * (max - min)) -------------------------------------------------------------------------------- /utils/nn.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union, Optional, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def num_params(model: nn.Module) -> int: 7 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 8 | 9 | def layer_init(layer, std=2.**0.5, bias_const=0.0): 10 | torch.nn.init.orthogonal_(layer.weight, std) 11 | torch.nn.init.constant_(layer.bias, bias_const) 12 | return layer 13 | 14 | def weight_init(m, std=0.02, bias_const=0.0): 15 | """Custom weight initialization for TD-MPC2.""" 16 | if isinstance(m, nn.Linear): 17 | nn.init.trunc_normal_(m.weight, std=std) 18 | if m.bias is not None: 19 | nn.init.constant_(m.bias, bias_const) 20 | elif isinstance(m, nn.Embedding): 21 | nn.init.uniform_(m.weight, -0.02, 0.02) 22 | elif isinstance(m, nn.ParameterList): 23 | for i,p in enumerate(m): 24 | if p.dim() == 3: # Linear 25 | nn.init.trunc_normal_(p, std=0.02) # Weight 26 | nn.init.constant_(m[i+1], 0) # Bias 27 | return m 28 | 29 | def zero_(params): 30 | """Initialize parameters to zero.""" 31 | for p in params: 32 | p.data.fill_(0) 33 | 34 | class NormedLinear(nn.Module): 35 | """ 36 | Linear layer with LayerNorm, activation. 37 | """ 38 | def __init__(self, in_features, out_features, act=nn.Mish(inplace=True)): 39 | super().__init__() 40 | self.linear = nn.Linear(in_features, out_features) 41 | self.ln = nn.LayerNorm(self.linear.out_features) 42 | self.act = act 43 | def forward(self, x): 44 | x = self.linear(x) 45 | return self.act(self.ln(x)) 46 | def __repr__(self): 47 | return f"NormedLinear(in_features={self.linear.in_features}, "\ 48 | f"out_features={self.linear.out_features}, "\ 49 | f"bias={self.linear.bias is not None}, "\ 50 | f"act={self.act.__class__.__name__})" 51 | def apply(self, fn): 52 | self.linear.apply(fn) 53 | return self 54 | 55 | def mlp( 56 | in_dim: int, 57 | mlp_dims: Union[int, List[int]], 58 | out_dim: int, 59 | hidden_act: nn.Module = nn.ELU(inplace=True), 60 | output_act: Optional[nn.Module] = None): 61 | """ 62 | Basic building block of TD-MPC2. 63 | MLP with LayerNorm, Mish activations. 64 | """ 65 | if isinstance(mlp_dims, int): 66 | mlp_dims = [mlp_dims] 67 | dims = [in_dim] + mlp_dims + [out_dim] 68 | mlp = nn.ModuleList() 69 | for i in range(len(dims) - 2): 70 | mlp.append(NormedLinear(dims[i], dims[i+1], act=hidden_act).apply(layer_init)) 71 | mlp.append(layer_init(nn.Linear(dims[-2], dims[-1]), std=0.01)) 72 | if output_act is not None: 73 | mlp.append(output_act) 74 | return nn.Sequential(*mlp) 75 | 76 | def clip_grad_norm( 77 | m: nn.Module, 78 | max_grad_norm: Optional[float] = None 79 | ) -> float: 80 | if max_grad_norm is not None: 81 | grad_norm = torch.nn.utils.clip_grad_norm_(m.parameters(), max_norm=max_grad_norm) 82 | else: 83 | grads = [p.grad for p in m.parameters() if p.grad is not None] 84 | grad_norm = torch.nn.utils.get_total_norm(grads) 85 | return grad_norm.item() -------------------------------------------------------------------------------- /utils/randomizer.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, List, Optional 2 | from textwrap import dedent 3 | 4 | from omegaconf import DictConfig 5 | import torch 6 | 7 | class RandomizerBase: 8 | def __init__( 9 | self, 10 | shape: Union[int, List[int], torch.Size], 11 | default_value: Union[float, bool], 12 | device: torch.device, 13 | enabled: bool = True, 14 | dtype: torch.dtype = torch.float, 15 | ): 16 | self.value = torch.zeros(shape, device=device, dtype=dtype) 17 | self.default_value = default_value 18 | self.enabled = enabled 19 | self.excluded_attributes = [ 20 | "excluded_attributes", 21 | "value", 22 | "default_value", 23 | "randomize", 24 | "default", 25 | "enabled", 26 | ] 27 | self.default() 28 | 29 | def __getattr__(self, name: str): 30 | if name not in self.excluded_attributes and hasattr(self.value, name): 31 | return getattr(self.value, name) 32 | else: 33 | return getattr(self, name) 34 | 35 | def __str__(self) -> str: 36 | return str(self.value) 37 | 38 | def randomize(self, idx: Optional[torch.Tensor] = None) -> torch.Tensor: 39 | raise NotImplementedError 40 | 41 | def default(self) -> torch.Tensor: 42 | self.value = torch.full_like(self.value, self.default_value) 43 | return self.value 44 | 45 | def __add__(self, other): 46 | return self.value + other 47 | def __rsub__(self, other): 48 | return other - self.value 49 | def __sub__(self, other): 50 | return self.value - other 51 | def __rmul__(self, other): 52 | return other * self.value 53 | def __mul__(self, other): 54 | return self.value * other 55 | def __div__(self, other): 56 | return self.value / other 57 | def __neg__(self): 58 | return -self.value 59 | def reshape(self, shape: Union[int, List[int], torch.Size]): 60 | return self.value.reshape(shape) 61 | def squeeze(self, dim: int = -1): 62 | return self.value.squeeze(dim) 63 | def unsqueeze(self, dim: int = -1): 64 | return self.value.unsqueeze(dim) 65 | 66 | class UniformRandomizer(RandomizerBase): 67 | def __init__( 68 | self, 69 | shape: Union[int, List[int], torch.Size], 70 | default_value: Union[float, bool], 71 | device: torch.device, 72 | enabled: bool = True, 73 | low: float = 0.0, 74 | high: float = 1.0, 75 | dtype: torch.dtype = torch.float, 76 | ): 77 | self.low = low 78 | self.high = high 79 | super().__init__(shape, default_value, device, enabled, dtype) 80 | self.excluded_attributes.extend(["low", "high"]) 81 | 82 | def randomize(self, idx: Optional[torch.Tensor] = None) -> torch.Tensor: 83 | if idx is not None: 84 | mask = torch.zeros_like(self.value, dtype=torch.bool) 85 | mask[idx] = True 86 | new = torch.rand_like(self.value) * (self.high - self.low) + self.low 87 | self.value = torch.where(mask, new, self.value) 88 | else: 89 | self.value.uniform_(self.low, self.high) 90 | return self.value 91 | 92 | def __repr__(self) -> str: 93 | return dedent(f""" 94 | UniformRandomizer( 95 | enabled={self.enabled}, 96 | low={self.low}, 97 | high={self.high}, 98 | default={self.default_value}, 99 | shape={self.value.shape}, 100 | device={self.value.device}, 101 | dtype={self.value.dtype} 102 | )""") 103 | 104 | @staticmethod 105 | def build( 106 | cfg: DictConfig, 107 | shape: Union[int, List[int], torch.Size], 108 | device: torch.device, 109 | dtype: torch.dtype = torch.float 110 | ) -> 'UniformRandomizer': 111 | return UniformRandomizer( 112 | shape=shape, 113 | default_value=cfg.default, 114 | device=device, 115 | enabled=cfg.enabled, 116 | low=cfg.min, 117 | high=cfg.max, 118 | dtype=dtype, 119 | ) 120 | 121 | class NormalRandomizer(RandomizerBase): 122 | def __init__( 123 | self, 124 | shape: Union[int, List[int], torch.Size], 125 | default_value: Union[float, bool], 126 | device: torch.device, 127 | enabled: bool = True, 128 | mean: float = 0.0, 129 | std: float = 1.0, 130 | dtype: torch.dtype = torch.float, 131 | ): 132 | self.mean = mean 133 | self.std = std 134 | super().__init__(shape, default_value, device, enabled, dtype) 135 | self.excluded_attributes.extend(["mean", "std"]) 136 | 137 | def randomize(self, idx: Optional[torch.Tensor] = None) -> torch.Tensor: 138 | if idx is not None: 139 | mask = torch.zeros_like(self.value, dtype=torch.bool) 140 | mask[idx] = True 141 | new = torch.randn_like(self.value) * self.std + self.mean 142 | self.value = torch.where(mask, new, self.value) 143 | else: 144 | self.value.normal_(self.mean, self.std) 145 | return self.value 146 | 147 | def __repr__(self) -> str: 148 | return dedent(f""" 149 | NormalRandomizer(enabled={self.enabled}, 150 | mean={self.mean}, 151 | std={self.std}, 152 | default={self.default_value}, 153 | shape={self.value.shape}, 154 | device={self.value.device}, 155 | dtype={self.value.dtype} 156 | )""") 157 | 158 | @staticmethod 159 | def build( 160 | cfg: DictConfig, 161 | shape: Union[int, List[int], torch.Size], 162 | device: torch.device, 163 | dtype: torch.dtype = torch.float 164 | ) -> 'NormalRandomizer': 165 | return NormalRandomizer( 166 | shape=shape, 167 | default_value=cfg.default, 168 | device=device, 169 | enabled=cfg.enabled, 170 | mean=cfg.mean, 171 | std=cfg.std, 172 | dtype=dtype, 173 | ) 174 | 175 | class RandomizerManager: 176 | randomizers: List[Union[UniformRandomizer, NormalRandomizer]] = [] 177 | def __init__( 178 | self, 179 | cfg: DictConfig, 180 | ): 181 | self.enabled: bool = cfg.enabled 182 | 183 | def refresh(self, idx: Optional[torch.Tensor] = None): 184 | for randomizer in self.randomizers: 185 | if self.enabled and randomizer.enabled: 186 | randomizer.randomize(idx) 187 | else: 188 | randomizer.default() 189 | 190 | def __str__(self) -> str: 191 | return ( 192 | "RandomizeManager(\n\t" + 193 | f"Enabled: {self.enabled},\n\t" + 194 | ",\n ".join([randomizer.__repr__() for randomizer in self.randomizers]) + 195 | "\n)" 196 | ) 197 | 198 | def build_randomizer( 199 | cfg: DictConfig, 200 | shape: Union[int, List[int], torch.Size], 201 | device: torch.device, 202 | dtype: torch.dtype = torch.float, 203 | ) -> Union[UniformRandomizer, NormalRandomizer]: 204 | if hasattr(cfg, "min") and hasattr(cfg, "max"): 205 | randomizer = UniformRandomizer.build(cfg, shape, device, dtype) 206 | elif hasattr(cfg, "mean") and hasattr(cfg, "std"): 207 | randomizer = NormalRandomizer.build(cfg, shape, device, dtype) 208 | else: 209 | raise ValueError("Invalid randomizer configuration. Must contain 'min' and 'max' for UniformRandomizer or 'mean' and 'std' for NormalRandomizer.") 210 | RandomizerManager.randomizers.append(randomizer) 211 | return randomizer 212 | 213 | if __name__ == "__main__": 214 | # Example usage 215 | print(UniformRandomizer([2, 3], 0.5, torch.device("cpu"), low=0.0, high=1.0).randomize(torch.tensor([0]))) 216 | print(build_randomizer(DictConfig({"defalut": 0.5, "min": 0, "max": 1}), [2, 3], torch.device("cpu")).randomize()) 217 | print(build_randomizer(DictConfig({"defalut": 0.5, "mean": 0, "std": 1}), [2, 3], torch.device("cpu")).randomize()) 218 | print(RandomizerManager(DictConfig({"enable": False}))) --------------------------------------------------------------------------------