├── brain_agent
├── __init__.py
├── core
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── agent_utils.py
│ │ ├── agent_abc.py
│ │ └── dmlab_multitask_agent.py
│ ├── algos
│ │ ├── __init__.py
│ │ ├── popart.py
│ │ ├── vtrace.py
│ │ └── aux_future_predict.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── model_abc.py
│ │ ├── model_utils.py
│ │ ├── action_distributions.py
│ │ ├── causal_transformer.py
│ │ ├── resnet.py
│ │ ├── rnn.py
│ │ └── transformer.py
│ ├── core_utils.py
│ ├── shared_buffer.py
│ ├── policy_worker.py
│ └── actor_worker.py
├── envs
│ ├── __init__.py
│ ├── dmlab
│ │ ├── __init__.py
│ │ ├── dmlab_env.py
│ │ ├── dmlab_wrappers.py
│ │ ├── dmlab_model.py
│ │ ├── dmlab_level_cache.py
│ │ ├── dmlab_gym.py
│ │ └── dmlab30.py
│ └── env_utils.py
└── utils
│ ├── __init__.py
│ ├── cfg.py
│ ├── logger.py
│ ├── timing.py
│ ├── utils.py
│ └── dist_utils.py
├── requirements.txt
├── assets
├── learning_curve.png
└── system_overview.png
├── configs
├── trxl_baseline_train.yaml
├── rnn_baseline_train.yaml
├── lstm_baseline_train.yaml
├── trxl_future_pred_train.yaml
├── lstm_baseline_eval.yaml
├── trxl_baseline_eval.yaml
├── trxl_recon_train.yaml
├── trxl_recon_eval.yaml
├── trxl_future_pred_eval.yaml
└── default.yaml
├── LICENSE
├── .gitignore
├── dist_launch.py
├── eval.py
├── train.py
└── README.md
/brain_agent/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/brain_agent/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/brain_agent/envs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/brain_agent/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/brain_agent/core/agents/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/brain_agent/core/algos/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/brain_agent/core/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/brain_agent/envs/dmlab/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | faster_fifo
2 | tensorboardX
3 | threadpoolctl
4 | colorlog
5 | gym
6 | omegaconf
7 |
--------------------------------------------------------------------------------
/assets/learning_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/brain-agent/HEAD/assets/learning_curve.png
--------------------------------------------------------------------------------
/assets/system_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kakaobrain/brain-agent/HEAD/assets/system_overview.png
--------------------------------------------------------------------------------
/brain_agent/envs/env_utils.py:
--------------------------------------------------------------------------------
1 | from brain_agent.envs.dmlab.dmlab_env import make_dmlab_env
2 |
3 | def create_env(cfg=None, env_config=None):
4 | if 'dmlab' in cfg.env.name:
5 | env = make_dmlab_env(cfg, env_config)
6 | else:
7 | raise NotImplementedError
8 | return env
9 |
--------------------------------------------------------------------------------
/configs/trxl_baseline_train.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | model:
5 | agent: dmlab_multitask_agent
6 | encoder:
7 | encoder_subtype: resnet_impala
8 | core:
9 | core_type: trxl
10 | use_half_policy_worker: True
11 |
12 | optim:
13 | type: adam
14 | learning_rate: 0.0001
15 | batch_size: 1536
16 | rollout: 96
17 | max_grad_norm: 2.5
18 |
19 |
--------------------------------------------------------------------------------
/brain_agent/core/agents/agent_utils.py:
--------------------------------------------------------------------------------
1 | from brain_agent.core.agents.dmlab_multitask_agent import DMLabMultiTaskAgent
2 |
3 | def create_agent(cfg, action_space, obs_space, num_levels=1, need_half=False):
4 | if cfg.model.agent == 'dmlab_multitask_agent':
5 | agent = DMLabMultiTaskAgent(cfg, action_space, obs_space, num_levels, need_half)
6 | else:
7 | raise NotImplementedError
8 | return agent
9 |
--------------------------------------------------------------------------------
/configs/rnn_baseline_train.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | model:
5 | agent: dmlab_multitask_agent
6 | core:
7 | core_type: rnn
8 | n_rnn_layer: 3
9 | core_init: orthogonal
10 |
11 | optim:
12 | type: adam
13 | learning_rate: 0.0002
14 | batch_size: 1536
15 | rollout: 96
16 | max_grad_norm: 2.5
17 |
18 | learner:
19 | use_ppo: True
20 | use_adv_normalization: True
21 |
--------------------------------------------------------------------------------
/configs/lstm_baseline_train.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | model:
5 | agent: dmlab_multitask_agent
6 | core:
7 | core_type: rnn
8 | n_rnn_layer: 3
9 | core_init: tensorflow_default
10 |
11 | optim:
12 | type: adam
13 | learning_rate: 0.0002
14 | batch_size: 576
15 | rollout: 96
16 | max_grad_norm: 2.5
17 | warmup_optimizer: 0
18 |
19 | learner:
20 | use_ppo: False
21 | use_adv_normalization: False
22 | exploration_loss_coeff: 0.003
23 |
--------------------------------------------------------------------------------
/configs/trxl_future_pred_train.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | model:
5 | agent: dmlab_multitask_agent
6 | encoder:
7 | encoder_subtype: resnet_impala_large
8 | encoder_pooling: stride
9 | core:
10 | core_type: trxl
11 | n_layer: 6
12 | hidden_size: 512
13 |
14 |
15 | learner:
16 | exploration_loss_coeff: 0.003
17 | psychlab_gamma: 0.7
18 | use_decoder: False
19 | use_aux_future_pred_loss : True
20 |
21 | env:
22 | action_set: extended_action_set_large
23 |
--------------------------------------------------------------------------------
/configs/lstm_baseline_eval.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | test:
5 | is_test: True
6 | checkpoint: ???
7 | test_num_episodes: 100
8 |
9 | model:
10 | agent: dmlab_multitask_agent
11 | core:
12 | core_type: rnn
13 | n_rnn_layer: 3
14 | core_init: tensorflow_default
15 | use_half_policy_worker: False
16 |
17 | actor:
18 | num_workers: 30
19 | num_envs_per_worker: 10
20 | num_splits: 2
21 |
22 | env:
23 | name: dmlab_30_test
24 | use_level_cache: True
25 | decorrelate_envs_on_one_worker: False
26 | decorrelate_experience_max_seconds: 0
27 | one_task_per_worker: True
--------------------------------------------------------------------------------
/configs/trxl_baseline_eval.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | test:
5 | is_test: True
6 | checkpoint: ???
7 | test_num_episodes: 100
8 |
9 | model:
10 | agent: dmlab_multitask_agent
11 | encoder:
12 | encoder_subtype: resnet_impala
13 | core:
14 | core_type: trxl
15 | use_half_policy_worker: False
16 |
17 | actor:
18 | num_workers: 30
19 | num_envs_per_worker: 10
20 | num_splits: 2
21 |
22 | env:
23 | name: dmlab_30_test
24 | use_level_cache: True
25 | decorrelate_envs_on_one_worker: False
26 | decorrelate_experience_max_seconds: 0
27 | one_task_per_worker: True
28 |
29 |
--------------------------------------------------------------------------------
/configs/trxl_recon_train.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | model:
5 | agent: dmlab_multitask_agent
6 | encoder:
7 | encoder_subtype: resnet_impala_large
8 | encoder_pooling: stride
9 | core:
10 | core_type: trxl
11 | n_layer: 6
12 | hidden_size: 512
13 | use_half_policy_worker: True
14 |
15 | optim:
16 | type: adam
17 | learning_rate: 0.0001
18 | batch_size: 1536
19 | rollout: 96
20 | max_grad_norm: 2.5
21 |
22 | learner:
23 | exploration_loss_coeff: 0.003
24 | psychlab_gamma: 0.7
25 | use_decoder: True
26 | reconstruction_loss_coeff: 0.01
27 |
28 | env:
29 | action_set: extended_action_set_large
30 |
--------------------------------------------------------------------------------
/brain_agent/core/agents/agent_abc.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from brain_agent.core.models.action_distributions import DiscreteActionsParameterization
3 |
4 | class ActorCriticBase(nn.Module):
5 | def __init__(self, action_space, cfg):
6 | super().__init__()
7 | self.cfg = cfg
8 | self.action_space = action_space
9 |
10 | def get_action_parameterization(self, core_output_size):
11 | action_parameterization = DiscreteActionsParameterization(self.cfg, core_output_size, self.action_space)
12 | return action_parameterization
13 |
14 | def model_to_device(self, device):
15 | self.to(device)
16 |
17 | def device_and_type_for_input_tensor(self, input_tensor_name):
18 | return self.encoder.device_and_type_for_input_tensor(input_tensor_name)
19 |
--------------------------------------------------------------------------------
/configs/trxl_recon_eval.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | test:
5 | is_test: True
6 | checkpoint: ???
7 | test_num_episodes: 100
8 |
9 | model:
10 | agent: dmlab_multitask_agent
11 | encoder:
12 | encoder_subtype: resnet_impala_large
13 | encoder_pooling: stride
14 | core:
15 | core_type: trxl
16 | n_layer: 6
17 | hidden_size: 512
18 | use_half_policy_worker: False
19 |
20 | learner:
21 | use_decoder: True
22 |
23 | actor:
24 | num_workers: 30
25 | num_envs_per_worker: 10
26 | num_splits: 2
27 |
28 | env:
29 | name: dmlab_30_test
30 | use_level_cache: True
31 | decorrelate_envs_on_one_worker: False
32 | decorrelate_experience_max_seconds: 0
33 | one_task_per_worker: True
34 | action_set: extended_action_set_large
35 |
--------------------------------------------------------------------------------
/configs/trxl_future_pred_eval.yaml:
--------------------------------------------------------------------------------
1 | train_dir: ???
2 | experiment: ???
3 |
4 | test:
5 | is_test: True
6 | checkpoint: ???
7 | test_num_episodes: 100
8 |
9 | model:
10 | agent: dmlab_multitask_agent
11 | encoder:
12 | encoder_subtype: resnet_impala_large
13 | encoder_pooling: stride
14 | core:
15 | core_type: trxl
16 | n_layer: 6
17 | hidden_size: 512
18 | use_half_policy_worker: False
19 |
20 | learner:
21 | use_aux_future_pred_loss : True
22 |
23 | actor:
24 | num_workers: 30
25 | num_envs_per_worker: 10
26 | num_splits: 2
27 |
28 | env:
29 | name: dmlab_30_test
30 | use_level_cache: True
31 | decorrelate_envs_on_one_worker: False
32 | decorrelate_experience_max_seconds: 0
33 | one_task_per_worker: True
34 | action_set: extended_action_set_large
35 |
--------------------------------------------------------------------------------
/brain_agent/utils/cfg.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from brain_agent.utils.utils import AttrDict
3 |
4 | class Configs(OmegaConf):
5 | @classmethod
6 | def get_defaults(cls):
7 | cfg = cls.load('configs/default.yaml')
8 | cls.set_struct(cfg, True)
9 | return cfg
10 |
11 | @classmethod
12 | def override_from_file_name(cls, cfg):
13 | c = cls.override_from_cli(cfg)
14 | if not cls.is_missing(c, 'cfg'):
15 | c = cls.load(c.cfg)
16 | cfg = cls.merge(cfg, c)
17 | return cfg
18 |
19 | @classmethod
20 | def override_from_cli(cls, cfg):
21 | c = cls.from_cli()
22 | cfg = cls.merge(cfg, c)
23 | return cfg
24 |
25 | @classmethod
26 | def to_attr_dict(cls, cfg):
27 | c = cls.to_container(cfg)
28 | c = AttrDict.from_nested_dicts(c)
29 | return c
30 |
31 |
--------------------------------------------------------------------------------
/brain_agent/core/algos/popart.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def update_mu_sigma(nu, mu, vs, task_ids, popart_clip_min, clamp_max, beta):
4 | oldnu = nu.clone()
5 | oldsigma = torch.sqrt(oldnu - mu ** 2)
6 | oldsigma[torch.isnan(oldsigma)] = popart_clip_min
7 | oldsigma = torch.clamp(oldsigma, min=popart_clip_min, max=clamp_max)
8 | oldmu = mu.clone()
9 |
10 | for i in range(len(task_ids)):
11 | task_id = task_ids[i]
12 | v = torch.mean(vs[i])
13 |
14 | mu[task_id] = (1 - beta) * mu[task_id] + beta * v
15 | nu[task_id] = (1 - beta) * nu[task_id] + beta * (v ** 2)
16 |
17 | sigma = torch.sqrt(nu - mu ** 2)
18 | sigma[torch.isnan(sigma)] = popart_clip_min
19 | sigma = torch.clamp(sigma, min=popart_clip_min, max=clamp_max)
20 |
21 | return mu, nu, sigma, oldmu, oldsigma
22 |
23 | def update_parameters(weight, bias, mu, sigma, oldmu, oldsigma):
24 | new_weight = (weight.t() * oldsigma / sigma).t()
25 | new_bias = (oldsigma * bias + oldmu - mu) / sigma
26 | return new_weight, new_bias
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Kakao Brain Corp.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/brain_agent/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from colorlog import ColoredFormatter
3 |
4 | log = logging.getLogger('rl')
5 | log.setLevel(logging.DEBUG)
6 | log.handlers = []
7 | log.propagate = False
8 | log_level = logging.DEBUG
9 |
10 | stream_handler = logging.StreamHandler()
11 | stream_handler.setLevel(log_level)
12 |
13 | stream_formatter = ColoredFormatter(
14 | '%(log_color)s[%(asctime)s][%(process)05d] %(message)s',
15 | datefmt=None,
16 | reset=True,
17 | log_colors={
18 | 'DEBUG': 'cyan',
19 | 'INFO': 'white,bold',
20 | 'INFOV': 'cyan,bold',
21 | 'WARNING': 'yellow',
22 | 'ERROR': 'red,bold',
23 | 'CRITICAL': 'red,bg_white',
24 | },
25 | secondary_log_colors={},
26 | style='%'
27 | )
28 | stream_handler.setFormatter(stream_formatter)
29 | log.addHandler(stream_handler)
30 |
31 | def init_logger(log_level='debug', file_path=None):
32 | log.setLevel(logging.getLevelName(str.upper(log_level)))
33 | if file_path is not None:
34 | file_handler = logging.FileHandler(file_path)
35 | file_formatter = logging.Formatter(fmt='[%(asctime)s][%(process)05d] %(message)s', datefmt=None, style='%')
36 | file_handler.setFormatter(file_formatter)
37 | log.addHandler(file_handler)
38 | for h in log.handlers:
39 | h.setLevel(logging.getLevelName(str.upper(log_level)))
40 |
41 |
42 |
--------------------------------------------------------------------------------
/brain_agent/core/models/model_abc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from brain_agent.core.models.model_utils import nonlinearity
4 |
5 | class ActionsParameterizationBase(nn.Module):
6 | def __init__(self, cfg, action_space):
7 | super().__init__()
8 | self.cfg = cfg
9 | self.action_space = action_space
10 |
11 | class EncoderBase(nn.Module):
12 | def __init__(self, cfg):
13 | super().__init__()
14 |
15 | self.cfg = cfg
16 |
17 | self.fc_after_enc = None
18 | self.encoder_out_size = -1 # to be initialized in the constuctor of derived class
19 |
20 | def get_encoder_out_size(self):
21 | return self.encoder_out_size
22 |
23 | def init_fc_blocks(self, input_size):
24 | layers = []
25 | fc_layer_size = self.cfg.model.core.hidden_size
26 |
27 | for i in range(self.cfg.model.encoder.encoder_extra_fc_layers):
28 | size = input_size if i == 0 else fc_layer_size
29 |
30 | layers.extend([
31 | nn.Linear(size, fc_layer_size),
32 | nonlinearity(self.cfg),
33 | ])
34 |
35 | if len(layers) > 0:
36 | self.fc_after_enc = nn.Sequential(*layers)
37 | self.encoder_out_size = fc_layer_size
38 | else:
39 | self.encoder_out_size = input_size
40 |
41 | def model_to_device(self, device):
42 | self.to(device)
43 |
44 | def device_and_type_for_input_tensor(self, _):
45 | return self.model_device(), torch.float32
46 |
47 | def model_device(self):
48 | return next(self.parameters()).device
49 |
50 | def forward_fc_blocks(self, x):
51 | if self.fc_after_enc is not None:
52 | x = self.fc_after_enc(x)
53 |
54 | return x
--------------------------------------------------------------------------------
/brain_agent/core/algos/vtrace.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def calculate_vtrace(values, rewards, dones, vtrace_rho, vtrace_c,
5 | num_trajectories, recurrence, gamma, exclude_last=False):
6 | values_cpu = values.cpu()
7 | rewards_cpu = rewards.cpu()
8 | dones_cpu = dones.cpu()
9 | vtrace_rho_cpu = vtrace_rho.cpu()
10 | vtrace_c_cpu = vtrace_c.cpu()
11 |
12 | vs = torch.zeros((num_trajectories * recurrence))
13 | adv = torch.zeros((num_trajectories * recurrence))
14 |
15 | bootstrap_values = values_cpu[recurrence - 1::recurrence]
16 | values_BT = values_cpu.view(-1, recurrence)
17 | next_values = torch.cat([values_BT[:, 1:], bootstrap_values.view(-1, 1)], dim=1).view(-1)
18 | next_vs = next_values[recurrence - 1::recurrence]
19 |
20 | masked_gammas = (1.0 - dones_cpu) * gamma
21 |
22 | if exclude_last:
23 | rollout_recurrence = recurrence - 1
24 | adv[recurrence - 1::recurrence] = rewards_cpu[recurrence - 1::recurrence] + (masked_gammas[recurrence - 1::recurrence] - 1) * next_vs
25 | vs[recurrence - 1::recurrence] = next_vs * vtrace_rho_cpu[recurrence - 1::recurrence] * adv[recurrence - 1::recurrence]
26 | else:
27 | rollout_recurrence = recurrence
28 |
29 | for i in reversed(range(rollout_recurrence)):
30 | rewards = rewards_cpu[i::recurrence]
31 | not_done_times_gamma = masked_gammas[i::recurrence]
32 |
33 | curr_values = values_cpu[i::recurrence]
34 | curr_next_values = next_values[i::recurrence]
35 | curr_vtrace_rho = vtrace_rho_cpu[i::recurrence]
36 | curr_vtrace_c = vtrace_c_cpu[i::recurrence]
37 |
38 | delta_s = curr_vtrace_rho * (rewards + not_done_times_gamma * curr_next_values - curr_values)
39 | adv[i::recurrence] = rewards + not_done_times_gamma * next_vs - curr_values
40 | next_vs = curr_values + delta_s + not_done_times_gamma * curr_vtrace_c * (next_vs - curr_next_values)
41 | vs[i::recurrence] = next_vs
42 |
43 | return vs, adv
44 |
--------------------------------------------------------------------------------
/brain_agent/utils/timing.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import deque
3 |
4 | from brain_agent.utils.utils import AttrDict
5 |
6 |
7 | class AvgTime:
8 | def __init__(self, num_values_to_avg):
9 | self.values = deque([], maxlen=num_values_to_avg)
10 |
11 | def __str__(self):
12 | avg_time = sum(self.values) / max(1, len(self.values))
13 | return f'{avg_time:.4f}'
14 |
15 |
16 | class TimingContext:
17 | def __init__(self, timer, key, additive=False, average=None):
18 | self._timer = timer
19 | self._key = key
20 | self._additive = additive
21 | self._average = average
22 | self._time_enter = None
23 |
24 | def __enter__(self):
25 | self._time_enter = time.time()
26 |
27 | def __exit__(self, type_, value, traceback):
28 | if self._key not in self._timer:
29 | if self._average is not None:
30 | self._timer[self._key] = AvgTime(num_values_to_avg=self._average)
31 | else:
32 | self._timer[self._key] = 0
33 |
34 | time_passed = max(time.time() - self._time_enter, 1e-8) # EPS to prevent div by zero
35 |
36 | if self._additive:
37 | self._timer[self._key] += time_passed
38 | elif self._average is not None:
39 | self._timer[self._key].values.append(time_passed)
40 | else:
41 | self._timer[self._key] = time_passed
42 |
43 |
44 | class Timing(AttrDict):
45 | def timeit(self, key):
46 | return TimingContext(self, key)
47 |
48 | def add_time(self, key):
49 | return TimingContext(self, key, additive=True)
50 |
51 | def time_avg(self, key, average=10):
52 | return TimingContext(self, key, average=average)
53 |
54 | def __str__(self):
55 | s = ''
56 | i = 0
57 | for key, value in self.items():
58 | str_value = f'{value:.4f}' if isinstance(value, float) else str(value)
59 | s += f'{key}: {str_value}'
60 | if i < len(self) - 1:
61 | s += ', '
62 | i += 1
63 | return s
--------------------------------------------------------------------------------
/brain_agent/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from brain_agent.utils.logger import log
4 | import glob
5 |
6 | def dict_of_list_put(d, k, v, max_len=100):
7 | if d.get(k) is None:
8 | d[k] = []
9 | d[k].append(v)
10 | if len(d[k]) > max_len:
11 | d[k].pop(0)
12 |
13 | def list_of_dicts_to_dict_of_lists(list_of_dicts):
14 | dict_of_lists = dict()
15 |
16 | for d in list_of_dicts:
17 | for key, x in d.items():
18 | if key not in dict_of_lists:
19 | dict_of_lists[key] = []
20 |
21 | dict_of_lists[key].append(x)
22 |
23 | return dict_of_lists
24 |
25 | def get_checkpoint_dir(cfg):
26 | checkpoint_dir = os.path.join(get_experiment_dir(cfg=cfg), f'checkpoint')
27 | os.makedirs(checkpoint_dir, exist_ok=True)
28 | return checkpoint_dir
29 |
30 |
31 | def get_checkpoints(checkpoints_dir):
32 | checkpoints = glob.glob(os.path.join(checkpoints_dir, 'checkpoint_*'))
33 | return sorted(checkpoints)
34 |
35 |
36 | def get_experiment_dir(cfg):
37 | exp_dir = os.path.join(cfg.train_dir, cfg.experiment)
38 | os.makedirs(exp_dir, exist_ok=True)
39 | return exp_dir
40 |
41 | def get_log_path(cfg):
42 | exp_dir = os.path.join(cfg.train_dir, cfg.experiment)
43 | os.makedirs(exp_dir, exist_ok=True)
44 | log_dir = os.path.join(cfg.train_dir, cfg.experiment, 'logs')
45 | os.makedirs(log_dir, exist_ok=True)
46 | date = datetime.now().strftime("%Y%m%d_%I%M%S%P")
47 | return os.path.join(log_dir, f'log-r{cfg.dist.world_rank:02d}-{date}.txt')
48 |
49 | def get_summary_dir(cfg, postfix=None):
50 | summary_dir = os.path.join(cfg.train_dir, cfg.experiment, 'summary')
51 | os.makedirs(summary_dir, exist_ok=True)
52 | if postfix is not None:
53 | summary_dir = os.path.join(summary_dir, postfix)
54 | os.makedirs(summary_dir, exist_ok=True)
55 | return summary_dir
56 |
57 | class AttrDict(dict):
58 | __setattr__ = dict.__setitem__
59 |
60 | def __getattribute__(self, item):
61 | if item in self:
62 | return self[item]
63 | else:
64 | return super().__getattribute__(item)
65 |
66 | @classmethod
67 | def from_nested_dicts(cls, data):
68 | if not isinstance(data, dict):
69 | return data
70 | else:
71 | return cls({key: cls.from_nested_dicts(data[key]) for key in data})
72 |
--------------------------------------------------------------------------------
/brain_agent/envs/dmlab/dmlab_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from brain_agent.envs.dmlab.dmlab30 import DMLAB_LEVELS_BY_ENVNAME
3 | from brain_agent.envs.dmlab.dmlab_gym import DmlabGymEnv
4 | from brain_agent.envs.dmlab.dmlab_level_cache import dmlab_ensure_global_cache_initialized
5 | from brain_agent.envs.dmlab.dmlab_wrappers import PixelFormatChwWrapper, EpisodicStatWrapper, RewardShapingWrapper
6 | from brain_agent.utils.utils import get_experiment_dir
7 | from brain_agent.utils.logger import log
8 |
9 | DMLAB_INITIALIZED = False
10 |
11 | def get_task_id(env_config, levels, cfg):
12 | if env_config is None:
13 | return 0
14 |
15 | num_envs = len(levels)
16 |
17 | if cfg.env.one_task_per_worker:
18 | return env_config['worker_index'] % num_envs
19 | else:
20 | return env_config['env_id'] % num_envs
21 |
22 |
23 | def make_dmlab_env_impl(levels, cfg, env_config, extra_cfg=None):
24 | skip_frames = cfg.env.frameskip
25 |
26 | task_id = get_task_id(env_config, levels, cfg)
27 | level = levels[task_id]
28 | log.debug('%r level %s task id %d', env_config, level, task_id)
29 |
30 | env = DmlabGymEnv(
31 | task_id, level, skip_frames, cfg.env.res_w, cfg.env.res_h,
32 | cfg.env.dataset_path, cfg.env.action_set,
33 | cfg.env.use_level_cache, cfg.env.level_cache_path, extra_cfg,
34 | )
35 | all_levels = []
36 | for l in levels:
37 | all_levels.append(l.replace('contributed/dmlab30/', ''))
38 |
39 | env.level_info = dict(
40 | num_levels=len(levels),
41 | all_levels=all_levels
42 | )
43 |
44 | env = PixelFormatChwWrapper(env)
45 |
46 | env = EpisodicStatWrapper(env)
47 |
48 | env = RewardShapingWrapper(env)
49 |
50 | return env
51 |
52 |
53 | def make_dmlab_env(cfg, env_config=None):
54 | levels = DMLAB_LEVELS_BY_ENVNAME[cfg.env.name]
55 | extra_cfg = None
56 | if cfg.test.is_test and 'test' in cfg.env.name:
57 | extra_cfg = dict(allowHoldOutLevels='true')
58 | ensure_initialized(cfg, levels)
59 | return make_dmlab_env_impl(levels, cfg, env_config, extra_cfg=extra_cfg)
60 |
61 |
62 | def ensure_initialized(cfg, levels):
63 | global DMLAB_INITIALIZED
64 | if DMLAB_INITIALIZED:
65 | return
66 |
67 | level_cache_dir = cfg.env.level_cache_path
68 | os.makedirs(level_cache_dir, exist_ok=True)
69 |
70 | dmlab_ensure_global_cache_initialized(get_experiment_dir(cfg=cfg), levels, level_cache_dir)
71 | DMLAB_INITIALIZED = True
72 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # additional
132 | .idea
--------------------------------------------------------------------------------
/brain_agent/core/models/model_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch import nn
4 | from brain_agent.utils.utils import AttrDict
5 |
6 | EPS = 1e-8
7 |
8 | def to_scalar(value):
9 | if isinstance(value, torch.Tensor):
10 | return value.item()
11 | else:
12 | return value
13 |
14 | def get_hidden_size(cfg, action_space):
15 | if cfg.model.core.core_type == 'trxl':
16 | size = cfg.model.core.hidden_size * (cfg.model.core.n_layer + 1)
17 | size += 64 * (cfg.model.core.n_layer + 1)
18 | if cfg.model.extended_input:
19 | size += (action_space.n + 1) * (cfg.model.core.n_layer + 1)
20 | elif cfg.model.core.core_type == 'rnn':
21 | size = cfg.model.core.hidden_size * cfg.model.core.n_rnn_layer * 2
22 | else:
23 | raise NotImplementedError
24 | return size
25 |
26 |
27 | def nonlinearity(cfg):
28 | if cfg.model.encoder.nonlinearity == 'elu':
29 | return nn.ELU(inplace=cfg.model.encoder.nonlinear_inplace)
30 | elif cfg.model.encoder.nonlinearity == 'relu':
31 | return nn.ReLU(inplace=cfg.model.encoder.nonlinear_inplace)
32 | elif cfg.model.encoder.nonlinearity == 'tanh':
33 | return nn.Tanh()
34 | else:
35 | raise Exception('Unknown nonlinearity')
36 |
37 |
38 | def get_obs_shape(obs_space):
39 | obs_shape = AttrDict()
40 | if hasattr(obs_space, 'spaces'):
41 | for key, space in obs_space.spaces.items():
42 | obs_shape[key] = space.shape
43 | else:
44 | obs_shape.obs = obs_space.shape
45 |
46 | return obs_shape
47 |
48 |
49 | def calc_num_elements(module, module_input_shape):
50 | shape_with_batch_dim = (1,) + module_input_shape
51 | some_input = torch.rand(shape_with_batch_dim)
52 | num_elements = module(some_input).numel()
53 | return num_elements
54 |
55 | def normalize_obs_return(obs_dict, cfg, half=False):
56 | with torch.no_grad():
57 | mean = cfg.env.obs_subtract_mean
58 | scale = cfg.env.obs_scale
59 |
60 | normalized_obs_dict = copy.deepcopy(obs_dict)
61 |
62 | if normalized_obs_dict['obs'].dtype != torch.float:
63 | normalized_obs_dict['obs'] = normalized_obs_dict['obs'].float()
64 |
65 | if abs(mean) > EPS:
66 | normalized_obs_dict['obs'].sub_(mean)
67 |
68 | if abs(scale - 1.0) > EPS:
69 | normalized_obs_dict['obs'].mul_(1.0 / scale)
70 |
71 | if half:
72 | normalized_obs_dict['obs'] = normalized_obs_dict['obs'].half()
73 |
74 | return normalized_obs_dict
75 |
76 |
--------------------------------------------------------------------------------
/brain_agent/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 |
4 | import torch
5 | from brain_agent.utils.logger import log
6 |
7 | DistEnv = collections.namedtuple('DistEnv', ['world_size', 'world_rank', 'local_rank', 'num_gpus', 'master'])
8 |
9 |
10 | def dist_init(cfg):
11 | if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) > 1:
12 | log.debug('[dist] Distributed: wait dist process group:%d', cfg.dist.local_rank)
13 | torch.distributed.init_process_group(backend=cfg.dist.dist_backend, init_method='env://',
14 | world_size=int(os.environ['WORLD_SIZE']))
15 | assert (int(os.environ['WORLD_SIZE']) == torch.distributed.get_world_size())
16 | log.debug('[dist] Distributed: success device:%d (%d/%d)',
17 | cfg.dist.local_rank, torch.distributed.get_rank(), torch.distributed.get_world_size())
18 | distenv = DistEnv(torch.distributed.get_world_size(), torch.distributed.get_rank(), cfg.dist.local_rank, 1, torch.distributed.get_rank() == 0)
19 | else:
20 | log.debug('[dist] Single processed')
21 | distenv = DistEnv(1, 0, 0, torch.cuda.device_count(), True)
22 | log.debug('[dist] %s', distenv)
23 | return distenv
24 |
25 |
26 | def dist_all_reduce_gradient(model):
27 | torch.distributed.barrier()
28 | world_size = float(torch.distributed.get_world_size())
29 | for p in model.parameters():
30 | if type(p.grad) is not type(None):
31 | torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM)
32 | p.grad.data /= world_size
33 |
34 |
35 | def dist_reduce_gradient(model, grads=None):
36 | torch.distributed.barrier()
37 | world_size = float(torch.distributed.get_world_size())
38 | if grads is None:
39 | for p in model.parameters():
40 | if type(p.grad) is not type(None):
41 | torch.distributed.reduce(p.grad.data, 0, op=torch.distributed.ReduceOp.SUM)
42 | p.grad.data /= world_size
43 | else:
44 | for grad in grads:
45 | if type(grad) is not type(None):
46 | torch.distributed.reduce(grad.data, 0, op=torch.distributed.ReduceOp.SUM)
47 | grad.data /= world_size
48 |
49 |
50 | def dist_all_reduce_buffers(model):
51 | torch.distributed.barrier()
52 | world_size = float(torch.distributed.get_world_size())
53 | for n, b in model.named_buffers():
54 | torch.distributed.all_reduce(b.data, op=torch.distributed.ReduceOp.SUM)
55 | b.data /= world_size
56 |
57 |
58 | def dist_broadcast_model(model):
59 | torch.distributed.barrier()
60 | for _, param in model.state_dict().items():
61 | torch.distributed.broadcast(param, 0)
62 | torch.distributed.barrier()
63 | torch.cuda.synchronize()
--------------------------------------------------------------------------------
/brain_agent/core/models/action_distributions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from brain_agent.utils.logger import log
6 | from brain_agent.core.models.model_abc import ActionsParameterizationBase
7 |
8 |
9 | class DiscreteActionsParameterization(ActionsParameterizationBase):
10 | def __init__(self, cfg, core_out_size, action_space):
11 | super().__init__(cfg, action_space)
12 |
13 | num_action_outputs = action_space.n
14 | self.distribution_linear = nn.Linear(core_out_size, num_action_outputs)
15 |
16 | def forward(self, actor_core_output):
17 | action_distribution_params = self.distribution_linear(actor_core_output)
18 | action_distribution = CategoricalActionDistribution(raw_logits=action_distribution_params)
19 | return action_distribution_params, action_distribution
20 |
21 |
22 | class CategoricalActionDistribution:
23 | def __init__(self, raw_logits):
24 | self.raw_logits = raw_logits
25 | self.log_p = self.p = None
26 |
27 | @property
28 | def probs(self):
29 | if self.p is None:
30 | self.p = F.softmax(self.raw_logits, dim=-1)
31 | return self.p
32 |
33 | @property
34 | def log_probs(self):
35 | if self.log_p is None:
36 | self.log_p = F.log_softmax(self.raw_logits, dim=-1)
37 | return self.log_p
38 |
39 | def sample_gumbel(self):
40 | sample = torch.argmax(self.raw_logits - torch.empty_like(self.raw_logits).exponential_().log_(), -1)
41 | return sample
42 |
43 | def sample(self):
44 | samples = torch.multinomial(self.probs, 1, True).squeeze(dim=-1)
45 | return samples
46 |
47 | def sample_max(self):
48 | samples = torch.argmax(self.probs, dim=-1)
49 | return samples
50 |
51 | def log_prob(self, value):
52 | value = value.long().unsqueeze(-1)
53 | log_probs = torch.gather(self.log_probs, -1, value).view(-1)
54 | return log_probs
55 |
56 | def entropy(self):
57 | p_log_p = self.log_probs * self.probs
58 | return -p_log_p.sum(-1)
59 |
60 | def _kl(self, other_log_probs):
61 | probs, log_probs = self.probs, self.log_probs
62 | kl = probs * (log_probs - other_log_probs)
63 | kl = kl.sum(dim=-1)
64 | return kl
65 |
66 | def _kl_inverse(self, other_log_probs):
67 | probs, log_probs = self.probs, self.log_probs
68 | kl = torch.exp(other_log_probs) * (other_log_probs - log_probs)
69 | kl = kl.sum(dim=-1)
70 | return kl
71 |
72 | def _kl_symmetric(self, other_log_probs):
73 | return 0.5 * (self._kl(other_log_probs) + self._kl_inverse(other_log_probs))
74 |
75 | def symmetric_kl_with_uniform_prior(self):
76 | probs, log_probs = self.probs, self.log_probs
77 | num_categories = log_probs.shape[-1]
78 | uniform_prob = 1 / num_categories
79 | log_uniform_prob = math.log(uniform_prob)
80 |
81 | return 0.5 * ((probs * (log_probs - log_uniform_prob)).sum(dim=-1)
82 | + (uniform_prob * (log_uniform_prob - log_probs)).sum(dim=-1))
83 |
84 | def kl_divergence(self, other):
85 | return self._kl(other.log_probs)
86 |
87 | def dbg_print(self):
88 | dbg_info = dict(
89 | entropy=self.entropy().mean(),
90 | min_logit=self.raw_logits.min(),
91 | max_logit=self.raw_logits.max(),
92 | min_prob=self.probs.min(),
93 | max_prob=self.probs.max(),
94 | )
95 |
96 | msg = ''
97 | for key, value in dbg_info.items():
98 | msg += f'{key}={value.cpu().item():.3f} '
99 | log.debug(msg)
100 |
101 |
102 | def sample_actions_log_probs(distribution):
103 | actions = distribution.sample()
104 | log_prob_actions = distribution.log_prob(actions)
105 | return actions, log_prob_actions
106 |
--------------------------------------------------------------------------------
/brain_agent/envs/dmlab/dmlab_wrappers.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | from gym import spaces, ObservationWrapper
4 | from math import tanh
5 | from brain_agent.envs.dmlab.dmlab30 import RANDOM_SCORES, HUMAN_SCORES
6 |
7 |
8 | def has_image_observations(observation_space):
9 | return len(observation_space.shape) >= 2
10 |
11 | def compute_hns(r, h, s):
12 | return (s-r) / (h-r) * 100
13 |
14 | class PixelFormatChwWrapper(ObservationWrapper):
15 | def __init__(self, env):
16 | super().__init__(env)
17 |
18 | if isinstance(env.observation_space, gym.spaces.Dict):
19 | img_obs_space = env.observation_space['obs']
20 | self.dict_obs_space = True
21 | else:
22 | img_obs_space = env.observation_space
23 | self.dict_obs_space = False
24 |
25 | if not has_image_observations(img_obs_space):
26 | raise Exception('Pixel format wrapper only works with image-based envs')
27 |
28 | obs_shape = img_obs_space.shape
29 | max_num_img_channels = 4
30 |
31 | if len(obs_shape) <= 2:
32 | raise Exception('Env obs do not have channel dimension?')
33 |
34 | if obs_shape[0] <= max_num_img_channels:
35 | raise Exception('Env obs already in CHW format?')
36 |
37 | h, w, c = obs_shape
38 | low, high = img_obs_space.low.flat[0], img_obs_space.high.flat[0]
39 | new_shape = [c, h, w]
40 |
41 | if self.dict_obs_space:
42 | dtype = env.observation_space.spaces['obs'].dtype if env.observation_space.spaces['obs'].dtype is not None else np.float32
43 | else:
44 | dtype = env.observation_space.dtype if env.observation_space.dtype is not None else np.float32
45 |
46 | new_img_obs_space = spaces.Box(low, high, shape=new_shape, dtype=dtype)
47 |
48 | if self.dict_obs_space:
49 | self.observation_space = env.observation_space
50 | self.observation_space.spaces['obs'] = new_img_obs_space
51 | else:
52 | self.observation_space = new_img_obs_space
53 |
54 | self.action_space = env.action_space
55 |
56 | @staticmethod
57 | def _transpose(obs):
58 | return np.transpose(obs, (2, 0, 1))
59 |
60 | def observation(self, observation):
61 | if observation is None:
62 | return observation
63 |
64 | if self.dict_obs_space:
65 | observation['obs'] = self._transpose(observation['obs'])
66 | else:
67 | observation = self._transpose(observation)
68 | return observation
69 |
70 | class EpisodicStatWrapper(gym.Wrapper):
71 | def __init__(self, env):
72 | super().__init__(env)
73 | self.raw_episode_return = self.episode_return = self.episode_length = 0
74 |
75 | def reset(self):
76 | obs = self.env.reset()
77 | self.raw_episode_return = self.episode_return = self.episode_length = 0
78 | return obs
79 |
80 | def step(self, action):
81 | obs, rew, done, info = self.env.step(action)
82 | self.episode_return += rew
83 | self.raw_episode_return += info.get('raw_rew', rew)
84 | self.episode_length += info.get('num_frames', 1)
85 |
86 | if done:
87 | level_name = self.unwrapped.level_name
88 | hns = compute_hns(
89 | RANDOM_SCORES[level_name.replace('train', 'test')],
90 | HUMAN_SCORES[level_name.replace('train', 'test')],
91 | self.raw_episode_return
92 | )
93 |
94 | info['episodic_stats'] = {
95 | 'level_name': self.unwrapped.level_name,
96 | 'task_id': self.unwrapped.task_id,
97 | 'episode_return': self.episode_return,
98 | 'episode_length': self.episode_length,
99 | 'raw_episode_return': self.raw_episode_return,
100 | 'hns': hns,
101 | }
102 | self.episode_return = 0
103 | self.raw_episode_return = 0
104 | self.episode_length = 0
105 |
106 | return obs, rew, done, info
107 |
108 | class RewardShapingWrapper(gym.Wrapper):
109 | def __init__(self, env):
110 | super().__init__(env)
111 |
112 | def reset(self):
113 | obs = self.env.reset()
114 | return obs
115 |
116 | def step(self, action):
117 | obs, rew, done, info = self.env.step(action)
118 |
119 | info['raw_rew'] = rew
120 | squeezed = tanh(rew / 5.0)
121 | clipped = (1.5 * squeezed) if rew < 0.0 else (5.0 * squeezed)
122 | rew = clipped
123 |
124 | return obs, rew, done, info
125 |
--------------------------------------------------------------------------------
/brain_agent/envs/dmlab/dmlab_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 |
5 | from brain_agent.core.models.model_abc import EncoderBase
6 | from brain_agent.core.models.resnet import ResnetEncoder
7 | from brain_agent.envs.dmlab.dmlab30 import DMLAB_VOCABULARY_SIZE, DMLAB_INSTRUCTIONS
8 | from brain_agent.utils.logger import log
9 |
10 |
11 | class DmlabEncoder(EncoderBase):
12 | def __init__(self, cfg, obs_space):
13 | super().__init__(cfg)
14 |
15 | if self.cfg.model.encoder.encoder_type == 'resnet':
16 | self.basic_encoder = ResnetEncoder(cfg, obs_space)
17 | else:
18 | raise NotImplementedError
19 | self.encoder_out_size = self.basic_encoder.encoder_out_size
20 |
21 | self.embedding_size = 20
22 | self.instructions_lstm_units = 64
23 | self.instructions_lstm_layers = 1
24 |
25 | padding_idx = 0
26 | self.word_embedding = nn.Embedding(
27 | num_embeddings=DMLAB_VOCABULARY_SIZE,
28 | embedding_dim=self.embedding_size,
29 | padding_idx=padding_idx
30 | )
31 |
32 | self.instructions_lstm = nn.LSTM(
33 | input_size=self.embedding_size,
34 | hidden_size=self.instructions_lstm_units,
35 | num_layers=self.instructions_lstm_layers,
36 | batch_first=True,
37 | )
38 |
39 | self.encoder_out_size += self.instructions_lstm_units
40 | log.debug('Policy head output size: %r', self.encoder_out_size)
41 |
42 | self.instructions_lstm.apply(self.initialize)
43 |
44 | def initialize(self, layer):
45 | gain = 1.0
46 | if hasattr(layer, 'bias') and isinstance(layer.bias, torch.nn.parameter.Parameter):
47 | layer.bias.data.fill_(0)
48 |
49 | if self.cfg.model.encoder.encoder_init == 'orthogonal':
50 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
51 | nn.init.orthogonal_(layer.weight.data, gain=gain)
52 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell:
53 | nn.init.orthogonal_(layer.weight_ih, gain=gain)
54 | nn.init.orthogonal_(layer.weight_hh, gain=gain)
55 | layer.bias_ih.data.fill_(0)
56 | layer.bias_hh.data.fill_(0)
57 | elif self.cfg.model.encoder.encoder_init == 'xavier_uniform':
58 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
59 | nn.init.xavier_uniform_(layer.weight.data, gain=gain)
60 | layer.bias.data.fill_(0)
61 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell:
62 | nn.init.xavier_uniform_(layer.weight_ih, gain=gain)
63 | nn.init.xavier_uniform_(layer.weight_hh, gain=gain)
64 | layer.bias_ih.data.fill_(0)
65 | layer.bias_hh.data.fill_(0)
66 | elif self.cfg.model.encoder.encoder_init == 'torch_default':
67 | pass
68 | else:
69 | raise NotImplementedError
70 |
71 | def model_to_device(self, device):
72 | self.to(device)
73 | self.word_embedding.to(self.device)
74 | self.instructions_lstm.to(self.device)
75 |
76 | def device_and_type_for_input_tensor(self, input_tensor_name):
77 | if input_tensor_name == DMLAB_INSTRUCTIONS:
78 | return self.model_device(), torch.int64
79 | else:
80 | return self.model_device(), torch.float32
81 |
82 | def forward(self, obs_dict, **kwargs):
83 | x = self.basic_encoder(obs_dict, **kwargs)
84 |
85 | with torch.no_grad():
86 | instr = obs_dict[DMLAB_INSTRUCTIONS]
87 | instr_lengths = (instr != 0).sum(axis=1)
88 | instr_lengths = torch.clamp(instr_lengths, min=1)
89 | max_instr_len = torch.max(instr_lengths).item()
90 | instr = instr[:, :max_instr_len]
91 | instr_lengths_cpu = instr_lengths.to('cpu')
92 |
93 | instr_embed = self.word_embedding(instr)
94 | instr_packed = torch.nn.utils.rnn.pack_padded_sequence(
95 | instr_embed, instr_lengths_cpu, batch_first=True, enforce_sorted=False,
96 | )
97 | rnn_output, _ = self.instructions_lstm(instr_packed)
98 | rnn_outputs, sequence_lengths = torch.nn.utils.rnn.pad_packed_sequence(rnn_output, batch_first=True)
99 |
100 | first_dim_idx = torch.arange(rnn_outputs.shape[0])
101 | last_output_idx = sequence_lengths - 1
102 | last_outputs = rnn_outputs[first_dim_idx, last_output_idx]
103 |
104 | last_outputs = last_outputs.to(x.device)
105 |
106 | x = torch.cat((x, last_outputs), dim=1)
107 | return x
108 |
109 |
--------------------------------------------------------------------------------
/brain_agent/core/models/causal_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from brain_agent.core.models.transformer import PositionalEmbedding, RelPartialLearnableDecoderLayer
5 |
6 |
7 | class CausalTransformer(nn.Module):
8 | def __init__(self, core_out_size, n_action, pre_lnorm=False):
9 | super().__init__()
10 | self.n_layer = 4
11 | self.n_head = 3
12 | self.d_head = 64
13 | self.d_inner = 512
14 | self.d_model = 196
15 | self.mem_len = 64
16 |
17 | self.blocks = nn.ModuleList()
18 | for i in range(self.n_layer):
19 | self.blocks.append(RelPartialLearnableDecoderLayer(
20 | n_head=self.n_head, d_model=self.d_model, d_head=self.d_head, d_inner=self.d_inner,
21 | pre_lnorm=pre_lnorm))
22 | # decoder head
23 | self.ln_f = nn.LayerNorm(self.d_model)
24 | self.head = nn.Linear(self.d_model, core_out_size, bias=True)
25 |
26 | self.apply(self._init_weights)
27 |
28 | self.state_encoder = nn.Sequential(nn.Linear(core_out_size+n_action, self.d_model), nn.Tanh())
29 |
30 | self.pos_emb = PositionalEmbedding(self.d_model)
31 | self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
32 | self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
33 |
34 |
35 | def _init_weights(self, module):
36 | if isinstance(module, (nn.Linear, nn.Embedding)):
37 | module.weight.data.normal_(mean=0.0, std=0.02)
38 | if isinstance(module, nn.Linear) and module.bias is not None:
39 | module.bias.data.zero_()
40 | elif isinstance(module, nn.LayerNorm):
41 | module.bias.data.zero_()
42 | module.weight.data.fill_(1.0)
43 |
44 | def forward(self, states, mem_begin_index, num_traj, mems=None):
45 | state_embeddings = self.state_encoder(states) # (batch * block_size, n_embd)
46 | x = state_embeddings
47 |
48 | qlen, bsz, _ = x.size() # qlen is number of characters in input ex
49 |
50 | if mems is not None:
51 | mlen = mems[0].size(0)
52 | klen = mlen + qlen
53 | dec_attn_mask_triu = torch.triu(state_embeddings.new_ones(qlen, klen), diagonal=1 + mlen)
54 | dec_attn_mask = dec_attn_mask_triu.bool().unsqueeze(-1).repeat(1, 1, bsz)
55 | else:
56 | mlen = self.mem_len
57 | klen = self.mem_len
58 | dec_attn_mask_triu = torch.triu(state_embeddings.new_ones(qlen, klen), diagonal=1)
59 | dec_attn_mask = dec_attn_mask_triu.bool().unsqueeze(-1).repeat(1, 1, bsz)
60 |
61 | for b in range(bsz):
62 | if mlen-mem_begin_index[b] > 0:
63 | dec_attn_mask[:, :mlen-mem_begin_index[b], b] = True
64 |
65 | dec_attn_mask = dec_attn_mask.transpose(1,2)
66 | temp = torch.logical_not(dec_attn_mask)
67 | temp = torch.sum(temp, dim=2, keepdim=True)
68 | temp = torch.ge(temp, 0.1)
69 |
70 | dec_attn_mask = torch.logical_and(temp, dec_attn_mask)
71 | dec_attn_mask = dec_attn_mask.transpose(1, 2)
72 |
73 | pos_seq = torch.arange(klen - 1, -1, -1.0, device=states.device, dtype=states.dtype) # [99,...0]
74 | pos_emb = self.pos_emb(pos_seq) # T x 1 x dim
75 |
76 | hids = [x]
77 | for i, layer in enumerate(self.blocks):
78 | if mems is not None:
79 | x = layer(x, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems[i])
80 | else:
81 | x = layer(x, pos_emb, self.r_w_bias, self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=None)
82 | hids.append(x)
83 |
84 | if mems is None:
85 | new_mems = hids
86 | else:
87 | new_mems = self._update_mems(hids, mems, mlen, qlen)
88 | mem_begin_index = mem_begin_index + 1
89 | x = self.ln_f(x)
90 | logits = self.head(x)
91 | return logits, new_mems, mem_begin_index
92 |
93 |
94 | def _update_mems(self, hids, mems, mlen, qlen):
95 | # does not deal with None
96 | if mems is None: return None
97 |
98 | # mems is not None
99 | assert len(hids) == len(mems), 'len(hids) != len(mems)'
100 |
101 | new_mems = []
102 | end_idx = mlen + max(0, qlen) # ext_len looks to usually be 0 (in their experiments anyways
103 | beg_idx = max(0, end_idx - self.mem_len) #if hids[0].shape[0] > 1 else 0
104 | for i in range(len(hids)):
105 | cat = torch.cat([mems[i], hids[i]], dim=0) # (m_len + q) x B x dim
106 | aa=1
107 | if beg_idx == end_idx: # cfg.mem_len=0
108 | new_mems.append(torch.zeros(cat[0:1].size()))
109 | else: # cfg.mem_len > 0
110 | new_mems.append(cat[beg_idx:end_idx])
111 |
112 | return new_mems
--------------------------------------------------------------------------------
/dist_launch.py:
--------------------------------------------------------------------------------
1 | # codebase: torch.distributed.launch.py
2 |
3 | import sys
4 | import subprocess
5 | import os
6 | from argparse import ArgumentParser, REMAINDER
7 |
8 |
9 | def parse_args():
10 | """
11 | Helper function parsing the command line options
12 | @retval ArgumentParser
13 | """
14 | parser = ArgumentParser(description="PyTorch distributed training launch "
15 | "helper utility that will spawn up "
16 | "multiple distributed processes")
17 |
18 | # Optional arguments for the launch helper
19 | parser.add_argument("--nnodes", type=int, default=1,
20 | help="The number of nodes to use for distributed "
21 | "training")
22 | parser.add_argument("--node_rank", type=int, default=0,
23 | help="The rank of the node for multi-node distributed "
24 | "training")
25 | parser.add_argument("--nproc_per_node", type=int, default=1,
26 | help="The number of processes to launch on each node, "
27 | "for GPU training, this is recommended to be set "
28 | "to the number of GPUs in your system so that "
29 | "each process can be bound to a single GPU.")
30 | parser.add_argument("--master_addr", default="127.0.0.1", type=str,
31 | help="Master node (rank 0)'s address, should be either "
32 | "the IP address or the hostname of node 0, for "
33 | "single node multi-proc training, the "
34 | "--master_addr can simply be 127.0.0.1")
35 | parser.add_argument("--master_port", default=2901, type=int,
36 | help="Master node (rank 0)'s free port that needs to "
37 | "be used for communication during distributed "
38 | "training")
39 | parser.add_argument("-m", "--module", default=False, action="store_true",
40 | help="Changes each process to interpret the launch script "
41 | "as a python module, executing with the same behavior as"
42 | "'python -m'.")
43 |
44 | # positional
45 | parser.add_argument("training_script", type=str,
46 | help="The full path to the single GPU training "
47 | "program/script to be launched in parallel, "
48 | "followed by all the arguments for the "
49 | "training script")
50 |
51 | # rest from the training program
52 | parser.add_argument('training_script_args', nargs=REMAINDER)
53 | return parser.parse_args()
54 |
55 | def main():
56 | args = parse_args()
57 |
58 | # world size in terms of number of processes
59 | dist_world_size = args.nproc_per_node * args.nnodes
60 |
61 | # set PyTorch distributed related environmental variables
62 | current_env = os.environ.copy()
63 | current_env["MASTER_ADDR"] = args.master_addr
64 | current_env["MASTER_PORT"] = str(args.master_port)
65 | current_env["WORLD_SIZE"] = str(dist_world_size)
66 |
67 | processes = []
68 |
69 | if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
70 | current_env["OMP_NUM_THREADS"] = str(1)
71 | print("*****************************************\n"
72 | "Setting OMP_NUM_THREADS environment variable for each process "
73 | "to be {} in default, to avoid your system being overloaded, "
74 | "please further tune the variable for optimal performance in "
75 | "your application as needed. \n"
76 | "*****************************************".format(current_env["OMP_NUM_THREADS"]))
77 |
78 | for local_rank in range(0, args.nproc_per_node):
79 | # each process's rank
80 | dist_rank = args.nproc_per_node * args.node_rank + local_rank
81 | current_env["RANK"] = str(dist_rank)
82 | current_env["LOCAL_RANK"] = str(local_rank)
83 |
84 | # spawn the processes
85 | cmd = [sys.executable, "-u"]
86 | if args.module:
87 | cmd.append("-m")
88 |
89 | cmd.append(args.training_script)
90 |
91 | cmd.append("dist.local_rank={}".format(local_rank))
92 | cmd.append("dist.nproc_per_node={}".format(args.nproc_per_node))
93 | cmd.append("dist.world_rank={}".format(dist_rank))
94 | cmd.append("dist.world_size={}".format(dist_world_size))
95 |
96 | cmd.extend(args.training_script_args)
97 |
98 | process = subprocess.Popen(cmd, env=current_env)
99 | processes.append(process)
100 |
101 | for process in processes:
102 | process.wait()
103 | if process.returncode != 0:
104 | raise subprocess.CalledProcessError(returncode=process.returncode,
105 | cmd=cmd)
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
--------------------------------------------------------------------------------
/brain_agent/core/core_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import psutil
3 | from collections import OrderedDict
4 | from brain_agent.utils.logger import log
5 | from faster_fifo import Full, Empty
6 |
7 | class TaskType:
8 | INIT, TERMINATE, RESET, ROLLOUT_STEP, POLICY_STEP, TRAIN, INIT_MODEL, EMPTY = range(8)
9 |
10 | def dict_of_lists_append(dict_of_lists, new_data, index):
11 | for key, x in new_data.items():
12 | if key in dict_of_lists:
13 | dict_of_lists[key].append(x[index])
14 | else:
15 | dict_of_lists[key] = [x[index]]
16 |
17 | def copy_dict_structure(d):
18 | d_copy = type(d)()
19 | _copy_dict_structure_func(d, d_copy)
20 | return d_copy
21 |
22 | def _copy_dict_structure_func(d, d_copy):
23 | for key, value in d.items():
24 | if isinstance(value, (dict, OrderedDict)):
25 | d_copy[key] = type(value)()
26 | _copy_dict_structure_func(value, d_copy[key])
27 | else:
28 | d_copy[key] = None
29 |
30 | def iterate_recursively(d):
31 | for k, v in d.items():
32 | if isinstance(v, (dict, OrderedDict)):
33 | yield from iterate_recursively(v)
34 | else:
35 | yield d, k, v
36 |
37 | def iter_dicts_recursively(d1, d2):
38 | for k, v in d1.items():
39 | assert k in d2
40 |
41 | if isinstance(v, (dict, OrderedDict)):
42 | yield from iter_dicts_recursively(d1[k], d2[k])
43 | else:
44 | yield d1, d2, k, d1[k], d2[k]
45 |
46 |
47 | def slice_mems(mems_buffer, mems_dones_buffer, mems_actions_buffer, actor_idx, split_idx, env_idx, s_idx, e_idx):
48 | # Slice given mems buffers in a cyclic queue manner
49 | if s_idx > e_idx:
50 | mems = torch.cat(
51 | [mems_buffer[actor_idx, split_idx, env_idx, s_idx:],
52 | mems_buffer[actor_idx, split_idx, env_idx, :e_idx]])
53 | mems_dones = torch.cat(
54 | [mems_dones_buffer[actor_idx, split_idx, env_idx, s_idx:],
55 | mems_dones_buffer[actor_idx, split_idx, env_idx, :e_idx]])
56 | mems_actions = torch.cat(
57 | [mems_actions_buffer[actor_idx, split_idx, env_idx, s_idx:],
58 | mems_actions_buffer[actor_idx, split_idx, env_idx, :e_idx]])
59 | else:
60 | mems = mems_buffer[actor_idx, split_idx, env_idx, s_idx:e_idx]
61 | mems_dones = mems_dones_buffer[actor_idx, split_idx, env_idx, s_idx:e_idx]
62 | mems_actions = mems_actions_buffer[actor_idx, split_idx, env_idx, s_idx:e_idx]
63 | return mems, mems_dones, mems_actions
64 |
65 | def join_or_kill(process, timeout=1.0):
66 | process.join(timeout)
67 | if process.is_alive():
68 | log.warning('Process %r could not join, kill it with fire!', process)
69 | process.kill()
70 | log.warning('Process %r is dead (%r)', process, process.is_alive())
71 |
72 | def set_process_cpu_affinity(worker_idx, num_workers, local_rank=0, nproc_per_node=0):
73 | curr_process = psutil.Process()
74 | available_cores = curr_process.cpu_affinity()
75 | cpu_count = len(available_cores)
76 | if nproc_per_node > 1:
77 | worker_idx = worker_idx * nproc_per_node + local_rank
78 | num_workers = num_workers * nproc_per_node
79 | core_indices = cores_for_worker_process(worker_idx, num_workers, cpu_count)
80 | if core_indices is not None:
81 | curr_process_cores = [available_cores[c] for c in core_indices]
82 | curr_process.cpu_affinity(curr_process_cores)
83 |
84 | log.debug('Worker %d uses CPU cores %r', worker_idx, curr_process.cpu_affinity())
85 |
86 | def cores_for_worker_process(worker_idx, num_workers, cpu_count):
87 | """
88 | Returns core indices, assuming available cores are [0, ..., cpu_count).
89 | If this is not the case (e.g. SLURM) use these as indices in the array of actual available cores.
90 | """
91 |
92 | worker_idx_modulo = worker_idx % cpu_count
93 |
94 | cores = None
95 | whole_workers_per_core = num_workers // cpu_count
96 | if worker_idx < whole_workers_per_core * cpu_count:
97 | cores = [worker_idx_modulo]
98 | else:
99 | remaining_workers = num_workers % cpu_count
100 | if cpu_count % remaining_workers == 0:
101 | cores_to_use = cpu_count // remaining_workers
102 | cores = list(range(worker_idx_modulo * cores_to_use, (worker_idx_modulo + 1) * cores_to_use, 1))
103 |
104 | return cores
105 |
106 | def safe_put(q, msg, attempts=3, queue_name=''):
107 | safe_put_many(q, [msg], attempts, queue_name)
108 |
109 |
110 | def safe_put_many(q, msgs, attempts=3, queue_name=''):
111 | for attempt in range(attempts):
112 | try:
113 | q.put_many(msgs)
114 | return
115 | except Full:
116 | log.warning('Could not put msgs to queue, the queue %s is full! Attempt %d', queue_name, attempt)
117 |
118 | log.error('Failed to put msgs to queue %s after %d attempts. Messages are lost!', queue_name, attempts)
119 |
120 | def safe_get(q, timeout=1e6, msg='Queue timeout'):
121 | """Using queue.get() with timeout is necessary, otherwise KeyboardInterrupt is not handled."""
122 | while True:
123 | try:
124 | return q.get(timeout=timeout)
125 | except Empty:
126 | log.info('Queue timed out (%s), timeout %.3f', msg, timeout)
--------------------------------------------------------------------------------
/brain_agent/core/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from brain_agent.core.models.model_abc import EncoderBase
4 | from brain_agent.core.models.model_utils import get_obs_shape, calc_num_elements, nonlinearity
5 | from brain_agent.utils.logger import log
6 |
7 | class ResnetEncoder(EncoderBase):
8 | def __init__(self, cfg, obs_space):
9 | super().__init__(cfg)
10 |
11 | obs_shape = get_obs_shape(obs_space)
12 | input_ch = obs_shape.obs[0]
13 | self.input_ch = input_ch
14 | log.debug('Num input channels: %d', input_ch)
15 |
16 | if cfg.model.encoder.encoder_subtype == 'resnet_impala':
17 | resnet_conf = [[16, 2], [32, 2], [32, 2]]
18 | elif cfg.model.encoder.encoder_subtype == 'resnet_impala_large':
19 | resnet_conf = [[32, 2], [64, 2], [64, 2]]
20 | else:
21 | raise NotImplementedError(f'Unknown resnet subtype {cfg.model.encoder.encoder_subtype}')
22 |
23 | curr_input_channels = input_ch
24 | layers = []
25 | if cfg.learner.use_decoder:
26 | layers_decoder = []
27 | for i, (out_channels, res_blocks) in enumerate(resnet_conf):
28 | if cfg.model.encoder.encoder_pooling == 'stride':
29 | enc_stride = 2
30 | pool = nn.Identity
31 | else:
32 | enc_stride = 1
33 | pool = nn.MaxPool2d if cfg.model.encoder.encoder_pooling == 'max' else nn.AvgPool2d
34 | layers.extend([
35 | nn.Conv2d(curr_input_channels, out_channels, kernel_size=3, stride=enc_stride, padding=1),
36 | pool(kernel_size=3, stride=2, padding=1), # padding SAME
37 | ])
38 |
39 | for j in range(res_blocks):
40 | layers.append(ResBlock(cfg, out_channels, out_channels))
41 |
42 | if cfg.learner.use_decoder:
43 | for j in range(res_blocks):
44 | layers_decoder.append(ResBlock(cfg, curr_input_channels, curr_input_channels))
45 | layers_decoder.append(
46 | nn.ConvTranspose2d(out_channels, curr_input_channels, kernel_size=3, stride=2,
47 | padding=1, output_padding=1)
48 | )
49 | curr_input_channels = out_channels
50 |
51 | layers.append(nonlinearity(cfg))
52 |
53 | self.conv_head = nn.Sequential(*layers)
54 | self.conv_head_out_size = calc_num_elements(self.conv_head, obs_shape.obs)
55 | log.debug('Convolutional layer output size: %r', self.conv_head_out_size)
56 | self.init_fc_blocks(self.conv_head_out_size)
57 |
58 | if cfg.learner.use_decoder:
59 | layers_decoder.reverse()
60 | self.deconv_head = nn.Sequential(*layers_decoder)
61 |
62 | self.apply(self.initialize)
63 |
64 | def initialize(self, layer):
65 | gain = 1.0
66 | if hasattr(layer, 'bias') and isinstance(layer.bias, torch.nn.parameter.Parameter):
67 | layer.bias.data.fill_(0)
68 |
69 | if self.cfg.model.encoder.encoder_init == 'orthogonal':
70 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
71 | nn.init.orthogonal_(layer.weight.data, gain=gain)
72 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell: # TODO: test for LSTM
73 | nn.init.orthogonal_(layer.weight_ih, gain=gain)
74 | nn.init.orthogonal_(layer.weight_hh, gain=gain)
75 | layer.bias_ih.data.fill_(0)
76 | layer.bias_hh.data.fill_(0)
77 | elif self.cfg.model.encoder.encoder_init == 'xavier_uniform':
78 | if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
79 | nn.init.xavier_uniform_(layer.weight.data, gain=gain)
80 | layer.bias.data.fill_(0)
81 | elif type(layer) == nn.GRUCell or type(layer) == nn.LSTMCell:
82 | nn.init.xavier_uniform_(layer.weight_ih, gain=gain)
83 | nn.init.xavier_uniform_(layer.weight_hh, gain=gain)
84 | layer.bias_ih.data.fill_(0)
85 | layer.bias_hh.data.fill_(0)
86 | elif self.cfg.model.encoder.encoder_init == 'torch_default':
87 | pass
88 | else:
89 | raise NotImplementedError
90 |
91 | def forward(self, obs_dict, decode=False):
92 | x = self.conv_head(obs_dict['obs'])
93 |
94 | if decode:
95 | self.reconstruction = torch.tanh(self.deconv_head(x))
96 |
97 | x = x.contiguous().view(-1, self.conv_head_out_size)
98 | x = self.forward_fc_blocks(x)
99 | return x
100 |
101 |
102 | class ResBlock(nn.Module):
103 | def __init__(self, cfg, input_ch, output_ch):
104 | super().__init__()
105 |
106 | layers = [
107 | nonlinearity(cfg),
108 | nn.Conv2d(input_ch, output_ch, kernel_size=3, stride=1, padding=1), # padding SAME
109 | nonlinearity(cfg),
110 | nn.Conv2d(output_ch, output_ch, kernel_size=3, stride=1, padding=1), # padding SAME
111 | ]
112 |
113 | self.res_block_core = nn.Sequential(*layers)
114 |
115 | def forward(self, x):
116 | identity = x
117 | out = self.res_block_core(x)
118 | out = out + identity
119 | return out
120 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | import numpy as np
4 | from faster_fifo import Queue, Empty
5 | from tensorboardX import SummaryWriter
6 | from brain_agent.core.actor_worker import ActorWorker
7 | from brain_agent.core.policy_worker import PolicyWorker
8 | from brain_agent.core.shared_buffer import SharedBuffer
9 | from brain_agent.utils.cfg import Configs
10 | from brain_agent.utils.utils import get_log_path, dict_of_list_put, AttrDict, get_summary_dir
11 | from brain_agent.core.core_utils import TaskType
12 | from brain_agent.utils.logger import log, init_logger
13 | from brain_agent.envs.env_utils import create_env
14 |
15 | def main():
16 |
17 | cfg = Configs.get_defaults()
18 | cfg = Configs.override_from_file_name(cfg)
19 | cfg = Configs.override_from_cli(cfg)
20 |
21 | cfg_str = Configs.to_yaml(cfg)
22 | cfg = Configs.to_attr_dict(cfg)
23 |
24 | init_logger(cfg.log.log_level, get_log_path(cfg))
25 |
26 | log.info(f'Experiment configuration:\n{cfg_str}')
27 |
28 | tmp_env = create_env(cfg, env_config=None)
29 | action_space = tmp_env.action_space
30 | obs_space = tmp_env.observation_space
31 | level_info = tmp_env.level_info
32 | num_levels = level_info['num_levels']
33 | tmp_env.close()
34 |
35 | assert cfg.env.one_task_per_worker
36 | assert cfg.test.is_test
37 | assert cfg.actor.num_workers >= level_info['num_levels']
38 |
39 | shared_buffer = SharedBuffer(cfg, obs_space, action_space)
40 | shared_buffer.stop_experience_collection.fill_(False)
41 |
42 | policy_worker_queue = Queue()
43 | actor_worker_queues = [Queue(2 * 1000 * 1000) for _ in range(cfg.actor.num_workers)]
44 | policy_queue = Queue()
45 | report_queue = Queue(40 * 1000 * 1000)
46 |
47 | policy_worker = PolicyWorker(cfg, obs_space, action_space, tmp_env.level_info, shared_buffer,
48 | policy_queue, actor_worker_queues, policy_worker_queue, report_queue)
49 | policy_worker.start_process()
50 | policy_worker.init()
51 | policy_worker.load_model()
52 |
53 | actor_workers = []
54 | for i in range(cfg.actor.num_workers):
55 | w = ActorWorker(cfg, obs_space, action_space, i, shared_buffer, actor_worker_queues[i], policy_queue,
56 | report_queue)
57 | w.init()
58 | w.request_reset()
59 | actor_workers.append(w)
60 |
61 | writer = SummaryWriter(get_summary_dir(cfg, postfix='test'))
62 |
63 | stats = AttrDict()
64 | stats['episodic_stats'] = AttrDict()
65 | actor_worker_task_id = AttrDict()
66 |
67 | env_steps = 0
68 | num_collected = 0
69 | terminate = False
70 |
71 | while not terminate:
72 | try:
73 | reports = report_queue.get_many(timeout=0.1)
74 | for report in reports:
75 | if 'terminate' in report:
76 | terminate = True
77 | if 'learner_env_steps' in report:
78 | env_steps = report['learner_env_steps']
79 | if 'initialized_env' in report:
80 | actor_idx, split_idx, _, task_id = report['initialized_env']
81 | actor_worker_task_id[actor_idx] = task_id[0]
82 | if 'episodic_stats' in report:
83 | s = report['episodic_stats']
84 | level_name = s['level_name'].replace('_contributed/dmlab30/', '')
85 | level_id = s['task_id']
86 |
87 | tag = f'_dmlab/{level_id:02d}_{level_name}_human_norm_score'
88 | dict_of_list_put(stats.episodic_stats, tag, s['hns'], cfg.test.test_num_episodes)
89 |
90 | hns = s['hns']
91 | log.info(f'[{num_collected} / {num_levels * cfg.test.test_num_episodes}] {level_id:02d}_'
92 | f'{level_name}: {hns}')
93 |
94 | if len(stats.episodic_stats[tag]) >= cfg.test.test_num_episodes:
95 | for i, w in enumerate(actor_workers):
96 | if actor_worker_task_id[i] == level_id and w.process.is_alive:
97 | actor_worker_queues[i].put((TaskType.TERMINATE, None))
98 |
99 | hns = []
100 | num_collected = 0
101 | for i, l in enumerate(level_info['all_levels']):
102 | tag = f'_dmlab/{i:02d}_{l}_human_norm_score'
103 | h = stats.episodic_stats.get(tag, None)
104 | if h is not None:
105 | num_collected += len(h)
106 | hns.append(np.array(h).mean())
107 |
108 | if num_collected >= num_levels * cfg.test.test_num_episodes:
109 | hns = np.array(hns)
110 | capped_hns = np.clip(hns, None, 100)
111 | log.info('-' * 100)
112 | log.info(f'num_collected: {num_collected}')
113 | log.info(f'mean_human_norm_score: {hns.mean()}')
114 | log.info(f'mean_capped_human_norm_score: {capped_hns.mean()}')
115 | log.info(f'median_human_norm_score: {np.median(hns)}')
116 | for i, l in enumerate(level_info['all_levels']):
117 | tag = f'_dmlab/{i:02d}_{l}_human_norm_score'
118 | h = stats.episodic_stats[tag]
119 | log.info(f'{tag}: {np.array(h).mean()}')
120 |
121 | writer.add_scalar(f'_dmlab/000_mean_human_norm_score', hns.mean(), env_steps)
122 | writer.add_scalar(f'_dmlab/000_mean_capped_human_norm_score', capped_hns.mean(), env_steps)
123 | writer.add_scalar(f'_dmlab/000_median_human_norm_score', np.median(hns), env_steps)
124 | for tag, scalar in stats.episodic_stats.items():
125 | writer.add_scalar(tag, np.array(scalar).mean(), env_steps)
126 |
127 | terminate = True
128 |
129 | except Empty:
130 | time.sleep(1.0)
131 | pass
132 |
133 | if __name__ == '__main__':
134 | sys.exit(main())
--------------------------------------------------------------------------------
/brain_agent/core/models/rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Tuple
4 | from torch.nn.utils.rnn import PackedSequence, invert_permutation
5 |
6 | def _build_pack_info_from_dones(dones: torch.Tensor, T: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
7 | num_samples = len(dones)
8 |
9 | rollout_boundaries = dones.clone().detach()
10 | rollout_boundaries[T - 1::T] = 1
11 | rollout_boundaries = rollout_boundaries.nonzero().squeeze(dim=1) + 1
12 |
13 | first_len = rollout_boundaries[0].unsqueeze(0)
14 |
15 | if len(rollout_boundaries) <= 1:
16 | rollout_lengths = first_len
17 | else:
18 | rollout_lengths = rollout_boundaries[1:] - rollout_boundaries[:-1]
19 | rollout_lengths = torch.cat([first_len, rollout_lengths])
20 |
21 | rollout_starts_orig = rollout_boundaries - rollout_lengths
22 |
23 | is_new_episode = dones.clone().detach().view((-1, T))
24 | is_new_episode = is_new_episode.roll(1, 1)
25 |
26 | is_new_episode[:, 0] = 0
27 | is_new_episode = is_new_episode.view((-1, ))
28 |
29 | lengths, sorted_indices = torch.sort(rollout_lengths, descending=True)
30 |
31 | cpu_lengths = lengths.to(device='cpu', non_blocking=True)
32 |
33 | rollout_starts_sorted = rollout_starts_orig.index_select(0, sorted_indices)
34 |
35 | select_inds = torch.empty(num_samples, device=dones.device, dtype=torch.int64)
36 |
37 | max_length = int(cpu_lengths[0].item())
38 |
39 | batch_sizes = torch.empty((max_length,), device='cpu', dtype=torch.int64)
40 |
41 | offset = 0
42 | prev_len = 0
43 | num_valid_for_length = lengths.size(0)
44 |
45 | unique_lengths = torch.unique_consecutive(cpu_lengths)
46 |
47 | for i in range(len(unique_lengths) - 1, -1, -1):
48 | valids = lengths[0:num_valid_for_length] > prev_len
49 | num_valid_for_length = int(valids.float().sum().item())
50 |
51 | next_len = int(unique_lengths[i])
52 |
53 | batch_sizes[prev_len:next_len] = num_valid_for_length
54 |
55 | new_inds = (
56 | rollout_starts_sorted[0:num_valid_for_length].view(1, num_valid_for_length)
57 | + torch.arange(prev_len, next_len, device=rollout_starts_sorted.device).view(next_len - prev_len, 1)
58 | ).view(-1)
59 |
60 | select_inds[offset:offset + new_inds.numel()] = new_inds
61 |
62 | offset += new_inds.numel()
63 |
64 | prev_len = next_len
65 |
66 | assert offset == num_samples
67 | assert is_new_episode.shape[0] == num_samples
68 |
69 | return rollout_starts_orig, is_new_episode, select_inds, batch_sizes, sorted_indices
70 |
71 | def _build_rnn_inputs(x, dones_cpu, rnn_states, T: int):
72 | rollout_starts, is_new_episode, select_inds, batch_sizes, sorted_indices = _build_pack_info_from_dones(
73 | dones_cpu, T)
74 | inverted_select_inds = invert_permutation(select_inds)
75 |
76 | def device(t):
77 | return t.to(device=x.device)
78 |
79 | select_inds = device(select_inds)
80 | inverted_select_inds = device(inverted_select_inds)
81 | sorted_indices = device(sorted_indices)
82 | rollout_starts = device(rollout_starts)
83 | is_new_episode = device(is_new_episode)
84 |
85 | x_seq = PackedSequence(x.index_select(0, select_inds), batch_sizes, sorted_indices)
86 |
87 | rnn_states = rnn_states.index_select(0, rollout_starts)
88 | is_same_episode = (1 - is_new_episode.view(-1, 1)).index_select(0, rollout_starts)
89 | rnn_states = rnn_states * is_same_episode
90 |
91 | return x_seq, rnn_states, inverted_select_inds
92 |
93 |
94 | class LSTM(nn.Module):
95 | def __init__(self, cfg, input_size):
96 | super().__init__()
97 |
98 | self.cfg = cfg
99 | self.core = nn.LSTM(input_size, cfg.model.core.hidden_size, cfg.model.core.n_rnn_layer)
100 |
101 | self.core_output_size = cfg.model.core.hidden_size
102 | self.n_rnn_layer = cfg.model.core.n_rnn_layer
103 | self.apply(self.initialize)
104 |
105 | def initialize(self, layer):
106 | gain = 1.0
107 |
108 | if self.cfg.model.core.core_init == 'tensorflow_default':
109 | if type(layer) == nn.LSTM:
110 | for n, p in layer.named_parameters():
111 | if 'weight_ih' in n:
112 | nn.init.xavier_uniform_(p.data, gain=gain)
113 | elif 'weight_hh' in n:
114 | nn.init.orthogonal_(p.data)
115 | elif 'bias_ih' in n:
116 | p.data.fill_(0)
117 | # Set forget-gate bias to 1
118 | n = p.size(0)
119 | p.data[(n // 4):(n // 2)].fill_(1)
120 | elif 'bias_hh' in n:
121 | p.data.fill_(0)
122 | elif self.cfg.model.core.core_init == 'torch_default':
123 | pass
124 | else:
125 | raise NotImplementedError
126 |
127 | def forward(self, head_output, rnn_states, dones, is_seq):
128 | if not is_seq:
129 | head_output = head_output.unsqueeze(0)
130 |
131 | if self.n_rnn_layer > 1:
132 | rnn_states = rnn_states.view(rnn_states.size(0), self.cfg.model.core.n_rnn_layer, -1)
133 | rnn_states = rnn_states.permute(1, 0, 2)
134 | else:
135 | rnn_states = rnn_states.unsqueeze(0)
136 |
137 | h, c = torch.split(rnn_states, self.cfg.model.core.hidden_size, dim=2)
138 |
139 | x, (h, c) = self.core(head_output, (h.contiguous(), c.contiguous()))
140 | new_rnn_states = torch.cat((h, c), dim=2)
141 |
142 | if not is_seq:
143 | x = x.squeeze(0)
144 |
145 | if self.n_rnn_layer > 1:
146 | new_rnn_states = new_rnn_states.permute(1, 0, 2)
147 | new_rnn_states = new_rnn_states.reshape(new_rnn_states.size(0), -1)
148 | else:
149 | new_rnn_states = new_rnn_states.squeeze(0)
150 |
151 | return x, new_rnn_states
152 |
153 | def get_core_out_size(self):
154 | return self.core_output_size
155 |
156 | @classmethod
157 | def build_rnn_inputs(cls, x, dones_cpu, rnn_states, T: int):
158 | return _build_rnn_inputs(x, dones_cpu, rnn_states, T)
159 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | cfg: None # path to configuration (.yaml) file
2 | train_dir: ??? # path to train dir
3 | experiment: ??? # experiment name. logs will be saved train_dir/experiment
4 |
5 | seed: 0
6 |
7 | # dist arguments will be automatically set from dist_launch.py
8 | dist:
9 | world_size: 1
10 | world_rank: 0
11 | local_rank: 0
12 | nproc_per_node: 1
13 | dist_backend: nccl
14 |
15 | model:
16 | agent: dmlab_multitask_agent # [dmlab_multitask_agent]
17 | encoder:
18 | encoder_type: resnet # [resnet]. encoder type. only resnet is currently implemented
19 | encoder_subtype: resnet_impala # [resnet_impala, resnet_impala_large]. specific encoder architecture. see resnet.py
20 | encoder_pooling: max # [max, stride, ???] pooling type. max pooling or conv2d with stride 2. Use avg pooling if not specified.
21 | nonlinearity: relu # [relu, elu, tanh] activation function used in encoder
22 | nonlinear_inplace: False # nonlinearity will be computed "inplace" if True
23 | encoder_extra_fc_layers: 1 # the number of extra fully connected layers that connect encoder outputs and core inputs
24 | encoder_init: orthogonal # [orthogonal, xavier_uniform, torch_default] initialization method used in encoder
25 | core:
26 | core_type: trxl # [trxl, rnn] core type.
27 | hidden_size: 256 # size of hidden layer in the model
28 | # trxl params
29 | mem_len: 512 # memory length
30 | n_layer: 12 # number of layers
31 | n_heads: 8 # number of MHA heads
32 | d_head: 64 # MHA head dimension
33 | d_inner: 2048 # the position wise ff network dimension
34 | # rnn_params
35 | n_rnn_layer: 3
36 | core_init: tensorflow_default # [tensorflow_default, torch_default] initialization method used in core
37 |
38 | extended_input: True # concatenate ont_hot action and reward to encoder output if set True
39 | use_popart: True # whether to use PopArt normalization or not
40 | popart_beta: 0.0003 # decay rate to track mean and standard derivation of the values
41 | popart_clip_min: 0.0001 # popart minimum clip value for numerical stability
42 | device: cuda
43 | use_half_policy_worker: True # use half-precision in Policy Worker
44 |
45 | test:
46 | is_test: False # set True when testing
47 | checkpoint: ??? # full-path to checkpoint (.pth)
48 | test_num_episodes: 100 # the number of test episodes per level
49 |
50 | optim:
51 | type: adam # [adam, adamw] type of optimizer
52 | learning_rate: 0.0001 # learning rate
53 | adam_beta1: 0.9 # adam momentum decay coefficient
54 | adam_beta2: 0.999 # adam second momentum decay coefficient
55 | adam_eps: 1e-06 # adam epsilon parameter (1e-8 to 1e-5 seem to reliably work okay)
56 | max_grad_norm: 2.5 # maximum L2 norm of the gradient vector (for clipping)
57 | rollout: 96 # length of the rollout from each environment in timesteps
58 | batch_size: 1536 # total batch size (batch X rollout)
59 | train_for_env_steps: 20000000000 # stop after a policy is trained for this many env steps
60 | warmup_optimizer: 100 # warm up step for leaner, optimizer step count
61 |
62 | shared_buffer:
63 | min_traj_buffers_per_worker: 4 # how many shared buffers to allocate per actor worker
64 |
65 | learner:
66 | exploration_loss_coeff: 1e-3 # coefficient for the exploration component of the loss function
67 | vtrace_rho: 1.0 # rho_hat clipping parameter of the V-trace algorithm
68 | vtrace_c: 1.0 # c_hat clipping parameter of the V-trace algorithm
69 | exclude_last: True # exclude last timestep vs when computing V-trace
70 | use_ppo: False # whether to use ppo or not
71 | ppo_clip_ratio: 0.1 # clipping ratio value for ppo algorithm
72 | ppo_clip_value: 0.2 # maximum absolute change in value estimate until it is clipped
73 | gamma: 0.99 # discount factor
74 | value_loss_coeff: 0.5 # coefficient for the critic loss
75 | keep_checkpoints: 3 # number of model checkpoints to keep
76 | use_adv_normalization: False # whether to use advantage normalization or not
77 | psychlab_gamma: -1.0 # specific gamma value for psychlab levels. use "gamma" if <0
78 | use_decoder: False # whether to use decoder for auxiliary reconstruction component or not
79 | reconstruction_loss_coeff: 0.0 # coefficient for auxiliary reconstruction component of the loss function
80 | use_aux_future_pred_loss: False # whether to add auxiliary loss for predicting obs of 2 to 10 future steps using auto regressive trnasformers
81 | aux_future_pred_loss_coeff: 1.0 # coefficient for auxiliary future predict component of the loss function autoregressive trnasformers
82 | resume_training : False # load latest checkpoint if set to True
83 |
84 | actor:
85 | num_workers: 30 # the number of actor processes for this node
86 | num_envs_per_worker: 6 # number of envs on sampled sequentially
87 | num_splits: 2 # set 2 for double buffering
88 | set_workers_cpu_affinity: True # whether to assign workers to specific CPU cores or not
89 |
90 | env:
91 | name: dmlab_30 # name of environment. see DMLAB_LEVELS_BY_ENVNAME in dmlab30.py
92 | res_w: 96 # width of image frame
93 | res_h: 72 # height of image frame
94 | dataset_path: ./brady_konkle_oliva2008 # path to dataset needed for some of the environments in DMLab-30
95 | use_level_cache: True # whether to use the local level cache (highly recommended)
96 | level_cache_path: ./dmlab_cache # location to store the cached levels
97 | action_set: extended_action_set # [impala_action_set, extended_action_set, extended_action_set_large]. see dmlab30.py
98 | frameskip: 4 # the number of frames for action repeat (frame skipping)
99 | obs_subtract_mean: 128.0 # mean value to subtract from observation
100 | obs_scale: 128.0 # scale value to for normalization
101 | one_task_per_worker: False # every envs in VectorEnvRunner will run the same level if set to True
102 | decorrelate_envs_on_one_worker: True # decorrelation of worker processes
103 | decorrelate_experience_max_seconds: 10 # maximum seconds for decorrelation
104 |
105 | log:
106 | save_milestones_step: 1000000000 # save intermediate checkpoints in a separate folder for later evaluation
107 | save_every_sec: 3600 # checkpointing rate
108 | log_level: debug # logging level
109 | report_interval: 300.0 # how often in seconds we write summaries
110 | num_stats_average: 100 # the number of stats to average (for tensorboard logging)
--------------------------------------------------------------------------------
/brain_agent/envs/dmlab/dmlab_level_cache.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ctypes
3 | import multiprocessing
4 | import random
5 | import shutil
6 | from pathlib import Path
7 |
8 | from brain_agent.utils.logger import log
9 |
10 | LEVEL_SEEDS_FILE_EXT = 'dm_lvl_seeds'
11 |
12 |
13 | def filename_to_level(filename):
14 | level = filename.split('.')[0]
15 | level = level[1:]
16 | return level
17 |
18 |
19 | def level_to_filename(level):
20 | filename = f'_{level}.{LEVEL_SEEDS_FILE_EXT}'
21 | return filename
22 |
23 |
24 | def read_seeds_file(filename, has_keys):
25 | seeds = []
26 |
27 | with open(filename, 'r') as seed_file:
28 | lines = seed_file.readlines()
29 | for line in lines:
30 | try:
31 | if has_keys:
32 | seed, cache_key = line.split(' ')
33 | else:
34 | seed = line
35 |
36 | seed = int(seed)
37 | seeds.append(seed)
38 | except Exception:
39 | pass
40 |
41 | return seeds
42 |
43 |
44 | class DmlabLevelCacheGlobal:
45 |
46 | def __init__(self, cache_dir, experiment_dir, levels):
47 | self.cache_dir = cache_dir
48 | self.experiment_dir = experiment_dir
49 |
50 | self.all_seeds = dict()
51 | self.available_seeds = dict()
52 | self.used_seeds = dict()
53 | self.num_seeds_used_in_current_run = dict()
54 | self.locks = dict()
55 |
56 | for lvl in levels:
57 | self.all_seeds[lvl] = []
58 | self.available_seeds[lvl] = []
59 | self.num_seeds_used_in_current_run[lvl] = multiprocessing.RawValue(ctypes.c_int32, 0)
60 | self.locks[lvl] = multiprocessing.Lock()
61 |
62 | lvl_seed_files = Path(os.path.join(cache_dir, '_contributed')).rglob(f'*.{LEVEL_SEEDS_FILE_EXT}')
63 | for lvl_seed_file in lvl_seed_files:
64 | lvl_seed_file = str(lvl_seed_file)
65 | level = filename_to_level(os.path.relpath(lvl_seed_file, cache_dir))
66 | self.all_seeds[level] = read_seeds_file(lvl_seed_file, has_keys=True)
67 | self.all_seeds[level] = list(set(self.all_seeds[level]))
68 |
69 | used_lvl_seeds_dir = os.path.join(self.experiment_dir, f'dmlab_used_lvl_seeds')
70 | os.makedirs(used_lvl_seeds_dir, exist_ok=True)
71 |
72 | used_seeds_files = Path(used_lvl_seeds_dir).rglob(f'*.{LEVEL_SEEDS_FILE_EXT}')
73 | self.used_seeds = dict()
74 | for used_seeds_file in used_seeds_files:
75 | used_seeds_file = str(used_seeds_file)
76 | level = filename_to_level(os.path.relpath(used_seeds_file, used_lvl_seeds_dir))
77 | self.used_seeds[level] = read_seeds_file(used_seeds_file, has_keys=False)
78 |
79 | self.used_seeds[level] = set(self.used_seeds[level])
80 |
81 | for lvl in levels:
82 | lvl_seeds = self.all_seeds.get(lvl, [])
83 | lvl_used_seeds = self.used_seeds.get(lvl, [])
84 |
85 | lvl_remaining_seeds = set(lvl_seeds) - set(lvl_used_seeds)
86 | self.available_seeds[lvl] = list(lvl_remaining_seeds)
87 |
88 | random.shuffle(self.available_seeds[lvl])
89 | log.debug('Env %s has %d remaining unused seeds', lvl, len(self.available_seeds[lvl]))
90 |
91 | def record_used_seed(self, level, seed):
92 | self.num_seeds_used_in_current_run[level].value += 1
93 |
94 | used_lvl_seeds_dir = os.path.join(self.experiment_dir, f'dmlab_used_lvl_seeds')
95 | used_seeds_filename = os.path.join(used_lvl_seeds_dir, level_to_filename(level))
96 | os.makedirs(os.path.dirname(used_seeds_filename), exist_ok=True)
97 |
98 | with open(used_seeds_filename, 'a') as fobj:
99 | fobj.write(f'{seed}\n')
100 |
101 | if level not in self.used_seeds:
102 | self.used_seeds[level] = {seed}
103 | else:
104 | self.used_seeds[level].add(seed)
105 |
106 | def get_unused_seed(self, level, random_state=None):
107 | with self.locks[level]:
108 | num_used_seeds = self.num_seeds_used_in_current_run[level].value
109 | if num_used_seeds >= len(self.available_seeds.get(level, [])):
110 |
111 | while True:
112 | if random_state is not None:
113 | new_seed = random_state.randint(0, 2 ** 31 - 1)
114 | else:
115 | new_seed = random.randint(0, 2 ** 31 - 1)
116 |
117 | if level not in self.used_seeds:
118 | break
119 |
120 | if new_seed in self.used_seeds[level]:
121 | pass
122 | else:
123 | break
124 | else:
125 | new_seed = self.available_seeds[level][num_used_seeds]
126 |
127 | self.record_used_seed(level, new_seed)
128 | return new_seed
129 |
130 | def add_new_level(self, level, seed, key, pk3_path):
131 | with self.locks[level]:
132 | num_used_seeds = self.num_seeds_used_in_current_run[level].value
133 | if num_used_seeds < len(self.available_seeds.get(level, [])):
134 | log.warning('We should only add new levels to cache if we ran out of pre-generated levels (seeds)')
135 | log.warning(
136 | 'Num used seeds: %d, available seeds: %d, level: %s, seed %r, key %r',
137 | num_used_seeds, len(self.available_seeds.get(level, [])), level, seed, key,
138 | )
139 |
140 | path = os.path.join(self.cache_dir, key)
141 | if not os.path.isfile(path):
142 | shutil.copyfile(pk3_path, path)
143 |
144 | lvl_seeds_filename = os.path.join(self.cache_dir, level_to_filename(level))
145 | os.makedirs(os.path.dirname(lvl_seeds_filename), exist_ok=True)
146 | with open(lvl_seeds_filename, 'a') as fobj:
147 | fobj.write(f'{seed} {key}\n')
148 |
149 |
150 | def dmlab_ensure_global_cache_initialized(experiment_dir, levels, level_cache_dir):
151 | global DMLAB_GLOBAL_LEVEL_CACHE
152 |
153 | assert multiprocessing.current_process().name == 'MainProcess', \
154 | 'make sure you initialize DMLab cache before child processes are forked'
155 |
156 | log.info('Initializing level cache...')
157 | DMLAB_GLOBAL_LEVEL_CACHE = DmlabLevelCacheGlobal(level_cache_dir, experiment_dir, levels)
158 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 | from faster_fifo import Queue, Empty
4 | import multiprocessing
5 | from tensorboardX import SummaryWriter
6 | import numpy as np
7 |
8 | from brain_agent.core.actor_worker import ActorWorker
9 | from brain_agent.core.policy_worker import PolicyWorker
10 | from brain_agent.core.shared_buffer import SharedBuffer
11 | from brain_agent.core.learner_worker import LearnerWorker
12 | from brain_agent.utils.cfg import Configs
13 | from brain_agent.utils.logger import log, init_logger
14 | from brain_agent.utils.utils import get_log_path, dict_of_list_put, get_summary_dir, AttrDict
15 | from brain_agent.envs.env_utils import create_env
16 |
17 |
18 | def main():
19 |
20 | cfg = Configs.get_defaults()
21 | cfg = Configs.override_from_file_name(cfg)
22 | cfg = Configs.override_from_cli(cfg)
23 |
24 | cfg_str = Configs.to_yaml(cfg)
25 | cfg = Configs.to_attr_dict(cfg)
26 |
27 | init_logger(cfg.log.log_level, get_log_path(cfg))
28 |
29 | log.info(f'Experiment configuration:\n{cfg_str}')
30 |
31 | tmp_env = create_env(cfg, env_config=None)
32 | action_space = tmp_env.action_space
33 | obs_space = tmp_env.observation_space
34 | level_info = tmp_env.level_info
35 | tmp_env.close()
36 |
37 | shared_buffer = SharedBuffer(cfg, obs_space, action_space)
38 |
39 | learner_worker_queue = Queue()
40 | policy_worker_queue = Queue()
41 | actor_worker_queues = [Queue(2 * 1000 * 1000) for _ in range(cfg.actor.num_workers)]
42 | policy_queue = Queue()
43 | report_queue = Queue(40 * 1000 * 1000)
44 |
45 | policy_lock = multiprocessing.Lock()
46 | resume_experience_collection_cv = multiprocessing.Condition()
47 |
48 | learner_worker = LearnerWorker(cfg, obs_space, action_space, level_info, report_queue, learner_worker_queue,
49 | policy_worker_queue,
50 | shared_buffer, policy_lock, resume_experience_collection_cv)
51 | learner_worker.start_process()
52 | learner_worker.init()
53 |
54 | policy_worker = PolicyWorker(cfg, obs_space, action_space, level_info, shared_buffer,
55 | policy_queue, actor_worker_queues, policy_worker_queue, report_queue, policy_lock, resume_experience_collection_cv)
56 | policy_worker.start_process() # init(), init_model() will be triggered from learner worker
57 |
58 | actor_workers = []
59 | for i in range(cfg.actor.num_workers):
60 | w = ActorWorker(cfg, obs_space, action_space, i, shared_buffer, actor_worker_queues[i], policy_queue,
61 | report_queue, learner_worker_queue)
62 | w.init()
63 | w.request_reset()
64 | actor_workers.append(w)
65 |
66 | summary_dir = get_summary_dir(cfg=cfg)
67 | writer = SummaryWriter(summary_dir) if cfg.dist.world_rank == 0 else None
68 |
69 | # Add configuration in tensorboard
70 | if cfg.dist.world_rank == 0:
71 | cfg_str = cfg_str.replace(' ', ' ').replace('\n', ' \n')
72 | writer.add_text('cfg', cfg_str, 0)
73 |
74 | stats = AttrDict()
75 | stats['episodic_stats'] = AttrDict()
76 |
77 | last_report = time.time()
78 | last_env_steps = 0
79 | terminate = False
80 | reports = []
81 |
82 | while not terminate:
83 | try:
84 | reports.extend(report_queue.get_many(timeout=0.1))
85 |
86 | if time.time() - last_report > cfg.log.report_interval:
87 | interval = time.time() - last_report
88 | last_report = time.time()
89 | terminate, last_env_steps = process_report(cfg, reports, writer, stats, last_env_steps, level_info,
90 | interval)
91 | reports = []
92 | except Empty:
93 | time.sleep(1.0)
94 | pass
95 |
96 | def process_report(cfg, reports, writer, stats, last_env_steps, level_info, interval):
97 | terminate = False
98 | env_steps = last_env_steps
99 |
100 | for report in reports:
101 | if report is not None:
102 | if 'terminate' in report:
103 | terminate = True
104 | if 'learner_env_steps' in report:
105 | env_steps = report['learner_env_steps']
106 | if 'train' in report:
107 | s = report['train']
108 | for k, v in s.items():
109 | dict_of_list_put(stats, f'train/{k}', v, cfg.log.num_stats_average)
110 | if 'episodic_stats' in report:
111 | s = report['episodic_stats']
112 | level_name = s['level_name']
113 | level_id = s['task_id']
114 |
115 | tag = f'_dmlab/{level_id:02d}_{level_name}_human_norm_score'
116 | dict_of_list_put(stats.episodic_stats, tag, s['hns'], cfg.log.num_stats_average)
117 |
118 | fps = (env_steps - last_env_steps) / interval
119 | dict_of_list_put(stats, f'train/_fps', fps, cfg.log.num_stats_average)
120 |
121 | key_timings = ['times_learner_worker', 'times_actor_worker', 'times_policy_worker']
122 | for key in key_timings:
123 | if key in report:
124 | for k, v in report[key].items():
125 | tag = key+'/'+k
126 | dict_of_list_put(stats, tag, v, cfg.log.num_stats_average)
127 |
128 |
129 | if writer is not None:
130 | for k, v in stats.items():
131 | if k == 'episodic_stats':
132 | hns = []
133 | for kk, vv in v.items():
134 | writer.add_scalar(kk, np.array(vv).mean(), env_steps)
135 | hns.append(np.array(vv).mean())
136 |
137 | if len(v.keys()) == level_info['num_levels']:
138 | hns = np.array(hns)
139 | capped_hns = np.clip(hns, None, 100)
140 | writer.add_scalar(f'_dmlab/000_mean_human_norm_score', hns.mean(), env_steps)
141 | writer.add_scalar(f'_dmlab/000_mean_capped_human_norm_score', capped_hns.mean(), env_steps)
142 | writer.add_scalar(f'_dmlab/000_median_human_norm_score', np.median(hns), env_steps)
143 | else:
144 | writer.add_scalar(k, np.array(v).mean(), env_steps)
145 |
146 | if env_steps >= cfg.optim.train_for_env_steps:
147 | terminate = True
148 | return terminate, env_steps
149 |
150 |
151 | if __name__ == '__main__':
152 | sys.exit(main())
153 |
--------------------------------------------------------------------------------
/brain_agent/envs/dmlab/dmlab_gym.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | import hashlib
5 | import deepmind_lab
6 | import gym
7 | import numpy as np
8 |
9 | from brain_agent.envs.dmlab import dmlab_level_cache
10 | from brain_agent.envs.dmlab.dmlab30 import DMLAB_INSTRUCTIONS, DMLAB_MAX_INSTRUCTION_LEN, DMLAB_VOCABULARY_SIZE, \
11 | IMPALA_ACTION_SET, EXTENDED_ACTION_SET, EXTENDED_ACTION_SET_LARGE
12 | from brain_agent.utils.logger import log
13 |
14 |
15 | def string_to_hash_bucket(s, vocabulary_size):
16 | return (int(hashlib.md5(s.encode('utf-8')).hexdigest(), 16) % (vocabulary_size - 1)) + 1
17 |
18 |
19 | def dmlab_level_to_level_name(level):
20 | level_name = level.split('/')[-1]
21 | return level_name
22 |
23 |
24 | class DmlabGymEnv(gym.Env):
25 | def __init__(
26 | self, task_id, level, action_repeat, res_w, res_h, dataset_path,
27 | action_set, use_level_cache, level_cache_path, extra_cfg=None,
28 | ):
29 | self.width = res_w
30 | self.height = res_h
31 |
32 | self.main_observation = 'RGB_INTERLEAVED'
33 | self.instructions_observation = DMLAB_INSTRUCTIONS
34 |
35 | self.action_repeat = action_repeat
36 |
37 | self.random_state = None
38 |
39 | self.task_id = task_id
40 | self.level = level
41 | self.level_name = dmlab_level_to_level_name(self.level)
42 |
43 | self.cache = dmlab_level_cache.DMLAB_GLOBAL_LEVEL_CACHE
44 |
45 | self.instructions = np.zeros([DMLAB_MAX_INSTRUCTION_LEN], dtype=np.int32)
46 |
47 | observation_format = [self.main_observation]
48 | observation_format += [self.instructions_observation]
49 |
50 | config = {
51 | 'width': self.width,
52 | 'height': self.height,
53 | 'datasetPath': dataset_path,
54 | 'logLevel': 'error',
55 | }
56 |
57 | if extra_cfg is not None:
58 | config.update(extra_cfg)
59 | config = {k: str(v) for k, v in config.items()}
60 |
61 | self.use_level_cache = use_level_cache
62 | self.level_cache_path = level_cache_path
63 |
64 | env_level_cache = self if use_level_cache else None
65 | self.env_uses_level_cache = False # will be set to True when this env instance queries the cache
66 | self.last_reset_seed = None
67 |
68 | if env_level_cache is not None:
69 | if not isinstance(self.cache, dmlab_level_cache.DmlabLevelCacheGlobal):
70 | raise Exception(
71 | 'DMLab global level cache object is not initialized! Make sure to call'
72 | 'dmlab_ensure_global_cache_initialized() in the main thread before you fork any child processes'
73 | 'or create any DMLab envs'
74 | )
75 |
76 | self.dmlab = deepmind_lab.Lab(
77 | level, observation_format, config=config, renderer='software', level_cache=env_level_cache,
78 | )
79 |
80 | if action_set == 'impala_action_set':
81 | self.action_set = IMPALA_ACTION_SET
82 | elif action_set == 'extended_action_set':
83 | self.action_set = EXTENDED_ACTION_SET
84 | elif action_set == 'extended_action_set_large':
85 | self.action_set = EXTENDED_ACTION_SET_LARGE
86 | self.action_list = np.array(self.action_set, dtype=np.intc) # DMLAB requires intc type for actions
87 |
88 | self.last_observation = None
89 |
90 | self.action_space = gym.spaces.Discrete(len(self.action_set))
91 |
92 | self.observation_space = gym.spaces.Dict(
93 | obs=gym.spaces.Box(low=0, high=255, shape=(self.height, self.width, 3), dtype=np.uint8)
94 | )
95 | self.observation_space.spaces[self.instructions_observation] = gym.spaces.Box(
96 | low=0, high=DMLAB_VOCABULARY_SIZE, shape=[DMLAB_MAX_INSTRUCTION_LEN], dtype=np.int32,
97 | )
98 |
99 | self.seed()
100 |
101 | def seed(self, seed=None):
102 | if seed is None:
103 | initial_seed = random.randint(0, int(1e9))
104 | else:
105 | initial_seed = seed
106 |
107 | self.random_state = np.random.RandomState(seed=initial_seed)
108 | return [initial_seed]
109 |
110 | def format_obs_dict(self, env_obs_dict):
111 | """We traditionally uses 'obs' key for the 'main' observation."""
112 |
113 | env_obs_dict['obs'] = env_obs_dict.pop(self.main_observation)
114 |
115 | instr = env_obs_dict.get(self.instructions_observation)
116 | self.instructions[:] = 0
117 | if instr is not None:
118 | instr_words = instr.split()
119 | for i, word in enumerate(instr_words):
120 | self.instructions[i] = string_to_hash_bucket(word, DMLAB_VOCABULARY_SIZE)
121 |
122 | env_obs_dict[self.instructions_observation] = self.instructions
123 |
124 | return env_obs_dict
125 |
126 | def reset(self):
127 | if self.use_level_cache:
128 | self.last_reset_seed = self.cache.get_unused_seed(self.level, self.random_state)
129 | else:
130 | self.last_reset_seed = self.random_state.randint(0, 2 ** 31 - 1)
131 |
132 | self.dmlab.reset(seed=self.last_reset_seed)
133 | self.last_observation = self.format_obs_dict(self.dmlab.observations())
134 | self.episodic_reward = 0
135 | return self.last_observation
136 |
137 | def step(self, action):
138 |
139 | reward = self.dmlab.step(self.action_list[action], num_steps=self.action_repeat)
140 | done = not self.dmlab.is_running()
141 |
142 | self.episodic_reward += reward
143 | info = {'num_frames': self.action_repeat}
144 |
145 | if not done:
146 | obs_dict = self.format_obs_dict(self.dmlab.observations())
147 | self.last_observation = obs_dict
148 | if done:
149 | self.reset()
150 |
151 | return self.last_observation, reward, done, info
152 |
153 |
154 | def close(self):
155 | self.dmlab.close()
156 |
157 |
158 | def fetch(self, key, pk3_path):
159 | if not self.env_uses_level_cache:
160 | self.env_uses_level_cache = True
161 |
162 | path = os.path.join(self.level_cache_path, key)
163 |
164 | if os.path.isfile(path):
165 | shutil.copyfile(path, pk3_path)
166 | return True
167 | else:
168 | log.warning('Cache miss in environment %s key: %s!', self.level_name, key)
169 | return False
170 |
171 | def write(self, key, pk3_path):
172 | log.debug('Add new level to cache! Level %s seed %r key %s', self.level_name, self.last_reset_seed, key)
173 | self.cache.add_new_level(self.level, self.last_reset_seed, key, pk3_path)
174 |
--------------------------------------------------------------------------------
/brain_agent/envs/dmlab/dmlab30.py:
--------------------------------------------------------------------------------
1 | DMLAB_INSTRUCTIONS = 'INSTR'
2 | DMLAB_VOCABULARY_SIZE = 1000
3 | DMLAB_MAX_INSTRUCTION_LEN = 16
4 |
5 |
6 | HUMAN_SCORES = {
7 | 'rooms_collect_good_objects_test': 10,
8 | 'rooms_exploit_deferred_effects_test': 85.65,
9 | 'rooms_select_nonmatching_object': 65.9,
10 | 'rooms_watermaze': 54,
11 | 'rooms_keys_doors_puzzle': 53.8,
12 | 'language_select_described_object': 389.5,
13 | 'language_select_located_object': 280.7,
14 | 'language_execute_random_task': 254.05,
15 | 'language_answer_quantitative_question': 184.5,
16 | 'lasertag_one_opponent_small': 12.65,
17 | 'lasertag_three_opponents_small': 18.55,
18 | 'lasertag_one_opponent_large': 18.6,
19 | 'lasertag_three_opponents_large': 31.5,
20 | 'natlab_fixed_large_map': 36.9,
21 | 'natlab_varying_map_regrowth': 24.45,
22 | 'natlab_varying_map_randomized': 42.35,
23 | 'skymaze_irreversible_path_hard': 100,
24 | 'skymaze_irreversible_path_varied': 100,
25 | 'psychlab_arbitrary_visuomotor_mapping': 58.75,
26 | 'psychlab_continuous_recognition': 58.3,
27 | 'psychlab_sequential_comparison': 39.5,
28 | 'psychlab_visual_search': 78.5,
29 | 'explore_object_locations_small': 74.45,
30 | 'explore_object_locations_large': 65.65,
31 | 'explore_obstructed_goals_small': 206,
32 | 'explore_obstructed_goals_large': 119.5,
33 | 'explore_goal_locations_small': 267.5,
34 | 'explore_goal_locations_large': 194.5,
35 | 'explore_object_rewards_few': 77.7,
36 | 'explore_object_rewards_many': 106.7,
37 | }
38 |
39 |
40 | RANDOM_SCORES = {
41 | 'rooms_collect_good_objects_test': 0.073,
42 | 'rooms_exploit_deferred_effects_test': 8.501,
43 | 'rooms_select_nonmatching_object': 0.312,
44 | 'rooms_watermaze': 4.065,
45 | 'rooms_keys_doors_puzzle': 4.135,
46 | 'language_select_described_object': -0.07,
47 | 'language_select_located_object': 1.929,
48 | 'language_execute_random_task': -5.913,
49 | 'language_answer_quantitative_question': -0.33,
50 | 'lasertag_one_opponent_small': -0.224,
51 | 'lasertag_three_opponents_small': -0.214,
52 | 'lasertag_one_opponent_large': -0.083,
53 | 'lasertag_three_opponents_large': -0.102,
54 | 'natlab_fixed_large_map': 2.173,
55 | 'natlab_varying_map_regrowth': 2.989,
56 | 'natlab_varying_map_randomized': 7.346,
57 | 'skymaze_irreversible_path_hard': 0.1,
58 | 'skymaze_irreversible_path_varied': 14.4,
59 | 'psychlab_arbitrary_visuomotor_mapping': 0.163,
60 | 'psychlab_continuous_recognition': 0.224,
61 | 'psychlab_sequential_comparison': 0.129,
62 | 'psychlab_visual_search': 0.085,
63 | 'explore_object_locations_small': 3.575,
64 | 'explore_object_locations_large': 4.673,
65 | 'explore_obstructed_goals_small': 6.76,
66 | 'explore_obstructed_goals_large': 2.61,
67 | 'explore_goal_locations_small': 7.66,
68 | 'explore_goal_locations_large': 3.14,
69 | 'explore_object_rewards_few': 2.073,
70 | 'explore_object_rewards_many': 2.438,
71 | }
72 |
73 | DMLAB_LEVELS_BY_ENVNAME = {
74 | 'dmlab_30':
75 | ['contributed/dmlab30/rooms_collect_good_objects_train',
76 | 'contributed/dmlab30/rooms_exploit_deferred_effects_train',
77 | 'contributed/dmlab30/rooms_select_nonmatching_object',
78 | 'contributed/dmlab30/rooms_watermaze',
79 | 'contributed/dmlab30/rooms_keys_doors_puzzle',
80 | 'contributed/dmlab30/language_select_described_object',
81 | 'contributed/dmlab30/language_select_located_object',
82 | 'contributed/dmlab30/language_execute_random_task',
83 | 'contributed/dmlab30/language_answer_quantitative_question',
84 | 'contributed/dmlab30/lasertag_one_opponent_small',
85 | 'contributed/dmlab30/lasertag_three_opponents_small',
86 | 'contributed/dmlab30/lasertag_one_opponent_large',
87 | 'contributed/dmlab30/lasertag_three_opponents_large',
88 | 'contributed/dmlab30/natlab_fixed_large_map',
89 | 'contributed/dmlab30/natlab_varying_map_regrowth',
90 | 'contributed/dmlab30/natlab_varying_map_randomized',
91 | 'contributed/dmlab30/skymaze_irreversible_path_hard',
92 | 'contributed/dmlab30/skymaze_irreversible_path_varied',
93 | 'contributed/dmlab30/psychlab_arbitrary_visuomotor_mapping',
94 | 'contributed/dmlab30/psychlab_continuous_recognition',
95 | 'contributed/dmlab30/psychlab_sequential_comparison',
96 | 'contributed/dmlab30/psychlab_visual_search',
97 | 'contributed/dmlab30/explore_object_locations_small',
98 | 'contributed/dmlab30/explore_object_locations_large',
99 | 'contributed/dmlab30/explore_obstructed_goals_small',
100 | 'contributed/dmlab30/explore_obstructed_goals_large',
101 | 'contributed/dmlab30/explore_goal_locations_small',
102 | 'contributed/dmlab30/explore_goal_locations_large',
103 | 'contributed/dmlab30/explore_object_rewards_few',
104 | 'contributed/dmlab30/explore_object_rewards_many'
105 | ],
106 | 'dmlab_30_test':
107 | ['contributed/dmlab30/rooms_collect_good_objects_test',
108 | 'contributed/dmlab30/rooms_exploit_deferred_effects_test',
109 | 'contributed/dmlab30/rooms_select_nonmatching_object',
110 | 'contributed/dmlab30/rooms_watermaze',
111 | 'contributed/dmlab30/rooms_keys_doors_puzzle',
112 | 'contributed/dmlab30/language_select_described_object',
113 | 'contributed/dmlab30/language_select_located_object',
114 | 'contributed/dmlab30/language_execute_random_task',
115 | 'contributed/dmlab30/language_answer_quantitative_question',
116 | 'contributed/dmlab30/lasertag_one_opponent_small',
117 | 'contributed/dmlab30/lasertag_three_opponents_small',
118 | 'contributed/dmlab30/lasertag_one_opponent_large',
119 | 'contributed/dmlab30/lasertag_three_opponents_large',
120 | 'contributed/dmlab30/natlab_fixed_large_map',
121 | 'contributed/dmlab30/natlab_varying_map_regrowth',
122 | 'contributed/dmlab30/natlab_varying_map_randomized',
123 | 'contributed/dmlab30/skymaze_irreversible_path_hard',
124 | 'contributed/dmlab30/skymaze_irreversible_path_varied',
125 | 'contributed/dmlab30/psychlab_arbitrary_visuomotor_mapping',
126 | 'contributed/dmlab30/psychlab_continuous_recognition',
127 | 'contributed/dmlab30/psychlab_sequential_comparison',
128 | 'contributed/dmlab30/psychlab_visual_search',
129 | 'contributed/dmlab30/explore_object_locations_small',
130 | 'contributed/dmlab30/explore_object_locations_large',
131 | 'contributed/dmlab30/explore_obstructed_goals_small',
132 | 'contributed/dmlab30/explore_obstructed_goals_large',
133 | 'contributed/dmlab30/explore_goal_locations_small',
134 | 'contributed/dmlab30/explore_goal_locations_large',
135 | 'contributed/dmlab30/explore_object_rewards_few',
136 | 'contributed/dmlab30/explore_object_rewards_many'
137 | ]
138 | }
139 |
140 | IMPALA_ACTION_SET = (
141 | (0, 0, 0, 1, 0, 0, 0), # Forward
142 | (0, 0, 0, -1, 0, 0, 0), # Backward
143 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left
144 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right
145 | (-20, 0, 0, 0, 0, 0, 0), # Look Left
146 | (20, 0, 0, 0, 0, 0, 0), # Look Right
147 | (-20, 0, 0, 1, 0, 0, 0), # Look Left + Forward
148 | (20, 0, 0, 1, 0, 0, 0), # Look Right + Forward
149 | (0, 0, 0, 0, 1, 0, 0), # Fire.
150 | )
151 |
152 |
153 | EXTENDED_ACTION_SET = (
154 | (0, 0, 0, 1, 0, 0, 0), # Forward
155 | (0, 0, 0, -1, 0, 0, 0), # Backward
156 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left
157 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right
158 | (-10, 0, 0, 0, 0, 0, 0), # Small Look Left
159 | (10, 0, 0, 0, 0, 0, 0), # Small Look Right
160 | (-60, 0, 0, 0, 0, 0, 0), # Large Look Left
161 | (60, 0, 0, 0, 0, 0, 0), # Large Look Right
162 | (0, 10, 0, 0, 0, 0, 0), # Look Down
163 | (0, -10, 0, 0, 0, 0, 0), # Look Up
164 | (-10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Left
165 | (10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Right
166 | (-60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Left
167 | (60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Right
168 | (0, 0, 0, 0, 1, 0, 0), # Fire.
169 | )
170 |
171 | EXTENDED_ACTION_SET_LARGE = (
172 | (0, 0, 0, 1, 0, 0, 0), # Forward
173 | (0, 0, 0, -1, 0, 0, 0), # Backward
174 | (0, 0, -1, 0, 0, 0, 0), # Strafe Left
175 | (0, 0, 1, 0, 0, 0, 0), # Strafe Right
176 | (-10, 0, 0, 0, 0, 0, 0), # Small Look Left
177 | (10, 0, 0, 0, 0, 0, 0), # Small Look Right
178 | (-60, 0, 0, 0, 0, 0, 0), # Large Look Left
179 | (60, 0, 0, 0, 0, 0, 0), # Large Look Right
180 | (0, 10, 0, 0, 0, 0, 0), # Look Down
181 | (0, -10, 0, 0, 0, 0, 0), # Look Up
182 | (0, 60, 0, 0, 0, 0, 0), # Large Look Down
183 | (0, -60, 0, 0, 0, 0, 0), # Large Look Up
184 | (-10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Left
185 | (10, 0, 0, 1, 0, 0, 0), # Forward + Small Look Right
186 | (-60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Left
187 | (60, 0, 0, 1, 0, 0, 0), # Forward + Large Look Right
188 | (0, 0, 0, 0, 1, 0, 0), # Fire.
189 | )
190 |
--------------------------------------------------------------------------------
/brain_agent/core/agents/dmlab_multitask_agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from brain_agent.core.agents.agent_abc import ActorCriticBase
4 | from brain_agent.envs.dmlab.dmlab_model import DmlabEncoder
5 | from brain_agent.core.models.transformer import MemTransformerLM
6 | from brain_agent.core.models.rnn import LSTM
7 | from brain_agent.core.algos.popart import update_parameters, update_mu_sigma
8 | from brain_agent.core.models.model_utils import normalize_obs_return
9 | from brain_agent.core.models.action_distributions import sample_actions_log_probs
10 | from brain_agent.utils.utils import AttrDict
11 |
12 | from brain_agent.core.algos.aux_future_predict import FuturePredict
13 |
14 |
15 | class DMLabMultiTaskAgent(ActorCriticBase):
16 | def __init__(self, cfg, action_space, obs_space, num_levels, need_half=False):
17 | super().__init__(action_space, cfg)
18 |
19 | self.encoder = DmlabEncoder(cfg, obs_space)
20 |
21 | core_input_size = self.encoder.get_encoder_out_size()
22 | if cfg.model.extended_input:
23 | core_input_size += action_space.n + 1
24 |
25 | if cfg.model.core.core_type == "trxl":
26 | self.core = MemTransformerLM(cfg, n_layer=cfg.model.core.n_layer, n_head=cfg.model.core.n_heads,
27 | d_head=cfg.model.core.d_head, d_model=core_input_size,
28 | d_inner=cfg.model.core.d_inner,
29 | mem_len=cfg.model.core.mem_len, pre_lnorm=True)
30 |
31 | elif cfg.model.core.core_type == "rnn":
32 | self.core = LSTM(cfg, core_input_size)
33 | else:
34 | raise Exception('Error: Not support given model core_type')
35 |
36 | core_out_size = self.core.get_core_out_size()
37 |
38 | if self.cfg.model.use_popart:
39 | self.register_buffer('mu', torch.zeros(num_levels, requires_grad=False))
40 | self.register_buffer('nu', torch.ones(num_levels, requires_grad=False))
41 | self.critic_linear = nn.Linear(core_out_size, num_levels)
42 | self.beta = self.cfg.model.popart_beta
43 | else:
44 | self.critic_linear = nn.Linear(core_out_size, 1)
45 |
46 | if cfg.learner.use_aux_future_pred_loss and cfg.model.core.core_type == "trxl":
47 | self.future_pred_module = FuturePredict(self.cfg, self.encoder.basic_encoder.input_ch,
48 | self.encoder.basic_encoder.conv_head_out_size,
49 | core_out_size,
50 | action_space)
51 |
52 | self.action_parameterization = self.get_action_parameterization(core_out_size)
53 |
54 | self.action_parameterization.apply(self.initialize)
55 | self.critic_linear.apply(self.initialize)
56 | self.train()
57 |
58 | self.need_half = need_half
59 | if self.need_half:
60 | self.half()
61 |
62 | def initialize(self, layer):
63 | def init_weight(weight):
64 | nn.init.normal_(weight, 0.0, 0.02)
65 |
66 | def init_bias(bias):
67 | nn.init.constant_(bias, 0.0)
68 |
69 | classname = layer.__class__.__name__
70 | if classname.find('Linear') != -1:
71 | if hasattr(layer, 'weight') and layer.weight is not None:
72 | init_weight(layer.weight)
73 | if hasattr(layer, 'bias') and layer.bias is not None:
74 | init_bias(layer.bias)
75 |
76 | # TODO: functionalize popart
77 | # def update_parameters(self, mu, sigma, oldmu, oldsigma):
78 | # self.critic_linear.weight.data, self.critic_linear.bias.data = \
79 | # update_parameters(self.critic_linear.weight, self.critic_linear.bias, mu, sigma, oldmu, oldsigma)
80 |
81 | def update_parameters(self, mu, sigma, oldmu, oldsigma):
82 | self.critic_linear.weight.data = (self.critic_linear.weight.t() * oldsigma / sigma).t()
83 | self.critic_linear.bias.data = (oldsigma * self.critic_linear.bias + oldmu - mu) / sigma
84 |
85 | def update_mu_sigma(self, vs, task_ids, cfg=None):
86 | oldnu = self.nu.clone()
87 | oldsigma = torch.sqrt(oldnu - self.mu ** 2)
88 | oldsigma[torch.isnan(oldsigma)] = self.cfg.model.popart_clip_min
89 | clamp_max = 1e4 if hasattr(self, 'need_half') and self.need_half else 1e6
90 | oldsigma = torch.clamp(oldsigma, min=cfg.model.popart_clip_min, max=clamp_max)
91 | oldmu = self.mu.clone()
92 |
93 | vs = vs.reshape(-1, self.cfg.optim.rollout)
94 | # same task ids over all time steps within a single episode
95 | task_ids_per_epi = task_ids.reshape(-1, self.cfg.optim.rollout)[:, 0]
96 | for i in range(len(task_ids_per_epi)):
97 | task_id = task_ids_per_epi[i]
98 | v = torch.mean(vs[i])
99 | self.mu[task_id] = (1 - self.beta) * self.mu[task_id] + self.beta * v
100 | self.nu[task_id] = (1 - self.beta) * self.nu[task_id] + self.beta * (v ** 2)
101 |
102 | sigma = torch.sqrt(self.nu - self.mu ** 2)
103 | sigma[torch.isnan(sigma)] = self.cfg.model.popart_clip_min
104 | sigma = torch.clamp(sigma, min=cfg.model.popart_clip_min, max=clamp_max)
105 |
106 | return self.mu, sigma, oldmu, oldsigma
107 |
108 | def forward_head(self, obs_dict, actions, rewards, decode=False):
109 | obs_dict = normalize_obs_return(obs_dict, self.cfg) # before normalize, [0,255]
110 | if self.need_half:
111 | obs_dict['obs'] = obs_dict['obs'].half()
112 |
113 | x = self.encoder(obs_dict, decode=decode)
114 | if decode:
115 | self.reconstruction = self.encoder.basic_encoder.reconstruction
116 | else:
117 | self.reconstruction = None
118 |
119 | x_extended = []
120 | if self.cfg.model.extended_input:
121 | assert torch.min(actions) >= -1 and torch.max(actions) < self.action_space.n
122 | done_ids = actions.eq(-1).nonzero(as_tuple=False)
123 | actions[done_ids] = 0
124 | prev_actions = nn.functional.one_hot(actions, self.action_space.n).float()
125 | prev_actions[done_ids] = 0.
126 | x_extended.append(prev_actions)
127 | x_extended.append(rewards.clamp(-1, 1).unsqueeze(1))
128 |
129 | x = torch.cat([x] + x_extended, dim=-1)
130 |
131 | if self.need_half:
132 | x = x.half()
133 |
134 | return x
135 |
136 | def forward_core_transformer(self, head_output, mems=None, mem_begin_index=None, dones=None, from_learner=False):
137 | x, new_mems = self.core(head_output, mems, mem_begin_index, dones=dones, from_learner=from_learner)
138 | return x, new_mems
139 |
140 | def forward_core_rnn(self, head_output, rnn_states, dones, is_seq=None):
141 | x, new_rnn_states = self.core(head_output, rnn_states, dones, is_seq)
142 | return x, new_rnn_states
143 |
144 | def forward_tail(self, core_output, task_ids, with_action_distribution=False):
145 | values = self.critic_linear(core_output)
146 | normalized_values = values.clone()
147 | sigmas = torch.ones((values.size(0), 1), requires_grad=False)
148 | mus = torch.zeros((values.size(0), 1), requires_grad=False)
149 | if self.cfg.model.use_popart:
150 | normalized_values = normalized_values.gather(dim=1, index=task_ids)
151 | with torch.no_grad():
152 | nus = self.nu.index_select(dim=0, index=task_ids.squeeze(1)).unsqueeze(1)
153 | mus = self.mu.index_select(dim=0, index=task_ids.squeeze(1)).unsqueeze(1)
154 | sigmas = torch.sqrt(nus - mus ** 2)
155 | sigmas[torch.isnan(sigmas)] = self.cfg.model.popart_clip_min
156 | clamp_max = 1e4 if self.need_half else 1e6
157 | sigmas = torch.clamp(sigmas, min=self.cfg.model.popart_clip_min, max=clamp_max)
158 | values = normalized_values * sigmas + mus
159 |
160 | action_distribution_params, action_distribution = self.action_parameterization(core_output)
161 |
162 | actions, log_prob_actions = sample_actions_log_probs(action_distribution)
163 |
164 | result = AttrDict(dict(
165 | actions=actions,
166 | action_logits=action_distribution_params,
167 | log_prob_actions=log_prob_actions,
168 | values=values,
169 | normalized_values=normalized_values,
170 | sigmas=sigmas,
171 | mus=mus
172 | ))
173 |
174 | if with_action_distribution:
175 | result.action_distribution = action_distribution
176 |
177 | return result
178 |
179 | def forward(self, obs_dict, actions, rewards, mems=None, mem_begin_index=None, rnn_states=None, dones=None, is_seq=None,
180 | task_ids=None, with_action_distribution=False, from_learner=False):
181 | x = self.forward_head(obs_dict, actions, rewards)
182 |
183 | if self.cfg.model.core.core_type == 'trxl':
184 | x, new_mems = self.forward_core_transformer(x, mems, mem_begin_index, from_learner=from_learner)
185 | elif self.cfg.model.core.core_type == 'rnn':
186 | x, new_rnn_states = self.forward_core_rnn(x, rnn_states, dones, is_seq)
187 |
188 | assert not x.isnan().any()
189 |
190 | result = self.forward_tail(x, task_ids, with_action_distribution=with_action_distribution)
191 |
192 | if self.cfg.model.core.core_type == "trxl":
193 | result.mems = new_mems
194 | elif self.cfg.model.core.core_type == 'rnn':
195 | result.rnn_states = new_rnn_states
196 |
197 | return result
198 |
--------------------------------------------------------------------------------
/brain_agent/core/algos/aux_future_predict.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from brain_agent.core.models.model_utils import normalize_obs_return
7 | from brain_agent.core.models.resnet import ResBlock
8 | from brain_agent.core.models.causal_transformer import CausalTransformer
9 |
10 | '''
11 | To learn the useful representations, this module tries to minimize the difference between the predicted future
12 | observations and the real observations for k(2~10) steps. It uses causal transformer in autoregressive manner
13 | to predict the state transitions when the actual action sequence and state embedding of current step are given.
14 |
15 | For the current state embedding, we use the outputs of TrXL core. These state embeddings are concatenated
16 | with actual actions taken, and then fed into the causal transformer. The causal transformer iteratively
17 | produces next states embedding in autoregressive manner.
18 |
19 | After we produce all the future state embeddings for K (2~10) steps, we apply transposed convolutional
20 | decoding layer to map the state embedding into the original image observation space. Finally, we calculate
21 | the L2 distance between decoded image outputs and real images.
22 |
23 | With this auxiliary loss added, mean HNS(human normalized score) of 30 tasks increased from 123.6 to 128.0,
24 | although capped mean HNS decreased from 91.25 to 90.53.
25 | You can add this module by changing the argument learner.use_aux_future_pred_loss to True.
26 |
27 | '''
28 |
29 | class FuturePredict(nn.Module):
30 | def __init__(self, cfg, encoder_input_ch, conv_head_out_size, core_out_size, action_space):
31 | super(FuturePredict, self).__init__()
32 | self.cfg = cfg
33 | self.action_space = action_space
34 | self.n_action = action_space.n
35 | self.horizon_k: int = 10
36 | self.time_subsample: int = 6
37 | self.forward_subsample: int = 2
38 | self.core_out_size = core_out_size
39 | self.conv_head_out_size = conv_head_out_size
40 |
41 | if isinstance(action_space, gym.spaces.Discrete):
42 | self.action_sizes = [action_space.n]
43 | else:
44 | self.action_sizes = [space.n for space in action_space.spaces]
45 |
46 | self.g = nn.Sequential(
47 | nn.Linear(core_out_size, 256),
48 | nn.ReLU(),
49 | nn.Linear(256, conv_head_out_size)
50 | )
51 |
52 | self.mse_loss = nn.MSELoss(reduction='none')
53 | self.causal_transformer = CausalTransformer(core_out_size, action_space.n, pre_lnorm=True)
54 | self.causal_transformer_window = self.causal_transformer.mem_len
55 |
56 | if cfg.model.encoder.encoder_subtype == 'resnet_impala':
57 | resnet_conf = [[16, 2], [32, 2], [32, 2]]
58 | elif cfg.model.encoder.encoder_subtype == 'resnet_impala_large':
59 | resnet_conf = [[32, 2], [64, 2], [64, 2]]
60 | self.conv_out_ch = resnet_conf[-1][0]
61 | layers_decoder = list()
62 | curr_input_channels = encoder_input_ch
63 | for i, (out_channels, res_blocks) in enumerate(resnet_conf):
64 |
65 | for j in range(res_blocks):
66 | layers_decoder.append(ResBlock(cfg, curr_input_channels, curr_input_channels))
67 | layers_decoder.append(
68 | nn.ConvTranspose2d(out_channels, curr_input_channels, kernel_size=3, stride=2,
69 | padding=1, output_padding=1)
70 | )
71 | curr_input_channels = out_channels
72 |
73 | layers_decoder.reverse()
74 | self.layers_decoder = nn.Sequential(*layers_decoder)
75 |
76 | def calc_loss(self, mb, b_t, mems, mems_actions, mems_dones, mem_begin_index, num_traj, recurrence):
77 | not_dones = (1.0 - mb.dones).view(num_traj, recurrence, 1)
78 | mask_res = self._build_mask_and_subsample(not_dones)
79 | (forward_mask, unroll_subsample, time_subsample, max_k) = mask_res
80 |
81 | actions_raw = mb.actions.view(num_traj, recurrence, -1).long()
82 | mems = torch.split(mems, self.core_out_size, dim=-1)[-1].transpose(0,1)
83 | mem_len = mems.size(1)
84 | mems_actions = mems_actions.transpose(0,1)
85 | b_t = b_t.reshape(num_traj, recurrence, -1)
86 | obs = normalize_obs_return(mb.obs, self.cfg)
87 | obs = obs['obs'].reshape(num_traj, recurrence, 3, 72, 96)
88 |
89 | cat_out = torch.cat([mems, b_t], dim=1)
90 | cat_action = torch.cat([mems_actions, actions_raw], dim=1)
91 | mems_not_dones = 1.0 - mems_dones.float()
92 | cat_not_dones = torch.cat([mems_not_dones, not_dones], dim=1)
93 | mem_begin_index = torch.tensor(mem_begin_index, device=cat_out.device)
94 |
95 | x, y, z, w, forward_targets = self._get_transfomrer_input(obs, cat_out, cat_action, cat_not_dones, time_subsample, max_k, mem_len, mem_begin_index)
96 | h_pred_stack, done_mask = self._make_transformer_pred(x,y,z,w, num_traj, max_k)
97 | h_pred_stack = h_pred_stack.index_select(0, unroll_subsample)
98 | final_pred = self.g(h_pred_stack)
99 |
100 | x_dec = final_pred.view(-1, self.conv_out_ch, 9, 12)
101 | for i in range(len(self.layers_decoder)):
102 | layer_decoder = self.layers_decoder[i]
103 | x_dec = layer_decoder(x_dec)
104 | x_dec = torch.tanh(x_dec)
105 |
106 | with torch.no_grad():
107 | forward_targets = forward_targets.flatten(0, 1)
108 | forward_targets = forward_targets.transpose(0, 1)
109 | forward_targets = forward_targets.index_select(0, unroll_subsample)
110 | forward_targets = forward_targets.flatten(0, 1)
111 |
112 | loss = self.mse_loss(x_dec, forward_targets)
113 | loss = loss.view(loss.size()[0], -1).mean(-1, keepdim=True)
114 |
115 | final_mask = torch.logical_and(forward_mask, done_mask)
116 | loss = torch.masked_select(loss, final_mask.flatten(0,1)).mean()
117 | loss = loss * self.cfg.learner.aux_future_pred_loss_coeff
118 | return loss
119 |
120 | def _make_transformer_pred(self, states, actions, not_dones, mem_begin_index, num_traj, max_k):
121 | actions = nn.functional.one_hot(actions.long(), num_classes=self.n_action).squeeze(3).float()
122 | actions = actions.view(self.time_subsample*num_traj, -1, self.n_action).transpose(0,1)
123 | states = states.view(self.time_subsample*num_traj, -1, self.core_out_size).transpose(0,1)
124 | not_dones = not_dones.view(self.time_subsample*num_traj, -1, 1).transpose(0,1)
125 | dones = 1.0 - not_dones
126 | tokens = torch.cat([states, actions], dim=2)
127 | input_tokens_past = tokens[:self.causal_transformer_window-1,:,:]
128 | input_token_current = tokens[self.causal_transformer_window-1,:,:].unsqueeze(0)
129 | input_token = torch.cat([input_tokens_past, input_token_current], dim=0)
130 |
131 | mem_begin_index = mem_begin_index.view(-1)
132 |
133 | y, mems, mem_begin_index = self.causal_transformer(input_token, mem_begin_index, num_traj, mems=None)
134 |
135 | lst = []
136 | for mem in mems:
137 | lst.append(mem)
138 | mems = lst
139 |
140 | y = y[-1].unsqueeze(0)
141 | out = [y]
142 | for i in range(max_k-1):
143 | new_input = torch.cat([y, actions[self.causal_transformer_window+i].unsqueeze(0)], dim=-1)
144 | y, mems, mem_begin_index = self.causal_transformer(new_input, mem_begin_index, num_traj, mems=mems)
145 | out.append(y)
146 |
147 | done_mask = torch.ge(dones.sum(dim=0), 1.0)
148 | done_mask = torch.logical_not(done_mask) # False means masking.
149 |
150 | return torch.stack(out).squeeze(1), done_mask
151 |
152 | def _get_transfomrer_input(self, obs, cat_out, cat_action, cat_not_dones, time_subsample, max_k, mem_len, mem_begin_index):
153 | out_lst = []
154 | actions_lst = []
155 | not_dones_lst = []
156 | mem_begin_index_lst = []
157 | target_lst = []
158 |
159 | for i in range(self.time_subsample):
160 | first_idx = mem_len + time_subsample[i]
161 | max_idx = first_idx + max_k
162 | min_idx = first_idx - self.causal_transformer_window + 1
163 | x = cat_out[:, min_idx:max_idx, :]
164 | y = cat_action[:, min_idx:max_idx, :]
165 | z = cat_not_dones[:, min_idx:max_idx, :]
166 | w = mem_begin_index + time_subsample[i] + 1
167 | k = obs[:, time_subsample[i]+1:time_subsample[i]+1+max_k]
168 | out_lst.append(x)
169 | actions_lst.append(y)
170 | not_dones_lst.append(z)
171 | mem_begin_index_lst.append(w)
172 | target_lst.append(k)
173 |
174 |
175 | return torch.stack(out_lst), torch.stack(actions_lst), torch.stack(not_dones_lst), \
176 | torch.stack(mem_begin_index_lst), torch.stack(target_lst)
177 |
178 | def _build_mask_and_subsample(self, not_dones):
179 | t = not_dones.size(1) - self.horizon_k
180 |
181 | not_dones_unfolded = self._build_unfolded(not_dones[:, :-1].to(dtype=torch.bool), self.horizon_k) # 10, 32, 24, 1
182 | time_subsample = torch.randperm(t - 2, device=not_dones.device, dtype=torch.long)[0:self.time_subsample]
183 |
184 | forward_mask = torch.cumprod(not_dones_unfolded.index_select(2, time_subsample), dim=0).to(dtype=torch.bool) # 10, 32, 6, 1
185 | forward_mask = forward_mask.flatten(1, 2) # 10, 192, 1
186 |
187 | max_k = forward_mask.flatten(1).any(-1).nonzero().max().item() + 1
188 |
189 | unroll_subsample = torch.randperm(max_k, dtype=torch.long)[0:self.forward_subsample]
190 |
191 | max_k = unroll_subsample.max().item() + 1
192 |
193 | unroll_subsample = unroll_subsample.to(device=not_dones.device)
194 | forward_mask = forward_mask.index_select(0, unroll_subsample)
195 |
196 | return forward_mask, unroll_subsample, time_subsample, max_k
197 |
198 | def _build_unfolded(self, x, k: int):
199 | tobe_cat = x.new_zeros(x.size(0), k, x.size(2))
200 | cat = torch.cat((x, tobe_cat), 1)
201 | cat = cat.unfold(1, size=k, step=1)
202 | cat = cat.permute(3, 0, 1, 2)
203 | return cat
204 |
205 |
--------------------------------------------------------------------------------
/brain_agent/core/shared_buffer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import gym
5 |
6 | from brain_agent.utils.logger import log
7 | from brain_agent.core.core_utils import iter_dicts_recursively, copy_dict_structure, iterate_recursively
8 | from brain_agent.core.models.model_utils import get_hidden_size
9 |
10 |
11 | def ensure_memory_shared(*tensors):
12 | for tensor_dict in tensors:
13 | for _, _, t in iterate_recursively(tensor_dict):
14 | assert t.is_shared()
15 |
16 | def to_torch_dtype(numpy_dtype):
17 | """from_numpy automatically infers type, so we leverage that."""
18 | x = np.zeros([1], dtype=numpy_dtype)
19 | t = torch.from_numpy(x)
20 | return t.dtype
21 |
22 | def to_numpy(t, num_dimensions):
23 | arr_shape = t.shape[:num_dimensions]
24 | arr = np.ndarray(arr_shape, dtype=object)
25 | to_numpy_func(t, arr)
26 | return arr
27 |
28 |
29 | def to_numpy_func(t, arr):
30 | if len(arr.shape) == 1:
31 | for i in range(t.shape[0]):
32 | arr[i] = t[i]
33 | else:
34 | for i in range(t.shape[0]):
35 | to_numpy_func(t[i], arr[i])
36 |
37 |
38 | class SharedBuffer:
39 | """
40 | Shared buffer stores data from different processes in a single place for efficient sharing of information.
41 | The tensors are stored according to the index scheme of
42 | [
43 | actor worker index,
44 | split index,
45 | env index,
46 | trajectory buffer index,
47 | time step
48 | ]
49 | actor worker index:
50 | Index of the actor worker. There can be multiple actor workers.
51 | split index:
52 | Index of the split. Each worker has many env objects which are grouped into splits.
53 | Each split is inferenced iteratively, for faster performance. (See Sample-factory for detail)
54 | env index:
55 | Inside a split, there are multiple env objects.
56 | trajectory buffer index:
57 | In case learner is significantly slow, data may pile and extra buffers takes care of them in a circular queue manner.
58 | time step:
59 | Rollout data are stored as a trajectory. A tensor resides in time step t within the trajectory.
60 | """
61 | def __init__(self, cfg, obs_space, action_space):
62 |
63 | self.cfg = cfg
64 | assert not cfg.actor.num_envs_per_worker % cfg.actor.num_splits, \
65 | f'actor.num_envs_per_worker ({cfg.actor.num_envs_per_worker}) ' \
66 | f'is not divided by actor.num_splits ({cfg.actor.num_splits})'
67 |
68 | self.obs_space = obs_space
69 | self.action_space = action_space
70 |
71 | self.envs_per_split = cfg.actor.num_envs_per_worker // cfg.actor.num_splits
72 | self.num_traj_buffers = self.calc_num_trajectory_buffers()
73 |
74 | core_hidden_size = get_hidden_size(self.cfg, action_space)
75 |
76 | log.debug('Allocating shared memory for trajectories')
77 | self.tensors = TensorDict()
78 |
79 | obs_dict = TensorDict()
80 | self.tensors['obs'] = obs_dict
81 | if isinstance(obs_space, gym.spaces.Dict):
82 | for name, space in obs_space.spaces.items():
83 | obs_dict[name] = self.init_tensor(space.dtype, space.shape)
84 | else:
85 | raise Exception('Only Dict observations spaces are supported')
86 | self.tensors['prev_actions'] = self.init_tensor(torch.int64, [1])
87 | self.tensors['prev_rewards'] = self.init_tensor(torch.float32, [1])
88 |
89 | self.tensors['rewards'] = self.init_tensor(torch.float32, [1])
90 | self.tensors['dones'] = self.init_tensor(torch.bool, [1])
91 | self.tensors['raw_rewards'] = self.init_tensor(torch.float32, [1])
92 |
93 | # policy outputs
94 | if self.cfg.model.core.core_type == 'trxl':
95 | policy_outputs = [
96 | ('actions', 1),
97 | ('action_logits', action_space.n),
98 | ('log_prob_actions', 1),
99 | ('values', 1),
100 | ('normalized_values', 1),
101 | ('policy_version', 1),
102 | ]
103 | elif self.cfg.model.core.core_type == 'rnn':
104 | policy_outputs = [
105 | ('actions', 1),
106 | ('action_logits', action_space.n),
107 | ('log_prob_actions', 1),
108 | ('values', 1),
109 | ('normalized_values', 1),
110 | ('policy_version', 1),
111 | ('rnn_states', core_hidden_size)
112 | ]
113 |
114 | policy_outputs = [PolicyOutput(*po) for po in policy_outputs]
115 | policy_outputs = sorted(policy_outputs, key=lambda policy_output: policy_output.name)
116 |
117 | for po in policy_outputs:
118 | self.tensors[po.name] = self.init_tensor(torch.float32, [po.size])
119 |
120 | ensure_memory_shared(self.tensors)
121 |
122 | self.tensors_individual_transitions = self.tensor_dict_to_numpy(len(self.tensor_dimensions()))
123 |
124 | self.tensor_trajectories = self.tensor_dict_to_numpy(len(self.tensor_dimensions()) - 1)
125 |
126 | traj_buffer_available_shape = [
127 | self.cfg.actor.num_workers,
128 | self.cfg.actor.num_splits,
129 | self.envs_per_split,
130 | self.num_traj_buffers,
131 | ]
132 | self.is_traj_tensor_available = torch.ones(traj_buffer_available_shape, dtype=torch.uint8)
133 | self.is_traj_tensor_available.share_memory_()
134 | self.is_traj_tensor_available = to_numpy(self.is_traj_tensor_available, 2)
135 |
136 | policy_outputs_combined_size = sum(po.size for po in policy_outputs)
137 | policy_outputs_shape = [
138 | self.cfg.actor.num_workers,
139 | self.cfg.actor.num_splits,
140 | self.envs_per_split,
141 | policy_outputs_combined_size,
142 | ]
143 |
144 | self.policy_outputs = policy_outputs
145 | self.policy_output_tensors = torch.zeros(policy_outputs_shape, dtype=torch.float32)
146 | self.policy_output_tensors.share_memory_()
147 | self.policy_output_tensors = to_numpy(self.policy_output_tensors, 3)
148 |
149 | self.policy_versions = torch.zeros([1], dtype=torch.int32)
150 | self.policy_versions.share_memory_()
151 |
152 | self.stop_experience_collection = torch.ones([1], dtype=torch.bool)
153 | self.stop_experience_collection.share_memory_()
154 |
155 | self.task_ids = torch.zeros([self.cfg.actor.num_workers, self.cfg.actor.num_splits, self.envs_per_split],
156 | dtype=torch.uint8)
157 | self.task_ids.share_memory_()
158 |
159 | self.max_mems_buffer_len = self.cfg.model.core.mem_len + self.cfg.optim.rollout * (self.num_traj_buffers + 1)
160 | self.mems_dimensions = [self.cfg.actor.num_workers, self.cfg.actor.num_splits, self.envs_per_split, self.max_mems_buffer_len]
161 | self.mems_dimensions.append(core_hidden_size)
162 | self.mems_dones_dimensions = [self.cfg.actor.num_workers, self.cfg.actor.num_splits,
163 | self.envs_per_split, self.max_mems_buffer_len]
164 | self.mems_dones_dimensions.append(1)
165 | self.mems_actions_dimensions = [self.cfg.actor.num_workers, self.cfg.actor.num_splits,
166 | self.envs_per_split, self.max_mems_buffer_len]
167 | self.mems_actions_dimensions.append(1)
168 |
169 | def calc_num_trajectory_buffers(self):
170 | num_traj_buffers = self.cfg.optim.batch_size / (
171 | self.cfg.actor.num_workers * self.cfg.actor.num_envs_per_worker * self.cfg.optim.rollout)
172 |
173 | num_traj_buffers *= 3
174 |
175 | num_traj_buffers = math.ceil(max(num_traj_buffers, self.cfg.shared_buffer.min_traj_buffers_per_worker))
176 | log.info('Using %d sets of trajectory buffers', num_traj_buffers)
177 | return num_traj_buffers
178 |
179 | def tensor_dimensions(self):
180 | dimensions = [
181 | self.cfg.actor.num_workers,
182 | self.cfg.actor.num_splits,
183 | self.envs_per_split,
184 | self.num_traj_buffers,
185 | self.cfg.optim.rollout,
186 | ]
187 | return dimensions
188 |
189 | def init_tensor(self, tensor_type, tensor_shape):
190 | if not isinstance(tensor_type, torch.dtype):
191 | tensor_type = to_torch_dtype(tensor_type)
192 |
193 | dimensions = self.tensor_dimensions()
194 | final_shape = dimensions + list(tensor_shape)
195 | t = torch.zeros(final_shape, dtype=tensor_type)
196 | t.share_memory_()
197 | return t
198 |
199 | def tensor_dict_to_numpy(self, num_dimensions):
200 | numpy_dict = copy_dict_structure(self.tensors)
201 | for d1, d2, key, curr_t, value2 in iter_dicts_recursively(self.tensors, numpy_dict):
202 | assert isinstance(curr_t, torch.Tensor)
203 | assert value2 is None
204 | d2[key] = to_numpy(curr_t, num_dimensions)
205 | assert isinstance(d2[key], np.ndarray)
206 | return numpy_dict
207 |
208 |
209 | class TensorDict(dict):
210 | def index(self, indices):
211 | return self.index_func(self, indices)
212 |
213 | def index_func(self, x, indices):
214 | if isinstance(x, (dict, TensorDict)):
215 | res = TensorDict()
216 | for key, value in x.items():
217 | res[key] = self.index_func(value, indices)
218 | return res
219 | else:
220 | t = x[indices]
221 | return t
222 |
223 | def set_data(self, index, new_data):
224 | self.set_data_func(self, index, new_data)
225 |
226 | def set_data_func(self, x, index, new_data):
227 | if isinstance(new_data, (dict, TensorDict)):
228 | for new_data_key, new_data_value in new_data.items():
229 | self.set_data_func(x[new_data_key], index, new_data_value)
230 | elif isinstance(new_data, torch.Tensor):
231 | x[index].copy_(new_data)
232 | elif isinstance(new_data, np.ndarray):
233 | t = torch.from_numpy(new_data)
234 | x[index].copy_(t)
235 | else:
236 | raise Exception(f'Type {type(new_data)} not supported in set_data_func')
237 |
238 |
239 | class PolicyOutput:
240 | def __init__(self, name, size):
241 | self.name = name
242 | self.size = size
243 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Brain Agent
2 | ***Brain Agent*** is a distributed agent learning system for large-scale and multi-task reinforcement learning, developed by [Kakao Brain](https://www.kakaobrain.com/).
3 | Brain Agent is based on the V-trace actor-critic framework [IMPALA](https://arxiv.org/abs/1802.01561) and modifies [Sample Factory](https://github.com/alex-petrenko/sample-factory) to utilize multiple GPUs and CPUs for a much higher throughput rate during training.
4 | ## Features
5 |
6 | 1. First publicly available implementations of reproducing SOTA results on [DMLAB30](https://github.com/deepmind/lab).
7 | 2. Scalable & massive throughput.
8 | BrainAgent can produce and train 20B frames/week, or 34K fps, with 16 V100 GPUs, by scaling up high throughput single node system [Sample Factory](https://github.com/alex-petrenko/sample-factory).
9 | 3. Based on following algorithms and architectures.
10 | * [TransformerXL-I](https://proceedings.mlr.press/v119/parisotto20a.html) core and [ResNet](https://arxiv.org/abs/1512.03385?context=cs) encoder
11 | * [V-trace](https://arxiv.org/abs/1802.01561) for update algorithm
12 | * [IMPALA](https://arxiv.org/abs/1802.01561) for system framework
13 | * [PopArt](https://arxiv.org/abs/1809.04474) for multitask handling
14 | 4. For self-supervised representation learning, we include 2 additional features.
15 | * ResNet-based decoder to reconstruct the original input image ([trxl_recon](https://github.com/kakaobrain/brain_agent/blob/main/configs/trxl_recon_train.yaml))
16 | * Additional autoregressive transformer to predict the images of future steps from the current state embedding and future action sequence ([trxl_future_pred](https://github.com/kakaobrain/brain_agent/blob/main/configs/trxl_future_pred_train.yaml))
17 | 5. Provide codes for both training and evaluation, along with SOTA model checkpoint with 28M params.
18 |
19 | ## How to Install
20 | - `Python 3.7`
21 | - `Pytorch 1.9.0`
22 | - `CUDA 11.1`
23 | - Install DMLab envrionment - [DMLab Github](https://github.com/deepmind/lab/blob/master/docs/users/build.md)
24 | - `pip install -r requirements.txt`
25 |
26 | ## Description of Codes
27 | - `dist_launch.py` -> distributed training launcher
28 | - `eval.py` -> entry point for evaluation
29 | - `train.py` -> entry point for training
30 | - `brain_agent`
31 | - `core`
32 | - `agents`
33 | - `dmlab_multitask_agent.py`
34 | - `algos`
35 | - `aux_future_predict.py` -> Computes auxiliary loss by predicting future state transitions with autoregressive transformer. Used only for ([trxl_future_pred](https://github.com/kakaobrain/brain_agent/blob/main/configs/trxl_future_pred_train.yaml)).
36 | - `popart.py`
37 | - `vtrace.py`
38 | - `actor_worker.py`
39 | - `learner_worker.py`
40 | - `policy_worker.py`
41 | - `shared_buffer.py` -> Defines SharedBuffer class for zero-copy communication between workers.
42 | - `envs`
43 | - `dmlab`
44 | - `utils`
45 | - ...
46 | - `configs`
47 | - ... -> Hyperparam configs for each of training/evaluation.
48 |
49 |
50 | ## How to Run
51 | ### Training
52 | - 1 node x 1 GPU
53 | ```bash
54 | python train.py cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
55 | ```
56 |
57 | - 1 node x 4 GPUs = 4 GPUs
58 | ```bash
59 | python -m dist_launch --nnodes=1 --node_rank=0 --nproc_per_node=4 -m train \
60 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
61 | ```
62 |
63 | - 4 nodes x 4 GPUs each = 16 GPUs
64 | ```bash
65 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=0 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \
66 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
67 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=1 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \
68 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
69 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=2 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \
70 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
71 | sleep 120; python -m dist_launch --nnodes=4 --node_rank=3 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \
72 | cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
73 | ```
74 |
75 | ### Evaluation
76 | ```bash
77 | python eval.py cfg=configs/trxl_recon_eval.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR test.checkpoint=$CHECKPOINT_FILE_PATH
78 | ```
79 |
80 | ### Setting Hyperparameters
81 | - All the default hyperparameters are defined at `configs/default.yaml`
82 | - Other config files override on `configs/default.yaml`.
83 | - You can use pre-defined hyperparameters for our experiments with `configs/trxl_recon_train.yaml` or `configs/trxl_future_pred.yaml`.
84 |
85 |
86 |
87 | ## Results for DMLAB30
88 |
89 | - Settings
90 | - 3 runs with different seeds
91 | - 100 episodes per each run
92 | - HNS : Human Normalised Score
93 | - Results
94 |
95 | | Model | Mean HNS | Median HNS | Mean Capped HNS |
96 | |:----------:|:--------:|:----------:|:---------------:|
97 | | [MERLIN](https://arxiv.org/pdf/1803.10760.pdf) | 115.2 | - | 89.4 |
98 | | [GTrXL](https://proceedings.mlr.press/v119/parisotto20a.html) | 117.6 | - | 89.1 |
99 | | [CoBERL](https://arxiv.org/pdf/2107.05431.pdf) | 115.47 | 110.86 |- |
100 | | [R2D2+](https://openreview.net/pdf?id=r1lyTjAqYX) | - | 99.5 | 85.7 |
101 | | [LASER](https://arxiv.org/abs/1909.11583) | - | 97.2 | 81.7 |
102 | | [PBL](https://arxiv.org/pdf/2004.14646.pdf) | 104.16 | - | 81.5 |
103 | | [PopArt-IMPALA](https://arxiv.org/abs/1809.04474) | - | - | 72.8 |
104 | | [IMPALA](https://arxiv.org/abs/1802.01561) | - | - | 58.4 |
105 | | Ours (lstm_baseline, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/ed7f0e5a8dc57ad72c8c38319f58000e/rnn_baseline_20b.pth)) | 103.03 ± 0.37 | 92.04 ± 0.73 | 81.35 ± 0.25 |
106 | | Ours (trxl_baseline, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/c0e90a4e3555a12b58e60729d13e2e02/trxl_baseline_20b.pth)) | 111.95 ± 1.00 | 105.43 ± 2.61 | 85.57 ± 0.20 |
107 | | Ours (trxl_recon, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/84ac1c594b8eb95e7fd9879d6172f99b/trxl_recon_20b.pth)) | 123.60 ± 0.84 | 108.63 ± 1.20 | **91.25 ± 0.41** |
108 | | Ours (trxl_future_pred, [20B ckpt](https://twg.kakaocdn.net/brainrepo/models/brain_agent/a54acdd9d3a14f2905d295c9c63bf31d/trxl_future_pred_20b.pth)) | **128.00 ± 0.43** | 108.80 0.99 | 90.53 ± 0.26 |
109 |
110 |
111 | Results for all 30 tasks
113 |
154 |
160 |