├── common ├── __init__.py ├── layers.py ├── wrappers.py ├── replay_buffer.py └── replay_buffer_dtm.py ├── requirements.txt ├── README.md ├── LICENSE ├── .gitignore ├── controller.py ├── em.py ├── dqn_mbec.py └── dqn.ipynb /common/__init__.py: -------------------------------------------------------------------------------- 1 | from . import layers 2 | from . import wrappers 3 | from . import replay_buffer -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ale_py==0.7.5 2 | gym==0.26.2 3 | ipython==8.7.0 4 | matplotlib==3.6.2 5 | numpy==1.19.2 6 | opencv_python==4.6.0.66 7 | pyflann==1.6.14 8 | pyflann_py3==0.1.0 9 | tensorboard_logger==0.1.0 10 | torch==1.4.0 11 | tqdm==4.63.0 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AJCAI22-Tutorial 2 | ### Demo code for AJCAI22-Tutorial 3 | [Tutorial website](https://thaihungle.github.io/talks/2022-12-05-AJCAI) 4 | [Conference website](https://ajcai2022.org/tutorials/) 5 | 6 | Install packages following requirements.txt 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | ### DQN 11 | Follow dqn.ipynb 12 | Reference: https://github.com/higgsfield/RL-Adventure/blob/master/1.dqn.ipynb 13 | 14 | 15 | ### MBEC 16 | - Create log and model folders to save training info 17 | ``` 18 | mkdir log 19 | mkdir model 20 | ``` 21 | - Run the script dqn_mbec.py using 22 | 23 | ``` 24 | python dqn_mbec.py --task MountainCar-v0 --rnoise 0.5 --render 0 --task2 mountaincar --n_epochs 100000 --max_episode 1000000 --model_name DTM --update_interval 100 --decay 1 --memory_size 3000 --k 15 --write_interval 10 --td_interval 1 --write_lr .5 --rec_rate .1 --rec_noise .1 --batch_size_plan 4 --rec_period 9999999999 --num_warm_up -1 --lr 0.0005 25 | ``` 26 | - Monitor training using tensorboard 27 | ``` 28 | cd log 29 | tensorboard --logdir=./ 30 | ``` 31 | Reference: https://github.com/thaihungle/MBEC-plus 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tony 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | log/ 131 | model/ -------------------------------------------------------------------------------- /common/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class NoisyLinear(nn.Module): 8 | def __init__(self, in_features, out_features, use_cuda, std_init=0.4): 9 | super(NoisyLinear, self).__init__() 10 | 11 | self.use_cuda = use_cuda 12 | self.in_features = in_features 13 | self.out_features = out_features 14 | self.std_init = std_init 15 | 16 | self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features)) 17 | self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features)) 18 | self.register_buffer('weight_epsilon', torch.FloatTensor(out_features, in_features)) 19 | 20 | self.bias_mu = nn.Parameter(torch.FloatTensor(out_features)) 21 | self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features)) 22 | self.register_buffer('bias_epsilon', torch.FloatTensor(out_features)) 23 | 24 | self.reset_parameters() 25 | self.reset_noise() 26 | 27 | def forward(self, x): 28 | if self.use_cuda: 29 | weight_epsilon = self.weight_epsilon.cuda() 30 | bias_epsilon = self.bias_epsilon.cuda() 31 | else: 32 | weight_epsilon = self.weight_epsilon 33 | bias_epsilon = self.bias_epsilon 34 | 35 | if self.training: 36 | weight = self.weight_mu + self.weight_sigma.mul(Variable(weight_epsilon)) 37 | bias = self.bias_mu + self.bias_sigma.mul(Variable(bias_epsilon)) 38 | else: 39 | weight = self.weight_mu 40 | bias = self.bias_mu 41 | 42 | return F.linear(x, weight, bias) 43 | 44 | def reset_parameters(self): 45 | mu_range = 1 / math.sqrt(self.weight_mu.size(1)) 46 | 47 | self.weight_mu.data.uniform_(-mu_range, mu_range) 48 | self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.weight_sigma.size(1))) 49 | 50 | self.bias_mu.data.uniform_(-mu_range, mu_range) 51 | self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.bias_sigma.size(0))) 52 | 53 | def reset_noise(self): 54 | epsilon_in = self._scale_noise(self.in_features) 55 | epsilon_out = self._scale_noise(self.out_features) 56 | 57 | self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in)) 58 | self.bias_epsilon.copy_(self._scale_noise(self.out_features)) 59 | 60 | def _scale_noise(self, size): 61 | x = torch.randn(size) 62 | x = x.sign().mul(x.abs().sqrt()) 63 | return x -------------------------------------------------------------------------------- /controller.py: -------------------------------------------------------------------------------- 1 | """LSTM Controller.""" 2 | import torch 3 | from torch import nn 4 | from torch.nn import Parameter 5 | import numpy as np 6 | 7 | 8 | class FFWController(nn.Module): 9 | """An NTM controller based on LSTM.""" 10 | def __init__(self, num_inputs, num_outputs, num_layers): 11 | super(FFWController, self).__init__() 12 | 13 | self.num_inputs = num_inputs 14 | self.num_outputs = num_outputs 15 | self.num_layers = num_layers 16 | 17 | 18 | 19 | def create_new_state(self, batch_size): 20 | h = torch.zeros(batch_size, self.num_outputs) 21 | if torch.cuda.is_available(): 22 | h = h.cuda() 23 | return h 24 | 25 | def reset_parameters(self): 26 | pass 27 | 28 | def size(self): 29 | return self.num_inputs, self.num_outputs 30 | 31 | def forward(self, x, prev_state): 32 | return x, prev_state 33 | 34 | class LSTMController(nn.Module): 35 | """An NTM controller based on LSTM.""" 36 | def __init__(self, num_inputs, num_outputs, num_layers): 37 | super(LSTMController, self).__init__() 38 | 39 | self.num_inputs = num_inputs 40 | self.num_outputs = num_outputs 41 | self.num_layers = num_layers 42 | 43 | self.lstm = nn.LSTM(input_size=num_inputs, 44 | hidden_size=num_outputs, 45 | num_layers=num_layers) 46 | # The hidden state is a learned parameter 47 | if torch.cuda.is_available(): 48 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0, requires_grad=False) 49 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs).cuda() * 0, requires_grad=False) 50 | 51 | else: 52 | self.lstm_h_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0, requires_grad=False) 53 | self.lstm_c_bias = Parameter(torch.randn(self.num_layers, 1, self.num_outputs) * 0, requires_grad=False) 54 | 55 | self.reset_parameters() 56 | 57 | def create_new_state(self, batch_size): 58 | # Dimension: (num_layers * num_directions, batch, hidden_size) 59 | lstm_h = self.lstm_h_bias.clone().repeat(1, batch_size, 1) 60 | lstm_c = self.lstm_c_bias.clone().repeat(1, batch_size, 1) 61 | # h = torch.zeros(self.num_layers, batch_size, self.num_outputs) 62 | # c = torch.zeros(self.num_layers, batch_size, self.num_outputs) 63 | # if torch.cuda.is_available(): 64 | # h = h.cuda() 65 | # c = c.cuda() 66 | # return h,c 67 | 68 | 69 | return lstm_h, lstm_c 70 | 71 | def reset_parameters(self): 72 | for p in self.lstm.parameters(): 73 | if p.dim() == 1: 74 | nn.init.constant_(p, 0) 75 | else: 76 | stdev = 5 / (np.sqrt(self.num_inputs + self.num_outputs)) 77 | nn.init.uniform_(p, -stdev, stdev) 78 | 79 | def size(self): 80 | return self.num_inputs, self.num_outputs 81 | 82 | def forward(self, x, prev_state): 83 | x = x.unsqueeze(0) 84 | outp, state = self.lstm(x, prev_state) 85 | return outp.squeeze(0), state -------------------------------------------------------------------------------- /common/wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | import gym 4 | from gym import spaces 5 | import cv2 6 | cv2.ocl.setUseOpenCL(False) 7 | 8 | class NoopResetEnv(gym.Wrapper): 9 | def __init__(self, env, noop_max=30): 10 | """Sample initial states by taking random number of no-ops on reset. 11 | No-op is assumed to be action 0. 12 | """ 13 | gym.Wrapper.__init__(self, env) 14 | self.noop_max = noop_max 15 | self.override_num_noops = None 16 | self.noop_action = 0 17 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 18 | 19 | def reset(self, **kwargs): 20 | """ Do no-op action for a number of steps in [1, noop_max].""" 21 | self.env.reset(**kwargs) 22 | if self.override_num_noops is not None: 23 | noops = self.override_num_noops 24 | else: 25 | noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101 26 | assert noops > 0 27 | obs = None 28 | for _ in range(noops): 29 | obs, _, done, _ = self.env.step(self.noop_action) 30 | if done: 31 | obs = self.env.reset(**kwargs) 32 | return obs 33 | 34 | def step(self, ac): 35 | return self.env.step(ac) 36 | 37 | class FireResetEnv(gym.Wrapper): 38 | def __init__(self, env): 39 | """Take action on reset for environments that are fixed until firing.""" 40 | gym.Wrapper.__init__(self, env) 41 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 42 | assert len(env.unwrapped.get_action_meanings()) >= 3 43 | 44 | def reset(self, **kwargs): 45 | self.env.reset(**kwargs) 46 | obs, _, done, _ = self.env.step(1) 47 | if done: 48 | self.env.reset(**kwargs) 49 | obs, _, done, _ = self.env.step(2) 50 | if done: 51 | self.env.reset(**kwargs) 52 | return obs 53 | 54 | def step(self, ac): 55 | return self.env.step(ac) 56 | 57 | class EpisodicLifeEnv(gym.Wrapper): 58 | def __init__(self, env): 59 | """Make end-of-life == end-of-episode, but only reset on true game over. 60 | Done by DeepMind for the DQN and co. since it helps value estimation. 61 | """ 62 | gym.Wrapper.__init__(self, env) 63 | self.lives = 0 64 | self.was_real_done = True 65 | 66 | def step(self, action): 67 | obs, reward, done, info = self.env.step(action) 68 | self.was_real_done = done 69 | # check current lives, make loss of life terminal, 70 | # then update lives to handle bonus lives 71 | lives = self.env.unwrapped.ale.lives() 72 | if lives < self.lives and lives > 0: 73 | # for Qbert sometimes we stay in lives == 0 condtion for a few frames 74 | # so its important to keep lives > 0, so that we only reset once 75 | # the environment advertises done. 76 | done = True 77 | self.lives = lives 78 | return obs, reward, done, info 79 | 80 | def reset(self, **kwargs): 81 | """Reset only when lives are exhausted. 82 | This way all states are still reachable even though lives are episodic, 83 | and the learner need not know about any of this behind-the-scenes. 84 | """ 85 | if self.was_real_done: 86 | obs = self.env.reset(**kwargs) 87 | else: 88 | # no-op step to advance from terminal/lost life state 89 | obs, _, _, _ = self.env.step(0) 90 | self.lives = self.env.unwrapped.ale.lives() 91 | return obs 92 | 93 | class MaxAndSkipEnv(gym.Wrapper): 94 | def __init__(self, env, skip=4): 95 | """Return only every `skip`-th frame""" 96 | gym.Wrapper.__init__(self, env) 97 | # most recent raw observations (for max pooling across time steps) 98 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 99 | self._skip = skip 100 | 101 | def reset(self): 102 | return self.env.reset() 103 | 104 | def step(self, action): 105 | """Repeat action, sum reward, and max over last observations.""" 106 | total_reward = 0.0 107 | done = None 108 | for i in range(self._skip): 109 | obs, reward, done, info = self.env.step(action) 110 | if i == self._skip - 2: self._obs_buffer[0] = obs 111 | if i == self._skip - 1: self._obs_buffer[1] = obs 112 | total_reward += reward 113 | if done: 114 | break 115 | # Note that the observation on the done=True frame 116 | # doesn't matter 117 | max_frame = self._obs_buffer.max(axis=0) 118 | 119 | return max_frame, total_reward, done, info 120 | 121 | def reset(self, **kwargs): 122 | return self.env.reset(**kwargs) 123 | 124 | class ClipRewardEnv(gym.RewardWrapper): 125 | def __init__(self, env): 126 | gym.RewardWrapper.__init__(self, env) 127 | 128 | def reward(self, reward): 129 | """Bin reward to {+1, 0, -1} by its sign.""" 130 | return np.sign(reward) 131 | 132 | class WarpFrame(gym.ObservationWrapper): 133 | def __init__(self, env): 134 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 135 | gym.ObservationWrapper.__init__(self, env) 136 | self.width = 84 137 | self.height = 84 138 | self.observation_space = spaces.Box(low=0, high=255, 139 | shape=(self.height, self.width, 1), dtype=np.uint8) 140 | 141 | def observation(self, frame): 142 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 143 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 144 | return frame[:, :, None] 145 | 146 | class FrameStack(gym.Wrapper): 147 | def __init__(self, env, k): 148 | """Stack k last frames. 149 | Returns lazy array, which is much more memory efficient. 150 | See Also 151 | -------- 152 | baselines.common.atari_wrappers.LazyFrames 153 | """ 154 | gym.Wrapper.__init__(self, env) 155 | self.k = k 156 | self.frames = deque([], maxlen=k) 157 | shp = env.observation_space.shape 158 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) 159 | 160 | def reset(self): 161 | ob = self.env.reset() 162 | for _ in range(self.k): 163 | self.frames.append(ob) 164 | return self._get_ob() 165 | 166 | def step(self, action): 167 | ob, reward, done, info = self.env.step(action) 168 | self.frames.append(ob) 169 | return self._get_ob(), reward, done, info 170 | 171 | def _get_ob(self): 172 | assert len(self.frames) == self.k 173 | return LazyFrames(list(self.frames)) 174 | 175 | class ScaledFloatFrame(gym.ObservationWrapper): 176 | def __init__(self, env): 177 | gym.ObservationWrapper.__init__(self, env) 178 | 179 | def observation(self, observation): 180 | # careful! This undoes the memory optimization, use 181 | # with smaller replay buffers only. 182 | return np.array(observation).astype(np.float32) / 255.0 183 | 184 | class LazyFrames(object): 185 | def __init__(self, frames): 186 | """This object ensures that common frames between the observations are only stored once. 187 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 188 | buffers. 189 | This object should only be converted to numpy array before being passed to the model. 190 | You'd not believe how complex the previous solution was.""" 191 | self._frames = frames 192 | self._out = None 193 | 194 | def _force(self): 195 | if self._out is None: 196 | self._out = np.concatenate(self._frames, axis=2) 197 | self._frames = None 198 | return self._out 199 | 200 | def __array__(self, dtype=None): 201 | out = self._force() 202 | if dtype is not None: 203 | out = out.astype(dtype) 204 | return out 205 | 206 | def __len__(self): 207 | return len(self._force()) 208 | 209 | def __getitem__(self, i): 210 | return self._force()[i] 211 | 212 | def make_atari(env_id): 213 | env = gym.make(env_id) 214 | assert 'NoFrameskip' in env.spec.id 215 | env = NoopResetEnv(env, noop_max=30) 216 | env = MaxAndSkipEnv(env, skip=4) 217 | return env 218 | 219 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 220 | """Configure environment for DeepMind-style Atari. 221 | """ 222 | if episode_life: 223 | env = EpisodicLifeEnv(env) 224 | if 'FIRE' in env.unwrapped.get_action_meanings(): 225 | env = FireResetEnv(env) 226 | env = WarpFrame(env) 227 | if scale: 228 | env = ScaledFloatFrame(env) 229 | if clip_rewards: 230 | env = ClipRewardEnv(env) 231 | if frame_stack: 232 | env = FrameStack(env, 4) 233 | return env 234 | 235 | 236 | 237 | class ImageToPyTorch(gym.ObservationWrapper): 238 | """ 239 | Image shape to num_channels x weight x height 240 | """ 241 | def __init__(self, env): 242 | super(ImageToPyTorch, self).__init__(env) 243 | old_shape = self.observation_space.shape 244 | self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.uint8) 245 | 246 | def observation(self, observation): 247 | return np.swapaxes(observation, 2, 0) 248 | 249 | 250 | def wrap_pytorch(env): 251 | return ImageToPyTorch(env) -------------------------------------------------------------------------------- /common/replay_buffer.py: -------------------------------------------------------------------------------- 1 | #code from openai 2 | #https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py 3 | 4 | import numpy as np 5 | import random 6 | 7 | import operator 8 | 9 | 10 | class SegmentTree(object): 11 | def __init__(self, capacity, operation, neutral_element): 12 | """Build a Segment Tree data structure. 13 | https://en.wikipedia.org/wiki/Segment_tree 14 | Can be used as regular array, but with two 15 | important differences: 16 | a) setting item's value is slightly slower. 17 | It is O(lg capacity) instead of O(1). 18 | b) user has access to an efficient `reduce` 19 | operation which reduces `operation` over 20 | a contiguous subsequence of items in the 21 | array. 22 | Paramters 23 | --------- 24 | capacity: int 25 | Total size of the array - must be a power of two. 26 | operation: lambda obj, obj -> obj 27 | and operation for combining elements (eg. sum, max) 28 | must for a mathematical group together with the set of 29 | possible values for array elements. 30 | neutral_element: obj 31 | neutral element for the operation above. eg. float('-inf') 32 | for max and 0 for sum. 33 | """ 34 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 35 | self._capacity = capacity 36 | self._value = [neutral_element for _ in range(2 * capacity)] 37 | self._operation = operation 38 | 39 | def _reduce_helper(self, start, end, node, node_start, node_end): 40 | if start == node_start and end == node_end: 41 | return self._value[node] 42 | mid = (node_start + node_end) // 2 43 | if end <= mid: 44 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 45 | else: 46 | if mid + 1 <= start: 47 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 48 | else: 49 | return self._operation( 50 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 51 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 52 | ) 53 | 54 | def reduce(self, start=0, end=None): 55 | """Returns result of applying `self.operation` 56 | to a contiguous subsequence of the array. 57 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 58 | Parameters 59 | ---------- 60 | start: int 61 | beginning of the subsequence 62 | end: int 63 | end of the subsequences 64 | Returns 65 | ------- 66 | reduced: obj 67 | result of reducing self.operation over the specified range of array elements. 68 | """ 69 | if end is None: 70 | end = self._capacity 71 | if end < 0: 72 | end += self._capacity 73 | end -= 1 74 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 75 | 76 | def __setitem__(self, idx, val): 77 | # index of the leaf 78 | idx += self._capacity 79 | self._value[idx] = val 80 | idx //= 2 81 | while idx >= 1: 82 | self._value[idx] = self._operation( 83 | self._value[2 * idx], 84 | self._value[2 * idx + 1] 85 | ) 86 | idx //= 2 87 | 88 | def __getitem__(self, idx): 89 | assert 0 <= idx < self._capacity 90 | return self._value[self._capacity + idx] 91 | 92 | 93 | class SumSegmentTree(SegmentTree): 94 | def __init__(self, capacity): 95 | super(SumSegmentTree, self).__init__( 96 | capacity=capacity, 97 | operation=operator.add, 98 | neutral_element=0.0 99 | ) 100 | 101 | def sum(self, start=0, end=None): 102 | """Returns arr[start] + ... + arr[end]""" 103 | return super(SumSegmentTree, self).reduce(start, end) 104 | 105 | def find_prefixsum_idx(self, prefixsum): 106 | """Find the highest index `i` in the array such that 107 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 108 | if array values are probabilities, this function 109 | allows to sample indexes according to the discrete 110 | probability efficiently. 111 | Parameters 112 | ---------- 113 | perfixsum: float 114 | upperbound on the sum of array prefix 115 | Returns 116 | ------- 117 | idx: int 118 | highest index satisfying the prefixsum constraint 119 | """ 120 | assert 0 <= prefixsum <= self.sum() + 1e-5 121 | idx = 1 122 | while idx < self._capacity: # while non-leaf 123 | if self._value[2 * idx] > prefixsum: 124 | idx = 2 * idx 125 | else: 126 | prefixsum -= self._value[2 * idx] 127 | idx = 2 * idx + 1 128 | return idx - self._capacity 129 | 130 | 131 | class MinSegmentTree(SegmentTree): 132 | def __init__(self, capacity): 133 | super(MinSegmentTree, self).__init__( 134 | capacity=capacity, 135 | operation=min, 136 | neutral_element=float('inf') 137 | ) 138 | 139 | def min(self, start=0, end=None): 140 | """Returns min(arr[start], ..., arr[end])""" 141 | 142 | return super(MinSegmentTree, self).reduce(start, end) 143 | 144 | 145 | class ReplayBuffer(object): 146 | def __init__(self, size): 147 | """Create Replay buffer. 148 | Parameters 149 | ---------- 150 | size: int 151 | Max number of transitions to store in the buffer. When the buffer 152 | overflows the old memories are dropped. 153 | """ 154 | self._storage = [] 155 | self._maxsize = size 156 | self._next_idx = 0 157 | 158 | def __len__(self): 159 | return len(self._storage) 160 | 161 | def push(self, state, action, reward, next_state, done): 162 | data = (state, action, reward, next_state, done) 163 | 164 | if self._next_idx >= len(self._storage): 165 | self._storage.append(data) 166 | else: 167 | self._storage[self._next_idx] = data 168 | self._next_idx = (self._next_idx + 1) % self._maxsize 169 | 170 | def _encode_sample(self, idxes): 171 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 172 | for i in idxes: 173 | data = self._storage[i] 174 | obs_t, action, reward, obs_tp1, done = data 175 | obses_t.append(np.array(obs_t, copy=False)) 176 | actions.append(np.array(action, copy=False)) 177 | rewards.append(reward) 178 | obses_tp1.append(np.array(obs_tp1, copy=False)) 179 | dones.append(done) 180 | return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones) 181 | 182 | def sample(self, batch_size): 183 | """Sample a batch of experiences. 184 | Parameters 185 | ---------- 186 | batch_size: int 187 | How many transitions to sample. 188 | Returns 189 | ------- 190 | obs_batch: np.array 191 | batch of observations 192 | act_batch: np.array 193 | batch of actions executed given obs_batch 194 | rew_batch: np.array 195 | rewards received as results of executing act_batch 196 | next_obs_batch: np.array 197 | next set of observations seen after executing act_batch 198 | done_mask: np.array 199 | done_mask[i] = 1 if executing act_batch[i] resulted in 200 | the end of an episode and 0 otherwise. 201 | """ 202 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 203 | return self._encode_sample(idxes) 204 | 205 | 206 | class PrioritizedReplayBuffer(ReplayBuffer): 207 | def __init__(self, size, alpha): 208 | """Create Prioritized Replay buffer. 209 | Parameters 210 | ---------- 211 | size: int 212 | Max number of transitions to store in the buffer. When the buffer 213 | overflows the old memories are dropped. 214 | alpha: float 215 | how much prioritization is used 216 | (0 - no prioritization, 1 - full prioritization) 217 | See Also 218 | -------- 219 | ReplayBuffer.__init__ 220 | """ 221 | super(PrioritizedReplayBuffer, self).__init__(size) 222 | assert alpha > 0 223 | self._alpha = alpha 224 | 225 | it_capacity = 1 226 | while it_capacity < size: 227 | it_capacity *= 2 228 | 229 | self._it_sum = SumSegmentTree(it_capacity) 230 | self._it_min = MinSegmentTree(it_capacity) 231 | self._max_priority = 1.0 232 | 233 | def push(self, *args, **kwargs): 234 | """See ReplayBuffer.store_effect""" 235 | idx = self._next_idx 236 | super(PrioritizedReplayBuffer, self).push(*args, **kwargs) 237 | self._it_sum[idx] = self._max_priority ** self._alpha 238 | self._it_min[idx] = self._max_priority ** self._alpha 239 | 240 | def _sample_proportional(self, batch_size): 241 | res = [] 242 | for _ in range(batch_size): 243 | # TODO(szymon): should we ensure no repeats? 244 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 245 | idx = self._it_sum.find_prefixsum_idx(mass) 246 | res.append(idx) 247 | return res 248 | 249 | def sample(self, batch_size, beta): 250 | """Sample a batch of experiences. 251 | compared to ReplayBuffer.sample 252 | it also returns importance weights and idxes 253 | of sampled experiences. 254 | Parameters 255 | ---------- 256 | batch_size: int 257 | How many transitions to sample. 258 | beta: float 259 | To what degree to use importance weights 260 | (0 - no corrections, 1 - full correction) 261 | Returns 262 | ------- 263 | obs_batch: np.array 264 | batch of observations 265 | act_batch: np.array 266 | batch of actions executed given obs_batch 267 | rew_batch: np.array 268 | rewards received as results of executing act_batch 269 | next_obs_batch: np.array 270 | next set of observations seen after executing act_batch 271 | done_mask: np.array 272 | done_mask[i] = 1 if executing act_batch[i] resulted in 273 | the end of an episode and 0 otherwise. 274 | weights: np.array 275 | Array of shape (batch_size,) and dtype np.float32 276 | denoting importance weight of each sampled transition 277 | idxes: np.array 278 | Array of shape (batch_size,) and dtype np.int32 279 | idexes in buffer of sampled experiences 280 | """ 281 | assert beta > 0 282 | 283 | idxes = self._sample_proportional(batch_size) 284 | 285 | weights = [] 286 | p_min = self._it_min.min() / self._it_sum.sum() 287 | max_weight = (p_min * len(self._storage)) ** (-beta) 288 | 289 | for idx in idxes: 290 | p_sample = self._it_sum[idx] / self._it_sum.sum() 291 | weight = (p_sample * len(self._storage)) ** (-beta) 292 | weights.append(weight / max_weight) 293 | weights = np.array(weights) 294 | encoded_sample = self._encode_sample(idxes) 295 | return tuple(list(encoded_sample) + [weights, idxes]) 296 | 297 | def update_priorities(self, idxes, priorities): 298 | """Update priorities of sampled transitions. 299 | sets priority of transition at index idxes[i] in buffer 300 | to priorities[i]. 301 | Parameters 302 | ---------- 303 | idxes: [int] 304 | List of idxes of sampled transitions 305 | priorities: [float] 306 | List of updated priorities corresponding to 307 | transitions at the sampled idxes denoted by 308 | variable `idxes`. 309 | """ 310 | assert len(idxes) == len(priorities) 311 | for idx, priority in zip(idxes, priorities): 312 | assert priority > 0 313 | assert 0 <= idx < len(self._storage) 314 | self._it_sum[idx] = priority ** self._alpha 315 | self._it_min[idx] = priority ** self._alpha 316 | 317 | self._max_priority = max(self._max_priority, priority) -------------------------------------------------------------------------------- /common/replay_buffer_dtm.py: -------------------------------------------------------------------------------- 1 | #code from openai 2 | #https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py 3 | 4 | import numpy as np 5 | import random 6 | 7 | import operator 8 | 9 | 10 | class SegmentTree(object): 11 | def __init__(self, capacity, operation, neutral_element): 12 | """Build a Segment Tree data structure. 13 | https://en.wikipedia.org/wiki/Segment_tree 14 | Can be used as regular array, but with two 15 | important differences: 16 | a) setting item's value is slightly slower. 17 | It is O(lg capacity) instead of O(1). 18 | b) user has access to an efficient `reduce` 19 | operation which reduces `operation` over 20 | a contiguous subsequence of items in the 21 | array. 22 | Paramters 23 | --------- 24 | capacity: int 25 | Total size of the array - must be a power of two. 26 | operation: lambda obj, obj -> obj 27 | and operation for combining elements (eg. sum, max) 28 | must for a mathematical group together with the set of 29 | possible values for array elements. 30 | neutral_element: obj 31 | neutral element for the operation above. eg. float('-inf') 32 | for max and 0 for sum. 33 | """ 34 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 35 | self._capacity = capacity 36 | self._value = [neutral_element for _ in range(2 * capacity)] 37 | self._operation = operation 38 | 39 | def _reduce_helper(self, start, end, node, node_start, node_end): 40 | if start == node_start and end == node_end: 41 | return self._value[node] 42 | mid = (node_start + node_end) // 2 43 | if end <= mid: 44 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 45 | else: 46 | if mid + 1 <= start: 47 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 48 | else: 49 | return self._operation( 50 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 51 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 52 | ) 53 | 54 | def reduce(self, start=0, end=None): 55 | """Returns result of applying `self.operation` 56 | to a contiguous subsequence of the array. 57 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 58 | Parameters 59 | ---------- 60 | start: int 61 | beginning of the subsequence 62 | end: int 63 | end of the subsequences 64 | Returns 65 | ------- 66 | reduced: obj 67 | result of reducing self.operation over the specified range of array elements. 68 | """ 69 | if end is None: 70 | end = self._capacity 71 | if end < 0: 72 | end += self._capacity 73 | end -= 1 74 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 75 | 76 | def __setitem__(self, idx, val): 77 | # index of the leaf 78 | idx += self._capacity 79 | self._value[idx] = val 80 | idx //= 2 81 | while idx >= 1: 82 | self._value[idx] = self._operation( 83 | self._value[2 * idx], 84 | self._value[2 * idx + 1] 85 | ) 86 | idx //= 2 87 | 88 | def __getitem__(self, idx): 89 | assert 0 <= idx < self._capacity 90 | return self._value[self._capacity + idx] 91 | 92 | 93 | class SumSegmentTree(SegmentTree): 94 | def __init__(self, capacity): 95 | super(SumSegmentTree, self).__init__( 96 | capacity=capacity, 97 | operation=operator.add, 98 | neutral_element=0.0 99 | ) 100 | 101 | def sum(self, start=0, end=None): 102 | """Returns arr[start] + ... + arr[end]""" 103 | return super(SumSegmentTree, self).reduce(start, end) 104 | 105 | def find_prefixsum_idx(self, prefixsum): 106 | """Find the highest index `i` in the array such that 107 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 108 | if array values are probabilities, this function 109 | allows to sample indexes according to the discrete 110 | probability efficiently. 111 | Parameters 112 | ---------- 113 | perfixsum: float 114 | upperbound on the sum of array prefix 115 | Returns 116 | ------- 117 | idx: int 118 | highest index satisfying the prefixsum constraint 119 | """ 120 | assert 0 <= prefixsum <= self.sum() + 1e-5 121 | idx = 1 122 | while idx < self._capacity: # while non-leaf 123 | if self._value[2 * idx] > prefixsum: 124 | idx = 2 * idx 125 | else: 126 | prefixsum -= self._value[2 * idx] 127 | idx = 2 * idx + 1 128 | return idx - self._capacity 129 | 130 | 131 | class MinSegmentTree(SegmentTree): 132 | def __init__(self, capacity): 133 | super(MinSegmentTree, self).__init__( 134 | capacity=capacity, 135 | operation=min, 136 | neutral_element=float('inf') 137 | ) 138 | 139 | def min(self, start=0, end=None): 140 | """Returns min(arr[start], ..., arr[end])""" 141 | 142 | return super(MinSegmentTree, self).reduce(start, end) 143 | 144 | 145 | class ReplayBuffer(object): 146 | def __init__(self, size): 147 | """Create Replay buffer. 148 | Parameters 149 | ---------- 150 | size: int 151 | Max number of transitions to store in the buffer. When the buffer 152 | overflows the old memories are dropped. 153 | """ 154 | self._storage = [] 155 | self._maxsize = size 156 | self._next_idx = 0 157 | 158 | def __len__(self): 159 | return len(self._storage) 160 | 161 | 162 | def push(self, state, h_trj, action, reward, next_state, nh_trj, done): 163 | data = (state, h_trj, action, reward, next_state, nh_trj, done) 164 | 165 | if self._next_idx >= len(self._storage): 166 | self._storage.append(data) 167 | else: 168 | self._storage[self._next_idx] = data 169 | self._next_idx = (self._next_idx + 1) % self._maxsize 170 | 171 | def _encode_sample(self, idxes): 172 | states, h_trjs, actions, rewards, next_states, nh_trjs, dones = [], [], [], [], [], [], [] 173 | for i in idxes: 174 | data = self._storage[i] 175 | state, h_trj, action, reward, next_state, nh_trj, done = data 176 | states.append(np.array(state, copy=False)) 177 | h_trjs.append(np.array(h_trj, copy=False)) 178 | actions.append(np.array(action, copy=False)) 179 | rewards.append(reward) 180 | next_states.append(np.array(next_state, copy=False)) 181 | nh_trjs.append(np.array(nh_trj, copy=False)) 182 | dones.append(done) 183 | 184 | return np.array(states), np.array(h_trjs), np.array(actions), np.array(rewards), \ 185 | np.array(next_states), np.array(nh_trjs), np.array(dones) 186 | 187 | def sample(self, batch_size): 188 | """Sample a batch of experiences. 189 | Parameters 190 | ---------- 191 | batch_size: int 192 | How many transitions to sample. 193 | Returns 194 | ------- 195 | obs_batch: np.array 196 | batch of observations 197 | act_batch: np.array 198 | batch of actions executed given obs_batch 199 | rew_batch: np.array 200 | rewards received as results of executing act_batch 201 | next_obs_batch: np.array 202 | next set of observations seen after executing act_batch 203 | done_mask: np.array 204 | done_mask[i] = 1 if executing act_batch[i] resulted in 205 | the end of an episode and 0 otherwise. 206 | """ 207 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 208 | return self._encode_sample(idxes) 209 | 210 | 211 | class PrioritizedReplayBuffer(ReplayBuffer): 212 | def __init__(self, size, alpha): 213 | """Create Prioritized Replay buffer. 214 | Parameters 215 | ---------- 216 | size: int 217 | Max number of transitions to store in the buffer. When the buffer 218 | overflows the old memories are dropped. 219 | alpha: float 220 | how much prioritization is used 221 | (0 - no prioritization, 1 - full prioritization) 222 | See Also 223 | -------- 224 | ReplayBuffer.__init__ 225 | """ 226 | super(PrioritizedReplayBuffer, self).__init__(size) 227 | assert alpha > 0 228 | self._alpha = alpha 229 | 230 | it_capacity = 1 231 | while it_capacity < size: 232 | it_capacity *= 2 233 | 234 | self._it_sum = SumSegmentTree(it_capacity) 235 | self._it_min = MinSegmentTree(it_capacity) 236 | self._max_priority = 1.0 237 | 238 | def push(self, *args, **kwargs): 239 | """See ReplayBuffer.store_effect""" 240 | idx = self._next_idx 241 | super(PrioritizedReplayBuffer, self).push(*args, **kwargs) 242 | self._it_sum[idx] = self._max_priority ** self._alpha 243 | self._it_min[idx] = self._max_priority ** self._alpha 244 | 245 | def _sample_proportional(self, batch_size): 246 | res = [] 247 | for _ in range(batch_size): 248 | # TODO(szymon): should we ensure no repeats? 249 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 250 | idx = self._it_sum.find_prefixsum_idx(mass) 251 | res.append(idx) 252 | return res 253 | 254 | def sample(self, batch_size, beta): 255 | """Sample a batch of experiences. 256 | compared to ReplayBuffer.sample 257 | it also returns importance weights and idxes 258 | of sampled experiences. 259 | Parameters 260 | ---------- 261 | batch_size: int 262 | How many transitions to sample. 263 | beta: float 264 | To what degree to use importance weights 265 | (0 - no corrections, 1 - full correction) 266 | Returns 267 | ------- 268 | obs_batch: np.array 269 | batch of observations 270 | act_batch: np.array 271 | batch of actions executed given obs_batch 272 | rew_batch: np.array 273 | rewards received as results of executing act_batch 274 | next_obs_batch: np.array 275 | next set of observations seen after executing act_batch 276 | done_mask: np.array 277 | done_mask[i] = 1 if executing act_batch[i] resulted in 278 | the end of an episode and 0 otherwise. 279 | weights: np.array 280 | Array of shape (batch_size,) and dtype np.float32 281 | denoting importance weight of each sampled transition 282 | idxes: np.array 283 | Array of shape (batch_size,) and dtype np.int32 284 | idexes in buffer of sampled experiences 285 | """ 286 | assert beta > 0 287 | 288 | idxes = self._sample_proportional(batch_size) 289 | 290 | weights = [] 291 | p_min = self._it_min.min() / self._it_sum.sum() 292 | max_weight = (p_min * len(self._storage)) ** (-beta) 293 | 294 | for idx in idxes: 295 | p_sample = self._it_sum[idx] / self._it_sum.sum() 296 | weight = (p_sample * len(self._storage)) ** (-beta) 297 | weights.append(weight / max_weight) 298 | weights = np.array(weights) 299 | encoded_sample = self._encode_sample(idxes) 300 | return tuple(list(encoded_sample) + [weights, idxes]) 301 | 302 | def update_priorities(self, idxes, priorities): 303 | """Update priorities of sampled transitions. 304 | sets priority of transition at index idxes[i] in buffer 305 | to priorities[i]. 306 | Parameters 307 | ---------- 308 | idxes: [int] 309 | List of idxes of sampled transitions 310 | priorities: [float] 311 | List of updated priorities corresponding to 312 | transitions at the sampled idxes denoted by 313 | variable `idxes`. 314 | """ 315 | assert len(idxes) == len(priorities) 316 | for idx, priority in zip(idxes, priorities): 317 | assert priority > 0 318 | assert 0 <= idx < len(self._storage) 319 | self._it_sum[idx] = priority ** self._alpha 320 | self._it_min[idx] = priority ** self._alpha 321 | 322 | self._max_priority = max(self._max_priority, priority) -------------------------------------------------------------------------------- /em.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.nn import Parameter 4 | import pyflann 5 | import numpy as np 6 | from torch.autograd import Variable 7 | import random 8 | 9 | def inverse_distance(d, epsilon=1e-3): 10 | return 1 / (d + epsilon) 11 | 12 | class DND: 13 | def __init__(self, kernel, num_neighbors, max_memory, lr): 14 | self.kernel = kernel 15 | self.num_neighbors = num_neighbors 16 | self.max_memory = max_memory 17 | self.lr = lr 18 | self.keys = None 19 | self.values = None 20 | pyflann.set_distance_type('euclidean') # squared euclidean actually 21 | self.kdtree = pyflann.FLANN() 22 | # key_cache stores a cache of all keys that exist in the DND 23 | # This makes DND updates efficient 24 | self.key_cache = {} 25 | # stale_index is a flag that indicates whether or not the index in self.kdtree is stale 26 | # This allows us to only rebuild the kdtree index when necessary 27 | self.stale_index = True 28 | # indexes_to_be_updated is the set of indexes to be updated on a call to update_params 29 | # This allows us to rebuild only the keys of key_cache that need to be rebuilt when necessary 30 | self.indexes_to_be_updated = set() 31 | 32 | # Keys and value to be inserted into self.keys and self.values when commit_insert is called 33 | self.keys_to_be_inserted = None 34 | self.values_to_be_inserted = None 35 | 36 | # Move recently used lookup indexes 37 | # These should be moved to the back of self.keys and self.values to get LRU property 38 | self.move_to_back = set() 39 | 40 | def get_mem_size(self): 41 | if self.keys is not None: 42 | return len(self.keys) 43 | return 0 44 | 45 | def get_index(self, key): 46 | """ 47 | If key exists in the DND, return its index 48 | Otherwise, return None 49 | """ 50 | # print(key.data.cpu().numpy().shape) 51 | if self.key_cache.get(tuple(key.data.cpu().numpy()[0])) is not None: 52 | if self.stale_index: 53 | self.commit_insert() 54 | return int(self.kdtree.nn_index(key.data.cpu().numpy(), 1)[0][0]) 55 | else: 56 | return None 57 | 58 | def update(self, value, index): 59 | """ 60 | Set self.values[index] = value 61 | """ 62 | values = self.values.data 63 | values[index] = value[0].data 64 | self.values = Parameter(values) 65 | # self.optimizer = optim.RMSprop([self.keys, self.values], lr=self.lr) 66 | 67 | def insert(self, key, value): 68 | """ 69 | Insert key, value pair into DND 70 | """ 71 | 72 | if torch.cuda.is_available(): 73 | if self.keys_to_be_inserted is None: 74 | # Initial insert 75 | self.keys_to_be_inserted = Variable(key.data.cuda(), requires_grad=True) 76 | self.values_to_be_inserted = value.data.cuda() 77 | else: 78 | self.keys_to_be_inserted = torch.cat( 79 | [self.keys_to_be_inserted.cuda(), Variable(key.data.cuda(), requires_grad=True)], 0) 80 | self.values_to_be_inserted = torch.cat( 81 | [self.values_to_be_inserted.cuda(), value.data.cuda()], 0) 82 | else: 83 | if self.keys_to_be_inserted is None: 84 | # Initial insert 85 | self.keys_to_be_inserted = Variable(key.data, requires_grad=True) 86 | self.values_to_be_inserted = value.data 87 | else: 88 | self.keys_to_be_inserted = torch.cat( 89 | [self.keys_to_be_inserted,Variable(key.data, requires_grad=True)], 0) 90 | self.values_to_be_inserted = torch.cat( 91 | [self.values_to_be_inserted, value.data], 0) 92 | self.key_cache[tuple(key.data.cpu().numpy()[0])] = 0 93 | self.stale_index = True 94 | 95 | def commit_insert(self): 96 | if self.keys is None: 97 | self.keys = self.keys_to_be_inserted 98 | self.values = self.values_to_be_inserted 99 | elif self.keys_to_be_inserted is not None: 100 | self.keys = torch.cat([self.keys.data, self.keys_to_be_inserted], 0) 101 | self.values = torch.cat([self.values.data, self.values_to_be_inserted], 0) 102 | # Move most recently used key-value pairs to the back 103 | if len(self.move_to_back) != 0: 104 | self.keys = torch.cat([self.keys.data[list(set(range(len( 105 | self.keys))) - self.move_to_back)], self.keys.data[list(self.move_to_back)]], 0) 106 | self.values = torch.cat([self.values.data[list(set(range(len( 107 | self.values))) - self.move_to_back)], self.values.data[list(self.move_to_back)]], 0) 108 | self.move_to_back = set() 109 | 110 | if len(self.keys) > self.max_memory: 111 | # Expel oldest key to maintain total memory 112 | for key in self.keys[:-self.max_memory]: 113 | del self.key_cache[tuple(key.data.cpu().numpy())] 114 | self.keys = self.keys[-self.max_memory:].data 115 | self.values = self.values[-self.max_memory:].data 116 | self.keys_to_be_inserted = None 117 | self.values_to_be_inserted = None 118 | # self.optimizer = optim.RMSprop([self.keys, self.values], lr=self.lr) 119 | self.kdtree.build_index(self.keys.data.cpu().numpy(), algorithm='kdtree') 120 | self.stale_index = False 121 | 122 | 123 | # def lookup_batch(self, lookup_key, update_flag=False): 124 | # """ 125 | # Perform DND lookup 126 | # If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 127 | # """ 128 | # lookup_indexesb, dists = self.kdtree.nn_index( 129 | # lookup_key.data.cpu().numpy(), min(self.num_neighbors, len(self.keys))) 130 | # 131 | # self.values[lookup_indexesb] 132 | # outs = [] 133 | # for b, lookup_indexes in enumerate(lookup_indexesb): 134 | # output = 0 135 | # kernel_sum = 0 136 | # for i, index in enumerate(lookup_indexes): 137 | # if i == 0 and self.key_cache.get(tuple(lookup_key[0].data.cpu().numpy())) is not None: 138 | # # If a key exactly equal to lookup_key is used in the DND lookup calculation 139 | # # then the loss becomes non-differentiable. Just skip this case to avoid the issue. 140 | # continue 141 | # if update_flag: 142 | # self.indexes_to_be_updated.add(int(index)) 143 | # else: 144 | # self.move_to_back.add(int(index)) 145 | # kernel_val = self.kernel(self.keys[int(index)], lookup_key[b]) 146 | # output += kernel_val * self.values[int(index)] 147 | # kernel_sum += kernel_val 148 | # output = output / kernel_sum 149 | # outs.append(output) 150 | # return torch.stack(outs) 151 | 152 | def nearest(self, lookup_key): 153 | lookup_indexesb, distb = self.kdtree.nn_index( 154 | lookup_key.data.cpu().numpy(), 1) 155 | indexes = torch.LongTensor(lookup_indexesb).view(-1) 156 | values = torch.tensor(self.values).gather(0, indexes.to(lookup_key.device)) 157 | return values.reshape(lookup_key.shape[0], -1) 158 | 159 | 160 | 161 | def lookup(self, lookup_key, update_flag=False, is_learning=False, p=0.7, K=0): 162 | """ 163 | Perform DND lookup 164 | If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 165 | """ 166 | if K<=0: 167 | K=self.num_neighbors 168 | lookup_indexesb, distb = self.kdtree.nn_index( 169 | lookup_key.data.cpu().numpy(), min(K, len(self.keys))) 170 | indexes = torch.LongTensor(lookup_indexesb).view(-1) 171 | 172 | # print(self.kdtree.nn_index( 173 | # lookup_key.data.cpu().numpy(), min(self.num_neighbors, len(self.keys)))[0]) 174 | # print("----") 175 | old_indexes = indexes 176 | vshape = torch.tensor(self.values).shape 177 | if len(vshape)==2: 178 | indexes=indexes.unsqueeze(-1).repeat(1, vshape[1]) 179 | 180 | kvalues = inverse_distance(torch.tensor(distb).to(lookup_key.device)) 181 | kvalues = kvalues/torch.sum(kvalues, dim=-1, keepdim=True) 182 | #print(torch.tensor(self.values)) 183 | #print(indexes) 184 | values = torch.tensor(self.values).gather(0, indexes.to(lookup_key.device)) 185 | 186 | #print(values) 187 | #raise False 188 | if len(vshape)==2: 189 | values = values.reshape(lookup_key.shape[0],vshape[1], -1) 190 | kvalues = kvalues.unsqueeze(1) 191 | else: 192 | values = values.reshape(lookup_key.shape[0],-1) 193 | if random.random() > p: 194 | return torch.max(values, dim=-1)[0].detach() 195 | if not is_learning: 196 | self.move_to_back.update(old_indexes.numpy()) 197 | 198 | return torch.sum(kvalues*values, dim=-1).detach() 199 | 200 | # outs = [] 201 | # for b, lookup_indexes in enumerate(lookup_indexesb): 202 | # output = 0 203 | # kernel_sum = 0 204 | # for i, index in enumerate(lookup_indexes): 205 | # if i == 0 and self.key_cache.get(tuple(lookup_key[0].data.cpu().numpy())) is not None: 206 | # # If a key exactly equal to lookup_key is used in the DND lookup calculation 207 | # # then the loss becomes non-differentiable. Just skip this case to avoid the issue. 208 | # continue 209 | # # if update_flag: 210 | # # self.indexes_to_be_updated.add(int(index)) 211 | # # else: 212 | # # self.move_to_back.add(int(index)) 213 | # # kernel_val = self.kernel(self.keys[int(index)], lookup_key[b]) 214 | # kernel_val = inverse_distance(distb[b][i]) 215 | # output += kernel_val * self.values[int(index)] 216 | # kernel_sum += kernel_val 217 | # output = output / kernel_sum 218 | # outs.append(output) 219 | # return torch.stack(outs) 220 | 221 | def lookup_grad(self, lookup_key, update_flag=False, is_learning=False): 222 | """ 223 | Perform DND lookup 224 | If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 225 | """ 226 | lookup_indexesb, distb = self.kdtree.nn_index( 227 | lookup_key.data.cpu().numpy(), min(self.num_neighbors, len(self.keys))) 228 | 229 | train_var = [] 230 | outs = [] 231 | for b, lookup_indexes in enumerate(lookup_indexesb): 232 | output = 0 233 | kernel_sum = 0 234 | for i, index in enumerate(lookup_indexes): 235 | kkkk = self.keys[int(index)].detach().requires_grad_(True) 236 | train_var.append((kkkk, index)) 237 | kernel_val = self.kernel(kkkk, lookup_key[b]) 238 | #kernel_val = inverse_distance(distb[b][i]) 239 | output += kernel_val * self.values[int(index)] 240 | kernel_sum += kernel_val 241 | output = output / kernel_sum 242 | outs.append(output) 243 | return torch.stack(outs), train_var 244 | 245 | 246 | def lookup2write(self, lookup_key, R, update_flag=False, K=0): 247 | """ 248 | Perform DND lookup 249 | If update_flag == True, add the nearest neighbor indexes to self.indexes_to_be_updated 250 | """ 251 | if K<=0: 252 | K=self.num_neighbors 253 | lookup_indexesb, distb = self.kdtree.nn_index( 254 | lookup_key.data.cpu().numpy(), min(K, len(self.keys))) 255 | #print(lookup_indexesb.shape) 256 | 257 | for b, lookup_indexes in enumerate(lookup_indexesb): 258 | ks = [] 259 | kernel_sum = 0 260 | if K == 1 and len(lookup_indexes.shape)==1: 261 | lookup_indexes=[lookup_indexes] 262 | distb=[distb] 263 | if isinstance(lookup_indexes, np.int32): 264 | lookup_indexes = [lookup_indexes] 265 | distb=[distb] 266 | 267 | for i, index in enumerate(lookup_indexes): 268 | # if i == 0 and self.key_cache.get(tuple(lookup_key[0].data.cpu().numpy())) is not None: 269 | # If a key exactly equal to lookup_key is used in the DND lookup calculation 270 | # then the loss becomes non-differentiable. Just skip this case to avoid the issue. 271 | # continue 272 | if update_flag: 273 | self.indexes_to_be_updated.add(int(index)) 274 | else: 275 | self.move_to_back.add(int(index)) 276 | curv = self.values[int(index)] 277 | # kernel_val = self.kernel(self.keys[int(index)], lookup_key[b]) 278 | kernel_val = inverse_distance(distb[b][i]) 279 | kernel_sum += kernel_val 280 | ks.append((index,kernel_val, curv)) 281 | # self.update((R - curv) * kernel_val * self.lr + curv, index) 282 | # kernel_val = 1 283 | for index, kernel_val, curv in ks: 284 | self.update((R-curv)*kernel_val/kernel_sum*self.lr + curv, index) 285 | #self.update(torch.max(R,curv), index) 286 | 287 | 288 | def update_params(self): 289 | """ 290 | Update self.keys and self.values via backprop 291 | Use self.indexes_to_be_updated to update self.key_cache accordingly and rebuild the index of self.kdtree 292 | """ 293 | for index in self.indexes_to_be_updated: 294 | del self.key_cache[tuple(self.keys[index].data.cpu().numpy())] 295 | # self.optimizer.step() 296 | # self.optimizer.zero_grad() 297 | for index in self.indexes_to_be_updated: 298 | self.key_cache[tuple(self.keys[index].data.cpu().numpy())] = 0 299 | self.indexes_to_be_updated = set() 300 | self.kdtree.build_index(self.keys.data.cpu().numpy()) 301 | self.stale_index = False 302 | -------------------------------------------------------------------------------- /dqn_mbec.py: -------------------------------------------------------------------------------- 1 | import math, random 2 | 3 | import gym 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.autograd as autograd 10 | import torch.nn.functional as F 11 | from torch.optim.lr_scheduler import StepLR 12 | import os, json, copy, pickle 13 | from IPython.display import clear_output 14 | import matplotlib.pyplot as plt 15 | from tensorboard_logger import configure, log_value 16 | from argparse import ArgumentParser 17 | from collections import deque 18 | from tqdm import tqdm 19 | import em as dnd 20 | import controller 21 | 22 | USE_CUDA = torch.cuda.is_available() 23 | # USE_CUDA= False 24 | Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs) 25 | 26 | 27 | 28 | 29 | 30 | class ReplayBuffer(object): 31 | def __init__(self, capacity): 32 | self.buffer = deque(maxlen=capacity) 33 | 34 | def push(self, state, h_trj, action, reward, old_reward, next_state, nh_trj, done): 35 | state = np.expand_dims(state, 0) 36 | next_state = np.expand_dims(next_state, 0) 37 | h_trj = np.expand_dims(h_trj, 0) 38 | nh_trj = np.expand_dims(nh_trj, 0) 39 | 40 | self.buffer.append((state, h_trj, action, reward, old_reward, next_state, nh_trj, done)) 41 | 42 | def sample(self, batch_size): 43 | state, h_trj, action, reward, old_reward, next_state, nh_trj, done = zip(*random.sample(self.buffer, batch_size)) 44 | return np.concatenate(state), np.concatenate(h_trj), action, reward, old_reward, \ 45 | np.concatenate(next_state), np.concatenate(nh_trj), done 46 | 47 | def __len__(self): 48 | return len(self.buffer) 49 | 50 | 51 | 52 | 53 | mse_criterion = nn.MSELoss() 54 | 55 | 56 | # plt.plot([epsilon_by_frame(i) for i in range(10000)]) 57 | # plt.show() 58 | def inverse_distance(h, h_i, epsilon=1e-3): 59 | return 1 / (torch.dist(h, h_i) + epsilon) 60 | 61 | def gauss_kernel(h, h_i, w=0.5): 62 | return torch.exp(-torch.dist(h, h_i)**2/w) 63 | 64 | def no_distance(h, h_i, epsilon=1e-3): 65 | return 1 66 | 67 | 68 | cos = nn.CosineSimilarity(dim=0, eps=1e-6) 69 | def cosine_distance(h, h_i): 70 | return max(cos(h, h_i),0) 71 | 72 | def weights_init(m): 73 | classname = m.__class__.__name__ 74 | if classname.find('Conv') != -1: 75 | weight_shape = list(m.weight.data.size()) 76 | fan_in = np.prod(weight_shape[1:4]) 77 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 78 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 79 | m.weight.data.uniform_(-w_bound, w_bound) 80 | m.bias.data.fill_(0) 81 | elif classname.find('Linear') != -1: 82 | weight_shape = list(m.weight.data.size()) 83 | fan_in = weight_shape[1] 84 | fan_out = weight_shape[0] 85 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 86 | m.weight.data.uniform_(-w_bound, w_bound) 87 | m.bias.data.fill_(0) 88 | 89 | class DQN_DTM(nn.Module): 90 | def __init__(self, env, args): 91 | super(DQN_DTM, self).__init__() 92 | if "maze_img" in args.task2: 93 | if "cnn" not in args.task2: 94 | self.num_inputs = 8 95 | self.proj= nn.Linear(514, 8) 96 | else: 97 | self.cnn = img_featurize.CNN(64) 98 | self.num_inputs = 64 99 | 100 | elif "world" in args.task2: 101 | self.cnn = img_featurize.CNN2(256) 102 | self.num_inputs = 256 103 | for param in self.cnn.parameters(): 104 | param.requires_grad = False 105 | else: 106 | self.num_inputs = env.observation_space.shape[0] 107 | if "trap" in args.task2: 108 | self.num_inputs = self.num_inputs + 2 109 | self.num_actions = env.action_space.n 110 | self.model_name=args.model_name 111 | self.num_warm_up=args.num_warm_up 112 | self.replay_buffer = args.replay_buffer 113 | self.gamma = args.gamma 114 | self.last_inserts=[] 115 | self.insert_size = args.insert_size 116 | self.args= args 117 | 118 | self.qnet = nn.Sequential( 119 | nn.Linear(self.num_inputs, args.qnet_size), 120 | nn.ReLU(), 121 | nn.Linear(args.qnet_size, args.qnet_size), 122 | nn.ReLU(), 123 | nn.Linear(args.qnet_size, env.action_space.n) 124 | ) 125 | 126 | self.qnet_target = nn.Sequential( 127 | nn.Linear(self.num_inputs, args.qnet_size), 128 | nn.ReLU(), 129 | nn.Linear(args.qnet_size, args.qnet_size), 130 | nn.ReLU(), 131 | nn.Linear(args.qnet_size, env.action_space.n) 132 | ) 133 | 134 | 135 | if args.write_interval<0: 136 | self.state2key = nn.Linear(self.num_inputs, args.hidden_size) 137 | self.dnd = dnd.DND(no_distance, num_neighbors=args.k, max_memory=args.memory_size, lr=args.write_lr) 138 | else: 139 | self.dnd = dnd.DND(inverse_distance, num_neighbors=args.k, max_memory=args.memory_size, lr=args.write_lr) 140 | 141 | self.emb_index2count = {} 142 | self.act_net = nn.Linear(self.num_actions, self.num_actions) 143 | self.act_net_target = nn.Linear(self.num_actions, self.num_actions) 144 | 145 | self.choice_net = nn.Sequential( 146 | nn.Linear(args.hidden_size, args.hidden_size),nn.ReLU(), 147 | nn.Linear(args.hidden_size, 1)) 148 | self.choice_net_target = nn.Sequential( 149 | nn.Linear(args.hidden_size, args.hidden_size),nn.ReLU(), 150 | nn.Linear(args.hidden_size, 1)) 151 | self.alpha = nn.Parameter(torch.tensor(1.0), 152 | requires_grad=True) 153 | self.alpha_target = nn.Parameter(torch.tensor(1.0), 154 | requires_grad=True) 155 | self.beta = nn.Parameter(torch.ones(self.num_actions), 156 | requires_grad=True) 157 | self.trj_model = controller.LSTMController(self.num_inputs+self.num_actions, args.hidden_size, num_layers=1) 158 | self.trj_out = nn.Linear(args.hidden_size, self.num_inputs+self.num_actions+1) 159 | self.reward_model = nn.Sequential( 160 | nn.Linear(self.num_inputs+self.num_actions+args.hidden_size, args.reward_hidden_size), 161 | nn.ReLU(), 162 | nn.Linear(args.reward_hidden_size,args.reward_hidden_size), 163 | nn.ReLU(), 164 | nn.Linear(args.reward_hidden_size, 1), 165 | ) 166 | self.best_trj = [] 167 | self.optimizer_dnd = torch.optim.Adam(self.trj_model.parameters()) 168 | # self.future_net = nn.Sequential( 169 | # nn.Linear(args.hidden_size, args.hidden_size), 170 | # nn.ReLU(), 171 | # nn.Linear(args.hidden_size, args.hidden_size), 172 | # ) 173 | self.future_net = nn.Sequential( 174 | nn.Linear(args.hidden_size, args.mem_dim), 175 | ) 176 | 177 | for param in self.future_net.parameters(): 178 | param.requires_grad = False 179 | if args.rec_period<0: 180 | for param in self.trj_model.parameters(): 181 | param.requires_grad = False 182 | 183 | self.apply(weights_init) 184 | 185 | def forward(self, x, h_trj, episode=0, use_mem=1.0, target=0, r=None, a=None, is_learning=False): 186 | 187 | q_value_semantic = self.semantic_net(x, target) 188 | z = torch.zeros(x.shape[0], self.num_actions) 189 | if USE_CUDA: 190 | z = z.cuda() 191 | q_episodic = qt= z 192 | if episode > self.num_warm_up and random.random() 1: 193 | plan_step = 1 194 | lx = x 195 | lh_trj = h_trj 196 | a0=a 197 | 198 | if self.args.write_interval<0: 199 | fh_trj = self.state2key(x) 200 | q_episodic = self.dnd.lookup(fh_trj, is_learning=is_learning, p=args.pread) 201 | #print(q_episodic) 202 | # print("q lookup", q_estimates) 203 | else: 204 | if a is not None: 205 | for i in range(plan_step): 206 | lx, h_trj_a = self.trj_model(self.make_trj_input(lx, a, r), lh_trj) 207 | if plan_step>1 or r is not None: 208 | lx = self.trj_out(lx) 209 | lh_trj = h_trj_a 210 | if len(lx.shape)>1: 211 | lx = lx[:,:self.num_inputs] 212 | else: 213 | lx = lx[:self.num_inputs] 214 | 215 | if r is None: 216 | r = self.reward_model( 217 | torch.cat([self.make_trj_input(lx, a), lh_trj[0][0].detach()], dim=-1)) 218 | 219 | q_episodic[:, a0] = r + self.args.gamma*self.episodic_net(h_trj_a, is_learning) 220 | if plan_step>1: 221 | for aa in range(self.num_actions): 222 | _, h_trj_aa = self.trj_model(self.make_trj_input(lx, aa, r), h_trj_a) 223 | qt[:, aa] = self.episodic_net(h_trj_aa, is_learning) 224 | a= qt.max(1)[1] 225 | 226 | else: 227 | # if random.random()<0.1: 228 | # if self.best_trj: 229 | # q_episodic = self.exploit(lx) 230 | # else: 231 | #print("plan") 232 | for a in range(self.num_actions): 233 | #print(a) 234 | # print(self.make_trj_input(lx, a)) 235 | lxx, h_trj_aa = self.trj_model(self.make_trj_input(lx, a), lh_trj) 236 | if r is None: 237 | # lxx = self.trj_out(lxx) 238 | # if len(lx.shape)>1: 239 | # pr = lxx[:,self.num_inputs+self.num_actions:] 240 | # lxx = lxx[:,:self.num_inputs] 241 | # else: 242 | # pr = lxx[self.num_inputs+self.num_actions:] 243 | # lxx = lxx[:self.num_inputs] 244 | pr = self.reward_model(torch.cat([self.make_trj_input(lx, a),lh_trj[0][0].detach()], dim=-1)) 245 | #print('predicted r: ',pr) 246 | else: 247 | pr = r 248 | 249 | pr = pr.to(device=lxx.device).squeeze(-1) 250 | 251 | q_episodic[:,a] = pr+self.args.gamma*self.episodic_net(h_trj_aa, is_learning) 252 | #print(lx, a, q_episodic[:,a]) 253 | #print('next v', self.episodic_net(h_trj_aa, is_learning)) 254 | #print('cur v est, ', q_episodic[:,a]) 255 | if is_learning is False and random.random()0: 287 | a = args.fix_alpha 288 | else: 289 | if target == 0: 290 | a = self.choice_net(h_trj[0][0]) 291 | # a = self.alpha 292 | else: 293 | a = self.choice_net_target(h_trj[0][0]) 294 | # a = self.alpha_target 295 | # a = self.alpha 296 | a = F.sigmoid(a) 297 | #q_value_semantic = self.semantic_net(torch.cat([x, a*q_episodic], dim=-1), target) 298 | 299 | #if random.random()<0.0003: 300 | # print(q_value_semantic[0], q_episodic[0]*a[0]) 301 | #a = F.sigmoid(self.choice_net(x)) 302 | #a = F.tanh(self.alpha) 303 | # a=0 304 | # return self.act_net(q_episodic) 305 | # if target == 0: 306 | if self.args.td_interval>0: 307 | if target == 0: 308 | return q_episodic*a+q_value_semantic,q_value_semantic, q_episodic*a 309 | else: 310 | return q_episodic*a+q_value_semantic,q_value_semantic, q_episodic*a 311 | # if target == 0: 312 | # return self.act_net(q_episodic*a+q_value_semantic),q_value_semantic, q_episodic*a 313 | # else: 314 | # return self.act_net_target(q_episodic*a+q_value_semantic),q_value_semantic, q_episodic*a 315 | 316 | 317 | # return q_episodic*a+ q_value_semantic,q_value_semantic, q_episodic 318 | # else: 319 | # return q_value_semantic, q_value_semantic, q_value_semantic 320 | # return q_value 321 | return q_episodic, q_episodic, q_episodic 322 | # return q_value_semantic*F.sigmoid(q_episodic),q_value_semantic, q_episodic 323 | # op = torch.matmul(q_value_semantic.unsqueeze(-1),F.sigmoid(q_episodic).unsqueeze(1)) 324 | # q = torch.matmul(op, self.beta.unsqueeze(0).repeat(q_value_semantic.shape[0], 1).unsqueeze(-1)) 325 | # return q.squeeze(-1),q_value_semantic, q_episodic 326 | 327 | 328 | def exploit(self, x): 329 | batch_size = x.shape[0] 330 | q_episodics = [] 331 | 332 | for (h_trj,v) in self.best_trj: 333 | z = torch.zeros(x.shape[0], self.num_actions) 334 | if USE_CUDA: 335 | z = z.cuda() 336 | kw = z 337 | last_h_trj = (h_trj[0].repeat(1, batch_size, 1), h_trj[1].repeat(1, batch_size, 1)) 338 | 339 | for a in range(self.num_actions): 340 | X = self.make_trj_input(x, a) 341 | y_p, nh = self.trj_model(X, last_h_trj) 342 | y_p = self.trj_out(y_p) 343 | rec_loss = torch.norm(y_p[:,:self.num_actions+self.num_inputs]- 344 | X[:, :self.num_actions + self.num_inputs], dim=-1, keepdim=True) 345 | kw[:,a] = -rec_loss 346 | 347 | kw = torch.exp(kw) 348 | kw = kw/torch.sum(kw, dim=-1, keepdim=True) 349 | vs = kw*1.6**v 350 | q_episodics.append(vs) 351 | return torch.mean(torch.stack(q_episodics, dim=0), dim=0) 352 | 353 | def value(self, s, h=None): 354 | if h is None: 355 | h = self.trj_model.create_new_state(1) 356 | q, qs, qe = self.forward(s, h) 357 | return torch.max(q, dim=-1)[0] 358 | else: 359 | return self.episodic_net(h) 360 | 361 | def planning(self, x, h_trj, a=None, plan_step=1): 362 | actions = [] 363 | qa = 0 364 | qsa = 0 365 | qea = 0 366 | 367 | t = 0 368 | for s in range(plan_step): 369 | q, qs, qe = self.forward(x, h_trj, a=a) 370 | action = q.max(1)[1].item() 371 | y_trj, h_trj = self.trj_model(self.make_trj_input(x, action), h_trj) 372 | actions.append(action) 373 | qa+=q*(self.args.gamma**t) 374 | qsa+=qs*(self.args.gamma**t) 375 | qea+=qe*(self.args.gamma**t) 376 | 377 | 378 | t+=1 379 | x = y_trj[:,:self.num_inputs] 380 | 381 | return qa, qsa, qea, actions 382 | 383 | def get_pivot_lastinsert(self): 384 | if len(self.last_inserts)>0: 385 | return min(self.last_inserts) 386 | else: 387 | return -10000000 388 | 389 | def semantic_net(self, x, target=0): 390 | if "maze_img" in self.args.task2: 391 | if "cnn" not in self.args.task2: 392 | x = self.proj(x).detach() 393 | else: 394 | x = self.cnn(x) 395 | elif "world" in self.args.task2: 396 | x = self.cnn(x) 397 | 398 | if target == 0: 399 | return self.qnet(x) 400 | else: 401 | return self.qnet_target(x) 402 | 403 | def episodic_net(self, h_trj, is_learning=False, K=0): 404 | fh_trj = self.future_net(h_trj[0][0]) 405 | # fh_trj = h_trj[0][0].detach() 406 | q_estimates = self.dnd.lookup(fh_trj, is_learning=is_learning, K=K, p=args.pread) 407 | # print("q lookup", q_estimates) 408 | return q_estimates 409 | 410 | def make_trj_input(self, x, a, r=None): 411 | 412 | if "maze_img" in self.args.task2: 413 | if "cnn" not in self.args.task2: 414 | x = self.proj(x).detach() 415 | else: 416 | if len(x.shape)==3: 417 | x = x.unsqueeze(0) 418 | 419 | x = self.cnn(x) 420 | elif "world" in self.args.task2: 421 | x = self.cnn(x) 422 | 423 | a_vec = torch.zeros(x.shape[0],self.num_actions) 424 | a_vec[:,a] = 1 425 | 426 | if USE_CUDA: 427 | a_vec = a_vec.cuda() 428 | # r = r.cuda() 429 | x = torch.cat([x, a_vec],dim=-1) 430 | 431 | return x 432 | 433 | 434 | def add_trj(self, h_trj, R, step, episode, action): 435 | # print(f"add R {R}") 436 | if self.args.write_interval<0: 437 | h_trj = h_trj 438 | else: 439 | h_trj = h_trj[0][0] 440 | hkey = torch.as_tensor(h_trj).float()#.detach() 441 | if USE_CUDA: 442 | hkey = hkey.cuda() 443 | hkey = self.future_net(hkey) 444 | # t = torch.Tensor([0.5]) 445 | # hkey = (F.sigmoid(hkey) > t).float() * 1 446 | if self.args.write_interval<0: 447 | hkey = self.state2key(hkey.cuda()) 448 | rvec = torch.zeros(1, self.num_actions) 449 | rvec[0,action] = R 450 | #print(hkey, rvec) 451 | #raise False 452 | else: 453 | rvec = R.unsqueeze(0) 454 | if USE_CUDA: 455 | rvec = rvec.cuda() 456 | 457 | # print(hkey) 458 | embedding_index = self.dnd.get_index(hkey) 459 | if embedding_index is None: 460 | self.dnd.insert(hkey, rvec.detach()) 461 | 462 | if self.insert_size>0: 463 | if len(self.last_inserts) > self.insert_size: 464 | self.last_inserts.sort() 465 | self.last_inserts = self.last_inserts[1:-1] 466 | self.last_inserts.append(rvec.detach()) 467 | if episode>self.num_warm_up and self.dnd.keys is not None: 468 | #try: 469 | self.dnd.lookup2write(hkey, rvec.detach()) 470 | #except Exception as e: 471 | # print(e) 472 | else: 473 | #print("dssssssss") 474 | if embedding_index not in self.emb_index2count: 475 | self.emb_index2count[embedding_index] = 1 476 | 477 | 478 | if self.args.write_interval<0: 479 | rvec = torch.zeros(1, self.num_actions) 480 | rvec[0, action] = R 481 | rvec = R.unsqueeze(0) 482 | if USE_CUDA: 483 | rvec = rvec.cuda() 484 | self.dnd.update(torch.max(rvec, 485 | torch.tensor(self.emb_index2count[embedding_index]).float().unsqueeze(0).to( 486 | device=rvec.device)), 487 | embedding_index) 488 | else: 489 | #R = self.dnd.values[embedding_index]*self.emb_index2count[embedding_index]+ R 490 | #R = R/(self.emb_index2count[embedding_index]+1) 491 | #self.emb_index2count[embedding_index]+=1 492 | #self.dnd.update(R.unsqueeze(0).detach(), embedding_index) 493 | 494 | # self.dnd.update(torch.max(R.unsqueeze(0), 495 | # torch.tensor(self.emb_index2count[embedding_index]).float().unsqueeze(0).to(device=R.device)), 496 | # embedding_index) 497 | 498 | if episode > self.num_warm_up and self.dnd.keys is not None: 499 | # try: 500 | self.dnd.lookup2write(hkey, rvec.detach(), K=args.k_write) 501 | 502 | def compute_rec_loss(self, last_h_trj, traj_buffer, optimizer, batch_size, noise=0.1): 503 | # print('len ', len(traj_buffer)) 504 | sasr = random.choices(traj_buffer, k=batch_size-1) 505 | sasr.append(traj_buffer[-1]) 506 | 507 | X = [] 508 | y = [] 509 | y2 = [] 510 | hs1 = [] 511 | hs2 = [] 512 | for s1,h, a,s2,r,o_r in sasr: 513 | s1 = torch.as_tensor(s1).float() 514 | s2 = torch.as_tensor(s2).float() 515 | if USE_CUDA: 516 | s1 = s1.cuda() 517 | s2 = s2.cuda() 518 | if len(s1.shape)==1 or len(s1.shape)==3: 519 | s1 = s1.unsqueeze(0) 520 | s2 = s2.unsqueeze(0) 521 | 522 | o_r = torch.FloatTensor([o_r]).unsqueeze(0) 523 | r = torch.FloatTensor([r]).unsqueeze(0) 524 | 525 | x = self.make_trj_input(s1, a, o_r) 526 | x2 = self.make_trj_input(s2, a, o_r) 527 | 528 | #X.append(x) 529 | if noise>0: 530 | if random.random()>0.5: 531 | X.append(F.dropout(x, p=noise)) 532 | else: 533 | noise_tensor = ((torch.max(torch.abs(x))*noise)**0.5)*torch.randn(x.shape) 534 | if USE_CUDA: 535 | noise_tensor = torch.tensor(noise_tensor).cuda() 536 | X.append(x + noise_tensor.float()) 537 | else: 538 | X.append(x) 539 | y.append(x) 540 | 541 | y2.append(torch.cat([x2, r.to(device=x2.device)], dim=-1)) 542 | 543 | hs1.append(torch.tensor(h[0]).to(device=last_h_trj[0].device)) 544 | hs2.append(torch.tensor(h[1]).to(device=last_h_trj[0].device)) 545 | X = torch.stack(X, dim=0) 546 | y = torch.stack(y, dim=0) 547 | y2 = torch.stack(y2, dim=0).squeeze(1) 548 | 549 | last_h_trj = (last_h_trj[0].repeat(1, batch_size, 1), last_h_trj[1].repeat(1, batch_size, 1)) 550 | cur_h_trj = (torch.cat(hs1, dim=1), 551 | torch.cat(hs2, dim=1)) 552 | 553 | if args.rec_type=="pred": 554 | last_h_trj = cur_h_trj 555 | y_p, _ = self.trj_model(X.squeeze(1), last_h_trj) 556 | y_p = self.trj_out(y_p) 557 | # pr = self.reward_model(torch.cat([X.squeeze(1), cur_h_trj],dim=-1)) 558 | _, h_trj = self.trj_model(y.squeeze(1), cur_h_trj) 559 | # h_pred = self.future_net(h_trj[0][0]) 560 | h_pred = h_trj[0][0] 561 | #print(y_p.shape) 562 | #print(X) 563 | #print(y2[:, self.num_inputs + self.num_actions:]) 564 | l1 = mse_criterion(y_p[:, :self.num_inputs + self.num_actions], y2[:, :self.num_inputs + self.num_actions]) 565 | # l2 = mse_criterion(y_p[:,self.num_inputs+self.num_actions:], y2[:,self.num_inputs+self.num_actions:]) 566 | #l2 = mse_criterion(pr, y2[:, self.num_inputs + self.num_actions:]) 567 | l3 = mse_criterion(cur_h_trj[0][0], last_h_trj[0][0]) 568 | loss = l1 569 | #loss = loss + l2 570 | # loss = loss + l3 571 | # loss = mse_criterion(h_pred, last_h_trj[0][0]) 572 | optimizer.zero_grad() 573 | loss.backward(retain_graph=True) 574 | torch.nn.utils.clip_grad_norm(self.parameters(), 10) 575 | optimizer.step() 576 | 577 | return loss, l1, 0, l3 578 | 579 | def compute_reward_loss(self, last_h_trj, traj_buffer, optimizer, batch_size, noise=0.1): 580 | # print('len ', len(traj_buffer)) 581 | sasr = random.choices(traj_buffer, k=batch_size-1) 582 | sasr.append(traj_buffer[-1]) 583 | 584 | X = [] 585 | y = [] 586 | y2 = [] 587 | hs1 = [] 588 | hs2 = [] 589 | for s1,h, a,s2,r,o_r in sasr: 590 | s1 = torch.as_tensor(s1).float() 591 | s2 = torch.as_tensor(s2).float() 592 | if USE_CUDA: 593 | s1 = s1.cuda() 594 | s2 = s2.cuda() 595 | if len(s1.shape)==1 or len(s1.shape)==3: 596 | s1 = s1.unsqueeze(0) 597 | s2 = s2.unsqueeze(0) 598 | 599 | o_r = torch.FloatTensor([o_r]).unsqueeze(0) 600 | r = torch.FloatTensor([r]).unsqueeze(0) 601 | 602 | x = self.make_trj_input(s1, a, o_r) 603 | x2 = self.make_trj_input(s2, a, o_r) 604 | 605 | #X.append(x) 606 | if noise>0: 607 | if random.random()>0.5: 608 | X.append(F.dropout(x, p=noise)) 609 | else: 610 | noise_tensor = ((torch.max(torch.abs(x))*noise)**0.5)*torch.randn(x.shape) 611 | if USE_CUDA: 612 | noise_tensor = torch.tensor(noise_tensor).cuda() 613 | X.append(x + noise_tensor.float()) 614 | else: 615 | X.append(x) 616 | 617 | y.append(x) 618 | 619 | y2.append(torch.cat([x2, r.to(device=x2.device)], dim=-1)) 620 | 621 | hs1.append(torch.tensor(h[0]).to(device=last_h_trj[0].device)) 622 | hs2.append(torch.tensor(h[1]).to(device=last_h_trj[0].device)) 623 | X = torch.stack(X, dim=0) 624 | y = torch.stack(y, dim=0) 625 | y2 = torch.stack(y2, dim=0).squeeze(1) 626 | cur_h_trj = (torch.cat(hs1, dim=1), 627 | torch.cat(hs2, dim=1)) 628 | 629 | 630 | # print(X) 631 | # print(y2[:, self.num_inputs + self.num_actions:]) 632 | 633 | pr = self.reward_model(torch.cat([X.squeeze(1), cur_h_trj[0][0]],dim=-1)) 634 | l2 = mse_criterion(pr, y2[:, self.num_inputs + self.num_actions:]) 635 | optimizer.zero_grad() 636 | l2.backward() 637 | optimizer.step() 638 | return l2 639 | 640 | def compute_td_loss(self, optimizer, batch_size, episode=0): 641 | state, h_trj, action, reward, old_reward, next_state, nh_trj, done = self.replay_buffer.sample(batch_size) 642 | 643 | state = Variable(torch.FloatTensor(np.float32(state))) 644 | next_state = Variable(torch.FloatTensor(np.float32(next_state)), volatile=True) 645 | action = Variable(torch.LongTensor(action)) 646 | reward = Variable(torch.FloatTensor(reward)) 647 | old_reward = Variable(torch.FloatTensor(old_reward)) 648 | 649 | done = Variable(torch.FloatTensor(done)) 650 | 651 | if USE_CUDA: 652 | state = state.cuda() 653 | next_state = next_state.cuda() 654 | action = action.cuda() 655 | reward = reward.cuda() 656 | old_reward = old_reward.cuda() 657 | done = done.cuda() 658 | 659 | 660 | # print(h_trj) 661 | # print(torch.tensor(h_trj[:,0,0,0])) 662 | hx = torch.tensor(h_trj[:,0,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(h_trj[:,0,0]).tolist(), dim=1) 663 | cx = torch.tensor(h_trj[:,1,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(h_trj[:,1,0]).tolist(), dim=1) 664 | q_values, q1, q2 = self.forward(state, (hx, cx), episode, use_mem=1, target=0, r=reward.unsqueeze(-1), a=action, is_learning=True) 665 | nhx = torch.tensor(nh_trj[:,0,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(nh_trj[:, 0,0]).tolist(), dim=1) 666 | ncx = torch.tensor(nh_trj[:,1,0,0]).to(device=state.device).unsqueeze(0)#torch.cat(torch.tensor(nh_trj[:, 1,0]).tolist(), dim=1) 667 | # raction = torch.randint(0, self.num_actions, action.shape) 668 | # if USE_CUDA: 669 | # raction = raction.cuda() 670 | # next_q_values = self.forward(next_state, (nhx, ncx), episode, use_mem=1) 671 | next_q_values, qn1, qn2 = self.forward(next_state, (nhx, ncx), episode, use_mem=1, target=1, r=None, is_learning=True) 672 | q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1) 673 | # q_value2 = q2.gather(1, action.unsqueeze(1)).squeeze(1) 674 | next_q_value = next_q_values.max(1)[0] 675 | # next_q_value = next_q_values.gather(1, raction.unsqueeze(1)).squeeze(1) 676 | 677 | expected_q_value = reward + self.gamma * next_q_value * (1 - done) 678 | # next_q_value2 = qn2.max(1)[0] 679 | # expected_q_value2 = reward + self.gamma * next_q_value2 * (1 - done) 680 | # print(q_value) 681 | # print(expected_q_value) 682 | loss = (q_value - Variable(expected_q_value.data)).pow(2).mean() 683 | # loss = (expected_q_value-q_value2).pow(2).mean() 684 | # loss = loss + (q1-q2).pow(2).mean() 685 | # loss = loss + (qn1/qn2-1).pow(2).mean() 686 | 687 | optimizer.zero_grad() 688 | loss.backward() 689 | torch.nn.utils.clip_grad_norm(self.parameters(), self.args.clip) 690 | optimizer.step() 691 | 692 | return loss 693 | 694 | 695 | def act(self, state, h_trj, epsilon, r=0, episode=0): 696 | state = Variable(torch.FloatTensor(state).unsqueeze(0), volatile=True) 697 | actions = actione = None 698 | if random.random() > epsilon and self.dnd.get_mem_size()>1: 699 | q_value, qs, qe = self.forward(state, h_trj, episode, 700 | r=None, 701 | use_mem=1) 702 | 703 | #q_value, qs, qe, _ = self.planning(state, h_trj, plan_step=2) 704 | 705 | action = q_value.max(1)[1].data[0] 706 | actions = qs.max(1)[1].data[0] 707 | actione = qe.max(1)[1].data[0] 708 | else: 709 | action = random.randrange(self.num_actions) 710 | 711 | y_trj, h_trj = self.trj_model(self.make_trj_input(state, action, 712 | torch.FloatTensor([r]).unsqueeze(0)), 713 | h_trj) 714 | # print("act {}".format(h_trj[0].shape)) 715 | # y_trj = self.trj_out(y_trj) 716 | return action, h_trj, y_trj, actions, actione 717 | 718 | def update_target(self): 719 | self.qnet_target.load_state_dict(self.qnet.state_dict()) 720 | self.choice_net_target.load_state_dict(self.choice_net.state_dict()) 721 | self.act_net_target.load_state_dict(self.act_net.state_dict()) 722 | self.alpha_target = self.alpha 723 | 724 | 725 | # def plot(frame_idx, rewards, td_losses, rec_losses): 726 | # # clear_output(True) 727 | # # plt.figure(figsize=(20,5)) 728 | # # plt.subplot(121) 729 | # # plt.title('frame %s. reward: %s' % (frame_idx, np.mean(rewards[-10:]))) 730 | # # plt.plot(rewards) 731 | # # plt.subplot(122) 732 | # # plt.title('loss') 733 | # # plt.plot(losses) 734 | # # plt.show() 735 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | def run(args): 743 | epsilon_start = args.epsilon_start 744 | epsilon_final = args.epsilon_final 745 | epsilon_decay = args.epsilon_decay 746 | 747 | epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp( 748 | -1. * frame_idx / epsilon_decay) 749 | 750 | batch_size = args.batch_size 751 | args.replay_buffer = ReplayBuffer(args.replay_size) 752 | env_id = args.task 753 | 754 | 755 | if "world" in args.task2: 756 | env = gym.make(env_id) 757 | 758 | env=img_featurize.FrameStack(env, 4) 759 | env = img_featurize.ImageToPyTorch(env) 760 | else: 761 | env = gym.make(env_id) 762 | 763 | 764 | if args.mode=="train": 765 | log_dir = f'./log/{args.task}{args.task2}new/{args.model_name}' \ 766 | f'k{args.k}n{args.memory_size}i{args.write_interval}w{args.num_warm_up}h{args.hidden_size}' \ 767 | f'u{args.update_interval}b{args.replay_size}l{args.insert_size}a{args.fix_alpha}' \ 768 | f'm{args.min_reward}q{args.qnet_size}-{args.run_id}/' 769 | print(log_dir) 770 | configure(log_dir) 771 | 772 | model = DQN_DTM(env, args) 773 | num_params = 0 774 | for p in model.parameters(): 775 | num_params += p.data.view(-1).size(0) 776 | print(f"no params {num_params}") 777 | 778 | if USE_CUDA: 779 | model = model.cuda() 780 | 781 | model.update_target() 782 | 783 | num_frames = args.n_epochs 784 | optimizer_td = optim.Adam(model.parameters(), lr=args.lr) 785 | scheduler_td = StepLR(optimizer_td, step_size=500, gamma=args.decay) 786 | 787 | optimizer_reward = optim.Adam(model.reward_model.parameters()) 788 | 789 | optimizer_rec = optim.Adam(model.parameters(), lr=args.lr) 790 | 791 | td_losses = [] 792 | rec_losses = [] 793 | rec_l1s = [] 794 | rec_l2s = [] 795 | rec_l3s = [] 796 | 797 | all_rewards = [] 798 | episode_reward = mepisode_reward = 0 799 | traj_buffer = [] 800 | state_buffer = [] 801 | best_reward = -1000000000 802 | 803 | state = env.reset() 804 | ostate = state.copy() 805 | 806 | 807 | 808 | if "maze_img" in args.task2: 809 | state_map = {} 810 | img = env.render("rgb_array") 811 | if "view" in args.task2: 812 | img = img_featurize.get_viewport(img, state, num_state=(env.observation_space.high[0]+1)**2) 813 | if "cnn" not in args.task2: 814 | # Load the pretrained model 815 | stateim = img_featurize.get_vector(img) 816 | stateim = np.concatenate([stateim, state]) 817 | 818 | state_map[tuple(state.tolist())] = stateim 819 | state = state_map[tuple(state.tolist())] 820 | else: 821 | state_map[tuple(state.tolist())] = img_featurize.get_image(img) 822 | state = state_map[tuple(state.tolist())] 823 | 824 | if "world" in args.task2: 825 | state = img_featurize.get_image2(state) 826 | 827 | 828 | if "trap" in args.task2: 829 | trap = [0, 0] 830 | while trap[0] == 0 and trap[1] == 0 or trap[0] == env.observation_space.high[0] and trap[1] == \ 831 | env.observation_space.high[0]: 832 | trap = [random.randint(0, env.observation_space.high[0]), 833 | random.randint(0, env.observation_space.high[0])] 834 | print("trap ", trap) 835 | state = np.concatenate([state, trap]) 836 | 837 | episode_num = 0 838 | best_episode_reward = -1e9 839 | step = 0 840 | h_trj = model.trj_model.create_new_state(1) 841 | reward = old_reward = 0 842 | m_contr = [] 843 | s_contr = [] 844 | 845 | for frame_idx in tqdm(range(1, num_frames + 1)): 846 | 847 | epsilon = epsilon_by_frame(frame_idx) 848 | action, nh_trj, y_trj, action_s, action_e = model.act(state, h_trj, epsilon, r=reward, episode=episode_num) 849 | 850 | try: 851 | action = action.item() 852 | except: 853 | pass 854 | if action_e is not None: 855 | if action == action_e: 856 | m_contr.append(1) 857 | else: 858 | m_contr.append(0) 859 | 860 | if action == action_s: 861 | s_contr.append(1) 862 | else: 863 | s_contr.append(0) 864 | 865 | old_reward = reward 866 | next_state, reward, done, _ = env.step(action) 867 | reward = float(reward) 868 | 869 | 870 | 871 | 872 | if "maze" in args.task: 873 | if step > 1000 and "hard" in args.task2: 874 | done = 1 875 | if next_state[0] == ostate[0] and next_state[1] == ostate[1]: 876 | reward = reward - 1 877 | if "hard" in args.task2: 878 | next_state = env.reset() 879 | if "trap" in args.task2 and next_state[0]==trap[0] and next_state[1]==trap[1]: 880 | reward = reward - 2 881 | if "trap_key" in args.task2 and next_state[0]==1 and next_state[1]==1: 882 | trap = [-1, -1] 883 | print("free trap") 884 | 885 | #print(next_state,action, state, reward) 886 | 887 | 888 | m_reward = reward 889 | 890 | #if "world" in args.task2: 891 | # if reward<1.0: 892 | # reward=-0.01 893 | 894 | if args.rnoise>0: 895 | reward += np.random.normal(0, args.rnoise, 1)[0] 896 | if -1 -args.rnoise: 898 | reward = -reward 899 | 900 | if random.random()args.min_reward: 964 | if len(model.last_inserts)>=args.insert_size>0 and RRbest_episode_reward: 974 | best_episode_reward = episode_reward 975 | 976 | # model.best_trj.append((h_trj, episode_reward)) 977 | # if len(model.best_trj)>10: 978 | # model.best_trj = sorted(model.best_trj, key=lambda tup: tup[1]) 979 | # model.best_trj = model.best_trj[1:] 980 | # 981 | 982 | l2 = model.compute_reward_loss(h_trj, traj_buffer, optimizer_reward, args.batch_size_reward, 983 | noise=0) 984 | rec_l2s.append(l2.data.item()) 985 | 986 | if frame_idx < args.rec_period and args.rec_rate>0: 987 | loss, l1, l2, l3 = model.compute_rec_loss(h_trj, traj_buffer, optimizer_rec, args.batch_size_plan, 988 | noise=args.rec_noise) 989 | 990 | rec_losses.append(loss.data.item()) 991 | rec_l1s.append(l1.data.item()) 992 | rec_l3s.append(l3.data.item()) 993 | 994 | if add_mem==1: 995 | model.dnd.commit_insert() 996 | 997 | h_trj = model.trj_model.create_new_state(1) 998 | state = env.reset() 999 | if "maze_img" in args.task2: 1000 | state = state_map[tuple(state.tolist())] 1001 | 1002 | 1003 | all_rewards.append(mepisode_reward) 1004 | #print("episode reward", episode_reward) 1005 | 1006 | episode_reward = mepisode_reward = 0 1007 | del traj_buffer 1008 | del state_buffer 1009 | 1010 | traj_buffer = [] 1011 | state_buffer = [] 1012 | episode_num+=1 1013 | step=0 1014 | if "random" in args.task2: 1015 | env.close() 1016 | del env 1017 | env = gym.make(env_id) 1018 | state = env.reset() 1019 | ostate = state.copy() 1020 | if "maze_img" in args.task2: 1021 | state_map = {} 1022 | img = env.render("rgb_array") 1023 | if "view" in args.task2: 1024 | img = img_featurize.get_viewport(img, state, num_state=(env.observation_space.high[0]+1)**2) 1025 | if "cnn" not in args.task2: 1026 | # Load the pretrained model 1027 | stateim = img_featurize.get_vector(img) 1028 | stateim = np.concatenate([stateim, state]) 1029 | 1030 | state_map[tuple(state.tolist())] = stateim 1031 | state = state_map[tuple(state.tolist())] 1032 | else: 1033 | state_map[tuple(state.tolist())] = img_featurize.get_image(img) 1034 | state = state_map[tuple(state.tolist())] 1035 | 1036 | if "world" in args.task2: 1037 | state = img_featurize.get_image2(state) 1038 | 1039 | if "trap" in args.task2: 1040 | trap = [0, 0] 1041 | while trap[0] == 0 and trap[1] == 0 or trap[0] == env.observation_space.high[0] and trap[1] == \ 1042 | env.observation_space.high[0]: 1043 | trap = [random.randint(0, env.observation_space.high[0]), 1044 | random.randint(0, env.observation_space.high[0])] 1045 | print("trap ", trap) 1046 | state = np.concatenate([state, trap]) 1047 | log_value('Reward/episode reward', np.mean(all_rewards[-args.num_avg_reward:]), episode_num) 1048 | 1049 | 1050 | else: 1051 | if args.write_interval<0: 1052 | state_buffer.append(([state.copy()], episode_reward, step, action)) 1053 | 1054 | elif frame_idx % args.write_interval == 0 and len(traj_buffer) > 0: 1055 | #if random.random() < args.rec_rate: 1056 | l2 = model.compute_reward_loss(h_trj, traj_buffer, optimizer_reward, args.batch_size_reward, 1057 | noise=0) 1058 | rec_l2s.append(l2.data.item()) 1059 | if random.random() < args.rec_rate and frame_idx0 and frame_idx % args.td_interval == 0 and frame_idx > args.td_start and len(args.replay_buffer) > batch_size: 1070 | loss = model.compute_td_loss(optimizer_td, batch_size, episode_num) 1071 | scheduler_td.step() 1072 | td_losses.append(loss.data.item()) 1073 | 1074 | 1075 | 1076 | # if frame_idx % 2000 == 0 and frame_idx > 0: 1077 | # for param_group in optimizer_rec.param_groups: 1078 | # param_group['lr'] = param_group['lr'] / 2 1079 | # for param_group in optimizer_td.param_groups: 1080 | # param_group['lr'] = param_group['lr'] / 2 1081 | 1082 | if frame_idx % args.update_interval == 0: 1083 | model.update_target() 1084 | 1085 | if frame_idx % args.plot_interval == 0: 1086 | #print(optimizer_td.param_groups[0]['lr']) 1087 | log_value('Mem/num stored', model.dnd.get_mem_size(), int(frame_idx)) 1088 | # log_value('Mem/alpha', model.alpha.detach().item(), int(frame_idx)) 1089 | log_value('Mem/min last', model.get_pivot_lastinsert(), int(frame_idx)) 1090 | log_value('Mem/contrib', np.mean(m_contr), int(frame_idx)) 1091 | log_value('Mem/scontrib', np.mean(s_contr), int(frame_idx)) 1092 | 1093 | log_value('Loss/rec loss', np.mean(rec_losses), int(frame_idx)) 1094 | log_value('Loss/rec l1 loss', np.mean(rec_l1s), int(frame_idx)) 1095 | log_value('Loss/rec l2 loss', np.mean(rec_l2s), int(frame_idx)) 1096 | log_value('Loss/rec l3 loss', np.mean(rec_l3s), int(frame_idx)) 1097 | log_value('Loss/td loss', np.mean(td_losses), int(frame_idx)) 1098 | log_value('Episode/num episode', episode_num, int(frame_idx)) 1099 | 1100 | currw = np.mean(all_rewards[-args.num_avg_reward:]) 1101 | 1102 | log_value('Reward/frame reward', currw, int(frame_idx)) 1103 | print(f'episode {episode_num} step {frame_idx} avg rewards {currw} vs best {best_reward}') 1104 | 1105 | if best_reward=currw and frame_idx>num_frames-2* args.plot_interval: 1106 | best_reward = currw 1107 | if os.path.isdir(args.save_model) is False: 1108 | os.mkdir(args.save_model) 1109 | save_dir = os.path.join(args.save_model, args.task) 1110 | if os.path.isdir(save_dir) is False: 1111 | os.mkdir(save_dir) 1112 | with open(os.path.join(save_dir, f'args.jon'), 'w') as fp: 1113 | sa = copy.copy(args) 1114 | sa.device = None 1115 | sa.replay_buffer = None 1116 | json.dump(vars(sa), fp) 1117 | save_dir = os.path.join(save_dir, f"{args.model_name}.pt") 1118 | torch.save(model.state_dict(), save_dir) 1119 | print(f"save model to {save_dir}!") 1120 | 1121 | with open(save_dir + 'mem', 'wb') as output: # Overwrites any existing file. 1122 | pickle.dump(model.dnd, output) 1123 | print(f"save memory to {save_dir + 'mem'}!") 1124 | 1125 | if len(all_rewards) > args.num_avg_reward*1.1: 1126 | all_rewards = all_rewards[-args.num_avg_reward:] 1127 | td_losses = [] 1128 | rec_losses = [] 1129 | rec_l1s = [] 1130 | rec_l2s = [] 1131 | rec_l3s = [] 1132 | m_contr = [] 1133 | s_contr = [] 1134 | 1135 | if episode_num > args.max_episode > 0: 1136 | break 1137 | 1138 | 1139 | 1140 | if __name__ == "__main__": 1141 | parser = ArgumentParser(description="Training script for TEM.") 1142 | parser.add_argument("--mode", default="train", 1143 | help="train or test") 1144 | parser.add_argument("--save_model", default="./model/", 1145 | help="save model dir") 1146 | parser.add_argument("--rnoise", type=float, default=0, 1147 | help="add noise to reward") 1148 | parser.add_argument("--pnoise", type=float, default=0, 1149 | help="add noise to state") 1150 | parser.add_argument("--task", default="MiniWorld-CollectHealth-v0", 1151 | help="task name") 1152 | parser.add_argument("--task2", default="miniworld", 1153 | help="task name") 1154 | parser.add_argument("--render", type=int, default=0, 1155 | help="render or not") 1156 | parser.add_argument("--n_epochs", type=int, default=1000000, 1157 | help="number of epochs") 1158 | parser.add_argument("--max_episode", type=int, default=10000, 1159 | help="number of episode allowed") 1160 | parser.add_argument("--model_name", default="DTM", 1161 | help="DTM") 1162 | parser.add_argument("--lr", type=float, default=0.0005, 1163 | help="lr") 1164 | parser.add_argument("--decay", type=float, default=1, 1165 | help=" decay lr") 1166 | parser.add_argument("--clip", type=float, default=10, 1167 | help="clip gradient") 1168 | parser.add_argument("--epsilon_start", type=float, default=1.0, 1169 | help="exploration start") 1170 | parser.add_argument("--epsilon_final", type=float, default=0.01, 1171 | help="exploration final") 1172 | parser.add_argument("--epsilon_decay", type=float, default=500.0, 1173 | help="exploration decay") 1174 | parser.add_argument("--gamma", type=float, default=0.99, 1175 | help="gamma") 1176 | parser.add_argument("--batch_size", type=int, default=32, 1177 | help="batch size value model") 1178 | parser.add_argument("--batch_size_plan", type=int, default=4, 1179 | help="batch size planning model") 1180 | parser.add_argument("--batch_size_reward", type=int, default=32, 1181 | help="batch size planning model") 1182 | parser.add_argument("--reward_hidden_size", type=int, default=32, 1183 | help="batch size planning model") 1184 | parser.add_argument("--replay_size", type=int, default=1000000, 1185 | help="replay buffer size") 1186 | parser.add_argument("--qnet_size", type=int, default=128, 1187 | help="MLP hidden") 1188 | parser.add_argument("--hidden_size", type=int, default=16, 1189 | help="RNN hidden") 1190 | parser.add_argument("--mem_dim", type=int, default=5, 1191 | help="memory dimesntion") 1192 | parser.add_argument("--memory_size", type=int, default=10000, 1193 | help="memory size") 1194 | parser.add_argument("--insert_size", type=int, default=-1, 1195 | help="insert size") 1196 | parser.add_argument("--k", type=int, default=5, 1197 | help="num neighbor") 1198 | parser.add_argument("--k_write", type=int, default=-1, 1199 | help="num neighbor") 1200 | parser.add_argument("--mem_mode", type=int, default=0, 1201 | help="memory_mode") 1202 | parser.add_argument("--pread", type=float, default=0.7, 1203 | help="minimum reward of env") 1204 | parser.add_argument("--min_reward", type=float, default=-1e8, 1205 | help="minimum reward of env") 1206 | parser.add_argument("--write_interval", type=int, default=10, 1207 | help="interval for memory writing") 1208 | parser.add_argument("--write_lr", type=float, default=.5, 1209 | help="learning rate of writing") 1210 | parser.add_argument("--fix_alpha", type=float, default=-1, 1211 | help="fix alpha") 1212 | parser.add_argument("--bstr_rate", type=float, default=0.1, 1213 | help="learning rate of writing") 1214 | parser.add_argument("--td_interval", type=int, default=1, 1215 | help="interval for td update") 1216 | parser.add_argument("--td_start", type=int, default=-1, 1217 | help="interval for td update") 1218 | parser.add_argument("--rec_rate", type=float, default=0.1, 1219 | help="rate of reconstruction learning") 1220 | parser.add_argument("--rec_noise", type=float, default=0.1, 1221 | help="rate of reconstruction learning") 1222 | parser.add_argument("--rec_type", type=str, default="mem", 1223 | help="rec type") 1224 | parser.add_argument("--rec_period", type=int, default=1e40, 1225 | help="period of reconstruction learning") 1226 | parser.add_argument("--update_interval", type=int, default=100, 1227 | help="interval for update target Q network") 1228 | parser.add_argument("--plot_interval", type=int, default=200, 1229 | help="interval for plotting") 1230 | parser.add_argument("--num_avg_reward", type=int, default=10, 1231 | help="interval for plotting") 1232 | parser.add_argument("--num_warm_up", type=int, default=-1, 1233 | help="number of episode warming up memory") 1234 | parser.add_argument("--run_id", default="no_td", 1235 | help="r1,r2,r3") 1236 | 1237 | global args 1238 | args = parser.parse_args() 1239 | #import sys 1240 | #with open(f'./log_screen_mem{args.model_name}.txt', 'w') as f: 1241 | if args.k_write<0: 1242 | args.k_write=args.k 1243 | run(args) -------------------------------------------------------------------------------- /dqn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 15, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math, random\n", 10 | "\n", 11 | "import gym\n", 12 | "import numpy as np\n", 13 | "\n", 14 | "import torch\n", 15 | "import torch.nn as nn\n", 16 | "import torch.optim as optim\n", 17 | "import torch.autograd as autograd \n", 18 | "import torch.nn.functional as F" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 16, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from IPython.display import clear_output\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "import os\n", 31 | "\n", 32 | "os.environ['KMP_DUPLICATE_LIB_OK']='True'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "

