├── .gitignore ├── DDPG ├── FiGAR.py ├── TempoRL.py ├── __init__.py ├── utils.py └── vanilla.py ├── LICENSE ├── README.md ├── TempoRL_Appendix.pdf ├── experiments └── .gitkeep ├── grid_envs.py ├── mountain_car.py ├── plot_atari_results.ipynb ├── plot_ddpg.ipynb ├── plot_featurized_results.ipynb ├── plot_tabular_results.ipynb ├── requirements.txt ├── run_atari_experiments.py ├── run_ddpg_experiments.py ├── run_featurized_experiments.py ├── run_tabular_experiments.py ├── tabular_requirements.txt └── utils ├── __init__.py ├── config ├── data_handling.py ├── env_wrappers.py ├── experiments.py └── plotting.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pdf 2 | experiments/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | *.idea 10 | *.ipynb 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /DDPG/FiGAR.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adaptation of the vanilla DDPG code to allow for FiGAR modification as presented 3 | in the FiGAR paper https://arxiv.org/pdf/1702.06054.pdf 4 | """ 5 | 6 | 7 | import copy 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | # We have to modify the Actor to also generate the repetition output 15 | class Actor(nn.Module): 16 | def __init__(self, state_dim, action_dim, max_action, rep_dim): 17 | super(Actor, self).__init__() 18 | 19 | self.l1 = nn.Linear(state_dim, 400) 20 | self.l2 = nn.Linear(400, 300) 21 | self.l3 = nn.Linear(300, action_dim) 22 | 23 | self.max_action = max_action 24 | 25 | self.l5 = nn.Linear(400, 300) 26 | self.l6 = nn.Linear(300, rep_dim) 27 | 28 | def forward(self, state): 29 | # As suggested by the FiGAR authors, the input layer is shared 30 | shared = F.relu(self.l1(state)) 31 | a = F.relu(self.l2(shared)) 32 | 33 | r = F.relu(self.l5(shared)) 34 | return self.max_action * torch.tanh(self.l3(a)), F.log_softmax(self.l6(r), dim=1) 35 | 36 | 37 | # The Critic has to be modified to be able to accept the additional repetition output of the actor 38 | class Critic(nn.Module): 39 | def __init__(self, state_dim, action_dim, repetition_dim): 40 | super(Critic, self).__init__() 41 | 42 | self.l1 = nn.Linear(state_dim + action_dim + repetition_dim, 400) 43 | self.l2 = nn.Linear(400, 300) 44 | self.l3 = nn.Linear(300, 1) 45 | 46 | def forward(self, state, action, repetition): 47 | q = F.relu(self.l1(torch.cat([state, action, repetition], 1))) 48 | q = F.relu(self.l2(q)) 49 | return self.l3(q) 50 | 51 | 52 | class DDPG(object): 53 | def __init__(self, state_dim, action_dim, max_action, repetition_dim, discount=0.99, tau=0.005): 54 | self.actor = Actor(state_dim, action_dim, max_action, repetition_dim).to(device) 55 | self.actor_target = copy.deepcopy(self.actor) 56 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters()) 57 | 58 | self.critic = Critic(state_dim, action_dim, repetition_dim).to(device) 59 | self.critic_target = copy.deepcopy(self.critic) 60 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters()) 61 | 62 | self.discount = discount 63 | self.tau = tau 64 | 65 | def select_action(self, state): 66 | # The select action method has to be adjusted to also sample from the repetition distribution 67 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 68 | action, repetition_prob = self.actor(state) 69 | repetition_dist = torch.distributions.Categorical(repetition_prob) 70 | repetition = repetition_dist.sample() 71 | return (action.cpu().data.numpy().flatten(), 72 | repetition.cpu().data.numpy().flatten(), 73 | repetition_prob.cpu().data.numpy().flatten()) 74 | 75 | def train(self, replay_buffer, batch_size=100): 76 | # The train method has to be adapted to take changes to Actor and Critic into account 77 | # Sample replay buffer 78 | state, action, repetition, next_state, reward, not_done = replay_buffer.sample(batch_size) 79 | 80 | # Compute the target Q value 81 | target_Q = self.critic_target(next_state, *self.actor_target(next_state)) 82 | target_Q = reward + (not_done * self.discount * target_Q).detach() 83 | 84 | # Get current Q estimate 85 | current_Q = self.critic(state, action, repetition) 86 | 87 | # Compute critic loss 88 | critic_loss = F.mse_loss(current_Q, target_Q) 89 | 90 | # Optimize the critic 91 | self.critic_optimizer.zero_grad() 92 | critic_loss.backward() 93 | self.critic_optimizer.step() 94 | 95 | # Compute actor loss 96 | actor_loss = -self.critic(state, *self.actor(state)).mean() 97 | 98 | # Optimize the actor 99 | self.actor_optimizer.zero_grad() 100 | actor_loss.backward() 101 | self.actor_optimizer.step() 102 | 103 | # Update the frozen target models 104 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 105 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 106 | 107 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 108 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 109 | 110 | def save(self, filename): 111 | torch.save(self.critic.state_dict(), filename + "_critic") 112 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 113 | 114 | torch.save(self.actor.state_dict(), filename + "_actor") 115 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 116 | 117 | def load(self, filename): 118 | self.critic.load_state_dict(torch.load(filename + "_critic")) 119 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 120 | self.critic_target = copy.deepcopy(self.critic) 121 | 122 | self.actor.load_state_dict(torch.load(filename + "_actor")) 123 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 124 | self.actor_target = copy.deepcopy(self.actor) 125 | -------------------------------------------------------------------------------- /DDPG/TempoRL.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adaptation of the vanilla DDPG code to allow for TempoRL modification. 3 | """ 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | # we use exactly the same Actor and Critic networks and training methods for both as in the vanilla implementation 10 | from DDPG.vanilla import DDPG as VanillaDDPG 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | 15 | class Q(nn.Module): 16 | """ 17 | Simple fully connected Q function. Also used for skip-Q when concatenating behaviour action and state together. 18 | Used for simpler environments such as mountain-car or lunar-lander. 19 | """ 20 | 21 | def __init__(self, state_dim, action_dim, skip_dim, non_linearity=F.relu): 22 | super(Q, self).__init__() 23 | # We follow the architecture of the Actor and Critic networks in terms of depth and hidden units 24 | self.fc1 = nn.Linear(state_dim + action_dim, 400) 25 | self.fc2 = nn.Linear(400, 300) 26 | self.fc3 = nn.Linear(300, skip_dim) 27 | self._non_linearity = non_linearity 28 | 29 | def forward(self, x): 30 | x = self._non_linearity(self.fc1(x)) 31 | x = self._non_linearity(self.fc2(x)) 32 | return self.fc3(x) 33 | 34 | 35 | class DDPG(VanillaDDPG): 36 | def __init__(self, state_dim, action_dim, max_action, skip_dim, discount=0.99, tau=0.005): 37 | # We can fully reuse the vanilla DDPG and simply stack TempoRL on top 38 | super(DDPG, self).__init__(state_dim, action_dim, max_action, discount, tau) 39 | 40 | # Create Skip Q network 41 | self.skip_Q = Q(state_dim, action_dim, skip_dim) 42 | self.skip_optimizer = torch.optim.Adam(self.skip_Q.parameters()) 43 | 44 | def select_skip(self, state, action): 45 | """ 46 | Select the skip action. 47 | Has to be called after select_action 48 | """ 49 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 50 | action = torch.FloatTensor(action.reshape(1, -1)).to(device) 51 | return self.skip_Q(torch.cat([state, action], 1)).cpu().data.numpy().flatten() 52 | 53 | def train_skip(self, replay_buffer, batch_size=100): 54 | """ 55 | Train the skip network 56 | """ 57 | # Sample replay buffer 58 | state, action, skip, next_state, reward, not_done = replay_buffer.sample(batch_size) 59 | 60 | # Compute the target Q value 61 | target_Q = self.critic_target(next_state, self.actor_target(next_state)) 62 | target_Q = reward + (not_done * np.power(self.discount, skip + 1) * target_Q).detach() 63 | 64 | # Get current Q estimate 65 | current_Q = self.skip_Q(torch.cat([state, action], 1)).gather(1, skip.long()) 66 | 67 | # Compute critic loss 68 | critic_loss = F.mse_loss(current_Q, target_Q) 69 | 70 | # Optimize the critic 71 | self.skip_optimizer.zero_grad() 72 | critic_loss.backward() 73 | self.skip_optimizer.step() 74 | 75 | def save(self, filename): 76 | super().save(filename) 77 | 78 | torch.save(self.skip_Q.state_dict(), filename + "_skip") 79 | torch.save(self.skip_optimizer.state_dict(), filename + "_skip_optimizer") 80 | 81 | def load(self, filename): 82 | super().load(filename) 83 | 84 | self.skip_Q.load_state_dict(torch.load(filename + "_skip")) 85 | self.skip_optimizer.load_state_dict(torch.load(filename + "_skip_optimizer")) 86 | -------------------------------------------------------------------------------- /DDPG/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/DDPG/__init__.py -------------------------------------------------------------------------------- /DDPG/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The ReplayBuffer was originally implemented by Scott Fujimoto https://github.com/sfujim/TD3/blob/master/utils.py 3 | 4 | We added a second replay buffer that can take repetitions/skips into account and created a version of the 5 | Pendulum-v0 environment that has a different rendering style to display if actions were reactive or proactive 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class ReplayBuffer(object): 13 | def __init__(self, state_dim, action_dim, max_size=int(1e6)): 14 | self.max_size = max_size 15 | self.ptr = 0 16 | self.size = 0 17 | 18 | self.state = np.zeros((max_size, state_dim)) 19 | self.action = np.zeros((max_size, action_dim)) 20 | self.next_state = np.zeros((max_size, state_dim)) 21 | self.reward = np.zeros((max_size, 1)) 22 | self.not_done = np.zeros((max_size, 1)) 23 | 24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | def add(self, state, action, next_state, reward, done): 27 | self.state[self.ptr] = state 28 | self.action[self.ptr] = action 29 | self.next_state[self.ptr] = next_state 30 | self.reward[self.ptr] = reward 31 | self.not_done[self.ptr] = 1. - done 32 | 33 | self.ptr = (self.ptr + 1) % self.max_size 34 | self.size = min(self.size + 1, self.max_size) 35 | 36 | def sample(self, batch_size): 37 | ind = np.random.randint(0, self.size, size=batch_size) 38 | 39 | return ( 40 | torch.FloatTensor(self.state[ind]).to(self.device), 41 | torch.FloatTensor(self.action[ind]).to(self.device), 42 | torch.FloatTensor(self.next_state[ind]).to(self.device), 43 | torch.FloatTensor(self.reward[ind]).to(self.device), 44 | torch.FloatTensor(self.not_done[ind]).to(self.device) 45 | ) 46 | 47 | 48 | class FiGARReplayBuffer(object): 49 | def __init__(self, state_dim, action_dim, rep_dim, max_size=int(1e6)): 50 | self.max_size = max_size 51 | self.ptr = 0 52 | self.size = 0 53 | 54 | self.state = np.zeros((max_size, state_dim)) 55 | self.action = np.zeros((max_size, action_dim)) 56 | self.rep = np.zeros((max_size, rep_dim)) 57 | self.next_state = np.zeros((max_size, state_dim)) 58 | self.reward = np.zeros((max_size, 1)) 59 | self.not_done = np.zeros((max_size, 1)) 60 | 61 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 62 | 63 | def add(self, state, action, rep, next_state, reward, done): 64 | self.state[self.ptr] = state 65 | self.action[self.ptr] = action 66 | self.rep[self.ptr] = rep 67 | self.next_state[self.ptr] = next_state 68 | self.reward[self.ptr] = reward 69 | self.not_done[self.ptr] = 1. - done 70 | 71 | self.ptr = (self.ptr + 1) % self.max_size 72 | self.size = min(self.size + 1, self.max_size) 73 | 74 | def sample(self, batch_size): 75 | ind = np.random.randint(0, self.size, size=batch_size) 76 | 77 | return ( 78 | torch.FloatTensor(self.state[ind]).to(self.device), 79 | torch.FloatTensor(self.action[ind]).to(self.device), 80 | torch.FloatTensor(self.rep[ind]).to(self.device), 81 | torch.FloatTensor(self.next_state[ind]).to(self.device), 82 | torch.FloatTensor(self.reward[ind]).to(self.device), 83 | torch.FloatTensor(self.not_done[ind]).to(self.device) 84 | ) 85 | 86 | 87 | import gym 88 | 89 | 90 | class Render(gym.Wrapper): 91 | """Render env by calling its render method. 92 | 93 | Args: 94 | env (gym.Env): Env to wrap. 95 | **kwargs: Keyword arguments passed to the render method. 96 | """ 97 | 98 | def __init__(self, env, episode_modulo=1, **kwargs): 99 | super().__init__(env) 100 | self._kwargs = kwargs 101 | self.render_every_nth_episode = episode_modulo 102 | self._episode_counter = -1 103 | 104 | def reset(self, **kwargs): 105 | self._episode_counter += 1 106 | ret = self.env.reset(**kwargs) 107 | if self._episode_counter % self.render_every_nth_episode == 0: 108 | self.env.render(**self._kwargs) 109 | return ret 110 | 111 | def step(self, action): 112 | ret = self.env.step(action) 113 | if self._episode_counter % self.render_every_nth_episode == 0: 114 | self.env.render(**self._kwargs) 115 | return ret 116 | 117 | def close(self): 118 | self.env.close() 119 | 120 | 121 | from gym.envs.classic_control import PendulumEnv 122 | from os import path 123 | import time 124 | import inspect 125 | 126 | 127 | class MyPendulum(PendulumEnv): 128 | 129 | def __init__(self, **kwargs): 130 | self.dec = False 131 | super().__init__(**kwargs) 132 | 133 | def is_decision_point(self): 134 | return self.dec 135 | 136 | def set_decision_point(self, b): 137 | self.dec = b 138 | 139 | def reset(self): 140 | self.dec = True 141 | return super().reset() 142 | 143 | def render(self, mode='human'): 144 | if self.viewer is None: 145 | from gym.envs.classic_control import rendering 146 | self.viewer = rendering.Viewer(500, 500) 147 | self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2) 148 | self.rod = rendering.make_capsule(1, .2) 149 | if self.is_decision_point(): 150 | self.rod.set_color(.3, .3, .8) 151 | time.sleep(0.5) 152 | self.set_decision_point(False) 153 | else: 154 | self.rod.set_color(.8, .3, .3) 155 | self.pole_transform = rendering.Transform() 156 | self.rod.add_attr(self.pole_transform) 157 | self.viewer.add_geom(self.rod) 158 | axle = rendering.make_circle(.05) 159 | axle.set_color(0, 0, 0) 160 | self.viewer.add_geom(axle) 161 | fname = path.join(path.dirname(inspect.getfile(PendulumEnv)), "assets/clockwise.png") 162 | self.img = rendering.Image(fname, 1., 1.) 163 | self.imgtrans = rendering.Transform() 164 | self.img.add_attr(self.imgtrans) 165 | 166 | self.viewer.add_onetime(self.img) 167 | 168 | if self.is_decision_point(): 169 | self.rod.set_color(.3, .3, .8) 170 | # time.sleep(0.5) 171 | self.set_decision_point(False) 172 | else: 173 | self.rod.set_color(.8, .3, .3) 174 | 175 | self.pole_transform.set_rotation(self.state[0] + np.pi / 2) 176 | if self.last_u: 177 | self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2) 178 | 179 | return self.viewer.render(return_rgb_array=mode == 'rgb_array') 180 | 181 | def close(self): 182 | if self.viewer: 183 | self.viewer.close() 184 | self.viewer = None 185 | 186 | 187 | from gym.envs.registration import register 188 | 189 | register( 190 | id='PendulumDecs-v0', 191 | entry_point='DDPG.utils:MyPendulum', 192 | max_episode_steps=200, 193 | ) 194 | -------------------------------------------------------------------------------- /DDPG/vanilla.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy of code by Scott Fujimoto released under the MIT license 3 | https://github.com/sfujim/TD3 4 | """ 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | # Re-tuned version of Deep Deterministic Policy Gradients (DDPG) 15 | # Paper: https://arxiv.org/abs/1509.02971 16 | 17 | 18 | class Actor(nn.Module): 19 | def __init__(self, state_dim, action_dim, max_action): 20 | super(Actor, self).__init__() 21 | 22 | self.l1 = nn.Linear(state_dim, 400) 23 | self.l2 = nn.Linear(400, 300) 24 | self.l3 = nn.Linear(300, action_dim) 25 | 26 | self.max_action = max_action 27 | 28 | def forward(self, state): 29 | a = F.relu(self.l1(state)) 30 | a = F.relu(self.l2(a)) 31 | return self.max_action * torch.tanh(self.l3(a)) 32 | 33 | 34 | class Critic(nn.Module): 35 | def __init__(self, state_dim, action_dim): 36 | super(Critic, self).__init__() 37 | 38 | self.l1 = nn.Linear(state_dim + action_dim, 400) 39 | self.l2 = nn.Linear(400, 300) 40 | self.l3 = nn.Linear(300, 1) 41 | 42 | def forward(self, state, action): 43 | q = F.relu(self.l1(torch.cat([state, action], 1))) 44 | q = F.relu(self.l2(q)) 45 | return self.l3(q) 46 | 47 | 48 | class DDPG(object): 49 | def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.005): 50 | self.actor = Actor(state_dim, action_dim, max_action).to(device) 51 | self.actor_target = copy.deepcopy(self.actor) 52 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters()) 53 | 54 | self.critic = Critic(state_dim, action_dim).to(device) 55 | self.critic_target = copy.deepcopy(self.critic) 56 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters()) 57 | 58 | self.discount = discount 59 | self.tau = tau 60 | 61 | def select_action(self, state): 62 | state = torch.FloatTensor(state.reshape(1, -1)).to(device) 63 | return self.actor(state).cpu().data.numpy().flatten() 64 | 65 | def train(self, replay_buffer, batch_size=100): 66 | # Sample replay buffer 67 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size) 68 | 69 | # Compute the target Q value 70 | target_Q = self.critic_target(next_state, self.actor_target(next_state)) 71 | target_Q = reward + (not_done * self.discount * target_Q).detach() 72 | 73 | # Get current Q estimate 74 | current_Q = self.critic(state, action) 75 | 76 | # Compute critic loss 77 | critic_loss = F.mse_loss(current_Q, target_Q) 78 | 79 | # Optimize the critic 80 | self.critic_optimizer.zero_grad() 81 | critic_loss.backward() 82 | self.critic_optimizer.step() 83 | 84 | # Compute actor loss 85 | actor_loss = -self.critic(state, self.actor(state)).mean() 86 | 87 | # Optimize the actor 88 | self.actor_optimizer.zero_grad() 89 | actor_loss.backward() 90 | self.actor_optimizer.step() 91 | 92 | # Update the frozen target models 93 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 94 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 95 | 96 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 97 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 98 | 99 | def save(self, filename): 100 | torch.save(self.critic.state_dict(), filename + "_critic") 101 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer") 102 | 103 | torch.save(self.actor.state_dict(), filename + "_actor") 104 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer") 105 | 106 | def load(self, filename): 107 | self.critic.load_state_dict(torch.load(filename + "_critic")) 108 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer")) 109 | self.critic_target = copy.deepcopy(self.critic) 110 | 111 | self.actor.load_state_dict(torch.load(filename + "_actor")) 112 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer")) 113 | self.actor_target = copy.deepcopy(self.actor) 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TempoRL 2 | 3 | This repository contains the code for the ICML'21 paper "[TempoRL: Learning When to Act](https://ml.informatik.uni-freiburg.de/papers/21-ICML-TempoRL.pdf)". 4 | 5 | If you use TempoRL in you research or application, please cite us: 6 | 7 | ```bibtex 8 | @inproceedings{biedenkapp-icml21, 9 | author = {André Biedenkapp and Raghu Rajan and Frank Hutter and Marius Lindauer}, 10 | title = {{T}empo{RL}: Learning When to Act}, 11 | booktitle = {Proceedings of the 38th International Conference on Machine Learning (ICML 2021)}, 12 | year = {2021}, 13 | month = jul, 14 | } 15 | ``` 16 | 17 | ## Appendix 18 | The appendix PDF has been uploaded to this repository and can be accessed [here](TempoRL_Appendix.pdf). 19 | 20 | ## Setup 21 | This code was developed with python 3.6.12 and torch 1.4.0. 22 | If you have the correct python version you need to install the dependencies via 23 | ```bash 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | If you only want to run quick experiments with the tabular agents you can install the minimal requirements in `tabular_requirements.txt` via 28 | ```bash 29 | pip install -r tabular_requirements.txt 30 | ``` 31 | 32 | To make use of the provided jupyter notebook you optionally have to install jupyter 33 | ```bash 34 | pip install jupyter 35 | ``` 36 | 37 | ## How to train tabular agents 38 | To run an agent on any of the below listed environments run 39 | ```bash 40 | python run_tabular_experiments.py -e 10000 --agent Agent --env env_name --eval-eps 500 41 | ``` 42 | replace Agent with `q` for vanilla q-learning and `sq` for our method. 43 | 44 | ## Envs 45 | Currently 3 simple environments available. 46 | Per default all environments give a reward of 1 when reaching the goal (X). 47 | The agents start in state (S) and can traverse open fields (o). 48 | When falling into "lava" (.) the agent receives a reward of -1. 49 | For no other transition are rewards generated. (When rendering environments the agent is marked with *) 50 | An agent can use at most 100 steps to reach the goal. 51 | 52 | Modifications of the below listed environments can run without goal rewards (env_name ends in _ng) 53 | or reduce the goal reward by the number of taken steps (env_name ends in _perc). 54 | * lava (Cliff) 55 | ```console 56 | S o . . . . . . o X 57 | o o . . . . . . o o 58 | o o . . . . . . o o 59 | o o o o o o o o o o 60 | o o o o o o o o o o 61 | o o o o o o o o o o 62 | ``` 63 | 64 | * lava2 (Bridge) 65 | ```console 66 | S o . . . . . . o X 67 | o o . . . . . . o o 68 | o o o o o o o o o o 69 | o o o o o o o o o o 70 | o o . . . . . . o o 71 | o o . . . . . . o o 72 | ``` 73 | 74 | * lava3 (ZigZag) 75 | ```console 76 | S o . . o o o o o o 77 | o o . . o o o o o o 78 | o o . . o o . . o o 79 | o o . . o o . . o o 80 | o o o o o o . . o o 81 | o o o o o o . . o X 82 | ``` 83 | 84 | ## How to train deep agents 85 | To train an agent on featurized environments run e.g. 86 | ```bash 87 | python run_featurized_experiments.py -e 10000 -t 1000000 --eval-after-n-steps 200 -s 1 --agent tdqn --skip-net-max-skips 10 --out-dir . --sparse 88 | ``` 89 | replace tdqn (our agent with shared network architecture) with dqn or dar to run the respective baseline agents 90 | 91 | To train a DDPG agent run e.g. 92 | ```bash 93 | python run_ddpg_experiments.py --policy TempoRLDDPG --env Pendulum-v0 --start_timesteps 1000 --max_timesteps 30000 --eval_freq 250 --max-skip 16 --save_model --out-dir . --seed 1 94 | ``` 95 | replace TempoRLDDPG with FiGARDDPG or DDPG to run the baseline agents. 96 | 97 | To train an agent on atari environments run e.g. 98 | ```bash 99 | run_atari_experiments.py --env freeway --env-max-steps 10000 --agent tdqn --out-dir experiments/atari_new/freeway/tdqn_3 --episodes 20000 --training-steps 2500000 --eval-after-n-steps 10000 --seed 12345 --84x84 --eval-n-episodes 3 100 | ``` 101 | replace tdqn (our agent with shared network architecture) with dqn or dar to run the respective baseline agents. 102 | 103 | ## Experiment data 104 | ##### Note: This data is required to run the plotting jupyter notebooks 105 | We provide all learning curve data, final policy network weights as well as commands to generate that data at: 106 | https://figshare.com/s/d9264dc125ca8ba8efd8 107 | 108 | (Download this data and move it into the experiments folder) 109 | -------------------------------------------------------------------------------- /TempoRL_Appendix.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/TempoRL_Appendix.pdf -------------------------------------------------------------------------------- /experiments/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/experiments/.gitkeep -------------------------------------------------------------------------------- /grid_envs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | from io import StringIO 4 | from typing import Tuple 5 | from gym.envs.toy_text.discrete import DiscreteEnv 6 | import time 7 | from scipy.spatial.distance import cityblock 8 | 9 | LEFT = 0 10 | UP = 1 11 | RIGHT = 2 12 | DOWN = 3 13 | 14 | 15 | class Specs: 16 | def __init__(self, max_steps): 17 | self.max_episode_steps = max_steps 18 | 19 | 20 | class GridCore(DiscreteEnv): 21 | metadata = {'render.modes': ['human', 'ansi']} 22 | 23 | def __init__(self, shape: Tuple[int] = (5, 10), start: Tuple[int] = (0, 0), 24 | goal: Tuple[int] = (0, 9), max_steps: int = 100, 25 | percentage_reward: bool = False, no_goal_rew: bool = False, 26 | dense_reward: bool = False, numpy_state: bool = True, 27 | xy_state: bool = False): 28 | try: 29 | self.shape = self._shape 30 | except AttributeError: 31 | self.shape = shape 32 | self.nS = np.prod(self.shape, dtype=int) # type: int 33 | self.nA = 4 34 | self.start = start 35 | self.goal = goal 36 | self.max_steps = max_steps 37 | self._steps = 0 38 | self._pr = percentage_reward 39 | self._no_goal_rew = no_goal_rew 40 | self.total_steps = 0 41 | self._dr = dense_reward 42 | self.spec = Specs(max_steps) 43 | self._nps = numpy_state 44 | self.xy = xy_state 45 | 46 | P = self._init_transition_probability() 47 | 48 | # We always start in state (3, 0) 49 | if start is not None: 50 | isd = np.zeros(self.nS) 51 | isd[np.ravel_multi_index(start, self.shape)] = 1.0 52 | else: 53 | isd = np.ones(self.nS) 54 | isd[np.ravel_multi_index(goal, self.shape)] = 0.0 55 | for pit in self._pits: 56 | isd[pit] = 0.0 57 | isd *= 1 / (self.nS - len(self._pits) - 1) 58 | 59 | super(GridCore, self).__init__(self.nS, self.nA, P, isd) 60 | 61 | def step(self, a): 62 | self._steps += 1 63 | s, r, d, i = super(GridCore, self).step(a) 64 | if self.xy: 65 | self.s = np.ravel_multi_index(self.s, self.shape) 66 | if self._steps >= self.max_steps: 67 | d = True 68 | i['early'] = True 69 | i['needs_reset'] = True 70 | self.total_steps += 1 71 | if self._nps: 72 | return np.array([s], dtype=np.float32), r, d, i 73 | return s, r, d, i 74 | 75 | def reset(self): 76 | self._steps = 0 77 | self.s = np.random.choice(range(len(self.isd)), p=self.isd) 78 | self.lastaction = None 79 | if self.xy: 80 | s = np.unravel_index(self.s, self.shape) 81 | if self._nps: 82 | return np.array([s], dtype=np.float32) 83 | else: 84 | return s 85 | if self._nps: 86 | return np.array([self.s], dtype=np.float32) 87 | return self.s 88 | 89 | def _init_transition_probability(self): 90 | raise NotImplementedError 91 | 92 | def _check_bounds(self, coord): 93 | coord[0] = min(coord[0], self.shape[0] - 1) 94 | coord[0] = max(coord[0], 0) 95 | coord[1] = min(coord[1], self.shape[1] - 1) 96 | coord[1] = max(coord[1], 0) 97 | return coord 98 | 99 | def print_T(self): 100 | print(self.P[self.s]) 101 | 102 | def map_output(self, s, pos): 103 | if self.s == s: 104 | output = " x " 105 | elif pos == self.goal: 106 | output = " T " 107 | else: 108 | output = " o " 109 | return output 110 | 111 | def map_control_output(self, s, pos): 112 | return self.map_output(s, pos) 113 | 114 | def map_with_inbetween_goal(self, s, pos, in_between_goal): 115 | return self.map_output(s, pos) 116 | 117 | def render(self, mode='human', close=False, in_control=None, in_between_goal=None): 118 | self._render(mode, close, in_control, in_between_goal) 119 | 120 | def _render(self, mode='human', close=False, in_control=None, in_between_goal=None): 121 | if close: 122 | return 123 | outfile = StringIO() if mode == 'ansi' else sys.stdout 124 | if mode == 'human': 125 | print('\033[2;0H') 126 | 127 | for s in range(self.nS): 128 | position = np.unravel_index(s, self.shape) 129 | # print(self.s) 130 | if in_control: 131 | output = self.map_control_output(s, position) 132 | elif in_between_goal: 133 | output = self.map_with_inbetween_goal(s, position, in_between_goal) 134 | else: 135 | output = self.map_output(s, position) 136 | if position[1] == 0: 137 | output = output.lstrip() 138 | if position[1] == self.shape[1] - 1: 139 | output = output.rstrip() 140 | output += "\n" 141 | outfile.write(output) 142 | outfile.write("\n") 143 | if mode == 'human': 144 | if in_control: 145 | time.sleep(0.2) 146 | else: 147 | time.sleep(0.05) 148 | 149 | 150 | class FallEnv(GridCore): 151 | _pits = [] 152 | 153 | def __init__(self, act_fail_prob: float = 1.0, **kwargs): 154 | self.afp = act_fail_prob 155 | super(FallEnv, self).__init__(**kwargs) 156 | 157 | def _calculate_transition_prob(self, current, delta, prob): 158 | transitions = [] 159 | for d, p in zip(delta, prob): 160 | new_position = np.array(current) + np.array(d) 161 | new_position = self._check_bounds(new_position).astype(int) 162 | new_state = np.ravel_multi_index(tuple(new_position), self.shape) 163 | if not self._dr: 164 | reward = 0.0 165 | is_done = False 166 | if tuple(new_position) == self.goal: 167 | if self._pr: 168 | reward = 1 - (self._steps / self.max_steps) 169 | elif not self._no_goal_rew: 170 | reward = 1.0 171 | is_done = True 172 | elif new_state in self._pits: 173 | reward = -1. 174 | is_done = True 175 | transitions.append((p, new_position if self.xy else new_state, reward, is_done)) 176 | else: 177 | reward = -cityblock(new_position, self.goal) 178 | is_done = False 179 | if tuple(new_position) == self.goal: 180 | if self._pr: 181 | reward = 1 - (self._steps / self.max_steps) 182 | elif not self._no_goal_rew: 183 | reward = 100.0 184 | is_done = True 185 | elif new_state in self._pits: 186 | reward = -100 187 | is_done = True 188 | transitions.append((p, new_position if self.xy else new_state, reward, is_done)) 189 | return transitions 190 | 191 | def _init_transition_probability(self): 192 | for idx, p in enumerate(self._pits): 193 | try: 194 | self._pits[idx] = np.ravel_multi_index(p, self.shape) 195 | except: 196 | pass # <- this has to be here for the agent. Otherwise it throws an unexplainable error 197 | # Calculate transition probabilities 198 | P = {} 199 | for s in range(self.nS): 200 | position = np.unravel_index(s, self.shape) 201 | P[s] = {a: [] for a in range(self.nA)} 202 | other_prob = self.afp / 3. 203 | tmp = [[UP, DOWN, LEFT, RIGHT], 204 | [DOWN, LEFT, RIGHT, UP], 205 | [LEFT, RIGHT, UP, DOWN], 206 | [RIGHT, UP, DOWN, LEFT]] 207 | tmp_dirs = [[[-1, 0], [1, 0], [0, -1], [0, 1]], 208 | [[1, 0], [0, -1], [0, 1], [-1, 0]], 209 | [[0, -1], [0, 1], [-1, 0], [1, 0]], 210 | [[0, 1], [-1, 0], [1, 0], [0, -1]]] 211 | tmp_pros = [[1 - self.afp, other_prob, other_prob, other_prob], 212 | [1 - self.afp, other_prob, other_prob, other_prob], 213 | [1 - self.afp, other_prob, other_prob, other_prob], 214 | [1 - self.afp, other_prob, other_prob, other_prob], ] 215 | for acts, dirs, probs in zip(tmp, tmp_dirs, tmp_pros): 216 | P[s][acts[0]] = self._calculate_transition_prob(position, dirs, probs) 217 | return P 218 | 219 | def map_output(self, s, pos): 220 | if self.s == s: 221 | output = " \u001b[33m*\u001b[37m " 222 | elif pos == self.goal: 223 | output = " \u001b[37mX\u001b[37m " 224 | elif s in self._pits: 225 | output = " \u001b[31m.\u001b[37m " 226 | else: 227 | output = " \u001b[30mo\u001b[37m " 228 | return output 229 | 230 | def map_control_output(self, s, pos): 231 | if self.s == s: 232 | return " \u001b[34m*\u001b[37m " 233 | else: 234 | return self.map_output(s, pos) 235 | 236 | def map_with_inbetween_goal(self, s, pos, in_between_goal): 237 | if s == in_between_goal: 238 | return " \u001b[34mx\u001b[37m " 239 | else: 240 | return self.map_output(s, pos) 241 | 242 | 243 | class Bridge6x10Env(FallEnv): 244 | _pits = [[0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], 245 | [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], 246 | [4, 2], [4, 3], [4, 4], [4, 5], [4, 6], [4, 7], 247 | [5, 2], [5, 3], [5, 4], [5, 5], [5, 6], [5, 7]] 248 | _shape = (6, 10) 249 | 250 | 251 | class Pit6x10Env(FallEnv): 252 | _pits = [[0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], 253 | [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], 254 | [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7]] 255 | # [3, 2], [3, 3], [3, 4], [3, 5], [3, 6], [3, 7]] 256 | _shape = (6, 10) 257 | 258 | 259 | class ZigZag6x10(FallEnv): 260 | _pits = [[0, 2], [0, 3], 261 | [1, 2], [1, 3], 262 | [2, 2], [2, 3], 263 | [3, 2], [3, 3], 264 | [5, 7], [5, 6], 265 | [4, 7], [4, 6], 266 | [3, 7], [3, 6], 267 | [2, 7], [2, 6], 268 | ] 269 | _shape = (6, 10) 270 | 271 | 272 | class ZigZag6x10H(FallEnv): 273 | _pits = [[0, 2], [0, 3], 274 | [1, 2], [1, 3], 275 | [2, 2], [2, 3], 276 | [3, 2], [3, 3], 277 | [5, 7], [5, 6], 278 | [4, 7], [4, 6], 279 | [3, 7], [3, 6], 280 | [2, 7], [2, 6], 281 | [4, 4], [5, 2] 282 | ] 283 | _shape = (6, 10) 284 | -------------------------------------------------------------------------------- /mountain_car.py: -------------------------------------------------------------------------------- 1 | """ 2 | http://incompleteideas.net/sutton/MountainCar/MountainCar1.cp 3 | permalink: https://perma.cc/6Z2N-PFWC 4 | """ 5 | 6 | import math 7 | import gym 8 | from gym import spaces 9 | from gym.utils import seeding 10 | import numpy as np 11 | 12 | class MountainCarEnv(gym.Env): 13 | metadata = { 14 | 'render.modes': ['human', 'rgb_array'], 15 | 'video.frames_per_second': 30 16 | } 17 | 18 | def __init__(self): 19 | self.min_position = -1.2 20 | self.max_position = 0.6 21 | self.max_speed = 0.07 22 | self.goal_position = 0.5 23 | 24 | self.low = np.array([self.min_position, -self.max_speed]) 25 | self.high = np.array([self.max_position, self.max_speed]) 26 | 27 | self.viewer = None 28 | 29 | self.action_space = spaces.Discrete(3) 30 | self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32) 31 | 32 | self.seed() 33 | self.reset() 34 | 35 | def seed(self, seed=None): 36 | self.np_random, seed = seeding.np_random(seed) 37 | return [seed] 38 | 39 | def step(self, action): 40 | for _ in range(4): 41 | assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action)) 42 | 43 | position, velocity = self.state 44 | velocity += (action-1)*0.001 + math.cos(3*position)*(-0.0025) 45 | velocity = np.clip(velocity, -self.max_speed, self.max_speed) 46 | position += velocity 47 | position = np.clip(position, self.min_position, self.max_position) 48 | if (position==self.min_position and velocity<0): velocity = 0 49 | 50 | done = bool(position >= self.goal_position) 51 | if not done: 52 | reward = -0.1*np.abs(self.goal_position - position) 53 | else: 54 | reward = 10 55 | self.state = (position, velocity) 56 | 57 | return np.array(self.state), reward, done, {} 58 | 59 | def reset(self): 60 | self.state = np.array([-0.5, 0]) 61 | return np.array(self.state) 62 | 63 | def _height(self, xs): 64 | return np.sin(3 * xs) * .45 + .55 65 | 66 | def close(self): 67 | if self.viewer is not None: 68 | self.viewer.close() 69 | self.viewer = None 70 | 71 | def render(self, mode='human', close=False): 72 | if close: 73 | if self.viewer is not None: 74 | self.viewer.close() 75 | self.viewer = None 76 | return 77 | 78 | screen_width = 600 79 | screen_height = 400 80 | 81 | world_width = self.max_position - self.min_position 82 | scale = screen_width/world_width 83 | carwidth=40 84 | carheight=20 85 | 86 | 87 | if self.viewer is None: 88 | from gym.envs.classic_control import rendering 89 | self.viewer = rendering.Viewer(screen_width, screen_height) 90 | xs = np.linspace(self.min_position, self.max_position, 100) 91 | ys = self._height(xs) 92 | xys = list(zip((xs-self.min_position)*scale, ys*scale)) 93 | 94 | self.track = rendering.make_polyline(xys) 95 | self.track.set_linewidth(4) 96 | self.viewer.add_geom(self.track) 97 | 98 | clearance = 10 99 | 100 | l,r,t,b = -carwidth/2, carwidth/2, carheight, 0 101 | car = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) 102 | car.add_attr(rendering.Transform(translation=(0, clearance))) 103 | self.cartrans = rendering.Transform() 104 | car.add_attr(self.cartrans) 105 | self.viewer.add_geom(car) 106 | frontwheel = rendering.make_circle(carheight/2.5) 107 | frontwheel.set_color(.5, .5, .5) 108 | frontwheel.add_attr(rendering.Transform(translation=(carwidth/4,clearance))) 109 | frontwheel.add_attr(self.cartrans) 110 | self.viewer.add_geom(frontwheel) 111 | backwheel = rendering.make_circle(carheight/2.5) 112 | backwheel.add_attr(rendering.Transform(translation=(-carwidth/4,clearance))) 113 | backwheel.add_attr(self.cartrans) 114 | backwheel.set_color(.5, .5, .5) 115 | self.viewer.add_geom(backwheel) 116 | flagx = (self.goal_position-self.min_position)*scale 117 | flagy1 = self._height(self.goal_position)*scale 118 | flagy2 = flagy1 + 50 119 | flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) 120 | self.viewer.add_geom(flagpole) 121 | flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2-10), (flagx+25, flagy2-5)]) 122 | flag.set_color(.8,.8,0) 123 | self.viewer.add_geom(flag) 124 | 125 | pos = self.state[0] 126 | self.cartrans.set_translation((pos-self.min_position)*scale, self._height(pos)*scale) 127 | self.cartrans.set_rotation(math.cos(3 * pos)) 128 | 129 | return self.viewer.render(return_rgb_array = mode=='rgb_array') 130 | -------------------------------------------------------------------------------- /plot_atari_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "from utils.plotting import get_colors, load_config, plot\n", 12 | "from utils.data_handling import load_dqn_data\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "#### Name explanations\n", 21 | "* DQN -> standard DQN\n", 22 | "* DAR_min^max -> Dynamic action repetition with small repetition and long repetition values\n", 23 | "* tqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action to be concatenated to the state\n", 24 | "* t-dqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action as contextual input\n", 25 | "* tdqn -> TempoRL DQN with shared state representation between the behavoiur and skip action outputs." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import json\n", 35 | "import glob\n", 36 | "import os\n", 37 | "import pandas as pd\n", 38 | "from matplotlib import pyplot as plt\n", 39 | "import seaborn as sb\n", 40 | "\n", 41 | "from scipy.signal import savgol_filter\n", 42 | " \n", 43 | "\n", 44 | "# Somehow the plotting functionallity I ended up with was already covered for the tabular case.\n", 45 | "# I should have just used the plot function from that.\n", 46 | "def plotMultiple(data, ylim=None, title='', logStepY=False, max_steps=200, xlim=None, figsize=None,\n", 47 | " alphas=None, smooth=5, savename=None, rewyticks=None, lenyticks=None,\n", 48 | " skip_stdevs=[], dont_label=[], dont_plot=[], min_steps=None,\n", 49 | " logRewY=False):\n", 50 | " \"\"\"\n", 51 | " Simple plotting method that shows the test reward on the y-axis and the number of performed training steps\n", 52 | " on the x-axis.\n", 53 | " \n", 54 | " data -> (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes])) the data to plot\n", 55 | " ylim -> (list) y-axis limit\n", 56 | " title -> (str) title on top of plot\n", 57 | " logStepY -> (bool) flag that indicates if the y-axis should be on log scale.\n", 58 | " max_steps -> (int) maximal episode length\n", 59 | " min_steps -> (int) optional minimum episode length. If not set assumes 1 as min\n", 60 | " xlim -> (list) x-axis limits\n", 61 | " figsize -> (list) dimensions of the figure\n", 62 | " alphas -> (dict[agent name] -> float) the alpha value to use for plotting of specific agents\n", 63 | " smooth -> (int) the window size for smoothing (has to be odd if used. < 0 deactivates this option)\n", 64 | " savename -> (str) filename to save the figure\n", 65 | " rewyticks -> (list) yticks for the reward plot\n", 66 | " lenyticks -> (list) yticks for the decisions plot\n", 67 | " skip_sdevs -> (list) list of names to not plot standard deviations for.\n", 68 | " dont_label -> (list) list of names to not label.\n", 69 | " dont_plot -> (list) list of names to not plot.\n", 70 | " logRewY -> (bool) flag that indicates if the reward y-axis should be on log scale.\n", 71 | " \"\"\"\n", 72 | " \n", 73 | " if smooth and smooth > 0:\n", 74 | " degree = 2\n", 75 | " for agent in data:\n", 76 | " data[agent] = list(data[agent]) # we have to convert the tuple to lists\n", 77 | " data[agent][0] = list(data[agent][0])\n", 78 | " data[agent][0][0] = savgol_filter(data[agent][0][0], smooth, degree) # smooth the mean reward\n", 79 | " data[agent][0][1] = savgol_filter(data[agent][0][1], smooth, degree) # smooth the stdev reward\n", 80 | " data[agent][1] = list(data[agent][1])\n", 81 | " data[agent][1][0] = savgol_filter(data[agent][1][0], smooth, degree) # smooth mean num steps\n", 82 | " data[agent][1][1] = savgol_filter(data[agent][1][1], smooth, degree)\n", 83 | " data[agent][2] = list(data[agent][2])\n", 84 | " data[agent][2][0] = savgol_filter(data[agent][2][0], smooth, degree) # smooth mean decisions\n", 85 | " data[agent][2][1] = savgol_filter(data[agent][2][1], smooth, degree)\n", 86 | "\n", 87 | " colors, color_map = get_colors()\n", 88 | " \n", 89 | "\n", 90 | " cfg = load_config()\n", 91 | " sb.set_style(cfg['plotting']['seaborn']['style'])\n", 92 | " sb.set_context(cfg['plotting']['seaborn']['context']['context'],\n", 93 | " font_scale=cfg['plotting']['seaborn']['context']['font scale'],\n", 94 | " rc=cfg['plotting']['seaborn']['context']['rc2'])\n", 95 | "\n", 96 | " if figsize:\n", 97 | " fig, ax = plt.subplots(2, figsize=figsize, dpi=100, sharex=True)\n", 98 | " else:\n", 99 | " fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100,sharex=True)\n", 100 | " ax[0].set_title(title)\n", 101 | "\n", 102 | " for agent in list(data.keys())[::-1]:\n", 103 | " if agent in dont_plot:\n", 104 | " continue\n", 105 | " try:\n", 106 | " alph = alphas[agent]\n", 107 | " except:\n", 108 | " alph = 1.\n", 109 | " color_name = color_map['dar'] if 'dar' in agent else color_map[agent]\n", 110 | " rew, lens, decs, train_steps, train_eps = data[agent]\n", 111 | " \n", 112 | " label = agent.upper()\n", 113 | " if agent in ['t-dqn', 'tdqn', 'tqn']:\n", 114 | " label = 't-DQN'\n", 115 | " elif agent in dont_label:\n", 116 | " label = None\n", 117 | "\n", 118 | " #### Plot rewards\n", 119 | " ax[0].step(train_steps[0], rew[0], where='post', c=colors[color_name], label=label,\n", 120 | " alpha=alph, ls='-' if agent != 't-dqn' else '-.')\n", 121 | " if agent not in skip_stdevs:\n", 122 | " ax[0].fill_between(train_steps[0], rew[0]-rew[1], rew[0]+rew[1],\n", 123 | " alpha=0.25 * alph, step='post',\n", 124 | " color=colors[color_name])\n", 125 | " #### Plot lens\n", 126 | " ax[1].step(train_steps[0], decs[0], where='post',\n", 127 | " c=np.array(colors[color_name]), ls='-',\n", 128 | " alpha=alph)\n", 129 | " if agent not in skip_stdevs:\n", 130 | " ax[1].fill_between(train_steps[0], decs[0]-decs[1], decs[0]+decs[1],\n", 131 | " alpha=0.125 * alph, step='post',\n", 132 | " color=np.array(colors[color_name]))\n", 133 | " ax[1].step(train_steps[0], lens[0], where='post',\n", 134 | " c=np.array(colors[color_name]) * .75, alpha=alph,\n", 135 | " ls=':')\n", 136 | " \n", 137 | " if agent not in skip_stdevs:\n", 138 | " ax[1].fill_between(train_steps[0], lens[0]-lens[1], lens[0]+lens[1],\n", 139 | " alpha=0.25 * alph, step='post',\n", 140 | " color=np.array(colors[color_name]) * .75)\n", 141 | " #ax[0].semilogx()\n", 142 | " if rewyticks is not None:\n", 143 | " ax[0].set_yticks(rewyticks)\n", 144 | " if ylim:\n", 145 | " ax[0].set_ylim(ylim)\n", 146 | " if xlim:\n", 147 | " ax[0].set_xlim(xlim)\n", 148 | " ax[0].set_ylabel('Reward')\n", 149 | " if len(data) - len(dont_label) < 5:\n", 150 | " ax[0].legend(ncol=1, loc='best', handlelength=.75)\n", 151 | " ax[1].semilogx()\n", 152 | " if logStepY:\n", 153 | " ax[1].semilogy()\n", 154 | " if logRewY:\n", 155 | " ax[0].semilogy()\n", 156 | " \n", 157 | " ax[1].plot([-999, -999], [-999, -999], ls=':', c='k', label='all')\n", 158 | " ax[1].plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')\n", 159 | " ax[1].legend(loc='best', ncol=1, handlelength=.75)\n", 160 | " if not min_steps:\n", 161 | " ax[1].set_ylim([1, max_steps])\n", 162 | " else:\n", 163 | " ax[1].set_ylim([min_steps, max_steps])\n", 164 | " if xlim:\n", 165 | " ax[1].set_xlim(xlim)\n", 166 | " ax[1].set_ylabel('#Actions')\n", 167 | " ax[1].set_xlabel('#Train Steps')\n", 168 | " if lenyticks is not None:\n", 169 | " ax[1].set_yticks(lenyticks)\n", 170 | " plt.tight_layout()\n", 171 | " if savename:\n", 172 | " plt.savefig(savename)\n", 173 | "\n", 174 | " plt.show()\n", 175 | "\n", 176 | "\n", 177 | "def get_best_to_plot(data, aucs, tempoRL=None):\n", 178 | " \"\"\"\n", 179 | " Simple method to filter which lines to plot.\n", 180 | " \"\"\"\n", 181 | " to_plot = dict()\n", 182 | "\n", 183 | " if tempoRL is None:\n", 184 | " aucs = list(sorted(aucs, key=lambda x: x[1], reverse=True))\n", 185 | " for idx, auc in enumerate(aucs):\n", 186 | " if 't' in auc[0]:\n", 187 | " break\n", 188 | " to_plot[aucs[idx][0]] = data[aucs[idx][0]] # the absolute best\n", 189 | " else:\n", 190 | " to_plot[tempoRL] = data[tempoRL]\n", 191 | "\n", 192 | " bv = -np.inf\n", 193 | " b = None\n", 194 | " for elem in aucs:\n", 195 | " if 'dar' not in elem[0]:\n", 196 | " continue\n", 197 | " elif elem[1] > bv:\n", 198 | " b, bv = elem[0], elem[1]\n", 199 | " to_plot[b] = data[b]\n", 200 | " \n", 201 | " \n", 202 | " to_plot['dqn'] = data['dqn']\n", 203 | " return to_plot" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | " " 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | " " 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": { 224 | "scrolled": false 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "data = {}\n", 229 | "\n", 230 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/pong/tdqn',\n", 231 | " #debug=True,\n", 232 | " max_steps=2.5*10**6)\n", 233 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/pong/dqn',\n", 234 | " #debug=True,\n", 235 | " max_steps=2.5*10**6)\n", 236 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/pong/dar',\n", 237 | " #debug=True,\n", 238 | " max_steps=2.5*10**6)\n", 239 | "\n", 240 | "plotMultiple(data, title='Pong',\n", 241 | " ylim=[-22, 22], xlim=[10**4, 2.5*10**6],\n", 242 | " min_steps=10**2, max_steps=3000, lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n", 243 | " smooth=7, savename='pong_50_seeds.pdf') #, logStepY=True)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "scrolled": false 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "data = {}\n", 255 | "\n", 256 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/beam_rider/tdqn_3',\n", 257 | " #debug=True,\n", 258 | " max_steps=2.5*10**6)\n", 259 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/beam_rider/dqn_3',\n", 260 | " #debug=True,\n", 261 | " max_steps=2.5*10**6)\n", 262 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/beam_rider/dar_3',\n", 263 | " #debug=True,\n", 264 | " max_steps=2.5*10**6)\n", 265 | "\n", 266 | "plotMultiple(data, title='BeamRider',\n", 267 | " ylim=[0, 600],\n", 268 | " xlim=[10**4, 2.5*10**6],\n", 269 | " max_steps=1000, rewyticks=[0, 150, 300, 450, 600], #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n", 270 | " smooth=7, savename='beamrider_15_seeds.pdf') #, logStepY=True)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": { 277 | "scrolled": false 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "data = {}\n", 282 | "\n", 283 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/freeway/tdqn_3',\n", 284 | " #debug=True,\n", 285 | " max_steps=2.5*10**6)\n", 286 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/freeway/dqn_3',\n", 287 | " #debug=True,\n", 288 | " max_steps=2.5*10**6)\n", 289 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/freeway/dar_3',\n", 290 | " #debug=True,\n", 291 | " max_steps=2.5*10**6)\n", 292 | "\n", 293 | "plotMultiple(data, title='Freeway',\n", 294 | " ylim=[0, 35],\n", 295 | " xlim=[10**4, 2.5*10**6],\n", 296 | " max_steps=2100, rewyticks=[0, 11, 22, 33], #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n", 297 | " smooth=7, savename='freeway_15_seeds.pdf') #, logStepY=True)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": { 304 | "scrolled": false 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "data = {}\n", 309 | "\n", 310 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/ms_pacman/tdqn_3',\n", 311 | " #debug=True,\n", 312 | " max_steps=2.5*10**6)\n", 313 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/ms_pacman/dqn_3',\n", 314 | " #debug=True,\n", 315 | " max_steps=2.5*10**6)\n", 316 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/ms_pacman/dar_3',\n", 317 | " #debug=True,\n", 318 | " max_steps=2.5*10**6)\n", 319 | "\n", 320 | "plotMultiple(data, title='MsPacman',\n", 321 | " ylim=[0, 750],\n", 322 | " xlim=[10**4, 2.5*10**6],\n", 323 | " max_steps=300, rewyticks=[0, 250, 500, 750], #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n", 324 | " smooth=7, savename='mspacman_15_seeds.pdf') #, logStepY=True)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": { 331 | "scrolled": false 332 | }, 333 | "outputs": [], 334 | "source": [ 335 | "data = {}\n", 336 | "\n", 337 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/qbert_long/tdqn_3',\n", 338 | " #debug=True,\n", 339 | " max_steps=5*10**6)\n", 340 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/qbert_long/dqn_3',\n", 341 | " #debug=True,\n", 342 | " max_steps=5*10**6)\n", 343 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/qbert_long/dar_3',\n", 344 | " #debug=True,\n", 345 | " max_steps=5*10**6)\n", 346 | "\n", 347 | "plotMultiple(data, title='QBert',\n", 348 | " ylim=[0, 1000],\n", 349 | " xlim=[10**4, 5*10**6], logRewY=False,\n", 350 | " max_steps=225, min_steps=0, rewyticks=[0, 250, 500, 750, 1000], lenyticks=[0, 50, 100, 150, 200],\n", 351 | " smooth=7, savename='qbert_15_sees.pdf') #, logStepY=True)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": { 358 | "scrolled": false 359 | }, 360 | "outputs": [], 361 | "source": [ 362 | "data = {}\n", 363 | "\n", 364 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/qbert/tdqn_3',\n", 365 | " #debug=True,\n", 366 | " max_steps=2.5*10**6)\n", 367 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/qbert/dqn_3',\n", 368 | " #debug=True,\n", 369 | " max_steps=2.5*10**6)\n", 370 | "\n", 371 | "plotMultiple(data, title='QBert',\n", 372 | " ylim=[0, 1000],\n", 373 | " xlim=[10**4, 2.5*10**6], logRewY=False,\n", 374 | " max_steps=225, min_steps=0, rewyticks=[0, 250, 500, 750, 1000], lenyticks=[0, 50, 100, 150, 200],\n", 375 | " smooth=7) #, logStepY=True)" 376 | ] 377 | } 378 | ], 379 | "metadata": { 380 | "kernelspec": { 381 | "display_name": "Python 3", 382 | "language": "python", 383 | "name": "python3" 384 | }, 385 | "language_info": { 386 | "codemirror_mode": { 387 | "name": "ipython", 388 | "version": 3 389 | }, 390 | "file_extension": ".py", 391 | "mimetype": "text/x-python", 392 | "name": "python", 393 | "nbconvert_exporter": "python", 394 | "pygments_lexer": "ipython3", 395 | "version": "3.7.7" 396 | } 397 | }, 398 | "nbformat": 4, 399 | "nbformat_minor": 4 400 | } 401 | -------------------------------------------------------------------------------- /plot_ddpg.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "import os\n", 12 | "from utils.plotting import get_colors, load_config, plot\n", 13 | "import numpy as np\n", 14 | "from matplotlib import pyplot as plt" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import json\n", 24 | "import glob\n", 25 | "import os\n", 26 | "import pandas as pd\n", 27 | "from matplotlib import pyplot as plt\n", 28 | "import seaborn as sb\n", 29 | "\n", 30 | "from scipy.signal import savgol_filter\n", 31 | " \n", 32 | "\n", 33 | "# Somehow the plotting functionallity I ended up with was already covered for the tabular case.\n", 34 | "# I should have just used the plot function from that.\n", 35 | "def plotMultiple(data, ylim=None, title='', logStepY=False, max_steps=200, xlim=None, figsize=None,\n", 36 | " alphas=None, smooth=5, savename=None, rewyticks=None, lenyticks=None,\n", 37 | " skip_stdevs=[], dont_label=[], dont_plot=[], min_steps=None):\n", 38 | " \"\"\"\n", 39 | " Simple plotting method that shows the test reward on the y-axis and the number of performed training steps\n", 40 | " on the x-axis.\n", 41 | " \n", 42 | " data -> (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes])) the data to plot\n", 43 | " ylim -> (list) y-axis limit\n", 44 | " title -> (str) title on top of plot\n", 45 | " logStepY -> (bool) flag that indicates if the y-axis should be on log scale.\n", 46 | " max_steps -> (int) maximal episode length\n", 47 | " xlim -> (list) x-axis limits\n", 48 | " figsize -> (list) dimensions of the figure\n", 49 | " alphas -> (dict[agent name] -> float) the alpha value to use for plotting of specific agents\n", 50 | " smooth -> (int) the window size for smoothing (has to be odd if used. < 0 deactivates this option)\n", 51 | " savename -> (str) filename to save the figure\n", 52 | " rewyticks -> (list) yticks for the reward plot\n", 53 | " lenyticks -> (list) yticks for the decisions plot\n", 54 | " skip_sdevs -> (list) list of names to not plot standard deviations for.\n", 55 | " dont_label -> (list) list of names to not label.\n", 56 | " dont_plot -> (list) list of names to not plot.\n", 57 | " \"\"\"\n", 58 | " \n", 59 | " if smooth and smooth > 0:\n", 60 | " degree = 2\n", 61 | " for agent in data:\n", 62 | " data[agent] = list(data[agent]) # we have to convert the tuple to lists\n", 63 | " data[agent][0] = list(data[agent][0])\n", 64 | " data[agent][0][0] = savgol_filter(data[agent][0][0], smooth, degree) # smooth the mean reward\n", 65 | " data[agent][0][1] = savgol_filter(data[agent][0][1], smooth, degree) # smooth the stdev reward\n", 66 | " data[agent][1] = list(data[agent][1])\n", 67 | " data[agent][1][0] = savgol_filter(data[agent][1][0], smooth, degree) # smooth mean num steps\n", 68 | " data[agent][1][1] = savgol_filter(data[agent][1][1], smooth, degree)\n", 69 | " data[agent][2] = list(data[agent][2])\n", 70 | " data[agent][2][0] = savgol_filter(data[agent][2][0], smooth, degree) # smooth mean decisions\n", 71 | " data[agent][2][1] = savgol_filter(data[agent][2][1], smooth, degree)\n", 72 | "\n", 73 | " colors, color_map = get_colors()\n", 74 | " \n", 75 | "\n", 76 | " cfg = load_config()\n", 77 | " sb.set_style(cfg['plotting']['seaborn']['style'])\n", 78 | " sb.set_context(cfg['plotting']['seaborn']['context']['context'],\n", 79 | " font_scale=cfg['plotting']['seaborn']['context']['font scale'],\n", 80 | " rc=cfg['plotting']['seaborn']['context']['rc2'])\n", 81 | "\n", 82 | " if figsize:\n", 83 | " fig, ax = plt.subplots(2, figsize=figsize, dpi=100, sharex=True)\n", 84 | " else:\n", 85 | " fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100,sharex=True)\n", 86 | " ax[0].set_title(title)\n", 87 | "\n", 88 | " for agent in list(data.keys())[::-1]:\n", 89 | " if agent in dont_plot:\n", 90 | " continue\n", 91 | " try:\n", 92 | " alph = alphas[agent]\n", 93 | " except:\n", 94 | " alph = 1.\n", 95 | " color_name = None\n", 96 | " if 'dar' in agent:\n", 97 | " color_name = color_map['dar']\n", 98 | " elif agent.startswith('t'):\n", 99 | " color_name = color_map['t-DDPG']\n", 100 | " elif agent.startswith('f'):\n", 101 | " color_name = color_map['f-DDPG']\n", 102 | " else:\n", 103 | " color_name = color_map[agent]\n", 104 | " rew, lens, decs, train_steps, train_eps = data[agent]\n", 105 | " \n", 106 | " label = agent.upper()\n", 107 | " if agent.startswith('t'):\n", 108 | " label = 't-DDPG'\n", 109 | " elif agent.startswith('f'):\n", 110 | " label = 'FiGAR'\n", 111 | " elif agent.startswith('e'):\n", 112 | " label = r'$\\epsilon$z-DQN'\n", 113 | " elif agent in dont_label:\n", 114 | " label = None\n", 115 | "\n", 116 | " #### Plot rewards\n", 117 | " ax[0].step(train_steps[0], rew[0], where='post', c=colors[color_name], label=label,\n", 118 | " alpha=alph)\n", 119 | " if agent not in skip_stdevs:\n", 120 | " ax[0].fill_between(train_steps[0], rew[0]-rew[1], rew[0]+rew[1], alpha=0.25 * alph, step='post',\n", 121 | " color=colors[color_name])\n", 122 | " #### Plot lens\n", 123 | " ax[1].step(train_steps[0], decs[0], where='post', c=np.array(colors[color_name]), ls='-',\n", 124 | " alpha=alph)\n", 125 | " if agent not in skip_stdevs:\n", 126 | " ax[1].fill_between(train_steps[0], decs[0]-decs[1], decs[0]+decs[1], alpha=0.125 * alph, step='post',\n", 127 | " color=np.array(colors[color_name]))\n", 128 | " ax[1].step(train_steps[0], lens[0], where='post',\n", 129 | " c=np.array(colors[color_name]) * .75, alpha=alph, ls=':')\n", 130 | " \n", 131 | " if agent not in skip_stdevs:\n", 132 | " ax[1].fill_between(train_steps[0], lens[0]-lens[1], lens[0]+lens[1], alpha=0.25 * alph, step='post',\n", 133 | " color=np.array(colors[color_name]) * .75)\n", 134 | " ax[0].semilogx()\n", 135 | " if rewyticks is not None:\n", 136 | " ax[0].set_yticks(rewyticks)\n", 137 | " if ylim:\n", 138 | " ax[0].set_ylim(ylim)\n", 139 | " if xlim:\n", 140 | " ax[0].set_xlim(xlim)\n", 141 | " ax[0].set_ylabel('Reward')\n", 142 | " if len(data) - len(dont_label) < 5:\n", 143 | " ax[0].legend(ncol=1, loc='best', handlelength=.75)\n", 144 | " ax[1].semilogx()\n", 145 | " if logStepY:\n", 146 | " ax[1].semilogy()\n", 147 | " \n", 148 | " ax[1].plot([-999, -999], [-999, -999], ls=':', c='k', label='all')\n", 149 | " ax[1].plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')\n", 150 | " ax[1].legend(loc='best', ncol=1, handlelength=.75)\n", 151 | " ax[1].set_ylim([min_steps if min_steps is not None else 1, max_steps])\n", 152 | " if xlim:\n", 153 | " ax[1].set_xlim(xlim)\n", 154 | " ax[1].set_ylabel('#Actions')\n", 155 | " ax[1].set_xlabel('#Train Steps')\n", 156 | " if lenyticks is not None:\n", 157 | " ax[1].set_yticks(lenyticks)\n", 158 | " plt.tight_layout()\n", 159 | " if savename:\n", 160 | " plt.savefig(savename)\n", 161 | "\n", 162 | " plt.show()" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": { 169 | "scrolled": false 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "results = {}\n", 174 | "ddpg_datas = []\n", 175 | "for i in sorted(os.listdir('experiments/ddpg/DDPG')):\n", 176 | " ddpg_datas.append(np.load(f'experiments/ddpg/DDPG/{i}/DDPG_Pendulum-v0_{i}.npy'))\n", 177 | "\n", 178 | "\n", 179 | "ddpg_mean = np.mean(ddpg_datas, axis=0)\n", 180 | "ddpg_stdev = np.std(ddpg_datas, axis=0)\n", 181 | "results['DDPG'] = [[ddpg_mean[:, 1], ddpg_stdev[:, 1]],\n", 182 | " [ddpg_mean[:, 3], ddpg_stdev[:, 3]], \n", 183 | " [ddpg_mean[:, 2], ddpg_stdev[:, 2]],\n", 184 | " [ddpg_mean[:, 0], ddpg_mean[:, 0]],\n", 185 | " [ddpg_mean[:, 0], ddpg_mean[:, 0]]]\n", 186 | "\n", 187 | "for max_len in [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]:\n", 188 | " temporl_datas = []\n", 189 | " for i in sorted(os.listdir(f'experiments/ddpg/TempoRLDDPG/{max_len}')):\n", 190 | " temporl_datas.append(np.load(f'experiments/ddpg/TempoRLDDPG/{max_len}/{i}/TempoRLDDPG_Pendulum-v0_{i}.npy'))\n", 191 | "\n", 192 | " figar_datas = []\n", 193 | " for i in sorted(os.listdir(f'experiments/ddpg/FiGARDDPG/{max_len}')):\n", 194 | " figar_datas.append(np.load(f'experiments/ddpg/FiGARDDPG/{max_len}/{i}/FiGARDDPG_Pendulum-v0_{i}.npy'))\n", 195 | "\n", 196 | " temporl_mean = np.mean(temporl_datas, axis=0)\n", 197 | " figar_mean = np.mean(figar_datas, axis=0)\n", 198 | " temporl_stdev = np.std(temporl_datas, axis=0)\n", 199 | " figar_stdev = np.std(figar_datas, axis=0)\n", 200 | " \n", 201 | " # (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes]))\n", 202 | " results['t-DDPG'] = [[temporl_mean[:, 1], temporl_stdev[:, 1]],\n", 203 | " [temporl_mean[:, 3], temporl_stdev[:, 3]],\n", 204 | " [temporl_mean[:, 2], temporl_stdev[:, 2]],\n", 205 | " [temporl_mean[:, 0], temporl_mean[:, 0]],\n", 206 | " [temporl_mean[:, 0], temporl_mean[:, 0]]]\n", 207 | " results['f-DDPG'] = [[figar_mean[:, 1], figar_stdev[:, 1]],\n", 208 | " [figar_mean[:, 3], figar_stdev[:, 3]],\n", 209 | " [figar_mean[:, 2], figar_stdev[:, 2]],\n", 210 | " [figar_mean[:, 0], figar_mean[:, 0]],\n", 211 | " [figar_mean[:, 0], figar_mean[:, 0]]]\n", 212 | " print(min(min(results['DDPG'][0][0]), min(results['t-DDPG'][0][0]), min(results['f-DDPG'][0][0])),\n", 213 | " max(max(results['DDPG'][0][0]), max(results['t-DDPG'][0][0]), max(results['f-DDPG'][0][0])))\n", 214 | " print(' DDPG AUC:', np.mean((results['DDPG'][0][0] + 1800) / (-145 + 1800)))\n", 215 | " print('t-DDPG AUC:', np.mean((results['t-DDPG'][0][0] + 1800) / (-145 + 1800)))\n", 216 | " print(' FiGAR AUC:', np.mean((results['f-DDPG'][0][0] + 1800) / (-145 + 1800)))\n", 217 | " plotMultiple(results, title=r'Pendulum-v0 -- $\\mathcal{J}=' + f'{max_len}$',\n", 218 | " smooth=0, ylim=[-1800, -50], min_steps=10, max_steps=210, xlim=[10**3, 3*10**4],\n", 219 | " lenyticks=[50, 125, 200], rewyticks=[-1800, -1000, -200],\n", 220 | " savename=f'ddpg_{max_len}.pdf')" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "Python 3", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.7.7" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 4 245 | } 246 | -------------------------------------------------------------------------------- /plot_featurized_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "from utils.plotting import get_colors, load_config, plot\n", 12 | "from utils.data_handling import load_dqn_data\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "#### Name explanations\n", 21 | "* DQN -> standard DQN\n", 22 | "* DAR_min^max -> Dynamic action repetition with small repetition and long repetition values\n", 23 | "* tqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action to be concatenated to the state\n", 24 | "* t-dqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action as contextual input\n", 25 | "* tdqn -> TempoRL DQN with shared state representation between the behavoiur and skip action outputs." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import json\n", 35 | "import glob\n", 36 | "import os\n", 37 | "import pandas as pd\n", 38 | "from matplotlib import pyplot as plt\n", 39 | "import seaborn as sb\n", 40 | "\n", 41 | "from scipy.signal import savgol_filter\n", 42 | " \n", 43 | "\n", 44 | "# Somehow the plotting functionallity I ended up with was already covered for the tabular case.\n", 45 | "# I should have just used the plot function from that.\n", 46 | "def plotMultiple(data, ylim=None, title='', logStepY=False, max_steps=200, xlim=None, figsize=None,\n", 47 | " alphas=None, smooth=5, savename=None, rewyticks=None, lenyticks=None,\n", 48 | " skip_stdevs=[], dont_label=[], dont_plot=[]):\n", 49 | " \"\"\"\n", 50 | " Simple plotting method that shows the test reward on the y-axis and the number of performed training steps\n", 51 | " on the x-axis.\n", 52 | " \n", 53 | " data -> (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes])) the data to plot\n", 54 | " ylim -> (list) y-axis limit\n", 55 | " title -> (str) title on top of plot\n", 56 | " logStepY -> (bool) flag that indicates if the y-axis should be on log scale.\n", 57 | " max_steps -> (int) maximal episode length\n", 58 | " xlim -> (list) x-axis limits\n", 59 | " figsize -> (list) dimensions of the figure\n", 60 | " alphas -> (dict[agent name] -> float) the alpha value to use for plotting of specific agents\n", 61 | " smooth -> (int) the window size for smoothing (has to be odd if used. < 0 deactivates this option)\n", 62 | " savename -> (str) filename to save the figure\n", 63 | " rewyticks -> (list) yticks for the reward plot\n", 64 | " lenyticks -> (list) yticks for the decisions plot\n", 65 | " skip_sdevs -> (list) list of names to not plot standard deviations for.\n", 66 | " dont_label -> (list) list of names to not label.\n", 67 | " dont_plot -> (list) list of names to not plot.\n", 68 | " \"\"\"\n", 69 | " \n", 70 | " if smooth and smooth > 0:\n", 71 | " degree = 2\n", 72 | " for agent in data:\n", 73 | " data[agent] = list(data[agent]) # we have to convert the tuple to lists\n", 74 | " data[agent][0] = list(data[agent][0])\n", 75 | " data[agent][0][0] = savgol_filter(data[agent][0][0], smooth, degree) # smooth the mean reward\n", 76 | " data[agent][0][1] = savgol_filter(data[agent][0][1], smooth, degree) # smooth the stdev reward\n", 77 | " data[agent][1] = list(data[agent][1])\n", 78 | " data[agent][1][0] = savgol_filter(data[agent][1][0], smooth, degree) # smooth mean num steps\n", 79 | " data[agent][1][1] = savgol_filter(data[agent][1][1], smooth, degree)\n", 80 | " data[agent][2] = list(data[agent][2])\n", 81 | " data[agent][2][0] = savgol_filter(data[agent][2][0], smooth, degree) # smooth mean decisions\n", 82 | " data[agent][2][1] = savgol_filter(data[agent][2][1], smooth, degree)\n", 83 | "\n", 84 | " colors, color_map = get_colors()\n", 85 | " \n", 86 | "\n", 87 | " cfg = load_config()\n", 88 | " sb.set_style(cfg['plotting']['seaborn']['style'])\n", 89 | " sb.set_context(cfg['plotting']['seaborn']['context']['context'],\n", 90 | " font_scale=cfg['plotting']['seaborn']['context']['font scale'],\n", 91 | " rc=cfg['plotting']['seaborn']['context']['rc2'])\n", 92 | "\n", 93 | " if figsize:\n", 94 | " fig, ax = plt.subplots(2, figsize=figsize, dpi=100, sharex=True)\n", 95 | " else:\n", 96 | " fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100,sharex=True)\n", 97 | " ax[0].set_title(title)\n", 98 | "\n", 99 | " for agent in list(data.keys())[::-1]:\n", 100 | " if agent in dont_plot:\n", 101 | " continue\n", 102 | " try:\n", 103 | " alph = alphas[agent]\n", 104 | " except:\n", 105 | " alph = 1.\n", 106 | " color_name = None\n", 107 | " if 'dar' in agent:\n", 108 | " color_name = color_map['dar']\n", 109 | " elif agent.startswith('t'):\n", 110 | " color_name = color_map['t-dqn']\n", 111 | " else:\n", 112 | " color_name = color_map[agent]\n", 113 | " rew, lens, decs, train_steps, train_eps = data[agent]\n", 114 | " \n", 115 | " label = agent.upper()\n", 116 | " if agent.startswith('t'):\n", 117 | " label = 't-DQN'\n", 118 | " elif agent in dont_label:\n", 119 | " label = None\n", 120 | "\n", 121 | " #### Plot rewards\n", 122 | " ax[0].step(train_steps[0][::5], rew[0][::5], where='post', c=colors[color_name], label=label,\n", 123 | " alpha=alph)\n", 124 | " if agent not in skip_stdevs:\n", 125 | " ax[0].fill_between(train_steps[0][::5], rew[0][::5]-rew[1][::5], rew[0][::5]+rew[1][::5], alpha=0.25 * alph, step='post',\n", 126 | " color=colors[color_name])\n", 127 | " #### Plot lens\n", 128 | " ax[1].step(train_steps[0], decs[0], where='post', c=np.array(colors[color_name]), ls='-',\n", 129 | " alpha=alph)\n", 130 | " if agent not in skip_stdevs:\n", 131 | " ax[1].fill_between(train_steps[0][::5], decs[0][::5]-decs[1][::5], decs[0][::5]+decs[1][::5], alpha=0.125 * alph, step='post',\n", 132 | " color=np.array(colors[color_name]))\n", 133 | " ax[1].step(train_steps[0][::5], lens[0][::5], where='post',\n", 134 | " c=np.array(colors[color_name]) * .75, alpha=alph, ls=':')\n", 135 | " \n", 136 | " if agent not in skip_stdevs:\n", 137 | " ax[1].fill_between(train_steps[0][::5], lens[0][::5]-lens[1][::5], lens[0][::5]+lens[1][::5], alpha=0.25 * alph, step='post',\n", 138 | " color=np.array(colors[color_name]) * .75)\n", 139 | " ax[0].semilogx()\n", 140 | " if rewyticks is not None:\n", 141 | " ax[0].set_yticks(rewyticks)\n", 142 | " if ylim:\n", 143 | " ax[0].set_ylim(ylim)\n", 144 | " if xlim:\n", 145 | " ax[0].set_xlim(xlim)\n", 146 | " ax[0].set_ylabel('Reward')\n", 147 | " if len(data) - len(dont_label) < 5:\n", 148 | " ax[0].legend(ncol=1, loc='best', handlelength=.75)\n", 149 | " ax[1].semilogx()\n", 150 | " if logStepY:\n", 151 | " ax[1].semilogy()\n", 152 | " \n", 153 | " ax[1].plot([-999, -999], [-999, -999], ls=':', c='k', label='all')\n", 154 | " ax[1].plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')\n", 155 | " ax[1].legend(loc='best', ncol=1, handlelength=.75)\n", 156 | " ax[1].set_ylim([1, max_steps])\n", 157 | " if xlim:\n", 158 | " ax[1].set_xlim(xlim)\n", 159 | " ax[1].set_ylabel('#Actions')\n", 160 | " ax[1].set_xlabel('#Train Steps')\n", 161 | " if lenyticks is not None:\n", 162 | " ax[1].set_yticks(lenyticks)\n", 163 | " plt.tight_layout()\n", 164 | " if savename:\n", 165 | " plt.savefig(savename)\n", 166 | "\n", 167 | " plt.show()\n", 168 | "\n", 169 | "\n", 170 | "def get_best_to_plot(data, aucs, tempoRL=None):\n", 171 | " \"\"\"\n", 172 | " Simple method to filter which lines to plot.\n", 173 | " \"\"\"\n", 174 | " to_plot = dict()\n", 175 | "\n", 176 | " if tempoRL is None:\n", 177 | " aucs = list(sorted(aucs, key=lambda x: x[1], reverse=True))\n", 178 | " for idx, auc in enumerate(aucs):\n", 179 | " if 't' in auc[0]:\n", 180 | " break\n", 181 | " to_plot[aucs[idx][0]] = data[aucs[idx][0]] # the absolute best\n", 182 | " else:\n", 183 | " to_plot[tempoRL] = data[tempoRL]\n", 184 | "\n", 185 | " bv = -np.inf\n", 186 | " b = None\n", 187 | " for elem in aucs:\n", 188 | " if 'dar' not in elem[0]:\n", 189 | " continue\n", 190 | " elif elem[1] > bv:\n", 191 | " b, bv = elem[0], elem[1]\n", 192 | " to_plot[b] = data[b]\n", 193 | " \n", 194 | " \n", 195 | " to_plot['dqn'] = data['dqn']\n", 196 | " print('Only plotting:', list(to_plot.keys()))\n", 197 | " return to_plot" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": {}, 203 | "source": [ 204 | "











" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "scrolled": false 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "mountain_sparse_data = {}\n", 216 | "mountain_sparse_alphas = {}\n", 217 | "mountain_sparse_aucs = []\n", 218 | "max_steps=10**6\n", 219 | "thresh = -110\n", 220 | "\n", 221 | "for pairs in [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9)]:\n", 222 | " dar_fm_str = r'$dar_{' + '{}'.format(pairs[0] + 1) + '}^{' + '{}'.format(pairs[1] + 1) + '}$'\n", 223 | " mountain_sparse_alphas[dar_fm_str] = 1/5\n", 224 | " mountain_sparse_data[dar_fm_str] = load_dqn_data(\n", 225 | " '*', 'experiments/featurized_results/sparsemountain/dar_orig_%d_%d' % (pairs[0], pairs[1]), max_steps=max_steps,\n", 226 | " )\n", 227 | " try:\n", 228 | " mountain_sparse_aucs.append([dar_fm_str, np.trapz((mountain_sparse_data[dar_fm_str][0][0] + 200)/110,\n", 229 | " x=(mountain_sparse_data[dar_fm_str][3][0]/max(\n", 230 | " mountain_sparse_data[dar_fm_str][3][0])))])\n", 231 | " except:\n", 232 | " del mountain_sparse_data[dar_fm_str]\n", 233 | "\n", 234 | "\n", 235 | "\n", 236 | "mountain_sparse_data['dqn'] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/dqn', max_steps=max_steps,\n", 237 | " )\n", 238 | "mountain_sparse_aucs.append(['dqn', np.trapz((mountain_sparse_data['dqn'][0][0] + 200)/110,\n", 239 | " x=(mountain_sparse_data['dqn'][3][0]/max(\n", 240 | " mountain_sparse_data['dqn'][3][0])))])\n", 241 | "\n", 242 | "for i in [2, 4, 6, 8, 10]:\n", 243 | " mountain_sparse_data['tqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/tqn_%d' % i, max_steps=max_steps,\n", 244 | " )\n", 245 | " mountain_sparse_aucs.append(['tqn_%d' % i, np.trapz((mountain_sparse_data['tqn_%d' % i][0][0] + 200)/110,\n", 246 | " x=(mountain_sparse_data['tqn_%d' % i][3][0]/max(\n", 247 | " mountain_sparse_data['tqn_%d' % i][3][0])))])\n", 248 | "\n", 249 | " mountain_sparse_data['t-dqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/t-dqn_%d' % i,\n", 250 | " max_steps=max_steps,\n", 251 | " )\n", 252 | " mountain_sparse_aucs.append(['t-dqn_%d' % i, np.trapz((mountain_sparse_data['t-dqn_%d' % i][0][0] + 200)/110,\n", 253 | " x=(mountain_sparse_data['t-dqn_%d' % i][3][0]/max(\n", 254 | " mountain_sparse_data['t-dqn_%d' % i][3][0])))])\n", 255 | "\n", 256 | " mountain_sparse_data['tdqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/tdqn_%d' % i, max_steps=max_steps,\n", 257 | " )\n", 258 | " mountain_sparse_aucs.append(['tdqn_%d' % i, np.trapz((mountain_sparse_data['tdqn_%d' % i][0][0] + 200)/110,\n", 259 | " x=(mountain_sparse_data['tdqn_%d' % i][3][0]/max(\n", 260 | " mountain_sparse_data['tdqn_%d' % i][3][0])))])" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": { 267 | "scrolled": false 268 | }, 269 | "outputs": [], 270 | "source": [ 271 | "mountain_sparse_plot = get_best_to_plot(mountain_sparse_data, mountain_sparse_aucs)\n", 272 | "\n", 273 | "plotMultiple(mountain_sparse_plot, title='MountainCar-v0',\n", 274 | " ylim=[-200, -100], max_steps=200, xlim=[10**3, 10**6], smooth=11,\n", 275 | " savename='mcv0-sparse.pdf', rewyticks=[-190, -150, -110], lenyticks=[0, 75, 150])" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "list(sorted(mountain_sparse_aucs, key=lambda x: x[1], reverse=True))" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "











" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": { 298 | "scrolled": false 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "moon_dense_data = {}\n", 303 | "moon_dense_alphas = {}\n", 304 | "moon_dense_aucs = []\n", 305 | "max_steps=10**6\n", 306 | "thresh=200\n", 307 | "\n", 308 | "for pairs in [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9)]:\n", 309 | " dar_fm_str = r'$dar_{' + '{}'.format(pairs[0] + 1) + '}^{' + '{}'.format(pairs[1] + 1) + '}$'\n", 310 | " moon_dense_alphas[dar_fm_str] = 1/5\n", 311 | " moon_dense_data[dar_fm_str] = load_dqn_data(\n", 312 | " '*', 'experiments/featurized_results/moon/dar_orig_%d_%d' % (pairs[0], pairs[1]), max_steps=max_steps,\n", 313 | " )\n", 314 | " moon_dense_aucs.append([dar_fm_str, np.trapz((moon_dense_data[dar_fm_str][0][0] + 250) / 500,\n", 315 | " x=(moon_dense_data[dar_fm_str][3][0]/max(\n", 316 | " moon_dense_data[dar_fm_str][3][0])))])\n", 317 | "\n", 318 | " \n", 319 | "moon_dense_data['dqn'] = load_dqn_data('*', 'experiments/featurized_results/moon/dqn', max_steps=max_steps,\n", 320 | " )\n", 321 | "moon_dense_aucs.append(['dqn', np.trapz((moon_dense_data['dqn'][0][0] + 250) / 500,\n", 322 | " x=(moon_dense_data['dqn'][3][0]/max(moon_dense_data['dqn'][3][0])))])\n", 323 | "\n", 324 | "\n", 325 | "for i in [2, 4, 6, 8, 10]:\n", 326 | " moon_dense_data['tqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/moon/tqn_%d' % i, max_steps=max_steps,\n", 327 | " )\n", 328 | " # compute normalized AUC\n", 329 | " moon_dense_aucs.append(['tqn_%d' % i, np.trapz((moon_dense_data['tqn_%d' % i][0][0] + 250) / 500,\n", 330 | " x=(moon_dense_data['tqn_%d' % i][3][0]/max(moon_dense_data['tqn_%d' % i][3][0])))])\n", 331 | "\n", 332 | " moon_dense_data['t-dqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/moon/t-dqn_%d' % i, max_steps=max_steps,\n", 333 | " )\n", 334 | " moon_dense_aucs.append(['t-dqn_%d' % i, np.trapz((moon_dense_data['t-dqn_%d' % i][0][0] + 250)/500,\n", 335 | " x=(moon_dense_data['t-dqn_%d' % i][3][0]/max(moon_dense_data['t-dqn_%d' % i][3][0])))])\n", 336 | "\n", 337 | " moon_dense_data['tdqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/moon/tdqn_%d' % i, max_steps=max_steps,\n", 338 | " )\n", 339 | " moon_dense_aucs.append(['tdqn_%d' % i, np.trapz((moon_dense_data['tdqn_%d' % i][0][0] + 250)/500,\n", 340 | " x=(moon_dense_data['tdqn_%d' % i][3][0]/max(moon_dense_data['tdqn_%d' % i][3][0])))])" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "moon_plot = get_best_to_plot(moon_dense_data, moon_dense_aucs)\n", 350 | "\n", 351 | "plotMultiple(moon_plot, title='LunarLander-v2', ylim=[-250, 200], max_steps=1000, xlim=[10**3, 10**6],\n", 352 | " smooth=11, savename='llv2-dense.pdf', rewyticks=[-250, 0, 200], lenyticks=[200, 500, 800])" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "list(sorted(moon_dense_aucs, key=lambda x: x[1], reverse=True))" 362 | ] 363 | } 364 | ], 365 | "metadata": { 366 | "kernelspec": { 367 | "display_name": "Python 3", 368 | "language": "python", 369 | "name": "python3" 370 | }, 371 | "language_info": { 372 | "codemirror_mode": { 373 | "name": "ipython", 374 | "version": 3 375 | }, 376 | "file_extension": ".py", 377 | "mimetype": "text/x-python", 378 | "name": "python", 379 | "nbconvert_exporter": "python", 380 | "pygments_lexer": "ipython3", 381 | "version": "3.7.7" 382 | } 383 | }, 384 | "nbformat": 4, 385 | "nbformat_minor": 4 386 | } 387 | -------------------------------------------------------------------------------- /plot_tabular_results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "from utils.data_handling import *\n", 12 | "from utils.plotting import *" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "









