├── common ├── __init__.py ├── layers.py ├── wrappers.py ├── replay_buffer.py └── replay_buffer_dtm.py ├── requirements.txt ├── README.md ├── LICENSE ├── .gitignore ├── controller.py ├── em.py ├── dqn_mbec.py └── dqn.ipynb /common/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import wrappers 3 | from . import replay_buffer -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ale_py==0.7.5 2 | gym==0.26.2 3 | ipython==8.7.0 4 | matplotlib==3.6.2 5 | numpy==1.19.2 6 | opencv_python==4.6.0.66 7 | pyflann==1.6.14 8 | pyflann_py3==0.1.0 9 | tensorboard_logger==0.1.0 10 | torch==1.4.0 11 | tqdm==4.63.0 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AJCAI22-Tutorial 2 | ### Demo code for AJCAI22-Tutorial 3 | [Tutorial website](https://thaihungle.github.io/talks/2022-12-05-AJCAI) 4 | [Conference website](https://ajcai2022.org/tutorials/) 5 | 6 | Install packages following requirements.txt 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | ### DQN 11 | Follow dqn.ipynb 12 | Reference: https://github.com/higgsfield/RL-Adventure/blob/master/1.dqn.ipynb 13 | 14 | 15 | ### MBEC 16 | - Create log and model folders to save training info 17 | ``` 18 | mkdir log 19 | mkdir model 20 | ``` 21 | - Run the script dqn_mbec.py using 22 | 23 | ``` 24 | python dqn_mbec.py --task MountainCar-v0 --rnoise 0.5 --render 0 --task2 mountaincar --n_epochs 100000 --max_episode 1000000 --model_name DTM --update_interval 100 --decay 1 --memory_size 3000 --k 15 --write_interval 10 --td_interval 1 --write_lr .5 --rec_rate .1 --rec_noise .1 --batch_size_plan 4 --rec_period 9999999999 --num_warm_up -1 --lr 0.0005 25 | ``` 26 | - Monitor training using tensorboard 27 | ``` 28 | cd log 29 | tensorboard --logdir=./ 30 | ``` 31 | Reference: https://github.com/thaihungle/MBEC-plus 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tony 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | log/ 131 | model/ -------------------------------------------------------------------------------- /common/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class NoisyLinear(nn.Module): 8 | def __init__(self, in_features, out_features, use_cuda, std_init=0.4): 9 | super(NoisyLinear, self).__init__() 10 | 11 | self.use_cuda = use_cuda 12 | self.in_features = in_features 13 | self.out_features = out_features 14 | self.std_init = std_init 15 | 16 | self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features)) 17 | self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features)) 18 | self.register_buffer('weight_epsilon', torch.FloatTensor(out_features, in_features)) 19 | 20 | self.bias_mu = nn.Parameter(torch.FloatTensor(out_features)) 21 | self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features)) 22 | self.register_buffer('bias_epsilon', torch.FloatTensor(out_features)) 23 | 24 | self.reset_parameters() 25 | self.reset_noise() 26 | 27 | def forward(self, x): 28 | if self.use_cuda: 29 | weight_epsilon = self.weight_epsilon.cuda() 30 | bias_epsilon = self.bias_epsilon.cuda() 31 | else: 32 | weight_epsilon = self.weight_epsilon 33 | bias_epsilon = self.bias_epsilon 34 | 35 | if self.training: 36 | weight = self.weight_mu + self.weight_sigma.mul(Variable(weight_epsilon)) 37 | bias = self.bias_mu + self.bias_sigma.mul(Variable(bias_epsilon)) 38 | else: 39 | weight = self.weight_mu 40 | bias = self.bias_mu 41 | 42 | return F.linear(x, weight, bias) 43 | 44 | def reset_parameters(self): 45 | mu_range = 1 / math.sqrt(self.weight_mu.size(1)) 46 | 47 | self.weight_mu.data.uniform_(-mu_range, mu_range) 48 | self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1))) 49 | 50 | self.bias_mu.data.uniform_(-mu_range, mu_range) 51 | self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0))) 52 | 53 | def reset_noise(self): 54 | epsilon_in = self._scale_noise(self.in_features) 55 | epsilon_out = self._scale_noise(self.out_features) 56 | 57 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 58 | self.bias_epsilon.copy_(self._scale_noise(self.out_features)) 59 | 60 | def _scale_noise(self, size): 61 | x = torch.randn(size) 62 | x = x.sign().mul(x.abs().sqrt()) 63 | return x -------------------------------------------------------------------------------- /controller.py: -------------------------------------------------------------------------------- 1 | """LSTM Controller.""" 2 | import torch 3 | from torch import nn 4 | from torch.nn import Parameter 5 | import numpy as np 6 | 7 | 8 | class FFWController(nn.Module): 9 | """An NTM controller based on LSTM.""" 10 | def __init__(self, num_inputs, num_outputs, num_layers): 11 | super(FFWController, self).__init__() 12 | 13 | self.num_inputs = num_inputs 14 | self.num_outputs = num_outputs 15 | self.num_layers = num_layers 16 | 17 | 18 | 19 | def create_new_state(self, batch_size): 20 | h = torch.zeros(batch_size, self.num_outputs) 21 | if torch.cuda.is_available(): 22 | h = h.cuda() 23 | return h 24 | 25 | def reset_parameters(self): 26 | pass 27 | 28 | def size(self): 29 | return self.num_inputs, self.num_outputs 30 | 31 | def forward(self, x, prev_state): 32 | return x, prev_state 33 | 34 | class LSTMController(nn.Module): 35 | """An NTM controller based on LSTM.""" 36 | def __init__(self, num_inputs, num_outputs, num_layers): 37 | super(LSTMController, self).__init__() 38 | 39 | self.num_inputs = num_inputs 40 | self.num_outputs = num_outputs 41 | self.num_layers = num_layers 42 | 43 | self.lstm = nn.LSTM(input_size=num_inputs, 44 | hidden_size=num_outputs, 45 | num_layers=num_layers) 46 | # The hidden state is a learned parameter 47 | if torch.cuda.is_available(): 48 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0, requires_grad=False) 49 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0, requires_grad=False) 50 | 51 | else: 52 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0, requires_grad=False) 53 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0, requires_grad=False) 54 | 55 | self.reset_parameters() 56 | 57 | def create_new_state(self, batch_size): 58 | # Dimension: (num_layers * num_directions, batch, hidden_size) 59 | lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) 60 | lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) 61 | # h = torch.zeros(self.num_layers, batch_size, self.num_outputs) 62 | # c = torch.zeros(self.num_layers, batch_size, self.num_outputs) 63 | # if torch.cuda.is_available(): 64 | # h = h.cuda() 65 | # c = c.cuda() 66 | # return h,c 67 | 68 | 69 | return lstm_h, lstm_c 70 | 71 | def reset_parameters(self): 72 | for p in self.lstm.parameters(): 73 | if p.dim() == 1: 74 | nn.init.constant_(p, 0) 75 | else: 76 | stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)) 77 | nn.init.uniform_(p, -stdev, stdev) 78 | 79 | def size(self): 80 | return self.num_inputs, self.num_outputs 81 | 82 | def forward(self, x, prev_state): 83 | x = x.unsqueeze(0) 84 | outp, state = self.lstm(x, prev_state) 85 | return outp.squeeze(0), state -------------------------------------------------------------------------------- /common/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | import gym 4 | from gym import spaces 5 | import cv2 6 | cv2.ocl.setUseOpenCL(False) 7 | 8 | class NoopResetEnv(gym.Wrapper): 9 | def __init__(self, env, noop_max=30): 10 | """Sample initial states by taking random number of no-ops on reset. 11 | No-op is assumed to be action 0. 12 | """ 13 | gym.Wrapper.__init__(self, env) 14 | self.noop_max = noop_max 15 | self.override_num_noops = None 16 | self.noop_action = 0 17 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 18 | 19 | def reset(self, **kwargs): 20 | """ Do no-op action for a number of steps in [1, noop_max].""" 21 | self.env.reset(**kwargs) 22 | if self.override_num_noops is not None: 23 | noops = self.override_num_noops 24 | else: 25 | noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101 26 | assert noops > 0 27 | obs = None 28 | for _ in range(noops): 29 | obs, _, done, _ = self.env.step(self.noop_action) 30 | if done: 31 | obs = self.env.reset(**kwargs) 32 | return obs 33 | 34 | def step(self, ac): 35 | return self.env.step(ac) 36 | 37 | class FireResetEnv(gym.Wrapper): 38 | def __init__(self, env): 39 | """Take action on reset for environments that are fixed until firing.""" 40 | gym.Wrapper.__init__(self, env) 41 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 42 | assert len(env.unwrapped.get_action_meanings()) >= 3 43 | 44 | def reset(self, **kwargs): 45 | self.env.reset(**kwargs) 46 | obs, _, done, _ = self.env.step(1) 47 | if done: 48 | self.env.reset(**kwargs) 49 | obs, _, done, _ = self.env.step(2) 50 | if done: 51 | self.env.reset(**kwargs) 52 | return obs 53 | 54 | def step(self, ac): 55 | return self.env.step(ac) 56 | 57 | class EpisodicLifeEnv(gym.Wrapper): 58 | def __init__(self, env): 59 | """Make end-of-life == end-of-episode, but only reset on true game over. 60 | Done by DeepMind for the DQN and co. since it helps value estimation. 61 | """ 62 | gym.Wrapper.__init__(self, env) 63 | self.lives = 0 64 | self.was_real_done = True 65 | 66 | def step(self, action): 67 | obs, reward, done, info = self.env.step(action) 68 | self.was_real_done = done 69 | # check current lives, make loss of life terminal, 70 | # then update lives to handle bonus lives 71 | lives = self.env.unwrapped.ale.lives() 72 | if lives < self.lives and lives > 0: 73 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames 74 | # so its important to keep lives > 0, so that we only reset once 75 | # the environment advertises done. 76 | done = True 77 | self.lives = lives 78 | return obs, reward, done, info 79 | 80 | def reset(self, **kwargs): 81 | """Reset only when lives are exhausted. 82 | This way all states are still reachable even though lives are episodic, 83 | and the learner need not know about any of this behind-the-scenes. 84 | """ 85 | if self.was_real_done: 86 | obs = self.env.reset(**kwargs) 87 | else: 88 | # no-op step to advance from terminal/lost life state 89 | obs, _, _, _ = self.env.step(0) 90 | self.lives = self.env.unwrapped.ale.lives() 91 | return obs 92 | 93 | class MaxAndSkipEnv(gym.Wrapper): 94 | def __init__(self, env, skip=4): 95 | """Return only every `skip`-th frame""" 96 | gym.Wrapper.__init__(self, env) 97 | # most recent raw observations (for max pooling across time steps) 98 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 99 | self._skip = skip 100 | 101 | def reset(self): 102 | return self.env.reset() 103 | 104 | def step(self, action): 105 | """Repeat action, sum reward, and max over last observations.""" 106 | total_reward = 0.0 107 | done = None 108 | for i in range(self._skip): 109 | obs, reward, done, info = self.env.step(action) 110 | if i == self._skip - 2: self._obs_buffer[0] = obs 111 | if i == self._skip - 1: self._obs_buffer[1] = obs 112 | total_reward += reward 113 | if done: 114 | break 115 | # Note that the observation on the done=True frame 116 | # doesn't matter 117 | max_frame = self._obs_buffer.max(axis=0) 118 | 119 | return max_frame, total_reward, done, info 120 | 121 | def reset(self, **kwargs): 122 | return self.env.reset(**kwargs) 123 | 124 | class ClipRewardEnv(gym.RewardWrapper): 125 | def __init__(self, env): 126 | gym.RewardWrapper.__init__(self, env) 127 | 128 | def reward(self, reward): 129 | """Bin reward to {+1, 0, -1} by its sign.""" 130 | return np.sign(reward) 131 | 132 | class WarpFrame(gym.ObservationWrapper): 133 | def __init__(self, env): 134 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 135 | gym.ObservationWrapper.__init__(self, env) 136 | self.width = 84 137 | self.height = 84 138 | self.observation_space = spaces.Box(low=0, high=255, 139 | shape=(self.height, self.width, 1), dtype=np.uint8) 140 | 141 | def observation(self, frame): 142 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 143 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 144 | return frame[:, :, None] 145 | 146 | class FrameStack(gym.Wrapper): 147 | def __init__(self, env, k): 148 | """Stack k last frames. 149 | Returns lazy array, which is much more memory efficient. 150 | See Also 151 | -------- 152 | baselines.common.atari_wrappers.LazyFrames 153 | """ 154 | gym.Wrapper.__init__(self, env) 155 | self.k = k 156 | self.frames = deque([], maxlen=k) 157 | shp = env.observation_space.shape 158 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) 159 | 160 | def reset(self): 161 | ob = self.env.reset() 162 | for _ in range(self.k): 163 | self.frames.append(ob) 164 | return self._get_ob() 165 | 166 | def step(self, action): 167 | ob, reward, done, info = self.env.step(action) 168 | self.frames.append(ob) 169 | return self._get_ob(), reward, done, info 170 | 171 | def _get_ob(self): 172 | assert len(self.frames) == self.k 173 | return LazyFrames(list(self.frames)) 174 | 175 | class ScaledFloatFrame(gym.ObservationWrapper): 176 | def __init__(self, env): 177 | gym.ObservationWrapper.__init__(self, env) 178 | 179 | def observation(self, observation): 180 | # careful! This undoes the memory optimization, use 181 | # with smaller replay buffers only. 182 | return np.array(observation).astype(np.float32) / 255.0 183 | 184 | class LazyFrames(object): 185 | def __init__(self, frames): 186 | """This object ensures that common frames between the observations are only stored once. 187 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 188 | buffers. 189 | This object should only be converted to numpy array before being passed to the model. 190 | You'd not believe how complex the previous solution was.""" 191 | self._frames = frames 192 | self._out = None 193 | 194 | def _force(self): 195 | if self._out is None: 196 | self._out = np.concatenate(self._frames, axis=2) 197 | self._frames = None 198 | return self._out 199 | 200 | def __array__(self, dtype=None): 201 | out = self._force() 202 | if dtype is not None: 203 | out = out.astype(dtype) 204 | return out 205 | 206 | def __len__(self): 207 | return len(self._force()) 208 | 209 | def __getitem__(self, i): 210 | return self._force()[i] 211 | 212 | def make_atari(env_id): 213 | env = gym.make(env_id) 214 | assert 'NoFrameskip' in env.spec.id 215 | env = NoopResetEnv(env, noop_max=30) 216 | env = MaxAndSkipEnv(env, skip=4) 217 | return env 218 | 219 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 220 | """Configure environment for DeepMind-style Atari. 221 | """ 222 | if episode_life: 223 | env = EpisodicLifeEnv(env) 224 | if 'FIRE' in env.unwrapped.get_action_meanings(): 225 | env = FireResetEnv(env) 226 | env = WarpFrame(env) 227 | if scale: 228 | env = ScaledFloatFrame(env) 229 | if clip_rewards: 230 | env = ClipRewardEnv(env) 231 | if frame_stack: 232 | env = FrameStack(env, 4) 233 | return env 234 | 235 | 236 | 237 | class ImageToPyTorch(gym.ObservationWrapper): 238 | """ 239 | Image shape to num_channels x weight x height 240 | """ 241 | def __init__(self, env): 242 | super(ImageToPyTorch, self).__init__(env) 243 | old_shape = self.observation_space.shape 244 | self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.uint8) 245 | 246 | def observation(self, observation): 247 | return np.swapaxes(observation, 2, 0) 248 | 249 | 250 | def wrap_pytorch(env): 251 | return ImageToPyTorch(env) -------------------------------------------------------------------------------- /common/replay_buffer.py: -------------------------------------------------------------------------------- 1 | #code from openai 2 | #https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py 3 | 4 | import numpy as np 5 | import random 6 | 7 | import operator 8 | 9 | 10 | class SegmentTree(object): 11 | def __init__(self, capacity, operation, neutral_element): 12 | """Build a Segment Tree data structure. 13 | https://en.wikipedia.org/wiki/Segment_tree 14 | Can be used as regular array, but with two 15 | important differences: 16 | a) setting item's value is slightly slower. 17 | It is O(lg capacity) instead of O(1). 18 | b) user has access to an efficient `reduce` 19 | operation which reduces `operation` over 20 | a contiguous subsequence of items in the 21 | array. 22 | Paramters 23 | --------- 24 | capacity: int 25 | Total size of the array - must be a power of two. 26 | operation: lambda obj, obj -> obj 27 | and operation for combining elements (eg. sum, max) 28 | must for a mathematical group together with the set of 29 | possible values for array elements. 30 | neutral_element: obj 31 | neutral element for the operation above. eg. float('-inf') 32 | for max and 0 for sum. 33 | """ 34 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 35 | self._capacity = capacity 36 | self._value = [neutral_element for _ in range(2 * capacity)] 37 | self._operation = operation 38 | 39 | def _reduce_helper(self, start, end, node, node_start, node_end): 40 | if start == node_start and end == node_end: 41 | return self._value[node] 42 | mid = (node_start + node_end) // 2 43 | if end <= mid: 44 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 45 | else: 46 | if mid + 1 <= start: 47 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 48 | else: 49 | return self._operation( 50 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 51 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 52 | ) 53 | 54 | def reduce(self, start=0, end=None): 55 | """Returns result of applying `self.operation` 56 | to a contiguous subsequence of the array. 57 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 58 | Parameters 59 | ---------- 60 | start: int 61 | beginning of the subsequence 62 | end: int 63 | end of the subsequences 64 | Returns 65 | ------- 66 | reduced: obj 67 | result of reducing self.operation over the specified range of array elements. 68 | """ 69 | if end is None: 70 | end = self._capacity 71 | if end < 0: 72 | end += self._capacity 73 | end -= 1 74 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 75 | 76 | def __setitem__(self, idx, val): 77 | # index of the leaf 78 | idx += self._capacity 79 | self._value[idx] = val 80 | idx //= 2 81 | while idx >= 1: 82 | self._value[idx] = self._operation( 83 | self._value[2 * idx], 84 | self._value[2 * idx + 1] 85 | ) 86 | idx //= 2 87 | 88 | def __getitem__(self, idx): 89 | assert 0 <= idx < self._capacity 90 | return self._value[self._capacity + idx] 91 | 92 | 93 | class SumSegmentTree(SegmentTree): 94 | def __init__(self, capacity): 95 | super(SumSegmentTree, self).__init__( 96 | capacity=capacity, 97 | operation=operator.add, 98 | neutral_element=0.0 99 | ) 100 | 101 | def sum(self, start=0, end=None): 102 | """Returns arr[start] + ... + arr[end]""" 103 | return super(SumSegmentTree, self).reduce(start, end) 104 | 105 | def find_prefixsum_idx(self, prefixsum): 106 | """Find the highest index `i` in the array such that 107 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 108 | if array values are probabilities, this function 109 | allows to sample indexes according to the discrete 110 | probability efficiently. 111 | Parameters 112 | ---------- 113 | perfixsum: float 114 | upperbound on the sum of array prefix 115 | Returns 116 | ------- 117 | idx: int 118 | highest index satisfying the prefixsum constraint 119 | """ 120 | assert 0 <= prefixsum <= self.sum() + 1e-5 121 | idx = 1 122 | while idx < self._capacity: # while non-leaf 123 | if self._value[2 * idx] > prefixsum: 124 | idx = 2 * idx 125 | else: 126 | prefixsum -= self._value[2 * idx] 127 | idx = 2 * idx + 1 128 | return idx - self._capacity 129 | 130 | 131 | class MinSegmentTree(SegmentTree): 132 | def __init__(self, capacity): 133 | super(MinSegmentTree, self).__init__( 134 | capacity=capacity, 135 | operation=min, 136 | neutral_element=float('inf') 137 | ) 138 | 139 | def min(self, start=0, end=None): 140 | """Returns min(arr[start], ..., arr[end])""" 141 | 142 | return super(MinSegmentTree, self).reduce(start, end) 143 | 144 | 145 | class ReplayBuffer(object): 146 | def __init__(self, size): 147 | """Create Replay buffer. 148 | Parameters 149 | ---------- 150 | size: int 151 | Max number of transitions to store in the buffer. When the buffer 152 | overflows the old memories are dropped. 153 | """ 154 | self._storage = [] 155 | self._maxsize = size 156 | self._next_idx = 0 157 | 158 | def __len__(self): 159 | return len(self._storage) 160 | 161 | def push(self, state, action, reward, next_state, done): 162 | data = (state, action, reward, next_state, done) 163 | 164 | if self._next_idx >= len(self._storage): 165 | self._storage.append(data) 166 | else: 167 | self._storage[self._next_idx] = data 168 | self._next_idx = (self._next_idx + 1) % self._maxsize 169 | 170 | def _encode_sample(self, idxes): 171 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 172 | for i in idxes: 173 | data = self._storage[i] 174 | obs_t, action, reward, obs_tp1, done = data 175 | obses_t.append(np.array(obs_t, copy=False)) 176 | actions.append(np.array(action, copy=False)) 177 | rewards.append(reward) 178 | obses_tp1.append(np.array(obs_tp1, copy=False)) 179 | dones.append(done) 180 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 181 | 182 | def sample(self, batch_size): 183 | """Sample a batch of experiences. 184 | Parameters 185 | ---------- 186 | batch_size: int 187 | How many transitions to sample. 188 | Returns 189 | ------- 190 | obs_batch: np.array 191 | batch of observations 192 | act_batch: np.array 193 | batch of actions executed given obs_batch 194 | rew_batch: np.array 195 | rewards received as results of executing act_batch 196 | next_obs_batch: np.array 197 | next set of observations seen after executing act_batch 198 | done_mask: np.array 199 | done_mask[i] = 1 if executing act_batch[i] resulted in 200 | the end of an episode and 0 otherwise. 201 | """ 202 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 203 | return self._encode_sample(idxes) 204 | 205 | 206 | class PrioritizedReplayBuffer(ReplayBuffer): 207 | def __init__(self, size, alpha): 208 | """Create Prioritized Replay buffer. 209 | Parameters 210 | ---------- 211 | size: int 212 | Max number of transitions to store in the buffer. When the buffer 213 | overflows the old memories are dropped. 214 | alpha: float 215 | how much prioritization is used 216 | (0 - no prioritization, 1 - full prioritization) 217 | See Also 218 | -------- 219 | ReplayBuffer.__init__ 220 | """ 221 | super(PrioritizedReplayBuffer, self).__init__(size) 222 | assert alpha > 0 223 | self._alpha = alpha 224 | 225 | it_capacity = 1 226 | while it_capacity < size: 227 | it_capacity *= 2 228 | 229 | self._it_sum = SumSegmentTree(it_capacity) 230 | self._it_min = MinSegmentTree(it_capacity) 231 | self._max_priority = 1.0 232 | 233 | def push(self, *args, **kwargs): 234 | """See ReplayBuffer.store_effect""" 235 | idx = self._next_idx 236 | super(PrioritizedReplayBuffer, self).push(*args, **kwargs) 237 | self._it_sum[idx] = self._max_priority ** self._alpha 238 | self._it_min[idx] = self._max_priority ** self._alpha 239 | 240 | def _sample_proportional(self, batch_size): 241 | res = [] 242 | for _ in range(batch_size): 243 | # TODO(szymon): should we ensure no repeats? 244 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 245 | idx = self._it_sum.find_prefixsum_idx(mass) 246 | res.append(idx) 247 | return res 248 | 249 | def sample(self, batch_size, beta): 250 | """Sample a batch of experiences. 251 | compared to ReplayBuffer.sample 252 | it also returns importance weights and idxes 253 | of sampled experiences. 254 | Parameters 255 | ---------- 256 | batch_size: int 257 | How many transitions to sample. 258 | beta: float 259 | To what degree to use importance weights 260 | (0 - no corrections, 1 - full correction) 261 | Returns 262 | ------- 263 | obs_batch: np.array 264 | batch of observations 265 | act_batch: np.array 266 | batch of actions executed given obs_batch 267 | rew_batch: np.array 268 | rewards received as results of executing act_batch 269 | next_obs_batch: np.array 270 | next set of observations seen after executing act_batch 271 | done_mask: np.array 272 | done_mask[i] = 1 if executing act_batch[i] resulted in 273 | the end of an episode and 0 otherwise. 274 | weights: np.array 275 | Array of shape (batch_size,) and dtype np.float32 276 | denoting importance weight of each sampled transition 277 | idxes: np.array 278 | Array of shape (batch_size,) and dtype np.int32 279 | idexes in buffer of sampled experiences 280 | """ 281 | assert beta > 0 282 | 283 | idxes = self._sample_proportional(batch_size) 284 | 285 | weights = [] 286 | p_min = self._it_min.min() / self._it_sum.sum() 287 | max_weight = (p_min * len(self._storage)) ** (-beta) 288 | 289 | for idx in idxes: 290 | p_sample = self._it_sum[idx] / self._it_sum.sum() 291 | weight = (p_sample * len(self._storage)) ** (-beta) 292 | weights.append(weight / max_weight) 293 | weights = np.array(weights) 294 | encoded_sample = self._encode_sample(idxes) 295 | return tuple(list(encoded_sample) + [weights, idxes]) 296 | 297 | def update_priorities(self, idxes, priorities): 298 | """Update priorities of sampled transitions. 299 | sets priority of transition at index idxes[i] in buffer 300 | to priorities[i]. 301 | Parameters 302 | ---------- 303 | idxes: [int] 304 | List of idxes of sampled transitions 305 | priorities: [float] 306 | List of updated priorities corresponding to 307 | transitions at the sampled idxes denoted by 308 | variable `idxes`. 309 | """ 310 | assert len(idxes) == len(priorities) 311 | for idx, priority in zip(idxes, priorities): 312 | assert priority > 0 313 | assert 0 <= idx < len(self._storage) 314 | self._it_sum[idx] = priority ** self._alpha 315 | self._it_min[idx] = priority ** self._alpha 316 | 317 | self._max_priority = max(self._max_priority, priority) -------------------------------------------------------------------------------- /common/replay_buffer_dtm.py: -------------------------------------------------------------------------------- 1 | #code from openai 2 | #https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py 3 | 4 | import numpy as np 5 | import random 6 | 7 | import operator 8 | 9 | 10 | class SegmentTree(object): 11 | def __init__(self, capacity, operation, neutral_element): 12 | """Build a Segment Tree data structure. 13 | https://en.wikipedia.org/wiki/Segment_tree 14 | Can be used as regular array, but with two 15 | important differences: 16 | a) setting item's value is slightly slower. 17 | It is O(lg capacity) instead of O(1). 18 | b) user has access to an efficient `reduce` 19 | operation which reduces `operation` over 20 | a contiguous subsequence of items in the 21 | array. 22 | Paramters 23 | --------- 24 | capacity: int 25 | Total size of the array - must be a power of two. 26 | operation: lambda obj, obj -> obj 27 | and operation for combining elements (eg. sum, max) 28 | must for a mathematical group together with the set of 29 | possible values for array elements. 30 | neutral_element: obj 31 | neutral element for the operation above. eg. float('-inf') 32 | for max and 0 for sum. 33 | """ 34 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 35 | self._capacity = capacity 36 | self._value = [neutral_element for _ in range(2 * capacity)] 37 | self._operation = operation 38 | 39 | def _reduce_helper(self, start, end, node, node_start, node_end): 40 | if start == node_start and end == node_end: 41 | return self._value[node] 42 | mid = (node_start + node_end) // 2 43 | if end <= mid: 44 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 45 | else: 46 | if mid + 1 <= start: 47 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 48 | else: 49 | return self._operation( 50 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 51 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 52 | ) 53 | 54 | def reduce(self, start=0, end=None): 55 | """Returns result of applying `self.operation` 56 | to a contiguous subsequence of the array. 57 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 58 | Parameters 59 | ---------- 60 | start: int 61 | beginning of the subsequence 62 | end: int 63 | end of the subsequences 64 | Returns 65 | ------- 66 | reduced: obj 67 | result of reducing self.operation over the specified range of array elements. 68 | """ 69 | if end is None: 70 | end = self._capacity 71 | if end < 0: 72 | end += self._capacity 73 | end -= 1 74 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 75 | 76 | def __setitem__(self, idx, val): 77 | # index of the leaf 78 | idx += self._capacity 79 | self._value[idx] = val 80 | idx //= 2 81 | while idx >= 1: 82 | self._value[idx] = self._operation( 83 | self._value[2 * idx], 84 | self._value[2 * idx + 1] 85 | ) 86 | idx //= 2 87 | 88 | def __getitem__(self, idx): 89 | assert 0 <= idx < self._capacity 90 | return self._value[self._capacity + idx] 91 | 92 | 93 | class SumSegmentTree(SegmentTree): 94 | def __init__(self, capacity): 95 | super(SumSegmentTree, self).__init__( 96 | capacity=capacity, 97 | operation=operator.add, 98 | neutral_element=0.0 99 | ) 100 | 101 | def sum(self, start=0, end=None): 102 | """Returns arr[start] + ... + arr[end]""" 103 | return super(SumSegmentTree, self).reduce(start, end) 104 | 105 | def find_prefixsum_idx(self, prefixsum): 106 | """Find the highest index `i` in the array such that 107 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 108 | if array values are probabilities, this function 109 | allows to sample indexes according to the discrete 110 | probability efficiently. 111 | Parameters 112 | ---------- 113 | perfixsum: float 114 | upperbound on the sum of array prefix 115 | Returns 116 | ------- 117 | idx: int 118 | highest index satisfying the prefixsum constraint 119 | """ 120 | assert 0 <= prefixsum <= self.sum() + 1e-5 121 | idx = 1 122 | while idx < self._capacity: # while non-leaf 123 | if self._value[2 * idx] > prefixsum: 124 | idx = 2 * idx 125 | else: 126 | prefixsum -= self._value[2 * idx] 127 | idx = 2 * idx + 1 128 | return idx - self._capacity 129 | 130 | 131 | class MinSegmentTree(SegmentTree): 132 | def __init__(self, capacity): 133 | super(MinSegmentTree, self).__init__( 134 | capacity=capacity, 135 | operation=min, 136 | neutral_element=float('inf') 137 | ) 138 | 139 | def min(self, start=0, end=None): 140 | """Returns min(arr[start], ..., arr[end])""" 141 | 142 | return super(MinSegmentTree, self).reduce(start, end) 143 | 144 | 145 | class ReplayBuffer(object): 146 | def __init__(self, size): 147 | """Create Replay buffer. 148 | Parameters 149 | ---------- 150 | size: int 151 | Max number of transitions to store in the buffer. When the buffer 152 | overflows the old memories are dropped. 153 | """ 154 | self._storage = [] 155 | self._maxsize = size 156 | self._next_idx = 0 157 | 158 | def __len__(self): 159 | return len(self._storage) 160 | 161 | 162 | def push(self, state, h_trj, action, reward, next_state, nh_trj, done): 163 | data = (state, h_trj, action, reward, next_state, nh_trj, done) 164 | 165 | if self._next_idx >= len(self._storage): 166 | self._storage.append(data) 167 | else: 168 | self._storage[self._next_idx] = data 169 | self._next_idx = (self._next_idx + 1) % self._maxsize 170 | 171 | def _encode_sample(self, idxes): 172 | states, h_trjs, actions, rewards, next_states, nh_trjs, dones = [], [], [], [], [], [], [] 173 | for i in idxes: 174 | data = self._storage[i] 175 | state, h_trj, action, reward, next_state, nh_trj, done = data 176 | states.append(np.array(state, copy=False)) 177 | h_trjs.append(np.array(h_trj, copy=False)) 178 | actions.append(np.array(action, copy=False)) 179 | rewards.append(reward) 180 | next_states.append(np.array(next_state, copy=False)) 181 | nh_trjs.append(np.array(nh_trj, copy=False)) 182 | dones.append(done) 183 | 184 | return np.array(states), np.array(h_trjs), np.array(actions), np.array(rewards), \ 185 | np.array(next_states), np.array(nh_trjs), np.array(dones) 186 | 187 | def sample(self, batch_size): 188 | """Sample a batch of experiences. 189 | Parameters 190 | ---------- 191 | batch_size: int 192 | How many transitions to sample. 193 | Returns 194 | ------- 195 | obs_batch: np.array 196 | batch of observations 197 | act_batch: np.array 198 | batch of actions executed given obs_batch 199 | rew_batch: np.array 200 | rewards received as results of executing act_batch 201 | next_obs_batch: np.array 202 | next set of observations seen after executing act_batch 203 | done_mask: np.array 204 | done_mask[i] = 1 if executing act_batch[i] resulted in 205 | the end of an episode and 0 otherwise. 206 | """ 207 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 208 | return self._encode_sample(idxes) 209 | 210 | 211 | class PrioritizedReplayBuffer(ReplayBuffer): 212 | def __init__(self, size, alpha): 213 | """Create Prioritized Replay buffer. 214 | Parameters 215 | ---------- 216 | size: int 217 | Max number of transitions to store in the buffer. When the buffer 218 | overflows the old memories are dropped. 219 | alpha: float 220 | how much prioritization is used 221 | (0 - no prioritization, 1 - full prioritization) 222 | See Also 223 | -------- 224 | ReplayBuffer.__init__ 225 | """ 226 | super(PrioritizedReplayBuffer, self).__init__(size) 227 | assert alpha > 0 228 | self._alpha = alpha 229 | 230 | it_capacity = 1 231 | while it_capacity < size: 232 | it_capacity *= 2 233 | 234 | self._it_sum = SumSegmentTree(it_capacity) 235 | self._it_min = MinSegmentTree(it_capacity) 236 | self._max_priority = 1.0 237 | 238 | def push(self, *args, **kwargs): 239 | """See ReplayBuffer.store_effect""" 240 | idx = self._next_idx 241 | super(PrioritizedReplayBuffer, self).push(*args, **kwargs) 242 | self._it_sum[idx] = self._max_priority ** self._alpha 243 | self._it_min[idx] = self._max_priority ** self._alpha 244 | 245 | def _sample_proportional(self, batch_size): 246 | res = [] 247 | for _ in range(batch_size): 248 | # TODO(szymon): should we ensure no repeats? 249 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 250 | idx = self._it_sum.find_prefixsum_idx(mass) 251 | res.append(idx) 252 | return res 253 | 254 | def sample(self, batch_size, beta): 255 | """Sample a batch of experiences. 256 | compared to ReplayBuffer.sample 257 | it also returns importance weights and idxes 258 | of sampled experiences. 259 | Parameters 260 | ---------- 261 | batch_size: int 262 | How many transitions to sample. 263 | beta: float 264 | To what degree to use importance weights 265 | (0 - no corrections, 1 - full correction) 266 | Returns 267 | ------- 268 | obs_batch: np.array 269 | batch of observations 270 | act_batch: np.array 271 | batch of actions executed given obs_batch 272 | rew_batch: np.array 273 | rewards received as results of executing act_batch 274 | next_obs_batch: np.array 275 | next set of observations seen after executing act_batch 276 | done_mask: np.array 277 | done_mask[i] = 1 if executing act_batch[i] resulted in 278 | the end of an episode and 0 otherwise. 279 | weights: np.array 280 | Array of shape (batch_size,) and dtype np.float32 281 | denoting importance weight of each sampled transition 282 | idxes: np.array 283 | Array of shape (batch_size,) and dtype np.int32 284 | idexes in buffer of sampled experiences 285 | """ 286 | assert beta > 0 287 | 288 | idxes = self._sample_proportional(batch_size) 289 | 290 | weights = [] 291 | p_min = self._it_min.min() / self._it_sum.sum() 292 | max_weight = (p_min * len(self._storage)) ** (-beta) 293 | 294 | for idx in idxes: 295 | p_sample = self._it_sum[idx] / self._it_sum.sum() 296 | weight = (p_sample * len(self._storage)) ** (-beta) 297 | weights.append(weight / max_weight) 298 | weights = np.array(weights) 299 | encoded_sample = self._encode_sample(idxes) 300 | return tuple(list(encoded_sample) + [weights, idxes]) 301 | 302 | def update_priorities(self, idxes, priorities): 303 | """Update priorities of sampled transitions. 304 | sets priority of transition at index idxes[i] in buffer 305 | to priorities[i]. 306 | Parameters 307 | ---------- 308 | idxes: [int] 309 | List of idxes of sampled transitions 310 | priorities: [float] 311 | List of updated priorities corresponding to 312 | transitions at the sampled idxes denoted by 313 | variable `idxes`. 314 | """ 315 | assert len(idxes) == len(priorities) 316 | for idx, priority in zip(idxes, priorities): 317 | assert priority > 0 318 | assert 0 <= idx < len(self._storage) 319 | self._it_sum[idx] = priority ** self._alpha 320 | self._it_min[idx] = priority ** self._alpha 321 | 322 | self._max_priority = max(self._max_priority, priority) -------------------------------------------------------------------------------- /em.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.nn import Parameter 4 | import pyflann 5 | import numpy as np 6 | from torch.autograd import Variable 7 | import random 8 | 9 | def inverse_distance(d, epsilon=1e-3): 10 | return 1 / (d + epsilon) 11 | 12 | class DND: 13 | def __init__(self, kernel, num_neighbors, max_memory, lr): 14 | self.kernel = kernel 15 | self.num_neighbors = num_neighbors 16 | self.max_memory = max_memory 17 | self.lr = lr 18 | self.keys = None 19 | self.values = None 20 | pyflann.set_distance_type('euclidean') # squared euclidean actually 21 | self.kdtree = pyflann.FLANN() 22 | # key_cache stores a cache of all keys that exist in the DND 23 | # This makes DND updates efficient 24 | self.key_cache = {} 25 | # stale_index is a flag that indicates whether or not the index in self.kdtree is stale 26 | # This allows us to only rebuild the kdtree index when necessary 27 | self.stale_index = True 28 | # indexes_to_be_updated is the set of indexes to be updated on a call to update_params 29 | # This allows us to rebuild only the keys of key_cache that need to be rebuilt when necessary 30 | self.indexes_to_be_updated = set() 31 | 32 | # Keys and value to be inserted into self.keys and self.values when commit_insert is called 33 | self.keys_to_be_inserted = None 34 | self.values_to_be_inserted = None 35 | 36 | # Move recently used lookup indexes 37 | # These should be moved to the back of self.keys and self.values to get LRU property 38 | self.move_to_back = set() 39 | 40 | def get_mem_size(self): 41 | if self.keys is not None: 42 | return len(self.keys) 43 | return 0 44 | 45 | def get_index(self, key): 46 | """ 47 | If key exists in the DND, return its index 48 | Otherwise, return None 49 | """ 50 | # print(key.data.cpu().numpy().shape) 51 | if self.key_cache.get(tuple(key.data.cpu().numpy()[0])) is not None: 52 | if self.stale_index: 53 | self.commit_insert() 54 | return int(self.kdtree.nn_index(key.data.cpu().numpy(), 1)[0][0]) 55 | else: 56 | return None 57 | 58 | def update(self, value, index): 59 | """ 60 | Set self.values[index] = value 61 | """ 62 | values = self.values.data 63 | values[index] = value[0].data 64 | self.values = Parameter(values) 65 | # self.optimizer = optim.RMSprop([self.keys, self.values], lr=self.lr) 66 | 67 | def insert(self, key, value): 68 | """ 69 | Insert key, value pair into DND 70 | """ 71 | 72 | if torch.cuda.is_available(): 73 | if self.keys_to_be_inserted is None: 74 | # Initial insert 75 | self.keys_to_be_inserted = Variable(key.data.cuda(), requires_grad=True) 76 | self.values_to_be_inserted = value.data.cuda() 77 | else: 78 | self.keys_to_be_inserted = torch.cat( 79 | [self.keys_to_be_inserted.cuda(), Variable(key.data.cuda(), requires_grad=True)], 0) 80 | self.values_to_be_inserted = torch.cat( 81 | [self.values_to_be_inserted.cuda(), value.data.cuda()], 0) 82 | else: 83 | if self.keys_to_be_inserted is None: 84 | # Initial insert 85 | self.keys_to_be_inserted = Variable(key.data, requires_grad=True) 86 | self.values_to_be_inserted = value.data 87 | else: 88 | self.keys_to_be_inserted = torch.cat( 89 | [self.keys_to_be_inserted,Variable(key.data, requires_grad=True)], 0) 90 | self.values_to_be_inserted = torch.cat( 91 | [self.values_to_be_inserted, value.data], 0) 92 | self.key_cache[tuple(key.data.cpu().numpy()[0])] = 0 93 | self.stale_index = True 94 | 95 | def commit_insert(self): 96 | if self.keys is None: 97 | self.keys = self.keys_to_be_inserted 98 | self.values = self.values_to_be_inserted 99 | elif self.keys_to_be_inserted is not None: 100 | self.keys = torch.cat([self.keys.data, self.keys_to_be_inserted], 0) 101 | self.values = torch.cat([self.values.data, self.values_to_be_inserted], 0) 102 | # Move most recently used key-value pairs to the back 103 | if len(self.move_to_back) != 0: 104 | self.keys = torch.cat([self.keys.data[list(set(range(len( 105 | self.keys))) - self.move_to_back)], self.keys.data[list(self.move_to_back)]], 0) 106 | self.values = torch.cat([self.values.data[list(set(range(len( 107 | self.values))) - self.move_to_back)], self.values.data[list(self.move_to_back)]], 0) 108 | self.move_to_back = set() 109 | 110 | if len(self.keys) > self.max_memory: 111 | # Expel oldest key to maintain total memory 112 | for key in self.keys[:-self.max_memory]: 113 | del self.key_cache[tuple(key.data.cpu().numpy())] 114 | self.keys = self.keys[-self.max_memory:].data 115 | self.values = self.values[-self.max_memory:].data 116 | self.keys_to_be_inserted = None 117 | self.values_to_be_inserted = None 118 | # self.optimizer = optim.RMSprop([self.keys, self.values], lr=self.lr) 119 | self.kdtree.build_index(self.keys.data.cpu().numpy(), algorithm='kdtree') 120 | self.stale_index = False 121 | 122 | 123 | # def lookup_batch(self, lookup_key, update_flag=False): 124 | # """ 125 | # Perform DND lookup 126 | # If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 127 | # """ 128 | # lookup_indexesb, dists = self.kdtree.nn_index( 129 | # lookup_key.data.cpu().numpy(), min(self.num_neighbors, len(self.keys))) 130 | # 131 | # self.values[lookup_indexesb] 132 | # outs = [] 133 | # for b, lookup_indexes in enumerate(lookup_indexesb): 134 | # output = 0 135 | # kernel_sum = 0 136 | # for i, index in enumerate(lookup_indexes): 137 | # if i == 0 and self.key_cache.get(tuple(lookup_key[0].data.cpu().numpy())) is not None: 138 | # # If a key exactly equal to lookup_key is used in the DND lookup calculation 139 | # # then the loss becomes non-differentiable. Just skip this case to avoid the issue. 140 | # continue 141 | # if update_flag: 142 | # self.indexes_to_be_updated.add(int(index)) 143 | # else: 144 | # self.move_to_back.add(int(index)) 145 | # kernel_val = self.kernel(self.keys[int(index)], lookup_key[b]) 146 | # output += kernel_val * self.values[int(index)] 147 | # kernel_sum += kernel_val 148 | # output = output / kernel_sum 149 | # outs.append(output) 150 | # return torch.stack(outs) 151 | 152 | def nearest(self, lookup_key): 153 | lookup_indexesb, distb = self.kdtree.nn_index( 154 | lookup_key.data.cpu().numpy(), 1) 155 | indexes = torch.LongTensor(lookup_indexesb).view(-1) 156 | values = torch.tensor(self.values).gather(0, indexes.to(lookup_key.device)) 157 | return values.reshape(lookup_key.shape[0], -1) 158 | 159 | 160 | 161 | def lookup(self, lookup_key, update_flag=False, is_learning=False, p=0.7, K=0): 162 | """ 163 | Perform DND lookup 164 | If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 165 | """ 166 | if K<=0: 167 | K=self.num_neighbors 168 | lookup_indexesb, distb = self.kdtree.nn_index( 169 | lookup_key.data.cpu().numpy(), min(K, len(self.keys))) 170 | indexes = torch.LongTensor(lookup_indexesb).view(-1) 171 | 172 | # print(self.kdtree.nn_index( 173 | # lookup_key.data.cpu().numpy(), min(self.num_neighbors, len(self.keys)))[0]) 174 | # print("----") 175 | old_indexes = indexes 176 | vshape = torch.tensor(self.values).shape 177 | if len(vshape)==2: 178 | indexes=indexes.unsqueeze(-1).repeat(1, vshape[1]) 179 | 180 | kvalues = inverse_distance(torch.tensor(distb).to(lookup_key.device)) 181 | kvalues = kvalues/torch.sum(kvalues, dim=-1, keepdim=True) 182 | #print(torch.tensor(self.values)) 183 | #print(indexes) 184 | values = torch.tensor(self.values).gather(0, indexes.to(lookup_key.device)) 185 | 186 | #print(values) 187 | #raise False 188 | if len(vshape)==2: 189 | values = values.reshape(lookup_key.shape[0],vshape[1], -1) 190 | kvalues = kvalues.unsqueeze(1) 191 | else: 192 | values = values.reshape(lookup_key.shape[0],-1) 193 | if random.random() > p: 194 | return torch.max(values, dim=-1)[0].detach() 195 | if not is_learning: 196 | self.move_to_back.update(old_indexes.numpy()) 197 | 198 | return torch.sum(kvalues*values, dim=-1).detach() 199 | 200 | # outs = [] 201 | # for b, lookup_indexes in enumerate(lookup_indexesb): 202 | # output = 0 203 | # kernel_sum = 0 204 | # for i, index in enumerate(lookup_indexes): 205 | # if i == 0 and self.key_cache.get(tuple(lookup_key[0].data.cpu().numpy())) is not None: 206 | # # If a key exactly equal to lookup_key is used in the DND lookup calculation 207 | # # then the loss becomes non-differentiable. Just skip this case to avoid the issue. 208 | # continue 209 | # # if update_flag: 210 | # # self.indexes_to_be_updated.add(int(index)) 211 | # # else: 212 | # # self.move_to_back.add(int(index)) 213 | # # kernel_val = self.kernel(self.keys[int(index)], lookup_key[b]) 214 | # kernel_val = inverse_distance(distb[b][i]) 215 | # output += kernel_val * self.values[int(index)] 216 | # kernel_sum += kernel_val 217 | # output = output / kernel_sum 218 | # outs.append(output) 219 | # return torch.stack(outs) 220 | 221 | def lookup_grad(self, lookup_key, update_flag=False, is_learning=False): 222 | """ 223 | Perform DND lookup 224 | If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 225 | """ 226 | lookup_indexesb, distb = self.kdtree.nn_index( 227 | lookup_key.data.cpu().numpy(), min(self.num_neighbors, len(self.keys))) 228 | 229 | train_var = [] 230 | outs = [] 231 | for b, lookup_indexes in enumerate(lookup_indexesb): 232 | output = 0 233 | kernel_sum = 0 234 | for i, index in enumerate(lookup_indexes): 235 | kkkk = self.keys[int(index)].detach().requires_grad_(True) 236 | train_var.append((kkkk, index)) 237 | kernel_val = self.kernel(kkkk, lookup_key[b]) 238 | #kernel_val = inverse_distance(distb[b][i]) 239 | output += kernel_val * self.values[int(index)] 240 | kernel_sum += kernel_val 241 | output = output / kernel_sum 242 | outs.append(output) 243 | return torch.stack(outs), train_var 244 | 245 | 246 | def lookup2write(self, lookup_key, R, update_flag=False, K=0): 247 | """ 248 | Perform DND lookup 249 | If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 250 | """ 251 | if K<=0: 252 | K=self.num_neighbors 253 | lookup_indexesb, distb = self.kdtree.nn_index( 254 | lookup_key.data.cpu().numpy(), min(K, len(self.keys))) 255 | #print(lookup_indexesb.shape) 256 | 257 | for b, lookup_indexes in enumerate(lookup_indexesb): 258 | ks = [] 259 | kernel_sum = 0 260 | if K == 1 and len(lookup_indexes.shape)==1: 261 | lookup_indexes=[lookup_indexes] 262 | distb=[distb] 263 | if isinstance(lookup_indexes, np.int32): 264 | lookup_indexes = [lookup_indexes] 265 | distb=[distb] 266 | 267 | for i, index in enumerate(lookup_indexes): 268 | # if i == 0 and self.key_cache.get(tuple(lookup_key[0].data.cpu().numpy())) is not None: 269 | # If a key exactly equal to lookup_key is used in the DND lookup calculation 270 | # then the loss becomes non-differentiable. Just skip this case to avoid the issue. 271 | # continue 272 | if update_flag: 273 | self.indexes_to_be_updated.add(int(index)) 274 | else: 275 | self.move_to_back.add(int(index)) 276 | curv = self.values[int(index)] 277 | # kernel_val = self.kernel(self.keys[int(index)], lookup_key[b]) 278 | kernel_val = inverse_distance(distb[b][i]) 279 | kernel_sum += kernel_val 280 | ks.append((index,kernel_val, curv)) 281 | # self.update((R - curv) * kernel_val * self.lr + curv, index) 282 | # kernel_val = 1 283 | for index, kernel_val, curv in ks: 284 | self.update((R-curv)*kernel_val/kernel_sum*self.lr + curv, index) 285 | #self.update(torch.max(R,curv), index) 286 | 287 | 288 | def update_params(self): 289 | """ 290 | Update self.keys and self.values via backprop 291 | Use self.indexes_to_be_updated to update self.key_cache accordingly and rebuild the index of self.kdtree 292 | """ 293 | for index in self.indexes_to_be_updated: 294 | del self.key_cache[tuple(self.keys[index].data.cpu().numpy())] 295 | # self.optimizer.step() 296 | # self.optimizer.zero_grad() 297 | for index in self.indexes_to_be_updated: 298 | self.key_cache[tuple(self.keys[index].data.cpu().numpy())] = 0 299 | self.indexes_to_be_updated = set() 300 | self.kdtree.build_index(self.keys.data.cpu().numpy()) 301 | self.stale_index = False 302 | -------------------------------------------------------------------------------- /dqn_mbec.py: -------------------------------------------------------------------------------- 1 | import math, random 2 | 3 | import gym 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.autograd as autograd 10 | import torch.nn.functional as F 11 | from torch.optim.lr_scheduler import StepLR 12 | import os, json, copy, pickle 13 | from IPython.display import clear_output 14 | import matplotlib.pyplot as plt 15 | from tensorboard_logger import configure, log_value 16 | from argparse import ArgumentParser 17 | from collections import deque 18 | from tqdm import tqdm 19 | import em as dnd 20 | import controller 21 | 22 | USE_CUDA = torch.cuda.is_available() 23 | # USE_CUDA= False 24 | Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs) 25 | 26 | 27 | 28 | 29 | 30 | class ReplayBuffer(object): 31 | def __init__(self, capacity): 32 | self.buffer = deque(maxlen=capacity) 33 | 34 | def push(self, state, h_trj, action, reward, old_reward, next_state, nh_trj, done): 35 | state = np.expand_dims(state, 0) 36 | next_state = np.expand_dims(next_state, 0) 37 | h_trj = np.expand_dims(h_trj, 0) 38 | nh_trj = np.expand_dims(nh_trj, 0) 39 | 40 | self.buffer.append((state, h_trj, action, reward, old_reward, next_state, nh_trj, done)) 41 | 42 | def sample(self, batch_size): 43 | state, h_trj, action, reward, old_reward, next_state, nh_trj, done = zip(*random.sample(self.buffer, batch_size)) 44 | return np.concatenate(state), np.concatenate(h_trj), action, reward, old_reward, \ 45 | np.concatenate(next_state), np.concatenate(nh_trj), done 46 | 47 | def __len__(self): 48 | return len(self.buffer) 49 | 50 | 51 | 52 | 53 | mse_criterion = nn.MSELoss() 54 | 55 | 56 | # plt.plot([epsilon_by_frame(i) for i in range(10000)]) 57 | # plt.show() 58 | def inverse_distance(h, h_i, epsilon=1e-3): 59 | return 1 / (torch.dist(h, h_i) + epsilon) 60 | 61 | def gauss_kernel(h, h_i, w=0.5): 62 | return torch.exp(-torch.dist(h, h_i)**2/w) 63 | 64 | def no_distance(h, h_i, epsilon=1e-3): 65 | return 1 66 | 67 | 68 | cos = nn.CosineSimilarity(dim=0, eps=1e-6) 69 | def cosine_distance(h, h_i): 70 | return max(cos(h, h_i),0) 71 | 72 | def weights_init(m): 73 | classname = m.__class__.__name__ 74 | if classname.find('Conv') != -1: 75 | weight_shape = list(m.weight.data.size()) 76 | fan_in = np.prod(weight_shape[1:4]) 77 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 78 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 79 | m.weight.data.uniform_(-w_bound, w_bound) 80 | m.bias.data.fill_(0) 81 | elif classname.find('Linear') != -1: 82 | weight_shape = list(m.weight.data.size()) 83 | fan_in = weight_shape[1] 84 | fan_out = weight_shape[0] 85 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 86 | m.weight.data.uniform_(-w_bound, w_bound) 87 | m.bias.data.fill_(0) 88 | 89 | class DQN_DTM(nn.Module): 90 | def __init__(self, env, args): 91 | super(DQN_DTM, self).__init__() 92 | if "maze_img" in args.task2: 93 | if "cnn" not in args.task2: 94 | self.num_inputs = 8 95 | self.proj= nn.Linear(514, 8) 96 | else: 97 | self.cnn = img_featurize.CNN(64) 98 | self.num_inputs = 64 99 | 100 | elif "world" in args.task2: 101 | self.cnn = img_featurize.CNN2(256) 102 | self.num_inputs = 256 103 | for param in self.cnn.parameters(): 104 | param.requires_grad = False 105 | else: 106 | self.num_inputs = env.observation_space.shape[0] 107 | if "trap" in args.task2: 108 | self.num_inputs = self.num_inputs + 2 109 | self.num_actions = env.action_space.n 110 | self.model_name=args.model_name 111 | self.num_warm_up=args.num_warm_up 112 | self.replay_buffer = args.replay_buffer 113 | self.gamma = args.gamma 114 | self.last_inserts=[] 115 | self.insert_size = args.insert_size 116 | self.args= args 117 | 118 | self.qnet = nn.Sequential( 119 | nn.Linear(self.num_inputs, args.qnet_size), 120 | nn.ReLU(), 121 | nn.Linear(args.qnet_size, args.qnet_size), 122 | nn.ReLU(), 123 | nn.Linear(args.qnet_size, env.action_space.n) 124 | ) 125 | 126 | self.qnet_target = nn.Sequential( 127 | nn.Linear(self.num_inputs, args.qnet_size), 128 | nn.ReLU(), 129 | nn.Linear(args.qnet_size, args.qnet_size), 130 | nn.ReLU(), 131 | nn.Linear(args.qnet_size, env.action_space.n) 132 | ) 133 | 134 | 135 | if args.write_interval<0: 136 | self.state2key = nn.Linear(self.num_inputs, args.hidden_size) 137 | self.dnd = dnd.DND(no_distance, num_neighbors=args.k, max_memory=args.memory_size, lr=args.write_lr) 138 | else: 139 | self.dnd = dnd.DND(inverse_distance, num_neighbors=args.k, max_memory=args.memory_size, lr=args.write_lr) 140 | 141 | self.emb_index2count = {} 142 | self.act_net = nn.Linear(self.num_actions, self.num_actions) 143 | self.act_net_target = nn.Linear(self.num_actions, self.num_actions) 144 | 145 | self.choice_net = nn.Sequential( 146 | nn.Linear(args.hidden_size, args.hidden_size),nn.ReLU(), 147 | nn.Linear(args.hidden_size, 1)) 148 | self.choice_net_target = nn.Sequential( 149 | nn.Linear(args.hidden_size, args.hidden_size),nn.ReLU(), 150 | nn.Linear(args.hidden_size, 1)) 151 | self.alpha = nn.Parameter(torch.tensor(1.0), 152 | requires_grad=True) 153 | self.alpha_target = nn.Parameter(torch.tensor(1.0), 154 | requires_grad=True) 155 | self.beta = nn.Parameter(torch.ones(self.num_actions), 156 | requires_grad=True) 157 | self.trj_model = controller.LSTMController(self.num_inputs+self.num_actions, args.hidden_size, num_layers=1) 158 | self.trj_out = nn.Linear(args.hidden_size, self.num_inputs+self.num_actions+1) 159 | self.reward_model = nn.Sequential( 160 | nn.Linear(self.num_inputs+self.num_actions+args.hidden_size, args.reward_hidden_size), 161 | nn.ReLU(), 162 | nn.Linear(args.reward_hidden_size,args.reward_hidden_size), 163 | nn.ReLU(), 164 | nn.Linear(args.reward_hidden_size, 1), 165 | ) 166 | self.best_trj = [] 167 | self.optimizer_dnd = torch.optim.Adam(self.trj_model.parameters()) 168 | # self.future_net = nn.Sequential( 169 | # nn.Linear(args.hidden_size, args.hidden_size), 170 | # nn.ReLU(), 171 | # nn.Linear(args.hidden_size, args.hidden_size), 172 | # ) 173 | self.future_net = nn.Sequential( 174 | nn.Linear(args.hidden_size, args.mem_dim), 175 | ) 176 | 177 | for param in self.future_net.parameters(): 178 | param.requires_grad = False 179 | if args.rec_period<0: 180 | for param in self.trj_model.parameters(): 181 | param.requires_grad = False 182 | 183 | self.apply(weights_init) 184 | 185 | def forward(self, x, h_trj, episode=0, use_mem=1.0, target=0, r=None, a=None, is_learning=False): 186 | 187 | q_value_semantic = self.semantic_net(x, target) 188 | z = torch.zeros(x.shape[0], self.num_actions) 189 | if USE_CUDA: 190 | z = z.cuda() 191 | q_episodic = qt= z 192 | if episode > self.num_warm_up and random.random() 1: 193 | plan_step = 1 194 | lx = x 195 | lh_trj = h_trj 196 | a0=a 197 | 198 | if self.args.write_interval<0: 199 | fh_trj = self.state2key(x) 200 | q_episodic = self.dnd.lookup(fh_trj, is_learning=is_learning, p=args.pread) 201 | #print(q_episodic) 202 | # print("q lookup", q_estimates) 203 | else: 204 | if a is not None: 205 | for i in range(plan_step): 206 | lx, h_trj_a = self.trj_model(self.make_trj_input(lx, a, r), lh_trj) 207 | if plan_step>1 or r is not None: 208 | lx = self.trj_out(lx) 209 | lh_trj = h_trj_a 210 | if len(lx.shape)>1: 211 | lx = lx[:,:self.num_inputs] 212 | else: 213 | lx = lx[:self.num_inputs] 214 | 215 | if r is None: 216 | r = self.reward_model( 217 | torch.cat([self.make_trj_input(lx, a), lh_trj[0][0].detach()], dim=-1)) 218 | 219 | q_episodic[:, a0] = r + self.args.gamma*self.episodic_net(h_trj_a, is_learning) 220 | if plan_step>1: 221 | for aa in range(self.num_actions): 222 | _, h_trj_aa = self.trj_model(self.make_trj_input(lx, aa, r), h_trj_a) 223 | qt[:, aa] = self.episodic_net(h_trj_aa, is_learning) 224 | a= qt.max(1)[1] 225 | 226 | else: 227 | # if random.random()<0.1: 228 | # if self.best_trj: 229 | # q_episodic = self.exploit(lx) 230 | # else: 231 | #print("plan") 232 | for a in range(self.num_actions): 233 | #print(a) 234 | # print(self.make_trj_input(lx, a)) 235 | lxx, h_trj_aa = self.trj_model(self.make_trj_input(lx, a), lh_trj) 236 | if r is None: 237 | # lxx = self.trj_out(lxx) 238 | # if len(lx.shape)>1: 239 | # pr = lxx[:,self.num_inputs+self.num_actions:] 240 | # lxx = lxx[:,:self.num_inputs] 241 | # else: 242 | # pr = lxx[self.num_inputs+self.num_actions:] 243 | # lxx = lxx[:self.num_inputs] 244 | pr = self.reward_model(torch.cat([self.make_trj_input(lx, a),lh_trj[0][0].detach()], dim=-1)) 245 | #print('predicted r: ',pr) 246 | else: 247 | pr = r 248 | 249 | pr = pr.to(device=lxx.device).squeeze(-1) 250 | 251 | q_episodic[:,a] = pr+self.args.gamma*self.episodic_net(h_trj_aa, is_learning) 252 | #print(lx, a, q_episodic[:,a]) 253 | #print('next v', self.episodic_net(h_trj_aa, is_learning)) 254 | #print('cur v est, ', q_episodic[:,a]) 255 | if is_learning is False and random.random()0: 287 | a = args.fix_alpha 288 | else: 289 | if target == 0: 290 | a = self.choice_net(h_trj[0][0]) 291 | # a = self.alpha 292 | else: 293 | a = self.choice_net_target(h_trj[0][0]) 294 | # a = self.alpha_target 295 | # a = self.alpha 296 | a = F.sigmoid(a) 297 | #q_value_semantic = self.semantic_net(torch.cat([x, a*q_episodic], dim=-1), target) 298 | 299 | #if random.random()<0.0003: 300 | # print(q_value_semantic[0], q_episodic[0]*a[0]) 301 | #a = F.sigmoid(self.choice_net(x)) 302 | #a = F.tanh(self.alpha) 303 | # a=0 304 | # return self.act_net(q_episodic) 305 | # if target == 0: 306 | if self.args.td_interval>0: 307 | if target == 0: 308 | return q_episodic*a+q_value_semantic,q_value_semantic, q_episodic*a 309 | else: 310 | return q_episodic*a+q_value_semantic,q_value_semantic, q_episodic*a 311 | # if target == 0: 312 | # return self.act_net(q_episodic*a+q_value_semantic),q_value_semantic, q_episodic*a 313 | # else: 314 | # return self.act_net_target(q_episodic*a+q_value_semantic),q_value_semantic, q_episodic*a 315 | 316 | 317 | # return q_episodic*a+ q_value_semantic,q_value_semantic, q_episodic 318 | # else: 319 | # return q_value_semantic, q_value_semantic, q_value_semantic 320 | # return q_value 321 | return q_episodic, q_episodic, q_episodic 322 | # return q_value_semantic*F.sigmoid(q_episodic),q_value_semantic, q_episodic 323 | # op = torch.matmul(q_value_semantic.unsqueeze(-1),F.sigmoid(q_episodic).unsqueeze(1)) 324 | # q = torch.matmul(op, self.beta.unsqueeze(0).repeat(q_value_semantic.shape[0], 1).unsqueeze(-1)) 325 | # return q.squeeze(-1),q_value_semantic, q_episodic 326 | 327 | 328 | def exploit(self, x): 329 | batch_size = x.shape[0] 330 | q_episodics = [] 331 | 332 | for (h_trj,v) in self.best_trj: 333 | z = torch.zeros(x.shape[0], self.num_actions) 334 | if USE_CUDA: 335 | z = z.cuda() 336 | kw = z 337 | last_h_trj = (h_trj[0].repeat(1, batch_size, 1), h_trj[1].repeat(1, batch_size, 1)) 338 | 339 | for a in range(self.num_actions): 340 | X = self.make_trj_input(x, a) 341 | y_p, nh = self.trj_model(X, last_h_trj) 342 | y_p = self.trj_out(y_p) 343 | rec_loss = torch.norm(y_p[:,:self.num_actions+self.num_inputs]- 344 | X[:, :self.num_actions + self.num_inputs], dim=-1, keepdim=True) 345 | kw[:,a] = -rec_loss 346 | 347 | kw = torch.exp(kw) 348 | kw = kw/torch.sum(kw, dim=-1, keepdim=True) 349 | vs = kw*1.6**v 350 | q_episodics.append(vs) 351 | return torch.mean(torch.stack(q_episodics, dim=0), dim=0) 352 | 353 | def value(self, s, h=None): 354 | if h is None: 355 | h = self.trj_model.create_new_state(1) 356 | q, qs, qe = self.forward(s, h) 357 | return torch.max(q, dim=-1)[0] 358 | else: 359 | return self.episodic_net(h) 360 | 361 | def planning(self, x, h_trj, a=None, plan_step=1): 362 | actions = [] 363 | qa = 0 364 | qsa = 0 365 | qea = 0 366 | 367 | t = 0 368 | for s in range(plan_step): 369 | q, qs, qe = self.forward(x, h_trj, a=a) 370 | action = q.max(1)[1].item() 371 | y_trj, h_trj = self.trj_model(self.make_trj_input(x, action), h_trj) 372 | actions.append(action) 373 | qa+=q*(self.args.gamma**t) 374 | qsa+=qs*(self.args.gamma**t) 375 | qea+=qe*(self.args.gamma**t) 376 | 377 | 378 | t+=1 379 | x = y_trj[:,:self.num_inputs] 380 | 381 | return qa, qsa, qea, actions 382 | 383 | def get_pivot_lastinsert(self): 384 | if len(self.last_inserts)>0: 385 | return min(self.last_inserts) 386 | else: 387 | return -10000000 388 | 389 | def semantic_net(self, x, target=0): 390 | if "maze_img" in self.args.task2: 391 | if "cnn" not in self.args.task2: 392 | x = self.proj(x).detach() 393 | else: 394 | x = self.cnn(x) 395 | elif "world" in self.args.task2: 396 | x = self.cnn(x) 397 | 398 | if target == 0: 399 | return self.qnet(x) 400 | else: 401 | return self.qnet_target(x) 402 | 403 | def episodic_net(self, h_trj, is_learning=False, K=0): 404 | fh_trj = self.future_net(h_trj[0][0]) 405 | # fh_trj = h_trj[0][0].detach() 406 | q_estimates = self.dnd.lookup(fh_trj, is_learning=is_learning, K=K, p=args.pread) 407 | # print("q lookup", q_estimates) 408 | return q_estimates 409 | 410 | def make_trj_input(self, x, a, r=None): 411 | 412 | if "maze_img" in self.args.task2: 413 | if "cnn" not in self.args.task2: 414 | x = self.proj(x).detach() 415 | else: 416 | if len(x.shape)==3: 417 | x = x.unsqueeze(0) 418 | 419 | x = self.cnn(x) 420 | elif "world" in self.args.task2: 421 | x = self.cnn(x) 422 | 423 | a_vec = torch.zeros(x.shape[0],self.num_actions) 424 | a_vec[:,a] = 1 425 | 426 | if USE_CUDA: 427 | a_vec = a_vec.cuda() 428 | # r = r.cuda() 429 | x = torch.cat([x, a_vec],dim=-1) 430 | 431 | return x 432 | 433 | 434 | def add_trj(self, h_trj, R, step, episode, action): 435 | # print(f"add R {R}") 436 | if self.args.write_interval<0: 437 | h_trj = h_trj 438 | else: 439 | h_trj = h_trj[0][0] 440 | hkey = torch.as_tensor(h_trj).float()#.detach() 441 | if USE_CUDA: 442 | hkey = hkey.cuda() 443 | hkey = self.future_net(hkey) 444 | # t = torch.Tensor([0.5]) 445 | # hkey = (F.sigmoid(hkey) > t).float() * 1 446 | if self.args.write_interval<0: 447 | hkey = self.state2key(hkey.cuda()) 448 | rvec = torch.zeros(1, self.num_actions) 449 | rvec[0,action] = R 450 | #print(hkey, rvec) 451 | #raise False 452 | else: 453 | rvec = R.unsqueeze(0) 454 | if USE_CUDA: 455 | rvec = rvec.cuda() 456 | 457 | # print(hkey) 458 | embedding_index = self.dnd.get_index(hkey) 459 | if embedding_index is None: 460 | self.dnd.insert(hkey, rvec.detach()) 461 | 462 | if self.insert_size>0: 463 | if len(self.last_inserts) > self.insert_size: 464 | self.last_inserts.sort() 465 | self.last_inserts = self.last_inserts[1:-1] 466 | self.last_inserts.append(rvec.detach()) 467 | if episode>self.num_warm_up and self.dnd.keys is not None: 468 | #try: 469 | self.dnd.lookup2write(hkey, rvec.detach()) 470 | #except Exception as e: 471 | # print(e) 472 | else: 473 | #print("dssssssss") 474 | if embedding_index not in self.emb_index2count: 475 | self.emb_index2count[embedding_index] = 1 476 | 477 | 478 | if self.args.write_interval<0: 479 | rvec = torch.zeros(1, self.num_actions) 480 | rvec[0, action] = R 481 | rvec = R.unsqueeze(0) 482 | if USE_CUDA: 483 | rvec = rvec.cuda() 484 | self.dnd.update(torch.max(rvec, 485 | torch.tensor(self.emb_index2count[embedding_index]).float().unsqueeze(0).to( 486 | device=rvec.device)), 487 | embedding_index) 488 | else: 489 | #R = self.dnd.values[embedding_index]*self.emb_index2count[embedding_index]+ R 490 | #R = R/(self.emb_index2count[embedding_index]+1) 491 | #self.emb_index2count[embedding_index]+=1 492 | #self.dnd.update(R.unsqueeze(0).detach(), embedding_index) 493 | 494 | # self.dnd.update(torch.max(R.unsqueeze(0), 495 | # torch.tensor(self.emb_index2count[embedding_index]).float().unsqueeze(0).to(device=R.device)), 496 | # embedding_index) 497 | 498 | if episode > self.num_warm_up and self.dnd.keys is not None: 499 | # try: 500 | self.dnd.lookup2write(hkey, rvec.detach(), K=args.k_write) 501 | 502 | def compute_rec_loss(self, last_h_trj, traj_buffer, optimizer, batch_size, noise=0.1): 503 | # print('len ', len(traj_buffer)) 504 | sasr = random.choices(traj_buffer, k=batch_size-1) 505 | sasr.append(traj_buffer[-1]) 506 | 507 | X = [] 508 | y = [] 509 | y2 = [] 510 | hs1 = [] 511 | hs2 = [] 512 | for s1,h, a,s2,r,o_r in sasr: 513 | s1 = torch.as_tensor(s1).float() 514 | s2 = torch.as_tensor(s2).float() 515 | if USE_CUDA: 516 | s1 = s1.cuda() 517 | s2 = s2.cuda() 518 | if len(s1.shape)==1 or len(s1.shape)==3: 519 | s1 = s1.unsqueeze(0) 520 | s2 = s2.unsqueeze(0) 521 | 522 | o_r = torch.FloatTensor([o_r]).unsqueeze(0) 523 | r = torch.FloatTensor([r]).unsqueeze(0) 524 | 525 | x = self.make_trj_input(s1, a, o_r) 526 | x2 = self.make_trj_input(s2, a, o_r) 527 | 528 | #X.append(x) 529 | if noise>0: 530 | if random.random()>0.5: 531 | X.append(F.dropout(x, p=noise)) 532 | else: 533 | noise_tensor = ((torch.max(torch.abs(x))*noise)**0.5)*torch.randn(x.shape) 534 | if USE_CUDA: 535 | noise_tensor = torch.tensor(noise_tensor).cuda() 536 | X.append(x + noise_tensor.float()) 537 | else: 538 | X.append(x) 539 | y.append(x) 540 | 541 | y2.append(torch.cat([x2, r.to(device=x2.device)], dim=-1)) 542 | 543 | hs1.append(torch.tensor(h[0]).to(device=last_h_trj[0].device)) 544 | hs2.append(torch.tensor(h[1]).to(device=last_h_trj[0].device)) 545 | X = torch.stack(X, dim=0) 546 | y = torch.stack(y, dim=0) 547 | y2 = torch.stack(y2, dim=0).squeeze(1) 548 | 549 | last_h_trj = (last_h_trj[0].repeat(1, batch_size, 1), last_h_trj[1].repeat(1, batch_size, 1)) 550 | cur_h_trj = (torch.cat(hs1, dim=1), 551 | torch.cat(hs2, dim=1)) 552 | 553 | if args.rec_type=="pred": 554 | last_h_trj = cur_h_trj 555 | y_p, _ = self.trj_model(X.squeeze(1), last_h_trj) 556 | y_p = self.trj_out(y_p) 557 | # pr = self.reward_model(torch.cat([X.squeeze(1), cur_h_trj],dim=-1)) 558 | _, h_trj = self.trj_model(y.squeeze(1), cur_h_trj) 559 | # h_pred = self.future_net(h_trj[0][0]) 560 | h_pred = h_trj[0][0] 561 | #print(y_p.shape) 562 | #print(X) 563 | #print(y2[:, self.num_inputs + self.num_actions:]) 564 | l1 = mse_criterion(y_p[:, :self.num_inputs + self.num_actions], y2[:, :self.num_inputs + self.num_actions]) 565 | # l2 = mse_criterion(y_p[:,self.num_inputs+self.num_actions:], y2[:,self.num_inputs+self.num_actions:]) 566 | #l2 = mse_criterion(pr, y2[:, self.num_inputs + self.num_actions:]) 567 | l3 = mse_criterion(cur_h_trj[0][0], last_h_trj[0][0]) 568 | loss = l1 569 | #loss = loss + l2 570 | # loss = loss + l3 571 | # loss = mse_criterion(h_pred, last_h_trj[0][0]) 572 | optimizer.zero_grad() 573 | loss.backward(retain_graph=True) 574 | torch.nn.utils.clip_grad_norm(self.parameters(), 10) 575 | optimizer.step() 576 | 577 | return loss, l1, 0, l3 578 | 579 | def compute_reward_loss(self, last_h_trj, traj_buffer, optimizer, batch_size, noise=0.1): 580 | # print('len ', len(traj_buffer)) 581 | sasr = random.choices(traj_buffer, k=batch_size-1) 582 | sasr.append(traj_buffer[-1]) 583 | 584 | X = [] 585 | y = [] 586 | y2 = [] 587 | hs1 = [] 588 | hs2 = [] 589 | for s1,h, a,s2,r,o_r in sasr: 590 | s1 = torch.as_tensor(s1).float() 591 | s2 = torch.as_tensor(s2).float() 592 | if USE_CUDA: 593 | s1 = s1.cuda() 594 | s2 = s2.cuda() 595 | if len(s1.shape)==1 or len(s1.shape)==3: 596 | s1 = s1.unsqueeze(0) 597 | s2 = s2.unsqueeze(0) 598 | 599 | o_r = torch.FloatTensor([o_r]).unsqueeze(0) 600 | r = torch.FloatTensor([r]).unsqueeze(0) 601 | 602 | x = self.make_trj_input(s1, a, o_r) 603 | x2 = self.make_trj_input(s2, a, o_r) 604 | 605 | #X.append(x) 606 | if noise>0: 607 | if random.random()>0.5: 608 | X.append(F.dropout(x, p=noise)) 609 | else: 610 | noise_tensor = ((torch.max(torch.abs(x))*noise)**0.5)*torch.randn(x.shape) 611 | if USE_CUDA: 612 | noise_tensor = torch.tensor(noise_tensor).cuda() 613 | X.append(x + noise_tensor.float()) 614 | else: 615 | X.append(x) 616 | 617 | y.append(x) 618 | 619 | y2.append(torch.cat([x2, r.to(device=x2.device)], dim=-1)) 620 | 621 | hs1.append(torch.tensor(h[0]).to(device=last_h_trj[0].device)) 622 | hs2.append(torch.tensor(h[1]).to(device=last_h_trj[0].device)) 623 | X = torch.stack(X, dim=0) 624 | y = torch.stack(y, dim=0) 625 | y2 = torch.stack(y2, dim=0).squeeze(1) 626 | cur_h_trj = (torch.cat(hs1, dim=1), 627 | torch.cat(hs2, dim=1)) 628 | 629 | 630 | # print(X) 631 | # print(y2[:, self.num_inputs + self.num_actions:]) 632 | 633 | pr = self.reward_model(torch.cat([X.squeeze(1), cur_h_trj[0][0]],dim=-1)) 634 | l2 = mse_criterion(pr, y2[:, self.num_inputs + self.num_actions:]) 635 | optimizer.zero_grad() 636 | l2.backward() 637 | optimizer.step() 638 | return l2 639 | 640 | def compute_td_loss(self, optimizer, batch_size, episode=0): 641 | state, h_trj, action, reward, old_reward, next_state, nh_trj, done = self.replay_buffer.sample(batch_size) 642 | 643 | state = Variable(torch.FloatTensor(np.float32(state))) 644 | next_state = Variable(torch.FloatTensor(np.float32(next_state)), volatile=True) 645 | action = Variable(torch.LongTensor(action)) 646 | reward = Variable(torch.FloatTensor(reward)) 647 | old_reward = Variable(torch.FloatTensor(old_reward)) 648 | 649 | done = Variable(torch.FloatTensor(done)) 650 | 651 | if USE_CUDA: 652 | state = state.cuda() 653 | next_state = next_state.cuda() 654 | action = action.cuda() 655 | reward = reward.cuda() 656 | old_reward = old_reward.cuda() 657 | done = done.cuda() 658 | 659 | 660 | # print(h_trj) 661 | # print(torch.tensor(h_trj[:,0,0,0])) 662 | hx = torch.tensor(h_trj[:,0,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(h_trj[:,0,0]).tolist(), dim=1) 663 | cx = torch.tensor(h_trj[:,1,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(h_trj[:,1,0]).tolist(), dim=1) 664 | q_values, q1, q2 = self.forward(state, (hx, cx), episode, use_mem=1, target=0, r=reward.unsqueeze(-1), a=action, is_learning=True) 665 | nhx = torch.tensor(nh_trj[:,0,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(nh_trj[:, 0,0]).tolist(), dim=1) 666 | ncx = torch.tensor(nh_trj[:,1,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(nh_trj[:, 1,0]).tolist(), dim=1) 667 | # raction = torch.randint(0, self.num_actions, action.shape) 668 | # if USE_CUDA: 669 | # raction = raction.cuda() 670 | # next_q_values = self.forward(next_state, (nhx, ncx), episode, use_mem=1) 671 | next_q_values, qn1, qn2 = self.forward(next_state, (nhx, ncx), episode, use_mem=1, target=1, r=None, is_learning=True) 672 | q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1) 673 | # q_value2 = q2.gather(1, action.unsqueeze(1)).squeeze(1) 674 | next_q_value = next_q_values.max(1)[0] 675 | # next_q_value = next_q_values.gather(1, raction.unsqueeze(1)).squeeze(1) 676 | 677 | expected_q_value = reward + self.gamma * next_q_value * (1 - done) 678 | # next_q_value2 = qn2.max(1)[0] 679 | # expected_q_value2 = reward + self.gamma * next_q_value2 * (1 - done) 680 | # print(q_value) 681 | # print(expected_q_value) 682 | loss = (q_value - Variable(expected_q_value.data)).pow(2).mean() 683 | # loss = (expected_q_value-q_value2).pow(2).mean() 684 | # loss = loss + (q1-q2).pow(2).mean() 685 | # loss = loss + (qn1/qn2-1).pow(2).mean() 686 | 687 | optimizer.zero_grad() 688 | loss.backward() 689 | torch.nn.utils.clip_grad_norm(self.parameters(), self.args.clip) 690 | optimizer.step() 691 | 692 | return loss 693 | 694 | 695 | def act(self, state, h_trj, epsilon, r=0, episode=0): 696 | state = Variable(torch.FloatTensor(state).unsqueeze(0), volatile=True) 697 | actions = actione = None 698 | if random.random() > epsilon and self.dnd.get_mem_size()>1: 699 | q_value, qs, qe = self.forward(state, h_trj, episode, 700 | r=None, 701 | use_mem=1) 702 | 703 | #q_value, qs, qe, _ = self.planning(state, h_trj, plan_step=2) 704 | 705 | action = q_value.max(1)[1].data[0] 706 | actions = qs.max(1)[1].data[0] 707 | actione = qe.max(1)[1].data[0] 708 | else: 709 | action = random.randrange(self.num_actions) 710 | 711 | y_trj, h_trj = self.trj_model(self.make_trj_input(state, action, 712 | torch.FloatTensor([r]).unsqueeze(0)), 713 | h_trj) 714 | # print("act {}".format(h_trj[0].shape)) 715 | # y_trj = self.trj_out(y_trj) 716 | return action, h_trj, y_trj, actions, actione 717 | 718 | def update_target(self): 719 | self.qnet_target.load_state_dict(self.qnet.state_dict()) 720 | self.choice_net_target.load_state_dict(self.choice_net.state_dict()) 721 | self.act_net_target.load_state_dict(self.act_net.state_dict()) 722 | self.alpha_target = self.alpha 723 | 724 | 725 | # def plot(frame_idx, rewards, td_losses, rec_losses): 726 | # # clear_output(True) 727 | # # plt.figure(figsize=(20,5)) 728 | # # plt.subplot(121) 729 | # # plt.title('frame %s. reward: %s' % (frame_idx, np.mean(rewards[-10:]))) 730 | # # plt.plot(rewards) 731 | # # plt.subplot(122) 732 | # # plt.title('loss') 733 | # # plt.plot(losses) 734 | # # plt.show() 735 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | def run(args): 743 | epsilon_start = args.epsilon_start 744 | epsilon_final = args.epsilon_final 745 | epsilon_decay = args.epsilon_decay 746 | 747 | epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp( 748 | -1. * frame_idx / epsilon_decay) 749 | 750 | batch_size = args.batch_size 751 | args.replay_buffer = ReplayBuffer(args.replay_size) 752 | env_id = args.task 753 | 754 | 755 | if "world" in args.task2: 756 | env = gym.make(env_id) 757 | 758 | env=img_featurize.FrameStack(env, 4) 759 | env = img_featurize.ImageToPyTorch(env) 760 | else: 761 | env = gym.make(env_id) 762 | 763 | 764 | if args.mode=="train": 765 | log_dir = f'./log/{args.task}{args.task2}new/{args.model_name}' \ 766 | f'k{args.k}n{args.memory_size}i{args.write_interval}w{args.num_warm_up}h{args.hidden_size}' \ 767 | f'u{args.update_interval}b{args.replay_size}l{args.insert_size}a{args.fix_alpha}' \ 768 | f'm{args.min_reward}q{args.qnet_size}-{args.run_id}/' 769 | print(log_dir) 770 | configure(log_dir) 771 | 772 | model = DQN_DTM(env, args) 773 | num_params = 0 774 | for p in model.parameters(): 775 | num_params += p.data.view(-1).size(0) 776 | print(f"no params {num_params}") 777 | 778 | if USE_CUDA: 779 | model = model.cuda() 780 | 781 | model.update_target() 782 | 783 | num_frames = args.n_epochs 784 | optimizer_td = optim.Adam(model.parameters(), lr=args.lr) 785 | scheduler_td = StepLR(optimizer_td, step_size=500, gamma=args.decay) 786 | 787 | optimizer_reward = optim.Adam(model.reward_model.parameters()) 788 | 789 | optimizer_rec = optim.Adam(model.parameters(), lr=args.lr) 790 | 791 | td_losses = [] 792 | rec_losses = [] 793 | rec_l1s = [] 794 | rec_l2s = [] 795 | rec_l3s = [] 796 | 797 | all_rewards = [] 798 | episode_reward = mepisode_reward = 0 799 | traj_buffer = [] 800 | state_buffer = [] 801 | best_reward = -1000000000 802 | 803 | state = env.reset() 804 | ostate = state.copy() 805 | 806 | 807 | 808 | if "maze_img" in args.task2: 809 | state_map = {} 810 | img = env.render("rgb_array") 811 | if "view" in args.task2: 812 | img = img_featurize.get_viewport(img, state, num_state=(env.observation_space.high[0]+1)**2) 813 | if "cnn" not in args.task2: 814 | # Load the pretrained model 815 | stateim = img_featurize.get_vector(img) 816 | stateim = np.concatenate([stateim, state]) 817 | 818 | state_map[tuple(state.tolist())] = stateim 819 | state = state_map[tuple(state.tolist())] 820 | else: 821 | state_map[tuple(state.tolist())] = img_featurize.get_image(img) 822 | state = state_map[tuple(state.tolist())] 823 | 824 | if "world" in args.task2: 825 | state = img_featurize.get_image2(state) 826 | 827 | 828 | if "trap" in args.task2: 829 | trap = [0, 0] 830 | while trap[0] == 0 and trap[1] == 0 or trap[0] == env.observation_space.high[0] and trap[1] == \ 831 | env.observation_space.high[0]: 832 | trap = [random.randint(0, env.observation_space.high[0]), 833 | random.randint(0, env.observation_space.high[0])] 834 | print("trap ", trap) 835 | state = np.concatenate([state, trap]) 836 | 837 | episode_num = 0 838 | best_episode_reward = -1e9 839 | step = 0 840 | h_trj = model.trj_model.create_new_state(1) 841 | reward = old_reward = 0 842 | m_contr = [] 843 | s_contr = [] 844 | 845 | for frame_idx in tqdm(range(1, num_frames + 1)): 846 | 847 | epsilon = epsilon_by_frame(frame_idx) 848 | action, nh_trj, y_trj, action_s, action_e = model.act(state, h_trj, epsilon, r=reward, episode=episode_num) 849 | 850 | try: 851 | action = action.item() 852 | except: 853 | pass 854 | if action_e is not None: 855 | if action == action_e: 856 | m_contr.append(1) 857 | else: 858 | m_contr.append(0) 859 | 860 | if action == action_s: 861 | s_contr.append(1) 862 | else: 863 | s_contr.append(0) 864 | 865 | old_reward = reward 866 | next_state, reward, done, _ = env.step(action) 867 | reward = float(reward) 868 | 869 | 870 | 871 | 872 | if "maze" in args.task: 873 | if step > 1000 and "hard" in args.task2: 874 | done = 1 875 | if next_state[0] == ostate[0] and next_state[1] == ostate[1]: 876 | reward = reward - 1 877 | if "hard" in args.task2: 878 | next_state = env.reset() 879 | if "trap" in args.task2 and next_state[0]==trap[0] and next_state[1]==trap[1]: 880 | reward = reward - 2 881 | if "trap_key" in args.task2 and next_state[0]==1 and next_state[1]==1: 882 | trap = [-1, -1] 883 | print("free trap") 884 | 885 | #print(next_state,action, state, reward) 886 | 887 | 888 | m_reward = reward 889 | 890 | #if "world" in args.task2: 891 | # if reward<1.0: 892 | # reward=-0.01 893 | 894 | if args.rnoise>0: 895 | reward += np.random.normal(0, args.rnoise, 1)[0] 896 | if -1 -args.rnoise: 898 | reward = -reward 899 | 900 | if random.random()args.min_reward: 964 | if len(model.last_inserts)>=args.insert_size>0 and RRbest_episode_reward: 974 | best_episode_reward = episode_reward 975 | 976 | # model.best_trj.append((h_trj, episode_reward)) 977 | # if len(model.best_trj)>10: 978 | # model.best_trj = sorted(model.best_trj, key=lambda tup: tup[1]) 979 | # model.best_trj = model.best_trj[1:] 980 | # 981 | 982 | l2 = model.compute_reward_loss(h_trj, traj_buffer, optimizer_reward, args.batch_size_reward, 983 | noise=0) 984 | rec_l2s.append(l2.data.item()) 985 | 986 | if frame_idx < args.rec_period and args.rec_rate>0: 987 | loss, l1, l2, l3 = model.compute_rec_loss(h_trj, traj_buffer, optimizer_rec, args.batch_size_plan, 988 | noise=args.rec_noise) 989 | 990 | rec_losses.append(loss.data.item()) 991 | rec_l1s.append(l1.data.item()) 992 | rec_l3s.append(l3.data.item()) 993 | 994 | if add_mem==1: 995 | model.dnd.commit_insert() 996 | 997 | h_trj = model.trj_model.create_new_state(1) 998 | state = env.reset() 999 | if "maze_img" in args.task2: 1000 | state = state_map[tuple(state.tolist())] 1001 | 1002 | 1003 | all_rewards.append(mepisode_reward) 1004 | #print("episode reward", episode_reward) 1005 | 1006 | episode_reward = mepisode_reward = 0 1007 | del traj_buffer 1008 | del state_buffer 1009 | 1010 | traj_buffer = [] 1011 | state_buffer = [] 1012 | episode_num+=1 1013 | step=0 1014 | if "random" in args.task2: 1015 | env.close() 1016 | del env 1017 | env = gym.make(env_id) 1018 | state = env.reset() 1019 | ostate = state.copy() 1020 | if "maze_img" in args.task2: 1021 | state_map = {} 1022 | img = env.render("rgb_array") 1023 | if "view" in args.task2: 1024 | img = img_featurize.get_viewport(img, state, num_state=(env.observation_space.high[0]+1)**2) 1025 | if "cnn" not in args.task2: 1026 | # Load the pretrained model 1027 | stateim = img_featurize.get_vector(img) 1028 | stateim = np.concatenate([stateim, state]) 1029 | 1030 | state_map[tuple(state.tolist())] = stateim 1031 | state = state_map[tuple(state.tolist())] 1032 | else: 1033 | state_map[tuple(state.tolist())] = img_featurize.get_image(img) 1034 | state = state_map[tuple(state.tolist())] 1035 | 1036 | if "world" in args.task2: 1037 | state = img_featurize.get_image2(state) 1038 | 1039 | if "trap" in args.task2: 1040 | trap = [0, 0] 1041 | while trap[0] == 0 and trap[1] == 0 or trap[0] == env.observation_space.high[0] and trap[1] == \ 1042 | env.observation_space.high[0]: 1043 | trap = [random.randint(0, env.observation_space.high[0]), 1044 | random.randint(0, env.observation_space.high[0])] 1045 | print("trap ", trap) 1046 | state = np.concatenate([state, trap]) 1047 | log_value('Reward/episode reward', np.mean(all_rewards[-args.num_avg_reward:]), episode_num) 1048 | 1049 | 1050 | else: 1051 | if args.write_interval<0: 1052 | state_buffer.append(([state.copy()], episode_reward, step, action)) 1053 | 1054 | elif frame_idx % args.write_interval == 0 and len(traj_buffer) > 0: 1055 | #if random.random() < args.rec_rate: 1056 | l2 = model.compute_reward_loss(h_trj, traj_buffer, optimizer_reward, args.batch_size_reward, 1057 | noise=0) 1058 | rec_l2s.append(l2.data.item()) 1059 | if random.random() < args.rec_rate and frame_idx0 and frame_idx % args.td_interval == 0 and frame_idx > args.td_start and len(args.replay_buffer) > batch_size: 1070 | loss = model.compute_td_loss(optimizer_td, batch_size, episode_num) 1071 | scheduler_td.step() 1072 | td_losses.append(loss.data.item()) 1073 | 1074 | 1075 | 1076 | # if frame_idx % 2000 == 0 and frame_idx > 0: 1077 | # for param_group in optimizer_rec.param_groups: 1078 | # param_group['lr'] = param_group['lr'] / 2 1079 | # for param_group in optimizer_td.param_groups: 1080 | # param_group['lr'] = param_group['lr'] / 2 1081 | 1082 | if frame_idx % args.update_interval == 0: 1083 | model.update_target() 1084 | 1085 | if frame_idx % args.plot_interval == 0: 1086 | #print(optimizer_td.param_groups[0]['lr']) 1087 | log_value('Mem/num stored', model.dnd.get_mem_size(), int(frame_idx)) 1088 | # log_value('Mem/alpha', model.alpha.detach().item(), int(frame_idx)) 1089 | log_value('Mem/min last', model.get_pivot_lastinsert(), int(frame_idx)) 1090 | log_value('Mem/contrib', np.mean(m_contr), int(frame_idx)) 1091 | log_value('Mem/scontrib', np.mean(s_contr), int(frame_idx)) 1092 | 1093 | log_value('Loss/rec loss', np.mean(rec_losses), int(frame_idx)) 1094 | log_value('Loss/rec l1 loss', np.mean(rec_l1s), int(frame_idx)) 1095 | log_value('Loss/rec l2 loss', np.mean(rec_l2s), int(frame_idx)) 1096 | log_value('Loss/rec l3 loss', np.mean(rec_l3s), int(frame_idx)) 1097 | log_value('Loss/td loss', np.mean(td_losses), int(frame_idx)) 1098 | log_value('Episode/num episode', episode_num, int(frame_idx)) 1099 | 1100 | currw = np.mean(all_rewards[-args.num_avg_reward:]) 1101 | 1102 | log_value('Reward/frame reward', currw, int(frame_idx)) 1103 | print(f'episode {episode_num} step {frame_idx} avg rewards {currw} vs best {best_reward}') 1104 | 1105 | if best_reward=currw and frame_idx>num_frames-2* args.plot_interval: 1106 | best_reward = currw 1107 | if os.path.isdir(args.save_model) is False: 1108 | os.mkdir(args.save_model) 1109 | save_dir = os.path.join(args.save_model, args.task) 1110 | if os.path.isdir(save_dir) is False: 1111 | os.mkdir(save_dir) 1112 | with open(os.path.join(save_dir, f'args.jon'), 'w') as fp: 1113 | sa = copy.copy(args) 1114 | sa.device = None 1115 | sa.replay_buffer = None 1116 | json.dump(vars(sa), fp) 1117 | save_dir = os.path.join(save_dir, f"{args.model_name}.pt") 1118 | torch.save(model.state_dict(), save_dir) 1119 | print(f"save model to {save_dir}!") 1120 | 1121 | with open(save_dir + 'mem', 'wb') as output: # Overwrites any existing file. 1122 | pickle.dump(model.dnd, output) 1123 | print(f"save memory to {save_dir + 'mem'}!") 1124 | 1125 | if len(all_rewards) > args.num_avg_reward*1.1: 1126 | all_rewards = all_rewards[-args.num_avg_reward:] 1127 | td_losses = [] 1128 | rec_losses = [] 1129 | rec_l1s = [] 1130 | rec_l2s = [] 1131 | rec_l3s = [] 1132 | m_contr = [] 1133 | s_contr = [] 1134 | 1135 | if episode_num > args.max_episode > 0: 1136 | break 1137 | 1138 | 1139 | 1140 | if __name__ == "__main__": 1141 | parser = ArgumentParser(description="Training script for TEM.") 1142 | parser.add_argument("--mode", default="train", 1143 | help="train or test") 1144 | parser.add_argument("--save_model", default="./model/", 1145 | help="save model dir") 1146 | parser.add_argument("--rnoise", type=float, default=0, 1147 | help="add noise to reward") 1148 | parser.add_argument("--pnoise", type=float, default=0, 1149 | help="add noise to state") 1150 | parser.add_argument("--task", default="MiniWorld-CollectHealth-v0", 1151 | help="task name") 1152 | parser.add_argument("--task2", default="miniworld", 1153 | help="task name") 1154 | parser.add_argument("--render", type=int, default=0, 1155 | help="render or not") 1156 | parser.add_argument("--n_epochs", type=int, default=1000000, 1157 | help="number of epochs") 1158 | parser.add_argument("--max_episode", type=int, default=10000, 1159 | help="number of episode allowed") 1160 | parser.add_argument("--model_name", default="DTM", 1161 | help="DTM") 1162 | parser.add_argument("--lr", type=float, default=0.0005, 1163 | help="lr") 1164 | parser.add_argument("--decay", type=float, default=1, 1165 | help=" decay lr") 1166 | parser.add_argument("--clip", type=float, default=10, 1167 | help="clip gradient") 1168 | parser.add_argument("--epsilon_start", type=float, default=1.0, 1169 | help="exploration start") 1170 | parser.add_argument("--epsilon_final", type=float, default=0.01, 1171 | help="exploration final") 1172 | parser.add_argument("--epsilon_decay", type=float, default=500.0, 1173 | help="exploration decay") 1174 | parser.add_argument("--gamma", type=float, default=0.99, 1175 | help="gamma") 1176 | parser.add_argument("--batch_size", type=int, default=32, 1177 | help="batch size value model") 1178 | parser.add_argument("--batch_size_plan", type=int, default=4, 1179 | help="batch size planning model") 1180 | parser.add_argument("--batch_size_reward", type=int, default=32, 1181 | help="batch size planning model") 1182 | parser.add_argument("--reward_hidden_size", type=int, default=32, 1183 | help="batch size planning model") 1184 | parser.add_argument("--replay_size", type=int, default=1000000, 1185 | help="replay buffer size") 1186 | parser.add_argument("--qnet_size", type=int, default=128, 1187 | help="MLP hidden") 1188 | parser.add_argument("--hidden_size", type=int, default=16, 1189 | help="RNN hidden") 1190 | parser.add_argument("--mem_dim", type=int, default=5, 1191 | help="memory dimesntion") 1192 | parser.add_argument("--memory_size", type=int, default=10000, 1193 | help="memory size") 1194 | parser.add_argument("--insert_size", type=int, default=-1, 1195 | help="insert size") 1196 | parser.add_argument("--k", type=int, default=5, 1197 | help="num neighbor") 1198 | parser.add_argument("--k_write", type=int, default=-1, 1199 | help="num neighbor") 1200 | parser.add_argument("--mem_mode", type=int, default=0, 1201 | help="memory_mode") 1202 | parser.add_argument("--pread", type=float, default=0.7, 1203 | help="minimum reward of env") 1204 | parser.add_argument("--min_reward", type=float, default=-1e8, 1205 | help="minimum reward of env") 1206 | parser.add_argument("--write_interval", type=int, default=10, 1207 | help="interval for memory writing") 1208 | parser.add_argument("--write_lr", type=float, default=.5, 1209 | help="learning rate of writing") 1210 | parser.add_argument("--fix_alpha", type=float, default=-1, 1211 | help="fix alpha") 1212 | parser.add_argument("--bstr_rate", type=float, default=0.1, 1213 | help="learning rate of writing") 1214 | parser.add_argument("--td_interval", type=int, default=1, 1215 | help="interval for td update") 1216 | parser.add_argument("--td_start", type=int, default=-1, 1217 | help="interval for td update") 1218 | parser.add_argument("--rec_rate", type=float, default=0.1, 1219 | help="rate of reconstruction learning") 1220 | parser.add_argument("--rec_noise", type=float, default=0.1, 1221 | help="rate of reconstruction learning") 1222 | parser.add_argument("--rec_type", type=str, default="mem", 1223 | help="rec type") 1224 | parser.add_argument("--rec_period", type=int, default=1e40, 1225 | help="period of reconstruction learning") 1226 | parser.add_argument("--update_interval", type=int, default=100, 1227 | help="interval for update target Q network") 1228 | parser.add_argument("--plot_interval", type=int, default=200, 1229 | help="interval for plotting") 1230 | parser.add_argument("--num_avg_reward", type=int, default=10, 1231 | help="interval for plotting") 1232 | parser.add_argument("--num_warm_up", type=int, default=-1, 1233 | help="number of episode warming up memory") 1234 | parser.add_argument("--run_id", default="no_td", 1235 | help="r1,r2,r3") 1236 | 1237 | global args 1238 | args = parser.parse_args() 1239 | #import sys 1240 | #with open(f'./log_screen_mem{args.model_name}.txt', 'w') as f: 1241 | if args.k_write<0: 1242 | args.k_write=args.k 1243 | run(args) -------------------------------------------------------------------------------- /dqn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 15, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math, random\n", 10 | "\n", 11 | "import gym\n", 12 | "import numpy as np\n", 13 | "\n", 14 | "import torch\n", 15 | "import torch.nn as nn\n", 16 | "import torch.optim as optim\n", 17 | "import torch.autograd as autograd \n", 18 | "import torch.nn.functional as F" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 16, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from IPython.display import clear_output\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "import os\n", 31 | "\n", 32 | "os.environ['KMP_DUPLICATE_LIB_OK']='True'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "

