├── .gitignore ├── README.md ├── arial.ttf ├── atari_beta_vae.py ├── atari_beta_vae_actor.py ├── atari_ccil.py ├── atari_cnn_actor.py ├── atari_cnn_actor_crlr.py ├── atari_vqvae.py ├── atari_vqvae_oreo.py ├── coordconv.py ├── download.sh ├── linear_models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | raw_datasets 3 | datasets 4 | *.pyc 5 | .vscode 6 | .ipynb_checkpoints 7 | ipynbs 8 | models* 9 | *.ipynb 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021) 2 | 3 | ## Video demo 4 | We here provide a video demo from confounded Enduro environment (see Figure 8 of the main draft). 5 | We also visualize the spatial attention map from a convolutional encoder trained with BC (medium) and OREO (right). 6 | 7 | ![Enduro_total_demo_cropped](https://user-images.githubusercontent.com/33256298/120595374-38554300-c47d-11eb-97e5-afff3c5d83b9.gif) 8 | 9 | ## Installation 10 | 11 | OREO requires CUDA 10.1 to run. 12 | 13 | Install the dependencies: 14 | ```sh 15 | conda install pytorch torchvision torchaudio cudatoolkit=10.1 -c pytorch 16 | pip install dopamine_rl sklearn tqdm kornia dropblock atari-py==0.2.6 gsutil 17 | ``` 18 | 19 | Download DQN Replay dataset for expert demonstrations on Atari environments: 20 | ```sh 21 | mkdir DATAPATH 22 | cp download.sh DATAPATH 23 | cd DATAPATH 24 | sh download.sh 25 | ``` 26 | 27 | 28 | ## Pre-training 29 | We here provide beta-VAE (for CCIL) and VQ-VAE (for CRLR and OREO) pretraining scripts. 30 | For other datasets, change the --env option. 31 | 32 | ### beta-VAE 33 | ```sh 34 | CUDA_VISIBLE_DEVICES=0,1,2,3 python atari_beta_vae.py --env=KungFuMaster --datapath DATAPATH --num_episodes 20 --seed 1 --ch_div 4 --lmd 10 35 | ``` 36 | ### VQ-VAE 37 | ```sh 38 | CUDA_VISIBLE_DEVICES=0,1,2,3 python atari_vqvae.py --env=KungFuMaster --datapath DATAPATH --num_episodes 20 --seed 1 39 | ``` 40 | 41 | ## Training BC policy 42 | We here provide training scripts for baselines and OREO. 43 | For other datasets, change the --env, --beta_vae_path, and --vqvae_path options. 44 | 45 | ### Behavioral cloning 46 | ```sh 47 | CUDA_VISIBLE_DEVICES=0 python atari_cnn_actor.py --env=KungFuMaster --datapath DATAPATH --seed 1 --eval_interval 1000 --num_episodes 20 --num_eval_episodes 100 48 | ``` 49 | ### Dropout 50 | ```sh 51 | CUDA_VISIBLE_DEVICES=0 python atari_cnn_actor.py --env=KungFuMaster --datapath DATAPATH --seed 1 --eval_interval 1000 --original_dropout --prob 0.5 --num_episodes 20 --num_eval_episodes 100 52 | ``` 53 | ### DropBlock 54 | ```sh 55 | CUDA_VISIBLE_DEVICES=0 python atari_cnn_actor.py --env=KungFuMaster --datapath DATAPATH --seed 1 --eval_interval 1000 --dropblock --prob 0.3 --num_episodes 20 --num_eval_episodes 100 56 | ``` 57 | ### Cutout 58 | ```sh 59 | CUDA_VISIBLE_DEVICES=0 python atari_cnn_actor.py --env=KungFuMaster --datapath DATAPATH --seed 1 --eval_interval 1000 --input_cutout --num_episodes 20 --num_eval_episodes 100 60 | ``` 61 | ### RandomShift 62 | ```sh 63 | CUDA_VISIBLE_DEVICES=0 python atari_cnn_actor.py --env=KungFuMaster --datapath DATAPATH --seed 1 --eval_interval 1000 --random_shift --num_episodes 20 --num_eval_episodes 100 64 | ``` 65 | ### CCIL (w/o interaction) 66 | ```sh 67 | CUDA_VISIBLE_DEVICES=0 python atari_beta_vae_actor.py --env=KungFuMaster --datapath DATAPATH --num_episodes 20 --num_eval_episodes 100 --seed 1 --eval_interval 1000 --prob 0.5 --ch_div 4 --beta_vae_path models_beta_vae_coord_conv_chdiv4_actor_lmd10.0/KungFuMaster_s1_epi20_con1_seed1_zdim50_beta4_kltol0_ep1000_beta_vae.pth 68 | ``` 69 | ### CRLR 70 | ```sh 71 | CUDA_VISIBLE_DEVICES=0 python atari_cnn_actor_crlr.py --fixed_size 15000 --num_sub_iters 10 --eval_interval 10 --save_interval 10 --n_epochs 10 --env=KungFuMaster --datapath DATAPATH --num_episodes 20 --num_eval_episodes 100 --seed 1 --vqvae_path models_vqvae/KungFuMaster_s1_epi20_con1_seed1_ne512_c0.25_ep1000_vqvae.pth 72 | ``` 73 | ## OREO 74 | ```sh 75 | CUDA_VISIBLE_DEVICES=0 python atari_vqvae_oreo.py --env=KungFuMaster --datapath DATAPATH --num_mask 5 --num_episodes 20 --num_eval_episodes 100 --seed 1 --eval_interval 1000 --prob 0.5 --vqvae_path models_vqvae/KungFuMaster_s1_epi20_con1_seed1_ne512_c0.25_ep1000_vqvae.pth 76 | ``` 77 | -------------------------------------------------------------------------------- /arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alinlab/oreo/bd50824df66eaa73a31354409351dca8400b1c61/arial.ttf -------------------------------------------------------------------------------- /atari_beta_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tensorflow as tf 4 | import numpy as np 5 | from tqdm import tqdm 6 | import os 7 | import logging 8 | import random 9 | import csv 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | import argparse 13 | import matplotlib.pyplot as plt 14 | 15 | from linear_models import CoordConvBetaVAE, weight_init 16 | from utils import load_dataset, set_seed_everywhere 17 | from dopamine.discrete_domains.atari_lib import create_atari_environment 18 | import kornia 19 | 20 | gfile = tf.io.gfile 21 | 22 | 23 | def compute_loss(x, x_pred, mu, logvar, kl_tolerance=0): 24 | recon_loss = (x - x_pred).pow(2).sum([1, 2, 3]).mean(0) 25 | kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1) 26 | kl_loss = torch.clamp(kl_loss, kl_tolerance * mu.shape[1], mu.shape[1]).mean() 27 | return recon_loss, kl_loss 28 | 29 | 30 | def train(args): 31 | device = torch.device("cuda") 32 | 33 | torch.backends.cudnn.benchmark = False 34 | torch.backends.cudnn.deterministic = True 35 | set_seed_everywhere(args.seed) 36 | 37 | observations, actions, _ = load_dataset( 38 | args.env, 39 | 1, 40 | args.datapath, 41 | args.normal, 42 | args.num_data, 43 | args.stack, 44 | args.num_episodes, 45 | ) 46 | 47 | logging.info("Building models..") 48 | beta_vae = CoordConvBetaVAE(args.z_dim, args.ch_div).to(device) 49 | 50 | if args.lmd > 0: 51 | env = create_atari_environment(args.env) 52 | action_dim = env.action_space.n 53 | 54 | actor = nn.Sequential( 55 | nn.Linear(args.z_dim, args.z_dim), 56 | nn.ReLU(), 57 | nn.Linear(args.z_dim, action_dim), 58 | ) 59 | actor.apply(weight_init) 60 | actor.to(device) 61 | if torch.cuda.device_count() > 1: 62 | actor = nn.DataParallel(actor) 63 | 64 | save_dir = "models_beta_vae" 65 | resize = kornia.geometry.Resize(64) 66 | save_dir = save_dir + "_coord_conv_chdiv{}".format(args.ch_div) 67 | if args.lmd > 0: 68 | save_dir = save_dir + "_actor_lmd{}".format(args.lmd) 69 | if args.add_path is not None: 70 | save_dir = save_dir + "_" + args.add_path 71 | 72 | if args.num_episodes is None: 73 | save_tag = "{}_s{}_data{}k_con{}_seed{}_zdim{}_beta{}_kltol{}".format( 74 | args.env, 75 | args.stack, 76 | int(args.num_data / 1000), 77 | 1 - int(args.normal), 78 | args.seed, 79 | args.z_dim, 80 | int(args.beta), 81 | args.kl_tolerance, 82 | ) 83 | else: 84 | save_tag = "{}_s{}_epi{}_con{}_seed{}_zdim{}_beta{}_kltol{}".format( 85 | args.env, 86 | args.stack, 87 | int(args.num_episodes), 88 | 1 - int(args.normal), 89 | args.seed, 90 | args.z_dim, 91 | int(args.beta), 92 | args.kl_tolerance, 93 | ) 94 | 95 | if not os.path.exists(save_dir): 96 | os.makedirs(save_dir) 97 | 98 | ## Multi-GPU 99 | if torch.cuda.device_count() > 1: 100 | beta_vae = nn.DataParallel(beta_vae) 101 | 102 | if args.lmd > 0: 103 | beta_vae_optimizer = torch.optim.Adam( 104 | list(beta_vae.parameters()) + list(actor.parameters()), lr=args.lr 105 | ) 106 | else: 107 | beta_vae_optimizer = torch.optim.Adam(beta_vae.parameters(), lr=args.lr) 108 | 109 | n_batch = len(observations) // args.batch_size + 1 110 | total_idxs = list(range(len(observations))) 111 | 112 | logging.info("Training starts..") 113 | f = open(os.path.join(save_dir, save_tag + "_beta_vae_train.csv"), "w") 114 | writer = csv.writer(f) 115 | if args.lmd > 0: 116 | writer.writerow(["Epoch", "Recon Error", "KL Loss", "Actor Loss"]) 117 | else: 118 | writer.writerow(["Epoch", "Recon Error", "KL Loss"]) 119 | 120 | criterion = nn.CrossEntropyLoss() 121 | for epoch in tqdm(range(args.n_epochs)): 122 | random.shuffle(total_idxs) 123 | recon_errors = [] 124 | kl_losses = [] 125 | actor_losses = [] 126 | for j in range(n_batch): 127 | batch_idxs = total_idxs[j * args.batch_size : (j + 1) * args.batch_size] 128 | xx = torch.as_tensor( 129 | observations[batch_idxs], device=device, dtype=torch.float32 130 | ) 131 | xx = xx / 255.0 132 | xx = resize(xx) 133 | 134 | beta_vae_optimizer.zero_grad() 135 | 136 | z, mu, logvar = beta_vae(xx, mode="encode") 137 | obs_pred = beta_vae(z, mode="decode") 138 | recon_loss, kl_loss = compute_loss( 139 | xx, obs_pred, mu, logvar, args.kl_tolerance 140 | ) 141 | 142 | if args.lmd > 0: 143 | batch_act = torch.as_tensor(actions[batch_idxs], device=device).long() 144 | logits = actor(z) 145 | actor_loss = criterion(logits, batch_act) 146 | loss = recon_loss + args.beta * kl_loss + args.lmd * actor_loss 147 | actor_losses.append(actor_loss.mean().detach().cpu().item()) 148 | else: 149 | loss = recon_loss + args.beta * kl_loss 150 | 151 | loss.backward() 152 | 153 | beta_vae_optimizer.step() 154 | 155 | recon_errors.append(recon_loss.mean().detach().cpu().item()) 156 | kl_losses.append(kl_loss.mean().detach().cpu().item()) 157 | 158 | if args.lmd > 0: 159 | logging.info( 160 | "Epoch {} | Recon Error: {:.4f} | KL Loss: {:.4f} | Actor Loss: {:.4f}".format( 161 | epoch + 1, 162 | np.mean(recon_errors), 163 | np.mean(kl_losses), 164 | np.mean(actor_losses), 165 | ) 166 | ) 167 | writer.writerow( 168 | [ 169 | epoch + 1, 170 | np.mean(recon_errors), 171 | np.mean(kl_losses), 172 | np.mean(actor_losses), 173 | ] 174 | ) 175 | else: 176 | logging.info( 177 | "Epoch {} | Recon Error: {:.4f} | KL Loss: {:.4f}".format( 178 | epoch + 1, np.mean(recon_errors), np.mean(kl_losses) 179 | ) 180 | ) 181 | writer.writerow([epoch + 1, np.mean(recon_errors), np.mean(kl_losses)]) 182 | 183 | if (epoch + 1) % args.save_interval == 0: 184 | torch.save( 185 | beta_vae.module.state_dict() 186 | if (torch.cuda.device_count() > 1) 187 | else beta_vae.state_dict(), 188 | os.path.join( 189 | save_dir, save_tag + "_ep{}_beta_vae.pth".format(epoch + 1) 190 | ), 191 | ) 192 | if args.lmd > 0: 193 | torch.save( 194 | actor.module.state_dict() 195 | if (torch.cuda.device_count() > 1) 196 | else actor.state_dict(), 197 | os.path.join( 198 | save_dir, save_tag + "_ep{}_actor.pth".format(epoch + 1) 199 | ), 200 | ) 201 | 202 | 203 | if __name__ == "__main__": 204 | parser = argparse.ArgumentParser() 205 | 206 | # Seed & Env 207 | parser.add_argument("--seed", default=1, type=int) 208 | parser.add_argument("--env", default="Pong", type=str) 209 | parser.add_argument("--datapath", default="/data", type=str) 210 | parser.add_argument("--save_interval", default=100, type=int) 211 | parser.add_argument("--normal", action="store_true", default=False) 212 | parser.add_argument("--num_data", default=50000, type=int) 213 | parser.add_argument("--num_episodes", default=None, type=int) 214 | parser.add_argument("--stack", default=1, type=int) 215 | parser.add_argument("--add_path", default=None, type=str) 216 | 217 | parser.add_argument("--embedding_dim", default=64, type=int) 218 | parser.add_argument("--num_hiddens", default=128, type=int) 219 | parser.add_argument("--num_residual_layers", default=2, type=int) 220 | parser.add_argument("--num_residual_hiddens", default=32, type=int) 221 | parser.add_argument("--beta", default=4, type=float) 222 | parser.add_argument("--kl_tolerance", default=0, type=float) 223 | parser.add_argument("--z_dim", default=50, type=int) 224 | parser.add_argument("--batch_size", default=1024, type=int) 225 | parser.add_argument("--n_epochs", default=1000, type=int) 226 | parser.add_argument("--lr", default=3e-4, type=float) 227 | 228 | parser.add_argument("--ch_div", default=1, type=int) 229 | 230 | parser.add_argument("--lmd", default=0, type=float) 231 | 232 | args = parser.parse_args() 233 | assert args.beta > 1.0, "beta should be larger than 1" 234 | train(args) 235 | -------------------------------------------------------------------------------- /atari_beta_vae_actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tensorflow as tf 4 | import numpy as np 5 | from tqdm import tqdm 6 | import os 7 | import logging 8 | import csv 9 | import random 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | import argparse 13 | 14 | from linear_models import weight_init, CoordConvEncoder 15 | from utils import ( 16 | load_dataset, 17 | evaluate, 18 | set_seed_everywhere, 19 | ) 20 | from dopamine.discrete_domains.atari_lib import create_atari_environment 21 | import kornia 22 | 23 | gfile = tf.io.gfile 24 | 25 | 26 | def train(args): 27 | device = torch.device("cuda") 28 | 29 | torch.backends.cudnn.benchmark = False 30 | torch.backends.cudnn.deterministic = True 31 | set_seed_everywhere(args.seed) 32 | 33 | ## fixed dataset 34 | observations, actions, data_variance = load_dataset( 35 | args.env, 36 | 1, 37 | args.datapath, 38 | args.normal, 39 | args.num_data, 40 | args.stack, 41 | args.num_episodes, 42 | ) 43 | ## Stage 1 44 | logging.info("Building models..") 45 | logging.info("Start stage 1...") 46 | 47 | env = create_atari_environment(args.env) 48 | action_dim = env.action_space.n 49 | 50 | n_batch = len(observations) // args.batch_size + 1 51 | total_idxs = list(range(len(observations))) 52 | 53 | logging.info("Training starts..") 54 | 55 | save_dir = "models_beta_vae_actor" 56 | 57 | if args.num_episodes is None: 58 | save_tag = "{}_s{}_data{}k_con{}_seed{}_ne{}".format( 59 | args.env, 60 | args.stack, 61 | int(args.num_data / 1000), 62 | 1 - int(args.normal), 63 | args.seed, 64 | args.num_embeddings, 65 | ) 66 | else: 67 | save_tag = "{}_s{}_epi{}_con{}_seed{}_ne{}".format( 68 | args.env, 69 | args.stack, 70 | int(args.num_episodes), 71 | 1 - int(args.normal), 72 | args.seed, 73 | args.num_embeddings, 74 | ) 75 | 76 | resize = kornia.geometry.Resize(64) 77 | save_dir = save_dir + "_coord_conv" 78 | 79 | save_dir = save_dir + "_graph_param" 80 | save_tag = save_tag + "_prob{}".format(args.prob) 81 | 82 | if args.add_path is not None: 83 | save_dir = save_dir + "_" + args.add_path 84 | if not os.path.exists(save_dir): 85 | os.makedirs(save_dir) 86 | 87 | encoder = CoordConvEncoder(1, args.z_dim * 2, args.ch_div).to(device) 88 | 89 | actor = nn.Sequential( 90 | nn.Linear(args.z_dim * 2, args.z_dim), 91 | nn.ReLU(), 92 | nn.Linear(args.z_dim, action_dim), 93 | ) 94 | actor.apply(weight_init) 95 | actor.to(device) 96 | 97 | for p in encoder.parameters(): 98 | p.requires_grad = False 99 | 100 | if args.beta_vae_path is None: 101 | assert False 102 | beta_vae_dict = torch.load(args.beta_vae_path, map_location="cpu") 103 | encoder.load_state_dict( 104 | {k[8:]: v for k, v in beta_vae_dict.items() if "encoder" in k} 105 | ) 106 | 107 | actor_optimizer = torch.optim.Adam(actor.parameters(), lr=args.lr) 108 | 109 | ## Multi-GPU 110 | if torch.cuda.device_count() > 1: 111 | encoder = nn.DataParallel(encoder) 112 | actor = nn.DataParallel(actor) 113 | 114 | criterion = nn.CrossEntropyLoss() 115 | scores = [] 116 | logging.info("Training starts..") 117 | f_tr = open(os.path.join(save_dir, save_tag + "_cnn_train.csv"), "w") 118 | writer_tr = csv.writer(f_tr) 119 | writer_tr.writerow(["Epoch", "Loss", "Accuracy"]) 120 | 121 | f_te = open(os.path.join(save_dir, save_tag + "_cnn_eval.csv"), "w") 122 | writer_te = csv.writer(f_te) 123 | writer_te.writerow(["Epoch", "Loss", "Accuracy", "Score"]) 124 | 125 | for epoch in tqdm(range(args.n_epochs)): 126 | encoder.eval() 127 | actor.train() 128 | random.shuffle(total_idxs) 129 | actor_losses = [] 130 | accuracies = [] 131 | for j in range(n_batch): 132 | batch_idxs = total_idxs[j * args.batch_size : (j + 1) * args.batch_size] 133 | xx = torch.as_tensor( 134 | observations[batch_idxs], device=device, dtype=torch.float32 135 | ) 136 | xx = xx / 255.0 137 | xx = resize(xx) 138 | 139 | batch_act = torch.as_tensor(actions[batch_idxs], device=device).long() 140 | 141 | actor_optimizer.zero_grad() 142 | 143 | with torch.no_grad(): 144 | z = encoder(xx) 145 | z, _ = z.chunk(2, dim=-1) # mu 146 | 147 | prob = torch.ones(z.size()) * (1 - args.prob) 148 | mask = torch.bernoulli(prob).to(device) 149 | z = torch.cat([z * mask, mask], dim=1) 150 | 151 | logits = actor(z) 152 | actor_loss = criterion(logits, batch_act) 153 | 154 | actor_loss.backward() 155 | 156 | actor_optimizer.step() 157 | 158 | accuracy = (batch_act == logits.argmax(1)).float().mean() 159 | 160 | actor_losses.append(actor_loss.mean().detach().cpu().item()) 161 | accuracies.append(accuracy.mean().detach().cpu().item()) 162 | 163 | logging.info( 164 | "(Train) Epoch {} | Actor Loss: {:.4f} | Accuracy: {:.2f}".format( 165 | epoch + 1, np.mean(actor_losses), np.mean(accuracies), 166 | ) 167 | ) 168 | writer_tr.writerow( 169 | [epoch + 1, np.mean(actor_losses), np.mean(accuracies),] 170 | ) 171 | 172 | if (epoch + 1) % args.eval_interval == 0: 173 | actor.eval() 174 | encoder.eval() 175 | score = evaluate( 176 | env, 177 | nn.Identity(), 178 | actor.module if torch.cuda.device_count() > 1 else actor, 179 | encoder.module if torch.cuda.device_count() > 1 else encoder, 180 | "beta_vae", 181 | device, 182 | args, 183 | ) 184 | logging.info("(Eval) Epoch {} | Score: {:.2f}".format(epoch + 1, score,)) 185 | scores.append(score) 186 | actor.train() 187 | writer_te.writerow( 188 | [epoch + 1, np.mean(actor_losses), np.mean(accuracies), score] 189 | ) 190 | 191 | f_tr.close() 192 | f_te.close() 193 | 194 | torch.save( 195 | encoder.module.state_dict() 196 | if torch.cuda.device_count() > 1 197 | else encoder.state_dict(), 198 | os.path.join(save_dir, save_tag + "_ep{}_encoder.pth".format(epoch + 1)), 199 | ) 200 | torch.save( 201 | actor.module.state_dict() 202 | if torch.cuda.device_count() > 1 203 | else actor.state_dict(), 204 | os.path.join(save_dir, save_tag + "_ep{}_actor.pth".format(epoch + 1),), 205 | ) 206 | 207 | 208 | if __name__ == "__main__": 209 | parser = argparse.ArgumentParser() 210 | 211 | # Seed & Env 212 | parser.add_argument("--seed", default=1, type=int) 213 | parser.add_argument("--env", default="Pong", type=str) 214 | parser.add_argument("--datapath", default="/data", type=str) 215 | parser.add_argument("--num_data", default=50000, type=int) 216 | parser.add_argument("--stack", default=1, type=int) 217 | parser.add_argument("--normal", action="store_true", default=False) 218 | parser.add_argument("--normal_eval", action="store_true", default=False) 219 | 220 | # Save & Evaluation 221 | parser.add_argument("--save_interval", default=20, type=int) 222 | parser.add_argument("--eval_interval", default=20, type=int) 223 | parser.add_argument("--num_episodes", default=None, type=int) 224 | parser.add_argument("--num_eval_episodes", default=20, type=int) 225 | parser.add_argument("--n_epochs", default=1000, type=int) 226 | parser.add_argument("--add_path", default=None, type=str) 227 | 228 | # Encoder & Hyperparams 229 | parser.add_argument("--embedding_dim", default=64, type=int) 230 | parser.add_argument("--num_embeddings", default=512, type=int) 231 | parser.add_argument("--num_hiddens", default=128, type=int) 232 | parser.add_argument("--num_residual_layers", default=2, type=int) 233 | parser.add_argument("--num_residual_hiddens", default=32, type=int) 234 | parser.add_argument("--batch_size", default=1024, type=int) 235 | parser.add_argument("--lr", default=3e-4, type=float) 236 | 237 | # Model load 238 | parser.add_argument("--beta_vae_path", default=None, type=str) 239 | # For MLP 240 | parser.add_argument("--z_dim", default=50, type=int) 241 | # For dropout 242 | parser.add_argument("--prob", default=0.5, type=float) 243 | 244 | parser.add_argument("--ch_div", default=1, type=int) 245 | 246 | args = parser.parse_args() 247 | if args.normal: 248 | assert args.normal_eval 249 | else: 250 | assert not args.normal_eval 251 | 252 | args.coord_conv = True 253 | train(args) 254 | -------------------------------------------------------------------------------- /atari_ccil.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | import os 4 | import random 5 | import csv 6 | import argparse 7 | from time import perf_counter 8 | from collections import deque 9 | 10 | from PIL import Image, ImageFont, ImageDraw 11 | import numpy as np 12 | from sklearn.linear_model import Ridge 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.distributions import Bernoulli 18 | 19 | from dopamine.discrete_domains.atari_lib import create_atari_environment 20 | from linear_models import Encoder, CoordConvEncoder 21 | import kornia 22 | from utils import set_seed_everywhere 23 | 24 | 25 | def sample(weights, temperature): 26 | return ( 27 | Bernoulli(logits=torch.from_numpy(weights) / temperature) 28 | .sample() 29 | .long() 30 | .numpy() 31 | ) 32 | 33 | 34 | def linear_regression(masks, rewards, alpha=1.0): 35 | model = Ridge(alpha).fit(masks, rewards) 36 | return model.coef_, model.intercept_ 37 | 38 | 39 | class SoftQAlgo: 40 | def __init__( 41 | self, num_dims, reward_fn, its, temperature=1.0, device=None, evals_per_it=1, 42 | ): 43 | self.num_dims = num_dims 44 | self.reward_fn = reward_fn 45 | self.its = its 46 | self.device = device 47 | self.temperature = lambda t: temperature 48 | self.evals_per_it = evals_per_it 49 | 50 | def run(self, args, writer): 51 | t = self.temperature(0) 52 | weights = np.zeros(self.num_dims) 53 | 54 | trace = [] 55 | masks = [] 56 | rewards = [] 57 | steps = [] 58 | 59 | mode = (np.sign(weights).astype(np.int64) + 1) // 2 60 | score = np.mean( 61 | [self.reward_fn(mode)[0] for _ in range(args.num_eval_episodes)] 62 | ) 63 | writer.writerow([args.env, args.seed, 0, 0, score]) 64 | 65 | for it in range(self.its): 66 | start = perf_counter() 67 | mask = sample(weights, t) 68 | reward = [] 69 | step = [] 70 | for _ in range(self.evals_per_it): 71 | r, s = self.reward_fn(mask) 72 | reward.append(r) 73 | step.append(s) 74 | reward, step = np.mean(reward), np.sum(step) 75 | 76 | masks.append(mask) 77 | rewards.append(reward) 78 | steps.append(step) 79 | 80 | weights, _ = linear_regression(masks, rewards, alpha=1.0) 81 | 82 | mode = (np.sign(weights).astype(np.int64) + 1) // 2 83 | trace.append( 84 | { 85 | "it": it, 86 | "reward": reward, 87 | "mask": mask, 88 | "weights": weights, 89 | "mode": mode, 90 | "time": perf_counter() - start, 91 | "past_mean_reward": np.mean(rewards), 92 | } 93 | ) 94 | pprint(trace[-1]) 95 | 96 | if (it + 1) % args.eval_interval == 0: 97 | score = np.mean( 98 | [self.reward_fn(mode)[0] for _ in range(args.num_eval_episodes)] 99 | ) 100 | print() 101 | total_steps = np.sum(steps) 102 | print(f"Reward at iter {it+1}, interaction {total_steps}: {score}") 103 | print() 104 | writer.writerow([args.env, args.seed, it + 1, total_steps, score]) 105 | 106 | return trace 107 | 108 | 109 | class StackedObs: 110 | def __init__(self, stack, confounded): 111 | self._stack = stack 112 | self._confounded = confounded 113 | self._deque = deque(maxlen=stack) 114 | self._font = ImageFont.truetype("arial.ttf", size=16) 115 | 116 | def reset(self, obs): 117 | self._deque.clear() 118 | for _ in range(self._stack): 119 | self._deque.append(obs) 120 | prev_action = 0 121 | return self._get_stacked_obs(prev_action) 122 | 123 | def step(self, obs, prev_action): 124 | self._deque.append(obs) 125 | return self._get_stacked_obs(prev_action) 126 | 127 | def _get_stacked_obs(self, prev_action): 128 | if self._confounded: 129 | stacked_obs = [] 130 | for c in range(self._stack): 131 | img = Image.fromarray(self._deque[c][..., 0]) 132 | draw = ImageDraw.Draw(img) 133 | draw.text( 134 | (11, 55), "{}".format(prev_action), fill=255, font=self._font, 135 | ) 136 | obs = np.asarray(img)[..., None] 137 | stacked_obs.append(obs) 138 | stacked_obs = np.concatenate(stacked_obs, axis=2) 139 | else: 140 | stacked_obs = np.concatenate(self._deque, axis=2) 141 | stacked_obs = np.transpose(stacked_obs, (2, 0, 1)) 142 | return stacked_obs 143 | 144 | 145 | def evaluate(env, pre_actor, actor, model, mask, device, args, num_eval_episodes): 146 | model.eval() 147 | actor.eval() 148 | stacked_obs_factory = StackedObs(args.stack, not args.normal_eval) 149 | average_episode_reward = 0 150 | mask = torch.from_numpy(mask).unsqueeze(0).to(device) 151 | 152 | human_scores = { 153 | "Amidar": 1675.8, 154 | "Asterix": 8503.3, 155 | "CrazyClimber": 35410.5, 156 | "DemonAttack": 3401.3, 157 | "Enduro": 309.6, 158 | "Freeway": 29.6, 159 | "Gopher": 2321.0, 160 | "Jamesbond": 406.7, 161 | "Kangaroo": 3035.0, 162 | "KungFuMaster": 22736.2, 163 | "Pong": 9.3, 164 | "PrivateEye": 69571.3, 165 | "Seaquest": 20181.8, 166 | "Alien": 6875.4, 167 | "Assault": 1496.4, 168 | "BankHeist": 734.4, 169 | "BattleZone": 37800.0, 170 | "Boxing": 4.3, 171 | "Breakout": 31.8, 172 | "ChopperCommand": 9881.8, 173 | "Frostbite": 4334.7, 174 | "Hero": 25762.5, 175 | "Krull": 2394.6, 176 | "MsPacman": 15693.4, 177 | "Qbert": 13455.0, 178 | "RoadRunner": 7845.0, 179 | "UpNDown": 9082.0, 180 | } 181 | random_scores = { 182 | "Amidar": 5.8, 183 | "Asterix": 210.0, 184 | "CrazyClimber": 10780.5, 185 | "DemonAttack": 152.1, 186 | "Enduro": 0.0, 187 | "Freeway": 0.0, 188 | "Gopher": 257.6, 189 | "Jamesbond": 29.0, 190 | "Kangaroo": 52.0, 191 | "KungFuMaster": 258.5, 192 | "Pong": -20.7, 193 | "PrivateEye": 24.9, 194 | "Seaquest": 68.4, 195 | "Alien": 227.8, 196 | "Assault": 222.4, 197 | "BankHeist": 14.2, 198 | "BattleZone": 2360.0, 199 | "Boxing": 0.1, 200 | "Breakout": 1.7, 201 | "ChopperCommand": 811.0, 202 | "Frostbite": 65.2, 203 | "Hero": 1027.0, 204 | "Krull": 1598.0, 205 | "MsPacman": 307.3, 206 | "Qbert": 163.9, 207 | "RoadRunner": 11.5, 208 | "UpNDown": 533.4, 209 | } 210 | 211 | resize = kornia.geometry.Resize(64) 212 | total_step = 0 213 | for episode in range(num_eval_episodes): 214 | obs = env.reset() 215 | done = False 216 | episode_reward = 0 217 | step = 0 218 | while not done: 219 | if step == 0: 220 | stacked_obs = stacked_obs_factory.reset(obs) 221 | 222 | with torch.no_grad(): 223 | stacked_obs = ( 224 | torch.as_tensor( 225 | stacked_obs, device=device, dtype=torch.float32 226 | ).unsqueeze(0) 227 | / 255.0 228 | ) 229 | 230 | stacked_obs = resize(stacked_obs) 231 | features = model(stacked_obs) 232 | 233 | features = pre_actor(torch.flatten(features, start_dim=1)) 234 | features, _ = features.chunk(2, dim=-1) # mu 235 | # causal graph 236 | features = torch.cat( 237 | [features * mask, mask.repeat(features.shape[0], 1)], dim=1 238 | ) 239 | action = actor(features).argmax(1)[0].cpu().item() 240 | 241 | obs, reward, done, info = env.step(action) 242 | prev_action = action 243 | stacked_obs = stacked_obs_factory.step(obs, prev_action) 244 | episode_reward += reward 245 | step += 1 246 | if step == 27000: 247 | done = True 248 | total_step += step 249 | 250 | average_episode_reward += episode_reward 251 | average_episode_reward /= num_eval_episodes 252 | model.train() 253 | actor.train() 254 | normalized_reward = (average_episode_reward - random_scores[args.env]) / np.abs( 255 | human_scores[args.env] - random_scores[args.env] 256 | ) 257 | return normalized_reward, total_step 258 | 259 | 260 | def intervention_policy_execution(args): 261 | torch.backends.cudnn.benchmark = False 262 | torch.backends.cudnn.deterministic = True 263 | set_seed_everywhere(args.seed) 264 | 265 | device = torch.device("cuda") 266 | env = create_atari_environment(args.env) 267 | action_dim = env.action_space.n 268 | 269 | actor = nn.Sequential( 270 | nn.Linear(args.z_dim * 2, args.z_dim), 271 | nn.ReLU(), 272 | nn.Linear(args.z_dim, action_dim), 273 | ).to(device) 274 | encoder = CoordConvEncoder(1, args.z_dim * 2, args.ch_div).to(device) 275 | 276 | if args.env in [ 277 | "Amidar", 278 | "Asterix", 279 | "CrazyClimber", 280 | "DemonAttack", 281 | "Enduro", 282 | "Freeway", 283 | "Gopher", 284 | "Jamesbond", 285 | "Kangaroo", 286 | "KungFuMaster", 287 | "Pong", 288 | "PrivateEye", 289 | "Seaquest", 290 | ]: 291 | num_episodes = 20 292 | elif args.env in [ 293 | "Alien", 294 | "Assault", 295 | "BankHeist", 296 | "BattleZone", 297 | "Boxing", 298 | "Breakout", 299 | "ChopperCommand", 300 | "Frostbite", 301 | "Hero", 302 | "Krull", 303 | "MsPacman", 304 | "Qbert", 305 | "RoadRunner", 306 | "UpNDown", 307 | ]: 308 | num_episodes = 50 309 | else: 310 | raise ValueError("not a target game") 311 | 312 | encoder_path = os.path.join( 313 | args.save_path, 314 | "{}_s1_epi{}_con{}_seed{}_ne512_prob0.5_ep1000_encoder.pth".format( 315 | args.env, num_episodes, 1 - int(args.normal_eval), args.seed 316 | ), 317 | ) 318 | actor_path = os.path.join( 319 | args.save_path, 320 | "{}_s1_epi{}_con{}_seed{}_ne512_prob0.5_ep1000_actor.pth".format( 321 | args.env, num_episodes, 1 - int(args.normal_eval), args.seed 322 | ), 323 | ) 324 | encoder.load_state_dict(torch.load(encoder_path, map_location="cpu")) 325 | actor.load_state_dict(torch.load(actor_path, map_location="cpu")) 326 | 327 | ## Multi-GPU 328 | if torch.cuda.device_count() > 1: 329 | encoder = nn.DataParallel(encoder) 330 | actor = nn.DataParallel(actor) 331 | 332 | def run_step(mask): 333 | score, steps = evaluate( 334 | env, 335 | nn.Identity(), 336 | actor.module if torch.cuda.device_count() > 1 else actor, 337 | encoder.module if torch.cuda.device_count() > 1 else encoder, 338 | mask, 339 | device, 340 | args, 341 | 1, 342 | ) 343 | return score, steps 344 | 345 | save_dir = "models_beta_vae_actor_coord_conv_ccil_normalized" 346 | save_tag = "{}_s{}_epi{}_con{}_seed{}_ne{}_temp{}".format( 347 | args.env, 348 | args.stack, 349 | int(num_episodes), 350 | 1 - int(args.normal_eval), 351 | args.seed, 352 | args.num_embeddings, 353 | int(args.temperature), 354 | ) 355 | 356 | if args.add_path is not None: 357 | save_dir = save_dir + "_" + args.add_path 358 | if not os.path.exists(save_dir): 359 | os.makedirs(save_dir) 360 | 361 | f_te = open(os.path.join(save_dir, save_tag + "_cnn_eval.csv"), "w") 362 | writer_te = csv.writer(f_te) 363 | writer_te.writerow(["Game", "Seed", "Iters", "Interactions", "Score"]) 364 | 365 | trace = SoftQAlgo( 366 | args.z_dim, run_step, args.num_its, temperature=args.temperature 367 | ).run(args, writer_te) 368 | 369 | best_mask = trace[-1]["mode"] 370 | print(f"Final mask {best_mask.tolist()}") 371 | 372 | score, _ = evaluate( 373 | env, 374 | nn.Identity(), 375 | actor.module if torch.cuda.device_count() > 1 else actor, 376 | encoder.module if torch.cuda.device_count() > 1 else encoder, 377 | best_mask, 378 | device, 379 | args, 380 | args.num_eval_episodes, 381 | ) 382 | 383 | print(f"Final reward {score}") 384 | writer_te.writerow([args.env, args.seed, args.num_its, "final", score]) 385 | 386 | f_te.close() 387 | 388 | torch.save( 389 | torch.from_numpy(best_mask), 390 | os.path.join(save_dir, save_tag + "_best_mask.pth"), 391 | ) 392 | 393 | print(f"Final reward {score}") 394 | 395 | 396 | def main(): 397 | parser = argparse.ArgumentParser() 398 | parser.add_argument("--num_its", type=int, default=20) 399 | parser.add_argument("--temperature", type=float, default=10) 400 | 401 | parser.add_argument("--seed", default=1, type=int) 402 | parser.add_argument("--env", default="Pong", type=str) 403 | parser.add_argument("--datapath", default="/data", type=str) 404 | parser.add_argument("--stack", default=1, type=int) 405 | parser.add_argument("--normal_eval", action="store_true", default=False) 406 | 407 | # Save & Evaluation 408 | parser.add_argument("--num_eval_episodes", default=20, type=int) 409 | parser.add_argument("--eval_interval", default=20, type=int) 410 | parser.add_argument("--add_path", default=None, type=str) 411 | 412 | # Encoder & Hyperparams 413 | parser.add_argument("--num_embeddings", default=512, type=int) 414 | parser.add_argument("--embedding_dim", default=64, type=int) 415 | parser.add_argument("--num_hiddens", default=128, type=int) 416 | parser.add_argument("--num_residual_layers", default=2, type=int) 417 | parser.add_argument("--num_residual_hiddens", default=32, type=int) 418 | 419 | # Model load 420 | parser.add_argument("--save_path", default=None, type=str) 421 | # For MLP 422 | parser.add_argument("--z_dim", default=50, type=int) 423 | parser.add_argument("--ch_div", default=1, type=int) 424 | 425 | intervention_policy_execution(parser.parse_args()) 426 | 427 | 428 | if __name__ == "__main__": 429 | main() 430 | -------------------------------------------------------------------------------- /atari_cnn_actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tensorflow as tf 4 | import numpy as np 5 | from tqdm import tqdm 6 | import gzip 7 | import os 8 | import logging 9 | import csv 10 | import random 11 | import copy 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | import pickle 15 | import argparse 16 | import matplotlib.pyplot as plt 17 | 18 | from PIL import Image, ImageFont, ImageDraw 19 | from kornia.augmentation import RandomErasing, RandomCrop 20 | 21 | from linear_models import Encoder, weight_init 22 | from utils import ( 23 | load_dataset, 24 | evaluate, 25 | set_seed_everywhere, 26 | ) 27 | from dopamine.discrete_domains.atari_lib import create_atari_environment 28 | from dropblock import DropBlock2D, LinearScheduler 29 | 30 | gfile = tf.io.gfile 31 | 32 | 33 | def train(args): 34 | device = torch.device("cuda") 35 | 36 | torch.backends.cudnn.benchmark = False 37 | torch.backends.cudnn.deterministic = True 38 | set_seed_everywhere(args.seed) 39 | 40 | ## fixed dataset 41 | observations, actions, data_variance = load_dataset( 42 | args.env, 43 | 1, 44 | args.datapath, 45 | args.normal, 46 | args.num_data, 47 | args.stack, 48 | args.num_episodes, 49 | ) 50 | 51 | ## Stage 1 52 | logging.info("Building models..") 53 | logging.info("Start stage 1...") 54 | 55 | env = create_atari_environment(args.env) 56 | action_dim = env.action_space.n 57 | 58 | n_batch = len(observations) // args.batch_size + 1 59 | total_idxs = list(range(len(observations))) 60 | 61 | logging.info("Training starts..") 62 | 63 | save_dir = "models_cnn_actor" 64 | 65 | if args.num_episodes is None: 66 | save_tag = "{}_s{}_data{}k_con{}_seed{}_ne{}".format( 67 | args.env, 68 | args.stack, 69 | int(args.num_data / 1000), 70 | 1 - int(args.normal), 71 | args.seed, 72 | args.num_embeddings, 73 | ) 74 | else: 75 | save_tag = "{}_s{}_epi{}_con{}_seed{}_ne{}".format( 76 | args.env, 77 | args.stack, 78 | int(args.num_episodes), 79 | 1 - int(args.normal), 80 | args.seed, 81 | args.num_embeddings, 82 | ) 83 | 84 | if args.original_dropout: 85 | save_dir = save_dir + "_original_dropout" 86 | save_tag = save_tag + "_prob{}".format(args.prob) 87 | 88 | if args.input_cutout: 89 | save_dir = save_dir + "_input_cutout" 90 | if args.random_shift: 91 | save_dir = save_dir + "_random_shift" 92 | 93 | if args.dropblock: 94 | save_dir = save_dir + "_dropblock" 95 | save_tag = save_tag + "_prob{}".format(args.prob) 96 | 97 | if args.add_path is not None: 98 | save_dir = save_dir + "_" + args.add_path 99 | if not os.path.exists(save_dir): 100 | os.makedirs(save_dir) 101 | 102 | encoder = Encoder( 103 | args.stack, 104 | args.embedding_dim, 105 | args.num_hiddens, 106 | args.num_residual_layers, 107 | args.num_residual_hiddens, 108 | ).to(device) 109 | pre_actor = nn.Sequential( 110 | nn.Flatten(start_dim=1), nn.Linear(8 * 8 * args.embedding_dim, args.z_dim) 111 | ) 112 | actor = nn.Sequential( 113 | nn.Linear(args.z_dim, args.z_dim), nn.ReLU(), nn.Linear(args.z_dim, action_dim), 114 | ) 115 | pre_actor.apply(weight_init) 116 | pre_actor.to(device) 117 | actor.apply(weight_init) 118 | actor.to(device) 119 | 120 | actor_optimizer = torch.optim.Adam( 121 | list(encoder.parameters()) 122 | + list(pre_actor.parameters()) 123 | + list(actor.parameters()), 124 | lr=args.lr, 125 | ) 126 | 127 | if args.input_cutout: 128 | ## same size as RAD 129 | scale_min = (float(10) / 84) ** 2 130 | scale_max = (float(30) / 84) ** 2 131 | cutout = RandomErasing(scale=(scale_min, scale_max), ratio=(1.0, 1.0), p=1.0) 132 | 133 | if args.random_shift: 134 | shift = nn.Sequential(nn.ReplicationPad2d(4), RandomCrop((84, 84))) 135 | 136 | if args.dropblock: 137 | drop_block = LinearScheduler( 138 | DropBlock2D(block_size=3, drop_prob=args.prob), 139 | start_value=0.0, 140 | stop_value=args.prob, 141 | nr_steps=int(n_batch * args.n_epochs // 2), 142 | ).to(device) 143 | drop_block.train() 144 | 145 | ## Multi-GPU 146 | if torch.cuda.device_count() > 1: 147 | encoder = nn.DataParallel(encoder) 148 | pre_actor = nn.DataParallel(pre_actor) 149 | actor = nn.DataParallel(actor) 150 | 151 | criterion = nn.CrossEntropyLoss() 152 | scores = [] 153 | logging.info("Training starts..") 154 | f_tr = open(os.path.join(save_dir, save_tag + "_cnn_train.csv"), "w") 155 | writer_tr = csv.writer(f_tr) 156 | writer_tr.writerow(["Epoch", "Loss", "Accuracy"]) 157 | 158 | f_te = open(os.path.join(save_dir, save_tag + "_cnn_eval.csv"), "w") 159 | writer_te = csv.writer(f_te) 160 | writer_te.writerow(["Epoch", "Loss", "Accuracy", "Score"]) 161 | 162 | for epoch in tqdm(range(args.n_epochs)): 163 | encoder.train() 164 | actor.train() 165 | random.shuffle(total_idxs) 166 | actor_losses = [] 167 | accuracies = [] 168 | for j in range(n_batch): 169 | batch_idxs = total_idxs[j * args.batch_size : (j + 1) * args.batch_size] 170 | xx = torch.as_tensor( 171 | observations[batch_idxs], device=device, dtype=torch.float32 172 | ) 173 | xx = xx / 255.0 174 | 175 | if args.input_cutout: 176 | xx = cutout(xx) 177 | if args.random_shift: 178 | xx = shift(xx) 179 | batch_act = torch.as_tensor(actions[batch_idxs], device=device).long() 180 | 181 | actor_optimizer.zero_grad() 182 | 183 | z = encoder(xx) 184 | 185 | if args.dropblock: 186 | drop_block.step() 187 | z = drop_block(z) 188 | 189 | if args.original_dropout: 190 | prob = torch.ones_like(z) * (1 - args.prob) 191 | mask = torch.bernoulli(prob).to(device) 192 | z = z * mask 193 | z = z / (1.0 - args.prob) 194 | 195 | z = pre_actor(z) 196 | 197 | logits = actor(z) 198 | actor_loss = criterion(logits, batch_act) 199 | 200 | actor_loss.backward() 201 | 202 | actor_optimizer.step() 203 | 204 | accuracy = (batch_act == logits.argmax(1)).float().mean() 205 | 206 | actor_losses.append(actor_loss.mean().detach().cpu().item()) 207 | accuracies.append(accuracy.mean().detach().cpu().item()) 208 | 209 | logging.info( 210 | "(Train) Epoch {} | Actor Loss: {:.4f} | Accuracy: {:.2f}".format( 211 | epoch + 1, np.mean(actor_losses), np.mean(accuracies), 212 | ) 213 | ) 214 | writer_tr.writerow( 215 | [epoch + 1, np.mean(actor_losses), np.mean(accuracies),] 216 | ) 217 | 218 | if (epoch + 1) % args.eval_interval == 0: 219 | actor.eval() 220 | encoder.eval() 221 | score = evaluate( 222 | env, 223 | pre_actor.module if torch.cuda.device_count() > 1 else pre_actor, 224 | actor.module if torch.cuda.device_count() > 1 else actor, 225 | encoder.module if torch.cuda.device_count() > 1 else encoder, 226 | "cnn", 227 | device, 228 | args, 229 | ) 230 | logging.info("(Eval) Epoch {} | Score: {:.2f}".format(epoch + 1, score,)) 231 | scores.append(score) 232 | encoder.train() 233 | actor.train() 234 | writer_te.writerow( 235 | [epoch + 1, np.mean(actor_losses), np.mean(accuracies), score] 236 | ) 237 | 238 | f_tr.close() 239 | f_te.close() 240 | 241 | torch.save( 242 | encoder.module.state_dict() 243 | if torch.cuda.device_count() > 1 244 | else encoder.state_dict(), 245 | os.path.join(save_dir, save_tag + "_ep{}_encoder.pth".format(epoch + 1)), 246 | ) 247 | torch.save( 248 | actor.module.state_dict() 249 | if torch.cuda.device_count() > 1 250 | else actor.state_dict(), 251 | os.path.join(save_dir, save_tag + "_ep{}_actor.pth".format(epoch + 1),), 252 | ) 253 | torch.save( 254 | pre_actor.module.state_dict() 255 | if torch.cuda.device_count() > 1 256 | else pre_actor.state_dict(), 257 | os.path.join(save_dir, save_tag + "_ep{}_pre_actor.pth".format(epoch + 1),), 258 | ) 259 | 260 | 261 | if __name__ == "__main__": 262 | parser = argparse.ArgumentParser() 263 | 264 | # Seed & Env 265 | parser.add_argument("--seed", default=1, type=int) 266 | parser.add_argument("--env", default="Pong", type=str) 267 | parser.add_argument("--datapath", default="/data", type=str) 268 | parser.add_argument("--num_data", default=50000, type=int) 269 | parser.add_argument("--stack", default=1, type=int) 270 | parser.add_argument("--normal", action="store_true", default=False) 271 | parser.add_argument("--normal_eval", action="store_true", default=False) 272 | 273 | # Save & Evaluation 274 | parser.add_argument("--save_interval", default=20, type=int) 275 | parser.add_argument("--eval_interval", default=20, type=int) 276 | parser.add_argument("--num_episodes", default=None, type=int) 277 | parser.add_argument("--num_eval_episodes", default=20, type=int) 278 | parser.add_argument("--n_epochs", default=1000, type=int) 279 | parser.add_argument("--add_path", default=None, type=str) 280 | 281 | # Encoder & Hyperparams 282 | parser.add_argument("--embedding_dim", default=64, type=int) 283 | parser.add_argument("--num_embeddings", default=512, type=int) 284 | parser.add_argument("--num_hiddens", default=128, type=int) 285 | parser.add_argument("--num_residual_layers", default=2, type=int) 286 | parser.add_argument("--num_residual_hiddens", default=32, type=int) 287 | parser.add_argument("--batch_size", default=1024, type=int) 288 | parser.add_argument("--lr", default=3e-4, type=float) 289 | 290 | # For MLP 291 | parser.add_argument("--z_dim", default=256, type=int) 292 | # For dropout 293 | parser.add_argument("--prob", default=0.5, type=float) 294 | parser.add_argument("--original_dropout", action="store_true", default=False) 295 | parser.add_argument("--code_dropout", action="store_true", default=False) 296 | parser.add_argument("--input_cutout", action="store_true", default=False) 297 | parser.add_argument("--random_shift", action="store_true", default=False) 298 | parser.add_argument("--dropblock", action="store_true", default=False) 299 | 300 | args = parser.parse_args() 301 | if args.normal: 302 | assert args.normal_eval 303 | else: 304 | assert not args.normal_eval 305 | 306 | train(args) 307 | -------------------------------------------------------------------------------- /atari_cnn_actor_crlr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tensorflow as tf 4 | import numpy as np 5 | from tqdm import tqdm 6 | import gzip 7 | import os 8 | import logging 9 | import csv 10 | import random 11 | import copy 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | import pickle 15 | import argparse 16 | import matplotlib.pyplot as plt 17 | 18 | from PIL import Image, ImageFont, ImageDraw 19 | from kornia.augmentation import RandomErasing 20 | 21 | from linear_models import Encoder, VectorQuantizer, weight_init 22 | from utils import ( 23 | load_dataset, 24 | evaluate_crlr, 25 | set_seed_everywhere, 26 | categorical_confounder_balancing_loss, 27 | ) 28 | from dopamine.discrete_domains.atari_lib import create_atari_environment 29 | from sklearn.linear_model import LogisticRegression 30 | 31 | 32 | gfile = tf.io.gfile 33 | 34 | 35 | def train(args): 36 | device = torch.device("cuda") 37 | 38 | torch.backends.cudnn.benchmark = False 39 | torch.backends.cudnn.deterministic = True 40 | set_seed_everywhere(args.seed) 41 | 42 | ## fixed dataset 43 | observations, actions, data_variance = load_dataset( 44 | args.env, 45 | 1, 46 | args.datapath, 47 | args.normal, 48 | args.num_data, 49 | args.stack, 50 | args.num_episodes, 51 | ) 52 | 53 | ## Stage 1 54 | logging.info("Building models..") 55 | logging.info("Start stage 1...") 56 | 57 | env = create_atari_environment(args.env) 58 | action_dim = env.action_space.n 59 | 60 | n_batch = len(observations) // args.batch_size + 1 61 | total_idxs = list(range(len(observations))) 62 | 63 | logging.info("Training starts..") 64 | 65 | save_dir = "models_vqvae_cnn_actor_crlr" 66 | 67 | if args.num_episodes is None: 68 | save_tag = "{}_s{}_data{}k_con{}_seed{}_ne{}".format( 69 | args.env, 70 | args.stack, 71 | int(args.num_data / 1000), 72 | 1 - int(args.normal), 73 | args.seed, 74 | args.num_embeddings, 75 | ) 76 | else: 77 | save_tag = "{}_s{}_epi{}_con{}_seed{}_ne{}".format( 78 | args.env, 79 | args.stack, 80 | int(args.num_episodes), 81 | 1 - int(args.normal), 82 | args.seed, 83 | args.num_embeddings, 84 | ) 85 | 86 | if args.add_path is not None: 87 | save_dir = save_dir + "_" + args.add_path 88 | if not os.path.exists(save_dir): 89 | os.makedirs(save_dir) 90 | 91 | encoder = Encoder( 92 | args.stack, 93 | args.embedding_dim, 94 | args.num_hiddens, 95 | args.num_residual_layers, 96 | args.num_residual_hiddens, 97 | ).to(device) 98 | quantizer = VectorQuantizer(args.embedding_dim, args.num_embeddings, 0.25).to( 99 | device 100 | ) 101 | 102 | for p in encoder.parameters(): 103 | p.requires_grad = False 104 | for p in quantizer.parameters(): 105 | p.requires_grad = False 106 | vqvae_dict = torch.load(args.vqvae_path, map_location="cpu") 107 | encoder.load_state_dict( 108 | {k[9:]: v for k, v in vqvae_dict.items() if "_encoder" in k} 109 | ) 110 | quantizer.load_state_dict( 111 | {k[11:]: v for k, v in vqvae_dict.items() if "_quantizer" in k} 112 | ) 113 | 114 | ## Multi-GPU 115 | if torch.cuda.device_count() > 1: 116 | encoder = nn.DataParallel(encoder) 117 | quantizer = nn.DataParallel(quantizer) 118 | 119 | criterion = nn.CrossEntropyLoss() 120 | logging.info("Training starts..") 121 | f_tr = open(os.path.join(save_dir, save_tag + "_cnn_train.csv"), "w") 122 | writer_tr = csv.writer(f_tr) 123 | writer_tr.writerow(["Epoch", "Actor Loss", "Weight Loss", "Accuracy"]) 124 | 125 | f_te = open(os.path.join(save_dir, save_tag + "_cnn_eval.csv"), "w") 126 | writer_te = csv.writer(f_te) 127 | writer_te.writerow(["Epoch", "Actor Loss", "Weight Loss", "Accuracy", "Score"]) 128 | 129 | if args.idx_path is None: 130 | encoder.eval() 131 | quantizer.eval() 132 | total_encoding_indices = [] 133 | with torch.no_grad(): 134 | for j in range(n_batch): 135 | batch_idxs = total_idxs[j * args.batch_size : (j + 1) * args.batch_size] 136 | xx = torch.as_tensor( 137 | observations[batch_idxs], device=device, dtype=torch.float32 138 | ) 139 | xx = xx / 255.0 140 | 141 | z = encoder(xx) 142 | z, *_, encoding_indices, _ = quantizer(z) 143 | total_encoding_indices.append(encoding_indices.cpu()) 144 | total_encoding_indices = torch.cat(total_encoding_indices, dim=0) 145 | if not os.path.exists("./total_idx"): 146 | os.makedirs("./total_idx") 147 | torch.save( 148 | total_encoding_indices, 149 | os.path.join("./total_idx", save_tag + "_total_idx.pth"), 150 | ) 151 | else: 152 | total_encoding_indices = torch.load(args.idx_path, map_location="cpu") 153 | 154 | N, P = total_encoding_indices.shape 155 | total_encoding_onehot = torch.zeros( 156 | (N * P, args.num_embeddings), device=total_encoding_indices.device 157 | ) 158 | total_encoding_onehot.scatter_( 159 | 1, total_encoding_indices.reshape(-1).unsqueeze(1), 1 160 | ) 161 | total_encoding_onehot = total_encoding_onehot.view(N, P, args.num_embeddings) # NPE 162 | 163 | actor = nn.Linear(args.num_embeddings * P, action_dim,).to(device) 164 | if torch.cuda.device_count() > 1: 165 | actor = nn.DataParallel(actor) 166 | 167 | criterion = nn.CrossEntropyLoss(reduction="none") 168 | total_actions = torch.as_tensor(actions, device=device).long() 169 | 170 | x_total = torch.flatten(total_encoding_onehot, start_dim=1).detach() # ND 171 | y_total = total_actions.detach() # N 172 | x_total_np = x_total.cpu().numpy() 173 | y_total_np = y_total.cpu().numpy() 174 | if args.fixed_size is None: 175 | fixed_size = len(x_total) 176 | else: 177 | fixed_size = args.fixed_size 178 | 179 | if len(x_total) > fixed_size: 180 | weight = torch.full( 181 | [fixed_size], 1.0 / fixed_size, requires_grad=True, device=device 182 | ) 183 | proj = torch.eye(fixed_size) - torch.ones(fixed_size, fixed_size) / fixed_size 184 | proj = proj.to(device) 185 | 186 | sample_idx = np.random.choice(len(x_total), fixed_size) 187 | x_total = x_total[sample_idx].to(device) 188 | y_total = y_total[sample_idx].to(device) 189 | x_total_np = x_total_np[sample_idx] 190 | y_total_np = y_total_np[sample_idx] 191 | total_encoding_indices = total_encoding_indices[sample_idx].to(device) 192 | total_encoding_onehot = total_encoding_onehot[sample_idx].to(device) 193 | total_actions = total_actions[sample_idx] 194 | else: 195 | weight = torch.full([N], 1.0 / N, requires_grad=True, device=device) 196 | proj = torch.eye(N) - torch.ones(N, N) / N 197 | proj = proj.to(device) 198 | 199 | for epoch in tqdm(range(args.n_epochs)): 200 | actor_losses = [] 201 | weight_losses = [] 202 | accuracies = [] 203 | sample_weight = weight.detach().cpu().numpy() # N 204 | actor_clf = LogisticRegression(random_state=args.seed, n_jobs=-1).fit( 205 | x_total_np, y_total_np, sample_weight=sample_weight 206 | ) 207 | cls_list = actor_clf.classes_ 208 | if not ((max(cls_list) == len(cls_list) - 1) and (min(cls_list) == 0)): 209 | raise ValueError("class re-mapping is needed") 210 | if torch.cuda.device_count() > 1: 211 | actor.module.weight.data = ( 212 | torch.from_numpy(actor_clf.coef_).float().to(device) 213 | ) 214 | actor.module.bias.data = ( 215 | torch.from_numpy(actor_clf.intercept_).float().to(device) 216 | ) 217 | else: 218 | actor.weight.data = torch.from_numpy(actor_clf.coef_).float().to(device) 219 | actor.bias.data = torch.from_numpy(actor_clf.intercept_).float().to(device) 220 | with torch.no_grad(): 221 | logits = actor(x_total) 222 | 223 | for ii in tqdm(range(args.num_sub_iters)): 224 | weight_loss = categorical_confounder_balancing_loss( 225 | total_encoding_indices, 226 | weight, 227 | args.num_embeddings, 228 | total_encoding_onehot, 229 | ) 230 | actor_loss = criterion(logits, total_actions) 231 | loss = weight @ actor_loss.detach() + args.lmd * weight_loss 232 | loss.backward() 233 | with torch.no_grad(): 234 | weight -= args.lr * (proj @ weight.grad) 235 | weight.abs_() ## non-negative weight 236 | weight /= weight.sum() ## normalization 237 | weight.grad.zero_() 238 | 239 | accuracy = (total_actions == logits.argmax(1)).float().mean() 240 | actor_losses.append(actor_loss.mean().detach().cpu().item()) 241 | weight_losses.append(weight_loss.mean().detach().cpu().item()) 242 | accuracies.append(accuracy.mean().detach().cpu().item()) 243 | 244 | logging.info( 245 | "Epochs {} | Actor Loss: {:.4f} | Weight Loss: {:.4f} | Accuracy: {:.2f}".format( 246 | epoch + 1, 247 | np.mean(actor_losses), 248 | np.mean(weight_losses), 249 | np.mean(accuracies), 250 | ) 251 | ) 252 | writer_tr.writerow( 253 | [ 254 | epoch + 1, 255 | np.mean(actor_losses), 256 | np.mean(weight_losses), 257 | np.mean(accuracies), 258 | ] 259 | ) 260 | 261 | if (epoch + 1) % args.eval_interval == 0: 262 | actor.eval() 263 | encoder.eval() 264 | quantizer.eval() 265 | score = evaluate_crlr( 266 | env, 267 | actor.module if torch.cuda.device_count() > 1 else actor, 268 | encoder.module if torch.cuda.device_count() > 1 else encoder, 269 | encoder.module if torch.cuda.device_count() > 1 else encoder, 270 | quantizer.module if torch.cuda.device_count() > 1 else quantizer, 271 | device, 272 | args, 273 | ) 274 | logging.info("(Eval) Epoch {} | Score: {:.2f}".format(epoch + 1, score,)) 275 | actor.train() 276 | writer_te.writerow( 277 | [ 278 | epoch + 1, 279 | np.mean(actor_losses), 280 | np.mean(weight_losses), 281 | np.mean(accuracies), 282 | score, 283 | ] 284 | ) 285 | 286 | f_tr.close() 287 | f_te.close() 288 | 289 | torch.save( 290 | actor.module.state_dict() 291 | if (torch.cuda.device_count() > 1) 292 | else actor.state_dict(), 293 | os.path.join(save_dir, save_tag + "_ep{}_actor.pth".format(epoch + 1),), 294 | ) 295 | torch.save( 296 | weight, os.path.join(save_dir, save_tag + "_ep{}_weight.pth".format(epoch + 1)), 297 | ) 298 | if len(x_total) > fixed_size: 299 | torch.save( 300 | sample_idx, 301 | os.path.join(save_dir, save_tag + "_ep{}_sample_idx.pth".format(epoch + 1)), 302 | ) 303 | 304 | torch.save( 305 | encoder.module.state_dict() 306 | if torch.cuda.device_count() > 1 307 | else encoder.state_dict(), 308 | os.path.join(save_dir, save_tag + "_ep{}_encoder.pth".format(epoch + 1)), 309 | ) 310 | torch.save( 311 | quantizer.module.state_dict() 312 | if torch.cuda.device_count() > 1 313 | else quantizer.state_dict(), 314 | os.path.join(save_dir, save_tag + "_ep{}_quantizer.pth".format(epoch + 1)), 315 | ) 316 | 317 | 318 | if __name__ == "__main__": 319 | parser = argparse.ArgumentParser() 320 | 321 | # Seed & Env 322 | parser.add_argument("--seed", default=1, type=int) 323 | parser.add_argument("--env", default="Pong", type=str) 324 | parser.add_argument("--datapath", default="/data", type=str) 325 | parser.add_argument("--num_data", default=50000, type=int) 326 | parser.add_argument("--stack", default=1, type=int) 327 | parser.add_argument("--normal", action="store_true", default=False) 328 | parser.add_argument("--normal_eval", action="store_true", default=False) 329 | 330 | # Save & Evaluation 331 | parser.add_argument("--save_interval", default=20, type=int) 332 | parser.add_argument("--eval_interval", default=20, type=int) 333 | parser.add_argument("--num_episodes", default=None, type=int) 334 | parser.add_argument("--num_eval_episodes", default=20, type=int) 335 | parser.add_argument("--n_epochs", default=1000, type=int) 336 | parser.add_argument("--add_path", default=None, type=str) 337 | 338 | # Encoder & Hyperparams 339 | parser.add_argument("--embedding_dim", default=64, type=int) 340 | parser.add_argument("--num_embeddings", default=512, type=int) 341 | parser.add_argument("--num_hiddens", default=128, type=int) 342 | parser.add_argument("--num_residual_layers", default=2, type=int) 343 | parser.add_argument("--num_residual_hiddens", default=32, type=int) 344 | parser.add_argument("--batch_size", default=1024, type=int) 345 | parser.add_argument("--lr", default=3e-4, type=float) 346 | 347 | # Model load 348 | parser.add_argument("--vqvae_path", default=None, type=str) 349 | # For MLP 350 | parser.add_argument("--z_dim", default=256, type=int) 351 | # For CRLR 352 | parser.add_argument("--lmd", default=1e-1, type=float) 353 | parser.add_argument("--num_sub_iters", default=50, type=int) 354 | parser.add_argument("--fixed_size", default=None, type=int) 355 | parser.add_argument("--idx_path", default=None, type=str) 356 | 357 | args = parser.parse_args() 358 | if args.normal: 359 | assert args.normal_eval 360 | else: 361 | assert not args.normal_eval 362 | 363 | train(args) 364 | -------------------------------------------------------------------------------- /atari_vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tensorflow as tf 4 | import numpy as np 5 | from tqdm import tqdm 6 | import os 7 | import logging 8 | import csv 9 | import random 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | import argparse 13 | 14 | from linear_models import ( 15 | VectorQuantizer, 16 | Encoder, 17 | Decoder, 18 | VQVAEModel, 19 | ) 20 | from utils import ( 21 | load_dataset, 22 | set_seed_everywhere, 23 | ) 24 | 25 | gfile = tf.io.gfile 26 | 27 | 28 | def train(args): 29 | device = torch.device("cuda") 30 | 31 | torch.backends.cudnn.benchmark = False 32 | torch.backends.cudnn.deterministic = True 33 | set_seed_everywhere(args.seed) 34 | 35 | ## fixed dataset 36 | observations, actions, data_variance = load_dataset( 37 | args.env, 38 | 1, 39 | args.datapath, 40 | args.normal, 41 | args.num_data, 42 | args.stack, 43 | args.num_episodes, 44 | ) 45 | 46 | ## Stage 1 47 | logging.info("Building models..") 48 | logging.info("Start stage 1...") 49 | encoder = Encoder( 50 | args.stack, 51 | args.embedding_dim, 52 | args.num_hiddens, 53 | args.num_residual_layers, 54 | args.num_residual_hiddens, 55 | ) 56 | decoder = Decoder( 57 | args.stack, 58 | args.embedding_dim, 59 | args.num_hiddens, 60 | args.num_residual_layers, 61 | args.num_residual_hiddens, 62 | ) 63 | quantizer = VectorQuantizer( 64 | args.embedding_dim, args.num_embeddings, args.commitment_cost, 65 | ) 66 | vqvae = VQVAEModel(encoder, decoder, quantizer).to(device) 67 | 68 | n_batch = len(observations) // args.batch_size + 1 69 | total_idxs = list(range(len(observations))) 70 | 71 | logging.info("Training starts..") 72 | 73 | save_dir = "models_vqvae" 74 | if args.num_episodes is None: 75 | save_tag = "{}_s{}_data{}k_con{}_seed{}_ne{}_c{}".format( 76 | args.env, 77 | args.stack, 78 | int(args.num_data / 1000), 79 | 1 - int(args.normal), 80 | args.seed, 81 | args.num_embeddings, 82 | args.commitment_cost, 83 | ) 84 | else: 85 | save_tag = "{}_s{}_epi{}_con{}_seed{}_ne{}_c{}".format( 86 | args.env, 87 | args.stack, 88 | int(args.num_episodes), 89 | 1 - int(args.normal), 90 | args.seed, 91 | args.num_embeddings, 92 | args.commitment_cost, 93 | ) 94 | 95 | if args.add_path is not None: 96 | save_dir = save_dir + "_" + args.add_path 97 | if not os.path.exists(save_dir): 98 | os.makedirs(save_dir) 99 | 100 | ## Multi-GPU 101 | if torch.cuda.device_count() > 1: 102 | vqvae = nn.DataParallel(vqvae) 103 | vqvae_optimizer = torch.optim.Adam(vqvae.parameters(), lr=args.lr) 104 | 105 | f = open(os.path.join(save_dir, save_tag + "_vqvae_train.csv"), "w") 106 | writer = csv.writer(f) 107 | writer.writerow(["Epoch", "Recon Error", "VQ Loss"]) 108 | for epoch in tqdm(range(args.n_epochs)): 109 | random.shuffle(total_idxs) 110 | recon_errors = [] 111 | vq_losses = [] 112 | vqvae.train() 113 | for j in range(n_batch): 114 | batch_idxs = total_idxs[j * args.batch_size : (j + 1) * args.batch_size] 115 | xx = torch.as_tensor( 116 | observations[batch_idxs], device=device, dtype=torch.float32 117 | ) 118 | xx = xx / 255.0 119 | 120 | vqvae_optimizer.zero_grad() 121 | 122 | z, x_recon, vq_loss, quantized, _ = vqvae(xx) 123 | vq_loss = vq_loss.mean() 124 | recon_error = torch.mean((x_recon - xx) ** 2) / data_variance 125 | loss = recon_error + vq_loss 126 | loss.backward() 127 | 128 | vqvae_optimizer.step() 129 | 130 | recon_errors.append(recon_error.mean().detach().cpu().item()) 131 | vq_losses.append(vq_loss.mean().detach().cpu().item()) 132 | logging.info( 133 | "(Train) Epoch {} | Recon Error: {:.4f} | VQ Loss: {:.4f}".format( 134 | epoch + 1, np.mean(recon_errors), np.mean(vq_losses) 135 | ) 136 | ) 137 | writer.writerow([epoch + 1, np.mean(recon_errors), np.mean(vq_losses)]) 138 | 139 | if (epoch + 1) % args.save_interval == 0: 140 | torch.save( 141 | vqvae.module.state_dict() 142 | if (torch.cuda.device_count() > 1) 143 | else vqvae.state_dict(), 144 | os.path.join(save_dir, save_tag + "_ep{}_vqvae.pth".format(epoch + 1)), 145 | ) 146 | f.close() 147 | 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser() 151 | 152 | # Seed & Env 153 | parser.add_argument("--seed", default=1, type=int) 154 | parser.add_argument("--env", default="Pong", type=str) 155 | parser.add_argument("--datapath", default="/data", type=str) 156 | parser.add_argument("--num_data", default=50000, type=int) 157 | parser.add_argument("--stack", default=1, type=int) 158 | parser.add_argument("--normal", action="store_true", default=False) 159 | 160 | # Save & Evaluation 161 | parser.add_argument("--save_interval", default=100, type=int) 162 | parser.add_argument("--num_episodes", default=None, type=int) 163 | parser.add_argument("--n_epochs", default=1000, type=int) 164 | parser.add_argument("--add_path", default=None, type=str) 165 | 166 | # VQVAE & Hyperparams 167 | parser.add_argument("--embedding_dim", default=64, type=int) 168 | parser.add_argument("--num_embeddings", default=512, type=int) 169 | parser.add_argument("--num_hiddens", default=128, type=int) 170 | parser.add_argument("--num_residual_layers", default=2, type=int) 171 | parser.add_argument("--num_residual_hiddens", default=32, type=int) 172 | parser.add_argument("--commitment_cost", default=0.25, type=float) 173 | parser.add_argument("--batch_size", default=1024, type=int) 174 | parser.add_argument("--lr", default=3e-4, type=float) 175 | 176 | args = parser.parse_args() 177 | 178 | train(args) 179 | -------------------------------------------------------------------------------- /atari_vqvae_oreo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tensorflow as tf 4 | import numpy as np 5 | from tqdm import tqdm 6 | import gzip 7 | import os 8 | import logging 9 | import csv 10 | import random 11 | import copy 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | import pickle 15 | import argparse 16 | import matplotlib.pyplot as plt 17 | 18 | from PIL import Image, ImageFont, ImageDraw 19 | 20 | from linear_models import Encoder, VectorQuantizer, weight_init 21 | from utils import ( 22 | load_dataset, 23 | evaluate, 24 | set_seed_everywhere, 25 | ) 26 | from dopamine.discrete_domains.atari_lib import create_atari_environment 27 | 28 | gfile = tf.io.gfile 29 | 30 | 31 | def train(args): 32 | device = torch.device("cuda") 33 | 34 | torch.backends.cudnn.benchmark = False 35 | torch.backends.cudnn.deterministic = True 36 | set_seed_everywhere(args.seed) 37 | 38 | ## fixed dataset 39 | observations, actions, data_variance = load_dataset( 40 | args.env, 41 | 1, 42 | args.datapath, 43 | args.normal, 44 | args.num_data, 45 | args.stack, 46 | args.num_episodes, 47 | ) 48 | 49 | ## Stage 1 50 | logging.info("Building models..") 51 | logging.info("Start stage 1...") 52 | 53 | env = create_atari_environment(args.env) 54 | action_dim = env.action_space.n 55 | 56 | n_batch = len(observations) // args.batch_size + 1 57 | total_idxs = list(range(len(observations))) 58 | 59 | logging.info("Training starts..") 60 | 61 | save_dir = "models_vqvae_cnn_actor" 62 | if args.num_episodes is None: 63 | save_tag = "{}_s{}_data{}k_con{}_seed{}_ne{}".format( 64 | args.env, 65 | args.stack, 66 | int(args.num_data / 1000), 67 | 1 - int(args.normal), 68 | args.seed, 69 | args.num_embeddings, 70 | ) 71 | else: 72 | save_tag = "{}_s{}_epi{}_con{}_seed{}_ne{}".format( 73 | args.env, 74 | args.stack, 75 | int(args.num_episodes), 76 | 1 - int(args.normal), 77 | args.seed, 78 | args.num_embeddings, 79 | ) 80 | 81 | assert args.vqvae_path is not None 82 | save_dir = save_dir + "_oreo" 83 | save_tag = save_tag + "_prob{}".format(args.prob) 84 | 85 | if args.num_mask > 1: 86 | save_dir = save_dir + "_train_mask{}".format(args.num_mask) 87 | 88 | if args.add_path is not None: 89 | save_dir = save_dir + "_" + args.add_path 90 | if not os.path.exists(save_dir): 91 | os.makedirs(save_dir) 92 | 93 | encoder = Encoder( 94 | args.stack, 95 | args.embedding_dim, 96 | args.num_hiddens, 97 | args.num_residual_layers, 98 | args.num_residual_hiddens, 99 | ).to(device) 100 | quantizer = VectorQuantizer(args.embedding_dim, args.num_embeddings, 0.25).to( 101 | device 102 | ) 103 | pre_actor = nn.Sequential( 104 | nn.Flatten(start_dim=1), nn.Linear(8 * 8 * args.embedding_dim, args.z_dim) 105 | ) 106 | actor = nn.Sequential( 107 | nn.Linear(args.z_dim, args.z_dim), nn.ReLU(), nn.Linear(args.z_dim, action_dim), 108 | ) 109 | pre_actor.apply(weight_init) 110 | pre_actor.to(device) 111 | actor.apply(weight_init) 112 | actor.to(device) 113 | 114 | for p in quantizer.parameters(): 115 | p.requires_grad = False 116 | vqvae_dict = torch.load(args.vqvae_path, map_location="cpu") 117 | encoder.load_state_dict( 118 | {k[9:]: v for k, v in vqvae_dict.items() if "_encoder" in k} 119 | ) 120 | quantizer.load_state_dict( 121 | {k[11:]: v for k, v in vqvae_dict.items() if "_quantizer" in k} 122 | ) 123 | 124 | criterion = nn.CrossEntropyLoss() 125 | scores = [] 126 | logging.info("Training starts..") 127 | f_tr = open(os.path.join(save_dir, save_tag + "_cnn_train.csv"), "w") 128 | writer_tr = csv.writer(f_tr) 129 | writer_tr.writerow(["Epoch", "Loss", "Accuracy"]) 130 | 131 | f_te = open(os.path.join(save_dir, save_tag + "_cnn_eval.csv"), "w") 132 | writer_te = csv.writer(f_te) 133 | writer_te.writerow(["Epoch", "Loss", "Accuracy", "Score"]) 134 | 135 | encoder.eval() 136 | quantizer.eval() 137 | total_encoding_indices = [] 138 | with torch.no_grad(): 139 | for j in range(n_batch): 140 | batch_idxs = total_idxs[j * args.batch_size : (j + 1) * args.batch_size] 141 | xx = torch.as_tensor( 142 | observations[batch_idxs], device=device, dtype=torch.float32 143 | ) 144 | xx = xx / 255.0 145 | 146 | z = encoder(xx) 147 | z, *_, encoding_indices, _ = quantizer(z) 148 | total_encoding_indices.append(encoding_indices.cpu()) 149 | total_encoding_indices = torch.cat(total_encoding_indices, dim=0) 150 | 151 | del quantizer 152 | 153 | actor_optimizer = torch.optim.Adam( 154 | list(encoder.parameters()) 155 | + list(pre_actor.parameters()) 156 | + list(actor.parameters()), 157 | lr=args.lr, 158 | ) 159 | 160 | ## Multi-GPU 161 | if torch.cuda.device_count() > 1: 162 | encoder = nn.DataParallel(encoder) 163 | pre_actor = nn.DataParallel(pre_actor) 164 | actor = nn.DataParallel(actor) 165 | 166 | for epoch in tqdm(range(args.n_epochs)): 167 | encoder.train() 168 | actor.train() 169 | random.shuffle(total_idxs) 170 | actor_losses = [] 171 | accuracies = [] 172 | for j in range(n_batch): 173 | batch_idxs = total_idxs[j * args.batch_size : (j + 1) * args.batch_size] 174 | xx = torch.as_tensor( 175 | observations[batch_idxs], device=device, dtype=torch.float32 176 | ) 177 | xx = xx / 255.0 178 | batch_act = torch.as_tensor(actions[batch_idxs], device=device).long() 179 | 180 | actor_optimizer.zero_grad() 181 | 182 | z = encoder(xx) 183 | with torch.no_grad(): 184 | encoding_indices = total_encoding_indices[batch_idxs].to( 185 | device 186 | ) # B x 64 187 | prob = torch.ones(xx.shape[0] * args.num_mask, args.num_embeddings) * ( 188 | 1 - args.prob 189 | ) 190 | code_mask = torch.bernoulli(prob).to(device) # B x 512 191 | 192 | ## one-hot encoding 193 | encoding_indices_flatten = encoding_indices.view(-1) # (Bx64) 194 | encoding_indices_onehot = torch.zeros( 195 | (len(encoding_indices_flatten), args.num_embeddings), 196 | device=encoding_indices_flatten.device, 197 | ) 198 | encoding_indices_onehot.scatter_( 199 | 1, encoding_indices_flatten.unsqueeze(1), 1 200 | ) 201 | encoding_indices_onehot = encoding_indices_onehot.view( 202 | xx.shape[0], -1, args.num_embeddings 203 | ) # B x 64 x 512 204 | 205 | mask = ( 206 | code_mask.unsqueeze(1) 207 | * torch.cat( 208 | [encoding_indices_onehot for m in range(args.num_mask)], dim=0 209 | ) 210 | ).sum(2) 211 | mask = mask.reshape(-1, 8, 8) 212 | 213 | z = torch.cat([z for m in range(args.num_mask)], dim=0) * mask.unsqueeze(1) 214 | z = z / (1.0 - args.prob) 215 | z = pre_actor(z) 216 | 217 | logits = actor(z) 218 | actor_loss = criterion( 219 | logits, torch.cat([batch_act for m in range(args.num_mask)], dim=0) 220 | ) 221 | 222 | actor_loss.backward() 223 | 224 | actor_optimizer.step() 225 | 226 | accuracy = (batch_act == logits[: xx.shape[0]].argmax(1)).float().mean() 227 | 228 | actor_losses.append(actor_loss.mean().detach().cpu().item()) 229 | accuracies.append(accuracy.mean().detach().cpu().item()) 230 | 231 | logging.info( 232 | "(Train) Epoch {} | Actor Loss: {:.4f} | Accuracy: {:.2f}".format( 233 | epoch + 1, np.mean(actor_losses), np.mean(accuracies), 234 | ) 235 | ) 236 | writer_tr.writerow( 237 | [epoch + 1, np.mean(actor_losses), np.mean(accuracies),] 238 | ) 239 | 240 | if (epoch + 1) % args.eval_interval == 0: 241 | actor.eval() 242 | encoder.eval() 243 | score = evaluate( 244 | env, 245 | pre_actor.module if torch.cuda.device_count() > 1 else pre_actor, 246 | actor.module if torch.cuda.device_count() > 1 else actor, 247 | encoder.module if torch.cuda.device_count() > 1 else encoder, 248 | "cnn", 249 | device, 250 | args, 251 | ) 252 | logging.info("(Eval) Epoch {} | Score: {:.2f}".format(epoch + 1, score,)) 253 | scores.append(score) 254 | actor.train() 255 | encoder.train() 256 | writer_te.writerow( 257 | [epoch + 1, np.mean(actor_losses), np.mean(accuracies), score] 258 | ) 259 | 260 | f_tr.close() 261 | f_te.close() 262 | torch.save( 263 | encoder.module.state_dict() 264 | if torch.cuda.device_count() > 1 265 | else encoder.state_dict(), 266 | os.path.join(save_dir, save_tag + "_ep{}_encoder.pth".format(epoch + 1)), 267 | ) 268 | torch.save( 269 | actor.module.state_dict() 270 | if torch.cuda.device_count() > 1 271 | else actor.state_dict(), 272 | os.path.join(save_dir, save_tag + "_ep{}_actor.pth".format(epoch + 1),), 273 | ) 274 | torch.save( 275 | pre_actor.module.state_dict() 276 | if torch.cuda.device_count() > 1 277 | else pre_actor.state_dict(), 278 | os.path.join(save_dir, save_tag + "_ep{}_pre_actor.pth".format(epoch + 1),), 279 | ) 280 | 281 | 282 | if __name__ == "__main__": 283 | parser = argparse.ArgumentParser() 284 | 285 | # Seed & Env 286 | parser.add_argument("--seed", default=1, type=int) 287 | parser.add_argument("--env", default="Pong", type=str) 288 | parser.add_argument("--datapath", default="/data", type=str) 289 | parser.add_argument("--num_data", default=50000, type=int) 290 | parser.add_argument("--stack", default=1, type=int) 291 | parser.add_argument("--normal", action="store_true", default=False) 292 | parser.add_argument("--normal_eval", action="store_true", default=False) 293 | 294 | # Save & Evaluation 295 | parser.add_argument("--save_interval", default=20, type=int) 296 | parser.add_argument("--eval_interval", default=20, type=int) 297 | parser.add_argument("--num_episodes", default=None, type=int) 298 | parser.add_argument("--num_eval_episodes", default=20, type=int) 299 | parser.add_argument("--n_epochs", default=1000, type=int) 300 | parser.add_argument("--add_path", default=None, type=str) 301 | 302 | # Encoder & Hyperparams 303 | parser.add_argument("--embedding_dim", default=64, type=int) 304 | parser.add_argument("--num_embeddings", default=512, type=int) 305 | parser.add_argument("--num_hiddens", default=128, type=int) 306 | parser.add_argument("--num_residual_layers", default=2, type=int) 307 | parser.add_argument("--num_residual_hiddens", default=32, type=int) 308 | parser.add_argument("--batch_size", default=1024, type=int) 309 | parser.add_argument("--lr", default=3e-4, type=float) 310 | 311 | # Model load 312 | parser.add_argument("--vqvae_path", default=None, type=str) 313 | # For MLP 314 | parser.add_argument("--z_dim", default=256, type=int) 315 | # For dropout 316 | parser.add_argument("--prob", default=0.5, type=float) 317 | parser.add_argument("--num_mask", default=1, type=int) 318 | 319 | args = parser.parse_args() 320 | if args.normal: 321 | assert args.normal_eval 322 | else: 323 | assert not args.normal_eval 324 | 325 | train(args) 326 | -------------------------------------------------------------------------------- /coordconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.modules.conv as conv 4 | 5 | 6 | class AddCoords(nn.Module): 7 | def __init__(self, rank, with_r=False, use_cuda=True): 8 | super(AddCoords, self).__init__() 9 | self.rank = rank 10 | self.with_r = with_r 11 | self.use_cuda = use_cuda 12 | 13 | def forward(self, input_tensor): 14 | """ 15 | :param input_tensor: shape (N, C_in, H, W) 16 | :return: 17 | """ 18 | if self.rank == 1: 19 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape 20 | xx_range = torch.arange(dim_x, dtype=torch.int32) 21 | xx_channel = xx_range[None, None, :] 22 | 23 | xx_channel = xx_channel.float() / (dim_x - 1) 24 | xx_channel = xx_channel * 2 - 1 25 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) 26 | 27 | if torch.cuda.is_available and self.use_cuda: 28 | input_tensor = input_tensor.cuda() 29 | xx_channel = xx_channel.cuda() 30 | out = torch.cat([input_tensor, xx_channel], dim=1) 31 | 32 | if self.with_r: 33 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) 34 | out = torch.cat([out, rr], dim=1) 35 | 36 | elif self.rank == 2: 37 | batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape 38 | xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) 39 | yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) 40 | 41 | xx_range = torch.arange(dim_y, dtype=torch.int32) 42 | yy_range = torch.arange(dim_x, dtype=torch.int32) 43 | xx_range = xx_range[None, None, :, None] 44 | yy_range = yy_range[None, None, :, None] 45 | 46 | xx_channel = torch.matmul(xx_range, xx_ones) 47 | yy_channel = torch.matmul(yy_range, yy_ones) 48 | 49 | # transpose y 50 | yy_channel = yy_channel.permute(0, 1, 3, 2) 51 | 52 | xx_channel = xx_channel.float() / (dim_y - 1) 53 | yy_channel = yy_channel.float() / (dim_x - 1) 54 | 55 | xx_channel = xx_channel * 2 - 1 56 | yy_channel = yy_channel * 2 - 1 57 | 58 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) 59 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) 60 | 61 | if torch.cuda.is_available and self.use_cuda: 62 | input_tensor = input_tensor.cuda() 63 | xx_channel = xx_channel.cuda() 64 | yy_channel = yy_channel.cuda() 65 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 66 | 67 | if self.with_r: 68 | rr = torch.sqrt( 69 | torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) 70 | ) 71 | out = torch.cat([out, rr], dim=1) 72 | 73 | elif self.rank == 3: 74 | batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = input_tensor.shape 75 | xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) 76 | yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) 77 | zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) 78 | 79 | xy_range = torch.arange(dim_y, dtype=torch.int32) 80 | xy_range = xy_range[None, None, None, :, None] 81 | 82 | yz_range = torch.arange(dim_z, dtype=torch.int32) 83 | yz_range = yz_range[None, None, None, :, None] 84 | 85 | zx_range = torch.arange(dim_x, dtype=torch.int32) 86 | zx_range = zx_range[None, None, None, :, None] 87 | 88 | xy_channel = torch.matmul(xy_range, xx_ones) 89 | xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) 90 | 91 | yz_channel = torch.matmul(yz_range, yy_ones) 92 | yz_channel = yz_channel.permute(0, 1, 3, 4, 2) 93 | yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) 94 | 95 | zx_channel = torch.matmul(zx_range, zz_ones) 96 | zx_channel = zx_channel.permute(0, 1, 4, 2, 3) 97 | zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) 98 | 99 | if torch.cuda.is_available and self.use_cuda: 100 | input_tensor = input_tensor.cuda() 101 | xx_channel = xx_channel.cuda() 102 | yy_channel = yy_channel.cuda() 103 | zz_channel = zz_channel.cuda() 104 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], dim=1) 105 | 106 | if self.with_r: 107 | rr = torch.sqrt( 108 | torch.pow(xx_channel - 0.5, 2) 109 | + torch.pow(yy_channel - 0.5, 2) 110 | + torch.pow(zz_channel - 0.5, 2) 111 | ) 112 | out = torch.cat([out, rr], dim=1) 113 | else: 114 | raise NotImplementedError 115 | 116 | return out 117 | 118 | 119 | class CoordConv1d(conv.Conv1d): 120 | def __init__( 121 | self, 122 | in_channels, 123 | out_channels, 124 | kernel_size, 125 | stride=1, 126 | padding=0, 127 | dilation=1, 128 | groups=1, 129 | bias=True, 130 | with_r=False, 131 | use_cuda=True, 132 | ): 133 | super(CoordConv1d, self).__init__( 134 | in_channels, 135 | out_channels, 136 | kernel_size, 137 | stride, 138 | padding, 139 | dilation, 140 | groups, 141 | bias, 142 | ) 143 | self.rank = 1 144 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 145 | self.conv = nn.Conv1d( 146 | in_channels + self.rank + int(with_r), 147 | out_channels, 148 | kernel_size, 149 | stride, 150 | padding, 151 | dilation, 152 | groups, 153 | bias, 154 | ) 155 | 156 | def forward(self, input_tensor): 157 | """ 158 | input_tensor_shape: (N, C_in,H,W) 159 | output_tensor_shape: N,C_out,H_out,W_out) 160 | :return: CoordConv2d Result 161 | """ 162 | out = self.addcoords(input_tensor) 163 | out = self.conv(out) 164 | 165 | return out 166 | 167 | 168 | class CoordConv2d(conv.Conv2d): 169 | def __init__( 170 | self, 171 | in_channels, 172 | out_channels, 173 | kernel_size, 174 | stride=1, 175 | padding=0, 176 | dilation=1, 177 | groups=1, 178 | bias=True, 179 | with_r=False, 180 | use_cuda=True, 181 | ): 182 | super(CoordConv2d, self).__init__( 183 | in_channels, 184 | out_channels, 185 | kernel_size, 186 | stride, 187 | padding, 188 | dilation, 189 | groups, 190 | bias, 191 | ) 192 | self.rank = 2 193 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 194 | self.conv = nn.Conv2d( 195 | in_channels + self.rank + int(with_r), 196 | out_channels, 197 | kernel_size, 198 | stride, 199 | padding, 200 | dilation, 201 | groups, 202 | bias, 203 | ) 204 | 205 | def forward(self, input_tensor): 206 | """ 207 | input_tensor_shape: (N, C_in,H,W) 208 | output_tensor_shape: N,C_out,H_out,W_out) 209 | :return: CoordConv2d Result 210 | """ 211 | out = self.addcoords(input_tensor) 212 | out = self.conv(out) 213 | 214 | return out 215 | 216 | 217 | class CoordConv3d(conv.Conv3d): 218 | def __init__( 219 | self, 220 | in_channels, 221 | out_channels, 222 | kernel_size, 223 | stride=1, 224 | padding=0, 225 | dilation=1, 226 | groups=1, 227 | bias=True, 228 | with_r=False, 229 | use_cuda=True, 230 | ): 231 | super(CoordConv3d, self).__init__( 232 | in_channels, 233 | out_channels, 234 | kernel_size, 235 | stride, 236 | padding, 237 | dilation, 238 | groups, 239 | bias, 240 | ) 241 | self.rank = 3 242 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 243 | self.conv = nn.Conv3d( 244 | in_channels + self.rank + int(with_r), 245 | out_channels, 246 | kernel_size, 247 | stride, 248 | padding, 249 | dilation, 250 | groups, 251 | bias, 252 | ) 253 | 254 | def forward(self, input_tensor): 255 | """ 256 | input_tensor_shape: (N, C_in,H,W) 257 | output_tensor_shape: N,C_out,H_out,W_out) 258 | :return: CoordConv2d Result 259 | """ 260 | out = self.addcoords(input_tensor) 261 | out = self.conv(out) 262 | 263 | return out 264 | 265 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for i in Alien Amidar Assault Asterix BankHeist BattleZone Boxing Breakout ChopperCommand CrazyClimber \ 3 | DemonAttack Enduro Freeway Frostbite Gopher Hero Jamesbond Kangaroo Krull KungFuMaster \ 4 | MsPacman Pong PrivateEye Qbert RoadRunner Seaquest UpNDown 5 | do 6 | echo $i 7 | mkdir -p $i 8 | cd $i 9 | gsutil -m cp -n -R gs://atari-replay-datasets/dqn/$i/1 . 10 | cd .. 11 | done -------------------------------------------------------------------------------- /linear_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from coordconv import AddCoords 5 | 6 | 7 | def weight_init(m): 8 | """Custom weight init for Conv2D and Linear layers.""" 9 | if isinstance(m, nn.Linear): 10 | nn.init.orthogonal_(m.weight.data) 11 | if hasattr(m.bias, "data"): 12 | m.bias.data.fill_(0.0) 13 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 14 | gain = nn.init.calculate_gain("relu") 15 | nn.init.orthogonal_(m.weight.data, gain) 16 | if hasattr(m.bias, "data"): 17 | m.bias.data.fill_(0.0) 18 | 19 | 20 | ## For CCIL 21 | class CoordConvEncoder(nn.Module): 22 | def __init__(self, input_channels, embedding_dim, ch_div=1): 23 | super(CoordConvEncoder, self).__init__() 24 | self.coordconv = AddCoords(2, with_r=False, use_cuda=True) 25 | self.conv1 = nn.Conv2d(input_channels + 2, 64 // ch_div, 5, stride=2) 26 | self.conv2 = nn.Conv2d(64 // ch_div, 128 // ch_div, 5, stride=2) 27 | self.conv3 = nn.Conv2d(128 // ch_div, 256 // ch_div, 5, stride=2) 28 | self.conv4 = nn.Conv2d(256 // ch_div, 512 // ch_div, 5, stride=2) 29 | 30 | self.fc = nn.Linear(512 // ch_div, embedding_dim) 31 | 32 | def forward(self, x): 33 | x = self.coordconv(x) # Bx3x64x64 34 | x = F.leaky_relu(self.conv1(x), 0.2) # Bx64x30x30 35 | x = F.leaky_relu(self.conv2(x), 0.2) # Bx128x13x13 36 | x = F.leaky_relu(self.conv3(x), 0.2) # Bx256x5x5 37 | x = F.leaky_relu(self.conv4(x), 0.2) # Bx512x1x1 38 | 39 | x = self.fc(torch.flatten(x, start_dim=1)) # BxN 40 | return x 41 | 42 | 43 | class CoordConvDecoder(nn.Module): 44 | def __init__(self, input_channels, embedding_dim, ch_div=1): 45 | super(CoordConvDecoder, self).__init__() 46 | self.coordconv = AddCoords(2, with_r=False, use_cuda=True) 47 | self.conv1 = nn.Conv2d(embedding_dim + 2, 512 // ch_div, 1) 48 | self.conv2 = nn.Conv2d(512 // ch_div, 256 // ch_div, 1) 49 | self.conv3 = nn.Conv2d(256 // ch_div, 256 // ch_div, 1) 50 | self.conv4 = nn.Conv2d(256 // ch_div, 128 // ch_div, 1) 51 | self.conv5 = nn.Conv2d(128 // ch_div, 64 // ch_div, 1) 52 | self.conv6 = nn.Conv2d(64 // ch_div, input_channels, 1) 53 | 54 | def forward(self, x): 55 | x = x.view(-1, x.shape[1], 1, 1) # BxNx1x1 56 | x = x.repeat(1, 1, 64, 64) # BxNx64x64 57 | x = self.coordconv(x) # Bx(N+2)x64x64 58 | x = F.relu(self.conv1(x)) # Bx512x64x64 59 | x = F.relu(self.conv2(x)) # Bx256x64x64 60 | x = F.relu(self.conv3(x)) # Bx256x64x64 61 | x = F.relu(self.conv4(x)) # Bx128x64x64 62 | x = F.relu(self.conv5(x)) # Bx64x64x64 63 | 64 | x = torch.sigmoid(self.conv6(x)) # Bx1x64x64 65 | return x 66 | 67 | 68 | class VectorQuantizer(nn.Module): 69 | def __init__(self, embedding_dim, num_embeddings, commitment_cost): 70 | super(VectorQuantizer, self).__init__() 71 | 72 | self._embedding_dim = embedding_dim 73 | self._num_embeddings = num_embeddings 74 | self._commitment_cost = commitment_cost 75 | 76 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 77 | self._embedding.weight.data.uniform_( 78 | -1 / self._num_embeddings, 1 / self._num_embeddings 79 | ) 80 | 81 | def forward(self, inputs): 82 | # convert inputs from BCHW -> BHWC 83 | inputs = inputs.permute(0, 2, 3, 1).contiguous() 84 | input_shape = inputs.shape 85 | 86 | # Flatten input 87 | flat_input = inputs.view(-1, self._embedding_dim) 88 | 89 | # Calculate distances 90 | distances = ( 91 | torch.sum(flat_input ** 2, dim=1, keepdim=True) 92 | # + torch.sum(self._embedding.weight**2, dim=1) 93 | + torch.sum(self._embedding.weight.t() ** 2, dim=0, keepdim=True) 94 | - 2 * torch.matmul(flat_input, self._embedding.weight.t()) 95 | ) 96 | 97 | # Encoding 98 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 99 | encodings = torch.zeros( 100 | encoding_indices.shape[0], self._num_embeddings, device=inputs.device 101 | ) 102 | encodings.scatter_(1, encoding_indices, 1) 103 | 104 | # Quantize and unflatten 105 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) 106 | 107 | # Loss 108 | e_latent_loss = F.mse_loss(quantized.detach(), inputs, reduction="none").mean( 109 | (1, 2, 3) 110 | ) 111 | q_latent_loss = F.mse_loss(quantized, inputs.detach(), reduction="none").mean( 112 | (1, 2, 3) 113 | ) 114 | loss = q_latent_loss + self._commitment_cost * e_latent_loss 115 | 116 | # Straight Through Estimator 117 | quantized = inputs + (quantized - inputs).detach() 118 | avg_probs = torch.mean(encodings, dim=0) 119 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 120 | 121 | # convert quantized from BHWC -> BCHW 122 | quantized = quantized.permute(0, 3, 1, 2).contiguous() 123 | encoding_indices = encoding_indices.view(input_shape[0], -1) 124 | return quantized, loss, perplexity, encodings, encoding_indices, distances 125 | 126 | 127 | class Residual(nn.Module): 128 | def __init__(self, in_channels, num_hiddens, num_residual_hiddens): 129 | super(Residual, self).__init__() 130 | self._block = nn.Sequential( 131 | nn.ReLU(True), 132 | nn.Conv2d( 133 | in_channels=in_channels, 134 | out_channels=num_residual_hiddens, 135 | kernel_size=3, 136 | stride=1, 137 | padding=1, 138 | bias=False, 139 | ), 140 | nn.ReLU(True), 141 | nn.Conv2d( 142 | in_channels=num_residual_hiddens, 143 | out_channels=num_hiddens, 144 | kernel_size=1, 145 | stride=1, 146 | bias=False, 147 | ), 148 | ) 149 | 150 | def forward(self, x): 151 | return x + self._block(x) 152 | 153 | 154 | class ResidualStack(nn.Module): 155 | def __init__( 156 | self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens 157 | ): 158 | super(ResidualStack, self).__init__() 159 | self._num_residual_layers = num_residual_layers 160 | self._layers = nn.ModuleList( 161 | [ 162 | Residual(in_channels, num_hiddens, num_residual_hiddens) 163 | for _ in range(self._num_residual_layers) 164 | ] 165 | ) 166 | 167 | def forward(self, x): 168 | for i in range(self._num_residual_layers): 169 | x = self._layers[i](x) 170 | return F.relu(x) 171 | 172 | 173 | class Encoder(nn.Module): 174 | def __init__( 175 | self, 176 | input_channels, 177 | embedding_dim, 178 | num_hiddens, 179 | num_residual_layers, 180 | num_residual_hiddens, 181 | ): 182 | super(Encoder, self).__init__() 183 | self._input_channels = input_channels 184 | self._embedding_dim = embedding_dim 185 | self._num_hiddens = num_hiddens 186 | self._num_residual_layers = num_residual_layers 187 | self._num_residual_hiddens = num_residual_hiddens 188 | 189 | ## 42 x 42 190 | self._conv_1 = nn.Conv2d( 191 | in_channels=input_channels, 192 | out_channels=num_hiddens // 4, 193 | kernel_size=4, 194 | stride=2, 195 | padding=1, 196 | ) 197 | ## 21 x 21 198 | self._conv_2 = nn.Conv2d( 199 | in_channels=num_hiddens // 4, 200 | out_channels=num_hiddens // 2, 201 | kernel_size=4, 202 | stride=2, 203 | padding=1, 204 | ) 205 | ## 10 x 10 206 | self._conv_3 = nn.Conv2d( 207 | in_channels=num_hiddens // 2, 208 | out_channels=num_hiddens, 209 | kernel_size=4, 210 | stride=2, 211 | padding=1, 212 | ) 213 | ## 8 x 8 214 | self._conv_4 = nn.Conv2d( 215 | in_channels=num_hiddens, 216 | out_channels=num_hiddens, 217 | kernel_size=3, 218 | stride=1, 219 | padding=0, 220 | ) 221 | ## 8 x 8 222 | self._residual_stack = ResidualStack( 223 | in_channels=num_hiddens, 224 | num_hiddens=num_hiddens, 225 | num_residual_layers=num_residual_layers, 226 | num_residual_hiddens=num_residual_hiddens, 227 | ) 228 | ## 8 x 8 229 | self._conv_5 = nn.Conv2d( 230 | in_channels=num_hiddens, out_channels=embedding_dim, kernel_size=1, stride=1 231 | ) 232 | self.apply(weight_init) 233 | 234 | def forward(self, inputs): 235 | x = self._conv_1(inputs) 236 | x = F.relu(x) 237 | x = self._conv_2(x) 238 | x = F.relu(x) 239 | x = self._conv_3(x) 240 | x = F.relu(x) 241 | 242 | x = self._conv_4(x) 243 | x = self._residual_stack(x) 244 | return self._conv_5(x) 245 | 246 | 247 | class Decoder(nn.Module): 248 | def __init__( 249 | self, 250 | out_channels, 251 | embedding_dim, 252 | num_hiddens, 253 | num_residual_layers, 254 | num_residual_hiddens, 255 | ): 256 | super(Decoder, self).__init__() 257 | self._out_channles = out_channels 258 | self._embedding_dim = embedding_dim 259 | self._num_hiddens = num_hiddens 260 | self._num_residual_layers = num_residual_layers 261 | self._num_residual_hiddens = num_residual_hiddens 262 | 263 | ## 8 x 8 264 | self._conv_1 = nn.Conv2d( 265 | in_channels=embedding_dim, 266 | out_channels=num_hiddens, 267 | kernel_size=3, 268 | stride=1, 269 | padding=1, 270 | ) 271 | ## 8 x 8 272 | self._residual_stack = ResidualStack( 273 | in_channels=num_hiddens, 274 | num_hiddens=num_hiddens, 275 | num_residual_layers=num_residual_layers, 276 | num_residual_hiddens=num_residual_hiddens, 277 | ) 278 | ## 10 x 10 279 | self._conv_trans_1 = nn.ConvTranspose2d( 280 | in_channels=num_hiddens, out_channels=num_hiddens, kernel_size=3, stride=1, 281 | ) 282 | ## 21 x 21 283 | self._conv_trans_2 = nn.ConvTranspose2d( 284 | in_channels=num_hiddens, 285 | out_channels=num_hiddens // 2, 286 | kernel_size=4, 287 | stride=2, 288 | padding=1, 289 | output_padding=1, 290 | ) 291 | ## 42 x 42 292 | self._conv_trans_3 = nn.ConvTranspose2d( 293 | in_channels=num_hiddens // 2, 294 | out_channels=num_hiddens // 4, 295 | kernel_size=4, 296 | stride=2, 297 | padding=1, 298 | ) 299 | ## 84 x 84 300 | self._conv_trans_4 = nn.ConvTranspose2d( 301 | in_channels=num_hiddens // 4, 302 | out_channels=out_channels, 303 | kernel_size=4, 304 | stride=2, 305 | padding=1, 306 | ) 307 | self.apply(weight_init) 308 | 309 | def forward(self, inputs): 310 | x = self._conv_1(inputs) 311 | 312 | x = self._residual_stack(x) 313 | 314 | x = self._conv_trans_1(x) 315 | x = F.relu(x) 316 | 317 | x = self._conv_trans_2(x) 318 | x = F.relu(x) 319 | 320 | x = self._conv_trans_3(x) 321 | x = F.relu(x) 322 | return self._conv_trans_4(x) 323 | 324 | 325 | class VQVAEModel(nn.Module): 326 | def __init__(self, encoder, decoder, quantizer): 327 | super(VQVAEModel, self).__init__() 328 | self._encoder = encoder 329 | self._decoder = decoder 330 | self._quantizer = quantizer 331 | 332 | def forward(self, x, encode_only=False): 333 | z = self._encoder(x) 334 | quantized, vq_loss, _, _, encoding_indices, _ = self._quantizer(z) 335 | if encode_only: 336 | x_recon = None 337 | else: 338 | x_recon = self._decoder(quantized) 339 | return z, x_recon, vq_loss, quantized, encoding_indices 340 | 341 | 342 | class CoordConvBetaVAE(nn.Module): 343 | def __init__(self, z_dim=32, ch_div=1): 344 | super(CoordConvBetaVAE, self).__init__() 345 | 346 | self.encoder = CoordConvEncoder(1, z_dim * 2, ch_div) 347 | self.decoder = CoordConvDecoder(1, z_dim, ch_div) 348 | 349 | def encode(self, x): 350 | x = self.encoder(x) 351 | mu, logvar = x.chunk(2, dim=-1) 352 | sigma = torch.exp(logvar / 2.0) 353 | epsilon = torch.randn_like(mu) 354 | z = mu + sigma * epsilon 355 | return z, mu, sigma 356 | 357 | def decode(self, x): 358 | x = self.decoder(x) 359 | return x 360 | 361 | ## for dataparallel 362 | def forward(self, x, mode="encode"): 363 | if mode == "encode": 364 | return self.encode(x) 365 | elif mode == "decode": 366 | return self.decode(x) 367 | else: 368 | raise NotImplementedError 369 | 370 | 371 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): 372 | if hidden_depth == 0: 373 | mods = [nn.Linear(input_dim, output_dim)] 374 | else: 375 | mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] 376 | for i in range(hidden_depth - 1): 377 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] 378 | mods.append(nn.Linear(hidden_dim, output_dim)) 379 | if output_mod is not None: 380 | mods.append(output_mod) 381 | trunk = nn.Sequential(*mods) 382 | return trunk 383 | 384 | 385 | class PreActor(nn.Module): 386 | """torch.distributions implementation of an diagonal Gaussian policy.""" 387 | 388 | def __init__(self, z_dim, out_dim): 389 | super().__init__() 390 | self.trunk = mlp(z_dim, None, out_dim, 0) 391 | self.apply(weight_init) 392 | 393 | def forward(self, h): 394 | h = torch.flatten(h, start_dim=1) 395 | logits = self.trunk(h) 396 | return logits 397 | 398 | 399 | class Actor(nn.Module): 400 | """torch.distributions implementation of an diagonal Gaussian policy.""" 401 | 402 | def __init__(self, z_dim, action_dim, hidden_dim, hidden_depth): 403 | super().__init__() 404 | self.trunk = mlp(z_dim, hidden_dim, action_dim, hidden_depth) 405 | self.apply(weight_init) 406 | 407 | def forward(self, h): 408 | logits = self.trunk(h) 409 | return logits 410 | 411 | 412 | class Projector(nn.Module): 413 | """torch.distributions implementation of an diagonal Gaussian policy.""" 414 | 415 | def __init__(self, z_dim, out_dim, hidden_dim, hidden_depth): 416 | super().__init__() 417 | self.trunk = mlp(z_dim, hidden_dim, out_dim, hidden_depth) 418 | self.apply(weight_init) 419 | 420 | def forward(self, h): 421 | outputs = self.trunk(h) 422 | return outputs 423 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import logging 4 | import random 5 | from tqdm import tqdm 6 | from collections import deque 7 | import copy 8 | import time 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import tensorflow as tf 15 | import gym 16 | 17 | logging.basicConfig(level=logging.INFO) 18 | import pickle 19 | 20 | from PIL import Image, ImageFont, ImageDraw 21 | from sklearn.linear_model import Ridge 22 | from torch.distributions import Bernoulli 23 | import kornia 24 | 25 | gfile = tf.io.gfile 26 | 27 | 28 | def set_seed_everywhere(seed): 29 | torch.manual_seed(seed) 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed_all(seed) 32 | np.random.seed(seed) 33 | random.seed(seed) 34 | 35 | 36 | def load_dataset(env, seed, datapath, normal, num_data, stack, num_episodes=None): 37 | try: 38 | if num_episodes is not None: 39 | path = os.path.join( 40 | datapath, 41 | env, 42 | str(seed), 43 | "replay_logs", 44 | "saved_episodes_{}_normal{}.pkl".format(int(num_episodes), int(normal)), 45 | ) 46 | else: 47 | path = os.path.join( 48 | datapath, 49 | env, 50 | str(seed), 51 | "replay_logs", 52 | "saved_dataset_{}_normal{}.pkl".format(int(num_data), int(normal)), 53 | ) 54 | with open(path, "rb") as f: 55 | observations, actions, data_variance = pickle.load(f) 56 | except Exception as e: 57 | print(e) 58 | path = os.path.join(datapath, env, str(seed), "replay_logs") 59 | ckpts = gfile.listdir(path) 60 | observation_lists = [os.path.join(path, p) for p in ckpts if "observation" in p] 61 | observation_lists = sorted( 62 | observation_lists, key=lambda s: int(s.split(".")[-2]) 63 | ) 64 | action_lists = [os.path.join(path, p) for p in ckpts if "action" in p] 65 | action_lists = sorted(action_lists, key=lambda s: int(s.split(".")[-2])) 66 | terminal_lists = [os.path.join(path, p) for p in ckpts if "terminal" in p] 67 | terminal_lists = sorted(terminal_lists, key=lambda s: int(s.split(".")[-2])) 68 | 69 | logging.info("Loading observations..") 70 | o_ckpt = observation_lists[-1] 71 | with tf.io.gfile.GFile(o_ckpt, "rb") as f: 72 | with gzip.GzipFile(fileobj=f) as infile: 73 | obs_chunk = np.load(infile, allow_pickle=False) 74 | 75 | logging.info("Loading actions..") 76 | a_ckpt = action_lists[-1] 77 | with tf.io.gfile.GFile(a_ckpt, "rb") as f: 78 | with gzip.GzipFile(fileobj=f) as infile: 79 | act_chunk = np.load(infile, allow_pickle=False) 80 | logging.info("Loading terminals..") 81 | t_ckpt = terminal_lists[-1] 82 | with tf.io.gfile.GFile(t_ckpt, "rb") as f: 83 | with gzip.GzipFile(fileobj=f) as infile: 84 | terminal_chunk = np.load(infile, allow_pickle=False) 85 | 86 | if num_episodes is not None: 87 | cut_idxs = np.where(terminal_chunk != 0)[0] + 1 88 | # list of episodes 89 | observations = np.split(obs_chunk, cut_idxs)[1:-1] 90 | actions = np.split(act_chunk, cut_idxs)[1:-1] 91 | terminals = np.split(terminal_chunk, cut_idxs)[1:-1] 92 | 93 | total_episodes = len(observations) 94 | num_episodes = min(int(num_episodes), total_episodes) 95 | logging.info("Number of episodes: {}".format(num_episodes)) 96 | observations = observations[: int(num_episodes)] 97 | actions = actions[: int(num_episodes)] 98 | terminals = terminals[: int(num_episodes)] 99 | 100 | observations = np.concatenate(observations, 0) 101 | actions = np.concatenate(actions, 0) 102 | terminals = np.concatenate(terminals, 0) 103 | logging.info("Number of frames: {}".format(len(observations))) 104 | 105 | data_variance = np.var( 106 | observations[: min(len(observations), 100000)] / 255.0 107 | ) 108 | else: 109 | observations = obs_chunk[: int(num_data)] 110 | actions = act_chunk[: int(num_data)] 111 | terminals = terminal_chunk[: int(num_data)] 112 | 113 | data_variance = np.var(observations[: min(int(num_data), 100000)] / 255.0) 114 | 115 | logging.info("Stacking dataset..") 116 | stacked_obs = [] 117 | stacked_actions = [] 118 | previous_actions = [] 119 | i = stack 120 | terminal_cnt = 0 121 | while True: 122 | if terminals[i] == 0: 123 | stacked_obs.append(observations[i - stack + 1 : i + 1]) 124 | stacked_actions.append(actions[i]) 125 | previous_actions.append(actions[i - 1]) 126 | i += 1 127 | else: 128 | terminal_cnt += 1 129 | i += stack 130 | if i >= len(observations): 131 | break 132 | observations = np.array(stacked_obs) 133 | actions = np.array(stacked_actions) 134 | 135 | logging.info("Number of terminals: {}".format(terminal_cnt)) 136 | 137 | if not normal: 138 | confounded_observations = np.empty( 139 | shape=(observations.shape[0], *observations.shape[1:]), 140 | dtype=observations.dtype, 141 | ) 142 | logging.info("Building dataset with previous actions to the images..") 143 | for i in tqdm(range(observations.shape[0])): 144 | if stack != 1: 145 | img = Image.fromarray(np.transpose(observations[i], (1, 2, 0))) 146 | else: 147 | img = Image.fromarray(observations[i][0]) 148 | draw = ImageDraw.Draw(img) 149 | font = ImageFont.truetype("arial.ttf", size=16) 150 | draw.text( 151 | (11, 55), 152 | "{}".format(previous_actions[i]), 153 | fill=(255,) * stack, 154 | font=font, 155 | ) 156 | if stack != 1: 157 | confounded_observations[i] = np.transpose( 158 | np.asarray(img), (2, 0, 1) 159 | ) 160 | else: 161 | confounded_observations[i] = np.asarray(img)[None, ...] 162 | 163 | observations = confounded_observations 164 | 165 | if num_episodes is not None: 166 | path = os.path.join( 167 | datapath, 168 | env, 169 | str(seed), 170 | "replay_logs", 171 | "saved_episodes_{}_normal{}.pkl".format(int(num_episodes), int(normal)), 172 | ) 173 | with open(path, "wb") as f: 174 | pickle.dump([observations, actions, data_variance], f, protocol=4) 175 | else: 176 | path = os.path.join( 177 | datapath, 178 | env, 179 | str(seed), 180 | "replay_logs", 181 | "saved_dataset_{}_normal{}.pkl".format(int(num_data), int(normal)), 182 | ) 183 | with open(path, "wb") as f: 184 | pickle.dump([observations, actions, data_variance], f, protocol=4) 185 | 186 | logging.info("Done!") 187 | assert observations.shape[0] == actions.shape[0], ( 188 | observations.shape, 189 | actions.shape, 190 | ) 191 | return observations, actions, data_variance 192 | 193 | 194 | class StackedObs: 195 | def __init__(self, stack, confounded): 196 | self._stack = stack 197 | self._confounded = confounded 198 | self._deque = deque(maxlen=stack) 199 | self._font = ImageFont.truetype("arial.ttf", size=16) 200 | 201 | def reset(self, obs): 202 | self._deque.clear() 203 | for _ in range(self._stack): 204 | self._deque.append(obs) 205 | prev_action = 0 206 | return self._get_stacked_obs(prev_action) 207 | 208 | def step(self, obs, prev_action): 209 | self._deque.append(obs) 210 | return self._get_stacked_obs(prev_action) 211 | 212 | def _get_stacked_obs(self, prev_action): 213 | if self._confounded: 214 | stacked_obs = [] 215 | for c in range(self._stack): 216 | img = Image.fromarray(self._deque[c][..., 0]) 217 | draw = ImageDraw.Draw(img) 218 | draw.text( 219 | (11, 55), "{}".format(prev_action), fill=255, font=self._font, 220 | ) 221 | obs = np.asarray(img)[..., None] 222 | stacked_obs.append(obs) 223 | stacked_obs = np.concatenate(stacked_obs, axis=2) 224 | else: 225 | stacked_obs = np.concatenate(self._deque, axis=2) 226 | stacked_obs = np.transpose(stacked_obs, (2, 0, 1)) 227 | return stacked_obs 228 | 229 | 230 | def sample(weights, temperature): 231 | return ( 232 | Bernoulli(logits=torch.from_numpy(weights) / temperature) 233 | .sample() 234 | .long() 235 | .numpy() 236 | ) 237 | 238 | 239 | def linear_regression(masks, rewards, alpha=1.0): 240 | model = Ridge(alpha).fit(masks, rewards) 241 | return model.coef_, model.intercept_ 242 | 243 | 244 | def evaluate( 245 | env, 246 | pre_actor, 247 | actor, 248 | model, 249 | mode, 250 | device, 251 | args, 252 | topk_index=None, 253 | mask=None, 254 | num_eval_episodes=None, 255 | quantizer=None, 256 | ): 257 | model.eval() 258 | actor.eval() 259 | stacked_obs_factory = StackedObs(args.stack, not args.normal_eval) 260 | average_episode_reward = 0 261 | if num_eval_episodes is None: 262 | num_eval_episodes = args.num_eval_episodes 263 | 264 | if hasattr(args, "coord_conv"): 265 | resize = kornia.geometry.Resize(64) 266 | 267 | for episode in range(num_eval_episodes): 268 | obs = env.reset() 269 | done = False 270 | episode_reward = 0 271 | step = 0 272 | while not done: 273 | if step == 0: 274 | stacked_obs = stacked_obs_factory.reset(obs) 275 | 276 | with torch.no_grad(): 277 | stacked_obs = ( 278 | torch.as_tensor( 279 | stacked_obs, device=device, dtype=torch.float32 280 | ).unsqueeze(0) 281 | / 255.0 282 | ) 283 | if hasattr(args, "coord_conv"): 284 | if args.coord_conv: 285 | stacked_obs = resize(stacked_obs) 286 | 287 | if mode in ["cnn", "beta_vae"]: 288 | features = model(stacked_obs) 289 | else: 290 | raise NotImplementedError(mode) 291 | 292 | if mode == "cnn": 293 | if quantizer is not None: 294 | features = quantizer(features)[0] 295 | features = pre_actor(features) 296 | action = actor(features).argmax(1)[0].cpu().item() 297 | elif mode == "beta_vae": 298 | features = pre_actor(torch.flatten(features, start_dim=1)) 299 | features, _ = features.chunk(2, dim=-1) # mu 300 | features = torch.cat([features, torch.ones_like(features)], dim=1) 301 | action = actor(features).argmax(1)[0].cpu().item() 302 | else: 303 | raise NotImplementedError(mode) 304 | 305 | obs, reward, done, info = env.step(action) 306 | prev_action = action 307 | stacked_obs = stacked_obs_factory.step(obs, prev_action) 308 | episode_reward += reward 309 | step += 1 310 | if step == 27000: 311 | done = True 312 | 313 | average_episode_reward += episode_reward 314 | average_episode_reward /= num_eval_episodes 315 | model.train() 316 | actor.train() 317 | return average_episode_reward 318 | 319 | 320 | def evaluate_crlr( 321 | env, actor, model, encoder, quantizer, device, args, num_eval_episodes=None, 322 | ): 323 | model.eval() 324 | actor.eval() 325 | encoder.eval() 326 | quantizer.eval() 327 | stacked_obs_factory = StackedObs(args.stack, not args.normal_eval) 328 | average_episode_reward = 0 329 | if num_eval_episodes is None: 330 | num_eval_episodes = args.num_eval_episodes 331 | for episode in tqdm(range(num_eval_episodes)): 332 | obs = env.reset() 333 | done = False 334 | episode_reward = 0 335 | step = 0 336 | while not done: 337 | if step == 0: 338 | stacked_obs = stacked_obs_factory.reset(obs) 339 | 340 | with torch.no_grad(): 341 | stacked_obs = ( 342 | torch.as_tensor( 343 | stacked_obs, device=device, dtype=torch.float32 344 | ).unsqueeze(0) 345 | / 255.0 346 | ) 347 | 348 | z = encoder(stacked_obs) 349 | z, *_, encoding_indices, _ = quantizer(z) 350 | # features = model(stacked_obs) 351 | 352 | ## one-hot encoding 353 | encoding_indices_flatten = encoding_indices.view(-1) # (Bx64) 354 | encoding_indices_onehot = torch.zeros( 355 | (len(encoding_indices_flatten), args.num_embeddings), 356 | device=encoding_indices_flatten.device, 357 | ) 358 | encoding_indices_onehot.scatter_( 359 | 1, encoding_indices_flatten.unsqueeze(1), 1 360 | ) 361 | encoding_indices_onehot = encoding_indices_onehot.view( 362 | 1, -1, args.num_embeddings 363 | ) # B x 64 x 512 364 | 365 | logits = actor(torch.flatten(encoding_indices_onehot, start_dim=1)) 366 | action = logits.argmax(1)[0].cpu().item() 367 | 368 | obs, reward, done, info = env.step(action) 369 | prev_action = action 370 | stacked_obs = stacked_obs_factory.step(obs, prev_action) 371 | episode_reward += reward 372 | step += 1 373 | if step == 27000: 374 | done = True 375 | 376 | average_episode_reward += episode_reward 377 | average_episode_reward /= num_eval_episodes 378 | model.train() 379 | actor.train() 380 | return average_episode_reward 381 | 382 | 383 | def categorical_confounder_balancing_loss(x, w, num_classes, x_onehot=None): 384 | N, P = x.shape 385 | 386 | # one-hot encoding 387 | if x_onehot is None: 388 | is_treat = torch.zeros((N * P, num_classes), device=x.device) 389 | is_treat.scatter_(1, x.reshape(-1).unsqueeze(1), 1) 390 | is_treat = is_treat.view(N, P, num_classes) 391 | is_treat = is_treat.permute(2, 0, 1) # NPC -> CNP 392 | else: 393 | is_treat = x_onehot.permute(2, 0, 1) 394 | 395 | w = w.unsqueeze(0).repeat(num_classes, 1) # N -> CN 396 | 397 | ## CPN x (CN1 * CNP) * CPP = CPP 398 | target_set = torch.bmm( 399 | is_treat.permute(0, 2, 1), F.normalize(w.unsqueeze(2) * is_treat, p=1, dim=1) 400 | ) * ~torch.eye(P, dtype=bool, device=x.device).unsqueeze(0).repeat( 401 | num_classes, 1, 1 402 | ) 403 | target_set = target_set.permute(1, 2, 0) # CPP -> PPC 404 | target_set = target_set.reshape(P, -1) # P(PC) 405 | loss = torch.sum(torch.var(target_set, dim=0)) 406 | 407 | return loss 408 | --------------------------------------------------------------------------------