├── .gitignore ├── LICENSE ├── README.md ├── atari ├── atari.py ├── common │ ├── __init__.py │ ├── load_cfg.py │ ├── optim.py │ ├── pad_dim.py │ └── timer.py ├── default.yaml ├── dqn │ ├── __init__.py │ ├── actor.py │ ├── algo.py │ ├── buffer.py │ ├── learner.py │ ├── model.py │ ├── prepare_obs.py │ └── sampler.py ├── emb_vis.py ├── eval.py ├── models │ └── .gitkeep ├── render.py ├── repre │ ├── cpc.py │ ├── inverse_dynamics.py │ ├── predictor.py │ ├── w_mse.py │ └── whitening.py ├── train.py └── train_emb.py ├── docker └── Dockerfile ├── montezuma.png └── pol ├── common ├── __init__.py ├── load_cfg.py ├── optim.py └── timer.py ├── default.yaml ├── dqn ├── __init__.py ├── actor.py ├── algo.py ├── buffer.py ├── learner.py └── model.py ├── env.py ├── eval.py ├── models └── .gitkeep ├── pol_env.py ├── predictor.py ├── render.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/wandb/** 3 | */models/*.pt 4 | */monitor_output/ 5 | */monitor_output/** 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Alexander Ermolov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Latent World Models For Intrinsically Motivated Exploration 2 | Official repository | [arXiv:2010.02302](https://arxiv.org/abs/2010.02302) | NeurIPS 2020 Spotlight 3 | 4 | [10m video presentation from NeurIPS](https://slideslive.com/38937965/latent-world-models-for-intrinsically-motivated-exploration) 5 | 6 | ![montezuma's revenge t-sne](https://raw.githubusercontent.com/htdt/lwm/master/montezuma.png) 7 | 8 | ## Installation 9 | The implementation is based on PyTorch. Logging works on [wandb.ai](https://wandb.ai/). See `docker/Dockerfile`. 10 | 11 | ## Usage 12 | After training, the resulting models will be saved as `models/dqn.pt`, `models/predictor.pt` etc. 13 | For evaluation, models will be loaded from the same filenames. 14 | 15 | #### Atari 16 | To reproduce LWM results from [Table 2](https://arxiv.org/abs/2010.02302): 17 | ```sh 18 | cd atari 19 | python -m train --env MontezumaRevenge --seed 0 20 | python -m eval --env MontezumaRevenge --seed 0 21 | ``` 22 | See `default.yaml` for detailed configuration. 23 | 24 | To get trajectory plots as on Figure 3: 25 | ```sh 26 | cd atari 27 | # first train encoders for random agent 28 | python -m train_emb 29 | # next play the game with keyboard 30 | python -m emb_vis 31 | # see plot_*.png 32 | ``` 33 | 34 | #### Partially Observable Labyrinth 35 | To reproduce scores from Table 1: 36 | ```sh 37 | cd pol 38 | # DQN agent 39 | python -m train --size 3 40 | python -m eval --size 3 41 | 42 | # DQN + WM agent 43 | python -m train --size 3 --add_ri 44 | python -m eval --size 3 --add_ri 45 | 46 | # random agent 47 | python -m eval --size 3 --random 48 | ``` 49 | 50 | Code of the environment is in [pol/pol_env.py](https://github.com/htdt/lwm/blob/master/pol/pol_env.py), it extends `gym.Env` and can be used as usual: 51 | ```python 52 | from pol_env import PolEnv 53 | env = PolEnv(size=3) 54 | obs = env.reset() 55 | action = env.observation_space.sample() 56 | obs, reward, done, infos = env.step(action) 57 | env.render() 58 | ####### 59 | # # # 60 | # ### # 61 | # #@ # 62 | # # # # 63 | # # # 64 | ####### 65 | ``` 66 | 67 | ## Bibtex 68 | ``` 69 | @inproceedings{LWM, 70 | author = {Ermolov, Aleksandr and Sebe, Nicu}, 71 | booktitle = {Advances in Neural Information Processing Systems}, 72 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, 73 | pages = {5565--5575}, 74 | publisher = {Curran Associates, Inc.}, 75 | title = {Latent World Models For Intrinsically Motivated Exploration}, 76 | volume = {33}, 77 | year = {2020} 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /atari/atari.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from gym.spaces.box import Box 3 | from baselines import bench 4 | from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv 5 | from baselines.common.atari_wrappers import wrap_deepmind, make_atari 6 | from baselines.common.vec_env import VecEnvWrapper 7 | 8 | 9 | def make_vec_envs(name, num, seed=0, max_ep_len=100000): 10 | def make_env(rank): 11 | def _thunk(): 12 | full_name = f"{name}NoFrameskip-v4" 13 | env = make_atari(full_name, max_episode_steps=max_ep_len) 14 | env.seed(seed + rank) 15 | env = bench.Monitor(env, None) 16 | env = wrap_deepmind(env, episode_life=True, clip_rewards=False) 17 | return env 18 | 19 | return _thunk 20 | 21 | envs = [make_env(i) for i in range(num)] 22 | envs = ShmemVecEnv(envs, context="fork") 23 | envs = VecTorch(envs) 24 | return envs 25 | 26 | 27 | class VecTorch(VecEnvWrapper): 28 | def __init__(self, env): 29 | super(VecTorch, self).__init__(env) 30 | obs = self.observation_space.shape 31 | self.observation_space = Box(0, 255, [obs[2], obs[0], obs[1]], 32 | dtype=self.observation_space.dtype) 33 | 34 | def _convert_obs(self, x): 35 | return torch.from_numpy(x).permute(0, 3, 1, 2) 36 | 37 | def reset(self): 38 | return self._convert_obs(self.venv.reset()) 39 | 40 | def step_async(self, actions): 41 | assert len(actions.shape) == 2 42 | actions = actions[:, 0].cpu().numpy() 43 | self.venv.step_async(actions) 44 | 45 | def step_wait(self): 46 | obs, reward, done, info = self.venv.step_wait() 47 | obs = self._convert_obs(obs) 48 | reward = torch.from_numpy(reward)[..., None].float() 49 | done = torch.tensor(done, dtype=torch.uint8)[..., None] 50 | return obs, reward, done, info 51 | -------------------------------------------------------------------------------- /atari/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htdt/lwm/b370828a2c99a7aae8b37f25319eba351a2768cb/atari/common/__init__.py -------------------------------------------------------------------------------- /atari/common/load_cfg.py: -------------------------------------------------------------------------------- 1 | import re 2 | import yaml 3 | 4 | 5 | def replace_e_float(d): 6 | p = re.compile(r"^-?\d+(\.\d+)?e-?\d+$") 7 | for name, val in d.items(): 8 | if type(val) == dict: 9 | replace_e_float(val) 10 | elif type(val) == str and p.match(val): 11 | d[name] = float(val) 12 | 13 | 14 | def load_cfg(name, prefix="."): 15 | with open(f"{prefix}/{name}.yaml") as f: 16 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 17 | replace_e_float(cfg) 18 | return cfg 19 | -------------------------------------------------------------------------------- /atari/common/optim.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import torch 4 | from torch.nn.utils import clip_grad_norm_ 5 | from torch.optim import Adam 6 | 7 | 8 | @dataclass 9 | class ParamOptim: 10 | params: List[torch.Tensor] 11 | lr: float = 1e-3 12 | eps: float = 1e-8 13 | clip_grad: float = None 14 | 15 | def __post_init__(self): 16 | self.optim = Adam(self.params, lr=self.lr, eps=self.eps) 17 | 18 | def step(self, loss): 19 | self.optim.zero_grad() 20 | loss.backward() 21 | if self.clip_grad is not None: 22 | clip_grad_norm_(self.params, self.clip_grad) 23 | self.optim.step() 24 | return loss 25 | -------------------------------------------------------------------------------- /atari/common/pad_dim.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def pad_dim(x, dim, size=1, value=0, left=False): 5 | p = [0] * len(x.shape) * 2 6 | p[-(dim + 1) * 2 + int(not left)] = size 7 | return F.pad(x, p, value=value) 8 | -------------------------------------------------------------------------------- /atari/common/timer.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | 6 | def timer_log(num_iter=1000): 7 | log = {} 8 | mean_t = defaultdict(list) 9 | t = mark = None 10 | while True: 11 | prev_t, prev_mark = t, mark 12 | mark = yield log 13 | t = time() 14 | if prev_mark is not None: 15 | mean_t[prev_mark].append(t - prev_t) 16 | 17 | if mark is None and len(list(mean_t.values())[0]) >= num_iter: 18 | log = {"time/" + k: np.mean(v) * 1000 for k, v in mean_t.items()} 19 | mean_t = defaultdict(list) 20 | else: 21 | log = {} 22 | -------------------------------------------------------------------------------- /atari/default.yaml: -------------------------------------------------------------------------------- 1 | optim: 2 | lr: 1e-4 3 | eps: 1e-3 4 | clip_grad: 40 5 | 6 | agent: 7 | actors: 128 8 | unroll: 80 9 | burnin: 40 10 | batch_size: 16 11 | frame_stack: 1 12 | n_step: 5 13 | gamma: 0.99 14 | target_tau: 0.005 15 | # eps: 0.01 16 | 17 | w_mse: 18 | lr: 5e-4 19 | frame_stack: 1 20 | spatial_shift: 4 21 | temporal_shift: 2 22 | emb_size: 32 23 | rnn_size: 256 24 | ri_momentum: 0.999 25 | ri_clamp: 10 26 | 27 | buffer: 28 | device: cpu 29 | size: 1e6 30 | warmup: 4e5 31 | prior_exp: 0.9 32 | importance_sampling_exp: 0.6 33 | 34 | train: 35 | frames: 5e7 36 | max_ep_len: 10000 37 | learner_every: 1 38 | w_mse_every: 1 39 | log_every: 100 40 | checkpoint_every: 1000 41 | -------------------------------------------------------------------------------- /atari/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import actor_iter 2 | from .model import DQN 3 | from .learner import Learner 4 | 5 | 6 | __all__ = ["actor_iter", "Learner", "DQN"] 7 | -------------------------------------------------------------------------------- /atari/dqn/actor.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | import torch 3 | import numpy as np 4 | from common.timer import timer_log 5 | 6 | 7 | def actor_iter(env, model, predictor, warmup, eps=None): 8 | minstep = int(warmup / env.num_envs) 9 | hx = ( 10 | torch.zeros(env.num_envs, 512, device=model.device) 11 | if model is not None 12 | else None 13 | ) 14 | hx_pred = None 15 | step = {"obs": env.reset()} 16 | 17 | timer = timer_log(100) 18 | next(timer) 19 | if eps is None: 20 | eps = (0.4 ** torch.linspace(1, 8, env.num_envs))[..., None] 21 | mean_reward, mean_len = [], [] 22 | log = {} 23 | 24 | for n_iter in count(): 25 | full_step = yield step, hx, log 26 | 27 | timer.send("actor/action") 28 | action = torch.randint(env.action_space.n, (env.num_envs, 1)) 29 | if n_iter >= minstep: 30 | with torch.no_grad(): 31 | _, ri, hx_pred = predictor.get_error(full_step, hx_pred) 32 | full_step = full_step[1:] 33 | full_step["reward"] += ri 34 | 35 | qs, hx = model(**full_step, hx=hx) 36 | action_greedy = qs[0].argmax(1)[..., None].cpu() 37 | x = torch.rand(env.num_envs, 1) > eps 38 | action[x] = action_greedy[x] 39 | 40 | timer.send("actor/env") 41 | # done = 1 means obs is first step of next episode 42 | # prev_obs + action = obs 43 | obs, reward, done, infos = env.step(action) 44 | 45 | log = timer.send(None) 46 | 47 | ep = [info["episode"] for info in infos if "episode" in info] 48 | mean_reward += [x["r"] for x in ep] 49 | mean_len += [x["l"] for x in ep] 50 | if len(mean_reward) >= env.num_envs: 51 | log = { 52 | "reward": np.mean(mean_reward), 53 | "len": np.mean(mean_len), 54 | **log, 55 | } 56 | mean_reward, mean_len, = [], [] 57 | if "episode" in infos[-1]: 58 | log = {"reward_last": infos[-1]["episode"]["r"], **log} 59 | 60 | step = {"obs": obs, "action": action, "reward": reward, "done": done} 61 | -------------------------------------------------------------------------------- /atari/dqn/algo.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from common.pad_dim import pad_dim 4 | 5 | 6 | def vf_rescaling(x): 7 | eps = 1e-3 8 | return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x 9 | 10 | 11 | def inv_vf_rescaling(x): 12 | eps = 1e-3 13 | return torch.sign(x) * ( 14 | (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps))) - 1) / (2 * eps)) ** 2 15 | - 1 16 | ) 17 | 18 | 19 | def n_step_bellman_target(reward, done, q, gamma, n_step): 20 | mask = 1 - pad_dim(done, dim=0, size=n_step - 1) 21 | reward = pad_dim(reward, dim=0, size=n_step - 1) 22 | for i in range(n_step): 23 | q[i:] *= gamma * mask[: len(mask) - i] 24 | q[i:] += reward[: len(reward) - i] 25 | return q[n_step - 1 :] 26 | 27 | 28 | def get_td_error(batch, hx_start, model, model_t, cfg, need_stat=False): 29 | n_step = cfg["agent"]["n_step"] 30 | pf = cfg["agent"]["frame_stack"] - 1 # prefix for frame stack 31 | burnin = cfg["agent"]["burnin"] 32 | bellman_target = partial( 33 | n_step_bellman_target, 34 | gamma=cfg["agent"]["gamma"], 35 | n_step=cfg["agent"]["n_step"], 36 | ) 37 | 38 | if burnin > 0: 39 | with torch.no_grad(): 40 | hx = model(**batch[: burnin + pf], hx=hx_start, only_hx=True) 41 | hx_target = model_t(**batch[: burnin + 1 + pf], hx=hx_start, only_hx=True) 42 | else: 43 | hx = hx_target = None 44 | 45 | qs, _ = model(**batch[burnin:], hx=hx) 46 | 47 | with torch.no_grad(): 48 | qs_target, _ = model_t(**batch[burnin + 1 :], hx=hx_target) 49 | 50 | action = batch["action"][burnin + pf + 1 : -n_step + 1] 51 | reward = batch["reward"][burnin + pf + 1 : -n_step + 1] 52 | done = batch["done"][burnin + pf + 1 : -n_step + 1].float() 53 | 54 | q = qs[:-n_step].gather(2, action) 55 | ns_action = qs[1:].argmax(2)[..., None].detach() 56 | next_q = qs_target.gather(2, ns_action) 57 | next_q = inv_vf_rescaling(next_q) 58 | target_q = bellman_target(reward, done, next_q) 59 | target_q = vf_rescaling(target_q) 60 | td_error = (q - target_q).abs() 61 | 62 | if need_stat: 63 | log = { 64 | "loss": td_error.mean().item(), 65 | "q_mean": qs.mean().item(), 66 | "q_std": qs.std().item(), 67 | } 68 | else: 69 | log = {} 70 | return td_error, log 71 | -------------------------------------------------------------------------------- /atari/dqn/buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DictWithSlicing(dict): 5 | def __getitem__(self, key): 6 | if isinstance(key, slice): 7 | return {k: v[key] for k, v in self.items()} 8 | return super().__getitem__(key) 9 | 10 | 11 | class Buffer: 12 | def __init__(self, maxlen, num_env, obs_shape, device): 13 | self.maxlen, self.num_env, self.device = maxlen, num_env, device 14 | self.reset() 15 | 16 | def tensor(shape=(1,), dtype=torch.float): 17 | return torch.empty( 18 | self.maxlen, self.num_env, *shape, dtype=dtype, device=self.device 19 | ) 20 | 21 | self._buffer = { 22 | "obs": tensor(obs_shape, torch.uint8), 23 | "action": tensor(dtype=torch.long), 24 | "reward": tensor(), 25 | "done": tensor(dtype=torch.uint8), 26 | } 27 | 28 | def query(self, idx, idx_env, steps, device="cuda"): 29 | qsize = len(idx) 30 | s = torch.arange(steps - 1, -1, -1) 31 | q0 = (idx[None, ...].repeat(steps, 1) - s[..., None]).flatten() 32 | q1 = idx_env[None, ...].repeat(steps, 1).flatten() 33 | return DictWithSlicing( 34 | { 35 | k: v[q0, q1].view(steps, qsize, *v.shape[2:]).to(device) 36 | for k, v in self._buffer.items() 37 | } 38 | ) 39 | 40 | def append(self, step): 41 | for k in self._buffer: 42 | if k not in step: 43 | self._buffer[k][self.cursor] = 0 44 | else: 45 | assert step[k].dtype == self._buffer[k].dtype 46 | assert step[k].shape == self._buffer[k].shape[1:] 47 | self._buffer[k][self.cursor] = step[k].to(self.device) 48 | self.cursor = (self.cursor + 1) % self.maxlen 49 | self._size = min(self.maxlen, self._size + 1) 50 | 51 | def get_recent(self, steps, device="cuda"): 52 | if len(self) == 0: 53 | return None 54 | idx = torch.tensor([self.cursor - 1] * self.num_env) 55 | idx_env = torch.arange(self.num_env) 56 | step = self.query(idx, idx_env, steps, device) 57 | if len(self) < steps: 58 | for el in step.values(): 59 | el[: steps - len(self)] = 0 60 | return step 61 | 62 | def reset(self): 63 | self.cursor = self._size = 0 64 | 65 | def __len__(self): 66 | return self._size 67 | -------------------------------------------------------------------------------- /atari/dqn/learner.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from functools import partial 3 | import random 4 | import torch 5 | 6 | from common.optim import ParamOptim 7 | from dqn.algo import get_td_error 8 | from dqn.sampler import Sampler 9 | 10 | 11 | def tde_to_prior(x, eta=0.9): 12 | return (eta * x.max(0).values + (1 - eta) * x.mean(0)).detach().cpu() 13 | 14 | 15 | class Learner: 16 | def __init__(self, model, buffer, predictor, cfg): 17 | num_env = cfg["agent"]["actors"] 18 | model_t = deepcopy(model) 19 | model_t = model_t.cuda().eval() 20 | self.model, self.model_t = model, model_t 21 | self.buffer = buffer 22 | self.predictor = predictor 23 | self.optim = ParamOptim(params=model.parameters(), **cfg["optim"]) 24 | 25 | self.batch_size = cfg["agent"]["batch_size"] 26 | self.unroll = cfg["agent"]["unroll"] 27 | self.unroll_prefix = ( 28 | cfg["agent"]["burnin"] 29 | + cfg["agent"]["n_step"] 30 | + cfg["agent"]["frame_stack"] 31 | - 1 32 | ) 33 | self.sample_steps = self.unroll_prefix + self.unroll 34 | self.hx_shift = cfg["agent"]["frame_stack"] - 1 35 | num_unrolls = (self.buffer.maxlen - self.unroll_prefix) // self.unroll 36 | 37 | if cfg["buffer"]["prior_exp"] > 0: 38 | self.sampler = Sampler( 39 | num_env=num_env, 40 | maxlen=num_unrolls, 41 | prior_exp=cfg["buffer"]["prior_exp"], 42 | importance_sampling_exp=cfg["buffer"]["importance_sampling_exp"], 43 | ) 44 | self.s2b = torch.empty(num_unrolls, dtype=torch.long) 45 | self.hxs = torch.empty(num_unrolls, num_env, 512, device="cuda") 46 | self.hx_cursor = 0 47 | else: 48 | self.sampler = None 49 | 50 | self.target_tau = cfg["agent"]["target_tau"] 51 | self.td_error = partial(get_td_error, model=model, model_t=model_t, cfg=cfg) 52 | 53 | def _update_target(self): 54 | for t, s in zip(self.model_t.parameters(), self.model.parameters()): 55 | t.data.copy_(t.data * (1.0 - self.target_tau) + s.data * self.target_tau) 56 | 57 | def append(self, step, hx, n_iter): 58 | self.buffer.append(step) 59 | 60 | if self.sampler is not None: 61 | if (n_iter + 1) % self.unroll == self.hx_shift: 62 | self.hxs[self.hx_cursor] = hx 63 | self.hx_cursor = (self.hx_cursor + 1) % len(self.hxs) 64 | 65 | k = n_iter - self.unroll_prefix 66 | if k > 0 and (k + 1) % self.unroll == 0: 67 | self.s2b[self.sampler.cursor] = self.buffer.cursor - 1 68 | x = self.buffer.get_recent(self.sample_steps) 69 | hx = self.hxs[self.sampler.cursor] 70 | with torch.no_grad(): 71 | loss, _ = self.td_error(x, hx) 72 | self.sampler.append(tde_to_prior(loss)) 73 | 74 | if len(self.sampler) == self.sampler.maxlen: 75 | idx_new = self.s2b[self.sampler.cursor - 1] 76 | idx_old = self.s2b[self.sampler.cursor] 77 | d = (idx_old - idx_new) % self.buffer.maxlen 78 | assert self.unroll_prefix + self.unroll <= d 79 | assert d < self.unroll_prefix + self.unroll * 2 80 | 81 | def loss_sampler(self, need_stat): 82 | idx0, idx1, weights = self.sampler.sample(self.batch_size) 83 | weights = weights.cuda() 84 | batch = self.buffer.query(self.s2b[idx0], idx1, self.sample_steps) 85 | hx = self.hxs[idx0, idx1] 86 | loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True) 87 | batch["reward"][1:] += ri 88 | td_error, log = self.td_error(batch, hx, need_stat=need_stat) 89 | self.sampler.update_prior(idx0, idx1, tde_to_prior(td_error)) 90 | loss = td_error.pow(2).sum(0) * weights[..., None] 91 | loss_pred = loss_pred.sum(0) * weights[..., None] 92 | return loss, loss_pred, ri, log 93 | 94 | def loss_uniform(self, need_stat): 95 | if len(self.buffer) < self.buffer.maxlen: 96 | no_prev = set(range(self.sample_steps)) 97 | else: 98 | no_prev = set( 99 | (self.buffer.cursor + i) % self.buffer.maxlen 100 | for i in range(self.sample_steps) 101 | ) 102 | all_idx = list(set(range(len(self.buffer))) - no_prev) 103 | idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) 104 | idx1 = torch.tensor( 105 | random.choices(range(self.buffer.num_env), k=self.batch_size) 106 | ) 107 | batch = self.buffer.query(idx0, idx1, self.sample_steps) 108 | loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True) 109 | batch["reward"][1:] += ri 110 | td_error, log = self.td_error(batch, None, need_stat=need_stat) 111 | loss = td_error.pow(2).sum(0) 112 | loss_pred = loss_pred.sum(0) 113 | return loss, loss_pred, ri, log 114 | 115 | def train(self, need_stat=True): 116 | loss_f = self.loss_uniform if self.sampler is None else self.loss_sampler 117 | loss, loss_pred, ri, log = loss_f(need_stat) 118 | self.optim.step(loss.mean()) 119 | self.predictor.optim.step(loss_pred.mean()) 120 | self._update_target() 121 | 122 | if need_stat: 123 | log.update( 124 | { 125 | "ri_std": ri.std(), 126 | "ri_mean": ri.mean(), 127 | "ri_run_mean": self.predictor.ri_mean, 128 | "ri_run_std": self.predictor.ri_std, 129 | "loss_predictor": loss_pred.mean(), 130 | } 131 | ) 132 | if self.sampler is not None: 133 | log.update(self.sampler.stats()) 134 | return log 135 | -------------------------------------------------------------------------------- /atari/dqn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import one_hot, relu 4 | from dqn.prepare_obs import prepare_obs 5 | 6 | 7 | def mnih_cnn(size_in, size_out): 8 | return nn.Sequential( 9 | nn.Conv2d(size_in, 32, 8, 4), 10 | nn.ReLU(), 11 | nn.Conv2d(32, 64, 4, 2), 12 | nn.ReLU(), 13 | nn.Conv2d(64, 64, 3, 1), 14 | nn.ReLU(), 15 | nn.Flatten(), 16 | nn.Linear(64 * 7 * 7, size_out), 17 | ) 18 | 19 | 20 | class DQN(nn.Module): 21 | def __init__(self, size_out, size_stack, device="cuda"): 22 | super(DQN, self).__init__() 23 | self.size_out = size_out 24 | self.size_stack = size_stack 25 | self.conv = mnih_cnn(size_stack, 512) 26 | self.rnn = nn.GRUCell(512 + 1 + size_out, 512) 27 | self.adv = nn.Sequential( 28 | nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, size_out, bias=False) 29 | ) 30 | self.val = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 1)) 31 | self.device = device 32 | 33 | def forward(self, obs, action, reward, done, hx=None, only_hx=False): 34 | obs = prepare_obs(obs, done, self.size_stack) 35 | steps, batch, *img_shape = obs.shape 36 | obs = obs.view(steps * batch, *img_shape) 37 | x = relu(self.conv(obs)) 38 | x = x.view(steps, batch, 512) 39 | 40 | pf = self.size_stack - 1 41 | mask = (1 - done[pf:]).float() 42 | a = one_hot(action[pf:, :, 0], self.size_out).float() * mask 43 | r = reward[pf:] * mask 44 | x = torch.cat([x, a, r], 2) 45 | 46 | y = torch.empty(steps, batch, 512, device=self.device) 47 | for i in range(steps): 48 | if hx is not None: 49 | hx *= mask[i] 50 | y[i] = hx = self.rnn(x[i], hx) 51 | hx = hx.clone().detach() 52 | if only_hx: 53 | return hx 54 | 55 | y = y.view(steps * batch, 512) 56 | adv, val = self.adv(y), self.val(y) 57 | q = val + adv - adv.mean(1, keepdim=True) 58 | q = q.view(steps, batch, self.size_out) 59 | return q, hx 60 | -------------------------------------------------------------------------------- /atari/dqn/prepare_obs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def prepare_obs(obs, done, fstack): 5 | assert obs.dtype == torch.uint8 6 | assert obs.shape[2] == 1 7 | 8 | if fstack > 1: 9 | obs = stack_frames(obs, fstack) 10 | done_stacked = stack_frames(done, fstack) 11 | obs = obs * obs_mask(done_stacked) 12 | return obs.float() / 128 - 1 13 | 14 | 15 | def stack_frames(x, stack=4): 16 | """ 17 | Args: 18 | x: [steps + stack - 1, batch, 1, ...] - flat trajectory with prefix = stack - 1 19 | Returns: 20 | [steps, batch, stack, ...] - each step (dim 0) includes stack of frames (dim 2) 21 | """ 22 | shape = (x.shape[0] - stack + 1, x.shape[1], stack, *x.shape[3:]) 23 | y = torch.empty(shape, dtype=x.dtype, device=x.device) 24 | for i in range(stack): 25 | y[:, :, i] = x[i : shape[0] + i, :, 0] 26 | return y 27 | 28 | 29 | def obs_mask(done): 30 | """ 31 | mask to zero out observations in 4-frame stack when done = 1 32 | """ 33 | mask = 1 - done[:, :, 1:] 34 | for i in reversed(range(mask.shape[2] - 1)): 35 | mask[:, :, i] *= mask[:, :, i + 1] 36 | mask = torch.cat([mask, torch.ones_like(mask[:, :, -1:])], 2) 37 | mask = mask[..., None, None] 38 | return mask 39 | -------------------------------------------------------------------------------- /atari/dqn/sampler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | 5 | @dataclass 6 | class Sampler: 7 | num_env: int 8 | maxlen: int 9 | prior_exp: float = 0.9 10 | importance_sampling_exp: float = 0.6 11 | 12 | def __post_init__(self): 13 | self.cursor = self._size = 0 14 | self._prior = torch.empty(self.maxlen, self.num_env, 1) 15 | 16 | def append(self, prior): 17 | assert prior.shape == self._prior.shape[1:] 18 | self._prior[self.cursor] = prior 19 | self.cursor = (self.cursor + 1) % self.maxlen 20 | self._size = min(self.maxlen, self._size + 1) 21 | 22 | def sample(self, batch_size): 23 | p = self._prior[: len(self)].view(-1) ** self.prior_exp 24 | idx_flat = p.multinomial(batch_size, replacement=True) 25 | weights = p[idx_flat] ** -self.importance_sampling_exp 26 | weights /= weights.max() 27 | idx0 = idx_flat // self.num_env 28 | idx1 = idx_flat % self.num_env 29 | return idx0, idx1, weights 30 | 31 | def update_prior(self, idx0, idx1, prior): 32 | self._prior[idx0, idx1] = prior 33 | 34 | def stats(self): 35 | x = self._prior[: len(self)] 36 | return { 37 | "prior/mean": x.mean(), 38 | "prior/std": x.std(), 39 | "prior/max": x.max(), 40 | } 41 | 42 | def __len__(self): 43 | return self._size 44 | -------------------------------------------------------------------------------- /atari/emb_vis.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import time 4 | from gym.envs.atari.atari_env import ACTION_MEANING 5 | from sklearn.manifold import TSNE 6 | import matplotlib.pyplot as plt 7 | 8 | from atari import make_vec_envs 9 | from dqn.model import mnih_cnn 10 | 11 | 12 | ACTION_ID = {v: k for k, v in ACTION_MEANING.items()} 13 | KEY2ACTION = { 14 | "w": ACTION_ID["UP"], 15 | "s": ACTION_ID["DOWN"], 16 | "a": ACTION_ID["LEFT"], 17 | "d": ACTION_ID["RIGHT"], 18 | "f": ACTION_ID["FIRE"], 19 | "i": ACTION_ID["UPFIRE"], 20 | "k": ACTION_ID["DOWNFIRE"], 21 | "j": ACTION_ID["LEFTFIRE"], 22 | "l": ACTION_ID["RIGHTFIRE"], 23 | "u": ACTION_ID["UPLEFTFIRE"], 24 | "o": ACTION_ID["UPRIGHTFIRE"], 25 | } 26 | 27 | 28 | def convert_key(a): 29 | return KEY2ACTION.get(chr(a), ACTION_ID["NOOP"]) 30 | 31 | 32 | def key_press(key, mod): 33 | global cur_action, restart, pause 34 | if key == 0xFF0D: 35 | restart = True 36 | if key == 32: 37 | pause = not pause 38 | cur_action = convert_key(key) 39 | 40 | 41 | def key_release(key, mod): 42 | global cur_action 43 | a = convert_key(key) 44 | if cur_action == a: 45 | cur_action = 0 46 | 47 | 48 | def plot(x, fname): 49 | path2d = TSNE().fit_transform(x) 50 | x, y = tuple(zip(*path2d)) 51 | plt.figure(figsize=(12, 12)) 52 | 53 | plt.scatter(x, y, c=list(range(len(x))), alpha=0.5, s=200) 54 | plt.savefig(fname) 55 | for i in range(0, len(x), 10): 56 | plt.text( 57 | x[i], 58 | y[i], 59 | str(i), 60 | horizontalalignment="center", 61 | verticalalignment="center", 62 | fontsize=14, 63 | ) 64 | plt.savefig("num_" + fname) 65 | 66 | 67 | if __name__ == "__main__": 68 | name = "MontezumaRevenge" if len(sys.argv) < 2 else sys.argv[1] 69 | env = make_vec_envs(name, 1) 70 | 71 | conv_wmse = mnih_cnn(1, 32) 72 | conv_idf = mnih_cnn(1, 32) 73 | conv_cpc = mnih_cnn(1, 32) 74 | conv_rnd = mnih_cnn(1, 32) 75 | conv_wmse.load_state_dict(torch.load("models/conv_wmse.pt", map_location="cpu")) 76 | conv_idf.load_state_dict(torch.load("models/conv_idf.pt", map_location="cpu")) 77 | conv_cpc.load_state_dict(torch.load("models/conv_cpc.pt", map_location="cpu")) 78 | conv_wmse.eval(), conv_idf.eval(), conv_cpc.eval(), conv_rnd.eval() 79 | 80 | mem = torch.empty(4, 1000, 32) 81 | cursor = 0 82 | 83 | env.render() 84 | env.unwrapped.viewer.window.on_key_press = key_press 85 | env.unwrapped.viewer.window.on_key_release = key_release 86 | window_still_open = True 87 | 88 | while window_still_open: 89 | cur_action = 0 90 | restart = False 91 | pause = False 92 | 93 | obs = env.reset() 94 | total_reward = steps = 0 95 | while 1: 96 | 97 | steps += 1 98 | if steps == 1000: 99 | break 100 | a = torch.tensor([[cur_action]]) 101 | obs, r, done, info = env.step(a) 102 | 103 | with torch.no_grad(): 104 | obs = obs.float() / 128 - 1 105 | mem[0, cursor] = conv_wmse(obs)[0] 106 | mem[1, cursor] = conv_idf(obs)[0] 107 | mem[2, cursor] = conv_cpc(obs)[0] 108 | mem[3, cursor] = conv_rnd(obs)[0] 109 | 110 | if cursor % 10 == 0: 111 | print(cursor) 112 | cursor += 1 113 | 114 | if r != 0: 115 | print(f"reward {r.item():0.2f}") 116 | total_reward += r.item() 117 | window_still_open = env.render() 118 | if not window_still_open or done or restart: 119 | break 120 | while pause: 121 | env.render() 122 | time.sleep(0.1) 123 | time.sleep(0.1) 124 | 125 | print(f"timesteps {steps} reward {total_reward:0.2f}") 126 | 127 | plot(mem[0, :cursor], "plot_wmse.png") 128 | plot(mem[1, :cursor], "plot_idf.png") 129 | plot(mem[2, :cursor], "plot_cpc.png") 130 | plot(mem[3, :cursor], "plot_rnd.png") 131 | -------------------------------------------------------------------------------- /atari/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import wandb 4 | 5 | from dqn.buffer import Buffer 6 | from common.load_cfg import load_cfg 7 | from atari import make_vec_envs 8 | from dqn import actor_iter, DQN 9 | from repre.w_mse import WMSE 10 | from repre.predictor import Predictor 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description="") 15 | parser.add_argument("--cfg", type=str, default="default") 16 | parser.add_argument("--env", type=str, default="MontezumaRevenge") 17 | parser.add_argument("--seed", type=int, default=0) 18 | parser.add_argument("--ri_scale", type=float, default=1) 19 | p = parser.parse_args() 20 | cfg = load_cfg(p.cfg) 21 | cfg.update(vars(p)) 22 | wandb.init(project="lwm", config=cfg) 23 | 24 | num_env = cfg["agent"]["actors"] 25 | fstack = cfg["agent"]["frame_stack"] 26 | envs = make_vec_envs(cfg["env"], num_env, cfg["seed"]) 27 | 28 | buffer = Buffer( 29 | num_env=num_env, 30 | maxlen=int(cfg["buffer"]["size"] / num_env), 31 | obs_shape=envs.observation_space.shape, 32 | device=cfg["buffer"]["device"], 33 | ) 34 | model = DQN(envs.action_space.n, fstack).cuda().train() 35 | wmse = WMSE(buffer, cfg) 36 | pred = Predictor(buffer, wmse.encoder, envs.action_space.n, cfg) 37 | actor = actor_iter(envs, model, pred, 0, eps=0.001) 38 | 39 | wmse.load(), pred.load() 40 | cp = torch.load("models/dqn.pt", map_location="cuda") 41 | model.load_state_dict(cp) 42 | model.eval() 43 | 44 | while True: 45 | full_step = buffer.get_recent(fstack + 1) 46 | step, hx, log = actor.send(full_step) 47 | buffer.append(step) 48 | if "reward" in log: 49 | wandb.log({"final_reward": log["reward"]}) 50 | break 51 | 52 | wandb.save("models/dqn.pt") 53 | wandb.save("models/w_mse.pt") 54 | wandb.save("models/predictor.pt") 55 | -------------------------------------------------------------------------------- /atari/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htdt/lwm/b370828a2c99a7aae8b37f25319eba351a2768cb/atari/models/.gitkeep -------------------------------------------------------------------------------- /atari/render.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | from tqdm import trange 4 | import torch 5 | 6 | from common.load_cfg import load_cfg 7 | from atari import make_vec_envs 8 | from dqn import actor_iter, DQN 9 | from dqn.buffer import Buffer 10 | from repre.w_mse import WMSE 11 | from repre.predictor import Predictor 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser(description="") 16 | parser.add_argument("--cfg", type=str, default="default") 17 | parser.add_argument("--env", type=str, default="MontezumaRevenge") 18 | parser.add_argument("--ri_scale", type=float, default=1) 19 | p = parser.parse_args() 20 | cfg = load_cfg(p.cfg) 21 | cfg.update(vars(p)) 22 | 23 | num_env = cfg["agent"]["actors"] 24 | fstack = cfg["agent"]["frame_stack"] 25 | env = make_vec_envs(name=cfg["env"], num=1) 26 | model = DQN(env.action_space.n, fstack, device="cpu") 27 | wmse = WMSE(None, cfg, device="cpu") 28 | pred = Predictor(None, wmse.encoder, env.action_space.n, cfg, device="cpu") 29 | actor = actor_iter(env, model, pred, 0, eps=0) 30 | obs_shape = env.observation_space.shape 31 | buffer = Buffer(num_env=1, maxlen=fstack + 1, obs_shape=obs_shape, device="cpu") 32 | 33 | wmse.load(), pred.load() 34 | cp = torch.load("models/dqn.pt", map_location="cpu") 35 | model.load_state_dict(cp) 36 | model.eval() 37 | 38 | for n_iter in trange(2000): 39 | full_step = buffer.get_recent(fstack + 1, "cpu") 40 | step, hx, log_a = actor.send(full_step) 41 | buffer.append(step) 42 | env.render() 43 | time.sleep(1 / 30) 44 | -------------------------------------------------------------------------------- /atari/repre/cpc.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.functional import cross_entropy, one_hot, relu 6 | 7 | from dqn.model import mnih_cnn 8 | from dqn.prepare_obs import prepare_obs 9 | from common.optim import ParamOptim 10 | from dqn.buffer import Buffer 11 | 12 | 13 | class CPCModel(nn.Module): 14 | def __init__(self, num_action, size_emb, size_stack, device="cuda"): 15 | super(CPCModel, self).__init__() 16 | self.size_emb = size_emb 17 | self.size_stack = size_stack 18 | self.num_action = num_action 19 | self.device = device 20 | 21 | self.conv = mnih_cnn(size_stack, size_emb) 22 | self.rnn = nn.GRUCell(num_action + size_emb, 512) 23 | self.fc = nn.Linear(512, size_emb) 24 | 25 | def forward(self, obs, action, done, hx=None, only_hx=False): 26 | obs = prepare_obs(obs, done, self.size_stack) 27 | steps, batch, *img_shape = obs.shape 28 | obs = obs.view(steps * batch, *img_shape) 29 | z = self.conv(obs).view(steps, batch, self.size_emb) 30 | 31 | pf = self.size_stack - 1 32 | mask = (1 - done[pf:]).float() 33 | a = one_hot(action[:, :, 0], self.num_action).float() 34 | x = torch.cat([relu(z[:-1]), a], 2) 35 | 36 | steps -= 1 37 | y = torch.empty(steps, batch, 512, device=self.device) 38 | for i in range(steps): 39 | if hx is not None: 40 | hx *= mask[i] 41 | y[i] = hx = self.rnn(x[i], hx) 42 | hx = hx.clone().detach() 43 | if only_hx: 44 | return hx 45 | 46 | y = y.view(steps * batch, 512) 47 | z_pred = self.fc(y).view(steps, batch, self.size_emb) 48 | return z[1:], z_pred, hx 49 | 50 | 51 | @dataclass 52 | class CPC: 53 | buffer: Buffer 54 | num_action: int 55 | frame_stack: int = 1 56 | batch_size: int = 32 57 | unroll: int = 32 58 | emb_size: int = 32 59 | lr: float = 5e-4 60 | device: str = "cuda" 61 | 62 | def __post_init__(self): 63 | self.model = CPCModel(self.num_action, self.emb_size, self.frame_stack) 64 | self.model = self.model.train().to(self.device) 65 | self.optim = ParamOptim(params=self.model.parameters(), lr=self.lr) 66 | self.target = torch.arange(self.batch_size * self.unroll).to(self.device) 67 | 68 | def train(self): 69 | # burnin = 2, fstack = 4, unroll = 2 70 | # idx 0 1 2 3 4 5 6 7 71 | # bin p p p b b b 72 | # a a 73 | # hx 74 | # rol p p p o o o 75 | # a a 76 | 77 | sample_steps = self.frame_stack + self.unroll 78 | 79 | if len(self.buffer) < self.buffer.maxlen: 80 | no_prev = set(range(sample_steps)) 81 | else: 82 | no_prev = set( 83 | (self.buffer.cursor + i) % self.buffer.maxlen 84 | for i in range(sample_steps) 85 | ) 86 | all_idx = list(set(range(len(self.buffer))) - no_prev) 87 | idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) 88 | idx1 = torch.tensor( 89 | random.choices(range(self.buffer.num_env), k=self.batch_size) 90 | ) 91 | batch = self.buffer.query(idx0, idx1, sample_steps) 92 | 93 | obs = batch["obs"] 94 | action = batch["action"][self.frame_stack :] 95 | done = batch["done"] 96 | z, z_pred, _ = self.model(obs, action, done) 97 | 98 | size = self.batch_size * self.unroll 99 | z = z.view(size, self.emb_size) 100 | z_pred = z_pred.view(size, self.emb_size) 101 | logits = z @ z_pred.t() 102 | loss = cross_entropy(logits, self.target) 103 | acc = (logits.argmax(-1) == self.target).float().mean() 104 | self.optim.step(loss) 105 | return {"loss_cpc": loss.item(), "acc_cpc": acc} 106 | 107 | def load(self): 108 | cp = torch.load("models/cpc.pt", map_location=self.device) 109 | self.model.load_state_dict(cp) 110 | 111 | def save(self): 112 | torch.save(self.model.state_dict(), "models/cpc.pt") 113 | -------------------------------------------------------------------------------- /atari/repre/inverse_dynamics.py: -------------------------------------------------------------------------------- 1 | import random 2 | from itertools import chain 3 | from dataclasses import dataclass 4 | import torch 5 | from torch import nn 6 | from torch.nn.functional import cross_entropy, relu 7 | from common.optim import ParamOptim 8 | from dqn.model import mnih_cnn 9 | from dqn.buffer import Buffer 10 | from dqn.prepare_obs import prepare_obs 11 | 12 | 13 | @dataclass 14 | class IDF: 15 | buffer: Buffer 16 | num_action: int 17 | emb_size: int = 32 18 | batch_size: int = 256 19 | lr: float = 5e-4 20 | frame_stack: int = 1 21 | device: str = "cuda" 22 | 23 | def __post_init__(self): 24 | self.encoder = mnih_cnn(self.frame_stack, self.emb_size) 25 | self.encoder = self.encoder.to(self.device).train() 26 | self.clf = nn.Sequential( 27 | nn.Linear(self.emb_size * 2, 128), 28 | nn.ReLU(), 29 | nn.Linear(128, self.num_action), 30 | ) 31 | self.clf = self.clf.to(self.device).train() 32 | params = chain(self.encoder.parameters(), self.clf.parameters()) 33 | self.optim = ParamOptim(lr=self.lr, params=params) 34 | 35 | def train(self): 36 | # 0 1 2 3 4 37 | # p p p o o 38 | # a 39 | 40 | sample_steps = self.frame_stack + 1 41 | if len(self.buffer) < self.buffer.maxlen: 42 | no_prev = set(range(sample_steps)) 43 | else: 44 | no_prev = set( 45 | (self.buffer.cursor + i) % self.buffer.maxlen 46 | for i in range(sample_steps) 47 | ) 48 | all_idx = list(set(range(len(self.buffer))) - no_prev) 49 | idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) 50 | idx1 = torch.tensor( 51 | random.choices(range(self.buffer.num_env), k=self.batch_size) 52 | ) 53 | batch = self.buffer.query(idx0, idx1, sample_steps) 54 | obs = prepare_obs(batch["obs"], batch["done"], self.frame_stack) 55 | action = batch["action"][-1, :, 0] 56 | 57 | x0, x1 = self.encoder(obs[0]), self.encoder(obs[1]) 58 | x = torch.cat([x0, x1], dim=-1) 59 | y = self.clf(relu(x)) 60 | loss_idf = cross_entropy(y, action) 61 | acc_idf = (y.argmax(-1) == action).float().mean() 62 | 63 | self.optim.step(loss_idf) 64 | return {"loss_idf": loss_idf, "acc_idf": acc_idf} 65 | 66 | def load(self): 67 | cp = torch.load("models/idf.pt", map_location=self.device) 68 | self.encoder.load_state_dict(cp[0]) 69 | self.clf.load_state_dict(cp[1]) 70 | 71 | def save(self): 72 | cp = [self.encoder.state_dict(), self.clf.state_dict()] 73 | torch.save(cp, "models/idf.pt") 74 | -------------------------------------------------------------------------------- /atari/repre/predictor.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.functional import one_hot 5 | from common.optim import ParamOptim 6 | from dqn.prepare_obs import prepare_obs 7 | 8 | 9 | class PredictorModel(nn.Module): 10 | def __init__(self, num_action, fstack, emb_size, rnn_size): 11 | super(PredictorModel, self).__init__() 12 | self.fstack = fstack 13 | self.num_action = num_action 14 | self.rnn_size = rnn_size 15 | self.emb_fc = nn.Linear(emb_size + num_action, 128) 16 | self.rnn = nn.GRUCell(128, rnn_size) 17 | self.fc = nn.Sequential( 18 | nn.Linear(rnn_size, rnn_size), nn.ReLU(), nn.Linear(rnn_size, emb_size), 19 | ) 20 | 21 | def forward(self, z, action, done, hx=None): 22 | unroll, batch, emb_size = z.shape 23 | a = one_hot(action[:, :, 0], self.num_action).float() 24 | z = torch.cat([z, a], dim=2) 25 | z = self.emb_fc(z.view(unroll * batch, (emb_size + self.num_action))) 26 | z = z.view(unroll, batch, 128) 27 | 28 | mask = 1 - done.float() 29 | x = torch.empty(unroll, batch, self.rnn_size, device=z.device) 30 | for i in range(unroll): 31 | if hx is not None: 32 | hx *= mask[i] 33 | x[i] = hx = self.rnn(z[i], hx) 34 | hx = hx.clone().detach() 35 | x = self.fc(x.view(unroll * batch, self.rnn_size)) 36 | z_pred = x.view(unroll, batch, emb_size) 37 | return z_pred, hx 38 | 39 | 40 | class Predictor: 41 | def __init__(self, buffer, encoder, num_action, cfg, device="cuda"): 42 | self.device = device 43 | self.buffer = buffer 44 | self.encoder = encoder 45 | 46 | self.frame_stack = cfg["w_mse"]["frame_stack"] 47 | self.emb_size = cfg["w_mse"]["emb_size"] 48 | self.rnn_size = cfg["w_mse"]["rnn_size"] 49 | 50 | self.model = PredictorModel( 51 | num_action, self.frame_stack, self.emb_size, self.rnn_size 52 | ) 53 | self.model = self.model.to(device).train() 54 | lr = cfg["w_mse"]["lr"] 55 | self.optim = ParamOptim(params=self.model.parameters(), lr=lr) 56 | self.ri_mean = self.ri_std = None 57 | self.ri_momentum = cfg["w_mse"]["ri_momentum"] 58 | self.ri_clamp = cfg["w_mse"].get("ri_clamp") 59 | self.ri_scale = cfg["ri_scale"] 60 | 61 | def get_error(self, batch, hx=None, update_stats=False): 62 | # p p p o o o 63 | # a a 64 | obs = prepare_obs(batch["obs"], batch["done"], self.frame_stack) 65 | steps, batch_size, *obs_shape = obs.shape 66 | obs = obs.view(batch_size * steps, *obs_shape) 67 | with torch.no_grad(): 68 | z = self.encoder(obs) 69 | z = z.view(steps, batch_size, self.emb_size) 70 | 71 | action = batch["action"][self.frame_stack :] 72 | done = batch["done"][self.frame_stack - 1 : -1] 73 | z_pred, hx = self.model(z[:-1], action, done, hx) 74 | err = (z[1:] - z_pred).pow(2).mean(2) 75 | 76 | ri = err.detach() 77 | if update_stats: 78 | if self.ri_mean is None: 79 | self.ri_mean = ri.mean() 80 | self.ri_std = ri.std() 81 | else: 82 | m = self.ri_momentum 83 | self.ri_mean = m * self.ri_mean + (1 - m) * ri.mean() 84 | self.ri_std = m * self.ri_std + (1 - m) * ri.std() 85 | if self.ri_mean is not None: 86 | ri = (ri[..., None] - self.ri_mean) / self.ri_std 87 | if self.ri_clamp is not None: 88 | ri.clamp_(-self.ri_clamp, self.ri_clamp) 89 | ri *= self.ri_scale 90 | else: 91 | ri = 0 92 | return err, ri, hx 93 | 94 | def train(self): 95 | # this function is used only for pretrain, main training loop is in dqn learner 96 | batch_size = 16 97 | sample_steps = self.frame_stack - 1 + 100 98 | if len(self.buffer) < self.buffer.maxlen: 99 | no_prev = set(range(sample_steps)) 100 | else: 101 | no_prev = set( 102 | (self.buffer.cursor + i) % self.buffer.maxlen 103 | for i in range(sample_steps) 104 | ) 105 | all_idx = list(set(range(len(self.buffer))) - no_prev) 106 | idx0 = torch.tensor(random.choices(all_idx, k=batch_size)) 107 | idx1 = torch.tensor(random.choices(range(self.buffer.num_env), k=batch_size)) 108 | batch = self.buffer.query(idx0, idx1, sample_steps) 109 | er = self.get_error(batch, update_stats=True)[0] 110 | loss = er.sum(0).mean() 111 | self.optim.step(loss) 112 | return {"loss_predictor": loss.item()} 113 | 114 | def load(self): 115 | cp = torch.load("models/predictor.pt", map_location=self.device) 116 | self.ri_mean, self.ri_std, model = cp 117 | self.model.load_state_dict(model) 118 | 119 | def save(self): 120 | data = [self.ri_mean, self.ri_std, self.model.state_dict()] 121 | torch.save(data, "models/predictor.pt") 122 | -------------------------------------------------------------------------------- /atari/repre/w_mse.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from common.optim import ParamOptim 4 | from repre.whitening import Whitening2d 5 | from dqn.model import mnih_cnn 6 | from dqn.prepare_obs import prepare_obs 7 | 8 | 9 | class WMSE: 10 | batch_size: int = 256 11 | lr: float = 5e-4 12 | 13 | def __init__(self, buffer, cfg, device="cuda"): 14 | self.device = device 15 | self.buffer = buffer 16 | self.emb_size = cfg["w_mse"]["emb_size"] 17 | self.temporal_shift = cfg["w_mse"]["temporal_shift"] 18 | self.spatial_shift = cfg["w_mse"]["spatial_shift"] 19 | self.frame_stack = cfg["w_mse"]["frame_stack"] 20 | 21 | self.encoder = mnih_cnn(self.frame_stack, self.emb_size) 22 | self.encoder = self.encoder.to(self.device).train() 23 | self.optim = ParamOptim(lr=self.lr, params=self.encoder.parameters()) 24 | self.w = Whitening2d(self.emb_size, track_running_stats=False) 25 | 26 | def load(self): 27 | cp = torch.load("models/w_mse.pt", map_location=self.device) 28 | self.encoder.load_state_dict(cp) 29 | 30 | def save(self): 31 | torch.save(self.encoder.state_dict(), "models/w_mse.pt") 32 | 33 | def train(self): 34 | def spatial(): 35 | return random.randrange(-self.spatial_shift, self.spatial_shift + 1) 36 | 37 | sample_steps = self.frame_stack + self.temporal_shift 38 | if len(self.buffer) < self.buffer.maxlen: 39 | no_prev = set(range(sample_steps)) 40 | else: 41 | no_prev = set( 42 | (self.buffer.cursor + i) % self.buffer.maxlen 43 | for i in range(sample_steps) 44 | ) 45 | all_idx = list(set(range(len(self.buffer))) - no_prev) 46 | idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) 47 | idx1 = torch.tensor( 48 | random.choices(range(self.buffer.num_env), k=self.batch_size) 49 | ) 50 | batch = self.buffer.query(idx0, idx1, sample_steps) 51 | obs = prepare_obs(batch["obs"], batch["done"], self.frame_stack) 52 | idx2 = random.choices(range(1, self.temporal_shift + 1), k=self.batch_size) 53 | x0 = obs[0] 54 | x1 = obs[idx2, range(self.batch_size)] 55 | 56 | if self.spatial_shift > 0: 57 | for n in range(self.batch_size): 58 | for x in [x0, x1]: 59 | shifts = spatial(), spatial() 60 | x[n] = torch.roll(x[n], shifts=shifts, dims=(-2, -1)) 61 | 62 | x0, x1 = self.encoder(x0), self.encoder(x1) 63 | z = self.w(torch.cat([x0, x1], dim=0)) 64 | loss = (z[:len(x0)] - z[len(x0):]).pow(2).mean() 65 | self.optim.step(loss) 66 | return {"loss_wmse": loss.item()} 67 | -------------------------------------------------------------------------------- /atari/repre/whitening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import conv2d 4 | 5 | 6 | class Whitening2d(nn.Module): 7 | def __init__(self, num_features, momentum=0.01, track_running_stats=True, eps=0): 8 | super(Whitening2d, self).__init__() 9 | self.num_features = num_features 10 | self.momentum = momentum 11 | self.track_running_stats = track_running_stats 12 | self.eps = eps 13 | 14 | if self.track_running_stats: 15 | self.register_buffer( 16 | "running_mean", torch.zeros([1, self.num_features, 1, 1]) 17 | ) 18 | self.register_buffer("running_variance", torch.eye(self.num_features)) 19 | 20 | def forward(self, x): 21 | x = x.unsqueeze(2).unsqueeze(3) 22 | m = x.mean(0).view(self.num_features, -1).mean(-1).view(1, -1, 1, 1) 23 | if not self.training and self.track_running_stats: # for inference 24 | m = self.running_mean 25 | xn = x - m 26 | 27 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.num_features, -1) 28 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 29 | 30 | eye = torch.eye(self.num_features).type(f_cov.type()) 31 | 32 | if not self.training and self.track_running_stats: # for inference 33 | f_cov = self.running_variance 34 | 35 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 36 | 37 | inv_sqrt = torch.triangular_solve( 38 | eye, torch.cholesky(f_cov_shrinked), upper=False 39 | )[0] 40 | inv_sqrt = inv_sqrt.contiguous().view( 41 | self.num_features, self.num_features, 1, 1 42 | ) 43 | 44 | decorrelated = conv2d(xn, inv_sqrt) 45 | 46 | if self.training and self.track_running_stats: 47 | self.running_mean = torch.add( 48 | self.momentum * m.detach(), 49 | (1 - self.momentum) * self.running_mean, 50 | out=self.running_mean, 51 | ) 52 | self.running_variance = torch.add( 53 | self.momentum * f_cov.detach(), 54 | (1 - self.momentum) * self.running_variance, 55 | out=self.running_variance, 56 | ) 57 | 58 | return decorrelated.squeeze(2).squeeze(2) 59 | 60 | def extra_repr(self): 61 | return "features={}, eps={}, momentum={}".format( 62 | self.num_features, self.eps, self.momentum 63 | ) 64 | -------------------------------------------------------------------------------- /atari/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import trange 3 | import torch 4 | import wandb 5 | 6 | from dqn.buffer import Buffer 7 | from common.load_cfg import load_cfg 8 | from atari import make_vec_envs 9 | from dqn import actor_iter, Learner, DQN 10 | from repre.w_mse import WMSE 11 | from repre.predictor import Predictor 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser(description="") 16 | parser.add_argument("--cfg", type=str, default="default") 17 | parser.add_argument("--env", type=str, default="MontezumaRevenge") 18 | parser.add_argument("--seed", type=int, default=0) 19 | parser.add_argument("--ri_scale", type=float, default=1) 20 | p = parser.parse_args() 21 | cfg = load_cfg(p.cfg) 22 | cfg.update(vars(p)) 23 | wandb.init(project="lwm", config=cfg) 24 | 25 | num_env = cfg["agent"]["actors"] 26 | fstack = cfg["agent"]["frame_stack"] 27 | envs = make_vec_envs(cfg["env"], num_env, cfg["seed"], cfg["train"]["max_ep_len"]) 28 | 29 | buffer = Buffer( 30 | num_env=num_env, 31 | maxlen=int(cfg["buffer"]["size"] / num_env), 32 | obs_shape=envs.observation_space.shape, 33 | device=cfg["buffer"]["device"], 34 | ) 35 | model = DQN(envs.action_space.n, fstack).cuda().train() 36 | wmse = WMSE(buffer, cfg) 37 | pred = Predictor(buffer, wmse.encoder, envs.action_space.n, cfg) 38 | learner = Learner(model, buffer, pred, cfg) 39 | actor = actor_iter( 40 | envs, model, pred, cfg["buffer"]["warmup"], eps=cfg["agent"].get("eps") 41 | ) 42 | 43 | start_train = int(cfg["buffer"]["warmup"] / num_env) 44 | log_every = cfg["train"]["log_every"] 45 | train_every = cfg["train"]["learner_every"] 46 | wmse_every = cfg["train"]["w_mse_every"] 47 | 48 | def save(): 49 | torch.save(model.state_dict(), "models/dqn.pt") 50 | wmse.save() 51 | pred.save() 52 | 53 | count = trange(int(cfg["train"]["frames"] / 4 / num_env), smoothing=0.05) 54 | for n_iter in count: 55 | full_step = buffer.get_recent(fstack + 1) 56 | step, hx, log = actor.send(full_step) 57 | learner.append(step, hx, n_iter) 58 | 59 | if n_iter == start_train: 60 | for i in trange(10000): 61 | cur_log = wmse.train() 62 | if i % 100 == 0: 63 | wandb.log(cur_log) 64 | for i in trange(5000): 65 | cur_log = pred.train() 66 | if i % 100 == 0: 67 | wandb.log(cur_log) 68 | wmse.save() 69 | pred.save() 70 | 71 | if n_iter > start_train and (n_iter + 1) % train_every == 0: 72 | cur_log = learner.train() 73 | if (n_iter + 1) % log_every < train_every: 74 | log.update(cur_log) 75 | 76 | if n_iter > start_train and (n_iter + 1) % wmse_every == 0: 77 | cur_log = wmse.train() 78 | if (n_iter + 1) % log_every < wmse_every: 79 | log.update(cur_log) 80 | 81 | if len(log): 82 | wandb.log({"frame": n_iter * num_env * 4, **log}) 83 | 84 | if (n_iter + 1) % cfg["train"]["checkpoint_every"] == 0: 85 | save() 86 | save() 87 | -------------------------------------------------------------------------------- /atari/train_emb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import trange 3 | import wandb 4 | import torch 5 | 6 | from dqn.buffer import Buffer 7 | from common.load_cfg import load_cfg 8 | from atari import make_vec_envs 9 | from dqn import actor_iter 10 | from repre.w_mse import WMSE 11 | from repre.inverse_dynamics import IDF 12 | from repre.cpc import CPC 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser(description="") 17 | parser.add_argument("--cfg", type=str, default="default") 18 | parser.add_argument("--env", type=str, default="MontezumaRevenge") 19 | p = parser.parse_args() 20 | cfg = load_cfg(p.cfg) 21 | cfg.update(vars(p)) 22 | wandb.init(project="lwm", config=cfg) 23 | 24 | num_env = cfg["agent"]["actors"] 25 | fstack = cfg["agent"]["frame_stack"] 26 | envs = make_vec_envs(cfg["env"], num_env, max_ep_len=cfg["train"]["max_ep_len"]) 27 | num_action = envs.action_space.n 28 | 29 | buffer = Buffer( 30 | num_env=num_env, 31 | maxlen=int(cfg["buffer"]["size"] / num_env), 32 | obs_shape=envs.observation_space.shape, 33 | device=cfg["buffer"]["device"], 34 | ) 35 | wmse = WMSE(buffer, cfg) 36 | idf = IDF(buffer=buffer, num_action=num_action) 37 | cpc = CPC(buffer=buffer, num_action=num_action) 38 | actor = actor_iter(envs, None, None, cfg["buffer"]["warmup"], eps=1) 39 | 40 | pretrain = int(cfg["buffer"]["warmup"] / num_env) 41 | for n_iter in trange(pretrain): 42 | step, hx, log = next(actor) 43 | buffer.append(step) 44 | 45 | # batch = 256 46 | for i in trange(20000): 47 | cur_log = wmse.train() 48 | if i % 200 == 0: 49 | wandb.log(cur_log) 50 | torch.save(wmse.encoder.state_dict(), "models/conv_wmse.pt") 51 | 52 | # batch = 256 53 | for i in trange(20000): 54 | cur_log = idf.train() 55 | if i % 200 == 0: 56 | wandb.log(cur_log) 57 | torch.save(idf.encoder.state_dict(), "models/conv_idf.pt") 58 | 59 | # batch = 32 * 32 60 | for i in trange(5000): 61 | cur_log = cpc.train() 62 | if i % 50 == 0: 63 | wandb.log(cur_log) 64 | torch.save(cpc.model.conv.state_dict(), "models/conv_cpc.pt") 65 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime 2 | RUN apt-get update && apt-get install -y libgtk2.0-dev 3 | RUN pip install opencv-python 4 | RUN pip install tensorflow-gpu==1.14 5 | RUN pip install git+https://github.com/openai/baselines 6 | RUN pip install gym[atari] 7 | RUN pip install wandb 8 | RUN pip install ipdb 9 | ENTRYPOINT wandb login $WANDB_KEY && /bin/bash 10 | -------------------------------------------------------------------------------- /montezuma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htdt/lwm/b370828a2c99a7aae8b37f25319eba351a2768cb/montezuma.png -------------------------------------------------------------------------------- /pol/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htdt/lwm/b370828a2c99a7aae8b37f25319eba351a2768cb/pol/common/__init__.py -------------------------------------------------------------------------------- /pol/common/load_cfg.py: -------------------------------------------------------------------------------- 1 | import re 2 | import yaml 3 | 4 | 5 | def replace_e_float(d): 6 | p = re.compile(r"^-?\d+(\.\d+)?e-?\d+$") 7 | for name, val in d.items(): 8 | if type(val) == dict: 9 | replace_e_float(val) 10 | elif type(val) == str and p.match(val): 11 | d[name] = float(val) 12 | 13 | 14 | def load_cfg(name, prefix="."): 15 | with open(f"{prefix}/{name}.yaml") as f: 16 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 17 | replace_e_float(cfg) 18 | return cfg 19 | -------------------------------------------------------------------------------- /pol/common/optim.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | import torch 4 | from torch.nn.utils import clip_grad_norm_ 5 | from torch.optim import Adam 6 | 7 | 8 | @dataclass 9 | class ParamOptim: 10 | params: List[torch.Tensor] 11 | lr: float = 1e-3 12 | eps: float = 1e-8 13 | clip_grad: float = None 14 | 15 | def __post_init__(self): 16 | self.optim = Adam(self.params, lr=self.lr, eps=self.eps) 17 | 18 | def scale_lr(self, k): 19 | for pg in self.optim.param_groups: 20 | pg["lr"] = self.lr * k 21 | 22 | def step(self, loss): 23 | self.optim.zero_grad() 24 | loss.backward() 25 | if self.clip_grad is not None: 26 | clip_grad_norm_(self.params, self.clip_grad) 27 | self.optim.step() 28 | return loss 29 | -------------------------------------------------------------------------------- /pol/common/timer.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | 6 | def timer_log(num_iter=1000): 7 | log = {} 8 | mean_t = defaultdict(list) 9 | t = mark = None 10 | while True: 11 | prev_t, prev_mark = t, mark 12 | mark = yield log 13 | t = time() 14 | if prev_mark is not None: 15 | mean_t[prev_mark].append(t - prev_t) 16 | 17 | if mark is None and len(list(mean_t.values())[0]) >= num_iter: 18 | log = {"time/" + k: np.mean(v) * 1000 for k, v in mean_t.items()} 19 | mean_t = defaultdict(list) 20 | else: 21 | log = {} 22 | -------------------------------------------------------------------------------- /pol/default.yaml: -------------------------------------------------------------------------------- 1 | optim: 2 | lr: 5e-4 3 | eps: 1e-3 4 | clip_grad: 40 5 | 6 | agent: 7 | rnn_size: 128 8 | actors: 8 9 | unroll: 32 10 | burnin: 16 11 | batch_size: 32 12 | gamma: 0.99 13 | target_tau: 0.05 14 | eps: 0.01 15 | 16 | self_sup: 17 | lr: 5e-4 18 | ri_momentum: 0.99 19 | 20 | buffer: 21 | device: cuda 22 | size: 1e5 23 | warmup: 1e4 24 | 25 | train: 26 | frames: 1e6 27 | max_ep_len: 1000 28 | learner_every: 4 29 | log_every: 100 30 | -------------------------------------------------------------------------------- /pol/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import actor_iter 2 | from .model import DQN 3 | from .learner import Learner 4 | 5 | 6 | __all__ = ["actor_iter", "Learner", "DQN"] 7 | -------------------------------------------------------------------------------- /pol/dqn/actor.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | import torch 3 | import numpy as np 4 | from common.timer import timer_log 5 | 6 | 7 | def actor_iter(env, model, predictor, warmup, eps=None): 8 | minstep = int(warmup / env.num_envs) 9 | hx = None # torch.zeros(env.num_envs, model.rnn_size, device=model.device) 10 | hx_pred = None 11 | step = {"obs": env.reset()} 12 | 13 | timer = timer_log(100) 14 | next(timer) 15 | if eps is None: 16 | eps = (0.4 ** torch.linspace(1, 8, env.num_envs))[..., None] 17 | mean_reward, mean_len = [], [] 18 | log = {} 19 | 20 | for n_iter in count(): 21 | full_step = yield step, hx, log 22 | 23 | timer.send("actor/action") 24 | action = torch.randint(env.action_space.n, (env.num_envs, 1)) 25 | if n_iter >= minstep: 26 | with torch.no_grad(): 27 | _, ri, hx_pred = predictor.get_error(full_step, hx_pred) 28 | full_step = full_step[1:] 29 | full_step["reward"] += ri 30 | 31 | qs, hx = model(**full_step, hx=hx) 32 | action_greedy = qs[0].argmax(1)[..., None].cpu() 33 | x = torch.rand(env.num_envs, 1) > eps 34 | action[x] = action_greedy[x] 35 | 36 | timer.send("actor/env") 37 | # done = 1 means obs is first step of next episode 38 | # prev_obs + action = obs 39 | obs, reward, done, infos = env.step(action) 40 | 41 | log = timer.send(None) 42 | 43 | ep = [info["episode"] for info in infos if "episode" in info] 44 | mean_reward += [x["r"] for x in ep] 45 | mean_len += [x["l"] for x in ep] 46 | if len(mean_reward) >= env.num_envs: 47 | log = { 48 | "reward": np.mean(mean_reward), 49 | "len": np.mean(mean_len), 50 | **log, 51 | } 52 | mean_reward, mean_len, = [], [] 53 | if "episode" in infos[-1]: 54 | log = {"reward_last": infos[-1]["episode"]["r"], **log} 55 | 56 | step = {"obs": obs, "action": action, "reward": reward, "done": done} 57 | -------------------------------------------------------------------------------- /pol/dqn/algo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def vf_rescaling(x): 5 | eps = 1e-3 6 | return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x 7 | 8 | 9 | def inv_vf_rescaling(x): 10 | eps = 1e-3 11 | return torch.sign(x) * ( 12 | (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps))) - 1) / (2 * eps)) ** 2 13 | - 1 14 | ) 15 | 16 | 17 | def get_td_error(batch, hx_start, model, model_t, cfg): 18 | burnin = cfg["agent"]["burnin"] 19 | gamma = cfg["agent"]["gamma"] 20 | 21 | if burnin > 0: 22 | with torch.no_grad(): 23 | hx = model(**batch[:burnin], hx=hx_start, only_hx=True) 24 | hx_target = model_t(**batch[: burnin + 1], hx=hx_start, only_hx=True) 25 | else: 26 | hx = hx_target = None 27 | 28 | qs, _ = model(**batch[burnin:], hx=hx) 29 | 30 | with torch.no_grad(): 31 | qs_target, _ = model_t(**batch[burnin + 1 :], hx=hx_target) 32 | 33 | action = batch["action"][burnin + 1 :] 34 | reward = batch["reward"][burnin + 1 :] 35 | done = batch["done"][burnin + 1 :].float() 36 | 37 | q = qs[:-1].gather(2, action) 38 | ns_action = qs[1:].argmax(2)[..., None].detach() 39 | next_q = qs_target.gather(2, ns_action) 40 | next_q = inv_vf_rescaling(next_q) 41 | target_q = next_q * gamma * (1 - done) + reward 42 | target_q = vf_rescaling(target_q) 43 | td_error = (q - target_q).abs() 44 | 45 | log = { 46 | "loss": td_error.mean().item(), 47 | "q_mean": qs.mean().item(), 48 | "q_std": qs.std().item(), 49 | } 50 | return td_error, log 51 | -------------------------------------------------------------------------------- /pol/dqn/buffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DictWithSlicing(dict): 5 | def __getitem__(self, key): 6 | if isinstance(key, slice): 7 | return {k: v[key] for k, v in self.items()} 8 | return super().__getitem__(key) 9 | 10 | 11 | class Buffer: 12 | def __init__(self, maxlen, num_env, obs_shape, device): 13 | self.maxlen, self.num_env, self.device = maxlen, num_env, device 14 | self.cursor = self._size = 0 15 | 16 | def tensor(shape=(1,), dtype=torch.float): 17 | return torch.empty( 18 | self.maxlen, self.num_env, *shape, dtype=dtype, device=self.device 19 | ) 20 | 21 | self._buffer = { 22 | "obs": tensor(obs_shape, torch.uint8), 23 | "action": tensor(dtype=torch.long), 24 | "reward": tensor(), 25 | "done": tensor(dtype=torch.uint8), 26 | } 27 | 28 | def query(self, idx, idx_env, steps, device="cuda"): 29 | qsize = len(idx) 30 | s = torch.arange(steps - 1, -1, -1) 31 | q0 = (idx[None, ...].repeat(steps, 1) - s[..., None]).flatten() 32 | q1 = idx_env[None, ...].repeat(steps, 1).flatten() 33 | return DictWithSlicing( 34 | { 35 | k: v[q0, q1].view(steps, qsize, *v.shape[2:]).to(device) 36 | for k, v in self._buffer.items() 37 | } 38 | ) 39 | 40 | def append(self, step): 41 | for k in self._buffer: 42 | if k not in step: 43 | self._buffer[k][self.cursor] = 0 44 | else: 45 | assert step[k].dtype == self._buffer[k].dtype 46 | assert step[k].shape == self._buffer[k].shape[1:] 47 | self._buffer[k][self.cursor] = step[k].to(self.device) 48 | self.cursor = (self.cursor + 1) % self.maxlen 49 | self._size = min(self.maxlen, self._size + 1) 50 | 51 | def get_recent(self, steps, device="cuda"): 52 | if len(self) == 0: 53 | return None 54 | idx = torch.tensor([self.cursor - 1] * self.num_env) 55 | idx_env = torch.arange(self.num_env) 56 | step = self.query(idx, idx_env, steps, device) 57 | if len(self) < steps: 58 | for el in step.values(): 59 | el[: steps - len(self)] = 0 60 | return step 61 | 62 | def __len__(self): 63 | return self._size 64 | -------------------------------------------------------------------------------- /pol/dqn/learner.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from functools import partial 3 | import random 4 | import torch 5 | 6 | from common.optim import ParamOptim 7 | from dqn.algo import get_td_error 8 | 9 | 10 | class Learner: 11 | def __init__(self, model, buffer, predictor, cfg): 12 | model_t = deepcopy(model) 13 | model_t = model_t.cuda().eval() 14 | self.model, self.model_t = model, model_t 15 | self.buffer = buffer 16 | self.predictor = predictor 17 | self.optim = ParamOptim(params=model.parameters(), **cfg["optim"]) 18 | 19 | self.batch_size = cfg["agent"]["batch_size"] 20 | self.unroll = cfg["agent"]["unroll"] 21 | self.unroll_prefix = cfg["agent"]["burnin"] + 1 22 | self.sample_steps = self.unroll_prefix + self.unroll 23 | 24 | self.target_tau = cfg["agent"]["target_tau"] 25 | self.td_error = partial(get_td_error, model=model, model_t=model_t, cfg=cfg) 26 | self.add_ri = cfg["add_ri"] 27 | 28 | def _update_target(self): 29 | for t, s in zip(self.model_t.parameters(), self.model.parameters()): 30 | t.data.copy_(t.data * (1.0 - self.target_tau) + s.data * self.target_tau) 31 | 32 | def loss_uniform(self): 33 | if len(self.buffer) < self.buffer.maxlen: 34 | no_prev = set(range(self.sample_steps)) 35 | else: 36 | no_prev = set( 37 | (self.buffer.cursor + i) % self.buffer.maxlen 38 | for i in range(self.sample_steps) 39 | ) 40 | all_idx = list(set(range(len(self.buffer))) - no_prev) 41 | idx0 = torch.tensor(random.choices(all_idx, k=self.batch_size)) 42 | idx1 = torch.tensor( 43 | random.choices(range(self.buffer.num_env), k=self.batch_size) 44 | ) 45 | batch = self.buffer.query(idx0, idx1, self.sample_steps) 46 | loss_pred, ri, _ = self.predictor.get_error(batch, update_stats=True) 47 | if self.add_ri: 48 | batch["reward"][1:] += ri 49 | td_error, log = self.td_error(batch, None) 50 | loss = td_error.pow(2).sum(0) 51 | return loss, loss_pred, ri, log 52 | 53 | def train(self, need_stat=True): 54 | loss, loss_pred, ri, log = self.loss_uniform() 55 | self.optim.step(loss.mean()) 56 | self.predictor.optim.step(loss_pred) 57 | self._update_target() 58 | 59 | if need_stat: 60 | log.update( 61 | { 62 | "ri_std": ri.std(), 63 | "ri_mean": ri.mean(), 64 | "ri_run_mean": self.predictor.ri_mean, 65 | "ri_run_std": self.predictor.ri_std, 66 | "loss_predictor": loss_pred.mean().detach(), 67 | } 68 | ) 69 | return log 70 | -------------------------------------------------------------------------------- /pol/dqn/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import one_hot 4 | 5 | 6 | class DQN(nn.Module): 7 | def __init__(self, rnn_size, device="cuda"): 8 | super(DQN, self).__init__() 9 | self.device = device 10 | self.rnn_size = rnn_size 11 | 12 | self.encoder = nn.Sequential(nn.Linear(4 + 4 + 1, 32), nn.ReLU()) 13 | # self.rnn = nn.GRUCell(32, self.rnn_size) 14 | self.rnn = nn.GRU(32, self.rnn_size) 15 | self.adv = nn.Sequential( 16 | nn.Linear(self.rnn_size, self.rnn_size), 17 | nn.ReLU(), 18 | nn.Linear(self.rnn_size, 4, bias=False), 19 | ) 20 | self.val = nn.Sequential( 21 | nn.Linear(self.rnn_size, self.rnn_size), 22 | nn.ReLU(), 23 | nn.Linear(self.rnn_size, 1), 24 | ) 25 | 26 | def forward(self, obs, action, reward, done, hx=None, only_hx=False): 27 | mask = (1 - done).float() 28 | a = one_hot(action[:, :, 0], 4).float() * mask 29 | r = reward * mask 30 | x = torch.cat([obs.float(), a, r], 2) 31 | 32 | steps, batch, *rest = x.shape 33 | x = x.view(steps * batch, *rest) 34 | x = self.encoder(x).view(steps, batch, 32) 35 | 36 | # y = torch.empty(steps, batch, self.rnn_size, device=self.device) 37 | # for i in range(steps): 38 | # if hx is not None: 39 | # hx *= mask[i] 40 | # y[i] = hx = self.rnn(x[i], hx) 41 | # hx = hx.clone().detach() 42 | y, hx = self.rnn(x, hx) 43 | if only_hx: 44 | return hx 45 | 46 | y = y.view(steps * batch, self.rnn_size) 47 | adv, val = self.adv(y), self.val(y) 48 | q = val + adv - adv.mean(1, keepdim=True) 49 | q = q.view(steps, batch, 4) 50 | return q, hx 51 | -------------------------------------------------------------------------------- /pol/env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from baselines import bench 3 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 4 | from baselines.common.vec_env import VecEnvWrapper 5 | from baselines.common.wrappers import TimeLimit 6 | from pol_env import PolEnv 7 | 8 | 9 | def make_vec_envs(num, size, seed=0, max_ep_len=1000): 10 | def make_env(rank): 11 | def _thunk(): 12 | env = PolEnv(size) 13 | env = TimeLimit(env, max_episode_steps=max_ep_len) 14 | env.seed(seed + rank) 15 | env = bench.Monitor(env, None) 16 | return env 17 | 18 | return _thunk 19 | 20 | envs = [make_env(i) for i in range(num)] 21 | envs = SubprocVecEnv(envs, context="fork") 22 | envs = VecTorch(envs) 23 | return envs 24 | 25 | 26 | class VecTorch(VecEnvWrapper): 27 | def reset(self): 28 | return torch.from_numpy(self.venv.reset()) 29 | 30 | def step_async(self, actions): 31 | assert len(actions.shape) == 2 32 | actions = actions[:, 0].cpu().numpy() 33 | self.venv.step_async(actions) 34 | 35 | def step_wait(self): 36 | obs, reward, done, info = self.venv.step_wait() 37 | obs = torch.from_numpy(obs) 38 | reward = torch.from_numpy(reward)[..., None].float() 39 | done = torch.tensor(done, dtype=torch.uint8)[..., None] 40 | return obs, reward, done, info 41 | -------------------------------------------------------------------------------- /pol/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import wandb 5 | 6 | from dqn.buffer import Buffer 7 | from common.load_cfg import load_cfg 8 | from env import make_vec_envs 9 | from dqn import actor_iter, DQN 10 | from predictor import Predictor 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description="") 15 | parser.add_argument("--size", type=int, default=3) 16 | parser.add_argument("--add_ri", action="store_true") 17 | parser.add_argument("--random", action="store_true") 18 | p = parser.parse_args() 19 | cfg = load_cfg("default") 20 | cfg.update(vars(p)) 21 | cfg["env"] = "pol" 22 | wandb.init(project="lwm", config=cfg) 23 | 24 | num_env = 1 25 | envs = make_vec_envs( 26 | num=num_env, 27 | size=cfg["size"], 28 | max_ep_len=cfg["train"]["max_ep_len"], 29 | ) 30 | buffer = Buffer( 31 | num_env=num_env, 32 | maxlen=int(cfg["buffer"]["size"] / num_env), 33 | obs_shape=(4,), 34 | device=cfg["buffer"]["device"], 35 | ) 36 | model = DQN(cfg["agent"]["rnn_size"]).cuda() 37 | pred = Predictor(buffer, cfg) 38 | if cfg["random"]: 39 | warmup = 1e8 40 | else: 41 | cp = torch.load("models/dqn.pt") 42 | model.load_state_dict(cp) 43 | model.eval() 44 | pred.load() 45 | warmup = 0 46 | actor = actor_iter(envs, model, pred, warmup, eps=0.01) 47 | 48 | reward = [] 49 | while len(reward) < 128: 50 | full_step = buffer.get_recent(2) 51 | step, hx, log = actor.send(full_step) 52 | buffer.append(step) 53 | if "reward" in log: 54 | reward.append(log["reward"]) 55 | 56 | wandb.log({"final_reward": np.mean(reward)}) 57 | -------------------------------------------------------------------------------- /pol/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/htdt/lwm/b370828a2c99a7aae8b37f25319eba351a2768cb/pol/models/.gitkeep -------------------------------------------------------------------------------- /pol/pol_env.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import gym 4 | from gym import spaces 5 | from gym.utils import seeding 6 | 7 | LEFT, DOWN, RIGHT, UP = 0, 1, 2, 3 8 | STEPS = [LEFT, DOWN, RIGHT, UP] 9 | OPPOSITE = {LEFT: RIGHT, RIGHT: LEFT, UP: DOWN, DOWN: UP} 10 | 11 | 12 | def step_grid(cur, d, size): 13 | x, y = cur 14 | if d == LEFT: 15 | x -= 1 16 | elif d == RIGHT: 17 | x += 1 18 | elif d == UP: 19 | y -= 1 20 | elif d == DOWN: 21 | y += 1 22 | if x < 0 or y < 0 or x >= size or y >= size: 23 | return cur 24 | return (x, y) 25 | 26 | 27 | def gen_labyrinth(size, np_random): 28 | edges = np.zeros((size, size, 4), dtype=bool) 29 | visit = np.zeros((size, size), dtype=bool) 30 | stack = [(0, 0)] 31 | 32 | while len(stack): 33 | cur = stack.pop() 34 | visit[cur] = 1 35 | neib = [d for d in STEPS if not visit[step_grid(cur, d, size)]] 36 | if len(neib): 37 | stack.append(cur) 38 | next_d = np_random.choice(neib) 39 | next_pos = step_grid(cur, next_d, size) 40 | edges[cur][next_d] = edges[next_pos][OPPOSITE[next_d]] = 1 41 | stack.append(next_pos) 42 | return edges 43 | 44 | 45 | class PolEnv(gym.Env): 46 | metadata = {"render.modes": ["human"]} 47 | 48 | def __init__(self, size): 49 | self.size = size 50 | self.action_space = spaces.Discrete(4) 51 | self.observation_space = spaces.Discrete(4) 52 | self.seed() 53 | 54 | def seed(self, seed=None): 55 | self.np_random, seed = seeding.np_random(seed) 56 | return [seed] 57 | 58 | def step(self, action): 59 | assert action in STEPS 60 | if self.map[self.pos][action]: 61 | self.pos = step_grid(self.pos, action, self.size) 62 | self.visit[self.pos] = 1 63 | done = self.visit.all() 64 | state = self.map[self.pos].astype(np.uint8) 65 | reward = -1.0 66 | return state, reward, done, {} 67 | 68 | def reset(self): 69 | self.map = gen_labyrinth(self.size, self.np_random) 70 | self.visit = np.zeros((self.size, self.size), dtype=bool) 71 | # self.pos = (0, 0) 72 | self.pos = tuple(self.np_random.randint(self.size, size=2)) 73 | self.visit[self.pos] = 1 74 | return self.map[self.pos].astype(np.uint8) 75 | 76 | def render(self, mode="human"): 77 | m2 = np.zeros((self.size * 2 + 1, self.size * 2 + 1), dtype=int) 78 | m2[1::2, 1::2] = 1 79 | m2[:-1:2, 1::2] = self.map[:, :, LEFT] 80 | m2[1::2, :-1:2] = self.map[:, :, UP] 81 | m2[self.pos[0] * 2 + 1, self.pos[1] * 2 + 1] = 2 82 | 83 | for s in m2.astype(str): 84 | s = "".join(s) 85 | s = s.replace("0", "#") 86 | s = s.replace("1", " ") 87 | s = s.replace("2", "@") 88 | sys.stdout.write(s + "\n") 89 | 90 | def close(self): 91 | self.map = None 92 | -------------------------------------------------------------------------------- /pol/predictor.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.functional import one_hot 5 | from common.optim import ParamOptim 6 | 7 | 8 | class PredictorModel(nn.Module): 9 | def __init__(self, rnn_size): 10 | super(PredictorModel, self).__init__() 11 | self.rnn_size = rnn_size 12 | self.encoder = nn.Sequential(nn.Linear(4 + 4, 32), nn.ReLU()) 13 | # self.rnn = nn.GRUCell(32, self.rnn_size) 14 | self.rnn = nn.GRU(32, self.rnn_size) 15 | self.fc = nn.Sequential( 16 | nn.Linear(self.rnn_size, self.rnn_size), 17 | nn.ReLU(), 18 | nn.Linear(self.rnn_size, 4), 19 | nn.Sigmoid(), 20 | ) 21 | 22 | def forward(self, z, action, done, hx=None): 23 | unroll, batch, emb_size = z.shape 24 | a = one_hot(action[:, :, 0], 4).float() 25 | z = torch.cat([z, a], dim=2) 26 | z = self.encoder(z.view(unroll * batch, 4 + 4)) 27 | z = z.view(unroll, batch, 32) 28 | 29 | # mask = 1 - done.float() 30 | # x = torch.empty(unroll, batch, self.rnn_size, device=z.device) 31 | # for i in range(unroll): 32 | # if hx is not None: 33 | # hx *= mask[i] 34 | # x[i] = hx = self.rnn(z[i], hx) 35 | # hx = hx.clone().detach() 36 | 37 | x, hx = self.rnn(z, hx) 38 | 39 | x = self.fc(x.view(unroll * batch, self.rnn_size)) 40 | z_pred = x.view(unroll, batch, 4) 41 | return z_pred, hx 42 | 43 | 44 | class Predictor: 45 | def __init__(self, buffer, cfg, device="cuda"): 46 | self.device = device 47 | self.buffer = buffer 48 | 49 | self.model = PredictorModel(cfg["agent"]["rnn_size"]) 50 | self.model = self.model.to(device).train() 51 | lr = cfg["self_sup"]["lr"] 52 | self.optim = ParamOptim(params=self.model.parameters(), lr=lr) 53 | self.ri_mean = self.ri_std = None 54 | self.ri_momentum = cfg["self_sup"]["ri_momentum"] 55 | 56 | def get_error(self, batch, hx=None, update_stats=False): 57 | z = batch["obs"].float() 58 | action = batch["action"][1:] 59 | done = batch["done"][:-1] 60 | z_pred, hx = self.model(z[:-1], action, done, hx) 61 | err = (z[1:] - z_pred).pow(2).mean(2) 62 | 63 | ri = err.detach() 64 | if update_stats: 65 | if self.ri_mean is None: 66 | self.ri_mean = ri.mean() 67 | self.ri_std = ri.std() 68 | else: 69 | m = self.ri_momentum 70 | self.ri_mean = m * self.ri_mean + (1 - m) * ri.mean() 71 | self.ri_std = m * self.ri_std + (1 - m) * ri.std() 72 | if self.ri_mean is not None: 73 | ri = (ri[..., None] - self.ri_mean) / self.ri_std 74 | else: 75 | ri = 0 76 | return err.mean(), ri, hx 77 | 78 | def train(self): 79 | # this function is used only for pretrain, main training loop is in dqn learner 80 | batch_size = 64 81 | sample_steps = 100 82 | if len(self.buffer) < self.buffer.maxlen: 83 | no_prev = set(range(sample_steps)) 84 | else: 85 | no_prev = set( 86 | (self.buffer.cursor + i) % self.buffer.maxlen 87 | for i in range(sample_steps) 88 | ) 89 | all_idx = list(set(range(len(self.buffer))) - no_prev) 90 | idx0 = torch.tensor(random.choices(all_idx, k=batch_size)) 91 | idx1 = torch.tensor(random.choices(range(self.buffer.num_env), k=batch_size)) 92 | batch = self.buffer.query(idx0, idx1, sample_steps) 93 | loss = self.get_error(batch, update_stats=True)[0] 94 | self.optim.step(loss) 95 | return {"loss_predictor": loss.item()} 96 | 97 | def load(self): 98 | cp = torch.load("models/predictor.pt", map_location=self.device) 99 | self.ri_mean, self.ri_std, model = cp 100 | self.model.load_state_dict(model) 101 | 102 | def save(self): 103 | data = [self.ri_mean, self.ri_std, self.model.state_dict()] 104 | torch.save(data, "models/predictor.pt") 105 | -------------------------------------------------------------------------------- /pol/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | 5 | from common.load_cfg import load_cfg 6 | from env import make_vec_envs 7 | from dqn import actor_iter, DQN 8 | from dqn.buffer import Buffer 9 | from predictor import Predictor 10 | 11 | 12 | if __name__ == "__main__": 13 | cfg = load_cfg("default") 14 | cfg["env"] = "pol" 15 | 16 | num_env = cfg["agent"]["actors"] 17 | env = make_vec_envs( 18 | num=1, 19 | size=3, 20 | max_ep_len=cfg["train"]["max_ep_len"], 21 | seed=10, 22 | ) 23 | model = DQN(cfg["agent"]["rnn_size"], device="cpu") 24 | pred = Predictor(None, cfg, device="cpu") 25 | actor = actor_iter(env, model, pred, 0, eps=0) 26 | buffer = Buffer(num_env=1, maxlen=2, obs_shape=(4,), device="cpu") 27 | 28 | cp = torch.load("models/dqn.pt", map_location="cpu") 29 | model.load_state_dict(cp) 30 | model.eval() 31 | pred.load() 32 | 33 | for n_iter in range(2000): 34 | full_step = buffer.get_recent(2, "cpu") 35 | step, hx, log_a = actor.send(full_step) 36 | buffer.append(step) 37 | # env.render() 38 | os.system("clear") 39 | env.remotes[0].send(('render', None)) 40 | env.remotes[0].recv() 41 | time.sleep(1) 42 | -------------------------------------------------------------------------------- /pol/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import trange 3 | import torch 4 | import wandb 5 | 6 | from dqn.buffer import Buffer 7 | from common.load_cfg import load_cfg 8 | from env import make_vec_envs 9 | from dqn import actor_iter, Learner, DQN 10 | from predictor import Predictor 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description="") 15 | parser.add_argument("--size", type=int, default=3) 16 | parser.add_argument("--add_ri", action="store_true") 17 | p = parser.parse_args() 18 | cfg = load_cfg("default") 19 | cfg.update(vars(p)) 20 | cfg["env"] = "pol" 21 | wandb.init(project="lwm", config=cfg) 22 | 23 | num_env = cfg["agent"]["actors"] 24 | envs = make_vec_envs( 25 | num=num_env, 26 | size=cfg["size"], 27 | max_ep_len=cfg["train"]["max_ep_len"], 28 | ) 29 | buffer = Buffer( 30 | num_env=num_env, 31 | maxlen=int(cfg["buffer"]["size"] / num_env), 32 | obs_shape=(4,), 33 | device=cfg["buffer"]["device"], 34 | ) 35 | model = DQN(cfg["agent"]["rnn_size"]).cuda().train() 36 | pred = Predictor(buffer, cfg) 37 | learner = Learner(model, buffer, pred, cfg) 38 | eps = cfg["agent"].get("eps") 39 | actor = actor_iter(envs, model, pred, cfg["buffer"]["warmup"], eps=eps) 40 | 41 | start_train = int(cfg["buffer"]["warmup"] / num_env) 42 | log_every = cfg["train"]["log_every"] 43 | train_every = cfg["train"]["learner_every"] 44 | 45 | count = trange(int(cfg["train"]["frames"] / num_env), smoothing=0.05) 46 | for n_iter in count: 47 | full_step = buffer.get_recent(2) 48 | step, hx, log = actor.send(full_step) 49 | buffer.append(step) 50 | 51 | if n_iter == start_train and cfg["add_ri"]: 52 | for i in trange(1000): 53 | cur_log = pred.train() 54 | if i % 100 == 0: 55 | wandb.log(cur_log) 56 | pred.save() 57 | 58 | if n_iter > start_train and (n_iter + 1) % train_every == 0: 59 | cur_log = learner.train() 60 | if (n_iter + 1) % log_every < train_every: 61 | log.update(cur_log) 62 | 63 | if len(log): 64 | wandb.log({"frame": n_iter * num_env, **log}) 65 | 66 | torch.save(model.state_dict(), "models/dqn.pt") 67 | pred.save() 68 | --------------------------------------------------------------------------------