├── dmpo ├── __init__.py ├── envs │ ├── __init__.py │ ├── crazyflie2.stl │ ├── math_utils.py │ ├── quadrotor_param.py │ └── quadrotor.py ├── models │ ├── __init__.py │ ├── net_utils.py │ └── mlp.py ├── mpc │ ├── __init__.py │ ├── model │ │ ├── __init__.py │ │ └── quadrotor_model.py │ ├── rollout │ │ ├── __init__.py │ │ ├── rollout_base.py │ │ ├── rollout_generator.py │ │ └── quadrotor_rollout.py │ └── task │ │ ├── __init__.py │ │ ├── quadrotor_rollout_task.py │ │ └── base_rollout_task.py ├── controllers │ ├── __init__.py │ └── dmpo_policy.py ├── dataset │ ├── __init__.py │ └── dataset_buffer.py ├── experiment │ ├── __init__.py │ ├── experiment.py │ ├── experiment_utils.py │ ├── ppo_rollout.py │ ├── ppo_trainer.py │ └── ppo_experiment.py └── utils.py ├── scripts ├── crazyflie2.stl ├── ppo_main.py └── run_dmpo_quadrotor.py ├── requirements.txt ├── environment.yml ├── setup.py ├── LICENSE ├── config ├── envs │ └── zigzagyaw.yaml ├── mpc │ └── quadrotor_zigzagyaw_mppi.yml └── experiments │ └── quadrotor_dmpo_zigzagyaw.yml └── README.md /dmpo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/mpc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/mpc/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/mpc/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dmpo/mpc/task/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/crazyflie2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jisacks/dmpo/HEAD/scripts/crazyflie2.stl -------------------------------------------------------------------------------- /dmpo/envs/crazyflie2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jisacks/dmpo/HEAD/dmpo/envs/crazyflie2.stl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0 2 | pytorch-lightning==2.0.9 3 | numpy 4 | tensorboard 5 | ghalton 6 | tqdm 7 | meshcat 8 | rowan 9 | numpy-quaternion -------------------------------------------------------------------------------- /dmpo/mpc/rollout/rollout_base.py: -------------------------------------------------------------------------------- 1 | class RolloutBase: 2 | def __init__(self): 3 | pass 4 | 5 | def cost_fn(self, state, act): 6 | pass 7 | 8 | def rollout_fn(self, state, act): 9 | pass 10 | 11 | def current_cost(self, current_state): 12 | pass 13 | 14 | def update_params(self): 15 | pass 16 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dmpo 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.8 7 | - pytorch=2.1.0 8 | - torchvision 9 | - torchaudio 10 | - pytorch-cuda=11.8 11 | - numpy 12 | - pip 13 | - pip: 14 | - lightning==2.0.9 15 | - tensorboard 16 | - ghalton 17 | - tqdm 18 | - meshcat 19 | - rowan 20 | - numpy-quaternion 21 | - pyyaml -------------------------------------------------------------------------------- /dmpo/models/net_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | from .mlp import MLP 3 | 4 | from typing import Optional 5 | 6 | def create_net(net_type: str, in_size: Optional[int]=None, out_size: Optional[int]=None, **kwargs) -> Module: 7 | if net_type == 'mlp': 8 | net = MLP(in_size=in_size, out_size=out_size, **kwargs) 9 | else: 10 | raise ValueError('Invalid network type {} specified'.format(net_type)) 11 | return net 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Runs the installation 2 | from setuptools import find_packages, setup 3 | 4 | # Avoids duplication of requirements 5 | with open("requirements.txt") as file: 6 | requirements = file.read().splitlines() 7 | 8 | setup( 9 | name="dmpo", 10 | author="Jacob Sacks", 11 | author_email="jsacks6@cs.washington.edu", 12 | description="PyTorch code for the paper Deep Model Predictive Optimization", 13 | url="https://github.com/jisacks/dmpo", 14 | install_requires=requirements, 15 | include_package_data=True, 16 | python_requires=">=3.8", 17 | version='1.0.0', 18 | packages=find_packages(), 19 | ) -------------------------------------------------------------------------------- /scripts/ppo_main.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import os 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-c', '--config', type=str, help='Config file to load.') 8 | args = parser.parse_args() 9 | 10 | # Load the configuration file 11 | config_file = args.config 12 | config = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader) 13 | 14 | # Set the seed 15 | seed = config.get('seed', 0) 16 | 17 | # Create the experiment 18 | from dmpo.experiment.ppo_experiment import PPOExperiment 19 | exp = PPOExperiment(**config) 20 | 21 | import numpy as np 22 | import torch 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | 26 | exp.run() -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jake Sacks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/envs/zigzagyaw.yaml: -------------------------------------------------------------------------------- 1 | # simulation parameters 2 | sim_dt: 0.02 3 | sim_tf: 4. 4 | traj: 'zig-zag-yaw' 5 | Vwind: 0 # velocity of wind in world frame, 0 means not considering wind 6 | initial_state: [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.01, 0.] 7 | min_dt: 0.6 # 0.6 8 | max_dt: 1.5 # 1.5 9 | 10 | mass: 0.04 11 | J: [16.571710, 16.655602, 29.261652] 12 | a_min: [0, 0, 0, 0] # bounds of sampling action: [thrust, omega (unit: rad/s)] 13 | a_max: [12, 12, 12, 12] 14 | noise_process_std: [0.3, 2] 15 | 16 | # MPPI parameters 17 | sim_dt_MPPI: 0.02 18 | lam: 0.003 # temparature 19 | H: 40 # horizon 20 | N: 8192 # number of samples 21 | sample_std: [0.25, 0.1, 2., 0.02] # standard deviation for sampling: [thrust (unit: hovering thrust), omega (unit: rad/s)] 22 | gamma_mean: 0.9 # learning rate 23 | gamma_Sigma: 0. # learning rate 24 | omega_gain: 40. # gain of the low-level controller 25 | discount: 0.99 # discount factor in MPPI 26 | 27 | # reward functions 28 | alpha_p: 0.05 29 | alpha_z: 0.0 30 | alpha_w: 0.0 31 | alpha_a: 0.0 32 | alpha_R: 0.05 33 | alpha_v: 0.0 34 | alpha_yaw: 0.0 35 | alpha_pitch: 0.0 36 | alpha_u_delta: 0.0 37 | alpha_u_thrust: 0.01 38 | alpha_u_omega: 0.01 -------------------------------------------------------------------------------- /dmpo/mpc/rollout/rollout_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import Tensor 4 | from typing import Optional, Union, Dict, Any, Callable 5 | 6 | class RolloutGenerator(): 7 | """ 8 | Class which handles rolling out dynamics models 9 | """ 10 | def __init__(self, 11 | rollout_fn: Optional[Callable[[Tensor, Tensor], dict]]=None, 12 | tensor_args: Dict[str, Any]={'device': 'cpu', 'dtype': torch.float32}) -> None: 13 | self._rollout_fn = rollout_fn 14 | self.tensor_args = tensor_args 15 | 16 | @property 17 | def rollout_fn(self): 18 | return self._rollout_fn 19 | 20 | @rollout_fn.setter 21 | def rollout_fn(self, fn: Callable[[Tensor, Tensor], dict]): 22 | self._rollout_fn = fn 23 | 24 | def run_rollouts(self, state: Tensor, act_seq: Tensor) -> Dict[str, Any]: 25 | state = state.to(**self.tensor_args) 26 | act_seq = act_seq.to(**self.tensor_args) 27 | 28 | trajectories = self._rollout_fn(state, act_seq) 29 | return trajectories 30 | 31 | def update_params(self, kwargs: Dict[str, Any]) -> bool: 32 | return self.rollout_fn.update_params(**kwargs) -------------------------------------------------------------------------------- /config/mpc/quadrotor_zigzagyaw_mppi.yml: -------------------------------------------------------------------------------- 1 | rollout: 2 | env_type: quadrotor 3 | num_envs: 1 4 | n_episodes: 1 5 | ep_length: 200 6 | base_seed: 123 7 | break_if_done: False 8 | use_condition: True 9 | dynamic_env: True 10 | 11 | environment: 12 | config: ../config/envs/zigzagyaw.yaml 13 | action_is_mf: False 14 | use_delay_model: True 15 | delay_coeff: 0.4 16 | # randomize_mass: True 17 | # randomize_delay_coeff: True 18 | # force_pert: True 19 | force_is_z: True 20 | mass_range: [0.7, 1.3] 21 | delay_range: [0.2, 0.6] 22 | force_range: [-3.5, 3.5] 23 | use_obs_noise: True 24 | 25 | model: 26 | config: ../config/envs/zigzagyaw.yaml 27 | action_is_mf: False 28 | use_delay_model: True 29 | delay_coeff: 0.4 30 | 31 | cost: 32 | alpha_p: 0.05 33 | alpha_z: 0.0 34 | alpha_w: 0.0 35 | alpha_a: 0.0 36 | alpha_R: 0.05 37 | alpha_v: 0.0 38 | alpha_yaw: 0.0 39 | alpha_pitch: 0.0 40 | alpha_u_delta: 0.0 41 | alpha_u_thrust: 0.01 42 | alpha_u_omega: 0.01 43 | 44 | env_cost: 45 | alpha_p: 0.05 46 | alpha_z: 0.0 47 | alpha_w: 0.0 48 | alpha_a: 0.0 49 | alpha_R: 0.05 50 | alpha_v: 0.0 51 | alpha_yaw: 0.0 52 | alpha_pitch: 0.0 53 | alpha_u_delta: 0.0 54 | alpha_u_thrust: 0.01 55 | alpha_u_omega: 0.01 56 | -------------------------------------------------------------------------------- /dmpo/mpc/task/quadrotor_rollout_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_rollout_task import BaseRolloutTask 4 | from ..model.quadrotor_model import QuadrotorModel 5 | from ..rollout.quadrotor_rollout import QuadrotorRollout 6 | 7 | from typing import Optional, Dict, Any 8 | 9 | class QuadrotorRolloutTask(BaseRolloutTask): 10 | def __init__(self, 11 | exp_params: Dict[str, Any], 12 | num_envs: Optional[int]=None, 13 | tensor_args: Dict[str, Any]={'device': "cpu", 'dtype': torch.float32}, 14 | **kwargs) -> None: 15 | super().__init__( 16 | exp_params=exp_params, 17 | num_envs=num_envs, 18 | tensor_args=tensor_args 19 | ) 20 | 21 | def get_rollout_fn(self, exp_params: Dict[str, Any]): 22 | dynamics_params = exp_params['model'] 23 | self.dynamics_model = QuadrotorModel(tensor_args=self.tensor_args, num_envs=self.num_envs, **dynamics_params) 24 | rollout_fn = QuadrotorRollout(dynamics_model=self.dynamics_model, 25 | tensor_args=self.tensor_args, 26 | num_envs=self.num_envs, 27 | exp_params=exp_params) 28 | return rollout_fn 29 | -------------------------------------------------------------------------------- /dmpo/mpc/task/base_rollout_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..rollout.rollout_generator import RolloutGenerator 4 | 5 | from torch import Tensor 6 | from typing import Optional, Dict, Any, List 7 | 8 | class BaseRolloutTask(): 9 | """ 10 | Base class for defining MPC tasks 11 | """ 12 | def __init__(self, 13 | exp_params: Dict[str, Any], 14 | num_envs: Optional[int]=None, 15 | tensor_args: Dict[str, Any]={'device': "cpu", 'dtype': torch.float32}) -> None: 16 | """ 17 | :param exp_params: Rollout parameters 18 | :param num_envs: # of environments (threads) 19 | :param tensor_args: PyTorch params 20 | """ 21 | self.tensor_args = tensor_args 22 | self.num_envs = num_envs 23 | self.running: bool = False 24 | self.top_idx: Optional[List[Tensor]] = None 25 | self.top_values: Optional[List[Tensor]] = None 26 | self.top_trajs: Optional[List[Tensor]] = None 27 | 28 | self.rollout_fn = self.init_rollout(exp_params) 29 | self.init_aux(exp_params, num_envs) 30 | 31 | def init_aux(self, exp_params: Dict[str, Any], num_envs: int): 32 | pass 33 | 34 | def get_rollout_fn(self, exp_params: Dict[str, Any]): 35 | raise NotImplementedError('Function get_rollout_fn has not implemented.') 36 | 37 | def init_rollout(self, exp_params): 38 | rollout_fn = self.get_rollout_fn(exp_params) 39 | return RolloutGenerator(rollout_fn=rollout_fn, tensor_args=self.tensor_args) 40 | 41 | def run_rollouts(self, states: Tensor, act_seqs: Tensor) -> Dict[str, Any]: 42 | trajectories = self.rollout_fn.run_rollouts(states, act_seqs) 43 | return trajectories 44 | 45 | def update_params(self, kwargs: Dict[str, Any]) -> bool: 46 | self.rollout_fn.update_params(kwargs) 47 | return True 48 | -------------------------------------------------------------------------------- /dmpo/experiment/experiment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.loggers import TensorBoardLogger 3 | 4 | from .experiment_utils import create_model, create_optim 5 | from .. import utils 6 | 7 | from torch import Tensor 8 | from torch.nn import Module 9 | from typing import Dict, Any, Optional, List, Union, Tuple, Callable 10 | 11 | class Experiment(): 12 | """ 13 | Base class for all experiments 14 | """ 15 | def __init__(self, 16 | model_config: Dict[str, Any], 17 | optim_config: Dict[str, Any], 18 | n_epochs: int=1, 19 | log_folder: str='logs', 20 | exp_name: str='experiment', 21 | dtype: str='float', 22 | device: str='cuda'): 23 | """ 24 | :param model_config: Configuration for model 25 | :param optim_config: Configuration for optimizer 26 | :param n_epochs: # of epochs to run 27 | :param n_gpus: # of GPUs to utilize 28 | :param log_folder: Folder to use for logging 29 | :param exp_name: Name of experiment 30 | :param dtype: PyTorch datatype to use 31 | :param device: PyTorch device to use 32 | :param val_every: # of epochs per which we run the validation step 33 | :param max_grad_norm: Maximum gradient norm clip 34 | :param lr_scheduler_config: Configuration for learning rate scheduler (optional) 35 | """ 36 | self.n_epochs = n_epochs 37 | 38 | # Set tensor args 39 | self.dtype = utils.TorchDtype[dtype] 40 | self.device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu') 41 | self.tensor_args = {'device':self.device, 'dtype':self.dtype} 42 | 43 | # Create the model 44 | self.model = create_model(tensor_args=self.tensor_args, **model_config) 45 | self.model.to(**self.tensor_args) 46 | 47 | # Create the logger 48 | self.logger = TensorBoardLogger(log_folder, exp_name) 49 | 50 | # Get the optimizers and learning rate schedulers 51 | self.optim, self.optim_args = create_optim(optim_config) 52 | -------------------------------------------------------------------------------- /dmpo/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch import Tensor 4 | from typing import Optional, List, Union, Dict, Any 5 | 6 | from .. import utils 7 | 8 | class MLP(nn.Module): 9 | """ 10 | Class for constructing multilayer perceptrons 11 | """ 12 | def __init__(self, 13 | in_size: int, 14 | out_size: int, 15 | hidden_size: List[int], 16 | act: Union[str, nn.Module], 17 | dropout_prob: Optional[float]=None, 18 | init_scale: Optional[float]=None, 19 | act_params: Dict[str, Any]={}, 20 | last_linear: bool=True) -> None: 21 | super().__init__() 22 | self.in_size = in_size 23 | self.out_size = out_size 24 | self.hidden_size = hidden_size 25 | 26 | # Convert activation string to Module 27 | if isinstance(act, str) and act != 'identity': 28 | act = utils.ActivationType[act] 29 | 30 | # Create the MLP 31 | net = [] 32 | prev_dim = in_size 33 | if not hidden_size is None and len(hidden_size) > 0: 34 | for hidden in hidden_size: 35 | # Construct the linear layer 36 | layer = nn.Linear(prev_dim, hidden) 37 | net.append(layer) 38 | 39 | # Apply dropout 40 | if not dropout_prob is None: 41 | net.append(nn.Dropout(p=dropout_prob)) 42 | 43 | # Append activation 44 | if act != 'identity': 45 | net.append(act(**act_params)) 46 | 47 | prev_dim = hidden 48 | 49 | # Apply terminal linear layer 50 | if last_linear: 51 | layer = nn.Linear(prev_dim, out_size) 52 | if not init_scale is None: 53 | layer.weight.data.normal_(0, init_scale) 54 | layer.bias.data.fill_(0) 55 | net.append(layer) 56 | 57 | self.net = nn.Sequential(*net) 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | # Automatically reshape input to be (batch_size, input_dim) 61 | x_shape = x.shape 62 | x = x.reshape(-1, x_shape[-1]) 63 | out = self.net(x) 64 | 65 | # Reshape input to original shape 66 | out = out.reshape(*x_shape[:-1], -1) 67 | return out -------------------------------------------------------------------------------- /config/experiments/quadrotor_dmpo_zigzagyaw.yml: -------------------------------------------------------------------------------- 1 | # Experiment configuration 2 | n_iters: 1000 3 | n_epochs: 1 4 | batch_size: 8 5 | n_gpus: 1 6 | log_folder: quadrotor_logs 7 | exp_name: quadrotor_dmpo_zigzagyaw 8 | dtype: float 9 | env_device: cuda 10 | seed: 0 11 | 12 | # Rollout configuration 13 | train_episode_len: 100 14 | val_episode_len: 200 15 | train_episodes: 1 16 | val_episodes: 1 17 | val_every: 10 18 | break_if_done: False 19 | use_condition: True 20 | n_pretrain_epochs: 10 21 | n_pretrain_steps: 1 22 | n_val_envs: 32 23 | 24 | # Training configuration 25 | max_grad_norm: 1 26 | num_workers: 0 27 | 28 | # Dataset configuration 29 | dataset_config: 30 | discount: 0.99 31 | gae_lambda: 0.95 32 | seq_length: 1 33 | stride: 1 34 | 35 | # Environment configuration 36 | env_name: quadrotor 37 | dynamic_env: True 38 | env_config: 39 | num_envs: 8 40 | config: ../config/envs/zigzagyaw.yaml 41 | action_is_mf: False 42 | use_delay_model: True 43 | delay_coeff: 0.4 44 | mass_range: [0.7, 1.3] 45 | delay_range: [0.2, 0.6] 46 | force_range: [-3.5, 3.5] 47 | force_is_z: True 48 | use_obs_noise: True 49 | train_env_config: 50 | randomize_mass: True 51 | randomize_delay_coeff: True 52 | force_pert: True 53 | val_env_config: 54 | randomize_mass: True 55 | randomize_delay_coeff: True 56 | force_pert: True 57 | 58 | # Task configuration 59 | task_config: 60 | task_config_file: ../config/mpc/quadrotor_zigzagyaw_mppi.yml 61 | 62 | # Model configuration 63 | model_config: 64 | model_type: dmpo_policy 65 | d_action: 4 66 | d_state: 13 67 | num_particles: 512 68 | horizon: 32 69 | gamma: 1.0 70 | top_k: 8 71 | init_mean: [0.3924, 0, 0 ,0] 72 | init_std: [0.1, 1., 1., 1.] 73 | mean_search_std: [0.1, 1., 1., 1.] 74 | std_search_std: [0.01, 0.1, 0.1, 0.1] 75 | learn_search_std: True 76 | learn_rollout_std: True 77 | is_delta: False 78 | is_gated: True 79 | is_residual: True 80 | d_cond: 56 81 | cond_mode: cat 82 | cond_actor: False 83 | cond_critic: True 84 | cond_shift: False 85 | critic_use_cost: False 86 | actor_use_state: False 87 | state_scale: null 88 | cond_scale: null 89 | mppi_params: 90 | temperature: 0.05 91 | step_size: 0.8 92 | scale_costs: True 93 | actor_params: 94 | net_type: mlp 95 | hidden_size: [256] 96 | act: relu 97 | init_scale: 0.001 98 | critic_params: 99 | net_type: mlp 100 | hidden_size: [1024] 101 | act: relu 102 | init_scale: 0.001 103 | shift_params: 104 | net_type: mlp 105 | hidden_size: [256] 106 | act: relu 107 | init_scale: 0.001 108 | 109 | # Optimizer configuration 110 | actor_optim_config: 111 | optim: Adam 112 | optim_args: 113 | lr: 0.000001 114 | 115 | critic_optim_config: 116 | optim: Adam 117 | optim_args: 118 | lr: 0.0001 119 | 120 | # Trainer configuration 121 | trainer_config: 122 | clip_epsilon: 0.2 123 | std_clip_epsilon: 0.2 124 | entropy_penalty: 0. 125 | kl_penalty: 0.0 126 | model_subsets: [[actor, shift_model], [critic]] 127 | 128 | -------------------------------------------------------------------------------- /dmpo/experiment/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..envs.quadrotor import QuadrotorEnv 5 | from ..mpc.task.base_rollout_task import BaseRolloutTask 6 | from ..mpc.task.quadrotor_rollout_task import QuadrotorRolloutTask 7 | from ..controllers.dmpo_policy import DMPOPolicy 8 | from .. import utils 9 | 10 | from torch.nn import Module 11 | from typing import Dict, Any, Optional, List, Tuple, Callable 12 | 13 | ModelTypes = ['dmpo_policy'] 14 | EnvTypes = ['quadrotor'] 15 | 16 | def create_model(model_type: str, 17 | tensor_args: Dict[str, Any] = {'device': 'cpu', 'dtype': torch.float}, 18 | **kwargs: Dict[str, Any]) -> Module: 19 | if model_type in ModelTypes: 20 | if model_type == 'dmpo_policy': 21 | return DMPOPolicy(tensor_args=tensor_args, **kwargs) 22 | else: 23 | raise ValueError('Invalid model type {} specified.'.format(model_type)) 24 | 25 | def create_env(env_name: str, 26 | env_config_file: Optional[str]=None, 27 | tensor_args: Dict[str, Any]={'device': 'cpu', 'dtype': torch.float}, 28 | **kwargs: Dict[str, Any]): 29 | if env_name in EnvTypes: 30 | # Load in the environment configuration file 31 | if not env_config_file is None: 32 | env_params = utils.load_yaml(env_config_file) 33 | 34 | # Allow for kwargs to overwrite configuration file 35 | for k, v in env_params.items(): 36 | if not k in kwargs.keys(): 37 | kwargs[k] = v 38 | 39 | # Handle each environment separately 40 | if env_name == 'quadrotor': 41 | return QuadrotorEnv(tensor_args=tensor_args, **kwargs) 42 | else: 43 | raise ValueError('Specified environment type {} is unsupported.'.format(env_name)) 44 | 45 | def create_task(env_name: str, 46 | num_envs: int, 47 | task_config_file: str, 48 | tensor_args: Dict[str, Any] = {'device': 'cpu', 'dtype': torch.float}, 49 | **kwargs: Dict[str, Any]) -> BaseRolloutTask: 50 | if env_name in EnvTypes: 51 | # Load in the task configuration file 52 | task_params = utils.load_yaml(task_config_file) 53 | 54 | # Handle each environment separately 55 | if env_name == 'quadrotor': 56 | return QuadrotorRolloutTask(exp_params=task_params, 57 | num_envs=num_envs, 58 | tensor_args=tensor_args, 59 | **kwargs) 60 | else: 61 | raise ValueError('Specified environment type {} is unsupported.'.format(env_name)) 62 | 63 | def create_optim(config: Dict[str, Any]) -> Tuple[Callable, Dict[str, Any]]: 64 | optim_type = config.get('optim', 'Adam') 65 | optim_args = config.get('optim_args', {}) 66 | 67 | if optim_type == 'Adam': 68 | optim = torch.optim.Adam 69 | elif optim_type == 'SGD': 70 | optim = torch.optim.SGD 71 | elif optim_type == 'Adagrad': 72 | optim = torch.optim.Adagrad 73 | elif optim_type == 'RMSprop': 74 | optim = torch.optim.RMSprop 75 | elif optim_type == 'RAdam': 76 | optim = torch.optim.RAdam 77 | else: 78 | raise ValueError('Specified optimizer type {} unsupported.'.format(optim_type)) 79 | 80 | return optim, optim_args -------------------------------------------------------------------------------- /scripts/run_dmpo_quadrotor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dmpo.envs.quadrotor import QuadrotorEnv 4 | from dmpo.experiment.ppo_rollout import ppo_rollout 5 | from dmpo.experiment.experiment_utils import create_model, create_task 6 | from dmpo import utils 7 | 8 | if __name__ == '__main__': 9 | # Main script configuration 10 | mpc_file = '../config/mpc/quadrotor_zigzagyaw_mppi.yml' 11 | train_file = '../config/experiments/quadrotor_dmpo_zigzagyaw.yml' 12 | model_file = 'quadrotor_logs/quadrotor_dmpo_zigzagyaw/version_0/checkpoints/best.pt' 13 | use_gpu = True 14 | is_mppi = True 15 | 16 | # Set up PyTorch 17 | device = 'cuda' if use_gpu else 'cpu' 18 | env_tensor_args = {'device': device, 'dtype': torch.double} 19 | ctrl_tensor_args = {'device': device, 'dtype': torch.double} 20 | 21 | # Get environment 22 | exp_params = utils.load_yaml(mpc_file) 23 | rollout_params = exp_params['rollout'] 24 | env_params = exp_params['environment'] 25 | env_cost_params = exp_params['env_cost'] 26 | cost_params = exp_params['cost'] 27 | 28 | env_type = rollout_params.pop('env_type') 29 | num_envs = rollout_params.pop('num_envs') 30 | env = QuadrotorEnv(num_envs=num_envs, tensor_args=env_tensor_args, **env_params) 31 | 32 | # Instantiate controller from configuration 33 | if not is_mppi: 34 | model_dict = torch.load(model_file, map_location='cpu') 35 | 36 | state_dict = model_dict['state_dict'] 37 | model_config = model_dict['model_config'] 38 | 39 | train_config = utils.load_yaml(train_file) 40 | task_config = train_config['task_config'] 41 | else: 42 | train_config = utils.load_yaml(train_file) 43 | model_config = train_config['model_config'] 44 | task_config = train_config['task_config'] 45 | 46 | if is_mppi: 47 | model_config['mppi_mode'] = True 48 | model_config['horizon'] = 32 49 | model_config['num_particles'] = 2048 50 | model_config['n_iters'] = 1 51 | task_config['task_config_file'] = mpc_file 52 | 53 | model = create_model(tensor_args=ctrl_tensor_args, **model_config) 54 | if not is_mppi: 55 | model.load_state_dict(state_dict) 56 | model.to(**ctrl_tensor_args) 57 | 58 | model.action_lows = utils.to_tensor(env.action_lows, ctrl_tensor_args) 59 | model.action_highs = utils.to_tensor(env.action_highs, ctrl_tensor_args) 60 | model.use_mean = True 61 | 62 | # Create the MPC task for performing rollouts 63 | task = create_task(env_name=env_type, 64 | num_envs=num_envs, 65 | tensor_args=ctrl_tensor_args, 66 | **task_config) 67 | model.set_task(task) 68 | 69 | # Perform the rollouts 70 | with torch.no_grad(): 71 | trajectories = ppo_rollout(env=env, 72 | model=model, 73 | tensor_args=ctrl_tensor_args, 74 | save_data=False, 75 | run_critic=False, 76 | **rollout_params) 77 | 78 | # Compute statistics 79 | success_dict = env.evaluate_success(trajectories) 80 | stat_dict = utils.compute_statistics(success_dict) 81 | 82 | success_percentage = success_dict['success_percentage'] 83 | mean_success_cost = stat_dict['mean_cost'] 84 | std_success_cost = stat_dict['std_cost'] 85 | 86 | print('Success Metric = {:.2f}, Mean Cost = {:.3e} +/- {:.3e}'.format( 87 | success_percentage, 88 | mean_success_cost, 89 | std_success_cost, 90 | )) 91 | 92 | # Visualize trajectory 93 | states = trajectories[0]['states'] 94 | 95 | mean_state_samples = None 96 | mppi_state_samples = None 97 | 98 | ref_traj = env.ref_trajectory[0].cpu().T 99 | env.visualize(states, ref_traj, env.avg_dt) 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | DMPO 3 |

