├── .gitignore ├── README.md ├── _assets ├── during_train.gif ├── tensorboard.jpg └── trained.gif ├── envs ├── __init__.py ├── payload_env_g2g.py ├── test.py └── torch_env.py ├── main.py ├── memory.py ├── rollout_generator.py ├── rssm_main_latent_overshooting.py ├── rssm_model.py ├── rssm_policy.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | */__pycache__ 3 | results/* 4 | *.png 5 | *.pth 6 | *.gv 7 | *.html 8 | stm* 9 | .vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dreamer-pytorch 2 | [![MIT License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md) 3 | 4 | A PyTorch Implementation of PlaNet: A Deep Planning Network for Reinforcement Learning [[1]](#references) by Danijar Hafner et.al. 5 | 6 | ### Usage 7 | - Run `main.py` for training. 8 | - Run `eval.py` for evaluation of a saved checkpoint. 9 | - Tensorboard will be used to display and store metrics and can be viewed by running the following: 10 | ```shell 11 | $ tensorboard --logdir /results 12 | ``` 13 | - Visit tensorboard in your browser! By default tensorboard launches at `localhost:6006`. You might see a screen similar to this: 14 | ![Tensorboard](_assets/tensorboard.jpg) 15 | 16 | ### Results 17 | The video on the **left** is the downscaled version of the **gym render**. 18 | The one on the **right** is **generated** by the decoder model. 19 | #### During Training 20 | ![training](_assets/during_train.gif) 21 | 22 | #### After Training 23 | ![training](_assets/trained.gif) 24 | 25 | 26 | ### Installation and running! 27 | Install dependencies ... 28 | - `pytorch==1.4.0` 29 | - `tensorboard-pytorch==0.7.1` 30 | - `tqdm==4.42.1` 31 | - `torchvision==0.5.0` 32 | - `gym==0.16.0` 33 | 34 | References & Acknowledgements 35 | ----------------------------- 36 | - [Learning Latent Dynamics for Planning from Pixels][paper] 37 | - [google-research/planet] by [@danijar] 38 | - [PlaNet] by [@Kaixhin] 39 | 40 | [Website]: https://danijar.com/project/planet/ 41 | [paper]: https://arxiv.org/abs/1811.04551 42 | [@danijar]: https://github.com/danijar 43 | [@Kaixhin]: https://github.com/Kaixhin 44 | [PlaNet]: https://github.com/Kaixhin/PlaNet 45 | [google-research/planet]: https://github.com/google-research/planet 46 | -------------------------------------------------------------------------------- /_assets/during_train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhayraw1/planet-torch/656189e060f24e262c09eb754e4bf99bed8662b7/_assets/during_train.gif -------------------------------------------------------------------------------- /_assets/tensorboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhayraw1/planet-torch/656189e060f24e262c09eb754e4bf99bed8662b7/_assets/tensorboard.jpg -------------------------------------------------------------------------------- /_assets/trained.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhayraw1/planet-torch/656189e060f24e262c09eb754e4bf99bed8662b7/_assets/trained.gif -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | from . import payload_env_g2g 2 | from . import payload_env_gym 3 | 4 | CUSTOM_ENVS = ['PayloadEnv-v0', 'PayloadEnvG2G-v0'] -------------------------------------------------------------------------------- /envs/payload_env_g2g.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import gym 5 | import pdb 6 | import yaml 7 | import time 8 | import enum 9 | import matplotlib 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | from numpy import pi 14 | from copy import deepcopy 15 | from gym.spaces import Dict, Box 16 | from numpy.linalg import norm, inv as inverse 17 | from shapely.geometry import MultiPoint, Point 18 | from matplotlib.patches import Polygon, Circle, FancyArrowPatch 19 | 20 | from payload_manipulation.utils.utils import * 21 | from payload_manipulation.utils.visualizer import Visualizer 22 | from payload_manipulation.utils.payload_boundaries import * 23 | from payload_manipulation.utils.transformations import compose_matrix 24 | from payload_manipulation.utils.payload_transform import PayloadTransform 25 | 26 | 27 | # matplotlib.use('Agg') 28 | matplotlib.rcParams['toolbar'] = 'None' 29 | 30 | def cossin(x): 31 | return np.cos(x), np.sin(x) 32 | 33 | class StateType(enum.Enum): 34 | RADIANS = 1 35 | COS_SIN = 2 36 | SIX_DIM = 3 37 | 38 | 39 | class ColorScheme: 40 | def __init__(self, payload=None, obstacles=None, freespace=None, goal=None): 41 | goal = goal if goal is not None else (231, 76, 60) 42 | payload = payload if payload is not None else (231, 76, 60) 43 | obstacles = obstacles if obstacles is not None else (52, 152, 219) 44 | freespace = freespace if freespace is not None else (96, 174, 39) 45 | self.payload = self.pld = tuple(map(int, payload)) 46 | self.obstacles = self.obs = tuple(map(int, obstacles)) 47 | self.freespace = self.fsp = tuple(map(int, freespace)) 48 | self.goal = self.gol = tuple(map(int, goal)) 49 | 50 | 51 | OBS_STATE_SPACE = { 52 | StateType.RADIANS: 3, StateType.COS_SIN: 8, StateType.SIX_DIM: 8 53 | } 54 | 55 | class PayloadEnv(gym.Env): 56 | def __init__(self, config=None): 57 | super().__init__() 58 | config = config or {} 59 | self.config = config 60 | self.__dict__.update(config) 61 | self.dt = config.get('dt', 0.1) 62 | self.n_obs = config.get('n_obs', 15) 63 | self.goal_state = np.zeros(2) 64 | self.payload_dims = config.get('payload_dims', np.array([4., 3., 0.])) 65 | self.obstacle_radius = config.get('obstacle_radius', 0.3) 66 | self.payload_b = np.array([200, 200, np.pi/4, np.pi/4, np.pi, np.pi]) 67 | self.transform = PayloadTransform(dims=self.payload_dims) 68 | self.mean_height = 1.0 69 | self.epsr = 7.5 70 | self.arena_dims = config.get('arena_dims', np.array([self.epsr]*2)) 71 | self.twoD = config.get('twoD', True) 72 | self.color_scheme = config.get('color_scheme', ColorScheme()) 73 | # for spinup algos... should have no effect 74 | self.init_vx() 75 | self.set_image_size(config.get('image_size', np.array([192, 192]))) 76 | self.state_type = config.get('state_type', StateType.RADIANS) 77 | assert isinstance(self.state_type, StateType) 78 | self.to_render = config.get('to_render', True) 79 | self.fig = None 80 | self.prev_state = None 81 | self.curr_img = None 82 | self.alpha = 0.7 83 | self.a, self.b, _ = self.payload_dims/2 + self.obstacle_radius*1.35 84 | self.action_space = Box( 85 | np.array([-0.3, -0.3, -np.pi/6, -np.pi/6, -np.pi/4]), 86 | np.array([+0.3, +0.3, +np.pi/6, +np.pi/6, +np.pi/4]) 87 | ) 88 | self.observation_space = Box( 89 | -np.ones(OBS_STATE_SPACE[self.state_type]), 90 | np.ones(OBS_STATE_SPACE[self.state_type]) 91 | ) 92 | 93 | 94 | def init_plot(self): 95 | self.fig = plt.figure(figsize=self.image_size/100) 96 | self.axs = plt.axes([0,0,1,1], frameon=False) 97 | plt.ion() 98 | self.im = plt.imshow(np.zeros(self.image_size), interpolation='none') 99 | self.axs.set_xticks([]) 100 | self.axs.set_yticks([]) 101 | plt.show() 102 | 103 | 104 | def render(self, mode='human', forced=False): 105 | if not self.to_render: 106 | return 107 | if self.fig is None: 108 | self.init_plot() 109 | if self.curr_img is None: 110 | raise ValueError('Can\'t render without reset()') 111 | if forced: 112 | self.get_observation() 113 | self.im.set_array(self.curr_img/255) 114 | plt.gcf().canvas.draw_idle() 115 | plt.gcf().canvas.start_event_loop(0.0001) 116 | if mode == 'human': 117 | return 118 | return self.curr_img 119 | 120 | 121 | def get_random_state(self, no_collide=[], r=1): 122 | while True: 123 | x = (2*np.random.random(2) - 1.0)*(self.arena_dims + 2**.5) 124 | aex = abs(self.to_ego_frame(x)) 125 | # check if obstacle is inside payload (don't want that) 126 | if all(aex < self.payload_dims[:2]/2 + self.obstacle_radius*2): 127 | # print('Fail 1') 128 | continue 129 | # check if obstacle is outside the eps neighbourhood 130 | # Can go super strict by using any() in place of all() 131 | if all(aex > self.epsr): 132 | # print('Fail 2') 133 | continue 134 | break 135 | return x 136 | 137 | 138 | def reset(self, total_chaos=False): 139 | self.agent_state = np.zeros(5) 140 | self.agent_state[-1] = (2*np.random.random() - 1)*np.pi 141 | self.obstacle_state = np.zeros((self.n_obs, 2)) 142 | elems = [] 143 | for i in range(self.n_obs): 144 | self.obstacle_state[i] = self.get_random_state() 145 | elems.append(self.obstacle_state[i]) 146 | r = 0.1 if total_chaos else max(self.payload_dims) 147 | self.goal_state = self.get_random_state() 148 | self.prev_action = np.zeros_like(self.action_space.sample()) 149 | self.success = False 150 | return self.get_observation() 151 | 152 | 153 | def init_vx(self): 154 | self.vxs = [] 155 | pseudo_vxs = np.array([ 156 | [1, 1, 1], [1, 1, -1], [-1, 1, -1], [-1, 1, 1], 157 | [-1, -1, 1], [-1, -1, -1], [1, -1, -1], [1, -1, 1] 158 | ]) 159 | if self.twoD: 160 | pseudo_vxs = pseudo_vxs[range(0, 8, 2)] 161 | for vx in pseudo_vxs: 162 | self.vxs.append(self.payload_dims/2*vx) 163 | self.vxs = np.stack(self.vxs).transpose(0, 1) 164 | 165 | @property 166 | def vertices(self): 167 | x, r = np.split(self.agent_state, [2]) 168 | x = np.concatenate([x, np.ones(1)*self.mean_height]) 169 | return np.matmul( 170 | compose_matrix(translate=x, angles=r), 171 | np.concatenate([self.vxs, np.ones((4, 1))], axis=-1).T 172 | )[:-1].T 173 | 174 | def to_ego_frame(self, pts): 175 | if len(pts.shape) == 1: 176 | return self.to_ego_frame(pts[None]).flatten() 177 | pts = np.concatenate([pts.T, OIS(pts.shape[0])]) 178 | x, y, _, _, r = self.agent_state 179 | tf = compose_matrix(translate=[x, y, 0], angles=[0, 0, r]) 180 | return np.matmul(np.linalg.inv(tf), pts)[:-2].T 181 | 182 | def from_ego_frame(self, pts): 183 | if len(pts.shape) == 1: 184 | return self.from_ego_frame(pts[None]).flatten() 185 | pts = np.concatenate([pts.T, OIS(pts.shape[0])]) 186 | x, y, _, _, r = self.agent_state 187 | tf = compose_matrix(translate=[x, y, 0], angles=[0, 0, r]) 188 | return np.matmul(tf, pts)[:-2].T 189 | 190 | def get_observation(self): 191 | obs = np.ones((*self.image_size, 3)) 192 | obs = (obs*self.color_scheme.fsp).astype(np.uint8) 193 | for obstacle_xy in self.obstacle_state: 194 | center = self.to_ego_frame(obstacle_xy) 195 | if (abs(center) > self.epsr).all(): 196 | continue 197 | center = ((center*[1, -1])*self.scale + self.image_size/2) 198 | cv2.circle( 199 | obs, tuple(map(int, center)), self.radius_px, 200 | self.color_scheme.obs, -1 201 | ) 202 | center = self.to_ego_frame(self.goal_state) 203 | center = ((center*[1, -1])*self.scale + self.image_size/2).astype('i') 204 | cv2.circle( 205 | obs, tuple(center), self.radius_px, self.color_scheme.gol, -1 206 | ) 207 | vxs = self.to_ego_frame(self.vertices[:, :-1])*[[1, -1]] 208 | vxs = (vxs*self.scale + self.image_size/2).astype('i') 209 | pld = np.copy(obs) 210 | cv2.fillConvexPoly(pld, vxs, self.color_scheme.pld) 211 | # alpha = 0.5 212 | cv2.addWeighted(pld, self.alpha, obs, 1 - self.alpha, 0, obs) 213 | self.curr_img = np.copy(obs) 214 | state = np.copy(self.agent_state) 215 | state[:2] = self.to_ego_frame(self.goal_state) 216 | if self.state_type == StateType.COS_SIN: 217 | state = np.concatenate( 218 | [state[:2]] + list(zip(*cossin(self.agent_state[2:]))) 219 | ) 220 | if self.state_type == StateType.SIX_DIM: 221 | state = np.concatenate([ 222 | state[:2], 223 | compose_matrix(angles=self.agent_state[2:])[:3, :2].flatten() 224 | ]) 225 | return self.curr_img, state 226 | 227 | 228 | def get_reward_and_done(self, action, info): 229 | pld = MultiPoint(self.to_ego_frame(self.vertices[:, :-1])).convex_hull 230 | reward = 0 231 | done, success = False, False 232 | d2g = self.to_ego_frame(self.goal_state) 233 | 234 | if any(abs(d2g) > self.epsr): 235 | return -1, False, True 236 | 237 | for obstacle_xy in self.to_ego_frame(self.obstacle_state): 238 | if any(obstacle_xy > self.epsr): 239 | continue 240 | obstacle = Point(*obstacle_xy).buffer(self.obstacle_radius) 241 | collision = obstacle.intersects(pld) 242 | if collision: 243 | print('COLLISION!') 244 | return -1, False, True 245 | # r1 = -norm(self.agent_state[2:4]) 246 | r2 = norm(self.prev_state[:2] - self.goal_state)\ 247 | - norm(self.agent_state[:2] - self.goal_state) 248 | # r3 = -norm(self.prev_action - action) 249 | # reward = 0.05*r1 + r2 + 0.1*r3 + (-0.1*self.dt) # step_penalty 250 | reward = r2 - 0.1*self.dt 251 | if norm(d2g) < 0.3: 252 | reward = 1 253 | success = True 254 | info['d2g'] = norm(d2g) 255 | return reward, success, done or success 256 | 257 | def step(self, controls): 258 | self.prev_state = np.copy(self.agent_state) 259 | if len(controls.shape) == 1: 260 | controls = controls[None] 261 | controls = controls.clip(self.action_space.low, self.action_space.high) 262 | for control in controls: 263 | self.agent_state[0] += control[0]*np.cos(self.agent_state[-1]) 264 | self.agent_state[0] -= control[1]*np.sin(self.agent_state[-1]) 265 | self.agent_state[1] += control[0]*np.sin(self.agent_state[-1]) 266 | self.agent_state[1] += control[1]*np.cos(self.agent_state[-1]) 267 | self.agent_state[2:] += control[2:]*self.dt 268 | self.agent_state[2:4] = self.agent_state[2:4].clip(-pi/4, pi/4) 269 | self.agent_state[-1] = mod_angle(self.agent_state[-1]) 270 | obs = self.get_observation() 271 | info = {} 272 | reward, success, done = self.get_reward_and_done(controls, info) 273 | self.prev_action = np.copy(controls) 274 | self.success = success 275 | info['success'] = success 276 | return obs, reward, done, info 277 | 278 | def set_rendering(self, enable): 279 | self.to_render = enable 280 | 281 | def set_image_size(self, image_size): 282 | assert image_size.size == 2 283 | self.image_size = image_size 284 | self.scale = self.image_size/(2*self.epsr) 285 | self.radius_px = int((self.obstacle_radius*self.scale).mean()*1.5) 286 | 287 | def close(self): 288 | plt.close(self.fig) 289 | super().close() 290 | 291 | 292 | 293 | def register_with_config(env_name, config=None): 294 | config = config or {} 295 | gym.envs.registration.register( 296 | id=env_name, 297 | entry_point=PayloadEnv, 298 | kwargs={'config': config} 299 | ) 300 | 301 | env_config = { 302 | 'n_obs': 8, 303 | 'state_type': StateType.SIX_DIM, 304 | 'color_scheme': ColorScheme(*(np.eye(3)*255), 255*np.ones(3)), 305 | } 306 | register_with_config('PayloadEnvG2G-v0', env_config) 307 | 308 | if __name__ == '__main__': 309 | set_seed(55723) 310 | env = gym.make('PayloadEnvG2G-v0') 311 | 312 | obs = env.reset(total_chaos=False) 313 | env.goal_state = env.agent_state[:2] 314 | # print('Goal RN: ', env.goal_state) 315 | env.to_ego_frame(env.goal_state) 316 | t = time.time() 317 | d = False 318 | for i in range(300): 319 | env.render(forced=True) 320 | pdb.set_trace() 321 | if d: 322 | break 323 | # u = np.array([1, 0, 0, 0, 0]) 324 | u = env.action_space.sample() 325 | u = env.action_space.high 326 | print(env.agent_state) 327 | obs, r, d, ii = env.step(u) 328 | time.sleep(.1) 329 | # print(obs[0].shape) 330 | # input() 331 | print(f'Time taken: {(time.time() - t)/300}') 332 | -------------------------------------------------------------------------------- /envs/test.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pdb 3 | import time 4 | import numpy as np 5 | from payload_manipulation.utils.utils import * 6 | set_seed(41176) 7 | env = gym.make('PayloadEnvG2G-v0') 8 | obs = env.reset(total_chaos=False) 9 | print(np.degrees(env.agent_state[-1])) 10 | # env.agent_state[-1] = -np.pi 11 | env.render() 12 | action = np.array([1, 0, 0, 0, 0]) 13 | for i in range(100): 14 | _, r, d, info = env.step(action) 15 | env.render() 16 | print({'r': r, 'd': d, **info}) 17 | time.sleep(0.05) 18 | pdb.set_trace() 19 | -------------------------------------------------------------------------------- /envs/torch_env.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abhayraw1/planet-torch/656189e060f24e262c09eb754e4bf99bed8662b7/envs/torch_env.py -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | from tqdm import trange 4 | from functools import partial 5 | from collections import defaultdict 6 | 7 | 8 | from torch.distributions import Normal, kl 9 | from torch.distributions.kl import kl_divergence 10 | 11 | from utils import * 12 | from memory import * 13 | from rssm_model import * 14 | from rssm_policy import * 15 | from rollout_generator import RolloutGenerator 16 | 17 | def train(memory, rssm, optimizer, device, N=32, H=50, beta=1.0, grads=False): 18 | """ 19 | Training implementation as indicated in: 20 | Learning Latent Dynamics for Planning from Pixels 21 | arXiv:1811.04551 22 | 23 | (a.) The Standard Varioational Bound Method 24 | using only single step predictions. 25 | """ 26 | free_nats = torch.ones(1, device=device)*3.0 27 | batch = memory.sample(N, H, time_first=True) 28 | x, u, r, t = [torch.tensor(x).float().to(device) for x in batch] 29 | preprocess_img(x, depth=5) 30 | e_t = bottle(rssm.encoder, x) 31 | h_t, s_t = rssm.get_init_state(e_t[0]) 32 | kl_loss, rc_loss, re_loss = 0, 0, 0 33 | states, priors, posteriors, posterior_samples = [], [], [], [] 34 | for i, a_t in enumerate(torch.unbind(u, dim=0)): 35 | h_t = rssm.deterministic_state_fwd(h_t, s_t, a_t) 36 | states.append(h_t) 37 | priors.append(rssm.state_prior(h_t)) 38 | posteriors.append(rssm.state_posterior(h_t, e_t[i + 1])) 39 | posterior_samples.append(Normal(*posteriors[-1]).rsample()) 40 | s_t = posterior_samples[-1] 41 | prior_dist = Normal(*map(torch.stack, zip(*priors))) 42 | posterior_dist = Normal(*map(torch.stack, zip(*posteriors))) 43 | states, posterior_samples = map(torch.stack, (states, posterior_samples)) 44 | rec_loss = F.mse_loss( 45 | bottle(rssm.decoder, states, posterior_samples), x[1:], 46 | reduction='none' 47 | ).sum((2, 3, 4)).mean() 48 | kld_loss = torch.max( 49 | kl_divergence(posterior_dist, prior_dist).sum(-1), 50 | free_nats 51 | ).mean() 52 | rew_loss = F.mse_loss( 53 | bottle(rssm.pred_reward, states, posterior_samples), r 54 | ) 55 | optimizer.zero_grad() 56 | nn.utils.clip_grad_norm_(rssm.parameters(), 1000., norm_type=2) 57 | (beta*kld_loss + rec_loss + rew_loss).backward() 58 | optimizer.step() 59 | metrics = { 60 | 'losses': { 61 | 'kl': kld_loss.item(), 62 | 'reconstruction': rec_loss.item(), 63 | 'reward_pred': rew_loss.item() 64 | }, 65 | } 66 | if grads: 67 | metrics['grad_norms'] = { 68 | k: 0 if v.grad is None else v.grad.norm().item() 69 | for k, v in rssm.named_parameters() 70 | } 71 | return metrics 72 | 73 | 74 | def main(): 75 | env = TorchImageEnvWrapper('Pendulum-v0', bit_depth=5) 76 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 77 | rssm_model = RecurrentStateSpaceModel(env.action_size).to(device) 78 | optimizer = torch.optim.Adam(rssm_model.parameters(), lr=1e-3, eps=1e-4) 79 | policy = RSSMPolicy( 80 | rssm_model, 81 | planning_horizon=20, 82 | num_candidates=1000, 83 | num_iterations=10, 84 | top_candidates=100, 85 | device=device 86 | ) 87 | rollout_gen = RolloutGenerator( 88 | env, 89 | device, 90 | policy=policy, 91 | episode_gen=lambda : Episode(partial(postprocess_img, depth=5)), 92 | max_episode_steps=100, 93 | ) 94 | mem = Memory(100) 95 | mem.append(rollout_gen.rollout_n(1, random_policy=True)) 96 | res_dir = 'results/' 97 | summary = TensorBoardMetrics(f'{res_dir}/') 98 | for i in trange(100, desc='Epoch', leave=False): 99 | metrics = {} 100 | for _ in trange(150, desc='Iter ', leave=False): 101 | train_metrics = train(mem, rssm_model.train(), optimizer, device) 102 | for k, v in flatten_dict(train_metrics).items(): 103 | if k not in metrics.keys(): 104 | metrics[k] = [] 105 | metrics[k].append(v) 106 | metrics[f'{k}_mean'] = np.array(v).mean() 107 | 108 | summary.update(metrics) 109 | mem.append(rollout_gen.rollout_once(explore=True)) 110 | eval_episode, eval_frames, eval_metrics = rollout_gen.rollout_eval() 111 | mem.append(eval_episode) 112 | save_video(eval_frames, res_dir, f'vid_{i+1}') 113 | summary.update(eval_metrics) 114 | 115 | if (i + 1) % 25 == 0: 116 | torch.save(rssm_model.state_dict(), f'{res_dir}/ckpt_{i+1}.pth') 117 | 118 | pdb.set_trace() 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | import numpy as np 4 | 5 | from utils import * 6 | from collections import deque 7 | from numpy.random import choice 8 | from torch import float32 as F32 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | class Episode: 12 | """Records the agent's interaction with the environment for a single 13 | episode. At termination, it converts all the data to Numpy arrays. 14 | """ 15 | def __init__(self, postprocess_fn=lambda x: x): 16 | self.x = [] 17 | self.u = [] 18 | self.t = [] 19 | self.r = [] 20 | self.postprocess_fn = postprocess_fn 21 | self._size = 0 22 | 23 | @property 24 | def size(self): 25 | return self._size 26 | 27 | def append(self, obs, act, reward, terminal): 28 | self._size += 1 29 | self.x.append(self.postprocess_fn(obs.numpy())) 30 | self.u.append(act.cpu().numpy()) 31 | self.r.append(reward) 32 | self.t.append(terminal) 33 | 34 | def terminate(self, obs): 35 | self.x.append(self.postprocess_fn(obs.numpy())) 36 | self.x = np.stack(self.x) 37 | self.u = np.stack(self.u) 38 | self.r = np.stack(self.r) 39 | self.t = np.stack(self.t) 40 | 41 | 42 | class Memory(deque): 43 | def __init__(self, size): 44 | """Maintains a FIFO list of `size` number of episodes. 45 | """ 46 | self.episodes = deque(maxlen=size) 47 | self.eps_lengths = deque(maxlen=size) 48 | print(f'Creating memory with len {size} episodes.') 49 | 50 | @property 51 | def size(self): 52 | return sum(self.eps_lengths) 53 | 54 | def _append(self, episode: Episode): 55 | if isinstance(episode, Episode): 56 | self.episodes.append(episode) 57 | self.eps_lengths.append(episode.size) 58 | else: 59 | raise ValueError('can only append or list of ') 60 | 61 | def append(self, episodes: [Episode]): 62 | if isinstance(episodes, Episode): 63 | episodes = [episodes] 64 | if isinstance(episodes, list): 65 | for e in episodes: 66 | self._append(e) 67 | else: 68 | raise ValueError('can only append or list of ') 69 | 70 | def sample(self, batch_size, tracelen=1, time_first=False): 71 | episode_idx = choice(len(self.episodes), batch_size) 72 | init_st_idx = [ 73 | choice(self.eps_lengths[i] - tracelen + 1) 74 | for i in episode_idx 75 | ] 76 | x, u, r, t = [], [], [], [] 77 | for n, (i, s) in enumerate(zip(episode_idx, init_st_idx)): 78 | x.append(self.episodes[i].x[s: s + tracelen + 1]) 79 | u.append(self.episodes[i].u[s: s + tracelen]) 80 | r.append(self.episodes[i].r[s: s + tracelen]) 81 | t.append(self.episodes[i].t[s: s + tracelen]) 82 | if tracelen == 1: 83 | rets = [np.stack(x)] + [np.stack(i)[:, 0] for i in (u, r, t)] 84 | else: 85 | rets = [np.stack(i) for i in (x, u, r, t)] 86 | if time_first: 87 | rets = [a.swapaxes(1, 0) for a in rets] 88 | return rets 89 | -------------------------------------------------------------------------------- /rollout_generator.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | import torch 4 | from collections import defaultdict 5 | 6 | from tqdm import trange 7 | from torchvision.utils import make_grid 8 | 9 | from memory import Episode # this needs modification! 10 | 11 | class RolloutGenerator: 12 | """Rollout generator class.""" 13 | def __init__(self, 14 | env, 15 | device, 16 | policy=None, 17 | max_episode_steps=None, 18 | episode_gen=None, 19 | name=None, 20 | ): 21 | self.env = env 22 | self.device = device 23 | self.policy = policy 24 | self.episode_gen = episode_gen or Episode 25 | self.name = name or 'Rollout Generator' 26 | self.max_episode_steps = max_episode_steps 27 | if self.max_episode_steps is None: 28 | self.max_episode_steps = self.env.max_episode_steps 29 | 30 | def rollout_once(self, random_policy=False, explore=False) -> Episode: 31 | """Performs a single rollout of an environment given a policy 32 | and returns and episode instance. 33 | """ 34 | if self.policy is None and not random_policy: 35 | random_policy = True 36 | print('Policy is None. Using random policy instead!!') 37 | if not random_policy: 38 | self.policy.reset() 39 | eps = self.episode_gen() 40 | obs = self.env.reset() 41 | des = f'{self.name} Ts' 42 | for _ in trange(self.max_episode_steps, desc=des, leave=False): 43 | if random_policy: 44 | act = self.env.sample_random_action() 45 | else: 46 | act = self.policy.poll(obs.to(self.device), explore).flatten() 47 | nobs, reward, terminal, _ = self.env.step(act) 48 | eps.append(obs, act, reward, terminal) 49 | obs = nobs 50 | eps.terminate(nobs) 51 | return eps 52 | 53 | def rollout_n(self, n=1, random_policy=False) -> [Episode]: 54 | """ 55 | Performs n rollouts. 56 | """ 57 | if self.policy is None and not random_policy: 58 | random_policy = True 59 | print('Policy is None. Using random policy instead!!') 60 | des = f'{self.name} EPS' 61 | ret = [] 62 | for _ in trange(n, desc=des, leave=False): 63 | ret.append(self.rollout_once(random_policy=random_policy)) 64 | return ret 65 | 66 | def rollout_eval_n(self, n): 67 | metrics = defaultdict(list) 68 | episodes, frames = [], [] 69 | for _ in range(n): 70 | e, f, m = self.rollout_eval() 71 | episodes.append(e) 72 | frames.append(f) 73 | for k, v in m.items(): 74 | metrics[k].append(v) 75 | return episodes, frames, metrics 76 | 77 | def rollout_eval(self): 78 | assert self.policy is not None, 'Policy is None!!' 79 | self.policy.reset() 80 | eps = self.episode_gen() 81 | obs = self.env.reset() 82 | des = f'{self.name} Eval Ts' 83 | frames = [] 84 | metrics = {} 85 | rec_losses = [] 86 | pred_r, act_r = [], [] 87 | eps_reward = 0 88 | for _ in trange(self.max_episode_steps, desc=des, leave=False): 89 | with torch.no_grad(): 90 | act = self.policy.poll(obs.to(self.device)).flatten() 91 | dec = self.policy.rssm.decoder( 92 | self.policy.h, 93 | self.policy.s 94 | ).squeeze().cpu().clamp_(-0.5, 0.5) 95 | rec_losses.append(((obs - dec).abs()).sum().item()) 96 | frames.append(make_grid([obs + 0.5, dec + 0.5], nrow=2).numpy()) 97 | pred_r.append(self.policy.rssm.pred_reward( 98 | self.policy.h, self.policy.s 99 | ).cpu().flatten().item()) 100 | nobs, reward, terminal, _ = self.env.step(act) 101 | eps.append(obs, act, reward, terminal) 102 | act_r.append(reward) 103 | eps_reward += reward 104 | obs = nobs 105 | eps.terminate(nobs) 106 | metrics['eval/episode_reward'] = eps_reward 107 | metrics['eval/reconstruction_loss'] = rec_losses 108 | metrics['eval/reward_pred_loss'] = abs( 109 | np.array(act_r)[:-1] - np.array(pred_r)[1:] 110 | ) 111 | return eps, np.stack(frames), metrics 112 | -------------------------------------------------------------------------------- /rssm_main_latent_overshooting.py: -------------------------------------------------------------------------------- 1 | # import gym 2 | import pdb 3 | import pickle 4 | 5 | from utils import * 6 | from memory import * 7 | from rssm_model import * 8 | from tqdm import trange 9 | from collections import defaultdict 10 | from torch.distributions import Normal 11 | from torch.distributions.kl import kl_divergence as kl_div 12 | 13 | from pprint import pprint 14 | BIT_DEPTH = 5 15 | FREE_NATS = 2 16 | STATE_SIZE = 200 17 | LATENT_SIZE = 30 18 | EMBEDDING_SIZE = 1024 19 | 20 | """ 21 | Training implementation as indicated in: 22 | Learning Latent Dynamics for Planning from Pixels 23 | arXiv:1811.04551 24 | 25 | (c.) The Latent Overshooting Method 26 | using only single step predictions. 27 | """ 28 | 29 | def train(memory, model, optimizer, record_grads=True): 30 | """ 31 | Trained using the Standard Variational Bound method indicated in Fig. 3a 32 | """ 33 | model.train() 34 | metrics = defaultdict(list) 35 | if record_grads: 36 | metrics['grads'] = defaultdict(list) 37 | for _ in trange(10, desc='# Epoch: ', leave=False): 38 | (x, u, _, _), lens = memory.sample(32) 39 | states, priors, posteriors = model(x, u) 40 | prior_dists = [Normal(*p) for p in priors] 41 | posterior_dists = [Normal(*p) for p in posteriors] 42 | posterior_samples = [d.rsample() for d in posterior_dists] 43 | # Reconstruction Loss 44 | rx = model.decoder(states[0], posterior_samples[0]) 45 | iloss = (((x[:, 0] - rx)**2).sum((1, 2, 3))).mean() 46 | # KL Divergence 47 | kl = kl_div(prior_dists[0], posterior_dists[0]).sum(-1) 48 | kloss = torch.max(FREE_NATS, kl).mean() 49 | mask = get_mask(u[..., 0], lens).T 50 | for i in range(1, len(states)): 51 | rx = model.decoder(states[i], posterior_samples[i]) 52 | iloss += (((x[:, i] - rx)**2).sum((1, 2, 3))*mask[i-1]).mean() 53 | kl = kl_div(prior_dists[i], posterior_dists[i]).sum(-1) 54 | kloss += torch.max(FREE_NATS, (kl*mask[i-1])).mean() 55 | kloss /= len(states) 56 | iloss /= len(states) 57 | optimizer.zero_grad() 58 | (iloss + kloss).backward() 59 | nn.utils.clip_grad_norm_(model.parameters(), 100, norm_type=2) 60 | if record_grads: 61 | pprint({ 62 | k: 0 if x.grad is None else x.grad.mean().item() 63 | for k, x in dict(model.named_parameters()).items() 64 | }) 65 | metrics['kl_losses'].append(kloss.item()) 66 | metrics['rec_losses'].append(iloss.item()) 67 | optimizer.step() 68 | return metrics 69 | 70 | 71 | def evaluate(memory, model, path, eps): 72 | model.eval() 73 | (x, u, _, _), lens = memory.sample(1) 74 | states, priors, posteriors = model(x, u) 75 | states = torch.stack(states).squeeze() 76 | priors = Normal(*map(lambda x: torch.stack(x).squeeze(), zip(*priors))) 77 | posts = Normal(*map(lambda x: torch.stack(x).squeeze(), zip(*posteriors))) 78 | pred1 = model.decoder(states, priors.mean) 79 | pred2 = model.decoder(states, posts.mean) 80 | save_frames(x[0], pred1, pred2, f'{path}_{eps}') 81 | 82 | 83 | def main(): 84 | global FREE_NATS 85 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 86 | FREE_NATS = torch.full((1, ), FREE_NATS).to(device) 87 | rssm = RecurrentStateSpaceModel(1, STATE_SIZE, LATENT_SIZE, EMBEDDING_SIZE) 88 | rssm = rssm.to(device) 89 | optimizer = torch.optim.Adam(get_combined_params(rssm), lr=1e-3) 90 | 91 | test_data = load_memory('test_exp_replay.pth', device) 92 | train_data = load_memory('train_exp_replay.pth', device) 93 | 94 | global_metrics = defaultdict(list) 95 | for i in trange(1000, desc='# Episode: ', leave=False): 96 | metrics = train(train_data, rssm, optimizer, record_grads=False) 97 | for k, v in metrics.items(): 98 | global_metrics[k].extend(metrics[k]) 99 | plot_metrics(global_metrics, path='results/test_rssm', prefix='TRAIN_') 100 | if (i + 1) % 10 == 0: 101 | evaluate(test_data, rssm, 'results/test_rssm/eps', i + 1) 102 | if (i + 1) % 25 == 0: 103 | torch.save(rssm.state_dict(), f'results/test_rssm/ckpt_{i+1}.pth') 104 | pdb.set_trace() 105 | 106 | 107 | if __name__ == '__main__': 108 | main() -------------------------------------------------------------------------------- /rssm_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.distributions import Normal 6 | 7 | 8 | class VisualEncoder(nn.Module): 9 | def __init__(self, embedding_size, activation_function='relu'): 10 | super().__init__() 11 | self.act_fn = getattr(F, activation_function) 12 | self.embedding_size = embedding_size 13 | self.conv1 = nn.Conv2d(3, 32, 4, stride=2) 14 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 15 | self.conv3 = nn.Conv2d(64, 128, 4, stride=2) 16 | self.conv4 = nn.Conv2d(128, 256, 4, stride=2) 17 | if embedding_size == 1024: 18 | self.fc = nn.Identity() 19 | else: 20 | self.fc = nn.Linear(1024, embedding_size) 21 | 22 | def forward(self, observation): 23 | hidden = self.act_fn(self.conv1(observation)) 24 | hidden = self.act_fn(self.conv2(hidden)) 25 | hidden = self.act_fn(self.conv3(hidden)) 26 | hidden = self.act_fn(self.conv4(hidden)) 27 | hidden = hidden.view(-1, 1024) 28 | hidden = self.fc(hidden) 29 | return hidden 30 | 31 | 32 | class VisualDecoder(nn.Module): 33 | def __init__(self, 34 | state_size, 35 | latent_size, 36 | embedding_size, 37 | activation_function='relu' 38 | ): 39 | super().__init__() 40 | self.act_fn = getattr(F, activation_function) 41 | self.embedding_size = embedding_size 42 | self.fc1 = nn.Linear(latent_size + state_size, embedding_size) 43 | self.conv1 = nn.ConvTranspose2d(embedding_size, 128, 5, stride=2) 44 | self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) 45 | self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) 46 | self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) 47 | 48 | def forward(self, state, latent): 49 | hidden = self.fc1(torch.cat([state, latent], dim=1)) 50 | hidden = hidden.view(-1, self.embedding_size, 1, 1) 51 | hidden = self.act_fn(self.conv1(hidden)) 52 | hidden = self.act_fn(self.conv2(hidden)) 53 | hidden = self.act_fn(self.conv3(hidden)) 54 | observation = self.conv4(hidden) 55 | return observation 56 | 57 | 58 | class RecurrentStateSpaceModel(nn.Module): 59 | """Recurrent State Space Model 60 | """ 61 | 62 | def __init__(self, 63 | action_size, 64 | state_size=200, 65 | latent_size=30, 66 | hidden_size=200, 67 | embed_size=1024, 68 | activation_function='relu' 69 | ): 70 | super().__init__() 71 | self.state_size = state_size 72 | self.action_size = action_size 73 | self.latent_size = latent_size 74 | self.act_fn = getattr(F, activation_function) 75 | self.encoder = VisualEncoder(embed_size) 76 | self.decoder = VisualDecoder(state_size, latent_size, embed_size) 77 | self.grucell = nn.GRUCell(state_size, state_size) 78 | self.lat_act_layer = nn.Linear(latent_size + action_size, state_size) 79 | self.fc_prior_1 = nn.Linear(state_size, hidden_size) 80 | self.fc_prior_m = nn.Linear(hidden_size, latent_size) 81 | self.fc_prior_s = nn.Linear(hidden_size, latent_size) 82 | self.fc_posterior_1 = nn.Linear(state_size + embed_size, hidden_size) 83 | self.fc_posterior_m = nn.Linear(hidden_size, latent_size) 84 | self.fc_posterior_s = nn.Linear(hidden_size, latent_size) 85 | self.fc_reward_1 = nn.Linear(state_size + latent_size, hidden_size) 86 | self.fc_reward_2 = nn.Linear(hidden_size, hidden_size) 87 | self.fc_reward_3 = nn.Linear(hidden_size, 1) 88 | 89 | 90 | def get_init_state(self, enc, h_t=None, s_t=None, a_t=None, mean=False): 91 | """Returns the initial posterior given the observation.""" 92 | N, dev = enc.size(0), enc.device 93 | h_t = torch.zeros(N, self.state_size).to(dev) if h_t is None else h_t 94 | s_t = torch.zeros(N, self.latent_size).to(dev) if s_t is None else s_t 95 | a_t = torch.zeros(N, self.action_size).to(dev) if a_t is None else a_t 96 | h_tp1 = self.deterministic_state_fwd(h_t, s_t, a_t) 97 | if mean: 98 | s_tp1 = self.state_posterior(h_t, enc, sample=True) 99 | else: 100 | s_tp1, _ = self.state_posterior(h_t, enc) 101 | return h_tp1, s_tp1 102 | 103 | def deterministic_state_fwd(self, h_t, s_t, a_t): 104 | """Returns the deterministic state given the previous states 105 | and action. 106 | """ 107 | h = torch.cat([s_t, a_t], dim=-1) 108 | h = self.act_fn(self.lat_act_layer(h)) 109 | return self.grucell(h, h_t) 110 | 111 | def state_prior(self, h_t, sample=False): 112 | """Returns the state prior given the deterministic state.""" 113 | z = self.act_fn(self.fc_prior_1(h_t)) 114 | m = self.fc_prior_m(z) 115 | s = F.softplus(self.fc_prior_s(z)) + 1e-1 116 | if sample: 117 | return m + torch.randn_like(m) * s 118 | return m, s 119 | 120 | def state_posterior(self, h_t, e_t, sample=False): 121 | """Returns the state prior given the deterministic state and obs.""" 122 | z = torch.cat([h_t, e_t], dim=-1) 123 | z = self.act_fn(self.fc_posterior_1(z)) 124 | m = self.fc_posterior_m(z) 125 | s = F.softplus(self.fc_posterior_s(z)) + 1e-1 126 | if sample: 127 | return m + torch.randn_like(m) * s 128 | return m, s 129 | 130 | def pred_reward(self, h_t, s_t): 131 | r = self.act_fn(self.fc_reward_1(torch.cat([h_t, s_t], dim=-1))) 132 | r = self.act_fn(self.fc_reward_2(r)) 133 | return self.fc_reward_3(r).squeeze() 134 | 135 | def rollout_prior(self, act, h_t, s_t): 136 | states, latents = [], [] 137 | for a_t in torch.unbind(act, dim=0): 138 | h_t = self.deterministic_state_fwd(h_t, s_t, a_t) 139 | s_t = self.state_prior(h_t) 140 | states.append(h_t) 141 | latents.append(s_t) 142 | Normal(*map(torch.stack, zip(*s_t))) 143 | return torch.stack(states), torch.stack(latents) 144 | -------------------------------------------------------------------------------- /rssm_policy.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | 4 | from torch.distributions import Normal 5 | 6 | 7 | class RSSMPolicy: 8 | def __init__(self, 9 | model, 10 | planning_horizon, 11 | num_candidates, 12 | num_iterations, 13 | top_candidates, 14 | device 15 | ): 16 | super().__init__() 17 | self.rssm = model 18 | self.N = num_candidates 19 | self.K = top_candidates 20 | self.T = num_iterations 21 | self.H = planning_horizon 22 | self.d = self.rssm.action_size 23 | self.device = device 24 | self.state_size = self.rssm.state_size 25 | self.latent_size = self.rssm.latent_size 26 | 27 | def reset(self): 28 | self.h = torch.zeros(1, self.state_size).to(self.device) 29 | self.s = torch.zeros(1, self.latent_size).to(self.device) 30 | self.a = torch.zeros(1, self.d).to(self.device) 31 | 32 | def _poll(self, obs): 33 | self.mu = torch.zeros(self.H, self.d).to(self.device) 34 | self.stddev = torch.ones(self.H, self.d).to(self.device) 35 | # observation could be of shape [CHW] but only 1 timestep 36 | assert len(obs.shape) == 3, 'obs should be [CHW]' 37 | self.h, self.s = self.rssm.get_init_state( 38 | self.rssm.encoder(obs[None]), 39 | self.h, self.s, self.a 40 | ) 41 | for _ in range(self.T): 42 | rwds = torch.zeros(self.N).to(self.device) 43 | actions = Normal(self.mu, self.stddev).sample((self.N,)) 44 | h_t = self.h.clone().expand(self.N, -1) 45 | s_t = self.s.clone().expand(self.N, -1) 46 | for a_t in torch.unbind(actions, dim=1): 47 | h_t = self.rssm.deterministic_state_fwd(h_t, s_t, a_t) 48 | s_t = self.rssm.state_prior(h_t, sample=True) 49 | rwds += self.rssm.pred_reward(h_t, s_t) 50 | _, k = torch.topk(rwds, self.K, dim=0, largest=True, sorted=False) 51 | self.mu = actions[k].mean(dim=0) 52 | self.stddev = actions[k].std(dim=0, unbiased=False) 53 | self.a = self.mu[0:1] 54 | 55 | def poll(self, observation, explore=False): 56 | with torch.no_grad(): 57 | self._poll(observation) 58 | if explore: 59 | self.a += torch.randn_like(self.a)*0.3 60 | return self.a 61 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pdb 4 | import cv2 5 | import gym 6 | import torch 7 | import pickle 8 | import pathlib 9 | import numpy as np 10 | 11 | from collections import defaultdict 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torchvision.utils import make_grid, save_image 15 | 16 | 17 | def to_tensor_obs(image): 18 | """ 19 | Converts the input np img to channel first 64x64 dim torch img. 20 | """ 21 | image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_LINEAR) 22 | image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) 23 | return image 24 | 25 | 26 | def postprocess_img(image, depth): 27 | """ 28 | Postprocess an image observation for storage. 29 | From float32 numpy array [-0.5, 0.5] to uint8 numpy array [0, 255]) 30 | """ 31 | image = np.floor((image + 0.5) * 2 ** depth) 32 | return np.clip(image * 2**(8 - depth), 0, 2**8 - 1).astype(np.uint8) 33 | 34 | 35 | def preprocess_img(image, depth): 36 | """ 37 | Preprocesses an observation inplace. 38 | From float32 Tensor [0, 255] to [-0.5, 0.5] 39 | Also adds some noise to the observations !! 40 | """ 41 | image.div_(2 ** (8 - depth)).floor_().div_(2 ** depth).sub_(0.5) 42 | image.add_(torch.randn_like(image).div_(2 ** depth)).clamp_(-0.5, 0.5) 43 | 44 | 45 | def bottle(func, *tensors): 46 | """ 47 | Evaluates a func that operates in N x D with inputs of shape N x T x D 48 | """ 49 | n, t = tensors[0].shape[:2] 50 | out = func(*[x.view(n*t, *x.shape[2:]) for x in tensors]) 51 | return out.view(n, t, *out.shape[1:]) 52 | 53 | 54 | def get_combined_params(*models): 55 | """ 56 | Returns the combine parameter list of all the models given as input. 57 | """ 58 | params = [] 59 | for model in models: 60 | params.extend(list(model.parameters())) 61 | return params 62 | 63 | 64 | def save_video(frames, path, name): 65 | """ 66 | Saves a video containing frames. 67 | """ 68 | frames = (frames*255).clip(0, 255).astype('uint8').transpose(0, 2, 3, 1) 69 | _, H, W, _ = frames.shape 70 | writer = cv2.VideoWriter( 71 | str(pathlib.Path(path)/f'{name}.mp4'), 72 | cv2.VideoWriter_fourcc(*'mp4v'), 25., (W, H), True 73 | ) 74 | for frame in frames[..., ::-1]: 75 | writer.write(frame) 76 | writer.release() 77 | 78 | 79 | def save_frames(target, pred_prior, pred_posterior, name, n_rows=5): 80 | """ 81 | Saves the target images with the generated prior and posterior predictions. 82 | """ 83 | image = torch.cat([target, pred_prior, pred_posterior], dim=3) 84 | save_image(make_grid(image + 0.5, nrow=n_rows), f'{name}.png') 85 | 86 | 87 | def get_mask(tensor, lengths): 88 | """ 89 | Generates the masks for batches of sequences. 90 | Time should be the first axis. 91 | input: 92 | tensor: the tensor for which to generate the mask [N x T x ...] 93 | lengths: lengths of the seq. [N] 94 | """ 95 | mask = torch.zeros_like(tensor) 96 | for i in range(len(lengths)): 97 | mask[i, :lengths[i]] = 1. 98 | return mask 99 | 100 | 101 | def load_memory(path, device): 102 | """ 103 | Loads an experience replay buffer. 104 | """ 105 | with open(path, 'rb') as f: 106 | memory = pickle.load(f) 107 | memory.device = device 108 | for e in memory.data: 109 | e.device = device 110 | return memory 111 | 112 | 113 | def flatten_dict(data, sep='.', prefix=''): 114 | """Flattens a nested dict into a dict. 115 | eg. {'a': 2, 'b': {'c': 20}} -> {'a': 2, 'b.c': 20} 116 | """ 117 | x = {} 118 | for key, val in data.items(): 119 | if isinstance(val, dict): 120 | x.update(flatten_dict(val, sep=sep, prefix=key)) 121 | else: 122 | x[f'{prefix}{sep}{key}'] = val 123 | return x 124 | 125 | 126 | class TensorBoardMetrics: 127 | """Plots and (optionally) stores metrics for an experiment. 128 | """ 129 | def __init__(self, path): 130 | self.writer = SummaryWriter(path) 131 | self.steps = defaultdict(lambda: 0) 132 | self.summary = {} 133 | 134 | def assign_type(self, key, val): 135 | if isinstance(val, (list, tuple)): 136 | fun = lambda k, x, s: self.writer.add_histogram(k, np.array(x), s) 137 | self.summary[key] = fun 138 | elif isinstance(val, (np.ndarray, torch.Tensor)): 139 | self.summary[key] = self.writer.add_histogram 140 | elif isinstance(val, float) or isinstance(val, int): 141 | self.summary[key] = self.writer.add_scalar 142 | else: 143 | raise ValueError(f'Datatype {type(val)} not allowed') 144 | 145 | def update(self, metrics: dict): 146 | metrics = flatten_dict(metrics) 147 | for key_dots, val in metrics.items(): 148 | key = key_dots.replace('.', '/') 149 | if self.summary.get(key, None) is None: 150 | self.assign_type(key, val) 151 | self.summary[key](key, val, self.steps[key]) 152 | self.steps[key] += 1 153 | 154 | 155 | def apply_model(model, inputs, ignore_dim=None): 156 | pass 157 | 158 | def plot_metrics(metrics, path, prefix): 159 | for key, val in metrics.items(): 160 | lineplot(np.arange(len(val)), val, f'{prefix}{key}', path) 161 | 162 | def lineplot(xs, ys, title, path='', xaxis='episode'): 163 | MAX_LINE = Line(color='rgb(0, 132, 180)', dash='dash') 164 | MIN_LINE = Line(color='rgb(0, 132, 180)', dash='dash') 165 | NO_LINE = Line(color='rgba(0, 0, 0, 0)') 166 | MEAN_LINE = Line(color='rgb(0, 172, 237)') 167 | std_colour = 'rgba(29, 202, 255, 0.2)' 168 | if isinstance(ys, dict): 169 | data = [] 170 | for key, val in ys.items(): 171 | xs = np.arange(len(val)) 172 | data.append(Scatter(x=xs, y=np.array(val), name=key)) 173 | elif np.asarray(ys, dtype=np.float32).ndim == 2: 174 | ys = np.asarray(ys, dtype=np.float32) 175 | ys_mean, ys_std = ys.mean(-1), ys.std(-1) 176 | ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std 177 | l_max = Scatter(x=xs, y=ys.max(-1), line=MAX_LINE, name='Max') 178 | l_min = Scatter(x=xs, y=ys.min(-1), line=MIN_LINE, name='Min') 179 | l_stu = Scatter(x=xs, y=ys_upper, line=NO_LINE, showlegend=False) 180 | l_mean = Scatter( 181 | x=xs, y=ys_mean, fill='tonexty', fillcolor=std_colour, 182 | line=MEAN_LINE, name='Mean' 183 | ) 184 | l_stl = Scatter( 185 | x=xs, y=ys_lower, fill='tonexty', fillcolor=std_colour, 186 | line=NO_LINE, name='-1 Std. Dev.', showlegend=False 187 | ) 188 | data = [l_stu, l_mean, l_stl, l_min, l_max] 189 | else: 190 | data = [Scatter(x=xs, y=ys, line=MEAN_LINE)] 191 | plotly.offline.plot({ 192 | 'data': data, 193 | 'layout': dict( 194 | title=title, 195 | xaxis={'title': xaxis}, 196 | yaxis={'title': title} 197 | ) 198 | }, filename=os.path.join(path, title + '.html'), auto_open=False 199 | ) 200 | 201 | 202 | 203 | class TorchImageEnvWrapper: 204 | """ 205 | Torch Env Wrapper that wraps a gym env and makes interactions using Tensors. 206 | Also returns observations in image form. 207 | """ 208 | def __init__(self, env, bit_depth, observation_shape=None, act_rep=2): 209 | self.env = gym.make(env) 210 | self.bit_depth = bit_depth 211 | self.action_repeats = act_rep 212 | 213 | def reset(self): 214 | self.env.reset() 215 | x = to_tensor_obs(self.env.render(mode='rgb_array')) 216 | preprocess_img(x, self.bit_depth) 217 | return x 218 | 219 | def step(self, u): 220 | u, rwds = u.cpu().detach().numpy(), 0 221 | for _ in range(self.action_repeats): 222 | _, r, d, i = self.env.step(u) 223 | rwds += r 224 | x = to_tensor_obs(self.env.render(mode='rgb_array')) 225 | preprocess_img(x, self.bit_depth) 226 | return x, rwds, d, i 227 | 228 | def render(self): 229 | self.env.render() 230 | 231 | def close(self): 232 | self.env.close() 233 | 234 | @property 235 | def observation_size(self): 236 | return (3, 64, 64) 237 | 238 | @property 239 | def action_size(self): 240 | return self.env.action_space.shape[0] 241 | 242 | def sample_random_action(self): 243 | return torch.tensor(self.env.action_space.sample()) 244 | --------------------------------------------------------------------------------