├── .github └── workflows │ ├── ruff.yml │ └── unittest.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── df_clf_cbf_comp.jpg ├── dp_sensitive_to_init_pos.jpg ├── model_input.jpg ├── result_anim.gif └── result_plot.png ├── config └── config.yaml ├── core ├── controllers │ ├── base_controller.py │ ├── quadrotor_clf_cbf_qp.py │ └── quadrotor_diffusion_policy.py ├── dataset │ └── quadrotor_dataset.py ├── env │ └── planar_quadrotor.py ├── networks │ ├── conditional_unet1d.py │ ├── conv1d_components.py │ └── positional_embedding.py └── trainers │ ├── base_trainer.py │ └── quadrotor_diffusion_policy_trainer.py ├── demo.ipynb ├── learning_note.md ├── pyproject.toml ├── tests ├── __init__.py ├── config │ └── test_config.yaml ├── fixtures │ └── test_dataset.joblib ├── test_datasets.py └── test_networks.py ├── train.ipynb └── utils ├── normalizers.py ├── utils.py └── visualization.py /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [push, pull_request] 3 | jobs: 4 | ruff: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - uses: chartboost/ruff-action@v1 9 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | name: Python Unit Tests 2 | 3 | # Defines when the action should run. This workflow triggers on push and pull requests to the main branch. 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | # Jobs that the workflow will execute 11 | jobs: 12 | run-unittests: 13 | # The type of runner that the job will run on 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.10"] 18 | 19 | # Steps represent a sequence of tasks that will be executed as part of the job 20 | steps: 21 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it 22 | - uses: actions/checkout@v3 23 | 24 | # Sets up a Python environment 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | # Install dependencies 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install torch==1.13.1 torchvision==0.14.1 jax==0.4.23 jaxlib==0.4.23 diffusers==0.18.2 joblib 35 | 36 | # Run unittests using the Python unittest module 37 | - name: Run unittest 38 | run: python -m unittest discover 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | tmp_imgs/*.* 4 | *-checkpoint.py 5 | *-checkpoint.ipynb 6 | *-checkpoint.yaml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 shaoanlu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffusion_policy_quadrotor 2 | This repository provides a demonstration of imitation learning using a diffusion policy on quadrotor control. The implementation is adapted from the official Diffusion Policy [repository](https://github.com/real-stanford/diffusion_policy) with an additional feature of using CBf-CLF controller to improve the safety of the generated trajectory. 3 | 4 | ## Result 5 | The control task is to drive the quadrotor from the initial position (0, 0) to the goal position (5, 5) without collision with the obstacles. The animation shows the denoising process of the diffusion policy predicting future trajectory followed by the quadrotor applying the actions. 6 | 7 | drawing drawing 8 | 9 | 10 | ## Usage 11 | The notebook `demo.ipynb` demonstrates a closed-loop simulation using the diffusion policy controller for quadrotor collision avoidance. You can run it in colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shaoanlu/diffusion_policy_quadrotor/blob/main/demo.ipynb). 12 | 13 | The training script is provided as `train.ipynb`. 14 | 15 | ## Dependencies 16 | The program was developed and tested in the following environment. 17 | - Python 3.10 18 | - `torch==2.2.1` 19 | - `jax==0.4.26` 20 | - `jaxlib==0.4.26` 21 | - `diffusers==0.27.2` 22 | - `torchvision==0.14.1` 23 | - `gdown` (to download pre-trained weights) 24 | - `joblib` (format of training data) 25 | 26 | ## Diffusion policy 27 | The policy takes 1) the latest N step of observation $o_t$ (position and velocity) and 2) the encoding of obstacle information $O_{BST}$ (a flattened 7x7 grid with obstacle radius as values) as input. The outputs are N steps of actions $a_t$ (future position and future velocity). 28 | 29 | drawing 30 | 31 | *The quadrotor icon is from [flaticon](https://www.flaticon.com/free-icon/quadcopter_5447794). 32 | 33 | 34 | ### Deviation from the original implementation 35 | - Add a linear layer before the Mish activation to the condition encoder of `ConditionalResidualBlock1D`. This is to prevent the activation from truncating large negative values from the normalized observation. 36 | - A CLF-CBF-QP controller is implemented and used to modify the noisy actions during the denoising process of the policy. By default, this controller is disabled. 37 | - A finetuned model for single-step inference is used by default. 38 | 39 | drawing 40 | 41 | 42 | ## References 43 | Papers 44 | - [Diffusion Policy: Visuomotor Policy Learning via Action Diffusion](https://diffusion-policy.cs.columbia.edu/) [arXiv:2303.04137] 45 | - [3D Diffusion Policy: Generalizable Visuomotor Policy Learning via Simple 3D Representations](https://3d-diffusion-policy.github.io/) [arXiv:2403.03954] 46 | - [Fine-Tuning Image-Conditional Diffusion Models is Easier than You Think](https://gonzalomartingarcia.github.io/diffusion-e2e-ft/) [arXiv:2409.11355] 47 | 48 | Videos and Lectures 49 | - [DeepLearning.AI: How Diffusion Models Work](https://www.deeplearning.ai/short-courses/how-diffusion-models-work/) 50 | - [[论文速览]Diffusion Policy: Visuomotor Policy Learning via Action Diff.[2303.04137]](https://www.bilibili.com/video/BV1Cu411Y7d7) 51 | - [[論文導讀]Planning with Diffusion for Flexible Behavior Synthesis解說 (含程式碼)](https://youtu.be/ciCcvWutle4) 52 | - [6.4210 Fall 2023 Lecture 18: Visuomotor Policies (via Behavior Cloning)](https://youtu.be/i-303tTtEig) 53 | 54 | ## Learning note 55 | ### Failure case: the diffusion policy controller failed to extrapolate from training data 56 | Figure: A failure case of the controller. 57 | - The left figure is a trajectory in the training data. 58 | - The middle figure is the closed-loop simulation result of the controller starting from the SAME initial position as the training data. 59 | - The right figure is the closed-loop simulation result of the controller starting from a DIFFERENT initial position, which resulted in a trajectory with collision. 60 | 61 | drawing 62 | 63 | Refer to [`learning_note.md`](learning_note.md) for other notes. 64 | 65 | -------------------------------------------------------------------------------- /assets/df_clf_cbf_comp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/df_clf_cbf_comp.jpg -------------------------------------------------------------------------------- /assets/dp_sensitive_to_init_pos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/dp_sensitive_to_init_pos.jpg -------------------------------------------------------------------------------- /assets/model_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/model_input.jpg -------------------------------------------------------------------------------- /assets/result_anim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/result_anim.gif -------------------------------------------------------------------------------- /assets/result_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/result_plot.png -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | name: "planar_quadrotor" 2 | 3 | pred_horizon: 96 4 | obs_horizon: 2 5 | action_horizon: 10 6 | 7 | controller: 8 | common: 9 | sampling_time: 0.1 # sec 10 | use_single_step_inference: true 11 | networks: 12 | obs_dim: 6 13 | action_dim: 6 14 | obstacle_encode_dim: 49 15 | noise_scheduler: 16 | type: "ddpm" 17 | ddpm: 18 | num_train_timesteps: 100 # number of diffusion iterations 19 | beta_schedule: "squaredcos_cap_v2" 20 | clip_sample: true # required when predict_epsilon=False 21 | prediction_type: "epsilon" 22 | ddim: # faster inference 23 | num_train_timesteps: 100 24 | beta_schedule: "squaredcos_cap_v2" 25 | clip_sample: true 26 | prediction_type: "epsilon" 27 | dpmsolver: # faster inference, experimental 28 | num_train_timesteps: 100 29 | beta_schedule: "squaredcos_cap_v2" 30 | prediction_type: "epsilon" 31 | use_karras_sigmas: true 32 | 33 | cbf_clf_controller: 34 | denoising_guidance_step: 100 # equals num_train_timesteps 35 | cbf_alpha: 10.0 36 | clf_gamma: 0.03 37 | penalty_slack_cbf: 1.0e+3 38 | penalty_slack_clf: 1.0 39 | 40 | trainer: 41 | use_ema: true 42 | batch_size: 256 43 | optimizer: 44 | name: "adamw" 45 | learning_rate: 1.0e-4 46 | weight_decay: 1.0e-6 47 | lr_scheduler: 48 | name: "cosine" 49 | num_warmup_steps: 500 50 | 51 | dataloader: 52 | batch_size: 256 53 | 54 | normalizer: 55 | action: 56 | min: [-2.6515920785069467, -10.169477462768555, -4.270973491668701, -13.883374419993748, -1.9637073183059692, -20.395111083984375] 57 | max: [8.0543811917305, 6.976394176483154, 8.477932775914669, 11.327190831233095, 2.5276688146591186, 18.05487823486328] 58 | observation: 59 | min: [-2.649836301803589, -9.564324378967285, -4.264063358306885, -13.772777557373047, -1.9476231336593628, -17.225351333618164] 60 | max: [8.05234146118164, 6.976299285888672, 8.474746704101562, 10.96181583404541, 2.5151419639587402, 18.054880142211914] 61 | 62 | simulator: 63 | m_q: 1.0 # kg 64 | I_xx: 0.1 # kg.m^2 65 | l_q: 0.3 # m, length of the quadrotor 66 | g: 9.81 67 | dt: 0.01 68 | -------------------------------------------------------------------------------- /core/controllers/base_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, List 3 | 4 | 5 | class BaseController: 6 | def __init__(self, device: str = "cuda"): 7 | self.device = device 8 | 9 | def predict_action(self, obs_dict: Dict[str, List]) -> np.ndarray: 10 | raise NotImplementedError() 11 | 12 | def reset(self): 13 | # reset state for stateful policies 14 | pass 15 | -------------------------------------------------------------------------------- /core/controllers/quadrotor_clf_cbf_qp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import numpy as np 3 | import osqp 4 | from scipy import sparse 5 | 6 | from core.controllers.base_controller import BaseController 7 | 8 | 9 | class QuadrotorCLFCBFController(BaseController): 10 | """ 11 | A CLF-CBF safety filter assuming a simple velocity-controled dynamics 12 | y_dot = u1 13 | z_dot = u2 14 | Barrier funciton h is defined as the distances to each obstacle 15 | """ 16 | 17 | def __init__(self, config: Dict, device: str = "cuda"): 18 | super().__init__(device) 19 | self.obstacle_info = {"center": [], "radius": []} 20 | self.set_config(config) 21 | 22 | def predict_action(self, obs_dict: Dict[str, List], control: np.ndarray, target_position: np.ndarray) -> np.ndarray: 23 | for center, radius in zip(obs_dict["obstacle_info"]["center"], obs_dict["obstacle_info"]["radius"]): 24 | self.set_obstacle(center, radius) 25 | 26 | safe_command = self.clf_cbf_control( 27 | state=obs_dict["state"], 28 | control=control, 29 | obs_center=self.obstacle_info["center"], 30 | obs_radius=self.obstacle_info["radius"], 31 | cbf_alpha=self.cbf_alpha, 32 | clf_gamma=self.clf_gamma, 33 | penalty_slack_cbf=self.penalty_slack_cbf, 34 | penalty_slack_clf=self.penalty_slack_clf, 35 | target_position=target_position, 36 | ) 37 | return safe_command 38 | 39 | def set_obstacle(self, center: tuple, radius: float): 40 | self.obstacle_info = {"center": [], "radius": []} 41 | self.obstacle_info["center"].append(center) 42 | self.obstacle_info["radius"].append(radius) 43 | 44 | def set_config(self, config: Dict): 45 | self.cbf_alpha = config["cbf_clf_controller"]["cbf_alpha"] 46 | self.clf_gamma = config["cbf_clf_controller"]["clf_gamma"] 47 | self.penalty_slack_cbf = config["cbf_clf_controller"]["penalty_slack_cbf"] 48 | self.penalty_slack_clf = config["cbf_clf_controller"]["penalty_slack_clf"] 49 | self.denoising_guidance_step = config["cbf_clf_controller"]["denoising_guidance_step"] 50 | self.quadrotor_params = config["simulator"] 51 | 52 | @staticmethod 53 | def _barrier_func(y, z, obs_y, obs_z, obs_r) -> float: 54 | return (y - obs_y) ** 2 + (z - obs_z) ** 2 - (obs_r) ** 2 55 | 56 | @staticmethod 57 | def _barrier_func_dot(y, z, obs_y, obs_z) -> list: 58 | return [2 * (y - obs_y), 2 * (z - obs_z)] 59 | 60 | @staticmethod 61 | def _lyapunoc_func(y, z, des_y, des_z) -> float: 62 | return (y - des_y) ** 2 + (z - des_z) ** 2 63 | 64 | @staticmethod 65 | def _lyapunov_func_dot(y, z, des_y, des_z) -> list: 66 | return [2 * (y - des_y), 2 * (z - des_z)] 67 | 68 | @staticmethod 69 | def _define_QP_problem_data( 70 | u1: float, 71 | u2: float, 72 | cbf_alpha: float, 73 | clf_gamma: float, 74 | penalty_slack_cbf: float, 75 | penalty_slack_clf: float, 76 | h: list, 77 | coeffs_dhdx: list, 78 | v: list, 79 | coeffs_dvdx: list, 80 | vmin=-15.0, 81 | vmax=15.0, 82 | ): 83 | vmin, vmax = -15.0, 15.0 84 | 85 | P = sparse.csc_matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, penalty_slack_cbf, 0], [0, 0, 0, penalty_slack_clf]]) 86 | q = np.array([-u1, -u2, 0, 0]) 87 | A = sparse.csc_matrix( 88 | [c for c in coeffs_dhdx] 89 | + [c for c in coeffs_dvdx] 90 | + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] 91 | ) 92 | lb = np.array([-cbf_alpha * h_ for h_ in h] + [-np.inf for _ in v] + [vmin, vmin, 0, 0]) 93 | ub = np.array([np.inf for _ in h] + [-clf_gamma * v_ for v_ in v] + [vmax, vmax, np.inf, np.inf]) 94 | return P, q, A, lb, ub 95 | 96 | @staticmethod 97 | def _get_quadrotor_state(state): 98 | y, y_dot, z, z_dot, phi, phi_dot = state 99 | return y, y_dot, z, z_dot, phi, phi_dot 100 | 101 | def _calculate_cbf_coeffs(self, state: np.ndarray, obs_center: List, obs_radius: List, minimal_distance: float): 102 | """ 103 | Let barrier function be h and system state x, the CBF constraint 104 | h_dot(x) >= - alpha * h + δ 105 | """ 106 | h = [] # barrier values (here, remaining distance to each obstacle) 107 | coeffs_dhdx = [] # dhdt = dhdx * dxdt = dhdx * u 108 | for center, radius in zip(obs_center, obs_radius): 109 | y, _, z, _, _, _ = self._get_quadrotor_state(state) 110 | h.append(self._barrier_func(y, z, center[0], center[1], radius + minimal_distance)) 111 | # Additional [1, 0] incorporates the CBF slack variable into the constraint 112 | coeffs_dhdx.append(self._barrier_func_dot(y, z, center[0], center[1]) + [1, 0]) 113 | return h, coeffs_dhdx 114 | 115 | def _calculate_clf_coeffs(self, state: np.ndarray, target_y: float, _target_z: float): 116 | """ 117 | Let Lyapunov function be v and system state x, the CBF constraint 118 | v_dot(x) - δ <= - gamma * v 119 | """ 120 | y, _, z, _, _, _ = self._get_quadrotor_state(state) 121 | v = [self._lyapunoc_func(y, z, target_y, _target_z)] 122 | # Additional [0, -1] incorporates the CLF slack variable into the constraint 123 | coeffs_dvdx = [self._lyapunov_func_dot(y, z, target_y, _target_z) + [0, -1]] 124 | return v, coeffs_dvdx 125 | 126 | def clf_cbf_control( 127 | self, 128 | state: np.ndarray, 129 | control: np.ndarray, 130 | obs_center: List, 131 | obs_radius: List, 132 | cbf_alpha: float = 15.0, 133 | clf_gamma: float = 0.01, 134 | penalty_slack_cbf: float = 1e2, 135 | penalty_slack_clf: float = 1.0, 136 | target_position: tuple = (5.0, 5.0), 137 | ): 138 | """ 139 | Calculate the safe command by solveing the following optimization problem 140 | 141 | minimize || u - u_nom ||^2 + k * δ^2 142 | u, δ 143 | s.t. 144 | h'(x) ≥ -𝛼 * h(x) - δ1 145 | v'(x) ≤ -γ * v(x) + δ2 146 | u_min ≤ u ≤ u_max 147 | 0 ≤ δ1,δ2 ≤ inf 148 | where 149 | u = [ux, uy] is the control input in x and y axis respectively. 150 | δ is the slack variable 151 | h(x) is the control barrier function and h'(x) its derivative 152 | v(x) is the lyapunov function and v'(x) its derivative 153 | 154 | The problem above can be formulated as QP (ref: https://osqp.org/docs/solver/index.html) 155 | 156 | minimize 1/2 * x^T * Px + q^T x 157 | x 158 | s.t. 159 | l ≤ Ax ≤ u 160 | where 161 | x = [ux, uy, δ1, δ2] 162 | 163 | """ 164 | u1, u2 = control 165 | target_y, target_z = target_position 166 | 167 | # Calculate values of the barrier function and coeffs in h_dot to state 168 | h, coeffs_dhdx = self._calculate_cbf_coeffs(state, obs_center, obs_radius, self.quadrotor_params["l_q"]) 169 | # Calculate value of the lyapunov function and coeffs in v_dot to state 170 | v, coeffs_dvdx = self._calculate_clf_coeffs(state, target_y, target_z) 171 | 172 | # Define problem 173 | P, q, A, lb, ub = self._define_QP_problem_data( 174 | u1, u2, cbf_alpha, clf_gamma, penalty_slack_cbf, penalty_slack_clf, h, coeffs_dhdx, v, coeffs_dvdx 175 | ) 176 | 177 | # Solve QP 178 | prob = osqp.OSQP() 179 | prob.setup(P, q, A, lb, ub, verbose=False, time_limit=0) 180 | # Solve QP problem 181 | res = prob.solve() 182 | 183 | safe_u1, safe_u2, _, _ = res.x 184 | return np.array([safe_u1, safe_u2]) 185 | -------------------------------------------------------------------------------- /core/controllers/quadrotor_diffusion_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 6 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 7 | from diffusers.schedulers.scheduling_dpmsolver_multistep import ( 8 | DPMSolverMultistepScheduler, 9 | ) 10 | 11 | from core.controllers.base_controller import BaseController 12 | from core.controllers.quadrotor_clf_cbf_qp import QuadrotorCLFCBFController 13 | from core.networks.conditional_unet1d import ConditionalUnet1D 14 | from utils.normalizers import BaseNormalizer 15 | 16 | 17 | def build_networks_from_config(config: Dict): 18 | action_dim = config["controller"]["networks"]["action_dim"] 19 | obs_dim = config["controller"]["networks"]["obs_dim"] 20 | obs_horizon = config["obs_horizon"] 21 | obstacle_encode_dim = config["controller"]["networks"]["obstacle_encode_dim"] 22 | return ConditionalUnet1D(input_dim=action_dim, global_cond_dim=obs_dim * obs_horizon + obstacle_encode_dim) 23 | 24 | 25 | def build_noise_scheduler_from_config(config: Dict): 26 | type_noise_scheduler = config["controller"]["noise_scheduler"]["type"] 27 | if type_noise_scheduler.lower() == "ddpm": 28 | return DDPMScheduler( 29 | num_train_timesteps=config["controller"]["noise_scheduler"]["ddpm"]["num_train_timesteps"], 30 | beta_schedule=config["controller"]["noise_scheduler"]["ddpm"]["beta_schedule"], 31 | clip_sample=config["controller"]["noise_scheduler"]["ddpm"]["clip_sample"], 32 | prediction_type=config["controller"]["noise_scheduler"]["ddpm"]["prediction_type"], 33 | ) 34 | elif type_noise_scheduler.lower() == "ddim": 35 | return DDIMScheduler( 36 | num_train_timesteps=config["controller"]["noise_scheduler"]["ddim"]["num_train_timesteps"], 37 | beta_schedule=config["controller"]["noise_scheduler"]["ddim"]["beta_schedule"], 38 | clip_sample=config["controller"]["noise_scheduler"]["ddim"]["clip_sample"], 39 | prediction_type=config["controller"]["noise_scheduler"]["ddim"]["prediction_type"], 40 | ) 41 | elif type_noise_scheduler.lower() == "dpmsolver": 42 | return DPMSolverMultistepScheduler( 43 | num_train_timesteps=config["controller"]["noise_scheduler"]["dpmsolver"]["num_train_timesteps"], 44 | beta_schedule=config["controller"]["noise_scheduler"]["dpmsolver"]["beta_schedule"], 45 | prediction_type=config["controller"]["noise_scheduler"]["dpmsolver"]["prediction_type"], 46 | use_karras_sigmas=config["controller"]["noise_scheduler"]["dpmsolver"]["use_karras_sigmas"], 47 | ) 48 | else: 49 | raise NotImplementedError 50 | 51 | 52 | class QuadrotorDiffusionPolicy(BaseController): 53 | def __init__( 54 | self, 55 | model: ConditionalUnet1D, 56 | noise_scheduler: DDPMScheduler, 57 | normalizer: BaseNormalizer, 58 | clf_cbf_controller: QuadrotorCLFCBFController, 59 | config: Dict, 60 | device: str = "cuda", 61 | ): 62 | self.device = device 63 | self.net = model 64 | self.noise_scheduler = noise_scheduler 65 | self.normalizer = normalizer 66 | 67 | self.set_config(config) 68 | self.net.to(self.device) 69 | 70 | self.clf_cbf_controller = clf_cbf_controller 71 | self.use_clf_cbf_guidance = False if clf_cbf_controller is None else True 72 | 73 | def predict_action(self, obs_dict: Dict[str, List]) -> np.ndarray: 74 | # stack the observations 75 | obs_seq = np.stack(obs_dict["state"]) 76 | # normalize observation and make it 1D 77 | nobs = self.normalizer.normalize_data(obs_seq, stats=self.norm_stats["obs"]) 78 | nobs = nobs.flatten() 79 | # concat obstacle information to observations 80 | nobs = np.concatenate([nobs] + obs_dict["obs_encode"], axis=0) 81 | # device transfer 82 | nobs = torch.from_numpy(nobs).to(self.device, dtype=torch.float32) 83 | 84 | # infer action 85 | with torch.no_grad(): 86 | # reshape observation to (1, obs_horizon*obs_dim+obstacle_encode_dim) 87 | obs_cond = nobs.unsqueeze(0).flatten(start_dim=1) 88 | 89 | # initialize action from Guassian noise 90 | noisy_action = torch.randn((1, self.pred_horizon, self.action_dim), device=self.device) 91 | naction = noisy_action 92 | 93 | # init scheduler 94 | self.noise_scheduler.set_timesteps(self.noise_scheduler.config.num_train_timesteps) 95 | 96 | # denoise 97 | denoise_timesteps = ( 98 | self.noise_scheduler.timesteps[:1] if self.use_single_step_inference else self.noise_scheduler.timesteps 99 | ) 100 | for k in denoise_timesteps: 101 | # predict noise 102 | noise_pred = self.net(sample=naction, timestep=k, global_cond=obs_cond) 103 | # inverse diffusion step (remove noise) 104 | if self.use_single_step_inference: 105 | naction = noisy_action - noise_pred 106 | else: 107 | naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample 108 | 109 | if self.use_clf_cbf_guidance: 110 | diffusing_action = self.normalizer.unnormalize_data( 111 | naction.detach().to("cpu").numpy().squeeze(), stats=self.norm_stats["act"] 112 | ) # (pred_horizon, 6) 113 | if k < self.clf_cbf_controller.denoising_guidance_step: 114 | refined_action = diffusing_action.copy() 115 | for idx, act in enumerate(diffusing_action): 116 | clf_cbf_obs, pred_control, target_position = self._preprocess_cbf_clf_input( 117 | obs_dict, act, diffusing_action 118 | ) 119 | safe_yz_velocity = self.clf_cbf_controller.predict_action( 120 | obs_dict=clf_cbf_obs, 121 | control=pred_control, 122 | target_position=target_position, 123 | ) 124 | refined_action[idx, ...] = self._calculate_refined_action_step(act, safe_yz_velocity) 125 | naction = self.normalizer.normalize_data(np.array(refined_action), stats=self.norm_stats["act"]) 126 | naction = torch.from_numpy(naction).to(self.device, dtype=torch.float32).unsqueeze(0) 127 | 128 | # unnormalize action 129 | naction = naction.detach().to("cpu").numpy() 130 | # (1, pred_horizon, action_dim) 131 | naction = naction[0] 132 | action_pred = self.normalizer.unnormalize_data(naction, stats=self.norm_stats["act"]) 133 | 134 | # only take action_horizon number of actions 135 | start = self.obs_horizon - 1 136 | end = start + self.action_horizon 137 | action = action_pred[start:end, :] # (action_horizon, action_dim) 138 | 139 | return action 140 | 141 | def load_weights(self, ckpt_path: str): 142 | state_dict = torch.load(ckpt_path, map_location=self.device) 143 | self.net.load_state_dict(state_dict) 144 | 145 | def set_config(self, config: Dict): 146 | self.obs_horizon = config["obs_horizon"] 147 | self.action_horizon = config["action_horizon"] 148 | self.pred_horizon = config["pred_horizon"] 149 | self.action_dim = config["controller"]["networks"]["action_dim"] 150 | self.sampling_time = config["controller"]["common"]["sampling_time"] 151 | self.norm_stats = { 152 | "act": config["normalizer"]["action"], 153 | "obs": config["normalizer"]["observation"], 154 | } 155 | self.quadrotor_params = config["simulator"] 156 | self.use_single_step_inference = config.get("controller").get("common").get("use_single_step_inference", False) 157 | 158 | def calculate_force_command(self, state: np.ndarray, ref_state: np.ndarray) -> np.ndarray: 159 | y, y_dot, z, z_dot, phi, phi_dot = state 160 | yr, yr_dot, zr, zr_dot, phir, phir_dot = ref_state 161 | ( 162 | dt, 163 | m_q, 164 | ) = self.quadrotor_params["dt"], self.quadrotor_params["m_q"] 165 | g, I_xx = self.quadrotor_params["g"], self.quadrotor_params["I_xx"] 166 | # how on earth do you want to calculate acceleration from position signals 167 | # est_zr_dot = (zr - z) / dt 168 | # est_phir_dot = (phir - phi) / dt 169 | zr_ddot = (zr_dot - z_dot) / dt 170 | phir_ddot = (phir_dot - phi_dot) / dt 171 | return np.array([m_q * (g + zr_ddot), I_xx * phir_ddot]) 172 | 173 | def _preprocess_cbf_clf_input( 174 | self, obs_dict: Dict[str, List], pred_action: np.ndarray, diffusing_action: np.ndarray 175 | ): 176 | pred_state = pred_action 177 | pred_control = pred_action[[1, 3]] 178 | target_position_y, target_position_z = diffusing_action[-1, [0, 2]] # myoptic planning of CLF 179 | target_position = (target_position_y, target_position_z) 180 | obstacle_info = {"center": obs_dict["obs_center"], "radius": obs_dict["obs_radius"]} 181 | return {"state": pred_state, "obstacle_info": obstacle_info}, pred_control, target_position 182 | 183 | def _calculate_refined_action_step(self, pred_act, safe_yz_velocity): 184 | refined_step_action = pred_act.copy() 185 | refined_step_action[0] += safe_yz_velocity[0] * self.sampling_time 186 | refined_step_action[2] += safe_yz_velocity[1] * self.sampling_time 187 | refined_step_action[1] = safe_yz_velocity[0] 188 | refined_step_action[3] = safe_yz_velocity[1] 189 | refined_step_action[4] = -np.arctan(safe_yz_velocity[0] / safe_yz_velocity[1]) 190 | # refinedstep_action[5] = (refinedstep_action[4] - pred_act[4]) / self.sampling_time 191 | return refined_step_action 192 | -------------------------------------------------------------------------------- /core/dataset/quadrotor_dataset.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import torch 3 | from typing import Dict 4 | import numpy as np 5 | 6 | 7 | def create_sample_indices( 8 | episode_ends: np.ndarray, 9 | sequence_length: int, 10 | pad_before: int = 0, 11 | pad_after: int = 0, 12 | ): 13 | indices = list() 14 | for i in range(len(episode_ends)): 15 | start_idx = 0 16 | if i > 0: 17 | start_idx = episode_ends[i - 1] 18 | end_idx = episode_ends[i] 19 | episode_length = end_idx - start_idx 20 | 21 | min_start = -pad_before 22 | max_start = episode_length - sequence_length + pad_after 23 | 24 | # range stops one idx before end 25 | for idx in range(min_start, max_start + 1): 26 | buffer_start_idx = max(idx, 0) + start_idx 27 | buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx 28 | start_offset = buffer_start_idx - (idx + start_idx) 29 | end_offset = (idx + sequence_length + start_idx) - buffer_end_idx 30 | sample_start_idx = 0 + start_offset 31 | sample_end_idx = sequence_length - end_offset 32 | indices.append([buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx]) 33 | indices = np.array(indices) 34 | return indices 35 | 36 | 37 | def sample_sequence( 38 | train_data, 39 | sequence_length, 40 | buffer_start_idx, 41 | buffer_end_idx, 42 | sample_start_idx, 43 | sample_end_idx, 44 | ): 45 | result = dict() 46 | for key, input_arr in train_data.items(): 47 | sample = input_arr[buffer_start_idx:buffer_end_idx] 48 | data = sample 49 | if (sample_start_idx > 0) or (sample_end_idx < sequence_length): 50 | data = np.zeros(shape=(sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype) 51 | if sample_start_idx > 0: 52 | data[:sample_start_idx] = sample[0] 53 | if sample_end_idx < sequence_length: 54 | data[sample_end_idx:] = sample[-1] 55 | data[sample_start_idx:sample_end_idx] = sample 56 | result[key] = data 57 | return result 58 | 59 | 60 | # normalize data 61 | def get_data_stats(data): 62 | data = data.reshape(-1, data.shape[-1]) 63 | stats = {"min": np.min(data, axis=0), "max": np.max(data, axis=0)} 64 | return stats 65 | 66 | 67 | def normalize_data(data, stats): 68 | # nomalize to [0,1] 69 | ndata = (data - stats["min"]) / (stats["max"] - stats["min"]) 70 | # normalize to [-1, 1] 71 | ndata = ndata * 2 - 1 72 | return ndata 73 | 74 | 75 | def unnormalize_data(ndata, stats): 76 | ndata = (ndata + 1) / 2 77 | data = ndata * (stats["max"] - stats["min"]) + stats["min"] 78 | return data 79 | 80 | 81 | def preprocess_dataset(dataset): 82 | dataset["obs_encode"] = [] 83 | dataset["episode_ends"] = [] 84 | 85 | current_idx = 0 86 | for s, c in zip(dataset["state"], dataset["control"]): 87 | dataset["episode_ends"].append(current_idx + s.shape[0]) 88 | current_idx += s.shape[0] 89 | 90 | for s, info in zip(dataset["state"], dataset["info"]): 91 | obs_encode = np.zeros((7, 7)) 92 | for center, radius in zip(info["obs_center"], info["obs_radius"]): 93 | obs_encode[tuple((center - 1).astype(np.int32))] = radius 94 | dataset["obs_encode"].append(np.vstack([obs_encode.flatten()] * s.shape[0])) 95 | return dataset 96 | 97 | 98 | # dataset 99 | class PlanarQuadrotorStateDataset(torch.utils.data.Dataset): 100 | def __init__( 101 | self, 102 | dataset_path: str, 103 | config: Dict, 104 | ): 105 | self.set_config(config) 106 | # read from zarr dataset 107 | dataset_root = joblib.load(dataset_path) 108 | # proprocessing 109 | dataset_root = preprocess_dataset(dataset_root) 110 | # All demonstration episodes are concatinated in the first dimension N 111 | train_data = { 112 | # (N, action_dim) 113 | "action": np.concatenate(dataset_root["desired_state"], axis=0), 114 | # (N, obs_dim) 115 | "obs": np.concatenate(dataset_root["state"], axis=0), 116 | # (N, 7x7) 117 | "obstacle": np.concatenate(dataset_root["obs_encode"], axis=0), 118 | } 119 | # Marks one-past the last index for each episode 120 | episode_ends = dataset_root["episode_ends"] 121 | 122 | # compute start and end of each state-action sequence 123 | # also handles padding 124 | indices = create_sample_indices( 125 | episode_ends=episode_ends, 126 | sequence_length=self.pred_horizon, 127 | # add padding such that each timestep in the dataset are seen 128 | pad_before=self.obs_horizon - 1, 129 | pad_after=self.action_horizon - 1, 130 | ) 131 | 132 | # compute statistics and normalized data to [-1,1] 133 | stats = dict() 134 | normalized_train_data = dict() 135 | for key, data in train_data.items(): 136 | if key == "obstacle": 137 | continue 138 | stats[key] = get_data_stats(data) 139 | normalized_train_data[key] = normalize_data(data, stats[key]) 140 | normalized_train_data["obstacle"] = train_data["obstacle"] 141 | 142 | self.indices = indices 143 | self.stats = stats 144 | self.normalized_train_data = normalized_train_data 145 | 146 | def __len__(self): 147 | # all possible segments of the dataset 148 | return len(self.indices) 149 | 150 | def __getitem__(self, idx): 151 | # get the start/end indices for this datapoint 152 | buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx = self.indices[idx] 153 | 154 | # get nomralized data using these indices 155 | nsample = sample_sequence( 156 | train_data=self.normalized_train_data, 157 | sequence_length=self.pred_horizon, 158 | buffer_start_idx=buffer_start_idx, 159 | buffer_end_idx=buffer_end_idx, 160 | sample_start_idx=sample_start_idx, 161 | sample_end_idx=sample_end_idx, 162 | ) 163 | 164 | # discard unused observations 165 | # (obs_horiuzon * obs_dim + obstacle_encode_dim) 166 | nsample["obs"] = np.concatenate( 167 | [nsample["obs"][: self.obs_horizon, :].flatten(), nsample["obstacle"][0]], 168 | axis=0, 169 | ) 170 | return nsample 171 | 172 | def set_config(self, config: Dict): 173 | self.config = config 174 | self.obs_horizon = config["obs_horizon"] 175 | self.action_horizon = config["action_horizon"] 176 | self.pred_horizon = config["pred_horizon"] 177 | -------------------------------------------------------------------------------- /core/env/planar_quadrotor.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional 3 | 4 | import jax 5 | from jax import numpy as jnp 6 | 7 | # Default quadrotor parameters 8 | m_q = 1.0 # kg 9 | I_xx = 0.1 # kg.m^2 10 | l_q = 0.3 # m, length of the quadrotor 11 | g = 9.81 12 | 13 | 14 | class PlanarQuadrotorEnv: 15 | def __init__(self, config: dict = None, state: Optional[jnp.ndarray] = None): 16 | if config is None: 17 | self.m_q = m_q 18 | self.I_xx = I_xx 19 | self.g = g 20 | self.l_q = l_q 21 | else: 22 | self.m_q = config["simulator"]["m_q"] 23 | self.I_xx = config["simulator"]["I_xx"] 24 | self.g = config["simulator"]["g"] 25 | self.l_q = config["simulator"]["l_q"] 26 | 27 | self.state: Optional[jnp.ndarray] = state 28 | 29 | @partial(jax.jit, static_argnums=0) 30 | def step(self, state=None, control=[0, 0], dt: float = 0.01): 31 | """ 32 | dynamics with JAX-compatible code. 33 | 34 | Equations are from the Aerial Robotics coursera lecture 35 | https://www.coursera.org/lecture/robotics-flight/2-d-quadrotor-control-kakc6 36 | """ 37 | if state is None: 38 | state = self.state 39 | if state is None: 40 | raise Exception("state variable is not defined.") 41 | 42 | y, y_dot, z, z_dot, phi, phi_dot = state 43 | u1, u2 = control 44 | # Quadrotor dynamics 45 | y_ddot = -u1 * jnp.sin(phi) / self.m_q 46 | z_ddot = -self.g + u1 * jnp.cos(phi) / self.m_q 47 | phi_ddot = u2 / self.I_xx 48 | 49 | next_state = ( 50 | state 51 | + jnp.array( 52 | [ 53 | y_dot + y_ddot * dt, 54 | y_ddot, 55 | z_dot + z_ddot * dt, 56 | z_ddot, 57 | phi_dot + phi_ddot * dt, 58 | phi_ddot, 59 | ] 60 | ) 61 | * dt 62 | ) 63 | self.state = next_state 64 | return next_state 65 | -------------------------------------------------------------------------------- /core/networks/conditional_unet1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Union 4 | 5 | from core.networks.conv1d_components import Downsample1d, Upsample1d, Conv1dBlock 6 | from core.networks.positional_embedding import SinusoidalPosEmb 7 | 8 | 9 | class ConditionalResidualBlock1D(nn.Module): 10 | def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): 11 | super().__init__() 12 | 13 | self.blocks = nn.ModuleList( 14 | [ 15 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 16 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 17 | ] 18 | ) 19 | 20 | # FiLM modulation https://arxiv.org/abs/1709.07871 21 | # predicts per-channel scale and bias 22 | cond_channels = out_channels * 2 23 | self.out_channels = out_channels 24 | self.cond_encoder = nn.Sequential( 25 | nn.Linear(cond_dim, cond_dim), nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1)) 26 | ) 27 | 28 | # make sure dimensions compatible 29 | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() 30 | 31 | def forward(self, x, cond): 32 | """ 33 | x : [ batch_size x in_channels x horizon ] 34 | cond : [ batch_size x cond_dim] 35 | 36 | returns: 37 | out : [ batch_size x out_channels x horizon ] 38 | """ 39 | out = self.blocks[0](x) 40 | embed = self.cond_encoder(cond) 41 | 42 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) 43 | scale = embed[:, 0, ...] 44 | bias = embed[:, 1, ...] 45 | out = scale * out + bias 46 | 47 | out = self.blocks[1](out) 48 | out = out + self.residual_conv(x) 49 | return out 50 | 51 | 52 | class ConditionalUnet1D(nn.Module): 53 | def __init__( 54 | self, 55 | input_dim, 56 | global_cond_dim, 57 | diffusion_step_embed_dim=256, 58 | down_dims=[256, 512, 1024], 59 | kernel_size=5, 60 | n_groups=8, 61 | ): 62 | """ 63 | input_dim: Dim of actions. 64 | global_cond_dim: Dim of global conditioning applied with FiLM 65 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim 66 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k 67 | down_dims: Channel size for each UNet level. 68 | The length of this array determines numebr of levels. 69 | kernel_size: Conv kernel size 70 | n_groups: Number of groups for GroupNorm 71 | """ 72 | 73 | super().__init__() 74 | all_dims = [input_dim] + list(down_dims) 75 | start_dim = down_dims[0] 76 | 77 | dsed = diffusion_step_embed_dim 78 | diffusion_step_encoder = nn.Sequential( 79 | SinusoidalPosEmb(dsed), 80 | nn.Linear(dsed, dsed * 4), 81 | nn.Mish(), 82 | nn.Linear(dsed * 4, dsed), 83 | ) 84 | cond_dim = dsed + global_cond_dim 85 | 86 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 87 | mid_dim = all_dims[-1] 88 | self.mid_modules = nn.ModuleList( 89 | [ 90 | ConditionalResidualBlock1D( 91 | mid_dim, 92 | mid_dim, 93 | cond_dim=cond_dim, 94 | kernel_size=kernel_size, 95 | n_groups=n_groups, 96 | ), 97 | ConditionalResidualBlock1D( 98 | mid_dim, 99 | mid_dim, 100 | cond_dim=cond_dim, 101 | kernel_size=kernel_size, 102 | n_groups=n_groups, 103 | ), 104 | ] 105 | ) 106 | 107 | down_modules = nn.ModuleList([]) 108 | for ind, (dim_in, dim_out) in enumerate(in_out): 109 | is_last = ind >= (len(in_out) - 1) 110 | down_modules.append( 111 | nn.ModuleList( 112 | [ 113 | ConditionalResidualBlock1D( 114 | dim_in, 115 | dim_out, 116 | cond_dim=cond_dim, 117 | kernel_size=kernel_size, 118 | n_groups=n_groups, 119 | ), 120 | ConditionalResidualBlock1D( 121 | dim_out, 122 | dim_out, 123 | cond_dim=cond_dim, 124 | kernel_size=kernel_size, 125 | n_groups=n_groups, 126 | ), 127 | Downsample1d(dim_out) if not is_last else nn.Identity(), 128 | ] 129 | ) 130 | ) 131 | 132 | up_modules = nn.ModuleList([]) 133 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 134 | is_last = ind >= (len(in_out) - 1) 135 | up_modules.append( 136 | nn.ModuleList( 137 | [ 138 | ConditionalResidualBlock1D( 139 | dim_out * 2, 140 | dim_in, 141 | cond_dim=cond_dim, 142 | kernel_size=kernel_size, 143 | n_groups=n_groups, 144 | ), 145 | ConditionalResidualBlock1D( 146 | dim_in, 147 | dim_in, 148 | cond_dim=cond_dim, 149 | kernel_size=kernel_size, 150 | n_groups=n_groups, 151 | ), 152 | Upsample1d(dim_in) if not is_last else nn.Identity(), 153 | ] 154 | ) 155 | ) 156 | 157 | final_conv = nn.Sequential( 158 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 159 | nn.Conv1d(start_dim, input_dim, 1), 160 | ) 161 | 162 | self.diffusion_step_encoder = diffusion_step_encoder 163 | self.up_modules = up_modules 164 | self.down_modules = down_modules 165 | self.final_conv = final_conv 166 | 167 | def forward( 168 | self, 169 | sample: torch.Tensor, 170 | timestep: Union[torch.Tensor, float, int], 171 | global_cond=None, 172 | ): 173 | """ 174 | x: (B,T,input_dim) 175 | timestep: (B,) or int, diffusion step 176 | global_cond: (B,global_cond_dim) 177 | output: (B,T,input_dim) 178 | """ 179 | # (B,T,C) 180 | sample = sample.moveaxis(-1, -2) 181 | # (B,C,T) 182 | 183 | # 1. time 184 | timesteps = timestep 185 | if not torch.is_tensor(timesteps): 186 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 187 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) 188 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 189 | timesteps = timesteps[None].to(sample.device) 190 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 191 | timesteps = timesteps.expand(sample.shape[0]) 192 | 193 | global_feature = self.diffusion_step_encoder(timesteps) 194 | 195 | if global_cond is not None: 196 | global_feature = torch.cat([global_feature, global_cond], axis=-1) 197 | 198 | x = sample 199 | h = [] 200 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 201 | x = resnet(x, global_feature) 202 | x = resnet2(x, global_feature) 203 | h.append(x) 204 | x = downsample(x) 205 | 206 | for mid_module in self.mid_modules: 207 | x = mid_module(x, global_feature) 208 | 209 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 210 | x = torch.cat((x, h.pop()), dim=1) 211 | x = resnet(x, global_feature) 212 | x = resnet2(x, global_feature) 213 | x = upsample(x) 214 | 215 | x = self.final_conv(x) 216 | 217 | # (B,C,T) 218 | x = x.moveaxis(-1, -2) 219 | # (B,T,C) 220 | return x 221 | -------------------------------------------------------------------------------- /core/networks/conv1d_components.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Downsample1d(nn.Module): 5 | def __init__(self, dim): 6 | super().__init__() 7 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 8 | 9 | def forward(self, x): 10 | return self.conv(x) 11 | 12 | 13 | class Upsample1d(nn.Module): 14 | def __init__(self, dim): 15 | super().__init__() 16 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 17 | 18 | def forward(self, x): 19 | return self.conv(x) 20 | 21 | 22 | class Conv1dBlock(nn.Module): 23 | """ 24 | Conv1d --> GroupNorm --> Mish 25 | """ 26 | 27 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 28 | super().__init__() 29 | 30 | self.block = nn.Sequential( 31 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 32 | nn.GroupNorm(n_groups, out_channels), 33 | nn.Mish(), 34 | ) 35 | 36 | def forward(self, x): 37 | return self.block(x) 38 | -------------------------------------------------------------------------------- /core/networks/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SinusoidalPosEmb(nn.Module): 7 | def __init__(self, dim): 8 | super().__init__() 9 | self.dim = dim 10 | 11 | def forward(self, x): 12 | device = x.device 13 | half_dim = self.dim // 2 14 | emb = math.log(10000) / (half_dim - 1) 15 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 16 | emb = x[:, None] * emb[None, :] 17 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 18 | return emb 19 | -------------------------------------------------------------------------------- /core/trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class BaseDiffusionPolicyTrainer: 5 | def __init__( 6 | self, 7 | ): 8 | pass 9 | 10 | @abstractmethod 11 | def train(self, *args, **kwargs): 12 | raise NotImplementedError 13 | 14 | @abstractmethod 15 | def save_checkpoint(self, path: str): 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /core/trainers/quadrotor_diffusion_policy_trainer.py: -------------------------------------------------------------------------------- 1 | from diffusers.optimization import get_scheduler 2 | from diffusers.training_utils import EMAModel 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from tqdm.auto import tqdm 7 | from typing import Dict, Optional 8 | 9 | from core.controllers.quadrotor_diffusion_policy import build_noise_scheduler_from_config 10 | from core.dataset.quadrotor_dataset import PlanarQuadrotorStateDataset 11 | from core.trainers.base_trainer import BaseDiffusionPolicyTrainer 12 | from utils.utils import get_device 13 | 14 | 15 | def build_dataloader_from_dataset_and_config(config: Dict, dataset: torch.utils.data.Dataset): 16 | return torch.utils.data.DataLoader( 17 | dataset, 18 | batch_size=config["trainer"]["batch_size"], 19 | shuffle=True, 20 | pin_memory=True, 21 | ) 22 | 23 | 24 | class PlanarQuadrotorDiffusionPolicyTrainer(BaseDiffusionPolicyTrainer): 25 | def __init__( 26 | self, net: nn.Module, dataset: PlanarQuadrotorStateDataset, config: Dict, device: Optional[str] = None 27 | ): 28 | self.net = net 29 | self.noise_scheduler = build_noise_scheduler_from_config(config) 30 | self.dataset = dataset 31 | self.set_config(config) 32 | self.device = get_device() if device is None else torch.device(device) 33 | 34 | self.net.to(self.device) 35 | 36 | # build optimizer 37 | if config["trainer"]["optimizer"]["name"].lower() == "adamw": 38 | self.optimizer = torch.optim.AdamW( 39 | params=self.net.parameters(), 40 | lr=config["trainer"]["optimizer"]["learning_rate"], 41 | weight_decay=config["trainer"]["optimizer"]["weight_decay"], 42 | ) 43 | else: 44 | raise NotImplementedError 45 | 46 | # build dataset 47 | self.dataloader = build_dataloader_from_dataset_and_config(config, dataset) 48 | 49 | # set EMA 50 | self.use_ema = config["trainer"]["use_ema"] 51 | self.ema = EMAModel(parameters=self.net.parameters(), power=0.75) if self.use_ema else None 52 | 53 | def prepare_inputs(self, batch): 54 | # data normalized in dataset 55 | # device transfer 56 | obs_cond = batch["obs"].to(self.device, dtype=torch.float32) # FiLM conditioning 57 | action = batch["action"].to(self.device, dtype=torch.float32) 58 | batch_size = obs_cond.shape[0] 59 | return obs_cond, action, batch_size 60 | 61 | def optimization_step(self, action, obs_cond, batch_size): 62 | # sample noise to add to actions 63 | noise = torch.randn(action.shape, device=self.device) 64 | 65 | # sample a diffusion iteration for each data point 66 | timesteps = torch.randint( 67 | 0, self.noise_scheduler.config.num_train_timesteps, (batch_size,), device=self.device 68 | ).long() 69 | 70 | # add noise to the clean images according to the noise magnitude at each diffusion iteration 71 | # (this is the forward diffusion process) 72 | noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps) 73 | 74 | # predict the noise residual 75 | noise_pred = self.net(noisy_actions, timesteps, global_cond=obs_cond) 76 | 77 | # L2 loss 78 | loss = nn.functional.mse_loss(noise_pred, noise) 79 | 80 | # optimize 81 | loss.backward() 82 | self.optimizer.step() 83 | self.optimizer.zero_grad() 84 | # step lr scheduler every batch 85 | # this is different from standard pytorch behavior 86 | self.lr_scheduler.step() 87 | 88 | # update Exponential Moving Average of the model weights 89 | if self.use_ema: 90 | self.ema.step(self.net.parameters()) 91 | 92 | return loss 93 | 94 | def train(self, num_epochs: int, save_ckpt_epoch: int = None): 95 | if save_ckpt_epoch is None: 96 | save_ckpt_epoch = num_epochs 97 | 98 | # set learning rate scheduler 99 | self.lr_scheduler = get_scheduler( 100 | name=self.config["trainer"]["lr_scheduler"]["name"], 101 | optimizer=self.optimizer, 102 | num_warmup_steps=self.config["trainer"]["lr_scheduler"]["num_warmup_steps"], 103 | num_training_steps=len(self.dataloader) * num_epochs, 104 | ) 105 | 106 | # training loop 107 | trn_loss = [] 108 | with tqdm(range(num_epochs), desc="Epoch") as tglobal: 109 | # epoch loop 110 | for epoch_idx in tglobal: 111 | epoch_loss = list() 112 | # batch loop 113 | with tqdm(self.dataloader, desc="Batch", leave=False) as tepoch: 114 | for nbatch in tepoch: 115 | obs_cond, action, B = self.prepare_inputs(nbatch) 116 | loss = self.optimization_step(action, obs_cond, B) 117 | 118 | # logging 119 | loss_cpu = loss.item() 120 | epoch_loss.append(loss_cpu) 121 | tepoch.set_postfix(loss=loss_cpu) 122 | 123 | tglobal.set_postfix(loss=np.mean(epoch_loss)) 124 | trn_loss.append(np.mean(epoch_loss)) 125 | 126 | # save intermediate ckpt 127 | if (epoch_idx + 1) % save_ckpt_epoch == 0: 128 | self.save_checkpoint(path=f"ckpt_ep{epoch_idx}.ckpt") 129 | 130 | return trn_loss 131 | 132 | def save_checkpoint(self, path: str): 133 | save_model = self.net 134 | if self.config["trainer"]["use_ema"]: 135 | self.ema.copy_to(save_model.parameters()) 136 | torch.save(save_model.state_dict(), path) 137 | 138 | def set_config(self, config: Dict): 139 | self.config = config 140 | self.obs_horizon = config["obs_horizon"] 141 | self.obs_dim = config["controller"]["networks"]["obs_dim"] 142 | self.action_horizon = config["action_horizon"] 143 | self.action_dim = config["controller"]["networks"]["action_dim"] 144 | self.pred_horizon = config["pred_horizon"] 145 | self.obstacle_encode_dim = config["controller"]["networks"]["obstacle_encode_dim"] 146 | 147 | 148 | class PlanarQuadrotorDiffusionPolicyFineTuningTrainer(PlanarQuadrotorDiffusionPolicyTrainer): 149 | def __init__(self, net: nn.Module, dataset: PlanarQuadrotorStateDataset, config: Dict, device: str | None = None): 150 | super().__init__(net, dataset, config, device) 151 | 152 | def optimization_step(self, action, obs_cond, batch_size): 153 | # sample noise to add to actions 154 | noise = torch.randn(action.shape, device=self.device) 155 | 156 | # sample a diffusion iteration for each data point 157 | timesteps = torch.randint( 158 | self.noise_scheduler.config.num_train_timesteps - 1, 159 | self.noise_scheduler.config.num_train_timesteps, 160 | (batch_size,), 161 | device=self.device, 162 | ).long() 163 | 164 | # add noise to the clean images according to the noise magnitude at each diffusion iteration 165 | # (this is the forward diffusion process) 166 | noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps) 167 | 168 | # predict the noise residual 169 | noise_pred = self.net(noisy_actions, timesteps, global_cond=obs_cond) 170 | 171 | # L2 loss 172 | loss = nn.functional.mse_loss(noisy_actions - noise_pred, action) 173 | 174 | # optimize 175 | loss.backward() 176 | self.optimizer.step() 177 | self.optimizer.zero_grad() 178 | # step lr scheduler every batch 179 | # this is different from standard pytorch behavior 180 | self.lr_scheduler.step() 181 | 182 | # update Exponential Moving Average of the model weights 183 | if self.use_ema: 184 | self.ema.step(self.net.parameters()) 185 | 186 | return loss 187 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "970e2b9e-ccce-47b2-bccc-9471841cce44", 6 | "metadata": {}, 7 | "source": [ 8 | "## Installation required in Colab" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "tckkiSjNY0Is", 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "tckkiSjNY0Is", 20 | "outputId": "627b5bff-c792-4954-947f-40baab0de93c" 21 | }, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Cloning into 'diffusion_policy_quadrotor'...\n", 28 | "remote: Enumerating objects: 155, done.\u001b[K\n", 29 | "remote: Counting objects: 100% (155/155), done.\u001b[K\n", 30 | "remote: Compressing objects: 100% (104/104), done.\u001b[K\n", 31 | "remote: Total 155 (delta 66), reused 111 (delta 38), pack-reused 0\u001b[K\n", 32 | "Receiving objects: 100% (155/155), 4.19 MiB | 25.69 MiB/s, done.\n", 33 | "Resolving deltas: 100% (66/66), done.\n" 34 | ] 35 | } 36 | ], 37 | "source": [ 38 | "! git clone https://github.com/shaoanlu/diffusion_policy_quadrotor.git" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "id": "7rrgEckSZEdx", 45 | "metadata": { 46 | "colab": { 47 | "base_uri": "https://localhost:8080/" 48 | }, 49 | "id": "7rrgEckSZEdx", 50 | "outputId": "5cd7bcf5-5b4c-4a48-996b-7257cdf7f03e" 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "/content/diffusion_policy_quadrotor\n", 58 | "\u001b[0m\u001b[01;34massets\u001b[0m/ \u001b[01;34mconfig\u001b[0m/ \u001b[01;34mcore\u001b[0m/ demo.ipynb LICENSE pyproject.toml README.md \u001b[01;34mutils\u001b[0m/\n" 59 | ] 60 | } 61 | ], 62 | "source": [ 63 | "%cd diffusion_policy_quadrotor\n", 64 | "%ls" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "id": "rK2Hjrq_Y4T6", 71 | "metadata": { 72 | "id": "rK2Hjrq_Y4T6" 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "%%capture\n", 77 | "!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 jax==0.4.23 jaxlib==0.4.23" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "id": "ee7a0434-3333-4502-9563-adcc12d4a413", 83 | "metadata": { 84 | "id": "ee7a0434-3333-4502-9563-adcc12d4a413" 85 | }, 86 | "source": [ 87 | "## Description\n", 88 | "\n", 89 | "This notebook demonstrate using a diffusion policy controller to drive a quadrotor moving from (0, 0) to (5, 5) with random circle obstacles presented." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 1, 95 | "id": "fffc0bb0-584b-4602-885f-a1cfa06d21b6", 96 | "metadata": { 97 | "id": "fffc0bb0-584b-4602-885f-a1cfa06d21b6" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "%load_ext autoreload\n", 102 | "%autoreload 2" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 2, 108 | "id": "670f6c7f-93c9-47dc-b2e4-81ea8949b3f9", 109 | "metadata": { 110 | "id": "670f6c7f-93c9-47dc-b2e4-81ea8949b3f9" 111 | }, 112 | "outputs": [], 113 | "source": [ 114 | "import numpy as np\n", 115 | "import os\n", 116 | "import torch\n", 117 | "import yaml\n", 118 | "import collections\n", 119 | "from tqdm.auto import tqdm\n", 120 | "import gdown" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 3, 126 | "id": "fa8f1de5-924a-4889-aacc-bac7601012c5", 127 | "metadata": { 128 | "colab": { 129 | "base_uri": "https://localhost:8080/", 130 | "height": 87, 131 | "referenced_widgets": [ 132 | "49a3b00a700b40f8b1de991477efe2b9", 133 | "fb587d9dbfa84d72af2b3c45d46c2cc1", 134 | "7487091f00dd4e3bba795ebab6b0f5b1", 135 | "a817704da5184ef6922fb679bf0c395e", 136 | "a3a8dd4705b64b5f9b21b2e3d4ce24b8", 137 | "3a953a43ed574a4c8e8ebc52d5993347", 138 | "1900dfab0c6e4f8a8350f2d024fc22c6", 139 | "a9afd6d521f44216adb119a54f531376", 140 | "632b2841ae2d44b18e67fe288ba136a1", 141 | "01969ff8ad4140328cda2a14df8fe439", 142 | "fca3ddf1eb6d439ba86878e00e4560fd" 143 | ] 144 | }, 145 | "id": "fa8f1de5-924a-4889-aacc-bac7601012c5", 146 | "outputId": "730f9c66-b606-4840-f614-540a89379eef" 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "from utils.normalizers import LinearNormalizer\n", 151 | "from core.controllers.quadrotor_diffusion_policy import QuadrotorDiffusionPolicy, build_networks_from_config, build_noise_scheduler_from_config\n", 152 | "from core.controllers.quadrotor_clf_cbf_qp import QuadrotorCLFCBFController\n", 153 | "from core.env.planar_quadrotor import PlanarQuadrotorEnv\n", 154 | "from utils.visualization import visualize_quadrotor_simulation_result" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "id": "06cf3d67-3f96-457e-a634-77b6e5f2e31e", 160 | "metadata": { 161 | "id": "06cf3d67-3f96-457e-a634-77b6e5f2e31e" 162 | }, 163 | "source": [ 164 | "## Load the config file" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 4, 170 | "id": "0c8c8f90-8ac2-4f87-943d-e41b7ff9ff6a", 171 | "metadata": { 172 | "id": "0c8c8f90-8ac2-4f87-943d-e41b7ff9ff6a", 173 | "scrolled": true 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "with open(\"config/config.yaml\", \"r\") as file:\n", 178 | " config = yaml.safe_load(file)\n", 179 | "\n", 180 | "# Whether to use a finetuned model trained following tricks mentioned in\n", 181 | "# [Fine-Tuning Image-Conditional Diffusion Models is Easier than You Think](https://arxiv.org/abs/2409.11355)\n", 182 | "use_single_step_inference = config.get(\"controller\").get(\"common\").get(\"use_single_step_inference\", False)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "id": "dd56c89a-520f-46b5-b9d6-ebef0160b184", 188 | "metadata": { 189 | "id": "dd56c89a-520f-46b5-b9d6-ebef0160b184" 190 | }, 191 | "source": [ 192 | "## Instantiate the controller" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 5, 198 | "id": "cb0707a7-3d66-46da-b722-2b2f42d8e099", 199 | "metadata": { 200 | "id": "cb0707a7-3d66-46da-b722-2b2f42d8e099" 201 | }, 202 | "outputs": [], 203 | "source": [ 204 | "clf_cbf_controller = QuadrotorCLFCBFController(config=config)\n", 205 | "\n", 206 | "controller = QuadrotorDiffusionPolicy(\n", 207 | " model=build_networks_from_config(config),\n", 208 | " noise_scheduler=build_noise_scheduler_from_config(config),\n", 209 | " normalizer=LinearNormalizer(),\n", 210 | " clf_cbf_controller=None, # set as clf_cbf_controller to enable CLF-CBF traj refinement\n", 211 | " config=config,\n", 212 | " device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n", 213 | ")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "cac28d88-e304-46b3-bc28-5fc54cd3705d", 219 | "metadata": { 220 | "id": "cac28d88-e304-46b3-bc28-5fc54cd3705d" 221 | }, 222 | "source": [ 223 | "## Download and load pretrained weights" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "727e3937-dd21-46a2-9fff-c41c1d653022", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "# download pretrained weights from Google drive\n", 234 | "if use_single_step_inference:\n", 235 | " ckpts_path = \"ema_noise_pred_net2_ph96_oh2_ah10_v8_singlestepFT.ckpt\"\n", 236 | " if not os.path.isfile(ckpts_path):\n", 237 | " gdown.download(id=\"1UhxlzoQ6DOt0HZhokzU4ktl2MfzlA2Ii\", output=ckpts_path, quiet=False)\n", 238 | "else:\n", 239 | " ckpts_path = \"ema_noise_pred_net2_ph96_oh2_ah10_v8.ckpt\"\n", 240 | " if not os.path.isfile(ckpts_path):\n", 241 | " gdown.download(id=\"1-as6EqMLECxU7IVLZZIEDAcXMox_RkI_\", output=ckpts_path, quiet=False)\n", 242 | "\n", 243 | "# load weights\n", 244 | "controller.load_weights(ckpts_path)" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "id": "a0fda0f4-f5cb-4fa9-86da-abf991edfb4b", 250 | "metadata": { 251 | "id": "a0fda0f4-f5cb-4fa9-86da-abf991edfb4b" 252 | }, 253 | "source": [ 254 | "## Instantiate the simulator" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 7, 260 | "id": "6b0ae93e-8fe6-40af-92e6-12643aa21f6d", 261 | "metadata": { 262 | "id": "6b0ae93e-8fe6-40af-92e6-12643aa21f6d" 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "sim = PlanarQuadrotorEnv(config)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "a679998d-7a2e-48dc-b0a4-30c24314e307", 273 | "metadata": { 274 | "id": "a679998d-7a2e-48dc-b0a4-30c24314e307" 275 | }, 276 | "outputs": [], 277 | "source": [] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "id": "1f1c50bc-e529-432f-92a9-9bb3c6e9fdf7", 282 | "metadata": { 283 | "id": "1f1c50bc-e529-432f-92a9-9bb3c6e9fdf7" 284 | }, 285 | "source": [ 286 | "## Run simulation" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 8, 292 | "id": "eecaed67-9814-4a62-aa7f-2caf5d5b43aa", 293 | "metadata": { 294 | "id": "eecaed67-9814-4a62-aa7f-2caf5d5b43aa" 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "def generate_random_obstacles():\n", 299 | " num_obstacles = np.random.randint(1, 8)\n", 300 | " obs_center, obs_radius = np.empty((num_obstacles, 2)), np.ones((num_obstacles,))\n", 301 | " for obs_idx in range(num_obstacles): # set XY position of each obstacle\n", 302 | " obs_center[obs_idx, ...] = np.random.randint(1, 8, size=(2,))\n", 303 | " obs_radius[obs_idx] = np.random.uniform(0.2, 1.5)\n", 304 | " return obs_center, obs_radius\n", 305 | "\n", 306 | "def encode_obstacle_info(obs_center: np.ndarray, obs_radius: np.ndarray):\n", 307 | " obs_encode = np.zeros((7, 7))\n", 308 | " for center, radius in zip(obs_center, obs_radius):\n", 309 | " obs_encode[tuple((center-1).astype(np.int32))] = radius\n", 310 | " obs_encode = obs_encode.flatten()\n", 311 | " return obs_encode" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 9, 317 | "id": "3250c190-d380-423c-8455-21a6c1b2c093", 318 | "metadata": { 319 | "colab": { 320 | "base_uri": "https://localhost:8080/", 321 | "height": 67, 322 | "referenced_widgets": [ 323 | "fa8ca951c0f54d3c965b136791c54f60", 324 | "14869b876e6143a9848a8421b8da93c3", 325 | "b00b242d48c34b739b335e6137bbec52", 326 | "ba8a57462eb4492880c4980ee126f888", 327 | "7afde10b06494393ab9b53818c580fdc", 328 | "bc399470fb094590a12ba13e98c1f6a8", 329 | "eba7a2b036374b1c9c329d884ce16ad4", 330 | "13066600e23245888e019523e0200005", 331 | "304615b68c9d496ea35928d1546df4ca", 332 | "3de6654be9de4d22b15fb28ad7625f57", 333 | "81e23a9379b34c0cb78e375d47a63a89" 334 | ] 335 | }, 336 | "id": "3250c190-d380-423c-8455-21a6c1b2c093", 337 | "outputId": "4c9609f0-a944-42f7-909c-c23cab24dbf0" 338 | }, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "application/vnd.jupyter.widget-view+json": { 343 | "model_id": "3d93b38cf8264ebe80428df1989880d1", 344 | "version_major": 2, 345 | "version_minor": 0 346 | }, 347 | "text/plain": [ 348 | "Eval: 0%| | 0/400 [00:00 max_steps:\n", 400 | " done = True\n", 401 | " if done:\n", 402 | " break" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "b006ad91-e853-4a03-80c5-7ecb142bc475", 409 | "metadata": { 410 | "id": "b006ad91-e853-4a03-80c5-7ecb142bc475" 411 | }, 412 | "outputs": [], 413 | "source": [] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "id": "602bdd60-2b90-48db-a29b-7f8645501142", 418 | "metadata": { 419 | "id": "602bdd60-2b90-48db-a29b-7f8645501142" 420 | }, 421 | "source": [ 422 | "## Visualize result" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 10, 428 | "id": "b92bcec2-d95d-4361-9b92-991346d83950", 429 | "metadata": { 430 | "colab": { 431 | "base_uri": "https://localhost:8080/", 432 | "height": 435 433 | }, 434 | "id": "b92bcec2-d95d-4361-9b92-991346d83950", 435 | "outputId": "0d8c5a83-8002-4deb-a87b-48a04af6d0af" 436 | }, 437 | "outputs": [ 438 | { 439 | "data": { 440 | "image/png": "", 441 | "text/plain": [ 442 | "
" 443 | ] 444 | }, 445 | "metadata": {}, 446 | "output_type": "display_data" 447 | } 448 | ], 449 | "source": [ 450 | "states = np.array(states)\n", 451 | "visualize_quadrotor_simulation_result(sim, states, obs_center=obs_center, obs_radius=obs_radius)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "id": "33740290-22ae-45bd-822f-0bea782691bd", 458 | "metadata": { 459 | "id": "33740290-22ae-45bd-822f-0bea782691bd" 460 | }, 461 | "outputs": [], 462 | "source": [] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "id": "f15cae17-f337-4ba0-a4ea-5a10a8bc0390", 468 | "metadata": { 469 | "id": "f15cae17-f337-4ba0-a4ea-5a10a8bc0390" 470 | }, 471 | "outputs": [], 472 | "source": [] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": null, 477 | "id": "7f889988-1fef-4c11-b113-457247db185c", 478 | "metadata": { 479 | "id": "7f889988-1fef-4c11-b113-457247db185c" 480 | }, 481 | "outputs": [], 482 | "source": [] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "id": "6c503b86-5211-4403-a830-b8a100e73b97", 488 | "metadata": { 489 | "id": "6c503b86-5211-4403-a830-b8a100e73b97" 490 | }, 491 | "outputs": [], 492 | "source": [] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "id": "f0772564-5e68-4884-a266-f5467db4a433", 498 | "metadata": { 499 | "id": "f0772564-5e68-4884-a266-f5467db4a433" 500 | }, 501 | "outputs": [], 502 | "source": [] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "id": "3c1ad555-40ce-4917-b0f7-e26dce4d485d", 508 | "metadata": { 509 | "id": "3c1ad555-40ce-4917-b0f7-e26dce4d485d" 510 | }, 511 | "outputs": [], 512 | "source": [] 513 | } 514 | ], 515 | "metadata": { 516 | "accelerator": "GPU", 517 | "colab": { 518 | "gpuType": "T4", 519 | "provenance": [] 520 | }, 521 | "kernelspec": { 522 | "display_name": "Python 3 (ipykernel)", 523 | "language": "python", 524 | "name": "python3" 525 | }, 526 | "language_info": { 527 | "codemirror_mode": { 528 | "name": "ipython", 529 | "version": 3 530 | }, 531 | "file_extension": ".py", 532 | "mimetype": "text/x-python", 533 | "name": "python", 534 | "nbconvert_exporter": "python", 535 | "pygments_lexer": "ipython3", 536 | "version": "3.10.14" 537 | }, 538 | "widgets": { 539 | "application/vnd.jupyter.widget-state+json": { 540 | "state": {}, 541 | "version_major": 2, 542 | "version_minor": 0 543 | } 544 | } 545 | }, 546 | "nbformat": 4, 547 | "nbformat_minor": 5 548 | } 549 | -------------------------------------------------------------------------------- /learning_note.md: -------------------------------------------------------------------------------- 1 | ## Learning note 2 | ### Main insight 3 | - The model should learn the residuals (gradient of the denoising) if possible. This greatly stabilizes the training. 4 | - Advantages of diffusion model: 1) capability of modeling multi-modality, 2) stable training, and 3) temporally output consistency. 5 | - Iteratively add training data of failure modes to make extrapolation into interpolation. 6 | ### Scribbles 7 | - The trained policy does not 100% reach the goal without collision (there is no collision in its training data). 8 | - Unable to recover from OOD data. 9 | - Long observation might be harmful to the performance, possibly due to the increased possibility of model overfitting. 10 | - The diffusion policy [paper](https://arxiv.org/pdf/2303.04137) also discusses this in its appendix. 11 | - I feel we don't need diffusion models for the simple task in this repo, supervised learning might be equally effective. 12 | - The controller struggles with performance issues when extreme values (maximum or minimum) are presented in the conditional vector. 13 | - For instance, it collides more on obstacles with a maximum radius of 1.5. 14 | - Collect more data to make everything interpolations instead of extrapolations. 15 | - Even though the loss curve appears saturated, the performance of the controller can still improve as training continues. 16 | - The training loss curves of the diffusion model are extremely smooth btw. 17 | - On the contrary, it might be difficult to know if the model is overfitting or not by looking at the trajectory as well as the the denoising process. 18 | - But in general I feel there is little harm training duffusion model as long as possible. 19 | - DDPM and DDIM samplers yield the best result. 20 | - Inference is not in real-time. The controller is set to sun 100Hz. 21 | 22 | ### Possible reasons for failures on collision avoidance 23 | 1. There is no data having collision in the training data. 24 | 2. Policy learned with imitation learning can exhibit accumulated error during closed-loop control 25 | 26 | When the quadrotor getting too close to the obstacles (due to 2), the input state becomes OOD (due to 1), therefore the diffusion policy is unable to recover from such situation. 27 | 28 | - Possible fix: adding training data that the quadrotor recovers from collision. 29 | 30 | ### Things that didn't work 31 | - Tried encoding distances to each obstacle. Did not observe improvement in terms of collision avoidance. 32 | - Tried using vision encoder to replace obstacle encoding. Didn't see performance improvement. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | 3 | # lint.select = ["A", "ANN", "B", "C90", "D", "E", "F", "I", "N", "COM", "DTZ", "PD", "RUF", "TID", "UP", "W"] 4 | # lint.ignore = ["D203", "D212"] 5 | 6 | lint.fixable = ["I", "RUF100"] 7 | lint.unfixable = [] 8 | 9 | # Exclude a variety of commonly ignored directories. 10 | exclude = [ 11 | ".bzr", 12 | ".direnv", 13 | ".eggs", 14 | ".git", 15 | ".hg", 16 | ".mypy_cache", 17 | ".nox", 18 | ".pants.d", 19 | ".pytype", 20 | ".ruff_cache", 21 | ".svn", 22 | ".tox", 23 | ".venv", 24 | "__pypackages__", 25 | "_build", 26 | "buck-out", 27 | "build", 28 | "dist", 29 | "node_modules", 30 | "venv", 31 | ] 32 | 33 | line-length = 120 34 | 35 | # Allow unused variables when underscore-prefixed. 36 | lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 37 | 38 | target-version = "py310" 39 | 40 | [tool.ruff.lint.mccabe] 41 | max-complexity = 10 42 | 43 | 44 | [tool.pyright] 45 | typeCheckingMode = "lazy" 46 | defineConstant = { DEBUG = true } 47 | 48 | reportMissingImports = false 49 | reportMissingTypeStubs = false 50 | reportInvalidStringEscapeSequence = false 51 | 52 | pythonVersion = "3.10" 53 | pythonPlatform = "Linux" 54 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/tests/__init__.py -------------------------------------------------------------------------------- /tests/config/test_config.yaml: -------------------------------------------------------------------------------- 1 | name: "planar_quadrotor" 2 | 3 | pred_horizon: 96 4 | obs_horizon: 2 5 | action_horizon: 10 6 | 7 | controller: 8 | networks: 9 | obs_dim: 6 10 | action_dim: 6 11 | obstacle_encode_dim: 49 12 | noise_scheduler: 13 | type: "ddpm" 14 | ddpm: 15 | num_train_timesteps: 100 # number of diffusion iterations 16 | beta_schedule: "squaredcos_cap_v2" 17 | clip_sample: true # required when predict_epsilon=False 18 | prediction_type: "epsilon" 19 | ddim: # faster inference 20 | num_train_timesteps: 100 21 | beta_schedule: "squaredcos_cap_v2" 22 | clip_sample: true 23 | prediction_type: "epsilon" 24 | dpmsolver: # faster inference, experimental 25 | num_train_timesteps: 100 26 | beta_schedule: "squaredcos_cap_v2" 27 | prediction_type: "epsilon" 28 | use_karras_sigmas: true 29 | 30 | trainer: 31 | use_ema: true 32 | batch_size: 8 33 | optimizer: 34 | name: "adamw" 35 | learning_rate: 1.0e-4 36 | weight_decay: 1.0e-6 37 | lr_scheduler: 38 | name: "cosine" 39 | num_warmup_steps: 500 40 | 41 | dataloader: 42 | batch_size: 3 43 | 44 | normalizer: 45 | action: 46 | min: [-2.6515920785069467, -10.169477462768555, -4.270973491668701, -13.883374419993748, -1.9637073183059692, -20.395111083984375] 47 | max: [8.0543811917305, 6.976394176483154, 8.477932775914669, 11.327190831233095, 2.5276688146591186, 18.05487823486328] 48 | observation: 49 | min: [-2.649836301803589, -9.564324378967285, -4.264063358306885, -13.772777557373047, -1.9476231336593628, -17.225351333618164] 50 | max: [8.05234146118164, 6.976299285888672, 8.474746704101562, 10.96181583404541, 2.5151419639587402, 18.054880142211914] 51 | 52 | simulator: 53 | m_q: 1.0 # kg 54 | I_xx: 0.1 # kg.m^2 55 | l_q: 0.3 # m, length of the quadrotor 56 | g: 9.81 57 | dt: 0.01 58 | -------------------------------------------------------------------------------- /tests/fixtures/test_dataset.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/tests/fixtures/test_dataset.joblib -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | import yaml 4 | 5 | from core.dataset.quadrotor_dataset import PlanarQuadrotorStateDataset 6 | 7 | 8 | class TestPlanarQuadrotorStateDataset(unittest.TestCase): 9 | def setUp(self): 10 | self.dataset_path = "tests/fixtures/test_dataset.joblib" 11 | 12 | with open("tests/config/test_config.yaml", "r") as file: 13 | self.config = yaml.safe_load(file) 14 | 15 | self.obs_horizon = self.config["obs_horizon"] 16 | self.action_horizon = self.config["action_horizon"] 17 | self.pred_horizon = self.config["pred_horizon"] 18 | self.action_dim = self.config["controller"]["networks"]["action_dim"] 19 | self.obs_dim = self.config["controller"]["networks"]["obs_dim"] 20 | self.obstacle_encode_dim = self.config["controller"]["networks"]["obstacle_encode_dim"] 21 | 22 | def test_init(self): 23 | dataset = PlanarQuadrotorStateDataset(dataset_path=self.dataset_path, config=self.config) 24 | self.assertTrue(isinstance(dataset, PlanarQuadrotorStateDataset)) 25 | 26 | def test_iter(self): 27 | batch_size = self.config["dataloader"]["batch_size"] 28 | dataset = PlanarQuadrotorStateDataset(dataset_path=self.dataset_path, config=self.config) 29 | dataloader = torch.utils.data.DataLoader( 30 | dataset, 31 | batch_size=batch_size, 32 | shuffle=True, 33 | pin_memory=True, 34 | ) 35 | 36 | # batch context matches expectecd shapes 37 | batch = next(iter(dataloader)) 38 | self.assertEqual(batch["obs"].shape, (batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim)) 39 | self.assertEqual(batch["action"].shape, (batch_size, self.pred_horizon, self.action_dim)) 40 | 41 | 42 | if __name__ == "__main__": 43 | unittest.main() 44 | -------------------------------------------------------------------------------- /tests/test_networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import unittest 3 | import yaml 4 | 5 | from core.networks.conditional_unet1d import ConditionalUnet1D 6 | 7 | 8 | class TestConditionalUnet1D(unittest.TestCase): 9 | def setUp(self): 10 | with open("tests/config/test_config.yaml", "r") as file: 11 | self.config = yaml.safe_load(file) 12 | 13 | self.obs_horizon = self.config["obs_horizon"] 14 | self.action_horizon = self.config["action_horizon"] 15 | self.pred_horizon = self.config["pred_horizon"] 16 | self.action_dim = self.config["controller"]["networks"]["action_dim"] 17 | self.obs_dim = self.config["controller"]["networks"]["obs_dim"] 18 | self.obstacle_encode_dim = self.config["controller"]["networks"]["obstacle_encode_dim"] 19 | 20 | def test_init(self): 21 | noise_pred_net = ConditionalUnet1D( 22 | input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim 23 | ) 24 | self.assertTrue(isinstance(noise_pred_net, ConditionalUnet1D)) 25 | 26 | def test_inference(self): 27 | net = ConditionalUnet1D( 28 | input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim 29 | ) 30 | 31 | # example inputs 32 | noised_action = torch.randn((1, self.pred_horizon, self.action_dim)) 33 | obs = torch.zeros((1, self.obs_horizon * self.obs_dim + self.obstacle_encode_dim)) 34 | diffusion_iter = torch.zeros((1,)) 35 | 36 | # the noise prediction network 37 | # takes noisy action, diffusion iteration and observation as input 38 | # predicts the noise added to action 39 | noise = net(sample=noised_action, timestep=diffusion_iter, global_cond=obs.flatten(start_dim=1)) 40 | # removing noise 41 | denoised_action = noised_action - noise 42 | 43 | self.assertEqual(denoised_action.shape, (1, self.pred_horizon, self.action_dim)) 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /utils/normalizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class BaseNormalizer: 5 | def normalize_data(self, *args, **kwargs): 6 | raise NotImplementedError() 7 | 8 | def unnormalize_data(self, *args, **kwargs): 9 | raise NotImplementedError() 10 | 11 | 12 | class LinearNormalizer(BaseNormalizer): 13 | def __init__(self): 14 | pass 15 | 16 | def normalize_data(self, data, stats): 17 | # nomalize to [0,1] 18 | ndata = (data - np.array(stats["min"])) / (np.array(stats["max"]) - np.array(stats["min"])) 19 | # normalize to [-1, 1] 20 | ndata = ndata * 2 - 1 21 | return ndata 22 | 23 | def unnormalize_data(self, ndata, stats): 24 | ndata = (ndata + 1) / 2 25 | data = ndata * (np.array(stats["max"]) - np.array(stats["min"])) + np.array(stats["min"]) 26 | return data 27 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_device(): 5 | if torch.cuda.is_available(): 6 | device = "cuda" 7 | else: 8 | device = "cpu" 9 | return torch.device(device) 10 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | 5 | def visualize_quadrotor_simulation_result( 6 | quadrotor, states: np.ndarray, obs_center: np.ndarray, obs_radius: np.ndarray 7 | ): 8 | def plot_obstacles(obs_center, obs_radius): 9 | for obs_p, obs_r in zip(obs_center, obs_radius): 10 | circle = plt.Circle( 11 | tuple(obs_p), 12 | obs_r, 13 | color="grey", 14 | fill=True, 15 | linestyle="--", 16 | linewidth=2, 17 | alpha=0.5, 18 | ) 19 | plt.gca().add_artist(circle) 20 | 21 | plt.figure() 22 | plt.gca().set_aspect("equal", adjustable="box") 23 | ys = [s[0] for s in states] 24 | zs = [s[2] for s in states] 25 | phis = [s[4] for s in states] 26 | # Generate circle for CBF 27 | plot_obstacles(obs_center, obs_radius) 28 | 29 | # Plot the trajectory 30 | plt.scatter(ys, zs) 31 | 32 | # Plot quadrotor pose 33 | y_ = [(y - quadrotor.l_q * np.cos(phi), y + quadrotor.l_q * np.cos(phi)) for (y, phi) in zip(ys[::10], phis[::10])] 34 | z_ = [(z - quadrotor.l_q * np.sin(phi), z + quadrotor.l_q * np.sin(phi)) for (z, phi) in zip(zs[::10], phis[::10])] 35 | for yy, zz in zip(y_, z_): 36 | plt.plot(yy, zz, marker="o", color="r", alpha=0.5) 37 | 38 | # plot start and end point 39 | init_x, init_y = states[0][0], states[0][2] 40 | plt.scatter(init_x, init_y, s=200, color="green", alpha=0.75, label="init. position") 41 | plt.scatter(5.0, 5.0, s=200, color="purple", alpha=0.75, label="target position") 42 | 43 | plt.xlim(-0.5, 7) 44 | plt.ylim(-0.5, 7) 45 | plt.grid() 46 | plt.show() 47 | --------------------------------------------------------------------------------