4 |

5 | Deep Model Predictive Optimization 6 |

7 | 8 |
9 | Jacob Sacks  •  10 | Rwik Rana  •  11 | Kevin Huang  •  12 | Alex Spitzer  •  13 | Guanya Shi  •  14 | Byron Boots 15 |
16 |
17 |
18 | 19 |

20 | Paper | 21 | Website 22 |

23 | 24 | A major challenge in robotics is to design robust policies which enable complex and agile behaviors in the real world. 25 | On one end of the spectrum, we have model-free reinforcement learning (MFRL), which is incredibly flexible and general 26 | but often results in brittle policies. In contrast, model predictive control (MPC) continually re-plans at each time 27 | step to remain robust to perturbations and model inaccuracies. However, despite its real-world successes, MPC often 28 | under-performs the optimal strategy. This is due to model quality, myopic behavior from short planning horizons, and 29 | approximations due to computational constraints. And even with a perfect model and enough compute, MPC can get stuck in 30 | bad local optima, depending heavily on the quality of the optimization algorithm. To this end, we propose Deep Model 31 | Predictive Optimization (DMPO), which learns the inner-loop of an MPC optimization algorithm directly via experience, 32 | specifically tailored to the needs of the control problem. We evaluate DMPO on a real quadrotor agile trajectory 33 | tracking task, on which it improves performance over a baseline MPC algorithm for a given computational budget. 34 | It can outperform the best MPC algorithm by up to 27% with fewer samples and an end-to-end policy trained with MFRL 35 | by 19%. Moreover, because DMPO requires fewer samples, it can also achieve these benefits with 4.3X less memory. 36 | When we subject the quadrotor to turbulent wind fields with an attached drag plate, DMPO can adapt zero-shot while 37 | still outperforming all baselines. 38 | 39 | ## Installation 40 | Create a new conda environment with: 41 | ``` 42 | conda env create -f environment.yml 43 | conda activate dmpo 44 | ``` 45 | and then install the repository with: 46 | ``` 47 | cd src 48 | pip install -e . 49 | cd .. 50 | ``` 51 | ## Test that the MPPI baseline works correctly 52 | We provide a test script in ```scripts/run_dmpo_quadrotor.py```. There is a ```is_mppi``` flag in the configuration section 53 | of the script, which if True, will run DMPO in MPPI mode (no learned residual). This is a good sanity check to make sure 54 | things installed properly. 55 | The script will display the total cost of the trajectory, and enable visualization in the browser via meshcat. 56 | You can run the script by: 57 | ``` 58 | cd scripts 59 | python run_dmpo_quadrotor.py 60 | ``` 61 | To test a trained model, turn the ```is_mppi``` off and specify the model location as the ```model_file``` variable. 62 | 63 | ## How to train a DMPO policy 64 | All experiment configurations are specified via YAML files. 65 | We provide an example for training a quadrotor to perform a zig-zag with yaw flips in 66 | ```config/experiments/quadrotor_dmpo_zigzagyaw.yml```. 67 | To run a training session with this configuration file, perform the following commands: 68 | ``` 69 | python ppo_main.py --config ../config/experiments/quadrotor_dmpo_zigzagyaw.yml 70 | ``` 71 | Once the model is done training, provide the correct path to the ```run_dmpo_quadrotor.py``` test script, run an episode, 72 | and visualize with meshcat. 73 | 74 | ## License 75 | The majority of DMPO is licensed under MIT license, however portions of the project are available under separate license 76 | terms. Pytorch-Lightning is under the Apache License 2.0 license. 77 | See [LICENSE](https://github.com/jisacks/dmpo/blob/main/LICENSE) for details. 78 | -------------------------------------------------------------------------------- /dmpo/envs/math_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # Quaternion routines adapted from rowan to use autograd 5 | def qmultiply(q1, q2): 6 | return np.concatenate(( 7 | np.array([q1[0] * q2[0] - np.sum(q1[1:4] * q2[1:4])]), # w1w2 8 | q1[0] * q2[1:4] + q2[0] * q1[1:4] + np.cross(q1[1:4], q2[1:4]))) 9 | 10 | def qconjugate(q): 11 | return np.concatenate((q[0:1],-q[1:4])) 12 | 13 | def qrotate(q, v): 14 | quat_v = np.concatenate((np.array([0]), v)) 15 | return qmultiply(q, qmultiply(quat_v, qconjugate(q)))[1:] 16 | 17 | def qexp(q): 18 | norm = np.linalg.norm(q[1:4]) 19 | e = np.exp(q[0]) 20 | result_w = e * np.cos(norm) 21 | if np.isclose(norm, 0): 22 | result_v = np.zeros(3) 23 | else: 24 | result_v = e * q[1:4] / norm * np.sin(norm) 25 | return np.concatenate((np.array([result_w]), result_v)) 26 | 27 | def qintegrate(q, v, dt, frame='body'): 28 | quat_v = np.concatenate((np.array([0]), v*dt/2)) 29 | if frame == 'body': 30 | return qmultiply(q, qexp(quat_v)) 31 | if frame == 'world': 32 | return qmultiply(qexp(quat_v), q) 33 | 34 | def qstandardize(q): 35 | if q[0] < 0: 36 | q *= -1 37 | return q / np.linalg.norm(q) 38 | 39 | def qtoR(q): 40 | q0 = q[0] 41 | q1 = q[1] 42 | q2 = q[2] 43 | q3 = q[3] 44 | 45 | # First row of the rotation matrix 46 | r00 = 2 * (q0 * q0 + q1 * q1) - 1 47 | r01 = 2 * (q1 * q2 - q0 * q3) 48 | r02 = 2 * (q1 * q3 + q0 * q2) 49 | 50 | # Second row of the rotation matrix 51 | r10 = 2 * (q1 * q2 + q0 * q3) 52 | r11 = 2 * (q0 * q0 + q2 * q2) - 1 53 | r12 = 2 * (q2 * q3 - q0 * q1) 54 | 55 | # Third row of the rotation matrix 56 | r20 = 2 * (q1 * q3 - q0 * q2) 57 | r21 = 2 * (q2 * q3 + q0 * q1) 58 | r22 = 2 * (q0 * q0 + q3 * q3) - 1 59 | 60 | # 3x3 rotation matrix 61 | rot_matrix = np.array([[r00, r01, r02], \ 62 | [r10, r11, r12], \ 63 | [r20, r21, r22]]) 64 | 65 | return rot_matrix 66 | 67 | 68 | # torch versions: 69 | def sqaured_distance_torch(x, x_d): 70 | ''' 71 | x: actual state, tensor, (N, m) 72 | x_d: desired, tensor, (m,) 73 | output: squared distance, tensor, (N,) 74 | ''' 75 | return torch.einsum('ij,ij->i', x-x_d, x-x_d) 76 | 77 | def qdistance_torch(q1, q2): 78 | ''' 79 | q1: tensor, (N, 4) 80 | q2: tensor, (N, 4) 81 | output: tensor, (N,) 82 | distance = 1 - ^2 83 | ''' 84 | return 1 - torch.einsum('bi,bi->b', q1, q2)**2 85 | 86 | def qmultiply_torch(q1, q2): 87 | ''' 88 | q1: tensor, (N, 4) 89 | q2: tensor, (N, 4) 90 | output: tensor, (N, 4) 91 | ''' 92 | temp = torch.zeros_like(q1) 93 | temp[:, 0] = q1[:,0] * q2[:,0] - torch.sum(q1[:,1:] * q2[:,1:], dim=1) 94 | temp[:, 1:] = q1[:,0:1] * q2[:,1:] + q2[:,0:1] * q1[:,1:] + torch.cross(q1[:,1:],q2[:,1:], dim=1) 95 | return temp 96 | 97 | def qconjugate_torch(q): 98 | temp = torch.zeros_like(q) 99 | if len(q.shape) == 1: 100 | temp[0] = q[0] 101 | temp[1:] = -q[1:] 102 | else: 103 | temp[:,0] = q[:,0] 104 | temp[:,1:] = -q[:,1:] 105 | return temp 106 | 107 | def qrotate_torch(q, v): 108 | ''' 109 | q: tensor, (N, 4) 110 | v: tensor, (N, 3) 111 | output: tensor, (N, 3) 112 | ''' 113 | # a more efficient way suggested by Alex: 114 | temp = 2. * torch.cross(q[:, 1:], v) 115 | return v + q[:, 0:1] * temp + torch.cross(q[:, 1:], temp) 116 | 117 | quat_v = torch.zeros_like(q) 118 | quat_v[:, 1:] = v 119 | return qmultiply_torch(q, qmultiply_torch(quat_v, qconjugate_torch(q)))[:,1:] 120 | 121 | def qexp_torch(q): 122 | ''' 123 | q: tensor, (N, 4) 124 | output: tensor, (N, 4) 125 | ''' 126 | norm = torch.linalg.norm(q[:,1:], dim=1) 127 | e = torch.exp(q[:,0]) 128 | result_w = e * torch.cos(norm) 129 | 130 | N = q.shape[0] 131 | result_v = torch.zeros_like(q[:,1:]) 132 | result_v[norm>0] = e.view(N,1)[norm>0] * q[norm>0,1:] / norm.view(N,1)[norm>0] * torch.sin(norm).view(N,1)[norm>0] 133 | 134 | return torch.cat((result_w.view(N,1), result_v), dim=1) 135 | 136 | def qintegrate_torch(q, v, dt, frame='body'): 137 | ''' 138 | q: tensor, (N, 4) 139 | v: tensor, (N, 3) 140 | output: tensor, (N, 4) 141 | ''' 142 | quat_v = torch.zeros_like(q) 143 | quat_v[:,1:] = v * dt / 2. 144 | if frame == 'body': 145 | return qmultiply_torch(q, qexp_torch(quat_v)) 146 | if frame == 'world': 147 | return qmultiply_torch(qexp_torch(quat_v), q) 148 | 149 | def qstandardize_torch(q): 150 | ''' 151 | q: tensor, (N, 4) 152 | output: tensor, (N, 4) 153 | ''' 154 | return torch.where(q[:, 0:1] < 0, -q, q) 155 | 156 | def qtoR_torch(q): 157 | ''' 158 | q: tensor, (N, 4) 159 | output: rotation matrix tensor, (N, 3, 3) 160 | ''' 161 | q0 = q[:, 0] 162 | q1 = q[:, 1] 163 | q2 = q[:, 2] 164 | q3 = q[:, 3] 165 | R = torch.zeros(q.shape[0], 3, 3).to(q) 166 | 167 | # First row of the rotation matrix 168 | R[:, 0, 0] = 2 * (q0 * q0 + q1 * q1) - 1 169 | R[:, 0, 1] = 2 * (q1 * q2 - q0 * q3) 170 | R[:, 0, 2] = 2 * (q1 * q3 + q0 * q2) 171 | 172 | # Second row of the rotation matrix 173 | R[:, 1, 0] = 2 * (q1 * q2 + q0 * q3) 174 | R[:, 1, 1] = 2 * (q0 * q0 + q2 * q2) - 1 175 | R[:, 1, 2] = 2 * (q2 * q3 - q0 * q1) 176 | 177 | # Third row of the rotation matrix 178 | R[:, 2, 0] = 2 * (q1 * q3 - q0 * q2) 179 | R[:, 2, 1] = 2 * (q2 * q3 + q0 * q1) 180 | R[:, 2, 2] = 2 * (q0 * q0 + q3 * q3) - 1 181 | 182 | return R 183 | 184 | def get_quaternion_from_euler(roll, pitch, yaw): 185 | qx = np.sin(roll / 2) * np.cos(pitch / 2) * np.cos(yaw / 2) - np.cos(roll / 2) * np.sin(pitch / 2) * np.sin(yaw / 2) 186 | qy = np.cos(roll / 2) * np.sin(pitch / 2) * np.cos(yaw / 2) + np.sin(roll / 2) * np.cos(pitch / 2) * np.sin(yaw / 2) 187 | qz = np.cos(roll / 2) * np.cos(pitch / 2) * np.sin(yaw / 2) - np.sin(roll / 2) * np.sin(pitch / 2) * np.cos(yaw / 2) 188 | qw = np.cos(roll / 2) * np.cos(pitch / 2) * np.cos(yaw / 2) + np.sin(roll / 2) * np.sin(pitch / 2) * np.sin(yaw / 2) 189 | return np.stack([qx, qy, qz, qw], axis=-1) -------------------------------------------------------------------------------- /dmpo/envs/quadrotor_param.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch import Tensor 5 | from numpy.typing import NDArray 6 | from typing import Dict, Any, Optional, List, Union, Tuple 7 | 8 | class QuadrotorParam(): 9 | """ 10 | Class containing parameters for quadrotor model 11 | """ 12 | def __init__(self, 13 | config: Dict[str, Any], 14 | MPPI: bool=False): 15 | """ 16 | :param config: dictionary containing configuration 17 | :param MPPI: indicates if we should use the MPPI controller parameters (which may be different) 18 | """ 19 | self.env_name = 'Quadrotor-Crazyflie' 20 | 21 | # Quadrotor parameters 22 | self.sim_t0 = 0 23 | self.sim_tf = config['sim_tf'] 24 | self.sim_dt = config['sim_dt'] 25 | self.sim_times = np.arange(self.sim_t0, self.sim_tf, self.sim_dt) 26 | if MPPI: 27 | self.sim_dt = config['sim_dt_MPPI'] 28 | self.sim_times = np.arange(self.sim_t0, self.sim_tf, self.sim_dt) 29 | 30 | # Control limits [N] for motor forces 31 | self.a_min = np.array(config.get('a_min', [0., 0, 0, 0])) 32 | self.a_max = np.array(config.get('a_max', [12., 12., 12., 12])) / 1000 * 9.81 # g->N 33 | 34 | # Crazyflie 2.0 quadrotor.py 35 | self.mass = config.get('mass', 0.034) #kg 36 | self.J = np.array(config.get('J', [16.571710, 16.655602, 29.261652])) * 1e-6 37 | self.d = 0.047 38 | 39 | # Side force model parameters for wind perturbations 40 | if config['Vwind'] == 0: 41 | self.wind = False 42 | self.Vwind = None 43 | else: 44 | self.wind = True 45 | self.Vwind = np.array(config['Vwind']) # velocity of wind in world frame 46 | self.Ct = 2.87e-3 47 | self.Cs = 2.31e-5 48 | self.k1 = 1.425 49 | self.k2 = 3.126 50 | self.rho = 1.225 # air density (in SI units) 51 | 52 | # Note: we assume here that our control is forces 53 | arm_length = 0.046 # m 54 | arm = 0.707106781 * arm_length 55 | t2t = 0.006 # thrust-to-torque ratio 56 | self.t2t = t2t 57 | self.B0 = np.array([ 58 | [1, 1, 1, 1], 59 | [-arm, -arm, arm, arm], 60 | [-arm, arm, arm, -arm], 61 | [-t2t, t2t, -t2t, t2t] 62 | ]) 63 | self.g = 9.81 # not signed 64 | 65 | # Exploration parameters: state boundary and initial state sampling range 66 | self.s_min = np.array( \ 67 | [-8, -8, -8, \ 68 | -5, -5, -5, \ 69 | -1.001, -1.001, -1.001, -1.001, 70 | -20, -20, -20]) 71 | self.rpy_limit = np.array([5, 5, 5]) 72 | self.limits = np.array([0.1,0.1,0.1,0.1,0.1,0.1,0,0,0,0,0,0,0]) 73 | 74 | # Measurement noise 75 | self.noise_measurement_std = np.zeros(13) 76 | self.noise_measurement_std[:3] = 0.005 77 | self.noise_measurement_std[3:6] = 0.005 78 | self.noise_measurement_std[6:10] = 0.01 79 | self.noise_measurement_std[10:] = 0.01 80 | 81 | # Process noise 82 | self.noise_process_std = config.get('noise_process_std', [0.3, 2.]) 83 | 84 | # Reference trajectory parameters 85 | self.ref_type = config['traj'] 86 | self.max_dist = config.get('max_dist', [1., 1., 0.]) 87 | self.min_dt = config.get('min_dt', 0.6) 88 | self.max_dt = config.get('max_dt', 1.5) 89 | 90 | # Cost function parameters 91 | self.alpha_p = config['alpha_p'] 92 | self.alpha_z = config['alpha_z'] 93 | self.alpha_w = config['alpha_w'] 94 | self.alpha_a = config['alpha_a'] 95 | self.alpha_R = config['alpha_R'] 96 | self.alpha_v = config['alpha_v'] 97 | self.alpha_yaw = config['alpha_yaw'] 98 | self.alpha_pitch = config['alpha_pitch'] 99 | self.alpha_u_delta = config['alpha_u_delta'] 100 | self.alpha_u_thrust = config['alpha_u_thrust'] 101 | self.alpha_u_omega = config['alpha_u_omega'] 102 | 103 | def get_reference(self, 104 | num_envs: int, 105 | dts: Optional[NDArray]=None, 106 | pos: Optional[NDArray]=None) -> Tuple[Tensor, NDArray, NDArray]: 107 | """ 108 | :param num_envs: # of reference trajectories to generate (one per environment) 109 | :param dts: delta in time between waypoints (randomly generated if not specified) 110 | :param pos: waypoints for zig-zag (randomly generated if not specified) 111 | """ 112 | 113 | self.ref_trajectory = np.zeros((num_envs, 13, len(self.sim_times))) 114 | self.ref_trajectory[:, 6, :] = 1. 115 | 116 | if self.ref_type == 'zig-zag-yaw': 117 | if dts is None: 118 | dts = np.random.uniform(self.min_dt, self.max_dt, size=(num_envs, len(self.sim_times), 1)) 119 | dts = dts.repeat(3, axis=2) 120 | d_times = dts.cumsum(axis=1) 121 | if pos is None: 122 | pos = np.random.uniform(0, np.array(self.max_dist), size=(num_envs, len(self.sim_times)//2, 3)) 123 | 124 | for env_idx in range(num_envs): 125 | for p_idx in range(3): 126 | for step, time in enumerate(self.sim_times[1:]): 127 | ref_idx = np.searchsorted(d_times[env_idx, :, p_idx], time) 128 | sign = 1 if np.ceil(ref_idx/2 - ref_idx//2) == 0 else -1 129 | ref_pos = sign * pos[env_idx, ref_idx, p_idx] 130 | prev_ref_pos = -1 * sign * pos[env_idx, ref_idx-1, p_idx] if ref_idx > 0 else 0 131 | ref_time = d_times[env_idx, ref_idx, p_idx] 132 | prev_ref_time = d_times[env_idx, ref_idx-1, p_idx] if ref_idx > 0 else 0 133 | cur_pos = self.ref_trajectory[env_idx, p_idx, step] 134 | 135 | delta = ref_pos - prev_ref_pos 136 | if delta != 0: 137 | delta = delta/(ref_time-prev_ref_time)*self.sim_dt 138 | self.ref_trajectory[env_idx, p_idx, step+1] = cur_pos + delta 139 | self.ref_trajectory[env_idx, 6, step+1] = 1 if sign == 1 else 0 140 | self.ref_trajectory[env_idx, 9, step+1] = 1 if sign == -1 else 0 141 | elif self.ref_type == 'zig-zag': 142 | if dts is None: 143 | dts = np.random.uniform(self.min_dt, self.max_dt, size=(num_envs, len(self.sim_times), 1)) 144 | dts = dts.repeat(3, axis=2) 145 | d_times = dts.cumsum(axis=1) 146 | if pos is None: 147 | pos = np.random.uniform(0, np.array(self.max_dist), size=(num_envs, len(self.sim_times)//2, 3)) 148 | 149 | for env_idx in range(num_envs): 150 | for p_idx in range(3): 151 | for step, time in enumerate(self.sim_times[1:]): 152 | ref_idx = np.searchsorted(d_times[env_idx, :, p_idx], time) 153 | sign = 1 if np.ceil(ref_idx/2 - ref_idx//2) == 0 else -1 154 | ref_pos = sign * pos[env_idx, ref_idx, p_idx] 155 | prev_ref_pos = -1 * sign * pos[env_idx, ref_idx-1, p_idx] if ref_idx > 0 else 0 156 | ref_time = d_times[env_idx, ref_idx, p_idx] 157 | prev_ref_time = d_times[env_idx, ref_idx-1, p_idx] if ref_idx > 0 else 0 158 | cur_pos = self.ref_trajectory[env_idx, p_idx, step] 159 | 160 | delta = ref_pos - prev_ref_pos 161 | if delta != 0: 162 | delta = delta/(ref_time-prev_ref_time)*self.sim_dt 163 | self.ref_trajectory[env_idx, p_idx, step+1] = cur_pos + delta 164 | else: 165 | raise ValueError('Invalid reference trajectory type specified.') 166 | 167 | return torch.tensor(self.ref_trajectory), dts, pos 168 | 169 | 170 | -------------------------------------------------------------------------------- /dmpo/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import copy 6 | import yaml 7 | import ghalton 8 | 9 | from torch import Tensor 10 | from numpy.typing import NDArray 11 | from typing import Optional, Tuple, Dict, Any, Union, List 12 | 13 | ActivationType = {'relu': nn.ReLU, 14 | 'sigmoid': nn.Sigmoid, 15 | 'tanh': nn.Tanh, 16 | 'elu': nn.ELU, 17 | 'leakyrelu': nn.LeakyReLU, 18 | 'prelu': nn.PReLU, 19 | 'relu6': nn.ReLU6, 20 | 'selu': nn.SELU, 21 | 'celu': nn.CELU, 22 | 'gelu': nn.GELU, 23 | 'silu': nn.SiLU, 24 | 'mish': nn.Mish, 25 | 'softplus': nn.Softplus, 26 | 'tanhshrink': nn.Tanhshrink 27 | } 28 | TorchDtype = {'float': torch.float, 'double': torch.double, 'long': torch.long} 29 | 30 | def compute_statistics(success_dict: Dict[str, Any]) -> Dict[str, Any]: 31 | successes = success_dict['successes'] 32 | total_costs = success_dict['total_costs'] 33 | 34 | mean_success_cost = torch.mean(total_costs[successes]) 35 | std_success_cost = torch.std(total_costs[successes]) 36 | mean_fail_cost = torch.mean(total_costs[~successes]) 37 | std_fail_cost = torch.std(total_costs[~successes]) 38 | mean_cost = torch.mean(total_costs) 39 | std_cost = torch.std(total_costs) 40 | ret_dict = dict( 41 | mean_success_cost=mean_success_cost, 42 | std_success_cost=std_success_cost, 43 | mean_fail_cost=mean_fail_cost, 44 | std_fail_cost=std_fail_cost, 45 | mean_cost=mean_cost, 46 | std_cost=std_cost 47 | ) 48 | return ret_dict 49 | 50 | def make_dir(path: str): 51 | """ 52 | Create a directory if the given path does not exist 53 | """ 54 | if not os.path.exists(path): 55 | os.makedirs(path) 56 | 57 | def remove_file(path: str): 58 | if os.path.exists(path): 59 | os.remove(path) 60 | 61 | def load_yaml(filename: str) -> Dict[str, Any]: 62 | with open(filename) as file: 63 | params = yaml.load(file, Loader=yaml.FullLoader) 64 | return params 65 | 66 | def to_tensor(input: Union[List[Union[Tensor, NDArray, float]], Tensor, NDArray, float], 67 | tensor_args={'device':'cpu', 'dtype':torch.float32}) -> Tensor: 68 | if isinstance(input, list): 69 | if isinstance(input[0], list): 70 | return torch.tensor(input, **tensor_args) 71 | elif isinstance(input[0], np.ndarray): 72 | return torch.stack([to_tensor(x) for x in input]) 73 | elif isinstance(input[0], torch.Tensor): 74 | return torch.stack(input).to(**tensor_args) 75 | elif isinstance(input[0], float): 76 | return torch.tensor(input, **tensor_args) 77 | else: 78 | raise ValueError('Invalid input to convert to tensor.') 79 | elif isinstance(input, torch.Tensor): 80 | return input.to(**tensor_args) 81 | elif isinstance(input, np.ndarray): 82 | return torch.from_numpy(input).to(**tensor_args) 83 | elif isinstance(input, float): 84 | return torch.tensor(input, **tensor_args) 85 | else: 86 | raise ValueError('Invalid input to convert to tensor.') 87 | 88 | def as_list(x: Union[List[Any], Any], n: int =1) -> List[Any]: 89 | """ 90 | Return a variable as a list 91 | :param x: Variable to be copied n times in a list (returns x if already a list) 92 | :param n: # of times to copy the variable 93 | """ 94 | if isinstance(x, list): 95 | return x 96 | else: 97 | return [x for _ in range(n)] 98 | 99 | def stack_list_tensors(input: List[Tensor]) -> List[Tensor]: 100 | """ 101 | Converts a list (length T) of Tensors (N, ...) into a list (length N) of Tensors (T, ...) 102 | """ 103 | out = [] 104 | for idx in range(len(input[0])): 105 | tensor_list = torch.stack([x[idx] for x in input], dim=0) 106 | out.append(tensor_list) 107 | return out 108 | 109 | def transpose_dict_list(input: List[Dict[str, Any]]) -> List[Dict[str, List[Any]]]: 110 | """ 111 | Converts a list (length T) of dictionaries of lists\Tensors (length N) into a list (length N) of dictionaries 112 | consisting of lists (length T) 113 | """ 114 | keys = list(input[0].keys()) 115 | length = len(input[0][keys[0]]) 116 | out = list({} for _ in range(length)) 117 | for k in keys: 118 | y = [x[k] for x in input] 119 | for idx in range(length): 120 | v = [data[idx] if isinstance(data, list) or isinstance(data, torch.Tensor) else data for data in y] 121 | v = [x.detach() for x in v if isinstance(x, torch.Tensor)] 122 | out[idx][k] = copy.deepcopy(v) 123 | return out 124 | 125 | def stack_list_list_dict_tensors(input: List[List[Dict[str, Tensor]]]) -> List[Dict[str, Tensor]]: 126 | """ 127 | Converts a list (length T) of lists (length N) of dictionaries of Tensors to a list (length N) of dictionaries of 128 | Tensors (T, ...) 129 | """ 130 | keys = list(input[0][0].keys()) 131 | length = len(input[0]) 132 | out = list({} for _ in range(length)) 133 | for k in keys: 134 | for idx in range(length): 135 | out[idx][k] = [] 136 | for t in range(len(input)): 137 | y = [x[k] for x in input[t]] 138 | out[idx][k].append(copy.deepcopy(y[idx])) 139 | out[idx][k] = torch.stack(out[idx][k]) 140 | return out 141 | 142 | def load_struct_from_dict(struct_instance, dict_instance): 143 | for key in dict_instance.keys(): 144 | if (hasattr(struct_instance, key)): 145 | if (isinstance(dict_instance[key], dict)): 146 | sub_struct = load_struct_from_dict(getattr(struct_instance, key), dict_instance[key]) 147 | setattr(struct_instance, key, sub_struct) 148 | else: 149 | setattr(struct_instance, key, dict_instance[key]) 150 | return struct_instance 151 | 152 | def set_if_empty(dict, key, value): 153 | if not key in dict.keys(): 154 | dict[key] = value 155 | return dict 156 | 157 | def generate_gaussian_halton_samples(num_samples: int, 158 | ndims: int, 159 | seed_val: int=123, 160 | device='cpu', 161 | dtype=torch.float32) -> Tensor: 162 | """ 163 | Generate Halton sequence and transform to Gaussian distribution 164 | :param num_samples: # of samples 165 | :param ndims: # of independent Gaussians from which to sample 166 | :param seed_val: Seed for the Halton sequence generator 167 | :param device: PyTorch device 168 | :param dtype: PyTorch dtype 169 | """ 170 | sequencer = ghalton.GeneralizedHalton(ndims, seed_val) 171 | uniform_halton_samples = torch.tensor(sequencer.get(num_samples), device=device, dtype=dtype) 172 | gaussian_halton_samples = torch.sqrt(torch.tensor([2.0], device=device, dtype=dtype)) \ 173 | * torch.erfinv(2 * uniform_halton_samples - 1) 174 | return gaussian_halton_samples -------------------------------------------------------------------------------- /dmpo/mpc/model/quadrotor_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import yaml 4 | 5 | from ...envs.quadrotor_param import QuadrotorParam 6 | from ...envs.math_utils import * 7 | 8 | from torch import Tensor 9 | from numpy.typing import NDArray 10 | from typing import Dict, Any, Optional, List, Union, Tuple 11 | 12 | class QuadrotorModel(): 13 | ''' 14 | Quadrotor model used by MPC 15 | ''' 16 | def __init__(self, 17 | config: str, 18 | num_envs: int=1, 19 | use_omega: bool = False, 20 | action_is_mf: bool = True, 21 | convert_mf_to_omega: bool = False, 22 | use_delay_model: bool = False, 23 | delay_coeff: float = 0.4, 24 | tensor_args: Dict[str, Any]={'device': 'cpu', 'dtype': torch.float32}): 25 | """ 26 | :param config: YAML configuration file, which will be parsed to form a QuadrotorParam object 27 | :param num_envs: # of parallel environments to simulation 28 | :param use_omega: use the omega controller, which converts desired (thrust, omega) to motor forces 29 | :param action_is_mf: specified that the action space is motor forces 30 | :param convert_mf_to_omega: converts a motor force command to desired (thrust, omega) to model Crazyflie 31 | :param use_delay_model: use the delay model to translate desired (thrust, omega) into actual thrust and omega 32 | :param delay_coeff: coefficient of delay model 33 | :param tensor_args: PyTorch tensor arguments 34 | """ 35 | super().__init__() 36 | self.num_envs = num_envs 37 | self.use_omega = use_omega 38 | self.action_is_mf = action_is_mf 39 | self.convert_mf_to_omega = convert_mf_to_omega 40 | self.use_delay_model = use_delay_model 41 | self.delay_coeff = delay_coeff 42 | self.tensor_args = tensor_args 43 | 44 | # Get the quadrotor configuration 45 | self.config = yaml.load(open(config), yaml.FullLoader) 46 | self.param = QuadrotorParam(self.config, MPPI=True) 47 | 48 | # Init timing 49 | self.times = self.param.sim_times 50 | self.time_step = 0 51 | self.avg_dt = self.times[1] - self.times[0] 52 | 53 | # Init system state 54 | self.init_state = torch.tensor(self.config['initial_state'], **self.tensor_args) 55 | 56 | # Control bounds 57 | self.a_min = torch.tensor(self.param.a_min, **self.tensor_args) 58 | self.a_max = torch.tensor(self.param.a_max, **self.tensor_args) 59 | 60 | if (not self.action_is_mf and not self.convert_mf_to_omega) or use_omega: 61 | self.action_lows = torch.tensor([0., -10, -10, -10], **self.tensor_args) 62 | #self.action_highs = torch.tensor([self.a_max[0]*4, 12, 12, 12], **self.tensor_args) 63 | self.action_highs = torch.tensor([0.7848, 10, 10, 10], **self.tensor_args) 64 | else: 65 | self.action_lows = self.a_min 66 | self.action_highs = self.a_max 67 | 68 | # Initial conditions 69 | self.s_min = torch.tensor(self.param.s_min, **self.tensor_args) 70 | self.s_max = -self.s_min 71 | self.rpy_limit = torch.tensor(self.param.rpy_limit, **self.tensor_args) 72 | self.limits = torch.tensor(self.param.limits, **self.tensor_args) 73 | 74 | # Constants 75 | self.d_state = 13 76 | self.d_obs = 13 77 | self.d_action = 4 78 | 79 | self.mass = self.param.mass 80 | self.g = self.param.g 81 | self.inv_mass = 1 / self.mass 82 | 83 | self.d = self.param.d 84 | self.rho = self.param.rho 85 | self.Cs = self.param.Cs 86 | self.Ct = self.param.Ct 87 | self.k1 = self.param.k1 88 | self.k2 = self.param.k2 89 | 90 | self.B0 = torch.tensor(self.param.B0, **self.tensor_args) 91 | self.B0_inv = torch.linalg.inv(self.B0) 92 | 93 | self.J = torch.tensor(self.param.J, **self.tensor_args) 94 | if self.J.shape == (3, 3): 95 | self.J = torch.as_tensor(self.J, **self.tensor_args) 96 | self.inv_J = torch.linalg.inv(self.J) 97 | else: 98 | self.J = torch.diag(torch.as_tensor(self.J, **self.tensor_args)) 99 | self.inv_J = torch.linalg.inv(self.J) 100 | 101 | # Controller gains 102 | self.omega_gain = self.config['omega_gain'] 103 | 104 | # Plotting stuff 105 | self.states_name = [ 106 | 'Position X [m]', 107 | 'Position Y [m]', 108 | 'Position Z [m]', 109 | 'Velocity X [m/s]', 110 | 'Velocity Y [m/s]', 111 | 'Velocity Z [m/s]', 112 | 'qw', 113 | 'qx', 114 | 'qy', 115 | 'qz', 116 | 'Angular Velocity X [rad/s]', 117 | 'Angular Velocity Y [rad/s]', 118 | 'Angular Velocity Z [rad/s]'] 119 | 120 | self.deduced_state_names = [ 121 | 'Roll [deg]', 122 | 'Pitch [deg]', 123 | 'Yaw [deg]', 124 | ] 125 | 126 | self.actions_name = [ 127 | 'Motor Force 1 [N]', 128 | 'Motor Force 2 [N]', 129 | 'Motor Force 3 [N]', 130 | 'Motor Force 4 [N]'] 131 | 132 | def f(self, s: Tensor, a: Tensor) -> Tensor: 133 | num_envs, num_samples, d_state = s.shape 134 | dsdt = torch.zeros(num_envs*num_samples, 13).to(**self.tensor_args) 135 | v = s[:, :, 3:6].view(-1, 3) # velocity (N, 3) 136 | q = s[:, :, 6:10].view(-1, 4) # quaternion (N, 4) 137 | omega = s[:, :, 10:].view(-1, 3) # angular velocity (N, 3) 138 | a = a.view(-1, 4) 139 | 140 | if self.action_is_mf and not self.convert_mf_to_omega: 141 | # If action space is motor forces and we did not convert to omega space, then compute wrench 142 | eta = a @ self.B0.T # output wrench (N, 4) 143 | else: 144 | # Otherwise, our action is (thrust, omega) 145 | eta = a 146 | 147 | f_u = torch.zeros(num_envs*num_samples, 3).to(**self.tensor_args) 148 | f_u[:, 2] = eta[:, 0] # total thrust (N, 3) 149 | tau_u = eta[:, 1:] # torque (N, 3) 150 | 151 | # dynamics 152 | # \dot{p} = v 153 | dsdt[:, :3] = v # <- implies velocity and position in same frame 154 | 155 | # mv = mg + R f_u # <- implies f_u in body frame, p, v in world frame 156 | dsdt[:, 5] -= self.g 157 | dsdt[:, 3:6] += qrotate_torch(q, f_u) / self.mass 158 | 159 | # \dot{R} = R S(w) 160 | # see https://rowan.readthedocs.io/en/latest/package-calculus.html 161 | qnew = qintegrate_torch(q, omega, self.avg_dt, frame='body') 162 | qnew = qstandardize_torch(qnew) 163 | 164 | # transform qnew to a "delta q" that works with the usual euler integration 165 | dsdt[:, 6:10] = (qnew - q) / self.avg_dt 166 | 167 | if self.action_is_mf and not self.convert_mf_to_omega: 168 | # Compute omega from torques 169 | # J\dot{w} = Jw x w + tau_u 170 | Jomega = omega @ self.J.T 171 | dsdt[:, 10:] = torch.cross(Jomega, omega) + tau_u 172 | dsdt[:, 10:] = dsdt[:, 10:] @ self.inv_J.T 173 | else: 174 | # Set updated omega to be the control command 175 | dsdt[:, 10:] = (tau_u - omega) / self.avg_dt 176 | 177 | dsdt = dsdt.view(num_envs, num_samples, -1) 178 | return dsdt 179 | 180 | def step(self, s: Tensor, a: Tensor) -> Tensor: 181 | new_s = s + self.avg_dt * self.f(s, a) 182 | return new_s 183 | 184 | def rollout_open_loop(self, start_state: Tensor, act_seq: Tensor) -> Dict[str, Any]: 185 | num_particles, horizon, _ = act_seq.shape 186 | state_seq = torch.zeros((num_particles, horizon, self.d_state), **self.tensor_args) 187 | state_t = start_state.repeat((num_particles, 1)) 188 | 189 | for t in range(horizon): 190 | state_t = self.step(state_t, act_seq[:,t]) 191 | state_seq[:, t] = state_t 192 | 193 | trajectories = dict( 194 | actions = act_seq, 195 | state_seq = state_seq 196 | ) 197 | 198 | return trajectories 199 | 200 | def get_next_state(self, curr_state, act, dt): 201 | pass -------------------------------------------------------------------------------- /dmpo/experiment/ppo_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | from time import time 5 | 6 | from .. import utils 7 | 8 | from torch.nn import Module 9 | from typing import Dict, Any, Optional, List 10 | 11 | def ppo_rollout(env, 12 | model: Module, 13 | n_episodes: int=1, 14 | ep_length: int=1, 15 | base_seed: int=0, 16 | break_if_done: bool=True, 17 | use_condition: bool=False, 18 | dynamic_env: bool=False, 19 | use_tqdm: bool=True, 20 | save_data: bool=True, 21 | run_critic=True, 22 | tensor_args: Dict[str, Any]={'device': 'cpu', 'dtype': torch.float32}) -> List[Dict[str, Any]]: 23 | """ 24 | Rollout controller in environment 25 | :param env: Environment to rollout in 26 | :param model: Container of actor to rollout and critic 27 | :param n_episodes: # of episodes 28 | :param ep_length: Episode length 29 | :param base_seed: Base seed (which is modified per episode) 30 | :param break_if_done: Break from episode loop if done signal has been achieved 31 | :param use_condition: Extract conditional information from environment to give to the policy 32 | :param dynamic_env: Run a dynamic environment 33 | :param save_data: Flag to indicate if we should return all rollout data 34 | :param run_critic: Flag to indicate if we should run the critic 35 | :param tensor_args: PyTorch Tensor settings 36 | """ 37 | trajectories = [] 38 | 39 | for ep in range(n_episodes): 40 | 41 | # Set the seed 42 | episode_seed = base_seed + 12345*ep 43 | np.random.seed(episode_seed) 44 | torch.random.manual_seed(episode_seed) 45 | if tensor_args['device'] == 'cuda': 46 | torch.cuda.manual_seed(episode_seed) 47 | 48 | # Reset the environment 49 | env.reset() 50 | param_dict = env.get_param_dict() 51 | 52 | # Optional settings if using MPC under the hood 53 | if hasattr(model, 'update_params'): 54 | model.update_params(param_dict) 55 | 56 | if hasattr(model, 'set_seed'): 57 | model.set_seed(episode_seed) 58 | 59 | # Retrieve environment information for conditioning model 60 | if use_condition: 61 | env_info = env.get_env_description() 62 | cond = torch.stack(env_info, dim=0).to(**tensor_args) 63 | 64 | if hasattr(model, 'set_cond'): 65 | model.set_cond(env_info) 66 | else: 67 | cond = None 68 | 69 | # Reset the actor and critic 70 | if hasattr(model, 'reset'): 71 | model.reset() 72 | 73 | # Perform an episode 74 | state_list = []; cond_list = []; reward_list = []; info_list = [] 75 | action_list = []; horizon_list = []; costs_list = []; value_list = [] 76 | old_mean_list = []; old_std_list = [] 77 | time_step_list = [] 78 | 79 | pbar = tqdm(range(ep_length)) if use_tqdm else range(ep_length) 80 | for t in pbar: 81 | # Get current state 82 | curr_states = env.get_env_state() 83 | 84 | if hasattr(env, 'get_env_obs'): 85 | curr_obs = env.get_env_obs() 86 | else: 87 | curr_obs = curr_states 88 | 89 | # Run the actor and critic 90 | out = model(curr_obs, cond=cond, run_critic=run_critic) 91 | 92 | # Retrieve model outputs 93 | action = out['action'] 94 | value = out['value'] 95 | horizon = out['horizon'] if 'horizon' in out.keys() else None 96 | old_mean = out['old_mean'] if 'old_mean' in out.keys() else None 97 | old_std = out['old_std'] if 'old_std' in out.keys() else None 98 | costs = out['costs'] if 'costs' in out.keys() else None 99 | 100 | # Prepare action to be applied 101 | if isinstance(action, list) and isinstance(action[0], torch.Tensor): 102 | action = torch.stack(action).cpu() 103 | 104 | # Take a step in the environment 105 | with torch.no_grad(): 106 | obs, reward, done, info = env.step(action) 107 | 108 | # Store information 109 | state_list.append(torch.stack(curr_states).cpu() if not isinstance(curr_states, torch.Tensor) else curr_states.cpu()) 110 | action_list.append(action.cpu()) 111 | reward_list.append(reward.cpu()) 112 | info_list.append(info) 113 | 114 | if save_data: 115 | value_list.append(value.cpu()) 116 | time_step_list.append(t) 117 | 118 | if not horizon is None: 119 | horizon_list.append(horizon.cpu()) 120 | 121 | if not cond is None: 122 | cond_list.append(cond.cpu()) 123 | 124 | if not costs is None: 125 | costs_list.append(costs.cpu()) 126 | else: 127 | costs_list.append(None) 128 | 129 | if not old_mean is None: 130 | old_mean_list.append(old_mean.cpu()) 131 | else: 132 | old_mean_list.append(None) 133 | 134 | if not old_std is None: 135 | old_std_list.append(old_std.cpu()) 136 | else: 137 | old_std_list.append(None) 138 | 139 | # Handle a dynamic environment 140 | if dynamic_env: 141 | param_dict = env.get_param_dict() 142 | if hasattr(model, 'update_params'): 143 | model.update_params(param_dict) 144 | 145 | if use_condition: 146 | env_info = env.get_env_description() 147 | cond = torch.stack(env_info, dim=0).to(**tensor_args) 148 | 149 | if hasattr(model, 'set_cond'): 150 | model.set_cond(env_info) 151 | 152 | # Break if done 153 | if break_if_done and done: 154 | break 155 | 156 | # Store trajectories 157 | states = utils.stack_list_tensors(state_list) 158 | actions = utils.stack_list_tensors(action_list) 159 | rewards = utils.stack_list_tensors(reward_list) 160 | infos = utils.transpose_dict_list(info_list) 161 | 162 | if save_data: 163 | horizons = utils.stack_list_tensors(horizon_list) 164 | values = utils.stack_list_tensors(value_list) 165 | time_steps = [torch.tensor(time_step_list).unsqueeze(1) for _ in range(env.num_envs)] 166 | 167 | if not cond is None: 168 | conds = utils.stack_list_tensors(cond_list) 169 | else: 170 | conds = None 171 | 172 | if not costs_list[0] is None: 173 | costs = utils.stack_list_tensors(costs_list) 174 | else: 175 | costs = None 176 | 177 | if not old_mean_list[0] is None: 178 | old_means = utils.stack_list_tensors(old_mean_list) 179 | else: 180 | old_means = None 181 | 182 | if not old_std_list[0] is None: 183 | old_stds = utils.stack_list_tensors(old_std_list) 184 | else: 185 | old_stds = None 186 | 187 | for idx in range(env.num_envs): 188 | if save_data: 189 | traj = dict( 190 | states=states[idx], 191 | rewards=rewards[idx], 192 | infos=infos[idx], 193 | actions=actions[idx], 194 | horizons=horizons[idx], 195 | costs=costs[idx] if not costs is None else None, 196 | values=values[idx], 197 | conds=conds[idx] if not conds is None else None, 198 | means=old_means[idx] if old_means is not None else None, 199 | stds=old_stds[idx] if not old_stds is None else None, 200 | time_steps=time_steps[idx] 201 | ) 202 | else: 203 | traj = dict( 204 | states=states[idx], 205 | actions=actions[idx], 206 | rewards=rewards[idx], 207 | infos=infos[idx] 208 | ) 209 | trajectories.append(traj) 210 | return trajectories 211 | -------------------------------------------------------------------------------- /dmpo/dataset/dataset_buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.utils.data.sampler import SubsetRandomSampler, Sampler, Sequence 4 | import numpy as np 5 | 6 | from .. import utils 7 | 8 | from torch import Tensor 9 | from torch.nn import Module 10 | from typing import Dict, Any, Optional, List, Union, Tuple, Callable, Iterator 11 | 12 | class SubsetSampler(Sampler[int]): 13 | r"""Samples elements randomly from a given list of indices, without replacement. 14 | 15 | Args: 16 | indices (sequence): a sequence of indices 17 | generator (Generator): Generator used in sampling. 18 | """ 19 | indices: Sequence[int] 20 | 21 | def __init__(self, indices: Sequence[int], generator=None) -> None: 22 | self.indices = indices 23 | self.generator = generator 24 | 25 | def __iter__(self) -> Iterator[int]: 26 | for i in self.indices: 27 | yield self.indices[i] 28 | 29 | def __len__(self) -> int: 30 | return len(self.indices) 31 | 32 | class DatasetBuffer(Dataset): 33 | """ 34 | Dataset buffer class for PPO 35 | """ 36 | def __init__(self, discount: float=1, seq_length: int=1, stride: int=1, gae_lambda: Optional[float]=None) -> None: 37 | """ 38 | :param discount: Discount factor 39 | """ 40 | super().__init__() 41 | self.discount = discount 42 | self.seq_length = seq_length 43 | self.stride = stride 44 | self.gae_lambda = gae_lambda 45 | 46 | self.clear() 47 | 48 | def clear(self) -> None: 49 | self.states = [] 50 | self.actions = [] 51 | self.horizons = [] 52 | self.costs = [] 53 | self.conds = [] 54 | self.rewards = [] 55 | self.values = [] 56 | self.stds = [] 57 | self.returns = [] 58 | self.advantages = [] 59 | self.means = [] 60 | self.time_steps = [] 61 | 62 | def push(self, trajectories: List[Dict[str, Any]]) -> None: 63 | # Iterate through the trajectories 64 | for traj in trajectories: 65 | self.states.append(traj['states']) 66 | self.actions.append(traj['actions']) 67 | self.costs.append(traj['costs']) 68 | self.rewards.append(traj['rewards']) 69 | self.values.append(traj['values']) 70 | self.conds.append(traj['conds']) 71 | self.horizons.append(traj['horizons']) 72 | self.means.append(traj['means']) 73 | self.stds.append(traj['stds']) 74 | self.time_steps.append(traj['time_steps']) 75 | 76 | def __len__(self) -> int: 77 | return len(self.states) 78 | 79 | def compute_returns_and_advantages(self) -> None: 80 | if self.gae_lambda is None: 81 | for traj_idx in range(len(self)): 82 | rewards = self.rewards[traj_idx] 83 | values = self.values[traj_idx] 84 | 85 | R = values[-1] 86 | returns = [] 87 | for reward in reversed(rewards): 88 | R = self.discount*R + reward 89 | returns.insert(0, R) 90 | returns = torch.stack(returns) 91 | self.returns.append(returns) 92 | 93 | returns = torch.stack(self.returns) 94 | values = torch.stack(self.values) 95 | adv = returns - values 96 | else: 97 | for traj_idx in range(len(self)): 98 | rewards = self.rewards[traj_idx] 99 | values = self.values[traj_idx] 100 | 101 | last_gae_lam = 0 102 | advantages = [] 103 | for t in reversed(range(len(rewards))): 104 | if t == len(rewards)-1: 105 | next_values = values[-1] 106 | else: 107 | next_values = values[t+1] 108 | delta = rewards[t] + self.discount*next_values - values[t] 109 | last_gae_lam = delta + self.discount*self.gae_lambda*last_gae_lam 110 | advantages.insert(0, last_gae_lam) 111 | advantages = torch.stack(advantages) 112 | returns = advantages + values 113 | self.advantages.append(advantages) 114 | self.returns.append(returns) 115 | adv = torch.stack(self.advantages) 116 | 117 | adv = (adv - adv.mean()) / (adv.std() + 1e-6) 118 | self.advantages = adv 119 | 120 | def split_into_subsequences(self) -> None: 121 | states = self.states 122 | actions = self.actions 123 | horizons = self.horizons 124 | costs = self.costs 125 | conds = self.conds 126 | returns = self.returns 127 | rewards = self.rewards 128 | advantages = self.advantages 129 | means = self.means 130 | stds = self.stds 131 | time_steps = self.time_steps 132 | 133 | self.clear() 134 | 135 | for traj_idx in range(len(states)): 136 | ep_length = states[traj_idx].shape[0] 137 | indices = np.arange(0, ep_length-self.seq_length+1, self.stride) 138 | n_subseqs = len(indices) 139 | 140 | # Split into subsequences 141 | self.states.extend( 142 | [states[traj_idx][i:i+self.seq_length] for i in indices]) 143 | self.actions.extend( 144 | [actions[traj_idx][i:i+self.seq_length] for i in indices]) 145 | self.rewards.extend( 146 | [rewards[traj_idx][i:i+self.seq_length] for i in indices]) 147 | self.time_steps.extend( 148 | [time_steps[traj_idx][i:i+self.seq_length] for i in indices]) 149 | 150 | # Handle returns and advantages if computed 151 | if len(returns) > 0: 152 | self.returns.extend( 153 | [returns[traj_idx][i:i+self.seq_length] for i in indices]) 154 | self.advantages.extend( 155 | [advantages[traj_idx][i:i+self.seq_length] for i in indices]) 156 | else: 157 | self.returns.extend([None for _ in range(n_subseqs)]) 158 | self.advantages.extend([None for _ in range(n_subseqs)]) 159 | 160 | # Append the full sampled plans 161 | if not horizons[traj_idx] is None: 162 | self.horizons.extend( 163 | [horizons[traj_idx][i:i+self.seq_length] for i in indices]) 164 | else: 165 | self.horizons.extend([None for _ in range(n_subseqs)]) 166 | 167 | # Handle the optional means variables 168 | if not means[traj_idx] is None: 169 | self.means.extend( 170 | [means[traj_idx][i:i+self.seq_length] for i in indices]) 171 | else: 172 | self.means.extend([None for _ in range(n_subseqs)]) 173 | 174 | # Handle the optional stds variables 175 | if not stds[traj_idx] is None: 176 | self.stds.extend( 177 | [stds[traj_idx][i:i + self.seq_length] for i in indices]) 178 | else: 179 | self.stds.extend([None for _ in range(n_subseqs)]) 180 | 181 | # Handle the optional costs variable 182 | if not costs[traj_idx] is None: 183 | self.costs.extend( 184 | [costs[traj_idx][i:i+self.seq_length] for i in indices]) 185 | else: 186 | self.costs.extend([None for _ in range(n_subseqs)]) 187 | 188 | # Handle the optional condition variable 189 | if not conds[traj_idx] is None: 190 | self.conds.extend( 191 | [conds[traj_idx][i:i+self.seq_length] for i in indices]) 192 | else: 193 | self.conds.extend([None for _ in range(n_subseqs)]) 194 | 195 | 196 | def get_samplers(self): 197 | indices = np.arange(len(self)) 198 | np.random.shuffle(indices) 199 | sampler = SubsetSampler(indices) 200 | return sampler 201 | 202 | def __getitem__(self, idx): 203 | return self.states[idx], \ 204 | self.actions[idx], \ 205 | self.rewards[idx], \ 206 | self.horizons[idx] if not self.horizons[idx] is None else torch.empty(()), \ 207 | self.costs[idx] if not self.costs[idx] is None else torch.empty(()), \ 208 | self.conds[idx] if not self.conds[idx] is None else torch.empty(()), \ 209 | self.returns[idx] if not self.returns[idx] is None else torch.empty(()), \ 210 | self.advantages[idx] if not self.advantages[idx] is None else torch.empty(()), \ 211 | self.means[idx] if not self.means[idx] is None else torch.empty(()), \ 212 | self.stds[idx] if not self.stds[idx] is None else torch.empty(()), \ 213 | self.time_steps[idx] 214 | -------------------------------------------------------------------------------- /dmpo/experiment/ppo_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as D 3 | from copy import deepcopy 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 7 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 8 | from pytorch_lightning.callbacks.callback import Callback 9 | 10 | from torch import Tensor 11 | from torch.nn import Module 12 | from typing import Dict, Any, Optional, List, Union, Tuple, Callable 13 | 14 | class PPOTrainer(pl.LightningModule): 15 | """ 16 | PyTorch Lightning module which implements PPO for updating the policy 17 | """ 18 | def __init__(self, 19 | model: Module, 20 | actor_optim: Callable, 21 | actor_optim_args: Dict[str, Any], 22 | critic_optim: Callable, 23 | critic_optim_args: Dict[str, Any], 24 | entropy_penalty: float, 25 | kl_penalty: float, 26 | clip_epsilon: Optional[float] = None, 27 | std_clip_epsilon: Optional[float] = None, 28 | max_grad_norm: Optional[float] = None, 29 | model_subsets: Optional[List[Union[str, List[str]]]] = None, 30 | critic_only: bool=False): 31 | super().__init__() 32 | 33 | # Actor and critic optimizers and associated arguments 34 | self.optims = [actor_optim, critic_optim] 35 | self.optims_args = [actor_optim_args, critic_optim_args] 36 | self.model_subsets = model_subsets 37 | 38 | # PPO parameters 39 | self.clip_epsilon = clip_epsilon 40 | self.std_clip_epsilon = std_clip_epsilon 41 | self.entropy_penalty = entropy_penalty 42 | self.kl_penalty = kl_penalty 43 | self.max_grad_norm = max_grad_norm 44 | self.critic_only = critic_only 45 | 46 | # Turn off PyTorch Lightning automatic optimization 47 | self.automatic_optimization = False 48 | 49 | # Create a copy of the model for computing CPI loss 50 | self.model = model 51 | if hasattr(self.model, 'rollout_task'): 52 | self.rollout_task = self.model.rollout_task 53 | self.model.rollout_task = None 54 | 55 | self.old_model = deepcopy(self.model) 56 | 57 | self.model.rollout_task = self.rollout_task 58 | self.old_model.rollout_task = self.rollout_task 59 | else: 60 | self.old_model = deepcopy(self.model) 61 | 62 | def training_step(self, batch: List[Any], batch_idx: int) -> Dict[str, Any]: 63 | # Get data from batch 64 | states, actions, rewards, horizons, costs, conds, returns, advantages, means, stds, time_steps = batch 65 | batch_size, T, H, d_action = horizons.shape 66 | 67 | # Get the optimizers 68 | actor_opt, critic_opt = self.optimizers() 69 | 70 | # Reset the models 71 | if hasattr(self.model, 'reset'): 72 | self.model.reset() 73 | self.old_model.reset() 74 | 75 | # Iterate over subsequence for actor update 76 | old_log_probs_list = []; old_std_log_probs_list = [] 77 | log_probs_list = []; std_log_probs_list = [] 78 | entropy_list = []; kl_div_list = [] 79 | 80 | for t in range(T): 81 | state = states[:, t] 82 | cond = conds[:, t] if conds.ndim > 1 else None 83 | horizon = horizons[:, t] 84 | cost = costs[:, t] if costs.ndim > 1 else None 85 | mean = means[:, t] if means.ndim > 1 else None 86 | std = stds[:, t] if stds.ndim > 1 else None 87 | 88 | # Get old action log probabilities 89 | with torch.no_grad(): 90 | self.old_model.to('cuda') 91 | out = self.old_model(state, cond=cond, costs=cost, mean=mean, std=std) 92 | old_dist = out['mean_dist'] 93 | old_std_dist = out['std_dist'] if 'std_dist' in out.keys() else None 94 | 95 | old_log_probs = old_dist.log_prob(horizon) 96 | old_log_probs = old_log_probs.mean(dim=(2)).mean(dim=(1)) 97 | old_log_probs_list.append(old_log_probs) 98 | 99 | if not old_std_dist is None: 100 | old_std_log_probs = old_std_dist.log_prob(std) 101 | old_std_log_probs = old_std_log_probs.mean(dim=(2)).mean(dim=(1)) 102 | old_std_log_probs_list.append(old_std_log_probs) 103 | 104 | # Get new action log probabilities 105 | out = self.model(state, cond=cond, costs=cost, mean=mean, std=std) 106 | dist = out['mean_dist'] 107 | std_dist = out['std_dist'] if 'std_dist' in out.keys() else None 108 | 109 | log_probs = dist.log_prob(horizon) 110 | log_probs = log_probs.mean(dim=(2)).mean(dim=(1)) 111 | log_probs_list.append(log_probs) 112 | 113 | if not std_dist is None: 114 | std_log_probs = std_dist.log_prob(std) 115 | std_log_probs = std_log_probs.mean(dim=(2)).mean(dim=(1)) 116 | std_log_probs_list.append(std_log_probs) 117 | 118 | # Compute entropy and KL divergence 119 | if not dist is None: 120 | entropy = dist.entropy() 121 | kl_div = D.kl_divergence(dist, old_dist) 122 | 123 | entropy_list.append(entropy) 124 | kl_div_list.append(kl_div) 125 | 126 | # Stack the lists 127 | old_log_probs = torch.stack(old_log_probs_list, dim=1) 128 | log_probs = torch.stack(log_probs_list, dim=1) 129 | entropies = torch.stack(entropy_list, dim=1) 130 | kl_divs = torch.stack(kl_div_list, dim=1) 131 | 132 | # Compute the actor loss 133 | ratio = torch.exp(log_probs - old_log_probs) 134 | cpi_loss = ratio * advantages.squeeze(2) 135 | 136 | # Compute actor loss for learning STD 137 | if not std_dist is None: 138 | old_std_log_probs = torch.stack(old_std_log_probs_list, dim=1) 139 | std_log_probs = torch.stack(std_log_probs_list, dim=1) 140 | std_ratio = torch.exp(std_log_probs - old_std_log_probs) 141 | std_cpi_loss = std_ratio * advantages.squeeze(2) 142 | else: 143 | std_ratio = None 144 | 145 | if not self.clip_epsilon is None: 146 | clip_loss = ratio.clamp(1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages.squeeze(2) 147 | actor_loss = -torch.min(cpi_loss, clip_loss).mean() 148 | 149 | if not std_ratio is None: 150 | clip_loss = std_ratio.clamp(1 - self.std_clip_epsilon, 151 | 1 + self.std_clip_epsilon) * advantages.squeeze(2) 152 | actor_loss += -torch.min(std_cpi_loss, clip_loss).mean() 153 | else: 154 | actor_loss = -cpi_loss.mean() - std_cpi_loss.mean() 155 | 156 | # Compute entropy penalty and KL divergence 157 | entropy = entropies.mean() 158 | kl_div = kl_divs.mean() 159 | total_loss = actor_loss - self.entropy_penalty * entropy + self.kl_penalty * kl_div 160 | 161 | # Compute ratio deviation 162 | ratio_dev = torch.abs(ratio - 1).mean() 163 | 164 | # Optimize the actor 165 | actor_opt.zero_grad() 166 | 167 | if not self.critic_only: 168 | self.manual_backward(total_loss) 169 | if not self.max_grad_norm is None: 170 | torch.nn.utils.clip_grad_norm_(actor_opt.param_groups[0]['params'], self.max_grad_norm) 171 | actor_opt.step() 172 | 173 | actor_opt.zero_grad() 174 | critic_opt.zero_grad() 175 | 176 | # Reset the model 177 | if hasattr(self.model, 'reset'): 178 | self.model.reset() 179 | self.old_model.reset() 180 | 181 | # Iterate over subsequence for critic update 182 | values_list = [] 183 | for t in range(T): 184 | state = states[:, t] 185 | cond = conds[:, t] if conds.ndim > 1 else None 186 | cost = costs[:, t] if costs.ndim > 1 else None 187 | mean = means[:, t] if means.ndim > 1 else None 188 | std = stds[:, t] if stds.ndim > 1 else None 189 | 190 | out = self.model(state, cond=cond, costs=cost, mean=mean, std=std) 191 | value = out['value'] 192 | values_list.append(value) 193 | 194 | # Stack the list 195 | values = torch.stack(values_list, dim=1) 196 | 197 | # Compute the critic loss 198 | critic_loss = 0.5 * (returns - values).pow(2).mean() 199 | 200 | # Optimize the critic 201 | critic_opt.zero_grad() 202 | self.manual_backward(critic_loss) 203 | if not self.max_grad_norm is None: 204 | torch.nn.utils.clip_grad_norm_(critic_opt.param_groups[0]['params'], self.max_grad_norm) 205 | critic_opt.step() 206 | 207 | actor_opt.zero_grad() 208 | critic_opt.zero_grad() 209 | 210 | # Reset the model 211 | if hasattr(self.model, 'reset'): 212 | self.model.reset() 213 | self.old_model.reset() 214 | 215 | loss_dict = dict( 216 | actor_loss=actor_loss.detach(), 217 | critic_loss=critic_loss.detach(), 218 | total_loss=total_loss.detach(), 219 | ratio=ratio.mean().detach(), 220 | ratio_dev=ratio_dev.detach(), 221 | entropy=entropy.detach(), 222 | kl_div=kl_div.detach() 223 | ) 224 | 225 | if not std_ratio is None: 226 | loss_dict['std_ratio'] = std_ratio.mean().detach() 227 | 228 | batch_dict = {'loss': total_loss, 'log': loss_dict} 229 | self.log_dict(loss_dict, on_step=False, on_epoch=True, prog_bar=True, logger=True) 230 | return batch_dict 231 | 232 | def configure_optimizers(self) -> List[Dict[str, Any]]: 233 | optimizers = [] 234 | for idx, optim in enumerate(self.optims): 235 | # Get the proper subset of parameters for each optimizer 236 | if not self.model_subsets is None: 237 | model_subset = self.model_subsets 238 | 239 | if isinstance(model_subset[idx], str): 240 | # If a single attribute is a subset 241 | params = filter(lambda x: x.requires_grad, getattr(self.model, model_subset[idx]).parameters()) 242 | else: 243 | # If we have a list of attributes for the subset 244 | params = [] 245 | for subset in model_subset[idx]: 246 | params.extend(filter(lambda x: x.requires_grad, getattr(self.model, subset).parameters())) 247 | else: 248 | # Use all parameters for the optimizer 249 | params = filter(lambda x: x.requires_grad, self.model.parameters()) 250 | 251 | # Create the optimizer 252 | optimizer = optim(params, **self.optims_args[idx]) 253 | optimizers.append(optimizer) 254 | opt_dict = [{'optimizer': optimizer} for optimizer in optimizers] 255 | return opt_dict -------------------------------------------------------------------------------- /dmpo/mpc/rollout/quadrotor_rollout.py: -------------------------------------------------------------------------------- 1 | from .rollout_base import RolloutBase 2 | from ..model.quadrotor_model import QuadrotorModel 3 | from ...envs.math_utils import * 4 | 5 | import torch 6 | 7 | from torch import Tensor 8 | from typing import Dict, Any, Optional, List 9 | 10 | class QuadrotorRollout(RolloutBase): 11 | def __init__(self, 12 | dynamics_model: QuadrotorModel, 13 | exp_params: Dict[str, Any], 14 | num_envs: int = 1, 15 | tensor_args: Dict[str, Any] = {'device': 'cpu', 'dtype': torch.float32}): 16 | """ 17 | :param dynamics_model: Dynamics model used in rollout 18 | :param exp_params: Cost function parameters 19 | :param num_envs: Number of environments 20 | :param tensor_args: PyTorch params 21 | """ 22 | super(QuadrotorRollout, self).__init__() 23 | self.dynamics_model = dynamics_model 24 | self.tensor_args = tensor_args 25 | self.exp_params = exp_params 26 | self.num_envs = num_envs 27 | 28 | self.use_omega = dynamics_model.use_omega 29 | self.action_is_mf = dynamics_model.action_is_mf 30 | self.convert_mf_to_omega = dynamics_model.convert_mf_to_omega 31 | self.use_delay_model = dynamics_model.use_delay_model 32 | 33 | self.param = self.dynamics_model.param 34 | self.alpha_p = self.param.alpha_p 35 | self.alpha_z = self.param.alpha_z 36 | self.alpha_w = self.param.alpha_w 37 | self.alpha_a = self.param.alpha_a 38 | self.alpha_R = self.param.alpha_R 39 | self.alpha_v = self.param.alpha_v 40 | self.alpha_pitch = self.param.alpha_pitch 41 | self.alpha_yaw = self.param.alpha_yaw 42 | self.t = torch.zeros((self.num_envs), **self.tensor_args) 43 | 44 | def cost_fn(self, state: Tensor, act: Tensor) -> Dict[str, Tensor]: 45 | time_step = torch.ceil(self.t / self.dynamics_model.avg_dt) 46 | num_envs, num_samples, H, _ = state.shape 47 | indices = torch.arange(H).to(**self.tensor_args) + time_step[0] 48 | indices = indices.clip(0, self.ref_trajectory.shape[-1]-1) 49 | indices = indices.to(torch.long) 50 | 51 | state_ref = self.ref_trajectory[:, :, indices].permute((0, 2, 1)) 52 | p_des = state_ref[:, :, 0:3] 53 | v_des = state_ref[:, :, 3:6] 54 | w_des = state_ref[:, :, 10:] 55 | q_des = state_ref[:, :, 6:10] 56 | 57 | # Position tracking error 58 | if self.alpha_p > 0: 59 | ep = torch.linalg.norm(state[:, :, :, 0:3] - p_des[:, None], dim=-1) 60 | else: 61 | ep = 0. 62 | 63 | # Additional cost on Z tracking error 64 | if self.alpha_z > 0: 65 | ez = torch.linalg.norm(state[:, :, :, 2:3] - p_des[:, None, :, 2:3], dim=-1) 66 | else: 67 | ez = 0. 68 | 69 | # Velocity tracking error 70 | if self.alpha_v > 0: 71 | ev = torch.linalg.norm(state[:, :, :, 3:6] - v_des[:, None], dim=-1) 72 | else: 73 | ev = 0. 74 | 75 | # Angular velocity tracking error 76 | if self.alpha_w > 0: 77 | ew = torch.linalg.norm(state[:, :, :, 10:] - w_des[:, None], dim=-1) 78 | else: 79 | ew = 0. 80 | 81 | # Orientation tracking error 82 | if self.alpha_R > 0: 83 | q_des_repeated = q_des[:, None].repeat(1, num_samples, 1, 1) 84 | eR = qdistance_torch(state[:, :, :, 6:10].view(-1, 4), q_des_repeated.view(-1, 4)) 85 | eR = eR.view(num_envs, num_samples, H) 86 | else: 87 | eR = 0. 88 | 89 | # Control cost 90 | if self.alpha_a > 0: 91 | ea = torch.linalg.norm(act, dim=-1) 92 | else: 93 | ea = 0. 94 | 95 | # Yaw tracking error 96 | if self.alpha_yaw > 0: 97 | q_des_repeated = q_des[:, None, :, :].repeat(1, num_samples, 1, 1).view(-1, 4) 98 | qe = qmultiply_torch(qconjugate_torch(q_des_repeated), state[:, :, :, 6:10].view(-1, 4)) 99 | Re = qtoR_torch(qe) 100 | eyaw = torch.atan2(Re[:, 1, 0], Re[:, 0, 0]) ** 2 101 | eyaw = eyaw.view(num_envs, num_samples, H) 102 | else: 103 | eyaw = 0. 104 | 105 | # Pitch tracking error 106 | if self.alpha_pitch > 0: 107 | q_des_repeated = q_des[:, None, :, :].repeat(1, num_samples, 1, 1).view(-1, 4) 108 | qe = qmultiply_torch(qconjugate_torch(q_des_repeated), state[:, :, :, 6:10].view(-1, 4)) 109 | Re = qtoR_torch(qe) 110 | epitch = (torch.asin(Re[:,2,0].clip(-1, 1)))**2 111 | epitch = epitch.view(num_envs, num_samples, H) 112 | else: 113 | epitch = 0. 114 | 115 | cost = (self.alpha_p * ep 116 | + self.alpha_z * ez 117 | + self.alpha_v * ev 118 | + self.alpha_w * ew 119 | + self.alpha_a * ea 120 | + self.alpha_yaw * eyaw 121 | + self.alpha_R * eR 122 | + self.alpha_pitch * epitch) * self.dynamics_model.avg_dt 123 | 124 | return dict( 125 | cost=cost, 126 | ep=self.alpha_p * ep * self.dynamics_model.avg_dt, 127 | ez=self.alpha_z * ez * self.dynamics_model.avg_dt, 128 | ev=self.alpha_v * ev * self.dynamics_model.avg_dt, 129 | ew=self.alpha_w * ew * self.dynamics_model.avg_dt, 130 | ea=self.alpha_a * ea * self.dynamics_model.avg_dt, 131 | eyaw=self.alpha_yaw * eyaw * self.dynamics_model.avg_dt, 132 | eR=self.alpha_R * eR * self.dynamics_model.avg_dt, 133 | epitch=self.alpha_pitch * epitch * self.dynamics_model.avg_dt 134 | ) 135 | 136 | def omega_controller(self, s: Tensor, a: Tensor) -> Tensor: 137 | # Converts desired (thrust, omega) to motor forces 138 | num_envs, num_samples, _ = a.shape 139 | 140 | T_d = a[:, :, 0].reshape(-1) 141 | omega_d = a[:, :, 1:].reshape(-1, 3) 142 | omega = s[:, :, 10:13].reshape(-1, 3) 143 | omega_e = omega_d - omega 144 | 145 | torque = self.dynamics_model.omega_gain * omega_e # tensor, (3,) 146 | torque = torch.mm(self.dynamics_model.J, torque.T).T 147 | torque -= torch.cross(torch.mm(self.dynamics_model.J, omega.T).T, omega) 148 | 149 | wrench = torch.cat((T_d.view(-1, 1), torque), dim=1) # tensor, (N, 4) 150 | motorForce = torch.mm(self.dynamics_model.B0_inv, wrench.T).T 151 | motorForce = torch.clip(motorForce, self.dynamics_model.a_min, self.dynamics_model.a_max) 152 | motorForce = motorForce.view(num_envs, num_samples, -1) 153 | return motorForce 154 | 155 | def convert_motor_forces(self, s: Tensor, a: Tensor) -> Tensor: 156 | ''' 157 | Converts motor forces to desired (thrust, omega) 158 | ''' 159 | num_envs, num_particles, _ = a.shape 160 | eta = a.view(num_envs*num_particles, -1) @ self.dynamics_model.B0.T 161 | T_d = eta[:, :1] 162 | tau_u = eta[:, 1:] 163 | 164 | omega = s[:, :, 10:].view(num_envs*num_particles, -1) 165 | Jomega = omega @ self.dynamics_model.J.T 166 | d_omega = torch.cross(Jomega, omega) + tau_u 167 | d_omega = d_omega @ self.dynamics_model.inv_J.T 168 | omega_d = omega + d_omega*self.dynamics_model.avg_dt 169 | 170 | new_a = torch.cat((T_d, omega_d), dim=-1) 171 | new_a = new_a.view(num_envs, num_particles, -1) 172 | return new_a 173 | 174 | def rollout_fn(self, start_state: Tensor, act_seq: Tensor) -> Dict[str, Any]: 175 | actions = act_seq 176 | if actions.ndim == 3: 177 | actions = actions.unsqueeze(0) 178 | 179 | num_envs, num_particles, horizon, _ = actions.shape 180 | state_seq = torch.zeros((num_envs, num_particles, horizon, start_state.shape[-1]), **self.tensor_args) 181 | states = start_state.unsqueeze(1).repeat((1, num_particles, 1)) 182 | 183 | # Apply delay model if action space is not motor forces 184 | if not self.action_is_mf and self.use_delay_model: 185 | if self.actions is None: 186 | self.actions = self.env_actions.unsqueeze(2).repeat(1, num_particles, horizon, 1) 187 | action = torch.clip(actions, self.dynamics_model.action_lows, self.dynamics_model.action_highs) 188 | self.actions = self.actions + self.dynamics_model.delay_coeff * (actions - self.actions) 189 | 190 | for t in range(horizon): 191 | if not self.action_is_mf and not self.convert_mf_to_omega: 192 | # Ensure control bounds if using desired (thrust, omega) as control space 193 | action = actions[:, :, t] 194 | action = torch.clip(action, self.dynamics_model.action_lows, self.dynamics_model.action_highs) 195 | elif self.convert_mf_to_omega: 196 | # Convert motor forces to desired (thrust, omega) command 197 | action = actions[:, :, t] 198 | action = torch.clip(action, self.dynamics_model.a_min, self.dynamics_model.a_max) 199 | action = self.convert_motor_forces(states, action) 200 | elif self.action_is_mf: 201 | # Ensure control bounds if using motor forces 202 | action = actions[:, :, t] 203 | action = torch.clip(action, self.dynamics_model.a_min, self.dynamics_model.a_max) 204 | 205 | if self.use_omega: 206 | # Apply omega controller to convert motor forces to desired (thrust, omega) 207 | action = self.omega_controller(states, actions[:, :, t]) 208 | elif self.use_delay_model: 209 | # Get action from the delayed actions 210 | action = self.actions[:, :, t] 211 | 212 | states = self.dynamics_model.step(states, action) 213 | state_seq[:, :, t] = states 214 | 215 | cost_seq = self.cost_fn(state_seq, act_seq) 216 | 217 | trajectories = dict( 218 | actions=act_seq, 219 | costs=cost_seq, 220 | rollout_time=0.0, 221 | state_seq=state_seq 222 | ) 223 | 224 | return trajectories 225 | 226 | def __call__(self, start_state: Tensor, act_seq: Tensor) -> Dict[str, Any]: 227 | return self.rollout_fn(start_state, act_seq) 228 | 229 | def update_params(self, t, actions, ref_dts, ref_pos): 230 | self.t = torch.tensor(t, **self.tensor_args) 231 | 232 | # Reset stored actions if at initial time point 233 | self.env_actions = torch.stack(actions).to(**self.tensor_args)[:, None] 234 | if np.any(t) == 0: 235 | self.actions = None 236 | 237 | # Get the reference trajectory 238 | kwargs = {} 239 | if not ref_dts is None: 240 | kwargs['dts'] = np.stack(ref_dts) 241 | if not ref_pos is None: 242 | kwargs['pos'] = np.stack(ref_pos) 243 | 244 | self.num_envs = len(ref_pos) 245 | self.ref_trajectory, _, _ = self.param.get_reference(self.num_envs, **kwargs) 246 | self.ref_trajectory = self.ref_trajectory.to(**self.tensor_args) 247 | return True -------------------------------------------------------------------------------- /dmpo/experiment/ppo_experiment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | import numpy as np 5 | 6 | from .experiment import Experiment 7 | from .ppo_rollout import ppo_rollout 8 | from ..dataset.dataset_buffer import DatasetBuffer 9 | from .experiment_utils import create_task, create_env, create_optim 10 | from .ppo_trainer import PPOTrainer 11 | from .. import utils 12 | 13 | from torch import Tensor 14 | from torch.nn import Module 15 | from typing import Dict, Any, Optional, List, Union, Tuple, Callable 16 | 17 | class PPOExperiment(Experiment): 18 | """ 19 | Experiment class for running PPO 20 | """ 21 | def __init__(self, 22 | env_name: str, 23 | env_config: Dict[str, Any], 24 | model_config: Dict[str, Any], 25 | actor_optim_config: Dict[str, Any], 26 | critic_optim_config: Dict[str, Any], 27 | task_config: Optional[Dict[str, Any]]=None, 28 | batch_size: int=1, 29 | seed: int=0, 30 | n_iters: int=1, 31 | use_condition: bool=False, 32 | train_episode_len: int=1, 33 | val_episode_len: int=1, 34 | train_episodes: int=1, 35 | val_episodes: int=1, 36 | n_val_envs: Optional[int]=None, 37 | break_if_done: bool=False, 38 | dynamic_env: bool=False, 39 | num_workers: int=0, 40 | model_file: Optional[str]=None, 41 | n_epochs: int=1, 42 | n_gpus: int=1, 43 | log_folder: str='logs', 44 | exp_name: str='experiment', 45 | dtype: str='float', 46 | device: str='cuda', 47 | env_device: str='cpu', 48 | val_every: int=1, 49 | max_grad_norm: float=1, 50 | n_pretrain_steps: Optional[int]=None, 51 | n_pretrain_epochs: int=1, 52 | dataset_config: Dict[str, Any]={}, 53 | trainer_config: Dict[str, Any]={}, 54 | train_env_config: Dict[str, Any]={}, 55 | val_env_config: Dict[str, Any]={}) -> None: 56 | """ 57 | :param env_name: Environment to load 58 | :param env_config: Configuration for environment 59 | :param model_config: Configuration for actor\critic model 60 | :param actor_optim_config: Configuration for actor optimizer 61 | :param critic_optim_config: Configuration for critic optimizer 62 | :param task_config: Configuration for the MPC task 63 | :param batch_size: Batch size 64 | :param seed: Base seed for rollouts 65 | :param n_iters: # of iterations of rollout-train-test 66 | :param use_condition: Handle logic for forming and sending conditioning variable to policy and critic 67 | :param train_episode_len: Max length of training episodes 68 | :param val_episode_len: Max length of validation episodes 69 | :param train_episodes: # of train episodes per iteration 70 | :param val_episodes: # of validation episodes per iteration 71 | :param n_val_envs: # of validation environments (can be different than # of training envs) 72 | :param break_if_done: Break from rollout if all environments have finished 73 | :param dynamic_env: Environment is dynamic (requires additional processing for MPC) 74 | :param num_workers: # of threads to use for loading training data 75 | :param model_file: Optional pre-trained model to load 76 | :param n_epochs: # of epochs to train for each iteration 77 | :param n_gpus: # of GPU devices to use 78 | :param log_folder: Folder to use for logging and checkpoints 79 | :param exp_name: Name of experiment for logging 80 | :param dtype: PyTorch data type 81 | :param device: PyTorch device 82 | :param val_every: # of epochs per which we run on validation data 83 | :param max_grad_norm: Maximum norm used for gradient clipping 84 | :param n_pretrain_steps: # of iterations to pretrain critic 85 | :param n_pretrain_epochs: # of epochs per iteration for pretraining critic 86 | :param dataset_config: Configuration for dataset buffer 87 | :param trainer_config: Configuration for the trainer 88 | :param train_env_config: Additional environment configuration for training rollouts 89 | :param val_env_config: Additional environment configuration for validation rollouts 90 | """ 91 | super().__init__(model_config=model_config, 92 | optim_config=actor_optim_config, 93 | n_epochs=n_epochs, 94 | log_folder=log_folder, 95 | exp_name=exp_name, 96 | dtype=dtype, 97 | device=device) 98 | 99 | self.batch_size = batch_size 100 | self.seed = seed 101 | self.n_iters = n_iters 102 | self.use_condition = use_condition 103 | self.train_episode_len = train_episode_len 104 | self.val_episode_len = val_episode_len 105 | self.train_episodes = train_episodes 106 | self.val_episodes = val_episodes 107 | self.val_every = val_every 108 | self.n_val_envs = n_val_envs 109 | self.break_if_done = break_if_done 110 | self.dynamic_env = dynamic_env 111 | self.num_workers = num_workers 112 | self.n_pretrain_steps = n_pretrain_steps 113 | self.n_pretrain_epochs = n_pretrain_epochs 114 | self.model_config = model_config 115 | self.task_config = task_config 116 | self.train_env_config = train_env_config 117 | self.val_env_config = val_env_config 118 | 119 | self.best_percent = None 120 | self.best_cost = None 121 | 122 | # Optionally load the model 123 | if not model_file is None: 124 | model_dict = torch.load(model_file, map_location='cpu') 125 | state_dict = model_dict['state_dict'] 126 | self.model.load_state_dict(state_dict) 127 | self.model.to(**self.tensor_args) 128 | 129 | if hasattr(self.model, 'horizon'): 130 | self.horizon = self.model.horizon 131 | self.d_action = self.model.d_action 132 | 133 | # Get the critic optimizer 134 | self.critic_optim, self.critic_optim_args = create_optim(critic_optim_config) 135 | 136 | # Create the environment 137 | self.env_tensor_args = {'device':env_device, 'dtype':self.dtype} 138 | self.env = create_env(env_name, tensor_args=self.env_tensor_args, **env_config) 139 | 140 | self.model.action_lows = utils.to_tensor(self.env.action_lows, self.tensor_args) 141 | self.model.action_highs = utils.to_tensor(self.env.action_highs, self.tensor_args) 142 | 143 | # Create the task 144 | if not task_config is None: 145 | self.task = create_task(env_name=env_name, 146 | num_envs=self.env.num_envs, 147 | tensor_args=self.tensor_args, 148 | **task_config) 149 | self.model.set_task(self.task) 150 | else: 151 | self.task = None 152 | 153 | # Create the dataset buffer 154 | self.dataset = DatasetBuffer(**dataset_config) 155 | 156 | # Create the PPO model trainer 157 | self.model_trainer = PPOTrainer(model=self.model, 158 | actor_optim=self.optim, 159 | actor_optim_args=self.optim_args, 160 | critic_optim=self.critic_optim, 161 | critic_optim_args=self.critic_optim_args, 162 | max_grad_norm=max_grad_norm, 163 | **trainer_config) 164 | 165 | # Create the PyTorch Lightning Trainer instance 166 | self.trainer = pl.Trainer(devices=n_gpus, 167 | accelerator='gpu', 168 | logger=self.logger, 169 | max_epochs=self.n_epochs, 170 | check_val_every_n_epoch=val_every, 171 | num_sanity_val_steps=0) 172 | 173 | def run_rollouts(self, itr: int=0, is_train: bool=False): 174 | if is_train: 175 | # Set the training configuration 176 | n_episodes = self.train_episodes 177 | ep_length = self.train_episode_len 178 | self.env.set_param(self.train_env_config) 179 | 180 | # Use a different random seed for each training trial 181 | seed = self.seed + 123*itr 182 | else: 183 | # Set the validation configuration 184 | n_episodes = self.val_episodes 185 | ep_length = self.val_episode_len 186 | self.env.set_param(self.val_env_config) 187 | 188 | # Use a fixed random seed for each validation trial 189 | seed = self.seed 190 | 191 | # Change # of environments if different for validation 192 | if not self.n_val_envs is None: 193 | n_envs = self.env.num_envs 194 | self.env.num_envs = self.n_val_envs 195 | 196 | # Validation mode uses the mean optimizer update 197 | if not is_train: 198 | use_mean = self.model.use_mean 199 | self.model.use_mean = True 200 | 201 | # Perform the rollout 202 | with torch.no_grad(): 203 | trajectories = ppo_rollout(env=self.env, 204 | model=self.model, 205 | n_episodes=n_episodes, 206 | ep_length=ep_length, 207 | base_seed=seed, 208 | break_if_done=self.break_if_done, 209 | use_condition=self.use_condition, 210 | dynamic_env=self.dynamic_env, 211 | use_tqdm=True, 212 | tensor_args=self.tensor_args) 213 | 214 | # Reset parameters if modified during validation mode 215 | if not is_train: 216 | self.model.use_mean = use_mean 217 | 218 | if not self.n_val_envs is None: 219 | self.env.num_envs = n_envs 220 | 221 | # Compute statistics 222 | success_dict = self.env.evaluate_success(trajectories) 223 | stat_dict = utils.compute_statistics(success_dict) 224 | 225 | # Display statistics 226 | success_percentage = success_dict['success_percentage'] 227 | mean_cost = stat_dict['mean_cost'] 228 | std_cost = stat_dict['std_cost'] 229 | mean_success_cost = stat_dict['mean_success_cost'] 230 | std_success_cost = stat_dict['std_success_cost'] 231 | 232 | # Save current model 233 | utils.make_dir(self.logger.log_dir + '/checkpoints') 234 | 235 | filename = self.logger.log_dir + '/checkpoints/last.pt' 236 | torch.save(dict(state_dict=self.model.state_dict(), 237 | model_config=self.model_config, 238 | task_config=self.task_config), filename) 239 | 240 | # Save model if validation performance better in terms of cost 241 | if not is_train: 242 | if (self.best_cost is None) or (self.best_cost >= mean_cost): 243 | self.best_percent = success_percentage 244 | self.best_cost = mean_cost 245 | self.model.reset() 246 | 247 | utils.make_dir(self.logger.log_dir + '/checkpoints') 248 | filename = self.logger.log_dir + '/checkpoints/best.pt' 249 | torch.save(dict(state_dict=self.model.state_dict(), 250 | model_config=self.model_config, 251 | task_config=self.task_config), 252 | filename) 253 | 254 | # Log validation performance 255 | if not is_train: 256 | self.logger.experiment.add_scalar('Best Success Percentage', self.best_percent, itr) 257 | self.logger.experiment.add_scalar('Best Mean Cost', self.best_cost, itr) 258 | 259 | self.logger.experiment.add_scalar('Success Percentage', success_percentage, itr) 260 | self.logger.experiment.add_scalar('Mean Cost', mean_cost, itr) 261 | self.logger.experiment.add_scalar('Mean Success Cost', mean_success_cost, itr) 262 | 263 | print('Success Metric (Best) = {:.2f} ({:.2f}), ' 264 | 'Mean Cost (Best) = {:.3e} +/- {:.3e} ({:.3e}), ' 265 | 'Mean Success Cost = {:.3e} +/- {:.3e}'.format(success_percentage, 266 | self.best_percent if not self.best_percent is None else 0, 267 | mean_cost, 268 | std_cost, 269 | self.best_cost if not self.best_percent is None else np.inf, 270 | mean_success_cost, 271 | std_success_cost)) 272 | 273 | return trajectories 274 | 275 | def run(self) -> None: 276 | # Main loop 277 | for itr in range(self.n_iters): 278 | print('Iteration {}'.format(itr)) 279 | 280 | # Collect training data 281 | trajectories = self.run_rollouts(itr=itr, is_train=True) 282 | 283 | # Add trajectories to dataset 284 | self.dataset.clear() 285 | self.dataset.push(trajectories) 286 | self.dataset.compute_returns_and_advantages() 287 | self.dataset.split_into_subsequences() 288 | 289 | # Fit the model 290 | sampler = self.dataset.get_samplers() 291 | data_loader = DataLoader(self.dataset, 292 | batch_size=self.batch_size, 293 | sampler=sampler, 294 | num_workers=self.num_workers) 295 | 296 | # Reset the PyTorch Lightning trainer 297 | self.trainer.fit_loop.epoch_progress.reset() 298 | 299 | # Remove the current checkpoints 300 | utils.remove_file(self.trainer.checkpoint_callback.best_model_path) 301 | 302 | last_model_path = '/'.join(self.trainer.checkpoint_callback.best_model_path.split('/')[:-1]) 303 | last_model_path = last_model_path + '/last.ckpt' 304 | utils.remove_file(last_model_path) 305 | 306 | # Reset the PyTorch Lightning checkpoint callback 307 | self.trainer.checkpoint_callback.best_k_models = {} 308 | self.trainer.checkpoint_callback.best_model_score = None 309 | self.trainer.checkpoint_callback.best_model_path = '' 310 | self.trainer.checkpoint_callback.filename = None 311 | 312 | # Set up for pretraining 313 | if not self.n_pretrain_steps is None and itr < self.n_pretrain_steps: 314 | self.model_trainer.critic_only = True 315 | self.trainer.fit_loop.max_epochs = self.n_pretrain_epochs 316 | else: 317 | self.model_trainer.critic_only = False 318 | self.trainer.fit_loop.max_epochs = self.n_epochs 319 | 320 | # Train the model 321 | self.model_trainer.old_model.load_state_dict(self.model.state_dict()) 322 | self.trainer.fit(self.model_trainer, data_loader) 323 | 324 | # Test current policy 325 | self.model.to(**self.tensor_args) 326 | self.model.eval() 327 | 328 | if itr%self.val_every == 0: 329 | self.run_rollouts(itr=itr, is_train=False) 330 | -------------------------------------------------------------------------------- /dmpo/envs/quadrotor.py: -------------------------------------------------------------------------------- 1 | import meshcat 2 | import meshcat.geometry as g 3 | import meshcat.transformations as tf 4 | import rowan 5 | import time 6 | from copy import deepcopy 7 | import yaml 8 | from math import ceil 9 | 10 | from .quadrotor_param import QuadrotorParam 11 | from .math_utils import * 12 | 13 | from torch import Tensor 14 | from typing import Dict, Any, Optional, List, Tuple 15 | 16 | class QuadrotorEnv(): 17 | """ 18 | Quadrotor simulation environment 19 | """ 20 | def __init__(self, 21 | config: str, 22 | num_envs: int=1, 23 | use_omega: bool=False, 24 | action_is_mf: bool=True, 25 | convert_mf_to_omega: bool=False, 26 | use_delay_model: bool=False, 27 | delay_coeff: float=0.2, 28 | randomize_mass: bool = False, 29 | mass_range: List[float]=[0.7, 1.3], 30 | randomize_delay_coeff: bool = False, 31 | delay_range: List[float]=[0.2, 0.6], 32 | force_pert: bool=False, 33 | force_range: List[float]=[-3.5, 3.5], 34 | force_is_z: bool=False, 35 | ou_theta: float=0.15, 36 | ou_sigma: float=0.20, 37 | use_obs_noise: bool=False, 38 | tensor_args: Dict[str, Any]={'device': 'cpu', 'dtype': torch.float32}): 39 | """ 40 | :param config: YAML configuration file, which will be parsed to form a QuadrotorParam object 41 | :param num_envs: # of parallel environments to simulation 42 | :param use_omega: use the omega controller, which converts desired (thrust, omega) to motor forces 43 | :param action_is_mf: specified that the action space is motor forces 44 | :param convert_mf_to_omega: converts a motor force command to desired (thrust, omega) to model Crazyflie 45 | :param use_delay_model: use the delay model to translate desired (thrust, omega) into actual thrust and omega 46 | :param delay_coeff: coefficient of delay model 47 | :param randomize_mass: use domain randomization for mass 48 | :param mass_range: scaling factors for mass for domain randomization 49 | :param randomize_delay_coeff: use domain randomization for delay coefficient 50 | :param delay_range: range of delay coefficients to use for domain randomization 51 | :param force_pert: use random force perturbations 52 | :param force_range: range of force perturbations 53 | :param force_is_z: force perturbations can also be in Z direction (will just be XY if false) 54 | :param ou_theta: OU process theta parameter for changing force perturbation over time 55 | :param ou_sigma: OU process sigma parameter for changing force perturbation over time 56 | :param use_obs_noise: corrupt state with observation noise 57 | :param tensor_args: PyTorch tensor arguments 58 | """ 59 | self.num_envs = num_envs 60 | self.tensor_args = tensor_args 61 | 62 | # Load in the configuration 63 | self.config = yaml.load(open(config), yaml.FullLoader) 64 | self.param = QuadrotorParam(self.config) 65 | 66 | # Action space parameters 67 | self.use_omega = use_omega 68 | self.action_is_mf = action_is_mf 69 | self.convert_mf_to_omega = convert_mf_to_omega 70 | 71 | # Delay model parameters 72 | self.use_delay_model = use_delay_model 73 | self.true_delay_coeff = delay_coeff 74 | 75 | # Domain randomization parameters 76 | self.mass_range = mass_range 77 | self.randomize_mass = randomize_mass 78 | self.delay_range = delay_range 79 | self.randomize_delay_coeff = randomize_delay_coeff 80 | self.force_pert = force_pert 81 | self.force_range = force_range 82 | self.force_is_z = force_is_z 83 | self.use_obs_noise = use_obs_noise 84 | self.ou_theta = ou_theta 85 | self.ou_sigma = ou_sigma 86 | 87 | # Init timing 88 | self.times = self.param.sim_times 89 | self.time_step = 0 90 | self.avg_dt = self.times[1] - self.times[0] 91 | 92 | # Init system state 93 | self.init_state = torch.tensor(self.config['initial_state'], **self.tensor_args) 94 | 95 | # Control bounds in motor force space 96 | self.a_min = torch.tensor(self.param.a_min, **self.tensor_args) 97 | self.a_max = torch.tensor(self.param.a_max, **self.tensor_args) 98 | 99 | # Get action bounds for controller (may be motor force space if specified) 100 | if (not self.action_is_mf and not self.convert_mf_to_omega) or use_omega: 101 | self.action_lows = torch.tensor([0., -10, -10, -10], **self.tensor_args) 102 | #self.action_highs = torch.tensor([self.a_max[0]*4, 12, 12, 12], **self.tensor_args) 103 | self.action_highs = torch.tensor([0.7848, 10, 10, 10], **self.tensor_args) 104 | else: 105 | self.action_lows = self.a_min 106 | self.action_highs = self.a_max 107 | 108 | # Initial conditions 109 | self.s_min = torch.tensor(self.param.s_min, **self.tensor_args) 110 | self.s_max = -self.s_min 111 | self.rpy_limit = torch.tensor(self.param.rpy_limit, **self.tensor_args) 112 | self.limits = torch.tensor(self.param.limits, **self.tensor_args) 113 | 114 | # Constants 115 | self.d_state = 13 116 | self.d_obs = 13 117 | self.d_action = 4 118 | 119 | self.mass = torch.tensor([self.param.mass], **self.tensor_args).repeat(self.num_envs) 120 | self.g = self.param.g 121 | self.inv_mass = 1 / self.mass 122 | 123 | self.d = self.param.d 124 | self.rho = self.param.rho 125 | self.Cs = self.param.Cs 126 | self.Ct = self.param.Ct 127 | self.k1 = self.param.k1 128 | self.k2 = self.param.k2 129 | 130 | self.B0 = torch.tensor(self.param.B0, **self.tensor_args) 131 | self.B0_inv = torch.linalg.inv(self.B0) 132 | 133 | self.J = torch.tensor(self.param.J, **self.tensor_args) 134 | if self.J.shape == (3, 3): 135 | self.J = torch.as_tensor(self.J, **self.tensor_args) 136 | self.inv_J = torch.linalg.inv(self.J) 137 | else: 138 | self.J = torch.diag(torch.as_tensor(self.J, **self.tensor_args)) 139 | self.inv_J = torch.linalg.inv(self.J) 140 | 141 | # Controller gains 142 | self.omega_gain = self.config['omega_gain'] 143 | 144 | # Plotting stuff 145 | self.states_name = [ 146 | 'Position X [m]', 147 | 'Position Y [m]', 148 | 'Position Z [m]', 149 | 'Velocity X [m/s]', 150 | 'Velocity Y [m/s]', 151 | 'Velocity Z [m/s]', 152 | 'qw', 153 | 'qx', 154 | 'qy', 155 | 'qz', 156 | 'Angular Velocity X [rad/s]', 157 | 'Angular Velocity Y [rad/s]', 158 | 'Angular Velocity Z [rad/s]'] 159 | 160 | self.deduced_state_names = [ 161 | 'Roll [deg]', 162 | 'Pitch [deg]', 163 | 'Yaw [deg]', 164 | ] 165 | 166 | self.actions_name = [ 167 | 'Motor Force 1 [N]', 168 | 'Motor Force 2 [N]', 169 | 'Motor Force 3 [N]', 170 | 'Motor Force 4 [N]'] 171 | 172 | # Reward function coefficients 173 | # ref: row 8, Table 3, USC sim-to-real paper 174 | self.alpha_p = self.param.alpha_p 175 | self.alpha_z = self.param.alpha_z 176 | self.alpha_w = self.param.alpha_w 177 | self.alpha_a = self.param.alpha_a 178 | self.alpha_R = self.param.alpha_R 179 | self.alpha_v = self.param.alpha_v 180 | self.alpha_yaw = self.param.alpha_yaw 181 | self.alpha_pitch = self.param.alpha_pitch 182 | self.alpha_u_delta = self.param.alpha_u_delta 183 | self.alpha_u_thrust = self.param.alpha_u_thrust 184 | self.alpha_u_omega = self.param.alpha_u_omega 185 | 186 | def get_env_state(self) -> List[Tensor]: 187 | return self.states 188 | 189 | def set_env_state(self, state: List[Tensor]) -> None: 190 | self.states = state 191 | 192 | def set_param(self, param_dict: Dict[str, Any]) -> None: 193 | for k, v in param_dict.items(): 194 | setattr(self, k, v) 195 | 196 | def reset(self) -> Tensor: 197 | # Determine the initial state of the quadrotor 198 | if self.init_state is None: 199 | self.states = torch.zeros(self.d_state) 200 | 201 | # Position and velocity 202 | limits = self.limits 203 | self.states[0:6] = torch.rand(6) * (2 * limits[0:6]) - limits[0:6] 204 | 205 | # Rotation 206 | rpy = np.radians(np.random.uniform(-self.rpy_limit, self.rpy_limit, 3)) 207 | q = rowan.from_euler(rpy[0], rpy[1], rpy[2], 'xyz') 208 | self.states[6:10] = torch.tensor(q, **self.tensor_args) 209 | 210 | # Angular velocity 211 | self.states[10:13] = torch.rand(3) * (2 * limits[10:13]) - limits[10:13] 212 | else: 213 | self.states = self.init_state 214 | 215 | # Reset state and action variables 216 | self.states = self.states.unsqueeze(0).repeat(self.num_envs, 1) 217 | self.time_step = 0 218 | self.actions = torch.zeros((self.num_envs, self.d_action), **self.tensor_args) 219 | self.prev_actions = torch.zeros((self.num_envs, self.d_action), **self.tensor_args) 220 | 221 | # Randomize mass 222 | if self.randomize_mass: 223 | mass_scale = torch.rand(self.num_envs, **self.tensor_args)*(self.mass_range[1] - self.mass_range[0]) + self.mass_range[0] 224 | self.mass = torch.tensor([self.param.mass], **self.tensor_args).repeat(self.num_envs) 225 | self.mass = self.mass * mass_scale 226 | self.inv_mass = 1 / self.mass 227 | else: 228 | self.mass = torch.tensor([self.param.mass], **self.tensor_args).repeat(self.num_envs) 229 | self.inv_mass = 1 / self.mass 230 | 231 | # Randomize delay coeff: 232 | if self.randomize_delay_coeff: 233 | self.delay_coeff = torch.rand(self.num_envs, **self.tensor_args) * (self.delay_range[1] - self.delay_range[0]) + self.delay_range[0] 234 | else: 235 | self.delay_coeff = torch.tensor([self.true_delay_coeff], **self.tensor_args).repeat(self.num_envs) 236 | 237 | if self.force_pert: 238 | self.force_dist = torch.rand((self.num_envs, 3), **self.tensor_args) * (self.force_range[1] - self.force_range[0]) + self.force_range[0] 239 | if not self.force_is_z: 240 | self.force_dist[:, -1] = 0 241 | else: 242 | self.force_dist = torch.zeros((self.num_envs, 3), **self.tensor_args) 243 | 244 | # Get the reference trajectories 245 | self.ref_trajectory, self.ref_dts, self.ref_pos = self.param.get_reference(self.num_envs) 246 | self.ref_trajectory = self.ref_trajectory.to(**self.tensor_args) 247 | 248 | return self.states 249 | 250 | def get_env_obs(self) -> List[Tensor]: 251 | if self.use_obs_noise: 252 | noise = torch.randn((self.num_envs, self.d_state), **self.tensor_args) 253 | noise = noise * torch.as_tensor(self.param.noise_measurement_std, **self.tensor_args) 254 | noisystate = self.states + noise 255 | noisystate[:, 6:10] /= torch.norm(noisystate[:, 6:10], dim=-1, keepdim=True) 256 | return noisystate 257 | else: 258 | return self.states 259 | 260 | def f(self, s: Tensor, a: Tensor) -> Tensor: 261 | num_envs = s.shape[0] 262 | dsdt = torch.zeros(num_envs, 13).to(**self.tensor_args) 263 | v = s[:, 3:6] # velocity (N, 3) 264 | q = s[:, 6:10] # quaternion (N, 4) 265 | omega = s[:, 10:] # angular velocity (N, 3) 266 | 267 | if self.action_is_mf and not self.convert_mf_to_omega: 268 | # If action space is motor forces and we did not convert to omega space, then compute wrench 269 | eta = a @ self.B0.T # output wrench (N, 4) 270 | else: 271 | # Otherwise, our action is (thrust, omega) 272 | eta = a 273 | 274 | f_u = torch.zeros(num_envs, 3).to(**self.tensor_args) 275 | f_u[:, 2] = eta[:, 0] # total thrust (N, 3) 276 | tau_u = eta[:, 1:] # torque (N, 3) or desired omega 277 | 278 | # dynamics 279 | # \dot{p} = v 280 | dsdt[:, :3] = v # <- implies velocity and position in same frame 281 | 282 | # Apply the force perturbation 283 | if self.force_pert: 284 | dsdt[:, 3:6] += self.force_dist 285 | 286 | # mv = mg + R f_u # <- implies f_u in body frame, p, v in world frame 287 | dsdt[:, 5] -= self.g 288 | dsdt[:, 3:6] += qrotate_torch(q, f_u) / self.mass[:, None] 289 | 290 | # \dot{R} = R S(w) 291 | # see https://rowan.readthedocs.io/en/latest/package-calculus.html 292 | qnew = qintegrate_torch(q, omega, self.avg_dt, frame='body') 293 | qnew = qstandardize_torch(qnew) 294 | 295 | # transform qnew to a "delta q" that works with the usual euler integration 296 | dsdt[:, 6:10] = (qnew - q) / self.avg_dt 297 | 298 | if self.action_is_mf and not self.convert_mf_to_omega: 299 | # Compute omega from torques 300 | # J\dot{w} = Jw x w + tau_u 301 | Jomega = omega @ self.J.T 302 | dsdt[:, 10:] = torch.cross(Jomega, omega) + tau_u 303 | dsdt[:, 10:] = dsdt[:, 10:] @ self.inv_J.T 304 | else: 305 | # Set updated omega to be the control command 306 | dsdt[:, 10:] = (tau_u - omega) / self.avg_dt 307 | 308 | # Adding noise 309 | dsdt[:, 3:6] += torch.normal(mean=0, 310 | std=self.param.noise_process_std[0], 311 | size=(self.num_envs, 3), 312 | **self.tensor_args) 313 | dsdt[:, 10:] += torch.normal(mean=0, 314 | std=self.param.noise_process_std[1], 315 | size=(self.num_envs, 3), 316 | **self.tensor_args) 317 | 318 | return dsdt 319 | 320 | def next_state(self, s: Tensor, a: Tensor) -> Tensor: 321 | new_s = s + self.avg_dt * self.f(s, a) 322 | return new_s 323 | 324 | def get_cost(self, a: Tensor) -> Tensor: 325 | state_ref = self.ref_trajectory[:, :, self.time_step] 326 | p_des = state_ref[:, 0:3] 327 | v_des = state_ref[:, 3:6] 328 | w_des = state_ref[:, 10:] 329 | q_des = state_ref[:, 6:10] 330 | 331 | # Position tracking error 332 | if self.alpha_p > 0: 333 | ep = torch.linalg.norm(self.states[:, 0:3] - p_des, dim=1) 334 | else: 335 | ep = 0. 336 | 337 | # Additional cost on Z tracking error 338 | if self.alpha_z > 0: 339 | ez = torch.linalg.norm(self.states[:, 2:3] - p_des[2:3], dim=-1) 340 | else: 341 | ez = 0. 342 | 343 | # Velocity tracking error 344 | if self.alpha_v > 0: 345 | ev = torch.linalg.norm(self.states[:, 3:6] - v_des, dim=1) 346 | else: 347 | ev = 0. 348 | 349 | # Angular velocity tracking error 350 | if self.alpha_w > 0: 351 | ew = torch.linalg.norm(self.states[:, 10:] - w_des, dim=1) 352 | else: 353 | ew = 0. 354 | 355 | # Orientation tracking error 356 | if self.alpha_R > 0: 357 | eR = qdistance_torch(self.states[:, 6:10], q_des) 358 | else: 359 | eR = 0. 360 | 361 | # Control cost 362 | if self.alpha_a > 0: 363 | ea = torch.linalg.norm(a, dim=1) 364 | else: 365 | ea = 0. 366 | 367 | # Yaw tracking error 368 | if self.alpha_yaw > 0: 369 | qe = qmultiply_torch(qconjugate_torch(q_des), self.states[:, 6:10]) 370 | Re = qtoR_torch(qe) 371 | eyaw = torch.atan2(Re[:, 1, 0], Re[:, 0, 0]) ** 2 372 | else: 373 | eyaw = 0. 374 | 375 | # Pitch tracking error 376 | if self.alpha_pitch > 0: 377 | qe = qmultiply_torch(qconjugate_torch(q_des), self.states[:, 6:10]) 378 | Re = qtoR_torch(qe) 379 | epitch = (torch.asin(Re[:,2,0].clip(-1, 1)))**2 380 | else: 381 | epitch = 0 382 | 383 | # Penalize control changes between time steps 384 | if self.alpha_u_delta > 0: 385 | edelta = torch.norm(a - self.prev_actions, dim=1) 386 | else: 387 | edelta = 0 388 | 389 | # Separate thrust control cost (use only if control space includes thrust) 390 | if self.alpha_u_thrust > 0: 391 | ethrust = torch.norm(a[:, :1], dim=1) 392 | else: 393 | ethrust = 0 394 | 395 | # Separate omega control cost (use only if control space includes omega) 396 | if self.alpha_u_omega > 0: 397 | eomega = torch.norm(a[:, 1:], dim=1) 398 | else: 399 | eomega = 0 400 | 401 | cost = (self.alpha_p * ep 402 | + self.alpha_z * ez 403 | + self.alpha_v * ev 404 | + self.alpha_w * ew 405 | + self.alpha_a * ea 406 | + self.alpha_yaw * eyaw 407 | + self.alpha_R * eR 408 | + self.alpha_pitch * epitch 409 | + self.alpha_u_delta * edelta 410 | + self.alpha_u_thrust * ethrust 411 | + self.alpha_u_omega * eomega) * self.avg_dt 412 | return cost 413 | 414 | def get_env_infos(self) -> Dict[str, Any]: 415 | done = (self.time_step + 1) >= len(self.times) 416 | dones = [done for _ in range(self.num_envs)] 417 | return dict(dones=dones, done=done) 418 | 419 | def omega_controller(self, a: Tensor) -> Tensor: 420 | ''' 421 | Converts desired (thrust, omega) to motor forces 422 | ''' 423 | T_d = a[:, 0] 424 | omega_d = a[:, 1:] 425 | 426 | omega = self.states[:, 10:13] 427 | omega_e = omega_d - omega 428 | 429 | torque = self.omega_gain * omega_e # tensor, (3,) 430 | torque = torch.mm(self.J, torque.T).T 431 | torque -= torch.cross(torch.mm(self.J, omega.T).T, omega) 432 | 433 | wrench = torch.cat((T_d.view(self.num_envs, 1), torque), dim=1) # tensor, (N, 4) 434 | motorForce = torch.mm(self.B0_inv, wrench.T).T 435 | motorForce = torch.clip(motorForce, self.a_min, self.a_max) 436 | return motorForce 437 | 438 | def convert_motor_forces(self, a: Tensor) -> Tensor: 439 | ''' 440 | Converts motor forces to desired (thrust, omega) 441 | ''' 442 | eta = a @ self.B0.T 443 | T_d = eta[:, :1] 444 | tau_u = eta[:, 1:] 445 | 446 | omega = self.states[:, 10:] 447 | Jomega = omega @ self.J.T 448 | d_omega = torch.cross(Jomega, omega) + tau_u 449 | d_omega = d_omega @ self.inv_J.T 450 | omega_d = omega + d_omega*self.avg_dt 451 | 452 | new_a = torch.cat((T_d, omega_d), dim=-1) 453 | return new_a 454 | 455 | def step(self, a: Tensor) -> Tuple[Any, ...]: 456 | a = a.to(**self.tensor_args) 457 | 458 | if not self.action_is_mf and not self.convert_mf_to_omega: 459 | # Ensure control bounds if using desired (thrust, omega) as control space 460 | a = torch.clip(a, self.action_lows, self.action_highs) 461 | elif self.convert_mf_to_omega: 462 | # Convert motor forces to desired (thrust, omega) command 463 | a = torch.clip(a, self.a_min, self.a_max) 464 | a = self.convert_motor_forces(a) 465 | elif self.action_is_mf: 466 | # Ensure control bounds if using motor forces 467 | a = torch.clip(a, self.a_min, self.a_max) 468 | 469 | # Apply omega controller to convert motor forces to desired (thrust, omega) 470 | if self.use_omega: 471 | a = self.omega_controller(a) 472 | cmd = a.clone() 473 | 474 | # Apply delay model on controls (should only be used for desired thrust, omega action space) 475 | if self.use_delay_model: 476 | self.actions = self.actions + self.delay_coeff[:, None]*(a - self.actions) 477 | a = self.actions 478 | 479 | # Compute the state transitions 480 | new_states = self.next_state(self.states, a) 481 | self.states = new_states 482 | 483 | # Increment force perturbation with OU process 484 | if self.force_pert: 485 | d_force = -self.ou_theta * self.force_dist 486 | d_force += torch.randn(d_force.shape, **self.tensor_args) * self.ou_sigma 487 | self.force_dist += d_force * self.avg_dt 488 | self.force_dist = self.force_dist.clamp(self.force_range[0], self.force_range[1]) 489 | if not self.force_is_z: 490 | self.force_dist[:, -1] = 0 491 | 492 | reward = -self.get_cost(cmd) 493 | info = self.get_env_infos() 494 | self.time_step += 1 495 | self.prev_actions = cmd 496 | return self.states, reward, info['done'], info 497 | 498 | def render(self, 499 | env_idx: int=0, 500 | samples: Optional[Tensor]=None, 501 | mean: Optional[Tensor]=None, 502 | state: Optional[Tensor]=None) -> Any: 503 | return None 504 | 505 | def get_param_dict(self) -> List[Dict[str, Any]]: 506 | """ 507 | Return a dictionary with parameters to be sent to MPC controller 508 | """ 509 | param_dict = [dict(t=self.times[self.time_step if self.time_step < len(self.times)-1 else -1], 510 | actions=self.actions[idx], 511 | ref_dts=self.ref_dts[idx] if not self.ref_dts is None else None, 512 | ref_pos=self.ref_pos[idx] if not self.ref_pos is None else None) 513 | for idx in range(self.num_envs)] 514 | return param_dict 515 | 516 | def get_env_description(self) -> List[Tensor]: 517 | """ 518 | Get a Tensor containing the reference trajectory used to condition policy and value function 519 | """ 520 | T = 32 521 | stride = 4 522 | dim = T//stride 523 | 524 | ref_traj = self.ref_trajectory[:, :, self.time_step:self.time_step+T:stride] 525 | ref_traj = torch.cat((ref_traj[:, :3], ref_traj[:, 6:10]), dim=1) 526 | 527 | if ref_traj.shape[2] < dim: 528 | diff = T//stride - ref_traj.shape[2] 529 | end_step = ref_traj[:, :, -1:].repeat(1, 1, diff) 530 | ref_traj = torch.cat((ref_traj, end_step), dim=-1) 531 | 532 | cond = ref_traj.reshape(self.num_envs, -1) 533 | #cond = torch.cat((cond, self.mass[:, None], self.delay_coeff[:, None], self.force_dist), dim=-1) 534 | info = [cond[i] for i in range(self.num_envs)] 535 | return info 536 | 537 | def evaluate_success(self, trajectories: List[Dict[str, Any]]) -> Dict[str, Any]: 538 | num_traj = len(trajectories) 539 | successes = [] 540 | total_costs = [] 541 | 542 | for idx, traj in enumerate(trajectories): 543 | costs = -traj['rewards'] 544 | 545 | total_cost = costs.sum() 546 | success = True 547 | 548 | successes.append(success) 549 | total_costs.append(total_cost) 550 | successes = torch.tensor(successes) 551 | total_costs = torch.tensor(total_costs) 552 | 553 | success_percentage = torch.sum(successes) / num_traj * 100. 554 | ret_dict = dict( 555 | successes=successes, 556 | total_costs=total_costs, 557 | success_percentage=success_percentage 558 | ) 559 | return ret_dict 560 | 561 | def visualize(self, states, ref_traj, dt): 562 | # Create a new visualizer 563 | vis = meshcat.Visualizer() 564 | vis.open() 565 | 566 | vis["/Cameras/default"].set_transform( 567 | tf.translation_matrix([0, 0, 0]).dot( 568 | tf.euler_matrix(0, np.radians(-30), -np.pi / 2))) 569 | 570 | vis["/Cameras/default/rotated/"].set_transform( 571 | tf.translation_matrix([1, 0, 0])) 572 | 573 | vis['/Background'].set_property('top_color', [0, 0, 0]) 574 | vis['/Background'].set_property('bottom_color', [0, 0, 0]) 575 | 576 | vis["Quadrotor"].set_object(g.StlMeshGeometry.from_file('./crazyflie2.stl')) 577 | 578 | vertices = np.array([[0, 0.5], [0, 0], [0, 0]]).astype(np.float32) 579 | vis["lines_segments"].set_object(g.Line(g.PointsGeometry(vertices), 580 | g.MeshBasicMaterial(color=0xff0000, linewidth=100.))) 581 | 582 | vis['ref'].set_object(g.Line(g.PointsGeometry(ref_traj.numpy()[:, :3].T), 583 | g.LineBasicMaterial(color=0xff99ff, linewidth=100.))) 584 | 585 | while True: 586 | for state in states: 587 | vis["Quadrotor"].set_transform( 588 | tf.translation_matrix([state[0], state[1], state[2]]).dot( 589 | tf.quaternion_matrix(state[6:10]))) 590 | 591 | vis["lines_segments"].set_transform( 592 | tf.translation_matrix([state[0], state[1], state[2]]).dot( 593 | tf.quaternion_matrix(state[6:10]))) 594 | 595 | time.sleep(dt) -------------------------------------------------------------------------------- /dmpo/controllers/dmpo_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributions as D 4 | 5 | from ..models.net_utils import create_net 6 | from ..mpc.task.base_rollout_task import BaseRolloutTask 7 | from .. import utils 8 | 9 | from torch import Tensor 10 | from torch.nn import Module 11 | from typing import Dict, Any, Optional, List, Union, Tuple, Callable 12 | 13 | STD_MAX = 1e3 14 | STD_MIN = 1e-6 15 | 16 | class DMPOPolicy(nn.Module): 17 | """ 18 | Contains actor and critic for DMPO 19 | """ 20 | def __init__(self, 21 | d_action: int, 22 | d_state: int, 23 | actor_params: Dict[str, Any], 24 | critic_params: Dict[str, Any], 25 | shift_params: Optional[Dict[str, Any]] = None, 26 | rollout_task: Optional[BaseRolloutTask] = None, 27 | horizon: int = 1, 28 | num_particles: int = 1, 29 | gamma: float = 1., 30 | top_k: int = 1, 31 | n_iters: int = 1, 32 | init_mean: Union[float, Tensor] = 0., 33 | init_std: Union[float, Tensor] = 1., 34 | sample_params: Optional[Dict[str, Any]] = None, 35 | seed_val: int = 0, 36 | state_seq_key: str = 'state_seq', 37 | mppi_params: Optional[Dict[str, Any]] = None, 38 | d_cond: Optional[int] = None, 39 | cond_mode: Optional[str] = None, 40 | cond_actor: bool = False, 41 | cond_critic: bool = False, 42 | cond_shift: bool = False, 43 | critic_use_cost: bool = False, 44 | actor_use_state: bool = False, 45 | use_mean: bool = False, 46 | mppi_mode: bool = False, 47 | is_delta: bool = False, 48 | is_gated: bool = False, 49 | is_residual: bool = False, 50 | mean_search_std: Optional[float] = None, 51 | std_search_std: Optional[float] = None, 52 | learn_search_std: bool = False, 53 | learn_rollout_std: bool = False, 54 | action_lows: Optional[Union[float, List[float], Tensor]] = None, 55 | action_highs: Optional[Union[float, List[float], Tensor]] = None, 56 | state_scale: Optional[Union[float, List[float], Tensor]] = None, 57 | cond_scale: Optional[Union[float, List[float], Tensor]] = None, 58 | tensor_args: Dict[str, Any] = {'device': 'cpu', 'dtype': torch.float32}): 59 | """ 60 | :param d_action: Dimensionality of the action space 61 | :param d_state: Dimensionality of the state space 62 | :param actor_params: Dictionary of model parameters for the actor network 63 | :param critic_params: Dictionary of model parameters for the critic network 64 | :param shift_params: Dictionary of model parameters for the shift model network 65 | :param rollout_task: Rollout function 66 | :param horizon: Horizon of MPC controller 67 | :param num_particles: Number of particles used for rollouts 68 | :param gamma: Discount factor 69 | :param top_k: Number of top performing trajectories to save for visualization purposes 70 | :param n_iters: Number of optimizer iterations 71 | :param init_mean: Initial mean of the optimizee 72 | :param init_std: Initial STD of the optimizee 73 | :param sample_params: Dictionary of parameters for sampling 74 | :param seed_val: Seed used to generate fixed samples from the standard Gaussian 75 | :param state_seq_key: Key used to acquire state sequences from dictionary returned by rollout module 76 | :param mppi_params: Dictionary of MPPI parameters 77 | :param d_cond: Dimensionality of conditioning variable 78 | :param cond_mode: Method for incorporating the conditioning variable (e.g. concatenating it to network input) 79 | :param cond_actor: Flag for conditioning the actor 80 | :param cond_critic: Flag for conditioning the critic 81 | :param cond_shift: Flag for conditioning the shift model 82 | :param critic_use_cost: Flag to indicate that the critic should use the trajectory costs rather than state 83 | :param actor_use_state: Flag to indicate that the actor should also use state 84 | :param use_mean: Flag to indicate that we use the mean of our optimizer distribution (for test mode) 85 | :param mppi_mode: Flag to indicate that we should run vanilla MPPI 86 | :param is_delta: Flag to indicate that the actor output should be added to current plan (used if not a residual on MPPI) 87 | :param is_gated: Flag to indicate that we use a gating term on the mean update 88 | :param is_residual: Flag to indicate that DMPO is learning residuals on MPPI 89 | :param mean_search_std: Initial STD of the optimizer policy over optimizee means 90 | :param std_search_std: Initial STD of the optimizer policy over optimizee STDs 91 | :param learn_search_std: Flag to indicate that we should learn the search STDs 92 | :param learn_rollout_std: Flag to indicate that we should learn the optimizee STDs (rather than assuming they are fixed) 93 | :param action_lows: Minimum action values 94 | :param action_highs: Maximum action values 95 | :param state_scale: Scale factor on state prior to any network input 96 | :param cond_scale: Scale factor on conditioning variable prior to any network input 97 | :param tensor_args: PyTorch Tensor settings 98 | """ 99 | super().__init__() 100 | 101 | # Rollout parameters 102 | self.d_action = d_action 103 | self.d_state = d_state 104 | self.d_cond = d_cond 105 | self.horizon = horizon 106 | 107 | self.num_particles = num_particles 108 | self.gamma = gamma 109 | self.top_k = top_k 110 | self.n_iters = n_iters 111 | self.init_n_iters = n_iters 112 | self.seed_val = seed_val 113 | self.state_seq_key = state_seq_key 114 | 115 | # Rollout function 116 | self.rollout_task = rollout_task 117 | 118 | # Sample parameters 119 | self.use_halton = sample_params.get('use_halton', True) if not sample_params is None else True 120 | 121 | # MPPI parameters 122 | self.temperature = mppi_params.get('temperature', 1e-3) if not mppi_params is None else 1e-3 123 | self.step_size = mppi_params.get('step_size', 1) if not mppi_params is None else 1 124 | self.scale_costs = mppi_params.get('scale_costs', True) if not mppi_params is None else True 125 | 126 | # Conditioning parameters 127 | self.cond_mode = cond_mode if not self.d_cond is None else None 128 | self.cond_actor = cond_actor 129 | self.cond_critic = cond_critic 130 | self.cond_shift = cond_shift 131 | 132 | # Other parameters 133 | self.use_mean = use_mean 134 | self.mppi_mode = mppi_mode 135 | self.learn_search_std = learn_search_std 136 | self.learn_rollout_std = learn_rollout_std 137 | self.is_gated = is_gated 138 | self.is_residual = is_residual 139 | self.is_delta = is_delta 140 | self.critic_use_cost = critic_use_cost 141 | self.actor_use_state = actor_use_state 142 | 143 | self.action_lows = utils.to_tensor(action_lows, tensor_args) if not action_lows is None else None 144 | self.action_highs = utils.to_tensor(action_highs, tensor_args) if not action_highs is None else None 145 | self.state_scale = utils.to_tensor(state_scale, tensor_args) if not state_scale is None else None 146 | self.cond_scale = utils.to_tensor(cond_scale, tensor_args) if not cond_scale is None else None 147 | self.tensor_args = tensor_args 148 | 149 | # Sizes of each input type 150 | mean_size = self.horizon * self.d_action 151 | std_size = self.horizon * self.d_action 152 | cost_size = self.num_particles 153 | cond_size = self.d_cond 154 | 155 | # Create the shift network 156 | if not shift_params is None: 157 | in_size = mean_size 158 | if learn_rollout_std: 159 | in_size += std_size 160 | if cond_mode == 'cat' and not d_cond is None and cond_shift: 161 | in_size += cond_size 162 | 163 | out_size = self.horizon*self.d_action 164 | if learn_rollout_std: 165 | out_size += self.horizon*self.d_action 166 | 167 | self.shift_model = create_net(in_size=in_size, out_size=out_size, **shift_params) 168 | else: 169 | self.shift_model = None 170 | 171 | # Create the actor network 172 | in_size = mean_size + cost_size 173 | if actor_use_state: 174 | in_size += d_state 175 | if learn_rollout_std: 176 | in_size += std_size 177 | if cond_mode == 'cat' and not d_cond is None and cond_actor: 178 | in_size += cond_size 179 | 180 | out_size = mean_size 181 | if is_gated: 182 | out_size += mean_size 183 | if learn_search_std: 184 | out_size += mean_size 185 | if learn_rollout_std: 186 | out_size += std_size 187 | if learn_search_std: 188 | out_size += std_size 189 | 190 | self.actor = create_net(in_size=in_size, out_size=out_size, **actor_params) 191 | 192 | # Create the critic 193 | if not critic_use_cost: 194 | in_size = d_state + mean_size 195 | else: 196 | in_size = mean_size + cost_size 197 | 198 | if learn_rollout_std: 199 | in_size += std_size 200 | if cond_mode == 'cat' and not d_cond is None and cond_critic: 201 | in_size += cond_size 202 | 203 | self.critic = create_net(in_size=in_size, out_size=1, **critic_params) 204 | 205 | # Set the initial mean 206 | if not isinstance(init_mean, Tensor): 207 | if isinstance(init_mean, float): 208 | self.init_mean = init_mean*torch.ones((horizon, d_action), **tensor_args) 209 | else: 210 | self.init_mean = torch.tensor(init_mean, **tensor_args).unsqueeze(0).repeat(horizon, 1) 211 | else: 212 | self.init_mean = init_mean.to(**tensor_args) 213 | 214 | # Set the initial rollout STD 215 | if not isinstance(init_std, Tensor): 216 | if isinstance(init_std, float): 217 | self.init_std = init_std*torch.ones((horizon, d_action), **tensor_args) 218 | else: 219 | self.init_std = torch.tensor(init_std, **tensor_args).unsqueeze(0).repeat(horizon, 1) 220 | else: 221 | self.init_std = init_std.to(**tensor_args) 222 | 223 | # Set the initial mean search STD 224 | if not mean_search_std is None: 225 | if not isinstance(mean_search_std, Tensor): 226 | if isinstance(mean_search_std, float): 227 | self.mean_search_std = mean_search_std * torch.ones((horizon, d_action), **tensor_args) 228 | else: 229 | self.mean_search_std = torch.tensor(mean_search_std, **tensor_args).unsqueeze(0).repeat(horizon, 230 | 1) 231 | else: 232 | self.mean_search_std = mean_search_std.to(**tensor_args) 233 | else: 234 | self.mean_search_std = None 235 | 236 | # Set the initial STD search STD 237 | if not std_search_std is None: 238 | if not isinstance(std_search_std, Tensor): 239 | if isinstance(std_search_std, float): 240 | self.std_search_std = std_search_std * torch.ones((horizon, d_action), **tensor_args) 241 | else: 242 | self.std_search_std = torch.tensor(std_search_std, **tensor_args).unsqueeze(0).repeat(horizon, 243 | 1) 244 | else: 245 | self.std_search_std = std_search_std.to(**tensor_args) 246 | else: 247 | self.std_search_std = None 248 | 249 | def set_n_iters(self, n_iters: int): 250 | self.n_iters = n_iters 251 | 252 | def reset(self): 253 | self.mean = self.init_mean.clone() 254 | self.std = self.init_std.clone() 255 | self.samples = None 256 | self.state_samples = None 257 | self.mppi_mean = None 258 | self.params_stacked = {} 259 | 260 | def update_params(self, kwargs: Dict[str, Any]): 261 | kwargs_stacked = {key: [] for key in kwargs[0].keys()} 262 | for kwargs_dict in kwargs: 263 | for k, v in kwargs_dict.items(): 264 | kwargs_stacked[k].append(v) 265 | self.params_stacked = kwargs_stacked 266 | self.rollout_task.update_params(kwargs_stacked) 267 | 268 | def set_task(self, task: BaseRolloutTask) -> None: 269 | self.rollout_task = task 270 | 271 | def generate_samples(self, batch_size: int): 272 | num_particles = self.num_particles-1 273 | 274 | if self.samples is None: 275 | if self.use_halton: 276 | self.samples = utils.generate_gaussian_halton_samples(num_samples=num_particles, 277 | ndims=self.d_action*self.horizon, 278 | seed_val=self.seed_val, 279 | device=self.tensor_args['device'], 280 | dtype=self.tensor_args['dtype']) 281 | else: 282 | with torch.random.fork_rng([torch.device(self.tensor_args['device'])]) as rng: 283 | torch.random.manual_seed(self.seed_val) 284 | self.samples = torch.randn((num_particles, self.d_action*self.horizon), 285 | **self.tensor_args) 286 | 287 | self.samples = self.samples[None, :, :] 288 | 289 | samples = self.samples.view(1, num_particles, self.horizon, self.d_action) 290 | samples = samples.repeat(batch_size, 1, 1, 1) 291 | 292 | # Always ensure mean is a sample 293 | zeros = torch.zeros((batch_size, 1, self.horizon, self.d_action), **self.tensor_args) 294 | samples = torch.cat((zeros, samples), dim=1) 295 | 296 | std = self.std[:, None, :, :] 297 | samples = samples * std + self.mean[:, None, :, :] 298 | 299 | if not self.action_lows is None and not self.action_highs is None: 300 | samples = samples.clamp(self.action_lows, self.action_highs) 301 | return samples 302 | 303 | def run_rollouts(self, x: Tensor) -> Tuple[Tensor, Tensor]: 304 | batch_size = x.shape[0] 305 | 306 | with torch.no_grad(): 307 | # Generate action samples 308 | samples = self.generate_samples(batch_size) 309 | 310 | # Run the rollouts 311 | trajectories = self.rollout_task.run_rollouts(x, samples) 312 | 313 | # Collect the results 314 | costs = trajectories['costs'] 315 | states = trajectories[self.state_seq_key] 316 | 317 | if not isinstance(costs, Tensor): 318 | costs = costs['cost'] 319 | 320 | # Compute the total costs 321 | gamma = torch.tensor([self.gamma ** i for i in range(costs.shape[-1])], **self.tensor_args) 322 | costs = torch.sum(gamma[None, None, :] * costs, dim=-1) 323 | 324 | # Process states 325 | self.state_samples = states 326 | top_values, top_idx = torch.topk(-costs, self.top_k, dim=-1) 327 | self.top_values = -top_values 328 | self.top_idx = top_idx 329 | self.top_trajs = [torch.index_select(states[idx], 0, top_idx[idx]) for idx in range(states.shape[0])] 330 | 331 | return costs, samples 332 | 333 | def process_mean(self, x: Tensor, mean: Tensor): 334 | batch_size = x.shape[0] 335 | if mean.ndim == 2: 336 | mean = mean.unsqueeze(0).repeat(batch_size, 1, 1) 337 | old_mean = mean 338 | 339 | # Shift mean forward 340 | shifted_mean = torch.cat((mean[:, 1:], torch.zeros_like(mean[:, -1:])), dim=1) 341 | return shifted_mean, old_mean 342 | 343 | def process_std(self, x: Tensor, std: Tensor, shift=True): 344 | batch_size = x.shape[0] 345 | if std.ndim == 2: 346 | std = std.unsqueeze(0).repeat(batch_size, 1, 1) 347 | old_std = std 348 | 349 | # Shift mean forward 350 | if shift: 351 | shifted_std = torch.cat((std[:, 1:], std[:, -1:]), dim=1) 352 | return shifted_std, old_std 353 | else: 354 | return old_std 355 | 356 | def get_actor_embedding(self, 357 | x: Tensor, 358 | costs: Tensor, 359 | mean: Optional[Tensor] = None, 360 | std: Optional[Tensor] = None, 361 | cond: Optional[Tensor] = None) -> Tensor: 362 | batch_size = x.shape[0] 363 | if not self.state_scale is None: 364 | x = x/self.state_scale 365 | 366 | # Compute the normalized costs 367 | costs = costs.reshape(batch_size, -1) 368 | cost_mean = costs.mean(dim=-1) 369 | cost_std = costs.std(dim=-1) 370 | costs = (costs - cost_mean[:, None]) / (cost_std[:, None] + 1e-6) 371 | 372 | # Reshape the mean and STD 373 | if not self.action_highs is None and not self.action_lows is None: 374 | mean = (mean - self.action_lows) / (self.action_highs - self.action_lows) 375 | mean = mean.view(batch_size, -1) 376 | 377 | if not std is None: 378 | if not self.action_highs is None and not self.action_lows is None: 379 | std = std / (self.init_std[0]*10) 380 | std = std.view(batch_size, -1) 381 | 382 | # Prepare the condition 383 | if not cond is None and not self.cond_scale is None: 384 | if self.cond_scale != cond.shape[-1]: 385 | cond_scale = self.cond_scale[:1].repeat(cond.shape[-1]) 386 | cond = cond / cond_scale 387 | else: 388 | cond = cond / self.cond_scale 389 | 390 | # Form the network input 391 | net_in = torch.cat((costs, mean), dim=-1) 392 | if self.actor_use_state: 393 | net_in = torch.cat((net_in, x), dim=-1) 394 | if self.learn_rollout_std: 395 | net_in = torch.cat((net_in, std), dim=-1) 396 | if not cond is None and not self.cond_mode is None and self.cond_actor: 397 | if self.cond_mode == 'cat': 398 | net_in = torch.cat((net_in, cond), dim=-1) 399 | else: 400 | raise ValueError('Invalid condition mode {} specified.'.format(self.cond_mode)) 401 | 402 | return net_in 403 | 404 | def get_critic_embedding(self, 405 | x: Tensor, 406 | costs: Tensor, 407 | mean: Optional[Tensor] = None, 408 | std: Optional[Tensor] = None, 409 | cond: Optional[Tensor] = None) -> Tensor: 410 | 411 | batch_size = x.shape[0] 412 | if not self.state_scale is None: 413 | x = x/self.state_scale 414 | 415 | # Reshape the mean and STD 416 | if not self.action_highs is None and not self.action_lows is None: 417 | mean = (mean - self.action_lows) / (self.action_highs - self.action_lows) 418 | mean = mean.view(batch_size, -1) 419 | 420 | if not std is None: 421 | if not self.action_highs is None and not self.action_lows is None: 422 | std = std / (self.init_std[0]*10) 423 | std = std.view(batch_size, -1) 424 | 425 | # Compute the total costs 426 | costs = costs.reshape(batch_size, -1) 427 | cost_mean = costs.mean(dim=-1) 428 | cost_std = costs.std(dim=-1) 429 | costs = (costs - cost_mean[:, None]) / (cost_std[:, None] + 1e-6) 430 | 431 | # Prepare the condition 432 | if not cond is None and not self.cond_scale is None: 433 | if self.cond_scale != cond.shape[-1]: 434 | cond_scale = self.cond_scale[:1].repeat(cond.shape[-1]) 435 | cond = cond / cond_scale 436 | else: 437 | cond = cond / self.cond_scale 438 | 439 | # Form the network input 440 | if not self.critic_use_cost: 441 | net_in = torch.cat((x, mean), dim=-1) 442 | else: 443 | net_in = torch.cat((costs, mean), dim=-1) 444 | 445 | if self.learn_rollout_std: 446 | net_in = torch.cat((net_in, std), dim=-1) 447 | 448 | if not cond is None and not self.cond_mode is None and self.cond_critic: 449 | if self.cond_mode == 'cat': 450 | net_in = torch.cat((net_in, cond), dim=-1) 451 | else: 452 | raise ValueError('Invalid condition mode {} specified.'.format(self.cond_mode)) 453 | 454 | return net_in 455 | 456 | def get_shift_embedding(self, 457 | x: Tensor, 458 | mean: Optional[Tensor] = None, 459 | std: Optional[Tensor] = None, 460 | shifted_std: Optional[Tensor] = None, 461 | cond: Optional[Tensor] = None) -> Tensor: 462 | batch_size = x.shape[0] 463 | 464 | # Reshape the mean and STD 465 | if not self.action_highs is None and not self.action_lows is None: 466 | mean = (mean - self.action_lows) / (self.action_highs - self.action_lows) 467 | mean = mean.view(batch_size, -1) 468 | 469 | if not std is None: 470 | if not shifted_std is None: 471 | std = std / (shifted_std*10) 472 | std = std.view(batch_size, -1) 473 | 474 | # Prepare the condition 475 | if not cond is None and not self.cond_scale is None: 476 | if self.cond_scale != cond.shape[-1]: 477 | cond_scale = self.cond_scale[:1].repeat(cond.shape[-1]) 478 | cond = cond / cond_scale 479 | else: 480 | cond = cond / self.cond_scale 481 | 482 | # Form the network input 483 | if self.learn_rollout_std: 484 | net_in = torch.cat((mean, std), dim=-1) 485 | else: 486 | net_in = mean 487 | 488 | if not cond is None and not self.cond_mode is None and self.cond_actor: 489 | if self.cond_mode == 'cat': 490 | net_in = torch.cat((net_in, cond), dim=-1) 491 | else: 492 | raise ValueError('Invalid condition mode {} specified.'.format(self.cond_mode)) 493 | 494 | return net_in 495 | 496 | def run_actor(self, 497 | x: Tensor, 498 | costs: Tensor, 499 | mean: Optional[Tensor] = None, 500 | std: Optional[Tensor] = None, 501 | cond: Optional[Tensor] = None) -> Tuple[D.Distribution, D.Distribution]: 502 | 503 | # Compute the MPPI update 504 | if self.mppi_mode or self.is_residual: 505 | with torch.no_grad(): 506 | mppi_costs = costs 507 | if self.scale_costs: 508 | min_costs = mppi_costs.min(dim=-1)[0][:, None] 509 | max_costs = mppi_costs.max(dim=-1)[0][:, None] 510 | mppi_costs = (mppi_costs - min_costs)/(max_costs - min_costs + 1e-6) 511 | 512 | weights = torch.softmax(-mppi_costs / self.temperature, dim=1) 513 | samples = self.generate_samples(weights.shape[0]) 514 | update = torch.sum(samples * weights[:, :, None, None], dim=1) 515 | 516 | mppi_mean = (1 - self.step_size) * self.mean + self.step_size * update.view(-1, self.horizon, self.d_action) 517 | else: 518 | mppi_mean = None 519 | self.mppi_mean = mppi_mean 520 | 521 | # Compute the DMPO update 522 | if not self.mppi_mode: 523 | 524 | # Get the actor embedding 525 | net_in = self.get_actor_embedding(x, costs=costs, mean=mean, std=std, cond=cond) 526 | 527 | # Run the actor 528 | out = self.actor(net_in) 529 | mean = out[:, :self.horizon*self.d_action].view(-1, self.horizon, self.d_action) 530 | idx = 1 531 | 532 | # Scale the actor output so that it is in a reasonable range 533 | if not self.action_lows is None and not self.action_highs is None: 534 | mean = torch.tanh(mean) 535 | mean = mean*(self.action_highs - self.action_lows) 536 | 537 | # Get the gating term if used 538 | if self.is_gated: 539 | gating = out[:, self.horizon*self.d_action*idx:self.horizon*self.d_action*(idx+1)] 540 | gating = torch.tanh(gating).view(-1, self.horizon, self.d_action) 541 | idx += 1 542 | 543 | # Get the updated search STD for the optimizee mean 544 | if self.learn_search_std: 545 | log_mean_std = out[:, self.horizon*self.d_action*idx:self.horizon*self.d_action*(idx+1)] 546 | mean_std = log_mean_std.exp().view(-1, self.horizon, self.d_action) 547 | mean_std = torch.clamp(self.mean_search_std[None, :, :]*mean_std, STD_MIN, STD_MAX) 548 | idx += 1 549 | else: 550 | mean_std = self.mean_search_std 551 | 552 | # Handle updates to the optimizee STD if used 553 | if self.learn_rollout_std: 554 | 555 | # Compute the updated optimizee STD 556 | log_std = out[:, self.horizon*self.d_action*idx:self.horizon*self.d_action*(idx+1)] 557 | std = log_std.exp().view(-1, self.horizon, self.d_action) 558 | std = torch.clamp(self.init_std[None, :, :]*std, STD_MIN, STD_MAX) 559 | idx += 1 560 | 561 | # Get the updated search STD for the optimizee STD 562 | if self.learn_search_std: 563 | log_std_std = out[:, self.horizon*self.d_action*idx:self.horizon*self.d_action*(idx+1)] 564 | std_std = log_std_std.exp().view(-1, self.horizon, self.d_action) 565 | std_std = torch.clamp(self.std_search_std*std_std, STD_MIN, STD_MAX) 566 | else: 567 | std_std = self.std_search_std 568 | 569 | # Compute the updated mean 570 | if not self.is_residual and self.is_delta: 571 | # If not in residual mode and learning a delta on current plan, optionally with a gating term 572 | if self.is_gated: 573 | mean = (1-gating)*self.mean + gating*mean 574 | else: 575 | mean = mean + self.mean 576 | elif self.is_residual: 577 | # If in residual mode, form the update on the MPPI proposed mean, optionally using a gating term 578 | if self.is_gated: 579 | mean = (1-gating)*mppi_mean + gating*mean 580 | else: 581 | mean = mppi_mean + mean 582 | 583 | mean_dist = D.Normal(mean, mean_std) 584 | if self.learn_rollout_std: 585 | std_dist = D.Normal(std, std_std) 586 | else: 587 | std_dist = None 588 | else: 589 | mean_dist = D.Normal(mppi_mean, self.std) 590 | std_dist = None 591 | 592 | return mean_dist, std_dist 593 | 594 | def run_critic(self, 595 | x: Tensor, 596 | costs: Tensor, 597 | mean: Optional[Tensor] = None, 598 | std: Optional[Tensor] = None, 599 | cond: Optional[Tensor] = None) -> Tensor: 600 | 601 | # Get the critic embedding 602 | net_in = self.get_critic_embedding(x, costs=costs, mean=mean, std=std, cond=cond) 603 | 604 | # Run the critics 605 | return self.critic(net_in) 606 | 607 | def run_shift_model(self, 608 | x: Tensor, 609 | mean: Optional[Tensor] = None, 610 | shifted_mean: Optional[Tensor] = None, 611 | std: Optional[Tensor] = None, 612 | shifted_std: Optional[Tensor] = None, 613 | cond: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 614 | # Get the shift model embedding 615 | net_in = self.get_shift_embedding(x, mean=mean, std=std, shifted_std=shifted_std, cond=cond) 616 | 617 | # Run the shift model 618 | out = self.shift_model(net_in) 619 | 620 | if self.learn_rollout_std: 621 | mean, log_std = torch.split(out, self.horizon*self.d_action, dim=-1) 622 | mean = mean.view(-1, self.horizon, self.d_action) 623 | log_std = log_std.view(-1, self.horizon, self.d_action) 624 | else: 625 | mean = out.view(-1, self.horizon, self.d_action) 626 | 627 | # Ensure network output is in a reasonable range 628 | if not self.action_lows is None and not self.action_highs is None: 629 | mean = torch.tanh(mean) 630 | mean = mean * (self.action_highs - self.action_lows) 631 | 632 | # Update the shifted optimizee STD 633 | if self.learn_rollout_std: 634 | std = log_std.exp().view(-1, self.horizon, self.d_action) 635 | std = torch.clamp(shifted_std*std, STD_MIN, STD_MAX) 636 | 637 | # Optionally set the optimizee mean as a residual on the shift-forward mean 638 | if self.is_residual: 639 | mean = mean + shifted_mean 640 | 641 | # Return the shifted mean and STD 642 | if self.learn_rollout_std: 643 | return mean, std 644 | else: 645 | return mean 646 | 647 | def forward(self, 648 | x: Tensor, 649 | costs: Optional[Tensor] = None, 650 | mean: Optional[Tensor] = None, 651 | std: Optional[Tensor] = None, 652 | cond: Optional[Tensor] = None, 653 | run_critic: bool = True, 654 | **kwargs) -> Dict[str, Any]: 655 | x = utils.to_tensor(x, self.tensor_args) 656 | 657 | # Handle the mean input 658 | if not mean is None: 659 | self.mean = mean 660 | shifted_mean, old_mean = self.process_mean(x, self.mean) 661 | self.mean = shifted_mean 662 | 663 | # Handle the STD input 664 | if self.learn_rollout_std and not self.mppi_mode: 665 | if not std is None: 666 | self.std = std 667 | shifted_std, old_std = self.process_std(x, self.std) 668 | self.std = shifted_std 669 | else: 670 | old_std = self.process_std(x, self.std, shift=False) 671 | self.std = old_std 672 | 673 | # Use learned shift model 674 | if not self.shift_model is None and not self.mppi_mode: 675 | if self.learn_rollout_std: 676 | self.mean, self.std = self.run_shift_model(x, old_mean, self.mean, old_std, self.std, cond) 677 | else: 678 | self.mean = self.run_shift_model(x, old_mean,self.mean, old_std, self.std, cond) 679 | 680 | # Run the specified # of iterations of optimization 681 | for iter in range(self.n_iters): 682 | 683 | # Ensure we reoptimize if running multiple iterations 684 | if iter > 0: 685 | costs = None 686 | self.rollout_task.update_params(self.params_stacked) 687 | 688 | # Run rollouts to compute costs 689 | if costs is None: 690 | costs, samples = self.run_rollouts(x) 691 | else: 692 | samples = None 693 | 694 | # Run the actor 695 | mean_dist, std_dist = self.run_actor(x, 696 | costs=costs, 697 | mean=self.mean, 698 | std=self.std, 699 | cond=cond) 700 | 701 | # Sample the mean 702 | if self.use_mean: 703 | horizon = mean_dist.loc 704 | else: 705 | horizon = mean_dist.rsample() 706 | horizon = horizon.clamp(self.action_lows, self.action_highs) 707 | self.mean = horizon 708 | 709 | # Sample the STD 710 | if self.learn_rollout_std and not self.mppi_mode: 711 | if self.use_mean: 712 | self.std = std_dist.loc 713 | else: 714 | self.std = std_dist.rsample() 715 | self.std = self.std.clamp(STD_MIN, STD_MAX) 716 | 717 | # Set the action to be the first in the horizon 718 | action = horizon[:, 0] 719 | 720 | # Run the critic 721 | if self.mppi_mode or not run_critic: 722 | value = torch.zeros(x.shape[0]) 723 | else: 724 | value = self.run_critic(x, 725 | costs=costs, 726 | mean=old_mean, 727 | std=old_std, 728 | cond=cond) 729 | 730 | return dict( 731 | action=action, 732 | horizon=horizon, 733 | costs=costs, 734 | mean=self.mean, 735 | old_mean=old_mean, 736 | std=self.std, 737 | old_std=old_std, 738 | samples=samples, 739 | mean_dist=mean_dist, 740 | std_dist=std_dist, 741 | value=value 742 | ) --------------------------------------------------------------------------------