├── .gitignore ├── LICENSE ├── README.md ├── conda_env.yml ├── config.yaml ├── data ├── dmc_dreamer_bench.csv ├── dmc_dreamer_bench.ipynb ├── dmc_planet_bench.csv └── dmc_planet_bench.ipynb ├── drq.py ├── logger.py ├── pngs ├── dreamer_bench.png └── planet_bench.png ├── replay_buffer.py ├── train.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | runs 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Denis Yarats, Ilya Kostrikov, Rob Fergus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DrQ: Data regularized Q 2 | 3 | This is a PyTorch implementation of **DrQ** from 4 | 5 | **Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels** by 6 | 7 | [Denis Yarats*](https://cs.nyu.edu/~dy1042/), [Ilya Kostrikov*](https://github.com/ikostrikov), [Rob Fergus](https://cs.nyu.edu/~fergus/pmwiki/pmwiki.php). 8 | 9 | *Equal contribution. Author ordering determined by coin flip. 10 | 11 | [[Paper]](https://arxiv.org/abs/2004.13649) [[Webpage]](https://sites.google.com/view/data-regularized-q) 12 | 13 | **Update**: we released a newer version **DrQ-v2**, please check it out [here](https://github.com/facebookresearch/drqv2). 14 | 15 | Implementations in other frameworks: [jax/flax](https://github.com/ikostrikov/jax-rl). 16 | 17 | ## Citation 18 | If you use this repo in your research, please consider citing the paper as follows 19 | ``` 20 | @inproceedings{yarats2021image, 21 | title={Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels}, 22 | author={Denis Yarats and Ilya Kostrikov and Rob Fergus}, 23 | booktitle={International Conference on Learning Representations}, 24 | year={2021}, 25 | url={https://openreview.net/forum?id=GY6-6sTvGaf} 26 | } 27 | ``` 28 | 29 | ## Requirements 30 | We assume you have access to a gpu that can run CUDA 9.2. Then, the simplest way to install all required dependencies is to create an anaconda environment by running 31 | ``` 32 | conda env create -f conda_env.yml 33 | ``` 34 | After the instalation ends you can activate your environment with 35 | ``` 36 | conda activate drq 37 | ``` 38 | 39 | ## Instructions 40 | To train the DrQ agent on the `Cartpole Swingup` task run 41 | ``` 42 | python train.py env=cartpole_swingup 43 | ``` 44 | **you can get the state-of-the-art performance in under 3 hours.** 45 | 46 | To reproduce the results from the paper run 47 | ``` 48 | python train.py env=cartpole_swingup batch_size=512 action_repeat=8 49 | ``` 50 | 51 | This will produce the `runs` folder, where all the outputs are going to be stored including train/eval logs, tensorboard blobs, and evaluation episode videos. To launch tensorboard run 52 | ``` 53 | tensorboard --logdir runs 54 | ``` 55 | 56 | The console output is also available in a form: 57 | ``` 58 | | train | E: 5 | S: 5000 | R: 11.4359 | D: 66.8 s | BR: 0.0581 | ALOSS: -1.0640 | CLOSS: 0.0996 | TLOSS: -23.1683 | TVAL: 0.0945 | AENT: 3.8132 59 | ``` 60 | a training entry decodes as 61 | ``` 62 | train - training episode 63 | E - total number of episodes 64 | S - total number of environment steps 65 | R - episode return 66 | D - duration in seconds 67 | BR - average reward of a sampled batch 68 | ALOSS - average loss of the actor 69 | CLOSS - average loss of the critic 70 | TLOSS - average loss of the temperature parameter 71 | TVAL - the value of temperature 72 | AENT - the actor's entropy 73 | ``` 74 | while an evaluation entry 75 | ``` 76 | | eval | E: 20 | S: 20000 | R: 10.9356 77 | ``` 78 | contains 79 | ``` 80 | E - evaluation was performed after E episodes 81 | S - evaluation was performed after S environment steps 82 | R - average episode return computed over `num_eval_episodes` (usually 10) 83 | ``` 84 | 85 | ## The PlaNet Benchmark 86 | **DrQ** demonstrates the state-of-the-art performance on a set of challenging image-based tasks from the DeepMind Control Suite (Tassa et al., 2018). We compare against PlaNet (Hafner et al., 2018), SAC-AE (Yarats et al., 2019), SLAC (Lee et al., 2019), CURL (Srinivas et al., 2020), and an upper-bound performance SAC States (Haarnoja et al., 2018). This follows the benchmark protocol established in PlaNet (Hafner et al., 2018). 87 | ![The PlaNet Benchmark](pngs/planet_bench.png) 88 | 89 | ## The Dreamer Benchmark 90 | **DrQ** demonstrates the state-of-the-art performance on an extended set of challenging image-based tasks from the DeepMind Control Suite (Tassa et al., 2018), following the benchmark protocol from Dreamer (Hafner et al., 2019). We compare against Dreamer (Hafner et al., 2019) and an upper-bound performance SAC States (Haarnoja et al., 2018). 91 | ![The Dreamer Benchmark](pngs/dreamer_bench.png) 92 | 93 | 94 | ## Acknowledgements 95 | We used [kornia](https://github.com/kornia/kornia) for data augmentation. 96 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: drq 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7 6 | - pip 7 | - pytorch 8 | - cudatoolkit 9 | - absl-py 10 | - pyparsing 11 | - jupyterlab 12 | - scikit-image 13 | - pip: 14 | - termcolor 15 | - git+git://github.com/deepmind/dm_control.git 16 | - git+git://github.com/denisyarats/dmc2gym.git 17 | - tb-nightly 18 | - imageio 19 | - imageio-ffmpeg 20 | - git+git://github.com/facebookresearch/hydra@0.11_branch 21 | - pandas 22 | - ipdb 23 | - tqdm 24 | - torchvision 25 | - yapf 26 | - mujoco_py 27 | - sklearn 28 | - matplotlib 29 | - kornia 30 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # env 2 | env: cartpole_swingup 3 | # IMPORTANT: if action_repeat is used the effective number of env steps needs to be 4 | # multiplied by action_repeat in the result graphs. 5 | # This is a common practice for a fair comparison. 6 | # See the 2nd paragraph in Appendix C of SLAC: https://arxiv.org/pdf/1907.00953.pdf 7 | # See Dreamer TF2's implementation: https://github.com/danijar/dreamer/blob/02f0210f5991c7710826ca7881f19c64a012290c/dreamer.py#L340 8 | action_repeat: 4 9 | # train 10 | num_train_steps: 1000000 11 | num_train_iters: 1 12 | num_seed_steps: 1000 13 | replay_buffer_capacity: 100000 14 | seed: 1 15 | # eval 16 | eval_frequency: 5000 17 | num_eval_episodes: 10 18 | # misc 19 | log_frequency_step: 10000 20 | log_save_tb: true 21 | save_video: true 22 | device: cuda 23 | # observation 24 | image_size: 84 25 | image_pad: 4 26 | frame_stack: 3 27 | # global params 28 | lr: 1e-3 29 | # IMPORTANT: please use a batch size of 512 to reproduce the results in the paper. Hovewer, with a smaller batch size it still works well. 30 | batch_size: 128 31 | 32 | # agent configuration 33 | agent: 34 | name: drq 35 | class: drq.DRQAgent 36 | params: 37 | obs_shape: ??? # to be specified later 38 | action_shape: ??? # to be specified later 39 | action_range: ??? # to be specified later 40 | device: ${device} 41 | encoder_cfg: ${encoder} 42 | critic_cfg: ${critic} 43 | actor_cfg: ${actor} 44 | discount: 0.99 45 | init_temperature: 0.1 46 | lr: ${lr} 47 | actor_update_frequency: 2 48 | critic_tau: 0.01 49 | critic_target_update_frequency: 2 50 | batch_size: ${batch_size} 51 | 52 | critic: 53 | class: drq.Critic 54 | params: 55 | encoder_cfg: ${agent.params.encoder_cfg} 56 | action_shape: ${agent.params.action_shape} 57 | hidden_dim: 1024 58 | hidden_depth: 2 59 | 60 | actor: 61 | class: drq.Actor 62 | params: 63 | encoder_cfg: ${agent.params.encoder_cfg} 64 | action_shape: ${agent.params.action_shape} 65 | hidden_depth: 2 66 | hidden_dim: 1024 67 | log_std_bounds: [-10, 2] 68 | 69 | encoder: 70 | class: drq.Encoder 71 | params: 72 | obs_shape: ${agent.params.obs_shape} 73 | feature_dim: 50 74 | 75 | 76 | # hydra configuration 77 | hydra: 78 | name: ${env} 79 | run: 80 | dir: ./runs/${now:%Y.%m.%d}/${now:%H%M%S}_${hydra.job.override_dirname} 81 | -------------------------------------------------------------------------------- /drq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import copy 6 | import math 7 | 8 | import utils 9 | import hydra 10 | 11 | 12 | class Encoder(nn.Module): 13 | """Convolutional encoder for image-based observations.""" 14 | def __init__(self, obs_shape, feature_dim): 15 | super().__init__() 16 | 17 | assert len(obs_shape) == 3 18 | self.num_layers = 4 19 | self.num_filters = 32 20 | self.output_dim = 35 21 | self.output_logits = False 22 | self.feature_dim = feature_dim 23 | 24 | self.convs = nn.ModuleList([ 25 | nn.Conv2d(obs_shape[0], self.num_filters, 3, stride=2), 26 | nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1), 27 | nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1), 28 | nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1) 29 | ]) 30 | 31 | self.head = nn.Sequential( 32 | nn.Linear(self.num_filters * 35 * 35, self.feature_dim), 33 | nn.LayerNorm(self.feature_dim)) 34 | 35 | self.outputs = dict() 36 | 37 | def forward_conv(self, obs): 38 | obs = obs / 255. 39 | self.outputs['obs'] = obs 40 | 41 | conv = torch.relu(self.convs[0](obs)) 42 | self.outputs['conv1'] = conv 43 | 44 | for i in range(1, self.num_layers): 45 | conv = torch.relu(self.convs[i](conv)) 46 | self.outputs['conv%s' % (i + 1)] = conv 47 | 48 | h = conv.view(conv.size(0), -1) 49 | return h 50 | 51 | def forward(self, obs, detach=False): 52 | h = self.forward_conv(obs) 53 | 54 | if detach: 55 | h = h.detach() 56 | 57 | out = self.head(h) 58 | if not self.output_logits: 59 | out = torch.tanh(out) 60 | 61 | self.outputs['out'] = out 62 | 63 | return out 64 | 65 | def copy_conv_weights_from(self, source): 66 | """Tie convolutional layers""" 67 | for i in range(self.num_layers): 68 | utils.tie_weights(src=source.convs[i], trg=self.convs[i]) 69 | 70 | def log(self, logger, step): 71 | for k, v in self.outputs.items(): 72 | logger.log_histogram(f'train_encoder/{k}_hist', v, step) 73 | if len(v.shape) > 2: 74 | logger.log_image(f'train_encoder/{k}_img', v[0], step) 75 | 76 | for i in range(self.num_layers): 77 | logger.log_param(f'train_encoder/conv{i + 1}', self.convs[i], step) 78 | 79 | 80 | class Actor(nn.Module): 81 | """torch.distributions implementation of an diagonal Gaussian policy.""" 82 | def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth, 83 | log_std_bounds): 84 | super().__init__() 85 | 86 | self.encoder = hydra.utils.instantiate(encoder_cfg) 87 | 88 | self.log_std_bounds = log_std_bounds 89 | self.trunk = utils.mlp(self.encoder.feature_dim, hidden_dim, 90 | 2 * action_shape[0], hidden_depth) 91 | 92 | self.outputs = dict() 93 | self.apply(utils.weight_init) 94 | 95 | def forward(self, obs, detach_encoder=False): 96 | obs = self.encoder(obs, detach=detach_encoder) 97 | 98 | mu, log_std = self.trunk(obs).chunk(2, dim=-1) 99 | 100 | # constrain log_std inside [log_std_min, log_std_max] 101 | log_std = torch.tanh(log_std) 102 | log_std_min, log_std_max = self.log_std_bounds 103 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 104 | 1) 105 | std = log_std.exp() 106 | 107 | self.outputs['mu'] = mu 108 | self.outputs['std'] = std 109 | 110 | dist = utils.SquashedNormal(mu, std) 111 | return dist 112 | 113 | def log(self, logger, step): 114 | for k, v in self.outputs.items(): 115 | logger.log_histogram(f'train_actor/{k}_hist', v, step) 116 | 117 | for i, m in enumerate(self.trunk): 118 | if type(m) == nn.Linear: 119 | logger.log_param(f'train_actor/fc{i}', m, step) 120 | 121 | 122 | class Critic(nn.Module): 123 | """Critic network, employes double Q-learning.""" 124 | def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth): 125 | super().__init__() 126 | 127 | self.encoder = hydra.utils.instantiate(encoder_cfg) 128 | 129 | self.Q1 = utils.mlp(self.encoder.feature_dim + action_shape[0], 130 | hidden_dim, 1, hidden_depth) 131 | self.Q2 = utils.mlp(self.encoder.feature_dim + action_shape[0], 132 | hidden_dim, 1, hidden_depth) 133 | 134 | self.outputs = dict() 135 | self.apply(utils.weight_init) 136 | 137 | def forward(self, obs, action, detach_encoder=False): 138 | assert obs.size(0) == action.size(0) 139 | obs = self.encoder(obs, detach=detach_encoder) 140 | 141 | obs_action = torch.cat([obs, action], dim=-1) 142 | q1 = self.Q1(obs_action) 143 | q2 = self.Q2(obs_action) 144 | 145 | self.outputs['q1'] = q1 146 | self.outputs['q2'] = q2 147 | 148 | return q1, q2 149 | 150 | def log(self, logger, step): 151 | self.encoder.log(logger, step) 152 | 153 | for k, v in self.outputs.items(): 154 | logger.log_histogram(f'train_critic/{k}_hist', v, step) 155 | 156 | assert len(self.Q1) == len(self.Q2) 157 | for i, (m1, m2) in enumerate(zip(self.Q1, self.Q2)): 158 | assert type(m1) == type(m2) 159 | if type(m1) is nn.Linear: 160 | logger.log_param(f'train_critic/q1_fc{i}', m1, step) 161 | logger.log_param(f'train_critic/q2_fc{i}', m2, step) 162 | 163 | 164 | class DRQAgent(object): 165 | """Data regularized Q: actor-critic method for learning from pixels.""" 166 | def __init__(self, obs_shape, action_shape, action_range, device, 167 | encoder_cfg, critic_cfg, actor_cfg, discount, 168 | init_temperature, lr, actor_update_frequency, critic_tau, 169 | critic_target_update_frequency, batch_size): 170 | self.action_range = action_range 171 | self.device = device 172 | self.discount = discount 173 | self.critic_tau = critic_tau 174 | self.actor_update_frequency = actor_update_frequency 175 | self.critic_target_update_frequency = critic_target_update_frequency 176 | self.batch_size = batch_size 177 | 178 | self.actor = hydra.utils.instantiate(actor_cfg).to(self.device) 179 | 180 | self.critic = hydra.utils.instantiate(critic_cfg).to(self.device) 181 | self.critic_target = hydra.utils.instantiate(critic_cfg).to( 182 | self.device) 183 | self.critic_target.load_state_dict(self.critic.state_dict()) 184 | 185 | # tie conv layers between actor and critic 186 | self.actor.encoder.copy_conv_weights_from(self.critic.encoder) 187 | 188 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) 189 | self.log_alpha.requires_grad = True 190 | # set target entropy to -|A| 191 | self.target_entropy = -action_shape[0] 192 | 193 | # optimizers 194 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr) 195 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 196 | lr=lr) 197 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr) 198 | 199 | self.train() 200 | self.critic_target.train() 201 | 202 | def train(self, training=True): 203 | self.training = training 204 | self.actor.train(training) 205 | self.critic.train(training) 206 | 207 | @property 208 | def alpha(self): 209 | return self.log_alpha.exp() 210 | 211 | def act(self, obs, sample=False): 212 | obs = torch.FloatTensor(obs).to(self.device) 213 | obs = obs.unsqueeze(0) 214 | dist = self.actor(obs) 215 | action = dist.sample() if sample else dist.mean 216 | action = action.clamp(*self.action_range) 217 | assert action.ndim == 2 and action.shape[0] == 1 218 | return utils.to_np(action[0]) 219 | 220 | def update_critic(self, obs, obs_aug, action, reward, next_obs, 221 | next_obs_aug, not_done, logger, step): 222 | with torch.no_grad(): 223 | dist = self.actor(next_obs) 224 | next_action = dist.rsample() 225 | log_prob = dist.log_prob(next_action).sum(-1, keepdim=True) 226 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 227 | target_V = torch.min(target_Q1, 228 | target_Q2) - self.alpha.detach() * log_prob 229 | target_Q = reward + (not_done * self.discount * target_V) 230 | 231 | dist_aug = self.actor(next_obs_aug) 232 | next_action_aug = dist_aug.rsample() 233 | log_prob_aug = dist_aug.log_prob(next_action_aug).sum(-1, 234 | keepdim=True) 235 | target_Q1, target_Q2 = self.critic_target(next_obs_aug, 236 | next_action_aug) 237 | target_V = torch.min( 238 | target_Q1, target_Q2) - self.alpha.detach() * log_prob_aug 239 | target_Q_aug = reward + (not_done * self.discount * target_V) 240 | 241 | target_Q = (target_Q + target_Q_aug) / 2 242 | 243 | # get current Q estimates 244 | current_Q1, current_Q2 = self.critic(obs, action) 245 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( 246 | current_Q2, target_Q) 247 | 248 | Q1_aug, Q2_aug = self.critic(obs_aug, action) 249 | 250 | critic_loss += F.mse_loss(Q1_aug, target_Q) + F.mse_loss( 251 | Q2_aug, target_Q) 252 | 253 | logger.log('train_critic/loss', critic_loss, step) 254 | 255 | # Optimize the critic 256 | self.critic_optimizer.zero_grad() 257 | critic_loss.backward() 258 | self.critic_optimizer.step() 259 | 260 | self.critic.log(logger, step) 261 | 262 | def update_actor_and_alpha(self, obs, logger, step): 263 | # detach conv filters, so we don't update them with the actor loss 264 | dist = self.actor(obs, detach_encoder=True) 265 | action = dist.rsample() 266 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 267 | # detach conv filters, so we don't update them with the actor loss 268 | actor_Q1, actor_Q2 = self.critic(obs, action, detach_encoder=True) 269 | 270 | actor_Q = torch.min(actor_Q1, actor_Q2) 271 | 272 | actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean() 273 | 274 | logger.log('train_actor/loss', actor_loss, step) 275 | logger.log('train_actor/target_entropy', self.target_entropy, step) 276 | logger.log('train_actor/entropy', -log_prob.mean(), step) 277 | 278 | # optimize the actor 279 | self.actor_optimizer.zero_grad() 280 | actor_loss.backward() 281 | self.actor_optimizer.step() 282 | 283 | self.actor.log(logger, step) 284 | 285 | self.log_alpha_optimizer.zero_grad() 286 | alpha_loss = (self.alpha * 287 | (-log_prob - self.target_entropy).detach()).mean() 288 | logger.log('train_alpha/loss', alpha_loss, step) 289 | logger.log('train_alpha/value', self.alpha, step) 290 | alpha_loss.backward() 291 | self.log_alpha_optimizer.step() 292 | 293 | def update(self, replay_buffer, logger, step): 294 | obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample( 295 | self.batch_size) 296 | 297 | logger.log('train/batch_reward', reward.mean(), step) 298 | 299 | self.update_critic(obs, obs_aug, action, reward, next_obs, 300 | next_obs_aug, not_done, logger, step) 301 | 302 | if step % self.actor_update_frequency == 0: 303 | self.update_actor_and_alpha(obs, logger, step) 304 | 305 | if step % self.critic_target_update_frequency == 0: 306 | utils.soft_update_params(self.critic, self.critic_target, 307 | self.critic_tau) 308 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | import shutil 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torchvision 11 | from termcolor import colored 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | COMMON_TRAIN_FORMAT = [('episode', 'E', 'int'), ('step', 'S', 'int'), 15 | ('episode_reward', 'R', 'float'), 16 | ('duration', 'D', 'time')] 17 | 18 | COMMON_EVAL_FORMAT = [('episode', 'E', 'int'), ('step', 'S', 'int'), 19 | ('episode_reward', 'R', 'float')] 20 | 21 | AGENT_TRAIN_FORMAT = { 22 | 'drq': [('batch_reward', 'BR', 'float'), ('actor_loss', 'ALOSS', 'float'), 23 | ('critic_loss', 'CLOSS', 'float'), 24 | ('alpha_loss', 'TLOSS', 'float'), ('alpha_value', 'TVAL', 'float'), 25 | ('actor_entropy', 'AENT', 'float')] 26 | } 27 | 28 | 29 | class AverageMeter(object): 30 | def __init__(self): 31 | self._sum = 0 32 | self._count = 0 33 | 34 | def update(self, value, n=1): 35 | self._sum += value 36 | self._count += n 37 | 38 | def value(self): 39 | return self._sum / max(1, self._count) 40 | 41 | 42 | class MetersGroup(object): 43 | def __init__(self, file_name, formating): 44 | self._csv_file_name = self._prepare_file(file_name, 'csv') 45 | self._formating = formating 46 | self._meters = defaultdict(AverageMeter) 47 | self._csv_file = open(self._csv_file_name, 'w') 48 | self._csv_writer = None 49 | 50 | def _prepare_file(self, prefix, suffix): 51 | file_name = f'{prefix}.{suffix}' 52 | if os.path.exists(file_name): 53 | os.remove(file_name) 54 | return file_name 55 | 56 | def log(self, key, value, n=1): 57 | self._meters[key].update(value, n) 58 | 59 | def _prime_meters(self): 60 | data = dict() 61 | for key, meter in self._meters.items(): 62 | if key.startswith('train'): 63 | key = key[len('train') + 1:] 64 | else: 65 | key = key[len('eval') + 1:] 66 | key = key.replace('/', '_') 67 | data[key] = meter.value() 68 | return data 69 | 70 | def _dump_to_csv(self, data): 71 | if self._csv_writer is None: 72 | self._csv_writer = csv.DictWriter(self._csv_file, 73 | fieldnames=sorted(data.keys()), 74 | restval=0.0) 75 | self._csv_writer.writeheader() 76 | self._csv_writer.writerow(data) 77 | self._csv_file.flush() 78 | 79 | def _format(self, key, value, ty): 80 | if ty == 'int': 81 | value = int(value) 82 | return f'{key}: {value}' 83 | elif ty == 'float': 84 | return f'{key}: {value:.04f}' 85 | elif ty == 'time': 86 | return f'{key}: {value:04.1f} s' 87 | else: 88 | raise f'invalid format type: {ty}' 89 | 90 | def _dump_to_console(self, data, prefix): 91 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 92 | pieces = [f'| {prefix: <14}'] 93 | for key, disp_key, ty in self._formating: 94 | value = data.get(key, 0) 95 | pieces.append(self._format(disp_key, value, ty)) 96 | print(' | '.join(pieces)) 97 | 98 | def dump(self, step, prefix, save=True): 99 | if len(self._meters) == 0: 100 | return 101 | if save: 102 | data = self._prime_meters() 103 | data['step'] = step 104 | self._dump_to_csv(data) 105 | self._dump_to_console(data, prefix) 106 | self._meters.clear() 107 | 108 | 109 | class Logger(object): 110 | def __init__(self, 111 | log_dir, 112 | save_tb=False, 113 | log_frequency=10000, 114 | action_repeat=1, 115 | agent='drq'): 116 | self._log_dir = log_dir 117 | self._log_frequency = log_frequency 118 | self._action_repeat = action_repeat 119 | if save_tb: 120 | tb_dir = os.path.join(log_dir, 'tb') 121 | if os.path.exists(tb_dir): 122 | try: 123 | shutil.rmtree(tb_dir) 124 | except: 125 | print("logger.py warning: Unable to remove tb directory") 126 | pass 127 | self._sw = SummaryWriter(tb_dir) 128 | else: 129 | self._sw = None 130 | # each agent has specific output format for training 131 | assert agent in AGENT_TRAIN_FORMAT 132 | train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent] 133 | self._train_mg = MetersGroup(os.path.join(log_dir, 'train'), 134 | formating=train_format) 135 | self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval'), 136 | formating=COMMON_EVAL_FORMAT) 137 | 138 | def _should_log(self, step, log_frequency): 139 | log_frequency = log_frequency or self._log_frequency 140 | return step % log_frequency == 0 141 | 142 | def _update_step(self, step): 143 | return step * self._action_repeat 144 | 145 | def _try_sw_log(self, key, value, step): 146 | step = self._update_step(step) 147 | if self._sw is not None: 148 | self._sw.add_scalar(key, value, step) 149 | 150 | def _try_sw_log_image(self, key, image, step): 151 | step = self._update_step(step) 152 | if self._sw is not None: 153 | assert image.dim() == 3 154 | grid = torchvision.utils.make_grid(image.unsqueeze(1)) 155 | self._sw.add_image(key, grid, step) 156 | 157 | def _try_sw_log_video(self, key, frames, step): 158 | step = self._update_step(step) 159 | if self._sw is not None: 160 | frames = torch.from_numpy(np.array(frames)) 161 | frames = frames.unsqueeze(0) 162 | self._sw.add_video(key, frames, step, fps=30) 163 | 164 | def _try_sw_log_histogram(self, key, histogram, step): 165 | step = self._update_step(step) 166 | if self._sw is not None: 167 | self._sw.add_histogram(key, histogram, step) 168 | 169 | def log(self, key, value, step, n=1, log_frequency=1): 170 | if not self._should_log(step, log_frequency): 171 | return 172 | assert key.startswith('train') or key.startswith('eval') 173 | if type(value) == torch.Tensor: 174 | value = value.item() 175 | self._try_sw_log(key, value / n, step) 176 | mg = self._train_mg if key.startswith('train') else self._eval_mg 177 | mg.log(key, value, n) 178 | 179 | def log_param(self, key, param, step, log_frequency=None): 180 | if not self._should_log(step, log_frequency): 181 | return 182 | self.log_histogram(key + '_w', param.weight.data, step) 183 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 184 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 185 | if hasattr(param, 'bias') and hasattr(param.bias, 'data'): 186 | self.log_histogram(key + '_b', param.bias.data, step) 187 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 188 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 189 | 190 | def log_image(self, key, image, step, log_frequency=None): 191 | if not self._should_log(step, log_frequency): 192 | return 193 | assert key.startswith('train') or key.startswith('eval') 194 | self._try_sw_log_image(key, image, step) 195 | 196 | def log_video(self, key, frames, step, log_frequency=None): 197 | if not self._should_log(step, log_frequency): 198 | return 199 | assert key.startswith('train') or key.startswith('eval') 200 | self._try_sw_log_video(key, frames, step) 201 | 202 | def log_histogram(self, key, histogram, step, log_frequency=None): 203 | if not self._should_log(step, log_frequency): 204 | return 205 | assert key.startswith('train') or key.startswith('eval') 206 | self._try_sw_log_histogram(key, histogram, step) 207 | 208 | def dump(self, step, save=True, ty=None): 209 | step = self._update_step(step) 210 | if ty is None: 211 | self._train_mg.dump(step, 'train', save) 212 | self._eval_mg.dump(step, 'eval', save) 213 | elif ty == 'eval': 214 | self._eval_mg.dump(step, 'eval', save) 215 | elif ty == 'train': 216 | self._train_mg.dump(step, 'train', save) 217 | else: 218 | raise f'invalid log type: {ty}' 219 | -------------------------------------------------------------------------------- /pngs/dreamer_bench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denisyarats/drq/dd040f144ed8f696c7db35087cd80c87372edb64/pngs/dreamer_bench.png -------------------------------------------------------------------------------- /pngs/planet_bench.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/denisyarats/drq/dd040f144ed8f696c7db35087cd80c87372edb64/pngs/planet_bench.png -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import kornia 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import utils 8 | 9 | 10 | class ReplayBuffer(object): 11 | """Buffer to store environment transitions.""" 12 | def __init__(self, obs_shape, action_shape, capacity, image_pad, device): 13 | self.capacity = capacity 14 | self.device = device 15 | 16 | self.aug_trans = nn.Sequential( 17 | nn.ReplicationPad2d(image_pad), 18 | kornia.augmentation.RandomCrop((obs_shape[-1], obs_shape[-1]))) 19 | 20 | self.obses = np.empty((capacity, *obs_shape), dtype=np.uint8) 21 | self.next_obses = np.empty((capacity, *obs_shape), dtype=np.uint8) 22 | self.actions = np.empty((capacity, *action_shape), dtype=np.float32) 23 | self.rewards = np.empty((capacity, 1), dtype=np.float32) 24 | self.not_dones = np.empty((capacity, 1), dtype=np.float32) 25 | self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32) 26 | 27 | self.idx = 0 28 | self.full = False 29 | 30 | def __len__(self): 31 | return self.capacity if self.full else self.idx 32 | 33 | def add(self, obs, action, reward, next_obs, done, done_no_max): 34 | np.copyto(self.obses[self.idx], obs) 35 | np.copyto(self.actions[self.idx], action) 36 | np.copyto(self.rewards[self.idx], reward) 37 | np.copyto(self.next_obses[self.idx], next_obs) 38 | np.copyto(self.not_dones[self.idx], not done) 39 | np.copyto(self.not_dones_no_max[self.idx], not done_no_max) 40 | 41 | self.idx = (self.idx + 1) % self.capacity 42 | self.full = self.full or self.idx == 0 43 | 44 | def sample(self, batch_size): 45 | idxs = np.random.randint(0, 46 | self.capacity if self.full else self.idx, 47 | size=batch_size) 48 | 49 | obses = self.obses[idxs] 50 | next_obses = self.next_obses[idxs] 51 | obses_aug = obses.copy() 52 | next_obses_aug = next_obses.copy() 53 | 54 | obses = torch.as_tensor(obses, device=self.device).float() 55 | next_obses = torch.as_tensor(next_obses, device=self.device).float() 56 | obses_aug = torch.as_tensor(obses_aug, device=self.device).float() 57 | next_obses_aug = torch.as_tensor(next_obses_aug, 58 | device=self.device).float() 59 | actions = torch.as_tensor(self.actions[idxs], device=self.device) 60 | rewards = torch.as_tensor(self.rewards[idxs], device=self.device) 61 | not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs], 62 | device=self.device) 63 | 64 | obses = self.aug_trans(obses) 65 | next_obses = self.aug_trans(next_obses) 66 | 67 | obses_aug = self.aug_trans(obses_aug) 68 | next_obses_aug = self.aug_trans(next_obses_aug) 69 | 70 | return obses, actions, rewards, next_obses, not_dones_no_max, obses_aug, next_obses_aug 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import os 4 | import pickle as pkl 5 | import sys 6 | import time 7 | 8 | import numpy as np 9 | 10 | import dmc2gym 11 | import hydra 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import utils 16 | from logger import Logger 17 | from replay_buffer import ReplayBuffer 18 | from video import VideoRecorder 19 | 20 | torch.backends.cudnn.benchmark = True 21 | 22 | 23 | def make_env(cfg): 24 | """Helper function to create dm_control environment""" 25 | if cfg.env == 'ball_in_cup_catch': 26 | domain_name = 'ball_in_cup' 27 | task_name = 'catch' 28 | elif cfg.env == 'point_mass_easy': 29 | domain_name = 'point_mass' 30 | task_name = 'easy' 31 | else: 32 | domain_name = cfg.env.split('_')[0] 33 | task_name = '_'.join(cfg.env.split('_')[1:]) 34 | 35 | # per dreamer: https://github.com/danijar/dreamer/blob/02f0210f5991c7710826ca7881f19c64a012290c/wrappers.py#L26 36 | camera_id = 2 if domain_name == 'quadruped' else 0 37 | 38 | env = dmc2gym.make(domain_name=domain_name, 39 | task_name=task_name, 40 | seed=cfg.seed, 41 | visualize_reward=False, 42 | from_pixels=True, 43 | height=cfg.image_size, 44 | width=cfg.image_size, 45 | frame_skip=cfg.action_repeat, 46 | camera_id=camera_id) 47 | 48 | env = utils.FrameStack(env, k=cfg.frame_stack) 49 | 50 | env.seed(cfg.seed) 51 | assert env.action_space.low.min() >= -1 52 | assert env.action_space.high.max() <= 1 53 | 54 | return env 55 | 56 | 57 | class Workspace(object): 58 | def __init__(self, cfg): 59 | self.work_dir = os.getcwd() 60 | print(f'workspace: {self.work_dir}') 61 | 62 | self.cfg = cfg 63 | 64 | self.logger = Logger(self.work_dir, 65 | save_tb=cfg.log_save_tb, 66 | log_frequency=cfg.log_frequency_step, 67 | agent=cfg.agent.name, 68 | action_repeat=cfg.action_repeat) 69 | 70 | utils.set_seed_everywhere(cfg.seed) 71 | self.device = torch.device(cfg.device) 72 | self.env = make_env(cfg) 73 | 74 | cfg.agent.params.obs_shape = self.env.observation_space.shape 75 | cfg.agent.params.action_shape = self.env.action_space.shape 76 | cfg.agent.params.action_range = [ 77 | float(self.env.action_space.low.min()), 78 | float(self.env.action_space.high.max()) 79 | ] 80 | self.agent = hydra.utils.instantiate(cfg.agent) 81 | 82 | self.replay_buffer = ReplayBuffer(self.env.observation_space.shape, 83 | self.env.action_space.shape, 84 | cfg.replay_buffer_capacity, 85 | self.cfg.image_pad, self.device) 86 | 87 | self.video_recorder = VideoRecorder( 88 | self.work_dir if cfg.save_video else None) 89 | self.step = 0 90 | 91 | def evaluate(self): 92 | average_episode_reward = 0 93 | for episode in range(self.cfg.num_eval_episodes): 94 | obs = self.env.reset() 95 | self.video_recorder.init(enabled=(episode == 0)) 96 | done = False 97 | episode_reward = 0 98 | episode_step = 0 99 | while not done: 100 | with utils.eval_mode(self.agent): 101 | action = self.agent.act(obs, sample=False) 102 | obs, reward, done, info = self.env.step(action) 103 | self.video_recorder.record(self.env) 104 | episode_reward += reward 105 | episode_step += 1 106 | 107 | average_episode_reward += episode_reward 108 | self.video_recorder.save(f'{self.step}.mp4') 109 | average_episode_reward /= self.cfg.num_eval_episodes 110 | self.logger.log('eval/episode_reward', average_episode_reward, 111 | self.step) 112 | self.logger.dump(self.step) 113 | 114 | def run(self): 115 | episode, episode_reward, episode_step, done = 0, 0, 1, True 116 | start_time = time.time() 117 | while self.step < self.cfg.num_train_steps: 118 | if done: 119 | if self.step > 0: 120 | self.logger.log('train/duration', 121 | time.time() - start_time, self.step) 122 | start_time = time.time() 123 | self.logger.dump( 124 | self.step, save=(self.step > self.cfg.num_seed_steps)) 125 | 126 | # evaluate agent periodically 127 | if self.step % self.cfg.eval_frequency == 0: 128 | self.logger.log('eval/episode', episode, self.step) 129 | self.evaluate() 130 | 131 | self.logger.log('train/episode_reward', episode_reward, 132 | self.step) 133 | 134 | obs = self.env.reset() 135 | done = False 136 | episode_reward = 0 137 | episode_step = 0 138 | episode += 1 139 | 140 | self.logger.log('train/episode', episode, self.step) 141 | 142 | # sample action for data collection 143 | if self.step < self.cfg.num_seed_steps: 144 | action = self.env.action_space.sample() 145 | else: 146 | with utils.eval_mode(self.agent): 147 | action = self.agent.act(obs, sample=True) 148 | 149 | # run training update 150 | if self.step >= self.cfg.num_seed_steps: 151 | for _ in range(self.cfg.num_train_iters): 152 | self.agent.update(self.replay_buffer, self.logger, 153 | self.step) 154 | 155 | next_obs, reward, done, info = self.env.step(action) 156 | 157 | # allow infinite bootstrap 158 | done = float(done) 159 | done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done 160 | episode_reward += reward 161 | 162 | self.replay_buffer.add(obs, action, reward, next_obs, done, 163 | done_no_max) 164 | 165 | obs = next_obs 166 | episode_step += 1 167 | self.step += 1 168 | 169 | 170 | @hydra.main(config_path='config.yaml', strict=True) 171 | def main(cfg): 172 | from train import Workspace as W 173 | workspace = W(cfg) 174 | workspace.run() 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | from collections import deque 5 | 6 | import numpy as np 7 | import scipy.linalg as sp_la 8 | 9 | import gym 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from skimage.util.shape import view_as_windows 14 | from torch import distributions as pyd 15 | 16 | 17 | class eval_mode(object): 18 | def __init__(self, *models): 19 | self.models = models 20 | 21 | def __enter__(self): 22 | self.prev_states = [] 23 | for model in self.models: 24 | self.prev_states.append(model.training) 25 | model.train(False) 26 | 27 | def __exit__(self, *args): 28 | for model, state in zip(self.models, self.prev_states): 29 | model.train(state) 30 | return False 31 | 32 | 33 | def soft_update_params(net, target_net, tau): 34 | for param, target_param in zip(net.parameters(), target_net.parameters()): 35 | target_param.data.copy_(tau * param.data + 36 | (1 - tau) * target_param.data) 37 | 38 | 39 | def set_seed_everywhere(seed): 40 | torch.manual_seed(seed) 41 | if torch.cuda.is_available(): 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | 46 | 47 | def make_dir(*path_parts): 48 | dir_path = os.path.join(*path_parts) 49 | try: 50 | os.mkdir(dir_path) 51 | except OSError: 52 | pass 53 | return dir_path 54 | 55 | 56 | def tie_weights(src, trg): 57 | assert type(src) == type(trg) 58 | trg.weight = src.weight 59 | trg.bias = src.bias 60 | 61 | 62 | def weight_init(m): 63 | """Custom weight init for Conv2D and Linear layers.""" 64 | if isinstance(m, nn.Linear): 65 | nn.init.orthogonal_(m.weight.data) 66 | if hasattr(m.bias, 'data'): 67 | m.bias.data.fill_(0.0) 68 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 69 | gain = nn.init.calculate_gain('relu') 70 | nn.init.orthogonal_(m.weight.data, gain) 71 | if hasattr(m.bias, 'data'): 72 | m.bias.data.fill_(0.0) 73 | 74 | 75 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): 76 | if hidden_depth == 0: 77 | mods = [nn.Linear(input_dim, output_dim)] 78 | else: 79 | mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] 80 | for i in range(hidden_depth - 1): 81 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] 82 | mods.append(nn.Linear(hidden_dim, output_dim)) 83 | if output_mod is not None: 84 | mods.append(output_mod) 85 | trunk = nn.Sequential(*mods) 86 | return trunk 87 | 88 | 89 | def to_np(t): 90 | if t is None: 91 | return None 92 | elif t.nelement() == 0: 93 | return np.array([]) 94 | else: 95 | return t.cpu().detach().numpy() 96 | 97 | 98 | class FrameStack(gym.Wrapper): 99 | def __init__(self, env, k): 100 | gym.Wrapper.__init__(self, env) 101 | self._k = k 102 | self._frames = deque([], maxlen=k) 103 | shp = env.observation_space.shape 104 | self.observation_space = gym.spaces.Box( 105 | low=0, 106 | high=1, 107 | shape=((shp[0] * k,) + shp[1:]), 108 | dtype=env.observation_space.dtype) 109 | self._max_episode_steps = env._max_episode_steps 110 | 111 | def reset(self): 112 | obs = self.env.reset() 113 | for _ in range(self._k): 114 | self._frames.append(obs) 115 | return self._get_obs() 116 | 117 | def step(self, action): 118 | obs, reward, done, info = self.env.step(action) 119 | self._frames.append(obs) 120 | return self._get_obs(), reward, done, info 121 | 122 | def _get_obs(self): 123 | assert len(self._frames) == self._k 124 | return np.concatenate(list(self._frames), axis=0) 125 | 126 | 127 | class TanhTransform(pyd.transforms.Transform): 128 | domain = pyd.constraints.real 129 | codomain = pyd.constraints.interval(-1.0, 1.0) 130 | bijective = True 131 | sign = +1 132 | 133 | def __init__(self, cache_size=1): 134 | super().__init__(cache_size=cache_size) 135 | 136 | @staticmethod 137 | def atanh(x): 138 | return 0.5 * (x.log1p() - (-x).log1p()) 139 | 140 | def __eq__(self, other): 141 | return isinstance(other, TanhTransform) 142 | 143 | def _call(self, x): 144 | return x.tanh() 145 | 146 | def _inverse(self, y): 147 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 148 | # one should use `cache_size=1` instead 149 | return self.atanh(y) 150 | 151 | def log_abs_det_jacobian(self, x, y): 152 | # We use a formula that is more numerically stable, see details in the following link 153 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 154 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 155 | 156 | 157 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 158 | def __init__(self, loc, scale): 159 | self.loc = loc 160 | self.scale = scale 161 | 162 | self.base_dist = pyd.Normal(loc, scale) 163 | transforms = [TanhTransform()] 164 | super().__init__(self.base_dist, transforms) 165 | 166 | @property 167 | def mean(self): 168 | mu = self.loc 169 | for tr in self.transforms: 170 | mu = tr(mu) 171 | return mu -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import imageio 5 | import numpy as np 6 | 7 | import utils 8 | 9 | 10 | class VideoRecorder(object): 11 | def __init__(self, root_dir, height=256, width=256, fps=10): 12 | self.save_dir = utils.make_dir(root_dir, 'video') if root_dir else None 13 | self.height = height 14 | self.width = width 15 | self.fps = fps 16 | self.frames = [] 17 | 18 | def init(self, enabled=True): 19 | self.frames = [] 20 | self.enabled = self.save_dir is not None and enabled 21 | 22 | def record(self, env): 23 | if self.enabled: 24 | frame = env.render(mode='rgb_array', 25 | height=self.height, 26 | width=self.width) 27 | self.frames.append(frame) 28 | 29 | def save(self, file_name): 30 | if self.enabled: 31 | path = os.path.join(self.save_dir, file_name) 32 | imageio.mimsave(path, self.frames, fps=self.fps) 33 | --------------------------------------------------------------------------------