├── src ├── __init__.py ├── algos │ ├── __init__.py │ ├── torchbeast.py │ ├── count.py │ ├── rnd.py │ ├── bebold.py │ ├── curiosity.py │ └── ride.py ├── core │ ├── __init__.py │ ├── prof.py │ ├── environment.py │ ├── vtrace.py │ ├── file_writer.py │ └── vtrace_test.py ├── losses.py ├── env_utils.py ├── arguments.py ├── models.py ├── atari_wrappers.py └── utils.py ├── requirements.txt ├── main.py ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /src/algos/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appdirs==1.4.4 2 | certifi==2019.11.28 3 | cffi==1.13.2 4 | cfgv==3.1.0 5 | cloudpickle==1.2.2 6 | distlib==0.3.0 7 | filelock==3.0.12 8 | future==0.18.2 9 | gitdb2==2.0.6 10 | GitPython==3.0.5 11 | gym==0.15.4 12 | -e git+https://github.com/maximecb/gym-minigrid.git@b84d99a2fbb481977172a85661eee812f022f130#egg=gym_minigrid 13 | gym-super-mario-bros==7.3.0 14 | identify==1.4.19 15 | importlib-metadata==1.6.1 16 | importlib-resources==2.0.0 17 | nes-py==8.1.1 18 | nodeenv==1.4.0 19 | numpy==1.17.4 20 | olefile==0.46 21 | opencv-python==4.1.2.30 22 | Pillow==8.1.1 23 | pre-commit==2.5.1 24 | pycparser==2.19 25 | pyglet==1.3.2 26 | PyQt5==5.13.2 27 | PyQt5-sip==12.7.0 28 | PyYAML==5.4 29 | scipy==1.3.3 30 | six==1.13.0 31 | smmap2==2.0.5 32 | toml==0.10.1 33 | torch==1.1.0 34 | torchvision==0.3.0 35 | tqdm==4.40.2 36 | virtualenv==20.0.21 37 | vizdoom==1.1.8 38 | zipp==3.1.0 39 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from src.arguments import parser 8 | 9 | from src.algos.torchbeast import train as train_vanilla 10 | from src.algos.count import train as train_count 11 | from src.algos.curiosity import train as train_curiosity 12 | from src.algos.rnd import train as train_rnd 13 | from src.algos.ride import train as train_ride 14 | from src.algos.bebold import train as train_bebold 15 | 16 | def main(flags): 17 | print(flags) 18 | if flags.model == 'vanilla': 19 | train_vanilla(flags) 20 | elif flags.model == 'count': 21 | train_count(flags) 22 | elif flags.model == 'curiosity': 23 | train_curiosity(flags) 24 | elif flags.model == 'rnd': 25 | train_rnd(flags) 26 | elif flags.model == 'ride': 27 | train_ride(flags) 28 | elif flags.model == 'bebold': 29 | train_bebold(flags) 30 | else: 31 | raise NotImplementedError("This model has not been implemented. "\ 32 | "The available options are: vanilla, count, curiosity, rnd, ride, \ 33 | no-episodic-counts, and only-episodic-count.") 34 | 35 | if __name__ == '__main__': 36 | flags = parser.parse_args() 37 | main(flags) 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NovelD: A Simple yet Effective Exploration Criterion 2 | 3 | ## Intro 4 | 5 | This is an implementation of the method proposed in 6 | 7 | NovelD: A Simple yet Effective Exploration Criterion and BeBold: Exploration Beyond the Boundary of Explored Regions 8 | 9 | ## Citation 10 | If you use this code in your own work, please cite our paper: 11 | ``` 12 | @article{zhang2021noveld, 13 | title={NovelD: A Simple yet Effective Exploration Criterion}, 14 | author={Zhang, Tianjun and Xu, Huazhe and Wang, Xiaolong and Wu, Yi and Keutzer, Kurt and Gonzalez, Joseph E and Tian, Yuandong}, 15 | journal={Advances in Neural Information Processing Systems}, 16 | volume={34}, 17 | year={2021} 18 | } 19 | ``` 20 | or 21 | ``` 22 | @article{zhang2020bebold, 23 | title={BeBold: Exploration Beyond the Boundary of Explored Regions}, 24 | author={Zhang, Tianjun and Xu, Huazhe and Wang, Xiaolong and Wu, Yi and Keutzer, Kurt and Gonzalez, Joseph E and Tian, Yuandong}, 25 | journal={arXiv preprint arXiv:2012.08621}, 26 | year={2020} 27 | } 28 | ``` 29 | 30 | ## Installation 31 | 32 | ``` 33 | # Install Instructions 34 | conda create -n ride python=3.7 35 | conda activate noveld 36 | git clone git@github.com:tianjunz/NovelD.git 37 | cd NovelD 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ## Train NovelD on MiniGrid 42 | ``` 43 | OMP_NUM_THREADS=1 python main.py --model bebold --env MiniGrid-ObstructedMaze-2Dlhb-v0 --total_frames 500000000 --intrinsic_reward_coef 0.05 --entropy_cost 0.0005 44 | ``` 45 | 46 | ## Acknowledgements 47 | Our vanilla RL algorithm is based on [RIDE](https://github.com/facebookresearch/impact-driven-exploration). 48 | 49 | ## License 50 | This code is under the CC-BY-NC 4.0 (Attribution-NonCommercial 4.0 International) license. 51 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | 13 | def compute_baseline_loss(advantages): 14 | return 0.5 * torch.sum(torch.mean(advantages**2, dim=1)) 15 | 16 | def compute_entropy_loss(logits): 17 | policy = F.softmax(logits, dim=-1) 18 | log_policy = F.log_softmax(logits, dim=-1) 19 | entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1) 20 | return -torch.sum(torch.mean(entropy_per_timestep, dim=1)) 21 | 22 | def compute_policy_gradient_loss(logits, actions, advantages): 23 | cross_entropy = F.nll_loss( 24 | F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), 25 | target=torch.flatten(actions, 0, 1), 26 | reduction='none') 27 | cross_entropy = cross_entropy.view_as(advantages) 28 | advantages.requires_grad = False 29 | policy_gradient_loss_per_timestep = cross_entropy * advantages 30 | return torch.sum(torch.mean(policy_gradient_loss_per_timestep, dim=1)) 31 | 32 | 33 | def compute_forward_dynamics_loss(pred_next_emb, next_emb): 34 | forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2) 35 | return torch.sum(torch.mean(forward_dynamics_loss, dim=1)) 36 | 37 | 38 | def compute_inverse_dynamics_loss(pred_actions, true_actions): 39 | inverse_dynamics_loss = F.nll_loss( 40 | F.log_softmax(torch.flatten(pred_actions, 0, 1), dim=-1), 41 | target=torch.flatten(true_actions, 0, 1), 42 | reduction='none') 43 | inverse_dynamics_loss = inverse_dynamics_loss.view_as(true_actions) 44 | return torch.sum(torch.mean(inverse_dynamics_loss, dim=1)) 45 | 46 | def compute_rnd_loss(pred_next_emb, next_emb): 47 | forward_dynamics_loss = torch.norm(pred_next_emb - next_emb, dim=2, p=2) 48 | return torch.mean(forward_dynamics_loss) 49 | 50 | -------------------------------------------------------------------------------- /src/core/prof.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Naive profiling using timeit.""" 8 | 9 | import collections 10 | import timeit 11 | 12 | 13 | class Timings: 14 | """Not thread-safe.""" 15 | 16 | def __init__(self): 17 | self._means = collections.defaultdict(int) 18 | self._vars = collections.defaultdict(int) 19 | self._counts = collections.defaultdict(int) 20 | self.reset() 21 | 22 | def reset(self): 23 | self.last_time = timeit.default_timer() 24 | 25 | def time(self, name): 26 | """Save an update for event `name`. 27 | 28 | Nerd alarm: We could just store a 29 | collections.defaultdict(list) 30 | and compute means and standard deviations at the end. But thanks to the 31 | clever math in Sutton-Barto 32 | (http://www.incompleteideas.net/book/first/ebook/node19.html) and 33 | https://math.stackexchange.com/a/103025/5051 we can update both the 34 | means and the stds online. O(1) FTW! 35 | """ 36 | now = timeit.default_timer() 37 | x = now - self.last_time 38 | self.last_time = now 39 | 40 | n = self._counts[name] 41 | 42 | mean = self._means[name] + (x - self._means[name]) / (n + 1) 43 | var = (n * self._vars[name] + n * (self._means[name] - mean)**2 + 44 | (x - mean)**2) / (n + 1) 45 | 46 | self._means[name] = mean 47 | self._vars[name] = var 48 | self._counts[name] += 1 49 | 50 | def means(self): 51 | return self._means 52 | 53 | def vars(self): 54 | return self._vars 55 | 56 | def stds(self): 57 | return {k: v**0.5 for k, v in self._vars.items()} 58 | 59 | def summary(self, prefix=''): 60 | means = self.means() 61 | stds = self.stds() 62 | total = sum(means.values()) 63 | 64 | result = prefix 65 | for k in sorted(means, key=means.get, reverse=True): 66 | result += f'\n %s: %.6fms +- %.6fms (%.2f%%) ' % ( 67 | k, 1000 * means[k], 1000 * stds[k], 100 * means[k] / total) 68 | result += '\nTotal: %.6fms' % (1000 * total) 69 | return result 70 | -------------------------------------------------------------------------------- /src/core/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | def _format_frame(frame): 11 | frame = torch.from_numpy(frame) 12 | return frame.view((1, 1) + frame.shape) # (...) -> (T,B,...). 13 | 14 | 15 | class Environment: 16 | 17 | def __init__(self, gym_env, fix_seed=False): 18 | self.gym_env = gym_env 19 | self.episode_return = None 20 | self.episode_step = None 21 | self.fix_seed = fix_seed 22 | 23 | def initial(self): 24 | initial_reward = torch.zeros(1, 1) 25 | # This supports only single-tensor actions ATM. 26 | initial_last_action = torch.zeros(1, 1, dtype=torch.int64) 27 | self.episode_return = torch.zeros(1, 1) 28 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 29 | initial_done = torch.zeros(1, 1, dtype=torch.uint8) 30 | if self.fix_seed: 31 | self.gym_env.seed(seed=1) 32 | initial_frame = _format_frame(self.gym_env.reset()) 33 | return dict( 34 | frame=initial_frame, 35 | reward=initial_reward, 36 | done=initial_done, 37 | episode_return=self.episode_return, 38 | episode_step=self.episode_step, 39 | last_action=initial_last_action) 40 | 41 | 42 | def step(self, action): 43 | frame, reward, done, unused_info = self.gym_env.step(action.item()) 44 | self.episode_step += 1 45 | self.episode_return += reward 46 | episode_step = self.episode_step 47 | episode_return = self.episode_return 48 | if done: 49 | if self.fix_seed: 50 | self.gym_env.seed(seed=1) 51 | frame = self.gym_env.reset() 52 | self.episode_return = torch.zeros(1, 1) 53 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 54 | 55 | frame = _format_frame(frame) 56 | reward = torch.tensor(reward).view(1, 1) 57 | done = torch.tensor(done).view(1, 1) 58 | 59 | return dict( 60 | frame=frame, 61 | reward=reward, 62 | done=done, 63 | episode_return=episode_return, 64 | episode_step=episode_step, 65 | last_action=action) 66 | 67 | 68 | def close(self): 69 | self.gym_env.close() -------------------------------------------------------------------------------- /src/core/vtrace.py: -------------------------------------------------------------------------------- 1 | # This file taken from 2 | # https://github.com/deepmind/scalable_agent/blob/ 3 | # cd66d00914d56c8ba2f0615d9cdeefcb169a8d70/vtrace.py 4 | # and modified. 5 | 6 | # Copyright 2018 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # https://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Functions to compute V-trace off-policy actor critic targets. 21 | 22 | For details and theory see: 23 | 24 | "IMPALA: Scalable Distributed Deep-RL with 25 | Importance Weighted Actor-Learner Architectures" 26 | by Espeholt, Soyer, Munos et al. 27 | 28 | See https://arxiv.org/abs/1802.01561 for the full paper. 29 | """ 30 | 31 | import collections 32 | 33 | import torch 34 | import torch.nn.functional as F 35 | 36 | VTraceFromLogitsReturns = collections.namedtuple('VTraceFromLogitsReturns', [ 37 | 'vs', 'pg_advantages', 'log_rhos', 'behavior_action_log_probs', 38 | 'target_action_log_probs' 39 | ]) 40 | 41 | VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages') 42 | 43 | 44 | def action_log_probs(policy_logits, actions): 45 | return -F.nll_loss( 46 | F.log_softmax(torch.flatten(policy_logits, 0, 1), dim=-1), 47 | torch.flatten(actions, 0, 1), 48 | reduction='none').view_as(actions) 49 | 50 | 51 | def from_logits(behavior_policy_logits, 52 | target_policy_logits, 53 | actions, 54 | discounts, 55 | rewards, 56 | values, 57 | bootstrap_value, 58 | clip_rho_threshold=1.0, 59 | clip_pg_rho_threshold=1.0): 60 | """V-trace for softmax policies.""" 61 | 62 | target_action_log_probs = action_log_probs(target_policy_logits, actions) 63 | behavior_action_log_probs = action_log_probs(behavior_policy_logits, 64 | actions) 65 | log_rhos = target_action_log_probs - behavior_action_log_probs 66 | vtrace_returns = from_importance_weights( 67 | log_rhos=log_rhos, 68 | discounts=discounts, 69 | rewards=rewards, 70 | values=values, 71 | bootstrap_value=bootstrap_value, 72 | clip_rho_threshold=clip_rho_threshold, 73 | clip_pg_rho_threshold=clip_pg_rho_threshold) 74 | return VTraceFromLogitsReturns( 75 | log_rhos=log_rhos, 76 | behavior_action_log_probs=behavior_action_log_probs, 77 | target_action_log_probs=target_action_log_probs, 78 | **vtrace_returns._asdict()) 79 | 80 | 81 | @torch.no_grad() 82 | def from_importance_weights(log_rhos, 83 | discounts, 84 | rewards, 85 | values, 86 | bootstrap_value, 87 | clip_rho_threshold=1.0, 88 | clip_pg_rho_threshold=1.0): 89 | """V-trace from log importance weights.""" 90 | with torch.no_grad(): 91 | rhos = torch.exp(log_rhos) 92 | if clip_rho_threshold is not None: 93 | clipped_rhos = torch.clamp(rhos, max=clip_rho_threshold) 94 | else: 95 | clipped_rhos = rhos 96 | 97 | cs = torch.clamp(rhos, max=1.0) 98 | # Append bootstrapped value to get [v1, ..., v_t+1] 99 | values_t_plus_1 = torch.cat( 100 | [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) 101 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) 102 | 103 | acc = torch.zeros_like(bootstrap_value) 104 | result = [] 105 | for t in range(discounts.shape[0] - 1, -1, -1): 106 | acc = deltas[t] + discounts[t] * cs[t] * acc 107 | result.append(acc) 108 | result.reverse() 109 | vs_minus_v_xs = torch.stack(result) 110 | 111 | # Add V(x_s) to get v_s. 112 | vs = torch.add(vs_minus_v_xs, values) 113 | 114 | # Advantage for policy gradient. 115 | vs_t_plus_1 = torch.cat( 116 | [vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) 117 | if clip_pg_rho_threshold is not None: 118 | clipped_pg_rhos = torch.clamp(rhos, max=clip_pg_rho_threshold) 119 | else: 120 | clipped_pg_rhos = rhos 121 | pg_advantages = (clipped_pg_rhos * 122 | (rewards + discounts * vs_t_plus_1 - values)) 123 | 124 | # Make sure no gradients backpropagated through the returned values. 125 | return VTraceReturns(vs=vs, pg_advantages=pg_advantages) 126 | -------------------------------------------------------------------------------- /src/env_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import gym 8 | import torch 9 | from collections import deque, defaultdict 10 | from gym import spaces 11 | import numpy as np 12 | from gym_minigrid.minigrid import OBJECT_TO_IDX, COLOR_TO_IDX 13 | 14 | 15 | def _format_observation(obs): 16 | obs = torch.tensor(obs) 17 | return obs.view((1, 1) + obs.shape) 18 | 19 | 20 | class Minigrid2Image(gym.ObservationWrapper): 21 | def __init__(self, env): 22 | gym.ObservationWrapper.__init__(self, env) 23 | self.observation_space = env.observation_space.spaces['image'] 24 | 25 | def observation(self, observation): 26 | return observation['image'] 27 | 28 | 29 | class Environment: 30 | def __init__(self, gym_env, fix_seed=False, env_seed=1): 31 | self.gym_env = gym_env 32 | self.episode_return = None 33 | self.episode_step = None 34 | self.episode_win = None 35 | self.fix_seed = fix_seed 36 | self.env_seed = env_seed 37 | 38 | def get_partial_obs(self): 39 | return self.gym_env.env.env.gen_obs()['image'] 40 | 41 | def initial(self): 42 | initial_reward = torch.zeros(1, 1) 43 | self.episode_return = torch.zeros(1, 1) 44 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 45 | self.episode_win = torch.zeros(1, 1, dtype=torch.int32) 46 | initial_done = torch.ones(1, 1, dtype=torch.uint8) 47 | if self.fix_seed: 48 | self.gym_env.seed(seed=self.env_seed) 49 | initial_frame = _format_observation(self.gym_env.reset()) 50 | partial_obs = _format_observation(self.get_partial_obs()) 51 | 52 | if self.gym_env.env.env.carrying: 53 | carried_col, carried_obj = torch.LongTensor([[COLOR_TO_IDX[self.gym_env.env.env.carrying.color]]]), torch.LongTensor([[OBJECT_TO_IDX[self.gym_env.env.env.carrying.type]]]) 54 | else: 55 | carried_col, carried_obj = torch.LongTensor([[5]]), torch.LongTensor([[1]]) 56 | 57 | return dict( 58 | frame=initial_frame, 59 | reward=initial_reward, 60 | done=initial_done, 61 | episode_return=self.episode_return, 62 | episode_step=self.episode_step, 63 | episode_win=self.episode_win, 64 | carried_col = carried_col, 65 | carried_obj = carried_obj, 66 | partial_obs=partial_obs 67 | ) 68 | 69 | def step(self, action): 70 | frame, reward, done, _ = self.gym_env.step(action.item()) 71 | 72 | self.episode_step += 1 73 | episode_step = self.episode_step 74 | 75 | self.episode_return += reward 76 | episode_return = self.episode_return 77 | 78 | if done and reward > 0: 79 | self.episode_win[0][0] = 1 80 | else: 81 | self.episode_win[0][0] = 0 82 | episode_win = self.episode_win 83 | 84 | if done: 85 | if self.fix_seed: 86 | self.gym_env.seed(seed=self.env_seed) 87 | frame = self.gym_env.reset() 88 | self.episode_return = torch.zeros(1, 1) 89 | self.episode_step = torch.zeros(1, 1, dtype=torch.int32) 90 | self.episode_win = torch.zeros(1, 1, dtype=torch.int32) 91 | 92 | frame = _format_observation(frame) 93 | reward = torch.tensor(reward).view(1, 1) 94 | done = torch.tensor(done).view(1, 1) 95 | partial_obs = _format_observation(self.get_partial_obs()) 96 | 97 | if self.gym_env.env.env.carrying: 98 | carried_col, carried_obj = torch.LongTensor([[COLOR_TO_IDX[self.gym_env.env.env.carrying.color]]]), torch.LongTensor([[OBJECT_TO_IDX[self.gym_env.env.env.carrying.type]]]) 99 | else: 100 | carried_col, carried_obj = torch.LongTensor([[5]]), torch.LongTensor([[1]]) 101 | 102 | 103 | return dict( 104 | frame=frame, 105 | reward=reward, 106 | done=done, 107 | episode_return=episode_return, 108 | episode_step = episode_step, 109 | episode_win = episode_win, 110 | carried_col = carried_col, 111 | carried_obj = carried_obj, 112 | partial_obs=partial_obs 113 | ) 114 | 115 | def get_full_obs(self): 116 | env = self.gym_env.unwrapped 117 | full_grid = env.grid.encode() 118 | full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array([ 119 | OBJECT_TO_IDX['agent'], 120 | COLOR_TO_IDX['red'], 121 | env.agent_dir 122 | ]) 123 | return full_grid 124 | 125 | def close(self): 126 | self.gym_env.close() 127 | 128 | 129 | class FrameStack(gym.Wrapper): 130 | def __init__(self, env, k): 131 | """Stack k last frames. 132 | Returns lazy array, which is much more memory efficient. 133 | See Also 134 | -------- 135 | baselines.common.atari_wrappers.LazyFrames 136 | """ 137 | gym.Wrapper.__init__(self, env) 138 | self.k = k 139 | self.frames = deque([], maxlen=k) 140 | shp = env.observation_space.shape 141 | self.observation_space = spaces.Box( 142 | low=0, 143 | high=255, 144 | shape=(shp[:-1] + (shp[-1] * k,)), 145 | dtype=env.observation_space.dtype) 146 | 147 | def reset(self): 148 | ob = self.env.reset() 149 | for _ in range(self.k): 150 | self.frames.append(ob) 151 | return self._get_ob() 152 | 153 | def step(self, action): 154 | ob, reward, done, info = self.env.step(action) 155 | self.frames.append(ob) 156 | return self._get_ob(), reward, done, info 157 | 158 | def _get_ob(self): 159 | assert len(self.frames) == self.k 160 | return LazyFrames(list(self.frames)) 161 | 162 | 163 | class LazyFrames(object): 164 | def __init__(self, frames): 165 | """This object ensures that common frames between the observations are only stored once. 166 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 167 | buffers. 168 | This object should only be converted to numpy array before being passed to the model. 169 | You'd not believe how complex the previous solution was.""" 170 | self._frames = frames 171 | self._out = None 172 | 173 | def _force(self): 174 | if self._out is None: 175 | self._out = np.concatenate(self._frames, axis=-1) 176 | self._frames = None 177 | return self._out 178 | 179 | def __array__(self, dtype=None): 180 | out = self._force() 181 | if dtype is not None: 182 | out = out.astype(dtype) 183 | return out 184 | 185 | def __len__(self): 186 | return len(self._force()) 187 | 188 | def __getitem__(self, i): 189 | return self._force()[i] 190 | 191 | 192 | -------------------------------------------------------------------------------- /src/core/file_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import copy 8 | import datetime 9 | import csv 10 | import json 11 | import logging 12 | import os 13 | import time 14 | from typing import Dict 15 | 16 | import git 17 | 18 | 19 | def gather_metadata() -> Dict: 20 | date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') 21 | # gathering git metadata 22 | try: 23 | repo = git.Repo(search_parent_directories=True) 24 | git_sha = repo.commit().hexsha 25 | git_data = dict( 26 | commit=git_sha, 27 | branch=repo.active_branch.name, 28 | is_dirty=repo.is_dirty(), 29 | path=repo.git_dir, 30 | ) 31 | except git.InvalidGitRepositoryError: 32 | git_data = None 33 | # gathering slurm metadata 34 | if 'SLURM_JOB_ID' in os.environ: 35 | slurm_env_keys = [k for k in os.environ if k.startswith('SLURM')] 36 | slurm_data = {} 37 | for k in slurm_env_keys: 38 | d_key = k.replace('SLURM_', '').replace('SLURMD_', '').lower() 39 | slurm_data[d_key] = os.environ[k] 40 | else: 41 | slurm_data = None 42 | return dict( 43 | date_start=date_start, 44 | date_end=None, 45 | successful=False, 46 | git=git_data, 47 | slurm=slurm_data, 48 | env=os.environ.copy(), 49 | ) 50 | 51 | 52 | class FileWriter: 53 | def __init__(self, 54 | xpid: str = None, 55 | xp_args: dict = None, 56 | rootdir: str = '~/palaas'): 57 | if not xpid: 58 | # make unique id 59 | xpid = '{proc}_{unixtime}'.format( 60 | proc=os.getpid(), unixtime=int(time.time())) 61 | self.xpid = xpid 62 | self._tick = 0 63 | 64 | # metadata gathering 65 | if xp_args is None: 66 | xp_args = {} 67 | self.metadata = gather_metadata() 68 | # we need to copy the args, otherwise when we close the file writer 69 | # (and rewrite the args) we might have non-serializable objects (or 70 | # other nasty stuff). 71 | self.metadata['args'] = copy.deepcopy(xp_args) 72 | self.metadata['xpid'] = self.xpid 73 | 74 | formatter = logging.Formatter('%(message)s') 75 | self._logger = logging.getLogger('palaas/out') 76 | 77 | # to stdout handler 78 | shandle = logging.StreamHandler() 79 | shandle.setFormatter(formatter) 80 | self._logger.addHandler(shandle) 81 | self._logger.setLevel(logging.INFO) 82 | 83 | rootdir = os.path.expandvars(os.path.expanduser(rootdir)) 84 | # to file handler 85 | self.basepath = os.path.join(rootdir, self.xpid) 86 | 87 | if not os.path.exists(self.basepath): 88 | self._logger.info('Creating log directory: %s', self.basepath) 89 | os.makedirs(self.basepath, exist_ok=True) 90 | else: 91 | self._logger.info('Found log directory: %s', self.basepath) 92 | 93 | # NOTE: remove latest because it creates errors when running on slurm 94 | # multiple jobs trying to write to latest but cannot find it 95 | # Add 'latest' as symlink unless it exists and is no symlink. 96 | # symlink = os.path.join(rootdir, 'latest') 97 | # if os.path.islink(symlink): 98 | # os.remove(symlink) 99 | # if not os.path.exists(symlink): 100 | # os.symlink(self.basepath, symlink) 101 | # self._logger.info('Symlinked log directory: %s', symlink) 102 | 103 | self.paths = dict( 104 | msg='{base}/out.log'.format(base=self.basepath), 105 | logs='{base}/logs.csv'.format(base=self.basepath), 106 | fields='{base}/fields.csv'.format(base=self.basepath), 107 | meta='{base}/meta.json'.format(base=self.basepath), 108 | ) 109 | 110 | self._logger.info('Saving arguments to %s', self.paths['meta']) 111 | if os.path.exists(self.paths['meta']): 112 | self._logger.warning('Path to meta file already exists. ' 113 | 'Not overriding meta.') 114 | else: 115 | self._save_metadata() 116 | 117 | self._logger.info('Saving messages to %s', self.paths['msg']) 118 | if os.path.exists(self.paths['msg']): 119 | self._logger.warning('Path to message file already exists. ' 120 | 'New data will be appended.') 121 | 122 | fhandle = logging.FileHandler(self.paths['msg']) 123 | fhandle.setFormatter(formatter) 124 | self._logger.addHandler(fhandle) 125 | 126 | self._logger.info('Saving logs data to %s', self.paths['logs']) 127 | self._logger.info('Saving logs\' fields to %s', self.paths['fields']) 128 | if os.path.exists(self.paths['logs']): 129 | self._logger.warning('Path to log file already exists. ' 130 | 'New data will be appended.') 131 | with open(self.paths['fields'], 'r') as csvfile: 132 | reader = csv.reader(csvfile) 133 | self.fieldnames = list(reader)[0] 134 | else: 135 | self.fieldnames = ['_tick', '_time'] 136 | 137 | def log(self, to_log: Dict, tick: int = None, 138 | verbose: bool = False) -> None: 139 | if tick is not None: 140 | raise NotImplementedError 141 | else: 142 | to_log['_tick'] = self._tick 143 | self._tick += 1 144 | to_log['_time'] = time.time() 145 | 146 | old_len = len(self.fieldnames) 147 | for k in to_log: 148 | if k not in self.fieldnames: 149 | self.fieldnames.append(k) 150 | if old_len != len(self.fieldnames): 151 | with open(self.paths['fields'], 'w') as csvfile: 152 | writer = csv.writer(csvfile) 153 | writer.writerow(self.fieldnames) 154 | self._logger.info('Updated log fields: %s', self.fieldnames) 155 | 156 | if to_log['_tick'] == 0: 157 | # print("\ncreating logs file ") 158 | with open(self.paths['logs'], 'a') as f: 159 | f.write('# %s\n' % ','.join(self.fieldnames)) 160 | 161 | if verbose: 162 | self._logger.info('LOG | %s', ', '.join( 163 | ['{}: {}'.format(k, to_log[k]) for k in sorted(to_log)])) 164 | 165 | with open(self.paths['logs'], 'a') as f: 166 | writer = csv.DictWriter(f, fieldnames=self.fieldnames) 167 | writer.writerow(to_log) 168 | # print("\nadded to log file") 169 | 170 | def close(self, successful: bool = True) -> None: 171 | self.metadata['date_end'] = datetime.datetime.now().strftime( 172 | '%Y-%m-%d %H:%M:%S.%f') 173 | self.metadata['successful'] = successful 174 | self._save_metadata() 175 | 176 | def _save_metadata(self) -> None: 177 | with open(self.paths['meta'], 'w') as jsonfile: 178 | json.dump(self.metadata, jsonfile, indent=4, sort_keys=True) 179 | -------------------------------------------------------------------------------- /src/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser(description='PyTorch Scalable Agent') 10 | 11 | # General Settings. 12 | parser.add_argument('--env', type=str, default='MiniGrid-ObstructedMaze-2Dlh-v0', 13 | help='Gym environment. Other options are: SuperMarioBros-1-1-v0 \ 14 | or VizdoomMyWayHomeDense-v0 etc.') 15 | parser.add_argument('--xpid', default=None, 16 | help='Experiment id (default: None).') 17 | parser.add_argument('--num_input_frames', default=1, type=int, 18 | help='Number of input frames to the model and state embedding including the current frame \ 19 | When num_input_frames > 1, it will also take the previous num_input_frames - 1 frames as input.') 20 | parser.add_argument('--run_id', default=0, type=int, 21 | help='Run id used for running multiple instances of the same HP set \ 22 | (instead of a different random seed since torchbeast does not accept this).') 23 | parser.add_argument('--seed', default=0, type=int, 24 | help='Environment seed.') 25 | parser.add_argument('--save_interval', default=10, type=int, metavar='N', 26 | help='Time interval (in minutes) at which to save the model.') 27 | parser.add_argument('--checkpoint_num_frames', default=10000000, type=int, 28 | help='Number of frames for checkpoint to load.') 29 | 30 | # Training settings. 31 | parser.add_argument('--disable_checkpoint', action='store_true', 32 | help='Disable saving checkpoint.') 33 | parser.add_argument('--savedir', default='./experiments/', 34 | help='Root dir where experiment data will be saved.') 35 | parser.add_argument('--num_actors', default=40, type=int, metavar='N', 36 | help='Number of actors.') 37 | parser.add_argument('--total_frames', default=500000000, type=int, metavar='T', 38 | help='Total environment frames to train for.') 39 | parser.add_argument('--batch_size', default=32, type=int, metavar='B', 40 | help='Learner batch size.') 41 | parser.add_argument('--unroll_length', default=100, type=int, metavar='T', 42 | help='The unroll length (time dimension).') 43 | parser.add_argument('--queue_timeout', default=1, type=int, 44 | metavar='S', help='Error timeout for queue.') 45 | parser.add_argument('--num_buffers', default=80, type=int, 46 | metavar='N', help='Number of shared-memory buffers.') 47 | parser.add_argument('--num_threads', default=4, type=int, 48 | metavar='N', help='Number learner threads.') 49 | parser.add_argument('--disable_cuda', action='store_true', 50 | help='Disable CUDA.') 51 | parser.add_argument('--max_grad_norm', default=40., type=float, 52 | metavar='MGN', help='Max norm of gradients.') 53 | 54 | # Loss settings. 55 | parser.add_argument('--entropy_cost', default=0.001, type=float, 56 | help='Entropy cost/multiplier.') 57 | parser.add_argument('--baseline_cost', default=0.5, type=float, 58 | help='Baseline cost/multiplier.') 59 | parser.add_argument('--discounting', default=0.99, type=float, 60 | help='Discounting factor.') 61 | 62 | # Optimizer settings. 63 | parser.add_argument('--learning_rate', default=0.0001, type=float, 64 | metavar='LR', help='Learning rate.') 65 | parser.add_argument('--predictor_learning_rate', default=0.0001, type=float, 66 | metavar='LR', help='Learning rate for RND predictor.') 67 | parser.add_argument('--alpha', default=0.99, type=float, 68 | help='RMSProp smoothing constant.') 69 | parser.add_argument('--momentum', default=0, type=float, 70 | help='RMSProp momentum.') 71 | parser.add_argument('--epsilon', default=1e-5, type=float, 72 | help='RMSProp epsilon.') 73 | 74 | # Exploration Settings. 75 | parser.add_argument('--forward_loss_coef', default=10.0, type=float, 76 | help='Coefficient for the forward dynamics loss. \ 77 | This weighs the inverse model loss agains the forward model loss. \ 78 | Should be between 0 and 1.') 79 | parser.add_argument('--inverse_loss_coef', default=0.1, type=float, 80 | help='Coefficient for the forward dynamics loss. \ 81 | This weighs the inverse model loss agains the forward model loss. \ 82 | Should be between 0 and 1.') 83 | parser.add_argument('--intrinsic_reward_coef', default=0.5, type=float, 84 | help='Coefficient for the intrinsic reward. \ 85 | This weighs the intrinsic reaward against the extrinsic one. \ 86 | Should be larger than 0.') 87 | parser.add_argument('--rnd_loss_coef', default=1.0, type=float, 88 | help='Coefficient for the RND loss coefficient relative to the IMPALA one.') 89 | 90 | # Singleton Environments. 91 | parser.add_argument('--fix_seed', action='store_true', 92 | help='Fix the environment seed so that it is \ 93 | no longer procedurally generated but rather the same layout every episode.') 94 | parser.add_argument('--env_seed', default=1, type=int, 95 | help='The seed used to generate the environment if we are using a \ 96 | singleton (i.e. not procedurally generated) environment.') 97 | parser.add_argument('--no_reward', action='store_true', 98 | help='No extrinsic reward. The agent uses only intrinsic reward to learn.') 99 | 100 | # Training Models. 101 | parser.add_argument('--model', default='vanilla', 102 | choices=['bebold'], 103 | help='Model used for training the agent.') 104 | 105 | # Baselines for AMIGo paper. 106 | parser.add_argument('--use_fullobs_policy', action='store_true', 107 | help='Use a full view of the environment as input to the policy network.') 108 | parser.add_argument('--use_fullobs_intrinsic', action='store_true', 109 | help='Use a full view of the environment for computing the intrinsic reward.') 110 | parser.add_argument('--target_update_freq', default=2, type=int, 111 | help='Number of time steps for updating target') 112 | parser.add_argument('--init_num_frames', default=1e6, type=int, 113 | help='Number of frames for updating teacher network') 114 | parser.add_argument('--planning_intrinsic_reward_coef', default=0.5, type=float, 115 | help='Coefficient for the planning intrinsic reward. \ 116 | This weighs the intrinsic reaward against the extrinsic one. \ 117 | Should be larger than 0.') 118 | parser.add_argument('--ema_momentum', default=1.0, type=float, 119 | help='Coefficient for the EMA update of the RND network') 120 | parser.add_argument('--use_lstm', action='store_true', 121 | help='Use a lstm version of policy network.') 122 | parser.add_argument('--use_lstm_intrinsic', action='store_true', 123 | help='Use a lstm version of intrinsic embedding network.') 124 | parser.add_argument('--state_embedding_dim', default=256, type=int, 125 | help='Embedding dimension of last layer of network') 126 | parser.add_argument('--scale_fac', default=0.5, type=float, 127 | help='coefficient for scaling visitation count difference') 128 | -------------------------------------------------------------------------------- /src/algos/torchbeast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import threading 10 | import time 11 | import timeit 12 | import pprint 13 | 14 | import numpy as np 15 | 16 | import torch 17 | from torch import multiprocessing as mp 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | from src.core import file_writer 22 | from src.core import prof 23 | from src.core import vtrace 24 | 25 | import src.models as models 26 | import src.losses as losses 27 | 28 | from src.env_utils import FrameStack 29 | from src.utils import get_batch, log, create_env, create_buffers, act 30 | 31 | 32 | MinigridPolicyNet = models.MinigridPolicyNet 33 | # MarioDoomPolicyNet = models.MarioDoomPolicyNet 34 | 35 | 36 | def learn(actor_model, 37 | model, 38 | batch, 39 | initial_agent_state, 40 | optimizer, 41 | scheduler, 42 | flags, 43 | lock=threading.Lock()): 44 | """Performs a learning (optimization) step.""" 45 | with lock: 46 | learner_outputs, unused_state = model(batch, initial_agent_state) 47 | 48 | bootstrap_value = learner_outputs['baseline'][-1] 49 | 50 | batch = {key: tensor[1:] for key, tensor in batch.items()} 51 | learner_outputs = { 52 | key: tensor[:-1] 53 | for key, tensor in learner_outputs.items() 54 | } 55 | 56 | rewards = batch['reward'] 57 | clipped_rewards = torch.clamp(rewards, -1, 1) 58 | 59 | discounts = (~batch['done']).float() * flags.discounting 60 | 61 | vtrace_returns = vtrace.from_logits( 62 | behavior_policy_logits=batch['policy_logits'], 63 | target_policy_logits=learner_outputs['policy_logits'], 64 | actions=batch['action'], 65 | discounts=discounts, 66 | rewards=clipped_rewards, 67 | values=learner_outputs['baseline'], 68 | bootstrap_value=bootstrap_value) 69 | 70 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], 71 | batch['action'], 72 | vtrace_returns.pg_advantages) 73 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( 74 | vtrace_returns.vs - learner_outputs['baseline']) 75 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( 76 | learner_outputs['policy_logits']) 77 | 78 | total_loss = pg_loss + baseline_loss + entropy_loss 79 | 80 | episode_returns = batch['episode_return'][batch['done']] 81 | stats = { 82 | 'mean_episode_return': torch.mean(episode_returns).item(), 83 | 'total_loss': total_loss.item(), 84 | 'pg_loss': pg_loss.item(), 85 | 'baseline_loss': baseline_loss.item(), 86 | 'entropy_loss': entropy_loss.item(), 87 | } 88 | 89 | scheduler.step() 90 | optimizer.zero_grad() 91 | total_loss.backward() 92 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) 93 | optimizer.step() 94 | 95 | actor_model.load_state_dict(model.state_dict()) 96 | return stats 97 | 98 | 99 | def train(flags): 100 | if flags.xpid is None: 101 | flags.xpid = 'torchbeast-%s' % time.strftime('%Y%m%d-%H%M%S') 102 | plogger = file_writer.FileWriter( 103 | xpid=flags.xpid, 104 | xp_args=flags.__dict__, 105 | rootdir=flags.savedir, 106 | ) 107 | checkpointpath = os.path.expandvars( 108 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 109 | 'model.tar'))) 110 | 111 | T = flags.unroll_length 112 | B = flags.batch_size 113 | 114 | flags.device = None 115 | if not flags.disable_cuda and torch.cuda.is_available(): 116 | log.info('Using CUDA.') 117 | flags.device = torch.device('cuda') 118 | else: 119 | log.info('Not using CUDA.') 120 | flags.device = torch.device('cpu') 121 | env = create_env(flags) 122 | if flags.num_input_frames > 1: 123 | env = FrameStack(env, flags.num_input_frames) 124 | 125 | if 'MiniGrid' in flags.env: 126 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n) 127 | else: 128 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n) 129 | 130 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags) 131 | 132 | model.share_memory() 133 | 134 | initial_agent_state_buffers = [] 135 | for _ in range(flags.num_buffers): 136 | state = model.initial_state(batch_size=1) 137 | for t in state: 138 | t.share_memory_() 139 | initial_agent_state_buffers.append(state) 140 | 141 | actor_processes = [] 142 | ctx = mp.get_context('fork') 143 | free_queue = ctx.SimpleQueue() 144 | full_queue = ctx.SimpleQueue() 145 | 146 | episode_state_count_dict = dict() 147 | train_state_count_dict = dict() 148 | for i in range(flags.num_actors): 149 | actor = ctx.Process( 150 | target=act, 151 | args=(i, free_queue, full_queue, model, buffers, 152 | episode_state_count_dict, train_state_count_dict, 153 | initial_agent_state_buffers, flags)) 154 | actor.start() 155 | actor_processes.append(actor) 156 | 157 | if 'MiniGrid' in flags.env: 158 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 159 | .to(device=flags.device) 160 | else: 161 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\ 162 | .to(device=flags.device) 163 | 164 | optimizer = torch.optim.RMSprop( 165 | learner_model.parameters(), 166 | lr=flags.learning_rate, 167 | momentum=flags.momentum, 168 | eps=flags.epsilon, 169 | alpha=flags.alpha) 170 | 171 | 172 | def lr_lambda(epoch): 173 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames 174 | 175 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 176 | 177 | logger = logging.getLogger('logfile') 178 | stat_keys = [ 179 | 'mean_episode_return', 180 | 'total_loss', 181 | 'pg_loss', 182 | 'baseline_loss', 183 | 'entropy_loss', 184 | ] 185 | logger.info('# Step\t%s', '\t'.join(stat_keys)) 186 | frames, stats = 0, {} 187 | 188 | 189 | def batch_and_learn(i, lock=threading.Lock()): 190 | """Thread target for the learning process.""" 191 | nonlocal frames, stats 192 | timings = prof.Timings() 193 | while frames < flags.total_frames: 194 | timings.reset() 195 | batch, agent_state = get_batch(free_queue, full_queue, buffers, 196 | initial_agent_state_buffers, flags, timings) 197 | stats = learn(model, learner_model, batch, agent_state, 198 | optimizer, scheduler, flags) 199 | timings.time('learn') 200 | with lock: 201 | to_log = dict(frames=frames) 202 | to_log.update({k: stats[k] for k in stat_keys}) 203 | plogger.log(to_log) 204 | frames += T * B 205 | 206 | if i == 0: 207 | log.info('Batch and learn: %s', timings.summary()) 208 | 209 | for m in range(flags.num_buffers): 210 | free_queue.put(m) 211 | 212 | threads = [] 213 | for i in range(flags.num_threads): 214 | thread = threading.Thread( 215 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) 216 | thread.start() 217 | threads.append(thread) 218 | 219 | def checkpoint(frames): 220 | if flags.disable_checkpoint: 221 | return 222 | checkpointpath = os.path.expandvars(os.path.expanduser( 223 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar'))) 224 | log.info('Saving checkpoint to %s', checkpointpath) 225 | torch.save({ 226 | 'model_state_dict': model.state_dict(), 227 | 'optimizer_state_dict': optimizer.state_dict(), 228 | 'scheduler_state_dict': scheduler.state_dict(), 229 | 'flags': vars(flags), 230 | }, checkpointpath) 231 | 232 | timer = timeit.default_timer 233 | try: 234 | last_checkpoint_time = timer() 235 | while frames < flags.total_frames: 236 | start_frames = frames 237 | start_time = timer() 238 | time.sleep(5) 239 | 240 | if timer() - last_checkpoint_time > flags.save_interval * 60: 241 | checkpoint(frames) 242 | last_checkpoint_time = timer() 243 | 244 | fps = (frames - start_frames) / (timer() - start_time) 245 | if stats.get('episode_returns', None): 246 | mean_return = 'Return per episode: %.1f. ' % stats[ 247 | 'mean_episode_return'] 248 | else: 249 | mean_return = '' 250 | total_loss = stats.get('total_loss', float('inf')) 251 | log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s', 252 | frames, total_loss, fps, mean_return, 253 | pprint.pformat(stats)) 254 | 255 | except KeyboardInterrupt: 256 | return 257 | else: 258 | for thread in threads: 259 | thread.join() 260 | log.info('Learning finished after %d frames.', frames) 261 | 262 | finally: 263 | for _ in range(flags.num_actors): 264 | free_queue.put(None) 265 | for actor in actor_processes: 266 | actor.join(timeout=1) 267 | checkpoint(frames) 268 | plogger.close() 269 | 270 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | def init(module, weight_init, bias_init, gain=1): 13 | weight_init(module.weight.data, gain=gain) 14 | bias_init(module.bias.data) 15 | return module 16 | 17 | 18 | class MinigridPolicyNet(nn.Module): 19 | def __init__(self, observation_shape, num_actions): 20 | super(MinigridPolicyNet, self).__init__() 21 | self.observation_shape = observation_shape 22 | self.num_actions = num_actions 23 | 24 | init_ = lambda m: init(m, nn.init.orthogonal_, 25 | lambda x: nn.init.constant_(x, 0), 26 | nn.init.calculate_gain('relu')) 27 | 28 | self.feat_extract = nn.Sequential( 29 | init_(nn.Conv2d(in_channels=self.observation_shape[2], out_channels=32, kernel_size=(3, 3), stride=1, padding=1)), 30 | nn.ELU(), 31 | init_(nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(3, 3), stride=2, padding=1)), 32 | nn.ELU(), 33 | init_(nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(3, 3), stride=2, padding=1)), 34 | nn.ELU(), 35 | ) 36 | 37 | self.fc = nn.Sequential( 38 | init_(nn.Linear(2048, 1024)), 39 | nn.ReLU(), 40 | init_(nn.Linear(1024, 1024)), 41 | nn.ReLU(), 42 | ) 43 | 44 | self.core = nn.LSTM(1024, 1024, 2) 45 | 46 | init_ = lambda m: init(m, nn.init.orthogonal_, 47 | lambda x: nn.init.constant_(x, 0)) 48 | 49 | self.policy = init_(nn.Linear(1024, self.num_actions)) 50 | self.baseline = init_(nn.Linear(1024, 1)) 51 | 52 | 53 | def initial_state(self, batch_size): 54 | return tuple(torch.zeros(self.core.num_layers, batch_size, 55 | self.core.hidden_size) for _ in range(2)) 56 | 57 | 58 | def forward(self, inputs, core_state=()): 59 | # -- [unroll_length x batch_size x height x width x channels] 60 | x = inputs['partial_obs'] 61 | T, B, *_ = x.shape 62 | 63 | # -- [unroll_length*batch_size x height x width x channels] 64 | x = torch.flatten(x, 0, 1) # Merge time and batch. 65 | x = x.float() 66 | 67 | # -- [unroll_length*batch_size x channels x width x height] 68 | x = x.permute(0, 3, 1, 2) 69 | x = self.feat_extract(x) 70 | x = x.reshape(T * B, -1) 71 | core_input = self.fc(x) 72 | 73 | core_input = core_input.view(T, B, -1) 74 | core_output_list = [] 75 | notdone = (~inputs['done']).float() 76 | for input, nd in zip(core_input.unbind(), notdone.unbind()): 77 | nd = nd.view(1, -1, 1) 78 | core_state = tuple(nd * s for s in core_state) 79 | output, core_state = self.core(input.unsqueeze(0), core_state) 80 | core_output_list.append(output) 81 | core_output = torch.flatten(torch.cat(core_output_list), 0, 1) 82 | 83 | policy_logits = self.policy(core_output) 84 | baseline = self.baseline(core_output) 85 | 86 | if self.training: 87 | action = torch.multinomial( 88 | F.softmax(policy_logits, dim=1), num_samples=1) 89 | else: 90 | action = torch.argmax(policy_logits, dim=1) 91 | 92 | policy_logits = policy_logits.view(T, B, self.num_actions) 93 | baseline = baseline.view(T, B) 94 | action = action.view(T, B) 95 | 96 | return dict(policy_logits=policy_logits, baseline=baseline, 97 | action=action), core_state 98 | 99 | 100 | class MinigridStateEmbeddingNet(nn.Module): 101 | def __init__(self, observation_shape, use_lstm=False): 102 | super(MinigridStateEmbeddingNet, self).__init__() 103 | self.observation_shape = observation_shape 104 | self.use_lstm = use_lstm 105 | 106 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 107 | constant_(x, 0), nn.init.calculate_gain('relu')) 108 | 109 | self.feat_extract = nn.Sequential( 110 | init_(nn.Conv2d(in_channels=self.observation_shape[2], out_channels=32, kernel_size=(3, 3), stride=1, padding=1)), 111 | nn.ELU(), 112 | init_(nn.Conv2d(in_channels=32, out_channels=128, kernel_size=(3, 3), stride=2, padding=1)), 113 | nn.ELU(), 114 | init_(nn.Conv2d(in_channels=128, out_channels=512, kernel_size=(3, 3), stride=2, padding=1)), 115 | nn.ELU(), 116 | ) 117 | 118 | self.fc = nn.Sequential( 119 | init_(nn.Linear(2048, 1024)), 120 | nn.ReLU(), 121 | init_(nn.Linear(1024, 1024)), 122 | nn.ReLU(), 123 | ) 124 | 125 | if self.use_lstm: 126 | self.core = nn.LSTM(1024, 1024, 2) 127 | 128 | def initial_state(self, batch_size): 129 | #TODO: we might need to change this 130 | return tuple(torch.zeros(2, batch_size, 131 | 1024) for _ in range(2)) 132 | 133 | def forward(self, inputs, core_state=(), done=None): 134 | # -- [unroll_length x batch_size x height x width x channels] 135 | x = inputs 136 | T, B, *_ = x.shape 137 | 138 | # -- [unroll_length*batch_size x height x width x channels] 139 | x = torch.flatten(x, 0, 1) # Merge time and batch. 140 | 141 | x = x.float() 142 | 143 | # -- [unroll_length*batch_size x channels x width x height] 144 | x = x.permute(0, 3, 1, 2) 145 | x = self.feat_extract(x) 146 | x = x.reshape(T * B, -1) 147 | x = self.fc(x) 148 | 149 | if self.use_lstm: 150 | core_input = x.view(T, B, -1) 151 | core_output_list = [] 152 | notdone = (~done).float() 153 | for input, nd in zip(core_input.unbind(), notdone.unbind()): 154 | nd = nd.view(1, -1, 1) 155 | core_state = tuple(nd * s for s in core_state) 156 | output, core_state = self.core(input.unsqueeze(0), core_state) 157 | core_output_list.append(output) 158 | x = torch.flatten(torch.cat(core_output_list), 0, 1) 159 | 160 | state_embedding = x.view(T, B, -1) 161 | return state_embedding, core_state 162 | 163 | class MinigridMLPEmbeddingNet(nn.Module): 164 | def __init__(self): 165 | super(MinigridMLPEmbeddingNet, self).__init__() 166 | 167 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 168 | constant_(x, 0), nn.init.calculate_gain('relu')) 169 | 170 | self.fc = nn.Sequential( 171 | nn.Linear(1024, 1024), 172 | nn.ReLU(), 173 | nn.Linear(1024, 1024), 174 | nn.ReLU(), 175 | nn.Linear(1024, 128), 176 | nn.ReLU(), 177 | nn.Linear(128, 128), 178 | nn.ReLU(), 179 | nn.Linear(128, 128), 180 | ) 181 | 182 | def forward(self, inputs, core_state=()): 183 | x = inputs 184 | T, B, *_ = x.shape 185 | 186 | x = self.fc(x) 187 | 188 | state_embedding = x.reshape(T, B, -1) 189 | 190 | return state_embedding, tuple() 191 | 192 | 193 | class MinigridMLPTargetEmbeddingNet(nn.Module): 194 | def __init__(self): 195 | super(MinigridMLPTargetEmbeddingNet, self).__init__() 196 | 197 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 198 | constant_(x, 0), nn.init.calculate_gain('relu')) 199 | 200 | self.fc = nn.Sequential( 201 | nn.Linear(1024, 128), 202 | nn.ReLU(), 203 | nn.Linear(128, 128), 204 | nn.ReLU(), 205 | nn.Linear(128, 128), 206 | nn.ReLU(), 207 | nn.Linear(128, 128), 208 | ) 209 | 210 | def forward(self, inputs, core_state=()): 211 | x = inputs 212 | T, B, *_ = x.shape 213 | 214 | x = self.fc(x) 215 | 216 | state_embedding = x.reshape(T, B, -1) 217 | 218 | return state_embedding, tuple() 219 | 220 | 221 | class MinigridInverseDynamicsNet(nn.Module): 222 | def __init__(self, num_actions): 223 | super(MinigridInverseDynamicsNet, self).__init__() 224 | self.num_actions = num_actions 225 | 226 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 227 | constant_(x, 0), nn.init.calculate_gain('relu')) 228 | self.inverse_dynamics = nn.Sequential( 229 | init_(nn.Linear(2 * 128, 256)), 230 | nn.ReLU(), 231 | ) 232 | 233 | init_ = lambda m: init(m, nn.init.orthogonal_, 234 | lambda x: nn.init.constant_(x, 0)) 235 | self.id_out = init_(nn.Linear(256, self.num_actions)) 236 | 237 | 238 | def forward(self, state_embedding, next_state_embedding): 239 | inputs = torch.cat((state_embedding, next_state_embedding), dim=2) 240 | action_logits = self.id_out(self.inverse_dynamics(inputs)) 241 | return action_logits 242 | 243 | 244 | class MinigridForwardDynamicsNet(nn.Module): 245 | def __init__(self, num_actions): 246 | super(MinigridForwardDynamicsNet, self).__init__() 247 | self.num_actions = num_actions 248 | 249 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 250 | constant_(x, 0), nn.init.calculate_gain('relu')) 251 | 252 | self.forward_dynamics = nn.Sequential( 253 | init_(nn.Linear(128 + self.num_actions, 256)), 254 | nn.ReLU(), 255 | ) 256 | 257 | init_ = lambda m: init(m, nn.init.orthogonal_, 258 | lambda x: nn.init.constant_(x, 0)) 259 | 260 | self.fd_out = init_(nn.Linear(256, 128)) 261 | 262 | def forward(self, state_embedding, action): 263 | action_one_hot = F.one_hot(action, num_classes=self.num_actions).float() 264 | inputs = torch.cat((state_embedding, action_one_hot), dim=2) 265 | next_state_emb = self.fd_out(self.forward_dynamics(inputs)) 266 | return next_state_emb 267 | 268 | 269 | -------------------------------------------------------------------------------- /src/core/vtrace_test.py: -------------------------------------------------------------------------------- 1 | # This file taken from 2 | # https://github.com/deepmind/scalable_agent/blob/ 3 | # d24bd74bd53d454b7222b7f0bea57a358e4ca33e/vtrace_test.py 4 | # and modified. 5 | 6 | # Copyright 2018 Google LLC 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # https://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | """Tests for V-trace. 21 | 22 | For details and theory see: 23 | 24 | "IMPALA: Scalable Distributed Deep-RL with 25 | Importance Weighted Actor-Learner Architectures" 26 | by Espeholt, Soyer, Munos et al. 27 | """ 28 | 29 | import unittest 30 | 31 | import numpy as np 32 | import torch 33 | 34 | import vtrace 35 | 36 | 37 | def _shaped_arange(*shape): 38 | """Runs np.arange, converts to float and reshapes.""" 39 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) 40 | 41 | 42 | def _softmax(logits): 43 | """Applies softmax non-linearity on inputs.""" 44 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) 45 | 46 | 47 | def _ground_truth_calculation(discounts, log_rhos, rewards, values, 48 | bootstrap_value, clip_rho_threshold, 49 | clip_pg_rho_threshold): 50 | """Calculates the ground truth for V-trace in Python/Numpy.""" 51 | vs = [] 52 | seq_len = len(discounts) 53 | rhos = np.exp(log_rhos) 54 | cs = np.minimum(rhos, 1.0) 55 | clipped_rhos = rhos 56 | if clip_rho_threshold: 57 | clipped_rhos = np.minimum(rhos, clip_rho_threshold) 58 | clipped_pg_rhos = rhos 59 | if clip_pg_rho_threshold: 60 | clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) 61 | 62 | # This is a very inefficient way to calculate the V-trace ground truth. 63 | # We calculate it this way because it is close to the mathematical notation 64 | # of V-trace. 65 | # v_s = V(x_s) 66 | # + \sum^{T-1}_{t=s} \gamma^{t-s} 67 | # * \prod_{i=s}^{t-1} c_i 68 | # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) 69 | # Note that when we take the product over c_i, we write `s:t` as the 70 | # notation of the paper is inclusive of the `t-1`, but Python is exclusive. 71 | # Also note that np.prod([]) == 1. 72 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) 73 | for s in range(seq_len): 74 | v_s = np.copy(values[s]) # Very important copy. 75 | for t in range(s, seq_len): 76 | v_s += ( 77 | np.prod(discounts[s:t], axis=0) * np.prod( 78 | cs[s:t], axis=0) * clipped_rhos[t] * 79 | (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t] 80 | )) 81 | vs.append(v_s) 82 | vs = np.stack(vs, axis=0) 83 | pg_advantages = (clipped_pg_rhos * (rewards + discounts * np.concatenate( 84 | [vs[1:], bootstrap_value[None, :]], axis=0) - values)) 85 | 86 | return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) 87 | 88 | 89 | def assert_allclose(actual, desired): 90 | return np.testing.assert_allclose(actual, desired, rtol=1e-06, atol=1e-05) 91 | 92 | 93 | class ActionLogProbsTest(unittest.TestCase): 94 | 95 | def test_action_log_probs(self, batch_size=2): 96 | seq_len = 7 97 | num_actions = 3 98 | 99 | policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 100 | actions = np.random.randint( 101 | 0, num_actions, size=(seq_len, batch_size), dtype=np.int64) 102 | 103 | action_log_probs_tensor = vtrace.action_log_probs( 104 | torch.from_numpy(policy_logits), torch.from_numpy(actions)) 105 | 106 | # Ground Truth 107 | # Using broadcasting to create a mask that indexes action logits 108 | action_index_mask = actions[..., None] == np.arange(num_actions) 109 | 110 | def index_with_mask(array, mask): 111 | return array[mask].reshape(*array.shape[:-1]) 112 | 113 | # Note: Normally log(softmax) is not a good idea because it's not 114 | # numerically stable. However, in this test we have well-behaved values. 115 | ground_truth_v = index_with_mask( 116 | np.log(_softmax(policy_logits)), action_index_mask) 117 | 118 | assert_allclose(ground_truth_v, action_log_probs_tensor) 119 | 120 | def test_action_log_probs_batch_1(self): 121 | self.test_action_log_probs(1) 122 | 123 | 124 | class VtraceTest(unittest.TestCase): 125 | 126 | def test_vtrace(self, batch_size=5): 127 | """Tests V-trace against ground truth data calculated in python.""" 128 | seq_len = 5 129 | 130 | # Create log_rhos such that rho will span from near-zero to above the 131 | # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), 132 | # so that rho is in approx [0.08, 12.2). 133 | log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) 134 | log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). 135 | values = { 136 | 'log_rhos': 137 | log_rhos, 138 | # T, B where B_i: [0.9 / (i+1)] * T 139 | 'discounts': 140 | np.array( 141 | [[0.9 / (b + 1) 142 | for b in range(batch_size)] 143 | for _ in range(seq_len)], 144 | dtype=np.float32), 145 | 'rewards': 146 | _shaped_arange(seq_len, batch_size), 147 | 'values': 148 | _shaped_arange(seq_len, batch_size) / batch_size, 149 | 'bootstrap_value': 150 | _shaped_arange(batch_size) + 1.0, 151 | 'clip_rho_threshold': 152 | 3.7, 153 | 'clip_pg_rho_threshold': 154 | 2.2, 155 | } 156 | 157 | ground_truth = _ground_truth_calculation(**values) 158 | 159 | values = {key: torch.tensor(value) for key, value in values.items()} 160 | output = vtrace.from_importance_weights(**values) 161 | 162 | for a, b in zip(ground_truth, output): 163 | assert_allclose(a, b) 164 | 165 | def test_vtrace_batch_1(self): 166 | self.test_vtrace(1) 167 | 168 | def test_vtrace_from_logits(self, batch_size=2): 169 | """Tests V-trace calculated from logits.""" 170 | seq_len = 5 171 | num_actions = 3 172 | clip_rho_threshold = None # No clipping. 173 | clip_pg_rho_threshold = None # No clipping. 174 | 175 | values = { 176 | 'behavior_policy_logits': 177 | _shaped_arange(seq_len, batch_size, num_actions), 178 | 'target_policy_logits': 179 | _shaped_arange(seq_len, batch_size, num_actions), 180 | 'actions': 181 | np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)), 182 | 'discounts': 183 | np.array( # T, B where B_i: [0.9 / (i+1)] * T 184 | [[0.9 / (b + 1) 185 | for b in range(batch_size)] 186 | for _ in range(seq_len)], 187 | dtype=np.float32), 188 | 'rewards': 189 | _shaped_arange(seq_len, batch_size), 190 | 'values': 191 | _shaped_arange(seq_len, batch_size) / batch_size, 192 | 'bootstrap_value': 193 | _shaped_arange(batch_size) + 1.0, # B 194 | } 195 | values = {k: torch.from_numpy(v) for k, v in values.items()} 196 | 197 | from_logits_output = vtrace.from_logits( 198 | clip_rho_threshold=clip_rho_threshold, 199 | clip_pg_rho_threshold=clip_pg_rho_threshold, 200 | **values) 201 | 202 | target_log_probs = vtrace.action_log_probs( 203 | values['target_policy_logits'], values['actions']) 204 | behavior_log_probs = vtrace.action_log_probs( 205 | values['behavior_policy_logits'], values['actions']) 206 | log_rhos = target_log_probs - behavior_log_probs 207 | 208 | # Calculate V-trace using the ground truth logits. 209 | from_iw = vtrace.from_importance_weights( 210 | log_rhos=log_rhos, 211 | discounts=values['discounts'], 212 | rewards=values['rewards'], 213 | values=values['values'], 214 | bootstrap_value=values['bootstrap_value'], 215 | clip_rho_threshold=clip_rho_threshold, 216 | clip_pg_rho_threshold=clip_pg_rho_threshold) 217 | 218 | assert_allclose(from_iw.vs, from_logits_output.vs) 219 | assert_allclose(from_iw.pg_advantages, from_logits_output.pg_advantages) 220 | assert_allclose(behavior_log_probs, 221 | from_logits_output.behavior_action_log_probs) 222 | assert_allclose(target_log_probs, 223 | from_logits_output.target_action_log_probs) 224 | assert_allclose(log_rhos, from_logits_output.log_rhos) 225 | 226 | def test_vtrace_from_logits_batch_1(self): 227 | self.test_vtrace_from_logits(1) 228 | 229 | def test_higher_rank_inputs_for_importance_weights(self): 230 | """Checks support for additional dimensions in inputs.""" 231 | T = 3 # pylint: disable=invalid-name 232 | B = 2 # pylint: disable=invalid-name 233 | values = { 234 | 'log_rhos': torch.zeros(T, B, 1), 235 | 'discounts': torch.zeros(T, B, 1), 236 | 'rewards': torch.zeros(T, B, 42), 237 | 'values': torch.zeros(T, B, 42), 238 | 'bootstrap_value': torch.zeros(B, 42), 239 | } 240 | output = vtrace.from_importance_weights(**values) 241 | self.assertSequenceEqual(output.vs.shape, (T, B, 42)) 242 | 243 | def test_inconsistent_rank_inputs_for_importance_weights(self): 244 | """Test one of many possible errors in shape of inputs.""" 245 | T = 3 # pylint: disable=invalid-name 246 | B = 2 # pylint: disable=invalid-name 247 | 248 | values = { 249 | 'log_rhos': torch.zeros(T, B, 1), 250 | 'discounts': torch.zeros(T, B, 1), 251 | 'rewards': torch.zeros(T, B, 42), 252 | 'values': torch.zeros(T, B, 42), 253 | # Should be [B, 42]. 254 | 'bootstrap_value': torch.zeros(B), 255 | } 256 | 257 | with self.assertRaisesRegex(RuntimeError, 258 | 'same number of dimensions: got 3 and 2'): 259 | vtrace.from_importance_weights(**values) 260 | 261 | 262 | if __name__ == '__main__': 263 | unittest.main() 264 | -------------------------------------------------------------------------------- /src/algos/count.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import threading 10 | import time 11 | import timeit 12 | import pprint 13 | 14 | import numpy as np 15 | 16 | import torch 17 | from torch import multiprocessing as mp 18 | from torch import nn 19 | from torch.nn import functional as F 20 | 21 | from src.core import file_writer 22 | from src.core import prof 23 | from src.core import vtrace 24 | 25 | import src.models as models 26 | import src.losses as losses 27 | 28 | from src.env_utils import FrameStack 29 | from src.utils import get_batch, log, create_env, create_buffers, act 30 | 31 | MinigridPolicyNet = models.MinigridPolicyNet 32 | 33 | def learn(actor_model, 34 | model, 35 | batch, 36 | initial_agent_state, 37 | optimizer, 38 | scheduler, 39 | flags, 40 | lock=threading.Lock()): 41 | """Performs a learning (optimization) step.""" 42 | with lock: 43 | intrinsic_rewards = torch.ones((flags.unroll_length, flags.batch_size), 44 | dtype=torch.float32).to(device=flags.device) 45 | 46 | intrinsic_rewards = batch['train_state_count'][1:].float().to(device=flags.device) 47 | 48 | intrinsic_reward_coef = flags.intrinsic_reward_coef 49 | intrinsic_rewards *= intrinsic_reward_coef 50 | 51 | learner_outputs, unused_state = model(batch, initial_agent_state) 52 | 53 | bootstrap_value = learner_outputs['baseline'][-1] 54 | 55 | batch = {key: tensor[1:] for key, tensor in batch.items()} 56 | learner_outputs = { 57 | key: tensor[:-1] 58 | for key, tensor in learner_outputs.items() 59 | } 60 | 61 | rewards = batch['reward'] 62 | if flags.no_reward: 63 | total_rewards = intrinsic_rewards 64 | else: 65 | total_rewards = rewards + intrinsic_rewards 66 | clipped_rewards = torch.clamp(total_rewards, -1, 1) 67 | 68 | discounts = (~batch['done']).float() * flags.discounting 69 | 70 | vtrace_returns = vtrace.from_logits( 71 | behavior_policy_logits=batch['policy_logits'], 72 | target_policy_logits=learner_outputs['policy_logits'], 73 | actions=batch['action'], 74 | discounts=discounts, 75 | rewards=clipped_rewards, 76 | values=learner_outputs['baseline'], 77 | bootstrap_value=bootstrap_value) 78 | 79 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], 80 | batch['action'], 81 | vtrace_returns.pg_advantages) 82 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( 83 | vtrace_returns.vs - learner_outputs['baseline']) 84 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( 85 | learner_outputs['policy_logits']) 86 | 87 | total_loss = pg_loss + baseline_loss + entropy_loss 88 | 89 | episode_returns = batch['episode_return'][batch['done']] 90 | stats = { 91 | 'mean_episode_return': torch.mean(episode_returns).item(), 92 | 'total_loss': total_loss.item(), 93 | 'pg_loss': pg_loss.item(), 94 | 'baseline_loss': baseline_loss.item(), 95 | 'entropy_loss': entropy_loss.item(), 96 | 'mean_rewards': torch.mean(rewards).item(), 97 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 98 | 'mean_total_rewards': torch.mean(total_rewards).item(), 99 | } 100 | 101 | scheduler.step() 102 | optimizer.zero_grad() 103 | total_loss.backward() 104 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) 105 | optimizer.step() 106 | 107 | actor_model.load_state_dict(model.state_dict()) 108 | return stats 109 | 110 | 111 | def train(flags): 112 | if flags.xpid is None: 113 | flags.xpid = 'count-%s' % time.strftime('%Y%m%d-%H%M%S') 114 | plogger = file_writer.FileWriter( 115 | xpid=flags.xpid, 116 | xp_args=flags.__dict__, 117 | rootdir=flags.savedir, 118 | ) 119 | checkpointpath = os.path.expandvars( 120 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 121 | 'model.tar'))) 122 | 123 | T = flags.unroll_length 124 | B = flags.batch_size 125 | 126 | flags.device = None 127 | if not flags.disable_cuda and torch.cuda.is_available(): 128 | log.info('Using CUDA.') 129 | flags.device = torch.device('cuda') 130 | else: 131 | log.info('Not using CUDA.') 132 | flags.device = torch.device('cpu') 133 | 134 | env = create_env(flags) 135 | if flags.num_input_frames > 1: 136 | env = FrameStack(env, flags.num_input_frames) 137 | 138 | if 'MiniGrid' in flags.env: 139 | if flags.use_fullobs_policy: 140 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n) 141 | else: 142 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n) 143 | else: 144 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n) 145 | 146 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags) 147 | 148 | model.share_memory() 149 | 150 | initial_agent_state_buffers = [] 151 | for _ in range(flags.num_buffers): 152 | state = model.initial_state(batch_size=1) 153 | for t in state: 154 | t.share_memory_() 155 | initial_agent_state_buffers.append(state) 156 | 157 | actor_processes = [] 158 | ctx = mp.get_context('fork') 159 | free_queue = ctx.SimpleQueue() 160 | full_queue = ctx.SimpleQueue() 161 | 162 | episode_state_count_dict = dict() 163 | train_state_count_dict = dict() 164 | for i in range(flags.num_actors): 165 | actor = ctx.Process( 166 | target=act, 167 | args=(i, free_queue, full_queue, model, buffers, 168 | episode_state_count_dict, train_state_count_dict, 169 | initial_agent_state_buffers, flags)) 170 | actor.start() 171 | actor_processes.append(actor) 172 | 173 | if 'MiniGrid' in flags.env: 174 | if flags.use_fullobs_policy: 175 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 176 | .to(device=flags.device) 177 | else: 178 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 179 | .to(device=flags.device) 180 | else: 181 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\ 182 | .to(device=flags.device) 183 | 184 | optimizer = torch.optim.RMSprop( 185 | learner_model.parameters(), 186 | lr=flags.learning_rate, 187 | momentum=flags.momentum, 188 | eps=flags.epsilon, 189 | alpha=flags.alpha) 190 | 191 | 192 | def lr_lambda(epoch): 193 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames 194 | 195 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 196 | 197 | logger = logging.getLogger('logfile') 198 | stat_keys = [ 199 | 'total_loss', 200 | 'mean_episode_return', 201 | 'pg_loss', 202 | 'baseline_loss', 203 | 'entropy_loss', 204 | 'mean_rewards', 205 | 'mean_intrinsic_rewards', 206 | 'mean_total_rewards', 207 | ] 208 | 209 | logger.info('# Step\t%s', '\t'.join(stat_keys)) 210 | frames, stats = 0, {} 211 | 212 | 213 | def batch_and_learn(i, lock=threading.Lock()): 214 | """Thread target for the learning process.""" 215 | nonlocal frames, stats 216 | timings = prof.Timings() 217 | while frames < flags.total_frames: 218 | timings.reset() 219 | batch, agent_state = get_batch(free_queue, full_queue, buffers, 220 | initial_agent_state_buffers, flags, timings) 221 | stats = learn(model, learner_model, batch, agent_state, 222 | optimizer, scheduler, flags) 223 | timings.time('learn') 224 | with lock: 225 | to_log = dict(frames=frames) 226 | to_log.update({k: stats[k] for k in stat_keys}) 227 | plogger.log(to_log) 228 | frames += T * B 229 | 230 | if i == 0: 231 | log.info('Batch and learn: %s', timings.summary()) 232 | 233 | for m in range(flags.num_buffers): 234 | free_queue.put(m) 235 | 236 | threads = [] 237 | for i in range(flags.num_threads): 238 | thread = threading.Thread( 239 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) 240 | thread.start() 241 | threads.append(thread) 242 | 243 | def checkpoint(frames): 244 | if flags.disable_checkpoint: 245 | return 246 | checkpointpath = os.path.expandvars(os.path.expanduser( 247 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar'))) 248 | log.info('Saving checkpoint to %s', checkpointpath) 249 | torch.save({ 250 | 'model_state_dict': model.state_dict(), 251 | 'optimizer_state_dict': optimizer.state_dict(), 252 | 'scheduler_state_dict': scheduler.state_dict(), 253 | 'flags': vars(flags), 254 | }, checkpointpath) 255 | 256 | timer = timeit.default_timer 257 | try: 258 | last_checkpoint_time = timer() 259 | while frames < flags.total_frames: 260 | start_frames = frames 261 | start_time = timer() 262 | time.sleep(5) 263 | 264 | if timer() - last_checkpoint_time > flags.save_interval * 60: 265 | checkpoint(frames) 266 | last_checkpoint_time = timer() 267 | 268 | fps = (frames - start_frames) / (timer() - start_time) 269 | if stats.get('episode_returns', None): 270 | mean_return = 'Return per episode: %.1f. ' % stats[ 271 | 'mean_episode_return'] 272 | else: 273 | mean_return = '' 274 | total_loss = stats.get('total_loss', float('inf')) 275 | log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s', 276 | frames, total_loss, fps, mean_return, 277 | pprint.pformat(stats)) 278 | 279 | except KeyboardInterrupt: 280 | return 281 | else: 282 | for thread in threads: 283 | thread.join() 284 | log.info('Learning finished after %d frames.', frames) 285 | 286 | finally: 287 | for _ in range(flags.num_actors): 288 | free_queue.put(None) 289 | for actor in actor_processes: 290 | actor.join(timeout=1) 291 | checkpoint(frames) 292 | plogger.close() 293 | 294 | -------------------------------------------------------------------------------- /src/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This code was taken from 8 | # https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py 9 | # and modified. 10 | 11 | from collections import deque, defaultdict 12 | 13 | import numpy as np 14 | 15 | import gym 16 | from gym import spaces 17 | import cv2 18 | cv2.ocl.setUseOpenCL(False) 19 | 20 | 21 | class NoNegativeRewardEnv(gym.RewardWrapper): 22 | """Clip reward in negative direction.""" 23 | def __init__(self, env=None, neg_clip=0.0): 24 | super(NoNegativeRewardEnv, self).__init__(env) 25 | self.neg_clip = neg_clip 26 | 27 | def _reward(self, reward): 28 | new_reward = self.neg_clip if reward < self.neg_clip else reward 29 | return new_reward 30 | 31 | 32 | class NoopResetEnv(gym.Wrapper): 33 | 34 | def __init__(self, env, noop_max=30): 35 | """Sample initial states by taking random number of no-ops on reset. 36 | No-op is assumed to be action 0. 37 | """ 38 | gym.Wrapper.__init__(self, env) 39 | self.noop_max = noop_max 40 | self.override_num_noops = None 41 | self.noop_action = 0 42 | 43 | def reset(self, **kwargs): 44 | """ Do no-op action for a number of steps in [1, noop_max].""" 45 | self.env.reset(**kwargs) 46 | if self.override_num_noops is not None: 47 | noops = self.override_num_noops 48 | else: 49 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) 50 | assert noops > 0 51 | obs = None 52 | for _ in range(noops): 53 | obs, _, done, _ = self.env.step(self.noop_action) 54 | if done: 55 | obs = self.env.reset(**kwargs) 56 | return obs 57 | 58 | def step(self, ac): 59 | return self.env.step(ac) 60 | 61 | 62 | class FireResetEnv(gym.Wrapper): 63 | 64 | def __init__(self, env): 65 | """Take action on reset for environments that are fixed until firing.""" 66 | gym.Wrapper.__init__(self, env) 67 | 68 | def reset(self, **kwargs): 69 | self.env.reset(**kwargs) 70 | obs, _, done, _ = self.env.step(1) 71 | if done: 72 | self.env.reset(**kwargs) 73 | obs, _, done, _ = self.env.step(2) 74 | if done: 75 | self.env.reset(**kwargs) 76 | return obs 77 | 78 | def step(self, ac): 79 | return self.env.step(ac) 80 | 81 | 82 | class EpisodicLifeEnv(gym.Wrapper): 83 | 84 | def __init__(self, env): 85 | """Make end-of-life == end-of-episode, but only reset on true game over. 86 | Done by DeepMind for the DQN and co. since it helps value estimation. 87 | """ 88 | gym.Wrapper.__init__(self, env) 89 | self.lives = 0 90 | self.was_real_done = True 91 | 92 | def step(self, action): 93 | obs, reward, done, info = self.env.step(action) 94 | self.was_real_done = done 95 | # check current lives, make loss of life terminal, 96 | # then update lives to handle bonus lives 97 | lives = self.env.unwrapped.ale.lives() 98 | if lives < self.lives and lives > 0: 99 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 100 | # so it's important to keep lives > 0, so that we only reset once 101 | # the environment advertises done. 102 | done = True 103 | self.lives = lives 104 | return obs, reward, done, info 105 | 106 | def reset(self, **kwargs): 107 | """Reset only when lives are exhausted. 108 | This way all states are still reachable even though lives are episodic, 109 | and the learner need not know about any of this behind-the-scenes. 110 | """ 111 | if self.was_real_done: 112 | obs = self.env.reset(**kwargs) 113 | else: 114 | # no-op step to advance from terminal/lost life state 115 | obs, _, _, _ = self.env.step(0) 116 | self.lives = self.env.unwrapped.ale.lives() 117 | return obs 118 | 119 | 120 | class MaxAndSkipEnv(gym.Wrapper): 121 | 122 | def __init__(self, env, skip=4): 123 | """Return only every `skip`-th frame""" 124 | gym.Wrapper.__init__(self, env) 125 | # most recent raw observations (for max pooling across time steps) 126 | self._obs_buffer = np.zeros( 127 | (2,) + env.observation_space.shape, dtype=np.uint8) 128 | self._skip = skip 129 | 130 | def step(self, action): 131 | """Repeat action, sum reward, and max over last observations.""" 132 | total_reward = 0.0 133 | done = None 134 | for i in range(self._skip): 135 | obs, reward, done, info = self.env.step(action) 136 | if i == self._skip - 2: self._obs_buffer[0] = obs 137 | if i == self._skip - 1: self._obs_buffer[1] = obs 138 | total_reward += reward 139 | if done: 140 | break 141 | # Note that the observation on the done=True frame 142 | # doesn't matter 143 | max_frame = self._obs_buffer.max(axis=0) 144 | 145 | return max_frame, total_reward, done, info 146 | 147 | def reset(self, **kwargs): 148 | return self.env.reset(**kwargs) 149 | 150 | 151 | class ClipRewardEnv(gym.RewardWrapper): 152 | 153 | def __init__(self, env): 154 | gym.RewardWrapper.__init__(self, env) 155 | 156 | def reward(self, reward): 157 | """Bin reward to {+1, 0, -1} by its sign.""" 158 | return np.sign(reward) 159 | 160 | 161 | class WarpFrame(gym.ObservationWrapper): 162 | 163 | def __init__(self, env, width=84, height=84, grayscale=True): 164 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 165 | gym.ObservationWrapper.__init__(self, env) 166 | self.width = width 167 | self.height = height 168 | self.grayscale = grayscale 169 | if self.grayscale: 170 | self.observation_space = spaces.Box( 171 | low=0, 172 | high=255, 173 | shape=(self.height, self.width, 1), 174 | dtype=np.uint8) 175 | else: 176 | self.observation_space = spaces.Box( 177 | low=0, 178 | high=255, 179 | shape=(self.height, self.width, 3), 180 | dtype=np.uint8) 181 | 182 | def observation(self, frame): 183 | if self.grayscale: 184 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 185 | frame = cv2.resize( 186 | frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 187 | if self.grayscale: 188 | frame = np.expand_dims(frame, -1) 189 | return frame 190 | 191 | 192 | class FrameStack(gym.Wrapper): 193 | 194 | def __init__(self, env, k): 195 | """Stack k last frames. 196 | 197 | Returns lazy array, which is much more memory efficient. 198 | 199 | See Also 200 | -------- 201 | baselines.common.atari_wrappers.LazyFrames 202 | """ 203 | gym.Wrapper.__init__(self, env) 204 | self.k = k 205 | self.frames = deque([], maxlen=k) 206 | shp = env.observation_space.shape 207 | self.observation_space = spaces.Box( 208 | low=0, 209 | high=255, 210 | shape=(shp[:-1] + (shp[-1] * k,)), 211 | dtype=env.observation_space.dtype) 212 | 213 | def reset(self): 214 | ob = self.env.reset() 215 | for _ in range(self.k): 216 | self.frames.append(ob) 217 | return self._get_ob() 218 | 219 | def step(self, action): 220 | ob, reward, done, info = self.env.step(action) 221 | self.frames.append(ob) 222 | return self._get_ob(), reward, done, info 223 | 224 | def _get_ob(self): 225 | assert len(self.frames) == self.k 226 | return LazyFrames(list(self.frames)) 227 | 228 | 229 | class ScaledFloatFrame(gym.ObservationWrapper): 230 | 231 | def __init__(self, env): 232 | gym.ObservationWrapper.__init__(self, env) 233 | self.observation_space = gym.spaces.Box( 234 | low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) 235 | 236 | def observation(self, observation): 237 | # careful! This undoes the memory optimization, use 238 | # with smaller replay buffers only. 239 | return np.array(observation).astype(np.float32) / 255.0 240 | 241 | 242 | class LazyFrames(object): 243 | 244 | def __init__(self, frames): 245 | """This object ensures that common frames between the observations are only stored once. 246 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 247 | buffers. 248 | 249 | This object should only be converted to numpy array before being passed to the model. 250 | 251 | You'd not believe how complex the previous solution was.""" 252 | self._frames = frames 253 | self._out = None 254 | 255 | def _force(self): 256 | if self._out is None: 257 | self._out = np.concatenate(self._frames, axis=-1) 258 | self._frames = None 259 | return self._out 260 | 261 | def __array__(self, dtype=None): 262 | out = self._force() 263 | if dtype is not None: 264 | out = out.astype(dtype) 265 | return out 266 | 267 | def __len__(self): 268 | return len(self._force()) 269 | 270 | def __getitem__(self, i): 271 | return self._force()[i] 272 | 273 | 274 | def make_atari(env_id, timelimit=True, noop=False): 275 | # XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping 276 | env = gym.make(env_id) 277 | if not timelimit: 278 | env = env.env 279 | 280 | # NOTE: this is changed from original atari implementation 281 | # assert 'NoFrameskip' in env.spec.id 282 | if noop: 283 | env = NoopResetEnv(env, noop_max=30) 284 | env = MaxAndSkipEnv(env, skip=6) 285 | return env 286 | 287 | 288 | # NOTE: this was changed so that episode_life is False by default 289 | def wrap_deepmind(env, 290 | episode_life=False, 291 | clip_rewards=True, 292 | frame_stack=False, 293 | scale=True, 294 | fire=False): # FYI scale=False in openai/baselines 295 | """Configure environment for DeepMind-style Atari. 296 | """ 297 | if episode_life: 298 | env = EpisodicLifeEnv(env) 299 | if fire: 300 | if 'FIRE' in env.unwrapped.get_action_meanings(): 301 | env = FireResetEnv(env) 302 | env = WarpFrame(env, width=42, height=42) 303 | if scale: 304 | env = ScaledFloatFrame(env) 305 | if clip_rewards: 306 | env = ClipRewardEnv(env) 307 | if frame_stack: 308 | env = FrameStack(env, 4) 309 | return env 310 | 311 | 312 | # Taken from https://github.com/openai/baselines/blob/master/baselines/run.py 313 | def get_env_type(env_id): 314 | # Re-parse the gym registry, since we could have new envs since last time. 315 | for env in gym.envs.registry.all(): 316 | env_type = env._entry_point.split(':')[0].split('.')[-1] 317 | _game_envs[env_type].add(env.id) # This is a set so add is idempotent 318 | 319 | if env_id in _game_envs.keys(): 320 | env_type = env_id 321 | env_id = [g for g in _game_envs[env_type]][0] 322 | else: 323 | env_type = None 324 | for g, e in _game_envs.items(): 325 | if env_id in e: 326 | env_type = g 327 | break 328 | assert env_type is not None, 'env_id {} is not recognized in env types'.format( 329 | env_id, _game_envs.keys()) 330 | 331 | return env_type, env_id 332 | 333 | 334 | class ImageToPyTorch(gym.ObservationWrapper): 335 | """ 336 | Image shape to channels x weight x height 337 | """ 338 | 339 | def __init__(self, env): 340 | super(ImageToPyTorch, self).__init__(env) 341 | old_shape = self.observation_space.shape 342 | self.observation_space = gym.spaces.Box( 343 | low=0.0, 344 | high=1.0, 345 | shape=(old_shape[-1], old_shape[0], old_shape[1]), 346 | dtype=np.uint8) 347 | 348 | def observation(self, observation): 349 | return np.swapaxes(observation, 2, 0) 350 | 351 | 352 | def wrap_pytorch(env): 353 | return ImageToPyTorch(env) 354 | 355 | -------------------------------------------------------------------------------- /src/algos/rnd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | import threading 11 | import time 12 | import timeit 13 | import pprint 14 | 15 | import numpy as np 16 | 17 | import torch 18 | from torch import multiprocessing as mp 19 | from torch import nn 20 | from torch.nn import functional as F 21 | 22 | from src.core import file_writer 23 | from src.core import prof 24 | from src.core import vtrace 25 | 26 | import src.models as models 27 | import src.losses as losses 28 | 29 | from src.env_utils import FrameStack 30 | from src.utils import get_batch, log, create_env, create_buffers, act 31 | 32 | MinigridPolicyNet = models.MinigridPolicyNet 33 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet 34 | 35 | def learn(actor_model, 36 | model, 37 | random_target_network, 38 | predictor_network, 39 | batch, 40 | initial_agent_state, 41 | optimizer, 42 | predictor_optimizer, 43 | scheduler, 44 | flags, 45 | frames=None, 46 | lock=threading.Lock()): 47 | """Performs a learning (optimization) step.""" 48 | with lock: 49 | if flags.use_fullobs_intrinsic: 50 | random_embedding = random_target_network(batch, next_state=True)\ 51 | .reshape(flags.unroll_length, flags.batch_size, 128) 52 | predicted_embedding = predictor_network(batch, next_state=True)\ 53 | .reshape(flags.unroll_length, flags.batch_size, 128) 54 | else: 55 | random_embedding = random_target_network(batch['partial_obs'][1:].to(device=flags.device)) 56 | predicted_embedding = predictor_network(batch['partial_obs'][1:].to(device=flags.device)) 57 | 58 | intrinsic_rewards = torch.norm(predicted_embedding.detach() - random_embedding.detach(), dim=2, p=2) 59 | 60 | intrinsic_reward_coef = flags.intrinsic_reward_coef 61 | intrinsic_rewards *= intrinsic_reward_coef 62 | 63 | num_samples = flags.unroll_length * flags.batch_size 64 | actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy() 65 | intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy() 66 | 67 | rnd_loss = flags.rnd_loss_coef * \ 68 | losses.compute_forward_dynamics_loss(predicted_embedding, random_embedding.detach()) 69 | 70 | learner_outputs, unused_state = model(batch, initial_agent_state) 71 | 72 | bootstrap_value = learner_outputs['baseline'][-1] 73 | 74 | batch = {key: tensor[1:] for key, tensor in batch.items()} 75 | learner_outputs = { 76 | key: tensor[:-1] 77 | for key, tensor in learner_outputs.items() 78 | } 79 | 80 | rewards = batch['reward'] 81 | 82 | if flags.no_reward: 83 | total_rewards = intrinsic_rewards 84 | else: 85 | total_rewards = rewards + intrinsic_rewards 86 | clipped_rewards = torch.clamp(total_rewards, -1, 1) 87 | 88 | discounts = (~batch['done']).float() * flags.discounting 89 | 90 | vtrace_returns = vtrace.from_logits( 91 | behavior_policy_logits=batch['policy_logits'], 92 | target_policy_logits=learner_outputs['policy_logits'], 93 | actions=batch['action'], 94 | discounts=discounts, 95 | rewards=clipped_rewards, 96 | values=learner_outputs['baseline'], 97 | bootstrap_value=bootstrap_value) 98 | 99 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], 100 | batch['action'], 101 | vtrace_returns.pg_advantages) 102 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( 103 | vtrace_returns.vs - learner_outputs['baseline']) 104 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( 105 | learner_outputs['policy_logits']) 106 | 107 | total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss 108 | 109 | episode_returns = batch['episode_return'][batch['done']] 110 | stats = { 111 | 'mean_episode_return': torch.mean(episode_returns).item(), 112 | 'total_loss': total_loss.item(), 113 | 'pg_loss': pg_loss.item(), 114 | 'baseline_loss': baseline_loss.item(), 115 | 'entropy_loss': entropy_loss.item(), 116 | 'rnd_loss': rnd_loss.item(), 117 | 'mean_rewards': torch.mean(rewards).item(), 118 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 119 | 'mean_total_rewards': torch.mean(total_rewards).item(), 120 | } 121 | 122 | scheduler.step() 123 | optimizer.zero_grad() 124 | predictor_optimizer.zero_grad() 125 | total_loss.backward() 126 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) 127 | nn.utils.clip_grad_norm_(predictor_network.parameters(), flags.max_grad_norm) 128 | optimizer.step() 129 | predictor_optimizer.step() 130 | 131 | actor_model.load_state_dict(model.state_dict()) 132 | return stats 133 | 134 | 135 | def train(flags): 136 | if flags.xpid is None: 137 | flags.xpid = 'rnd-%s' % time.strftime('%Y%m%d-%H%M%S') 138 | plogger = file_writer.FileWriter( 139 | xpid=flags.xpid, 140 | xp_args=flags.__dict__, 141 | rootdir=flags.savedir, 142 | ) 143 | 144 | checkpointpath = os.path.expandvars( 145 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 146 | 'model.tar'))) 147 | 148 | T = flags.unroll_length 149 | B = flags.batch_size 150 | 151 | flags.device = None 152 | if not flags.disable_cuda and torch.cuda.is_available(): 153 | log.info('Using CUDA.') 154 | flags.device = torch.device('cuda') 155 | else: 156 | log.info('Not using CUDA.') 157 | flags.device = torch.device('cpu') 158 | 159 | env = create_env(flags) 160 | if flags.num_input_frames > 1: 161 | env = FrameStack(env, flags.num_input_frames) 162 | 163 | if 'MiniGrid' in flags.env: 164 | if flags.use_fullobs_policy: 165 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n) 166 | else: 167 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n) 168 | if flags.use_fullobs_intrinsic: 169 | random_target_network = FullObsMinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device) 170 | predictor_network = FullObsMinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device) 171 | else: 172 | random_target_network = MinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device) 173 | predictor_network = MinigridStateEmbeddingNet(env.observation_space.shape).to(device=flags.device) 174 | else: 175 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n) 176 | random_target_network = MarioDoomStateEmbeddingNet(env.observation_space.shape).to(device=flags.device) 177 | predictor_network = MarioDoomStateEmbeddingNet(env.observation_space.shape).to(device=flags.device) 178 | 179 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags) 180 | 181 | model.share_memory() 182 | 183 | initial_agent_state_buffers = [] 184 | for _ in range(flags.num_buffers): 185 | state = model.initial_state(batch_size=1) 186 | for t in state: 187 | t.share_memory_() 188 | initial_agent_state_buffers.append(state) 189 | 190 | actor_processes = [] 191 | ctx = mp.get_context('fork') 192 | free_queue = ctx.SimpleQueue() 193 | full_queue = ctx.SimpleQueue() 194 | 195 | episode_state_count_dict = dict() 196 | train_state_count_dict = dict() 197 | for i in range(flags.num_actors): 198 | actor = ctx.Process( 199 | target=act, 200 | args=(i, free_queue, full_queue, model, buffers, 201 | episode_state_count_dict, train_state_count_dict, 202 | initial_agent_state_buffers, flags)) 203 | actor.start() 204 | actor_processes.append(actor) 205 | 206 | if 'MiniGrid' in flags.env: 207 | if flags.use_fullobs_policy: 208 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 209 | .to(device=flags.device) 210 | else: 211 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 212 | .to(device=flags.device) 213 | else: 214 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\ 215 | .to(device=flags.device) 216 | 217 | optimizer = torch.optim.RMSprop( 218 | learner_model.parameters(), 219 | lr=flags.learning_rate, 220 | momentum=flags.momentum, 221 | eps=flags.epsilon, 222 | alpha=flags.alpha) 223 | 224 | predictor_optimizer = torch.optim.RMSprop( 225 | predictor_network.parameters(), 226 | lr=flags.learning_rate, 227 | momentum=flags.momentum, 228 | eps=flags.epsilon, 229 | alpha=flags.alpha) 230 | 231 | 232 | def lr_lambda(epoch): 233 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames 234 | 235 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 236 | 237 | logger = logging.getLogger('logfile') 238 | stat_keys = [ 239 | 'total_loss', 240 | 'mean_episode_return', 241 | 'pg_loss', 242 | 'baseline_loss', 243 | 'entropy_loss', 244 | 'rnd_loss', 245 | 'mean_rewards', 246 | 'mean_intrinsic_rewards', 247 | 'mean_total_rewards', 248 | ] 249 | 250 | logger.info('# Step\t%s', '\t'.join(stat_keys)) 251 | 252 | frames, stats = 0, {} 253 | 254 | 255 | def batch_and_learn(i, lock=threading.Lock()): 256 | """Thread target for the learning process.""" 257 | nonlocal frames, stats 258 | timings = prof.Timings() 259 | while frames < flags.total_frames: 260 | timings.reset() 261 | batch, agent_state = get_batch(free_queue, full_queue, buffers, 262 | initial_agent_state_buffers, flags, timings) 263 | stats = learn(model, learner_model, random_target_network, predictor_network, 264 | batch, agent_state, optimizer, predictor_optimizer, scheduler, 265 | flags, frames=frames) 266 | timings.time('learn') 267 | with lock: 268 | to_log = dict(frames=frames) 269 | to_log.update({k: stats[k] for k in stat_keys}) 270 | plogger.log(to_log) 271 | frames += T * B 272 | 273 | if i == 0: 274 | log.info('Batch and learn: %s', timings.summary()) 275 | 276 | for m in range(flags.num_buffers): 277 | free_queue.put(m) 278 | 279 | threads = [] 280 | for i in range(flags.num_threads): 281 | thread = threading.Thread( 282 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) 283 | thread.start() 284 | threads.append(thread) 285 | 286 | 287 | def checkpoint(frames): 288 | if flags.disable_checkpoint: 289 | return 290 | checkpointpath = os.path.expandvars(os.path.expanduser( 291 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar'))) 292 | log.info('Saving checkpoint to %s', checkpointpath) 293 | torch.save({ 294 | 'model_state_dict': model.state_dict(), 295 | 'random_target_network_state_dict': random_target_network.state_dict(), 296 | 'predictor_network_state_dict': predictor_network.state_dict(), 297 | 'optimizer_state_dict': optimizer.state_dict(), 298 | 'predictor_optimizer_state_dict': predictor_optimizer.state_dict(), 299 | 'scheduler_state_dict': scheduler.state_dict(), 300 | 'flags': vars(flags), 301 | }, checkpointpath) 302 | 303 | timer = timeit.default_timer 304 | try: 305 | last_checkpoint_time = timer() 306 | while frames < flags.total_frames: 307 | start_frames = frames 308 | start_time = timer() 309 | time.sleep(5) 310 | 311 | if timer() - last_checkpoint_time > flags.save_interval * 60: 312 | checkpoint(frames) 313 | last_checkpoint_time = timer() 314 | 315 | fps = (frames - start_frames) / (timer() - start_time) 316 | 317 | if stats.get('episode_returns', None): 318 | mean_return = 'Return per episode: %.1f. ' % stats[ 319 | 'mean_episode_return'] 320 | else: 321 | mean_return = '' 322 | 323 | total_loss = stats.get('total_loss', float('inf')) 324 | if stats: 325 | log.info('After %i frames: loss %f @ %.1f fps. Mean Return %.1f. \n Stats \n %s', \ 326 | frames, total_loss, fps, stats['mean_episode_return'], pprint.pformat(stats)) 327 | 328 | except KeyboardInterrupt: 329 | return 330 | else: 331 | for thread in threads: 332 | thread.join() 333 | log.info('Learning finished after %d frames.', frames) 334 | finally: 335 | for _ in range(flags.num_actors): 336 | free_queue.put(None) 337 | for actor in actor_processes: 338 | actor.join(timeout=1) 339 | 340 | checkpoint(frames) 341 | plogger.close() 342 | 343 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from __future__ import division 7 | import torch.nn as nn 8 | import torch 9 | import typing 10 | import gym 11 | import threading 12 | from torch import multiprocessing as mp 13 | import logging 14 | import traceback 15 | import os 16 | import numpy as np 17 | import copy 18 | 19 | from src.core import prof 20 | from src.env_utils import FrameStack, Environment, Minigrid2Image 21 | from src import atari_wrappers as atari_wrappers 22 | 23 | from gym_minigrid import wrappers as wrappers 24 | 25 | # from nes_py.wrappers import JoypadSpace 26 | # from gym_super_mario_bros.actions import COMPLEX_MOVEMENT 27 | 28 | # import vizdoomgym 29 | OBJECT_TO_IDX = { 30 | 'unseen' : 0, 31 | 'empty' : 1, 32 | 'wall' : 2, 33 | 'floor' : 3, 34 | 'door' : 4, 35 | 'key' : 5, 36 | 'ball' : 6, 37 | 'box' : 7, 38 | 'goal' : 8, 39 | 'lava' : 9, 40 | 'agent' : 10, 41 | } 42 | 43 | # This augmentation is based on random walk of agents 44 | def augmentation(frames): 45 | # agent_loc = agent_loc(frames) 46 | return frames 47 | 48 | # Entropy loss on categorical distribution 49 | def catentropy(logits): 50 | a = logits - torch.max(logits, dim=-1, keepdim=True)[0] 51 | e = torch.exp(a) 52 | z = torch.sum(e, dim=-1, keepdim=True) 53 | p = e / z 54 | entropy = torch.sum(p * (torch.log(z) - a), dim=-1) 55 | return torch.mean(entropy) 56 | 57 | # Here is computing how many objects 58 | def num_objects(frames): 59 | T, B, H, W, *_ = frames.shape 60 | num_objects = frames[:, :, :, :, 0] 61 | num_objects = (num_objects == 4).long() + (num_objects == 5).long() + \ 62 | (num_objects == 6).long() + (num_objects == 7).long() + (num_objects == 8).long() 63 | return num_objects 64 | 65 | # EMA of the 2 networks 66 | def soft_update_params(net, target_net, tau): 67 | for param, target_param in zip(net.parameters(), target_net.parameters()): 68 | target_param.data.copy_( 69 | tau * param.data + (1 - tau) * target_param.data 70 | ) 71 | 72 | def agent_loc(frames): 73 | T, B, H, W, *_ = frames.shape 74 | agent_location = torch.flatten(frames, 2, 3) 75 | agent_location = agent_location[:,:,:,0] 76 | agent_location = (agent_location == 10).nonzero() #select object id 77 | agent_location = agent_location[:,2] 78 | agent_location = torch.cat(((agent_location//W).unsqueeze(-1), (agent_location%W).unsqueeze(-1)), dim=-1) 79 | agent_location = agent_location.view(-1).tolist() 80 | return agent_location 81 | 82 | COMPLETE_MOVEMENT = [ 83 | ['NOOP'], 84 | ['up'], 85 | ['down'], 86 | ['left'], 87 | ['left', 'A'], 88 | ['left', 'B'], 89 | ['left', 'A', 'B'], 90 | ['right'], 91 | ['right', 'A'], 92 | ['right', 'B'], 93 | ['right', 'A', 'B'], 94 | ['A'], 95 | ['B'], 96 | ['A', 'B'], 97 | ] 98 | 99 | shandle = logging.StreamHandler() 100 | shandle.setFormatter( 101 | logging.Formatter( 102 | '[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] ' 103 | '%(message)s')) 104 | log = logging.getLogger('torchbeast') 105 | log.propagate = False 106 | log.addHandler(shandle) 107 | log.setLevel(logging.INFO) 108 | 109 | Buffers = typing.Dict[str, typing.List[torch.Tensor]] 110 | 111 | 112 | def create_env(flags): 113 | if 'MiniGrid' in flags.env: 114 | return Minigrid2Image(wrappers.FullyObsWrapper(gym.make(flags.env))) 115 | elif 'Mario' in flags.env: 116 | env = atari_wrappers.wrap_pytorch( 117 | atari_wrappers.wrap_deepmind( 118 | atari_wrappers.make_atari(flags.env, noop=True), 119 | clip_rewards=False, 120 | frame_stack=True, 121 | scale=False, 122 | fire=True)) 123 | env = JoypadSpace(env, COMPLETE_MOVEMENT) 124 | return env 125 | else: 126 | env = atari_wrappers.wrap_pytorch( 127 | atari_wrappers.wrap_deepmind( 128 | atari_wrappers.make_atari(flags.env, noop=False), 129 | clip_rewards=False, 130 | frame_stack=True, 131 | scale=False, 132 | fire=False)) 133 | return env 134 | 135 | 136 | def get_batch(free_queue: mp.Queue, 137 | full_queue: mp.Queue, 138 | buffers: Buffers, 139 | initial_agent_state_buffers, 140 | initial_encoder_state_buffers, 141 | flags, 142 | timings, 143 | lock=threading.Lock()): 144 | with lock: 145 | timings.time('lock') 146 | indices = [full_queue.get() for _ in range(flags.batch_size)] 147 | timings.time('dequeue') 148 | batch = { 149 | key: torch.stack([buffers[key][m] for m in indices], dim=1) 150 | for key in buffers 151 | } 152 | initial_agent_state = ( 153 | torch.cat(ts, dim=1) 154 | for ts in zip(*[initial_agent_state_buffers[m] for m in indices]) 155 | ) 156 | initial_encoder_state = ( 157 | torch.cat(ts, dim=1) 158 | for ts in zip(*[initial_encoder_state_buffers[m] for m in indices]) 159 | ) 160 | timings.time('batch') 161 | for m in indices: 162 | free_queue.put(m) 163 | timings.time('enqueue') 164 | batch = { 165 | k: t.to(device=flags.device, non_blocking=True) 166 | for k, t in batch.items() 167 | } 168 | initial_agent_state = tuple(t.to(device=flags.device, non_blocking=True) 169 | for t in initial_agent_state) 170 | initial_encoder_state = tuple(t.to(device=flags.device, non_blocking=True) 171 | for t in initial_encoder_state) 172 | timings.time('device') 173 | return batch, initial_agent_state, initial_encoder_state 174 | 175 | def create_buffers(obs_shape, num_actions, flags) -> Buffers: 176 | T = flags.unroll_length 177 | specs = dict( 178 | frame=dict(size=(T + 1, *obs_shape), dtype=torch.uint8), 179 | reward=dict(size=(T + 1,), dtype=torch.float32), 180 | done=dict(size=(T + 1,), dtype=torch.uint8), 181 | episode_return=dict(size=(T + 1,), dtype=torch.float32), 182 | episode_step=dict(size=(T + 1,), dtype=torch.int32), 183 | last_action=dict(size=(T + 1,), dtype=torch.int64), 184 | policy_logits=dict(size=(T + 1, num_actions), dtype=torch.float32), 185 | baseline=dict(size=(T + 1,), dtype=torch.float32), 186 | action=dict(size=(T + 1,), dtype=torch.int64), 187 | episode_win=dict(size=(T + 1,), dtype=torch.int32), 188 | carried_obj=dict(size=(T + 1,), dtype=torch.int32), 189 | carried_col=dict(size=(T + 1,), dtype=torch.int32), 190 | partial_obs=dict(size=(T + 1, 7, 7, 3), dtype=torch.uint8), 191 | episode_state_count=dict(size=(T + 1, ), dtype=torch.float32), 192 | train_state_count=dict(size=(T + 1, ), dtype=torch.float32), 193 | partial_state_count=dict(size=(T + 1, ), dtype=torch.float32), 194 | encoded_state_count=dict(size=(T + 1, ), dtype=torch.float32), 195 | ) 196 | buffers: Buffers = {key: [] for key in specs} 197 | for _ in range(flags.num_buffers): 198 | for key in buffers: 199 | buffers[key].append(torch.empty(**specs[key]).share_memory_()) 200 | return buffers 201 | 202 | def create_heatmap_buffers(obs_shape): 203 | specs = [] 204 | for r in range(obs_shape[0]): 205 | for c in range(obs_shape[1]): 206 | specs.append(tuple([r, c])) 207 | buffers: Buffers = {key: torch.zeros(1).share_memory_() for key in specs} 208 | return buffers 209 | 210 | def act(i: int, free_queue: mp.Queue, full_queue: mp.Queue, 211 | model: torch.nn.Module, 212 | encoder: torch.nn.Module, 213 | buffers: Buffers, 214 | # inv_dynamics_buf: Buffers, 215 | # inv_dynamics_dict: dict, 216 | episode_state_count_dict: dict, train_state_count_dict: dict, 217 | partial_state_count_dict: dict, encoded_state_count_dict: dict, 218 | heatmap_dict: dict, 219 | heatmap_buffers: Buffers, 220 | initial_agent_state_buffers, 221 | initial_encoder_state_buffers, 222 | flags): 223 | try: 224 | log.info('Actor %i started.', i) 225 | timings = prof.Timings() 226 | 227 | gym_env = create_env(flags) 228 | seed = i ^ int.from_bytes(os.urandom(4), byteorder='little') 229 | gym_env.seed(seed) 230 | 231 | if flags.num_input_frames > 1: 232 | gym_env = FrameStack(gym_env, flags.num_input_frames) 233 | 234 | env = Environment(gym_env, fix_seed=flags.fix_seed, env_seed=flags.env_seed) 235 | 236 | env_output = env.initial() 237 | agent_state = model.initial_state(batch_size=1) 238 | encoder_state = encoder.initial_state(batch_size=1) 239 | agent_output, unused_state = model(env_output, agent_state) 240 | prev_env_output = None 241 | 242 | while True: 243 | index = free_queue.get() 244 | if index is None: 245 | break 246 | 247 | # Write old rollout end. 248 | for key in env_output: 249 | buffers[key][index][0, ...] = env_output[key] 250 | for key in agent_output: 251 | buffers[key][index][0, ...] = agent_output[key] 252 | for i, tensor in enumerate(agent_state): 253 | initial_agent_state_buffers[index][i][...] = tensor 254 | for i, tensor in enumerate(encoder_state): 255 | initial_encoder_state_buffers[index][i][...] = tensor 256 | 257 | 258 | # Update the episodic state counts 259 | episode_state_key = tuple(env_output['frame'].view(-1).tolist()) 260 | if episode_state_key in episode_state_count_dict: 261 | episode_state_count_dict[episode_state_key] += 1 262 | else: 263 | episode_state_count_dict.update({episode_state_key: 1}) 264 | buffers['episode_state_count'][index][0, ...] = \ 265 | torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key))) 266 | 267 | # Reset the episode state counts when the episode is over 268 | if env_output['done'][0][0]: 269 | for episode_state_key in episode_state_count_dict: 270 | episode_state_count_dict = dict() 271 | 272 | # Update the training state counts 273 | train_state_key = tuple(env_output['frame'].view(-1).tolist()) 274 | if train_state_key in train_state_count_dict: 275 | train_state_count_dict[train_state_key] += 1 276 | else: 277 | train_state_count_dict.update({train_state_key: 1}) 278 | buffers['train_state_count'][index][0, ...] = \ 279 | torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key))) 280 | partial_state_key = tuple(env_output['partial_obs'].view(-1).tolist()) 281 | if partial_state_key in partial_state_count_dict: 282 | partial_state_count_dict[partial_state_key] += 1 283 | else: 284 | partial_state_count_dict.update({partial_state_key: 1}) 285 | buffers['partial_state_count'][index][0, ...] = \ 286 | torch.tensor(1 / np.sqrt(partial_state_count_dict.get(partial_state_key))) 287 | # Update the agent position counts 288 | heatmap_key = tuple(agent_loc(env_output['frame'])) 289 | heatmap_buffers[heatmap_key] += 1 290 | 291 | # Do new rollout 292 | for t in range(flags.unroll_length): 293 | timings.reset() 294 | 295 | with torch.no_grad(): 296 | agent_output, agent_state = model(env_output, agent_state) 297 | _, encoder_state = encoder(env_output['partial_obs'], encoder_state, env_output['done']) 298 | 299 | timings.time('model') 300 | 301 | prev_env_output = copy.deepcopy(env_output) 302 | env_output = env.step(agent_output['action']) 303 | 304 | timings.time('step') 305 | 306 | for key in env_output: 307 | buffers[key][index][t + 1, ...] = env_output[key] 308 | 309 | for key in agent_output: 310 | buffers[key][index][t + 1, ...] = agent_output[key] 311 | 312 | # Update the episodic state counts 313 | episode_state_key = tuple(env_output['frame'].view(-1).tolist()) 314 | if episode_state_key in episode_state_count_dict: 315 | episode_state_count_dict[episode_state_key] += 1 316 | else: 317 | episode_state_count_dict.update({episode_state_key: 1}) 318 | buffers['episode_state_count'][index][t + 1, ...] = \ 319 | torch.tensor(1 / np.sqrt(episode_state_count_dict.get(episode_state_key))) 320 | 321 | # Reset the episode state counts when the episode is over 322 | if env_output['done'][0][0]: 323 | episode_state_count_dict = dict() 324 | 325 | # Update the training state counts 326 | train_state_key = tuple(env_output['frame'].view(-1).tolist()) 327 | if train_state_key in train_state_count_dict: 328 | train_state_count_dict[train_state_key] += 1 329 | else: 330 | train_state_count_dict.update({train_state_key: 1}) 331 | buffers['train_state_count'][index][t + 1, ...] = \ 332 | torch.tensor(1 / np.sqrt(train_state_count_dict.get(train_state_key))) 333 | partial_state_key = tuple(env_output['partial_obs'].view(-1).tolist()) 334 | if partial_state_key in partial_state_count_dict: 335 | partial_state_count_dict[partial_state_key] += 1 336 | else: 337 | partial_state_count_dict.update({partial_state_key: 1}) 338 | buffers['partial_state_count'][index][t + 1, ...] = \ 339 | torch.tensor(1 / np.sqrt(partial_state_count_dict.get(partial_state_key))) 340 | # Update the agent position counts 341 | heatmap_key = tuple(agent_loc(env_output['frame'])) 342 | heatmap_buffers[heatmap_key] += 1 343 | 344 | timings.time('write') 345 | full_queue.put(index) 346 | 347 | if i == 0: 348 | log.info('Actor %i: %s', i, timings.summary()) 349 | 350 | except KeyboardInterrupt: 351 | pass 352 | except Exception as e: 353 | logging.error('Exception in worker process %i', i) 354 | traceback.print_exc() 355 | print() 356 | raise e 357 | 358 | -------------------------------------------------------------------------------- /src/algos/bebold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | import threading 11 | import time 12 | import timeit 13 | import pprint 14 | import json 15 | 16 | import numpy as np 17 | 18 | import torch 19 | from torch import multiprocessing as mp 20 | from torch import nn 21 | from torch.nn import functional as F 22 | 23 | from src.core import file_writer 24 | from src.core import prof 25 | from src.core import vtrace 26 | 27 | import src.models as models 28 | import src.losses as losses 29 | 30 | from src.env_utils import FrameStack 31 | from src.utils import get_batch, log, create_env, create_buffers, act, create_heatmap_buffers 32 | 33 | MinigridPolicyNet = models.MinigridPolicyNet 34 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet 35 | MinigridMLPEmbeddingNet = models.MinigridMLPEmbeddingNet 36 | MinigridMLPTargetEmbeddingNet = models.MinigridMLPTargetEmbeddingNet 37 | 38 | def momentum_update(model, target, ema_momentum): 39 | ''' 40 | Update the key_encoder parameters through the momentum update: 41 | key_params = momentum * key_params + (1 - momentum) * query_params 42 | ''' 43 | # For each of the parameters in each encoder 44 | for p_m, p_t in zip(model.parameters(), target.parameters()): 45 | p_m.data = p_m.data * ema_momentum + p_t.detach().data * (1. - ema_momentum) 46 | # For each of the buffers in each encoder 47 | for b_m, b_t in zip(model.buffers(), target.buffers()): 48 | b_m.data = b_m.data * ema_momentum + b_t.detach().data * (1. - ema_momentum) 49 | 50 | def learn(actor_model, 51 | model, 52 | random_target_network, 53 | predictor_network, 54 | actor_encoder, 55 | encoder, 56 | batch, 57 | initial_agent_state, 58 | initial_encoder_state, 59 | optimizer, 60 | predictor_optimizer, 61 | scheduler, 62 | flags, 63 | frames=None, 64 | lock=threading.Lock()): 65 | """Performs a learning (optimization) step.""" 66 | with lock: 67 | count_rewards = torch.ones((flags.unroll_length, flags.batch_size), 68 | dtype=torch.float32).to(device=flags.device) 69 | # Use the scale of square root N 70 | count_rewards = batch['episode_state_count'][1:].float().to(device=flags.device) 71 | 72 | encoded_states, unused_state = encoder(batch['partial_obs'].to(device=flags.device), initial_encoder_state, batch['done']) 73 | random_embedding_next, unused_state = random_target_network(encoded_states[1:].detach(), initial_agent_state) 74 | predicted_embedding_next, unused_state = predictor_network(encoded_states[1:].detach(), initial_agent_state) 75 | random_embedding, unused_state = random_target_network(encoded_states[:-1].detach(), initial_agent_state) 76 | predicted_embedding, unused_state = predictor_network(encoded_states[:-1].detach(), initial_agent_state) 77 | 78 | intrinsic_rewards_next = torch.norm(predicted_embedding_next.detach() - random_embedding_next.detach(), dim=2, p=2) 79 | intrinsic_rewards = torch.norm(predicted_embedding.detach() - random_embedding.detach(), dim=2, p=2) 80 | intrinsic_rewards = torch.clamp(intrinsic_rewards_next - flags.scale_fac * intrinsic_rewards, min=0) 81 | intrinsic_rewards *= (count_rewards == 1).float() 82 | 83 | intrinsic_reward_coef = flags.intrinsic_reward_coef 84 | intrinsic_rewards *= count_rewards * intrinsic_reward_coef 85 | 86 | num_samples = flags.unroll_length * flags.batch_size 87 | actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy() 88 | intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy() 89 | rnd_loss = flags.rnd_loss_coef * \ 90 | losses.compute_rnd_loss(predicted_embedding_next, random_embedding_next.detach()) 91 | 92 | learner_outputs, unused_state = model(batch, initial_agent_state) 93 | 94 | bootstrap_value = learner_outputs['baseline'][-1] 95 | 96 | batch = {key: tensor[1:] for key, tensor in batch.items()} 97 | learner_outputs = { 98 | key: tensor[:-1] 99 | for key, tensor in learner_outputs.items() 100 | } 101 | 102 | rewards = batch['reward'] 103 | 104 | if flags.no_reward: 105 | total_rewards = intrinsic_rewards 106 | else: 107 | total_rewards = rewards + intrinsic_rewards 108 | clipped_rewards = torch.clamp(total_rewards, -1, 1) 109 | 110 | discounts = (~batch['done']).float() * flags.discounting 111 | 112 | vtrace_returns = vtrace.from_logits( 113 | behavior_policy_logits=batch['policy_logits'], 114 | target_policy_logits=learner_outputs['policy_logits'], 115 | actions=batch['action'], 116 | discounts=discounts, 117 | rewards=clipped_rewards, 118 | values=learner_outputs['baseline'], 119 | bootstrap_value=bootstrap_value) 120 | 121 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], 122 | batch['action'], 123 | vtrace_returns.pg_advantages) 124 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( 125 | vtrace_returns.vs - learner_outputs['baseline']) 126 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( 127 | learner_outputs['policy_logits']) 128 | 129 | total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss 130 | 131 | episode_returns = batch['episode_return'][batch['done']] 132 | stats = { 133 | 'mean_episode_return': torch.mean(episode_returns).item(), 134 | 'total_loss': total_loss.item(), 135 | 'pg_loss': pg_loss.item(), 136 | 'baseline_loss': baseline_loss.item(), 137 | 'entropy_loss': entropy_loss.item(), 138 | 'rnd_loss': rnd_loss.item(), 139 | 'mean_rewards': torch.mean(rewards).item(), 140 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 141 | 'mean_total_rewards': torch.mean(total_rewards).item(), 142 | } 143 | 144 | scheduler.step() 145 | optimizer.zero_grad() 146 | predictor_optimizer.zero_grad() 147 | total_loss.backward() 148 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) 149 | nn.utils.clip_grad_norm_(predictor_network.parameters(), flags.max_grad_norm) 150 | optimizer.step() 151 | predictor_optimizer.step() 152 | 153 | actor_model.load_state_dict(model.state_dict()) 154 | actor_encoder.load_state_dict(encoder.state_dict()) 155 | return stats 156 | 157 | 158 | def train(flags): 159 | if flags.xpid is None: 160 | flags.xpid = flags.env + '-bebold-%s' % time.strftime('%Y%m%d-%H%M%S') 161 | plogger = file_writer.FileWriter( 162 | xpid=flags.xpid, 163 | xp_args=flags.__dict__, 164 | rootdir=flags.savedir, 165 | ) 166 | 167 | checkpointpath = os.path.expandvars( 168 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 169 | 'model.tar'))) 170 | 171 | T = flags.unroll_length 172 | B = flags.batch_size 173 | 174 | flags.device = None 175 | if not flags.disable_cuda and torch.cuda.is_available(): 176 | log.info('Using CUDA.') 177 | flags.device = torch.device('cuda') 178 | else: 179 | log.info('Not using CUDA.') 180 | flags.device = torch.device('cpu') 181 | 182 | env = create_env(flags) 183 | if flags.num_input_frames > 1: 184 | env = FrameStack(env, flags.num_input_frames) 185 | 186 | if 'MiniGrid' in flags.env: 187 | if flags.use_fullobs_policy: 188 | raise Exception('We have not implemented full ob policy!') 189 | else: 190 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n) 191 | random_target_network = MinigridMLPTargetEmbeddingNet().to(device=flags.device) 192 | predictor_network = MinigridMLPEmbeddingNet().to(device=flags.device) 193 | encoder = MinigridStateEmbeddingNet(env.observation_space.shape, flags.use_lstm) 194 | else: 195 | raise Exception('Only MiniGrid is suppported Now!') 196 | 197 | momentum_update(encoder.feat_extract, model.feat_extract, 0) 198 | momentum_update(encoder.fc, model.fc, 0) 199 | if flags.use_lstm: 200 | momentum_update(encoder.core, model.core, 0) 201 | 202 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags) 203 | heatmap_buffers = create_heatmap_buffers(env.observation_space.shape) 204 | model.share_memory() 205 | encoder.share_memory() 206 | 207 | initial_agent_state_buffers = [] 208 | for _ in range(flags.num_buffers): 209 | state = model.initial_state(batch_size=1) 210 | for t in state: 211 | t.share_memory_() 212 | initial_agent_state_buffers.append(state) 213 | initial_encoder_state_buffers = [] 214 | for _ in range(flags.num_buffers): 215 | state = encoder.initial_state(batch_size=1) 216 | for t in state: 217 | t.share_memory_() 218 | initial_encoder_state_buffers.append(state) 219 | 220 | actor_processes = [] 221 | ctx = mp.get_context('fork') 222 | free_queue = ctx.Queue() 223 | full_queue = ctx.Queue() 224 | 225 | episode_state_count_dict = dict() 226 | train_state_count_dict = dict() 227 | partial_state_count_dict = dict() 228 | encoded_state_count_dict = dict() 229 | heatmap_dict = dict() 230 | for i in range(flags.num_actors): 231 | actor = ctx.Process( 232 | target=act, 233 | args=(i, free_queue, full_queue, model, encoder, buffers, 234 | episode_state_count_dict, train_state_count_dict, partial_state_count_dict, encoded_state_count_dict, 235 | heatmap_dict, heatmap_buffers, initial_agent_state_buffers, initial_encoder_state_buffers, flags)) 236 | actor.start() 237 | actor_processes.append(actor) 238 | 239 | if 'MiniGrid' in flags.env: 240 | if flags.use_fullobs_policy: 241 | raise Exception('We have not implemented full ob policy!') 242 | else: 243 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 244 | .to(device=flags.device) 245 | learner_encoder = MinigridStateEmbeddingNet(env.observation_space.shape, flags.use_lstm)\ 246 | .to(device=flags.device) 247 | else: 248 | raise Exception('Only MiniGrid is suppported Now!') 249 | learner_encoder.load_state_dict(encoder.state_dict()) 250 | 251 | optimizer = torch.optim.RMSprop( 252 | learner_model.parameters(), 253 | lr=flags.learning_rate, 254 | momentum=flags.momentum, 255 | eps=flags.epsilon, 256 | alpha=flags.alpha) 257 | 258 | predictor_optimizer = torch.optim.Adam( 259 | predictor_network.parameters(), 260 | lr=flags.predictor_learning_rate) 261 | 262 | def lr_lambda(epoch): 263 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames 264 | 265 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 266 | 267 | logger = logging.getLogger('logfile') 268 | stat_keys = [ 269 | 'total_loss', 270 | 'mean_episode_return', 271 | 'pg_loss', 272 | 'baseline_loss', 273 | 'entropy_loss', 274 | 'rnd_loss', 275 | 'mean_rewards', 276 | 'mean_intrinsic_rewards', 277 | 'mean_total_rewards', 278 | ] 279 | 280 | logger.info('# Step\t%s', '\t'.join(stat_keys)) 281 | 282 | frames, stats = 0, {} 283 | 284 | def batch_and_learn(i, lock=threading.Lock()): 285 | """Thread target for the learning process.""" 286 | nonlocal frames, stats 287 | timings = prof.Timings() 288 | while frames < flags.total_frames: 289 | timings.reset() 290 | batch, agent_state, encoder_state = get_batch(free_queue, full_queue, buffers, 291 | initial_agent_state_buffers, initial_encoder_state_buffers, flags, timings) 292 | stats = learn(model, learner_model, random_target_network, predictor_network, 293 | encoder, learner_encoder, batch, agent_state, encoder_state, optimizer, 294 | predictor_optimizer, scheduler, flags, frames=frames) 295 | timings.time('learn') 296 | with lock: 297 | to_log = dict(frames=frames) 298 | to_log.update({k: stats[k] for k in stat_keys}) 299 | plogger.log(to_log) 300 | frames += T * B 301 | 302 | if i == 0: 303 | log.info('Batch and learn: %s', timings.summary()) 304 | 305 | for m in range(flags.num_buffers): 306 | free_queue.put(m) 307 | 308 | threads = [] 309 | for i in range(flags.num_threads): 310 | thread = threading.Thread( 311 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) 312 | thread.start() 313 | threads.append(thread) 314 | 315 | 316 | def checkpoint(frames): 317 | if flags.disable_checkpoint: 318 | return 319 | checkpointpath = os.path.expandvars(os.path.expanduser( 320 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model_'+str(frames)+'.tar'))) 321 | log.info('Saving checkpoint to %s', checkpointpath) 322 | torch.save({ 323 | 'model_state_dict': model.state_dict(), 324 | 'encoder': encoder.state_dict(), 325 | 'random_target_network_state_dict': random_target_network.state_dict(), 326 | 'predictor_network_state_dict': predictor_network.state_dict(), 327 | 'optimizer_state_dict': optimizer.state_dict(), 328 | 'predictor_optimizer_state_dict': predictor_optimizer.state_dict(), 329 | 'scheduler_state_dict': scheduler.state_dict(), 330 | 'flags': vars(flags), 331 | }, checkpointpath) 332 | 333 | timer = timeit.default_timer 334 | try: 335 | last_checkpoint_time = timer() 336 | while frames < flags.total_frames: 337 | start_frames = frames 338 | start_time = timer() 339 | time.sleep(5) 340 | 341 | if timer() - last_checkpoint_time > flags.save_interval * 60: 342 | checkpoint(frames) 343 | save_heatmap(frames) 344 | last_checkpoint_time = timer() 345 | 346 | fps = (frames - start_frames) / (timer() - start_time) 347 | 348 | if stats.get('episode_returns', None): 349 | mean_return = 'Return per episode: %.1f. ' % stats[ 350 | 'mean_episode_return'] 351 | else: 352 | mean_return = '' 353 | 354 | total_loss = stats.get('total_loss', float('inf')) 355 | if stats: 356 | log.info('After %i frames: loss %f @ %.1f fps. Mean Return %.1f. \n Stats \n %s', \ 357 | frames, total_loss, fps, stats['mean_episode_return'], pprint.pformat(stats)) 358 | 359 | except KeyboardInterrupt: 360 | return 361 | else: 362 | for thread in threads: 363 | thread.join() 364 | log.info('Learning finished after %d frames.', frames) 365 | finally: 366 | for _ in range(flags.num_actors): 367 | free_queue.put(None) 368 | for actor in actor_processes: 369 | actor.join(timeout=1) 370 | 371 | checkpoint(frames) 372 | plogger.close() 373 | 374 | -------------------------------------------------------------------------------- /src/algos/curiosity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import sys 10 | import threading 11 | import time 12 | import timeit 13 | import pprint 14 | 15 | import numpy as np 16 | 17 | import torch 18 | from torch import multiprocessing as mp 19 | from torch import nn 20 | from torch.nn import functional as F 21 | 22 | from src.core import file_writer 23 | from src.core import prof 24 | from src.core import vtrace 25 | 26 | import src.models as models 27 | import src.losses as losses 28 | 29 | from src.env_utils import FrameStack 30 | from src.utils import get_batch, log, create_env, create_buffers, act 31 | 32 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet 33 | MinigridForwardDynamicsNet = models.MinigridForwardDynamicsNet 34 | MinigridInverseDynamicsNet = models.MinigridInverseDynamicsNet 35 | MinigridPolicyNet = models.MinigridPolicyNet 36 | 37 | def learn(actor_model, 38 | model, 39 | state_embedding_model, 40 | forward_dynamics_model, 41 | inverse_dynamics_model, 42 | batch, 43 | initial_agent_state, 44 | optimizer, 45 | state_embedding_optimizer, 46 | forward_dynamics_optimizer, 47 | inverse_dynamics_optimizer, 48 | scheduler, 49 | flags, 50 | frames=None, 51 | lock=threading.Lock()): 52 | """Performs a learning (optimization) step.""" 53 | with lock: 54 | if flags.use_fullobs_intrinsic: 55 | state_emb = state_embedding_model(batch, next_state=False)\ 56 | .reshape(flags.unroll_length, flags.batch_size, 128) 57 | next_state_emb = state_embedding_model(batch, next_state=True)\ 58 | .reshape(flags.unroll_length, flags.batch_size, 128) 59 | else: 60 | state_emb = state_embedding_model(batch['partial_obs'][:-1].to(device=flags.device)) 61 | next_state_emb = state_embedding_model(batch['partial_obs'][1:].to(device=flags.device)) 62 | 63 | pred_next_state_emb = forward_dynamics_model(\ 64 | state_emb, batch['action'][1:].to(device=flags.device)) 65 | pred_actions = inverse_dynamics_model(state_emb, next_state_emb) 66 | entropy_emb_actions = losses.compute_entropy_loss(pred_actions) 67 | 68 | intrinsic_rewards = torch.norm(pred_next_state_emb - next_state_emb, dim=2, p=2) 69 | 70 | intrinsic_reward_coef = flags.intrinsic_reward_coef 71 | intrinsic_rewards *= intrinsic_reward_coef 72 | 73 | forward_dynamics_loss = flags.forward_loss_coef * \ 74 | losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb) 75 | 76 | inverse_dynamics_loss = flags.inverse_loss_coef * \ 77 | losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:]) 78 | 79 | num_samples = flags.unroll_length * flags.batch_size 80 | actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy() 81 | intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy() 82 | 83 | 84 | learner_outputs, unused_state = model(batch, initial_agent_state) 85 | 86 | bootstrap_value = learner_outputs['baseline'][-1] 87 | 88 | batch = {key: tensor[1:] for key, tensor in batch.items()} 89 | learner_outputs = { 90 | key: tensor[:-1] 91 | for key, tensor in learner_outputs.items() 92 | } 93 | 94 | actions = batch['action'].reshape(flags.unroll_length * flags.batch_size).cpu().numpy() 95 | action_percentage = [0 for _ in range(model.num_actions)] 96 | for i in range(model.num_actions): 97 | action_percentage[i] = np.sum([a == i for a in actions]) / len(actions) 98 | 99 | rewards = batch['reward'] 100 | 101 | if flags.no_reward: 102 | total_rewards = intrinsic_rewards 103 | else: 104 | total_rewards = rewards + intrinsic_rewards 105 | clipped_rewards = torch.clamp(total_rewards, -1, 1) 106 | 107 | discounts = (~batch['done']).float() * flags.discounting 108 | 109 | vtrace_returns = vtrace.from_logits( 110 | behavior_policy_logits=batch['policy_logits'], 111 | target_policy_logits=learner_outputs['policy_logits'], 112 | actions=batch['action'], 113 | discounts=discounts, 114 | rewards=clipped_rewards, 115 | values=learner_outputs['baseline'], 116 | bootstrap_value=bootstrap_value) 117 | 118 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], 119 | batch['action'], 120 | vtrace_returns.pg_advantages) 121 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( 122 | vtrace_returns.vs - learner_outputs['baseline']) 123 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( 124 | learner_outputs['policy_logits']) 125 | 126 | total_loss = pg_loss + baseline_loss + entropy_loss \ 127 | + forward_dynamics_loss + inverse_dynamics_loss 128 | 129 | episode_returns = batch['episode_return'][batch['done']] 130 | episode_lengths = batch['episode_step'][batch['done']] 131 | episode_wins = batch['episode_win'][batch['done']] 132 | stats = { 133 | 'mean_episode_return': torch.mean(episode_returns).item(), 134 | 'total_loss': total_loss.item(), 135 | 'pg_loss': pg_loss.item(), 136 | 'baseline_loss': baseline_loss.item(), 137 | 'entropy_loss': entropy_loss.item(), 138 | 'forward_dynamics_loss': forward_dynamics_loss.item(), 139 | 'inverse_dynamics_loss': inverse_dynamics_loss.item(), 140 | 'mean_rewards': torch.mean(rewards).item(), 141 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 142 | 'mean_total_rewards': torch.mean(total_rewards).item(), 143 | } 144 | 145 | scheduler.step() 146 | optimizer.zero_grad() 147 | state_embedding_optimizer.zero_grad() 148 | forward_dynamics_optimizer.zero_grad() 149 | inverse_dynamics_optimizer.zero_grad() 150 | total_loss.backward() 151 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) 152 | nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm) 153 | nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm) 154 | nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm) 155 | optimizer.step() 156 | state_embedding_optimizer.step() 157 | forward_dynamics_optimizer.step() 158 | inverse_dynamics_optimizer.step() 159 | 160 | actor_model.load_state_dict(model.state_dict()) 161 | return stats 162 | 163 | 164 | def train(flags): 165 | if flags.xpid is None: 166 | flags.xpid = 'curiosity-%s' % time.strftime('%Y%m%d-%H%M%S') 167 | plogger = file_writer.FileWriter( 168 | xpid=flags.xpid, 169 | xp_args=flags.__dict__, 170 | rootdir=flags.savedir, 171 | ) 172 | 173 | checkpointpath = os.path.expandvars(os.path.expanduser( 174 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar'))) 175 | 176 | T = flags.unroll_length 177 | B = flags.batch_size 178 | 179 | flags.device = None 180 | if not flags.disable_cuda and torch.cuda.is_available(): 181 | log.info('Using CUDA.') 182 | flags.device = torch.device('cuda') 183 | else: 184 | log.info('Not using CUDA.') 185 | flags.device = torch.device('cpu') 186 | 187 | env = create_env(flags) 188 | if flags.num_input_frames > 1: 189 | env = FrameStack(env, flags.num_input_frames) 190 | 191 | if 'MiniGrid' in flags.env: 192 | if flags.use_fullobs_policy: 193 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n) 194 | else: 195 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n) 196 | if flags.use_fullobs_intrinsic: 197 | state_embedding_model = FullObsMinigridStateEmbeddingNet(env.observation_space.shape)\ 198 | .to(device=flags.device) 199 | else: 200 | state_embedding_model = MinigridStateEmbeddingNet(env.observation_space.shape)\ 201 | .to(device=flags.device) 202 | forward_dynamics_model = MinigridForwardDynamicsNet(env.action_space.n)\ 203 | .to(device=flags.device) 204 | inverse_dynamics_model = MinigridInverseDynamicsNet(env.action_space.n)\ 205 | .to(device=flags.device) 206 | else: 207 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n) 208 | state_embedding_model = MarioDoomStateEmbeddingNet(env.observation_space.shape)\ 209 | .to(device=flags.device) 210 | forward_dynamics_model = MarioDoomForwardDynamicsNet(env.action_space.n)\ 211 | .to(device=flags.device) 212 | inverse_dynamics_model = MarioDoomInverseDynamicsNet(env.action_space.n)\ 213 | .to(device=flags.device) 214 | 215 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags) 216 | model.share_memory() 217 | 218 | initial_agent_state_buffers = [] 219 | for _ in range(flags.num_buffers): 220 | state = model.initial_state(batch_size=1) 221 | for t in state: 222 | t.share_memory_() 223 | initial_agent_state_buffers.append(state) 224 | 225 | actor_processes = [] 226 | ctx = mp.get_context('fork') 227 | free_queue = ctx.SimpleQueue() 228 | full_queue = ctx.SimpleQueue() 229 | 230 | episode_state_count_dict = dict() 231 | train_state_count_dict = dict() 232 | for i in range(flags.num_actors): 233 | actor = ctx.Process( 234 | target=act, 235 | args=(i, free_queue, full_queue, model, buffers, 236 | episode_state_count_dict, train_state_count_dict, 237 | initial_agent_state_buffers, flags)) 238 | actor.start() 239 | actor_processes.append(actor) 240 | 241 | if 'MiniGrid' in flags.env: 242 | if flags.use_fullobs_policy: 243 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 244 | .to(device=flags.device) 245 | else: 246 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 247 | .to(device=flags.device) 248 | else: 249 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\ 250 | .to(device=flags.device) 251 | 252 | optimizer = torch.optim.RMSprop( 253 | learner_model.parameters(), 254 | lr=flags.learning_rate, 255 | momentum=flags.momentum, 256 | eps=flags.epsilon, 257 | alpha=flags.alpha) 258 | 259 | state_embedding_optimizer = torch.optim.RMSprop( 260 | state_embedding_model.parameters(), 261 | lr=flags.learning_rate, 262 | momentum=flags.momentum, 263 | eps=flags.epsilon, 264 | alpha=flags.alpha) 265 | 266 | inverse_dynamics_optimizer = torch.optim.RMSprop( 267 | inverse_dynamics_model.parameters(), 268 | lr=flags.learning_rate, 269 | momentum=flags.momentum, 270 | eps=flags.epsilon, 271 | alpha=flags.alpha) 272 | 273 | forward_dynamics_optimizer = torch.optim.RMSprop( 274 | forward_dynamics_model.parameters(), 275 | lr=flags.learning_rate, 276 | momentum=flags.momentum, 277 | eps=flags.epsilon, 278 | alpha=flags.alpha) 279 | 280 | 281 | def lr_lambda(epoch): 282 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames 283 | 284 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 285 | 286 | logger = logging.getLogger('logfile') 287 | stat_keys = [ 288 | 'total_loss', 289 | 'mean_episode_return', 290 | 'pg_loss', 291 | 'baseline_loss', 292 | 'entropy_loss', 293 | 'forward_dynamics_loss', 294 | 'inverse_dynamics_loss', 295 | 'mean_rewards', 296 | 'mean_intrinsic_rewards', 297 | 'mean_total_rewards', 298 | ] 299 | 300 | logger.info('# Step\t%s', '\t'.join(stat_keys)) 301 | 302 | frames, stats = 0, {} 303 | 304 | 305 | def batch_and_learn(i, lock=threading.Lock()): 306 | """Thread target for the learning process.""" 307 | nonlocal frames, stats 308 | timings = prof.Timings() 309 | while frames < flags.total_frames: 310 | timings.reset() 311 | batch, agent_state = get_batch(free_queue, full_queue, buffers, 312 | initial_agent_state_buffers, flags, timings) 313 | stats = learn(model, learner_model, state_embedding_model, forward_dynamics_model, 314 | inverse_dynamics_model, batch, agent_state, optimizer, 315 | state_embedding_optimizer, forward_dynamics_optimizer, 316 | inverse_dynamics_optimizer, scheduler, flags, frames=frames) 317 | timings.time('learn') 318 | with lock: 319 | to_log = dict(frames=frames) 320 | to_log.update({k: stats[k] for k in stat_keys}) 321 | plogger.log(to_log) 322 | frames += T * B 323 | 324 | if i == 0: 325 | log.info('Batch and learn: %s', timings.summary()) 326 | 327 | for m in range(flags.num_buffers): 328 | free_queue.put(m) 329 | 330 | threads = [] 331 | for i in range(flags.num_threads): 332 | thread = threading.Thread( 333 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) 334 | thread.start() 335 | threads.append(thread) 336 | 337 | 338 | def checkpoint(frames): 339 | if flags.disable_checkpoint: 340 | return 341 | checkpointpath = os.path.expandvars( 342 | os.path.expanduser('%s/%s/%s' % (flags.savedir, flags.xpid, 343 | 'model.tar'))) 344 | log.info('Saving checkpoint to %s', checkpointpath) 345 | torch.save({ 346 | 'model_state_dict': model.state_dict(), 347 | 'state_embedding_model_state_dict': state_embedding_model.state_dict(), 348 | 'forward_dynamics_model_state_dict': forward_dynamics_model.state_dict(), 349 | 'inverse_dynamics_model_state_dict': inverse_dynamics_model.state_dict(), 350 | 'optimizer_state_dict': optimizer.state_dict(), 351 | 'state_embedding_optimizer_state_dict': state_embedding_optimizer.state_dict(), 352 | 'forward_dynamics_optimizer_state_dict': forward_dynamics_optimizer.state_dict(), 353 | 'inverse_dynamics_optimizer_state_dict': inverse_dynamics_optimizer.state_dict(), 354 | 'scheduler_state_dict': scheduler.state_dict(), 355 | 'flags': vars(flags), 356 | }, checkpointpath) 357 | 358 | timer = timeit.default_timer 359 | try: 360 | last_checkpoint_time = timer() 361 | while frames < flags.total_frames: 362 | start_frames = frames 363 | start_time = timer() 364 | time.sleep(5) 365 | 366 | if timer() - last_checkpoint_time > flags.save_interval * 60: 367 | checkpoint(frames) 368 | last_checkpoint_time = timer() 369 | 370 | fps = (frames - start_frames) / (timer() - start_time) 371 | 372 | if stats.get('episode_returns', None): 373 | mean_return = 'Return per episode: %.1f. ' % stats[ 374 | 'mean_episode_return'] 375 | else: 376 | mean_return = '' 377 | 378 | total_loss = stats.get('total_loss', float('inf')) 379 | if stats: 380 | log.info('After %i frames: loss %f @ %.1f fps. Mean Return %.1f. \n Stats \n %s', \ 381 | frames, total_loss, fps, stats['mean_episode_return'], pprint.pformat(stats)) 382 | 383 | except KeyboardInterrupt: 384 | return 385 | else: 386 | for thread in threads: 387 | thread.join() 388 | log.info('Learning finished after %d frames.', frames) 389 | finally: 390 | for _ in range(flags.num_actors): 391 | free_queue.put(None) 392 | for actor in actor_processes: 393 | actor.join(timeout=1) 394 | 395 | checkpoint(frames) 396 | plogger.close() 397 | 398 | -------------------------------------------------------------------------------- /src/algos/ride.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import threading 10 | import time 11 | import timeit 12 | import pprint 13 | import json 14 | 15 | import numpy as np 16 | 17 | import torch 18 | from torch import multiprocessing as mp 19 | from torch import nn 20 | from torch.nn import functional as F 21 | 22 | from src.core import file_writer 23 | from src.core import prof 24 | from src.core import vtrace 25 | 26 | import src.models as models 27 | import src.losses as losses 28 | 29 | from src.env_utils import FrameStack 30 | from src.utils import get_batch, log, create_env, create_buffers, act, create_heatmap_buffers 31 | 32 | MinigridStateEmbeddingNet = models.MinigridStateEmbeddingNet 33 | MinigridForwardDynamicsNet = models.MinigridForwardDynamicsNet 34 | MinigridInverseDynamicsNet = models.MinigridInverseDynamicsNet 35 | MinigridPolicyNet = models.MinigridPolicyNet 36 | 37 | 38 | def learn(actor_model, 39 | model, 40 | state_embedding_model, 41 | forward_dynamics_model, 42 | inverse_dynamics_model, 43 | batch, 44 | initial_agent_state, 45 | optimizer, 46 | state_embedding_optimizer, 47 | forward_dynamics_optimizer, 48 | inverse_dynamics_optimizer, 49 | scheduler, 50 | flags, 51 | frames=None, 52 | lock=threading.Lock()): 53 | """Performs a learning (optimization) step.""" 54 | with lock: 55 | count_rewards = torch.ones((flags.unroll_length, flags.batch_size), 56 | dtype=torch.float32).to(device=flags.device) 57 | count_rewards = batch['episode_state_count'][1:].float().to(device=flags.device) 58 | 59 | if flags.use_fullobs_intrinsic: 60 | state_emb = state_embedding_model(batch, next_state=False)\ 61 | .reshape(flags.unroll_length, flags.batch_size, 128) 62 | next_state_emb = state_embedding_model(batch, next_state=True)\ 63 | .reshape(flags.unroll_length, flags.batch_size, 128) 64 | else: 65 | state_emb = state_embedding_model(batch['partial_obs'][:-1].to(device=flags.device)) 66 | next_state_emb = state_embedding_model(batch['partial_obs'][1:].to(device=flags.device)) 67 | 68 | pred_next_state_emb = forward_dynamics_model( 69 | state_emb, batch['action'][1:].to(device=flags.device)) 70 | pred_actions = inverse_dynamics_model(state_emb, next_state_emb) 71 | 72 | control_rewards = torch.norm(next_state_emb - state_emb, dim=2, p=2) 73 | 74 | intrinsic_rewards = count_rewards * control_rewards 75 | 76 | intrinsic_reward_coef = flags.intrinsic_reward_coef 77 | intrinsic_rewards *= intrinsic_reward_coef 78 | 79 | forward_dynamics_loss = flags.forward_loss_coef * \ 80 | losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb) 81 | 82 | inverse_dynamics_loss = flags.inverse_loss_coef * \ 83 | losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:]) 84 | 85 | learner_outputs, unused_state = model(batch, initial_agent_state) 86 | 87 | bootstrap_value = learner_outputs['baseline'][-1] 88 | 89 | batch = {key: tensor[1:] for key, tensor in batch.items()} 90 | learner_outputs = { 91 | key: tensor[:-1] 92 | for key, tensor in learner_outputs.items() 93 | } 94 | 95 | 96 | rewards = batch['reward'] 97 | if flags.no_reward: 98 | total_rewards = intrinsic_rewards 99 | else: 100 | total_rewards = rewards + intrinsic_rewards 101 | clipped_rewards = torch.clamp(total_rewards, -1, 1) 102 | 103 | discounts = (~batch['done']).float() * flags.discounting 104 | 105 | vtrace_returns = vtrace.from_logits( 106 | behavior_policy_logits=batch['policy_logits'], 107 | target_policy_logits=learner_outputs['policy_logits'], 108 | actions=batch['action'], 109 | discounts=discounts, 110 | rewards=clipped_rewards, 111 | values=learner_outputs['baseline'], 112 | bootstrap_value=bootstrap_value) 113 | 114 | pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'], 115 | batch['action'], 116 | vtrace_returns.pg_advantages) 117 | baseline_loss = flags.baseline_cost * losses.compute_baseline_loss( 118 | vtrace_returns.vs - learner_outputs['baseline']) 119 | entropy_loss = flags.entropy_cost * losses.compute_entropy_loss( 120 | learner_outputs['policy_logits']) 121 | 122 | total_loss = pg_loss + baseline_loss + entropy_loss + \ 123 | forward_dynamics_loss + inverse_dynamics_loss 124 | 125 | episode_returns = batch['episode_return'][batch['done']] 126 | stats = { 127 | 'mean_episode_return': torch.mean(episode_returns).item(), 128 | 'total_loss': total_loss.item(), 129 | 'pg_loss': pg_loss.item(), 130 | 'baseline_loss': baseline_loss.item(), 131 | 'entropy_loss': entropy_loss.item(), 132 | 'mean_rewards': torch.mean(rewards).item(), 133 | 'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(), 134 | 'mean_total_rewards': torch.mean(total_rewards).item(), 135 | 'mean_control_rewards': torch.mean(control_rewards).item(), 136 | 'mean_count_rewards': torch.mean(count_rewards).item(), 137 | 'forward_dynamics_loss': forward_dynamics_loss.item(), 138 | 'inverse_dynamics_loss': inverse_dynamics_loss.item(), 139 | } 140 | 141 | scheduler.step() 142 | optimizer.zero_grad() 143 | state_embedding_optimizer.zero_grad() 144 | forward_dynamics_optimizer.zero_grad() 145 | inverse_dynamics_optimizer.zero_grad() 146 | total_loss.backward() 147 | nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm) 148 | nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm) 149 | nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm) 150 | nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm) 151 | optimizer.step() 152 | state_embedding_optimizer.step() 153 | forward_dynamics_optimizer.step() 154 | inverse_dynamics_optimizer.step() 155 | 156 | actor_model.load_state_dict(model.state_dict()) 157 | return stats 158 | 159 | 160 | def train(flags): 161 | if flags.xpid is None: 162 | flags.xpid = 'ride-%s' % time.strftime('%Y%m%d-%H%M%S') 163 | plogger = file_writer.FileWriter( 164 | xpid=flags.xpid, 165 | xp_args=flags.__dict__, 166 | rootdir=flags.savedir, 167 | ) 168 | checkpointpath = os.path.expandvars(os.path.expanduser( 169 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model.tar'))) 170 | 171 | T = flags.unroll_length 172 | B = flags.batch_size 173 | 174 | flags.device = None 175 | if not flags.disable_cuda and torch.cuda.is_available(): 176 | log.info('Using CUDA.') 177 | flags.device = torch.device('cuda') 178 | else: 179 | log.info('Not using CUDA.') 180 | flags.device = torch.device('cpu') 181 | 182 | env = create_env(flags) 183 | if flags.num_input_frames > 1: 184 | env = FrameStack(env, flags.num_input_frames) 185 | 186 | if 'MiniGrid' in flags.env: 187 | if flags.use_fullobs_policy: 188 | model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n) 189 | else: 190 | model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n) 191 | if flags.use_fullobs_intrinsic: 192 | state_embedding_model = FullObsMinigridStateEmbeddingNet(env.observation_space.shape)\ 193 | .to(device=flags.device) 194 | else: 195 | state_embedding_model = MinigridStateEmbeddingNet(env.observation_space.shape)\ 196 | .to(device=flags.device) 197 | forward_dynamics_model = MinigridForwardDynamicsNet(env.action_space.n)\ 198 | .to(device=flags.device) 199 | inverse_dynamics_model = MinigridInverseDynamicsNet(env.action_space.n)\ 200 | .to(device=flags.device) 201 | else: 202 | model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n) 203 | state_embedding_model = MarioDoomStateEmbeddingNet(env.observation_space.shape)\ 204 | .to(device=flags.device) 205 | forward_dynamics_model = MarioDoomForwardDynamicsNet(env.action_space.n)\ 206 | .to(device=flags.device) 207 | inverse_dynamics_model = MarioDoomInverseDynamicsNet(env.action_space.n)\ 208 | .to(device=flags.device) 209 | 210 | 211 | buffers = create_buffers(env.observation_space.shape, model.num_actions, flags) 212 | heatmap_buffers = create_heatmap_buffers(env.observation_space.shape) 213 | 214 | model.share_memory() 215 | 216 | initial_agent_state_buffers = [] 217 | for _ in range(flags.num_buffers): 218 | state = model.initial_state(batch_size=1) 219 | for t in state: 220 | t.share_memory_() 221 | initial_agent_state_buffers.append(state) 222 | 223 | actor_processes = [] 224 | ctx = mp.get_context('fork') 225 | free_queue = ctx.Queue() 226 | full_queue = ctx.Queue() 227 | 228 | episode_state_count_dict = dict() 229 | train_state_count_dict = dict() 230 | partial_state_count_dict = dict() 231 | encoded_state_count_dict = dict() 232 | heatmap_dict = dict() 233 | for i in range(flags.num_actors): 234 | actor = ctx.Process( 235 | target=act, 236 | args=(i, free_queue, full_queue, model, buffers, 237 | episode_state_count_dict, train_state_count_dict, partial_state_count_dict, encoded_state_count_dict, 238 | heatmap_dict, heatmap_buffers, initial_agent_state_buffers, flags)) 239 | actor.start() 240 | actor_processes.append(actor) 241 | 242 | if 'MiniGrid' in flags.env: 243 | if flags.use_fullobs_policy: 244 | learner_model = FullObsMinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 245 | .to(device=flags.device) 246 | else: 247 | learner_model = MinigridPolicyNet(env.observation_space.shape, env.action_space.n)\ 248 | .to(device=flags.device) 249 | else: 250 | learner_model = MarioDoomPolicyNet(env.observation_space.shape, env.action_space.n)\ 251 | .to(device=flags.device) 252 | 253 | optimizer = torch.optim.RMSprop( 254 | learner_model.parameters(), 255 | lr=flags.learning_rate, 256 | momentum=flags.momentum, 257 | eps=flags.epsilon, 258 | alpha=flags.alpha) 259 | 260 | state_embedding_optimizer = torch.optim.RMSprop( 261 | state_embedding_model.parameters(), 262 | lr=flags.learning_rate, 263 | momentum=flags.momentum, 264 | eps=flags.epsilon, 265 | alpha=flags.alpha) 266 | 267 | inverse_dynamics_optimizer = torch.optim.RMSprop( 268 | inverse_dynamics_model.parameters(), 269 | lr=flags.learning_rate, 270 | momentum=flags.momentum, 271 | eps=flags.epsilon, 272 | alpha=flags.alpha) 273 | 274 | forward_dynamics_optimizer = torch.optim.RMSprop( 275 | forward_dynamics_model.parameters(), 276 | lr=flags.learning_rate, 277 | momentum=flags.momentum, 278 | eps=flags.epsilon, 279 | alpha=flags.alpha) 280 | 281 | 282 | def lr_lambda(epoch): 283 | return 1 - min(epoch * T * B, flags.total_frames) / flags.total_frames 284 | 285 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 286 | 287 | logger = logging.getLogger('logfile') 288 | stat_keys = [ 289 | 'total_loss', 290 | 'mean_episode_return', 291 | 'pg_loss', 292 | 'baseline_loss', 293 | 'entropy_loss', 294 | 'mean_rewards', 295 | 'mean_intrinsic_rewards', 296 | 'mean_total_rewards', 297 | 'mean_control_rewards', 298 | 'mean_count_rewards', 299 | 'forward_dynamics_loss', 300 | 'inverse_dynamics_loss', 301 | ] 302 | logger.info('# Step\t%s', '\t'.join(stat_keys)) 303 | frames, stats = 0, {} 304 | 305 | 306 | def batch_and_learn(i, lock=threading.Lock()): 307 | """Thread target for the learning process.""" 308 | nonlocal frames, stats 309 | timings = prof.Timings() 310 | while frames < flags.total_frames: 311 | timings.reset() 312 | batch, agent_state = get_batch(free_queue, full_queue, buffers, 313 | initial_agent_state_buffers, flags, timings) 314 | stats = learn(model, learner_model, state_embedding_model, forward_dynamics_model, 315 | inverse_dynamics_model, batch, agent_state, optimizer, 316 | state_embedding_optimizer, forward_dynamics_optimizer, 317 | inverse_dynamics_optimizer, scheduler, flags, frames=frames) 318 | timings.time('learn') 319 | with lock: 320 | to_log = dict(frames=frames) 321 | to_log.update({k: stats[k] for k in stat_keys}) 322 | plogger.log(to_log) 323 | frames += T * B 324 | 325 | if i == 0: 326 | log.info('Batch and learn: %s', timings.summary()) 327 | 328 | for m in range(flags.num_buffers): 329 | free_queue.put(m) 330 | 331 | threads = [] 332 | for i in range(flags.num_threads): 333 | thread = threading.Thread( 334 | target=batch_and_learn, name='batch-and-learn-%d' % i, args=(i,)) 335 | thread.start() 336 | threads.append(thread) 337 | 338 | 339 | def checkpoint(frames): 340 | if flags.disable_checkpoint: 341 | return 342 | checkpointpath = os.path.expandvars(os.path.expanduser( 343 | '%s/%s/%s' % (flags.savedir, flags.xpid,'model_'+str(frames)+'.tar'))) 344 | log.info('Saving checkpoint to %s', checkpointpath) 345 | torch.save({ 346 | 'model_state_dict': model.state_dict(), 347 | 'state_embedding_model_state_dict': state_embedding_model.state_dict(), 348 | 'forward_dynamics_model_state_dict': forward_dynamics_model.state_dict(), 349 | 'inverse_dynamics_model_state_dict': inverse_dynamics_model.state_dict(), 350 | 'optimizer_state_dict': optimizer.state_dict(), 351 | 'state_embedding_optimizer_state_dict': state_embedding_optimizer.state_dict(), 352 | 'forward_dynamics_optimizer_state_dict': forward_dynamics_optimizer.state_dict(), 353 | 'inverse_dynamics_optimizer_state_dict': inverse_dynamics_optimizer.state_dict(), 354 | 'scheduler_state_dict': scheduler.state_dict(), 355 | 'flags': vars(flags), 356 | }, checkpointpath) 357 | 358 | def save_heatmap(frames): 359 | checkpoint_path = os.path.expandvars(os.path.expanduser( 360 | '%s/%s/%s' % (flags.savedir, flags.xpid,'heatmap_'+str(frames)+'.json'))) 361 | log.info('Saving heatmap to %s', checkpoint_path) 362 | heatmap_dict = dict() 363 | for i, key in enumerate(heatmap_buffers.keys()): 364 | heatmap_dict.update({i: heatmap_buffers[key].item()}) 365 | with open(checkpoint_path, 'w') as fp: 366 | json.dump(heatmap_dict, fp) 367 | 368 | timer = timeit.default_timer 369 | try: 370 | last_checkpoint_time = timer() 371 | while frames < flags.total_frames: 372 | start_frames = frames 373 | start_time = timer() 374 | time.sleep(5) 375 | 376 | if timer() - last_checkpoint_time > flags.save_interval * 60: 377 | checkpoint(frames) 378 | save_heatmap(frames) 379 | last_checkpoint_time = timer() 380 | 381 | fps = (frames - start_frames) / (timer() - start_time) 382 | if stats.get('episode_returns', None): 383 | mean_return = 'Return per episode: %.1f. ' % stats[ 384 | 'mean_episode_return'] 385 | else: 386 | mean_return = '' 387 | total_loss = stats.get('total_loss', float('inf')) 388 | log.info('After %i frames: loss %f @ %.1f fps. %sStats:\n%s', 389 | frames, total_loss, fps, mean_return, 390 | pprint.pformat(stats)) 391 | 392 | except KeyboardInterrupt: 393 | return 394 | else: 395 | for thread in threads: 396 | thread.join() 397 | log.info('Learning finished after %d frames.', frames) 398 | 399 | finally: 400 | for _ in range(flags.num_actors): 401 | free_queue.put(None) 402 | for actor in actor_processes: 403 | actor.join(timeout=1) 404 | checkpoint(frames) 405 | plogger.close() 406 | 407 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. --------------------------------------------------------------------------------