Use Cuda

" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 17, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "USE_CUDA = torch.cuda.is_available()\n", 49 | "Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "

Replay Buffer

" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 18, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from collections import deque\n", 66 | "\n", 67 | "class ReplayBuffer(object):\n", 68 | " def __init__(self, capacity):\n", 69 | " self.buffer = deque(maxlen=capacity)\n", 70 | " \n", 71 | " def push(self, state, action, reward, next_state, done):\n", 72 | " state = np.expand_dims(state, 0)\n", 73 | " next_state = np.expand_dims(next_state, 0)\n", 74 | " \n", 75 | " self.buffer.append((state, action, reward, next_state, done))\n", 76 | " \n", 77 | " def sample(self, batch_size):\n", 78 | " state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))\n", 79 | " return np.concatenate(state), action, reward, np.concatenate(next_state), done\n", 80 | " \n", 81 | " def __len__(self):\n", 82 | " return len(self.buffer)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "

Cart Pole Environment

" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 19, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "env_id = \"MountainCar-v0\"\n", 99 | "env = gym.make(env_id)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "

Epsilon greedy exploration

" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 20, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "epsilon_start = 1.0\n", 116 | "epsilon_final = 0.01\n", 117 | "epsilon_decay = 500\n", 118 | "\n", 119 | "epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 21, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "[]" 131 | ] 132 | }, 133 | "execution_count": 21, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | }, 137 | { 138 | "data": { 139 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZ9UlEQVR4nO3dfXQd9X3n8ff3XulK1pP1aNmWbWQb24lJAgaFmIeT0IQQYFvc3W0b3GVDSFL6RFua7u6Bk57Qsv80TdvdZuOS0DakTSgOoWniEojbJpA2FFMLHMAPGGSDbQkbyY+yLcuypO/+MSNzLWTr2rrSaGY+r3Pu0cxvRvd+RyN/PPrNb2bM3RERkfjLRF2AiIgUhwJdRCQhFOgiIgmhQBcRSQgFuohIQpRE9cGNjY3e2toa1ceLiMTS888/v9/dm8ZaFlmgt7a20t7eHtXHi4jEkpntOtsydbmIiCSEAl1EJCEU6CIiCaFAFxFJCAW6iEhCjBvoZvY1M+s2s81nWW5m9iUz6zCzl8zs8uKXKSIi4ynkCP3rwI3nWH4TsCR83Qk8MPGyRETkfI0b6O7+r8DBc6yyCvhbD2wAas1sTrEKHG3jGwf5wg9eQbf9FRE5UzH60FuAPXnznWHbO5jZnWbWbmbtPT09F/RhL+45zANP76D3xOAFfb+ISFJN6UlRd3/Q3dvcva2pacwrV8fVUJUD4MDxk8UsTUQk9ooR6F3A/Lz5eWHbpKivLAPg4PGByfoIEZFYKkagrwM+EY52WQkccfe9RXjfMTVUjhyhK9BFRPKNe3MuM3sEuA5oNLNO4D6gFMDdvwI8AdwMdAB9wB2TVSxAfRjoOkIXETnTuIHu7qvHWe7AbxatonEo0EVExha7K0XLS7NU5rIcOKZAFxHJF7tAB6irzHFQo1xERM4Qy0BvqMzppKiIyCixDPT6ypz60EVERolpoJcp0EVERolloDdUBUfoup+LiMjbYhno9ZU5Tg4O0zcwFHUpIiLTRmwDHTQWXUQkXywDXZf/i4i8UywD/e0jdI1FFxEZEctAbwjvuKirRUVE3hbLQK+vUh+6iMhosQz0ylyWXElGgS4ikieWgW5muvxfRGSUWAY66PJ/EZHRYh3oOkIXEXlbbAO9QbfQFRE5Q3wDvapMwxZFRPLENtAbq8roGxiib2Aw6lJERKaF2AZ6U3VwcdH+ozpKFxGBBAR6z7H+iCsREZkeYhvojeHVoj1HdWJURARiHOinj9AV6CIiQIwDvaGyjIxBj0a6iIgAMQ70bMaoryzTEbqISCi2gQ5BP7oCXUQkEOtAb6ouo+eYAl1EBBIQ6Pt1hC4iAsQ90KuCI3R3j7oUEZHIxTvQq8sYGBymt1+X/4uIxD7QAfarH11EpLBAN7MbzWy7mXWY2T1jLF9gZk+Z2SYze8nMbi5+qe/UVKWLi0RERowb6GaWBdYANwHLgdVmtnzUar8PPOruK4Bbgb8odqFjadTVoiIipxVyhH4l0OHuO919AFgLrBq1jgM14fRM4M3ilXh2OkIXEXlbIYHeAuzJm+8M2/L9AXCbmXUCTwC/NdYbmdmdZtZuZu09PT0XUO6ZZs4opTRr6kMXEaF4J0VXA19393nAzcA3zOwd7+3uD7p7m7u3NTU1TfhDMxmjQZf/i4gAhQV6FzA/b35e2Jbv08CjAO7+LFAONBajwPHoalERkUAhgb4RWGJmC80sR3DSc92odXYDHwEws3cTBPrE+1QK0FRdRnevAl1EZNxAd/dB4C5gPbCNYDTLFjO738xuCVf7PeBXzOxF4BHgkz5Fl28215TRrS4XERFKClnJ3Z8gONmZ3/b5vOmtwDXFLa0wzTXlHDh+klNDw5RmY32dlIjIhMQ+AZtrynFHR+kiknqxD/TZNeUA7Duih0WLSLrFPtCbw0B/q1eBLiLpFvtAnz1TR+giIpCAQK+rKCVXktERuoikXuwD3cxoriljnwJdRFIu9oEOwYlRHaGLSNolItCba8p5S1eLikjKJSLQZ9eUs+9Iv54tKiKplohAb64p58SpIT1bVERSLRmBPlNj0UVEEhHoulpURCRpga4jdBFJsUQE+qya4Nmi3Qp0EUmxRAR6eWmWuopSHaGLSKolItAhGOmiPnQRSbPEBHpL7QzePKxAF5H0Skygz62dQdfhE1GXISISmUQF+pETpzh2UhcXiUg6JSbQW+pmALBXR+kiklLJCfTaYCx6pwJdRFIqMYE+tzY4Qn9TgS4iKZWYQJ9VXU5Jxug6pEAXkXRKTKBnM8bsmeU6QheR1EpMoIPGootIuiUu0DUWXUTSKlGBPrd2Bvt6+xkcGo66FBGRKZeoQG+pm8HQsPPWUT1fVETSJ1GBrqGLIpJmiQr0kYuLNHRRRNIoUYE+coSuE6MikkYFBbqZ3Whm282sw8zuOcs6v2RmW81si5n9XXHLLExFroS6ilI6dYQuIilUMt4KZpYF1gAfBTqBjWa2zt235q2zBLgXuMbdD5nZrMkqeDwL6ivoPNQX1ceLiESmkCP0K4EOd9/p7gPAWmDVqHV+BVjj7ocA3L27uGUWbn59BbsPKtBFJH0KCfQWYE/efGfYlm8psNTMnjGzDWZ241hvZGZ3mlm7mbX39PRcWMXjuKihgq5DJzQWXURSp1gnRUuAJcB1wGrgL82sdvRK7v6gu7e5e1tTU1ORPvpMC+orGBx29ur5oiKSMoUEehcwP29+XtiWrxNY5+6n3P114FWCgJ9yC+orAdh1QN0uIpIuhQT6RmCJmS00sxxwK7Bu1DrfJTg6x8waCbpgdhavzMItaKgAUD+6iKTOuIHu7oPAXcB6YBvwqLtvMbP7zeyWcLX1wAEz2wo8BfxPdz8wWUWfy+yacnLZDLsOHo/i40VEIjPusEUAd38CeGJU2+fzph34bPiKVDZjzKubwR4doYtIyiTqStERCxo0dFFE0ieZgV5fwa4DfQR/OIiIpENiA/1o/yBHTpyKuhQRkSmT2EAHDV0UkXRJZqBr6KKIpFAyA/30EbqGLopIeiQy0CtyJcyuKWfnfgW6iKRHIgMdYFFTJTt7FOgikh6JDfSFjZXs7DmmoYsikhqJDfRFTVX09g9y4PhA1KWIiEyJBAd6cNdFdbuISFokNtAXN1YB8Pr+YxFXIiIyNRIb6C11M8iVZHSELiKpkdhAz2aM1oYKdijQRSQlEhvoAIsaq9ipLhcRSYlkB3pTJbsP9HFKD4wWkRRIeKBXMTjsetiFiKRCwgNdQxdFJD0SHeiLm4Khi691qx9dRJIv0YE+c0Ypc2aW8+pbR6MuRURk0iU60AGWNlezfZ8CXUSSL/GBvmx2NR09xxjUSBcRSbjkB3pzNQODw+zSSBcRSbjkB/rsagB1u4hI4iU+0C+eVYWZAl1Eki/xgV5emqW1oVIjXUQk8RIf6ABLm6vYrkAXkYRLRaAva67mjf3H6T81FHUpIiKTJhWBvnR2NcMOHbpiVEQSLBWB/u45NQBs29sbcSUiIpMnFYG+sKGSylyWLW8q0EUkuQoKdDO70cy2m1mHmd1zjvX+q5m5mbUVr8SJy2SM5XNr2Nx1JOpSREQmzbiBbmZZYA1wE7AcWG1my8dYrxr4HeC5YhdZDJfMncnWvb0MDXvUpYiITIpCjtCvBDrcfae7DwBrgVVjrPe/gS8A/UWsr2je0zKTvoEhXt+ve6OLSDIVEugtwJ68+c6w7TQzuxyY7+7fP9cbmdmdZtZuZu09PT3nXexEvKclODG65U11u4hIMk34pKiZZYA/A35vvHXd/UF3b3P3tqampol+9HlZ3FRFriSjfnQRSaxCAr0LmJ83Py9sG1ENvAd42szeAFYC66bbidHSbIZ3z65mc5dGuohIMhUS6BuBJWa20MxywK3AupGF7n7E3RvdvdXdW4ENwC3u3j4pFU/AJS0z2fzmEdx1YlREkmfcQHf3QeAuYD2wDXjU3beY2f1mdstkF1hM75k7k6P9g+zWvdFFJIFKClnJ3Z8AnhjV9vmzrHvdxMuaHJfNrwXgp3sOc1FDZbTFiIgUWSquFB2xtLmKilyWF3YdiroUEZGiS1Wgl2QzXDqvlhd2H466FBGRoktVoAOsWFDLtr29nBjQrXRFJFlSF+iXL6hjcNh5WePRRSRhUhfoly2oBWDTbvWji0iypC7QG6vKuKihghcU6CKSMKkLdIAV84MTo7rASESSJJWBfsVFdfQcPcmegyeiLkVEpGhSGegrFzUAsGHngYgrEREpnlQG+sWzqmisyvGsAl1EEiSVgW5mfGBRAxt2HlA/uogkRioDHYJul71H+nWjLhFJjNQG+lWL6gH1o4tIcqQ20Bc3Bf3oG3YejLoUEZGiSG2gj/SjP7tD/egikgypDXSAay9uZF9vP691H4u6FBGRCUt1oH9oafCg6h9v74m4EhGRiUt1oM+tncHS5iqefrU76lJERCYs1YEOcN2yWWx8/RDHTw5GXYqIyISkPtA/tLSJgaFhnt2h4YsiEm+pD/S21joqcll1u4hI7KU+0MtKsly9uIGnXunR8EURibXUBzrADctn03X4BFve7I26FBGRC6ZAB65f3kzG4Aeb90VdiojIBVOgA/WVOT6wsIEnN++NuhQRkQumQA/d9N7Z7Og5Tkf30ahLERG5IAr00A3LZwPw5MvqdhGReFKgh2bPLGfFglq+/7K6XUQknhToeX7+shZe2XeUbXs12kVE4keBnufnLp1LScb4h01dUZciInLeFOh56itzXLdsFt/d1MXQsC4yEpF4KSjQzexGM9tuZh1mds8Yyz9rZlvN7CUz+6GZXVT8UqfGf7m8he6jJ3mmY3/UpYiInJdxA93MssAa4CZgObDazJaPWm0T0Obu7wMeA/642IVOlQ+/axY15SU89nxn1KWIiJyXQo7QrwQ63H2nuw8Aa4FV+Su4+1Pu3hfObgDmFbfMqVNemuU/r2jhB5v3ceDYyajLEREpWCGB3gLsyZvvDNvO5tPAk2MtMLM7zazdzNp7eqbvU4JuW3kRA0PDPNquo3QRiY+inhQ1s9uANuCLYy139wfdvc3d25qamor50UW1pLmalYvqefi5XTo5KiKxUUigdwHz8+bnhW1nMLPrgc8Bt7h77Psq/vvKVjoPneDHuk+6iMREIYG+EVhiZgvNLAfcCqzLX8HMVgBfJQjzRCTgDZc0M6u6jIeeeSPqUkRECjJuoLv7IHAXsB7YBjzq7lvM7H4zuyVc7YtAFfBtM/upma07y9vFRmk2w6euXci/vbaflzuPRF2OiMi4LKqn9LS1tXl7e3skn12oo/2nuPqPfsS1FzfywG1XRF2OiAhm9ry7t421TFeKnkN1eSmfuOoifrBlHx3dx6IuR0TknBTo47jjmoXkshn+4umOqEsRETknBfo4GqvKuP3qVv5hUxfb9+nhFyIyfSnQC/Ab1y2mqqyEL65/JepSRETOSoFegNqKHL/2ocX8y7ZuNr5xMOpyRETGpEAv0KeuWUhzTRl/+I9bdPWoiExLCvQCzchl+f3/tJzNXb18c8OuqMsREXkHBfp5+Nn3zeHaixv5k/Xb6e7tj7ocEZEzKNDPg5lx/6pLODk4zH3rthDVRVkiImNRoJ+nRU1V3P3RJTy5eR/feUHPHhWR6UOBfgF+9YOLubK1nvvWbWHPwb7xv0FEZAoo0C9ANmP86S9dCsBvr93EycGhiCsSEVGgX7D59RV88Rfex6bdh7nve+pPF5HoKdAn4Kb3zuE3f2Yxazfu4ZvP7Y66HBFJuZKoC4i7z350Gdv2HuW+721mVnUZH7tkdtQliUhK6Qh9grIZ48u/vIJL59fyW49s4t937I+6JBFJKQV6EVTkSnjok++ntaGCz/xNO890KNRFZOop0IuktiLHNz/zAebXVXDHQxv5py37oi5JRFJGgV5Es6rL+davrmT53Bp+/eEX+NpPXtfoFxGZMgr0IqutyPHwZz7AR941i/sf38r/+PZL9J/SOHURmXwK9ElQWVbCV267gruvX8Lfv9DJqi8/w5Y3j0RdlogknAJ9kmQyxt3XL+WhO97Pwb4Bfn7NM/y/H76mq0pFZNIo0CfZzyybxT/d/UFuuGQ2f/rPr/Kx//Ov/OiVt6IuS0QSSIE+Beoqc6z55cv5+h3vJ5MxPvX1dlY/uIFndxyIujQRSRCLahRGW1ubt7e3R/LZURoYHOabG3bxwI930HP0JFe21nPHNa1cv7yZ0qz+fxWRczOz5929bcxlCvRo9J8aYu1/7OYv/+11ug6fYFZ1GR9//3xWXTaXi2dVR12eiExTCvRpbGjYeeqVbh5+bhdPv9qDOyxrrubm987hw++axSVza8hkLOoyRWSaUKDHxFu9/Tz58l6+//Je2ncdwh3qKkq5enEjVy1u4LL5tSybXa2uGZEUU6DHUPfRfv694wA/6djPT17bz77wodS5kgyXzK3hvS0zuXhWFYubgldzTRlmOpIXSToFesy5O3sOnuDFzsO81HmYF/ccYeveXo6dHDy9TlVZCfPqZjC3dgZza8uZMzP42lxdTl1ljvrKHLUVpZSVZCPcEhGZqHMFuu6HHgNmxoKGChY0VPBzl84FgpDvPnqSHd3H6Og5xo7uY3QeOsGbR/p5YfchDvedGvO9qspKqKsspa4iR2WuhMqyEirLssHX3MjXEirKspSVZCnNGmUlGXIlGXLZYD4XzpeVZCjNBq9sxsiYkc0YWTMyGU63jbRnDP0VITKJCgp0M7sR+HMgC/yVu//RqOVlwN8CVwAHgI+7+xvFLVXymRnNNeU015Rz9cWN71jeNzDI3iP9vNXbz+G+Uxw8PsCh4wMc6jvFob4BDh4f4PjJQboOn6BvYJDjJwc5fnKIE5N835mM8Y7wtzDoR7Lewu0bif6g3U5P57fbmO2W933nXm/a/fcyzQqaZuVMuwOCC63mtz+y5PTBWTGNG+hmlgXWAB8FOoGNZrbO3bfmrfZp4JC7X2xmtwJfAD5e9GqlYBW5ktP96+djaNjpGxikb2CIgcFhTg4Oc2pomIHBYQZGvo5qPzU0zJA7w8PO0LAz5ATTHsy7O0PDvL3OGes67py+K6VDME847zDSKRisktceLnA8b/rM7+eM7/cz3mu63Qdzut2Zc3pVw7QryCdQ0MwZpUWs5G2FHKFfCXS4+04AM1sLrALyA30V8Afh9GPAl83MfLr9hsq4shmjuryU6vLJ+YUTkclTyPi3FmBP3nxn2DbmOu4+CBwBGka/kZndaWbtZtbe09NzYRWLiMiYpnRAs7s/6O5t7t7W1NQ0lR8tIpJ4hQR6FzA/b35e2DbmOmZWAswkODkqIiJTpJBA3wgsMbOFZpYDbgXWjVpnHXB7OP0LwI/Ufy4iMrXGPSnq7oNmdhewnmDY4tfcfYuZ3Q+0u/s64K+Bb5hZB3CQIPRFRGQKFTQO3d2fAJ4Y1fb5vOl+4BeLW5qIiJwP3eVJRCQhFOgiIgkR2c25zKwH2HWB394I7C9iOXGgbU4HbXM6TGSbL3L3Mcd9RxboE2Fm7We721hSaZvTQducDpO1zepyERFJCAW6iEhCxDXQH4y6gAhom9NB25wOk7LNsexDFxGRd4rrEbqIiIyiQBcRSYjYBbqZ3Whm282sw8zuibqeC2Vm883sKTPbamZbzOx3wvZ6M/tnM3st/FoXtpuZfSnc7pfM7PK897o9XP81M7v9bJ85XZhZ1sw2mdnj4fxCM3su3LZvhTeBw8zKwvmOcHlr3nvcG7ZvN7OPRbQpBTGzWjN7zMxeMbNtZnZV0vezmf1u+Hu92cweMbPypO1nM/uamXWb2ea8tqLtVzO7wsxeDr/nS2YFPH/P3WPzIrg52A5gEZADXgSWR13XBW7LHODycLoaeBVYDvwxcE/Yfg/whXD6ZuBJgscYrgSeC9vrgZ3h17pwui7q7Rtn2z8L/B3weDj/KHBrOP0V4NfD6d8AvhJO3wp8K5xeHu77MmBh+DuRjXq7zrG9fwN8JpzOAbVJ3s8ED7x5HZiRt38/mbT9DHwQuBzYnNdWtP0K/Ee4roXfe9O4NUX9QznPH+BVwPq8+XuBe6Ouq0jb9j2C57ZuB+aEbXOA7eH0V4HVeetvD5evBr6a137GetPtRXA//R8CHwYeD39Z9wMlo/cxwR0+rwqnS8L1bPR+z19vur0Ing3wOuEAhNH7L4n7mbefYFYf7rfHgY8lcT8DraMCvSj7NVz2Sl77Geud7RW3LpdCHocXO+GfmCuA54Bmd98bLtoHNIfTZ9v2uP1M/i/wv4DhcL4BOOzBowvhzPrP9mjDOG3zQqAHeCjsZvorM6skwfvZ3buAPwF2A3sJ9tvzJHs/jyjWfm0Jp0e3n1PcAj1xzKwK+HvgbnfvzV/mwX/NiRlXamY/C3S7+/NR1zKFSgj+LH/A3VcAxwn+FD8tgfu5juDB8QuBuUAlcGOkRUUgiv0at0Av5HF4sWFmpQRh/rC7fydsfsvM5oTL5wDdYfvZtj1OP5NrgFvM7A1gLUG3y58DtRY8uhDOrP9sjzaM0zZ3Ap3u/lw4/xhBwCd5P18PvO7uPe5+CvgOwb5P8n4eUaz92hVOj24/p7gFeiGPw4uF8Iz1XwPb3P3P8hblP87vdoK+9ZH2T4Rny1cCR8I/7dYDN5hZXXhkdEPYNu24+73uPs/dWwn23Y/c/b8BTxE8uhDeuc1jPdpwHXBrODpiIbCE4ATStOPu+4A9ZrYsbPoIsJUE72eCrpaVZlYR/p6PbHNi93OeouzXcFmvma0Mf4afyHuvs4v6pMIFnIS4mWBEyA7gc1HXM4HtuJbgz7GXgJ+Gr5sJ+g5/CLwG/AtQH65vwJpwu18G2vLe61NAR/i6I+ptK3D7r+PtUS6LCP6hdgDfBsrC9vJwviNcvijv+z8X/iy2U8DZ/4i39TKgPdzX3yUYzZDo/Qz8IfAKsBn4BsFIlUTtZ+ARgnMEpwj+Evt0Mfcr0Bb+/HYAX2bUifWxXrr0X0QkIeLW5SIiImehQBcRSQgFuohIQijQRUQSQoEuIpIQCnQRkYRQoIuIJMT/B3B4SePGsjO/AAAAAElFTkSuQmCC\n", 140 | "text/plain": [ 141 | "
" 142 | ] 143 | }, 144 | "metadata": { 145 | "needs_background": "light" 146 | }, 147 | "output_type": "display_data" 148 | } 149 | ], 150 | "source": [ 151 | "plt.plot([epsilon_by_frame(i) for i in range(10000)])" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "

Deep Q Network

" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 22, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "class DQN(nn.Module):\n", 168 | " def __init__(self, num_inputs, num_actions):\n", 169 | " super(DQN, self).__init__()\n", 170 | " \n", 171 | " self.layers = nn.Sequential(\n", 172 | " nn.Linear(env.observation_space.shape[0], 128),\n", 173 | " nn.ReLU(),\n", 174 | " nn.Linear(128, 128),\n", 175 | " nn.ReLU(),\n", 176 | " nn.Linear(128, env.action_space.n)\n", 177 | " )\n", 178 | " \n", 179 | " def forward(self, x):\n", 180 | " return self.layers(x)\n", 181 | " \n", 182 | " def act(self, state, epsilon):\n", 183 | " if random.random() > epsilon:\n", 184 | " state = Variable(torch.FloatTensor(state).unsqueeze(0), volatile=True)\n", 185 | " q_value = self.forward(state)\n", 186 | " action = q_value.max(1)[1].data[0]\n", 187 | " else:\n", 188 | " action = random.randrange(env.action_space.n)\n", 189 | " return action" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 23, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "model = DQN(env.observation_space.shape[0], env.action_space.n)\n", 199 | "\n", 200 | "if USE_CUDA:\n", 201 | " model = model.cuda()\n", 202 | " \n", 203 | "optimizer = optim.Adam(model.parameters())\n", 204 | "\n", 205 | "replay_buffer = ReplayBuffer(1000)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "

Computing Temporal Difference Loss

" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 24, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "def compute_td_loss(batch_size):\n", 222 | " state, action, reward, next_state, done = replay_buffer.sample(batch_size)\n", 223 | "\n", 224 | " state = Variable(torch.FloatTensor(np.float32(state)))\n", 225 | " next_state = Variable(torch.FloatTensor(np.float32(next_state)), volatile=True)\n", 226 | " action = Variable(torch.LongTensor(action))\n", 227 | " reward = Variable(torch.FloatTensor(reward))\n", 228 | " done = Variable(torch.FloatTensor(done))\n", 229 | "\n", 230 | " q_values = model(state)\n", 231 | " next_q_values = model(next_state)\n", 232 | "\n", 233 | " q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)\n", 234 | " next_q_value = next_q_values.max(1)[0]\n", 235 | " expected_q_value = reward + gamma * next_q_value * (1 - done)\n", 236 | " \n", 237 | " loss = (q_value - Variable(expected_q_value.data)).pow(2).mean()\n", 238 | " \n", 239 | " optimizer.zero_grad()\n", 240 | " loss.backward()\n", 241 | " optimizer.step()\n", 242 | " \n", 243 | " return loss" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 25, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "def plot(frame_idx, rewards, losses):\n", 253 | " clear_output(True)\n", 254 | " plt.figure(figsize=(20,5))\n", 255 | " plt.subplot(131)\n", 256 | " plt.title('frame %s. reward: %s' % (frame_idx, np.mean(rewards[-10:])))\n", 257 | " plt.plot(rewards)\n", 258 | " plt.subplot(132)\n", 259 | " plt.title('loss')\n", 260 | " plt.plot(losses)\n", 261 | " plt.show()" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "

Training

" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 26, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "data": { 278 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAwQAAAE/CAYAAAD42QSlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAsTAAALEwEAmpwYAABKCUlEQVR4nO3deZxcVZn/8c/TK4EgYYmILAYU9YeOokaEcUdUUEfQ0fmh/hSXGXSUUWec0aAzDDqioI4igigqAoogIkgkEWQJ+5qE7At09oSE7Et3p7u25/fHvdWpdKqqq7qq7q2+9X2/XvXqqrvVufdWV93nnuecY+6OiIiIiIi0pra4CyAiIiIiIvFRQCAiIiIi0sIUEIiIiIiItDAFBCIiIiIiLUwBgYiIiIhIC1NAICIiIiLSwhQQJJyZvczM5pjZLjP7YtzlkcYxs0+a2UNxl0NEJKnMbKWZnRZ3OUTqTQFB8n0VmOHuB7r7ZXEXZjgzu8rMlppZzsw+WWT+v5rZBjPbaWZXm1l3wbxJZjbDzPrNbMnwL+la1m1FZvYDM3smDB6XmNknhs0/0cxmhcdslpmdWDDPzOwSM9sSPi4xMyvzXh81s1Vm1mdmfzKzQxq4ayIiIlKGAoLkexGwsNRMM2uPsCzFzAU+D8wePsPM3g1MAd5BsB/HAd8sWOQG4CngUOAbwM1mNrHWdathZh3VrlMPDXrfPuDvgIOAc4Afm9nfhu/XBdwG/BY4GLgWuC2cDnAucBbwauBV4XY+W6LsrwB+DnwcOBzoB37agP0RERGRCiggSDAzuxd4O3C5mfWa2UvN7Bozu9LMpptZH/B2M3uvmT0V3klfY2YXFmxjkpm5mX0qnLfNzD5nZq83s3lmtt3MLh/2vp82s8Xhsnea2YtKldHdr3D3e4CBIrPPAX7l7gvdfRvwP8Anw/d4KfBa4L/dfbe7/xGYD/x9HdYd6biuNLOvmdk8oM/MOszsZDN7JDwec83sbeGybzez+QXr3mVmTxa8ftDMzgqfTzGzZeEd+kVm9oGC5T5pZg+b2Y/MbAtwoZkdamZTw/P2BPDiSspfirv/t7svcfecuz8OPAicEs5+G9ABXOrug2FtkwGnhvPPAf7X3de6+zrgfwmPdxEfA/7s7g+4ey/wX8AHzezAWsovIhIVM+s2s0vN7NnwcWm+FtrMDjOz28Pfg63h93xbOO9rZrYu/J5fambviHdPRAIKCBLM3U8luKg7z93Hu/vT4ayPAhcBBwIPEdwZ/gQwAXgv8M/5i9QCbwCOB/4vcCnBXfXTgFcA/2BmbwUwszOBrwMfBCaG73/DKHfhFQQ1CHlzgcPN7NBw3nJ33zVs/ivqsG4lPkJwrCYQ3OWeBnwbOAT4d+CPYY3DY8Dx4Q9EJ8Hd8xea2YFmNg6YTHCMAJYBbya4Q/9N4LdmdkTBe74BWB6+30XAFQSB1BHAp8PHkPAHaUoV+1S47jjg9eypXXoFMM/dvWCxeZQ/3qWO517LuvsyIAW8dDRlFRGJwTeAk4ETCWpGTwL+M5z3FWAtwW/g4QS/iW5mLwPOA17v7gcC7wZWRlpqkRIUELSm29z94fBO8IC73+fu88PX8wgu4N86bJ3/CZf9K0EAcYO7bwzvBj8IvCZc7nPAd919sbtngO8AJ5arJShjPLCj4HX++YFF5uXn5+8y17JuJS5z9zXuvhv4f8B0d58eHsO7gJnAe8L5TwJvAV5HcCH8MPBGgh+TZ9x9C4C7/8Hdnw238XvgGYIfmbxn3f0n4XFNEdRoXODufe6+gCCNZ4i7v8/dL65inwr9LCzrneHr0Rzv8SXaEdTj+IuIxOljwLfC38FNBDdxPh7OSxPcqHmRu6fd/cHwZkoW6AZOMLNOd18Z3hARiZ0Cgta0pvCFmb3Bgga2m8xsB8FF/WHD1nmu4PnuIq/Hh89fRJB7vt3MtgNbCVJLjhxFOXuB5xW8zj/fVWRefn7+rn8t61ai8Bi+CPhwfp/D/X4TwQ8CwP0EKTdvCZ/fRxBwvTV8DYCZfcKCHqHy23gle5+HwvecSJDCUzhtVaWFN7OfhWlkvWb29WHzvh++9z8U1AiM5nj3DqtRoMSyw7clItLsXsje37mrwmkA3wd6gL+a2fJ8Ta279wBfBi4ENprZjWb2QkSagAKC1jT8Iu13wFTgaHc/iODucMkeYkawBvisu08oeIxz90dGsa2FBFWxea8GngvvqC8EjhuWd/5q9qS41LJuJQqP4RrgN8P2+YCCu/PDA4L7GRYQhDUovyCoTj7U3ScAC9j7PBS+5yYgAxxdMO2Yigvv/rkwjWy8u38nP93MvgmcAbzL3XcWrLIQeNWwO/6vovzxLnU891rWzI4juGv2dInlRUSazbMEN4Pyjgmn4e673P0r7n4c8H7g3/JtBdz9d+7+pnBdBy6JttgixSkgEAhSNba6+4CZnUTQxmC0fgacb0FPMpjZQWb24VILm1mXme1HcOHbaWb75RtfAdcBnzGzE8xsAkF+5jUAYXuIOcB/h+t8gOAC9Y91WLdavwX+zszebWbt4TbfZmZHhfMfAV5GkP7zhLsvJPgxeAPwQLjMAQQ/DpvC4/Ipgrv0Rbl7FriFoHHx/mZ2AkHD3lEzs/MJzv1p+TSmAvcRVHd/MWxMd144/d7w73UEP3pHhne8vkJ4vIu4nuB4vdnMDgC+BdwyrE2HiEgzuwH4TzObaGaHARcQ/BZgZu8zs5eEN1B2EHx35iwYF+jUsPHxAEHtei6m8ovsRQGBQNDt57fMbBfBl9pNo92Qu99KcMfjRjPbSXCX+4wyq/yV4Evxb4GrwudvCbd1B/A9YAawmqBK9r8L1j2boFHuNuBi4ENhLmdN65rZx8ys4toCd18D5BtTbyKoMfgPwv8vd+8j6FZ1obunwtUeBVa5+8ZwmUUEPfM8SpCO9TcEbQ3KOY8gVWsDwcX3rwtnmtlfhqcDjeA7BHe5eoanE4XlPoug8fl2ggbMZxXsz8+BPxP01rSAoJH1zwvK0mtmbw63tZAgLe16YCNBQPr5KsopIhK3bxO0FZtH8L03O5wGQQccdxOkRz4K/NTdZxDUhF4MbCb43n4+cH60xRYpzoqn+IqIiIiISCtQDYGIiIiISAtTQCAiIiIi0sIUEIiIiIiItDAFBCIiIiIiLUwBgYiIiIhIC+uIuwCVOuyww3zSpElxF0NEpOnMmjVrs7tPjLsccdJvhIhIaSP9ToyZgGDSpEnMnDkz7mKIiDQdM1sVdxnipt8IEZHSRvqdUMqQiIiIiEgLU0AgIiIiItLCFBCIiIiIiLQwBQQiIiIiIi1MAYGIiIiISAtTQCAiIiIi0sIUEIiIiIiItDAFBCIiIiIiLUwBgYiIiIhIC1NA0CSe2znAkg074y6GiIiISGQeWbaZVCYXdzFangKCJnHp3U/z+etnx10MERERkUgsWLeDj/7icb4zfXHcRWl5CgiaxM6BDL0DmbiLISIiIhKJbf0pAHo29sZcElFA0CTSmRzprKrMRERERCRaCgiaRCqbUw6diCSSme1nZk+Y2VwzW2hm3wynX2NmK8xsTvg4MZxuZnaZmfWY2Twze22sOyAiknAdcRdAAulsjnTW4y6GiEgjDAKnunuvmXUCD5nZX8J5/+HuNw9b/gzg+PDxBuDK8K+IiDSAagiaRDrjpLI53BUUiEiyeCCfJNwZPsp92Z0JXBeu9xgwwcyOaHQ5RURalQKCJjEYth9QLYGIJJGZtZvZHGAjcJe7Px7OuihMC/qRmXWH044E1hSsvjacJiIiDaCAoEmkM/mAQO0IRCR53D3r7icCRwEnmdkrgfOBlwOvBw4BvlbNNs3sXDObaWYzN23aVO8ii4i0DAUETSKdVUAgIsnn7tuBGcDp7r4+TAsaBH4NnBQutg44umC1o8Jpw7d1lbtPdvfJEydObHDJRaRauZyzdlt/3MWQCiggaBL5QCClgEBEEsbMJprZhPD5OOCdwJJ8uwAzM+AsYEG4ylTgE2FvQycDO9x9feQFF5Ga/PS+Ht50yQyWb9I4A81OvQw1iXyXo+p6VEQS6AjgWjNrJ7gRdZO7325m95rZRMCAOcDnwuWnA+8BeoB+4FPRF1lEavXIsi0ArN8xwHETx8dcGilHAUGTSIWNidWoWESSxt3nAa8pMv3UEss78IVGl0tERAJKGWoSakMgIiIiInFQQNAklDIkIiIiInFQQNAk1KhYREREROKggKAJ5HJOJhe2IVANgYiIiLQAV7PJpqGAoAkU1gqohkBERERaiVncJRAFBE2gsCGxGhWLiIiISJQUEDSBwq5GUxnVn4mIiIhIdBQQNIHCnoWUMiQiIiIiUVJA0AT2ShlSo2IRERERiZACgiaQUhsCEREREYmJAoImkFYvQyIiIiISEwUETWCvNgRKGRIRERGRCCkgaAJ7dzuqXoZEREREJDoKCJpAYVejakMgIiIiIlGqKSAwsw+b2UIzy5nZ5ILp7zSzWWY2P/x7asG814XTe8zsMjONT7fXSMVKGRIRERGRCNVaQ7AA+CDwwLDpm4G/c/e/Ac4BflMw70rgn4Djw8fpNZZhzCvsalQ1BCIiIiISpY5aVnb3xQDDb/K7+1MFLxcC48ysGzgEeJ67Pxaudx1wFvCXWsox1qmXIRERERGJSxRtCP4emO3ug8CRwNqCeWvDaS1NKUMiIiIiEpcRawjM7G7gBUVmfcPdbxth3VcAlwDvGk3hzOxc4FyAY445ZjSbGBNSShkSERERkZiMGBC4+2mj2bCZHQXcCnzC3ZeFk9cBRxUsdlQ4rdR7XwVcBTB58uTE9seZ72q0u6NN3Y6KiIhIS9AVT/NoSMqQmU0ApgFT3P3h/HR3Xw/sNLOTw96FPgGUrWVoBalMFoADujuUMiQiIiIikaq129EPmNla4BRgmpndGc46D3gJcIGZzQkfzw/nfR74JdADLKPFGxTDnhqCA7rb1ahYRERERCJVay9DtxKkBQ2f/m3g2yXWmQm8spb3TZp8EHBAV4faEIiIiIhIpDRScRPIBwH7d7UrZUhEREREIqWAoAmkMjk62ozujnbVEIiIiEhLsJEXkYgoIGgC6WyOro42ujraSKmXIRERERGJkAKCJpDOOp3tbXS2tyllSEQSx8z2M7MnzGyumS00s2+G0481s8fNrMfMfm9mXeH07vB1Tzh/Uqw7ICKScAoImsBgJkdnextdHaaUIRFJokHgVHd/NXAicLqZnUwwcOWP3P0lwDbgM+HynwG2hdN/FC4nIiINooCgCaSzObo72uhqb1NAICKJ44He8GVn+HDgVODmcPq1wFnh8zPD14Tz3xGOXSMiIg2ggKAJpLM5OtuNzvY20koZEpEEMrN2M5sDbATuIhiHZru7Z8JF1gJHhs+PBNYAhPN3AIcW2ea5ZjbTzGZu2rSpwXsgIpJcCgiaQCpMGersaNPAZCKSSO6edfcTgaOAk4CX12GbV7n7ZHefPHHixFo3JyLSshQQNIGghiBIGVKjYhFJMnffDswgGOF+gpnlB8g8ClgXPl8HHA0Qzj8I2BJtSUVEWocCgiaQyvpQt6NpdTsqIgljZhPNbEL4fBzwTmAxQWDwoXCxc4DbwudTw9eE8+91d305iog0SMfIi0ijpTJZutrb6Gw3pQyJSBIdAVxrZu0EN6JucvfbzWwRcKOZfRt4CvhVuPyvgN+YWQ+wFTg7jkKLiLQKBQRNIJ119usMxiHI5pxszmlvU4caIpIM7j4PeE2R6csJ2hMMnz4AfDiCoolIjFTt1zyUMtQE0tkcXe1BylD+tYiIiIhIFBQQNIF8L0Nd7cHpUNqQiIiIiERFAUETSGVzdHYEKUOAxiIQERERkcgoIGgC6WyO7r1ShpRVJyIiIiLRUEDQBNIZDwYmy6cMqYZARERERCKigKAJBClDRme7Db0WEREREYmCAoImkM7k6Gpvp1u9DImIiIhIxBQQNIE9NQRKGRIRERGRaCkgiJm7kwrHIRjqZUg1BCIiIiISEQUEMcvmHHf2GphMbQhEREREJCoKCGKW72J0r3EI1O2oiIiIiEREAUHM8u0F9hqpWG0IRERERCQiCghilk8P6mo3OjuCbkfVhkBEREREoqKAIGb5i/+ujj01BAoIRERERCQqCghilr/4LxypeFApQyIiIiISEQUEMdurDYEGJhMREZEW4a5OVJqFAoKYpYqlDKmGQERERFqEmcVdhJangCBm+S5Gu9rb6NQ4BCIiIiISMQUEMStMGepsz/cypCo0EREREYmGAoKYFetlSOMQiIiIiEhUFBDELDXUy5BhZnS2m1KGRERERCQyCghiVpgylP+rRsUiIiIiEhUFBDHLpwx1hw2Kuzra1O2oiIiIiERGAUHMCgcmy/9VypCIiIi0Co1HED8FBDEbShnK1xC0t5HK6B9DREREkk3jDzQPBQQxS4VdjOa7HO1sN6UMiUiimNnRZjbDzBaZ2UIz+1I4/UIzW2dmc8LHewrWOd/MesxsqZm9O77Si4gkX0fcBWh1+QbE3e3tgNoQiEgiZYCvuPtsMzsQmGVmd4XzfuTuPyhc2MxOAM4GXgG8ELjbzF7q7tlISy0i0iJUQxCzoTYEHfkagjaNQyAiieLu6919dvh8F7AYOLLMKmcCN7r7oLuvAHqAkxpfUhGR1qSAIGbFuh1Vo2IRSSozmwS8Bng8nHSemc0zs6vN7OBw2pHAmoLV1lI+gBARkRooIIhZOpvDDDraghoCpQyJSFKZ2Xjgj8CX3X0ncCXwYuBEYD3wv1Vu71wzm2lmMzdt2lTv4oqINMz8tTtYvH5n3MUYooAgZqms09neNtTSvkspQyKSQGbWSRAMXO/utwC4+3PunnX3HPAL9qQFrQOOLlj9qHDaXtz9Knef7O6TJ06c2NgdEJG6S3J3o+5OpswN3r+7/CHO+PGDEZaoPAUEMUtlcnS17zkNQS9Dyf0HEZHWY8Edj18Bi939hwXTjyhY7APAgvD5VOBsM+s2s2OB44EnoiqviEQrid2PfvPPi3jJN/4yZoIe9TIUs3Q2R1fHnoBAKUMikkBvBD4OzDezOeG0rwMfMbMTAQdWAp8FcPeFZnYTsIigh6IvqIchERlLrnlkZdxFqIoCgpils7mhMQhAvQyJSPK4+0NAsVuA08uscxFwUcMKJSIiQ2pKGTKzD4eDzOTMbHKR+ceYWa+Z/XvBtNPDgWZ6zGxKLe+fBKlMbqiHIQjbEKiGQEREWtCfnlrH9Pnr4y6GSMuptQ3BAuCDwAMl5v8Q+Ev+hZm1A1cAZwAnEFQXn1BjGca0lFKGREREAPjy7+fw+etnx10MkZZTU8qQuy+G4o1BzOwsYAXQVzD5JKDH3ZeHy9xIMADNolrKMZals8MbFStlSERERESi05BehsK+pr8GfHPYLA02M8zwlKHO9jb1MiQiIiIikRkxIDCzu81sQZHHmWVWuxD4kbv31lK4Vhh0Jp31vRsVd5jaEIiIiIi0iP5UhklTpvGrh1bEVoYRU4bc/bRRbPcNwIfM7HvABCBnZgPALCoYbKbgva8CrgKYPHlyIm+bD29D0B2mDLl7IvvlFREREZE9tvalALj6oRV85k3HxlKGhnQ76u5vzj83swuBXne/3Mw6gOPDgWbWAWcDH21EGcaKdDbH+O49pyGfPpTJ7V1zICIiIiLSCLV2O/oBM1sLnAJMM7M7yy3v7hngPOBOYDFwk7svrKUMY90+IxWHtQXqaUhEREREolBrL0O3AreOsMyFw15Pp8xgNK1mn5GKw+AgnXHoiqtUIiIiItIqGtLLkFQuaFS8bw3BYDYbV5FEREREpIUoIIjZviMVB+0G1PWoiIiISOtwj+/aTwFBzIqNVAyQ1uBkIiIikmC69Rlohl4lFRDELBipuGAcgrC2QGMRiIhILS65YwmTpkyLuxgiI4r/clgUEMSs2EjF+ekiIiKjdeV9y+IugoiMEQoIYrZPL0PqdlREREREIqSAIEbuvk8vQ12qIRARERGRCCkgiFG+nUBhDUE+OFAvQyIiIiKtI84rPwUEMcpf9Hfu1ag43+2oaghEREREkq4ZGlUrIIhRvmvRrvZ92xAMKmVIRERERCKggCBG+VqAzo592xCohkBEREREoqCAIEb5WoBi3Y4qIBARERGRKCggiFH+or+7SLej6mVIREREZGzzMdJHjAKCGO1pVKwaAhEREZFWFmfwoIAgRqkiKUND4xCo21ERERGRxLMm6GZIAUGMio1DoJGKRURERFpHM6QVKSCI0VAvQ0XGIVAbAhFJCjM72sxmmNkiM1toZl8Kpx9iZneZ2TPh34PD6WZml5lZj5nNM7PXxrsHIvU1mMkye/W2uIvRNJrgergpxFlToIAgRqki4xC0txlmqiEQkUTJAF9x9xOAk4EvmNkJwBTgHnc/HrgnfA1wBnB8+DgXuDL6Ios0zv/cvogP/vQRlm3qjbsosWqCTJmGO+7r0/n5/cviLsaIFBDEKF0kZcjM6GpvG0onEhEZ69x9vbvPDp/vAhYDRwJnAteGi10LnBU+PxO4zgOPARPM7IhoSy3SOAuf3QnA9v50zCWJV6vUDFz36Kq4izAiBQQx2pMytPdp6GpvU8qQiCSSmU0CXgM8Dhzu7uvDWRuAw8PnRwJrClZbG04TkQRqhZqCSqiXoRZVbGAyCEYuVsqQiCSNmY0H/gh82d13Fs5zd6fKG4Zmdq6ZzTSzmZs2bapjSUVEoqNehlpcfhyCwoHJIKghSGdapSJNRFqBmXUSBAPXu/st4eTn8qlA4d+N4fR1wNEFqx8VTtuLu1/l7pPdffLEiRMbV/gG2dw7yNt/cB8rNvfFXRQRaXEKCGJUKmWos8PUhkBEEsPMDPgVsNjdf1gwaypwTvj8HOC2gumfCHsbOhnYUZBalBjT569nxeY+rn5oRdxFaRnnXP0Ef3/lIxUtO2nKNC65Y0mDSyRJMGfNdu5YsCHuYtREAUGMinU7GrxWo2IRSZQ3Ah8HTjWzOeHjPcDFwDvN7BngtPA1wHRgOdAD/AL4fAxllgS6/+lNzFpVeXefV97X/L3DSPzOuuJhPvfbWXEXoyYdcReglQ2NVFw0ZUgBgYgkg7s/ROl2g+8osrwDX2hooURERuGU797Dx095EZ9/20viLkpdqYYgRkMjFQ/vZahDNQQiIiKVuOGJ1dwye23cxZAWsX7HAN+7Y2lDtu0xdsSqgCBG+YbD+7QhaFcvQyIiIpU4/5b5/NtNc+MuxiipA5FC2ZxzxYweegczcRelrnyE/kStCTpeVUAQo1Q2S3ub0d42vA2BqZchEZEE6Nm4i92pbNzF2MfNs9by8V89vte0OxZsYM6a7fEUqA5O++H93PjE6riLUZH4L/+a0/T56/n+nUu55C9qzB01BQQxSmd9n3QhgK6OdqUMiYiMcYOZLKf98AG+8LvZcRdlH//+h7k8+MzmvaZ97rezOOuKh2MqUeXWbutnwbod+0zv2djLlFvmx1AiqZf8+Ex9qWTVEIwFCghilMrk9ulhCKCr3TRSsYjIGJfNBTW9jy7bEnNJkuVNl8zgfT95KO5i1M2kKdO4YkZP3MWQFqeAIEapbI6ujn1PgdoQiIgk3whpxdJCvn9nYxqpjlkt+r8R53eCAoIYpTO5EilDCghERJJipJ5DrE4J5d+dvnifdgFJ9g8/ezTuIozo90+u5rjzp+k3vUKt2raiXt8BtdA4BDFKZ3P7jEEA4cBkShkSERnTou455OcPLI/0/WqxvT/FhP27atrGEyu31qk0jXPRtMXkHPoHsxy0v+7BNtLuVJaujrZ9OmqRyujTGaNUNrdPl6OQH6m4RevLREQk0e5Z/BwnfusuHlm2eeSFW0wu57zlezP401Pr4i5KNEpc6ozmCuj/XHAHX7zxqZqK08oUEMQolSney1C3UoZERBKj0XnBuZwzmKlf16bz1m5nyYadddvecPk7+3PX7NtT0Fj3jVvnM2nKNCDocnbnQPHeckp9JNK5HKu39vPVm+c1qITNKZ8yU2vqzLR562svTItSylCMSqcMqZchEZGxLqq84C/e+BS31/FC6P2XB12Prrz4vXXbZqu4/vE94yCc9sMHKlhD6S3SHFRDEKN0NkdXkW5H1cuQiEiyrN7S37Bt1zMYkKgpPViagwKCGAXjEBRvQ5DJObmcvihERMa6wUyOt3x/BlPnPht3UaRJ5G8FquvZ4rzgwAyks9w2Z91e06T+FBDEKF1iHIL8NI1WLCKSHMVG1wVdFLYEZQZVpFia3ffuWMqXbpzDQz1qhN5ICghilMp60RqCfENjpQ2JiIhIK9uwczcAO3cXb6At9aGAIEapTLZoL0OdYbuCtLoeFRFJvGYYlEgaTD/nVdHhip4Cghils14iZagdQD0NiYhI8hS52tven+K2OS3S976UFPVgfnlrtvbzzT8vbOm2m+p2NEbpbG6oNqDQnhoCBQQiImOV7vyXV3h8vnjjHB54elN8hYlJqctP1z3ySJ13w1PMXbOds048klcfPSG2csR51lVDEKNSvQypUbGIiLSS9dt3x12ExqowOIzrDnmzKWxoH0Wj+7h7MGqGs66AIEapUr0MqVGxiIiINKFVW/rYsTvdkG2Xq1VTjVtj1RQQmNmHzWyhmeXMbPKwea8ys0fD+fPNbL9w+uvC1z1mdplZ657iYGCy4uMQgNoQiIhI67rqgWVc9+jKuIvRUGOxy9m3fv8+3n/5Qw3ZdlzHo28w6MGoL9WYnozGwmmutYZgAfBBYK/xuc2sA/gt8Dl3fwXwNiAfTl4J/BNwfPg4vcYyjFklBybrUA2BiIi0tu9MX8IFty2MuxgNkb8XGneqymitauDI23FYtqkPgF8+uCLmksSnpoDA3Re7+9Iis94FzHP3ueFyW9w9a2ZHAM9z98c8+C+4DjirljKMVdmck3PKpgylMmPzi0JERESkWnHnjOTGaIBWD41qQ/BSwM3sTjObbWZfDacfCawtWG5tOK3l5O/+F29UHPxHqFGxiEhyjdW7w6Xc+tRaNu4ciLsY0uTmrtnOpCnTWL+jsobkSfk3Wb6pl0lTpjFnzfaSy8S5ryMGBGZ2t5ktKPI4s8xqHcCbgI+Ffz9gZu+otnBmdq6ZzTSzmZs2Jas7stRQQFCs29EwZUhtCEREEi8JDem29aX419/P5ZxfP1nxOo26+Jk0ZRp3LNjQmI03WCt0N/qbx1YB8OAzm6va3yj+T4Z/Jm+bs65uQe59S4Pr2D89VWS8jSb4EhhxHAJ3P20U210LPODumwHMbDrwWoJ2BUcVLHcUUHIkEne/CrgKYPLkyYn6L8k3GC6WMtSpXoZERGI1mMnSHQ4SKSPLhAM6bdo1uouneqeKXP3wCk5/5Qvqu9EGSNSFTQ2Gn/64jkv+fWcs2cinrgmC25e/4EDu+PJb9lm2P5Vh/67kDOfVqJShO4G/MbP9wwbGbwUWuft6YKeZnRz2LvQJ4LYGlaGp5S/2i/UypHEIRCRpzOxqM9toZgsKpl1oZuvMbE74eE/BvPPD3uiWmtm7oyzrvUue42X/eQdzy1Ttt6LdqSzZOo7kGne+eDNq1XEImi0wKuzdakOJGoJ5a3dEVJpo1Nrt6AfMbC1wCjDNzO4EcPdtwA+BJ4E5wGx3nxau9nngl0APsAz4Sy1lGKvSYYPhom0I1O2oiCTPNRTvVe5H7n5i+JgOYGYnAGcDrwjX+amZRXa7/g8zg6Zu1z++qq7bHcttBnbsTvN/LriDf//D3KrWe3Ll1opzxVvJWP4s1FMjAqBv/XnRqI9vK5+XWnsZutXdj3L3bnc/3N3fXTDvt+7+Cnd/pbt/tWD6zHDai939PG/Ro5/KZoE9XYwW2pMy1JKHRkQSyN0fALZWuPiZwI3uPujuKwhuIJ3UsMINM2PpRgBumrl2hCXHvqUbdjFpyjRWl+lGcsbSjbz6m38F4Nan1vGNW+czkM5WtP0P/+xR3v6D++pR1ERqzfqAkdVyaXj1wytYs7V+QWg9rlLHwnnWSMUxyXcpWjZlKFPZF66IyBh2npnNC1OKDg6nHQmsKVimaI90jep4orOtMT+Nt8xex+8eX73P9OsfX12255FG+sPM4DDfubB0I9wnVuwdx13/+OriDSNLGEirtrtUepRu++2tXmPVNls6WuXnOb5PhAKCmAy1Iego1suQhcvoq0JEEu1K4MXAicB64H+rWdndr3L3ye4+eeLEifUrVYMuJrb0pfj6rfP3mZ7JOWdd8XBj3rRCF01fXNXy+nWqzZ6ByWIuiIwoikSWZmg7ooAgJqky4xDkp6lRsYgkmbs/5+5Zd88Bv2BPWtA64OiCRcv2SDdW1etuaDMZ7bVT0i6Mn35u116vk7Z/w/1x1lp+U9AQtyJljkkSD1e5f/dm6G5WAUFM8mMMFE0ZUqNiEWkB4ej1eR8A8j0QTQXONrNuMzsWOB54IuryNVqSmtAlMLapybt+9EDcRYjUV/4wl/+6bWFFy5b7qOhjFN8RUEAQk6EagiKNitvajI420zgEIpIYZnYD8CjwMjNba2afAb5nZvPNbB7wduBfAdx9IXATsAi4A/iCu6tR1SjdMrt04+g4QpJi79mqAUWpO8MJihUBuPzeHnYNpEe1btx3zxN2KkpKzogKY0yqTA0BBGlDCghEJCnc/SNFJv+qzPIXARc1rkRlJOwK4N9umstbXzqRQ8d3x12UvbRoDMBtc9aVbEQ+1gKjXM4rSm9+dPkWvn37Yi750KtG/V7VHBulrlVPNQQxyTcYLjZScX66GhWLiIxdzXRxUc8BxZKodzDDR656jFVb+hqy/cKL2S/dOCcx5+Nbty/i5f91R0XL9lfYVS0w5oPyno27Rl6oKPUy1HLSZRoV56cPqg2BiMiYNGfNdq4v0sVosxntDekxdiN7RPcsfo5Hl2/hB399Ou6ijCk3PFHfz3hT1pCM4hr9gz99pKrli/UydN2jK5k0ZdqoU62qpYAgJnt6GSr+6e9qVxsCEZGx6qwrHuZ/bl8UdzGGlLqmqdf9yGaqDYnD4vU7GRzN2EEtftxKmbF0I5/9zcy4izEq7s7OgUyR6dVt54KwkfbGXYP1KNaIFBDEZKgNQdmUIQUEIiIy9sxZs51JU6bFXYxIbNgxwBk/fpAL/lRZLzsSKNdYuD+V5c6Fz0VSjv7UvhfvtZg699m6bm/l5saksQ2ngCAmQwOTlUkZUrejIiIyFgxP9bj2kZUjrtPIm+NRZp7sDFM6Zq/eVvW6w49BnDUtA+ksk6ZMq3saUDmV9CDU6GNy8V+W7PueNWxv7bbd+27PR58O9ZlrZ7JsU28NJaqMAoKYVNKGQDUEIiIxaMY85hpFfaF561Olx5GL4vA2YybOYCZbca1JHLn0W/tSAFx2zzMN2f7wcTcK8+Yr291gqe/fuaSutU87du/J0a+mi9MoT9FzOwca/h4KCGKSv/tfMiDoaCOlXoZERKKnr969FAsmajlERcchSGIUNkzRO9ElDuRYaJOxbvvuWDo/uWLGsoZtO3/cK/k0NuIUxXneFRDEJH+xX6pRcXd7G6nRNFASEZGaRPWb3OzXfI2+K5n8EGBvi57dGXcR6uqNF99b923GERiO9I5R/J8W1gg9umwLG3c1vkZgOAUEMUlnc3S1t2El6gU7O0zjEIiIxCCJF6qjGe31Dd+5p+S8OxduYNKUaWzvT9VSLJEq1Oea6Md3P8Obv1f/YKZePvKLx3j/Tx6O/H0VEMQklcmVrB0AtSEQEWlmW/tSDFQz0NIYVuy+1X1LNwHQs7G2xo47B9Jj6jhu7h1kw47id291C686OXc++5tZIy531+L69jb0o7ufZs3WPQ1/C2/MFk2Pa3Aez71LnuP7dyzda9qGIrVzt8xey9PPjXbAs5F1NGzLUlY6myvZ5SgEvQ+plyERkeb02v+5i8kvOpib//lv4y5K7B7p2cxLDh8PVH9R/KoL/8rxzx8fSSPaelzYTf723QCsvPi9Q9MqKfpYqnWKKo9950B6xEyI2au3FZRnLB3F0oZ/Dj99TWXjLfzbTXOBvT979aQagpiks7mSDYoh36hYAYGISLOauar6bibj0qjGq/cu2chHf/k4v3pwRdnlPn3NkyXnPVNjLYPUR6ODsvxHrdz7DJ9X2AMQMLrB36pQj2PwyweXF91uLZuOom2FAoKYDGbKBwRdShkSEZEml0+fWbml/OBJ9y7ZGEVxatbo9BCpza8fXlm3beWDi8JL7Xr0qLWtPz3yQiXE+elTQBCTdNbpVsqQiEjTiepHOcoEiDsWbGDttv66bzeTC45WraPKxt3t6PAOPnK5+n4KSm1tNI29x7LBTJabZq4d9fr9qcpqCCo5rt+dXmxAstLrrdnaz7/c8FRF77/PdsfAaVZAEJP0CDUE6mVIRETq5Vu3L+LMy+vfc8loUjia8S58vky3z1vPt/68aFQ1GrU2sN6rPHXbUvMwYEtv+V6p6hUWbtw1OOIyQyMKV/imF9y2gD/PfbaGUgWy7rznxw8WnZcfHG64TK7xN4gVEMQknc3R2TFCL0OqIRARkTrZUuRioxF54zNXbq1ouShH4713yXMce/70odePLNvMfUuLX/Rf/fCKyFJ2mzA2alrVfF7mrtlet/et1znKb6ZvMMui9dWNSfHfUxfWpxBlKCCISSoch6CULjUqFhFJtGa4FmzEBemHfvZo/Tdao5/dv3dDz4/+4nE++evSDZ3rbSz0jxN1cBJlmtiW3kEmTZnGo8u2lF0uimMwmr1evql8G516UEAQk1QFjYpT2VxTVq2KiLSq5Zt69+n5pJUNv6hrlgvfRpSjlnYFN81cWzQA/Mm9z4y+QGPM8P0vzNfPH9p61RoNv3R6avV2oHgPQEXXL3i+O53lqzfPrU/BoHn+SYZRQBCTkcYh6Gxvwx2ydW7YJCIio3fq/97PB39aWy7+d6YvZsG6HXUqUXXWbO3nihk9Q6/r3ai1mX6xFj67gx/cuXTkBSt0x8INe73+8o1PcdJFd1e07rWPrCx6cJ5c2Txd10aZwjXcuu27R16Iel9Le7jNPVstdQxumrl2n4bn1ZRlw86BoeXjbkBfigYmi0kqm2NCuUbF4bx01uloj6pUIiIy0s/1shqr7696YDnXPbqSr53+8pq2MxqfuubJqhu/lquoruUish4V4AufLR1YfeCKR0hlc3zxHcfX/kZA72Bm6PmkKdOqWnf11n6295dvUCvlffY3szjnlBc19D2iSMpoa854QDUEcUlnfMQ2BIC6HhURSaBSdwk39w6ydMOuhr1vf8FFLcTTqHX4ndZavPeyh4pOd2eoHd7m3pF7nInCzoHMiMvUo8bm2e27Gz6AVzmbewfZuHOg6LyRT335BVZv3dN17pqC59UGaKVE8e8QZ01MOQoIYhL0MlSuDUHwiVHDYhGRaMWZ9vL2H9zHuy99oGHbb6aUnkaavXpPKs4/XjuzSZM0Siss79PP7WLjruIX2MMNZrL87cX38h9/mNeYglVg8rfv5qTv3FPVOpWen8JrolllRgqvJrAa7QV6tf9Lu9NB2W+eNfpxGBpJKUMxCUYqLt/tKKDRikVEmsTqLfUf2Gu4XRXcRa7F8BqBh3o2N/T9Cj3Ss5lL736GJ4p0S7r0ufrWimQLdnTR+p284dhD6rr9KL3rRw/Q1d7G0xedMeKy+fGL7llc20BxUSisJVuzrZ9JU6Zx5okvLLvOaLIm5q/dwT1LguNRaW1RsauzSuKGcsc9H9Q1a9NQ1RDEJJ3NlR+pWClDIiJN5Uu/H90opc2snoNpjeSxFVuLBgNjxVdvLn7XvREpIMOvGSvNFqhHz4T1ami+z3gUZTY7EN49v21O+YG/Ci+my5Xz7sV7xpj4u8sf4oYn1gAwd+2+bU72On017vpnrp1Z2wZipIAgJunsCCMVq4ZARETqbDQXe43KeW7WXOqx7JbZ6wDoS1XfhqDevd/ku/ospZLP4qJn9x7Aq9KA54kVlQee29WNMKCAIDYjjUOQn6c2BCKSBGZ2tZltNLMFBdMOMbO7zOyZ8O/B4XQzs8vMrMfM5pnZa+MreWU27Rqs+u5srTdzFz27s+rGlI1uRFzNMfjO9CUNLEl0Upn6HNTt/Sk+dGX5Qd2ue3QlTxekV900c81eF81PFbSdKHYuZizdOKpaoY27BvZ639GYNn896SqP1ffr2G1sMfc/vYm7FkWTXtWs3Y3mKSCISTrrZcch6FbKkIgkyzXA6cOmTQHucffjgXvC1wBnAMeHj3OBKyMq46is2tLH6y+6m6seqGzQo3p5ZFn1+f/1jgeadezM4ZdeI9VG1NLz0Xsue3DU6xaaOvdZ5o8wPsUFty3cq9H5V2+eV/L9i91Q/NSvn+S0H94/Ylk27BjgwqkLyYTbeNPFM3jXj2pv7P67J1bXvI1qTJ1bPgWpXMPkVqOAIAbuTqrilKEm/bYVEamCuz8ADK/HPxO4Nnx+LXBWwfTrPPAYMMHMjoikoJS+y10qBWLttmBQpfuf3tSoItVNs17AN1qt+/3Y8i31KUgdVLov+bz8irdbEC5OuWUe1zyykoeXBfs9mmwFx3lk2d7HrRE3OfsGizfEX7Wljy/e0Dztfpo9RU4BQQwyYauYrrK9DAXz1IZARBLscHdfHz7fABwePj8SWFOw3NpwWlMb6e5uEjX7RU7eMyXSZP7pupmc9sP7R0x1+tKNzXNhWanfPLqSdDbHhVMXVjUWg2HctzQIbv/ld7NH/f4j9ZhVawpN/pT96+/nFJ0/OIrgYzRtbLb2JWPAOXU7GoN8hFwuZUi9DIlIK3F3N7Oqfo3N7FyClCKOOeaYhpRrJGu39dMfNuCspsvQelxIX3bPM6NYq7FVBNv6SzfQvH2E9I1GKnXRFlX+eBx6NvZyz+KNXPPISjbuGuCnH3tdResVXhRXMphalGYXqaVbUmIgv+1lPot5w/+HRjMq91lXPDzi+4y07WagGoIY5O/6q1GxiLS45/KpQOHffF+B64CjC5Y7Kpy2F3e/yt0nu/vkiRMnNrywxaza0s8/XVd9V4Pu5YOCS+5YwhsvvrfEe/Zx4xOry16sPV4ixaXeFyXVBDbLN/fV983LqOdoyM1qxwi94zzwzOahEYtzFVxKNHuj12qNVCtS7n9hxtLK0/8KR08up3ewuXszUg1BDFIVBAT5GgKlDIlIgk0FzgEuDv/eVjD9PDO7EXgDsKMgtaiprKnwYqBaV963rOS8t37/vrLrnn7pAyXvmo4mHnikzOBlUVxEzli6kbe/7PkNf59mMprxBFKZHH8q6Md/a1+Kb09bXM9i1SxbEJnUa7yDuFT7yb9pZnOOUJynGoIYDKUMaRwCEWkRZnYD8CjwMjNba2afIQgE3mlmzwCnha8BpgPLgR7gF8DnIy5rJO+zYedAQ7ZbKhiA0V1oFhvMKUqf+vWTLGhw+4wv3TinoduvxLwqj/Pwj+mNT+7bg8+mXZW3Hai3Yh+1wpSyvsHqx0qoRqP/ixs9qnjUVEMQg3zPQWpDICKtwt0/UmLWO4os68AXGlui0qJKnPj5/dF2UzqWve8nD7Hy4vfG9v7P7Wz8hfXUEUbpHclglb0KNdpI4xb8eFRtYPbIBxyl4vfR1j9kc6XW3PuN/vG6mbF+JutNNQQxqKwNQfDBS6nbURGRSFXzrTvadl5xpbiPlV+U9Tt2F52e5JtkhZ+lktek+WVHcRyq6WmoHv5apMF2vT9/A+ksq7bUL21vW3+KF399eom5+5Y+SY3SFRDEIP+P3Fmm29F8OlE6wV9+IiJj3U/u7Ym7CFWppOeVZnDe74p38/mFGrrBrEazt0nenao+3eaPsyrLYW9kbzijSVkruS3g5f91R5n3qn6b63eUTuMrdoN2NB0KNCsFBDHI3wWoKGVIbQhERCJVzUXLaHO0m70LwriVuuBN0h3ZkfSWGHALGtMgNx8E6aNZ3ANjYODBWiggiEG6mkbFqiEQERlT7liwIe4iRKZRPcUsWr+zpvVL54FXphkCtiU1HoNyiqVk/eKBoE1LIxsi17PB/khbGs1bNcN5j4sCghgMdTtapoago00jFYuIxG1pmR57Svncb2c1oCQSpRlLNw4935aQkWgLnfLde1k5bFyIXz60IqbSjM6abeXbDrTyxf1o1BQQmNmHzWyhmeXMbHLB9E4zu9bM5pvZYjM7v2De6Wa21Mx6zGxKLe8/VuUv8svVEJgZXR1tDCogEBGJTbm0jVqMlT7Y71zYOrUdhQr7jP/zvPhGWC7Ffd8L3oumVzfmwLMFtQQD6cZ2AZpXzzYEl95dvpeib09bVLf3agW11hAsAD4IPDBs+oeBbnf/G+B1wGfNbJKZtQNXAGcAJwAfMbMTaizDmJPKBP8Q5XoZgiBgSGfGxo+GiIhUbqz0lnPRCANbRTm67dYE3qmPUrkUmi9HNA5DlFc05RoIy75qGofA3RdD0ZwwBw4wsw5gHJACdgInAT3uvjxc70bgTKClwrihGoKO8l+kne2mlCERkYh5mVeNeQ+pxMd++XjcRWgaq7b2c3lBD1fv+fGDNW3vsRVbKlqu1jv8zdR5k/4H99aoNgQ3A33AemA18AN33wocCawpWG5tOK2l7BmpuL3scl0dbWPmLpKIiNTmn8u0Pbju0ZXMW7s9usJUKsIrvKUbGtfItpwLblsYy/uW87Wb53H34j09LlXSCHufa/lRXBH/9vF9R0OuRlSjgFdiNO2DkmzEGgIzuxt4QZFZ33D320qsdhKQBV4IHAw8GG6nKmZ2LnAuwDHHHFPt6k1raGCyEWsI2lRDICISo2sfWcXrXnRI3W/OGPtej/2lTO9E+YvSZhsZddq89XEXQSJ0y+zyYxlMmjKt7Pw/z22e9hjrthcf/K5VjVhD4O6nufsrizxKBQMAHwXucPe0u28EHgYmA+uAowuWOyqcVuq9r3L3ye4+eeLEiZXt0RhQyUjFELQh0DgEIiLxmRpewLzxknvrut1mulNaTjMVUyketSl3LsfKgHWNNlYa+zdCo1KGVgOnApjZAcDJwBLgSeB4MzvWzLqAs4GpDSpD0xrMVBYQdLYrZUhEJGrFrpsq7Zt9wbod9S2MSOi3j62KuwiSYLV2O/oBM1sLnAJMM7M7w1lXAOPNbCFBEPBrd5/n7hngPOBOYDFwk7s3X3Jeg6XD4a+7y4xDAEEbAqUMiYhEa/jd+0wV38Pv+8lD9S6OhFq9X/n//NOCuIsgCVZrL0O3ArcWmd5L0PVosXWmA9Nred+xrtKUoaCXoRb/BhQRidnlM3pGXqhK1Yyku6V3T+3E7lQ0/cVLc6s2taVcutBfqxhr4qnV26t6Xxk7NFJxDFKZHG0G7W0jNypWGwIRkWgN71rxpifXlFgyGtPm72m4+6lrnoj0vVdtKT8arMTj6ed6q1revfSYEX+aU7IpZ8vJtPBNWAUEMUhnc3SNkC4E6nZURKQZPBvzAEeFl3GPLd8aWzmkNVVTmzXWZVpoX4dTQBCDVDY3YroQhCMVq4ZARKS1NVNXPzJmDU8zqvTSN9fqjTdaRE1tCGR0UpkcXRUEBBqHQEQkes10+TP523exuTcVdzEkBtv70zy5svE1QqVSiaS1KCCIgVKGRESkEgoGWtc/Xjezrtt7dnu8qW/S3JQyFIN01itKGQpqCJrpXpWIiIiMRaPtoUr1B61BAUEMgjYEI/+LdXWYehkSERGR2KyPuVG9REMpQzFIZXJ0dbSPuFyXRioWEYmc2lBK0txRxVgDw3315nl1LIk0K9UQxCCdzdFVQQ2BGhWLiIhInB5dviXuIkgEFBDEIF1ht6OdHQoIRESipl4+paXo8y4oIIhFkDJU6TgETq6FB8oQEYmaro+kFSg1TgqpDUEMUlln/67Kuh0FSOdydLeN3OZARGQsMrOVwC4gC2TcfbKZHQL8HpgErAT+wd23xVVGkSSasXQj0+atj7sY0gRUQxCDdKbClKGwnYG6HhWRFvB2dz/R3SeHr6cA97j78cA94WsRqaOfzuiJuwjSJBQQxCCVzdHVUVmjYkA9DYlIKzoTuDZ8fi1wVnxFEUketZWRQgoIYhD0MlRFypAaFotIsjnwVzObZWbnhtMOd/d8LsMG4PAoCyOSdO6gJoqSpzYEMag8ZUg1BCLSEt7k7uvM7PnAXWa2pHCmu7uZ7XPpEgYP5wIcc8wx0ZRUJEFmrVKzHAmohiAGqWyOzgp7GQLVEIhIsrn7uvDvRuBW4CTgOTM7AiD8u7HIele5+2R3nzxx4sQoiywy5n362ifjLoI0EQUEMUhlqksZSikgEJGEMrMDzOzA/HPgXcACYCpwTrjYOcBt8ZRQJJmUfSCFlDIUg3TWKxqHIJ8ylM4oyU9EEutw4FYLWjh2AL9z9zvM7EngJjP7DLAK+IeoCqT+2UWk1SggiEEwUnElvQwFy6iGQESSyt2XA68uMn0L8I7oSyQi0nqUMhSxXM7J5Jyu9pEHGhtKGVK1noiIiIg0iAKCiOXv9ndWMA6BGhWLiETP1fGoiLQYBQQRy1/cV9KouFMBgYhI5AyN2CQirUUBQcTy6T+VNCpWypCISPQ0gquItBoFBBFLZ4Oq6KoGJlMNgYiIiIg0iAKCiOXTfyoJCPa0IVA+q4iIiIg0hgKCiA1m8gFBBd2Ohg2PlTIkIiIiIo2igCBi+RqC7kraEKhRsYiIiIg0mAKCiFWTMtTZoYBARCRqGqlYRFqNAoKIpTLVtyFQo2IRERERaRQFBBHLX9xX0u3oUC9DakMgIiIiIg2igCBi1XQ72t5mtLeZUoZEREREpGEUEEQsnal8pGIIeiNSt6MiIiIi0igKCCJWTcoQBIGDUoZERKLj6CaMiLQWBQQR29PL0MjjEEAQOKhRsYiIiIg0igKCiFXTy1B+ubRqCERERESkQRQQRCxVxcBkoBoCEREREWksBQQRS4+mhkABgYhIZIzKUjpFRJJCAUHEhrodrbCGoLO9jVRGDdxEREREpDEUEEQsVW2j4nZTypCIiIiINIwCgoilqhyHoKtDjYpFREREpHEUEEQsnc3R2W6YVVZDoDYEIiIiItJICggilsrkKm5QDGEbAgUEIiKR0cBkItJqFBBELJ3NVTxKMYTdjiplSEREREQaRAFBxFJZr6qGoEspQyIiIiLSQAoIIpbO5ipuUAxBb0T5rkpFREREROqtpoDAzL5vZkvMbJ6Z3WpmEwrmnW9mPWa21MzeXTD99HBaj5lNqeX9x6JURilDIiKVaPXfCxGRqNRaQ3AX8Ep3fxXwNHA+gJmdAJwNvAI4HfipmbWbWTtwBXAGcALwkXDZlpHvZahS6mVIRFpRnL8XOVXKikiL6ahlZXf/a8HLx4APhc/PBG5090FghZn1ACeF83rcfTmAmd0YLruolnKUsrUvxaxV2xqx6VF7dvvuqnsZ6k9luWvRcw0slYjE7YiD9uOVRx4UdzGayUlE8Hsxff56Lpq2mMFMlpe94ECOnDBOtbIi0nJqCgiG+TTw+/D5kQQBQt7acBrAmmHT31Bqg2Z2LnAuwDHHHFN1gZZs2Mk/XTez6vUa7c3HH1bxsoeN72J3OtuU+yEi9XPWiS/k0rNfE3cxmsmRjPB7UetvBMDEA7t5w3GHsHpLP1v70izdsGuUxRURGbtGDAjM7G7gBUVmfcPdbwuX+QaQAa6vZ+Hc/SrgKoDJkydXXYn7qqMmcPu/vKmeRaqLFx26f8XLfu6tL+btL38+ripskUQ7aFxn3EUYc2r9jQB4/aRDeP2kQ/aa1juYYdGzOxnf3cGugTQOPG+/ThxndypLzoNAYu22fvbv6uCgcR10trcxvruDrX0pujva2TmQ5phD98dzsHMgTSqbI53NMWFc11AaqDts6Rsk5zCYzjKYzXHI/l0cf/h4cg7tZmztT3HQuE7SmRz96Sy5nLNrIENXh7Fx5yAH7tfJuK42XnDQOPoGM2RzTjqbI5XJkc46bW0wYVwXOXfa2wwDdqeDfTh0fBfdHW0MpHJs7hvkwP06aDNj5+40B3R3sL0/zcQDu+lPZWgLB9Ps6mijzYxxne3kx9ccSGcxM7I5p7ujjV0DGcxgW3+KA7o62Naf4uD9u9jaF+xLe5thBpt2Be+5cdcgB3Z3ctC4TnYOpGkzo6ujjbXb+pl4YDf7dbZz2PhucNg1mKaro41NuwbpaGujzaC9zciE793eZhw0rpO123YzvrsDM3jhQeNoazNyOceMsgODevhjaxYs39ZmQ9OzYS5ZR3vb0Os2M9raDHcnnfWhlOB0Nniv3eks7WZkss6O3Wk6O4y+wSxHHLQfg5kc7Wbk3BnX1c7O3Wn6Ulky2RwvnDCOjnZjc2+KznZjdyrLoeO72dqb4rADu0hnnM4OY2tfigP3C45p32CGLb0pAPbrDMYwGt/dweHP24/d6Sy9Axl2p7Mcd9gBZHJB+bvCsY5S2RwDqSxtbTZ03HK5YDs5D9o/ZnI5HNjel+b5z+smmws+U+1txo7dacZ1ttPZ3hYcY6DNjF2DGXansozfr4MDutoBSGVz5HLBedvWnyKVydHRbuzf1cF+nW3s2J0mm3N6BzIMpHNs6h3goHGddLa3cdTB+7Njd5qcOx1tRpsF5d05kGb55j7azTj6kP05cL8Olm/q4/kHdpPO5uhob2PF5l6OOeQA9utsI5tzUpkc7W3BZ23/rg527g7+1w8b38VgJsdAOsvGXYMcvH8X2ZwzcXw3a7b1M5jJcvQh+zO+u4PeweB/Y3cqy+50lsOftx+rtvSxO5XlhRPG0TuY4dDxXTxvv056BzOkszn27wz+J7o720hnnEPHB98JOYetfYP0p7Ic0B2U59ADutncN8hhB3SzpW+QVx81YTRfcxUxr/FK08w+CXwWeIe794fTzgdw9++Gr+8ELgxXudDd311suXImT57sM2fqLrmIyHBmNsvdJ8ddjnoys1Oo4vdCvxEiIqWN9DtRay9DpwNfBd6fDwZCU4GzzazbzI4FjgeeAJ4EjjezY82si6Dh8dRayiAiIomk3wsRkYjU2obgcqAbuCusgnvM3T/n7gvN7CaCxl8Z4AvungUws/OAO4F24Gp3X1hjGUREJGHcPaPfCxGRaNTay9BLysy7CLioyPTpwPRa3ldERJJPvxciItHQSMUiIiIiIi1MAYGIiIiISAtTQCAiIiIi0sIUEIiIiIiItDAFBCIiIiIiLUwBgYiIiIhIC1NAICIiIiLSwszd4y5DRcxsE7BqFKseBmyuc3Ganfa5NWifW0Ml+/wid58YRWGaVQ2/EdB6nyvtb7Jpf5NttPtb9ndizAQEo2VmM919ctzliJL2uTVon1tDK+5z1FrtGGt/k037m2yN2l+lDImIiIiItDAFBCIiIiIiLawVAoKr4i5ADLTPrUH73BpacZ+j1mrHWPubbNrfZGvI/ia+DYGIiIiIiJTWCjUEIiIiIiJSQqIDAjM73cyWmlmPmU2JuzyNYGZXm9lGM1tQMO0QM7vLzJ4J/x4cZxnrzcyONrMZZrbIzBaa2ZfC6YndbzPbz8yeMLO54T5/M5x+rJk9Hn7Gf29mXXGXtZ7MrN3MnjKz28PXSd/flWY238zmmNnMcFpiP9fNIAm/E9V+J1rgsnCf55nZawu2dU64/DNmdk5c+1SJSr8fzKw7fN0Tzp9UsI3zw+lLzezdMe3KiMxsgpndbGZLzGyxmZ2S5PNrZv8afpYXmNkN4W9gYs6vVXHtNprzaWavC39LesJ1bcRCuXsiH0A7sAw4DugC5gInxF2uBuznW4DXAgsKpn0PmBI+nwJcEnc567zPRwCvDZ8fCDwNnJDk/QYMGB8+7wQeB04GbgLODqf/DPjnuMta5/3+N+B3wO3h66Tv70rgsGHTEvu5jvuRlN+Jar8TgfcAfwm/V04GHg+nHwIsD/8eHD4/OO79K7PfFX0/AJ8HfhY+Pxv4ffj8hPCcdwPHhp+F9rj3q8S+Xgv8Y/i8C5iQ1PMLHAmsAMYVnNdPJun8UsW122jOJ/BEuKyF654xUpmSXENwEtDj7svdPQXcCJwZc5nqzt0fALYOm3wmwZcH4d+zoixTo7n7enefHT7fBSwm+AJJ7H57oDd82Rk+HDgVuDmcnqh9NrOjgPcCvwxfGwne3zIS+7luAon4nRjFd+KZwHXh98pjwAQzOwJ4N3CXu291923AXcDp0e1J5ar8fig8DjcD7wiXPxO40d0H3X0F0EPwmWgqZnYQwQXkrwDcPeXu20nw+QU6gHFm1gHsD6wnQee3ymu3qs5nOO957v6YB9HBdVTwu5HkgOBIYE3B67XhtFZwuLuvD59vAA6PszCNFFYNvobgjnmi9zusHp8DbCT4x18GbHf3TLhI0j7jlwJfBXLh60NJ9v5CEOT91cxmmdm54bREf65jlrjfiQq/E0vt91g6HpdS+ffD0H6F83eEy4+V/T0W2AT8OkyR+qWZHUBCz6+7rwN+AKwmCAR2ALNI7vnNq9f5PDJ8Pnx6WUkOCITgzjLBRUbimNl44I/Al919Z+G8JO63u2fd/UTgKIK7HC+Pt0SNY2bvAza6+6y4yxKxN7n7a4EzgC+Y2VsKZybxcy310yrfiS34/dBBkF5ypbu/BugjSCkZkrDzezDBXfFjgRcCB9C8NRkNEcf5THJAsA44uuD1UeG0VvBcWGVE+HdjzOWpOzPrJPjhu97dbwknJ36/AcKq4hnAKQRVhx3hrCR9xt8IvN/MVhKkcZwK/Jjk7i8wdGcMd98I3EoQ+LXE5zomifmdqPI7sdR+j5XjUe33w9B+hfMPArYwdvZ3LbDW3R8PX99MECAk9fyeBqxw903ungZuITjnST2/efU6n+vC58Onl5XkgOBJ4PiwVXoXQUOTqTGXKSpTgXxr83OA22IsS92FuYG/Aha7+w8LZiV2v81soplNCJ+PA95JkCc8A/hQuFhi9tndz3f3o9x9EsH/7r3u/jESur8AZnaAmR2Yfw68C1hAgj/XTSARvxOj+E6cCnwi7L3kZGBHmKpwJ/AuMzs4vEv7rnBaUxnF90PhcfhQuLyH088Oe6k5FjieoDFmU3H3DcAaM3tZOOkdwCISen4JUoVONrP9w892fn8TeX4L1OV8hvN2mtnJ4fH7BJX8bozU6ngsPwhaZj9NkGv9jbjL06B9vIEgxy5NcBfhMwS5c/cAzwB3A4fEXc467/ObCKrS5gFzwsd7krzfwKuAp8J9XgBcEE4/juALrgf4A9Add1kbsO9vY08vIond33Df5oaPhfnvrCR/rpvhkYTfiWq/Ewl6Hrki3Of5wOSCbX06/P/qAT4V975VsO8jfj8A+4Wve8L5xxWs/43wOCylgp5YYtzPE4GZ4Tn+E0GvMok9v8A3gSXh791vCHoKSsz5pYprt9GcT2ByeOyWAZcTDkRc7qGRikVEREREWliSU4ZERERERGQECghERERERFqYAgIRERERkRamgEBEREREpIUpIBARERERaWEKCEREREREWpgCAhERERGRFqaAQERERESkhf1/eMgJxu1HdHQAAAAASUVORK5CYII=\n", 279 | "text/plain": [ 280 | "
" 281 | ] 282 | }, 283 | "metadata": { 284 | "needs_background": "light" 285 | }, 286 | "output_type": "display_data" 287 | } 288 | ], 289 | "source": [ 290 | "num_frames = 100000\n", 291 | "batch_size = 32\n", 292 | "gamma = 0.99\n", 293 | "\n", 294 | "losses = []\n", 295 | "all_rewards = []\n", 296 | "episode_reward = 0\n", 297 | "\n", 298 | "state = env.reset()\n", 299 | "for frame_idx in range(1, num_frames + 1):\n", 300 | " epsilon = epsilon_by_frame(frame_idx)\n", 301 | " action = model.act(state, epsilon)\n", 302 | " try:\n", 303 | " action = action.item()\n", 304 | " except:\n", 305 | " pass\n", 306 | " next_state, reward, done, _ = env.step(action)\n", 307 | " replay_buffer.push(state, action, reward, next_state, done)\n", 308 | " state = next_state\n", 309 | " episode_reward += reward\n", 310 | " \n", 311 | " if done:\n", 312 | " state= env.reset()\n", 313 | " all_rewards.append(episode_reward)\n", 314 | " episode_reward = 0\n", 315 | " \n", 316 | " if len(replay_buffer) > batch_size:\n", 317 | " loss = compute_td_loss(batch_size)\n", 318 | " losses.append(loss.data.item())\n", 319 | " \n", 320 | " if frame_idx % 200 == 0:\n", 321 | " plot(frame_idx, all_rewards, losses)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "


" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "

Atari Environment

" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 12, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "from common.wrappers import make_atari, wrap_deepmind, wrap_pytorch" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 13, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "env_id = \"PongNoFrameskip-v4\"\n", 354 | "env = make_atari(env_id)\n", 355 | "env = wrap_deepmind(env)\n", 356 | "env = wrap_pytorch(env)" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 14, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "class CnnDQN(nn.Module):\n", 366 | " def __init__(self, input_shape, num_actions):\n", 367 | " super(CnnDQN, self).__init__()\n", 368 | " \n", 369 | " self.input_shape = input_shape\n", 370 | " self.num_actions = num_actions\n", 371 | " \n", 372 | " self.features = nn.Sequential(\n", 373 | " nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),\n", 374 | " nn.ReLU(),\n", 375 | " nn.Conv2d(32, 64, kernel_size=4, stride=2),\n", 376 | " nn.ReLU(),\n", 377 | " nn.Conv2d(64, 64, kernel_size=3, stride=1),\n", 378 | " nn.ReLU()\n", 379 | " )\n", 380 | " \n", 381 | " self.fc = nn.Sequential(\n", 382 | " nn.Linear(self.feature_size(), 512),\n", 383 | " nn.ReLU(),\n", 384 | " nn.Linear(512, self.num_actions)\n", 385 | " )\n", 386 | " \n", 387 | " def forward(self, x):\n", 388 | " x = self.features(x)\n", 389 | " x = x.view(x.size(0), -1)\n", 390 | " x = self.fc(x)\n", 391 | " return x\n", 392 | " \n", 393 | " def feature_size(self):\n", 394 | " return self.features(autograd.Variable(torch.zeros(1, *self.input_shape))).view(1, -1).size(1)\n", 395 | " \n", 396 | " def act(self, state, epsilon):\n", 397 | " if random.random() > epsilon:\n", 398 | " state = Variable(torch.FloatTensor(np.float32(state)).unsqueeze(0), volatile=True)\n", 399 | " q_value = self.forward(state)\n", 400 | " action = q_value.max(1)[1].data[0]\n", 401 | " else:\n", 402 | " action = random.randrange(env.action_space.n)\n", 403 | " return action" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 15, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "model = CnnDQN(env.observation_space.shape, env.action_space.n)\n", 413 | "\n", 414 | "if USE_CUDA:\n", 415 | " model = model.cuda()\n", 416 | " \n", 417 | "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n", 418 | "\n", 419 | "replay_initial = 10000\n", 420 | "replay_buffer = ReplayBuffer(100000)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 16, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "epsilon_start = 1.0\n", 430 | "epsilon_final = 0.01\n", 431 | "epsilon_decay = 30000\n", 432 | "\n", 433 | "epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 17, 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "[]" 445 | ] 446 | }, 447 | "execution_count": 17, 448 | "metadata": {}, 449 | "output_type": "execute_result" 450 | }, 451 | { 452 | "data": { 453 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEFCAYAAADzHRw3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAsTAAALEwEAmpwYAAAX3ElEQVR4nO3df3BV553f8ffn6reQhACJ8NPGsDg2Zm3H1vjH7uzGu8k22E1gptvs2hN3u60bT7LrtDNJM+NOWjfr/LPbtJnpNrgJTbPZ7KzjtdOZlJ2QMu2GTLauySKvjWOw8QiMjQAbGcQvC5CQvv3jXpGLEOiCru7ROefzmtHo3HMe7vkervjo4TnnPEcRgZmZpV8h6QLMzKw6HOhmZhnhQDczywgHuplZRjjQzcwyoj6pHXd1dcWKFSuS2r2ZWSq9+OKL70VE92TbEgv0FStW0Nvbm9TuzcxSSdJbl9vmIRczs4xwoJuZZYQD3cwsIxzoZmYZ4UA3M8uIKQNd0rclHZH06mW2S9KfSuqT9IqkO6pfppmZTaWSHvp3gHVX2H4/sLr09SjwX6dflpmZXa0pAz0ifgocu0KTDcB3o2g70ClpcbUKnGjH/mN8devrjI552l8zs3LVGENfChwoe91fWncJSY9K6pXUOzAwcE07e/nt42zctpeh4fPX9OfNzLKqpidFI2JTRPRERE9396R3rk5pTlPx5tb3z41WszQzs9SrRqAfBJaXvV5WWjcj5jTVAXD6nHvoZmblqhHom4HfK13tcg9wIiIOV+F9J9V2oYfuQDczKzfl5FySvgfcB3RJ6gf+PdAAEBHfALYADwB9wBDwz2aqWCgbcvEYupnZRaYM9Ih4aIrtAfxh1SqawpxGj6GbmU0mdXeKjo+he8jFzOxiqQv08TF0nxQ1M7tY6gJ9jk+KmplNKnWB3tJQhwTvD3sM3cysXOoCvVAQrQ117qGbmU2QukCH4rCLA93M7GKpDPS2pnqfFDUzmyCVge4eupnZpVIa6HU+KWpmNkE6A73RPXQzs4nSGegecjEzu0RqA/2053IxM7tIKgO9ranOTywyM5sglYE+p6meoeFRxvxcUTOzC9IZ6I2eE93MbKJ0BrqfK2pmdomUBrqfK2pmNlEqA318TnSfGDUz+4VUBnprox9yYWY2USoDvc1j6GZml0hloPu5omZml0ploPu5omZml0ploM/xSVEzs0ukMtDHnyt6+qwD3cxsXCoDvVAQbY31nPKQi5nZBakMdID25npOuYduZnZBigO9gVNnR5Iuw8xs1khxoLuHbmZWLtWBftI9dDOzC1Ic6A3uoZuZlUlxoHvIxcysXEWBLmmdpD2S+iQ9Psn26yRtk/SSpFckPVD9Ui82flI0wk8tMjODCgJdUh2wEbgfWAM8JGnNhGb/Fng2Ij4EPAg8Ve1CJ2pvrmdkNDh3fmymd2VmlgqV9NDvAvoiYl9EDAPPABsmtAmgo7Q8FzhUvRIn19FcvP3fJ0bNzIoqCfSlwIGy1/2ldeW+DDwsqR/YAnxusjeS9KikXkm9AwMD11DuL3S0NAB4HN3MrKRaJ0UfAr4TEcuAB4C/kHTJe0fEpojoiYie7u7uae2wvdRDd6CbmRVVEugHgeVlr5eV1pV7BHgWICJeAJqBrmoUeDntzeM9dA+5mJlBZYG+A1gt6QZJjRRPem6e0OZt4CMAkm6mGOjTG1OZgnvoZmYXmzLQI+I88BiwFXiN4tUsuyQ9KWl9qdkXgE9L2gl8D/j9mOHrCd1DNzO7WH0ljSJiC8WTneXrnihb3g38anVLuzL30M3MLpbaO0XbGuuR4KQD3cwMSHGgX3jIhYdczMyAFAc6eD4XM7NyKQ90P+TCzGxcygPdPXQzs3EOdDOzjEh5oHvIxcxsXMoD3T10M7NxKQ/0Bk76IRdmZkDKA72ztYGR0eDMyGjSpZiZJS7dgV6aE/34kMfRzcxSHehzS4F+4owD3cws3YHe6h66mdm4dAf6hR76cMKVmJklL9WB3tnaCHjIxcwM0h7oPilqZnZBqgO9tbGO+oLcQzczI+WBLonO1gaOO9DNzNId6FA8MXrCQy5mZhkJdPfQzczSH+idrY0c92WLZmYZCPSWBl/lYmZGBgK9w0MuZmZABgK9s7WBU2fPc350LOlSzMwSlfpAH7/9/6QfdGFmOZf6QO9s9YyLZmaQhUBvKc7ncnzIV7qYWb6lPtA7xudzcQ/dzHIu9YE+PuRy0oFuZjmX/kAv9dAH3/eQi5nlW+oDffwql0HfXGRmOVdRoEtaJ2mPpD5Jj1+mze9I2i1pl6Snq1vm5dXXFehsbeCYe+hmlnP1UzWQVAdsBH4L6Ad2SNocEbvL2qwG/g3wqxExKGnhTBU8mflzGh3oZpZ7lfTQ7wL6ImJfRAwDzwAbJrT5NLAxIgYBIuJIdcu8svmtDnQzs0oCfSlwoOx1f2lduRuBGyU9L2m7pHWTvZGkRyX1SuodGBi4toon4R66mVn1TorWA6uB+4CHgP8mqXNio4jYFBE9EdHT3d1dpV3DgrZGjjrQzSznKgn0g8DystfLSuvK9QObI2IkIt4E3qAY8DUxf04jg0PDREStdmlmNutUEug7gNWSbpDUCDwIbJ7Q5gcUe+dI6qI4BLOvemVe2bzWRkbHgpNnPEGXmeXXlIEeEeeBx4CtwGvAsxGxS9KTktaXmm0FjkraDWwDvhgRR2eq6IkWtBXnczn6/rla7dLMbNaZ8rJFgIjYAmyZsO6JsuUAPl/6qrn5c5oAGPQEXWaWY6m/UxSKly0CHD3tQDez/MpGoJeGXHzpopnlWTYCvdRDP+YhFzPLsUwEektjHS0NdRzzkIuZ5VgmAh18t6iZWWYCfUFbo4dczCzXMhPo8zxBl5nlXGYCfcGcRl+2aGa5lplA72pv4r3T5zyfi5nlVmYCfWF7E+fOj3HqnOdzMbN8ykygd7cXb/8fOOX5XMwsn7IT6G3FQD9y0oFuZvmUnUAf76GfdqCbWT5lJtAXtjcDHnIxs/zKTKB3tNTTWFfgyKmzSZdiZpaIzAS6JLrbm9xDN7PcykygQ/FadAe6meVVpgJ9oQPdzHIsU4HuIRczy7NsBXpbE8eGhhkZHUu6FDOzmstUoC/saCLCzxY1s3zKVKCP3y3qYRczy6NsBfqFu0V9LbqZ5U+mAn1hR/Fu0Xc9n4uZ5VC2Ar29iYLg8An30M0sfzIV6A11Bbrbmzh8/EzSpZiZ1VymAh1g8dwW99DNLJcyGOjNHD7hHrqZ5U8GA73YQ/ezRc0sbzIY6M0MDY9y8qyfLWpm+ZK9QO8sXrroYRczy5uKAl3SOkl7JPVJevwK7X5bUkjqqV6JV2fx3PFA94lRM8uXKQNdUh2wEbgfWAM8JGnNJO3agX8F/KzaRV6NxXNbADh83IFuZvlSSQ/9LqAvIvZFxDDwDLBhknZfAf4ESDRJx28uesdDLmaWM5UE+lLgQNnr/tK6CyTdASyPiB9e6Y0kPSqpV1LvwMDAVRdbifq6AgvbmznkIRczy5lpnxSVVAC+BnxhqrYRsSkieiKip7u7e7q7vqzFnc2840A3s5ypJNAPAsvLXi8rrRvXDqwFfiJpP3APsDnJE6NL5rZwyLf/m1nOVBLoO4DVkm6Q1Ag8CGwe3xgRJyKiKyJWRMQKYDuwPiJ6Z6TiCiyb30L/4BnGxnxzkZnlx5SBHhHngceArcBrwLMRsUvSk5LWz3SB12L5vFaGR8d495SHXcwsP+oraRQRW4AtE9Y9cZm2902/rOm5bn4rAAeOnblwGaOZWdZl7k5RgOWlQH/72FDClZiZ1U4mA31pZwuSA93M8iWTgd5YX2BxRzP9DnQzy5FMBjoUh13cQzezPMl0oB8YdKCbWX5kNtCvm9/KuyfPcXZkNOlSzMxqIrOBvnx+8XLF/kHfMWpm+ZDZQL/uwqWL7ydciZlZbWQ20FcsmAPAvgEHupnlQ2YDff6cRjpbG9jrQDeznMhsoEtiZdcc9g2cTroUM7OayGygA6zqbmPfe+6hm1k+ZDrQV3a3MXDqHCfPjiRdipnZjMt4oPvEqJnlR6YDfVV3G4DH0c0sFzId6NfNb6WuIPfQzSwXMh3ojfUFrpvfyl730M0sBzId6FAcdnnj3VNJl2FmNuMyH+g3L27nzffe9yRdZpZ5OQj0DsYC+o542MXMsi3zgX7TonYAdh8+mXAlZmYzK/OBfv2COTQ3FHj9sMfRzSzbMh/odQXxwUUdvP6Oe+hmlm2ZD3SAmxe189rhk0RE0qWYmc2YXAT6TYvaGRwa4cipc0mXYmY2Y3IR6Dcv7gBg16ETCVdiZjZzchHoa5fOpSB4+YAD3cyyKxeBPqepnhs/0M7OA8eTLsXMbMbkItABbl/eyc7+4z4xamaZlZtAv215J8eHRnjr6FDSpZiZzYjcBPrtyzsB2Nl/PNE6zMxmSkWBLmmdpD2S+iQ9Psn2z0vaLekVSX8j6frqlzo9qxe20dJQx0tvH0+6FDOzGTFloEuqAzYC9wNrgIckrZnQ7CWgJyJuBb4P/IdqFzpd9XUFbl02lxffGky6FDOzGVFJD/0uoC8i9kXEMPAMsKG8QURsi4jxwentwLLqllkdd69cwK5DJ/zQaDPLpEoCfSlwoOx1f2nd5TwC/GiyDZIeldQrqXdgYKDyKqvknpXzGQvo3X+s5vs2M5tpVT0pKulhoAf46mTbI2JTRPRERE93d3c1d12RO66bR2Ndge37HOhmlj31FbQ5CCwve72stO4ikj4KfAn4cETMyklTmhvquH15J9v3HU26FDOzqqukh74DWC3pBkmNwIPA5vIGkj4EfBNYHxFHql9m9dyzcj6vHvQ4upllz5SBHhHngceArcBrwLMRsUvSk5LWl5p9FWgDnpP0sqTNl3m7xN27qouxgBf2upduZtlSyZALEbEF2DJh3RNlyx+tcl0zpmfFPNqb6tn2+hE+dsuipMsxM6ua3NwpOq6hrsCv3djFtj1HPK+LmWVK7gId4Ddv+gDvnjzHrkN+LJ2ZZUcuA/2+D3YjwbbXZ/X5WzOzq5LLQO9qa+L25Z1s3f1O0qWYmVVNLgMd4B/+8mJePXiSfQOnky7FzKwqchvoH791CRL89c7DSZdiZlYVuQ30RXObuWvFfDbvPOirXcwsE3Ib6ADrb1/C3oH3efWgr3Yxs/TLdaB//NYlNDcUePrv3kq6FDOzact1oM9taWD9bUv4ny8f8twuZpZ6uQ50gIfvuZ6h4VF+8NIlE0iamaVK7gP91mWd3LpsLn/2/H5Gx3xy1MzSK/eBDvDZD6/izffe54c/9yWMZpZeDnTgY7cs4pcWtrHxx32MuZduZinlQAcKBfGHv7GKPe+eci/dzFLLgV6y/ral3Ly4gz/+0eucHRlNuhwzs6vmQC+pK4h/9/GbOXj8DN/6231Jl2NmdtUc6GV+ZVUX969dxH/5cR99Rzxpl5mliwN9gj/acAstjXV84bmdnB8dS7ocM7OKOdAnWNjezFc2rGXngeN87X+/kXQ5ZmYVc6BP4hO3LeGhu5bz1E/28sNXfNWLmaWDA/0y/mj9Wu68fh5feO5ltu87mnQ5ZmZTcqBfRmN9gW/+kztZNq+Vf/6dHfTuP5Z0SWZmV+RAv4Kutiae/hd384GOZj71rZ+xxTcdmdks5kCfwsKOZp77zL3csqSDP/jLv+c/bt3DiK9+MbNZyIFega62Jp7+9D188s5lfH1bH//oqf/HqwdPJF2WmdlFHOgVam6o46ufvI1vPHwHB4+f4RNf/7988bmdHDg2lHRpZmYA1CddQNqsW7uYe1d18dS2Pv7s+f38j7/v52O3LOJTd1/PvasWUFdQ0iWaWU4pqSfe9/T0RG9vbyL7rpZ3Tpzlz1/Yz19uf4uTZ8/T1dbE/WsX8Wuru7h75QLmtjQkXaKZZYykFyOiZ9JtDvTpOzsyyo9fP8Jf7zzEtj1HODsyRkFw06IOblnSwZolHdy0qIPrFrSyqKPZvXgzu2YO9Bo6d36Ul98+zvN7j/LS24O8dvgk750evrC9viCWdLaweG4zC9oamdfayII5jcyb08jclgZaG+toaaynpaGutFz83lBXoL4g6se/F0RdQUj+5WCWJ1cK9IrG0CWtA/4zUAd8KyL+eML2JuC7wJ3AUeB3I2L/dIpOq6b6Ou5euYC7Vy64sO7IqbO88c5pDgwOceDYEP2DZzh84gx73jnF4NAIg0PDXOvv1fFgb6grlAIeBEj6xfcL60Bc3Ibx9YJC+Z+BYqNZYJaUMSt+eSZfgVXDv/zIaj5x25Kqv++UgS6pDtgI/BbQD+yQtDkidpc1ewQYjIhfkvQg8CfA71a92pRa2N7Mwvbmy24fHQtOnhnhxJkRzoyMMjQ8ypnhUYaGz3NmpLg8MjrG+bHg/GiUvo8xMhaMjo1dtC6ACAii9B0iSsuXrC++pqzdWNnybDA7qmBWFBKzoQiripk6v1ZJD/0uoC8i9gFIegbYAJQH+gbgy6Xl7wNfl6SYLakwy9UVxLzSsIuZ2bWq5Dr0pcCBstf9pXWTtomI88AJYMGENkh6VFKvpN6BgYFrq9jMzCZV0xuLImJTRPRERE93d3ctd21mlnmVBPpBYHnZ62WldZO2kVQPzKV4ctTMzGqkkkDfAayWdIOkRuBBYPOENpuBf1pa/sfAjz1+bmZWW1OeFI2I85IeA7ZSvGzx2xGxS9KTQG9EbAb+O/AXkvqAYxRD38zMaqii69AjYguwZcK6J8qWzwKfrG5pZmZ2NTzboplZRjjQzcwyIrG5XCQNAG9d4x/vAt6rYjlp4GPOBx9zPkznmK+PiEmv+04s0KdDUu/lJqfJKh9zPviY82GmjtlDLmZmGeFANzPLiLQG+qakC0iAjzkffMz5MCPHnMoxdDMzu1Rae+hmZjaBA93MLCNmdaBLWidpj6Q+SY9Psr1J0l+Vtv9M0ooEyqyqCo7585J2S3pF0t9Iuj6JOqtpqmMua/fbkkJS6i9xq+SYJf1O6bPeJenpWtdYbRX8bF8naZukl0o/3w8kUWe1SPq2pCOSXr3Mdkn609LfxyuS7pj2TouPJ5t9XxQnAtsLrAQagZ3Amglt/gD4Rmn5QeCvkq67Bsf8G0BrafmzeTjmUrt24KfAdqAn6bpr8DmvBl4C5pVeL0y67hoc8ybgs6XlNcD+pOue5jH/OnAH8Opltj8A/Ijio2LvAX423X3O5h76hUffRcQwMP7ou3IbgD8vLX8f+Ihmw5N8r92UxxwR2yJiqPRyO8X56dOsks8Z4CsUn1V7tpbFzZBKjvnTwMaIGASIiCM1rrHaKjnmADpKy3OBQzWsr+oi4qcUZ5+9nA3Ad6NoO9ApafF09jmbA71qj75LkUqOudwjFH/Dp9mUx1z6r+jyiPhhLQubQZV8zjcCN0p6XtJ2SetqVt3MqOSYvww8LKmf4uyun6tNaYm52n/vU6po+lybfSQ9DPQAH066lpkkqQB8Dfj9hEuptXqKwy73Ufxf2E8l/XJEHE+yqBn2EPCdiPhPku6l+IyFtRExlnRhaTGbe+h5fPRdJceMpI8CXwLWR8S5GtU2U6Y65nZgLfATSfspjjVuTvmJ0Uo+535gc0SMRMSbwBsUAz6tKjnmR4BnASLiBaCZ4iRWWVXRv/erMZsDPY+PvpvymCV9CPgmxTBP+7gqTHHMEXEiIroiYkVErKB43mB9RPQmU25VVPKz/QOKvXMkdVEcgtlXwxqrrZJjfhv4CICkmykG+kBNq6ytzcDvla52uQc4ERGHp/WOSZ8JnuIs8QMUeyZ7gS+V1j1J8R80FD/w54A+4O+AlUnXXINj/j/Au8DLpa/NSdc808c8oe1PSPlVLhV+zqI41LQb+DnwYNI11+CY1wDPU7wC5mXgHyRd8zSP93vAYWCE4v+4HgE+A3ym7DPeWPr7+Hk1fq5967+ZWUbM5iEXMzO7Cg50M7OMcKCbmWWEA93MLCMc6GZmNTDVZF2TtL/qydl8lYuZWQ1I+nXgNMX5W9ZO0XY1xZusfjMiBiUtjAruO3EP3cysBmKSybokrZL0vyS9KOlvJd1U2nRNk7M50M3MkrMJ+FxE3An8a+Cp0vprmpzNk3OZmSVAUhvwK8BzZbN+N5W+X9PkbA50M7NkFIDjEXH7JNv6KT7wYgR4U9L45Gw7pnpDMzOrsYg4STGsPwkXHkl3W2nzD7iGydkc6GZmNSDpe8ALwAcl9Ut6BPgU8IikncAufvEUp63AUUm7gW3AFyNiyqnBfdmimVlGuIduZpYRDnQzs4xwoJuZZYQD3cwsIxzoZmYZ4UA3M8sIB7qZWUb8fwP37w/uJfr6AAAAAElFTkSuQmCC", 454 | "text/plain": [ 455 | "
" 456 | ] 457 | }, 458 | "metadata": { 459 | "needs_background": "light" 460 | }, 461 | "output_type": "display_data" 462 | } 463 | ], 464 | "source": [ 465 | "plt.plot([epsilon_by_frame(i) for i in range(1000000)])" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 21, 471 | "metadata": {}, 472 | "outputs": [ 473 | { 474 | "name": "stderr", 475 | "output_type": "stream", 476 | "text": [ 477 | ":2: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n", 478 | " Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)\n" 479 | ] 480 | }, 481 | { 482 | "ename": "KeyboardInterrupt", 483 | "evalue": "", 484 | "output_type": "error", 485 | "traceback": [ 486 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 487 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 488 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplay_buffer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mreplay_initial\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_td_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 489 | "\u001b[0;32m\u001b[0m in \u001b[0;36mcompute_td_loss\u001b[0;34m(batch_size)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 490 | "\u001b[0;32m~/anaconda3/envs/pytorch-gat/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 491 | "\u001b[0;32m~/anaconda3/envs/pytorch-gat/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 131\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m allow_unreachable=True) # allow_unreachable flag\n", 492 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 493 | ] 494 | } 495 | ], 496 | "source": [ 497 | "num_frames = 1400000\n", 498 | "batch_size = 32\n", 499 | "gamma = 0.99\n", 500 | "\n", 501 | "losses = []\n", 502 | "all_rewards = []\n", 503 | "episode_reward = 0\n", 504 | "\n", 505 | "state= env.reset()\n", 506 | "for frame_idx in range(1, num_frames + 1):\n", 507 | " epsilon = epsilon_by_frame(frame_idx)\n", 508 | " action = model.act(state, epsilon)\n", 509 | " \n", 510 | " next_state, reward, done, _ = env.step(action)\n", 511 | " replay_buffer.push(state, action, reward, next_state, done)\n", 512 | " \n", 513 | " state = next_state\n", 514 | " episode_reward += reward\n", 515 | " \n", 516 | " if done:\n", 517 | " state = env.reset()\n", 518 | " all_rewards.append(episode_reward)\n", 519 | " episode_reward = 0\n", 520 | " \n", 521 | " if len(replay_buffer) > replay_initial:\n", 522 | " loss = compute_td_loss(batch_size)\n", 523 | " losses.append(loss.item())\n", 524 | " \n", 525 | " if frame_idx % 10000 == 0:\n", 526 | " plot(frame_idx, all_rewards, losses)" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": {}, 533 | "outputs": [], 534 | "source": [] 535 | } 536 | ], 537 | "metadata": { 538 | "kernelspec": { 539 | "display_name": "pytorch-gat", 540 | "language": "python", 541 | "name": "pytorch-gat" 542 | }, 543 | "language_info": { 544 | "codemirror_mode": { 545 | "name": "ipython", 546 | "version": 3 547 | }, 548 | "file_extension": ".py", 549 | "mimetype": "text/x-python", 550 | "name": "python", 551 | "nbconvert_exporter": "python", 552 | "pygments_lexer": "ipython3", 553 | "version": "3.8.5" 554 | } 555 | }, 556 | "nbformat": 4, 557 | "nbformat_minor": 2 558 | } 559 | --------------------------------------------------------------------------------