├── .gitignore
├── DDPG
├── FiGAR.py
├── TempoRL.py
├── __init__.py
├── utils.py
└── vanilla.py
├── LICENSE
├── README.md
├── TempoRL_Appendix.pdf
├── experiments
└── .gitkeep
├── grid_envs.py
├── mountain_car.py
├── plot_atari_results.ipynb
├── plot_ddpg.ipynb
├── plot_featurized_results.ipynb
├── plot_tabular_results.ipynb
├── requirements.txt
├── run_atari_experiments.py
├── run_ddpg_experiments.py
├── run_featurized_experiments.py
├── run_tabular_experiments.py
├── tabular_requirements.txt
└── utils
├── __init__.py
├── config
├── data_handling.py
├── env_wrappers.py
├── experiments.py
└── plotting.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pdf
2 | experiments/*
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | *.idea
10 | *.ipynb
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
--------------------------------------------------------------------------------
/DDPG/FiGAR.py:
--------------------------------------------------------------------------------
1 | """
2 | Adaptation of the vanilla DDPG code to allow for FiGAR modification as presented
3 | in the FiGAR paper https://arxiv.org/pdf/1702.06054.pdf
4 | """
5 |
6 |
7 | import copy
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 |
14 | # We have to modify the Actor to also generate the repetition output
15 | class Actor(nn.Module):
16 | def __init__(self, state_dim, action_dim, max_action, rep_dim):
17 | super(Actor, self).__init__()
18 |
19 | self.l1 = nn.Linear(state_dim, 400)
20 | self.l2 = nn.Linear(400, 300)
21 | self.l3 = nn.Linear(300, action_dim)
22 |
23 | self.max_action = max_action
24 |
25 | self.l5 = nn.Linear(400, 300)
26 | self.l6 = nn.Linear(300, rep_dim)
27 |
28 | def forward(self, state):
29 | # As suggested by the FiGAR authors, the input layer is shared
30 | shared = F.relu(self.l1(state))
31 | a = F.relu(self.l2(shared))
32 |
33 | r = F.relu(self.l5(shared))
34 | return self.max_action * torch.tanh(self.l3(a)), F.log_softmax(self.l6(r), dim=1)
35 |
36 |
37 | # The Critic has to be modified to be able to accept the additional repetition output of the actor
38 | class Critic(nn.Module):
39 | def __init__(self, state_dim, action_dim, repetition_dim):
40 | super(Critic, self).__init__()
41 |
42 | self.l1 = nn.Linear(state_dim + action_dim + repetition_dim, 400)
43 | self.l2 = nn.Linear(400, 300)
44 | self.l3 = nn.Linear(300, 1)
45 |
46 | def forward(self, state, action, repetition):
47 | q = F.relu(self.l1(torch.cat([state, action, repetition], 1)))
48 | q = F.relu(self.l2(q))
49 | return self.l3(q)
50 |
51 |
52 | class DDPG(object):
53 | def __init__(self, state_dim, action_dim, max_action, repetition_dim, discount=0.99, tau=0.005):
54 | self.actor = Actor(state_dim, action_dim, max_action, repetition_dim).to(device)
55 | self.actor_target = copy.deepcopy(self.actor)
56 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters())
57 |
58 | self.critic = Critic(state_dim, action_dim, repetition_dim).to(device)
59 | self.critic_target = copy.deepcopy(self.critic)
60 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
61 |
62 | self.discount = discount
63 | self.tau = tau
64 |
65 | def select_action(self, state):
66 | # The select action method has to be adjusted to also sample from the repetition distribution
67 | state = torch.FloatTensor(state.reshape(1, -1)).to(device)
68 | action, repetition_prob = self.actor(state)
69 | repetition_dist = torch.distributions.Categorical(repetition_prob)
70 | repetition = repetition_dist.sample()
71 | return (action.cpu().data.numpy().flatten(),
72 | repetition.cpu().data.numpy().flatten(),
73 | repetition_prob.cpu().data.numpy().flatten())
74 |
75 | def train(self, replay_buffer, batch_size=100):
76 | # The train method has to be adapted to take changes to Actor and Critic into account
77 | # Sample replay buffer
78 | state, action, repetition, next_state, reward, not_done = replay_buffer.sample(batch_size)
79 |
80 | # Compute the target Q value
81 | target_Q = self.critic_target(next_state, *self.actor_target(next_state))
82 | target_Q = reward + (not_done * self.discount * target_Q).detach()
83 |
84 | # Get current Q estimate
85 | current_Q = self.critic(state, action, repetition)
86 |
87 | # Compute critic loss
88 | critic_loss = F.mse_loss(current_Q, target_Q)
89 |
90 | # Optimize the critic
91 | self.critic_optimizer.zero_grad()
92 | critic_loss.backward()
93 | self.critic_optimizer.step()
94 |
95 | # Compute actor loss
96 | actor_loss = -self.critic(state, *self.actor(state)).mean()
97 |
98 | # Optimize the actor
99 | self.actor_optimizer.zero_grad()
100 | actor_loss.backward()
101 | self.actor_optimizer.step()
102 |
103 | # Update the frozen target models
104 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
105 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
106 |
107 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
108 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
109 |
110 | def save(self, filename):
111 | torch.save(self.critic.state_dict(), filename + "_critic")
112 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
113 |
114 | torch.save(self.actor.state_dict(), filename + "_actor")
115 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")
116 |
117 | def load(self, filename):
118 | self.critic.load_state_dict(torch.load(filename + "_critic"))
119 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
120 | self.critic_target = copy.deepcopy(self.critic)
121 |
122 | self.actor.load_state_dict(torch.load(filename + "_actor"))
123 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
124 | self.actor_target = copy.deepcopy(self.actor)
125 |
--------------------------------------------------------------------------------
/DDPG/TempoRL.py:
--------------------------------------------------------------------------------
1 | """
2 | Adaptation of the vanilla DDPG code to allow for TempoRL modification.
3 | """
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | # we use exactly the same Actor and Critic networks and training methods for both as in the vanilla implementation
10 | from DDPG.vanilla import DDPG as VanillaDDPG
11 |
12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 |
14 |
15 | class Q(nn.Module):
16 | """
17 | Simple fully connected Q function. Also used for skip-Q when concatenating behaviour action and state together.
18 | Used for simpler environments such as mountain-car or lunar-lander.
19 | """
20 |
21 | def __init__(self, state_dim, action_dim, skip_dim, non_linearity=F.relu):
22 | super(Q, self).__init__()
23 | # We follow the architecture of the Actor and Critic networks in terms of depth and hidden units
24 | self.fc1 = nn.Linear(state_dim + action_dim, 400)
25 | self.fc2 = nn.Linear(400, 300)
26 | self.fc3 = nn.Linear(300, skip_dim)
27 | self._non_linearity = non_linearity
28 |
29 | def forward(self, x):
30 | x = self._non_linearity(self.fc1(x))
31 | x = self._non_linearity(self.fc2(x))
32 | return self.fc3(x)
33 |
34 |
35 | class DDPG(VanillaDDPG):
36 | def __init__(self, state_dim, action_dim, max_action, skip_dim, discount=0.99, tau=0.005):
37 | # We can fully reuse the vanilla DDPG and simply stack TempoRL on top
38 | super(DDPG, self).__init__(state_dim, action_dim, max_action, discount, tau)
39 |
40 | # Create Skip Q network
41 | self.skip_Q = Q(state_dim, action_dim, skip_dim)
42 | self.skip_optimizer = torch.optim.Adam(self.skip_Q.parameters())
43 |
44 | def select_skip(self, state, action):
45 | """
46 | Select the skip action.
47 | Has to be called after select_action
48 | """
49 | state = torch.FloatTensor(state.reshape(1, -1)).to(device)
50 | action = torch.FloatTensor(action.reshape(1, -1)).to(device)
51 | return self.skip_Q(torch.cat([state, action], 1)).cpu().data.numpy().flatten()
52 |
53 | def train_skip(self, replay_buffer, batch_size=100):
54 | """
55 | Train the skip network
56 | """
57 | # Sample replay buffer
58 | state, action, skip, next_state, reward, not_done = replay_buffer.sample(batch_size)
59 |
60 | # Compute the target Q value
61 | target_Q = self.critic_target(next_state, self.actor_target(next_state))
62 | target_Q = reward + (not_done * np.power(self.discount, skip + 1) * target_Q).detach()
63 |
64 | # Get current Q estimate
65 | current_Q = self.skip_Q(torch.cat([state, action], 1)).gather(1, skip.long())
66 |
67 | # Compute critic loss
68 | critic_loss = F.mse_loss(current_Q, target_Q)
69 |
70 | # Optimize the critic
71 | self.skip_optimizer.zero_grad()
72 | critic_loss.backward()
73 | self.skip_optimizer.step()
74 |
75 | def save(self, filename):
76 | super().save(filename)
77 |
78 | torch.save(self.skip_Q.state_dict(), filename + "_skip")
79 | torch.save(self.skip_optimizer.state_dict(), filename + "_skip_optimizer")
80 |
81 | def load(self, filename):
82 | super().load(filename)
83 |
84 | self.skip_Q.load_state_dict(torch.load(filename + "_skip"))
85 | self.skip_optimizer.load_state_dict(torch.load(filename + "_skip_optimizer"))
86 |
--------------------------------------------------------------------------------
/DDPG/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/DDPG/__init__.py
--------------------------------------------------------------------------------
/DDPG/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | The ReplayBuffer was originally implemented by Scott Fujimoto https://github.com/sfujim/TD3/blob/master/utils.py
3 |
4 | We added a second replay buffer that can take repetitions/skips into account and created a version of the
5 | Pendulum-v0 environment that has a different rendering style to display if actions were reactive or proactive
6 | """
7 |
8 | import numpy as np
9 | import torch
10 |
11 |
12 | class ReplayBuffer(object):
13 | def __init__(self, state_dim, action_dim, max_size=int(1e6)):
14 | self.max_size = max_size
15 | self.ptr = 0
16 | self.size = 0
17 |
18 | self.state = np.zeros((max_size, state_dim))
19 | self.action = np.zeros((max_size, action_dim))
20 | self.next_state = np.zeros((max_size, state_dim))
21 | self.reward = np.zeros((max_size, 1))
22 | self.not_done = np.zeros((max_size, 1))
23 |
24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25 |
26 | def add(self, state, action, next_state, reward, done):
27 | self.state[self.ptr] = state
28 | self.action[self.ptr] = action
29 | self.next_state[self.ptr] = next_state
30 | self.reward[self.ptr] = reward
31 | self.not_done[self.ptr] = 1. - done
32 |
33 | self.ptr = (self.ptr + 1) % self.max_size
34 | self.size = min(self.size + 1, self.max_size)
35 |
36 | def sample(self, batch_size):
37 | ind = np.random.randint(0, self.size, size=batch_size)
38 |
39 | return (
40 | torch.FloatTensor(self.state[ind]).to(self.device),
41 | torch.FloatTensor(self.action[ind]).to(self.device),
42 | torch.FloatTensor(self.next_state[ind]).to(self.device),
43 | torch.FloatTensor(self.reward[ind]).to(self.device),
44 | torch.FloatTensor(self.not_done[ind]).to(self.device)
45 | )
46 |
47 |
48 | class FiGARReplayBuffer(object):
49 | def __init__(self, state_dim, action_dim, rep_dim, max_size=int(1e6)):
50 | self.max_size = max_size
51 | self.ptr = 0
52 | self.size = 0
53 |
54 | self.state = np.zeros((max_size, state_dim))
55 | self.action = np.zeros((max_size, action_dim))
56 | self.rep = np.zeros((max_size, rep_dim))
57 | self.next_state = np.zeros((max_size, state_dim))
58 | self.reward = np.zeros((max_size, 1))
59 | self.not_done = np.zeros((max_size, 1))
60 |
61 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62 |
63 | def add(self, state, action, rep, next_state, reward, done):
64 | self.state[self.ptr] = state
65 | self.action[self.ptr] = action
66 | self.rep[self.ptr] = rep
67 | self.next_state[self.ptr] = next_state
68 | self.reward[self.ptr] = reward
69 | self.not_done[self.ptr] = 1. - done
70 |
71 | self.ptr = (self.ptr + 1) % self.max_size
72 | self.size = min(self.size + 1, self.max_size)
73 |
74 | def sample(self, batch_size):
75 | ind = np.random.randint(0, self.size, size=batch_size)
76 |
77 | return (
78 | torch.FloatTensor(self.state[ind]).to(self.device),
79 | torch.FloatTensor(self.action[ind]).to(self.device),
80 | torch.FloatTensor(self.rep[ind]).to(self.device),
81 | torch.FloatTensor(self.next_state[ind]).to(self.device),
82 | torch.FloatTensor(self.reward[ind]).to(self.device),
83 | torch.FloatTensor(self.not_done[ind]).to(self.device)
84 | )
85 |
86 |
87 | import gym
88 |
89 |
90 | class Render(gym.Wrapper):
91 | """Render env by calling its render method.
92 |
93 | Args:
94 | env (gym.Env): Env to wrap.
95 | **kwargs: Keyword arguments passed to the render method.
96 | """
97 |
98 | def __init__(self, env, episode_modulo=1, **kwargs):
99 | super().__init__(env)
100 | self._kwargs = kwargs
101 | self.render_every_nth_episode = episode_modulo
102 | self._episode_counter = -1
103 |
104 | def reset(self, **kwargs):
105 | self._episode_counter += 1
106 | ret = self.env.reset(**kwargs)
107 | if self._episode_counter % self.render_every_nth_episode == 0:
108 | self.env.render(**self._kwargs)
109 | return ret
110 |
111 | def step(self, action):
112 | ret = self.env.step(action)
113 | if self._episode_counter % self.render_every_nth_episode == 0:
114 | self.env.render(**self._kwargs)
115 | return ret
116 |
117 | def close(self):
118 | self.env.close()
119 |
120 |
121 | from gym.envs.classic_control import PendulumEnv
122 | from os import path
123 | import time
124 | import inspect
125 |
126 |
127 | class MyPendulum(PendulumEnv):
128 |
129 | def __init__(self, **kwargs):
130 | self.dec = False
131 | super().__init__(**kwargs)
132 |
133 | def is_decision_point(self):
134 | return self.dec
135 |
136 | def set_decision_point(self, b):
137 | self.dec = b
138 |
139 | def reset(self):
140 | self.dec = True
141 | return super().reset()
142 |
143 | def render(self, mode='human'):
144 | if self.viewer is None:
145 | from gym.envs.classic_control import rendering
146 | self.viewer = rendering.Viewer(500, 500)
147 | self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
148 | self.rod = rendering.make_capsule(1, .2)
149 | if self.is_decision_point():
150 | self.rod.set_color(.3, .3, .8)
151 | time.sleep(0.5)
152 | self.set_decision_point(False)
153 | else:
154 | self.rod.set_color(.8, .3, .3)
155 | self.pole_transform = rendering.Transform()
156 | self.rod.add_attr(self.pole_transform)
157 | self.viewer.add_geom(self.rod)
158 | axle = rendering.make_circle(.05)
159 | axle.set_color(0, 0, 0)
160 | self.viewer.add_geom(axle)
161 | fname = path.join(path.dirname(inspect.getfile(PendulumEnv)), "assets/clockwise.png")
162 | self.img = rendering.Image(fname, 1., 1.)
163 | self.imgtrans = rendering.Transform()
164 | self.img.add_attr(self.imgtrans)
165 |
166 | self.viewer.add_onetime(self.img)
167 |
168 | if self.is_decision_point():
169 | self.rod.set_color(.3, .3, .8)
170 | # time.sleep(0.5)
171 | self.set_decision_point(False)
172 | else:
173 | self.rod.set_color(.8, .3, .3)
174 |
175 | self.pole_transform.set_rotation(self.state[0] + np.pi / 2)
176 | if self.last_u:
177 | self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2)
178 |
179 | return self.viewer.render(return_rgb_array=mode == 'rgb_array')
180 |
181 | def close(self):
182 | if self.viewer:
183 | self.viewer.close()
184 | self.viewer = None
185 |
186 |
187 | from gym.envs.registration import register
188 |
189 | register(
190 | id='PendulumDecs-v0',
191 | entry_point='DDPG.utils:MyPendulum',
192 | max_episode_steps=200,
193 | )
194 |
--------------------------------------------------------------------------------
/DDPG/vanilla.py:
--------------------------------------------------------------------------------
1 | """
2 | Copy of code by Scott Fujimoto released under the MIT license
3 | https://github.com/sfujim/TD3
4 | """
5 | import copy
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | # Re-tuned version of Deep Deterministic Policy Gradients (DDPG)
15 | # Paper: https://arxiv.org/abs/1509.02971
16 |
17 |
18 | class Actor(nn.Module):
19 | def __init__(self, state_dim, action_dim, max_action):
20 | super(Actor, self).__init__()
21 |
22 | self.l1 = nn.Linear(state_dim, 400)
23 | self.l2 = nn.Linear(400, 300)
24 | self.l3 = nn.Linear(300, action_dim)
25 |
26 | self.max_action = max_action
27 |
28 | def forward(self, state):
29 | a = F.relu(self.l1(state))
30 | a = F.relu(self.l2(a))
31 | return self.max_action * torch.tanh(self.l3(a))
32 |
33 |
34 | class Critic(nn.Module):
35 | def __init__(self, state_dim, action_dim):
36 | super(Critic, self).__init__()
37 |
38 | self.l1 = nn.Linear(state_dim + action_dim, 400)
39 | self.l2 = nn.Linear(400, 300)
40 | self.l3 = nn.Linear(300, 1)
41 |
42 | def forward(self, state, action):
43 | q = F.relu(self.l1(torch.cat([state, action], 1)))
44 | q = F.relu(self.l2(q))
45 | return self.l3(q)
46 |
47 |
48 | class DDPG(object):
49 | def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.005):
50 | self.actor = Actor(state_dim, action_dim, max_action).to(device)
51 | self.actor_target = copy.deepcopy(self.actor)
52 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters())
53 |
54 | self.critic = Critic(state_dim, action_dim).to(device)
55 | self.critic_target = copy.deepcopy(self.critic)
56 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
57 |
58 | self.discount = discount
59 | self.tau = tau
60 |
61 | def select_action(self, state):
62 | state = torch.FloatTensor(state.reshape(1, -1)).to(device)
63 | return self.actor(state).cpu().data.numpy().flatten()
64 |
65 | def train(self, replay_buffer, batch_size=100):
66 | # Sample replay buffer
67 | state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
68 |
69 | # Compute the target Q value
70 | target_Q = self.critic_target(next_state, self.actor_target(next_state))
71 | target_Q = reward + (not_done * self.discount * target_Q).detach()
72 |
73 | # Get current Q estimate
74 | current_Q = self.critic(state, action)
75 |
76 | # Compute critic loss
77 | critic_loss = F.mse_loss(current_Q, target_Q)
78 |
79 | # Optimize the critic
80 | self.critic_optimizer.zero_grad()
81 | critic_loss.backward()
82 | self.critic_optimizer.step()
83 |
84 | # Compute actor loss
85 | actor_loss = -self.critic(state, self.actor(state)).mean()
86 |
87 | # Optimize the actor
88 | self.actor_optimizer.zero_grad()
89 | actor_loss.backward()
90 | self.actor_optimizer.step()
91 |
92 | # Update the frozen target models
93 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
94 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
95 |
96 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
97 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
98 |
99 | def save(self, filename):
100 | torch.save(self.critic.state_dict(), filename + "_critic")
101 | torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
102 |
103 | torch.save(self.actor.state_dict(), filename + "_actor")
104 | torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")
105 |
106 | def load(self, filename):
107 | self.critic.load_state_dict(torch.load(filename + "_critic"))
108 | self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
109 | self.critic_target = copy.deepcopy(self.critic)
110 |
111 | self.actor.load_state_dict(torch.load(filename + "_actor"))
112 | self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
113 | self.actor_target = copy.deepcopy(self.actor)
114 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TempoRL
2 |
3 | This repository contains the code for the ICML'21 paper "[TempoRL: Learning When to Act](https://ml.informatik.uni-freiburg.de/papers/21-ICML-TempoRL.pdf)".
4 |
5 | If you use TempoRL in you research or application, please cite us:
6 |
7 | ```bibtex
8 | @inproceedings{biedenkapp-icml21,
9 | author = {André Biedenkapp and Raghu Rajan and Frank Hutter and Marius Lindauer},
10 | title = {{T}empo{RL}: Learning When to Act},
11 | booktitle = {Proceedings of the 38th International Conference on Machine Learning (ICML 2021)},
12 | year = {2021},
13 | month = jul,
14 | }
15 | ```
16 |
17 | ## Appendix
18 | The appendix PDF has been uploaded to this repository and can be accessed [here](TempoRL_Appendix.pdf).
19 |
20 | ## Setup
21 | This code was developed with python 3.6.12 and torch 1.4.0.
22 | If you have the correct python version you need to install the dependencies via
23 | ```bash
24 | pip install -r requirements.txt
25 | ```
26 |
27 | If you only want to run quick experiments with the tabular agents you can install the minimal requirements in `tabular_requirements.txt` via
28 | ```bash
29 | pip install -r tabular_requirements.txt
30 | ```
31 |
32 | To make use of the provided jupyter notebook you optionally have to install jupyter
33 | ```bash
34 | pip install jupyter
35 | ```
36 |
37 | ## How to train tabular agents
38 | To run an agent on any of the below listed environments run
39 | ```bash
40 | python run_tabular_experiments.py -e 10000 --agent Agent --env env_name --eval-eps 500
41 | ```
42 | replace Agent with `q` for vanilla q-learning and `sq` for our method.
43 |
44 | ## Envs
45 | Currently 3 simple environments available.
46 | Per default all environments give a reward of 1 when reaching the goal (X).
47 | The agents start in state (S) and can traverse open fields (o).
48 | When falling into "lava" (.) the agent receives a reward of -1.
49 | For no other transition are rewards generated. (When rendering environments the agent is marked with *)
50 | An agent can use at most 100 steps to reach the goal.
51 |
52 | Modifications of the below listed environments can run without goal rewards (env_name ends in _ng)
53 | or reduce the goal reward by the number of taken steps (env_name ends in _perc).
54 | * lava (Cliff)
55 | ```console
56 | S o . . . . . . o X
57 | o o . . . . . . o o
58 | o o . . . . . . o o
59 | o o o o o o o o o o
60 | o o o o o o o o o o
61 | o o o o o o o o o o
62 | ```
63 |
64 | * lava2 (Bridge)
65 | ```console
66 | S o . . . . . . o X
67 | o o . . . . . . o o
68 | o o o o o o o o o o
69 | o o o o o o o o o o
70 | o o . . . . . . o o
71 | o o . . . . . . o o
72 | ```
73 |
74 | * lava3 (ZigZag)
75 | ```console
76 | S o . . o o o o o o
77 | o o . . o o o o o o
78 | o o . . o o . . o o
79 | o o . . o o . . o o
80 | o o o o o o . . o o
81 | o o o o o o . . o X
82 | ```
83 |
84 | ## How to train deep agents
85 | To train an agent on featurized environments run e.g.
86 | ```bash
87 | python run_featurized_experiments.py -e 10000 -t 1000000 --eval-after-n-steps 200 -s 1 --agent tdqn --skip-net-max-skips 10 --out-dir . --sparse
88 | ```
89 | replace tdqn (our agent with shared network architecture) with dqn or dar to run the respective baseline agents
90 |
91 | To train a DDPG agent run e.g.
92 | ```bash
93 | python run_ddpg_experiments.py --policy TempoRLDDPG --env Pendulum-v0 --start_timesteps 1000 --max_timesteps 30000 --eval_freq 250 --max-skip 16 --save_model --out-dir . --seed 1
94 | ```
95 | replace TempoRLDDPG with FiGARDDPG or DDPG to run the baseline agents.
96 |
97 | To train an agent on atari environments run e.g.
98 | ```bash
99 | run_atari_experiments.py --env freeway --env-max-steps 10000 --agent tdqn --out-dir experiments/atari_new/freeway/tdqn_3 --episodes 20000 --training-steps 2500000 --eval-after-n-steps 10000 --seed 12345 --84x84 --eval-n-episodes 3
100 | ```
101 | replace tdqn (our agent with shared network architecture) with dqn or dar to run the respective baseline agents.
102 |
103 | ## Experiment data
104 | ##### Note: This data is required to run the plotting jupyter notebooks
105 | We provide all learning curve data, final policy network weights as well as commands to generate that data at:
106 | https://figshare.com/s/d9264dc125ca8ba8efd8
107 |
108 | (Download this data and move it into the experiments folder)
109 |
--------------------------------------------------------------------------------
/TempoRL_Appendix.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/TempoRL_Appendix.pdf
--------------------------------------------------------------------------------
/experiments/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/experiments/.gitkeep
--------------------------------------------------------------------------------
/grid_envs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | from io import StringIO
4 | from typing import Tuple
5 | from gym.envs.toy_text.discrete import DiscreteEnv
6 | import time
7 | from scipy.spatial.distance import cityblock
8 |
9 | LEFT = 0
10 | UP = 1
11 | RIGHT = 2
12 | DOWN = 3
13 |
14 |
15 | class Specs:
16 | def __init__(self, max_steps):
17 | self.max_episode_steps = max_steps
18 |
19 |
20 | class GridCore(DiscreteEnv):
21 | metadata = {'render.modes': ['human', 'ansi']}
22 |
23 | def __init__(self, shape: Tuple[int] = (5, 10), start: Tuple[int] = (0, 0),
24 | goal: Tuple[int] = (0, 9), max_steps: int = 100,
25 | percentage_reward: bool = False, no_goal_rew: bool = False,
26 | dense_reward: bool = False, numpy_state: bool = True,
27 | xy_state: bool = False):
28 | try:
29 | self.shape = self._shape
30 | except AttributeError:
31 | self.shape = shape
32 | self.nS = np.prod(self.shape, dtype=int) # type: int
33 | self.nA = 4
34 | self.start = start
35 | self.goal = goal
36 | self.max_steps = max_steps
37 | self._steps = 0
38 | self._pr = percentage_reward
39 | self._no_goal_rew = no_goal_rew
40 | self.total_steps = 0
41 | self._dr = dense_reward
42 | self.spec = Specs(max_steps)
43 | self._nps = numpy_state
44 | self.xy = xy_state
45 |
46 | P = self._init_transition_probability()
47 |
48 | # We always start in state (3, 0)
49 | if start is not None:
50 | isd = np.zeros(self.nS)
51 | isd[np.ravel_multi_index(start, self.shape)] = 1.0
52 | else:
53 | isd = np.ones(self.nS)
54 | isd[np.ravel_multi_index(goal, self.shape)] = 0.0
55 | for pit in self._pits:
56 | isd[pit] = 0.0
57 | isd *= 1 / (self.nS - len(self._pits) - 1)
58 |
59 | super(GridCore, self).__init__(self.nS, self.nA, P, isd)
60 |
61 | def step(self, a):
62 | self._steps += 1
63 | s, r, d, i = super(GridCore, self).step(a)
64 | if self.xy:
65 | self.s = np.ravel_multi_index(self.s, self.shape)
66 | if self._steps >= self.max_steps:
67 | d = True
68 | i['early'] = True
69 | i['needs_reset'] = True
70 | self.total_steps += 1
71 | if self._nps:
72 | return np.array([s], dtype=np.float32), r, d, i
73 | return s, r, d, i
74 |
75 | def reset(self):
76 | self._steps = 0
77 | self.s = np.random.choice(range(len(self.isd)), p=self.isd)
78 | self.lastaction = None
79 | if self.xy:
80 | s = np.unravel_index(self.s, self.shape)
81 | if self._nps:
82 | return np.array([s], dtype=np.float32)
83 | else:
84 | return s
85 | if self._nps:
86 | return np.array([self.s], dtype=np.float32)
87 | return self.s
88 |
89 | def _init_transition_probability(self):
90 | raise NotImplementedError
91 |
92 | def _check_bounds(self, coord):
93 | coord[0] = min(coord[0], self.shape[0] - 1)
94 | coord[0] = max(coord[0], 0)
95 | coord[1] = min(coord[1], self.shape[1] - 1)
96 | coord[1] = max(coord[1], 0)
97 | return coord
98 |
99 | def print_T(self):
100 | print(self.P[self.s])
101 |
102 | def map_output(self, s, pos):
103 | if self.s == s:
104 | output = " x "
105 | elif pos == self.goal:
106 | output = " T "
107 | else:
108 | output = " o "
109 | return output
110 |
111 | def map_control_output(self, s, pos):
112 | return self.map_output(s, pos)
113 |
114 | def map_with_inbetween_goal(self, s, pos, in_between_goal):
115 | return self.map_output(s, pos)
116 |
117 | def render(self, mode='human', close=False, in_control=None, in_between_goal=None):
118 | self._render(mode, close, in_control, in_between_goal)
119 |
120 | def _render(self, mode='human', close=False, in_control=None, in_between_goal=None):
121 | if close:
122 | return
123 | outfile = StringIO() if mode == 'ansi' else sys.stdout
124 | if mode == 'human':
125 | print('\033[2;0H')
126 |
127 | for s in range(self.nS):
128 | position = np.unravel_index(s, self.shape)
129 | # print(self.s)
130 | if in_control:
131 | output = self.map_control_output(s, position)
132 | elif in_between_goal:
133 | output = self.map_with_inbetween_goal(s, position, in_between_goal)
134 | else:
135 | output = self.map_output(s, position)
136 | if position[1] == 0:
137 | output = output.lstrip()
138 | if position[1] == self.shape[1] - 1:
139 | output = output.rstrip()
140 | output += "\n"
141 | outfile.write(output)
142 | outfile.write("\n")
143 | if mode == 'human':
144 | if in_control:
145 | time.sleep(0.2)
146 | else:
147 | time.sleep(0.05)
148 |
149 |
150 | class FallEnv(GridCore):
151 | _pits = []
152 |
153 | def __init__(self, act_fail_prob: float = 1.0, **kwargs):
154 | self.afp = act_fail_prob
155 | super(FallEnv, self).__init__(**kwargs)
156 |
157 | def _calculate_transition_prob(self, current, delta, prob):
158 | transitions = []
159 | for d, p in zip(delta, prob):
160 | new_position = np.array(current) + np.array(d)
161 | new_position = self._check_bounds(new_position).astype(int)
162 | new_state = np.ravel_multi_index(tuple(new_position), self.shape)
163 | if not self._dr:
164 | reward = 0.0
165 | is_done = False
166 | if tuple(new_position) == self.goal:
167 | if self._pr:
168 | reward = 1 - (self._steps / self.max_steps)
169 | elif not self._no_goal_rew:
170 | reward = 1.0
171 | is_done = True
172 | elif new_state in self._pits:
173 | reward = -1.
174 | is_done = True
175 | transitions.append((p, new_position if self.xy else new_state, reward, is_done))
176 | else:
177 | reward = -cityblock(new_position, self.goal)
178 | is_done = False
179 | if tuple(new_position) == self.goal:
180 | if self._pr:
181 | reward = 1 - (self._steps / self.max_steps)
182 | elif not self._no_goal_rew:
183 | reward = 100.0
184 | is_done = True
185 | elif new_state in self._pits:
186 | reward = -100
187 | is_done = True
188 | transitions.append((p, new_position if self.xy else new_state, reward, is_done))
189 | return transitions
190 |
191 | def _init_transition_probability(self):
192 | for idx, p in enumerate(self._pits):
193 | try:
194 | self._pits[idx] = np.ravel_multi_index(p, self.shape)
195 | except:
196 | pass # <- this has to be here for the agent. Otherwise it throws an unexplainable error
197 | # Calculate transition probabilities
198 | P = {}
199 | for s in range(self.nS):
200 | position = np.unravel_index(s, self.shape)
201 | P[s] = {a: [] for a in range(self.nA)}
202 | other_prob = self.afp / 3.
203 | tmp = [[UP, DOWN, LEFT, RIGHT],
204 | [DOWN, LEFT, RIGHT, UP],
205 | [LEFT, RIGHT, UP, DOWN],
206 | [RIGHT, UP, DOWN, LEFT]]
207 | tmp_dirs = [[[-1, 0], [1, 0], [0, -1], [0, 1]],
208 | [[1, 0], [0, -1], [0, 1], [-1, 0]],
209 | [[0, -1], [0, 1], [-1, 0], [1, 0]],
210 | [[0, 1], [-1, 0], [1, 0], [0, -1]]]
211 | tmp_pros = [[1 - self.afp, other_prob, other_prob, other_prob],
212 | [1 - self.afp, other_prob, other_prob, other_prob],
213 | [1 - self.afp, other_prob, other_prob, other_prob],
214 | [1 - self.afp, other_prob, other_prob, other_prob], ]
215 | for acts, dirs, probs in zip(tmp, tmp_dirs, tmp_pros):
216 | P[s][acts[0]] = self._calculate_transition_prob(position, dirs, probs)
217 | return P
218 |
219 | def map_output(self, s, pos):
220 | if self.s == s:
221 | output = " \u001b[33m*\u001b[37m "
222 | elif pos == self.goal:
223 | output = " \u001b[37mX\u001b[37m "
224 | elif s in self._pits:
225 | output = " \u001b[31m.\u001b[37m "
226 | else:
227 | output = " \u001b[30mo\u001b[37m "
228 | return output
229 |
230 | def map_control_output(self, s, pos):
231 | if self.s == s:
232 | return " \u001b[34m*\u001b[37m "
233 | else:
234 | return self.map_output(s, pos)
235 |
236 | def map_with_inbetween_goal(self, s, pos, in_between_goal):
237 | if s == in_between_goal:
238 | return " \u001b[34mx\u001b[37m "
239 | else:
240 | return self.map_output(s, pos)
241 |
242 |
243 | class Bridge6x10Env(FallEnv):
244 | _pits = [[0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7],
245 | [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7],
246 | [4, 2], [4, 3], [4, 4], [4, 5], [4, 6], [4, 7],
247 | [5, 2], [5, 3], [5, 4], [5, 5], [5, 6], [5, 7]]
248 | _shape = (6, 10)
249 |
250 |
251 | class Pit6x10Env(FallEnv):
252 | _pits = [[0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7],
253 | [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7],
254 | [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7]]
255 | # [3, 2], [3, 3], [3, 4], [3, 5], [3, 6], [3, 7]]
256 | _shape = (6, 10)
257 |
258 |
259 | class ZigZag6x10(FallEnv):
260 | _pits = [[0, 2], [0, 3],
261 | [1, 2], [1, 3],
262 | [2, 2], [2, 3],
263 | [3, 2], [3, 3],
264 | [5, 7], [5, 6],
265 | [4, 7], [4, 6],
266 | [3, 7], [3, 6],
267 | [2, 7], [2, 6],
268 | ]
269 | _shape = (6, 10)
270 |
271 |
272 | class ZigZag6x10H(FallEnv):
273 | _pits = [[0, 2], [0, 3],
274 | [1, 2], [1, 3],
275 | [2, 2], [2, 3],
276 | [3, 2], [3, 3],
277 | [5, 7], [5, 6],
278 | [4, 7], [4, 6],
279 | [3, 7], [3, 6],
280 | [2, 7], [2, 6],
281 | [4, 4], [5, 2]
282 | ]
283 | _shape = (6, 10)
284 |
--------------------------------------------------------------------------------
/mountain_car.py:
--------------------------------------------------------------------------------
1 | """
2 | http://incompleteideas.net/sutton/MountainCar/MountainCar1.cp
3 | permalink: https://perma.cc/6Z2N-PFWC
4 | """
5 |
6 | import math
7 | import gym
8 | from gym import spaces
9 | from gym.utils import seeding
10 | import numpy as np
11 |
12 | class MountainCarEnv(gym.Env):
13 | metadata = {
14 | 'render.modes': ['human', 'rgb_array'],
15 | 'video.frames_per_second': 30
16 | }
17 |
18 | def __init__(self):
19 | self.min_position = -1.2
20 | self.max_position = 0.6
21 | self.max_speed = 0.07
22 | self.goal_position = 0.5
23 |
24 | self.low = np.array([self.min_position, -self.max_speed])
25 | self.high = np.array([self.max_position, self.max_speed])
26 |
27 | self.viewer = None
28 |
29 | self.action_space = spaces.Discrete(3)
30 | self.observation_space = spaces.Box(self.low, self.high, dtype=np.float32)
31 |
32 | self.seed()
33 | self.reset()
34 |
35 | def seed(self, seed=None):
36 | self.np_random, seed = seeding.np_random(seed)
37 | return [seed]
38 |
39 | def step(self, action):
40 | for _ in range(4):
41 | assert self.action_space.contains(action), "%r (%s) invalid" % (action, type(action))
42 |
43 | position, velocity = self.state
44 | velocity += (action-1)*0.001 + math.cos(3*position)*(-0.0025)
45 | velocity = np.clip(velocity, -self.max_speed, self.max_speed)
46 | position += velocity
47 | position = np.clip(position, self.min_position, self.max_position)
48 | if (position==self.min_position and velocity<0): velocity = 0
49 |
50 | done = bool(position >= self.goal_position)
51 | if not done:
52 | reward = -0.1*np.abs(self.goal_position - position)
53 | else:
54 | reward = 10
55 | self.state = (position, velocity)
56 |
57 | return np.array(self.state), reward, done, {}
58 |
59 | def reset(self):
60 | self.state = np.array([-0.5, 0])
61 | return np.array(self.state)
62 |
63 | def _height(self, xs):
64 | return np.sin(3 * xs) * .45 + .55
65 |
66 | def close(self):
67 | if self.viewer is not None:
68 | self.viewer.close()
69 | self.viewer = None
70 |
71 | def render(self, mode='human', close=False):
72 | if close:
73 | if self.viewer is not None:
74 | self.viewer.close()
75 | self.viewer = None
76 | return
77 |
78 | screen_width = 600
79 | screen_height = 400
80 |
81 | world_width = self.max_position - self.min_position
82 | scale = screen_width/world_width
83 | carwidth=40
84 | carheight=20
85 |
86 |
87 | if self.viewer is None:
88 | from gym.envs.classic_control import rendering
89 | self.viewer = rendering.Viewer(screen_width, screen_height)
90 | xs = np.linspace(self.min_position, self.max_position, 100)
91 | ys = self._height(xs)
92 | xys = list(zip((xs-self.min_position)*scale, ys*scale))
93 |
94 | self.track = rendering.make_polyline(xys)
95 | self.track.set_linewidth(4)
96 | self.viewer.add_geom(self.track)
97 |
98 | clearance = 10
99 |
100 | l,r,t,b = -carwidth/2, carwidth/2, carheight, 0
101 | car = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
102 | car.add_attr(rendering.Transform(translation=(0, clearance)))
103 | self.cartrans = rendering.Transform()
104 | car.add_attr(self.cartrans)
105 | self.viewer.add_geom(car)
106 | frontwheel = rendering.make_circle(carheight/2.5)
107 | frontwheel.set_color(.5, .5, .5)
108 | frontwheel.add_attr(rendering.Transform(translation=(carwidth/4,clearance)))
109 | frontwheel.add_attr(self.cartrans)
110 | self.viewer.add_geom(frontwheel)
111 | backwheel = rendering.make_circle(carheight/2.5)
112 | backwheel.add_attr(rendering.Transform(translation=(-carwidth/4,clearance)))
113 | backwheel.add_attr(self.cartrans)
114 | backwheel.set_color(.5, .5, .5)
115 | self.viewer.add_geom(backwheel)
116 | flagx = (self.goal_position-self.min_position)*scale
117 | flagy1 = self._height(self.goal_position)*scale
118 | flagy2 = flagy1 + 50
119 | flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
120 | self.viewer.add_geom(flagpole)
121 | flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2-10), (flagx+25, flagy2-5)])
122 | flag.set_color(.8,.8,0)
123 | self.viewer.add_geom(flag)
124 |
125 | pos = self.state[0]
126 | self.cartrans.set_translation((pos-self.min_position)*scale, self._height(pos)*scale)
127 | self.cartrans.set_rotation(math.cos(3 * pos))
128 |
129 | return self.viewer.render(return_rgb_array = mode=='rgb_array')
130 |
--------------------------------------------------------------------------------
/plot_atari_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "from utils.plotting import get_colors, load_config, plot\n",
12 | "from utils.data_handling import load_dqn_data\n",
13 | "import numpy as np"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "#### Name explanations\n",
21 | "* DQN -> standard DQN\n",
22 | "* DAR_min^max -> Dynamic action repetition with small repetition and long repetition values\n",
23 | "* tqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action to be concatenated to the state\n",
24 | "* t-dqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action as contextual input\n",
25 | "* tdqn -> TempoRL DQN with shared state representation between the behavoiur and skip action outputs."
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import json\n",
35 | "import glob\n",
36 | "import os\n",
37 | "import pandas as pd\n",
38 | "from matplotlib import pyplot as plt\n",
39 | "import seaborn as sb\n",
40 | "\n",
41 | "from scipy.signal import savgol_filter\n",
42 | " \n",
43 | "\n",
44 | "# Somehow the plotting functionallity I ended up with was already covered for the tabular case.\n",
45 | "# I should have just used the plot function from that.\n",
46 | "def plotMultiple(data, ylim=None, title='', logStepY=False, max_steps=200, xlim=None, figsize=None,\n",
47 | " alphas=None, smooth=5, savename=None, rewyticks=None, lenyticks=None,\n",
48 | " skip_stdevs=[], dont_label=[], dont_plot=[], min_steps=None,\n",
49 | " logRewY=False):\n",
50 | " \"\"\"\n",
51 | " Simple plotting method that shows the test reward on the y-axis and the number of performed training steps\n",
52 | " on the x-axis.\n",
53 | " \n",
54 | " data -> (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes])) the data to plot\n",
55 | " ylim -> (list) y-axis limit\n",
56 | " title -> (str) title on top of plot\n",
57 | " logStepY -> (bool) flag that indicates if the y-axis should be on log scale.\n",
58 | " max_steps -> (int) maximal episode length\n",
59 | " min_steps -> (int) optional minimum episode length. If not set assumes 1 as min\n",
60 | " xlim -> (list) x-axis limits\n",
61 | " figsize -> (list) dimensions of the figure\n",
62 | " alphas -> (dict[agent name] -> float) the alpha value to use for plotting of specific agents\n",
63 | " smooth -> (int) the window size for smoothing (has to be odd if used. < 0 deactivates this option)\n",
64 | " savename -> (str) filename to save the figure\n",
65 | " rewyticks -> (list) yticks for the reward plot\n",
66 | " lenyticks -> (list) yticks for the decisions plot\n",
67 | " skip_sdevs -> (list) list of names to not plot standard deviations for.\n",
68 | " dont_label -> (list) list of names to not label.\n",
69 | " dont_plot -> (list) list of names to not plot.\n",
70 | " logRewY -> (bool) flag that indicates if the reward y-axis should be on log scale.\n",
71 | " \"\"\"\n",
72 | " \n",
73 | " if smooth and smooth > 0:\n",
74 | " degree = 2\n",
75 | " for agent in data:\n",
76 | " data[agent] = list(data[agent]) # we have to convert the tuple to lists\n",
77 | " data[agent][0] = list(data[agent][0])\n",
78 | " data[agent][0][0] = savgol_filter(data[agent][0][0], smooth, degree) # smooth the mean reward\n",
79 | " data[agent][0][1] = savgol_filter(data[agent][0][1], smooth, degree) # smooth the stdev reward\n",
80 | " data[agent][1] = list(data[agent][1])\n",
81 | " data[agent][1][0] = savgol_filter(data[agent][1][0], smooth, degree) # smooth mean num steps\n",
82 | " data[agent][1][1] = savgol_filter(data[agent][1][1], smooth, degree)\n",
83 | " data[agent][2] = list(data[agent][2])\n",
84 | " data[agent][2][0] = savgol_filter(data[agent][2][0], smooth, degree) # smooth mean decisions\n",
85 | " data[agent][2][1] = savgol_filter(data[agent][2][1], smooth, degree)\n",
86 | "\n",
87 | " colors, color_map = get_colors()\n",
88 | " \n",
89 | "\n",
90 | " cfg = load_config()\n",
91 | " sb.set_style(cfg['plotting']['seaborn']['style'])\n",
92 | " sb.set_context(cfg['plotting']['seaborn']['context']['context'],\n",
93 | " font_scale=cfg['plotting']['seaborn']['context']['font scale'],\n",
94 | " rc=cfg['plotting']['seaborn']['context']['rc2'])\n",
95 | "\n",
96 | " if figsize:\n",
97 | " fig, ax = plt.subplots(2, figsize=figsize, dpi=100, sharex=True)\n",
98 | " else:\n",
99 | " fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100,sharex=True)\n",
100 | " ax[0].set_title(title)\n",
101 | "\n",
102 | " for agent in list(data.keys())[::-1]:\n",
103 | " if agent in dont_plot:\n",
104 | " continue\n",
105 | " try:\n",
106 | " alph = alphas[agent]\n",
107 | " except:\n",
108 | " alph = 1.\n",
109 | " color_name = color_map['dar'] if 'dar' in agent else color_map[agent]\n",
110 | " rew, lens, decs, train_steps, train_eps = data[agent]\n",
111 | " \n",
112 | " label = agent.upper()\n",
113 | " if agent in ['t-dqn', 'tdqn', 'tqn']:\n",
114 | " label = 't-DQN'\n",
115 | " elif agent in dont_label:\n",
116 | " label = None\n",
117 | "\n",
118 | " #### Plot rewards\n",
119 | " ax[0].step(train_steps[0], rew[0], where='post', c=colors[color_name], label=label,\n",
120 | " alpha=alph, ls='-' if agent != 't-dqn' else '-.')\n",
121 | " if agent not in skip_stdevs:\n",
122 | " ax[0].fill_between(train_steps[0], rew[0]-rew[1], rew[0]+rew[1],\n",
123 | " alpha=0.25 * alph, step='post',\n",
124 | " color=colors[color_name])\n",
125 | " #### Plot lens\n",
126 | " ax[1].step(train_steps[0], decs[0], where='post',\n",
127 | " c=np.array(colors[color_name]), ls='-',\n",
128 | " alpha=alph)\n",
129 | " if agent not in skip_stdevs:\n",
130 | " ax[1].fill_between(train_steps[0], decs[0]-decs[1], decs[0]+decs[1],\n",
131 | " alpha=0.125 * alph, step='post',\n",
132 | " color=np.array(colors[color_name]))\n",
133 | " ax[1].step(train_steps[0], lens[0], where='post',\n",
134 | " c=np.array(colors[color_name]) * .75, alpha=alph,\n",
135 | " ls=':')\n",
136 | " \n",
137 | " if agent not in skip_stdevs:\n",
138 | " ax[1].fill_between(train_steps[0], lens[0]-lens[1], lens[0]+lens[1],\n",
139 | " alpha=0.25 * alph, step='post',\n",
140 | " color=np.array(colors[color_name]) * .75)\n",
141 | " #ax[0].semilogx()\n",
142 | " if rewyticks is not None:\n",
143 | " ax[0].set_yticks(rewyticks)\n",
144 | " if ylim:\n",
145 | " ax[0].set_ylim(ylim)\n",
146 | " if xlim:\n",
147 | " ax[0].set_xlim(xlim)\n",
148 | " ax[0].set_ylabel('Reward')\n",
149 | " if len(data) - len(dont_label) < 5:\n",
150 | " ax[0].legend(ncol=1, loc='best', handlelength=.75)\n",
151 | " ax[1].semilogx()\n",
152 | " if logStepY:\n",
153 | " ax[1].semilogy()\n",
154 | " if logRewY:\n",
155 | " ax[0].semilogy()\n",
156 | " \n",
157 | " ax[1].plot([-999, -999], [-999, -999], ls=':', c='k', label='all')\n",
158 | " ax[1].plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')\n",
159 | " ax[1].legend(loc='best', ncol=1, handlelength=.75)\n",
160 | " if not min_steps:\n",
161 | " ax[1].set_ylim([1, max_steps])\n",
162 | " else:\n",
163 | " ax[1].set_ylim([min_steps, max_steps])\n",
164 | " if xlim:\n",
165 | " ax[1].set_xlim(xlim)\n",
166 | " ax[1].set_ylabel('#Actions')\n",
167 | " ax[1].set_xlabel('#Train Steps')\n",
168 | " if lenyticks is not None:\n",
169 | " ax[1].set_yticks(lenyticks)\n",
170 | " plt.tight_layout()\n",
171 | " if savename:\n",
172 | " plt.savefig(savename)\n",
173 | "\n",
174 | " plt.show()\n",
175 | "\n",
176 | "\n",
177 | "def get_best_to_plot(data, aucs, tempoRL=None):\n",
178 | " \"\"\"\n",
179 | " Simple method to filter which lines to plot.\n",
180 | " \"\"\"\n",
181 | " to_plot = dict()\n",
182 | "\n",
183 | " if tempoRL is None:\n",
184 | " aucs = list(sorted(aucs, key=lambda x: x[1], reverse=True))\n",
185 | " for idx, auc in enumerate(aucs):\n",
186 | " if 't' in auc[0]:\n",
187 | " break\n",
188 | " to_plot[aucs[idx][0]] = data[aucs[idx][0]] # the absolute best\n",
189 | " else:\n",
190 | " to_plot[tempoRL] = data[tempoRL]\n",
191 | "\n",
192 | " bv = -np.inf\n",
193 | " b = None\n",
194 | " for elem in aucs:\n",
195 | " if 'dar' not in elem[0]:\n",
196 | " continue\n",
197 | " elif elem[1] > bv:\n",
198 | " b, bv = elem[0], elem[1]\n",
199 | " to_plot[b] = data[b]\n",
200 | " \n",
201 | " \n",
202 | " to_plot['dqn'] = data['dqn']\n",
203 | " return to_plot"
204 | ]
205 | },
206 | {
207 | "cell_type": "markdown",
208 | "metadata": {},
209 | "source": [
210 | " "
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {},
216 | "source": [
217 | " "
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "metadata": {
224 | "scrolled": false
225 | },
226 | "outputs": [],
227 | "source": [
228 | "data = {}\n",
229 | "\n",
230 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/pong/tdqn',\n",
231 | " #debug=True,\n",
232 | " max_steps=2.5*10**6)\n",
233 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/pong/dqn',\n",
234 | " #debug=True,\n",
235 | " max_steps=2.5*10**6)\n",
236 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/pong/dar',\n",
237 | " #debug=True,\n",
238 | " max_steps=2.5*10**6)\n",
239 | "\n",
240 | "plotMultiple(data, title='Pong',\n",
241 | " ylim=[-22, 22], xlim=[10**4, 2.5*10**6],\n",
242 | " min_steps=10**2, max_steps=3000, lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n",
243 | " smooth=7, savename='pong_50_seeds.pdf') #, logStepY=True)"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "metadata": {
250 | "scrolled": false
251 | },
252 | "outputs": [],
253 | "source": [
254 | "data = {}\n",
255 | "\n",
256 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/beam_rider/tdqn_3',\n",
257 | " #debug=True,\n",
258 | " max_steps=2.5*10**6)\n",
259 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/beam_rider/dqn_3',\n",
260 | " #debug=True,\n",
261 | " max_steps=2.5*10**6)\n",
262 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/beam_rider/dar_3',\n",
263 | " #debug=True,\n",
264 | " max_steps=2.5*10**6)\n",
265 | "\n",
266 | "plotMultiple(data, title='BeamRider',\n",
267 | " ylim=[0, 600],\n",
268 | " xlim=[10**4, 2.5*10**6],\n",
269 | " max_steps=1000, rewyticks=[0, 150, 300, 450, 600], #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n",
270 | " smooth=7, savename='beamrider_15_seeds.pdf') #, logStepY=True)"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {
277 | "scrolled": false
278 | },
279 | "outputs": [],
280 | "source": [
281 | "data = {}\n",
282 | "\n",
283 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/freeway/tdqn_3',\n",
284 | " #debug=True,\n",
285 | " max_steps=2.5*10**6)\n",
286 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/freeway/dqn_3',\n",
287 | " #debug=True,\n",
288 | " max_steps=2.5*10**6)\n",
289 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/freeway/dar_3',\n",
290 | " #debug=True,\n",
291 | " max_steps=2.5*10**6)\n",
292 | "\n",
293 | "plotMultiple(data, title='Freeway',\n",
294 | " ylim=[0, 35],\n",
295 | " xlim=[10**4, 2.5*10**6],\n",
296 | " max_steps=2100, rewyticks=[0, 11, 22, 33], #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n",
297 | " smooth=7, savename='freeway_15_seeds.pdf') #, logStepY=True)"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {
304 | "scrolled": false
305 | },
306 | "outputs": [],
307 | "source": [
308 | "data = {}\n",
309 | "\n",
310 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/ms_pacman/tdqn_3',\n",
311 | " #debug=True,\n",
312 | " max_steps=2.5*10**6)\n",
313 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/ms_pacman/dqn_3',\n",
314 | " #debug=True,\n",
315 | " max_steps=2.5*10**6)\n",
316 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/ms_pacman/dar_3',\n",
317 | " #debug=True,\n",
318 | " max_steps=2.5*10**6)\n",
319 | "\n",
320 | "plotMultiple(data, title='MsPacman',\n",
321 | " ylim=[0, 750],\n",
322 | " xlim=[10**4, 2.5*10**6],\n",
323 | " max_steps=300, rewyticks=[0, 250, 500, 750], #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],\n",
324 | " smooth=7, savename='mspacman_15_seeds.pdf') #, logStepY=True)"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": null,
330 | "metadata": {
331 | "scrolled": false
332 | },
333 | "outputs": [],
334 | "source": [
335 | "data = {}\n",
336 | "\n",
337 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/qbert_long/tdqn_3',\n",
338 | " #debug=True,\n",
339 | " max_steps=5*10**6)\n",
340 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/qbert_long/dqn_3',\n",
341 | " #debug=True,\n",
342 | " max_steps=5*10**6)\n",
343 | "data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/qbert_long/dar_3',\n",
344 | " #debug=True,\n",
345 | " max_steps=5*10**6)\n",
346 | "\n",
347 | "plotMultiple(data, title='QBert',\n",
348 | " ylim=[0, 1000],\n",
349 | " xlim=[10**4, 5*10**6], logRewY=False,\n",
350 | " max_steps=225, min_steps=0, rewyticks=[0, 250, 500, 750, 1000], lenyticks=[0, 50, 100, 150, 200],\n",
351 | " smooth=7, savename='qbert_15_sees.pdf') #, logStepY=True)"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": null,
357 | "metadata": {
358 | "scrolled": false
359 | },
360 | "outputs": [],
361 | "source": [
362 | "data = {}\n",
363 | "\n",
364 | "data['tdqn'] = load_dqn_data('*', 'experiments/atari/qbert/tdqn_3',\n",
365 | " #debug=True,\n",
366 | " max_steps=2.5*10**6)\n",
367 | "data['dqn'] = load_dqn_data('*', 'experiments/atari/qbert/dqn_3',\n",
368 | " #debug=True,\n",
369 | " max_steps=2.5*10**6)\n",
370 | "\n",
371 | "plotMultiple(data, title='QBert',\n",
372 | " ylim=[0, 1000],\n",
373 | " xlim=[10**4, 2.5*10**6], logRewY=False,\n",
374 | " max_steps=225, min_steps=0, rewyticks=[0, 250, 500, 750, 1000], lenyticks=[0, 50, 100, 150, 200],\n",
375 | " smooth=7) #, logStepY=True)"
376 | ]
377 | }
378 | ],
379 | "metadata": {
380 | "kernelspec": {
381 | "display_name": "Python 3",
382 | "language": "python",
383 | "name": "python3"
384 | },
385 | "language_info": {
386 | "codemirror_mode": {
387 | "name": "ipython",
388 | "version": 3
389 | },
390 | "file_extension": ".py",
391 | "mimetype": "text/x-python",
392 | "name": "python",
393 | "nbconvert_exporter": "python",
394 | "pygments_lexer": "ipython3",
395 | "version": "3.7.7"
396 | }
397 | },
398 | "nbformat": 4,
399 | "nbformat_minor": 4
400 | }
401 |
--------------------------------------------------------------------------------
/plot_ddpg.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "import os\n",
12 | "from utils.plotting import get_colors, load_config, plot\n",
13 | "import numpy as np\n",
14 | "from matplotlib import pyplot as plt"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import json\n",
24 | "import glob\n",
25 | "import os\n",
26 | "import pandas as pd\n",
27 | "from matplotlib import pyplot as plt\n",
28 | "import seaborn as sb\n",
29 | "\n",
30 | "from scipy.signal import savgol_filter\n",
31 | " \n",
32 | "\n",
33 | "# Somehow the plotting functionallity I ended up with was already covered for the tabular case.\n",
34 | "# I should have just used the plot function from that.\n",
35 | "def plotMultiple(data, ylim=None, title='', logStepY=False, max_steps=200, xlim=None, figsize=None,\n",
36 | " alphas=None, smooth=5, savename=None, rewyticks=None, lenyticks=None,\n",
37 | " skip_stdevs=[], dont_label=[], dont_plot=[], min_steps=None):\n",
38 | " \"\"\"\n",
39 | " Simple plotting method that shows the test reward on the y-axis and the number of performed training steps\n",
40 | " on the x-axis.\n",
41 | " \n",
42 | " data -> (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes])) the data to plot\n",
43 | " ylim -> (list) y-axis limit\n",
44 | " title -> (str) title on top of plot\n",
45 | " logStepY -> (bool) flag that indicates if the y-axis should be on log scale.\n",
46 | " max_steps -> (int) maximal episode length\n",
47 | " xlim -> (list) x-axis limits\n",
48 | " figsize -> (list) dimensions of the figure\n",
49 | " alphas -> (dict[agent name] -> float) the alpha value to use for plotting of specific agents\n",
50 | " smooth -> (int) the window size for smoothing (has to be odd if used. < 0 deactivates this option)\n",
51 | " savename -> (str) filename to save the figure\n",
52 | " rewyticks -> (list) yticks for the reward plot\n",
53 | " lenyticks -> (list) yticks for the decisions plot\n",
54 | " skip_sdevs -> (list) list of names to not plot standard deviations for.\n",
55 | " dont_label -> (list) list of names to not label.\n",
56 | " dont_plot -> (list) list of names to not plot.\n",
57 | " \"\"\"\n",
58 | " \n",
59 | " if smooth and smooth > 0:\n",
60 | " degree = 2\n",
61 | " for agent in data:\n",
62 | " data[agent] = list(data[agent]) # we have to convert the tuple to lists\n",
63 | " data[agent][0] = list(data[agent][0])\n",
64 | " data[agent][0][0] = savgol_filter(data[agent][0][0], smooth, degree) # smooth the mean reward\n",
65 | " data[agent][0][1] = savgol_filter(data[agent][0][1], smooth, degree) # smooth the stdev reward\n",
66 | " data[agent][1] = list(data[agent][1])\n",
67 | " data[agent][1][0] = savgol_filter(data[agent][1][0], smooth, degree) # smooth mean num steps\n",
68 | " data[agent][1][1] = savgol_filter(data[agent][1][1], smooth, degree)\n",
69 | " data[agent][2] = list(data[agent][2])\n",
70 | " data[agent][2][0] = savgol_filter(data[agent][2][0], smooth, degree) # smooth mean decisions\n",
71 | " data[agent][2][1] = savgol_filter(data[agent][2][1], smooth, degree)\n",
72 | "\n",
73 | " colors, color_map = get_colors()\n",
74 | " \n",
75 | "\n",
76 | " cfg = load_config()\n",
77 | " sb.set_style(cfg['plotting']['seaborn']['style'])\n",
78 | " sb.set_context(cfg['plotting']['seaborn']['context']['context'],\n",
79 | " font_scale=cfg['plotting']['seaborn']['context']['font scale'],\n",
80 | " rc=cfg['plotting']['seaborn']['context']['rc2'])\n",
81 | "\n",
82 | " if figsize:\n",
83 | " fig, ax = plt.subplots(2, figsize=figsize, dpi=100, sharex=True)\n",
84 | " else:\n",
85 | " fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100,sharex=True)\n",
86 | " ax[0].set_title(title)\n",
87 | "\n",
88 | " for agent in list(data.keys())[::-1]:\n",
89 | " if agent in dont_plot:\n",
90 | " continue\n",
91 | " try:\n",
92 | " alph = alphas[agent]\n",
93 | " except:\n",
94 | " alph = 1.\n",
95 | " color_name = None\n",
96 | " if 'dar' in agent:\n",
97 | " color_name = color_map['dar']\n",
98 | " elif agent.startswith('t'):\n",
99 | " color_name = color_map['t-DDPG']\n",
100 | " elif agent.startswith('f'):\n",
101 | " color_name = color_map['f-DDPG']\n",
102 | " else:\n",
103 | " color_name = color_map[agent]\n",
104 | " rew, lens, decs, train_steps, train_eps = data[agent]\n",
105 | " \n",
106 | " label = agent.upper()\n",
107 | " if agent.startswith('t'):\n",
108 | " label = 't-DDPG'\n",
109 | " elif agent.startswith('f'):\n",
110 | " label = 'FiGAR'\n",
111 | " elif agent.startswith('e'):\n",
112 | " label = r'$\\epsilon$z-DQN'\n",
113 | " elif agent in dont_label:\n",
114 | " label = None\n",
115 | "\n",
116 | " #### Plot rewards\n",
117 | " ax[0].step(train_steps[0], rew[0], where='post', c=colors[color_name], label=label,\n",
118 | " alpha=alph)\n",
119 | " if agent not in skip_stdevs:\n",
120 | " ax[0].fill_between(train_steps[0], rew[0]-rew[1], rew[0]+rew[1], alpha=0.25 * alph, step='post',\n",
121 | " color=colors[color_name])\n",
122 | " #### Plot lens\n",
123 | " ax[1].step(train_steps[0], decs[0], where='post', c=np.array(colors[color_name]), ls='-',\n",
124 | " alpha=alph)\n",
125 | " if agent not in skip_stdevs:\n",
126 | " ax[1].fill_between(train_steps[0], decs[0]-decs[1], decs[0]+decs[1], alpha=0.125 * alph, step='post',\n",
127 | " color=np.array(colors[color_name]))\n",
128 | " ax[1].step(train_steps[0], lens[0], where='post',\n",
129 | " c=np.array(colors[color_name]) * .75, alpha=alph, ls=':')\n",
130 | " \n",
131 | " if agent not in skip_stdevs:\n",
132 | " ax[1].fill_between(train_steps[0], lens[0]-lens[1], lens[0]+lens[1], alpha=0.25 * alph, step='post',\n",
133 | " color=np.array(colors[color_name]) * .75)\n",
134 | " ax[0].semilogx()\n",
135 | " if rewyticks is not None:\n",
136 | " ax[0].set_yticks(rewyticks)\n",
137 | " if ylim:\n",
138 | " ax[0].set_ylim(ylim)\n",
139 | " if xlim:\n",
140 | " ax[0].set_xlim(xlim)\n",
141 | " ax[0].set_ylabel('Reward')\n",
142 | " if len(data) - len(dont_label) < 5:\n",
143 | " ax[0].legend(ncol=1, loc='best', handlelength=.75)\n",
144 | " ax[1].semilogx()\n",
145 | " if logStepY:\n",
146 | " ax[1].semilogy()\n",
147 | " \n",
148 | " ax[1].plot([-999, -999], [-999, -999], ls=':', c='k', label='all')\n",
149 | " ax[1].plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')\n",
150 | " ax[1].legend(loc='best', ncol=1, handlelength=.75)\n",
151 | " ax[1].set_ylim([min_steps if min_steps is not None else 1, max_steps])\n",
152 | " if xlim:\n",
153 | " ax[1].set_xlim(xlim)\n",
154 | " ax[1].set_ylabel('#Actions')\n",
155 | " ax[1].set_xlabel('#Train Steps')\n",
156 | " if lenyticks is not None:\n",
157 | " ax[1].set_yticks(lenyticks)\n",
158 | " plt.tight_layout()\n",
159 | " if savename:\n",
160 | " plt.savefig(savename)\n",
161 | "\n",
162 | " plt.show()"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "metadata": {
169 | "scrolled": false
170 | },
171 | "outputs": [],
172 | "source": [
173 | "results = {}\n",
174 | "ddpg_datas = []\n",
175 | "for i in sorted(os.listdir('experiments/ddpg/DDPG')):\n",
176 | " ddpg_datas.append(np.load(f'experiments/ddpg/DDPG/{i}/DDPG_Pendulum-v0_{i}.npy'))\n",
177 | "\n",
178 | "\n",
179 | "ddpg_mean = np.mean(ddpg_datas, axis=0)\n",
180 | "ddpg_stdev = np.std(ddpg_datas, axis=0)\n",
181 | "results['DDPG'] = [[ddpg_mean[:, 1], ddpg_stdev[:, 1]],\n",
182 | " [ddpg_mean[:, 3], ddpg_stdev[:, 3]], \n",
183 | " [ddpg_mean[:, 2], ddpg_stdev[:, 2]],\n",
184 | " [ddpg_mean[:, 0], ddpg_mean[:, 0]],\n",
185 | " [ddpg_mean[:, 0], ddpg_mean[:, 0]]]\n",
186 | "\n",
187 | "for max_len in [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]:\n",
188 | " temporl_datas = []\n",
189 | " for i in sorted(os.listdir(f'experiments/ddpg/TempoRLDDPG/{max_len}')):\n",
190 | " temporl_datas.append(np.load(f'experiments/ddpg/TempoRLDDPG/{max_len}/{i}/TempoRLDDPG_Pendulum-v0_{i}.npy'))\n",
191 | "\n",
192 | " figar_datas = []\n",
193 | " for i in sorted(os.listdir(f'experiments/ddpg/FiGARDDPG/{max_len}')):\n",
194 | " figar_datas.append(np.load(f'experiments/ddpg/FiGARDDPG/{max_len}/{i}/FiGARDDPG_Pendulum-v0_{i}.npy'))\n",
195 | "\n",
196 | " temporl_mean = np.mean(temporl_datas, axis=0)\n",
197 | " figar_mean = np.mean(figar_datas, axis=0)\n",
198 | " temporl_stdev = np.std(temporl_datas, axis=0)\n",
199 | " figar_stdev = np.std(figar_datas, axis=0)\n",
200 | " \n",
201 | " # (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes]))\n",
202 | " results['t-DDPG'] = [[temporl_mean[:, 1], temporl_stdev[:, 1]],\n",
203 | " [temporl_mean[:, 3], temporl_stdev[:, 3]],\n",
204 | " [temporl_mean[:, 2], temporl_stdev[:, 2]],\n",
205 | " [temporl_mean[:, 0], temporl_mean[:, 0]],\n",
206 | " [temporl_mean[:, 0], temporl_mean[:, 0]]]\n",
207 | " results['f-DDPG'] = [[figar_mean[:, 1], figar_stdev[:, 1]],\n",
208 | " [figar_mean[:, 3], figar_stdev[:, 3]],\n",
209 | " [figar_mean[:, 2], figar_stdev[:, 2]],\n",
210 | " [figar_mean[:, 0], figar_mean[:, 0]],\n",
211 | " [figar_mean[:, 0], figar_mean[:, 0]]]\n",
212 | " print(min(min(results['DDPG'][0][0]), min(results['t-DDPG'][0][0]), min(results['f-DDPG'][0][0])),\n",
213 | " max(max(results['DDPG'][0][0]), max(results['t-DDPG'][0][0]), max(results['f-DDPG'][0][0])))\n",
214 | " print(' DDPG AUC:', np.mean((results['DDPG'][0][0] + 1800) / (-145 + 1800)))\n",
215 | " print('t-DDPG AUC:', np.mean((results['t-DDPG'][0][0] + 1800) / (-145 + 1800)))\n",
216 | " print(' FiGAR AUC:', np.mean((results['f-DDPG'][0][0] + 1800) / (-145 + 1800)))\n",
217 | " plotMultiple(results, title=r'Pendulum-v0 -- $\\mathcal{J}=' + f'{max_len}$',\n",
218 | " smooth=0, ylim=[-1800, -50], min_steps=10, max_steps=210, xlim=[10**3, 3*10**4],\n",
219 | " lenyticks=[50, 125, 200], rewyticks=[-1800, -1000, -200],\n",
220 | " savename=f'ddpg_{max_len}.pdf')"
221 | ]
222 | }
223 | ],
224 | "metadata": {
225 | "kernelspec": {
226 | "display_name": "Python 3",
227 | "language": "python",
228 | "name": "python3"
229 | },
230 | "language_info": {
231 | "codemirror_mode": {
232 | "name": "ipython",
233 | "version": 3
234 | },
235 | "file_extension": ".py",
236 | "mimetype": "text/x-python",
237 | "name": "python",
238 | "nbconvert_exporter": "python",
239 | "pygments_lexer": "ipython3",
240 | "version": "3.7.7"
241 | }
242 | },
243 | "nbformat": 4,
244 | "nbformat_minor": 4
245 | }
246 |
--------------------------------------------------------------------------------
/plot_featurized_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "from utils.plotting import get_colors, load_config, plot\n",
12 | "from utils.data_handling import load_dqn_data\n",
13 | "import numpy as np"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "#### Name explanations\n",
21 | "* DQN -> standard DQN\n",
22 | "* DAR_min^max -> Dynamic action repetition with small repetition and long repetition values\n",
23 | "* tqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action to be concatenated to the state\n",
24 | "* t-dqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action as contextual input\n",
25 | "* tdqn -> TempoRL DQN with shared state representation between the behavoiur and skip action outputs."
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import json\n",
35 | "import glob\n",
36 | "import os\n",
37 | "import pandas as pd\n",
38 | "from matplotlib import pyplot as plt\n",
39 | "import seaborn as sb\n",
40 | "\n",
41 | "from scipy.signal import savgol_filter\n",
42 | " \n",
43 | "\n",
44 | "# Somehow the plotting functionallity I ended up with was already covered for the tabular case.\n",
45 | "# I should have just used the plot function from that.\n",
46 | "def plotMultiple(data, ylim=None, title='', logStepY=False, max_steps=200, xlim=None, figsize=None,\n",
47 | " alphas=None, smooth=5, savename=None, rewyticks=None, lenyticks=None,\n",
48 | " skip_stdevs=[], dont_label=[], dont_plot=[]):\n",
49 | " \"\"\"\n",
50 | " Simple plotting method that shows the test reward on the y-axis and the number of performed training steps\n",
51 | " on the x-axis.\n",
52 | " \n",
53 | " data -> (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes])) the data to plot\n",
54 | " ylim -> (list) y-axis limit\n",
55 | " title -> (str) title on top of plot\n",
56 | " logStepY -> (bool) flag that indicates if the y-axis should be on log scale.\n",
57 | " max_steps -> (int) maximal episode length\n",
58 | " xlim -> (list) x-axis limits\n",
59 | " figsize -> (list) dimensions of the figure\n",
60 | " alphas -> (dict[agent name] -> float) the alpha value to use for plotting of specific agents\n",
61 | " smooth -> (int) the window size for smoothing (has to be odd if used. < 0 deactivates this option)\n",
62 | " savename -> (str) filename to save the figure\n",
63 | " rewyticks -> (list) yticks for the reward plot\n",
64 | " lenyticks -> (list) yticks for the decisions plot\n",
65 | " skip_sdevs -> (list) list of names to not plot standard deviations for.\n",
66 | " dont_label -> (list) list of names to not label.\n",
67 | " dont_plot -> (list) list of names to not plot.\n",
68 | " \"\"\"\n",
69 | " \n",
70 | " if smooth and smooth > 0:\n",
71 | " degree = 2\n",
72 | " for agent in data:\n",
73 | " data[agent] = list(data[agent]) # we have to convert the tuple to lists\n",
74 | " data[agent][0] = list(data[agent][0])\n",
75 | " data[agent][0][0] = savgol_filter(data[agent][0][0], smooth, degree) # smooth the mean reward\n",
76 | " data[agent][0][1] = savgol_filter(data[agent][0][1], smooth, degree) # smooth the stdev reward\n",
77 | " data[agent][1] = list(data[agent][1])\n",
78 | " data[agent][1][0] = savgol_filter(data[agent][1][0], smooth, degree) # smooth mean num steps\n",
79 | " data[agent][1][1] = savgol_filter(data[agent][1][1], smooth, degree)\n",
80 | " data[agent][2] = list(data[agent][2])\n",
81 | " data[agent][2][0] = savgol_filter(data[agent][2][0], smooth, degree) # smooth mean decisions\n",
82 | " data[agent][2][1] = savgol_filter(data[agent][2][1], smooth, degree)\n",
83 | "\n",
84 | " colors, color_map = get_colors()\n",
85 | " \n",
86 | "\n",
87 | " cfg = load_config()\n",
88 | " sb.set_style(cfg['plotting']['seaborn']['style'])\n",
89 | " sb.set_context(cfg['plotting']['seaborn']['context']['context'],\n",
90 | " font_scale=cfg['plotting']['seaborn']['context']['font scale'],\n",
91 | " rc=cfg['plotting']['seaborn']['context']['rc2'])\n",
92 | "\n",
93 | " if figsize:\n",
94 | " fig, ax = plt.subplots(2, figsize=figsize, dpi=100, sharex=True)\n",
95 | " else:\n",
96 | " fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100,sharex=True)\n",
97 | " ax[0].set_title(title)\n",
98 | "\n",
99 | " for agent in list(data.keys())[::-1]:\n",
100 | " if agent in dont_plot:\n",
101 | " continue\n",
102 | " try:\n",
103 | " alph = alphas[agent]\n",
104 | " except:\n",
105 | " alph = 1.\n",
106 | " color_name = None\n",
107 | " if 'dar' in agent:\n",
108 | " color_name = color_map['dar']\n",
109 | " elif agent.startswith('t'):\n",
110 | " color_name = color_map['t-dqn']\n",
111 | " else:\n",
112 | " color_name = color_map[agent]\n",
113 | " rew, lens, decs, train_steps, train_eps = data[agent]\n",
114 | " \n",
115 | " label = agent.upper()\n",
116 | " if agent.startswith('t'):\n",
117 | " label = 't-DQN'\n",
118 | " elif agent in dont_label:\n",
119 | " label = None\n",
120 | "\n",
121 | " #### Plot rewards\n",
122 | " ax[0].step(train_steps[0][::5], rew[0][::5], where='post', c=colors[color_name], label=label,\n",
123 | " alpha=alph)\n",
124 | " if agent not in skip_stdevs:\n",
125 | " ax[0].fill_between(train_steps[0][::5], rew[0][::5]-rew[1][::5], rew[0][::5]+rew[1][::5], alpha=0.25 * alph, step='post',\n",
126 | " color=colors[color_name])\n",
127 | " #### Plot lens\n",
128 | " ax[1].step(train_steps[0], decs[0], where='post', c=np.array(colors[color_name]), ls='-',\n",
129 | " alpha=alph)\n",
130 | " if agent not in skip_stdevs:\n",
131 | " ax[1].fill_between(train_steps[0][::5], decs[0][::5]-decs[1][::5], decs[0][::5]+decs[1][::5], alpha=0.125 * alph, step='post',\n",
132 | " color=np.array(colors[color_name]))\n",
133 | " ax[1].step(train_steps[0][::5], lens[0][::5], where='post',\n",
134 | " c=np.array(colors[color_name]) * .75, alpha=alph, ls=':')\n",
135 | " \n",
136 | " if agent not in skip_stdevs:\n",
137 | " ax[1].fill_between(train_steps[0][::5], lens[0][::5]-lens[1][::5], lens[0][::5]+lens[1][::5], alpha=0.25 * alph, step='post',\n",
138 | " color=np.array(colors[color_name]) * .75)\n",
139 | " ax[0].semilogx()\n",
140 | " if rewyticks is not None:\n",
141 | " ax[0].set_yticks(rewyticks)\n",
142 | " if ylim:\n",
143 | " ax[0].set_ylim(ylim)\n",
144 | " if xlim:\n",
145 | " ax[0].set_xlim(xlim)\n",
146 | " ax[0].set_ylabel('Reward')\n",
147 | " if len(data) - len(dont_label) < 5:\n",
148 | " ax[0].legend(ncol=1, loc='best', handlelength=.75)\n",
149 | " ax[1].semilogx()\n",
150 | " if logStepY:\n",
151 | " ax[1].semilogy()\n",
152 | " \n",
153 | " ax[1].plot([-999, -999], [-999, -999], ls=':', c='k', label='all')\n",
154 | " ax[1].plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')\n",
155 | " ax[1].legend(loc='best', ncol=1, handlelength=.75)\n",
156 | " ax[1].set_ylim([1, max_steps])\n",
157 | " if xlim:\n",
158 | " ax[1].set_xlim(xlim)\n",
159 | " ax[1].set_ylabel('#Actions')\n",
160 | " ax[1].set_xlabel('#Train Steps')\n",
161 | " if lenyticks is not None:\n",
162 | " ax[1].set_yticks(lenyticks)\n",
163 | " plt.tight_layout()\n",
164 | " if savename:\n",
165 | " plt.savefig(savename)\n",
166 | "\n",
167 | " plt.show()\n",
168 | "\n",
169 | "\n",
170 | "def get_best_to_plot(data, aucs, tempoRL=None):\n",
171 | " \"\"\"\n",
172 | " Simple method to filter which lines to plot.\n",
173 | " \"\"\"\n",
174 | " to_plot = dict()\n",
175 | "\n",
176 | " if tempoRL is None:\n",
177 | " aucs = list(sorted(aucs, key=lambda x: x[1], reverse=True))\n",
178 | " for idx, auc in enumerate(aucs):\n",
179 | " if 't' in auc[0]:\n",
180 | " break\n",
181 | " to_plot[aucs[idx][0]] = data[aucs[idx][0]] # the absolute best\n",
182 | " else:\n",
183 | " to_plot[tempoRL] = data[tempoRL]\n",
184 | "\n",
185 | " bv = -np.inf\n",
186 | " b = None\n",
187 | " for elem in aucs:\n",
188 | " if 'dar' not in elem[0]:\n",
189 | " continue\n",
190 | " elif elem[1] > bv:\n",
191 | " b, bv = elem[0], elem[1]\n",
192 | " to_plot[b] = data[b]\n",
193 | " \n",
194 | " \n",
195 | " to_plot['dqn'] = data['dqn']\n",
196 | " print('Only plotting:', list(to_plot.keys()))\n",
197 | " return to_plot"
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "
"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "scrolled": false
212 | },
213 | "outputs": [],
214 | "source": [
215 | "mountain_sparse_data = {}\n",
216 | "mountain_sparse_alphas = {}\n",
217 | "mountain_sparse_aucs = []\n",
218 | "max_steps=10**6\n",
219 | "thresh = -110\n",
220 | "\n",
221 | "for pairs in [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9)]:\n",
222 | " dar_fm_str = r'$dar_{' + '{}'.format(pairs[0] + 1) + '}^{' + '{}'.format(pairs[1] + 1) + '}$'\n",
223 | " mountain_sparse_alphas[dar_fm_str] = 1/5\n",
224 | " mountain_sparse_data[dar_fm_str] = load_dqn_data(\n",
225 | " '*', 'experiments/featurized_results/sparsemountain/dar_orig_%d_%d' % (pairs[0], pairs[1]), max_steps=max_steps,\n",
226 | " )\n",
227 | " try:\n",
228 | " mountain_sparse_aucs.append([dar_fm_str, np.trapz((mountain_sparse_data[dar_fm_str][0][0] + 200)/110,\n",
229 | " x=(mountain_sparse_data[dar_fm_str][3][0]/max(\n",
230 | " mountain_sparse_data[dar_fm_str][3][0])))])\n",
231 | " except:\n",
232 | " del mountain_sparse_data[dar_fm_str]\n",
233 | "\n",
234 | "\n",
235 | "\n",
236 | "mountain_sparse_data['dqn'] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/dqn', max_steps=max_steps,\n",
237 | " )\n",
238 | "mountain_sparse_aucs.append(['dqn', np.trapz((mountain_sparse_data['dqn'][0][0] + 200)/110,\n",
239 | " x=(mountain_sparse_data['dqn'][3][0]/max(\n",
240 | " mountain_sparse_data['dqn'][3][0])))])\n",
241 | "\n",
242 | "for i in [2, 4, 6, 8, 10]:\n",
243 | " mountain_sparse_data['tqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/tqn_%d' % i, max_steps=max_steps,\n",
244 | " )\n",
245 | " mountain_sparse_aucs.append(['tqn_%d' % i, np.trapz((mountain_sparse_data['tqn_%d' % i][0][0] + 200)/110,\n",
246 | " x=(mountain_sparse_data['tqn_%d' % i][3][0]/max(\n",
247 | " mountain_sparse_data['tqn_%d' % i][3][0])))])\n",
248 | "\n",
249 | " mountain_sparse_data['t-dqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/t-dqn_%d' % i,\n",
250 | " max_steps=max_steps,\n",
251 | " )\n",
252 | " mountain_sparse_aucs.append(['t-dqn_%d' % i, np.trapz((mountain_sparse_data['t-dqn_%d' % i][0][0] + 200)/110,\n",
253 | " x=(mountain_sparse_data['t-dqn_%d' % i][3][0]/max(\n",
254 | " mountain_sparse_data['t-dqn_%d' % i][3][0])))])\n",
255 | "\n",
256 | " mountain_sparse_data['tdqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/sparsemountain/tdqn_%d' % i, max_steps=max_steps,\n",
257 | " )\n",
258 | " mountain_sparse_aucs.append(['tdqn_%d' % i, np.trapz((mountain_sparse_data['tdqn_%d' % i][0][0] + 200)/110,\n",
259 | " x=(mountain_sparse_data['tdqn_%d' % i][3][0]/max(\n",
260 | " mountain_sparse_data['tdqn_%d' % i][3][0])))])"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {
267 | "scrolled": false
268 | },
269 | "outputs": [],
270 | "source": [
271 | "mountain_sparse_plot = get_best_to_plot(mountain_sparse_data, mountain_sparse_aucs)\n",
272 | "\n",
273 | "plotMultiple(mountain_sparse_plot, title='MountainCar-v0',\n",
274 | " ylim=[-200, -100], max_steps=200, xlim=[10**3, 10**6], smooth=11,\n",
275 | " savename='mcv0-sparse.pdf', rewyticks=[-190, -150, -110], lenyticks=[0, 75, 150])"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": null,
281 | "metadata": {},
282 | "outputs": [],
283 | "source": [
284 | "list(sorted(mountain_sparse_aucs, key=lambda x: x[1], reverse=True))"
285 | ]
286 | },
287 | {
288 | "cell_type": "markdown",
289 | "metadata": {},
290 | "source": [
291 | "
"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "metadata": {
298 | "scrolled": false
299 | },
300 | "outputs": [],
301 | "source": [
302 | "moon_dense_data = {}\n",
303 | "moon_dense_alphas = {}\n",
304 | "moon_dense_aucs = []\n",
305 | "max_steps=10**6\n",
306 | "thresh=200\n",
307 | "\n",
308 | "for pairs in [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9)]:\n",
309 | " dar_fm_str = r'$dar_{' + '{}'.format(pairs[0] + 1) + '}^{' + '{}'.format(pairs[1] + 1) + '}$'\n",
310 | " moon_dense_alphas[dar_fm_str] = 1/5\n",
311 | " moon_dense_data[dar_fm_str] = load_dqn_data(\n",
312 | " '*', 'experiments/featurized_results/moon/dar_orig_%d_%d' % (pairs[0], pairs[1]), max_steps=max_steps,\n",
313 | " )\n",
314 | " moon_dense_aucs.append([dar_fm_str, np.trapz((moon_dense_data[dar_fm_str][0][0] + 250) / 500,\n",
315 | " x=(moon_dense_data[dar_fm_str][3][0]/max(\n",
316 | " moon_dense_data[dar_fm_str][3][0])))])\n",
317 | "\n",
318 | " \n",
319 | "moon_dense_data['dqn'] = load_dqn_data('*', 'experiments/featurized_results/moon/dqn', max_steps=max_steps,\n",
320 | " )\n",
321 | "moon_dense_aucs.append(['dqn', np.trapz((moon_dense_data['dqn'][0][0] + 250) / 500,\n",
322 | " x=(moon_dense_data['dqn'][3][0]/max(moon_dense_data['dqn'][3][0])))])\n",
323 | "\n",
324 | "\n",
325 | "for i in [2, 4, 6, 8, 10]:\n",
326 | " moon_dense_data['tqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/moon/tqn_%d' % i, max_steps=max_steps,\n",
327 | " )\n",
328 | " # compute normalized AUC\n",
329 | " moon_dense_aucs.append(['tqn_%d' % i, np.trapz((moon_dense_data['tqn_%d' % i][0][0] + 250) / 500,\n",
330 | " x=(moon_dense_data['tqn_%d' % i][3][0]/max(moon_dense_data['tqn_%d' % i][3][0])))])\n",
331 | "\n",
332 | " moon_dense_data['t-dqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/moon/t-dqn_%d' % i, max_steps=max_steps,\n",
333 | " )\n",
334 | " moon_dense_aucs.append(['t-dqn_%d' % i, np.trapz((moon_dense_data['t-dqn_%d' % i][0][0] + 250)/500,\n",
335 | " x=(moon_dense_data['t-dqn_%d' % i][3][0]/max(moon_dense_data['t-dqn_%d' % i][3][0])))])\n",
336 | "\n",
337 | " moon_dense_data['tdqn_%d' % i] = load_dqn_data('*', 'experiments/featurized_results/moon/tdqn_%d' % i, max_steps=max_steps,\n",
338 | " )\n",
339 | " moon_dense_aucs.append(['tdqn_%d' % i, np.trapz((moon_dense_data['tdqn_%d' % i][0][0] + 250)/500,\n",
340 | " x=(moon_dense_data['tdqn_%d' % i][3][0]/max(moon_dense_data['tdqn_%d' % i][3][0])))])"
341 | ]
342 | },
343 | {
344 | "cell_type": "code",
345 | "execution_count": null,
346 | "metadata": {},
347 | "outputs": [],
348 | "source": [
349 | "moon_plot = get_best_to_plot(moon_dense_data, moon_dense_aucs)\n",
350 | "\n",
351 | "plotMultiple(moon_plot, title='LunarLander-v2', ylim=[-250, 200], max_steps=1000, xlim=[10**3, 10**6],\n",
352 | " smooth=11, savename='llv2-dense.pdf', rewyticks=[-250, 0, 200], lenyticks=[200, 500, 800])"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": null,
358 | "metadata": {},
359 | "outputs": [],
360 | "source": [
361 | "list(sorted(moon_dense_aucs, key=lambda x: x[1], reverse=True))"
362 | ]
363 | }
364 | ],
365 | "metadata": {
366 | "kernelspec": {
367 | "display_name": "Python 3",
368 | "language": "python",
369 | "name": "python3"
370 | },
371 | "language_info": {
372 | "codemirror_mode": {
373 | "name": "ipython",
374 | "version": 3
375 | },
376 | "file_extension": ".py",
377 | "mimetype": "text/x-python",
378 | "name": "python",
379 | "nbconvert_exporter": "python",
380 | "pygments_lexer": "ipython3",
381 | "version": "3.7.7"
382 | }
383 | },
384 | "nbformat": 4,
385 | "nbformat_minor": 4
386 | }
387 |
--------------------------------------------------------------------------------
/plot_tabular_results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "from utils.data_handling import *\n",
12 | "from utils.plotting import *"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "
\n",
20 | "\n",
21 | "# Cliff"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {
28 | "scrolled": false
29 | },
30 | "outputs": [],
31 | "source": [
32 | "max_skip = 7\n",
33 | "episodes = 10_000\n",
34 | "max_steps = 100\n",
35 | "\n",
36 | "for exp_version in ['-1.0-linear', '-0.1-const', '-1.0-log']:\n",
37 | " print(exp_version[1:].upper())\n",
38 | " methods = ['q', 'sq']\n",
39 | " rews, lens, steps = load_data(\"experiments/tabular_results/cliff\", methods, exp_version,\n",
40 | " episodes, max_skip, max_steps, local=True)\n",
41 | " for m in methods:\n",
42 | " print(len(rews[m]))\n",
43 | " title = '{:s}'.format(\"Cliff\")\n",
44 | " plot(methods, rews, lens, steps, title, episodes, logrewy=False,\n",
45 | " logleny=False, logx=True, annotate=True, savefig=\"cliff{:s}.pdf\".format(exp_version.replace('.', '_')),\n",
46 | " individual=False)"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "
\n",
54 | "\n",
55 | "# Bridge"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {
62 | "scrolled": false
63 | },
64 | "outputs": [],
65 | "source": [
66 | "max_skip = 7\n",
67 | "episodes = 10_000\n",
68 | "max_steps = 100\n",
69 | "\n",
70 | "for exp_version in ['-1.0-linear', '-0.1-const', '-1.0-log']:\n",
71 | " print(exp_version[1:].upper())\n",
72 | " methods = ['q', 'sq']\n",
73 | " rews, lens, steps = load_data(\"experiments/tabular_results/bridge\", methods, exp_version,\n",
74 | " episodes, max_skip, max_steps, local=True)\n",
75 | " for m in methods:\n",
76 | " print(len(rews[m]))\n",
77 | " title = '{:s}'.format(\"Bridge\")\n",
78 | " plot(methods, rews, lens, steps, title, episodes, annotate=True,\n",
79 | " logrewy=False, logleny=False, savefig=\"bridge{:s}.pdf\".format(exp_version.replace('.', '_')))"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "
\n",
87 | "\n",
88 | "# ZigZag"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {
95 | "scrolled": false
96 | },
97 | "outputs": [],
98 | "source": [
99 | "max_skip = 7\n",
100 | "episodes = 10_000\n",
101 | "max_steps = 100\n",
102 | "\n",
103 | "for exp_version in ['-1.0-linear', '-0.1-const', '-1.0-log']:\n",
104 | " print(exp_version[1:].upper())\n",
105 | " methods = ['q', 'sq']\n",
106 | " rews, lens, steps = load_data(\"experiments/tabular_results/zigzag\", methods, exp_version,\n",
107 | " episodes, max_skip, max_steps, local=True)\n",
108 | " for m in methods:\n",
109 | " print(len(rews[m]))\n",
110 | " title = '{:s}'.format(\"ZigZag\")\n",
111 | " plot(methods, rews, lens, steps, title, episodes, annotate=True,\n",
112 | " logrewy=False, logleny=False, savefig=\"zigzag{:s}.pdf\".format(exp_version.replace('.', '_')))"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "metadata": {},
118 | "source": [
119 | "
\n",
120 | "\n",
121 | " \n",
122 | "# Influence of skip-lenth on tempoRL"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "## ZigZag - Linear"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {
136 | "scrolled": false
137 | },
138 | "outputs": [],
139 | "source": [
140 | "max_skip = 2\n",
141 | "episodes = 10_000\n",
142 | "max_steps = 100\n",
143 | "\n",
144 | "for exp_version in ['-1.0-linear']:\n",
145 | " print(exp_version[1:].upper())\n",
146 | " for max_skip in range(2, 17):\n",
147 | " rews, lens, steps = load_data(\"experiments/tabular_results/j_ablation/zigzag\", ['sq'], exp_version,\n",
148 | " episodes, max_skip, max_steps, local=True)\n",
149 | " # the q we compare to is the same as in standard zigzag. max_skip does not influence q\n",
150 | " rews_, lens_, steps_ = load_data(\"experiments/tabular_results/zigzag\", ['q'], exp_version,\n",
151 | " episodes, 7, max_steps, local=True)\n",
152 | " rews.update(rews_)\n",
153 | " lens.update(lens_)\n",
154 | " steps.update(steps_)\n",
155 | " for m in methods:\n",
156 | " print(len(rews[m]))\n",
157 | " title = '{:s} - {:d}'.format(\"ZigZag\", max_skip)\n",
158 | " plot(['sq', 'q'], rews, lens, steps, title, episodes, annotate=False,\n",
159 | " logrewy=False, logleny=False, savefig=None)"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "## ZigZag - Log"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {
173 | "scrolled": false
174 | },
175 | "outputs": [],
176 | "source": [
177 | "max_skip = 2\n",
178 | "episodes = 10_000\n",
179 | "max_steps = 100\n",
180 | "\n",
181 | "for exp_version in ['-1.0-log']:\n",
182 | " print(exp_version[1:].upper())\n",
183 | " for max_skip in range(2, 17):\n",
184 | " rews, lens, steps = load_data(\"experiments/tabular_results/j_ablation/zigzag\", ['sq'], exp_version,\n",
185 | " episodes, max_skip, max_steps, local=True)\n",
186 | " rews_, lens_, steps_ = load_data(\"experiments/tabular_results/zigzag\", ['q'], exp_version,\n",
187 | " episodes, 7, max_steps, local=True)\n",
188 | " rews.update(rews_)\n",
189 | " lens.update(lens_)\n",
190 | " steps.update(steps_)\n",
191 | " for m in methods:\n",
192 | " print(len(rews[m]))\n",
193 | " title = '{:s} - {:d}'.format(\"ZigZag\", max_skip)\n",
194 | " plot(['sq', 'q'], rews, lens, steps, title, episodes, annotate=False,\n",
195 | " logrewy=False, logleny=False, savefig=None)"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {},
201 | "source": [
202 | "## ZigZag - Constant"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "metadata": {
209 | "scrolled": false
210 | },
211 | "outputs": [],
212 | "source": [
213 | "max_skip = 2\n",
214 | "episodes = 10_000\n",
215 | "max_steps = 100\n",
216 | "\n",
217 | "for exp_version in ['-0.1-const']:\n",
218 | " print(exp_version[1:].upper())\n",
219 | " for max_skip in range(2, 17):\n",
220 | " rews, lens, steps = load_data(\"experiments/tabular_results/j_ablation/zigzag\", ['sq'], exp_version,\n",
221 | " episodes, max_skip, max_steps, local=True)\n",
222 | " rews_, lens_, steps_ = load_data(\"experiments/tabular_results/zigzag\", ['q'], exp_version,\n",
223 | " episodes, 7, max_steps, local=True)\n",
224 | " rews.update(rews_)\n",
225 | " lens.update(lens_)\n",
226 | " steps.update(steps_)\n",
227 | " for m in methods:\n",
228 | " print(len(rews[m]))\n",
229 | " title = '{:s} - {:d}'.format(\"ZigZag\", max_skip)\n",
230 | " plot(['sq', 'q'], rews, lens, steps, title, episodes, annotate=False,\n",
231 | " logrewy=False, logleny=False, savefig=None)"
232 | ]
233 | }
234 | ],
235 | "metadata": {
236 | "kernelspec": {
237 | "display_name": "Python 3",
238 | "language": "python",
239 | "name": "python3"
240 | },
241 | "language_info": {
242 | "codemirror_mode": {
243 | "name": "ipython",
244 | "version": 3
245 | },
246 | "file_extension": ".py",
247 | "mimetype": "text/x-python",
248 | "name": "python",
249 | "nbconvert_exporter": "python",
250 | "pygments_lexer": "ipython3",
251 | "version": "3.7.7"
252 | }
253 | },
254 | "nbformat": 4,
255 | "nbformat_minor": 2
256 | }
257 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyyaml~=5.4
2 | seaborn~=0.10.1
3 | matplotlib~=3.3.0
4 | gym~=0.17.2
5 | numpy~=1.19.1
6 | scipy~=1.5.2
7 | torchvision~=0.5.0
8 | pillow~=7.2.0
9 | ray[rllib]~=1.0.1
10 | opencv-python~=4.4.0.46
11 | future~=0.18.2
12 | torch~=1.4.0
13 |
--------------------------------------------------------------------------------
/run_atari_experiments.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import numpy as np
5 | import gym
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from torch.autograd import Variable
11 | from itertools import count
12 | from collections import namedtuple
13 | import time
14 | from mountain_car import MountainCarEnv
15 | from utils import experiments
16 |
17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18 | device = torch.device('cpu')
19 |
20 |
21 | def tt(ndarray):
22 | """
23 | Helper Function to cast observation to correct type/device
24 | """
25 | if device.type == "cuda":
26 | return Variable(torch.from_numpy(ndarray).float().cuda(), requires_grad=False)
27 | else:
28 | return Variable(torch.from_numpy(ndarray).float(), requires_grad=False)
29 |
30 |
31 | def soft_update(target, source, tau):
32 | """
33 | Simple Helper for updating target-network parameters
34 | :param target: target network
35 | :param source: policy network
36 | :param tau: weight to regulate how strongly to update (1 -> copy over weights)
37 | """
38 | for target_param, param in zip(target.parameters(), source.parameters()):
39 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
40 |
41 |
42 | def hard_update(target, source):
43 | """
44 | See soft_update
45 | """
46 | soft_update(target, source, 1.0)
47 |
48 |
49 | class NatureDQN(nn.Module):
50 | """
51 | DQN following the DQN implementation from
52 | https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf
53 | """
54 |
55 | def __init__(self, in_channels=4, num_actions=18):
56 | """
57 | :param in_channels: number of channel of input. (how many stacked images are used)
58 | :param num_actions: action values
59 | """
60 | super(NatureDQN, self).__init__()
61 | if env.observation_space.shape[-1] == 84: #hack
62 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
63 | elif env.observation_space.shape[-1] == 42: #hack
64 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=4, stride=2)
65 | else:
66 | raise ValueError("Check state space dimensionality. Expected nx42x42 or nx84x84. Was:", env.observation_space.shape)
67 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
68 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
69 | self.fc4 = nn.Linear(7 * 7 * 64, 512)
70 | self.fc5 = nn.Linear(512, num_actions)
71 |
72 | def forward(self, x):
73 | x = F.relu(self.conv1(x))
74 | x = F.relu(self.conv2(x))
75 | x = F.relu(self.conv3(x))
76 | x = F.relu(self.fc4(x.reshape(x.size(0), -1)))
77 | return self.fc5(x)
78 |
79 |
80 | class NatureTQN(nn.Module):
81 | """
82 | Network to learn the skip behaviour using the same architecture as the original DQN but with additional context.
83 | The context is expected to be the chosen behaviour action on which the skip-Q is conditioned.
84 |
85 | This Q function is expected to be used solely to learn the skip-Q function
86 | """
87 |
88 | def __init__(self, in_channels=4, num_actions=18):
89 | """
90 | :param in_channels: number of channel of input. (how many stacked images are used)
91 | :param num_actions: action values
92 | """
93 | super(NatureTQN, self).__init__()
94 | if env.observation_space.shape[-1] == 84: # hack
95 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
96 | elif env.observation_space.shape[-1] == 42: # hack
97 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=4, stride=2)
98 | else:
99 | raise ValueError("Check state space dimensionality. Expected nx42x42 or nx84x84. Was:",
100 | env.observation_space.shape)
101 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
102 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
103 |
104 | self.skip = nn.Linear(1, 10) # Context layer
105 |
106 | self.fc4 = nn.Linear(7 * 7 * 64 + 10, 512) # Combination layer
107 | self.fc5 = nn.Linear(512, num_actions) # Output
108 |
109 | def forward(self, x, action_val=None):
110 | # Process input image
111 | x = F.relu(self.conv1(x))
112 | x = F.relu(self.conv2(x))
113 | x = F.relu(self.conv3(x))
114 |
115 | # Process behaviour context
116 | x_ = F.relu(self.skip(action_val))
117 |
118 | # Combine both streams
119 | x = F.relu(self.fc4(
120 | torch.cat([x.reshape(x.size(0), -1), x_], 1))) # This layer concatenates the context and CNN part
121 | return self.fc5(x)
122 |
123 |
124 | class NatureWeightsharingTQN(nn.Module):
125 | """
126 | Network to learn the skip behaviour using the same architecture as the original DQN but with additional context.
127 | The context is expected to be the chosen behaviour action on which the skip-Q is conditioned.
128 | This implementation allows to share weights between the behaviour network and the skip network
129 | """
130 |
131 | def __init__(self, in_channels=4, num_actions=18, num_skip_actions=10):
132 | """
133 | :param in_channels: number of channel of input. (how many stacked images are used)
134 | :param num_actions: action values
135 | """
136 | super(NatureWeightsharingTQN, self).__init__()
137 | # shared input-layers
138 | if env.observation_space.shape[-1] == 84: #hack
139 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
140 | elif env.observation_space.shape[-1] == 42: #hack
141 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=4, stride=2)
142 | else:
143 | raise ValueError("Check state space dimensionality. Expected nx42x42 or nx84x84. Was:", env.observation_space.shape)
144 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
145 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
146 |
147 | # skip-layers
148 | self.skip = nn.Linear(1, 10) # Context layer
149 | self.skip_fc4 = nn.Linear(7 * 7 * 64 + 10, 512)
150 | self.skip_fc5 = nn.Linear(512, num_skip_actions)
151 |
152 | # behaviour-layers
153 | self.action_fc4 = nn.Linear(7 * 7 * 64, 512)
154 | self.action_fc5 = nn.Linear(512, num_actions)
155 |
156 | def forward(self, x, action_val=None):
157 | # Process input image
158 | x = F.relu(self.conv1(x))
159 | x = F.relu(self.conv2(x))
160 | x = F.relu(self.conv3(x))
161 |
162 | if action_val is not None: # Only if an action_value was provided we evaluate the skip output layers Q(s,j|a)
163 | x_ = F.relu(self.skip(action_val))
164 | x = F.relu(self.skip_fc4(
165 | torch.cat([x.reshape(x.size(0), -1), x_], 1))) # This layer concatenates the context and CNN part
166 | return self.skip_fc5(x)
167 | else: # otherwise we simply continue as in standard DQN and compute Q(s,a)
168 | x = F.relu(self.action_fc4(x.reshape(x.size(0), -1)))
169 | return self.action_fc5(x)
170 |
171 |
172 | class Q(nn.Module):
173 | """
174 | Simple fully connected Q function. Also used for skip-Q when concatenating behaviour action and state together.
175 | Used for simpler environments such as mountain-car or lunar-lander.
176 | """
177 |
178 | def __init__(self, state_dim, action_dim, non_linearity=F.relu, hidden_dim=50):
179 | super(Q, self).__init__()
180 | self.fc1 = nn.Linear(state_dim, hidden_dim)
181 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
182 | self.fc3 = nn.Linear(hidden_dim, action_dim)
183 | self._non_linearity = non_linearity
184 |
185 | def forward(self, x):
186 | x = self._non_linearity(self.fc1(x))
187 | x = self._non_linearity(self.fc2(x))
188 | return self.fc3(x)
189 |
190 |
191 | class TQ(nn.Module):
192 | """
193 | Q-Function that takes the behaviour action as context.
194 | This Q is solely inteded to be used for computing the skip-Q Q(s,j|a).
195 |
196 | Basically the same architecture as Q but with context input layer.
197 | """
198 |
199 | def __init__(self, state_dim, skip_dim, non_linearity=F.relu, hidden_dim=50):
200 | super(TQ, self).__init__()
201 | self.fc1 = nn.Linear(state_dim, hidden_dim)
202 | self.skip_fc2 = nn.Linear(1, 10)
203 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
204 | self.skip_fc3 = nn.Linear(hidden_dim + 10, skip_dim) # output layer taking context and state into account
205 | self._non_linearity = non_linearity
206 |
207 | def forward(self, x, a=None):
208 | # Process the input state
209 | x = self._non_linearity(self.fc1(x))
210 | x = self._non_linearity(self.fc2(x))
211 |
212 | # Process behaviour-action as context
213 | x_ = self._non_linearity(self.skip_fc2(a))
214 | return self.skip_fc3(torch.cat([x, x_], -1)) # Concatenate both to produce the final output
215 |
216 |
217 | class WeightSharingTQ(nn.Module):
218 | """
219 | Q-function with shared state representation but two independent output streams (action, skip)
220 | """
221 |
222 | def __init__(self, state_dim, action_dim, skip_dim, non_linearity=F.relu, hidden_dim=50):
223 | super(WeightSharingTQ, self).__init__()
224 | self.fc1 = nn.Linear(state_dim, hidden_dim)
225 | self.skip_fc2 = nn.Linear(1, 10)
226 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
227 | self.action_fc3 = nn.Linear(hidden_dim, action_dim)
228 | self.skip_fc3 = nn.Linear(hidden_dim + 10, skip_dim)
229 | self._non_linearity = non_linearity
230 |
231 | def forward(self, x, a=None):
232 | # Process input state with shared layers
233 | x = self._non_linearity(self.fc1(x))
234 | x = self._non_linearity(self.fc2(x))
235 |
236 | if a is not None: # Only compute Skip Output if the behaviour action is given as context
237 | x_ = self._non_linearity(self.skip_fc2(a))
238 | return self.skip_fc3(torch.cat([x, x_], -1))
239 |
240 | # Only compute Behaviour output
241 | return self.action_fc3(x)
242 |
243 |
244 | class ReplayBuffer:
245 | """
246 | Simple Replay Buffer. Used for standard DQN learning.
247 | """
248 |
249 | def __init__(self, max_size):
250 | self._data = namedtuple("ReplayBuffer", ["states", "actions", "next_states", "rewards", "terminal_flags"])
251 | self._data = self._data(states=[], actions=[], next_states=[], rewards=[], terminal_flags=[])
252 | self._size = 0
253 | self._max_size = max_size
254 |
255 | def add_transition(self, state, action, next_state, reward, done):
256 | self._data.states.append(state)
257 | self._data.actions.append(action)
258 | self._data.next_states.append(next_state)
259 | self._data.rewards.append(reward)
260 | self._data.terminal_flags.append(done)
261 | self._size += 1
262 |
263 | if self._size > self._max_size:
264 | self._data.states.pop(0)
265 | self._data.actions.pop(0)
266 | self._data.next_states.pop(0)
267 | self._data.rewards.pop(0)
268 | self._data.terminal_flags.pop(0)
269 |
270 | def random_next_batch(self, batch_size):
271 | batch_indices = np.random.choice(len(self._data.states), batch_size)
272 | batch_states = np.array([self._data.states[i] for i in batch_indices])
273 | batch_actions = np.array([self._data.actions[i] for i in batch_indices])
274 | batch_next_states = np.array([self._data.next_states[i] for i in batch_indices])
275 | batch_rewards = np.array([self._data.rewards[i] for i in batch_indices])
276 | batch_terminal_flags = np.array([self._data.terminal_flags[i] for i in batch_indices])
277 | return tt(batch_states), tt(batch_actions), tt(batch_next_states), tt(batch_rewards), tt(batch_terminal_flags)
278 |
279 |
280 | class SkipReplayBuffer:
281 | """
282 | Replay Buffer for training the skip-Q.
283 | Expects "concatenated states" which already contain the behaviour-action for the skip-Q.
284 | Stores transitions as usual but with additional skip-length. The skip-length is used to properly discount.
285 | """
286 |
287 | def __init__(self, max_size):
288 | self._data = namedtuple("ReplayBuffer", ["states", "actions", "next_states",
289 | "rewards", "terminal_flags", "lengths"])
290 | self._data = self._data(states=[], actions=[], next_states=[], rewards=[], terminal_flags=[], lengths=[])
291 | self._size = 0
292 | self._max_size = max_size
293 |
294 | def add_transition(self, state, action, next_state, reward, done, length):
295 | self._data.states.append(state)
296 | self._data.actions.append(action)
297 | self._data.next_states.append(next_state)
298 | self._data.rewards.append(reward)
299 | self._data.terminal_flags.append(done)
300 | self._data.lengths.append(length) # Observed skip-length of the transition
301 | self._size += 1
302 |
303 | if self._size > self._max_size:
304 | self._data.states.pop(0)
305 | self._data.actions.pop(0)
306 | self._data.next_states.pop(0)
307 | self._data.rewards.pop(0)
308 | self._data.terminal_flags.pop(0)
309 | self._data.lengths.pop(0)
310 |
311 | def random_next_batch(self, batch_size):
312 | batch_indices = np.random.choice(len(self._data.states), batch_size)
313 | batch_states = np.array([self._data.states[i] for i in batch_indices])
314 | batch_actions = np.array([self._data.actions[i] for i in batch_indices])
315 | batch_next_states = np.array([self._data.next_states[i] for i in batch_indices])
316 | batch_rewards = np.array([self._data.rewards[i] for i in batch_indices])
317 | batch_terminal_flags = np.array([self._data.terminal_flags[i] for i in batch_indices])
318 | batch_lengths = np.array([self._data.lengths[i] for i in batch_indices])
319 | return tt(batch_states), tt(batch_actions), tt(batch_next_states),\
320 | tt(batch_rewards), tt(batch_terminal_flags), tt(batch_lengths)
321 |
322 |
323 | class NoneConcatSkipReplayBuffer:
324 | """
325 | Replay Buffer for training the skip-Q.
326 | Expects states in which the behaviour-action is not siply concatenated for the skip-Q.
327 | Stores transitions as usual but with additional skip-length. The skip-length is used to properly discount.
328 | Additionally stores the behaviour_action which is the context for this skip-transition.
329 | """
330 |
331 | def __init__(self, max_size):
332 | self._data = namedtuple("ReplayBuffer", ["states", "actions", "next_states",
333 | "rewards", "terminal_flags", "lengths", "behaviour_action"])
334 | self._data = self._data(states=[], actions=[], next_states=[], rewards=[], terminal_flags=[], lengths=[],
335 | behaviour_action=[])
336 | self._size = 0
337 | self._max_size = max_size
338 |
339 | def add_transition(self, state, action, next_state, reward, done, length, behaviour):
340 | self._data.states.append(state)
341 | self._data.actions.append(action)
342 | self._data.next_states.append(next_state)
343 | self._data.rewards.append(reward)
344 | self._data.terminal_flags.append(done)
345 | self._data.lengths.append(length) # Observed skip-length
346 | self._data.behaviour_action.append(behaviour) # Behaviour action to condition skip on
347 | self._size += 1
348 |
349 | if self._size > self._max_size:
350 | self._data.states.pop(0)
351 | self._data.actions.pop(0)
352 | self._data.next_states.pop(0)
353 | self._data.rewards.pop(0)
354 | self._data.terminal_flags.pop(0)
355 | self._data.lengths.pop(0)
356 | self._data.behaviour_action.pop(0)
357 |
358 | def random_next_batch(self, batch_size):
359 | batch_indices = np.random.choice(len(self._data.states), batch_size)
360 | batch_states = np.array([self._data.states[i] for i in batch_indices])
361 | batch_actions = np.array([self._data.actions[i] for i in batch_indices])
362 | batch_next_states = np.array([self._data.next_states[i] for i in batch_indices])
363 | batch_rewards = np.array([self._data.rewards[i] for i in batch_indices])
364 | batch_terminal_flags = np.array([self._data.terminal_flags[i] for i in batch_indices])
365 | batch_lengths = np.array([self._data.lengths[i] for i in batch_indices])
366 | batch_behavoiurs = np.array([self._data.behaviour_action[i] for i in batch_indices])
367 | return tt(batch_states), tt(batch_actions), tt(batch_next_states),\
368 | tt(batch_rewards), tt(batch_terminal_flags), tt(batch_lengths), tt(batch_behavoiurs)
369 |
370 |
371 | class DQN:
372 | """
373 | Simple double DQN Agent
374 | """
375 |
376 | def __init__(self, state_dim: int, action_dim: int, gamma: float,
377 | env: gym.Env, eval_env: gym.Env, vision: bool = False):
378 | """
379 | Initialize the DQN Agent
380 | :param state_dim: dimensionality of the input states
381 | :param action_dim: dimensionality of the output actions
382 | :param gamma: discount factor
383 | :param env: environment to train on
384 | :param eval_env: environment to evaluate on
385 | :param vision: boolean flag to indicate if the input state is an image or not
386 | """
387 | if not vision: # For featurized states
388 | self._q = Q(state_dim, action_dim).to(device)
389 | self._q_target = Q(state_dim, action_dim).to(device)
390 | else: # For image states, i.e. Atari
391 | self._q = NatureDQN(state_dim, action_dim).to(device)
392 | self._q_target = NatureDQN(state_dim, action_dim).to(device)
393 |
394 | self._gamma = gamma
395 |
396 | self.batch_size = 32
397 | self.grad_clip_val = 40.0
398 | self.target_net_upd_freq = 500
399 | self.learning_starts = 10_000
400 | self.initial_epsilon = 1.0
401 | self.final_epsilon = 0.01
402 | self.epsilon_timesteps = 200_000
403 | self.train_freq = 4
404 |
405 | self._loss_function = nn.SmoothL1Loss() # huber loss # nn.MSELoss()
406 | self._q_optimizer = optim.Adam(self._q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08)
407 | # self._q_optimizer = optim.RMSprop(self._q.parameters(), lr=0.00025, alpha=0.01, eps=1e-08, momentum=0.95)
408 | self._action_dim = action_dim
409 |
410 | self._replay_buffer = ReplayBuffer(5e4)
411 | self._env = env
412 | self._eval_env = eval_env
413 |
414 | def get_action(self, x: np.ndarray, epsilon: float) -> int:
415 | """
416 | Simple helper to get action epsilon-greedy based on observation x
417 | """
418 | u = np.argmax(self._q(tt(x[None, :])).cpu().detach().numpy())
419 | r = np.random.uniform()
420 | if r < epsilon:
421 | return np.random.randint(self._action_dim)
422 | return u
423 |
424 | def train(self, episodes: int, max_env_time_steps: int, epsilon: float, eval_eps: int = 1,
425 | eval_every_n_steps: int = 1, max_train_time_steps: int = 1_000_000):
426 | """
427 | Training loop
428 | :param episodes: maximum number of episodes to train for
429 | :param max_env_time_steps: maximum number of steps in the environment to perform per episode
430 | :param epsilon: constant epsilon for exploration when selecting actions
431 | :param eval_eps: numper of episodes to run for evaluation
432 | :param eval_every_n_steps: interval of steps after which to evaluate the trained agent
433 | :param max_train_time_steps: maximum number of steps to train
434 | :return:
435 | """
436 | total_steps = 0
437 | num_update_steps = 0
438 | batch_size = self.batch_size
439 | grad_clip_val = self.grad_clip_val
440 | target_net_upd_freq = self.target_net_upd_freq
441 | learning_starts = self.learning_starts
442 |
443 | start_time = time.time()
444 |
445 | for e in range(episodes):
446 | print("# Episode: %s/%s" % (e + 1, episodes))
447 | s = self._env.reset()
448 |
449 | for t in range(max_env_time_steps):
450 | # s = s
451 | # s = torch.from_numpy(s).unsqueeze(0)
452 | if total_steps > self.epsilon_timesteps:
453 | epsilon = self.final_epsilon
454 | else:
455 | epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) * (
456 | total_steps / self.epsilon_timesteps)
457 |
458 | a = self.get_action(s, epsilon)
459 | ns, r, d, _ = self._env.step(a)
460 | total_steps += 1
461 |
462 | ########### Begin Evaluation
463 | if (total_steps % eval_every_n_steps) == 0:
464 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps)
465 | eval_stats = dict(
466 | elapsed_time=time.time() - start_time,
467 | training_steps=total_steps,
468 | training_eps=e,
469 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)),
470 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)),
471 | avg_rew_per_eval_ep=float(np.mean(eval_r)),
472 | std_rew_per_eval_ep=float(np.std(eval_r)),
473 | eval_eps=eval_eps
474 | )
475 |
476 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh:
477 | json.dump(eval_stats, out_fh)
478 | out_fh.write('\n')
479 | ########### End Evaluation
480 |
481 | # Update replay buffer
482 | self._replay_buffer.add_transition(s, a, ns, r, d)
483 |
484 | batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = \
485 | self._replay_buffer.random_next_batch(batch_size)
486 |
487 | ########### Begin double Q-learning update
488 | target = batch_rewards + (1 - batch_terminal_flags) * self._gamma * \
489 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax(
490 | self._q(batch_next_states), dim=1)]
491 | current_prediction = self._q(batch_states)[torch.arange(batch_size).long(), batch_actions.long()]
492 |
493 | loss = self._loss_function(current_prediction, target.detach())
494 |
495 |
496 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0):
497 | num_update_steps += 1
498 | self._q_optimizer.zero_grad()
499 | loss.backward()
500 | for param in self._q.parameters():
501 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val)
502 | self._q_optimizer.step()
503 |
504 |
505 | if (total_steps % target_net_upd_freq) == 0:
506 | hard_update(self._q_target, self._q)
507 | # soft_update(self._q_target, self._q, 0.01)
508 | ########### End double Q-learning update
509 | if d:
510 | break
511 | s = ns
512 | if total_steps >= max_train_time_steps:
513 | break
514 | if total_steps >= max_train_time_steps:
515 | break
516 |
517 | # Final evaluation
518 | if (total_steps % eval_every_n_steps) != 0:
519 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps)
520 | eval_stats = dict(
521 | elapsed_time=time.time() - start_time,
522 | training_steps=total_steps,
523 | training_eps=e,
524 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)),
525 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)),
526 | avg_rew_per_eval_ep=float(np.mean(eval_r)),
527 | std_rew_per_eval_ep=float(np.std(eval_r)),
528 | eval_eps=eval_eps
529 | )
530 |
531 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh:
532 | json.dump(eval_stats, out_fh)
533 | out_fh.write('\n')
534 |
535 | def eval(self, episodes: int, max_env_time_steps: int):
536 | """
537 | Simple method that evaluates the agent with fixed epsilon = 0
538 | :param episodes: max number of episodes to play
539 | :param max_env_time_steps: max number of max_env_time_steps to play
540 |
541 | :returns (steps per episode), (reward per episode), (decisions per episode)
542 | """
543 | steps, rewards, decisions = [], [], []
544 | with torch.no_grad():
545 | for e in range(episodes):
546 | ed, es, er = 0, 0, 0
547 |
548 | s = self._eval_env.reset()
549 | for _ in count():
550 | a = self.get_action(s, 0)
551 | ed += 1
552 |
553 | ns, r, d, _ = self._eval_env.step(a)
554 | # print(r, d)
555 | er += r
556 | es += 1
557 | if es >= max_env_time_steps or d:
558 | break
559 | s = ns
560 | steps.append(es)
561 | rewards.append(er)
562 | decisions.append(ed)
563 |
564 | return steps, rewards, decisions
565 |
566 | def save_model(self, path):
567 | torch.save(self._q.state_dict(), os.path.join(path, 'Q'))
568 |
569 |
570 | class DAR:
571 | """
572 | Simple Dynamic Action Repetition Agent based on double DQN
573 | """
574 |
575 | def __init__(self, state_dim: int, action_dim: int,
576 | num_output_duplication: int, skip_map: dict,
577 | gamma: float, env: gym.Env, eval_env: gym.Env, vision=False):
578 | """
579 | Initialize the DQN Agent
580 | :param state_dim: dimensionality of the input states
581 | :param action_dim: dimensionality of the output actions
582 | :param num_output_duplication: integer that determines how often to duplicate output heads (original is 2)
583 | :param skip_map: determines the skip value associated with each output head
584 | :param gamma: discount factor
585 | :param env: environment to train on
586 | :param eval_env: environment to evaluate on
587 | """
588 | if not vision: # For featurized states
589 | self._q = Q(state_dim, action_dim * num_output_duplication).to(device)
590 | self._q_target = Q(state_dim, action_dim * num_output_duplication).to(device)
591 | else: # For image states, i.e. Atari
592 | self._q = NatureDQN(state_dim, action_dim * num_output_duplication).to(device)
593 | self._q_target = NatureDQN(state_dim, action_dim * num_output_duplication).to(device)
594 |
595 | self._gamma = gamma
596 |
597 | self.batch_size = 32
598 | self.grad_clip_val = 40.0
599 | self.target_net_upd_freq = 500
600 | self.learning_starts = 10_000
601 | self.initial_epsilon = 1.0
602 | self.final_epsilon = 0.01
603 | self.epsilon_timesteps = 200_000
604 | self.train_freq = 4
605 |
606 | self._loss_function = nn.SmoothL1Loss() # huber loss # nn.MSELoss()
607 | self._q_optimizer = optim.Adam(self._q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08)
608 | self._action_dim = action_dim
609 |
610 | self._replay_buffer = ReplayBuffer(5e4)
611 | self._skip_map = skip_map
612 | self._dup_vals = num_output_duplication
613 | self._env = env
614 | self._eval_env = eval_env
615 |
616 | def get_action(self, x: np.ndarray, epsilon: float) -> int:
617 | """
618 | Simple helper to get action epsilon-greedy based on observation x
619 | """
620 | u = np.argmax(self._q(tt(x[None, :])).detach().numpy())
621 | r = np.random.uniform()
622 | if r < epsilon:
623 | return np.random.randint(self._action_dim)
624 | return u
625 |
626 | def train(self, episodes: int, max_env_time_steps: int, epsilon: float, eval_eps: int = 1,
627 | eval_every_n_steps: int = 1, max_train_time_steps: int = 1_000_000):
628 | """
629 | Training loop
630 | :param episodes: maximum number of episodes to train for
631 | :param max_env_time_steps: maximum number of steps in the environment to perform per episode
632 | :param epsilon: constant epsilon for exploration when selecting actions
633 | :param eval_eps: numper of episodes to run for evaluation
634 | :param eval_every_n_steps: interval of steps after which to evaluate the trained agent
635 | :param max_train_time_steps: maximum number of steps to train
636 | """
637 | total_steps = 0
638 | num_update_steps = 0
639 | batch_size = self.batch_size
640 | grad_clip_val = self.grad_clip_val
641 | target_net_upd_freq = self.target_net_upd_freq
642 | learning_starts = self.learning_starts
643 |
644 | start_time = time.time()
645 | for e in range(episodes):
646 | print("%s/%s" % (e + 1, episodes))
647 | s = self._env.reset()
648 | es = 0
649 | for t in range(max_env_time_steps):
650 | if total_steps > self.epsilon_timesteps:
651 | epsilon = self.final_epsilon
652 | else:
653 | epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) * (
654 | total_steps / self.epsilon_timesteps)
655 | a = self.get_action(s, epsilon)
656 |
657 | # convert action id int corresponding behaviour action and skip value
658 | act = a // self._dup_vals # behaviour
659 | rep = a // self._env.action_space.n # skip id
660 | skip = self._skip_map[rep] # skip id to corresponding skip value
661 |
662 | for _ in range(skip + 1): # repeat chosen behaviour action for "skip" steps
663 | ns, r, d, _ = self._env.step(act)
664 | total_steps += 1
665 | es += 1
666 |
667 | ########### Begin Evaluation
668 | if (total_steps % eval_every_n_steps) == 0:
669 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps)
670 | eval_stats = dict(
671 | elapsed_time=time.time() - start_time,
672 | training_steps=total_steps,
673 | training_eps=e,
674 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)),
675 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)),
676 | avg_rew_per_eval_ep=float(np.mean(eval_r)),
677 | std_rew_per_eval_ep=float(np.std(eval_r)),
678 | eval_eps=eval_eps
679 | )
680 |
681 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh:
682 | json.dump(eval_stats, out_fh)
683 | out_fh.write('\n')
684 | ########### End Evaluation
685 |
686 | ### Q-update based double Q learning
687 | self._replay_buffer.add_transition(s, a, ns, r, d)
688 | batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = \
689 | self._replay_buffer.random_next_batch(batch_size)
690 |
691 | target = batch_rewards + (1 - batch_terminal_flags) * self._gamma * \
692 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax(
693 | self._q(batch_next_states), dim=1)]
694 | current_prediction = self._q(batch_states)[torch.arange(batch_size).long(), batch_actions.long()]
695 |
696 | loss = self._loss_function(current_prediction, target.detach())
697 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0):
698 | num_update_steps += 1
699 | self._q_optimizer.zero_grad()
700 | loss.backward()
701 | for param in self._q.parameters():
702 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val)
703 | self._q_optimizer.step()
704 |
705 | if (total_steps % target_net_upd_freq) == 0:
706 | hard_update(self._q_target, self._q)
707 | # soft_update(self._q_target, self._q, 0.01)
708 | ########### End double Q-learning update
709 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps:
710 | break
711 |
712 | s = ns
713 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps:
714 | break
715 | if total_steps >= max_train_time_steps:
716 | break
717 |
718 | # Final evaluation
719 | if (total_steps % eval_every_n_steps) != 0:
720 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps)
721 | eval_stats = dict(
722 | elapsed_time=time.time() - start_time,
723 | training_steps=total_steps,
724 | training_eps=e,
725 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)),
726 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)),
727 | avg_rew_per_eval_ep=float(np.mean(eval_r)),
728 | std_rew_per_eval_ep=float(np.std(eval_r)),
729 | eval_eps=eval_eps
730 | )
731 |
732 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh:
733 | json.dump(eval_stats, out_fh)
734 | out_fh.write('\n')
735 |
736 | def eval(self, episodes: int, max_env_time_steps: int):
737 | """
738 | Simple method that evaluates the agent with fixed epsilon = 0
739 | :param episodes: max number of episodes to play
740 | :param max_env_time_steps: max number of max_env_time_steps to play
741 |
742 | :returns (steps per episode), (reward per episode), (decisions per episode)
743 | """
744 | steps, rewards, decisions = [], [], []
745 | with torch.no_grad():
746 | for e in range(episodes):
747 | ed, es, er = 0, 0, 0
748 |
749 | s = self._eval_env.reset()
750 | for _ in count():
751 | # print(self._q(tt(s)))
752 | a = self.get_action(s, 0)
753 | act = a // self._dup_vals
754 | rep = a // self._eval_env.action_space.n
755 | skip = self._skip_map[rep]
756 |
757 | ed += 1
758 |
759 | d = False
760 | for _ in range(skip + 1):
761 | ns, r, d, _ = self._eval_env.step(act)
762 | er += r
763 | es += 1
764 | if es >= max_env_time_steps or d:
765 | break
766 | s = ns
767 | if es >= max_env_time_steps or d:
768 | break
769 | steps.append(es)
770 | rewards.append(er)
771 | decisions.append(ed)
772 |
773 | return steps, rewards, decisions
774 |
775 | def save_model(self, path):
776 | torch.save(self._q.state_dict(), os.path.join(path, 'Q'))
777 |
778 |
779 | class TDQN:
780 | """
781 | TempoRL DQN agent capable of handling more complex state inputs through use of contextualized behaviour actions.
782 | """
783 |
784 | def __init__(self, state_dim, action_dim, skip_dim, gamma, env, eval_env, vision=False, shared=True):
785 | """
786 | Initialize the DQN Agent
787 | :param state_dim: dimensionality of the input states
788 | :param action_dim: dimensionality of the action output
789 | :param skip_dim: dimenionality of the skip output
790 | :param gamma: discount factor
791 | :param env: environment to train on
792 | :param eval_env: environment to evaluate on
793 | :param vision: boolean flag to indicate if the input state is an image or not
794 | :param shared: boolean flag to indicate if a weight sharing input representation is used or not.
795 | """
796 | if not vision:
797 | if shared:
798 | self._q = WeightSharingTQ(state_dim, action_dim, skip_dim).to(device)
799 | self._q_target = WeightSharingTQ(state_dim, action_dim, skip_dim).to(device)
800 | else:
801 | self._q = Q(state_dim, action_dim).to(device)
802 | self._q_target = Q(state_dim, action_dim).to(device)
803 | else:
804 | if shared:
805 | self._q = NatureWeightsharingTQN(state_dim, action_dim, skip_dim).to(device)
806 | self._q_target = NatureWeightsharingTQN(state_dim, action_dim, skip_dim).to(device)
807 | else:
808 | self._q = NatureDQN(state_dim, action_dim).to(device)
809 | self._q_target = NatureDQN(state_dim, action_dim).to(device)
810 |
811 | if shared:
812 | self._skip_q = self._q
813 | else:
814 | if not vision:
815 | self._skip_q = TQ(state_dim, skip_dim).to(device)
816 | else:
817 | self._skip_q = NatureTQN(state_dim, skip_dim).to(device)
818 | print('Using {} as Q'.format(str(self._q)))
819 | print('Using {} as skip-Q\n{}'.format(str(self._skip_q), '#' * 80))
820 |
821 | self._gamma = gamma
822 | self._action_dim = action_dim
823 | self._skip_dim = skip_dim
824 |
825 | self.batch_size = 32
826 | self.grad_clip_val = 40.0
827 | self.target_net_upd_freq = 500
828 | self.learning_starts = 10_000
829 | self.initial_epsilon = 1.0
830 | self.final_epsilon = 0.01
831 | self.epsilon_timesteps = 200_000
832 | self.train_freq = 4
833 |
834 | self._loss_function = nn.SmoothL1Loss() # huber loss # nn.MSELoss()
835 | self._skip_loss_function = nn.SmoothL1Loss() # nn.MSELoss()
836 | self._q_optimizer = optim.Adam(self._q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08)
837 | self._skip_q_optimizer = optim.Adam(self._skip_q.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08)
838 |
839 | self._replay_buffer = ReplayBuffer(5e4)
840 | self._skip_replay_buffer = NoneConcatSkipReplayBuffer(5e4)
841 | self._env = env
842 | self._eval_env = eval_env
843 |
844 | def get_action(self, x: np.ndarray, epsilon: float) -> int:
845 | """
846 | Simple helper to get action epsilon-greedy based on observation x
847 | """
848 | u = np.argmax(self._q(tt(x[None, :])).cpu().detach().numpy())
849 | r = np.random.uniform()
850 | if r < epsilon:
851 | return np.random.randint(self._action_dim)
852 | return u
853 |
854 | def get_skip(self, x: np.ndarray, a: np.ndarray, epsilon: float) -> int:
855 | """
856 | Simple helper to get the skip epsilon-greedy based on observation x conditioned on behaviour action a
857 | """
858 | u = np.argmax(self._skip_q(tt(x[None, :]), tt(a[None, :])).detach().numpy())
859 | r = np.random.uniform()
860 | if r < epsilon:
861 | return np.random.randint(self._skip_dim)
862 | return u
863 |
864 | def eval(self, episodes: int, max_env_time_steps: int):
865 | """
866 | Simple method that evaluates the agent with fixed epsilon = 0
867 | :param episodes: max number of episodes to play
868 | :param max_env_time_steps: max number of max_env_time_steps to play
869 |
870 | :returns (steps per episode), (reward per episode), (decisions per episode)
871 | """
872 | steps, rewards, decisions = [], [], []
873 | with torch.no_grad():
874 | for e in range(episodes):
875 | ed, es, er = 0, 0, 0
876 |
877 | s = self._eval_env.reset()
878 | for _ in count():
879 | a = self.get_action(s, 0)
880 | skip = self.get_skip(s, np.array([a]), 0)
881 | ed += 1
882 |
883 | d = False
884 | for _ in range(skip + 1):
885 | ns, r, d, _ = self._eval_env.step(a)
886 | er += r
887 | es += 1
888 | if es >= max_env_time_steps or d:
889 | break
890 | s = ns
891 | if es >= max_env_time_steps or d:
892 | break
893 | steps.append(es)
894 | rewards.append(er)
895 | decisions.append(ed)
896 |
897 | return steps, rewards, decisions
898 |
899 | def train(self, episodes: int, max_env_time_steps: int, epsilon: float, eval_eps: int = 1,
900 | eval_every_n_steps: int = 1, max_train_time_steps: int = 1_000_000):
901 | """
902 | Training loop
903 | :param episodes: maximum number of episodes to train for
904 | :param max_env_time_steps: maximum number of steps in the environment to perform per episode
905 | :param epsilon: constant epsilon for exploration when selecting actions
906 | :param eval_eps: numper of episodes to run for evaluation
907 | :param eval_every_n_steps: interval of steps after which to evaluate the trained agent
908 | :param max_train_time_steps: maximum number of steps to train
909 | """
910 | total_steps = 0
911 | num_update_steps = 0
912 | batch_size = self.batch_size
913 | grad_clip_val = self.grad_clip_val
914 | target_net_upd_freq = self.target_net_upd_freq
915 | learning_starts = self.learning_starts
916 |
917 | start_time = time.time()
918 |
919 | for e in range(episodes):
920 | print("# Episode: %s/%s" % (e + 1, episodes))
921 | s = self._env.reset()
922 | es = 0
923 | for _ in count():
924 |
925 | if total_steps > self.epsilon_timesteps:
926 | epsilon = self.final_epsilon
927 | else:
928 | epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) * (
929 | total_steps / self.epsilon_timesteps)
930 |
931 | a = self.get_action(s, epsilon)
932 | skip = self.get_skip(s, np.array([a]), epsilon) # get skip with the selected action as context
933 |
934 | d = False
935 | skip_states, skip_rewards = [], []
936 | for curr_skip in range(skip + 1): # repeat the selected action for "skip" times
937 | ns, r, d, info_ = self._env.step(a)
938 | total_steps += 1
939 | es += 1
940 | skip_states.append(s) # keep track of all observed skips
941 | skip_rewards.append(r)
942 |
943 | #### Begin Evaluation
944 | if (total_steps % eval_every_n_steps) == 0:
945 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps)
946 | eval_stats = dict(
947 | elapsed_time=time.time() - start_time,
948 | training_steps=total_steps,
949 | training_eps=e,
950 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)),
951 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)),
952 | avg_rew_per_eval_ep=float(np.mean(eval_r)),
953 | std_rew_per_eval_ep=float(np.std(eval_r)),
954 | eval_eps=eval_eps
955 | )
956 |
957 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh:
958 | json.dump(eval_stats, out_fh)
959 | out_fh.write('\n')
960 | #### End Evaluation
961 |
962 | # Update the skip replay buffer with all observed skips.
963 | # if curr_skip == (skip + 1):
964 | skip_id = 0
965 | for start_state in skip_states:
966 | skip_reward = 0
967 | for exp, r in enumerate(skip_rewards[skip_id:]): # make sure to properly discount
968 | skip_reward += np.power(self._gamma, exp) * r
969 |
970 | self._skip_replay_buffer.add_transition(start_state, curr_skip - skip_id, ns,
971 | skip_reward, d, curr_skip - skip_id + 1,
972 | np.array([a])) # also keep track of the behavior action
973 | skip_id += 1
974 |
975 | # Skip Q update based on double DQN where target is behavior Q
976 | batch_states, batch_actions, batch_next_states, batch_rewards,\
977 | batch_terminal_flags, batch_lengths, batch_behaviours = \
978 | self._skip_replay_buffer.random_next_batch(batch_size)
979 |
980 | target = batch_rewards + (1 - batch_terminal_flags) * np.power(self._gamma, batch_lengths) * \
981 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax(
982 | self._q(batch_next_states), dim=1)]
983 | current_prediction = self._skip_q(batch_states, batch_behaviours)[
984 | torch.arange(batch_size).long(), batch_actions.long()]
985 |
986 | loss = self._skip_loss_function(current_prediction, target.detach())
987 |
988 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0):
989 | num_update_steps += 1
990 | self._skip_q_optimizer.zero_grad()
991 | loss.backward()
992 | for param in self._skip_q.parameters():
993 | if param.grad is None:
994 | pass
995 | # print("##### Skip Q Parameter with grad = None:", param.name)
996 | else:
997 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val)
998 | self._skip_q_optimizer.step()
999 |
1000 | # Action Q update based on double DQN with normal target
1001 | self._replay_buffer.add_transition(s, a, ns, r, d)
1002 | batch_states, batch_actions, batch_next_states, batch_rewards, batch_terminal_flags = \
1003 | self._replay_buffer.random_next_batch(batch_size)
1004 |
1005 | target = batch_rewards + (1 - batch_terminal_flags) * self._gamma * \
1006 | self._q_target(batch_next_states)[torch.arange(batch_size).long(), torch.argmax(
1007 | self._q(batch_next_states), dim=1)]
1008 | current_prediction = self._q(batch_states)[torch.arange(batch_size).long(), batch_actions.long()]
1009 |
1010 | loss = self._loss_function(current_prediction, target.detach())
1011 |
1012 | if (total_steps > learning_starts) and (total_steps % self.train_freq == 0):
1013 | self._q_optimizer.zero_grad()
1014 | loss.backward()
1015 | for param in self._q.parameters():
1016 | if param.grad is None:
1017 | pass
1018 | # print("##### Q Parameter with grad = None:", param.name)
1019 | else:
1020 | param.grad.data.clamp_(-grad_clip_val, grad_clip_val)
1021 | self._q_optimizer.step()
1022 |
1023 | if (total_steps % target_net_upd_freq) == 0:
1024 | hard_update(self._q_target, self._q)
1025 | # soft_update(self._q_target, self._q, 0.01)
1026 |
1027 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps:
1028 | break
1029 |
1030 | s = ns
1031 | if es >= max_env_time_steps or d or total_steps >= max_train_time_steps:
1032 | break
1033 | if total_steps >= max_train_time_steps:
1034 | break
1035 |
1036 | # final evaluation
1037 | if (total_steps % eval_every_n_steps) != 0:
1038 | eval_s, eval_r, eval_d = self.eval(eval_eps, max_env_time_steps)
1039 | eval_stats = dict(
1040 | elapsed_time=time.time() - start_time,
1041 | training_steps=total_steps,
1042 | training_eps=e,
1043 | avg_num_steps_per_eval_ep=float(np.mean(eval_s)),
1044 | avg_num_decs_per_eval_ep=float(np.mean(eval_d)),
1045 | avg_rew_per_eval_ep=float(np.mean(eval_r)),
1046 | std_rew_per_eval_ep=float(np.std(eval_r)),
1047 | eval_eps=eval_eps
1048 | )
1049 |
1050 | with open(os.path.join(out_dir, 'eval_scores.json'), 'a+') as out_fh:
1051 | json.dump(eval_stats, out_fh)
1052 | out_fh.write('\n')
1053 |
1054 | def save_model(self, path):
1055 | torch.save(self._q.state_dict(), os.path.join(path, 'Q'))
1056 | torch.save(self._skip_q.state_dict(), os.path.join(path, 'TQ'))
1057 |
1058 |
1059 | if __name__ == "__main__":
1060 | import argparse
1061 |
1062 | outdir_suffix_dict = {'none': '', 'empty': '', 'time': '%Y_%m_%d_%H%M%S',
1063 | 'seed': '{:d}', 'params': '{:d}_{:d}_{:d}',
1064 | 'paramsseed': '{:d}_{:d}_{:d}_{:d}'}
1065 | parser = argparse.ArgumentParser('TempoRL')
1066 | parser.add_argument('--episodes', '-e',
1067 | default=100,
1068 | type=int,
1069 | help='Number of training episodes.')
1070 | parser.add_argument('--training-steps', '-t',
1071 | default=1_000_000,
1072 | type=int,
1073 | help='Number of training episodes.')
1074 |
1075 | parser.add_argument('--out-dir',
1076 | default=None,
1077 | type=str,
1078 | help='Directory to save results. Defaults to tmp dir.')
1079 | parser.add_argument('--out-dir-suffix',
1080 | default='paramsseed',
1081 | type=str,
1082 | choices=list(outdir_suffix_dict.keys()),
1083 | help='Created suffix of directory to save results.')
1084 | parser.add_argument('--seed', '-s',
1085 | default=12345,
1086 | type=int,
1087 | help='Seed')
1088 | parser.add_argument('--eval-after-n-steps',
1089 | default=10 ** 3,
1090 | type=int,
1091 | help='After how many steps to evaluate')
1092 | parser.add_argument('--eval-n-episodes',
1093 | default=1,
1094 | type=int,
1095 | help='How many episodes to evaluate')
1096 |
1097 | # DQN -> normal double DQN agent
1098 | # DAR -> Dynamic action repetition agent based on normal DDQN with repeated output heads for different skip values
1099 | # tqn -> Not usable with vision states
1100 | # tdqn -> TempoRL DDQN with shared state representation for behaviour and skip Qs
1101 | # t-dqn -> TempoRL DDQN without shared state representation for behaviour and skip Qs
1102 | parser.add_argument('--agent',
1103 | choices=['dqn', 'dar', 'tdqn', 't-dqn'],
1104 | type=str.lower,
1105 | help='Which agent to train',
1106 | default='tdqn')
1107 | parser.add_argument('--skip-net-max-skips',
1108 | type=int,
1109 | default=10,
1110 | help='Maximum skip-size')
1111 | parser.add_argument('--env-max-steps',
1112 | default=200,
1113 | type=int,
1114 | help='Maximal steps in environment before termination.',
1115 | dest='env_ms')
1116 | parser.add_argument('--dar-base', default=None,
1117 | type=int,
1118 | help='DAR base')
1119 | parser.add_argument('--sparse', action='store_true')
1120 | parser.add_argument('--no-frame-skip', action='store_true')
1121 | parser.add_argument('--84x84', action='store_true', dest='large_image')
1122 | parser.add_argument('--dar-A', default=None, type=int)
1123 | parser.add_argument('--dar-B', default=None, type=int)
1124 | parser.add_argument('--env',
1125 | type=str,
1126 | help="Possible envs = 'mountain', 'moon' or any Atari env",
1127 | default='mountain')
1128 |
1129 | # setup output dir
1130 | args = parser.parse_args()
1131 | torch.manual_seed(args.seed)
1132 | np.random.seed(args.seed)
1133 | random.seed(args.seed)
1134 | outdir_suffix_dict['seed'] = outdir_suffix_dict['seed'].format(args.seed)
1135 | epis = args.episodes if args.episodes else -1
1136 | outdir_suffix_dict['params'] = outdir_suffix_dict['params'].format(
1137 | epis, args.skip_net_max_skips, args.env_ms)
1138 | outdir_suffix_dict['paramsseed'] = outdir_suffix_dict['paramsseed'].format(
1139 | epis, args.skip_net_max_skips, args.env_ms, args.seed)
1140 |
1141 | out_dir = experiments.prepare_output_dir(args, user_specified_dir=args.out_dir,
1142 | time_format=outdir_suffix_dict[args.out_dir_suffix])
1143 |
1144 | if args.env not in ['mountain', 'moon']:
1145 | from utils.env_wrappers import make_env, make_env_old
1146 |
1147 | # Setup Envs
1148 | game = ''.join([g.capitalize() for g in args.env.split('_')])
1149 |
1150 | if args.no_frame_skip:
1151 | eval_game = '{}NoFrameskip-v4'.format(game)
1152 | game = '{}NoFrameskip-v0'.format(game)
1153 | env = make_env_old(game, dim=84 if args.large_image else 42)
1154 | eval_env = make_env_old(eval_game, dim=84 if args.large_image else 42)
1155 | else:
1156 | eval_game = '{}Deterministic-v4'.format(game)
1157 | game = '{}Deterministic-v0'.format(game)
1158 | env = make_env(game, dim=84 if args.large_image else 42)
1159 | eval_env = make_env(eval_game, dim=84 if args.large_image else 42)
1160 |
1161 | # Setup Agent
1162 | state_dim = env.observation_space.shape[0] # (4, 42, 42) or (4, 84, 84) for PyTorch order
1163 | action_dim = env.action_space.n
1164 | if args.agent == 'dqn':
1165 | agent = DQN(state_dim, action_dim, gamma=0.99, env=env, eval_env=eval_env, vision=True)
1166 | elif args.agent == 'tdqn':
1167 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env, eval_env=eval_env,
1168 | vision=True)
1169 | elif args.agent == 't-dqn':
1170 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env,
1171 | eval_env=eval_env, shared=False, vision=True)
1172 | elif args.agent == 'dar':
1173 | if args.dar_A is not None and args.dar_B is not None:
1174 | skip_map = {0: args.dar_A, 1: args.dar_B}
1175 | elif args.dar_base:
1176 | skip_map = {a: args.dar_base ** a for a in range(args.skip_net_max_skips)}
1177 | else:
1178 | skip_map = {a: a for a in range(args.skip_net_max_skips)}
1179 | agent = DAR(state_dim, action_dim, args.skip_net_max_skips, skip_map, gamma=0.99, env=env,
1180 | eval_env=eval_env, vision=True)
1181 | else: # Simple featurized environments
1182 | # Setup Env
1183 | if args.env == 'mountain':
1184 | if args.sparse:
1185 | from gym.envs.classic_control import MountainCarEnv
1186 | env = MountainCarEnv()
1187 | eval_env = MountainCarEnv()
1188 | elif args.env == 'moon':
1189 | env = gym.make('LunarLander-v2')
1190 | eval_env = gym.make('LunarLander-v2')
1191 |
1192 | # Setup agent
1193 | state_dim = env.observation_space.shape[0]
1194 | action_dim = env.action_space.n
1195 | if args.agent == 'dqn':
1196 | agent = DQN(state_dim, action_dim, gamma=0.99, env=env, eval_env=eval_env)
1197 | elif args.agent == 'tdqn':
1198 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env, eval_env=eval_env)
1199 | elif args.agent == 't-dqn':
1200 | agent = TDQN(state_dim, action_dim, args.skip_net_max_skips, gamma=0.99, env=env,
1201 | eval_env=eval_env, shared=False)
1202 | elif args.agent == 'dar':
1203 | if args.dar_A is not None and args.dar_B is not None:
1204 | skip_map = {0: args.dar_A, 1: args.dar_B}
1205 | elif args.dar_base:
1206 | skip_map = {a: args.dar_base ** a for a in range(args.skip_net_max_skips)}
1207 | else:
1208 | skip_map = {a: a for a in range(args.skip_net_max_skips)}
1209 | agent = DAR(state_dim, action_dim, args.skip_net_max_skips, skip_map, gamma=0.99, env=env,
1210 | eval_env=eval_env)
1211 | else:
1212 | raise NotImplementedError
1213 |
1214 | episodes = args.episodes
1215 | max_env_time_steps = args.env_ms
1216 | epsilon = 0.2
1217 |
1218 | agent.train(episodes, max_env_time_steps, epsilon, args.eval_n_episodes, args.eval_after_n_steps,
1219 | max_train_time_steps=args.training_steps)
1220 | os.mkdir(os.path.join(out_dir, 'final'))
1221 | agent.save_model(os.path.join(out_dir, 'final'))
1222 |
--------------------------------------------------------------------------------
/run_ddpg_experiments.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on code from Scott Fujimoto https://github.com/sfujim/TD3
3 | We adapted the DDPG code he provides to allow for FiGAR and TempoRL variants
4 | This code is originally under the MIT license https://github.com/sfujim/TD3/blob/master/LICENSE
5 | """
6 |
7 | import argparse
8 |
9 | import gym
10 | import numpy as np
11 | import torch
12 |
13 | from DDPG import utils
14 | from DDPG.FiGAR import DDPG as FiGARDDPG
15 | from DDPG.TempoRL import DDPG as TempoRLDDPG
16 | from DDPG.vanilla import DDPG
17 | from utils import experiments
18 |
19 |
20 | # Runs policy for X episodes and returns average reward
21 | # A fixed seed is used for the eval environment
22 | def eval_policy(policy, env_name, seed, eval_episodes=10, FiGAR=False, TempoRL=False):
23 | eval_env = gym.make(env_name)
24 | special = 'PendulumDecs-v0' == env_name
25 | if special:
26 | eval_env = utils.Render(eval_env, episode_modulo=10)
27 | eval_env.seed(seed + 100)
28 |
29 | avg_reward = 0.
30 | avg_steps = 0.
31 | avg_decs = 0.
32 | for _ in range(eval_episodes):
33 | state, done = eval_env.reset(), False
34 | repetition = 1
35 | while not done:
36 | if FiGAR:
37 | action, repetition, rps = policy.select_action(np.array(state))
38 | repetition = repetition[0] + 1
39 |
40 | elif TempoRL:
41 | action = policy.select_action(np.array(state))
42 | repetition = np.argmax(policy.select_skip(np.array(state), action)) + 1
43 |
44 | else:
45 | action = policy.select_action(np.array(state))
46 |
47 | if special:
48 | eval_env.set_decision_point(True)
49 | avg_decs += 1
50 |
51 | for _ in range(repetition):
52 | state, reward, done, _ = eval_env.step(action)
53 | avg_reward += reward
54 | avg_steps += 1
55 | if done:
56 | break
57 | eval_env.close()
58 |
59 | avg_reward /= eval_episodes
60 | avg_decs /= eval_episodes
61 | avg_steps /= eval_episodes
62 |
63 | print("---------------------------------------")
64 | print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
65 | print("---------------------------------------")
66 | return avg_reward, avg_decs, avg_steps
67 |
68 |
69 | if __name__ == "__main__":
70 |
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument('--out-dir',
73 | default=None,
74 | type=str,
75 | help='Directory to save results. Defaults to tmp dir.')
76 | parser.add_argument("--policy", default="TempoRLDDPG") # Policy name (DDPG, FiGARDDPG or our TempoRLDDPG)
77 | parser.add_argument("--env", default="Pendulum-v0") # OpenAI gym environment name
78 | parser.add_argument("--seed", default=0, type=int) # Sets Gym, PyTorch and Numpy seeds
79 | parser.add_argument("--start_timesteps", default=1e3, type=int) # Time steps initial random policy is used
80 | parser.add_argument("--eval_freq", default=500, type=int) # How often (time steps) we evaluate
81 | parser.add_argument("--max_timesteps", default=2e4, type=int) # Max time steps to run environment
82 | parser.add_argument("--expl_noise", default=0.1) # Std of Gaussian exploration noise
83 | parser.add_argument("--batch_size", default=256, type=int) # Batch size for both actor and critic
84 | parser.add_argument("--discount", default=0.99) # Discount factor
85 | parser.add_argument("--tau", default=0.005) # Target network update rate
86 | parser.add_argument("--max-skip", "--max-rep", default=20, type=int,
87 | dest='max_rep') # Maximum Skip length to use with FiGAR or TempoRL
88 | parser.add_argument("--save_model", action="store_true") # Save model and optimizer parameters
89 | parser.add_argument("--load_model", default="") # Model load file name, "" doesn't load, "default" uses file_name
90 | args = parser.parse_args()
91 |
92 | outdir_suffix_dict = dict()
93 | outdir_suffix_dict['seed'] = '{:d}'.format(args.seed)
94 | out_dir = experiments.prepare_output_dir(args, user_specified_dir=args.out_dir,
95 | time_format=outdir_suffix_dict['seed'])
96 |
97 | file_name = f"{args.policy}_{args.env}_{args.seed}"
98 | print("---------------------------------------")
99 | print(f"Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}")
100 | print("---------------------------------------")
101 |
102 | env = gym.make(args.env)
103 |
104 | # Set seeds
105 | env.seed(args.seed)
106 | torch.manual_seed(args.seed)
107 | np.random.seed(args.seed)
108 |
109 | state_dim = env.observation_space.shape[0]
110 | action_dim = env.action_space.shape[0]
111 | max_action = float(env.action_space.high[0])
112 |
113 | kwargs = {
114 | "state_dim": state_dim,
115 | "action_dim": action_dim,
116 | "max_action": max_action,
117 | "discount": args.discount,
118 | "tau": args.tau,
119 | }
120 | max_rep = args.max_rep
121 |
122 | # Initialize policy
123 | if args.policy == "DDPG":
124 | policy = DDPG(**kwargs)
125 | elif args.policy.startswith('FiGAR'):
126 | kwargs['repetition_dim'] = max_rep
127 | policy = FiGARDDPG(**kwargs)
128 | elif args.policy.startswith('TempoRL'):
129 | kwargs['skip_dim'] = max_rep
130 | policy = TempoRLDDPG(**kwargs)
131 | else:
132 | raise NotImplementedError
133 |
134 | if args.load_model != "":
135 | policy_file = args.load_model
136 | policy.load(f"{out_dir}/{policy_file}")
137 |
138 | skip_replay_buffer = None
139 | if 'FiGAR' in args.policy:
140 | replay_buffer = utils.FiGARReplayBuffer(state_dim, action_dim, rep_dim=max_rep)
141 | else:
142 | replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
143 | if 'TempoRL' in args.policy:
144 | skip_replay_buffer = utils.FiGARReplayBuffer(state_dim, action_dim, rep_dim=1)
145 |
146 | # Evaluate untrained policy
147 | evaluations = [[0, *eval_policy(policy, args.env, args.seed, FiGAR='FiGAR' in args.policy,
148 | TempoRL='TempoRL' in args.policy)]]
149 |
150 | state, done = env.reset(), False
151 | episode_reward = 0
152 | episode_timesteps = 0
153 | episode_num = 0
154 |
155 | t = 0
156 | while t < int(args.max_timesteps):
157 |
158 | episode_timesteps += 1
159 |
160 | # Select action randomly or according to policy
161 | if t < args.start_timesteps: # Before learning starts we sample actions uniformly at random
162 | action = env.action_space.sample()
163 | if args.policy.startswith('FiGAR'):
164 | # FiGAR uses a second actor network to learn the repetition value so we have to create
165 | # initial distirbution over the possible repetition values
166 | repetition_probs = np.random.random(max_rep)
167 |
168 |
169 | def softmax(x):
170 | """Compute softmax values for each sets of scores in x."""
171 | e_x = np.exp(x - np.max(x))
172 | return e_x / e_x.sum()
173 |
174 |
175 | repetition_probs = softmax(repetition_probs)
176 | repetition = np.argmax(repetition_probs)
177 | elif args.policy.startswith('TempoRL'):
178 | # TempoRL uses a simple DQN for which we can simply sample from the possible skip values
179 | repetition = np.random.randint(max_rep) + 1
180 | else:
181 | repetition = 1
182 | else:
183 | # Get Action and skip values
184 | if 'FiGAR' in args.policy:
185 | # For FiGAR we treat the action policy exploration as in standard DDPG
186 | action, repetition, repetition_probs = policy.select_action(np.array(state))
187 | action = (
188 | action + np.random.normal(0, max_action * args.expl_noise, size=action_dim)
189 | ).clip(-max_action, max_action)
190 | # The Repetition policy however uses epsilon greedy exploration as described in the original paper
191 | # https://arxiv.org/pdf/1702.06054.pdf
192 | if np.random.random() < args.expl_noise:
193 | repetition = np.random.randint(max_rep) + 1 # + 1 since randint samples from [0, max_rep)
194 | else:
195 | repetition = repetition[0]
196 | elif 'TempoRL' in args.policy:
197 | # TempoRL does not interfere with the action policy and its exploration
198 | action = (
199 | policy.select_action(np.array(state))
200 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim)
201 | ).clip(-max_action, max_action)
202 |
203 | # the skip policy uses epsilon greedy exploration for learning
204 | repetition = policy.select_skip(state, action)
205 | if np.random.random() < args.expl_noise:
206 | repetition = np.random.randint(max_rep) + 1 # + 1 sonce randint samples from [0, max_rep)
207 | else:
208 | repetition = np.argmax(repetition) + 1 # + 1 since indices start at 0
209 | else:
210 | # Standard DDPG
211 | action = (
212 | policy.select_action(np.array(state))
213 | + np.random.normal(0, max_action * args.expl_noise, size=action_dim)
214 | ).clip(-max_action, max_action)
215 | repetition = 1 # Never skip with vanilla DDPG
216 |
217 | # Perform action
218 | skip_states, skip_rewards = [], [] # only used for TempoRL to build the local conectedness graph
219 | for curr_skip in range(repetition):
220 | next_state, reward, done, _ = env.step(action)
221 | t += 1
222 | done_bool = float(done) if episode_timesteps < env._max_episode_steps else 0
223 | skip_states.append(state)
224 | skip_rewards.append(reward)
225 |
226 | # Store data in replay buffer
227 | if 'FiGAR' in args.policy:
228 | # To train the second actor with FiGAR, we need to keep track of its output "repetition_probs"
229 | replay_buffer.add(state, action, repetition_probs, next_state, reward, done_bool)
230 | else:
231 | # Vanilla DDPG
232 | replay_buffer.add(state, action, next_state, reward, done_bool)
233 | # In addition to the normal replay_buffer
234 | # TempoRL uses a second replay buffer that is only used for training the skip network
235 | if 'TempoRL' in args.policy:
236 | # Update the skip buffer with all observed transitions in the local connectedness graph
237 | skip_id = 0
238 | for start_state in skip_states:
239 | skip_reward = 0
240 | for exp, r in enumerate(skip_rewards[skip_id:]):
241 | skip_reward += np.power(policy.discount, exp) * r # make sure to properly discount rewards
242 | skip_replay_buffer.add(start_state, action, curr_skip - skip_id, next_state, skip_reward, done)
243 | skip_id += 1
244 |
245 | state = next_state
246 | episode_reward += reward
247 |
248 | # Train agent after collecting sufficient data
249 | if t >= args.start_timesteps:
250 | policy.train(replay_buffer, args.batch_size)
251 | if 'TempoRL' in args.policy:
252 | policy.train_skip(skip_replay_buffer, args.batch_size)
253 |
254 | if done:
255 | # +1 to account for 0 indexing. +0 on ep_timesteps since it will increment +1 even if done=True
256 | print(
257 | f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_reward:.3f}")
258 | # Reset environment
259 | state, done = env.reset(), False
260 | episode_reward = 0
261 | episode_timesteps = 0
262 | episode_num += 1
263 | break
264 |
265 | # Evaluate episode
266 | if (t + 1) % args.eval_freq == 0:
267 | evaluations.append([t, *eval_policy(policy, args.env, args.seed, FiGAR='FiGAR' in args.policy,
268 | TempoRL='TempoRL' in args.policy)])
269 | np.save(f"{out_dir}/{file_name}", evaluations)
270 | if args.save_model:
271 | policy.save(f"{out_dir}/{file_name}")
272 |
--------------------------------------------------------------------------------
/run_tabular_experiments.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | from utils import experiments
7 |
8 | from grid_envs import GridCore
9 |
10 |
11 | def make_epsilon_greedy_policy(Q: defaultdict, epsilon: float, nA: int) -> callable:
12 | """
13 | Creates an epsilon-greedy policy based on a given Q-function and epsilon.
14 | I.e. create weight vector from which actions get sampled.
15 |
16 | :param Q: tabular state-action lookup function
17 | :param epsilon: exploration factor
18 | :param nA: size of action space to consider for this policy
19 | """
20 |
21 | def policy_fn(observation):
22 | policy = np.ones(nA) * epsilon / nA
23 | best_action = np.random.choice(np.flatnonzero( # random choice for tie-breaking only
24 | Q[observation] == Q[observation].max()
25 | ))
26 | policy[best_action] += (1 - epsilon)
27 | return policy
28 |
29 | return policy_fn
30 |
31 |
32 | def get_decay_schedule(start_val: float, decay_start: int, num_steps: int, type_: str):
33 | """
34 | Create epsilon decay schedule
35 |
36 | :param start_val: Start decay from this value (i.e. 1)
37 | :param decay_start: number of iterations to start epsilon decay after
38 | :param num_steps: Total number of steps to decay over
39 | :param type_: Which strategy to use. Implemented choices: 'const', 'log', 'linear'
40 | :return:
41 | """
42 | if type_ == 'const':
43 | return np.array([start_val for _ in range(num_steps)])
44 | elif type_ == 'log':
45 | return np.hstack([[start_val for _ in range(decay_start)],
46 | np.logspace(np.log10(start_val), np.log10(0.000001), (num_steps - decay_start))])
47 | elif type_ == 'linear':
48 | return np.hstack([[start_val for _ in range(decay_start)],
49 | np.linspace(start_val, 0, (num_steps - decay_start), endpoint=True)])
50 | else:
51 | raise NotImplementedError
52 |
53 |
54 | def td_update(q: defaultdict, state: int, action: int, reward: float, next_state: int, gamma: float, alpha: float):
55 | """ Simple TD update rule """
56 | # TD update
57 | best_next_action = np.random.choice(np.flatnonzero(q[next_state] == q[next_state].max())) # greedy best next
58 | td_target = reward + gamma * q[next_state][best_next_action]
59 | td_delta = td_target - q[state][action]
60 | return q[state][action] + alpha * td_delta
61 |
62 |
63 | def q_learning(
64 | environment: GridCore,
65 | num_episodes: int,
66 | discount_factor: float = 1.0,
67 | alpha: float = 0.5,
68 | epsilon: float = 0.1,
69 | epsilon_decay: str = 'const',
70 | decay_starts: int = 0,
71 | eval_every: int = 10,
72 | render_eval: bool = True):
73 | """
74 | Vanilla tabular Q-learning algorithm
75 | :param environment: which environment to use
76 | :param num_episodes: number of episodes to train
77 | :param discount_factor: discount factor used in TD updates
78 | :param alpha: learning rate used in TD updates
79 | :param epsilon: exploration fraction (either constant or starting value for schedule)
80 | :param epsilon_decay: determine type of exploration (constant, linear/exponential decay schedule)
81 | :param decay_starts: After how many episodes epsilon decay starts
82 | :param eval_every: Number of episodes between evaluations
83 | :param render_eval: Flag to activate/deactivate rendering of evaluation runs
84 | :return: training and evaluation statistics (i.e. rewards and episode lengths)
85 | """
86 | assert 0 <= discount_factor <= 1, 'Lambda should be in [0, 1]'
87 | assert 0 <= epsilon <= 1, 'epsilon has to be in [0, 1]'
88 | assert alpha > 0, 'Learning rate has to be positive'
89 | # The action-value function.
90 | # Nested dict that maps state -> (action -> action-value).
91 | Q = defaultdict(lambda: np.zeros(environment.action_space.n))
92 |
93 | # Keeps track of episode lengths and rewards
94 | rewards = []
95 | lens = []
96 | test_rewards = []
97 | test_lens = []
98 | train_steps_list = []
99 | test_steps_list = []
100 |
101 | epsilon_schedule = get_decay_schedule(epsilon, decay_starts, num_episodes, epsilon_decay)
102 | for i_episode in range(num_episodes + 1):
103 | # print('#' * 100)
104 | epsilon = epsilon_schedule[min(i_episode, num_episodes - 1)]
105 | # The policy we're following
106 | policy = make_epsilon_greedy_policy(Q, epsilon, environment.action_space.n)
107 | policy_state = environment.reset()
108 | episode_length, cummulative_reward = 0, 0
109 | while True: # roll out episode
110 | policy_action = np.random.choice(list(range(environment.action_space.n)), p=policy(policy_state))
111 | s_, policy_reward, policy_done, _ = environment.step(policy_action)
112 | cummulative_reward += policy_reward
113 | episode_length += 1
114 |
115 | Q[policy_state][policy_action] = td_update(Q, policy_state, policy_action,
116 | policy_reward, s_, discount_factor, alpha)
117 |
118 | if policy_done:
119 | break
120 | policy_state = s_
121 | rewards.append(cummulative_reward)
122 | lens.append(episode_length)
123 | train_steps_list.append(environment.total_steps)
124 |
125 | # evaluation with greedy policy
126 | test_steps = 0
127 | if i_episode % eval_every == 0:
128 | policy_state = environment.reset()
129 | episode_length, cummulative_reward = 0, 0
130 | if render_eval:
131 | environment.render()
132 | while True: # roll out episode
133 | policy_action = np.random.choice(np.flatnonzero(Q[policy_state] == Q[policy_state].max()))
134 | environment.total_steps -= 1 # don't count evaluation steps
135 | s_, policy_reward, policy_done, _ = environment.step(policy_action)
136 | test_steps += 1
137 | if render_eval:
138 | environment.render()
139 | s_ = s_
140 | cummulative_reward += policy_reward
141 | episode_length += 1
142 | if policy_done:
143 | break
144 | policy_state = s_
145 | test_rewards.append(cummulative_reward)
146 | test_lens.append(episode_length)
147 | test_steps_list.append(test_steps)
148 | print('Done %4d/%4d episodes' % (i_episode, num_episodes))
149 | return (rewards, lens), (test_rewards, test_lens), (train_steps_list, test_steps_list), Q
150 |
151 |
152 | class SkipTransition:
153 | """
154 | Simple helper class to keep track of all transitions observed when skipping through an MDP
155 | """
156 |
157 | def __init__(self, skips, df):
158 | self.state_mat = np.full((skips, skips), -1, dtype=int) # might need to change type for other envs
159 | self.reward_mat = np.full((skips, skips), np.nan, dtype=float)
160 | self.idx = 0
161 | self.df = df
162 |
163 | def add(self, reward, next_state):
164 | """
165 | Add reward and next_state to triangular matrix
166 | :param reward: received reward
167 | :param next_state: state reached
168 | """
169 | self.idx += 1
170 | for i in range(self.idx):
171 | self.state_mat[self.idx - i - 1, i] = next_state
172 | # Automatically discount rewards when adding to corresponding skip
173 | self.reward_mat[self.idx - i - 1, i] = reward * self.df ** i + np.nansum(self.reward_mat[self.idx - i - 1])
174 |
175 |
176 | def temporl_q_learning(
177 | environment: GridCore,
178 | num_episodes: int,
179 | discount_factor: float = 1.0,
180 | alpha: float = 0.5,
181 | epsilon: float = 0.1,
182 | epsilon_decay: str = 'const',
183 | decay_starts: int = 0,
184 | decay_stops: int = None,
185 | eval_every: int = 10,
186 | render_eval: bool = True,
187 | max_skip: int = 7):
188 | """
189 | Implementation of tabular TempoRL
190 | :param environment: which environment to use
191 | :param num_episodes: number of episodes to train
192 | :param discount_factor: discount factor used in TD updates
193 | :param alpha: learning rate used in TD updates
194 | :param epsilon: exploration fraction (either constant or starting value for schedule)
195 | :param epsilon_decay: determine type of exploration (constant, linear/exponential decay schedule)
196 | :param decay_starts: After how many episodes epsilon decay starts
197 | :param decay_stops: Episode after which to stop epsilon decay
198 | :param eval_every: Number of episodes between evaluations
199 | :param render_eval: Flag to activate/deactivate rendering of evaluation runs
200 | :param max_skip: Maximum skip size to use.
201 | :return: training and evaluation statistics (i.e. rewards and episode lengths)
202 | """
203 | temporal_actions = max_skip
204 | action_Q = defaultdict(lambda: np.zeros(environment.action_space.n))
205 | temporal_Q = defaultdict(lambda: np.zeros(temporal_actions))
206 | if not decay_stops:
207 | decay_stops = num_episodes
208 |
209 | epsilon_schedule_action = get_decay_schedule(epsilon, decay_starts, decay_stops, epsilon_decay)
210 | epsilon_schedule_temporal = get_decay_schedule(epsilon, decay_starts, decay_stops, epsilon_decay)
211 | rewards = []
212 | lens = []
213 | test_rewards = []
214 | test_lens = []
215 | train_steps_list = []
216 | test_steps_list = []
217 | for i_episode in range(num_episodes + 1):
218 |
219 | # setup exploration policy for this episode
220 | epsilon_action = epsilon_schedule_action[min(i_episode, num_episodes - 1)]
221 | epsilon_temporal = epsilon_schedule_temporal[min(i_episode, num_episodes - 1)]
222 | action_policy = make_epsilon_greedy_policy(action_Q, epsilon_action, environment.action_space.n)
223 | temporal_policy = make_epsilon_greedy_policy(temporal_Q, epsilon_temporal, temporal_actions)
224 |
225 | episode_r = 0
226 | state = environment.reset() # type: list
227 | action_pol_len = 0
228 | while True: # roll out episode
229 | action = np.random.choice(list(range(environment.action_space.n)), p=action_policy(state))
230 | temporal_state = (state, action)
231 | action_pol_len += 1
232 | temporal_action = np.random.choice(list(range(temporal_actions)), p=temporal_policy(temporal_state))
233 |
234 | s_ = None
235 | done = False
236 | tmp_state = state
237 | skip_transition = SkipTransition(temporal_action + 1, discount_factor)
238 | reward = 0
239 | for tmp_temporal_action in range(temporal_action + 1):
240 | if not done:
241 | # only perform action if we are not done. If we are not done "skipping" though we have to
242 | # still add reward and same state to the skip_transition.
243 | s_, reward, done, _ = environment.step(action)
244 | skip_transition.add(reward, tmp_state)
245 |
246 | # 1-step update of action Q (like in vanilla Q)
247 | action_Q[tmp_state][action] = td_update(action_Q, tmp_state, action,
248 | reward, s_, discount_factor, alpha)
249 |
250 | count = 0
251 | # For all sofar observed transitions compute all forward skip updates
252 | for skip_num in range(skip_transition.idx):
253 | skip = skip_transition.state_mat[skip_num]
254 | rew = skip_transition.reward_mat[skip_num]
255 | skip_start_state = (skip[0], action)
256 |
257 | # Temporal TD update
258 | best_next_action = np.random.choice(
259 | np.flatnonzero(action_Q[s_] == action_Q[s_].max())) # greedy best next
260 | td_target = rew[skip_transition.idx - 1 - count] + (
261 | discount_factor ** (skip_transition.idx - 1)) * action_Q[s_][best_next_action]
262 | td_delta = td_target - temporal_Q[skip_start_state][skip_transition.idx - count - 1]
263 | temporal_Q[skip_start_state][skip_transition.idx - count - 1] += alpha * td_delta
264 | count += 1
265 |
266 | tmp_state = s_
267 | state = s_
268 | if done:
269 | break
270 | rewards.append(episode_r)
271 | lens.append(action_pol_len)
272 | train_steps_list.append(environment.total_steps)
273 |
274 | # ---------------------------------------------- EVALUATION -------------------------------------------------
275 | # ---------------------------------------------- EVALUATION -------------------------------------------------
276 | test_steps = 0
277 | if i_episode % eval_every == 0:
278 | episode_r = 0
279 | state = environment.reset() # type: list
280 | if render_eval:
281 | environment.render(in_control=True)
282 | action_pol_len = 0
283 | while True: # roll out episode
284 | action = np.random.choice(np.flatnonzero(action_Q[state] == action_Q[state].max()))
285 | temporal_state = (state, action)
286 | action_pol_len += 1
287 |
288 | # Examples of different action selection schemes when greedily following a policy
289 | # temporal_action = np.random.choice(
290 | # np.flatnonzero(temporal_Q[temporal_state] == temporal_Q[temporal_state].max()))
291 | temporal_action = np.max( # if there are ties use the larger action
292 | np.flatnonzero(temporal_Q[temporal_state] == temporal_Q[temporal_state].max()))
293 | # temporal_action = np.min( # if there are ties use the smaller action
294 | # np.flatnonzero(temporal_Q[temporal_state] == temporal_Q[temporal_state].max()))
295 |
296 | for i in range(temporal_action + 1):
297 | environment.total_steps -= 1 # don't count evaluation steps
298 | s_, reward, done, _ = environment.step(action)
299 | test_steps += 1
300 | if render_eval:
301 | environment.render(in_control=False)
302 | episode_r += reward
303 | if done:
304 | break
305 | if render_eval:
306 | environment.render(in_control=True)
307 | state = s_
308 | if done:
309 | break
310 | test_rewards.append(episode_r)
311 | test_lens.append(action_pol_len)
312 | test_steps_list.append(test_steps)
313 | print('Done %4d/%4d episodes' % (i_episode, num_episodes))
314 | return (rewards, lens), (test_rewards, test_lens), (train_steps_list, test_steps_list), (action_Q, temporal_Q)
315 |
316 |
317 | if __name__ == '__main__':
318 | import argparse
319 |
320 | outdir_suffix_dict = {'none': '', 'empty': '', 'time': '%Y%m%dT%H%M%S.%f',
321 | 'seed': '{:d}', 'params': '{:d}_{:d}_{:d}',
322 | 'paramsseed': '{:d}_{:d}_{:d}_{:d}'}
323 | parser = argparse.ArgumentParser('Skip-MDP Tabular-Q')
324 | parser.add_argument('--episodes', '-e',
325 | default=10_000,
326 | type=int,
327 | help='Number of training episodes')
328 | parser.add_argument('--out-dir',
329 | default=None,
330 | type=str,
331 | help='Directory to save results. Defaults to tmp dir.')
332 | parser.add_argument('--out-dir-suffix',
333 | default='paramsseed',
334 | type=str,
335 | choices=list(outdir_suffix_dict.keys()),
336 | help='Created suffix of directory to save results.')
337 | parser.add_argument('--seed', '-s',
338 | default=12345,
339 | type=int,
340 | help='Seed')
341 | parser.add_argument('--env-max-steps',
342 | default=100,
343 | type=int,
344 | help='Maximal steps in environment before termination.',
345 | dest='env_ms')
346 | parser.add_argument('--agent-eps-decay',
347 | default='linear',
348 | choices={'linear', 'log', 'const'},
349 | help='Epsilon decay schedule',
350 | dest='agent_eps_d')
351 | parser.add_argument('--agent-eps',
352 | default=1.0,
353 | type=float,
354 | help='Epsilon value. Used as start value when decay linear or log. Otherwise constant value.',
355 | dest='agent_eps')
356 | parser.add_argument('--agent',
357 | default='sq',
358 | choices={'sq', 'q'},
359 | type=str.lower,
360 | help='Agent type to train')
361 | parser.add_argument('--env',
362 | default='lava',
363 | choices={'lava', 'lava2',
364 | 'lava_perc', 'lava2_perc',
365 | 'lava_ng', 'lava2_ng',
366 | 'lava3', 'lava3_perc', 'lava3_ng'},
367 | type=str.lower,
368 | help='Enironment to use')
369 | parser.add_argument('--eval-eps',
370 | default=100,
371 | type=int,
372 | help='After how many episodes to evaluate')
373 | parser.add_argument('--stochasticity',
374 | default=0,
375 | type=float,
376 | help='probability of the selected action failing and instead executing any of the remaining 3')
377 | parser.add_argument('--no-render',
378 | action='store_true',
379 | help='Deactivate rendering of environment evaluation')
380 | parser.add_argument('--max-skips',
381 | type=int,
382 | default=7,
383 | help='Max skip size for tempoRL')
384 |
385 | # setup output dir
386 | args = parser.parse_args()
387 | outdir_suffix_dict['seed'] = outdir_suffix_dict['seed'].format(args.seed)
388 | outdir_suffix_dict['params'] = outdir_suffix_dict['params'].format(
389 | args.episodes, args.max_skips, args.env_ms)
390 | outdir_suffix_dict['paramsseed'] = outdir_suffix_dict['paramsseed'].format(
391 | args.episodes, args.max_skips, args.env_ms, args.seed)
392 |
393 | if not args.no_render:
394 | # Clear screen in ANSI terminal
395 | print('\033c')
396 | print('\x1bc')
397 |
398 | out_dir = experiments.prepare_output_dir(args, user_specified_dir=args.out_dir,
399 | time_format=outdir_suffix_dict[args.out_dir_suffix])
400 |
401 | np.random.seed(args.seed) # seed nump
402 | d = None
403 |
404 | if args.env.startswith('lava'):
405 | import gym
406 | from grid_envs import Bridge6x10Env, Pit6x10Env, ZigZag6x10, ZigZag6x10H
407 |
408 | perc = args.env.endswith('perc')
409 | ng = args.env.endswith('ng')
410 | if args.env.startswith('lava2'):
411 | d = Bridge6x10Env(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng,
412 | act_fail_prob=args.stochasticity, numpy_state=False)
413 | elif args.env.startswith('lava3'):
414 | d = ZigZag6x10(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng, goal=(5, 9),
415 | act_fail_prob=args.stochasticity, numpy_state=False)
416 | elif args.env.startswith('lava4'):
417 | d = ZigZag6x10H(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng, goal=(5, 9),
418 | act_fail_prob=args.stochasticity, numpy_state=False)
419 | else:
420 | d = Pit6x10Env(max_steps=args.env_ms, percentage_reward=perc, no_goal_rew=ng,
421 | act_fail_prob=args.stochasticity, numpy_state=False)
422 |
423 | # setup agent
424 | if args.agent == 'sq':
425 | train_data, test_data, num_steps, (action_Q, t_Q) = temporl_q_learning(d, args.episodes,
426 | epsilon_decay=args.agent_eps_d,
427 | epsilon=args.agent_eps,
428 | discount_factor=.99, alpha=.5,
429 | eval_every=args.eval_eps,
430 | render_eval=not args.no_render,
431 | max_skip=args.max_skips)
432 | elif args.agent == 'q':
433 | train_data, test_data, num_steps, Q = q_learning(d, args.episodes,
434 | epsilon_decay=args.agent_eps_d,
435 | epsilon=args.agent_eps,
436 | discount_factor=.99,
437 | alpha=.5, eval_every=args.eval_eps,
438 | render_eval=not args.no_render)
439 | else:
440 | raise NotImplemented
441 |
442 | # TODO save resulting Q-function for easy reuse
443 | with open(os.path.join(out_dir, 'train_data.pkl'), 'wb') as outfh:
444 | pickle.dump(train_data, outfh)
445 | with open(os.path.join(out_dir, 'test_data.pkl'), 'wb') as outfh:
446 | pickle.dump(test_data, outfh)
447 | with open(os.path.join(out_dir, 'steps_per_episode.pkl'), 'wb') as outfh:
448 | pickle.dump(num_steps, outfh)
449 |
450 | if args.agent == 'q':
451 | with open(os.path.join(out_dir, 'Q.pkl'), 'wb') as outfh:
452 | pickle.dump(dict(Q), outfh)
453 | elif args.agent == 'sq':
454 | with open(os.path.join(out_dir, 'Q.pkl'), 'wb') as outfh:
455 | pickle.dump(dict(action_Q), outfh)
456 | with open(os.path.join(out_dir, 'J.pkl'), 'wb') as outfh:
457 | pickle.dump(dict(t_Q), outfh)
458 |
--------------------------------------------------------------------------------
/tabular_requirements.txt:
--------------------------------------------------------------------------------
1 | pyyaml~=5.4
2 | seaborn~=0.10.1
3 | matplotlib~=3.3.0
4 | gym~=0.17.2
5 | numpy~=1.19.1
6 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/automl/TempoRL/b8f4b0648489dbcc4895374df56a0d051379f2a8/utils/__init__.py
--------------------------------------------------------------------------------
/utils/config:
--------------------------------------------------------------------------------
1 | data:
2 | local: "."
3 | remote: "path_to_remote"
4 | plotting:
5 | color_map:
6 | sq: 4
7 | q: 2
8 | dqn: 2
9 | dar: 1
10 | tqn: 4
11 | tdqn: 4
12 | t-dqn: 4
13 | DDPG: 5
14 | t-DDPG: 9
15 | f-DDPG: 7
16 | palette: "colorblind"
17 | seaborn:
18 | style: "darkgrid"
19 | context:
20 | context: "paper"
21 | font scale: 1
22 | font: "Arial"
23 | rc:
24 | grid.linewidth: 4
25 | axes.labelsize: 32
26 | axes.titlesize: 32
27 | legend.fontsize: 28
28 | lines.linewidth: 4
29 | xtick.labelsize: 32
30 | ytick.labelsize: 32
31 | rc2:
32 | grid.linewidth: 6
33 | axes.labelsize: 64
34 | axes.titlesize: 64
35 | legend.fontsize: 35
36 | lines.linewidth: 6
37 | xtick.labelsize: 62
38 | ytick.labelsize: 62
39 |
--------------------------------------------------------------------------------
/utils/data_handling.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os
3 | import glob
4 | import pickle
5 | import numpy as np
6 |
7 | import json
8 | import pandas as pd
9 |
10 |
11 | def load_config():
12 | with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config'), 'r') as ymlf:
13 | cfg = yaml.load(ymlf, Loader=yaml.FullLoader)
14 | return cfg
15 |
16 |
17 | def load_data(experiment_dir="experiments_01_28", methods=['sq', 'q'], exp_version='-1.0-linear',
18 | episodes=50_000, max_skip=6, max_steps=100, local=True, debug=False):
19 | cfg = load_config()
20 | method_test_rewards = {}
21 | method_test_lengths = {}
22 | method_steps_per_episodes = {}
23 | if not local:
24 | print('Loading from')
25 | print(os.path.join(cfg['data']['local' if local else 'remote'], experiment_dir))
26 | for method in methods:
27 | print(method)
28 | files = glob.glob(
29 | os.path.join(cfg['data']['local' if local else 'remote'], experiment_dir,
30 | '{:s}-experiments{:s}'.format(method, exp_version),
31 | '{:d}_{:d}_{:d}_*', 'test_data.pkl'
32 | ).format(
33 | episodes,
34 | max_skip,
35 | max_steps
36 | ))
37 | test_rewards, test_lens, steps_per_eps = [], [], []
38 | for file in files:
39 | if debug:
40 | print('Loading', file)
41 | with open(file, 'rb') as fh:
42 | data = pickle.load(fh)
43 | test_rewards.append(data[0])
44 | test_lens.append(data[1])
45 | try:
46 | with open(file.replace('test_data', 'steps_per_episode'), 'rb') as fh:
47 | data = pickle.load(fh)
48 | steps_per_eps.append(data[1])
49 | except FileNotFoundError:
50 | print('No steps data found')
51 |
52 | method_test_rewards[method] = np.array(test_rewards)
53 | method_test_lengths[method] = np.array(test_lens)
54 | method_steps_per_episodes[method] = np.array(steps_per_eps)
55 | return method_test_rewards, method_test_lengths, method_steps_per_episodes
56 |
57 |
58 | def load_dqn_data(experiment_dir, method, max_steps=None, succ_threashold=None, debug=False):
59 | cfg = load_config()
60 | print(os.path.abspath(os.path.join(method, experiment_dir, 'eval_scores.json')))
61 | files = glob.glob(
62 | os.path.abspath(os.path.join(method, experiment_dir, 'eval_scores.json')),
63 | recursive=True
64 | )
65 | frames = []
66 | max_len = 0
67 | succ_count = 0
68 | for file in sorted(files):
69 | if debug:
70 | print('Loading', file)
71 | data = []
72 | with open(file, 'r') as fh:
73 | for line in fh:
74 | loaded = json.loads(line)
75 | data.append(loaded)
76 | if max_steps and loaded['training_steps'] >= max_steps:
77 | break
78 | frame = pd.DataFrame(data)
79 | max_len = max(max_len, frame.shape[0])
80 | if succ_threashold:
81 | if loaded['avg_rew_per_eval_ep'] > succ_threashold:
82 | succ_count += 1
83 | frames.append(frame)
84 | rews, lens, decs, training_steps, training_eps = [], [], [], [], []
85 | for frame in frames:
86 | for (list_, array) in [(rews, frame.avg_rew_per_eval_ep), (lens, frame.avg_num_steps_per_eval_ep),
87 | (decs, frame.avg_num_decs_per_eval_ep), (training_steps, frame.training_steps),
88 | (training_eps, frame.training_eps)]:
89 | data = np.full((max_len,), np.nan)
90 | data[:len(array)] = array
91 | list_.append(data)
92 | mean_r, std_r = np.nanmean(rews, axis=0), np.nanstd(rews, axis=0)
93 | mean_l, std_l = np.nanmean(lens, axis=0), np.nanstd(lens, axis=0)
94 | mean_d, std_d = np.nanmean(decs, axis=0), np.nanstd(decs, axis=0)
95 | mean_ts, std_ts = np.nanmean(training_steps, axis=0), np.nanstd(training_steps, axis=0)
96 | mean_te, std_te = np.nanmean(training_eps, axis=0), np.nanstd(training_eps, axis=0)
97 | if succ_threashold:
98 | try:
99 | print('\t {}/{} ({}\%) runs exceeded a final performance of {} after {} training steps'.format(
100 | succ_count, len(frames), succ_count / len(frames) * 100, succ_threashold, max_steps
101 | ))
102 | except ZeroDivisionError:
103 | pass
104 | if debug:
105 | print('#' * 80, '\n')
106 |
107 | return (mean_r, std_r), (mean_l, std_l), (mean_d, std_d), (mean_ts, std_ts), (mean_te, std_te)
108 |
--------------------------------------------------------------------------------
/utils/env_wrappers.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import gym
3 | import gym.spaces
4 | import numpy as np
5 | import collections
6 | from ray.rllib.env.atari_wrappers import wrap_deepmind, WarpFrame
7 |
8 |
9 | class FireResetEnv(gym.Wrapper):
10 | def __init__(self, env=None):
11 | """For environments where the user need to press FIRE for the game to start."""
12 | super(FireResetEnv, self).__init__(env)
13 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
14 | assert len(env.unwrapped.get_action_meanings()) >= 3
15 |
16 | def step(self, action):
17 | return self.env.step(action)
18 |
19 | def reset(self):
20 | self.env.reset()
21 | obs, _, done, _ = self.env.step(1)
22 | if done:
23 | self.env.reset()
24 | obs, _, done, _ = self.env.step(2)
25 | if done:
26 | self.env.reset()
27 | return obs
28 |
29 |
30 | class MaxAndSkipEnv(gym.Wrapper):
31 | def __init__(self, env=None, skip=4):
32 | """Return only every `skip`-th frame"""
33 | super(MaxAndSkipEnv, self).__init__(env)
34 | # most recent raw observations (for max pooling across time steps)
35 | self._obs_buffer = collections.deque(maxlen=2)
36 | self._skip = skip
37 |
38 | def step(self, action):
39 | total_reward = 0.0
40 | done = None
41 | for _ in range(self._skip):
42 | obs, reward, done, info = self.env.step(action)
43 | self._obs_buffer.append(obs)
44 | total_reward += reward
45 | if done:
46 | break
47 | max_frame = np.max(np.stack(self._obs_buffer), axis=0)
48 | return max_frame, total_reward, done, info
49 |
50 | def reset(self):
51 | """Clear past frame buffer and init. to first obs. from inner env."""
52 | self._obs_buffer.clear()
53 | obs = self.env.reset()
54 | self._obs_buffer.append(obs)
55 | return obs
56 |
57 |
58 | class ProcessFrame84(gym.ObservationWrapper):
59 | def __init__(self, env=None):
60 | super(ProcessFrame84, self).__init__(env)
61 | self.observation_space = gym.spaces.Box(
62 | low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
63 |
64 | def observation(self, obs):
65 | return ProcessFrame84.process(obs)
66 |
67 | @staticmethod
68 | def process(frame):
69 | if frame.size == 210 * 160 * 3:
70 | img = np.reshape(frame, [210, 160, 3]).astype(
71 | np.float32)
72 | elif frame.size == 250 * 160 * 3:
73 | img = np.reshape(frame, [250, 160, 3]).astype(
74 | np.float32)
75 | else:
76 | assert False, "Unknown resolution."
77 | img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + \
78 | img[:, :, 2] * 0.114
79 | resized_screen = cv2.resize(
80 | img, (84, 110), interpolation=cv2.INTER_AREA)
81 | x_t = resized_screen[18:102, :]
82 | x_t = np.reshape(x_t, [84, 84, 1])
83 | return x_t.astype(np.uint8)
84 |
85 |
86 | class ImageToPyTorch(gym.ObservationWrapper):
87 | def __init__(self, env):
88 | super(ImageToPyTorch, self).__init__(env)
89 | old_shape = self.observation_space.shape
90 | new_shape = (old_shape[-1], old_shape[0], old_shape[1])
91 | self.observation_space = gym.spaces.Box(
92 | low=0.0, high=1.0, shape=new_shape, dtype=np.float32)
93 |
94 | def observation(self, observation):
95 | return np.moveaxis(observation, 2, 0)
96 |
97 |
98 | class ScaledFloatFrame(gym.ObservationWrapper):
99 | def observation(self, obs):
100 | return np.array(obs).astype(np.float32) / 255.0
101 |
102 |
103 | class BufferWrapper(gym.ObservationWrapper):
104 | def __init__(self, env, n_steps, dtype=np.float32):
105 | super(BufferWrapper, self).__init__(env)
106 | self.dtype = dtype
107 | old_space = env.observation_space
108 | self.observation_space = gym.spaces.Box(
109 | old_space.low.repeat(n_steps, axis=0),
110 | old_space.high.repeat(n_steps, axis=0), dtype=dtype)
111 |
112 | def reset(self):
113 | self.buffer = np.zeros_like(
114 | self.observation_space.low, dtype=self.dtype)
115 | return self.observation(self.env.reset())
116 |
117 | def observation(self, observation):
118 | self.buffer[:-1] = self.buffer[1:]
119 | self.buffer[-1] = observation
120 | return self.buffer
121 |
122 |
123 | def make_env(env_name, dim=42):
124 | env = gym.make(env_name)
125 | env = wrap_deepmind(env, dim=dim, framestack=True) # grayscales already
126 | env = ImageToPyTorch(env)
127 | return env
128 |
129 | def make_env_old(env_name, dim=42):
130 | env = gym.make(env_name)
131 | env = MaxAndSkipEnv(env, skip=1)
132 | env = FireResetEnv(env)
133 | # env = ProcessFrame84(env)
134 | env = WarpFrame(env, dim=dim)
135 | env = ImageToPyTorch(env)
136 | env = BufferWrapper(env, 4) # Stacks frames
137 | return ScaledFloatFrame(env) # Normalises observations
138 |
--------------------------------------------------------------------------------
/utils/experiments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import os
5 | import sys
6 | import tempfile
7 |
8 |
9 | def prepare_output_dir(args, user_specified_dir=None, argv=None,
10 | time_format='%Y%m%dT%H%M%S.%f'):
11 | """
12 | Largely a copy of chainerRLs prepare output dir
13 | See (https://github.com/chainer/chainerrl/blob/018a29132d77e5af0f92161250c72aba10c6ce29/chainerrl/experiments/prepare_output_dir.py)
14 | Prepare a directory for outputting training results.
15 |
16 | An output directory, which ends with the current datetime string,
17 | is created. Then the following infomation is saved into the directory:
18 |
19 | args.txt: command line arguments
20 | command.txt: command itself
21 | environ.txt: environmental variables
22 |
23 | Args:
24 | args (dict or argparse.Namespace): Arguments to save
25 | user_specified_dir (str or None): If str is specified, the output
26 | directory is created under that path. If not specified, it is
27 | created as a new temporary directory instead.
28 | argv (list or None): The list of command line arguments passed to a
29 | script. If not specified, sys.argv is used instead.
30 | time_format (str): Format used to represent the current datetime. The
31 | default format is the basic format of ISO 8601.
32 | Returns:
33 | Path of the output directory created by this function (str).
34 | """
35 | time_str = datetime.datetime.now().strftime(time_format)
36 | if user_specified_dir is not None:
37 | if os.path.exists(user_specified_dir):
38 | if not os.path.isdir(user_specified_dir):
39 | raise RuntimeError(
40 | '{} is not a directory'.format(user_specified_dir))
41 | outdir = os.path.join(user_specified_dir, time_str)
42 | if os.path.exists(outdir):
43 | raise RuntimeError('{} exists'.format(outdir))
44 | else:
45 | os.makedirs(outdir)
46 | else:
47 | outdir = tempfile.mkdtemp(prefix=time_str)
48 |
49 | # Save all the arguments
50 | with open(os.path.join(outdir, 'args.txt'), 'w') as f:
51 | if isinstance(args, argparse.Namespace):
52 | args = vars(args)
53 | f.write(json.dumps(args))
54 |
55 | # Save all the environment variables
56 | with open(os.path.join(outdir, 'environ.txt'), 'w') as f:
57 | f.write(json.dumps(dict(os.environ)))
58 |
59 | # Save the command
60 | with open(os.path.join(outdir, 'command.txt'), 'w') as f:
61 | if argv is None:
62 | argv = sys.argv
63 | f.write(' '.join(argv))
64 |
65 | print('Results stored in {:s}'.format(os.path.abspath(outdir)))
66 | return outdir
67 |
--------------------------------------------------------------------------------
/utils/plotting.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib import pyplot as plt
3 | import seaborn as sb
4 |
5 | from utils.data_handling import load_config
6 |
7 |
8 | def get_colors():
9 | cfg = load_config()
10 | sb.set_style(cfg['plotting']['seaborn']['style'])
11 | sb.set_context(cfg['plotting']['seaborn']['context']['context'],
12 | font_scale=cfg['plotting']['seaborn']['context']['font scale'],
13 | rc=cfg['plotting']['seaborn']['context']['rc'])
14 | colors = list(sb.color_palette(cfg['plotting']['palette']))
15 | color_map = cfg['plotting']['color_map']
16 | return colors, color_map
17 |
18 |
19 | def plot_lens(ax, methods, lens, steps_per_ep, title, episodes=50_000,
20 | log=False, logx=False, eval_steps=1):
21 | colors, color_map = get_colors()
22 | print('AVG pol length')
23 | mv = 0
24 | created_steps_leg = False
25 | for method in methods:
26 | if method != 'hq':
27 | m, s = lens[method].mean(axis=0), lens[method].std(axis=0)
28 | total_ones = np.ones(m.shape) * 100
29 | # print("{:>4s}: {:>5.2f}".format(method, 1 - np.sum(total_ones - m) / np.sum(total_ones)))
30 | print("{:>4s}: {:>5.2f}".format(method, np.mean(m)))
31 | ax.step(np.arange(1, m.shape[0] + 1) * eval_steps, m, where='post',
32 | c=colors[color_map[method]])
33 | ax.fill_between(
34 | np.arange(1, m.shape[0] + 1) * eval_steps, m - s, m + s, alpha=0.25, step='post',
35 | color=colors[color_map[method]])
36 | if 'hq' not in methods:
37 | mv = max(mv, max(m) + max(m) * .05)
38 | else:
39 | raise NotImplementedError
40 | if steps_per_ep:
41 | m, s = steps_per_ep[method].mean(axis=0), steps_per_ep[method].std(axis=0)
42 | ax.step(np.arange(1, m.shape[0] + 1) * eval_steps, m, where='post',
43 | c=np.array(colors[color_map[method]]) * .75, ls=':')
44 | ax.fill_between(
45 | np.arange(1, m.shape[0] + 1) * eval_steps, m - s, m + s, alpha=0.125, step='post',
46 | color=np.array(colors[color_map[method]]) * .75)
47 | mv = max(mv, max(m))
48 | if not created_steps_leg:
49 | ax.plot([-999, -999], [-999, -999], ls=':', c='k', label='all')
50 | ax.plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')
51 | created_steps_leg = True
52 | if log:
53 | ax.set_ylim([1, max(mv, 100)])
54 | ax.semilogy()
55 | else:
56 | ax.set_ylim([0, mv])
57 | if logx:
58 | ax.set_ylim([1, max(mv, 100)])
59 | ax.set_xlim([1, episodes * eval_steps])
60 | ax.semilogx()
61 | else:
62 | ax.set_xlim([0, episodes * eval_steps])
63 | ax.set_ylabel('#Steps')
64 | if steps_per_ep:
65 | ax.legend(loc='upper right', ncol=1, handlelength=.75)
66 | ax.set_ylabel('#Steps')
67 | ax.set_xlabel('#Episodes')
68 | ax.set_title(title)
69 | return ax
70 |
71 |
72 | def _annotate(ax, rewards, max_reward, eval_steps):
73 | qxy = ((np.where(rewards['q'].mean(axis=0) >= .5 * max_reward)[0])[0] * eval_steps, .5)
74 | sqvxy = ((np.where(rewards['sq'].mean(axis=0) >= .5 * max_reward)[0])[0] * eval_steps, .5)
75 | ax.annotate("", # '{:d}x speedup'.format(int(np.round(qxy[0]/sqvxy[0]))),
76 | xy=qxy,
77 | xycoords='data', xytext=sqvxy, textcoords='data',
78 | arrowprops=dict(arrowstyle="<->", color="0.",
79 | connectionstyle="arc3,rad=0.", lw=5,
80 | ), )
81 |
82 | speedup = qxy[0] / sqvxy[0]
83 | qxy = (qxy[0], .5 * max_reward)
84 | sqvxy = (sqvxy[0], .25)
85 | ax.annotate(r'${:.2f}\times$ speedup'.format(speedup),
86 | xy=qxy,
87 | xycoords='data', xytext=sqvxy, textcoords='data',
88 | arrowprops=dict(arrowstyle="-", color="0.",
89 | connectionstyle="arc3,rad=0.", lw=0
90 | ),
91 | fontsize=22)
92 |
93 | try:
94 | qxy = ((np.where(rewards['q'].mean(axis=0) >= max_reward)[0])[0] * eval_steps, max_reward)
95 | sqvxy = ((np.where(rewards['sq'].mean(axis=0) >= max_reward)[0])[0] * eval_steps, max_reward)
96 | ax.annotate("", # '{:d}x speedup'.format(int(np.round(qxy[0]/sqvxy[0]))),
97 | xy=qxy,
98 | xycoords='data', xytext=sqvxy, textcoords='data',
99 | arrowprops=dict(arrowstyle="<->", color="0.",
100 | connectionstyle="arc3,rad=0.", lw=5,
101 | ), )
102 |
103 | speedup = qxy[0] / sqvxy[0]
104 | qxy = (qxy[0], max_reward)
105 | sqvxy = (sqvxy[0], .75)
106 | ax.annotate(r'${:.2f}\times$ speedup'.format(speedup),
107 | xy=qxy,
108 | xycoords='data', xytext=sqvxy, textcoords='data',
109 | arrowprops=dict(arrowstyle="-", color="0.",
110 | connectionstyle="arc3,rad=0.", lw=0
111 | ),
112 | fontsize=22)
113 | except:
114 | pass
115 |
116 |
117 | def plot_rewards(ax, methods, rewards, title, episodes=50_000,
118 | xlabel='#Episodes', log=False, logx=False, annotate=False, eval_steps=1):
119 | colors, color_map = get_colors()
120 | print('AUC')
121 | min_m = np.inf
122 | max_m = -np.inf
123 | for method in methods:
124 | m, s = rewards[method].mean(axis=0), rewards[method].std(axis=0)
125 | # used for AUC computation
126 | m_, s_ = ((rewards[method] + 1) / 2).mean(axis=0), ((rewards[method] + 1) / 2).std(axis=0)
127 | min_m = min(min(m), min_m)
128 | max_m = max(max(m), max_m)
129 | total_ones = np.ones(m.shape)
130 | label = method
131 | if method == 'sqv3':
132 | label = "sn-$\mathcal{Q}$"
133 | elif method == 'sq':
134 | label = "t-$\mathcal{Q}$"
135 | label = label.replace('q', '$\mathcal{Q}$')
136 | label = r'{:s}'.format(label)
137 | print("{:>2s}: {:>5.2f}".format(method, 1 - np.sum(total_ones - m_) / np.sum(total_ones)))
138 | ax.step(np.arange(1, m.shape[0] + 1) * eval_steps, m, where='post', c=colors[color_map[method]],
139 | label=label)
140 | ax.fill_between(np.arange(1, m.shape[0] + 1) * eval_steps, m - s, m + s, alpha=0.25, step='post',
141 | color=colors[color_map[method]])
142 | if annotate:
143 | _annotate(ax, rewards, max_m, eval_steps)
144 | if log:
145 | ax.set_ylim([.01, max_m])
146 | ax.semilogy()
147 | else:
148 | ax.set_ylim([min(-1, min_m - .1), max(1, max_m + .1)])
149 | if logx:
150 | ax.set_xlim([1, episodes * eval_steps])
151 | ax.semilogx()
152 | else:
153 | ax.set_xlim([0, episodes * eval_steps])
154 | ax.set_ylabel('Reward')
155 | ax.set_xlabel(xlabel)
156 | ax.set_title(title)
157 | ax.legend(ncol=1, loc='lower right', handlelength=.75)
158 | return ax
159 |
160 |
161 | def plot(methods, rewards, lens, steps_per_ep, title, episodes=50_000,
162 | show=True, savefig=None, logleny=True,
163 | logrewy=True, logx=False, annotate=False, eval_steps=1, horizontal=False,
164 | individual=False):
165 | _, _ = get_colors()
166 | if not individual:
167 | if horizontal:
168 | fig, ax = plt.subplots(1, 2, figsize=(32, 5), dpi=100)
169 | else:
170 | fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100)
171 | ax[0] = plot_rewards(ax[0], methods, rewards, title, episodes,
172 | xlabel='#Episodes' if horizontal else '',
173 | log=logrewy, logx=logx, annotate=annotate, eval_steps=eval_steps)
174 | print()
175 | ax[1] = plot_lens(ax[1], methods, lens, steps_per_ep, '', episodes, log=logleny,
176 | logx=logx, eval_steps=eval_steps)
177 | if savefig:
178 | plt.savefig(savefig, dpi=100)
179 | if show:
180 | plt.show()
181 | else:
182 | try:
183 | name, suffix = savefig.split('.')
184 | except:
185 | name = savefig
186 | suffix = '.pdf'
187 | fig, ax = plt.subplots(1, figsize=(10, 4), dpi=100)
188 | ax = plot_rewards(ax, methods, rewards, '', episodes,
189 | xlabel='#Episodes',
190 | log=logrewy, logx=logx, annotate=annotate, eval_steps=eval_steps)
191 | plt.tight_layout()
192 | if savefig:
193 | plt.savefig(name + '_rewards' + '.' + suffix, dpi=100)
194 | if show:
195 | plt.show()
196 | fig, ax = plt.subplots(1, figsize=(10, 4), dpi=100)
197 | ax = plot_lens(ax, methods, lens, steps_per_ep, '', episodes, log=logleny,
198 | logx=logx, eval_steps=eval_steps)
199 | plt.tight_layout()
200 | if savefig:
201 | plt.savefig(name + '_lens' + '.' + suffix, dpi=100)
202 | if show:
203 | plt.show()
204 |
--------------------------------------------------------------------------------