├── brain_agent ├── __init__.py ├── core │ ├── __init__.py │ ├── agents │ │ ├── __init__.py │ │ ├── agent_utils.py │ │ ├── agent_abc.py │ │ └── dmlab_multitask_agent.py │ ├── algos │ │ ├── __init__.py │ │ ├── popart.py │ │ ├── vtrace.py │ │ └── aux_future_predict.py │ ├── models │ │ ├── __init__.py │ │ ├── model_abc.py │ │ ├── model_utils.py │ │ ├── action_distributions.py │ │ ├── causal_transformer.py │ │ ├── resnet.py │ │ ├── rnn.py │ │ └── transformer.py │ ├── core_utils.py │ ├── shared_buffer.py │ ├── policy_worker.py │ └── actor_worker.py ├── envs │ ├── __init__.py │ ├── dmlab │ │ ├── __init__.py │ │ ├── dmlab_env.py │ │ ├── dmlab_wrappers.py │ │ ├── dmlab_model.py │ │ ├── dmlab_level_cache.py │ │ ├── dmlab_gym.py │ │ └── dmlab30.py │ └── env_utils.py └── utils │ ├── __init__.py │ ├── cfg.py │ ├── logger.py │ ├── timing.py │ ├── utils.py │ └── dist_utils.py ├── requirements.txt ├── assets ├── learning_curve.png └── system_overview.png ├── configs ├── trxl_baseline_train.yaml ├── rnn_baseline_train.yaml ├── lstm_baseline_train.yaml ├── trxl_future_pred_train.yaml ├── lstm_baseline_eval.yaml ├── trxl_baseline_eval.yaml ├── trxl_recon_train.yaml ├── trxl_recon_eval.yaml ├── trxl_future_pred_eval.yaml └── default.yaml ├── LICENSE ├── .gitignore ├── dist_launch.py ├── eval.py ├── train.py └── README.md /brain_agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brain_agent/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brain_agent/envs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brain_agent/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brain_agent/core/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brain_agent/core/algos/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brain_agent/core/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /brain_agent/envs/dmlab/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faster_fifo 2 | tensorboardX 3 | threadpoolctl 4 | colorlog 5 | gym 6 | omegaconf 7 | -------------------------------------------------------------------------------- /assets/learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/brain-agent/HEAD/assets/learning_curve.png -------------------------------------------------------------------------------- /assets/system_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kakaobrain/brain-agent/HEAD/assets/system_overview.png -------------------------------------------------------------------------------- /brain_agent/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | from brain_agent.envs.dmlab.dmlab_env import make_dmlab_env 2 | 3 | def create_env(cfg=None, env_config=None): 4 | if 'dmlab' in cfg.env.name: 5 | env = make_dmlab_env(cfg, env_config) 6 | else: 7 | raise NotImplementedError 8 | return env 9 | -------------------------------------------------------------------------------- /configs/trxl_baseline_train.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | model: 5 | agent: dmlab_multitask_agent 6 | encoder: 7 | encoder_subtype: resnet_impala 8 | core: 9 | core_type: trxl 10 | use_half_policy_worker: True 11 | 12 | optim: 13 | type: adam 14 | learning_rate: 0.0001 15 | batch_size: 1536 16 | rollout: 96 17 | max_grad_norm: 2.5 18 | 19 | -------------------------------------------------------------------------------- /brain_agent/core/agents/agent_utils.py: -------------------------------------------------------------------------------- 1 | from brain_agent.core.agents.dmlab_multitask_agent import DMLabMultiTaskAgent 2 | 3 | def create_agent(cfg, action_space, obs_space, num_levels=1, need_half=False): 4 | if cfg.model.agent == 'dmlab_multitask_agent': 5 | agent = DMLabMultiTaskAgent(cfg, action_space, obs_space, num_levels, need_half) 6 | else: 7 | raise NotImplementedError 8 | return agent 9 | -------------------------------------------------------------------------------- /configs/rnn_baseline_train.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | model: 5 | agent: dmlab_multitask_agent 6 | core: 7 | core_type: rnn 8 | n_rnn_layer: 3 9 | core_init: orthogonal 10 | 11 | optim: 12 | type: adam 13 | learning_rate: 0.0002 14 | batch_size: 1536 15 | rollout: 96 16 | max_grad_norm: 2.5 17 | 18 | learner: 19 | use_ppo: True 20 | use_adv_normalization: True 21 | -------------------------------------------------------------------------------- /configs/lstm_baseline_train.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | model: 5 | agent: dmlab_multitask_agent 6 | core: 7 | core_type: rnn 8 | n_rnn_layer: 3 9 | core_init: tensorflow_default 10 | 11 | optim: 12 | type: adam 13 | learning_rate: 0.0002 14 | batch_size: 576 15 | rollout: 96 16 | max_grad_norm: 2.5 17 | warmup_optimizer: 0 18 | 19 | learner: 20 | use_ppo: False 21 | use_adv_normalization: False 22 | exploration_loss_coeff: 0.003 23 | -------------------------------------------------------------------------------- /configs/trxl_future_pred_train.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | model: 5 | agent: dmlab_multitask_agent 6 | encoder: 7 | encoder_subtype: resnet_impala_large 8 | encoder_pooling: stride 9 | core: 10 | core_type: trxl 11 | n_layer: 6 12 | hidden_size: 512 13 | 14 | 15 | learner: 16 | exploration_loss_coeff: 0.003 17 | psychlab_gamma: 0.7 18 | use_decoder: False 19 | use_aux_future_pred_loss : True 20 | 21 | env: 22 | action_set: extended_action_set_large 23 | -------------------------------------------------------------------------------- /configs/lstm_baseline_eval.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | test: 5 | is_test: True 6 | checkpoint: ??? 7 | test_num_episodes: 100 8 | 9 | model: 10 | agent: dmlab_multitask_agent 11 | core: 12 | core_type: rnn 13 | n_rnn_layer: 3 14 | core_init: tensorflow_default 15 | use_half_policy_worker: False 16 | 17 | actor: 18 | num_workers: 30 19 | num_envs_per_worker: 10 20 | num_splits: 2 21 | 22 | env: 23 | name: dmlab_30_test 24 | use_level_cache: True 25 | decorrelate_envs_on_one_worker: False 26 | decorrelate_experience_max_seconds: 0 27 | one_task_per_worker: True -------------------------------------------------------------------------------- /configs/trxl_baseline_eval.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | test: 5 | is_test: True 6 | checkpoint: ??? 7 | test_num_episodes: 100 8 | 9 | model: 10 | agent: dmlab_multitask_agent 11 | encoder: 12 | encoder_subtype: resnet_impala 13 | core: 14 | core_type: trxl 15 | use_half_policy_worker: False 16 | 17 | actor: 18 | num_workers: 30 19 | num_envs_per_worker: 10 20 | num_splits: 2 21 | 22 | env: 23 | name: dmlab_30_test 24 | use_level_cache: True 25 | decorrelate_envs_on_one_worker: False 26 | decorrelate_experience_max_seconds: 0 27 | one_task_per_worker: True 28 | 29 | -------------------------------------------------------------------------------- /configs/trxl_recon_train.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | model: 5 | agent: dmlab_multitask_agent 6 | encoder: 7 | encoder_subtype: resnet_impala_large 8 | encoder_pooling: stride 9 | core: 10 | core_type: trxl 11 | n_layer: 6 12 | hidden_size: 512 13 | use_half_policy_worker: True 14 | 15 | optim: 16 | type: adam 17 | learning_rate: 0.0001 18 | batch_size: 1536 19 | rollout: 96 20 | max_grad_norm: 2.5 21 | 22 | learner: 23 | exploration_loss_coeff: 0.003 24 | psychlab_gamma: 0.7 25 | use_decoder: True 26 | reconstruction_loss_coeff: 0.01 27 | 28 | env: 29 | action_set: extended_action_set_large 30 | -------------------------------------------------------------------------------- /brain_agent/core/agents/agent_abc.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from brain_agent.core.models.action_distributions import DiscreteActionsParameterization 3 | 4 | class ActorCriticBase(nn.Module): 5 | def __init__(self, action_space, cfg): 6 | super().__init__() 7 | self.cfg = cfg 8 | self.action_space = action_space 9 | 10 | def get_action_parameterization(self, core_output_size): 11 | action_parameterization = DiscreteActionsParameterization(self.cfg, core_output_size, self.action_space) 12 | return action_parameterization 13 | 14 | def model_to_device(self, device): 15 | self.to(device) 16 | 17 | def device_and_type_for_input_tensor(self, input_tensor_name): 18 | return self.encoder.device_and_type_for_input_tensor(input_tensor_name) 19 | -------------------------------------------------------------------------------- /configs/trxl_recon_eval.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | test: 5 | is_test: True 6 | checkpoint: ??? 7 | test_num_episodes: 100 8 | 9 | model: 10 | agent: dmlab_multitask_agent 11 | encoder: 12 | encoder_subtype: resnet_impala_large 13 | encoder_pooling: stride 14 | core: 15 | core_type: trxl 16 | n_layer: 6 17 | hidden_size: 512 18 | use_half_policy_worker: False 19 | 20 | learner: 21 | use_decoder: True 22 | 23 | actor: 24 | num_workers: 30 25 | num_envs_per_worker: 10 26 | num_splits: 2 27 | 28 | env: 29 | name: dmlab_30_test 30 | use_level_cache: True 31 | decorrelate_envs_on_one_worker: False 32 | decorrelate_experience_max_seconds: 0 33 | one_task_per_worker: True 34 | action_set: extended_action_set_large 35 | -------------------------------------------------------------------------------- /configs/trxl_future_pred_eval.yaml: -------------------------------------------------------------------------------- 1 | train_dir: ??? 2 | experiment: ??? 3 | 4 | test: 5 | is_test: True 6 | checkpoint: ??? 7 | test_num_episodes: 100 8 | 9 | model: 10 | agent: dmlab_multitask_agent 11 | encoder: 12 | encoder_subtype: resnet_impala_large 13 | encoder_pooling: stride 14 | core: 15 | core_type: trxl 16 | n_layer: 6 17 | hidden_size: 512 18 | use_half_policy_worker: False 19 | 20 | learner: 21 | use_aux_future_pred_loss : True 22 | 23 | actor: 24 | num_workers: 30 25 | num_envs_per_worker: 10 26 | num_splits: 2 27 | 28 | env: 29 | name: dmlab_30_test 30 | use_level_cache: True 31 | decorrelate_envs_on_one_worker: False 32 | decorrelate_experience_max_seconds: 0 33 | one_task_per_worker: True 34 | action_set: extended_action_set_large 35 | -------------------------------------------------------------------------------- /brain_agent/utils/cfg.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from brain_agent.utils.utils import AttrDict 3 | 4 | class Configs(OmegaConf): 5 | @classmethod 6 | def get_defaults(cls): 7 | cfg = cls.load('configs/default.yaml') 8 | cls.set_struct(cfg, True) 9 | return cfg 10 | 11 | @classmethod 12 | def override_from_file_name(cls, cfg): 13 | c = cls.override_from_cli(cfg) 14 | if not cls.is_missing(c, 'cfg'): 15 | c = cls.load(c.cfg) 16 | cfg = cls.merge(cfg, c) 17 | return cfg 18 | 19 | @classmethod 20 | def override_from_cli(cls, cfg): 21 | c = cls.from_cli() 22 | cfg = cls.merge(cfg, c) 23 | return cfg 24 | 25 | @classmethod 26 | def to_attr_dict(cls, cfg): 27 | c = cls.to_container(cfg) 28 | c = AttrDict.from_nested_dicts(c) 29 | return c 30 | 31 | -------------------------------------------------------------------------------- /brain_agent/core/algos/popart.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def update_mu_sigma(nu, mu, vs, task_ids, popart_clip_min, clamp_max, beta): 4 | oldnu = nu.clone() 5 | oldsigma = torch.sqrt(oldnu - mu ** 2) 6 | oldsigma[torch.isnan(oldsigma)] = popart_clip_min 7 | oldsigma = torch.clamp(oldsigma, min=popart_clip_min, max=clamp_max) 8 | oldmu = mu.clone() 9 | 10 | for i in range(len(task_ids)): 11 | task_id = task_ids[i] 12 | v = torch.mean(vs[i]) 13 | 14 | mu[task_id] = (1 - beta) * mu[task_id] + beta * v 15 | nu[task_id] = (1 - beta) * nu[task_id] + beta * (v ** 2) 16 | 17 | sigma = torch.sqrt(nu - mu ** 2) 18 | sigma[torch.isnan(sigma)] = popart_clip_min 19 | sigma = torch.clamp(sigma, min=popart_clip_min, max=clamp_max) 20 | 21 | return mu, nu, sigma, oldmu, oldsigma 22 | 23 | def update_parameters(weight, bias, mu, sigma, oldmu, oldsigma): 24 | new_weight = (weight.t() * oldsigma / sigma).t() 25 | new_bias = (oldsigma * bias + oldmu - mu) / sigma 26 | return new_weight, new_bias -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kakao Brain Corp. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /brain_agent/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from colorlog import ColoredFormatter 3 | 4 | log = logging.getLogger('rl') 5 | log.setLevel(logging.DEBUG) 6 | log.handlers = [] 7 | log.propagate = False 8 | log_level = logging.DEBUG 9 | 10 | stream_handler = logging.StreamHandler() 11 | stream_handler.setLevel(log_level) 12 | 13 | stream_formatter = ColoredFormatter( 14 | '%(log_color)s[%(asctime)s][%(process)05d] %(message)s', 15 | datefmt=None, 16 | reset=True, 17 | log_colors={ 18 | 'DEBUG': 'cyan', 19 | 'INFO': 'white,bold', 20 | 'INFOV': 'cyan,bold', 21 | 'WARNING': 'yellow', 22 | 'ERROR': 'red,bold', 23 | 'CRITICAL': 'red,bg_white', 24 | }, 25 | secondary_log_colors={}, 26 | style='%' 27 | ) 28 | stream_handler.setFormatter(stream_formatter) 29 | log.addHandler(stream_handler) 30 | 31 | def init_logger(log_level='debug', file_path=None): 32 | log.setLevel(logging.getLevelName(str.upper(log_level))) 33 | if file_path is not None: 34 | file_handler = logging.FileHandler(file_path) 35 | file_formatter = logging.Formatter(fmt='[%(asctime)s][%(process)05d] %(message)s', datefmt=None, style='%') 36 | file_handler.setFormatter(file_formatter) 37 | log.addHandler(file_handler) 38 | for h in log.handlers: 39 | h.setLevel(logging.getLevelName(str.upper(log_level))) 40 | 41 | 42 | -------------------------------------------------------------------------------- /brain_agent/core/models/model_abc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from brain_agent.core.models.model_utils import nonlinearity 4 | 5 | class ActionsParameterizationBase(nn.Module): 6 | def __init__(self, cfg, action_space): 7 | super().__init__() 8 | self.cfg = cfg 9 | self.action_space = action_space 10 | 11 | class EncoderBase(nn.Module): 12 | def __init__(self, cfg): 13 | super().__init__() 14 | 15 | self.cfg = cfg 16 | 17 | self.fc_after_enc = None 18 | self.encoder_out_size = -1 # to be initialized in the constuctor of derived class 19 | 20 | def get_encoder_out_size(self): 21 | return self.encoder_out_size 22 | 23 | def init_fc_blocks(self, input_size): 24 | layers = [] 25 | fc_layer_size = self.cfg.model.core.hidden_size 26 | 27 | for i in range(self.cfg.model.encoder.encoder_extra_fc_layers): 28 | size = input_size if i == 0 else fc_layer_size 29 | 30 | layers.extend([ 31 | nn.Linear(size, fc_layer_size), 32 | nonlinearity(self.cfg), 33 | ]) 34 | 35 | if len(layers) > 0: 36 | self.fc_after_enc = nn.Sequential(*layers) 37 | self.encoder_out_size = fc_layer_size 38 | else: 39 | self.encoder_out_size = input_size 40 | 41 | def model_to_device(self, device): 42 | self.to(device) 43 | 44 | def device_and_type_for_input_tensor(self, _): 45 | return self.model_device(), torch.float32 46 | 47 | def model_device(self): 48 | return next(self.parameters()).device 49 | 50 | def forward_fc_blocks(self, x): 51 | if self.fc_after_enc is not None: 52 | x = self.fc_after_enc(x) 53 | 54 | return x -------------------------------------------------------------------------------- /brain_agent/core/algos/vtrace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calculate_vtrace(values, rewards, dones, vtrace_rho, vtrace_c, 5 | num_trajectories, recurrence, gamma, exclude_last=False): 6 | values_cpu = values.cpu() 7 | rewards_cpu = rewards.cpu() 8 | dones_cpu = dones.cpu() 9 | vtrace_rho_cpu = vtrace_rho.cpu() 10 | vtrace_c_cpu = vtrace_c.cpu() 11 | 12 | vs = torch.zeros((num_trajectories * recurrence)) 13 | adv = torch.zeros((num_trajectories * recurrence)) 14 | 15 | bootstrap_values = values_cpu[recurrence - 1::recurrence] 16 | values_BT = values_cpu.view(-1, recurrence) 17 | next_values = torch.cat([values_BT[:, 1:], bootstrap_values.view(-1, 1)], dim=1).view(-1) 18 | next_vs = next_values[recurrence - 1::recurrence] 19 | 20 | masked_gammas = (1.0 - dones_cpu) * gamma 21 | 22 | if exclude_last: 23 | rollout_recurrence = recurrence - 1 24 | adv[recurrence - 1::recurrence] = rewards_cpu[recurrence - 1::recurrence] + (masked_gammas[recurrence - 1::recurrence] - 1) * next_vs 25 | vs[recurrence - 1::recurrence] = next_vs * vtrace_rho_cpu[recurrence - 1::recurrence] * adv[recurrence - 1::recurrence] 26 | else: 27 | rollout_recurrence = recurrence 28 | 29 | for i in reversed(range(rollout_recurrence)): 30 | rewards = rewards_cpu[i::recurrence] 31 | not_done_times_gamma = masked_gammas[i::recurrence] 32 | 33 | curr_values = values_cpu[i::recurrence] 34 | curr_next_values = next_values[i::recurrence] 35 | curr_vtrace_rho = vtrace_rho_cpu[i::recurrence] 36 | curr_vtrace_c = vtrace_c_cpu[i::recurrence] 37 | 38 | delta_s = curr_vtrace_rho * (rewards + not_done_times_gamma * curr_next_values - curr_values) 39 | adv[i::recurrence] = rewards + not_done_times_gamma * next_vs - curr_values 40 | next_vs = curr_values + delta_s + not_done_times_gamma * curr_vtrace_c * (next_vs - curr_next_values) 41 | vs[i::recurrence] = next_vs 42 | 43 | return vs, adv 44 | -------------------------------------------------------------------------------- /brain_agent/utils/timing.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import deque 3 | 4 | from brain_agent.utils.utils import AttrDict 5 | 6 | 7 | class AvgTime: 8 | def __init__(self, num_values_to_avg): 9 | self.values = deque([], maxlen=num_values_to_avg) 10 | 11 | def __str__(self): 12 | avg_time = sum(self.values) / max(1, len(self.values)) 13 | return f'{avg_time:.4f}' 14 | 15 | 16 | class TimingContext: 17 | def __init__(self, timer, key, additive=False, average=None): 18 | self._timer = timer 19 | self._key = key 20 | self._additive = additive 21 | self._average = average 22 | self._time_enter = None 23 | 24 | def __enter__(self): 25 | self._time_enter = time.time() 26 | 27 | def __exit__(self, type_, value, traceback): 28 | if self._key not in self._timer: 29 | if self._average is not None: 30 | self._timer[self._key] = AvgTime(num_values_to_avg=self._average) 31 | else: 32 | self._timer[self._key] = 0 33 | 34 | time_passed = max(time.time() - self._time_enter, 1e-8) # EPS to prevent div by zero 35 | 36 | if self._additive: 37 | self._timer[self._key] += time_passed 38 | elif self._average is not None: 39 | self._timer[self._key].values.append(time_passed) 40 | else: 41 | self._timer[self._key] = time_passed 42 | 43 | 44 | class Timing(AttrDict): 45 | def timeit(self, key): 46 | return TimingContext(self, key) 47 | 48 | def add_time(self, key): 49 | return TimingContext(self, key, additive=True) 50 | 51 | def time_avg(self, key, average=10): 52 | return TimingContext(self, key, average=average) 53 | 54 | def __str__(self): 55 | s = '' 56 | i = 0 57 | for key, value in self.items(): 58 | str_value = f'{value:.4f}' if isinstance(value, float) else str(value) 59 | s += f'{key}: {str_value}' 60 | if i < len(self) - 1: 61 | s += ', ' 62 | i += 1 63 | return s -------------------------------------------------------------------------------- /brain_agent/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from brain_agent.utils.logger import log 4 | import glob 5 | 6 | def dict_of_list_put(d, k, v, max_len=100): 7 | if d.get(k) is None: 8 | d[k] = [] 9 | d[k].append(v) 10 | if len(d[k]) > max_len: 11 | d[k].pop(0) 12 | 13 | def list_of_dicts_to_dict_of_lists(list_of_dicts): 14 | dict_of_lists = dict() 15 | 16 | for d in list_of_dicts: 17 | for key, x in d.items(): 18 | if key not in dict_of_lists: 19 | dict_of_lists[key] = [] 20 | 21 | dict_of_lists[key].append(x) 22 | 23 | return dict_of_lists 24 | 25 | def get_checkpoint_dir(cfg): 26 | checkpoint_dir = os.path.join(get_experiment_dir(cfg=cfg), f'checkpoint') 27 | os.makedirs(checkpoint_dir, exist_ok=True) 28 | return checkpoint_dir 29 | 30 | 31 | def get_checkpoints(checkpoints_dir): 32 | checkpoints = glob.glob(os.path.join(checkpoints_dir, 'checkpoint_*')) 33 | return sorted(checkpoints) 34 | 35 | 36 | def get_experiment_dir(cfg): 37 | exp_dir = os.path.join(cfg.train_dir, cfg.experiment) 38 | os.makedirs(exp_dir, exist_ok=True) 39 | return exp_dir 40 | 41 | def get_log_path(cfg): 42 | exp_dir = os.path.join(cfg.train_dir, cfg.experiment) 43 | os.makedirs(exp_dir, exist_ok=True) 44 | log_dir = os.path.join(cfg.train_dir, cfg.experiment, 'logs') 45 | os.makedirs(log_dir, exist_ok=True) 46 | date = datetime.now().strftime("%Y%m%d_%I%M%S%P") 47 | return os.path.join(log_dir, f'log-r{cfg.dist.world_rank:02d}-{date}.txt') 48 | 49 | def get_summary_dir(cfg, postfix=None): 50 | summary_dir = os.path.join(cfg.train_dir, cfg.experiment, 'summary') 51 | os.makedirs(summary_dir, exist_ok=True) 52 | if postfix is not None: 53 | summary_dir = os.path.join(summary_dir, postfix) 54 | os.makedirs(summary_dir, exist_ok=True) 55 | return summary_dir 56 | 57 | class AttrDict(dict): 58 | __setattr__ = dict.__setitem__ 59 | 60 | def __getattribute__(self, item): 61 | if item in self: 62 | return self[item] 63 | else: 64 | return super().__getattribute__(item) 65 | 66 | @classmethod 67 | def from_nested_dicts(cls, data): 68 | if not isinstance(data, dict): 69 | return data 70 | else: 71 | return cls({key: cls.from_nested_dicts(data[key]) for key in data}) 72 | -------------------------------------------------------------------------------- /brain_agent/envs/dmlab/dmlab_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from brain_agent.envs.dmlab.dmlab30 import DMLAB_LEVELS_BY_ENVNAME 3 | from brain_agent.envs.dmlab.dmlab_gym import DmlabGymEnv 4 | from brain_agent.envs.dmlab.dmlab_level_cache import dmlab_ensure_global_cache_initialized 5 | from brain_agent.envs.dmlab.dmlab_wrappers import PixelFormatChwWrapper, EpisodicStatWrapper, RewardShapingWrapper 6 | from brain_agent.utils.utils import get_experiment_dir 7 | from brain_agent.utils.logger import log 8 | 9 | DMLAB_INITIALIZED = False 10 | 11 | def get_task_id(env_config, levels, cfg): 12 | if env_config is None: 13 | return 0 14 | 15 | num_envs = len(levels) 16 | 17 | if cfg.env.one_task_per_worker: 18 | return env_config['worker_index'] % num_envs 19 | else: 20 | return env_config['env_id'] % num_envs 21 | 22 | 23 | def make_dmlab_env_impl(levels, cfg, env_config, extra_cfg=None): 24 | skip_frames = cfg.env.frameskip 25 | 26 | task_id = get_task_id(env_config, levels, cfg) 27 | level = levels[task_id] 28 | log.debug('%r level %s task id %d', env_config, level, task_id) 29 | 30 | env = DmlabGymEnv( 31 | task_id, level, skip_frames, cfg.env.res_w, cfg.env.res_h, 32 | cfg.env.dataset_path, cfg.env.action_set, 33 | cfg.env.use_level_cache, cfg.env.level_cache_path, extra_cfg, 34 | ) 35 | all_levels = [] 36 | for l in levels: 37 | all_levels.append(l.replace('contributed/dmlab30/', '')) 38 | 39 | env.level_info = dict( 40 | num_levels=len(levels), 41 | all_levels=all_levels 42 | ) 43 | 44 | env = PixelFormatChwWrapper(env) 45 | 46 | env = EpisodicStatWrapper(env) 47 | 48 | env = RewardShapingWrapper(env) 49 | 50 | return env 51 | 52 | 53 | def make_dmlab_env(cfg, env_config=None): 54 | levels = DMLAB_LEVELS_BY_ENVNAME[cfg.env.name] 55 | extra_cfg = None 56 | if cfg.test.is_test and 'test' in cfg.env.name: 57 | extra_cfg = dict(allowHoldOutLevels='true') 58 | ensure_initialized(cfg, levels) 59 | return make_dmlab_env_impl(levels, cfg, env_config, extra_cfg=extra_cfg) 60 | 61 | 62 | def ensure_initialized(cfg, levels): 63 | global DMLAB_INITIALIZED 64 | if DMLAB_INITIALIZED: 65 | return 66 | 67 | level_cache_dir = cfg.env.level_cache_path 68 | os.makedirs(level_cache_dir, exist_ok=True) 69 | 70 | dmlab_ensure_global_cache_initialized(get_experiment_dir(cfg=cfg), levels, level_cache_dir) 71 | DMLAB_INITIALIZED = True 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # additional 132 | .idea -------------------------------------------------------------------------------- /brain_agent/core/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch import nn 4 | from brain_agent.utils.utils import AttrDict 5 | 6 | EPS = 1e-8 7 | 8 | def to_scalar(value): 9 | if isinstance(value, torch.Tensor): 10 | return value.item() 11 | else: 12 | return value 13 | 14 | def get_hidden_size(cfg, action_space): 15 | if cfg.model.core.core_type == 'trxl': 16 | size = cfg.model.core.hidden_size * (cfg.model.core.n_layer + 1) 17 | size += 64 * (cfg.model.core.n_layer + 1) 18 | if cfg.model.extended_input: 19 | size += (action_space.n + 1) * (cfg.model.core.n_layer + 1) 20 | elif cfg.model.core.core_type == 'rnn': 21 | size = cfg.model.core.hidden_size * cfg.model.core.n_rnn_layer * 2 22 | else: 23 | raise NotImplementedError 24 | return size 25 | 26 | 27 | def nonlinearity(cfg): 28 | if cfg.model.encoder.nonlinearity == 'elu': 29 | return nn.ELU(inplace=cfg.model.encoder.nonlinear_inplace) 30 | elif cfg.model.encoder.nonlinearity == 'relu': 31 | return nn.ReLU(inplace=cfg.model.encoder.nonlinear_inplace) 32 | elif cfg.model.encoder.nonlinearity == 'tanh': 33 | return nn.Tanh() 34 | else: 35 | raise Exception('Unknown nonlinearity') 36 | 37 | 38 | def get_obs_shape(obs_space): 39 | obs_shape = AttrDict() 40 | if hasattr(obs_space, 'spaces'): 41 | for key, space in obs_space.spaces.items(): 42 | obs_shape[key] = space.shape 43 | else: 44 | obs_shape.obs = obs_space.shape 45 | 46 | return obs_shape 47 | 48 | 49 | def calc_num_elements(module, module_input_shape): 50 | shape_with_batch_dim = (1,) + module_input_shape 51 | some_input = torch.rand(shape_with_batch_dim) 52 | num_elements = module(some_input).numel() 53 | return num_elements 54 | 55 | def normalize_obs_return(obs_dict, cfg, half=False): 56 | with torch.no_grad(): 57 | mean = cfg.env.obs_subtract_mean 58 | scale = cfg.env.obs_scale 59 | 60 | normalized_obs_dict = copy.deepcopy(obs_dict) 61 | 62 | if normalized_obs_dict['obs'].dtype != torch.float: 63 | normalized_obs_dict['obs'] = normalized_obs_dict['obs'].float() 64 | 65 | if abs(mean) > EPS: 66 | normalized_obs_dict['obs'].sub_(mean) 67 | 68 | if abs(scale - 1.0) > EPS: 69 | normalized_obs_dict['obs'].mul_(1.0 / scale) 70 | 71 | if half: 72 | normalized_obs_dict['obs'] = normalized_obs_dict['obs'].half() 73 | 74 | return normalized_obs_dict 75 | 76 | -------------------------------------------------------------------------------- /brain_agent/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | 4 | import torch 5 | from brain_agent.utils.logger import log 6 | 7 | DistEnv = collections.namedtuple('DistEnv', ['world_size', 'world_rank', 'local_rank', 'num_gpus', 'master']) 8 | 9 | 10 | def dist_init(cfg): 11 | if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1: 12 | log.debug('[dist] Distributed: wait dist process group:%d', cfg.dist.local_rank) 13 | torch.distributed.init_process_group(backend=cfg.dist.dist_backend, init_method='env://', 14 | world_size=int(os.environ['WORLD_SIZE'])) 15 | assert (int(os.environ['WORLD_SIZE']) == torch.distributed.get_world_size()) 16 | log.debug('[dist] Distributed: success device:%d (%d/%d)', 17 | cfg.dist.local_rank, torch.distributed.get_rank(), torch.distributed.get_world_size()) 18 | distenv = DistEnv(torch.distributed.get_world_size(), torch.distributed.get_rank(), cfg.dist.local_rank, 1, torch.distributed.get_rank() == 0) 19 | else: 20 | log.debug('[dist] Single processed') 21 | distenv = DistEnv(1, 0, 0, torch.cuda.device_count(), True) 22 | log.debug('[dist] %s', distenv) 23 | return distenv 24 | 25 | 26 | def dist_all_reduce_gradient(model): 27 | torch.distributed.barrier() 28 | world_size = float(torch.distributed.get_world_size()) 29 | for p in model.parameters(): 30 | if type(p.grad) is not type(None): 31 | torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM) 32 | p.grad.data /= world_size 33 | 34 | 35 | def dist_reduce_gradient(model, grads=None): 36 | torch.distributed.barrier() 37 | world_size = float(torch.distributed.get_world_size()) 38 | if grads is None: 39 | for p in model.parameters(): 40 | if type(p.grad) is not type(None): 41 | torch.distributed.reduce(p.grad.data, 0, op=torch.distributed.ReduceOp.SUM) 42 | p.grad.data /= world_size 43 | else: 44 | for grad in grads: 45 | if type(grad) is not type(None): 46 | torch.distributed.reduce(grad.data, 0, op=torch.distributed.ReduceOp.SUM) 47 | grad.data /= world_size 48 | 49 | 50 | def dist_all_reduce_buffers(model): 51 | torch.distributed.barrier() 52 | world_size = float(torch.distributed.get_world_size()) 53 | for n, b in model.named_buffers(): 54 | torch.distributed.all_reduce(b.data, op=torch.distributed.ReduceOp.SUM) 55 | b.data /= world_size 56 | 57 | 58 | def dist_broadcast_model(model): 59 | torch.distributed.barrier() 60 | for _, param in model.state_dict().items(): 61 | torch.distributed.broadcast(param, 0) 62 | torch.distributed.barrier() 63 | torch.cuda.synchronize() -------------------------------------------------------------------------------- /brain_agent/core/models/action_distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from brain_agent.utils.logger import log 6 | from brain_agent.core.models.model_abc import ActionsParameterizationBase 7 | 8 | 9 | class DiscreteActionsParameterization(ActionsParameterizationBase): 10 | def __init__(self, cfg, core_out_size, action_space): 11 | super().__init__(cfg, action_space) 12 | 13 | num_action_outputs = action_space.n 14 | self.distribution_linear = nn.Linear(core_out_size, num_action_outputs) 15 | 16 | def forward(self, actor_core_output): 17 | action_distribution_params = self.distribution_linear(actor_core_output) 18 | action_distribution = CategoricalActionDistribution(raw_logits=action_distribution_params) 19 | return action_distribution_params, action_distribution 20 | 21 | 22 | class CategoricalActionDistribution: 23 | def __init__(self, raw_logits): 24 | self.raw_logits = raw_logits 25 | self.log_p = self.p = None 26 | 27 | @property 28 | def probs(self): 29 | if self.p is None: 30 | self.p = F.softmax(self.raw_logits, dim=-1) 31 | return self.p 32 | 33 | @property 34 | def log_probs(self): 35 | if self.log_p is None: 36 | self.log_p = F.log_softmax(self.raw_logits, dim=-1) 37 | return self.log_p 38 | 39 | def sample_gumbel(self): 40 | sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1) 41 | return sample 42 | 43 | def sample(self): 44 | samples = torch.multinomial(self.probs, 1, True).squeeze(dim=-1) 45 | return samples 46 | 47 | def sample_max(self): 48 | samples = torch.argmax(self.probs, dim=-1) 49 | return samples 50 | 51 | def log_prob(self, value): 52 | value = value.long().unsqueeze(-1) 53 | log_probs = torch.gather(self.log_probs, -1, value).view(-1) 54 | return log_probs 55 | 56 | def entropy(self): 57 | p_log_p = self.log_probs * self.probs 58 | return -p_log_p.sum(-1) 59 | 60 | def _kl(self, other_log_probs): 61 | probs, log_probs = self.probs, self.log_probs 62 | kl = probs * (log_probs - other_log_probs) 63 | kl = kl.sum(dim=-1) 64 | return kl 65 | 66 | def _kl_inverse(self, other_log_probs): 67 | probs, log_probs = self.probs, self.log_probs 68 | kl = torch.exp(other_log_probs) * (other_log_probs - log_probs) 69 | kl = kl.sum(dim=-1) 70 | return kl 71 | 72 | def _kl_symmetric(self, other_log_probs): 73 | return 0.5 * (self._kl(other_log_probs) + self._kl_inverse(other_log_probs)) 74 | 75 | def symmetric_kl_with_uniform_prior(self): 76 | probs, log_probs = self.probs, self.log_probs 77 | num_categories = log_probs.shape[-1] 78 | uniform_prob = 1 / num_categories 79 | log_uniform_prob = math.log(uniform_prob) 80 | 81 | return 0.5 * ((probs * (log_probs - log_uniform_prob)).sum(dim=-1) 82 | + (uniform_prob * (log_uniform_prob - log_probs)).sum(dim=-1)) 83 | 84 | def kl_divergence(self, other): 85 | return self._kl(other.log_probs) 86 | 87 | def dbg_print(self): 88 | dbg_info = dict( 89 | entropy=self.entropy().mean(), 90 | min_logit=self.raw_logits.min(), 91 | max_logit=self.raw_logits.max(), 92 | min_prob=self.probs.min(), 93 | max_prob=self.probs.max(), 94 | ) 95 | 96 | msg = '' 97 | for key, value in dbg_info.items(): 98 | msg += f'{key}={value.cpu().item():.3f} ' 99 | log.debug(msg) 100 | 101 | 102 | def sample_actions_log_probs(distribution): 103 | actions = distribution.sample() 104 | log_prob_actions = distribution.log_prob(actions) 105 | return actions, log_prob_actions 106 | -------------------------------------------------------------------------------- /brain_agent/envs/dmlab/dmlab_wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from gym import spaces, ObservationWrapper 4 | from math import tanh 5 | from brain_agent.envs.dmlab.dmlab30 import RANDOM_SCORES, HUMAN_SCORES 6 | 7 | 8 | def has_image_observations(observation_space): 9 | return len(observation_space.shape) >= 2 10 | 11 | def compute_hns(r, h, s): 12 | return (s-r) / (h-r) * 100 13 | 14 | class PixelFormatChwWrapper(ObservationWrapper): 15 | def __init__(self, env): 16 | super().__init__(env) 17 | 18 | if isinstance(env.observation_space, gym.spaces.Dict): 19 | img_obs_space = env.observation_space['obs'] 20 | self.dict_obs_space = True 21 | else: 22 | img_obs_space = env.observation_space 23 | self.dict_obs_space = False 24 | 25 | if not has_image_observations(img_obs_space): 26 | raise Exception('Pixel format wrapper only works with image-based envs') 27 | 28 | obs_shape = img_obs_space.shape 29 | max_num_img_channels = 4 30 | 31 | if len(obs_shape) <= 2: 32 | raise Exception('Env obs do not have channel dimension?') 33 | 34 | if obs_shape[0] <= max_num_img_channels: 35 | raise Exception('Env obs already in CHW format?') 36 | 37 | h, w, c = obs_shape 38 | low, high = img_obs_space.low.flat[0], img_obs_space.high.flat[0] 39 | new_shape = [c, h, w] 40 | 41 | if self.dict_obs_space: 42 | dtype = env.observation_space.spaces['obs'].dtype if env.observation_space.spaces['obs'].dtype is not None else np.float32 43 | else: 44 | dtype = env.observation_space.dtype if env.observation_space.dtype is not None else np.float32 45 | 46 | new_img_obs_space = spaces.Box(low, high, shape=new_shape, dtype=dtype) 47 | 48 | if self.dict_obs_space: 49 | self.observation_space = env.observation_space 50 | self.observation_space.spaces['obs'] = new_img_obs_space 51 | else: 52 | self.observation_space = new_img_obs_space 53 | 54 | self.action_space = env.action_space 55 | 56 | @staticmethod 57 | def _transpose(obs): 58 | return np.transpose(obs, (2, 0, 1)) 59 | 60 | def observation(self, observation): 61 | if observation is None: 62 | return observation 63 | 64 | if self.dict_obs_space: 65 | observation['obs'] = self._transpose(observation['obs']) 66 | else: 67 | observation = self._transpose(observation) 68 | return observation 69 | 70 | class EpisodicStatWrapper(gym.Wrapper): 71 | def __init__(self, env): 72 | super().__init__(env) 73 | self.raw_episode_return = self.episode_return = self.episode_length = 0 74 | 75 | def reset(self): 76 | obs = self.env.reset() 77 | self.raw_episode_return = self.episode_return = self.episode_length = 0 78 | return obs 79 | 80 | def step(self, action): 81 | obs, rew, done, info = self.env.step(action) 82 | self.episode_return += rew 83 | self.raw_episode_return += info.get('raw_rew', rew) 84 | self.episode_length += info.get('num_frames', 1) 85 | 86 | if done: 87 | level_name = self.unwrapped.level_name 88 | hns = compute_hns( 89 | RANDOM_SCORES[level_name.replace('train', 'test')], 90 | HUMAN_SCORES[level_name.replace('train', 'test')], 91 | self.raw_episode_return 92 | ) 93 | 94 | info['episodic_stats'] = { 95 | 'level_name': self.unwrapped.level_name, 96 | 'task_id': self.unwrapped.task_id, 97 | 'episode_return': self.episode_return, 98 | 'episode_length': self.episode_length, 99 | 'raw_episode_return': self.raw_episode_return, 100 | 'hns': hns, 101 | } 102 | self.episode_return = 0 103 | self.raw_episode_return = 0 104 | self.episode_length = 0 105 | 106 | return obs, rew, done, info 107 | 108 | class RewardShapingWrapper(gym.Wrapper): 109 | def __init__(self, env): 110 | super().__init__(env) 111 | 112 | def reset(self): 113 | obs = self.env.reset() 114 | return obs 115 | 116 | def step(self, action): 117 | obs, rew, done, info = self.env.step(action) 118 | 119 | info['raw_rew'] = rew 120 | squeezed = tanh(rew / 5.0) 121 | clipped = (1.5 * squeezed) if rew < 0.0 else (5.0 * squeezed) 122 | rew = clipped 123 | 124 | return obs, rew, done, info 125 | -------------------------------------------------------------------------------- /brain_agent/envs/dmlab/dmlab_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | 5 | from brain_agent.core.models.model_abc import EncoderBase 6 | from brain_agent.core.models.resnet import ResnetEncoder 7 | from brain_agent.envs.dmlab.dmlab30 import DMLAB_VOCABULARY_SIZE, DMLAB_INSTRUCTIONS 8 | from brain_agent.utils.logger import log 9 | 10 | 11 | class DmlabEncoder(EncoderBase): 12 | def __init__(self, cfg, obs_space): 13 | super().__init__(cfg) 14 | 15 | if self.cfg.model.encoder.encoder_type == 'resnet': 16 | self.basic_encoder = ResnetEncoder(cfg, obs_space) 17 | else: 18 | raise NotImplementedError 19 | self.encoder_out_size = self.basic_encoder.encoder_out_size 20 | 21 | self.embedding_size = 20 22 | self.instructions_lstm_units = 64 23 | self.instructions_lstm_layers = 1 24 | 25 | padding_idx = 0 26 | self.word_embedding = nn.Embedding( 27 | num_embeddings=DMLAB_VOCABULARY_SIZE, 28 | embedding_dim=self.embedding_size, 29 | padding_idx=padding_idx 30 | ) 31 | 32 | self.instructions_lstm = nn.LSTM( 33 | input_size=self.embedding_size, 34 | hidden_size=self.instructions_lstm_units, 35 | num_layers=self.instructions_lstm_layers, 36 | batch_first=True, 37 | ) 38 | 39 | self.encoder_out_size += self.instructions_lstm_units 40 | log.debug('Policy head output size: %r', self.encoder_out_size) 41 | 42 | self.instructions_lstm.apply(self.initialize) 43 | 44 | def initialize(self, layer): 45 | gain = 1.0 46 | if hasattr(layer, 'bias') and isinstance(layer.bias, torch.nn.parameter.Parameter): 47 | layer.bias.data.fill_(0) 48 | 49 | if self.cfg.model.encoder.encoder_init == 'orthogonal': 50 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear: 51 | nn.init.orthogonal_(layer.weight.data, gain=gain) 52 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell: 53 | nn.init.orthogonal_(layer.weight_ih, gain=gain) 54 | nn.init.orthogonal_(layer.weight_hh, gain=gain) 55 | layer.bias_ih.data.fill_(0) 56 | layer.bias_hh.data.fill_(0) 57 | elif self.cfg.model.encoder.encoder_init == 'xavier_uniform': 58 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear: 59 | nn.init.xavier_uniform_(layer.weight.data, gain=gain) 60 | layer.bias.data.fill_(0) 61 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell: 62 | nn.init.xavier_uniform_(layer.weight_ih, gain=gain) 63 | nn.init.xavier_uniform_(layer.weight_hh, gain=gain) 64 | layer.bias_ih.data.fill_(0) 65 | layer.bias_hh.data.fill_(0) 66 | elif self.cfg.model.encoder.encoder_init == 'torch_default': 67 | pass 68 | else: 69 | raise NotImplementedError 70 | 71 | def model_to_device(self, device): 72 | self.to(device) 73 | self.word_embedding.to(self.device) 74 | self.instructions_lstm.to(self.device) 75 | 76 | def device_and_type_for_input_tensor(self, input_tensor_name): 77 | if input_tensor_name == DMLAB_INSTRUCTIONS: 78 | return self.model_device(), torch.int64 79 | else: 80 | return self.model_device(), torch.float32 81 | 82 | def forward(self, obs_dict, **kwargs): 83 | x = self.basic_encoder(obs_dict, **kwargs) 84 | 85 | with torch.no_grad(): 86 | instr = obs_dict[DMLAB_INSTRUCTIONS] 87 | instr_lengths = (instr != 0).sum(axis=1) 88 | instr_lengths = torch.clamp(instr_lengths, min=1) 89 | max_instr_len = torch.max(instr_lengths).item() 90 | instr = instr[:, :max_instr_len] 91 | instr_lengths_cpu = instr_lengths.to('cpu') 92 | 93 | instr_embed = self.word_embedding(instr) 94 | instr_packed = torch.nn.utils.rnn.pack_padded_sequence( 95 | instr_embed, instr_lengths_cpu, batch_first=True, enforce_sorted=False, 96 | ) 97 | rnn_output, _ = self.instructions_lstm(instr_packed) 98 | rnn_outputs, sequence_lengths = torch.nn.utils.rnn.pad_packed_sequence(rnn_output, batch_first=True) 99 | 100 | first_dim_idx = torch.arange(rnn_outputs.shape[0]) 101 | last_output_idx = sequence_lengths - 1 102 | last_outputs = rnn_outputs[first_dim_idx, last_output_idx] 103 | 104 | last_outputs = last_outputs.to(x.device) 105 | 106 | x = torch.cat((x, last_outputs), dim=1) 107 | return x 108 | 109 | -------------------------------------------------------------------------------- /brain_agent/core/models/causal_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from brain_agent.core.models.transformer import PositionalEmbedding, RelPartialLearnableDecoderLayer 5 | 6 | 7 | class CausalTransformer(nn.Module): 8 | def __init__(self, core_out_size, n_action, pre_lnorm=False): 9 | super().__init__() 10 | self.n_layer = 4 11 | self.n_head = 3 12 | self.d_head = 64 13 | self.d_inner = 512 14 | self.d_model = 196 15 | self.mem_len = 64 16 | 17 | self.blocks = nn.ModuleList() 18 | for i in range(self.n_layer): 19 | self.blocks.append(RelPartialLearnableDecoderLayer( 20 | n_head=self.n_head, d_model=self.d_model, d_head=self.d_head, d_inner=self.d_inner, 21 | pre_lnorm=pre_lnorm)) 22 | # decoder head 23 | self.ln_f = nn.LayerNorm(self.d_model) 24 | self.head = nn.Linear(self.d_model, core_out_size, bias=True) 25 | 26 | self.apply(self._init_weights) 27 | 28 | self.state_encoder = nn.Sequential(nn.Linear(core_out_size+n_action, self.d_model), nn.Tanh()) 29 | 30 | self.pos_emb = PositionalEmbedding(self.d_model) 31 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 32 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 33 | 34 | 35 | def _init_weights(self, module): 36 | if isinstance(module, (nn.Linear, nn.Embedding)): 37 | module.weight.data.normal_(mean=0.0, std=0.02) 38 | if isinstance(module, nn.Linear) and module.bias is not None: 39 | module.bias.data.zero_() 40 | elif isinstance(module, nn.LayerNorm): 41 | module.bias.data.zero_() 42 | module.weight.data.fill_(1.0) 43 | 44 | def forward(self, states, mem_begin_index, num_traj, mems=None): 45 | state_embeddings = self.state_encoder(states) # (batch * block_size, n_embd) 46 | x = state_embeddings 47 | 48 | qlen, bsz, _ = x.size() # qlen is number of characters in input ex 49 | 50 | if mems is not None: 51 | mlen = mems[0].size(0) 52 | klen = mlen + qlen 53 | dec_attn_mask_triu = torch.triu(state_embeddings.new_ones(qlen, klen), diagonal=1 + mlen) 54 | dec_attn_mask = dec_attn_mask_triu.bool().unsqueeze(-1).repeat(1, 1, bsz) 55 | else: 56 | mlen = self.mem_len 57 | klen = self.mem_len 58 | dec_attn_mask_triu = torch.triu(state_embeddings.new_ones(qlen, klen), diagonal=1) 59 | dec_attn_mask = dec_attn_mask_triu.bool().unsqueeze(-1).repeat(1, 1, bsz) 60 | 61 | for b in range(bsz): 62 | if mlen-mem_begin_index[b] > 0: 63 | dec_attn_mask[:, :mlen-mem_begin_index[b], b] = True 64 | 65 | dec_attn_mask = dec_attn_mask.transpose(1,2) 66 | temp = torch.logical_not(dec_attn_mask) 67 | temp = torch.sum(temp, dim=2, keepdim=True) 68 | temp = torch.ge(temp, 0.1) 69 | 70 | dec_attn_mask = torch.logical_and(temp, dec_attn_mask) 71 | dec_attn_mask = dec_attn_mask.transpose(1, 2) 72 | 73 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=states.device, dtype=states.dtype) # [99,...0] 74 | pos_emb = self.pos_emb(pos_seq) # T x 1 x dim 75 | 76 | hids = [x] 77 | for i, layer in enumerate(self.blocks): 78 | if mems is not None: 79 | x = layer(x, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems[i]) 80 | else: 81 | x = layer(x, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=None) 82 | hids.append(x) 83 | 84 | if mems is None: 85 | new_mems = hids 86 | else: 87 | new_mems = self._update_mems(hids, mems, mlen, qlen) 88 | mem_begin_index = mem_begin_index + 1 89 | x = self.ln_f(x) 90 | logits = self.head(x) 91 | return logits, new_mems, mem_begin_index 92 | 93 | 94 | def _update_mems(self, hids, mems, mlen, qlen): 95 | # does not deal with None 96 | if mems is None: return None 97 | 98 | # mems is not None 99 | assert len(hids) == len(mems), 'len(hids) != len(mems)' 100 | 101 | new_mems = [] 102 | end_idx = mlen + max(0, qlen) # ext_len looks to usually be 0 (in their experiments anyways 103 | beg_idx = max(0, end_idx - self.mem_len) #if hids[0].shape[0] > 1 else 0 104 | for i in range(len(hids)): 105 | cat = torch.cat([mems[i], hids[i]], dim=0) # (m_len + q) x B x dim 106 | aa=1 107 | if beg_idx == end_idx: # cfg.mem_len=0 108 | new_mems.append(torch.zeros(cat[0:1].size())) 109 | else: # cfg.mem_len > 0 110 | new_mems.append(cat[beg_idx:end_idx]) 111 | 112 | return new_mems -------------------------------------------------------------------------------- /dist_launch.py: -------------------------------------------------------------------------------- 1 | # codebase: torch.distributed.launch.py 2 | 3 | import sys 4 | import subprocess 5 | import os 6 | from argparse import ArgumentParser, REMAINDER 7 | 8 | 9 | def parse_args(): 10 | """ 11 | Helper function parsing the command line options 12 | @retval ArgumentParser 13 | """ 14 | parser = ArgumentParser(description="PyTorch distributed training launch " 15 | "helper utility that will spawn up " 16 | "multiple distributed processes") 17 | 18 | # Optional arguments for the launch helper 19 | parser.add_argument("--nnodes", type=int, default=1, 20 | help="The number of nodes to use for distributed " 21 | "training") 22 | parser.add_argument("--node_rank", type=int, default=0, 23 | help="The rank of the node for multi-node distributed " 24 | "training") 25 | parser.add_argument("--nproc_per_node", type=int, default=1, 26 | help="The number of processes to launch on each node, " 27 | "for GPU training, this is recommended to be set " 28 | "to the number of GPUs in your system so that " 29 | "each process can be bound to a single GPU.") 30 | parser.add_argument("--master_addr", default="127.0.0.1", type=str, 31 | help="Master node (rank 0)'s address, should be either " 32 | "the IP address or the hostname of node 0, for " 33 | "single node multi-proc training, the " 34 | "--master_addr can simply be 127.0.0.1") 35 | parser.add_argument("--master_port", default=2901, type=int, 36 | help="Master node (rank 0)'s free port that needs to " 37 | "be used for communication during distributed " 38 | "training") 39 | parser.add_argument("-m", "--module", default=False, action="store_true", 40 | help="Changes each process to interpret the launch script " 41 | "as a python module, executing with the same behavior as" 42 | "'python -m'.") 43 | 44 | # positional 45 | parser.add_argument("training_script", type=str, 46 | help="The full path to the single GPU training " 47 | "program/script to be launched in parallel, " 48 | "followed by all the arguments for the " 49 | "training script") 50 | 51 | # rest from the training program 52 | parser.add_argument('training_script_args', nargs=REMAINDER) 53 | return parser.parse_args() 54 | 55 | def main(): 56 | args = parse_args() 57 | 58 | # world size in terms of number of processes 59 | dist_world_size = args.nproc_per_node * args.nnodes 60 | 61 | # set PyTorch distributed related environmental variables 62 | current_env = os.environ.copy() 63 | current_env["MASTER_ADDR"] = args.master_addr 64 | current_env["MASTER_PORT"] = str(args.master_port) 65 | current_env["WORLD_SIZE"] = str(dist_world_size) 66 | 67 | processes = [] 68 | 69 | if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1: 70 | current_env["OMP_NUM_THREADS"] = str(1) 71 | print("*****************************************\n" 72 | "Setting OMP_NUM_THREADS environment variable for each process " 73 | "to be {} in default, to avoid your system being overloaded, " 74 | "please further tune the variable for optimal performance in " 75 | "your application as needed. \n" 76 | "*****************************************".format(current_env["OMP_NUM_THREADS"])) 77 | 78 | for local_rank in range(0, args.nproc_per_node): 79 | # each process's rank 80 | dist_rank = args.nproc_per_node * args.node_rank + local_rank 81 | current_env["RANK"] = str(dist_rank) 82 | current_env["LOCAL_RANK"] = str(local_rank) 83 | 84 | # spawn the processes 85 | cmd = [sys.executable, "-u"] 86 | if args.module: 87 | cmd.append("-m") 88 | 89 | cmd.append(args.training_script) 90 | 91 | cmd.append("dist.local_rank={}".format(local_rank)) 92 | cmd.append("dist.nproc_per_node={}".format(args.nproc_per_node)) 93 | cmd.append("dist.world_rank={}".format(dist_rank)) 94 | cmd.append("dist.world_size={}".format(dist_world_size)) 95 | 96 | cmd.extend(args.training_script_args) 97 | 98 | process = subprocess.Popen(cmd, env=current_env) 99 | processes.append(process) 100 | 101 | for process in processes: 102 | process.wait() 103 | if process.returncode != 0: 104 | raise subprocess.CalledProcessError(returncode=process.returncode, 105 | cmd=cmd) 106 | 107 | 108 | if __name__ == "__main__": 109 | main() -------------------------------------------------------------------------------- /brain_agent/core/core_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import psutil 3 | from collections import OrderedDict 4 | from brain_agent.utils.logger import log 5 | from faster_fifo import Full, Empty 6 | 7 | class TaskType: 8 | INIT, TERMINATE, RESET, ROLLOUT_STEP, POLICY_STEP, TRAIN, INIT_MODEL, EMPTY = range(8) 9 | 10 | def dict_of_lists_append(dict_of_lists, new_data, index): 11 | for key, x in new_data.items(): 12 | if key in dict_of_lists: 13 | dict_of_lists[key].append(x[index]) 14 | else: 15 | dict_of_lists[key] = [x[index]] 16 | 17 | def copy_dict_structure(d): 18 | d_copy = type(d)() 19 | _copy_dict_structure_func(d, d_copy) 20 | return d_copy 21 | 22 | def _copy_dict_structure_func(d, d_copy): 23 | for key, value in d.items(): 24 | if isinstance(value, (dict, OrderedDict)): 25 | d_copy[key] = type(value)() 26 | _copy_dict_structure_func(value, d_copy[key]) 27 | else: 28 | d_copy[key] = None 29 | 30 | def iterate_recursively(d): 31 | for k, v in d.items(): 32 | if isinstance(v, (dict, OrderedDict)): 33 | yield from iterate_recursively(v) 34 | else: 35 | yield d, k, v 36 | 37 | def iter_dicts_recursively(d1, d2): 38 | for k, v in d1.items(): 39 | assert k in d2 40 | 41 | if isinstance(v, (dict, OrderedDict)): 42 | yield from iter_dicts_recursively(d1[k], d2[k]) 43 | else: 44 | yield d1, d2, k, d1[k], d2[k] 45 | 46 | 47 | def slice_mems(mems_buffer, mems_dones_buffer, mems_actions_buffer, actor_idx, split_idx, env_idx, s_idx, e_idx): 48 | # Slice given mems buffers in a cyclic queue manner 49 | if s_idx > e_idx: 50 | mems = torch.cat( 51 | [mems_buffer[actor_idx, split_idx, env_idx, s_idx:], 52 | mems_buffer[actor_idx, split_idx, env_idx, :e_idx]]) 53 | mems_dones = torch.cat( 54 | [mems_dones_buffer[actor_idx, split_idx, env_idx, s_idx:], 55 | mems_dones_buffer[actor_idx, split_idx, env_idx, :e_idx]]) 56 | mems_actions = torch.cat( 57 | [mems_actions_buffer[actor_idx, split_idx, env_idx, s_idx:], 58 | mems_actions_buffer[actor_idx, split_idx, env_idx, :e_idx]]) 59 | else: 60 | mems = mems_buffer[actor_idx, split_idx, env_idx, s_idx:e_idx] 61 | mems_dones = mems_dones_buffer[actor_idx, split_idx, env_idx, s_idx:e_idx] 62 | mems_actions = mems_actions_buffer[actor_idx, split_idx, env_idx, s_idx:e_idx] 63 | return mems, mems_dones, mems_actions 64 | 65 | def join_or_kill(process, timeout=1.0): 66 | process.join(timeout) 67 | if process.is_alive(): 68 | log.warning('Process %r could not join, kill it with fire!', process) 69 | process.kill() 70 | log.warning('Process %r is dead (%r)', process, process.is_alive()) 71 | 72 | def set_process_cpu_affinity(worker_idx, num_workers, local_rank=0, nproc_per_node=0): 73 | curr_process = psutil.Process() 74 | available_cores = curr_process.cpu_affinity() 75 | cpu_count = len(available_cores) 76 | if nproc_per_node > 1: 77 | worker_idx = worker_idx * nproc_per_node + local_rank 78 | num_workers = num_workers * nproc_per_node 79 | core_indices = cores_for_worker_process(worker_idx, num_workers, cpu_count) 80 | if core_indices is not None: 81 | curr_process_cores = [available_cores[c] for c in core_indices] 82 | curr_process.cpu_affinity(curr_process_cores) 83 | 84 | log.debug('Worker %d uses CPU cores %r', worker_idx, curr_process.cpu_affinity()) 85 | 86 | def cores_for_worker_process(worker_idx, num_workers, cpu_count): 87 | """ 88 | Returns core indices, assuming available cores are [0, ..., cpu_count). 89 | If this is not the case (e.g. SLURM) use these as indices in the array of actual available cores. 90 | """ 91 | 92 | worker_idx_modulo = worker_idx % cpu_count 93 | 94 | cores = None 95 | whole_workers_per_core = num_workers // cpu_count 96 | if worker_idx < whole_workers_per_core * cpu_count: 97 | cores = [worker_idx_modulo] 98 | else: 99 | remaining_workers = num_workers % cpu_count 100 | if cpu_count % remaining_workers == 0: 101 | cores_to_use = cpu_count // remaining_workers 102 | cores = list(range(worker_idx_modulo * cores_to_use, (worker_idx_modulo + 1) * cores_to_use, 1)) 103 | 104 | return cores 105 | 106 | def safe_put(q, msg, attempts=3, queue_name=''): 107 | safe_put_many(q, [msg], attempts, queue_name) 108 | 109 | 110 | def safe_put_many(q, msgs, attempts=3, queue_name=''): 111 | for attempt in range(attempts): 112 | try: 113 | q.put_many(msgs) 114 | return 115 | except Full: 116 | log.warning('Could not put msgs to queue, the queue %s is full! Attempt %d', queue_name, attempt) 117 | 118 | log.error('Failed to put msgs to queue %s after %d attempts. Messages are lost!', queue_name, attempts) 119 | 120 | def safe_get(q, timeout=1e6, msg='Queue timeout'): 121 | """Using queue.get() with timeout is necessary, otherwise KeyboardInterrupt is not handled.""" 122 | while True: 123 | try: 124 | return q.get(timeout=timeout) 125 | except Empty: 126 | log.info('Queue timed out (%s), timeout %.3f', msg, timeout) -------------------------------------------------------------------------------- /brain_agent/core/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from brain_agent.core.models.model_abc import EncoderBase 4 | from brain_agent.core.models.model_utils import get_obs_shape, calc_num_elements, nonlinearity 5 | from brain_agent.utils.logger import log 6 | 7 | class ResnetEncoder(EncoderBase): 8 | def __init__(self, cfg, obs_space): 9 | super().__init__(cfg) 10 | 11 | obs_shape = get_obs_shape(obs_space) 12 | input_ch = obs_shape.obs[0] 13 | self.input_ch = input_ch 14 | log.debug('Num input channels: %d', input_ch) 15 | 16 | if cfg.model.encoder.encoder_subtype == 'resnet_impala': 17 | resnet_conf = [[16, 2], [32, 2], [32, 2]] 18 | elif cfg.model.encoder.encoder_subtype == 'resnet_impala_large': 19 | resnet_conf = [[32, 2], [64, 2], [64, 2]] 20 | else: 21 | raise NotImplementedError(f'Unknown resnet subtype {cfg.model.encoder.encoder_subtype}') 22 | 23 | curr_input_channels = input_ch 24 | layers = [] 25 | if cfg.learner.use_decoder: 26 | layers_decoder = [] 27 | for i, (out_channels, res_blocks) in enumerate(resnet_conf): 28 | if cfg.model.encoder.encoder_pooling == 'stride': 29 | enc_stride = 2 30 | pool = nn.Identity 31 | else: 32 | enc_stride = 1 33 | pool = nn.MaxPool2d if cfg.model.encoder.encoder_pooling == 'max' else nn.AvgPool2d 34 | layers.extend([ 35 | nn.Conv2d(curr_input_channels, out_channels, kernel_size=3, stride=enc_stride, padding=1), 36 | pool(kernel_size=3, stride=2, padding=1), # padding SAME 37 | ]) 38 | 39 | for j in range(res_blocks): 40 | layers.append(ResBlock(cfg, out_channels, out_channels)) 41 | 42 | if cfg.learner.use_decoder: 43 | for j in range(res_blocks): 44 | layers_decoder.append(ResBlock(cfg, curr_input_channels, curr_input_channels)) 45 | layers_decoder.append( 46 | nn.ConvTranspose2d(out_channels, curr_input_channels, kernel_size=3, stride=2, 47 | padding=1, output_padding=1) 48 | ) 49 | curr_input_channels = out_channels 50 | 51 | layers.append(nonlinearity(cfg)) 52 | 53 | self.conv_head = nn.Sequential(*layers) 54 | self.conv_head_out_size = calc_num_elements(self.conv_head, obs_shape.obs) 55 | log.debug('Convolutional layer output size: %r', self.conv_head_out_size) 56 | self.init_fc_blocks(self.conv_head_out_size) 57 | 58 | if cfg.learner.use_decoder: 59 | layers_decoder.reverse() 60 | self.deconv_head = nn.Sequential(*layers_decoder) 61 | 62 | self.apply(self.initialize) 63 | 64 | def initialize(self, layer): 65 | gain = 1.0 66 | if hasattr(layer, 'bias') and isinstance(layer.bias, torch.nn.parameter.Parameter): 67 | layer.bias.data.fill_(0) 68 | 69 | if self.cfg.model.encoder.encoder_init == 'orthogonal': 70 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear: 71 | nn.init.orthogonal_(layer.weight.data, gain=gain) 72 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell: # TODO: test for LSTM 73 | nn.init.orthogonal_(layer.weight_ih, gain=gain) 74 | nn.init.orthogonal_(layer.weight_hh, gain=gain) 75 | layer.bias_ih.data.fill_(0) 76 | layer.bias_hh.data.fill_(0) 77 | elif self.cfg.model.encoder.encoder_init == 'xavier_uniform': 78 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear: 79 | nn.init.xavier_uniform_(layer.weight.data, gain=gain) 80 | layer.bias.data.fill_(0) 81 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell: 82 | nn.init.xavier_uniform_(layer.weight_ih, gain=gain) 83 | nn.init.xavier_uniform_(layer.weight_hh, gain=gain) 84 | layer.bias_ih.data.fill_(0) 85 | layer.bias_hh.data.fill_(0) 86 | elif self.cfg.model.encoder.encoder_init == 'torch_default': 87 | pass 88 | else: 89 | raise NotImplementedError 90 | 91 | def forward(self, obs_dict, decode=False): 92 | x = self.conv_head(obs_dict['obs']) 93 | 94 | if decode: 95 | self.reconstruction = torch.tanh(self.deconv_head(x)) 96 | 97 | x = x.contiguous().view(-1, self.conv_head_out_size) 98 | x = self.forward_fc_blocks(x) 99 | return x 100 | 101 | 102 | class ResBlock(nn.Module): 103 | def __init__(self, cfg, input_ch, output_ch): 104 | super().__init__() 105 | 106 | layers = [ 107 | nonlinearity(cfg), 108 | nn.Conv2d(input_ch, output_ch, kernel_size=3, stride=1, padding=1), # padding SAME 109 | nonlinearity(cfg), 110 | nn.Conv2d(output_ch, output_ch, kernel_size=3, stride=1, padding=1), # padding SAME 111 | ] 112 | 113 | self.res_block_core = nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | identity = x 117 | out = self.res_block_core(x) 118 | out = out + identity 119 | return out 120 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import numpy as np 4 | from faster_fifo import Queue, Empty 5 | from tensorboardX import SummaryWriter 6 | from brain_agent.core.actor_worker import ActorWorker 7 | from brain_agent.core.policy_worker import PolicyWorker 8 | from brain_agent.core.shared_buffer import SharedBuffer 9 | from brain_agent.utils.cfg import Configs 10 | from brain_agent.utils.utils import get_log_path, dict_of_list_put, AttrDict, get_summary_dir 11 | from brain_agent.core.core_utils import TaskType 12 | from brain_agent.utils.logger import log, init_logger 13 | from brain_agent.envs.env_utils import create_env 14 | 15 | def main(): 16 | 17 | cfg = Configs.get_defaults() 18 | cfg = Configs.override_from_file_name(cfg) 19 | cfg = Configs.override_from_cli(cfg) 20 | 21 | cfg_str = Configs.to_yaml(cfg) 22 | cfg = Configs.to_attr_dict(cfg) 23 | 24 | init_logger(cfg.log.log_level, get_log_path(cfg)) 25 | 26 | log.info(f'Experiment configuration:\n{cfg_str}') 27 | 28 | tmp_env = create_env(cfg, env_config=None) 29 | action_space = tmp_env.action_space 30 | obs_space = tmp_env.observation_space 31 | level_info = tmp_env.level_info 32 | num_levels = level_info['num_levels'] 33 | tmp_env.close() 34 | 35 | assert cfg.env.one_task_per_worker 36 | assert cfg.test.is_test 37 | assert cfg.actor.num_workers >= level_info['num_levels'] 38 | 39 | shared_buffer = SharedBuffer(cfg, obs_space, action_space) 40 | shared_buffer.stop_experience_collection.fill_(False) 41 | 42 | policy_worker_queue = Queue() 43 | actor_worker_queues = [Queue(2 * 1000 * 1000) for _ in range(cfg.actor.num_workers)] 44 | policy_queue = Queue() 45 | report_queue = Queue(40 * 1000 * 1000) 46 | 47 | policy_worker = PolicyWorker(cfg, obs_space, action_space, tmp_env.level_info, shared_buffer, 48 | policy_queue, actor_worker_queues, policy_worker_queue, report_queue) 49 | policy_worker.start_process() 50 | policy_worker.init() 51 | policy_worker.load_model() 52 | 53 | actor_workers = [] 54 | for i in range(cfg.actor.num_workers): 55 | w = ActorWorker(cfg, obs_space, action_space, i, shared_buffer, actor_worker_queues[i], policy_queue, 56 | report_queue) 57 | w.init() 58 | w.request_reset() 59 | actor_workers.append(w) 60 | 61 | writer = SummaryWriter(get_summary_dir(cfg, postfix='test')) 62 | 63 | stats = AttrDict() 64 | stats['episodic_stats'] = AttrDict() 65 | actor_worker_task_id = AttrDict() 66 | 67 | env_steps = 0 68 | num_collected = 0 69 | terminate = False 70 | 71 | while not terminate: 72 | try: 73 | reports = report_queue.get_many(timeout=0.1) 74 | for report in reports: 75 | if 'terminate' in report: 76 | terminate = True 77 | if 'learner_env_steps' in report: 78 | env_steps = report['learner_env_steps'] 79 | if 'initialized_env' in report: 80 | actor_idx, split_idx, _, task_id = report['initialized_env'] 81 | actor_worker_task_id[actor_idx] = task_id[0] 82 | if 'episodic_stats' in report: 83 | s = report['episodic_stats'] 84 | level_name = s['level_name'].replace('_contributed/dmlab30/', '') 85 | level_id = s['task_id'] 86 | 87 | tag = f'_dmlab/{level_id:02d}_{level_name}_human_norm_score' 88 | dict_of_list_put(stats.episodic_stats, tag, s['hns'], cfg.test.test_num_episodes) 89 | 90 | hns = s['hns'] 91 | log.info(f'[{num_collected} / {num_levels * cfg.test.test_num_episodes}] {level_id:02d}_' 92 | f'{level_name}: {hns}') 93 | 94 | if len(stats.episodic_stats[tag]) >= cfg.test.test_num_episodes: 95 | for i, w in enumerate(actor_workers): 96 | if actor_worker_task_id[i] == level_id and w.process.is_alive: 97 | actor_worker_queues[i].put((TaskType.TERMINATE, None)) 98 | 99 | hns = [] 100 | num_collected = 0 101 | for i, l in enumerate(level_info['all_levels']): 102 | tag = f'_dmlab/{i:02d}_{l}_human_norm_score' 103 | h = stats.episodic_stats.get(tag, None) 104 | if h is not None: 105 | num_collected += len(h) 106 | hns.append(np.array(h).mean()) 107 | 108 | if num_collected >= num_levels * cfg.test.test_num_episodes: 109 | hns = np.array(hns) 110 | capped_hns = np.clip(hns, None, 100) 111 | log.info('-' * 100) 112 | log.info(f'num_collected: {num_collected}') 113 | log.info(f'mean_human_norm_score: {hns.mean()}') 114 | log.info(f'mean_capped_human_norm_score: {capped_hns.mean()}') 115 | log.info(f'median_human_norm_score: {np.median(hns)}') 116 | for i, l in enumerate(level_info['all_levels']): 117 | tag = f'_dmlab/{i:02d}_{l}_human_norm_score' 118 | h = stats.episodic_stats[tag] 119 | log.info(f'{tag}: {np.array(h).mean()}') 120 | 121 | writer.add_scalar(f'_dmlab/000_mean_human_norm_score', hns.mean(), env_steps) 122 | writer.add_scalar(f'_dmlab/000_mean_capped_human_norm_score', capped_hns.mean(), env_steps) 123 | writer.add_scalar(f'_dmlab/000_median_human_norm_score', np.median(hns), env_steps) 124 | for tag, scalar in stats.episodic_stats.items(): 125 | writer.add_scalar(tag, np.array(scalar).mean(), env_steps) 126 | 127 | terminate = True 128 | 129 | except Empty: 130 | time.sleep(1.0) 131 | pass 132 | 133 | if __name__ == '__main__': 134 | sys.exit(main()) -------------------------------------------------------------------------------- /brain_agent/core/models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | from torch.nn.utils.rnn import PackedSequence, invert_permutation 5 | 6 | def _build_pack_info_from_dones(dones: torch.Tensor, T: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 7 | num_samples = len(dones) 8 | 9 | rollout_boundaries = dones.clone().detach() 10 | rollout_boundaries[T - 1::T] = 1 11 | rollout_boundaries = rollout_boundaries.nonzero().squeeze(dim=1) + 1 12 | 13 | first_len = rollout_boundaries[0].unsqueeze(0) 14 | 15 | if len(rollout_boundaries) <= 1: 16 | rollout_lengths = first_len 17 | else: 18 | rollout_lengths = rollout_boundaries[1:] - rollout_boundaries[:-1] 19 | rollout_lengths = torch.cat([first_len, rollout_lengths]) 20 | 21 | rollout_starts_orig = rollout_boundaries - rollout_lengths 22 | 23 | is_new_episode = dones.clone().detach().view((-1, T)) 24 | is_new_episode = is_new_episode.roll(1, 1) 25 | 26 | is_new_episode[:, 0] = 0 27 | is_new_episode = is_new_episode.view((-1, )) 28 | 29 | lengths, sorted_indices = torch.sort(rollout_lengths, descending=True) 30 | 31 | cpu_lengths = lengths.to(device='cpu', non_blocking=True) 32 | 33 | rollout_starts_sorted = rollout_starts_orig.index_select(0, sorted_indices) 34 | 35 | select_inds = torch.empty(num_samples, device=dones.device, dtype=torch.int64) 36 | 37 | max_length = int(cpu_lengths[0].item()) 38 | 39 | batch_sizes = torch.empty((max_length,), device='cpu', dtype=torch.int64) 40 | 41 | offset = 0 42 | prev_len = 0 43 | num_valid_for_length = lengths.size(0) 44 | 45 | unique_lengths = torch.unique_consecutive(cpu_lengths) 46 | 47 | for i in range(len(unique_lengths) - 1, -1, -1): 48 | valids = lengths[0:num_valid_for_length] > prev_len 49 | num_valid_for_length = int(valids.float().sum().item()) 50 | 51 | next_len = int(unique_lengths[i]) 52 | 53 | batch_sizes[prev_len:next_len] = num_valid_for_length 54 | 55 | new_inds = ( 56 | rollout_starts_sorted[0:num_valid_for_length].view(1, num_valid_for_length) 57 | + torch.arange(prev_len, next_len, device=rollout_starts_sorted.device).view(next_len - prev_len, 1) 58 | ).view(-1) 59 | 60 | select_inds[offset:offset + new_inds.numel()] = new_inds 61 | 62 | offset += new_inds.numel() 63 | 64 | prev_len = next_len 65 | 66 | assert offset == num_samples 67 | assert is_new_episode.shape[0] == num_samples 68 | 69 | return rollout_starts_orig, is_new_episode, select_inds, batch_sizes, sorted_indices 70 | 71 | def _build_rnn_inputs(x, dones_cpu, rnn_states, T: int): 72 | rollout_starts, is_new_episode, select_inds, batch_sizes, sorted_indices = _build_pack_info_from_dones( 73 | dones_cpu, T) 74 | inverted_select_inds = invert_permutation(select_inds) 75 | 76 | def device(t): 77 | return t.to(device=x.device) 78 | 79 | select_inds = device(select_inds) 80 | inverted_select_inds = device(inverted_select_inds) 81 | sorted_indices = device(sorted_indices) 82 | rollout_starts = device(rollout_starts) 83 | is_new_episode = device(is_new_episode) 84 | 85 | x_seq = PackedSequence(x.index_select(0, select_inds), batch_sizes, sorted_indices) 86 | 87 | rnn_states = rnn_states.index_select(0, rollout_starts) 88 | is_same_episode = (1 - is_new_episode.view(-1, 1)).index_select(0, rollout_starts) 89 | rnn_states = rnn_states * is_same_episode 90 | 91 | return x_seq, rnn_states, inverted_select_inds 92 | 93 | 94 | class LSTM(nn.Module): 95 | def __init__(self, cfg, input_size): 96 | super().__init__() 97 | 98 | self.cfg = cfg 99 | self.core = nn.LSTM(input_size, cfg.model.core.hidden_size, cfg.model.core.n_rnn_layer) 100 | 101 | self.core_output_size = cfg.model.core.hidden_size 102 | self.n_rnn_layer = cfg.model.core.n_rnn_layer 103 | self.apply(self.initialize) 104 | 105 | def initialize(self, layer): 106 | gain = 1.0 107 | 108 | if self.cfg.model.core.core_init == 'tensorflow_default': 109 | if type(layer) == nn.LSTM: 110 | for n, p in layer.named_parameters(): 111 | if 'weight_ih' in n: 112 | nn.init.xavier_uniform_(p.data, gain=gain) 113 | elif 'weight_hh' in n: 114 | nn.init.orthogonal_(p.data) 115 | elif 'bias_ih' in n: 116 | p.data.fill_(0) 117 | # Set forget-gate bias to 1 118 | n = p.size(0) 119 | p.data[(n // 4):(n // 2)].fill_(1) 120 | elif 'bias_hh' in n: 121 | p.data.fill_(0) 122 | elif self.cfg.model.core.core_init == 'torch_default': 123 | pass 124 | else: 125 | raise NotImplementedError 126 | 127 | def forward(self, head_output, rnn_states, dones, is_seq): 128 | if not is_seq: 129 | head_output = head_output.unsqueeze(0) 130 | 131 | if self.n_rnn_layer > 1: 132 | rnn_states = rnn_states.view(rnn_states.size(0), self.cfg.model.core.n_rnn_layer, -1) 133 | rnn_states = rnn_states.permute(1, 0, 2) 134 | else: 135 | rnn_states = rnn_states.unsqueeze(0) 136 | 137 | h, c = torch.split(rnn_states, self.cfg.model.core.hidden_size, dim=2) 138 | 139 | x, (h, c) = self.core(head_output, (h.contiguous(), c.contiguous())) 140 | new_rnn_states = torch.cat((h, c), dim=2) 141 | 142 | if not is_seq: 143 | x = x.squeeze(0) 144 | 145 | if self.n_rnn_layer > 1: 146 | new_rnn_states = new_rnn_states.permute(1, 0, 2) 147 | new_rnn_states = new_rnn_states.reshape(new_rnn_states.size(0), -1) 148 | else: 149 | new_rnn_states = new_rnn_states.squeeze(0) 150 | 151 | return x, new_rnn_states 152 | 153 | def get_core_out_size(self): 154 | return self.core_output_size 155 | 156 | @classmethod 157 | def build_rnn_inputs(cls, x, dones_cpu, rnn_states, T: int): 158 | return _build_rnn_inputs(x, dones_cpu, rnn_states, T) 159 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | cfg: None # path to configuration (.yaml) file 2 | train_dir: ??? # path to train dir 3 | experiment: ??? # experiment name. logs will be saved train_dir/experiment 4 | 5 | seed: 0 6 | 7 | # dist arguments will be automatically set from dist_launch.py 8 | dist: 9 | world_size: 1 10 | world_rank: 0 11 | local_rank: 0 12 | nproc_per_node: 1 13 | dist_backend: nccl 14 | 15 | model: 16 | agent: dmlab_multitask_agent # [dmlab_multitask_agent] 17 | encoder: 18 | encoder_type: resnet # [resnet]. encoder type. only resnet is currently implemented 19 | encoder_subtype: resnet_impala # [resnet_impala, resnet_impala_large]. specific encoder architecture. see resnet.py 20 | encoder_pooling: max # [max, stride, ???] pooling type. max pooling or conv2d with stride 2. Use avg pooling if not specified. 21 | nonlinearity: relu # [relu, elu, tanh] activation function used in encoder 22 | nonlinear_inplace: False # nonlinearity will be computed "inplace" if True 23 | encoder_extra_fc_layers: 1 # the number of extra fully connected layers that connect encoder outputs and core inputs 24 | encoder_init: orthogonal # [orthogonal, xavier_uniform, torch_default] initialization method used in encoder 25 | core: 26 | core_type: trxl # [trxl, rnn] core type. 27 | hidden_size: 256 # size of hidden layer in the model 28 | # trxl params 29 | mem_len: 512 # memory length 30 | n_layer: 12 # number of layers 31 | n_heads: 8 # number of MHA heads 32 | d_head: 64 # MHA head dimension 33 | d_inner: 2048 # the position wise ff network dimension 34 | # rnn_params 35 | n_rnn_layer: 3 36 | core_init: tensorflow_default # [tensorflow_default, torch_default] initialization method used in core 37 | 38 | extended_input: True # concatenate ont_hot action and reward to encoder output if set True 39 | use_popart: True # whether to use PopArt normalization or not 40 | popart_beta: 0.0003 # decay rate to track mean and standard derivation of the values 41 | popart_clip_min: 0.0001 # popart minimum clip value for numerical stability 42 | device: cuda 43 | use_half_policy_worker: True # use half-precision in Policy Worker 44 | 45 | test: 46 | is_test: False # set True when testing 47 | checkpoint: ??? # full-path to checkpoint (.pth) 48 | test_num_episodes: 100 # the number of test episodes per level 49 | 50 | optim: 51 | type: adam # [adam, adamw] type of optimizer 52 | learning_rate: 0.0001 # learning rate 53 | adam_beta1: 0.9 # adam momentum decay coefficient 54 | adam_beta2: 0.999 # adam second momentum decay coefficient 55 | adam_eps: 1e-06 # adam epsilon parameter (1e-8 to 1e-5 seem to reliably work okay) 56 | max_grad_norm: 2.5 # maximum L2 norm of the gradient vector (for clipping) 57 | rollout: 96 # length of the rollout from each environment in timesteps 58 | batch_size: 1536 # total batch size (batch X rollout) 59 | train_for_env_steps: 20000000000 # stop after a policy is trained for this many env steps 60 | warmup_optimizer: 100 # warm up step for leaner, optimizer step count 61 | 62 | shared_buffer: 63 | min_traj_buffers_per_worker: 4 # how many shared buffers to allocate per actor worker 64 | 65 | learner: 66 | exploration_loss_coeff: 1e-3 # coefficient for the exploration component of the loss function 67 | vtrace_rho: 1.0 # rho_hat clipping parameter of the V-trace algorithm 68 | vtrace_c: 1.0 # c_hat clipping parameter of the V-trace algorithm 69 | exclude_last: True # exclude last timestep vs when computing V-trace 70 | use_ppo: False # whether to use ppo or not 71 | ppo_clip_ratio: 0.1 # clipping ratio value for ppo algorithm 72 | ppo_clip_value: 0.2 # maximum absolute change in value estimate until it is clipped 73 | gamma: 0.99 # discount factor 74 | value_loss_coeff: 0.5 # coefficient for the critic loss 75 | keep_checkpoints: 3 # number of model checkpoints to keep 76 | use_adv_normalization: False # whether to use advantage normalization or not 77 | psychlab_gamma: -1.0 # specific gamma value for psychlab levels. use "gamma" if <0 78 | use_decoder: False # whether to use decoder for auxiliary reconstruction component or not 79 | reconstruction_loss_coeff: 0.0 # coefficient for auxiliary reconstruction component of the loss function 80 | use_aux_future_pred_loss: False # whether to add auxiliary loss for predicting obs of 2 to 10 future steps using auto regressive trnasformers 81 | aux_future_pred_loss_coeff: 1.0 # coefficient for auxiliary future predict component of the loss function autoregressive trnasformers 82 | resume_training : False # load latest checkpoint if set to True 83 | 84 | actor: 85 | num_workers: 30 # the number of actor processes for this node 86 | num_envs_per_worker: 6 # number of envs on sampled sequentially 87 | num_splits: 2 # set 2 for double buffering 88 | set_workers_cpu_affinity: True # whether to assign workers to specific CPU cores or not 89 | 90 | env: 91 | name: dmlab_30 # name of environment. see DMLAB_LEVELS_BY_ENVNAME in dmlab30.py 92 | res_w: 96 # width of image frame 93 | res_h: 72 # height of image frame 94 | dataset_path: ./brady_konkle_oliva2008 # path to dataset needed for some of the environments in DMLab-30 95 | use_level_cache: True # whether to use the local level cache (highly recommended) 96 | level_cache_path: ./dmlab_cache # location to store the cached levels 97 | action_set: extended_action_set # [impala_action_set, extended_action_set, extended_action_set_large]. see dmlab30.py 98 | frameskip: 4 # the number of frames for action repeat (frame skipping) 99 | obs_subtract_mean: 128.0 # mean value to subtract from observation 100 | obs_scale: 128.0 # scale value to for normalization 101 | one_task_per_worker: False # every envs in VectorEnvRunner will run the same level if set to True 102 | decorrelate_envs_on_one_worker: True # decorrelation of worker processes 103 | decorrelate_experience_max_seconds: 10 # maximum seconds for decorrelation 104 | 105 | log: 106 | save_milestones_step: 1000000000 # save intermediate checkpoints in a separate folder for later evaluation 107 | save_every_sec: 3600 # checkpointing rate 108 | log_level: debug # logging level 109 | report_interval: 300.0 # how often in seconds we write summaries 110 | num_stats_average: 100 # the number of stats to average (for tensorboard logging) -------------------------------------------------------------------------------- /brain_agent/envs/dmlab/dmlab_level_cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ctypes 3 | import multiprocessing 4 | import random 5 | import shutil 6 | from pathlib import Path 7 | 8 | from brain_agent.utils.logger import log 9 | 10 | LEVEL_SEEDS_FILE_EXT = 'dm_lvl_seeds' 11 | 12 | 13 | def filename_to_level(filename): 14 | level = filename.split('.')[0] 15 | level = level[1:] 16 | return level 17 | 18 | 19 | def level_to_filename(level): 20 | filename = f'_{level}.{LEVEL_SEEDS_FILE_EXT}' 21 | return filename 22 | 23 | 24 | def read_seeds_file(filename, has_keys): 25 | seeds = [] 26 | 27 | with open(filename, 'r') as seed_file: 28 | lines = seed_file.readlines() 29 | for line in lines: 30 | try: 31 | if has_keys: 32 | seed, cache_key = line.split(' ') 33 | else: 34 | seed = line 35 | 36 | seed = int(seed) 37 | seeds.append(seed) 38 | except Exception: 39 | pass 40 | 41 | return seeds 42 | 43 | 44 | class DmlabLevelCacheGlobal: 45 | 46 | def __init__(self, cache_dir, experiment_dir, levels): 47 | self.cache_dir = cache_dir 48 | self.experiment_dir = experiment_dir 49 | 50 | self.all_seeds = dict() 51 | self.available_seeds = dict() 52 | self.used_seeds = dict() 53 | self.num_seeds_used_in_current_run = dict() 54 | self.locks = dict() 55 | 56 | for lvl in levels: 57 | self.all_seeds[lvl] = [] 58 | self.available_seeds[lvl] = [] 59 | self.num_seeds_used_in_current_run[lvl] = multiprocessing.RawValue(ctypes.c_int32, 0) 60 | self.locks[lvl] = multiprocessing.Lock() 61 | 62 | lvl_seed_files = Path(os.path.join(cache_dir, '_contributed')).rglob(f'*.{LEVEL_SEEDS_FILE_EXT}') 63 | for lvl_seed_file in lvl_seed_files: 64 | lvl_seed_file = str(lvl_seed_file) 65 | level = filename_to_level(os.path.relpath(lvl_seed_file, cache_dir)) 66 | self.all_seeds[level] = read_seeds_file(lvl_seed_file, has_keys=True) 67 | self.all_seeds[level] = list(set(self.all_seeds[level])) 68 | 69 | used_lvl_seeds_dir = os.path.join(self.experiment_dir, f'dmlab_used_lvl_seeds') 70 | os.makedirs(used_lvl_seeds_dir, exist_ok=True) 71 | 72 | used_seeds_files = Path(used_lvl_seeds_dir).rglob(f'*.{LEVEL_SEEDS_FILE_EXT}') 73 | self.used_seeds = dict() 74 | for used_seeds_file in used_seeds_files: 75 | used_seeds_file = str(used_seeds_file) 76 | level = filename_to_level(os.path.relpath(used_seeds_file, used_lvl_seeds_dir)) 77 | self.used_seeds[level] = read_seeds_file(used_seeds_file, has_keys=False) 78 | 79 | self.used_seeds[level] = set(self.used_seeds[level]) 80 | 81 | for lvl in levels: 82 | lvl_seeds = self.all_seeds.get(lvl, []) 83 | lvl_used_seeds = self.used_seeds.get(lvl, []) 84 | 85 | lvl_remaining_seeds = set(lvl_seeds) - set(lvl_used_seeds) 86 | self.available_seeds[lvl] = list(lvl_remaining_seeds) 87 | 88 | random.shuffle(self.available_seeds[lvl]) 89 | log.debug('Env %s has %d remaining unused seeds', lvl, len(self.available_seeds[lvl])) 90 | 91 | def record_used_seed(self, level, seed): 92 | self.num_seeds_used_in_current_run[level].value += 1 93 | 94 | used_lvl_seeds_dir = os.path.join(self.experiment_dir, f'dmlab_used_lvl_seeds') 95 | used_seeds_filename = os.path.join(used_lvl_seeds_dir, level_to_filename(level)) 96 | os.makedirs(os.path.dirname(used_seeds_filename), exist_ok=True) 97 | 98 | with open(used_seeds_filename, 'a') as fobj: 99 | fobj.write(f'{seed}\n') 100 | 101 | if level not in self.used_seeds: 102 | self.used_seeds[level] = {seed} 103 | else: 104 | self.used_seeds[level].add(seed) 105 | 106 | def get_unused_seed(self, level, random_state=None): 107 | with self.locks[level]: 108 | num_used_seeds = self.num_seeds_used_in_current_run[level].value 109 | if num_used_seeds >= len(self.available_seeds.get(level, [])): 110 | 111 | while True: 112 | if random_state is not None: 113 | new_seed = random_state.randint(0, 2 ** 31 - 1) 114 | else: 115 | new_seed = random.randint(0, 2 ** 31 - 1) 116 | 117 | if level not in self.used_seeds: 118 | break 119 | 120 | if new_seed in self.used_seeds[level]: 121 | pass 122 | else: 123 | break 124 | else: 125 | new_seed = self.available_seeds[level][num_used_seeds] 126 | 127 | self.record_used_seed(level, new_seed) 128 | return new_seed 129 | 130 | def add_new_level(self, level, seed, key, pk3_path): 131 | with self.locks[level]: 132 | num_used_seeds = self.num_seeds_used_in_current_run[level].value 133 | if num_used_seeds < len(self.available_seeds.get(level, [])): 134 | log.warning('We should only add new levels to cache if we ran out of pre-generated levels (seeds)') 135 | log.warning( 136 | 'Num used seeds: %d, available seeds: %d, level: %s, seed %r, key %r', 137 | num_used_seeds, len(self.available_seeds.get(level, [])), level, seed, key, 138 | ) 139 | 140 | path = os.path.join(self.cache_dir, key) 141 | if not os.path.isfile(path): 142 | shutil.copyfile(pk3_path, path) 143 | 144 | lvl_seeds_filename = os.path.join(self.cache_dir, level_to_filename(level)) 145 | os.makedirs(os.path.dirname(lvl_seeds_filename), exist_ok=True) 146 | with open(lvl_seeds_filename, 'a') as fobj: 147 | fobj.write(f'{seed} {key}\n') 148 | 149 | 150 | def dmlab_ensure_global_cache_initialized(experiment_dir, levels, level_cache_dir): 151 | global DMLAB_GLOBAL_LEVEL_CACHE 152 | 153 | assert multiprocessing.current_process().name == 'MainProcess', \ 154 | 'make sure you initialize DMLab cache before child processes are forked' 155 | 156 | log.info('Initializing level cache...') 157 | DMLAB_GLOBAL_LEVEL_CACHE = DmlabLevelCacheGlobal(level_cache_dir, experiment_dir, levels) 158 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from faster_fifo import Queue, Empty 4 | import multiprocessing 5 | from tensorboardX import SummaryWriter 6 | import numpy as np 7 | 8 | from brain_agent.core.actor_worker import ActorWorker 9 | from brain_agent.core.policy_worker import PolicyWorker 10 | from brain_agent.core.shared_buffer import SharedBuffer 11 | from brain_agent.core.learner_worker import LearnerWorker 12 | from brain_agent.utils.cfg import Configs 13 | from brain_agent.utils.logger import log, init_logger 14 | from brain_agent.utils.utils import get_log_path, dict_of_list_put, get_summary_dir, AttrDict 15 | from brain_agent.envs.env_utils import create_env 16 | 17 | 18 | def main(): 19 | 20 | cfg = Configs.get_defaults() 21 | cfg = Configs.override_from_file_name(cfg) 22 | cfg = Configs.override_from_cli(cfg) 23 | 24 | cfg_str = Configs.to_yaml(cfg) 25 | cfg = Configs.to_attr_dict(cfg) 26 | 27 | init_logger(cfg.log.log_level, get_log_path(cfg)) 28 | 29 | log.info(f'Experiment configuration:\n{cfg_str}') 30 | 31 | tmp_env = create_env(cfg, env_config=None) 32 | action_space = tmp_env.action_space 33 | obs_space = tmp_env.observation_space 34 | level_info = tmp_env.level_info 35 | tmp_env.close() 36 | 37 | shared_buffer = SharedBuffer(cfg, obs_space, action_space) 38 | 39 | learner_worker_queue = Queue() 40 | policy_worker_queue = Queue() 41 | actor_worker_queues = [Queue(2 * 1000 * 1000) for _ in range(cfg.actor.num_workers)] 42 | policy_queue = Queue() 43 | report_queue = Queue(40 * 1000 * 1000) 44 | 45 | policy_lock = multiprocessing.Lock() 46 | resume_experience_collection_cv = multiprocessing.Condition() 47 | 48 | learner_worker = LearnerWorker(cfg, obs_space, action_space, level_info, report_queue, learner_worker_queue, 49 | policy_worker_queue, 50 | shared_buffer, policy_lock, resume_experience_collection_cv) 51 | learner_worker.start_process() 52 | learner_worker.init() 53 | 54 | policy_worker = PolicyWorker(cfg, obs_space, action_space, level_info, shared_buffer, 55 | policy_queue, actor_worker_queues, policy_worker_queue, report_queue, policy_lock, resume_experience_collection_cv) 56 | policy_worker.start_process() # init(), init_model() will be triggered from learner worker 57 | 58 | actor_workers = [] 59 | for i in range(cfg.actor.num_workers): 60 | w = ActorWorker(cfg, obs_space, action_space, i, shared_buffer, actor_worker_queues[i], policy_queue, 61 | report_queue, learner_worker_queue) 62 | w.init() 63 | w.request_reset() 64 | actor_workers.append(w) 65 | 66 | summary_dir = get_summary_dir(cfg=cfg) 67 | writer = SummaryWriter(summary_dir) if cfg.dist.world_rank == 0 else None 68 | 69 | # Add configuration in tensorboard 70 | if cfg.dist.world_rank == 0: 71 | cfg_str = cfg_str.replace(' ', ' ').replace('\n', ' \n') 72 | writer.add_text('cfg', cfg_str, 0) 73 | 74 | stats = AttrDict() 75 | stats['episodic_stats'] = AttrDict() 76 | 77 | last_report = time.time() 78 | last_env_steps = 0 79 | terminate = False 80 | reports = [] 81 | 82 | while not terminate: 83 | try: 84 | reports.extend(report_queue.get_many(timeout=0.1)) 85 | 86 | if time.time() - last_report > cfg.log.report_interval: 87 | interval = time.time() - last_report 88 | last_report = time.time() 89 | terminate, last_env_steps = process_report(cfg, reports, writer, stats, last_env_steps, level_info, 90 | interval) 91 | reports = [] 92 | except Empty: 93 | time.sleep(1.0) 94 | pass 95 | 96 | def process_report(cfg, reports, writer, stats, last_env_steps, level_info, interval): 97 | terminate = False 98 | env_steps = last_env_steps 99 | 100 | for report in reports: 101 | if report is not None: 102 | if 'terminate' in report: 103 | terminate = True 104 | if 'learner_env_steps' in report: 105 | env_steps = report['learner_env_steps'] 106 | if 'train' in report: 107 | s = report['train'] 108 | for k, v in s.items(): 109 | dict_of_list_put(stats, f'train/{k}', v, cfg.log.num_stats_average) 110 | if 'episodic_stats' in report: 111 | s = report['episodic_stats'] 112 | level_name = s['level_name'] 113 | level_id = s['task_id'] 114 | 115 | tag = f'_dmlab/{level_id:02d}_{level_name}_human_norm_score' 116 | dict_of_list_put(stats.episodic_stats, tag, s['hns'], cfg.log.num_stats_average) 117 | 118 | fps = (env_steps - last_env_steps) / interval 119 | dict_of_list_put(stats, f'train/_fps', fps, cfg.log.num_stats_average) 120 | 121 | key_timings = ['times_learner_worker', 'times_actor_worker', 'times_policy_worker'] 122 | for key in key_timings: 123 | if key in report: 124 | for k, v in report[key].items(): 125 | tag = key+'/'+k 126 | dict_of_list_put(stats, tag, v, cfg.log.num_stats_average) 127 | 128 | 129 | if writer is not None: 130 | for k, v in stats.items(): 131 | if k == 'episodic_stats': 132 | hns = [] 133 | for kk, vv in v.items(): 134 | writer.add_scalar(kk, np.array(vv).mean(), env_steps) 135 | hns.append(np.array(vv).mean()) 136 | 137 | if len(v.keys()) == level_info['num_levels']: 138 | hns = np.array(hns) 139 | capped_hns = np.clip(hns, None, 100) 140 | writer.add_scalar(f'_dmlab/000_mean_human_norm_score', hns.mean(), env_steps) 141 | writer.add_scalar(f'_dmlab/000_mean_capped_human_norm_score', capped_hns.mean(), env_steps) 142 | writer.add_scalar(f'_dmlab/000_median_human_norm_score', np.median(hns), env_steps) 143 | else: 144 | writer.add_scalar(k, np.array(v).mean(), env_steps) 145 | 146 | if env_steps >= cfg.optim.train_for_env_steps: 147 | terminate = True 148 | return terminate, env_steps 149 | 150 | 151 | if __name__ == '__main__': 152 | sys.exit(main()) 153 | -------------------------------------------------------------------------------- /brain_agent/envs/dmlab/dmlab_gym.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import hashlib 5 | import deepmind_lab 6 | import gym 7 | import numpy as np 8 | 9 | from brain_agent.envs.dmlab import dmlab_level_cache 10 | from brain_agent.envs.dmlab.dmlab30 import DMLAB_INSTRUCTIONS, DMLAB_MAX_INSTRUCTION_LEN, DMLAB_VOCABULARY_SIZE, \ 11 | IMPALA_ACTION_SET, EXTENDED_ACTION_SET, EXTENDED_ACTION_SET_LARGE 12 | from brain_agent.utils.logger import log 13 | 14 | 15 | def string_to_hash_bucket(s, vocabulary_size): 16 | return (int(hashlib.md5(s.encode('utf-8')).hexdigest(), 16) % (vocabulary_size - 1)) + 1 17 | 18 | 19 | def dmlab_level_to_level_name(level): 20 | level_name = level.split('/')[-1] 21 | return level_name 22 | 23 | 24 | class DmlabGymEnv(gym.Env): 25 | def __init__( 26 | self, task_id, level, action_repeat, res_w, res_h, dataset_path, 27 | action_set, use_level_cache, level_cache_path, extra_cfg=None, 28 | ): 29 | self.width = res_w 30 | self.height = res_h 31 | 32 | self.main_observation = 'RGB_INTERLEAVED' 33 | self.instructions_observation = DMLAB_INSTRUCTIONS 34 | 35 | self.action_repeat = action_repeat 36 | 37 | self.random_state = None 38 | 39 | self.task_id = task_id 40 | self.level = level 41 | self.level_name = dmlab_level_to_level_name(self.level) 42 | 43 | self.cache = dmlab_level_cache.DMLAB_GLOBAL_LEVEL_CACHE 44 | 45 | self.instructions = np.zeros([DMLAB_MAX_INSTRUCTION_LEN], dtype=np.int32) 46 | 47 | observation_format = [self.main_observation] 48 | observation_format += [self.instructions_observation] 49 | 50 | config = { 51 | 'width': self.width, 52 | 'height': self.height, 53 | 'datasetPath': dataset_path, 54 | 'logLevel': 'error', 55 | } 56 | 57 | if extra_cfg is not None: 58 | config.update(extra_cfg) 59 | config = {k: str(v) for k, v in config.items()} 60 | 61 | self.use_level_cache = use_level_cache 62 | self.level_cache_path = level_cache_path 63 | 64 | env_level_cache = self if use_level_cache else None 65 | self.env_uses_level_cache = False # will be set to True when this env instance queries the cache 66 | self.last_reset_seed = None 67 | 68 | if env_level_cache is not None: 69 | if not isinstance(self.cache, dmlab_level_cache.DmlabLevelCacheGlobal): 70 | raise Exception( 71 | 'DMLab global level cache object is not initialized! Make sure to call' 72 | 'dmlab_ensure_global_cache_initialized() in the main thread before you fork any child processes' 73 | 'or create any DMLab envs' 74 | ) 75 | 76 | self.dmlab = deepmind_lab.Lab( 77 | level, observation_format, config=config, renderer='software', level_cache=env_level_cache, 78 | ) 79 | 80 | if action_set == 'impala_action_set': 81 | self.action_set = IMPALA_ACTION_SET 82 | elif action_set == 'extended_action_set': 83 | self.action_set = EXTENDED_ACTION_SET 84 | elif action_set == 'extended_action_set_large': 85 | self.action_set = EXTENDED_ACTION_SET_LARGE 86 | self.action_list = np.array(self.action_set, dtype=np.intc) # DMLAB requires intc type for actions 87 | 88 | self.last_observation = None 89 | 90 | self.action_space = gym.spaces.Discrete(len(self.action_set)) 91 | 92 | self.observation_space = gym.spaces.Dict( 93 | obs=gym.spaces.Box(low=0, high=255, shape=(self.height, self.width, 3), dtype=np.uint8) 94 | ) 95 | self.observation_space.spaces[self.instructions_observation] = gym.spaces.Box( 96 | low=0, high=DMLAB_VOCABULARY_SIZE, shape=[DMLAB_MAX_INSTRUCTION_LEN], dtype=np.int32, 97 | ) 98 | 99 | self.seed() 100 | 101 | def seed(self, seed=None): 102 | if seed is None: 103 | initial_seed = random.randint(0, int(1e9)) 104 | else: 105 | initial_seed = seed 106 | 107 | self.random_state = np.random.RandomState(seed=initial_seed) 108 | return [initial_seed] 109 | 110 | def format_obs_dict(self, env_obs_dict): 111 | """We traditionally uses 'obs' key for the 'main' observation.""" 112 | 113 | env_obs_dict['obs'] = env_obs_dict.pop(self.main_observation) 114 | 115 | instr = env_obs_dict.get(self.instructions_observation) 116 | self.instructions[:] = 0 117 | if instr is not None: 118 | instr_words = instr.split() 119 | for i, word in enumerate(instr_words): 120 | self.instructions[i] = string_to_hash_bucket(word, DMLAB_VOCABULARY_SIZE) 121 | 122 | env_obs_dict[self.instructions_observation] = self.instructions 123 | 124 | return env_obs_dict 125 | 126 | def reset(self): 127 | if self.use_level_cache: 128 | self.last_reset_seed = self.cache.get_unused_seed(self.level, self.random_state) 129 | else: 130 | self.last_reset_seed = self.random_state.randint(0, 2 ** 31 - 1) 131 | 132 | self.dmlab.reset(seed=self.last_reset_seed) 133 | self.last_observation = self.format_obs_dict(self.dmlab.observations()) 134 | self.episodic_reward = 0 135 | return self.last_observation 136 | 137 | def step(self, action): 138 | 139 | reward = self.dmlab.step(self.action_list[action], num_steps=self.action_repeat) 140 | done = not self.dmlab.is_running() 141 | 142 | self.episodic_reward += reward 143 | info = {'num_frames': self.action_repeat} 144 | 145 | if not done: 146 | obs_dict = self.format_obs_dict(self.dmlab.observations()) 147 | self.last_observation = obs_dict 148 | if done: 149 | self.reset() 150 | 151 | return self.last_observation, reward, done, info 152 | 153 | 154 | def close(self): 155 | self.dmlab.close() 156 | 157 | 158 | def fetch(self, key, pk3_path): 159 | if not self.env_uses_level_cache: 160 | self.env_uses_level_cache = True 161 | 162 | path = os.path.join(self.level_cache_path, key) 163 | 164 | if os.path.isfile(path): 165 | shutil.copyfile(path, pk3_path) 166 | return True 167 | else: 168 | log.warning('Cache miss in environment %s key: %s!', self.level_name, key) 169 | return False 170 | 171 | def write(self, key, pk3_path): 172 | log.debug('Add new level to cache! Level %s seed %r key %s', self.level_name, self.last_reset_seed, key) 173 | self.cache.add_new_level(self.level, self.last_reset_seed, key, pk3_path) 174 | -------------------------------------------------------------------------------- /brain_agent/envs/dmlab/dmlab30.py: -------------------------------------------------------------------------------- 1 | DMLAB_INSTRUCTIONS = 'INSTR' 2 | DMLAB_VOCABULARY_SIZE = 1000 3 | DMLAB_MAX_INSTRUCTION_LEN = 16 4 | 5 | 6 | HUMAN_SCORES = { 7 | 'rooms_collect_good_objects_test': 10, 8 | 'rooms_exploit_deferred_effects_test': 85.65, 9 | 'rooms_select_nonmatching_object': 65.9, 10 | 'rooms_watermaze': 54, 11 | 'rooms_keys_doors_puzzle': 53.8, 12 | 'language_select_described_object': 389.5, 13 | 'language_select_located_object': 280.7, 14 | 'language_execute_random_task': 254.05, 15 | 'language_answer_quantitative_question': 184.5, 16 | 'lasertag_one_opponent_small': 12.65, 17 | 'lasertag_three_opponents_small': 18.55, 18 | 'lasertag_one_opponent_large': 18.6, 19 | 'lasertag_three_opponents_large': 31.5, 20 | 'natlab_fixed_large_map': 36.9, 21 | 'natlab_varying_map_regrowth': 24.45, 22 | 'natlab_varying_map_randomized': 42.35, 23 | 'skymaze_irreversible_path_hard': 100, 24 | 'skymaze_irreversible_path_varied': 100, 25 | 'psychlab_arbitrary_visuomotor_mapping': 58.75, 26 | 'psychlab_continuous_recognition': 58.3, 27 | 'psychlab_sequential_comparison': 39.5, 28 | 'psychlab_visual_search': 78.5, 29 | 'explore_object_locations_small': 74.45, 30 | 'explore_object_locations_large': 65.65, 31 | 'explore_obstructed_goals_small': 206, 32 | 'explore_obstructed_goals_large': 119.5, 33 | 'explore_goal_locations_small': 267.5, 34 | 'explore_goal_locations_large': 194.5, 35 | 'explore_object_rewards_few': 77.7, 36 | 'explore_object_rewards_many': 106.7, 37 | } 38 | 39 | 40 | RANDOM_SCORES = { 41 | 'rooms_collect_good_objects_test': 0.073, 42 | 'rooms_exploit_deferred_effects_test': 8.501, 43 | 'rooms_select_nonmatching_object': 0.312, 44 | 'rooms_watermaze': 4.065, 45 | 'rooms_keys_doors_puzzle': 4.135, 46 | 'language_select_described_object': -0.07, 47 | 'language_select_located_object': 1.929, 48 | 'language_execute_random_task': -5.913, 49 | 'language_answer_quantitative_question': -0.33, 50 | 'lasertag_one_opponent_small': -0.224, 51 | 'lasertag_three_opponents_small': -0.214, 52 | 'lasertag_one_opponent_large': -0.083, 53 | 'lasertag_three_opponents_large': -0.102, 54 | 'natlab_fixed_large_map': 2.173, 55 | 'natlab_varying_map_regrowth': 2.989, 56 | 'natlab_varying_map_randomized': 7.346, 57 | 'skymaze_irreversible_path_hard': 0.1, 58 | 'skymaze_irreversible_path_varied': 14.4, 59 | 'psychlab_arbitrary_visuomotor_mapping': 0.163, 60 | 'psychlab_continuous_recognition': 0.224, 61 | 'psychlab_sequential_comparison': 0.129, 62 | 'psychlab_visual_search': 0.085, 63 | 'explore_object_locations_small': 3.575, 64 | 'explore_object_locations_large': 4.673, 65 | 'explore_obstructed_goals_small': 6.76, 66 | 'explore_obstructed_goals_large': 2.61, 67 | 'explore_goal_locations_small': 7.66, 68 | 'explore_goal_locations_large': 3.14, 69 | 'explore_object_rewards_few': 2.073, 70 | 'explore_object_rewards_many': 2.438, 71 | } 72 | 73 | DMLAB_LEVELS_BY_ENVNAME = { 74 | 'dmlab_30': 75 | ['contributed/dmlab30/rooms_collect_good_objects_train', 76 | 'contributed/dmlab30/rooms_exploit_deferred_effects_train', 77 | 'contributed/dmlab30/rooms_select_nonmatching_object', 78 | 'contributed/dmlab30/rooms_watermaze', 79 | 'contributed/dmlab30/rooms_keys_doors_puzzle', 80 | 'contributed/dmlab30/language_select_described_object', 81 | 'contributed/dmlab30/language_select_located_object', 82 | 'contributed/dmlab30/language_execute_random_task', 83 | 'contributed/dmlab30/language_answer_quantitative_question', 84 | 'contributed/dmlab30/lasertag_one_opponent_small', 85 | 'contributed/dmlab30/lasertag_three_opponents_small', 86 | 'contributed/dmlab30/lasertag_one_opponent_large', 87 | 'contributed/dmlab30/lasertag_three_opponents_large', 88 | 'contributed/dmlab30/natlab_fixed_large_map', 89 | 'contributed/dmlab30/natlab_varying_map_regrowth', 90 | 'contributed/dmlab30/natlab_varying_map_randomized', 91 | 'contributed/dmlab30/skymaze_irreversible_path_hard', 92 | 'contributed/dmlab30/skymaze_irreversible_path_varied', 93 | 'contributed/dmlab30/psychlab_arbitrary_visuomotor_mapping', 94 | 'contributed/dmlab30/psychlab_continuous_recognition', 95 | 'contributed/dmlab30/psychlab_sequential_comparison', 96 | 'contributed/dmlab30/psychlab_visual_search', 97 | 'contributed/dmlab30/explore_object_locations_small', 98 | 'contributed/dmlab30/explore_object_locations_large', 99 | 'contributed/dmlab30/explore_obstructed_goals_small', 100 | 'contributed/dmlab30/explore_obstructed_goals_large', 101 | 'contributed/dmlab30/explore_goal_locations_small', 102 | 'contributed/dmlab30/explore_goal_locations_large', 103 | 'contributed/dmlab30/explore_object_rewards_few', 104 | 'contributed/dmlab30/explore_object_rewards_many' 105 | ], 106 | 'dmlab_30_test': 107 | ['contributed/dmlab30/rooms_collect_good_objects_test', 108 | 'contributed/dmlab30/rooms_exploit_deferred_effects_test', 109 | 'contributed/dmlab30/rooms_select_nonmatching_object', 110 | 'contributed/dmlab30/rooms_watermaze', 111 | 'contributed/dmlab30/rooms_keys_doors_puzzle', 112 | 'contributed/dmlab30/language_select_described_object', 113 | 'contributed/dmlab30/language_select_located_object', 114 | 'contributed/dmlab30/language_execute_random_task', 115 | 'contributed/dmlab30/language_answer_quantitative_question', 116 | 'contributed/dmlab30/lasertag_one_opponent_small', 117 | 'contributed/dmlab30/lasertag_three_opponents_small', 118 | 'contributed/dmlab30/lasertag_one_opponent_large', 119 | 'contributed/dmlab30/lasertag_three_opponents_large', 120 | 'contributed/dmlab30/natlab_fixed_large_map', 121 | 'contributed/dmlab30/natlab_varying_map_regrowth', 122 | 'contributed/dmlab30/natlab_varying_map_randomized', 123 | 'contributed/dmlab30/skymaze_irreversible_path_hard', 124 | 'contributed/dmlab30/skymaze_irreversible_path_varied', 125 | 'contributed/dmlab30/psychlab_arbitrary_visuomotor_mapping', 126 | 'contributed/dmlab30/psychlab_continuous_recognition', 127 | 'contributed/dmlab30/psychlab_sequential_comparison', 128 | 'contributed/dmlab30/psychlab_visual_search', 129 | 'contributed/dmlab30/explore_object_locations_small', 130 | 'contributed/dmlab30/explore_object_locations_large', 131 | 'contributed/dmlab30/explore_obstructed_goals_small', 132 | 'contributed/dmlab30/explore_obstructed_goals_large', 133 | 'contributed/dmlab30/explore_goal_locations_small', 134 | 'contributed/dmlab30/explore_goal_locations_large', 135 | 'contributed/dmlab30/explore_object_rewards_few', 136 | 'contributed/dmlab30/explore_object_rewards_many' 137 | ] 138 | } 139 | 140 | IMPALA_ACTION_SET = ( 141 | (0, 0, 0, 1, 0, 0, 0), # Forward 142 | (0, 0, 0, -1, 0, 0, 0), # Backward 143 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 144 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 145 | (-20, 0, 0, 0, 0, 0, 0), # Look Left 146 | (20, 0, 0, 0, 0, 0, 0), # Look Right 147 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward 148 | (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward 149 | (0, 0, 0, 0, 1, 0, 0), # Fire. 150 | ) 151 | 152 | 153 | EXTENDED_ACTION_SET = ( 154 | (0, 0, 0, 1, 0, 0, 0), # Forward 155 | (0, 0, 0, -1, 0, 0, 0), # Backward 156 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 157 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 158 | (-10, 0, 0, 0, 0, 0, 0), # Small Look Left 159 | (10, 0, 0, 0, 0, 0, 0), # Small Look Right 160 | (-60, 0, 0, 0, 0, 0, 0), # Large Look Left 161 | (60, 0, 0, 0, 0, 0, 0), # Large Look Right 162 | (0, 10, 0, 0, 0, 0, 0), # Look Down 163 | (0, -10, 0, 0, 0, 0, 0), # Look Up 164 | (-10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Left 165 | (10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Right 166 | (-60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Left 167 | (60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Right 168 | (0, 0, 0, 0, 1, 0, 0), # Fire. 169 | ) 170 | 171 | EXTENDED_ACTION_SET_LARGE = ( 172 | (0, 0, 0, 1, 0, 0, 0), # Forward 173 | (0, 0, 0, -1, 0, 0, 0), # Backward 174 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left 175 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right 176 | (-10, 0, 0, 0, 0, 0, 0), # Small Look Left 177 | (10, 0, 0, 0, 0, 0, 0), # Small Look Right 178 | (-60, 0, 0, 0, 0, 0, 0), # Large Look Left 179 | (60, 0, 0, 0, 0, 0, 0), # Large Look Right 180 | (0, 10, 0, 0, 0, 0, 0), # Look Down 181 | (0, -10, 0, 0, 0, 0, 0), # Look Up 182 | (0, 60, 0, 0, 0, 0, 0), # Large Look Down 183 | (0, -60, 0, 0, 0, 0, 0), # Large Look Up 184 | (-10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Left 185 | (10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Right 186 | (-60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Left 187 | (60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Right 188 | (0, 0, 0, 0, 1, 0, 0), # Fire. 189 | ) 190 | -------------------------------------------------------------------------------- /brain_agent/core/agents/dmlab_multitask_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from brain_agent.core.agents.agent_abc import ActorCriticBase 4 | from brain_agent.envs.dmlab.dmlab_model import DmlabEncoder 5 | from brain_agent.core.models.transformer import MemTransformerLM 6 | from brain_agent.core.models.rnn import LSTM 7 | from brain_agent.core.algos.popart import update_parameters, update_mu_sigma 8 | from brain_agent.core.models.model_utils import normalize_obs_return 9 | from brain_agent.core.models.action_distributions import sample_actions_log_probs 10 | from brain_agent.utils.utils import AttrDict 11 | 12 | from brain_agent.core.algos.aux_future_predict import FuturePredict 13 | 14 | 15 | class DMLabMultiTaskAgent(ActorCriticBase): 16 | def __init__(self, cfg, action_space, obs_space, num_levels, need_half=False): 17 | super().__init__(action_space, cfg) 18 | 19 | self.encoder = DmlabEncoder(cfg, obs_space) 20 | 21 | core_input_size = self.encoder.get_encoder_out_size() 22 | if cfg.model.extended_input: 23 | core_input_size += action_space.n + 1 24 | 25 | if cfg.model.core.core_type == "trxl": 26 | self.core = MemTransformerLM(cfg, n_layer=cfg.model.core.n_layer, n_head=cfg.model.core.n_heads, 27 | d_head=cfg.model.core.d_head, d_model=core_input_size, 28 | d_inner=cfg.model.core.d_inner, 29 | mem_len=cfg.model.core.mem_len, pre_lnorm=True) 30 | 31 | elif cfg.model.core.core_type == "rnn": 32 | self.core = LSTM(cfg, core_input_size) 33 | else: 34 | raise Exception('Error: Not support given model core_type') 35 | 36 | core_out_size = self.core.get_core_out_size() 37 | 38 | if self.cfg.model.use_popart: 39 | self.register_buffer('mu', torch.zeros(num_levels, requires_grad=False)) 40 | self.register_buffer('nu', torch.ones(num_levels, requires_grad=False)) 41 | self.critic_linear = nn.Linear(core_out_size, num_levels) 42 | self.beta = self.cfg.model.popart_beta 43 | else: 44 | self.critic_linear = nn.Linear(core_out_size, 1) 45 | 46 | if cfg.learner.use_aux_future_pred_loss and cfg.model.core.core_type == "trxl": 47 | self.future_pred_module = FuturePredict(self.cfg, self.encoder.basic_encoder.input_ch, 48 | self.encoder.basic_encoder.conv_head_out_size, 49 | core_out_size, 50 | action_space) 51 | 52 | self.action_parameterization = self.get_action_parameterization(core_out_size) 53 | 54 | self.action_parameterization.apply(self.initialize) 55 | self.critic_linear.apply(self.initialize) 56 | self.train() 57 | 58 | self.need_half = need_half 59 | if self.need_half: 60 | self.half() 61 | 62 | def initialize(self, layer): 63 | def init_weight(weight): 64 | nn.init.normal_(weight, 0.0, 0.02) 65 | 66 | def init_bias(bias): 67 | nn.init.constant_(bias, 0.0) 68 | 69 | classname = layer.__class__.__name__ 70 | if classname.find('Linear') != -1: 71 | if hasattr(layer, 'weight') and layer.weight is not None: 72 | init_weight(layer.weight) 73 | if hasattr(layer, 'bias') and layer.bias is not None: 74 | init_bias(layer.bias) 75 | 76 | # TODO: functionalize popart 77 | # def update_parameters(self, mu, sigma, oldmu, oldsigma): 78 | # self.critic_linear.weight.data, self.critic_linear.bias.data = \ 79 | # update_parameters(self.critic_linear.weight, self.critic_linear.bias, mu, sigma, oldmu, oldsigma) 80 | 81 | def update_parameters(self, mu, sigma, oldmu, oldsigma): 82 | self.critic_linear.weight.data = (self.critic_linear.weight.t() * oldsigma / sigma).t() 83 | self.critic_linear.bias.data = (oldsigma * self.critic_linear.bias + oldmu - mu) / sigma 84 | 85 | def update_mu_sigma(self, vs, task_ids, cfg=None): 86 | oldnu = self.nu.clone() 87 | oldsigma = torch.sqrt(oldnu - self.mu ** 2) 88 | oldsigma[torch.isnan(oldsigma)] = self.cfg.model.popart_clip_min 89 | clamp_max = 1e4 if hasattr(self, 'need_half') and self.need_half else 1e6 90 | oldsigma = torch.clamp(oldsigma, min=cfg.model.popart_clip_min, max=clamp_max) 91 | oldmu = self.mu.clone() 92 | 93 | vs = vs.reshape(-1, self.cfg.optim.rollout) 94 | # same task ids over all time steps within a single episode 95 | task_ids_per_epi = task_ids.reshape(-1, self.cfg.optim.rollout)[:, 0] 96 | for i in range(len(task_ids_per_epi)): 97 | task_id = task_ids_per_epi[i] 98 | v = torch.mean(vs[i]) 99 | self.mu[task_id] = (1 - self.beta) * self.mu[task_id] + self.beta * v 100 | self.nu[task_id] = (1 - self.beta) * self.nu[task_id] + self.beta * (v ** 2) 101 | 102 | sigma = torch.sqrt(self.nu - self.mu ** 2) 103 | sigma[torch.isnan(sigma)] = self.cfg.model.popart_clip_min 104 | sigma = torch.clamp(sigma, min=cfg.model.popart_clip_min, max=clamp_max) 105 | 106 | return self.mu, sigma, oldmu, oldsigma 107 | 108 | def forward_head(self, obs_dict, actions, rewards, decode=False): 109 | obs_dict = normalize_obs_return(obs_dict, self.cfg) # before normalize, [0,255] 110 | if self.need_half: 111 | obs_dict['obs'] = obs_dict['obs'].half() 112 | 113 | x = self.encoder(obs_dict, decode=decode) 114 | if decode: 115 | self.reconstruction = self.encoder.basic_encoder.reconstruction 116 | else: 117 | self.reconstruction = None 118 | 119 | x_extended = [] 120 | if self.cfg.model.extended_input: 121 | assert torch.min(actions) >= -1 and torch.max(actions) < self.action_space.n 122 | done_ids = actions.eq(-1).nonzero(as_tuple=False) 123 | actions[done_ids] = 0 124 | prev_actions = nn.functional.one_hot(actions, self.action_space.n).float() 125 | prev_actions[done_ids] = 0. 126 | x_extended.append(prev_actions) 127 | x_extended.append(rewards.clamp(-1, 1).unsqueeze(1)) 128 | 129 | x = torch.cat([x] + x_extended, dim=-1) 130 | 131 | if self.need_half: 132 | x = x.half() 133 | 134 | return x 135 | 136 | def forward_core_transformer(self, head_output, mems=None, mem_begin_index=None, dones=None, from_learner=False): 137 | x, new_mems = self.core(head_output, mems, mem_begin_index, dones=dones, from_learner=from_learner) 138 | return x, new_mems 139 | 140 | def forward_core_rnn(self, head_output, rnn_states, dones, is_seq=None): 141 | x, new_rnn_states = self.core(head_output, rnn_states, dones, is_seq) 142 | return x, new_rnn_states 143 | 144 | def forward_tail(self, core_output, task_ids, with_action_distribution=False): 145 | values = self.critic_linear(core_output) 146 | normalized_values = values.clone() 147 | sigmas = torch.ones((values.size(0), 1), requires_grad=False) 148 | mus = torch.zeros((values.size(0), 1), requires_grad=False) 149 | if self.cfg.model.use_popart: 150 | normalized_values = normalized_values.gather(dim=1, index=task_ids) 151 | with torch.no_grad(): 152 | nus = self.nu.index_select(dim=0, index=task_ids.squeeze(1)).unsqueeze(1) 153 | mus = self.mu.index_select(dim=0, index=task_ids.squeeze(1)).unsqueeze(1) 154 | sigmas = torch.sqrt(nus - mus ** 2) 155 | sigmas[torch.isnan(sigmas)] = self.cfg.model.popart_clip_min 156 | clamp_max = 1e4 if self.need_half else 1e6 157 | sigmas = torch.clamp(sigmas, min=self.cfg.model.popart_clip_min, max=clamp_max) 158 | values = normalized_values * sigmas + mus 159 | 160 | action_distribution_params, action_distribution = self.action_parameterization(core_output) 161 | 162 | actions, log_prob_actions = sample_actions_log_probs(action_distribution) 163 | 164 | result = AttrDict(dict( 165 | actions=actions, 166 | action_logits=action_distribution_params, 167 | log_prob_actions=log_prob_actions, 168 | values=values, 169 | normalized_values=normalized_values, 170 | sigmas=sigmas, 171 | mus=mus 172 | )) 173 | 174 | if with_action_distribution: 175 | result.action_distribution = action_distribution 176 | 177 | return result 178 | 179 | def forward(self, obs_dict, actions, rewards, mems=None, mem_begin_index=None, rnn_states=None, dones=None, is_seq=None, 180 | task_ids=None, with_action_distribution=False, from_learner=False): 181 | x = self.forward_head(obs_dict, actions, rewards) 182 | 183 | if self.cfg.model.core.core_type == 'trxl': 184 | x, new_mems = self.forward_core_transformer(x, mems, mem_begin_index, from_learner=from_learner) 185 | elif self.cfg.model.core.core_type == 'rnn': 186 | x, new_rnn_states = self.forward_core_rnn(x, rnn_states, dones, is_seq) 187 | 188 | assert not x.isnan().any() 189 | 190 | result = self.forward_tail(x, task_ids, with_action_distribution=with_action_distribution) 191 | 192 | if self.cfg.model.core.core_type == "trxl": 193 | result.mems = new_mems 194 | elif self.cfg.model.core.core_type == 'rnn': 195 | result.rnn_states = new_rnn_states 196 | 197 | return result 198 | -------------------------------------------------------------------------------- /brain_agent/core/algos/aux_future_predict.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from brain_agent.core.models.model_utils import normalize_obs_return 7 | from brain_agent.core.models.resnet import ResBlock 8 | from brain_agent.core.models.causal_transformer import CausalTransformer 9 | 10 | ''' 11 | To learn the useful representations, this module tries to minimize the difference between the predicted future 12 | observations and the real observations for k(2~10) steps. It uses causal transformer in autoregressive manner 13 | to predict the state transitions when the actual action sequence and state embedding of current step are given. 14 | 15 | For the current state embedding, we use the outputs of TrXL core. These state embeddings are concatenated 16 | with actual actions taken, and then fed into the causal transformer. The causal transformer iteratively 17 | produces next states embedding in autoregressive manner. 18 | 19 | After we produce all the future state embeddings for K (2~10) steps, we apply transposed convolutional 20 | decoding layer to map the state embedding into the original image observation space. Finally, we calculate 21 | the L2 distance between decoded image outputs and real images. 22 | 23 | With this auxiliary loss added, mean HNS(human normalized score) of 30 tasks increased from 123.6 to 128.0, 24 | although capped mean HNS decreased from 91.25 to 90.53. 25 | You can add this module by changing the argument learner.use_aux_future_pred_loss to True. 26 | 27 | ''' 28 | 29 | class FuturePredict(nn.Module): 30 | def __init__(self, cfg, encoder_input_ch, conv_head_out_size, core_out_size, action_space): 31 | super(FuturePredict, self).__init__() 32 | self.cfg = cfg 33 | self.action_space = action_space 34 | self.n_action = action_space.n 35 | self.horizon_k: int = 10 36 | self.time_subsample: int = 6 37 | self.forward_subsample: int = 2 38 | self.core_out_size = core_out_size 39 | self.conv_head_out_size = conv_head_out_size 40 | 41 | if isinstance(action_space, gym.spaces.Discrete): 42 | self.action_sizes = [action_space.n] 43 | else: 44 | self.action_sizes = [space.n for space in action_space.spaces] 45 | 46 | self.g = nn.Sequential( 47 | nn.Linear(core_out_size, 256), 48 | nn.ReLU(), 49 | nn.Linear(256, conv_head_out_size) 50 | ) 51 | 52 | self.mse_loss = nn.MSELoss(reduction='none') 53 | self.causal_transformer = CausalTransformer(core_out_size, action_space.n, pre_lnorm=True) 54 | self.causal_transformer_window = self.causal_transformer.mem_len 55 | 56 | if cfg.model.encoder.encoder_subtype == 'resnet_impala': 57 | resnet_conf = [[16, 2], [32, 2], [32, 2]] 58 | elif cfg.model.encoder.encoder_subtype == 'resnet_impala_large': 59 | resnet_conf = [[32, 2], [64, 2], [64, 2]] 60 | self.conv_out_ch = resnet_conf[-1][0] 61 | layers_decoder = list() 62 | curr_input_channels = encoder_input_ch 63 | for i, (out_channels, res_blocks) in enumerate(resnet_conf): 64 | 65 | for j in range(res_blocks): 66 | layers_decoder.append(ResBlock(cfg, curr_input_channels, curr_input_channels)) 67 | layers_decoder.append( 68 | nn.ConvTranspose2d(out_channels, curr_input_channels, kernel_size=3, stride=2, 69 | padding=1, output_padding=1) 70 | ) 71 | curr_input_channels = out_channels 72 | 73 | layers_decoder.reverse() 74 | self.layers_decoder = nn.Sequential(*layers_decoder) 75 | 76 | def calc_loss(self, mb, b_t, mems, mems_actions, mems_dones, mem_begin_index, num_traj, recurrence): 77 | not_dones = (1.0 - mb.dones).view(num_traj, recurrence, 1) 78 | mask_res = self._build_mask_and_subsample(not_dones) 79 | (forward_mask, unroll_subsample, time_subsample, max_k) = mask_res 80 | 81 | actions_raw = mb.actions.view(num_traj, recurrence, -1).long() 82 | mems = torch.split(mems, self.core_out_size, dim=-1)[-1].transpose(0,1) 83 | mem_len = mems.size(1) 84 | mems_actions = mems_actions.transpose(0,1) 85 | b_t = b_t.reshape(num_traj, recurrence, -1) 86 | obs = normalize_obs_return(mb.obs, self.cfg) 87 | obs = obs['obs'].reshape(num_traj, recurrence, 3, 72, 96) 88 | 89 | cat_out = torch.cat([mems, b_t], dim=1) 90 | cat_action = torch.cat([mems_actions, actions_raw], dim=1) 91 | mems_not_dones = 1.0 - mems_dones.float() 92 | cat_not_dones = torch.cat([mems_not_dones, not_dones], dim=1) 93 | mem_begin_index = torch.tensor(mem_begin_index, device=cat_out.device) 94 | 95 | x, y, z, w, forward_targets = self._get_transfomrer_input(obs, cat_out, cat_action, cat_not_dones, time_subsample, max_k, mem_len, mem_begin_index) 96 | h_pred_stack, done_mask = self._make_transformer_pred(x,y,z,w, num_traj, max_k) 97 | h_pred_stack = h_pred_stack.index_select(0, unroll_subsample) 98 | final_pred = self.g(h_pred_stack) 99 | 100 | x_dec = final_pred.view(-1, self.conv_out_ch, 9, 12) 101 | for i in range(len(self.layers_decoder)): 102 | layer_decoder = self.layers_decoder[i] 103 | x_dec = layer_decoder(x_dec) 104 | x_dec = torch.tanh(x_dec) 105 | 106 | with torch.no_grad(): 107 | forward_targets = forward_targets.flatten(0, 1) 108 | forward_targets = forward_targets.transpose(0, 1) 109 | forward_targets = forward_targets.index_select(0, unroll_subsample) 110 | forward_targets = forward_targets.flatten(0, 1) 111 | 112 | loss = self.mse_loss(x_dec, forward_targets) 113 | loss = loss.view(loss.size()[0], -1).mean(-1, keepdim=True) 114 | 115 | final_mask = torch.logical_and(forward_mask, done_mask) 116 | loss = torch.masked_select(loss, final_mask.flatten(0,1)).mean() 117 | loss = loss * self.cfg.learner.aux_future_pred_loss_coeff 118 | return loss 119 | 120 | def _make_transformer_pred(self, states, actions, not_dones, mem_begin_index, num_traj, max_k): 121 | actions = nn.functional.one_hot(actions.long(), num_classes=self.n_action).squeeze(3).float() 122 | actions = actions.view(self.time_subsample*num_traj, -1, self.n_action).transpose(0,1) 123 | states = states.view(self.time_subsample*num_traj, -1, self.core_out_size).transpose(0,1) 124 | not_dones = not_dones.view(self.time_subsample*num_traj, -1, 1).transpose(0,1) 125 | dones = 1.0 - not_dones 126 | tokens = torch.cat([states, actions], dim=2) 127 | input_tokens_past = tokens[:self.causal_transformer_window-1,:,:] 128 | input_token_current = tokens[self.causal_transformer_window-1,:,:].unsqueeze(0) 129 | input_token = torch.cat([input_tokens_past, input_token_current], dim=0) 130 | 131 | mem_begin_index = mem_begin_index.view(-1) 132 | 133 | y, mems, mem_begin_index = self.causal_transformer(input_token, mem_begin_index, num_traj, mems=None) 134 | 135 | lst = [] 136 | for mem in mems: 137 | lst.append(mem) 138 | mems = lst 139 | 140 | y = y[-1].unsqueeze(0) 141 | out = [y] 142 | for i in range(max_k-1): 143 | new_input = torch.cat([y, actions[self.causal_transformer_window+i].unsqueeze(0)], dim=-1) 144 | y, mems, mem_begin_index = self.causal_transformer(new_input, mem_begin_index, num_traj, mems=mems) 145 | out.append(y) 146 | 147 | done_mask = torch.ge(dones.sum(dim=0), 1.0) 148 | done_mask = torch.logical_not(done_mask) # False means masking. 149 | 150 | return torch.stack(out).squeeze(1), done_mask 151 | 152 | def _get_transfomrer_input(self, obs, cat_out, cat_action, cat_not_dones, time_subsample, max_k, mem_len, mem_begin_index): 153 | out_lst = [] 154 | actions_lst = [] 155 | not_dones_lst = [] 156 | mem_begin_index_lst = [] 157 | target_lst = [] 158 | 159 | for i in range(self.time_subsample): 160 | first_idx = mem_len + time_subsample[i] 161 | max_idx = first_idx + max_k 162 | min_idx = first_idx - self.causal_transformer_window + 1 163 | x = cat_out[:, min_idx:max_idx, :] 164 | y = cat_action[:, min_idx:max_idx, :] 165 | z = cat_not_dones[:, min_idx:max_idx, :] 166 | w = mem_begin_index + time_subsample[i] + 1 167 | k = obs[:, time_subsample[i]+1:time_subsample[i]+1+max_k] 168 | out_lst.append(x) 169 | actions_lst.append(y) 170 | not_dones_lst.append(z) 171 | mem_begin_index_lst.append(w) 172 | target_lst.append(k) 173 | 174 | 175 | return torch.stack(out_lst), torch.stack(actions_lst), torch.stack(not_dones_lst), \ 176 | torch.stack(mem_begin_index_lst), torch.stack(target_lst) 177 | 178 | def _build_mask_and_subsample(self, not_dones): 179 | t = not_dones.size(1) - self.horizon_k 180 | 181 | not_dones_unfolded = self._build_unfolded(not_dones[:, :-1].to(dtype=torch.bool), self.horizon_k) # 10, 32, 24, 1 182 | time_subsample = torch.randperm(t - 2, device=not_dones.device, dtype=torch.long)[0:self.time_subsample] 183 | 184 | forward_mask = torch.cumprod(not_dones_unfolded.index_select(2, time_subsample), dim=0).to(dtype=torch.bool) # 10, 32, 6, 1 185 | forward_mask = forward_mask.flatten(1, 2) # 10, 192, 1 186 | 187 | max_k = forward_mask.flatten(1).any(-1).nonzero().max().item() + 1 188 | 189 | unroll_subsample = torch.randperm(max_k, dtype=torch.long)[0:self.forward_subsample] 190 | 191 | max_k = unroll_subsample.max().item() + 1 192 | 193 | unroll_subsample = unroll_subsample.to(device=not_dones.device) 194 | forward_mask = forward_mask.index_select(0, unroll_subsample) 195 | 196 | return forward_mask, unroll_subsample, time_subsample, max_k 197 | 198 | def _build_unfolded(self, x, k: int): 199 | tobe_cat = x.new_zeros(x.size(0), k, x.size(2)) 200 | cat = torch.cat((x, tobe_cat), 1) 201 | cat = cat.unfold(1, size=k, step=1) 202 | cat = cat.permute(3, 0, 1, 2) 203 | return cat 204 | 205 | -------------------------------------------------------------------------------- /brain_agent/core/shared_buffer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import gym 5 | 6 | from brain_agent.utils.logger import log 7 | from brain_agent.core.core_utils import iter_dicts_recursively, copy_dict_structure, iterate_recursively 8 | from brain_agent.core.models.model_utils import get_hidden_size 9 | 10 | 11 | def ensure_memory_shared(*tensors): 12 | for tensor_dict in tensors: 13 | for _, _, t in iterate_recursively(tensor_dict): 14 | assert t.is_shared() 15 | 16 | def to_torch_dtype(numpy_dtype): 17 | """from_numpy automatically infers type, so we leverage that.""" 18 | x = np.zeros([1], dtype=numpy_dtype) 19 | t = torch.from_numpy(x) 20 | return t.dtype 21 | 22 | def to_numpy(t, num_dimensions): 23 | arr_shape = t.shape[:num_dimensions] 24 | arr = np.ndarray(arr_shape, dtype=object) 25 | to_numpy_func(t, arr) 26 | return arr 27 | 28 | 29 | def to_numpy_func(t, arr): 30 | if len(arr.shape) == 1: 31 | for i in range(t.shape[0]): 32 | arr[i] = t[i] 33 | else: 34 | for i in range(t.shape[0]): 35 | to_numpy_func(t[i], arr[i]) 36 | 37 | 38 | class SharedBuffer: 39 | """ 40 | Shared buffer stores data from different processes in a single place for efficient sharing of information. 41 | The tensors are stored according to the index scheme of 42 | [ 43 | actor worker index, 44 | split index, 45 | env index, 46 | trajectory buffer index, 47 | time step 48 | ] 49 | actor worker index: 50 | Index of the actor worker. There can be multiple actor workers. 51 | split index: 52 | Index of the split. Each worker has many env objects which are grouped into splits. 53 | Each split is inferenced iteratively, for faster performance. (See Sample-factory for detail) 54 | env index: 55 | Inside a split, there are multiple env objects. 56 | trajectory buffer index: 57 | In case learner is significantly slow, data may pile and extra buffers takes care of them in a circular queue manner. 58 | time step: 59 | Rollout data are stored as a trajectory. A tensor resides in time step t within the trajectory. 60 | """ 61 | def __init__(self, cfg, obs_space, action_space): 62 | 63 | self.cfg = cfg 64 | assert not cfg.actor.num_envs_per_worker % cfg.actor.num_splits, \ 65 | f'actor.num_envs_per_worker ({cfg.actor.num_envs_per_worker}) ' \ 66 | f'is not divided by actor.num_splits ({cfg.actor.num_splits})' 67 | 68 | self.obs_space = obs_space 69 | self.action_space = action_space 70 | 71 | self.envs_per_split = cfg.actor.num_envs_per_worker // cfg.actor.num_splits 72 | self.num_traj_buffers = self.calc_num_trajectory_buffers() 73 | 74 | core_hidden_size = get_hidden_size(self.cfg, action_space) 75 | 76 | log.debug('Allocating shared memory for trajectories') 77 | self.tensors = TensorDict() 78 | 79 | obs_dict = TensorDict() 80 | self.tensors['obs'] = obs_dict 81 | if isinstance(obs_space, gym.spaces.Dict): 82 | for name, space in obs_space.spaces.items(): 83 | obs_dict[name] = self.init_tensor(space.dtype, space.shape) 84 | else: 85 | raise Exception('Only Dict observations spaces are supported') 86 | self.tensors['prev_actions'] = self.init_tensor(torch.int64, [1]) 87 | self.tensors['prev_rewards'] = self.init_tensor(torch.float32, [1]) 88 | 89 | self.tensors['rewards'] = self.init_tensor(torch.float32, [1]) 90 | self.tensors['dones'] = self.init_tensor(torch.bool, [1]) 91 | self.tensors['raw_rewards'] = self.init_tensor(torch.float32, [1]) 92 | 93 | # policy outputs 94 | if self.cfg.model.core.core_type == 'trxl': 95 | policy_outputs = [ 96 | ('actions', 1), 97 | ('action_logits', action_space.n), 98 | ('log_prob_actions', 1), 99 | ('values', 1), 100 | ('normalized_values', 1), 101 | ('policy_version', 1), 102 | ] 103 | elif self.cfg.model.core.core_type == 'rnn': 104 | policy_outputs = [ 105 | ('actions', 1), 106 | ('action_logits', action_space.n), 107 | ('log_prob_actions', 1), 108 | ('values', 1), 109 | ('normalized_values', 1), 110 | ('policy_version', 1), 111 | ('rnn_states', core_hidden_size) 112 | ] 113 | 114 | policy_outputs = [PolicyOutput(*po) for po in policy_outputs] 115 | policy_outputs = sorted(policy_outputs, key=lambda policy_output: policy_output.name) 116 | 117 | for po in policy_outputs: 118 | self.tensors[po.name] = self.init_tensor(torch.float32, [po.size]) 119 | 120 | ensure_memory_shared(self.tensors) 121 | 122 | self.tensors_individual_transitions = self.tensor_dict_to_numpy(len(self.tensor_dimensions())) 123 | 124 | self.tensor_trajectories = self.tensor_dict_to_numpy(len(self.tensor_dimensions()) - 1) 125 | 126 | traj_buffer_available_shape = [ 127 | self.cfg.actor.num_workers, 128 | self.cfg.actor.num_splits, 129 | self.envs_per_split, 130 | self.num_traj_buffers, 131 | ] 132 | self.is_traj_tensor_available = torch.ones(traj_buffer_available_shape, dtype=torch.uint8) 133 | self.is_traj_tensor_available.share_memory_() 134 | self.is_traj_tensor_available = to_numpy(self.is_traj_tensor_available, 2) 135 | 136 | policy_outputs_combined_size = sum(po.size for po in policy_outputs) 137 | policy_outputs_shape = [ 138 | self.cfg.actor.num_workers, 139 | self.cfg.actor.num_splits, 140 | self.envs_per_split, 141 | policy_outputs_combined_size, 142 | ] 143 | 144 | self.policy_outputs = policy_outputs 145 | self.policy_output_tensors = torch.zeros(policy_outputs_shape, dtype=torch.float32) 146 | self.policy_output_tensors.share_memory_() 147 | self.policy_output_tensors = to_numpy(self.policy_output_tensors, 3) 148 | 149 | self.policy_versions = torch.zeros([1], dtype=torch.int32) 150 | self.policy_versions.share_memory_() 151 | 152 | self.stop_experience_collection = torch.ones([1], dtype=torch.bool) 153 | self.stop_experience_collection.share_memory_() 154 | 155 | self.task_ids = torch.zeros([self.cfg.actor.num_workers, self.cfg.actor.num_splits, self.envs_per_split], 156 | dtype=torch.uint8) 157 | self.task_ids.share_memory_() 158 | 159 | self.max_mems_buffer_len = self.cfg.model.core.mem_len + self.cfg.optim.rollout * (self.num_traj_buffers + 1) 160 | self.mems_dimensions = [self.cfg.actor.num_workers, self.cfg.actor.num_splits, self.envs_per_split, self.max_mems_buffer_len] 161 | self.mems_dimensions.append(core_hidden_size) 162 | self.mems_dones_dimensions = [self.cfg.actor.num_workers, self.cfg.actor.num_splits, 163 | self.envs_per_split, self.max_mems_buffer_len] 164 | self.mems_dones_dimensions.append(1) 165 | self.mems_actions_dimensions = [self.cfg.actor.num_workers, self.cfg.actor.num_splits, 166 | self.envs_per_split, self.max_mems_buffer_len] 167 | self.mems_actions_dimensions.append(1) 168 | 169 | def calc_num_trajectory_buffers(self): 170 | num_traj_buffers = self.cfg.optim.batch_size / ( 171 | self.cfg.actor.num_workers * self.cfg.actor.num_envs_per_worker * self.cfg.optim.rollout) 172 | 173 | num_traj_buffers *= 3 174 | 175 | num_traj_buffers = math.ceil(max(num_traj_buffers, self.cfg.shared_buffer.min_traj_buffers_per_worker)) 176 | log.info('Using %d sets of trajectory buffers', num_traj_buffers) 177 | return num_traj_buffers 178 | 179 | def tensor_dimensions(self): 180 | dimensions = [ 181 | self.cfg.actor.num_workers, 182 | self.cfg.actor.num_splits, 183 | self.envs_per_split, 184 | self.num_traj_buffers, 185 | self.cfg.optim.rollout, 186 | ] 187 | return dimensions 188 | 189 | def init_tensor(self, tensor_type, tensor_shape): 190 | if not isinstance(tensor_type, torch.dtype): 191 | tensor_type = to_torch_dtype(tensor_type) 192 | 193 | dimensions = self.tensor_dimensions() 194 | final_shape = dimensions + list(tensor_shape) 195 | t = torch.zeros(final_shape, dtype=tensor_type) 196 | t.share_memory_() 197 | return t 198 | 199 | def tensor_dict_to_numpy(self, num_dimensions): 200 | numpy_dict = copy_dict_structure(self.tensors) 201 | for d1, d2, key, curr_t, value2 in iter_dicts_recursively(self.tensors, numpy_dict): 202 | assert isinstance(curr_t, torch.Tensor) 203 | assert value2 is None 204 | d2[key] = to_numpy(curr_t, num_dimensions) 205 | assert isinstance(d2[key], np.ndarray) 206 | return numpy_dict 207 | 208 | 209 | class TensorDict(dict): 210 | def index(self, indices): 211 | return self.index_func(self, indices) 212 | 213 | def index_func(self, x, indices): 214 | if isinstance(x, (dict, TensorDict)): 215 | res = TensorDict() 216 | for key, value in x.items(): 217 | res[key] = self.index_func(value, indices) 218 | return res 219 | else: 220 | t = x[indices] 221 | return t 222 | 223 | def set_data(self, index, new_data): 224 | self.set_data_func(self, index, new_data) 225 | 226 | def set_data_func(self, x, index, new_data): 227 | if isinstance(new_data, (dict, TensorDict)): 228 | for new_data_key, new_data_value in new_data.items(): 229 | self.set_data_func(x[new_data_key], index, new_data_value) 230 | elif isinstance(new_data, torch.Tensor): 231 | x[index].copy_(new_data) 232 | elif isinstance(new_data, np.ndarray): 233 | t = torch.from_numpy(new_data) 234 | x[index].copy_(t) 235 | else: 236 | raise Exception(f'Type {type(new_data)} not supported in set_data_func') 237 | 238 | 239 | class PolicyOutput: 240 | def __init__(self, name, size): 241 | self.name = name 242 | self.size = size 243 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Brain Agent 2 | ***Brain Agent*** is a distributed agent learning system for large-scale and multi-task reinforcement learning, developed by [Kakao Brain](https://www.kakaobrain.com/). 3 | Brain Agent is based on the V-trace actor-critic framework [IMPALA](https://arxiv.org/abs/1802.01561) and modifies [Sample Factory](https://github.com/alex-petrenko/sample-factory) to utilize multiple GPUs and CPUs for a much higher throughput rate during training. 4 | ## Features 5 | 6 | 1. First publicly available implementations of reproducing SOTA results on [DMLAB30](https://github.com/deepmind/lab). 7 | 2. Scalable & massive throughput. 8 | BrainAgent can produce and train 20B frames/week, or 34K fps, with 16 V100 GPUs, by scaling up high throughput single node system [Sample Factory](https://github.com/alex-petrenko/sample-factory). 9 | 3. Based on following algorithms and architectures. 10 | * [TransformerXL-I](https://proceedings.mlr.press/v119/parisotto20a.html) core and [ResNet](https://arxiv.org/abs/1512.03385?context=cs) encoder 11 | * [V-trace](https://arxiv.org/abs/1802.01561) for update algorithm 12 | * [IMPALA](https://arxiv.org/abs/1802.01561) for system framework 13 | * [PopArt](https://arxiv.org/abs/1809.04474) for multitask handling 14 | 4. For self-supervised representation learning, we include 2 additional features. 15 | * ResNet-based decoder to reconstruct the original input image ([trxl_recon](https://github.com/kakaobrain/brain_agent/blob/main/configs/trxl_recon_train.yaml)) 16 | * Additional autoregressive transformer to predict the images of future steps from the current state embedding and future action sequence ([trxl_future_pred](https://github.com/kakaobrain/brain_agent/blob/main/configs/trxl_future_pred_train.yaml)) 17 | 5. Provide codes for both training and evaluation, along with SOTA model checkpoint with 28M params. 18 | 19 | ## How to Install 20 | - `Python 3.7` 21 | - `Pytorch 1.9.0` 22 | - `CUDA 11.1` 23 | - Install DMLab envrionment - [DMLab Github](https://github.com/deepmind/lab/blob/master/docs/users/build.md) 24 | - `pip install -r requirements.txt` 25 | 26 | ## Description of Codes 27 | - `dist_launch.py` -> distributed training launcher 28 | - `eval.py` -> entry point for evaluation 29 | - `train.py` -> entry point for training 30 | - `brain_agent` 31 | - `core` 32 | - `agents` 33 | - `dmlab_multitask_agent.py` 34 | - `algos` 35 | - `aux_future_predict.py` -> Computes auxiliary loss by predicting future state transitions with autoregressive transformer. Used only for ([trxl_future_pred](https://github.com/kakaobrain/brain_agent/blob/main/configs/trxl_future_pred_train.yaml)). 36 | - `popart.py` 37 | - `vtrace.py` 38 | - `actor_worker.py` 39 | - `learner_worker.py` 40 | - `policy_worker.py` 41 | - `shared_buffer.py` -> Defines SharedBuffer class for zero-copy communication between workers. 42 | - `envs` 43 | - `dmlab` 44 | - `utils` 45 | - ... 46 | - `configs` 47 | - ... -> Hyperparam configs for each of training/evaluation. 48 | 49 | 50 | ## How to Run 51 | ### Training 52 | - 1 node x 1 GPU 53 | ```bash 54 | python train.py cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR 55 | ``` 56 | 57 | - 1 node x 4 GPUs = 4 GPUs 58 | ```bash 59 | python -m dist_launch --nnodes=1 --node_rank=0 --nproc_per_node=4 -m train \ 60 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR 61 | ``` 62 | 63 | - 4 nodes x 4 GPUs each = 16 GPUs 64 | ```bash 65 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=0 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 66 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR 67 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=1 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 68 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR 69 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=2 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 70 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR 71 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=3 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 72 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR 73 | ``` 74 | 75 | ### Evaluation 76 | ```bash 77 | python eval.py cfg=configs/trxl_recon_eval.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR test.checkpoint=$CHECKPOINT_FILE_PATH 78 | ``` 79 | 80 | ### Setting Hyperparameters 81 | - All the default hyperparameters are defined at `configs/default.yaml` 82 | - Other config files override on `configs/default.yaml`. 83 | - You can use pre-defined hyperparameters for our experiments with `configs/trxl_recon_train.yaml` or `configs/trxl_future_pred.yaml`. 84 | 85 | 86 | 87 | ## Results for DMLAB30 88 | 89 | - Settings 90 | - 3 runs with different seeds 91 | - 100 episodes per each run 92 | - HNS : Human Normalised Score 93 | - Results 94 | 95 | | Model | Mean HNS | Median HNS | Mean Capped HNS | 96 | |:----------:|:--------:|:----------:|:---------------:| 97 | | [MERLIN](https://arxiv.org/pdf/1803.10760.pdf) | 115.2 | - | 89.4 | 98 | | [GTrXL](https://proceedings.mlr.press/v119/parisotto20a.html) | 117.6 | - | 89.1 | 99 | | [CoBERL](https://arxiv.org/pdf/2107.05431.pdf) | 115.47 | 110.86 |- | 100 | | [R2D2+](https://openreview.net/pdf?id=r1lyTjAqYX) | - | 99.5 | 85.7 | 101 | | [LASER](https://arxiv.org/abs/1909.11583) | - | 97.2 | 81.7 | 102 | | [PBL](https://arxiv.org/pdf/2004.14646.pdf) | 104.16 | - | 81.5 | 103 | | [PopArt-IMPALA](https://arxiv.org/abs/1809.04474) | - | - | 72.8 | 104 | | [IMPALA](https://arxiv.org/abs/1802.01561) | - | - | 58.4 | 105 | | Ours (lstm_baseline, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/ed7f0e5a8dc57ad72c8c38319f58000e/rnn_baseline_20b.pth)) | 103.03 ± 0.37 | 92.04 ± 0.73 | 81.35 ± 0.25 | 106 | | Ours (trxl_baseline, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/c0e90a4e3555a12b58e60729d13e2e02/trxl_baseline_20b.pth)) | 111.95 ± 1.00 | 105.43 ± 2.61 | 85.57 ± 0.20 | 107 | | Ours (trxl_recon, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/84ac1c594b8eb95e7fd9879d6172f99b/trxl_recon_20b.pth)) | 123.60 ± 0.84 | 108.63 ± 1.20 | **91.25 ± 0.41** | 108 | | Ours (trxl_future_pred, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/a54acdd9d3a14f2905d295c9c63bf31d/trxl_future_pred_20b.pth)) | **128.00 ± 0.43** | 108.80 0.99 | 90.53 ± 0.26 | 109 | 110 | 111 |
112 | Results for all 30 tasks 113 |
114 | 115 | | Level | lstm_baseline |  trxl_baseline  |     trxl_recon     | trxl_future_pred | 116 | | :-----: | :-----: | :-----: | :-----: | :---: | 117 | | rooms_collect_good_objects_(train / test) | 94.22 ± 0.84 / 95.13 ± 0.61 | 97.85 ± 0.31 / 95.20 ± 1.26 | 97.58 ± 0.20 / 89.39 ± 1.42 | 98.19 ± 0.18 / 98.52 ± 0.95 | 118 | | rooms_exploit_deferred_effects_(train / test) | 37.84 ± 2.23 / 4.36 ± 1.84 | 38.40 ± 3.82 / 1.73 ± 0.63 | 38.86 ± 3.48 / 4.04 ± 0.89 | 40.93 ± 3.12 / 2.26 ± 0.71 | 119 | | rooms_select_nonmatching_object | 50.13 ± 2.95 | 98.78 ± 1.38 | 99.52 ± 0.97 | 113.20 ± 1.14 | 120 | | rooms_watermaze | 45.09 ± 4.70 | 36.92 ± 6.90 | 111.20 ± 2.29 | 55.82 ± 0.74 | 121 | | rooms_keys_doors_puzzle | 51.75 ± 8.90 | 55.86 ± 4.25 | 61.24 ± 9.09 | 64.95 ± 8.43 | 122 | | language_select_described_object | 150.57 ± 0.58 | 154.90 ± 0.22 | 155.35 ± 0.17 | 158.23 ± 0.90 | 123 | | language_select_located_object | 225.97 ± 1.93 | 244.46 ± 1.56 | 252.04 ± 0.31 | 261.20 ± 1.15 | 124 | | language_execute_random_task | 126.49 ± 2.35 | 139.63 ± 1.23 | 145.21 ± 0.36 | 150.20 ± 1.35 | 125 | | language_answer_quantitative_question | 153.92 ± 2.35 | 162.99 ± 2.42 | 163.72 ± 1.36 | 166.07 ± 1.72 | 126 | | lasertag_one_opponent_small | 234.90 ± 6.19 | 243.52 ± 3.96 | 249.99 ± 6.64 | 279.54 ± 4.14 | 127 | | lasertag_three_opponents_small | 235.61 ± 1.92 | 242.61 ± 3.75 | 246.68 ± 5.99 | 264.20 ± 3.76 | 128 | | lasertag_one_opponent_large | 74.88 ± 5.06 | 83.51 ± 1.31 | 82.55 ± 2.15 | 94.86 ± 3.64 | 129 | | lasertag_three_opponents_large | 84.78 ± 2.42 | 92.04 ± 2.17 | 96.54 ± 0.67 | 105.83 ± 0.47 | 130 | | natlab_fixed_large_map | 98.10 ± 1.77 | 110.74 ± 1.34 | 120.53 ± 1.79 | 118.17 ± 1.79 | 131 | | natlab_varying_map_regrowth | 108.54 ± 1.20 | 107.16 ± 2.68 | 108.14 ± 1.25 | 104.83 ± 1.26 | 132 | | natlab_varying_map_randomized | 85.33 ± 6.52 | 86.33 ± 7.30 | 85.53 ± 6.69 | 77.74 ± 0.84 | 133 | | skymaze_irreversible_path_hard | 55.29 ± 9.08 | 60.63 ± 4.73 | 61.63 ± 2.52 | 66.30 ± 5.69 | 134 | | skymaze_irreversible_path_varied | 77.02 ± 3.57 | 77.41 ± 0.67 | 81.31 ± 2.34 | 79.36 ± 7.95 | 135 | | psychlab_arbitrary_visuomotor_mapping | 52.17 ± 2.06 | 51.46 ± 0.45 | 101.82 ± 0.19 | 101.80 ± 0.00 | 136 | | psychlab_continuous_recognition | 52.57 ± 0.46 | 52.41 ± 0.92 | 102.46 ± 0.32 | 102.30 ± 0.00 | 137 | | psychlab_sequential_comparison | 76.82 ± 0.45 | 75.48 ± 1.16 | 75.74 ± 0.58 | 76.13 ± 0.77 | 138 | | psychlab_visual_search | 101.54 ± 0.10 | 101.58 ± 0.04 | 101.91 ± 0.00 | 101.90 ± 0.00 | 139 | | explore_object_locations_small | 118.89 ± 0.93 | 121.47 ± 0.26 | 123.54 ± 2.61 | 126.67 ± 2.08 | 140 | | explore_object_locations_large | 111.46 ± 2.91 | 120.70 ± 2.12 | 115.43 ± 1.64 | 129.83 ± 2.41 | 141 | | explore_obstructed_goals_small | 136.92 ± 6.02 | 148.05 ± 1.96 | 166.75 ± 3.63 | 174.30 ± 3.72 | 142 | | explore_obstructed_goals_large | 92.36 ± 5.81 | 106.73 ± 7.86 | 153.44 ± 3.20 | 176.43 ± 1.50 | 143 | | explore_goal_locations_small | 143.21 ± 8.21 | 154.87 ± 4.41 | 177.16 ± 0.37 | 193.00 ± 3.75 | 144 | | explore_goal_locations_large | 98.50 ± 9.61 | 117.33 ± 6.75 | 160.39 ± 3.32 | 178.13 ± 7.15 | 145 | | explore_object_rewards_few | 76.29 ± 1.52 | 108.64 ± 0.89 | 109.58 ± 3.53 | 110.07 ± 1.42 | 146 | | explore_object_rewards_many | 72.33 ± 0.87 | 105.33 ± 1.52 | 105.15 ± 0.75 | 107.23 ± 1.59 | 147 | 148 |
149 |
150 | 151 | - Learning curves 152 |
153 | Learning Curve 154 |
155 | 156 | 157 | ## Distributed RL System Overview 158 |
159 | Learning Curve 160 |
161 | 162 | 163 | ## Notes 164 | - Acknowledgement 165 | - [Sample Factory](alex-petrenko/sample-factory) for optimized single node training, [Transformer-XL](https://github.com/kimiyoung/transformer-xl) for neural net core architecture. 166 | 167 | - License 168 | - This repository is released under the MIT license, included [here](LICENSE). 169 | - This repository includes some codes from [sample-factory](https://github.com/alex-petrenko/sample-factory) 170 | (MIT license) and [transformer-xl](https://github.com/kimiyoung/transformer-xl) (Apache 2.0 License). 171 | 172 | - Contact 173 | - Agent learning team, [Kakao Brain](https://www.kakaobrain.com/). 174 | - If you have any question or feedback regarding this repository, please email to contact@kakaobrain.com 175 | 176 | ## Citation 177 | ``` 178 | @misc{kakaobrain2022brain_agent, title = {Brain Agent}, 179 | author = {Donghoon Lee, Taehwan Kwon, Seungeun Rho, Daniel Wontae Nam, Jongmin Kim, Daejin Jo, and Sungwoong Kim}, 180 | year = {2022}, howpublished = {\url{https://github.com/kakaobrain/brain_agent}} } 181 | ``` 182 | -------------------------------------------------------------------------------- /brain_agent/core/models/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2019 kimiyoung 3 | Copyright 2022 Kakao Brain Corp. 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | Note: 18 | This specific file is modification of kimiyoung's version of TrXL 19 | https://github.com/kimiyoung/transformer-xl which implements TrXL-I in https://arxiv.org/pdf/1910.06764.pdf for 20 | RL environments. 21 | """ 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | class PositionalEmbedding(nn.Module): 28 | def __init__(self, demb): 29 | super(PositionalEmbedding, self).__init__() 30 | 31 | self.demb = demb 32 | 33 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 34 | self.register_buffer('inv_freq', inv_freq) 35 | 36 | def forward(self, pos_seq, bsz=None): 37 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 38 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 39 | 40 | if bsz is not None: 41 | return pos_emb[:,None,:].expand(-1, bsz, -1) 42 | else: 43 | return pos_emb[:,None,:] 44 | 45 | 46 | class PositionwiseFF(nn.Module): 47 | def __init__(self, d_model, d_inner, pre_lnorm=False): 48 | super(PositionwiseFF, self).__init__() 49 | 50 | self.d_model = d_model 51 | self.d_inner = d_inner 52 | 53 | self.CoreNet = nn.Sequential( 54 | nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), 55 | nn.Linear(d_inner, d_model), 56 | ) 57 | 58 | self.layer_norm = nn.LayerNorm(d_model) 59 | self.pre_lnorm = pre_lnorm 60 | 61 | def forward(self, inp): 62 | if self.pre_lnorm: 63 | ##### layer normalization + positionwise feed-forward 64 | core_out = self.CoreNet(self.layer_norm(inp)) 65 | 66 | ##### residual connection 67 | output = F.relu(core_out) + inp 68 | else: 69 | ##### positionwise feed-forward 70 | core_out = self.CoreNet(inp) 71 | 72 | ##### residual connection + layer normalization 73 | output = self.layer_norm(inp + core_out) 74 | 75 | return output 76 | 77 | 78 | class RelPartialLearnableDecoderLayer(nn.Module): 79 | def __init__(self, n_head, d_model, d_head, d_inner, **kwargs): 80 | super(RelPartialLearnableDecoderLayer, self).__init__() 81 | 82 | self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, 83 | d_head, **kwargs) 84 | self.pos_ff = PositionwiseFF(d_model, d_inner, pre_lnorm=kwargs.get('pre_lnorm')) 85 | 86 | def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): 87 | 88 | output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, 89 | attn_mask=dec_attn_mask, 90 | mems=mems) 91 | output = self.pos_ff(output) 92 | 93 | return output 94 | 95 | 96 | class RelMultiHeadAttn(nn.Module): 97 | def __init__(self, n_head, d_model, d_head, mem_len=None, pre_lnorm=False): 98 | super(RelMultiHeadAttn, self).__init__() 99 | 100 | self.n_head = n_head 101 | self.d_model = d_model 102 | self.d_head = d_head 103 | 104 | self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) 105 | 106 | self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) 107 | 108 | self.layer_norm = nn.LayerNorm(d_model) 109 | 110 | self.scale = 1 / (d_head ** 0.5) 111 | 112 | self.pre_lnorm = pre_lnorm 113 | 114 | def _parallelogram_mask(self, h, w, left=False): 115 | mask = torch.ones((h, w)).bool() 116 | m = min(h, w) 117 | mask[:m,:m] = torch.triu(mask[:m,:m]) 118 | mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) 119 | 120 | if left: 121 | return mask 122 | else: 123 | return mask.flip(0) 124 | 125 | def _shift(self, x, qlen, klen, mask, left=False): 126 | if qlen > 1: 127 | zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), 128 | device=x.device, dtype=x.dtype) 129 | else: 130 | zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) 131 | 132 | if left: 133 | mask = mask.flip(1) 134 | x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) 135 | else: 136 | x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) 137 | 138 | x = x_padded.masked_select(mask[:,:,None,None]) \ 139 | .view(qlen, klen, x.size(2), x.size(3)) 140 | 141 | return x 142 | 143 | def _rel_shift(self, x, zero_triu=False): 144 | zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), 145 | device=x.device, dtype=x.dtype) 146 | x_padded = torch.cat([zero_pad, x], dim=1) 147 | 148 | x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) 149 | 150 | x = x_padded[1:].view_as(x) 151 | 152 | if zero_triu: 153 | ones = torch.ones((x.size(0), x.size(1))) 154 | x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] 155 | 156 | return x 157 | 158 | def forward(self, w, r, attn_mask=None, mems=None): 159 | raise NotImplementedError 160 | 161 | 162 | class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): 163 | def __init__(self, *args, **kwargs): 164 | super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) 165 | 166 | self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) 167 | 168 | 169 | def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): 170 | qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) 171 | 172 | if (mems is not None) and mems.size(0) > 0: # TODO: check 173 | cat = torch.cat([mems, w], 0) 174 | if self.pre_lnorm: 175 | w_heads = self.qkv_net(self.layer_norm(cat)) 176 | else: 177 | w_heads = self.qkv_net(cat) 178 | 179 | if mems.dtype == torch.float16: 180 | r = r.half() # TODO: should be handled with cfg 181 | r_head_k = self.r_net(r) 182 | 183 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 184 | w_head_q = w_head_q[-qlen:] 185 | else: 186 | if self.pre_lnorm: 187 | w_heads = self.qkv_net(self.layer_norm(w)) 188 | else: 189 | w_heads = self.qkv_net(w) 190 | 191 | r_head_k = self.r_net(r) 192 | 193 | w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) 194 | 195 | klen = w_head_k.size(0) 196 | 197 | w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head 198 | w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head 199 | w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # klen x bsz x n_head x d_head 200 | 201 | r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head 202 | 203 | #### compute attention score 204 | rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head 205 | AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head 206 | 207 | rr_head_q = w_head_q + r_r_bias 208 | BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head 209 | BD = self._rel_shift(BD) 210 | 211 | # [qlen x klen x bsz x n_head] 212 | attn_score = AC + BD 213 | attn_score.mul_(self.scale) 214 | 215 | #### compute attention probability 216 | if attn_mask is not None and attn_mask.any().item(): 217 | if attn_mask.dim() == 2: 218 | attn_score = attn_score.float().masked_fill( 219 | attn_mask[None,:,:,None], -float("inf")).type_as(attn_score) 220 | elif attn_mask.dim() == 3: 221 | attn_score = attn_score.float().masked_fill( 222 | attn_mask[:,:,:,None], -float("inf")).type_as(attn_score) 223 | 224 | # [qlen x klen x bsz x n_head] 225 | attn_prob = F.softmax(attn_score, dim=1) 226 | 227 | #### compute attention vector 228 | attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) 229 | 230 | # [qlen x bsz x n_head x d_head] 231 | attn_vec = attn_vec.contiguous().view( 232 | attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) 233 | 234 | ##### linear projection 235 | attn_out = self.o_net(attn_vec) 236 | 237 | if self.pre_lnorm: 238 | ##### residual connection 239 | # modified. applying ReLU before residual connection 240 | output = w + F.relu(attn_out) 241 | else: 242 | ##### residual connection + layer normalization 243 | output = self.layer_norm(w + attn_out) 244 | 245 | return output 246 | 247 | 248 | class MemTransformerLM(nn.Module): 249 | def __init__(self, cfg, n_layer, n_head, d_model, d_head, d_inner, 250 | mem_len=1, pre_lnorm=False): 251 | super(MemTransformerLM, self).__init__() 252 | self.cfg = cfg 253 | 254 | self.d_embed = d_model 255 | self.d_model = d_model 256 | self.n_head = n_head 257 | self.d_head = d_head 258 | 259 | self.n_layer = n_layer 260 | 261 | self.mem_len = mem_len 262 | 263 | self.layers = nn.ModuleList() 264 | 265 | for i in range(n_layer): 266 | self.layers.append( 267 | RelPartialLearnableDecoderLayer( 268 | n_head, d_model, d_head, d_inner, 269 | mem_len=mem_len, pre_lnorm=pre_lnorm) 270 | ) 271 | 272 | # create positional encoding-related parameters 273 | self.pos_emb = PositionalEmbedding(self.d_model) 274 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 275 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) 276 | 277 | self.apply(self.initialize) 278 | 279 | def initialize(self, layer): 280 | def init_weight(weight): 281 | nn.init.normal_(weight, 0.0, 0.02) # args.init_std) 282 | 283 | def init_bias(bias): 284 | nn.init.constant_(bias, 0.0) 285 | 286 | classname = layer.__class__.__name__ 287 | if classname.find('Linear') != -1: 288 | if hasattr(layer, 'weight') and layer.weight is not None: 289 | init_weight(layer.weight) 290 | if hasattr(layer, 'bias') and layer.bias is not None: 291 | init_bias(layer.bias) 292 | elif classname.find('AdaptiveEmbedding') != -1: 293 | if hasattr(layer, 'emb_projs'): 294 | for i in range(len(layer.emb_projs)): 295 | if layer.emb_projs[i] is not None: 296 | nn.init.normal_(layer.emb_projs[i], 0.0, 0.01) # args.proj_init_std) 297 | elif classname.find('Embedding') != -1: 298 | if hasattr(layer, 'weight'): 299 | init_weight(layer.weight) 300 | elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: 301 | if hasattr(layer, 'cluster_weight') and layer.cluster_weight is not None: 302 | init_weight(layer.cluster_weight) 303 | if hasattr(layer, 'cluster_bias') and layer.cluster_bias is not None: 304 | init_bias(layer.cluster_bias) 305 | if hasattr(layer, 'out_projs'): 306 | for i in range(len(layer.out_projs)): 307 | if layer.out_projs[i] is not None: 308 | nn.init.normal_(layer.out_projs[i], 0.0, 0.01) # args.proj_init_std) 309 | elif classname.find('LayerNorm') != -1: 310 | if hasattr(layer, 'weight'): 311 | nn.init.normal_(layer.weight, 1.0, 0.02) # args.init_std) 312 | if hasattr(layer, 'bias') and layer.bias is not None: 313 | init_bias(layer.bias) 314 | elif classname.find('TransformerLM') != -1: 315 | if hasattr(layer, 'r_emb'): 316 | init_weight(layer.r_emb) 317 | if hasattr(layer, 'r_w_bias'): 318 | init_weight(layer.r_w_bias) 319 | if hasattr(layer, 'r_r_bias'): 320 | init_weight(layer.r_r_bias) 321 | if hasattr(layer, 'r_bias'): 322 | init_bias(layer.r_bias) 323 | 324 | def get_core_out_size(self): 325 | return self.d_model 326 | 327 | def init_mems(self): 328 | if self.mem_len > 0: 329 | mems = [] 330 | param = next(self.parameters()) 331 | for i in range(self.n_layer + 1): 332 | empty = torch.empty(0, dtype=param.dtype, device=param.device) 333 | mems.append(empty) 334 | 335 | return mems 336 | else: 337 | return None 338 | 339 | def _update_mems(self, hids, mems, mlen, qlen): 340 | # does not deal with None 341 | if mems is None: return None 342 | 343 | # mems is not None 344 | assert len(hids) == len(mems), 'len(hids) != len(mems)' 345 | 346 | with torch.no_grad(): 347 | new_mems = [] 348 | end_idx = mlen + max(0, qlen) 349 | beg_idx = max(0, end_idx - self.mem_len) 350 | 351 | for i in range(len(hids)): 352 | 353 | cat = torch.cat([mems[i], hids[i]], dim=0) 354 | new_mems.append(cat[beg_idx:end_idx].detach()) 355 | 356 | # only return last step mem 357 | new_mems = [m[-1] for m in new_mems] 358 | new_mems = torch.cat(new_mems, dim=-1) 359 | 360 | return new_mems 361 | 362 | def _forward(self, obs_emb, mems=None, mem_begin_index=None, dones=None, from_learner=False): 363 | qlen, bsz, _ = obs_emb.size() 364 | 365 | mlen = mems[0].size(0) if mems is not None else 0 366 | 367 | klen = mlen + qlen 368 | 369 | dec_attn_mask = (torch.triu( 370 | obs_emb.new_ones(qlen, klen), diagonal=1+mlen) 371 | + torch.tril( 372 | obs_emb.new_ones(qlen, klen), diagonal=-1)).bool().unsqueeze(-1).repeat(1, 1, bsz) 373 | 374 | for b in range(bsz): 375 | dec_attn_mask[:, :(mlen - max(0, mem_begin_index[b])), b] = True 376 | if dones is not None: 377 | query_done_index = torch.where(dones[:, b] > 0) 378 | for q in query_done_index[0]: 379 | # Going to mask out elements before done for new episode 380 | dec_attn_mask[q + 1:, :(mlen + q + 1), b] = True 381 | 382 | hids = [] 383 | pos_seq = torch.arange(klen-1, -1, -1.0, device=obs_emb.device, 384 | dtype=obs_emb.dtype) 385 | 386 | pos_emb = self.pos_emb(pos_seq) 387 | core_out = obs_emb 388 | 389 | hids.append(core_out) 390 | 391 | for i, layer in enumerate(self.layers): 392 | mems_i = None if mems is None else mems[i] 393 | core_out = layer(core_out, pos_emb, self.r_w_bias, 394 | self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) 395 | hids.append(core_out) 396 | 397 | new_mems = self._update_mems(hids, mems, mlen, qlen) if not from_learner else None 398 | 399 | return core_out, new_mems 400 | 401 | def forward(self, data, mems, mem_begin_index=None, dones=None, from_learner=False): 402 | if mems is None: 403 | mems = self.init_mems() 404 | else: 405 | mems = torch.split(mems, self.d_model, dim=-1) 406 | 407 | if from_learner: 408 | data = data.reshape( 409 | int(self.cfg.optim.batch_size // self.cfg.optim.rollout), 410 | self.cfg.optim.rollout, 411 | -1).transpose(0, 1) 412 | else: 413 | data = data.unsqueeze(0) 414 | 415 | # input observation should be either (1 x B x dim) or (T x B x dim) 416 | hidden, new_mems = self._forward(data, mems=mems, mem_begin_index=mem_begin_index, dones=dones, 417 | from_learner=from_learner) 418 | 419 | # reshape hidden: T x B x dim -> TB x dim 420 | hidden = hidden.transpose(0, 1).reshape(hidden.size(0) * hidden.size(1), -1) 421 | 422 | return hidden, new_mems 423 | 424 | def get_mem_begin_index(self, mems_dones, actor_env_step): 425 | # mems_dones: (n_batch, n_seq, 1) 426 | # actor_env_step: (n_batch) 427 | assert mems_dones.shape[0] == actor_env_step.shape[0], ( 428 | f'The number of batches should be same for mems_done ({mems_dones.shape[0]})' 429 | + f' and actor_env_step ({actor_env_step.shape[0]})' 430 | ) 431 | mems_dones = mems_dones.squeeze(-1).cpu() 432 | actor_env_step = actor_env_step.cpu() 433 | 434 | arange = torch.arange(1, self.cfg.model.core.mem_len + 1, 1).unsqueeze(0) # 0 ~ self.cfg.mem_len - 1, (1, n_seq) 435 | step_count_dones = mems_dones * arange # (n_batch, n_seq) 436 | step_count_last_dones = step_count_dones.max(dim=-1).values # (n_batch) 437 | numel_to_be_attentioned = self.cfg.model.core.mem_len - step_count_last_dones 438 | mem_begin_index = torch.min(numel_to_be_attentioned, actor_env_step) 439 | mem_begin_index = mem_begin_index.int().tolist() 440 | 441 | return mem_begin_index 442 | -------------------------------------------------------------------------------- /brain_agent/core/policy_worker.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import signal 3 | import time 4 | from collections import deque 5 | from queue import Empty 6 | import torch 7 | import numpy as np 8 | import psutil 9 | import os 10 | 11 | from torch.multiprocessing import Process as TorchProcess 12 | from brain_agent.core.agents.agent_utils import create_agent 13 | from brain_agent.utils.logger import log 14 | from brain_agent.utils.utils import AttrDict 15 | from brain_agent.utils.timing import Timing 16 | 17 | from brain_agent.core.core_utils import TaskType, dict_of_lists_append, slice_mems, join_or_kill 18 | 19 | 20 | class PolicyWorker: 21 | """ 22 | Policy worker is a separate process that uses the policy to generate actions for the rollout train data. 23 | 24 | Args: 25 | cfg (['brain_agent.utils.utils.AttrDict'], 'AttrDict'): 26 | Global configuration in a form of AttrDict, a dictionary whose values can be accessed 27 | obs_space ('gym.spaces'): 28 | Observation space. 29 | action_space ('gym.spaces.discrete.Discrete'): 30 | Action space object. Currently only supports discrete action spaces. 31 | level_info ('dct'): 32 | Dictionary of level info, from DMLab env. 33 | shared_buffer (['brain_agent.core.shared_buffer.SharedBuffer']): 34 | Shared buffer object that stores collected rollouts. 35 | policy_queue ('faster_fifo.Queue'): 36 | Action request queue for the policy in the policy worker. Not to be confused with policy_worker_queue. 37 | actor_worker_queue ('faster_fifo.Queue'): 38 | Task queue for the actor worker. 39 | policy_worker_queue ('faster_fifo.Queue'): 40 | Task queue for the policy worker. Not to be confused wirth policy_queue. 41 | report_queue ('faster_fifo.Queue'): 42 | Task queue for reporting. This is where various workers dump information to log. 43 | policy_lock ('multiprocessing.synchronize.Lock', *optional*): 44 | This will be used to apply lock when updating and broadcasting model parameters. 45 | resume_experiment_collection_cv(*optional*) 46 | """ 47 | def __init__(self, cfg, obs_space, action_space, level_info, shared_buffer, policy_queue, actor_worker_queues, 48 | policy_worker_queue, report_queue, policy_lock=None, resume_experience_collection_cv=None): 49 | log.info('Initializing policy worker %d', cfg.dist.world_rank) 50 | 51 | self.cfg = cfg 52 | 53 | self.obs_space = obs_space 54 | self.action_space = action_space 55 | self.level_info = level_info 56 | 57 | self.device = None 58 | self.actor_critic = None 59 | self.shared_model_weights = None 60 | 61 | self.policy_queue = policy_queue 62 | self.actor_worker_queues = actor_worker_queues 63 | self.report_queue = report_queue 64 | self.policy_worker_queue = policy_worker_queue 65 | 66 | self.policy_lock = policy_lock 67 | self.resume_experience_collection_cv = resume_experience_collection_cv 68 | 69 | self.initialized = False 70 | self.terminate = False 71 | self.initialized_event = multiprocessing.Event() 72 | self.initialized_event.clear() 73 | 74 | self.shared_buffer = shared_buffer 75 | 76 | self.tensors_individual_transitions = self.shared_buffer.tensors_individual_transitions 77 | self.policy_outputs = self.shared_buffer.policy_output_tensors 78 | 79 | self.latest_policy_version = -1 80 | self.num_policy_updates = 0 81 | 82 | self.requests = [] 83 | 84 | self.total_num_samples = 0 85 | 86 | self.num_traj_buffers = shared_buffer.num_traj_buffers 87 | self.timing = Timing() 88 | 89 | self.initialized = False 90 | self.timing = Timing() 91 | 92 | self.process = TorchProcess(target=self._run, daemon=True) 93 | 94 | def start_process(self): 95 | self.process.start() 96 | 97 | def init(self): 98 | self.policy_worker_queue.put((TaskType.INIT, None)) 99 | self.initialized_event.wait() 100 | 101 | def init_model(self, data): 102 | self.policy_worker_queue.put((TaskType.INIT_MODEL, data)) 103 | 104 | def load_model(self): 105 | self.policy_worker_queue.put((TaskType.INIT_MODEL, None)) 106 | 107 | def _init(self): 108 | if self.cfg.model.device == 'cuda': 109 | assert torch.cuda.device_count() == 1 110 | self.device = torch.device('cuda', index=0) 111 | else: 112 | self.device = torch.device('cpu') 113 | 114 | log.info('Policy worker %d initialized', self.cfg.dist.world_rank) 115 | self.initialized_event.set() 116 | 117 | def _init_model(self, init_model_data): 118 | 119 | self.actor_critic = create_agent(self.cfg, self.action_space, self.obs_space, self.level_info[ 120 | 'num_levels'], need_half=self.cfg.model.use_half_policy_worker) 121 | 122 | self.actor_critic.model_to_device(self.device) 123 | for p in self.actor_critic.parameters(): 124 | p.requires_grad = False 125 | 126 | if self.cfg.model.core.core_type == 'trxl': 127 | max_batch_size = self.cfg.actor.num_workers * self.cfg.actor.num_envs_per_worker * 2 128 | mem_T_dim = self.cfg.model.core.mem_len 129 | mem_D_dim = self.shared_buffer.mems_dimensions[-1] 130 | mem_dones_D_dim = 1 131 | self.mems = torch.zeros([max_batch_size, mem_T_dim, mem_D_dim]).float().to(self.device) 132 | self.mems_dones = torch.zeros([max_batch_size, mem_T_dim, mem_dones_D_dim]).float().to(self.device) 133 | self.mems_actions = torch.zeros([max_batch_size, mem_T_dim, mem_dones_D_dim]).short().to(self.device) 134 | self.max_mems_buffer_len = self.shared_buffer.max_mems_buffer_len 135 | 136 | self.mems_buffer = None 137 | self.mems_dones_buffer = None 138 | 139 | if init_model_data is None: 140 | self._load_model() 141 | else: 142 | policy_version, state_dict, mems_buffer, mems_dones_buffer, mems_actions_buffer = init_model_data 143 | self.actor_critic.load_state_dict(state_dict) 144 | 145 | self.mems_buffer = mems_buffer 146 | self.mems_actions_buffer = mems_actions_buffer 147 | self.mems_dones_buffer = mems_dones_buffer 148 | 149 | self.shared_model_weights = state_dict 150 | self.latest_policy_version = policy_version 151 | 152 | log.info('Initialized model on the policy worker %d!', self.cfg.dist.world_rank) 153 | self.initialized = True 154 | 155 | def _load_model(self): 156 | ckpt = torch.load(self.cfg.test.checkpoint) 157 | env_step = ckpt['env_steps'] 158 | policy_version = ckpt['train_step'] 159 | state_dict = ckpt['model'] 160 | 161 | self.actor_critic.load_state_dict(state_dict) 162 | self.shared_model_weights = state_dict 163 | self.latest_policy_version = policy_version 164 | 165 | self.mems_buffer = torch.zeros(self.shared_buffer.mems_dimensions).to(self.cfg.model.device) 166 | self.mems_dones_buffer = torch.zeros(self.shared_buffer.mems_dones_dimensions, dtype=torch.bool).to( 167 | self.cfg.model.device) 168 | self.mems_actions_buffer = torch.zeros(self.shared_buffer.mems_actions_dimensions).short().to( 169 | self.cfg.model.device) 170 | 171 | self.report_queue.put(dict(learner_env_steps=env_step)) 172 | 173 | def _write_done_on_mems_dones_buffer(self, raw_index, actor_env_step): 174 | traj_tensors = self.shared_buffer.tensors_individual_transitions 175 | 176 | actor_idx, split_idx, env_idx, traj_buffer_idx, rollout_step = raw_index 177 | if rollout_step == 0: 178 | index_for_done = actor_idx, split_idx, env_idx, (traj_buffer_idx - 1) % self.num_traj_buffers, rollout_step - 1 179 | else: 180 | index_for_done = actor_idx, split_idx, env_idx, traj_buffer_idx, rollout_step - 1 181 | 182 | done = traj_tensors['dones'][index_for_done] 183 | index_for_mems_dones = actor_idx, split_idx, env_idx, (actor_env_step - 1) % self.max_mems_buffer_len 184 | self.mems_dones_buffer[index_for_mems_dones] = bool(done) 185 | 186 | def _handle_policy_steps(self): 187 | """ 188 | Forwards policy for the indexed location in self.requests. 189 | The resulting actions are stored back into the respective index in the shared buffer. 190 | """ 191 | with torch.no_grad(): 192 | with self.timing.timeit('deserialize'): 193 | observations = AttrDict() 194 | r_idx = 0 195 | first_rollout_list = [] 196 | rollout_step_list = [] 197 | actor_env_step_list = [] 198 | actions = [] 199 | rewards = [] 200 | task_ids = [] 201 | rnn_states = [] 202 | dones = [] 203 | 204 | traj_tensors = self.shared_buffer.tensors_individual_transitions 205 | 206 | # Run through the request to read necessary data to inference action using policy. 207 | for request in self.requests: 208 | actor_idx, split_idx, request_data = request 209 | if self.cfg.model.core.core_type == 'trxl': 210 | for env_idx, traj_buffer_idx, rollout_step, first_rollout, actor_env_step in request_data: 211 | index = actor_idx, split_idx, env_idx, traj_buffer_idx, rollout_step 212 | with self.timing.timeit('write_done_on_mems_dones_buffer'): 213 | self._write_done_on_mems_dones_buffer(index, actor_env_step) 214 | self.timing['write_done_on_mems_dones_buffer'] *= len(self.requests) * len(request_data) 215 | dict_of_lists_append(observations, traj_tensors['obs'], index) 216 | 217 | s_idx = (actor_env_step - self.cfg.model.core.mem_len) % self.max_mems_buffer_len 218 | e_idx = actor_env_step % self.max_mems_buffer_len 219 | with self.timing.timeit('mems_copy'): 220 | self.mems[r_idx], self.mems_dones[r_idx], self.mems_actions[r_idx] = slice_mems( 221 | self.mems_buffer, self.mems_dones_buffer, self.mems_actions_buffer, *index[:3], s_idx, e_idx) 222 | self.timing['mems_copy'] *= len(self.requests) * len(request_data) 223 | r_idx += 1 224 | first_rollout_list.append(first_rollout) 225 | rollout_step_list.append(rollout_step) 226 | actor_env_step_list.append(actor_env_step) 227 | 228 | actions.append(traj_tensors['prev_actions'][index]) 229 | rewards.append(traj_tensors['prev_rewards'][index]) 230 | task_ids.append(self.shared_buffer.task_ids[actor_idx][split_idx][env_idx].unsqueeze(0)) 231 | self.total_num_samples += 1 232 | elif self.cfg.model.core.core_type == 'rnn': 233 | for env_idx, traj_buffer_idx, rollout_step, first_rollout, actor_env_step in request_data: 234 | index = actor_idx, split_idx, env_idx, traj_buffer_idx, rollout_step 235 | rnn_states.append(traj_tensors['rnn_states'][index]) 236 | dones.append(traj_tensors['dones'][index]) 237 | 238 | first_rollout_list.append(first_rollout) 239 | rollout_step_list.append(rollout_step) 240 | actor_env_step_list.append(actor_env_step) 241 | dict_of_lists_append(observations, traj_tensors['obs'], index) 242 | 243 | actions.append(traj_tensors['prev_actions'][index]) 244 | rewards.append(traj_tensors['prev_rewards'][index]) 245 | task_ids.append(self.shared_buffer.task_ids[actor_idx][split_idx][env_idx].unsqueeze(0)) 246 | self.total_num_samples += 1 247 | 248 | with self.timing.timeit('reordering_mems'): 249 | if self.cfg.model.core.core_type == 'trxl': 250 | n_batch = len(actor_env_step_list) 251 | if self.cfg.model.core.mem_len > 0: 252 | mems_dones = self.mems_dones[:n_batch] 253 | actor_env_step = torch.tensor(actor_env_step_list) # (n_batch) 254 | mem_begin_index = self.actor_critic.core.get_mem_begin_index(mems_dones, actor_env_step) 255 | self.actor_critic.actor_env_step = actor_env_step 256 | else: 257 | mem_begin_index = [0] * n_batch 258 | 259 | 260 | with self.timing.timeit('stack'): 261 | for key, x in observations.items(): 262 | observations[key] = torch.stack(x) 263 | actions = torch.stack(actions) 264 | rewards = torch.stack(rewards) 265 | task_ids = torch.stack(task_ids) 266 | if self.cfg.model.core.core_type == 'rnn': 267 | dones = torch.stack(dones) 268 | rnn_states = torch.stack(rnn_states) 269 | 270 | with self.timing.timeit('obs_to_device'): 271 | for key, x in observations.items(): 272 | device, dtype = self.actor_critic.device_and_type_for_input_tensor(key) 273 | observations[key] = x.to(device).type(dtype) 274 | 275 | actions = actions.to(self.device).long() 276 | rewards = rewards.to(self.device).float() 277 | task_ids = task_ids.to(self.device).long() 278 | 279 | if self.cfg.model.core.core_type == 'rnn': 280 | rnn_states = rnn_states.to(self.device).float() 281 | dones = dones.to(self.device).float() 282 | if self.cfg.model.use_half_policy_worker: 283 | rnn_states = rnn_states.half() 284 | 285 | 286 | num_samples = actions.shape[0] 287 | if self.cfg.model.core.core_type == 'trxl': 288 | mems = self.mems[:num_samples] 289 | 290 | with self.timing.timeit('forward'): 291 | if self.cfg.model.use_half_policy_worker: 292 | rewards = rewards.half() 293 | if self.cfg.model.core.core_type == 'trxl': 294 | mems = mems.half() 295 | for key in observations: 296 | obs = observations[key] 297 | if obs.dtype == torch.float32: 298 | observations[key] = obs.half() 299 | 300 | if self.cfg.model.core.core_type == 'trxl': 301 | policy_outputs = self.actor_critic(observations, actions.squeeze(1), rewards.squeeze(1), 302 | mems=mems.transpose(0, 1), mem_begin_index=mem_begin_index, 303 | task_ids=task_ids, 304 | with_action_distribution=False, from_learner=False) 305 | elif self.cfg.model.core.core_type == 'rnn': 306 | policy_outputs = self.actor_critic(observations, actions.squeeze(1), rewards.squeeze(1), 307 | rnn_states=rnn_states, dones=dones.squeeze(1), is_seq=False, 308 | task_ids=task_ids, with_action_distribution=False) 309 | 310 | 311 | if self.cfg.model.core.core_type == 'trxl': 312 | midx = 0 313 | for request in self.requests: 314 | actor_idx, split_idx, request_data = request 315 | for env_idx, traj_buffer_idx, rollout_step, first_rollout, actor_env_step in request_data: 316 | mem_index = actor_idx, split_idx, env_idx, actor_env_step % self.max_mems_buffer_len 317 | self.mems_buffer[mem_index] = policy_outputs['mems'][midx] 318 | midx += 1 319 | del policy_outputs['mems'] 320 | 321 | 322 | for key, output_value in policy_outputs.items(): 323 | policy_outputs[key] = output_value.cpu() 324 | 325 | policy_outputs.policy_version = torch.empty([num_samples]).fill_(self.latest_policy_version) 326 | 327 | # concat all tensors into a single tensor for performance 328 | output_tensors = [] 329 | for policy_output in self.shared_buffer.policy_outputs: 330 | tensor_name = policy_output.name 331 | output_value = policy_outputs[tensor_name].float() 332 | if len(output_value.shape) == 1: 333 | output_value.unsqueeze_(dim=1) 334 | output_tensors.append(output_value) 335 | 336 | output_tensors = torch.cat(output_tensors, dim=1) 337 | 338 | self._enqueue_policy_outputs(self.requests, output_tensors) 339 | 340 | self.requests = [] 341 | 342 | def _enqueue_policy_outputs(self, requests, output_tensors): 343 | output_idx = 0 344 | 345 | outputs_ready = set() 346 | policy_outputs = self.shared_buffer.policy_output_tensors 347 | 348 | for request in requests: 349 | actor_idx, split_idx, request_data = request 350 | worker_outputs = policy_outputs[actor_idx, split_idx] 351 | for env_idx, traj_buffer_idx, rollout_step, _, _ in request_data: # writing at shared buffer 352 | worker_outputs[env_idx].copy_(output_tensors[output_idx]) 353 | output_idx += 1 354 | 355 | outputs_ready.add((actor_idx, split_idx)) 356 | 357 | for actor_idx, split_idx in outputs_ready: 358 | advance_rollout_request = dict(split_idx=split_idx) 359 | self.actor_worker_queues[actor_idx].put((TaskType.ROLLOUT_STEP, advance_rollout_request)) 360 | 361 | def _update_weights(self): 362 | learner_policy_version = self.shared_buffer.policy_versions[0].item() 363 | 364 | if self.latest_policy_version < learner_policy_version and self.shared_model_weights is not None: 365 | if self.policy_lock is not None: 366 | with self.timing.timeit('weight_update'): 367 | with self.policy_lock: 368 | self.actor_critic.load_state_dict(self.shared_model_weights) 369 | 370 | self.latest_policy_version = learner_policy_version 371 | 372 | if self.num_policy_updates % 10 == 0: 373 | log.info( 374 | 'Updated weights on worker %d, policy_version %d', 375 | self.cfg.dist.world_rank, self.latest_policy_version, 376 | ) 377 | 378 | self.num_policy_updates += 1 379 | 380 | def _run(self): 381 | signal.signal(signal.SIGINT, signal.SIG_IGN) 382 | 383 | psutil.Process().nice(2) 384 | 385 | torch.multiprocessing.set_sharing_strategy('file_system') 386 | os.environ['CUDA_VISIBLE_DEVICES'] = f'{self.cfg.dist.local_rank}' 387 | 388 | log.info('Initializing model on the policy worker %d...', self.cfg.dist.world_rank) 389 | log.info(f'POLICY worker {self.cfg.dist.world_rank}\tpid {os.getpid()}\tparent {os.getppid()}') 390 | 391 | torch.set_num_threads(1) 392 | 393 | last_report = last_cache_cleanup = time.time() 394 | last_report_samples = 0 395 | request_count = deque(maxlen=50) 396 | 397 | min_num_requests = self.cfg.actor.num_workers 398 | min_num_requests //= 3 399 | min_num_requests = max(1, min_num_requests) 400 | if self.cfg.test.is_test: 401 | min_num_requests = 1 402 | log.info('Min num requests: %d', min_num_requests) 403 | 404 | # Again, very conservative timer. Only wait a little bit, then continue operation. 405 | wait_for_min_requests = 0.025 406 | 407 | while not self.terminate: 408 | try: 409 | while self.shared_buffer.stop_experience_collection: 410 | if self.resume_experience_collection_cv is not None: 411 | with self.resume_experience_collection_cv: 412 | self.resume_experience_collection_cv.wait(timeout=0.05) 413 | 414 | waiting_started = time.time() 415 | while len(self.requests) < min_num_requests and time.time() - waiting_started < wait_for_min_requests: 416 | try: 417 | policy_requests = self.policy_queue.get_many(timeout=0.005) 418 | self.requests.extend(policy_requests) 419 | except Empty: 420 | pass 421 | 422 | self._update_weights() 423 | 424 | with self.timing.timeit('one_step'): 425 | if self.initialized: 426 | if len(self.requests) > 0: 427 | request_count.append(len(self.requests)) 428 | self._handle_policy_steps() 429 | 430 | try: 431 | task_type, data = self.policy_worker_queue.get_nowait() 432 | if task_type == TaskType.INIT: 433 | self._init() 434 | elif task_type == TaskType.TERMINATE: 435 | self.terminate = True 436 | break 437 | elif task_type == TaskType.INIT_MODEL: 438 | self._init_model(data) 439 | 440 | except Empty: 441 | pass 442 | 443 | if time.time() - last_report > 3.0: 444 | samples_since_last_report = self.total_num_samples - last_report_samples 445 | stats = dict() 446 | if len(request_count) > 0: 447 | stats['avg_request_count'] = np.mean(request_count) 448 | 449 | stats['times_policy_worker'] = {} 450 | for key, value in self.timing.items(): 451 | stats['times_policy_worker'][key] = value 452 | 453 | # self.report_queue.put(dict( 454 | # samples=samples_since_last_report, stats=stats, 455 | # )) 456 | self.report_queue.put(stats) 457 | 458 | last_report = time.time() 459 | last_report_samples = self.total_num_samples 460 | 461 | if time.time() - last_cache_cleanup > 300.0 or ( 462 | self.total_num_samples < 1000): 463 | if self.cfg.model.device == 'cuda': 464 | torch.cuda.empty_cache() 465 | last_cache_cleanup = time.time() 466 | 467 | except KeyboardInterrupt: 468 | log.warning('Keyboard interrupt detected on worker %d', self.cfg.dist.world_rank) 469 | self.terminate = True 470 | self.report_queue.put(('terminate', 'policy_worker')) 471 | 472 | except: 473 | log.exception('Unknown exception on policy worker') 474 | self.terminate = True 475 | self.report_queue.put(('terminate', 'policy_worker')) 476 | 477 | time.sleep(0.2) 478 | 479 | def close(self): 480 | self.task_queue.put((TaskType.TERMINATE, None)) 481 | 482 | def join(self): 483 | join_or_kill(self.process) 484 | -------------------------------------------------------------------------------- /brain_agent/core/actor_worker.py: -------------------------------------------------------------------------------- 1 | import psutil 2 | import random 3 | import time 4 | import os 5 | import signal 6 | import torch 7 | from queue import Empty, Full 8 | from threadpoolctl import threadpool_limits 9 | from torch.multiprocessing import Process as TorchProcess 10 | from brain_agent.utils.logger import log 11 | from brain_agent.utils.utils import AttrDict 12 | from brain_agent.core.core_utils import set_process_cpu_affinity, safe_put, safe_put_many, join_or_kill 13 | from brain_agent.envs.env_utils import create_env 14 | from brain_agent.core.core_utils import TaskType 15 | from brain_agent.utils.timing import Timing 16 | 17 | 18 | # TODO: actions -> action 19 | 20 | class ActorWorker: 21 | """ 22 | ActorWorker is responsible for running the environment(s) with the action(s) that may be provided by the policy workers. 23 | 24 | Args: 25 | cfg (['brain_agent.utils.utils.AttrDict'], 'AttrDict'): 26 | Global configuration in a form of AttrDict, a dictionary whose values can be accessed 27 | obs_space ('gym.spaces'): 28 | Observation space. 29 | action_space ('gym.spaces.discrete.Discrete'): 30 | Action space object. Currently only supports discrete action spaces. 31 | shared_buffer (['brain_agent.core.shared_buffer.SharedBuffer']): 32 | Shared buffer object that stores collected rollouts. 33 | actor_worker_queue ('faster_fifo.Queue'): 34 | Task queue for the actor worker. 35 | policy_queue ('faster_fifo.Queue'): 36 | Action request queue for the policy in the policy worker. Not to be confused with policy_worker_queue. 37 | report_queue ('faster_fifo.Queue'): 38 | Task queue for reporting. This is where various workers dump information to log. 39 | learner_worker_queue ('faster_fifo.Queue', *optional*): 40 | Task queue for the learner worker. This is where other processes dump tasks for the learner. 41 | """ 42 | def __init__(self, cfg, obs_space, action_space, actor_idx, shared_buffer, actor_worker_queue, policy_queue, 43 | report_queue, learner_worker_queue=None): 44 | self.cfg = cfg 45 | self.obs_space = obs_space 46 | self.action_space = action_space 47 | self.actor_idx = actor_idx 48 | self.shared_buffer = shared_buffer 49 | self.actor_worker_queue = actor_worker_queue 50 | self.policy_queue = policy_queue 51 | self.learner_worker_queue = learner_worker_queue 52 | self.report_queue = report_queue 53 | 54 | self.env_runners = None 55 | 56 | self.num_splits = self.cfg.actor.num_splits 57 | self.num_envs_per_worker = self.cfg.actor.num_envs_per_worker 58 | self.num_envs_per_vector = self.cfg.actor.num_envs_per_worker // self.num_splits 59 | assert self.num_envs_per_worker >= self.num_splits 60 | assert self.num_envs_per_worker % self.num_splits == 0, 'Vector size should be divisible by num_splits' 61 | 62 | self.terminate = False 63 | self.num_complete_rollouts = 0 64 | self.timing = Timing() 65 | 66 | 67 | self.process = TorchProcess(target=self._run, daemon=True) 68 | self.process.start() 69 | 70 | def _init(self): 71 | log.info('Initializing envs for env runner %d...', self.actor_idx) 72 | 73 | threadpool_limits(limits=1, user_api=None) 74 | 75 | if self.cfg.actor.set_workers_cpu_affinity: 76 | set_process_cpu_affinity(self.actor_idx, self.cfg.actor.num_workers, self.cfg.dist.local_rank, 77 | self.cfg.dist.nproc_per_node) 78 | psutil.Process().nice(10) # learner: 0 79 | 80 | self.env_runners = [] 81 | for split_idx in range(self.num_splits): 82 | env_runner = VectorEnvRunner( 83 | self.cfg, self.num_envs_per_vector, self.actor_idx, split_idx, 84 | self.shared_buffer 85 | ) 86 | env_runner.init() 87 | self.env_runners.append(env_runner) 88 | 89 | def _terminate(self): 90 | for env_runner in self.env_runners: 91 | env_runner.close() 92 | self.terminate = True 93 | 94 | def _enqueue_policy_request(self, split_idx, requests): 95 | """Distribute action requests to their corresponding queues.""" 96 | policy_request = (self.actor_idx, split_idx, requests) 97 | self.policy_queue.put(policy_request) 98 | 99 | def _enqueue_complete_rollouts(self, split_idx, complete_rollouts): 100 | """Send complete rollouts from VectorEnv to the learner.""" 101 | if self.cfg.test.is_test: 102 | return 103 | 104 | traj_buffer_idx = complete_rollouts['traj_buffer_idx'] 105 | 106 | env_runner = self.env_runners[split_idx] 107 | env_runner.traj_tensors_available[:, traj_buffer_idx] = 0 108 | 109 | self.learner_worker_queue.put((TaskType.TRAIN, complete_rollouts)) 110 | 111 | def _report_stats(self, stats): 112 | safe_put_many(self.report_queue, stats, queue_name='report') 113 | 114 | def _handle_reset(self): 115 | for split_idx, env_runner in enumerate(self.env_runners): 116 | policy_inputs = env_runner.reset(self.report_queue) 117 | self._enqueue_policy_request(split_idx, policy_inputs) 118 | 119 | log.info('Finished reset for worker %d', self.actor_idx) 120 | safe_put(self.report_queue, dict(finished_reset=self.actor_idx), queue_name='report') 121 | 122 | def _advance_rollouts(self, data): 123 | split_idx = data['split_idx'] 124 | 125 | runner = self.env_runners[split_idx] 126 | policy_request, complete_rollouts, episodic_stats = runner.advance_rollouts(data, self.timing) 127 | 128 | if complete_rollouts: 129 | self._enqueue_complete_rollouts(split_idx, complete_rollouts) 130 | 131 | if self.num_complete_rollouts == 0: 132 | 133 | delay = (float(self.actor_idx) / self.cfg.actor.num_workers) * \ 134 | self.cfg.env.decorrelate_experience_max_seconds 135 | log.info( 136 | 'Worker %d, sleep for %.3f sec to decorrelate experience collection', 137 | self.actor_idx, delay, 138 | ) 139 | time.sleep(delay) 140 | log.info('Worker %d awakens!', self.actor_idx) 141 | 142 | self.num_complete_rollouts += len(complete_rollouts) 143 | 144 | if policy_request is not None: 145 | self._enqueue_policy_request(split_idx, policy_request) 146 | 147 | if episodic_stats: 148 | self._report_stats(episodic_stats) 149 | 150 | def _run(self): 151 | log.info('Initializing vector env runner %d...', self.actor_idx) 152 | log.info(f'ACTOR worker {self.actor_idx}\tpid {os.getpid()}\tparent {os.getppid()}') 153 | 154 | # workers should ignore Ctrl+C because the termination is handled in the event loop by a special msg 155 | signal.signal(signal.SIGINT, signal.SIG_IGN) 156 | 157 | torch.multiprocessing.set_sharing_strategy('file_system') 158 | 159 | last_report = time.time() 160 | with torch.no_grad(): 161 | while not self.terminate: 162 | try: 163 | try: 164 | with self.timing.timeit('wait_actor'): 165 | tasks = self.actor_worker_queue.get_many(timeout=0.1) 166 | except Empty: 167 | tasks = [] 168 | 169 | for task in tasks: 170 | task_type, data = task 171 | 172 | if task_type == TaskType.INIT: 173 | self._init() 174 | continue 175 | 176 | if task_type == TaskType.TERMINATE: 177 | self._terminate() 178 | break 179 | 180 | # handling actual workload 181 | if task_type == TaskType.ROLLOUT_STEP: 182 | with self.timing.timeit('one_step'): 183 | self._advance_rollouts(data) 184 | 185 | elif task_type == TaskType.RESET: 186 | self._handle_reset() 187 | 188 | if time.time() - last_report > 5.0 and 'one_step' in self.timing: 189 | stats = {} 190 | stats['times_actor_worker'] = {} 191 | for key, value in self.timing.items(): 192 | stats['times_actor_worker'][key] = value 193 | safe_put(self.report_queue, stats) 194 | 195 | last_report = time.time() 196 | 197 | except RuntimeError as exc: 198 | log.warning('Error while processing data w: %d, exception: %s', self.actor_idx, exc) 199 | log.warning('Terminate process...') 200 | self.terminate = True 201 | self.report_queue.put(('terminate', 'actor_worker')) 202 | except KeyboardInterrupt: 203 | self.terminate = True 204 | except: 205 | log.exception('Unknown exception in rollout worker') 206 | self.report_queue.put(('terminate', 'actor_worker')) 207 | self.terminate = True 208 | 209 | if self.actor_idx <= 1: 210 | time.sleep(0.1) 211 | log.info( 212 | 'Env runner %d, CPU aff. %r, rollouts %d', 213 | self.actor_idx, psutil.Process().cpu_affinity(), self.num_complete_rollouts, 214 | ) 215 | 216 | def init(self): 217 | self.actor_worker_queue.put((TaskType.INIT, None)) 218 | 219 | def request_reset(self): 220 | self.actor_worker_queue.put((TaskType.RESET, None)) 221 | 222 | def request_step(self, split, actions): 223 | data = (split, actions) 224 | self.actor_worker_queue.put((TaskType.ROLLOUT_STEP, data)) 225 | 226 | def close(self): 227 | self.actor_worker_queue.put((TaskType.TERMINATE, None)) 228 | 229 | def join(self): 230 | join_or_kill(self.process) 231 | 232 | 233 | class VectorEnvRunner: 234 | def __init__(self, cfg, num_envs, actor_idx, split_idx, shared_buffer): 235 | self.cfg = cfg 236 | 237 | self.num_envs = num_envs 238 | self.actor_idx = actor_idx 239 | self.actor_idx_node = actor_idx * self.cfg.dist.nproc_per_node + self.cfg.dist.local_rank 240 | self.actor_idx_world = actor_idx * self.cfg.dist.world_size + self.cfg.dist.world_rank 241 | self.split_idx = split_idx 242 | 243 | self.rollout_step = 0 244 | self.traj_buffer_idx = 0 # current shared trajectory buffer to use 245 | 246 | self.first_rollout = [True] * shared_buffer.num_traj_buffers 247 | self.shared_buffer = shared_buffer 248 | 249 | index = (actor_idx, split_idx) 250 | 251 | self.traj_tensors = shared_buffer.tensors_individual_transitions.index(index) 252 | self.traj_tensors_available = shared_buffer.is_traj_tensor_available[index] 253 | self.num_traj_buffers = shared_buffer.num_traj_buffers 254 | self.policy_outputs = shared_buffer.policy_outputs 255 | self.policy_output_tensors = shared_buffer.policy_output_tensors[index] 256 | self.task_id = shared_buffer.task_ids[index] 257 | 258 | self.envs, self.actor_states, self.episode_rewards = [], [], [] 259 | 260 | def init(self): 261 | 262 | for env_i in range(self.num_envs): 263 | vector_idx = self.split_idx * self.num_envs + env_i 264 | 265 | # global env id within the entire system 266 | env_id = self.actor_idx_world * self.cfg.actor.num_envs_per_worker + vector_idx 267 | 268 | env_config = AttrDict( 269 | worker_index=self.actor_idx_world, vector_index=vector_idx, env_id=env_id, 270 | ) 271 | 272 | env = create_env(self.cfg, env_config=env_config) 273 | 274 | env.seed(env_id + self.cfg.seed * self.cfg.actor.num_workers * self.cfg.actor.num_envs_per_worker * 275 | self.cfg.dist.world_size) 276 | 277 | self.envs.append(env) 278 | self.task_id[env_i] = env.unwrapped.task_id 279 | 280 | traj_tensors = self.traj_tensors.index(env_i) 281 | actor_state = ActorState( 282 | self.cfg, env, self.actor_idx, self.split_idx, env_i, traj_tensors, 283 | self.num_traj_buffers, self.policy_outputs, self.policy_output_tensors[ 284 | env_i] 285 | ) 286 | episode_rewards_env = 0.0 287 | 288 | self.actor_states.append(actor_state) 289 | self.episode_rewards.append(episode_rewards_env) 290 | 291 | def _process_policy_outputs(self): 292 | 293 | for env_i in range(self.num_envs): 294 | actor_state = self.actor_states[env_i] 295 | 296 | # via shared memory mechanism the new data should already be copied into the shared tensors 297 | 298 | policy_outputs = torch.split( 299 | actor_state.policy_output_tensors, 300 | split_size_or_sections=actor_state.policy_output_sizes, 301 | dim=0, 302 | ) 303 | policy_outputs_dict = dict() 304 | for tensor_idx, name in enumerate(actor_state.policy_output_names): 305 | if name == 'rnn_states' and self.cfg.model.core.core_type == 'rnn': 306 | new_rnn_state = policy_outputs[tensor_idx] 307 | else: 308 | policy_outputs_dict[name] = policy_outputs[tensor_idx] 309 | 310 | actor_state.set_trajectory_data(policy_outputs_dict, self.traj_buffer_idx, self.rollout_step) 311 | actor_state.last_actions = policy_outputs_dict['actions'] 312 | actor_state.prev_actions = policy_outputs_dict['actions'] 313 | 314 | if self.cfg.model.core.core_type == 'rnn': 315 | actor_state.last_rnn_state = new_rnn_state 316 | 317 | def _process_env_step(self, new_obs, rewards, dones, infos, env_i): 318 | 319 | episodic_stats = [] 320 | actor_state = self.actor_states[env_i] 321 | 322 | actor_state.record_env_step( 323 | rewards, dones, infos, self.traj_buffer_idx, self.rollout_step, 324 | ) 325 | actor_state.last_obs = new_obs 326 | actor_state.prev_rewards = float(rewards) 327 | 328 | actor_state.actor_env_step += 1 329 | if self.cfg.model.core.core_type=='rnn': 330 | actor_state.update_rnn_state(dones) 331 | 332 | if dones: 333 | actor_state.prev_rewards = 0. 334 | actor_state.prev_actions = torch.Tensor([-1]) 335 | episodic_stat = dict() 336 | episodic_stat['episodic_stats'] = infos['episodic_stats'] 337 | episodic_stats.append(episodic_stat) 338 | 339 | return episodic_stats 340 | 341 | def _finalize_trajectories(self): 342 | 343 | rollouts = [] 344 | for env_i in range(self.num_envs): 345 | actor_state = self.actor_states[env_i] 346 | rollout = actor_state.finalize_trajectory(self.rollout_step, 347 | self.first_rollout[self.traj_buffer_idx]) 348 | rollout['task_idx'] = self.task_id[env_i] 349 | rollouts.append(rollout) 350 | 351 | return dict(rollouts=rollouts, traj_buffer_idx=self.traj_buffer_idx) 352 | 353 | def _format_policy_request(self): 354 | 355 | policy_request = [] 356 | 357 | for env_i in range(self.num_envs): 358 | actor_state = self.actor_states[env_i] 359 | data = (env_i, self.traj_buffer_idx, self.rollout_step, self.first_rollout[self.traj_buffer_idx], 360 | actor_state.actor_env_step) 361 | policy_request.append(data) 362 | 363 | return policy_request 364 | 365 | def _prepare_next_step(self): 366 | 367 | for env_i in range(len(self.envs)): 368 | actor_state = self.actor_states[env_i] 369 | if self.cfg.model.core.core_type == 'trxl': 370 | policy_inputs = dict(obs=actor_state.last_obs) 371 | elif self.cfg.model.core.core_type == 'rnn': 372 | policy_inputs = dict(obs=actor_state.last_obs, rnn_states=actor_state.last_rnn_state) 373 | actor_state.traj_tensors['dones'][self.traj_buffer_idx, self.rollout_step].fill_(actor_state.done) 374 | actor_state.traj_tensors['prev_actions'][self.traj_buffer_idx, self.rollout_step].copy_( 375 | actor_state.prev_actions.type( 376 | actor_state.traj_tensors['prev_actions'][self.traj_buffer_idx, self.rollout_step].type())) 377 | actor_state.traj_tensors['prev_rewards'][self.traj_buffer_idx, self.rollout_step].fill_( 378 | actor_state.prev_rewards) 379 | 380 | actor_state.set_trajectory_data(policy_inputs, self.traj_buffer_idx, self.rollout_step) 381 | 382 | if self.rollout_step == 0 and self.first_rollout[self.traj_buffer_idx]: # start of the new trajectory, 383 | self.first_rollout[self.traj_buffer_idx] = False 384 | 385 | def reset(self, report_queue): 386 | 387 | for env_i, e in enumerate(self.envs): 388 | obs = e.reset() 389 | 390 | env_i_split = self.num_envs * self.split_idx + env_i 391 | if self.cfg.env.decorrelate_envs_on_one_worker and not self.cfg.test.is_test: 392 | decorrelate_steps = self.cfg.optim.rollout * env_i_split + self.cfg.optim.rollout * random.randint(0, 4) 393 | 394 | log.info('Decorrelating experience for %d frames...', decorrelate_steps) 395 | for decorrelate_step in range(decorrelate_steps): 396 | action = e.action_space.sample() 397 | obs, rew, done, info = e.step(action) 398 | 399 | actor_state = self.actor_states[env_i] 400 | actor_state.set_trajectory_data(dict(obs=obs), self.traj_buffer_idx, self.rollout_step) 401 | actor_state.traj_tensors['prev_actions'][self.traj_buffer_idx, self.rollout_step][0].fill_(-1) 402 | actor_state.traj_tensors['prev_rewards'][self.traj_buffer_idx, self.rollout_step][0] = 0. 403 | actor_state.traj_tensors['dones'][self.traj_buffer_idx, self.rollout_step][0].fill_(False) 404 | 405 | safe_put(report_queue, dict(initialized_env=(self.actor_idx, self.split_idx, env_i, self.task_id.tolist())), 406 | queue_name='report') 407 | 408 | policy_request = self._format_policy_request() 409 | return policy_request 410 | 411 | def advance_rollouts(self, data, timing): 412 | 413 | self._process_policy_outputs() 414 | 415 | complete_rollouts, episodic_stats = [], [] 416 | timing['env_step'], timing['overhead'] = 0, 0 417 | 418 | 419 | for env_i, e in enumerate(self.envs): 420 | with timing.timeit('_env_step'): 421 | actions = self.actor_states[env_i].last_actions.type(torch.int32).numpy().item() 422 | new_obs, rewards, dones, infos = e.step(actions) 423 | 424 | timing['env_step'] += timing['_env_step'] 425 | 426 | with timing.timeit('_overhead'): 427 | stats = self._process_env_step(new_obs, rewards, dones, infos, env_i) 428 | episodic_stats.extend(stats) 429 | timing['overhead'] += timing['_overhead'] 430 | 431 | self.rollout_step += 1 432 | if self.rollout_step == self.cfg.optim.rollout: 433 | # finalize and serialize the trajectory if we have a complete rollout 434 | complete_rollouts = self._finalize_trajectories() 435 | self.rollout_step = 0 436 | self.traj_buffer_idx = (self.traj_buffer_idx + 1) % self.num_traj_buffers 437 | 438 | if self.traj_tensors_available[:, self.traj_buffer_idx].min() == 0: 439 | with timing.timeit('wait_traj_buffer'): 440 | self.wait_for_traj_buffers() 441 | 442 | self._prepare_next_step() 443 | policy_request = self._format_policy_request() 444 | 445 | return policy_request, complete_rollouts, episodic_stats 446 | 447 | def wait_for_traj_buffers(self): 448 | print_warning = True 449 | while self.traj_tensors_available[:, self.traj_buffer_idx].min() == 0: 450 | if print_warning: 451 | log.warning( 452 | 'Waiting for trajectory buffer %d on actor %d-%d', 453 | self.traj_buffer_idx, self.actor_idx, self.split_idx, 454 | ) 455 | print_warning = False 456 | time.sleep(0.002) 457 | 458 | def close(self): 459 | for e in self.envs: 460 | e.close() 461 | 462 | 463 | class ActorState: 464 | def __init__(self, cfg, env, actor_idx, split_idx, env_idx, traj_tensors, num_traj_buffers, policy_outputs_info, 465 | policy_output_tensors): 466 | 467 | self.cfg = cfg 468 | self.env = env 469 | self.actor_idx = actor_idx 470 | self.split_idx = split_idx 471 | self.env_idx = env_idx 472 | 473 | if not self.cfg.model.core.core_type == 'rnn': 474 | self.last_rnn_state = None 475 | 476 | self.traj_tensors = traj_tensors 477 | self.num_traj_buffers = num_traj_buffers 478 | 479 | self.policy_output_names = [p.name for p in policy_outputs_info] 480 | self.policy_output_sizes = [p.size for p in policy_outputs_info] 481 | self.policy_output_tensors = policy_output_tensors 482 | 483 | self.prev_actions = None 484 | self.prev_rewards = None 485 | 486 | self.last_actions = None 487 | self.last_policy_steps = None 488 | 489 | self.num_trajectories = 0 490 | self.rollout_env_steps = 0 491 | 492 | self.last_episode_reward = 0 493 | self.last_episode_duration = 0 494 | self.last_episode_true_reward = 0 495 | self.last_episode_extra_stats = dict() 496 | 497 | self.actor_env_step = 0 498 | 499 | def set_trajectory_data(self, data, traj_buffer_idx, rollout_step): 500 | 501 | index = (traj_buffer_idx, rollout_step) 502 | self.traj_tensors.set_data(index, data) 503 | 504 | def record_env_step(self, reward, done, info, traj_buffer_idx, rollout_step): 505 | 506 | self.traj_tensors['rewards'][traj_buffer_idx, rollout_step][0] = float(reward) 507 | self.traj_tensors['dones'][traj_buffer_idx, rollout_step][0] = done 508 | 509 | env_steps = info.get('num_frames', 1) 510 | self.rollout_env_steps += env_steps 511 | self.last_episode_duration += env_steps 512 | 513 | if done: 514 | self.done = True 515 | self.last_episode_extra_stats = info.get('episode_extra_stats', dict()) 516 | else: 517 | self.done = False 518 | 519 | def finalize_trajectory(self, rollout_step, first_rollout=None): 520 | 521 | t_id = f'{self.actor_idx}_{self.split_idx}_{self.env_idx}_{self.num_trajectories}' 522 | mem_idx = ( 523 | self.actor_idx, 524 | self.split_idx, 525 | self.env_idx, 526 | rollout_step, 527 | self.actor_env_step, 528 | first_rollout 529 | ) 530 | traj_dict = dict( 531 | t_id=t_id, length=rollout_step, env_steps=self.rollout_env_steps, mem_idx=mem_idx, 532 | actor_idx=self.actor_idx, split_idx=self.split_idx, env_idx=self.env_idx 533 | ) 534 | 535 | self.num_trajectories += 1 536 | self.rollout_env_steps = 0 537 | 538 | return traj_dict 539 | 540 | def episodic_stats(self): 541 | stats = dict(reward=self.last_episode_reward, len=self.last_episode_duration) 542 | 543 | stats['episode_extra_stats'] = self.last_episode_extra_stats 544 | 545 | report = dict(episodic=stats) 546 | self.last_episode_reward = self.last_episode_duration = self.last_episode_true_reward = 0 547 | self.last_episode_extra_stats = dict() 548 | return report 549 | 550 | def update_rnn_state(self, done): 551 | """If we encountered an episode boundary, reset rnn states to their default values.""" 552 | if done: 553 | self.last_rnn_state.fill_(0.0) 554 | --------------------------------------------------------------------------------