Use Cuda

" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 17, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "USE_CUDA = torch.cuda.is_available()\n", 49 | "Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "

Replay Buffer

" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 18, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from collections import deque\n", 66 | "\n", 67 | "class ReplayBuffer(object):\n", 68 | " def __init__(self, capacity):\n", 69 | " self.buffer = deque(maxlen=capacity)\n", 70 | " \n", 71 | " def push(self, state, action, reward, next_state, done):\n", 72 | " state = np.expand_dims(state, 0)\n", 73 | " next_state = np.expand_dims(next_state, 0)\n", 74 | " \n", 75 | " self.buffer.append((state, action, reward, next_state, done))\n", 76 | " \n", 77 | " def sample(self, batch_size):\n", 78 | " state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))\n", 79 | " return np.concatenate(state), action, reward, np.concatenate(next_state), done\n", 80 | " \n", 81 | " def __len__(self):\n", 82 | " return len(self.buffer)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "

Cart Pole Environment

" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 19, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "env_id = \"MountainCar-v0\"\n", 99 | "env = gym.make(env_id)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "

Epsilon greedy exploration

" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 20, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "epsilon_start = 1.0\n", 116 | "epsilon_final = 0.01\n", 117 | "epsilon_decay = 500\n", 118 | "\n", 119 | "epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 21, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "[]" 131 | ] 132 | }, 133 | "execution_count": 21, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | }, 137 | { 138 | "data": { 139 | "image/png": "\n", 140 | "text/plain": [ 141 | "
" 142 | ] 143 | }, 144 | "metadata": { 145 | "needs_background": "light" 146 | }, 147 | "output_type": "display_data" 148 | } 149 | ], 150 | "source": [ 151 | "plt.plot([epsilon_by_frame(i) for i in range(10000)])" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "

Deep Q Network

" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 22, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "class DQN(nn.Module):\n", 168 | " def __init__(self, num_inputs, num_actions):\n", 169 | " super(DQN, self).__init__()\n", 170 | " \n", 171 | " self.layers = nn.Sequential(\n", 172 | " nn.Linear(env.observation_space.shape[0], 128),\n", 173 | " nn.ReLU(),\n", 174 | " nn.Linear(128, 128),\n", 175 | " nn.ReLU(),\n", 176 | " nn.Linear(128, env.action_space.n)\n", 177 | " )\n", 178 | " \n", 179 | " def forward(self, x):\n", 180 | " return self.layers(x)\n", 181 | " \n", 182 | " def act(self, state, epsilon):\n", 183 | " if random.random() > epsilon:\n", 184 | " state = Variable(torch.FloatTensor(state).unsqueeze(0), volatile=True)\n", 185 | " q_value = self.forward(state)\n", 186 | " action = q_value.max(1)[1].data[0]\n", 187 | " else:\n", 188 | " action = random.randrange(env.action_space.n)\n", 189 | " return action" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 23, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "model = DQN(env.observation_space.shape[0], env.action_space.n)\n", 199 | "\n", 200 | "if USE_CUDA:\n", 201 | " model = model.cuda()\n", 202 | " \n", 203 | "optimizer = optim.Adam(model.parameters())\n", 204 | "\n", 205 | "replay_buffer = ReplayBuffer(1000)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "

Computing Temporal Difference Loss

" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 24, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "def compute_td_loss(batch_size):\n", 222 | " state, action, reward, next_state, done = replay_buffer.sample(batch_size)\n", 223 | "\n", 224 | " state = Variable(torch.FloatTensor(np.float32(state)))\n", 225 | " next_state = Variable(torch.FloatTensor(np.float32(next_state)), volatile=True)\n", 226 | " action = Variable(torch.LongTensor(action))\n", 227 | " reward = Variable(torch.FloatTensor(reward))\n", 228 | " done = Variable(torch.FloatTensor(done))\n", 229 | "\n", 230 | " q_values = model(state)\n", 231 | " next_q_values = model(next_state)\n", 232 | "\n", 233 | " q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)\n", 234 | " next_q_value = next_q_values.max(1)[0]\n", 235 | " expected_q_value = reward + gamma * next_q_value * (1 - done)\n", 236 | " \n", 237 | " loss = (q_value - Variable(expected_q_value.data)).pow(2).mean()\n", 238 | " \n", 239 | " optimizer.zero_grad()\n", 240 | " loss.backward()\n", 241 | " optimizer.step()\n", 242 | " \n", 243 | " return loss" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 25, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "def plot(frame_idx, rewards, losses):\n", 253 | " clear_output(True)\n", 254 | " plt.figure(figsize=(20,5))\n", 255 | " plt.subplot(131)\n", 256 | " plt.title('frame %s. reward: %s' % (frame_idx, np.mean(rewards[-10:])))\n", 257 | " plt.plot(rewards)\n", 258 | " plt.subplot(132)\n", 259 | " plt.title('loss')\n", 260 | " plt.plot(losses)\n", 261 | " plt.show()" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "

Training

" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 26, 274 | "metadata": {}, 275 | "outputs": [ 276 | { 277 | "data": { 278 | "image/png": "\n", 279 | "text/plain": [ 280 | "
" 281 | ] 282 | }, 283 | "metadata": { 284 | "needs_background": "light" 285 | }, 286 | "output_type": "display_data" 287 | } 288 | ], 289 | "source": [ 290 | "num_frames = 100000\n", 291 | "batch_size = 32\n", 292 | "gamma = 0.99\n", 293 | "\n", 294 | "losses = []\n", 295 | "all_rewards = []\n", 296 | "episode_reward = 0\n", 297 | "\n", 298 | "state = env.reset()\n", 299 | "for frame_idx in range(1, num_frames + 1):\n", 300 | " epsilon = epsilon_by_frame(frame_idx)\n", 301 | " action = model.act(state, epsilon)\n", 302 | " try:\n", 303 | " action = action.item()\n", 304 | " except:\n", 305 | " pass\n", 306 | " next_state, reward, done, _ = env.step(action)\n", 307 | " replay_buffer.push(state, action, reward, next_state, done)\n", 308 | " state = next_state\n", 309 | " episode_reward += reward\n", 310 | " \n", 311 | " if done:\n", 312 | " state= env.reset()\n", 313 | " all_rewards.append(episode_reward)\n", 314 | " episode_reward = 0\n", 315 | " \n", 316 | " if len(replay_buffer) > batch_size:\n", 317 | " loss = compute_td_loss(batch_size)\n", 318 | " losses.append(loss.data.item())\n", 319 | " \n", 320 | " if frame_idx % 200 == 0:\n", 321 | " plot(frame_idx, all_rewards, losses)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "


" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "metadata": {}, 334 | "source": [ 335 | "

Atari Environment

" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 12, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "from common.wrappers import make_atari, wrap_deepmind, wrap_pytorch" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 13, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "env_id = \"PongNoFrameskip-v4\"\n", 354 | "env = make_atari(env_id)\n", 355 | "env = wrap_deepmind(env)\n", 356 | "env = wrap_pytorch(env)" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 14, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "class CnnDQN(nn.Module):\n", 366 | " def __init__(self, input_shape, num_actions):\n", 367 | " super(CnnDQN, self).__init__()\n", 368 | " \n", 369 | " self.input_shape = input_shape\n", 370 | " self.num_actions = num_actions\n", 371 | " \n", 372 | " self.features = nn.Sequential(\n", 373 | " nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),\n", 374 | " nn.ReLU(),\n", 375 | " nn.Conv2d(32, 64, kernel_size=4, stride=2),\n", 376 | " nn.ReLU(),\n", 377 | " nn.Conv2d(64, 64, kernel_size=3, stride=1),\n", 378 | " nn.ReLU()\n", 379 | " )\n", 380 | " \n", 381 | " self.fc = nn.Sequential(\n", 382 | " nn.Linear(self.feature_size(), 512),\n", 383 | " nn.ReLU(),\n", 384 | " nn.Linear(512, self.num_actions)\n", 385 | " )\n", 386 | " \n", 387 | " def forward(self, x):\n", 388 | " x = self.features(x)\n", 389 | " x = x.view(x.size(0), -1)\n", 390 | " x = self.fc(x)\n", 391 | " return x\n", 392 | " \n", 393 | " def feature_size(self):\n", 394 | " return self.features(autograd.Variable(torch.zeros(1, *self.input_shape))).view(1, -1).size(1)\n", 395 | " \n", 396 | " def act(self, state, epsilon):\n", 397 | " if random.random() > epsilon:\n", 398 | " state = Variable(torch.FloatTensor(np.float32(state)).unsqueeze(0), volatile=True)\n", 399 | " q_value = self.forward(state)\n", 400 | " action = q_value.max(1)[1].data[0]\n", 401 | " else:\n", 402 | " action = random.randrange(env.action_space.n)\n", 403 | " return action" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 15, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "model = CnnDQN(env.observation_space.shape, env.action_space.n)\n", 413 | "\n", 414 | "if USE_CUDA:\n", 415 | " model = model.cuda()\n", 416 | " \n", 417 | "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n", 418 | "\n", 419 | "replay_initial = 10000\n", 420 | "replay_buffer = ReplayBuffer(100000)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 16, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "epsilon_start = 1.0\n", 430 | "epsilon_final = 0.01\n", 431 | "epsilon_decay = 30000\n", 432 | "\n", 433 | "epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 17, 439 | "metadata": {}, 440 | "outputs": [ 441 | { 442 | "data": { 443 | "text/plain": [ 444 | "[]" 445 | ] 446 | }, 447 | "execution_count": 17, 448 | "metadata": {}, 449 | "output_type": "execute_result" 450 | }, 451 | { 452 | "data": { 453 | "image/png": "", 454 | "text/plain": [ 455 | "
" 456 | ] 457 | }, 458 | "metadata": { 459 | "needs_background": "light" 460 | }, 461 | "output_type": "display_data" 462 | } 463 | ], 464 | "source": [ 465 | "plt.plot([epsilon_by_frame(i) for i in range(1000000)])" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 21, 471 | "metadata": {}, 472 | "outputs": [ 473 | { 474 | "name": "stderr", 475 | "output_type": "stream", 476 | "text": [ 477 | ":2: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.\n", 478 | " Variable = lambda *args, **kwargs: autograd.Variable(*args, **kwargs).cuda() if USE_CUDA else autograd.Variable(*args, **kwargs)\n" 479 | ] 480 | }, 481 | { 482 | "ename": "KeyboardInterrupt", 483 | "evalue": "", 484 | "output_type": "error", 485 | "traceback": [ 486 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 487 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 488 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplay_buffer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mreplay_initial\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_td_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 489 | "\u001b[0;32m\u001b[0m in \u001b[0;36mcompute_td_loss\u001b[0;34m(batch_size)\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 20\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 21\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 490 | "\u001b[0;32m~/anaconda3/envs/pytorch-gat/lib/python3.8/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 219\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 220\u001b[0m create_graph=create_graph)\n\u001b[0;32m--> 221\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 222\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 491 | "\u001b[0;32m~/anaconda3/envs/pytorch-gat/lib/python3.8/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 130\u001b[0;31m Variable._execution_engine.run_backward(\n\u001b[0m\u001b[1;32m 131\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 132\u001b[0m allow_unreachable=True) # allow_unreachable flag\n", 492 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 493 | ] 494 | } 495 | ], 496 | "source": [ 497 | "num_frames = 1400000\n", 498 | "batch_size = 32\n", 499 | "gamma = 0.99\n", 500 | "\n", 501 | "losses = []\n", 502 | "all_rewards = []\n", 503 | "episode_reward = 0\n", 504 | "\n", 505 | "state= env.reset()\n", 506 | "for frame_idx in range(1, num_frames + 1):\n", 507 | " epsilon = epsilon_by_frame(frame_idx)\n", 508 | " action = model.act(state, epsilon)\n", 509 | " \n", 510 | " next_state, reward, done, _ = env.step(action)\n", 511 | " replay_buffer.push(state, action, reward, next_state, done)\n", 512 | " \n", 513 | " state = next_state\n", 514 | " episode_reward += reward\n", 515 | " \n", 516 | " if done:\n", 517 | " state = env.reset()\n", 518 | " all_rewards.append(episode_reward)\n", 519 | " episode_reward = 0\n", 520 | " \n", 521 | " if len(replay_buffer) > replay_initial:\n", 522 | " loss = compute_td_loss(batch_size)\n", 523 | " losses.append(loss.item())\n", 524 | " \n", 525 | " if frame_idx % 10000 == 0:\n", 526 | " plot(frame_idx, all_rewards, losses)" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": {}, 533 | "outputs": [], 534 | "source": [] 535 | } 536 | ], 537 | "metadata": { 538 | "kernelspec": { 539 | "display_name": "pytorch-gat", 540 | "language": "python", 541 | "name": "pytorch-gat" 542 | }, 543 | "language_info": { 544 | "codemirror_mode": { 545 | "name": "ipython", 546 | "version": 3 547 | }, 548 | "file_extension": ".py", 549 | "mimetype": "text/x-python", 550 | "name": "python", 551 | "nbconvert_exporter": "python", 552 | "pygments_lexer": "ipython3", 553 | "version": "3.8.5" 554 | } 555 | }, 556 | "nbformat": 4, 557 | "nbformat_minor": 2 558 | } 559 | --------------------------------------------------------------------------------