├── .gitignore ├── LICENSE.txt ├── README.md ├── pyproject.toml ├── src └── pytorch_mppi │ ├── __init__.py │ ├── autotune.py │ ├── autotune_global.py │ ├── autotune_qd.py │ └── mppi.py └── tests ├── auto_tune_parameters.py ├── pendulum.py ├── pendulum_approximate.py ├── pendulum_approximate_continuous.py ├── smooth_mppi.py └── test_batch_wrapper.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | pytorch_mppi.egg-info 3 | __pycache__ 4 | dist 5 | *.png 6 | tests/images 7 | tests/*.pkl -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 University of Michigan ARM Lab 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch MPPI Implementation 2 | This repository implements Model Predictive Path Integral (MPPI) 3 | with approximate dynamics in pytorch. MPPI typically requires actual 4 | trajectory samples, but [this paper](https://ieeexplore.ieee.org/document/7989202/) 5 | showed that it could be done with approximate dynamics (such as with a neural network) 6 | using importance sampling. 7 | 8 | Thus it can be used in place of other trajectory optimization methods 9 | such as the Cross Entropy Method (CEM), or random shooting. 10 | 11 | --- 12 | New since Aug 2024 smoothing methods, including our own KMPPI, see the [section](#smoothing) below on smoothing 13 | 14 | # Installation 15 | ```shell 16 | pip install pytorch-mppi 17 | ``` 18 | for autotuning hyperparameters, install with 19 | ```shell 20 | pip install pytorch-mppi[tune] 21 | ``` 22 | 23 | for running tests, install with 24 | ```shell 25 | pip install pytorch-mppi[test] 26 | ``` 27 | for development, clone the repository then install in editable mode 28 | ```shell 29 | pip install -e . 30 | ``` 31 | 32 | # Usage 33 | See `tests/pendulum_approximate.py` for usage with a neural network approximating 34 | the pendulum dynamics. See the `not_batch` branch for an easier to read 35 | algorithm. Basic use case is shown below 36 | 37 | ```python 38 | from pytorch_mppi import MPPI 39 | 40 | # create controller with chosen parameters 41 | ctrl = MPPI(dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS, 42 | lambda_=lambda_, device=d, 43 | u_min=torch.tensor(ACTION_LOW, dtype=torch.double, device=d), 44 | u_max=torch.tensor(ACTION_HIGH, dtype=torch.double, device=d)) 45 | 46 | # assuming you have a gym-like env 47 | obs = env.reset() 48 | for i in range(100): 49 | action = ctrl.command(obs) 50 | obs, reward, done, _ = env.step(action.cpu().numpy()) 51 | ``` 52 | 53 | # Requirements 54 | - pytorch (>= 1.0) 55 | - `next state <- dynamics(state, action)` function (doesn't have to be true dynamics) 56 | - `state` is `K x nx`, `action` is `K x nu` 57 | - `cost <- running_cost(state, action)` function 58 | - `cost` is `K x 1`, state is `K x nx`, `action` is `K x nu` 59 | 60 | # Features 61 | - Approximate dynamics MPPI with importance sampling 62 | - Parallel/batch pytorch implementation for accelerated sampling 63 | - Control bounds via sampling control noise from rectified gaussian 64 | - Handle stochastic dynamic models (assuming each call is a sample) by sampling multiple state trajectories for the same 65 | action trajectory with `rollout_samples` 66 | - 67 | # Parameter tuning and hints 68 | `terminal_state_cost` - function(state (K x T x nx)) -> cost (K x 1) by default there is no terminal 69 | cost, but if you experience your trajectory getting close to but never quite reaching the goal, then 70 | having a terminal cost can help. The function should scale with the horizon (T) to keep up with the 71 | scaling of the running cost. 72 | 73 | `lambda_` - higher values increases the cost of control noise, so you end up with more 74 | samples around the mean; generally lower values work better (try `1e-2`) 75 | 76 | `num_samples` - number of trajectories to sample; generally the more the better. 77 | Runtime performance scales much better with `num_samples` than `horizon`, especially 78 | if you're using a GPU device (remember to pass that in!) 79 | 80 | `noise_mu` - the default is 0 for all control dimensions, which may work out 81 | really poorly if you have control bounds and the allowed range is not 0-centered. 82 | Remember to change this to an appropriate value for non-symmetric control dimensions. 83 | 84 | ## Smoothing 85 | From version 0.8.0 onwards, you can use MPPI variants that smooth the control signal. We've implemented 86 | [SMPPI](https://arxiv.org/pdf/2112.09988) as well our own kernel interpolation MPPI (KMPPI). In the base algorithm, 87 | you can achieve somewhat smoother trajectories by increasing `lambda_`; however, that comes at the cost of 88 | optimality. Explicit smoothing algorithms can achieve smoothness without sacrificing optimality. 89 | 90 | We used it and described it in our recent paper ([arxiv](https://arxiv.org/abs/2408.10450)) and you can cite it 91 | until we release a work dedicated to KMPPI. Below we show the difference between MPPI, SMPPI, and KMPPI on a toy 92 | 2D navigation problem where the control is a constrained delta position. You can check it out in `tests/smooth_mppi.py`. 93 | 94 | The API is mostly the same, with some additional constructor options: 95 | ```python 96 | import pytorch_mppi as mppi 97 | ctrl = mppi.KMPPI(args, 98 | kernel=mppi.RBFKernel(sigma=2), # kernel in trajectory time space (1 dimensional) 99 | num_support_pts=5, # number of control points to sample, <= horizon 100 | **kwargs) 101 | ``` 102 | The kernel can be any subclass of `mppi.TimeKernel`. It is a kernel in the trajectory time space (1 dimensional). 103 | Note that B-spline smoothing can be achieved by using a B-spline kernel. The number of support points is the number 104 | of control points to sample. Any trajectory points in between are interpolated using the kernel. For example if a 105 | trajectory horizon is 20 and `num_support_pts` is 5, then 5 control points evenly spaced throughout the horizon 106 | (with the first and last corresponding to the actual start and end of the trajectory) are sampled. The rest of the 107 | trajectory is interpolated using the kernel. The kernel is applied to the control signal, not the state signal. 108 | 109 | MPPI without smoothing 110 | 111 | ![MPPI](https://imgur.com/9wEcT2s.gif) 112 | 113 | [SMPPI](https://arxiv.org/pdf/2112.09988) smoothing by sampling noise in the action derivative space doesn't work well on this problem 114 | 115 | ![SMPPI](https://imgur.com/xwYy3aj.gif) 116 | 117 | KMPPI smoothing with RBF kernel works well 118 | 119 | ![KMPPI](https://imgur.com/IG1Zrtd.gif) 120 | 121 | 122 | ## Autotune 123 | from version 0.5.0 onwards, you can automatically tune the hyperparameters. 124 | A convenient tuner compatible with the popular [ray tune](https://docs.ray.io/en/latest/tune/index.html) library 125 | is implemented. You can select from a variety of cutting edge black-box optimizers such as 126 | [CMA-ES](https://github.com/CMA-ES/pycma), [HyperOpt](http://hyperopt.github.io/hyperopt/), 127 | [fmfn/BayesianOptimization](https://github.com/fmfn/BayesianOptimization), and so on. 128 | See `tests/auto_tune_parameters.py` for an example. A tutorial based on it follows. 129 | 130 | The tuner can be used for other controllers as well, but you will need to define the appropriate 131 | `TunableParameter` subclasses. 132 | 133 | First we create a toy 2D environment to do controls on and create the controller with some 134 | default parameters. 135 | ```python 136 | import torch 137 | from pytorch_mppi import MPPI 138 | 139 | device = "cpu" 140 | dtype = torch.double 141 | 142 | # create toy environment to do on control on (default start and goal) 143 | env = Toy2DEnvironment(visualize=True, terminal_scale=10) 144 | 145 | # create MPPI with some initial parameters 146 | mppi = MPPI(env.dynamics, env.running_cost, 2, 147 | terminal_state_cost=env.terminal_cost, 148 | noise_sigma=torch.diag(torch.tensor([5., 5.], dtype=dtype, device=device)), 149 | num_samples=500, 150 | horizon=20, device=device, 151 | u_max=torch.tensor([2., 2.], dtype=dtype, device=device), 152 | lambda_=1) 153 | ``` 154 | 155 | We then need to create an evaluation function for the tuner to tune on. 156 | It should take no arguments and output a `EvaluationResult` populated at least by costs. 157 | If you don't need rollouts for the cost evaluation, then you can set it to None in the return. 158 | Tips for creating the evaluation function are described in comments below: 159 | 160 | ```python 161 | from pytorch_mppi import autotune 162 | # use the same nominal trajectory to start with for all the evaluations for fairness 163 | nominal_trajectory = mppi.U.clone() 164 | # parameters for our sample evaluation function - lots of choices for the evaluation function 165 | evaluate_running_cost = True 166 | num_refinement_steps = 10 167 | num_trajectories = 5 168 | 169 | def evaluate(): 170 | costs = [] 171 | rollouts = [] 172 | # we sample multiple trajectories for the same start to goal problem, but in your case you should consider 173 | # evaluating over a diverse dataset of trajectories 174 | for j in range(num_trajectories): 175 | mppi.U = nominal_trajectory.clone() 176 | # the nominal trajectory at the start will be different if the horizon's changed 177 | mppi.change_horizon(mppi.T) 178 | # usually MPPI will have its nominal trajectory warm-started from the previous iteration 179 | # for a fair test of tuning we will reset its nominal trajectory to the same random one each time 180 | # we manually warm it by refining it for some steps 181 | for k in range(num_refinement_steps): 182 | mppi.command(env.start, shift_nominal_trajectory=False) 183 | 184 | rollout = mppi.get_rollouts(env.start) 185 | 186 | this_cost = 0 187 | rollout = rollout[0] 188 | # here we evaluate on the rollout MPPI cost of the resulting trajectories 189 | # alternative costs for tuning the parameters are possible, such as just considering terminal cost 190 | if evaluate_running_cost: 191 | for t in range(len(rollout) - 1): 192 | this_cost = this_cost + env.running_cost(rollout[t], mppi.U[t]) 193 | this_cost = this_cost + env.terminal_cost(rollout, mppi.U) 194 | 195 | rollouts.append(rollout) 196 | costs.append(this_cost) 197 | # can return None for rollouts if they do not need to be calculated 198 | return autotune.EvaluationResult(torch.stack(costs), torch.stack(rollouts)) 199 | ``` 200 | 201 | With this we have enough to start tuning. For example, we can tune iteratively with the CMA-ES optimizer 202 | 203 | ```python 204 | # these are subclass of TunableParameter (specifically MPPIParameter) that we want to tune 205 | params_to_tune = [autotune.SigmaParameter(mppi), autotune.HorizonParameter(mppi), autotune.LambdaParameter(mppi)] 206 | # create a tuner with a CMA-ES optimizer 207 | tuner = autotune.Autotune(params_to_tune, evaluate_fn=evaluate, optimizer=autotune.CMAESOpt(sigma=1.0)) 208 | # tune parameters for a number of iterations 209 | iterations = 30 210 | for i in range(iterations): 211 | # results of this optimization step are returned 212 | res = tuner.optimize_step() 213 | # we can render the rollouts in the environment 214 | env.draw_rollouts(res.rollouts) 215 | # get best results and apply it to the controller 216 | # (by default the controller will take on the latest tuned parameter, which may not be best) 217 | res = tuner.get_best_result() 218 | tuner.apply_parameters(res.param_values) 219 | ``` 220 | This is a local search method that optimizes starting from the initially defined parameters. 221 | For global searching, we use ray tune compatible searching algorithms. Note that you can modify the 222 | search space of each parameter, but default reasonable ones are provided. 223 | 224 | ```python 225 | # can also use a Ray Tune optimizer, see 226 | # https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#search-algorithms-tune-search 227 | # rather than adapting the current parameters, these optimizers allow you to define a search space for each 228 | # and will search on that space 229 | from pytorch_mppi import autotune_global 230 | from ray.tune.search.hyperopt import HyperOptSearch 231 | from ray.tune.search.bayesopt import BayesOptSearch 232 | 233 | # the global version of the parameters define a reasonable search space for each parameter 234 | params_to_tune = [autotune_global.SigmaGlobalParameter(mppi), 235 | autotune_global.HorizonGlobalParameter(mppi), 236 | autotune_global.LambdaGlobalParameter(mppi)] 237 | 238 | # be sure to close any figures before ray tune optimization or they will be duplicated 239 | env.visualize = False 240 | plt.close('all') 241 | tuner = autotune_global.AutotuneGlobal(params_to_tune, evaluate_fn=evaluate, 242 | optimizer=autotune_global.RayOptimizer(HyperOptSearch)) 243 | # ray tuners cannot be tuned iteratively, but you can specify how many iterations to tune for 244 | res = tuner.optimize_all(100) 245 | res = tuner.get_best_result() 246 | tuner.apply_parameters(res.params) 247 | ``` 248 | 249 | For example tuning hyperparameters (with CMA-ES) only on the toy problem (the nominal trajectory is reset each time so they are sampling from noise): 250 | 251 | ![toy tuning](https://i.imgur.com/2qtYMwu.gif) 252 | 253 | If you want more than just the best solution found, such as if you want diversity 254 | across hyperparameter values, or if your evaluation function has large uncertainty, 255 | then you can directly query past results by 256 | ```python 257 | for res in tuner.optim.all_res: 258 | # the cost 259 | print(res.metrics['cost']) 260 | # extract the parameters 261 | params = tuner.config_to_params(res.config) 262 | print(params) 263 | # apply the parameters to the controller 264 | tuner.apply_parameters(params) 265 | ``` 266 | 267 | Alternatively you can try Quality Diversity optimization using the 268 | [CMA-ME optimizer](https://github.com/icaros-usc/pyribs). This optimizer will 269 | try to optimize for high quality parameters while ensuring there is diversity across 270 | them. However, it is very slow and you might be better using a `RayOptimizer` and selecting 271 | for top results while checking for diversity. 272 | To use it, you need to install 273 | ```python 274 | pip install ribs 275 | ``` 276 | 277 | You then use it as 278 | 279 | ```python 280 | import pytorch_mppi.autotune_qd 281 | 282 | optim = pytorch_mppi.autotune_qd.CMAMEOpt() 283 | tuner = autotune_global.AutotuneGlobal(params_to_tune, evaluate_fn=evaluate, 284 | optimizer=optim) 285 | 286 | iterations = 10 287 | for i in range(iterations): 288 | # results of this optimization step are returned 289 | res = tuner.optimize_step() 290 | # we can render the rollouts in the environment 291 | best_params = optim.get_diverse_top_parameters(5) 292 | for res in best_params: 293 | print(res) 294 | ``` 295 | 296 | # Tests 297 | Under `tests` you can find the `MPPI` method applied to known pendulum dynamics 298 | and approximate pendulum dynamics (with a 2 layer feedforward net 299 | estimating the state residual). Using a continuous angle representation 300 | (feeding `cos(\theta), sin(\theta)` instead of `\theta` directly) makes 301 | a huge difference. Although both works, the continuous representation 302 | is much more robust to controller parameters and random seed. In addition, 303 | the problem of continuing to spin after over-swinging does not appear. 304 | 305 | Sample result on approximate dynamics with 100 steps of random policy data 306 | to initialize the dynamics: 307 | 308 | ![pendulum results](https://i.imgur.com/euYQJ25.gif) 309 | 310 | # Related projects 311 | - [pytorch CEM](https://github.com/LemonPi/pytorch_cem) - an alternative MPC shooting method with similar API as this 312 | project 313 | - [pytorch iCEM](https://github.com/UM-ARM-Lab/pytorch_icem) - alternative sampling based MPC 314 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pytorch_mppi" 3 | version = "0.8.0" 4 | description = "Model Predictive Path Integral (MPPI) implemented in pytorch" 5 | readme = "README.md" # Optional 6 | 7 | # Specify which Python versions you support. In contrast to the 8 | # 'Programming Language' classifiers above, 'pip install' will check this 9 | # and refuse to install the project if the version does not match. See 10 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#python-requires 11 | requires-python = ">=3.6" 12 | 13 | # This is either text indicating the license for the distribution, or a file 14 | # that contains the license 15 | # https://packaging.python.org/en/latest/specifications/core-metadata/#license 16 | license = { file = "LICENSE.txt" } 17 | 18 | # This field adds keywords for your project which will appear on the 19 | # project page. What does your project relate to? 20 | # 21 | # Note that this is a list of additional keywords, separated 22 | # by commas, to be used to assist searching for the distribution in a 23 | # larger catalog. 24 | keywords = ["mppi", "pytorch", "control", "robotics"] # Optional 25 | authors = [ 26 | { name = "Sheng Zhong", email = "zhsh@umich.edu" } # Optional 27 | ] 28 | maintainers = [ 29 | { name = "Sheng Zhong", email = "zhsh@umich.edu" } # Optional 30 | ] 31 | 32 | # Classifiers help users find your project by categorizing it. 33 | # 34 | # For a list of valid classifiers, see https://pypi.org/classifiers/ 35 | classifiers = [# Optional 36 | "Development Status :: 4 - Beta", 37 | # Indicate who your project is intended for 38 | "Intended Audience :: Developers", 39 | # Pick your license as you wish 40 | "License :: OSI Approved :: MIT License", 41 | # Specify the Python versions you support here. In particular, ensure 42 | # that you indicate you support Python 3. These classifiers are *not* 43 | # checked by "pip install". See instead "python_requires" below. 44 | "Programming Language :: Python :: 3", 45 | "Programming Language :: Python :: 3 :: Only", 46 | ] 47 | 48 | # This field lists other packages that your project depends on to run. 49 | # Any package you put here will be installed by pip when your project is 50 | # installed, so they must be valid existing projects. 51 | # 52 | # For an analysis of this field vs pip's requirements files see: 53 | # https://packaging.python.org/discussions/install-requires-vs-requirements/ 54 | dependencies = [# Optional 55 | 'torch', 56 | 'numpy', 57 | 'arm-pytorch-utilities>=0.4', 58 | ] 59 | 60 | # List additional groups of dependencies here (e.g. development 61 | # dependencies). Users will be able to install these using the "extras" 62 | # syntax, for example: 63 | # 64 | # $ pip install sampleproject[dev] 65 | # 66 | # Similar to `dependencies` above, these must be valid existing 67 | # projects. 68 | [project.optional-dependencies] # Optional 69 | tune = [ 70 | 'cma', 71 | 'ray[tune]', 72 | 'bayesian-optimization', 73 | 'hyperopt', 74 | ] 75 | test = [ 76 | "pytest", 77 | 'gym', 78 | 'pygame', 79 | 'pyglet==1.5.27', 80 | 'window-recorder', 81 | 'cma', 82 | 'ray[tune]', 83 | 'bayesian-optimization', 84 | 'hyperopt', 85 | ] 86 | 87 | # List URLs that are relevant to your project 88 | # 89 | # This field corresponds to the "Project-URL" and "Home-Page" metadata fields: 90 | # https://packaging.python.org/specifications/core-metadata/#project-url-multiple-use 91 | # https://packaging.python.org/specifications/core-metadata/#home-page-optional 92 | # 93 | # Examples listed include a pattern for specifying where the package tracks 94 | # issues, where the source is hosted, where to say thanks to the package 95 | # maintainers, and where to support the project financially. The key is 96 | # what's used to render the link text on PyPI. 97 | [project.urls] # Optional 98 | "Homepage" = "https://github.com/LemonPi/pytorch_mppi" 99 | "Bug Reports" = "https://github.com/LemonPi/pytorch_mppi/issues" 100 | "Source" = "https://github.com/LemonPi/pytorch_mppi" 101 | 102 | # The following would provide a command line executable called `sample` 103 | # which executes the function `main` from this package when invoked. 104 | #[project.scripts] # Optional 105 | #sample = "sample:main" 106 | 107 | # This is configuration specific to the `setuptools` build backend. 108 | # If you are using a different build backend, you will need to change this. 109 | [tool.setuptools] 110 | # If there are data files included in your packages that need to be 111 | # installed, specify them here. 112 | 113 | [build-system] 114 | # These are the assumed default build requirements from pip: 115 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 116 | requires = ["setuptools>=43.0.0", "wheel"] 117 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /src/pytorch_mppi/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_mppi.mppi import MPPI, SMPPI, KMPPI 2 | -------------------------------------------------------------------------------- /src/pytorch_mppi/autotune.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import abc 3 | 4 | import numpy as np 5 | import torch 6 | import typing 7 | 8 | from arm_pytorch_utilities.tensor_utils import ensure_tensor 9 | from torch.distributions import MultivariateNormal 10 | 11 | from pytorch_mppi import MPPI 12 | # optimizers 13 | import cma 14 | 15 | logger = logging.getLogger(__file__) 16 | 17 | 18 | class EvaluationResult(typing.NamedTuple): 19 | # (N) cost for each trajectory evaluated 20 | costs: torch.Tensor 21 | # (N x H x nx) where H is the horizon and nx is the state dimension 22 | rollouts: torch.Tensor 23 | # parameter values populated by the tuner after evaluation returns 24 | params: dict = None 25 | # iteration number populated by the tuner after evaluation returns 26 | iteration: int = None 27 | 28 | 29 | class Optimizer: 30 | def __init__(self): 31 | self.tuner: typing.Optional[Autotune] = None 32 | self.optim = None 33 | 34 | @abc.abstractmethod 35 | def setup_optimization(self) -> None: 36 | """Create backend optim object with optimization parameters and MPPI parameters from the tuner""" 37 | 38 | @abc.abstractmethod 39 | def optimize_step(self) -> EvaluationResult: 40 | """Optimize a single step, returning the evaluation result from the latest parameters""" 41 | 42 | def optimize_all(self, iterations) -> EvaluationResult: 43 | """Optimize multiple steps, returning the best evaluation results. 44 | Some optimizers may only have this implemented.""" 45 | res = None 46 | for i in range(iterations): 47 | res = self.optimize_step() 48 | return res 49 | 50 | 51 | class CMAESOpt(Optimizer): 52 | """Optimize using CMA-ES, an evolutionary algorithm that maintains a Gaussian population, 53 | starting around the initial parameters with a variance (potentially different for each hyperparameter).""" 54 | 55 | def __init__(self, population=10, sigma=0.1): 56 | self.population = population 57 | self.sigma = sigma 58 | super().__init__() 59 | 60 | def setup_optimization(self): 61 | x0 = self.tuner.flatten_params() 62 | 63 | options = {"popsize": self.population, "seed": np.random.randint(0, 10000), "tolfun": 1e-5, "tolfunhist": 1e-6} 64 | self.optim = cma.CMAEvolutionStrategy(x0=x0, sigma0=self.sigma, inopts=options) 65 | 66 | def optimize_step(self): 67 | params = self.optim.ask() 68 | # convert params for use 69 | 70 | cost_per_param = [] 71 | all_rollouts = [] 72 | for param in params: 73 | self.tuner.unflatten_params(param) 74 | res = self.tuner.evaluate_fn() 75 | cost_per_param.append(res.costs.mean().cpu().numpy()) 76 | all_rollouts.append(res.rollouts) 77 | 78 | cost_per_param = np.array(cost_per_param) 79 | self.optim.tell(params, cost_per_param) 80 | 81 | best_param = self.optim.best.x 82 | self.tuner.unflatten_params(best_param) 83 | res = self.tuner.evaluate_fn() 84 | return res 85 | 86 | 87 | class TunableParameter(abc.ABC): 88 | """A parameter that can be tuned by the autotuner. Holds references to the object that defines its actual value.""" 89 | 90 | @staticmethod 91 | @abc.abstractmethod 92 | def name(): 93 | """Get the name of the parameter""" 94 | 95 | @abc.abstractmethod 96 | def dim(self): 97 | """Get the dimension of the parameter""" 98 | 99 | @abc.abstractmethod 100 | def get_current_parameter_value(self): 101 | """Get the current underlying value of the parameter""" 102 | 103 | @abc.abstractmethod 104 | def ensure_valid_value(self, value): 105 | """Return a validated parameter value as close in intent as the input value as possible""" 106 | 107 | @abc.abstractmethod 108 | def apply_parameter_value(self, value): 109 | """Apply the parameter value to the underlying object""" 110 | 111 | @abc.abstractmethod 112 | def attach_to_state(self, state: dict): 113 | """Reattach/reinitialize the parameter to a new internal state. This should be similar to a call to __init__""" 114 | 115 | def get_parameter_value_from_config(self, config): 116 | """Get the serialized value of the parameter from a config dictionary, where each name is a scalar""" 117 | return config[self.name()] 118 | 119 | def get_config_from_parameter_value(self, value): 120 | """Reverse of the above method, get a config dictionary from a parameter value""" 121 | return {self.name(): value} 122 | 123 | 124 | class MPPIParameter(TunableParameter, abc.ABC): 125 | def __init__(self, mppi: MPPI, dim=None): 126 | self.mppi = mppi 127 | self._dim = dim 128 | if self.mppi is not None: 129 | self.d = self.mppi.d 130 | self.dtype = self.mppi.dtype 131 | if dim is None: 132 | self._dim = self.mppi.nu 133 | 134 | def attach_to_state(self, state: dict): 135 | self.mppi = state['mppi'] 136 | self.d = self.mppi.d 137 | self.dtype = self.mppi.dtype 138 | 139 | 140 | class SigmaParameter(MPPIParameter): 141 | eps = 0.0001 142 | 143 | @staticmethod 144 | def name(): 145 | return 'sigma' 146 | 147 | def dim(self): 148 | return self._dim 149 | 150 | def get_current_parameter_value(self): 151 | return torch.cat([self.mppi.noise_sigma[i][i].view(1) for i in range(self.dim())]) 152 | 153 | def ensure_valid_value(self, value): 154 | sigma = ensure_tensor(self.d, self.dtype, value) 155 | sigma[sigma < self.eps] = self.eps 156 | return sigma 157 | 158 | def apply_parameter_value(self, value): 159 | sigma = self.ensure_valid_value(value) 160 | self.mppi.noise_sigma = torch.diag(sigma) 161 | self.mppi.noise_dist = MultivariateNormal(self.mppi.noise_mu, covariance_matrix=self.mppi.noise_sigma) 162 | self.mppi.noise_sigma_inv = torch.inverse(self.mppi.noise_sigma.detach()) 163 | 164 | def get_parameter_value_from_config(self, config): 165 | return torch.tensor([config[f'{self.name()}{i}'] for i in range(self.dim())], dtype=self.dtype, device=self.d) 166 | 167 | def get_config_from_parameter_value(self, value): 168 | return {f'{self.name()}{i}': value[i].item() for i in range(self.dim())} 169 | 170 | 171 | class MuParameter(MPPIParameter): 172 | @staticmethod 173 | def name(): 174 | return 'mu' 175 | 176 | def dim(self): 177 | return self._dim 178 | 179 | def get_current_parameter_value(self): 180 | return self.mppi.noise_mu.clone() 181 | 182 | def ensure_valid_value(self, value): 183 | mu = ensure_tensor(self.d, self.dtype, value) 184 | return mu 185 | 186 | def apply_parameter_value(self, value): 187 | mu = self.ensure_valid_value(value) 188 | self.mppi.noise_dist = MultivariateNormal(mu, covariance_matrix=self.mppi.noise_sigma) 189 | self.mppi.noise_sigma_inv = torch.inverse(self.mppi.noise_sigma.detach()) 190 | 191 | def get_parameter_value_from_config(self, config): 192 | return torch.tensor([config[f'{self.name()}{i}'] for i in range(self.dim())], dtype=self.dtype, device=self.d) 193 | 194 | def get_config_from_parameter_value(self, value): 195 | return {f'{self.name()}{i}': value[i].item() for i in range(self.dim())} 196 | 197 | 198 | class LambdaParameter(MPPIParameter): 199 | eps = 0.0001 200 | 201 | @staticmethod 202 | def name(): 203 | return 'lambda' 204 | 205 | def dim(self): 206 | return 1 207 | 208 | def get_current_parameter_value(self): 209 | return self.mppi.lambda_ 210 | 211 | def ensure_valid_value(self, value): 212 | if torch.is_tensor(value) or isinstance(value, np.ndarray): 213 | value = value[0] 214 | v = max(value, self.eps) 215 | return v 216 | 217 | def apply_parameter_value(self, value): 218 | v = self.ensure_valid_value(value) 219 | self.mppi.lambda_ = v 220 | 221 | 222 | class HorizonParameter(MPPIParameter): 223 | @staticmethod 224 | def name(): 225 | return 'horizon' 226 | 227 | def dim(self): 228 | return 1 229 | 230 | def get_current_parameter_value(self): 231 | return self.mppi.T 232 | 233 | def ensure_valid_value(self, value): 234 | if torch.is_tensor(value) or isinstance(value, np.ndarray): 235 | value = value[0] 236 | v = max(round(value), 1) 237 | return v 238 | 239 | def apply_parameter_value(self, value): 240 | v = self.ensure_valid_value(value) 241 | self.mppi.change_horizon(v) 242 | 243 | 244 | class Autotune: 245 | """Tune selected hyperparameters using state-of-the-art optimizers on an evaluation function. 246 | Subclass to define other parameters to optimize over such as terminal cost scaling. 247 | See tests/auto_tune_parameters.py for an example evaluate_fn 248 | """ 249 | eps = 0.0001 250 | 251 | def __init__(self, params_to_tune: typing.Sequence[TunableParameter], 252 | evaluate_fn: typing.Callable[[], EvaluationResult], 253 | reload_state_fn: typing.Callable[[], dict] = None, 254 | optimizer=CMAESOpt()): 255 | """ 256 | 257 | :param params_to_tune: sequence of tunable parameters 258 | :param evaluate_fn: function that returns an EvaluationResult that we want to minimize 259 | :param reload_state_fn: function that returns a dictionary of state to reattach to the parameters 260 | :param optimizer: optimizer that searches in the parameter space 261 | """ 262 | self.evaluate_fn = evaluate_fn 263 | self.reload_state_fn = reload_state_fn 264 | 265 | self.params = params_to_tune 266 | self.optim = optimizer 267 | self.optim.tuner = self 268 | self.results = [] 269 | 270 | self.attach_parameters() 271 | self.optim.setup_optimization() 272 | 273 | def optimize_step(self) -> EvaluationResult: 274 | res = self.optim.optimize_step() 275 | res = self.log_current_result(res) 276 | return res 277 | 278 | def optimize_all(self, iterations) -> EvaluationResult: 279 | res = self.optim.optimize_all(iterations) 280 | res = self.log_current_result(res) 281 | return res 282 | 283 | def get_best_result(self) -> EvaluationResult: 284 | return min(self.results, key=lambda res: res.costs.mean().item()) 285 | 286 | def log_current_result(self, res: EvaluationResult): 287 | with torch.no_grad(): 288 | iteration = len(self.results) 289 | kv = self.get_parameter_values(self.params) 290 | res = res._replace(iteration=iteration, 291 | params={k: v.detach().clone() if torch.is_tensor(v) else v for k, v in 292 | kv.items()}) 293 | logger.info(f"i:{iteration} cost: {res.costs.mean().item()} params:{kv}") 294 | self.results.append(res) 295 | return res 296 | 297 | def get_parameter_values(self, params_to_tune: typing.Sequence[TunableParameter]): 298 | # take on the assigned values to the MPPI 299 | return {p.name(): p.get_current_parameter_value() for p in params_to_tune} 300 | 301 | def flatten_params(self): 302 | x = [] 303 | kv = self.get_parameter_values(self.params) 304 | # TODO ensure this is the same order as define and unflatten 305 | for k, v in kv.items(): 306 | if torch.is_tensor(v): 307 | x.append(v.detach().cpu().numpy()) 308 | else: 309 | x.append([v]) 310 | x = np.concatenate(x) 311 | return x 312 | 313 | def unflatten_params(self, x, apply=True): 314 | # have to be in the same order as the flattening 315 | param_values = {} 316 | i = 0 317 | for p in self.params: 318 | raw_value = x[i:i + p.dim()] 319 | param_values[p.name()] = p.ensure_valid_value(raw_value) 320 | i += p.dim() 321 | if apply: 322 | self.apply_parameters(param_values) 323 | return param_values 324 | 325 | def apply_parameters(self, param_values): 326 | for p in self.params: 327 | p.apply_parameter_value(param_values[p.name()]) 328 | 329 | def attach_parameters(self): 330 | """Attach parameters to any underlying state they require In most cases the parameters are defined already 331 | attached to whatever state it needs, e.g. the MPPI controller object for changing the parameter values. 332 | However, there are cases where the full state is not serializable, e.g. when using a multiprocessing pool 333 | and so we pass only the information required to load the state. We then must load the state and reattach 334 | the parameters to the state each training iteration.""" 335 | if self.reload_state_fn is not None: 336 | state = self.reload_state_fn() 337 | for p in self.params: 338 | p.attach_to_state(state) 339 | 340 | def config_to_params(self, config): 341 | """Configs are param dictionaries where each must be a scalar""" 342 | return {p.name(): p.get_parameter_value_from_config(config) for p in self.params} 343 | -------------------------------------------------------------------------------- /src/pytorch_mppi/autotune_global.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | import torch.cuda 4 | 5 | # pip install "ray[tune]" bayesian-optimization hyperopt 6 | from ray import tune 7 | from ray import train 8 | 9 | from pytorch_mppi import autotune 10 | from ray.tune.search.hyperopt import HyperOptSearch 11 | 12 | 13 | class GlobalTunableParameter(autotune.TunableParameter, abc.ABC): 14 | def __init__(self, search_space): 15 | self.search_space = search_space 16 | 17 | @abc.abstractmethod 18 | def total_search_space(self) -> dict: 19 | """Return the potentially multidimensional search space for this parameter, which is a dictionary mapping 20 | each of the parameter's corresponding config names to a search space.""" 21 | 22 | def get_linearized_search_space_value(self, param_values): 23 | if self.dim() == 1: 24 | return [self._linearize_space_value(self.search_space, param_values[self.name()])] 25 | return [self._linearize_space_value(self.search_space, param_values[f"{self.name()}"][i].item()) for i in 26 | range(self.dim())] 27 | 28 | @staticmethod 29 | def linearize_search_space(space): 30 | # tune doesn't have public API for type checking samplers 31 | sampler = space.get_sampler() 32 | if hasattr(sampler, 'base'): 33 | b = np.log(sampler.base) 34 | return np.log(space.lower) / b, np.log(space.upper) / b 35 | return space.lower, space.upper 36 | 37 | @staticmethod 38 | def _linearize_space_value(space, v): 39 | # tune doesn't have public API for type checking samplers 40 | sampler = space.get_sampler() 41 | # log 42 | if hasattr(sampler, 'base'): 43 | b = np.log(sampler.base) 44 | return np.log(v) / b 45 | # quantized 46 | if hasattr(sampler, 'q'): 47 | return np.round(np.divide(v, sampler.q)) * sampler.q 48 | return v 49 | 50 | 51 | class SigmaGlobalParameter(autotune.SigmaParameter, GlobalTunableParameter): 52 | def __init__(self, *args, search_space=tune.loguniform(1e-4, 1e2), **kwargs): 53 | super().__init__(*args, **kwargs) 54 | GlobalTunableParameter.__init__(self, search_space) 55 | 56 | def total_search_space(self) -> dict: 57 | return {f"{self.name()}{i}": self.search_space for i in range(self.dim())} 58 | 59 | 60 | class MuGlobalParameter(autotune.MuParameter, GlobalTunableParameter): 61 | def __init__(self, *args, search_space=tune.uniform(-1, 1), **kwargs): 62 | super().__init__(*args, **kwargs) 63 | GlobalTunableParameter.__init__(self, search_space) 64 | 65 | def total_search_space(self) -> dict: 66 | return {f"{self.name()}{i}": self.search_space for i in range(self.dim())} 67 | 68 | 69 | class LambdaGlobalParameter(autotune.LambdaParameter, GlobalTunableParameter): 70 | def __init__(self, *args, search_space=tune.loguniform(1e-5, 1e3), **kwargs): 71 | super().__init__(*args, **kwargs) 72 | GlobalTunableParameter.__init__(self, search_space) 73 | 74 | def total_search_space(self) -> dict: 75 | return {self.name(): self.search_space} 76 | 77 | 78 | class HorizonGlobalParameter(autotune.HorizonParameter, GlobalTunableParameter): 79 | def __init__(self, *args, search_space=tune.randint(1, 50), **kwargs): 80 | super().__init__(*args, **kwargs) 81 | GlobalTunableParameter.__init__(self, search_space) 82 | 83 | def total_search_space(self) -> dict: 84 | return {self.name(): self.search_space} 85 | 86 | 87 | class AutotuneGlobal(autotune.Autotune): 88 | def search_space(self): 89 | space = {} 90 | for p in self.params: 91 | assert isinstance(p, GlobalTunableParameter) 92 | space.update(p.total_search_space()) 93 | return space 94 | 95 | def linearized_search_space(self): 96 | return {k: GlobalTunableParameter.linearize_search_space(space) for k, space in self.search_space().items()} 97 | 98 | def linearize_params(self, param_values): 99 | v = [] 100 | for p in self.params: 101 | assert isinstance(p, GlobalTunableParameter) 102 | v.extend(p.get_linearized_search_space_value(param_values)) 103 | return np.array(v) 104 | 105 | def initial_value(self): 106 | init = {} 107 | param_values = self.get_parameter_values(self.params) 108 | for p in self.params: 109 | assert isinstance(p, GlobalTunableParameter) 110 | init.update(p.get_config_from_parameter_value(param_values[p.name()])) 111 | return init 112 | 113 | 114 | class RayOptimizer(autotune.Optimizer): 115 | def __init__(self, search_alg=HyperOptSearch, 116 | default_iterations=100): 117 | self.iterations = default_iterations 118 | self.search_alg = search_alg 119 | self.all_res = None 120 | super().__init__() 121 | 122 | def setup_optimization(self): 123 | if not isinstance(self.tuner, AutotuneGlobal): 124 | raise RuntimeError(f"Ray optimizers require global search space information provided by AutotuneMPPIGlobal") 125 | space = self.tuner.search_space() 126 | init = self.tuner.initial_value() 127 | 128 | hyperopt_search = self.search_alg(points_to_evaluate=[init], metric="cost", mode="min") 129 | 130 | trainable_with_resources = tune.with_resources(self.trainable, {"gpu": 1 if torch.cuda.is_available() else 0}) 131 | self.optim = tune.Tuner( 132 | trainable_with_resources, 133 | tune_config=tune.TuneConfig( 134 | num_samples=self.iterations, 135 | search_alg=hyperopt_search, 136 | metric="cost", 137 | mode="min", 138 | ), 139 | param_space=space, 140 | ) 141 | 142 | def trainable(self, config): 143 | self.tuner.attach_parameters() 144 | self.tuner.apply_parameters(self.tuner.config_to_params(config)) 145 | res = self.tuner.evaluate_fn() 146 | train.report({'cost': res.costs.mean().item()}) 147 | 148 | def optimize_step(self): 149 | raise RuntimeError("Ray optimizers only allow tuning of all iterations at once") 150 | 151 | def optimize_all(self, iterations): 152 | self.iterations = iterations 153 | self.setup_optimization() 154 | self.all_res = self.optim.fit() 155 | self.tuner.apply_parameters(self.tuner.config_to_params(self.all_res.get_best_result().config)) 156 | res = self.tuner.evaluate_fn() 157 | return res 158 | -------------------------------------------------------------------------------- /src/pytorch_mppi/autotune_qd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # pip install ribs 4 | import ribs 5 | 6 | from pytorch_mppi import autotune 7 | from pytorch_mppi.autotune_global import AutotuneGlobal 8 | 9 | 10 | class CMAMEOpt(autotune.Optimizer): 11 | """Quality Diversity optimize using CMA-ME to find a set of good and diverse hyperparameters""" 12 | 13 | def __init__(self, population=10, sigma=1.0, bins=15): 14 | """ 15 | 16 | :param population: number of parameters to sample at once (scales linearly) 17 | :param sigma: initial variance along all dimensions 18 | :param bins: int or a Sequence[int] for each hyperparameter for the number of bins in the archive. 19 | More bins means more granularity along that dimension. 20 | """ 21 | self.population = population 22 | self.sigma = sigma 23 | self.archive = None 24 | self.qd_score_offset = -3000 25 | self.num_emitters = 1 26 | self.bins = bins 27 | super().__init__() 28 | 29 | def setup_optimization(self): 30 | if not isinstance(self.tuner, AutotuneGlobal): 31 | raise RuntimeError(f"Quality diversity optimizers require global search space information provided " 32 | f"by AutotuneMPPIGlobal") 33 | 34 | x = self.tuner.flatten_params() 35 | ranges = self.tuner.linearized_search_space() 36 | ranges = list(ranges.values()) 37 | 38 | param_dim = len(x) 39 | bins = self.bins 40 | if isinstance(bins, (int, float)): 41 | bins = [bins for _ in range(param_dim)] 42 | self.archive = ribs.archives.GridArchive(solution_dim=param_dim, 43 | dims=bins, 44 | ranges=ranges, 45 | seed=np.random.randint(0, 10000), qd_score_offset=self.qd_score_offset) 46 | emitters = [ 47 | ribs.emitters.EvolutionStrategyEmitter(self.archive, x0=x, sigma0=self.sigma, batch_size=self.population, 48 | seed=np.random.randint(0, 10000)) for i in 49 | range(self.num_emitters) 50 | ] 51 | self.optim = ribs.schedulers.Scheduler(self.archive, emitters) 52 | 53 | def optimize_step(self): 54 | if not isinstance(self.tuner, AutotuneGlobal): 55 | raise RuntimeError(f"Quality diversity optimizers require global search space information provided " 56 | f"by AutotuneMPPIGlobal") 57 | 58 | params = self.optim.ask() 59 | # measure is the whole hyperparameter set - we want to diverse along each dimension 60 | 61 | cost_per_param = [] 62 | all_rollouts = [] 63 | bcs = [] 64 | for param in params: 65 | full_param = self.tuner.unflatten_params(param) 66 | res = self.tuner.evaluate_fn() 67 | cost_per_param.append(res.costs.mean().cpu().numpy()) 68 | all_rollouts.append(res.rollouts) 69 | behavior = self.tuner.linearize_params(full_param) 70 | bcs.append(behavior) 71 | 72 | cost_per_param = np.array(cost_per_param) 73 | self.optim.tell(-cost_per_param, bcs) 74 | 75 | best_param = self.archive.best_elite 76 | # best_param = self.optim.best.x 77 | self.tuner.unflatten_params(best_param.solution) 78 | res = self.tuner.evaluate_fn() 79 | return res 80 | 81 | def get_diverse_top_parameters(self, num_top): 82 | df = self.archive.as_pandas() 83 | objectives = df.objective_batch() 84 | solutions = df.solution_batch() 85 | # store to allow restoring on next step 86 | if len(solutions) > num_top: 87 | order = np.argpartition(-objectives, num_top) 88 | solutions = solutions[order[:num_top]] 89 | 90 | return [self.tuner.unflatten_params(x, apply=False) for x in solutions] 91 | -------------------------------------------------------------------------------- /src/pytorch_mppi/mppi.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import typing 4 | 5 | import torch 6 | from torch.distributions.multivariate_normal import MultivariateNormal 7 | from arm_pytorch_utilities import handle_batch_input 8 | from functorch import vmap 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def _ensure_non_zero(cost, beta, factor): 14 | return torch.exp(-factor * (cost - beta)) 15 | 16 | 17 | class SpecificActionSampler: 18 | def __init__(self): 19 | self.start_idx = 0 20 | self.end_idx = 0 21 | self.slice = slice(0, 0) 22 | 23 | def sample_trajectories(self, state, info): 24 | raise NotImplementedError 25 | 26 | def specific_dynamics(self, next_state, state, action, t): 27 | """Handle dynamics in a specific way for the specific action sampler; defaults to using default dynamics""" 28 | return next_state 29 | 30 | def register_sample_start_end(self, start_idx, end_idx): 31 | self.start_idx = start_idx 32 | self.end_idx = end_idx 33 | self.slice = slice(start_idx, end_idx) 34 | 35 | 36 | class MPPI(): 37 | """ 38 | Model Predictive Path Integral control 39 | This implementation batch samples the trajectories and so scales well with the number of samples K. 40 | 41 | Implemented according to algorithm 2 in Williams et al., 2017 42 | 'Information Theoretic MPC for Model-Based Reinforcement Learning', 43 | based off of https://github.com/ferreirafabio/mppi_pendulum 44 | """ 45 | 46 | def __init__(self, dynamics, running_cost, nx, noise_sigma, num_samples=100, horizon=15, device="cpu", 47 | terminal_state_cost=None, 48 | lambda_=1., 49 | noise_mu=None, 50 | u_min=None, 51 | u_max=None, 52 | u_init=None, 53 | U_init=None, 54 | u_scale=1, 55 | u_per_command=1, 56 | step_dependent_dynamics=False, 57 | rollout_samples=1, 58 | rollout_var_cost=0, 59 | rollout_var_discount=0.95, 60 | sample_null_action=False, 61 | specific_action_sampler: typing.Optional[SpecificActionSampler] = None, 62 | noise_abs_cost=False): 63 | """ 64 | :param dynamics: function(state, action) -> next_state (K x nx) taking in batch state (K x nx) and action (K x nu) 65 | :param running_cost: function(state, action) -> cost (K) taking in batch state and action (same as dynamics) 66 | :param nx: state dimension 67 | :param noise_sigma: (nu x nu) control noise covariance (assume v_t ~ N(u_t, noise_sigma)) 68 | :param num_samples: K, number of trajectories to sample 69 | :param horizon: T, length of each trajectory 70 | :param device: pytorch device 71 | :param terminal_state_cost: function(state) -> cost (K x 1) taking in batch state 72 | :param lambda_: temperature, positive scalar where larger values will allow more exploration 73 | :param noise_mu: (nu) control noise mean (used to bias control samples); defaults to zero mean 74 | :param u_min: (nu) minimum values for each dimension of control to pass into dynamics 75 | :param u_max: (nu) maximum values for each dimension of control to pass into dynamics 76 | :param u_init: (nu) what to initialize new end of trajectory control to be; defeaults to zero 77 | :param U_init: (T x nu) initial control sequence; defaults to noise 78 | :param step_dependent_dynamics: whether the passed in dynamics needs horizon step passed in (as 3rd arg) 79 | :param rollout_samples: M, number of state trajectories to rollout for each control trajectory 80 | (should be 1 for deterministic dynamics and more for models that output a distribution) 81 | :param rollout_var_cost: Cost attached to the variance of costs across trajectory rollouts 82 | :param rollout_var_discount: Discount of variance cost over control horizon 83 | :param sample_null_action: Whether to explicitly sample a null action (bad for starting in a local minima) 84 | :param specific_action_sampler: Function to explicitly sample actions to use instead of sampling from noise from 85 | nominal trajectory, may output a number of action trajectories fewer than horizon 86 | :param noise_abs_cost: Whether to use the absolute value of the action noise to avoid bias when all states have the same cost 87 | """ 88 | self.d = device 89 | self.dtype = noise_sigma.dtype 90 | self.K = num_samples # N_SAMPLES 91 | self.T = horizon # TIMESTEPS 92 | 93 | # dimensions of state and control 94 | self.nx = nx 95 | self.nu = 1 if len(noise_sigma.shape) == 0 else noise_sigma.shape[0] 96 | self.lambda_ = lambda_ 97 | 98 | if noise_mu is None: 99 | noise_mu = torch.zeros(self.nu, dtype=self.dtype) 100 | 101 | if u_init is None: 102 | u_init = torch.zeros_like(noise_mu) 103 | 104 | # handle 1D edge case 105 | if self.nu == 1: 106 | noise_mu = noise_mu.view(-1) 107 | noise_sigma = noise_sigma.view(-1, 1) 108 | 109 | # bounds 110 | self.u_min = u_min 111 | self.u_max = u_max 112 | self.u_scale = u_scale 113 | self.u_per_command = u_per_command 114 | # make sure if any of them is specified, both are specified 115 | if self.u_max is not None and self.u_min is None: 116 | if not torch.is_tensor(self.u_max): 117 | self.u_max = torch.tensor(self.u_max) 118 | self.u_min = -self.u_max 119 | if self.u_min is not None and self.u_max is None: 120 | if not torch.is_tensor(self.u_min): 121 | self.u_min = torch.tensor(self.u_min) 122 | self.u_max = -self.u_min 123 | if self.u_min is not None: 124 | self.u_min = self.u_min.to(device=self.d) 125 | self.u_max = self.u_max.to(device=self.d) 126 | 127 | self.noise_mu = noise_mu.to(self.d) 128 | self.noise_sigma = noise_sigma.to(self.d) 129 | self.noise_sigma_inv = torch.inverse(self.noise_sigma) 130 | self.noise_dist = MultivariateNormal(self.noise_mu, covariance_matrix=self.noise_sigma) 131 | # T x nu control sequence 132 | self.U = U_init 133 | self.u_init = u_init.to(self.d) 134 | 135 | if self.U is None: 136 | self.U = self.noise_dist.sample((self.T,)) 137 | 138 | self.step_dependency = step_dependent_dynamics 139 | self.F = dynamics 140 | self.running_cost = running_cost 141 | self.terminal_state_cost = terminal_state_cost 142 | self.sample_null_action = sample_null_action 143 | self.specific_action_sampler = specific_action_sampler 144 | self.noise_abs_cost = noise_abs_cost 145 | self.state = None 146 | self.info = None 147 | 148 | # handling dynamics models that output a distribution (take multiple trajectory samples) 149 | self.M = rollout_samples 150 | self.rollout_var_cost = rollout_var_cost 151 | self.rollout_var_discount = rollout_var_discount 152 | 153 | # sampled results from last command 154 | self.cost_total = None 155 | self.cost_total_non_zero = None 156 | self.omega = None 157 | self.states = None 158 | self.actions = None 159 | 160 | def get_params(self): 161 | return f"K={self.K} T={self.T} M={self.M} lambda={self.lambda_} noise_mu={self.noise_mu.cpu().numpy()} noise_sigma={self.noise_sigma.cpu().numpy()}".replace( 162 | "\n", ",") 163 | 164 | @handle_batch_input(n=2) 165 | def _dynamics(self, state, u, t): 166 | return self.F(state, u, t) if self.step_dependency else self.F(state, u) 167 | 168 | @handle_batch_input(n=2) 169 | def _running_cost(self, state, u, t): 170 | return self.running_cost(state, u, t) if self.step_dependency else self.running_cost(state, u) 171 | 172 | def get_action_sequence(self): 173 | return self.U 174 | 175 | def shift_nominal_trajectory(self): 176 | """ 177 | Shift the nominal trajectory forward one step 178 | """ 179 | # shift command 1 time step 180 | self.U = torch.roll(self.U, -1, dims=0) 181 | self.U[-1] = self.u_init 182 | 183 | def command(self, state, shift_nominal_trajectory=True, info=None): 184 | """ 185 | :param state: (nx) or (K x nx) current state, or samples of states (for propagating a distribution of states) 186 | :param shift_nominal_trajectory: Whether to roll the nominal trajectory forward one step. This should be True 187 | if the command is to be executed. If the nominal trajectory is to be refined then it should be False. 188 | :param info: Optional dictionary to store context information 189 | :returns action: (nu) best action 190 | """ 191 | self.info = info 192 | if shift_nominal_trajectory: 193 | self.shift_nominal_trajectory() 194 | 195 | return self._command(state) 196 | 197 | def _compute_weighting(self, cost_total): 198 | beta = torch.min(cost_total) 199 | self.cost_total_non_zero = _ensure_non_zero(cost_total, beta, 1 / self.lambda_) 200 | eta = torch.sum(self.cost_total_non_zero) 201 | self.omega = (1. / eta) * self.cost_total_non_zero 202 | return self.omega 203 | 204 | def _command(self, state): 205 | if not torch.is_tensor(state): 206 | state = torch.tensor(state) 207 | self.state = state.to(dtype=self.dtype, device=self.d) 208 | cost_total = self._compute_total_cost_batch() 209 | 210 | self._compute_weighting(cost_total) 211 | perturbations = torch.sum(self.omega.view(-1, 1, 1) * self.noise, dim=0) 212 | 213 | self.U = self.U + perturbations 214 | action = self.get_action_sequence()[:self.u_per_command] 215 | # reduce dimensionality if we only need the first command 216 | if self.u_per_command == 1: 217 | action = action[0] 218 | return action 219 | 220 | def change_horizon(self, horizon): 221 | if horizon < self.U.shape[0]: 222 | # truncate trajectory 223 | self.U = self.U[:horizon] 224 | elif horizon > self.U.shape[0]: 225 | # extend with u_init 226 | self.U = torch.cat((self.U, self.u_init.repeat(horizon - self.U.shape[0], 1))) 227 | self.T = horizon 228 | 229 | def reset(self): 230 | """ 231 | Clear controller state after finishing a trial 232 | """ 233 | self.U = self.noise_dist.sample((self.T,)) 234 | 235 | def _compute_rollout_costs(self, perturbed_actions): 236 | K, T, nu = perturbed_actions.shape 237 | assert nu == self.nu 238 | 239 | cost_total = torch.zeros(K, device=self.d, dtype=self.dtype) 240 | cost_samples = cost_total.repeat(self.M, 1) 241 | cost_var = torch.zeros_like(cost_total) 242 | 243 | # allow propagation of a sample of states (ex. to carry a distribution), or to start with a single state 244 | if self.state.shape == (K, self.nx): 245 | state = self.state 246 | else: 247 | state = self.state.view(1, -1).repeat(K, 1) 248 | 249 | # rollout action trajectory M times to estimate expected cost 250 | state = state.repeat(self.M, 1, 1) 251 | 252 | states = [] 253 | actions = [] 254 | for t in range(T): 255 | u = self.u_scale * perturbed_actions[:, t].repeat(self.M, 1, 1) 256 | next_state = self._dynamics(state, u, t) 257 | # potentially handle dynamics in a specific way for the specific action sampler 258 | next_state = self._sample_specific_dynamics(next_state, state, u, t) 259 | state = next_state 260 | c = self._running_cost(state, u, t) 261 | cost_samples = cost_samples + c 262 | if self.M > 1: 263 | cost_var += c.var(dim=0) * (self.rollout_var_discount ** t) 264 | 265 | # Save total states/actions 266 | states.append(state) 267 | actions.append(u) 268 | 269 | # Actions is K x T x nu 270 | # States is K x T x nx 271 | actions = torch.stack(actions, dim=-2) 272 | states = torch.stack(states, dim=-2) 273 | 274 | # action perturbation cost 275 | if self.terminal_state_cost: 276 | c = self.terminal_state_cost(states, actions) 277 | cost_samples = cost_samples + c 278 | cost_total = cost_total + cost_samples.mean(dim=0) 279 | cost_total = cost_total + cost_var * self.rollout_var_cost 280 | return cost_total, states, actions 281 | 282 | def _compute_perturbed_action_and_noise(self): 283 | # parallelize sampling across trajectories 284 | # resample noise each time we take an action 285 | noise = self.noise_dist.rsample((self.K, self.T)) 286 | # broadcast own control to noise over samples; now it's K x T x nu 287 | perturbed_action = self.U + noise 288 | perturbed_action = self._sample_specific_actions(perturbed_action) 289 | # naively bound control 290 | self.perturbed_action = self._bound_action(perturbed_action) 291 | # bounded noise after bounding (some got cut off, so we don't penalize that in action cost) 292 | self.noise = self.perturbed_action - self.U 293 | 294 | def _sample_specific_actions(self, perturbed_action): 295 | # specific sampling of actions (encoding trajectory prior and domain knowledge to create biases) 296 | i = 0 297 | if self.sample_null_action: 298 | perturbed_action[i] = 0 299 | i += 1 300 | if self.specific_action_sampler is not None: 301 | actions = self.specific_action_sampler.sample_trajectories(self.state, self.info) 302 | # check how long it is 303 | actions = actions.reshape(-1, self.T, self.nu) 304 | perturbed_action[i:i + actions.shape[0]] = actions 305 | self.specific_action_sampler.register_sample_start_end(i, i + actions.shape[0]) 306 | i += actions.shape[0] 307 | return perturbed_action 308 | 309 | def _sample_specific_dynamics(self, next_state, state, u, t): 310 | if self.specific_action_sampler is not None: 311 | next_state = self.specific_action_sampler.specific_dynamics(next_state, state, u, t) 312 | return next_state 313 | 314 | def _compute_total_cost_batch(self): 315 | self._compute_perturbed_action_and_noise() 316 | if self.noise_abs_cost: 317 | action_cost = self.lambda_ * torch.abs(self.noise) @ self.noise_sigma_inv 318 | # NOTE: The original paper does self.lambda_ * torch.abs(self.noise) @ self.noise_sigma_inv, but this biases 319 | # the actions with low noise if all states have the same cost. With abs(noise) we prefer actions close to the 320 | # nomial trajectory. 321 | else: 322 | action_cost = self.lambda_ * self.noise @ self.noise_sigma_inv # Like original paper 323 | 324 | rollout_cost, self.states, actions = self._compute_rollout_costs(self.perturbed_action) 325 | self.actions = actions / self.u_scale 326 | 327 | # action perturbation cost 328 | perturbation_cost = torch.sum(self.U * action_cost, dim=(1, 2)) 329 | self.cost_total = rollout_cost + perturbation_cost 330 | return self.cost_total 331 | 332 | def _bound_action(self, action): 333 | if self.u_max is not None: 334 | return torch.max(torch.min(action, self.u_max), self.u_min) 335 | return action 336 | 337 | def _slice_control(self, t): 338 | return slice(t * self.nu, (t + 1) * self.nu) 339 | 340 | def get_rollouts(self, state, num_rollouts=1, U=None): 341 | """ 342 | :param state: either (nx) vector or (num_rollouts x nx) for sampled initial states 343 | :param num_rollouts: Number of rollouts with same action sequence - for generating samples with stochastic 344 | dynamics 345 | :returns states: num_rollouts x T x nx vector of trajectories 346 | 347 | """ 348 | state = state.view(-1, self.nx) 349 | if state.size(0) == 1: 350 | state = state.repeat(num_rollouts, 1) 351 | 352 | if U is None: 353 | U = self.get_action_sequence() 354 | T = U.shape[0] 355 | states = torch.zeros((num_rollouts, T + 1, self.nx), dtype=U.dtype, device=U.device) 356 | states[:, 0] = state 357 | for t in range(T): 358 | next_state = self._dynamics(states[:, t].view(num_rollouts, -1), 359 | self.u_scale * U[t].tile(num_rollouts, 1), t) 360 | # dynamics may augment state; here we just take the first nx dimensions 361 | states[:, t + 1] = next_state[:, :self.nx] 362 | 363 | return states[:, 1:] 364 | 365 | 366 | class SMPPI(MPPI): 367 | """Smooth MPPI by lifting the control space and penalizing the change in action from 368 | https://arxiv.org/pdf/2112.09988 369 | """ 370 | 371 | def __init__(self, *args, w_action_seq_cost=1., delta_t=1., U_init=None, action_min=None, action_max=None, 372 | **kwargs): 373 | self.w_action_seq_cost = w_action_seq_cost 374 | self.delta_t = delta_t 375 | 376 | super().__init__(*args, U_init=U_init, **kwargs) 377 | 378 | # these are the actual commanded actions, which is now no longer directly sampled 379 | self.action_min = action_min 380 | self.action_max = action_max 381 | if self.action_min is not None and self.action_max is None: 382 | if not torch.is_tensor(self.action_min): 383 | self.action_min = torch.tensor(self.action_min) 384 | self.action_max = -self.action_min 385 | if self.action_max is not None and self.action_min is None: 386 | if not torch.is_tensor(self.action_max): 387 | self.action_max = torch.tensor(self.action_max) 388 | self.action_min = -self.action_max 389 | if self.action_min is not None: 390 | self.action_min = self.action_min.to(device=self.d) 391 | self.action_max = self.action_max.to(device=self.d) 392 | 393 | # this smooth formulation works better if control starts from 0 394 | if U_init is None: 395 | self.action_sequence = torch.zeros_like(self.U) 396 | else: 397 | self.action_sequence = U_init 398 | self.U = torch.zeros_like(self.U) 399 | 400 | def get_params(self): 401 | return f"{super().get_params()} w={self.w_action_seq_cost} t={self.delta_t}" 402 | 403 | def shift_nominal_trajectory(self): 404 | self.U = torch.roll(self.U, -1, dims=0) 405 | self.U[-1] = self.u_init 406 | self.action_sequence = torch.roll(self.action_sequence, -1, dims=0) 407 | self.action_sequence[-1] = self.action_sequence[-2] # add T-1 action to T 408 | 409 | def get_action_sequence(self): 410 | return self.action_sequence 411 | 412 | def reset(self): 413 | self.U = torch.zeros_like(self.U) 414 | self.action_sequence = torch.zeros_like(self.U) 415 | 416 | def change_horizon(self, horizon): 417 | if horizon < self.U.shape[0]: 418 | # truncate trajectory 419 | self.U = self.U[:horizon] 420 | self.action_sequence = self.action_sequence[:horizon] 421 | elif horizon > self.U.shape[0]: 422 | # extend with u_init 423 | extend_for = horizon - self.U.shape[0] 424 | self.U = torch.cat((self.U, self.u_init.repeat(extend_for, 1))) 425 | self.action_sequence = torch.cat((self.action_sequence, self.action_sequence[-1].repeat(extend_for, 1))) 426 | self.T = horizon 427 | 428 | def _bound_d_action(self, control): 429 | if self.u_max is not None: 430 | return torch.max(torch.min(control, self.u_max), self.u_min) # action 431 | return control 432 | 433 | def _bound_action(self, action): 434 | if self.action_max is not None: 435 | return torch.max(torch.min(action, self.action_max), self.action_min) 436 | return action 437 | 438 | def _command(self, state): 439 | if not torch.is_tensor(state): 440 | state = torch.tensor(state) 441 | self.state = state.to(dtype=self.dtype, device=self.d) 442 | cost_total = self._compute_total_cost_batch() 443 | 444 | self._compute_weighting(cost_total) 445 | perturbations = torch.sum(self.omega.view(-1, 1, 1) * self.noise, dim=0) 446 | 447 | self.U = self.U + perturbations 448 | # U is now the lifted control space, so we integrate it 449 | self.action_sequence += self.U * self.delta_t 450 | 451 | action = self.get_action_sequence()[:self.u_per_command] 452 | # reduce dimensionality if we only need the first command 453 | if self.u_per_command == 1: 454 | action = action[0] 455 | return action 456 | 457 | def _compute_perturbed_action_and_noise(self): 458 | # parallelize sampling across trajectories 459 | # resample noise each time we take an action 460 | noise = self.noise_dist.rsample((self.K, self.T)) 461 | # broadcast own control to noise over samples; now it's K x T x nu 462 | perturbed_control = self.U + noise 463 | # naively bound control 464 | self.perturbed_control = self._bound_d_action(perturbed_control) 465 | # bounded noise after bounding (some got cut off, so we don't penalize that in action cost) 466 | self.perturbed_action = self.action_sequence + perturbed_control * self.delta_t 467 | self.perturbed_action = self._sample_specific_actions(self.perturbed_action) 468 | self.perturbed_action = self._bound_action(self.perturbed_action) 469 | 470 | self.noise = (self.perturbed_action - self.action_sequence) / self.delta_t - self.U 471 | 472 | def _compute_total_cost_batch(self): 473 | self._compute_perturbed_action_and_noise() 474 | if self.noise_abs_cost: 475 | action_cost = self.lambda_ * torch.abs(self.noise) @ self.noise_sigma_inv 476 | # NOTE: The original paper does self.lambda_ * torch.abs(self.noise) @ self.noise_sigma_inv, but this biases 477 | # the actions with low noise if all states have the same cost. With abs(noise) we prefer actions close to the 478 | # nomial trajectory. 479 | else: 480 | action_cost = self.lambda_ * self.noise @ self.noise_sigma_inv # Like original paper 481 | 482 | # action difference as cost 483 | action_diff = self.u_scale * torch.diff(self.perturbed_action, dim=-2) 484 | action_smoothness_cost = torch.sum(torch.square(action_diff), dim=(1, 2)) 485 | # handle non-homogeneous action sequence cost 486 | action_smoothness_cost *= self.w_action_seq_cost 487 | 488 | rollout_cost, self.states, actions = self._compute_rollout_costs(self.perturbed_action) 489 | self.actions = actions / self.u_scale 490 | 491 | # action perturbation cost 492 | perturbation_cost = torch.sum(self.U * action_cost, dim=(1, 2)) 493 | self.cost_total = rollout_cost + perturbation_cost + action_smoothness_cost 494 | return self.cost_total 495 | 496 | 497 | class TimeKernel: 498 | """Kernel acting on the time dimension of trajectories for use in interpolation and smoothing""" 499 | 500 | def __call__(self, t, tk): 501 | raise NotImplementedError 502 | 503 | 504 | class RBFKernel(TimeKernel): 505 | def __init__(self, sigma=1): 506 | self.sigma = sigma 507 | 508 | def __repr__(self): 509 | return f"RBFKernel(sigma={self.sigma})" 510 | 511 | def __call__(self, t, tk): 512 | d = torch.sum((t[:, None] - tk) ** 2, dim=-1) 513 | k = torch.exp(-d / (1e-8 + 2 * self.sigma ** 2)) 514 | return k 515 | 516 | 517 | class KMPPI(MPPI): 518 | """MPPI with kernel interpolation of control points for smoothing""" 519 | 520 | def __init__(self, *args, num_support_pts=None, kernel: TimeKernel = RBFKernel(), **kwargs): 521 | super().__init__(*args, **kwargs) 522 | self.num_support_pts = num_support_pts or self.T // 2 523 | # control points to be sampled 524 | self.theta = torch.zeros((self.num_support_pts, self.nu), dtype=self.dtype, device=self.d) 525 | self.Tk = None 526 | self.Hs = None 527 | # interpolation kernel 528 | self.interpolation_kernel = kernel 529 | self.intp_krnl = None 530 | self.prepare_vmap_interpolation() 531 | 532 | def get_params(self): 533 | return f"{super().get_params()} num_support_pts={self.num_support_pts} kernel={self.interpolation_kernel}" 534 | 535 | def reset(self): 536 | super().reset() 537 | self.theta.zero_() 538 | 539 | def shift_nominal_trajectory(self): 540 | super().shift_nominal_trajectory() 541 | self.theta, _ = self.do_kernel_interpolation(self.Tk[0] + 1, self.Tk[0], self.theta) 542 | 543 | def do_kernel_interpolation(self, t, tk, c): 544 | K = self.interpolation_kernel(t.unsqueeze(-1), tk.unsqueeze(-1)) 545 | Ktktk = self.interpolation_kernel(tk.unsqueeze(-1), tk.unsqueeze(-1)) 546 | # print(K.shape, Ktktk.shape) 547 | # row normalize K 548 | # K = K / K.sum(dim=1).unsqueeze(1) 549 | 550 | # KK = K @ torch.inverse(Ktktk) 551 | KK = torch.linalg.solve(Ktktk, K, left=False) 552 | 553 | return torch.matmul(KK, c), K 554 | 555 | def prepare_vmap_interpolation(self): 556 | self.Tk = torch.linspace(0, self.T - 1, int(self.num_support_pts), device=self.d, dtype=self.dtype).unsqueeze( 557 | 0).repeat(self.K, 1) 558 | self.Hs = torch.linspace(0, self.T - 1, int(self.T), device=self.d, dtype=self.dtype).unsqueeze(0).repeat( 559 | self.K, 1) 560 | self.intp_krnl = vmap(self.do_kernel_interpolation) 561 | 562 | def deparameterize_to_trajectory_single(self, theta): 563 | return self.do_kernel_interpolation(self.Hs[0], self.Tk[0], theta) 564 | 565 | def deparameterize_to_trajectory_batch(self, theta): 566 | assert theta.shape == (self.K, self.num_support_pts, self.nu) 567 | return self.intp_krnl(self.Hs, self.Tk, theta) 568 | 569 | def _compute_perturbed_action_and_noise(self): 570 | # parallelize sampling across trajectories 571 | # resample noise each time we take an action 572 | noise = self.noise_dist.rsample((self.K, self.num_support_pts)) 573 | perturbed_control_pts = self.theta + noise 574 | # control points in the same space as control and should be bounded 575 | perturbed_control_pts = self._bound_action(perturbed_control_pts) 576 | self.noise_theta = perturbed_control_pts - self.theta 577 | perturbed_action, _ = self.deparameterize_to_trajectory_batch(perturbed_control_pts) 578 | perturbed_action = self._sample_specific_actions(perturbed_action) 579 | # naively bound control 580 | self.perturbed_action = self._bound_action(perturbed_action) 581 | # bounded noise after bounding (some got cut off, so we don't penalize that in action cost) 582 | self.noise = self.perturbed_action - self.U 583 | 584 | def _command(self, state): 585 | if not torch.is_tensor(state): 586 | state = torch.tensor(state) 587 | self.state = state.to(dtype=self.dtype, device=self.d) 588 | cost_total = self._compute_total_cost_batch() 589 | 590 | self._compute_weighting(cost_total) 591 | perturbations = torch.sum(self.omega.view(-1, 1, 1) * self.noise_theta, dim=0) 592 | 593 | self.theta = self.theta + perturbations 594 | self.U, _ = self.deparameterize_to_trajectory_single(self.theta) 595 | 596 | action = self.get_action_sequence()[:self.u_per_command] 597 | # reduce dimensionality if we only need the first command 598 | if self.u_per_command == 1: 599 | action = action[0] 600 | return action 601 | 602 | 603 | def run_mppi(mppi, env, retrain_dynamics, retrain_after_iter=50, iter=1000, render=True): 604 | dataset = torch.zeros((retrain_after_iter, mppi.nx + mppi.nu), dtype=mppi.U.dtype, device=mppi.d) 605 | total_reward = 0 606 | for i in range(iter): 607 | state = env.unwrapped.state.copy() 608 | command_start = time.perf_counter() 609 | action = mppi.command(state) 610 | elapsed = time.perf_counter() - command_start 611 | res = env.step(action.cpu().numpy()) 612 | s, r = res[0], res[1] 613 | total_reward += r 614 | logger.debug("action taken: %.4f cost received: %.4f time taken: %.5fs", action, -r, elapsed) 615 | if render: 616 | env.render() 617 | 618 | di = i % retrain_after_iter 619 | if di == 0 and i > 0: 620 | retrain_dynamics(dataset) 621 | # don't have to clear dataset since it'll be overridden, but useful for debugging 622 | dataset.zero_() 623 | dataset[di, :mppi.nx] = torch.tensor(state, dtype=mppi.U.dtype) 624 | dataset[di, mppi.nx:] = action 625 | return total_reward, dataset 626 | -------------------------------------------------------------------------------- /tests/auto_tune_parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import typing 3 | 4 | import window_recorder 5 | from arm_pytorch_utilities import linalg 6 | import matplotlib.colors 7 | from matplotlib import pyplot as plt 8 | 9 | from pytorch_mppi.mppi import handle_batch_input 10 | 11 | from pytorch_mppi import autotune 12 | 13 | from pytorch_mppi import MPPI 14 | from pytorch_seed import seed 15 | import logging 16 | # import window_recorder 17 | from contextlib import nullcontext 18 | 19 | plt.switch_backend('Qt5Agg') 20 | 21 | logger = logging.getLogger(__file__) 22 | logging.basicConfig(level=logging.INFO, 23 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 24 | datefmt='%m-%d %H:%M:%S') 25 | 26 | 27 | class LinearDeltaDynamics: 28 | def __init__(self, B): 29 | self.B = B 30 | 31 | @handle_batch_input(n=2) 32 | def __call__(self, state, action): 33 | nx = state + action @ self.B.transpose(0, 1) 34 | return nx 35 | 36 | 37 | class ScaledLinearDynamics: 38 | def __init__(self, cost, B): 39 | self.B = B 40 | self.cost = cost 41 | 42 | @handle_batch_input(n=2) 43 | def __call__(self, state, action): 44 | nx = state + action @ self.B.transpose(0, 1) / torch.log(self.cost(state) + 1e-8).reshape(-1, 1) * 2 45 | return nx 46 | 47 | 48 | class LQRCost: 49 | def __init__(self, Q, R, goal): 50 | self.Q = Q 51 | self.R = R 52 | self.goal = goal 53 | 54 | @handle_batch_input(n=2) 55 | def __call__(self, state, action=None): 56 | dx = self.goal - state 57 | c = linalg.batch_quadratic_product(dx, self.Q) 58 | if action is not None: 59 | c += linalg.batch_quadratic_product(action, self.R) 60 | return c 61 | 62 | 63 | class HillCost: 64 | def __init__(self, Q, center, cost_at_center=1): 65 | self.Q = Q 66 | self.center = center 67 | self.cost_at_center = cost_at_center 68 | 69 | @handle_batch_input(n=2) 70 | def __call__(self, state, action=None): 71 | dx = self.center - state 72 | d = linalg.batch_quadratic_product(dx, self.Q) 73 | c = self.cost_at_center * torch.exp(-d) 74 | return c 75 | 76 | 77 | class Toy2DEnvironment: 78 | def __init__(self, start=None, goal=None, dtype=torch.double, device="cpu", evaluate_running_cost=True, 79 | visualize=True, 80 | num_trajectories=5, 81 | terminal_scale=100, 82 | r=0.01): 83 | self.d = device 84 | self.dtype = dtype 85 | self.state_ranges = [ 86 | (-5, 5), 87 | (-5, 5) 88 | ] 89 | self.evaluate_running_cost = evaluate_running_cost 90 | self.num_trajectories = num_trajectories 91 | self.visualize = visualize 92 | self.nx = 2 93 | 94 | self.start = start or torch.tensor([-3, -2], device=self.d, dtype=self.dtype) 95 | self.goal = goal or torch.tensor([2, 2], device=self.d, dtype=self.dtype) 96 | 97 | self.costs = [] 98 | 99 | eye = torch.eye(2, device=self.d, dtype=self.dtype) 100 | goal_cost = LQRCost(eye, eye * r, self.goal) 101 | self.costs.append(goal_cost) 102 | 103 | # for increasing difficulty, we add some "hills" 104 | self.costs.append(HillCost(torch.tensor([[0.1, 0.05], [0.05, 0.1]], device=self.d, dtype=self.dtype) * 2.5, 105 | torch.tensor([-0.5, -1.], device=self.d, dtype=self.dtype), cost_at_center=200)) 106 | 107 | B = torch.tensor([[0.5, 0], [0, -0.5]], device=self.d, dtype=self.dtype) 108 | self.dynamics = LinearDeltaDynamics(B) 109 | # self.dynamics = ScaledLinearDynamics(self.running_cost, B) 110 | 111 | self.terminal_scale = terminal_scale 112 | self.start_visualization() 113 | 114 | def terminal_cost(self, states, actions): 115 | return self.terminal_scale * self.running_cost(states[..., -1, :]) 116 | 117 | @handle_batch_input(n=2) 118 | def running_cost(self, state, action=None): 119 | c = None 120 | for cost in self.costs: 121 | if c is None: 122 | c = cost(state, action) 123 | else: 124 | c += cost(state, action) 125 | return c 126 | 127 | def start_visualization(self): 128 | if self.visualize: 129 | plt.ion() 130 | plt.show() 131 | 132 | self.fig, self.ax = plt.subplots(figsize=(7, 7)) 133 | self.ax.set_aspect('equal') 134 | self.ax.set(xlim=self.state_ranges[0]) 135 | self.ax.set(ylim=self.state_ranges[0]) 136 | 137 | self.cmap = "Greys" 138 | # artists for clearing / redrawing 139 | self.start_artist = None 140 | self.goal_artist = None 141 | self.cost_artist = None 142 | self.rollout_artist = None 143 | self.draw_costs() 144 | self.draw_start() 145 | self.draw_goal() 146 | 147 | def draw_results(self, params, all_results: typing.Sequence[autotune.EvaluationResult]): 148 | iterations = [res.iteration for res in all_results] 149 | loss = [res.costs.mean().item() for res in all_results] 150 | 151 | # loss curve 152 | fig, ax = plt.subplots() 153 | ax.plot(iterations, loss) 154 | ax.set_xlabel('iteration') 155 | ax.set_ylabel('cost') 156 | plt.pause(0.001) 157 | plt.savefig('cost.png') 158 | 159 | if 'sigma' in params: 160 | sigma = [res.params['sigma'] for res in all_results] 161 | sigma = torch.stack(sigma) 162 | fig, ax = plt.subplots(nrows=2, sharex=True) 163 | ax[0].plot(iterations, sigma[:, 0]) 164 | ax[1].plot(iterations, sigma[:, 1]) 165 | ax[1].set_xlabel('iteration') 166 | ax[0].set_ylabel('sigma[0]') 167 | ax[1].set_ylabel('sigma[1]') 168 | plt.draw() 169 | plt.pause(0.005) 170 | plt.savefig('sigma.png') 171 | 172 | def draw_rollouts(self, rollouts): 173 | if not self.visualize: 174 | return 175 | self.clear_artist(self.rollout_artist) 176 | artists = [] 177 | for rollout in rollouts: 178 | r = torch.cat((self.start.reshape(1, -1), rollout)) 179 | artists += self.ax.plot(r[:, 0], r[:, 1], color="skyblue") 180 | artists += [self.ax.scatter(r[-1, 0], r[-1, 1], color="tab:red")] 181 | self.rollout_artist = artists 182 | plt.pause(0.001) 183 | 184 | def draw_costs(self, resolution=0.05, value_padding=0): 185 | if not self.visualize: 186 | return 187 | coords = [torch.arange(low, high + resolution, resolution, dtype=self.dtype, device=self.d) for low, high in 188 | self.state_ranges] 189 | pts = torch.cartesian_prod(*coords) 190 | val = self.running_cost(pts) 191 | 192 | norm = matplotlib.colors.Normalize(vmin=val.min().cpu() - value_padding, vmax=val.max().cpu()) 193 | 194 | x = coords[0].cpu() 195 | z = coords[1].cpu() 196 | v = val.reshape(len(x), len(z)).transpose(0, 1).cpu() 197 | 198 | self.clear_artist(self.cost_artist) 199 | a = [] 200 | a.append(self.ax.contourf(x, z, v, levels=[2, 4, 8, 16, 24, 32, 40, 50, 60, 80, 100, 150, 200, 250], norm=norm, 201 | cmap=self.cmap)) 202 | a.append(self.ax.contour(x, z, v, levels=a[0].levels, colors='k', linestyles='dashed')) 203 | a.append(self.ax.clabel(a[1], a[1].levels, inline=True, fontsize=13)) 204 | self.cost_artist = a 205 | 206 | plt.draw() 207 | plt.pause(0.0005) 208 | 209 | @staticmethod 210 | def clear_artist(artist): 211 | if artist is not None: 212 | for a in artist: 213 | a.remove() 214 | 215 | def draw_start(self): 216 | if not self.visualize: 217 | return 218 | self.clear_artist(self.start_artist) 219 | self.start_artist = self.draw_state(self.start, "tab:blue", label='start') 220 | 221 | def draw_goal(self): 222 | # when combined with other costs it's no longer the single goal so no need for label 223 | return 224 | if not self.visualize: 225 | return 226 | self.clear_artist(self.goal_artist) 227 | # when combined with other costs it's no longer the single goal so no need for label 228 | self.goal_artist = self.draw_state(self.goal, "tab:green") # , label='goal') 229 | 230 | def draw_state(self, state, color, label=None, ox=-0.3, oy=0.3): 231 | artists = [self.ax.scatter(state[0].cpu(), state[1].cpu(), color=color)] 232 | if label is not None: 233 | artists.append(self.ax.text(state[0].cpu() + ox, state[1].cpu() + oy, label, color=color)) 234 | plt.pause(0.0001) 235 | return artists 236 | 237 | 238 | def main(): 239 | seed(1) 240 | device = "cpu" 241 | dtype = torch.double 242 | 243 | # create toy environment to do on control on (default start and goal) 244 | env = Toy2DEnvironment(visualize=True, terminal_scale=10) 245 | 246 | # create MPPI with some initial parameters 247 | mppi = MPPI(env.dynamics, env.running_cost, 2, 248 | noise_sigma=torch.diag(torch.tensor([5., 5.], dtype=dtype, device=device)), 249 | num_samples=500, 250 | horizon=20, device=device, 251 | terminal_state_cost=env.terminal_cost, 252 | u_max=torch.tensor([2., 2.], dtype=dtype, device=device), 253 | lambda_=1) 254 | 255 | # use the same nominal trajectory to start with for all the evaluations for fairness 256 | nominal_trajectory = mppi.U.clone() 257 | # parameters for our sample evaluation function - lots of choices for the evaluation function 258 | evaluate_running_cost = True 259 | num_refinement_steps = 10 260 | num_trajectories = 5 261 | 262 | def evaluate(): 263 | costs = [] 264 | rollouts = [] 265 | # we sample multiple trajectories for the same start to goal problem, but in your case you should consider 266 | # evaluating over a diverse dataset of trajectories 267 | for j in range(num_trajectories): 268 | mppi.U = nominal_trajectory.clone() 269 | # the nominal trajectory at the start will be different if the horizon's changed 270 | mppi.change_horizon(mppi.T) 271 | # usually MPPI will have its nominal trajectory warm-started from the previous iteration 272 | # for a fair test of tuning we will reset its nominal trajectory to the same random one each time 273 | # we manually warm it by refining it for some steps 274 | for k in range(num_refinement_steps): 275 | mppi.command(env.start, shift_nominal_trajectory=False) 276 | 277 | rollout = mppi.get_rollouts(env.start) 278 | 279 | this_cost = 0 280 | rollout = rollout[0] 281 | # here we evaluate on the rollout MPPI cost of the resulting trajectories 282 | # alternative costs for tuning the parameters are possible, such as just considering terminal cost 283 | if evaluate_running_cost: 284 | for t in range(len(rollout) - 1): 285 | this_cost = this_cost + env.running_cost(rollout[t], mppi.U[t]) 286 | this_cost = this_cost + env.terminal_cost(rollout, mppi.U) 287 | 288 | rollouts.append(rollout) 289 | costs.append(this_cost) 290 | return autotune.EvaluationResult(torch.stack(costs), torch.stack(rollouts)) 291 | 292 | # choose from autotune.AutotuneMPPI.TUNABLE_PARAMS 293 | params_to_tune = [autotune.SigmaParameter(mppi), autotune.HorizonParameter(mppi), autotune.LambdaParameter(mppi)] 294 | # create a tuner with a CMA-ES optimizer 295 | # tuner = autotune.Autotune(params_to_tune, evaluate_fn=evaluate, optimizer=autotune.CMAESOpt(sigma=1.0)) 296 | # # tune parameters for a number of iterations 297 | # with window_recorder.WindowRecorder(["Figure 1"]): 298 | # iterations = 30 299 | # for i in range(iterations): 300 | # # results of this optimization step are returned 301 | # res = tuner.optimize_step() 302 | # # we can render the rollouts in the environment 303 | # env.draw_rollouts(res.rollouts) 304 | # 305 | # # get best results and apply it to the controller 306 | # # (by default the controller will take on the latest tuned parameter, which may not be best) 307 | # res = tuner.get_best_result() 308 | # tuner.apply_parameters(res.params) 309 | # env.draw_results(res.params, tuner.results) 310 | 311 | try: 312 | # can also use a Ray Tune optimizer, see 313 | # https://docs.ray.io/en/latest/tune/api_docs/suggestion.html#search-algorithms-tune-search 314 | # rather than adapting the current parameters, these optimizers allow you to define a search space for each 315 | # and will search on that space 316 | # be sure to close plt windows or else ray will duplicate them 317 | from pytorch_mppi import autotune_global 318 | from ray.tune.search.hyperopt import HyperOptSearch 319 | from ray.tune.search.bayesopt import BayesOptSearch 320 | 321 | params_to_tune = [autotune_global.SigmaGlobalParameter(mppi), 322 | autotune_global.HorizonGlobalParameter(mppi), 323 | autotune_global.LambdaGlobalParameter(mppi)] 324 | env.visualize = False 325 | plt.close('all') 326 | tuner = autotune_global.AutotuneGlobal(params_to_tune, evaluate_fn=evaluate, 327 | optimizer=autotune_global.RayOptimizer(HyperOptSearch)) 328 | # ray tuners cannot be tuned iteratively, but you can specify how many iterations to tune for 329 | res = tuner.optimize_all(100) 330 | env.visualize = True 331 | env.start_visualization() 332 | env.draw_rollouts(res.rollouts) 333 | env.draw_results(res.params, tuner.results) 334 | 335 | # can also use quality diversity optimization 336 | # import pytorch_mppi.autotune_qd 337 | # optim = pytorch_mppi.autotune_qd.CMAMEOpt() 338 | # tuner = autotune_global.AutotuneGlobal(mppi, params_to_tune, evaluate_fn=evaluate, 339 | # optimizer=optim) 340 | # 341 | # iterations = 10 342 | # for i in range(iterations): 343 | # # results of this optimization step are returned 344 | # res = tuner.optimize_step() 345 | # # we can render the rollouts in the environment 346 | # best_params = optim.get_diverse_top_parameters(5) 347 | # for res in best_params: 348 | # logger.info(res) 349 | 350 | except ImportError: 351 | print("To test the ray tuning, install with:\npip install 'ray[tune]' bayesian-optimization hyperopt") 352 | pass 353 | 354 | 355 | if __name__ == "__main__": 356 | main() 357 | -------------------------------------------------------------------------------- /tests/pendulum.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import logging 5 | import math 6 | from pytorch_mppi import mppi 7 | from gym import logger as gym_log 8 | 9 | gym_log.set_level(gym_log.INFO) 10 | logger = logging.getLogger(__name__) 11 | logging.basicConfig(level=logging.DEBUG, 12 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 13 | datefmt='%m-%d %H:%M:%S') 14 | 15 | if __name__ == "__main__": 16 | ENV_NAME = "Pendulum-v1" 17 | TIMESTEPS = 15 # T 18 | N_SAMPLES = 100 # K 19 | ACTION_LOW = -2.0 20 | ACTION_HIGH = 2.0 21 | 22 | d = "cpu" 23 | dtype = torch.double 24 | 25 | noise_sigma = torch.tensor(10, device=d, dtype=dtype) 26 | # noise_sigma = torch.tensor([[10, 0], [0, 10]], device=d, dtype=dtype) 27 | lambda_ = 1. 28 | 29 | 30 | def dynamics(state, perturbed_action): 31 | # true dynamics from gym 32 | th = state[:, 0].view(-1, 1) 33 | thdot = state[:, 1].view(-1, 1) 34 | 35 | g = 10 36 | m = 1 37 | l = 1 38 | dt = 0.05 39 | 40 | u = perturbed_action 41 | u = torch.clamp(u, -2, 2) 42 | 43 | newthdot = thdot + (3 * g / (2 * l) * np.sin(th) + 3.0 / (m * l ** 2) * u) * dt 44 | newthdot = np.clip(newthdot, -8, 8) 45 | newth = th + newthdot * dt 46 | 47 | state = torch.cat((newth, newthdot), dim=1) 48 | return state 49 | 50 | 51 | def angle_normalize(x): 52 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 53 | 54 | 55 | def running_cost(state, action): 56 | theta = state[:, 0] 57 | theta_dt = state[:, 1] 58 | action = action[:, 0] 59 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt ** 2 60 | return cost 61 | 62 | 63 | def train(new_data): 64 | pass 65 | 66 | 67 | downward_start = True 68 | env = gym.make(ENV_NAME, render_mode="human") 69 | 70 | env.reset() 71 | if downward_start: 72 | env.state = env.unwrapped.state = [np.pi, 1] 73 | 74 | nx = 2 75 | mppi_gym = mppi.MPPI(dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS, 76 | lambda_=lambda_, u_min=torch.tensor(ACTION_LOW, device=d), 77 | u_max=torch.tensor(ACTION_HIGH, device=d), device=d) 78 | total_reward = mppi.run_mppi(mppi_gym, env, train) 79 | logger.info("Total reward %f", total_reward) 80 | -------------------------------------------------------------------------------- /tests/pendulum_approximate.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import logging 5 | import math 6 | from pytorch_mppi import mppi 7 | from gym import logger as gym_log 8 | 9 | gym_log.set_level(gym_log.INFO) 10 | logger = logging.getLogger(__name__) 11 | logging.basicConfig(level=logging.INFO, 12 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 13 | datefmt='%m-%d %H:%M:%S') 14 | 15 | if __name__ == "__main__": 16 | ENV_NAME = "Pendulum-v1" 17 | TIMESTEPS = 30 # T 18 | N_SAMPLES = 1000 # K 19 | ACTION_LOW = -2.0 20 | ACTION_HIGH = 2.0 21 | 22 | d = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 23 | dtype = torch.double 24 | 25 | noise_sigma = torch.tensor(1, device=d, dtype=dtype) 26 | # noise_sigma = torch.tensor([[10, 0], [0, 10]], device=d, dtype=dtype) 27 | lambda_ = 1. 28 | 29 | import random 30 | 31 | randseed = 25 32 | if randseed is None: 33 | randseed = random.randint(0, 1000000) 34 | random.seed(randseed) 35 | np.random.seed(randseed) 36 | torch.manual_seed(randseed) 37 | logger.info("random seed %d", randseed) 38 | 39 | # new hyperparmaeters for approximate dynamics 40 | H_UNITS = 32 41 | TRAIN_EPOCH = 150 42 | BOOT_STRAP_ITER = 100 43 | 44 | nx = 2 45 | nu = 1 46 | # network output is state residual 47 | network = torch.nn.Sequential( 48 | torch.nn.Linear(nx + nu, H_UNITS), 49 | torch.nn.Tanh(), 50 | torch.nn.Linear(H_UNITS, H_UNITS), 51 | torch.nn.Tanh(), 52 | torch.nn.Linear(H_UNITS, nx) 53 | ).double().to(device=d) 54 | 55 | 56 | def dynamics(state, perturbed_action): 57 | u = torch.clamp(perturbed_action, ACTION_LOW, ACTION_HIGH) 58 | if state.dim() == 1 or u.dim() == 1: 59 | state = state.view(1, -1) 60 | u = u.view(1, -1) 61 | if u.shape[1] > 1: 62 | u = u[:, 0].view(-1, 1) 63 | xu = torch.cat((state, u), dim=1) 64 | state_residual = network(xu) 65 | next_state = state + state_residual 66 | next_state[:, 0] = angle_normalize(next_state[:, 0]) 67 | return next_state 68 | 69 | 70 | def true_dynamics(state, perturbed_action): 71 | # true dynamics from gym 72 | th = state[:, 0].view(-1, 1) 73 | thdot = state[:, 1].view(-1, 1) 74 | 75 | g = 10 76 | m = 1 77 | l = 1 78 | dt = 0.05 79 | 80 | u = perturbed_action 81 | u = torch.clamp(u, -2, 2) 82 | 83 | newthdot = thdot + (-3 * g / (2 * l) * torch.sin(th + np.pi) + 3. / (m * l ** 2) * u) * dt 84 | newth = th + newthdot * dt 85 | newthdot = torch.clamp(newthdot, -8, 8) 86 | 87 | state = torch.cat((newth, newthdot), dim=1) 88 | return state 89 | 90 | 91 | def angular_diff_batch(a, b): 92 | """Angle difference from b to a (a - b)""" 93 | d = a - b 94 | d[d > math.pi] -= 2 * math.pi 95 | d[d < -math.pi] += 2 * math.pi 96 | return d 97 | 98 | 99 | def angle_normalize(x): 100 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 101 | 102 | 103 | def running_cost(state, action): 104 | theta = state[:, 0] 105 | theta_dt = state[:, 1] 106 | action = action[:, 0] 107 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt ** 2 108 | return cost 109 | 110 | 111 | dataset = None 112 | # create some true dynamics validation set to compare model against 113 | Nv = 1000 114 | statev = torch.cat(((torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * 2 * math.pi, 115 | (torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * 16), dim=1) 116 | actionv = (torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * (ACTION_HIGH - ACTION_LOW) 117 | 118 | 119 | def train(new_data): 120 | global dataset 121 | # not normalized inside the simulator 122 | new_data[:, 0] = angle_normalize(new_data[:, 0]) 123 | if not torch.is_tensor(new_data): 124 | new_data = torch.from_numpy(new_data) 125 | # clamp actions 126 | new_data[:, -1] = torch.clamp(new_data[:, -1], ACTION_LOW, ACTION_HIGH) 127 | new_data = new_data.to(device=d) 128 | # append data to whole dataset 129 | if dataset is None: 130 | dataset = new_data 131 | else: 132 | dataset = torch.cat((dataset, new_data), dim=0) 133 | 134 | # train on the whole dataset (assume small enough we can train on all together) 135 | XU = dataset 136 | dtheta = angular_diff_batch(XU[1:, 0], XU[:-1, 0]) 137 | dtheta_dt = XU[1:, 1] - XU[:-1, 1] 138 | Y = torch.cat((dtheta.view(-1, 1), dtheta_dt.view(-1, 1)), dim=1) # x' - x residual 139 | XU = XU[:-1] # make same size as Y 140 | 141 | # thaw network 142 | for param in network.parameters(): 143 | param.requires_grad = True 144 | 145 | optimizer = torch.optim.Adam(network.parameters()) 146 | for epoch in range(TRAIN_EPOCH): 147 | optimizer.zero_grad() 148 | # MSE loss 149 | Yhat = network(XU) 150 | loss = (Y - Yhat).norm(2, dim=1) ** 2 151 | loss.mean().backward() 152 | optimizer.step() 153 | logger.debug("ds %d epoch %d loss %f", dataset.shape[0], epoch, loss.mean().item()) 154 | 155 | # freeze network 156 | for param in network.parameters(): 157 | param.requires_grad = False 158 | 159 | # evaluate network against true dynamics 160 | yt = true_dynamics(statev, actionv) 161 | yp = dynamics(statev, actionv) 162 | dtheta = angular_diff_batch(yp[:, 0], yt[:, 0]) 163 | dtheta_dt = yp[:, 1] - yt[:, 1] 164 | E = torch.cat((dtheta.view(-1, 1), dtheta_dt.view(-1, 1)), dim=1).norm(dim=1) 165 | logger.info("Error with true dynamics theta %f theta_dt %f norm %f", dtheta.abs().mean(), 166 | dtheta_dt.abs().mean(), E.mean()) 167 | logger.debug("Start next collection sequence") 168 | 169 | 170 | downward_start = True 171 | env = gym.make(ENV_NAME, render_mode="human") # bypass the default TimeLimit wrapper 172 | env.reset() 173 | if downward_start: 174 | env.state = env.unwrapped.state = [np.pi, 1] 175 | 176 | # bootstrap network with random actions 177 | if BOOT_STRAP_ITER: 178 | logger.info("bootstrapping with random action for %d actions", BOOT_STRAP_ITER) 179 | new_data = np.zeros((BOOT_STRAP_ITER, nx + nu)) 180 | for i in range(BOOT_STRAP_ITER): 181 | pre_action_state = env.state 182 | action = np.random.uniform(low=ACTION_LOW, high=ACTION_HIGH) 183 | env.step([action]) 184 | # env.render() 185 | new_data[i, :nx] = pre_action_state 186 | new_data[i, nx:] = action 187 | 188 | train(new_data) 189 | logger.info("bootstrapping finished") 190 | 191 | env.reset() 192 | if downward_start: 193 | env.state = env.unwrapped.state = [np.pi, 1] 194 | 195 | mppi_gym = mppi.MPPI(dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS, 196 | lambda_=lambda_, device=d, u_min=torch.tensor(ACTION_LOW, dtype=torch.double, device=d), 197 | u_max=torch.tensor(ACTION_HIGH, dtype=torch.double, device=d)) 198 | total_reward, data = mppi.run_mppi(mppi_gym, env, train) 199 | logger.info("Total reward %f", total_reward) 200 | -------------------------------------------------------------------------------- /tests/pendulum_approximate_continuous.py: -------------------------------------------------------------------------------- 1 | """ 2 | Same as approximate dynamics, but now the input is sine and cosine of theta (output is still dtheta) 3 | This is a continuous representation of theta, which some papers show is easier for a NN to learn. 4 | """ 5 | import gym 6 | import numpy as np 7 | import torch 8 | import logging 9 | import math 10 | from pytorch_mppi import mppi 11 | from gym import logger as gym_log 12 | 13 | gym_log.set_level(gym_log.INFO) 14 | logger = logging.getLogger(__name__) 15 | logging.basicConfig(level=logging.INFO, 16 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 17 | datefmt='%m-%d %H:%M:%S') 18 | 19 | if __name__ == "__main__": 20 | ENV_NAME = "Pendulum-v1" 21 | TIMESTEPS = 15 # T 22 | N_SAMPLES = 100 # K 23 | ACTION_LOW = -2.0 24 | ACTION_HIGH = 2.0 25 | 26 | d = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | dtype = torch.double 28 | 29 | noise_sigma = torch.tensor(1, device=d, dtype=dtype) 30 | # noise_sigma = torch.tensor([[10, 0], [0, 10]], device=d, dtype=dtype) 31 | lambda_ = 1. 32 | 33 | import random 34 | 35 | randseed = 24 36 | if randseed is None: 37 | randseed = random.randint(0, 1000000) 38 | random.seed(randseed) 39 | np.random.seed(randseed) 40 | torch.manual_seed(randseed) 41 | logger.info("random seed %d", randseed) 42 | 43 | # new hyperparmaeters for approximate dynamics 44 | H_UNITS = 32 45 | TRAIN_EPOCH = 150 46 | BOOT_STRAP_ITER = 100 47 | 48 | nx = 2 49 | nu = 1 50 | # network output is state residual 51 | network = torch.nn.Sequential( 52 | torch.nn.Linear(nx + nu + 1, H_UNITS), 53 | torch.nn.Tanh(), 54 | torch.nn.Linear(H_UNITS, H_UNITS), 55 | torch.nn.Tanh(), 56 | torch.nn.Linear(H_UNITS, nx) 57 | ).double().to(device=d) 58 | 59 | 60 | def dynamics(state, perturbed_action): 61 | u = torch.clamp(perturbed_action, ACTION_LOW, ACTION_HIGH) 62 | if state.dim() == 1 or u.dim() == 1: 63 | state = state.view(1, -1) 64 | u = u.view(1, -1) 65 | if u.shape[1] > 1: 66 | u = u[:, 0].view(-1, 1) 67 | xu = torch.cat((state, u), dim=1) 68 | # feed in cosine and sine of angle instead of theta 69 | xu = torch.cat((torch.sin(xu[:, 0]).view(-1, 1), torch.cos(xu[:, 0]).view(-1, 1), xu[:, 1:]), dim=1) 70 | state_residual = network(xu) 71 | # output dtheta directly so can just add 72 | next_state = state + state_residual 73 | next_state[:, 0] = angle_normalize(next_state[:, 0]) 74 | return next_state 75 | 76 | 77 | def true_dynamics(state, perturbed_action): 78 | # true dynamics from gym 79 | th = state[:, 0].view(-1, 1) 80 | thdot = state[:, 1].view(-1, 1) 81 | 82 | g = 10 83 | m = 1 84 | l = 1 85 | dt = 0.05 86 | 87 | u = perturbed_action 88 | u = torch.clamp(u, -2, 2) 89 | 90 | newthdot = thdot + (3 * g / (2 * l) * torch.sin(th) + 3.0 / (m * l**2) * u) * dt 91 | newthdot = torch.clip(newthdot, -8, 8) 92 | newth = th + newthdot * dt 93 | 94 | state = torch.cat((newth, newthdot), dim=1) 95 | return state 96 | 97 | 98 | def angular_diff_batch(a, b): 99 | """Angle difference from b to a (a - b)""" 100 | d = a - b 101 | d[d > math.pi] -= 2 * math.pi 102 | d[d < -math.pi] += 2 * math.pi 103 | return d 104 | 105 | 106 | def angle_normalize(x): 107 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 108 | 109 | 110 | def running_cost(state, action): 111 | theta = state[:, 0] 112 | theta_dt = state[:, 1] 113 | action = action[:, 0] 114 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt ** 2 115 | return cost 116 | 117 | 118 | dataset = None 119 | # create some true dynamics validation set to compare model against 120 | Nv = 1000 121 | statev = torch.cat(((torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * 2 * math.pi, 122 | (torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * 16), dim=1) 123 | actionv = (torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * (ACTION_HIGH - ACTION_LOW) 124 | 125 | 126 | def train(new_data): 127 | global dataset 128 | # not normalized inside the simulator 129 | new_data[:, 0] = angle_normalize(new_data[:, 0]) 130 | if not torch.is_tensor(new_data): 131 | new_data = torch.from_numpy(new_data) 132 | # clamp actions 133 | new_data[:, -1] = torch.clamp(new_data[:, -1], ACTION_LOW, ACTION_HIGH) 134 | new_data = new_data.to(device=d) 135 | # append data to whole dataset 136 | if dataset is None: 137 | dataset = new_data 138 | else: 139 | dataset = torch.cat((dataset, new_data), dim=0) 140 | 141 | # train on the whole dataset (assume small enough we can train on all together) 142 | XU = dataset 143 | dtheta = angular_diff_batch(XU[1:, 0], XU[:-1, 0]) 144 | dtheta_dt = XU[1:, 1] - XU[:-1, 1] 145 | Y = torch.cat((dtheta.view(-1, 1), dtheta_dt.view(-1, 1)), dim=1) # x' - x residual 146 | xu = XU[:-1] # make same size as Y 147 | xu = torch.cat((torch.sin(xu[:, 0]).view(-1, 1), torch.cos(xu[:, 0]).view(-1, 1), xu[:, 1:]), dim=1) 148 | 149 | # thaw network 150 | for param in network.parameters(): 151 | param.requires_grad = True 152 | 153 | optimizer = torch.optim.Adam(network.parameters()) 154 | for epoch in range(TRAIN_EPOCH): 155 | optimizer.zero_grad() 156 | # MSE loss 157 | Yhat = network(xu) 158 | loss = (Y - Yhat).norm(2, dim=1) ** 2 159 | loss.mean().backward() 160 | optimizer.step() 161 | logger.debug("ds %d epoch %d loss %f", dataset.shape[0], epoch, loss.mean().item()) 162 | 163 | # freeze network 164 | for param in network.parameters(): 165 | param.requires_grad = False 166 | 167 | # evaluate network against true dynamics 168 | yt = true_dynamics(statev, actionv) 169 | yp = dynamics(statev, actionv) 170 | dtheta = angular_diff_batch(yp[:, 0], yt[:, 0]) 171 | dtheta_dt = yp[:, 1] - yt[:, 1] 172 | E = torch.cat((dtheta.view(-1, 1), dtheta_dt.view(-1, 1)), dim=1).norm(dim=1) 173 | logger.info("Error with true dynamics theta %f theta_dt %f norm %f", dtheta.abs().mean(), 174 | dtheta_dt.abs().mean(), E.mean()) 175 | logger.debug("Start next collection sequence") 176 | 177 | 178 | downward_start = True 179 | env = gym.make(ENV_NAME, render_mode="human") # bypass the default TimeLimit wrapper 180 | env.reset() 181 | if downward_start: 182 | env.state = env.unwrapped.state = [np.pi, 1] 183 | 184 | # bootstrap network with random actions 185 | if BOOT_STRAP_ITER: 186 | logger.info("bootstrapping with random action for %d actions", BOOT_STRAP_ITER) 187 | new_data = np.zeros((BOOT_STRAP_ITER, nx + nu)) 188 | for i in range(BOOT_STRAP_ITER): 189 | pre_action_state = env.state 190 | action = np.random.uniform(low=ACTION_LOW, high=ACTION_HIGH) 191 | env.step([action]) 192 | # env.render() 193 | new_data[i, :nx] = pre_action_state 194 | new_data[i, nx:] = action 195 | 196 | train(new_data) 197 | logger.info("bootstrapping finished") 198 | 199 | env.reset() 200 | if downward_start: 201 | env.state = env.unwrapped.state = [np.pi, 1] 202 | 203 | mppi_gym = mppi.MPPI(dynamics, running_cost, nx, noise_sigma, num_samples=N_SAMPLES, horizon=TIMESTEPS, 204 | lambda_=lambda_, device=d, u_min=torch.tensor(ACTION_LOW, dtype=torch.double, device=d), 205 | u_max=torch.tensor(ACTION_HIGH, dtype=torch.double, device=d)) 206 | total_reward, data = mppi.run_mppi(mppi_gym, env, train) 207 | logger.info("Total reward %f", total_reward) 208 | -------------------------------------------------------------------------------- /tests/smooth_mppi.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from arm_pytorch_utilities import linalg, handle_batch_input, sort_nicely, cache 6 | import matplotlib.colors 7 | from matplotlib import pyplot as plt 8 | import os 9 | 10 | from pytorch_mppi import mppi 11 | import pytorch_seed 12 | import logging 13 | import collections 14 | 15 | def is_sequence(obj): 16 | if isinstance(obj, str): 17 | return False 18 | return isinstance(obj, collections.abc.Sequence) 19 | 20 | 21 | plt.switch_backend('Qt5Agg') 22 | 23 | logger = logging.getLogger(__file__) 24 | logging.basicConfig(level=logging.INFO, 25 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 26 | datefmt='%m-%d %H:%M:%S') 27 | 28 | 29 | class LinearDeltaDynamics: 30 | def __init__(self, B): 31 | self.B = B 32 | 33 | @handle_batch_input(n=2) 34 | def __call__(self, state, action): 35 | nx = state + action @ self.B.transpose(0, 1) 36 | return nx 37 | 38 | 39 | class ScaledLinearDynamics: 40 | def __init__(self, cost, B): 41 | self.B = B 42 | self.cost = cost 43 | 44 | @handle_batch_input(n=2) 45 | def __call__(self, state, action): 46 | nx = state + action @ self.B.transpose(0, 1) / torch.log(self.cost(state) + 1e-8).reshape(-1, 1) * 2 47 | return nx 48 | 49 | 50 | class LQRCost: 51 | def __init__(self, Q, R, goal): 52 | self.Q = Q 53 | self.R = R 54 | self.goal = goal 55 | 56 | @handle_batch_input(n=2) 57 | def __call__(self, state, action=None): 58 | dx = self.goal - state 59 | c = linalg.batch_quadratic_product(dx, self.Q) 60 | if action is not None: 61 | c += linalg.batch_quadratic_product(action, self.R) 62 | return c 63 | 64 | 65 | class HillCost: 66 | def __init__(self, Q, center, cost_at_center=1): 67 | self.Q = Q 68 | self.center = center 69 | self.cost_at_center = cost_at_center 70 | 71 | @handle_batch_input(n=2) 72 | def __call__(self, state, action=None): 73 | dx = self.center - state 74 | d = linalg.batch_quadratic_product(dx, self.Q) 75 | c = self.cost_at_center * torch.exp(-d) 76 | return c 77 | 78 | 79 | class Toy2DEnvironment: 80 | def __init__(self, start=None, goal=None, dtype=torch.double, device="cpu", evaluate_running_cost=True, 81 | visualize=True, 82 | num_trajectories=5, 83 | terminal_scale=100, 84 | r=0.01): 85 | self.d = device 86 | self.dtype = dtype 87 | self.state_ranges = [ 88 | (-5, 5), 89 | (-5, 5) 90 | ] 91 | self.evaluate_running_cost = evaluate_running_cost 92 | self.num_trajectories = num_trajectories 93 | self.visualize = visualize 94 | self.nx = 2 95 | 96 | self.start = start or torch.tensor([-3, -2], device=self.d, dtype=self.dtype) 97 | self.goal = goal or torch.tensor([2, 2], device=self.d, dtype=self.dtype) 98 | self.state = self.start 99 | 100 | self.costs = [] 101 | 102 | eye = torch.eye(2, device=self.d, dtype=self.dtype) 103 | goal_cost = LQRCost(eye, eye * r, self.goal) 104 | self.costs.append(goal_cost) 105 | 106 | # for increasing difficulty, we add some "hills" 107 | self.costs.append(HillCost(torch.tensor([[0.1, 0.05], [0.05, 0.1]], device=self.d, dtype=self.dtype) * 2.5, 108 | torch.tensor([-0.5, -1.], device=self.d, dtype=self.dtype), cost_at_center=200)) 109 | 110 | B = torch.tensor([[0.5, 0], [0, -0.5]], device=self.d, dtype=self.dtype) 111 | self.dynamics = LinearDeltaDynamics(B) 112 | # self.dynamics = ScaledLinearDynamics(self.running_cost, B) 113 | 114 | self.terminal_scale = terminal_scale 115 | 116 | self.trajectory_artist = None 117 | 118 | self.start_visualization() 119 | 120 | def reset(self): 121 | self.state = self.start 122 | self.clear_artist(self.rollout_artist) 123 | self.rollout_artist = None 124 | self.clear_trajectory() 125 | self.trajectory_artist = None 126 | 127 | def step(self, action): 128 | self.state = self.dynamics(self.state, action) 129 | return self.state 130 | 131 | def terminal_cost(self, states, actions): 132 | return self.terminal_scale * self.running_cost(states[..., -1, :]) 133 | 134 | @handle_batch_input(n=2) 135 | def running_cost(self, state, action=None): 136 | c = None 137 | for cost in self.costs: 138 | if c is None: 139 | c = cost(state, action) 140 | else: 141 | c += cost(state, action) 142 | return c 143 | 144 | def start_visualization(self): 145 | if self.visualize: 146 | plt.ion() 147 | plt.show() 148 | 149 | self.fig, self.ax = plt.subplots(figsize=(7, 7)) 150 | self.ax.set_aspect('equal') 151 | self.ax.set(xlim=self.state_ranges[0]) 152 | self.ax.set(ylim=self.state_ranges[0]) 153 | 154 | self.cmap = "Greys" 155 | # artists for clearing / redrawing 156 | self.start_artist = None 157 | self.goal_artist = None 158 | self.cost_artist = None 159 | self.rollout_artist = None 160 | self.draw_costs() 161 | # self.draw_start() 162 | self.draw_goal() 163 | 164 | @staticmethod 165 | def get_v(i, values, rollouts): 166 | if type(values) != str and is_sequence(values) and len(values) == len(rollouts): 167 | c = values[i] 168 | else: 169 | c = values 170 | return c 171 | 172 | def draw_rollouts(self, rollouts, color="skyblue", label=None, end_state_color="tab:red", linewidth=1.5): 173 | if not self.visualize: 174 | return 175 | self.clear_artist(self.rollout_artist) 176 | artists = [] 177 | for i, rollout in enumerate(rollouts): 178 | # prepend start state 179 | rollout = torch.cat([self.state.view(1, -1), rollout], dim=0) 180 | r = rollout.cpu() 181 | artists += [self.ax.scatter(r[0, 0], r[0, 1], color="tab:blue")] 182 | # if color is a string treat it as a single color, otherwise treat it as a list of colors 183 | artists += self.ax.plot(r[:, 0], r[:, 1], 184 | color=self.get_v(i, color, rollouts), 185 | label=self.get_v(i, label, rollouts), 186 | linewidth=self.get_v(i, linewidth, rollouts)) 187 | artists += [self.ax.scatter(r[-1, 0], r[-1, 1], color=self.get_v(i, end_state_color, rollouts))] 188 | self.rollout_artist = artists 189 | if label is not None: 190 | self.ax.legend(loc = "upper right") 191 | plt.pause(0.001) 192 | 193 | def draw_trajectory_step(self, prev_state, cur_state, color="tab:blue"): 194 | if not self.visualize: 195 | return 196 | if self.trajectory_artist is None: 197 | self.trajectory_artist = [] 198 | artists = self.trajectory_artist 199 | artists += self.ax.plot([prev_state[0].cpu(), cur_state[0].cpu()], 200 | [prev_state[1].cpu(), cur_state[1].cpu()], color=color) 201 | plt.draw() 202 | plt.pause(0.001) 203 | 204 | def clear_trajectory(self): 205 | self.clear_artist(self.trajectory_artist) 206 | 207 | def draw_costs(self, resolution=0.05, value_padding=0): 208 | if not self.visualize: 209 | return 210 | coords = [torch.arange(low, high + resolution, resolution, dtype=self.dtype, device=self.d) for low, high in 211 | self.state_ranges] 212 | pts = torch.cartesian_prod(*coords) 213 | val = self.running_cost(pts) 214 | 215 | norm = matplotlib.colors.Normalize(vmin=val.min().cpu() - value_padding, vmax=val.max().cpu()) 216 | 217 | x = coords[0].cpu() 218 | z = coords[1].cpu() 219 | v = val.reshape(len(x), len(z)).transpose(0, 1).cpu() 220 | 221 | self.clear_artist(self.cost_artist) 222 | a = [] 223 | a.append(self.ax.contourf(x, z, v, levels=[2, 4, 8, 16, 24, 32, 40, 50, 60, 80, 100, 150, 200, 250], norm=norm, 224 | cmap=self.cmap)) 225 | # a.append(self.ax.contourf(x, z, v, norm=norm, 226 | # cmap=self.cmap)) 227 | # reduce opacity 228 | a.append(self.ax.contour(x, z, v, levels=a[0].levels, colors='k', linestyles='dashed')) 229 | a.append(self.ax.clabel(a[1], a[1].levels, inline=True, fontsize=13)) 230 | a[1].set_alpha(0.3) 231 | self.cost_artist = a 232 | 233 | plt.draw() 234 | plt.pause(0.0005) 235 | 236 | @staticmethod 237 | def clear_artist(artist): 238 | if artist is not None: 239 | for a in artist: 240 | a.remove() 241 | 242 | def draw_start(self): 243 | if not self.visualize: 244 | return 245 | self.clear_artist(self.start_artist) 246 | self.start_artist = self.draw_state(self.start, "tab:blue", label='start') 247 | 248 | def draw_goal(self): 249 | # when combined with other costs it's no longer the single goal so no need for label 250 | return 251 | if not self.visualize: 252 | return 253 | self.clear_artist(self.goal_artist) 254 | # when combined with other costs it's no longer the single goal so no need for label 255 | self.goal_artist = self.draw_state(self.goal, "tab:green") # , label='goal') 256 | 257 | def draw_state(self, state, color, label=None, ox=-0.3, oy=0.3): 258 | artists = [self.ax.scatter(state[0].cpu(), state[1].cpu(), color=color)] 259 | if label is not None: 260 | artists.append(self.ax.text(state[0].cpu() + ox, state[1].cpu() + oy, label, color=color)) 261 | plt.pause(0.0001) 262 | return artists 263 | 264 | 265 | def make_gif(imgs_dir, gif_name): 266 | import imageio 267 | images = [] 268 | # human sort 269 | names = os.listdir(imgs_dir) 270 | sort_nicely(names) 271 | for filename in names: 272 | if filename.endswith(".png"): 273 | images.append(imageio.v2.imread(os.path.join(imgs_dir, filename))) 274 | imageio.mimsave(gif_name, images, duration=0.1) 275 | 276 | 277 | def make_gif_ffmpeg(imgs_dir, gif_name, fps=6): 278 | import subprocess 279 | # first generate palette and then use it to generate gif 280 | palette_path = os.path.join("images", "palette.png") 281 | cmd = ["ffmpeg", "-y", "-i", os.path.join(imgs_dir, "%d.png"), "-vf", "palettegen", palette_path] 282 | subprocess.run(cmd) 283 | cmd = ["ffmpeg", "-y", "-framerate", str(fps), "-i", os.path.join(imgs_dir, "%d.png"), "-i", palette_path, 284 | "-lavfi", "paletteuse", gif_name] 285 | subprocess.run(cmd) 286 | 287 | 288 | def do_control(env, ctrl, ch, seeds=(0,), run_steps=20, num_refinement_steps=1, save_img=True, plot_single=False, 289 | evaluate_running_cost=True, plot_trajectory_candidates=False): 290 | if save_img: 291 | os.makedirs("images", exist_ok=True) 292 | os.makedirs("images/runs", exist_ok=True) 293 | os.makedirs("images/gif", exist_ok=True) 294 | 295 | for seed in seeds: 296 | pytorch_seed.seed(seed) 297 | key = f"{ctrl.__class__.__name__}" 298 | 299 | # use the same nominal trajectory to start with for all the evaluations for fairness 300 | # parameters for our sample evaluation function - lots of choices for the evaluation function 301 | rollout_costs = [] 302 | actual_costs = [] 303 | controls = [] 304 | state = env.state 305 | # we sample multiple trajectories for the same start to goal problem, but in your case you should consider 306 | # evaluating over a diverse dataset of trajectories 307 | for i in range(run_steps): 308 | # mppi.U = nominal_trajectory.clone() 309 | # the nominal trajectory at the start will be different if the horizon's changed 310 | # mppi.change_horizon(mppi.T) 311 | # usually MPPI will have its nominal trajectory warm-started from the previous iteration 312 | # for a fair test of tuning we will reset its nominal trajectory to the same random one each time 313 | # we manually warm it by refining it for some steps 314 | u = None 315 | for k in range(num_refinement_steps): 316 | last_refinement = k == num_refinement_steps - 1 317 | u = ctrl.command(state, shift_nominal_trajectory=last_refinement) 318 | 319 | rollout = ctrl.get_rollouts(state) 320 | 321 | rollout_cost = 0 322 | this_cost = env.running_cost(state) 323 | rollout = rollout[0] 324 | # here we evaluate on the rollout MPPI cost of the resulting trajectories 325 | # alternative costs for tuning the parameters are possible, such as just considering terminal cost 326 | if evaluate_running_cost: 327 | for t in range(len(rollout) - 1): 328 | rollout_cost = rollout_cost + env.running_cost(rollout[t], ctrl.U[t]) 329 | rollout_cost = rollout_cost + env.terminal_cost(rollout, ctrl.U) 330 | 331 | prev_state = copy.deepcopy(state) 332 | state = env.step(u) 333 | 334 | if plot_trajectory_candidates: 335 | from matplotlib import cm 336 | # only plot some candidates rather than all of them 337 | num_candidates = min(10, ctrl.K) 338 | # for the combined trajectory 339 | color = [] 340 | end_color = [] 341 | rollouts = [] 342 | linewidth = [] 343 | labels = [] 344 | # for all the candidates 345 | # create matplotlib color map based on cost 346 | cost = ctrl.cost_total.cpu() 347 | best_idx = torch.argsort(cost) 348 | 349 | norm = matplotlib.colors.Normalize(vmin=cost.min(), vmax=cost[best_idx][:num_candidates*5].max()) 350 | m = cm.ScalarMappable(norm=norm, cmap=cm.jet) 351 | 352 | traj_color = m.to_rgba(cost) 353 | # lower alpha 354 | traj_color[:, 3] = 0.2 355 | 356 | # get rollouts per sampled action trajectory 357 | for j in range(num_candidates): 358 | idx = best_idx[j] 359 | this_U = ctrl.actions[0, idx] 360 | rollouts.append(ctrl.get_rollouts(state, U=this_U)[0]) 361 | color.append(traj_color[idx]) 362 | end_color.append([1, 0, 0, 0.2]) 363 | linewidth.append(1) 364 | labels.append(None) 365 | color.append("skyblue") 366 | end_color.append("tab:red") 367 | rollouts.append(rollout) 368 | linewidth.append(2) 369 | labels.append(key) 370 | env.draw_rollouts(rollouts, color=color, end_state_color=end_color, linewidth=linewidth, label=labels) 371 | else: 372 | # just draw the single state rollout 373 | env.draw_rollouts([rollout], label=key) 374 | 375 | env.draw_trajectory_step(prev_state, state) 376 | 377 | if save_img: 378 | plt.savefig(f"images/runs/{i}.png") 379 | 380 | print(f"step {i} state {state} current cost {this_cost} rollout cost {rollout_cost}") 381 | actual_costs.append(this_cost.cpu()) 382 | rollout_costs.append(rollout_cost.cpu()) 383 | controls.append(u.cpu()) 384 | 385 | controls = torch.stack(controls) 386 | # consider total difference 387 | control_diff = torch.diff(controls, dim=0) 388 | print(f"total accumulated cost: {sum(actual_costs)}") 389 | print(f"total accumulated rollout cost: {sum(rollout_costs)}") 390 | env.reset() 391 | ctrl.reset() 392 | 393 | secondary_key = (seed, ctrl.get_params()) 394 | # make_gif("images/runs", f"images/gif/{key}_{seed}.gif") 395 | make_gif_ffmpeg("images/runs", f"images/gif/{key}_{seed}.gif", fps=10) 396 | if key not in ch: 397 | ch[key] = {} 398 | ch[key][secondary_key] = { 399 | "actual_costs": actual_costs, 400 | "rollout_costs": rollout_costs, 401 | "controls": controls, 402 | "control_diff": control_diff, 403 | } 404 | ch.save() 405 | 406 | if plot_single: 407 | # plot the costs with the step 408 | fig, ax = plt.subplots(nrows=3, sharex=True, figsize=(8, 14)) 409 | # xlim 0 410 | ax[0].set_xlim(0, run_steps - 1) 411 | # tick on x for every step discretely 412 | ax[0].plot(actual_costs) 413 | ax[0].set_title(f"actual costs total: {sum(actual_costs)}") 414 | ax[0].set_ylim(0, max(actual_costs) * 1.1) 415 | ax[1].plot(rollout_costs) 416 | ax[1].set_title(f"rollout costs total: {sum(rollout_costs)}") 417 | # set the y axis to be log scale 418 | ax[1].set_yscale('log') 419 | 420 | ax[2].plot(controls[:, 0], label='u0') 421 | ax[2].plot(controls[:, 1], label='u1') 422 | ax[2].legend() 423 | ax[2].set_title(f"control inputs total diff: {control_diff.abs().sum()}") 424 | ax[2].set_xticks(range(run_steps)) 425 | plt.tight_layout() 426 | plt.pause(0.001) 427 | input("Press Enter to close the window and exit...") 428 | 429 | 430 | def plot_result(ch): 431 | num_steps = 20 432 | 433 | def simplify_name(name): 434 | # remove shared parameters 435 | return name.replace("K=500 T=20 M=1 ", "").replace("noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]]", 436 | "").replace("_lambda=1 ", " ") 437 | 438 | methods = {} 439 | for key, values in ch.items(): 440 | for secondary_key, data in values.items(): 441 | actual_costs = torch.tensor(data["actual_costs"]) 442 | rollout_costs = torch.tensor(data["rollout_costs"]) 443 | controls = data["controls"] 444 | control_diff = data["control_diff"] 445 | 446 | method_name = f"{key}_{secondary_key[1]}" 447 | if method_name not in methods: 448 | methods[method_name] = { 449 | "actual_costs": [actual_costs], 450 | "rollout_costs": [rollout_costs], 451 | "controls": [controls], 452 | "control_diff": [control_diff], 453 | } 454 | else: 455 | m = methods[method_name] 456 | m["actual_costs"].append(actual_costs) 457 | m["rollout_costs"].append(rollout_costs) 458 | m["controls"].append(controls) 459 | m["control_diff"].append(control_diff) 460 | 461 | method_names = "\n".join(methods.keys()) 462 | print(f"all method keys\n{method_names}") 463 | 464 | allowed_names = [ 465 | # "MPPI_K=500 T=20 M=1 lambda=1 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]]", 466 | # "SMPPI_K=500 T=20 M=1 lambda=1 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]] w=5 t=1.0", 467 | # "SMPPI_K=500 T=20 M=1 lambda=10 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]] w=10 t=1.0", 468 | # "KMPPI_K=500 T=20 M=1 lambda=1 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]] num_support_pts=5 kernel=rbf4theta", 469 | # "KMPPI_K=500 T=20 M=1 lambda=1 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]] num_support_pts=5 kernel=RBFKernel(sigma=1.5)", 470 | "KMPPI_K=500 T=20 M=1 lambda=1 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]] num_support_pts=5 kernel=RBFKernel(sigma=2)", 471 | # "KMPPI_K=500 T=20 M=1 lambda=10 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]] num_support_pts=5 kernel=RBFKernel(sigma=2)", 472 | # "KMPPI_K=500 T=20 M=1 lambda=1 noise_mu=[0. 0.] noise_sigma=[[1. 0.], [0. 1.]] num_support_pts=5 kernel=RBFKernel(sigma=3)", 473 | ] 474 | 475 | fig, ax = plt.subplots(nrows=2, sharex=True, figsize=(8, 14)) 476 | ax[0].set_xlim(0, num_steps - 1) 477 | # only set the min of y to be 0 478 | ax[0].set_title(f"trajectory cost") 479 | ax[1].set_title(f"rollout cost") 480 | ax[1].set_yscale('log') 481 | ax[1].set_xticks(range(num_steps)) 482 | ax[1].set_xlabel("step") 483 | f, a = plt.subplots() 484 | a.set_title(f"control inputs total diff") 485 | # tick on x for every step discretely 486 | for method in allowed_names: 487 | data = methods[method] 488 | method = simplify_name(method) 489 | actual_costs = torch.stack(data["actual_costs"]) 490 | rollout_costs = torch.stack(data["rollout_costs"]) 491 | controls = data["controls"] 492 | control_diff = data["control_diff"] 493 | for i, v in enumerate([actual_costs, rollout_costs]): 494 | # plot the median along dim 0 and the 25th and 75th percentile 495 | lower = torch.quantile(v, .25, dim=0) 496 | upper = torch.quantile(v, .75, dim=0) 497 | ax[i].fill_between(range(num_steps), lower, upper, alpha=0.2) 498 | ax[i].plot(v.median(dim=0)[0], label=method) 499 | 500 | # compute total control diff 501 | control_diff = torch.stack(control_diff) 502 | total_diff = control_diff.abs().sum(dim=(1, 2)) 503 | c1 = actual_costs.sum(dim=1) 504 | c2 = rollout_costs.sum(dim=1) 505 | print( 506 | f"method {method}\ntrajectory cost {c1.mean():.1f} ({c1.std():.1f})\nrollout cost {c2.mean():.1f} ({c2.std():.1f})\ncontrol diff {total_diff.mean():.1f} ({total_diff.std():.1f})") 507 | # plot frequency of total control diff 508 | # kernel density estimate 509 | from scipy.stats import gaussian_kde 510 | density = gaussian_kde(total_diff) 511 | density.covariance_factor = lambda: .25 512 | density._compute_covariance() 513 | xs = torch.linspace(0, total_diff.max() * 1.2, 50) 514 | a.plot(xs, density(xs), label=method) 515 | # plot histogram of total control diff 516 | 517 | ax[0].set_ylim(0, None) 518 | ax[1].legend() 519 | a.set_ylim(0, None) 520 | a.set_xlim(0, None) 521 | a.legend() 522 | plt.show() 523 | plt.tight_layout() 524 | input("Press Enter to close the window and exit...") 525 | 526 | 527 | def main(plot_only=False): 528 | device = "cuda" if torch.cuda.is_available() else "cpu" 529 | dtype = torch.double 530 | pytorch_seed.seed(2) 531 | ch = cache.LocalCache("mppi_res.pkl") 532 | 533 | if plot_only: 534 | plot_result(ch) 535 | return 536 | 537 | # create toy environment to do on control on (default start and goal) 538 | env = Toy2DEnvironment(visualize=True, terminal_scale=10, device=device) 539 | shared_params = { 540 | "num_samples": 500, 541 | "horizon": 20, 542 | "noise_mu": torch.zeros(2, dtype=dtype, device=device), 543 | "noise_sigma": torch.diag(torch.tensor([1., 1.], dtype=dtype, device=device)), 544 | "u_max": torch.tensor([1., 1.], dtype=dtype, device=device), 545 | "terminal_state_cost": env.terminal_cost, 546 | "lambda_": 1, 547 | "device": device, 548 | } 549 | # create MPPI with some initial parameters 550 | mmppi = mppi.MPPI(env.dynamics, env.running_cost, 2, 551 | **shared_params) 552 | smppi = mppi.SMPPI(env.dynamics, env.running_cost, 2, 553 | **shared_params, 554 | w_action_seq_cost=10, 555 | action_max=torch.tensor([1., 1.], dtype=dtype, device=device)) 556 | kmppi = mppi.KMPPI(env.dynamics, env.running_cost, 2, 557 | **shared_params, 558 | kernel=mppi.RBFKernel(sigma=2), 559 | num_support_pts=5, 560 | ) 561 | 562 | for ctrl in [mmppi, kmppi, smppi]: 563 | do_control(env, ctrl, ch, seeds=range(5), run_steps=20, num_refinement_steps=1, save_img=True, 564 | plot_single=False, evaluate_running_cost=False, plot_trajectory_candidates=True) 565 | 566 | 567 | if __name__ == "__main__": 568 | main() 569 | -------------------------------------------------------------------------------- /tests/test_batch_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_mppi.mppi import handle_batch_input 3 | 4 | 5 | @handle_batch_input(n=2) 6 | def add_2d(a, b): 7 | assert a.ndim == 2 8 | assert b.ndim == 2 9 | return a + b 10 | 11 | 12 | @handle_batch_input(n=3) 13 | def add_3d(a, b): 14 | assert a.ndim == 3 15 | assert b.ndim == 3 16 | return a + b 17 | 18 | 19 | def test_batch_wrapper_2d(): 20 | a_2d = torch.tensor([[0.1, 0.2, 0.3]]) 21 | b_2d = torch.tensor([[0.5, -0.2, 0.3]]) 22 | a_3d = torch.tile(a_2d, [1, 1, 1]) 23 | b_3d = torch.tile(b_2d, [1, 1, 1]) 24 | a_4d = torch.tile(a_3d, [2, 1, 1]) 25 | b_4d = torch.tile(b_3d, [2, 1, 1]) 26 | expected_sum_2d = torch.tensor([[0.6, 0.0, 0.6]]) 27 | expected_sum_3d = torch.tensor([[[0.6, 0.0, 0.6]]]) 28 | expected_sum_4d = torch.tensor([[[[0.6, 0.0, 0.6]]], [[[0.6, 0.0, 0.6]]]]) 29 | sum_2d = add_2d(a_2d, b_2d) 30 | sum_3d = add_2d(a_3d, b_3d) 31 | sum_4d = add_2d(a_4d, b_4d) 32 | assert torch.allclose(sum_2d, expected_sum_2d) 33 | assert torch.allclose(sum_3d, expected_sum_3d) 34 | assert torch.allclose(sum_4d, expected_sum_4d) 35 | 36 | 37 | def test_batch_wrapper_3d(): 38 | a_3d = torch.tensor([[[0.1, 0.2, 0.3]]]) 39 | b_3d = torch.tensor([[[0.5, -0.2, 0.3]]]) 40 | a_4d = torch.tile(a_3d, [2, 1, 1]) 41 | b_4d = torch.tile(b_3d, [2, 1, 1]) 42 | expected_sum_3d = torch.tensor([[[0.6, 0.0, 0.6]]]) 43 | expected_sum_4d = torch.tensor([[[[0.6, 0.0, 0.6]]], [[[0.6, 0.0, 0.6]]]]) 44 | sum_3d = add_3d(a_3d, b_3d) 45 | sum_4d = add_3d(a_4d, b_4d) 46 | assert torch.allclose(sum_3d, expected_sum_3d) 47 | assert torch.allclose(sum_4d, expected_sum_4d) 48 | --------------------------------------------------------------------------------