├── L4DCMain.png ├── README.md ├── experiments ├── D_comp │ ├── main.py │ ├── plot.py │ └── results.npy └── discont_comp │ ├── main.py │ ├── plot.py │ └── plot_num_obs.py ├── main.py ├── mpc ├── cem.py ├── grad.py ├── gradcem.py ├── svgd.py └── test_energy.py └── setup.py /L4DCMain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/homangab/gradcem/02a8b36269704ab7e4c1207b6420cc788286fd67/L4DCMain.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | Model-Predictive Control via Cross-Entropy and Gradient-Based Optimization 3 | ====== 4 | 5 | Code accompanying the paper Model-Predictive Control via Cross-Entropy and Gradient-Based Optimization with authors [Homanga Bharadhwaj](https://homangab.github.io), Kevin Xie, and [Florian Shkurti](http://www.cs.toronto.edu/~florian/) (First two authors contributed equally). To be presented in [L4DC 2020](https://sites.google.com/berkeley.edu/l4dc/accepted-papers). 6 | 7 | ![Overall description](L4DCMain.png) 8 | 9 | Requirements 10 | ------------ 11 | 12 | - Python 3 13 | - [DeepMind Control Suite](https://github.com/deepmind/dm_control) 14 | - [Gym](https://gym.openai.com/) 15 | - [OpenCV Python](https://pypi.python.org/pypi/opencv-python) 16 | - [MuJoCo](http://www.mujoco.org/) 17 | - [PyTorch](http://pytorch.org/) 18 | 19 | 20 | Instructions 21 | ------------ 22 | 23 | Run `python.main.py` in folder `experiments/D_comp` followed by `python plot.py` (for Fig. 3a in the paper) 24 | 25 | Run `python.main.py` in folder `experiments/discont_comp` followed by `python plot_num_obs.py` (for Fig. 3b in the paper) 26 | 27 | Contributors 28 | ------------ 29 | @homangab and @kevincxie (equal contribution) 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /experiments/D_comp/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mpc.test_energy import test_energy2d, NavigateGTEnv 4 | import matplotlib.pyplot as plt 5 | from mpc.cem import CEM 6 | from mpc.grad import GradPlan 7 | from mpc.gradcem import GradCEMPlan 8 | from tqdm import tqdm 9 | 10 | def run(planner, env): 11 | planner.set_env(env) 12 | plan = planner.forward(1, return_plan=True) 13 | env.reset_state(1) 14 | r = env.rollout(plan) 15 | # print(r.size()) 16 | return r.item() 17 | 18 | def run_mult(planner, env, N): 19 | rs = [] 20 | for i in tqdm(range(N)): 21 | rs.append(run(planner,env)) 22 | rs = np.array(rs) 23 | m = np.mean(rs) 24 | std = np.std(rs) 25 | return m, std 26 | 27 | if __name__ == "__main__": 28 | # Compare CEM to GradPlan as we increase D 29 | comp_device = torch.device('cuda:0') 30 | H = 50 31 | 32 | K = 20 33 | tK = 4 34 | opt_iter = 10 35 | cem_planner = CEM(H, opt_iter, K, tK, None, device=comp_device) 36 | 37 | # K = 20 38 | # opt_iter = 10 39 | grad_planner = GradPlan(H, opt_iter, K, None, device=comp_device) 40 | 41 | gradcem_planner = GradCEMPlan(H, opt_iter, K, tK, None, device=comp_device) 42 | 43 | B = 1 44 | grad_return = [] 45 | cem_return = [] 46 | grad_std = [] 47 | cem_std = [] 48 | 49 | gradcem_return = [] 50 | gradcem_std = [] 51 | for D in range(2,21): 52 | print(D) 53 | env = NavigateGTEnv(B, D, test_energy2d, comp_device, control='force', sparse_r_step=H-1, dt=2.5/H) 54 | 55 | print("Running grad+cem planner") 56 | m, std = run_mult(gradcem_planner, env,20) 57 | gradcem_return.append(m) 58 | gradcem_std.append(std) 59 | 60 | m, std = run_mult(grad_planner, env,20) 61 | grad_return.append(m) 62 | grad_std.append(std) 63 | 64 | m, std = run_mult(cem_planner, env,20) 65 | cem_return.append(m) 66 | cem_std.append(std) 67 | 68 | # print(grad_return) 69 | # print(grad_std) 70 | # print(cem_return) 71 | # print(cem_std) 72 | grad_return = np.array(grad_return) 73 | grad_std = np.array(grad_std) 74 | cem_return = np.array(cem_return) 75 | cem_std = np.array(cem_std) 76 | gradcem_return = np.array(gradcem_return) 77 | gradcem_std = np.array(gradcem_std) 78 | # results = np.stack((grad_return, grad_std, cem_return, cem_std)) 79 | results = np.stack((grad_return, grad_std, cem_return, cem_std, gradcem_return, gradcem_std)) 80 | 81 | np.save("results_new.npy", results) 82 | -------------------------------------------------------------------------------- /experiments/D_comp/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | results = np.load('results_new.npy') 5 | x = np.arange(2, 21) 6 | g_ret = results[0] 7 | g_std = results[1] 8 | c_ret = results[2] 9 | c_std = results[3] 10 | gc_ret = results[4] 11 | gc_std = results[5] 12 | 13 | plt.errorbar(x, g_ret, yerr=g_std, label="Grad") 14 | plt.errorbar(x, c_ret, yerr=c_std, label="CEM") 15 | plt.errorbar(x, gc_ret, yerr=gc_std, label="Grad+CEM") 16 | plt.legend() 17 | plt.xticks(x) 18 | plt.ylabel("Total Reward") 19 | plt.xlabel("Action Dimensionality") 20 | plt.savefig("./test20D_new.png") 21 | -------------------------------------------------------------------------------- /experiments/D_comp/results.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/homangab/gradcem/02a8b36269704ab7e4c1207b6420cc788286fd67/experiments/D_comp/results.npy -------------------------------------------------------------------------------- /experiments/discont_comp/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mpc.test_energy import test_energy2d, NavigateGTEnv 4 | import matplotlib.pyplot as plt 5 | from mpc.cem import CEM 6 | from mpc.grad import GradPlan 7 | from mpc.gradcem import GradCEMPlan 8 | from tqdm import tqdm 9 | 10 | def run(planner, env): 11 | planner.set_env(env) 12 | plan = planner.forward(1, return_plan=True) 13 | env.reset_state(1) 14 | r = env.rollout(plan) 15 | # print(r.size()) 16 | return r.item() 17 | 18 | def run_mult(planner, env, N): 19 | rs = [] 20 | for i in tqdm(range(N)): 21 | rs.append(run(planner,env)) 22 | rs = np.array(rs) 23 | m = np.mean(rs) 24 | std = np.std(rs) 25 | return m, std 26 | 27 | if __name__ == "__main__": 28 | # Compare CEM to GradPlan as we increase D 29 | comp_device = torch.device('cuda:0') 30 | H = 50 31 | dt = 2.5/H 32 | 33 | K = 20 34 | tK = 4 35 | opt_iter = 10 36 | cem_planner = CEM(H, opt_iter, K, tK, None, device=comp_device) 37 | 38 | # K = 10 39 | # opt_iter = 20 40 | grad_planner = GradPlan(H, opt_iter, K, None, device=comp_device) 41 | 42 | # K = 10 43 | # tK = 4 44 | # opt_iter = 20 45 | gradcem_planner = GradCEMPlan(H, opt_iter, K, tK, None, device=comp_device) 46 | 47 | 48 | B = 1 49 | grad_return = [] 50 | cem_return = [] 51 | gradcem_return = [] 52 | grad_std = [] 53 | cem_std = [] 54 | gradcem_std = [] 55 | # for D in range(2,20): 56 | D = 2 57 | for num_obs in range(5,11): 58 | print(D) 59 | np.random.seed(0) 60 | env = NavigateGTEnv(B, D, test_energy2d, comp_device, control='force', sparse_r_step=H-1, dt=dt, obstacles_env=True, num_obs = num_obs) 61 | print("Running grad+cem planner") 62 | m, std = run_mult(gradcem_planner, env,50) 63 | gradcem_return.append(m) 64 | gradcem_std.append(std) 65 | 66 | np.random.seed(0) 67 | env = NavigateGTEnv(B, D, test_energy2d, comp_device, control='force', sparse_r_step=H-1, dt=dt, obstacles_env=True, num_obs = num_obs) 68 | print("Running grad planner") 69 | m, std = run_mult(grad_planner, env,50) 70 | grad_return.append(m) 71 | grad_std.append(std) 72 | 73 | np.random.seed(0) 74 | env = NavigateGTEnv(B, D, test_energy2d, comp_device, control='force', sparse_r_step=H-1, dt=dt, obstacles_env=True, num_obs = num_obs) 75 | print("Running cem planner") 76 | m, std = run_mult(cem_planner, env,50) 77 | cem_return.append(m) 78 | cem_std.append(std) 79 | 80 | 81 | 82 | # print(grad_return) 83 | # print(grad_std) 84 | # print(cem_return) 85 | # print(cem_std) 86 | # print(gradcem_return) 87 | # print(gradcem_std) 88 | grad_return = np.array(grad_return) 89 | grad_std = np.array(grad_std) 90 | cem_return = np.array(cem_return) 91 | cem_std = np.array(cem_std) 92 | gradcem_return = np.array(gradcem_return) 93 | gradcem_std = np.array(gradcem_std) 94 | results = np.stack((grad_return, grad_std, cem_return, cem_std, gradcem_return, gradcem_std)) 95 | 96 | np.save("results_new_num_obs1.npy", results) 97 | -------------------------------------------------------------------------------- /experiments/discont_comp/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | results = np.load('results_new.npy') 5 | x = np.arange(2,4) 6 | g_ret = results[0] 7 | g_std = results[1] 8 | c_ret = results[2] 9 | c_std = results[3] 10 | cg_ret = results[3] 11 | cg_std = results[4] 12 | 13 | plt.errorbar(x, g_ret, yerr=g_std, label="Grad", fmt='o', capsize=2) 14 | plt.errorbar(x, c_ret, yerr=c_std, label="CEM", fmt='o', capsize=2) 15 | plt.errorbar(x, cg_ret, yerr=cg_std, label="Grad+CEM", fmt='o', capsize=2) 16 | plt.legend() 17 | plt.xticks(x) 18 | plt.ylabel("Total Reward") 19 | plt.xlabel("Action Dimensionality") 20 | plt.savefig("./comp_plot.png") 21 | -------------------------------------------------------------------------------- /experiments/discont_comp/plot_num_obs.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | results = np.load('results_new_num_obs1.npy') 5 | x = np.arange(5,8) 6 | g_ret = results[0][0:3] 7 | g_std = results[1][0:3] 8 | c_ret = results[2][0:3] 9 | c_std = results[3][0:3] 10 | cg_ret = results[3][0:3] 11 | cg_std = results[4][0:3] 12 | print(g_ret) 13 | 14 | plt.errorbar(x, g_ret, yerr=g_std, label="Grad", fmt='-o', capsize=2) 15 | plt.errorbar(x, c_ret, yerr=c_std, label="CEM", fmt='-o', capsize=2) 16 | plt.errorbar(x, cg_ret, yerr=cg_std, label="Grad+CEM", fmt='-o', capsize=2) 17 | plt.legend() 18 | plt.xticks(x, x**2) 19 | plt.ylabel("Total Reward") 20 | plt.xlabel("Number of Obstacles") 21 | plt.savefig("./comp_plot_num_obs.png") 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mpc.test_energy import test_energy2d, NavigateGTEnv 4 | import matplotlib.pyplot as plt 5 | from mpc.cem import CEM 6 | from mpc.grad import GradPlan 7 | from mpc.gradcem import GradCEMPlan 8 | from mpc.svgd import SVGDPlan 9 | 10 | if __name__ == "__main__": 11 | 12 | # config = { 13 | # "planner": "CEM" 14 | # } 15 | # config = { 16 | # "planner": "GradPlan" 17 | # } 18 | # config = { 19 | # "planner": "GradCEMPlan" 20 | # } 21 | config = { 22 | "planner": "SVGDPlan" 23 | } 24 | 25 | B = 1 26 | # comp_device = torch.device('cuda:0') 27 | comp_device = torch.device('cpu') 28 | H = 70 29 | env = NavigateGTEnv(B, 3, test_energy2d, comp_device, control='force', sparse_r_step=H-1, dt=2.0/H, obstacles_env=True) 30 | 31 | planner = None 32 | if config["planner"] == "CEM": 33 | K = 100 34 | tK = 20 35 | opt_iter = 10 36 | planner = CEM(H, opt_iter, K, tK, env, device=comp_device) 37 | elif config["planner"] == "GradCEMPlan": 38 | K = 100 39 | tK = 20 40 | opt_iter = 10 41 | planner = GradCEMPlan(H, opt_iter, K, tK, env, device=comp_device) 42 | elif config["planner"] == "GradPlan": 43 | K = 20 44 | opt_iter = 10 45 | planner = GradPlan(H, opt_iter, K, env, device=comp_device) 46 | elif config["planner"] == "SVGDPlan": 47 | K = 20 48 | opt_iter = 10 49 | planner = SVGDPlan(H, opt_iter, K, env, device=comp_device) 50 | 51 | # actions = planner.forward(B, return_plan=True) 52 | plans = planner.forward(B, return_plan_each_iter=True) 53 | # actions = actions.cpu().numpy() 54 | 55 | # Visualize 56 | import matplotlib.pyplot as plt 57 | fig, ax = plt.subplots() 58 | 59 | ax.set_aspect('equal') 60 | 61 | env.reset_state(1) 62 | env.draw_env_2d_proj(ax) 63 | 64 | for actions in plans[0:]: 65 | env.reset_state(1) 66 | rs, ss = env.rollout(actions, return_traj=True) 67 | # ps = [s[0].cpu().numpy() for s in ss] 68 | # ps = np.array(ps) 69 | # ps = ps.squeeze() 70 | env.draw_traj_2d_proj(ax, ss) 71 | 72 | ax.set_xlim(-0.6, 1.6) 73 | ax.set_ylim(-0.6, 1.6) 74 | 75 | plt.savefig("./experiments/tmp/main.png") 76 | -------------------------------------------------------------------------------- /mpc/cem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import jit 3 | from torch import nn, optim 4 | 5 | # class CEM: 6 | # def __init__(self, candidates, top_candidates): 7 | 8 | 9 | class CEM(): # jit.ScriptModule): 10 | def __init__(self, planning_horizon, opt_iters, samples, top_samples, env, device): 11 | super().__init__() 12 | self.set_env(env) 13 | self.H = planning_horizon 14 | self.opt_iters = opt_iters 15 | self.K, self.top_K = samples, top_samples 16 | self.device = device 17 | 18 | def set_env(self, env): 19 | self.env = env 20 | if self.env is not None: 21 | self.a_size = env.a_size 22 | 23 | # @jit.script_method 24 | def forward(self, batch_size, return_plan=False, return_plan_each_iter=False): 25 | # Here batch is strictly if multiple CEMs should be performed! 26 | B = batch_size 27 | 28 | # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I) 29 | a_mu = torch.zeros(self.H, B, 1, self.a_size, device=self.device) 30 | a_std = torch.ones(self.H, B, 1, self.a_size, device=self.device) 31 | 32 | plan_each_iter = [] 33 | for _ in range(self.opt_iters): 34 | self.env.reset_state(B*self.K) 35 | # Evaluate J action sequences from the current belief (over entire sequence at once, batched over particles) 36 | # Sample actions (T x (B*K) x A) 37 | actions = (a_mu + a_std * torch.randn(self.H, B, self.K, self.a_size, device=self.device)).view(self.H, B * self.K, self.a_size) 38 | 39 | # Returns (B*K) 40 | returns = self.env.rollout(actions) 41 | 42 | # Re-fit belief to the K best action sequences 43 | _, topk = returns.reshape(B, self.K).topk(self.top_K, dim=1, largest=True, sorted=False) 44 | topk += self.K * torch.arange(0, B, dtype=torch.int64, device=topk.device).unsqueeze(dim=1) 45 | best_actions = actions[:, topk.view(-1)].reshape(self.H, B, self.top_K, self.a_size) 46 | # Update belief with new means and standard deviations 47 | a_mu = best_actions.mean(dim=2, keepdim=True) 48 | a_std = best_actions.std(dim=2, unbiased=False, keepdim=True) 49 | 50 | if return_plan_each_iter: 51 | _, topk = returns.reshape(B, self.K).topk(1, dim=1, largest=True, sorted=False) 52 | best_plan = actions[:, topk[0]].reshape(self.H, B, self.a_size).detach() 53 | plan_each_iter.append(best_plan.data.clone()) 54 | # plan_each_iter.append(a_mu.squeeze(dim=2).data.clone()) 55 | 56 | if return_plan_each_iter: 57 | return plan_each_iter 58 | if return_plan: 59 | return a_mu.squeeze(dim=2) 60 | else: 61 | # Return first action mean µ_t 62 | return a_mu.squeeze(dim=2)[0] 63 | 64 | if __name__ == "__main__": 65 | from test_energy import get_test_energy2d_env 66 | B = 1 67 | K = 100 68 | tK = 10 69 | t_env = get_test_energy2d_env(B*K) 70 | H = 1 71 | planner = CEM(H, 10, K, tK, t_env, device=torch.device('cpu')) 72 | action = planner.forward(B) 73 | action = action.cpu().numpy() 74 | 75 | import matplotlib.pyplot as plt 76 | N = 30 77 | x = torch.linspace(-1,1,N) 78 | y = torch.linspace(-1,1,N) 79 | X, Y = torch.meshgrid(x,y) 80 | actions_grid = torch.stack((X,Y),dim=-1) 81 | # print(actions_grid) 82 | energies = t_env.func(actions_grid.reshape(-1,2)) 83 | 84 | plt.pcolormesh(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="coolwarm") 85 | plt.contour(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="seismic") 86 | plt.scatter(action[:, 0], action[:, 1]) 87 | plt.show() 88 | -------------------------------------------------------------------------------- /mpc/grad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import jit 3 | from torch import nn, optim 4 | 5 | 6 | class GradPlan(): # jit.ScriptModule): 7 | def __init__(self, planning_horizon, opt_iters, samples, env, device, grad_clip=True): 8 | super().__init__() 9 | self.set_env(env) 10 | self.H = planning_horizon 11 | self.opt_iters = opt_iters 12 | self.K = samples 13 | self.device = device 14 | self.grad_clip = grad_clip 15 | 16 | def set_env(self, env): 17 | self.env = env 18 | if self.env is not None: 19 | self.a_size = env.a_size 20 | 21 | # @jit.script_method 22 | def forward(self, batch_size, return_plan=False, return_plan_each_iter=False): 23 | # Here batch is strictly if multiple Plans should be performed! 24 | B = batch_size 25 | 26 | # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I) 27 | a_mu = torch.zeros(self.H, B, 1, self.a_size, device=self.device) 28 | a_std = torch.ones(self.H, B, 1, self.a_size, device=self.device) 29 | 30 | # Sample actions (T x (B*K) x A) 31 | actions = (a_mu + a_std * torch.randn(self.H, B, self.K, self.a_size, device=self.device)).view(self.H, B * self.K, self.a_size) 32 | # TODO: debug 33 | # actions = actions*0 34 | actions = torch.tensor(actions, requires_grad=True) 35 | 36 | # optimizer = optim.SGD([actions], lr=0.1, momentum=0) 37 | optimizer = optim.RMSprop([actions], lr=0.1) 38 | plan_each_iter = [] 39 | for _ in range(self.opt_iters): 40 | self.env.reset_state(B*self.K) 41 | 42 | optimizer.zero_grad() 43 | 44 | # Returns (B*K) 45 | returns = self.env.rollout(actions) 46 | tot_returns = returns.sum() 47 | (-tot_returns).backward() 48 | 49 | # print(actions.grad.size()) 50 | 51 | # grad clip 52 | # Find norm across batch 53 | if self.grad_clip: 54 | epsilon = 1e-6 55 | max_grad_norm = 1.0 56 | actions_grad_norm = actions.grad.norm(2.0,dim=2,keepdim=True)+epsilon 57 | # print("before clip", actions.grad.max().cpu().numpy()) 58 | 59 | # Normalize by that 60 | actions.grad.data.div_(actions_grad_norm) 61 | actions.grad.data.mul_(actions_grad_norm.clamp(min=0, max=max_grad_norm)) 62 | # print("after clip", actions.grad.max().cpu().numpy()) 63 | 64 | # print(actions.grad) 65 | 66 | optimizer.step() 67 | 68 | if return_plan_each_iter: 69 | _, topk = returns.reshape(B, self.K).topk(1, dim=1, largest=True, sorted=False) 70 | best_plan = actions[:, topk[0]].reshape(self.H, B, self.a_size).detach() 71 | plan_each_iter.append(best_plan.data.clone()) 72 | 73 | actions = actions.detach() 74 | # Re-fit belief to the K best action sequences 75 | _, topk = returns.reshape(B, self.K).topk(1, dim=1, largest=True, sorted=False) 76 | best_plan = actions[:, topk[0]].reshape(self.H, B, self.a_size) 77 | 78 | if return_plan_each_iter: 79 | return plan_each_iter 80 | if return_plan: 81 | return best_plan 82 | else: 83 | return best_plan[0] 84 | 85 | if __name__ == "__main__": 86 | from test_energy import get_test_energy2d_env 87 | 88 | # torch.manual_seed(0) 89 | 90 | B = 1 91 | K = 1 92 | t_env = get_test_energy2d_env(B*K) 93 | H = 1 94 | planner = GradPlan(H, 10, K, t_env, device=torch.device('cpu')) 95 | action = planner.forward(B) 96 | action = action.cpu().numpy() 97 | 98 | import matplotlib.pyplot as plt 99 | N = 30 100 | x = torch.linspace(-1,1,N) 101 | y = torch.linspace(-1,1,N) 102 | X, Y = torch.meshgrid(x,y) 103 | actions_grid = torch.stack((X,Y),dim=-1) 104 | # print(actions_grid) 105 | energies = t_env.func(actions_grid.reshape(-1,2)) 106 | 107 | plt.pcolormesh(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="coolwarm") 108 | plt.contour(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="seismic") 109 | plt.scatter(action[:, 0], action[:, 1]) 110 | plt.show() 111 | -------------------------------------------------------------------------------- /mpc/gradcem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import jit 3 | from torch import nn, optim 4 | 5 | 6 | class GradCEMPlan(): # jit.ScriptModule): 7 | def __init__(self, planning_horizon, opt_iters, samples, top_samples, env, device, grad_clip=True): 8 | super().__init__() 9 | self.set_env(env) 10 | self.H = planning_horizon 11 | self.opt_iters = opt_iters 12 | self.K = samples 13 | self.top_K = top_samples 14 | self.device = device 15 | self.grad_clip = grad_clip 16 | 17 | def set_env(self, env): 18 | self.env = env 19 | if self.env is not None: 20 | self.a_size = env.a_size 21 | 22 | # @jit.script_method 23 | def forward(self, batch_size, return_plan=False, return_plan_each_iter=False): 24 | # Here batch is strictly if multiple Plans should be performed! 25 | B = batch_size 26 | 27 | # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I) 28 | a_mu = torch.zeros(self.H, B, 1, self.a_size, device=self.device) 29 | a_std = torch.ones(self.H, B, 1, self.a_size, device=self.device) 30 | 31 | # Sample actions (T x (B*K) x A) 32 | actions = (a_mu + a_std * torch.randn(self.H, B, self.K, self.a_size, device=self.device)).view(self.H, B * self.K, self.a_size) 33 | actions = torch.tensor(actions, requires_grad=True) 34 | 35 | # optimizer = optim.SGD([actions], lr=0.1, momentum=0) 36 | optimizer = optim.RMSprop([actions], lr=0.1) 37 | plan_each_iter = [] 38 | for _ in range(self.opt_iters): 39 | self.env.reset_state(B*self.K) 40 | 41 | optimizer.zero_grad() 42 | 43 | # Returns (B*K) 44 | returns = self.env.rollout(actions) 45 | tot_returns = returns.sum() 46 | (-tot_returns).backward() 47 | 48 | # grad clip 49 | # Find norm across batch 50 | if self.grad_clip: 51 | epsilon = 1e-6 52 | max_grad_norm = 1.0 53 | actions_grad_norm = actions.grad.norm(2.0,dim=2,keepdim=True)+epsilon 54 | # print("before clip", actions.grad.max().cpu().numpy()) 55 | 56 | # Normalize by that 57 | actions.grad.data.div_(actions_grad_norm) 58 | actions.grad.data.mul_(actions_grad_norm.clamp(min=0, max=max_grad_norm)) 59 | # print("after clip", actions.grad.max().cpu().numpy()) 60 | 61 | optimizer.step() 62 | 63 | _, topk = returns.reshape(B, self.K).topk(self.top_K, dim=1, largest=True, sorted=False) 64 | topk += self.K * torch.arange(0, B, dtype=torch.int64, device=topk.device).unsqueeze(dim=1) 65 | best_actions = actions[:, topk.view(-1)].reshape(self.H, B, self.top_K, self.a_size) 66 | a_mu = best_actions.mean(dim=2, keepdim=True) 67 | a_std = best_actions.std(dim=2, unbiased=False, keepdim=True) 68 | 69 | if return_plan_each_iter: 70 | _, topk = returns.reshape(B, self.K).topk(1, dim=1, largest=True, sorted=False) 71 | best_plan = actions[:, topk[0]].reshape(self.H, B, self.a_size).detach() 72 | plan_each_iter.append(best_plan.data.clone()) 73 | 74 | # There must be cleaner way to do this 75 | k_resamp = self.K-self.top_K 76 | _, botn_k = returns.reshape(B, self.K).topk(k_resamp, dim=1, largest=False, sorted=False) 77 | botn_k += self.K * torch.arange(0, B, dtype=torch.int64, device=self.device).unsqueeze(dim=1) 78 | 79 | resample_actions = (a_mu + a_std * torch.randn(self.H, B, k_resamp, self.a_size, device=self.device)).view(self.H, B * k_resamp, self.a_size) 80 | actions.data[:, botn_k.view(-1)] = resample_actions.data 81 | 82 | actions = actions.detach() 83 | # Re-fit belief to the K best action sequences 84 | _, topk = returns.reshape(B, self.K).topk(1, dim=1, largest=True, sorted=False) 85 | best_plan = actions[:, topk[0]].reshape(self.H, B, self.a_size) 86 | 87 | if return_plan_each_iter: 88 | return plan_each_iter 89 | if return_plan: 90 | return best_plan 91 | else: 92 | return best_plan[0] 93 | 94 | if __name__ == "__main__": 95 | from test_energy import get_test_energy2d_env 96 | B = 1 97 | K = 100 98 | top_K = 10 99 | t_env = get_test_energy2d_env(B*K) 100 | H = 1 101 | planner = GradCEMPlan(H, 10, K, top_K, t_env, device=torch.device('cpu')) 102 | action = planner.forward(B) 103 | action = action.cpu().numpy() 104 | 105 | import matplotlib.pyplot as plt 106 | N = 30 107 | x = torch.linspace(-1,1,N) 108 | y = torch.linspace(-1,1,N) 109 | X, Y = torch.meshgrid(x,y) 110 | actions_grid = torch.stack((X,Y),dim=-1) 111 | # print(actions_grid) 112 | energies = t_env.func(actions_grid.reshape(-1,2)) 113 | 114 | plt.pcolormesh(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="coolwarm") 115 | plt.contour(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="seismic") 116 | plt.scatter(action[:, 0], action[:, 1]) 117 | plt.show() 118 | -------------------------------------------------------------------------------- /mpc/svgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import jit 3 | from torch import nn, optim, autograd 4 | 5 | import numpy as np 6 | 7 | def npy(tensor): 8 | return tensor.cpu().detach().numpy() 9 | 10 | def squared_dist(x): 11 | assert len(x.size()) == 2 12 | norm = (x ** 2).sum(1).view(-1, 1) 13 | dn = (norm + norm.view(1, -1)) - 2.0 * (x @ x.t()) 14 | return dn 15 | 16 | def rbf_kernel(x, h=None): 17 | """ 18 | Returns the full kernel matrix for input x 19 | x: NxC 20 | Output 21 | K: NxN where Kij = k(x_i, x_j) 22 | dK: NxC where dKi = sum_j grad_j Kij 23 | """ 24 | n, c = x.size() 25 | sq_dist_mat = squared_dist(x) 26 | 27 | if h is None: 28 | # Apply median trick for h 29 | h = torch.clamp(torch.median(sq_dist_mat)/np.log(n+1),1e-3, float('inf')).detach() 30 | 31 | K = torch.exp(-sq_dist_mat/h) 32 | dK = -(torch.matmul(K, x) - torch.sum(K, dim=-1,keepdim=True)*x)/h 33 | 34 | return K, dK 35 | 36 | 37 | class SVGDPlan(): # jit.ScriptModule): 38 | """ Plan with Stein Variational Gradient Descent """ 39 | def __init__(self, planning_horizon, opt_iters, samples, env, device, grad_clip=True): 40 | super().__init__() 41 | self.set_env(env) 42 | self.H = planning_horizon 43 | self.opt_iters = opt_iters 44 | self.K = samples 45 | self.device = device 46 | self.grad_clip = grad_clip 47 | 48 | def set_env(self, env): 49 | self.env = env 50 | if self.env is not None: 51 | self.a_size = env.a_size 52 | 53 | # @jit.script_method 54 | def forward(self, batch_size, return_plan=False, return_plan_each_iter=False, alpha=1.0): 55 | # TODO enable batching by batch the distance matrix computation 56 | assert batch_size == 1 57 | 58 | # Here batch is strictly if multiple Plans should be performed! 59 | B = batch_size 60 | 61 | # Initialize factorized belief over action sequences q(a_t:t+H) ~ N(0, I) 62 | flat_a_mu = torch.zeros(1, self.H * self.a_size, device=self.device) 63 | flat_a_std = torch.ones(1, self.H * self.a_size, device=self.device) 64 | 65 | # actions = (a_mu + a_std * torch.randn(self.H, B, self.K, self.a_size, device=self.device)).view(self.H, B * self.K, self.a_size) 66 | # Sample actions ((B*K) x (T,A)) 67 | flat_actions = (flat_a_mu + flat_a_std * torch.randn(B*self.K, self.H*self.a_size, device=self.device)) 68 | # TODO: debug 69 | # flat_actions = flat_actions*0 70 | flat_actions = torch.tensor(flat_actions, requires_grad=True) 71 | 72 | # Dummy op to init grad 73 | flat_actions.sum().backward() 74 | 75 | # optimizer = optim.SGD([actions], lr=0.1, momentum=0) 76 | optimizer = optim.RMSprop([flat_actions], lr=0.1) 77 | plan_each_iter = [] 78 | for _ in range(self.opt_iters): 79 | self.env.reset_state(B*self.K) 80 | 81 | optimizer.zero_grad() 82 | 83 | # Get log prob gradient 84 | # Returns (B*K) 85 | # Use p propto exp(r) -> logp = r + const 86 | # Need actions in H,B*K,A format 87 | actions = flat_actions.view(B*self.K, self.H, self.a_size).transpose(0,1) 88 | returns = self.env.rollout(actions) 89 | tot_returns = returns.sum() 90 | grad_scores = autograd.grad(-tot_returns, flat_actions, retain_graph=True) 91 | assert len(grad_scores) == 1 92 | grad_scores = grad_scores[0].detach() 93 | # (-tot_returns).backward() 94 | 95 | print('grad_score', npy(grad_scores)) 96 | print(grad_scores.size()) 97 | # grad clip 98 | # Find norm across batch 99 | if self.grad_clip: 100 | epsilon = 1e-6 101 | max_grad_norm = 1.0 102 | grad_scores_norm = grad_scores.norm(2.0,dim=-1,keepdim=True)+epsilon 103 | # print("before clip", actions.grad.max().cpu().numpy()) 104 | 105 | # Normalize by that 106 | grad_scores.data.div_(grad_scores_norm) 107 | grad_scores.data.mul_(grad_scores_norm.clamp(min=0, max=max_grad_norm)) 108 | # print("after clip", actions.grad.max().cpu().numpy()) 109 | 110 | print('grad_score', npy(grad_scores)) 111 | 112 | # Get the kernel matrix and the summed kernel gradients 113 | # TODO: handle batching 114 | K, dK = rbf_kernel(flat_actions.view(self.K, self.H*self.a_size)) 115 | print("K", npy(K)) 116 | print("dK", npy(dK)) 117 | 118 | # Form SVGD gradient 119 | svgd = (torch.matmul(K, grad_scores) + alpha * dK).detach() 120 | print("svgd", npy(svgd)) 121 | 122 | # Assign gradient 123 | # flat_actions.grad = (svgd.clone()) 124 | with torch.no_grad(): 125 | flat_actions.grad.set_(svgd.clone()) 126 | 127 | optimizer.step() 128 | 129 | if return_plan_each_iter: 130 | _, topk = returns.reshape(B, self.K).topk(1, dim=1, largest=True, sorted=False) 131 | actions = flat_actions.view(B*self.K, self.H, self.a_size).transpose(0,1) 132 | best_plan = actions[:, topk[0]].reshape(self.H, B, self.a_size).detach() 133 | plan_each_iter.append(best_plan.data.clone()) 134 | 135 | actions = flat_actions.view(B*self.K, self.H, self.a_size).transpose(0,1) 136 | actions = actions.detach() 137 | # Re-fit belief to the K best action sequences 138 | _, topk = returns.reshape(B, self.K).topk(1, dim=1, largest=True, sorted=False) 139 | best_plan = actions[:, topk[0]].reshape(self.H, B, self.a_size) 140 | 141 | if return_plan_each_iter: 142 | return plan_each_iter 143 | if return_plan: 144 | return best_plan 145 | else: 146 | return best_plan[0] 147 | 148 | if __name__ == "__main__": 149 | from test_energy import get_test_energy2d_env 150 | 151 | # torch.manual_seed(0) 152 | B = 1 153 | K = 10 154 | # to_K = 10 155 | t_env = get_test_energy2d_env(B*K) 156 | H = 1 157 | opt_iters = 10 158 | planner = SVGDPlan(H, opt_iters, K, t_env, device=torch.device('cpu')) 159 | action = planner.forward(B) 160 | action = action.cpu().numpy() 161 | 162 | import matplotlib.pyplot as plt 163 | N = 30 164 | x = torch.linspace(-1,1,N) 165 | y = torch.linspace(-1,1,N) 166 | X, Y = torch.meshgrid(x,y) 167 | actions_grid = torch.stack((X,Y),dim=-1) 168 | # print(actions_grid) 169 | energies = t_env.func(actions_grid.reshape(-1,2)) 170 | 171 | plt.pcolormesh(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="coolwarm") 172 | plt.contour(X.numpy(), Y.numpy(), -energies.reshape(N,N).numpy(), cmap="seismic") 173 | plt.scatter(action[:, 0], action[:, 1]) 174 | plt.show() 175 | -------------------------------------------------------------------------------- /mpc/test_energy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import torch.nn as nn 5 | 6 | def get_test_energy2d_env(batch_size): 7 | return FuncMinGTEnv(batch_size, 2, test_energy2d) 8 | 9 | def test_energy2d(action_batch): 10 | assert action_batch.dim() == 2 11 | assert action_batch.size()[1] == 2 12 | 13 | opt_point = torch.tensor([[1.0, 1.0]], 14 | requires_grad=False, 15 | device=action_batch.device) 16 | 17 | return ((action_batch-opt_point)**2).sum(-1) 18 | 19 | def get_test_energy(opt_point): 20 | def test_energy(query_batch): 21 | assert query_batch.dim() == 2 22 | assert query_batch.size(1) == opt_point.size(-1) 23 | 24 | return ((query_batch-opt_point)**2).sum(-1) 25 | return test_energy 26 | 27 | # def time_param_curve() 28 | 29 | class BatchRepulseCircle: 30 | def __init__(self, origins, radius, batch_dims=[0,1], k=1.0): 31 | self.B = origins.size(0) 32 | self.origins = origins 33 | self.device = self.origins.device 34 | self.radius = radius 35 | self.k = k 36 | self.batch_dims = batch_dims 37 | 38 | def force(self, x): 39 | # print(x.size()) 40 | x = x.unsqueeze(-2) 41 | # print(x.size()) 42 | # print(self.origins.size()) 43 | # print(self.B) 44 | self.device = self.origins.device 45 | # print(self.batch_dims) 46 | contact_vector = (x-self.origins)[..., torch.arange(self.B, device=self.device, dtype=torch.long).view(-1,1), self.batch_dims] 47 | # print('cv', contact_vector.size()) 48 | # print(contact_vector.size()) 49 | dist = contact_vector.norm(dim=-1,keepdim=True) 50 | # print(dist.size()) 51 | penetration = (self.radius - dist).clamp(0,self.radius) 52 | # penetration = torch.max(torch.min((self.radii - dist),0),self.origins) 53 | # print(penetration.size()) 54 | force = self.k*(contact_vector)*(penetration/(dist+1e-6)+1e-16).pow(0.3) 55 | # print(force.size()) 56 | tot_force = force.sum(dim=1) 57 | # print(tot_force.size()) 58 | shape = tuple(tot_force.size()[:-1]) 59 | # print(shape) 60 | return torch.cat((tot_force, torch.zeros(shape+(x.size(-1)-2,), dtype=torch.float, device=x.device)),dim=-1) 61 | 62 | class RepulseCircle: 63 | def __init__(self, origin, radius, k=1.0, dims=[0,1]): 64 | self.origin = origin 65 | self.radius = radius 66 | self.k = k 67 | self.dims = dims 68 | 69 | def force(self, x): 70 | contact_vector = (x-self.origin)[..., self.dims] 71 | # print(contact_vector.size()) 72 | dist = contact_vector.norm(dim=-1,keepdim=True) 73 | penetration = (self.radius - dist).clamp(0,self.radius) 74 | force = self.k*(contact_vector)*(penetration/(dist+1e-6)+1e-16).pow(0.3) 75 | shape = tuple(force.size()[:-1]) 76 | return torch.cat((force, torch.zeros(shape+(x.size(-1)-2,), dtype=torch.float, device=x.device)),dim=-1) 77 | 78 | class NavigateGTEnv(): 79 | def __init__(self, batch_size, input_size, batched_func, device, control='force', mass=1.0, sparse_r_step=None, dt=0.05, obstacles_env=False, num_obs=12): 80 | """ 81 | batch_size B: number of agents to simulate in parallel 82 | input_size A: number of dimensions for the space the agent is operating in 83 | batched_func func: a batched function for what the reward should be given a position 84 | device: torch device 85 | control: type of control can be {'vel', 'accel', 'force'} 86 | mass: if control type is force, this is the mass of the agent 87 | """ 88 | self.device = device 89 | self.a_size = input_size 90 | self.s_size = input_size 91 | self.func = batched_func 92 | self.control = control 93 | self.mass = mass 94 | self.state = None 95 | self.dt = dt 96 | self.reset_state(batch_size) 97 | self.obstacles_env = obstacles_env 98 | self.num_obs = num_obs 99 | 100 | self.primary_axis = torch.ones(self.s_size) 101 | # self.primary_axis = self.primary_axis/self.primary_axis.norm() 102 | 103 | self.opt_point = torch.tensor(self.primary_axis, 104 | requires_grad=False, 105 | device=device) 106 | 107 | #TODO: hack for now 108 | self.func = get_test_energy(self.opt_point) 109 | 110 | if obstacles_env: 111 | # self.obstacles = [] 112 | # obstacle_list = [((0.5,0.5),0.06), ((0.3,0.3),0.08), ((0.05,0.2),0.08), ((0.2,0.05),0.1), ((0.5,0.25),0.06), ((0.25,0.5),0.06), ((0.8,0.25),0.1), ((0.25,0.8),0.1), ((0.5,0.75),0.1), ((0.75,0.5),0.1), ((0.5,-0.1),0.15), ((-0.1,0.5),0.15)] 113 | origin_list = [] 114 | num_obs = self.num_obs 115 | print(num_obs) 116 | density = 0.8 117 | radius = density/num_obs 118 | for x_pos in np.linspace(-0.7,1.3,num_obs): 119 | for y_pos in np.linspace(-0.7,1.3,num_obs): 120 | x_pos = x_pos + np.random.uniform(-0.1/num_obs, 0.1/num_obs) 121 | y_pos = y_pos + np.random.uniform(-0.1/num_obs, 0.1/num_obs) 122 | origin_list.append([x_pos, y_pos]+[0]*(self.s_size-2)) 123 | circ_origins = torch.tensor(origin_list, device=self.device) 124 | 125 | # for (point, rad) in obstacle_list: 126 | # # Pad it out to the right dimensionality 127 | # circ_origin = torch.tensor(point+(0.,)*(self.s_size-2), 128 | # requires_grad=False, 129 | # device=device) 130 | # self.obstacles.append(RepulseCircle(circ_origin, rad, 100.0)) 131 | self.obstacle = BatchRepulseCircle(circ_origins, radius, batch_dims=torch.tensor([0,1], dtype=torch.long, device=self.device).view(1,2).expand(circ_origins.size(0),2), k=300.0) 132 | else: 133 | circ_origin = torch.tensor(self.primary_axis*0.5, 134 | requires_grad=False, 135 | device=device) 136 | self.obstacle = RepulseCircle(circ_origin, 0.5, k=40.0) 137 | 138 | # Step on which sparse reward is given (if None dense reward given at each time step) 139 | self.sparse_r_step=sparse_r_step 140 | 141 | def reset_state(self, batch_size): 142 | """ 143 | Starts new "episode" for all the agents 144 | resets the state of the environment to the initial state 145 | batch_size B: sets the batch_size for this run 146 | """ 147 | self.t_step = 0 148 | self.B = batch_size 149 | if(self.state is None or batch_size != self.state[0].size(0)): 150 | self.pos = torch.tensor(torch.zeros((batch_size, self.s_size), dtype=torch.float), device=self.device, requires_grad=False) 151 | self.vel = torch.tensor(torch.zeros((batch_size, self.s_size), dtype=torch.float), device=self.device, requires_grad=False) 152 | self.state = [self.pos, self.vel] 153 | # self.done = torch.tensor(torch.zeros((batch_size), dtype=torch.float), requires_grad=False) 154 | # Detach from graph 155 | self.state[0] = self.state[0].detach() 156 | self.state[1] = self.state[1].detach() 157 | # Reset to zero 158 | self.state[0].fill_(0) 159 | self.state[1].fill_(0) 160 | 161 | def rollout(self, actions, return_traj=False): 162 | # Uncoditional action sequence rollout 163 | # actions: shape: TxBxA (time, batch, action) 164 | assert actions.dim() == 3 165 | assert actions.size(1) == self.B, "{}, {}".format(actions.size(1), self.B) 166 | assert actions.size(2) == self.a_size 167 | T = actions.size(0) 168 | rs = [] 169 | ss = [] 170 | 171 | total_r = torch.zeros(self.B, requires_grad=True, device=actions.device) 172 | for i in range(T): 173 | _, r, done = self.step(actions[i]) 174 | rs.append(r) 175 | ss.append(self.state) 176 | total_r = total_r + r 177 | if(done): 178 | break 179 | if return_traj: 180 | return rs, ss 181 | else: 182 | return total_r 183 | 184 | def step(self, action): 185 | self.state = self.sim(self.state, action) 186 | o = self.calc_obs(self.state) 187 | r = self.calc_reward(self.state, action) 188 | return o, r, False 189 | 190 | def sim(self, state, action): 191 | # Symplectic euler 192 | next_state = [None, None] 193 | # Velocity control 194 | if self.control == 'vel': 195 | next_state[1] = nn.Tanh()(action) 196 | elif self.control == 'accel': 197 | next_state[1] = state[1] + self.dt * nn.Tanh()(action) 198 | elif self.control == 'force': 199 | if self.obstacles_env: 200 | # fext = self.obstacles[0].force(state[0]) 201 | # for i in range(1,len(self.obstacles)): 202 | # fext = fext + self.obstacles[i].force(state[0]) 203 | fext = self.obstacle.force(state[0]) 204 | else: 205 | fext = self.obstacle.force(state[0]) 206 | next_state[1] = state[1] + self.dt * (nn.Tanh()(action) + fext)/self.mass 207 | else: 208 | raise NotImplementedError() 209 | next_state[0] = state[0] + self.dt * next_state[1] 210 | self.t_step += 1 211 | return next_state 212 | 213 | def calc_obs(self, state): 214 | return None 215 | 216 | def calc_reward(self, state, action): 217 | if self.sparse_r_step is not None: 218 | if self.sparse_r_step == self.t_step: 219 | return -self.func(state[0]) 220 | else: 221 | return 0 222 | 223 | return -self.func(state[0]) 224 | 225 | # def visualize_2d(self, ) 226 | @staticmethod 227 | def repulsive_circle_force(x, origin, radius, k=1.0): 228 | contact_vector = (x-origin) 229 | # print(contact_vector.size()) 230 | dist = contact_vector.norm(dim=-1,keepdim=True) 231 | penetration = (radius - dist).clamp(0,radius) 232 | force = k*(contact_vector)*(penetration/(dist+1e-6)) 233 | return force 234 | 235 | def draw_env_2d_proj(self, ax): 236 | # Project onto principal dimension 237 | N = 30 238 | x = torch.linspace(-0.5,1.5,N) 239 | y = torch.linspace(-0.5,1.5,N) 240 | X, Y = torch.meshgrid(x,y) 241 | pos_grid = torch.stack((X,Y),dim=-1) 242 | shape = tuple(pos_grid.size()[:-1]) 243 | pos_grid = torch.cat((pos_grid, torch.ones(shape+(self.s_size-2,))), dim=-1) 244 | energies = self.func(pos_grid.reshape(-1,self.s_size).to(self.device)) 245 | ax.pcolormesh(X.numpy(), Y.numpy(), -energies.reshape(N,N).cpu().numpy(), cmap="coolwarm") 246 | ax.contour(X.numpy(), Y.numpy(), -energies.reshape(N,N).cpu().numpy(), cmap="seismic") 247 | 248 | # Draw repulsive circle 249 | 250 | if self.obstacles_env: 251 | for i in range(self.obstacle.origins.size(0)): 252 | circle = plt.Circle(self.obstacle.origins[i].cpu().numpy().squeeze(), self.obstacle.radius, fill=False, edgecolor='k') 253 | ax.add_artist(circle) 254 | 255 | # for i in range(len(self.obstacles)): 256 | # circle = plt.Circle(self.obstacles[i].origin.cpu().numpy().squeeze(), self.obstacles[i].radius, fill=False, edgecolor='k') 257 | # ax.add_artist(circle) 258 | else: 259 | circle = plt.Circle(self.obstacle.origin.cpu().numpy().squeeze(), self.obstacle.radius, fill=False, edgecolor='lightgreen') 260 | ax.add_artist(circle) 261 | 262 | def draw_traj_2d_proj(self, ax, states): 263 | # Project onto principal dimension 264 | ps = [s[0].cpu().numpy() for s in states] 265 | ps = np.array(ps) 266 | ps = ps.squeeze() 267 | ax.plot(ps[:, 0], ps[:, 1],'ko-',linewidth=0.5,markersize=1) 268 | 269 | 270 | class FuncMinGTEnv(): 271 | def __init__(self, batch_size, input_size, batched_func): 272 | self.a_size = input_size 273 | self.func = batched_func 274 | self.state = None 275 | self.reset_state(batch_size) 276 | 277 | def reset_state(self, batch_size): 278 | self.B = batch_size 279 | 280 | def rollout(self, actions): 281 | # Uncoditional action sequence rollout 282 | # TxBxA 283 | T = actions.size(0) 284 | total_r = torch.zeros(self.B, requires_grad=True, device=actions.device) 285 | for i in range(T): 286 | _, r, done = self.step(actions[i]) 287 | total_r = total_r + r 288 | if(done): 289 | break 290 | return total_r 291 | 292 | def step(self, action): 293 | self.state = self.sim(self.state, action) 294 | o = self.calc_obs(self.state) 295 | r = self.calc_reward(self.state, action) 296 | # always done after first step 297 | return o, r, True 298 | 299 | def sim(self, state, action): 300 | return state 301 | 302 | def calc_obs(self, state): 303 | return None 304 | 305 | def calc_reward(self, state, action): 306 | return -self.func(action) 307 | 308 | 309 | 310 | # class PMEnv(): 311 | # def __init__(): 312 | # dt = 0.1 313 | # max_v = 2.0 314 | # max_a = 1.0 315 | # start_p = torch.tensor([[-1.0,0.0]], requires_grad=False) 316 | # start_v = torch.tensor([[0.0,0.0]], requires_grad=False) 317 | # cur_p = None 318 | # cur_v = None 319 | 320 | # def reset(batch_size): 321 | # cur_p = (torch.start_p.clone()) 322 | 323 | 324 | # def batch_step(actions): 325 | # pass 326 | 327 | 328 | if __name__ == "__main__": 329 | import matplotlib.pyplot as plt 330 | N = 30 331 | x = torch.linspace(-1,1,N) 332 | y = torch.linspace(-1,1,N) 333 | X, Y = torch.meshgrid(x,y) 334 | actions_grid = torch.stack((X,Y),dim=-1) 335 | print(actions_grid) 336 | energies = test_energy2d(actions_grid.reshape(-1,2)) 337 | 338 | plt.pcolormesh(X.numpy(), Y.numpy(), energies.reshape(N,N).numpy(), cmap="coolwarm") 339 | plt.contour(X.numpy(), Y.numpy(), energies.reshape(N,N).numpy(), cmap="seismic") 340 | plt.show() 341 | 342 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | 5 | setup(name='mpc', 6 | packages=['mpc'], 7 | ) 8 | --------------------------------------------------------------------------------