├── .dockerignore
├── .gitignore
├── .jenkins
├── Jenkinsfile
└── run_jenkins.sh
├── Dockerfile
├── LICENSE
├── README.md
├── adept
├── __init__.py
├── actor
│ ├── __init__.py
│ ├── ac_eval.py
│ ├── ac_rollout.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── ac_helper.py
│ │ └── actor_module.py
│ ├── impala.py
│ └── ppo.py
├── agent
│ ├── __init__.py
│ ├── actor_critic.py
│ ├── base
│ │ ├── __init__.py
│ │ └── agent_module.py
│ └── ppo.py
├── app.py
├── container
│ ├── __init__.py
│ ├── actorlearner
│ │ ├── __init__.py
│ │ ├── learner_container.py
│ │ ├── rollout_queuer.py
│ │ └── rollout_worker.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── container.py
│ │ ├── nccl_optimizer.py
│ │ └── updater.py
│ ├── distrib.py
│ ├── evaluation.py
│ ├── evaluation_thread.py
│ ├── init.py
│ ├── local.py
│ └── render.py
├── env
│ ├── __init__.py
│ ├── _gym_wrappers.py
│ ├── _spaces.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── _env.py
│ │ └── env_module.py
│ └── openai_gym.py
├── exp
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── exp_module.py
│ │ └── spec_builder.py
│ ├── replay.py
│ └── rollout.py
├── globals.py
├── learner
│ ├── __init__.py
│ ├── ac_rollout.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── dm_return_scale.py
│ │ └── learner_module.py
│ └── impala.py
├── manager
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ └── manager_module.py
│ ├── simple_env_manager.py
│ └── subproc_env_manager.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── memory.py
│ ├── mlp.py
│ ├── norm.py
│ ├── sequence.py
│ └── spatial.py
├── network
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── network_module.py
│ │ └── submodule.py
│ ├── modular_network.py
│ ├── net1d
│ │ ├── __init__.py
│ │ ├── identity_1d.py
│ │ ├── linear.py
│ │ ├── lstm.py
│ │ └── submodule_1d.py
│ ├── net2d
│ │ ├── __init__.py
│ │ ├── identity_2d.py
│ │ └── submodule_2d.py
│ ├── net3d
│ │ ├── __init__.py
│ │ ├── _resnets.py
│ │ ├── four_conv.py
│ │ ├── identity_3d.py
│ │ ├── networks.py
│ │ ├── rmc.py
│ │ └── submodule_3d.py
│ └── net4d
│ │ ├── __init__.py
│ │ ├── identity_4d.py
│ │ └── submodule_4d.py
├── preprocess
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── ops.py
│ │ └── preprocessor.py
│ └── ops.py
├── registry
│ ├── __init__.py
│ └── registry.py
├── rewardnorm
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ └── rewnorm_module.py
│ └── normalizers.py
├── scripts
│ ├── __init__.py
│ ├── _distrib.py
│ ├── actorlearner.py
│ ├── distrib.py
│ ├── evaluate.py
│ ├── local.py
│ └── render.py
└── utils
│ ├── __init__.py
│ ├── logging.py
│ ├── requires_args.py
│ ├── script_helpers.py
│ └── util.py
├── docker
├── Dockerfile
├── Dockerfile.nogpu
├── README.md
├── connect.py
└── startup.sh
├── docs
├── api_overview.md
├── modular_network.md
├── new_api.md
└── resume_training.md
├── examples
├── custom_agent_stub.py
├── custom_environment_stub.py
├── custom_network_stub.py
└── custom_submodule_stub.py
├── images
├── architecture.png
├── banner.png
├── benchmark.png
└── modular_network.png
├── setup.py
└── tests
├── __init__.py
├── distrib
├── __init__.py
├── allreduce.py
├── container_sync.py
├── control_flow_zmq.py
├── exp_sync_broadcast.py
├── hello_ray.py
├── launch.py
├── multi_group.py
├── nccl_typecheck.py
└── ray_container.py
├── exp
└── test_rollout.py
├── learner
├── __init__.py
└── nstep.py
├── network
├── __init__.py
└── test_modular_network.py
├── registry
├── __init__.py
└── test_registry.py
└── utils
├── __init__.py
├── test_requires_args.py
└── test_util.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | logs
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | .DS_Store
3 | *.xml
4 | .idea/
5 | .vscode/
6 | __pycache__/
7 | logs/
8 | trained_models/
9 | *egg-info
10 | dist/
11 | archive/
12 | build/
13 | .pytest_cache/
14 | test_reports/
--------------------------------------------------------------------------------
/.jenkins/Jenkinsfile:
--------------------------------------------------------------------------------
1 | pipeline {
2 | agent {
3 | dockerfile {
4 | filename 'Dockerfile'
5 | args("""--runtime nvidia \
6 | -v /tmp/adept_logs:/tmp/adept_logs \
7 | --net host \
8 | --cap-add SYS_PTRACE""")
9 | }
10 | }
11 | triggers { pollSCM('H/15 * * * *') }
12 |
13 | stages {
14 | stage('Build') {
15 | steps {
16 | echo 'Checking build...'
17 | sh 'nvidia-smi'
18 | sh 'python -m adept.scripts.local -h'
19 | }
20 | }
21 | stage('Test') {
22 | steps {
23 | echo 'Running unit tests...'
24 | sh 'pytest --verbose'
25 | }
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/.jenkins/run_jenkins.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker run \
4 | --rm \
5 | -u root \
6 | -p 8080:8080 \
7 | -v /var/jenkins_home:/var/jenkins_home \
8 | -v /var/run/docker.sock:/var/run/docker.sock \
9 | -v "$HOME":/home \
10 | jenkinsci/blueocean
11 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ./docker/Dockerfile
--------------------------------------------------------------------------------
/adept/__init__.py:
--------------------------------------------------------------------------------
1 | def register_agent(agent_cls):
2 | from adept.registry import REGISTRY
3 |
4 | REGISTRY.register_agent(agent_cls)
5 |
6 |
7 | def register_actor(actor_cls):
8 | from adept.registry import REGISTRY
9 |
10 | REGISTRY.register_actor(actor_cls)
11 |
12 |
13 | def register_exp(exp_cls):
14 | from adept.registry import REGISTRY
15 |
16 | REGISTRY.register_exp(exp_cls)
17 |
18 |
19 | def register_learner(learner_cls):
20 | from adept.registry import REGISTRY
21 |
22 | REGISTRY.register_learner(learner_cls)
23 |
24 |
25 | def register_env(env_cls):
26 | from adept.registry import REGISTRY
27 |
28 | REGISTRY.register_env(env_cls)
29 |
30 |
31 | def register_reward_norm(rwd_norm_cls):
32 | from adept.registry import REGISTRY
33 |
34 | REGISTRY.register_reward_normalizer(rwd_norm_cls)
35 |
36 |
37 | def register_network(network_cls):
38 | from adept.registry import REGISTRY
39 |
40 | REGISTRY.register_network(network_cls)
41 |
42 |
43 | def register_submodule(submod_cls):
44 | from adept.registry import REGISTRY
45 |
46 | REGISTRY.register_submodule(submod_cls)
47 |
48 |
49 | def register_manager(manager_cls):
50 | from adept.registry import REGISTRY
51 |
52 | REGISTRY.register_manager(manager_cls)
53 |
--------------------------------------------------------------------------------
/adept/actor/__init__.py:
--------------------------------------------------------------------------------
1 | from .base.actor_module import ActorModule
2 | from .ac_rollout import ACRolloutActorTrain
3 | from .ppo import PPOActorTrain
4 | from adept.actor.ac_eval import ACActorEval, ACActorEvalSample
5 | from .impala import ImpalaHostActor, ImpalaWorkerActor
6 |
7 | ACTOR_REG = [
8 | ACRolloutActorTrain,
9 | ACActorEval,
10 | ACActorEvalSample,
11 | PPOActorTrain,
12 | ImpalaHostActor,
13 | ImpalaWorkerActor,
14 | ]
15 |
--------------------------------------------------------------------------------
/adept/actor/ac_eval.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from adept.actor import ActorModule
4 | from adept.actor.base.ac_helper import ACActorHelperMixin
5 |
6 |
7 | class ACActorEval(ActorModule, ACActorHelperMixin):
8 | args = {}
9 |
10 | @classmethod
11 | def from_args(cls, args, action_space):
12 | return cls(action_space)
13 |
14 | @staticmethod
15 | def output_space(action_space):
16 | head_dict = {"critic": (1,), **action_space}
17 | return head_dict
18 |
19 | def compute_action_exp(self, preds, internals, obs, available_actions):
20 | actions = OrderedDict()
21 |
22 | for key in self.action_keys:
23 | logit = self.flatten_logits(preds[key])
24 |
25 | softmax = self.softmax(logit)
26 | action = self.select_action(softmax)
27 |
28 | actions[key] = action.cpu()
29 | return actions, {"value": preds["critic"].squeeze(-1)}
30 |
31 | @classmethod
32 | def _exp_spec(
33 | cls, rollout_len, batch_sz, obs_space, act_space, internal_space
34 | ):
35 | return {}
36 |
37 |
38 | class ACActorEvalSample(ACActorEval):
39 | def compute_action_exp(self, preds, internals, obs, available_actions):
40 | actions = OrderedDict()
41 |
42 | for key in self.action_keys:
43 | logit = self.flatten_logits(preds[key])
44 |
45 | softmax = self.softmax(logit)
46 | action = self.sample_action(softmax)
47 |
48 | actions[key] = action.cpu()
49 | return actions, {"value": preds["critic"].squeeze(-1)}
50 |
--------------------------------------------------------------------------------
/adept/actor/ac_rollout.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from collections import OrderedDict
16 |
17 | import torch
18 |
19 | from adept.actor.base.ac_helper import ACActorHelperMixin
20 | from adept.actor.base.actor_module import ActorModule
21 |
22 |
23 | class ACRolloutActorTrain(ActorModule, ACActorHelperMixin):
24 | args = {}
25 |
26 | @classmethod
27 | def from_args(cls, args, action_space):
28 | return cls(action_space)
29 |
30 | @staticmethod
31 | def output_space(action_space):
32 | head_dict = {"critic": (1,), **action_space}
33 | return head_dict
34 |
35 | def compute_action_exp(self, preds, internals, obs, available_actions):
36 | values = preds["critic"].squeeze(1)
37 |
38 | actions = OrderedDict()
39 | log_probs = []
40 | entropies = []
41 |
42 | for key in self.action_keys:
43 | logit = self.flatten_logits(preds[key])
44 |
45 | log_softmax, softmax = self.log_softmax(logit), self.softmax(logit)
46 | entropy = self.entropy(log_softmax, softmax)
47 | action = self.sample_action(softmax)
48 |
49 | entropies.append(entropy)
50 | log_probs.append(self.log_probability(log_softmax, action))
51 | actions[key] = action.cpu()
52 |
53 | log_probs = torch.cat(log_probs, dim=1)
54 | entropies = torch.cat(entropies, dim=1)
55 |
56 | return (
57 | actions,
58 | {"log_probs": log_probs, "entropies": entropies, "values": values},
59 | )
60 |
61 | @classmethod
62 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
63 | act_key_len = len(act_space.keys())
64 |
65 | spec = {
66 | "log_probs": (exp_len, batch_sz, act_key_len),
67 | "entropies": (exp_len, batch_sz, act_key_len),
68 | "values": (exp_len, batch_sz),
69 | }
70 |
71 | return spec
72 |
--------------------------------------------------------------------------------
/adept/actor/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor_module import ActorModule
2 |
--------------------------------------------------------------------------------
/adept/actor/base/ac_helper.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | class ACActorHelperMixin(metaclass=abc.ABCMeta):
7 | """
8 | A helper class for actor critic actors.
9 | """
10 |
11 | @staticmethod
12 | def flatten_logits(logit):
13 | """
14 | :param logits: Tensor of arbitrary dim
15 | :return: logits flattened to (N, X)
16 | """
17 | size = logit.size()
18 | dim = logit.dim()
19 |
20 | if dim == 3:
21 | n, f, l = size
22 | logit = logit.view(n, f * l)
23 | elif dim == 4:
24 | n, f, h, w = size
25 | logit = logit.view(n, f * h * w)
26 | elif dim == 5:
27 | n, f, d, h, w = size
28 | logit = logit.view(n, f * d * h * w)
29 | return logit
30 |
31 | @staticmethod
32 | def softmax(logit):
33 | """
34 | :param logit: torch.Tensor (N, X)
35 | :return: torch.Tensor (N, X)
36 | """
37 | return F.softmax(logit, dim=1)
38 |
39 | @staticmethod
40 | def log_softmax(logit):
41 | """
42 | :param logit: torch.Tensor (N, X)
43 | :return: torch.Tensor (N, X)
44 | """
45 | return F.log_softmax(logit, dim=1)
46 |
47 | @staticmethod
48 | def log_probability(log_softmax, action):
49 | """
50 | :param log_softmax: Tensor (N, X)
51 | :param action: LongTensor (N)
52 | :return: Tensor (N, 1)
53 | """
54 | return log_softmax.gather(1, action.unsqueeze(1))
55 |
56 | @staticmethod
57 | def entropy(log_softmax, softmax):
58 | """
59 | :param log_softmax: Tensor (N, X)
60 | :param softmax: Tensor (N, X)
61 | :return: Tensor (N, 1)
62 | """
63 | return -(log_softmax * softmax).sum(1, keepdim=True)
64 |
65 | @staticmethod
66 | def sample_action(softmax):
67 | """
68 | Samples an action from a softmax distribution.
69 |
70 | :param softmax: torch.Tensor (N, X)
71 | :return: torch.Tensor (N)
72 | """
73 | return softmax.multinomial(1).squeeze(1)
74 |
75 | @staticmethod
76 | def select_action(softmax):
77 | """
78 | Selects the action with the highest probability.
79 |
80 | :param softmax:
81 | :return:
82 | """
83 | return torch.argmax(softmax, dim=1)
84 |
--------------------------------------------------------------------------------
/adept/actor/base/actor_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | """
16 | An actor observes the environment and decides actions. It also outputs extra
17 | info necessary for model updates (learning) to occur.
18 | """
19 | import abc
20 | from collections import defaultdict
21 |
22 | from adept.exp.base.spec_builder import ExpSpecBuilder
23 | from adept.utils.requires_args import RequiresArgsMixin
24 |
25 |
26 | class ActorModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
27 | def __init__(self, action_space):
28 | self._action_space = action_space
29 |
30 | @property
31 | def action_space(self):
32 | return self._action_space
33 |
34 | @property
35 | def action_keys(self):
36 | return sorted(self.action_space.keys())
37 |
38 | @staticmethod
39 | @abc.abstractmethod
40 | def output_space(action_space):
41 | raise NotImplementedError
42 |
43 | @classmethod
44 | def exp_spec_builder(cls, obs_space, act_space, internal_space, batch_sz):
45 | def build_fn(exp_len):
46 | exp_space = cls._exp_spec(
47 | exp_len, batch_sz, obs_space, act_space, internal_space
48 | )
49 | env_space = {
50 | "rewards": (exp_len, batch_sz),
51 | "terminals": (exp_len, batch_sz),
52 | }
53 | return {**exp_space, **env_space}
54 |
55 | key_types = cls._key_types(obs_space, act_space, internal_space)
56 | exp_keys = cls._exp_keys(obs_space, act_space, internal_space)
57 | return ExpSpecBuilder(
58 | obs_space, act_space, internal_space, key_types, exp_keys, build_fn
59 | )
60 |
61 | @classmethod
62 | @abc.abstractmethod
63 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
64 | raise NotImplementedError
65 |
66 | @classmethod
67 | def _exp_keys(cls, obs_space, act_space, internal_space):
68 | dummy = cls._exp_spec(1, 1, obs_space, act_space, internal_space)
69 | return dummy.keys()
70 |
71 | @classmethod
72 | def _key_types(cls, obs_space, act_space, internal_space):
73 | return defaultdict(lambda: "float")
74 |
75 | @abc.abstractmethod
76 | def from_args(self, args, action_space):
77 | raise NotImplementedError
78 |
79 | @abc.abstractmethod
80 | def compute_action_exp(self, preds, internals, obs, available_actions):
81 | """
82 | B = Batch Size
83 |
84 | :param preds: Dict[str, torch.Tensor]
85 | :return:
86 | actions: Dict[ActionKey, Tensor (B)]
87 | experience: Dict[str, Tensor (B, X)]
88 | """
89 | raise NotImplementedError
90 |
91 | def act(self, network, obs, prev_internals):
92 | """
93 | :param obs: Dict[str, Tensor]
94 | :param prev_internals: previous interal states. Dict[str, Tensor]
95 | :return:
96 | actions: Dict[ActionKey, Tensor (B)]
97 | experience: Dict[str, Tensor (B, X)]
98 | internal_states: Dict[str, Tensor]
99 | """
100 |
101 | predictions, internal_states, pobs = network(obs, prev_internals)
102 |
103 | if "available_actions" in obs:
104 | av_actions = obs["available_actions"]
105 | else:
106 | av_actions = None
107 |
108 | actions, exp = self.compute_action_exp(
109 | predictions, prev_internals, pobs, av_actions
110 | )
111 | return actions, exp, internal_states
112 |
--------------------------------------------------------------------------------
/adept/actor/impala.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from collections import OrderedDict, defaultdict
16 | from functools import reduce
17 |
18 | import torch
19 |
20 | from adept.actor.base.ac_helper import ACActorHelperMixin
21 | from adept.actor.base.actor_module import ActorModule
22 |
23 |
24 | class ImpalaHostActor(ActorModule, ACActorHelperMixin):
25 | args = {}
26 |
27 | @classmethod
28 | def from_args(cls, args, action_space):
29 | return cls(action_space)
30 |
31 | @staticmethod
32 | def output_space(action_space):
33 | head_dict = {"critic": (1,), **action_space}
34 | return head_dict
35 |
36 | def compute_action_exp(self, preds, internals, obs, available_actions):
37 | values = preds["critic"].squeeze(1)
38 |
39 | log_softmaxes = []
40 | entropies = []
41 |
42 | for key in self.action_keys:
43 | logit = self.flatten_logits(preds[key])
44 |
45 | # print(logit) NaNs
46 |
47 | log_softmax, softmax = self.log_softmax(logit), self.softmax(logit)
48 | entropy = self.entropy(log_softmax, softmax)
49 |
50 | entropies.append(entropy)
51 | log_softmaxes.append(log_softmax)
52 |
53 | log_softmaxes = torch.stack(log_softmaxes, dim=1)
54 | entropies = torch.cat(entropies, dim=1)
55 |
56 | return (
57 | None,
58 | {
59 | "log_softmaxes": log_softmaxes,
60 | "entropies": entropies,
61 | "values": values,
62 | },
63 | )
64 |
65 | @classmethod
66 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
67 | flat_act_space = 0
68 | for k, shape in act_space.items():
69 | flat_act_space += reduce(lambda a, b: a * b, shape)
70 | act_key_len = len(act_space.keys())
71 |
72 | obs_spec = {
73 | k: (exp_len + 1, batch_sz, *shape) for k, shape in obs_space.items()
74 | }
75 | action_spec = {k: (exp_len, batch_sz) for k in act_space.keys()}
76 | internal_spec = {
77 | k: (exp_len, batch_sz, *shape)
78 | for k, shape in internal_space.items()
79 | }
80 |
81 | spec = {
82 | "log_softmaxes": (exp_len, batch_sz, act_key_len, flat_act_space),
83 | "entropies": (exp_len, batch_sz, act_key_len),
84 | "values": (exp_len, batch_sz),
85 | # From Workers
86 | "log_probs": (exp_len, batch_sz, act_key_len),
87 | **obs_spec,
88 | **action_spec,
89 | **internal_spec,
90 | }
91 |
92 | return spec
93 |
94 | @classmethod
95 | def _key_types(cls, obs_space, act_space, internal_space):
96 | d = defaultdict(lambda: "float")
97 | for k in act_space.keys():
98 | d[k] = "long"
99 | # TODO this needs a better solution
100 | for k in obs_space.keys():
101 | d[k] = "byte"
102 | return d
103 |
104 |
105 | class ImpalaWorkerActor(ActorModule, ACActorHelperMixin):
106 | args = {}
107 |
108 | @classmethod
109 | def from_args(cls, args, action_space):
110 | return cls(action_space)
111 |
112 | @staticmethod
113 | def output_space(action_space):
114 | head_dict = {"critic": (1,), **action_space}
115 | return head_dict
116 |
117 | def compute_action_exp(self, preds, internals, obs, available_actions):
118 | log_probs = []
119 | actions_gpu = OrderedDict()
120 | actions_cpu = OrderedDict()
121 |
122 | for key in self.action_keys:
123 | logit = self.flatten_logits(preds[key])
124 |
125 | log_softmax, softmax = self.log_softmax(logit), self.softmax(logit)
126 | action = self.sample_action(softmax)
127 |
128 | log_probs.append(self.log_probability(log_softmax, action))
129 | actions_gpu[key] = action
130 | actions_cpu[key] = action.cpu()
131 |
132 | log_probs = torch.cat(log_probs, dim=1)
133 |
134 | internals = {k: torch.stack(vs) for k, vs in internals.items()}
135 |
136 | return actions_cpu, {"log_probs": log_probs, **actions_gpu, **internals}
137 |
138 | @classmethod
139 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
140 | act_key_len = len(act_space.keys())
141 |
142 | obs_spec = {
143 | k: (exp_len + 1, batch_sz, *shape) for k, shape in obs_space.items()
144 | }
145 | action_spec = {k: (exp_len, batch_sz) for k in act_space.keys()}
146 | internal_spec = {
147 | k: (exp_len, batch_sz, *shape)
148 | for k, shape in internal_space.items()
149 | }
150 |
151 | spec = {
152 | "log_probs": (exp_len, batch_sz, act_key_len),
153 | **obs_spec,
154 | **action_spec,
155 | **internal_spec,
156 | }
157 |
158 | return spec
159 |
160 | @classmethod
161 | def _key_types(cls, obs_space, act_space, internal_space):
162 | d = defaultdict(lambda: "float")
163 | for k in act_space.keys():
164 | d[k] = "long"
165 | # TODO this needs a better solution
166 | for k in obs_space.keys():
167 | d[k] = "byte"
168 | return d
169 |
--------------------------------------------------------------------------------
/adept/actor/ppo.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from collections import OrderedDict
16 |
17 | import torch
18 |
19 | from adept.actor.base.ac_helper import ACActorHelperMixin
20 | from adept.actor.base.actor_module import ActorModule
21 |
22 |
23 | class PPOActorTrain(ActorModule, ACActorHelperMixin):
24 | args = {}
25 |
26 | @classmethod
27 | def from_args(cls, args, action_space):
28 | return cls(action_space)
29 |
30 | @staticmethod
31 | def output_space(action_space):
32 | head_dict = {'critic': (1,), **action_space}
33 | return head_dict
34 |
35 | def compute_action_exp(self, preds, internals, obs, available_actions):
36 | values = preds['critic'].squeeze(1)
37 |
38 | actions = OrderedDict()
39 | actions_gpu = OrderedDict()
40 | log_probs = []
41 |
42 | for key in self.action_keys:
43 | logit = self.flatten_logits(preds[key])
44 |
45 | log_softmax, softmax = self.log_softmax(logit), self.softmax(logit)
46 | action = self.sample_action(softmax)
47 |
48 | log_probs.append(self.log_probability(log_softmax, action))
49 | actions_gpu[key] = action
50 | actions[key] = action.cpu()
51 |
52 | log_probs = torch.cat(log_probs, dim=1)
53 | internals = {k: torch.stack(vs) for k, vs in internals.items()}
54 |
55 | return actions, {
56 | 'log_probs': log_probs,
57 | 'values': values,
58 | **internals,
59 | **actions_gpu
60 | }
61 |
62 | @classmethod
63 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
64 | act_key_len = len(act_space.keys())
65 | action_spec = {k: (exp_len, batch_sz) for k in act_space.keys()}
66 | internal_spec = {
67 | k: (exp_len, batch_sz, *shape) for k, shape in internal_space.items()
68 | }
69 | obs_spec = {k: (exp_len + 1, batch_sz, *shape) for k, shape in
70 | obs_space.items()}
71 |
72 | spec = {
73 | 'log_probs': (exp_len, batch_sz, act_key_len),
74 | 'values': (exp_len, batch_sz),
75 | **action_spec,
76 | **obs_spec,
77 | **internal_spec
78 | }
79 |
80 | return spec
81 |
--------------------------------------------------------------------------------
/adept/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from .base.agent_module import AgentModule
2 | from .actor_critic import ActorCritic
3 | from .ppo import PPO
4 |
5 | AGENT_REG = [
6 | ActorCritic,
7 | PPO
8 | ]
9 |
--------------------------------------------------------------------------------
/adept/agent/actor_critic.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from adept.actor import ACRolloutActorTrain
16 | from adept.exp import Rollout
17 | from adept.learner import ACRolloutLearner
18 | from .base.agent_module import AgentModule
19 |
20 |
21 | class ActorCritic(AgentModule):
22 | args = {**Rollout.args, **ACRolloutActorTrain.args, **ACRolloutLearner.args}
23 |
24 | def __init__(
25 | self,
26 | reward_normalizer,
27 | action_space,
28 | spec_builder,
29 | rollout_len,
30 | discount,
31 | normalize_advantage,
32 | entropy_weight,
33 | return_scale,
34 | ):
35 | super(ActorCritic, self).__init__(reward_normalizer, action_space)
36 | self.discount = discount
37 | self.normalize_advantage = normalize_advantage
38 | self.entropy_weight = entropy_weight
39 |
40 | self._exp_cache = Rollout(spec_builder, rollout_len)
41 | self._actor = ACRolloutActorTrain(action_space)
42 | self._learner = ACRolloutLearner(
43 | reward_normalizer,
44 | discount,
45 | normalize_advantage,
46 | entropy_weight,
47 | return_scale,
48 | )
49 |
50 | @classmethod
51 | def from_args(
52 | cls, args, reward_normalizer, action_space, spec_builder, **kwargs
53 | ):
54 | return cls(
55 | reward_normalizer,
56 | action_space,
57 | spec_builder,
58 | rollout_len=args.rollout_len,
59 | discount=args.discount,
60 | normalize_advantage=args.normalize_advantage,
61 | entropy_weight=args.entropy_weight,
62 | return_scale=args.return_scale,
63 | )
64 |
65 | @property
66 | def exp_cache(self):
67 | return self._exp_cache
68 |
69 | @classmethod
70 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
71 | return ACRolloutActorTrain._exp_spec(
72 | exp_len, batch_sz, obs_space, act_space, internal_space
73 | )
74 |
75 | @staticmethod
76 | def output_space(action_space):
77 | return ACRolloutActorTrain.output_space(action_space)
78 |
79 | def compute_action_exp(
80 | self, predictions, internals, obs, available_actions
81 | ):
82 | return self._actor.compute_action_exp(
83 | predictions, internals, obs, available_actions
84 | )
85 |
86 | def learn_step(self, updater, network, next_obs, internals):
87 | return self._learner.learn_step(
88 | updater, network, self.exp_cache.read(), next_obs, internals
89 | )
90 |
--------------------------------------------------------------------------------
/adept/agent/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/agent/base/__init__.py
--------------------------------------------------------------------------------
/adept/agent/base/agent_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | """
16 | An Agent interacts with the environment and accumulates experience.
17 | """
18 | from collections import defaultdict
19 |
20 | import abc
21 |
22 | from adept.exp import ExpSpecBuilder
23 | from adept.utils.requires_args import RequiresArgsMixin
24 |
25 |
26 | class AgentModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
27 | """
28 | An Agent is an Actor (chooses actions) and a Learner (updates parameters)
29 | and maintains a cache to store a rollout or experience replay.
30 |
31 | Actors and Learners are treated separately for Actor-Learner architectures
32 | where multiple actors send their experience to a single learner.
33 | """
34 |
35 | def __init__(self, reward_normalizer, action_space):
36 | self._reward_normalizer = reward_normalizer
37 | self._action_space = action_space
38 |
39 | @classmethod
40 | @abc.abstractmethod
41 | def from_args(
42 | cls, args, reward_normalizer, action_space, spec_builder, **kwargs
43 | ):
44 | raise NotImplementedError
45 |
46 | @property
47 | @abc.abstractmethod
48 | def exp_cache(self):
49 | """Get experience cache"""
50 | raise NotImplementedError
51 |
52 | @property
53 | def action_space(self):
54 | return self._action_space
55 |
56 | @property
57 | def action_keys(self):
58 | return list(sorted(self.action_space.keys()))
59 |
60 | @classmethod
61 | def exp_spec_builder(cls, obs_space, act_space, internal_space, batch_sz):
62 | def build_fn(exp_len):
63 | exp_space = cls._exp_spec(
64 | exp_len, batch_sz, obs_space, act_space, internal_space
65 | )
66 | env_space = {
67 | "rewards": (exp_len, batch_sz),
68 | "terminals": (exp_len, batch_sz),
69 | }
70 | return {**exp_space, **env_space}
71 |
72 | key_types = cls._key_types(obs_space, act_space, internal_space)
73 | exp_keys = cls._exp_keys(obs_space, act_space, internal_space)
74 | return ExpSpecBuilder(
75 | obs_space, act_space, internal_space, key_types, exp_keys, build_fn
76 | )
77 |
78 | @classmethod
79 | @abc.abstractmethod
80 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
81 | raise NotImplementedError
82 |
83 | @classmethod
84 | def _exp_keys(cls, obs_space, act_space, internal_space):
85 | dummy = cls._exp_spec(1, 1, obs_space, act_space, internal_space)
86 | return dummy.keys()
87 |
88 | @classmethod
89 | def _key_types(cls, obs_space, act_space, internal_space):
90 | return defaultdict(lambda: "float")
91 |
92 | @staticmethod
93 | @abc.abstractmethod
94 | def output_space(action_space):
95 | raise NotImplementedError
96 |
97 | @abc.abstractmethod
98 | def compute_action_exp(
99 | self, predictions, internals, obs, available_actions
100 | ):
101 | raise NotImplementedError
102 |
103 | @abc.abstractmethod
104 | def learn_step(self, updater, network, next_obs, internals):
105 | raise NotImplementedError
106 |
107 | def is_ready(self):
108 | return self.exp_cache.is_ready()
109 |
110 | def clear(self):
111 | self.exp_cache.clear()
112 |
113 | def act(self, network, obs, prev_internals):
114 | """
115 | :param network: NetworkModule
116 | :param obs: Dict[str, Tensor]
117 | :param prev_internals: previous interal states. Dict[str, Tensor]
118 | :return:
119 | actions: Dict[ActionKey, LongTensor (B)]
120 | internal_states: Dict[str, Tensor]
121 | """
122 | predictions, internal_states, pobs = network(obs, prev_internals)
123 |
124 | if "available_actions" in obs:
125 | av_actions = obs["available_actions"]
126 | else:
127 | av_actions = None
128 |
129 | actions, experience = self.compute_action_exp(
130 | predictions, prev_internals, pobs, av_actions
131 | )
132 | self.exp_cache.write_actor(experience)
133 | return actions, internal_states
134 |
135 | def observe(self, obs, rewards, terminals, infos):
136 | self.exp_cache.write_env(obs, rewards, terminals, infos)
137 | return rewards, terminals, infos
138 |
139 | def to(self, device):
140 | self.exp_cache.to(device)
141 | return self
142 |
--------------------------------------------------------------------------------
/adept/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (C) 2018 Heron Systems, Inc.
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 | """
17 | __ __
18 | ____ _____/ /__ ____ / /_
19 | / __ `/ __ / _ \/ __ \/ __/
20 | / /_/ / /_/ / __/ /_/ / /_
21 | \__,_/\__,_/\___/ .___/\__/
22 | /_/
23 |
24 | Usage:
25 | adept.app [...]
26 | adept.app (-h | --help)
27 | adept.app --version
28 |
29 | Commands:
30 | local Train an agent on a single GPU.
31 | distrib Train an agent on multiple machines and/or GPUs.
32 | actorlearner Train an agent on multiple machines and/or GPUs.
33 | evaluate Evaluate a trained agent.
34 | render Visualize an agent playing an Atari game.
35 | replay_gen_sc2 Generate SC2 replay files of an agent playing SC2.
36 |
37 | See 'adept.app help ' for more information on a specific command.
38 | """
39 | from docopt import docopt
40 | from adept.globals import VERSION
41 | from subprocess import call
42 | import os
43 |
44 |
45 | def parse_args():
46 | args = docopt(
47 | __doc__, version="adept version " + VERSION, options_first=True
48 | )
49 |
50 | env = os.environ
51 | argv = args[""]
52 | if args[""] == "local":
53 | exit(call(["python", "-m", "adept.scripts.local"] + argv, env=env))
54 | elif args[""] == "distrib":
55 | exit(call(["python", "-m", "adept.scripts.distrib"] + argv, env=env))
56 | elif args[""] == "actorlearner":
57 | exit(
58 | call(["python", "-m", "adept.scripts.actorlearner"] + argv, env=env)
59 | )
60 | elif args[""] == "evaluate":
61 | exit(call(["python", "-m", "adept.scripts.evaluate"] + argv, env=env))
62 | elif args[""] == "render":
63 | exit(call(["python", "-m", "adept.scripts.render"] + argv, env=env))
64 | elif args[""] == "replay_gen_sc2":
65 | exit(
66 | call(
67 | ["python", "-m", "adept.scripts.replay_gen_sc2"] + argv, env=env
68 | )
69 | )
70 | elif args[""] == "help":
71 | if "local" in args[""]:
72 | exit(call(["python", "-m", "adept.scripts.local", "-h"]))
73 | elif "distrib" in args[""]:
74 | exit(call(["python", "-m", "adept.scripts.distrib", "-h"]))
75 | elif "actorlearner" in args[""]:
76 | exit(call(["python", "-m", "adept.scripts.actorlearner", "-h"]))
77 | elif "evaluate" in args[""]:
78 | exit(call(["python", "-m", "adept.scripts.evaluate", "-h"]))
79 | elif "render" in args[""]:
80 | exit(call(["python", "-m", "adept.scripts.render", "-h"]))
81 | elif "replay_gen_sc2" in args[""]:
82 | exit(call(["python", "-m", "adept.scripts.replay_gen_sc2", "-h"]))
83 | else:
84 | exit(
85 | "{} is not a valid command. See 'adept.app --help'.".format(
86 | args[""]
87 | )
88 | )
89 |
90 |
91 | if __name__ == "__main__":
92 | parse_args()
93 |
--------------------------------------------------------------------------------
/adept/container/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .local import Local
16 | from .distrib import DistribHost, DistribWorker
17 | from .actorlearner import ActorLearnerHost, ActorLearnerWorker
18 | from .evaluation import EvalContainer
19 | from .init import Init
20 | from .base.updater import Updater
21 |
--------------------------------------------------------------------------------
/adept/container/actorlearner/__init__.py:
--------------------------------------------------------------------------------
1 | from .learner_container import ActorLearnerHost
2 | from .rollout_worker import ActorLearnerWorker
3 |
--------------------------------------------------------------------------------
/adept/container/actorlearner/rollout_queuer.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from time import time
16 |
17 | import queue
18 | import ray
19 | import threading
20 | import torch
21 |
22 |
23 | class RolloutQueuerAsync:
24 | def __init__(self, workers, num_rollouts, queue_max_size, timeout=15.0):
25 | self.workers = workers
26 | self.num_rollouts = num_rollouts
27 | self.queue_max_size = queue_max_size
28 | self.queue_timeout = timeout
29 |
30 | self.futures = [w.run.remote() for w in self.workers]
31 | self.future_inds = [w for w in range(len(self.workers))]
32 | self._should_stop = True
33 | self.rollout_queue = queue.Queue(self.queue_max_size)
34 | self._worker_wait_time = 0
35 | self._host_wait_time = 0
36 |
37 | def _background_queing_thread(self):
38 | while not self._should_stop:
39 | ready_ids, remaining_ids = ray.wait(self.futures, num_returns=1)
40 |
41 | # if ray returns an empty list that means all objects have been gotten
42 | # this should never happen?
43 | if len(ready_ids) == 0:
44 | print("WARNING: ray returned no ready rollouts")
45 | # otherwise rollout was returned
46 | else:
47 | for ready in ready_ids:
48 | # get object and add to queue
49 | rollouts = ray.get(ready)
50 | # this will block if queue is at max size
51 | self._add_to_queue(rollouts)
52 |
53 | # remove from futures
54 | self._idle_workers = []
55 | for ready in ready_ids:
56 | index = self.futures.index(ready)
57 | del self.futures[index]
58 | self._idle_workers.append(self.future_inds[index])
59 | del self.future_inds[index]
60 |
61 | # tell the worker(s) to start another rollout
62 | self._restart_idle_workers()
63 |
64 | # done, wait for all remaining to finish
65 | dones, not_dones = ray.wait(
66 | self.futures, len(self.futures), timeout=self.queue_timeout
67 | )
68 |
69 | if len(not_dones) > 0:
70 | print("WARNING: Not all rollout workers finished")
71 |
72 | def _add_to_queue(self, rollout):
73 | st = time()
74 | self.rollout_queue.put(rollout, timeout=self.queue_timeout)
75 | et = time()
76 | self._worker_wait_time += et - st
77 |
78 | def start(self):
79 | self._should_stop = False
80 | self.background_thread = threading.Thread(
81 | target=self._background_queing_thread
82 | )
83 | self.background_thread.start()
84 |
85 | def get(self):
86 | st = time()
87 | worker_data = [
88 | self.rollout_queue.get(True) for _ in range(self.num_rollouts)
89 | ]
90 | et = time()
91 | self._host_wait_time += et - st
92 |
93 | rollouts = []
94 | terminal_rewards = []
95 | terminal_infos = []
96 | for w in worker_data:
97 | r, t, i = w["rollout"], w["terminal_rewards"], w["terminal_infos"]
98 | rollouts.append(r)
99 | terminal_rewards.append(t)
100 | terminal_infos.append(i)
101 |
102 | return rollouts, terminal_rewards, terminal_infos
103 |
104 | def close(self):
105 | self._should_stop = True
106 |
107 | # try to join background thread
108 | self.background_thread.join()
109 |
110 | def metrics(self):
111 | return {
112 | "Host wait time": self._host_wait_time,
113 | "Worker wait time": self._worker_wait_time,
114 | }
115 |
116 | def _restart_idle_workers(self):
117 | del_inds = []
118 | for d_ind, w_ind in enumerate(self._idle_workers):
119 | worker = self.workers[w_ind]
120 | self.futures.append(worker.run.remote())
121 | self.future_inds.append(w_ind)
122 | del_inds.append(d_ind)
123 |
124 | for d in del_inds:
125 | del self._idle_workers[d]
126 |
127 | def __len__(self):
128 | return len(self.rollout_queue)
129 |
--------------------------------------------------------------------------------
/adept/container/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .container import Container
2 | from .nccl_optimizer import NCCLOptimizer
3 |
--------------------------------------------------------------------------------
/adept/container/base/container.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Container:
5 | @staticmethod
6 | def load_network(network, path):
7 | network.load_state_dict(
8 | torch.load(path, map_location=lambda storage, loc: storage)
9 | )
10 | return network
11 |
12 | @staticmethod
13 | def load_optim(optimizer, path):
14 | optimizer.load_state_dict(
15 | torch.load(path, map_location=lambda storage, loc: storage)
16 | )
17 | return optimizer
18 |
19 | @staticmethod
20 | def init_next_save(initial_step_count, epoch_len):
21 | next_save = 0
22 | if initial_step_count > 0:
23 | while next_save <= initial_step_count:
24 | next_save += epoch_len
25 | return next_save
26 |
27 | @staticmethod
28 | def count_parameters(net):
29 | return sum(p.numel() for p in net.parameters() if p.requires_grad)
30 |
31 | @staticmethod
32 | def write_summaries(
33 | writer, step_count, total_loss, loss_dict, metric_dict, n_params
34 | ):
35 | writer.add_scalar("loss/total_loss", total_loss.item(), step_count)
36 | for l_name, loss in loss_dict.items():
37 | writer.add_scalar("loss/" + l_name, loss.item(), step_count)
38 | for m_name, metric in metric_dict.items():
39 | writer.add_scalar("metric/" + m_name, metric.item(), step_count)
40 | for p_name, param in n_params:
41 | p_name = p_name.replace(".", "/")
42 | writer.add_scalar(p_name, torch.norm(param).item(), step_count)
43 | if param.grad is not None:
44 | writer.add_scalar(
45 | p_name + ".grad", torch.norm(param.grad).item(), step_count
46 | )
47 |
--------------------------------------------------------------------------------
/adept/container/base/nccl_optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 |
17 | class NCCLOptimizer:
18 | def __init__(self, optimizer_fn, network, world_size, param_sync_rate=1000):
19 | self.network = network
20 | self.optimizer = optimizer_fn(self.network.parameters())
21 | self.param_sync_rate = param_sync_rate
22 | self._opt_count = 0
23 | self.process_group = None
24 | self.world_size = world_size
25 |
26 | def set_process_group(self, pg):
27 | self.process_group = pg
28 |
29 | def step(self):
30 | handles = []
31 | for param in self.network.parameters():
32 | handles.append(self.process_group.allreduce(param.grad))
33 | for handle in handles:
34 | handle.wait()
35 | for param in self.network.parameters():
36 | param.grad.mul_(1.0 / self.world_size)
37 | self.optimizer.step()
38 | self._opt_count += 1
39 |
40 | # sync params every once in a while to reduce numerical errors
41 | if self._opt_count % self.param_sync_rate == 0:
42 | self.sync_parameters()
43 | # can't just sync buffers, some are int and don't mean well
44 | # self.sync_buffers()
45 |
46 | def sync_parameters(self):
47 | for param in self.network.parameters():
48 | self.process_group.allreduce(param.data)
49 | param.data.mul_(1.0 / self.world_size)
50 |
51 | def sync_buffers(self):
52 | for b in self.network.buffers():
53 | self.process_group.allreduce(b.data)
54 | b.data.mul_(1.0 / self.world_size)
55 |
56 | def zero_grad(self):
57 | self.optimizer.zero_grad()
58 |
59 | def state_dict(self):
60 | return self.optimizer.state_dict()
61 |
62 | def load_state_dict(self, d):
63 | return self.optimizer.load_state_dict(d)
64 |
--------------------------------------------------------------------------------
/adept/container/base/updater.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Updater(metaclass=abc.ABCMeta):
5 | def __init__(self, optimizer, network, grad_norm_clip):
6 | self.optimizer = optimizer
7 | self.network = network
8 | self.grad_norm_clip = grad_norm_clip
9 |
10 | def step(self, loss):
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/adept/container/evaluation_thread.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import numpy as np
16 | from threading import Thread
17 | from time import time, sleep
18 |
19 |
20 | class EvaluationThread(LogsAndSummarizesRewards):
21 | def __init__(
22 | self,
23 | training_network,
24 | agent,
25 | env,
26 | nb_env,
27 | logger,
28 | summary_writer,
29 | step_rate_limit,
30 | override_step_count_fn=None,
31 | ):
32 | self._training_network = training_network
33 | self._agent = agent
34 | self._environment = env
35 | self._nb_env = nb_env
36 | self._logger = logger
37 | self._summary_writer = summary_writer
38 | self._step_rate_limit = step_rate_limit
39 | self._override_step_count_fn = override_step_count_fn
40 | self._thread = Thread(target=self._run)
41 | self._should_stop = False
42 |
43 | def start(self):
44 | self._thread.start()
45 |
46 | def stop(self):
47 | self._should_stop = True
48 | self._thread.join()
49 |
50 | def _run(self):
51 | next_obs = self.environment.reset()
52 | self.start_time = time()
53 | while not self._should_stop:
54 | if self._step_rate_limit > 0:
55 | sleep(1 / self._step_rate_limit)
56 | obs = next_obs
57 | actions = self.agent.act_eval(obs)
58 | next_obs, rewards, terminals, infos = self.environment.step(actions)
59 |
60 | self.agent.reset_internals(terminals)
61 | # Perform state updates
62 | terminal_rewards, terminal_infos = self.update_buffers(
63 | rewards, terminals, infos
64 | )
65 | self.log_episode_results(
66 | terminal_rewards, terminal_infos, self.local_step_count
67 | )
68 | self.write_reward_summaries(terminal_rewards, self.local_step_count)
69 |
70 | if np.any(terminals) and np.any(infos):
71 | self.network.load_state_dict(
72 | self._training_network.state_dict()
73 | )
74 |
75 | def log_episode_results(
76 | self, terminal_rewards, terminal_infos, step_count, initial_step_count=0
77 | ):
78 | if terminal_rewards:
79 | ep_reward = np.mean(terminal_rewards)
80 | self.logger.info(
81 | "eval_frames: {} reward: {} avg_eval_fps: {}".format(
82 | step_count,
83 | ep_reward,
84 | (step_count - initial_step_count)
85 | / (time() - self.start_time),
86 | )
87 | )
88 | return terminal_rewards
89 |
90 | @property
91 | def agent(self):
92 | return self._agent
93 |
94 | @property
95 | def environment(self):
96 | return self._environment
97 |
98 | @environment.setter
99 | def environment(self, new_env):
100 | self._environment = new_env
101 |
102 | @property
103 | def nb_env(self):
104 | return self._nb_env
105 |
106 | @property
107 | def logger(self):
108 | return self._logger
109 |
110 | @property
111 | def summary_writer(self):
112 | return self._summary_writer
113 |
114 | @property
115 | def summary_name(self):
116 | return "reward/eval"
117 |
118 | @property
119 | def local_step_count(self):
120 | if self._override_step_count_fn is not None:
121 | return self._override_step_count_fn()
122 | else:
123 | return self._local_step_count
124 |
125 | @local_step_count.setter
126 | def local_step_count(self, step_count):
127 | self._local_step_count = step_count
128 |
--------------------------------------------------------------------------------
/adept/container/render.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 |
17 | from adept.network import ModularNetwork
18 | from adept.registry import REGISTRY
19 | from adept.utils import dtensor_to_dev, listd_to_dlist
20 | from adept.utils.script_helpers import LogDirHelper
21 | from adept.utils.util import DotDict
22 |
23 |
24 | class RenderContainer:
25 | def __init__(
26 | self,
27 | actor,
28 | epoch_id,
29 | start,
30 | end,
31 | logger,
32 | log_id_dir,
33 | gpu_id,
34 | seed,
35 | manager,
36 | extra_args={},
37 | ):
38 | self.log_dir_helper = log_dir_helper = LogDirHelper(log_id_dir)
39 | self.train_args = train_args = log_dir_helper.load_args()
40 | self.train_args = DotDict({**self.train_args, **extra_args})
41 | self.device = device = self._device_from_gpu_id(gpu_id)
42 | self.logger = logger
43 |
44 | if epoch_id:
45 | epoch_ids = [epoch_id]
46 | else:
47 | epoch_ids = self.log_dir_helper.epochs()
48 | epoch_ids = filter(lambda eid: eid >= start, epoch_ids)
49 | if end != -1.0:
50 | epoch_ids = filter(lambda eid: eid <= end, epoch_ids)
51 | epoch_ids = list(epoch_ids)
52 | self.epoch_ids = epoch_ids
53 |
54 | engine = REGISTRY.lookup_engine(train_args.env)
55 | env_cls = REGISTRY.lookup_env(train_args.env)
56 | manager_cls = REGISTRY.lookup_manager(manager)
57 | self.env_mgr = manager_cls.from_args(
58 | self.train_args, engine, env_cls, seed=seed, nb_env=1
59 | )
60 | if train_args.agent:
61 | agent = train_args.agent
62 | else:
63 | agent = train_args.actor_host
64 | output_space = REGISTRY.lookup_output_space(
65 | agent, self.env_mgr.action_space
66 | )
67 | actor_cls = REGISTRY.lookup_actor(actor)
68 | self.actor = actor_cls.from_args(
69 | actor_cls.prompt(), self.env_mgr.action_space
70 | )
71 |
72 | self.network = self._init_network(
73 | train_args,
74 | self.env_mgr.observation_space,
75 | self.env_mgr.gpu_preprocessor,
76 | output_space,
77 | REGISTRY,
78 | ).to(device)
79 |
80 | @staticmethod
81 | def _device_from_gpu_id(gpu_id):
82 | return torch.device(
83 | "cuda:{}".format(gpu_id)
84 | if (torch.cuda.is_available() and gpu_id >= 0)
85 | else "cpu"
86 | )
87 |
88 | @staticmethod
89 | def _init_network(
90 | train_args, obs_space, gpu_preprocessor, output_space, net_reg
91 | ):
92 | if train_args.custom_network:
93 | net_cls = net_reg.lookup_network(train_args.custom_network)
94 | else:
95 | net_cls = ModularNetwork
96 |
97 | return net_cls.from_args(
98 | train_args, obs_space, output_space, gpu_preprocessor, net_reg
99 | )
100 |
101 | def run(self):
102 | for epoch_id in self.epoch_ids:
103 | reward_buf = 0
104 | for net_path in self.log_dir_helper.network_paths_at_epoch(
105 | epoch_id
106 | ):
107 | self.network.load_state_dict(
108 | torch.load(
109 | net_path, map_location=lambda storage, loc: storage
110 | )
111 | )
112 | self.network.eval()
113 |
114 | internals = listd_to_dlist(
115 | [self.network.new_internals(self.device)]
116 | )
117 | next_obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
118 | self.env_mgr.render()
119 |
120 | episode_complete = False
121 | while not episode_complete:
122 | obs = next_obs
123 | with torch.no_grad():
124 | actions, _, internals = self.actor.act(
125 | self.network, obs, internals
126 | )
127 | next_obs, rewards, terminals, infos = self.env_mgr.step(
128 | actions
129 | )
130 | self.env_mgr.render()
131 | next_obs = dtensor_to_dev(next_obs, self.device)
132 |
133 | reward_buf += rewards[0]
134 |
135 | if terminals[0]:
136 | episode_complete = True
137 |
138 | print(f"EPOCH_ID: {epoch_id} REWARD: {reward_buf}")
139 |
140 | def close(self):
141 | self.env_mgr.close()
142 |
--------------------------------------------------------------------------------
/adept/env/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from adept.env.openai_gym import ATARI_ENVS
16 | from .base.env_module import EnvModule
17 | from .openai_gym import AdeptGymEnv
18 |
19 | ENV_REG = [AdeptGymEnv]
20 |
--------------------------------------------------------------------------------
/adept/env/_spaces.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from gym import spaces
16 |
17 |
18 | class Space(dict):
19 | def __init__(self, entries_by_name):
20 | super(Space, self).__init__(entries_by_name)
21 |
22 | @classmethod
23 | def from_gym(cls, gym_space):
24 | entries_by_name = Space._detect_gym_spaces(gym_space)
25 | return cls(entries_by_name)
26 |
27 | @staticmethod
28 | def _detect_gym_spaces(gym_space):
29 | if isinstance(gym_space, spaces.Discrete):
30 | return {"Discrete": (gym_space.n,)}
31 | elif isinstance(gym_space, spaces.MultiDiscrete):
32 | raise NotImplementedError
33 | elif isinstance(gym_space, spaces.MultiBinary):
34 | return {"MultiBinary": (gym_space.n,)}
35 | elif isinstance(gym_space, spaces.Box):
36 | return {"Box": gym_space.shape}
37 | elif isinstance(gym_space, spaces.Dict):
38 | return {
39 | name: list(Space._detect_gym_spaces(s).values())[0]
40 | for name, s in gym_space.spaces.items()
41 | }
42 | elif isinstance(gym_space, spaces.Tuple):
43 | return {
44 | idx: list(Space._detect_gym_spaces(s).values())[0]
45 | for idx, s in enumerate(gym_space.spaces)
46 | }
47 |
48 | @staticmethod
49 | def dtypes_from_gym(gym_space):
50 | if isinstance(gym_space, spaces.Discrete):
51 | return {"Discrete": gym_space.dtype}
52 | elif isinstance(gym_space, spaces.MultiDiscrete):
53 | raise NotImplementedError
54 | elif isinstance(gym_space, spaces.MultiBinary):
55 | return {"MultiBinary": gym_space.dtype}
56 | elif isinstance(gym_space, spaces.Box):
57 | return {"Box": gym_space.dtype}
58 | elif isinstance(gym_space, spaces.Dict):
59 | return {
60 | name: list(Space.dtypes_from_gym(s).values())[0]
61 | for name, s in gym_space.spaces.items()
62 | }
63 | elif isinstance(gym_space, spaces.Tuple):
64 | return {
65 | idx: list(Space.dtypes_from_gym(s).values())[0]
66 | for idx, s in enumerate(gym_space.spaces)
67 | }
68 | else:
69 | raise NotImplementedError
70 |
--------------------------------------------------------------------------------
/adept/env/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/env/base/__init__.py
--------------------------------------------------------------------------------
/adept/env/base/_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 |
18 | class HasEnvMetaData(metaclass=abc.ABCMeta):
19 | @property
20 | @abc.abstractmethod
21 | def observation_space(self):
22 | raise NotImplementedError
23 |
24 | @property
25 | @abc.abstractmethod
26 | def action_space(self):
27 | raise NotImplementedError
28 |
29 | @property
30 | @abc.abstractmethod
31 | def cpu_preprocessor(self):
32 | raise NotImplementedError
33 |
34 | @property
35 | @abc.abstractmethod
36 | def gpu_preprocessor(self):
37 | raise NotImplementedError
38 |
39 |
40 | class EnvBase(HasEnvMetaData, metaclass=abc.ABCMeta):
41 | @abc.abstractmethod
42 | def step(self, action):
43 | """
44 | :param action: Dict[ActionID, Any] Action dictionary
45 | :return: Tuple[Observation, Reward, Terminal, Info]
46 | """
47 | raise NotImplementedError
48 |
49 | @abc.abstractmethod
50 | def reset(self, **kwargs):
51 | """
52 | :param kwargs:
53 | :return: Dict[ObservationID, Any] Observation dictionary
54 | """
55 | raise NotImplementedError
56 |
57 | @abc.abstractmethod
58 | def close(self):
59 | """
60 | Close environment. Release resources.
61 |
62 | :return:
63 | """
64 | raise NotImplementedError
65 |
--------------------------------------------------------------------------------
/adept/env/base/env_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from adept.utils.requires_args import RequiresArgsMixin
18 | from adept.env.base._env import EnvBase
19 |
20 |
21 | class EnvModule(EnvBase, RequiresArgsMixin, metaclass=abc.ABCMeta):
22 | ids = None
23 |
24 | """
25 | Implement this class to add your custom environment. Don't forget to
26 | implement args.
27 | """
28 |
29 | def __init__(self, action_space, cpu_preprocessor, gpu_preprocessor):
30 | """
31 | :param observation_space: ._spaces.Spaces
32 | :param action_space: ._spaces.Spaces
33 | :param cpu_preprocessor: adept.preprocess.observation.ObsPreprocessor
34 | :param gpu_preprocessor: adept.preprocess.observation.ObsPreprocessor
35 | """
36 | self._action_space = action_space
37 | self._cpu_preprocessor = cpu_preprocessor
38 | self._gpu_preprocessor = gpu_preprocessor
39 |
40 | @classmethod
41 | @abc.abstractmethod
42 | def from_args(cls, args, seed, **kwargs):
43 | """
44 | Construct from arguments. For convenience.
45 |
46 | :param args: Arguments object
47 | :param seed: Integer used to seed this environment.
48 | :param kwargs: Any custom arguments are passed through kwargs.
49 | :return: EnvModule instance.
50 | """
51 | raise NotImplementedError
52 |
53 | @classmethod
54 | def from_args_curry(cls, args, seed, **kwargs):
55 | def _f():
56 | return cls.from_args(args, seed, **kwargs)
57 |
58 | return _f
59 |
60 | @property
61 | def observation_space(self):
62 | return self._gpu_preprocessor.observation_space
63 |
64 | @property
65 | def action_space(self):
66 | return self._action_space
67 |
68 | @property
69 | def cpu_preprocessor(self):
70 | return self._cpu_preprocessor
71 |
72 | @property
73 | def gpu_preprocessor(self):
74 | return self._gpu_preprocessor
75 |
76 | @classmethod
77 | def check_ids_implemented(cls):
78 | if cls.ids is None:
79 | raise NotImplementedError(
80 | 'Subclass must define class attribute "ids"'
81 | )
82 |
--------------------------------------------------------------------------------
/adept/exp/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ExpModule
2 | from .base import ExpSpecBuilder
3 | from .replay import ExperienceReplay, PrioritizedExperienceReplay
4 | from .rollout import Rollout
5 |
6 | EXP_REG = [
7 | Rollout, # , ExperienceReplay, PrioritizedExperienceReplay
8 | ]
9 |
--------------------------------------------------------------------------------
/adept/exp/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .spec_builder import ExpSpecBuilder
2 | from .exp_module import ExpModule
3 |
--------------------------------------------------------------------------------
/adept/exp/base/exp_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from adept.utils.requires_args import RequiresArgsMixin
18 |
19 |
20 | class ExpModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
21 | @abc.abstractmethod
22 | def write_actor(self, experience):
23 | raise NotImplementedError
24 |
25 | @abc.abstractmethod
26 | def write_env(self, obs, rewards, terminals, infos):
27 | raise NotImplementedError
28 |
29 | @abc.abstractmethod
30 | def read(self):
31 | raise NotImplementedError
32 |
33 | @abc.abstractmethod
34 | def is_ready(self):
35 | raise NotImplementedError
36 |
--------------------------------------------------------------------------------
/adept/exp/base/spec_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 |
17 | class ExpSpecBuilder:
18 | def __init__(
19 | self, obs_keys, act_keys, internal_keys, key_types, exp_keys, build_fn
20 | ):
21 | self.obs_keys = sorted(obs_keys.keys())
22 | self.action_keys = sorted(act_keys.keys())
23 | self.internal_keys = sorted(internal_keys.keys())
24 | self.key_types = key_types
25 | self.exp_keys = sorted(exp_keys)
26 | self.build_fn = build_fn
27 |
28 | def __call__(self, rollout_len):
29 | return self.build_fn(rollout_len)
30 |
--------------------------------------------------------------------------------
/adept/exp/replay.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from adept.exp.base.exp_module import ExpModule
16 |
17 |
18 | class ExperienceReplay(ExpModule):
19 | def write_actor(self, items):
20 | pass
21 |
22 | def write_env(self, obs, rewards, terminals, infos):
23 | pass
24 |
25 | def read(self):
26 | pass
27 |
28 | def is_ready(self):
29 | pass
30 |
31 |
32 | class PrioritizedExperienceReplay(ExpModule):
33 | def write_actor(self, items):
34 | pass
35 |
36 | def write_env(self, obs, rewards, terminals, infos):
37 | pass
38 |
39 | def read(self):
40 | pass
41 |
42 | def is_ready(self):
43 | pass
44 |
--------------------------------------------------------------------------------
/adept/globals.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | VERSION = "0.4.0dev0"
16 |
--------------------------------------------------------------------------------
/adept/learner/__init__.py:
--------------------------------------------------------------------------------
1 | from adept.learner.base.learner_module import LearnerModule
2 | from .ac_rollout import ACRolloutLearner
3 | from .impala import ImpalaLearner
4 |
5 | LEARNER_REG = [ACRolloutLearner, ImpalaLearner]
6 |
--------------------------------------------------------------------------------
/adept/learner/ac_rollout.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base.learner_module import LearnerModule
4 | from .base.dm_return_scale import DeepMindReturnScaler
5 |
6 |
7 | class ACRolloutLearner(LearnerModule):
8 | """
9 | Actor Critic Rollout Learner
10 | """
11 |
12 | args = {
13 | "discount": 0.99,
14 | "normalize_advantage": False,
15 | "entropy_weight": 0.01,
16 | "return_scale": False,
17 | }
18 |
19 | def __init__(
20 | self,
21 | reward_normalizer,
22 | discount,
23 | normalize_advantage,
24 | entropy_weight,
25 | return_scale,
26 | ):
27 | self.reward_normalizer = reward_normalizer
28 | self.discount = discount
29 | self.normalize_advantage = normalize_advantage
30 | self.entropy_weight = entropy_weight
31 | self.return_scale = return_scale
32 | if return_scale:
33 | self.dm_scaler = DeepMindReturnScaler(10.0 ** -3)
34 |
35 | @classmethod
36 | def from_args(cls, args, reward_normalizer):
37 | return cls(
38 | reward_normalizer,
39 | args.discount,
40 | args.normalize_advantage,
41 | args.entropy_weight,
42 | args.return_scale,
43 | )
44 |
45 | def learn_step(self, updater, network, experiences, next_obs, internals):
46 | # normalize rewards
47 | rewards = self.reward_normalizer(torch.stack(experiences.rewards))
48 |
49 | # torch stack rollouts
50 | r_log_probs_action = torch.stack(experiences.log_probs)
51 | r_values = torch.stack(experiences.values)
52 | r_entropies = torch.stack(experiences.entropies)
53 |
54 | # estimate value of next state
55 | with torch.no_grad():
56 | results, _, _ = network(next_obs, internals)
57 | last_values = results["critic"].squeeze(1).data
58 |
59 | # compute nstep return and advantage over batch
60 | r_tgt_returns = self.compute_returns(
61 | last_values, rewards, experiences.terminals
62 | )
63 | r_advantages = r_tgt_returns - r_values.data
64 |
65 | # normalize advantage so that an even number
66 | # of actions are reinforced and penalized
67 | if self.normalize_advantage:
68 | r_advantages = (r_advantages - r_advantages.mean()) / (
69 | r_advantages.std() + 1e-5
70 | )
71 |
72 | # batched losses
73 | policy_loss = -(r_log_probs_action) * r_advantages.unsqueeze(-1)
74 | # mean over actions, seq, batch
75 | policy_loss = policy_loss.mean()
76 | entropy_loss = -r_entropies.mean() * self.entropy_weight
77 | value_loss = 0.5 * (r_tgt_returns - r_values).pow(2).mean()
78 |
79 | updater.step(value_loss + policy_loss + entropy_loss)
80 |
81 | losses = {
82 | "value_loss": value_loss,
83 | "policy_loss": policy_loss,
84 | "entropy_loss": entropy_loss,
85 | }
86 | metrics = {}
87 | return losses, metrics
88 |
89 | def compute_returns(self, bootstrap_value, rewards, terminals):
90 | # First step of nstep reward target is estimated value of t+1
91 | target_return = bootstrap_value
92 | rollout_len = len(rewards)
93 | nstep_target_returns = []
94 | for i in reversed(range(rollout_len)):
95 | reward = rewards[i]
96 | terminal_mask = 1.0 - terminals[i].float()
97 |
98 | if self.return_scale:
99 | target_return = self.dm_scaler.calc_scale(
100 | reward
101 | + self.discount
102 | * self.dm_scaler.calc_inverse_scale(target_return)
103 | * terminal_mask
104 | )
105 | else:
106 | target_return = reward + (
107 | self.discount * target_return * terminal_mask
108 | )
109 | nstep_target_returns.append(target_return)
110 |
111 | # reverse lists
112 | nstep_target_returns = torch.stack(
113 | list(reversed(nstep_target_returns))
114 | ).data
115 |
116 | return nstep_target_returns
117 |
--------------------------------------------------------------------------------
/adept/learner/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .learner_module import LearnerModule
2 |
--------------------------------------------------------------------------------
/adept/learner/base/dm_return_scale.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class DeepMindReturnScaler:
5 | """
6 | Scale returns as in R2D2.
7 | https://openreview.net/pdf?id=r1lyTjAqYX
8 | """
9 |
10 | def __init__(self, scale):
11 | self.scale = scale
12 |
13 | def calc_scale(self, x):
14 | return (
15 | torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + self.scale * x
16 | )
17 |
18 | def calc_inverse_scale(self, x):
19 | sign = torch.sign(x)
20 | sqrt = torch.sqrt(1 + 4 * self.scale * (torch.abs(x) + 1 + self.scale))
21 | return sign * ((((sqrt - 1) / (2 * self.scale)) ** 2) - 1)
22 |
--------------------------------------------------------------------------------
/adept/learner/base/learner_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | """
16 | A Learner receives agent-environment interactions which are used to compute
17 | loss.
18 | """
19 | import abc
20 | from adept.utils.requires_args import RequiresArgsMixin
21 |
22 |
23 | class LearnerModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
24 | """
25 | This one of the modules to use for custom Actor-Learner code.
26 | """
27 |
28 | @classmethod
29 | @abc.abstractmethod
30 | def from_args(self, args, reward_normalizer):
31 | raise NotImplementedError
32 |
33 | @abc.abstractmethod
34 | def learn_step(self, updater, network, experiences, next_obs, internals):
35 | raise NotImplementedError
36 |
--------------------------------------------------------------------------------
/adept/manager/__init__.py:
--------------------------------------------------------------------------------
1 | from adept.manager.simple_env_manager import SimpleEnvManager
2 | from adept.manager.subproc_env_manager import SubProcEnvManager
3 | from adept.manager.base.manager_module import EnvManagerModule
4 |
5 |
6 | MANAGER_REG = [SimpleEnvManager, SubProcEnvManager]
7 |
--------------------------------------------------------------------------------
/adept/manager/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/manager/base/__init__.py
--------------------------------------------------------------------------------
/adept/manager/base/manager_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from adept.env.base._env import EnvBase
18 | from adept.utils.requires_args import RequiresArgsMixin
19 |
20 |
21 | class EnvManagerModule(EnvBase, RequiresArgsMixin, metaclass=abc.ABCMeta):
22 | def __init__(self, env_fns, engine):
23 | self._env_fns = env_fns
24 | self._engine = engine
25 |
26 | @property
27 | def env_fns(self):
28 | return self._env_fns
29 |
30 | @property
31 | def engine(self):
32 | return self._engine
33 |
34 | @property
35 | def nb_env(self):
36 | return len(self._env_fns)
37 |
38 | @classmethod
39 | def from_args(cls, args, engine, env_cls, seed=None, nb_env=None, **kwargs):
40 | if seed is None:
41 | seed = int(args.seed)
42 | if nb_env is None:
43 | nb_env = args.nb_env
44 |
45 | env_fns = []
46 | for i in range(nb_env):
47 | env_fns.append(env_cls.from_args_curry(args, seed + i, **kwargs))
48 | return cls(env_fns, engine)
49 |
--------------------------------------------------------------------------------
/adept/manager/simple_env_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import numpy as np
16 | import torch
17 |
18 | from adept.utils import listd_to_dlist
19 | from adept.utils.util import dlist_to_listd
20 | from .base.manager_module import EnvManagerModule
21 |
22 |
23 | class SimpleEnvManager(EnvManagerModule):
24 | """
25 | Manages multiple env in the same process. This is slower than a
26 | SubProcEnvManager but allows debugging.
27 | """
28 |
29 | args = {}
30 |
31 | def __init__(self, env_fns, engine):
32 | super(SimpleEnvManager, self).__init__(env_fns, engine)
33 | self.envs = [fn() for fn in env_fns]
34 | env = self.envs[0]
35 | self._observation_space = env.observation_space
36 | self._action_space = env.action_space
37 | self._cpu_preprocessor = env.cpu_preprocessor
38 | self._gpu_preprocessor = env.gpu_preprocessor
39 |
40 | self.buf_obs = [None for _ in range(self.nb_env)]
41 | self.buf_dones = [None for _ in range(self.nb_env)]
42 | self.buf_rews = [None for _ in range(self.nb_env)]
43 | self.buf_infos = [None for _ in range(self.nb_env)]
44 | self.actions = None
45 |
46 | @property
47 | def cpu_preprocessor(self):
48 | return self._cpu_preprocessor
49 |
50 | @property
51 | def gpu_preprocessor(self):
52 | return self._gpu_preprocessor
53 |
54 | @property
55 | def observation_space(self):
56 | return self._observation_space
57 |
58 | @property
59 | def action_space(self):
60 | return self._action_space
61 |
62 | def step(self, actions):
63 | self.step_async(actions)
64 | return self.step_wait()
65 |
66 | def step_async(self, actions):
67 | self.actions = dlist_to_listd(actions)
68 |
69 | def step_wait(self):
70 | obs = []
71 | for e in range(self.nb_env):
72 | (
73 | ob,
74 | self.buf_rews[e],
75 | self.buf_dones[e],
76 | self.buf_infos[e],
77 | ) = self.envs[e].step(self.actions[e])
78 | if self.buf_dones[e]:
79 | ob = self.envs[e].reset()
80 | obs.append(ob)
81 | obs = listd_to_dlist(obs)
82 | new_obs = {}
83 | for k, v in dummy_handle_ob(obs).items():
84 | if self._is_tensor_key(k):
85 | new_obs[k] = torch.stack(v)
86 | else:
87 | new_obs[k] = v
88 | self.buf_obs = new_obs
89 |
90 | return (
91 | self.buf_obs,
92 | torch.tensor(self.buf_rews),
93 | torch.tensor(self.buf_dones),
94 | self.buf_infos,
95 | )
96 |
97 | def reset(self):
98 | obs = []
99 | for e in range(self.nb_env):
100 | ob = self.envs[e].reset()
101 | obs.append(ob)
102 | obs = listd_to_dlist(obs)
103 | new_obs = {}
104 | for k, v in dummy_handle_ob(obs).items():
105 | if self._is_tensor_key(k):
106 | new_obs[k] = torch.stack(v)
107 | else:
108 | new_obs[k] = v
109 | self.buf_obs = new_obs
110 | return self.buf_obs
111 |
112 | def close(self):
113 | return [e.close() for e in self.envs]
114 |
115 | def render(self, mode="human"):
116 | return [e.render(mode=mode) for e in self.envs]
117 |
118 | def _is_tensor_key(self, key):
119 | return None not in self.cpu_preprocessor.observation_space[key]
120 |
121 |
122 | def dummy_handle_ob(ob):
123 | new_ob = {}
124 | for k, v in ob.items():
125 | if isinstance(v, np.ndarray):
126 | new_ob[k] = torch.from_numpy(v)
127 | else:
128 | new_ob[k] = v
129 | return new_ob
130 |
--------------------------------------------------------------------------------
/adept/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .norm import Identity
16 | from .attention import MultiHeadSelfAttention, RMCCell
17 | from .sequence import LSTMCellLayerNorm
18 | from .spatial import Residual2DPreact
19 |
--------------------------------------------------------------------------------
/adept/modules/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import math
16 |
17 | import torch
18 | from torch import nn
19 | from torch.nn import Parameter, functional as F
20 |
21 |
22 | class GaussianLinear(nn.Module):
23 | def __init__(self, fan_in, nodes):
24 | super().__init__()
25 | self.mu = nn.Linear(fan_in, nodes)
26 | # init_tflearn_fc_(self.mu)
27 | self.std = nn.Linear(fan_in, nodes)
28 | # init_tflearn_fc_(self.std)
29 |
30 | def forward(self, x):
31 | mu = self.mu(x)
32 | if self.training:
33 | std = self.std(x)
34 | std = torch.exp(0.5 * std)
35 | eps = torch.randn_like(std)
36 | return eps.mul(std).add_(mu)
37 | else:
38 | return mu
39 |
40 | def get_parameter_names(self):
41 | return ["Mu_W", "Mu_b", "Std_W", "Std_b"]
42 |
43 |
44 | class NoisyLinear(nn.Linear):
45 | """
46 | Reference implementation:
47 | https://github.com/Kaixhin/NoisyNet-A3C/blob/master/model.py
48 | """
49 |
50 | def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
51 | super(NoisyLinear, self).__init__(in_features, out_features, bias=True)
52 | # µ^w and µ^b reuse self.weight and self.bias
53 | self.sigma_init = sigma_init
54 | self.sigma_weight = Parameter(
55 | torch.Tensor(out_features, in_features)
56 | ) # σ^w
57 | self.sigma_bias = Parameter(torch.Tensor(out_features)) # σ^b
58 | self.init_params()
59 |
60 | def init_params(self):
61 | limit = math.sqrt(3 / self.in_features)
62 |
63 | self.weight.data.uniform_(-limit, limit)
64 | self.bias.data.uniform_(-limit, limit)
65 | self.sigma_weight.data.fill_(self.sigma_init)
66 | self.sigma_bias.data.fill_(self.sigma_init)
67 |
68 | def forward(self, x, internals):
69 | if self.training:
70 | w = self.weight + self.sigma_weight * internals[0]
71 | b = self.bias + self.sigma_bias * internals[1]
72 | else:
73 | w = self.weight + self.sigma_weight
74 | b = self.bias + self.sigma_bias
75 | return F.linear(x, w, b)
76 |
77 | def batch_forward(self, x, internals, batch_size=None):
78 | print(
79 | "WARNING: calling forward multiple times is actually"
80 | "faster than this and takes less memory"
81 | )
82 | batch_size = batch_size if batch_size is not None else x.shape[0]
83 | x = x.unsqueeze(1)
84 | # internals come in as [[w, b], ...] reshape to [w, ...], [b, ...]
85 | eps_w, eps_b = zip(*internals)
86 | eps_w = torch.stack(eps_w)
87 | eps_b = torch.stack(eps_b)
88 | batch_w = self.weight.unsqueeze(0).expand(
89 | batch_size, -1, -1
90 | ) + self.sigma_weight.unsqueeze(0).expand(batch_size, -1, -1)
91 | batch_w += eps_w
92 | # permute to b x m x p
93 | batch_w = batch_w.permute(0, 2, 1)
94 | batch_b = self.bias.expand(batch_size, -1) + self.sigma_bias.expand(
95 | batch_size, -1
96 | )
97 | batch_b += eps_b
98 |
99 | bmm = torch.bmm(x, batch_w).squeeze(1)
100 |
101 | return bmm + batch_b
102 |
103 | def reset(self, gpu=False, device=None):
104 | # sample new noise
105 | if not gpu:
106 | return (
107 | torch.randn(self.out_features, self.in_features).detach(),
108 | torch.randn(self.out_features).detach(),
109 | )
110 | else:
111 | return (
112 | torch.randn(self.out_features, self.in_features)
113 | .cuda(device, non_blocking=True)
114 | .detach(),
115 | torch.randn(self.out_features)
116 | .cuda(device, non_blocking=True)
117 | .detach(),
118 | )
119 |
120 | def get_parameter_names(self):
121 | return ["W", "b", "sigma_W", "sigma_b"]
122 |
--------------------------------------------------------------------------------
/adept/modules/norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 |
17 |
18 | class Identity(torch.nn.Module):
19 | def forward(self, x):
20 | return x
21 |
--------------------------------------------------------------------------------
/adept/modules/sequence.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from torch.nn import Module, Linear, LayerNorm
17 |
18 |
19 | class LSTMCellLayerNorm(Module):
20 | """
21 | A lstm cell that layer norms the cell state
22 | https://github.com/seba-1511/lstms.pth/blob/master/lstms/lstm.py for reference.
23 | Original License Apache 2.0
24 | """
25 |
26 | def __init__(self, input_size, hidden_size, bias=True, forget_bias=0):
27 | super().__init__()
28 | self.input_size = input_size
29 | self.hidden_size = hidden_size
30 | self.ih = Linear(input_size, 4 * hidden_size, bias=bias)
31 | self.hh = Linear(hidden_size, 4 * hidden_size, bias=bias)
32 |
33 | if bias:
34 | self.ih.bias.data.fill_(0)
35 | self.hh.bias.data.fill_(0)
36 | # forget bias init
37 | self.ih.bias.data[hidden_size : hidden_size * 2].fill_(forget_bias)
38 | self.hh.bias.data[hidden_size : hidden_size * 2].fill_(forget_bias)
39 |
40 | self.ln_cell = LayerNorm(hidden_size)
41 |
42 | def forward(self, x, hidden):
43 | """
44 | LSTM Cell that layer normalizes the cell state.
45 | :param x: Tensor{B, C}
46 | :param hidden: A Tuple[Tensor{B, C}, Tensor{B, C}] of (previous output, cell state)
47 | :return:
48 | """
49 | h, c = hidden
50 |
51 | # Linear mappings
52 | i2h = self.ih(x)
53 | h2h = self.hh(h)
54 | preact = i2h + h2h
55 |
56 | # activations
57 | gates = preact[:, : 3 * self.hidden_size].sigmoid()
58 | g_t = preact[:, 3 * self.hidden_size :].tanh()
59 | i_t = gates[:, : self.hidden_size]
60 | f_t = gates[:, self.hidden_size : 2 * self.hidden_size]
61 | o_t = gates[:, -self.hidden_size :]
62 |
63 | # cell computations
64 | c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t)
65 | c_t = self.ln_cell(c_t)
66 | h_t = torch.mul(o_t, c_t.tanh())
67 |
68 | return h_t, c_t
69 |
--------------------------------------------------------------------------------
/adept/modules/spatial.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from torch import nn as nn
16 | from torch.nn import functional as F
17 |
18 |
19 | class Residual2DPreact(nn.Module):
20 | def __init__(self, nb_in_chan, nb_out_chan, stride=1):
21 | super(Residual2DPreact, self).__init__()
22 |
23 | self.nb_in_chan = nb_in_chan
24 | self.nb_out_chan = nb_out_chan
25 | self.stride = stride
26 |
27 | self.bn1 = nn.BatchNorm2d(nb_in_chan)
28 | self.conv1 = nn.Conv2d(
29 | nb_in_chan, nb_out_chan, 3, stride=stride, padding=1, bias=False
30 | )
31 | self.bn2 = nn.BatchNorm2d(nb_out_chan)
32 | self.conv2 = nn.Conv2d(
33 | nb_out_chan, nb_out_chan, 3, stride=1, padding=1, bias=False
34 | )
35 |
36 | relu_gain = nn.init.calculate_gain("relu")
37 | self.conv1.weight.data.mul_(relu_gain)
38 | self.conv2.weight.data.mul_(relu_gain)
39 |
40 | self.do_projection = (
41 | self.nb_in_chan != self.nb_out_chan or self.stride > 1
42 | )
43 | if self.do_projection:
44 | self.projection = nn.Conv2d(
45 | nb_in_chan, nb_out_chan, 3, stride=stride, padding=1
46 | )
47 | self.projection.weight.data.mul_(relu_gain)
48 |
49 | def forward(self, x):
50 | first = F.relu(self.bn1(x))
51 | if self.do_projection:
52 | projection = self.projection(first)
53 | else:
54 | projection = x
55 | x = self.conv1(first)
56 | x = self.conv2(F.relu(self.bn2(x)))
57 | return x + projection
58 |
--------------------------------------------------------------------------------
/adept/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .base.network_module import NetworkModule
2 | from .modular_network import ModularNetwork
3 | from .net1d.submodule_1d import SubModule1D
4 | from .net2d.submodule_2d import SubModule2D
5 | from .net3d.submodule_3d import SubModule3D
6 | from .net4d.submodule_4d import SubModule4D
7 |
8 | from .net1d.linear import Linear
9 | from .net1d.identity_1d import Identity1D
10 | from .net1d.lstm import LSTM
11 |
12 | from .net2d.identity_2d import Identity2D
13 |
14 | from .net3d.identity_3d import Identity3D
15 | from .net3d.four_conv import FourConv
16 |
17 | from .net4d.identity_4d import Identity4D
18 |
19 | NET_REG = []
20 | SUBMOD_REG = [
21 | Identity1D,
22 | Linear,
23 | LSTM,
24 | Identity2D,
25 | Identity3D,
26 | FourConv,
27 | Identity4D,
28 | ]
29 |
--------------------------------------------------------------------------------
/adept/network/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/network/base/__init__.py
--------------------------------------------------------------------------------
/adept/network/base/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | import torch
18 | from torch import distributed as dist
19 |
20 |
21 | class BaseNetwork(torch.nn.Module):
22 | @classmethod
23 | @abc.abstractmethod
24 | def from_args(
25 | cls, args, observation_space, output_space, gpu_preprocessor, net_reg
26 | ):
27 | raise NotImplementedError
28 |
29 | @abc.abstractmethod
30 | def new_internals(self, device):
31 | """
32 | :return: Dict[InternalKey, torch.Tensor (ND)]
33 | """
34 | raise NotImplementedError
35 |
36 | @abc.abstractmethod
37 | def forward(self, observation, internals):
38 | raise NotImplementedError
39 |
40 | def internal_space(self):
41 | return {k: t.shape for k, t in self.new_internals("cpu").items()}
42 |
43 | def sync(self, src, grp=None, async_op=False):
44 |
45 | keys = []
46 | handles = []
47 |
48 | for k, t in self.state_dict().items():
49 | if grp is None:
50 | h = dist.broadcast(t, src, async_op=True)
51 | else:
52 | h = dist.broadcast(t, src, grp, async_op=True)
53 |
54 | keys.append(k)
55 | handles.append(h)
56 |
57 | if not async_op:
58 | for k, h in zip(keys, handles):
59 | h.wait()
60 |
61 | return handles
62 |
--------------------------------------------------------------------------------
/adept/network/base/network_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from adept.network.base.base import BaseNetwork
18 | from adept.utils.requires_args import RequiresArgsMixin
19 |
20 |
21 | class NetworkModule(BaseNetwork, RequiresArgsMixin, metaclass=abc.ABCMeta):
22 | pass
23 |
--------------------------------------------------------------------------------
/adept/network/net1d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/network/net1d/__init__.py
--------------------------------------------------------------------------------
/adept/network/net1d/identity_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_1d import SubModule1D
16 |
17 |
18 | class Identity1D(SubModule1D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/adept/network/net1d/linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from __future__ import division
16 |
17 | from torch import nn
18 | from torch.nn import functional as F
19 |
20 | from adept.modules import Identity
21 |
22 | from .submodule_1d import SubModule1D
23 |
24 |
25 | class Linear(SubModule1D):
26 | args = {"linear_normalize": "bn", "linear_nb_hidden": 512, "nb_layer": 3}
27 |
28 | def __init__(self, input_shape, id, normalize, nb_hidden, nb_layer):
29 | super().__init__(input_shape, id)
30 | self._nb_hidden = nb_hidden
31 |
32 | nb_input_channel = input_shape[0]
33 |
34 | bias = not normalize
35 | self.linears = nn.ModuleList(
36 | [
37 | nn.Linear(
38 | nb_input_channel if i == 0 else nb_hidden, nb_hidden, bias
39 | )
40 | for i in range(nb_layer)
41 | ]
42 | )
43 | if normalize == "bn":
44 | self.norms = nn.ModuleList(
45 | [nn.BatchNorm1d(nb_hidden) for _ in range(nb_layer)]
46 | )
47 | elif normalize == "gn":
48 | if nb_hidden % 16 != 0:
49 | raise Exception(
50 | "linear_nb_hidden must be divisible by 16 for Group Norm"
51 | )
52 | self.norms = nn.ModuleList(
53 | [
54 | nn.GroupNorm(nb_hidden // 16, nb_hidden)
55 | for _ in range(nb_layer)
56 | ]
57 | )
58 | else:
59 | self.norms = nn.ModuleList([Identity() for _ in range(nb_layer)])
60 |
61 | @classmethod
62 | def from_args(cls, args, input_shape, id):
63 | return cls(
64 | input_shape,
65 | id,
66 | args.linear_normalize,
67 | args.linear_nb_hidden,
68 | args.nb_layer,
69 | )
70 |
71 | def _forward(self, xs, internals, **kwargs):
72 | for linear, norm in zip(self.linears, self.norms):
73 | xs = F.relu(norm(linear(xs)))
74 | return xs, {}
75 |
76 | def _new_internals(self):
77 | return {}
78 |
79 | @property
80 | def _output_shape(self):
81 | return (self._nb_hidden,)
82 |
--------------------------------------------------------------------------------
/adept/network/net1d/lstm.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from torch.nn import LSTMCell
17 |
18 | from adept.modules import LSTMCellLayerNorm
19 | from .submodule_1d import SubModule1D
20 |
21 |
22 | class LSTM(SubModule1D):
23 | args = {"lstm_normalize": True, "lstm_nb_hidden": 512}
24 |
25 | def __init__(self, input_shape, id, normalize, nb_hidden):
26 | super().__init__(input_shape, id)
27 | self._nb_hidden = nb_hidden
28 |
29 | if normalize:
30 | self.lstm = LSTMCellLayerNorm(input_shape[0], nb_hidden)
31 | else:
32 | self.lstm = LSTMCell(input_shape[0], nb_hidden)
33 | self.lstm.bias_ih.data.fill_(0)
34 | self.lstm.bias_hh.data.fill_(0)
35 |
36 | @classmethod
37 | def from_args(cls, args, input_shape, id):
38 | return cls(input_shape, id, args.lstm_normalize, args.lstm_nb_hidden)
39 |
40 | @property
41 | def _output_shape(self):
42 | return (self._nb_hidden,)
43 |
44 | def _forward(self, xs, internals, **kwargs):
45 | hxs = self.stacked_internals("hx", internals)
46 | cxs = self.stacked_internals("cx", internals)
47 | hxs, cxs = self.lstm(xs, (hxs, cxs))
48 |
49 | return (
50 | hxs,
51 | {
52 | "hx": list(torch.unbind(hxs, dim=0)),
53 | "cx": list(torch.unbind(cxs, dim=0)),
54 | },
55 | )
56 |
57 | def _new_internals(self):
58 | return {
59 | "hx": torch.zeros(self._nb_hidden),
60 | "cx": torch.zeros(self._nb_hidden),
61 | }
62 |
--------------------------------------------------------------------------------
/adept/network/net1d/submodule_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from adept.network.base.submodule import SubModule
16 | import abc
17 |
18 |
19 | class SubModule1D(SubModule, metaclass=abc.ABCMeta):
20 | dim = 1
21 |
22 | def __init__(self, input_shape, id):
23 | super(SubModule1D, self).__init__(input_shape, id)
24 |
25 | def _to_1d_shape(self):
26 | return self._output_shape
27 |
28 | def _to_2d_shape(self):
29 | return self._output_shape + (1,)
30 |
31 | def _to_3d_shape(self):
32 | return self._output_shape + (1, 1)
33 |
34 | def _to_4d_shape(self):
35 | return self._output_shape + (1, 1, 1)
36 |
--------------------------------------------------------------------------------
/adept/network/net2d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/network/net2d/__init__.py
--------------------------------------------------------------------------------
/adept/network/net2d/identity_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_2d import SubModule2D
16 |
17 |
18 | class Identity2D(SubModule2D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/adept/network/net2d/submodule_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from ..base.submodule import SubModule
16 | import abc
17 | import math
18 |
19 |
20 | class SubModule2D(SubModule, metaclass=abc.ABCMeta):
21 | dim = 2
22 |
23 | def __init__(self, input_shape, id):
24 | super(SubModule2D, self).__init__(input_shape, id)
25 |
26 | def _to_1d_shape(self):
27 | f, s = self._output_shape
28 | return (f * s,)
29 |
30 | def _to_2d_shape(self):
31 | return self._output_shape
32 |
33 | def _to_3d_shape(self):
34 | f, s = self._output_shape
35 | return (f * s, 1, 1)
36 |
37 | def _to_4d_shape(self):
38 | f, s = self._output_shape
39 | return (f, s, 1, 1)
40 |
--------------------------------------------------------------------------------
/adept/network/net3d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/network/net3d/__init__.py
--------------------------------------------------------------------------------
/adept/network/net3d/four_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from torch.nn import Conv2d, BatchNorm2d, GroupNorm, init, functional as F
16 |
17 | from adept.modules import Identity
18 | from .submodule_3d import SubModule3D
19 |
20 |
21 | class FourConv(SubModule3D):
22 | args = {"fourconv_norm": "bn"}
23 |
24 | def __init__(self, in_shape, id, normalize):
25 | super().__init__(in_shape, id)
26 | bias = not normalize
27 | self._in_shape = in_shape
28 | self._out_shape = None
29 | self.conv1 = Conv2d(in_shape[0], 32, 7, stride=2, padding=1, bias=bias)
30 | self.conv2 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias)
31 | self.conv3 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias)
32 | self.conv4 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias)
33 |
34 | if normalize == "bn":
35 | self.bn1 = BatchNorm2d(32)
36 | self.bn2 = BatchNorm2d(32)
37 | self.bn3 = BatchNorm2d(32)
38 | self.bn4 = BatchNorm2d(32)
39 | elif normalize == "gn":
40 | self.bn1 = GroupNorm(8, 32)
41 | self.bn2 = GroupNorm(8, 32)
42 | self.bn3 = GroupNorm(8, 32)
43 | self.bn4 = GroupNorm(8, 32)
44 | else:
45 | self.bn1 = Identity()
46 | self.bn2 = Identity()
47 | self.bn3 = Identity()
48 | self.bn4 = Identity()
49 |
50 | relu_gain = init.calculate_gain("relu")
51 | self.conv1.weight.data.mul_(relu_gain)
52 | self.conv2.weight.data.mul_(relu_gain)
53 | self.conv3.weight.data.mul_(relu_gain)
54 | self.conv4.weight.data.mul_(relu_gain)
55 |
56 | @classmethod
57 | def from_args(cls, args, in_shape, id):
58 | return cls(in_shape, id, args.fourconv_norm)
59 |
60 | @property
61 | def _output_shape(self):
62 | # For 84x84, (32, 5, 5)
63 | if self._out_shape is None:
64 | output_dim = calc_output_dim(self._in_shape[1], 7, 2, 1, 1)
65 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
66 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
67 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
68 | self._out_shape = 32, output_dim, output_dim
69 | return self._out_shape
70 |
71 | def _forward(self, xs, internals, **kwargs):
72 | xs = F.relu(self.bn1(self.conv1(xs)))
73 | xs = F.relu(self.bn2(self.conv2(xs)))
74 | xs = F.relu(self.bn3(self.conv3(xs)))
75 | xs = F.relu(self.bn4(self.conv4(xs)))
76 | return xs, {}
77 |
78 | def _new_internals(self):
79 | return {}
80 |
81 |
82 | def calc_output_dim(dim_size, kernel_size, stride, padding, dilation):
83 | numerator = dim_size + 2 * padding - dilation * (kernel_size - 1) - 1
84 | return numerator // stride + 1
85 |
86 |
87 | if __name__ == "__main__":
88 | output_dim = 84
89 | output_dim = calc_output_dim(output_dim, 7, 2, 1, 1)
90 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
91 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
92 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
93 | print(output_dim) # should be 5
94 |
--------------------------------------------------------------------------------
/adept/network/net3d/identity_3d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_3d import SubModule3D
16 |
17 |
18 | class Identity3D(SubModule3D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/adept/network/net3d/rmc.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from torch.nn import Conv2d, Linear, BatchNorm2d, BatchNorm1d, functional as F
17 |
18 | from adept.modules import RMCCell, Identity
19 | from ..base.network_module import NetworkModule
20 |
21 |
22 | # TODO
23 | class RMC(NetworkModule):
24 | """
25 | Relational Memory Core
26 | https://arxiv.org/pdf/1806.01822.pdf
27 | """
28 |
29 | def __init__(self, nb_in_chan, output_shape_dict, normalize):
30 | self.embedding_size = 512
31 | super(RMC, self).__init__(self.embedding_size, output_shape_dict)
32 | bias = not normalize
33 | self.conv1 = Conv2d(
34 | nb_in_chan, 32, kernel_size=3, stride=2, padding=1, bias=bias
35 | )
36 | self.conv2 = Conv2d(
37 | 32, 32, kernel_size=3, stride=2, padding=1, bias=bias
38 | )
39 | self.conv3 = Conv2d(
40 | 32, 32, kernel_size=3, stride=2, padding=1, bias=bias
41 | )
42 | self.attention = RMCCell(100, 100, 34)
43 | self.conv4 = Conv2d(
44 | 34, 8, kernel_size=3, stride=1, padding=1, bias=bias
45 | )
46 | # BATCH x 8 x 10 x 10
47 | self.linear = Linear(800, 512, bias=bias)
48 |
49 | if normalize:
50 | self.bn1 = BatchNorm2d(32)
51 | self.bn2 = BatchNorm2d(32)
52 | self.bn3 = BatchNorm2d(32)
53 | self.bn4 = BatchNorm2d(8)
54 | self.bn_linear = BatchNorm1d(512)
55 | else:
56 | self.bn1 = Identity()
57 | self.bn2 = Identity()
58 | self.bn3 = Identity()
59 | self.bn_linear = Identity()
60 |
61 | def forward(self, input, prev_memories):
62 | """
63 | :param input: Tensor{B, C, H, W}
64 | :param prev_memories: Tuple{B}[Tensor{C}]
65 | :return:
66 | """
67 |
68 | x = F.relu(self.bn1(self.conv1(input)))
69 | x = F.relu(self.bn2(self.conv2(x)))
70 | x = F.relu(self.bn3(self.conv3(x)))
71 |
72 | h = x.size(2)
73 | w = x.size(3)
74 | xs_chan = (
75 | torch.linspace(-1, 1, w)
76 | .view(1, 1, 1, w)
77 | .expand(input.size(0), 1, w, w)
78 | .to(input.device)
79 | )
80 | ys_chan = (
81 | torch.linspace(-1, 1, h)
82 | .view(1, 1, h, 1)
83 | .expand(input.size(0), 1, h, h)
84 | .to(input.device)
85 | )
86 | x = torch.cat([x, xs_chan, ys_chan], dim=1)
87 |
88 | # need to transpose because attention expects
89 | # attention dim before channel dim
90 | x = x.view(x.size(0), x.size(1), h * w).transpose(1, 2)
91 | prev_memories = torch.stack(prev_memories)
92 | x = next_memories = self.attention(x.contiguous(), prev_memories)
93 | # need to undo the transpose before output
94 | x = x.transpose(1, 2)
95 | x = x.view(x.size(0), x.size(1), h, w)
96 |
97 | x = F.relu(self.bn4(self.conv4(x)))
98 | x = x.view(x.size(0), -1)
99 | x = F.relu(self.bn_linear(self.linear(x)))
100 | return x, list(torch.unbind(next_memories, 0))
101 |
102 | def new_internals(self, device):
103 | pass
104 |
--------------------------------------------------------------------------------
/adept/network/net3d/submodule_3d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from ..base.submodule import SubModule
16 | import abc
17 |
18 |
19 | class SubModule3D(SubModule, metaclass=abc.ABCMeta):
20 | dim = 3
21 |
22 | def __init__(self, input_shape, id):
23 | super(SubModule3D, self).__init__(input_shape, id)
24 |
25 | def _to_1d_shape(self):
26 | f, h, w = self._output_shape
27 | return (f * h * w,)
28 |
29 | def _to_2d_shape(self):
30 | f, h, w = self._output_shape
31 | return (f, h * w)
32 |
33 | def _to_3d_shape(self):
34 | return self._output_shape
35 |
36 | def _to_4d_shape(self):
37 | f, h, w = self._output_shape
38 | return (f, 1, h, w)
39 |
--------------------------------------------------------------------------------
/adept/network/net4d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/adept/network/net4d/__init__.py
--------------------------------------------------------------------------------
/adept/network/net4d/identity_4d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_4d import SubModule4D
16 |
17 |
18 | class Identity4D(SubModule4D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/adept/network/net4d/submodule_4d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from ..base.submodule import SubModule
16 | import abc
17 |
18 |
19 | class SubModule4D(SubModule, metaclass=abc.ABCMeta):
20 | dim = 4
21 |
22 | def __init__(self, input_shape, id):
23 | super(SubModule4D, self).__init__(input_shape, id)
24 |
25 | def _to_1d_shape(self):
26 | f, d, h, w = self._output_shape
27 | return (f * d * h * w,)
28 |
29 | def _to_2d_shape(self):
30 | f, d, h, w = self._output_shape
31 | return (f, d * h * w)
32 |
33 | def _to_3d_shape(self):
34 | f, d, h, w = self._output_shape
35 | return (f * d, h, w)
36 |
37 | def _to_4d_shape(self):
38 | return self._output_shape
39 |
--------------------------------------------------------------------------------
/adept/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 Heron Systems, Inc.
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 | from .base import Operation, SimpleOperation, MultiOperation
18 | from .base import CPUPreprocessor, GPUPreprocessor
19 |
--------------------------------------------------------------------------------
/adept/preprocess/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .ops import Operation, SimpleOperation, MultiOperation
2 | from .preprocessor import CPUPreprocessor, GPUPreprocessor
--------------------------------------------------------------------------------
/adept/preprocess/base/ops.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Operation(abc.ABC):
5 | @abc.abstractmethod
6 | def update_shape(self, old_shape):
7 | raise NotImplementedError
8 |
9 | @abc.abstractmethod
10 | def update_dtype(self, old_dtype):
11 | raise NotImplementedError
12 |
13 | def reset(self):
14 | pass
15 |
16 | def to(self, device):
17 | return self
18 |
19 |
20 | class MultiOperation(Operation, metaclass=abc.ABCMeta):
21 | """Modofies multiple keys of an observation dictionary."""
22 |
23 | def __init__(self, input_fields, output_fields):
24 | self.input_fields = input_fields
25 | self.output_fields = output_fields
26 |
27 | @abc.abstractmethod
28 | def preprocess_cpu(self, tensors):
29 | """Preprocess multiple observation fields on the CPU.
30 |
31 | Parameters
32 | ----------
33 | tensors : list[torch.Tensor]
34 |
35 | Returns
36 | -------
37 | list[torch.Tensor]
38 | """
39 | raise NotImplemented
40 |
41 | @abc.abstractmethod
42 | def preprocess_gpu(self, tensors):
43 | """Preprocess multiple observation fields on the GPU.
44 |
45 | Parameters
46 | ----------
47 | tensors : list[torch.Tensor]
48 |
49 | Returns
50 | -------
51 | list[torch.Tensor]
52 | """
53 | raise NotImplemented
54 |
55 |
56 | class SimpleOperation(Operation, metaclass=abc.ABCMeta):
57 | """Modifies a single key in the observation dictionary."""
58 |
59 | def __init__(self, input_field, output_field):
60 | self.input_field = input_field
61 | self.output_field = output_field
62 |
63 | @abc.abstractmethod
64 | def preprocess_cpu(self, tensor):
65 | """Preprocess a specific field of an observation on the CPU.
66 |
67 | Parameters
68 | ----------
69 | tensor : torch.Tensor
70 |
71 | Returns
72 | -------
73 | torch.Tensor
74 | """
75 | raise NotImplemented
76 |
77 | @abc.abstractmethod
78 | def preprocess_gpu(self, tensor):
79 | """Preprocess a specific field of an observation on the GPU
80 |
81 | Parameters
82 | ----------
83 | tensor : torch.Tensor
84 |
85 | Returns
86 | -------
87 | torch.Tensor
88 | """
89 | raise NotImplemented
90 |
--------------------------------------------------------------------------------
/adept/preprocess/base/preprocessor.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from copy import copy, deepcopy
16 | from adept.preprocess.base import MultiOperation, SimpleOperation
17 |
18 |
19 | class _Preprocessor:
20 | def __init__(self, ops, observation_space, observation_dtypes=None):
21 | """
22 | Parameters
23 | ----------
24 | ops : list[gamebreaker.preprocess.Operation]
25 | observation_space : dict[str, Shape]
26 | observation_dtypes : dict[str, dtype]
27 | """
28 | cur_space = deepcopy(observation_space)
29 | cur_dtypes = deepcopy(observation_dtypes)
30 |
31 | self.ops = ops
32 | self.observation_space, self.observation_dtypes = self._update(
33 | cur_space, cur_dtypes
34 | )
35 |
36 | def _update(self, cur_space, cur_dtypes):
37 | cur_space = copy(cur_space)
38 | cur_dtypes = copy(cur_dtypes)
39 | for op in self.ops:
40 | if isinstance(op, SimpleOperation):
41 | output_shape = op.update_shape(
42 | cur_space[op.input_field]
43 | )
44 | if output_shape:
45 | cur_space[op.output_field] = output_shape
46 | else:
47 | del cur_space[op.output_field]
48 | if cur_dtypes:
49 | output_dtype = op.update_dtype(
50 | cur_dtypes[op.input_field]
51 | )
52 | if output_dtype:
53 | cur_dtypes[op.output_field] = output_dtype
54 | else:
55 | del cur_dtypes[op.output_field]
56 | elif isinstance(op, MultiOperation):
57 | input_shapes = [cur_space[k] for k in op.input_fields]
58 | input_dtypes = [cur_dtypes[k] for k in op.input_fields]
59 | output_shapes = op.update_shape(input_shapes)
60 | output_dtypes = op.update_dtype(input_dtypes)
61 | for k, shape, dtype in zip(
62 | op.output_fields, output_shapes, output_dtypes
63 | ):
64 | if shape:
65 | cur_space[k] = shape
66 | else:
67 | del cur_space[k]
68 | if cur_dtypes:
69 | if dtype:
70 | cur_dtypes[k] = dtype
71 | else:
72 | del cur_dtypes[k]
73 | return cur_space, cur_dtypes
74 |
75 |
76 | class CPUPreprocessor(_Preprocessor):
77 | def __call__(self, obs):
78 | obs = copy(obs)
79 | for op in self.ops:
80 | if isinstance(op, SimpleOperation):
81 | output_tensor = op.preprocess_cpu(obs[op.input_field])
82 | if output_tensor is not None:
83 | obs[op.output_field] = output_tensor
84 | else:
85 | del obs[op.output_field]
86 | elif isinstance(op, MultiOperation):
87 | input_tensors = [obs[k] for k in op.input_fields]
88 | output_tensors = op.preprocess_cpu(input_tensors)
89 | for k, tensor in zip(op.output_fields, output_tensors):
90 | if tensor is not None:
91 | obs[k] = tensor
92 | else:
93 | del obs[k]
94 | return obs
95 |
96 | def reset(self):
97 | for o in self.ops:
98 | o.reset()
99 |
100 |
101 | class GPUPreprocessor(_Preprocessor):
102 | def __call__(self, obs):
103 | obs = copy(obs)
104 | for op in self.ops:
105 | if isinstance(op, SimpleOperation):
106 | output_tensor = op.preprocess_gpu(obs[op.input_field])
107 | if output_tensor is not None:
108 | obs[op.output_field] = output_tensor
109 | else:
110 | del obs[op.output_field]
111 | elif isinstance(op, MultiOperation):
112 | input_tensors = [obs[k] for k in op.input_fields]
113 | output_tensors = op.preprocess_gpu(input_tensors)
114 | for k, tensor in zip(op.output_fields, output_tensors):
115 | if tensor is not None:
116 | obs[k] = tensor
117 | else:
118 | del obs[k]
119 | return obs
120 |
121 | def to(self, device):
122 | self.ops = [op.to(device) for op in self.ops]
123 | return self
124 |
--------------------------------------------------------------------------------
/adept/registry/__init__.py:
--------------------------------------------------------------------------------
1 | from .registry import Registry
2 |
3 | REGISTRY = Registry()
4 |
--------------------------------------------------------------------------------
/adept/rewardnorm/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import RewardNormModule
2 | from .normalizers import Scale, Clip, Identity
3 |
4 | REWARD_NORM_REG = [Scale, Clip, Identity]
5 |
--------------------------------------------------------------------------------
/adept/rewardnorm/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .rewnorm_module import RewardNormModule
2 |
--------------------------------------------------------------------------------
/adept/rewardnorm/base/rewnorm_module.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from adept.utils.requires_args import RequiresArgsMixin
4 |
5 |
6 | class RewardNormModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
7 | def __call__(self, reward):
8 | """
9 | Normalizes a reward tensor.
10 |
11 | :param reward: torch.Tensor (1D)
12 | :return:
13 | """
14 | raise NotImplementedError
15 |
--------------------------------------------------------------------------------
/adept/rewardnorm/normalizers.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from .base import RewardNormModule
17 |
18 |
19 | class Clip(RewardNormModule):
20 | args = {"floor": -1, "ceil": 1}
21 |
22 | def __init__(self, floor, ceil):
23 | self.floor = floor
24 | self.ceil = ceil
25 |
26 | @classmethod
27 | def from_args(cls, args):
28 | return cls(args.floor, args.ceil)
29 |
30 | def __call__(self, reward):
31 | return torch.clamp(reward, self.floor, self.ceil)
32 |
33 |
34 | class Scale(RewardNormModule):
35 | args = {"coefficient": 0.1}
36 |
37 | def __init__(self, coefficient):
38 | self.coefficient = coefficient
39 |
40 | @classmethod
41 | def from_args(cls, args):
42 | return cls(args.coefficient)
43 |
44 | def __call__(self, reward):
45 | return self.coefficient * reward
46 |
47 |
48 | class Identity(RewardNormModule):
49 | args = {}
50 |
51 | @classmethod
52 | def from_args(cls, args):
53 | return cls()
54 |
55 | def __call__(self, item):
56 | return item
57 |
--------------------------------------------------------------------------------
/adept/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 Heron Systems, Inc.
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
--------------------------------------------------------------------------------
/adept/scripts/_distrib.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (C) 2018 Heron Systems, Inc.
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 | """
17 | Distributed worker script. Called from launcher (distrib.py).
18 | """
19 | import argparse
20 | import json
21 | import os
22 | import torch.distributed as dist
23 |
24 | from adept.container import Init, DistribHost, DistribWorker
25 | from adept.registry import REGISTRY as R
26 | from adept.utils.script_helpers import LogDirHelper
27 | from adept.utils.util import DotDict
28 |
29 | MODE = "Distrib"
30 | WORLD_SIZE = int(os.environ["WORLD_SIZE"])
31 | GLOBAL_RANK = int(os.environ["RANK"])
32 | LOCAL_RANK = int(os.environ["LOCAL_RANK"])
33 |
34 |
35 | def str2bool(v):
36 | if v.lower() in ("yes", "true", "t", "y", "1"):
37 | return True
38 | elif v.lower() in ("no", "false", "f", "n", "0"):
39 | return False
40 | else:
41 | raise argparse.ArgumentTypeError("Boolean value expected.")
42 |
43 |
44 | def parse_args():
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--log-id-dir", required=True)
47 | parser.add_argument(
48 | "--resume", type=str2bool, nargs="?", const=True, default=False
49 | )
50 | parser.add_argument("--load-network", default=None)
51 | parser.add_argument("--load-optim", default=None)
52 | parser.add_argument("--initial-step-count", type=int, default=0)
53 | parser.add_argument("--init-method", default=None)
54 | parser.add_argument("--custom-network", default=None)
55 | args = parser.parse_args()
56 | return args
57 |
58 |
59 | def main(local_args):
60 | """
61 | Run distributed training.
62 |
63 | :param local_args: Dict[str, Any]
64 | :return:
65 | """
66 | log_id_dir = local_args.log_id_dir
67 | initial_step_count = local_args.initial_step_count
68 |
69 | R.load_extern_classes(log_id_dir)
70 | logger = Init.setup_logger(
71 | log_id_dir, "train{}".format(GLOBAL_RANK)
72 | )
73 |
74 | helper = LogDirHelper(log_id_dir)
75 | with open(helper.args_file_path(), "r") as args_file:
76 | args = DotDict(json.load(args_file))
77 |
78 | if local_args.resume:
79 | args = DotDict({**args, **vars(local_args)})
80 |
81 | dist.init_process_group(
82 | backend="nccl",
83 | init_method=args.init_method,
84 | world_size=WORLD_SIZE,
85 | rank=LOCAL_RANK,
86 | )
87 | logger.info("Rank {} initialized.".format(GLOBAL_RANK))
88 |
89 | if LOCAL_RANK == 0:
90 | container = DistribHost(
91 | args,
92 | logger,
93 | log_id_dir,
94 | initial_step_count,
95 | LOCAL_RANK,
96 | GLOBAL_RANK,
97 | WORLD_SIZE,
98 | )
99 | else:
100 | container = DistribWorker(
101 | args,
102 | logger,
103 | log_id_dir,
104 | initial_step_count,
105 | LOCAL_RANK,
106 | GLOBAL_RANK,
107 | WORLD_SIZE,
108 | )
109 |
110 | try:
111 | container.run()
112 | finally:
113 | container.close()
114 |
115 |
116 | if __name__ == "__main__":
117 | main(parse_args())
118 |
--------------------------------------------------------------------------------
/adept/scripts/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (C) 2018 Heron Systems, Inc.
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 | """
17 | __ __
18 | ____ _____/ /__ ____ / /_
19 | / __ `/ __ / _ \/ __ \/ __/
20 | / /_/ / /_/ / __/ /_/ / /_
21 | \__,_/\__,_/\___/ .___/\__/
22 | /_/
23 | Evaluate
24 |
25 | Evaluates an agent after training. Computes N-episode average reward by
26 | loading a saved model from each epoch. N-episode averages are computed by
27 | running N env in parallel.
28 |
29 | Usage:
30 | evaluate (--logdir ) [options]
31 | evaluate (-h | --help)
32 |
33 | Required:
34 | --logdir Path to train logs (.../logs//)
35 |
36 | Options:
37 | --epoch Epoch number to load [default: None]
38 | --actor Name of the eval actor [default: ACActorEval]
39 | --gpu-id CUDA device ID of GPU [default: 0]
40 | --nb-episode Number of episodes to average [default: 30]
41 | --start Epoch to start from [default: 0]
42 | --end Epoch to end on [default: -1]
43 | --seed Seed for random variables [default: 512]
44 | --custom-network Name of custom network class
45 | """
46 | from adept.container import EvalContainer
47 | from adept.container import Init
48 | from adept.registry import REGISTRY as R
49 | from adept.utils.script_helpers import parse_path, parse_none
50 | from adept.utils.util import DotDict
51 |
52 |
53 | def parse_args():
54 | from docopt import docopt
55 |
56 | args = docopt(__doc__)
57 | args = {k.strip("--").replace("-", "_"): v for k, v in args.items()}
58 | del args["h"]
59 | del args["help"]
60 | args = DotDict(args)
61 | args.logdir = parse_path(args.logdir)
62 | # TODO implement Option utility
63 | epoch_option = parse_none(args.epoch)
64 | if epoch_option:
65 | args.epoch = int(float(epoch_option))
66 | else:
67 | args.epoch = epoch_option
68 | args.gpu_id = int(args.gpu_id)
69 | args.nb_episode = int(args.nb_episode)
70 | args.start = float(args.start)
71 | args.end = float(args.end)
72 | args.seed = int(args.seed)
73 | return args
74 |
75 |
76 | def main(args):
77 | """
78 | Run an evaluation.
79 | :param args: Dict[str, Any]
80 | :return:
81 | """
82 | args = DotDict(args)
83 |
84 | Init.print_ascii_logo()
85 | logger = Init.setup_logger(args.logdir, "eval")
86 | Init.log_args(logger, args)
87 | R.load_extern_classes(args.logdir)
88 |
89 | eval_container = EvalContainer(
90 | args.actor,
91 | args.epoch,
92 | logger,
93 | args.logdir,
94 | args.gpu_id,
95 | args.nb_episode,
96 | args.start,
97 | args.end,
98 | args.seed,
99 | )
100 | try:
101 | eval_container.run()
102 | finally:
103 | eval_container.close()
104 |
105 |
106 | if __name__ == "__main__":
107 | main(parse_args())
108 |
--------------------------------------------------------------------------------
/adept/scripts/render.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (C) 2018 Heron Systems, Inc.
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 | """
17 | __ __
18 | ____ _____/ /__ ____ / /_
19 | / __ `/ __ / _ \/ __ \/ __/
20 | / /_/ / /_/ / __/ /_/ / /_
21 | \__,_/\__,_/\___/ .___/\__/
22 | /_/
23 |
24 | Render Atari
25 |
26 | Renders an agent interacting with an environment.
27 |
28 | Usage:
29 | render --logdir [options]
30 | render (-h | --help)
31 |
32 | Required:
33 | --logdir Path to train logs (.../logs//)
34 |
35 | Options:
36 | --epoch Epoch number to load [default: None]
37 | --actor Name of the Actor [default: ACActorEval]
38 | --start Epoch number to start from [default: 0]
39 | --end Epoch number to end on [default: -1]
40 | --gpu-id CUDA device ID of GPU [default: 0]
41 | --seed Seed for random variables [default: 512]
42 | --manager Manager to use [default: SimpleEnvManager]
43 | """
44 |
45 | from adept.container import Init
46 | from adept.container.render import RenderContainer
47 | from adept.registry import REGISTRY as R
48 | from adept.utils.script_helpers import parse_path, parse_none
49 | from adept.utils.util import DotDict
50 |
51 |
52 | def parse_args():
53 | from docopt import docopt
54 |
55 | args = docopt(__doc__)
56 | args = {k.strip("--").replace("-", "_"): v for k, v in args.items()}
57 | del args["h"]
58 | del args["help"]
59 | args = DotDict(args)
60 |
61 | args.logdir = parse_path(args.logdir)
62 |
63 | # TODO implement Option utility
64 | epoch_option = parse_none(args.epoch)
65 | if epoch_option:
66 | args.epoch = int(float(epoch_option))
67 | else:
68 | args.epoch = epoch_option
69 | args.start = int(float(args.start))
70 | args.end = int(float(args.end))
71 | args.gpu_id = int(args.gpu_id)
72 | args.seed = int(args.seed)
73 | return args
74 |
75 |
76 | def main(args):
77 | """
78 | Run an evaluation training.
79 |
80 | :param args: Dict[str, Any]
81 | :return:
82 | """
83 | # construct logging objects
84 | args = DotDict(args)
85 |
86 | Init.print_ascii_logo()
87 | logger = Init.setup_logger(args.logdir, "eval")
88 | Init.log_args(logger, args)
89 | R.load_extern_classes(args.logdir)
90 |
91 | container = RenderContainer(
92 | args.actor,
93 | args.epoch,
94 | args.start,
95 | args.end,
96 | logger,
97 | args.logdir,
98 | args.gpu_id,
99 | args.seed,
100 | args.manager,
101 | )
102 | try:
103 | container.run()
104 | finally:
105 | container.close()
106 |
107 |
108 | if __name__ == "__main__":
109 | main(parse_args())
110 |
--------------------------------------------------------------------------------
/adept/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .util import listd_to_dlist, dlist_to_listd, dtensor_to_dev
16 |
--------------------------------------------------------------------------------
/adept/utils/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import os
16 | from collections import namedtuple
17 | from time import time
18 |
19 | import torch
20 |
21 | from adept.utils.util import HeapQueue
22 |
23 |
24 | class ModelSaver:
25 | BufferEntry = namedtuple(
26 | "BufferEntry", ["reward", "priority", "network", "optimizer"]
27 | )
28 |
29 | def __init__(self, nb_top_model, log_id_dir):
30 | self.nb_top_model = nb_top_model
31 | self._buffer = HeapQueue(nb_top_model)
32 | self._log_id_dir = log_id_dir
33 |
34 | def append_if_better(self, reward, network, optimizer):
35 | self._buffer.push(
36 | self.BufferEntry(
37 | reward, time(), network.state_dict(), optimizer.state_dict()
38 | )
39 | )
40 |
41 | def write_state_dicts(self, epoch_id):
42 | save_dir = os.path.join(self._log_id_dir, str(epoch_id))
43 | if len(self._buffer) > 0:
44 | os.makedirs(save_dir)
45 | for j, buff_entry in enumerate(self._buffer.flush()):
46 | torch.save(
47 | buff_entry.network,
48 | os.path.join(
49 | save_dir,
50 | "model_{}_{}.pth".format(j + 1, int(buff_entry.reward)),
51 | ),
52 | )
53 | torch.save(
54 | buff_entry.optimizer,
55 | os.path.join(
56 | save_dir,
57 | "optimizer_{}_{}.pth".format(j + 1, int(buff_entry.reward)),
58 | ),
59 | )
60 |
61 |
62 | class SimpleModelSaver:
63 | def __init__(self, log_id_dir):
64 | self._log_id_dir = log_id_dir
65 |
66 | def save_state_dicts(self, network, step_count, optimizer=None):
67 | save_dir = os.path.join(self._log_id_dir, str(step_count))
68 | os.makedirs(save_dir)
69 | torch.save(
70 | network.state_dict(),
71 | os.path.join(save_dir, "model_{}.pth".format(step_count)),
72 | )
73 | if optimizer is not None:
74 | torch.save(
75 | optimizer.state_dict(),
76 | os.path.join(save_dir, "optimizer_{}.pth".format(step_count)),
77 | )
78 |
--------------------------------------------------------------------------------
/adept/utils/requires_args.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import json
16 | import abc
17 |
18 |
19 | class RequiresArgsMixin(metaclass=abc.ABCMeta):
20 | """
21 | This mixin makes it so that subclasses must implement an args class
22 | attribute. These arguments are parsed at runtime and the user is offered a
23 | chance to change any desired args. Classes the use this mixin must
24 | implement the from_args() class method. from_args() is essentially a
25 | secondary constructor.
26 | """
27 |
28 | args = None # Dict[str, Any]
29 |
30 | @classmethod
31 | def check_args_implemented(cls):
32 | if cls.args is None:
33 | raise NotImplementedError(
34 | 'Subclass must define class attribute "args"'
35 | )
36 |
37 | @classmethod
38 | def prompt(cls, provided=None):
39 | """
40 | Display defaults as JSON, prompt user for changes.
41 |
42 | :param provided: Dict[str, Any], Override default prompts.
43 | :return: Dict[str, Any] Updated config dictionary.
44 | """
45 | if provided is not None:
46 | overrides = {k: v for k, v in provided.items() if k in cls.args}
47 | args = {**cls.args, **overrides}
48 | else:
49 | args = cls.args
50 | return cls._prompt(cls.__name__, args)
51 |
52 | @staticmethod
53 | def _prompt(name, args):
54 | """
55 | Display defaults as JSON, prompt user for changes.
56 |
57 | :param name: str Name of class
58 | :param args: Dict[str, Any]
59 | :return: Dict[str, Any] Updated config dictionary.
60 | """
61 | if not args:
62 | return args
63 |
64 | user_input = input(
65 | "\n{} Defaults:\n{}\n"
66 | "Press ENTER to use defaults. Otherwise, "
67 | "modify JSON keys then press ENTER.\n".format(
68 | name, json.dumps(args, indent=2, sort_keys=True)
69 | )
70 | + 'Example: {"x": True, "gamma": 0.001}\n'
71 | )
72 |
73 | # use defaults if no changes specified
74 | if user_input == "":
75 | return args
76 |
77 | updates = json.loads(user_input)
78 | return {**args, **updates}
79 |
80 | @classmethod
81 | @abc.abstractmethod
82 | def from_args(cls, *argss, **kwargs):
83 | raise NotImplementedError
84 |
--------------------------------------------------------------------------------
/adept/utils/script_helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import json
16 | import os
17 | from time import sleep
18 |
19 | from adept.utils.util import DotDict
20 |
21 |
22 | def parse_bool_str(bool_str):
23 | """
24 | Convert string to boolean.
25 |
26 | :param bool_str: str
27 | :return: Bool
28 | """
29 | if bool_str.lower() == "false":
30 | return False
31 | elif bool_str.lower() == "true":
32 | return True
33 | else:
34 | raise ValueError('Unable to parse "{}"'.format(bool_str))
35 |
36 |
37 | def parse_list_str(list_str, item_type):
38 | items = list_str.split(",")
39 | return [item_type(item) for item in items]
40 |
41 |
42 | def parse_none(none_str):
43 | if none_str == "None":
44 | return None
45 | else:
46 | return none_str
47 |
48 |
49 | def parse_path(rel_path):
50 | """
51 | :param rel_path: (str) relative path
52 | :return: (str) absolute path
53 | """
54 | return os.path.abspath(rel_path)
55 |
56 |
57 | class LogDirHelper:
58 | def __init__(self, log_id_path):
59 | """
60 | :param log_id_path: str Path to Log ID
61 | """
62 |
63 | self._log_id_path = log_id_path
64 |
65 | def epochs(self):
66 | epochs = []
67 | for item in os.listdir(self._log_id_path):
68 | item_path = os.path.join(self._log_id_path, item)
69 | if os.path.isdir(item_path):
70 | if item.isnumeric():
71 | item_int = int(item)
72 | if item_int >= 0:
73 | epochs.append(item_int)
74 | return list(sorted(epochs))
75 |
76 | def latest_epoch(self):
77 | epochs = self.epochs()
78 | return max(epochs) if epochs else 0
79 |
80 | def latest_epoch_path(self):
81 | return os.path.join(self._log_id_path, str(self.latest_epoch()))
82 |
83 | def latest_network_path(self):
84 | network_file = [
85 | f for f in os.listdir(self.latest_epoch_path()) if ("model" in f)
86 | ][0]
87 | return os.path.join(self.latest_epoch_path(), network_file)
88 |
89 | def latest_optim_path(self):
90 | optim_file = [
91 | f for f in os.listdir(self.latest_epoch_path()) if ("optim" in f)
92 | ][0]
93 | return os.path.join(self.latest_epoch_path(), optim_file)
94 |
95 | def epoch_path_at_epoch(self, epoch):
96 | return os.path.join(self._log_id_path, str(epoch))
97 |
98 | def network_path_at_epoch(self, epoch, num_tries=1, retry_delay=3):
99 | """Find network path at epoch
100 |
101 | Parameters
102 | ----------
103 | epoch : int
104 | epoch to find network path for
105 | num_tries: int, optional
106 | number of tries to do, by default 1 (no retries)
107 | retry_delay : int, optional
108 | delay between retry attempts, by default 3 (seconds)
109 |
110 | Returns
111 | -------
112 | str
113 | path to network file
114 | """
115 | assert num_tries, "num_tries must be greater than 0"
116 |
117 | epoch_path = self.epoch_path_at_epoch(epoch)
118 |
119 | for try_idx in range(num_tries):
120 | if try_idx > 0:
121 | sleep(retry_delay)
122 |
123 | network_files = [f for f in os.listdir(epoch_path) if ("model" in f)]
124 |
125 | if len(network_files):
126 | break
127 | else:
128 | raise AssertionError(
129 | "No network files found at epoch {epoch} for {self._log_id_path} after {num_tries} tries"
130 | )
131 |
132 | assert len(network_files) <= 1, (
133 | "More than one network paths at epoch {epoch}, "
134 | "maybe you want network_paths_at_epoch()"
135 | )
136 |
137 | network_file = network_files[0]
138 | return os.path.join(epoch_path, network_file)
139 |
140 | def network_paths_at_epoch(self, epoch):
141 | epoch_path = self.epoch_path_at_epoch(epoch)
142 | return [
143 | os.path.join(epoch_path, f)
144 | for f in os.listdir(epoch_path)
145 | if ("model" in f)
146 | ]
147 |
148 | def optim_path_at_epoch(self, epoch):
149 | epoch_path = self.epoch_path_at_epoch(epoch)
150 | optim_file = [f for f in os.listdir(epoch_path) if ("optim" in f)][0]
151 | return os.path.join(epoch_path, optim_file)
152 |
153 | def timestamp(self):
154 | splits = self._log_id_path.split("_")
155 | timestamp = splits[-2] + "_" + splits[-1]
156 | return timestamp
157 |
158 | def args_file_path(self):
159 | return os.path.join(self._log_id_path, "args.json")
160 |
161 | def load_args(self):
162 | with open(self.args_file_path()) as args_file:
163 | return DotDict(json.load(args_file))
164 |
165 | def eval_path(self):
166 | return os.path.join(self._log_id_path, "eval.csv")
167 |
--------------------------------------------------------------------------------
/adept/utils/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import json
16 | import heapq
17 | import inspect
18 | from collections import OrderedDict
19 |
20 | import numpy as np
21 | import torch
22 |
23 |
24 | def listd_to_dlist(list_of_dicts):
25 | """
26 | Converts a list of dictionaries to a dictionary of lists. Preserves key
27 | order.
28 |
29 | K is type of key.
30 | V is type of value.
31 | :param list_of_dicts: List[Dict[K, V]]
32 | :return: Dict[K, List[V]]
33 | """
34 | new_dict = OrderedDict()
35 | for d in list_of_dicts:
36 | for k, v in d.items():
37 | if k not in new_dict:
38 | new_dict[k] = [v]
39 | else:
40 | new_dict[k].append(v)
41 | return new_dict
42 |
43 |
44 | def dlist_to_listd(dict_of_lists):
45 | """
46 | Converts a dictionary of lists to a list of dictionaries. Preserves key
47 | order.
48 |
49 | K is type of key.
50 | V is type of value.
51 | :param dict_of_lists: Dict[K, List[V]]
52 | :return: List[Dict[K, V]]
53 | """
54 | keys = dict_of_lists.keys()
55 | list_len = len(dict_of_lists[next(iter(keys))])
56 | new_list = []
57 | for i in range(list_len):
58 | temp_d = OrderedDict()
59 | for k in keys:
60 | temp_d[k] = dict_of_lists[k][i]
61 | new_list.append(temp_d)
62 | return new_list
63 |
64 |
65 | def dtensor_to_dev(d_tensor, device):
66 | """
67 | Move a dictionary of tensors to a device.
68 |
69 | :param d_tensor: Dict[str, Tensor]
70 | :param device: torch.device
71 | :return: Dict[str, Tensor] on desired device.
72 | """
73 | return {k: v.to(device) for k, v in d_tensor.items()}
74 |
75 |
76 | def json_to_dict(file_path):
77 | """Read JSON config."""
78 | json_object = json.load(open(file_path, "r"))
79 | return json_object
80 |
81 |
82 | _numpy_to_torch_dtype = {
83 | np.float16: torch.float16,
84 | np.float32:torch.float32,
85 | np.float64: torch.float64,
86 | np.uint8: torch.uint8,
87 | np.int8: torch.int8,
88 | np.int16: torch.int16,
89 | np.int32: torch.int32,
90 | np.int64: torch.int64,
91 | }
92 | _torch_to_numpy_dtype = {v: k for k, v in _numpy_to_torch_dtype.items()}
93 |
94 |
95 | def numpy_to_torch_dtype(dtype):
96 | if inspect.isclass(dtype):
97 | name = dtype
98 | else:
99 | name = type(dtype)
100 | if name not in _numpy_to_torch_dtype:
101 | raise ValueError(
102 | f"Could not convert numpy dtype {dtype.name} to a torch dtype."
103 | )
104 | return _numpy_to_torch_dtype[name]
105 |
106 |
107 | def torch_to_numpy_dtype(dtype):
108 | if dtype not in _torch_to_numpy_dtype:
109 | raise ValueError(
110 | "Could not convert torch dtype {} to a numpy dtype.".format(
111 | dtype
112 | )
113 | )
114 |
115 | return _torch_to_numpy_dtype[dtype]
116 |
117 |
118 | class CircularBuffer(object):
119 | def __init__(self, size):
120 | self.index = 0
121 | self.size = size
122 | self._data = []
123 |
124 | def append(self, value):
125 | if len(self._data) == self.size:
126 | self._data[self.index] = value
127 | else:
128 | self._data.append(value)
129 | self.index = (self.index + 1) % self.size
130 |
131 | def is_empty(self):
132 | return self._data == []
133 |
134 | def not_empty(self):
135 | return not self.is_empty()
136 |
137 | def is_full(self):
138 | return len(self) == self.size
139 |
140 | def not_full(self):
141 | return not self.is_full()
142 |
143 | def __getitem__(self, key):
144 | """get element by index like a regular array"""
145 | return self._data[key]
146 |
147 | def __setitem__(self, key, value):
148 | self._data[key] = value
149 |
150 | def __repr__(self):
151 | """return string representation"""
152 | return self._data.__repr__() + " (" + str(len(self._data)) + " items)"
153 |
154 | def __len__(self):
155 | return len(self._data)
156 |
157 |
158 | class HeapQueue:
159 | def __init__(self, maxlen):
160 | self.q = []
161 | self.maxlen = maxlen
162 |
163 | def push(self, item):
164 | if len(self.q) < self.maxlen:
165 | heapq.heappush(self.q, item)
166 | else:
167 | heapq.heappushpop(self.q, item)
168 |
169 | def flush(self):
170 | q = self.q
171 | self.q = []
172 | return q
173 |
174 | def __len__(self):
175 | return len(self.q)
176 |
177 |
178 | class DotDict(dict):
179 | """
180 | Dictionary to access attributes
181 | """
182 |
183 | __getattr__ = dict.get
184 | __setattr__ = dict.__setitem__
185 | __delattr__ = dict.__delitem__
186 |
187 | # Support pickling
188 | def __getstate__(obj):
189 | return dict(obj.items())
190 |
191 | def __setstate__(cls, attributes):
192 | return DotDict(**attributes)
193 |
--------------------------------------------------------------------------------
/docker/connect.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2020 Heron Systems, Inc.
2 | """
3 | Sample command:
4 | python connect.py --dockerfile ./Dockerfile --username gburdell \
5 | --email gburdell@heronsystems.com \
6 | --fullname "George P. Burdell"
7 | Given a path to a Dockerfile, this script (1) builds a Docker image from the
8 | Dockerfile and (2) outputs a command that can directly be pasted into a shell
9 | that connects to a Docker instance spawned from the Docker image built in (1).
10 | """
11 | import argparse
12 | import os
13 | import sys
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "--dockerfile",
20 | default="./Dockerfile",
21 | help=(
22 | "Path to the Dockerfile you wish to work with, "
23 | "e.g., /some/path/to/Dockerfile"
24 | ),
25 | )
26 | parser.add_argument(
27 | "--username",
28 | required=True,
29 | help="Your local username in the docker instance that you connect to.",
30 | )
31 | parser.add_argument(
32 | "--email",
33 | required=True,
34 | help=(
35 | "Your email address, used for git purposes, e.g., gburdell@heronsystems.com."
36 | ),
37 | )
38 | parser.add_argument(
39 | "--fullname",
40 | required=True,
41 | help="Your full name, used for git purposes.",
42 | )
43 | return parser.parse_args()
44 |
45 |
46 | def runcmd(cmd):
47 | print("Running --\n{}\n--".format(cmd))
48 | return os.system(cmd)
49 |
50 |
51 | def connect_local(args):
52 | # Check for .ssh under /mnt directory; this is important to be able to pull
53 | # from github within Docker instances.
54 | if not os.path.exists(
55 | "/mnt/users/{username}".format(username=args.username)
56 | ):
57 | print(
58 | (
59 | "Please create the directory /mnt/users/{username} and run:"
60 | "chmod -R a+rw /mnt/users/{username}."
61 | ).format(username=args.username)
62 | )
63 |
64 | # Check for .ssh under /mnt directory; this is important to be able to pull
65 | # from github within Docker instances.
66 | if not os.path.exists(
67 | "/mnt/users/{username}/.ssh".format(username=args.username)
68 | ):
69 | print(
70 | "Please run the following before running connect.py. This is to "
71 | "allow your Docker instance to be able to pull/push from "
72 | "private github repositories:\n"
73 | "ln -s ~/.ssh /mnt/users/{username}/.ssh".format(
74 | username=args.username
75 | )
76 | )
77 | sys.exit(1)
78 |
79 | # Build the docker image
80 | dockerfile_dir = os.path.dirname(os.path.abspath(args.dockerfile))
81 | cmd = (
82 | "cd {dockerfile_dir}; docker build "
83 | "-f {dockerfile} "
84 | "--build-arg USERNAME={username} "
85 | "--build-arg EMAIL={email} "
86 | '--build-arg FULLNAME="{fullname}" '
87 | "-t {username}-dev .".format(
88 | dockerfile=args.dockerfile,
89 | dockerfile_dir=dockerfile_dir,
90 | username=args.username,
91 | email=args.email,
92 | fullname=args.fullname,
93 | )
94 | )
95 | success = runcmd(cmd)
96 |
97 | if success != 0:
98 | print("Please fix the errors that occurred above.")
99 | sys.exit(1)
100 |
101 | # Construct the docker instance creation command
102 | cmd = (
103 | 'xhost +"local:docker"; docker run -it --rm '
104 | # Attach gpus
105 | "--gpus=all "
106 | # Take care of ctrl-p issues
107 | '--detach-keys="ctrl-@" '
108 | # Configure X
109 | "-e DISPLAY=$DISPLAY "
110 | "-v /tmp/.X11-unix:/tmp/.X11-unix:ro "
111 | # Expose all ports
112 | "--network host "
113 | # Mount volumes
114 | "-v /mnt/:/mnt/ "
115 | '{}-dev; xhost -"local:docker"'.format(args.username)
116 | )
117 | print("Run this to connect:\n{cmd}".format(cmd=cmd))
118 |
119 | # For convenience, copy to the clipboard if possible
120 | try:
121 | import pyperclip
122 |
123 | pyperclip.copy(cmd)
124 | except:
125 | pass
126 |
127 |
128 | def main():
129 | connect_local(parse_args())
130 |
131 |
132 | if __name__ == "__main__":
133 | main()
134 |
--------------------------------------------------------------------------------
/docker/startup.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/zsh
2 | # Copyright (C) 2020 Heron Systems, Inc.
3 | #
4 | # This script's purpose is to run preliminary setup for the dev environment
5 | # once an image has been created.
6 |
7 | # retain state across sessions
8 | mkdir -p "/mnt/users/$USERNAME"
9 | sudo chown -R $USERNAME "/mnt/users/$USERNAME"
10 | cd /home/$USERNAME/
11 | persistents=".bashrc .bash_logout .zshrc"
12 | for f in ${persistents};
13 | do if [ ! -e /mnt/users/$USERNAME/${f} ];
14 | then
15 | cp ${f} /mnt/users/$USERNAME/${f};
16 | fi;
17 | done
18 |
19 | # zhistory and clients aren't created by default, so create them if they don't
20 | # exist already
21 | test -e /mnt/users/$USERNAME/.zhistory || touch /mnt/users/$USERNAME/.zhistory
22 | mkdir -p /mnt/users/$USERNAME/clients
23 | mkdir -p /mnt/users/$USERNAME/data
24 |
25 | # symlink these files to the persistent versions
26 | persistents=".bashrc .bash_logout .zshrc .zhistory .tmux.conf .ssh clients data"
27 | for f in ${persistents};
28 | do if [ -e "/mnt/users/${USERNAME}/$f" ];
29 | then
30 | rm -f ${f} && ln -s "/mnt/users/${USERNAME}/$f" $f;
31 | fi
32 | done
33 |
34 | # git setup
35 | export GIT_SSL_NO_VERIFY=1
36 | git config --global user.email "$EMAIL"
37 | git config --global user.name "$FULLNAME"
38 |
--------------------------------------------------------------------------------
/docs/api_overview.md:
--------------------------------------------------------------------------------
1 | # API Overview
2 | 
3 | ### Containers
4 | Containers hold all of the application state. Each subprocess gets a container
5 | in Distributed and IMPALA modes.
6 | ### Agents
7 | An Agent acts on and observes the environment.
8 | Currently only ActorCritic is supported. Other agents, such as DQN or ACER may
9 | be added later.
10 | ### Networks
11 | Networks are not PyTorch modules, they need to implement our abstract
12 | NetworkModule or ModularNetwork classes. A ModularNetwork consists of a
13 | source nets, body, and heads.
14 | ### Environments
15 | Environments run in subprocesses and send their observation, rewards,
16 | terminals, and infos to the host process. They work pretty much the same way as
17 | OpenAI's code.
18 | ### Experience Caches
19 | An Experience Cache is a Rollout or Experience Replay that is written to after
20 | stepping and read before learning.
21 |
22 | Inheritance Tree:
23 | * [Container](#container)
24 | * LocalWorker
25 | * DistribWorker
26 | * ImpalaHost
27 | * ImpalaWorkerCPU
28 | * ImpalaWorkerGPU
29 | * [Agent](#agent)
30 | * ActorCritic
31 | * ActorCriticVTrace
32 | * ACER
33 | * ACERVTrace
34 | * [ExperienceCache](#experiencecache)
35 | * ExperienceReplay
36 | * RolloutCache
37 | * [Environment](#environment)
38 | * Gym Env
39 | * SC2Feat1DEnv (1d action space)
40 | * SC2Feat3DEnv (3d action space)
41 | * SC2RawEnv (new proto)
42 | * [EnvironmentManager](#environmentmanager)
43 | * SimpleEnvManager (synchronous, same process, for debugging / rendering)
44 | * SubProcEnvManager (use torch.multiprocessing.Pipe)
45 | * [Network](#network)
46 | * ModularNetwork
47 | * CustomNetwork
48 | * [SubModule](#submodule)
49 | * SubModule1D
50 | * Identity1D
51 | * LSTM
52 | * SubModule2D
53 | * Identity2D
54 | * TransformerEnc
55 | * TransformerDec
56 | * SubModule3D
57 | * Identity3D
58 | * SC2LEEncoder - encoder from SC2LE paper
59 | * SC2LEDecoder - decoder from SC2LE paper
60 | * ConvLSTM
61 | * SubModule4D
62 | * Identity4D
63 |
64 | ### Container
65 | * `agent:` [Agent](#agent)
66 | * `env:` [EnvironmentManager](#environment)
67 | * `exp_cache:` [ExperienceCache](#experiencecache)
68 | * `network:` [Network](#network)
69 | * `local_step_count: int`
70 | * `reward_buffer: List[float], keep track of running rewards`
71 | * `hidden_state: Dict[HSName, torch.Tensor], keep track of current LSTM states`
72 | ```python
73 | def run(nb_step, initial_step=0): ...
74 | """
75 | args:
76 | nb_step: int, number of steps to train for
77 | return:
78 | self
79 | """
80 | ```
81 |
82 | ### Agent
83 | * [network](#network)
84 | ```python
85 | def observe(observations, rewards, terminals, infos): ...
86 | """
87 | args:
88 | observations: Dict[ObsName, List[torch.Tensor]]
89 | rewards: List[int]
90 | terminals: List[Bool]
91 | infos: List[Any]
92 | return:
93 | None
94 | """
95 | def act(observation): ...
96 | """
97 | legend:
98 | ObsName = str
99 | ActionName = str
100 | args:
101 | observation: Dict[ObsName, torch.Tensor]
102 | return:
103 | actions: Dict[ActionName, torch.Tensor],
104 | experience: Dict[ExpName, torch.Tensor]
105 | """
106 | def compute_loss(experience, next_obs): ...
107 | """
108 | args:
109 | experience: torch.Tensor
110 | next_obs: Dict[ObsName, torch.Tensor]
111 | return:
112 | losses: Dict[LossName, torch.Tensor]
113 | metrics: Dict[MetricName, torch.Tensor]
114 | """
115 | ```
116 |
117 | ### EnvironmentManager
118 | * `obs_preprocessor_gpu: ObsPreprocessor`
119 | * `env_cls: type, environment class`
120 | ```python
121 | def step(actions): ...
122 | """
123 | args:
124 | actions: Dict[str, List[np.ndarray]]
125 | return:
126 | Tuple[Observations, Rewards, Terminals, Infos]
127 | """
128 | def reset(): ...
129 | """
130 | description:
131 | Reset the environment to its initial state.
132 | return:
133 | observation: Dict[ObsName, torch.Tensor]
134 | """
135 | def close(): ...
136 | """
137 | description:
138 | Close environments.
139 | return:
140 | None
141 | """
142 | ```
143 |
144 | ### Environment
145 | * [obs_preprocessor_cpu](#)
146 | * [action_preprocessor](#)
147 | ```python
148 | def step(action): ...
149 | """
150 | args:
151 | action: Dict[str, np.ndarray]
152 | return:
153 | Tuple[Observation, Reward, Terminal, Info]
154 | """
155 | ```
156 |
157 | ### ExperienceCache
158 | ```python
159 | def read(): ...
160 | def env_write(observations, rewards, terminals): ...
161 | def agent_write(log_prob, entropy): ...
162 | def write(observation, reward, terminal, value, log_prob, entropy, hidden_state): ...
163 | def ready(): ...
164 | def submit(): ...
165 | def receive(): ...
166 | ```
167 |
168 | ### Network
169 | ```python
170 | def forward(observation, hidden_state): ...
171 | """
172 | args:
173 | observation: Dict[ObsName, torch.Tensor]
174 | hidden_state: Dict[HSName, torch.Tensor]
175 | """
176 | ```
177 |
178 | ### SubModule
179 | * `input_shape: Tuple[int]`
180 | ```python
181 | def forward(): ...
182 | def output_shape(dim): ...
183 | """
184 | args:
185 | dim: int, dimensionality of
186 | """
187 | ```
188 |
189 | ### HiddenState
190 | ```python
191 | def detach(): ...
192 | ```
193 |
--------------------------------------------------------------------------------
/docs/modular_network.md:
--------------------------------------------------------------------------------
1 | 
2 |
--------------------------------------------------------------------------------
/docs/new_api.md:
--------------------------------------------------------------------------------
1 | Inheritance Tree:
2 | * [Container](#container)
3 | * LocalWorker
4 | * DistribWorker
5 | * ImpalaHost
6 | * ImpalaWorkerCPU
7 | * ImpalaWorkerGPU
8 | * [Agent](#agent)
9 | * ActorCritic
10 | * VTrace
11 | * [TrajectoryCache](#trajectorycache)
12 | * ExperienceReplay
13 | * RolloutCache
14 | * [Environment](#environment)
15 | * GymEnv
16 | * SC2Feat1DEnv (1d action space)
17 | * SC2Feat3DEnv (3d action space)
18 | * SC2RawEnv (new proto)
19 | * [EnvironmentManager](#environmentmanager)
20 | * SimpleEnvManager (synchronous, same process, for debugging / rendering)
21 | * SubProcEnvManager (use torch.multiprocessing.Pipe, default start method)
22 | * SelfPlayEnvManager
23 | * [Network](#network)
24 | * ModularNetwork
25 | * CustomNetwork
26 | * [SubModule](#submodule)
27 | * SubModule1D
28 | * Identity1D
29 | * LSTM
30 | * SubModule2D
31 | * Identity2D
32 | * TransformerEnc
33 | * TransformerDec
34 | * SubModule3D
35 | * Identity3D
36 | * SC2LEEncoder - encoder from SC2LE paper
37 | * SC2LEDecoder - decoder from SC2LE paper
38 | * ConvLSTM
39 | * SubModule4D
40 | * Identity4D
41 |
42 | ### Container
43 | * `_agent:` [Agent](#agent)
44 | * `_environment:` [Environment](#environment)
45 |
46 | * `_network:` [Network](#network)
47 | * `_local_step_count: int`
48 | * `_reward_buffer: List[float], keep track of running rewards`
49 | * `_hidden_state: Dict[Id, torch.Tensor], keep track of current LSTM states`
50 | ```python
51 | def run(nb_step, initial_step=0): ...
52 | """
53 | args:
54 | nb_step: int, number of steps to train for
55 | return:
56 | None
57 | """
58 | ```
59 |
60 | ### Agent
61 | * `_trajectory_cache:` [TrajectoryCache](#trajectorycache)
62 | ```python
63 |
64 | @staticmethod
65 | def output_space(action_space): ...
66 | """
67 | args:
68 | action_space: Dict[ActionName, Shape]
69 |
70 | """
71 | def observe(cls): ...
72 | def compute_loss(prediction, trajectory): ... # ?
73 | ```
74 |
75 | ### EnvironmentManager
76 | * _obs_preprocessor_gpu: [ObsPreprocessor](#)
77 | * `env_cls: type, environment class`
78 | ```python
79 | def step_train(policy_logits): ...
80 | """
81 | args:
82 | policy_logits: Dict[ActionName, torch.Tensor]
83 | return:
84 | Tuple[Obs, Reward, Done, Info, LogProb, Entropy]
85 | """
86 | def step_eval(policy_logits): ...
87 | """
88 | args:
89 | policy_logits: Dict[ActionName, torch.Tensor]
90 | return:
91 | Tuple[Obs, Reward, Done, Info]
92 | """
93 |
94 | ```
95 |
96 | ### Environment
97 | * _obs_preprocessor_cpu: [ObsPreprocessor](#)
98 | * _action_preprocessor: [ActionPreprocessor](#)
99 | * `_action_space: Dict[]`
100 | ```python
101 | def step(observation): ...
102 | def reset(): ...
103 | def close(): ...
104 |
105 |
106 | ```
107 |
108 | ### TrajectoryCache
109 | ```python
110 | def read(): ...
111 | def write(): ...
112 | def ready(): ...
113 | """
114 | return:
115 | bool, whether the cache is ready
116 | """
117 | ```
118 |
119 | ### Network
120 | ```python
121 | def forward(observation, hidden_state): ...
122 | """
123 | args:
124 | observation: Dict[ObsName, torch.Tensor]
125 | hidden_state: Dict[
126 | """
127 | ```
128 |
129 | ### SubModule
130 | * `input_shape: Tuple[int]`
131 | ```python
132 | def forward(): ...
133 | def output_shape(dim): ...
134 | """
135 | args:
136 | dim: int, dimensionality of
137 | """
138 | ```
139 |
140 | ### ObsPreprocessor
141 |
--------------------------------------------------------------------------------
/docs/resume_training.md:
--------------------------------------------------------------------------------
1 | To resume training, you need to navigate to the log folder of the training job.
2 | By default, logs are saved in `/tmp/adept_logs`.
3 | ```bash
4 | # Change directory to the desired log directory
5 | cd /tmp/adept_logs///
6 | # To continue training on a single GPU
7 | python -m adept.app local --resume .
8 | # To continue training on multiple GPUs
9 | python -m adept.app distrib --resume .
10 | ```
11 |
--------------------------------------------------------------------------------
/examples/custom_agent_stub.py:
--------------------------------------------------------------------------------
1 | """
2 | Use a custom agent.
3 | """
4 | from adept.agents import AgentModule, AgentRegistry
5 | from adept.scripts.local import parse_args, main
6 |
7 |
8 | class MyCustomAgent(AgentModule):
9 | # You will be prompted for these when training script starts
10 | args = {"example_arg1": True, "example_arg2": 5}
11 |
12 | def __init__(
13 | self,
14 | network,
15 | device,
16 | reward_normalizer,
17 | gpu_preprocessor,
18 | engine,
19 | action_space,
20 | nb_env,
21 | *args,
22 | **kwargs
23 | ):
24 | super(MyCustomAgent, self).__init__(
25 | network,
26 | device,
27 | reward_normalizer,
28 | gpu_preprocessor,
29 | engine,
30 | action_space,
31 | nb_env,
32 | )
33 |
34 | @classmethod
35 | def from_args(
36 | cls,
37 | args,
38 | network,
39 | device,
40 | reward_normalizer,
41 | gpu_preprocessor,
42 | engine,
43 | action_space,
44 | **kwargs
45 | ):
46 | """
47 |
48 | ArgName = str
49 |
50 | :param args: Dict[ArgName, Any]
51 | :param network: BaseNetwork
52 | :param device: torch.device
53 | :param reward_normalizer: Callable[[float], float]
54 | :param gpu_preprocessor: ObsPreprocessor
55 | :param engine: env_registry.Engines
56 | :param action_space: Dict[ActionKey, torch.Tensor]
57 | :param kwargs:
58 | :return: MyCustomAgent
59 | """
60 | pass
61 |
62 | @property
63 | def exp_cache(self):
64 | """
65 | Experience cache, probably a RolloutCache or ExperienceReplay.
66 |
67 | :return: BaseExperience
68 | """
69 | pass
70 |
71 | @staticmethod
72 | def output_space(action_space):
73 | """
74 | Merge action space with any agent-based outputs to get an output_space.
75 |
76 | ActionKey = str
77 | Shape = Tuple[*int]
78 |
79 | :param action_space: Dict[ActionKey, Shape]
80 | :return:
81 | """
82 | pass
83 |
84 | def compute_loss(self, experience, next_obs):
85 | """
86 | Compute losses.
87 |
88 | ObsKey = str
89 | LossKey = str
90 |
91 | :param experience: Tuple[*Any]
92 | :param next_obs: Dict[ObsKey, torch.Tensor]
93 | :return: Dict[LossKey, torch.Tensor (0D)]
94 | """
95 | pass
96 |
97 | def act(self, obs):
98 | """
99 | Generate an action.
100 |
101 | ObsKey = str
102 | ActionKey = str
103 |
104 | :param obs: Dict[ObsKey, torch.Tensor]
105 | :return: Dict[ActionKey, np.ndarray]
106 | """
107 | pass
108 |
109 | def act_eval(self, obs):
110 | """
111 | Generate an action in an evaluation.
112 |
113 | ObsKey = str
114 | ActionKey = str
115 |
116 | :param obs: Dict[ObsKey, torch.Tensor]
117 | :return: Dict[ActionKey, np.ndarray]
118 | """
119 | pass
120 |
121 |
122 | if __name__ == "__main__":
123 | args = parse_args()
124 | agent_reg = AgentRegistry()
125 | agent_reg.register_agent(MyCustomAgent)
126 |
127 | main(args, agent_registry=agent_reg)
128 |
129 | # Call script like this to train agent:
130 | # python -m custom_agent_stub.py --agent MyCustomAgent
131 |
--------------------------------------------------------------------------------
/examples/custom_environment_stub.py:
--------------------------------------------------------------------------------
1 | """
2 | Use a custom environment.
3 | """
4 | from adept.env import EnvModule
5 | from adept.scripts.local import parse_args, main
6 |
7 |
8 | class MyCustomEnv(EnvModule):
9 | # You will be prompted for these when training script starts
10 | args = {"example_arg1": True, "example_arg2": 5}
11 | ids = ["scenario1", "scenario2"]
12 |
13 | def __init__(self, action_space, cpu_ops, gpu_ops, *args, **kwargs):
14 | super(MyCustomEnv, self).__init__(action_space, cpu_ops, gpu_ops)
15 |
16 | @classmethod
17 | def from_args(cls, args, seed, **kwargs):
18 | """
19 | Construct from arguments. For convenience.
20 |
21 | :param args: Arguments object
22 | :param seed: Integer used to seed this environment.
23 | :param kwargs: Any custom arguments are passed through kwargs.
24 | :return: EnvModule instance.
25 | """
26 | pass
27 |
28 | def step(self, action):
29 | """
30 | Perform action.
31 |
32 | ActionID = str
33 | Observation = Dict[ObsKey, Any]
34 | Reward = np.ndarray
35 | Terminal = bool
36 | Info = Dict[Any, Any]
37 |
38 | :param action: Dict[ActionID, Any] Action dictionary
39 | :return: Tuple[Observation, Reward, Terminal, Info]
40 | """
41 | pass
42 |
43 | def reset(self, **kwargs):
44 | """
45 | Reset environment.
46 |
47 | ObsKey = str
48 |
49 | :param kwargs:
50 | :return: Dict[ObsKey, Any] Observation dictionary
51 | """
52 | pass
53 |
54 | def close(self):
55 | """
56 | Close any connections / resources.
57 |
58 | :return:
59 | """
60 | pass
61 |
62 |
63 | if __name__ == "__main__":
64 | import adept
65 |
66 | adept.register_env(MyCustomEnv)
67 | main(parse_args())
68 |
69 | # Call script like this to train agent:
70 | # python -m custom_env_stub.py --env scenario1
71 |
--------------------------------------------------------------------------------
/examples/custom_network_stub.py:
--------------------------------------------------------------------------------
1 | """
2 | Use a custom network.
3 | """
4 | from adept.networks import NetworkModule, NetworkRegistry
5 | from adept.scripts.local import parse_args, main
6 |
7 |
8 | class MyCustomNetwork(NetworkModule):
9 | # You will be prompted for these when training script starts
10 | args = {"example_arg1": True, "example_arg2": 5}
11 |
12 | def __init__(self):
13 | super(MyCustomNetwork, self).__init__()
14 | # Set properties and whatnot here
15 |
16 | @classmethod
17 | def from_args(cls, args, observation_space, output_space, net_reg):
18 | """
19 | Construct a MyCustomNetwork from arguments.
20 | ArgName = str
21 | ObsKey = str
22 | OutputKey = str
23 | Shape = Tuple[*int]
24 | :param args: Dict[ArgName, Any]
25 | :param observation_space: Dict[ObsKey, Shape]
26 | :param output_space: Dict[OutputKey, Shape]
27 | :param net_reg: NetworkRegistry
28 | :return: MyCustomNetwork
29 | """
30 | pass
31 |
32 | def new_internals(self, device):
33 | """
34 | Define any initial hidden states here, move them to device if necessary.
35 | InternalKey=str
36 | :return: Dict[InternalKey, torch.Tensor (ND)]
37 | """
38 | pass
39 |
40 | def forward(self, observation, internals):
41 | """
42 | Compute forward pass.
43 | ObsKey = str
44 | InternalKey = str
45 | :param observation: Dict[ObsKey, torch.Tensor (1D | 2D | 3D | 4D)]
46 | :param internals: Dict[InternalKey, torch.Tensor (ND)]
47 | :return: torch.Tensor
48 | """
49 | pass
50 |
51 |
52 | if __name__ == "__main__":
53 | args = parse_args()
54 | network_reg = NetworkRegistry()
55 | network_reg.register_network(MyCustomNetwork)
56 |
57 | main(args, net_registry=network_reg)
58 |
59 | # Call script like this to train agent:
60 | # python -m custom_network_stub.py --custom-network MyCustomNetwork
61 |
--------------------------------------------------------------------------------
/examples/custom_submodule_stub.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom submodule stub
3 | """
4 | from adept.networks import (
5 | SubModule1D,
6 | SubModule2D,
7 | SubModule3D,
8 | SubModule4D,
9 | NetworkRegistry,
10 | )
11 | from adept.scripts.local import parse_args, main
12 |
13 |
14 | # If your Module processes 2D, then inherit SubModule2D and so on.
15 | # Dimensionality refers to feature map dimensions not including batch.
16 | # ie. (F, ) = 1D, (F, L) = 2D, (F, H, W) = 3D, (F, D, H, W) = 4D
17 | class MyCustomSubModule1D(SubModule1D):
18 | # You will be prompted for these when training script starts
19 | args = {"example_arg1": True, "example_arg2": 5}
20 |
21 | def __init__(self, input_shape, id):
22 | super(MyCustomSubModule1D, self).__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | """
27 | Construct a MyCustomSubModule1D from arguments.
28 |
29 | :param args: Dict[ArgName, Any]
30 | :param input_shape: Tuple[*int]
31 | :param id: str
32 | :return: MyCustomSubModule1D
33 | """
34 | pass
35 |
36 | @property
37 | def _output_shape(self):
38 | """
39 | Return the output shape. If it's a function of the input shape, you can
40 | access the input shape via ``self.input_shape``.
41 |
42 | :return: Tuple[*int]
43 | """
44 | pass
45 |
46 | def _forward(self, input, internals, **kwargs):
47 | """
48 | Compute forward pass.
49 |
50 | ObsKey = str
51 | InternalKey = str
52 |
53 | :param observation: Dict[ObsKey, torch.Tensor]
54 | :param internals: Dict[InternalKey, torch.Tensor (ND)]
55 | :return: torch.Tensor
56 | """
57 | pass
58 |
59 | def _new_internals(self):
60 | """
61 | Define any initial hidden states here, move them to device if necessary.
62 |
63 | InternalKey=str
64 |
65 | :return: Dict[InternalKey, torch.Tensor (ND)]
66 | """
67 | pass
68 |
69 |
70 | if __name__ == "__main__":
71 | args = parse_args()
72 | network_reg = NetworkRegistry()
73 | network_reg.register_submodule(MyCustomSubModule1D)
74 |
75 | main(args, net_registry=network_reg)
76 |
77 | # Call script like this to train agent:
78 | # python -m custom_submodule_stub.py --net1d MyCustomSubModule1D
79 |
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/images/architecture.png
--------------------------------------------------------------------------------
/images/banner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/images/banner.png
--------------------------------------------------------------------------------
/images/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/images/benchmark.png
--------------------------------------------------------------------------------
/images/modular_network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/images/modular_network.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from adept.globals import VERSION
3 |
4 | # https://github.com/kennethreitz/setup.py/blob/master/setup.py
5 |
6 |
7 | with open("README.md", "r") as fh:
8 | long_description = fh.read()
9 |
10 | extras = {
11 | "profiler": ["pyinstrument>=2.0"],
12 | "atari": [
13 | "gym[atari]>=0.10",
14 | "opencv-python-headless<4,>=3.4",
15 | ]
16 | }
17 | test_deps = ["pytest"]
18 |
19 | all_deps = []
20 | for group_name in extras:
21 | all_deps += extras[group_name]
22 | all_deps = all_deps + test_deps
23 | extras["all"] = all_deps
24 |
25 |
26 | setup(
27 | name="adeptRL",
28 | version=VERSION,
29 | author="heron",
30 | author_email="adept@heronsystems.com",
31 | description="Reinforcement Learning Framework",
32 | long_description=long_description,
33 | long_description_content_type="text/markdown",
34 | url="https://github.com/heronsystems/adeptRL",
35 | license="GNU",
36 | python_requires=">=3.5.0",
37 | packages=find_packages(),
38 | install_requires=[
39 | "protobuf>=3.15.3",
40 | "numpy>=1.14",
41 | "tensorflow<3,>=2.4.0",
42 | "cloudpickle>=0.5",
43 | "pyzmq>=17.1.2",
44 | "docopt>=0.6",
45 | "torch>=1.3.1",
46 | "torchvision>=0.4.2",
47 | "ray>=1.3.0",
48 | "pandas>=1.0.5",
49 | "msgpack<2,>=1.0.2",
50 | "msgpack-numpy<1,>=0.4.7",
51 | ],
52 | test_requires=test_deps,
53 | extras_require=extras,
54 | include_package_data=True,
55 | )
56 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/tests/__init__.py
--------------------------------------------------------------------------------
/tests/distrib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/tests/distrib/__init__.py
--------------------------------------------------------------------------------
/tests/distrib/allreduce.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 | WORLD_SIZE = int(os.environ["WORLD_SIZE"])
7 | GLOBAL_RANK = int(os.environ["RANK"])
8 | LOCAL_RANK = int(os.environ["LOCAL_RANK"])
9 | NB_NODE = int(os.environ["NB_NODE"])
10 | LOCAL_SIZE = WORLD_SIZE // NB_NODE
11 |
12 | print("w", WORLD_SIZE)
13 | print("g", GLOBAL_RANK)
14 | print("l", LOCAL_RANK)
15 | print("n", NB_NODE)
16 |
17 |
18 | def on_worker():
19 | return LOCAL_RANK != 0
20 |
21 |
22 | def on_host():
23 | return LOCAL_RANK == 0
24 |
25 |
26 | if __name__ == "__main__":
27 | nb_gpu = torch.cuda.device_count()
28 | print("Device Count", nb_gpu)
29 |
30 | dist.init_process_group(
31 | backend="nccl", world_size=WORLD_SIZE, rank=LOCAL_RANK
32 | )
33 |
34 | print("LOCAL_RANK", LOCAL_RANK, "initialized.")
35 | t = torch.tensor([1.0, 2.0, 3.0]).to(f"cuda:{LOCAL_RANK}")
36 |
37 | # tags to identify tensors
38 | # loop thru workers
39 | dist.barrier()
40 | handle = dist.all_reduce(t, async_op=True)
41 | handle.wait()
42 |
43 | print(t)
44 |
--------------------------------------------------------------------------------
/tests/distrib/control_flow_zmq.py:
--------------------------------------------------------------------------------
1 | import zmq
2 | import os
3 | import time
4 | from collections import deque
5 |
6 | WORLD_SIZE = int(os.environ["WORLD_SIZE"])
7 | GLOBAL_RANK = int(os.environ["RANK"])
8 | LOCAL_RANK = int(os.environ["LOCAL_RANK"])
9 | NB_NODE = int(os.environ["NB_NODE"])
10 | LOCAL_SIZE = WORLD_SIZE // NB_NODE
11 |
12 | if __name__ == "__main__":
13 | if LOCAL_RANK == 0:
14 | context = zmq.Context()
15 | h_to_w = context.socket(zmq.PUBLISH)
16 | w_to_h = context.socket(zmq.PULL)
17 |
18 | h_to_w.bind("tcp://*:5556")
19 | w_to_h.bind("tcp://*:5557")
20 |
21 | step_count = 0
22 | nb_batch = 2
23 | # while step_count < 100:
24 | # q, q_lookup = deque(), set()
25 | # while len(q) < nb_batch:
26 | # for i, hand in enumerate(handles):
27 | # if i not in q_lookup:
28 |
29 | else:
30 | context = zmq.Context()
31 | h_to_w = context.socket(zmq.SUBSCRIBE)
32 | w_to_h = context.socket(zmq.PUSH)
33 |
34 | h_to_w.connect("tcp://localhost:5556")
35 | w_to_h.connect("tcp://localhost:5557")
36 |
37 | done = False
38 |
39 | while not done:
40 |
41 | print("worker received")
42 |
43 | time.sleep(1)
44 |
45 | # Host event loop
46 | # check for rollouts
47 | # batch rollouts
48 | # tell q workers to do another
49 | # learn on batch
50 | # send new model
51 |
52 | # Worker event loop
53 | # step the actor, write to exp
54 | # if new model, receive new model params
55 | # if exp ready, notify host
56 | # wait for host to be ready for another
57 |
58 | # Commands:
59 | # CALC_EXPS
60 | # GET_ROLLOUT_i
61 | # GET_
62 |
--------------------------------------------------------------------------------
/tests/distrib/exp_sync_broadcast.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import chain
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | WORLD_SIZE = int(os.environ["WORLD_SIZE"])
8 | GLOBAL_RANK = int(os.environ["RANK"])
9 | LOCAL_RANK = int(os.environ["LOCAL_RANK"])
10 | NB_NODE = int(os.environ["NB_NODE"])
11 | LOCAL_SIZE = WORLD_SIZE // NB_NODE
12 |
13 | print("w", WORLD_SIZE)
14 | print("g", GLOBAL_RANK)
15 | print("l", LOCAL_RANK)
16 | print("n", NB_NODE)
17 |
18 |
19 | cache_spec = {"xs": (2, 6, 3, 3), "ys": (2, 6, 12), "rewards": (2, 6, 16)}
20 |
21 |
22 | def gpu_id(local_rank, device_count):
23 | if local_rank == 0:
24 | return 0
25 | elif device_count == 1:
26 | return 0
27 | else:
28 | return (local_rank % (device_count - 1)) + 1
29 |
30 |
31 | class WorkerCache(dict):
32 | def __init__(self, cache_spec, gpu_id):
33 | super(WorkerCache, self).__init__()
34 | self.sorted_keys = sorted(cache_spec.keys())
35 | self.gpu_id = gpu_id
36 |
37 | for k in self.sorted_keys:
38 | self[k] = self._init_rollout(cache_spec, k)
39 |
40 | def _init_rollout(self, spec, key):
41 | return [
42 | torch.ones(*spec[key][1:]).to(f"cuda:{self.gpu_id}")
43 | for _ in range(spec[key][0])
44 | ]
45 |
46 | def sync(self, src, grp):
47 | handles = []
48 | for k in self.sorted_keys:
49 | for t in self[k]:
50 | handles.append(
51 | dist.broadcast(t, src=src, group=grp, async_op=True)
52 | )
53 | return handles
54 |
55 | def iter_tensors(self):
56 | return chain(*[self[k] for k in self.sorted_keys])
57 |
58 |
59 | class HostCache(WorkerCache):
60 | def _init_rollout(self, spec, key):
61 | return [
62 | torch.zeros(*spec[key][1:]).to(f"cuda:{self.gpu_id}")
63 | for _ in range(spec[key][0])
64 | ]
65 |
66 |
67 | def on_worker():
68 | return LOCAL_RANK != 0
69 |
70 |
71 | def on_host():
72 | return LOCAL_RANK == 0
73 |
74 |
75 | if __name__ == "__main__":
76 | nb_gpu = torch.cuda.device_count()
77 | print("Device Count", nb_gpu)
78 |
79 | dist.init_process_group(
80 | backend="nccl", world_size=WORLD_SIZE, rank=LOCAL_RANK
81 | )
82 |
83 | groups = []
84 | for i in range(1, LOCAL_SIZE):
85 | grp = [0, i]
86 | groups.append(dist.new_group(grp))
87 |
88 | print("LOCAL_RANK", LOCAL_RANK, "initialized.")
89 | if on_worker():
90 | cache = WorkerCache(cache_spec, gpu_id(LOCAL_RANK, nb_gpu))
91 | [t.fill_(LOCAL_RANK) for t in cache.iter_tensors()]
92 | else:
93 | caches = [
94 | HostCache(cache_spec, gpu_id(LOCAL_RANK, nb_gpu))
95 | for _ in range(LOCAL_SIZE - 1)
96 | ]
97 |
98 | # tags to identify tensors
99 | # loop thru workers
100 |
101 | if on_worker():
102 | handle = cache.sync(LOCAL_RANK, groups[LOCAL_RANK - 1])
103 | else:
104 | handles = []
105 | for i, cache in enumerate(caches):
106 | handles.append(cache.sync(i + 1, groups[i]))
107 |
108 | if on_worker():
109 | [h.wait() for h in handle]
110 | else:
111 | for handle in handles:
112 | [h.wait() for h in handle]
113 |
114 | if on_host():
115 | for cache in caches:
116 | for t in cache.iter_tensors():
117 | print(t)
118 |
--------------------------------------------------------------------------------
/tests/distrib/hello_ray.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import namedtuple
3 |
4 | import ray
5 |
6 |
7 | @ray.remote(num_cpus=2)
8 | class Worker:
9 | State = namedtuple("State", ["asdf"])
10 |
11 | def __init__(self):
12 | pass
13 |
14 | def sleep(self, t):
15 | time.sleep(t)
16 | print(f"slept for {t}")
17 |
18 | def sleep5(self):
19 | time.sleep(5)
20 |
21 | def sleep10(self):
22 | time.sleep(10)
23 |
24 |
25 | def main():
26 | ray.init(num_cpus=4)
27 | remote_worker = Worker.remote()
28 |
29 | t_zero = time.time()
30 |
31 | f5 = remote_worker.sleep5.remote()
32 | f10 = remote_worker.sleep10.remote()
33 |
34 | ray.wait([f5, f10], num_returns=2)
35 | print("delta", time.time() - t_zero)
36 |
37 |
38 | def main_async():
39 | import asyncio
40 | from ray.experimental import async_api
41 |
42 | ray.init(num_cpus=4)
43 | remote_worker = Worker.remote()
44 | loop = asyncio.get_event_loop()
45 |
46 | t_zero = time.time()
47 |
48 | tasks = [
49 | async_api.as_future(remote_worker.sleep.remote(i)) for i in range(1, 3)
50 | ]
51 | loop.run_until_complete(asyncio.gather(tasks))
52 |
53 | print("delta", time.time() - t_zero)
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/tests/distrib/launch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import subprocess
4 |
5 | NB_NODE = 1
6 | NODE_RANK = 0
7 | MASTER_ADDR = "127.0.0.1"
8 | MASTER_PORT = "29500"
9 |
10 |
11 | def main(args):
12 | processes = []
13 | for local_rank in range(args.nb_proc):
14 | # each process's rank
15 | dist_rank = args.nb_proc * NODE_RANK + local_rank
16 | current_env["RANK"] = str(dist_rank)
17 | current_env["LOCAL_RANK"] = str(local_rank)
18 |
19 | cmd = [sys.executable, "-u", args.script]
20 |
21 | proc = subprocess.Popen(cmd, env=current_env)
22 | processes.append(proc)
23 |
24 | for process in processes:
25 | process.wait()
26 |
27 |
28 | if __name__ == "__main__":
29 | from argparse import ArgumentParser
30 |
31 | parser = ArgumentParser()
32 | parser.add_argument("--script", default="exp_sync_broadcast.py")
33 | parser.add_argument("--nb-proc", type=int, default=2)
34 | args = parser.parse_args()
35 |
36 | dist_world_size = args.nb_proc * NB_NODE
37 |
38 | current_env = os.environ.copy()
39 | current_env["MASTER_ADDR"] = MASTER_ADDR
40 | current_env["MASTER_PORT"] = MASTER_PORT
41 | current_env["WORLD_SIZE"] = str(dist_world_size)
42 | current_env["NB_NODE"] = str(NB_NODE)
43 |
44 | main(args)
45 |
--------------------------------------------------------------------------------
/tests/distrib/multi_group.py:
--------------------------------------------------------------------------------
1 | import os
2 | from threading import Thread
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | WORLD_SIZE = int(os.environ["WORLD_SIZE"])
8 | GLOBAL_RANK = int(os.environ["RANK"])
9 | LOCAL_RANK = int(os.environ["LOCAL_RANK"])
10 | NB_NODE = int(os.environ["NB_NODE"])
11 | LOCAL_SIZE = WORLD_SIZE // NB_NODE
12 |
13 | print("w", WORLD_SIZE)
14 | print("g", GLOBAL_RANK)
15 | print("l", LOCAL_RANK)
16 | print("n", NB_NODE)
17 |
18 |
19 | def on_worker():
20 | return LOCAL_RANK != 0
21 |
22 |
23 | def on_host():
24 | return LOCAL_RANK == 0
25 |
26 |
27 | if __name__ == "__main__":
28 | nb_gpu = torch.cuda.device_count()
29 | print("Device Count", nb_gpu)
30 |
31 | dist.init_process_group(
32 | backend="nccl", world_size=WORLD_SIZE, rank=LOCAL_RANK
33 | )
34 | wh_group = dist.new_group([0, 1])
35 | hw_group = dist.new_group([0, 1])
36 |
37 | print("LOCAL_RANK", LOCAL_RANK, "initialized.")
38 |
39 | def wh_loop():
40 | count = 0
41 | while count < 10:
42 | dist.barrier(wh_group)
43 | dist.broadcast(t_a, 1, wh_group)
44 | count += 1
45 | print(f"wh {count}")
46 | return True
47 |
48 | def hw_loop():
49 | count = 0
50 | while count < 5:
51 | dist.barrier(hw_group)
52 | dist.broadcast(t_b, 0, hw_group)
53 | count += 1
54 | print(f"hw {count}")
55 | return True
56 |
57 | if on_host():
58 | t_a = torch.tensor([1, 2, 3]).to("cuda:0")
59 | t_b = torch.tensor([1, 2, 3]).to("cuda:0")
60 |
61 | if on_worker():
62 | t_a = torch.tensor([4, 5, 6]).to("cuda:0")
63 | t_b = torch.tensor([4, 5, 6]).to("cuda:0")
64 |
65 | thread_wh = Thread(target=wh_loop)
66 | thread_hw = Thread(target=hw_loop)
67 |
68 | thread_wh.start()
69 | thread_hw.start()
70 |
71 | thread_wh.join()
72 | thread_hw.join()
73 |
74 | print("t_a", t_a, "should be [4, 5, 6]")
75 | print("t_b", t_b, "should be [1, 2, 3]")
76 |
--------------------------------------------------------------------------------
/tests/distrib/nccl_typecheck.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import chain
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | WORLD_SIZE = int(os.environ["WORLD_SIZE"])
8 | GLOBAL_RANK = int(os.environ["RANK"])
9 | LOCAL_RANK = int(os.environ["LOCAL_RANK"])
10 | NB_NODE = int(os.environ["NB_NODE"])
11 | LOCAL_SIZE = WORLD_SIZE // NB_NODE
12 |
13 | print("w", WORLD_SIZE)
14 | print("g", GLOBAL_RANK)
15 | print("l", LOCAL_RANK)
16 | print("n", NB_NODE)
17 |
18 |
19 | def on_worker():
20 | return LOCAL_RANK != 0
21 |
22 |
23 | def on_host():
24 | return LOCAL_RANK == 0
25 |
26 |
27 | if __name__ == "__main__":
28 | nb_gpu = torch.cuda.device_count()
29 | print("Device Count", nb_gpu)
30 |
31 | dist.init_process_group(
32 | backend="nccl", world_size=WORLD_SIZE, rank=LOCAL_RANK
33 | )
34 |
35 | print("LOCAL_RANK", LOCAL_RANK, "initialized.")
36 | if on_host():
37 | t = torch.tensor([1, 2, 3]).to("cuda:0")
38 | if on_worker():
39 | t = torch.tensor([1.0, 2.0, 3.0]).to("cuda:0")
40 |
41 | # tags to identify tensors
42 | # loop thru workers
43 | dist.barrier()
44 | handle = dist.broadcast(t, 0, async_op=True)
45 | handle.wait()
46 |
47 | print(t.long())
48 | print(t.type())
49 |
--------------------------------------------------------------------------------
/tests/exp/test_rollout.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import torch
4 | from adept.exp import Rollout, ExpSpecBuilder
5 |
6 | obs_space = {"obs_a": (2, 2), "obs_b": (3, 3)}
7 | act_space = {"act_a": (5,), "act_b": (6,)}
8 | internal_space = {"internal_a": (2,), "internal_b": (3,)}
9 | obs_keys = ["obs_a", "obs_b"]
10 | act_keys = ["act_a", "act_b"]
11 | internal_keys = ["internal_a", "internal_b"]
12 | batch_size = 16
13 | exp_len = 20
14 |
15 |
16 | def build_fn(exp_len):
17 | return {
18 | "obs_a": (exp_len + 1, batch_size, 2, 2),
19 | "obs_b": (exp_len + 1, batch_size, 3, 3),
20 | "act_a": (exp_len, batch_size),
21 | "act_b": (exp_len, batch_size),
22 | "internal_a": (exp_len, batch_size, 2),
23 | "internal_b": (exp_len, batch_size, 3),
24 | "rewards": (exp_len, batch_size),
25 | "terminals": (exp_len, batch_size),
26 | }
27 |
28 |
29 | spec_builder = ExpSpecBuilder(
30 | obs_keys=obs_space,
31 | act_keys=act_space,
32 | internal_keys=internal_space,
33 | key_types={
34 | "obs_a": "long",
35 | "obs_b": "long",
36 | "act_a": "long",
37 | "act_b": "long",
38 | "internal_a": "float",
39 | "internal_b": "float",
40 | "rewards": "float",
41 | "terminals": "float",
42 | },
43 | exp_keys=obs_keys + act_keys + internal_keys + ["rewards", "terminals"],
44 | build_fn=build_fn,
45 | )
46 |
47 |
48 | class TestRollout(unittest.TestCase):
49 | def test_next_obs(self):
50 | r = Rollout(spec_builder, 20)
51 | next_obs = {
52 | "obs_a": torch.ones(batch_size, 2, 2),
53 | "obs_b": torch.ones(batch_size, 3, 3),
54 | }
55 | r.write_next_obs(next_obs)
56 | next_obs = r.read().next_observation
57 | print(next_obs["obs_a"].shape)
58 | print(next_obs["obs_b"].shape)
59 | # print(next_obs)
60 | self.assertEqual(next_obs["obs_a"][0][0][0].item(), 1)
61 | self.assertEqual(next_obs["obs_b"][0][0][0].item(), 1)
62 |
--------------------------------------------------------------------------------
/tests/learner/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/tests/learner/__init__.py
--------------------------------------------------------------------------------
/tests/learner/nstep.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | rollout_len = 20
4 |
5 | rewards = torch.zeros(rollout_len)
6 | rewards[-1] = 1.0
7 | terminals = torch.zeros(rollout_len)
8 | terminals[-1] = 1.0
9 | terminal_masks = 1.0 - terminals
10 | bootstrap_value = 0.0
11 |
12 | target = bootstrap_value
13 | nsteps = []
14 | for i in reversed(range(rollout_len)):
15 | target = rewards[i] + 0.99 * target * terminal_masks[i]
16 | nsteps.append(target)
17 |
18 | print(nsteps)
19 |
--------------------------------------------------------------------------------
/tests/network/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/tests/network/__init__.py
--------------------------------------------------------------------------------
/tests/network/test_modular_network.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from adept.network.modular_network import ModularNetwork
4 | from adept.network.net1d.identity_1d import Identity1D
5 | from adept.network.net2d.identity_2d import Identity2D
6 | from adept.network.net3d.identity_3d import Identity3D
7 | from adept.network.net4d.identity_4d import Identity4D
8 |
9 |
10 | def dummy_gpu_preprocessor(obs):
11 | return obs
12 |
13 |
14 | class TestModularNetwork(unittest.TestCase):
15 | # Example of valid structure
16 | source_nets = {
17 | "source_1d": Identity1D((16,), "source_1d"),
18 | "source_2d": Identity2D((16, 8 * 8), "source_2d"),
19 | "source_3d": Identity3D((16, 8, 8), "source_3d"),
20 | "source_4d": Identity4D((16, 8, 8, 8), "source_4d"),
21 | }
22 | body = Identity3D((176, 8, 8), "body")
23 | heads = {
24 | "1": Identity1D((11264,), "head1d"),
25 | "2": Identity2D((176, 64), "head2d"),
26 | "3": Identity3D((176, 8, 8), "head3d"),
27 | }
28 | output_space = {
29 | "output_1d": (16,),
30 | "output_2d": (16, 8 * 8),
31 | "output_3d": (16, 8, 8),
32 | }
33 |
34 | def test_heads_not_higher_dim_than_body(self):
35 | stub_1d = Identity1D((32,), "stub_1d")
36 | stub_2d = Identity2D((32, 32), "stub_2d")
37 |
38 | source_nets = {"source": stub_1d}
39 | body = stub_1d
40 | heads = {"2": stub_2d}
41 | output_space = {"output": (32, 32)}
42 |
43 | with self.assertRaises(AssertionError):
44 | ModularNetwork(
45 | source_nets, body, heads, output_space, dummy_gpu_preprocessor
46 | )
47 |
48 | def test_source_nets_match_body(self):
49 | stub_32 = Identity2D((32, 32), "stub_32")
50 | stub_64 = Identity2D((32, 64), "stub_64")
51 |
52 | source_nets = {"source": stub_32}
53 | body = stub_64 # should error
54 | heads = {"2": stub_64}
55 | output_space = {"output": (32, 64)}
56 |
57 | with self.assertRaises(AssertionError):
58 | ModularNetwork(
59 | source_nets, body, heads, output_space, dummy_gpu_preprocessor
60 | )
61 |
62 | def test_body_matches_heads(self):
63 | stub_32 = Identity2D((32, 32), "stub_32")
64 | stub_64 = Identity2D((32, 64), "stub_64")
65 |
66 | source_nets = {"source": stub_32}
67 | body = stub_32
68 | heads = {"2": stub_64} # should error
69 | output_space = {"output": (32, 64)}
70 |
71 | with self.assertRaises(AssertionError):
72 | ModularNetwork(
73 | source_nets, body, heads, output_space, dummy_gpu_preprocessor
74 | )
75 |
76 | def test_output_has_a_head(self):
77 | stub_2d = Identity2D((32, 32), "stub_2d")
78 |
79 | source_nets = {"source": stub_2d}
80 | body = stub_2d
81 | heads = {"2": stub_2d}
82 | output_space = {"output": (32, 32, 32)} # should error
83 | with self.assertRaises(AssertionError):
84 | ModularNetwork(
85 | source_nets, body, heads, output_space, dummy_gpu_preprocessor
86 | )
87 |
88 | def test_heads_match_out_shapes(self):
89 | stub_2d = Identity2D((32, 32), "stub_2d")
90 |
91 | source_nets = {"source": stub_2d}
92 | body = stub_2d
93 | heads = {"2": stub_2d}
94 | output_space = {"output": (32, 64)} # should error
95 | with self.assertRaises(AssertionError):
96 | ModularNetwork(
97 | source_nets, body, heads, output_space, dummy_gpu_preprocessor
98 | )
99 |
100 | def test_valid_structure(self):
101 | try:
102 | ModularNetwork(
103 | self.source_nets,
104 | self.body,
105 | self.heads,
106 | self.output_space,
107 | dummy_gpu_preprocessor,
108 | )
109 | except:
110 | self.fail("Unexpected exception")
111 |
112 | def test_forward(self):
113 | import torch
114 |
115 | BATCH = 32
116 | obs = {
117 | "source_1d": torch.zeros((BATCH, 16,)),
118 | "source_2d": torch.zeros((BATCH, 16, 8 * 8)),
119 | "source_3d": torch.zeros((BATCH, 16, 8, 8)),
120 | "source_4d": torch.zeros((BATCH, 16, 8, 8, 8)),
121 | }
122 | try:
123 | net = ModularNetwork(
124 | self.source_nets,
125 | self.body,
126 | self.heads,
127 | self.output_space,
128 | dummy_gpu_preprocessor,
129 | )
130 | outputs, _ = net.forward(obs, {})
131 | except:
132 | self.fail("Unexpected exception")
133 |
134 |
135 | if __name__ == "__main__":
136 | unittest.main(verbosity=1)
137 |
--------------------------------------------------------------------------------
/tests/registry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/tests/registry/__init__.py
--------------------------------------------------------------------------------
/tests/registry/test_registry.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import torch
3 |
4 | from adept.agent import AgentModule
5 | from adept.env import EnvModule
6 | from adept.preprocess.base.preprocessor import ObsPreprocessor
7 | from adept.preprocess import CPUPreprocessor, GPUPreprocessor
8 | from adept.registry import Registry
9 |
10 |
11 | class NotSubclass:
12 | pass
13 |
14 |
15 | class ArgsNotImplemented(AgentModule):
16 | pass
17 |
18 |
19 | class DummyEnv(EnvModule):
20 |
21 | args = {}
22 | ids = ["dummy"]
23 |
24 | def __init__(self):
25 | obs_space = {"screen": (3, 84, 84)}
26 | action_space = {"action": (8,)}
27 | cpu_preprocessor = CPUPreprocessor([], obs_space)
28 | gpu_preprocessor = GPUPreprocessor(
29 | [], cpu_preprocessor.observation_space
30 | )
31 | super(DummyEnv, self).__init__(
32 | action_space, cpu_preprocessor, gpu_preprocessor
33 | )
34 |
35 | @classmethod
36 | def from_args(cls, args, seed, **kwargs):
37 | return cls()
38 |
39 | def step(self, action):
40 | return {"screen": torch.rand((3, 84, 84))}, 1, False, {}
41 |
42 | def reset(self, **kwargs):
43 | return torch.rand((3, 84, 84))
44 |
45 | def close(self):
46 | return None
47 |
48 |
49 | class TestRegistry(unittest.TestCase):
50 | def test_register_invalid_class(self):
51 | registry = Registry()
52 | with self.assertRaises(AssertionError):
53 | registry.register_agent(NotSubclass)
54 |
55 | def test_register_args_not_implemented(self):
56 | registry = Registry()
57 | with self.assertRaises(NotImplementedError):
58 | registry.register_agent(ArgsNotImplemented)
59 |
60 | def test_save_classes(self):
61 | dummy_log_id_dir = "/tmp/adept_test/test_save_classes"
62 | registry = Registry()
63 | registry.register_env(DummyEnv)
64 | registry.save_extern_classes(dummy_log_id_dir)
65 |
66 | other = Registry()
67 | other.load_extern_classes(dummy_log_id_dir)
68 | env_cls = other.lookup_env("dummy")
69 | env = env_cls()
70 | env.reset()
71 |
72 |
73 | if __name__ == "__main__":
74 | unittest.main(verbosity=1)
75 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/heronsystems/adeptRL/d8554d134c1dfee6659baafd886684351c1dd982/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/test_requires_args.py:
--------------------------------------------------------------------------------
1 | import io
2 | import sys
3 | import unittest
4 |
5 | from adept.utils.requires_args import RequiresArgsMixin
6 |
7 |
8 | class Stub(RequiresArgsMixin):
9 | args = {"k1": 0, "k2": False, "k3": 1.5, "k4": "hello"}
10 |
11 |
12 | class TestRequiresArgs(unittest.TestCase):
13 | stub = Stub()
14 |
15 | def test_prompt_no_changes(self):
16 | sys.stdin = io.StringIO("\n")
17 | new_conf = self.stub.prompt()
18 | assert new_conf == self.stub.args
19 |
20 | def test_prompt_modify(self):
21 | sys.stdin = io.StringIO('{"k1": 5}')
22 | new_conf = self.stub.prompt()
23 | assert new_conf["k1"] == 5
24 | assert new_conf["k2"] == self.stub.args["k2"]
25 | assert new_conf["k3"] == self.stub.args["k3"]
26 | assert new_conf["k4"] == self.stub.args["k4"]
27 |
28 |
29 | if __name__ == "__main__":
30 | unittest.main(verbosity=1)
31 |
--------------------------------------------------------------------------------
/tests/utils/test_util.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from adept.utils.util import listd_to_dlist, dlist_to_listd
4 |
5 |
6 | class TestUtil(unittest.TestCase):
7 | def test_dlist_to_listd(self):
8 | assert dlist_to_listd({"a": [1]}) == [{"a": 1}]
9 |
10 | def test_listd_to_dlist(self):
11 | assert listd_to_dlist([{"a": 1}]) == {"a": [1]}
12 |
13 |
14 | if __name__ == "__main__":
15 | unittest.main(verbosity=2)
16 |
--------------------------------------------------------------------------------