\n", 20 | "\n", 21 | "# Cliff" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": { 28 | "scrolled": false 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "max_skip = 7\n", 33 | "episodes = 10_000\n", 34 | "max_steps = 100\n", 35 | "\n", 36 | "for exp_version in ['-1.0-linear', '-0.1-const', '-1.0-log']:\n", 37 | " print(exp_version[1:].upper())\n", 38 | " methods = ['q', 'sq']\n", 39 | " rews, lens, steps = load_data(\"experiments/tabular_results/cliff\", methods, exp_version,\n", 40 | " episodes, max_skip, max_steps, local=True)\n", 41 | " for m in methods:\n", 42 | " print(len(rews[m]))\n", 43 | " title = '{:s}'.format(\"Cliff\")\n", 44 | " plot(methods, rews, lens, steps, title, episodes, logrewy=False,\n", 45 | " logleny=False, logx=True, annotate=True, savefig=\"cliff{:s}.pdf\".format(exp_version.replace('.', '_')),\n", 46 | " individual=False)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "









\n", 54 | "\n", 55 | "# Bridge" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "scrolled": false 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "max_skip = 7\n", 67 | "episodes = 10_000\n", 68 | "max_steps = 100\n", 69 | "\n", 70 | "for exp_version in ['-1.0-linear', '-0.1-const', '-1.0-log']:\n", 71 | " print(exp_version[1:].upper())\n", 72 | " methods = ['q', 'sq']\n", 73 | " rews, lens, steps = load_data(\"experiments/tabular_results/bridge\", methods, exp_version,\n", 74 | " episodes, max_skip, max_steps, local=True)\n", 75 | " for m in methods:\n", 76 | " print(len(rews[m]))\n", 77 | " title = '{:s}'.format(\"Bridge\")\n", 78 | " plot(methods, rews, lens, steps, title, episodes, annotate=True,\n", 79 | " logrewy=False, logleny=False, savefig=\"bridge{:s}.pdf\".format(exp_version.replace('.', '_')))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "









\n", 87 | "\n", 88 | "# ZigZag" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": { 95 | "scrolled": false 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "max_skip = 7\n", 100 | "episodes = 10_000\n", 101 | "max_steps = 100\n", 102 | "\n", 103 | "for exp_version in ['-1.0-linear', '-0.1-const', '-1.0-log']:\n", 104 | " print(exp_version[1:].upper())\n", 105 | " methods = ['q', 'sq']\n", 106 | " rews, lens, steps = load_data(\"experiments/tabular_results/zigzag\", methods, exp_version,\n", 107 | " episodes, max_skip, max_steps, local=True)\n", 108 | " for m in methods:\n", 109 | " print(len(rews[m]))\n", 110 | " title = '{:s}'.format(\"ZigZag\")\n", 111 | " plot(methods, rews, lens, steps, title, episodes, annotate=True,\n", 112 | " logrewy=False, logleny=False, savefig=\"zigzag{:s}.pdf\".format(exp_version.replace('.', '_')))" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "









