├── .gitignore ├── Data Aug Visualization.ipynb ├── README.md ├── TransformLayer.py ├── conda_env.yml ├── curl_sac.py ├── data_augs.py ├── data_sample.npy ├── encoder.py ├── logger.py ├── rad_thumb.png ├── scripts ├── cheetah_test.sh └── run.sh ├── train.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | tmp/ 2 | notebooks/ 3 | __pycache__/ 4 | .ipynb_checkpoints/ 5 | scripts/run_*.sh 6 | test* 7 | *frame2state* 8 | cartpole-swingup-04-05-im84-b128-s23-pixel/ 9 | cartpole-swingup-04-06-im84-b128-s23-pixel/ 10 | walker-walk-04-06-im84-b128-s23-pixel/ 11 | single_aug_040520.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning with Augmented Data (RAD) 2 | 3 | Official codebase for [Reinforcement Learning with Augmented Data](https://mishalaskin.github.io/rad). This codebase was originally forked from [CURL](https://mishalaskin.github.io/curl). 4 | 5 | Additionally, here is the [codebase link for ProcGen experiments](https://github.com/pokaxpoka/rad_procgen) and [codebase link for OpenAI Gym experiments](https://github.com/pokaxpoka/rad_openaigym). 6 | 7 | 8 | ## BibTex 9 | 10 | ``` 11 | @article{laskin2020reinforcement, 12 | title={Reinforcement learning with augmented data}, 13 | author={Laskin, Michael and Lee, Kimin and Stooke, Adam and Pinto, Lerrel and Abbeel, Pieter and Srinivas, Aravind}, 14 | journal={arXiv preprint arXiv:2004.14990}, 15 | year={2020} 16 | } 17 | ``` 18 | 19 | ## Installation 20 | 21 | All of the dependencies are in the `conda_env.yml` file. They can be installed manually or with the following command: 22 | 23 | ``` 24 | conda env create -f conda_env.yml 25 | ``` 26 | 27 | ## Instructions 28 | To train a RAD agent on the `cartpole swingup` task from image-based observations run `bash script/run.sh` from the root of this directory. The `run.sh` file contains the following command, which you can modify to try different environments / augmentations / hyperparamters. 29 | 30 | ``` 31 | CUDA_VISIBLE_DEVICES=0 python train.py \ 32 | --domain_name cartpole \ 33 | --task_name swingup \ 34 | --encoder_type pixel --work_dir ./tmp/cartpole \ 35 | --action_repeat 8 --num_eval_episodes 10 \ 36 | --pre_transform_image_size 100 --image_size 84 \ 37 | --agent rad_sac --frame_stack 3 --data_augs flip \ 38 | --seed 23 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 --batch_size 128 --num_train_steps 200000 & 39 | ``` 40 | 41 | ## Data Augmentations 42 | 43 | Augmentations can be specified through the `--data_augs` flag. This codebase supports the augmentations specified in `data_augs.py`. To chain multiple data augmentation simply separate the augmentation strings with a `-` string. For example to apply `crop -> rotate -> flip` you can do the following `--data_augs crop-rotate-flip`. 44 | 45 | All data augmentations can be visualized in `All_Data_Augs.ipynb`. You can also test the efficiency of our modules by running `python data_aug.py`. 46 | 47 | 48 | ## Logging 49 | 50 | In your console, you should see printouts that look like this: 51 | 52 | ``` 53 | | train | E: 13 | S: 2000 | D: 9.1 s | R: 48.3056 | BR: 0.8279 | A_LOSS: -3.6559 | CR_LOSS: 2.7563 54 | | train | E: 17 | S: 2500 | D: 9.1 s | R: 146.5945 | BR: 0.9066 | A_LOSS: -5.8576 | CR_LOSS: 6.0176 55 | | train | E: 21 | S: 3000 | D: 7.7 s | R: 138.7537 | BR: 1.0354 | A_LOSS: -7.8795 | CR_LOSS: 7.3928 56 | | train | E: 25 | S: 3500 | D: 9.0 s | R: 181.5103 | BR: 1.0764 | A_LOSS: -10.9712 | CR_LOSS: 8.8753 57 | | train | E: 29 | S: 4000 | D: 8.9 s | R: 240.6485 | BR: 1.2042 | A_LOSS: -13.8537 | CR_LOSS: 9.4001 58 | ``` 59 | The above output decodes as: 60 | 61 | ``` 62 | train - training episode 63 | E - total number of episodes 64 | S - total number of environment steps 65 | D - duration in seconds to train 1 episode 66 | R - episode reward 67 | BR - average reward of sampled batch 68 | A_LOSS - average loss of actor 69 | CR_LOSS - average loss of critic 70 | ``` 71 | 72 | All data related to the run is stored in the specified `working_dir`. To enable model or video saving, use the `--save_model` or `--save_video` flags. For all available flags, inspect `train.py`. To visualize progress with tensorboard run: 73 | 74 | ``` 75 | tensorboard --logdir log --port 6006 76 | ``` 77 | 78 | and go to `localhost:6006` in your browser. If you're running headlessly, try port forwarding with ssh. 79 | 80 | -------------------------------------------------------------------------------- /TransformLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | import numbers 7 | import random 8 | import time 9 | 10 | def rgb2hsv(rgb, eps=1e-8): 11 | # Reference: https://www.rapidtables.com/convert/color/rgb-to-hsv.html 12 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287 13 | 14 | _device = rgb.device 15 | r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] 16 | 17 | Cmax = rgb.max(1)[0] 18 | Cmin = rgb.min(1)[0] 19 | delta = Cmax - Cmin 20 | 21 | hue = torch.zeros((rgb.shape[0], rgb.shape[2], rgb.shape[3])).to(_device) 22 | hue[Cmax== r] = (((g - b)/(delta + eps)) % 6)[Cmax == r] 23 | hue[Cmax == g] = ((b - r)/(delta + eps) + 2)[Cmax == g] 24 | hue[Cmax == b] = ((r - g)/(delta + eps) + 4)[Cmax == b] 25 | hue[Cmax == 0] = 0.0 26 | hue = hue / 6. # making hue range as [0, 1.0) 27 | hue = hue.unsqueeze(dim=1) 28 | 29 | saturation = (delta) / (Cmax + eps) 30 | saturation[Cmax == 0.] = 0. 31 | saturation = saturation.to(_device) 32 | saturation = saturation.unsqueeze(dim=1) 33 | 34 | value = Cmax 35 | value = value.to(_device) 36 | value = value.unsqueeze(dim=1) 37 | 38 | return torch.cat((hue, saturation, value), dim=1)#.type(torch.FloatTensor).to(_device) 39 | # return hue, saturation, value 40 | 41 | def hsv2rgb(hsv): 42 | # Reference: https://www.rapidtables.com/convert/color/hsv-to-rgb.html 43 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287 44 | 45 | _device = hsv.device 46 | 47 | hsv = torch.clamp(hsv, 0, 1) 48 | hue = hsv[:, 0, :, :] * 360. 49 | saturation = hsv[:, 1, :, :] 50 | value = hsv[:, 2, :, :] 51 | 52 | c = value * saturation 53 | x = - c * (torch.abs((hue / 60.) % 2 - 1) - 1) 54 | m = (value - c).unsqueeze(dim=1) 55 | 56 | rgb_prime = torch.zeros_like(hsv).to(_device) 57 | 58 | inds = (hue < 60) * (hue >= 0) 59 | rgb_prime[:, 0, :, :][inds] = c[inds] 60 | rgb_prime[:, 1, :, :][inds] = x[inds] 61 | 62 | inds = (hue < 120) * (hue >= 60) 63 | rgb_prime[:, 0, :, :][inds] = x[inds] 64 | rgb_prime[:, 1, :, :][inds] = c[inds] 65 | 66 | inds = (hue < 180) * (hue >= 120) 67 | rgb_prime[:, 1, :, :][inds] = c[inds] 68 | rgb_prime[:, 2, :, :][inds] = x[inds] 69 | 70 | inds = (hue < 240) * (hue >= 180) 71 | rgb_prime[:, 1, :, :][inds] = x[inds] 72 | rgb_prime[:, 2, :, :][inds] = c[inds] 73 | 74 | inds = (hue < 300) * (hue >= 240) 75 | rgb_prime[:, 2, :, :][inds] = c[inds] 76 | rgb_prime[:, 0, :, :][inds] = x[inds] 77 | 78 | inds = (hue < 360) * (hue >= 300) 79 | rgb_prime[:, 2, :, :][inds] = x[inds] 80 | rgb_prime[:, 0, :, :][inds] = c[inds] 81 | 82 | rgb = rgb_prime + torch.cat((m, m, m), dim=1) 83 | rgb = rgb.to(_device) 84 | 85 | return torch.clamp(rgb, 0, 1) 86 | 87 | class ColorJitterLayer(nn.Module): 88 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0, batch_size=128, stack_size=3): 89 | super(ColorJitterLayer, self).__init__() 90 | self.brightness = self._check_input(brightness, 'brightness') 91 | self.contrast = self._check_input(contrast, 'contrast') 92 | self.saturation = self._check_input(saturation, 'saturation') 93 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 94 | clip_first_on_zero=False) 95 | self.prob = p 96 | self.batch_size = batch_size 97 | self.stack_size = stack_size 98 | 99 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 100 | if isinstance(value, numbers.Number): 101 | if value < 0: 102 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 103 | value = [center - value, center + value] 104 | if clip_first_on_zero: 105 | value[0] = max(value[0], 0) 106 | elif isinstance(value, (tuple, list)) and len(value) == 2: 107 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 108 | raise ValueError("{} values should be between {}".format(name, bound)) 109 | else: 110 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 111 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 112 | # or (0., 0.) for hue, do nothing 113 | if value[0] == value[1] == center: 114 | value = None 115 | return value 116 | 117 | def adjust_contrast(self, x): 118 | """ 119 | Args: 120 | x: torch tensor img (rgb type) 121 | Factor: torch tensor with same length as x 122 | 0 gives gray solid image, 1 gives original image, 123 | Returns: 124 | torch tensor image: Brightness adjusted 125 | """ 126 | _device = x.device 127 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.contrast) 128 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 129 | means = torch.mean(x, dim=(2, 3), keepdim=True) 130 | return torch.clamp((x - means) 131 | * factor.view(len(x), 1, 1, 1) + means, 0, 1) 132 | 133 | def adjust_hue(self, x): 134 | _device = x.device 135 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.hue) 136 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 137 | h = x[:, 0, :, :] 138 | h += (factor.view(len(x), 1, 1) * 255. / 360.) 139 | h = (h % 1) 140 | x[:, 0, :, :] = h 141 | return x 142 | 143 | def adjust_brightness(self, x): 144 | """ 145 | Args: 146 | x: torch tensor img (hsv type) 147 | Factor: 148 | torch tensor with same length as x 149 | 0 gives black image, 1 gives original image, 150 | 2 gives the brightness factor of 2. 151 | Returns: 152 | torch tensor image: Brightness adjusted 153 | """ 154 | _device = x.device 155 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.brightness) 156 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 157 | x[:, 2, :, :] = torch.clamp(x[:, 2, :, :] 158 | * factor.view(len(x), 1, 1), 0, 1) 159 | return torch.clamp(x, 0, 1) 160 | 161 | def adjust_saturate(self, x): 162 | """ 163 | Args: 164 | x: torch tensor img (hsv type) 165 | Factor: 166 | torch tensor with same length as x 167 | 0 gives black image and white, 1 gives original image, 168 | 2 gives the brightness factor of 2. 169 | Returns: 170 | torch tensor image: Brightness adjusted 171 | """ 172 | _device = x.device 173 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.saturation) 174 | factor = factor.reshape(-1,1).repeat(1, self.stack_size).reshape(-1) 175 | x[:, 1, :, :] = torch.clamp(x[:, 1, :, :] 176 | * factor.view(len(x), 1, 1), 0, 1) 177 | return torch.clamp(x, 0, 1) 178 | 179 | def transform(self, inputs): 180 | hsv_transform_list = [rgb2hsv, self.adjust_brightness, 181 | self.adjust_hue, self.adjust_saturate, 182 | hsv2rgb] 183 | rgb_transform_list = [self.adjust_contrast] 184 | # Shuffle transform 185 | if random.uniform(0,1) >= 0.5: 186 | transform_list = rgb_transform_list + hsv_transform_list 187 | else: 188 | transform_list = hsv_transform_list + rgb_transform_list 189 | for t in transform_list: 190 | inputs = t(inputs) 191 | return inputs 192 | 193 | def forward(self, inputs): 194 | _device = inputs.device 195 | random_inds = np.random.choice( 196 | [True, False], len(inputs), p=[self.prob, 1 - self.prob]) 197 | inds = torch.tensor(random_inds).to(_device) 198 | if random_inds.sum() > 0: 199 | inputs[inds] = self.transform(inputs[inds]) 200 | return inputs -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | name: rad 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6 6 | - pytorch 7 | - torchvision 8 | - cudatoolkit=9.2 9 | - absl-py 10 | - pyparsing 11 | - pillow=6.1 12 | - pip: 13 | - termcolor 14 | - git+git://github.com/deepmind/dm_control.git 15 | - git+git://github.com/1nadequacy/dmc2gym.git 16 | - tb-nightly 17 | - imageio 18 | - imageio-ffmpeg 19 | - torchvision 20 | - scikit-image 21 | - tabulate -------------------------------------------------------------------------------- /curl_sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import copy 6 | import math 7 | 8 | import utils 9 | from encoder import make_encoder 10 | import data_augs as rad 11 | 12 | LOG_FREQ = 10000 13 | 14 | 15 | def gaussian_logprob(noise, log_std): 16 | """Compute Gaussian log probability.""" 17 | residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) 18 | return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) 19 | 20 | 21 | def squash(mu, pi, log_pi): 22 | """Apply squashing function. 23 | See appendix C from https://arxiv.org/pdf/1812.05905.pdf. 24 | """ 25 | mu = torch.tanh(mu) 26 | if pi is not None: 27 | pi = torch.tanh(pi) 28 | if log_pi is not None: 29 | log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) 30 | return mu, pi, log_pi 31 | 32 | 33 | def weight_init(m): 34 | """Custom weight init for Conv2D and Linear layers.""" 35 | if isinstance(m, nn.Linear): 36 | nn.init.orthogonal_(m.weight.data) 37 | m.bias.data.fill_(0.0) 38 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 39 | # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf 40 | assert m.weight.size(2) == m.weight.size(3) 41 | m.weight.data.fill_(0.0) 42 | m.bias.data.fill_(0.0) 43 | mid = m.weight.size(2) // 2 44 | gain = nn.init.calculate_gain('relu') 45 | nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) 46 | 47 | 48 | class Actor(nn.Module): 49 | """MLP actor network.""" 50 | def __init__( 51 | self, obs_shape, action_shape, hidden_dim, encoder_type, 52 | encoder_feature_dim, log_std_min, log_std_max, num_layers, num_filters 53 | ): 54 | super().__init__() 55 | 56 | self.encoder = make_encoder( 57 | encoder_type, obs_shape, encoder_feature_dim, num_layers, 58 | num_filters, output_logits=True 59 | ) 60 | 61 | self.log_std_min = log_std_min 62 | self.log_std_max = log_std_max 63 | 64 | self.trunk = nn.Sequential( 65 | nn.Linear(self.encoder.feature_dim, hidden_dim), nn.ReLU(), 66 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 67 | nn.Linear(hidden_dim, 2 * action_shape[0]) 68 | ) 69 | 70 | self.outputs = dict() 71 | self.apply(weight_init) 72 | 73 | def forward( 74 | self, obs, compute_pi=True, compute_log_pi=True, detach_encoder=False 75 | ): 76 | obs = self.encoder(obs, detach=detach_encoder) 77 | 78 | mu, log_std = self.trunk(obs).chunk(2, dim=-1) 79 | 80 | # constrain log_std inside [log_std_min, log_std_max] 81 | log_std = torch.tanh(log_std) 82 | log_std = self.log_std_min + 0.5 * ( 83 | self.log_std_max - self.log_std_min 84 | ) * (log_std + 1) 85 | 86 | self.outputs['mu'] = mu 87 | self.outputs['std'] = log_std.exp() 88 | 89 | if compute_pi: 90 | std = log_std.exp() 91 | noise = torch.randn_like(mu) 92 | pi = mu + noise * std 93 | else: 94 | pi = None 95 | entropy = None 96 | 97 | if compute_log_pi: 98 | log_pi = gaussian_logprob(noise, log_std) 99 | else: 100 | log_pi = None 101 | 102 | mu, pi, log_pi = squash(mu, pi, log_pi) 103 | 104 | return mu, pi, log_pi, log_std 105 | 106 | def log(self, L, step, log_freq=LOG_FREQ): 107 | if step % log_freq != 0: 108 | return 109 | 110 | for k, v in self.outputs.items(): 111 | L.log_histogram('train_actor/%s_hist' % k, v, step) 112 | 113 | L.log_param('train_actor/fc1', self.trunk[0], step) 114 | L.log_param('train_actor/fc2', self.trunk[2], step) 115 | L.log_param('train_actor/fc3', self.trunk[4], step) 116 | 117 | 118 | class QFunction(nn.Module): 119 | """MLP for q-function.""" 120 | def __init__(self, obs_dim, action_dim, hidden_dim): 121 | super().__init__() 122 | 123 | self.trunk = nn.Sequential( 124 | nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(), 125 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 126 | nn.Linear(hidden_dim, 1) 127 | ) 128 | 129 | def forward(self, obs, action): 130 | assert obs.size(0) == action.size(0) 131 | 132 | obs_action = torch.cat([obs, action], dim=1) 133 | return self.trunk(obs_action) 134 | 135 | 136 | class Critic(nn.Module): 137 | """Critic network, employes two q-functions.""" 138 | def __init__( 139 | self, obs_shape, action_shape, hidden_dim, encoder_type, 140 | encoder_feature_dim, num_layers, num_filters 141 | ): 142 | super().__init__() 143 | 144 | 145 | self.encoder = make_encoder( 146 | encoder_type, obs_shape, encoder_feature_dim, num_layers, 147 | num_filters, output_logits=True 148 | ) 149 | 150 | self.Q1 = QFunction( 151 | self.encoder.feature_dim, action_shape[0], hidden_dim 152 | ) 153 | self.Q2 = QFunction( 154 | self.encoder.feature_dim, action_shape[0], hidden_dim 155 | ) 156 | 157 | self.outputs = dict() 158 | self.apply(weight_init) 159 | 160 | def forward(self, obs, action, detach_encoder=False): 161 | # detach_encoder allows to stop gradient propogation to encoder 162 | obs = self.encoder(obs, detach=detach_encoder) 163 | 164 | q1 = self.Q1(obs, action) 165 | q2 = self.Q2(obs, action) 166 | 167 | self.outputs['q1'] = q1 168 | self.outputs['q2'] = q2 169 | 170 | return q1, q2 171 | 172 | def log(self, L, step, log_freq=LOG_FREQ): 173 | if step % log_freq != 0: 174 | return 175 | 176 | self.encoder.log(L, step, log_freq) 177 | 178 | for k, v in self.outputs.items(): 179 | L.log_histogram('train_critic/%s_hist' % k, v, step) 180 | 181 | for i in range(3): 182 | L.log_param('train_critic/q1_fc%d' % i, self.Q1.trunk[i * 2], step) 183 | L.log_param('train_critic/q2_fc%d' % i, self.Q2.trunk[i * 2], step) 184 | 185 | 186 | class CURL(nn.Module): 187 | """ 188 | CURL 189 | """ 190 | 191 | def __init__(self, obs_shape, z_dim, batch_size, critic, critic_target, output_type="continuous"): 192 | super(CURL, self).__init__() 193 | self.batch_size = batch_size 194 | 195 | self.encoder = critic.encoder 196 | 197 | self.encoder_target = critic_target.encoder 198 | 199 | self.W = nn.Parameter(torch.rand(z_dim, z_dim)) 200 | self.output_type = output_type 201 | 202 | def encode(self, x, detach=False, ema=False): 203 | """ 204 | Encoder: z_t = e(x_t) 205 | :param x: x_t, x y coordinates 206 | :return: z_t, value in r2 207 | """ 208 | if ema: 209 | with torch.no_grad(): 210 | z_out = self.encoder_target(x) 211 | else: 212 | z_out = self.encoder(x) 213 | 214 | if detach: 215 | z_out = z_out.detach() 216 | return z_out 217 | 218 | #def update_target(self): 219 | # utils.soft_update_params(self.encoder, self.encoder_target, 0.05) 220 | 221 | def compute_logits(self, z_a, z_pos): 222 | """ 223 | Uses logits trick for CURL: 224 | - compute (B,B) matrix z_a (W z_pos.T) 225 | - positives are all diagonal elements 226 | - negatives are all other elements 227 | - to compute loss use multiclass cross entropy with identity matrix for labels 228 | """ 229 | Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) 230 | logits = torch.matmul(z_a, Wz) # (B,B) 231 | logits = logits - torch.max(logits, 1)[0][:, None] 232 | return logits 233 | 234 | class RadSacAgent(object): 235 | """RAD with SAC.""" 236 | def __init__( 237 | self, 238 | obs_shape, 239 | action_shape, 240 | device, 241 | hidden_dim=256, 242 | discount=0.99, 243 | init_temperature=0.01, 244 | alpha_lr=1e-3, 245 | alpha_beta=0.9, 246 | actor_lr=1e-3, 247 | actor_beta=0.9, 248 | actor_log_std_min=-10, 249 | actor_log_std_max=2, 250 | actor_update_freq=2, 251 | critic_lr=1e-3, 252 | critic_beta=0.9, 253 | critic_tau=0.005, 254 | critic_target_update_freq=2, 255 | encoder_type='pixel', 256 | encoder_feature_dim=50, 257 | encoder_lr=1e-3, 258 | encoder_tau=0.005, 259 | num_layers=4, 260 | num_filters=32, 261 | cpc_update_freq=1, 262 | log_interval=100, 263 | detach_encoder=False, 264 | latent_dim=128, 265 | data_augs = '', 266 | ): 267 | self.device = device 268 | self.discount = discount 269 | self.critic_tau = critic_tau 270 | self.encoder_tau = encoder_tau 271 | self.actor_update_freq = actor_update_freq 272 | self.critic_target_update_freq = critic_target_update_freq 273 | self.cpc_update_freq = cpc_update_freq 274 | self.log_interval = log_interval 275 | self.image_size = obs_shape[-1] 276 | self.latent_dim = latent_dim 277 | self.detach_encoder = detach_encoder 278 | self.encoder_type = encoder_type 279 | self.data_augs = data_augs 280 | 281 | self.augs_funcs = {} 282 | 283 | aug_to_func = { 284 | 'crop':rad.random_crop, 285 | 'grayscale':rad.random_grayscale, 286 | 'cutout':rad.random_cutout, 287 | 'cutout_color':rad.random_cutout_color, 288 | 'flip':rad.random_flip, 289 | 'rotate':rad.random_rotation, 290 | 'rand_conv':rad.random_convolution, 291 | 'color_jitter':rad.random_color_jitter, 292 | 'translate':rad.random_translate, 293 | 'no_aug':rad.no_aug, 294 | } 295 | 296 | for aug_name in self.data_augs.split('-'): 297 | assert aug_name in aug_to_func, 'invalid data aug string' 298 | self.augs_funcs[aug_name] = aug_to_func[aug_name] 299 | 300 | self.actor = Actor( 301 | obs_shape, action_shape, hidden_dim, encoder_type, 302 | encoder_feature_dim, actor_log_std_min, actor_log_std_max, 303 | num_layers, num_filters 304 | ).to(device) 305 | 306 | self.critic = Critic( 307 | obs_shape, action_shape, hidden_dim, encoder_type, 308 | encoder_feature_dim, num_layers, num_filters 309 | ).to(device) 310 | 311 | self.critic_target = Critic( 312 | obs_shape, action_shape, hidden_dim, encoder_type, 313 | encoder_feature_dim, num_layers, num_filters 314 | ).to(device) 315 | 316 | self.critic_target.load_state_dict(self.critic.state_dict()) 317 | 318 | # tie encoders between actor and critic, and CURL and critic 319 | self.actor.encoder.copy_conv_weights_from(self.critic.encoder) 320 | 321 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) 322 | self.log_alpha.requires_grad = True 323 | # set target entropy to -|A| 324 | self.target_entropy = -np.prod(action_shape) 325 | 326 | # optimizers 327 | self.actor_optimizer = torch.optim.Adam( 328 | self.actor.parameters(), lr=actor_lr, betas=(actor_beta, 0.999) 329 | ) 330 | 331 | self.critic_optimizer = torch.optim.Adam( 332 | self.critic.parameters(), lr=critic_lr, betas=(critic_beta, 0.999) 333 | ) 334 | 335 | self.log_alpha_optimizer = torch.optim.Adam( 336 | [self.log_alpha], lr=alpha_lr, betas=(alpha_beta, 0.999) 337 | ) 338 | 339 | if self.encoder_type == 'pixel': 340 | # create CURL encoder (the 128 batch size is probably unnecessary) 341 | self.CURL = CURL(obs_shape, encoder_feature_dim, 342 | self.latent_dim, self.critic,self.critic_target, output_type='continuous').to(self.device) 343 | 344 | # optimizer for critic encoder for reconstruction loss 345 | self.encoder_optimizer = torch.optim.Adam( 346 | self.critic.encoder.parameters(), lr=encoder_lr 347 | ) 348 | 349 | self.cpc_optimizer = torch.optim.Adam( 350 | self.CURL.parameters(), lr=encoder_lr 351 | ) 352 | self.cross_entropy_loss = nn.CrossEntropyLoss() 353 | 354 | self.train() 355 | self.critic_target.train() 356 | 357 | def train(self, training=True): 358 | self.training = training 359 | self.actor.train(training) 360 | self.critic.train(training) 361 | if self.encoder_type == 'pixel': 362 | self.CURL.train(training) 363 | 364 | @property 365 | def alpha(self): 366 | return self.log_alpha.exp() 367 | 368 | def select_action(self, obs): 369 | with torch.no_grad(): 370 | obs = torch.FloatTensor(obs).to(self.device) 371 | obs = obs.unsqueeze(0) 372 | mu, _, _, _ = self.actor( 373 | obs, compute_pi=False, compute_log_pi=False 374 | ) 375 | return mu.cpu().data.numpy().flatten() 376 | 377 | def sample_action(self, obs): 378 | if obs.shape[-1] != self.image_size: 379 | obs = utils.center_crop_image(obs, self.image_size) 380 | 381 | with torch.no_grad(): 382 | obs = torch.FloatTensor(obs).to(self.device) 383 | obs = obs.unsqueeze(0) 384 | mu, pi, _, _ = self.actor(obs, compute_log_pi=False) 385 | return pi.cpu().data.numpy().flatten() 386 | 387 | def update_critic(self, obs, action, reward, next_obs, not_done, L, step): 388 | with torch.no_grad(): 389 | _, policy_action, log_pi, _ = self.actor(next_obs) 390 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) 391 | target_V = torch.min(target_Q1, 392 | target_Q2) - self.alpha.detach() * log_pi 393 | target_Q = reward + (not_done * self.discount * target_V) 394 | 395 | # get current Q estimates 396 | current_Q1, current_Q2 = self.critic( 397 | obs, action, detach_encoder=self.detach_encoder) 398 | critic_loss = F.mse_loss(current_Q1, 399 | target_Q) + F.mse_loss(current_Q2, target_Q) 400 | if step % self.log_interval == 0: 401 | L.log('train_critic/loss', critic_loss, step) 402 | 403 | 404 | # Optimize the critic 405 | self.critic_optimizer.zero_grad() 406 | critic_loss.backward() 407 | self.critic_optimizer.step() 408 | 409 | self.critic.log(L, step) 410 | 411 | def update_actor_and_alpha(self, obs, L, step): 412 | # detach encoder, so we don't update it with the actor loss 413 | _, pi, log_pi, log_std = self.actor(obs, detach_encoder=True) 414 | actor_Q1, actor_Q2 = self.critic(obs, pi, detach_encoder=True) 415 | 416 | actor_Q = torch.min(actor_Q1, actor_Q2) 417 | actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() 418 | 419 | if step % self.log_interval == 0: 420 | L.log('train_actor/loss', actor_loss, step) 421 | L.log('train_actor/target_entropy', self.target_entropy, step) 422 | entropy = 0.5 * log_std.shape[1] * \ 423 | (1.0 + np.log(2 * np.pi)) + log_std.sum(dim=-1) 424 | if step % self.log_interval == 0: 425 | L.log('train_actor/entropy', entropy.mean(), step) 426 | 427 | # optimize the actor 428 | self.actor_optimizer.zero_grad() 429 | actor_loss.backward() 430 | self.actor_optimizer.step() 431 | 432 | self.actor.log(L, step) 433 | 434 | self.log_alpha_optimizer.zero_grad() 435 | alpha_loss = (self.alpha * 436 | (-log_pi - self.target_entropy).detach()).mean() 437 | if step % self.log_interval == 0: 438 | L.log('train_alpha/loss', alpha_loss, step) 439 | L.log('train_alpha/value', self.alpha, step) 440 | alpha_loss.backward() 441 | self.log_alpha_optimizer.step() 442 | 443 | def update_cpc(self, obs_anchor, obs_pos, cpc_kwargs, L, step): 444 | 445 | # time flips 446 | """ 447 | time_pos = cpc_kwargs["time_pos"] 448 | time_anchor= cpc_kwargs["time_anchor"] 449 | obs_anchor = torch.cat((obs_anchor, time_anchor), 0) 450 | obs_pos = torch.cat((obs_anchor, time_pos), 0) 451 | """ 452 | z_a = self.CURL.encode(obs_anchor) 453 | z_pos = self.CURL.encode(obs_pos, ema=True) 454 | 455 | logits = self.CURL.compute_logits(z_a, z_pos) 456 | labels = torch.arange(logits.shape[0]).long().to(self.device) 457 | loss = self.cross_entropy_loss(logits, labels) 458 | 459 | self.encoder_optimizer.zero_grad() 460 | self.cpc_optimizer.zero_grad() 461 | loss.backward() 462 | 463 | self.encoder_optimizer.step() 464 | self.cpc_optimizer.step() 465 | if step % self.log_interval == 0: 466 | L.log('train/curl_loss', loss, step) 467 | 468 | 469 | def update(self, replay_buffer, L, step): 470 | if self.encoder_type == 'pixel': 471 | obs, action, reward, next_obs, not_done = replay_buffer.sample_rad(self.augs_funcs) 472 | else: 473 | obs, action, reward, next_obs, not_done = replay_buffer.sample_proprio() 474 | 475 | if step % self.log_interval == 0: 476 | L.log('train/batch_reward', reward.mean(), step) 477 | 478 | self.update_critic(obs, action, reward, next_obs, not_done, L, step) 479 | 480 | if step % self.actor_update_freq == 0: 481 | self.update_actor_and_alpha(obs, L, step) 482 | 483 | if step % self.critic_target_update_freq == 0: 484 | utils.soft_update_params( 485 | self.critic.Q1, self.critic_target.Q1, self.critic_tau 486 | ) 487 | utils.soft_update_params( 488 | self.critic.Q2, self.critic_target.Q2, self.critic_tau 489 | ) 490 | utils.soft_update_params( 491 | self.critic.encoder, self.critic_target.encoder, 492 | self.encoder_tau 493 | ) 494 | 495 | #if step % self.cpc_update_freq == 0 and self.encoder_type == 'pixel': 496 | # obs_anchor, obs_pos = cpc_kwargs["obs_anchor"], cpc_kwargs["obs_pos"] 497 | # self.update_cpc(obs_anchor, obs_pos,cpc_kwargs, L, step) 498 | 499 | def save(self, model_dir, step): 500 | torch.save( 501 | self.actor.state_dict(), '%s/actor_%s.pt' % (model_dir, step) 502 | ) 503 | torch.save( 504 | self.critic.state_dict(), '%s/critic_%s.pt' % (model_dir, step) 505 | ) 506 | 507 | def save_curl(self, model_dir, step): 508 | torch.save( 509 | self.CURL.state_dict(), '%s/curl_%s.pt' % (model_dir, step) 510 | ) 511 | 512 | def load(self, model_dir, step): 513 | self.actor.load_state_dict( 514 | torch.load('%s/actor_%s.pt' % (model_dir, step)) 515 | ) 516 | self.critic.load_state_dict( 517 | torch.load('%s/critic_%s.pt' % (model_dir, step)) 518 | ) 519 | 520 | -------------------------------------------------------------------------------- /data_augs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn as nn 5 | from TransformLayer import ColorJitterLayer 6 | 7 | 8 | def random_crop(imgs, out=84): 9 | """ 10 | args: 11 | imgs: np.array shape (B,C,H,W) 12 | out: output size (e.g. 84) 13 | returns np.array 14 | """ 15 | n, c, h, w = imgs.shape 16 | crop_max = h - out + 1 17 | w1 = np.random.randint(0, crop_max, n) 18 | h1 = np.random.randint(0, crop_max, n) 19 | cropped = np.empty((n, c, out, out), dtype=imgs.dtype) 20 | for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)): 21 | 22 | cropped[i] = img[:, h11:h11 + out, w11:w11 + out] 23 | return cropped 24 | 25 | 26 | def grayscale(imgs): 27 | # imgs: b x c x h x w 28 | device = imgs.device 29 | b, c, h, w = imgs.shape 30 | frames = c // 3 31 | 32 | imgs = imgs.view([b,frames,3,h,w]) 33 | imgs = imgs[:, :, 0, ...] * 0.2989 + imgs[:, :, 1, ...] * 0.587 + imgs[:, :, 2, ...] * 0.114 34 | 35 | imgs = imgs.type(torch.uint8).float() 36 | # assert len(imgs.shape) == 3, imgs.shape 37 | imgs = imgs[:, :, None, :, :] 38 | imgs = imgs * torch.ones([1, 1, 3, 1, 1], dtype=imgs.dtype).float().to(device) # broadcast tiling 39 | return imgs 40 | 41 | def random_grayscale(images,p=.3): 42 | """ 43 | args: 44 | imgs: torch.tensor shape (B,C,H,W) 45 | device: cpu or cuda 46 | returns torch.tensor 47 | """ 48 | device = images.device 49 | in_type = images.type() 50 | images = images * 255. 51 | images = images.type(torch.uint8) 52 | # images: [B, C, H, W] 53 | bs, channels, h, w = images.shape 54 | images = images.to(device) 55 | gray_images = grayscale(images) 56 | rnd = np.random.uniform(0., 1., size=(images.shape[0],)) 57 | mask = rnd <= p 58 | mask = torch.from_numpy(mask) 59 | frames = images.shape[1] // 3 60 | images = images.view(*gray_images.shape) 61 | mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype) 62 | mask = mask.type(images.dtype).to(device) 63 | mask = mask[:, :, None, None, None] 64 | out = mask * gray_images + (1 - mask) * images 65 | out = out.view([bs, -1, h, w]).type(in_type) / 255. 66 | return out 67 | 68 | # random cutout 69 | # TODO: should mask this 70 | 71 | def random_cutout(imgs, min_cut=10,max_cut=30): 72 | """ 73 | args: 74 | imgs: np.array shape (B,C,H,W) 75 | min / max cut: int, min / max size of cutout 76 | returns np.array 77 | """ 78 | 79 | n, c, h, w = imgs.shape 80 | w1 = np.random.randint(min_cut, max_cut, n) 81 | h1 = np.random.randint(min_cut, max_cut, n) 82 | 83 | cutouts = np.empty((n, c, h, w), dtype=imgs.dtype) 84 | for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)): 85 | cut_img = img.copy() 86 | cut_img[:, h11:h11 + h11, w11:w11 + w11] = 0 87 | #print(img[:, h11:h11 + h11, w11:w11 + w11].shape) 88 | cutouts[i] = cut_img 89 | return cutouts 90 | 91 | def random_cutout_color(imgs, min_cut=10,max_cut=30): 92 | """ 93 | args: 94 | imgs: shape (B,C,H,W) 95 | out: output size (e.g. 84) 96 | """ 97 | 98 | n, c, h, w = imgs.shape 99 | w1 = np.random.randint(min_cut, max_cut, n) 100 | h1 = np.random.randint(min_cut, max_cut, n) 101 | 102 | cutouts = np.empty((n, c, h, w), dtype=imgs.dtype) 103 | rand_box = np.random.randint(0, 255, size=(n, c)) / 255. 104 | for i, (img, w11, h11) in enumerate(zip(imgs, w1, h1)): 105 | cut_img = img.copy() 106 | 107 | # add random box 108 | cut_img[:, h11:h11 + h11, w11:w11 + w11] = np.tile( 109 | rand_box[i].reshape(-1,1,1), 110 | (1,) + cut_img[:, h11:h11 + h11, w11:w11 + w11].shape[1:]) 111 | 112 | cutouts[i] = cut_img 113 | return cutouts 114 | 115 | # random flip 116 | 117 | def random_flip(images,p=.2): 118 | """ 119 | args: 120 | imgs: torch.tensor shape (B,C,H,W) 121 | device: cpu or gpu, 122 | p: prob of applying aug, 123 | returns torch.tensor 124 | """ 125 | # images: [B, C, H, W] 126 | device = images.device 127 | bs, channels, h, w = images.shape 128 | 129 | images = images.to(device) 130 | 131 | flipped_images = images.flip([3]) 132 | 133 | rnd = np.random.uniform(0., 1., size=(images.shape[0],)) 134 | mask = rnd <= p 135 | mask = torch.from_numpy(mask) 136 | frames = images.shape[1] #// 3 137 | images = images.view(*flipped_images.shape) 138 | mask = mask[:, None] * torch.ones([1, frames]).type(mask.dtype) 139 | 140 | mask = mask.type(images.dtype).to(device) 141 | mask = mask[:, :, None, None] 142 | 143 | out = mask * flipped_images + (1 - mask) * images 144 | 145 | out = out.view([bs, -1, h, w]) 146 | return out 147 | 148 | # random rotation 149 | 150 | def random_rotation(images,p=.3): 151 | """ 152 | args: 153 | imgs: torch.tensor shape (B,C,H,W) 154 | device: str, cpu or gpu, 155 | p: float, prob of applying aug, 156 | returns torch.tensor 157 | """ 158 | device = images.device 159 | # images: [B, C, H, W] 160 | bs, channels, h, w = images.shape 161 | 162 | images = images.to(device) 163 | 164 | rot90_images = images.rot90(1,[2,3]) 165 | rot180_images = images.rot90(2,[2,3]) 166 | rot270_images = images.rot90(3,[2,3]) 167 | 168 | rnd = np.random.uniform(0., 1., size=(images.shape[0],)) 169 | rnd_rot = np.random.randint(1, 4, size=(images.shape[0],)) 170 | mask = rnd <= p 171 | mask = rnd_rot * mask 172 | mask = torch.from_numpy(mask).to(device) 173 | 174 | frames = images.shape[1] 175 | masks = [torch.zeros_like(mask) for _ in range(4)] 176 | for i,m in enumerate(masks): 177 | m[torch.where(mask==i)] = 1 178 | m = m[:, None] * torch.ones([1, frames]).type(mask.dtype).type(images.dtype).to(device) 179 | m = m[:,:,None,None] 180 | masks[i] = m 181 | 182 | 183 | out = masks[0] * images + masks[1] * rot90_images + masks[2] * rot180_images + masks[3] * rot270_images 184 | 185 | out = out.view([bs, -1, h, w]) 186 | return out 187 | 188 | 189 | # random color 190 | 191 | 192 | 193 | def random_convolution(imgs): 194 | ''' 195 | random covolution in "network randomization" 196 | 197 | (imbs): B x (C x stack) x H x W, note: imgs should be normalized and torch tensor 198 | ''' 199 | _device = imgs.device 200 | 201 | img_h, img_w = imgs.shape[2], imgs.shape[3] 202 | num_stack_channel = imgs.shape[1] 203 | num_batch = imgs.shape[0] 204 | num_trans = num_batch 205 | batch_size = int(num_batch / num_trans) 206 | 207 | # initialize random covolution 208 | rand_conv = nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(_device) 209 | 210 | for trans_index in range(num_trans): 211 | torch.nn.init.xavier_normal_(rand_conv.weight.data) 212 | temp_imgs = imgs[trans_index*batch_size:(trans_index+1)*batch_size] 213 | temp_imgs = temp_imgs.reshape(-1, 3, img_h, img_w) # (batch x stack, channel, h, w) 214 | rand_out = rand_conv(temp_imgs) 215 | if trans_index == 0: 216 | total_out = rand_out 217 | else: 218 | total_out = torch.cat((total_out, rand_out), 0) 219 | total_out = total_out.reshape(-1, num_stack_channel, img_h, img_w) 220 | return total_out 221 | 222 | 223 | def random_color_jitter(imgs): 224 | """ 225 | inputs np array outputs tensor 226 | """ 227 | b,c,h,w = imgs.shape 228 | imgs = imgs.view(-1,3,h,w) 229 | transform_module = nn.Sequential(ColorJitterLayer(brightness=0.4, 230 | contrast=0.4, 231 | saturation=0.4, 232 | hue=0.5, 233 | p=1.0, 234 | batch_size=128)) 235 | 236 | imgs = transform_module(imgs).view(b,c,h,w) 237 | return imgs 238 | 239 | 240 | def random_translate(imgs, size, return_random_idxs=False, h1s=None, w1s=None): 241 | n, c, h, w = imgs.shape 242 | assert size >= h and size >= w 243 | outs = np.zeros((n, c, size, size), dtype=imgs.dtype) 244 | h1s = np.random.randint(0, size - h + 1, n) if h1s is None else h1s 245 | w1s = np.random.randint(0, size - w + 1, n) if w1s is None else w1s 246 | for out, img, h1, w1 in zip(outs, imgs, h1s, w1s): 247 | out[:, h1:h1 + h, w1:w1 + w] = img 248 | if return_random_idxs: # So can do the same to another set of imgs. 249 | return outs, dict(h1s=h1s, w1s=w1s) 250 | return outs 251 | 252 | 253 | def no_aug(x): 254 | return x 255 | 256 | 257 | if __name__ == '__main__': 258 | import time 259 | from tabulate import tabulate 260 | def now(): 261 | return time.time() 262 | def secs(t): 263 | s = now() - t 264 | tot = round((1e5 * s)/60,1) 265 | return round(s,3),tot 266 | 267 | x = np.load('data_sample.npy',allow_pickle=True) 268 | x = np.concatenate([x,x,x],1) 269 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 270 | 271 | x = torch.from_numpy(x).to(device) 272 | x = x.float() / 255. 273 | 274 | # crop 275 | t = now() 276 | random_crop(x.cpu().numpy(),64) 277 | s1,tot1 = secs(t) 278 | # grayscale 279 | t = now() 280 | random_grayscale(x,p=.5) 281 | s2,tot2 = secs(t) 282 | # normal cutout 283 | t = now() 284 | random_cutout(x.cpu().numpy(),10,30) 285 | s3,tot3 = secs(t) 286 | # color cutout 287 | t = now() 288 | random_cutout_color(x.cpu().numpy(),10,30) 289 | s4,tot4 = secs(t) 290 | # flip 291 | t = now() 292 | random_flip(x,p=.5) 293 | s5,tot5 = secs(t) 294 | # rotate 295 | t = now() 296 | random_rotation(x,p=.5) 297 | s6,tot6 = secs(t) 298 | # rand conv 299 | t = now() 300 | random_convolution(x) 301 | s7,tot7 = secs(t) 302 | # rand color jitter 303 | t = now() 304 | random_color_jitter(x) 305 | s8,tot8 = secs(t) 306 | 307 | print(tabulate([['Crop', s1,tot1], 308 | ['Grayscale', s2,tot2], 309 | ['Normal Cutout', s3,tot3], 310 | ['Color Cutout', s4,tot4], 311 | ['Flip', s5,tot5], 312 | ['Rotate', s6,tot6], 313 | ['Rand Conv', s7,tot7], 314 | ['Color Jitter', s8,tot8]], 315 | headers=['Data Aug', 'Time / batch (secs)', 'Time / 100k steps (mins)'])) 316 | 317 | -------------------------------------------------------------------------------- /data_sample.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MishaLaskin/rad/18d079e677398c70ff2eefefcc81d5a99662103d/data_sample.npy -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def tie_weights(src, trg): 6 | assert type(src) == type(trg) 7 | trg.weight = src.weight 8 | trg.bias = src.bias 9 | 10 | 11 | OUT_DIM = {2: 39, 4: 35, 6: 31} 12 | OUT_DIM_64 = {2: 29, 4: 25, 6: 21} 13 | OUT_DIM_108 = {4: 47} 14 | 15 | 16 | class PixelEncoder(nn.Module): 17 | """Convolutional encoder of pixels observations.""" 18 | def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,output_logits=False): 19 | super().__init__() 20 | 21 | assert len(obs_shape) == 3 22 | self.obs_shape = obs_shape 23 | self.feature_dim = feature_dim 24 | self.num_layers = num_layers 25 | # try 2 5x5s with strides 2x2. with samep adding, it should reduce 84 to 21, so with valid, it should be even smaller than 21. 26 | self.convs = nn.ModuleList( 27 | [nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)] 28 | ) 29 | for i in range(num_layers - 1): 30 | self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) 31 | 32 | if obs_shape[-1] == 108: 33 | assert num_layers in OUT_DIM_108 34 | out_dim = OUT_DIM_108[num_layers] 35 | elif obs_shape[-1] == 64: 36 | out_dim = OUT_DIM_64[num_layers] 37 | else: 38 | out_dim = OUT_DIM[num_layers] 39 | 40 | self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim) 41 | self.ln = nn.LayerNorm(self.feature_dim) 42 | 43 | self.outputs = dict() 44 | self.output_logits = output_logits 45 | 46 | def reparameterize(self, mu, logstd): 47 | std = torch.exp(logstd) 48 | eps = torch.randn_like(std) 49 | return mu + eps * std 50 | 51 | def forward_conv(self, obs): 52 | if obs.max() > 1.: 53 | obs = obs / 255. 54 | 55 | self.outputs['obs'] = obs 56 | 57 | conv = torch.relu(self.convs[0](obs)) 58 | self.outputs['conv1'] = conv 59 | 60 | for i in range(1, self.num_layers): 61 | conv = torch.relu(self.convs[i](conv)) 62 | self.outputs['conv%s' % (i + 1)] = conv 63 | 64 | h = conv.view(conv.size(0), -1) 65 | return h 66 | 67 | def forward(self, obs, detach=False): 68 | h = self.forward_conv(obs) 69 | 70 | if detach: 71 | h = h.detach() 72 | 73 | h_fc = self.fc(h) 74 | self.outputs['fc'] = h_fc 75 | 76 | h_norm = self.ln(h_fc) 77 | self.outputs['ln'] = h_norm 78 | 79 | if self.output_logits: 80 | out = h_norm 81 | else: 82 | out = torch.tanh(h_norm) 83 | self.outputs['tanh'] = out 84 | 85 | return out 86 | 87 | def copy_conv_weights_from(self, source): 88 | """Tie convolutional layers""" 89 | # only tie conv layers 90 | for i in range(self.num_layers): 91 | tie_weights(src=source.convs[i], trg=self.convs[i]) 92 | 93 | def log(self, L, step, log_freq): 94 | if step % log_freq != 0: 95 | return 96 | 97 | for k, v in self.outputs.items(): 98 | L.log_histogram('train_encoder/%s_hist' % k, v, step) 99 | if len(v.shape) > 2: 100 | L.log_image('train_encoder/%s_img' % k, v[0], step) 101 | 102 | for i in range(self.num_layers): 103 | L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step) 104 | L.log_param('train_encoder/fc', self.fc, step) 105 | L.log_param('train_encoder/ln', self.ln, step) 106 | 107 | 108 | class IdentityEncoder(nn.Module): 109 | def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args): 110 | super().__init__() 111 | 112 | assert len(obs_shape) == 1 113 | self.feature_dim = obs_shape[0] 114 | 115 | def forward(self, obs, detach=False): 116 | return obs 117 | 118 | def copy_conv_weights_from(self, source): 119 | pass 120 | 121 | def log(self, L, step, log_freq): 122 | pass 123 | 124 | 125 | _AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder} 126 | 127 | 128 | def make_encoder( 129 | encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False 130 | ): 131 | assert encoder_type in _AVAILABLE_ENCODERS 132 | return _AVAILABLE_ENCODERS[encoder_type]( 133 | obs_shape, feature_dim, num_layers, num_filters, output_logits 134 | ) 135 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from collections import defaultdict 3 | import json 4 | import os 5 | import shutil 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | from termcolor import colored 10 | 11 | FORMAT_CONFIG = { 12 | 'rl': { 13 | 'train': [ 14 | ('episode', 'E', 'int'), ('step', 'S', 'int'), 15 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), 16 | ('batch_reward', 'BR', 'float'), ('actor_loss', 'A_LOSS', 'float'), 17 | ('critic_loss', 'CR_LOSS', 'float') 18 | ], 19 | 'eval': [('step', 'S', 'int'), ('episode_reward', 'ER', 'float')] 20 | } 21 | } 22 | 23 | 24 | class AverageMeter(object): 25 | def __init__(self): 26 | self._sum = 0 27 | self._count = 0 28 | 29 | def update(self, value, n=1): 30 | self._sum += value 31 | self._count += n 32 | 33 | def value(self): 34 | return self._sum / max(1, self._count) 35 | 36 | 37 | class MetersGroup(object): 38 | def __init__(self, file_name, formating): 39 | self._file_name = file_name 40 | if os.path.exists(file_name): 41 | os.remove(file_name) 42 | self._formating = formating 43 | self._meters = defaultdict(AverageMeter) 44 | 45 | def log(self, key, value, n=1): 46 | self._meters[key].update(value, n) 47 | 48 | def _prime_meters(self): 49 | data = dict() 50 | for key, meter in self._meters.items(): 51 | if key.startswith('train'): 52 | key = key[len('train') + 1:] 53 | else: 54 | key = key[len('eval') + 1:] 55 | key = key.replace('/', '_') 56 | data[key] = meter.value() 57 | return data 58 | 59 | def _dump_to_file(self, data): 60 | with open(self._file_name, 'a') as f: 61 | f.write(json.dumps(data) + '\n') 62 | 63 | def _format(self, key, value, ty): 64 | template = '%s: ' 65 | if ty == 'int': 66 | template += '%d' 67 | elif ty == 'float': 68 | template += '%.04f' 69 | elif ty == 'time': 70 | template += '%.01f s' 71 | else: 72 | raise 'invalid format type: %s' % ty 73 | return template % (key, value) 74 | 75 | def _dump_to_console(self, data, prefix): 76 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 77 | pieces = ['{:5}'.format(prefix)] 78 | for key, disp_key, ty in self._formating: 79 | value = data.get(key, 0) 80 | pieces.append(self._format(disp_key, value, ty)) 81 | print('| %s' % (' | '.join(pieces))) 82 | 83 | def dump(self, step, prefix): 84 | if len(self._meters) == 0: 85 | return 86 | data = self._prime_meters() 87 | data['step'] = step 88 | self._dump_to_file(data) 89 | self._dump_to_console(data, prefix) 90 | self._meters.clear() 91 | 92 | 93 | class Logger(object): 94 | def __init__(self, log_dir, use_tb=True, config='rl'): 95 | self._log_dir = log_dir 96 | if use_tb: 97 | tb_dir = os.path.join(log_dir, 'tb') 98 | if os.path.exists(tb_dir): 99 | shutil.rmtree(tb_dir) 100 | self._sw = SummaryWriter(tb_dir) 101 | else: 102 | self._sw = None 103 | self._train_mg = MetersGroup( 104 | os.path.join(log_dir, 'train.log'), 105 | formating=FORMAT_CONFIG[config]['train'] 106 | ) 107 | self._eval_mg = MetersGroup( 108 | os.path.join(log_dir, 'eval.log'), 109 | formating=FORMAT_CONFIG[config]['eval'] 110 | ) 111 | 112 | def _try_sw_log(self, key, value, step): 113 | if self._sw is not None: 114 | self._sw.add_scalar(key, value, step) 115 | 116 | def _try_sw_log_image(self, key, image, step): 117 | if self._sw is not None: 118 | assert image.dim() == 3 119 | grid = torchvision.utils.make_grid(image.unsqueeze(1)) 120 | self._sw.add_image(key, grid, step) 121 | 122 | def _try_sw_log_video(self, key, frames, step): 123 | if self._sw is not None: 124 | frames = torch.from_numpy(np.array(frames)) 125 | frames = frames.unsqueeze(0) 126 | self._sw.add_video(key, frames, step, fps=30) 127 | 128 | def _try_sw_log_histogram(self, key, histogram, step): 129 | if self._sw is not None: 130 | self._sw.add_histogram(key, histogram, step) 131 | 132 | def log(self, key, value, step, n=1): 133 | assert key.startswith('train') or key.startswith('eval') 134 | if type(value) == torch.Tensor: 135 | value = value.item() 136 | self._try_sw_log(key, value / n, step) 137 | mg = self._train_mg if key.startswith('train') else self._eval_mg 138 | mg.log(key, value, n) 139 | 140 | def log_param(self, key, param, step): 141 | self.log_histogram(key + '_w', param.weight.data, step) 142 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 143 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 144 | if hasattr(param, 'bias'): 145 | self.log_histogram(key + '_b', param.bias.data, step) 146 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 147 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 148 | 149 | def log_image(self, key, image, step): 150 | assert key.startswith('train') or key.startswith('eval') 151 | self._try_sw_log_image(key, image, step) 152 | 153 | def log_video(self, key, frames, step): 154 | assert key.startswith('train') or key.startswith('eval') 155 | self._try_sw_log_video(key, frames, step) 156 | 157 | def log_histogram(self, key, histogram, step): 158 | assert key.startswith('train') or key.startswith('eval') 159 | self._try_sw_log_histogram(key, histogram, step) 160 | 161 | def dump(self, step): 162 | self._train_mg.dump(step, 'train') 163 | self._eval_mg.dump(step, 'eval') 164 | -------------------------------------------------------------------------------- /rad_thumb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MishaLaskin/rad/18d079e677398c70ff2eefefcc81d5a99662103d/rad_thumb.png -------------------------------------------------------------------------------- /scripts/cheetah_test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --domain_name cheetah \ 3 | --task_name run \ 4 | --encoder_type pixel --work_dir ./tmp/translation \ 5 | --action_repeat 4 --num_eval_episodes 10 \ 6 | --pre_transform_image_size 100 --image_size 108 \ 7 | --agent rad_sac --frame_stack 3 --data_augs translate \ 8 | --seed 1208 --critic_lr 2e-4 --actor_lr 2e-4 --eval_freq 10000 \ 9 | --batch_size 128 --num_train_steps 600000 --init_steps 10000 \ 10 | --num_filters 32 --encoder_feature_dim 64 --replay_buffer_capacity 100000 \ 11 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --domain_name cartpole \ 3 | --task_name swingup \ 4 | --encoder_type pixel --work_dir ./tmp \ 5 | --action_repeat 8 --num_eval_episodes 10 \ 6 | --pre_transform_image_size 100 --image_size 84 \ 7 | --agent rad_sac --frame_stack 3 --data_augs crop \ 8 | --seed 23 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 --batch_size 128 --num_train_steps 200000 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | import os 5 | import math 6 | import gym 7 | import sys 8 | import random 9 | import time 10 | import json 11 | import dmc2gym 12 | import copy 13 | 14 | import utils 15 | from logger import Logger 16 | from video import VideoRecorder 17 | 18 | from curl_sac import RadSacAgent 19 | from torchvision import transforms 20 | import data_augs as rad 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | # environment 25 | parser.add_argument('--domain_name', default='cartpole') 26 | parser.add_argument('--task_name', default='swingup') 27 | parser.add_argument('--pre_transform_image_size', default=100, type=int) 28 | 29 | parser.add_argument('--image_size', default=84, type=int) 30 | parser.add_argument('--action_repeat', default=1, type=int) 31 | parser.add_argument('--frame_stack', default=3, type=int) 32 | # replay buffer 33 | parser.add_argument('--replay_buffer_capacity', default=100000, type=int) 34 | # train 35 | parser.add_argument('--agent', default='rad_sac', type=str) 36 | parser.add_argument('--init_steps', default=1000, type=int) 37 | parser.add_argument('--num_train_steps', default=1000000, type=int) 38 | parser.add_argument('--batch_size', default=32, type=int) 39 | parser.add_argument('--hidden_dim', default=1024, type=int) 40 | # eval 41 | parser.add_argument('--eval_freq', default=1000, type=int) 42 | parser.add_argument('--num_eval_episodes', default=10, type=int) 43 | # critic 44 | parser.add_argument('--critic_lr', default=1e-3, type=float) 45 | parser.add_argument('--critic_beta', default=0.9, type=float) 46 | parser.add_argument('--critic_tau', default=0.01, type=float) # try 0.05 or 0.1 47 | parser.add_argument('--critic_target_update_freq', default=2, type=int) # try to change it to 1 and retain 0.01 above 48 | # actor 49 | parser.add_argument('--actor_lr', default=1e-3, type=float) 50 | parser.add_argument('--actor_beta', default=0.9, type=float) 51 | parser.add_argument('--actor_log_std_min', default=-10, type=float) 52 | parser.add_argument('--actor_log_std_max', default=2, type=float) 53 | parser.add_argument('--actor_update_freq', default=2, type=int) 54 | # encoder 55 | parser.add_argument('--encoder_type', default='pixel', type=str) 56 | parser.add_argument('--encoder_feature_dim', default=50, type=int) 57 | parser.add_argument('--encoder_lr', default=1e-3, type=float) 58 | parser.add_argument('--encoder_tau', default=0.05, type=float) 59 | parser.add_argument('--num_layers', default=4, type=int) 60 | parser.add_argument('--num_filters', default=32, type=int) 61 | parser.add_argument('--latent_dim', default=128, type=int) 62 | # sac 63 | parser.add_argument('--discount', default=0.99, type=float) 64 | parser.add_argument('--init_temperature', default=0.1, type=float) 65 | parser.add_argument('--alpha_lr', default=1e-4, type=float) 66 | parser.add_argument('--alpha_beta', default=0.5, type=float) 67 | # misc 68 | parser.add_argument('--seed', default=1, type=int) 69 | parser.add_argument('--work_dir', default='.', type=str) 70 | parser.add_argument('--save_tb', default=False, action='store_true') 71 | parser.add_argument('--save_buffer', default=False, action='store_true') 72 | parser.add_argument('--save_video', default=False, action='store_true') 73 | parser.add_argument('--save_model', default=False, action='store_true') 74 | parser.add_argument('--detach_encoder', default=False, action='store_true') 75 | # data augs 76 | parser.add_argument('--data_augs', default='crop', type=str) 77 | 78 | 79 | parser.add_argument('--log_interval', default=100, type=int) 80 | args = parser.parse_args() 81 | return args 82 | 83 | 84 | def evaluate(env, agent, video, num_episodes, L, step, args): 85 | all_ep_rewards = [] 86 | 87 | def run_eval_loop(sample_stochastically=True): 88 | start_time = time.time() 89 | prefix = 'stochastic_' if sample_stochastically else '' 90 | for i in range(num_episodes): 91 | obs = env.reset() 92 | video.init(enabled=(i == 0)) 93 | done = False 94 | episode_reward = 0 95 | while not done: 96 | # center crop image 97 | if args.encoder_type == 'pixel' and 'crop' in args.data_augs: 98 | obs = utils.center_crop_image(obs,args.image_size) 99 | if args.encoder_type == 'pixel' and 'translate' in args.data_augs: 100 | # first crop the center with pre_image_size 101 | obs = utils.center_crop_image(obs, args.pre_transform_image_size) 102 | # then translate cropped to center 103 | obs = utils.center_translate(obs, args.image_size) 104 | with utils.eval_mode(agent): 105 | if sample_stochastically: 106 | action = agent.sample_action(obs / 255.) 107 | else: 108 | action = agent.select_action(obs / 255.) 109 | obs, reward, done, _ = env.step(action) 110 | video.record(env) 111 | episode_reward += reward 112 | 113 | video.save('%d.mp4' % step) 114 | L.log('eval/' + prefix + 'episode_reward', episode_reward, step) 115 | all_ep_rewards.append(episode_reward) 116 | 117 | L.log('eval/' + prefix + 'eval_time', time.time()-start_time , step) 118 | mean_ep_reward = np.mean(all_ep_rewards) 119 | best_ep_reward = np.max(all_ep_rewards) 120 | std_ep_reward = np.std(all_ep_rewards) 121 | L.log('eval/' + prefix + 'mean_episode_reward', mean_ep_reward, step) 122 | L.log('eval/' + prefix + 'best_episode_reward', best_ep_reward, step) 123 | 124 | filename = args.work_dir + '/' + args.domain_name + '--'+args.task_name + '-' + args.data_augs + '--s' + str(args.seed) + '--eval_scores.npy' 125 | key = args.domain_name + '-' + args.task_name + '-' + args.data_augs 126 | try: 127 | log_data = np.load(filename,allow_pickle=True) 128 | log_data = log_data.item() 129 | except: 130 | log_data = {} 131 | 132 | if key not in log_data: 133 | log_data[key] = {} 134 | 135 | log_data[key][step] = {} 136 | log_data[key][step]['step'] = step 137 | log_data[key][step]['mean_ep_reward'] = mean_ep_reward 138 | log_data[key][step]['max_ep_reward'] = best_ep_reward 139 | log_data[key][step]['std_ep_reward'] = std_ep_reward 140 | log_data[key][step]['env_step'] = step * args.action_repeat 141 | 142 | np.save(filename,log_data) 143 | 144 | run_eval_loop(sample_stochastically=False) 145 | L.dump(step) 146 | 147 | 148 | def make_agent(obs_shape, action_shape, args, device): 149 | if args.agent == 'rad_sac': 150 | return RadSacAgent( 151 | obs_shape=obs_shape, 152 | action_shape=action_shape, 153 | device=device, 154 | hidden_dim=args.hidden_dim, 155 | discount=args.discount, 156 | init_temperature=args.init_temperature, 157 | alpha_lr=args.alpha_lr, 158 | alpha_beta=args.alpha_beta, 159 | actor_lr=args.actor_lr, 160 | actor_beta=args.actor_beta, 161 | actor_log_std_min=args.actor_log_std_min, 162 | actor_log_std_max=args.actor_log_std_max, 163 | actor_update_freq=args.actor_update_freq, 164 | critic_lr=args.critic_lr, 165 | critic_beta=args.critic_beta, 166 | critic_tau=args.critic_tau, 167 | critic_target_update_freq=args.critic_target_update_freq, 168 | encoder_type=args.encoder_type, 169 | encoder_feature_dim=args.encoder_feature_dim, 170 | encoder_lr=args.encoder_lr, 171 | encoder_tau=args.encoder_tau, 172 | num_layers=args.num_layers, 173 | num_filters=args.num_filters, 174 | log_interval=args.log_interval, 175 | detach_encoder=args.detach_encoder, 176 | latent_dim=args.latent_dim, 177 | data_augs=args.data_augs 178 | 179 | ) 180 | else: 181 | assert 'agent is not supported: %s' % args.agent 182 | 183 | def main(): 184 | args = parse_args() 185 | if args.seed == -1: 186 | args.__dict__["seed"] = np.random.randint(1,1000000) 187 | utils.set_seed_everywhere(args.seed) 188 | 189 | pre_transform_image_size = args.pre_transform_image_size if 'crop' in args.data_augs else args.image_size 190 | pre_image_size = args.pre_transform_image_size # record the pre transform image size for translation 191 | 192 | env = dmc2gym.make( 193 | domain_name=args.domain_name, 194 | task_name=args.task_name, 195 | seed=args.seed, 196 | visualize_reward=False, 197 | from_pixels=(args.encoder_type == 'pixel'), 198 | height=pre_transform_image_size, 199 | width=pre_transform_image_size, 200 | frame_skip=args.action_repeat 201 | ) 202 | 203 | env.seed(args.seed) 204 | 205 | # stack several consecutive frames together 206 | if args.encoder_type == 'pixel': 207 | env = utils.FrameStack(env, k=args.frame_stack) 208 | 209 | # make directory 210 | ts = time.gmtime() 211 | ts = time.strftime("%m-%d", ts) 212 | env_name = args.domain_name + '-' + args.task_name 213 | exp_name = env_name + '-' + ts + '-im' + str(args.image_size) +'-b' \ 214 | + str(args.batch_size) + '-s' + str(args.seed) + '-' + args.encoder_type 215 | args.work_dir = args.work_dir + '/' + exp_name 216 | 217 | utils.make_dir(args.work_dir) 218 | video_dir = utils.make_dir(os.path.join(args.work_dir, 'video')) 219 | model_dir = utils.make_dir(os.path.join(args.work_dir, 'model')) 220 | buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer')) 221 | 222 | video = VideoRecorder(video_dir if args.save_video else None) 223 | 224 | with open(os.path.join(args.work_dir, 'args.json'), 'w') as f: 225 | json.dump(vars(args), f, sort_keys=True, indent=4) 226 | 227 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 228 | 229 | action_shape = env.action_space.shape 230 | 231 | if args.encoder_type == 'pixel': 232 | obs_shape = (3*args.frame_stack, args.image_size, args.image_size) 233 | pre_aug_obs_shape = (3*args.frame_stack,pre_transform_image_size,pre_transform_image_size) 234 | else: 235 | obs_shape = env.observation_space.shape 236 | pre_aug_obs_shape = obs_shape 237 | 238 | replay_buffer = utils.ReplayBuffer( 239 | obs_shape=pre_aug_obs_shape, 240 | action_shape=action_shape, 241 | capacity=args.replay_buffer_capacity, 242 | batch_size=args.batch_size, 243 | device=device, 244 | image_size=args.image_size, 245 | pre_image_size=pre_image_size, 246 | ) 247 | 248 | agent = make_agent( 249 | obs_shape=obs_shape, 250 | action_shape=action_shape, 251 | args=args, 252 | device=device 253 | ) 254 | 255 | 256 | L = Logger(args.work_dir, use_tb=args.save_tb) 257 | 258 | episode, episode_reward, done = 0, 0, True 259 | start_time = time.time() 260 | 261 | for step in range(args.num_train_steps): 262 | # evaluate agent periodically 263 | 264 | if step % args.eval_freq == 0: 265 | L.log('eval/episode', episode, step) 266 | evaluate(env, agent, video, args.num_eval_episodes, L, step,args) 267 | if args.save_model: 268 | agent.save_curl(model_dir, step) 269 | if args.save_buffer: 270 | replay_buffer.save(buffer_dir) 271 | 272 | if done: 273 | if step > 0: 274 | if step % args.log_interval == 0: 275 | L.log('train/duration', time.time() - start_time, step) 276 | L.dump(step) 277 | start_time = time.time() 278 | if step % args.log_interval == 0: 279 | L.log('train/episode_reward', episode_reward, step) 280 | 281 | obs = env.reset() 282 | done = False 283 | episode_reward = 0 284 | episode_step = 0 285 | episode += 1 286 | if step % args.log_interval == 0: 287 | L.log('train/episode', episode, step) 288 | 289 | # sample action for data collection 290 | if step < args.init_steps: 291 | action = env.action_space.sample() 292 | else: 293 | with utils.eval_mode(agent): 294 | action = agent.sample_action(obs / 255.) 295 | 296 | # run training update 297 | if step >= args.init_steps: 298 | num_updates = 1 299 | for _ in range(num_updates): 300 | agent.update(replay_buffer, L, step) 301 | 302 | next_obs, reward, done, _ = env.step(action) 303 | 304 | # allow infinit bootstrap 305 | done_bool = 0 if episode_step + 1 == env._max_episode_steps else float( 306 | done 307 | ) 308 | episode_reward += reward 309 | replay_buffer.add(obs, action, reward, next_obs, done_bool) 310 | 311 | obs = next_obs 312 | episode_step += 1 313 | 314 | 315 | if __name__ == '__main__': 316 | torch.multiprocessing.set_start_method('spawn') 317 | 318 | main() 319 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import gym 5 | import os 6 | from collections import deque 7 | import random 8 | from torch.utils.data import Dataset, DataLoader 9 | import time 10 | from skimage.util.shape import view_as_windows 11 | 12 | class eval_mode(object): 13 | def __init__(self, *models): 14 | self.models = models 15 | 16 | def __enter__(self): 17 | self.prev_states = [] 18 | for model in self.models: 19 | self.prev_states.append(model.training) 20 | model.train(False) 21 | 22 | def __exit__(self, *args): 23 | for model, state in zip(self.models, self.prev_states): 24 | model.train(state) 25 | return False 26 | 27 | 28 | def soft_update_params(net, target_net, tau): 29 | for param, target_param in zip(net.parameters(), target_net.parameters()): 30 | target_param.data.copy_( 31 | tau * param.data + (1 - tau) * target_param.data 32 | ) 33 | 34 | 35 | def set_seed_everywhere(seed): 36 | torch.manual_seed(seed) 37 | if torch.cuda.is_available(): 38 | torch.cuda.manual_seed_all(seed) 39 | np.random.seed(seed) 40 | random.seed(seed) 41 | 42 | 43 | def module_hash(module): 44 | result = 0 45 | for tensor in module.state_dict().values(): 46 | result += tensor.sum().item() 47 | return result 48 | 49 | 50 | def make_dir(dir_path): 51 | try: 52 | os.mkdir(dir_path) 53 | except OSError: 54 | pass 55 | return dir_path 56 | 57 | 58 | def preprocess_obs(obs, bits=5): 59 | """Preprocessing image, see https://arxiv.org/abs/1807.03039.""" 60 | bins = 2**bits 61 | assert obs.dtype == torch.float32 62 | if bits < 8: 63 | obs = torch.floor(obs / 2**(8 - bits)) 64 | obs = obs / bins 65 | obs = obs + torch.rand_like(obs) / bins 66 | obs = obs - 0.5 67 | return obs 68 | 69 | 70 | class ReplayBuffer(Dataset): 71 | """Buffer to store environment transitions.""" 72 | def __init__(self, obs_shape, action_shape, capacity, batch_size, device,image_size=84, 73 | pre_image_size=84, transform=None): 74 | self.capacity = capacity 75 | self.batch_size = batch_size 76 | self.device = device 77 | self.image_size = image_size 78 | self.pre_image_size = pre_image_size # for translation 79 | self.transform = transform 80 | # the proprioceptive obs is stored as float32, pixels obs as uint8 81 | obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 82 | 83 | self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) 84 | self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) 85 | self.actions = np.empty((capacity, *action_shape), dtype=np.float32) 86 | self.rewards = np.empty((capacity, 1), dtype=np.float32) 87 | self.not_dones = np.empty((capacity, 1), dtype=np.float32) 88 | 89 | self.idx = 0 90 | self.last_save = 0 91 | self.full = False 92 | 93 | 94 | 95 | 96 | def add(self, obs, action, reward, next_obs, done): 97 | 98 | np.copyto(self.obses[self.idx], obs) 99 | np.copyto(self.actions[self.idx], action) 100 | np.copyto(self.rewards[self.idx], reward) 101 | np.copyto(self.next_obses[self.idx], next_obs) 102 | np.copyto(self.not_dones[self.idx], not done) 103 | 104 | self.idx = (self.idx + 1) % self.capacity 105 | self.full = self.full or self.idx == 0 106 | 107 | def sample_proprio(self): 108 | 109 | idxs = np.random.randint( 110 | 0, self.capacity if self.full else self.idx, size=self.batch_size 111 | ) 112 | 113 | obses = self.obses[idxs] 114 | next_obses = self.next_obses[idxs] 115 | 116 | obses = torch.as_tensor(obses, device=self.device).float() 117 | actions = torch.as_tensor(self.actions[idxs], device=self.device) 118 | rewards = torch.as_tensor(self.rewards[idxs], device=self.device) 119 | next_obses = torch.as_tensor( 120 | next_obses, device=self.device 121 | ).float() 122 | not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) 123 | return obses, actions, rewards, next_obses, not_dones 124 | 125 | def sample_cpc(self): 126 | 127 | start = time.time() 128 | idxs = np.random.randint( 129 | 0, self.capacity if self.full else self.idx, size=self.batch_size 130 | ) 131 | 132 | obses = self.obses[idxs] 133 | next_obses = self.next_obses[idxs] 134 | pos = obses.copy() 135 | 136 | obses = fast_random_crop(obses, self.image_size) 137 | next_obses = fast_random_crop(next_obses, self.image_size) 138 | pos = fast_random_crop(pos, self.image_size) 139 | 140 | obses = torch.as_tensor(obses, device=self.device).float() 141 | next_obses = torch.as_tensor( 142 | next_obses, device=self.device 143 | ).float() 144 | actions = torch.as_tensor(self.actions[idxs], device=self.device) 145 | rewards = torch.as_tensor(self.rewards[idxs], device=self.device) 146 | not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) 147 | 148 | pos = torch.as_tensor(pos, device=self.device).float() 149 | cpc_kwargs = dict(obs_anchor=obses, obs_pos=pos, 150 | time_anchor=None, time_pos=None) 151 | 152 | return obses, actions, rewards, next_obses, not_dones, cpc_kwargs 153 | 154 | def sample_rad(self,aug_funcs): 155 | 156 | # augs specified as flags 157 | # curl_sac organizes flags into aug funcs 158 | # passes aug funcs into sampler 159 | 160 | 161 | idxs = np.random.randint( 162 | 0, self.capacity if self.full else self.idx, size=self.batch_size 163 | ) 164 | 165 | obses = self.obses[idxs] 166 | next_obses = self.next_obses[idxs] 167 | if aug_funcs: 168 | for aug,func in aug_funcs.items(): 169 | # apply crop and cutout first 170 | if 'crop' in aug or 'cutout' in aug: 171 | obses = func(obses) 172 | next_obses = func(next_obses) 173 | elif 'translate' in aug: 174 | og_obses = center_crop_images(obses, self.pre_image_size) 175 | og_next_obses = center_crop_images(next_obses, self.pre_image_size) 176 | obses, rndm_idxs = func(og_obses, self.image_size, return_random_idxs=True) 177 | next_obses = func(og_next_obses, self.image_size, **rndm_idxs) 178 | 179 | obses = torch.as_tensor(obses, device=self.device).float() 180 | next_obses = torch.as_tensor(next_obses, device=self.device).float() 181 | actions = torch.as_tensor(self.actions[idxs], device=self.device) 182 | rewards = torch.as_tensor(self.rewards[idxs], device=self.device) 183 | not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) 184 | 185 | obses = obses / 255. 186 | next_obses = next_obses / 255. 187 | 188 | # augmentations go here 189 | if aug_funcs: 190 | for aug,func in aug_funcs.items(): 191 | # skip crop and cutout augs 192 | if 'crop' in aug or 'cutout' in aug or 'translate' in aug: 193 | continue 194 | obses = func(obses) 195 | next_obses = func(next_obses) 196 | 197 | return obses, actions, rewards, next_obses, not_dones 198 | 199 | def save(self, save_dir): 200 | if self.idx == self.last_save: 201 | return 202 | path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx)) 203 | payload = [ 204 | self.obses[self.last_save:self.idx], 205 | self.next_obses[self.last_save:self.idx], 206 | self.actions[self.last_save:self.idx], 207 | self.rewards[self.last_save:self.idx], 208 | self.not_dones[self.last_save:self.idx] 209 | ] 210 | self.last_save = self.idx 211 | torch.save(payload, path) 212 | 213 | def load(self, save_dir): 214 | chunks = os.listdir(save_dir) 215 | chucks = sorted(chunks, key=lambda x: int(x.split('_')[0])) 216 | for chunk in chucks: 217 | start, end = [int(x) for x in chunk.split('.')[0].split('_')] 218 | path = os.path.join(save_dir, chunk) 219 | payload = torch.load(path) 220 | assert self.idx == start 221 | self.obses[start:end] = payload[0] 222 | self.next_obses[start:end] = payload[1] 223 | self.actions[start:end] = payload[2] 224 | self.rewards[start:end] = payload[3] 225 | self.not_dones[start:end] = payload[4] 226 | self.idx = end 227 | 228 | def __getitem__(self, idx): 229 | idx = np.random.randint( 230 | 0, self.capacity if self.full else self.idx, size=1 231 | ) 232 | idx = idx[0] 233 | obs = self.obses[idx] 234 | action = self.actions[idx] 235 | reward = self.rewards[idx] 236 | next_obs = self.next_obses[idx] 237 | not_done = self.not_dones[idx] 238 | 239 | if self.transform: 240 | obs = self.transform(obs) 241 | next_obs = self.transform(next_obs) 242 | 243 | return obs, action, reward, next_obs, not_done 244 | 245 | def __len__(self): 246 | return self.capacity 247 | 248 | class FrameStack(gym.Wrapper): 249 | def __init__(self, env, k): 250 | gym.Wrapper.__init__(self, env) 251 | self._k = k 252 | self._frames = deque([], maxlen=k) 253 | shp = env.observation_space.shape 254 | self.observation_space = gym.spaces.Box( 255 | low=0, 256 | high=1, 257 | shape=((shp[0] * k,) + shp[1:]), 258 | dtype=env.observation_space.dtype 259 | ) 260 | self._max_episode_steps = env._max_episode_steps 261 | 262 | def reset(self): 263 | obs = self.env.reset() 264 | for _ in range(self._k): 265 | self._frames.append(obs) 266 | return self._get_obs() 267 | 268 | def step(self, action): 269 | obs, reward, done, info = self.env.step(action) 270 | self._frames.append(obs) 271 | return self._get_obs(), reward, done, info 272 | 273 | def _get_obs(self): 274 | assert len(self._frames) == self._k 275 | return np.concatenate(list(self._frames), axis=0) 276 | 277 | 278 | def center_crop_image(image, output_size): 279 | h, w = image.shape[1:] 280 | new_h, new_w = output_size, output_size 281 | 282 | top = (h - new_h)//2 283 | left = (w - new_w)//2 284 | 285 | image = image[:, top:top + new_h, left:left + new_w] 286 | return image 287 | 288 | 289 | def center_crop_images(image, output_size): 290 | h, w = image.shape[2:] 291 | new_h, new_w = output_size, output_size 292 | 293 | top = (h - new_h)//2 294 | left = (w - new_w)//2 295 | 296 | image = image[:, :, top:top + new_h, left:left + new_w] 297 | return image 298 | 299 | 300 | def center_translate(image, size): 301 | c, h, w = image.shape 302 | assert size >= h and size >= w 303 | outs = np.zeros((c, size, size), dtype=image.dtype) 304 | h1 = (size - h) // 2 305 | w1 = (size - w) // 2 306 | outs[:, h1:h1 + h, w1:w1 + w] = image 307 | return outs 308 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | import numpy as np 4 | 5 | 6 | class VideoRecorder(object): 7 | def __init__(self, dir_name, height=256, width=256, camera_id=0, fps=30): 8 | self.dir_name = dir_name 9 | self.height = height 10 | self.width = width 11 | self.camera_id = camera_id 12 | self.fps = fps 13 | self.frames = [] 14 | 15 | def init(self, enabled=True): 16 | self.frames = [] 17 | self.enabled = self.dir_name is not None and enabled 18 | 19 | def record(self, env): 20 | if self.enabled: 21 | try: 22 | frame = env.render( 23 | mode='rgb_array', 24 | height=self.height, 25 | width=self.width, 26 | camera_id=self.camera_id 27 | ) 28 | except: 29 | frame = env.render( 30 | mode='rgb_array', 31 | ) 32 | 33 | self.frames.append(frame) 34 | 35 | def save(self, file_name): 36 | if self.enabled: 37 | path = os.path.join(self.dir_name, file_name) 38 | imageio.mimsave(path, self.frames, fps=self.fps) 39 | --------------------------------------------------------------------------------