├── .github
└── workflows
│ ├── ruff.yml
│ └── unittest.yml
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── df_clf_cbf_comp.jpg
├── dp_sensitive_to_init_pos.jpg
├── model_input.jpg
├── result_anim.gif
└── result_plot.png
├── config
└── config.yaml
├── core
├── controllers
│ ├── base_controller.py
│ ├── quadrotor_clf_cbf_qp.py
│ └── quadrotor_diffusion_policy.py
├── dataset
│ └── quadrotor_dataset.py
├── env
│ └── planar_quadrotor.py
├── networks
│ ├── conditional_unet1d.py
│ ├── conv1d_components.py
│ └── positional_embedding.py
└── trainers
│ ├── base_trainer.py
│ └── quadrotor_diffusion_policy_trainer.py
├── demo.ipynb
├── learning_note.md
├── pyproject.toml
├── tests
├── __init__.py
├── config
│ └── test_config.yaml
├── fixtures
│ └── test_dataset.joblib
├── test_datasets.py
└── test_networks.py
├── train.ipynb
└── utils
├── normalizers.py
├── utils.py
└── visualization.py
/.github/workflows/ruff.yml:
--------------------------------------------------------------------------------
1 | name: Ruff
2 | on: [push, pull_request]
3 | jobs:
4 | ruff:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v4
8 | - uses: chartboost/ruff-action@v1
9 |
--------------------------------------------------------------------------------
/.github/workflows/unittest.yml:
--------------------------------------------------------------------------------
1 | name: Python Unit Tests
2 |
3 | # Defines when the action should run. This workflow triggers on push and pull requests to the main branch.
4 | on:
5 | push:
6 | branches: [ main ]
7 | pull_request:
8 | branches: [ main ]
9 |
10 | # Jobs that the workflow will execute
11 | jobs:
12 | run-unittests:
13 | # The type of runner that the job will run on
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ["3.10"]
18 |
19 | # Steps represent a sequence of tasks that will be executed as part of the job
20 | steps:
21 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
22 | - uses: actions/checkout@v3
23 |
24 | # Sets up a Python environment
25 | - name: Set up Python ${{ matrix.python-version }}
26 | uses: actions/setup-python@v4
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 |
30 | # Install dependencies
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install torch==1.13.1 torchvision==0.14.1 jax==0.4.23 jaxlib==0.4.23 diffusers==0.18.2 joblib
35 |
36 | # Run unittests using the Python unittest module
37 | - name: Run unittest
38 | run: python -m unittest discover
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.pyc
3 | tmp_imgs/*.*
4 | *-checkpoint.py
5 | *-checkpoint.ipynb
6 | *-checkpoint.yaml
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 shaoanlu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # diffusion_policy_quadrotor
2 | This repository provides a demonstration of imitation learning using a diffusion policy on quadrotor control. The implementation is adapted from the official Diffusion Policy [repository](https://github.com/real-stanford/diffusion_policy) with an additional feature of using CBf-CLF controller to improve the safety of the generated trajectory.
3 |
4 | ## Result
5 | The control task is to drive the quadrotor from the initial position (0, 0) to the goal position (5, 5) without collision with the obstacles. The animation shows the denoising process of the diffusion policy predicting future trajectory followed by the quadrotor applying the actions.
6 |
7 |
8 |
9 |
10 | ## Usage
11 | The notebook `demo.ipynb` demonstrates a closed-loop simulation using the diffusion policy controller for quadrotor collision avoidance. You can run it in colab [](https://colab.research.google.com/github/shaoanlu/diffusion_policy_quadrotor/blob/main/demo.ipynb).
12 |
13 | The training script is provided as `train.ipynb`.
14 |
15 | ## Dependencies
16 | The program was developed and tested in the following environment.
17 | - Python 3.10
18 | - `torch==2.2.1`
19 | - `jax==0.4.26`
20 | - `jaxlib==0.4.26`
21 | - `diffusers==0.27.2`
22 | - `torchvision==0.14.1`
23 | - `gdown` (to download pre-trained weights)
24 | - `joblib` (format of training data)
25 |
26 | ## Diffusion policy
27 | The policy takes 1) the latest N step of observation $o_t$ (position and velocity) and 2) the encoding of obstacle information $O_{BST}$ (a flattened 7x7 grid with obstacle radius as values) as input. The outputs are N steps of actions $a_t$ (future position and future velocity).
28 |
29 |
30 |
31 | *The quadrotor icon is from [flaticon](https://www.flaticon.com/free-icon/quadcopter_5447794).
32 |
33 |
34 | ### Deviation from the original implementation
35 | - Add a linear layer before the Mish activation to the condition encoder of `ConditionalResidualBlock1D`. This is to prevent the activation from truncating large negative values from the normalized observation.
36 | - A CLF-CBF-QP controller is implemented and used to modify the noisy actions during the denoising process of the policy. By default, this controller is disabled.
37 | - A finetuned model for single-step inference is used by default.
38 |
39 |
40 |
41 |
42 | ## References
43 | Papers
44 | - [Diffusion Policy: Visuomotor Policy Learning via Action Diffusion](https://diffusion-policy.cs.columbia.edu/) [arXiv:2303.04137]
45 | - [3D Diffusion Policy: Generalizable Visuomotor Policy Learning via Simple 3D Representations](https://3d-diffusion-policy.github.io/) [arXiv:2403.03954]
46 | - [Fine-Tuning Image-Conditional Diffusion Models is Easier than You Think](https://gonzalomartingarcia.github.io/diffusion-e2e-ft/) [arXiv:2409.11355]
47 |
48 | Videos and Lectures
49 | - [DeepLearning.AI: How Diffusion Models Work](https://www.deeplearning.ai/short-courses/how-diffusion-models-work/)
50 | - [[论文速览]Diffusion Policy: Visuomotor Policy Learning via Action Diff.[2303.04137]](https://www.bilibili.com/video/BV1Cu411Y7d7)
51 | - [[論文導讀]Planning with Diffusion for Flexible Behavior Synthesis解說 (含程式碼)](https://youtu.be/ciCcvWutle4)
52 | - [6.4210 Fall 2023 Lecture 18: Visuomotor Policies (via Behavior Cloning)](https://youtu.be/i-303tTtEig)
53 |
54 | ## Learning note
55 | ### Failure case: the diffusion policy controller failed to extrapolate from training data
56 | Figure: A failure case of the controller.
57 | - The left figure is a trajectory in the training data.
58 | - The middle figure is the closed-loop simulation result of the controller starting from the SAME initial position as the training data.
59 | - The right figure is the closed-loop simulation result of the controller starting from a DIFFERENT initial position, which resulted in a trajectory with collision.
60 |
61 |
62 |
63 | Refer to [`learning_note.md`](learning_note.md) for other notes.
64 |
65 |
--------------------------------------------------------------------------------
/assets/df_clf_cbf_comp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/df_clf_cbf_comp.jpg
--------------------------------------------------------------------------------
/assets/dp_sensitive_to_init_pos.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/dp_sensitive_to_init_pos.jpg
--------------------------------------------------------------------------------
/assets/model_input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/model_input.jpg
--------------------------------------------------------------------------------
/assets/result_anim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/result_anim.gif
--------------------------------------------------------------------------------
/assets/result_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/assets/result_plot.png
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | name: "planar_quadrotor"
2 |
3 | pred_horizon: 96
4 | obs_horizon: 2
5 | action_horizon: 10
6 |
7 | controller:
8 | common:
9 | sampling_time: 0.1 # sec
10 | use_single_step_inference: true
11 | networks:
12 | obs_dim: 6
13 | action_dim: 6
14 | obstacle_encode_dim: 49
15 | noise_scheduler:
16 | type: "ddpm"
17 | ddpm:
18 | num_train_timesteps: 100 # number of diffusion iterations
19 | beta_schedule: "squaredcos_cap_v2"
20 | clip_sample: true # required when predict_epsilon=False
21 | prediction_type: "epsilon"
22 | ddim: # faster inference
23 | num_train_timesteps: 100
24 | beta_schedule: "squaredcos_cap_v2"
25 | clip_sample: true
26 | prediction_type: "epsilon"
27 | dpmsolver: # faster inference, experimental
28 | num_train_timesteps: 100
29 | beta_schedule: "squaredcos_cap_v2"
30 | prediction_type: "epsilon"
31 | use_karras_sigmas: true
32 |
33 | cbf_clf_controller:
34 | denoising_guidance_step: 100 # equals num_train_timesteps
35 | cbf_alpha: 10.0
36 | clf_gamma: 0.03
37 | penalty_slack_cbf: 1.0e+3
38 | penalty_slack_clf: 1.0
39 |
40 | trainer:
41 | use_ema: true
42 | batch_size: 256
43 | optimizer:
44 | name: "adamw"
45 | learning_rate: 1.0e-4
46 | weight_decay: 1.0e-6
47 | lr_scheduler:
48 | name: "cosine"
49 | num_warmup_steps: 500
50 |
51 | dataloader:
52 | batch_size: 256
53 |
54 | normalizer:
55 | action:
56 | min: [-2.6515920785069467, -10.169477462768555, -4.270973491668701, -13.883374419993748, -1.9637073183059692, -20.395111083984375]
57 | max: [8.0543811917305, 6.976394176483154, 8.477932775914669, 11.327190831233095, 2.5276688146591186, 18.05487823486328]
58 | observation:
59 | min: [-2.649836301803589, -9.564324378967285, -4.264063358306885, -13.772777557373047, -1.9476231336593628, -17.225351333618164]
60 | max: [8.05234146118164, 6.976299285888672, 8.474746704101562, 10.96181583404541, 2.5151419639587402, 18.054880142211914]
61 |
62 | simulator:
63 | m_q: 1.0 # kg
64 | I_xx: 0.1 # kg.m^2
65 | l_q: 0.3 # m, length of the quadrotor
66 | g: 9.81
67 | dt: 0.01
68 |
--------------------------------------------------------------------------------
/core/controllers/base_controller.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Dict, List
3 |
4 |
5 | class BaseController:
6 | def __init__(self, device: str = "cuda"):
7 | self.device = device
8 |
9 | def predict_action(self, obs_dict: Dict[str, List]) -> np.ndarray:
10 | raise NotImplementedError()
11 |
12 | def reset(self):
13 | # reset state for stateful policies
14 | pass
15 |
--------------------------------------------------------------------------------
/core/controllers/quadrotor_clf_cbf_qp.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 | import numpy as np
3 | import osqp
4 | from scipy import sparse
5 |
6 | from core.controllers.base_controller import BaseController
7 |
8 |
9 | class QuadrotorCLFCBFController(BaseController):
10 | """
11 | A CLF-CBF safety filter assuming a simple velocity-controled dynamics
12 | y_dot = u1
13 | z_dot = u2
14 | Barrier funciton h is defined as the distances to each obstacle
15 | """
16 |
17 | def __init__(self, config: Dict, device: str = "cuda"):
18 | super().__init__(device)
19 | self.obstacle_info = {"center": [], "radius": []}
20 | self.set_config(config)
21 |
22 | def predict_action(self, obs_dict: Dict[str, List], control: np.ndarray, target_position: np.ndarray) -> np.ndarray:
23 | for center, radius in zip(obs_dict["obstacle_info"]["center"], obs_dict["obstacle_info"]["radius"]):
24 | self.set_obstacle(center, radius)
25 |
26 | safe_command = self.clf_cbf_control(
27 | state=obs_dict["state"],
28 | control=control,
29 | obs_center=self.obstacle_info["center"],
30 | obs_radius=self.obstacle_info["radius"],
31 | cbf_alpha=self.cbf_alpha,
32 | clf_gamma=self.clf_gamma,
33 | penalty_slack_cbf=self.penalty_slack_cbf,
34 | penalty_slack_clf=self.penalty_slack_clf,
35 | target_position=target_position,
36 | )
37 | return safe_command
38 |
39 | def set_obstacle(self, center: tuple, radius: float):
40 | self.obstacle_info = {"center": [], "radius": []}
41 | self.obstacle_info["center"].append(center)
42 | self.obstacle_info["radius"].append(radius)
43 |
44 | def set_config(self, config: Dict):
45 | self.cbf_alpha = config["cbf_clf_controller"]["cbf_alpha"]
46 | self.clf_gamma = config["cbf_clf_controller"]["clf_gamma"]
47 | self.penalty_slack_cbf = config["cbf_clf_controller"]["penalty_slack_cbf"]
48 | self.penalty_slack_clf = config["cbf_clf_controller"]["penalty_slack_clf"]
49 | self.denoising_guidance_step = config["cbf_clf_controller"]["denoising_guidance_step"]
50 | self.quadrotor_params = config["simulator"]
51 |
52 | @staticmethod
53 | def _barrier_func(y, z, obs_y, obs_z, obs_r) -> float:
54 | return (y - obs_y) ** 2 + (z - obs_z) ** 2 - (obs_r) ** 2
55 |
56 | @staticmethod
57 | def _barrier_func_dot(y, z, obs_y, obs_z) -> list:
58 | return [2 * (y - obs_y), 2 * (z - obs_z)]
59 |
60 | @staticmethod
61 | def _lyapunoc_func(y, z, des_y, des_z) -> float:
62 | return (y - des_y) ** 2 + (z - des_z) ** 2
63 |
64 | @staticmethod
65 | def _lyapunov_func_dot(y, z, des_y, des_z) -> list:
66 | return [2 * (y - des_y), 2 * (z - des_z)]
67 |
68 | @staticmethod
69 | def _define_QP_problem_data(
70 | u1: float,
71 | u2: float,
72 | cbf_alpha: float,
73 | clf_gamma: float,
74 | penalty_slack_cbf: float,
75 | penalty_slack_clf: float,
76 | h: list,
77 | coeffs_dhdx: list,
78 | v: list,
79 | coeffs_dvdx: list,
80 | vmin=-15.0,
81 | vmax=15.0,
82 | ):
83 | vmin, vmax = -15.0, 15.0
84 |
85 | P = sparse.csc_matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, penalty_slack_cbf, 0], [0, 0, 0, penalty_slack_clf]])
86 | q = np.array([-u1, -u2, 0, 0])
87 | A = sparse.csc_matrix(
88 | [c for c in coeffs_dhdx]
89 | + [c for c in coeffs_dvdx]
90 | + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
91 | )
92 | lb = np.array([-cbf_alpha * h_ for h_ in h] + [-np.inf for _ in v] + [vmin, vmin, 0, 0])
93 | ub = np.array([np.inf for _ in h] + [-clf_gamma * v_ for v_ in v] + [vmax, vmax, np.inf, np.inf])
94 | return P, q, A, lb, ub
95 |
96 | @staticmethod
97 | def _get_quadrotor_state(state):
98 | y, y_dot, z, z_dot, phi, phi_dot = state
99 | return y, y_dot, z, z_dot, phi, phi_dot
100 |
101 | def _calculate_cbf_coeffs(self, state: np.ndarray, obs_center: List, obs_radius: List, minimal_distance: float):
102 | """
103 | Let barrier function be h and system state x, the CBF constraint
104 | h_dot(x) >= - alpha * h + δ
105 | """
106 | h = [] # barrier values (here, remaining distance to each obstacle)
107 | coeffs_dhdx = [] # dhdt = dhdx * dxdt = dhdx * u
108 | for center, radius in zip(obs_center, obs_radius):
109 | y, _, z, _, _, _ = self._get_quadrotor_state(state)
110 | h.append(self._barrier_func(y, z, center[0], center[1], radius + minimal_distance))
111 | # Additional [1, 0] incorporates the CBF slack variable into the constraint
112 | coeffs_dhdx.append(self._barrier_func_dot(y, z, center[0], center[1]) + [1, 0])
113 | return h, coeffs_dhdx
114 |
115 | def _calculate_clf_coeffs(self, state: np.ndarray, target_y: float, _target_z: float):
116 | """
117 | Let Lyapunov function be v and system state x, the CBF constraint
118 | v_dot(x) - δ <= - gamma * v
119 | """
120 | y, _, z, _, _, _ = self._get_quadrotor_state(state)
121 | v = [self._lyapunoc_func(y, z, target_y, _target_z)]
122 | # Additional [0, -1] incorporates the CLF slack variable into the constraint
123 | coeffs_dvdx = [self._lyapunov_func_dot(y, z, target_y, _target_z) + [0, -1]]
124 | return v, coeffs_dvdx
125 |
126 | def clf_cbf_control(
127 | self,
128 | state: np.ndarray,
129 | control: np.ndarray,
130 | obs_center: List,
131 | obs_radius: List,
132 | cbf_alpha: float = 15.0,
133 | clf_gamma: float = 0.01,
134 | penalty_slack_cbf: float = 1e2,
135 | penalty_slack_clf: float = 1.0,
136 | target_position: tuple = (5.0, 5.0),
137 | ):
138 | """
139 | Calculate the safe command by solveing the following optimization problem
140 |
141 | minimize || u - u_nom ||^2 + k * δ^2
142 | u, δ
143 | s.t.
144 | h'(x) ≥ -𝛼 * h(x) - δ1
145 | v'(x) ≤ -γ * v(x) + δ2
146 | u_min ≤ u ≤ u_max
147 | 0 ≤ δ1,δ2 ≤ inf
148 | where
149 | u = [ux, uy] is the control input in x and y axis respectively.
150 | δ is the slack variable
151 | h(x) is the control barrier function and h'(x) its derivative
152 | v(x) is the lyapunov function and v'(x) its derivative
153 |
154 | The problem above can be formulated as QP (ref: https://osqp.org/docs/solver/index.html)
155 |
156 | minimize 1/2 * x^T * Px + q^T x
157 | x
158 | s.t.
159 | l ≤ Ax ≤ u
160 | where
161 | x = [ux, uy, δ1, δ2]
162 |
163 | """
164 | u1, u2 = control
165 | target_y, target_z = target_position
166 |
167 | # Calculate values of the barrier function and coeffs in h_dot to state
168 | h, coeffs_dhdx = self._calculate_cbf_coeffs(state, obs_center, obs_radius, self.quadrotor_params["l_q"])
169 | # Calculate value of the lyapunov function and coeffs in v_dot to state
170 | v, coeffs_dvdx = self._calculate_clf_coeffs(state, target_y, target_z)
171 |
172 | # Define problem
173 | P, q, A, lb, ub = self._define_QP_problem_data(
174 | u1, u2, cbf_alpha, clf_gamma, penalty_slack_cbf, penalty_slack_clf, h, coeffs_dhdx, v, coeffs_dvdx
175 | )
176 |
177 | # Solve QP
178 | prob = osqp.OSQP()
179 | prob.setup(P, q, A, lb, ub, verbose=False, time_limit=0)
180 | # Solve QP problem
181 | res = prob.solve()
182 |
183 | safe_u1, safe_u2, _, _ = res.x
184 | return np.array([safe_u1, safe_u2])
185 |
--------------------------------------------------------------------------------
/core/controllers/quadrotor_diffusion_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import numpy as np
4 | import torch
5 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
6 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
7 | from diffusers.schedulers.scheduling_dpmsolver_multistep import (
8 | DPMSolverMultistepScheduler,
9 | )
10 |
11 | from core.controllers.base_controller import BaseController
12 | from core.controllers.quadrotor_clf_cbf_qp import QuadrotorCLFCBFController
13 | from core.networks.conditional_unet1d import ConditionalUnet1D
14 | from utils.normalizers import BaseNormalizer
15 |
16 |
17 | def build_networks_from_config(config: Dict):
18 | action_dim = config["controller"]["networks"]["action_dim"]
19 | obs_dim = config["controller"]["networks"]["obs_dim"]
20 | obs_horizon = config["obs_horizon"]
21 | obstacle_encode_dim = config["controller"]["networks"]["obstacle_encode_dim"]
22 | return ConditionalUnet1D(input_dim=action_dim, global_cond_dim=obs_dim * obs_horizon + obstacle_encode_dim)
23 |
24 |
25 | def build_noise_scheduler_from_config(config: Dict):
26 | type_noise_scheduler = config["controller"]["noise_scheduler"]["type"]
27 | if type_noise_scheduler.lower() == "ddpm":
28 | return DDPMScheduler(
29 | num_train_timesteps=config["controller"]["noise_scheduler"]["ddpm"]["num_train_timesteps"],
30 | beta_schedule=config["controller"]["noise_scheduler"]["ddpm"]["beta_schedule"],
31 | clip_sample=config["controller"]["noise_scheduler"]["ddpm"]["clip_sample"],
32 | prediction_type=config["controller"]["noise_scheduler"]["ddpm"]["prediction_type"],
33 | )
34 | elif type_noise_scheduler.lower() == "ddim":
35 | return DDIMScheduler(
36 | num_train_timesteps=config["controller"]["noise_scheduler"]["ddim"]["num_train_timesteps"],
37 | beta_schedule=config["controller"]["noise_scheduler"]["ddim"]["beta_schedule"],
38 | clip_sample=config["controller"]["noise_scheduler"]["ddim"]["clip_sample"],
39 | prediction_type=config["controller"]["noise_scheduler"]["ddim"]["prediction_type"],
40 | )
41 | elif type_noise_scheduler.lower() == "dpmsolver":
42 | return DPMSolverMultistepScheduler(
43 | num_train_timesteps=config["controller"]["noise_scheduler"]["dpmsolver"]["num_train_timesteps"],
44 | beta_schedule=config["controller"]["noise_scheduler"]["dpmsolver"]["beta_schedule"],
45 | prediction_type=config["controller"]["noise_scheduler"]["dpmsolver"]["prediction_type"],
46 | use_karras_sigmas=config["controller"]["noise_scheduler"]["dpmsolver"]["use_karras_sigmas"],
47 | )
48 | else:
49 | raise NotImplementedError
50 |
51 |
52 | class QuadrotorDiffusionPolicy(BaseController):
53 | def __init__(
54 | self,
55 | model: ConditionalUnet1D,
56 | noise_scheduler: DDPMScheduler,
57 | normalizer: BaseNormalizer,
58 | clf_cbf_controller: QuadrotorCLFCBFController,
59 | config: Dict,
60 | device: str = "cuda",
61 | ):
62 | self.device = device
63 | self.net = model
64 | self.noise_scheduler = noise_scheduler
65 | self.normalizer = normalizer
66 |
67 | self.set_config(config)
68 | self.net.to(self.device)
69 |
70 | self.clf_cbf_controller = clf_cbf_controller
71 | self.use_clf_cbf_guidance = False if clf_cbf_controller is None else True
72 |
73 | def predict_action(self, obs_dict: Dict[str, List]) -> np.ndarray:
74 | # stack the observations
75 | obs_seq = np.stack(obs_dict["state"])
76 | # normalize observation and make it 1D
77 | nobs = self.normalizer.normalize_data(obs_seq, stats=self.norm_stats["obs"])
78 | nobs = nobs.flatten()
79 | # concat obstacle information to observations
80 | nobs = np.concatenate([nobs] + obs_dict["obs_encode"], axis=0)
81 | # device transfer
82 | nobs = torch.from_numpy(nobs).to(self.device, dtype=torch.float32)
83 |
84 | # infer action
85 | with torch.no_grad():
86 | # reshape observation to (1, obs_horizon*obs_dim+obstacle_encode_dim)
87 | obs_cond = nobs.unsqueeze(0).flatten(start_dim=1)
88 |
89 | # initialize action from Guassian noise
90 | noisy_action = torch.randn((1, self.pred_horizon, self.action_dim), device=self.device)
91 | naction = noisy_action
92 |
93 | # init scheduler
94 | self.noise_scheduler.set_timesteps(self.noise_scheduler.config.num_train_timesteps)
95 |
96 | # denoise
97 | denoise_timesteps = (
98 | self.noise_scheduler.timesteps[:1] if self.use_single_step_inference else self.noise_scheduler.timesteps
99 | )
100 | for k in denoise_timesteps:
101 | # predict noise
102 | noise_pred = self.net(sample=naction, timestep=k, global_cond=obs_cond)
103 | # inverse diffusion step (remove noise)
104 | if self.use_single_step_inference:
105 | naction = noisy_action - noise_pred
106 | else:
107 | naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample
108 |
109 | if self.use_clf_cbf_guidance:
110 | diffusing_action = self.normalizer.unnormalize_data(
111 | naction.detach().to("cpu").numpy().squeeze(), stats=self.norm_stats["act"]
112 | ) # (pred_horizon, 6)
113 | if k < self.clf_cbf_controller.denoising_guidance_step:
114 | refined_action = diffusing_action.copy()
115 | for idx, act in enumerate(diffusing_action):
116 | clf_cbf_obs, pred_control, target_position = self._preprocess_cbf_clf_input(
117 | obs_dict, act, diffusing_action
118 | )
119 | safe_yz_velocity = self.clf_cbf_controller.predict_action(
120 | obs_dict=clf_cbf_obs,
121 | control=pred_control,
122 | target_position=target_position,
123 | )
124 | refined_action[idx, ...] = self._calculate_refined_action_step(act, safe_yz_velocity)
125 | naction = self.normalizer.normalize_data(np.array(refined_action), stats=self.norm_stats["act"])
126 | naction = torch.from_numpy(naction).to(self.device, dtype=torch.float32).unsqueeze(0)
127 |
128 | # unnormalize action
129 | naction = naction.detach().to("cpu").numpy()
130 | # (1, pred_horizon, action_dim)
131 | naction = naction[0]
132 | action_pred = self.normalizer.unnormalize_data(naction, stats=self.norm_stats["act"])
133 |
134 | # only take action_horizon number of actions
135 | start = self.obs_horizon - 1
136 | end = start + self.action_horizon
137 | action = action_pred[start:end, :] # (action_horizon, action_dim)
138 |
139 | return action
140 |
141 | def load_weights(self, ckpt_path: str):
142 | state_dict = torch.load(ckpt_path, map_location=self.device)
143 | self.net.load_state_dict(state_dict)
144 |
145 | def set_config(self, config: Dict):
146 | self.obs_horizon = config["obs_horizon"]
147 | self.action_horizon = config["action_horizon"]
148 | self.pred_horizon = config["pred_horizon"]
149 | self.action_dim = config["controller"]["networks"]["action_dim"]
150 | self.sampling_time = config["controller"]["common"]["sampling_time"]
151 | self.norm_stats = {
152 | "act": config["normalizer"]["action"],
153 | "obs": config["normalizer"]["observation"],
154 | }
155 | self.quadrotor_params = config["simulator"]
156 | self.use_single_step_inference = config.get("controller").get("common").get("use_single_step_inference", False)
157 |
158 | def calculate_force_command(self, state: np.ndarray, ref_state: np.ndarray) -> np.ndarray:
159 | y, y_dot, z, z_dot, phi, phi_dot = state
160 | yr, yr_dot, zr, zr_dot, phir, phir_dot = ref_state
161 | (
162 | dt,
163 | m_q,
164 | ) = self.quadrotor_params["dt"], self.quadrotor_params["m_q"]
165 | g, I_xx = self.quadrotor_params["g"], self.quadrotor_params["I_xx"]
166 | # how on earth do you want to calculate acceleration from position signals
167 | # est_zr_dot = (zr - z) / dt
168 | # est_phir_dot = (phir - phi) / dt
169 | zr_ddot = (zr_dot - z_dot) / dt
170 | phir_ddot = (phir_dot - phi_dot) / dt
171 | return np.array([m_q * (g + zr_ddot), I_xx * phir_ddot])
172 |
173 | def _preprocess_cbf_clf_input(
174 | self, obs_dict: Dict[str, List], pred_action: np.ndarray, diffusing_action: np.ndarray
175 | ):
176 | pred_state = pred_action
177 | pred_control = pred_action[[1, 3]]
178 | target_position_y, target_position_z = diffusing_action[-1, [0, 2]] # myoptic planning of CLF
179 | target_position = (target_position_y, target_position_z)
180 | obstacle_info = {"center": obs_dict["obs_center"], "radius": obs_dict["obs_radius"]}
181 | return {"state": pred_state, "obstacle_info": obstacle_info}, pred_control, target_position
182 |
183 | def _calculate_refined_action_step(self, pred_act, safe_yz_velocity):
184 | refined_step_action = pred_act.copy()
185 | refined_step_action[0] += safe_yz_velocity[0] * self.sampling_time
186 | refined_step_action[2] += safe_yz_velocity[1] * self.sampling_time
187 | refined_step_action[1] = safe_yz_velocity[0]
188 | refined_step_action[3] = safe_yz_velocity[1]
189 | refined_step_action[4] = -np.arctan(safe_yz_velocity[0] / safe_yz_velocity[1])
190 | # refinedstep_action[5] = (refinedstep_action[4] - pred_act[4]) / self.sampling_time
191 | return refined_step_action
192 |
--------------------------------------------------------------------------------
/core/dataset/quadrotor_dataset.py:
--------------------------------------------------------------------------------
1 | import joblib
2 | import torch
3 | from typing import Dict
4 | import numpy as np
5 |
6 |
7 | def create_sample_indices(
8 | episode_ends: np.ndarray,
9 | sequence_length: int,
10 | pad_before: int = 0,
11 | pad_after: int = 0,
12 | ):
13 | indices = list()
14 | for i in range(len(episode_ends)):
15 | start_idx = 0
16 | if i > 0:
17 | start_idx = episode_ends[i - 1]
18 | end_idx = episode_ends[i]
19 | episode_length = end_idx - start_idx
20 |
21 | min_start = -pad_before
22 | max_start = episode_length - sequence_length + pad_after
23 |
24 | # range stops one idx before end
25 | for idx in range(min_start, max_start + 1):
26 | buffer_start_idx = max(idx, 0) + start_idx
27 | buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
28 | start_offset = buffer_start_idx - (idx + start_idx)
29 | end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
30 | sample_start_idx = 0 + start_offset
31 | sample_end_idx = sequence_length - end_offset
32 | indices.append([buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx])
33 | indices = np.array(indices)
34 | return indices
35 |
36 |
37 | def sample_sequence(
38 | train_data,
39 | sequence_length,
40 | buffer_start_idx,
41 | buffer_end_idx,
42 | sample_start_idx,
43 | sample_end_idx,
44 | ):
45 | result = dict()
46 | for key, input_arr in train_data.items():
47 | sample = input_arr[buffer_start_idx:buffer_end_idx]
48 | data = sample
49 | if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
50 | data = np.zeros(shape=(sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype)
51 | if sample_start_idx > 0:
52 | data[:sample_start_idx] = sample[0]
53 | if sample_end_idx < sequence_length:
54 | data[sample_end_idx:] = sample[-1]
55 | data[sample_start_idx:sample_end_idx] = sample
56 | result[key] = data
57 | return result
58 |
59 |
60 | # normalize data
61 | def get_data_stats(data):
62 | data = data.reshape(-1, data.shape[-1])
63 | stats = {"min": np.min(data, axis=0), "max": np.max(data, axis=0)}
64 | return stats
65 |
66 |
67 | def normalize_data(data, stats):
68 | # nomalize to [0,1]
69 | ndata = (data - stats["min"]) / (stats["max"] - stats["min"])
70 | # normalize to [-1, 1]
71 | ndata = ndata * 2 - 1
72 | return ndata
73 |
74 |
75 | def unnormalize_data(ndata, stats):
76 | ndata = (ndata + 1) / 2
77 | data = ndata * (stats["max"] - stats["min"]) + stats["min"]
78 | return data
79 |
80 |
81 | def preprocess_dataset(dataset):
82 | dataset["obs_encode"] = []
83 | dataset["episode_ends"] = []
84 |
85 | current_idx = 0
86 | for s, c in zip(dataset["state"], dataset["control"]):
87 | dataset["episode_ends"].append(current_idx + s.shape[0])
88 | current_idx += s.shape[0]
89 |
90 | for s, info in zip(dataset["state"], dataset["info"]):
91 | obs_encode = np.zeros((7, 7))
92 | for center, radius in zip(info["obs_center"], info["obs_radius"]):
93 | obs_encode[tuple((center - 1).astype(np.int32))] = radius
94 | dataset["obs_encode"].append(np.vstack([obs_encode.flatten()] * s.shape[0]))
95 | return dataset
96 |
97 |
98 | # dataset
99 | class PlanarQuadrotorStateDataset(torch.utils.data.Dataset):
100 | def __init__(
101 | self,
102 | dataset_path: str,
103 | config: Dict,
104 | ):
105 | self.set_config(config)
106 | # read from zarr dataset
107 | dataset_root = joblib.load(dataset_path)
108 | # proprocessing
109 | dataset_root = preprocess_dataset(dataset_root)
110 | # All demonstration episodes are concatinated in the first dimension N
111 | train_data = {
112 | # (N, action_dim)
113 | "action": np.concatenate(dataset_root["desired_state"], axis=0),
114 | # (N, obs_dim)
115 | "obs": np.concatenate(dataset_root["state"], axis=0),
116 | # (N, 7x7)
117 | "obstacle": np.concatenate(dataset_root["obs_encode"], axis=0),
118 | }
119 | # Marks one-past the last index for each episode
120 | episode_ends = dataset_root["episode_ends"]
121 |
122 | # compute start and end of each state-action sequence
123 | # also handles padding
124 | indices = create_sample_indices(
125 | episode_ends=episode_ends,
126 | sequence_length=self.pred_horizon,
127 | # add padding such that each timestep in the dataset are seen
128 | pad_before=self.obs_horizon - 1,
129 | pad_after=self.action_horizon - 1,
130 | )
131 |
132 | # compute statistics and normalized data to [-1,1]
133 | stats = dict()
134 | normalized_train_data = dict()
135 | for key, data in train_data.items():
136 | if key == "obstacle":
137 | continue
138 | stats[key] = get_data_stats(data)
139 | normalized_train_data[key] = normalize_data(data, stats[key])
140 | normalized_train_data["obstacle"] = train_data["obstacle"]
141 |
142 | self.indices = indices
143 | self.stats = stats
144 | self.normalized_train_data = normalized_train_data
145 |
146 | def __len__(self):
147 | # all possible segments of the dataset
148 | return len(self.indices)
149 |
150 | def __getitem__(self, idx):
151 | # get the start/end indices for this datapoint
152 | buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx = self.indices[idx]
153 |
154 | # get nomralized data using these indices
155 | nsample = sample_sequence(
156 | train_data=self.normalized_train_data,
157 | sequence_length=self.pred_horizon,
158 | buffer_start_idx=buffer_start_idx,
159 | buffer_end_idx=buffer_end_idx,
160 | sample_start_idx=sample_start_idx,
161 | sample_end_idx=sample_end_idx,
162 | )
163 |
164 | # discard unused observations
165 | # (obs_horiuzon * obs_dim + obstacle_encode_dim)
166 | nsample["obs"] = np.concatenate(
167 | [nsample["obs"][: self.obs_horizon, :].flatten(), nsample["obstacle"][0]],
168 | axis=0,
169 | )
170 | return nsample
171 |
172 | def set_config(self, config: Dict):
173 | self.config = config
174 | self.obs_horizon = config["obs_horizon"]
175 | self.action_horizon = config["action_horizon"]
176 | self.pred_horizon = config["pred_horizon"]
177 |
--------------------------------------------------------------------------------
/core/env/planar_quadrotor.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional
3 |
4 | import jax
5 | from jax import numpy as jnp
6 |
7 | # Default quadrotor parameters
8 | m_q = 1.0 # kg
9 | I_xx = 0.1 # kg.m^2
10 | l_q = 0.3 # m, length of the quadrotor
11 | g = 9.81
12 |
13 |
14 | class PlanarQuadrotorEnv:
15 | def __init__(self, config: dict = None, state: Optional[jnp.ndarray] = None):
16 | if config is None:
17 | self.m_q = m_q
18 | self.I_xx = I_xx
19 | self.g = g
20 | self.l_q = l_q
21 | else:
22 | self.m_q = config["simulator"]["m_q"]
23 | self.I_xx = config["simulator"]["I_xx"]
24 | self.g = config["simulator"]["g"]
25 | self.l_q = config["simulator"]["l_q"]
26 |
27 | self.state: Optional[jnp.ndarray] = state
28 |
29 | @partial(jax.jit, static_argnums=0)
30 | def step(self, state=None, control=[0, 0], dt: float = 0.01):
31 | """
32 | dynamics with JAX-compatible code.
33 |
34 | Equations are from the Aerial Robotics coursera lecture
35 | https://www.coursera.org/lecture/robotics-flight/2-d-quadrotor-control-kakc6
36 | """
37 | if state is None:
38 | state = self.state
39 | if state is None:
40 | raise Exception("state variable is not defined.")
41 |
42 | y, y_dot, z, z_dot, phi, phi_dot = state
43 | u1, u2 = control
44 | # Quadrotor dynamics
45 | y_ddot = -u1 * jnp.sin(phi) / self.m_q
46 | z_ddot = -self.g + u1 * jnp.cos(phi) / self.m_q
47 | phi_ddot = u2 / self.I_xx
48 |
49 | next_state = (
50 | state
51 | + jnp.array(
52 | [
53 | y_dot + y_ddot * dt,
54 | y_ddot,
55 | z_dot + z_ddot * dt,
56 | z_ddot,
57 | phi_dot + phi_ddot * dt,
58 | phi_ddot,
59 | ]
60 | )
61 | * dt
62 | )
63 | self.state = next_state
64 | return next_state
65 |
--------------------------------------------------------------------------------
/core/networks/conditional_unet1d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Union
4 |
5 | from core.networks.conv1d_components import Downsample1d, Upsample1d, Conv1dBlock
6 | from core.networks.positional_embedding import SinusoidalPosEmb
7 |
8 |
9 | class ConditionalResidualBlock1D(nn.Module):
10 | def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8):
11 | super().__init__()
12 |
13 | self.blocks = nn.ModuleList(
14 | [
15 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
16 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
17 | ]
18 | )
19 |
20 | # FiLM modulation https://arxiv.org/abs/1709.07871
21 | # predicts per-channel scale and bias
22 | cond_channels = out_channels * 2
23 | self.out_channels = out_channels
24 | self.cond_encoder = nn.Sequential(
25 | nn.Linear(cond_dim, cond_dim), nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1))
26 | )
27 |
28 | # make sure dimensions compatible
29 | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
30 |
31 | def forward(self, x, cond):
32 | """
33 | x : [ batch_size x in_channels x horizon ]
34 | cond : [ batch_size x cond_dim]
35 |
36 | returns:
37 | out : [ batch_size x out_channels x horizon ]
38 | """
39 | out = self.blocks[0](x)
40 | embed = self.cond_encoder(cond)
41 |
42 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
43 | scale = embed[:, 0, ...]
44 | bias = embed[:, 1, ...]
45 | out = scale * out + bias
46 |
47 | out = self.blocks[1](out)
48 | out = out + self.residual_conv(x)
49 | return out
50 |
51 |
52 | class ConditionalUnet1D(nn.Module):
53 | def __init__(
54 | self,
55 | input_dim,
56 | global_cond_dim,
57 | diffusion_step_embed_dim=256,
58 | down_dims=[256, 512, 1024],
59 | kernel_size=5,
60 | n_groups=8,
61 | ):
62 | """
63 | input_dim: Dim of actions.
64 | global_cond_dim: Dim of global conditioning applied with FiLM
65 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
66 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
67 | down_dims: Channel size for each UNet level.
68 | The length of this array determines numebr of levels.
69 | kernel_size: Conv kernel size
70 | n_groups: Number of groups for GroupNorm
71 | """
72 |
73 | super().__init__()
74 | all_dims = [input_dim] + list(down_dims)
75 | start_dim = down_dims[0]
76 |
77 | dsed = diffusion_step_embed_dim
78 | diffusion_step_encoder = nn.Sequential(
79 | SinusoidalPosEmb(dsed),
80 | nn.Linear(dsed, dsed * 4),
81 | nn.Mish(),
82 | nn.Linear(dsed * 4, dsed),
83 | )
84 | cond_dim = dsed + global_cond_dim
85 |
86 | in_out = list(zip(all_dims[:-1], all_dims[1:]))
87 | mid_dim = all_dims[-1]
88 | self.mid_modules = nn.ModuleList(
89 | [
90 | ConditionalResidualBlock1D(
91 | mid_dim,
92 | mid_dim,
93 | cond_dim=cond_dim,
94 | kernel_size=kernel_size,
95 | n_groups=n_groups,
96 | ),
97 | ConditionalResidualBlock1D(
98 | mid_dim,
99 | mid_dim,
100 | cond_dim=cond_dim,
101 | kernel_size=kernel_size,
102 | n_groups=n_groups,
103 | ),
104 | ]
105 | )
106 |
107 | down_modules = nn.ModuleList([])
108 | for ind, (dim_in, dim_out) in enumerate(in_out):
109 | is_last = ind >= (len(in_out) - 1)
110 | down_modules.append(
111 | nn.ModuleList(
112 | [
113 | ConditionalResidualBlock1D(
114 | dim_in,
115 | dim_out,
116 | cond_dim=cond_dim,
117 | kernel_size=kernel_size,
118 | n_groups=n_groups,
119 | ),
120 | ConditionalResidualBlock1D(
121 | dim_out,
122 | dim_out,
123 | cond_dim=cond_dim,
124 | kernel_size=kernel_size,
125 | n_groups=n_groups,
126 | ),
127 | Downsample1d(dim_out) if not is_last else nn.Identity(),
128 | ]
129 | )
130 | )
131 |
132 | up_modules = nn.ModuleList([])
133 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
134 | is_last = ind >= (len(in_out) - 1)
135 | up_modules.append(
136 | nn.ModuleList(
137 | [
138 | ConditionalResidualBlock1D(
139 | dim_out * 2,
140 | dim_in,
141 | cond_dim=cond_dim,
142 | kernel_size=kernel_size,
143 | n_groups=n_groups,
144 | ),
145 | ConditionalResidualBlock1D(
146 | dim_in,
147 | dim_in,
148 | cond_dim=cond_dim,
149 | kernel_size=kernel_size,
150 | n_groups=n_groups,
151 | ),
152 | Upsample1d(dim_in) if not is_last else nn.Identity(),
153 | ]
154 | )
155 | )
156 |
157 | final_conv = nn.Sequential(
158 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
159 | nn.Conv1d(start_dim, input_dim, 1),
160 | )
161 |
162 | self.diffusion_step_encoder = diffusion_step_encoder
163 | self.up_modules = up_modules
164 | self.down_modules = down_modules
165 | self.final_conv = final_conv
166 |
167 | def forward(
168 | self,
169 | sample: torch.Tensor,
170 | timestep: Union[torch.Tensor, float, int],
171 | global_cond=None,
172 | ):
173 | """
174 | x: (B,T,input_dim)
175 | timestep: (B,) or int, diffusion step
176 | global_cond: (B,global_cond_dim)
177 | output: (B,T,input_dim)
178 | """
179 | # (B,T,C)
180 | sample = sample.moveaxis(-1, -2)
181 | # (B,C,T)
182 |
183 | # 1. time
184 | timesteps = timestep
185 | if not torch.is_tensor(timesteps):
186 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
187 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
188 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
189 | timesteps = timesteps[None].to(sample.device)
190 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
191 | timesteps = timesteps.expand(sample.shape[0])
192 |
193 | global_feature = self.diffusion_step_encoder(timesteps)
194 |
195 | if global_cond is not None:
196 | global_feature = torch.cat([global_feature, global_cond], axis=-1)
197 |
198 | x = sample
199 | h = []
200 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
201 | x = resnet(x, global_feature)
202 | x = resnet2(x, global_feature)
203 | h.append(x)
204 | x = downsample(x)
205 |
206 | for mid_module in self.mid_modules:
207 | x = mid_module(x, global_feature)
208 |
209 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
210 | x = torch.cat((x, h.pop()), dim=1)
211 | x = resnet(x, global_feature)
212 | x = resnet2(x, global_feature)
213 | x = upsample(x)
214 |
215 | x = self.final_conv(x)
216 |
217 | # (B,C,T)
218 | x = x.moveaxis(-1, -2)
219 | # (B,T,C)
220 | return x
221 |
--------------------------------------------------------------------------------
/core/networks/conv1d_components.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Downsample1d(nn.Module):
5 | def __init__(self, dim):
6 | super().__init__()
7 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
8 |
9 | def forward(self, x):
10 | return self.conv(x)
11 |
12 |
13 | class Upsample1d(nn.Module):
14 | def __init__(self, dim):
15 | super().__init__()
16 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
17 |
18 | def forward(self, x):
19 | return self.conv(x)
20 |
21 |
22 | class Conv1dBlock(nn.Module):
23 | """
24 | Conv1d --> GroupNorm --> Mish
25 | """
26 |
27 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
28 | super().__init__()
29 |
30 | self.block = nn.Sequential(
31 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
32 | nn.GroupNorm(n_groups, out_channels),
33 | nn.Mish(),
34 | )
35 |
36 | def forward(self, x):
37 | return self.block(x)
38 |
--------------------------------------------------------------------------------
/core/networks/positional_embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class SinusoidalPosEmb(nn.Module):
7 | def __init__(self, dim):
8 | super().__init__()
9 | self.dim = dim
10 |
11 | def forward(self, x):
12 | device = x.device
13 | half_dim = self.dim // 2
14 | emb = math.log(10000) / (half_dim - 1)
15 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
16 | emb = x[:, None] * emb[None, :]
17 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
18 | return emb
19 |
--------------------------------------------------------------------------------
/core/trainers/base_trainer.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 |
4 | class BaseDiffusionPolicyTrainer:
5 | def __init__(
6 | self,
7 | ):
8 | pass
9 |
10 | @abstractmethod
11 | def train(self, *args, **kwargs):
12 | raise NotImplementedError
13 |
14 | @abstractmethod
15 | def save_checkpoint(self, path: str):
16 | raise NotImplementedError
17 |
--------------------------------------------------------------------------------
/core/trainers/quadrotor_diffusion_policy_trainer.py:
--------------------------------------------------------------------------------
1 | from diffusers.optimization import get_scheduler
2 | from diffusers.training_utils import EMAModel
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from tqdm.auto import tqdm
7 | from typing import Dict, Optional
8 |
9 | from core.controllers.quadrotor_diffusion_policy import build_noise_scheduler_from_config
10 | from core.dataset.quadrotor_dataset import PlanarQuadrotorStateDataset
11 | from core.trainers.base_trainer import BaseDiffusionPolicyTrainer
12 | from utils.utils import get_device
13 |
14 |
15 | def build_dataloader_from_dataset_and_config(config: Dict, dataset: torch.utils.data.Dataset):
16 | return torch.utils.data.DataLoader(
17 | dataset,
18 | batch_size=config["trainer"]["batch_size"],
19 | shuffle=True,
20 | pin_memory=True,
21 | )
22 |
23 |
24 | class PlanarQuadrotorDiffusionPolicyTrainer(BaseDiffusionPolicyTrainer):
25 | def __init__(
26 | self, net: nn.Module, dataset: PlanarQuadrotorStateDataset, config: Dict, device: Optional[str] = None
27 | ):
28 | self.net = net
29 | self.noise_scheduler = build_noise_scheduler_from_config(config)
30 | self.dataset = dataset
31 | self.set_config(config)
32 | self.device = get_device() if device is None else torch.device(device)
33 |
34 | self.net.to(self.device)
35 |
36 | # build optimizer
37 | if config["trainer"]["optimizer"]["name"].lower() == "adamw":
38 | self.optimizer = torch.optim.AdamW(
39 | params=self.net.parameters(),
40 | lr=config["trainer"]["optimizer"]["learning_rate"],
41 | weight_decay=config["trainer"]["optimizer"]["weight_decay"],
42 | )
43 | else:
44 | raise NotImplementedError
45 |
46 | # build dataset
47 | self.dataloader = build_dataloader_from_dataset_and_config(config, dataset)
48 |
49 | # set EMA
50 | self.use_ema = config["trainer"]["use_ema"]
51 | self.ema = EMAModel(parameters=self.net.parameters(), power=0.75) if self.use_ema else None
52 |
53 | def prepare_inputs(self, batch):
54 | # data normalized in dataset
55 | # device transfer
56 | obs_cond = batch["obs"].to(self.device, dtype=torch.float32) # FiLM conditioning
57 | action = batch["action"].to(self.device, dtype=torch.float32)
58 | batch_size = obs_cond.shape[0]
59 | return obs_cond, action, batch_size
60 |
61 | def optimization_step(self, action, obs_cond, batch_size):
62 | # sample noise to add to actions
63 | noise = torch.randn(action.shape, device=self.device)
64 |
65 | # sample a diffusion iteration for each data point
66 | timesteps = torch.randint(
67 | 0, self.noise_scheduler.config.num_train_timesteps, (batch_size,), device=self.device
68 | ).long()
69 |
70 | # add noise to the clean images according to the noise magnitude at each diffusion iteration
71 | # (this is the forward diffusion process)
72 | noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps)
73 |
74 | # predict the noise residual
75 | noise_pred = self.net(noisy_actions, timesteps, global_cond=obs_cond)
76 |
77 | # L2 loss
78 | loss = nn.functional.mse_loss(noise_pred, noise)
79 |
80 | # optimize
81 | loss.backward()
82 | self.optimizer.step()
83 | self.optimizer.zero_grad()
84 | # step lr scheduler every batch
85 | # this is different from standard pytorch behavior
86 | self.lr_scheduler.step()
87 |
88 | # update Exponential Moving Average of the model weights
89 | if self.use_ema:
90 | self.ema.step(self.net.parameters())
91 |
92 | return loss
93 |
94 | def train(self, num_epochs: int, save_ckpt_epoch: int = None):
95 | if save_ckpt_epoch is None:
96 | save_ckpt_epoch = num_epochs
97 |
98 | # set learning rate scheduler
99 | self.lr_scheduler = get_scheduler(
100 | name=self.config["trainer"]["lr_scheduler"]["name"],
101 | optimizer=self.optimizer,
102 | num_warmup_steps=self.config["trainer"]["lr_scheduler"]["num_warmup_steps"],
103 | num_training_steps=len(self.dataloader) * num_epochs,
104 | )
105 |
106 | # training loop
107 | trn_loss = []
108 | with tqdm(range(num_epochs), desc="Epoch") as tglobal:
109 | # epoch loop
110 | for epoch_idx in tglobal:
111 | epoch_loss = list()
112 | # batch loop
113 | with tqdm(self.dataloader, desc="Batch", leave=False) as tepoch:
114 | for nbatch in tepoch:
115 | obs_cond, action, B = self.prepare_inputs(nbatch)
116 | loss = self.optimization_step(action, obs_cond, B)
117 |
118 | # logging
119 | loss_cpu = loss.item()
120 | epoch_loss.append(loss_cpu)
121 | tepoch.set_postfix(loss=loss_cpu)
122 |
123 | tglobal.set_postfix(loss=np.mean(epoch_loss))
124 | trn_loss.append(np.mean(epoch_loss))
125 |
126 | # save intermediate ckpt
127 | if (epoch_idx + 1) % save_ckpt_epoch == 0:
128 | self.save_checkpoint(path=f"ckpt_ep{epoch_idx}.ckpt")
129 |
130 | return trn_loss
131 |
132 | def save_checkpoint(self, path: str):
133 | save_model = self.net
134 | if self.config["trainer"]["use_ema"]:
135 | self.ema.copy_to(save_model.parameters())
136 | torch.save(save_model.state_dict(), path)
137 |
138 | def set_config(self, config: Dict):
139 | self.config = config
140 | self.obs_horizon = config["obs_horizon"]
141 | self.obs_dim = config["controller"]["networks"]["obs_dim"]
142 | self.action_horizon = config["action_horizon"]
143 | self.action_dim = config["controller"]["networks"]["action_dim"]
144 | self.pred_horizon = config["pred_horizon"]
145 | self.obstacle_encode_dim = config["controller"]["networks"]["obstacle_encode_dim"]
146 |
147 |
148 | class PlanarQuadrotorDiffusionPolicyFineTuningTrainer(PlanarQuadrotorDiffusionPolicyTrainer):
149 | def __init__(self, net: nn.Module, dataset: PlanarQuadrotorStateDataset, config: Dict, device: str | None = None):
150 | super().__init__(net, dataset, config, device)
151 |
152 | def optimization_step(self, action, obs_cond, batch_size):
153 | # sample noise to add to actions
154 | noise = torch.randn(action.shape, device=self.device)
155 |
156 | # sample a diffusion iteration for each data point
157 | timesteps = torch.randint(
158 | self.noise_scheduler.config.num_train_timesteps - 1,
159 | self.noise_scheduler.config.num_train_timesteps,
160 | (batch_size,),
161 | device=self.device,
162 | ).long()
163 |
164 | # add noise to the clean images according to the noise magnitude at each diffusion iteration
165 | # (this is the forward diffusion process)
166 | noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps)
167 |
168 | # predict the noise residual
169 | noise_pred = self.net(noisy_actions, timesteps, global_cond=obs_cond)
170 |
171 | # L2 loss
172 | loss = nn.functional.mse_loss(noisy_actions - noise_pred, action)
173 |
174 | # optimize
175 | loss.backward()
176 | self.optimizer.step()
177 | self.optimizer.zero_grad()
178 | # step lr scheduler every batch
179 | # this is different from standard pytorch behavior
180 | self.lr_scheduler.step()
181 |
182 | # update Exponential Moving Average of the model weights
183 | if self.use_ema:
184 | self.ema.step(self.net.parameters())
185 |
186 | return loss
187 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "970e2b9e-ccce-47b2-bccc-9471841cce44",
6 | "metadata": {},
7 | "source": [
8 | "## Installation required in Colab"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "tckkiSjNY0Is",
15 | "metadata": {
16 | "colab": {
17 | "base_uri": "https://localhost:8080/"
18 | },
19 | "id": "tckkiSjNY0Is",
20 | "outputId": "627b5bff-c792-4954-947f-40baab0de93c"
21 | },
22 | "outputs": [
23 | {
24 | "name": "stdout",
25 | "output_type": "stream",
26 | "text": [
27 | "Cloning into 'diffusion_policy_quadrotor'...\n",
28 | "remote: Enumerating objects: 155, done.\u001b[K\n",
29 | "remote: Counting objects: 100% (155/155), done.\u001b[K\n",
30 | "remote: Compressing objects: 100% (104/104), done.\u001b[K\n",
31 | "remote: Total 155 (delta 66), reused 111 (delta 38), pack-reused 0\u001b[K\n",
32 | "Receiving objects: 100% (155/155), 4.19 MiB | 25.69 MiB/s, done.\n",
33 | "Resolving deltas: 100% (66/66), done.\n"
34 | ]
35 | }
36 | ],
37 | "source": [
38 | "! git clone https://github.com/shaoanlu/diffusion_policy_quadrotor.git"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 2,
44 | "id": "7rrgEckSZEdx",
45 | "metadata": {
46 | "colab": {
47 | "base_uri": "https://localhost:8080/"
48 | },
49 | "id": "7rrgEckSZEdx",
50 | "outputId": "5cd7bcf5-5b4c-4a48-996b-7257cdf7f03e"
51 | },
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "/content/diffusion_policy_quadrotor\n",
58 | "\u001b[0m\u001b[01;34massets\u001b[0m/ \u001b[01;34mconfig\u001b[0m/ \u001b[01;34mcore\u001b[0m/ demo.ipynb LICENSE pyproject.toml README.md \u001b[01;34mutils\u001b[0m/\n"
59 | ]
60 | }
61 | ],
62 | "source": [
63 | "%cd diffusion_policy_quadrotor\n",
64 | "%ls"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 3,
70 | "id": "rK2Hjrq_Y4T6",
71 | "metadata": {
72 | "id": "rK2Hjrq_Y4T6"
73 | },
74 | "outputs": [],
75 | "source": [
76 | "%%capture\n",
77 | "!pip3 install torch==1.13.1 torchvision==0.14.1 diffusers==0.18.2 jax==0.4.23 jaxlib==0.4.23"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "id": "ee7a0434-3333-4502-9563-adcc12d4a413",
83 | "metadata": {
84 | "id": "ee7a0434-3333-4502-9563-adcc12d4a413"
85 | },
86 | "source": [
87 | "## Description\n",
88 | "\n",
89 | "This notebook demonstrate using a diffusion policy controller to drive a quadrotor moving from (0, 0) to (5, 5) with random circle obstacles presented."
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 1,
95 | "id": "fffc0bb0-584b-4602-885f-a1cfa06d21b6",
96 | "metadata": {
97 | "id": "fffc0bb0-584b-4602-885f-a1cfa06d21b6"
98 | },
99 | "outputs": [],
100 | "source": [
101 | "%load_ext autoreload\n",
102 | "%autoreload 2"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 2,
108 | "id": "670f6c7f-93c9-47dc-b2e4-81ea8949b3f9",
109 | "metadata": {
110 | "id": "670f6c7f-93c9-47dc-b2e4-81ea8949b3f9"
111 | },
112 | "outputs": [],
113 | "source": [
114 | "import numpy as np\n",
115 | "import os\n",
116 | "import torch\n",
117 | "import yaml\n",
118 | "import collections\n",
119 | "from tqdm.auto import tqdm\n",
120 | "import gdown"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 3,
126 | "id": "fa8f1de5-924a-4889-aacc-bac7601012c5",
127 | "metadata": {
128 | "colab": {
129 | "base_uri": "https://localhost:8080/",
130 | "height": 87,
131 | "referenced_widgets": [
132 | "49a3b00a700b40f8b1de991477efe2b9",
133 | "fb587d9dbfa84d72af2b3c45d46c2cc1",
134 | "7487091f00dd4e3bba795ebab6b0f5b1",
135 | "a817704da5184ef6922fb679bf0c395e",
136 | "a3a8dd4705b64b5f9b21b2e3d4ce24b8",
137 | "3a953a43ed574a4c8e8ebc52d5993347",
138 | "1900dfab0c6e4f8a8350f2d024fc22c6",
139 | "a9afd6d521f44216adb119a54f531376",
140 | "632b2841ae2d44b18e67fe288ba136a1",
141 | "01969ff8ad4140328cda2a14df8fe439",
142 | "fca3ddf1eb6d439ba86878e00e4560fd"
143 | ]
144 | },
145 | "id": "fa8f1de5-924a-4889-aacc-bac7601012c5",
146 | "outputId": "730f9c66-b606-4840-f614-540a89379eef"
147 | },
148 | "outputs": [],
149 | "source": [
150 | "from utils.normalizers import LinearNormalizer\n",
151 | "from core.controllers.quadrotor_diffusion_policy import QuadrotorDiffusionPolicy, build_networks_from_config, build_noise_scheduler_from_config\n",
152 | "from core.controllers.quadrotor_clf_cbf_qp import QuadrotorCLFCBFController\n",
153 | "from core.env.planar_quadrotor import PlanarQuadrotorEnv\n",
154 | "from utils.visualization import visualize_quadrotor_simulation_result"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "id": "06cf3d67-3f96-457e-a634-77b6e5f2e31e",
160 | "metadata": {
161 | "id": "06cf3d67-3f96-457e-a634-77b6e5f2e31e"
162 | },
163 | "source": [
164 | "## Load the config file"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 4,
170 | "id": "0c8c8f90-8ac2-4f87-943d-e41b7ff9ff6a",
171 | "metadata": {
172 | "id": "0c8c8f90-8ac2-4f87-943d-e41b7ff9ff6a",
173 | "scrolled": true
174 | },
175 | "outputs": [],
176 | "source": [
177 | "with open(\"config/config.yaml\", \"r\") as file:\n",
178 | " config = yaml.safe_load(file)\n",
179 | "\n",
180 | "# Whether to use a finetuned model trained following tricks mentioned in\n",
181 | "# [Fine-Tuning Image-Conditional Diffusion Models is Easier than You Think](https://arxiv.org/abs/2409.11355)\n",
182 | "use_single_step_inference = config.get(\"controller\").get(\"common\").get(\"use_single_step_inference\", False)"
183 | ]
184 | },
185 | {
186 | "cell_type": "markdown",
187 | "id": "dd56c89a-520f-46b5-b9d6-ebef0160b184",
188 | "metadata": {
189 | "id": "dd56c89a-520f-46b5-b9d6-ebef0160b184"
190 | },
191 | "source": [
192 | "## Instantiate the controller"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 5,
198 | "id": "cb0707a7-3d66-46da-b722-2b2f42d8e099",
199 | "metadata": {
200 | "id": "cb0707a7-3d66-46da-b722-2b2f42d8e099"
201 | },
202 | "outputs": [],
203 | "source": [
204 | "clf_cbf_controller = QuadrotorCLFCBFController(config=config)\n",
205 | "\n",
206 | "controller = QuadrotorDiffusionPolicy(\n",
207 | " model=build_networks_from_config(config),\n",
208 | " noise_scheduler=build_noise_scheduler_from_config(config),\n",
209 | " normalizer=LinearNormalizer(),\n",
210 | " clf_cbf_controller=None, # set as clf_cbf_controller to enable CLF-CBF traj refinement\n",
211 | " config=config,\n",
212 | " device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
213 | ")"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "id": "cac28d88-e304-46b3-bc28-5fc54cd3705d",
219 | "metadata": {
220 | "id": "cac28d88-e304-46b3-bc28-5fc54cd3705d"
221 | },
222 | "source": [
223 | "## Download and load pretrained weights"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": null,
229 | "id": "727e3937-dd21-46a2-9fff-c41c1d653022",
230 | "metadata": {},
231 | "outputs": [],
232 | "source": [
233 | "# download pretrained weights from Google drive\n",
234 | "if use_single_step_inference:\n",
235 | " ckpts_path = \"ema_noise_pred_net2_ph96_oh2_ah10_v8_singlestepFT.ckpt\"\n",
236 | " if not os.path.isfile(ckpts_path):\n",
237 | " gdown.download(id=\"1UhxlzoQ6DOt0HZhokzU4ktl2MfzlA2Ii\", output=ckpts_path, quiet=False)\n",
238 | "else:\n",
239 | " ckpts_path = \"ema_noise_pred_net2_ph96_oh2_ah10_v8.ckpt\"\n",
240 | " if not os.path.isfile(ckpts_path):\n",
241 | " gdown.download(id=\"1-as6EqMLECxU7IVLZZIEDAcXMox_RkI_\", output=ckpts_path, quiet=False)\n",
242 | "\n",
243 | "# load weights\n",
244 | "controller.load_weights(ckpts_path)"
245 | ]
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "id": "a0fda0f4-f5cb-4fa9-86da-abf991edfb4b",
250 | "metadata": {
251 | "id": "a0fda0f4-f5cb-4fa9-86da-abf991edfb4b"
252 | },
253 | "source": [
254 | "## Instantiate the simulator"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": 7,
260 | "id": "6b0ae93e-8fe6-40af-92e6-12643aa21f6d",
261 | "metadata": {
262 | "id": "6b0ae93e-8fe6-40af-92e6-12643aa21f6d"
263 | },
264 | "outputs": [],
265 | "source": [
266 | "sim = PlanarQuadrotorEnv(config)"
267 | ]
268 | },
269 | {
270 | "cell_type": "code",
271 | "execution_count": null,
272 | "id": "a679998d-7a2e-48dc-b0a4-30c24314e307",
273 | "metadata": {
274 | "id": "a679998d-7a2e-48dc-b0a4-30c24314e307"
275 | },
276 | "outputs": [],
277 | "source": []
278 | },
279 | {
280 | "cell_type": "markdown",
281 | "id": "1f1c50bc-e529-432f-92a9-9bb3c6e9fdf7",
282 | "metadata": {
283 | "id": "1f1c50bc-e529-432f-92a9-9bb3c6e9fdf7"
284 | },
285 | "source": [
286 | "## Run simulation"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": 8,
292 | "id": "eecaed67-9814-4a62-aa7f-2caf5d5b43aa",
293 | "metadata": {
294 | "id": "eecaed67-9814-4a62-aa7f-2caf5d5b43aa"
295 | },
296 | "outputs": [],
297 | "source": [
298 | "def generate_random_obstacles():\n",
299 | " num_obstacles = np.random.randint(1, 8)\n",
300 | " obs_center, obs_radius = np.empty((num_obstacles, 2)), np.ones((num_obstacles,))\n",
301 | " for obs_idx in range(num_obstacles): # set XY position of each obstacle\n",
302 | " obs_center[obs_idx, ...] = np.random.randint(1, 8, size=(2,))\n",
303 | " obs_radius[obs_idx] = np.random.uniform(0.2, 1.5)\n",
304 | " return obs_center, obs_radius\n",
305 | "\n",
306 | "def encode_obstacle_info(obs_center: np.ndarray, obs_radius: np.ndarray):\n",
307 | " obs_encode = np.zeros((7, 7))\n",
308 | " for center, radius in zip(obs_center, obs_radius):\n",
309 | " obs_encode[tuple((center-1).astype(np.int32))] = radius\n",
310 | " obs_encode = obs_encode.flatten()\n",
311 | " return obs_encode"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 9,
317 | "id": "3250c190-d380-423c-8455-21a6c1b2c093",
318 | "metadata": {
319 | "colab": {
320 | "base_uri": "https://localhost:8080/",
321 | "height": 67,
322 | "referenced_widgets": [
323 | "fa8ca951c0f54d3c965b136791c54f60",
324 | "14869b876e6143a9848a8421b8da93c3",
325 | "b00b242d48c34b739b335e6137bbec52",
326 | "ba8a57462eb4492880c4980ee126f888",
327 | "7afde10b06494393ab9b53818c580fdc",
328 | "bc399470fb094590a12ba13e98c1f6a8",
329 | "eba7a2b036374b1c9c329d884ce16ad4",
330 | "13066600e23245888e019523e0200005",
331 | "304615b68c9d496ea35928d1546df4ca",
332 | "3de6654be9de4d22b15fb28ad7625f57",
333 | "81e23a9379b34c0cb78e375d47a63a89"
334 | ]
335 | },
336 | "id": "3250c190-d380-423c-8455-21a6c1b2c093",
337 | "outputId": "4c9609f0-a944-42f7-909c-c23cab24dbf0"
338 | },
339 | "outputs": [
340 | {
341 | "data": {
342 | "application/vnd.jupyter.widget-view+json": {
343 | "model_id": "3d93b38cf8264ebe80428df1989880d1",
344 | "version_major": 2,
345 | "version_minor": 0
346 | },
347 | "text/plain": [
348 | "Eval: 0%| | 0/400 [00:00, ?it/s]"
349 | ]
350 | },
351 | "metadata": {},
352 | "output_type": "display_data"
353 | }
354 | ],
355 | "source": [
356 | "np.random.seed(123)\n",
357 | "\n",
358 | "# Env parameters\n",
359 | "max_steps = 400\n",
360 | "dt = 0.01\n",
361 | "ratio_sim_ts = 10 # ratio of sampling time between the simulator and the controller\n",
362 | "\n",
363 | "# get first observation\n",
364 | "state = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) # [y, y_dot, z, z_dot, phi, phi_dot]\n",
365 | "states = [state] # `states` is a list containing the states over time\n",
366 | "controls = [np.array([0, 0])]\n",
367 | "obs_center, obs_radius = generate_random_obstacles()\n",
368 | "obs_encode = encode_obstacle_info(obs_center, obs_radius)\n",
369 | "obs = {\n",
370 | " \"state\": collections.deque([state] * controller.obs_horizon, maxlen=controller.obs_horizon),\n",
371 | " \"obs_encode\": [obs_encode],\n",
372 | " \"obs_center\": obs_center,\n",
373 | " \"obs_radius\": obs_radius,\n",
374 | "}\n",
375 | "\n",
376 | "# termimnation params\n",
377 | "done = False\n",
378 | "step_idx = 0\n",
379 | "\n",
380 | "with tqdm(total=max_steps, desc=\"Eval\") as pbar:\n",
381 | " while not done:\n",
382 | " # controller inference\n",
383 | " action = controller.predict_action(obs)\n",
384 | "\n",
385 | " # execute action_horizon steps without replanning\n",
386 | " for i in range(action.shape[0]):\n",
387 | " # stepping env\n",
388 | " command = controller.calculate_force_command(state, action[i])\n",
389 | " for _ in range(ratio_sim_ts):\n",
390 | " state = sim.step(state, command, dt=dt/ratio_sim_ts)\n",
391 | " # save observations and controls\n",
392 | " obs[\"state\"].append(state.copy())\n",
393 | " states.append(state.copy())\n",
394 | " controls.append(action[i].copy())\n",
395 | "\n",
396 | " # update progress bar\n",
397 | " step_idx += 1\n",
398 | " pbar.update(1)\n",
399 | " if step_idx > max_steps:\n",
400 | " done = True\n",
401 | " if done:\n",
402 | " break"
403 | ]
404 | },
405 | {
406 | "cell_type": "code",
407 | "execution_count": null,
408 | "id": "b006ad91-e853-4a03-80c5-7ecb142bc475",
409 | "metadata": {
410 | "id": "b006ad91-e853-4a03-80c5-7ecb142bc475"
411 | },
412 | "outputs": [],
413 | "source": []
414 | },
415 | {
416 | "cell_type": "markdown",
417 | "id": "602bdd60-2b90-48db-a29b-7f8645501142",
418 | "metadata": {
419 | "id": "602bdd60-2b90-48db-a29b-7f8645501142"
420 | },
421 | "source": [
422 | "## Visualize result"
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "execution_count": 10,
428 | "id": "b92bcec2-d95d-4361-9b92-991346d83950",
429 | "metadata": {
430 | "colab": {
431 | "base_uri": "https://localhost:8080/",
432 | "height": 435
433 | },
434 | "id": "b92bcec2-d95d-4361-9b92-991346d83950",
435 | "outputId": "0d8c5a83-8002-4deb-a87b-48a04af6d0af"
436 | },
437 | "outputs": [
438 | {
439 | "data": {
440 | "image/png": "",
441 | "text/plain": [
442 | ""
443 | ]
444 | },
445 | "metadata": {},
446 | "output_type": "display_data"
447 | }
448 | ],
449 | "source": [
450 | "states = np.array(states)\n",
451 | "visualize_quadrotor_simulation_result(sim, states, obs_center=obs_center, obs_radius=obs_radius)"
452 | ]
453 | },
454 | {
455 | "cell_type": "code",
456 | "execution_count": null,
457 | "id": "33740290-22ae-45bd-822f-0bea782691bd",
458 | "metadata": {
459 | "id": "33740290-22ae-45bd-822f-0bea782691bd"
460 | },
461 | "outputs": [],
462 | "source": []
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": null,
467 | "id": "f15cae17-f337-4ba0-a4ea-5a10a8bc0390",
468 | "metadata": {
469 | "id": "f15cae17-f337-4ba0-a4ea-5a10a8bc0390"
470 | },
471 | "outputs": [],
472 | "source": []
473 | },
474 | {
475 | "cell_type": "code",
476 | "execution_count": null,
477 | "id": "7f889988-1fef-4c11-b113-457247db185c",
478 | "metadata": {
479 | "id": "7f889988-1fef-4c11-b113-457247db185c"
480 | },
481 | "outputs": [],
482 | "source": []
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "id": "6c503b86-5211-4403-a830-b8a100e73b97",
488 | "metadata": {
489 | "id": "6c503b86-5211-4403-a830-b8a100e73b97"
490 | },
491 | "outputs": [],
492 | "source": []
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": null,
497 | "id": "f0772564-5e68-4884-a266-f5467db4a433",
498 | "metadata": {
499 | "id": "f0772564-5e68-4884-a266-f5467db4a433"
500 | },
501 | "outputs": [],
502 | "source": []
503 | },
504 | {
505 | "cell_type": "code",
506 | "execution_count": null,
507 | "id": "3c1ad555-40ce-4917-b0f7-e26dce4d485d",
508 | "metadata": {
509 | "id": "3c1ad555-40ce-4917-b0f7-e26dce4d485d"
510 | },
511 | "outputs": [],
512 | "source": []
513 | }
514 | ],
515 | "metadata": {
516 | "accelerator": "GPU",
517 | "colab": {
518 | "gpuType": "T4",
519 | "provenance": []
520 | },
521 | "kernelspec": {
522 | "display_name": "Python 3 (ipykernel)",
523 | "language": "python",
524 | "name": "python3"
525 | },
526 | "language_info": {
527 | "codemirror_mode": {
528 | "name": "ipython",
529 | "version": 3
530 | },
531 | "file_extension": ".py",
532 | "mimetype": "text/x-python",
533 | "name": "python",
534 | "nbconvert_exporter": "python",
535 | "pygments_lexer": "ipython3",
536 | "version": "3.10.14"
537 | },
538 | "widgets": {
539 | "application/vnd.jupyter.widget-state+json": {
540 | "state": {},
541 | "version_major": 2,
542 | "version_minor": 0
543 | }
544 | }
545 | },
546 | "nbformat": 4,
547 | "nbformat_minor": 5
548 | }
549 |
--------------------------------------------------------------------------------
/learning_note.md:
--------------------------------------------------------------------------------
1 | ## Learning note
2 | ### Main insight
3 | - The model should learn the residuals (gradient of the denoising) if possible. This greatly stabilizes the training.
4 | - Advantages of diffusion model: 1) capability of modeling multi-modality, 2) stable training, and 3) temporally output consistency.
5 | - Iteratively add training data of failure modes to make extrapolation into interpolation.
6 | ### Scribbles
7 | - The trained policy does not 100% reach the goal without collision (there is no collision in its training data).
8 | - Unable to recover from OOD data.
9 | - Long observation might be harmful to the performance, possibly due to the increased possibility of model overfitting.
10 | - The diffusion policy [paper](https://arxiv.org/pdf/2303.04137) also discusses this in its appendix.
11 | - I feel we don't need diffusion models for the simple task in this repo, supervised learning might be equally effective.
12 | - The controller struggles with performance issues when extreme values (maximum or minimum) are presented in the conditional vector.
13 | - For instance, it collides more on obstacles with a maximum radius of 1.5.
14 | - Collect more data to make everything interpolations instead of extrapolations.
15 | - Even though the loss curve appears saturated, the performance of the controller can still improve as training continues.
16 | - The training loss curves of the diffusion model are extremely smooth btw.
17 | - On the contrary, it might be difficult to know if the model is overfitting or not by looking at the trajectory as well as the the denoising process.
18 | - But in general I feel there is little harm training duffusion model as long as possible.
19 | - DDPM and DDIM samplers yield the best result.
20 | - Inference is not in real-time. The controller is set to sun 100Hz.
21 |
22 | ### Possible reasons for failures on collision avoidance
23 | 1. There is no data having collision in the training data.
24 | 2. Policy learned with imitation learning can exhibit accumulated error during closed-loop control
25 |
26 | When the quadrotor getting too close to the obstacles (due to 2), the input state becomes OOD (due to 1), therefore the diffusion policy is unable to recover from such situation.
27 |
28 | - Possible fix: adding training data that the quadrotor recovers from collision.
29 |
30 | ### Things that didn't work
31 | - Tried encoding distances to each obstacle. Did not observe improvement in terms of collision avoidance.
32 | - Tried using vision encoder to replace obstacle encoding. Didn't see performance improvement.
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 |
3 | # lint.select = ["A", "ANN", "B", "C90", "D", "E", "F", "I", "N", "COM", "DTZ", "PD", "RUF", "TID", "UP", "W"]
4 | # lint.ignore = ["D203", "D212"]
5 |
6 | lint.fixable = ["I", "RUF100"]
7 | lint.unfixable = []
8 |
9 | # Exclude a variety of commonly ignored directories.
10 | exclude = [
11 | ".bzr",
12 | ".direnv",
13 | ".eggs",
14 | ".git",
15 | ".hg",
16 | ".mypy_cache",
17 | ".nox",
18 | ".pants.d",
19 | ".pytype",
20 | ".ruff_cache",
21 | ".svn",
22 | ".tox",
23 | ".venv",
24 | "__pypackages__",
25 | "_build",
26 | "buck-out",
27 | "build",
28 | "dist",
29 | "node_modules",
30 | "venv",
31 | ]
32 |
33 | line-length = 120
34 |
35 | # Allow unused variables when underscore-prefixed.
36 | lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
37 |
38 | target-version = "py310"
39 |
40 | [tool.ruff.lint.mccabe]
41 | max-complexity = 10
42 |
43 |
44 | [tool.pyright]
45 | typeCheckingMode = "lazy"
46 | defineConstant = { DEBUG = true }
47 |
48 | reportMissingImports = false
49 | reportMissingTypeStubs = false
50 | reportInvalidStringEscapeSequence = false
51 |
52 | pythonVersion = "3.10"
53 | pythonPlatform = "Linux"
54 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/tests/__init__.py
--------------------------------------------------------------------------------
/tests/config/test_config.yaml:
--------------------------------------------------------------------------------
1 | name: "planar_quadrotor"
2 |
3 | pred_horizon: 96
4 | obs_horizon: 2
5 | action_horizon: 10
6 |
7 | controller:
8 | networks:
9 | obs_dim: 6
10 | action_dim: 6
11 | obstacle_encode_dim: 49
12 | noise_scheduler:
13 | type: "ddpm"
14 | ddpm:
15 | num_train_timesteps: 100 # number of diffusion iterations
16 | beta_schedule: "squaredcos_cap_v2"
17 | clip_sample: true # required when predict_epsilon=False
18 | prediction_type: "epsilon"
19 | ddim: # faster inference
20 | num_train_timesteps: 100
21 | beta_schedule: "squaredcos_cap_v2"
22 | clip_sample: true
23 | prediction_type: "epsilon"
24 | dpmsolver: # faster inference, experimental
25 | num_train_timesteps: 100
26 | beta_schedule: "squaredcos_cap_v2"
27 | prediction_type: "epsilon"
28 | use_karras_sigmas: true
29 |
30 | trainer:
31 | use_ema: true
32 | batch_size: 8
33 | optimizer:
34 | name: "adamw"
35 | learning_rate: 1.0e-4
36 | weight_decay: 1.0e-6
37 | lr_scheduler:
38 | name: "cosine"
39 | num_warmup_steps: 500
40 |
41 | dataloader:
42 | batch_size: 3
43 |
44 | normalizer:
45 | action:
46 | min: [-2.6515920785069467, -10.169477462768555, -4.270973491668701, -13.883374419993748, -1.9637073183059692, -20.395111083984375]
47 | max: [8.0543811917305, 6.976394176483154, 8.477932775914669, 11.327190831233095, 2.5276688146591186, 18.05487823486328]
48 | observation:
49 | min: [-2.649836301803589, -9.564324378967285, -4.264063358306885, -13.772777557373047, -1.9476231336593628, -17.225351333618164]
50 | max: [8.05234146118164, 6.976299285888672, 8.474746704101562, 10.96181583404541, 2.5151419639587402, 18.054880142211914]
51 |
52 | simulator:
53 | m_q: 1.0 # kg
54 | I_xx: 0.1 # kg.m^2
55 | l_q: 0.3 # m, length of the quadrotor
56 | g: 9.81
57 | dt: 0.01
58 |
--------------------------------------------------------------------------------
/tests/fixtures/test_dataset.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaoanlu/diffusion_policy_quadrotor/58781ae2ad1e7b4ff83b40d4726a3eae4f2cefff/tests/fixtures/test_dataset.joblib
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | import yaml
4 |
5 | from core.dataset.quadrotor_dataset import PlanarQuadrotorStateDataset
6 |
7 |
8 | class TestPlanarQuadrotorStateDataset(unittest.TestCase):
9 | def setUp(self):
10 | self.dataset_path = "tests/fixtures/test_dataset.joblib"
11 |
12 | with open("tests/config/test_config.yaml", "r") as file:
13 | self.config = yaml.safe_load(file)
14 |
15 | self.obs_horizon = self.config["obs_horizon"]
16 | self.action_horizon = self.config["action_horizon"]
17 | self.pred_horizon = self.config["pred_horizon"]
18 | self.action_dim = self.config["controller"]["networks"]["action_dim"]
19 | self.obs_dim = self.config["controller"]["networks"]["obs_dim"]
20 | self.obstacle_encode_dim = self.config["controller"]["networks"]["obstacle_encode_dim"]
21 |
22 | def test_init(self):
23 | dataset = PlanarQuadrotorStateDataset(dataset_path=self.dataset_path, config=self.config)
24 | self.assertTrue(isinstance(dataset, PlanarQuadrotorStateDataset))
25 |
26 | def test_iter(self):
27 | batch_size = self.config["dataloader"]["batch_size"]
28 | dataset = PlanarQuadrotorStateDataset(dataset_path=self.dataset_path, config=self.config)
29 | dataloader = torch.utils.data.DataLoader(
30 | dataset,
31 | batch_size=batch_size,
32 | shuffle=True,
33 | pin_memory=True,
34 | )
35 |
36 | # batch context matches expectecd shapes
37 | batch = next(iter(dataloader))
38 | self.assertEqual(batch["obs"].shape, (batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim))
39 | self.assertEqual(batch["action"].shape, (batch_size, self.pred_horizon, self.action_dim))
40 |
41 |
42 | if __name__ == "__main__":
43 | unittest.main()
44 |
--------------------------------------------------------------------------------
/tests/test_networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import unittest
3 | import yaml
4 |
5 | from core.networks.conditional_unet1d import ConditionalUnet1D
6 |
7 |
8 | class TestConditionalUnet1D(unittest.TestCase):
9 | def setUp(self):
10 | with open("tests/config/test_config.yaml", "r") as file:
11 | self.config = yaml.safe_load(file)
12 |
13 | self.obs_horizon = self.config["obs_horizon"]
14 | self.action_horizon = self.config["action_horizon"]
15 | self.pred_horizon = self.config["pred_horizon"]
16 | self.action_dim = self.config["controller"]["networks"]["action_dim"]
17 | self.obs_dim = self.config["controller"]["networks"]["obs_dim"]
18 | self.obstacle_encode_dim = self.config["controller"]["networks"]["obstacle_encode_dim"]
19 |
20 | def test_init(self):
21 | noise_pred_net = ConditionalUnet1D(
22 | input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim
23 | )
24 | self.assertTrue(isinstance(noise_pred_net, ConditionalUnet1D))
25 |
26 | def test_inference(self):
27 | net = ConditionalUnet1D(
28 | input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim
29 | )
30 |
31 | # example inputs
32 | noised_action = torch.randn((1, self.pred_horizon, self.action_dim))
33 | obs = torch.zeros((1, self.obs_horizon * self.obs_dim + self.obstacle_encode_dim))
34 | diffusion_iter = torch.zeros((1,))
35 |
36 | # the noise prediction network
37 | # takes noisy action, diffusion iteration and observation as input
38 | # predicts the noise added to action
39 | noise = net(sample=noised_action, timestep=diffusion_iter, global_cond=obs.flatten(start_dim=1))
40 | # removing noise
41 | denoised_action = noised_action - noise
42 |
43 | self.assertEqual(denoised_action.shape, (1, self.pred_horizon, self.action_dim))
44 |
45 |
46 | if __name__ == "__main__":
47 | unittest.main()
48 |
--------------------------------------------------------------------------------
/utils/normalizers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class BaseNormalizer:
5 | def normalize_data(self, *args, **kwargs):
6 | raise NotImplementedError()
7 |
8 | def unnormalize_data(self, *args, **kwargs):
9 | raise NotImplementedError()
10 |
11 |
12 | class LinearNormalizer(BaseNormalizer):
13 | def __init__(self):
14 | pass
15 |
16 | def normalize_data(self, data, stats):
17 | # nomalize to [0,1]
18 | ndata = (data - np.array(stats["min"])) / (np.array(stats["max"]) - np.array(stats["min"]))
19 | # normalize to [-1, 1]
20 | ndata = ndata * 2 - 1
21 | return ndata
22 |
23 | def unnormalize_data(self, ndata, stats):
24 | ndata = (ndata + 1) / 2
25 | data = ndata * (np.array(stats["max"]) - np.array(stats["min"])) + np.array(stats["min"])
26 | return data
27 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_device():
5 | if torch.cuda.is_available():
6 | device = "cuda"
7 | else:
8 | device = "cpu"
9 | return torch.device(device)
10 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 |
4 |
5 | def visualize_quadrotor_simulation_result(
6 | quadrotor, states: np.ndarray, obs_center: np.ndarray, obs_radius: np.ndarray
7 | ):
8 | def plot_obstacles(obs_center, obs_radius):
9 | for obs_p, obs_r in zip(obs_center, obs_radius):
10 | circle = plt.Circle(
11 | tuple(obs_p),
12 | obs_r,
13 | color="grey",
14 | fill=True,
15 | linestyle="--",
16 | linewidth=2,
17 | alpha=0.5,
18 | )
19 | plt.gca().add_artist(circle)
20 |
21 | plt.figure()
22 | plt.gca().set_aspect("equal", adjustable="box")
23 | ys = [s[0] for s in states]
24 | zs = [s[2] for s in states]
25 | phis = [s[4] for s in states]
26 | # Generate circle for CBF
27 | plot_obstacles(obs_center, obs_radius)
28 |
29 | # Plot the trajectory
30 | plt.scatter(ys, zs)
31 |
32 | # Plot quadrotor pose
33 | y_ = [(y - quadrotor.l_q * np.cos(phi), y + quadrotor.l_q * np.cos(phi)) for (y, phi) in zip(ys[::10], phis[::10])]
34 | z_ = [(z - quadrotor.l_q * np.sin(phi), z + quadrotor.l_q * np.sin(phi)) for (z, phi) in zip(zs[::10], phis[::10])]
35 | for yy, zz in zip(y_, z_):
36 | plt.plot(yy, zz, marker="o", color="r", alpha=0.5)
37 |
38 | # plot start and end point
39 | init_x, init_y = states[0][0], states[0][2]
40 | plt.scatter(init_x, init_y, s=200, color="green", alpha=0.75, label="init. position")
41 | plt.scatter(5.0, 5.0, s=200, color="purple", alpha=0.75, label="target position")
42 |
43 | plt.xlim(-0.5, 7)
44 | plt.ylim(-0.5, 7)
45 | plt.grid()
46 | plt.show()
47 |
--------------------------------------------------------------------------------