├── .gitignore ├── Readme.md ├── agents ├── __init__.py ├── base_agent.py └── ppo.py ├── common ├── __init__.py ├── env │ ├── __init__.py │ ├── atari_wrappers.py │ ├── parallel_env.py │ └── procgen_wrappers.py ├── logger.py ├── misc_util.py ├── model.py ├── policy.py └── storage.py ├── compute_metrics.py ├── config.py ├── experiments └── scripts │ ├── get-fig2-data.sh │ ├── plot-figure2.py │ ├── train-coinrun.sh │ ├── train-keys-and-chests.sh │ └── train-maze-I.sh ├── hyperparams └── procgen │ └── config.yml ├── plot_training_csv.py ├── plot_value_coin_barchart.py ├── render.py ├── run_coinrun.py ├── run_utils.py ├── run_vanilla_coinrun.sh ├── train-interleave-envs.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *,cover 45 | .hypothesis/ 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | local_settings.py 53 | 54 | # Flask stuff: 55 | instance/ 56 | .webassets-cache 57 | 58 | # Scrapy stuff: 59 | .scrapy 60 | 61 | # Sphinx documentation 62 | docs/_build/ 63 | 64 | # PyBuilder 65 | target/ 66 | 67 | # IPython Notebook 68 | .ipynb_checkpoints 69 | 70 | # pyenv 71 | .python-version 72 | 73 | # celery beat schedule file 74 | celerybeat-schedule 75 | 76 | # Spyder project settings 77 | .spyderproject 78 | 79 | # Rope project settings 80 | .ropeproject 81 | 82 | 83 | 84 | .idea/ 85 | logs/ 86 | experiments/results 87 | videos 88 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Training code for Goal Misgeneralization in Deep Reinforcement Learning 2 | 3 | This code is based on a fork of [this repository](https://github.com/joonleesky/train-procgen-pytorch) by Hojoon Lee. 4 | It includes scripts for training RL agents on [modified procgen environments](https://github.com/JacobPfau/procgenAISC) and producing the figures for the paper [Goal Misgeneralization in Deep Reinforcement Learning](https://arxiv.org/abs/2105.14111). 5 | 6 | 7 | ## Requirements 8 | 9 | - python>=3.6 10 | - torch 1.3 11 | - procgen (you will need to install our custom procgen linked above) 12 | - pyyaml 13 | - pandas 14 | - tensorboard==2.5 15 | 16 | ## Reproducing experiments 17 | 18 | ### Coinrun 19 | 20 | Train: 21 | 22 | ``` 23 | python train.py --exp_name coinrun --env_name coinrun --num_levels 100000 --distribution_mode hard --param_name hard-500 --num_timesteps 200000000 --num_checkpoints 5 --seed 6033 --random_percent 0 24 | ``` 25 | 26 | In order to reproduce the experiments from the ablation, change the `random_percent` variable. 27 | 28 | Test: 29 | ``` 30 | python render.py --exp_name coinrun_test --env_name coinrun_aisc --distribution_mode hard --param_name hard-500 --model_file PATH_TO_MODEL_FILE 31 | ``` 32 | where `PATH_TO_MODEL_FILE` is the path to the model file generated by the above training command. 33 | 34 | ### Maze (Variant 1) 35 | 36 | ``` 37 | python train.py --exp_name maze1 --env_name maze_aisc --num_levels 100000 --distribution_mode hard --param_name hard-500 --num_timesteps 200000000 --num_checkpoints 5 --seed 1080 38 | ``` 39 | 40 | ``` 41 | python render.py --exp_name maze1_test --env_name maze --distribution_mode hard --param_name hard-500 --model_file PATH_TO_MODEL_FILE 42 | ``` 43 | 44 | ### Maze (Variant 2) 45 | 46 | ``` 47 | python train.py --exp_name maze2 --env_name maze_yellowgem --num_levels 100000 --distribution_mode hard --param_name hard-500 --num_timesteps 200000000 --num_checkpoints 5 --seed 2809 48 | ``` 49 | 50 | ``` 51 | python render.py --exp_name maze2_test --env_name maze_redgem_yellowstar --distribution_mode hard --param_name hard-500 --model_file PATH_TO_MODEL_FILE 52 | ``` 53 | 54 | ### Keys and Chests 55 | 56 | ``` 57 | python train.py --exp_name keys_chests --env_name heist_aisc_many_chests --num_levels 100000 --distribution_mode hard --param_name hard-500 --num_timesteps 200000000 --num_checkpoints 5 --seed 1111 58 | ``` 59 | 60 | ``` 61 | python render.py --exp_name maze2_test --env_name heist_aisc_many_keys --distribution_mode hard --param_name hard-500 --model_file PATH_TO_MODEL_FILE 62 | 63 | 64 | 65 | ``` 66 | 67 | The original Readme (not our work) is reproduced below. 68 | 69 | 70 | 71 | --- 72 | 73 | Training Procgen environment with Pytorch 74 | =============== 75 | 76 | 🆕✅🎉 *updated code: 10th September 2020: bug fixes + support recurrent policy.* 77 | 78 | ## Introduction 79 | 80 | This repository contains code to train baseline ppo agent in Procgen implemented with Pytorch. 81 | 82 | This implementation is inspired to accelerate the research in procgen environment. 83 | It aims to reproduce the result in Procgen paper. 84 | Code is designed to satisfy both readability and productivity. I tried to match the code as close as possible to [OpenAI baselines's](https://github.com/openai/train-procgen) while following the coding style from [ikostrikov's](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail). 85 | 86 | There were several key points to watch out for procgen, which differ from the general RL implementations 87 | 88 | - Xavier uniform initialization was used for conv layers rather than orthogonal initialization. 89 | - Do not use observation normalization 90 | - Gradient accumulation to [handle large mini-batch size](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255). 91 | 92 | Training logs for `starpilot` can be found on `logs/procgen/starpilot`. 93 | 94 | ## Requirements 95 | 96 | - python>=3.6 97 | - torch 1.3 98 | - procgen 99 | - pyyaml 100 | 101 | ## Train 102 | 103 | Use `train.py` to train the agent in procgen environment. It has the following arguments: 104 | - `--exp_name`: ID to designate your expriment.s 105 | - `--env_name`: Name of the Procgen environment. 106 | - `--start_level`: Start level for for environment. 107 | - `--num_levels`: Number of training levels for environment. 108 | - `--distribution_mode`: Mode of your environ 109 | - `--param_name`: Configurations name for your training. By default, the training loads hyperparameters from `config.yml/procgen/param_name`. 110 | - `--num_timesteps`: Number of total timesteps to train your agent. 111 | 112 | After you start training your agent, log and parameters are automatically stored in `logs/procgen/env-name/exp-name/` 113 | 114 | ## Try it out 115 | 116 | Sample efficiency on easy environments 117 | 118 | `python train.py --exp_name easy-run-all --env_name ENV_NAME --param_name easy --num_levels 0 --distribution_mode easy --num_timesteps 25000000` 119 | 120 | Sample efficiency on hard environments 121 | 122 | `python train.py --exp_name hard-run-all --env_name ENV_NAME --param_name hard --num_levels 0 --distribution_mode hard --num_timesteps 200000000` 123 | 124 | Generalization on easy environments 125 | 126 | `python train.py --exp_name easy-run-200 --env_name ENV_NAME --param_name easy-200 --num_levels 200 --distribution_mode easy --num_timesteps 25000000` 127 | 128 | Generalization on hard environments 129 | 130 | `python train.py --exp_name hard-run-500 --env_name ENV_NAME --param_name hard-500 --num_levels 500 --distribution_mode hard --num_timesteps 200000000` 131 | 132 | If your GPU device could handle larger memory than 5GB, increase the mini-batch size to facilitate the trianing. 133 | 134 | ## TODO 135 | 136 | - [ ] Implement Data Augmentation from [RAD](https://mishalaskin.github.io/rad/). 137 | - [ ] Create evaluation code to measure the test performance. 138 | 139 | ## References 140 | 141 | [1] [PPO: Human-level control through deep reinforcement learning ](https://arxiv.org/abs/1707.06347)
142 | [2] [GAE: High-Dimensional Continuous Control Using Generalized Advantage Estimation ](https://arxiv.org/abs/1506.02438)
143 | [3] [IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures](https://arxiv.org/abs/1802.01561)
144 | [4] [Implementation Matters in Deep Policy Gradients: A Case Study on PPO and TRPO](https://arxiv.org/abs/2005.12729)
145 | [5] [Leveraging Procedural Generation to Benchmark Reinforcement Learning](https://arxiv.org/abs/1912.01588) 146 | 147 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbkjr/train-procgen-pytorch/2906e6f77a70ff09a1b5ffac33773bfe96c722d9/agents/__init__.py -------------------------------------------------------------------------------- /agents/base_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class BaseAgent(object): 4 | """ 5 | Class for the basic agent objects. 6 | To define your own agent, subclass this class and implement the functions below. 7 | """ 8 | 9 | def __init__(self, 10 | env, 11 | policy, 12 | logger, 13 | storage, 14 | device, 15 | num_checkpoints, 16 | env_valid=None, 17 | storage_valid=None): 18 | """ 19 | env: (gym.Env) environment following the openAI Gym API 20 | """ 21 | self.env = env 22 | self.policy = policy 23 | self.logger = logger 24 | self.storage = storage 25 | self.device = device 26 | self.num_checkpoints = num_checkpoints 27 | self.env_valid = env_valid 28 | self.storage_valid = storage_valid 29 | self.t = 0 30 | 31 | def predict(self, obs): 32 | """ 33 | Predict the action with the given input 34 | """ 35 | pass 36 | 37 | def update_policy(self): 38 | """ 39 | Train the neural network model 40 | """ 41 | pass 42 | 43 | def train(self, num_timesteps): 44 | """ 45 | Train the agent with collecting the trajectories 46 | """ 47 | pass 48 | 49 | def evaluate(self): 50 | """ 51 | Evaluate the agent 52 | """ 53 | pass 54 | -------------------------------------------------------------------------------- /agents/ppo.py: -------------------------------------------------------------------------------- 1 | from .base_agent import BaseAgent 2 | from common.misc_util import adjust_lr, get_n_params 3 | import torch 4 | import torch.optim as optim 5 | import numpy as np 6 | 7 | 8 | class PPO(BaseAgent): 9 | def __init__(self, 10 | env, 11 | policy, 12 | logger, 13 | storage, 14 | device, 15 | n_checkpoints, 16 | env_valid=None, 17 | storage_valid=None, 18 | n_steps=128, 19 | n_envs=8, 20 | epoch=3, 21 | mini_batch_per_epoch=8, 22 | mini_batch_size=32*8, 23 | gamma=0.99, 24 | lmbda=0.95, 25 | learning_rate=2.5e-4, 26 | grad_clip_norm=0.5, 27 | eps_clip=0.2, 28 | value_coef=0.5, 29 | entropy_coef=0.01, 30 | normalize_adv=True, 31 | normalize_rew=True, 32 | use_gae=True, 33 | **kwargs): 34 | 35 | super(PPO, self).__init__(env, policy, logger, storage, device, 36 | n_checkpoints, env_valid, storage_valid) 37 | 38 | self.n_steps = n_steps 39 | self.n_envs = n_envs 40 | self.epoch = epoch 41 | self.mini_batch_per_epoch = mini_batch_per_epoch 42 | self.mini_batch_size = mini_batch_size 43 | self.gamma = gamma 44 | self.lmbda = lmbda 45 | self.learning_rate = learning_rate 46 | self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate, eps=1e-5) 47 | self.grad_clip_norm = grad_clip_norm 48 | self.eps_clip = eps_clip 49 | self.value_coef = value_coef 50 | self.entropy_coef = entropy_coef 51 | self.normalize_adv = normalize_adv 52 | self.normalize_rew = normalize_rew 53 | self.use_gae = use_gae 54 | 55 | def predict(self, obs, hidden_state, done): 56 | with torch.no_grad(): 57 | obs = torch.FloatTensor(obs).to(device=self.device) 58 | hidden_state = torch.FloatTensor(hidden_state).to(device=self.device) 59 | mask = torch.FloatTensor(1-done).to(device=self.device) 60 | dist, value, hidden_state = self.policy(obs, hidden_state, mask) 61 | act = dist.sample() 62 | log_prob_act = dist.log_prob(act) 63 | 64 | return act.cpu().numpy(), log_prob_act.cpu().numpy(), value.cpu().numpy(), hidden_state.cpu().numpy() 65 | 66 | def predict_w_value_saliency(self, obs, hidden_state, done): 67 | obs = torch.FloatTensor(obs).to(device=self.device) 68 | obs.requires_grad_() 69 | obs.retain_grad() 70 | hidden_state = torch.FloatTensor(hidden_state).to(device=self.device) 71 | mask = torch.FloatTensor(1-done).to(device=self.device) 72 | dist, value, hidden_state = self.policy(obs, hidden_state, mask) 73 | value.backward(retain_graph=True) 74 | act = dist.sample() 75 | log_prob_act = dist.log_prob(act) 76 | 77 | return act.detach().cpu().numpy(), log_prob_act.detach().cpu().numpy(), value.detach().cpu().numpy(), hidden_state.detach().cpu().numpy(), obs.grad.data.detach().cpu().numpy() 78 | 79 | def optimize(self): 80 | pi_loss_list, value_loss_list, entropy_loss_list = [], [], [] 81 | batch_size = self.n_steps * self.n_envs // self.mini_batch_per_epoch 82 | if batch_size < self.mini_batch_size: 83 | self.mini_batch_size = batch_size 84 | grad_accumulation_steps = batch_size / self.mini_batch_size 85 | grad_accumulation_cnt = 1 86 | 87 | self.policy.train() 88 | for e in range(self.epoch): 89 | recurrent = self.policy.is_recurrent() 90 | generator = self.storage.fetch_train_generator(mini_batch_size=self.mini_batch_size, 91 | recurrent=recurrent) 92 | for sample in generator: 93 | obs_batch, hidden_state_batch, act_batch, done_batch, \ 94 | old_log_prob_act_batch, old_value_batch, return_batch, adv_batch = sample 95 | mask_batch = (1-done_batch) 96 | dist_batch, value_batch, _ = self.policy(obs_batch, hidden_state_batch, mask_batch) 97 | 98 | # Clipped Surrogate Objective 99 | log_prob_act_batch = dist_batch.log_prob(act_batch) 100 | ratio = torch.exp(log_prob_act_batch - old_log_prob_act_batch) 101 | surr1 = ratio * adv_batch 102 | surr2 = torch.clamp(ratio, 1.0 - self.eps_clip, 1.0 + self.eps_clip) * adv_batch 103 | pi_loss = -torch.min(surr1, surr2).mean() 104 | 105 | # Clipped Bellman-Error 106 | clipped_value_batch = old_value_batch + (value_batch - old_value_batch).clamp(-self.eps_clip, self.eps_clip) 107 | v_surr1 = (value_batch - return_batch).pow(2) 108 | v_surr2 = (clipped_value_batch - return_batch).pow(2) 109 | value_loss = 0.5 * torch.max(v_surr1, v_surr2).mean() 110 | 111 | # Policy Entropy 112 | entropy_loss = dist_batch.entropy().mean() 113 | loss = pi_loss + self.value_coef * value_loss - self.entropy_coef * entropy_loss 114 | loss.backward() 115 | 116 | # Let model to handle the large batch-size with small gpu-memory 117 | if grad_accumulation_cnt % grad_accumulation_steps == 0: 118 | torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.grad_clip_norm) 119 | self.optimizer.step() 120 | self.optimizer.zero_grad() 121 | grad_accumulation_cnt += 1 122 | pi_loss_list.append(-pi_loss.item()) 123 | value_loss_list.append(-value_loss.item()) 124 | entropy_loss_list.append(entropy_loss.item()) 125 | 126 | summary = {'Loss/pi': np.mean(pi_loss_list), 127 | 'Loss/v': np.mean(value_loss_list), 128 | 'Loss/entropy': np.mean(entropy_loss_list)} 129 | return summary 130 | 131 | def train(self, num_timesteps): 132 | save_every = num_timesteps // self.num_checkpoints 133 | checkpoint_cnt = 0 134 | obs = self.env.reset() 135 | hidden_state = np.zeros((self.n_envs, self.storage.hidden_state_size)) 136 | done = np.zeros(self.n_envs) 137 | 138 | if self.env_valid is not None: 139 | obs_v = self.env_valid.reset() 140 | hidden_state_v = np.zeros((self.n_envs, self.storage.hidden_state_size)) 141 | done_v = np.zeros(self.n_envs) 142 | 143 | while self.t < num_timesteps: 144 | # Run Policy 145 | self.policy.eval() 146 | for _ in range(self.n_steps): 147 | act, log_prob_act, value, next_hidden_state = self.predict(obs, hidden_state, done) 148 | next_obs, rew, done, info = self.env.step(act) 149 | self.storage.store(obs, hidden_state, act, rew, done, info, log_prob_act, value) 150 | obs = next_obs 151 | hidden_state = next_hidden_state 152 | value_batch = self.storage.value_batch[:self.n_steps] 153 | _, _, last_val, hidden_state = self.predict(obs, hidden_state, done) 154 | self.storage.store_last(obs, hidden_state, last_val) 155 | # Compute advantage estimates 156 | self.storage.compute_estimates(self.gamma, self.lmbda, self.use_gae, self.normalize_adv) 157 | 158 | #valid 159 | if self.env_valid is not None: 160 | for _ in range(self.n_steps): 161 | act_v, log_prob_act_v, value_v, next_hidden_state_v = self.predict(obs_v, hidden_state_v, done_v) 162 | next_obs_v, rew_v, done_v, info_v = self.env_valid.step(act_v) 163 | self.storage_valid.store(obs_v, hidden_state_v, act_v, 164 | rew_v, done_v, info_v, 165 | log_prob_act_v, value_v) 166 | obs_v = next_obs_v 167 | hidden_state_v = next_hidden_state_v 168 | _, _, last_val_v, hidden_state_v = self.predict(obs_v, hidden_state_v, done_v) 169 | self.storage_valid.store_last(obs_v, hidden_state_v, last_val_v) 170 | self.storage_valid.compute_estimates(self.gamma, self.lmbda, self.use_gae, self.normalize_adv) 171 | 172 | # Optimize policy & valueq 173 | summary = self.optimize() 174 | # Log the training-procedure 175 | self.t += self.n_steps * self.n_envs 176 | rew_batch, done_batch = self.storage.fetch_log_data() 177 | if self.storage_valid is not None: 178 | rew_batch_v, done_batch_v = self.storage_valid.fetch_log_data() 179 | else: 180 | rew_batch_v = done_batch_v = None 181 | self.logger.feed(rew_batch, done_batch, rew_batch_v, done_batch_v) 182 | self.logger.dump() 183 | self.optimizer = adjust_lr(self.optimizer, self.learning_rate, self.t, num_timesteps) 184 | # Save the model 185 | if self.t > ((checkpoint_cnt+1) * save_every): 186 | print("Saving model.") 187 | torch.save({'model_state_dict': self.policy.state_dict(), 188 | 'optimizer_state_dict': self.optimizer.state_dict()}, 189 | self.logger.logdir + '/model_' + str(self.t) + '.pth') 190 | checkpoint_cnt += 1 191 | self.env.close() 192 | if self.env_valid is not None: 193 | self.env_valid.close() 194 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc_util import * -------------------------------------------------------------------------------- /common/env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jbkjr/train-procgen-pytorch/2906e6f77a70ff09a1b5ffac33773bfe96c722d9/common/env/__init__.py -------------------------------------------------------------------------------- /common/env/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | os.environ.setdefault('PATH', '') 5 | from collections import deque 6 | import gym 7 | from gym import spaces 8 | import cv2 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | class NoopResetEnv(gym.Wrapper): 12 | def __init__(self, env, noop_max=30): 13 | """Sample initial states by taking random number of no-ops on reset. 14 | No-op is assumed to be action 0. 15 | """ 16 | gym.Wrapper.__init__(self, env) 17 | self.noop_max = noop_max 18 | self.override_num_noops = None 19 | self.noop_action = 0 20 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 21 | 22 | def reset(self, **kwargs): 23 | """ Do no-op action for a number of steps in [1, noop_max].""" 24 | self.env.reset(**kwargs) 25 | if self.override_num_noops is not None: 26 | noops = self.override_num_noops 27 | else: 28 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 29 | assert noops > 0 30 | obs = None 31 | for _ in range(noops): 32 | obs, _, done, _ = self.env.step(self.noop_action) 33 | if done: 34 | obs = self.env.reset(**kwargs) 35 | return obs 36 | 37 | def step(self, ac): 38 | return self.env.step(ac) 39 | 40 | class FireResetEnv(gym.Wrapper): 41 | def __init__(self, env): 42 | """Take action on reset for environments that are fixed until firing.""" 43 | gym.Wrapper.__init__(self, env) 44 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 45 | assert len(env.unwrapped.get_action_meanings()) >= 3 46 | 47 | def reset(self, **kwargs): 48 | self.env.reset(**kwargs) 49 | obs, _, done, _ = self.env.step(1) 50 | if done: 51 | self.env.reset(**kwargs) 52 | obs, _, done, _ = self.env.step(2) 53 | if done: 54 | self.env.reset(**kwargs) 55 | return obs 56 | 57 | def step(self, ac): 58 | return self.env.step(ac) 59 | 60 | class EpisodicLifeEnv(gym.Wrapper): 61 | def __init__(self, env): 62 | """Make end-of-life == end-of-episode, but only reset on true game over. 63 | Done by DeepMind for the DQN and co. since it helps value estimation. 64 | """ 65 | gym.Wrapper.__init__(self, env) 66 | self.lives = 0 67 | self.was_real_done = True 68 | 69 | def step(self, action): 70 | obs, reward, done, info = self.env.step(action) 71 | self.was_real_done = done 72 | # check current lives, make loss of life terminal, 73 | # then update lives to handle bonus lives 74 | lives = self.env.unwrapped.ale.lives() 75 | if lives < self.lives and lives > 0: 76 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 77 | # so it's important to keep lives > 0, so that we only reset once 78 | # the environment advertises done. 79 | done = True 80 | self.lives = lives 81 | info['env_done'] = self.was_real_done 82 | return obs, reward, done, info 83 | 84 | def reset(self, **kwargs): 85 | """Reset only when lives are exhausted. 86 | This way all states are still reachable even though lives are episodic, 87 | and the learner need not know about any of this behind-the-scenes. 88 | """ 89 | if self.was_real_done: 90 | obs = self.env.reset(**kwargs) 91 | else: 92 | # no-op step to advance from terminal/lost life state 93 | obs, _, _, _ = self.env.step(0) 94 | self.lives = self.env.unwrapped.ale.lives() 95 | return obs 96 | 97 | class MaxAndSkipEnv(gym.Wrapper): 98 | def __init__(self, env, skip=4): 99 | """Return only every `skip`-th frame""" 100 | gym.Wrapper.__init__(self, env) 101 | # most recent raw observations (for max pooling across time steps) 102 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 103 | self._skip = skip 104 | 105 | def step(self, action): 106 | """Repeat action, sum reward, and max over last observations.""" 107 | total_reward = 0.0 108 | done = None 109 | for i in range(self._skip): 110 | obs, reward, done, info = self.env.step(action) 111 | if i == self._skip - 2: self._obs_buffer[0] = obs 112 | if i == self._skip - 1: self._obs_buffer[1] = obs 113 | total_reward += reward 114 | if done: 115 | break 116 | # Note that the observation on the done=True frame 117 | # doesn't matter 118 | max_frame = self._obs_buffer.max(axis=0) 119 | 120 | return max_frame, total_reward, done, info 121 | 122 | def reset(self, **kwargs): 123 | return self.env.reset(**kwargs) 124 | 125 | class ClipRewardEnv(gym.RewardWrapper): 126 | def __init__(self, env): 127 | gym.RewardWrapper.__init__(self, env) 128 | 129 | def reward(self, reward): 130 | """Bin reward to {+1, 0, -1} by its sign.""" 131 | return np.sign(reward) 132 | 133 | def step(self, act): 134 | """Bin reward to {+1, 0, -1} by its sign.""" 135 | s, rew, done, info = self.env.step(act) 136 | info['env_reward'] = rew 137 | return s, rew, done, info 138 | 139 | 140 | class WarpFrame(gym.ObservationWrapper): 141 | def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): 142 | """ 143 | Warp frames to 84x84 as done in the Nature paper and later work. 144 | If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which 145 | observation should be warped. 146 | """ 147 | super().__init__(env) 148 | self._width = width 149 | self._height = height 150 | self._grayscale = grayscale 151 | self._key = dict_space_key 152 | if self._grayscale: 153 | num_colors = 1 154 | else: 155 | num_colors = 3 156 | 157 | new_space = gym.spaces.Box( 158 | low=0, 159 | high=255, 160 | shape=(self._height, self._width, num_colors), 161 | dtype=np.uint8, 162 | ) 163 | if self._key is None: 164 | original_space = self.observation_space 165 | self.observation_space = new_space 166 | else: 167 | original_space = self.observation_space.spaces[self._key] 168 | self.observation_space.spaces[self._key] = new_space 169 | assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 170 | 171 | def observation(self, obs): 172 | if self._key is None: 173 | frame = obs 174 | else: 175 | frame = obs[self._key] 176 | 177 | if self._grayscale: 178 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 179 | frame = cv2.resize( 180 | frame, (self._width, self._height), interpolation=cv2.INTER_AREA 181 | ) 182 | if self._grayscale: 183 | frame = np.expand_dims(frame, -1) 184 | 185 | if self._key is None: 186 | obs = frame 187 | else: 188 | obs = obs.copy() 189 | obs[self._key] = frame 190 | return obs 191 | 192 | 193 | class FrameStack(gym.Wrapper): 194 | def __init__(self, env, k): 195 | """Stack k last frames. 196 | Returns lazy array, which is much more memory efficient. 197 | See Also 198 | -------- 199 | baselines.common.atari_wrappers.LazyFrames 200 | """ 201 | gym.Wrapper.__init__(self, env) 202 | self.k = k 203 | self.frames = deque([], maxlen=k) 204 | shp = env.observation_space.shape 205 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) 206 | 207 | def reset(self): 208 | ob = self.env.reset() 209 | for _ in range(self.k): 210 | self.frames.append(ob) 211 | return self._get_ob() 212 | 213 | def step(self, action): 214 | ob, reward, done, info = self.env.step(action) 215 | self.frames.append(ob) 216 | return self._get_ob(), reward, done, info 217 | 218 | def _get_ob(self): 219 | assert len(self.frames) == self.k 220 | return LazyFrames(list(self.frames)) 221 | 222 | class ScaledFloatFrame(gym.ObservationWrapper): 223 | def __init__(self, env): 224 | gym.ObservationWrapper.__init__(self, env) 225 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) 226 | 227 | def observation(self, observation): 228 | # careful! This undoes the memory optimization, use 229 | # with smaller replay buffers only. 230 | return np.array(observation).astype(np.float32) / 255.0 231 | 232 | class LazyFrames(object): 233 | def __init__(self, frames): 234 | """This object ensures that common frames between the observations are only stored once. 235 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 236 | buffers. 237 | This object should only be converted to numpy array before being passed to the model. 238 | You'd not believe how complex the previous solution was.""" 239 | self._frames = frames 240 | self._out = None 241 | 242 | def _force(self): 243 | if self._out is None: 244 | self._out = np.concatenate(self._frames, axis=-1) 245 | self._frames = None 246 | return self._out 247 | 248 | def __array__(self, dtype=None): 249 | out = self._force() 250 | if dtype is not None: 251 | out = out.astype(dtype) 252 | return out 253 | 254 | def __len__(self): 255 | return len(self._force()) 256 | 257 | def __getitem__(self, i): 258 | return self._force()[i] 259 | 260 | def count(self): 261 | frames = self._force() 262 | return frames.shape[frames.ndim - 1] 263 | 264 | def frame(self, i): 265 | return self._force()[..., i] 266 | 267 | 268 | class TransposeFrame(gym.ObservationWrapper): 269 | def __init__(self, env): 270 | gym.ObservationWrapper.__init__(self, env) 271 | obs_shape = self.observation_space.shape 272 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=(obs_shape[2], obs_shape[0], obs_shape[1]), dtype=np.float32) 273 | 274 | def observation(self, observation): 275 | return observation.transpose(2, 0, 1) 276 | 277 | 278 | def wrap_deepmind(env, episode_life=True, preprocess=True, max_and_skip=True, 279 | clip_rewards=True, no_op_reset=True, history_length=4, scale=True, transpose=True): 280 | """Configure environment for DeepMind-style Atari.""" 281 | if no_op_reset: 282 | env = NoopResetEnv(env, noop_max=30) 283 | if max_and_skip: 284 | env = MaxAndSkipEnv(env, skip=4) 285 | if episode_life: 286 | env = EpisodicLifeEnv(env) 287 | if 'FIRE' in env.unwrapped.get_action_meanings(): 288 | env = FireResetEnv(env) 289 | if preprocess: 290 | env = WarpFrame(env) 291 | if clip_rewards: 292 | env = ClipRewardEnv(env) 293 | if history_length > 1: 294 | env = FrameStack(env, history_length) 295 | if scale: 296 | env = ScaledFloatFrame(env) 297 | if transpose: 298 | env = TransposeFrame(env) 299 | return env -------------------------------------------------------------------------------- /common/env/parallel_env.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | from multiprocessing import Process, Pipe, Value 4 | 5 | 6 | def worker(worker_id, env, master_end, worker_end): 7 | master_end.close() # Forbid worker to use the master end for messaging 8 | 9 | while True: 10 | cmd, data = worker_end.recv() 11 | if cmd == 'step': 12 | ob, reward, done, info = env.step(data) 13 | if done: 14 | ob = env.reset() 15 | worker_end.send((ob, reward, done, info)) 16 | elif cmd == 'seed': 17 | worker_end.send(env.seed(data)) 18 | elif cmd == 'reset': 19 | ob = env.reset() 20 | worker_end.send(ob) 21 | elif cmd == 'close': 22 | worker_end.close() 23 | break 24 | else: 25 | raise NotImplementedError 26 | 27 | 28 | class ParallelEnv(object): 29 | """ 30 | This class 31 | """ 32 | 33 | def __init__(self, num_processes, env): 34 | self.nenvs = num_processes 35 | self.waiting = False 36 | self.closed = False 37 | self.workers = [] 38 | self.observation_space = env.observation_space 39 | self.action_space = env.action_space 40 | 41 | self.master_ends, self.send_ends = zip(*[Pipe() for _ in range(self.nenvs)]) 42 | for worker_id, (master_end, send_end) in enumerate(zip(self.master_ends, self.send_ends)): 43 | p = Process(target=worker, 44 | args=(worker_id, copy.deepcopy(env), master_end, send_end)) 45 | p.start() 46 | self.workers.append(p) 47 | 48 | def step(self, actions): 49 | """ 50 | Perform step for each environment and return the stacked transitions 51 | """ 52 | for master_end, action in zip(self.master_ends, actions): 53 | master_end.send(('step', action)) 54 | self.waiting = True 55 | 56 | results = [master_end.recv() for master_end in self.master_ends] 57 | self.waiting = False 58 | obs, rews, dones, infos = zip(*results) 59 | 60 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 61 | 62 | def seed(self, seed=None): 63 | for idx, master_end in enumerate(self.master_ends): 64 | master_end.send(('seed', seed + idx)) 65 | return [master_end.recv() for master_end in self.master_ends] 66 | 67 | def reset(self): 68 | for master_end in self.master_ends: 69 | master_end.send(('reset', None)) 70 | results = [master_end.recv() for master_end in self.master_ends] 71 | 72 | return np.stack(results) 73 | 74 | def close(self): 75 | if self.closed: 76 | return 77 | if self.waiting: 78 | [master_end.recv() for master_end in self.master_ends] 79 | for master_end in self.master_ends: 80 | master_end.send(('close', None)) 81 | for worker in self.workers: 82 | worker.join() 83 | self.closed = True -------------------------------------------------------------------------------- /common/env/procgen_wrappers.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | from abc import ABC, abstractmethod 4 | import numpy as np 5 | import gym 6 | from gym import spaces 7 | import time 8 | from collections import deque 9 | import torch 10 | 11 | 12 | """ 13 | Copy-pasted from OpenAI to obviate dependency on Baselines. Required for vectorized environments. 14 | """ 15 | 16 | class AlreadySteppingError(Exception): 17 | """ 18 | Raised when an asynchronous step is running while 19 | step_async() is called again. 20 | """ 21 | 22 | def __init__(self): 23 | msg = 'already running an async step' 24 | Exception.__init__(self, msg) 25 | 26 | 27 | class NotSteppingError(Exception): 28 | """ 29 | Raised when an asynchronous step is not running but 30 | step_wait() is called. 31 | """ 32 | 33 | def __init__(self): 34 | msg = 'not running an async step' 35 | Exception.__init__(self, msg) 36 | 37 | 38 | class VecEnv(ABC): 39 | """ 40 | An abstract asynchronous, vectorized environment. 41 | Used to batch data from multiple copies of an environment, so that 42 | each observation becomes an batch of observations, and expected action is a batch of actions to 43 | be applied per-environment. 44 | """ 45 | closed = False 46 | viewer = None 47 | 48 | metadata = { 49 | 'render.modes': ['human', 'rgb_array'] 50 | } 51 | 52 | def __init__(self, num_envs, observation_space, action_space): 53 | self.num_envs = num_envs 54 | self.observation_space = observation_space 55 | self.action_space = action_space 56 | 57 | @abstractmethod 58 | def reset(self): 59 | """ 60 | Reset all the environments and return an array of 61 | observations, or a dict of observation arrays. 62 | 63 | If step_async is still doing work, that work will 64 | be cancelled and step_wait() should not be called 65 | until step_async() is invoked again. 66 | """ 67 | pass 68 | 69 | @abstractmethod 70 | def step_async(self, actions): 71 | """ 72 | Tell all the environments to start taking a step 73 | with the given actions. 74 | Call step_wait() to get the results of the step. 75 | 76 | You should not call this if a step_async run is 77 | already pending. 78 | """ 79 | pass 80 | 81 | @abstractmethod 82 | def step_wait(self): 83 | """ 84 | Wait for the step taken with step_async(). 85 | 86 | Returns (obs, rews, dones, infos): 87 | - obs: an array of observations, or a dict of 88 | arrays of observations. 89 | - rews: an array of rewards 90 | - dones: an array of "episode done" booleans 91 | - infos: a sequence of info objects 92 | """ 93 | pass 94 | 95 | def close_extras(self): 96 | """ 97 | Clean up the extra resources, beyond what's in this base class. 98 | Only runs when not self.closed. 99 | """ 100 | pass 101 | 102 | def close(self): 103 | if self.closed: 104 | return 105 | if self.viewer is not None: 106 | self.viewer.close() 107 | self.close_extras() 108 | self.closed = True 109 | 110 | def step(self, actions): 111 | """ 112 | Step the environments synchronously. 113 | 114 | This is available for backwards compatibility. 115 | """ 116 | self.step_async(actions) 117 | return self.step_wait() 118 | 119 | def render(self, mode='human'): 120 | imgs = self.get_images() 121 | bigimg = "ARGHH" #tile_images(imgs) 122 | if mode == 'human': 123 | self.get_viewer().imshow(bigimg) 124 | return self.get_viewer().isopen 125 | elif mode == 'rgb_array': 126 | return bigimg 127 | else: 128 | raise NotImplementedError 129 | 130 | def get_images(self): 131 | """ 132 | Return RGB images from each environment 133 | """ 134 | raise NotImplementedError 135 | 136 | @property 137 | def unwrapped(self): 138 | if isinstance(self, VecEnvWrapper): 139 | return self.venv.unwrapped 140 | else: 141 | return self 142 | 143 | def get_viewer(self): 144 | if self.viewer is None: 145 | from gym.envs.classic_control import rendering 146 | self.viewer = rendering.SimpleImageViewer() 147 | return self.viewer 148 | 149 | 150 | class VecEnvWrapper(VecEnv): 151 | """ 152 | An environment wrapper that applies to an entire batch 153 | of environments at once. 154 | """ 155 | 156 | def __init__(self, venv, observation_space=None, action_space=None): 157 | self.venv = venv 158 | super().__init__(num_envs=venv.num_envs, 159 | observation_space=observation_space or venv.observation_space, 160 | action_space=action_space or venv.action_space) 161 | 162 | def step_async(self, actions): 163 | self.venv.step_async(actions) 164 | 165 | @abstractmethod 166 | def reset(self): 167 | pass 168 | 169 | @abstractmethod 170 | def step_wait(self): 171 | pass 172 | 173 | def close(self): 174 | return self.venv.close() 175 | 176 | def render(self, mode='human'): 177 | return self.venv.render(mode=mode) 178 | 179 | def get_images(self): 180 | return self.venv.get_images() 181 | 182 | def __getattr__(self, name): 183 | if name.startswith('_'): 184 | raise AttributeError("attempted to get missing private attribute '{}'".format(name)) 185 | return getattr(self.venv, name) 186 | 187 | 188 | class VecEnvObservationWrapper(VecEnvWrapper): 189 | @abstractmethod 190 | def process(self, obs): 191 | pass 192 | 193 | def reset(self): 194 | obs = self.venv.reset() 195 | return self.process(obs) 196 | 197 | def step_wait(self): 198 | obs, rews, dones, infos = self.venv.step_wait() 199 | return self.process(obs), rews, dones, infos 200 | 201 | 202 | class CloudpickleWrapper(object): 203 | """ 204 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 205 | """ 206 | 207 | def __init__(self, x): 208 | self.x = x 209 | 210 | def __getstate__(self): 211 | import cloudpickle 212 | return cloudpickle.dumps(self.x) 213 | 214 | def __setstate__(self, ob): 215 | import pickle 216 | self.x = pickle.loads(ob) 217 | 218 | 219 | @contextlib.contextmanager 220 | def clear_mpi_env_vars(): 221 | """ 222 | from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. 223 | This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing 224 | Processes. 225 | """ 226 | removed_environment = {} 227 | for k, v in list(os.environ.items()): 228 | for prefix in ['OMPI_', 'PMI_']: 229 | if k.startswith(prefix): 230 | removed_environment[k] = v 231 | del os.environ[k] 232 | try: 233 | yield 234 | finally: 235 | os.environ.update(removed_environment) 236 | 237 | 238 | class VecFrameStack(VecEnvWrapper): 239 | def __init__(self, venv, nstack): 240 | self.venv = venv 241 | self.nstack = nstack 242 | wos = venv.observation_space # wrapped ob space 243 | low = np.repeat(wos.low, self.nstack, axis=-1) 244 | high = np.repeat(wos.high, self.nstack, axis=-1) 245 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 246 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 247 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 248 | 249 | def step_wait(self): 250 | obs, rews, news, infos = self.venv.step_wait() 251 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 252 | for (i, new) in enumerate(news): 253 | if new: 254 | self.stackedobs[i] = 0 255 | self.stackedobs[..., -obs.shape[-1]:] = obs 256 | return self.stackedobs, rews, news, infos 257 | 258 | def reset(self): 259 | obs = self.venv.reset() 260 | self.stackedobs[...] = 0 261 | self.stackedobs[..., -obs.shape[-1]:] = obs 262 | return self.stackedobs 263 | 264 | class VecExtractDictObs(VecEnvObservationWrapper): 265 | def __init__(self, venv, key): 266 | self.key = key 267 | super().__init__(venv=venv, 268 | observation_space=venv.observation_space.spaces[self.key]) 269 | 270 | def process(self, obs): 271 | return obs[self.key] 272 | 273 | 274 | class RunningMeanStd(object): 275 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 276 | def __init__(self, epsilon=1e-4, shape=()): 277 | self.mean = np.zeros(shape, 'float64') 278 | self.var = np.ones(shape, 'float64') 279 | self.count = epsilon 280 | 281 | def update(self, x): 282 | batch_mean = np.mean(x, axis=0) 283 | batch_var = np.var(x, axis=0) 284 | batch_count = x.shape[0] 285 | self.update_from_moments(batch_mean, batch_var, batch_count) 286 | 287 | def update_from_moments(self, batch_mean, batch_var, batch_count): 288 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 289 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count) 290 | 291 | 292 | def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 293 | delta = batch_mean - mean 294 | tot_count = count + batch_count 295 | 296 | new_mean = mean + delta * batch_count / tot_count 297 | m_a = var * count 298 | m_b = batch_var * batch_count 299 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 300 | new_var = M2 / tot_count 301 | new_count = tot_count 302 | 303 | return new_mean, new_var, new_count 304 | 305 | 306 | class VecNormalize(VecEnvWrapper): 307 | """ 308 | A vectorized wrapper that normalizes the observations 309 | and returns from an environment. 310 | """ 311 | 312 | def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8): 313 | VecEnvWrapper.__init__(self, venv) 314 | 315 | self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None 316 | self.ret_rms = RunningMeanStd(shape=()) if ret else None 317 | 318 | self.clipob = clipob 319 | self.cliprew = cliprew 320 | self.ret = np.zeros(self.num_envs) 321 | self.gamma = gamma 322 | self.epsilon = epsilon 323 | 324 | def step_wait(self): 325 | obs, rews, news, infos = self.venv.step_wait() 326 | for i in range(len(infos)): 327 | infos[i]['env_reward'] = rews[i] 328 | self.ret = self.ret * self.gamma + rews 329 | obs = self._obfilt(obs) 330 | if self.ret_rms: 331 | self.ret_rms.update(self.ret) 332 | rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew) 333 | self.ret[news] = 0. 334 | return obs, rews, news, infos 335 | 336 | def _obfilt(self, obs): 337 | if self.ob_rms: 338 | self.ob_rms.update(obs) 339 | obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob) 340 | return obs 341 | else: 342 | return obs 343 | 344 | def reset(self): 345 | self.ret = np.zeros(self.num_envs) 346 | obs = self.venv.reset() 347 | return self._obfilt(obs) 348 | 349 | 350 | class TransposeFrame(VecEnvWrapper): 351 | def __init__(self, env): 352 | super().__init__(venv=env) 353 | obs_shape = self.observation_space.shape 354 | self.observation_space = gym.spaces.Box(low=0, high=255, shape=(obs_shape[2], obs_shape[0], obs_shape[1]), dtype=np.float32) 355 | 356 | def step_wait(self): 357 | obs, reward, done, info = self.venv.step_wait() 358 | return obs.transpose(0,3,1,2), reward, done, info 359 | 360 | def reset(self): 361 | obs = self.venv.reset() 362 | return obs.transpose(0,3,1,2) 363 | 364 | 365 | class ScaledFloatFrame(VecEnvWrapper): 366 | def __init__(self, env): 367 | super().__init__(venv=env) 368 | obs_shape = self.observation_space.shape 369 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.float32) 370 | 371 | def step_wait(self): 372 | obs, reward, done, info = self.venv.step_wait() 373 | return obs/255.0, reward, done, info 374 | 375 | def reset(self): 376 | obs = self.venv.reset() 377 | return obs/255.0 378 | -------------------------------------------------------------------------------- /common/logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from collections import deque 4 | import time 5 | import csv 6 | 7 | try: 8 | import wandb 9 | except ImportError: 10 | pass 11 | 12 | class Logger(object): 13 | 14 | def __init__(self, n_envs, logdir, use_wandb=False): 15 | self.start_time = time.time() 16 | self.n_envs = n_envs 17 | self.logdir = logdir 18 | self.use_wandb = use_wandb 19 | 20 | # training 21 | self.episode_rewards = [] 22 | for _ in range(n_envs): 23 | self.episode_rewards.append([]) 24 | 25 | self.episode_timeout_buffer = deque(maxlen = 40) 26 | self.episode_len_buffer = deque(maxlen = 40) 27 | self.episode_reward_buffer = deque(maxlen = 40) 28 | 29 | # validation 30 | self.episode_rewards_v = [] 31 | for _ in range(n_envs): 32 | self.episode_rewards_v.append([]) 33 | 34 | self.episode_timeout_buffer_v = deque(maxlen = 40) 35 | self.episode_len_buffer_v = deque(maxlen = 40) 36 | self.episode_reward_buffer_v = deque(maxlen = 40) 37 | 38 | time_metrics = ["timesteps", "wall_time", "num_episodes"] # only collected once 39 | episode_metrics = ["max_episode_rewards", "mean_episode_rewards", "min_episode_rewards", 40 | "max_episode_len", "mean_episode_len", "min_episode_len", 41 | "mean_timeouts"] # collected for both train and val envs 42 | self.log = pd.DataFrame(columns = time_metrics + episode_metrics + \ 43 | ["val_"+m for m in episode_metrics]) 44 | 45 | self.timesteps = 0 46 | self.num_episodes = 0 47 | 48 | def feed(self, rew_batch, done_batch, rew_batch_v=None, done_batch_v=None): 49 | steps = rew_batch.shape[0] 50 | rew_batch = rew_batch.T 51 | done_batch = done_batch.T 52 | 53 | valid = rew_batch_v is not None and done_batch_v is not None 54 | if valid: 55 | rew_batch_v = rew_batch_v.T 56 | done_batch_v = done_batch_v.T 57 | 58 | for i in range(self.n_envs): 59 | for j in range(steps): 60 | self.episode_rewards[i].append(rew_batch[i][j]) 61 | if valid: 62 | self.episode_rewards_v[i].append(rew_batch_v[i][j]) 63 | 64 | if done_batch[i][j]: 65 | self.episode_timeout_buffer.append(1 if j == steps-1 else 0) 66 | self.episode_len_buffer.append(len(self.episode_rewards[i])) 67 | self.episode_reward_buffer.append(np.sum(self.episode_rewards[i])) 68 | self.episode_rewards[i] = [] 69 | self.num_episodes += 1 70 | if valid and done_batch_v[i][j]: 71 | self.episode_timeout_buffer_v.append(1 if j == steps-1 else 0) 72 | self.episode_len_buffer_v.append(len(self.episode_rewards_v[i])) 73 | self.episode_reward_buffer_v.append(np.sum(self.episode_rewards_v[i])) 74 | self.episode_rewards_v[i] = [] 75 | 76 | self.timesteps += (self.n_envs * steps) 77 | 78 | def dump(self): 79 | wall_time = time.time() - self.start_time 80 | episode_statistics = self._get_episode_statistics() 81 | episode_statistics_list = list(episode_statistics.values()) 82 | log = [self.timesteps, wall_time, self.num_episodes] + episode_statistics_list 83 | self.log.loc[len(self.log)] = log 84 | 85 | with open(self.logdir + '/log-append.csv', 'a') as f: 86 | writer = csv.writer(f) 87 | if f.tell() == 0: 88 | writer.writerow(self.log.columns) 89 | writer.writerow(log) 90 | 91 | print(self.log.loc[len(self.log)-1]) 92 | 93 | if self.use_wandb: 94 | wandb.log({k: v for k, v in zip(self.log.columns, log)}) 95 | 96 | def _get_episode_statistics(self): 97 | episode_statistics = {} 98 | episode_statistics['Rewards/max_episodes'] = np.max(self.episode_reward_buffer, initial=0) 99 | episode_statistics['Rewards/mean_episodes'] = np.mean(self.episode_reward_buffer) 100 | episode_statistics['Rewards/min_episodes'] = np.min(self.episode_reward_buffer, initial=0) 101 | episode_statistics['Len/max_episodes'] = np.max(self.episode_len_buffer, initial=0) 102 | episode_statistics['Len/mean_episodes'] = np.mean(self.episode_len_buffer) 103 | episode_statistics['Len/min_episodes'] = np.min(self.episode_len_buffer, initial=0) 104 | episode_statistics['Len/mean_timeout'] = np.mean(self.episode_timeout_buffer) 105 | 106 | # valid 107 | episode_statistics['[Valid] Rewards/max_episodes'] = np.max(self.episode_reward_buffer_v, initial=0) 108 | episode_statistics['[Valid] Rewards/mean_episodes'] = np.mean(self.episode_reward_buffer_v) 109 | episode_statistics['[Valid] Rewards/min_episodes'] = np.min(self.episode_reward_buffer_v, initial=0) 110 | episode_statistics['[Valid] Len/max_episodes'] = np.max(self.episode_len_buffer_v, initial=0) 111 | episode_statistics['[Valid] Len/mean_episodes'] = np.mean(self.episode_len_buffer_v) 112 | episode_statistics['[Valid] Len/min_episodes'] = np.min(self.episode_len_buffer_v, initial=0) 113 | episode_statistics['[Valid] Len/mean_timeout'] = np.mean(self.episode_timeout_buffer_v) 114 | return episode_statistics 115 | -------------------------------------------------------------------------------- /common/misc_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import gym 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def set_global_seeds(seed): 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed_all(seed) 11 | torch.backends.cudnn.benchmark = False 12 | torch.backends.cudnn.deterministic = True 13 | 14 | 15 | def set_global_log_levels(level): 16 | gym.logger.set_level(level) 17 | 18 | 19 | def orthogonal_init(module, gain=nn.init.calculate_gain('relu')): 20 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 21 | nn.init.orthogonal_(module.weight.data, gain) 22 | nn.init.constant_(module.bias.data, 0) 23 | return module 24 | 25 | 26 | def xavier_uniform_init(module, gain=1.0): 27 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 28 | nn.init.xavier_uniform_(module.weight.data, gain) 29 | nn.init.constant_(module.bias.data, 0) 30 | return module 31 | 32 | 33 | def adjust_lr(optimizer, init_lr, timesteps, max_timesteps): 34 | lr = init_lr * (1 - (timesteps / max_timesteps)) 35 | for param_group in optimizer.param_groups: 36 | param_group['lr'] = lr 37 | return optimizer 38 | 39 | 40 | def get_n_params(model): 41 | return str(np.round(np.array([p.numel() for p in model.parameters()]).sum() / 1e6, 3)) + ' M params' -------------------------------------------------------------------------------- /common/model.py: -------------------------------------------------------------------------------- 1 | from .misc_util import orthogonal_init, xavier_uniform_init 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class Flatten(nn.Module): 7 | def forward(self, x): 8 | return torch.flatten(x, start_dim=1) 9 | 10 | 11 | class MlpModel(nn.Module): 12 | def __init__(self, 13 | input_dims=4, 14 | hidden_dims=[64, 64], 15 | **kwargs): 16 | """ 17 | input_dim: (int) number of the input dimensions 18 | hidden_dims: (list) list of the dimensions for the hidden layers 19 | use_batchnorm: (bool) whether to use batchnorm 20 | """ 21 | super(MlpModel, self).__init__() 22 | 23 | # Hidden layers 24 | hidden_dims = [input_dims] + hidden_dims 25 | layers = [] 26 | for i in range(len(hidden_dims) - 1): 27 | in_features = hidden_dims[i] 28 | out_features = hidden_dims[i + 1] 29 | layers.append(nn.Linear(in_features, out_features)) 30 | layers.append(nn.ReLU()) 31 | self.layers = nn.Sequential(*layers) 32 | self.output_dim = hidden_dims[-1] 33 | self.apply(orthogonal_init) 34 | 35 | def forward(self, x): 36 | for layer in self.layers: 37 | x = layer(x) 38 | return x 39 | 40 | 41 | class NatureModel(nn.Module): 42 | def __init__(self, 43 | in_channels, 44 | **kwargs): 45 | """ 46 | input_shape: (tuple) tuple of the input dimension shape (channel, height, width) 47 | filters: (list) list of the tuples consists of (number of channels, kernel size, and strides) 48 | use_batchnorm: (bool) whether to use batchnorm 49 | """ 50 | super(NatureModel, self).__init__() 51 | self.layers = nn.Sequential( 52 | nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=8, stride=4), nn.ReLU(), 53 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2), nn.ReLU(), 54 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(), 55 | Flatten(), 56 | nn.Linear(in_features=64*7*7, out_features=512), nn.ReLU() 57 | ) 58 | self.output_dim = 512 59 | self.apply(orthogonal_init) 60 | 61 | def forward(self, x): 62 | x = self.layers(x) 63 | return x 64 | 65 | 66 | class ResidualBlock(nn.Module): 67 | def __init__(self, 68 | in_channels): 69 | super(ResidualBlock, self).__init__() 70 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1) 71 | self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1) 72 | 73 | def forward(self, x): 74 | out = nn.ReLU()(x) 75 | out = self.conv1(out) 76 | out = nn.ReLU()(out) 77 | out = self.conv2(out) 78 | return out + x 79 | 80 | class ImpalaBlock(nn.Module): 81 | def __init__(self, in_channels, out_channels): 82 | super(ImpalaBlock, self).__init__() 83 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1) 84 | self.res1 = ResidualBlock(out_channels) 85 | self.res2 = ResidualBlock(out_channels) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | x = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)(x) 90 | x = self.res1(x) 91 | x = self.res2(x) 92 | return x 93 | 94 | scale = 1 95 | class ImpalaModel(nn.Module): 96 | def __init__(self, 97 | in_channels, 98 | **kwargs): 99 | super(ImpalaModel, self).__init__() 100 | self.block1 = ImpalaBlock(in_channels=in_channels, out_channels=16*scale) 101 | self.block2 = ImpalaBlock(in_channels=16*scale, out_channels=32*scale) 102 | self.block3 = ImpalaBlock(in_channels=32*scale, out_channels=32*scale) 103 | self.fc = nn.Linear(in_features=32*scale * 8 * 8, out_features=256) 104 | 105 | self.output_dim = 256 106 | self.apply(xavier_uniform_init) 107 | 108 | def forward(self, x): 109 | x = self.block1(x) 110 | x = self.block2(x) 111 | x = self.block3(x) 112 | x = nn.ReLU()(x) 113 | x = Flatten()(x) 114 | x = self.fc(x) 115 | x = nn.ReLU()(x) 116 | return x 117 | 118 | 119 | class GRU(nn.Module): 120 | def __init__(self, input_size, hidden_size): 121 | super(GRU, self).__init__() 122 | self.gru = orthogonal_init(nn.GRU(input_size, hidden_size), gain=1.0) 123 | 124 | def forward(self, x, hxs, masks): 125 | # Prediction 126 | if x.size(0) == hxs.size(0): 127 | # input for GRU-CELL: (L=sequence_length, N, H) 128 | # output for GRU-CELL: (output: (L, N, H), hidden: (L, N, H)) 129 | masks = masks.unsqueeze(-1) 130 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 131 | x = x.squeeze(0) 132 | hxs = hxs.squeeze(0) 133 | # Training 134 | # We will recompute the hidden state to allow gradient to be back-propagated through time 135 | else: 136 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 137 | N = hxs.size(0) 138 | T = int(x.size(0) / N) 139 | 140 | # unflatten 141 | x = x.view(T, N, x.size(1)) 142 | 143 | # Same deal with masks 144 | masks = masks.view(T, N) 145 | 146 | # Let's figure out which steps in the sequence have a zero for any agent 147 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 148 | # (can be interpreted as a truncated back-propagation through time) 149 | has_zeros = ((masks[1:] == 0.0) \ 150 | .any(dim=-1) 151 | .nonzero() 152 | .squeeze() 153 | .cpu()) 154 | 155 | # +1 to correct the masks[1:] 156 | if has_zeros.dim() == 0: 157 | # Deal with scalar 158 | has_zeros = [has_zeros.item() + 1] 159 | else: 160 | has_zeros = (has_zeros + 1).numpy().tolist() 161 | 162 | # add t=0 and t=T to the list 163 | has_zeros = [0] + has_zeros + [T] 164 | 165 | hxs = hxs.unsqueeze(0) 166 | outputs = [] 167 | for i in range(len(has_zeros) - 1): 168 | # We can now process steps that don't have any zeros in masks together! 169 | # This is much faster 170 | start_idx = has_zeros[i] 171 | end_idx = has_zeros[i + 1] 172 | 173 | rnn_scores, hxs = self.gru( 174 | x[start_idx:end_idx], 175 | hxs * masks[start_idx].view(1, -1, 1)) 176 | 177 | outputs.append(rnn_scores) 178 | 179 | # assert len(outputs) == T 180 | # x is a (T, N, -1) tensor 181 | x = torch.cat(outputs, dim=0) 182 | # flatten 183 | x = x.view(T * N, -1) 184 | hxs = hxs.squeeze(0) 185 | 186 | return x, hxs 187 | -------------------------------------------------------------------------------- /common/policy.py: -------------------------------------------------------------------------------- 1 | from .misc_util import orthogonal_init 2 | from .model import GRU 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import Categorical, Normal 6 | 7 | class CategoricalPolicy(nn.Module): 8 | def __init__(self, 9 | embedder, 10 | recurrent, 11 | action_size): 12 | """ 13 | embedder: (torch.Tensor) model to extract the embedding for observation 14 | action_size: number of the categorical actions 15 | """ 16 | super(CategoricalPolicy, self).__init__() 17 | self.embedder = embedder 18 | # small scale weight-initialization in policy enhances the stability 19 | self.fc_policy = orthogonal_init(nn.Linear(self.embedder.output_dim, action_size), gain=0.01) 20 | self.fc_value = orthogonal_init(nn.Linear(self.embedder.output_dim, 1), gain=1.0) 21 | 22 | self.recurrent = recurrent 23 | if self.recurrent: 24 | self.gru = GRU(self.embedder.output_dim, self.embedder.output_dim) 25 | 26 | def is_recurrent(self): 27 | return self.recurrent 28 | 29 | def forward(self, x, hx, masks): 30 | hidden = self.embedder(x) 31 | if self.recurrent: 32 | hidden, hx = self.gru(hidden, hx, masks) 33 | logits = self.fc_policy(hidden) 34 | log_probs = F.log_softmax(logits, dim=1) 35 | p = Categorical(logits=log_probs) 36 | v = self.fc_value(hidden).reshape(-1) 37 | return p, v, hx 38 | -------------------------------------------------------------------------------- /common/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 3 | import numpy as np 4 | from collections import deque 5 | 6 | class Storage(): 7 | 8 | def __init__(self, obs_shape, hidden_state_size, num_steps, num_envs, device): 9 | self.obs_shape = obs_shape 10 | self.hidden_state_size = hidden_state_size 11 | self.num_steps = num_steps 12 | self.num_envs = num_envs 13 | self.device = device 14 | self.reset() 15 | 16 | def reset(self): 17 | self.obs_batch = torch.zeros(self.num_steps+1, self.num_envs, *self.obs_shape) 18 | self.hidden_states_batch = torch.zeros(self.num_steps+1, self.num_envs, self.hidden_state_size) 19 | self.act_batch = torch.zeros(self.num_steps, self.num_envs) 20 | self.rew_batch = torch.zeros(self.num_steps, self.num_envs) 21 | self.done_batch = torch.zeros(self.num_steps, self.num_envs) 22 | self.log_prob_act_batch = torch.zeros(self.num_steps, self.num_envs) 23 | self.value_batch = torch.zeros(self.num_steps+1, self.num_envs) 24 | self.return_batch = torch.zeros(self.num_steps, self.num_envs) 25 | self.adv_batch = torch.zeros(self.num_steps, self.num_envs) 26 | self.info_batch = deque(maxlen=self.num_steps) 27 | self.step = 0 28 | 29 | def store(self, obs, hidden_state, act, rew, done, info, log_prob_act, value): 30 | self.obs_batch[self.step] = torch.from_numpy(obs.copy()) 31 | self.hidden_states_batch[self.step] = torch.from_numpy(hidden_state.copy()) 32 | self.act_batch[self.step] = torch.from_numpy(act.copy()) 33 | self.rew_batch[self.step] = torch.from_numpy(rew.copy()) 34 | self.done_batch[self.step] = torch.from_numpy(done.copy()) 35 | self.log_prob_act_batch[self.step] = torch.from_numpy(log_prob_act.copy()) 36 | self.value_batch[self.step] = torch.from_numpy(value.copy()) 37 | self.info_batch.append(info) 38 | 39 | self.step = (self.step + 1) % self.num_steps 40 | 41 | def store_last(self, last_obs, last_hidden_state, last_value): 42 | self.obs_batch[-1] = torch.from_numpy(last_obs.copy()) 43 | self.hidden_states_batch[-1] = torch.from_numpy(last_hidden_state.copy()) 44 | self.value_batch[-1] = torch.from_numpy(last_value.copy()) 45 | 46 | def compute_estimates(self, gamma=0.99, lmbda=0.95, use_gae=True, normalize_adv=True): 47 | rew_batch = self.rew_batch 48 | if use_gae: 49 | A = 0 50 | for i in reversed(range(self.num_steps)): 51 | rew = rew_batch[i] 52 | done = self.done_batch[i] 53 | value = self.value_batch[i] 54 | next_value = self.value_batch[i+1] 55 | 56 | delta = (rew + gamma * next_value * (1 - done)) - value 57 | self.adv_batch[i] = A = gamma * lmbda * A * (1 - done) + delta 58 | else: 59 | G = self.value_batch[-1] 60 | for i in reversed(range(self.num_steps)): 61 | rew = rew_batch[i] 62 | done = self.done_batch[i] 63 | 64 | G = rew + gamma * G * (1 - done) 65 | self.return_batch[i] = G 66 | 67 | self.return_batch = self.adv_batch + self.value_batch[:-1] 68 | if normalize_adv: 69 | self.adv_batch = (self.adv_batch - torch.mean(self.adv_batch)) / (torch.std(self.adv_batch) + 1e-8) 70 | 71 | 72 | def fetch_train_generator(self, mini_batch_size=None, recurrent=False): 73 | batch_size = self.num_steps * self.num_envs 74 | if mini_batch_size is None: 75 | mini_batch_size = batch_size 76 | # If agent's policy is not recurrent, data could be sampled without considering the time-horizon 77 | if not recurrent: 78 | sampler = BatchSampler(SubsetRandomSampler(range(batch_size)), 79 | mini_batch_size, 80 | drop_last=True) 81 | for indices in sampler: 82 | obs_batch = torch.FloatTensor(self.obs_batch[:-1]).reshape(-1, *self.obs_shape)[indices].to(self.device) 83 | hidden_state_batch = torch.FloatTensor(self.hidden_states_batch[:-1]).reshape(-1, self.hidden_state_size).to(self.device) 84 | act_batch = torch.FloatTensor(self.act_batch).reshape(-1)[indices].to(self.device) 85 | done_batch = torch.FloatTensor(self.done_batch).reshape(-1)[indices].to(self.device) 86 | log_prob_act_batch = torch.FloatTensor(self.log_prob_act_batch).reshape(-1)[indices].to(self.device) 87 | value_batch = torch.FloatTensor(self.value_batch[:-1]).reshape(-1)[indices].to(self.device) 88 | return_batch = torch.FloatTensor(self.return_batch).reshape(-1)[indices].to(self.device) 89 | adv_batch = torch.FloatTensor(self.adv_batch).reshape(-1)[indices].to(self.device) 90 | yield obs_batch, hidden_state_batch, act_batch, done_batch, log_prob_act_batch, value_batch, return_batch, adv_batch 91 | # If agent's policy is recurrent, data should be sampled along the time-horizon 92 | else: 93 | num_mini_batch_per_epoch = batch_size // mini_batch_size 94 | num_envs_per_batch = self.num_envs // num_mini_batch_per_epoch 95 | perm = torch.randperm(self.num_envs) 96 | for start_ind in range(0, self.num_envs, num_envs_per_batch): 97 | idxes = perm[start_ind:start_ind+num_envs_per_batch] 98 | obs_batch = torch.FloatTensor(self.obs_batch[:-1, idxes]).reshape(-1, *self.obs_shape).to(self.device) 99 | # [0:1] instead of [0] to keep two-dimensional array 100 | hidden_state_batch = torch.FloatTensor(self.hidden_states_batch[0:1, idxes]).reshape(-1, self.hidden_state_size).to(self.device) 101 | act_batch = torch.FloatTensor(self.act_batch[:, idxes]).reshape(-1).to(self.device) 102 | done_batch = torch.FloatTensor(self.done_batch[:, idxes]).reshape(-1).to(self.device) 103 | log_prob_act_batch = torch.FloatTensor(self.log_prob_act_batch[:, idxes]).reshape(-1).to(self.device) 104 | value_batch = torch.FloatTensor(self.value_batch[:-1, idxes]).reshape(-1).to(self.device) 105 | return_batch = torch.FloatTensor(self.return_batch[:, idxes]).reshape(-1).to(self.device) 106 | adv_batch = torch.FloatTensor(self.adv_batch[:, idxes]).reshape(-1).to(self.device) 107 | yield obs_batch, hidden_state_batch, act_batch, done_batch, log_prob_act_batch, value_batch, return_batch, adv_batch 108 | 109 | def fetch_log_data(self): 110 | if 'env_reward' in self.info_batch[0][0]: 111 | rew_batch = [] 112 | for step in range(self.num_steps): 113 | infos = self.info_batch[step] 114 | rew_batch.append([info['env_reward'] for info in infos]) 115 | rew_batch = np.array(rew_batch) 116 | else: 117 | rew_batch = self.rew_batch.numpy() 118 | if 'env_done' in self.info_batch[0][0]: 119 | done_batch = [] 120 | for step in range(self.num_steps): 121 | infos = self.info_batch[step] 122 | done_batch.append([info['env_done'] for info in infos]) 123 | done_batch = np.array(done_batch) 124 | else: 125 | done_batch = self.done_batch.numpy() 126 | return rew_batch, done_batch 127 | -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | from torch._C import Value 2 | from test import load_env_and_agent, run 3 | import argparse 4 | from common import set_global_seeds, set_global_log_levels 5 | import numpy as np 6 | from pathlib import Path 7 | import csv 8 | 9 | if __name__=='__main__': 10 | raise NotImplementedError("I made changes to test.py, so now this script needs to be overhauled") 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--random_percent_model_dir', type=str, default=None, 15 | help="directory with saved coinrun models trained on " 16 | "environments with coin position randomized " 17 | "0, 1, 2, 5, and 10 percent of the time.") 18 | 19 | parser.add_argument('--num_levels_model_dir', type=str, default=None, 20 | help="directory with saved coinrun models trained on " 21 | "environments with different numbers of " 22 | "distinct levels.") 23 | 24 | parser.add_argument('--results_dir', type=str, default=None) 25 | 26 | parser.add_argument('--num_timesteps', type=int, default = 10_000) 27 | parser.add_argument('--exp_name', type=str, default = 'compute_metrics', help='experiment name') 28 | parser.add_argument('--start_level', type=int, default = np.random.randint(0, 10**9), help='start-level for environment') 29 | parser.add_argument('--distribution_mode',type=str, default = 'hard', help='distribution mode for environment') 30 | parser.add_argument('--param_name', type=str, default = 'hard', help='hyper-parameter ID') 31 | parser.add_argument('--device', type=str, default = 'cpu', required = False, help='whether to use gpu') 32 | parser.add_argument('--gpu_device', type=int, default = int(0), required = False, help = 'visible device in CUDA') 33 | parser.add_argument('--seed', type=int, default = np.random.randint(0,9999), help='Random generator seed') 34 | parser.add_argument('--log_level', type=int, default = int(40), help='[10,20,30,40]') 35 | parser.add_argument('--logdir', type=str, default = None) 36 | 37 | parser.add_argument('--num_threads', type=int, default=8) 38 | parser.add_argument('--num_envs', type=int, default=1) 39 | 40 | args = parser.parse_args() 41 | 42 | set_global_seeds(args.seed) 43 | set_global_log_levels(args.log_level) 44 | 45 | 46 | # deploy saved models in --random_percent_model_dir and compute 47 | # how often the models navigate to the end of the level instead of getting 48 | # the coin 49 | 50 | def random_percent_ablation(): 51 | def get_agent_path(random_percent): 52 | """return path of saved agent trained on env with coin randomized 53 | random_percent of the time""" 54 | model_path = Path(args.random_percent_model_dir) 55 | model_path = model_path / f"random_percent_{random_percent}" / "model_200015872.pth" 56 | return model_path 57 | 58 | if args.results_dir: 59 | results_dir = Path(args.results_dir) 60 | else: 61 | results_dir = Path(args.random_percent_model_dir) 62 | 63 | for random_percent in [0, 1, 2, 5, 10]: 64 | model_path = get_agent_path(random_percent) 65 | print(f"Loading agent trained on distribution random_percent_{random_percent}") 66 | print(f"Loading from {model_path}...") 67 | print() 68 | 69 | agent = load_env_and_agent(exp_name=args.exp_name, 70 | env_name="coinrun", 71 | num_envs=args.num_envs, 72 | logdir=args.logdir, 73 | model_file=model_path, 74 | start_level=args.start_level, 75 | num_levels=0, # this means start_level is meaningless (level seeds are drawn randomly) 76 | distribution_mode=args.distribution_mode, 77 | param_name=args.param_name, 78 | device=args.device, 79 | gpu_device=args.gpu_device, 80 | seed=args.seed, 81 | num_checkpoints=0, 82 | random_percent=100, 83 | num_threads=args.num_threads) 84 | 85 | print() 86 | print("Running...") 87 | results = run(agent, args.num_timesteps, args.logdir) 88 | results.update({"random_percent": random_percent}) 89 | 90 | results_file = str(results_dir / "results.csv") 91 | print() 92 | print(f"Saving results to {results_file}") 93 | print() 94 | # write results to csv 95 | if random_percent == 0: 96 | with open(results_file, "w") as f: 97 | w = csv.DictWriter(f, results.keys()) 98 | w.writeheader() 99 | w.writerow(results) 100 | else: 101 | with open(results_file, "a") as f: 102 | w = csv.DictWriter(f, results.keys()) 103 | w.writerow(results) 104 | 105 | 106 | def num_levels_ablation(): 107 | def get_agent_path(num_levels): 108 | model_path = Path(args.num_levels_model_dir) 109 | model_path = model_path / f"nr_levels_{num_levels}" / "model_200015872.pth" 110 | return model_path 111 | 112 | if args.results_dir: 113 | results_dir = Path(args.results_dir) 114 | else: 115 | results_dir = Path(args.num_levels_model_dir) 116 | 117 | for num_levels in [100, 316, 1000, 3160, 10_000, 31_600, 100_000]: 118 | model_path = get_agent_path(num_levels) 119 | print(f"Loading agent trained on distribution nr_levels_{num_levels}") 120 | print(f"Loading from {model_path}...") 121 | print() 122 | 123 | agent = load_env_and_agent(exp_name=args.exp_name, 124 | env_name="coinrun", 125 | num_envs=args.num_envs, 126 | logdir=args.logdir, 127 | model_file=model_path, 128 | start_level=args.start_level, 129 | num_levels=0, 130 | distribution_mode=args.distribution_mode, 131 | param_name=args.param_name, 132 | device=args.device, 133 | gpu_device=args.gpu_device, 134 | seed=args.seed, 135 | num_checkpoints=0, 136 | random_percent=100, 137 | num_threads=args.num_threads) 138 | 139 | print() 140 | print("Running...") 141 | 142 | results = run(agent, args.num_timesteps, args.logdir) 143 | results.update({"num_levels": num_levels}) 144 | results_file = str(results_dir / "results.csv") 145 | 146 | print() 147 | print(f"Saving results to {results_file}") 148 | print() 149 | 150 | # write results to csv 151 | if num_levels == 100: 152 | with open(results_file, "w") as f: 153 | w = csv.DictWriter(f, results.keys()) 154 | w.writeheader() 155 | w.writerow(results) 156 | else: 157 | with open(results_file, "a") as f: 158 | w = csv.DictWriter(f, results.keys()) 159 | w.writerow(results) 160 | 161 | if args.random_percent_model_dir: 162 | random_percent_ablation() 163 | 164 | if args.num_levels_model_dir: 165 | num_levels_ablation() 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | results_dir = "experiments/results/" 4 | on_cluster = os.environ["HOME"] == "/home/lsl38" 5 | -------------------------------------------------------------------------------- /experiments/scripts/get-fig2-data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -o errexit 3 | 4 | # gather metrics for figure 2 5 | # load trained coinrun models and deploy them in the test environment. 6 | # to do this without specifying the model_file every time, trained coinrun 7 | # models must be stored in logs with exp_name 'freq-sweep-random-percent-$random_percent' 8 | # write output metrics to csv files in ./experiments/results/ 9 | 10 | num_seeds=10000 11 | #num_seeds=10 12 | 13 | random_percent=$SLURM_ARRAY_TASK_ID 14 | 15 | if [[ $1 = 'standard' ]] 16 | then 17 | python run_coinrun.py --model_file $random_percent --start_level_seed 0 --num_seeds $num_seeds --random_percent 100 18 | elif [[ $1 = 'joint' ]] 19 | then 20 | python run_coinrun.py --model_file $random_percent --start_level_seed 0 --num_seeds $num_seeds --random_percent $random_percent --reset_mode "complete" 21 | fi 22 | -------------------------------------------------------------------------------- /experiments/scripts/plot-figure2.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import numpy as np 4 | 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | 8 | sns.set() 9 | 10 | # Change this to the folder where the output of get-fig2-data.sh is stored 11 | results_dir = "/home/lauro/projects/aisc2021/train-procgen-pytorch/experiments/hpc-results/" 12 | vanilla_coinrun_resdir = results_dir + "test_rand_percent_0/train_rand_percent_0/" 13 | van_df = pd.read_csv(vanilla_coinrun_resdir + "metrics.csv") 14 | max_collect_freq = 0.1 15 | 16 | 17 | def listdir(path): 18 | return [os.path.join(path, d) for d in os.listdir(path)] 19 | 20 | 21 | def path_to_rand_percent(path): 22 | """extract integer rand_percent from path info 23 | path must end with integer rand_percent""" 24 | out = path[-3:] 25 | while not out.isdigit(): 26 | out = out[1:] 27 | return int(out) 28 | 29 | 30 | def seed_collect_freq(seed): 31 | """returns frequency at which agent collects the inv coin 32 | in the vanilla environment, for env seed""" 33 | idx = van_df['seed'] == seed 34 | return np.mean(van_df.loc[idx]["inv_coin_collected"]) 35 | 36 | 37 | collect_freqs = list(map(seed_collect_freq, range(np.max(van_df['seed'])))) 38 | collect_freqs = np.array(collect_freqs) 39 | (good_seeds,) = np.nonzero(collect_freqs < max_collect_freq) 40 | 41 | 42 | def get_good_seed_df(df): 43 | """ 44 | given a dataframe with column 'seed', return new dataframe that is a subset 45 | of the old one, and includes only rows with good seeds. 46 | """ 47 | good_seed_idx = [seed in good_seeds for seed in df['seed']] 48 | return df.loc[good_seed_idx] 49 | 50 | 51 | test_rp100_resdir = os.path.join(results_dir, "test_rand_percent_100") 52 | rand_percents = [path_to_rand_percent(file) for file in os.listdir(results_dir) if file.startswith("test_rand")] 53 | rand_percents.sort() 54 | joint_rp_paths = [os.path.join(results_dir, f"test_rand_percent_{rp}", f"train_rand_percent_{rp}") for rp in rand_percents[:-1]] 55 | 56 | 57 | # sweep over training rand_percent 58 | csv_files = {path_to_rand_percent(path): os.path.join(path, "metrics.csv") for path in listdir(test_rp100_resdir)} 59 | dataframes = {k: pd.read_csv(v, on_bad_lines="skip") for k, v in csv_files.items()} 60 | dataframes = {k: get_good_seed_df(df) for k, df in dataframes.items()} 61 | reach_end_freqs = {k: np.mean(df["inv_coin_collected"]) for k, df in dataframes.items()} 62 | 63 | data = list(reach_end_freqs.items()) 64 | data.sort() 65 | data = np.array(data) 66 | 67 | 68 | # sweep over training & test rand_percent jointly 69 | # measure how often model dies or times out, ie not gets coin 70 | csv_files = {path_to_rand_percent(path): os.path.join(path, "metrics.csv") for path in joint_rp_paths} 71 | dataframes = {k: pd.read_csv(v) for k, v in csv_files.items()} 72 | dataframes = {k: get_good_seed_df(df) for k, df in dataframes.items()} 73 | fail_to_get_coin_freq = {k: 1 - np.mean(df["coin_collected"]) for k, df in dataframes.items()} 74 | 75 | joint_data = list(fail_to_get_coin_freq.items()) 76 | joint_data.sort() 77 | joint_data = np.array(joint_data) 78 | 79 | 80 | baseline_vanilla_df = get_good_seed_df(van_df) 81 | prob_of_reaching_end_without_inv_coin = np.mean(baseline_vanilla_df["coin_collected"] == 1) 82 | 83 | 84 | figpath = "./" 85 | 86 | fig, ax = plt.subplots(figsize=[6, 2.5]) 87 | plt.axhline(y=prob_of_reaching_end_without_inv_coin * 100, linestyle="--", color="tab:grey", label="Maximum possible OR frequency") 88 | 89 | x, y = joint_data.T 90 | ax.plot(x, y*100, "--o", label="IID Robustness Failure", color="tab:orange") 91 | 92 | x, y = data.T 93 | ax.plot(x, y*100, "--o", label="Objective Robustness Failure", color="tab:blue") 94 | 95 | # plt.ylim(50, 101) 96 | plt.ylabel("Frequency (%)") 97 | plt.xlabel("Probability (%) of a level with randomized coin.") 98 | plt.legend() 99 | 100 | plt.savefig(figpath + "coinrun_freq.pdf") 101 | -------------------------------------------------------------------------------- /experiments/scripts/train-coinrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -o errexit 3 | 4 | # This script trains on the coinrun environment, with the coin 5 | # placed randomly $random_percent % of the time, and otherwise 6 | # placed at the end of the level. 7 | 8 | random_percent=$SLURM_ARRAY_TASK_ID 9 | experiment_name="freq-sweep-random-percent-$random_percent" 10 | 11 | let n_steps=80*10**6 12 | n_checkpoints=4 13 | n_threads=32 # 32 CPUs per GPU 14 | wandb_tags="hpc large-model" 15 | 16 | 17 | coinrun_opt="--env_name coinrun --random_percent $random_percent --val_env_name coinrun_aisc" # coinrun_aisc ignores random_percent arg 18 | 19 | # include the option 20 | # --model_file auto 21 | # when resuming a run from a saved checkpoint. This should load the latest model 22 | # saved under $experiment_name 23 | 24 | options=" 25 | $coinrun_opt 26 | --use_wandb 27 | --param_name A100 28 | --distribution_mode hard 29 | --num_timesteps $n_steps 30 | --num_checkpoints $n_checkpoints 31 | --num_threads $n_threads 32 | --wandb_tags $wandb_tags 33 | --exp_name $experiment_name 34 | " 35 | 36 | 37 | python train.py $options 38 | -------------------------------------------------------------------------------- /experiments/scripts/train-keys-and-chests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -o errexit 3 | 4 | # this script trains on the heist_aisc_many_chests environment 5 | # applying a small penalty for picking up keys. 6 | 7 | experiment_name="key-penalty" 8 | key_penalty=3 9 | 10 | let n_steps=80*10**6 11 | n_checkpoints=4 12 | n_threads=32 # 32 CPUs per GPU 13 | wandb_tags="hpc large-model" 14 | 15 | keys_and_chests_opt="--env_name heist_aisc_many_chests --val_env_name heist_aisc_many_keys --key_penalty $key_penalty" 16 | 17 | 18 | # include the option 19 | # --model_file auto 20 | # when resuming a run from a saved checkpoint. This should load the latest model 21 | # saved under $experiment_name 22 | 23 | options=" 24 | $keys_and_chests_opt 25 | --param_name A100 26 | --use_wandb 27 | --distribution_mode hard 28 | --num_timesteps $n_steps 29 | --num_checkpoints $n_checkpoints 30 | --num_threads $n_threads 31 | --wandb_tags $wandb_tags 32 | --exp_name $experiment_name 33 | " 34 | 35 | python train.py $options 36 | -------------------------------------------------------------------------------- /experiments/scripts/train-maze-I.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -o errexit 3 | 4 | # This script trains on maze_aisc, where goal position 5 | # is randomized within a region of size --rand-region 6 | 7 | 8 | rand_region=$SLURM_ARRAY_TASK_ID 9 | experiment_name="maze-I-sweep-rand-region-$rand_region" 10 | 11 | let n_steps=80*10**6 12 | n_checkpoints=4 13 | n_threads=32 # 32 CPUs per GPU 14 | 15 | wandb_tags="hpc large-model rand-region-sweep" 16 | export WANDB_RUN_ID="maze-sweep-rand-region-$rand_region" 17 | 18 | 19 | maze_opt="--env_name maze_aisc --rand_region $rand_region --val_env_name maze" # coinrun_aisc ignores random_percent arg 20 | 21 | # include the option 22 | # --model_file auto 23 | # when resuming a run from a saved checkpoint. This should load the latest model 24 | # saved under $experiment_name 25 | 26 | options=" 27 | $maze_opt 28 | --use_wandb 29 | --param_name A100 30 | --distribution_mode hard 31 | --num_timesteps $n_steps 32 | --num_checkpoints $n_checkpoints 33 | --num_threads $n_threads 34 | --wandb_tags $wandb_tags 35 | --exp_name $experiment_name 36 | " 37 | 38 | 39 | python train.py $options 40 | -------------------------------------------------------------------------------- /hyperparams/procgen/config.yml: -------------------------------------------------------------------------------- 1 | debug: 2 | algo: ppo 3 | n_envs: 2 4 | n_steps: 64 5 | epoch: 1 6 | mini_batch_per_epoch: 4 7 | mini_batch_size: 512 8 | gamma: 0.999 9 | lmbda: 0.95 10 | learning_rate: 0.0005 11 | grad_clip_norm: 0.5 12 | eps_clip: 0.2 13 | value_coef: 0.5 14 | entropy_coef: 0.01 15 | normalize_adv: True 16 | normalize_rew: True 17 | use_gae: True 18 | architecture: impala 19 | recurrent: False 20 | 21 | easy: 22 | algo: ppo 23 | n_envs: 64 24 | n_steps: 256 25 | epoch: 3 26 | mini_batch_per_epoch: 8 27 | mini_batch_size: 2048 28 | gamma: 0.999 29 | lmbda: 0.95 30 | learning_rate: 0.0005 31 | grad_clip_norm: 0.5 32 | eps_clip: 0.2 33 | value_coef: 0.5 34 | entropy_coef: 0.01 35 | normalize_adv: True 36 | normalize_rew: True 37 | use_gae: True 38 | architecture: impala 39 | recurrent: False 40 | 41 | easy-200: 42 | algo: ppo 43 | n_envs: 128 44 | n_steps: 256 45 | epoch: 3 46 | mini_batch_per_epoch: 8 47 | mini_batch_size: 2048 48 | gamma: 0.999 49 | lmbda: 0.95 50 | learning_rate: 0.0005 51 | grad_clip_norm: 0.5 52 | eps_clip: 0.2 53 | value_coef: 0.5 54 | entropy_coef: 0.01 55 | normalize_adv: True 56 | normalize_rew: True 57 | use_gae: True 58 | architecture: impala 59 | recurrent: False 60 | 61 | hard: 62 | algo: ppo 63 | n_envs: 128 64 | n_steps: 256 65 | epoch: 3 66 | mini_batch_per_epoch: 8 67 | mini_batch_size: 4096 68 | gamma: 0.999 69 | lmbda: 0.95 70 | learning_rate: 0.0005 71 | grad_clip_norm: 0.5 72 | eps_clip: 0.2 73 | value_coef: 0.5 74 | entropy_coef: 0.01 75 | normalize_adv: True 76 | normalize_rew: True 77 | use_gae: True 78 | architecture: impala 79 | recurrent: False 80 | 81 | hard-500: 82 | algo: ppo 83 | n_envs: 256 84 | n_steps: 256 85 | epoch: 3 86 | mini_batch_per_epoch: 8 87 | mini_batch_size: 8192 88 | gamma: 0.999 89 | lmbda: 0.95 90 | learning_rate: 0.0005 91 | grad_clip_norm: 0.5 92 | eps_clip: 0.2 93 | value_coef: 0.5 94 | entropy_coef: 0.01 95 | normalize_adv: True 96 | normalize_rew: True 97 | use_gae: True 98 | architecture: impala 99 | recurrent: False 100 | 101 | hard-500-mem: 102 | algo: ppo 103 | n_envs: 256 104 | n_steps: 256 105 | epoch: 3 106 | mini_batch_per_epoch: 8 107 | mini_batch_size: 8192 108 | gamma: 0.999 109 | lmbda: 0.95 110 | learning_rate: 0.0005 111 | grad_clip_norm: 0.5 112 | eps_clip: 0.2 113 | value_coef: 0.5 114 | entropy_coef: 0.01 115 | normalize_adv: True 116 | normalize_rew: True 117 | use_gae: True 118 | architecture: impala 119 | recurrent: False 120 | 121 | hard-rec: 122 | algo: ppo 123 | n_envs: 256 124 | n_steps: 256 125 | epoch: 3 126 | mini_batch_per_epoch: 8 127 | mini_batch_size: 8192 128 | gamma: 0.999 129 | lmbda: 0.95 130 | learning_rate: 0.0005 131 | grad_clip_norm: 0.5 132 | eps_clip: 0.2 133 | value_coef: 0.5 134 | entropy_coef: 0.01 135 | normalize_adv: True 136 | normalize_rew: True 137 | use_gae: True 138 | architecture: impala 139 | recurrent: True 140 | 141 | hard-local-dev: 142 | algo: ppo 143 | n_envs: 16 144 | n_steps: 256 145 | epoch: 3 146 | mini_batch_per_epoch: 8 147 | mini_batch_size: 8192 148 | gamma: 0.999 149 | lmbda: 0.95 150 | learning_rate: 0.0005 151 | grad_clip_norm: 0.5 152 | eps_clip: 0.2 153 | value_coef: 0.5 154 | entropy_coef: 0.01 155 | normalize_adv: True 156 | normalize_rew: True 157 | use_gae: True 158 | architecture: impala 159 | recurrent: False 160 | 161 | hard-local-dev-rec: 162 | algo: ppo 163 | n_envs: 16 164 | n_steps: 256 165 | epoch: 3 166 | mini_batch_per_epoch: 8 167 | mini_batch_size: 8192 168 | gamma: 0.999 169 | lmbda: 0.95 170 | learning_rate: 0.0005 171 | grad_clip_norm: 0.5 172 | eps_clip: 0.2 173 | value_coef: 0.5 174 | entropy_coef: 0.01 175 | normalize_adv: True 176 | normalize_rew: True 177 | use_gae: True 178 | architecture: impala 179 | recurrent: True 180 | 181 | A100: 182 | algo: ppo 183 | n_envs: 512 184 | n_steps: 256 185 | epoch: 3 186 | mini_batch_per_epoch: 16 187 | mini_batch_size: 32768 # 32768 # this is just a maximum 188 | gamma: 0.999 189 | lmbda: 0.95 190 | learning_rate: 0.0005 # should make larger? 191 | grad_clip_norm: 0.5 192 | eps_clip: 0.2 193 | value_coef: 0.5 194 | entropy_coef: 0.01 195 | normalize_adv: True 196 | normalize_rew: True 197 | use_gae: True 198 | architecture: impala 199 | recurrent: False 200 | 201 | 202 | A100-large: # for larger model (16x params) 203 | algo: ppo 204 | n_envs: 512 205 | n_steps: 256 206 | epoch: 3 207 | mini_batch_per_epoch: 16 208 | mini_batch_size: 2048 # vary this param to adjust for memory 209 | gamma: 0.999 210 | lmbda: 0.95 211 | learning_rate: 0.0005 # scale by 1 / sqrt(channel_scale) 212 | grad_clip_norm: 0.5 213 | eps_clip: 0.2 214 | value_coef: 0.5 215 | entropy_coef: 0.01 216 | normalize_adv: True 217 | normalize_rew: True 218 | use_gae: True 219 | architecture: impala 220 | recurrent: False 221 | -------------------------------------------------------------------------------- /plot_training_csv.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import argparse 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser( 9 | description='args for plotting') 10 | parser.add_argument( 11 | '--datapath', type=str) 12 | args = parser.parse_args() 13 | return args 14 | 15 | def plot(): 16 | args = parse_args() 17 | 18 | data = pd.read_csv(args.datapath + "/log.csv") 19 | cols = list(data.columns) 20 | 21 | # Get rid of columns we don't plot 22 | cols.remove('wall_time') 23 | cols.remove('num_episodes') 24 | cols.remove('num_episodes_val') 25 | cols.remove('max_episode_len') 26 | cols.remove('min_episode_len') 27 | cols.remove('val_max_episode_len') 28 | cols.remove('val_min_episode_len') 29 | cols.remove('max_episode_rewards') 30 | cols.remove('min_episode_rewards') 31 | cols.remove('val_max_episode_rewards') 32 | cols.remove('val_min_episode_rewards') 33 | cols.remove('timesteps') 34 | 35 | # Plot episode len with some preprocessing, then remove 36 | plt.subplots(figsize=[10, 7]) 37 | for name in ['mean_episode_len', 'val_mean_episode_len']: 38 | plt.plot(data['timesteps'], data[name]/100, label=name+"/100", 39 | alpha=0.8) 40 | cols.remove('mean_episode_len') 41 | cols.remove('val_mean_episode_len') 42 | 43 | # Plot the rest 44 | for name in cols: 45 | plt.plot(data['timesteps'], data[name], label=name) 46 | plt.xlabel('timesteps') 47 | plt.legend() 48 | plt.savefig(args.datapath + "/plot.png") 49 | 50 | if __name__ == '__main__': 51 | plot() -------------------------------------------------------------------------------- /plot_value_coin_barchart.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import argparse 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser( 9 | description='args for plotting') 10 | parser.add_argument( 11 | '--datapath', type=str) 12 | args = parser.parse_args() 13 | return args 14 | 15 | def plot(): 16 | args = parse_args() 17 | 18 | num_evaluated_ims = 1001 19 | 20 | metadata = pd.read_csv(args.datapath + "/obs_metadata.csv") 21 | cols = list(metadata.columns) 22 | metadata = metadata[0:num_evaluated_ims] 23 | 24 | # Add value data 25 | value_files = os.listdir(args.datapath) 26 | value_file_cond = lambda file: ('val' in file) and ('.npy' in file) 27 | value_files = [file for file in value_files if value_file_cond(file)] 28 | value_files = sorted(value_files) 29 | value_files = value_files[0:len(metadata)] 30 | values = [np.load(os.path.join(args.datapath, file)) for file in value_files] 31 | values = np.array(values) 32 | metadata['value'] = values.tolist() 33 | 34 | 35 | # Make the labels and indices you'll use to identify different categories 36 | coin_categs_inds = [0, 1] #set(metadata['coin_visible']) 37 | bme_categs_inds = [0, 1, 2, 3] #set(metadata['begin_middle_end']) 38 | coin_categs_strs = ['(No coin)', '(with coin)'] 39 | bme_categs_strs = ['Beginning\n', 'Middle\n', 'End\n', 'After End\n'] 40 | full_categs_strs = [] 41 | full_categs_inds = [] 42 | for bme_categ in bme_categs_strs: 43 | for coin_categ in coin_categs_strs: 44 | full_categs_strs.append(f"{bme_categ} {coin_categ}") 45 | 46 | for bme_categ in bme_categs_inds: 47 | for coin_categ in coin_categs_inds: 48 | full_categs_inds.append( (bme_categ, coin_categ) ) 49 | 50 | # Separate the data into their categories 51 | match_dict = {arr: label for (arr, label) in zip(full_categs_inds, full_categs_strs)} 52 | categ_metadata = {} 53 | for coin_categs_ind in coin_categs_inds: 54 | for bme_categs_ind in bme_categs_inds: 55 | condition = (metadata['coin_visible'] == coin_categs_ind) & (metadata['begin_middle_end'] == bme_categs_ind) 56 | label = match_dict[ (bme_categs_ind, coin_categs_ind) ] 57 | categ_metadata[label] = metadata[condition] 58 | 59 | 60 | # Get the mean value for each category and compute confidence intervals 61 | results_low_mean_high = {} 62 | for full_categ in full_categs_strs: 63 | # Draw 15000 bootstrap replicates 64 | categ_values = categ_metadata[full_categ]['value'] 65 | bs_replicates_values = draw_bs_replicates(categ_values, np.mean, 10000)#15000) 66 | 67 | categ_mean = np.mean(categ_values) 68 | # Print empirical mean 69 | print(f"{full_categ} | Empirical mean: " + str(categ_mean)) 70 | 71 | # Print the mean of bootstrap replicates 72 | print(f"{full_categ} | Bootstrap replicates mean: " + str(np.mean(bs_replicates_values))) 73 | categ_low = np.percentile(bs_replicates_values,[5.]) 74 | categ_high = np.percentile(bs_replicates_values,[95.]) 75 | results_low_mean_high[full_categ] = (categ_low, categ_mean, categ_high) 76 | 77 | # Count how many samples are actually used 78 | sample_count_list = [] 79 | for full_categ in full_categs_strs: 80 | # Draw 15000 bootstrap replicates 81 | categ_values = categ_metadata[full_categ]['value'] 82 | sample_count_list.append(categ_values) 83 | print(f"n={sum([len(y) for y in sample_count_list])}") 84 | 85 | # # And we do 'post end-wall' manually because we collected that data later 86 | # value_files_post = os.listdir(args.datapath + "/for_post_endwall_bars") 87 | # value_files_post = [file for file in value_files_post if value_file_cond(file)] 88 | # value_files_post = sorted(value_files_post) 89 | # values_post = [np.load(os.path.join(args.datapath + "/for_post_endwall_bars", file)) for file in value_files_post] 90 | # values_post = np.array(values_post) 91 | # bs_replicates_values_post = draw_bs_replicates(values_post, np.mean, 92 | # 10000) # 15000) 93 | # values_post_mean = np.mean(values_post) 94 | # print(f"Post | Empirical mean: " + str(values_post)) 95 | # # Print the mean of bootstrap replicates 96 | # print(f"Post | Bootstrap replicates mean: " + str( 97 | # np.mean(bs_replicates_values_post))) 98 | # values_post_low = np.percentile(bs_replicates_values_post, [5.]) 99 | # values_post_high = np.percentile(bs_replicates_values_post, [95.]) 100 | # post_name = 'After End\n(No coin)' 101 | # results_low_mean_high[post_name] = (values_post_low, values_post_mean, values_post_high) 102 | # 103 | # full_categs_strs.append(post_name) 104 | 105 | # Then do a bit of processing 106 | xticks = np.arange(0,len(full_categs_strs)) 107 | means = [v[1] for v in results_low_mean_high.values()] 108 | lows = [v[0] for v in results_low_mean_high.values()] 109 | highs = [v[2] for v in results_low_mean_high.values()] 110 | errs = [(v[0], v[2]) for v in results_low_mean_high.values()] 111 | errs = [np.concatenate(e) for e in errs] 112 | errs = np.stack(errs, axis=0).transpose() 113 | errs = errs - means 114 | errs = np.abs(errs) 115 | 116 | 117 | fig, ax = plt.subplots(figsize=(7, 4)) 118 | ax.yaxis.grid(True) 119 | ax.bar(xticks, means, yerr=errs, align='center', alpha=0.95, 120 | ecolor='black', color=['orangered', 'lightsalmon', 'olivedrab', 'yellowgreen', 'deepskyblue', 'skyblue', 'darkslategray', 'cadetblue'], capsize=10) 121 | plt.ylim([5, 10]) 122 | ax.set_ylabel('Value function output') 123 | plt.box(False) 124 | ax.set_xticks(xticks) 125 | ax.set_xticklabels(full_categs_strs) 126 | # Save the figure and show 127 | plt.tight_layout() 128 | plt.savefig(args.datapath + '/bar_plot_with_error_bars.png') 129 | plt.close() 130 | 131 | 132 | 133 | def draw_bs_replicates(data, func, size): 134 | """creates a bootstrap sample, computes replicates and returns replicates array""" 135 | # Create an empty array to store replicates 136 | bs_replicates = np.empty(size) 137 | 138 | # Create bootstrap replicates as much as size 139 | for i in range(size): 140 | # Create a bootstrap sample 141 | bs_sample = np.random.choice(data, size=len(data)) 142 | # Get bootstrap replicate and append to bs_replicates 143 | bs_replicates[i] = func(bs_sample) 144 | 145 | return bs_replicates 146 | 147 | 148 | 149 | 150 | if __name__ == '__main__': 151 | plot() 152 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | from common.env.procgen_wrappers import * 2 | from common.logger import Logger 3 | from common.storage import Storage 4 | from common.model import NatureModel, ImpalaModel 5 | from common.policy import CategoricalPolicy 6 | from common import set_global_seeds, set_global_log_levels 7 | 8 | import os, time, yaml, argparse 9 | import gym 10 | from procgen import ProcgenGym3Env 11 | import random 12 | import torch 13 | 14 | from PIL import Image 15 | import torchvision as tv 16 | 17 | from gym3 import ViewerWrapper, VideoRecorderWrapper, ToBaselinesVecEnv 18 | 19 | 20 | if __name__=='__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--exp_name', type=str, default = 'render', help='experiment name') 23 | parser.add_argument('--env_name', type=str, default = 'coinrun', help='environment ID') 24 | parser.add_argument('--start_level', type=int, default = int(0), help='start-level for environment') 25 | parser.add_argument('--num_levels', type=int, default = int(0), help='number of training levels for environment') 26 | parser.add_argument('--distribution_mode',type=str, default = 'hard', help='distribution mode for environment') 27 | parser.add_argument('--param_name', type=str, default = 'easy-200', help='hyper-parameter ID') 28 | parser.add_argument('--device', type=str, default = 'cpu', required = False, help='whether to use gpu') 29 | parser.add_argument('--gpu_device', type=int, default = int(0), required = False, help = 'visible device in CUDA') 30 | parser.add_argument('--seed', type=int, default = random.randint(0,9999), help='Random generator seed') 31 | parser.add_argument('--log_level', type=int, default = int(40), help='[10,20,30,40]') 32 | parser.add_argument('--num_checkpoints', type=int, default = int(1), help='number of checkpoints to store') 33 | parser.add_argument('--logdir', type=str, default = None) 34 | 35 | #multi threading 36 | parser.add_argument('--num_threads', type=int, default=8) 37 | 38 | #render parameters 39 | parser.add_argument('--tps', type=int, default=15, help="env fps") 40 | parser.add_argument('--vid_dir', type=str, default=None) 41 | parser.add_argument('--model_file', type=str) 42 | parser.add_argument('--save_value', action='store_true') 43 | parser.add_argument('--save_value_individual', action='store_true') 44 | parser.add_argument('--value_saliency', action='store_true') 45 | 46 | 47 | 48 | parser.add_argument('--random_percent', type=float, default=0., help='percent of environments in which coin is randomized (only for coinrun)') 49 | parser.add_argument('--corruption_type', type=str, default = None) 50 | parser.add_argument('--corruption_severity', type=str, default = 1) 51 | parser.add_argument('--agent_view', action="store_true", help="see what the agent sees") 52 | parser.add_argument('--continue_after_coin', action="store_true", help="level doesnt end when agent gets coin") 53 | parser.add_argument('--noview', action="store_true", help="just take vids") 54 | 55 | 56 | 57 | args = parser.parse_args() 58 | exp_name = args.exp_name 59 | env_name = args.env_name 60 | start_level = args.start_level 61 | num_levels = args.num_levels 62 | distribution_mode = args.distribution_mode 63 | param_name = args.param_name 64 | device = args.device 65 | gpu_device = args.gpu_device 66 | seed = args.seed 67 | log_level = args.log_level 68 | num_checkpoints = args.num_checkpoints 69 | 70 | set_global_seeds(seed) 71 | set_global_log_levels(log_level) 72 | 73 | #################### 74 | ## HYPERPARAMETERS # 75 | #################### 76 | print('[LOADING HYPERPARAMETERS...]') 77 | with open('hyperparams/procgen/config.yml', 'r') as f: 78 | hyperparameters = yaml.safe_load(f)[param_name] 79 | for key, value in hyperparameters.items(): 80 | print(key, ':', value) 81 | 82 | ############ 83 | ## DEVICE ## 84 | ############ 85 | if args.device == 'gpu': 86 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device) 87 | device = torch.device('cuda') 88 | else: 89 | device = torch.device('cpu') 90 | 91 | ################# 92 | ## ENVIRONMENT ## 93 | ################# 94 | print('INITIALIZAING ENVIRONMENTS...') 95 | n_envs = 1 96 | 97 | def create_venv_render(args, hyperparameters, is_valid=False): 98 | venv = ProcgenGym3Env(num=n_envs, 99 | env_name=args.env_name, 100 | num_levels=0 if is_valid else args.num_levels, 101 | start_level=0 if is_valid else args.start_level, 102 | distribution_mode=args.distribution_mode, 103 | num_threads=1, 104 | render_mode="rgb_array", 105 | random_percent=args.random_percent, 106 | corruption_type=args.corruption_type, 107 | corruption_severity=int(args.corruption_severity), 108 | continue_after_coin=args.continue_after_coin, 109 | ) 110 | info_key = None if args.agent_view else "rgb" 111 | ob_key = "rgb" if args.agent_view else None 112 | if not args.noview: 113 | venv = ViewerWrapper(venv, tps=args.tps, info_key=info_key, ob_key=ob_key) # N.B. this line caused issues for me. I just commented it out, but it's uncommented in the pushed version in case it's just me (Lee). 114 | if args.vid_dir is not None: 115 | venv = VideoRecorderWrapper(venv, directory=args.vid_dir, 116 | info_key=info_key, ob_key=ob_key, fps=args.tps) 117 | venv = ToBaselinesVecEnv(venv) 118 | venv = VecExtractDictObs(venv, "rgb") 119 | normalize_rew = hyperparameters.get('normalize_rew', True) 120 | if normalize_rew: 121 | venv = VecNormalize(venv, ob=False) # normalizing returns, but not 122 | #the img frames 123 | venv = TransposeFrame(venv) 124 | venv = ScaledFloatFrame(venv) 125 | 126 | return venv 127 | n_steps = hyperparameters.get('n_steps', 256) 128 | 129 | #env = create_venv(args, hyperparameters) 130 | #env_valid = create_venv(args, hyperparameters, is_valid=True) 131 | env = create_venv_render(args, hyperparameters, is_valid=True) 132 | 133 | ############ 134 | ## LOGGER ## 135 | ############ 136 | print('INITIALIZAING LOGGER...') 137 | if args.logdir is None: 138 | logdir = 'procgen/' + env_name + '/' + exp_name + '/' + 'RENDER_seed' + '_' + \ 139 | str(seed) + '_' + time.strftime("%d-%m-%Y_%H-%M-%S") 140 | logdir = os.path.join('logs', logdir) 141 | else: 142 | logdir = args.logdir 143 | if not (os.path.exists(logdir)): 144 | os.makedirs(logdir) 145 | logdir_indiv_value = os.path.join(logdir, 'value_individual') 146 | if not (os.path.exists(logdir_indiv_value)) and args.save_value_individual: 147 | os.makedirs(logdir_indiv_value) 148 | logdir_saliency_value = os.path.join(logdir, 'value_saliency') 149 | if not (os.path.exists(logdir_saliency_value)) and args.value_saliency: 150 | os.makedirs(logdir_saliency_value) 151 | print(f'Logging to {logdir}') 152 | logger = Logger(n_envs, logdir) 153 | 154 | ########### 155 | ## MODEL ## 156 | ########### 157 | print('INTIALIZING MODEL...') 158 | observation_space = env.observation_space 159 | observation_shape = observation_space.shape 160 | architecture = hyperparameters.get('architecture', 'impala') 161 | in_channels = observation_shape[0] 162 | action_space = env.action_space 163 | 164 | # Model architecture 165 | if architecture == 'nature': 166 | model = NatureModel(in_channels=in_channels) 167 | elif architecture == 'impala': 168 | model = ImpalaModel(in_channels=in_channels) 169 | 170 | # Discrete action space 171 | recurrent = hyperparameters.get('recurrent', False) 172 | if isinstance(action_space, gym.spaces.Discrete): 173 | action_size = action_space.n 174 | policy = CategoricalPolicy(model, recurrent, action_size) 175 | else: 176 | raise NotImplementedError 177 | policy.to(device) 178 | 179 | ############# 180 | ## STORAGE ## 181 | ############# 182 | print('INITIALIZAING STORAGE...') 183 | hidden_state_dim = model.output_dim 184 | storage = Storage(observation_shape, hidden_state_dim, n_steps, n_envs, device) 185 | #storage_valid = Storage(observation_shape, hidden_state_dim, n_steps, n_envs, device) 186 | 187 | ########### 188 | ## AGENT ## 189 | ########### 190 | print('INTIALIZING AGENT...') 191 | algo = hyperparameters.get('algo', 'ppo') 192 | if algo == 'ppo': 193 | from agents.ppo import PPO as AGENT 194 | else: 195 | raise NotImplementedError 196 | agent = AGENT(env, policy, logger, storage, device, num_checkpoints, **hyperparameters) 197 | 198 | agent.policy.load_state_dict(torch.load(args.model_file, map_location=device)["model_state_dict"]) 199 | agent.n_envs = n_envs 200 | 201 | ############ 202 | ## RENDER ## 203 | ############ 204 | 205 | # save observations and value estimates 206 | def save_value_estimates(storage, epoch_idx): 207 | """write observations and value estimates to npy / csv file""" 208 | print(f"Saving observations and values to {logdir}") 209 | np.save(logdir + f"/observations_{epoch_idx}", storage.obs_batch) 210 | np.save(logdir + f"/value_{epoch_idx}", storage.value_batch) 211 | return 212 | 213 | def save_value_estimates_individual(storage, epoch_idx, individual_value_idx): 214 | """write individual observations and value estimates to npy / csv file""" 215 | print(f"Saving random samples of observations and values to {logdir}") 216 | obs = storage.obs_batch.clone().detach().squeeze().permute(0, 2, 3, 1) 217 | obs = (obs * 255 ).cpu().numpy().astype(np.uint8) 218 | vals = storage.value_batch.squeeze() 219 | 220 | random_idxs = np.random.choice(obs.shape[0], 5, replace=False) 221 | for rand_id in random_idxs: 222 | im = obs[rand_id] 223 | val = vals[rand_id] 224 | im = Image.fromarray(im) 225 | im.save(logdir_indiv_value + f"/obs_{individual_value_idx:05d}.png") 226 | np.save(logdir_indiv_value + f"/val_{individual_value_idx:05d}.npy", val) 227 | individual_value_idx += 1 228 | return individual_value_idx 229 | 230 | def write_scalar(scalar, filename): 231 | """write scalar to filename""" 232 | with open(logdir + "/" + filename, "w") as f: 233 | f.write(str(scalar)) 234 | 235 | 236 | obs = agent.env.reset() 237 | hidden_state = np.zeros((agent.n_envs, agent.storage.hidden_state_size)) 238 | done = np.zeros(agent.n_envs) 239 | 240 | 241 | individual_value_idx = 1001 242 | save_frequency = 1 243 | saliency_save_idx = 0 244 | epoch_idx = 0 245 | while True: 246 | agent.policy.eval() 247 | for _ in range(agent.n_steps): # = 256 248 | if not args.value_saliency: 249 | act, log_prob_act, value, next_hidden_state = agent.predict(obs, hidden_state, done) 250 | else: 251 | act, log_prob_act, value, next_hidden_state, value_saliency_obs = agent.predict_w_value_saliency(obs, hidden_state, done) 252 | if saliency_save_idx % save_frequency == 0: 253 | 254 | value_saliency_obs = value_saliency_obs.swapaxes(1, 3) 255 | obs_copy = obs.swapaxes(1, 3) 256 | 257 | ims_grad = value_saliency_obs.mean(axis=-1) 258 | 259 | percentile = np.percentile(np.abs(ims_grad), 99.9999999) 260 | ims_grad = ims_grad.clip(-percentile, percentile) / percentile 261 | ims_grad = torch.tensor(ims_grad) 262 | blurrer = tv.transforms.GaussianBlur( 263 | kernel_size=5, 264 | sigma=5.) # (5, sigma=(5., 6.)) 265 | ims_grad = blurrer(ims_grad).squeeze().unsqueeze(-1) 266 | 267 | pos_grads = ims_grad.where(ims_grad > 0., 268 | torch.zeros_like(ims_grad)) 269 | neg_grads = ims_grad.where(ims_grad < 0., 270 | torch.zeros_like(ims_grad)).abs() 271 | 272 | 273 | # Make a couple of copies of the original im for later 274 | sample_ims_faint = torch.tensor(obs_copy.mean(-1)) * 0.2 275 | sample_ims_faint = torch.stack([sample_ims_faint] * 3, axis=-1) 276 | sample_ims_faint = sample_ims_faint * 255 277 | sample_ims_faint = sample_ims_faint.clone().detach().type( 278 | torch.uint8).cpu().numpy() 279 | 280 | grad_scale = 9. 281 | grad_vid = np.zeros_like(sample_ims_faint) 282 | pos_grads = pos_grads * grad_scale * 255 283 | neg_grads = neg_grads * grad_scale * 255 284 | grad_vid[:, :, :, 2] = pos_grads.squeeze().clone().detach().type( 285 | torch.uint8).cpu().numpy() 286 | grad_vid[:, :, :, 0] = neg_grads.squeeze().clone().detach().type( 287 | torch.uint8).cpu().numpy() 288 | 289 | grad_vid = grad_vid + sample_ims_faint 290 | 291 | grad_vid = Image.fromarray(grad_vid.swapaxes(0,2).squeeze()) 292 | grad_vid.save(logdir_saliency_value + f"/sal_obs_{saliency_save_idx:05d}_grad.png") 293 | 294 | obs_copy = (obs_copy * 255).astype(np.uint8) 295 | obs_copy = Image.fromarray(obs_copy.swapaxes(0,2).squeeze()) 296 | obs_copy.save(logdir_saliency_value + f"/sal_obs_{saliency_save_idx:05d}_raw.png") 297 | 298 | saliency_save_idx += 1 299 | 300 | 301 | 302 | next_obs, rew, done, info = agent.env.step(act) 303 | 304 | agent.storage.store(obs, hidden_state, act, rew, done, info, log_prob_act, value) 305 | obs = next_obs 306 | hidden_state = next_hidden_state 307 | 308 | _, _, last_val, hidden_state = agent.predict(obs, hidden_state, done) 309 | agent.storage.store_last(obs, hidden_state, last_val) 310 | 311 | if args.save_value_individual: 312 | individual_value_idx = save_value_estimates_individual(agent.storage, epoch_idx, individual_value_idx) 313 | 314 | if args.save_value: 315 | save_value_estimates(agent.storage, epoch_idx) 316 | epoch_idx += 1 317 | 318 | agent.storage.compute_estimates(agent.gamma, agent.lmbda, agent.use_gae, 319 | agent.normalize_adv) 320 | -------------------------------------------------------------------------------- /run_coinrun.py: -------------------------------------------------------------------------------- 1 | from common.env.procgen_wrappers import * 2 | from common import set_global_seeds, set_global_log_levels 3 | import os, argparse 4 | import random 5 | from tqdm import tqdm 6 | import config 7 | import numpy as np 8 | 9 | from run_utils import run_env 10 | 11 | if __name__=='__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--exp_name', type=str, default = 'metrics', help='experiment name') 14 | parser.add_argument('--start_level', type=int, default = int(0), help='start-level for environment') 15 | parser.add_argument('--device', type=str, default = 'cpu', required = False, help='whether to use gpu') 16 | parser.add_argument('--gpu_device', type=int, default = int(0), required = False, help = 'visible device in CUDA') 17 | parser.add_argument('--agent_seed', type=int, default = random.randint(0,999999), help='Seed for pytorch') 18 | parser.add_argument('--log_level', type=int, default = int(40), help='[10,20,30,40]') 19 | parser.add_argument('--logdir', type=str, default = None) 20 | parser.add_argument('--start_level_seed', type=int, default = 0) 21 | parser.add_argument('--num_seeds', type=int, default = 10) 22 | parser.add_argument('--random_percent', type=int, default = 0) 23 | parser.add_argument('--seed_file', type=str, help="path to text file with env seeds to run on.") 24 | parser.add_argument('--reset_mode', type=str, default="inv_coin", help="Reset modes:" 25 | "- inv_coin returns when agent gets the inv coin OR finishes the level" 26 | "- complete returns when the agent finishes the level") 27 | 28 | #multi threading 29 | parser.add_argument('--num_threads', type=int, default=8) 30 | 31 | #render parameters 32 | parser.add_argument('--num_envs', type=int, default=1) 33 | parser.add_argument('--vid_dir', type=str, default=None) 34 | parser.add_argument('--model_file', type=str, help="Can be either a path to a model file, or an " 35 | "integer. Integer is interpreted as random_percent in training") 36 | 37 | args = parser.parse_args() 38 | 39 | # Seeds 40 | set_global_seeds(args.agent_seed) 41 | set_global_log_levels(args.log_level) 42 | 43 | if args.seed_file: 44 | print(f"Loading env seeds from {args.seed_file}") 45 | with open(args.seed_file, 'r') as f: 46 | seeds = f.read() 47 | seeds = [int(s) for s in seeds.split()] 48 | else: 49 | print(f"Running on env seeds {args.start_level_seed} to {args.start_level_seed + args.num_seeds}.") 50 | seeds = np.arange(args.num_seeds) + args.start_level_seed 51 | 52 | # Model file 53 | def get_model_path(random_percent): 54 | """return path of saved model trained with random_percent""" 55 | assert random_percent in range(101) 56 | logpath = "./logs" if config.on_cluster else "./hpc-logs" 57 | logpath = os.path.join(logpath, f"train/coinrun/freq-sweep-random-percent-{random_percent}") 58 | run = list(os.listdir(logpath))[0] 59 | return os.path.join(logpath, run, "model_80084992.pth") 60 | 61 | datestr = time.strftime("%Y-%m-%d_%H:%M:%S") 62 | logpath = os.path.join(config.results_dir, f"test_rand_percent_{args.random_percent}") 63 | try: 64 | path_to_model_file = get_model_path(int(args.model_file)) 65 | logpath = os.path.join(logpath, f"train_rand_percent_{args.model_file}") 66 | except (ValueError, AssertionError): 67 | path_to_model_file = args.model_file 68 | logpath = os.path.join(logpath, f"unkown_model__" + datestr) 69 | 70 | os.makedirs(logpath, exist_ok=True) 71 | with open(os.path.join(logpath, "metadata.txt"), "a") as f: 72 | f.write(time.strftime("%Y-%m-%d %H:%M:%S") + f", modelfile {path_to_model_file}\n") 73 | 74 | logfile = os.path.join(logpath, f"metrics_agent_seed_{args.agent_seed}.csv") 75 | print(f"Saving metrics to {logfile}.") 76 | print(f"Running coinrun with random_percent={args.random_percent}...") 77 | for env_seed in tqdm(seeds, disable=True): 78 | run_env(exp_name=args.exp_name, 79 | logfile=logfile, 80 | model_file=path_to_model_file, 81 | level_seed=env_seed, 82 | device=args.device, 83 | gpu_device=args.gpu_device, 84 | random_percent=args.random_percent, 85 | reset_mode=args.reset_mode) 86 | -------------------------------------------------------------------------------- /run_utils.py: -------------------------------------------------------------------------------- 1 | from numpy.lib.npyio import save 2 | from common.env.procgen_wrappers import * 3 | from common.logger import Logger 4 | from common.storage import Storage 5 | from common.model import NatureModel, ImpalaModel 6 | from common.policy import CategoricalPolicy 7 | from common import set_global_seeds, set_global_log_levels 8 | 9 | from pathlib import Path 10 | import os, time, yaml 11 | import gym 12 | from procgen import ProcgenEnv 13 | import torch 14 | import csv 15 | import numpy as np 16 | 17 | 18 | def load_env_and_agent(exp_name, 19 | env_name, 20 | num_envs, 21 | model_file, 22 | start_level, 23 | num_levels, 24 | distribution_mode="hard", 25 | param_name="hard", 26 | device="cpu", 27 | gpu_device=0, 28 | random_percent=0, 29 | logdir=None, 30 | num_threads=10): 31 | 32 | if env_name != "coinrun": 33 | raise ValueError("Environment must be coinrun") 34 | 35 | #################### 36 | ## HYPERPARAMETERS # 37 | #################### 38 | with open('hyperparams/procgen/config.yml', 'r') as f: 39 | hyperparameters = yaml.safe_load(f)[param_name] 40 | 41 | ############ 42 | ## DEVICE ## 43 | ############ 44 | if device == 'gpu': 45 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device) 46 | device = torch.device('cuda') 47 | else: 48 | device = torch.device('cpu') 49 | 50 | ################# 51 | ## ENVIRONMENT ## 52 | ################# 53 | def create_venv(hyperparameters): 54 | venv = ProcgenEnv(num_envs=num_envs, 55 | env_name=env_name, 56 | num_levels=num_levels, 57 | start_level=int(start_level), 58 | distribution_mode=distribution_mode, 59 | num_threads=num_threads, 60 | random_percent=random_percent) 61 | venv = VecExtractDictObs(venv, "rgb") 62 | normalize_rew = hyperparameters.get('normalize_rew', True) 63 | if normalize_rew: 64 | venv = VecNormalize(venv, ob=False) # normalizing returns, but not 65 | #the img frames 66 | venv = TransposeFrame(venv) 67 | venv = ScaledFloatFrame(venv) 68 | return venv 69 | n_steps = hyperparameters.get('n_steps', 256) 70 | 71 | env = create_venv(hyperparameters) 72 | 73 | ############ 74 | ## LOGGER ## 75 | ############ 76 | logger = Logger(num_envs, "/dev/null") 77 | 78 | ########### 79 | ## MODEL ## 80 | ########### 81 | observation_space = env.observation_space 82 | observation_shape = observation_space.shape 83 | architecture = hyperparameters.get('architecture', 'impala') 84 | in_channels = observation_shape[0] 85 | action_space = env.action_space 86 | 87 | # Model architecture 88 | if architecture == 'nature': 89 | model = NatureModel(in_channels=in_channels) 90 | elif architecture == 'impala': 91 | model = ImpalaModel(in_channels=in_channels) 92 | 93 | # Discrete action space 94 | recurrent = hyperparameters.get('recurrent', False) 95 | if isinstance(action_space, gym.spaces.Discrete): 96 | action_size = action_space.n 97 | policy = CategoricalPolicy(model, recurrent, action_size) 98 | else: 99 | raise NotImplementedError 100 | policy.to(device) 101 | 102 | ############# 103 | ## STORAGE ## 104 | ############# 105 | hidden_state_dim = model.output_dim 106 | storage = Storage(observation_shape, hidden_state_dim, n_steps, num_envs, device) 107 | 108 | ########### 109 | ## AGENT ## 110 | ########### 111 | algo = hyperparameters.get('algo', 'ppo') 112 | if algo == 'ppo': 113 | from agents.ppo import PPO as AGENT 114 | else: 115 | raise NotImplementedError 116 | agent = AGENT(env, policy, logger, storage, device, 0, **hyperparameters) 117 | 118 | agent.policy.load_state_dict(torch.load(model_file, map_location=device)["model_state_dict"]) 119 | agent.n_envs = num_envs 120 | return agent 121 | 122 | 123 | def load_episode(exp_name, level_seed, **kwargs): 124 | """Load a single coinrun level with fixed seed. Same level layout after reset 125 | logdir is just for agent logs.""" 126 | return load_env_and_agent( 127 | exp_name=exp_name, 128 | env_name="coinrun", 129 | num_envs=1, 130 | num_levels=1, 131 | start_level=level_seed, 132 | num_threads=1, 133 | **kwargs) 134 | 135 | 136 | ############## 137 | ## DEPLOY ## 138 | ############## 139 | 140 | def run_env( 141 | exp_name, 142 | level_seed, 143 | logfile=None, 144 | reset_mode="inv_coin", 145 | max_num_timesteps=10_000, 146 | save_value=False, 147 | **kwargs): 148 | """ 149 | Runs one coinrun level. 150 | Reset modes: 151 | - inv_coin returns when agent gets the inv coin OR finishes the level 152 | - complete returns when the agent finishes the level 153 | - off resets only when max_num_timesteps is reached (repeating always the same level) 154 | 155 | returns level metrics. If logfile (csv) is supplied, metrics are also 156 | appended there. 157 | """ 158 | if save_value: 159 | raise NotImplementedError 160 | 161 | if logfile is not None: 162 | append_to_csv = True 163 | 164 | agent = load_episode(exp_name, level_seed, **kwargs) 165 | 166 | obs = agent.env.reset() 167 | hidden_state = np.zeros((agent.n_envs, agent.storage.hidden_state_size)) 168 | done = np.zeros(agent.n_envs) 169 | 170 | 171 | def log_to_csv(metrics): 172 | """write metrics to csv""" 173 | if not metrics: 174 | return 175 | column_names = ["seed", "steps", "rand_coin", "coin_collected", "inv_coin_collected", "died", "timed_out"] 176 | metrics = [int(m) for m in metrics] 177 | if append_to_csv: 178 | with open(logfile, "a") as f: 179 | w = csv.writer(f) 180 | if f.tell() == 0: # write header first 181 | w.writerow(column_names) 182 | w.writerow(metrics) 183 | 184 | 185 | def log_metrics(done: bool, info: dict): 186 | """ 187 | When run complete, log metrics in the 188 | following format: 189 | seed, steps, randomize_goal, collected_coin, collected_inv_coin, died, timed_out 190 | """ 191 | metrics = None 192 | if done: 193 | keys = ["prev_level_seed", "prev_level/total_steps", "prev_level/randomize_goal", "prev_level_complete", "prev_level/invisible_coin_collected"] 194 | metrics = [info[key] for key in keys] 195 | if info["prev_level_complete"]: 196 | metrics.extend([False, False]) 197 | else: 198 | timed_out = info["prev_level/total_steps"] > 999 199 | metrics.extend([not timed_out, timed_out]) 200 | elif info["invisible_coin_collected"]: 201 | keys = ["level_seed", "total_steps", "randomize_goal"] 202 | metrics = [info[key] for key in keys] 203 | metrics.extend([-1, True, -1, -1]) 204 | else: 205 | raise 206 | log_to_csv(metrics) 207 | return metrics 208 | 209 | 210 | def check_if_break(done: bool, info: dict): 211 | if reset_mode == "inv_coin": 212 | return done or info["invisible_coin_collected"] 213 | elif reset_mode == "complete": 214 | return done 215 | elif reset_mode == "off": 216 | return False 217 | else: 218 | raise ValueError("Reset mode must be one of inv_coin, complete, off." 219 | f"Instead got {reset_mode}") 220 | 221 | step = 0 222 | while step < max_num_timesteps: 223 | agent.policy.eval() 224 | for _ in range(agent.n_steps): # = 256 225 | step += 1 226 | act, log_prob_act, value, next_hidden_state = agent.predict(obs, hidden_state, done) 227 | next_obs, rew, done, info = agent.env.step(act) 228 | 229 | agent.storage.store(obs, hidden_state, act, rew, done, info, log_prob_act, value) 230 | obs = next_obs 231 | hidden_state = next_hidden_state 232 | 233 | if check_if_break(done[0], info[0]): 234 | log_metrics(done[0], info[0]) 235 | return 236 | return 237 | 238 | -------------------------------------------------------------------------------- /run_vanilla_coinrun.sh: -------------------------------------------------------------------------------- 1 | 2 | for _ in {0..10} 3 | do 4 | python run_vanilla_coinrun.py --model_file /home/lauro/projects/aisc2021/model-files/coinrun.pth --start_level_seed 0 --num_seeds 5000 5 | done 6 | -------------------------------------------------------------------------------- /train-interleave-envs.py: -------------------------------------------------------------------------------- 1 | from common.env.procgen_wrappers import * 2 | from common.logger import Logger 3 | from common.storage import Storage 4 | from common.model import NatureModel, ImpalaModel 5 | from common.policy import CategoricalPolicy 6 | from common import set_global_seeds, set_global_log_levels 7 | 8 | import os, time, yaml, argparse 9 | import gym 10 | from procgen import ProcgenGym3Env 11 | import random 12 | import torch 13 | 14 | import gym3 15 | 16 | if __name__=='__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--exp_name', type=str, default = 'test', help='experiment name') 19 | parser.add_argument('--envs', type=str, default = ['maze', 'heist'], nargs="+", help='list of environment IDs') 20 | parser.add_argument('--start_level', type=int, default = int(0), help='start-level for environment') 21 | parser.add_argument('--num_levels', type=int, default = int(0), help='number of training levels for environment') 22 | parser.add_argument('--distribution_mode',type=str, default = 'easy', help='distribution mode for environment') 23 | parser.add_argument('--param_name', type=str, default = 'easy-200', help='hyper-parameter ID') 24 | parser.add_argument('--device', type=str, default = 'gpu', required = False, help='whether to use gpu') 25 | parser.add_argument('--gpu_device', type=int, default = int(0), required = False, help = 'visible device in CUDA') 26 | parser.add_argument('--num_timesteps', type=int, default = int(25000000), help = 'number of training timesteps') 27 | parser.add_argument('--seed', type=int, default = random.randint(0,9999), help='Random generator seed') 28 | parser.add_argument('--log_level', type=int, default = int(40), help='[10,20,30,40]') 29 | parser.add_argument('--num_checkpoints', type=int, default = int(1), help='number of checkpoints to store') 30 | parser.add_argument('--model_file', type=str) 31 | 32 | #multi threading 33 | parser.add_argument('--num_threads', type=int, default=8) 34 | 35 | args = parser.parse_args() 36 | 37 | set_global_seeds(args.seed) 38 | set_global_log_levels(args.log_level) 39 | 40 | #################### 41 | ## HYPERPARAMETERS # 42 | #################### 43 | print('[LOADING HYPERPARAMETERS...]') 44 | with open('hyperparams/procgen/config.yml', 'r') as f: 45 | hyperparameters = yaml.safe_load(f)[args.param_name] 46 | for key, value in hyperparameters.items(): 47 | print(key, ':', value) 48 | 49 | ############ 50 | ## DEVICE ## 51 | ############ 52 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_device) 53 | if args.device == 'gpu': 54 | device = torch.device('cuda') 55 | elif args.device == 'cpu': 56 | device = torch.device('cpu') 57 | 58 | ################# 59 | ## ENVIRONMENT ## 60 | ################# 61 | print('INITIALIZAING ENVIRONMENTS...') 62 | print("Training on concatenation of the following envs:") 63 | print("\n".join(args.envs)) 64 | print() 65 | 66 | n_steps = hyperparameters.get('n_steps', 256) 67 | n_envs = hyperparameters.get('n_envs', 256) 68 | distribution_modes = [args.distribution_mode] * len(args.envs) 69 | num_threads_per_env = args.num_threads // len(args.envs) 70 | num_envs_per_env = n_envs // len(args.envs) 71 | 72 | def create_concat_env(hyperparameters): 73 | envs = [ProcgenGym3Env(num=num_envs_per_env, 74 | env_name=env_name, 75 | num_levels=args.num_levels, 76 | start_level=args.start_level, 77 | distribution_mode=distribution_mode, 78 | num_threads=num_threads_per_env) 79 | for env_name, distribution_mode in zip(args.envs, distribution_modes)] 80 | 81 | venv = gym3.ConcatEnv(envs) 82 | venv = gym3.ToBaselinesVecEnv(venv) 83 | venv = VecExtractDictObs(venv, "rgb") 84 | normalize_rew = hyperparameters.get('normalize_rew', True) 85 | if normalize_rew: 86 | venv = VecNormalize(venv, ob=False) # normalizing returns, but not 87 | #the img frames 88 | venv = TransposeFrame(venv) 89 | venv = ScaledFloatFrame(venv) 90 | return venv 91 | 92 | env = create_concat_env(hyperparameters) 93 | 94 | ############ 95 | ## LOGGER ## 96 | ############ 97 | print('INITIALIZING LOGGER...') 98 | logdir = 'procgen-august-2021/' + 'concat' + '/' + args.exp_name + '/' + 'seed' + '_' + \ 99 | str(args.seed) + '_' + time.strftime("%d-%m-%Y_%H-%M-%S") 100 | logdir = os.path.join('logs', logdir) 101 | if not (os.path.exists(logdir)): 102 | os.makedirs(logdir) 103 | print(f'Logging to {logdir}') 104 | logger = Logger(n_envs, logdir) 105 | 106 | ########### 107 | ## MODEL ## 108 | ########### 109 | print('INTIALIZING MODEL...') 110 | observation_space = env.observation_space 111 | observation_shape = observation_space.shape 112 | architecture = hyperparameters.get('architecture', 'impala') 113 | in_channels = observation_shape[0] 114 | action_space = env.action_space 115 | 116 | # Model architecture 117 | if architecture == 'nature': 118 | model = NatureModel(in_channels=in_channels) 119 | elif architecture == 'impala': 120 | model = ImpalaModel(in_channels=in_channels) 121 | 122 | # Discrete action space 123 | recurrent = hyperparameters.get('recurrent', False) 124 | if isinstance(action_space, gym.spaces.Discrete): 125 | action_size = action_space.n 126 | policy = CategoricalPolicy(model, recurrent, action_size) 127 | else: 128 | raise NotImplementedError 129 | policy.to(device) 130 | 131 | ############# 132 | ## STORAGE ## 133 | ############# 134 | print('INITIALIZING STORAGE...') 135 | hidden_state_dim = model.output_dim 136 | storage = Storage(observation_shape, hidden_state_dim, n_steps, n_envs, device) 137 | 138 | ########### 139 | ## AGENT ## 140 | ########### 141 | print('INTIALIZING AGENT...') 142 | algo = hyperparameters.get('algo', 'ppo') 143 | if algo == 'ppo': 144 | from agents.ppo import PPO as AGENT 145 | else: 146 | raise NotImplementedError 147 | agent = AGENT(env, policy, logger, storage, device, args.num_checkpoints, **hyperparameters) 148 | if args.model_file is not None: 149 | print("Loading agent from %s" % args.model_file) 150 | checkpoint = torch.load(args.model_file) 151 | agent.policy.load_state_dict(checkpoint["model_state_dict"]) 152 | agent.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 153 | 154 | ############## 155 | ## TRAINING ## 156 | ############## 157 | print('START TRAINING...') 158 | agent.train(args.num_timesteps) 159 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from common.env.procgen_wrappers import * 2 | from common.logger import Logger 3 | from common.storage import Storage 4 | from common.model import NatureModel, ImpalaModel 5 | from common.policy import CategoricalPolicy 6 | from common import set_global_seeds, set_global_log_levels 7 | 8 | import os, time, yaml, argparse 9 | import gym 10 | from procgen import ProcgenEnv 11 | import random 12 | import torch 13 | 14 | try: 15 | import wandb 16 | except ImportError: 17 | pass 18 | 19 | 20 | if __name__=='__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--exp_name', type=str, default = 'test', help='experiment name') 23 | parser.add_argument('--env_name', type=str, default = 'coinrun', help='environment ID') 24 | parser.add_argument('--val_env_name', type=str, default = None, help='optional validation environment ID') 25 | parser.add_argument('--start_level', type=int, default = int(0), help='start-level for environment') 26 | parser.add_argument('--num_levels', type=int, default = int(0), help='number of training levels for environment') 27 | parser.add_argument('--distribution_mode',type=str, default = 'easy', help='distribution mode for environment') 28 | parser.add_argument('--param_name', type=str, default = 'easy-200', help='hyper-parameter ID') 29 | parser.add_argument('--device', type=str, default = 'gpu', required = False, help='whether to use gpu') 30 | parser.add_argument('--gpu_device', type=int, default = int(0), required = False, help = 'visible device in CUDA') 31 | parser.add_argument('--num_timesteps', type=int, default = int(25000000), help = 'number of training timesteps') 32 | parser.add_argument('--seed', type=int, default = random.randint(0,9999), help='Random generator seed') 33 | parser.add_argument('--log_level', type=int, default = int(40), help='[10,20,30,40]') 34 | parser.add_argument('--num_checkpoints', type=int, default = int(1), help='number of checkpoints to store') 35 | parser.add_argument('--model_file', type=str) 36 | parser.add_argument('--use_wandb', action="store_true") 37 | 38 | parser.add_argument('--wandb_tags', type=str, nargs='+') 39 | 40 | 41 | parser.add_argument('--random_percent', type=int, default=0, help='COINRUN: percent of environments in which coin is randomized (only for coinrun)') 42 | parser.add_argument('--key_penalty', type=int, default=0, help='HEIST_AISC: Penalty for picking up keys (divided by 10)') 43 | parser.add_argument('--step_penalty', type=int, default=0, help='HEIST_AISC: Time penalty per step (divided by 1000)') 44 | parser.add_argument('--rand_region', type=int, default=0, help='MAZE: size of region (in upper left corner) in which goal is sampled.') 45 | 46 | 47 | #multi threading 48 | parser.add_argument('--num_threads', type=int, default=8) 49 | 50 | args = parser.parse_args() 51 | exp_name = args.exp_name 52 | env_name = args.env_name 53 | val_env_name = args.val_env_name if args.val_env_name else args.env_name 54 | start_level = args.start_level 55 | start_level_val = random.randint(0, 9999) 56 | num_levels = args.num_levels 57 | distribution_mode = args.distribution_mode 58 | param_name = args.param_name 59 | gpu_device = args.gpu_device 60 | num_timesteps = int(args.num_timesteps) 61 | seed = args.seed 62 | log_level = args.log_level 63 | num_checkpoints = args.num_checkpoints 64 | 65 | set_global_seeds(seed) 66 | set_global_log_levels(log_level) 67 | 68 | if args.start_level == start_level_val: 69 | raise ValueError("Seeds for training and validation envs are equal.") 70 | 71 | #################### 72 | ## HYPERPARAMETERS # 73 | #################### 74 | print('[LOADING HYPERPARAMETERS...]') 75 | with open('hyperparams/procgen/config.yml', 'r') as f: 76 | hyperparameters = yaml.safe_load(f)[param_name] 77 | for key, value in hyperparameters.items(): 78 | print(key, ':', value) 79 | 80 | ############ 81 | ## DEVICE ## 82 | ############ 83 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device) 84 | if args.device == 'gpu': 85 | device = torch.device('cuda') 86 | elif args.device == 'cpu': 87 | device = torch.device('cpu') 88 | 89 | ################# 90 | ## ENVIRONMENT ## 91 | ################# 92 | print('INITIALIZAING ENVIRONMENTS...') 93 | 94 | n_steps = hyperparameters.get('n_steps', 256) 95 | n_envs = hyperparameters.get('n_envs', 256) 96 | 97 | def create_venv(args, hyperparameters, is_valid=False): 98 | venv = ProcgenEnv(num_envs=n_envs, 99 | env_name=val_env_name if is_valid else env_name, 100 | num_levels=0 if is_valid else args.num_levels, 101 | start_level=start_level_val if is_valid else args.start_level, 102 | distribution_mode=args.distribution_mode, 103 | num_threads=args.num_threads, 104 | random_percent=args.random_percent, 105 | step_penalty=args.step_penalty, 106 | key_penalty=args.key_penalty, 107 | rand_region=args.rand_region) 108 | venv = VecExtractDictObs(venv, "rgb") 109 | normalize_rew = hyperparameters.get('normalize_rew', True) 110 | if normalize_rew: 111 | venv = VecNormalize(venv, ob=False) # normalizing returns, but not 112 | #the img frames 113 | venv = TransposeFrame(venv) 114 | venv = ScaledFloatFrame(venv) 115 | return venv 116 | 117 | env = create_venv(args, hyperparameters) 118 | env_valid = create_venv(args, hyperparameters, is_valid=True) 119 | 120 | 121 | 122 | 123 | ############ 124 | ## LOGGER ## 125 | ############ 126 | def listdir(path): 127 | return [os.path.join(path, d) for d in os.listdir(path)] 128 | 129 | def get_latest_model(model_dir): 130 | """given model_dir with files named model_n.pth where n is an integer, 131 | return the filename with largest n""" 132 | steps = [int(filename[6:-4]) for filename in os.listdir(model_dir) if filename.startswith("model_")] 133 | return list(os.listdir(model_dir))[np.argmax(steps)] 134 | 135 | print('INITIALIZING LOGGER...') 136 | 137 | logdir = os.path.join('logs', 'train', env_name, exp_name) 138 | if args.model_file == "auto": # try to figure out which file to load 139 | logdirs_with_model = [d for d in listdir(logdir) if any(['model' in filename for filename in os.listdir(d)])] 140 | if len(logdirs_with_model) > 1: 141 | raise ValueError("Received args.model_file = 'auto', but there are multiple experiments" 142 | f" with saved models under experiment_name {exp_name}.") 143 | elif len(logdirs_with_model) == 0: 144 | raise ValueError("Received args.model_file = 'auto', but there are" 145 | f" no saved models under experiment_name {exp_name}.") 146 | model_dir = logdirs_with_model[0] 147 | args.model_file = os.path.join(model_dir, get_latest_model(model_dir)) 148 | logdir = model_dir # reuse logdir 149 | else: 150 | run_name = time.strftime("%Y-%m-%d__%H-%M-%S") + f'__seed_{seed}' 151 | logdir = os.path.join(logdir, run_name) 152 | if not (os.path.exists(logdir)): 153 | os.makedirs(logdir) 154 | 155 | print(f'Logging to {logdir}') 156 | if args.use_wandb: 157 | cfg = vars(args) 158 | cfg.update(hyperparameters) 159 | wb_resume = "allow" if args.model_file is None else "must" 160 | wandb.init(project="objective-robustness", config=cfg, tags=args.wandb_tags, resume=wb_resume) 161 | logger = Logger(n_envs, logdir, use_wandb=args.use_wandb) 162 | 163 | ########### 164 | ## MODEL ## 165 | ########### 166 | print('INTIALIZING MODEL...') 167 | observation_space = env.observation_space 168 | observation_shape = observation_space.shape 169 | architecture = hyperparameters.get('architecture', 'impala') 170 | in_channels = observation_shape[0] 171 | action_space = env.action_space 172 | 173 | # Model architecture 174 | if architecture == 'nature': 175 | model = NatureModel(in_channels=in_channels) 176 | elif architecture == 'impala': 177 | model = ImpalaModel(in_channels=in_channels) 178 | 179 | # Discrete action space 180 | recurrent = hyperparameters.get('recurrent', False) 181 | if isinstance(action_space, gym.spaces.Discrete): 182 | action_size = action_space.n 183 | policy = CategoricalPolicy(model, recurrent, action_size) 184 | else: 185 | raise NotImplementedError 186 | policy.to(device) 187 | 188 | ############# 189 | ## STORAGE ## 190 | ############# 191 | print('INITIALIZING STORAGE...') 192 | hidden_state_dim = model.output_dim 193 | storage = Storage(observation_shape, hidden_state_dim, n_steps, n_envs, device) 194 | storage_valid = Storage(observation_shape, hidden_state_dim, n_steps, n_envs, device) 195 | 196 | ########### 197 | ## AGENT ## 198 | ########### 199 | print('INTIALIZING AGENT...') 200 | algo = hyperparameters.get('algo', 'ppo') 201 | if algo == 'ppo': 202 | from agents.ppo import PPO as AGENT 203 | else: 204 | raise NotImplementedError 205 | agent = AGENT(env, policy, logger, storage, device, 206 | num_checkpoints, 207 | env_valid=env_valid, 208 | storage_valid=storage_valid, 209 | **hyperparameters) 210 | if args.model_file is not None: 211 | print("Loading agent from %s" % args.model_file) 212 | checkpoint = torch.load(args.model_file) 213 | agent.policy.load_state_dict(checkpoint["model_state_dict"]) 214 | agent.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 215 | 216 | ############## 217 | ## TRAINING ## 218 | ############## 219 | print('START TRAINING...') 220 | agent.train(num_timesteps) 221 | --------------------------------------------------------------------------------