├── LICENSE ├── README.md ├── envs └── cartpole_cont.py ├── gym_cartpole.py ├── gym_pendulum.py ├── plot_cost.py ├── pytorch_mppi ├── __init__.py ├── mppi.py └── smooth_mppi.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kim Taekyung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SMPPI Implementation in Pytorch 2 | 3 | This repository implements the idea of Smooth Model Predictive Path Integral control (SMPPI), using neural network dynamics model in pytorch. SMPPI is a general framework that is able to obtain smooth actions using sampling-based MPC without any extra smoothing algorithms (e.g. Savitzky-Golay Filter). The related paper will be relased soon. 4 | 5 | # Installation 6 | 7 | Clone repository, then 'pip install -e .' or 'pip3 install -e .' based on your environment. 8 | 9 | Or you can manually install dependencies: 10 | 11 | - pytorch 12 | - numpy 13 | - gym 14 | - scipy 15 | 16 | # How to Run Example 17 | 18 | You can run our test example by: 19 | 20 | For pendulum, 21 | ```bash 22 | python gym_pendulum.py 23 | ``` 24 | For cartpole, 25 | ```bash 26 | python gym_cartpole.py 27 | ``` 28 | 29 | 30 | It's an inverted pendulum in gym environment. The sample results of the four different controllers are shown below: 31 | 32 | | MPPI w/o Smoothing | MPPI (apply smoothing on noise sequence) | 33 | | :------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------: | 34 | | | | 35 | | __MPPI (apply smoothing on control sequence)__ | __SMPPI__ | 36 | | | | 37 | 38 | It's a cartpole (continuous action) environment. Since MPPI requires random noise sampling of actions, cartpole environment in openAI gym(which has only two discrete actions, Left or Right) is not suitable for MPPI test. So we made custom environment which provides continuous action. In this environment, action can vary continuously from -10.0 to 10.0. (For more detail, see envs/cartpole_cont.py) 39 | 40 | The sample result of SMPPI controller is shown below: 41 | 42 | 43 | 44 | 45 | They are collecting the state-action pairs dataset with exploration. The dynamics models are retrained every 50 iterations. SMPPI can accurately find the optimal action sequence, right after re-training the neural network dynamics. 46 | # How to Use 47 | 48 | Simply import SMPPI from 'pytorch_mppi', you can obtain sequence of smooth optimal actions from sampling-based MPC. 49 | 50 | ```python 51 | from pytorch_mppi import smooth_mppi as smppi 52 | # define your dynamics model (both work for nominal dynamics or neural network approximation) 53 | # create controller with chosen parameters 54 | mppi_env = smppi.MPPI(dynamics, running_cost, nx, nu, noise_sigma, 55 | num_samples=N_SAMPLES, 56 | horizon=TIMESTEPS, lambda_=lambda_, gamma_=gamma_, device=device, 57 | w_action_seq_cost=Omega, 58 | u_min=torch.tensor(D_ACTION_LOW, dtype=dtype, device=device), 59 | u_max=torch.tensor(D_ACTION_HIGH, dtype=dtype, device=device), 60 | action_min=torch.tensor(ACTION_LOW, dtype=dtype, device=device), 61 | action_max=torch.tensor(ACTION_HIGH, dtype=dtype, device=device)) 62 | 63 | # assuming you have a gym-like env 64 | obs = env.reset() 65 | for i in range(100): 66 | action = mppi_env.command(obs) 67 | obs, reward, done, _ = env.step(action.cpu().numpy()) 68 | ``` 69 | 70 | Alternatively, you can test the original MPPI with different smoothing methods. 71 | 72 | ```python 73 | from pytorch_mppi import mppi 74 | # define your dynamics model (both work for nominal dyanmics or neural network approximation) 75 | # create controller with chosen parameters 76 | mppi_env = mppi.MPPI(dynamics, running_cost, nx, nu, noise_sigma, 77 | num_samples=N_SAMPLES, 78 | horizon=TIMESTEPS, lambda_=lambda_, gamma_=gamma_, device=device, 79 | u_min=torch.tensor(ACTION_LOW, dtype=dtype, device=device), 80 | u_max=torch.tensor(ACTION_HIGH, dtype=dtype, device=device), 81 | smooth=SMOOTHING_METHOD) 82 | ``` 83 | 84 | You have three options for the 'SMOOTHING_METHOD': 85 | 86 | 1. __"no filter"__ : no smoothing 87 | 2. __"smooth u"__ : smooth control sequence after adding noise 88 | 3. __"smooth noise"__ : smooth noise sequence before adding noise 89 | 90 | For the smoothing algorithm, we use convolutional Savitzky-Golay Filter (in scipy). 91 | 92 | # Parameters Description 93 | 94 | ### lambda\_ 95 | 96 | - temperature, positive scalar where larger values will allow more exploration 97 | - we recommend 10.0 ~ 20.0 when you have more than 1,000 samples 98 | 99 | ### gamma\_ 100 | 101 | - running action cost parameter 102 | - see [MPPI paper](https://ieeexplore.ieee.org/abstract/document/8558663?casa_token=RTtdCK4jrykAAAAA:YgIhGuAKv_dPA_JjvaxHT2npZuaFVI0utE4JSnDkALwqbUvh676UydsOUg44ka5rawG7edPo) for more detail 103 | 104 | ### w_action_seq_cost 105 | 106 | - (nu x nu) weight parameter for smoothing action sequence 107 | 108 | ### num_samples 109 | 110 | - number of trajectories to sample; generally the more the better. (determine this parameter based on the size of your neural network model.) 111 | - try to have it between 1K ~ 10K, if your GPU allows it to. 112 | 113 | ### noise_sigma 114 | 115 | - (nu x nu) control noise covariance; larger covariance yeilds more exploration 116 | 117 | | See our paper for further information (will be released soon). 118 | 119 | # Requirements 120 | 121 | - `next state <- dynamics(state, action)` function (doesn't have to be true dynamics) 122 | - `state` is `K x nx`, `action` is `K x nu` 123 | - `cost <- running_cost(state, action)` function 124 | - `cost` is `K x 1`, state is `K x nx`, `action` is `K x nu` 125 | 126 | | __The shapes of the important tensors (such as 'states', 'noise', 'actions') are all commented on the scripts.__ 127 | 128 | # Related Works 129 | 130 | This repository was built based on the [project of pytorch implementation of MPPI](https://github.com/UM-ARM-Lab/pytorch_mppi), that I had contributed before. Thanks for the great work of [LemonPi](https://github.com/LemonPi). 131 | -------------------------------------------------------------------------------- /envs/cartpole_cont.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classic cart-pole system implemented by Rich Sutton et al. 3 | Copied from http://incompleteideas.net/sutton/book/code/pole.c 4 | permalink: https://perma.cc/C9ZM-652R 5 | """ 6 | 7 | import math 8 | import gym 9 | from gym import spaces, logger 10 | from gym.utils import seeding 11 | import numpy as np 12 | 13 | 14 | class CartPoleContEnv(gym.Env): 15 | 16 | metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50} 17 | 18 | def __init__(self): 19 | self.gravity = 9.8 20 | self.masscart = 1.0 21 | self.masspole = 0.1 22 | self.total_mass = self.masspole + self.masscart 23 | self.length = 0.5 # actually half the pole's length 24 | self.polemass_length = self.masspole * self.length 25 | # self.force_mag = 30.0 26 | self.tau = 0.02 # seconds between state updates 27 | self.min_action = -10.0 28 | self.max_action = 10.0 29 | 30 | # Angle at which to fail the episode 31 | self.theta_threshold_radians = 12 * 2 * math.pi / 360 32 | self.x_threshold = 2.4#10.0 33 | 34 | # Angle limit set to 2 * theta_threshold_radians so failing observation 35 | # is still within bounds. 36 | high = np.array( 37 | [ 38 | #np.finfo(np.float32).max, 39 | self.x_threshold * 2, 40 | np.finfo(np.float32).max, 41 | np.finfo(np.float32).max, 42 | #self.theta_threshold_radians * 2, 43 | np.finfo(np.float32).max 44 | ], 45 | dtype=np.float32 46 | ) 47 | 48 | self.action_space = spaces.Box( 49 | low = self.min_action, 50 | high = self.max_action, 51 | shape=(1,), 52 | dtype=np.float32 53 | ) 54 | self.observation_space = spaces.Box(-high, high, dtype=np.float32) 55 | 56 | self.seed() 57 | self.viewer = None 58 | self.state = None 59 | 60 | self.steps_beyond_done = None 61 | 62 | def seed(self, seed=None): 63 | self.np_random, seed = seeding.np_random(seed) 64 | return [seed] 65 | 66 | def step(self, action): 67 | #err_msg = "%r (%s) invalid" % (action, type(action)) 68 | #assert self.action_space.contains(action), err_msg 69 | 70 | x, x_dot, theta, theta_dot = self.state 71 | #force = self.force_mag * float(action) 72 | force = float(action) 73 | costheta = math.cos(theta) 74 | sintheta = math.sin(theta) 75 | 76 | # For the interested reader: 77 | # https://coneural.org/florian/papers/05_cart_pole.pdf 78 | temp = ( 79 | force + self.polemass_length * theta_dot ** 2 * sintheta 80 | ) / self.total_mass 81 | thetaacc = (self.gravity * sintheta - costheta * temp) / ( 82 | self.length * (4.0 / 3.0 - self.masspole * costheta ** 2 / self.total_mass) 83 | ) 84 | xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass 85 | 86 | x = x + self.tau * x_dot 87 | x_dot = x_dot + self.tau * xacc 88 | theta = theta + self.tau * theta_dot 89 | theta_dot = theta_dot + self.tau * thetaacc 90 | 91 | self.state = (x, x_dot, theta, theta_dot) 92 | 93 | #done = False 94 | done = bool( 95 | x < -self.x_threshold 96 | or x > self.x_threshold 97 | #or theta < -self.theta_threshold_radians 98 | #or theta > self.theta_threshold_radians 99 | ) 100 | 101 | if not done: 102 | reward = 1.0 103 | elif self.steps_beyond_done is None: 104 | # Pole just fell! 105 | self.steps_beyond_done = 0 106 | reward = 1.0 107 | else: 108 | if self.steps_beyond_done == 0: 109 | logger.warn( 110 | "You are calling 'step()' even though this " 111 | "environment has already returned done = True. You " 112 | "should always call 'reset()' once you receive 'done = " 113 | "True' -- any further steps are undefined behavior." 114 | ) 115 | self.steps_beyond_done += 1 116 | reward = 0.0 117 | 118 | return np.array(self.state, dtype=np.float32), reward, done, {} 119 | 120 | def reset(self): 121 | self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) 122 | self.state[2] += np.pi 123 | self.steps_beyond_done = None 124 | return np.array(self.state, dtype=np.float32) 125 | 126 | def render(self, mode="human"): 127 | screen_width = 600 128 | screen_height = 400 129 | 130 | world_width = self.x_threshold * 2 131 | scale = screen_width / world_width 132 | carty = 100 # TOP OF CART 133 | polewidth = 10.0 134 | polelen = scale * (2 * self.length) 135 | cartwidth = 50.0 136 | cartheight = 30.0 137 | 138 | if self.viewer is None: 139 | from gym.envs.classic_control import rendering 140 | 141 | self.viewer = rendering.Viewer(screen_width, screen_height) 142 | l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 143 | axleoffset = cartheight / 4.0 144 | cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 145 | self.carttrans = rendering.Transform() 146 | cart.add_attr(self.carttrans) 147 | self.viewer.add_geom(cart) 148 | l, r, t, b = ( 149 | -polewidth / 2, 150 | polewidth / 2, 151 | polelen - polewidth / 2, 152 | -polewidth / 2, 153 | ) 154 | pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 155 | pole.set_color(0.8, 0.6, 0.4) 156 | self.poletrans = rendering.Transform(translation=(0, axleoffset)) 157 | pole.add_attr(self.poletrans) 158 | pole.add_attr(self.carttrans) 159 | self.viewer.add_geom(pole) 160 | self.axle = rendering.make_circle(polewidth / 2) 161 | self.axle.add_attr(self.poletrans) 162 | self.axle.add_attr(self.carttrans) 163 | self.axle.set_color(0.5, 0.5, 0.8) 164 | self.viewer.add_geom(self.axle) 165 | self.track = rendering.Line((0, carty), (screen_width, carty)) 166 | self.track.set_color(0, 0, 0) 167 | self.viewer.add_geom(self.track) 168 | 169 | self._pole_geom = pole 170 | 171 | if self.state is None: 172 | return None 173 | 174 | # Edit the pole polygon vertex 175 | pole = self._pole_geom 176 | l, r, t, b = ( 177 | -polewidth / 2, 178 | polewidth / 2, 179 | polelen - polewidth / 2, 180 | -polewidth / 2, 181 | ) 182 | pole.v = [(l, b), (l, t), (r, t), (r, b)] 183 | 184 | x = self.state 185 | cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART 186 | self.carttrans.set_translation(cartx, carty) 187 | self.poletrans.set_rotation(-x[2]) 188 | 189 | return self.viewer.render(return_rgb_array=mode == "rgb_array") 190 | 191 | def close(self): 192 | if self.viewer: 193 | self.viewer.close() 194 | self.viewer = None 195 | -------------------------------------------------------------------------------- /gym_cartpole.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import logging 5 | import math 6 | from gym import wrappers, logger as gym_log 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | from envs import cartpole_cont 10 | 11 | gym_log.set_level(gym_log.INFO) 12 | logger = logging.getLogger(__name__) 13 | logging.basicConfig(level=logging.INFO, 14 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 15 | datefmt='%m-%d %H:%M:%S') 16 | 17 | SMPPI = True 18 | 19 | # three options for control smoothing 20 | # 1: "no filter" : no smoothing 21 | # 2: "smooth u" : smooth control sequence after adding noise 22 | # 3: "smooth noise" : smooth noise sequence before adding noise 23 | ## for more detail, please wait for our paper now under review. 24 | if not SMPPI: 25 | SMOOTH = "no filter" 26 | 27 | 28 | 29 | if SMPPI: 30 | from pytorch_mppi import smooth_mppi as mppi 31 | else: 32 | from pytorch_mppi import mppi 33 | 34 | if __name__ == "__main__": 35 | # ENV_NAME = "ContinuousCart_Pole-v1" 36 | TIMESTEPS = 75 37 | N_SAMPLES = 1000 38 | ACTION_LOW = -10.0 39 | ACTION_HIGH = 10.0 40 | D_ACTION_LOW = -1.0 41 | D_ACTION_HIGH = 1.0 42 | 43 | device = torch.device("cuda") if torch.cuda.is_available( 44 | ) else torch.device("cpu") 45 | dtype = torch.double 46 | 47 | noise_sigma = torch.tensor([5.0], device=device, dtype=dtype) 48 | # if size of action space is larger than 1: 49 | # noise_sigma = torch.tensor([[1, 0], [0, 2]], device=d, dtype=dtype) 50 | lambda_ = 10. 51 | gamma_ = 0.1 52 | 53 | import random 54 | 55 | randseed = 42 56 | if randseed is None: 57 | randseed = random.randint(0, 1000000) 58 | random.seed(randseed) 59 | np.random.seed(randseed) 60 | torch.manual_seed(randseed) 61 | logger.info("random seed %d", randseed) 62 | 63 | H_UNITS = 32 64 | TRAIN_EPOCH = 100 65 | BOOT_STRAP_ITER = 0 #30000 66 | EPISODE_CUT = 1000 67 | BATCH_SIZE = 50 68 | 69 | cost_tolerance = 40. 70 | SUCCESS_CRITERION = 1000 71 | 72 | nx = 4 73 | nu = 1 74 | # network output is state residual 75 | network = torch.nn.Sequential( 76 | torch.nn.Linear(nx + nu, H_UNITS), 77 | torch.nn.Tanh(), 78 | torch.nn.Linear(H_UNITS, H_UNITS), 79 | torch.nn.Tanh(), 80 | torch.nn.Linear(H_UNITS, nx - 1) 81 | ).double().to(device=device) 82 | 83 | def dynamics(state, perturbed_action): 84 | tau = 0.02 85 | u = torch.clamp(perturbed_action, ACTION_LOW, ACTION_HIGH) 86 | xu = torch.cat((state, u), dim=1) 87 | # feed in cosine and sine of angle instead of theta 88 | xu = torch.cat( 89 | (xu[:, 1].view(-1, 1), 90 | torch.sin(xu[:, 2]).view(-1, 1), 91 | torch.cos(xu[:, 2]).view(-1, 1), 92 | xu[:, 3].view(-1, 1), 93 | xu[:, 4].view(-1, 1)), dim=1) 94 | network.eval() 95 | with torch.no_grad(): 96 | state_residual = network(xu) 97 | 98 | # output dtheta directly so can just add 99 | next_state = torch.zeros_like(state) 100 | next_state[:, 1:] = state[:, 1:].clone().detach() + state_residual 101 | next_state[:, 0] = state[:, 0].clone().detach() + tau * state[:, 1].clone().detach() 102 | next_state[:, 2] = angle_normalize(next_state[:, 2]) 103 | return next_state 104 | 105 | def true_dynamics(state, perturbed_action): 106 | perturbed_action = torch.clamp(perturbed_action, ACTION_LOW, ACTION_HIGH) 107 | gravity = 9.8 108 | masscart = 1.0 109 | masspole = 0.1 110 | total_mass = masspole + masscart 111 | length = 0.5 # actually half the pole's length 112 | polemass_length = masspole * length 113 | tau = 0.02 # seconds between state updates 114 | 115 | x = state[:, 0].view(-1, 1) 116 | x_dot = state[:, 1].view(-1, 1) 117 | theta = state[:, 2].view(-1, 1) 118 | theta_dot = state[:, 3].view(-1, 1) 119 | 120 | #force = force_mag * perturbed_action 121 | force = perturbed_action 122 | costheta = torch.cos(theta) 123 | sintheta = torch.sin(theta) 124 | # For the interested reader: 125 | # https://coneural.org/florian/papers/05_cart_pole.pdf 126 | temp = ( 127 | force + polemass_length * theta_dot ** 2 * sintheta 128 | ) / total_mass 129 | thetaacc = (gravity * sintheta - costheta * temp) / ( 130 | length * (4.0 / 3.0 - masspole * costheta ** 2 / total_mass) 131 | ) 132 | xacc = temp - polemass_length * thetaacc * costheta / total_mass 133 | 134 | x = x + tau * x_dot 135 | x_dot = x_dot + tau * xacc 136 | theta = theta + tau * theta_dot 137 | theta_dot = theta_dot + tau * thetaacc 138 | theta = angle_normalize(theta) 139 | next_state = torch.cat((x, x_dot, theta, theta_dot), dim=1) 140 | return next_state 141 | 142 | def angular_diff_batch(a, b): 143 | """Angle difference from b to a (a - b)""" 144 | d = a - b 145 | d[d > math.pi] -= 2 * math.pi 146 | d[d < -math.pi] += 2 * math.pi 147 | return d 148 | 149 | def angle_normalize(x): 150 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 151 | 152 | def running_cost(state): 153 | x = state[:, 0] 154 | x_dot = state[:, 1] 155 | theta = state[:, 2] 156 | theta_dot = state[:, 3] 157 | w_x = 50. 158 | w_x_dot = 0.01 159 | w_theta = 5. 160 | w_theta_dot = 0.01 161 | cost = w_x * x ** 2 + w_x_dot * x_dot ** 2 + w_theta * angle_normalize(theta) ** 2 + w_theta_dot * theta_dot ** 2 162 | return cost 163 | 164 | dataset_xu = None 165 | dataset_Y = None 166 | # create some true dynamics validation set to compare model 167 | Nv = 1000 168 | statev = torch.cat(((torch.rand(Nv, 1, dtype=dtype, device=device) - 0.5) * 2 * 1.2, 169 | (torch.rand(Nv, 1, dtype=dtype, device=device) - 0.5) * 2, 170 | (torch.rand(Nv, 1, dtype=dtype, device=device) - 0.5) * 2 * math.pi * 10 / 180, 171 | (torch.rand(Nv, 1, dtype=dtype, device=device) - 0.5) * 2 172 | ), dim=1) 173 | actionv = (torch.rand(Nv, 1, dtype=dtype, device=device) - 174 | 0.5) * (ACTION_HIGH - ACTION_LOW) 175 | 176 | class CustomDataset(Dataset): 177 | def __init__(self, x, y): 178 | self.x_data = x 179 | self.y_data = y 180 | 181 | def __len__(self): 182 | return len(self.x_data) 183 | 184 | def __getitem__(self, item): 185 | x_ = self.x_data[item] 186 | y_ = self.y_data[item] 187 | return x_, y_ 188 | 189 | def dataset_append(state, action, next_state): 190 | global dataset_xu, dataset_Y 191 | state[2] = angle_normalize(state[2]) 192 | next_state[2] = angle_normalize(next_state[2]) 193 | action = torch.clamp(action.clone().detach(), ACTION_LOW, ACTION_HIGH) 194 | 195 | xu = torch.cat((state, action), dim=0) 196 | 197 | xu = torch.tensor((xu[1], 198 | torch.sin(xu[2]), 199 | torch.cos(xu[2]), 200 | xu[3], 201 | xu[4])).view(1, -1) 202 | dx = next_state[0] - state[0] 203 | dx_dot = next_state[1] - state[1] 204 | dtheta = angular_diff_batch(next_state[2], state[2]) 205 | dtheta_dot = next_state[3] - state[3] 206 | Y = torch.tensor((dx_dot, dtheta, dtheta_dot)).view(1, -1).clone().detach() 207 | 208 | if dataset_xu is None and dataset_Y is None: 209 | dataset_xu = xu 210 | dataset_Y = Y 211 | 212 | else: 213 | dataset_xu = torch.cat((dataset_xu, xu), dim=0) 214 | dataset_Y = torch.cat((dataset_Y, Y), dim=0) 215 | 216 | def train(epoch=TRAIN_EPOCH): 217 | global dataset_xu, dataset_Y, network 218 | 219 | # thaw network 220 | for param in network.parameters(): 221 | param.requires_grad = True 222 | 223 | optimizer = torch.optim.Adam(network.parameters(), lr=1e-3) 224 | train_dataset = CustomDataset(dataset_xu, dataset_Y) 225 | train_loader = DataLoader( 226 | train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) 227 | 228 | network.train() 229 | for i in range(epoch): 230 | # MSE loss 231 | for x, y in train_loader: 232 | x, y = x.to(device), y.to(device) 233 | 234 | yhat = network(x) 235 | loss = (y - yhat).norm(2, dim=1) ** 2 236 | optimizer.zero_grad() 237 | loss.mean().backward() 238 | optimizer.step() 239 | logger.debug("ds %d epoch %d loss %f", 240 | dataset_xu.shape[0], i, loss.mean().item()) 241 | 242 | # freeze network 243 | for param in network.parameters(): 244 | param.requires_grad = False 245 | 246 | # evaluate network against true dynamics 247 | yt = true_dynamics(statev, actionv) 248 | yp = dynamics(statev, actionv) 249 | dx = yp[:, 0] - yt[:, 0] 250 | dx_dot = yp[:, 1] - yt[:, 1] 251 | dtheta = angular_diff_batch(yp[:, 2], yt[:, 2]) 252 | dtheta_dot = yp[:, 3] - yt[:, 3] 253 | E = torch.cat((dx.view(-1, 1), dx_dot.view(-1, 1), dtheta.view(-1, 1), dtheta_dot.view(-1, 1)), 254 | dim=1).norm(dim=1) 255 | logger.info("Error with true dynamics x %f x_dot %f theta %f theta_dot %f norm %f", dx.abs().mean(), dx_dot.abs().mean(), dtheta.abs().mean(), 256 | dtheta_dot.abs().mean(), E.mean()) 257 | logger.debug("Start next collection sequence") 258 | 259 | def model_save(): 260 | global network 261 | torch.save(network.state_dict(), 'model_weights_cartpole.pth') 262 | 263 | env = cartpole_cont.CartPoleContEnv() 264 | if BOOT_STRAP_ITER: 265 | logger.info( 266 | "bootstrapping with random action for %d actions", BOOT_STRAP_ITER) 267 | data_count = 0 268 | while True: 269 | env.reset() 270 | for i in range(EPISODE_CUT): 271 | state = env.state 272 | state = torch.tensor(state, dtype=torch.float64).to(device=device) 273 | action = np.random.uniform(low=ACTION_LOW, high=ACTION_HIGH) 274 | action = torch.tensor([action], dtype=torch.float64).to(device=device) 275 | s, _, done, _ = env.step(action.cpu().numpy()) 276 | next_state = env.state 277 | next_state = torch.tensor(next_state, dtype=torch.float64).to(device=device) 278 | dataset_append(state, action, next_state) 279 | data_count += 1 280 | if data_count == BOOT_STRAP_ITER: 281 | break 282 | if done: 283 | break 284 | if data_count == BOOT_STRAP_ITER: 285 | break 286 | train(epoch=500) 287 | logger.info("bootstrapping finished") 288 | 289 | env = wrappers.Monitor(env, '/tmp/mppi/', force=True) 290 | 291 | if SMPPI: 292 | mppi_gym = mppi.MPPI(dynamics, running_cost, nx, nu, noise_sigma, 293 | num_samples=N_SAMPLES, 294 | horizon=TIMESTEPS, 295 | lambda_=lambda_, 296 | gamma_=gamma_, 297 | device=device, 298 | u_min=torch.tensor( 299 | D_ACTION_LOW, dtype=dtype, device=device), 300 | u_max=torch.tensor( 301 | D_ACTION_HIGH, dtype=dtype, device=device), 302 | action_min=torch.tensor( 303 | ACTION_LOW, dtype=dtype, device=device), 304 | action_max=torch.tensor(ACTION_HIGH, dtype=dtype, device=device)) 305 | else: 306 | mppi_gym = mppi.MPPI(dynamics, running_cost, nx, nu, noise_sigma, 307 | num_samples=N_SAMPLES, 308 | horizon=TIMESTEPS, 309 | lambda_=lambda_, 310 | gamma_=gamma_, 311 | device=device, 312 | u_min=torch.tensor( 313 | ACTION_LOW, dtype=dtype, device=device), 314 | u_max=torch.tensor( 315 | ACTION_HIGH, dtype=dtype, device=device), 316 | smooth=SMOOTH) 317 | 318 | cost_history = mppi.run_mppi_episode(mppi_gym, env, dataset_append, train, running_cost, model_save, cost_tolerance, SUCCESS_CRITERION) -------------------------------------------------------------------------------- /gym_pendulum.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import logging 5 | import math 6 | from gym import wrappers, logger as gym_log 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | 10 | gym_log.set_level(gym_log.INFO) 11 | logger = logging.getLogger(__name__) 12 | logging.basicConfig(level=logging.INFO, 13 | format='[%(levelname)s %(asctime)s %(pathname)s:%(lineno)d] %(message)s', 14 | datefmt='%m-%d %H:%M:%S') 15 | 16 | SMPPI = True 17 | 18 | downward_start = True 19 | INIT_VEL = 0 20 | 21 | # three options for control smoothing 22 | # 1: "no filter" : no smoothing 23 | # 2: "smooth u" : smooth control sequence after adding noise 24 | # 3: "smooth noise" : smooth noise sequence before adding noise 25 | # for more detail, please wait for our paper now under review. 26 | if not SMPPI: 27 | SMOOTH = "no filter" 28 | 29 | 30 | if SMPPI: 31 | from pytorch_mppi import smooth_mppi as mppi 32 | else: 33 | from pytorch_mppi import mppi 34 | 35 | if __name__ == "__main__": 36 | ENV_NAME = "Pendulum-v1" 37 | TIMESTEPS = 15 # T 38 | N_SAMPLES = 1000 # K 39 | ACTION_LOW = -2.0 40 | ACTION_HIGH = 2.0 41 | D_ACTION_LOW = -8.0 42 | D_ACTION_HIGH = 8.0 43 | 44 | device = torch.device("cuda") if torch.cuda.is_available( 45 | ) else torch.device("cpu") 46 | dtype = torch.double 47 | 48 | noise_sigma = torch.tensor([1.], device=device, dtype=dtype) 49 | # if size of action space is larger than 1: 50 | # noise_sigma = torch.tensor([[1, 0], [0, 2]], device=d, dtype=dtype) 51 | lambda_ = 10. 52 | gamma_ = 0.1 53 | 54 | import random 55 | 56 | randseed = 42 57 | if randseed is None: 58 | randseed = random.randint(0, 1000000) 59 | random.seed(randseed) 60 | np.random.seed(randseed) 61 | torch.manual_seed(randseed) 62 | logger.info("random seed %d", randseed) 63 | 64 | # new hyperparmaeters for approximate dynamics 65 | H_UNITS = 32 66 | TRAIN_EPOCH = 100 # 150 67 | BOOT_STRAP_ITER = 0 68 | EPISODE_CUT = 1000 69 | BATCH_SIZE = 50 70 | 71 | cost_tolerance = 0.1 72 | SUCCESS_CRITERION = 300 73 | 74 | nx = 2 75 | nu = 1 76 | # network output is state residual 77 | network = torch.nn.Sequential( 78 | torch.nn.Linear(nx + nu + 1, H_UNITS), 79 | torch.nn.Tanh(), 80 | torch.nn.Linear(H_UNITS, H_UNITS), 81 | torch.nn.Tanh(), 82 | torch.nn.Linear(H_UNITS, nx) 83 | ).double().to(device=device) 84 | 85 | def dynamics(state, perturbed_action): 86 | u = torch.clamp(perturbed_action, ACTION_LOW, ACTION_HIGH) 87 | if state.dim() == 1 or u.dim() == 1: 88 | state = state.view(1, -1) 89 | u = u.view(1, -1) 90 | xu = torch.cat((state, u), dim=1) 91 | # feed in cosine and sine of angle instead of theta 92 | xu = torch.cat((torch.sin( 93 | xu[:, 0]).view(-1, 1), torch.cos(xu[:, 0]).view(-1, 1), xu[:, 1:]), dim=1) 94 | 95 | network.eval() 96 | with torch.no_grad(): 97 | state_residual = network(xu) 98 | # output dtheta directly so can just add 99 | next_state = state.clone().detach() + state_residual 100 | next_state[:, 0] = angle_normalize(next_state[:, 0]) 101 | return next_state 102 | 103 | def true_dynamics(state, perturbed_action): 104 | # true dynamics from gym 105 | th = state[:, 0].view(-1, 1) 106 | thdot = state[:, 1].view(-1, 1) 107 | 108 | g = 10 109 | m = 1 110 | l = 1 111 | dt = 0.05 112 | 113 | u = perturbed_action 114 | u = torch.clamp(u, -2, 2) 115 | 116 | newthdot = thdot + (-3 * g / (2 * l) * 117 | torch.sin(th + np.pi) + 3. / (m * l ** 2) * u) * dt 118 | newth = th + newthdot * dt 119 | newthdot = torch.clamp(newthdot, -8, 8) 120 | 121 | next_state = torch.cat((newth, newthdot), dim=1) 122 | return next_state 123 | 124 | def angular_diff_batch(a, b): 125 | """Angle difference from b to a (a - b)""" 126 | d = a - b 127 | d[d > math.pi] -= 2 * math.pi 128 | d[d < -math.pi] += 2 * math.pi 129 | return d 130 | 131 | def angle_normalize(x): 132 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 133 | 134 | def running_cost(state): 135 | theta = state[:, 0] 136 | theta_dt = state[:, 1] 137 | cost = angle_normalize(theta) ** 2 + 0.1 * theta_dt ** 2 138 | return cost 139 | 140 | dataset_xu = None 141 | dataset_Y = None 142 | # create some true dynamics validation set to compare model against 143 | Nv = 1000 144 | statev = torch.cat(((torch.rand(Nv, 1, dtype=dtype, device=device) - 0.5) * 2 * math.pi, 145 | (torch.rand(Nv, 1, dtype=dtype, device=device) - 0.5) * 16), dim=1) 146 | actionv = (torch.rand(Nv, 1, dtype=dtype, device=device) - 147 | 0.5) * (ACTION_HIGH - ACTION_LOW) 148 | 149 | class CustomDataset(Dataset): 150 | def __init__(self, x, y): 151 | self.x_data = x 152 | self.y_data = y 153 | 154 | def __len__(self): 155 | return len(self.x_data) 156 | 157 | def __getitem__(self, item): 158 | x_ = self.x_data[item] 159 | y_ = self.y_data[item] 160 | return x_, y_ 161 | 162 | def dataset_append(state, action, next_state): 163 | global dataset_xu, dataset_Y 164 | state[0] = angle_normalize(state[0]) 165 | next_state[0] = angle_normalize(next_state[0]) 166 | action = torch.clamp(action.clone().detach(), ACTION_LOW, ACTION_HIGH) 167 | 168 | xu = torch.cat((state, action), dim=0) 169 | 170 | xu = torch.tensor( 171 | (torch.sin(xu[0]), 172 | torch.cos(xu[0]), 173 | xu[1], 174 | xu[2])).view(1, -1) 175 | dtheta = angular_diff_batch(next_state[0], state[0]) 176 | dtheta_dot = next_state[1] - state[1] 177 | Y = torch.tensor((dtheta, dtheta_dot)).view(1, -1).clone().detach() 178 | 179 | if dataset_xu is None and dataset_Y is None: 180 | dataset_xu = xu 181 | dataset_Y = Y 182 | 183 | else: 184 | dataset_xu = torch.cat((dataset_xu, xu), dim=0) 185 | dataset_Y = torch.cat((dataset_Y, Y), dim=0) 186 | 187 | def train(epoch=TRAIN_EPOCH): 188 | global dataset_xu, dataset_Y, network 189 | # thaw network 190 | for param in network.parameters(): 191 | param.requires_grad = True 192 | 193 | optimizer = torch.optim.Adam(network.parameters()) 194 | train_dataset = CustomDataset(dataset_xu, dataset_Y) 195 | train_loader = DataLoader( 196 | train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) 197 | 198 | network.train() 199 | for i in range(epoch): 200 | # MSE loss 201 | for x, y in train_loader: 202 | x, y = x.to(device), y.to(device) 203 | yhat = network(x) 204 | loss = (y - yhat).norm(2, dim=1) ** 2 205 | optimizer.zero_grad() 206 | loss.mean().backward() 207 | optimizer.step() 208 | logger.debug("ds %d epoch %d loss %f", 209 | dataset_xu.shape[0], i, loss.mean().item()) 210 | 211 | # freeze network 212 | for param in network.parameters(): 213 | param.requires_grad = False 214 | 215 | # evaluate network against true dynamics 216 | yt = true_dynamics(statev, actionv) 217 | yp = dynamics(statev, actionv) 218 | dtheta = angular_diff_batch(yp[:, 0], yt[:, 0]) 219 | dtheta_dt = yp[:, 1] - yt[:, 1] 220 | E = torch.cat((dtheta.view(-1, 1), dtheta_dt.view(-1, 1)), 221 | dim=1).norm(dim=1) 222 | logger.info("Error with true dynamics theta %f theta_dt %f norm %f", dtheta.abs().mean(), 223 | dtheta_dt.abs().mean(), E.mean()) 224 | logger.debug("Start next collection sequence") 225 | 226 | def model_save(): 227 | global network 228 | torch.save(network.state_dict(), 'model_weights_pendulum.pth') 229 | 230 | env = gym.make(ENV_NAME).env # bypass the default TimeLimit wrapper 231 | # bootstrap network with random actions 232 | if BOOT_STRAP_ITER: 233 | logger.info( 234 | "bootstrapping with random action for %d actions", BOOT_STRAP_ITER) 235 | data_count = 0 236 | while True: 237 | env.reset() 238 | for i in range(EPISODE_CUT): 239 | state = env.state 240 | state = torch.tensor( 241 | state, dtype=torch.float64).to(device=device) 242 | action = np.random.uniform(low=ACTION_LOW, high=ACTION_HIGH) 243 | action = torch.tensor( 244 | [action], dtype=torch.float64).to(device=device) 245 | s, _, done, _ = env.step(action.cpu().numpy()) 246 | next_state = env.state 247 | next_state = torch.tensor( 248 | next_state, dtype=torch.float64).to(device=device) 249 | dataset_append(state, action, next_state) 250 | data_count += 1 251 | if data_count == BOOT_STRAP_ITER: 252 | break 253 | if done: 254 | break 255 | if data_count == BOOT_STRAP_ITER: 256 | break 257 | train(epoch=500) 258 | logger.info("bootstrapping finished") 259 | 260 | env = wrappers.Monitor(env, '/tmp/mppi/', force=True) 261 | env.reset() 262 | if downward_start: 263 | env.env.state = [np.pi, INIT_VEL] 264 | 265 | if SMPPI: 266 | mppi_gym = mppi.MPPI(dynamics, running_cost, nx, nu, noise_sigma, 267 | num_samples=N_SAMPLES, 268 | horizon=TIMESTEPS, 269 | lambda_=lambda_, 270 | gamma_=gamma_, 271 | device=device, 272 | u_min=torch.tensor( 273 | D_ACTION_LOW, dtype=dtype, device=device), 274 | u_max=torch.tensor( 275 | D_ACTION_HIGH, dtype=dtype, device=device), 276 | action_min=torch.tensor( 277 | ACTION_LOW, dtype=dtype, device=device), 278 | action_max=torch.tensor(ACTION_HIGH, dtype=dtype, device=device)) 279 | else: 280 | mppi_gym = mppi.MPPI(dynamics, running_cost, nx, nu, noise_sigma, 281 | num_samples=N_SAMPLES, 282 | horizon=TIMESTEPS, 283 | lambda_=lambda_, 284 | gamma_=gamma_, 285 | device=device, 286 | u_min=torch.tensor( 287 | ACTION_LOW, dtype=dtype, device=device), 288 | u_max=torch.tensor( 289 | ACTION_HIGH, dtype=dtype, device=device), 290 | smooth=SMOOTH) 291 | 292 | cost_history = mppi.run_mppi_episode( 293 | mppi_gym, env, dataset_append, train, running_cost, model_save, cost_tolerance, SUCCESS_CRITERION) 294 | -------------------------------------------------------------------------------- /plot_cost.py: -------------------------------------------------------------------------------- 1 | from plotly.subplots import make_subplots 2 | import plotly.graph_objects as go 3 | import csv 4 | import numpy as np 5 | import os 6 | import random 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | import torch.optim as optim 13 | import argparse 14 | import seaborn as sns 15 | import matplotlib.pyplot as plt 16 | import pandas as pd 17 | 18 | timestep = 50 19 | 20 | 21 | def csv2np(filename): 22 | raw_data = [] 23 | with open(filename, newline='') as csvfile: 24 | spamreader = csv.reader(csvfile, delimiter=',', quotechar='|') 25 | for i, row in enumerate(spamreader): 26 | raw_data.append(row) 27 | return np.array(raw_data, dtype='float32') 28 | 29 | 30 | filename = 'original' 31 | split_data = csv2np(filename+'.csv') 32 | split_iteration = np.hsplit(split_data, timestep) 33 | split_iteration = np.array(split_iteration) 34 | print(split_iteration.shape) 35 | iteration_column = np.reshape( 36 | np.arange(split_iteration.shape[0]), (split_iteration.shape[0], 1)) 37 | df = pd.DataFrame(data=np.repeat(iteration_column, 38 | split_iteration.shape[1], axis=0), columns=['Iteration']) 39 | mean_cost = [np.mean(it, axis=1) for it in split_iteration] 40 | mean_cost = np.transpose(np.array(mean_cost)) 41 | mean_cost = np.clip(mean_cost, 0.01, 10) 42 | df.loc[:, 'Original w/o filter'] = pd.Series(np.squeeze(np.reshape( 43 | np.transpose(mean_cost), (1, -1))), index=df.index) 44 | 45 | # fig = go.Figure() 46 | # mean_cost = split_data 47 | # t = np.arange(0, mean_cost.shape[1]) 48 | 49 | # fig.add_trace(go.Scatter(x=t, y=mean_cost[0], mode='lines')) 50 | # fig.add_trace(go.Scatter(x=t, y=mean_cost[1], mode='lines')) 51 | # fig.add_trace(go.Scatter(x=t, y=mean_cost[2], mode='lines')) 52 | # fig.add_trace(go.Scatter(x=t, y=mean_cost[3], mode='lines')) 53 | # fig.add_trace(go.Scatter(x=t, y=mean_cost[4], mode='lines')) 54 | # fig.add_trace(go.Scatter(x=t, y=mean_cost[5], mode='lines')) 55 | # fig.add_trace(go.Scatter(x=t, y=mean_cost[6], mode='lines')) 56 | # fig.show() 57 | 58 | 59 | # fig = go.Figure() 60 | 61 | # fig.add_trace(go.Scatter(x=t, y=mean_cost, mode='lines', name='original', marker={ 62 | # 'color': 'rgb(255,0,0)'}, line=dict(width=4, dash='dash'))) 63 | 64 | filename = 'u_cost' 65 | split_data = csv2np(filename+'.csv') 66 | split_iteration = np.hsplit(split_data, timestep) 67 | split_iteration = np.array(split_iteration) 68 | mean_cost = [np.mean(it, axis=1) for it in split_iteration] 69 | mean_cost = np.transpose(np.array(mean_cost)) 70 | mean_cost = np.clip(mean_cost, 0.01, 10) 71 | df.loc[:, 'Original w/ action cost'] = pd.Series(np.squeeze(np.reshape( 72 | np.transpose(mean_cost), (1, -1))), index=df.index) 73 | 74 | filename = 'original_filter_on_noise' 75 | split_data = csv2np(filename+'.csv') 76 | split_iteration = np.hsplit(split_data, timestep) 77 | split_iteration = np.array(split_iteration) 78 | mean_cost = [np.mean(it, axis=1) for it in split_iteration] 79 | mean_cost = np.transpose(np.array(mean_cost)) 80 | mean_cost = np.clip(mean_cost, 0.01, 10) 81 | df.loc[:, 'Original (SGF(\u03B5))'] = pd.Series(np.squeeze(np.reshape( 82 | np.transpose(mean_cost), (1, -1))), index=df.index) 83 | 84 | filename = 'original_filter_on_u' 85 | split_data = csv2np(filename+'.csv') 86 | split_iteration = np.hsplit(split_data, timestep) 87 | split_iteration = np.array(split_iteration) 88 | mean_cost = [np.mean(it, axis=1) for it in split_iteration] 89 | mean_cost = np.transpose(np.array(mean_cost)) 90 | mean_cost = np.clip(mean_cost, 0.01, 10) 91 | df.loc[:, 'Original (SGF(u))'] = pd.Series(np.squeeze(np.reshape( 92 | np.transpose(mean_cost), (1, -1))), index=df.index) 93 | 94 | 95 | filename = 'das' 96 | split_data = csv2np(filename+'.csv') 97 | split_iteration = np.hsplit(split_data, timestep) 98 | split_iteration = np.array(split_iteration) 99 | mean_cost = [np.mean(it, axis=1) for it in split_iteration] 100 | mean_cost = np.transpose(np.array(mean_cost)) 101 | mean_cost = np.clip(mean_cost, 0.01, 10) 102 | df.loc[:, 'Ours'] = pd.Series(np.squeeze(np.reshape( 103 | np.transpose(mean_cost), (1, -1))), index=df.index) 104 | 105 | 106 | #sns.lineplot(data=df, x="iteration", y="Ours", ci=50, legend='full') 107 | #sns.lineplot(data=df, x="iteration", y="Original", ci=50, legend='full') 108 | df = df.melt('Iteration', var_name='Method', value_name='State cost') 109 | plt.figure(figsize=[6, 4]) 110 | 111 | sns.set_palette(reversed(sns.color_palette('Set1', 5)), 5) 112 | sns.lineplot(data=df, x='Iteration', y='State cost', hue='Method', ci=80) 113 | plt.yscale('log') 114 | plt.tight_layout() 115 | plt.xlim(0, 35) 116 | plt.ylim(0.009, 10) 117 | plt.xticks(fontsize=12) 118 | plt.yticks(fontsize=12) 119 | plt.xlabel(xlabel='Iteration', fontsize=12) 120 | plt.ylabel(ylabel='State cost', fontsize=12) 121 | plt.legend(prop={'size': 12}) 122 | plt.savefig('/home/add/Desktop/tempo.pdf', format='pdf') 123 | # plt.legend() 124 | plt.show() 125 | plt.close() 126 | 127 | 128 | # fig.update_layout( 129 | # xaxis_title="itertaion", 130 | # yaxis_title="cost", 131 | # height=650, 132 | # width=1800, 133 | # margin=dict(l=30, r=30, t=30, b=30), 134 | # font=dict(size=20), 135 | # legend=dict(font=dict(size=25), yanchor="top", 136 | # y=0.99, xanchor="right", x=0.99) 137 | # ) 138 | # fig.layout.template = 'plotly_white' 139 | # fig.show() 140 | 141 | # if not os.path.exists("image"): 142 | # os.mkdir("image") 143 | # fig.write_image("./image/inference_"+filename+"_vx.svg") 144 | -------------------------------------------------------------------------------- /pytorch_mppi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tkkim-robot/smooth-mppi-pytorch/dcc6f4285069dbf8add0a111c8b936d0ef1830e6/pytorch_mppi/__init__.py -------------------------------------------------------------------------------- /pytorch_mppi/mppi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import logging 4 | from torch.distributions.multivariate_normal import MultivariateNormal 5 | import numpy as np 6 | import math 7 | import csv 8 | from scipy.signal import savgol_filter 9 | 10 | f = open('mppi.csv', 'a', encoding='utf-8', newline='') 11 | wr = csv.writer(f) 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | def is_tensor_like(x): 16 | return torch.is_tensor(x) or type(x) is np.ndarray 17 | 18 | 19 | class MPPI(): 20 | """ 21 | Model Predictive Path Integral control 22 | This implementation batch samples the trajectories and so scales well with the number of samples K. 23 | 24 | Implemented according to algorithm 2 in Williams et al., 2017 25 | 'Information Theoretic MPC for Model-Based Reinforcement Learning', 26 | based off of https://github.com/ferreirafabio/mppi_pendulum 27 | """ 28 | 29 | def __init__(self, dynamics, running_cost, nx, nu, noise_sigma, num_samples=100, horizon=15, device="cpu", 30 | terminal_state_cost=None, 31 | lambda_=1., 32 | gamma_=1., 33 | noise_mu=None, 34 | u_min=None, 35 | u_max=None, 36 | u_init=None, 37 | U_init=None, 38 | u_scale=1, 39 | u_per_command=1, 40 | step_dependent_dynamics=False, 41 | rollout_var_cost=0, 42 | rollout_var_discount=0.95, 43 | sample_null_action=False, 44 | smooth="no filter"): 45 | """ 46 | :param dynamics: function(state, action) -> next_state (K x nx) taking in batch state (K x nx) and action (K x nu) 47 | :param running_cost: function(state, action) -> cost (K x 1) taking in batch state and action (same as dynamics) 48 | :param nx: state dimension 49 | :param noise_sigma: (nu x nu) control noise covariance (assume v_t ~ N(u_t, noise_sigma)) 50 | :param num_samples: K, number of trajectories to sample 51 | :param horizon: T, length of each trajectory 52 | :param device: pytorch device 53 | :param terminal_state_cost: function(state) -> cost (K x 1) taking in batch state 54 | :param lambda_: temperature, positive scalar where larger values will allow more exploration 55 | :param noise_mu: (nu) control noise mean (used to bias control samples); defaults to zero mean 56 | :param u_min: (nu) minimum values for each dimension of control to pass into dynamics 57 | :param u_max: (nu) maximum values for each dimension of control to pass into dynamics 58 | :param u_init: (nu) what to initialize new end of trajectory control to be; defeaults to zero 59 | :param U_init: (T x nu) initial control sequence; defaults to noise 60 | :param step_dependent_dynamics: whether the passed in dynamics needs horizon step passed in (as 3rd arg) 61 | :param rollout_var_cost: Cost attached to the variance of costs across trajectory rollouts 62 | :param rollout_var_discount: Discount of variance cost over control horizon 63 | :param sample_null_action: Whether to explicitly sample a null action (bad for starting in a local minima) 64 | :param smooth: Whether to explicitly sample a null action (bad for starting in a local minima) 65 | 66 | * option 67 | :param rollout_samples: M, number of state trajectories to rollout for each control trajectory 68 | (should be 1 for deterministic dynamics and more for models that output a distribution) 69 | """ 70 | 71 | self.smooth_method = smooth 72 | self.device = device 73 | self.dtype = noise_sigma.dtype 74 | self.K = num_samples # N_SAMPLES 75 | self.T = horizon # TIMESTEPS 76 | 77 | # dimensions of state and control 78 | self.nx = nx 79 | self.nu = 1 if len(noise_sigma.shape) == 0 else noise_sigma.shape[0] 80 | self.lambda_ = lambda_ 81 | self.gamma_ = gamma_ 82 | 83 | if noise_mu is None: 84 | noise_mu = torch.zeros(self.nu, dtype=self.dtype) 85 | 86 | if u_init is None: 87 | u_init = torch.zeros_like(noise_mu) 88 | 89 | if U_init is None: 90 | U_init = torch.zeros(self.T, self.nu).to(device) 91 | 92 | # handle 1D edge case 93 | if self.nu == 1: 94 | noise_mu = noise_mu.view(-1) 95 | noise_sigma = noise_sigma.view(-1, 1) 96 | 97 | # bounds 98 | self.u_min = u_min 99 | self.u_max = u_max 100 | self.u_scale = u_scale 101 | self.u_per_command = u_per_command 102 | # make sure if any of them is specified, both are specified 103 | if self.u_max is not None and self.u_min is None: 104 | if not torch.is_tensor(self.u_max): 105 | self.u_max = torch.tensor(self.u_max) 106 | self.u_min = -self.u_max 107 | if self.u_min is not None and self.u_max is None: 108 | if not torch.is_tensor(self.u_min): 109 | self.u_min = torch.tensor(self.u_min) 110 | self.u_max = -self.u_min 111 | if self.u_min is not None: 112 | self.u_min = self.u_min.to(device=self.device) 113 | self.u_max = self.u_max.to(device=self.device) 114 | 115 | self.noise_mu = noise_mu.to(self.device) 116 | self.noise_sigma = noise_sigma.to(self.device) 117 | self.noise_sigma_inv = torch.inverse(self.noise_sigma) 118 | self.noise_dist = MultivariateNormal( 119 | self.noise_mu, covariance_matrix=self.noise_sigma) 120 | # T x nu control sequence 121 | self.U = U_init 122 | self.u_init = u_init.to(self.device) 123 | 124 | if self.U is None: 125 | self.U = self.noise_dist.sample((self.T,)) 126 | self.U = torch.zeros_like(self.U) 127 | 128 | self.U_history = torch.zeros_like(self.U)[:5] 129 | 130 | self.step_dependency = step_dependent_dynamics 131 | self.F = dynamics 132 | self.running_cost = running_cost 133 | self.terminal_state_cost = terminal_state_cost 134 | self.sample_null_action = sample_null_action 135 | self.state = None 136 | 137 | # handling dynamics models that output a distribution (take multiple trajectory samples) 138 | self.rollout_var_cost = rollout_var_cost 139 | self.rollout_var_discount = rollout_var_discount 140 | 141 | # sampled results from last command 142 | self.cost_total = None 143 | self.cost_total_non_zero = None 144 | self.omega = None 145 | self.states = None 146 | self.actions = None 147 | 148 | def _dynamics(self, state, u, t): 149 | return self.F(state, u, t) if self.step_dependency else self.F(state, u) 150 | 151 | def _running_cost(self, state): 152 | return self.running_cost(state) 153 | 154 | def command(self, state): 155 | """ 156 | :param state: (nx) or (K x nx) current state, or samples of states (for propagating a distribution of states) 157 | :returns action: (nu) best action 158 | """ 159 | # shift command 1 time step 160 | self.U = torch.roll(self.U, -1, dims=0) 161 | self.U[-1] = self.u_init 162 | 163 | perturbed_action = self.action_sampling() 164 | 165 | cost_total, states = self._compute_batch_rollout_costs( 166 | perturbed_action, state) 167 | self.omega = self._compute_weighting(cost_total) 168 | 169 | weighted_noise = torch.sum( 170 | self.omega.view(-1, 1, 1) * self.noise, dim=0) 171 | 172 | if self.smooth_method == "no filter": 173 | self.U += weighted_noise 174 | elif self.smooth_method == "smooth u": 175 | self.U += weighted_noise 176 | U_filtered = savgol_filter( 177 | torch.cat([self.U_history, self.U]).to('cpu'), 5, 3, axis=0) 178 | self.U = torch.tensor(U_filtered[-self.T:]).to(self.device) 179 | 180 | self.U_history = torch.roll(self.U_history, -1, dims=0) 181 | self.U_history[-1] = self.U[0] 182 | elif self.smooth_method == "smooth noise": 183 | self.U += torch.tensor(savgol_filter(weighted_noise.to('cpu'), 184 | 5, 3, axis=0)).to(self.device) 185 | else: 186 | raise Exception("Wrong smooth option !!!") 187 | 188 | 189 | action = self.U[0] 190 | 191 | return action 192 | 193 | def reset(self): 194 | """ 195 | Clear controller state after finishing a trial 196 | """ 197 | self.U = torch.zeros_like(self.U) 198 | 199 | def _compute_weighting(self, cost_total): 200 | beta = torch.min(cost_total) 201 | cost_total_non_zero = torch.exp(-1/self.lambda_ * (cost_total - beta)) 202 | eta = torch.sum(cost_total_non_zero) 203 | omega = (1. / eta) * cost_total_non_zero 204 | return omega 205 | 206 | def _compute_batch_rollout_costs(self, perturbed_actions, state): 207 | K, T, nu = perturbed_actions.shape 208 | assert nu == self.nu 209 | 210 | cost_total = torch.zeros(K, device=self.device, dtype=self.dtype) 211 | cost_samples = torch.zeros(K, device=self.device, dtype=self.dtype) 212 | 213 | # allow propagation of a sample of states (ex. to carry a distribution), or to start with a single state 214 | # state -> nx 215 | if state.shape == (K, self.nx): 216 | state = state 217 | else: 218 | state = state.view(1, -1).repeat(K, 1) 219 | # state -> K*nu 220 | 221 | states = [] 222 | actions = [] 223 | 224 | prev_map_mask = torch.zeros((K), device=self.device) 225 | for t in range(T): 226 | # perturbed_actions -> K*T*nu 227 | # perturbed_actions[:, t] -> K*nu 228 | u = self.u_scale * perturbed_actions[:, t] # v -> K*nu 229 | 230 | state = self._dynamics(state, u, t) 231 | c = self._running_cost(state) # c -> K 232 | 233 | cost_samples += c # cost_samples -> K 234 | 235 | # Save total states/actions 236 | states.append(state) 237 | actions.append(u) 238 | 239 | # actions -> [K*nu, K*nu ...] with size T 240 | # torch.stack(actions, dim=-2) -> K*T*nu 241 | actions = torch.stack(actions, dim=-2) 242 | states = torch.stack(states, dim=-2) 243 | 244 | # terminal state cost 245 | if self.terminal_state_cost: 246 | phi = self.terminal_state_cost(states, actions) 247 | cost_samples += phi 248 | 249 | action_cost = self.gamma_ * self.noise @ self.noise_sigma_inv 250 | 251 | # action_cost -> K*T*nu 252 | # U -> T*nu 253 | perturbation_cost = torch.sum(self.U * action_cost, dim=(1, 2)) 254 | 255 | cost_total += cost_samples + perturbation_cost # K dim 256 | 257 | return cost_total, states 258 | 259 | def _bound_action(self, action): 260 | 261 | return torch.max(torch.min(action, self.u_max), self.u_min) # action 262 | 263 | def action_sampling(self): 264 | # parallelize sampling across trajectories 265 | # resample noise each time we take an action 266 | 267 | # Small portion are just guanssian perturbation aroung zero 268 | self.noise = self.noise_dist.sample( 269 | (round(self.K*0.99), self.T)) # K*T*nu (noise_dist has nu-dim) 270 | # broadcast own control to noise over samples; now it's K x T x nu 271 | perturbed_action = self.U + self.noise 272 | perturbed_action = torch.cat( 273 | [perturbed_action, self.noise_dist.sample((round(self.K*0.01), self.T))]) 274 | 275 | if self.sample_null_action: 276 | perturbed_action[self.K - 1] = 0 277 | 278 | perturbed_action = self._bound_action(perturbed_action) 279 | 280 | # remove U to earn bounded noise 281 | self.noise = perturbed_action - self.U 282 | 283 | return perturbed_action 284 | 285 | 286 | def angle_normalize(x): 287 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 288 | 289 | def run_mppi_episode(mppi, env, dataset_append, retrain_dynamics, cost, model_save, cost_tolerance, SUCCESS_CRITERION, retrain_after_iter=50, num_episode=30, render=True): 290 | dataset_count = 0 291 | cost_history = [] 292 | cost_ = 0. 293 | for ep in range(num_episode): 294 | env.reset() 295 | success_count = 0 296 | cost_episode = [] 297 | 298 | while True: 299 | if render: 300 | env.render() 301 | state = env.state 302 | state = torch.tensor(state, dtype=mppi.noise_sigma.dtype).to(device=mppi.device) 303 | command_start = time.perf_counter() 304 | action = mppi.command(state) 305 | elapsed = time.perf_counter() - command_start 306 | s, _, done, _ = env.step(action.cpu().numpy()) 307 | next_state = env.state 308 | next_state = torch.tensor(next_state, dtype=mppi.noise_sigma.dtype).to(device=mppi.device) 309 | 310 | # Collect Training datas 311 | dataset_append(state, action, next_state) 312 | 313 | logger.debug( 314 | "action taken: %.4f cost received: %.4f time taken: %.5fs", action, cost_, elapsed) 315 | 316 | dataset_count += 1 317 | di = dataset_count % retrain_after_iter 318 | if di == 0 and dataset_count > 0: 319 | retrain_dynamics() 320 | 321 | cost_ = cost(next_state.view(1, -1)) 322 | cost_episode.append(cost_.item()) 323 | 324 | if cost_ < cost_tolerance: 325 | success_count += 1 326 | if success_count >= SUCCESS_CRITERION: 327 | print("Task completed") 328 | cost_history.append(cost_episode) 329 | model_save() 330 | return cost_history 331 | else: 332 | success_count = 0 333 | 334 | if done: 335 | print("Episode {} terminated".format(ep + 1)) 336 | break 337 | wr.writerow(cost_episode) 338 | cost_history.append(cost_episode) 339 | return cost_history -------------------------------------------------------------------------------- /pytorch_mppi/smooth_mppi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import logging 4 | from torch.distributions.multivariate_normal import MultivariateNormal 5 | import numpy as np 6 | import math 7 | import csv 8 | 9 | f = open('smppi.csv', 'a', encoding='utf-8', newline='') 10 | wr = csv.writer(f) 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | def is_tensor_like(x): 15 | return torch.is_tensor(x) or type(x) is np.ndarray 16 | 17 | 18 | class MPPI(): 19 | """ 20 | Model Predictive Path Integral control 21 | This implementation batch samples the trajectories and so scales well with the number of samples K. 22 | 23 | Implemented according to algorithm 2 in Williams et al., 2017 24 | 'Information Theoretic MPC for Model-Based Reinforcement Learning', 25 | based off of https://github.com/ferreirafabio/mppi_pendulum 26 | """ 27 | 28 | def __init__(self, dynamics, running_cost, nx, nu, noise_sigma, num_samples=100, horizon=15, device="cpu", 29 | terminal_state_cost=None, 30 | lambda_=1., 31 | gamma_=1., 32 | w_action_seq_cost=1., 33 | noise_mu=None, 34 | u_min=None, 35 | u_max=None, 36 | action_min=None, 37 | action_max=None, 38 | u_init=None, 39 | U_init=None, 40 | u_scale=1, 41 | u_per_command=1, 42 | step_dependent_dynamics=False, 43 | rollout_var_cost=0, 44 | rollout_var_discount=0.95, 45 | sample_null_action=False): 46 | """ 47 | :param dynamics: function(state, action) -> next_state (K x nx) taking in batch state (K x nx) and action (K x nu) 48 | :param running_cost: function(state, action) -> cost (K x 1) taking in batch state and action (same as dynamics) 49 | :param nx: state dimension 50 | :param noise_sigma: (nu x nu) control noise covariance (assume v_t ~ N(u_t, noise_sigma)) 51 | :param num_samples: K, number of trajectories to sample 52 | :param horizon: T, length of each trajectory 53 | :param device: pytorch device 54 | :param terminal_state_cost: function(state) -> cost (K x 1) taking in batch state 55 | :param lambda_: temperature, positive scalar where larger values will allow more exploration 56 | :param gamma_: running action cost parameter 57 | :param w_action_seq_cost: (nu x nu) weight parameter for action sequence cost 58 | :param noise_mu: (nu) control noise mean (used to bias control samples); defaults to zero mean 59 | :param u_min: (nu) minimum values for each dimension of control 60 | :param u_max: (nu) maximum values for each dimension of control 61 | :param action_min: (nu) minimum values for each dimension of action to pass into dynamics 62 | :param action_max: (nu) maximum values for each dimension of action to pass into dynamics 63 | :param u_init: (nu) what to initialize new end of trajectory control to be; defeaults to zero 64 | :param U_init: (T x nu) initial control sequence; defaults to noise 65 | :param step_dependent_dynamics: whether the passed in dynamics needs horizon step passed in (as 3rd arg) 66 | :param rollout_var_cost: Cost attached to the variance of costs across trajectory rollouts 67 | :param rollout_var_discount: Discount of variance cost over control horizon 68 | :param sample_null_action: Whether to explicitly sample a null action (bad for starting in a local minima) 69 | 70 | * option 71 | :param rollout_samples: M, number of state trajectories to rollout for each control trajectory 72 | (should be 1 for deterministic dynamics and more for models that output a distribution) 73 | """ 74 | 75 | self.device = device 76 | self.dtype = noise_sigma.dtype 77 | self.K = num_samples # N_SAMPLES 78 | self.T = horizon # TIMESTEPS 79 | 80 | # dimensions of state and control 81 | self.nx = nx 82 | self.nu = 1 if len(noise_sigma.shape) == 0 else noise_sigma.shape[0] 83 | self.lambda_ = lambda_ 84 | self.gamma_ = gamma_ 85 | 86 | self.w_action_seq_cost = w_action_seq_cost 87 | 88 | if noise_mu is None: 89 | noise_mu = torch.zeros(self.nu, dtype=self.dtype) 90 | 91 | if u_init is None: 92 | u_init = torch.zeros_like(noise_mu) 93 | 94 | if U_init is None: 95 | U_init = torch.zeros(self.T, self.nu).to(device) 96 | 97 | # handle 1D edge case 98 | if self.nu == 1: 99 | noise_mu = noise_mu.view(-1) 100 | noise_sigma = noise_sigma.view(-1, 1) 101 | 102 | # bounds 103 | self.u_min = u_min 104 | self.u_max = u_max 105 | self.action_min = action_min 106 | self.action_max = action_max 107 | self.u_scale = u_scale 108 | self.u_per_command = u_per_command 109 | # make sure if any of them is specified, both are specified 110 | if self.u_max is not None and self.u_min is None: 111 | if not torch.is_tensor(self.u_max): 112 | self.u_max = torch.tensor(self.u_max) 113 | self.u_min = -self.u_max 114 | if self.u_min is not None and self.u_max is None: 115 | if not torch.is_tensor(self.u_min): 116 | self.u_min = torch.tensor(self.u_min) 117 | self.u_max = -self.u_min 118 | if self.u_min is not None: 119 | self.u_min = self.u_min.to(device=self.device) 120 | self.u_max = self.u_max.to(device=self.device) 121 | self.action_min = self.action_min.to(device=self.device) 122 | self.action_max = self.action_max.to(device=self.device) 123 | 124 | self.noise_mu = noise_mu.to(self.device) 125 | self.noise_sigma = noise_sigma.to(self.device) 126 | self.noise_sigma_inv = torch.inverse(self.noise_sigma) 127 | self.noise_dist = MultivariateNormal( 128 | self.noise_mu, covariance_matrix=self.noise_sigma) 129 | # T x nu control sequence 130 | self.U = U_init 131 | self.action_sequence = U_init 132 | self.u_init = u_init.to(self.device) 133 | 134 | if self.U is None: 135 | self.U = self.noise_dist.sample((self.T,)) 136 | self.U = torch.zeros_like(self.U) 137 | self.action_sequence = torch.zeros_like(self.U) 138 | 139 | self.step_dependency = step_dependent_dynamics 140 | self.F = dynamics 141 | self.running_cost = running_cost 142 | self.terminal_state_cost = terminal_state_cost 143 | self.sample_null_action = sample_null_action 144 | self.state = None 145 | 146 | # handling dynamics models that output a distribution (take multiple trajectory samples) 147 | self.rollout_var_cost = rollout_var_cost 148 | self.rollout_var_discount = rollout_var_discount 149 | 150 | # sampled results from last command 151 | self.cost_total = None 152 | self.cost_total_non_zero = None 153 | self.omega = None 154 | self.states = None 155 | self.actions = None 156 | 157 | def _dynamics(self, state, u, t): 158 | return self.F(state, u, t) if self.step_dependency else self.F(state, u) 159 | 160 | def _running_cost(self, state): 161 | return self.running_cost(state) 162 | 163 | def command(self, state): 164 | """ 165 | :param state: (nx) or (K x nx) current state, or samples of states (for propagating a distribution of states) 166 | :returns action: (nu) best action 167 | """ 168 | # shift command 1 time step 169 | self.U = torch.roll(self.U, -1, dims=0) 170 | self.U[-1] = self.u_init 171 | self.action_sequence = torch.roll(self.action_sequence, -1, dims=0) 172 | self.action_sequence[-1] = self.action_sequence[-2] # add T-1 action to T 173 | 174 | perturbed_action = self.noise_sampling() 175 | 176 | cost_total, states = self._compute_batch_rollout_costs( 177 | perturbed_action, state) 178 | self.omega = self._compute_weighting(cost_total) 179 | 180 | weighted_noise = torch.sum( 181 | self.omega.view(-1, 1, 1) * self.noise, dim=0) 182 | self.U += weighted_noise 183 | 184 | self.action_sequence += self.U 185 | 186 | action = self.action_sequence[0] 187 | 188 | return action 189 | 190 | def reset(self): 191 | """ 192 | Clear controller state after finishing a trial 193 | """ 194 | self.U = torch.zeros_like(self.U) 195 | self.action_sequence = torch.zeros_like(self.U) 196 | 197 | def _compute_weighting(self, cost_total): 198 | beta = torch.min(cost_total) 199 | cost_total_non_zero = torch.exp(-1/self.lambda_ * (cost_total - beta)) 200 | eta = torch.sum(cost_total_non_zero) 201 | omega = (1. / eta) * cost_total_non_zero 202 | return omega 203 | 204 | def _compute_batch_rollout_costs(self, perturbed_actions, state): 205 | K, T, nu = perturbed_actions.shape 206 | assert nu == self.nu 207 | 208 | cost_total = torch.zeros(K, device=self.device, dtype=self.dtype) 209 | cost_samples = torch.zeros(K, device=self.device, dtype=self.dtype) 210 | 211 | # allow propagation of a sample of states (ex. to carry a distribution), or to start with a single state 212 | # state -> nx 213 | if state.shape == (K, self.nx): 214 | state = state 215 | else: 216 | state = state.view(1, -1).repeat(K, 1) 217 | # state -> K*nu 218 | 219 | states = [] 220 | actions = [] 221 | 222 | for t in range(T): 223 | # perturbed_actions -> K*T*nu 224 | # perturbed_actions[:, t] -> K*nu 225 | action = self.u_scale * perturbed_actions[:, t] # v -> K*nu 226 | 227 | state = self._dynamics(state, action, t) 228 | c = self._running_cost(state) # c -> K 229 | 230 | cost_samples += c # cost_samples -> K 231 | 232 | # Save total states/actions 233 | states.append(state) 234 | actions.append(action) 235 | 236 | # actions -> [K*nu, K*nu ...] with size T 237 | # torch.stack(actions, dim=-2) -> K*T*nu 238 | actions = torch.stack(actions, dim=-2) 239 | states = torch.stack(states, dim=-2) 240 | 241 | # terminal state cost 242 | if self.terminal_state_cost: 243 | phi = self.terminal_state_cost(states, actions) 244 | cost_samples += phi 245 | 246 | control_cost = self.gamma_ * self.noise @ self.noise_sigma_inv 247 | 248 | # control_cost -> K*T*nu 249 | # U -> T*nu 250 | control_cost = torch.sum(self.U * control_cost, dim=(1, 2)) 251 | 252 | # action difference as cost 253 | action_diff = self.u_scale * \ 254 | (perturbed_actions[:, 1:] - perturbed_actions[:, :-1]) 255 | action_sequence_cost = torch.sum(torch.square(action_diff), dim=(1, 2)) 256 | action_sequence_cost *= self.w_action_seq_cost 257 | 258 | cost_total = cost_samples + control_cost + action_sequence_cost # K dim 259 | 260 | return cost_total, states 261 | 262 | def _bound_d_action(self, control): 263 | return torch.max(torch.min(control, self.u_max), self.u_min) # action 264 | 265 | def _bound_action(self, action): 266 | return torch.max(torch.min(action, self.action_max), self.action_min) # derivative action (= control) 267 | 268 | def noise_sampling(self): 269 | # parallelize sampling across trajectories 270 | # resample noise each time we take an action 271 | 272 | # Small portion are just guanssian perturbation aroung zero 273 | self.noise = self.noise_dist.sample( 274 | (round(self.K*0.99), self.T)) # K*T*nu (noise_dist has nu-dim) 275 | # broadcast own control to noise over samples; now it's K x T x nu 276 | perturbed_control = self.U + self.noise 277 | perturbed_control = torch.cat( 278 | [perturbed_control, self.noise_dist.sample((round(self.K*0.01), self.T))]) 279 | 280 | perturbed_control = self._bound_d_action(perturbed_control) 281 | 282 | perturbed_action = perturbed_control + self.action_sequence 283 | 284 | if self.sample_null_action: 285 | perturbed_action[self.K - 1] = 0 286 | 287 | perturbed_action = self._bound_action(perturbed_action) 288 | 289 | # remove action and U to earn double bounded noise 290 | self.noise = perturbed_action - self.action_sequence - self.U 291 | 292 | return perturbed_action 293 | 294 | 295 | def angle_normalize(x): 296 | return (((x + math.pi) % (2 * math.pi)) - math.pi) 297 | 298 | def run_mppi_episode(mppi, env, dataset_append, retrain_dynamics, cost, model_save, cost_tolerance, SUCCESS_CRITERION, retrain_after_iter=50, num_episode=30, render=True): 299 | dataset_count = 0 300 | cost_history = [] 301 | cost_ = 0. 302 | for ep in range(num_episode): 303 | env.reset() 304 | success_count = 0 305 | cost_episode = [] 306 | 307 | while True: 308 | if render: 309 | env.render() 310 | state = env.state 311 | state = torch.tensor(state, dtype=mppi.noise_sigma.dtype).to(device=mppi.device) 312 | command_start = time.perf_counter() 313 | action = mppi.command(state) 314 | elapsed = time.perf_counter() - command_start 315 | s, _, done, _ = env.step(action.cpu().numpy()) 316 | next_state = env.state 317 | next_state = torch.tensor(next_state, dtype=mppi.noise_sigma.dtype).to(device=mppi.device) 318 | 319 | # Collect Training datas 320 | dataset_append(state, action, next_state) 321 | 322 | logger.debug( 323 | "action taken: %.4f cost received: %.4f time taken: %.5fs", action, cost_, elapsed) 324 | 325 | dataset_count += 1 326 | di = dataset_count % retrain_after_iter 327 | if di == 0 and dataset_count > 0: 328 | retrain_dynamics() 329 | 330 | cost_ = cost(next_state.view(1, -1)) 331 | cost_episode.append(cost_.item()) 332 | 333 | if cost_ < cost_tolerance: 334 | success_count += 1 335 | if success_count >= SUCCESS_CRITERION: 336 | print("Task completed") 337 | cost_history.append(cost_episode) 338 | model_save() 339 | return cost_history 340 | else: 341 | success_count = 0 342 | 343 | if done: 344 | print("Episode {} terminated".format(ep + 1)) 345 | break 346 | wr.writerow(cost_episode) 347 | cost_history.append(cost_episode) 348 | return cost_history 349 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='pytorch_mppi', 5 | version='0.1.0', 6 | packages=['pytorch_mppi'], 7 | url='https://github.com/ktk1501/smooth-mppi-pytorch', 8 | license='MIT', 9 | author='Taekyung Kim', 10 | author_email='ktk1501@kakao.com', 11 | description='Smooth Model Predictive Path Integral without Smoothing (SMPPI) implemented in pytorch', 12 | install_requires=[ 13 | 'torch', 14 | 'numpy', 15 | 'scipy' 16 | ], 17 | tests_require=[ 18 | 'gym' 19 | ] 20 | ) 21 | --------------------------------------------------------------------------------