├── interp ├── common │ ├── __init__.py │ ├── models.py │ ├── atari_interp.py │ ├── atari_no_score.py │ └── wrappers.py ├── __init__.py └── utils.py ├── agents └── .gitignore ├── videos └── .gitignore ├── datasets └── .gitignore ├── reward-models └── .gitignore ├── .gitmodules ├── setup.py ├── requirements.txt ├── execute_maze_runs ├── prepare ├── README.md ├── .gitignore ├── scripts ├── maze_train_reward_model.py ├── atari_train_reward_model.py ├── maze_create_dataset.py ├── atari_create_dataset.py ├── atari_record_videos.py ├── maze_figures.py └── maze_train_agent.py └── paper-notebooks ├── gridworld-training-curves.ipynb └── atari-figures.ipynb /interp/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /agents/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /videos/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore hdf5 files in this directory 2 | *.hdf5 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /reward-models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "rl-baselines3-zoo"] 2 | path = rl-baselines3-zoo 3 | url = https://github.com/ejmichaud/rl-baselines3-zoo 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="interp", 5 | version='0.1', 6 | author='Eric J. Michaud', 7 | packages=['interp'] 8 | ) 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | stable_baselines3 2 | gym[all] 3 | git+https://github.com/ejmichaud/mazelab.git 4 | tqdm 5 | h5py 6 | PyYAML 7 | numpy 8 | matplotlib 9 | scipy 10 | scikit-image 11 | jupyter 12 | -e . 13 | -------------------------------------------------------------------------------- /interp/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import gym 3 | from .common.atari_no_score import AtariEnvNoScore, supported_games 4 | 5 | for game in supported_games: 6 | name = ''.join([g.capitalize() for g in game.split('_')]) 7 | obs_type = 'image' 8 | nondeterministic = False 9 | frameskip = 4 10 | gym.register( 11 | id='{}NoFrameskipNoScore-v4'.format(name), 12 | entry_point=AtariEnvNoScore, 13 | kwargs={'game': game, 'obs_type': obs_type, 'frameskip': 1}, # A frameskip of 1 means we get every frame 14 | max_episode_steps=frameskip * 100000, 15 | nondeterministic=nondeterministic, 16 | ) 17 | -------------------------------------------------------------------------------- /execute_maze_runs: -------------------------------------------------------------------------------- 1 | 2 | cd rl-baselines3-zoo 3 | 4 | python train.py --algo ppo \ 5 | --env EmptyMaze-10x10-CoinFlipGoal-v3 \ 6 | -f ../agents-custom \ 7 | --gym-packages mazelab \ 8 | --eval-freq 500 \ 9 | --seed 0 \ 10 | --hyperparam-title CoinFlipGoalWithGroundTruth 11 | 12 | python train.py --algo ppo \ 13 | --env EmptyMaze-10x10-CoinFlipGoal-v3 \ 14 | -f ../agents-custom \ 15 | --gym-packages mazelab \ 16 | --eval-freq 500 \ 17 | --seed 0 \ 18 | --hyperparam-title CoinFlipGoalWithCoinFlipGoalRewardModel 19 | 20 | python train.py --algo ppo \ 21 | --env EmptyMaze-10x10-TwoGoals-v3 \ 22 | -f ../agents-custom \ 23 | --gym-packages mazelab \ 24 | --eval-freq 500 \ 25 | --seed 0 \ 26 | --hyperparam-title TwoGoalsWithGroundTruth 27 | 28 | python train.py --algo ppo \ 29 | --env EmptyMaze-10x10-TwoGoals-v3 \ 30 | -f ../agents-custom \ 31 | --gym-packages mazelab \ 32 | --eval-freq 500 \ 33 | --seed 0 \ 34 | --hyperparam-title TwoGoalsWithCoinFlipGoalRewardModel 35 | 36 | cd .. 37 | -------------------------------------------------------------------------------- /interp/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | 4 | import numpy as np 5 | 6 | from gym.spaces import Box 7 | from gym.wrappers import FrameStack 8 | 9 | def get_latest_run_id(log_path, env_id): 10 | """ 11 | Returns the latest run number for the given log name and log path, 12 | by finding the greatest number in the directories. 13 | 14 | :param log_path: (str) path to log folder 15 | :param env_id: (str) 16 | :return: (int) latest run number 17 | """ 18 | max_run_id = 0 19 | for path in glob.glob(log_path + "/{}_[0-9]*".format(env_id)): 20 | file_name = path.split("/")[-1] 21 | ext = file_name.split("_")[-1] 22 | if env_id == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id: 23 | max_run_id = int(ext) 24 | return max_run_id 25 | 26 | 27 | class AtariFrameStack(FrameStack): 28 | """Stacks frames but without using LazyFrame.""" 29 | 30 | def __init__(self, env, n_stack=4): 31 | """ 32 | Wraper for `env` which stacks `n_stack` frames. 33 | 34 | Args: 35 | env (gym.Env): Environment to wrap 36 | n_stack (int): Number of observations to stack 37 | """ 38 | super(AtariFrameStack, self).__init__(env, n_stack) 39 | # TODO: modify this line to make the wrapper work for more general types of environments, not just 84x84 Atari 40 | self.observation_space = Box(0, 255, shape=[84, 84, 4], dtype=env.observation_space.dtype) 41 | 42 | def _get_observation(self): 43 | assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack) 44 | stack = np.array(list(self.frames)) # (4, 84, 84) 45 | stack = np.transpose(stack, axes=(1, 2, 0)) # (84, 84, 4) 46 | return stack 47 | 48 | -------------------------------------------------------------------------------- /interp/common/models.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch as th 5 | import torch.nn as nn 6 | 7 | from stable_baselines3 import A2C 8 | from stable_baselines3.common.vec_env import VecTransposeImage 9 | 10 | 11 | class MazeRewardModel(nn.Module): 12 | """A simple 2-hidden-layer MLP reward model for mazelab environments.""" 13 | def __init__(self, env, device): 14 | """Iniitalize a reward model with a Gym environment and a PyTorch device.""" 15 | super(MazeRewardModel, self).__init__() 16 | w, h = env.observation_space.shape 17 | features = 2 * w * h 18 | self.net = nn.Sequential( 19 | nn.Flatten(start_dim=1), 20 | nn.Linear(features, 64, bias=True), 21 | nn.Tanh(), 22 | nn.Linear(64, 64, bias=True), 23 | nn.Tanh(), 24 | nn.Linear(64, 1, bias=False), 25 | ).to(device) 26 | self.device = device 27 | 28 | def forward(self, obs): 29 | """Evaluate the reward model on an np.ndarray observation (s, s').""" 30 | return self.net(th.tensor(obs).to(self.device)) 31 | 32 | def tforward(self, ss): 33 | """Evaluate the reward model on an a th.Tensor observation (s, s').""" 34 | return self.net(ss) 35 | 36 | 37 | class AtariRewardModel(nn.Module): 38 | """A reward model for Atari, using the CNN feature extractor that SB3 policies use.""" 39 | def __init__(self, env, device): 40 | super(AtariRewardModel, self).__init__() 41 | self.ac_model = A2C('CnnPolicy', env).policy 42 | self.reward_net = nn.Linear(512, 1).to(device) 43 | self.device = device 44 | 45 | def forward(self, obs): 46 | obs_transposed = VecTransposeImage.transpose_image(obs) 47 | latent, _, _= self.ac_model._get_latent(th.tensor(obs_transposed).to(self.device)) 48 | return self.reward_net(latent) 49 | 50 | def forward_tensor(self, obs): 51 | """obs is a tensor which has already been transposed correctly.""" 52 | latent, _, _= self.ac_model._get_latent(obs.to(self.device)) 53 | return self.reward_net(latent) 54 | 55 | def freeze_extractor(self): 56 | for p in self.ac_model.policy.features_extractor.parameters(): 57 | p.requires_grad = False 58 | -------------------------------------------------------------------------------- /prepare: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # -------------------------------------------------------------------- 4 | # Prepare the rl-baselines3-zoo submodule for use with mazelab. 5 | # -------------------------------------------------------------------- 6 | 7 | hyperparams = r""" 8 | CoinFlipGoalWithGroundTruth: 9 | policy: 'MlpPolicy' 10 | gamma: 0.97 11 | n_timesteps: !!float 45000 12 | normalize: false 13 | learning_rate: lin_1.0e-4 14 | 15 | CoinFlipGoalWithCoinFlipGoalRewardModel: 16 | policy: 'MlpPolicy' 17 | gamma: 0.97 18 | n_timesteps: !!float 45000 19 | normalize: false 20 | learning_rate: lin_1.0e-4 21 | env_wrapper: 22 | - interp.common.wrappers.WrapWithRewardModel: 23 | model_path: "../reward-models/EmptyMaze-10x10-CoinFlipGoal-v3-reward_model.pt" 24 | 25 | TwoGoalsWithGroundTruth: 26 | policy: 'MlpPolicy' 27 | gamma: 0.97 28 | n_timesteps: !!float 45000 29 | normalize: false 30 | learning_rate: lin_1.0e-4 31 | 32 | TwoGoalsWithCoinFlipGoalRewardModel: 33 | policy: 'MlpPolicy' 34 | gamma: 0.97 35 | n_timesteps: !!float 45000 36 | normalize: false 37 | learning_rate: lin_1.0e-4 38 | env_wrapper: 39 | - interp.common.wrappers.WrapWithRewardModel: 40 | model_path: "../reward-models/EmptyMaze-10x10-CoinFlipGoal-v3-reward_model.pt" 41 | 42 | 43 | 44 | EmptyMaze-10x10-FixedGoal-v3: 45 | policy: 'MlpPolicy' 46 | gamma: 0.97 47 | n_timesteps: !!float 30000 48 | normalize: false 49 | learning_rate: lin_1.0e-4 50 | 51 | EmptyMaze-10x10-FixedGoal-v3: 52 | policy: 'MlpPolicy' 53 | gamma: 0.97 54 | n_timesteps: !!float 30000 55 | normalize: false 56 | learning_rate: lin_1.0e-4 57 | 58 | EmptyMaze-10x10-CoinFlipGoal-v3: 59 | policy: 'MlpPolicy' 60 | gamma: 0.97 61 | n_timesteps: !!float 45000 62 | normalize: false 63 | learning_rate: lin_1.0e-4 64 | 65 | EmptyMaze-10x10-RandomGoal-v3: 66 | policy: 'MlpPolicy' 67 | gamma: 0.97 68 | n_timesteps: !!float 250000 69 | normalize: false 70 | learning_rate: lin_1.0e-4 71 | 72 | """ 73 | 74 | if __name__ == '__main__': 75 | with open('rl-baselines3-zoo/hyperparams/ppo.yml', 'r') as f: 76 | original = f.read() 77 | new = hyperparams + original 78 | with open('rl-baselines3-zoo/hyperparams/ppo.yml', 'w') as f: 79 | f.write(new) 80 | -------------------------------------------------------------------------------- /interp/common/atari_interp.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import scipy.ndimage 5 | import skimage.transform 6 | 7 | """ 8 | Note that this code closely follows that of: 9 | https://github.com/greydanus/visualize_atari/blob/master/saliency.py 10 | See Greydanus et al., Visualizing and Understanding Atari Agents. Url: https://arxiv.org/abs/1711.00138 11 | """ 12 | 13 | def get_mask(center, size, r): 14 | """Creates a normalized mask (np.ndarray) of shape `size`, radius `r`, centered at `center`.""" 15 | y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]] 16 | keep = x*x + y*y <= 1 17 | mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels 18 | mask = scipy.ndimage.gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1 19 | return mask/mask.max() 20 | 21 | def occlude(img, mask): 22 | """Uses `mask` to occlude a region of `img`.""" 23 | assert len(image.shape) == 4, and image.shape[0] == 1, "`img` must have shape (1, h, w, n) where k >= 1" 24 | img = np.copy(img) 25 | n = img.shape[-1] 26 | for k in range(n): 27 | I = img[0, :, :, k] 28 | img[0, :, :, k] = I*(1-mask) + scipy.ndimage.gaussian_filter(I, sigma=3)*mask 29 | return img 30 | 31 | def compute_saliency_map(reward_model, obs, output_shape=(210, 160), stride=5, radius=5): 32 | """Computes a saliency map of `reward_model` over `obs`, reshaped to `ouput_shape`.""" 33 | baseline = reward_model(obs).detach().cpu().numpy() 34 | _, h, w, _ = obs.shape 35 | scores = np.zeros((h // stride + 1, w // stride + 1)) 36 | for i in range(0, h, stride): 37 | for j in range(0, w, stride): 38 | mask = get_mask(center=(i, j), size=(h, w), r=radius) 39 | obs_perturbed = occlude(obs, mask) 40 | perturbed_reward = reward_model(obs_perturbed).detach().cpu().numpy() 41 | scores[i // stride, j // stride] = 0.5 * np.abs(perturbed_reward - baseline) ** 2 42 | # pmax = scores.max() 43 | scores = skimage.transform.resize(scores, output_shape=output_shape) 44 | scores = scores.astype(np.float32) 45 | # return pmax * scores / scores.max() 46 | return scores / scores.max() 47 | 48 | def add_saliency_to_frame(frame, saliency, channel=1): 49 | """Impose saliency map `saliency` over image `frame`.""" 50 | pmax = saliency.max() 51 | I = frame.astype('uint16') 52 | I[:, :, channel] += (frame.max() * saliency).astype('uint16') 53 | I = I.clip(1,255).astype('uint8') 54 | return I 55 | 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # interpreting-rewards 2 | 3 | This repository accompanies the paper [Understanding Learned Reward Functions](https://ericjmichaud.com/rewards.pdf) by [Eric J. Michaud](https://ericjmichaud.com), [Adam Gleave](https://gleave.me) and [Stuart Russell](https://people.eecs.berkeley.edu/~russell/). It aims to enable easy reproduction of the results from that paper and to serve as a branching-off point for future iterations of the work. **Note that this repository is still very much a work in progress. Although you should be able to replicate many of the figures from scratch, the pipeline for training agents using a learned reward function (which Figure 3 and Table 1 need) is currently broken.** 4 | 5 | 6 | ## Installation 7 | 8 | First, clone the repository: 9 | ``` 10 | git clone --recurse-submodules https://github.com/HumanCompatibleAI/interpreting-rewards.git 11 | ``` 12 | Which will also clone a special version of `rl-baselines3-zoo` as a submodule. 13 | 14 | To install dependencies, run 15 | ``` 16 | pip install -r requirements.txt 17 | pip install -r rl-baselines3-zoo/requirements.txt 18 | ``` 19 | 20 | ## Usage 21 | 22 | **To replicate gridworld results from the very beginning**, first train a policy on the gridworld environment: 23 | ``` 24 | ./prepare 25 | cd rl-baselines3-zoo 26 | python train.py --algo ppo --env EmptyMaze-10x10-CoinFlipGoal-v3 -f ../agents --seed 0 --gym-packages mazelab 27 | cd .. 28 | ``` 29 | Then use this policy to create a dataset of (transition, reward) pairs, to train a reward model on via regression: 30 | ``` 31 | cd scripts 32 | python maze_create_dataset.py --algo ppo --env EmptyMaze-10x10-CoinFlipGoal-v3 -f ../agents --seed 0 33 | ``` 34 | And train the reward model: 35 | ``` 36 | python maze_train_reward_model.py --env EmptyMaze-10x10-CoinFlipGoal-v3 --epochs 5 --seed 0 37 | ``` 38 | From here, the saliency map figures from the paper can be created with a single command: 39 | ``` 40 | python maze_figures.py 41 | ``` 42 | Which will create and save the figures to a `./figures` directory of the repository root directory. 43 | 44 | **To replicate Atari results**, first train policies on Breakout and Seaquest: 45 | ``` 46 | cd rl-baselines3-zoo 47 | python train.py --algo ppo --env BreakoutNoFrameskip-v4 -f ../agents 48 | python train.py --algo ppo --env SeaquestNoFrameskip-v4 -f ../agents 49 | ``` 50 | Create datasets for each: 51 | ``` 52 | cd ../scripts 53 | python atari_create_dataset.py --algo ppo --env BreakoutNoFrameskip-v4 -f ../agents --seed 0 54 | python atari_create_dataset.py --algo ppo --env SeaquestNoFrameskip-v4 -f ../agents --seed 0 55 | ``` 56 | Train reward models: 57 | ``` 58 | python atari_train_reward_model.py --algo ppo --env BreakoutNoFrameskip-v4 -f ../agents --seed 0 --epochs 5 59 | python atari_train_reward_model.py --algo ppo --env SeaquestNoFrameskip-v4 -f ../agents --seed 0 --epochs 5 60 | ``` 61 | 62 | ## TODO: 63 | * Finish testing the Atari pipeline, add a script for creating Atari figures 64 | * Fix the scripts which train maze agents and Atari agents using a learned reward function. Due to changes in `rl-baselines3-zoo` since I last ran these, they no longer work. The Atari script will need a code review, as the results that it generated before were weird (large differences the performance of the policy trained on ground-truth reward vs. reward model, despite the reward model having close to 0 test error. This could be an issue with the reward model only being trained on transitions from an expert policy, but I worry it could be something squirrelly with the script or my custom wrappers. 65 | * Add scripts for downloading parts of the pipeline from AWS without having to run the policy training, dataset creation, reward model training scripts, etc. 66 | 67 | 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,macos 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,macos 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | # Thumbnails 29 | ._* 30 | 31 | # Files that might appear in the root of a volume 32 | .DocumentRevisions-V100 33 | .fseventsd 34 | .Spotlight-V100 35 | .TemporaryItems 36 | .Trashes 37 | .VolumeIcon.icns 38 | .com.apple.timemachine.donotpresent 39 | 40 | # Directories potentially created on remote AFP share 41 | .AppleDB 42 | .AppleDesktop 43 | Network Trash Folder 44 | Temporary Items 45 | .apdisk 46 | 47 | ### Python ### 48 | # Byte-compiled / optimized / DLL files 49 | __pycache__/ 50 | *.py[cod] 51 | *$py.class 52 | 53 | # C extensions 54 | *.so 55 | 56 | # Distribution / packaging 57 | .Python 58 | build/ 59 | develop-eggs/ 60 | dist/ 61 | downloads/ 62 | eggs/ 63 | .eggs/ 64 | lib/ 65 | lib64/ 66 | parts/ 67 | sdist/ 68 | var/ 69 | wheels/ 70 | pip-wheel-metadata/ 71 | share/python-wheels/ 72 | *.egg-info/ 73 | .installed.cfg 74 | *.egg 75 | MANIFEST 76 | 77 | # PyInstaller 78 | # Usually these files are written by a python script from a template 79 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 80 | *.manifest 81 | *.spec 82 | 83 | # Installer logs 84 | pip-log.txt 85 | pip-delete-this-directory.txt 86 | 87 | # Unit test / coverage reports 88 | htmlcov/ 89 | .tox/ 90 | .nox/ 91 | .coverage 92 | .coverage.* 93 | .cache 94 | nosetests.xml 95 | coverage.xml 96 | *.cover 97 | *.py,cover 98 | .hypothesis/ 99 | .pytest_cache/ 100 | 101 | # Translations 102 | *.mo 103 | *.pot 104 | 105 | # Django stuff: 106 | *.log 107 | local_settings.py 108 | db.sqlite3 109 | db.sqlite3-journal 110 | 111 | # Flask stuff: 112 | instance/ 113 | .webassets-cache 114 | 115 | # Scrapy stuff: 116 | .scrapy 117 | 118 | # Sphinx documentation 119 | docs/_build/ 120 | 121 | # PyBuilder 122 | target/ 123 | 124 | # Jupyter Notebook 125 | 126 | # IPython 127 | 128 | # pyenv 129 | .python-version 130 | 131 | # pipenv 132 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 133 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 134 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 135 | # install all needed dependencies. 136 | #Pipfile.lock 137 | 138 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 139 | __pypackages__/ 140 | 141 | # Celery stuff 142 | celerybeat-schedule 143 | celerybeat.pid 144 | 145 | # SageMath parsed files 146 | *.sage.py 147 | 148 | # Environments 149 | .env 150 | .venv 151 | env/ 152 | venv/ 153 | ENV/ 154 | env.bak/ 155 | venv.bak/ 156 | 157 | # Spyder project settings 158 | .spyderproject 159 | .spyproject 160 | 161 | # Rope project settings 162 | .ropeproject 163 | 164 | # mkdocs documentation 165 | /site 166 | 167 | # mypy 168 | .mypy_cache/ 169 | .dmypy.json 170 | dmypy.json 171 | 172 | # Pyre type checker 173 | .pyre/ 174 | 175 | # pytype static type analyzer 176 | .pytype/ 177 | 178 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,macos 179 | 180 | 181 | -------------------------------------------------------------------------------- /scripts/maze_train_reward_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import difflib 3 | import os 4 | import sys 5 | import importlib 6 | import time 7 | import uuid 8 | import random 9 | import warnings 10 | from collections import OrderedDict 11 | from pprint import pprint 12 | 13 | from tqdm.auto import tqdm 14 | import h5py 15 | 16 | import yaml 17 | import gym 18 | import numpy as np 19 | import torch as th 20 | # For custom activation fn 21 | import torch.nn as nn # noqa: F401 pytype: disable=unused-import 22 | 23 | from stable_baselines3.common.utils import set_random_seed 24 | # from stable_baselines3.common.cmd_util import make_atari_env 25 | from stable_baselines3.common.vec_env import VecFrameStack, VecNormalize, DummyVecEnv, VecTransposeImage 26 | from stable_baselines3.common.atari_wrappers import AtariWrapper 27 | from stable_baselines3.common.preprocessing import is_image_space 28 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 29 | from stable_baselines3.common.utils import constant_fn, get_device 30 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 31 | from stable_baselines3 import A2C 32 | 33 | import mazelab 34 | 35 | from interp.common.models import MazeRewardModel 36 | 37 | class RewardData(th.utils.data.Dataset): 38 | def __init__(self, env_id, train=True): 39 | self.f = h5py.File(f"../datasets/rewards_{env_id}.hdf5", 'r') 40 | if train: 41 | self.group = self.f['train'] 42 | else: 43 | self.group = self.f['test'] 44 | 45 | def __getitem__(self, k): 46 | input = self.group['inputs'][k] 47 | output = self.group['outputs'][k] 48 | return (input, output) 49 | 50 | def __len__(self): 51 | return self.group['inputs'].shape[0] 52 | 53 | def close(self): 54 | self.f.close() 55 | 56 | 57 | if __name__ == '__main__': # noqa: C901 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--env', type=str, default="BreakoutNoFrameskip-v4", help='environment ID') 60 | parser.add_argument('-e', '--epochs', help='Number of epochs to train for', default=5, 61 | type=int) 62 | parser.add_argument('-s', '--seed', help="Random seed", default=0, type=int) 63 | args = parser.parse_args() 64 | 65 | device = get_device() 66 | print(f"Using {device} device.") 67 | 68 | seed = args.seed 69 | random.seed(seed) 70 | np.random.seed(seed) 71 | th.manual_seed(seed) 72 | th.backends.cudnn.deterministic = True 73 | th.backends.cudnn.benchmark = False 74 | set_random_seed(seed) 75 | 76 | env_id = args.env 77 | if 'maze' not in env_id.lower(): 78 | raise Exception(f"env {env_id} is not a maze env") 79 | env = gym.make(env_id) 80 | print(f"Created env with obs.shape = {env.reset().shape}.") 81 | 82 | train = RewardData(env_id, train=True) 83 | test = RewardData(env_id, train=False) 84 | 85 | train_loader = th.utils.data.DataLoader(train, batch_size=20, shuffle=True, num_workers=0) 86 | test_loader = th.utils.data.DataLoader(test, batch_size=20, shuffle=False, num_workers=0) 87 | 88 | reward_model = MazeRewardModel(env, device) 89 | optimizer = th.optim.Adam(reward_model.parameters()) 90 | loss_fn = th.nn.MSELoss(reduction="sum") 91 | 92 | num_batches = 0 93 | for e in range(args.epochs): 94 | for samples, targets in tqdm(train_loader): 95 | optimizer.zero_grad() 96 | batch_loss = loss_fn(reward_model(samples), targets.to(device)) 97 | batch_loss.backward() 98 | optimizer.step() 99 | num_batches += 1 100 | test_loss = 0 101 | for samples, targets in test_loader: 102 | with th.no_grad(): 103 | test_loss += loss_fn(reward_model(samples), targets.to(device)) 104 | print("Epoch {:3d} | Test Loss: {:.6f}".format(e, float(test_loss) / len(test))) 105 | 106 | th.save(reward_model.state_dict(), f"../reward-models/{env_id}-reward_model.pt") 107 | 108 | -------------------------------------------------------------------------------- /scripts/atari_train_reward_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import difflib 3 | import os 4 | import sys 5 | import importlib 6 | import time 7 | import random 8 | import uuid 9 | import warnings 10 | from collections import OrderedDict 11 | from pprint import pprint 12 | 13 | from tqdm.auto import tqdm 14 | import h5py 15 | 16 | import yaml 17 | import gym 18 | import numpy as np 19 | import torch as th 20 | # For custom activation fn 21 | import torch.nn as nn # noqa: F401 pytype: disable=unused-import 22 | 23 | from stable_baselines3.common.utils import set_random_seed 24 | # from stable_baselines3.common.cmd_util import make_atari_env 25 | from stable_baselines3.common.vec_env import VecFrameStack, VecNormalize, DummyVecEnv, VecTransposeImage 26 | from stable_baselines3.common.atari_wrappers import AtariWrapper 27 | from stable_baselines3.common.preprocessing import is_image_space 28 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 29 | from stable_baselines3.common.utils import constant_fn, get_device 30 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 31 | from stable_baselines3 import A2C 32 | 33 | from interp.common.models import AtariRewardModel 34 | 35 | class RewardData(th.utils.data.Dataset): 36 | def __init__(self, env_id, train=True): 37 | self.f = h5py.File(f"../datasets/rewards_{env_id}.hdf5", 'r') 38 | if train: 39 | self.group = self.f['train'] 40 | else: 41 | self.group = self.f['test'] 42 | 43 | def __getitem__(self, k): 44 | if k % 2 == 0: 45 | input = self.group['zeros-inputs'][k // 2] 46 | label = self.group['zeros-labels'][k // 2] 47 | return (input, label) 48 | else: 49 | input = self.group['ones-inputs'][k // 2] 50 | label = self.group['ones-labels'][k // 2] 51 | return (input, label) 52 | 53 | def __len__(self): 54 | return self.group['ones-labels'].shape[0] + self.group['zeros-labels'].shape[0] 55 | 56 | def close(self): 57 | self.f.close() 58 | 59 | 60 | if __name__ == '__main__': # noqa: C901 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--env', type=str, default="BreakoutNoFrameskip-v4", help='environment ID') 63 | parser.add_argument('-e', '--epochs', help='Number of epochs to train for', default=5, 64 | type=int) 65 | parser.add_argument('-s', '--seed', help="Random seed", default=0, type=int) 66 | args = parser.parse_args() 67 | 68 | device = get_device() 69 | print(f"Using {device} device.") 70 | 71 | seed = args.seed 72 | random.seed(seed) 73 | np.random.seed(seed) 74 | th.manual_seed(seed) 75 | th.backends.cudnn.deterministic = True 76 | th.backends.cudnn.benchmark = False 77 | set_random_seed(seed) 78 | 79 | env_id = args.env 80 | if 'NoFrameskip' not in env_id: 81 | raise Exception(f"env {env_id} is not an Atari env") 82 | env = gym.make(args.env) 83 | env = AtariWrapper(env) 84 | env = DummyVecEnv([lambda: env]) 85 | env = VecFrameStack(env, n_stack=4) 86 | print(f"Created env with obs.shape = {env.reset().shape}.") 87 | 88 | 89 | train = RewardData(env_id, train=True) 90 | test = RewardData(env_id, train=False) 91 | 92 | train_loader = th.utils.data.DataLoader(train, batch_size=20, shuffle=True, num_workers=0) 93 | test_loader = th.utils.data.DataLoader(test, batch_size=20, shuffle=False, num_workers=0) 94 | 95 | reward_model = AtariRewardModel(env, device) 96 | optimizer = th.optim.Adam(reward_model.parameters()) 97 | loss_fn = th.nn.MSELoss(reduction="sum") 98 | 99 | num_batches = 0 100 | for e in range(args.epochs): 101 | for samples, targets in tqdm(train_loader): 102 | optimizer.zero_grad() 103 | batch_loss = loss_fn(reward_model(samples), targets.to(device)) 104 | batch_loss.backward() 105 | optimizer.step() 106 | num_batches += 1 107 | test_loss = 0 108 | for samples, targets in test_loader: 109 | with th.no_grad(): 110 | test_loss += loss_fn(reward_model(samples), targets.to(device)) 111 | print("Epoch {:3d} | Test Loss: {:.4f}".format(e, float(test_loss) / len(test))) 112 | 113 | th.save(reward_model.state_dict(), f"../reward-models/{env_id}-reward_model.pt") 114 | 115 | -------------------------------------------------------------------------------- /interp/common/atari_no_score.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import scipy.ndimage 5 | 6 | import gym 7 | from gym.envs.atari import AtariEnv 8 | 9 | supported_games = { 10 | 'breakout': { 11 | 'shape': (210, 160, 3), 12 | 'regions': [ 13 | ( 14 | (4, 16), # row range to modify 15 | (35, 81) # column range to modify 16 | ) 17 | ], 18 | 'sample_background_loc': (10, 10) 19 | }, 20 | 'seaquest': { 21 | 'shape': (210, 160, 3), 22 | 'regions': [ 23 | ( 24 | (7, 19), 25 | (70, 120) 26 | ) 27 | ], 28 | 'sample_background_loc': (8, 65) 29 | }, 30 | 'pong': { 31 | 'shape': (210, 160, 3), 32 | 'regions': [ 33 | ( 34 | (0, 22), 35 | (15, 145) 36 | ) 37 | ], 38 | 'sample_background_loc': (10, 10) 39 | }, 40 | 'space_invaders': { 41 | 'shape': (210, 160, 3), 42 | 'regions': [ 43 | ( 44 | (8, 22), 45 | (0, 70) 46 | ), 47 | ( 48 | (8, 22), 49 | (80, 150) 50 | ) 51 | ], 52 | 'sample_background_loc': (7, 7) 53 | }, 54 | 'tennis': { 55 | 'shape': (250, 160, 3), 56 | 'regions': [ 57 | ( 58 | (30, 39), 59 | (30, 72) 60 | ), 61 | ( 62 | (30, 39), 63 | (94, 136) 64 | ), 65 | ], 66 | 'sample_background_loc': (31, 29) 67 | } 68 | } 69 | 70 | 71 | class AtariEnvModifiedScoreBase(AtariEnv, ABC): 72 | """ 73 | Base, abstract class for Atari environments which obscure their score in some way. 74 | """ 75 | 76 | def __init__(self, *args, **kwargs): 77 | if kwargs.get('obs_type') and kwargs['obs_type'] == 'ram': 78 | raise Exception("Only image-based observations can have their score obscurred.") 79 | super(AtariEnvModifiedScoreBase, self).__init__(*args, **kwargs) 80 | assert self.game in supported_games 81 | self.mask = np.zeros(supported_games[self.game]['shape']) 82 | for ((r0, r1), (c0, c1)) in supported_games[self.game]['regions']: 83 | self.mask[r0:r1, c0:c1, :] = 1 84 | 85 | @abstractmethod 86 | def _get_image(self): 87 | pass 88 | 89 | class AtariEnvNoScore(AtariEnvModifiedScoreBase): 90 | """ 91 | An Atari Environment without the score displayed. 92 | """ 93 | 94 | def __init__(self, *args, **kwargs): 95 | """ 96 | Create Atari Environment without the score displayed. 97 | 98 | Args: 99 | game (str): The name of the Atari game, in lowercase. Ex: 'breakout' 100 | mode: As far as I can tell, this parameter is never used. Defaults to None. 101 | difficulty: Also appears to not be used. Defaults to None. 102 | obs_type (str): 'image' or 'ram' 103 | frameskip (tuple or int): Set to 1 for NoFrameskip. 104 | repeat_action_probability (float): Does what you think. 105 | full_action_space (bool): whether to use the full action space. 106 | """ 107 | super(AtariEnvNoScore, self).__init__(*args, **kwargs) 108 | 109 | def _get_image(self): 110 | image = self.ale.getScreenRGB2() 111 | assert image.shape == self.mask.shape, "Game observation dimensions don't mask mask dimensions" 112 | sx, sy = supported_games[self.game]['sample_background_loc'] 113 | color = image[sx, sy] 114 | image = image * (1 - self.mask) + (self.mask * color) 115 | return image.astype(np.uint8) 116 | 117 | 118 | class AtariEnvBlurScore(AtariEnvModifiedScoreBase): 119 | """ 120 | An Atari Environment with the score blurred out. 121 | """ 122 | 123 | def __init__(self, *args, **kwargs): 124 | """ 125 | Create Atari Environment with the score blurred out. 126 | 127 | Args: 128 | game (str): The name of the Atari game, in lowercase. Ex: 'breakout' 129 | mode: As far as I can tell, this parameter is never used. Defaults to None. 130 | difficulty: Also appears to not be used. Defaults to None. 131 | obs_type (str): 'image' or 'ram' 132 | frameskip (tuple or int): Set to 1 for NoFrameskip. 133 | repeat_action_probability (float): Does what you think. Defaults to 0. 134 | full_action_space (bool): whether to use the full action space. Defaults to False. 135 | """ 136 | super(AtariEnvBlurScore, self).__init__(*args, **kwargs) 137 | 138 | def _get_image(self): 139 | image = self.ale.getScreenRGB2() 140 | assert image.shape == self.mask.shape, "Game observation dimensions don't mask mask dimensions" 141 | A = scipy.ndimage.gaussian_filter(image, sigma=(3, 3, 0)) 142 | image = image * (1 - self.mask) + (A * self.mask) 143 | return image.astype(np.uint8) 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /interp/common/wrappers.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABC, abstractmethod 3 | 4 | import numpy as np 5 | import scipy.ndimage 6 | 7 | import torch as th 8 | 9 | import gym 10 | from stable_baselines3.common.type_aliases import GymStepReturn 11 | from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper 12 | 13 | from .models import MazeRewardModel, AtariRewardModel 14 | 15 | 16 | class DummyWrapper(VecEnvWrapper): 17 | """ 18 | Wraps the venv but does absolutely nothing. 19 | """ 20 | 21 | def __init__(self, venv): 22 | super(DummyWrapper, self).__init__(venv) 23 | 24 | def step_wait(self) -> 'GymStepReturn': 25 | observations, rewards, dones, infos = self.venv.step_wait() 26 | return observations, rewards, dones, infos 27 | 28 | def reset(self) -> np.ndarray: 29 | return self.venv.reset() 30 | 31 | 32 | class CustomRewardVecWrapper(VecEnvWrapper): 33 | """ 34 | Wrapper for overriding the environment reward with a custom reward function. 35 | """ 36 | 37 | def __init__(self, venv, reward_function): 38 | super(CustomRewardVecWrapper, self).__init__(venv) 39 | self.reward_function = reward_function 40 | 41 | def step_wait(self) -> 'GymStepReturn': 42 | observations, rewards, dones, infos = self.venv.step_wait() 43 | custom_rewards = self.reward_function(observations) 44 | if type(custom_rewards) is th.Tensor: 45 | custom_rewards = custom_rewards.cpu().detach().numpy() 46 | if len(custom_rewards.shape) == 2: 47 | custom_rewards = custom_rewards.flatten() 48 | elif len(custom_rewards.shape) == 1: 49 | pass 50 | else: 51 | raise Exception("Weirdly shaped reward from custom reward function") 52 | return observations, custom_rewards, dones, infos 53 | 54 | def reset(self) -> np.ndarray: 55 | return self.venv.reset() 56 | 57 | 58 | class CustomRewardWrapper(gym.Wrapper): 59 | def __init__(self, env, reward_function): 60 | super(CustomRewardWrapper, self).__init__(env) 61 | self.reward_function = reward_function 62 | 63 | def step(self, action): 64 | next_state, reward, done, info = self.env.step(action) 65 | custom_reward = self.reward_function(next_state) 66 | if type(custom_reward) is th.Tensor: 67 | custom_reward = custom_reward.cpu().detach().numpy() 68 | if len(custom_reward.shape) == 2: 69 | custom_reward = custom_reward.flatten() 70 | elif len(custom_reward.shape) == 1: 71 | pass 72 | else: 73 | raise Exception("Weirdly shaped reward from custom reward function") 74 | return observations, custom_reward, dones, infos 75 | 76 | 77 | class CustomRewardSSVecWrapper(VecEnvWrapper): 78 | """ 79 | Overrides environment reward with a given reward function R(s, s'). 80 | """ 81 | 82 | def __init__(self, venv, reward_function): 83 | super(CustomRewardSSVecWrapper, self).__init__(venv) 84 | self.reward_function = reward_function 85 | self.prev_obs = None 86 | 87 | def step_wait(self) -> 'GymStepReturn': 88 | obs, rewards, dones, infos = self.venv.step_wait() 89 | custom_rewards = [] 90 | for k in range(len(rewards)): 91 | if dones[k]: 92 | reward_input = np.array((self.prev_obs[k], infos[k]['terminal_observation'])).astype(np.float32) 93 | reward_input = np.expand_dims(reward_input, axis=0) 94 | custom_rewards.append(self.reward_function(reward_input)) 95 | else: 96 | reward_input = np.array((self.prev_obs[k], obs[k])).astype(np.float32) 97 | reward_input = np.expand_dims(reward_input, axis=0) 98 | custom_rewards.append(self.reward_function(reward_input)) 99 | custom_rewards = np.array(custom_rewards) 100 | if type(custom_rewards) is th.Tensor: 101 | custom_rewards = custom_rewards.cpu().detach().numpy() 102 | if len(custom_rewards.shape) == 2: 103 | custom_rewards = custom_rewards.flatten() 104 | elif len(custom_rewards.shape) == 1: 105 | pass 106 | else: 107 | raise Exception("Weirdly shaped reward from custom reward function") 108 | self.prev_obs = obs 109 | return obs, custom_rewards, dones, infos 110 | 111 | def reset(self) -> np.ndarray: 112 | obs = self.venv.reset() 113 | self.prev_obs = obs 114 | return obs 115 | 116 | 117 | class CustomRewardSSWrapper(gym.Wrapper): 118 | """ 119 | Overrides environment reward with a given reward function R(s, s'). 120 | """ 121 | 122 | def __init__(self, env, reward_function): 123 | super(CustomRewardSSWrapper, self).__init__(env) 124 | self.reward_function = reward_function 125 | self.prev_obs = None 126 | 127 | def step(self, action): 128 | next_state, reward, done, info = self.env.step(action) 129 | # if done: 130 | # reward_input = np.array((self.prev_obs, info['terminal_observation'])).astype(np.float32) 131 | # reward_input = np.expand_dims(reward_input, axis=0) 132 | # custom_reward = self.reward_function(reward_input) 133 | # else: 134 | reward_input = np.array((self.prev_obs, next_state)).astype(np.float32) 135 | reward_input = np.expand_dims(reward_input, axis=0) 136 | custom_reward = self.reward_function(reward_input) 137 | self.prev_obs = next_state 138 | return next_state, custom_reward, done, info 139 | 140 | def reset(self): 141 | obs = self.env.reset() 142 | self.prev_obs = obs 143 | return obs 144 | 145 | 146 | def WrapWithRewardModel(env, model_path, device='cuda'): 147 | if 'Maze' in model_path: 148 | rm = MazeRewardModel(env, device) 149 | rm.load_state_dict(th.load(model_path)) 150 | return CustomRewardSSWrapper(env, rm) 151 | elif 'NoFrameskip' in model_path: 152 | rm = AtariRewardModel(env, device) 153 | rm.load_state_dict(th.load(model_path)) 154 | return CustomRewardWrapper(env, rm) 155 | else: 156 | raise Exception("Only Maze or Atari environments are supported by this 'wrapper'.") 157 | 158 | -------------------------------------------------------------------------------- /scripts/maze_create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | import random 5 | from pathlib import Path 6 | from itertools import product 7 | import argparse 8 | 9 | import h5py 10 | 11 | import gym 12 | import mazelab 13 | 14 | import numpy as np 15 | import matplotlib 16 | import matplotlib.pyplot as plt 17 | import scipy.ndimage 18 | import torch as th 19 | import torch.nn as nn 20 | 21 | from tqdm import tqdm 22 | 23 | from stable_baselines3.common.utils import set_random_seed 24 | from stable_baselines3.common.vec_env import VecEnvWrapper, VecEnv, DummyVecEnv 25 | from stable_baselines3.common.vec_env import VecTransposeImage 26 | 27 | sys.path.insert(1, "../rl-baselines3-zoo") 28 | import utils.import_envs # noqa: F401 pylint: disable=unused-import 29 | from utils.utils import StoreDict 30 | from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams 31 | 32 | 33 | if __name__ == '__main__': 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--algo', help='RL Algorithm', default='ppo', 37 | type=str, required=False, choices=list(ALGOS.keys())) 38 | parser.add_argument('--env', type=str, default="EmptyMaze-10x10-CoinFlipGoal-v3", help='environment ID') 39 | parser.add_argument('--exp-id', type=int, default=1, help="experiment ID") 40 | parser.add_argument('-n', '--n-timesteps', help='Overwrite the number of timesteps', default=25000, 41 | type=int) 42 | parser.add_argument('-f', '--log-folder', help='Log folder', type=str, default='../agents') 43 | parser.add_argument('-s', '--seed', help="Random seed", default=0, type=int) 44 | args = parser.parse_args() 45 | 46 | 47 | ########### Set Device ############ 48 | device = th.device('cuda' if th.cuda.is_available() else 'cpu') 49 | dtype = th.float32 50 | th.set_default_dtype(dtype) 51 | print("Using device: {}".format(device)) 52 | 53 | seed = args.seed 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | th.manual_seed(seed) 57 | th.backends.cudnn.deterministic = True 58 | th.backends.cudnn.benchmark = False 59 | set_random_seed(seed) 60 | 61 | 62 | ########### Set Params ############ 63 | env_id = args.env 64 | folder = args.log_folder 65 | algo = args.algo 66 | num_threads = -1 67 | n_envs = 1 68 | exp_id = args.exp_id 69 | verbose = 1 70 | no_render = False 71 | deterministic = False 72 | load_best = True 73 | load_checkpoint = None 74 | norm_reward = False 75 | reward_log = '' 76 | env_kwargs = None 77 | 78 | 79 | # Sanity checks 80 | if exp_id > 0: 81 | log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id)) 82 | else: 83 | log_path = os.path.join(folder, algo) 84 | 85 | found = False 86 | for ext in ['zip']: 87 | model_path = os.path.join(log_path, f'{env_id}.{ext}') 88 | found = os.path.isfile(model_path) 89 | if found: 90 | break 91 | 92 | if load_best: 93 | model_path = os.path.join(log_path, "best_model.zip") 94 | found = os.path.isfile(model_path) 95 | 96 | if load_checkpoint is not None: 97 | model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") 98 | found = os.path.isfile(model_path) 99 | 100 | if not found: 101 | raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") 102 | 103 | if algo in ['dqn', 'ddpg', 'sac', 'td3']: 104 | n_envs = 1 105 | 106 | set_random_seed(seed) 107 | 108 | if num_threads > 0: 109 | if verbose > 1: 110 | print(f"Setting torch.num_threads to {num_threads}") 111 | th.set_num_threads(num_threads) 112 | 113 | is_atari = 'NoFrameskip' in env_id 114 | 115 | stats_path = os.path.join(log_path, env_id) 116 | hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True) 117 | env_kwargs = {} if env_kwargs is None else env_kwargs 118 | 119 | log_dir = reward_log if reward_log != '' else None 120 | 121 | # env = create_test_env(env_id, n_envs=n_envs, 122 | # stats_path=stats_path, seed=seed, log_dir=log_dir, 123 | # should_render=not no_render, 124 | # hyperparams=hyperparams, 125 | # env_kwargs=env_kwargs) 126 | env = gym.make(env_id) 127 | obs_shape = env.reset().shape 128 | 129 | model = ALGOS[algo].load(model_path, env=env) 130 | 131 | 132 | database = h5py.File(f"../datasets/rewards_{env_id}.hdf5", 'a') 133 | 134 | SAMPLES = args.n_timesteps 135 | train = database.create_group('train') 136 | inputs = train.create_dataset('inputs', (SAMPLES, 2, *obs_shape)) 137 | outputs = train.create_dataset('outputs', (SAMPLES, 1)) 138 | # zeros_inputs = train.create_dataset('zeros-inputs', (SAMPLES//2, 84, 84, 4)) 139 | # zeros_labels = train.create_dataset('zeros-labels', (SAMPLES//2, 1)) 140 | # ones_inputs = train.create_dataset('ones-inputs', (SAMPLES//2, 84, 84, 4)) 141 | # ones_labels = train.create_dataset('ones-labels', (SAMPLES//2, 1)) 142 | state = env.reset() 143 | next_state = None 144 | for i in tqdm(range(SAMPLES)): 145 | action, _states = model.predict(state, deterministic=False) 146 | next_state, reward, done, info = env.step(action) 147 | inputs[i] = np.array((state, next_state)) 148 | outputs[i] = reward 149 | if done: 150 | state = env.reset() 151 | else: 152 | state = next_state 153 | 154 | 155 | SAMPLES = args.n_timesteps // 2 156 | test = database.create_group('test') 157 | inputs = test.create_dataset('inputs', (SAMPLES, 2, *obs_shape)) 158 | outputs = test.create_dataset('outputs', (SAMPLES, 1)) 159 | # zeros_inputs = train.create_dataset('zeros-inputs', (SAMPLES//2, 84, 84, 4)) 160 | # zeros_labels = train.create_dataset('zeros-labels', (SAMPLES//2, 1)) 161 | # ones_inputs = train.create_dataset('ones-inputs', (SAMPLES//2, 84, 84, 4)) 162 | # ones_labels = train.create_dataset('ones-labels', (SAMPLES//2, 1)) 163 | state = env.reset() 164 | next_state = None 165 | for i in tqdm(range(SAMPLES)): 166 | action, _states = model.predict(state, deterministic=False) 167 | next_state, reward, done, info = env.step(action) 168 | inputs[i] = np.array((state, next_state)) 169 | outputs[i] = reward 170 | if done: 171 | state = env.reset() 172 | else: 173 | state = next_state 174 | 175 | 176 | database.close() 177 | 178 | 179 | -------------------------------------------------------------------------------- /scripts/atari_create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | from pathlib import Path 5 | import random 6 | from itertools import product 7 | import argparse 8 | 9 | import h5py 10 | 11 | import gym 12 | import numpy as np 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | import scipy.ndimage 16 | import torch as th 17 | import torch.nn as nn 18 | 19 | from tqdm import tqdm 20 | 21 | from stable_baselines3.common.utils import set_random_seed 22 | from stable_baselines3.common.vec_env import VecEnvWrapper, VecEnv, DummyVecEnv 23 | from stable_baselines3.common.vec_env import VecTransposeImage 24 | 25 | sys.path.insert(1, "../rl-baselines3-zoo") 26 | import utils.import_envs # noqa: F401 pylint: disable=unused-import 27 | from utils.utils import StoreDict 28 | from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams 29 | 30 | 31 | if __name__ == '__main__': 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--algo', help='RL Algorithm', default='ppo', 35 | type=str, required=False, choices=list(ALGOS.keys())) 36 | parser.add_argument('--env', type=str, default="BreakoutNoFrameskip-v4", help='environment ID') 37 | parser.add_argument('--exp-id', type=int, default=1, help="experiment ID") 38 | parser.add_argument('-n', '--n-timesteps', help='Overwrite the number of timesteps', default=50000, 39 | type=int) 40 | parser.add_argument('-f', '--log-folder', help='Log folder', type=str, default='../agents') 41 | parser.add_argument('-s', '--seed', help="Random seed", default=0, type=int) 42 | args = parser.parse_args() 43 | 44 | ########### Set Device ############ 45 | device = th.device('cuda' if th.cuda.is_available() else 'cpu') 46 | dtype = th.float32 47 | th.set_default_dtype(dtype) 48 | print("Using device: {}".format(device)) 49 | 50 | seed = args.seed 51 | random.seed(seed) 52 | np.random.seed(seed) 53 | th.manual_seed(seed) 54 | th.backends.cudnn.deterministic = True 55 | th.backends.cudnn.benchmark = False 56 | set_random_seed(seed) 57 | 58 | ########### Set Device ############ 59 | device = th.device('cuda' if th.cuda.is_available() else 'cpu') 60 | dtype = th.float32 61 | th.set_default_dtype(dtype) 62 | print("Using device: {}".format(device)) 63 | 64 | ########### Set Params ############ 65 | env_id = args.env 66 | folder = args.log_folder 67 | algo = args.algo 68 | num_threads = -1 69 | n_envs = 1 70 | exp_id = args.exp_id 71 | verbose = 1 72 | no_render = False 73 | deterministic = False 74 | load_best = True 75 | load_checkpoint = None 76 | norm_reward = False 77 | reward_log = '' 78 | env_kwargs = None 79 | 80 | 81 | # Sanity checks 82 | if exp_id > 0: 83 | log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id)) 84 | else: 85 | log_path = os.path.join(folder, algo) 86 | 87 | found = False 88 | for ext in ['zip']: 89 | model_path = os.path.join(log_path, f'{env_id}.{ext}') 90 | found = os.path.isfile(model_path) 91 | if found: 92 | break 93 | 94 | if load_best: 95 | model_path = os.path.join(log_path, "best_model.zip") 96 | found = os.path.isfile(model_path) 97 | 98 | if load_checkpoint is not None: 99 | model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") 100 | found = os.path.isfile(model_path) 101 | 102 | if not found: 103 | raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") 104 | 105 | if algo in ['dqn', 'ddpg', 'sac', 'td3']: 106 | n_envs = 1 107 | 108 | set_random_seed(seed) 109 | 110 | if num_threads > 0: 111 | if verbose > 1: 112 | print(f"Setting torch.num_threads to {num_threads}") 113 | th.set_num_threads(num_threads) 114 | 115 | is_atari = 'NoFrameskip' in env_id 116 | 117 | stats_path = os.path.join(log_path, env_id) 118 | hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True) 119 | env_kwargs = {} if env_kwargs is None else env_kwargs 120 | 121 | log_dir = reward_log if reward_log != '' else None 122 | 123 | env = create_test_env(env_id, n_envs=n_envs, 124 | stats_path=stats_path, seed=seed, log_dir=log_dir, 125 | should_render=not no_render, 126 | hyperparams=hyperparams, 127 | env_kwargs=env_kwargs) 128 | 129 | model = ALGOS[algo].load(model_path, env=env) 130 | 131 | 132 | database = h5py.File(f"../datasets/rewards_{env_id}.hdf5", 'a') 133 | 134 | SAMPLES = args.n_timesteps 135 | bar = tqdm(total=SAMPLES//2) 136 | train = database.create_group('train') 137 | zeros_inputs = train.create_dataset('zeros-inputs', (SAMPLES//2, 84, 84, 4)) 138 | zeros_labels = train.create_dataset('zeros-labels', (SAMPLES//2, 1)) 139 | ones_inputs = train.create_dataset('ones-inputs', (SAMPLES//2, 84, 84, 4)) 140 | ones_labels = train.create_dataset('ones-labels', (SAMPLES//2, 1)) 141 | i = j = 0 142 | obs = env.reset() 143 | while i < SAMPLES//2 or j < SAMPLES//2: 144 | action, _states = model.predict(obs, deterministic=False) 145 | obs, reward, done, info = env.step(action) 146 | reward = reward[0] 147 | if reward and j < SAMPLES//2: 148 | ones_inputs[j] = obs 149 | ones_labels[j] = reward 150 | j += 1 151 | bar.update() 152 | if not reward and i < SAMPLES//2: 153 | zeros_inputs[i] = obs 154 | zeros_labels[i] = reward 155 | i += 1 156 | 157 | 158 | SAMPLES = args.n_timesteps // 2 159 | bar = tqdm(total=SAMPLES//2) 160 | test = database.create_group('test') 161 | zeros_inputs = test.create_dataset('zeros-inputs', (SAMPLES//2, 84, 84, 4)) 162 | zeros_labels = test.create_dataset('zeros-labels', (SAMPLES//2, 1)) 163 | ones_inputs = test.create_dataset('ones-inputs', (SAMPLES//2, 84, 84, 4)) 164 | ones_labels = test.create_dataset('ones-labels', (SAMPLES//2, 1)) 165 | i = j = 0 166 | obs = env.reset() 167 | while i < SAMPLES//2 or j < SAMPLES//2: 168 | action, _states = model.predict(obs, deterministic=False) 169 | obs, reward, done, info = env.step(action) 170 | reward = reward[0] 171 | if reward and j < SAMPLES//2: 172 | ones_inputs[j] = obs 173 | ones_labels[j] = reward 174 | j += 1 175 | bar.update() 176 | if not reward and i < SAMPLES//2: 177 | zeros_inputs[i] = obs 178 | zeros_labels[i] = reward 179 | i += 1 180 | 181 | database.close() 182 | 183 | 184 | -------------------------------------------------------------------------------- /scripts/atari_record_videos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | from pathlib import Path 5 | from itertools import product 6 | import h5py 7 | import random 8 | 9 | import gym 10 | import numpy as np 11 | import matplotlib 12 | import matplotlib.cm 13 | import matplotlib.pyplot as plt 14 | import scipy.ndimage 15 | import skimage.transform 16 | import torch as th 17 | import torch.nn as nn 18 | 19 | from tqdm.auto import tqdm 20 | 21 | from stable_baselines3.common.utils import set_random_seed 22 | from stable_baselines3.common.atari_wrappers import AtariWrapper 23 | from stable_baselines3.common.vec_env import VecEnvWrapper, VecEnv, DummyVecEnv, VecFrameStack 24 | from stable_baselines3.common.vec_env import VecTransposeImage 25 | 26 | sys.path.insert(1, "../rl-baselines3-zoo") 27 | import utils.import_envs # noqa: F401 pylint: disable=unused-import 28 | from utils.utils import StoreDict 29 | from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams 30 | 31 | import interp 32 | from interp.common.models import AtariRewardModel 33 | from celluloid import Camera 34 | 35 | 36 | ########### Set Device ############ 37 | device = th.device('cuda' if th.cuda.is_available() else 'cpu') 38 | dtype = th.float32 39 | th.set_default_dtype(dtype) 40 | print("Using device: {}".format(device)) 41 | 42 | def get_mask(center, size, r): 43 | y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]] 44 | keep = x*x + y*y <= 1 45 | mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels 46 | mask = scipy.ndimage.gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1 47 | return mask/mask.max() 48 | 49 | def occlude(img, mask): 50 | assert img.shape[1:] == (84, 84, 4) 51 | img = np.copy(img) 52 | for k in range(4): 53 | I = img[0, :, :, k] 54 | img[0, :, :, k] = I*(1-mask) + scipy.ndimage.gaussian_filter(I, sigma=3)*mask 55 | return img 56 | 57 | def compute_saliency_map(reward_model, obs, stride=5, radius=5): 58 | baseline = reward_model(obs).detach().cpu().numpy() 59 | scores = np.zeros((84 // stride + 1, 84 // stride + 1)) 60 | for i in range(0, 84, stride): 61 | for j in range(0, 84, stride): 62 | mask = get_mask(center=(i, j), size=(84, 84), r=radius) 63 | obs_perturbed = occlude(obs, mask) 64 | perturbed_reward = reward_model(obs_perturbed).detach().cpu().numpy() 65 | scores[i // stride, j // stride] = 0.5 * np.abs(perturbed_reward - baseline) ** 2 66 | pmax = scores.max() 67 | scores = skimage.transform.resize(scores, output_shape=(210, 160)) 68 | scores = scores.astype(np.float32) 69 | # return pmax * scores / scores.max() 70 | return scores / scores.max() 71 | 72 | def add_saliency_to_frame(frame, saliency, channel=1): 73 | # def saliency_on_atari_frame(saliency, atari, fudge_factor, channel=2, sigma=0): 74 | # sometimes saliency maps are a bit clearer if you blur them 75 | # slightly...sigma adjusts the radius of that blur 76 | pmax = saliency.max() 77 | I = frame.astype('uint16') 78 | I[:, :, channel] += (frame.max() * saliency).astype('uint16') 79 | I = I.clip(1,255).astype('uint8') 80 | return I 81 | 82 | env_id = "BreakoutNoFrameskip-v4" 83 | folder = "../agents" 84 | algo = "ppo" 85 | n_timesteps = 10000 86 | num_threads = -1 87 | n_envs = 1 88 | exp_id = 1 89 | verbose = 1 90 | no_render = False 91 | deterministic = False 92 | load_best = True 93 | load_checkpoint = None 94 | norm_reward = False 95 | seed = 0 96 | reward_log = '' 97 | env_kwargs = None 98 | 99 | # Sanity checks 100 | if exp_id > 0: 101 | log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id)) 102 | else: 103 | log_path = os.path.join(folder, algo) 104 | 105 | found = False 106 | for ext in ['zip']: 107 | if found: 108 | break 109 | 110 | if load_best: 111 | model_path = os.path.join(log_path, "best_model.zip") 112 | found = os.path.isfile(model_path) 113 | 114 | if load_checkpoint is not None: 115 | model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") 116 | found = os.path.isfile(model_path) 117 | 118 | if not found: 119 | raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") 120 | 121 | if algo in ['dqn', 'ddpg', 'sac', 'td3']: 122 | n_envs = 1 123 | 124 | set_random_seed(seed) 125 | 126 | if num_threads > 0: 127 | if verbose > 1: 128 | print(f"Setting torch.num_threads to {num_threads}") 129 | th.set_num_threads(num_threads) 130 | 131 | is_atari = 'NoFrameskip' in env_id 132 | 133 | stats_path = os.path.join(log_path, env_id) 134 | hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True) 135 | env_kwargs = {} if env_kwargs is None else env_kwargs 136 | 137 | log_dir = reward_log if reward_log != '' else None 138 | 139 | env = create_test_env(env_id, n_envs=n_envs, 140 | stats_path=stats_path, seed=seed, log_dir=log_dir, 141 | should_render=not no_render, 142 | hyperparams=hyperparams, 143 | env_kwargs=env_kwargs) 144 | 145 | model = ALGOS[algo].load(model_path, env=env, device=device) 146 | 147 | obs = env.reset() 148 | 149 | rm = AtariRewardModel(env, device) 150 | rm.load_state_dict(th.load(f"../reward-models/BreakoutNoFrameskip-v4-reward_model.pt")) 151 | rm = rm.to(device) 152 | 153 | 154 | random.seed(0) 155 | np.random.seed(0) 156 | th.manual_seed(0) 157 | th.backends.cudnn.deterministic = True 158 | th.backends.cudnn.benchmark = False 159 | 160 | breakout_images = [] 161 | obs = env.reset() 162 | for _ in tqdm(range(120)): 163 | action, _states = model.predict(obs, deterministic=False) 164 | obs, reward, done, info = env.step(action) 165 | if done: 166 | obs = env.reset() 167 | sal = compute_saliency_map(rm, obs) 168 | screenshot = env.render(mode='rgb_array') 169 | image = add_saliency_to_frame(screenshot, sal) 170 | breakout_images.append(image) 171 | 172 | 173 | 174 | env_id = "SeaquestNoFrameskip-v4" 175 | folder = "../agents" 176 | algo = "ppo" 177 | n_timesteps = 10000 178 | num_threads = -1 179 | n_envs = 1 180 | exp_id = 1 181 | verbose = 1 182 | no_render = False 183 | deterministic = False 184 | load_best = True 185 | load_checkpoint = None 186 | norm_reward = False 187 | seed = 0 188 | reward_log = '' 189 | env_kwargs = None 190 | 191 | 192 | # Sanity checks 193 | if exp_id > 0: 194 | log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id)) 195 | else: 196 | log_path = os.path.join(folder, algo) 197 | 198 | found = False 199 | for ext in ['zip']: 200 | model_path = os.path.join(log_path, f'{env_id}.{ext}') 201 | found = os.path.isfile(model_path) 202 | if found: 203 | break 204 | 205 | if load_best: 206 | model_path = os.path.join(log_path, "best_model.zip") 207 | found = os.path.isfile(model_path) 208 | 209 | if load_checkpoint is not None: 210 | model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") 211 | found = os.path.isfile(model_path) 212 | 213 | if not found: 214 | raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") 215 | 216 | if algo in ['dqn', 'ddpg', 'sac', 'td3']: 217 | n_envs = 1 218 | 219 | set_random_seed(seed) 220 | 221 | if num_threads > 0: 222 | if verbose > 1: 223 | print(f"Setting torch.num_threads to {num_threads}") 224 | th.set_num_threads(num_threads) 225 | 226 | is_atari = 'NoFrameskip' in env_id 227 | 228 | stats_path = os.path.join(log_path, env_id) 229 | hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True) 230 | env_kwargs = {} if env_kwargs is None else env_kwargs 231 | 232 | log_dir = reward_log if reward_log != '' else None 233 | 234 | env = create_test_env(env_id, n_envs=n_envs, 235 | stats_path=stats_path, seed=seed, log_dir=log_dir, 236 | should_render=not no_render, 237 | hyperparams=hyperparams, 238 | env_kwargs=env_kwargs) 239 | 240 | model = ALGOS[algo].load(model_path, env=env, device=device) 241 | 242 | obs = env.reset() 243 | 244 | rm = AtariRewardModel(env, device) 245 | rm.load_state_dict(th.load(f"../reward-models/SeaquestNoFrameskip-v4-reward_model.pt")) 246 | rm = rm.to(device) 247 | 248 | random.seed(0) 249 | np.random.seed(0) 250 | th.manual_seed(0) 251 | th.backends.cudnn.deterministic = True 252 | th.backends.cudnn.benchmark = False 253 | 254 | seaquest_images = [] 255 | obs = env.reset() 256 | for _ in tqdm(range(122)): 257 | action, _states = model.predict(obs, deterministic=False) 258 | obs, reward, done, info = env.step(action) 259 | if done: 260 | obs = env.reset() 261 | sal = compute_saliency_map(rm, obs, radius=6, stride=6) 262 | screenshot = env.render(mode='rgb_array') 263 | image = add_saliency_to_frame(screenshot, sal) 264 | seaquest_images.append(image) 265 | 266 | 267 | fig, ax = plt.subplots(1, 1, figsize=(5, 3.09375)) 268 | camera = Camera(fig) 269 | ax.axis('off') 270 | for img in breakout_images: 271 | ax.imshow(img) 272 | camera.snap() 273 | 274 | animation = camera.animate(interval=80) 275 | animation.save('../videos/breakout.mp4') 276 | 277 | fig, ax = plt.subplots(1, 1, figsize=(5, 3.09375)) 278 | camera = Camera(fig) 279 | ax.axis('off') 280 | for img in seaquest_images[20:]: 281 | ax.imshow(img) 282 | camera.snap() 283 | 284 | animation = camera.animate(interval=80) 285 | animation.save('../videos/seaquest.mp4') 286 | 287 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5, 3.09375)) 288 | camera = Camera(fig) 289 | ax1.axis('off') 290 | ax2.axis('off') 291 | for i in range(100): 292 | ax1.imshow(breakout_images[i]) 293 | ax2.imshow(seaquest_images[i+20]) 294 | camera.snap() 295 | 296 | animation = camera.animate(interval=80) 297 | animation.save('../videos/breakout-and-seaquest.mp4') 298 | 299 | 300 | 301 | 302 | -------------------------------------------------------------------------------- /scripts/maze_figures.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | import sys 6 | import importlib 7 | from pathlib import Path 8 | from itertools import product 9 | import h5py 10 | import random 11 | 12 | import numpy as np 13 | import matplotlib 14 | import matplotlib.pyplot as plt 15 | import scipy.ndimage 16 | import skimage.transform 17 | 18 | import gym 19 | import mazelab 20 | import torch as th 21 | import torch.nn as nn 22 | 23 | from tqdm.auto import tqdm 24 | import imageio 25 | from IPython.display import Image 26 | 27 | from stable_baselines3 import A2C 28 | sys.path.insert(1, "../rl-baselines3-zoo") 29 | import utils.import_envs # noqa: F401 pylint: disable=unused-import 30 | from utils.utils import StoreDict 31 | from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams 32 | 33 | from captum.attr import ( 34 | Saliency, 35 | IntegratedGradients, 36 | FeatureAblation, 37 | FeaturePermutation, 38 | Occlusion, 39 | ShapleyValueSampling 40 | ) 41 | 42 | from interp.common.models import MazeRewardModel 43 | 44 | 45 | if __name__ == '__main__': 46 | 47 | fig_path = Path('../figures/') 48 | if not fig_path.exists(): 49 | fig_path.mkdir() 50 | 51 | ########### Set Device ############ 52 | # device = th.device('cuda' if th.cuda.is_available() else 'cpu') 53 | device = 'cpu' 54 | dtype = th.float32 55 | th.set_default_dtype(dtype) 56 | print(f"Using {device} device") 57 | 58 | 59 | env_id = 'EmptyMaze-10x10-CoinFlipGoal-v3' 60 | print(env_id) 61 | folder = "../agents" 62 | algo = "ppo" 63 | n_timesteps = 10000 64 | num_threads = -1 65 | n_envs = 1 66 | exp_id = 1 67 | verbose = 1 68 | no_render = False 69 | deterministic = False 70 | load_best = True 71 | load_checkpoint = None 72 | norm_reward = False 73 | seed = 0 74 | reward_log = '' 75 | env_kwargs = None 76 | 77 | 78 | # In[7]: 79 | 80 | 81 | # Sanity checks 82 | if exp_id > 0: 83 | log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id)) 84 | else: 85 | log_path = os.path.join(folder, algo) 86 | 87 | found = False 88 | for ext in ['zip']: 89 | model_path = os.path.join(log_path, f'{env_id}.{ext}') 90 | found = os.path.isfile(model_path) 91 | if found: 92 | break 93 | 94 | if load_best: 95 | model_path = os.path.join(log_path, "best_model.zip") 96 | found = os.path.isfile(model_path) 97 | 98 | if load_checkpoint is not None: 99 | model_path = os.path.join(log_path, f"rl_model_{load_checkpoint}_steps.zip") 100 | found = os.path.isfile(model_path) 101 | 102 | if not found: 103 | raise ValueError(f"No model found for {algo} on {env_id}, path: {model_path}") 104 | 105 | if algo in ['dqn', 'ddpg', 'sac', 'td3']: 106 | n_envs = 1 107 | 108 | 109 | if num_threads > 0: 110 | if verbose > 1: 111 | print(f"Setting torch.num_threads to {num_threads}") 112 | th.set_num_threads(num_threads) 113 | 114 | is_atari = 'NoFrameskip' in env_id 115 | 116 | stats_path = os.path.join(log_path, env_id) 117 | hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True) 118 | env_kwargs = {} if env_kwargs is None else env_kwargs 119 | 120 | log_dir = reward_log if reward_log != '' else None 121 | 122 | # env = create_test_env(env_id, n_envs=n_envs, 123 | # stats_path=stats_path, seed=seed, log_dir=log_dir, 124 | # should_render=not no_render, 125 | # hyperparams=hyperparams, 126 | # env_kwargs=env_kwargs) 127 | 128 | env = gym.make(env_id) 129 | model = ALGOS[algo].load(model_path, env=env, device=device) 130 | 131 | obs = env.reset() 132 | 133 | 134 | rm = MazeRewardModel(env, 'cpu').to('cpu') 135 | rm.load_state_dict(th.load(f"../reward-models/{env_id}-reward_model.pt", map_location='cpu')) 136 | 137 | 138 | # In[10]: 139 | 140 | 141 | plt.rcParams.update({ 142 | # "text.usetex": True, 143 | # "font.family": "serif", 144 | "font.serif": ["Times"], 145 | }) 146 | 147 | 148 | # In[11]: 149 | 150 | 151 | random.seed(0) 152 | np.random.seed(0) 153 | th.manual_seed(0) 154 | th.backends.cudnn.deterministic = True 155 | th.backends.cudnn.benchmark = False 156 | 157 | plt.figure(figsize=(5.5, 2.5)) 158 | 159 | sal = Saliency(rm.tforward) 160 | w = 6 161 | i = 1 162 | 163 | while i <= w: 164 | obs = env.reset() 165 | action, _states = model.predict(obs, deterministic=True) 166 | next_obs, reward, done, info = env.step(action) 167 | if i == w: 168 | obs = env.reset() 169 | goal_pos = env.maze.objects.goal.positions[0] 170 | if goal_pos == [1, 1]: 171 | agent_pos = [2, 1] 172 | else: 173 | agent_pos = [8, 9] 174 | env.maze.objects.agent.positions[0] = agent_pos 175 | obs = env.maze.to_value() 176 | env.maze.objects.agent.positions[0] = goal_pos 177 | next_obs = env.maze.to_value() 178 | 179 | top_left = (env.maze.objects.goal.positions[0] == [1, 1]) 180 | if (i in [1, 2, 3] and top_left) or (i in [4, 5, 6] and not top_left): 181 | continue 182 | 183 | ax = plt.subplot(3, w, i) 184 | screenshot = env.maze.to_rgb() 185 | ax.imshow(screenshot) 186 | ax.set_xticks([]) 187 | ax.set_yticks([]) 188 | ax.set_yticklabels([]) 189 | ax.set_xticklabels([]) 190 | # plt.axis('off') 191 | if i == 1: 192 | ax.set_ylabel(r"$s'$", rotation=0, fontsize=10, fontfamily="Times New Roman") 193 | ax.yaxis.set_label_coords(-0.22, 0.4) 194 | 195 | input = np.array((obs, next_obs)).astype(np.float32) 196 | input = th.tensor(np.expand_dims(input, axis=0)).to('cpu').to(dtype) 197 | input.requires_grad = True 198 | attributions = sal.attribute(input) 199 | attributions = np.abs(attributions.detach()[0, ...]) 200 | 201 | ax = plt.subplot(3, w, w+i) 202 | ax.set_xticks([]) 203 | ax.set_yticks([]) 204 | ax.set_yticklabels([]) 205 | ax.set_xticklabels([]) 206 | ax.imshow(attributions[0, ...], cmap='gray', vmin=attributions.min(), vmax=attributions.max()) 207 | if i == 1: 208 | ax.set_ylabel(r"$\left\vert \frac{dR}{ds} \right\vert$", rotation=0, fontsize=10, fontfamily="Times New Roman") 209 | ax.yaxis.set_label_coords(-0.28, 0.3) 210 | 211 | 212 | ax = plt.subplot(3, w, 2*w + i) 213 | ax.set_xticks([]) 214 | ax.set_yticks([]) 215 | ax.set_yticklabels([]) 216 | ax.set_xticklabels([]) 217 | ax.imshow(attributions[1, ...], cmap='gray', vmin=attributions.min(), vmax=attributions.max()) 218 | if i == 1: 219 | ax.set_ylabel(r"$\left\vert \frac{dR}{ds'} \right\vert$", rotation=0, fontsize=10, fontfamily="Times New Roman") 220 | ax.yaxis.set_label_coords(-0.28, 0.26) 221 | 222 | i += 1 223 | 224 | plt.subplots_adjust(left=0.07, right=0.93, bottom=0.05, top=0.95) 225 | 226 | plt.savefig(f"{fig_path}/CoinFlipsaliencymapssixstates.pdf", dpi=350) 227 | plt.savefig(f"{fig_path}/CoinFlipsaliencymapssixstates.png", dpi=350) 228 | 229 | 230 | 231 | # In[12]: 232 | 233 | 234 | random.seed(0) 235 | np.random.seed(0) 236 | th.manual_seed(0) 237 | th.backends.cudnn.deterministic = True 238 | th.backends.cudnn.benchmark = False 239 | 240 | plt.figure(figsize=(5.5, 1.7)) 241 | 242 | # No Goal Object: 243 | env.reset() 244 | obs = env.maze.to_value() 245 | goal_pos = tuple(map(int, np.where(obs==3))) 246 | empty_pos = (np.where(obs == 0)[0][0], np.where(obs == 0)[1][0]) 247 | s = env.maze.to_rgb() 248 | obs[goal_pos] = obs[empty_pos] 249 | s[goal_pos] = s[empty_pos] 250 | 251 | action, _states = model.predict(obs, deterministic=True) 252 | next_obs, reward, done, info = env.step(action) 253 | goal_pos = tuple(map(int, np.where(next_obs==3))) 254 | empty_pos = (np.where(next_obs == 0)[0][0], np.where(next_obs == 0)[1][0]) 255 | sp = env.maze.to_rgb() 256 | next_obs[goal_pos] = next_obs[empty_pos] 257 | sp[goal_pos] = sp[empty_pos] 258 | 259 | ax = plt.subplot(1, 6, 1) 260 | ax.imshow(s) 261 | ax.set_xticks([]) 262 | ax.set_yticks([]) 263 | ax.set_yticklabels([]) 264 | ax.set_xticklabels([]) 265 | ax.set_title(r"$s$", fontsize=10, fontfamily="Times New Roman") 266 | 267 | ax = plt.subplot(1, 6, 2) 268 | ax.imshow(sp) 269 | ax.set_xticks([]) 270 | ax.set_yticks([]) 271 | ax.set_yticklabels([]) 272 | ax.set_xticklabels([]) 273 | ax.set_title(r"$s'$", fontsize=10, fontfamily="Times New Roman") 274 | 275 | input = np.array((obs, next_obs)).astype(np.float32) 276 | input = th.tensor(np.expand_dims(input, axis=0)).to('cpu').to(dtype) 277 | ax.set_xlabel("$R(s, s')$ = {:.3f}".format(rm.tforward(input).item()), fontsize=10, fontfamily="Times New Roman") 278 | ax.xaxis.set_label_coords(-0.1, -0.05) 279 | ax.text(-3.2, 16.7, '(a)', fontsize=10, weight="bold", fontfamily="Times New Roman") 280 | 281 | 282 | # Two Goal Objects: 283 | env.reset() 284 | env.maze.objects.goal.positions = [[1, 1], [9, 9]] 285 | env.maze.objects.agent.positions = [[2, 1]] 286 | obs = env.maze.to_value() 287 | s = env.maze.to_rgb() 288 | 289 | action, _states = model.predict(obs, deterministic=True) 290 | next_obs, reward, done, info = env.step(action) 291 | sp = env.maze.to_rgb() 292 | 293 | 294 | ax = plt.subplot(1, 6, 3) 295 | ax.imshow(s) 296 | ax.set_xticks([]) 297 | ax.set_yticks([]) 298 | ax.set_yticklabels([]) 299 | ax.set_xticklabels([]) 300 | ax.set_title(r"$s$", fontsize=10, fontfamily="Times New Roman") 301 | 302 | ax = plt.subplot(1, 6, 4) 303 | ax.imshow(sp) 304 | ax.set_xticks([]) 305 | ax.set_yticks([]) 306 | ax.set_yticklabels([]) 307 | ax.set_xticklabels([]) 308 | ax.set_title(r"$s'$", fontsize=10, fontfamily="Times New Roman") 309 | 310 | input = np.array((obs, next_obs)).astype(np.float32) 311 | input = th.tensor(np.expand_dims(input, axis=0)).to('cpu').to(dtype) 312 | ax.set_xlabel("$R(s, s')$ = {:.3f}".format(rm.tforward(input).item()), fontsize=10, fontfamily="Times New Roman") 313 | ax.xaxis.set_label_coords(-0.1, -0.05) 314 | ax.text(-3.2, 16.7, '(b)', fontsize=10, weight="bold", fontfamily="Times New Roman") 315 | # ax.plot([12, 12], [-3, 13], color='black', lw=1) 316 | 317 | 318 | # Spamming goal objectives: 319 | env.reset() 320 | agent_pos = env.maze.objects.agent.positions 321 | for i in range(0, 9*9, 1): 322 | if i % 2 == 0: 323 | x, y = i // 9, i % 9 324 | x, y = x+1, y+1 325 | if agent_pos != [x, y]: 326 | env.maze.objects.goal.positions.append([x, y]) 327 | obs = env.maze.to_value() 328 | s = env.maze.to_rgb() 329 | 330 | action, _states = model.predict(obs, deterministic=True) 331 | next_obs, reward, done, info = env.step(action) 332 | sp = env.maze.to_rgb() 333 | 334 | 335 | ax = plt.subplot(1, 6, 5) 336 | ax.imshow(s) 337 | ax.set_xticks([]) 338 | ax.set_yticks([]) 339 | ax.set_yticklabels([]) 340 | ax.set_xticklabels([]) 341 | ax.set_title(r"$s$", fontsize=10, fontfamily="Times New Roman") 342 | 343 | ax = plt.subplot(1, 6, 6) 344 | ax.imshow(sp) 345 | ax.set_xticks([]) 346 | ax.set_yticks([]) 347 | ax.set_yticklabels([]) 348 | ax.set_xticklabels([]) 349 | ax.set_title(r"$s'$", fontsize=10, fontfamily="Times New Roman") 350 | 351 | input = np.array((obs, next_obs)).astype(np.float32) 352 | input = th.tensor(np.expand_dims(input, axis=0)).to('cpu').to(dtype) 353 | ax.set_xlabel("$R(s, s')$ = {:.3f}".format(rm.tforward(input).item()), fontsize=10, fontfamily="Times New Roman") 354 | ax.xaxis.set_label_coords(-0.1, -0.05) 355 | 356 | ax.text(-3.2, 16.7, '(c)', fontsize=10, weight="bold", fontfamily="Times New Roman") 357 | 358 | plt.subplots_adjust(left=0.05, right=0.95) 359 | 360 | plt.savefig(f"{fig_path}/CoinFlipcounterfactuals.pdf", dpi=350) 361 | plt.savefig(f"{fig_path}/CoinFlipcounterfactuals.png", dpi=350) 362 | 363 | 364 | 365 | 366 | 367 | custom_trained_agents_path = Path('../agents-custom') 368 | if custom_trained_agents_path.exists(): 369 | evaluations = { 370 | 'CoinFlipGoal': { 371 | 'true': Path(custom_trained_agents_path, 'ppo/EmptyMaze-10x10-CoinFlipGoal-v3_1/evaluations.npz'), 372 | 'rm': Path(custom_trained_agents_path, 'ppo/EmptyMaze-10x10-CoinFlipGoal-v3_2/evaluations.npz') 373 | }, 374 | 'TwoGoals': { 375 | 'true': Path(custom_trained_agents_path, 'ppo/EmptyMaze-10x10-TwoGoals-v3_1/evaluations.npz'), 376 | 'rm': Path(custom_trained_agents_path, 'ppo/EmptyMaze-10x10-TwoGoals-v3_2/evaluations.npz') 377 | } 378 | } 379 | 380 | plt.rcParams.update({ 381 | # "text.usetex": True, 382 | # "font.family": "serif", 383 | "font.serif": ["Times"], 384 | }) 385 | 386 | plt.figure(figsize=(5.5, 2.7)) 387 | 388 | ax = plt.subplot(2, 2, 1) 389 | with np.load(evaluations['CoinFlipGoal']['true']) as data: 390 | timesteps = data['timesteps'] 391 | results = data['results'] 392 | means = np.mean(results, axis=1).flatten() 393 | stds = np.std(results, axis=1).flatten() 394 | ax.plot(timesteps, means, color="blue") 395 | ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color="blue") 396 | # plt.ylim(-5, 40) 397 | ax.set_xlabel("Timesteps", fontsize=9, fontfamily='Times New Roman') 398 | ax.set_ylabel("Episode Return", fontsize=9, fontfamily='Times New Roman') 399 | ax.set_ylim(0, 1.1) 400 | plt.xticks(fontsize=7, fontfamily='Times New Roman') 401 | plt.yticks(fontsize=7, fontfamily='Times New Roman') 402 | plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman') 403 | ax.xaxis.labelpad=1 404 | # plt.legend(loc='upper left', prop={'size': 6}) 405 | ax.set_title("CoinFlipGoal w/ True Reward", fontsize=9, fontfamily='Times New Roman') 406 | ax.text(0.5, -0.75, "(a)", size=10, ha="center", weight="bold", fontfamily='Times New Roman', 407 | transform=ax.transAxes) 408 | 409 | ax = plt.subplot(2, 2, 3) 410 | with np.load(evaluations['CoinFlipGoal']['rm']) as data: 411 | timesteps = data['timesteps'] 412 | results = data['results'] 413 | means = np.mean(results, axis=1).flatten() 414 | stds = np.std(results, axis=1).flatten() 415 | ax.plot(timesteps, means, color="blue") 416 | ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color="blue") 417 | # plt.ylim(-5, 40) 418 | ax.set_xlabel("Timesteps", fontsize=9, fontfamily='Times New Roman') 419 | ax.set_ylabel("Episode Return", fontsize=9, fontfamily='Times New Roman') 420 | ax.set_ylim(0, 1.1) 421 | plt.xticks(fontsize=7, fontfamily='Times New Roman') 422 | plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman') 423 | ax.xaxis.labelpad=1 424 | # plt.legend(loc='upper left', prop={'size': 6}) 425 | ax.set_title("CoinFlipGoal w/ Regressed Reward", fontsize=9, fontfamily='Times New Roman') 426 | ax.text(0.5, -0.75, "(c)", size=10, ha="center", weight="bold", fontfamily='Times New Roman', 427 | transform=ax.transAxes) 428 | 429 | ax = plt.subplot(2, 2, 2) 430 | with np.load(evaluations['TwoGoals']['true']) as data: 431 | timesteps = data['timesteps'] 432 | results = data['results'] 433 | means = np.mean(results, axis=1).flatten() 434 | stds = np.std(results, axis=1).flatten() 435 | ax.plot(timesteps, means, color="blue") 436 | ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color="blue") 437 | # plt.ylim(-5, 40) 438 | ax.set_xlabel("Timesteps", fontsize=9, fontfamily='Times New Roman') 439 | ax.set_ylabel("Episode Return", fontsize=9, fontfamily='Times New Roman') 440 | ax.set_ylim(0, 1.1) 441 | plt.xticks(fontsize=7, fontfamily='Times New Roman') 442 | plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman') 443 | ax.xaxis.labelpad=1 444 | # plt.legend(loc='upper left', prop={'size': 6}) 445 | ax.set_title("TwoGoals w/ True Reward", fontsize=9, fontfamily='Times New Roman') 446 | ax.text(0.5, -0.75, "(b)", size=10, ha="center", weight="bold", fontfamily='Times New Roman', 447 | transform=ax.transAxes) 448 | 449 | ax = plt.subplot(2, 2, 4) 450 | with np.load(evaluations['TwoGoals']['rm']) as data: 451 | timesteps = data['timesteps'] 452 | results = data['results'] 453 | means = np.mean(results, axis=1).flatten() 454 | stds = np.std(results, axis=1).flatten() 455 | ax.plot(timesteps, means, color="blue") 456 | ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color="blue") 457 | # plt.ylim(-5, 40) 458 | ax.set_xlabel("Timesteps", fontsize=9, fontfamily='Times New Roman') 459 | ax.set_ylabel("Episode Return", fontsize=7, fontfamily='Times New Roman') 460 | ax.set_ylim(0, 1.1) 461 | plt.xticks(fontsize=7, fontfamily='Times New Roman') 462 | plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman') 463 | ax.xaxis.labelpad=1 464 | # plt.legend(loc='upper left', prop={'size': 6}) 465 | ax.set_title("TwoGoals w/ CoinFlipGoal's Regressed Reward", fontsize=9, fontfamily='Times New Roman') 466 | ax.text(0.5, -0.75, "(d)", size=10, ha="center", weight="bold", fontfamily='Times New Roman', 467 | transform=ax.transAxes) 468 | 469 | 470 | plt.subplots_adjust(hspace=1.2, wspace=0.4, bottom=0.2, top=0.92) 471 | 472 | plt.savefig(f"{fig_path}/maze-training-curves.pdf", dpi=100) 473 | plt.savefig(f"{fig_path}/maze-training-curves.png", dpi=350) 474 | -------------------------------------------------------------------------------- /scripts/maze_train_agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import difflib 3 | import os 4 | import sys 5 | import importlib 6 | import time 7 | import uuid 8 | import warnings 9 | from collections import OrderedDict 10 | from pprint import pprint 11 | 12 | import yaml 13 | import gym 14 | import seaborn 15 | import numpy as np 16 | import torch as th 17 | # For custom activation fn 18 | import torch.nn as nn # noqa: F401 pytype: disable=unused-import 19 | 20 | from stable_baselines3.common.utils import set_random_seed 21 | # from stable_baselines3.common.cmd_util import make_atari_env 22 | from stable_baselines3.common.vec_env import VecFrameStack, VecNormalize, DummyVecEnv, VecTransposeImage 23 | from stable_baselines3.common.preprocessing import is_image_space 24 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 25 | from stable_baselines3.common.utils import constant_fn 26 | from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback 27 | 28 | sys.path.append('../rl-baselines3-zoo/') 29 | # Register custom envs 30 | import utils.import_envs # noqa: F401 pytype: disable=import-error 31 | from utils import make_env, ALGOS, linear_schedule, get_latest_run_id, get_wrapper_class 32 | from utils.hyperparams_opt import hyperparam_optimization 33 | from utils.callbacks import SaveVecNormalizeCallback 34 | from utils.noise import LinearNormalActionNoise 35 | from utils.utils import StoreDict, get_callback_class 36 | 37 | seaborn.set() 38 | 39 | if __name__ == '__main__': # noqa: C901 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--algo', help='RL Algorithm', default='ppo', 42 | type=str, required=False, choices=list(ALGOS.keys())) 43 | parser.add_argument('--env', type=str, default="CartPole-v1", help='environment ID') 44 | parser.add_argument('-tb', '--tensorboard-log', help='Tensorboard log dir', default='', type=str) 45 | parser.add_argument('-i', '--trained-agent', help='Path to a pretrained agent to continue training', 46 | default='', type=str) 47 | parser.add_argument('-n', '--n-timesteps', help='Overwrite the number of timesteps', default=-1, 48 | type=int) 49 | parser.add_argument('--num-threads', help='Number of threads for PyTorch (-1 to use default)', default=-1, 50 | type=int) 51 | parser.add_argument('--log-interval', help='Override log interval (default: -1, no change)', default=-1, 52 | type=int) 53 | parser.add_argument('--eval-freq', help='Evaluate the agent every n steps (if negative, no evaluation)', 54 | default=10000, type=int) 55 | parser.add_argument('--eval-episodes', help='Number of episodes to use for evaluation', 56 | default=5, type=int) 57 | parser.add_argument('--save-freq', help='Save the model every n steps (if negative, no checkpoint)', 58 | default=-1, type=int) 59 | parser.add_argument('--save-replay-buffer', help='Save the replay buffer too (when applicable)', 60 | action='store_true', default=False) 61 | parser.add_argument('-f', '--log-folder', help='Log folder', type=str, default='logs') 62 | parser.add_argument('--seed', help='Random generator seed', type=int, default=-1) 63 | parser.add_argument('--n-trials', help='Number of trials for optimizing hyperparameters', type=int, default=10) 64 | parser.add_argument('-optimize', '--optimize-hyperparameters', action='store_true', default=False, 65 | help='Run hyperparameters search') 66 | parser.add_argument('--n-jobs', help='Number of parallel jobs when optimizing hyperparameters', type=int, default=1) 67 | parser.add_argument('--sampler', help='Sampler to use when optimizing hyperparameters', type=str, 68 | default='tpe', choices=['random', 'tpe', 'skopt']) 69 | parser.add_argument('--pruner', help='Pruner to use when optimizing hyperparameters', type=str, 70 | default='median', choices=['halving', 'median', 'none']) 71 | parser.add_argument('--n-startup-trials', help='Number of trials before using optuna sampler', 72 | type=int, default=10) 73 | parser.add_argument('--n-evaluations', help='Number of evaluations for hyperparameter optimization', 74 | type=int, default=20) 75 | parser.add_argument('--storage', help='Database storage path if distributed optimization should be used', type=str, 76 | default=None) 77 | parser.add_argument('--study-name', help='Study name for distributed optimization', type=str, 78 | default=None) 79 | parser.add_argument('--verbose', help='Verbose mode (0: no output, 1: INFO)', default=1, 80 | type=int) 81 | parser.add_argument('--gym-packages', type=str, nargs='+', default=[], 82 | help='Additional external Gym environment package modules to import (e.g. gym_minigrid)') 83 | parser.add_argument('--env-kwargs', type=str, nargs='+', action=StoreDict, 84 | help='Optional keyword argument to pass to the env constructor') 85 | parser.add_argument('-params', '--hyperparams', type=str, nargs='+', action=StoreDict, 86 | help='Overwrite hyperparameter (e.g. learning_rate:0.01 train_freq:10)') 87 | parser.add_argument('-uuid', '--uuid', action='store_true', default=False, 88 | help='Ensure that the run has a unique ID') 89 | args = parser.parse_args() 90 | 91 | # Going through custom gym packages to let them register in the global registory 92 | for env_module in args.gym_packages: 93 | importlib.import_module(env_module) 94 | 95 | env_id = args.env 96 | registered_envs = set(gym.envs.registry.env_specs.keys()) # pytype: disable=module-attr 97 | 98 | # If the environment is not found, suggest the closest match 99 | if env_id not in registered_envs: 100 | try: 101 | closest_match = difflib.get_close_matches(env_id, registered_envs, n=1)[0] 102 | except IndexError: 103 | closest_match = "'no close match found...'" 104 | raise ValueError(f'{env_id} not found in gym registry, you maybe meant {closest_match}?') 105 | 106 | # Unique id to ensure there is no race condition for the folder creation 107 | uuid_str = f'_{uuid.uuid4()}' if args.uuid else '' 108 | if args.seed < 0: 109 | # Seed but with a random one 110 | args.seed = np.random.randint(2**32 - 1, dtype='int64').item() 111 | 112 | set_random_seed(args.seed) 113 | 114 | # Setting num threads to 1 makes things run faster on cpu 115 | if args.num_threads > 0: 116 | if args.verbose > 1: 117 | print(f"Setting torch.num_threads to {args.num_threads}") 118 | th.set_num_threads(args.num_threads) 119 | 120 | if args.trained_agent != "": 121 | assert args.trained_agent.endswith('.zip') and os.path.isfile(args.trained_agent), \ 122 | "The trained_agent must be a valid path to a .zip file" 123 | 124 | tensorboard_log = None if args.tensorboard_log == '' else os.path.join(args.tensorboard_log, env_id) 125 | 126 | is_atari = False 127 | if 'NoFrameskip' in env_id: 128 | is_atari = True 129 | 130 | is_maze = False 131 | if 'Maze' in env_id: 132 | is_maze=True 133 | 134 | print("=" * 10, env_id, "=" * 10) 135 | print(f"Seed: {args.seed}") 136 | 137 | # Load hyperparameters from yaml file 138 | with open(f'hyperparams/{args.algo}.yml', 'r') as f: 139 | hyperparams_dict = yaml.safe_load(f) 140 | if env_id in list(hyperparams_dict.keys()): 141 | hyperparams = hyperparams_dict[env_id] 142 | elif is_atari: 143 | hyperparams = hyperparams_dict['atari'] 144 | elif is_maze: 145 | hyperparams = hyperparams_dict['maze'] 146 | else: 147 | raise ValueError(f"Hyperparameters not found for {args.algo}-{env_id}") 148 | 149 | if args.hyperparams is not None: 150 | # Overwrite hyperparams if needed 151 | hyperparams.update(args.hyperparams) 152 | # Sort hyperparams that will be saved 153 | saved_hyperparams = OrderedDict([(key, hyperparams[key]) for key in sorted(hyperparams.keys())]) 154 | env_kwargs = {} if args.env_kwargs is None else args.env_kwargs 155 | 156 | algo_ = args.algo 157 | # HER is only a wrapper around an algo 158 | if args.algo == 'her': 159 | algo_ = saved_hyperparams['model_class'] 160 | assert algo_ in {'sac', 'ddpg', 'dqn', 'td3'}, "{} is not compatible with HER".format(algo_) 161 | # Retrieve the model class 162 | hyperparams['model_class'] = ALGOS[saved_hyperparams['model_class']] 163 | 164 | if args.verbose > 0: 165 | pprint(saved_hyperparams) 166 | 167 | n_envs = hyperparams.get('n_envs', 1) 168 | 169 | if args.verbose > 0: 170 | print(f"Using {n_envs} environments") 171 | 172 | # Create schedules 173 | for key in ['learning_rate', 'clip_range', 'clip_range_vf']: 174 | if key not in hyperparams: 175 | continue 176 | if isinstance(hyperparams[key], str): 177 | schedule, initial_value = hyperparams[key].split('_') 178 | initial_value = float(initial_value) 179 | hyperparams[key] = linear_schedule(initial_value) 180 | elif isinstance(hyperparams[key], (float, int)): 181 | # Negative value: ignore (ex: for clipping) 182 | if hyperparams[key] < 0: 183 | continue 184 | hyperparams[key] = constant_fn(float(hyperparams[key])) 185 | else: 186 | raise ValueError(f'Invalid value for {key}: {hyperparams[key]}') 187 | 188 | # Should we overwrite the number of timesteps? 189 | if args.n_timesteps > 0: 190 | if args.verbose: 191 | print(f"Overwriting n_timesteps with n={args.n_timesteps}") 192 | n_timesteps = args.n_timesteps 193 | else: 194 | n_timesteps = int(hyperparams['n_timesteps']) 195 | 196 | normalize = False 197 | normalize_kwargs = {} 198 | if 'normalize' in hyperparams.keys(): 199 | normalize = hyperparams['normalize'] 200 | if isinstance(normalize, str): 201 | print('normalize hyperparam is apparently a string') 202 | normalize_kwargs = eval(normalize) 203 | normalize = True 204 | del hyperparams['normalize'] 205 | 206 | if 'policy_kwargs' in hyperparams.keys(): 207 | # Convert to python object if needed 208 | if isinstance(hyperparams['policy_kwargs'], str): 209 | hyperparams['policy_kwargs'] = eval(hyperparams['policy_kwargs']) 210 | 211 | # Delete keys so the dict can be pass to the model constructor 212 | if 'n_envs' in hyperparams.keys(): 213 | del hyperparams['n_envs'] 214 | del hyperparams['n_timesteps'] 215 | 216 | # obtain a class object from a wrapper name string in hyperparams 217 | # and delete the entry 218 | env_wrapper = get_wrapper_class(hyperparams) 219 | if 'env_wrapper' in hyperparams.keys(): 220 | del hyperparams['env_wrapper'] 221 | 222 | log_path = f"{args.log_folder}/{args.algo}/" 223 | save_path = os.path.join(log_path, f"{env_id}_{get_latest_run_id(log_path, env_id) + 1}{uuid_str}") 224 | params_path = f"{save_path}/{env_id}" 225 | os.makedirs(params_path, exist_ok=True) 226 | 227 | callbacks = get_callback_class(hyperparams) 228 | if 'callback' in hyperparams.keys(): 229 | del hyperparams['callback'] 230 | 231 | if args.save_freq > 0: 232 | # Account for the number of parallel environments 233 | args.save_freq = max(args.save_freq // n_envs, 1) 234 | callbacks.append(CheckpointCallback(save_freq=args.save_freq, 235 | save_path=save_path, name_prefix='rl_model', verbose=1)) 236 | 237 | def create_env(n_envs, eval_env=False, no_log=False): 238 | """ 239 | Create the environment and wrap it if necessary 240 | :param n_envs: (int) 241 | :param eval_env: (bool) Whether is it an environment used for evaluation or not 242 | :param no_log: (bool) Do not log training when doing hyperparameter optim 243 | (issue with writing the same file) 244 | :return: (Union[gym.Env, VecEnv]) 245 | """ 246 | global hyperparams 247 | global env_kwargs 248 | 249 | # Do not log eval env (issue with writing the same file) 250 | log_dir = None if eval_env or no_log else save_path 251 | 252 | if n_envs == 1: 253 | env = DummyVecEnv([make_env(env_id, 0, args.seed, 254 | wrapper_class=env_wrapper, log_dir=log_dir, 255 | env_kwargs=env_kwargs)]) 256 | else: 257 | # env = SubprocVecEnv([make_env(env_id, i, args.seed) for i in range(n_envs)]) 258 | # On most env, SubprocVecEnv does not help and is quite memory hungry 259 | env = DummyVecEnv([make_env(env_id, i, args.seed, log_dir=log_dir, env_kwargs=env_kwargs, 260 | wrapper_class=env_wrapper) for i in range(n_envs)]) 261 | # if normalize: 262 | # # Copy to avoid changing default values by reference 263 | # local_normalize_kwargs = normalize_kwargs.copy() 264 | # # Do not normalize reward for env used for evaluation 265 | # if eval_env: 266 | # if len(local_normalize_kwargs) > 0: 267 | # local_normalize_kwargs['norm_reward'] = False 268 | # else: 269 | # local_normalize_kwargs = {'norm_reward': False} 270 | 271 | # if args.verbose > 0: 272 | # if len(local_normalize_kwargs) > 0: 273 | # print(f"Normalization activated: {local_normalize_kwargs}") 274 | # else: 275 | # print("Normalizing input and reward") 276 | # env = VecNormalize(env, **local_normalize_kwargs) 277 | 278 | # Optional Frame-stacking 279 | if hyperparams.get('frame_stack', False): 280 | n_stack = hyperparams['frame_stack'] 281 | env = VecFrameStack(env, n_stack) 282 | print(f"Stacking {n_stack} frames") 283 | 284 | if is_image_space(env.observation_space): 285 | if args.verbose > 0: 286 | print("Wrapping into a VecTransposeImage") 287 | env = VecTransposeImage(env) 288 | return env 289 | 290 | env = create_env(n_envs) 291 | 292 | # Create test env if needed, do not normalize reward 293 | eval_env = None 294 | if args.eval_freq > 0 and not args.optimize_hyperparameters: 295 | # Account for the number of parallel environments 296 | args.eval_freq = max(args.eval_freq // n_envs, 1) 297 | 298 | if 'NeckEnv' in env_id: 299 | # Use the training env as eval env when using the neck 300 | # because there is only one robot 301 | # there will be an issue with the reset 302 | eval_callback = EvalCallback(env, callback_on_new_best=None, 303 | best_model_save_path=save_path, 304 | log_path=save_path, eval_freq=args.eval_freq) 305 | callbacks.append(eval_callback) 306 | else: 307 | if args.verbose > 0: 308 | print("Creating test environment") 309 | 310 | # save_vec_normalize = SaveVecNormalizeCallback(save_freq=1, save_path=params_path) 311 | # eval_callback = EvalCallback(create_env(1, eval_env=True), callback_on_new_best=save_vec_normalize, 312 | # best_model_save_path=save_path, n_eval_episodes=args.eval_episodes, 313 | # log_path=save_path, eval_freq=args.eval_freq, 314 | # deterministic=not is_atari) 315 | # save_vec_normalize = SaveVecNormalizeCallback(save_freq=1, save_path=params_path) 316 | eval_callback = EvalCallback(env, 317 | best_model_save_path=save_path, n_eval_episodes=args.eval_episodes, 318 | log_path=save_path, eval_freq=args.eval_freq, 319 | deterministic=not is_atari) 320 | callbacks.append(eval_callback) 321 | 322 | # TODO: check for hyperparameters optimization 323 | # TODO: check What happens with the eval env when using frame stack 324 | if 'frame_stack' in hyperparams: 325 | del hyperparams['frame_stack'] 326 | 327 | # Stop env processes to free memory 328 | if args.optimize_hyperparameters and n_envs > 1: 329 | env.close() 330 | 331 | # Parse noise string for DDPG and SAC 332 | if algo_ in ['ddpg', 'sac', 'td3'] and hyperparams.get('noise_type') is not None: 333 | noise_type = hyperparams['noise_type'].strip() 334 | noise_std = hyperparams['noise_std'] 335 | n_actions = env.action_space.shape[0] 336 | if 'normal' in noise_type: 337 | if 'lin' in noise_type: 338 | final_sigma = hyperparams.get('noise_std_final', 0.0) * np.ones(n_actions) 339 | hyperparams['action_noise'] = LinearNormalActionNoise(mean=np.zeros(n_actions), 340 | sigma=noise_std * np.ones(n_actions), 341 | final_sigma=final_sigma, 342 | max_steps=n_timesteps) 343 | else: 344 | hyperparams['action_noise'] = NormalActionNoise(mean=np.zeros(n_actions), 345 | sigma=noise_std * np.ones(n_actions)) 346 | elif 'ornstein-uhlenbeck' in noise_type: 347 | hyperparams['action_noise'] = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), 348 | sigma=noise_std * np.ones(n_actions)) 349 | else: 350 | raise RuntimeError(f'Unknown noise type "{noise_type}"') 351 | print(f"Applying {noise_type} noise with std {noise_std}") 352 | del hyperparams['noise_type'] 353 | del hyperparams['noise_std'] 354 | if 'noise_std_final' in hyperparams: 355 | del hyperparams['noise_std_final'] 356 | 357 | if args.trained_agent.endswith('.zip') and os.path.isfile(args.trained_agent): 358 | # Continue training 359 | print("Loading pretrained agent") 360 | # Policy should not be changed 361 | del hyperparams['policy'] 362 | 363 | if 'policy_kwargs' in hyperparams.keys(): 364 | del hyperparams['policy_kwargs'] 365 | 366 | model = ALGOS[args.algo].load(args.trained_agent, env=env, seed=args.seed, 367 | tensorboard_log=tensorboard_log, verbose=args.verbose, **hyperparams) 368 | 369 | exp_folder = args.trained_agent.split('.zip')[0] 370 | # if normalize: 371 | # print("Loading saved running average") 372 | # stats_path = os.path.join(exp_folder, env_id) 373 | # if os.path.exists(os.path.join(stats_path, 'vecnormalize.pkl')): 374 | # env = VecNormalize.load(os.path.join(stats_path, 'vecnormalize.pkl'), env) 375 | # else: 376 | # # Legacy: 377 | # env.load_running_average(exp_folder) 378 | 379 | replay_buffer_path = os.path.join(os.path.dirname(args.trained_agent), 'replay_buffer.pkl') 380 | if os.path.exists(replay_buffer_path): 381 | print("Loading replay buffer") 382 | model.load_replay_buffer(replay_buffer_path) 383 | 384 | elif args.optimize_hyperparameters: 385 | 386 | if args.verbose > 0: 387 | print("Optimizing hyperparameters") 388 | 389 | if args.storage is not None and args.study_name is None: 390 | warnings.warn(f"You passed a remote storage: {args.storage} but no `--study-name`." 391 | "The study name will be generated by Optuna, make sure to re-use the same study name " 392 | "when you want to do distributed hyperparameter optimization.") 393 | 394 | def create_model(*_args, **kwargs): 395 | """ 396 | Helper to create a model with different hyperparameters 397 | """ 398 | return ALGOS[args.algo](env=create_env(n_envs, no_log=True), tensorboard_log=tensorboard_log, 399 | verbose=0, **kwargs) 400 | 401 | data_frame = hyperparam_optimization(args.algo, create_model, create_env, n_trials=args.n_trials, 402 | n_timesteps=n_timesteps, hyperparams=hyperparams, 403 | n_jobs=args.n_jobs, seed=args.seed, 404 | sampler_method=args.sampler, pruner_method=args.pruner, 405 | n_startup_trials=args.n_startup_trials, n_evaluations=args.n_evaluations, 406 | storage=args.storage, study_name=args.study_name, 407 | verbose=args.verbose, deterministic_eval=not is_atari) 408 | 409 | report_name = (f"report_{env_id}_{args.n_trials}-trials-{n_timesteps}" 410 | f"-{args.sampler}-{args.pruner}_{int(time.time())}.csv") 411 | 412 | log_path = os.path.join(args.log_folder, args.algo, report_name) 413 | 414 | if args.verbose: 415 | print(f"Writing report to {log_path}") 416 | 417 | os.makedirs(os.path.dirname(log_path), exist_ok=True) 418 | data_frame.to_csv(log_path) 419 | exit() 420 | else: 421 | # Train an agent from scratch 422 | model = ALGOS[args.algo](env=env, tensorboard_log=tensorboard_log, 423 | seed=args.seed, verbose=args.verbose, **hyperparams) 424 | 425 | kwargs = {} 426 | if args.log_interval > -1: 427 | kwargs = {'log_interval': args.log_interval} 428 | 429 | if len(callbacks) > 0: 430 | kwargs['callback'] = callbacks 431 | 432 | # Save hyperparams 433 | with open(os.path.join(params_path, 'config.yml'), 'w') as f: 434 | yaml.dump(saved_hyperparams, f) 435 | 436 | # save command line arguments 437 | with open(os.path.join(params_path, 'args.yml'), 'w') as f: 438 | ordered_args = OrderedDict([(key, vars(args)[key]) for key in sorted(vars(args).keys())]) 439 | yaml.dump(ordered_args, f) 440 | 441 | print(f"Log path: {save_path}") 442 | 443 | try: 444 | model.learn(n_timesteps, eval_log_path=save_path, eval_env=env, eval_freq=args.eval_freq, **kwargs) 445 | except KeyboardInterrupt: 446 | pass 447 | 448 | # Save trained model 449 | 450 | print(f"Saving to {save_path}") 451 | model.save(f"{save_path}/{env_id}") 452 | 453 | if hasattr(model, 'save_replay_buffer') and args.save_replay_buffer: 454 | print("Saving replay buffer") 455 | model.save_replay_buffer(os.path.join(save_path, 'replay_buffer.pkl')) 456 | 457 | # if normalize: 458 | # Important: save the running average, for testing the agent we need that normalization 459 | # model.get_vec_normalize_env().save(os.path.join(params_path, 'vecnormalize.pkl')) 460 | # Deprecated saving: 461 | # env.save_running_average(params_path) 462 | -------------------------------------------------------------------------------- /paper-notebooks/gridworld-training-curves.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 12, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "True" 24 | ] 25 | }, 26 | "execution_count": 12, 27 | "metadata": {}, 28 | "output_type": "execute_result" 29 | } 30 | ], 31 | "source": [ 32 | "Path('../scripts/output/regressed_exps')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 15, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "evaluations = {\n", 42 | " 'CoinFlipGoal': {\n", 43 | " 'true': Path('../scripts/output/regressed_exps/ppo/EmptyMaze-10x10-CoinFlipGoal-v3_1/evaluations.npz'),\n", 44 | " 'rm': Path('../scripts/output/regressed_exps/ppo/EmptyMaze-10x10-CoinFlipGoal-v3_2/evaluations.npz')\n", 45 | " },\n", 46 | " 'TwoGoals': {\n", 47 | " 'true': Path('../scripts/output/regressed_exps/ppo/EmptyMaze-10x10-TwoGoals-v3_1/evaluations.npz'),\n", 48 | " 'rm': Path('../scripts/output/regressed_exps/ppo/EmptyMaze-10x10-TwoGoals-v3_2/evaluations.npz')\n", 49 | " }\n", 50 | "}\n" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 53, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "plt.rcParams.update({\n", 60 | "# \"text.usetex\": True,\n", 61 | "# \"font.family\": \"serif\",\n", 62 | " \"font.serif\": [\"Times\"],\n", 63 | "})" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 73, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAADJCAYAAADcgqJyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2deZhcVbW331935jCEkIQQMjMaBBQQBJnBAWVSQJFJDIgDoF5ArsKVQa7fRRQUGQVE7hUREARkHgIRkEGDMiqoJEwJEEgImZMe1vfH2oeqrlRVV3e6uqq71/s8/VTXPvucs86ptddZe5+115aZEQRBEPQOGmotQBAEQdB1hFEPgiDoRYRRD4Ig6EWEUQ+CIOhFhFEPgiDoRYRRD4Ig6EX0q7UAnUHSfsAE4DVgBDDQzC4uUu944FYze62gfDDwQ2Bt4DZga+A+4HDgWjP7o6RzzeyUEuefCBwNPAcsB/YysxMqkHtd4B4z27bCS0XSgcAdZrZc0mHAcOCDwBDgcUBmdlGlx8s77ijgFuAP6ViY2ekdPU4F59kA+K2Z7dLVx+6tSDoDeAWYCtwNDALuN7OHOnCMLwBjgFnAWsAQM7usgv0+BexvZl/vpOwbAiPN7HFJGwFfx9vp4cAlwOZmdlInjjsYOBfX/z8CHweOMLPlnZGzzHmGANcB3zSzl7vy2N1FjzPqkrYGvmRmB+aV7V+sbiljZ2bLJD0DDDOzW4BbJK0F7Ako1Sll0AcA1wKfNLNFqezflchuZvMkLa6kbh6D8hT3LjObL+moJPvFkoZ38HiZLHMl/RM3Gk8Db0u6wMzmdeZ4Zc4zW1JrVx6zD3Bh+p13A+40s6c68jtL2h3Yw8y+mle2eYW7vwAc0iFp2/JRM/tN+v9t4BRgHLCbmV21Gvq6TNKTuN5fLmlv4NPA71dD1mLnWSppflces7vpcUYd+CLwYH6Bmd0qaVNg51Q0CLgc9wx+DQwAvgFMB3bJeyB8WNJUYA0z+7kkACRNAK4B9k6f04E9gDuBl4E38wz6JOBoSdfgymvAR4B7gH8ApwKvAkvN7IrCi5F0NLAT8D/ArcDHgMPSOd9N+2bXWUzZ1pR0J3AH7pWdCGwP/G+6B/9Osm8KPGdmvy1yjFHAImBhkmlq+v5p4HTg/nRvL8Eb0TTg+HSN2wMTgV/inuFXgTeBM9J1vAyMLHLOoASFv7OkBuCnkh4AlgFHAfvhvc1z0/dngQOA7+C/wbV5+x8IbCnpPOArBXU/DQwDdgC+X3Dek4B3gKFmdkkq2wL4P+CTwM3pGCuBrc3scuD9B7iZvZf2yT9sk6Tbcf0eDOyO90j2wXvdZwPHACuALYp59el+jMN1G0n7pk2HACfjuv9DXDc3Ao4DzgHOx9vHq8BS4F7gerzt/AXYHO99b1N4zp5ETxxT7wcMLFJ+Ou7JXgkcmerMxj3vfwHvmdkFuOJk/M3MrsKHMd7HzF4BWs1sMbAAH6L4IvAtoD95D0MzmwXsBTwFnGhmfwDOA36CN8DfA/OBHUtcz1XA+sDMVG88/tB4Fjf2j5S7GUnWFjM728yuARaaWRPwz1TleLxhPo4b30L2xIeg9jCzJkmbAR/Cvay5+PDStcDGuDHfDtgqXd9rwJ9Sna2AvwOLkod4DPCEmd2Uzh90EjNrBc7EdWgG/vAcBDyKDx0OMrN7gRdx3W+jo/jDfhJu5ArrPgc8k+pvVHDqvYBG4IY8WZ4F/oq3r0dwfZgAXClpO+DP7VzLIlwfbjGzs8k5Lc+lz48D6+LOwEBJjQWHmIL3Ls8ws2dS2ZG4EzIjXcMP8bbzKLAhsA7uML1LXnvMbzvAYmCwmd2Trq/H0hON+i3Ap5T3+Jc0Hr+WUaloDtCCe82FrFJmZsUUMb9eq5ktAV7HPYzNJK1XsL0VGCGp0cwW4F7vJNyjeZg0rFPk3Ja2n417EWel84CPl1eSx6EpX5Zs3/TZCCwxs9uA3xXZ9wm80R6eV39NM5uOP5xW4A+e4/EGuCk+Zvoe7mWNBJ7PO19z+hxB7gEauShWk+Q8jMDv+WXARXgPKl/v38J/r9/SdgilNf0Vq3sw7nzMZFUdPQ5/gJxZUH4V3hO7DdjXxbNWYEMze6nCS2rO+2ygrb72T/r3U/I8/8SreBv5qqTswTUGN+AXALPN7I94j3cc7oGfjLexYu0xazu9Rl97nFFPP9hNwAWSjpR0AN6VOgs4VtJBePdrJf6U3gTYDBgjaSywrqSt8C7WVpLWAEgewYbAFEmjgfHpE2C/NI59bvI0DgJOlXRgOv8Tqd73gTMlfRk4CRiKN8I90/EmAOtJmlJwWb8CVpjZE8AcM3ssdXOfLrx+SYPyZB8qaQwwIQ0DAfxd0v/gXvnGeOM/R9LvgbF5x1kv3ZutceM9XtLpJO9c0nTgWDNbmF40v4sPt1xP6vbiL5q/iDeWbfHu64aS1gYuBU5IQzlD04M3qJD0IntjYIc8B+Z6/AXqb4C5ydG4D2iUdAT+gP21md0IPCjpVEmfBnbBhyxXqYv/hl9J33fAH9rjUrv4Gd4DfTZfNjP7E+7IPJKO+0QysCtLXM5OwERJYyQNxfUwG+J4EPesdwE2SDJuJekePADB0v0YjOvYZkmmG4BrU8/yCuAx4GIgeyf0AO5x/wJ4Ox2nsD2OJ9d27gY2lvRd/GVsYRvtMagjCb0k5b+06xNIuho4s6e+CQ+K0xd1OegbtPuiVNKh+Iu2RvwJ+/FqC1UvpIiYcbgH+nJtpQlWl76sy0HfoZLol4/jQxuGdwf7DGa2EO+qBb2DPqvLQd+hEqP+Ah6RAcWjJ4KgpxC6HPR6KjHq4/DwKfAXiVdWT5wgqCqhy0GvpxKjvtLMzoL3I0SCoKcSuhz0eiox6u9J2gOPJ90Gjx2tCiNGjLCJEyd26TGXL4fGEs23pQUGDcp9b2qC1lZoaPBt/fuX3jfoep588sl3zKyas09rrsutrbAyBf4NHAgSNDe7vjWkAOOWFv8s1D0z/xtYbOpdUFd0gy6XpBKjPobc9PtJ5SquLhMnTmTGjBlddrwVK+CVV2CNNYpvX7wYxo93w97SAi+9lKu7cqU3svERXd1tSHqlyqeouS6/8QYsW+YOxMSJbqDfeMOdj8xYt6bpNg1FZpEsWgQTJrR1RoL6oxt0uSSVGPXjzKwZQFKPyonQ1FR+u+QNbNAg/8wP2R8wwBtQczP064kZcoJi1FSXW1tdp4YOdd3MPPIVK9p65cWMeUZjIyxZEkY9KE0l5upZSW/gs08Nn5FVkjSz7MfARWb2kqR1gP3xfBQ3lUhK1SGyxtDQ4Ia5FJm3XYrzz4fbb/eGss02cNppsOaavu255+DUU+Htt8ufI+g4H/kI/K5YwoLqU1Ndbm7Ojut/mR43NcGQIe3vv3y56+xtt8WwYL2w1VZw6621lqItlRj1j2XKK+lb7VU2s2ZJ75HLrXAQ8GT6/1B82vr7SDoWOBZgfAVjHYsWwZw5/v/IkTC8TCLPZctKK/+sWXD55f6jjBsHf/gDPPII7L+/e0I33ODH3mmn8g+GoONMnlyzU9dcl7PeYENDbizdrH3H4eWXYepUePFF2HVX1/2g9tTj8GwlRv08SYZ7NxvgSXM6wkA8aU4DnjioDSld5+UA22677So5C1pafOx77bVd+efN866nmRvfzKgvW+YNJf8l0ooV/rKzGBdd5Nsuuwy22w7uvhv+8z/hN7/xBrbffvDd78Kmm8Jaa3XwioN6paa6nE9Dg/ckM++9kNZWH2vfYAP/fuaZ7sz87//CjjvChhtGDzIoTiVG/WY8sVQznuypLClUbD1gfUk74auIZPmOb+6ogC0tMHeuj3GDG+o113SjvnRpzstZsMC98lEpB11rq3dri409vvYa3HgjHH54rtFsv72XDR2aq7d4ce68Qa+gprqcT2OjG/VsCKaQ+++HL3/Ze5MTJsB998HJJ8POO+eiZoKgGGWNesoqOBwff2zEF3w4q9w+ZtaCL2EFnuISPAthp2lqcq9l4MCc5y25QW9q8rIlS7x85MhcmFgpxb/sMveUjj4659kPGpSLOsgnjHrvoF50OaPQqM+d68MrF13kUTHPP+/l3/kObL65OzJTp7pelxtyDIL2PPXd8dVrdsEbwwtVl6gI/fq5Mq9Y4cMw+WQRLplBXrHCDXR+t/aFF+Dpp+ELX/CXTTffDPvs4159ZrQLjXfm5cd4eq+hLnQ5IxtTz+ZRPPQQ/O1v8PDDbtRnzoRhw7zOo4/CN7/pur94ccSpB+Upa9TN7OqUh3sMvujDm90iVRGGDm0bcgjeMJYvb/vyaelSN8ZNTTlP/Sc/gbvugilTPG79vffg4IN9Wxau2Njohr2lxf9vagqPqDdRT7qcz9Klrm9PPeXf//Uv/5w1yz30L3/Ze5Zf+Upun+g9BuWoxA+9DDgQXwLqnOqKUxppVa+5f39vFFnDGDDAx9abm+Hdd91gr1zpXhDAz38ON90Eo0fDxz7mZfnRMWuskZvt19oascC9kLrQ5Xyam0sb9cmTYe+9PWRu+PDoPQaVUcmL0uuABWa2XNLO7dbuRvr1c4Pe1OQGvbHRo2BefdW998GDvTu7ZAl8+MNw551e5ytf8U+ztkZ9yBB/GGSER9TrqDtdbm313mE2hv6vf8H8+e6cTCqY8xq9x6ASKnnmDwa+JulefE3CuiH/ZWlmnPv39/LBg/37tGk+BvmLX/gQTksLHHRQbv98o54Z8UWL/Dgxk7TXUXNdNoOrr3bHI/v+wgveQ9xiCw8IeDYtHlcYz1+YqygIitGu2TKz6yXdhj8AhlVfpI5ROM5eqPTTpsEOO3jo4sknw1//Ch/4gDeQfv3aRsj07++TCQo9+KB3UA+6/OabPnP5m9+EE05wh+Rvf/NtBx/sBv3ee/17oacuRe8xaJ+ynrqkD0kaYWZLzWwxcEo3yVUxgweX9l5mzfIogj3T2kXHHusvncCNerEogkGD/JjReHoX9aLLzzzjn2++6fo3eLCPp6+7Luy2m2+7555Vk8mZeVk4G0F7lPTUJV2JrzS/pqTPA9+gDtfpLDdEMn26f2ZGPZ+WlrYTjYLeSz3p8nPP+edbb7mBbmz0cNuttvJJRgMG+BDMxIltHYumpph0FFRGOU/9eTPbGtgSuAS4G/hRt0jVRfzjH/5iacKEVbe1toY33oeoG13O99TB487/+U9/kd+vn0//Bx96yU8j0NxcWdKvIChn1MdK2g9frPd+PO/F2d0iVRcxc+aq45IZZvEitA9RF7rc2prz1DOj/u9/uy5OmeLfN9rIPydP9ol0y5bl9o2XpEEllDNrE4APp/+b8dXXN6m6RF3IrFmeK6MY8TK0T1EXuvziix5eO3myOxzLlvlkOPDhFoBNklSTJuWiuzJKJacLgnzKGfWpZrYwv0BSj8lXuHSpe0OlPPXCcMagV1MXupwthPSpT8Ell7h+ZqGNWdTVxhv79yycccCA3IS4MOpBJZQcfilsBKXK6pVZs/yzXO7uGH7pG9SLLs+Y4ePiO+3k3zOjPnKkly9d6jOdf/jD3IznYcO8PF6SBpXSa83azJn+WcxTb26ORhJ0PzNmwAc/CGPG+Pe33vLFL7LQRTP3zI86KrdPlvMom0wXBO3R7oxSSSdJ+rGkyZI+2x1CdQWZp54Z9fysjZUuHxb0LmqpyytWeOTLFlvAeut5Weapl4rOamhwIz90aBj1oHIqSRPQCtxuZjOBo8pVlLSNpM9IOi6v7JuSzpE0evVE9UaQLT+34Ya5McpizJzpibuyWPRFizyjI3iMejSSPknNdHnAAJ85ethhnht9yBBfrGXOnOJLorW25sbQR42KyJegciox6ouASZKOBNZop+4xZnYHMFDSByUNx6MMWoC3iu0g6VhJMyTNePvtt9ueeBEceST86U/+/dJLPeHWN77hnvd993n5vHlwxhm5hTLAPfXMSzfzBpLlXod46dRHqZkuSx6uOGaM/z96NDz5pBvvUkY9e+czZEjoa1A5lRj16/C43qHAAe3UzeJJWoBlZjbfzI4HZgFbFNvBzC43s23NbNuRBavpNjZ6LoxTTvGJRNdcA5/7HHzve7DllvD4417vhhvgyivh97/P7TtzZu4laXPzqo0iGkmfpGa6XMjo0bnMjBMm5HIRNTS4Qc836kHQEcqlCTg1b3v2SvHHwNfKHO9qSXsDS4GdJQ0BtsE9pOc7KtyQIW7Id9gB9t/fQ7tOOMG3ffSjcMUVHus7bZqX3XQTHHGEpy2dP7/tePraa8PChX6MrPEEfYN60OVCRo/OrdY1frwb9f79cwY9jHrQWcqpzRzcK/kscAu+BNhO5Q5mZo8WKX6209Lhq79873u+mvoBB+SmUX/0ox7rO306/PnPnhDpL3/xaIIsJ3rmqWcpS1tb4Z13Iid1H6QudDmf0WlUfsAA/3/FCndismUb88fUg6AjlDTqZnY1gKRRZjY9rax+dHcJls8hh7RdrQjgIx9xb/snP3GjffbZcNxxPgSTvRzNH1Pv18/L33gjXpL2NepJlzMyoz5uXG7IJQuzzVIDxOS4oDNU0sFbIOluYBJwYZXlKYoE++7btmyttdyLf/ZZn6Cxzz7wm9/AxRd7lMuWW+aMupRb9GLIkEjk1YepuS5nZEY9C2fMz0WUDcvEEGHQGSpZJOO+5NmsMLMHu0Gmitl+ezfqu+/uXs0RR8Bjj3l0zMkn58Yo8/NQDxsWRr2vUi+6vHKlhzVC28iXbInF/IXUg6CjtGvUJf0MjwCYK2ljM7u8+mJVxo47etTLHnv493339dzp+ROLmpvbDre0E5QQ9GLqRZebmmD99f3/LJEX5Iw65BbFCIKOUsnwy30pXhdJB1VZng6x115w4YWw3365ssKZoi0tOa8o6PPUhS63trpRv+oqd0wyGhtzQy+RcC7oLJUY9U1TRrtRwBTgxuqKVDmNjR63Xo5Sy9YFfZK60OUs59AnPuH/Z0OE2TChmZeHpx50hkrU5hJ8ssZS4NvVFac6RGhYkKgbXc73yltacu95Mu88YtSDzlKJUT8RuB3Ym/KTNeqW6MYGiZrrcjapaOBAN+bQ1qhnwy6hs0FnqcQf+BdwCL6u42vVFadrybqx0UCCRM11ORsO7Ncvt/hFFqOekYXfBkFnqMRTH4EvBXYncHh1xelasqnXkTc9SNSFLg8alAu3hZyeZvTrF2G3QeepJE79UuBSAElfr7pEXUiWHiAIoD50uaHBvXKztpEu+Z75gAHhqQedp6SnLum09HmLpAckPQgUy4dRt4RRD6C+dLlfv+IJ5fKHCMOoB6tDOdX5Wfo8y8z+BiBp0+qL1HUUdmuDPkvd6HJjY3GDnV8W8yqC1aHcwtPZkhONkq6SdD6wuHvE6hoKu7VB36SedLl/fzfsmade7L1PFrMeBJ2hEtX5KXANcCVwaLmKhUuASVpH0lGSvpJWjulWzCLyJWhDTXVZymVizPSyuTleigZdSyVG/Tdm9oCZ/R34B4CkMSXqtlkCDDgIeAZ4knYaUaW0tnpq0mXLcnG+pQhPPSigprrcvz+MGOH/Z554a2u89wm6lkpM3uGSDsEX7R0u6SRgA3y9xkLaLAGGLx3WhD88lhU7uKRjgWPT18WSXszbPAI0DwYNgJYUK9DYAM3JnPdrzJUXo0GwfGX7l9gpRgDvVOnYq0O9ygXtyzahyuevsS4XXvugAe56rGwqr8dVp151pl7lgtrrcklkWVq4UhWksWb2uqRhwEIza83KitTdEVgbGIs3gNuBLBP6rWY2v0PCSTPMbNuO7NNd1Kts9SoX1F620OXi1Kts9SoX1Lds5dYoPQNfx3GipNF41/NC4LhijQBKLgH2q64QNAg6S+hy0JcoN6b+tJktBXYB/mpmK4H7u0esIOhSQpeDPkM5o56N8T1nZq0FZd1F3SzIUYR6la1e5YLayRa6XJ56la1e5YI6lq3kmLqkmcACYFj6FDDSzMZ2n3hBsPqELgd9iXLRL18ws7/kF0j6SJXlCYJqELoc9BnajX4JgiAIeg51OTVH0jbAaGCimV1cIxn64RETFwG7AW8B84EV+bJJOrrUtirJtSvwCWA8ML1e5EqybQTsg8fwzqon2WpF6HJZuUKXq0C9ZpgonM3X7ZhZM/Aefo/WNrPbgWMKZNuizLZqyf2omZ0GDKkzuTCzfwMP4MpdV7LVkJpfX+hyx+nJulyvRr1wNl8tyWYSgieBypetocy2qshtZk1JYX5QT3Ll8Sa+EEU9ylYL6un6Qpc7Ro/U5bocU8+fzWdmV9RIhka8u3otsDnwKj4tuF++bJK+VmpbleQ6HNgOj+JoAf5SD3Il2fYCBgCDgZGlzl8L2WpF6HJZuUKXq0BdGvUgCIKgc9Tr8EsQBEHQCeoy+qWnkXKLvAJMxVeqHwTsZGZ71FSwIOggocs9nzDqXcOFZjZf0m7AnWb21OouCiKpAfgvM/tBl0gYBJURutzDCaPeBRSmYZU0AbhG0p7AebjnswnwOv5SaCcz+7ykLJXrIcBJwKfw0LMdgUuAg9Miye8BmwK7AjcCO+Pxs43AWmZ2ZMoN/g4w1Mwuqeb1Br2X0OWeT4ypVwEzewVoTdkA5wEzgHOAyWb2c2BUqnoksCht3xiYgodQXWJms4B5ZvYwcALwNvCnVO81PPPg8cAkSROBvfCGcUN3XGPQNwhd7nmEUa8ehWFFxbICjgEeBS4AZuOr3j8L/Lqg3prAa2b2W+DBguPNxhvTcel4Z66u4EFQQOhyDyKMehchaRTueewgaX1gvKSxwIZ4d3VTYINUNlLSZsAVwGPAxbgXdDqwErgnHfZdSfsB5wI3SLoWn30HsJOkfYCHzGwe3oiewhtSEHSa0OWeTcSp90AkHQVgZlfXVpIgWD1Cl7ueeFHaw0izAzdL/8viqRz0UEKXq0N46kEQBL2IGFMPgiDoRYRRD4Ig6EWEUa8xkgYVfB9cK1mCYHUo1OVUFvrczYRRrxGShks6rcimNSWdJWlAtwsVBJ0gT5ffkHRmweapafWnoJsIo147fgfcYmbL8wvNbC7wBPDzmkgVBB3nd8AteAqAQn4B/ELSOt0rUt8ljHoNSGszbmJmz0s6TdJUSS+mCRjgM+2+LGl8DcUMgnbJ1+VUtI2k30q6Q1L/tJTeP4Gv1U7KvkUY9dqwBzAnjUGegufLeAv4EoCZLcO9nt1qJWAQVMgewJy8708CX8YTen02lb0OfLyb5eqzhFGvDaOAFWnoZXvg83imuv55dZYD69dAtiDoCKPwxZnfJ+n1UmBSKgpd7kbCqNeGOcAgSWvj2eoewlON5jOYth5QENQjc/CFNDIaJPXD9XdGKgtd7kbCqNeGu4GxuGe+EvgK8BKeJGmdFAa2JvBA7UQMgorIdBngeWBd4DTgp2Y2LZWPJZfYK6gykfulBpjZXyT9BRhjZqt0SyV9GrjUzGZ3v3RBUDmZLkva0sw+U7g9ee0TgK92v3R9k/DUa8dhwGfSEMz7SFoPT3L0nZpIFQQdp6guJ74EHG1mC7tZpj5LJPSqMZIazayl1Pcg6CkU093Q5+4njHoQBEEvIoZfgiAIehFh1IMgCHoRYdSDIAh6EXVj1CXtJ+kESQdIOkbScSXqHS9pXJHywZLOl/TLdIwfSNpZ0i9SfgoknVvm/BMlnS3pC5L2l3RhhXKvK2lG+zXb7HNglqZU0ihJj0r6bpL5Bx05VlciqZ+ky7L7lcqmSHpT0iGSrpa0f5XO/QlJl1bj2CXOd4akoyQ9JOnUdO936eAxviDpP5K+HSmpovwmkj61OtcqaUNJH837PiDpz2FJlgskTSyxb9E2IGmcpH+mY3xe0hWp7KG0fQdJB5SRabXuhaT1JF1Tpt5Okl5NeniDpB0qOX41kLRBdl/S937p/t+brv1WSSOqcN5V2mdRzKzmf8DWwE0FZft34jhHAd/O+74WcCawWzv7DQAeBdbMK9u8A+ed3kE5Dyv4fjXwIUD4zNJ1a/hbrHK/gKfS51bAM1U670Tg6m68zuH59z6/rML9dwd+UVBWkc6s7rUW0Z/zgc/mfZ8MbNmJ404HhqX/18rKuuNe4BPxftxO3UwP9wf+0F26UupeFXzfDfhZ+v+nwIlVOm+79qxeJh99Ec9M+D5mdqukTYGdU9Eg4HLgEuDXuCH+Bq6Iu5jZganehyVNBdYws59LAkDSBOAaYO/0OR1PRnQn8DLwppktSnUnAUcnz2EcYMBH8Flx/wBOBV4FlprZFYUXI+loYCfgf4BbgY/hsbzTgXfTvsUYBSwCFqbjTE3fPw1MBb6Fzzw9HW8ITcCOeI/rNLxxbQo8BzwNfBTYAfgesA+eJGxHM/tOkWN/FZgF7JLkLMZk4MUk25rAIfjM1wbgqXRfsuv8OLBdupYP4NPEDwCOxZOYrQ0MxX/PrYH1SpyzKpjZ/PzvkhqAn0p6AFiGOwj7AT8Ezk3fn8Wv4Tv4/bo2b/8DgS0lnYfPEM6v+2lgGP5bfL/gvCfhD/KhZnZJKtsC+D/gk8DN6Rgrga3N7HKgteByjsB1Iru2mek4XwLeBD4BXInnY7kG2BM3PK/hcyLuN7PMS/5cOv8FJD1Mx/oqnp/oD3ha6NuAg4FDu+JemFlT8tinkKe3ZlaYPgPa6uFYvB1nen8L/ps9ire/z6XP6fhkyyyb5CHASXjisffwdnQa3sZW4EnI/lxw7LuAo3F7MbKIXPny3Z/k2w13Lj8L/ATX/WI69pN0P/4MbAxclM73K/x3G0j77ROon+GXfrjQhZwO3GVmVwJHpjqzcY/2X8B7ZnYBrmwZfzOzq4DH8w9kZq8ArWa2GFiAK+cX8R+xP3mza81sFrAXbqhONLM/AOfhN34Z8HtgPq4IxbgKT2A0M9Ubjz80nsWN/SNF9tkTbyh7JAXfDPfe3wbmpmv8Av5gmQb8FU9p+rqZHQkchxuHx3HvZwhwEG4cFgJTgA8DlxQ59njcYNwDPFxEtkGpUU/FM/CBG+8W4G/A+mZ2fzrPgnTftsZDZu8F/og3kpHpOmYBz5nZV3CF/hXw2xL3slsws1bcC9oRz1kyBnckHsWvZVC6lhdxXWyjM/g1TcITtBXWfQ54JtXfqODUewGNwA15sjyL/74DcfNq4MAAABeKSURBVF3ZDp+VeaWk7fCGn88q7UfSGsDB6Tf9P+DMvDawEv/tHwd+gBuujN8DP2LVB8c/gf5JtkYzOx9/UHypq+5FehgV6m0+wySdivcYT01lx9NW7zfCe1w3Am8ked/B1y44J8mwCP+NNyavXeD2cF/8YTe9yLGPB54ws5tYNVcTwCRJvwHuNrM7UtnR6ToeBTahtI7Nw3XgLdxJfRtYjOejn0b59tmGejHqtwCfUuZWA/Jc4g249wru6bXgXnMhq5SZWaHiF9ZrNbMluLGZDmwmn82Zv70VGCGfQLEA/3Em4R7Uw/jDZdWTeD/pYeBs4ETgrHQecENX7BqewH/Uw9P3Rnw4aDr+QFmBe0h7A38ys8dSvea8+kvM7DZ80YLZwLfxB+PWwM9wj+nXRY49gdx9LiZbP+CKJMP2eedrMbMHgWx8+He4V3su7r0tSeXH4N7NHHL3LJN7NO6113zCRHqYj8B7PJfh3tL9tNXDt/D78Fvc28toTX/F6h6MOwMzWVVnjsMb95kF5VcBZ+AP+n1dPGsFNjSzlwrq3oLrBfD+snLr4Pc2Xw5Y9T5nMuffhzfNrLA32abtpM838Yd4V90LWFVv82nGexhjce8ZCvQ+GfHZadz5mwX7gt/rR/GeyGzatot+uLOyJ+7AFLapEeQcyGL6+jbe1j+fHqqk63869eifLaNjI3An8/78+2Kej3445dtnG+pi+MXM/ijpA8AF6aXjQtxDOQv4lqRpwP/iXdANyWWGG5O6X+tK2grYBlhD0hpmtlhSY6o/V9ILwHhJmaLvJ2khcK6ZLZJ0EHBqegHSghtZ8C7RmZJm4t21ofgP8ko63gRgPUlTzOzveZf1K+BYM3tC0hwzeyx1a5/Ov/b0INkEV+DzgMslnY4byuWSpuPd4/+WtCewBtAqaRnulUyRJwC7CLhO0hu48R+D9xbuwZXtdLw7d4/54hyFxz5W0g/xBjOb1MWT9EFc4bbEu4t3pHt4DXBzum9XA/8GbgK2N7MnJf2DXEKy9XAjPxj3UjbDvf9G3FO/GrgXf4CukXpTVUfSKNxb20HS0+lhez3eu3oCmGxmSyTdB+wr6Qh86OCy1JsamTzHp/Df4kFglbrAOfgwRAM+rDAQGJca/s/wYcUN8mUzsz9J+pKZPSLpQeAJeR6VlUUu5Vv40NH6uJf5tplNk7/YPgnvvn8v/W7jU5vZCG9H/ZMsU/Dffifg9nR/1kv1N8B/s8lJhvGS9iY5C6n9rPa9SL/77rTV2+y32iUdYww+3HSTpHMo0HtJC/Cex4bAfEk/Ste1De5YXAE8hvd2vov3Su5K5xuGO2J34w+g39G2TV2K93SHA0MljTezV9M92QF3+N7Gez83S/pP/CH0gKSn8Ic0FNexdXGH6Qi8bYwHRkraJrWnxcXaZzE6NKNU0iArWH6tJyLparw7+nKNRakYSQPxl2HXJWM4tdh4fhBUG0nTzWy3WstRDEm7AzPN7JVkfPcysxva26830a6nLulQ/MnXiD8levQKJpLWwl9+bo4/uXsKA4BPStoS7yX8scbyBH0QSZvjPdN1zWxereUpQj/gdEmz8ICE62ssT7fTrqcu6Vf4MIgBG6cXYkEQBEEdUsmY+gt4dAT4G+AgCIKgTqnEqI8jt1zVhngYUxAEQVCHVGLUV5rZWeC5kassTxAEQbAaVGLU35O0Bx7nuQ0eolMVRowYYRMnTuyy4zU1QWsrNJSIxm9thYHFpjwFNeHJJ598x8zKzdTrMRTTZTNYuRL69wcJVqyAxgrdpJYWGDCgtC4H9UUtdbkSoz6G3FT9SVWUhYkTJzJjRodyY5Vl5kxvQKUazqJFMHFiGPZ6QdIr3XSeHfDUBaPSLMMup5gur1wJL73khrl/f/8cMKCy4y1ZAiNGwDrrVEHYoMvpLl0uRiVG/bg0qwlJ21RZng6TBe+oYG5aUxM0N8OgQavuk9HQ4N5SGPU+x8H4LMwF3X3ihgbXyaam8rpZSP/+sHhxGPWgfSox6s+mGVUNeFjj7uUqp9lVPwYuMrOXJK2DZ1Xrj2dinF9u/44yd64b9tGj25avXJkz+KW46y6YPh0GD+5KiYL22HRTOP309utVkSH4fIuReK6TolRLl/v397+OMHMm/PznMHRoZ84YVIvJk+Hss2stRVsqMeofy5RX0rfaq2xmzZLeI5e/4CDgyfT/ofi03veRdCyeuY/x48fTERYtgnff9f+HDIG11sptW7IE+pW5unffhVNP9cY1bFiHThusJkuWtF+nylyDT6dfv1yl7tTl9rjiCrjjDhg7tksPG6wmC7q9r9c+lRj18yQZ7qlvgCfC6QgD8RSxDXginzakVKKXA2y77bbt5ixoaoKlS/3/t97KeS5vvOHd2QED/KXSggXlvZpf/tKNy513wp57dvCKgtWi0peDVWQTPKlTR+lSXa6U1lZ44AHYay+46abyzkrQvdSBLq9CJepxM56EqhlP01qWFPa4HrC+pJ2A6/Asc9mxOs2SJTBnjg+rSG7As5vav78b+bFjYeHCTBaYNw++9z147jk45xzYZRff/stfwic/CVtuWfnLqqDXsA+euGoEns+6KNXU5Y7w3HM+zLjbbv7+px4NSVA/lDXqKQPhcHwsvRHPYXxWuX3MrAX4evqa5f791eqJCcuWwWuv+TBLMU9l0CAfjnnvPZg/38fJn38eDj3Ujfjo0fDFL3rDWLTIy77+9TDofZSL0+e65SpVS5c7yrRp7qDsskuENAbt056K7I7nFz4TTxu5tNoCleLdd90Al+t6Dh3qHk1zs3szN93kBvyuu7z7+rWvuae/aBFMnQpTpoRR76Nshw8lblFrQQppbYUnnmj7kn/aNNhqK1hvvVWjvIKgkLKeupldLen3eKz6QjwpfrfT3OyGeI01ytfL4n6z7unrr/twzGab+ffvf9//MhYvjvHJPspcPIJlcq0FKeSaa3y48NprYdddffjwqafgP/4jhl2CyqikM3cZcCCe1L0qEzXaY8kS91Aq8VLyjfrs2eWjBcyiofRRmvDl2b5ba0HyaWqCi1I8zX33+ef997ue7rprOCBBZVRi1K8DHk6LY+zcXuWuxsy9lY5M1MjIPPVSSGHU+xqSPouv4rMJvhByTVm2zGP2H3/chwtnz/ZhlmnTXPdvvhnGj4+hwqByKnn2DwYOkfRf+HqJ3UpTU8dn34E3lnfegQ02KF8vvJ8+x5N41tE5+LJmNeWOOzwS66qrYM01YYst/IX+qafCww/DI4/At7/tdUNXg0po11M3s+vxhYM/h6/w3e105uXQ7Nn+WcpTb231RhIvnvoWeQsqP00djKnfeCOMGwdHHunvjU46yePRAU45xb31Aw/M6WsQtEdZoy7pQ5JGmNnStCjsKd0k12rz+uv+OW5c8e0tLR2fqh30GvqZ2QJ8VmjNeOMN98QPOgj+3/+Df/4TPv5x711+4AMewrvttjAppdGLocKgEkoadUlXAlcBj0n6sKQrgLe6TbLVJDPqpYZfslSmQZ/kJUmX4qvI14xbbnFP/HOf8+9DhuS27bGHfx54YK4sYtSDSijXoXvezI6RNBh4APgJcGv3iLX6vP66d1ezRF9Z7Ho23BJGvW8iaaCZPQo8Kunr7e5QJcx86GWbbTwpVCGHHAKzZsEBB+Tqh1EPKqGcmoyVtB+eze5+PO9FneUjK83s2bD++m7IW1o8X8yyvGwdZjH80ke5QtKRki7AdbrbeP11WL7c/3/iCXjhBTj44OJ1J0/2JF5ZkrqI1AoqpZxRnwB8OP01AxvjYWA9gtdey70kXbrUw8TM3MBnxIunPsmvgbHApcDV3XXSefNgu+3g3HP9+wUXwMiRPp5eKeGpB5VQzqxNNbOF+QWS1ipVud54/XX42Mc8r/rAgZ5et6HBk35lM1PD8+mTfBd4D1/5aBKe2KvqrLsuHH64G/Phw+Ghh3x2cyW5/LMlGSNSK6iEkka90KCXKqtHmprceI8b50Z97FhvEGut5Qm/Fi8uv8xd0Kv5vJnNA5A0qjtP/N//7TNEL7rIVzA64ojK9mttjaHCoHJ6fIeuuXnVsjfe8IaQDb9kDULy2XmbbOJhYtGd7XtkBj39324q6a5kwAA4/3xfa/TEEytfxShi1IOO0K5Zk3SSpB9LmpymWNcN997r06fvvLNteX44Y0QNBIVIGi9pmKSyqXerwYQJ8OSTniW0UmJORdARKjF3rcDtZjYTOKpcRUnbSPqMpOPyyr4p6RxJo8vt21HmzPHMdUuWwMkn5ww5wKtpzuDYsW7Qw6gHBeyfJh+VDGmspi6353UvXtx2yb+I1Ao6QiXmbhEwSdKRQDvJbznGzO4ABkr6oKTheMRMCyUmLkk6VtIMSTPefvvtsgdvaoLf/c5Dvb76VR8vv+Ya92SOO86HYszg+us9nHGDDdrGpgdBYoGko4Edy9Spqi6Xw6xtPvUYfgk6QqVZGgcCQ4ED2qmbvXpsAZaZ2XwzOx6YRYkFCczscjPb1sy2HTlyZMkDv/gi7LuvJzc680x45hn40Y9g9939c8YMOO88eOwx+POf3cg3NMTL0KAoNwBPAUeXqVM1XS5HS4vrrBWscBq9zaBSSj7/JZ2atz3zdX8MfK3M8a6WtDe+QtLOkoYA2+De/vOdFXLuXDfogwbB5ZfDTjt5dzSbVn3AAR4iduGFcNttHv97yCERNRC0RdKteDjjBOBVPANpqUjxquhye6xcmYvSytbihfDUg8oppypzcK/ks8At+DqlZWN60/TrQp7ttHSJX/7SZ+LdfXfxKdXg4WIzZsBLL+Xif5cvj8YQtOFzZtYiaQ8ze0DSf5aqWC1dbo+WFo+KWbbMhxMz/Y2UFkGllItTvxo8ltfMpqeV1ct1V6vC3Llw3XXw2c+WNujgXvuVV/oY+5FHeplZGPUgR1pIGqC/pF9RB/nUwQ14U5PnUwefLDdkiC+ObuYOSrwXCiqlkpG6BZLuBv4O/LnK8qzCBRfAihVwwgnt191kE/jBD3LDMvGCKSjBMDyqa2atBQH3zrMhl2xx9UGDvLy5ufJ49iCAClY+MrP7kpe+wswe7AaZ3mf+fLjsMvj0p2GjjTq+f6xBGpRgpZkdLWmPWgsCrqejRrk3PjClGMuckdbWzi3lGPRd2jXqkn6GRwDMlbSxmV1efbGcYcN8PD3rlnaUmHgUlGAdScfjL0wfqKUgzc0573x0XvR79oLfLMbTg45RyeDEfSleF0kdyCm3+jQ0eGTLyy93bn8pjHpQlGuAbYE/1FqQpqa26XUz+vXLTZyLIcSgI1SiLpum7IyjgCnAjdUVqWuJ4ZegCCcA44FngF/WUpCWlrYrHuUzcGB46UHHqcSPvQSfeLQU+HZ1xel6wlMPirAQX3h6Qq0FgdKGe/Dg0gY/CEpRiad+Ir5W6SV4moCfVlWiLiaMepAhqZ+ZNQPTUlHnpn12EVl0VqkJcsOHRyhj0HEqMer/Ag4B7gZeq644XUf2kjQaRZDHnsA9wLeAecB6wBO1EmbFityCLcUIhyToDJWozQh8Sbs7gcOrK07X0doa4+lBW8zsnvTvNcAS4LwaikNzc+4laRB0Fe0adTO71My+ZGavUyZVab0RE4+CMnwN73nu3p0nzXqOra25HOkRgx50NSWNuqTT0uctkh6Q9CBQLB9GXRJGPSjDW8D2wBRJp3TXSfv181WPli71vETrrBPDg0HXU87s/Sx9nmVmfwOQtGn1ReoawqgHZfgz8C7wErB+d5542DBYsMCzMXZ2Ul0QlKOkp25m2dorjZKuknQ+sLh7xFp9Iu1uUAxJnwfmAv8AppjZdd15/oYGnzm6zjrhdATVoZIXpT/FXyxdCRxaXXE6TlOT/xUSeV+CEswB1kqLTlc9lW4xhgyB9darxZmDvkAlvsJvzOwBAEmT0+cYM5tTWFHSNsBoYKKZXSxpHWB/oD9wk5nN7zrR3RtfvtzHJRsbVw0BC6MeFOEzQIukMfgM6UeKVepuXQ6CrqISo364pEPwVKXDJZ0EbICv11jIMWb2dUknSvogsAPwZNp2KHBR4Q6SjgWOTV8XS3oxb/MI0DwYNABaWlc9XWMDrEx++oD+bes0NsDylasuDNZljADeqdKxV4d6lQval607Znj+HV/1CDxRXSmqoMs99nepFfUqF9SHLhelEqN+iJm9LmkYsNDMWiWNLVG3zbqO+NqmTfgwz7JiO6Ssj0UzP0qaYda6bQUydjsum9WdbPUqF9SHbGb267yvfyxTtQq6HL9LR6hXuaC+ZSsX0nhGWpdxoqRt8dwvFwKkmPVitFnXEV+0eltga+DmrhQ8CKpM6HLQIynnqT9tZksl7QKckzz0+8sdrMS6jr9aLQmDoAaELgc9lXLRL9n49HNm1lpQ1l1024IcnaBeZatXuaC+Zas29Xzt9SpbvcoFdSybrMR7REkzgQX4eo4LAAEjzazUeHoQBEFQY8oNv3zBzP6SXyDpI1WWJwiCIFgNSnrqQRAEQc+jLicqF078qJEM/YAf4/HIu+FJoOYDK2g7KeXoUtuqJNeuwCfw5dim14tcSbaNgH3wGN5Z9SRbrQhdLitX6HIVqNc0/Mekxa4Hpokf3U5aIec9/B6tbWa3A8cUyLZFmW3VkvtRMzsNGFJncmFm/wYewJW7rmSrITW/vtDljtOTdblejXrhxI9akk06AU9oli9bQ5ltVZHbzJqSwvygnuTK4018UZV6lK0W1NP1hS53jB6py3U5pi5pR2BtYKyZXVEjGRrx7uq1wOb41PJ38CGr92WT9LVS26ok1+HAdnhEUgvwl3qQK8m2FzAAGIyv/1kX96yWhC6XlSt0uQrUpVEPgiAIOke9Dr8EQRAEnaAuo196GpLOAF4BpuJrXw4CdjKzPWoqWBB0kNDlnk8Y9a7hQjObL2k34E4ze0rS8NU5oKQG4L/M7AddImEQVEbocg8njHoXULhggqQJwDWS9gTOwz2fTYDX8ZdCO5nZ5yXtm3Y5BDgJ+BQeerYjcAlwcFrw+z1gU2BX4EY8a+AI/G37WmZ2ZMpz/w4w1Mwuqeb1Br2X0OWeT4ypVwEzewVoNbOVwDxgBnAOMNnMfg6MSlWPBBal7RvjK/F8GLjEzGYB88zsYeAE4G3gT6nea3gWzeOBSZImAnvhDeOG7rjGoG8QutzzCKNePQrDiopluBwDPApcAMwGfoavm/nrgnprAq+Z2W+BBwuONxtvTMel4525uoIHQQGhyz2IMOpdhKRRuOexg6T1gfFphagN8e7qpsAGqWykpM2AK4DHgItxL+h0YCVwTzrsu5L2A84FbpB0LT77DmAnSfsAD5nZPLwRPUWNFlMOeg+hyz2biFPvgUg6CsDMrq6tJEGweoQudz3xorSHkWYHbpb+l8VTOeihhC5Xh/DUgyAIehExph4EQdCLCKMeBEHQiwijXgdIGpT3f/801hgEPYp8Pc4rG1wLWfoyYdRriKThkk4rKG4FvpsmYQRB3VOox5L+U5JJ+hCwr6RP1VC8PkcY9dryO+AWM1ueFZhZCx7re1NahiwI6p1CPX4i22BmNwDfScvDBd1AGPUakdZn3MTMnpd0mKRzJT0laRMzW4Avo3VwjcUMgrIU6PF3JN0FHFZQ7QngO90vXd8kjHrt2AOYI2kYPhvvKmAasEHa/jrw8RrJFgSVkunxWsD/AD8Cri+oE7rcjUT3vnaMwr3xTfEls9Y1s5Pyti8H1q+FYEHQAfL1uBFYCKxVUCd0uRsJT712zMEXIHgZaAYOApCUeeqDU50gqGcyPZ6DJ/7aANfdfEKXu5Ew6rXjbnyB2reAk4GjJP0BUNo+llwypCCoVzI9ng2cDZwB7A/8C08KBqHL3UqkCaghkm4Fvm9mzxSUrwPcDuySomGCoG4ppcd52+8BvmpmL3erYH2U8NRry2HAZyStnRVI6g8cARwUBj3oIayixxmSDgDODoPefYSnXgdIaswMuKQGMyu2CEEQ1DX5elyuLKguYdSDIAh6ETH8EgRB0IsIox4EQdCLCKMeBEHQiwijHgRB0IsIox4EQdCL+P/FstuVcMqz8QAAAABJRU5ErkJggg==\n", 74 | "text/plain": [ 75 | "
" 76 | ] 77 | }, 78 | "metadata": { 79 | "needs_background": "light" 80 | }, 81 | "output_type": "display_data" 82 | } 83 | ], 84 | "source": [ 85 | "plt.figure(figsize=(5.5, 2.7))\n", 86 | "\n", 87 | "ax = plt.subplot(2, 2, 1)\n", 88 | "with np.load(evaluations['CoinFlipGoal']['true']) as data:\n", 89 | " timesteps = data['timesteps']\n", 90 | " results = data['results']\n", 91 | "means = np.mean(results, axis=1).flatten()\n", 92 | "stds = np.std(results, axis=1).flatten()\n", 93 | "ax.plot(timesteps, means, color=\"blue\")\n", 94 | "ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color=\"blue\")\n", 95 | "# plt.ylim(-5, 40)\n", 96 | "ax.set_xlabel(\"Timesteps\", fontsize=9, fontfamily='Times New Roman')\n", 97 | "ax.set_ylabel(\"Episode Return\", fontsize=9, fontfamily='Times New Roman')\n", 98 | "ax.set_ylim(0, 1.1)\n", 99 | "plt.xticks(fontsize=7, fontfamily='Times New Roman')\n", 100 | "plt.yticks(fontsize=7, fontfamily='Times New Roman')\n", 101 | "plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman')\n", 102 | "ax.xaxis.labelpad=1\n", 103 | "# plt.legend(loc='upper left', prop={'size': 6})\n", 104 | "ax.set_title(\"CoinFlipGoal w/ True Reward\", fontsize=9, fontfamily='Times New Roman')\n", 105 | "ax.text(0.5, -0.75, \"(a)\", size=10, ha=\"center\", weight=\"bold\", fontfamily='Times New Roman', \n", 106 | " transform=ax.transAxes)\n", 107 | "\n", 108 | "ax = plt.subplot(2, 2, 3)\n", 109 | "with np.load(evaluations['CoinFlipGoal']['rm']) as data:\n", 110 | " timesteps = data['timesteps']\n", 111 | " results = data['results']\n", 112 | "means = np.mean(results, axis=1).flatten()\n", 113 | "stds = np.std(results, axis=1).flatten()\n", 114 | "ax.plot(timesteps, means, color=\"blue\")\n", 115 | "ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color=\"blue\")\n", 116 | "# plt.ylim(-5, 40)\n", 117 | "ax.set_xlabel(\"Timesteps\", fontsize=9, fontfamily='Times New Roman')\n", 118 | "ax.set_ylabel(\"Episode Return\", fontsize=9, fontfamily='Times New Roman')\n", 119 | "ax.set_ylim(0, 1.1)\n", 120 | "plt.xticks(fontsize=7, fontfamily='Times New Roman')\n", 121 | "plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman')\n", 122 | "ax.xaxis.labelpad=1\n", 123 | "# plt.legend(loc='upper left', prop={'size': 6})\n", 124 | "ax.set_title(\"CoinFlipGoal w/ Regressed Reward\", fontsize=9, fontfamily='Times New Roman')\n", 125 | "ax.text(0.5, -0.75, \"(c)\", size=10, ha=\"center\", weight=\"bold\", fontfamily='Times New Roman', \n", 126 | " transform=ax.transAxes)\n", 127 | "\n", 128 | "ax = plt.subplot(2, 2, 2)\n", 129 | "with np.load(evaluations['TwoGoals']['true']) as data:\n", 130 | " timesteps = data['timesteps']\n", 131 | " results = data['results']\n", 132 | "means = np.mean(results, axis=1).flatten()\n", 133 | "stds = np.std(results, axis=1).flatten()\n", 134 | "ax.plot(timesteps, means, color=\"blue\")\n", 135 | "ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color=\"blue\")\n", 136 | "# plt.ylim(-5, 40)\n", 137 | "ax.set_xlabel(\"Timesteps\", fontsize=9, fontfamily='Times New Roman')\n", 138 | "ax.set_ylabel(\"Episode Return\", fontsize=9, fontfamily='Times New Roman')\n", 139 | "ax.set_ylim(0, 1.1)\n", 140 | "plt.xticks(fontsize=7, fontfamily='Times New Roman')\n", 141 | "plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman')\n", 142 | "ax.xaxis.labelpad=1\n", 143 | "# plt.legend(loc='upper left', prop={'size': 6})\n", 144 | "ax.set_title(\"TwoGoals w/ True Reward\", fontsize=9, fontfamily='Times New Roman')\n", 145 | "ax.text(0.5, -0.75, \"(b)\", size=10, ha=\"center\", weight=\"bold\", fontfamily='Times New Roman', \n", 146 | " transform=ax.transAxes)\n", 147 | "\n", 148 | "ax = plt.subplot(2, 2, 4)\n", 149 | "with np.load(evaluations['TwoGoals']['rm']) as data:\n", 150 | " timesteps = data['timesteps']\n", 151 | " results = data['results']\n", 152 | "means = np.mean(results, axis=1).flatten()\n", 153 | "stds = np.std(results, axis=1).flatten()\n", 154 | "ax.plot(timesteps, means, color=\"blue\")\n", 155 | "ax.fill_between(timesteps, means-stds, means+stds, alpha=0.1, color=\"blue\")\n", 156 | "# plt.ylim(-5, 40)\n", 157 | "ax.set_xlabel(\"Timesteps\", fontsize=9, fontfamily='Times New Roman')\n", 158 | "ax.set_ylabel(\"Episode Return\", fontsize=7, fontfamily='Times New Roman')\n", 159 | "ax.set_ylim(0, 1.1)\n", 160 | "plt.xticks(fontsize=7, fontfamily='Times New Roman')\n", 161 | "plt.yticks([0, 0.5, 1], fontsize=7, fontfamily='Times New Roman')\n", 162 | "ax.xaxis.labelpad=1\n", 163 | "# plt.legend(loc='upper left', prop={'size': 6})\n", 164 | "ax.set_title(\"TwoGoals w/ CoinFlipGoal's Regressed Reward\", fontsize=9, fontfamily='Times New Roman')\n", 165 | "ax.text(0.5, -0.75, \"(d)\", size=10, ha=\"center\", weight=\"bold\", fontfamily='Times New Roman', \n", 166 | " transform=ax.transAxes)\n", 167 | "\n", 168 | "\n", 169 | "plt.subplots_adjust(hspace=1.2, wspace=0.4, bottom=0.2, top=0.92)\n", 170 | "\n", 171 | "plt.savefig('figures/maze-training-curves.pdf', dpi=100)" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [] 201 | } 202 | ], 203 | "metadata": { 204 | "kernelspec": { 205 | "display_name": "Python 3", 206 | "language": "python", 207 | "name": "python3" 208 | }, 209 | "language_info": { 210 | "codemirror_mode": { 211 | "name": "ipython", 212 | "version": 3 213 | }, 214 | "file_extension": ".py", 215 | "mimetype": "text/x-python", 216 | "name": "python", 217 | "nbconvert_exporter": "python", 218 | "pygments_lexer": "ipython3", 219 | "version": "3.7.9" 220 | } 221 | }, 222 | "nbformat": 4, 223 | "nbformat_minor": 4 224 | } 225 | -------------------------------------------------------------------------------- /paper-notebooks/atari-figures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "import importlib\n", 12 | "from pathlib import Path\n", 13 | "from itertools import product\n", 14 | "import h5py\n", 15 | "import random\n", 16 | "\n", 17 | "import gym\n", 18 | "import numpy as np\n", 19 | "import matplotlib\n", 20 | "import matplotlib.cm\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import scipy.ndimage\n", 23 | "import skimage.transform\n", 24 | "import torch as th\n", 25 | "import torch.nn as nn\n", 26 | "\n", 27 | "from tqdm.auto import tqdm\n", 28 | "\n", 29 | "from stable_baselines3.common.utils import set_random_seed\n", 30 | "from stable_baselines3.common.atari_wrappers import AtariWrapper\n", 31 | "from stable_baselines3.common.vec_env import VecEnvWrapper, VecEnv, DummyVecEnv, VecFrameStack\n", 32 | "from stable_baselines3.common.vec_env import VecTransposeImage\n", 33 | "\n", 34 | "sys.path.insert(1, \"../rl-baselines3-zoo\")\n", 35 | "import utils.import_envs # noqa: F401 pylint: disable=unused-import\n", 36 | "from utils.utils import StoreDict\n", 37 | "from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams\n", 38 | "\n", 39 | "import interp\n", 40 | "from interp.common.models import AtariRewardModel" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 10, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stdout", 50 | "output_type": "stream", 51 | "text": [ 52 | "Using device: cuda\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "########### Set Device ############\n", 58 | "device = th.device('cuda' if th.cuda.is_available() else 'cpu')\n", 59 | "# device = 'cpu'\n", 60 | "dtype = th.float32\n", 61 | "th.set_default_dtype(dtype)\n", 62 | "print(\"Using device: {}\".format(device))" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 11, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "def get_mask(center, size, r):\n", 72 | " y,x = np.ogrid[-center[0]:size[0]-center[0], -center[1]:size[1]-center[1]]\n", 73 | " keep = x*x + y*y <= 1\n", 74 | " mask = np.zeros(size) ; mask[keep] = 1 # select a circle of pixels\n", 75 | " mask = scipy.ndimage.gaussian_filter(mask, sigma=r) # blur the circle of pixels. this is a 2D Gaussian for r=r^2=1\n", 76 | " return mask/mask.max()\n", 77 | "\n", 78 | "def occlude(img, mask):\n", 79 | " assert img.shape[1:] == (84, 84, 4)\n", 80 | " img = np.copy(img)\n", 81 | " for k in range(4):\n", 82 | " I = img[0, :, :, k]\n", 83 | " img[0, :, :, k] = I*(1-mask) + scipy.ndimage.gaussian_filter(I, sigma=3)*mask\n", 84 | " return img\n", 85 | "\n", 86 | "def compute_saliency_map(reward_model, obs, stride=5, radius=5):\n", 87 | " baseline = reward_model(obs).detach().cpu().numpy()\n", 88 | " scores = np.zeros((84 // stride + 1, 84 // stride + 1))\n", 89 | " for i in range(0, 84, stride):\n", 90 | " for j in range(0, 84, stride):\n", 91 | " mask = get_mask(center=(i, j), size=(84, 84), r=radius)\n", 92 | " obs_perturbed = occlude(obs, mask)\n", 93 | " perturbed_reward = reward_model(obs_perturbed).detach().cpu().numpy()\n", 94 | " scores[i // stride, j // stride] = 0.5 * np.abs(perturbed_reward - baseline) ** 2\n", 95 | " pmax = scores.max()\n", 96 | " scores = skimage.transform.resize(scores, output_shape=(210, 160))\n", 97 | " scores = scores.astype(np.float32)\n", 98 | "# return pmax * scores / scores.max()\n", 99 | " return scores / scores.max()\n", 100 | "\n", 101 | "def add_saliency_to_frame(frame, saliency, channel=1):\n", 102 | "# def saliency_on_atari_frame(saliency, atari, fudge_factor, channel=2, sigma=0):\n", 103 | " # sometimes saliency maps are a bit clearer if you blur them\n", 104 | " # slightly...sigma adjusts the radius of that blur\n", 105 | " pmax = saliency.max()\n", 106 | " I = frame.astype('uint16')\n", 107 | " I[:, :, channel] += (frame.max() * saliency).astype('uint16')\n", 108 | " I = I.clip(1,255).astype('uint8')\n", 109 | " return I" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 12, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "env_id = \"BreakoutNoFrameskip-v4\"\n", 119 | "folder = \"../agents\"\n", 120 | "algo = \"ppo\"\n", 121 | "n_timesteps = 10000\n", 122 | "num_threads = -1\n", 123 | "n_envs = 1\n", 124 | "exp_id = 1\n", 125 | "verbose = 1\n", 126 | "no_render = False\n", 127 | "deterministic = False\n", 128 | "load_best = True\n", 129 | "load_checkpoint = None\n", 130 | "norm_reward = False\n", 131 | "seed = 0\n", 132 | "reward_log = ''\n", 133 | "env_kwargs = None" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 13, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "Stacking 4 frames\n", 146 | "Wrapping the env in a VecTransposeImage.\n" 147 | ] 148 | } 149 | ], 150 | "source": [ 151 | "# Sanity checks\n", 152 | "if exp_id > 0:\n", 153 | " log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id))\n", 154 | "else:\n", 155 | " log_path = os.path.join(folder, algo)\n", 156 | " \n", 157 | "found = False\n", 158 | "for ext in ['zip']:\n", 159 | " model_path = os.path.join(log_path, f'{env_id}.{ext}')\n", 160 | " found = os.path.isfile(model_path)\n", 161 | " if found:\n", 162 | " break\n", 163 | "\n", 164 | "if load_best:\n", 165 | " model_path = os.path.join(log_path, \"best_model.zip\")\n", 166 | " found = os.path.isfile(model_path)\n", 167 | "\n", 168 | "if load_checkpoint is not None:\n", 169 | " model_path = os.path.join(log_path, f\"rl_model_{load_checkpoint}_steps.zip\")\n", 170 | " found = os.path.isfile(model_path)\n", 171 | "\n", 172 | "if not found:\n", 173 | " raise ValueError(f\"No model found for {algo} on {env_id}, path: {model_path}\")\n", 174 | "\n", 175 | "if algo in ['dqn', 'ddpg', 'sac', 'td3']:\n", 176 | " n_envs = 1\n", 177 | "\n", 178 | "set_random_seed(seed)\n", 179 | "\n", 180 | "if num_threads > 0:\n", 181 | " if verbose > 1:\n", 182 | " print(f\"Setting torch.num_threads to {num_threads}\")\n", 183 | " th.set_num_threads(num_threads)\n", 184 | "\n", 185 | "is_atari = 'NoFrameskip' in env_id\n", 186 | "\n", 187 | "stats_path = os.path.join(log_path, env_id)\n", 188 | "hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True)\n", 189 | "env_kwargs = {} if env_kwargs is None else env_kwargs\n", 190 | "\n", 191 | "log_dir = reward_log if reward_log != '' else None\n", 192 | "\n", 193 | "env = create_test_env(env_id, n_envs=n_envs,\n", 194 | " stats_path=stats_path, seed=seed, log_dir=log_dir,\n", 195 | " should_render=not no_render,\n", 196 | " hyperparams=hyperparams,\n", 197 | " env_kwargs=env_kwargs)\n", 198 | "\n", 199 | "model = ALGOS[algo].load(model_path, env=env, device=device)\n", 200 | "\n", 201 | "obs = env.reset()" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 15, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "rm = AtariRewardModel(env, device)\n", 211 | "rm.load_state_dict(th.load(f\"../reward-models/BreakoutNoFrameskip-v4-reward_model.pt\"))\n", 212 | "rm = rm.to(device)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 8, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "random.seed(0)\n", 222 | "np.random.seed(0)\n", 223 | "th.manual_seed(0)\n", 224 | "th.backends.cudnn.deterministic = True\n", 225 | "th.backends.cudnn.benchmark = False\n", 226 | "\n", 227 | "breakout_images = []\n", 228 | "n = 0\n", 229 | "\n", 230 | "obs = env.reset()\n", 231 | "while n < 5:\n", 232 | " action, _states = model.predict(obs, deterministic=False)\n", 233 | " obs, reward, done, info = env.step(action)\n", 234 | " if done:\n", 235 | " obs = env.reset()\n", 236 | " use = (reward[0] and not random.randint(0, 9)) or \\\n", 237 | " (not reward[0] and not random.randint(0, 50))\n", 238 | " if use:\n", 239 | " n += 1\n", 240 | " sal = compute_saliency_map(rm, obs)\n", 241 | " screenshot = env.render(mode='rgb_array')\n", 242 | " image = add_saliency_to_frame(screenshot, sal)\n", 243 | " breakout_images.append(image)\n", 244 | " " 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 9, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "env_id = \"SeaquestNoFrameskip-v4\"\n", 254 | "folder = \"../agents\"\n", 255 | "algo = \"ppo\"\n", 256 | "n_timesteps = 10000\n", 257 | "num_threads = -1\n", 258 | "n_envs = 1\n", 259 | "exp_id = 1\n", 260 | "verbose = 1\n", 261 | "no_render = False\n", 262 | "deterministic = False\n", 263 | "load_best = True\n", 264 | "load_checkpoint = None\n", 265 | "norm_reward = False\n", 266 | "seed = 0\n", 267 | "reward_log = ''\n", 268 | "env_kwargs = None" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 10, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "name": "stdout", 278 | "output_type": "stream", 279 | "text": [ 280 | "Stacking 4 frames\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "# Sanity checks\n", 286 | "if exp_id > 0:\n", 287 | " log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id))\n", 288 | "else:\n", 289 | " log_path = os.path.join(folder, algo)\n", 290 | " \n", 291 | "found = False\n", 292 | "for ext in ['zip']:\n", 293 | " model_path = os.path.join(log_path, f'{env_id}.{ext}')\n", 294 | " found = os.path.isfile(model_path)\n", 295 | " if found:\n", 296 | " break\n", 297 | "\n", 298 | "if load_best:\n", 299 | " model_path = os.path.join(log_path, \"best_model.zip\")\n", 300 | " found = os.path.isfile(model_path)\n", 301 | "\n", 302 | "if load_checkpoint is not None:\n", 303 | " model_path = os.path.join(log_path, f\"rl_model_{load_checkpoint}_steps.zip\")\n", 304 | " found = os.path.isfile(model_path)\n", 305 | "\n", 306 | "if not found:\n", 307 | " raise ValueError(f\"No model found for {algo} on {env_id}, path: {model_path}\")\n", 308 | "\n", 309 | "if algo in ['dqn', 'ddpg', 'sac', 'td3']:\n", 310 | " n_envs = 1\n", 311 | "\n", 312 | "set_random_seed(seed)\n", 313 | "\n", 314 | "if num_threads > 0:\n", 315 | " if verbose > 1:\n", 316 | " print(f\"Setting torch.num_threads to {num_threads}\")\n", 317 | " th.set_num_threads(num_threads)\n", 318 | "\n", 319 | "is_atari = 'NoFrameskip' in env_id\n", 320 | "\n", 321 | "stats_path = os.path.join(log_path, env_id)\n", 322 | "hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True)\n", 323 | "env_kwargs = {} if env_kwargs is None else env_kwargs\n", 324 | "\n", 325 | "log_dir = reward_log if reward_log != '' else None\n", 326 | "\n", 327 | "env = create_test_env(env_id, n_envs=n_envs,\n", 328 | " stats_path=stats_path, seed=seed, log_dir=log_dir,\n", 329 | " should_render=not no_render,\n", 330 | " hyperparams=hyperparams,\n", 331 | " env_kwargs=env_kwargs)\n", 332 | "\n", 333 | "model = ALGOS[algo].load(model_path, env=env, device=device)\n", 334 | "\n", 335 | "obs = env.reset()" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 11, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "rm = AtariRewardModel(env, device)\n", 345 | "rm.load_state_dict(th.load(f\"../reward-models/SeaquestNoFrameskip-v4-reward_model.pt\"))\n", 346 | "rm = rm.to(device)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 12, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "random.seed(0)\n", 356 | "np.random.seed(0)\n", 357 | "th.manual_seed(0)\n", 358 | "th.backends.cudnn.deterministic = True\n", 359 | "th.backends.cudnn.benchmark = False\n", 360 | "\n", 361 | "seaquest_images = []\n", 362 | "n = 0\n", 363 | "\n", 364 | "obs = env.reset()\n", 365 | "cumulative_reward = 0\n", 366 | "while n < 5:\n", 367 | " action, _states = model.predict(obs, deterministic=False)\n", 368 | " obs, reward, done, info = env.step(action)\n", 369 | " cumulative_reward += reward\n", 370 | " if done:\n", 371 | " obs = env.reset()\n", 372 | " use = (reward[0] and not random.randint(0, 9)) or \\\n", 373 | " (not reward[0] and not random.randint(0, 50)) and cumulative_reward >= 1\n", 374 | " if use:\n", 375 | " n += 1\n", 376 | " sal = compute_saliency_map(rm, obs)\n", 377 | " screenshot = env.render(mode='rgb_array')\n", 378 | " image = add_saliency_to_frame(screenshot, sal)\n", 379 | " seaquest_images.append(image)\n", 380 | " " 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 15, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "data": { 390 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUEAAADACAYAAACaldH4AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO29eZAc533f/Xm659qdvQEsaFwEAQEgeIEEJSohGR02HceJTVORFYuxlcOu1+X4TSVRyrnKiZJ6X7tSlaucuCqpVA5bjq1KIluJ4iSOY8WyLNGJRImUSJAACIIASCyABbC72GOunul+8kf3s/NMb89Mzz2z83yqumZ3pqfn+fbz9Ld/z9HPI6SUGAwGw7hiDToBBoPBMEiMCRoMhrHGmKDBYBhrjAkaDIaxxpigwWAYa4wJGgyGsSbR6EPLsmKPn0kmkzz44IPYtg02lESJi+9e5NjJY0ylp6AMsiy5dP4S+c08mWSGE6dO8PaltykUCiCotWQPhBCcOnmKTDrjv+V5XLhwAcdxotNrWzz44IPcvHGTtbW1yH08zxNxNbWtHyiVSly8eJFjx44xNTUFgJSSS5cukc/nyWQynDhxgrffDvRHIITg1KlTZDIx9VuB/pud6++V9qmpKQ4cOLD93Zu3brJ8a3nHMQepPThef/TfvMny8nDpHzftXYsEp6am+OQnP8ni4iInT57khY+9QDKR5Pu+7/s4e/YsU9NTfPLFT7J3714A5ubnePHFF1lYWKh7TNuyef6HnufM42eYnJxkYnICy6qf5FQyxcc//nGOHz/eLVmx2aH/hRdIJjX9wefb+udi6Ldtnn/+ec6cCfRPNNGfGoz+VrQ/8sgjPP/888zMzPADP/gDnHnsTOQxbTvI+zMx835A2qED/T/wA5w500D/bs77IdLe1eqwlJJ33nmHpRtL/pET/nZn9Q5vvfUW0gtuMPp9RkWAYZ8WIIXkwsULFIoFjr3vGM899xyTk5PdTHJX2da/tFTz/p07gf4WB6ZLKblw4QKFQoFjx4ZbfyvaE4kEi/sXSaaTYBOZ/1JKLpwP8v74cGuHNvQvLpJMJhseb9fm/ZBpb1gdbgXP88jlczz9zNOkp9LkijlkQlIoFzjywBGm0lMUCgVc193efyu3hSe92gtA4JtkYIKFYoFEIkFapJE0N5HcVo5yudwtWbHxPI9cLsfTTz9NOp0ml8shpaRQKHDkyBGmpiL0b23heV7dY6rvJxIJ0ul0LBPN5fqvvxXtjuOQz+dBQL6Yx3Ed3wQ9am6OUgZ5nwy0I/2yocpKxKkYhHZoUz+Qz+frVvF2dd4zXNpFox84e/Zs7NBFCMHE5ARCCEiBl/Uo2AUyUxlsx4Y8UITCVgHP8bCwyGQzFEtFPAIjCEeIAiYyE1hY4AUnJ1+oRpQRTE5O4jgOlUol8vNXXnkldrtQy/onAv34haNQKJDJZLbbSwAKhQKe52FZFplMhmKx2NAI9aqAKhyN8qxb+ruifSKDbdVqt22bVCrlR4AJcFyHcrEMLv6mflXgV4FFoJ1AO9LfR7LDCEch77f1BziOU/fiHem8HyHtDSPBT33qU9EfhO/IMvR+CpgKtgS+AQYmWFPYLW2T2vGiju+xs+D3+LHnuvrHgJ5rt4Fk8KrKhCq/4eKqlwn1WVR56CIm78eH1qvDqg1PGZcKYnQjlPgFuohfyB1q7/Jo3wkfI1zg69z1a37LMHpIqmUifCPVX+vlr8l7Q5dozwSVEeoFOHynrgCl4P9KsIXafYDaCyAqGoyK/MKRJ5gLYtTQo3uCv3UDrGdyen5HVXBMOTC0SOeRoB7J6YW3gn+nh9qqSzhijDJFrXMkVnr07xp6TzeiML3c6MdVr+Gbm6zzefiYJkI0tEgsE1xbW+Pq1at+AQsatElQNbtK8Lca7gD1q7FRUR8R+0UMmYncJ9x2GHEBHD16lPn5+UhtcdjWP6J0on+H9qj2ujg0MqewuelDZlyqUaK6+arfjWojDpUnk/ddzPsRI672WCZ45coVfu3Xfs3/RzVoKxMsU23QtoNNEB39wc5qUKfEqA7/6I/+aEcXQo3+EaQT/X3VHjbBcE0jPJ6wXlOKhsn7Ecn7HhBXe+dtgnrBVAOklQmqDaKrwgZDI8KdJDpR5ceUKUMbtDdYWpmgqqaoTY8E6/XwuhHH6wbGWEeD2CP1NNQNt15V3OS7oQPaN8GoTR/3F1Vt8WjvIjDsLuL06DcqJ8b0DF2kPRNU1Vy9eht+CrkfBdUY6mjTqGe/Xvlp1qFmMLRI6xMo6NXacIN1vTt81PCYegbWjrE1+44xy+Gh0eD38D5R70e9hv82GFqg9UhQVWujBreqz9Vrs8IeppXBz1GN5nGGYRhGh2ZGGWc/g6EJsUwwZdvMqwefLaodIHocGW4P1KvM+qv+flR0KLR9o2jUYwiRUUJKe4i/HWr0jyApu339sbS3034Xzm+1b/jJo/CxWowCO9Guvm/yfjSJqz2WCT69dy+feOYZ/x9lgrrh6b3D4U4R3fhc7TVcoBsNfA4bX7MhE6GLYmnfXjaay6xLjf4RZGlv+/rrao8azB5FVG2g2eNx9fI5rgFq73WiHUzej4P2mJGgxdxEsGs44gv3CutGqIY2qBlC9NlC6s8etfPCafa4lPqO/qpxx+5s7tga/SNIJ/qbag83R4Sp1yQSNrqo/Nb/jzMu0OT9Dnqa90NOXO2xFMpsBXkk5/8THtUfNsDwBr7pqSdL1BY1XrDRI3RR0WCchnLhp78TavSPIJ3oj609ThU26jvdoEF12OR9H/J+SImrPZ4JComXCEI33YTqGaB6tljNoF0mevxg2NAaRQ1RRhjuhNHf074rRWct5zX6R5BO9MfW3qgK242Oi2Y1gHofx18zKPr7Ju+7mJr+Eld7LBMsJ1zWs8GKaOEB0qr6G44Ak8Em8E3QCV71iLAmxURfMPWiQNhpfHW+X0509phKjf4RpBP9sbTHaaut17zR6Hjh7zUbOxhB2TZ538l3x0F7vEjQAjcZlLR6JhgVDapH6MKdHFGzSUP9qKFRW2D4+3XS30k0UqN/BJEdNIvF1h7VTAHRNzX1GpWv4ff00QXh48VIVifa1fdN3o8mcbXHiwSzHuuLwVoA9R6TC0eEqjosqN8mGJ4iPW4jOhH7NYoGcp7/+21So38E6UR/bO1xqqu6wYUH2Ov76LPIRA2panLT0zF534e8H1Liao9lgq4QFFRPS9gE9UgwHBFaof3CF0R4HKGiXvQQJuZNyhWdtcDX6B9BOtEfW3u94SzhfUI30Irj4jouQkBiKomVENXP1XAqdcOsNz1bk/R3Qq/zvlJycUuB/mwSy+ruqP6+5H2bDIv2WCZ4gwzn5Xf5/6gxfeHp9aH2Lg7179jhMYF622u9oRD1hknE4DQZFtv7KhDSP4J0oj+2dpVH6m9Fk+FNV758hSu/dxVrxuL9f/f9TO3P+u3Hehuyjr5WTYzyMOx5f+VLV7jyW1exEhbv/2vvZ+pAtqvH70vet8mwaI8XCbpJisWZ6hvK6MJtgPqdWt3FAdcp4pby4AqS1jTCS/ifq2qx3kmiX0yE3m8T162/0HPc79fob/n7RVw3DwiSiWmE1d+xV53oj9IupUelvE54HehkchphJZHSpexsAFIbFSCCz7XxpgLypSRbtotYkBQfmCSxMINcqVC+vQllie1OYnuZajnR16qpqT1IyuV1pPSwRIpEcqpj7S3rT2j6yxuE7wZReZ/PJ9nachEJSbEwSaI4g/QqlCubgMS2J7HtTPOEKv10T/+4aI91NXo3T1A+/5yuKbr9T22hoTBLV77AlYv/CuHZPPb4P2Z66lRtodbv9roJ6pFFBybonb4Gi/fa/35Yf4ssvfsFrrz9rxDC5rEn/zHTM6faPlY7dKI/SrtTussrX/8pKpXqGDIhBI88/veZnT9DsXCTV7/x07hecbuMCCukPShD7pXPwtwFOGxTXvsk5dUH2Dx3gde++DPIdZcHDv8kBxc/5keGqqxEtCdXyjle/cZPUSreZWHvB3nosb/esfa29edv8urLP43rFrXPo/PevfRZ4AJ4NuWvfZLyzANsrl/gtVd+BildHjjxkxw8/LGm6axUcrz69Z+iVOqe/nHRHs8Eyx7ORilQRK0J6mMC01TXkoXtiNDZKlHeKoN0cXIlHFHyC3JZ26Ia0MPjB9s0Qq/c2VinGv1tsK1fuDibgX6dHnfAdaI/SnupVMLZLOO6tXVVZ7OEY5coFUs4mw6eLFfLh+3i5P3PkWzfKF3XhTngsKR8xMEplHCyDo5woOLhlEs4pZK/cqGqJqtoUDPCSsX/zXKpjJN2ttPci7xvqr8Q6Pf0z12cjRIOtcdynaC6JCXlnIOD/11n0wH8345T9iqVEs5Wd/UPpfaIWmKl3Jn25iYoYOXuH/CNl/5TbUIE1QKewF9wXW3KBINC6jrFoCPE483XPoOQiaoJqrafsFDdBDtck+TY4Y9x+OAjbX9/5U5IfyvY4CaLsAdIe7y58hnEWsLXU6C6KH0PjbAT/VHaJR6uu/NJgvPnfh5L+FUizyvV3sQ8jze//RmECIpccAN1JwpwH8jpCueu/k3EHRt5uQJpD1Lw7nu/ytKlX/fPkTLCiNEFEkml7D8pem/1Fb7x0o92rL1j/TUEZV/UXnKu64/Dk7LCuVf/JkLYSFl9rvTdq7/K0ru/3jSdvdA/NNr1prckkAk2CeRB5iSVUvvaG5tg4Lqe51Au39v5Wb1eO33yhtCECRVnq/odZYJqtbqoJ0O6YIKe57T/Zeroj4sFe+9L8/iD9/Gl0i0qiS0/CnKBZeAGkKN3yw7Qmf5WtLuVXK0MVR6C/K+4W9UoMEX15pkGbKhUNv3y4LF9c/2e+TleeW+NlXKptrzo5SqElJXtNPcz73foD1GpbDX8fqWyufP33SKeVrWMQ7f0D4N2yythJwUly/PLyQywFz+okPjXzxL+NUR72mNFgk2fBIjaovYPD4OpVwVuNjh6lEgBe8B7QFYv+iT+xXwJPxJcoacmOFCiOjHUnT0FTLBtgtufqQW7bPAsoMPHHg2jy6mZGY7tmeI315dgCtgHHAXuD3awgXvAGm0HSs0jwWZPaoSNLbz4UrPvqPf0tsaoAbSjeh0k4W62xO9OLMNJ/Is+RbVqt0Q783uPDlHDqFQ+p6mejygTTMD/Wr7ln6fddnM0xOJ6Ps+9hOOXlWl8E3wAeBjfa9aBK539RqzqcE37nI4eDepzBkKtmYUnVdWruM3MbVTNT6GiHhUFZvAzVJmAWqJ0t6OXlfAC6iX8wnwHWAW2gvdUT3C42UU/nmFXs14us54v+9dKIzoIlOKZoBrrBTvv7OFxgSox+nfC7YZRhRpqI0m0745yga8Am/jtf9lgmwjeX8PvHBllfe0g8fWX8A3vDv7NoYR/rm5SNcM8tYaor2VtGA9cqmXjDn7kp8rAdfxy0sE11NgE9fVe9ec59ZmhdePSo0B96iy98IZnmY66szdqOxw1KsAGvgkm8E1wEl/TGn6v57hd1LoJ5vALtovftlPAjwrXqRpgnR7hkS4Xhvi4+NcJVD1JTRmtOkV6ZoJ6dVjNCBO+E+tGqP7XnyFG+1yvEkeO+mf3FWwVCd4K/p7GN0ILP9oZx0gQqiaoDFFFgOqur4ZOhSfciNuMYtg9KBMsU1tbAN8AO5z3Nd7zW416bsMGqP4PT5WlfxY2wt2Mhx/RrOJn5ha+CaperR6PERxKVG2ijH9+HPyCHX4ePVxOTBQ4vqi2YQf/Gop6Tr1NmptgVBtg1GeqSqyqz17ovXDb4TgYIFQvcv/RYf/vAv4Fv041Ghon9LZhqB1RENXWHPW3YbzpYhlobIL1xv4powtXhcO9yXEK9m5HmaBqGy1Qfb46x3iaINTWHhqVlXr/GwxdorkJqtd6QxP0z6LG99U7XiO6GOoOHIlvgqqtS+8wqgSf7QadrRJ184zaR381GHpAvEiQiNd6+zf6vxHh5Tr13x/16rP+mKChijE5wxAQvzqsj9/rdqEV+OPEJoJNr2o7VB+er1BrzAaDwdAh8UwwPJ6v20Zk4ZvgHDBL7dTqefyew3DHisFgMHSB+G2Cek9vN1FjEDP4BrhIdZYVNdBYdS7os82YaNBgMHSBeCZY7/9uoJ5GUSa4N0iVMsEEvgHmqHYijHL7oMEQRps1p2bZCUNfiDdYuleGozpD1ESJWfjoqf28UVzndqno/26C6qNU6hGz8MLtBsOoYuHf/Bfw58pTA+vXMOW8T/R3xZ8w+gwrGWAK5KyErPTvhKqn+F7weQK/YKjqsqH3hCe4NXQXZYL3A4eAu/jnehNjgn1isCaoxsxpkeDvFW/7HSQZ/Kmn1AzMmWC/MvWn9jJ0D/386o9KmnPeXWz88n4UeBB4D7/Ws0R10gBDTxmOSFCvEk/izyA7SXUiRX3iTTWRgzHB/hC18p+he1j45f0+4Bh+c89lBn1ljhWDPdX6kJcyftvfVvCemoHmBn51ODzQ2FyMvSX8CKShJ3xkcZEL9ga3top+Od/EHxM7btOrDZDBm6DqBS7i9wCvUZ1YoIDfSHybas+wMb/+oM8lqZugicC7iufhl/01/OmhVvA7R0ybd98YrAmqx8mUCW5R7QVeCzY1yYBaxc8YYX+oNwmGoav8/q3bfjPQGv6ckyv4Zd5Egn1j8JGgMkE1y/Aa/gDpG1QnIlUrtJkLsX/UW/jK5EF3kVRv+mmqN34TCfaN4TBBNXPsepCiTXwjVDPGqg4UEwX2Dz0KDE92Gu411mfGgWoTh6E5qjq8jl/GN/GbgUwk2DcG3weljLCIb3wSvxCogtBozWND7wivLwM754G0tS1YIhNJdY5EE800R50v9XhogfFcd2aADIcJqslGBdWVxYrB++rCMlFgf9F7h8PLp0L1aZ8EflNFKtiUUarV4QyNUSao2sJV2Tcm2DcGb4LgZ7iaKitP1RijjE//34xd6x36Ugl6lVivJqvxnWn8xv0Mtb39Zv7EeKhFpdSqaebZ+L4Sf6GlXqNMT61drC48faLVqDFrxgh7Q3hlt/ANKWpR+SzVTi4LQ1zCy1QY+spwRIIKvcFdGaC+hcet6cZo7p7dRV8fOrxutN5WqExwgqoJ5qgOdo/CDLkxDBENTbBk+bemSr+eGAiboGpstyGREEwnktWL0gIpJBuVMl6TOQYTlQrJUin6wwaURK1+AcwmkwhRe0I2y2Uqcniv6nb0l2SgXY/O9aqankdpak0wHbwflSfBTWs6mSARChcLrkvR7W5DYrfyvpfYQjCTTNa8J4ENx+k4OGwr7/uovVUmbZu0XXuHLXseW5WdwxHiam9ogl+ZzgNwc6L1QtQ26qIJRYKn5mb4B48+gcgJf1B1GQqywl94+RvcLpYaTrZ6/1tv8Xiu9RWaw/rnUin+5VNPMRUqsH/729/mlbW1lo/fL9rR/5VsHiTcTJVqTVC/UekdI+q57zK+CUZVhwPjtC3Bzz/xOO/LTtfk16+88w7//tq1ltLZjG7lfS85ms3yC08+iW1VT5rjefzFl1/mej7f0bHbyvtBXPcx+fPHj/MnDh6see+bKyt85rXXduwbV3vjSFD4JbTvQ77Ca5tU4PpGnn/42puIvNgeTFpBsu6Um1atiqUSm1tbLSdjW39wR8xVKvyzixdJaoVVAlfauMj6STv6S2h5r5ufelVVZbViXhG/U0v9Xy+EkeBJyWffeYcZO1mTd++0kUfN6Fbe95LlYpF/dP48llbDcKVktY0INkxbed9H7a3yOzdv8ub6es17K3XOU1ztQjaoxu3/wCMSoLi6zuaVpVbS2hl6hKGPQVON7mr4THjoRp0Lb9K2SQXGtVoqxc7agenvMu3ob6pdzQGZxo8As8Gri//Ug3r8K4z+GF4fWhBM3vcg70eEuNobmqBlWYNp6NKn2FK9w2rxpXKwqYHU4eU5m+B5XuwLYWD6e0hc/U21q/bAJNVhMmqcYJ7qcghDhMn7LuX9CNJI+3D1Div0Bnh9yIx639P2I+JvQ+9RNyQVneeptgOaNTIMI8Rwj+YKz2QyhG0UY0/UmM7hLlUGQw3DGQnq1WH94lLRh4oIo55rNfQHtSyCGieYDjYPvyq8xdBVhw2GKBq2CRoMBsNux1RcDAbDWGNM0GAwjDXGBA0Gw1hjTNBgMIw1xgQNfUMI8ZQQoiKEONFgn68KIf55P9NlGG9M77ChbwghvgikpZR/rME+LwKfBQ5LKZf7ljjD2GIiQUNfEEJMAX8c+O9CiKwQ4oIQoiiEuCeE+LwQIhPs+lv4ow9fGFhiDWOFMUFDvziNb26X8B+q+3vAo8BfAn4YeBFASnkPuAM8NpBUGsaO4XxixLAb0ae62Av8OPBPgang/SOh/U07jaEvGBM09IsL+BHgSeAh4HuBz+DXRv4egUkKIeaAfcDrA0mlYeww1WFDX5BSbgD/A79d8L8BF4G/im94Ot+PPw/Nf+5rAg1ji+kdNvQNIcQHgZeA01LKS3X2+SrwupTyp/uaOMPYYkzQYDCMNaY6bDAYxhpjggaDYawxJmgwGMYaY4IGg2GsMSZoMBjGGmOCBoNhrDEmaDAYxhpjggaDYawxJmgwGMYaY4IGg2GsMSZoMBjGGmOCBoNhrGk4n6BlWbWzKwhgFtgfbCfx5/99GLgH/B7wFeAy/sxxHr7NpoFM8HcRKAAl6k+baQebSp0MNhtIBZsIfqMSbGXACX6zAZ7nicZ7VNmhH2ASmAHmgQeBp4AngXXgfwXbFS39M8BB4BD+9KG3gZvBazHYHHaeCxFsUvtM6U8Hf0t8vS5d1x+pvR0EtVrC6ZvAP4cfAT4AvIVfjr4GrMb8DVVeLKplIiL1Hed9L5gEPgh8lKr+3wW+Snz9Mek470Vos4Akfh5mgCx+eZ8O9l8GbuF7g6Q6ra5+DILPksBZ/PPwh4EN4NvAa8C7+OdiFf96sYJNPwZUPafFvI81qepk9iH27PuY/8NzwCJwH3AMeCB4XQsS+26QiLAJqgtXmWAxnFjJxIRDoZgiM1nGTrvkSxmy2SJlJ0GplNxpghWqBqA2N/okdMK2fkE1o+dC+u8Fr1eCLyWoNcGD+IVjCr/ATFG9GTiAlExkHAqFFJlMGTvpks8H+kuafnUuLaoG6QbH6IH+be3toC4UVVg9qiao0jgJHJNMnHYovC9Fxipj33HJ38iQPVCkXA60N/qNBP650W+M+s2jAzrS3wxBVf+DgX5Rxl52yS/F1N9DarSHzU/9nSLaBGXw/yS+NyhCRmpbHtnJIjknQ/J0Bfshl/yJDNmtIuV8gtJW0s/fOfwArKSlIVyP9YguY02IZYLJ1D5m55/1f3weWNC2PcF76rP5IMGNTDBNbSQoYXo6z4s/8mX+7a98mAeO3+Lsk2/zxsVDPPzQNaQn+PxvPIO0RP1IUDeBLrOtH/yMnQs2pX8h0DmvbSoSnMbPvNngb0fTng/ORwmms3le/FNf5pd+5UN88hNf4eLbh7m7OsMHzr7l6//1Z5C2qI2qVYbrkWCX9ddobwf9glGRoF44MzA9l+fFD3yZX9r4EB8/8TWS04JvrR/h0YNXkFLw+S88g5R1buSC6rnWy0STiDguHetvxgRMLwT6Sx/i46cC/fdi6u8hO7TrealHghlfx3YtSZlgidpgJxxJAvfdt8LJ40t4lsXywhxnT77NG/YhHp6/hjwj+PyFZ5Azws9XNzhmnEiwhRtge22C+knQt7j5FE68gEIhvf3vge9aYW1tCoGkUEgxO5NjW5UMbf0kXAB03fX0h9OrX7R29XuFoq/fEpL5+U2OH7vB3j0bFIoh/fpvh8/jKOJBYS0NOeC6IFlymU9vkrLLvvZZTftuREJhI+2bxR1Bsuwyn9kklRwx/eFyHvManZ4qkJ0scuvmPAfmV1i7O4W4KSmspphN5KAkq+YGjct8m9dC62uM1DOBOBdkVLtX8F4m47B2b4qJCYfZmRw3by+wvDzPge9a4dwbR5GeGI4LXtevt0Xp4Xm4AOgRkPpuyASV/mTC5dvfOU5m0uHau/uZn9vk3JtH/UhAN1sVCY46EjKew9ryFOnlMtfKiyyvz3Hz+h72T6z5eT+AKKhveJCRDmt3p0ivlblWXGT5TqD/+BDr15Okl/dwdVS/BiJkzM1uUSikKRWTzKZy3Ly1wDLzHDi0wrnlo8ii6EkTl07rJqjaoMr4d68t/EbMDfw2LlUVCRuAF/wjAancLzgrUrK1NcHnf+PDAPyX33x621R++39+YOexXGrbmOo0hnYdPaNVaL6F3ymi9LuhfdW5KgEFGVRZBUi9VGj6Jfz+1x7bNtsa/eE06Od22IMFGeQ9obz3JFt3J/j8f/ww7IPff/UxuAHchN+++oE2fqdrKe4yep0wVPZXJ/j8f/0wHILfv/sYvI2v/5029PeS2iJbfVXXg95BKdG8QNXiBAjtIELy8rce3D7ue7+x6Pc1LMJvn/8AbFL1FHVdhY00Kk0t0np1uIxfdbmHX1gv4ffgvAEsBZ+p8FVtFcCRPHDgFotz90BKnnziEpblks0WePiha7UiwmYTPpbeEVKhYa9Q16lQNb87wDv466K9iX8+CqG0O8G+q5IHJm+xaN+DLcmTpy9huS7ZdIGHT1+rHj98hyX0v97+p+vv8d2ybSQgJQ8cvcXiPr+b8Mmzobwv4d9E7uKXqwLxo1z9ZuNq/w8VDfSfvua3Da/gl587+NfQsEb54YhPmZ+Dn28qKFjH11XRtIvguk+4ZKcLPPzYtWoHYiI4Vg6/F3gFv0zk8ctHuMc/XOX2Qp+1QGuRoKRqguqGVgoSXAKu45+E0AUpLMlzH3qFj3z4NV7/zlFsPJ587BLzs5s8fPoaCdvj/IUjeJ5V/Z7HjnbDmihLvfY7GnLxI2Dwh7kI/DtWCf8mkKfaIRQYlshJnvueV/jI97zG668cxX7I48mHLjFvbfLwiWsk8Dh//gieG7onydCrugnAzkh4SKNBgeS5736Fj3zoNV4/dxTb8njy7CXm5zZ5+KFrJJIe5//NEbwNq3qDKdKaCagyoM7JECFESL8d0p/yOP/LR/BWLP/aylO9hoaNqKqtSqcyRHXjlyCKkuc+/Aof+eBrvP7GUQ4XktgAABTYSURBVOyEx5NPXmL+S5s8/Og1EmmP85eC614dI0/V8NRN3iF62FP4+mjzBth6dVhllBO8rgHv4YvfxM/A8AVZkRzad5cb1xa4tTTP449e5ty5o0xkHDzP4vULR2rbPfROBIXew1gvJO+XCao7UwVf7w2q+vPU3pkcYEtyKHOXG68tcOv8PI8/fJlz3zzKhOXgORavXzyys80zSo+KeMJDAAbVURQLyaEDd7lxc4Fby/M8/thlzr1xlImJIO/PHfHbfSpUmxNijHesIXzDHKrzIDl0UNN/JqT/9SPIvKiWJdV0MowmCDuNMFxDc6jWLz3JocVA+615Hn88uO4nHDyC694W1XZuDz/A2mJnlFdhp+nVS1uLxDJBvx0++AXl9oJqRKjMqc6FKCqCz/3K91Cu2FiW5KXfexTXs0jYLq5rb59PEf5yPecP0+NCX6NfZboqrBtE61evlUD/v9T0/9ajuK5FIqHpF37UoHcW7UCZYB+p0d7O96Xgc58LtAvJS1+tk/eO7Kxtp0dloOv6v1ZHf0n7jSEx8aba1Ud1yqRA8Ll/p5X7rz2Ki0ViwsX9HRuZBGGBsIO8V8PH1BAyfcRFs5t8B0FALBOclR6PyPLOH22FMoCnnTA3qNrFuOU3M4YeE6m/1d92ADxNrls9J3qE2+SG0m/qam8FpXNbT528HwK9Yfqmf7drV4GTCMp9Sjsf+ogLhd7Eob9Hnf+jqsYxiWWCi+T4iHynsRn1mz6mY1t/r1AGqL/CUJzrnmsfcsZZf9e1q+hOb8dWxhc17jVuIBBVA+u2CU4kr3Jg5nO1P9ov6g1Abpu/3PI3JhJXOTD9ueY7doOe32ha099X7T1nyPO+5wxB3gv8p0xSwavaEvi1JfVsfFRPf9RY5LAB1jXB+tpjmeDWVIXLx3O1P9wvum6CrZObcblyotDfHx0Sxlk7jLf+renQdd8NBL7rKBPUzVANsymyc7xxvQcyYptggyRJWX9vNZtEau4xpo78yGCqakJ77cJv3/3O34o9/F7pn9x3lj0Pfqr9Hx0irn3lL8fSr7Sn588we+zF3iaqT9x6+a+bvI+B0p5ZeJy543+6/R+Mas9TJqg2ZYQJqp2Nqndcf/okbIJhP2hSJb7x9Z+pqz1WJOi6CQrOZJxde8cA2yPLZcHGxmBm8hg0jiNYWRnfaScrFYutXLr5jruQSsVis9va9UdOE+wcLK3PBhX1zHCjdsM22wVjmaBlC1Lp8TQBACEElj2eRmDbFunJ8TQBACklnjsEPVQDQFiCRNLu/oGjJgBRnYLquXq1T6NIUH2vHyZYKa6Sv/1yvCPuQirFlV2k/6da2rtSWsG9/Y0epWX4cUurFFe+OehkDISeXffKwJTRqYlI9MHR4SehGplg1Gv47wbEMkEpZihbD8Y74i7EY4aSPDnoZAwGaw6ZfmTQqRgY45z3fbvulfGpvxVRxhc2wfDQMr3ZLGYLcOuPzRkMBkO30KM//f96YwTDRqe/T8T7MRgbE0xYNq43ZE/XGwzjjDKs8GVZz9DqGWD4ey0yNq39tp0gmRjfzp1uIAQsLozNfbNl9s0nsMbmiuoOiwsJf4rBcBtgI6OLMscO+q52fZZZCY/MTAXPLtBoTKShOTNTNt/7zMygkzG0/OB3z5JMxB6KOPbYNjz/0Tmsdk5ZF5+v3/UmePSD6zz1Z5Z49idv8L7DhwadnJFl73yCH/n+eZ44PcnRg6lBJ2eoSCcFP/jRWc6cmuTpx6cGnZyRYGrS4hN/dJ5HT05w9uHuj0FOJVLsmd6DEM0ddtebIEClmKB4Z4brq9cHnZSR5b69CZ44PYnnmWg6TColePbsFLblTw1laM5ExuKDZ7L+0MAYRtUKQgimJ6bZKm7Fqv3t+gaeq1+f5erX5xAIXNcZdHJGlvdulflb/2SJiivZypsOJp2SI/lHv7RMseSRL5hzE4eNTZf/71/cpFKR5Lp8zqSUrGyuxN5/15ugV9Fvzaa9pl3WN4d1quPB45Qld1YrzXc0bFMqS0rrw1GmTPBuMBjGmrExwYRlI0wkaDAYQoyNCZpxgp1jxgk2xowTbJ3FhcTAQ5Ndn2VmnGD3MOMEG2PGCbaGbcPz3z038BvHrjdBM06wO5hxgvUx4wRbx4wT7DNmnGDnmHGC9THjBFtndMcJDtNqczEx4wS7gxknWB8zTrB1RmecoJrZVZ/tVX/Yecevx/7dvjGQcYJqhlz101HTBY0YZpxgfcw4wdYZpnGCjU0wgb/qUwLI4C+KotYBUKtB6bQy88NuIuytFv7iMang7wo7z9luPh8GwwjRuAUjiT/tdRLfBCeDLR28b4W2He9JTp9YYTJbjl4ubzcQXgBG4J+HNDAlOX1mhcmFcvVcRi0ybRhrbMvj5NFVshPlQSdlLGlugqoqnAGywATVJfIaGqFE2PDjnzjHD3//W/iThsn664eOIlEGuG2CEpGFH/9z5/jhF96CpAQr2HaDdkOXkBw5sMnPffol9u3J07UqgpBM7XMQtgyuvTEkptc0rg5n8BdEzgAzwBx+tU5FNWWi88yDxdkCP/HC6xw9vMGeuSL75gt869x+vnVukc2tdPNZYkcZAYv7CvzET7zO0SMb7MkU2ZcpcO3tGb76fw6yvJzduVCMYawQQnL2ods8fOIuD59YYSpb5lM/9Ca/+7+P8NIrB1s8mmRmysESknLFIldIcvKjq7zvw6tc//YMb/7WXirFHqwaN6y0GGg1NsFpIA9MAfuBg/imuBpsUZ2twn8/M11hYV8RzxNkJ8s89fgt7j+0wVtX56trme4GAwhrEIAHGavCQqaIlxdk7TJPPXSL+2c3+INvHBhEKg1DhhCSP/zEDZ59/xJCwMZmilTSQ8rWqwmWJTmwuEUy6bG5lSK3lMRzBStXJgnWUR8PwqvYhZftrENjE5zCN7sscB9wDN/4pvCjw1JEAgSQh3fvzPCzv/gsE05l+/298wU28qmdDj3q+aSiWvW3B+9eneFnf/5ZJiYr/jkrgSxDbit4dG+3RsGGWEgp+NIfHOFXv3h62/gqniCXb/3RTs+zuPDOnpr33nlpjstfm0cIieeOUfuL3iynmuyadEI3jwQTVCNBZYKZ4P1isJ/uwACbQAmcNRtn094OTddzQTVYd+fdYgT6Mn8eUAZnI9DvUu0ZDi8puFv0G1pCSrHDuLqJGhomh7oBWjKZqZAvJuhKQ3m4XV412zU5dOOOEZW2BH5v5yR+VKh6iSe0bVJ7VSYZHqIXtW7obkTim14QAVKm/gIyu/k81MU4vwE+/NR1/sZPvozdrSq7Gp+r+9VU8NqAxpHgJv7FvAncBC4F/98JthK1F7Eyua3gO6rHX7/4u7hAykAJm1eUmXlUoz1db72/x4D7D2zwzJNL/If/dgrXM8+YjRuZdIX9e3Ps35Pnj7x/iQcOrfNH3r/Ey6/fR64QbgpoIWLQDVB15M7gB2b5xmlqbILrVE3wepAmN3h/narJhRNTBDbwI6Fwe1nYDPtER+sO1+tpamSEUWunhv/ugX6BQA6ds/rpmZos89E/9C7PPnmD//m1+1lbzxgj7DFJO0nZHZ7xh9NZh7/yZ1/hwWNrWJZkK5/khe99m8vvze4wwZkphzMP3gHg8rtz3LhdOzmFQJBMJHAq5doqcAa/KW8vvhHea5ym5iZYxje068Grh29yBaoNjuFrrhLso5/7eubXp+s1k8qQLxVa/2K9m084Ao76bABelLBtEnaSgtOG1h4xP1Piz/3JN3j89B3mZ4u4ruDv/L9f5xd++SxXrs8OOnm7FoEgk0wPlQkWSza/+eXj/MZvV29+t+5mWVreOfvOxlaar36z8cxPCTtJxXXx8GpNcAbfBBdoOkNC44+38A0tj1/9XaPa3uVS29AfVd2N6vyQoa0PzEzMkCvl8GSLkWAjo2v0fpSucDTcI+1lt0IqkWIyPUm+1KQe0CcsS5JKuly8Mg+A6wkuX5vj9upE39IwlRmvKa4sYZHNZNkqbg06KTVs5tL87v8+0pVjSSQVt8Ls5AxrxXvVXuEkfjugGttcanSUZiboaa8VqsbmBlvUBR1VDQz/30cDVNPqZDNZbq7d7P4PtKqjx1FwKpFi/9x93MutDY0Jrm2k+YXPnqXkDG5W6oQ1RoOFgaSdYP/sIlLKoTPCbmFbNvNT86QSKdZLG36QozxK75hsMnlU4wYZ3dCU+XlNNkltT6j+d9g0+2CEUkqWVpdwKg4zE9M9+pE2ty4jEOyd3svyvWVWt9a6/wNt4nnWQA0Q4F5+faC/329KFYd3lq8wlZkitUuXlZiemEZKj/fuvofnebU+5eA3yeWpDuWrQ8ySqayV1i/eOB0KfWBlaxnR7oyXwgPRpamSeqhdCsHdrVs4Fad7v9NN7aPICOv3gFsbS/4CY6KNTsEh154vb3Kv4PjPRquJSaDagVvCN8AmkWBjE1QRS/oe7PtOdSBwOPKDvkR1ndJ2EjP3YP93og8wZPq7Pm1s+h4sfqfbRx0ddoH+tovmkGt39IcN1MxNaqKXOfwe4izRo1g0mpsgQPIWzPxO1QSjqry7FQGkAv1R7IYxj41I3oLZOtrHgXHWP6zao2o5NtU5PNVMVzb+tdmkOThWdThhSbKTcGT2CLaw2ShusLyxTK6Uq848HTCRmGBhcoF92X1IJEvrS2yUNnBCU9tbwmIyOcnh2cOkE2lKlRLvrb9Hvpzf0YubslPMpGc4OHsQgeBO7g6r+VUKlZ3DQLLJLPun9zOTnsH1XN5df5eck6PitRnWy5B+y2ajtMHy5jK5cm6HCe42/Qm7gfYQu037uOsfSu0FTbv29Fk2nWX//H5mpmZwky7vVt4ll89RKVe60yZ4et9pPvHMJ5hMTiIQlL0yuVKO9+69x+U7l9ksbJK0k5zae4pDs4eYSc+QTqSRUlKsFFnNr7K0scRbK28hpWQxu8j98/dzcOYgk8lJbMvG9Vzy5TxLG0tcW7vG7dxthBCc3HOSgzMHWZhcIJPIIISgVCmxUdrg+vp1Lt69SNktM52e5vjCcQ7PHSabzJK0k0gpyZfzLG8tc339OtfuXYsjN1r/s4F+ISi7ZXLlQP/KZTaLu1f/6cUG2lcvs1navdrHXf/Qat8MtHuB9j3HOTx/mOxElmQqiUxI8laeZWeZ6/nrXLvdWLtotBqTlfEf6nviiSf4sU/9mP9meIiH6jEeET796U/H7jKw7Aj9YUasOhxXv2Vp2n+sjvYRo6W8H2P9Q6+93mTG4dnt1eQJEj79Z+trb2ncQrlc5trVa2QmfGfGA6fokLSTpJIpkskk8/Pz2/svLCyQSrW3Rm1la4PN899CdxgpLMr7jyGTg1n3dod+CY7jkEwkSaW6pF96bLzxMm5hZ5VjkPrL5TLXrl0jk8lsr+XqOA7JZHe0u7lNNs5/E+rclIci73ulX3psvPFN3EL98Xy7Mu+lDMp643GMsbW32VfRkgkWCgXeePMNXNflwMEDlAolnKJDxamQ28oxNzvHc889t73/+973vpqT0wq5d85z8bP/Bak/72snWf2jPzGwC2FbfyXQXyzhOA6VSoVcrjv6ZaXM+S/8Uwo3ru78cID6C4UCb7wR5P2BA5RKIe1znWnPX3uLC7/8m0ivzuRvw5D3PdIv3QoX/vMvkr/+Tv2ddmHeS7fChS/+Ivn3LjfesZH28D1Tn7Qk5uOrjU1Qn/ZKwMTkBI88+giZTIbJSb99cOXOCulkmonMhD8eSePSpUvtR4K5DbY++ELtZADCwstk2zpeN5iYmOCRRzT9QrCyskI6HegXXdAvPTZOPYt7/xM7Pxug/obaJzrX7uY32fxDP1R/sexhzvtO9UuPjZNP4x4+U3+f3Zj3UrJx4mncQ4813i9KuyomUXNyttg81VIkmEwmOX78eM2Pzc3M7XwqJGBtrfapBccRrKwmuG9/mViLzh861Uryes4O/cDc3Jz/R8R4wbb17z3ceWK7TEPtEbSl/eBw5bdOz/XvGb48V/RU+57GEyREEo7yomapitq3DvEeoWhmWE0+lxIKBcErr06yupKo1+wz2jQ4B2Ohvw7jrB3GW3/PtUdN2hKewCVG22C86nD4bz0B4b/rTBl/9VqaA99V5vBhJ14UOKqMu/4Ixlk7jLf+nmlXEZ7+GrVPDGJVh9fW1nj11Vejf0R/fK5eOApsbdkUp1xWVuIlbJioqz+KOpkxqvpb0l6HUdUO461/qLU3m9quhaiz8TjBiSaT/0eZYAwsK0V26hjJ5CyuW2Bj4wKzsw8DgkLhBqXirZjJbx3P81oeJxibXaTf6tFajaOg3U/n+OofN+2xBkvXJWqqrCakM4scOvIjFApLJJMzALhuAdvOIKWLU1ojkZyi7NyjkL/OxvobzQ/aAl03wRaLy6jo78WFMCraYbz1j5v2xh0j4dli9AbHZtXgej9opZDSBen5JyGRJbd52T8JiWlSqTlSqT1kp44xMTm8PWbtPimya/S3wThrh/HWP8zaG0eCSUtGdjGHh4O0ZAYWlpVCCAspfRf1vBKWlcGyEniygiUSuG4JkEjZ3fnMWo4E9fFIio7uk6OhvzdVotHQDuOtf9y0NzbBRtXBEe3qb/tCGODiSd1k0O1Cg2TwJjhYTN5H09pUy2PWvW8wGHY/jYfIRFlkOCIasVlUWqLe0gC7Va/BMIY0rA4bDAbDbqfNlYcMBoNhd2BM0GAwjDXGBA0Gw1hjTNBgMIw1xgQNfUMI8ZQQoiKEOCGEkEKIn4vY5+eEEK8LIZoslGgwdAdjgoZ+8rPAl2i8HPY/Bx4CfqgvKTKMPcYEDX1BCDEF/HHgv2tvPyGEuCCEuCWE+BiAlPIG8BrwiQEk0zCGGBM09IvT+IPzL2nvLQJ/ErgK/GutCvwW0GThCYOhOxgTNPSLqKevvyylfBP4MrAA7Nc+M6P4DX3BmKChX1wAXOCk9t5HhRAPAR8BVoHbwfsngdf7mjrD2GJM0NAXpJQbwP/AbxdU3Aa+ABwD/h8pZUUIcQC/Kvwf+59KwzjS0pKbBkOH/P/AS0BSSllvTqK/ALwJfLFvqTKMNWYCBYPBMNaY6rDBYBhrjAkaDIaxxpigwWAYa4wJGgyGscaYoMFgGGuMCRoMhrHGmKDBYBhr/i+sj23dYfmmqgAAAABJRU5ErkJggg==\n", 391 | "text/plain": [ 392 | "
" 393 | ] 394 | }, 395 | "metadata": { 396 | "needs_background": "light" 397 | }, 398 | "output_type": "display_data" 399 | } 400 | ], 401 | "source": [ 402 | "plt.figure(figsize=(5.5, 3.2))\n", 403 | "\n", 404 | "for k in range(5):\n", 405 | " ax = plt.subplot(2, 5, k+1)\n", 406 | " ax.imshow(breakout_images[k])\n", 407 | " ax.axis('off')\n", 408 | " if k == 2:\n", 409 | " ax.text(0.5, -0.17, \"(a)\", size=9, ha=\"center\", weight=\"bold\", \n", 410 | " transform=ax.transAxes)\n", 411 | " \n", 412 | "for k in range(5):\n", 413 | " ax = plt.subplot(2, 5, 5 + k+1)\n", 414 | " ax.imshow(seaquest_images[k])\n", 415 | " ax.axis('off')\n", 416 | " if k == 2:\n", 417 | " ax.text(0.5, -0.17, \"(b)\", size=9, ha=\"center\", weight=\"bold\", \n", 418 | " transform=ax.transAxes)\n", 419 | "\n", 420 | "plt.savefig('figures/atarisaliencymaps.pdf', dpi=300)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": {}, 426 | "source": [ 427 | "# timeseries plots" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 117, 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "env_id = \"BreakoutNoFrameskipNoScore-v4\"\n", 437 | "folder = \"../agents-custom\"\n", 438 | "algo = \"ppo\"\n", 439 | "n_timesteps = 10000\n", 440 | "num_threads = -1\n", 441 | "n_envs = 1\n", 442 | "exp_id = 1\n", 443 | "verbose = 1\n", 444 | "no_render = False\n", 445 | "deterministic = False\n", 446 | "load_best = True\n", 447 | "load_checkpoint = None\n", 448 | "norm_reward = False\n", 449 | "seed = 0\n", 450 | "reward_log = ''\n", 451 | "env_kwargs = None" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 11, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "name": "stdout", 461 | "output_type": "stream", 462 | "text": [ 463 | "Stacking 4 frames\n" 464 | ] 465 | } 466 | ], 467 | "source": [ 468 | "# Sanity checks\n", 469 | "if exp_id > 0:\n", 470 | " log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id))\n", 471 | "else:\n", 472 | " log_path = os.path.join(folder, algo)\n", 473 | " \n", 474 | "found = False\n", 475 | "for ext in ['zip']:\n", 476 | " model_path = os.path.join(log_path, f'{env_id}.{ext}')\n", 477 | " found = os.path.isfile(model_path)\n", 478 | " if found:\n", 479 | " break\n", 480 | "\n", 481 | "if load_best:\n", 482 | " model_path = os.path.join(log_path, \"best_model.zip\")\n", 483 | " found = os.path.isfile(model_path)\n", 484 | "\n", 485 | "if load_checkpoint is not None:\n", 486 | " model_path = os.path.join(log_path, f\"rl_model_{load_checkpoint}_steps.zip\")\n", 487 | " found = os.path.isfile(model_path)\n", 488 | "\n", 489 | "if not found:\n", 490 | " raise ValueError(f\"No model found for {algo} on {env_id}, path: {model_path}\")\n", 491 | "\n", 492 | "if algo in ['dqn', 'ddpg', 'sac', 'td3']:\n", 493 | " n_envs = 1\n", 494 | "\n", 495 | "set_random_seed(seed)\n", 496 | "\n", 497 | "if num_threads > 0:\n", 498 | " if verbose > 1:\n", 499 | " print(f\"Setting torch.num_threads to {num_threads}\")\n", 500 | " th.set_num_threads(num_threads)\n", 501 | "\n", 502 | "is_atari = 'NoFrameskip' in env_id\n", 503 | "\n", 504 | "stats_path = os.path.join(log_path, env_id)\n", 505 | "hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True)\n", 506 | "env_kwargs = {} if env_kwargs is None else env_kwargs\n", 507 | "\n", 508 | "log_dir = reward_log if reward_log != '' else None\n", 509 | "\n", 510 | "env = create_test_env(env_id, n_envs=n_envs,\n", 511 | " stats_path=stats_path, seed=seed, log_dir=log_dir,\n", 512 | " should_render=not no_render,\n", 513 | " hyperparams=hyperparams,\n", 514 | " env_kwargs=env_kwargs)\n", 515 | "\n", 516 | "model = ALGOS[algo].load(model_path, env=env, device=device)\n", 517 | "\n", 518 | "obs = env.reset()" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 12, 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "rm = RewardModel(env, device)\n", 528 | "rm.load_state_dict(th.load(f\"../reward-models/BreakoutNoFrameskip-v4-reward_model.pt\"))\n", 529 | "rm = rm.to(device)\n" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 20, 535 | "metadata": {}, 536 | "outputs": [ 537 | { 538 | "data": { 539 | "application/vnd.jupyter.widget-view+json": { 540 | "model_id": "72d48b30723146c39db003f84204db1b", 541 | "version_major": 2, 542 | "version_minor": 0 543 | }, 544 | "text/plain": [ 545 | "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" 546 | ] 547 | }, 548 | "metadata": {}, 549 | "output_type": "display_data" 550 | }, 551 | { 552 | "name": "stdout", 553 | "output_type": "stream", 554 | "text": [ 555 | "\n" 556 | ] 557 | } 558 | ], 559 | "source": [ 560 | "radius = 5\n", 561 | "stride = 4\n", 562 | "\n", 563 | "TIMESTEPS = 200\n", 564 | "\n", 565 | "breakout_obs = env.reset()\n", 566 | "breakout_rewards = []\n", 567 | "for i in tqdm(range(TIMESTEPS)):\n", 568 | " action, _states = model.predict(breakout_obs, deterministic=False)\n", 569 | " breakout_obs, reward, done, info = env.step(action)\n", 570 | " if done:\n", 571 | " breakout_obs = env.reset()\n", 572 | " if i == TIMESTEPS - 1:\n", 573 | " sal = compute_saliency_map(rm, breakout_obs)\n", 574 | " screenshot = env.render(mode='rgb_array')\n", 575 | " breakout_image = add_saliency_to_frame(screenshot, sal)\n", 576 | " breakout_rewards.append(rm(breakout_obs).item())" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 21, 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "env_id = \"SeaquestNoFrameskipNoScore-v4\"\n", 586 | "folder = \"../agents-custom\"\n", 587 | "algo = \"ppo\"\n", 588 | "n_timesteps = 10000\n", 589 | "num_threads = -1\n", 590 | "n_envs = 1\n", 591 | "exp_id = 1\n", 592 | "verbose = 1\n", 593 | "no_render = False\n", 594 | "deterministic = False\n", 595 | "load_best = True\n", 596 | "load_checkpoint = None\n", 597 | "norm_reward = False\n", 598 | "seed = 0\n", 599 | "reward_log = ''\n", 600 | "env_kwargs = None" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 22, 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "name": "stdout", 610 | "output_type": "stream", 611 | "text": [ 612 | "Stacking 4 frames\n" 613 | ] 614 | } 615 | ], 616 | "source": [ 617 | "# Sanity checks\n", 618 | "if exp_id > 0:\n", 619 | " log_path = os.path.join(folder, algo, '{}_{}'.format(env_id, exp_id))\n", 620 | "else:\n", 621 | " log_path = os.path.join(folder, algo)\n", 622 | " \n", 623 | "found = False\n", 624 | "for ext in ['zip']:\n", 625 | " model_path = os.path.join(log_path, f'{env_id}.{ext}')\n", 626 | " found = os.path.isfile(model_path)\n", 627 | " if found:\n", 628 | " break\n", 629 | "\n", 630 | "if load_best:\n", 631 | " model_path = os.path.join(log_path, \"best_model.zip\")\n", 632 | " found = os.path.isfile(model_path)\n", 633 | "\n", 634 | "if load_checkpoint is not None:\n", 635 | " model_path = os.path.join(log_path, f\"rl_model_{load_checkpoint}_steps.zip\")\n", 636 | " found = os.path.isfile(model_path)\n", 637 | "\n", 638 | "if not found:\n", 639 | " raise ValueError(f\"No model found for {algo} on {env_id}, path: {model_path}\")\n", 640 | "\n", 641 | "if algo in ['dqn', 'ddpg', 'sac', 'td3']:\n", 642 | " n_envs = 1\n", 643 | "\n", 644 | "set_random_seed(seed)\n", 645 | "\n", 646 | "if num_threads > 0:\n", 647 | " if verbose > 1:\n", 648 | " print(f\"Setting torch.num_threads to {num_threads}\")\n", 649 | " th.set_num_threads(num_threads)\n", 650 | "\n", 651 | "is_atari = 'NoFrameskip' in env_id\n", 652 | "\n", 653 | "stats_path = os.path.join(log_path, env_id)\n", 654 | "hyperparams, stats_path = get_saved_hyperparams(stats_path, norm_reward=norm_reward, test_mode=True)\n", 655 | "env_kwargs = {} if env_kwargs is None else env_kwargs\n", 656 | "\n", 657 | "log_dir = reward_log if reward_log != '' else None\n", 658 | "\n", 659 | "env = create_test_env(env_id, n_envs=n_envs,\n", 660 | " stats_path=stats_path, seed=seed, log_dir=log_dir,\n", 661 | " should_render=not no_render,\n", 662 | " hyperparams=hyperparams,\n", 663 | " env_kwargs=env_kwargs)\n", 664 | "\n", 665 | "model = ALGOS[algo].load(model_path, env=env, device=device)\n", 666 | "\n", 667 | "obs = env.reset()" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 23, 673 | "metadata": {}, 674 | "outputs": [], 675 | "source": [ 676 | "rm = RewardModel(env, device)\n", 677 | "rm.load_state_dict(th.load(f\"../reward-models/SeaquestNoFrameskip-v4-reward_model.pt\"))\n", 678 | "rm = rm.to(device)" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 24, 684 | "metadata": {}, 685 | "outputs": [ 686 | { 687 | "data": { 688 | "application/vnd.jupyter.widget-view+json": { 689 | "model_id": "f2ab43304c5a4451ba4bb614007b0781", 690 | "version_major": 2, 691 | "version_minor": 0 692 | }, 693 | "text/plain": [ 694 | "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" 695 | ] 696 | }, 697 | "metadata": {}, 698 | "output_type": "display_data" 699 | }, 700 | { 701 | "name": "stdout", 702 | "output_type": "stream", 703 | "text": [ 704 | "\n" 705 | ] 706 | } 707 | ], 708 | "source": [ 709 | "radius = 5\n", 710 | "stride = 4\n", 711 | "\n", 712 | "TIMESTEPS = 200\n", 713 | "\n", 714 | "seaquest_obs = env.reset()\n", 715 | "seaquest_rewards = []\n", 716 | "for i in tqdm(range(TIMESTEPS)):\n", 717 | " action, _states = model.predict(seaquest_obs, deterministic=False)\n", 718 | " seaquest_obs, reward, done, info = env.step(action)\n", 719 | " if done:\n", 720 | " seaquest_obs = env.reset()\n", 721 | " if i == TIMESTEPS - 1:\n", 722 | " sal = compute_saliency_map(rm, seaquest_obs)\n", 723 | " screenshot = env.render(mode='rgb_array')\n", 724 | " seaquest_image = add_saliency_to_frame(screenshot, sal)\n", 725 | " seaquest_rewards.append(rm(seaquest_obs).item())" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": 136, 731 | "metadata": {}, 732 | "outputs": [ 733 | { 734 | "data": { 735 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAACSCAYAAAC+Pop7AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO2deXwkVbn3v091dzp7MpNkJrMzzD7sMIAIKIsICgoiigrXq6BXr6Li8rrxquh7vVevXlFUPi6IiAhyFWRxYZNFZB8YYAaYgdmXzJ7JvvRS5/3jVHVXdzpJT5KqXnK+n099ulLbOd2p86unnvOc54hSCoPBYDCUF1ahK2AwGAyGiceIu8FgMJQhRtwNBoOhDDHibjAYDGWIEXeDwWAoQ4y4GwwGQxlS1OIuIleLyGmFrofBUAyY9mA4GIpa3A0Gg8EwNsKFroAXEVkOfA7YAMwH/ghsK2ilDIYCYdqDYTwUlbgDZwPrge8CJymlnihwfQyGQmLag2HMFJtb5ldAK7ASOKnAdTEYCo1pD4YxI8WUW8bpLHoUiAIvAycopfYXtFIGQ4Ew7cEwHorNLTMN+D4wADxgbmTDJMe0B8OYKSrL3WAwGAwTQ7H53A0Gg8EwARhxNxgMhjLEiLvPiEhYRK4SkV8Uui4GQzFg2kQwGHH3nxrgXrJ+axFZLiLfEpFvisjSwlTNYCgIpk0EgBF3n1FKdQK5ohw+B/wY+ImzbjBMCkybCIYRQyEty5qUoTS2bUsAxcwG9jnrc3IdICL/BvwbQE1NzXFLl+Y2Zjr64lgW1FdG/KhninhSsf1AH7OnVBMJ+f8TPffcc/uUUi2+F3QQjNgmLGcRwD1KgEZ0UON0oMk5RjlLL7DHWfY6+2rRtu1UYAGwGFjknF/lLBFgB/A68Bqwy9nm7g87ZeOUMwj0O0sf0OOU3eNsG/QsSc8CEAI7Xvg2kU972NHRT3tvjJmNVTTVVPhW0a3tfXT2x5k3tZr6Kn/bnZeDaRPFFueeF+FwmLq6Onp6erBtm7q6Ovr6+ojFYiOeF4lEqK2tpbu7m0QiEVBth2U70Oys58wXopT6BfALgBUrVqiVK1fmvNAhX/4LACu/c+6EV9LLbc9u5Uu3r6YT+D9nL+GTpy/0tTwR2eJrAUEipMV2uG3iWSwyHxaS43iFFuAEEHeOtUk/WHLhPlSyyxlJuoMz8UZsE/m0h6/+aTW3PL2Vb11wOP/yhnm+VfRjv13JfS/v5keXHMvbjpgx4df/9K2reOz1vaz6+lszth9MmyhJcZ87dy6XX345N910Ex0dHVxxxRXcfffdPPvssyOet3DhQi699FKuv/56Nm3aFEhdRUSAi4ElInIscAXwEeAHwKfQTecHE1FWImkTDvnnaZtaE02tf+++db6Le3FjUVN7BJbl/CbZQum13Bs8SyOZFnUYPURpkLRA16Gt98aspQFtlVc659U7Sx3Q7WyvcZZsY3LA2Za9VKDHvw6QtuCzLfcJvqX8bBPusB2/XzNS5fhU0N0vto37GiUp7pZlEY1GOeGEExgYGKCyshLLGv0OdM/L59iJQulRYt91FoDLnM9XgK9PZFkb9vaypLVuIi+ZQdLWd3RF2KK6IuRbOaWAWBFmzf0MkYrW3FZwtltmurNMJe2WAe0aqUULdD0Qcv6uBaYAh3oW1y3jPmMr0Za67ZwfBapJu23c8r1umT7P0uss/Whxdx8yPou7v21C/7B+ie5QAivooClJcd++fTvXXXcdAE1NTRx//PEFrlFxsKd7IBBxP/6QKaze3ulbOaWE5Ksirjsk32121qJyHO8+UMJoMQ+R6WLJ/nTXXcHPvv5wFK9+DSFtuftb6VLojCxJcR8YGGDjxo0AdHZ2smXLFnp6ekY9r6+vjy1btjAwMOB3FQuC35kkErYNQCRk+V5WWaGy1kf77UYTdkVaxEOkxT1M/la2ez3XQnfL8T4QvG8gJYLf7pJsgntDOHhGFPc3velNQdVjXKxZs4ampqa86vvSSy+xYMECFixYEEDNgsX2WXFdyz1sWb6XVfYM9/Plstyzhd1lNMudrGOzGc5yL2LBGg3lumX8LqcEbv8Rxf38888Pqh6GCcDvGy7tcxfsEri5C0a2ENvoiJYY2qft9bnHnX3JrOOTzr4Y6c7OXtLRMSHn70HnbxutaLZnHdLWd4J0VE0sq9zsB0gJE5zlHsxDZDyUpFvGkObav7+eWg/Kco+EjOWeE9fi9kbDgBbRfqDL+dvrOhlAd27GQLBZsmg7a7fOgQQsnredDf0zSR6wWBDfybbdzcSsCm2lW8ABdHz7AXS8egTdqarQ4u/WQZGObXfj293IGPfBUCYCn+7HLmbZDQYzQrXE+cEDr6XW/bamEx5xN9qehcpavGLpFff2rKUT6IXpU9o54dh1vP89DzN/1i6OXr6BD777QQ5t3slhU7bwwUUPMKPnAGwFNqFnVd1Cprj3oR8WrmXuDkryDlzqcpZe59gEuV0/JSr0Kq3ukx5juZcRwVnuYiz3XOTyiyu00PajXSADpE0qcbYNwqmnreHcM54hlghz6kmrOe2Ml1CVwglL13HGWS9ghWxa+jrYsm162tJ2RbvXKSNJupPV609PkCnuPQzttLV9+D0KQNA+97yjpQqAEfcywu+JV4xbJg9y/Syudez627NHmyZg9ar5PP/kQgYGK6isifHEk4cxYFcgUxQrX1pMfzjKjo6mTH95wrPuXs/tXHX99q7vvhc90MkVdyGz83U4a73ULPiARbd4pT1PcU8kEtxzzz15hRuWInV1dZx33nmEw6X9rPPbLZMp7v6WVewIEEVRmY/yZVvGXoEHULD11dnpUER3MFQEnV5rD1APFRVAUqWF2xVeb6oC53opi93bKTtI2m1jkQ6jdM9xv1j2Fy1mBcsiKK9MKdz+eamZbdu8+uqr7N9fnlM4Njc38/a3v73Q1Rg3flvTXp876DeFYn4t9RNLwTIVp94ezO8Eb/SK2+mabTWT9WmRttBdQQ6jUwaEyLS29dNG73NzzCRI+9xdUXejYyAdtTNaXplSUDIH9+3V79syqHLGQ2mbqoYM/Lfc3UFMkiovgOSQRYmFzTFqOy3kOSDOFVHvp3efV6Rd8fe6VeKkc8FUOkt2VI4r/t6Hgivubu4ab9hkiQn3wVDMonswjMeAMuJeRvjvc9efruVuK0WolN7ZJxCRJC3Rv9BaGR394NRJWZ9evOLuHhNCi7k7+jSKzjnTgM4lk229eyN1XDEPkX5LcK+Hcz13sbLOd2PpE3jCKj+d//csIEE9q1JDEnwuUKmxP6iMuJcR/kfL2IhAyHIt9zI1+/JAiaJv3n566/LM5T2asHs/3WMsdCKwanSmxwZ0QtwmdAIyr7/cO1DK7UDtBDpIhz66+dzjpDteI2S+SXgHPLniXkIqEVRuGRff25xSWGP8LiX0bzOMhu1zOFvCVoQtwXJMiUms7SQI8bvB06kKN+R/Uq7OyVzhk+6xVtbfEdL+80EyO2lttLC7o2C9GSC9MfD9aNEOk5kHfhRxf0/+37KgBH1L2j6b7klbERljAlYj7iVMthsmiDh3SwTHcJ/UlrttW7y8ZT5WtHn0g10OZkIMV9yrPUsD2ho/gI6ica1u1/J2hT2GFvQOZ+lEi7v7YHDTF2QnGxvOLVNCKuG2Cb/vTffyfrtlxvM9SujfZsgm+8YKIreM13Kf7OGQylaoZJ4/wmhv1t7LeI/1DjLyim6czDzyrigPt3iv4Q52SjLUaldZ57h/lwhB+cJdkj43uvF8DyPuRYx3zsi5c+cO2Z/9VA8iFDJkSaqDZzJb7ghU1lQQGkuHaja5rHbQ1nUV2t9eTXqGJne2JTfs0bW8Y842N4DHmxjMjWl33TFeyz076iZnh2rhGa09AKnv4Lvl7n4GNHBwLOQt7g2RCHaFfxPOFpKGSKQoYz6y54zM3j9U3P2tT9JWhENW2udeJkPWx4SdpHPbk0io9uDOG87nnusYV9xdt0w9Oh+NO+2eG+/uiq93sJLbodqJ7lDN5XN3l+HE3R0oFQI9C15hGa09QDr9gN+iG5T7Zzw+/bzEPWpZ/PC444j09Y25oGImVlPDhlCo5EJ+s++rICx343N3EItEaCmEm3y4tvPpCrd36j7XJTNI5uhUSPvc3cXNN+OGWLrhkN7zvG4dr8iHcqyXAEH5wl2SPhs4vvvcRYS6SIRoRZ5hXyXGgGO5l5pUZf/j/a6/7frcTSik/3iF1rWkvZEw3gFLrjh7o2XcVMLZg5e8/nXvYCYvuWLnS4S0uAdT6SBCIcdKfm4ZUdjTBlCx/jEXVMyoCgukhO5gh6EdqgH53Icp3+AD3hj2OFq0XevbO0BJkc4j4xV4N+3AcOIOQ2Pws2dmKqH/s+uWCWJwkf702y0z9nPz71CtSkAkMfaSihgVLqFwAA/ZnS3+x9zahEOSGg7t9409qcmeyDqJFmlX0G0y4+CzxT173Rv94j3HK/DDZYgsoX9zUKLr4rdbxn/LHYhFkoiUp7jHw4lSun9TDI1z97e8hK0IiWcQk7/FTU5yiWzSs12ytrnYpN0vcTLTAntnW3Jxr+212nMlMitR/DZ0UuWUeoeqEuiLxkhEYmMuqJiJWfGSSmvqkv1/9/1GU9otYzpUfST7J/UmERtN3L0hjN6IF+9MS7n866N1OJXQvzmoOPegonICGcRkh8cXc1nM2KGJ/V4iMl8ptWlCL5qDIR2qflvuSVfczSAm30mpFJkZHL0drJJ1vHeQknfgkmKoSya7rFyZKieIoNqDl6AMD7810fc4dyWKnjlxrFB8zAUVM8lkAtU5tnNF5A/oOW5qSDeT2cDJE1W/4Qh6EJOOc/cMYjLq7g9esfX6xN2/veGP3nO8Ip79OZqrZYL+lYVsDxBctExw6QfGfm7+lrvlTT3pLbFY/BnZv0I+9XJercb3Ff5DKfWiiFymlLoBQET+fVxXzJOhce7+ljfE52603T+8v603XYDX9559nPt3dkdosP+ngrUHTbDRMkG4QsdKfjMxITxhNxOTKpRSvHLTq/Tu7AVg8UWLaFzYOOYKTBSb7t3M3hf2AjDr1FnMOnnmiMf3tPXy6m9fQSloaJjKu9+tGMsse0qpF53VE0XkSfTddeTBX+ngCdpyT/ncrWDKM5Ap0MP5xovo31DI9qDL15+lHOfu9eP775ZB2E8F/SRRSrF+yyBdm/Ro1eoeixiVY67ARLF5Z5Lt63SdEstsIqPUqaN/gLXr+kBBc3P1RLgY/ged90IB14z3YvkQeJx7UhG2LI/PvYhUZTJQWj934O0BPM/CgDpU/XhD8Aq67+JuJ212PLaDrsEuUBDrTEfN7Fm1h77dWlSnNB1Pbe2CMVfGJWkPsqvtr9jJQUTCtM48h3B45BweqnMvsB2AjvUdbL5vMwC1dQuZMnXF0O+0cwfwDBPYYn4KvF0pFVjHRPYDKYjcMjpxmOlQNYxK4O0B0gZOUP1BfnSoemPbx/OQys9yTyo2/nljzgmyd/xjR2p9yfJ3UTXnnLHXxiE22M76x39GPN6BZVXSePJJVFXPGvEce/8q4HkA9q/Zz/41uq4zZy+j/rChdbI7XgZ1AxMo7g8Ci0SkB3ibUurnE3Xh4Qg+t4xNRTicCoU0g5gMIxB4e4DgQiFT5fnQBryjUgMZxJQPsdh+ens2j/s6iXgXSrmjRm36+3Zg2yMbAIl4T+7tie6cdRoY2MkEv+eeg57lUoBjAN9v5sCzQipMKKQhXwJvD1Da0TKJpE17X4wqz9RLgaT8zYfNG25g88Ybx38hBe7bnG3HeOn5z486S6yyc6cQ2LPrYfbu+UeOE0aLDTtozlNK9QKIyBETeeHhGBrn7ncopO1M1pG7fIPBQ+DtAdL3pO8pf53PpK1SZclYZ7J2+MEDr3HdIxt4+AunpcspFstdqSSoic/TolRiHDpso/yeXFTzPhF5BzoX33LgaL8LDHqEqjuISUyHqmF0Am8PkLZ0x2tRJ23F75/dyntXzCESyh5UkBbdzv4487/yV7523nIuP2X+uMp8YVsHAM9ubs+ox1gZWmvDWLGAe5RSHwCuDaLAoHPLuB2qJs7dkAeBtwfwivv4bs5bntnKVX9aw2+e2DxiObs69bRXv3tqC+/9+ZN85Y6Xxlzm0tZ6AJ7bfCBdzji+hxH3iWMhUCciPwLOdTeKSLWI/LeIXCEi7/Fs/6CIXO8sx42lwMAtd9vkljHkTeDtAdL35HgNnf09gwB09efu63Onzk14CnpmUzu3PrNtzGW67Wp7R3pSpGBS/hpG40dKqTYRORZ43bP9QuBZpdQfRORO4A/OdgU8gZ5AbVOuCx7sHKp+a+1gPEllJGQ6VA35EHh7gLRFPVZf9f6eQda0dZFw1Dtk5bZ/3VDL/ngiVfnRuPL3q3jT4hYuPHZ2zv2DCa3kPQPp7LvjMaCM5T5xfFtEPoye9dIbujMH2OusV3m23wX8Gvgz8B+5LqiU+oVSaoVSakVLS8uQ/UOiZXxW28GETTRsmQmyDfkQeHuAtEU91nvzshuf5V9veIZdXdrdEg5ldpLu7xnklbau1EOkdzD/PsY7X2jjc//74rD7BxP6Wt0ecTdumeLgo8BK9EzCz3q2bwPcO9E7ldVSpc2L/cC0sRQYdG6ZwYSdYbmbOHfDCATaHl7f3c0fn9ueMnDG0hYeWrubF7frDIKPva6fP2FL2NnZz08fXo9Siu/dt463X/sYr+zsAqB3MPccF6/u7OIrd6zOaXD1DiY4+lv38+hrezO2u5Z7l9dyL5ZQyEnO39CRAV/35NcAuAO4WkSmA78TkRvQN/wZInIi+kb+n7EUOGQmJp/FdjCRJBq2jFvGkA+BtoezrtHhzstn6E7JsUSZ3PVCW2p9d5f2uSvgs7e9wFMb2zl9yTT2dA9mnOOKezxrSqZLrn+a9t4YV75lEdPrKzP2v7a7m46+OD+4fx1vXpx+AxmMu+Ke9vMHkhXSMCoXAjOAy0XkCqXURwGUUn3AFz3HuT7G74y3wCDj3JO2Ip5URMOhdIeqUXfD8ATeHmB8ce4D8SSLp9fSF0uy/UB/alvM9YUPJlKGjUuPI+79sUz3THtvLHW+9xPSnbDuRPMurlvGLQ+KaBCTr1jo+SPDpNOfupMWFAd/A7YALwM3B1HgkMRhPpbl3njRiGVyyxjyIfD2ALnj3J9Yv49I2OL4Q6ailGJHRz+zp1QPOXcgblPlcTsC3PL0VmY06CSEvYMJ+mKZgtPriHqvZ/u29nS0i+uTH4inBdv7IFBKpdrTYGJoaEzZd6hWh0M6BrQFOARYjE7/3whUFLJmGXxeKXWJUuo/lVKrgyhwaJy7f2rrvjJqt0zu8g0GD4G3B0hbxd628IHrn+Y9P3sSgKvuXMMp3304FeroZSCeJBoJUV8ZSW3b0z2Y8sMf6IulxNzFfZh4xfvU/344te4+DLyW+1V36p9j1dYOvnbXmtT2XOL+id89z08fXj/idx6OkhD3lspK3rV0jvbGLQCWoPvcpwKREU8NkjNE5McicpKIXBpEgUPj3P0ry73xouGQsdwN+RBYe/C6LgYdEe3oi7Nlf2/GcbGEzS1PbwVIRcN4GXACBmorczs02ntj9A3TgToc7sPAK+7b2tP9yDc/tTVd90TuyBtv9MzBUBLivnOwn5t2bYTpaHFfBswFplBMlnsHOn73SbQDyXeC9Lm7N15lxDKDmAz5EFh72N+btsL7HBH9y+qdvPl7j2S0idd2d6fWD/S6uasUa3Z0cv1jG+kZiFMZtqgbSdxjB5dexX0YeC17L163++Awx1RXjO3nKwmfe8y22UG/ttzno0U+DuyhmMR9DjBdRKYAbwV+43eBQ+Pc/Ssrt+VuxN0wLIG0hxv+uYkDfen5JXqyrNydnWkL/bkt6WH97c45f3huG1+6Pe01OnxWA7XR4cW9N3ZwVvS9L++irjJCNJLbjvZmgMzlloEyF3cgPW+k5VmKZPpWETkKaAMeA45Dpzr1nSDzubuvlRk+d99KM5QyQbaHe15qY8eBtJsjkeUrXLW1I7X++Pp9qfUn1u/jqjtWM3tqZsdqZThENJxbTPf3xug7iEFLoMMr73qhjd9efkLO/b2xJB/+9TMsbq1jV9cAFWErI1oGoGqM4l4SbhkUOmJ2L3pg8mvoSZc60BZ84fkMunYfBGqB9wdRaJD53FOWe8Qyg5gMoxFYe5jVWDUk9tzLqq3aWq+MWNz/yu7U9t8/u43uwQSv7uzi0Oaa1PbKiEXPYG5R2dM9SCw5ttfj4dwyAA+v28vPH91ILGEzvX7oc3CslntpifsetLivQ49zawdiI5wXHI8rpW4DHlVKfZkxjjg9WIKcQzUdLePJLRNIJmVDvoSsQLp68iGw9pArpDGjIhv0jGzfeMdhwx6zbEY9jdU6MqMyEqJzmGRhOw705dyeD7s9Hbh1w7h9ABa0DJ1OtCoyNgdL6Yh7L1rcN6DFfStwgGIR98+JyP3AFSLyAHBbEIUOtdyD6VA1uWWKk+royEIXIIG1h1lTqkbc/+rOLuY1VXPekTOGPWZuUzUzGvR1opEQZyzN/Sza1zN2sVm/R6fXueuTJ/PPL58xZP9s53vkGrRU3pY76CwUe4HN6BxzOygmt8wnlFJvVUodrpQ6C/hkEIUGmc/d26Fq0g/4Q2U0wfvOXcuMabmnjByNeCJOOBQe94xAE0Bg7WF248jiDnD+UTOpqxw+ZnpGQyXNtToyoyoS4r0r5rDmm2fz7FVvYZmTzmAkwtbov/eGvfp/OqOxkoaqCFecvjC1b2lrHbd97CSAnJ25NdFyF/ckMIDOL9eFdtPEKIpePaXUoyP97RfZbpEgLPdo2MLNgmp87hNLJJzkgres54K3jG3QSjwZZ+aUmVSECxtCFmR7cN0pIzHdGWHqMq0u06/dUhulqUb/ZpXOCOzaaJiWuijLZtRlHHvqomYWTst0nXjfHn76gWNT65efMp/rLtF/u+kM3OiYL5y9JPU2MbOxilmNVfzmshP4jwsOJ2RJhsiXt1vGi/Isk5wg87m7HULeDlVjuU8cZ560hc9+6HlqquOceNQuvvaJpzjr5M2sOHxXXucLQmtjKwm7ePJxBEE+kSSt9Zni3pot9vVRptZowc+OVMm2pI+bN2VITqXZHnE/1+P+ufj4OZy1fDoAm/b1ErKESk/oo7vudqK+eXELTbVRVl/9Vp696i2p48rfLePFiAoQ7ExM7k1fETKDmPygvi5GZaUW5hktvZyyYgdXXPoCi+cfGOVM5/zqOvoG+2hrb2MwPnz0SLnhjRMfjulZ4v6Fty7hnMNaU3+31FbS5Lhl3IRfLu6DYGlrHa31lVx03OwhETMzG3K7hirDoYz5Vy89cW7G3xVhvd5Uk/kmUV0RznholX+cu2EIQeaWSYl72DKDmFzEBsuJex7nT3HX3+fz0rqpVFelO5HCYRuUgIweW93Z7zwECu5uD5axiPusKVX87F+O45Av/wWAlrooS1u1+6W5NlNoT17QDKyjvTfGM4417c7S5JLtpnGpzBq4dFpWR607dsR9sAzHWOPcjbiXMEHmlnGtlYqwZSbIBi3sU9dD1NOwx/F72MCGPnRfUjbN+3JsNMDIwnfNxUexZkdXqrM0dY7zQDikqZrN+/uoqghx5rLp3Hz5iZy0oCnj2MNnNdBUU8EXzl6S2padu33pMJ2u0awHz8KsMEd3QFRTbe4xXiFLSNqK6oqxybQR9yJmtDkjs6fg8je3jHHLZGJD/VNQldWEJvNP4jO52kPlCJb7inlTedcxQ+crdc+54xMnZ8Sfn7KoecixIUt47mtnZWy7aMVsfv7oxtTfbmfs0HIyLfdZWZE9biqD+mFy2fz106fy2Ot7CeURjZMLI+5FjFLqF8AvAFasWDFENoIMhYwnbSIhQURMhyoACqw9SFgIS9jZokgkR+/QDFthxPGfxO3RY3lDEsISLRS2skmq0d00ESuSrlMenazeOiVUoigjoXK1B68PO5vhrHrXcp9aU8HUYYR5JL509lKuPHMxy75+LzBUxF0qnLr95rITeHVn15DJOdyJPobLZbOktY4lrXU59+WDEfcSJshBTLGEnbpZzSAmsCyLU+afwqzWWRwy5RBQ0B3rZv3+9azft572vnba+9v1sWIxtWoqs+pnMbthNsumLSNiRVAo1u5dy67uXWxo30BHf0dK7Ouj9TRUNrBs2jLmNMxhWo321+7p3cO2jm28svcVuga66BrUc3lGQhEaKxtZ2LSQ1tpWlrQsQRBiyRhr965lW+c22rraaO9vx1b6LWxq1VSaqptY0LSARU2LqK3QboPNBzazq2cXr+x5ha7BLvrj/ZQiw3VERsPjiyOxLKGqIsTTXz0TkeGzObp9U29e3JIxnZ7LhcfMYtXWDuZ70h9MJEbcS5ihce7+lRVL2KnefZNbRg/1P/mQk2luSr/KN9c0M3/KfM449AwGE4O0dbWhlCJshZlVP4uQFRqSIqClpkVb43aSXT27GIgPpLbXRmsJSShjUFJzTTPLWpZx5sIz6RnsYW+vnmS5MlJJa20rIStt5bvMrJ9J0k6SsBPs6NpB0k4iIsysm0k0HB1Sp+aaZpRSnLPoHPb376droGtCf7ugqMxKAPaj9x3NLU9vHWJBjxW3o3ZPd2Zu+B+89yjuf3l3rlMyuPQN8/jAifPG7HYZDSPuJUyQ+dxzifukdMtkt0P3b89vEbJCVEeqWdi0kHywxMIKWcxpmJNfFUS7ghqrGmmsaszrHPfBcujUQ/MvIxRmeu10ptdOz+ucYiNbxM8/ehbnHz1rwsvJziJ54bGzufDYob7+bESEkI/RTUbcS5ggU/7Gkl5x97+8UqK7u5v169fT2NgICuJx7VoJh8Ikk0mqqqoIhUK0tramrHDLspg5cyaWNXYXQceqx4i17x12f3zaPJINQ90BhollOJ97oTHiXsJki+vj6/fz/NYDHDt3yoSXFUvaqc4rM81eJp2dnWzbuo22HW00TmmkfX87jY2NRMIROjo62LNnD1VVVVx00UWpc8LhMIsXL6aiYmypApRSvP7XX9L96vPDHtN93NuMuAdAxQiduoXEiHsJk0tcv3n3y9x1xSkTXpa3Q9VMkKSMrXsAAA+vSURBVJ1JXV0d06dPp3FKI9XV1cybN4/BgUH6evuor6/nsMMOI5HIjFhJJpNs2rRpXJb7gWkLGZThM0EmprQOu88wcbjGTsU4O2onGiPuJUwut4hfGQFjCTsVZZDO527EHaChoYET33BieoPK+syBK+7jYtoCvRgKzs8uPZalraNnkAwSI+4lTC7L2a9sr6ZDdQwIZlBTAFz/wRWs293N9+5bB8Dnz1pMz0HOdTpezjl8+HzxhcKIewmTS1z96nyPezpUMR2q+WF+nkB4y/LpHH/I1JS4f+rMRQWuUXFQXE4iw0GRS1wtv9wyOaJljLYbioWQnzGFJcqIlnuvpUfJxC0buwBmSEiE6ZWVQ6xRBewZGCAxQeoitk1lfz8VidLKhe1a7ucdOYM/v7QT8NktE8p2y0xedU8mkzz++ONUV40wtV0Z/Tyf/exnC12FEclnNqTJxoji/midTlFnx+IMWMHfqU3RKL888USqw5nVjNk2H336abb29k5IOdGBAY586imioaKZYDgvXJ/7f114RFrcfXLMGJ97JrZt849H/5F757h/F0EkRH3DcsLhWpRK0tX1KpZEqK1fhFI2/X3bGehvG29BZYMR96GMKO6283slUQUxQnricW7YsIFIVrhYUik6YhM8M3YyWXKpsF3LOcMV49OXGPSI+6TOLRPAV26cegz1DYeRTPYTiTSQTPZjhSqZ2nQ8fX3bsO044XAtzS0n09/fRjzWSU/369j25JmkIxu/hvCXMiOK+45HVwKgEkmS/cHfOH3JJLdv2+Z7OXsGBvjUypWpDoiXfC9xYnAtZ6+4Z08TNlHEk2m3jNuQsvNaGyaGaHQalVWtdHetI5HoJRKpZ/qMt7Jn19+prJpBNNpARcVUwuEaqmvnE4228Orqb01qcRcR3nnUTC44Zmahq1I0jCjur910T1D1KCgx2+aVzs5CV+OgcS1nr+HeHxs9HWw+tPfGuHPVDj5w4lwqI6GMDtVIyGJmQyVb9mu3nVJq1Ph6pRR3vrCDkxc0My1rZhxDJh0HVtHT/TrhcC2JRA8iFrFYB8qOEa1qJRyuIRHvJlrZykB/G7YdJ5kcGP3CZc617z+m0FUoKkwoZAkRT9pc88BrPLu5nXlNNalBRF7Lva2jn7tfbGPRtFoG4klmT6lm3a7unBMR5GLD3h6iYYtrHnid25/fzpq2TmY2VNHRF88oZ+H0Ol7f080TG/bx8d8+x5SaCr7xjuWcsVQnmRpMJKkIWfQMJvjb6l38c/0+7n6xjTccOpXf/9tJOcs+0BtjTVsnpy6a3EPmBwd2M0jurILxeNoI6el+PagqGUoQI+4lxE8eWs91j2zgqNkN/HX1TvocK90SOPeIGfxl9U66BxN8+tZVQ8592+GtKAUXHTebM5dNY+2ubq57ZAPnHtFK10CC0xa3cPPTW7n275mCccfzO1LrR85uSK0vmlbLb5/az+due5GKsEU0bHHZjStZPL2Wy0+Zz7fueQXQ/SMDnnzXT21s50BvjCk1Fezs7Gd/T4yF02rZ2z3IO3/yTxK24umvnjnmqcUMBoPGtKACIiLVwNXAVmC3UuoPIx3/kVPns2h6LecdOZMdHf3c8dx2BhJJwiGLn15yLPV3rObWZ7by4/cfw6eyBP5va3YRDVvc+/KujO33vDg04mJKdYT5zTV87z1H8d2/reXMZdO4+PjMaf7OXDaNX/1zE7u6Brj2/cdw+pIWPn3rKp7cuJ8v3b46Z/3rK8N0DSQ45v89wOGz6lmzIzNPeEXY4o8fP8kI+yTlYNuDYWRkpORPllWA+MciwLbtQLreReRSYFAp9QcRuVMpdcFwx65YsUKtXLlyxOt1D8TZ3TXAwml17OocoL03RjRi0R9Lsq9nkOPmTeHK37/A39fu4eSFTXzx7KX88MHXqAhbPL+1g6NmN3Dqohb+9Y2H5FX/+1/exZb9fXz45EMIO52tu7sG+O7f1rJsRj2XnTKfnzy0nmsefA2Amy47gYfW7uHGJzZTFw3zr288hAXTarjrhTae2dTO185bzvtPSD9EROQ5pdSKvCoTEKZN+MdEt4dy5GDaxIjibvAXEfkK8KRS6hERuU8pdXbW/tSEwMASYB3QDOwLtqaB437HeUqpye2An0SMsT1A+bcJ7/fLu02Y99/Csg1w/1FDJqr0TgjsIiIri82anWgmw3c05OSg2wOU//0y1u9nxL2w3AFcLSLTgd8VujIGQ4Ex7WECMeJeQJRSfcAXC10Pg6EYMO1hYjFZIUuPIa+lZchk+I6GiaPc75cxfT/ToWowGAxliLHcDQaDoQwx4l6kiEhYRK4SkXJ/5TQY8sK0iYPDdKgWLzXAvcC/A4jIxegwsTnoUXxCCY/mE5F3AkuBCPAa2tAom+9n8AXTJg7i+xmfexEjIocA/1cp9RERuUspdb6IvAf9z7fIczRfMSIis5RSO0SkAfgVECmn72fwB9Mm8v9+xi1TOrh5cvcCc9FP873OtqqC1GgcKKXcjGTvAr5PmX0/QyCU1T0z0W3CuGVKBzdhdwv6tcx9ZYMco/lKARE5F9gI7KAMv5/Bd8runpnINmHcMkWK6NkvvgicB3wGWER5+RcvAL4EvAjUAXdTRt/PMPGYNmF87gaDwTDpMT53g8FgKEOMuBsMBkMZYsTdYDAYyhAj7gaDwVCGGHE3GAyGMsTEuZcJIvJnYC1wLHrwwxPooczvBW7xc7SeE8L1glJqs19lGAwHg2kPRtzLiV8ope4WkSuBRqXU1SLyTqVUn4i8y+eyLwA6gM0+l2Mw5Mukbw9G3MsEpdTdubaJyIXA14GjReSr6BvvIeBE4MfAGcBhwOVKqY0isgL4GHqUXL1S6iveazrJmk4AYsBO4EHgaOBDInKoUuoGEfkyOslTPXogxhbgenQypD3AEcDHlVI7J/hnMBgA0x7cL2yWMlqAK4Grs7Y94lnfjB7pdhTwqLPtfODzzvpKYI6zfgNwata17gJOctbf6HzeCJzmrB8BPOisR4HVzvrVwIed9fcBPyz0b2WW8l8mc3swHaqTj81K31GdwCZnWyd6uDPAYuASx9pIAg1Z538RuFJEnkBbItksA2qd8z8LbBGRCrds53MjsGQCvovBMF7Ktj0Yt4whm7XA9Uqpfc4r6f6s/bOUUheLyCL0bPX3om96EZHDgVeAvUqp76A3vkcpFdNpQZjnXGMB+pXUYCh2SrY9GHEvI0TkUOAsoEpEFiulXhOR84F5zmets34K8GbgSOcGvARYICJzgI8D3xaRNqARncjIy+kicjIwA/iZs+0h4CPAAaXUFSLyiIh8GwgBr3rOPVxEvgkc45RjMPjGZG8PJnGYIRBE5Gq0r/ORAlfFYCg4QbQH43M3+I5jQb0J+BcRCRW6PgZDIQmqPRjL3WCYZIhISCmVHGafoHXBDrhahgmm7C33kZ6Mxoo0TDZE5H1oH/J5ItIpIjdmHRIGbhGRY4c5PyIivxKRn4vID0XkIRF5h9/1PlhE5EMi8qFC16OQlLW4uzfyCIdcJCL/KSJDfgcRmSci94uIEpF7ReQ2EdkmIp+coLq9wbneIxNwrUYRuVpEjp6AqhnKFBF5N3puzquVUn8GDmQfo5SKA98FHnTcB9l8ADheKfUxpdSVwNfQMwUVGx9ylklL2Yq790Ye7hil1G3o0WjX5Ni3BbjF+fNmpdTFwAvAtSJSPd76KaWeAjaM9zoOjcA30CPjDIbh+BbwkFJqwLNtuYj8r4hsd4flK6VWAX3A53NcYyo6yuMmp42tUUpdByAiF4nIYyLyGxG5XUTqRWS2iPxdRL4rIi+KyAecY2c6x/5TRH7tGFF/dSxu5Rgrp4tIl/t2ISKHO8f/SESeFJHlIhIVkT+JyF1OGTc70S8LgYXO20WNXz9oUVPoEWQ+jkx7GbjJWY8AjwL/DTwLfM5z3EeBODAzxzU+BCjgUvQotqfQs49XoK0bBfwJuBNYj7ZgHgB+AjyDM5oNPdT4R+gY2JuBsLP9EZzRcsDj6BFwc4BTgCfRD6cngDcCy9Exs5vRQ5mf0v++1Gg3BdwPfLTQv71Zim9x7hkFfMezbTNwn7N+HzqG23L+fgZ4Jsd1pjttSDlLP3AxWvT7nTZW7+z7OnoO0BOdc3+AnvsT4FrnmGbg/c76ac4+hTOq1Knjjdl1Ap5HhxyucI7/BjANOM/Zn2pbk3UpS8vdeVIvB9qcTQr4L6XUF4HfAt/3PM13oP2MI1m9FwN/RY9Wu1IpFVNKufGuUaUzzH0R+B/gKKXUFcB24DfOMXcppT7jHHMJOvbWW9/3AbcqpS5zzvsDuhF8AZ174o/o+NhnAJRSvegHhcuNzuctSqlfjvLzGCYnVc5nImv7LudzJ1qgW5y/40CuN9QD6Dwsy4FPA13AV9CWciVwGlrgn0OP8oyhR3j+Fp2DZZpznUVo42Sfpw6jcTgwU0R+hh5FCrAabVx9zfkOficFKxnKUtwZeiNbwKki8nvgbLQV3uzsizufI7lablNKvQ24CbhZRM717HsOQCl1B/rmizo3XzXQ7jxEmkXkduAK55xpnvMPQ1s0ruA3A62k/aEH0AMkmjEYxogjoh3AlKxdjZ7PdvSbKWihfz3HpS4DLlFKvaqU+jHwv851NwKDwNNKqY+jU+s+6Bz/KeALwD8813kdHZzTCszMKqMH3Y7EUz/Qb64HnOu/FbgVmAV8D/1Quh34sIhE0G1fRKRlmL6DsqcsxT3HjXwu8FW0b/0uZ5s4n1Odz1w3cjZu1ra5nm3ekLJXABv4pFLqHOBXaBfODWir+/tZZYN2H10IvF1EPgHsI21FufVrc7Z3o5MPeesN6YeYiMjxJgrIMAzXo61uROQ8dPuY7YyePBLt0rNFZCrasv55jmu8DFzmRMv8FJ0j/dNOm/swcJqIXANchX7bfBz99nk12pBBRD4IfMfZ92fgnKwyrgHeCXwGbXytEJFlaBfqgIj8Cu363IxuD98B/i8wH/2GHneuuwDt/sl+oE0OCu0X8mtBP81XOuuHoJMC3YYWWgV83dl3DfBUjvPnAn9zjv0L8Et0B+gNaKv8EmffE8AKTzkPo62ZH6N99SFn27PoG145130TsM1Z3oDOLeH6L09B+9R/6Hy62eZWoIX+28A9zrU+6JTxhFPOLYX+7c1SnAtaCO8GPjLCMQL8GvhqgPU6DY/P3SwTs5TtICYRiaJ913crpa4f5pij0db1u5RSW4Osn8FQKERkhVJq5TD7IsAypdRLAdbnVnTa298opT4UVLnlTtmKu8soN/LhwGtKqVjA1TIYDAZfKXtxNxgMhslIWXaoGgwGw2THiLvBYDCUIUbcDQaDoQwx4m4wGAxliBF3g8FgKEP+P6jmpYO5QHdwAAAAAElFTkSuQmCC\n", 736 | "text/plain": [ 737 | "
" 738 | ] 739 | }, 740 | "metadata": { 741 | "needs_background": "light" 742 | }, 743 | "output_type": "display_data" 744 | } 745 | ], 746 | "source": [ 747 | "fig, (ax2, ax1, ax4, ax3) = plt.subplots(1, 4, figsize=(5.5, 2), gridspec_kw={'width_ratios': [1, 1, 1, 1]})\n", 748 | "\n", 749 | "ax1.plot(list(range(100, 200)), breakout_rewards[100:])\n", 750 | "ax1.set_ylim(0, 1)\n", 751 | "ax1.set_xlabel('Time step', fontsize=9, fontfamily='Serif')\n", 752 | "ax1.set_ylabel('Reward', fontsize=7, fontfamily='Serif')\n", 753 | "ax1.set_xticklabels([None, 100, 200], fontsize=7, fontfamily='Serif')\n", 754 | "ax1.set_yticklabels([0, None, 0.5, None, 1.0], fontsize=7, fontfamily='Serif')\n", 755 | "\n", 756 | "ax2.imshow(breakout_image)\n", 757 | "ax2.axis('off')\n", 758 | "ax2.set_title(\"s'\", fontsize=9, fontfamily='Serif')\n", 759 | "\n", 760 | "asp = np.diff(ax1.get_xlim())[0] / np.diff(ax1.get_ylim())[0]\n", 761 | "asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])\n", 762 | "ax1.set_aspect(asp)\n", 763 | "ax1.text(-0.33, -0.45, \"(a) Breakout\", size=9, ha=\"center\", weight=\"bold\", fontfamily='Serif', \n", 764 | " transform=ax1.transAxes)\n", 765 | "ax1.yaxis.labelpad=-1\n", 766 | "\n", 767 | "ax3.plot(list(range(100, 200)), seaquest_rewards[100:])\n", 768 | "ax3.set_ylim(0, 1)\n", 769 | "ax3.set_xlabel('Time step', fontsize=9, fontfamily='Serif')\n", 770 | "ax3.set_ylabel('Reward', fontsize=7, fontfamily='Serif')\n", 771 | "ax3.set_xticklabels([None, 100, 200], fontsize=7, fontfamily='Serif')\n", 772 | "ax3.set_yticklabels([0, None, 0.5, None, 1.0], fontsize=7, fontfamily='Serif')\n", 773 | "\n", 774 | "ax4.imshow(seaquest_image)\n", 775 | "ax4.axis('off')\n", 776 | "ax4.set_title(\"s'\", fontsize=9, fontfamily='Serif')\n", 777 | "\n", 778 | "asp = np.diff(ax3.get_xlim())[0] / np.diff(ax3.get_ylim())[0]\n", 779 | "asp /= np.abs(np.diff(ax4.get_xlim())[0] / np.diff(ax4.get_ylim())[0])\n", 780 | "ax3.set_aspect(asp)\n", 781 | "ax3.text(-0.33, -0.45, \"(b) Seaquest\", size=9, ha=\"center\", weight=\"bold\", fontfamily='Serif', \n", 782 | " transform=ax3.transAxes)\n", 783 | "ax3.yaxis.labelpad=-1\n", 784 | "\n", 785 | "plt.subplots_adjust(wspace=0.6, top=0.99, bottom=0.17, left=0.05, right=0.95)\n", 786 | "\n", 787 | "plt.savefig('figures/atari-noscore-timeseries.pdf', dpi=300)" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": null, 793 | "metadata": {}, 794 | "outputs": [], 795 | "source": [] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": null, 800 | "metadata": {}, 801 | "outputs": [], 802 | "source": [] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": null, 807 | "metadata": {}, 808 | "outputs": [], 809 | "source": [] 810 | } 811 | ], 812 | "metadata": { 813 | "kernelspec": { 814 | "display_name": "Python 3", 815 | "language": "python", 816 | "name": "python3" 817 | }, 818 | "language_info": { 819 | "codemirror_mode": { 820 | "name": "ipython", 821 | "version": 3 822 | }, 823 | "file_extension": ".py", 824 | "mimetype": "text/x-python", 825 | "name": "python", 826 | "nbconvert_exporter": "python", 827 | "pygments_lexer": "ipython3", 828 | "version": "3.7.9" 829 | } 830 | }, 831 | "nbformat": 4, 832 | "nbformat_minor": 4 833 | } 834 | --------------------------------------------------------------------------------