\n", 120 | "\n", 121 | " \n", 122 | "# Influence of skip-lenth on tempoRL" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "## ZigZag - Linear" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": { 136 | "scrolled": false 137 | }, 138 | "outputs": [], 139 | "source": [ 140 | "max_skip = 2\n", 141 | "episodes = 10_000\n", 142 | "max_steps = 100\n", 143 | "\n", 144 | "for exp_version in ['-1.0-linear']:\n", 145 | " print(exp_version[1:].upper())\n", 146 | " for max_skip in range(2, 17):\n", 147 | " rews, lens, steps = load_data(\"experiments/tabular_results/j_ablation/zigzag\", ['sq'], exp_version,\n", 148 | " episodes, max_skip, max_steps, local=True)\n", 149 | " # the q we compare to is the same as in standard zigzag. max_skip does not influence q\n", 150 | " rews_, lens_, steps_ = load_data(\"experiments/tabular_results/zigzag\", ['q'], exp_version,\n", 151 | " episodes, 7, max_steps, local=True)\n", 152 | " rews.update(rews_)\n", 153 | " lens.update(lens_)\n", 154 | " steps.update(steps_)\n", 155 | " for m in methods:\n", 156 | " print(len(rews[m]))\n", 157 | " title = '{:s} - {:d}'.format(\"ZigZag\", max_skip)\n", 158 | " plot(['sq', 'q'], rews, lens, steps, title, episodes, annotate=False,\n", 159 | " logrewy=False, logleny=False, savefig=None)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## ZigZag - Log" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "scrolled": false 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "max_skip = 2\n", 178 | "episodes = 10_000\n", 179 | "max_steps = 100\n", 180 | "\n", 181 | "for exp_version in ['-1.0-log']:\n", 182 | " print(exp_version[1:].upper())\n", 183 | " for max_skip in range(2, 17):\n", 184 | " rews, lens, steps = load_data(\"experiments/tabular_results/j_ablation/zigzag\", ['sq'], exp_version,\n", 185 | " episodes, max_skip, max_steps, local=True)\n", 186 | " rews_, lens_, steps_ = load_data(\"experiments/tabular_results/zigzag\", ['q'], exp_version,\n", 187 | " episodes, 7, max_steps, local=True)\n", 188 | " rews.update(rews_)\n", 189 | " lens.update(lens_)\n", 190 | " steps.update(steps_)\n", 191 | " for m in methods:\n", 192 | " print(len(rews[m]))\n", 193 | " title = '{:s} - {:d}'.format(\"ZigZag\", max_skip)\n", 194 | " plot(['sq', 'q'], rews, lens, steps, title, episodes, annotate=False,\n", 195 | " logrewy=False, logleny=False, savefig=None)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": {}, 201 | "source": [ 202 | "## ZigZag - Constant" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "scrolled": false 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "max_skip = 2\n", 214 | "episodes = 10_000\n", 215 | "max_steps = 100\n", 216 | "\n", 217 | "for exp_version in ['-0.1-const']:\n", 218 | " print(exp_version[1:].upper())\n", 219 | " for max_skip in range(2, 17):\n", 220 | " rews, lens, steps = load_data(\"experiments/tabular_results/j_ablation/zigzag\", ['sq'], exp_version,\n", 221 | " episodes, max_skip, max_steps, local=True)\n", 222 | " rews_, lens_, steps_ = load_data(\"experiments/tabular_results/zigzag\", ['q'], exp_version,\n", 223 | " episodes, 7, max_steps, local=True)\n", 224 | " rews.update(rews_)\n", 225 | " lens.update(lens_)\n", 226 | " steps.update(steps_)\n", 227 | " for m in methods:\n", 228 | " print(len(rews[m]))\n", 229 | " title = '{:s} - {:d}'.format(\"ZigZag\", max_skip)\n", 230 | " plot(['sq', 'q'], rews, lens, steps, title, episodes, annotate=False,\n", 231 | " logrewy=False, logleny=False, savefig=None)" 232 | ] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": "Python 3", 238 | "language": "python", 239 | "name": "python3" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 3 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython3", 251 | "version": "3.7.7" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 2 256 | } 257 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml~=5.4 2 | seaborn~=0.10.1 3 | matplotlib~=3.3.0 4 | gym~=0.17.2 5 | numpy~=1.19.1 6 | scipy~=1.5.2 7 | torchvision~=0.5.0 8 | pillow~=7.2.0 9 | ray[rllib]~=1.0.1 10 | opencv-python~=4.4.0.46 11 | future~=0.18.2 12 | torch~=1.4.0 13 | -------------------------------------------------------------------------------- /run_atari_experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | import gym 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | from itertools import count 12 | from collections import namedtuple 13 | import time 14 | from mountain_car import MountainCarEnv 15 | from utils import experiments 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | device = torch.device('cpu') 19 | 20 | 21 | def tt(ndarray): 22 | """ 23 | Helper Function to cast observation to correct type/device 24 | """ 25 | if device.type == "cuda": 26 | return Variable(torch.from_numpy(ndarray).float().cuda(), requires_grad=False) 27 | else: 28 | return Variable(torch.from_numpy(ndarray).float(), requires_grad=False) 29 | 30 | 31 | def soft_update(target, source, tau): 32 | """ 33 | Simple Helper for updating target-network parameters 34 | :param target: target network 35 | :param source: policy network 36 | :param tau: weight to regulate how strongly to update (1 -> copy over weights) 37 | """ 38 | for target_param, param in zip(target.parameters(), source.parameters()): 39 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 40 | 41 | 42 | def hard_update(target, source): 43 | """ 44 | See soft_update 45 | """ 46 | soft_update(target, source, 1.0) 47 | 48 | 49 | class NatureDQN(nn.Module): 50 | """ 51 | DQN following the DQN implementation from 52 | https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf 53 | """ 54 | 55 | def __init__(self, in_channels=4, num_actions=18): 56 | """ 57 | :param in_channels: number of channel of input. (how many stacked images are used) 58 | :param num_actions: action values 59 | """ 60 | super(NatureDQN, self).__init__() 61 | if env.observation_space.shape[-1] == 84: #hack 62 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4) 63 | elif env.observation_space.shape[-1] == 42: #hack 64 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=4, stride=2) 65 | else: 66 | raise ValueError("Check state space dimensionality. Expected nx42x42 or nx84x84. Was:", env.observation_space.shape) 67 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 68 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 69 | self.fc4 = nn.Linear(7 * 7 * 64, 512) 70 | self.fc5 = nn.Linear(512, num_actions) 71 | 72 | def forward(self, x): 73 | x = F.relu(self.conv1(x)) 74 | x = F.relu(self.conv2(x)) 75 | x = F.relu(self.conv3(x)) 76 | x = F.relu(self.fc4(x.reshape(x.size(0), -1))) 77 | return self.fc5(x) 78 | 79 | 80 | class NatureTQN(nn.Module): 81 | """ 82 | Network to learn the skip behaviour using the same architecture as the original DQN but with additional context. 83 | The context is expected to be the chosen behaviour action on which the skip-Q is conditioned. 84 | 85 | This Q function is expected to be used solely to learn the skip-Q function 86 | """ 87 | 88 | def __init__(self, in_channels=4, num_actions=18): 89 | """ 90 | :param in_channels: number of channel of input. (how many stacked images are used) 91 | :param num_actions: action values 92 | """ 93 | super(NatureTQN, self).__init__() 94 | if env.observation_space.shape[-1] == 84: # hack 95 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4) 96 | elif env.observation_space.shape[-1] == 42: # hack 97 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=4, stride=2) 98 | else: 99 | raise ValueError("Check state space dimensionality. Expected nx42x42 or nx84x84. Was:", 100 | env.observation_space.shape) 101 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 102 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 103 | 104 | self.skip = nn.Linear(1, 10) # Context layer 105 | 106 | self.fc4 = nn.Linear(7 * 7 * 64 + 10, 512) # Combination layer 107 | self.fc5 = nn.Linear(512, num_actions) # Output 108 | 109 | def forward(self, x, action_val=None): 110 | # Process input image 111 | x = F.relu(self.conv1(x)) 112 | x = F.relu(self.conv2(x)) 113 | x = F.relu(self.conv3(x)) 114 | 115 | # Process behaviour context 116 | x_ = F.relu(self.skip(action_val)) 117 | 118 | # Combine both streams 119 | x = F.relu(self.fc4( 120 | torch.cat([x.reshape(x.size(0), -1), x_], 1))) # This layer concatenates the context and CNN part 121 | return self.fc5(x) 122 | 123 | 124 | class NatureWeightsharingTQN(nn.Module): 125 | """ 126 | Network to learn the skip behaviour using the same architecture as the original DQN but with additional context. 127 | The context is expected to be the chosen behaviour action on which the skip-Q is conditioned. 128 | This implementation allows to share weights between the behaviour network and the skip network 129 | """ 130 | 131 | def __init__(self, in_channels=4, num_actions=18, num_skip_actions=10): 132 | """ 133 | :param in_channels: number of channel of input. (how many stacked images are used) 134 | :param num_actions: action values 135 | """ 136 | super(NatureWeightsharingTQN, self).__init__() 137 | # shared input-layers 138 | if env.observation_space.shape[-1] == 84: #hack 139 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4) 140 | elif env.observation_space.shape[-1] == 42: #hack 141 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=4, stride=2) 142 | else: 143 | raise ValueError("Check state space dimensionality. Expected nx42x42 or nx84x84. Was:", env.observation_space.shape) 144 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) 145 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) 146 | 147 | # skip-layers 148 | self.skip = nn.Linear(1, 10) # Context layer 149 | self.skip_fc4 = nn.Linear(7 * 7 * 64 + 10, 512) 150 | self.skip_fc5 = nn.Linear(512, num_skip_actions) 151 | 152 | # behaviour-layers 153 | self.action_fc4 = nn.Linear(7 * 7 * 64, 512) 154 | self.action_fc5 = nn.Linear(512, num_actions) 155 | 156 | def forward(self, x, action_val=None): 157 | # Process input image 158 | x = F.relu(self.conv1(x)) 159 | x = F.relu(self.conv2(x)) 160 | x = F.relu(self.conv3(x)) 161 | 162 | if action_val is not None: # Only if an action_value was provided we evaluate the skip output layers Q(s,j|a) 163 | x_ = F.relu(self.skip(action_val)) 164 | x = F.relu(self.skip_fc4( 165 | torch.cat([x.reshape(x.size(0), -1), x_], 1))) # This layer concatenates the context and CNN part 166 | return self.skip_fc5(x) 167 | else: # otherwise we simply continue as in standard DQN and compute Q(s,a) 168 | x = F.relu(self.action_fc4(x.reshape(x.size(0), -1))) 169 | return self.action_fc5(x) 170 | 171 | 172 | class Q(nn.Module): 173 | """ 174 | Simple fully connected Q function. Also used for skip-Q when concatenating behaviour action and state together. 175 | Used for simpler environments such as mountain-car or lunar-lander. 176 | """ 177 | 178 | def __init__(self, state_dim, action_dim, non_linearity=F.relu, hidden_dim=50): 179 | super(Q, self).__init__() 180 | self.fc1 = nn.Linear(state_dim, hidden_dim) 181 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 182 | self.fc3 = nn.Linear(hidden_dim, action_dim) 183 | self._non_linearity = non_linearity 184 | 185 | def forward(self, x): 186 | x = self._non_linearity(self.fc1(x)) 187 | x = self._non_linearity(self.fc2(x)) 188 | return self.fc3(x) 189 | 190 | 191 | class TQ(nn.Module): 192 | """ 193 | Q-Function that takes the behaviour action as context. 194 | This Q is solely inteded to be used for computing the skip-Q Q(s,j|a). 195 | 196 | Basically the same architecture as Q but with context input layer. 197 | """ 198 | 199 | def __init__(self, state_dim, skip_dim, non_linearity=F.relu, hidden_dim=50): 200 | super(TQ, self).__init__() 201 | self.fc1 = nn.Linear(state_dim, hidden_dim) 202 | self.skip_fc2 = nn.Linear(1, 10) 203 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 204 | self.skip_fc3 = nn.Linear(hidden_dim + 10, skip_dim) # output layer taking context and state into account 205 | self._non_linearity = non_linearity 206 | 207 | def forward(self, x, a=None): 208 | # Process the input state 209 | x = self._non_linearity(self.fc1(x)) 210 | x = self._non_linearity(self.fc2(x)) 211 | 212 | # Process behaviour-action as context 213 | x_ = self._non_linearity(self.skip_fc2(a)) 214 | return self.skip_fc3(torch.cat([x, x_], -1)) # Concatenate both to produce the final output 215 | 216 | 217 | class WeightSharingTQ(nn.Module): 218 | """ 219 | Q-function with shared state representation but two independent output streams (action, skip) 220 | """ 221 | 222 | def __init__(self, state_dim, action_dim, skip_dim, non_linearity=F.relu, hidden_dim=50): 223 | super(WeightSharingTQ, self).__init__() 224 | self.fc1 = nn.Linear(state_dim, hidden_dim) 225 | self.skip_fc2 = nn.Linear(1, 10) 226 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 227 | self.action_fc3 = nn.Linear(hidden_dim, action_dim) 228 | self.skip_fc3 = nn.Linear(hidden_dim + 10, skip_dim) 229 | self._non_linearity = non_linearity 230 | 231 | def forward(self, x, a=None): 232 | # Process input state with shared layers 233 | x = self._non_linearity(self.fc1(x)) 234 | x = self._non_linearity(self.fc2(x)) 235 | 236 | if a is not None: # Only compute Skip Output if the behaviour action is given as context 237 | x_ = self._non_linearity(self.skip_fc2(a)) 238 | return self.skip_fc3(torch.cat([x, x_], -1)) 239 | 240 | # Only compute Behaviour output 241 | return self.action_fc3(x) 242 | 243 | 244 | class ReplayBuffer: 245 | """ 246 | Simple Replay Buffer. Used for standard DQN learning. 247 | """ 248 | 249 | def __init__(self, max_size): 250 | self._data = namedtuple("ReplayBuffer", ["states", "actions", "next_states", "rewards", "terminal_flags"]) 251 | self._data = self._data(states=[], actions=[], next_states=[], rewards=[], terminal_flags=[]) 252 | self._size = 0 253 | self._max_size = max_size 254 | 255 | def add_transition(self, state, action, next_state, reward, done): 256 | self._data.states.append(state) 257 | self._data.actions.append(action) 258 | self._data.next_states.append(next_state) 259 | self._data.rewards.append(reward) 260 | self._data.terminal_flags.append(done) 261 | self._size += 1 262 | 263 | if self._size > self._max_size: 264 | self._data.states.pop(0) 265 | self._data.actions.pop(0) 266 | self._data.next_states.pop(0) 267 | self._data.rewards.pop(0) 268 | self._data.terminal_flags.pop(0) 269 | 270 | def random_next_batch(self, batch_size): 271 | batch_indices = np.random.choice(len(self._data.states), batch_size) 272 | batch_states = np.array([self._data.states[i] for i in batch_indices]) 273 | batch_actions = np.array([self._data.actions[i] for i in batch_indices]) 274 | batch_next_states = np.array([self._data.next_states[i] for i in batch_indices]) 275 | batch_rewards = np.array([self._data.rewards[i] for i in batch_indices]) 276 | batch_terminal_flags = np.array([self._data.terminal_flags[i] for i in batch_indices]) 277 | return tt(batch_states), tt(batch_actions), tt(batch_next_states), tt(batch_rewards), tt(batch_terminal_flags) 278 | 279 | 280 | class SkipReplayBuffer: 281 | """ 282 | Replay Buffer for training the skip-Q. 283 | Expects "concatenated states" which already contain the behaviour-action for the skip-Q. 284 | Stores transitions as usual but with additional skip-length. The skip-length is used to properly discount. 285 | """ 286 | 287 | def __init__(self, max_size): 288 | self._data = namedtuple("ReplayBuffer", ["states", "actions", "next_states", 289 | "rewards", "terminal_flags", "lengths"]) 290 | self._data = self._data(states=[], actions=[], next_states=[], rewards=[], terminal_flags=[], lengths=[]) 291 | self._size = 0 292 | self._max_size = max_size 293 | 294 | def add_transition(self, state, action, next_state, reward, done, length): 295 | self._data.states.append(state) 296 | self._data.actions.append(action) 297 | self._data.next_states.append(next_state) 298 | self._data.rewards.append(reward) 299 | self._data.terminal_flags.append(done) 300 | self._data.lengths.append(length) # Observed skip-length of the transition 301 | self._size += 1 302 | 303 | if self._size > self._max_size: 304 | self._data.states.pop(0) 305 | self._data.actions.pop(0) 306 | self._data.next_states.pop(0) 307 | self._data.rewards.pop(0) 308 | self._data.terminal_flags.pop(0) 309 | self._data.lengths.pop(0) 310 | 311 | def random_next_batch(self, batch_size): 312 | batch_indices = np.random.choice(len(self._data.states), batch_size) 313 | batch_states = np.array([self._data.states[i] for i in batch_indices]) 314 | batch_actions = np.array([self._data.actions[i] for i in batch_indices]) 315 | batch_next_states = np.array([self._data.next_states[i] for i in batch_indices]) 316 | batch_rewards = np.array([self._data.rewards[i] for i in batch_indices]) 317 | batch_terminal_flags = np.array([self._data.terminal_flags[i] for i in batch_indices]) 318 | batch_lengths = np.array([self._data.lengths[i] for i in batch_indices]) 319 | return tt(batch_states), tt(batch_actions), tt(batch_next_states),\ 320 | tt(batch_rewards), tt(batch_terminal_flags), tt(batch_lengths) 321 | 322 | 323 | class NoneConcatSkipReplayBuffer: 324 | """ 325 | Replay Buffer for training the skip-Q. 326 | Expects states in which the behaviour-action is not siply concatenated for the skip-Q. 327 | Stores transitions as usual but with additional skip-length. The skip-length is used to properly discount. 328 | Additionally stores the behaviour_action which is the context for this skip-transition. 329 | """ 330 | 331 | def __init__(self, max_size): 332 | self._data = namedtuple("ReplayBuffer", ["states", "actions", "next_states", 333 | "rewards", "terminal_flags", "lengths", "behaviour_action"]) 334 | self._data = self._data(states=[], actions=[], next_states=[], rewards=[], terminal_flags=[], lengths=[], 335 | behaviour_action=[]) 336 | self._size = 0 337 | self._max_size = max_size 338 | 339 | def add_transition(self, state, action, next_state, reward, done, length, behaviour): 340 | self._data.states.append(state) 341 | self._data.actions.append(action) 342 | self._data.next_states.append(next_state) 343 | self._data.rewards.append(reward) 344 | self._data.terminal_flags.append(done) 345 | self._data.lengths.append(length) # Observed skip-length 346 | self._data.behaviour_action.append(behaviour) # Behaviour action to condition skip on 347 | self._size += 1 348 | 349 | if self._size > self._max_size: 350 | self._data.states.pop(0) 351 | self._data.actions.pop(0) 352 | self._data.next_states.pop(0) 353 | self._data.rewards.pop(0) 354 | self._data.terminal_flags.pop(0) 355 | self._data.lengths.pop(0) 356 | self._data.behaviour_action.pop(0) 357 | 358 | def random_next_batch(self, batch_size): 359 | batch_indices = np.random.choice(len(self._data.states), batch_size) 360 | batch_states = np.array([self._data.states[i] for i in batch_indices]) 361 | batch_actions = np.array([self._data.actions[i] for i in batch_indices]) 362 | batch_next_states = np.array([self._data.next_states[i] for i in batch_indices]) 363 | batch_rewards = np.array([self._data.rewards[i] for i in batch_indices]) 364 | batch_terminal_flags = np.array([self._data.terminal_flags[i] for i in batch_indices]) 365 | batch_lengths = np.array([self._data.lengths[i] for i in batch_indices]) 366 | batch_behavoiurs = np.array([self._data.behaviour_action[i] for i in batch_indices]) 367 | return tt(batch_states), tt(batch_actions), tt(batch_next_states),\ 368 | tt(batch_rewards), tt(batch_terminal_flags), tt(batch_lengths), tt(batch_behavoiurs) 369 | 370 | 371 | class DQN: 372 | """ 373 | Simple double DQN Agent 374 | """ 375 | 376 | def __init__(self, state_dim: int, action_dim: int, gamma: float, 377 | env: gym.Env, eval_env: gym.Env, vision: bool = False): 378 | """ 379 | Initialize the DQN Agent 380 | :param state_dim: dimensionality of the input states 381 | :param action_dim: dimensionality of the output actions 382 | :param gamma: discount factor 383 | :param env: environment to train on 384 | :param eval_env: environment to evaluate on 385 | :param vision: boolean flag to indicate if the input state is an image or not 386 | """ 387 | if not vision: # For featurized states 388 | self._q = Q(state_dim, action_dim).to(device) 389 | self._q_target = Q(state_dim, action_dim).to(device) 390 | else: # For image states, i.e. Atari 391 | self._q = NatureDQN(state_dim, action_dim).to(device) 392 | self._q_target = NatureDQN(state_dim, action_dim).to(device) 393 | 394 | self._gamma = gamma 395 | 396 | self.batch_size = 32 397 | self.grad_clip_val = 40.0 398 | self.target_net_upd_freq = 500 399 | self.learning_starts = 10_000 400 | self.initial_epsilon = 1.0 401 | self.final_epsilon = 0.01 402 | self.epsilon_timesteps = 200_000 403 | self.train_freq = 4 404 | 405 | self._loss_function = nn.SmoothL1Loss() # huber loss # nn.MSELoss() 406 | self._q_optimizer = optim.Adam(self._q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08) 407 | # self._q_optimizer = optim.RMSprop(self._q.parameters(), lr=0.00025, alpha=0.01, eps=1e-08, momentum=0.95) 408 | self._action_dim = action_dim 409 | 410 | self._replay_buffer = ReplayBuffer(5e4) 411 | self._env = env 412 | self._eval_env = eval_env 413 | 414 | def get_action(self, x: np.ndarray, epsilon: float) -> int: 415 | """ 416 | Simple helper to get action epsilon-greedy based on observation x 417 | """ 418 | u = np.argmax(self._q(tt(x[None, :])).cpu().detach().numpy()) 419 | r = np.random.uniform() 420 | if r < epsilon: 421 | return np.random.randint(self._action_dim) 422 | return u 423 | 424 | def train(self, episodes: int, max_env_time_steps: int, epsilon: float, eval_eps: int = 1, 425 | eval_every_n_steps: int = 1, max_train_time_steps: int = 1_000_000): 426 | """ 427 | Training loop 428 | :param episodes: maximum number of episodes to train for 429 | :param max_env_time_steps: maximum number of steps in the environment to perform per episode 430 | :param epsilon: constant epsilon for exploration when selecting actions 431 | :param eval_eps: numper of episodes to run for evaluation 432 | :param eval_every_n_steps: interval of steps after which to evaluate the trained agent 433 | :param max_train_time_steps: maximum number of steps to train 434 | :return: 435 | """ 436 | total_steps = 0 437 | num_update_steps = 0 438 | batch_size = self.batch_size 439 | grad_clip_val = self.grad_clip_val 440 | target_net_upd_freq = self.target_net_upd_freq 441 | learning_starts = self.learning_starts 442 | 443 | start_time = time.time() 444 | 445 | for e in range(episodes): 446 | print("# Episode: %s/%s" % (e + 1, episodes)) 447 | s = self._env.reset() 448 | 449 | for t in range(max_env_time_steps): 450 | # s = s 451 | # s = torch.from_numpy(s).unsqueeze(0) 452 | if total_steps > self.epsilon_timesteps: 453 | epsilon = self.final_epsilon 454 | else: 455 | epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) * ( 456 | total_steps / self.epsilon_timesteps) 457 | 458 | a = self.get_action(s, epsilon) 459 | ns, r, d, _ = self._env.step(a) 460 | total_steps += 1 461 | 462 | ########### Begin Evaluation 463 | if (total_steps % eval_every_n_steps) == 0: 464 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps) 465 | eval_stats = dict( 466 | elapsed_time=time.time() - start_time, 467 | training_steps=total_steps, 468 | training_eps=e, 469 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)), 470 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)), 471 | avg_rew_per_eval_ep=float(np.mean(eval_r)), 472 | std_rew_per_eval_ep=float(np.std(eval_r)), 473 | eval_eps=eval_eps 474 | ) 475 | 476 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh: 477 | json.dump(eval_stats, out_fh) 478 | out_fh.write('\n') 479 | ########### End Evaluation 480 | 481 | # Update replay buffer 482 | self._replay_buffer.add_transition(s, a, ns, r, d) 483 | 484 | batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = \ 485 | self._replay_buffer.random_next_batch(batch_size) 486 | 487 | ########### Begin double Q-learning update 488 | target = batch_rewards + (1 - batch_terminal_flags) * self._gamma * \ 489 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax( 490 | self._q(batch_next_states), dim=1)] 491 | current_prediction = self._q(batch_states)[torch.arange(batch_size).long(), batch_actions.long()] 492 | 493 | loss = self._loss_function(current_prediction, target.detach()) 494 | 495 | 496 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0): 497 | num_update_steps += 1 498 | self._q_optimizer.zero_grad() 499 | loss.backward() 500 | for param in self._q.parameters(): 501 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val) 502 | self._q_optimizer.step() 503 | 504 | 505 | if (total_steps % target_net_upd_freq) == 0: 506 | hard_update(self._q_target, self._q) 507 | # soft_update(self._q_target, self._q, 0.01) 508 | ########### End double Q-learning update 509 | if d: 510 | break 511 | s = ns 512 | if total_steps >= max_train_time_steps: 513 | break 514 | if total_steps >= max_train_time_steps: 515 | break 516 | 517 | # Final evaluation 518 | if (total_steps % eval_every_n_steps) != 0: 519 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps) 520 | eval_stats = dict( 521 | elapsed_time=time.time() - start_time, 522 | training_steps=total_steps, 523 | training_eps=e, 524 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)), 525 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)), 526 | avg_rew_per_eval_ep=float(np.mean(eval_r)), 527 | std_rew_per_eval_ep=float(np.std(eval_r)), 528 | eval_eps=eval_eps 529 | ) 530 | 531 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh: 532 | json.dump(eval_stats, out_fh) 533 | out_fh.write('\n') 534 | 535 | def eval(self, episodes: int, max_env_time_steps: int): 536 | """ 537 | Simple method that evaluates the agent with fixed epsilon = 0 538 | :param episodes: max number of episodes to play 539 | :param max_env_time_steps: max number of max_env_time_steps to play 540 | 541 | :returns (steps per episode), (reward per episode), (decisions per episode) 542 | """ 543 | steps, rewards, decisions = [], [], [] 544 | with torch.no_grad(): 545 | for e in range(episodes): 546 | ed, es, er = 0, 0, 0 547 | 548 | s = self._eval_env.reset() 549 | for _ in count(): 550 | a = self.get_action(s, 0) 551 | ed += 1 552 | 553 | ns, r, d, _ = self._eval_env.step(a) 554 | # print(r, d) 555 | er += r 556 | es += 1 557 | if es >= max_env_time_steps or d: 558 | break 559 | s = ns 560 | steps.append(es) 561 | rewards.append(er) 562 | decisions.append(ed) 563 | 564 | return steps, rewards, decisions 565 | 566 | def save_model(self, path): 567 | torch.save(self._q.state_dict(), os.path.join(path, 'Q')) 568 | 569 | 570 | class DAR: 571 | """ 572 | Simple Dynamic Action Repetition Agent based on double DQN 573 | """ 574 | 575 | def __init__(self, state_dim: int, action_dim: int, 576 | num_output_duplication: int, skip_map: dict, 577 | gamma: float, env: gym.Env, eval_env: gym.Env, vision=False): 578 | """ 579 | Initialize the DQN Agent 580 | :param state_dim: dimensionality of the input states 581 | :param action_dim: dimensionality of the output actions 582 | :param num_output_duplication: integer that determines how often to duplicate output heads (original is 2) 583 | :param skip_map: determines the skip value associated with each output head 584 | :param gamma: discount factor 585 | :param env: environment to train on 586 | :param eval_env: environment to evaluate on 587 | """ 588 | if not vision: # For featurized states 589 | self._q = Q(state_dim, action_dim * num_output_duplication).to(device) 590 | self._q_target = Q(state_dim, action_dim * num_output_duplication).to(device) 591 | else: # For image states, i.e. Atari 592 | self._q = NatureDQN(state_dim, action_dim * num_output_duplication).to(device) 593 | self._q_target = NatureDQN(state_dim, action_dim * num_output_duplication).to(device) 594 | 595 | self._gamma = gamma 596 | 597 | self.batch_size = 32 598 | self.grad_clip_val = 40.0 599 | self.target_net_upd_freq = 500 600 | self.learning_starts = 10_000 601 | self.initial_epsilon = 1.0 602 | self.final_epsilon = 0.01 603 | self.epsilon_timesteps = 200_000 604 | self.train_freq = 4 605 | 606 | self._loss_function = nn.SmoothL1Loss() # huber loss # nn.MSELoss() 607 | self._q_optimizer = optim.Adam(self._q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08) 608 | self._action_dim = action_dim 609 | 610 | self._replay_buffer = ReplayBuffer(5e4) 611 | self._skip_map = skip_map 612 | self._dup_vals = num_output_duplication 613 | self._env = env 614 | self._eval_env = eval_env 615 | 616 | def get_action(self, x: np.ndarray, epsilon: float) -> int: 617 | """ 618 | Simple helper to get action epsilon-greedy based on observation x 619 | """ 620 | u = np.argmax(self._q(tt(x[None, :])).detach().numpy()) 621 | r = np.random.uniform() 622 | if r < epsilon: 623 | return np.random.randint(self._action_dim) 624 | return u 625 | 626 | def train(self, episodes: int, max_env_time_steps: int, epsilon: float, eval_eps: int = 1, 627 | eval_every_n_steps: int = 1, max_train_time_steps: int = 1_000_000): 628 | """ 629 | Training loop 630 | :param episodes: maximum number of episodes to train for 631 | :param max_env_time_steps: maximum number of steps in the environment to perform per episode 632 | :param epsilon: constant epsilon for exploration when selecting actions 633 | :param eval_eps: numper of episodes to run for evaluation 634 | :param eval_every_n_steps: interval of steps after which to evaluate the trained agent 635 | :param max_train_time_steps: maximum number of steps to train 636 | """ 637 | total_steps = 0 638 | num_update_steps = 0 639 | batch_size = self.batch_size 640 | grad_clip_val = self.grad_clip_val 641 | target_net_upd_freq = self.target_net_upd_freq 642 | learning_starts = self.learning_starts 643 | 644 | start_time = time.time() 645 | for e in range(episodes): 646 | print("%s/%s" % (e + 1, episodes)) 647 | s = self._env.reset() 648 | es = 0 649 | for t in range(max_env_time_steps): 650 | if total_steps > self.epsilon_timesteps: 651 | epsilon = self.final_epsilon 652 | else: 653 | epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) * ( 654 | total_steps / self.epsilon_timesteps) 655 | a = self.get_action(s, epsilon) 656 | 657 | # convert action id int corresponding behaviour action and skip value 658 | act = a // self._dup_vals # behaviour 659 | rep = a // self._env.action_space.n # skip id 660 | skip = self._skip_map[rep] # skip id to corresponding skip value 661 | 662 | for _ in range(skip + 1): # repeat chosen behaviour action for "skip" steps 663 | ns, r, d, _ = self._env.step(act) 664 | total_steps += 1 665 | es += 1 666 | 667 | ########### Begin Evaluation 668 | if (total_steps % eval_every_n_steps) == 0: 669 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps) 670 | eval_stats = dict( 671 | elapsed_time=time.time() - start_time, 672 | training_steps=total_steps, 673 | training_eps=e, 674 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)), 675 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)), 676 | avg_rew_per_eval_ep=float(np.mean(eval_r)), 677 | std_rew_per_eval_ep=float(np.std(eval_r)), 678 | eval_eps=eval_eps 679 | ) 680 | 681 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh: 682 | json.dump(eval_stats, out_fh) 683 | out_fh.write('\n') 684 | ########### End Evaluation 685 | 686 | ### Q-update based double Q learning 687 | self._replay_buffer.add_transition(s, a, ns, r, d) 688 | batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = \ 689 | self._replay_buffer.random_next_batch(batch_size) 690 | 691 | target = batch_rewards + (1 - batch_terminal_flags) * self._gamma * \ 692 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax( 693 | self._q(batch_next_states), dim=1)] 694 | current_prediction = self._q(batch_states)[torch.arange(batch_size).long(), batch_actions.long()] 695 | 696 | loss = self._loss_function(current_prediction, target.detach()) 697 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0): 698 | num_update_steps += 1 699 | self._q_optimizer.zero_grad() 700 | loss.backward() 701 | for param in self._q.parameters(): 702 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val) 703 | self._q_optimizer.step() 704 | 705 | if (total_steps % target_net_upd_freq) == 0: 706 | hard_update(self._q_target, self._q) 707 | # soft_update(self._q_target, self._q, 0.01) 708 | ########### End double Q-learning update 709 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps: 710 | break 711 | 712 | s = ns 713 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps: 714 | break 715 | if total_steps >= max_train_time_steps: 716 | break 717 | 718 | # Final evaluation 719 | if (total_steps % eval_every_n_steps) != 0: 720 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps) 721 | eval_stats = dict( 722 | elapsed_time=time.time() - start_time, 723 | training_steps=total_steps, 724 | training_eps=e, 725 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)), 726 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)), 727 | avg_rew_per_eval_ep=float(np.mean(eval_r)), 728 | std_rew_per_eval_ep=float(np.std(eval_r)), 729 | eval_eps=eval_eps 730 | ) 731 | 732 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh: 733 | json.dump(eval_stats, out_fh) 734 | out_fh.write('\n') 735 | 736 | def eval(self, episodes: int, max_env_time_steps: int): 737 | """ 738 | Simple method that evaluates the agent with fixed epsilon = 0 739 | :param episodes: max number of episodes to play 740 | :param max_env_time_steps: max number of max_env_time_steps to play 741 | 742 | :returns (steps per episode), (reward per episode), (decisions per episode) 743 | """ 744 | steps, rewards, decisions = [], [], [] 745 | with torch.no_grad(): 746 | for e in range(episodes): 747 | ed, es, er = 0, 0, 0 748 | 749 | s = self._eval_env.reset() 750 | for _ in count(): 751 | # print(self._q(tt(s))) 752 | a = self.get_action(s, 0) 753 | act = a // self._dup_vals 754 | rep = a // self._eval_env.action_space.n 755 | skip = self._skip_map[rep] 756 | 757 | ed += 1 758 | 759 | d = False 760 | for _ in range(skip + 1): 761 | ns, r, d, _ = self._eval_env.step(act) 762 | er += r 763 | es += 1 764 | if es >= max_env_time_steps or d: 765 | break 766 | s = ns 767 | if es >= max_env_time_steps or d: 768 | break 769 | steps.append(es) 770 | rewards.append(er) 771 | decisions.append(ed) 772 | 773 | return steps, rewards, decisions 774 | 775 | def save_model(self, path): 776 | torch.save(self._q.state_dict(), os.path.join(path, 'Q')) 777 | 778 | 779 | class TDQN: 780 | """ 781 | TempoRL DQN agent capable of handling more complex state inputs through use of contextualized behaviour actions. 782 | """ 783 | 784 | def __init__(self, state_dim, action_dim, skip_dim, gamma, env, eval_env, vision=False, shared=True): 785 | """ 786 | Initialize the DQN Agent 787 | :param state_dim: dimensionality of the input states 788 | :param action_dim: dimensionality of the action output 789 | :param skip_dim: dimenionality of the skip output 790 | :param gamma: discount factor 791 | :param env: environment to train on 792 | :param eval_env: environment to evaluate on 793 | :param vision: boolean flag to indicate if the input state is an image or not 794 | :param shared: boolean flag to indicate if a weight sharing input representation is used or not. 795 | """ 796 | if not vision: 797 | if shared: 798 | self._q = WeightSharingTQ(state_dim, action_dim, skip_dim).to(device) 799 | self._q_target = WeightSharingTQ(state_dim, action_dim, skip_dim).to(device) 800 | else: 801 | self._q = Q(state_dim, action_dim).to(device) 802 | self._q_target = Q(state_dim, action_dim).to(device) 803 | else: 804 | if shared: 805 | self._q = NatureWeightsharingTQN(state_dim, action_dim, skip_dim).to(device) 806 | self._q_target = NatureWeightsharingTQN(state_dim, action_dim, skip_dim).to(device) 807 | else: 808 | self._q = NatureDQN(state_dim, action_dim).to(device) 809 | self._q_target = NatureDQN(state_dim, action_dim).to(device) 810 | 811 | if shared: 812 | self._skip_q = self._q 813 | else: 814 | if not vision: 815 | self._skip_q = TQ(state_dim, skip_dim).to(device) 816 | else: 817 | self._skip_q = NatureTQN(state_dim, skip_dim).to(device) 818 | print('Using {} as Q'.format(str(self._q))) 819 | print('Using {} as skip-Q\n{}'.format(str(self._skip_q), '#' * 80)) 820 | 821 | self._gamma = gamma 822 | self._action_dim = action_dim 823 | self._skip_dim = skip_dim 824 | 825 | self.batch_size = 32 826 | self.grad_clip_val = 40.0 827 | self.target_net_upd_freq = 500 828 | self.learning_starts = 10_000 829 | self.initial_epsilon = 1.0 830 | self.final_epsilon = 0.01 831 | self.epsilon_timesteps = 200_000 832 | self.train_freq = 4 833 | 834 | self._loss_function = nn.SmoothL1Loss() # huber loss # nn.MSELoss() 835 | self._skip_loss_function = nn.SmoothL1Loss() # nn.MSELoss() 836 | self._q_optimizer = optim.Adam(self._q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08) 837 | self._skip_q_optimizer = optim.Adam(self._skip_q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08) 838 | 839 | self._replay_buffer = ReplayBuffer(5e4) 840 | self._skip_replay_buffer = NoneConcatSkipReplayBuffer(5e4) 841 | self._env = env 842 | self._eval_env = eval_env 843 | 844 | def get_action(self, x: np.ndarray, epsilon: float) -> int: 845 | """ 846 | Simple helper to get action epsilon-greedy based on observation x 847 | """ 848 | u = np.argmax(self._q(tt(x[None, :])).cpu().detach().numpy()) 849 | r = np.random.uniform() 850 | if r < epsilon: 851 | return np.random.randint(self._action_dim) 852 | return u 853 | 854 | def get_skip(self, x: np.ndarray, a: np.ndarray, epsilon: float) -> int: 855 | """ 856 | Simple helper to get the skip epsilon-greedy based on observation x conditioned on behaviour action a 857 | """ 858 | u = np.argmax(self._skip_q(tt(x[None, :]), tt(a[None, :])).detach().numpy()) 859 | r = np.random.uniform() 860 | if r < epsilon: 861 | return np.random.randint(self._skip_dim) 862 | return u 863 | 864 | def eval(self, episodes: int, max_env_time_steps: int): 865 | """ 866 | Simple method that evaluates the agent with fixed epsilon = 0 867 | :param episodes: max number of episodes to play 868 | :param max_env_time_steps: max number of max_env_time_steps to play 869 | 870 | :returns (steps per episode), (reward per episode), (decisions per episode) 871 | """ 872 | steps, rewards, decisions = [], [], [] 873 | with torch.no_grad(): 874 | for e in range(episodes): 875 | ed, es, er = 0, 0, 0 876 | 877 | s = self._eval_env.reset() 878 | for _ in count(): 879 | a = self.get_action(s, 0) 880 | skip = self.get_skip(s, np.array([a]), 0) 881 | ed += 1 882 | 883 | d = False 884 | for _ in range(skip + 1): 885 | ns, r, d, _ = self._eval_env.step(a) 886 | er += r 887 | es += 1 888 | if es >= max_env_time_steps or d: 889 | break 890 | s = ns 891 | if es >= max_env_time_steps or d: 892 | break 893 | steps.append(es) 894 | rewards.append(er) 895 | decisions.append(ed) 896 | 897 | return steps, rewards, decisions 898 | 899 | def train(self, episodes: int, max_env_time_steps: int, epsilon: float, eval_eps: int = 1, 900 | eval_every_n_steps: int = 1, max_train_time_steps: int = 1_000_000): 901 | """ 902 | Training loop 903 | :param episodes: maximum number of episodes to train for 904 | :param max_env_time_steps: maximum number of steps in the environment to perform per episode 905 | :param epsilon: constant epsilon for exploration when selecting actions 906 | :param eval_eps: numper of episodes to run for evaluation 907 | :param eval_every_n_steps: interval of steps after which to evaluate the trained agent 908 | :param max_train_time_steps: maximum number of steps to train 909 | """ 910 | total_steps = 0 911 | num_update_steps = 0 912 | batch_size = self.batch_size 913 | grad_clip_val = self.grad_clip_val 914 | target_net_upd_freq = self.target_net_upd_freq 915 | learning_starts = self.learning_starts 916 | 917 | start_time = time.time() 918 | 919 | for e in range(episodes): 920 | print("# Episode: %s/%s" % (e + 1, episodes)) 921 | s = self._env.reset() 922 | es = 0 923 | for _ in count(): 924 | 925 | if total_steps > self.epsilon_timesteps: 926 | epsilon = self.final_epsilon 927 | else: 928 | epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) * ( 929 | total_steps / self.epsilon_timesteps) 930 | 931 | a = self.get_action(s, epsilon) 932 | skip = self.get_skip(s, np.array([a]), epsilon) # get skip with the selected action as context 933 | 934 | d = False 935 | skip_states, skip_rewards = [], [] 936 | for curr_skip in range(skip + 1): # repeat the selected action for "skip" times 937 | ns, r, d, info_ = self._env.step(a) 938 | total_steps += 1 939 | es += 1 940 | skip_states.append(s) # keep track of all observed skips 941 | skip_rewards.append(r) 942 | 943 | #### Begin Evaluation 944 | if (total_steps % eval_every_n_steps) == 0: 945 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps) 946 | eval_stats = dict( 947 | elapsed_time=time.time() - start_time, 948 | training_steps=total_steps, 949 | training_eps=e, 950 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)), 951 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)), 952 | avg_rew_per_eval_ep=float(np.mean(eval_r)), 953 | std_rew_per_eval_ep=float(np.std(eval_r)), 954 | eval_eps=eval_eps 955 | ) 956 | 957 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh: 958 | json.dump(eval_stats, out_fh) 959 | out_fh.write('\n') 960 | #### End Evaluation 961 | 962 | # Update the skip replay buffer with all observed skips. 963 | # if curr_skip == (skip + 1): 964 | skip_id = 0 965 | for start_state in skip_states: 966 | skip_reward = 0 967 | for exp, r in enumerate(skip_rewards[skip_id:]): # make sure to properly discount 968 | skip_reward += np.power(self._gamma, exp) * r 969 | 970 | self._skip_replay_buffer.add_transition(start_state, curr_skip - skip_id, ns, 971 | skip_reward, d, curr_skip - skip_id + 1, 972 | np.array([a])) # also keep track of the behavior action 973 | skip_id += 1 974 | 975 | # Skip Q update based on double DQN where target is behavior Q 976 | batch_states, batch_actions, batch_next_states, batch_rewards,\ 977 | batch_terminal_flags, batch_lengths, batch_behaviours = \ 978 | self._skip_replay_buffer.random_next_batch(batch_size) 979 | 980 | target = batch_rewards + (1 - batch_terminal_flags) * np.power(self._gamma, batch_lengths) * \ 981 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax( 982 | self._q(batch_next_states), dim=1)] 983 | current_prediction = self._skip_q(batch_states, batch_behaviours)[ 984 | torch.arange(batch_size).long(), batch_actions.long()] 985 | 986 | loss = self._skip_loss_function(current_prediction, target.detach()) 987 | 988 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0): 989 | num_update_steps += 1 990 | self._skip_q_optimizer.zero_grad() 991 | loss.backward() 992 | for param in self._skip_q.parameters(): 993 | if param.grad is None: 994 | pass 995 | # print("##### Skip Q Parameter with grad = None:", param.name) 996 | else: 997 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val) 998 | self._skip_q_optimizer.step() 999 | 1000 | # Action Q update based on double DQN with normal target 1001 | self._replay_buffer.add_transition(s, a, ns, r, d) 1002 | batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = \ 1003 | self._replay_buffer.random_next_batch(batch_size) 1004 | 1005 | target = batch_rewards + (1 - batch_terminal_flags) * self._gamma * \ 1006 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax( 1007 | self._q(batch_next_states), dim=1)] 1008 | current_prediction = self._q(batch_states)[torch.arange(batch_size).long(), batch_actions.long()] 1009 | 1010 | loss = self._loss_function(current_prediction, target.detach()) 1011 | 1012 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0): 1013 | self._q_optimizer.zero_grad() 1014 | loss.backward() 1015 | for param in self._q.parameters(): 1016 | if param.grad is None: 1017 | pass 1018 | # print("##### Q Parameter with grad = None:", param.name) 1019 | else: 1020 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val) 1021 | self._q_optimizer.step() 1022 | 1023 | if (total_steps % target_net_upd_freq) == 0: 1024 | hard_update(self._q_target, self._q) 1025 | # soft_update(self._q_target, self._q, 0.01) 1026 | 1027 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps: 1028 | break 1029 | 1030 | s = ns 1031 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps: 1032 | break 1033 | if total_steps >= max_train_time_steps: 1034 | break 1035 | 1036 | # final evaluation 1037 | if (total_steps % eval_every_n_steps) != 0: 1038 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps) 1039 | eval_stats = dict( 1040 | elapsed_time=time.time() - start_time, 1041 | training_steps=total_steps, 1042 | training_eps=e, 1043 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)), 1044 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)), 1045 | avg_rew_per_eval_ep=float(np.mean(eval_r)), 1046 | std_rew_per_eval_ep=float(np.std(eval_r)), 1047 | eval_eps=eval_eps 1048 | ) 1049 | 1050 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh: 1051 | json.dump(eval_stats, out_fh) 1052 | out_fh.write('\n') 1053 | 1054 | def save_model(self, path): 1055 | torch.save(self._q.state_dict(), os.path.join(path, 'Q')) 1056 | torch.save(self._skip_q.state_dict(), os.path.join(path, 'TQ')) 1057 | 1058 | 1059 | if __name__ == "__main__": 1060 | import argparse 1061 | 1062 | outdir_suffix_dict = {'none': '', 'empty': '', 'time': '%Y_%m_%d_%H%M%S', 1063 | 'seed': '{:d}', 'params': '{:d}_{:d}_{:d}', 1064 | 'paramsseed': '{:d}_{:d}_{:d}_{:d}'} 1065 | parser = argparse.ArgumentParser('TempoRL') 1066 | parser.add_argument('--episodes', '-e', 1067 | default=100, 1068 | type=int, 1069 | help='Number of training episodes.') 1070 | parser.add_argument('--training-steps', '-t', 1071 | default=1_000_000, 1072 | type=int, 1073 | help='Number of training episodes.') 1074 | 1075 | parser.add_argument('--out-dir', 1076 | default=None, 1077 | type=str, 1078 | help='Directory to save results. Defaults to tmp dir.') 1079 | parser.add_argument('--out-dir-suffix', 1080 | default='paramsseed', 1081 | type=str, 1082 | choices=list(outdir_suffix_dict.keys()), 1083 | help='Created suffix of directory to save results.') 1084 | parser.add_argument('--seed', '-s', 1085 | default=12345, 1086 | type=int, 1087 | help='Seed') 1088 | parser.add_argument('--eval-after-n-steps', 1089 | default=10 ** 3, 1090 | type=int, 1091 | help='After how many steps to evaluate') 1092 | parser.add_argument('--eval-n-episodes', 1093 | default=1, 1094 | type=int, 1095 | help='How many episodes to evaluate') 1096 | 1097 | # DQN -> normal double DQN agent 1098 | # DAR -> Dynamic action repetition agent based on normal DDQN with repeated output heads for different skip values 1099 | # tqn -> Not usable with vision states 1100 | # tdqn -> TempoRL DDQN with shared state representation for behaviour and skip Qs 1101 | # t-dqn -> TempoRL DDQN without shared state representation for behaviour and skip Qs 1102 | parser.add_argument('--agent', 1103 | choices=['dqn', 'dar', 'tdqn', 't-dqn'], 1104 | type=str.lower, 1105 | help='Which agent to train', 1106 | default='tdqn') 1107 | parser.add_argument('--skip-net-max-skips', 1108 | type=int, 1109 | default=10, 1110 | help='Maximum skip-size') 1111 | parser.add_argument('--env-max-steps', 1112 | default=200, 1113 | type=int, 1114 | help='Maximal steps in environment before termination.', 1115 | dest='env_ms') 1116 | parser.add_argument('--dar-base', default=None, 1117 | type=int, 1118 | help='DAR base') 1119 | parser.add_argument('--sparse', action='store_true') 1120 | parser.add_argument('--no-frame-skip', action='store_true') 1121 | parser.add_argument('--84x84', action='store_true', dest='large_image') 1122 | parser.add_argument('--dar-A', default=None, type=int) 1123 | parser.add_argument('--dar-B', default=None, type=int) 1124 | parser.add_argument('--env', 1125 | type=str, 1126 | help="Possible envs = 'mountain', 'moon' or any Atari env", 1127 | default='mountain') 1128 | 1129 | # setup output dir 1130 | args = parser.parse_args() 1131 | torch.manual_seed(args.seed) 1132 | np.random.seed(args.seed) 1133 | random.seed(args.seed) 1134 | outdir_suffix_dict['seed'] = outdir_suffix_dict['seed'].format(args.seed) 1135 | epis = args.episodes if args.episodes else -1 1136 | outdir_suffix_dict['params'] = outdir_suffix_dict['params'].format( 1137 | epis, args.skip_net_max_skips, args.env_ms) 1138 | outdir_suffix_dict['paramsseed'] = outdir_suffix_dict['paramsseed'].format( 1139 | epis, args.skip_net_max_skips, args.env_ms, args.seed) 1140 | 1141 | out_dir = experiments.prepare_output_dir(args, user_specified_dir=args.out_dir, 1142 | time_format=outdir_suffix_dict[args.out_dir_suffix]) 1143 | 1144 | if args.env not in ['mountain', 'moon']: 1145 | from utils.env_wrappers import make_env, make_env_old 1146 | 1147 | # Setup Envs 1148 | game = ''.join([g.capitalize() for g in args.env.split('_')]) 1149 | 1150 | if args.no_frame_skip: 1151 | eval_game = '{}NoFrameskip-v4'.format(game) 1152 | game = '{}NoFrameskip-v0'.format(game) 1153 | env = make_env_old(game, dim=84 if args.large_image else 42) 1154 | eval_env = make_env_old(eval_game, dim=84 if args.large_image else 42) 1155 | else: 1156 | eval_game = '{}Deterministic-v4'.format(game) 1157 | game = '{}Deterministic-v0'.format(game) 1158 | env = make_env(game, dim=84 if args.large_image else 42) 1159 | eval_env = make_env(eval_game, dim=84 if args.large_image else 42) 1160 | 1161 | # Setup Agent 1162 | state_dim = env.observation_space.shape[0] # (4, 42, 42) or (4, 84, 84) for PyTorch order 1163 | action_dim = env.action_space.n 1164 | if args.agent == 'dqn': 1165 | agent = DQN(state_dim, action_dim, gamma=0.99, env=env, eval_env=eval_env, vision=True) 1166 | elif args.agent == 'tdqn': 1167 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env, eval_env=eval_env, 1168 | vision=True) 1169 | elif args.agent == 't-dqn': 1170 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env, 1171 | eval_env=eval_env, shared=False, vision=True) 1172 | elif args.agent == 'dar': 1173 | if args.dar_A is not None and args.dar_B is not None: 1174 | skip_map = {0: args.dar_A, 1: args.dar_B} 1175 | elif args.dar_base: 1176 | skip_map = {a: args.dar_base ** a for a in range(args.skip_net_max_skips)} 1177 | else: 1178 | skip_map = {a: a for a in range(args.skip_net_max_skips)} 1179 | agent = DAR(state_dim, action_dim, args.skip_net_max_skips, skip_map, gamma=0.99, env=env, 1180 | eval_env=eval_env, vision=True) 1181 | else: # Simple featurized environments 1182 | # Setup Env 1183 | if args.env == 'mountain': 1184 | if args.sparse: 1185 | from gym.envs.classic_control import MountainCarEnv 1186 | env = MountainCarEnv() 1187 | eval_env = MountainCarEnv() 1188 | elif args.env == 'moon': 1189 | env = gym.make('LunarLander-v2') 1190 | eval_env = gym.make('LunarLander-v2') 1191 | 1192 | # Setup agent 1193 | state_dim = env.observation_space.shape[0] 1194 | action_dim = env.action_space.n 1195 | if args.agent == 'dqn': 1196 | agent = DQN(state_dim, action_dim, gamma=0.99, env=env, eval_env=eval_env) 1197 | elif args.agent == 'tdqn': 1198 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env, eval_env=eval_env) 1199 | elif args.agent == 't-dqn': 1200 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env, 1201 | eval_env=eval_env, shared=False) 1202 | elif args.agent == 'dar': 1203 | if args.dar_A is not None and args.dar_B is not None: 1204 | skip_map = {0: args.dar_A, 1: args.dar_B} 1205 | elif args.dar_base: 1206 | skip_map = {a: args.dar_base ** a for a in range(args.skip_net_max_skips)} 1207 | else: 1208 | skip_map = {a: a for a in range(args.skip_net_max_skips)} 1209 | agent = DAR(state_dim, action_dim, args.skip_net_max_skips, skip_map, gamma=0.99, env=env, 1210 | eval_env=eval_env) 1211 | else: 1212 | raise NotImplementedError 1213 | 1214 | episodes = args.episodes 1215 | max_env_time_steps = args.env_ms 1216 | epsilon = 0.2 1217 | 1218 | agent.train(episodes, max_env_time_steps, epsilon, args.eval_n_episodes, args.eval_after_n_steps, 1219 | max_train_time_steps=args.training_steps) 1220 | os.mkdir(os.path.join(out_dir, 'final')) 1221 | agent.save_model(os.path.join(out_dir, 'final')) 1222 | -------------------------------------------------------------------------------- /run_ddpg_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on code from Scott Fujimoto https://github.com/sfujim/TD3 3 | We adapted the DDPG code he provides to allow for FiGAR and TempoRL variants 4 | This code is originally under the MIT license https://github.com/sfujim/TD3/blob/master/LICENSE 5 | """ 6 | 7 | import argparse 8 | 9 | import gym 10 | import numpy as np 11 | import torch 12 | 13 | from DDPG import utils 14 | from DDPG.FiGAR import DDPG as FiGARDDPG 15 | from DDPG.TempoRL import DDPG as TempoRLDDPG 16 | from DDPG.vanilla import DDPG 17 | from utils import experiments 18 | 19 | 20 | # Runs policy for X episodes and returns average reward 21 | # A fixed seed is used for the eval environment 22 | def eval_policy(policy, env_name, seed, eval_episodes=10, FiGAR=False, TempoRL=False): 23 | eval_env = gym.make(env_name) 24 | special = 'PendulumDecs-v0' == env_name 25 | if special: 26 | eval_env = utils.Render(eval_env, episode_modulo=10) 27 | eval_env.seed(seed + 100) 28 | 29 | avg_reward = 0. 30 | avg_steps = 0. 31 | avg_decs = 0. 32 | for _ in range(eval_episodes): 33 | state, done = eval_env.reset(), False 34 | repetition = 1 35 | while not done: 36 | if FiGAR: 37 | action, repetition, rps = policy.select_action(np.array(state)) 38 | repetition = repetition[0] + 1 39 | 40 | elif TempoRL: 41 | action = policy.select_action(np.array(state)) 42 | repetition = np.argmax(policy.select_skip(np.array(state), action)) + 1 43 | 44 | else: 45 | action = policy.select_action(np.array(state)) 46 | 47 | if special: 48 | eval_env.set_decision_point(True) 49 | avg_decs += 1 50 | 51 | for _ in range(repetition): 52 | state, reward, done, _ = eval_env.step(action) 53 | avg_reward += reward 54 | avg_steps += 1 55 | if done: 56 | break 57 | eval_env.close() 58 | 59 | avg_reward /= eval_episodes 60 | avg_decs /= eval_episodes 61 | avg_steps /= eval_episodes 62 | 63 | print("---------------------------------------") 64 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}") 65 | print("---------------------------------------") 66 | return avg_reward, avg_decs, avg_steps 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('--out-dir', 73 | default=None, 74 | type=str, 75 | help='Directory to save results. Defaults to tmp dir.') 76 | parser.add_argument("--policy", default="TempoRLDDPG") # Policy name (DDPG, FiGARDDPG or our TempoRLDDPG) 77 | parser.add_argument("--env", default="Pendulum-v0") # OpenAI gym environment name 78 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds 79 | parser.add_argument("--start_timesteps", default=1e3, type=int) # Time steps initial random policy is used 80 | parser.add_argument("--eval_freq", default=500, type=int) # How often (time steps) we evaluate 81 | parser.add_argument("--max_timesteps", default=2e4, type=int) # Max time steps to run environment 82 | parser.add_argument("--expl_noise", default=0.1) # Std of Gaussian exploration noise 83 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic 84 | parser.add_argument("--discount", default=0.99) # Discount factor 85 | parser.add_argument("--tau", default=0.005) # Target network update rate 86 | parser.add_argument("--max-skip", "--max-rep", default=20, type=int, 87 | dest='max_rep') # Maximum Skip length to use with FiGAR or TempoRL 88 | parser.add_argument("--save_model", action="store_true") # Save model and optimizer parameters 89 | parser.add_argument("--load_model", default="") # Model load file name, "" doesn't load, "default" uses file_name 90 | args = parser.parse_args() 91 | 92 | outdir_suffix_dict = dict() 93 | outdir_suffix_dict['seed'] = '{:d}'.format(args.seed) 94 | out_dir = experiments.prepare_output_dir(args, user_specified_dir=args.out_dir, 95 | time_format=outdir_suffix_dict['seed']) 96 | 97 | file_name = f"{args.policy}_{args.env}_{args.seed}" 98 | print("---------------------------------------") 99 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}") 100 | print("---------------------------------------") 101 | 102 | env = gym.make(args.env) 103 | 104 | # Set seeds 105 | env.seed(args.seed) 106 | torch.manual_seed(args.seed) 107 | np.random.seed(args.seed) 108 | 109 | state_dim = env.observation_space.shape[0] 110 | action_dim = env.action_space.shape[0] 111 | max_action = float(env.action_space.high[0]) 112 | 113 | kwargs = { 114 | "state_dim": state_dim, 115 | "action_dim": action_dim, 116 | "max_action": max_action, 117 | "discount": args.discount, 118 | "tau": args.tau, 119 | } 120 | max_rep = args.max_rep 121 | 122 | # Initialize policy 123 | if args.policy == "DDPG": 124 | policy = DDPG(**kwargs) 125 | elif args.policy.startswith('FiGAR'): 126 | kwargs['repetition_dim'] = max_rep 127 | policy = FiGARDDPG(**kwargs) 128 | elif args.policy.startswith('TempoRL'): 129 | kwargs['skip_dim'] = max_rep 130 | policy = TempoRLDDPG(**kwargs) 131 | else: 132 | raise NotImplementedError 133 | 134 | if args.load_model != "": 135 | policy_file = args.load_model 136 | policy.load(f"{out_dir}/{policy_file}") 137 | 138 | skip_replay_buffer = None 139 | if 'FiGAR' in args.policy: 140 | replay_buffer = utils.FiGARReplayBuffer(state_dim, action_dim, rep_dim=max_rep) 141 | else: 142 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim) 143 | if 'TempoRL' in args.policy: 144 | skip_replay_buffer = utils.FiGARReplayBuffer(state_dim, action_dim, rep_dim=1) 145 | 146 | # Evaluate untrained policy 147 | evaluations = [[0, *eval_policy(policy, args.env, args.seed, FiGAR='FiGAR' in args.policy, 148 | TempoRL='TempoRL' in args.policy)]] 149 | 150 | state, done = env.reset(), False 151 | episode_reward = 0 152 | episode_timesteps = 0 153 | episode_num = 0 154 | 155 | t = 0 156 | while t < int(args.max_timesteps): 157 | 158 | episode_timesteps += 1 159 | 160 | # Select action randomly or according to policy 161 | if t < args.start_timesteps: # Before learning starts we sample actions uniformly at random 162 | action = env.action_space.sample() 163 | if args.policy.startswith('FiGAR'): 164 | # FiGAR uses a second actor network to learn the repetition value so we have to create 165 | # initial distirbution over the possible repetition values 166 | repetition_probs = np.random.random(max_rep) 167 | 168 | 169 | def softmax(x): 170 | """Compute softmax values for each sets of scores in x.""" 171 | e_x = np.exp(x - np.max(x)) 172 | return e_x / e_x.sum() 173 | 174 | 175 | repetition_probs = softmax(repetition_probs) 176 | repetition = np.argmax(repetition_probs) 177 | elif args.policy.startswith('TempoRL'): 178 | # TempoRL uses a simple DQN for which we can simply sample from the possible skip values 179 | repetition = np.random.randint(max_rep) + 1 180 | else: 181 | repetition = 1 182 | else: 183 | # Get Action and skip values 184 | if 'FiGAR' in args.policy: 185 | # For FiGAR we treat the action policy exploration as in standard DDPG 186 | action, repetition, repetition_probs = policy.select_action(np.array(state)) 187 | action = ( 188 | action + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 189 | ).clip(-max_action, max_action) 190 | # The Repetition policy however uses epsilon greedy exploration as described in the original paper 191 | # https://arxiv.org/pdf/1702.06054.pdf 192 | if np.random.random() < args.expl_noise: 193 | repetition = np.random.randint(max_rep) + 1 # + 1 since randint samples from [0, max_rep) 194 | else: 195 | repetition = repetition[0] 196 | elif 'TempoRL' in args.policy: 197 | # TempoRL does not interfere with the action policy and its exploration 198 | action = ( 199 | policy.select_action(np.array(state)) 200 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 201 | ).clip(-max_action, max_action) 202 | 203 | # the skip policy uses epsilon greedy exploration for learning 204 | repetition = policy.select_skip(state, action) 205 | if np.random.random() < args.expl_noise: 206 | repetition = np.random.randint(max_rep) + 1 # + 1 sonce randint samples from [0, max_rep) 207 | else: 208 | repetition = np.argmax(repetition) + 1 # + 1 since indices start at 0 209 | else: 210 | # Standard DDPG 211 | action = ( 212 | policy.select_action(np.array(state)) 213 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim) 214 | ).clip(-max_action, max_action) 215 | repetition = 1 # Never skip with vanilla DDPG 216 | 217 | # Perform action 218 | skip_states, skip_rewards = [], [] # only used for TempoRL to build the local conectedness graph 219 | for curr_skip in range(repetition): 220 | next_state, reward, done, _ = env.step(action) 221 | t += 1 222 | done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0 223 | skip_states.append(state) 224 | skip_rewards.append(reward) 225 | 226 | # Store data in replay buffer 227 | if 'FiGAR' in args.policy: 228 | # To train the second actor with FiGAR, we need to keep track of its output "repetition_probs" 229 | replay_buffer.add(state, action, repetition_probs, next_state, reward, done_bool) 230 | else: 231 | # Vanilla DDPG 232 | replay_buffer.add(state, action, next_state, reward, done_bool) 233 | # In addition to the normal replay_buffer 234 | # TempoRL uses a second replay buffer that is only used for training the skip network 235 | if 'TempoRL' in args.policy: 236 | # Update the skip buffer with all observed transitions in the local connectedness graph 237 | skip_id = 0 238 | for start_state in skip_states: 239 | skip_reward = 0 240 | for exp, r in enumerate(skip_rewards[skip_id:]): 241 | skip_reward += np.power(policy.discount, exp) * r # make sure to properly discount rewards 242 | skip_replay_buffer.add(start_state, action, curr_skip - skip_id, next_state, skip_reward, done) 243 | skip_id += 1 244 | 245 | state = next_state 246 | episode_reward += reward 247 | 248 | # Train agent after collecting sufficient data 249 | if t >= args.start_timesteps: 250 | policy.train(replay_buffer, args.batch_size) 251 | if 'TempoRL' in args.policy: 252 | policy.train_skip(skip_replay_buffer, args.batch_size) 253 | 254 | if done: 255 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True 256 | print( 257 | f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}") 258 | # Reset environment 259 | state, done = env.reset(), False 260 | episode_reward = 0 261 | episode_timesteps = 0 262 | episode_num += 1 263 | break 264 | 265 | # Evaluate episode 266 | if (t + 1) % args.eval_freq == 0: 267 | evaluations.append([t, *eval_policy(policy, args.env, args.seed, FiGAR='FiGAR' in args.policy, 268 | TempoRL='TempoRL' in args.policy)]) 269 | np.save(f"{out_dir}/{file_name}", evaluations) 270 | if args.save_model: 271 | policy.save(f"{out_dir}/{file_name}") 272 | -------------------------------------------------------------------------------- /run_tabular_experiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | from utils import experiments 7 | 8 | from grid_envs import GridCore 9 | 10 | 11 | def make_epsilon_greedy_policy(Q: defaultdict, epsilon: float, nA: int) -> callable: 12 | """ 13 | Creates an epsilon-greedy policy based on a given Q-function and epsilon. 14 | I.e. create weight vector from which actions get sampled. 15 | 16 | :param Q: tabular state-action lookup function 17 | :param epsilon: exploration factor 18 | :param nA: size of action space to consider for this policy 19 | """ 20 | 21 | def policy_fn(observation): 22 | policy = np.ones(nA) * epsilon / nA 23 | best_action = np.random.choice(np.flatnonzero( # random choice for tie-breaking only 24 | Q[observation] == Q[observation].max() 25 | )) 26 | policy[best_action] += (1 - epsilon) 27 | return policy 28 | 29 | return policy_fn 30 | 31 | 32 | def get_decay_schedule(start_val: float, decay_start: int, num_steps: int, type_: str): 33 | """ 34 | Create epsilon decay schedule 35 | 36 | :param start_val: Start decay from this value (i.e. 1) 37 | :param decay_start: number of iterations to start epsilon decay after 38 | :param num_steps: Total number of steps to decay over 39 | :param type_: Which strategy to use. Implemented choices: 'const', 'log', 'linear' 40 | :return: 41 | """ 42 | if type_ == 'const': 43 | return np.array([start_val for _ in range(num_steps)]) 44 | elif type_ == 'log': 45 | return np.hstack([[start_val for _ in range(decay_start)], 46 | np.logspace(np.log10(start_val), np.log10(0.000001), (num_steps - decay_start))]) 47 | elif type_ == 'linear': 48 | return np.hstack([[start_val for _ in range(decay_start)], 49 | np.linspace(start_val, 0, (num_steps - decay_start), endpoint=True)]) 50 | else: 51 | raise NotImplementedError 52 | 53 | 54 | def td_update(q: defaultdict, state: int, action: int, reward: float, next_state: int, gamma: float, alpha: float): 55 | """ Simple TD update rule """ 56 | # TD update 57 | best_next_action = np.random.choice(np.flatnonzero(q[next_state] == q[next_state].max())) # greedy best next 58 | td_target = reward + gamma * q[next_state][best_next_action] 59 | td_delta = td_target - q[state][action] 60 | return q[state][action] + alpha * td_delta 61 | 62 | 63 | def q_learning( 64 | environment: GridCore, 65 | num_episodes: int, 66 | discount_factor: float = 1.0, 67 | alpha: float = 0.5, 68 | epsilon: float = 0.1, 69 | epsilon_decay: str = 'const', 70 | decay_starts: int = 0, 71 | eval_every: int = 10, 72 | render_eval: bool = True): 73 | """ 74 | Vanilla tabular Q-learning algorithm 75 | :param environment: which environment to use 76 | :param num_episodes: number of episodes to train 77 | :param discount_factor: discount factor used in TD updates 78 | :param alpha: learning rate used in TD updates 79 | :param epsilon: exploration fraction (either constant or starting value for schedule) 80 | :param epsilon_decay: determine type of exploration (constant, linear/exponential decay schedule) 81 | :param decay_starts: After how many episodes epsilon decay starts 82 | :param eval_every: Number of episodes between evaluations 83 | :param render_eval: Flag to activate/deactivate rendering of evaluation runs 84 | :return: training and evaluation statistics (i.e. rewards and episode lengths) 85 | """ 86 | assert 0 <= discount_factor <= 1, 'Lambda should be in [0, 1]' 87 | assert 0 <= epsilon <= 1, 'epsilon has to be in [0, 1]' 88 | assert alpha > 0, 'Learning rate has to be positive' 89 | # The action-value function. 90 | # Nested dict that maps state -> (action -> action-value). 91 | Q = defaultdict(lambda: np.zeros(environment.action_space.n)) 92 | 93 | # Keeps track of episode lengths and rewards 94 | rewards = [] 95 | lens = [] 96 | test_rewards = [] 97 | test_lens = [] 98 | train_steps_list = [] 99 | test_steps_list = [] 100 | 101 | epsilon_schedule = get_decay_schedule(epsilon, decay_starts, num_episodes, epsilon_decay) 102 | for i_episode in range(num_episodes + 1): 103 | # print('#' * 100) 104 | epsilon = epsilon_schedule[min(i_episode, num_episodes - 1)] 105 | # The policy we're following 106 | policy = make_epsilon_greedy_policy(Q, epsilon, environment.action_space.n) 107 | policy_state = environment.reset() 108 | episode_length, cummulative_reward = 0, 0 109 | while True: # roll out episode 110 | policy_action = np.random.choice(list(range(environment.action_space.n)), p=policy(policy_state)) 111 | s_, policy_reward, policy_done, _ = environment.step(policy_action) 112 | cummulative_reward += policy_reward 113 | episode_length += 1 114 | 115 | Q[policy_state][policy_action] = td_update(Q, policy_state, policy_action, 116 | policy_reward, s_, discount_factor, alpha) 117 | 118 | if policy_done: 119 | break 120 | policy_state = s_ 121 | rewards.append(cummulative_reward) 122 | lens.append(episode_length) 123 | train_steps_list.append(environment.total_steps) 124 | 125 | # evaluation with greedy policy 126 | test_steps = 0 127 | if i_episode % eval_every == 0: 128 | policy_state = environment.reset() 129 | episode_length, cummulative_reward = 0, 0 130 | if render_eval: 131 | environment.render() 132 | while True: # roll out episode 133 | policy_action = np.random.choice(np.flatnonzero(Q[policy_state] == Q[policy_state].max())) 134 | environment.total_steps -= 1 # don't count evaluation steps 135 | s_, policy_reward, policy_done, _ = environment.step(policy_action) 136 | test_steps += 1 137 | if render_eval: 138 | environment.render() 139 | s_ = s_ 140 | cummulative_reward += policy_reward 141 | episode_length += 1 142 | if policy_done: 143 | break 144 | policy_state = s_ 145 | test_rewards.append(cummulative_reward) 146 | test_lens.append(episode_length) 147 | test_steps_list.append(test_steps) 148 | print('Done %4d/%4d episodes' % (i_episode, num_episodes)) 149 | return (rewards, lens), (test_rewards, test_lens), (train_steps_list, test_steps_list), Q 150 | 151 | 152 | class SkipTransition: 153 | """ 154 | Simple helper class to keep track of all transitions observed when skipping through an MDP 155 | """ 156 | 157 | def __init__(self, skips, df): 158 | self.state_mat = np.full((skips, skips), -1, dtype=int) # might need to change type for other envs 159 | self.reward_mat = np.full((skips, skips), np.nan, dtype=float) 160 | self.idx = 0 161 | self.df = df 162 | 163 | def add(self, reward, next_state): 164 | """ 165 | Add reward and next_state to triangular matrix 166 | :param reward: received reward 167 | :param next_state: state reached 168 | """ 169 | self.idx += 1 170 | for i in range(self.idx): 171 | self.state_mat[self.idx - i - 1, i] = next_state 172 | # Automatically discount rewards when adding to corresponding skip 173 | self.reward_mat[self.idx - i - 1, i] = reward * self.df ** i + np.nansum(self.reward_mat[self.idx - i - 1]) 174 | 175 | 176 | def temporl_q_learning( 177 | environment: GridCore, 178 | num_episodes: int, 179 | discount_factor: float = 1.0, 180 | alpha: float = 0.5, 181 | epsilon: float = 0.1, 182 | epsilon_decay: str = 'const', 183 | decay_starts: int = 0, 184 | decay_stops: int = None, 185 | eval_every: int = 10, 186 | render_eval: bool = True, 187 | max_skip: int = 7): 188 | """ 189 | Implementation of tabular TempoRL 190 | :param environment: which environment to use 191 | :param num_episodes: number of episodes to train 192 | :param discount_factor: discount factor used in TD updates 193 | :param alpha: learning rate used in TD updates 194 | :param epsilon: exploration fraction (either constant or starting value for schedule) 195 | :param epsilon_decay: determine type of exploration (constant, linear/exponential decay schedule) 196 | :param decay_starts: After how many episodes epsilon decay starts 197 | :param decay_stops: Episode after which to stop epsilon decay 198 | :param eval_every: Number of episodes between evaluations 199 | :param render_eval: Flag to activate/deactivate rendering of evaluation runs 200 | :param max_skip: Maximum skip size to use. 201 | :return: training and evaluation statistics (i.e. rewards and episode lengths) 202 | """ 203 | temporal_actions = max_skip 204 | action_Q = defaultdict(lambda: np.zeros(environment.action_space.n)) 205 | temporal_Q = defaultdict(lambda: np.zeros(temporal_actions)) 206 | if not decay_stops: 207 | decay_stops = num_episodes 208 | 209 | epsilon_schedule_action = get_decay_schedule(epsilon, decay_starts, decay_stops, epsilon_decay) 210 | epsilon_schedule_temporal = get_decay_schedule(epsilon, decay_starts, decay_stops, epsilon_decay) 211 | rewards = [] 212 | lens = [] 213 | test_rewards = [] 214 | test_lens = [] 215 | train_steps_list = [] 216 | test_steps_list = [] 217 | for i_episode in range(num_episodes + 1): 218 | 219 | # setup exploration policy for this episode 220 | epsilon_action = epsilon_schedule_action[min(i_episode, num_episodes - 1)] 221 | epsilon_temporal = epsilon_schedule_temporal[min(i_episode, num_episodes - 1)] 222 | action_policy = make_epsilon_greedy_policy(action_Q, epsilon_action, environment.action_space.n) 223 | temporal_policy = make_epsilon_greedy_policy(temporal_Q, epsilon_temporal, temporal_actions) 224 | 225 | episode_r = 0 226 | state = environment.reset() # type: list 227 | action_pol_len = 0 228 | while True: # roll out episode 229 | action = np.random.choice(list(range(environment.action_space.n)), p=action_policy(state)) 230 | temporal_state = (state, action) 231 | action_pol_len += 1 232 | temporal_action = np.random.choice(list(range(temporal_actions)), p=temporal_policy(temporal_state)) 233 | 234 | s_ = None 235 | done = False 236 | tmp_state = state 237 | skip_transition = SkipTransition(temporal_action + 1, discount_factor) 238 | reward = 0 239 | for tmp_temporal_action in range(temporal_action + 1): 240 | if not done: 241 | # only perform action if we are not done. If we are not done "skipping" though we have to 242 | # still add reward and same state to the skip_transition. 243 | s_, reward, done, _ = environment.step(action) 244 | skip_transition.add(reward, tmp_state) 245 | 246 | # 1-step update of action Q (like in vanilla Q) 247 | action_Q[tmp_state][action] = td_update(action_Q, tmp_state, action, 248 | reward, s_, discount_factor, alpha) 249 | 250 | count = 0 251 | # For all sofar observed transitions compute all forward skip updates 252 | for skip_num in range(skip_transition.idx): 253 | skip = skip_transition.state_mat[skip_num] 254 | rew = skip_transition.reward_mat[skip_num] 255 | skip_start_state = (skip[0], action) 256 | 257 | # Temporal TD update 258 | best_next_action = np.random.choice( 259 | np.flatnonzero(action_Q[s_] == action_Q[s_].max())) # greedy best next 260 | td_target = rew[skip_transition.idx - 1 - count] + ( 261 | discount_factor ** (skip_transition.idx - 1)) * action_Q[s_][best_next_action] 262 | td_delta = td_target - temporal_Q[skip_start_state][skip_transition.idx - count - 1] 263 | temporal_Q[skip_start_state][skip_transition.idx - count - 1] += alpha * td_delta 264 | count += 1 265 | 266 | tmp_state = s_ 267 | state = s_ 268 | if done: 269 | break 270 | rewards.append(episode_r) 271 | lens.append(action_pol_len) 272 | train_steps_list.append(environment.total_steps) 273 | 274 | # ---------------------------------------------- EVALUATION ------------------------------------------------- 275 | # ---------------------------------------------- EVALUATION ------------------------------------------------- 276 | test_steps = 0 277 | if i_episode % eval_every == 0: 278 | episode_r = 0 279 | state = environment.reset() # type: list 280 | if render_eval: 281 | environment.render(in_control=True) 282 | action_pol_len = 0 283 | while True: # roll out episode 284 | action = np.random.choice(np.flatnonzero(action_Q[state] == action_Q[state].max())) 285 | temporal_state = (state, action) 286 | action_pol_len += 1 287 | 288 | # Examples of different action selection schemes when greedily following a policy 289 | # temporal_action = np.random.choice( 290 | # np.flatnonzero(temporal_Q[temporal_state] == temporal_Q[temporal_state].max())) 291 | temporal_action = np.max( # if there are ties use the larger action 292 | np.flatnonzero(temporal_Q[temporal_state] == temporal_Q[temporal_state].max())) 293 | # temporal_action = np.min( # if there are ties use the smaller action 294 | # np.flatnonzero(temporal_Q[temporal_state] == temporal_Q[temporal_state].max())) 295 | 296 | for i in range(temporal_action + 1): 297 | environment.total_steps -= 1 # don't count evaluation steps 298 | s_, reward, done, _ = environment.step(action) 299 | test_steps += 1 300 | if render_eval: 301 | environment.render(in_control=False) 302 | episode_r += reward 303 | if done: 304 | break 305 | if render_eval: 306 | environment.render(in_control=True) 307 | state = s_ 308 | if done: 309 | break 310 | test_rewards.append(episode_r) 311 | test_lens.append(action_pol_len) 312 | test_steps_list.append(test_steps) 313 | print('Done %4d/%4d episodes' % (i_episode, num_episodes)) 314 | return (rewards, lens), (test_rewards, test_lens), (train_steps_list, test_steps_list), (action_Q, temporal_Q) 315 | 316 | 317 | if __name__ == '__main__': 318 | import argparse 319 | 320 | outdir_suffix_dict = {'none': '', 'empty': '', 'time': '%Y%m%dT%H%M%S.%f', 321 | 'seed': '{:d}', 'params': '{:d}_{:d}_{:d}', 322 | 'paramsseed': '{:d}_{:d}_{:d}_{:d}'} 323 | parser = argparse.ArgumentParser('Skip-MDP Tabular-Q') 324 | parser.add_argument('--episodes', '-e', 325 | default=10_000, 326 | type=int, 327 | help='Number of training episodes') 328 | parser.add_argument('--out-dir', 329 | default=None, 330 | type=str, 331 | help='Directory to save results. Defaults to tmp dir.') 332 | parser.add_argument('--out-dir-suffix', 333 | default='paramsseed', 334 | type=str, 335 | choices=list(outdir_suffix_dict.keys()), 336 | help='Created suffix of directory to save results.') 337 | parser.add_argument('--seed', '-s', 338 | default=12345, 339 | type=int, 340 | help='Seed') 341 | parser.add_argument('--env-max-steps', 342 | default=100, 343 | type=int, 344 | help='Maximal steps in environment before termination.', 345 | dest='env_ms') 346 | parser.add_argument('--agent-eps-decay', 347 | default='linear', 348 | choices={'linear', 'log', 'const'}, 349 | help='Epsilon decay schedule', 350 | dest='agent_eps_d') 351 | parser.add_argument('--agent-eps', 352 | default=1.0, 353 | type=float, 354 | help='Epsilon value. Used as start value when decay linear or log. Otherwise constant value.', 355 | dest='agent_eps') 356 | parser.add_argument('--agent', 357 | default='sq', 358 | choices={'sq', 'q'}, 359 | type=str.lower, 360 | help='Agent type to train') 361 | parser.add_argument('--env', 362 | default='lava', 363 | choices={'lava', 'lava2', 364 | 'lava_perc', 'lava2_perc', 365 | 'lava_ng', 'lava2_ng', 366 | 'lava3', 'lava3_perc', 'lava3_ng'}, 367 | type=str.lower, 368 | help='Enironment to use') 369 | parser.add_argument('--eval-eps', 370 | default=100, 371 | type=int, 372 | help='After how many episodes to evaluate') 373 | parser.add_argument('--stochasticity', 374 | default=0, 375 | type=float, 376 | help='probability of the selected action failing and instead executing any of the remaining 3') 377 | parser.add_argument('--no-render', 378 | action='store_true', 379 | help='Deactivate rendering of environment evaluation') 380 | parser.add_argument('--max-skips', 381 | type=int, 382 | default=7, 383 | help='Max skip size for tempoRL') 384 | 385 | # setup output dir 386 | args = parser.parse_args() 387 | outdir_suffix_dict['seed'] = outdir_suffix_dict['seed'].format(args.seed) 388 | outdir_suffix_dict['params'] = outdir_suffix_dict['params'].format( 389 | args.episodes, args.max_skips, args.env_ms) 390 | outdir_suffix_dict['paramsseed'] = outdir_suffix_dict['paramsseed'].format( 391 | args.episodes, args.max_skips, args.env_ms, args.seed) 392 | 393 | if not args.no_render: 394 | # Clear screen in ANSI terminal 395 | print('\033c') 396 | print('\x1bc') 397 | 398 | out_dir = experiments.prepare_output_dir(args, user_specified_dir=args.out_dir, 399 | time_format=outdir_suffix_dict[args.out_dir_suffix]) 400 | 401 | np.random.seed(args.seed) # seed nump 402 | d = None 403 | 404 | if args.env.startswith('lava'): 405 | import gym 406 | from grid_envs import Bridge6x10Env, Pit6x10Env, ZigZag6x10, ZigZag6x10H 407 | 408 | perc = args.env.endswith('perc') 409 | ng = args.env.endswith('ng') 410 | if args.env.startswith('lava2'): 411 | d = Bridge6x10Env(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng, 412 | act_fail_prob=args.stochasticity, numpy_state=False) 413 | elif args.env.startswith('lava3'): 414 | d = ZigZag6x10(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng, goal=(5, 9), 415 | act_fail_prob=args.stochasticity, numpy_state=False) 416 | elif args.env.startswith('lava4'): 417 | d = ZigZag6x10H(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng, goal=(5, 9), 418 | act_fail_prob=args.stochasticity, numpy_state=False) 419 | else: 420 | d = Pit6x10Env(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng, 421 | act_fail_prob=args.stochasticity, numpy_state=False) 422 | 423 | # setup agent 424 | if args.agent == 'sq': 425 | train_data, test_data, num_steps, (action_Q, t_Q) = temporl_q_learning(d, args.episodes, 426 | epsilon_decay=args.agent_eps_d, 427 | epsilon=args.agent_eps, 428 | discount_factor=.99, alpha=.5, 429 | eval_every=args.eval_eps, 430 | render_eval=not args.no_render, 431 | max_skip=args.max_skips) 432 | elif args.agent == 'q': 433 | train_data, test_data, num_steps, Q = q_learning(d, args.episodes, 434 | epsilon_decay=args.agent_eps_d, 435 | epsilon=args.agent_eps, 436 | discount_factor=.99, 437 | alpha=.5, eval_every=args.eval_eps, 438 | render_eval=not args.no_render) 439 | else: 440 | raise NotImplemented 441 | 442 | # TODO save resulting Q-function for easy reuse 443 | with open(os.path.join(out_dir, 'train_data.pkl'), 'wb') as outfh: 444 | pickle.dump(train_data, outfh) 445 | with open(os.path.join(out_dir, 'test_data.pkl'), 'wb') as outfh: 446 | pickle.dump(test_data, outfh) 447 | with open(os.path.join(out_dir, 'steps_per_episode.pkl'), 'wb') as outfh: 448 | pickle.dump(num_steps, outfh) 449 | 450 | if args.agent == 'q': 451 | with open(os.path.join(out_dir, 'Q.pkl'), 'wb') as outfh: 452 | pickle.dump(dict(Q), outfh) 453 | elif args.agent == 'sq': 454 | with open(os.path.join(out_dir, 'Q.pkl'), 'wb') as outfh: 455 | pickle.dump(dict(action_Q), outfh) 456 | with open(os.path.join(out_dir, 'J.pkl'), 'wb') as outfh: 457 | pickle.dump(dict(t_Q), outfh) 458 | -------------------------------------------------------------------------------- /tabular_requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml~=5.4 2 | seaborn~=0.10.1 3 | matplotlib~=3.3.0 4 | gym~=0.17.2 5 | numpy~=1.19.1 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/utils/__init__.py -------------------------------------------------------------------------------- /utils/config: -------------------------------------------------------------------------------- 1 | data: 2 | local: "." 3 | remote: "path_to_remote" 4 | plotting: 5 | color_map: 6 | sq: 4 7 | q: 2 8 | dqn: 2 9 | dar: 1 10 | tqn: 4 11 | tdqn: 4 12 | t-dqn: 4 13 | DDPG: 5 14 | t-DDPG: 9 15 | f-DDPG: 7 16 | palette: "colorblind" 17 | seaborn: 18 | style: "darkgrid" 19 | context: 20 | context: "paper" 21 | font scale: 1 22 | font: "Arial" 23 | rc: 24 | grid.linewidth: 4 25 | axes.labelsize: 32 26 | axes.titlesize: 32 27 | legend.fontsize: 28 28 | lines.linewidth: 4 29 | xtick.labelsize: 32 30 | ytick.labelsize: 32 31 | rc2: 32 | grid.linewidth: 6 33 | axes.labelsize: 64 34 | axes.titlesize: 64 35 | legend.fontsize: 35 36 | lines.linewidth: 6 37 | xtick.labelsize: 62 38 | ytick.labelsize: 62 39 | -------------------------------------------------------------------------------- /utils/data_handling.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | import glob 4 | import pickle 5 | import numpy as np 6 | 7 | import json 8 | import pandas as pd 9 | 10 | 11 | def load_config(): 12 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config'), 'r') as ymlf: 13 | cfg = yaml.load(ymlf, Loader=yaml.FullLoader) 14 | return cfg 15 | 16 | 17 | def load_data(experiment_dir="experiments_01_28", methods=['sq', 'q'], exp_version='-1.0-linear', 18 | episodes=50_000, max_skip=6, max_steps=100, local=True, debug=False): 19 | cfg = load_config() 20 | method_test_rewards = {} 21 | method_test_lengths = {} 22 | method_steps_per_episodes = {} 23 | if not local: 24 | print('Loading from') 25 | print(os.path.join(cfg['data']['local' if local else 'remote'], experiment_dir)) 26 | for method in methods: 27 | print(method) 28 | files = glob.glob( 29 | os.path.join(cfg['data']['local' if local else 'remote'], experiment_dir, 30 | '{:s}-experiments{:s}'.format(method, exp_version), 31 | '{:d}_{:d}_{:d}_*', 'test_data.pkl' 32 | ).format( 33 | episodes, 34 | max_skip, 35 | max_steps 36 | )) 37 | test_rewards, test_lens, steps_per_eps = [], [], [] 38 | for file in files: 39 | if debug: 40 | print('Loading', file) 41 | with open(file, 'rb') as fh: 42 | data = pickle.load(fh) 43 | test_rewards.append(data[0]) 44 | test_lens.append(data[1]) 45 | try: 46 | with open(file.replace('test_data', 'steps_per_episode'), 'rb') as fh: 47 | data = pickle.load(fh) 48 | steps_per_eps.append(data[1]) 49 | except FileNotFoundError: 50 | print('No steps data found') 51 | 52 | method_test_rewards[method] = np.array(test_rewards) 53 | method_test_lengths[method] = np.array(test_lens) 54 | method_steps_per_episodes[method] = np.array(steps_per_eps) 55 | return method_test_rewards, method_test_lengths, method_steps_per_episodes 56 | 57 | 58 | def load_dqn_data(experiment_dir, method, max_steps=None, succ_threashold=None, debug=False): 59 | cfg = load_config() 60 | print(os.path.abspath(os.path.join(method, experiment_dir, 'eval_scores.json'))) 61 | files = glob.glob( 62 | os.path.abspath(os.path.join(method, experiment_dir, 'eval_scores.json')), 63 | recursive=True 64 | ) 65 | frames = [] 66 | max_len = 0 67 | succ_count = 0 68 | for file in sorted(files): 69 | if debug: 70 | print('Loading', file) 71 | data = [] 72 | with open(file, 'r') as fh: 73 | for line in fh: 74 | loaded = json.loads(line) 75 | data.append(loaded) 76 | if max_steps and loaded['training_steps'] >= max_steps: 77 | break 78 | frame = pd.DataFrame(data) 79 | max_len = max(max_len, frame.shape[0]) 80 | if succ_threashold: 81 | if loaded['avg_rew_per_eval_ep'] > succ_threashold: 82 | succ_count += 1 83 | frames.append(frame) 84 | rews, lens, decs, training_steps, training_eps = [], [], [], [], [] 85 | for frame in frames: 86 | for (list_, array) in [(rews, frame.avg_rew_per_eval_ep), (lens, frame.avg_num_steps_per_eval_ep), 87 | (decs, frame.avg_num_decs_per_eval_ep), (training_steps, frame.training_steps), 88 | (training_eps, frame.training_eps)]: 89 | data = np.full((max_len,), np.nan) 90 | data[:len(array)] = array 91 | list_.append(data) 92 | mean_r, std_r = np.nanmean(rews, axis=0), np.nanstd(rews, axis=0) 93 | mean_l, std_l = np.nanmean(lens, axis=0), np.nanstd(lens, axis=0) 94 | mean_d, std_d = np.nanmean(decs, axis=0), np.nanstd(decs, axis=0) 95 | mean_ts, std_ts = np.nanmean(training_steps, axis=0), np.nanstd(training_steps, axis=0) 96 | mean_te, std_te = np.nanmean(training_eps, axis=0), np.nanstd(training_eps, axis=0) 97 | if succ_threashold: 98 | try: 99 | print('\t {}/{} ({}\%) runs exceeded a final performance of {} after {} training steps'.format( 100 | succ_count, len(frames), succ_count / len(frames) * 100, succ_threashold, max_steps 101 | )) 102 | except ZeroDivisionError: 103 | pass 104 | if debug: 105 | print('#' * 80, '\n') 106 | 107 | return (mean_r, std_r), (mean_l, std_l), (mean_d, std_d), (mean_ts, std_ts), (mean_te, std_te) 108 | -------------------------------------------------------------------------------- /utils/env_wrappers.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gym 3 | import gym.spaces 4 | import numpy as np 5 | import collections 6 | from ray.rllib.env.atari_wrappers import wrap_deepmind, WarpFrame 7 | 8 | 9 | class FireResetEnv(gym.Wrapper): 10 | def __init__(self, env=None): 11 | """For environments where the user need to press FIRE for the game to start.""" 12 | super(FireResetEnv, self).__init__(env) 13 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 14 | assert len(env.unwrapped.get_action_meanings()) >= 3 15 | 16 | def step(self, action): 17 | return self.env.step(action) 18 | 19 | def reset(self): 20 | self.env.reset() 21 | obs, _, done, _ = self.env.step(1) 22 | if done: 23 | self.env.reset() 24 | obs, _, done, _ = self.env.step(2) 25 | if done: 26 | self.env.reset() 27 | return obs 28 | 29 | 30 | class MaxAndSkipEnv(gym.Wrapper): 31 | def __init__(self, env=None, skip=4): 32 | """Return only every `skip`-th frame""" 33 | super(MaxAndSkipEnv, self).__init__(env) 34 | # most recent raw observations (for max pooling across time steps) 35 | self._obs_buffer = collections.deque(maxlen=2) 36 | self._skip = skip 37 | 38 | def step(self, action): 39 | total_reward = 0.0 40 | done = None 41 | for _ in range(self._skip): 42 | obs, reward, done, info = self.env.step(action) 43 | self._obs_buffer.append(obs) 44 | total_reward += reward 45 | if done: 46 | break 47 | max_frame = np.max(np.stack(self._obs_buffer), axis=0) 48 | return max_frame, total_reward, done, info 49 | 50 | def reset(self): 51 | """Clear past frame buffer and init. to first obs. from inner env.""" 52 | self._obs_buffer.clear() 53 | obs = self.env.reset() 54 | self._obs_buffer.append(obs) 55 | return obs 56 | 57 | 58 | class ProcessFrame84(gym.ObservationWrapper): 59 | def __init__(self, env=None): 60 | super(ProcessFrame84, self).__init__(env) 61 | self.observation_space = gym.spaces.Box( 62 | low=0, high=255, shape=(84, 84, 1), dtype=np.uint8) 63 | 64 | def observation(self, obs): 65 | return ProcessFrame84.process(obs) 66 | 67 | @staticmethod 68 | def process(frame): 69 | if frame.size == 210 * 160 * 3: 70 | img = np.reshape(frame, [210, 160, 3]).astype( 71 | np.float32) 72 | elif frame.size == 250 * 160 * 3: 73 | img = np.reshape(frame, [250, 160, 3]).astype( 74 | np.float32) 75 | else: 76 | assert False, "Unknown resolution." 77 | img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + \ 78 | img[:, :, 2] * 0.114 79 | resized_screen = cv2.resize( 80 | img, (84, 110), interpolation=cv2.INTER_AREA) 81 | x_t = resized_screen[18:102, :] 82 | x_t = np.reshape(x_t, [84, 84, 1]) 83 | return x_t.astype(np.uint8) 84 | 85 | 86 | class ImageToPyTorch(gym.ObservationWrapper): 87 | def __init__(self, env): 88 | super(ImageToPyTorch, self).__init__(env) 89 | old_shape = self.observation_space.shape 90 | new_shape = (old_shape[-1], old_shape[0], old_shape[1]) 91 | self.observation_space = gym.spaces.Box( 92 | low=0.0, high=1.0, shape=new_shape, dtype=np.float32) 93 | 94 | def observation(self, observation): 95 | return np.moveaxis(observation, 2, 0) 96 | 97 | 98 | class ScaledFloatFrame(gym.ObservationWrapper): 99 | def observation(self, obs): 100 | return np.array(obs).astype(np.float32) / 255.0 101 | 102 | 103 | class BufferWrapper(gym.ObservationWrapper): 104 | def __init__(self, env, n_steps, dtype=np.float32): 105 | super(BufferWrapper, self).__init__(env) 106 | self.dtype = dtype 107 | old_space = env.observation_space 108 | self.observation_space = gym.spaces.Box( 109 | old_space.low.repeat(n_steps, axis=0), 110 | old_space.high.repeat(n_steps, axis=0), dtype=dtype) 111 | 112 | def reset(self): 113 | self.buffer = np.zeros_like( 114 | self.observation_space.low, dtype=self.dtype) 115 | return self.observation(self.env.reset()) 116 | 117 | def observation(self, observation): 118 | self.buffer[:-1] = self.buffer[1:] 119 | self.buffer[-1] = observation 120 | return self.buffer 121 | 122 | 123 | def make_env(env_name, dim=42): 124 | env = gym.make(env_name) 125 | env = wrap_deepmind(env, dim=dim, framestack=True) # grayscales already 126 | env = ImageToPyTorch(env) 127 | return env 128 | 129 | def make_env_old(env_name, dim=42): 130 | env = gym.make(env_name) 131 | env = MaxAndSkipEnv(env, skip=1) 132 | env = FireResetEnv(env) 133 | # env = ProcessFrame84(env) 134 | env = WarpFrame(env, dim=dim) 135 | env = ImageToPyTorch(env) 136 | env = BufferWrapper(env, 4) # Stacks frames 137 | return ScaledFloatFrame(env) # Normalises observations 138 | -------------------------------------------------------------------------------- /utils/experiments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import sys 6 | import tempfile 7 | 8 | 9 | def prepare_output_dir(args, user_specified_dir=None, argv=None, 10 | time_format='%Y%m%dT%H%M%S.%f'): 11 | """ 12 | Largely a copy of chainerRLs prepare output dir 13 | See (https://github.com/chainer/chainerrl/blob/018a29132d77e5af0f92161250c72aba10c6ce29/chainerrl/experiments/prepare_output_dir.py) 14 | Prepare a directory for outputting training results. 15 | 16 | An output directory, which ends with the current datetime string, 17 | is created. Then the following infomation is saved into the directory: 18 | 19 | args.txt: command line arguments 20 | command.txt: command itself 21 | environ.txt: environmental variables 22 | 23 | Args: 24 | args (dict or argparse.Namespace): Arguments to save 25 | user_specified_dir (str or None): If str is specified, the output 26 | directory is created under that path. If not specified, it is 27 | created as a new temporary directory instead. 28 | argv (list or None): The list of command line arguments passed to a 29 | script. If not specified, sys.argv is used instead. 30 | time_format (str): Format used to represent the current datetime. The 31 | default format is the basic format of ISO 8601. 32 | Returns: 33 | Path of the output directory created by this function (str). 34 | """ 35 | time_str = datetime.datetime.now().strftime(time_format) 36 | if user_specified_dir is not None: 37 | if os.path.exists(user_specified_dir): 38 | if not os.path.isdir(user_specified_dir): 39 | raise RuntimeError( 40 | '{} is not a directory'.format(user_specified_dir)) 41 | outdir = os.path.join(user_specified_dir, time_str) 42 | if os.path.exists(outdir): 43 | raise RuntimeError('{} exists'.format(outdir)) 44 | else: 45 | os.makedirs(outdir) 46 | else: 47 | outdir = tempfile.mkdtemp(prefix=time_str) 48 | 49 | # Save all the arguments 50 | with open(os.path.join(outdir, 'args.txt'), 'w') as f: 51 | if isinstance(args, argparse.Namespace): 52 | args = vars(args) 53 | f.write(json.dumps(args)) 54 | 55 | # Save all the environment variables 56 | with open(os.path.join(outdir, 'environ.txt'), 'w') as f: 57 | f.write(json.dumps(dict(os.environ))) 58 | 59 | # Save the command 60 | with open(os.path.join(outdir, 'command.txt'), 'w') as f: 61 | if argv is None: 62 | argv = sys.argv 63 | f.write(' '.join(argv)) 64 | 65 | print('Results stored in {:s}'.format(os.path.abspath(outdir))) 66 | return outdir 67 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import seaborn as sb 4 | 5 | from utils.data_handling import load_config 6 | 7 | 8 | def get_colors(): 9 | cfg = load_config() 10 | sb.set_style(cfg['plotting']['seaborn']['style']) 11 | sb.set_context(cfg['plotting']['seaborn']['context']['context'], 12 | font_scale=cfg['plotting']['seaborn']['context']['font scale'], 13 | rc=cfg['plotting']['seaborn']['context']['rc']) 14 | colors = list(sb.color_palette(cfg['plotting']['palette'])) 15 | color_map = cfg['plotting']['color_map'] 16 | return colors, color_map 17 | 18 | 19 | def plot_lens(ax, methods, lens, steps_per_ep, title, episodes=50_000, 20 | log=False, logx=False, eval_steps=1): 21 | colors, color_map = get_colors() 22 | print('AVG pol length') 23 | mv = 0 24 | created_steps_leg = False 25 | for method in methods: 26 | if method != 'hq': 27 | m, s = lens[method].mean(axis=0), lens[method].std(axis=0) 28 | total_ones = np.ones(m.shape) * 100 29 | # print("{:>4s}: {:>5.2f}".format(method, 1 - np.sum(total_ones - m) / np.sum(total_ones))) 30 | print("{:>4s}: {:>5.2f}".format(method, np.mean(m))) 31 | ax.step(np.arange(1, m.shape[0] + 1) * eval_steps, m, where='post', 32 | c=colors[color_map[method]]) 33 | ax.fill_between( 34 | np.arange(1, m.shape[0] + 1) * eval_steps, m - s, m + s, alpha=0.25, step='post', 35 | color=colors[color_map[method]]) 36 | if 'hq' not in methods: 37 | mv = max(mv, max(m) + max(m) * .05) 38 | else: 39 | raise NotImplementedError 40 | if steps_per_ep: 41 | m, s = steps_per_ep[method].mean(axis=0), steps_per_ep[method].std(axis=0) 42 | ax.step(np.arange(1, m.shape[0] + 1) * eval_steps, m, where='post', 43 | c=np.array(colors[color_map[method]]) * .75, ls=':') 44 | ax.fill_between( 45 | np.arange(1, m.shape[0] + 1) * eval_steps, m - s, m + s, alpha=0.125, step='post', 46 | color=np.array(colors[color_map[method]]) * .75) 47 | mv = max(mv, max(m)) 48 | if not created_steps_leg: 49 | ax.plot([-999, -999], [-999, -999], ls=':', c='k', label='all') 50 | ax.plot([-999, -999], [-999, -999], ls='-', c='k', label='dec') 51 | created_steps_leg = True 52 | if log: 53 | ax.set_ylim([1, max(mv, 100)]) 54 | ax.semilogy() 55 | else: 56 | ax.set_ylim([0, mv]) 57 | if logx: 58 | ax.set_ylim([1, max(mv, 100)]) 59 | ax.set_xlim([1, episodes * eval_steps]) 60 | ax.semilogx() 61 | else: 62 | ax.set_xlim([0, episodes * eval_steps]) 63 | ax.set_ylabel('#Steps') 64 | if steps_per_ep: 65 | ax.legend(loc='upper right', ncol=1, handlelength=.75) 66 | ax.set_ylabel('#Steps') 67 | ax.set_xlabel('#Episodes') 68 | ax.set_title(title) 69 | return ax 70 | 71 | 72 | def _annotate(ax, rewards, max_reward, eval_steps): 73 | qxy = ((np.where(rewards['q'].mean(axis=0) >= .5 * max_reward)[0])[0] * eval_steps, .5) 74 | sqvxy = ((np.where(rewards['sq'].mean(axis=0) >= .5 * max_reward)[0])[0] * eval_steps, .5) 75 | ax.annotate("", # '{:d}x speedup'.format(int(np.round(qxy[0]/sqvxy[0]))), 76 | xy=qxy, 77 | xycoords='data', xytext=sqvxy, textcoords='data', 78 | arrowprops=dict(arrowstyle="<->", color="0.", 79 | connectionstyle="arc3,rad=0.", lw=5, 80 | ), ) 81 | 82 | speedup = qxy[0] / sqvxy[0] 83 | qxy = (qxy[0], .5 * max_reward) 84 | sqvxy = (sqvxy[0], .25) 85 | ax.annotate(r'${:.2f}\times$ speedup'.format(speedup), 86 | xy=qxy, 87 | xycoords='data', xytext=sqvxy, textcoords='data', 88 | arrowprops=dict(arrowstyle="-", color="0.", 89 | connectionstyle="arc3,rad=0.", lw=0 90 | ), 91 | fontsize=22) 92 | 93 | try: 94 | qxy = ((np.where(rewards['q'].mean(axis=0) >= max_reward)[0])[0] * eval_steps, max_reward) 95 | sqvxy = ((np.where(rewards['sq'].mean(axis=0) >= max_reward)[0])[0] * eval_steps, max_reward) 96 | ax.annotate("", # '{:d}x speedup'.format(int(np.round(qxy[0]/sqvxy[0]))), 97 | xy=qxy, 98 | xycoords='data', xytext=sqvxy, textcoords='data', 99 | arrowprops=dict(arrowstyle="<->", color="0.", 100 | connectionstyle="arc3,rad=0.", lw=5, 101 | ), ) 102 | 103 | speedup = qxy[0] / sqvxy[0] 104 | qxy = (qxy[0], max_reward) 105 | sqvxy = (sqvxy[0], .75) 106 | ax.annotate(r'${:.2f}\times$ speedup'.format(speedup), 107 | xy=qxy, 108 | xycoords='data', xytext=sqvxy, textcoords='data', 109 | arrowprops=dict(arrowstyle="-", color="0.", 110 | connectionstyle="arc3,rad=0.", lw=0 111 | ), 112 | fontsize=22) 113 | except: 114 | pass 115 | 116 | 117 | def plot_rewards(ax, methods, rewards, title, episodes=50_000, 118 | xlabel='#Episodes', log=False, logx=False, annotate=False, eval_steps=1): 119 | colors, color_map = get_colors() 120 | print('AUC') 121 | min_m = np.inf 122 | max_m = -np.inf 123 | for method in methods: 124 | m, s = rewards[method].mean(axis=0), rewards[method].std(axis=0) 125 | # used for AUC computation 126 | m_, s_ = ((rewards[method] + 1) / 2).mean(axis=0), ((rewards[method] + 1) / 2).std(axis=0) 127 | min_m = min(min(m), min_m) 128 | max_m = max(max(m), max_m) 129 | total_ones = np.ones(m.shape) 130 | label = method 131 | if method == 'sqv3': 132 | label = "sn-$\mathcal{Q}$" 133 | elif method == 'sq': 134 | label = "t-$\mathcal{Q}$" 135 | label = label.replace('q', '$\mathcal{Q}$') 136 | label = r'{:s}'.format(label) 137 | print("{:>2s}: {:>5.2f}".format(method, 1 - np.sum(total_ones - m_) / np.sum(total_ones))) 138 | ax.step(np.arange(1, m.shape[0] + 1) * eval_steps, m, where='post', c=colors[color_map[method]], 139 | label=label) 140 | ax.fill_between(np.arange(1, m.shape[0] + 1) * eval_steps, m - s, m + s, alpha=0.25, step='post', 141 | color=colors[color_map[method]]) 142 | if annotate: 143 | _annotate(ax, rewards, max_m, eval_steps) 144 | if log: 145 | ax.set_ylim([.01, max_m]) 146 | ax.semilogy() 147 | else: 148 | ax.set_ylim([min(-1, min_m - .1), max(1, max_m + .1)]) 149 | if logx: 150 | ax.set_xlim([1, episodes * eval_steps]) 151 | ax.semilogx() 152 | else: 153 | ax.set_xlim([0, episodes * eval_steps]) 154 | ax.set_ylabel('Reward') 155 | ax.set_xlabel(xlabel) 156 | ax.set_title(title) 157 | ax.legend(ncol=1, loc='lower right', handlelength=.75) 158 | return ax 159 | 160 | 161 | def plot(methods, rewards, lens, steps_per_ep, title, episodes=50_000, 162 | show=True, savefig=None, logleny=True, 163 | logrewy=True, logx=False, annotate=False, eval_steps=1, horizontal=False, 164 | individual=False): 165 | _, _ = get_colors() 166 | if not individual: 167 | if horizontal: 168 | fig, ax = plt.subplots(1, 2, figsize=(32, 5), dpi=100) 169 | else: 170 | fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100) 171 | ax[0] = plot_rewards(ax[0], methods, rewards, title, episodes, 172 | xlabel='#Episodes' if horizontal else '', 173 | log=logrewy, logx=logx, annotate=annotate, eval_steps=eval_steps) 174 | print() 175 | ax[1] = plot_lens(ax[1], methods, lens, steps_per_ep, '', episodes, log=logleny, 176 | logx=logx, eval_steps=eval_steps) 177 | if savefig: 178 | plt.savefig(savefig, dpi=100) 179 | if show: 180 | plt.show() 181 | else: 182 | try: 183 | name, suffix = savefig.split('.') 184 | except: 185 | name = savefig 186 | suffix = '.pdf' 187 | fig, ax = plt.subplots(1, figsize=(10, 4), dpi=100) 188 | ax = plot_rewards(ax, methods, rewards, '', episodes, 189 | xlabel='#Episodes', 190 | log=logrewy, logx=logx, annotate=annotate, eval_steps=eval_steps) 191 | plt.tight_layout() 192 | if savefig: 193 | plt.savefig(name + '_rewards' + '.' + suffix, dpi=100) 194 | if show: 195 | plt.show() 196 | fig, ax = plt.subplots(1, figsize=(10, 4), dpi=100) 197 | ax = plot_lens(ax, methods, lens, steps_per_ep, '', episodes, log=logleny, 198 | logx=logx, eval_steps=eval_steps) 199 | plt.tight_layout() 200 | if savefig: 201 | plt.savefig(name + '_lens' + '.' + suffix, dpi=100) 202 | if show: 203 | plt.show() 204 | --------------------------------------------------------------------------------