├── src
├── __init__.py
├── algos
│ ├── __init__.py
│ ├── torchbeast.py
│ ├── count.py
│ ├── rnd.py
│ ├── bebold.py
│ ├── curiosity.py
│ └── ride.py
├── core
│ ├── __init__.py
│ ├── prof.py
│ ├── environment.py
│ ├── vtrace.py
│ ├── file_writer.py
│ └── vtrace_test.py
├── losses.py
├── env_utils.py
├── arguments.py
├── models.py
├── atari_wrappers.py
└── utils.py
├── requirements.txt
├── main.py
├── README.md
└── LICENSE
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/src/algos/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/src/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | appdirs==1.4.4
2 | certifi==2019.11.28
3 | cffi==1.13.2
4 | cfgv==3.1.0
5 | cloudpickle==1.2.2
6 | distlib==0.3.0
7 | filelock==3.0.12
8 | future==0.18.2
9 | gitdb2==2.0.6
10 | GitPython==3.0.5
11 | gym==0.15.4
12 | -e git+https://github.com/maximecb/gym-minigrid.git@b84d99a2fbb481977172a85661eee812f022f130#egg=gym_minigrid
13 | gym-super-mario-bros==7.3.0
14 | identify==1.4.19
15 | importlib-metadata==1.6.1
16 | importlib-resources==2.0.0
17 | nes-py==8.1.1
18 | nodeenv==1.4.0
19 | numpy==1.17.4
20 | olefile==0.46
21 | opencv-python==4.1.2.30
22 | Pillow==8.1.1
23 | pre-commit==2.5.1
24 | pycparser==2.19
25 | pyglet==1.3.2
26 | PyQt5==5.13.2
27 | PyQt5-sip==12.7.0
28 | PyYAML==5.4
29 | scipy==1.3.3
30 | six==1.13.0
31 | smmap2==2.0.5
32 | toml==0.10.1
33 | torch==1.1.0
34 | torchvision==0.3.0
35 | tqdm==4.40.2
36 | virtualenv==20.0.21
37 | vizdoom==1.1.8
38 | zipp==3.1.0
39 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from src.arguments import parser
8 |
9 | from src.algos.torchbeast import train as train_vanilla
10 | from src.algos.count import train as train_count
11 | from src.algos.curiosity import train as train_curiosity
12 | from src.algos.rnd import train as train_rnd
13 | from src.algos.ride import train as train_ride
14 | from src.algos.bebold import train as train_bebold
15 |
16 | def main(flags):
17 | print(flags)
18 | if flags.model == 'vanilla':
19 | train_vanilla(flags)
20 | elif flags.model == 'count':
21 | train_count(flags)
22 | elif flags.model == 'curiosity':
23 | train_curiosity(flags)
24 | elif flags.model == 'rnd':
25 | train_rnd(flags)
26 | elif flags.model == 'ride':
27 | train_ride(flags)
28 | elif flags.model == 'bebold':
29 | train_bebold(flags)
30 | else:
31 | raise NotImplementedError("This model has not been implemented. "\
32 | "The available options are: vanilla, count, curiosity, rnd, ride, \
33 | no-episodic-counts, and only-episodic-count.")
34 |
35 | if __name__ == '__main__':
36 | flags = parser.parse_args()
37 | main(flags)
38 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NovelD: A Simple yet Effective Exploration Criterion
2 |
3 | ## Intro
4 |
5 | This is an implementation of the method proposed in
6 |
7 | NovelD: A Simple yet Effective Exploration Criterion and BeBold: Exploration Beyond the Boundary of Explored Regions
8 |
9 | ## Citation
10 | If you use this code in your own work, please cite our paper:
11 | ```
12 | @article{zhang2021noveld,
13 | title={NovelD: A Simple yet Effective Exploration Criterion},
14 | author={Zhang, Tianjun and Xu, Huazhe and Wang, Xiaolong and Wu, Yi and Keutzer, Kurt and Gonzalez, Joseph E and Tian, Yuandong},
15 | journal={Advances in Neural Information Processing Systems},
16 | volume={34},
17 | year={2021}
18 | }
19 | ```
20 | or
21 | ```
22 | @article{zhang2020bebold,
23 | title={BeBold: Exploration Beyond the Boundary of Explored Regions},
24 | author={Zhang, Tianjun and Xu, Huazhe and Wang, Xiaolong and Wu, Yi and Keutzer, Kurt and Gonzalez, Joseph E and Tian, Yuandong},
25 | journal={arXiv preprint arXiv:2012.08621},
26 | year={2020}
27 | }
28 | ```
29 |
30 | ## Installation
31 |
32 | ```
33 | # Install Instructions
34 | conda create -n ride python=3.7
35 | conda activate noveld
36 | git clone git@github.com:tianjunz/NovelD.git
37 | cd NovelD
38 | pip install -r requirements.txt
39 | ```
40 |
41 | ## Train NovelD on MiniGrid
42 | ```
43 | OMP_NUM_THREADS=1 python main.py --model bebold --env MiniGrid-ObstructedMaze-2Dlhb-v0 --total_frames 500000000 --intrinsic_reward_coef 0.05 --entropy_cost 0.0005
44 | ```
45 |
46 | ## Acknowledgements
47 | Our vanilla RL algorithm is based on [RIDE](https://github.com/facebookresearch/impact-driven-exploration).
48 |
49 | ## License
50 | This code is under the CC-BY-NC 4.0 (Attribution-NonCommercial 4.0 International) license.
51 |
--------------------------------------------------------------------------------
/src/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 | import numpy as np
11 |
12 |
13 | def compute_baseline_loss(advantages):
14 | return 0.5 * torch.sum(torch.mean(advantages**2, dim=1))
15 |
16 | def compute_entropy_loss(logits):
17 | policy = F.softmax(logits, dim=-1)
18 | log_policy = F.log_softmax(logits, dim=-1)
19 | entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1)
20 | return -torch.sum(torch.mean(entropy_per_timestep, dim=1))
21 |
22 | def compute_policy_gradient_loss(logits, actions, advantages):
23 | cross_entropy = F.nll_loss(
24 | F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
25 | target=torch.flatten(actions, 0, 1),
26 | reduction='none')
27 | cross_entropy = cross_entropy.view_as(advantages)
28 | advantages.requires_grad = False
29 | policy_gradient_loss_per_timestep = cross_entropy * advantages
30 | return torch.sum(torch.mean(policy_gradient_loss_per_timestep, dim=1))
31 |
32 |
33 | def compute_forward_dynamics_loss(pred_next_emb, next_emb):
34 | forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2)
35 | return torch.sum(torch.mean(forward_dynamics_loss, dim=1))
36 |
37 |
38 | def compute_inverse_dynamics_loss(pred_actions, true_actions):
39 | inverse_dynamics_loss = F.nll_loss(
40 | F.log_softmax(torch.flatten(pred_actions, 0, 1), dim=-1),
41 | target=torch.flatten(true_actions, 0, 1),
42 | reduction='none')
43 | inverse_dynamics_loss = inverse_dynamics_loss.view_as(true_actions)
44 | return torch.sum(torch.mean(inverse_dynamics_loss, dim=1))
45 |
46 | def compute_rnd_loss(pred_next_emb, next_emb):
47 | forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2)
48 | return torch.mean(forward_dynamics_loss)
49 |
50 |
--------------------------------------------------------------------------------
/src/core/prof.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Naive profiling using timeit."""
8 |
9 | import collections
10 | import timeit
11 |
12 |
13 | class Timings:
14 | """Not thread-safe."""
15 |
16 | def __init__(self):
17 | self._means = collections.defaultdict(int)
18 | self._vars = collections.defaultdict(int)
19 | self._counts = collections.defaultdict(int)
20 | self.reset()
21 |
22 | def reset(self):
23 | self.last_time = timeit.default_timer()
24 |
25 | def time(self, name):
26 | """Save an update for event `name`.
27 |
28 | Nerd alarm: We could just store a
29 | collections.defaultdict(list)
30 | and compute means and standard deviations at the end. But thanks to the
31 | clever math in Sutton-Barto
32 | (http://www.incompleteideas.net/book/first/ebook/node19.html) and
33 | https://math.stackexchange.com/a/103025/5051 we can update both the
34 | means and the stds online. O(1) FTW!
35 | """
36 | now = timeit.default_timer()
37 | x = now - self.last_time
38 | self.last_time = now
39 |
40 | n = self._counts[name]
41 |
42 | mean = self._means[name] + (x - self._means[name]) / (n + 1)
43 | var = (n * self._vars[name] + n * (self._means[name] - mean)**2 +
44 | (x - mean)**2) / (n + 1)
45 |
46 | self._means[name] = mean
47 | self._vars[name] = var
48 | self._counts[name] += 1
49 |
50 | def means(self):
51 | return self._means
52 |
53 | def vars(self):
54 | return self._vars
55 |
56 | def stds(self):
57 | return {k: v**0.5 for k, v in self._vars.items()}
58 |
59 | def summary(self, prefix=''):
60 | means = self.means()
61 | stds = self.stds()
62 | total = sum(means.values())
63 |
64 | result = prefix
65 | for k in sorted(means, key=means.get, reverse=True):
66 | result += f'\n %s: %.6fms +- %.6fms (%.2f%%) ' % (
67 | k, 1000 * means[k], 1000 * stds[k], 100 * means[k] / total)
68 | result += '\nTotal: %.6fms' % (1000 * total)
69 | return result
70 |
--------------------------------------------------------------------------------
/src/core/environment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 |
10 | def _format_frame(frame):
11 | frame = torch.from_numpy(frame)
12 | return frame.view((1, 1) + frame.shape) # (...) -> (T,B,...).
13 |
14 |
15 | class Environment:
16 |
17 | def __init__(self, gym_env, fix_seed=False):
18 | self.gym_env = gym_env
19 | self.episode_return = None
20 | self.episode_step = None
21 | self.fix_seed = fix_seed
22 |
23 | def initial(self):
24 | initial_reward = torch.zeros(1, 1)
25 | # This supports only single-tensor actions ATM.
26 | initial_last_action = torch.zeros(1, 1, dtype=torch.int64)
27 | self.episode_return = torch.zeros(1, 1)
28 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32)
29 | initial_done = torch.zeros(1, 1, dtype=torch.uint8)
30 | if self.fix_seed:
31 | self.gym_env.seed(seed=1)
32 | initial_frame = _format_frame(self.gym_env.reset())
33 | return dict(
34 | frame=initial_frame,
35 | reward=initial_reward,
36 | done=initial_done,
37 | episode_return=self.episode_return,
38 | episode_step=self.episode_step,
39 | last_action=initial_last_action)
40 |
41 |
42 | def step(self, action):
43 | frame, reward, done, unused_info = self.gym_env.step(action.item())
44 | self.episode_step += 1
45 | self.episode_return += reward
46 | episode_step = self.episode_step
47 | episode_return = self.episode_return
48 | if done:
49 | if self.fix_seed:
50 | self.gym_env.seed(seed=1)
51 | frame = self.gym_env.reset()
52 | self.episode_return = torch.zeros(1, 1)
53 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32)
54 |
55 | frame = _format_frame(frame)
56 | reward = torch.tensor(reward).view(1, 1)
57 | done = torch.tensor(done).view(1, 1)
58 |
59 | return dict(
60 | frame=frame,
61 | reward=reward,
62 | done=done,
63 | episode_return=episode_return,
64 | episode_step=episode_step,
65 | last_action=action)
66 |
67 |
68 | def close(self):
69 | self.gym_env.close()
--------------------------------------------------------------------------------
/src/core/vtrace.py:
--------------------------------------------------------------------------------
1 | # This file taken from
2 | # https://github.com/deepmind/scalable_agent/blob/
3 | # cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py
4 | # and modified.
5 |
6 | # Copyright 2018 Google LLC
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # https://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | """Functions to compute V-trace off-policy actor critic targets.
21 |
22 | For details and theory see:
23 |
24 | "IMPALA: Scalable Distributed Deep-RL with
25 | Importance Weighted Actor-Learner Architectures"
26 | by Espeholt, Soyer, Munos et al.
27 |
28 | See https://arxiv.org/abs/1802.01561 for the full paper.
29 | """
30 |
31 | import collections
32 |
33 | import torch
34 | import torch.nn.functional as F
35 |
36 | VTraceFromLogitsReturns = collections.namedtuple('VTraceFromLogitsReturns', [
37 | 'vs', 'pg_advantages', 'log_rhos', 'behavior_action_log_probs',
38 | 'target_action_log_probs'
39 | ])
40 |
41 | VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages')
42 |
43 |
44 | def action_log_probs(policy_logits, actions):
45 | return -F.nll_loss(
46 | F.log_softmax(torch.flatten(policy_logits, 0, 1), dim=-1),
47 | torch.flatten(actions, 0, 1),
48 | reduction='none').view_as(actions)
49 |
50 |
51 | def from_logits(behavior_policy_logits,
52 | target_policy_logits,
53 | actions,
54 | discounts,
55 | rewards,
56 | values,
57 | bootstrap_value,
58 | clip_rho_threshold=1.0,
59 | clip_pg_rho_threshold=1.0):
60 | """V-trace for softmax policies."""
61 |
62 | target_action_log_probs = action_log_probs(target_policy_logits, actions)
63 | behavior_action_log_probs = action_log_probs(behavior_policy_logits,
64 | actions)
65 | log_rhos = target_action_log_probs - behavior_action_log_probs
66 | vtrace_returns = from_importance_weights(
67 | log_rhos=log_rhos,
68 | discounts=discounts,
69 | rewards=rewards,
70 | values=values,
71 | bootstrap_value=bootstrap_value,
72 | clip_rho_threshold=clip_rho_threshold,
73 | clip_pg_rho_threshold=clip_pg_rho_threshold)
74 | return VTraceFromLogitsReturns(
75 | log_rhos=log_rhos,
76 | behavior_action_log_probs=behavior_action_log_probs,
77 | target_action_log_probs=target_action_log_probs,
78 | **vtrace_returns._asdict())
79 |
80 |
81 | @torch.no_grad()
82 | def from_importance_weights(log_rhos,
83 | discounts,
84 | rewards,
85 | values,
86 | bootstrap_value,
87 | clip_rho_threshold=1.0,
88 | clip_pg_rho_threshold=1.0):
89 | """V-trace from log importance weights."""
90 | with torch.no_grad():
91 | rhos = torch.exp(log_rhos)
92 | if clip_rho_threshold is not None:
93 | clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold)
94 | else:
95 | clipped_rhos = rhos
96 |
97 | cs = torch.clamp(rhos, max=1.0)
98 | # Append bootstrapped value to get [v1, ..., v_t+1]
99 | values_t_plus_1 = torch.cat(
100 | [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
101 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
102 |
103 | acc = torch.zeros_like(bootstrap_value)
104 | result = []
105 | for t in range(discounts.shape[0] - 1, -1, -1):
106 | acc = deltas[t] + discounts[t] * cs[t] * acc
107 | result.append(acc)
108 | result.reverse()
109 | vs_minus_v_xs = torch.stack(result)
110 |
111 | # Add V(x_s) to get v_s.
112 | vs = torch.add(vs_minus_v_xs, values)
113 |
114 | # Advantage for policy gradient.
115 | vs_t_plus_1 = torch.cat(
116 | [vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
117 | if clip_pg_rho_threshold is not None:
118 | clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold)
119 | else:
120 | clipped_pg_rhos = rhos
121 | pg_advantages = (clipped_pg_rhos *
122 | (rewards + discounts * vs_t_plus_1 - values))
123 |
124 | # Make sure no gradients backpropagated through the returned values.
125 | return VTraceReturns(vs=vs, pg_advantages=pg_advantages)
126 |
--------------------------------------------------------------------------------
/src/env_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import gym
8 | import torch
9 | from collections import deque, defaultdict
10 | from gym import spaces
11 | import numpy as np
12 | from gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX
13 |
14 |
15 | def _format_observation(obs):
16 | obs = torch.tensor(obs)
17 | return obs.view((1, 1) + obs.shape)
18 |
19 |
20 | class Minigrid2Image(gym.ObservationWrapper):
21 | def __init__(self, env):
22 | gym.ObservationWrapper.__init__(self, env)
23 | self.observation_space = env.observation_space.spaces['image']
24 |
25 | def observation(self, observation):
26 | return observation['image']
27 |
28 |
29 | class Environment:
30 | def __init__(self, gym_env, fix_seed=False, env_seed=1):
31 | self.gym_env = gym_env
32 | self.episode_return = None
33 | self.episode_step = None
34 | self.episode_win = None
35 | self.fix_seed = fix_seed
36 | self.env_seed = env_seed
37 |
38 | def get_partial_obs(self):
39 | return self.gym_env.env.env.gen_obs()['image']
40 |
41 | def initial(self):
42 | initial_reward = torch.zeros(1, 1)
43 | self.episode_return = torch.zeros(1, 1)
44 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32)
45 | self.episode_win = torch.zeros(1, 1, dtype=torch.int32)
46 | initial_done = torch.ones(1, 1, dtype=torch.uint8)
47 | if self.fix_seed:
48 | self.gym_env.seed(seed=self.env_seed)
49 | initial_frame = _format_observation(self.gym_env.reset())
50 | partial_obs = _format_observation(self.get_partial_obs())
51 |
52 | if self.gym_env.env.env.carrying:
53 | carried_col, carried_obj = torch.LongTensor([[COLOR_TO_IDX[self.gym_env.env.env.carrying.color]]]), torch.LongTensor([[OBJECT_TO_IDX[self.gym_env.env.env.carrying.type]]])
54 | else:
55 | carried_col, carried_obj = torch.LongTensor([[5]]), torch.LongTensor([[1]])
56 |
57 | return dict(
58 | frame=initial_frame,
59 | reward=initial_reward,
60 | done=initial_done,
61 | episode_return=self.episode_return,
62 | episode_step=self.episode_step,
63 | episode_win=self.episode_win,
64 | carried_col = carried_col,
65 | carried_obj = carried_obj,
66 | partial_obs=partial_obs
67 | )
68 |
69 | def step(self, action):
70 | frame, reward, done, _ = self.gym_env.step(action.item())
71 |
72 | self.episode_step += 1
73 | episode_step = self.episode_step
74 |
75 | self.episode_return += reward
76 | episode_return = self.episode_return
77 |
78 | if done and reward > 0:
79 | self.episode_win[0][0] = 1
80 | else:
81 | self.episode_win[0][0] = 0
82 | episode_win = self.episode_win
83 |
84 | if done:
85 | if self.fix_seed:
86 | self.gym_env.seed(seed=self.env_seed)
87 | frame = self.gym_env.reset()
88 | self.episode_return = torch.zeros(1, 1)
89 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32)
90 | self.episode_win = torch.zeros(1, 1, dtype=torch.int32)
91 |
92 | frame = _format_observation(frame)
93 | reward = torch.tensor(reward).view(1, 1)
94 | done = torch.tensor(done).view(1, 1)
95 | partial_obs = _format_observation(self.get_partial_obs())
96 |
97 | if self.gym_env.env.env.carrying:
98 | carried_col, carried_obj = torch.LongTensor([[COLOR_TO_IDX[self.gym_env.env.env.carrying.color]]]), torch.LongTensor([[OBJECT_TO_IDX[self.gym_env.env.env.carrying.type]]])
99 | else:
100 | carried_col, carried_obj = torch.LongTensor([[5]]), torch.LongTensor([[1]])
101 |
102 |
103 | return dict(
104 | frame=frame,
105 | reward=reward,
106 | done=done,
107 | episode_return=episode_return,
108 | episode_step = episode_step,
109 | episode_win = episode_win,
110 | carried_col = carried_col,
111 | carried_obj = carried_obj,
112 | partial_obs=partial_obs
113 | )
114 |
115 | def get_full_obs(self):
116 | env = self.gym_env.unwrapped
117 | full_grid = env.grid.encode()
118 | full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([
119 | OBJECT_TO_IDX['agent'],
120 | COLOR_TO_IDX['red'],
121 | env.agent_dir
122 | ])
123 | return full_grid
124 |
125 | def close(self):
126 | self.gym_env.close()
127 |
128 |
129 | class FrameStack(gym.Wrapper):
130 | def __init__(self, env, k):
131 | """Stack k last frames.
132 | Returns lazy array, which is much more memory efficient.
133 | See Also
134 | --------
135 | baselines.common.atari_wrappers.LazyFrames
136 | """
137 | gym.Wrapper.__init__(self, env)
138 | self.k = k
139 | self.frames = deque([], maxlen=k)
140 | shp = env.observation_space.shape
141 | self.observation_space = spaces.Box(
142 | low=0,
143 | high=255,
144 | shape=(shp[:-1] + (shp[-1] * k,)),
145 | dtype=env.observation_space.dtype)
146 |
147 | def reset(self):
148 | ob = self.env.reset()
149 | for _ in range(self.k):
150 | self.frames.append(ob)
151 | return self._get_ob()
152 |
153 | def step(self, action):
154 | ob, reward, done, info = self.env.step(action)
155 | self.frames.append(ob)
156 | return self._get_ob(), reward, done, info
157 |
158 | def _get_ob(self):
159 | assert len(self.frames) == self.k
160 | return LazyFrames(list(self.frames))
161 |
162 |
163 | class LazyFrames(object):
164 | def __init__(self, frames):
165 | """This object ensures that common frames between the observations are only stored once.
166 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
167 | buffers.
168 | This object should only be converted to numpy array before being passed to the model.
169 | You'd not believe how complex the previous solution was."""
170 | self._frames = frames
171 | self._out = None
172 |
173 | def _force(self):
174 | if self._out is None:
175 | self._out = np.concatenate(self._frames, axis=-1)
176 | self._frames = None
177 | return self._out
178 |
179 | def __array__(self, dtype=None):
180 | out = self._force()
181 | if dtype is not None:
182 | out = out.astype(dtype)
183 | return out
184 |
185 | def __len__(self):
186 | return len(self._force())
187 |
188 | def __getitem__(self, i):
189 | return self._force()[i]
190 |
191 |
192 |
--------------------------------------------------------------------------------
/src/core/file_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import copy
8 | import datetime
9 | import csv
10 | import json
11 | import logging
12 | import os
13 | import time
14 | from typing import Dict
15 |
16 | import git
17 |
18 |
19 | def gather_metadata() -> Dict:
20 | date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
21 | # gathering git metadata
22 | try:
23 | repo = git.Repo(search_parent_directories=True)
24 | git_sha = repo.commit().hexsha
25 | git_data = dict(
26 | commit=git_sha,
27 | branch=repo.active_branch.name,
28 | is_dirty=repo.is_dirty(),
29 | path=repo.git_dir,
30 | )
31 | except git.InvalidGitRepositoryError:
32 | git_data = None
33 | # gathering slurm metadata
34 | if 'SLURM_JOB_ID' in os.environ:
35 | slurm_env_keys = [k for k in os.environ if k.startswith('SLURM')]
36 | slurm_data = {}
37 | for k in slurm_env_keys:
38 | d_key = k.replace('SLURM_', '').replace('SLURMD_', '').lower()
39 | slurm_data[d_key] = os.environ[k]
40 | else:
41 | slurm_data = None
42 | return dict(
43 | date_start=date_start,
44 | date_end=None,
45 | successful=False,
46 | git=git_data,
47 | slurm=slurm_data,
48 | env=os.environ.copy(),
49 | )
50 |
51 |
52 | class FileWriter:
53 | def __init__(self,
54 | xpid: str = None,
55 | xp_args: dict = None,
56 | rootdir: str = '~/palaas'):
57 | if not xpid:
58 | # make unique id
59 | xpid = '{proc}_{unixtime}'.format(
60 | proc=os.getpid(), unixtime=int(time.time()))
61 | self.xpid = xpid
62 | self._tick = 0
63 |
64 | # metadata gathering
65 | if xp_args is None:
66 | xp_args = {}
67 | self.metadata = gather_metadata()
68 | # we need to copy the args, otherwise when we close the file writer
69 | # (and rewrite the args) we might have non-serializable objects (or
70 | # other nasty stuff).
71 | self.metadata['args'] = copy.deepcopy(xp_args)
72 | self.metadata['xpid'] = self.xpid
73 |
74 | formatter = logging.Formatter('%(message)s')
75 | self._logger = logging.getLogger('palaas/out')
76 |
77 | # to stdout handler
78 | shandle = logging.StreamHandler()
79 | shandle.setFormatter(formatter)
80 | self._logger.addHandler(shandle)
81 | self._logger.setLevel(logging.INFO)
82 |
83 | rootdir = os.path.expandvars(os.path.expanduser(rootdir))
84 | # to file handler
85 | self.basepath = os.path.join(rootdir, self.xpid)
86 |
87 | if not os.path.exists(self.basepath):
88 | self._logger.info('Creating log directory: %s', self.basepath)
89 | os.makedirs(self.basepath, exist_ok=True)
90 | else:
91 | self._logger.info('Found log directory: %s', self.basepath)
92 |
93 | # NOTE: remove latest because it creates errors when running on slurm
94 | # multiple jobs trying to write to latest but cannot find it
95 | # Add 'latest' as symlink unless it exists and is no symlink.
96 | # symlink = os.path.join(rootdir, 'latest')
97 | # if os.path.islink(symlink):
98 | # os.remove(symlink)
99 | # if not os.path.exists(symlink):
100 | # os.symlink(self.basepath, symlink)
101 | # self._logger.info('Symlinked log directory: %s', symlink)
102 |
103 | self.paths = dict(
104 | msg='{base}/out.log'.format(base=self.basepath),
105 | logs='{base}/logs.csv'.format(base=self.basepath),
106 | fields='{base}/fields.csv'.format(base=self.basepath),
107 | meta='{base}/meta.json'.format(base=self.basepath),
108 | )
109 |
110 | self._logger.info('Saving arguments to %s', self.paths['meta'])
111 | if os.path.exists(self.paths['meta']):
112 | self._logger.warning('Path to meta file already exists. '
113 | 'Not overriding meta.')
114 | else:
115 | self._save_metadata()
116 |
117 | self._logger.info('Saving messages to %s', self.paths['msg'])
118 | if os.path.exists(self.paths['msg']):
119 | self._logger.warning('Path to message file already exists. '
120 | 'New data will be appended.')
121 |
122 | fhandle = logging.FileHandler(self.paths['msg'])
123 | fhandle.setFormatter(formatter)
124 | self._logger.addHandler(fhandle)
125 |
126 | self._logger.info('Saving logs data to %s', self.paths['logs'])
127 | self._logger.info('Saving logs\' fields to %s', self.paths['fields'])
128 | if os.path.exists(self.paths['logs']):
129 | self._logger.warning('Path to log file already exists. '
130 | 'New data will be appended.')
131 | with open(self.paths['fields'], 'r') as csvfile:
132 | reader = csv.reader(csvfile)
133 | self.fieldnames = list(reader)[0]
134 | else:
135 | self.fieldnames = ['_tick', '_time']
136 |
137 | def log(self, to_log: Dict, tick: int = None,
138 | verbose: bool = False) -> None:
139 | if tick is not None:
140 | raise NotImplementedError
141 | else:
142 | to_log['_tick'] = self._tick
143 | self._tick += 1
144 | to_log['_time'] = time.time()
145 |
146 | old_len = len(self.fieldnames)
147 | for k in to_log:
148 | if k not in self.fieldnames:
149 | self.fieldnames.append(k)
150 | if old_len != len(self.fieldnames):
151 | with open(self.paths['fields'], 'w') as csvfile:
152 | writer = csv.writer(csvfile)
153 | writer.writerow(self.fieldnames)
154 | self._logger.info('Updated log fields: %s', self.fieldnames)
155 |
156 | if to_log['_tick'] == 0:
157 | # print("\ncreating logs file ")
158 | with open(self.paths['logs'], 'a') as f:
159 | f.write('# %s\n' % ','.join(self.fieldnames))
160 |
161 | if verbose:
162 | self._logger.info('LOG | %s', ', '.join(
163 | ['{}: {}'.format(k, to_log[k]) for k in sorted(to_log)]))
164 |
165 | with open(self.paths['logs'], 'a') as f:
166 | writer = csv.DictWriter(f, fieldnames=self.fieldnames)
167 | writer.writerow(to_log)
168 | # print("\nadded to log file")
169 |
170 | def close(self, successful: bool = True) -> None:
171 | self.metadata['date_end'] = datetime.datetime.now().strftime(
172 | '%Y-%m-%d %H:%M:%S.%f')
173 | self.metadata['successful'] = successful
174 | self._save_metadata()
175 |
176 | def _save_metadata(self) -> None:
177 | with open(self.paths['meta'], 'w') as jsonfile:
178 | json.dump(self.metadata, jsonfile, indent=4, sort_keys=True)
179 |
--------------------------------------------------------------------------------
/src/arguments.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser(description='PyTorch Scalable Agent')
10 |
11 | # General Settings.
12 | parser.add_argument('--env', type=str, default='MiniGrid-ObstructedMaze-2Dlh-v0',
13 | help='Gym environment. Other options are: SuperMarioBros-1-1-v0 \
14 | or VizdoomMyWayHomeDense-v0 etc.')
15 | parser.add_argument('--xpid', default=None,
16 | help='Experiment id (default: None).')
17 | parser.add_argument('--num_input_frames', default=1, type=int,
18 | help='Number of input frames to the model and state embedding including the current frame \
19 | When num_input_frames > 1, it will also take the previous num_input_frames - 1 frames as input.')
20 | parser.add_argument('--run_id', default=0, type=int,
21 | help='Run id used for running multiple instances of the same HP set \
22 | (instead of a different random seed since torchbeast does not accept this).')
23 | parser.add_argument('--seed', default=0, type=int,
24 | help='Environment seed.')
25 | parser.add_argument('--save_interval', default=10, type=int, metavar='N',
26 | help='Time interval (in minutes) at which to save the model.')
27 | parser.add_argument('--checkpoint_num_frames', default=10000000, type=int,
28 | help='Number of frames for checkpoint to load.')
29 |
30 | # Training settings.
31 | parser.add_argument('--disable_checkpoint', action='store_true',
32 | help='Disable saving checkpoint.')
33 | parser.add_argument('--savedir', default='./experiments/',
34 | help='Root dir where experiment data will be saved.')
35 | parser.add_argument('--num_actors', default=40, type=int, metavar='N',
36 | help='Number of actors.')
37 | parser.add_argument('--total_frames', default=500000000, type=int, metavar='T',
38 | help='Total environment frames to train for.')
39 | parser.add_argument('--batch_size', default=32, type=int, metavar='B',
40 | help='Learner batch size.')
41 | parser.add_argument('--unroll_length', default=100, type=int, metavar='T',
42 | help='The unroll length (time dimension).')
43 | parser.add_argument('--queue_timeout', default=1, type=int,
44 | metavar='S', help='Error timeout for queue.')
45 | parser.add_argument('--num_buffers', default=80, type=int,
46 | metavar='N', help='Number of shared-memory buffers.')
47 | parser.add_argument('--num_threads', default=4, type=int,
48 | metavar='N', help='Number learner threads.')
49 | parser.add_argument('--disable_cuda', action='store_true',
50 | help='Disable CUDA.')
51 | parser.add_argument('--max_grad_norm', default=40., type=float,
52 | metavar='MGN', help='Max norm of gradients.')
53 |
54 | # Loss settings.
55 | parser.add_argument('--entropy_cost', default=0.001, type=float,
56 | help='Entropy cost/multiplier.')
57 | parser.add_argument('--baseline_cost', default=0.5, type=float,
58 | help='Baseline cost/multiplier.')
59 | parser.add_argument('--discounting', default=0.99, type=float,
60 | help='Discounting factor.')
61 |
62 | # Optimizer settings.
63 | parser.add_argument('--learning_rate', default=0.0001, type=float,
64 | metavar='LR', help='Learning rate.')
65 | parser.add_argument('--predictor_learning_rate', default=0.0001, type=float,
66 | metavar='LR', help='Learning rate for RND predictor.')
67 | parser.add_argument('--alpha', default=0.99, type=float,
68 | help='RMSProp smoothing constant.')
69 | parser.add_argument('--momentum', default=0, type=float,
70 | help='RMSProp momentum.')
71 | parser.add_argument('--epsilon', default=1e-5, type=float,
72 | help='RMSProp epsilon.')
73 |
74 | # Exploration Settings.
75 | parser.add_argument('--forward_loss_coef', default=10.0, type=float,
76 | help='Coefficient for the forward dynamics loss. \
77 | This weighs the inverse model loss agains the forward model loss. \
78 | Should be between 0 and 1.')
79 | parser.add_argument('--inverse_loss_coef', default=0.1, type=float,
80 | help='Coefficient for the forward dynamics loss. \
81 | This weighs the inverse model loss agains the forward model loss. \
82 | Should be between 0 and 1.')
83 | parser.add_argument('--intrinsic_reward_coef', default=0.5, type=float,
84 | help='Coefficient for the intrinsic reward. \
85 | This weighs the intrinsic reaward against the extrinsic one. \
86 | Should be larger than 0.')
87 | parser.add_argument('--rnd_loss_coef', default=1.0, type=float,
88 | help='Coefficient for the RND loss coefficient relative to the IMPALA one.')
89 |
90 | # Singleton Environments.
91 | parser.add_argument('--fix_seed', action='store_true',
92 | help='Fix the environment seed so that it is \
93 | no longer procedurally generated but rather the same layout every episode.')
94 | parser.add_argument('--env_seed', default=1, type=int,
95 | help='The seed used to generate the environment if we are using a \
96 | singleton (i.e. not procedurally generated) environment.')
97 | parser.add_argument('--no_reward', action='store_true',
98 | help='No extrinsic reward. The agent uses only intrinsic reward to learn.')
99 |
100 | # Training Models.
101 | parser.add_argument('--model', default='vanilla',
102 | choices=['bebold'],
103 | help='Model used for training the agent.')
104 |
105 | # Baselines for AMIGo paper.
106 | parser.add_argument('--use_fullobs_policy', action='store_true',
107 | help='Use a full view of the environment as input to the policy network.')
108 | parser.add_argument('--use_fullobs_intrinsic', action='store_true',
109 | help='Use a full view of the environment for computing the intrinsic reward.')
110 | parser.add_argument('--target_update_freq', default=2, type=int,
111 | help='Number of time steps for updating target')
112 | parser.add_argument('--init_num_frames', default=1e6, type=int,
113 | help='Number of frames for updating teacher network')
114 | parser.add_argument('--planning_intrinsic_reward_coef', default=0.5, type=float,
115 | help='Coefficient for the planning intrinsic reward. \
116 | This weighs the intrinsic reaward against the extrinsic one. \
117 | Should be larger than 0.')
118 | parser.add_argument('--ema_momentum', default=1.0, type=float,
119 | help='Coefficient for the EMA update of the RND network')
120 | parser.add_argument('--use_lstm', action='store_true',
121 | help='Use a lstm version of policy network.')
122 | parser.add_argument('--use_lstm_intrinsic', action='store_true',
123 | help='Use a lstm version of intrinsic embedding network.')
124 | parser.add_argument('--state_embedding_dim', default=256, type=int,
125 | help='Embedding dimension of last layer of network')
126 | parser.add_argument('--scale_fac', default=0.5, type=float,
127 | help='coefficient for scaling visitation count difference')
128 |
--------------------------------------------------------------------------------
/src/algos/torchbeast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import threading
10 | import time
11 | import timeit
12 | import pprint
13 |
14 | import numpy as np
15 |
16 | import torch
17 | from torch import multiprocessing as mp
18 | from torch import nn
19 | from torch.nn import functional as F
20 |
21 | from src.core import file_writer
22 | from src.core import prof
23 | from src.core import vtrace
24 |
25 | import src.models as models
26 | import src.losses as losses
27 |
28 | from src.env_utils import FrameStack
29 | from src.utils import get_batch, log, create_env, create_buffers, act
30 |
31 |
32 | MinigridPolicyNet = models.MinigridPolicyNet
33 | # MarioDoomPolicyNet = models.MarioDoomPolicyNet
34 |
35 |
36 | def learn(actor_model,
37 | model,
38 | batch,
39 | initial_agent_state,
40 | optimizer,
41 | scheduler,
42 | flags,
43 | lock=threading.Lock()):
44 | """Performs a learning (optimization) step."""
45 | with lock:
46 | learner_outputs, unused_state = model(batch, initial_agent_state)
47 |
48 | bootstrap_value = learner_outputs['baseline'][-1]
49 |
50 | batch = {key: tensor[1:] for key, tensor in batch.items()}
51 | learner_outputs = {
52 | key: tensor[:-1]
53 | for key, tensor in learner_outputs.items()
54 | }
55 |
56 | rewards = batch['reward']
57 | clipped_rewards = torch.clamp(rewards, -1, 1)
58 |
59 | discounts = (~batch['done']).float() * flags.discounting
60 |
61 | vtrace_returns = vtrace.from_logits(
62 | behavior_policy_logits=batch['policy_logits'],
63 | target_policy_logits=learner_outputs['policy_logits'],
64 | actions=batch['action'],
65 | discounts=discounts,
66 | rewards=clipped_rewards,
67 | values=learner_outputs['baseline'],
68 | bootstrap_value=bootstrap_value)
69 |
70 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
71 | batch['action'],
72 | vtrace_returns.pg_advantages)
73 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
74 | vtrace_returns.vs - learner_outputs['baseline'])
75 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
76 | learner_outputs['policy_logits'])
77 |
78 | total_loss = pg_loss + baseline_loss + entropy_loss
79 |
80 | episode_returns = batch['episode_return'][batch['done']]
81 | stats = {
82 | 'mean_episode_return': torch.mean(episode_returns).item(),
83 | 'total_loss': total_loss.item(),
84 | 'pg_loss': pg_loss.item(),
85 | 'baseline_loss': baseline_loss.item(),
86 | 'entropy_loss': entropy_loss.item(),
87 | }
88 |
89 | scheduler.step()
90 | optimizer.zero_grad()
91 | total_loss.backward()
92 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
93 | optimizer.step()
94 |
95 | actor_model.load_state_dict(model.state_dict())
96 | return stats
97 |
98 |
99 | def train(flags):
100 | if flags.xpid is None:
101 | flags.xpid = 'torchbeast-%s' % time.strftime('%Y%m%d-%H%M%S')
102 | plogger = file_writer.FileWriter(
103 | xpid=flags.xpid,
104 | xp_args=flags.__dict__,
105 | rootdir=flags.savedir,
106 | )
107 | checkpointpath = os.path.expandvars(
108 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
109 | 'model.tar')))
110 |
111 | T = flags.unroll_length
112 | B = flags.batch_size
113 |
114 | flags.device = None
115 | if not flags.disable_cuda and torch.cuda.is_available():
116 | log.info('Using CUDA.')
117 | flags.device = torch.device('cuda')
118 | else:
119 | log.info('Not using CUDA.')
120 | flags.device = torch.device('cpu')
121 | env = create_env(flags)
122 | if flags.num_input_frames > 1:
123 | env = FrameStack(env, flags.num_input_frames)
124 |
125 | if 'MiniGrid' in flags.env:
126 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)
127 | else:
128 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)
129 |
130 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags)
131 |
132 | model.share_memory()
133 |
134 | initial_agent_state_buffers = []
135 | for _ in range(flags.num_buffers):
136 | state = model.initial_state(batch_size=1)
137 | for t in state:
138 | t.share_memory_()
139 | initial_agent_state_buffers.append(state)
140 |
141 | actor_processes = []
142 | ctx = mp.get_context('fork')
143 | free_queue = ctx.SimpleQueue()
144 | full_queue = ctx.SimpleQueue()
145 |
146 | episode_state_count_dict = dict()
147 | train_state_count_dict = dict()
148 | for i in range(flags.num_actors):
149 | actor = ctx.Process(
150 | target=act,
151 | args=(i, free_queue, full_queue, model, buffers,
152 | episode_state_count_dict, train_state_count_dict,
153 | initial_agent_state_buffers, flags))
154 | actor.start()
155 | actor_processes.append(actor)
156 |
157 | if 'MiniGrid' in flags.env:
158 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
159 | .to(device=flags.device)
160 | else:
161 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\
162 | .to(device=flags.device)
163 |
164 | optimizer = torch.optim.RMSprop(
165 | learner_model.parameters(),
166 | lr=flags.learning_rate,
167 | momentum=flags.momentum,
168 | eps=flags.epsilon,
169 | alpha=flags.alpha)
170 |
171 |
172 | def lr_lambda(epoch):
173 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames
174 |
175 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
176 |
177 | logger = logging.getLogger('logfile')
178 | stat_keys = [
179 | 'mean_episode_return',
180 | 'total_loss',
181 | 'pg_loss',
182 | 'baseline_loss',
183 | 'entropy_loss',
184 | ]
185 | logger.info('# Step\t%s', '\t'.join(stat_keys))
186 | frames, stats = 0, {}
187 |
188 |
189 | def batch_and_learn(i, lock=threading.Lock()):
190 | """Thread target for the learning process."""
191 | nonlocal frames, stats
192 | timings = prof.Timings()
193 | while frames < flags.total_frames:
194 | timings.reset()
195 | batch, agent_state = get_batch(free_queue, full_queue, buffers,
196 | initial_agent_state_buffers, flags, timings)
197 | stats = learn(model, learner_model, batch, agent_state,
198 | optimizer, scheduler, flags)
199 | timings.time('learn')
200 | with lock:
201 | to_log = dict(frames=frames)
202 | to_log.update({k: stats[k] for k in stat_keys})
203 | plogger.log(to_log)
204 | frames += T * B
205 |
206 | if i == 0:
207 | log.info('Batch and learn: %s', timings.summary())
208 |
209 | for m in range(flags.num_buffers):
210 | free_queue.put(m)
211 |
212 | threads = []
213 | for i in range(flags.num_threads):
214 | thread = threading.Thread(
215 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,))
216 | thread.start()
217 | threads.append(thread)
218 |
219 | def checkpoint(frames):
220 | if flags.disable_checkpoint:
221 | return
222 | checkpointpath = os.path.expandvars(os.path.expanduser(
223 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar')))
224 | log.info('Saving checkpoint to %s', checkpointpath)
225 | torch.save({
226 | 'model_state_dict': model.state_dict(),
227 | 'optimizer_state_dict': optimizer.state_dict(),
228 | 'scheduler_state_dict': scheduler.state_dict(),
229 | 'flags': vars(flags),
230 | }, checkpointpath)
231 |
232 | timer = timeit.default_timer
233 | try:
234 | last_checkpoint_time = timer()
235 | while frames < flags.total_frames:
236 | start_frames = frames
237 | start_time = timer()
238 | time.sleep(5)
239 |
240 | if timer() - last_checkpoint_time > flags.save_interval * 60:
241 | checkpoint(frames)
242 | last_checkpoint_time = timer()
243 |
244 | fps = (frames - start_frames) / (timer() - start_time)
245 | if stats.get('episode_returns', None):
246 | mean_return = 'Return per episode: %.1f. ' % stats[
247 | 'mean_episode_return']
248 | else:
249 | mean_return = ''
250 | total_loss = stats.get('total_loss', float('inf'))
251 | log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s',
252 | frames, total_loss, fps, mean_return,
253 | pprint.pformat(stats))
254 |
255 | except KeyboardInterrupt:
256 | return
257 | else:
258 | for thread in threads:
259 | thread.join()
260 | log.info('Learning finished after %d frames.', frames)
261 |
262 | finally:
263 | for _ in range(flags.num_actors):
264 | free_queue.put(None)
265 | for actor in actor_processes:
266 | actor.join(timeout=1)
267 | checkpoint(frames)
268 | plogger.close()
269 |
270 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 | import numpy as np
11 |
12 | def init(module, weight_init, bias_init, gain=1):
13 | weight_init(module.weight.data, gain=gain)
14 | bias_init(module.bias.data)
15 | return module
16 |
17 |
18 | class MinigridPolicyNet(nn.Module):
19 | def __init__(self, observation_shape, num_actions):
20 | super(MinigridPolicyNet, self).__init__()
21 | self.observation_shape = observation_shape
22 | self.num_actions = num_actions
23 |
24 | init_ = lambda m: init(m, nn.init.orthogonal_,
25 | lambda x: nn.init.constant_(x, 0),
26 | nn.init.calculate_gain('relu'))
27 |
28 | self.feat_extract = nn.Sequential(
29 | init_(nn.Conv2d(in_channels=self.observation_shape[2], out_channels=32, kernel_size=(3, 3), stride=1, padding=1)),
30 | nn.ELU(),
31 | init_(nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(3, 3), stride=2, padding=1)),
32 | nn.ELU(),
33 | init_(nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(3, 3), stride=2, padding=1)),
34 | nn.ELU(),
35 | )
36 |
37 | self.fc = nn.Sequential(
38 | init_(nn.Linear(2048, 1024)),
39 | nn.ReLU(),
40 | init_(nn.Linear(1024, 1024)),
41 | nn.ReLU(),
42 | )
43 |
44 | self.core = nn.LSTM(1024, 1024, 2)
45 |
46 | init_ = lambda m: init(m, nn.init.orthogonal_,
47 | lambda x: nn.init.constant_(x, 0))
48 |
49 | self.policy = init_(nn.Linear(1024, self.num_actions))
50 | self.baseline = init_(nn.Linear(1024, 1))
51 |
52 |
53 | def initial_state(self, batch_size):
54 | return tuple(torch.zeros(self.core.num_layers, batch_size,
55 | self.core.hidden_size) for _ in range(2))
56 |
57 |
58 | def forward(self, inputs, core_state=()):
59 | # -- [unroll_length x batch_size x height x width x channels]
60 | x = inputs['partial_obs']
61 | T, B, *_ = x.shape
62 |
63 | # -- [unroll_length*batch_size x height x width x channels]
64 | x = torch.flatten(x, 0, 1) # Merge time and batch.
65 | x = x.float()
66 |
67 | # -- [unroll_length*batch_size x channels x width x height]
68 | x = x.permute(0, 3, 1, 2)
69 | x = self.feat_extract(x)
70 | x = x.reshape(T * B, -1)
71 | core_input = self.fc(x)
72 |
73 | core_input = core_input.view(T, B, -1)
74 | core_output_list = []
75 | notdone = (~inputs['done']).float()
76 | for input, nd in zip(core_input.unbind(), notdone.unbind()):
77 | nd = nd.view(1, -1, 1)
78 | core_state = tuple(nd * s for s in core_state)
79 | output, core_state = self.core(input.unsqueeze(0), core_state)
80 | core_output_list.append(output)
81 | core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
82 |
83 | policy_logits = self.policy(core_output)
84 | baseline = self.baseline(core_output)
85 |
86 | if self.training:
87 | action = torch.multinomial(
88 | F.softmax(policy_logits, dim=1), num_samples=1)
89 | else:
90 | action = torch.argmax(policy_logits, dim=1)
91 |
92 | policy_logits = policy_logits.view(T, B, self.num_actions)
93 | baseline = baseline.view(T, B)
94 | action = action.view(T, B)
95 |
96 | return dict(policy_logits=policy_logits, baseline=baseline,
97 | action=action), core_state
98 |
99 |
100 | class MinigridStateEmbeddingNet(nn.Module):
101 | def __init__(self, observation_shape, use_lstm=False):
102 | super(MinigridStateEmbeddingNet, self).__init__()
103 | self.observation_shape = observation_shape
104 | self.use_lstm = use_lstm
105 |
106 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
107 | constant_(x, 0), nn.init.calculate_gain('relu'))
108 |
109 | self.feat_extract = nn.Sequential(
110 | init_(nn.Conv2d(in_channels=self.observation_shape[2], out_channels=32, kernel_size=(3, 3), stride=1, padding=1)),
111 | nn.ELU(),
112 | init_(nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(3, 3), stride=2, padding=1)),
113 | nn.ELU(),
114 | init_(nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(3, 3), stride=2, padding=1)),
115 | nn.ELU(),
116 | )
117 |
118 | self.fc = nn.Sequential(
119 | init_(nn.Linear(2048, 1024)),
120 | nn.ReLU(),
121 | init_(nn.Linear(1024, 1024)),
122 | nn.ReLU(),
123 | )
124 |
125 | if self.use_lstm:
126 | self.core = nn.LSTM(1024, 1024, 2)
127 |
128 | def initial_state(self, batch_size):
129 | #TODO: we might need to change this
130 | return tuple(torch.zeros(2, batch_size,
131 | 1024) for _ in range(2))
132 |
133 | def forward(self, inputs, core_state=(), done=None):
134 | # -- [unroll_length x batch_size x height x width x channels]
135 | x = inputs
136 | T, B, *_ = x.shape
137 |
138 | # -- [unroll_length*batch_size x height x width x channels]
139 | x = torch.flatten(x, 0, 1) # Merge time and batch.
140 |
141 | x = x.float()
142 |
143 | # -- [unroll_length*batch_size x channels x width x height]
144 | x = x.permute(0, 3, 1, 2)
145 | x = self.feat_extract(x)
146 | x = x.reshape(T * B, -1)
147 | x = self.fc(x)
148 |
149 | if self.use_lstm:
150 | core_input = x.view(T, B, -1)
151 | core_output_list = []
152 | notdone = (~done).float()
153 | for input, nd in zip(core_input.unbind(), notdone.unbind()):
154 | nd = nd.view(1, -1, 1)
155 | core_state = tuple(nd * s for s in core_state)
156 | output, core_state = self.core(input.unsqueeze(0), core_state)
157 | core_output_list.append(output)
158 | x = torch.flatten(torch.cat(core_output_list), 0, 1)
159 |
160 | state_embedding = x.view(T, B, -1)
161 | return state_embedding, core_state
162 |
163 | class MinigridMLPEmbeddingNet(nn.Module):
164 | def __init__(self):
165 | super(MinigridMLPEmbeddingNet, self).__init__()
166 |
167 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
168 | constant_(x, 0), nn.init.calculate_gain('relu'))
169 |
170 | self.fc = nn.Sequential(
171 | nn.Linear(1024, 1024),
172 | nn.ReLU(),
173 | nn.Linear(1024, 1024),
174 | nn.ReLU(),
175 | nn.Linear(1024, 128),
176 | nn.ReLU(),
177 | nn.Linear(128, 128),
178 | nn.ReLU(),
179 | nn.Linear(128, 128),
180 | )
181 |
182 | def forward(self, inputs, core_state=()):
183 | x = inputs
184 | T, B, *_ = x.shape
185 |
186 | x = self.fc(x)
187 |
188 | state_embedding = x.reshape(T, B, -1)
189 |
190 | return state_embedding, tuple()
191 |
192 |
193 | class MinigridMLPTargetEmbeddingNet(nn.Module):
194 | def __init__(self):
195 | super(MinigridMLPTargetEmbeddingNet, self).__init__()
196 |
197 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
198 | constant_(x, 0), nn.init.calculate_gain('relu'))
199 |
200 | self.fc = nn.Sequential(
201 | nn.Linear(1024, 128),
202 | nn.ReLU(),
203 | nn.Linear(128, 128),
204 | nn.ReLU(),
205 | nn.Linear(128, 128),
206 | nn.ReLU(),
207 | nn.Linear(128, 128),
208 | )
209 |
210 | def forward(self, inputs, core_state=()):
211 | x = inputs
212 | T, B, *_ = x.shape
213 |
214 | x = self.fc(x)
215 |
216 | state_embedding = x.reshape(T, B, -1)
217 |
218 | return state_embedding, tuple()
219 |
220 |
221 | class MinigridInverseDynamicsNet(nn.Module):
222 | def __init__(self, num_actions):
223 | super(MinigridInverseDynamicsNet, self).__init__()
224 | self.num_actions = num_actions
225 |
226 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
227 | constant_(x, 0), nn.init.calculate_gain('relu'))
228 | self.inverse_dynamics = nn.Sequential(
229 | init_(nn.Linear(2 * 128, 256)),
230 | nn.ReLU(),
231 | )
232 |
233 | init_ = lambda m: init(m, nn.init.orthogonal_,
234 | lambda x: nn.init.constant_(x, 0))
235 | self.id_out = init_(nn.Linear(256, self.num_actions))
236 |
237 |
238 | def forward(self, state_embedding, next_state_embedding):
239 | inputs = torch.cat((state_embedding, next_state_embedding), dim=2)
240 | action_logits = self.id_out(self.inverse_dynamics(inputs))
241 | return action_logits
242 |
243 |
244 | class MinigridForwardDynamicsNet(nn.Module):
245 | def __init__(self, num_actions):
246 | super(MinigridForwardDynamicsNet, self).__init__()
247 | self.num_actions = num_actions
248 |
249 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
250 | constant_(x, 0), nn.init.calculate_gain('relu'))
251 |
252 | self.forward_dynamics = nn.Sequential(
253 | init_(nn.Linear(128 + self.num_actions, 256)),
254 | nn.ReLU(),
255 | )
256 |
257 | init_ = lambda m: init(m, nn.init.orthogonal_,
258 | lambda x: nn.init.constant_(x, 0))
259 |
260 | self.fd_out = init_(nn.Linear(256, 128))
261 |
262 | def forward(self, state_embedding, action):
263 | action_one_hot = F.one_hot(action, num_classes=self.num_actions).float()
264 | inputs = torch.cat((state_embedding, action_one_hot), dim=2)
265 | next_state_emb = self.fd_out(self.forward_dynamics(inputs))
266 | return next_state_emb
267 |
268 |
269 |
--------------------------------------------------------------------------------
/src/core/vtrace_test.py:
--------------------------------------------------------------------------------
1 | # This file taken from
2 | # https://github.com/deepmind/scalable_agent/blob/
3 | # d24bd74bd53d454b7222b7f0bea57a358e4ca33e/vtrace_test.py
4 | # and modified.
5 |
6 | # Copyright 2018 Google LLC
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # https://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | """Tests for V-trace.
21 |
22 | For details and theory see:
23 |
24 | "IMPALA: Scalable Distributed Deep-RL with
25 | Importance Weighted Actor-Learner Architectures"
26 | by Espeholt, Soyer, Munos et al.
27 | """
28 |
29 | import unittest
30 |
31 | import numpy as np
32 | import torch
33 |
34 | import vtrace
35 |
36 |
37 | def _shaped_arange(*shape):
38 | """Runs np.arange, converts to float and reshapes."""
39 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape)
40 |
41 |
42 | def _softmax(logits):
43 | """Applies softmax non-linearity on inputs."""
44 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
45 |
46 |
47 | def _ground_truth_calculation(discounts, log_rhos, rewards, values,
48 | bootstrap_value, clip_rho_threshold,
49 | clip_pg_rho_threshold):
50 | """Calculates the ground truth for V-trace in Python/Numpy."""
51 | vs = []
52 | seq_len = len(discounts)
53 | rhos = np.exp(log_rhos)
54 | cs = np.minimum(rhos, 1.0)
55 | clipped_rhos = rhos
56 | if clip_rho_threshold:
57 | clipped_rhos = np.minimum(rhos, clip_rho_threshold)
58 | clipped_pg_rhos = rhos
59 | if clip_pg_rho_threshold:
60 | clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold)
61 |
62 | # This is a very inefficient way to calculate the V-trace ground truth.
63 | # We calculate it this way because it is close to the mathematical notation
64 | # of V-trace.
65 | # v_s = V(x_s)
66 | # + \sum^{T-1}_{t=s} \gamma^{t-s}
67 | # * \prod_{i=s}^{t-1} c_i
68 | # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t))
69 | # Note that when we take the product over c_i, we write `s:t` as the
70 | # notation of the paper is inclusive of the `t-1`, but Python is exclusive.
71 | # Also note that np.prod([]) == 1.
72 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0)
73 | for s in range(seq_len):
74 | v_s = np.copy(values[s]) # Very important copy.
75 | for t in range(s, seq_len):
76 | v_s += (
77 | np.prod(discounts[s:t], axis=0) * np.prod(
78 | cs[s:t], axis=0) * clipped_rhos[t] *
79 | (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t]
80 | ))
81 | vs.append(v_s)
82 | vs = np.stack(vs, axis=0)
83 | pg_advantages = (clipped_pg_rhos * (rewards + discounts * np.concatenate(
84 | [vs[1:], bootstrap_value[None, :]], axis=0) - values))
85 |
86 | return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages)
87 |
88 |
89 | def assert_allclose(actual, desired):
90 | return np.testing.assert_allclose(actual, desired, rtol=1e-06, atol=1e-05)
91 |
92 |
93 | class ActionLogProbsTest(unittest.TestCase):
94 |
95 | def test_action_log_probs(self, batch_size=2):
96 | seq_len = 7
97 | num_actions = 3
98 |
99 | policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10
100 | actions = np.random.randint(
101 | 0, num_actions, size=(seq_len, batch_size), dtype=np.int64)
102 |
103 | action_log_probs_tensor = vtrace.action_log_probs(
104 | torch.from_numpy(policy_logits), torch.from_numpy(actions))
105 |
106 | # Ground Truth
107 | # Using broadcasting to create a mask that indexes action logits
108 | action_index_mask = actions[..., None] == np.arange(num_actions)
109 |
110 | def index_with_mask(array, mask):
111 | return array[mask].reshape(*array.shape[:-1])
112 |
113 | # Note: Normally log(softmax) is not a good idea because it's not
114 | # numerically stable. However, in this test we have well-behaved values.
115 | ground_truth_v = index_with_mask(
116 | np.log(_softmax(policy_logits)), action_index_mask)
117 |
118 | assert_allclose(ground_truth_v, action_log_probs_tensor)
119 |
120 | def test_action_log_probs_batch_1(self):
121 | self.test_action_log_probs(1)
122 |
123 |
124 | class VtraceTest(unittest.TestCase):
125 |
126 | def test_vtrace(self, batch_size=5):
127 | """Tests V-trace against ground truth data calculated in python."""
128 | seq_len = 5
129 |
130 | # Create log_rhos such that rho will span from near-zero to above the
131 | # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5),
132 | # so that rho is in approx [0.08, 12.2).
133 | log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len)
134 | log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5).
135 | values = {
136 | 'log_rhos':
137 | log_rhos,
138 | # T, B where B_i: [0.9 / (i+1)] * T
139 | 'discounts':
140 | np.array(
141 | [[0.9 / (b + 1)
142 | for b in range(batch_size)]
143 | for _ in range(seq_len)],
144 | dtype=np.float32),
145 | 'rewards':
146 | _shaped_arange(seq_len, batch_size),
147 | 'values':
148 | _shaped_arange(seq_len, batch_size) / batch_size,
149 | 'bootstrap_value':
150 | _shaped_arange(batch_size) + 1.0,
151 | 'clip_rho_threshold':
152 | 3.7,
153 | 'clip_pg_rho_threshold':
154 | 2.2,
155 | }
156 |
157 | ground_truth = _ground_truth_calculation(**values)
158 |
159 | values = {key: torch.tensor(value) for key, value in values.items()}
160 | output = vtrace.from_importance_weights(**values)
161 |
162 | for a, b in zip(ground_truth, output):
163 | assert_allclose(a, b)
164 |
165 | def test_vtrace_batch_1(self):
166 | self.test_vtrace(1)
167 |
168 | def test_vtrace_from_logits(self, batch_size=2):
169 | """Tests V-trace calculated from logits."""
170 | seq_len = 5
171 | num_actions = 3
172 | clip_rho_threshold = None # No clipping.
173 | clip_pg_rho_threshold = None # No clipping.
174 |
175 | values = {
176 | 'behavior_policy_logits':
177 | _shaped_arange(seq_len, batch_size, num_actions),
178 | 'target_policy_logits':
179 | _shaped_arange(seq_len, batch_size, num_actions),
180 | 'actions':
181 | np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)),
182 | 'discounts':
183 | np.array( # T, B where B_i: [0.9 / (i+1)] * T
184 | [[0.9 / (b + 1)
185 | for b in range(batch_size)]
186 | for _ in range(seq_len)],
187 | dtype=np.float32),
188 | 'rewards':
189 | _shaped_arange(seq_len, batch_size),
190 | 'values':
191 | _shaped_arange(seq_len, batch_size) / batch_size,
192 | 'bootstrap_value':
193 | _shaped_arange(batch_size) + 1.0, # B
194 | }
195 | values = {k: torch.from_numpy(v) for k, v in values.items()}
196 |
197 | from_logits_output = vtrace.from_logits(
198 | clip_rho_threshold=clip_rho_threshold,
199 | clip_pg_rho_threshold=clip_pg_rho_threshold,
200 | **values)
201 |
202 | target_log_probs = vtrace.action_log_probs(
203 | values['target_policy_logits'], values['actions'])
204 | behavior_log_probs = vtrace.action_log_probs(
205 | values['behavior_policy_logits'], values['actions'])
206 | log_rhos = target_log_probs - behavior_log_probs
207 |
208 | # Calculate V-trace using the ground truth logits.
209 | from_iw = vtrace.from_importance_weights(
210 | log_rhos=log_rhos,
211 | discounts=values['discounts'],
212 | rewards=values['rewards'],
213 | values=values['values'],
214 | bootstrap_value=values['bootstrap_value'],
215 | clip_rho_threshold=clip_rho_threshold,
216 | clip_pg_rho_threshold=clip_pg_rho_threshold)
217 |
218 | assert_allclose(from_iw.vs, from_logits_output.vs)
219 | assert_allclose(from_iw.pg_advantages, from_logits_output.pg_advantages)
220 | assert_allclose(behavior_log_probs,
221 | from_logits_output.behavior_action_log_probs)
222 | assert_allclose(target_log_probs,
223 | from_logits_output.target_action_log_probs)
224 | assert_allclose(log_rhos, from_logits_output.log_rhos)
225 |
226 | def test_vtrace_from_logits_batch_1(self):
227 | self.test_vtrace_from_logits(1)
228 |
229 | def test_higher_rank_inputs_for_importance_weights(self):
230 | """Checks support for additional dimensions in inputs."""
231 | T = 3 # pylint: disable=invalid-name
232 | B = 2 # pylint: disable=invalid-name
233 | values = {
234 | 'log_rhos': torch.zeros(T, B, 1),
235 | 'discounts': torch.zeros(T, B, 1),
236 | 'rewards': torch.zeros(T, B, 42),
237 | 'values': torch.zeros(T, B, 42),
238 | 'bootstrap_value': torch.zeros(B, 42),
239 | }
240 | output = vtrace.from_importance_weights(**values)
241 | self.assertSequenceEqual(output.vs.shape, (T, B, 42))
242 |
243 | def test_inconsistent_rank_inputs_for_importance_weights(self):
244 | """Test one of many possible errors in shape of inputs."""
245 | T = 3 # pylint: disable=invalid-name
246 | B = 2 # pylint: disable=invalid-name
247 |
248 | values = {
249 | 'log_rhos': torch.zeros(T, B, 1),
250 | 'discounts': torch.zeros(T, B, 1),
251 | 'rewards': torch.zeros(T, B, 42),
252 | 'values': torch.zeros(T, B, 42),
253 | # Should be [B, 42].
254 | 'bootstrap_value': torch.zeros(B),
255 | }
256 |
257 | with self.assertRaisesRegex(RuntimeError,
258 | 'same number of dimensions: got 3 and 2'):
259 | vtrace.from_importance_weights(**values)
260 |
261 |
262 | if __name__ == '__main__':
263 | unittest.main()
264 |
--------------------------------------------------------------------------------
/src/algos/count.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import threading
10 | import time
11 | import timeit
12 | import pprint
13 |
14 | import numpy as np
15 |
16 | import torch
17 | from torch import multiprocessing as mp
18 | from torch import nn
19 | from torch.nn import functional as F
20 |
21 | from src.core import file_writer
22 | from src.core import prof
23 | from src.core import vtrace
24 |
25 | import src.models as models
26 | import src.losses as losses
27 |
28 | from src.env_utils import FrameStack
29 | from src.utils import get_batch, log, create_env, create_buffers, act
30 |
31 | MinigridPolicyNet = models.MinigridPolicyNet
32 |
33 | def learn(actor_model,
34 | model,
35 | batch,
36 | initial_agent_state,
37 | optimizer,
38 | scheduler,
39 | flags,
40 | lock=threading.Lock()):
41 | """Performs a learning (optimization) step."""
42 | with lock:
43 | intrinsic_rewards = torch.ones((flags.unroll_length, flags.batch_size),
44 | dtype=torch.float32).to(device=flags.device)
45 |
46 | intrinsic_rewards = batch['train_state_count'][1:].float().to(device=flags.device)
47 |
48 | intrinsic_reward_coef = flags.intrinsic_reward_coef
49 | intrinsic_rewards *= intrinsic_reward_coef
50 |
51 | learner_outputs, unused_state = model(batch, initial_agent_state)
52 |
53 | bootstrap_value = learner_outputs['baseline'][-1]
54 |
55 | batch = {key: tensor[1:] for key, tensor in batch.items()}
56 | learner_outputs = {
57 | key: tensor[:-1]
58 | for key, tensor in learner_outputs.items()
59 | }
60 |
61 | rewards = batch['reward']
62 | if flags.no_reward:
63 | total_rewards = intrinsic_rewards
64 | else:
65 | total_rewards = rewards + intrinsic_rewards
66 | clipped_rewards = torch.clamp(total_rewards, -1, 1)
67 |
68 | discounts = (~batch['done']).float() * flags.discounting
69 |
70 | vtrace_returns = vtrace.from_logits(
71 | behavior_policy_logits=batch['policy_logits'],
72 | target_policy_logits=learner_outputs['policy_logits'],
73 | actions=batch['action'],
74 | discounts=discounts,
75 | rewards=clipped_rewards,
76 | values=learner_outputs['baseline'],
77 | bootstrap_value=bootstrap_value)
78 |
79 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
80 | batch['action'],
81 | vtrace_returns.pg_advantages)
82 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
83 | vtrace_returns.vs - learner_outputs['baseline'])
84 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
85 | learner_outputs['policy_logits'])
86 |
87 | total_loss = pg_loss + baseline_loss + entropy_loss
88 |
89 | episode_returns = batch['episode_return'][batch['done']]
90 | stats = {
91 | 'mean_episode_return': torch.mean(episode_returns).item(),
92 | 'total_loss': total_loss.item(),
93 | 'pg_loss': pg_loss.item(),
94 | 'baseline_loss': baseline_loss.item(),
95 | 'entropy_loss': entropy_loss.item(),
96 | 'mean_rewards': torch.mean(rewards).item(),
97 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
98 | 'mean_total_rewards': torch.mean(total_rewards).item(),
99 | }
100 |
101 | scheduler.step()
102 | optimizer.zero_grad()
103 | total_loss.backward()
104 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
105 | optimizer.step()
106 |
107 | actor_model.load_state_dict(model.state_dict())
108 | return stats
109 |
110 |
111 | def train(flags):
112 | if flags.xpid is None:
113 | flags.xpid = 'count-%s' % time.strftime('%Y%m%d-%H%M%S')
114 | plogger = file_writer.FileWriter(
115 | xpid=flags.xpid,
116 | xp_args=flags.__dict__,
117 | rootdir=flags.savedir,
118 | )
119 | checkpointpath = os.path.expandvars(
120 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
121 | 'model.tar')))
122 |
123 | T = flags.unroll_length
124 | B = flags.batch_size
125 |
126 | flags.device = None
127 | if not flags.disable_cuda and torch.cuda.is_available():
128 | log.info('Using CUDA.')
129 | flags.device = torch.device('cuda')
130 | else:
131 | log.info('Not using CUDA.')
132 | flags.device = torch.device('cpu')
133 |
134 | env = create_env(flags)
135 | if flags.num_input_frames > 1:
136 | env = FrameStack(env, flags.num_input_frames)
137 |
138 | if 'MiniGrid' in flags.env:
139 | if flags.use_fullobs_policy:
140 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)
141 | else:
142 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)
143 | else:
144 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)
145 |
146 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags)
147 |
148 | model.share_memory()
149 |
150 | initial_agent_state_buffers = []
151 | for _ in range(flags.num_buffers):
152 | state = model.initial_state(batch_size=1)
153 | for t in state:
154 | t.share_memory_()
155 | initial_agent_state_buffers.append(state)
156 |
157 | actor_processes = []
158 | ctx = mp.get_context('fork')
159 | free_queue = ctx.SimpleQueue()
160 | full_queue = ctx.SimpleQueue()
161 |
162 | episode_state_count_dict = dict()
163 | train_state_count_dict = dict()
164 | for i in range(flags.num_actors):
165 | actor = ctx.Process(
166 | target=act,
167 | args=(i, free_queue, full_queue, model, buffers,
168 | episode_state_count_dict, train_state_count_dict,
169 | initial_agent_state_buffers, flags))
170 | actor.start()
171 | actor_processes.append(actor)
172 |
173 | if 'MiniGrid' in flags.env:
174 | if flags.use_fullobs_policy:
175 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
176 | .to(device=flags.device)
177 | else:
178 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
179 | .to(device=flags.device)
180 | else:
181 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\
182 | .to(device=flags.device)
183 |
184 | optimizer = torch.optim.RMSprop(
185 | learner_model.parameters(),
186 | lr=flags.learning_rate,
187 | momentum=flags.momentum,
188 | eps=flags.epsilon,
189 | alpha=flags.alpha)
190 |
191 |
192 | def lr_lambda(epoch):
193 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames
194 |
195 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
196 |
197 | logger = logging.getLogger('logfile')
198 | stat_keys = [
199 | 'total_loss',
200 | 'mean_episode_return',
201 | 'pg_loss',
202 | 'baseline_loss',
203 | 'entropy_loss',
204 | 'mean_rewards',
205 | 'mean_intrinsic_rewards',
206 | 'mean_total_rewards',
207 | ]
208 |
209 | logger.info('# Step\t%s', '\t'.join(stat_keys))
210 | frames, stats = 0, {}
211 |
212 |
213 | def batch_and_learn(i, lock=threading.Lock()):
214 | """Thread target for the learning process."""
215 | nonlocal frames, stats
216 | timings = prof.Timings()
217 | while frames < flags.total_frames:
218 | timings.reset()
219 | batch, agent_state = get_batch(free_queue, full_queue, buffers,
220 | initial_agent_state_buffers, flags, timings)
221 | stats = learn(model, learner_model, batch, agent_state,
222 | optimizer, scheduler, flags)
223 | timings.time('learn')
224 | with lock:
225 | to_log = dict(frames=frames)
226 | to_log.update({k: stats[k] for k in stat_keys})
227 | plogger.log(to_log)
228 | frames += T * B
229 |
230 | if i == 0:
231 | log.info('Batch and learn: %s', timings.summary())
232 |
233 | for m in range(flags.num_buffers):
234 | free_queue.put(m)
235 |
236 | threads = []
237 | for i in range(flags.num_threads):
238 | thread = threading.Thread(
239 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,))
240 | thread.start()
241 | threads.append(thread)
242 |
243 | def checkpoint(frames):
244 | if flags.disable_checkpoint:
245 | return
246 | checkpointpath = os.path.expandvars(os.path.expanduser(
247 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar')))
248 | log.info('Saving checkpoint to %s', checkpointpath)
249 | torch.save({
250 | 'model_state_dict': model.state_dict(),
251 | 'optimizer_state_dict': optimizer.state_dict(),
252 | 'scheduler_state_dict': scheduler.state_dict(),
253 | 'flags': vars(flags),
254 | }, checkpointpath)
255 |
256 | timer = timeit.default_timer
257 | try:
258 | last_checkpoint_time = timer()
259 | while frames < flags.total_frames:
260 | start_frames = frames
261 | start_time = timer()
262 | time.sleep(5)
263 |
264 | if timer() - last_checkpoint_time > flags.save_interval * 60:
265 | checkpoint(frames)
266 | last_checkpoint_time = timer()
267 |
268 | fps = (frames - start_frames) / (timer() - start_time)
269 | if stats.get('episode_returns', None):
270 | mean_return = 'Return per episode: %.1f. ' % stats[
271 | 'mean_episode_return']
272 | else:
273 | mean_return = ''
274 | total_loss = stats.get('total_loss', float('inf'))
275 | log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s',
276 | frames, total_loss, fps, mean_return,
277 | pprint.pformat(stats))
278 |
279 | except KeyboardInterrupt:
280 | return
281 | else:
282 | for thread in threads:
283 | thread.join()
284 | log.info('Learning finished after %d frames.', frames)
285 |
286 | finally:
287 | for _ in range(flags.num_actors):
288 | free_queue.put(None)
289 | for actor in actor_processes:
290 | actor.join(timeout=1)
291 | checkpoint(frames)
292 | plogger.close()
293 |
294 |
--------------------------------------------------------------------------------
/src/atari_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # This code was taken from
8 | # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
9 | # and modified.
10 |
11 | from collections import deque, defaultdict
12 |
13 | import numpy as np
14 |
15 | import gym
16 | from gym import spaces
17 | import cv2
18 | cv2.ocl.setUseOpenCL(False)
19 |
20 |
21 | class NoNegativeRewardEnv(gym.RewardWrapper):
22 | """Clip reward in negative direction."""
23 | def __init__(self, env=None, neg_clip=0.0):
24 | super(NoNegativeRewardEnv, self).__init__(env)
25 | self.neg_clip = neg_clip
26 |
27 | def _reward(self, reward):
28 | new_reward = self.neg_clip if reward < self.neg_clip else reward
29 | return new_reward
30 |
31 |
32 | class NoopResetEnv(gym.Wrapper):
33 |
34 | def __init__(self, env, noop_max=30):
35 | """Sample initial states by taking random number of no-ops on reset.
36 | No-op is assumed to be action 0.
37 | """
38 | gym.Wrapper.__init__(self, env)
39 | self.noop_max = noop_max
40 | self.override_num_noops = None
41 | self.noop_action = 0
42 |
43 | def reset(self, **kwargs):
44 | """ Do no-op action for a number of steps in [1, noop_max]."""
45 | self.env.reset(**kwargs)
46 | if self.override_num_noops is not None:
47 | noops = self.override_num_noops
48 | else:
49 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
50 | assert noops > 0
51 | obs = None
52 | for _ in range(noops):
53 | obs, _, done, _ = self.env.step(self.noop_action)
54 | if done:
55 | obs = self.env.reset(**kwargs)
56 | return obs
57 |
58 | def step(self, ac):
59 | return self.env.step(ac)
60 |
61 |
62 | class FireResetEnv(gym.Wrapper):
63 |
64 | def __init__(self, env):
65 | """Take action on reset for environments that are fixed until firing."""
66 | gym.Wrapper.__init__(self, env)
67 |
68 | def reset(self, **kwargs):
69 | self.env.reset(**kwargs)
70 | obs, _, done, _ = self.env.step(1)
71 | if done:
72 | self.env.reset(**kwargs)
73 | obs, _, done, _ = self.env.step(2)
74 | if done:
75 | self.env.reset(**kwargs)
76 | return obs
77 |
78 | def step(self, ac):
79 | return self.env.step(ac)
80 |
81 |
82 | class EpisodicLifeEnv(gym.Wrapper):
83 |
84 | def __init__(self, env):
85 | """Make end-of-life == end-of-episode, but only reset on true game over.
86 | Done by DeepMind for the DQN and co. since it helps value estimation.
87 | """
88 | gym.Wrapper.__init__(self, env)
89 | self.lives = 0
90 | self.was_real_done = True
91 |
92 | def step(self, action):
93 | obs, reward, done, info = self.env.step(action)
94 | self.was_real_done = done
95 | # check current lives, make loss of life terminal,
96 | # then update lives to handle bonus lives
97 | lives = self.env.unwrapped.ale.lives()
98 | if lives < self.lives and lives > 0:
99 | # for Qbert sometimes we stay in lives == 0 condition for a few frames
100 | # so it's important to keep lives > 0, so that we only reset once
101 | # the environment advertises done.
102 | done = True
103 | self.lives = lives
104 | return obs, reward, done, info
105 |
106 | def reset(self, **kwargs):
107 | """Reset only when lives are exhausted.
108 | This way all states are still reachable even though lives are episodic,
109 | and the learner need not know about any of this behind-the-scenes.
110 | """
111 | if self.was_real_done:
112 | obs = self.env.reset(**kwargs)
113 | else:
114 | # no-op step to advance from terminal/lost life state
115 | obs, _, _, _ = self.env.step(0)
116 | self.lives = self.env.unwrapped.ale.lives()
117 | return obs
118 |
119 |
120 | class MaxAndSkipEnv(gym.Wrapper):
121 |
122 | def __init__(self, env, skip=4):
123 | """Return only every `skip`-th frame"""
124 | gym.Wrapper.__init__(self, env)
125 | # most recent raw observations (for max pooling across time steps)
126 | self._obs_buffer = np.zeros(
127 | (2,) + env.observation_space.shape, dtype=np.uint8)
128 | self._skip = skip
129 |
130 | def step(self, action):
131 | """Repeat action, sum reward, and max over last observations."""
132 | total_reward = 0.0
133 | done = None
134 | for i in range(self._skip):
135 | obs, reward, done, info = self.env.step(action)
136 | if i == self._skip - 2: self._obs_buffer[0] = obs
137 | if i == self._skip - 1: self._obs_buffer[1] = obs
138 | total_reward += reward
139 | if done:
140 | break
141 | # Note that the observation on the done=True frame
142 | # doesn't matter
143 | max_frame = self._obs_buffer.max(axis=0)
144 |
145 | return max_frame, total_reward, done, info
146 |
147 | def reset(self, **kwargs):
148 | return self.env.reset(**kwargs)
149 |
150 |
151 | class ClipRewardEnv(gym.RewardWrapper):
152 |
153 | def __init__(self, env):
154 | gym.RewardWrapper.__init__(self, env)
155 |
156 | def reward(self, reward):
157 | """Bin reward to {+1, 0, -1} by its sign."""
158 | return np.sign(reward)
159 |
160 |
161 | class WarpFrame(gym.ObservationWrapper):
162 |
163 | def __init__(self, env, width=84, height=84, grayscale=True):
164 | """Warp frames to 84x84 as done in the Nature paper and later work."""
165 | gym.ObservationWrapper.__init__(self, env)
166 | self.width = width
167 | self.height = height
168 | self.grayscale = grayscale
169 | if self.grayscale:
170 | self.observation_space = spaces.Box(
171 | low=0,
172 | high=255,
173 | shape=(self.height, self.width, 1),
174 | dtype=np.uint8)
175 | else:
176 | self.observation_space = spaces.Box(
177 | low=0,
178 | high=255,
179 | shape=(self.height, self.width, 3),
180 | dtype=np.uint8)
181 |
182 | def observation(self, frame):
183 | if self.grayscale:
184 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
185 | frame = cv2.resize(
186 | frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
187 | if self.grayscale:
188 | frame = np.expand_dims(frame, -1)
189 | return frame
190 |
191 |
192 | class FrameStack(gym.Wrapper):
193 |
194 | def __init__(self, env, k):
195 | """Stack k last frames.
196 |
197 | Returns lazy array, which is much more memory efficient.
198 |
199 | See Also
200 | --------
201 | baselines.common.atari_wrappers.LazyFrames
202 | """
203 | gym.Wrapper.__init__(self, env)
204 | self.k = k
205 | self.frames = deque([], maxlen=k)
206 | shp = env.observation_space.shape
207 | self.observation_space = spaces.Box(
208 | low=0,
209 | high=255,
210 | shape=(shp[:-1] + (shp[-1] * k,)),
211 | dtype=env.observation_space.dtype)
212 |
213 | def reset(self):
214 | ob = self.env.reset()
215 | for _ in range(self.k):
216 | self.frames.append(ob)
217 | return self._get_ob()
218 |
219 | def step(self, action):
220 | ob, reward, done, info = self.env.step(action)
221 | self.frames.append(ob)
222 | return self._get_ob(), reward, done, info
223 |
224 | def _get_ob(self):
225 | assert len(self.frames) == self.k
226 | return LazyFrames(list(self.frames))
227 |
228 |
229 | class ScaledFloatFrame(gym.ObservationWrapper):
230 |
231 | def __init__(self, env):
232 | gym.ObservationWrapper.__init__(self, env)
233 | self.observation_space = gym.spaces.Box(
234 | low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
235 |
236 | def observation(self, observation):
237 | # careful! This undoes the memory optimization, use
238 | # with smaller replay buffers only.
239 | return np.array(observation).astype(np.float32) / 255.0
240 |
241 |
242 | class LazyFrames(object):
243 |
244 | def __init__(self, frames):
245 | """This object ensures that common frames between the observations are only stored once.
246 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
247 | buffers.
248 |
249 | This object should only be converted to numpy array before being passed to the model.
250 |
251 | You'd not believe how complex the previous solution was."""
252 | self._frames = frames
253 | self._out = None
254 |
255 | def _force(self):
256 | if self._out is None:
257 | self._out = np.concatenate(self._frames, axis=-1)
258 | self._frames = None
259 | return self._out
260 |
261 | def __array__(self, dtype=None):
262 | out = self._force()
263 | if dtype is not None:
264 | out = out.astype(dtype)
265 | return out
266 |
267 | def __len__(self):
268 | return len(self._force())
269 |
270 | def __getitem__(self, i):
271 | return self._force()[i]
272 |
273 |
274 | def make_atari(env_id, timelimit=True, noop=False):
275 | # XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
276 | env = gym.make(env_id)
277 | if not timelimit:
278 | env = env.env
279 |
280 | # NOTE: this is changed from original atari implementation
281 | # assert 'NoFrameskip' in env.spec.id
282 | if noop:
283 | env = NoopResetEnv(env, noop_max=30)
284 | env = MaxAndSkipEnv(env, skip=6)
285 | return env
286 |
287 |
288 | # NOTE: this was changed so that episode_life is False by default
289 | def wrap_deepmind(env,
290 | episode_life=False,
291 | clip_rewards=True,
292 | frame_stack=False,
293 | scale=True,
294 | fire=False): # FYI scale=False in openai/baselines
295 | """Configure environment for DeepMind-style Atari.
296 | """
297 | if episode_life:
298 | env = EpisodicLifeEnv(env)
299 | if fire:
300 | if 'FIRE' in env.unwrapped.get_action_meanings():
301 | env = FireResetEnv(env)
302 | env = WarpFrame(env, width=42, height=42)
303 | if scale:
304 | env = ScaledFloatFrame(env)
305 | if clip_rewards:
306 | env = ClipRewardEnv(env)
307 | if frame_stack:
308 | env = FrameStack(env, 4)
309 | return env
310 |
311 |
312 | # Taken from https://github.com/openai/baselines/blob/master/baselines/run.py
313 | def get_env_type(env_id):
314 | # Re-parse the gym registry, since we could have new envs since last time.
315 | for env in gym.envs.registry.all():
316 | env_type = env._entry_point.split(':')[0].split('.')[-1]
317 | _game_envs[env_type].add(env.id) # This is a set so add is idempotent
318 |
319 | if env_id in _game_envs.keys():
320 | env_type = env_id
321 | env_id = [g for g in _game_envs[env_type]][0]
322 | else:
323 | env_type = None
324 | for g, e in _game_envs.items():
325 | if env_id in e:
326 | env_type = g
327 | break
328 | assert env_type is not None, 'env_id {} is not recognized in env types'.format(
329 | env_id, _game_envs.keys())
330 |
331 | return env_type, env_id
332 |
333 |
334 | class ImageToPyTorch(gym.ObservationWrapper):
335 | """
336 | Image shape to channels x weight x height
337 | """
338 |
339 | def __init__(self, env):
340 | super(ImageToPyTorch, self).__init__(env)
341 | old_shape = self.observation_space.shape
342 | self.observation_space = gym.spaces.Box(
343 | low=0.0,
344 | high=1.0,
345 | shape=(old_shape[-1], old_shape[0], old_shape[1]),
346 | dtype=np.uint8)
347 |
348 | def observation(self, observation):
349 | return np.swapaxes(observation, 2, 0)
350 |
351 |
352 | def wrap_pytorch(env):
353 | return ImageToPyTorch(env)
354 |
355 |
--------------------------------------------------------------------------------
/src/algos/rnd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 | import threading
11 | import time
12 | import timeit
13 | import pprint
14 |
15 | import numpy as np
16 |
17 | import torch
18 | from torch import multiprocessing as mp
19 | from torch import nn
20 | from torch.nn import functional as F
21 |
22 | from src.core import file_writer
23 | from src.core import prof
24 | from src.core import vtrace
25 |
26 | import src.models as models
27 | import src.losses as losses
28 |
29 | from src.env_utils import FrameStack
30 | from src.utils import get_batch, log, create_env, create_buffers, act
31 |
32 | MinigridPolicyNet = models.MinigridPolicyNet
33 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet
34 |
35 | def learn(actor_model,
36 | model,
37 | random_target_network,
38 | predictor_network,
39 | batch,
40 | initial_agent_state,
41 | optimizer,
42 | predictor_optimizer,
43 | scheduler,
44 | flags,
45 | frames=None,
46 | lock=threading.Lock()):
47 | """Performs a learning (optimization) step."""
48 | with lock:
49 | if flags.use_fullobs_intrinsic:
50 | random_embedding = random_target_network(batch, next_state=True)\
51 | .reshape(flags.unroll_length, flags.batch_size, 128)
52 | predicted_embedding = predictor_network(batch, next_state=True)\
53 | .reshape(flags.unroll_length, flags.batch_size, 128)
54 | else:
55 | random_embedding = random_target_network(batch['partial_obs'][1:].to(device=flags.device))
56 | predicted_embedding = predictor_network(batch['partial_obs'][1:].to(device=flags.device))
57 |
58 | intrinsic_rewards = torch.norm(predicted_embedding.detach() - random_embedding.detach(), dim=2, p=2)
59 |
60 | intrinsic_reward_coef = flags.intrinsic_reward_coef
61 | intrinsic_rewards *= intrinsic_reward_coef
62 |
63 | num_samples = flags.unroll_length * flags.batch_size
64 | actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy()
65 | intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy()
66 |
67 | rnd_loss = flags.rnd_loss_coef * \
68 | losses.compute_forward_dynamics_loss(predicted_embedding, random_embedding.detach())
69 |
70 | learner_outputs, unused_state = model(batch, initial_agent_state)
71 |
72 | bootstrap_value = learner_outputs['baseline'][-1]
73 |
74 | batch = {key: tensor[1:] for key, tensor in batch.items()}
75 | learner_outputs = {
76 | key: tensor[:-1]
77 | for key, tensor in learner_outputs.items()
78 | }
79 |
80 | rewards = batch['reward']
81 |
82 | if flags.no_reward:
83 | total_rewards = intrinsic_rewards
84 | else:
85 | total_rewards = rewards + intrinsic_rewards
86 | clipped_rewards = torch.clamp(total_rewards, -1, 1)
87 |
88 | discounts = (~batch['done']).float() * flags.discounting
89 |
90 | vtrace_returns = vtrace.from_logits(
91 | behavior_policy_logits=batch['policy_logits'],
92 | target_policy_logits=learner_outputs['policy_logits'],
93 | actions=batch['action'],
94 | discounts=discounts,
95 | rewards=clipped_rewards,
96 | values=learner_outputs['baseline'],
97 | bootstrap_value=bootstrap_value)
98 |
99 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
100 | batch['action'],
101 | vtrace_returns.pg_advantages)
102 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
103 | vtrace_returns.vs - learner_outputs['baseline'])
104 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
105 | learner_outputs['policy_logits'])
106 |
107 | total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss
108 |
109 | episode_returns = batch['episode_return'][batch['done']]
110 | stats = {
111 | 'mean_episode_return': torch.mean(episode_returns).item(),
112 | 'total_loss': total_loss.item(),
113 | 'pg_loss': pg_loss.item(),
114 | 'baseline_loss': baseline_loss.item(),
115 | 'entropy_loss': entropy_loss.item(),
116 | 'rnd_loss': rnd_loss.item(),
117 | 'mean_rewards': torch.mean(rewards).item(),
118 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
119 | 'mean_total_rewards': torch.mean(total_rewards).item(),
120 | }
121 |
122 | scheduler.step()
123 | optimizer.zero_grad()
124 | predictor_optimizer.zero_grad()
125 | total_loss.backward()
126 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
127 | nn.utils.clip_grad_norm_(predictor_network.parameters(), flags.max_grad_norm)
128 | optimizer.step()
129 | predictor_optimizer.step()
130 |
131 | actor_model.load_state_dict(model.state_dict())
132 | return stats
133 |
134 |
135 | def train(flags):
136 | if flags.xpid is None:
137 | flags.xpid = 'rnd-%s' % time.strftime('%Y%m%d-%H%M%S')
138 | plogger = file_writer.FileWriter(
139 | xpid=flags.xpid,
140 | xp_args=flags.__dict__,
141 | rootdir=flags.savedir,
142 | )
143 |
144 | checkpointpath = os.path.expandvars(
145 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
146 | 'model.tar')))
147 |
148 | T = flags.unroll_length
149 | B = flags.batch_size
150 |
151 | flags.device = None
152 | if not flags.disable_cuda and torch.cuda.is_available():
153 | log.info('Using CUDA.')
154 | flags.device = torch.device('cuda')
155 | else:
156 | log.info('Not using CUDA.')
157 | flags.device = torch.device('cpu')
158 |
159 | env = create_env(flags)
160 | if flags.num_input_frames > 1:
161 | env = FrameStack(env, flags.num_input_frames)
162 |
163 | if 'MiniGrid' in flags.env:
164 | if flags.use_fullobs_policy:
165 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)
166 | else:
167 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)
168 | if flags.use_fullobs_intrinsic:
169 | random_target_network = FullObsMinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device)
170 | predictor_network = FullObsMinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device)
171 | else:
172 | random_target_network = MinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device)
173 | predictor_network = MinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device)
174 | else:
175 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)
176 | random_target_network = MarioDoomStateEmbeddingNet(env.observation_space.shape).to(device=flags.device)
177 | predictor_network = MarioDoomStateEmbeddingNet(env.observation_space.shape).to(device=flags.device)
178 |
179 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags)
180 |
181 | model.share_memory()
182 |
183 | initial_agent_state_buffers = []
184 | for _ in range(flags.num_buffers):
185 | state = model.initial_state(batch_size=1)
186 | for t in state:
187 | t.share_memory_()
188 | initial_agent_state_buffers.append(state)
189 |
190 | actor_processes = []
191 | ctx = mp.get_context('fork')
192 | free_queue = ctx.SimpleQueue()
193 | full_queue = ctx.SimpleQueue()
194 |
195 | episode_state_count_dict = dict()
196 | train_state_count_dict = dict()
197 | for i in range(flags.num_actors):
198 | actor = ctx.Process(
199 | target=act,
200 | args=(i, free_queue, full_queue, model, buffers,
201 | episode_state_count_dict, train_state_count_dict,
202 | initial_agent_state_buffers, flags))
203 | actor.start()
204 | actor_processes.append(actor)
205 |
206 | if 'MiniGrid' in flags.env:
207 | if flags.use_fullobs_policy:
208 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
209 | .to(device=flags.device)
210 | else:
211 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
212 | .to(device=flags.device)
213 | else:
214 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\
215 | .to(device=flags.device)
216 |
217 | optimizer = torch.optim.RMSprop(
218 | learner_model.parameters(),
219 | lr=flags.learning_rate,
220 | momentum=flags.momentum,
221 | eps=flags.epsilon,
222 | alpha=flags.alpha)
223 |
224 | predictor_optimizer = torch.optim.RMSprop(
225 | predictor_network.parameters(),
226 | lr=flags.learning_rate,
227 | momentum=flags.momentum,
228 | eps=flags.epsilon,
229 | alpha=flags.alpha)
230 |
231 |
232 | def lr_lambda(epoch):
233 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames
234 |
235 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
236 |
237 | logger = logging.getLogger('logfile')
238 | stat_keys = [
239 | 'total_loss',
240 | 'mean_episode_return',
241 | 'pg_loss',
242 | 'baseline_loss',
243 | 'entropy_loss',
244 | 'rnd_loss',
245 | 'mean_rewards',
246 | 'mean_intrinsic_rewards',
247 | 'mean_total_rewards',
248 | ]
249 |
250 | logger.info('# Step\t%s', '\t'.join(stat_keys))
251 |
252 | frames, stats = 0, {}
253 |
254 |
255 | def batch_and_learn(i, lock=threading.Lock()):
256 | """Thread target for the learning process."""
257 | nonlocal frames, stats
258 | timings = prof.Timings()
259 | while frames < flags.total_frames:
260 | timings.reset()
261 | batch, agent_state = get_batch(free_queue, full_queue, buffers,
262 | initial_agent_state_buffers, flags, timings)
263 | stats = learn(model, learner_model, random_target_network, predictor_network,
264 | batch, agent_state, optimizer, predictor_optimizer, scheduler,
265 | flags, frames=frames)
266 | timings.time('learn')
267 | with lock:
268 | to_log = dict(frames=frames)
269 | to_log.update({k: stats[k] for k in stat_keys})
270 | plogger.log(to_log)
271 | frames += T * B
272 |
273 | if i == 0:
274 | log.info('Batch and learn: %s', timings.summary())
275 |
276 | for m in range(flags.num_buffers):
277 | free_queue.put(m)
278 |
279 | threads = []
280 | for i in range(flags.num_threads):
281 | thread = threading.Thread(
282 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,))
283 | thread.start()
284 | threads.append(thread)
285 |
286 |
287 | def checkpoint(frames):
288 | if flags.disable_checkpoint:
289 | return
290 | checkpointpath = os.path.expandvars(os.path.expanduser(
291 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar')))
292 | log.info('Saving checkpoint to %s', checkpointpath)
293 | torch.save({
294 | 'model_state_dict': model.state_dict(),
295 | 'random_target_network_state_dict': random_target_network.state_dict(),
296 | 'predictor_network_state_dict': predictor_network.state_dict(),
297 | 'optimizer_state_dict': optimizer.state_dict(),
298 | 'predictor_optimizer_state_dict': predictor_optimizer.state_dict(),
299 | 'scheduler_state_dict': scheduler.state_dict(),
300 | 'flags': vars(flags),
301 | }, checkpointpath)
302 |
303 | timer = timeit.default_timer
304 | try:
305 | last_checkpoint_time = timer()
306 | while frames < flags.total_frames:
307 | start_frames = frames
308 | start_time = timer()
309 | time.sleep(5)
310 |
311 | if timer() - last_checkpoint_time > flags.save_interval * 60:
312 | checkpoint(frames)
313 | last_checkpoint_time = timer()
314 |
315 | fps = (frames - start_frames) / (timer() - start_time)
316 |
317 | if stats.get('episode_returns', None):
318 | mean_return = 'Return per episode: %.1f. ' % stats[
319 | 'mean_episode_return']
320 | else:
321 | mean_return = ''
322 |
323 | total_loss = stats.get('total_loss', float('inf'))
324 | if stats:
325 | log.info('After %i frames: loss %f @ %.1f fps. Mean Return %.1f. \n Stats \n %s', \
326 | frames, total_loss, fps, stats['mean_episode_return'], pprint.pformat(stats))
327 |
328 | except KeyboardInterrupt:
329 | return
330 | else:
331 | for thread in threads:
332 | thread.join()
333 | log.info('Learning finished after %d frames.', frames)
334 | finally:
335 | for _ in range(flags.num_actors):
336 | free_queue.put(None)
337 | for actor in actor_processes:
338 | actor.join(timeout=1)
339 |
340 | checkpoint(frames)
341 | plogger.close()
342 |
343 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from __future__ import division
7 | import torch.nn as nn
8 | import torch
9 | import typing
10 | import gym
11 | import threading
12 | from torch import multiprocessing as mp
13 | import logging
14 | import traceback
15 | import os
16 | import numpy as np
17 | import copy
18 |
19 | from src.core import prof
20 | from src.env_utils import FrameStack, Environment, Minigrid2Image
21 | from src import atari_wrappers as atari_wrappers
22 |
23 | from gym_minigrid import wrappers as wrappers
24 |
25 | # from nes_py.wrappers import JoypadSpace
26 | # from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
27 |
28 | # import vizdoomgym
29 | OBJECT_TO_IDX = {
30 | 'unseen' : 0,
31 | 'empty' : 1,
32 | 'wall' : 2,
33 | 'floor' : 3,
34 | 'door' : 4,
35 | 'key' : 5,
36 | 'ball' : 6,
37 | 'box' : 7,
38 | 'goal' : 8,
39 | 'lava' : 9,
40 | 'agent' : 10,
41 | }
42 |
43 | # This augmentation is based on random walk of agents
44 | def augmentation(frames):
45 | # agent_loc = agent_loc(frames)
46 | return frames
47 |
48 | # Entropy loss on categorical distribution
49 | def catentropy(logits):
50 | a = logits - torch.max(logits, dim=-1, keepdim=True)[0]
51 | e = torch.exp(a)
52 | z = torch.sum(e, dim=-1, keepdim=True)
53 | p = e / z
54 | entropy = torch.sum(p * (torch.log(z) - a), dim=-1)
55 | return torch.mean(entropy)
56 |
57 | # Here is computing how many objects
58 | def num_objects(frames):
59 | T, B, H, W, *_ = frames.shape
60 | num_objects = frames[:, :, :, :, 0]
61 | num_objects = (num_objects == 4).long() + (num_objects == 5).long() + \
62 | (num_objects == 6).long() + (num_objects == 7).long() + (num_objects == 8).long()
63 | return num_objects
64 |
65 | # EMA of the 2 networks
66 | def soft_update_params(net, target_net, tau):
67 | for param, target_param in zip(net.parameters(), target_net.parameters()):
68 | target_param.data.copy_(
69 | tau * param.data + (1 - tau) * target_param.data
70 | )
71 |
72 | def agent_loc(frames):
73 | T, B, H, W, *_ = frames.shape
74 | agent_location = torch.flatten(frames, 2, 3)
75 | agent_location = agent_location[:,:,:,0]
76 | agent_location = (agent_location == 10).nonzero() #select object id
77 | agent_location = agent_location[:,2]
78 | agent_location = torch.cat(((agent_location//W).unsqueeze(-1), (agent_location%W).unsqueeze(-1)), dim=-1)
79 | agent_location = agent_location.view(-1).tolist()
80 | return agent_location
81 |
82 | COMPLETE_MOVEMENT = [
83 | ['NOOP'],
84 | ['up'],
85 | ['down'],
86 | ['left'],
87 | ['left', 'A'],
88 | ['left', 'B'],
89 | ['left', 'A', 'B'],
90 | ['right'],
91 | ['right', 'A'],
92 | ['right', 'B'],
93 | ['right', 'A', 'B'],
94 | ['A'],
95 | ['B'],
96 | ['A', 'B'],
97 | ]
98 |
99 | shandle = logging.StreamHandler()
100 | shandle.setFormatter(
101 | logging.Formatter(
102 | '[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] '
103 | '%(message)s'))
104 | log = logging.getLogger('torchbeast')
105 | log.propagate = False
106 | log.addHandler(shandle)
107 | log.setLevel(logging.INFO)
108 |
109 | Buffers = typing.Dict[str, typing.List[torch.Tensor]]
110 |
111 |
112 | def create_env(flags):
113 | if 'MiniGrid' in flags.env:
114 | return Minigrid2Image(wrappers.FullyObsWrapper(gym.make(flags.env)))
115 | elif 'Mario' in flags.env:
116 | env = atari_wrappers.wrap_pytorch(
117 | atari_wrappers.wrap_deepmind(
118 | atari_wrappers.make_atari(flags.env, noop=True),
119 | clip_rewards=False,
120 | frame_stack=True,
121 | scale=False,
122 | fire=True))
123 | env = JoypadSpace(env, COMPLETE_MOVEMENT)
124 | return env
125 | else:
126 | env = atari_wrappers.wrap_pytorch(
127 | atari_wrappers.wrap_deepmind(
128 | atari_wrappers.make_atari(flags.env, noop=False),
129 | clip_rewards=False,
130 | frame_stack=True,
131 | scale=False,
132 | fire=False))
133 | return env
134 |
135 |
136 | def get_batch(free_queue: mp.Queue,
137 | full_queue: mp.Queue,
138 | buffers: Buffers,
139 | initial_agent_state_buffers,
140 | initial_encoder_state_buffers,
141 | flags,
142 | timings,
143 | lock=threading.Lock()):
144 | with lock:
145 | timings.time('lock')
146 | indices = [full_queue.get() for _ in range(flags.batch_size)]
147 | timings.time('dequeue')
148 | batch = {
149 | key: torch.stack([buffers[key][m] for m in indices], dim=1)
150 | for key in buffers
151 | }
152 | initial_agent_state = (
153 | torch.cat(ts, dim=1)
154 | for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
155 | )
156 | initial_encoder_state = (
157 | torch.cat(ts, dim=1)
158 | for ts in zip(*[initial_encoder_state_buffers[m] for m in indices])
159 | )
160 | timings.time('batch')
161 | for m in indices:
162 | free_queue.put(m)
163 | timings.time('enqueue')
164 | batch = {
165 | k: t.to(device=flags.device, non_blocking=True)
166 | for k, t in batch.items()
167 | }
168 | initial_agent_state = tuple(t.to(device=flags.device, non_blocking=True)
169 | for t in initial_agent_state)
170 | initial_encoder_state = tuple(t.to(device=flags.device, non_blocking=True)
171 | for t in initial_encoder_state)
172 | timings.time('device')
173 | return batch, initial_agent_state, initial_encoder_state
174 |
175 | def create_buffers(obs_shape, num_actions, flags) -> Buffers:
176 | T = flags.unroll_length
177 | specs = dict(
178 | frame=dict(size=(T + 1, *obs_shape), dtype=torch.uint8),
179 | reward=dict(size=(T + 1,), dtype=torch.float32),
180 | done=dict(size=(T + 1,), dtype=torch.uint8),
181 | episode_return=dict(size=(T + 1,), dtype=torch.float32),
182 | episode_step=dict(size=(T + 1,), dtype=torch.int32),
183 | last_action=dict(size=(T + 1,), dtype=torch.int64),
184 | policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32),
185 | baseline=dict(size=(T + 1,), dtype=torch.float32),
186 | action=dict(size=(T + 1,), dtype=torch.int64),
187 | episode_win=dict(size=(T + 1,), dtype=torch.int32),
188 | carried_obj=dict(size=(T + 1,), dtype=torch.int32),
189 | carried_col=dict(size=(T + 1,), dtype=torch.int32),
190 | partial_obs=dict(size=(T + 1, 7, 7, 3), dtype=torch.uint8),
191 | episode_state_count=dict(size=(T + 1, ), dtype=torch.float32),
192 | train_state_count=dict(size=(T + 1, ), dtype=torch.float32),
193 | partial_state_count=dict(size=(T + 1, ), dtype=torch.float32),
194 | encoded_state_count=dict(size=(T + 1, ), dtype=torch.float32),
195 | )
196 | buffers: Buffers = {key: [] for key in specs}
197 | for _ in range(flags.num_buffers):
198 | for key in buffers:
199 | buffers[key].append(torch.empty(**specs[key]).share_memory_())
200 | return buffers
201 |
202 | def create_heatmap_buffers(obs_shape):
203 | specs = []
204 | for r in range(obs_shape[0]):
205 | for c in range(obs_shape[1]):
206 | specs.append(tuple([r, c]))
207 | buffers: Buffers = {key: torch.zeros(1).share_memory_() for key in specs}
208 | return buffers
209 |
210 | def act(i: int, free_queue: mp.Queue, full_queue: mp.Queue,
211 | model: torch.nn.Module,
212 | encoder: torch.nn.Module,
213 | buffers: Buffers,
214 | # inv_dynamics_buf: Buffers,
215 | # inv_dynamics_dict: dict,
216 | episode_state_count_dict: dict, train_state_count_dict: dict,
217 | partial_state_count_dict: dict, encoded_state_count_dict: dict,
218 | heatmap_dict: dict,
219 | heatmap_buffers: Buffers,
220 | initial_agent_state_buffers,
221 | initial_encoder_state_buffers,
222 | flags):
223 | try:
224 | log.info('Actor %i started.', i)
225 | timings = prof.Timings()
226 |
227 | gym_env = create_env(flags)
228 | seed = i ^ int.from_bytes(os.urandom(4), byteorder='little')
229 | gym_env.seed(seed)
230 |
231 | if flags.num_input_frames > 1:
232 | gym_env = FrameStack(gym_env, flags.num_input_frames)
233 |
234 | env = Environment(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed)
235 |
236 | env_output = env.initial()
237 | agent_state = model.initial_state(batch_size=1)
238 | encoder_state = encoder.initial_state(batch_size=1)
239 | agent_output, unused_state = model(env_output, agent_state)
240 | prev_env_output = None
241 |
242 | while True:
243 | index = free_queue.get()
244 | if index is None:
245 | break
246 |
247 | # Write old rollout end.
248 | for key in env_output:
249 | buffers[key][index][0, ...] = env_output[key]
250 | for key in agent_output:
251 | buffers[key][index][0, ...] = agent_output[key]
252 | for i, tensor in enumerate(agent_state):
253 | initial_agent_state_buffers[index][i][...] = tensor
254 | for i, tensor in enumerate(encoder_state):
255 | initial_encoder_state_buffers[index][i][...] = tensor
256 |
257 |
258 | # Update the episodic state counts
259 | episode_state_key = tuple(env_output['frame'].view(-1).tolist())
260 | if episode_state_key in episode_state_count_dict:
261 | episode_state_count_dict[episode_state_key] += 1
262 | else:
263 | episode_state_count_dict.update({episode_state_key: 1})
264 | buffers['episode_state_count'][index][0, ...] = \
265 | torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))
266 |
267 | # Reset the episode state counts when the episode is over
268 | if env_output['done'][0][0]:
269 | for episode_state_key in episode_state_count_dict:
270 | episode_state_count_dict = dict()
271 |
272 | # Update the training state counts
273 | train_state_key = tuple(env_output['frame'].view(-1).tolist())
274 | if train_state_key in train_state_count_dict:
275 | train_state_count_dict[train_state_key] += 1
276 | else:
277 | train_state_count_dict.update({train_state_key: 1})
278 | buffers['train_state_count'][index][0, ...] = \
279 | torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
280 | partial_state_key = tuple(env_output['partial_obs'].view(-1).tolist())
281 | if partial_state_key in partial_state_count_dict:
282 | partial_state_count_dict[partial_state_key] += 1
283 | else:
284 | partial_state_count_dict.update({partial_state_key: 1})
285 | buffers['partial_state_count'][index][0, ...] = \
286 | torch.tensor(1 / np.sqrt(partial_state_count_dict.get(partial_state_key)))
287 | # Update the agent position counts
288 | heatmap_key = tuple(agent_loc(env_output['frame']))
289 | heatmap_buffers[heatmap_key] += 1
290 |
291 | # Do new rollout
292 | for t in range(flags.unroll_length):
293 | timings.reset()
294 |
295 | with torch.no_grad():
296 | agent_output, agent_state = model(env_output, agent_state)
297 | _, encoder_state = encoder(env_output['partial_obs'], encoder_state, env_output['done'])
298 |
299 | timings.time('model')
300 |
301 | prev_env_output = copy.deepcopy(env_output)
302 | env_output = env.step(agent_output['action'])
303 |
304 | timings.time('step')
305 |
306 | for key in env_output:
307 | buffers[key][index][t + 1, ...] = env_output[key]
308 |
309 | for key in agent_output:
310 | buffers[key][index][t + 1, ...] = agent_output[key]
311 |
312 | # Update the episodic state counts
313 | episode_state_key = tuple(env_output['frame'].view(-1).tolist())
314 | if episode_state_key in episode_state_count_dict:
315 | episode_state_count_dict[episode_state_key] += 1
316 | else:
317 | episode_state_count_dict.update({episode_state_key: 1})
318 | buffers['episode_state_count'][index][t + 1, ...] = \
319 | torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key)))
320 |
321 | # Reset the episode state counts when the episode is over
322 | if env_output['done'][0][0]:
323 | episode_state_count_dict = dict()
324 |
325 | # Update the training state counts
326 | train_state_key = tuple(env_output['frame'].view(-1).tolist())
327 | if train_state_key in train_state_count_dict:
328 | train_state_count_dict[train_state_key] += 1
329 | else:
330 | train_state_count_dict.update({train_state_key: 1})
331 | buffers['train_state_count'][index][t + 1, ...] = \
332 | torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key)))
333 | partial_state_key = tuple(env_output['partial_obs'].view(-1).tolist())
334 | if partial_state_key in partial_state_count_dict:
335 | partial_state_count_dict[partial_state_key] += 1
336 | else:
337 | partial_state_count_dict.update({partial_state_key: 1})
338 | buffers['partial_state_count'][index][t + 1, ...] = \
339 | torch.tensor(1 / np.sqrt(partial_state_count_dict.get(partial_state_key)))
340 | # Update the agent position counts
341 | heatmap_key = tuple(agent_loc(env_output['frame']))
342 | heatmap_buffers[heatmap_key] += 1
343 |
344 | timings.time('write')
345 | full_queue.put(index)
346 |
347 | if i == 0:
348 | log.info('Actor %i: %s', i, timings.summary())
349 |
350 | except KeyboardInterrupt:
351 | pass
352 | except Exception as e:
353 | logging.error('Exception in worker process %i', i)
354 | traceback.print_exc()
355 | print()
356 | raise e
357 |
358 |
--------------------------------------------------------------------------------
/src/algos/bebold.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 | import threading
11 | import time
12 | import timeit
13 | import pprint
14 | import json
15 |
16 | import numpy as np
17 |
18 | import torch
19 | from torch import multiprocessing as mp
20 | from torch import nn
21 | from torch.nn import functional as F
22 |
23 | from src.core import file_writer
24 | from src.core import prof
25 | from src.core import vtrace
26 |
27 | import src.models as models
28 | import src.losses as losses
29 |
30 | from src.env_utils import FrameStack
31 | from src.utils import get_batch, log, create_env, create_buffers, act, create_heatmap_buffers
32 |
33 | MinigridPolicyNet = models.MinigridPolicyNet
34 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet
35 | MinigridMLPEmbeddingNet = models.MinigridMLPEmbeddingNet
36 | MinigridMLPTargetEmbeddingNet = models.MinigridMLPTargetEmbeddingNet
37 |
38 | def momentum_update(model, target, ema_momentum):
39 | '''
40 | Update the key_encoder parameters through the momentum update:
41 | key_params = momentum * key_params + (1 - momentum) * query_params
42 | '''
43 | # For each of the parameters in each encoder
44 | for p_m, p_t in zip(model.parameters(), target.parameters()):
45 | p_m.data = p_m.data * ema_momentum + p_t.detach().data * (1. - ema_momentum)
46 | # For each of the buffers in each encoder
47 | for b_m, b_t in zip(model.buffers(), target.buffers()):
48 | b_m.data = b_m.data * ema_momentum + b_t.detach().data * (1. - ema_momentum)
49 |
50 | def learn(actor_model,
51 | model,
52 | random_target_network,
53 | predictor_network,
54 | actor_encoder,
55 | encoder,
56 | batch,
57 | initial_agent_state,
58 | initial_encoder_state,
59 | optimizer,
60 | predictor_optimizer,
61 | scheduler,
62 | flags,
63 | frames=None,
64 | lock=threading.Lock()):
65 | """Performs a learning (optimization) step."""
66 | with lock:
67 | count_rewards = torch.ones((flags.unroll_length, flags.batch_size),
68 | dtype=torch.float32).to(device=flags.device)
69 | # Use the scale of square root N
70 | count_rewards = batch['episode_state_count'][1:].float().to(device=flags.device)
71 |
72 | encoded_states, unused_state = encoder(batch['partial_obs'].to(device=flags.device), initial_encoder_state, batch['done'])
73 | random_embedding_next, unused_state = random_target_network(encoded_states[1:].detach(), initial_agent_state)
74 | predicted_embedding_next, unused_state = predictor_network(encoded_states[1:].detach(), initial_agent_state)
75 | random_embedding, unused_state = random_target_network(encoded_states[:-1].detach(), initial_agent_state)
76 | predicted_embedding, unused_state = predictor_network(encoded_states[:-1].detach(), initial_agent_state)
77 |
78 | intrinsic_rewards_next = torch.norm(predicted_embedding_next.detach() - random_embedding_next.detach(), dim=2, p=2)
79 | intrinsic_rewards = torch.norm(predicted_embedding.detach() - random_embedding.detach(), dim=2, p=2)
80 | intrinsic_rewards = torch.clamp(intrinsic_rewards_next - flags.scale_fac * intrinsic_rewards, min=0)
81 | intrinsic_rewards *= (count_rewards == 1).float()
82 |
83 | intrinsic_reward_coef = flags.intrinsic_reward_coef
84 | intrinsic_rewards *= count_rewards * intrinsic_reward_coef
85 |
86 | num_samples = flags.unroll_length * flags.batch_size
87 | actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy()
88 | intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy()
89 | rnd_loss = flags.rnd_loss_coef * \
90 | losses.compute_rnd_loss(predicted_embedding_next, random_embedding_next.detach())
91 |
92 | learner_outputs, unused_state = model(batch, initial_agent_state)
93 |
94 | bootstrap_value = learner_outputs['baseline'][-1]
95 |
96 | batch = {key: tensor[1:] for key, tensor in batch.items()}
97 | learner_outputs = {
98 | key: tensor[:-1]
99 | for key, tensor in learner_outputs.items()
100 | }
101 |
102 | rewards = batch['reward']
103 |
104 | if flags.no_reward:
105 | total_rewards = intrinsic_rewards
106 | else:
107 | total_rewards = rewards + intrinsic_rewards
108 | clipped_rewards = torch.clamp(total_rewards, -1, 1)
109 |
110 | discounts = (~batch['done']).float() * flags.discounting
111 |
112 | vtrace_returns = vtrace.from_logits(
113 | behavior_policy_logits=batch['policy_logits'],
114 | target_policy_logits=learner_outputs['policy_logits'],
115 | actions=batch['action'],
116 | discounts=discounts,
117 | rewards=clipped_rewards,
118 | values=learner_outputs['baseline'],
119 | bootstrap_value=bootstrap_value)
120 |
121 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
122 | batch['action'],
123 | vtrace_returns.pg_advantages)
124 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
125 | vtrace_returns.vs - learner_outputs['baseline'])
126 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
127 | learner_outputs['policy_logits'])
128 |
129 | total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss
130 |
131 | episode_returns = batch['episode_return'][batch['done']]
132 | stats = {
133 | 'mean_episode_return': torch.mean(episode_returns).item(),
134 | 'total_loss': total_loss.item(),
135 | 'pg_loss': pg_loss.item(),
136 | 'baseline_loss': baseline_loss.item(),
137 | 'entropy_loss': entropy_loss.item(),
138 | 'rnd_loss': rnd_loss.item(),
139 | 'mean_rewards': torch.mean(rewards).item(),
140 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
141 | 'mean_total_rewards': torch.mean(total_rewards).item(),
142 | }
143 |
144 | scheduler.step()
145 | optimizer.zero_grad()
146 | predictor_optimizer.zero_grad()
147 | total_loss.backward()
148 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
149 | nn.utils.clip_grad_norm_(predictor_network.parameters(), flags.max_grad_norm)
150 | optimizer.step()
151 | predictor_optimizer.step()
152 |
153 | actor_model.load_state_dict(model.state_dict())
154 | actor_encoder.load_state_dict(encoder.state_dict())
155 | return stats
156 |
157 |
158 | def train(flags):
159 | if flags.xpid is None:
160 | flags.xpid = flags.env + '-bebold-%s' % time.strftime('%Y%m%d-%H%M%S')
161 | plogger = file_writer.FileWriter(
162 | xpid=flags.xpid,
163 | xp_args=flags.__dict__,
164 | rootdir=flags.savedir,
165 | )
166 |
167 | checkpointpath = os.path.expandvars(
168 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
169 | 'model.tar')))
170 |
171 | T = flags.unroll_length
172 | B = flags.batch_size
173 |
174 | flags.device = None
175 | if not flags.disable_cuda and torch.cuda.is_available():
176 | log.info('Using CUDA.')
177 | flags.device = torch.device('cuda')
178 | else:
179 | log.info('Not using CUDA.')
180 | flags.device = torch.device('cpu')
181 |
182 | env = create_env(flags)
183 | if flags.num_input_frames > 1:
184 | env = FrameStack(env, flags.num_input_frames)
185 |
186 | if 'MiniGrid' in flags.env:
187 | if flags.use_fullobs_policy:
188 | raise Exception('We have not implemented full ob policy!')
189 | else:
190 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)
191 | random_target_network = MinigridMLPTargetEmbeddingNet().to(device=flags.device)
192 | predictor_network = MinigridMLPEmbeddingNet().to(device=flags.device)
193 | encoder = MinigridStateEmbeddingNet(env.observation_space.shape, flags.use_lstm)
194 | else:
195 | raise Exception('Only MiniGrid is suppported Now!')
196 |
197 | momentum_update(encoder.feat_extract, model.feat_extract, 0)
198 | momentum_update(encoder.fc, model.fc, 0)
199 | if flags.use_lstm:
200 | momentum_update(encoder.core, model.core, 0)
201 |
202 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags)
203 | heatmap_buffers = create_heatmap_buffers(env.observation_space.shape)
204 | model.share_memory()
205 | encoder.share_memory()
206 |
207 | initial_agent_state_buffers = []
208 | for _ in range(flags.num_buffers):
209 | state = model.initial_state(batch_size=1)
210 | for t in state:
211 | t.share_memory_()
212 | initial_agent_state_buffers.append(state)
213 | initial_encoder_state_buffers = []
214 | for _ in range(flags.num_buffers):
215 | state = encoder.initial_state(batch_size=1)
216 | for t in state:
217 | t.share_memory_()
218 | initial_encoder_state_buffers.append(state)
219 |
220 | actor_processes = []
221 | ctx = mp.get_context('fork')
222 | free_queue = ctx.Queue()
223 | full_queue = ctx.Queue()
224 |
225 | episode_state_count_dict = dict()
226 | train_state_count_dict = dict()
227 | partial_state_count_dict = dict()
228 | encoded_state_count_dict = dict()
229 | heatmap_dict = dict()
230 | for i in range(flags.num_actors):
231 | actor = ctx.Process(
232 | target=act,
233 | args=(i, free_queue, full_queue, model, encoder, buffers,
234 | episode_state_count_dict, train_state_count_dict, partial_state_count_dict, encoded_state_count_dict,
235 | heatmap_dict, heatmap_buffers, initial_agent_state_buffers, initial_encoder_state_buffers, flags))
236 | actor.start()
237 | actor_processes.append(actor)
238 |
239 | if 'MiniGrid' in flags.env:
240 | if flags.use_fullobs_policy:
241 | raise Exception('We have not implemented full ob policy!')
242 | else:
243 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
244 | .to(device=flags.device)
245 | learner_encoder = MinigridStateEmbeddingNet(env.observation_space.shape, flags.use_lstm)\
246 | .to(device=flags.device)
247 | else:
248 | raise Exception('Only MiniGrid is suppported Now!')
249 | learner_encoder.load_state_dict(encoder.state_dict())
250 |
251 | optimizer = torch.optim.RMSprop(
252 | learner_model.parameters(),
253 | lr=flags.learning_rate,
254 | momentum=flags.momentum,
255 | eps=flags.epsilon,
256 | alpha=flags.alpha)
257 |
258 | predictor_optimizer = torch.optim.Adam(
259 | predictor_network.parameters(),
260 | lr=flags.predictor_learning_rate)
261 |
262 | def lr_lambda(epoch):
263 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames
264 |
265 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
266 |
267 | logger = logging.getLogger('logfile')
268 | stat_keys = [
269 | 'total_loss',
270 | 'mean_episode_return',
271 | 'pg_loss',
272 | 'baseline_loss',
273 | 'entropy_loss',
274 | 'rnd_loss',
275 | 'mean_rewards',
276 | 'mean_intrinsic_rewards',
277 | 'mean_total_rewards',
278 | ]
279 |
280 | logger.info('# Step\t%s', '\t'.join(stat_keys))
281 |
282 | frames, stats = 0, {}
283 |
284 | def batch_and_learn(i, lock=threading.Lock()):
285 | """Thread target for the learning process."""
286 | nonlocal frames, stats
287 | timings = prof.Timings()
288 | while frames < flags.total_frames:
289 | timings.reset()
290 | batch, agent_state, encoder_state = get_batch(free_queue, full_queue, buffers,
291 | initial_agent_state_buffers, initial_encoder_state_buffers, flags, timings)
292 | stats = learn(model, learner_model, random_target_network, predictor_network,
293 | encoder, learner_encoder, batch, agent_state, encoder_state, optimizer,
294 | predictor_optimizer, scheduler, flags, frames=frames)
295 | timings.time('learn')
296 | with lock:
297 | to_log = dict(frames=frames)
298 | to_log.update({k: stats[k] for k in stat_keys})
299 | plogger.log(to_log)
300 | frames += T * B
301 |
302 | if i == 0:
303 | log.info('Batch and learn: %s', timings.summary())
304 |
305 | for m in range(flags.num_buffers):
306 | free_queue.put(m)
307 |
308 | threads = []
309 | for i in range(flags.num_threads):
310 | thread = threading.Thread(
311 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,))
312 | thread.start()
313 | threads.append(thread)
314 |
315 |
316 | def checkpoint(frames):
317 | if flags.disable_checkpoint:
318 | return
319 | checkpointpath = os.path.expandvars(os.path.expanduser(
320 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model_'+str(frames)+'.tar')))
321 | log.info('Saving checkpoint to %s', checkpointpath)
322 | torch.save({
323 | 'model_state_dict': model.state_dict(),
324 | 'encoder': encoder.state_dict(),
325 | 'random_target_network_state_dict': random_target_network.state_dict(),
326 | 'predictor_network_state_dict': predictor_network.state_dict(),
327 | 'optimizer_state_dict': optimizer.state_dict(),
328 | 'predictor_optimizer_state_dict': predictor_optimizer.state_dict(),
329 | 'scheduler_state_dict': scheduler.state_dict(),
330 | 'flags': vars(flags),
331 | }, checkpointpath)
332 |
333 | timer = timeit.default_timer
334 | try:
335 | last_checkpoint_time = timer()
336 | while frames < flags.total_frames:
337 | start_frames = frames
338 | start_time = timer()
339 | time.sleep(5)
340 |
341 | if timer() - last_checkpoint_time > flags.save_interval * 60:
342 | checkpoint(frames)
343 | save_heatmap(frames)
344 | last_checkpoint_time = timer()
345 |
346 | fps = (frames - start_frames) / (timer() - start_time)
347 |
348 | if stats.get('episode_returns', None):
349 | mean_return = 'Return per episode: %.1f. ' % stats[
350 | 'mean_episode_return']
351 | else:
352 | mean_return = ''
353 |
354 | total_loss = stats.get('total_loss', float('inf'))
355 | if stats:
356 | log.info('After %i frames: loss %f @ %.1f fps. Mean Return %.1f. \n Stats \n %s', \
357 | frames, total_loss, fps, stats['mean_episode_return'], pprint.pformat(stats))
358 |
359 | except KeyboardInterrupt:
360 | return
361 | else:
362 | for thread in threads:
363 | thread.join()
364 | log.info('Learning finished after %d frames.', frames)
365 | finally:
366 | for _ in range(flags.num_actors):
367 | free_queue.put(None)
368 | for actor in actor_processes:
369 | actor.join(timeout=1)
370 |
371 | checkpoint(frames)
372 | plogger.close()
373 |
374 |
--------------------------------------------------------------------------------
/src/algos/curiosity.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import sys
10 | import threading
11 | import time
12 | import timeit
13 | import pprint
14 |
15 | import numpy as np
16 |
17 | import torch
18 | from torch import multiprocessing as mp
19 | from torch import nn
20 | from torch.nn import functional as F
21 |
22 | from src.core import file_writer
23 | from src.core import prof
24 | from src.core import vtrace
25 |
26 | import src.models as models
27 | import src.losses as losses
28 |
29 | from src.env_utils import FrameStack
30 | from src.utils import get_batch, log, create_env, create_buffers, act
31 |
32 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet
33 | MinigridForwardDynamicsNet = models.MinigridForwardDynamicsNet
34 | MinigridInverseDynamicsNet = models.MinigridInverseDynamicsNet
35 | MinigridPolicyNet = models.MinigridPolicyNet
36 |
37 | def learn(actor_model,
38 | model,
39 | state_embedding_model,
40 | forward_dynamics_model,
41 | inverse_dynamics_model,
42 | batch,
43 | initial_agent_state,
44 | optimizer,
45 | state_embedding_optimizer,
46 | forward_dynamics_optimizer,
47 | inverse_dynamics_optimizer,
48 | scheduler,
49 | flags,
50 | frames=None,
51 | lock=threading.Lock()):
52 | """Performs a learning (optimization) step."""
53 | with lock:
54 | if flags.use_fullobs_intrinsic:
55 | state_emb = state_embedding_model(batch, next_state=False)\
56 | .reshape(flags.unroll_length, flags.batch_size, 128)
57 | next_state_emb = state_embedding_model(batch, next_state=True)\
58 | .reshape(flags.unroll_length, flags.batch_size, 128)
59 | else:
60 | state_emb = state_embedding_model(batch['partial_obs'][:-1].to(device=flags.device))
61 | next_state_emb = state_embedding_model(batch['partial_obs'][1:].to(device=flags.device))
62 |
63 | pred_next_state_emb = forward_dynamics_model(\
64 | state_emb, batch['action'][1:].to(device=flags.device))
65 | pred_actions = inverse_dynamics_model(state_emb, next_state_emb)
66 | entropy_emb_actions = losses.compute_entropy_loss(pred_actions)
67 |
68 | intrinsic_rewards = torch.norm(pred_next_state_emb - next_state_emb, dim=2, p=2)
69 |
70 | intrinsic_reward_coef = flags.intrinsic_reward_coef
71 | intrinsic_rewards *= intrinsic_reward_coef
72 |
73 | forward_dynamics_loss = flags.forward_loss_coef * \
74 | losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb)
75 |
76 | inverse_dynamics_loss = flags.inverse_loss_coef * \
77 | losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:])
78 |
79 | num_samples = flags.unroll_length * flags.batch_size
80 | actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy()
81 | intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy()
82 |
83 |
84 | learner_outputs, unused_state = model(batch, initial_agent_state)
85 |
86 | bootstrap_value = learner_outputs['baseline'][-1]
87 |
88 | batch = {key: tensor[1:] for key, tensor in batch.items()}
89 | learner_outputs = {
90 | key: tensor[:-1]
91 | for key, tensor in learner_outputs.items()
92 | }
93 |
94 | actions = batch['action'].reshape(flags.unroll_length * flags.batch_size).cpu().numpy()
95 | action_percentage = [0 for _ in range(model.num_actions)]
96 | for i in range(model.num_actions):
97 | action_percentage[i] = np.sum([a == i for a in actions]) / len(actions)
98 |
99 | rewards = batch['reward']
100 |
101 | if flags.no_reward:
102 | total_rewards = intrinsic_rewards
103 | else:
104 | total_rewards = rewards + intrinsic_rewards
105 | clipped_rewards = torch.clamp(total_rewards, -1, 1)
106 |
107 | discounts = (~batch['done']).float() * flags.discounting
108 |
109 | vtrace_returns = vtrace.from_logits(
110 | behavior_policy_logits=batch['policy_logits'],
111 | target_policy_logits=learner_outputs['policy_logits'],
112 | actions=batch['action'],
113 | discounts=discounts,
114 | rewards=clipped_rewards,
115 | values=learner_outputs['baseline'],
116 | bootstrap_value=bootstrap_value)
117 |
118 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
119 | batch['action'],
120 | vtrace_returns.pg_advantages)
121 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
122 | vtrace_returns.vs - learner_outputs['baseline'])
123 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
124 | learner_outputs['policy_logits'])
125 |
126 | total_loss = pg_loss + baseline_loss + entropy_loss \
127 | + forward_dynamics_loss + inverse_dynamics_loss
128 |
129 | episode_returns = batch['episode_return'][batch['done']]
130 | episode_lengths = batch['episode_step'][batch['done']]
131 | episode_wins = batch['episode_win'][batch['done']]
132 | stats = {
133 | 'mean_episode_return': torch.mean(episode_returns).item(),
134 | 'total_loss': total_loss.item(),
135 | 'pg_loss': pg_loss.item(),
136 | 'baseline_loss': baseline_loss.item(),
137 | 'entropy_loss': entropy_loss.item(),
138 | 'forward_dynamics_loss': forward_dynamics_loss.item(),
139 | 'inverse_dynamics_loss': inverse_dynamics_loss.item(),
140 | 'mean_rewards': torch.mean(rewards).item(),
141 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
142 | 'mean_total_rewards': torch.mean(total_rewards).item(),
143 | }
144 |
145 | scheduler.step()
146 | optimizer.zero_grad()
147 | state_embedding_optimizer.zero_grad()
148 | forward_dynamics_optimizer.zero_grad()
149 | inverse_dynamics_optimizer.zero_grad()
150 | total_loss.backward()
151 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
152 | nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm)
153 | nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm)
154 | nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm)
155 | optimizer.step()
156 | state_embedding_optimizer.step()
157 | forward_dynamics_optimizer.step()
158 | inverse_dynamics_optimizer.step()
159 |
160 | actor_model.load_state_dict(model.state_dict())
161 | return stats
162 |
163 |
164 | def train(flags):
165 | if flags.xpid is None:
166 | flags.xpid = 'curiosity-%s' % time.strftime('%Y%m%d-%H%M%S')
167 | plogger = file_writer.FileWriter(
168 | xpid=flags.xpid,
169 | xp_args=flags.__dict__,
170 | rootdir=flags.savedir,
171 | )
172 |
173 | checkpointpath = os.path.expandvars(os.path.expanduser(
174 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar')))
175 |
176 | T = flags.unroll_length
177 | B = flags.batch_size
178 |
179 | flags.device = None
180 | if not flags.disable_cuda and torch.cuda.is_available():
181 | log.info('Using CUDA.')
182 | flags.device = torch.device('cuda')
183 | else:
184 | log.info('Not using CUDA.')
185 | flags.device = torch.device('cpu')
186 |
187 | env = create_env(flags)
188 | if flags.num_input_frames > 1:
189 | env = FrameStack(env, flags.num_input_frames)
190 |
191 | if 'MiniGrid' in flags.env:
192 | if flags.use_fullobs_policy:
193 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)
194 | else:
195 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)
196 | if flags.use_fullobs_intrinsic:
197 | state_embedding_model = FullObsMinigridStateEmbeddingNet(env.observation_space.shape)\
198 | .to(device=flags.device)
199 | else:
200 | state_embedding_model = MinigridStateEmbeddingNet(env.observation_space.shape)\
201 | .to(device=flags.device)
202 | forward_dynamics_model = MinigridForwardDynamicsNet(env.action_space.n)\
203 | .to(device=flags.device)
204 | inverse_dynamics_model = MinigridInverseDynamicsNet(env.action_space.n)\
205 | .to(device=flags.device)
206 | else:
207 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)
208 | state_embedding_model = MarioDoomStateEmbeddingNet(env.observation_space.shape)\
209 | .to(device=flags.device)
210 | forward_dynamics_model = MarioDoomForwardDynamicsNet(env.action_space.n)\
211 | .to(device=flags.device)
212 | inverse_dynamics_model = MarioDoomInverseDynamicsNet(env.action_space.n)\
213 | .to(device=flags.device)
214 |
215 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags)
216 | model.share_memory()
217 |
218 | initial_agent_state_buffers = []
219 | for _ in range(flags.num_buffers):
220 | state = model.initial_state(batch_size=1)
221 | for t in state:
222 | t.share_memory_()
223 | initial_agent_state_buffers.append(state)
224 |
225 | actor_processes = []
226 | ctx = mp.get_context('fork')
227 | free_queue = ctx.SimpleQueue()
228 | full_queue = ctx.SimpleQueue()
229 |
230 | episode_state_count_dict = dict()
231 | train_state_count_dict = dict()
232 | for i in range(flags.num_actors):
233 | actor = ctx.Process(
234 | target=act,
235 | args=(i, free_queue, full_queue, model, buffers,
236 | episode_state_count_dict, train_state_count_dict,
237 | initial_agent_state_buffers, flags))
238 | actor.start()
239 | actor_processes.append(actor)
240 |
241 | if 'MiniGrid' in flags.env:
242 | if flags.use_fullobs_policy:
243 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
244 | .to(device=flags.device)
245 | else:
246 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
247 | .to(device=flags.device)
248 | else:
249 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\
250 | .to(device=flags.device)
251 |
252 | optimizer = torch.optim.RMSprop(
253 | learner_model.parameters(),
254 | lr=flags.learning_rate,
255 | momentum=flags.momentum,
256 | eps=flags.epsilon,
257 | alpha=flags.alpha)
258 |
259 | state_embedding_optimizer = torch.optim.RMSprop(
260 | state_embedding_model.parameters(),
261 | lr=flags.learning_rate,
262 | momentum=flags.momentum,
263 | eps=flags.epsilon,
264 | alpha=flags.alpha)
265 |
266 | inverse_dynamics_optimizer = torch.optim.RMSprop(
267 | inverse_dynamics_model.parameters(),
268 | lr=flags.learning_rate,
269 | momentum=flags.momentum,
270 | eps=flags.epsilon,
271 | alpha=flags.alpha)
272 |
273 | forward_dynamics_optimizer = torch.optim.RMSprop(
274 | forward_dynamics_model.parameters(),
275 | lr=flags.learning_rate,
276 | momentum=flags.momentum,
277 | eps=flags.epsilon,
278 | alpha=flags.alpha)
279 |
280 |
281 | def lr_lambda(epoch):
282 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames
283 |
284 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
285 |
286 | logger = logging.getLogger('logfile')
287 | stat_keys = [
288 | 'total_loss',
289 | 'mean_episode_return',
290 | 'pg_loss',
291 | 'baseline_loss',
292 | 'entropy_loss',
293 | 'forward_dynamics_loss',
294 | 'inverse_dynamics_loss',
295 | 'mean_rewards',
296 | 'mean_intrinsic_rewards',
297 | 'mean_total_rewards',
298 | ]
299 |
300 | logger.info('# Step\t%s', '\t'.join(stat_keys))
301 |
302 | frames, stats = 0, {}
303 |
304 |
305 | def batch_and_learn(i, lock=threading.Lock()):
306 | """Thread target for the learning process."""
307 | nonlocal frames, stats
308 | timings = prof.Timings()
309 | while frames < flags.total_frames:
310 | timings.reset()
311 | batch, agent_state = get_batch(free_queue, full_queue, buffers,
312 | initial_agent_state_buffers, flags, timings)
313 | stats = learn(model, learner_model, state_embedding_model, forward_dynamics_model,
314 | inverse_dynamics_model, batch, agent_state, optimizer,
315 | state_embedding_optimizer, forward_dynamics_optimizer,
316 | inverse_dynamics_optimizer, scheduler, flags, frames=frames)
317 | timings.time('learn')
318 | with lock:
319 | to_log = dict(frames=frames)
320 | to_log.update({k: stats[k] for k in stat_keys})
321 | plogger.log(to_log)
322 | frames += T * B
323 |
324 | if i == 0:
325 | log.info('Batch and learn: %s', timings.summary())
326 |
327 | for m in range(flags.num_buffers):
328 | free_queue.put(m)
329 |
330 | threads = []
331 | for i in range(flags.num_threads):
332 | thread = threading.Thread(
333 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,))
334 | thread.start()
335 | threads.append(thread)
336 |
337 |
338 | def checkpoint(frames):
339 | if flags.disable_checkpoint:
340 | return
341 | checkpointpath = os.path.expandvars(
342 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid,
343 | 'model.tar')))
344 | log.info('Saving checkpoint to %s', checkpointpath)
345 | torch.save({
346 | 'model_state_dict': model.state_dict(),
347 | 'state_embedding_model_state_dict': state_embedding_model.state_dict(),
348 | 'forward_dynamics_model_state_dict': forward_dynamics_model.state_dict(),
349 | 'inverse_dynamics_model_state_dict': inverse_dynamics_model.state_dict(),
350 | 'optimizer_state_dict': optimizer.state_dict(),
351 | 'state_embedding_optimizer_state_dict': state_embedding_optimizer.state_dict(),
352 | 'forward_dynamics_optimizer_state_dict': forward_dynamics_optimizer.state_dict(),
353 | 'inverse_dynamics_optimizer_state_dict': inverse_dynamics_optimizer.state_dict(),
354 | 'scheduler_state_dict': scheduler.state_dict(),
355 | 'flags': vars(flags),
356 | }, checkpointpath)
357 |
358 | timer = timeit.default_timer
359 | try:
360 | last_checkpoint_time = timer()
361 | while frames < flags.total_frames:
362 | start_frames = frames
363 | start_time = timer()
364 | time.sleep(5)
365 |
366 | if timer() - last_checkpoint_time > flags.save_interval * 60:
367 | checkpoint(frames)
368 | last_checkpoint_time = timer()
369 |
370 | fps = (frames - start_frames) / (timer() - start_time)
371 |
372 | if stats.get('episode_returns', None):
373 | mean_return = 'Return per episode: %.1f. ' % stats[
374 | 'mean_episode_return']
375 | else:
376 | mean_return = ''
377 |
378 | total_loss = stats.get('total_loss', float('inf'))
379 | if stats:
380 | log.info('After %i frames: loss %f @ %.1f fps. Mean Return %.1f. \n Stats \n %s', \
381 | frames, total_loss, fps, stats['mean_episode_return'], pprint.pformat(stats))
382 |
383 | except KeyboardInterrupt:
384 | return
385 | else:
386 | for thread in threads:
387 | thread.join()
388 | log.info('Learning finished after %d frames.', frames)
389 | finally:
390 | for _ in range(flags.num_actors):
391 | free_queue.put(None)
392 | for actor in actor_processes:
393 | actor.join(timeout=1)
394 |
395 | checkpoint(frames)
396 | plogger.close()
397 |
398 |
--------------------------------------------------------------------------------
/src/algos/ride.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | import threading
10 | import time
11 | import timeit
12 | import pprint
13 | import json
14 |
15 | import numpy as np
16 |
17 | import torch
18 | from torch import multiprocessing as mp
19 | from torch import nn
20 | from torch.nn import functional as F
21 |
22 | from src.core import file_writer
23 | from src.core import prof
24 | from src.core import vtrace
25 |
26 | import src.models as models
27 | import src.losses as losses
28 |
29 | from src.env_utils import FrameStack
30 | from src.utils import get_batch, log, create_env, create_buffers, act, create_heatmap_buffers
31 |
32 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet
33 | MinigridForwardDynamicsNet = models.MinigridForwardDynamicsNet
34 | MinigridInverseDynamicsNet = models.MinigridInverseDynamicsNet
35 | MinigridPolicyNet = models.MinigridPolicyNet
36 |
37 |
38 | def learn(actor_model,
39 | model,
40 | state_embedding_model,
41 | forward_dynamics_model,
42 | inverse_dynamics_model,
43 | batch,
44 | initial_agent_state,
45 | optimizer,
46 | state_embedding_optimizer,
47 | forward_dynamics_optimizer,
48 | inverse_dynamics_optimizer,
49 | scheduler,
50 | flags,
51 | frames=None,
52 | lock=threading.Lock()):
53 | """Performs a learning (optimization) step."""
54 | with lock:
55 | count_rewards = torch.ones((flags.unroll_length, flags.batch_size),
56 | dtype=torch.float32).to(device=flags.device)
57 | count_rewards = batch['episode_state_count'][1:].float().to(device=flags.device)
58 |
59 | if flags.use_fullobs_intrinsic:
60 | state_emb = state_embedding_model(batch, next_state=False)\
61 | .reshape(flags.unroll_length, flags.batch_size, 128)
62 | next_state_emb = state_embedding_model(batch, next_state=True)\
63 | .reshape(flags.unroll_length, flags.batch_size, 128)
64 | else:
65 | state_emb = state_embedding_model(batch['partial_obs'][:-1].to(device=flags.device))
66 | next_state_emb = state_embedding_model(batch['partial_obs'][1:].to(device=flags.device))
67 |
68 | pred_next_state_emb = forward_dynamics_model(
69 | state_emb, batch['action'][1:].to(device=flags.device))
70 | pred_actions = inverse_dynamics_model(state_emb, next_state_emb)
71 |
72 | control_rewards = torch.norm(next_state_emb - state_emb, dim=2, p=2)
73 |
74 | intrinsic_rewards = count_rewards * control_rewards
75 |
76 | intrinsic_reward_coef = flags.intrinsic_reward_coef
77 | intrinsic_rewards *= intrinsic_reward_coef
78 |
79 | forward_dynamics_loss = flags.forward_loss_coef * \
80 | losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb)
81 |
82 | inverse_dynamics_loss = flags.inverse_loss_coef * \
83 | losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:])
84 |
85 | learner_outputs, unused_state = model(batch, initial_agent_state)
86 |
87 | bootstrap_value = learner_outputs['baseline'][-1]
88 |
89 | batch = {key: tensor[1:] for key, tensor in batch.items()}
90 | learner_outputs = {
91 | key: tensor[:-1]
92 | for key, tensor in learner_outputs.items()
93 | }
94 |
95 |
96 | rewards = batch['reward']
97 | if flags.no_reward:
98 | total_rewards = intrinsic_rewards
99 | else:
100 | total_rewards = rewards + intrinsic_rewards
101 | clipped_rewards = torch.clamp(total_rewards, -1, 1)
102 |
103 | discounts = (~batch['done']).float() * flags.discounting
104 |
105 | vtrace_returns = vtrace.from_logits(
106 | behavior_policy_logits=batch['policy_logits'],
107 | target_policy_logits=learner_outputs['policy_logits'],
108 | actions=batch['action'],
109 | discounts=discounts,
110 | rewards=clipped_rewards,
111 | values=learner_outputs['baseline'],
112 | bootstrap_value=bootstrap_value)
113 |
114 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
115 | batch['action'],
116 | vtrace_returns.pg_advantages)
117 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
118 | vtrace_returns.vs - learner_outputs['baseline'])
119 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
120 | learner_outputs['policy_logits'])
121 |
122 | total_loss = pg_loss + baseline_loss + entropy_loss + \
123 | forward_dynamics_loss + inverse_dynamics_loss
124 |
125 | episode_returns = batch['episode_return'][batch['done']]
126 | stats = {
127 | 'mean_episode_return': torch.mean(episode_returns).item(),
128 | 'total_loss': total_loss.item(),
129 | 'pg_loss': pg_loss.item(),
130 | 'baseline_loss': baseline_loss.item(),
131 | 'entropy_loss': entropy_loss.item(),
132 | 'mean_rewards': torch.mean(rewards).item(),
133 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
134 | 'mean_total_rewards': torch.mean(total_rewards).item(),
135 | 'mean_control_rewards': torch.mean(control_rewards).item(),
136 | 'mean_count_rewards': torch.mean(count_rewards).item(),
137 | 'forward_dynamics_loss': forward_dynamics_loss.item(),
138 | 'inverse_dynamics_loss': inverse_dynamics_loss.item(),
139 | }
140 |
141 | scheduler.step()
142 | optimizer.zero_grad()
143 | state_embedding_optimizer.zero_grad()
144 | forward_dynamics_optimizer.zero_grad()
145 | inverse_dynamics_optimizer.zero_grad()
146 | total_loss.backward()
147 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
148 | nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm)
149 | nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm)
150 | nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm)
151 | optimizer.step()
152 | state_embedding_optimizer.step()
153 | forward_dynamics_optimizer.step()
154 | inverse_dynamics_optimizer.step()
155 |
156 | actor_model.load_state_dict(model.state_dict())
157 | return stats
158 |
159 |
160 | def train(flags):
161 | if flags.xpid is None:
162 | flags.xpid = 'ride-%s' % time.strftime('%Y%m%d-%H%M%S')
163 | plogger = file_writer.FileWriter(
164 | xpid=flags.xpid,
165 | xp_args=flags.__dict__,
166 | rootdir=flags.savedir,
167 | )
168 | checkpointpath = os.path.expandvars(os.path.expanduser(
169 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar')))
170 |
171 | T = flags.unroll_length
172 | B = flags.batch_size
173 |
174 | flags.device = None
175 | if not flags.disable_cuda and torch.cuda.is_available():
176 | log.info('Using CUDA.')
177 | flags.device = torch.device('cuda')
178 | else:
179 | log.info('Not using CUDA.')
180 | flags.device = torch.device('cpu')
181 |
182 | env = create_env(flags)
183 | if flags.num_input_frames > 1:
184 | env = FrameStack(env, flags.num_input_frames)
185 |
186 | if 'MiniGrid' in flags.env:
187 | if flags.use_fullobs_policy:
188 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)
189 | else:
190 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)
191 | if flags.use_fullobs_intrinsic:
192 | state_embedding_model = FullObsMinigridStateEmbeddingNet(env.observation_space.shape)\
193 | .to(device=flags.device)
194 | else:
195 | state_embedding_model = MinigridStateEmbeddingNet(env.observation_space.shape)\
196 | .to(device=flags.device)
197 | forward_dynamics_model = MinigridForwardDynamicsNet(env.action_space.n)\
198 | .to(device=flags.device)
199 | inverse_dynamics_model = MinigridInverseDynamicsNet(env.action_space.n)\
200 | .to(device=flags.device)
201 | else:
202 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)
203 | state_embedding_model = MarioDoomStateEmbeddingNet(env.observation_space.shape)\
204 | .to(device=flags.device)
205 | forward_dynamics_model = MarioDoomForwardDynamicsNet(env.action_space.n)\
206 | .to(device=flags.device)
207 | inverse_dynamics_model = MarioDoomInverseDynamicsNet(env.action_space.n)\
208 | .to(device=flags.device)
209 |
210 |
211 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags)
212 | heatmap_buffers = create_heatmap_buffers(env.observation_space.shape)
213 |
214 | model.share_memory()
215 |
216 | initial_agent_state_buffers = []
217 | for _ in range(flags.num_buffers):
218 | state = model.initial_state(batch_size=1)
219 | for t in state:
220 | t.share_memory_()
221 | initial_agent_state_buffers.append(state)
222 |
223 | actor_processes = []
224 | ctx = mp.get_context('fork')
225 | free_queue = ctx.Queue()
226 | full_queue = ctx.Queue()
227 |
228 | episode_state_count_dict = dict()
229 | train_state_count_dict = dict()
230 | partial_state_count_dict = dict()
231 | encoded_state_count_dict = dict()
232 | heatmap_dict = dict()
233 | for i in range(flags.num_actors):
234 | actor = ctx.Process(
235 | target=act,
236 | args=(i, free_queue, full_queue, model, buffers,
237 | episode_state_count_dict, train_state_count_dict, partial_state_count_dict, encoded_state_count_dict,
238 | heatmap_dict, heatmap_buffers, initial_agent_state_buffers, flags))
239 | actor.start()
240 | actor_processes.append(actor)
241 |
242 | if 'MiniGrid' in flags.env:
243 | if flags.use_fullobs_policy:
244 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
245 | .to(device=flags.device)
246 | else:
247 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\
248 | .to(device=flags.device)
249 | else:
250 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\
251 | .to(device=flags.device)
252 |
253 | optimizer = torch.optim.RMSprop(
254 | learner_model.parameters(),
255 | lr=flags.learning_rate,
256 | momentum=flags.momentum,
257 | eps=flags.epsilon,
258 | alpha=flags.alpha)
259 |
260 | state_embedding_optimizer = torch.optim.RMSprop(
261 | state_embedding_model.parameters(),
262 | lr=flags.learning_rate,
263 | momentum=flags.momentum,
264 | eps=flags.epsilon,
265 | alpha=flags.alpha)
266 |
267 | inverse_dynamics_optimizer = torch.optim.RMSprop(
268 | inverse_dynamics_model.parameters(),
269 | lr=flags.learning_rate,
270 | momentum=flags.momentum,
271 | eps=flags.epsilon,
272 | alpha=flags.alpha)
273 |
274 | forward_dynamics_optimizer = torch.optim.RMSprop(
275 | forward_dynamics_model.parameters(),
276 | lr=flags.learning_rate,
277 | momentum=flags.momentum,
278 | eps=flags.epsilon,
279 | alpha=flags.alpha)
280 |
281 |
282 | def lr_lambda(epoch):
283 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames
284 |
285 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
286 |
287 | logger = logging.getLogger('logfile')
288 | stat_keys = [
289 | 'total_loss',
290 | 'mean_episode_return',
291 | 'pg_loss',
292 | 'baseline_loss',
293 | 'entropy_loss',
294 | 'mean_rewards',
295 | 'mean_intrinsic_rewards',
296 | 'mean_total_rewards',
297 | 'mean_control_rewards',
298 | 'mean_count_rewards',
299 | 'forward_dynamics_loss',
300 | 'inverse_dynamics_loss',
301 | ]
302 | logger.info('# Step\t%s', '\t'.join(stat_keys))
303 | frames, stats = 0, {}
304 |
305 |
306 | def batch_and_learn(i, lock=threading.Lock()):
307 | """Thread target for the learning process."""
308 | nonlocal frames, stats
309 | timings = prof.Timings()
310 | while frames < flags.total_frames:
311 | timings.reset()
312 | batch, agent_state = get_batch(free_queue, full_queue, buffers,
313 | initial_agent_state_buffers, flags, timings)
314 | stats = learn(model, learner_model, state_embedding_model, forward_dynamics_model,
315 | inverse_dynamics_model, batch, agent_state, optimizer,
316 | state_embedding_optimizer, forward_dynamics_optimizer,
317 | inverse_dynamics_optimizer, scheduler, flags, frames=frames)
318 | timings.time('learn')
319 | with lock:
320 | to_log = dict(frames=frames)
321 | to_log.update({k: stats[k] for k in stat_keys})
322 | plogger.log(to_log)
323 | frames += T * B
324 |
325 | if i == 0:
326 | log.info('Batch and learn: %s', timings.summary())
327 |
328 | for m in range(flags.num_buffers):
329 | free_queue.put(m)
330 |
331 | threads = []
332 | for i in range(flags.num_threads):
333 | thread = threading.Thread(
334 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,))
335 | thread.start()
336 | threads.append(thread)
337 |
338 |
339 | def checkpoint(frames):
340 | if flags.disable_checkpoint:
341 | return
342 | checkpointpath = os.path.expandvars(os.path.expanduser(
343 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model_'+str(frames)+'.tar')))
344 | log.info('Saving checkpoint to %s', checkpointpath)
345 | torch.save({
346 | 'model_state_dict': model.state_dict(),
347 | 'state_embedding_model_state_dict': state_embedding_model.state_dict(),
348 | 'forward_dynamics_model_state_dict': forward_dynamics_model.state_dict(),
349 | 'inverse_dynamics_model_state_dict': inverse_dynamics_model.state_dict(),
350 | 'optimizer_state_dict': optimizer.state_dict(),
351 | 'state_embedding_optimizer_state_dict': state_embedding_optimizer.state_dict(),
352 | 'forward_dynamics_optimizer_state_dict': forward_dynamics_optimizer.state_dict(),
353 | 'inverse_dynamics_optimizer_state_dict': inverse_dynamics_optimizer.state_dict(),
354 | 'scheduler_state_dict': scheduler.state_dict(),
355 | 'flags': vars(flags),
356 | }, checkpointpath)
357 |
358 | def save_heatmap(frames):
359 | checkpoint_path = os.path.expandvars(os.path.expanduser(
360 | '%s/%s/%s' % (flags.savedir, flags.xpid,'heatmap_'+str(frames)+'.json')))
361 | log.info('Saving heatmap to %s', checkpoint_path)
362 | heatmap_dict = dict()
363 | for i, key in enumerate(heatmap_buffers.keys()):
364 | heatmap_dict.update({i: heatmap_buffers[key].item()})
365 | with open(checkpoint_path, 'w') as fp:
366 | json.dump(heatmap_dict, fp)
367 |
368 | timer = timeit.default_timer
369 | try:
370 | last_checkpoint_time = timer()
371 | while frames < flags.total_frames:
372 | start_frames = frames
373 | start_time = timer()
374 | time.sleep(5)
375 |
376 | if timer() - last_checkpoint_time > flags.save_interval * 60:
377 | checkpoint(frames)
378 | save_heatmap(frames)
379 | last_checkpoint_time = timer()
380 |
381 | fps = (frames - start_frames) / (timer() - start_time)
382 | if stats.get('episode_returns', None):
383 | mean_return = 'Return per episode: %.1f. ' % stats[
384 | 'mean_episode_return']
385 | else:
386 | mean_return = ''
387 | total_loss = stats.get('total_loss', float('inf'))
388 | log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s',
389 | frames, total_loss, fps, mean_return,
390 | pprint.pformat(stats))
391 |
392 | except KeyboardInterrupt:
393 | return
394 | else:
395 | for thread in threads:
396 | thread.join()
397 | log.info('Learning finished after %d frames.', frames)
398 |
399 | finally:
400 | for _ in range(flags.num_actors):
401 | free_queue.put(None)
402 | for actor in actor_processes:
403 | actor.join(timeout=1)
404 | checkpoint(frames)
405 | plogger.close()
406 |
407 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 | Section 1 -- Definitions.
71 |
72 | a. Adapted Material means material subject to Copyright and Similar
73 | Rights that is derived from or based upon the Licensed Material
74 | and in which the Licensed Material is translated, altered,
75 | arranged, transformed, or otherwise modified in a manner requiring
76 | permission under the Copyright and Similar Rights held by the
77 | Licensor. For purposes of this Public License, where the Licensed
78 | Material is a musical work, performance, or sound recording,
79 | Adapted Material is always produced where the Licensed Material is
80 | synched in timed relation with a moving image.
81 |
82 | b. Adapter's License means the license You apply to Your Copyright
83 | and Similar Rights in Your contributions to Adapted Material in
84 | accordance with the terms and conditions of this Public License.
85 |
86 | c. Copyright and Similar Rights means copyright and/or similar rights
87 | closely related to copyright including, without limitation,
88 | performance, broadcast, sound recording, and Sui Generis Database
89 | Rights, without regard to how the rights are labeled or
90 | categorized. For purposes of this Public License, the rights
91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
92 | Rights.
93 | d. Effective Technological Measures means those measures that, in the
94 | absence of proper authority, may not be circumvented under laws
95 | fulfilling obligations under Article 11 of the WIPO Copyright
96 | Treaty adopted on December 20, 1996, and/or similar international
97 | agreements.
98 |
99 | e. Exceptions and Limitations means fair use, fair dealing, and/or
100 | any other exception or limitation to Copyright and Similar Rights
101 | that applies to Your use of the Licensed Material.
102 |
103 | f. Licensed Material means the artistic or literary work, database,
104 | or other material to which the Licensor applied this Public
105 | License.
106 |
107 | g. Licensed Rights means the rights granted to You subject to the
108 | terms and conditions of this Public License, which are limited to
109 | all Copyright and Similar Rights that apply to Your use of the
110 | Licensed Material and that the Licensor has authority to license.
111 |
112 | h. Licensor means the individual(s) or entity(ies) granting rights
113 | under this Public License.
114 |
115 | i. NonCommercial means not primarily intended for or directed towards
116 | commercial advantage or monetary compensation. For purposes of
117 | this Public License, the exchange of the Licensed Material for
118 | other material subject to Copyright and Similar Rights by digital
119 | file-sharing or similar means is NonCommercial provided there is
120 | no payment of monetary compensation in connection with the
121 | exchange.
122 |
123 | j. Share means to provide material to the public by any means or
124 | process that requires permission under the Licensed Rights, such
125 | as reproduction, public display, public performance, distribution,
126 | dissemination, communication, or importation, and to make material
127 | available to the public including in ways that members of the
128 | public may access the material from a place and at a time
129 | individually chosen by them.
130 |
131 | k. Sui Generis Database Rights means rights other than copyright
132 | resulting from Directive 96/9/EC of the European Parliament and of
133 | the Council of 11 March 1996 on the legal protection of databases,
134 | as amended and/or succeeded, as well as other essentially
135 | equivalent rights anywhere in the world.
136 |
137 | l. You means the individual or entity exercising the Licensed Rights
138 | under this Public License. Your has a corresponding meaning.
139 |
140 | Section 2 -- Scope.
141 |
142 | a. License grant.
143 |
144 | 1. Subject to the terms and conditions of this Public License,
145 | the Licensor hereby grants You a worldwide, royalty-free,
146 | non-sublicensable, non-exclusive, irrevocable license to
147 | exercise the Licensed Rights in the Licensed Material to:
148 |
149 | a. reproduce and Share the Licensed Material, in whole or
150 | in part, for NonCommercial purposes only; and
151 |
152 | b. produce, reproduce, and Share Adapted Material for
153 | NonCommercial purposes only.
154 |
155 | 2. Exceptions and Limitations. For the avoidance of doubt, where
156 | Exceptions and Limitations apply to Your use, this Public
157 | License does not apply, and You do not need to comply with
158 | its terms and conditions.
159 |
160 | 3. Term. The term of this Public License is specified in Section
161 | 6(a).
162 |
163 | 4. Media and formats; technical modifications allowed. The
164 | Licensor authorizes You to exercise the Licensed Rights in
165 | all media and formats whether now known or hereafter created,
166 | and to make technical modifications necessary to do so. The
167 | Licensor waives and/or agrees not to assert any right or
168 | authority to forbid You from making technical modifications
169 | necessary to exercise the Licensed Rights, including
170 | technical modifications necessary to circumvent Effective
171 | Technological Measures. For purposes of this Public License,
172 | simply making modifications authorized by this Section 2(a)
173 | (4) never produces Adapted Material.
174 |
175 | 5. Downstream recipients.
176 |
177 | a. Offer from the Licensor -- Licensed Material. Every
178 | recipient of the Licensed Material automatically
179 | receives an offer from the Licensor to exercise the
180 | Licensed Rights under the terms and conditions of this
181 | Public License.
182 |
183 | b. No downstream restrictions. You may not offer or impose
184 | any additional or different terms or conditions on, or
185 | apply any Effective Technological Measures to, the
186 | Licensed Material if doing so restricts exercise of the
187 | Licensed Rights by any recipient of the Licensed
188 | Material.
189 |
190 | 6. No endorsement. Nothing in this Public License constitutes or
191 | may be construed as permission to assert or imply that You
192 | are, or that Your use of the Licensed Material is, connected
193 | with, or sponsored, endorsed, or granted official status by,
194 | the Licensor or others designated to receive attribution as
195 | provided in Section 3(a)(1)(A)(i).
196 |
197 | b. Other rights.
198 |
199 | 1. Moral rights, such as the right of integrity, are not
200 | licensed under this Public License, nor are publicity,
201 | privacy, and/or other similar personality rights; however, to
202 | the extent possible, the Licensor waives and/or agrees not to
203 | assert any such rights held by the Licensor to the limited
204 | extent necessary to allow You to exercise the Licensed
205 | Rights, but not otherwise.
206 |
207 | 2. Patent and trademark rights are not licensed under this
208 | Public License.
209 |
210 | 3. To the extent possible, the Licensor waives any right to
211 | collect royalties from You for the exercise of the Licensed
212 | Rights, whether directly or through a collecting society
213 | under any voluntary or waivable statutory or compulsory
214 | licensing scheme. In all other cases the Licensor expressly
215 | reserves any right to collect such royalties, including when
216 | the Licensed Material is used other than for NonCommercial
217 | purposes.
218 |
219 | Section 3 -- License Conditions.
220 |
221 | Your exercise of the Licensed Rights is expressly made subject to the
222 | following conditions.
223 |
224 | a. Attribution.
225 |
226 | 1. If You Share the Licensed Material (including in modified
227 | form), You must:
228 |
229 | a. retain the following if it is supplied by the Licensor
230 | with the Licensed Material:
231 |
232 | i. identification of the creator(s) of the Licensed
233 | Material and any others designated to receive
234 | attribution, in any reasonable manner requested by
235 | the Licensor (including by pseudonym if
236 | designated);
237 |
238 | ii. a copyright notice;
239 |
240 | iii. a notice that refers to this Public License;
241 |
242 | iv. a notice that refers to the disclaimer of
243 | warranties;
244 |
245 | v. a URI or hyperlink to the Licensed Material to the
246 | extent reasonably practicable;
247 |
248 | b. indicate if You modified the Licensed Material and
249 | retain an indication of any previous modifications; and
250 |
251 | c. indicate the Licensed Material is licensed under this
252 | Public License, and include the text of, or the URI or
253 | hyperlink to, this Public License.
254 |
255 | 2. You may satisfy the conditions in Section 3(a)(1) in any
256 | reasonable manner based on the medium, means, and context in
257 | which You Share the Licensed Material. For example, it may be
258 | reasonable to satisfy the conditions by providing a URI or
259 | hyperlink to a resource that includes the required
260 | information.
261 |
262 | 3. If requested by the Licensor, You must remove any of the
263 | information required by Section 3(a)(1)(A) to the extent
264 | reasonably practicable.
265 |
266 | 4. If You Share Adapted Material You produce, the Adapter's
267 | License You apply must not prevent recipients of the Adapted
268 | Material from complying with this Public License.
269 |
270 | Section 4 -- Sui Generis Database Rights.
271 |
272 | Where the Licensed Rights include Sui Generis Database Rights that
273 | apply to Your use of the Licensed Material:
274 |
275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276 | to extract, reuse, reproduce, and Share all or a substantial
277 | portion of the contents of the database for NonCommercial purposes
278 | only;
279 |
280 | b. if You include all or a substantial portion of the database
281 | contents in a database in which You have Sui Generis Database
282 | Rights, then the database in which You have Sui Generis Database
283 | Rights (but not its individual contents) is Adapted Material; and
284 |
285 | c. You must comply with the conditions in Section 3(a) if You Share
286 | all or a substantial portion of the contents of the database.
287 |
288 | For the avoidance of doubt, this Section 4 supplements and does not
289 | replace Your obligations under this Public License where the Licensed
290 | Rights include other Copyright and Similar Rights.
291 |
292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304 |
305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314 |
315 | c. The disclaimer of warranties and limitation of liability provided
316 | above shall be interpreted in a manner that, to the extent
317 | possible, most closely approximates an absolute disclaimer and
318 | waiver of all liability.
319 |
320 | Section 6 -- Term and Termination.
321 |
322 | a. This Public License applies for the term of the Copyright and
323 | Similar Rights licensed here. However, if You fail to comply with
324 | this Public License, then Your rights under this Public License
325 | terminate automatically.
326 |
327 | b. Where Your right to use the Licensed Material has terminated under
328 | Section 6(a), it reinstates:
329 |
330 | 1. automatically as of the date the violation is cured, provided
331 | it is cured within 30 days of Your discovery of the
332 | violation; or
333 |
334 | 2. upon express reinstatement by the Licensor.
335 |
336 | For the avoidance of doubt, this Section 6(b) does not affect any
337 | right the Licensor may have to seek remedies for Your violations
338 | of this Public License.
339 |
340 | c. For the avoidance of doubt, the Licensor may also offer the
341 | Licensed Material under separate terms or conditions or stop
342 | distributing the Licensed Material at any time; however, doing so
343 | will not terminate this Public License.
344 |
345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346 | License.
347 |
348 | Section 7 -- Other Terms and Conditions.
349 |
350 | a. The Licensor shall not be bound by any additional or different
351 | terms or conditions communicated by You unless expressly agreed.
352 |
353 | b. Any arrangements, understandings, or agreements regarding the
354 | Licensed Material not stated herein are separate from and
355 | independent of the terms and conditions of this Public License.
356 |
357 | Section 8 -- Interpretation.
358 |
359 | a. For the avoidance of doubt, this Public License does not, and
360 | shall not be interpreted to, reduce, limit, restrict, or impose
361 | conditions on any use of the Licensed Material that could lawfully
362 | be made without permission under this Public License.
363 |
364 | b. To the extent possible, if any provision of this Public License is
365 | deemed unenforceable, it shall be automatically reformed to the
366 | minimum extent necessary to make it enforceable. If the provision
367 | cannot be reformed, it shall be severed from this Public License
368 | without affecting the enforceability of the remaining terms and
369 | conditions.
370 |
371 | c. No term or condition of this Public License will be waived and no
372 | failure to comply consented to unless expressly agreed to by the
373 | Licensor.
374 |
375 | d. Nothing in this Public License constitutes or may be interpreted
376 | as a limitation upon, or waiver of, any privileges and immunities
377 | that apply to the Licensor or You, including from the legal
378 | processes of any jurisdiction or authority.
379 |
380 | =======================================================================
381 |
382 | Creative Commons is not a party to its public
383 | licenses. Notwithstanding, Creative Commons may elect to apply one of
384 | its public licenses to material it publishes and in those instances
385 | will be considered the “Licensor.” The text of the Creative Commons
386 | public licenses is dedicated to the public domain under the CC0 Public
387 | Domain Dedication. Except for the limited purpose of indicating that
388 | material is shared under a Creative Commons public license or as
389 | otherwise permitted by the Creative Commons policies published at
390 | creativecommons.org/policies, Creative Commons does not authorize the
391 | use of the trademark "Creative Commons" or any other trademark or logo
392 | of Creative Commons without its prior written consent including,
393 | without limitation, in connection with any unauthorized modifications
394 | to any of its public licenses or any other arrangements,
395 | understandings, or agreements concerning use of licensed material. For
396 | the avoidance of doubt, this paragraph does not form part of the
397 | public licenses.
398 |
399 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------