├── .gitignore ├── LICENSE.txt ├── README.md ├── pyproject.toml ├── src └── pytorch_icem │ ├── __init__.py │ └── icem.py └── tests └── pendulum_approximate_continuous.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | pytorch_icem.egg-info 3 | __pycache__ 4 | dist 5 | *.png -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 University of Michigan ARM Lab 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch iCEM Implementation 2 | This repository implements the improved Cross Entropy Method (iCEM) 3 | with approximate dynamics in pytorch, from [this paper](https://martius-lab.github.io/iCEM/). 4 | 5 | MPPI typically requires actual 6 | trajectory samples, but [this paper](https://martius-lab.github.io/iCEM/) 7 | showed that it could be done with approximate dynamics (such as with a neural network) 8 | using importance sampling. 9 | 10 | Thus it can be used in place of other trajectory optimization methods 11 | such as the Cross Entropy Method (CEM), or random shooting. 12 | 13 | 14 | # Related projects 15 | - [pytorch CEM](https://github.com/LemonPi/pytorch_cem) - alternative sampling based MPC 16 | - [pytorch MPPI](https://github.com/UM-ARM-Lab/pytorch_mppi) - alternative sampling based MPC 17 | - [iCEM](https://github.com/martius-lab/iCEM) - original paper's numpy implementation and experiments code 18 | 19 | 20 | # Installation 21 | ```shell 22 | pip install pytorch-icem 23 | ``` 24 | for running tests, install with 25 | ```shell 26 | pip install pytorch-icem[test] 27 | ``` 28 | for development, clone the repository then install in editable mode 29 | ```shell 30 | pip install -e . 31 | ``` 32 | 33 | # Usage 34 | See `tests/pendulum_approximate_continuous.py` for usage with a neural network approximating 35 | the pendulum dynamics. Basic use case is shown below 36 | 37 | ```python 38 | from pytorch_icem import iCEM 39 | 40 | # create controller with chosen parameters 41 | ctrl = icem.iCEM(dynamics, terminal_cost, nx, nu, sigma=sigma, 42 | warmup_iters=10, online_iters=10, 43 | num_samples=N_SAMPLES, num_elites=10, horizon=TIMESTEPS, device=d, ) 44 | 45 | # assuming you have a gym-like env 46 | obs = env.reset() 47 | for i in range(100): 48 | action = ctrl.command(obs) 49 | obs, reward, done, _, _ = env.step(action.cpu().numpy()) 50 | ``` 51 | 52 | # Requirements 53 | - pytorch (>= 1.0) 54 | - `next state <- dynamics(state, action)` function (doesn't have to be true dynamics) 55 | - `state` is `K x nx`, `action` is `K x nu` 56 | - `trajectory cost <- cost(state, action)` function for the whole state action trajectory, T is the horizon 57 | - `cost` is `K x 1`, state is `K x T x nx`, `action` is `K x T x nu` 58 | 59 | # Features 60 | - Parallel/batch pytorch implementation for accelerated sampling 61 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pytorch_icem" 3 | version = "0.1.0" 4 | description = "Improved Cross Entropy Method (iCEM) implemented in pytorch" 5 | readme = "README.md" # Optional 6 | 7 | # Specify which Python versions you support. In contrast to the 8 | # 'Programming Language' classifiers above, 'pip install' will check this 9 | # and refuse to install the project if the version does not match. See 10 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#python-requires 11 | requires-python = ">=3.6" 12 | 13 | # This is either text indicating the license for the distribution, or a file 14 | # that contains the license 15 | # https://packaging.python.org/en/latest/specifications/core-metadata/#license 16 | license = { file = "LICENSE.txt" } 17 | 18 | # This field adds keywords for your project which will appear on the 19 | # project page. What does your project relate to? 20 | # 21 | # Note that this is a list of additional keywords, separated 22 | # by commas, to be used to assist searching for the distribution in a 23 | # larger catalog. 24 | keywords = ["icem", "mpc", "pytorch", "control", "robotics"] # Optional 25 | authors = [ 26 | { name = "Sheng Zhong", email = "zhsh@umich.edu" }, 27 | { name = "Thomas Power", email = "tpower@umich.edu" }, 28 | ] 29 | maintainers = [ 30 | { name = "Sheng Zhong", email = "zhsh@umich.edu" }, 31 | { name = "Thomas Power", email = "tpower@umich.edu" }, 32 | ] 33 | 34 | # Classifiers help users find your project by categorizing it. 35 | # 36 | # For a list of valid classifiers, see https://pypi.org/classifiers/ 37 | classifiers = [# Optional 38 | "Development Status :: 4 - Beta", 39 | # Indicate who your project is intended for 40 | "Intended Audience :: Developers", 41 | # Pick your license as you wish 42 | "License :: OSI Approved :: MIT License", 43 | # Specify the Python versions you support here. In particular, ensure 44 | # that you indicate you support Python 3. These classifiers are *not* 45 | # checked by "pip install". See instead "python_requires" below. 46 | "Programming Language :: Python :: 3", 47 | "Programming Language :: Python :: 3 :: Only", 48 | ] 49 | 50 | # This field lists other packages that your project depends on to run. 51 | # Any package you put here will be installed by pip when your project is 52 | # installed, so they must be valid existing projects. 53 | # 54 | # For an analysis of this field vs pip's requirements files see: 55 | # https://packaging.python.org/discussions/install-requires-vs-requirements/ 56 | dependencies = [# Optional 57 | 'torch', 58 | 'numpy', 59 | 'colorednoise', 60 | 'arm-pytorch-utilities>=0.4', 61 | ] 62 | 63 | # List additional groups of dependencies here (e.g. development 64 | # dependencies). Users will be able to install these using the "extras" 65 | # syntax, for example: 66 | # 67 | # $ pip install sampleproject[dev] 68 | # 69 | # Similar to `dependencies` above, these must be valid existing 70 | # projects. 71 | [project.optional-dependencies] # Optional 72 | test = [ 73 | "pytest", 74 | 'gym', 75 | 'pygame', 76 | 'pyglet==1.5.27', 77 | 'window-recorder', 78 | ] 79 | 80 | # List URLs that are relevant to your project 81 | # 82 | # This field corresponds to the "Project-URL" and "Home-Page" metadata fields: 83 | # https://packaging.python.org/specifications/core-metadata/#project-url-multiple-use 84 | # https://packaging.python.org/specifications/core-metadata/#home-page-optional 85 | # 86 | # Examples listed include a pattern for specifying where the package tracks 87 | # issues, where the source is hosted, where to say thanks to the package 88 | # maintainers, and where to support the project financially. The key is 89 | # what's used to render the link text on PyPI. 90 | [project.urls] # Optional 91 | "Homepage" = "https://github.com/UM-ARM-Lab/pytorch_icem" 92 | "Bug Reports" = "https://github.com/UM-ARM-Lab/pytorch_icem/issues" 93 | "Source" = "https://github.com/UM-ARM-Lab/pytorch_icem" 94 | 95 | # The following would provide a command line executable called `sample` 96 | # which executes the function `main` from this package when invoked. 97 | #[project.scripts] # Optional 98 | #sample = "sample:main" 99 | 100 | # This is configuration specific to the `setuptools` build backend. 101 | # If you are using a different build backend, you will need to change this. 102 | [tool.setuptools] 103 | # If there are data files included in your packages that need to be 104 | # installed, specify them here. 105 | 106 | [build-system] 107 | # These are the assumed default build requirements from pip: 108 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 109 | requires = ["setuptools>=43.0.0", "wheel"] 110 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /src/pytorch_icem/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_icem.icem import iCEM 2 | -------------------------------------------------------------------------------- /src/pytorch_icem/icem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import colorednoise 3 | from arm_pytorch_utilities import handle_batch_input 4 | 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | def accumulate_running_cost(running_cost, terminal_state_weight=10.0): 10 | def _accumulate_running_cost(x, u): 11 | cost = running_cost(x, u) 12 | terminal_cost = cost[:, -1] 13 | cost = torch.sum(cost, dim=-1) 14 | cost += terminal_cost * (terminal_state_weight - 1) 15 | # cost[:, -1] += (terminal_state_weight - 1) * cost[:, -1] 16 | return cost 17 | 18 | return _accumulate_running_cost 19 | 20 | class iCEM: 21 | 22 | def __init__(self, dynamics, trajectory_cost, nx, nu, sigma=None, num_samples=100, num_elites=10, horizon=15, 23 | elites_keep_fraction=0.5, 24 | alpha=0.05, 25 | noise_beta=3, 26 | warmup_iters=100, online_iters=100, 27 | includes_x0=False, 28 | fixed_H=True, 29 | device="cpu"): 30 | 31 | self.dynamics = dynamics 32 | self.trajectory_cost = trajectory_cost 33 | 34 | self.nx = nx 35 | self.nu = nu 36 | self.H = horizon 37 | self.fixed_H = fixed_H 38 | self.N = num_samples 39 | self.device = device 40 | 41 | if sigma is None: 42 | sigma = torch.ones(self.nu, device=self.device).float() 43 | elif isinstance(sigma, float): 44 | sigma = torch.ones(self.nu, device=self.device).float() * sigma 45 | if len(sigma.shape) != nu: 46 | raise ValueError(f"Sigma must be either a scalar or a vector of length nu {nu}") 47 | self.sigma = sigma 48 | self.dtype = self.sigma.dtype 49 | 50 | self.warmup_iters = warmup_iters 51 | self.online_iters = online_iters 52 | self.includes_x0 = includes_x0 53 | self.noise_beta = noise_beta 54 | self.K = num_elites 55 | self.alpha = alpha 56 | self.keep_fraction = elites_keep_fraction 57 | 58 | self.sigma = torch.tensor(self.sigma).to(device=self.device) 59 | self.std = self.sigma.clone() 60 | 61 | # initialise mean and std of actions 62 | self.mean = torch.zeros(self.H, self.nu, device=self.device) 63 | self.kept_elites = None 64 | self.warmed_up = False 65 | 66 | def reset(self): 67 | self.warmed_up = False 68 | self.mean = torch.zeros(self.H, self.nu, device=self.device) 69 | self.std = self.sigma.clone() 70 | self.kept_elites = None 71 | 72 | def sample_action_sequences(self, state, N): 73 | # colored noise 74 | if self.noise_beta > 0: 75 | # Important improvement 76 | # self.mean has shape h,d: we need to swap d and h because temporal correlations are in last axis) 77 | # noinspection PyUnresolvedReferences 78 | samples = colorednoise.powerlaw_psd_gaussian(self.noise_beta, size=(N, self.nu, 79 | self.H)).transpose( 80 | [0, 2, 1]) 81 | samples = torch.from_numpy(samples).to(device=self.device, dtype=self.dtype) 82 | else: 83 | samples = torch.randn(N, self.H, self.nu, device=self.device, dtype=self.dtype) 84 | 85 | U = self.mean + self.std * samples 86 | return U 87 | 88 | def update_distribution(self, elites): 89 | """ 90 | param: elites - K x H x du number of best K control sequences by cost 91 | """ 92 | 93 | # fit around mean of elites 94 | new_mean = elites.mean(dim=0) 95 | new_std = elites.std(dim=0) 96 | 97 | self.mean = (1 - self.alpha) * new_mean + self.alpha * self.mean # [h,d] 98 | self.std = (1 - self.alpha) * new_std + self.alpha * self.std 99 | 100 | @handle_batch_input(n=3) 101 | def _cost(self, x, u): 102 | return self.trajectory_cost(x, u) 103 | # xu = torch.cat((x, u), dim=-1) 104 | # return self.problem.objective(xu) 105 | 106 | @handle_batch_input(n=2) 107 | def _dynamics(self, x, u): 108 | return self.dynamics(x, u) 109 | 110 | def _rollout_dynamics(self, x0, u): 111 | N, H, du = u.shape 112 | assert H == self.H 113 | assert du == self.nu 114 | 115 | x = [x0.reshape(1, self.nx).repeat(N, 1)] 116 | for t in range(self.H): 117 | x.append(self._dynamics(x[-1], u[:, t])) 118 | 119 | if self.includes_x0: 120 | return torch.stack(x[:-1], dim=1) 121 | return torch.stack(x[1:], dim=1) 122 | 123 | def command(self, state, shift_nominal_trajectory=True, return_full_trajectories=False, **kwargs): 124 | if not torch.is_tensor(state): 125 | state = torch.tensor(state, device=self.device, dtype=self.dtype) 126 | x = state 127 | 128 | if self.fixed_H or (not self.warmed_up): 129 | new_T = None 130 | else: 131 | # new_T = self.problem.H - 1 132 | new_T = self.H - 1 133 | self.H = new_T 134 | 135 | # self.problem.update(x, T=new_T, **kwargs) 136 | 137 | if self.warmed_up: 138 | iterations = self.online_iters 139 | else: 140 | iterations = self.warmup_iters 141 | self.warmed_up = True 142 | 143 | # Shift the keep elites 144 | 145 | for i in range(iterations): 146 | if self.kept_elites is None: 147 | # Sample actions 148 | U = self.sample_action_sequences(x, self.N) 149 | else: 150 | # reuse the elites from the previous iteration 151 | U = self.sample_action_sequences(x, self.N - len(self.kept_elites)) 152 | U = torch.cat((U, self.kept_elites), dim=0) 153 | 154 | # evaluate costs and update the distribution! 155 | pred_x = self._rollout_dynamics(x, U) 156 | costs = self._cost(pred_x, U) 157 | sorted, indices = torch.sort(costs) 158 | elites = U[indices[:self.K]] 159 | self.update_distribution(elites) 160 | # save kept elites fraction 161 | self.kept_elites = U[indices[:int(self.K * self.keep_fraction)]] 162 | 163 | # Return best sampled trajectory 164 | out_U = elites[0].clone() 165 | if shift_nominal_trajectory: 166 | self.shift() 167 | 168 | if return_full_trajectories: 169 | out_X = self._rollout_dynamics(x, out_U.reshape(1, self.H, self.nu)).reshape(self.H, self.nx) 170 | out_trajectory = torch.cat((out_X, out_U), dim=-1) 171 | 172 | # Top N // 20 sampled trajectories - for visualization 173 | sampled_trajectories = torch.cat((pred_x, U), dim=-1) 174 | # only return best 10% trajectories for visualization 175 | sampled_trajectories = sampled_trajectories[torch.argsort(costs, descending=False)][:64] 176 | return out_trajectory, sampled_trajectories 177 | else: 178 | return out_U[0] 179 | 180 | def shift(self): 181 | # roll distribution 182 | self.mean = torch.roll(self.mean, -1, dims=0) 183 | self.mean[-1] = torch.zeros(self.nu, device=self.device) 184 | self.std = self.sigma.clone() 185 | # Also shift the elites 186 | if self.kept_elites is not None: 187 | self.kept_elites = torch.roll(self.kept_elites, -1, dims=1) 188 | self.kept_elites[:, -1] = self.sigma * torch.randn(len(self.kept_elites), self.nu, device=self.device) -------------------------------------------------------------------------------- /tests/pendulum_approximate_continuous.py: -------------------------------------------------------------------------------- 1 | """ 2 | Same as approximate dynamics, but now the input is sine and cosine of theta (output is still dtheta) 3 | This is a continuous representation of theta, which some papers show is easier for a NN to learn. 4 | """ 5 | import gym 6 | import numpy as np 7 | import torch 8 | import logging 9 | import math 10 | from pytorch_icem import icem 11 | from gym import logger as gym_log 12 | import time 13 | 14 | gym_log.set_level(gym_log.INFO) 15 | logger = logging.getLogger(__name__) 16 | logging.basicConfig(level=logging.DEBUG, 17 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 18 | datefmt='%m-%d %H:%M:%S') 19 | 20 | 21 | def run(ctrl: icem.iCEM, env, retrain_dynamics, retrain_after_iter=50, iter=1000, render=True): 22 | dataset = torch.zeros((retrain_after_iter, ctrl.nx + ctrl.nu), device=ctrl.device) 23 | total_reward = 0 24 | for i in range(iter): 25 | state = env.unwrapped.state.copy() 26 | command_start = time.perf_counter() 27 | action = ctrl.command(state) 28 | elapsed = time.perf_counter() - command_start 29 | res = env.step(action.cpu().numpy()) 30 | s, r = res[0], res[1] 31 | total_reward += r 32 | logger.debug("action taken: %.4f cost received: %.4f time taken: %.5fs", action, -r, elapsed) 33 | if render: 34 | env.render() 35 | 36 | di = i % retrain_after_iter 37 | if di == 0 and i > 0: 38 | retrain_dynamics(dataset) 39 | # don't have to clear dataset since it'll be overridden, but useful for debugging 40 | dataset.zero_() 41 | dataset[di, :ctrl.nx] = torch.tensor(state, device=ctrl.device) 42 | dataset[di, ctrl.nx:] = action 43 | return total_reward, dataset 44 | 45 | 46 | if __name__ == "__main__": 47 | ENV_NAME = "Pendulum-v1" 48 | TIMESTEPS = 15 # T 49 | N_SAMPLES = 100 # K 50 | ACTION_LOW = -2.0 51 | ACTION_HIGH = 2.0 52 | 53 | d = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 54 | dtype = torch.double 55 | 56 | sigma = torch.tensor([1], device=d, dtype=dtype) 57 | # noise_sigma = torch.tensor([[10, 0], [0, 10]], device=d, dtype=dtype) 58 | lambda_ = 1. 59 | 60 | import random 61 | 62 | randseed = 24 63 | if randseed is None: 64 | randseed = random.randint(0, 1000000) 65 | random.seed(randseed) 66 | np.random.seed(randseed) 67 | torch.manual_seed(randseed) 68 | logger.info("random seed %d", randseed) 69 | 70 | # new hyperparmaeters for approximate dynamics 71 | H_UNITS = 32 72 | TRAIN_EPOCH = 150 73 | BOOT_STRAP_ITER = 100 74 | 75 | nx = 2 76 | nu = 1 77 | # network output is state residual 78 | network = torch.nn.Sequential( 79 | torch.nn.Linear(nx + nu + 1, H_UNITS), 80 | torch.nn.Tanh(), 81 | torch.nn.Linear(H_UNITS, H_UNITS), 82 | torch.nn.Tanh(), 83 | torch.nn.Linear(H_UNITS, nx) 84 | ).double().to(device=d) 85 | 86 | 87 | def dynamics(state, perturbed_action): 88 | u = torch.clamp(perturbed_action, ACTION_LOW, ACTION_HIGH) 89 | if state.dim() == 1 or u.dim() == 1: 90 | state = state.view(1, -1) 91 | u = u.view(1, -1) 92 | if u.shape[1] > 1: 93 | u = u[:, 0].view(-1, 1) 94 | xu = torch.cat((state, u), dim=1) 95 | # feed in cosine and sine of angle instead of theta 96 | xu = torch.cat((torch.sin(xu[:, 0]).view(-1, 1), torch.cos(xu[:, 0]).view(-1, 1), xu[:, 1:]), dim=1) 97 | state_residual = network(xu) 98 | # output dtheta directly so can just add 99 | next_state = state + state_residual 100 | next_state[:, 0] = angle_normalize(next_state[:, 0]) 101 | return next_state 102 | 103 | 104 | def true_dynamics(state, perturbed_action): 105 | # true dynamics from gym 106 | th = state[:, 0].view(-1, 1) 107 | thdot = state[:, 1].view(-1, 1) 108 | 109 | g = 10 110 | m = 1 111 | l = 1 112 | dt = 0.05 113 | 114 | u = perturbed_action 115 | u = torch.clamp(u, -2, 2) 116 | 117 | newthdot = thdot + (3 * g / (2 * l) * torch.sin(th) + 3.0 / (m * l ** 2) * u) * dt 118 | newthdot = torch.clip(newthdot, -8, 8) 119 | newth = th + newthdot * dt 120 | 121 | state = torch.cat((newth, newthdot), dim=1) 122 | return state 123 | 124 | 125 | def angular_diff_batch(a, b): 126 | """Angle difference from b to a (a - b)""" 127 | d = a - b 128 | d[d > math.pi] -= 2 * math.pi 129 | d[d < -math.pi] += 2 * math.pi 130 | return d 131 | 132 | 133 | def angle_normalize(x): 134 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 135 | 136 | 137 | def running_cost(state, action): 138 | theta = state[..., 0] 139 | theta_dt = state[..., 1] 140 | action = action[..., 0] 141 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt ** 2 142 | return cost 143 | 144 | 145 | dataset = None 146 | # create some true dynamics validation set to compare model against 147 | Nv = 1000 148 | statev = torch.cat(((torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * 2 * math.pi, 149 | (torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * 16), dim=1) 150 | actionv = (torch.rand(Nv, 1, dtype=torch.double, device=d) - 0.5) * (ACTION_HIGH - ACTION_LOW) 151 | 152 | 153 | def train(new_data): 154 | global dataset 155 | # not normalized inside the simulator 156 | new_data[:, 0] = angle_normalize(new_data[:, 0]) 157 | if not torch.is_tensor(new_data): 158 | new_data = torch.from_numpy(new_data) 159 | # clamp actions 160 | new_data[:, -1] = torch.clamp(new_data[:, -1], ACTION_LOW, ACTION_HIGH) 161 | new_data = new_data.to(device=d) 162 | # append data to whole dataset 163 | if dataset is None: 164 | dataset = new_data 165 | else: 166 | dataset = torch.cat((dataset, new_data), dim=0) 167 | 168 | # train on the whole dataset (assume small enough we can train on all together) 169 | XU = dataset 170 | dtheta = angular_diff_batch(XU[1:, 0], XU[:-1, 0]) 171 | dtheta_dt = XU[1:, 1] - XU[:-1, 1] 172 | Y = torch.cat((dtheta.view(-1, 1), dtheta_dt.view(-1, 1)), dim=1) # x' - x residual 173 | xu = XU[:-1] # make same size as Y 174 | xu = torch.cat((torch.sin(xu[:, 0]).view(-1, 1), torch.cos(xu[:, 0]).view(-1, 1), xu[:, 1:]), dim=1) 175 | 176 | # thaw network 177 | for param in network.parameters(): 178 | param.requires_grad = True 179 | 180 | optimizer = torch.optim.Adam(network.parameters()) 181 | for epoch in range(TRAIN_EPOCH): 182 | optimizer.zero_grad() 183 | # MSE loss 184 | Yhat = network(xu) 185 | loss = (Y - Yhat).norm(2, dim=1) ** 2 186 | loss.mean().backward() 187 | optimizer.step() 188 | logger.debug("ds %d epoch %d loss %f", dataset.shape[0], epoch, loss.mean().item()) 189 | 190 | # freeze network 191 | for param in network.parameters(): 192 | param.requires_grad = False 193 | 194 | # evaluate network against true dynamics 195 | yt = true_dynamics(statev, actionv) 196 | yp = dynamics(statev, actionv) 197 | dtheta = angular_diff_batch(yp[:, 0], yt[:, 0]) 198 | dtheta_dt = yp[:, 1] - yt[:, 1] 199 | E = torch.cat((dtheta.view(-1, 1), dtheta_dt.view(-1, 1)), dim=1).norm(dim=1) 200 | logger.info("Error with true dynamics theta %f theta_dt %f norm %f", dtheta.abs().mean(), 201 | dtheta_dt.abs().mean(), E.mean()) 202 | logger.debug("Start next collection sequence") 203 | 204 | 205 | downward_start = True 206 | env = gym.make(ENV_NAME, render_mode="human") # bypass the default TimeLimit wrapper 207 | env.reset() 208 | if downward_start: 209 | env.state = env.unwrapped.state = [np.pi, 1] 210 | 211 | # bootstrap network with random actions 212 | if BOOT_STRAP_ITER: 213 | logger.info("bootstrapping with random action for %d actions", BOOT_STRAP_ITER) 214 | new_data = np.zeros((BOOT_STRAP_ITER, nx + nu)) 215 | for i in range(BOOT_STRAP_ITER): 216 | pre_action_state = env.state 217 | action = np.random.uniform(low=ACTION_LOW, high=ACTION_HIGH) 218 | env.step([action]) 219 | # env.render() 220 | new_data[i, :nx] = pre_action_state 221 | new_data[i, nx:] = action 222 | 223 | train(new_data) 224 | logger.info("bootstrapping finished") 225 | 226 | env.reset() 227 | if downward_start: 228 | env.state = env.unwrapped.state = [np.pi, 1] 229 | 230 | ctrl = icem.iCEM(dynamics, icem.accumulate_running_cost(running_cost), nx, nu, sigma=sigma, 231 | warmup_iters=5, online_iters=5, 232 | num_samples=N_SAMPLES, num_elites=10, horizon=TIMESTEPS, device=d, ) 233 | total_reward, data = run(ctrl, env, train) 234 | logger.info("Total reward %f", total_reward) 235 | --------------------------------------------------------------------------------