├── .DS_Store
├── README.md
├── mcs
├── .DS_Store
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── globals.cpython-37.pyc
│ └── globals.cpython-38.pyc
├── actor
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── ac_eval.cpython-37.pyc
│ │ ├── ac_eval.cpython-38.pyc
│ │ ├── ac_rollout.cpython-37.pyc
│ │ ├── ac_rollout.cpython-38.pyc
│ │ ├── impala.cpython-37.pyc
│ │ └── impala.cpython-38.pyc
│ ├── ac_eval.py
│ ├── ac_rollout.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── ac_helper.cpython-37.pyc
│ │ │ ├── ac_helper.cpython-38.pyc
│ │ │ ├── actor_module.cpython-37.pyc
│ │ │ └── actor_module.cpython-38.pyc
│ │ ├── ac_helper.py
│ │ └── actor_module.py
│ └── impala.py
├── agent
│ ├── .DS_Store
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── actor_critic.cpython-37.pyc
│ │ └── actor_critic.cpython-38.pyc
│ ├── actor_critic.py
│ └── base
│ │ ├── .DS_Store
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── agent_module.cpython-37.pyc
│ │ └── agent_module.cpython-38.pyc
│ │ └── agent_module.py
├── config
│ └── default.json
├── container
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── distrib.cpython-37.pyc
│ │ ├── distrib.cpython-38.pyc
│ │ ├── evaluation.cpython-37.pyc
│ │ ├── evaluation.cpython-38.pyc
│ │ ├── init.cpython-37.pyc
│ │ ├── init.cpython-38.pyc
│ │ ├── local.cpython-37.pyc
│ │ └── local.cpython-38.pyc
│ ├── actorlearner
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── learner_container.cpython-37.pyc
│ │ │ ├── learner_container.cpython-38.pyc
│ │ │ ├── rollout_queuer.cpython-37.pyc
│ │ │ ├── rollout_queuer.cpython-38.pyc
│ │ │ ├── rollout_worker.cpython-37.pyc
│ │ │ └── rollout_worker.cpython-38.pyc
│ │ ├── learner_container.py
│ │ ├── rollout_queuer.py
│ │ └── rollout_worker.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── container.cpython-37.pyc
│ │ │ ├── container.cpython-38.pyc
│ │ │ ├── nccl_optimizer.cpython-37.pyc
│ │ │ ├── nccl_optimizer.cpython-38.pyc
│ │ │ ├── updater.cpython-37.pyc
│ │ │ └── updater.cpython-38.pyc
│ │ ├── container.py
│ │ ├── nccl_optimizer.py
│ │ └── updater.py
│ ├── distrib.py
│ ├── evaluation.py
│ ├── evaluation_thread.py
│ ├── init.py
│ ├── local.py
│ └── render.py
├── env
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── _gym_wrappers.cpython-37.pyc
│ │ ├── _gym_wrappers.cpython-38.pyc
│ │ ├── _spaces.cpython-37.pyc
│ │ ├── _spaces.cpython-38.pyc
│ │ ├── openai_gym.cpython-37.pyc
│ │ └── openai_gym.cpython-38.pyc
│ ├── _gym_wrappers.py
│ ├── _spaces.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── _env.cpython-37.pyc
│ │ │ ├── _env.cpython-38.pyc
│ │ │ ├── env_module.cpython-37.pyc
│ │ │ └── env_module.cpython-38.pyc
│ │ ├── _env.py
│ │ └── env_module.py
│ └── openai_gym.py
├── evaluate.py
├── exp
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── repetitive_buffer.cpython-37.pyc
│ │ ├── repetitive_buffer.cpython-38.pyc
│ │ ├── replay.cpython-37.pyc
│ │ ├── replay.cpython-38.pyc
│ │ ├── rollout.cpython-37.pyc
│ │ └── rollout.cpython-38.pyc
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── exp_module.cpython-37.pyc
│ │ │ ├── exp_module.cpython-38.pyc
│ │ │ ├── spec_builder.cpython-37.pyc
│ │ │ └── spec_builder.cpython-38.pyc
│ │ ├── exp_module.py
│ │ └── spec_builder.py
│ ├── repetitive_buffer.py
│ ├── replay.py
│ └── rollout.py
├── globals.py
├── learner
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── ac_rollout.cpython-37.pyc
│ │ ├── ac_rollout.cpython-38.pyc
│ │ ├── impala.cpython-37.pyc
│ │ └── impala.cpython-38.pyc
│ ├── ac_rollout.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── dm_return_scale.cpython-37.pyc
│ │ │ ├── dm_return_scale.cpython-38.pyc
│ │ │ ├── learner_module.cpython-37.pyc
│ │ │ └── learner_module.cpython-38.pyc
│ │ ├── dm_return_scale.py
│ │ └── learner_module.py
│ └── impala.py
├── logs
│ └── .DS_Store
├── manager
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── simple_env_manager.cpython-37.pyc
│ │ ├── simple_env_manager.cpython-38.pyc
│ │ ├── subproc_env_manager.cpython-37.pyc
│ │ └── subproc_env_manager.cpython-38.pyc
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── manager_module.cpython-37.pyc
│ │ │ └── manager_module.cpython-38.pyc
│ │ └── manager_module.py
│ ├── simple_env_manager.py
│ └── subproc_env_manager.py
├── modules
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── attention.cpython-37.pyc
│ │ ├── attention.cpython-38.pyc
│ │ ├── norm.cpython-37.pyc
│ │ ├── norm.cpython-38.pyc
│ │ ├── sequence.cpython-37.pyc
│ │ ├── sequence.cpython-38.pyc
│ │ ├── spatial.cpython-37.pyc
│ │ └── spatial.cpython-38.pyc
│ ├── attention.py
│ ├── memory.py
│ ├── mlp.py
│ ├── norm.py
│ ├── sequence.py
│ └── spatial.py
├── network
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── modular_network.cpython-37.pyc
│ │ └── modular_network.cpython-38.pyc
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── base.cpython-37.pyc
│ │ │ ├── base.cpython-38.pyc
│ │ │ ├── network_module.cpython-37.pyc
│ │ │ ├── network_module.cpython-38.pyc
│ │ │ ├── submodule.cpython-37.pyc
│ │ │ └── submodule.cpython-38.pyc
│ │ ├── base.py
│ │ ├── network_module.py
│ │ └── submodule.py
│ ├── modular_network.py
│ ├── net1d
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── identity_1d.cpython-37.pyc
│ │ │ ├── identity_1d.cpython-38.pyc
│ │ │ ├── linear.cpython-37.pyc
│ │ │ ├── linear.cpython-38.pyc
│ │ │ ├── lstm.cpython-37.pyc
│ │ │ ├── lstm.cpython-38.pyc
│ │ │ ├── submodule_1d.cpython-37.pyc
│ │ │ └── submodule_1d.cpython-38.pyc
│ │ ├── identity_1d.py
│ │ ├── linear.py
│ │ ├── lstm.py
│ │ └── submodule_1d.py
│ ├── net2d
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── deconv.cpython-37.pyc
│ │ │ ├── deconv.cpython-38.pyc
│ │ │ ├── identity_2d.cpython-37.pyc
│ │ │ ├── identity_2d.cpython-38.pyc
│ │ │ ├── submodule_2d.cpython-37.pyc
│ │ │ └── submodule_2d.cpython-38.pyc
│ │ ├── deconv.py
│ │ ├── identity_2d.py
│ │ └── submodule_2d.py
│ ├── net3d
│ │ ├── RelationalMHDPA.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── RelationalMHDPA.cpython-38.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── deconv.cpython-38.pyc
│ │ │ ├── four_conv.cpython-37.pyc
│ │ │ ├── four_conv.cpython-38.pyc
│ │ │ ├── identity_3d.cpython-37.pyc
│ │ │ ├── identity_3d.cpython-38.pyc
│ │ │ ├── submodule_3d.cpython-37.pyc
│ │ │ └── submodule_3d.cpython-38.pyc
│ │ ├── _resnets.py
│ │ ├── four_conv.py
│ │ ├── identity_3d.py
│ │ ├── networks.py
│ │ ├── rmc.py
│ │ └── submodule_3d.py
│ └── net4d
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── identity_4d.cpython-37.pyc
│ │ ├── identity_4d.cpython-38.pyc
│ │ ├── submodule_4d.cpython-37.pyc
│ │ └── submodule_4d.cpython-38.pyc
│ │ ├── identity_4d.py
│ │ └── submodule_4d.py
├── preprocess
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── observation.cpython-37.pyc
│ │ ├── observation.cpython-38.pyc
│ │ ├── ops.cpython-37.pyc
│ │ └── ops.cpython-38.pyc
│ ├── observation.py
│ └── ops.py
├── registry
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── registry.cpython-37.pyc
│ │ └── registry.cpython-38.pyc
│ └── registry.py
├── rewardnorm
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── normalizers.cpython-37.pyc
│ │ └── normalizers.cpython-38.pyc
│ ├── base
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── __init__.cpython-38.pyc
│ │ │ ├── rewnorm_module.cpython-37.pyc
│ │ │ └── rewnorm_module.cpython-38.pyc
│ │ └── rewnorm_module.py
│ └── normalizers.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── logging.cpython-37.pyc
│ ├── logging.cpython-38.pyc
│ ├── requires_args.cpython-37.pyc
│ ├── requires_args.cpython-38.pyc
│ ├── script_helpers.cpython-37.pyc
│ ├── script_helpers.cpython-38.pyc
│ ├── util.cpython-37.pyc
│ └── util.cpython-38.pyc
│ ├── logging.py
│ ├── requires_args.py
│ ├── script_helpers.py
│ └── util.py
├── requirements.txt
└── setup.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/.DS_Store
--------------------------------------------------------------------------------
/mcs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/.DS_Store
--------------------------------------------------------------------------------
/mcs/__init__.py:
--------------------------------------------------------------------------------
1 | def register_agent(agent_cls):
2 | from mcs.registry import REGISTRY
3 |
4 | REGISTRY.register_agent(agent_cls)
5 |
6 |
7 | def register_actor(actor_cls):
8 | from mcs.registry import REGISTRY
9 |
10 | REGISTRY.register_actor(actor_cls)
11 |
12 |
13 | def register_exp(exp_cls):
14 | from mcs.registry import REGISTRY
15 |
16 | REGISTRY.register_exp(exp_cls)
17 |
18 |
19 | def register_learner(learner_cls):
20 | from mcs.registry import REGISTRY
21 |
22 | REGISTRY.register_learner(learner_cls)
23 |
24 |
25 | def register_env(env_cls):
26 | from mcs.registry import REGISTRY
27 |
28 | REGISTRY.register_env(env_cls)
29 |
30 |
31 | def register_reward_norm(rwd_norm_cls):
32 | from mcs.registry import REGISTRY
33 |
34 | REGISTRY.register_reward_normalizer(rwd_norm_cls)
35 |
36 |
37 | def register_network(network_cls):
38 | from mcs.registry import REGISTRY
39 |
40 | REGISTRY.register_network(network_cls)
41 |
42 |
43 | def register_submodule(submod_cls):
44 | from mcs.registry import REGISTRY
45 |
46 | REGISTRY.register_submodule(submod_cls)
47 |
48 |
49 | def register_manager(manager_cls):
50 | from mcs.registry import REGISTRY
51 |
52 | REGISTRY.register_manager(manager_cls)
53 |
--------------------------------------------------------------------------------
/mcs/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/__pycache__/globals.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/globals.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/__pycache__/globals.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/globals.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/__init__.py:
--------------------------------------------------------------------------------
1 | from .base.actor_module import ActorModule
2 | from .ac_rollout import ACRolloutActorTrain
3 | from mcs.actor.ac_eval import ACActorEval, ACActorEvalSample
4 | from .impala import ImpalaHostActor, ImpalaWorkerActor,ImpalaHostTargetActor
5 |
6 | ACTOR_REG = [
7 | ACRolloutActorTrain,
8 | ACActorEval,
9 | ACActorEvalSample,
10 | ImpalaHostActor,
11 | ImpalaWorkerActor,
12 | ImpalaHostTargetActor
13 | ]
14 |
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/ac_eval.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_eval.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/ac_eval.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_eval.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/ac_rollout.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_rollout.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/ac_rollout.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_rollout.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/impala.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/impala.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/actor/__pycache__/impala.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/impala.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/ac_eval.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from mcs.actor import ActorModule
4 | from mcs.actor.base.ac_helper import ACActorHelperMixin
5 |
6 |
7 | class ACActorEval(ActorModule, ACActorHelperMixin):
8 | args = {}
9 |
10 | @classmethod
11 | def from_args(cls, action_space):
12 | return cls(action_space)
13 |
14 | @staticmethod
15 | def output_space(action_space):
16 | head_dict = {"critic": (1,), **action_space}
17 | return head_dict
18 |
19 | def compute_action_exp(self, preds, internals, obs, available_actions):
20 | actions = OrderedDict()
21 |
22 | for key in self.action_keys:
23 | logit = self.flatten_logits(preds[key])
24 |
25 | softmax = self.softmax(logit)
26 | action = self.sample_action(softmax)
27 |
28 | actions[key] = action.cpu()
29 | return actions, {"value": preds["critic"].squeeze(-1)}
30 |
31 | @classmethod
32 | def _exp_spec(
33 | cls, rollout_len, batch_sz, obs_space, act_space, internal_space
34 | ):
35 | return {}
36 |
37 |
38 | class ACActorEvalSample(ACActorEval):
39 | def compute_action_exp(self, preds, internals, obs, available_actions):
40 | actions = OrderedDict()
41 |
42 | for key in self.action_keys:
43 | logit = self.flatten_logits(preds[key])
44 |
45 | softmax = self.softmax(logit)
46 | action = self.sample_action(softmax)
47 |
48 | actions[key] = action.cpu()
49 | return actions, {"value": preds["critic"].squeeze(-1)}
50 |
--------------------------------------------------------------------------------
/mcs/actor/ac_rollout.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from collections import OrderedDict
16 |
17 | import torch
18 |
19 | from mcs.actor.base.ac_helper import ACActorHelperMixin
20 | from mcs.actor.base.actor_module import ActorModule
21 |
22 |
23 | class ACRolloutActorTrain(ActorModule, ACActorHelperMixin):
24 | args = {}
25 |
26 | @classmethod
27 | def from_args(cls, action_space):
28 | return cls(action_space)
29 |
30 | @staticmethod
31 | def output_space(action_space):
32 | head_dict = {"critic": (1,), **action_space}
33 | return head_dict
34 |
35 | def compute_action_exp(self, preds, internals, obs, available_actions):
36 | values = preds["critic"].squeeze(1)
37 |
38 | actions = OrderedDict()
39 | log_probs = []
40 | entropies = []
41 |
42 | for key in self.action_keys:
43 | logit = self.flatten_logits(preds[key])
44 |
45 | log_softmax, softmax = self.log_softmax(logit), self.softmax(logit)
46 | entropy = self.entropy(log_softmax, softmax)
47 | action = self.sample_action(softmax)
48 |
49 | entropies.append(entropy)
50 | log_probs.append(self.log_probability(log_softmax, action))
51 | actions[key] = action.cpu()
52 |
53 | log_probs = torch.cat(log_probs, dim=1)
54 | entropies = torch.cat(entropies, dim=1)
55 |
56 | return (
57 | actions,
58 | {"log_probs": log_probs, "entropies": entropies, "values": values},
59 | )
60 |
61 | @classmethod
62 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
63 | act_key_len = len(act_space.keys())
64 |
65 | spec = {
66 | "log_probs": (exp_len, batch_sz, act_key_len),
67 | "entropies": (exp_len, batch_sz, act_key_len),
68 | "values": (exp_len, batch_sz),
69 | }
70 |
71 | return spec
72 |
--------------------------------------------------------------------------------
/mcs/actor/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor_module import ActorModule
2 |
--------------------------------------------------------------------------------
/mcs/actor/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/actor/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/base/__pycache__/ac_helper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/ac_helper.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/actor/base/__pycache__/ac_helper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/ac_helper.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/base/__pycache__/actor_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/actor_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/actor/base/__pycache__/actor_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/actor_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/actor/base/ac_helper.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | class ACActorHelperMixin(metaclass=abc.ABCMeta):
7 | """
8 | A helper class for actor critic actors.
9 | """
10 |
11 | @staticmethod
12 | def flatten_logits(logit):
13 | """
14 | :param logits: Tensor of arbitrary dim
15 | :return: logits flattened to (N, X)
16 | """
17 | size = logit.size()
18 | dim = logit.dim()
19 |
20 | if dim == 3:
21 | n, f, l = size
22 | logit = logit.view(n, f * l)
23 | elif dim == 4:
24 | n, f, h, w = size
25 | logit = logit.view(n, f * h * w)
26 | elif dim == 5:
27 | n, f, d, h, w = size
28 | logit = logit.view(n, f * d * h * w)
29 | return logit
30 |
31 | @staticmethod
32 | def softmax(logit):
33 | """
34 | :param logit: torch.Tensor (N, X)
35 | :return: torch.Tensor (N, X)
36 | """
37 | return F.softmax(logit, dim=1)
38 |
39 | @staticmethod
40 | def log_softmax(logit):
41 | """
42 | :param logit: torch.Tensor (N, X)
43 | :return: torch.Tensor (N, X)
44 | """
45 | return F.log_softmax(logit, dim=1)
46 |
47 | @staticmethod
48 | def log_probability(log_softmax, action):
49 | """
50 | :param log_softmax: Tensor (N, X)
51 | :param action: LongTensor (N)
52 | :return: Tensor (N, 1)
53 | """
54 | return log_softmax.gather(1, action.unsqueeze(1))
55 |
56 | @staticmethod
57 | def entropy(log_softmax, softmax):
58 | """
59 | :param log_softmax: Tensor (N, X)
60 | :param softmax: Tensor (N, X)
61 | :return: Tensor (N, 1)
62 | """
63 | return -(log_softmax * softmax).sum(1, keepdim=True)
64 |
65 | @staticmethod
66 | def sample_action(softmax):
67 | """
68 | Samples an action from a softmax distribution.
69 |
70 | :param softmax: torch.Tensor (N, X)
71 | :return: torch.Tensor (N)
72 | """
73 | return softmax.multinomial(1).squeeze(1)
74 |
75 | @staticmethod
76 | def select_action(softmax):
77 | """
78 | Selects the action with the highest probability.
79 |
80 | :param softmax:
81 | :return:
82 | """
83 | return torch.argmax(softmax, dim=1)
84 |
--------------------------------------------------------------------------------
/mcs/actor/base/actor_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | """
16 | An actor observes the environment and decides actions. It also outputs extra
17 | info necessary for model updates (learning) to occur.
18 | """
19 | import abc
20 | from collections import defaultdict
21 |
22 | from mcs.exp.base.spec_builder import ExpSpecBuilder
23 | from mcs.utils.requires_args import RequiresArgsMixin
24 |
25 |
26 | class ActorModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
27 | def __init__(self,action_space):
28 | self._action_space = action_space
29 |
30 |
31 | @property
32 | def action_space(self):
33 | return self._action_space
34 |
35 | @property
36 | def action_keys(self):
37 | return sorted(self.action_space.keys())
38 |
39 | @staticmethod
40 | @abc.abstractmethod
41 | def output_space(action_space):
42 | raise NotImplementedError
43 |
44 | @classmethod
45 | def exp_spec_builder(cls, obs_space, act_space, internal_space, batch_sz):
46 | def build_fn(exp_len):
47 | exp_space = cls._exp_spec(
48 | exp_len, batch_sz, obs_space, act_space, internal_space
49 | )
50 | env_space = {
51 | "rewards": (exp_len, batch_sz),
52 | "terminals": (exp_len, batch_sz),
53 | }
54 | return {**exp_space, **env_space}
55 |
56 | key_types = cls._key_types(obs_space, act_space, internal_space)
57 | exp_keys = cls._exp_keys(obs_space, act_space, internal_space)
58 | return ExpSpecBuilder(
59 | obs_space, act_space, internal_space, key_types, exp_keys, build_fn
60 | )
61 |
62 | @classmethod
63 | @abc.abstractmethod
64 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
65 | raise NotImplementedError
66 |
67 | @classmethod
68 | def _exp_keys(cls, obs_space, act_space, internal_space):
69 | dummy = cls._exp_spec(1, 1, obs_space, act_space, internal_space)
70 | return dummy.keys()
71 |
72 | @classmethod
73 | def _key_types(cls, obs_space, act_space, internal_space):
74 | return defaultdict(lambda: "float")
75 |
76 | @abc.abstractmethod
77 | def from_args(self, args, action_space):
78 | raise NotImplementedError
79 |
80 | @abc.abstractmethod
81 | def compute_action_exp(self, preds, internals, obs, available_actions):
82 | """
83 | B = Batch Size
84 |
85 | :param preds: Dict[str, torch.Tensor]
86 | :return:
87 | actions: Dict[ActionKey, Tensor (B)]
88 | experience: Dict[str, Tensor (B, X)]
89 | """
90 | raise NotImplementedError
91 |
92 | def act(self, network, obs, prev_internals):
93 | """
94 | :param obs: Dict[str, Tensor]
95 | :param prev_internals: previous interal states. Dict[str, Tensor]
96 | :return:
97 | actions: Dict[ActionKey, Tensor (B)]
98 | experience: Dict[str, Tensor (B, X)]
99 | internal_states: Dict[str, Tensor]
100 | """
101 |
102 | predictions, internal_states, pobs = network(obs, prev_internals)
103 |
104 | if "available_actions" in obs:
105 | av_actions = obs["available_actions"]
106 | else:
107 | av_actions = None
108 |
109 | actions, exp = self.compute_action_exp(
110 | predictions, prev_internals, pobs, av_actions
111 | )
112 | return actions, exp, internal_states
113 |
--------------------------------------------------------------------------------
/mcs/agent/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/.DS_Store
--------------------------------------------------------------------------------
/mcs/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from .base.agent_module import AgentModule
2 | from .actor_critic import ActorCritic
3 |
4 | AGENT_REG = [
5 | ActorCritic
6 | ]
7 |
--------------------------------------------------------------------------------
/mcs/agent/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/agent/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/agent/__pycache__/actor_critic.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/actor_critic.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/agent/__pycache__/actor_critic.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/actor_critic.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/agent/actor_critic.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from mcs.actor import ACRolloutActorTrain
16 | from mcs.exp import Rollout
17 | from mcs.learner import ACRolloutLearner
18 | from .base.agent_module import AgentModule
19 |
20 |
21 | class ActorCritic(AgentModule):
22 | args = {**Rollout.args, **ACRolloutActorTrain.args, **ACRolloutLearner.args}
23 |
24 | def __init__(
25 | self,
26 | reward_normalizer,
27 | action_space,
28 | spec_builder,
29 | rollout_len,
30 | discount,
31 | normalize_advantage,
32 | entropy_weight,
33 | return_scale,
34 | ):
35 | super(ActorCritic, self).__init__(reward_normalizer, action_space)
36 | self.discount = discount
37 | self.normalize_advantage = normalize_advantage
38 | self.entropy_weight = entropy_weight
39 |
40 | self._exp_cache = Rollout(spec_builder, rollout_len)
41 | self._actor = ACRolloutActorTrain(action_space)
42 | self._learner = ACRolloutLearner(
43 | reward_normalizer,
44 | discount,
45 | normalize_advantage,
46 | entropy_weight,
47 | return_scale,
48 | )
49 |
50 | @classmethod
51 | def from_args(
52 | cls, args, reward_normalizer, action_space, spec_builder, **kwargs
53 | ):
54 | return cls(
55 | reward_normalizer,
56 | action_space,
57 | spec_builder,
58 | rollout_len=args.rollout_len,
59 | discount=args.discount,
60 | normalize_advantage=args.normalize_advantage,
61 | entropy_weight=args.entropy_weight,
62 | return_scale=args.return_scale,
63 | )
64 |
65 | @property
66 | def exp_cache(self):
67 | return self._exp_cache
68 |
69 | @classmethod
70 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
71 | return ACRolloutActorTrain._exp_spec(
72 | exp_len, batch_sz, obs_space, act_space, internal_space
73 | )
74 |
75 | @staticmethod
76 | def output_space(action_space):
77 | return ACRolloutActorTrain.output_space(action_space)
78 |
79 | def compute_action_exp(
80 | self, predictions, internals, obs, available_actions
81 | ):
82 | return self._actor.compute_action_exp(
83 | predictions, internals, obs, available_actions
84 | )
85 |
86 | def learn_step(self, updater, network, next_obs, internals):
87 | return self._learner.learn_step(
88 | updater, network, self.exp_cache.read(), next_obs, internals
89 | )
90 |
--------------------------------------------------------------------------------
/mcs/agent/base/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/.DS_Store
--------------------------------------------------------------------------------
/mcs/agent/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__init__.py
--------------------------------------------------------------------------------
/mcs/agent/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/agent/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/agent/base/__pycache__/agent_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/agent_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/agent/base/__pycache__/agent_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/agent_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/agent/base/agent_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | """
16 | An Agent interacts with the environment and accumulates experience.
17 | """
18 | from collections import defaultdict
19 |
20 | import abc
21 |
22 | from mcs.exp import ExpSpecBuilder
23 | from mcs.utils.requires_args import RequiresArgsMixin
24 |
25 |
26 | class AgentModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
27 | """
28 | An Agent is an Actor (chooses actions) and a Learner (updates parameters)
29 | and maintains a cache to store a rollout or experience replay.
30 |
31 | Actors and Learners are treated separately for Actor-Learner architectures
32 | where multiple actors send their experience to a single learner.
33 | """
34 |
35 | def __init__(self, reward_normalizer, action_space):
36 | self._reward_normalizer = reward_normalizer
37 | self._action_space = action_space
38 |
39 | @classmethod
40 | @abc.abstractmethod
41 | def from_args(
42 | cls, args, reward_normalizer, action_space, spec_builder, **kwargs
43 | ):
44 | raise NotImplementedError
45 |
46 | @property
47 | @abc.abstractmethod
48 | def exp_cache(self):
49 | """Get experience cache"""
50 | raise NotImplementedError
51 |
52 | @property
53 | def action_space(self):
54 | return self._action_space
55 |
56 | @property
57 | def action_keys(self):
58 | return list(sorted(self.action_space.keys()))
59 |
60 | @classmethod
61 | def exp_spec_builder(cls, obs_space, act_space, internal_space, batch_sz):
62 | def build_fn(exp_len):
63 | exp_space = cls._exp_spec(
64 | exp_len, batch_sz, obs_space, act_space, internal_space
65 | )
66 | env_space = {
67 | "rewards": (exp_len, batch_sz),
68 | "terminals": (exp_len, batch_sz),
69 | }
70 | return {**exp_space, **env_space}
71 |
72 | key_types = cls._key_types(obs_space, act_space, internal_space)
73 | exp_keys = cls._exp_keys(obs_space, act_space, internal_space)
74 | return ExpSpecBuilder(
75 | obs_space, act_space, internal_space, key_types, exp_keys, build_fn
76 | )
77 |
78 | @classmethod
79 | @abc.abstractmethod
80 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space):
81 | raise NotImplementedError
82 |
83 | @classmethod
84 | def _exp_keys(cls, obs_space, act_space, internal_space):
85 | dummy = cls._exp_spec(1, 1, obs_space, act_space, internal_space)
86 | return dummy.keys()
87 |
88 | @classmethod
89 | def _key_types(cls, obs_space, act_space, internal_space):
90 | return defaultdict(lambda: "float")
91 |
92 | @staticmethod
93 | @abc.abstractmethod
94 | def output_space(action_space):
95 | raise NotImplementedError
96 |
97 | @abc.abstractmethod
98 | def compute_action_exp(
99 | self, predictions, internals, obs, available_actions
100 | ):
101 | raise NotImplementedError
102 |
103 | @abc.abstractmethod
104 | def learn_step(self, updater, network, next_obs, internals):
105 | raise NotImplementedError
106 |
107 | def is_ready(self):
108 | return self.exp_cache.is_ready()
109 |
110 | def clear(self):
111 | self.exp_cache.clear()
112 |
113 | def act(self, network, obs, prev_internals):
114 | """
115 | :param network: NetworkModule
116 | :param obs: Dict[str, Tensor]
117 | :param prev_internals: previous interal states. Dict[str, Tensor]
118 | :return:
119 | actions: Dict[ActionKey, LongTensor (B)]
120 | internal_states: Dict[str, Tensor]
121 | """
122 | predictions, internal_states, pobs = network(obs, prev_internals)
123 |
124 | if "available_actions" in obs:
125 | av_actions = obs["available_actions"]
126 | else:
127 | av_actions = None
128 |
129 | actions, experience = self.compute_action_exp(
130 | predictions, prev_internals, pobs, av_actions
131 | )
132 | self.exp_cache.write_actor(experience)
133 | return actions, internal_states
134 |
135 | def observe(self, obs, rewards, terminals, infos):
136 | self.exp_cache.write_env(obs, rewards, terminals, infos)
137 | return rewards, terminals, infos
138 |
139 | def to(self, device):
140 | self.exp_cache.to(device)
141 | return self
142 |
--------------------------------------------------------------------------------
/mcs/config/default.json:
--------------------------------------------------------------------------------
1 | {
2 | "actor_host": "ImpalaHostActor",
3 | "actor_worker": "ImpalaWorkerActor",
4 | "actor_target": "ImpalaHostTargetActor",
5 | "ceil": 1,
6 | "custom_network": null,
7 | "discount": 0.99,
8 | "entropy_weight": 0.01,
9 | "env": "SpaceInvadersNoFrameskip-v4",
10 | "epoch_len": 1000000,
11 | "eval": false,
12 | "exp": "Rollout",
13 | "floor": -1,
14 | "fourconv_norm": "bn",
15 | "frame_stack": false,
16 | "grad_norm_clip": "0.5",
17 | "head1d": "Identity1D",
18 | "head2d": "DeConv2D",
19 | "head3d": "Identity3D",
20 | "head4d": "Identity4D",
21 | "learner": "ImpalaLearner",
22 | "load_network": null,
23 | "load_optim": null,
24 | "logdir": "./logs",
25 | "lr": 0.0007,
26 | "lstm_nb_hidden": 512,
27 | "lstm_normalize": true,
28 | "manager": "SubProcEnvManager",
29 | "max_episode_length": 10000,
30 | "minimum_importance_policy": 1.0,
31 | "minimum_importance_value": 1.0,
32 | "nb_env": 32,
33 | "nb_learn_batch": 2,
34 | "nb_learners": 1,
35 | "nb_step": 50000000,
36 | "nb_workers": 2,
37 | "net1d": "Identity1D",
38 | "net2d": "Identity2D",
39 | "net3d": "FourConv",
40 | "net4d": "Identity4D",
41 | "netbody": "Linear",
42 | "noop_max": 30,
43 | "profile": false,
44 | "prompt": false,
45 | "ray_addr": null,
46 | "resume": null,
47 | "rollout_len": 20,
48 | "rollout_queue_size": 8,
49 | "rwd_norm": "Clip",
50 | "seed": 99,
51 | "skip_rate": 4,
52 | "summary_freq": 10,
53 | "tag": "default",
54 | "learner_cpu_alloc": 5,
55 | "learner_gpu_alloc": 0.5,
56 | "worker_cpu_alloc": 1,
57 | "worker_gpu_alloc": 0.05,
58 | "debug_mode": false,
59 | "test_mode": false,
60 | "use_pixel_control": false,
61 | "cell_size": 4,
62 | "pixel_control_loss_gamma": 0.99,
63 | "action_space": 6,
64 | "gae_gamma": 0.99,
65 | "gae_lambda": 0.995,
66 | "use_mhra": false,
67 | "num_head": 4,
68 | "probs_clip": 0.4,
69 | "target_worker_clip_rho": 2,
70 | "minibatch_buffer_size": 4,
71 | "num_sgd": 1,
72 | "linear_normalize": "not",
73 | "linear_nb_hidden": 512,
74 | "nb_layer": 3
75 | }
76 |
--------------------------------------------------------------------------------
/mcs/container/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .local import Local
16 | from .distrib import DistribHost, DistribWorker
17 | from .actorlearner import ActorLearnerHost, ActorLearnerWorker
18 | from .evaluation import EvalContainer
19 | from .init import Init
20 | from .base.updater import Updater
21 |
--------------------------------------------------------------------------------
/mcs/container/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/distrib.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/distrib.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/distrib.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/distrib.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/evaluation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/evaluation.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/evaluation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/evaluation.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/init.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/init.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/init.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/init.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/local.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/local.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/__pycache__/local.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/local.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__init__.py:
--------------------------------------------------------------------------------
1 | from .learner_container import ActorLearnerHost
2 | from .rollout_worker import ActorLearnerWorker
3 |
4 |
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/learner_container.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/learner_container.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/learner_container.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/learner_container.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/rollout_worker.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_worker.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/__pycache__/rollout_worker.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_worker.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/actorlearner/rollout_queuer.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from time import time
16 |
17 | import queue
18 | import ray
19 | import threading
20 | import torch
21 |
22 |
23 | class RolloutQueuerAsync:
24 | def __init__(self, workers, num_rollouts, queue_max_size, timeout=15.0):
25 | self.workers = workers
26 | self.num_rollouts = num_rollouts
27 | self.queue_max_size = queue_max_size
28 | self.queue_timeout = timeout
29 |
30 | self.futures = [w.run.remote() for w in self.workers]
31 | self.future_inds = [w for w in range(len(self.workers))]
32 | self._should_stop = True
33 | self.rollout_queue = queue.Queue(self.queue_max_size)
34 | self._worker_wait_time = 0
35 | self._host_wait_time = 0
36 |
37 | def _background_queing_thread(self):
38 | while not self._should_stop:
39 | ready_ids, remaining_ids = ray.wait(self.futures, num_returns=1)
40 |
41 | # if ray returns an empty list that means all objects have been gotten
42 | # this should never happen?
43 | if len(ready_ids) == 0:
44 | print("WARNING: ray returned no ready rollouts")
45 | # otherwise rollout was returned
46 | else:
47 | for ready in ready_ids:
48 | # get object and add to queue
49 | rollouts = ray.get(ready)
50 | # this will block if queue is at max size
51 | self._add_to_queue(rollouts)
52 |
53 | # remove from futures
54 | self._idle_workers = []
55 | for ready in ready_ids:
56 | index = self.futures.index(ready)
57 | del self.futures[index]
58 | self._idle_workers.append(self.future_inds[index])
59 | del self.future_inds[index]
60 |
61 | # tell the worker(s) to start another rollout
62 | self._restart_idle_workers()
63 |
64 | # done, wait for all remaining to finish
65 | dones, not_dones = ray.wait(
66 | self.futures, len(self.futures), timeout=self.queue_timeout
67 | )
68 |
69 | if len(not_dones) > 0:
70 | print("WARNING: Not all rollout workers finished")
71 |
72 | def _add_to_queue(self, rollout):
73 | st = time()
74 | self.rollout_queue.put(rollout, timeout=self.queue_timeout)
75 | et = time()
76 | self._worker_wait_time += et - st
77 |
78 | def start(self):
79 | self._should_stop = False
80 | self.background_thread = threading.Thread(
81 | target=self._background_queing_thread
82 | )
83 | self.background_thread.start()
84 |
85 | def get(self):
86 | st = time()
87 | worker_data = [
88 | self.rollout_queue.get(True) for _ in range(self.num_rollouts)
89 | ]
90 | et = time()
91 | self._host_wait_time += et - st
92 |
93 | rollouts = []
94 | terminal_rewards = []
95 | terminal_infos = []
96 | for w in worker_data:
97 | r, t, i = w["rollout"], w["terminal_rewards"], w["terminal_infos"]
98 | rollouts.append(r)
99 | terminal_rewards.append(t)
100 | terminal_infos.append(i)
101 |
102 | return rollouts, terminal_rewards, terminal_infos
103 |
104 | def close(self):
105 | self._should_stop = True
106 |
107 | # try to join background thread
108 | self.background_thread.join()
109 |
110 | def metrics(self):
111 | return {
112 | "Host wait time": self._host_wait_time,
113 | "Worker wait time": self._worker_wait_time,
114 | }
115 |
116 | def _restart_idle_workers(self):
117 | del_inds = []
118 | for d_ind, w_ind in enumerate(self._idle_workers):
119 | worker = self.workers[w_ind]
120 | self.futures.append(worker.run.remote())
121 | self.future_inds.append(w_ind)
122 | del_inds.append(d_ind)
123 |
124 | for d in del_inds:
125 | del self._idle_workers[d]
126 |
127 | def __len__(self):
128 | return len(self.rollout_queue)
129 |
--------------------------------------------------------------------------------
/mcs/container/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .container import Container
2 | from .nccl_optimizer import NCCLOptimizer
3 |
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/container.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/container.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/container.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/container.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/nccl_optimizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/nccl_optimizer.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/nccl_optimizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/nccl_optimizer.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/updater.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/updater.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/container/base/__pycache__/updater.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/updater.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/container/base/container.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | class Container:
5 | @staticmethod
6 | def load_network(network, path):
7 | network.load_state_dict(
8 | torch.load(path, map_location=lambda storage, loc: storage)
9 | )
10 | return network
11 |
12 | @staticmethod
13 | def load_optim(optimizer, path):
14 | optimizer.load_state_dict(
15 | torch.load(path, map_location=lambda storage, loc: storage)
16 | )
17 | return optimizer
18 |
19 | @staticmethod
20 | def init_next_save(initial_step_count, epoch_len):
21 | next_save = 0
22 | if initial_step_count > 0:
23 | while next_save <= initial_step_count:
24 | next_save += epoch_len
25 | return next_save
26 |
27 | @staticmethod
28 | def count_parameters(net):
29 | return sum(p.numel() for p in net.parameters() if p.requires_grad)
30 |
31 | @staticmethod
32 | def write_summaries(
33 | writer, step_count, total_loss, loss_dict, metric_dict, n_params
34 | ):
35 | writer.add_scalar("loss/total_loss", total_loss.item(), step_count)
36 | for l_name, loss in loss_dict.items():
37 | writer.add_scalar("loss/" + l_name, loss.item(), step_count)
38 | for m_name, metric in metric_dict.items():
39 | if m_name == "importance":
40 | writer.add_scalar("metric/" + m_name, metric.item(), step_count)
41 | elif m_name == "action":
42 | writer.add_histogram("metric/" + m_name, getData(metric), step_count)
43 | else:
44 | writer.add_scalar("metric/" + m_name, metric.item(), step_count)
45 | for p_name, param in n_params:
46 | p_name = p_name.replace(".", "/")
47 | writer.add_scalar(p_name, torch.norm(param).item(), step_count)
48 | if param.grad is not None:
49 | writer.add_scalar(
50 | p_name + ".grad", torch.norm(param.grad).item(), step_count
51 | )
52 |
53 |
54 |
55 | def getData(data):
56 | result = np.linspace(0,26,27)
57 | repeat = []
58 | for i in range(27):
59 | repeat.append(int(data[i]*100))
60 | return result.repeat(repeat)
61 |
62 |
--------------------------------------------------------------------------------
/mcs/container/base/nccl_optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 |
17 | class NCCLOptimizer:
18 | def __init__(self, optimizer_fn, network, world_size, param_sync_rate=1000):
19 | self.network = network
20 | self.optimizer = optimizer_fn(self.network.parameters())
21 | self.param_sync_rate = param_sync_rate
22 | self._opt_count = 0
23 | self.process_group = None
24 | self.world_size = world_size
25 |
26 | def set_process_group(self, pg):
27 | self.process_group = pg
28 |
29 | def step(self):
30 | handles = []
31 | for param in self.network.parameters():
32 | handles.append(self.process_group.allreduce(param.grad))
33 | for handle in handles:
34 | handle.wait()
35 | for param in self.network.parameters():
36 | param.grad.mul_(1.0 / self.world_size)
37 | self.optimizer.step()
38 | self._opt_count += 1
39 |
40 | # sync params every once in a while to reduce numerical errors
41 | if self._opt_count % self.param_sync_rate == 0:
42 | self.sync_parameters()
43 | # can't just sync buffers, some are int and don't mean well
44 | # self.sync_buffers()
45 |
46 | def sync_parameters(self):
47 | for param in self.network.parameters():
48 | self.process_group.allreduce(param.data)
49 | param.data.mul_(1.0 / self.world_size)
50 |
51 | def sync_buffers(self):
52 | for b in self.network.buffers():
53 | self.process_group.allreduce(b.data)
54 | b.data.mul_(1.0 / self.world_size)
55 |
56 | def zero_grad(self):
57 | self.optimizer.zero_grad()
58 |
59 | def state_dict(self):
60 | return self.optimizer.state_dict()
61 |
62 | def load_state_dict(self, d):
63 | return self.optimizer.load_state_dict(d)
64 |
--------------------------------------------------------------------------------
/mcs/container/base/updater.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Updater(metaclass=abc.ABCMeta):
5 | def __init__(self, optimizer, network, grad_norm_clip):
6 | self.optimizer = optimizer
7 | self.network = network
8 | self.grad_norm_clip = grad_norm_clip
9 |
10 | def step(self, loss):
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/mcs/container/evaluation_thread.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import numpy as np
16 | from threading import Thread
17 | from time import time, sleep
18 |
19 |
20 | class EvaluationThread(LogsAndSummarizesRewards):
21 | def __init__(
22 | self,
23 | training_network,
24 | agent,
25 | env,
26 | nb_env,
27 | logger,
28 | summary_writer,
29 | step_rate_limit,
30 | override_step_count_fn=None,
31 | ):
32 | self._training_network = training_network
33 | self._agent = agent
34 | self._environment = env
35 | self._nb_env = nb_env
36 | self._logger = logger
37 | self._summary_writer = summary_writer
38 | self._step_rate_limit = step_rate_limit
39 | self._override_step_count_fn = override_step_count_fn
40 | self._thread = Thread(target=self._run)
41 | self._should_stop = False
42 |
43 | def start(self):
44 | self._thread.start()
45 |
46 | def stop(self):
47 | self._should_stop = True
48 | self._thread.join()
49 |
50 | def _run(self):
51 | next_obs = self.environment.reset()
52 | self.start_time = time()
53 | while not self._should_stop:
54 | if self._step_rate_limit > 0:
55 | sleep(1 / self._step_rate_limit)
56 | obs = next_obs
57 | actions = self.agent.act_eval(obs)
58 | next_obs, rewards, terminals, infos = self.environment.step(actions)
59 |
60 | self.agent.reset_internals(terminals)
61 | # Perform state updates
62 | terminal_rewards, terminal_infos = self.update_buffers(
63 | rewards, terminals, infos
64 | )
65 | self.log_episode_results(
66 | terminal_rewards, terminal_infos, self.local_step_count
67 | )
68 | self.write_reward_summaries(terminal_rewards, self.local_step_count)
69 |
70 | if np.any(terminals) and np.any(infos):
71 | self.network.load_state_dict(
72 | self._training_network.state_dict()
73 | )
74 |
75 | def log_episode_results(
76 | self, terminal_rewards, terminal_infos, step_count, initial_step_count=0
77 | ):
78 | if terminal_rewards:
79 | ep_reward = np.mean(terminal_rewards)
80 | self.logger.info(
81 | "eval_frames: {} reward: {} avg_eval_fps: {}".format(
82 | step_count,
83 | ep_reward,
84 | (step_count - initial_step_count)
85 | / (time() - self.start_time),
86 | )
87 | )
88 | return terminal_rewards
89 |
90 | @property
91 | def agent(self):
92 | return self._agent
93 |
94 | @property
95 | def environment(self):
96 | return self._environment
97 |
98 | @environment.setter
99 | def environment(self, new_env):
100 | self._environment = new_env
101 |
102 | @property
103 | def nb_env(self):
104 | return self._nb_env
105 |
106 | @property
107 | def logger(self):
108 | return self._logger
109 |
110 | @property
111 | def summary_writer(self):
112 | return self._summary_writer
113 |
114 | @property
115 | def summary_name(self):
116 | return "reward/eval"
117 |
118 | @property
119 | def local_step_count(self):
120 | if self._override_step_count_fn is not None:
121 | return self._override_step_count_fn()
122 | else:
123 | return self._local_step_count
124 |
125 | @local_step_count.setter
126 | def local_step_count(self, step_count):
127 | self._local_step_count = step_count
128 |
--------------------------------------------------------------------------------
/mcs/container/render.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 |
17 | from mcs.manager import SimpleEnvManager
18 | from mcs.network import ModularNetwork
19 | from mcs.registry import REGISTRY
20 | from mcs.utils import dtensor_to_dev, listd_to_dlist
21 | from mcs.utils.script_helpers import LogDirHelper
22 |
23 |
24 | class RenderContainer:
25 | def __init__(
26 | self,
27 | actor,
28 | epoch_id,
29 | start,
30 | end,
31 | logger,
32 | log_id_dir,
33 | gpu_id,
34 | seed,
35 | manager,
36 | ):
37 | self.log_dir_helper = log_dir_helper = LogDirHelper(log_id_dir)
38 | self.train_args = train_args = log_dir_helper.load_args()
39 | self.device = device = self._device_from_gpu_id(gpu_id)
40 | self.logger = logger
41 |
42 | if epoch_id:
43 | epoch_ids = [epoch_id]
44 | else:
45 | epoch_ids = self.log_dir_helper.epochs()
46 | epoch_ids = filter(lambda eid: eid >= start, epoch_ids)
47 | if end != -1.0:
48 | epoch_ids = filter(lambda eid: eid <= end, epoch_ids)
49 | epoch_ids = list(epoch_ids)
50 | self.epoch_ids = epoch_ids
51 |
52 | engine = REGISTRY.lookup_engine(train_args.env)
53 | env_cls = REGISTRY.lookup_env(train_args.env)
54 | manager_cls = REGISTRY.lookup_manager(manager)
55 | self.env_mgr = manager_cls.from_args(
56 | self.train_args, engine, env_cls, seed=seed, nb_env=1
57 | )
58 | if train_args.agent:
59 | agent = train_args.agent
60 | else:
61 | agent = train_args.actor_host
62 | output_space = REGISTRY.lookup_output_space(
63 | agent, self.env_mgr.action_space
64 | )
65 | actor_cls = REGISTRY.lookup_actor(actor)
66 | self.actor = actor_cls.from_args(
67 | actor_cls.prompt(), self.env_mgr.action_space
68 | )
69 |
70 | self.network = self._init_network(
71 | train_args,
72 | self.env_mgr.observation_space,
73 | self.env_mgr.gpu_preprocessor,
74 | output_space,
75 | REGISTRY,
76 | ).to(device)
77 |
78 | @staticmethod
79 | def _device_from_gpu_id(gpu_id):
80 | return torch.device(
81 | "cuda:{}".format(gpu_id)
82 | if (torch.cuda.is_available() and gpu_id >= 0)
83 | else "cpu"
84 | )
85 |
86 | @staticmethod
87 | def _init_network(
88 | train_args, obs_space, gpu_preprocessor, output_space, net_reg
89 | ):
90 | if train_args.custom_network:
91 | net_cls = net_reg.lookup_network(train_args.custom_network)
92 | else:
93 | net_cls = ModularNetwork
94 |
95 | return net_cls.from_args(
96 | train_args, obs_space, output_space, gpu_preprocessor, net_reg
97 | )
98 |
99 | def run(self):
100 | for epoch_id in self.epoch_ids:
101 | reward_buf = 0
102 | for net_path in self.log_dir_helper.network_paths_at_epoch(
103 | epoch_id
104 | ):
105 | self.network.load_state_dict(
106 | torch.load(
107 | net_path, map_location=lambda storage, loc: storage
108 | )
109 | )
110 | self.network.eval()
111 |
112 | internals = listd_to_dlist(
113 | [self.network.new_internals(self.device)]
114 | )
115 | next_obs = dtensor_to_dev(self.env_mgr.reset(), self.device)
116 | self.env_mgr.render()
117 |
118 | episode_complete = False
119 | while not episode_complete:
120 | obs = next_obs
121 | with torch.no_grad():
122 | actions, _, internals = self.actor.act(
123 | self.network, obs, internals
124 | )
125 | next_obs, rewards, terminals, infos = self.env_mgr.step(
126 | actions
127 | )
128 | self.env_mgr.render()
129 | next_obs = dtensor_to_dev(next_obs, self.device)
130 |
131 | reward_buf += rewards[0]
132 |
133 | if terminals[0]:
134 | episode_complete = True
135 |
136 | print(f"EPOCH_ID: {epoch_id} REWARD: {reward_buf}")
137 |
138 | def close(self):
139 | self.env_mgr.close()
140 |
--------------------------------------------------------------------------------
/mcs/env/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from mcs.env.openai_gym import ATARI_ENVS
16 | from .base.env_module import EnvModule
17 | from .openai_gym import AdeptGymEnv
18 |
19 | ENV_REG = [AdeptGymEnv]
20 |
--------------------------------------------------------------------------------
/mcs/env/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/env/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/env/__pycache__/_gym_wrappers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_gym_wrappers.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/env/__pycache__/_gym_wrappers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_gym_wrappers.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/env/__pycache__/_spaces.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_spaces.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/env/__pycache__/_spaces.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_spaces.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/env/__pycache__/openai_gym.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/openai_gym.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/env/__pycache__/openai_gym.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/openai_gym.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/env/_spaces.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from gym import spaces
16 |
17 |
18 | class Space(dict):
19 | def __init__(self, entries_by_name):
20 | super(Space, self).__init__(entries_by_name)
21 |
22 | @classmethod
23 | def from_gym(cls, gym_space):
24 | entries_by_name = Space._detect_gym_spaces(gym_space)
25 | return cls(entries_by_name)
26 |
27 | @staticmethod
28 | def _detect_gym_spaces(gym_space):
29 | if isinstance(gym_space, spaces.Discrete):
30 | return {"Discrete": (gym_space.n,)}
31 | elif isinstance(gym_space, spaces.MultiDiscrete):
32 | return {"MultiDiscrete": gym_space.nvec.tolist()}
33 | elif isinstance(gym_space, spaces.MultiBinary):
34 | return {"MultiBinary": (gym_space.n,)}
35 | elif isinstance(gym_space, spaces.Box):
36 | return {"Box": gym_space.shape}
37 | elif isinstance(gym_space, spaces.Dict):
38 | return {
39 | name: list(Space._detect_gym_spaces(s).values())[0]
40 | for name, s in gym_space.spaces.items()
41 | }
42 | elif isinstance(gym_space, spaces.Tuple):
43 | return {
44 | idx: list(Space._detect_gym_spaces(s).values())[0]
45 | for idx, s in enumerate(gym_space.spaces)
46 | }
47 |
48 | @staticmethod
49 | def dtypes_from_gym(gym_space):
50 | if isinstance(gym_space, spaces.Discrete):
51 | return {"Discrete": gym_space.dtype}
52 | elif isinstance(gym_space, spaces.MultiDiscrete):
53 | return {"MultiDiscrete":gym_space.dtype}
54 | elif isinstance(gym_space, spaces.MultiBinary):
55 | return {"MultiBinary": gym_space.dtype}
56 | elif isinstance(gym_space, spaces.Box):
57 | return {"Box": gym_space.dtype}
58 | elif isinstance(gym_space, spaces.Dict):
59 | return {
60 | name: list(Space.dtypes_from_gym(s).values())[0]
61 | for name, s in gym_space.spaces.items()
62 | }
63 | elif isinstance(gym_space, spaces.Tuple):
64 | return {
65 | idx: list(Space.dtypes_from_gym(s).values())[0]
66 | for idx, s in enumerate(gym_space.spaces)
67 | }
68 | else:
69 | raise NotImplementedError
70 |
--------------------------------------------------------------------------------
/mcs/env/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__init__.py
--------------------------------------------------------------------------------
/mcs/env/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/env/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/env/base/__pycache__/_env.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/_env.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/env/base/__pycache__/_env.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/_env.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/env/base/__pycache__/env_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/env_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/env/base/__pycache__/env_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/env_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/env/base/_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 |
18 | class HasEnvMetaData(metaclass=abc.ABCMeta):
19 | @property
20 | @abc.abstractmethod
21 | def observation_space(self):
22 | raise NotImplementedError
23 |
24 | @property
25 | @abc.abstractmethod
26 | def action_space(self):
27 | raise NotImplementedError
28 |
29 | @property
30 | @abc.abstractmethod
31 | def cpu_preprocessor(self):
32 | raise NotImplementedError
33 |
34 | @property
35 | @abc.abstractmethod
36 | def gpu_preprocessor(self):
37 | raise NotImplementedError
38 |
39 |
40 | class EnvBase(HasEnvMetaData, metaclass=abc.ABCMeta):
41 | @abc.abstractmethod
42 | def step(self, action):
43 | """
44 | :param action: Dict[ActionID, Any] Action dictionary
45 | :return: Tuple[Observation, Reward, Terminal, Info]
46 | """
47 | raise NotImplementedError
48 |
49 | @abc.abstractmethod
50 | def reset(self, **kwargs):
51 | """
52 | :param kwargs:
53 | :return: Dict[ObservationID, Any] Observation dictionary
54 | """
55 | raise NotImplementedError
56 |
57 | @abc.abstractmethod
58 | def close(self):
59 | """
60 | Close environment. Release resources.
61 |
62 | :return:
63 | """
64 | raise NotImplementedError
65 |
--------------------------------------------------------------------------------
/mcs/env/base/env_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from mcs.utils.requires_args import RequiresArgsMixin
18 | from mcs.env.base._env import EnvBase
19 |
20 |
21 | class EnvModule(EnvBase, RequiresArgsMixin, metaclass=abc.ABCMeta):
22 | ids = None
23 |
24 | """
25 | Implement this class to add your custom environment. Don't forget to
26 | implement args.
27 | """
28 |
29 | def __init__(self, action_space, cpu_preprocessor, gpu_preprocessor):
30 | """
31 | :param observation_space: ._spaces.Spaces
32 | :param action_space: ._spaces.Spaces
33 | :param cpu_preprocessor: mcs.preprocess.observation.ObsPreprocessor
34 | :param gpu_preprocessor: mcs.preprocess.observation.ObsPreprocessor
35 | """
36 | self._action_space = action_space
37 | self._cpu_preprocessor = cpu_preprocessor
38 | self._gpu_preprocessor = gpu_preprocessor
39 |
40 | @classmethod
41 | @abc.abstractmethod
42 | def from_args(cls, args, seed, **kwargs):
43 | """
44 | Construct from arguments. For convenience.
45 |
46 | :param args: Arguments object
47 | :param seed: Integer used to seed this environment.
48 | :param kwargs: Any custom arguments are passed through kwargs.
49 | :return: EnvModule instance.
50 | """
51 | raise NotImplementedError
52 |
53 | @classmethod
54 | def from_args_curry(cls, args, seed, **kwargs):
55 | def _f():
56 | return cls.from_args(args, seed, **kwargs)
57 |
58 | return _f
59 |
60 | @property
61 | def observation_space(self):
62 | return self._gpu_preprocessor.observation_space
63 |
64 | @property
65 | def action_space(self):
66 | return self._action_space
67 |
68 | @property
69 | def cpu_preprocessor(self):
70 | return self._cpu_preprocessor
71 |
72 | @property
73 | def gpu_preprocessor(self):
74 | return self._gpu_preprocessor
75 |
76 | @classmethod
77 | def check_ids_implemented(cls):
78 | if cls.ids is None:
79 | raise NotImplementedError(
80 | 'Subclass must define class attribute "ids"'
81 | )
82 |
--------------------------------------------------------------------------------
/mcs/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (C) 2018 Heron Systems, Inc.
3 | #
4 | # This program is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # This program is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with this program. If not, see .
16 | """
17 | Usage:
18 | evaluate (--logdir ) [options]
19 | evaluate (-h | --help)
20 |
21 | Required:
22 | --logdir Path to train logs (.../logs//)
23 |
24 | Options:
25 | --epoch Epoch number to load [default: None]
26 | --actor Name of the eval actor [default: ACActorEval]
27 | --gpu-id CUDA device ID of GPU [default: 0]
28 | --nb-episode Number of episodes to average [default: 30]
29 | --start Epoch to start from [default: 0]
30 | --end Epoch to end on [default: -1]
31 | --seed Seed for random variables [default: 512]
32 | --custom-network Name of custom network class
33 | """
34 | from mcs.container import EvalContainer
35 | from mcs.container import Init
36 | from mcs.registry import REGISTRY as R
37 | from mcs.utils.script_helpers import parse_path, parse_none
38 | from mcs.utils.util import DotDict
39 |
40 | import os
41 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
42 |
43 | def parse_args():
44 | from docopt import docopt
45 | args = docopt(__doc__)
46 | args = {k.strip("--").replace("-", "_"): v for k, v in args.items()}
47 | del args["h"]
48 | del args["help"]
49 |
50 | args = DotDict(args)
51 | args.logdir = parse_path(args.logdir)
52 | # TODO implement Option utility
53 | epoch_option = parse_none(args.epoch)
54 | if epoch_option:
55 | args.epoch = int(float(epoch_option))
56 | else:
57 | args.epoch = epoch_option
58 |
59 | args.gpu_id = int(args.gpu_id)
60 | args.nb_episode = int(args.nb_episode)
61 | args.start = float(args.start)
62 | args.end = float(args.end)
63 | args.seed = int(args.seed)
64 | return args
65 |
66 |
67 | def main(args):
68 | """
69 | Run an evaluation.
70 | :param args: Dict[str, Any]
71 | :return:
72 | """
73 | args = DotDict(args)
74 |
75 | Init.print_ascii_logo()
76 | logger = Init.setup_logger(args.logdir, "eval")
77 | Init.log_args(logger, args)
78 | R.load_extern_classes(args.logdir)
79 |
80 | eval_container = EvalContainer(
81 | args.actor,
82 | args.epoch,
83 | logger,
84 | args.logdir,
85 | args.gpu_id,
86 | args.nb_episode,
87 | args.start,
88 | args.end,
89 | args.seed,
90 | )
91 | try:
92 | eval_container.run()
93 | finally:
94 | eval_container.close()
95 |
96 |
97 | if __name__ == "__main__":
98 | main(parse_args())
99 |
--------------------------------------------------------------------------------
/mcs/exp/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ExpModule
2 | from .base import ExpSpecBuilder
3 | from .replay import ExperienceReplay, PrioritizedExperienceReplay
4 | from .rollout import Rollout
5 |
6 | EXP_REG = [
7 | Rollout, # , ExperienceReplay, PrioritizedExperienceReplay
8 | ]
9 |
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/repetitive_buffer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/repetitive_buffer.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/repetitive_buffer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/repetitive_buffer.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/replay.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/replay.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/replay.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/replay.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/rollout.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/rollout.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/exp/__pycache__/rollout.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/rollout.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/exp/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .spec_builder import ExpSpecBuilder
2 | from .exp_module import ExpModule
3 |
--------------------------------------------------------------------------------
/mcs/exp/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/exp/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/exp/base/__pycache__/exp_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/exp_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/exp/base/__pycache__/exp_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/exp_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/exp/base/__pycache__/spec_builder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/spec_builder.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/exp/base/__pycache__/spec_builder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/spec_builder.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/exp/base/exp_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from mcs.utils.requires_args import RequiresArgsMixin
18 |
19 |
20 | class ExpModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
21 | @abc.abstractmethod
22 | def write_actor(self, experience):
23 | raise NotImplementedError
24 |
25 | @abc.abstractmethod
26 | def write_env(self, obs, rewards, terminals, infos):
27 | raise NotImplementedError
28 |
29 | @abc.abstractmethod
30 | def read(self):
31 | raise NotImplementedError
32 |
33 | @abc.abstractmethod
34 | def is_ready(self):
35 | raise NotImplementedError
36 |
--------------------------------------------------------------------------------
/mcs/exp/base/spec_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 |
17 | class ExpSpecBuilder:
18 | def __init__(
19 | self, obs_keys, act_keys, internal_keys, key_types, exp_keys, build_fn
20 | ):
21 | self.obs_keys = sorted(obs_keys.keys())
22 | self.action_keys = sorted(act_keys.keys())
23 | self.internal_keys = sorted(internal_keys.keys())
24 | self.key_types = key_types
25 | self.exp_keys = sorted(exp_keys)
26 | self.build_fn = build_fn
27 |
28 | def __call__(self, rollout_len):
29 | return self.build_fn(rollout_len)
30 |
--------------------------------------------------------------------------------
/mcs/exp/repetitive_buffer.py:
--------------------------------------------------------------------------------
1 | from mcs.exp.rollout import Rollout
2 | import copy
3 | import torch
4 |
5 |
6 | class RepetitiveBuffer(object):
7 |
8 |
9 | def __init__(self, inqueue, buffer_size, num_k, exp): # (q,4,2_)
10 | self.inqueue = inqueue
11 | self.size = buffer_size # 4
12 |
13 | self.max_visit_times = num_k # 2
14 | self.cur_max_ttl = num_k # 2
15 |
16 | self.buffers = [None] * buffer_size
17 | self.buffer_count = [0] * buffer_size # [0,0,0,0]
18 | self.idx = 0
19 | self.exp = copy.deepcopy(exp)
20 |
21 | def get(self, target_actor, target_network, device):
22 | terminal_rewards = []
23 | terminal_infos = []
24 | if self.buffer_count[self.idx] <= 0:
25 | self.exp.clear()
26 | # Get batch from queue
27 | rollouts, terminal_rewards, terminal_infos = self.inqueue.get()
28 |
29 | # Iterate forward on batch
30 | self.exp.write_exps(rollouts)
31 |
32 | self.exp.to(device)
33 | r = self.exp.read()
34 | internals = {k: ts[0].unbind(0) for k, ts in r.internals.items()}
35 | with torch.no_grad():
36 | for obs, rewards in zip(
37 | r.observations, r.rewards
38 | ):
39 | _, t_h_exp, t_internals = target_actor.act(
40 | target_network, obs, internals
41 | )
42 | self.exp.write_actor(t_h_exp, no_env=True)
43 | self.exp.reset_index()
44 | self.buffers[self.idx] = copy.deepcopy(self.exp)
45 | self.buffer_count[self.idx] = self.max_visit_times
46 |
47 | buf = self.buffers[self.idx]
48 | self.buffer_count[self.idx] -= 1
49 | released = self.buffer_count[self.idx] <= 0
50 | if released:
51 | self.buffers[self.idx] = None
52 | self.idx = (self.idx + 1) % self.size
53 | return buf, terminal_rewards, terminal_infos
54 |
55 | def setMaxTimes(self, times):
56 | self.max_visit_times = times
57 |
--------------------------------------------------------------------------------
/mcs/exp/replay.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from mcs.exp.base.exp_module import ExpModule
16 |
17 |
18 | class ExperienceReplay(ExpModule):
19 | def write_actor(self, items):
20 | pass
21 |
22 | def write_env(self, obs, rewards, terminals, infos):
23 | pass
24 |
25 | def read(self):
26 | pass
27 |
28 | def is_ready(self):
29 | pass
30 |
31 |
32 | class PrioritizedExperienceReplay(ExpModule):
33 | def write_actor(self, items):
34 | pass
35 |
36 | def write_env(self, obs, rewards, terminals, infos):
37 | pass
38 |
39 | def read(self):
40 | pass
41 |
42 | def is_ready(self):
43 | pass
44 |
--------------------------------------------------------------------------------
/mcs/globals.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | VERSION = "1.0"
16 |
--------------------------------------------------------------------------------
/mcs/learner/__init__.py:
--------------------------------------------------------------------------------
1 | from mcs.learner.base.learner_module import LearnerModule
2 | from .ac_rollout import ACRolloutLearner
3 | from .impala import ImpalaLearner
4 |
5 | LEARNER_REG = [ACRolloutLearner, ImpalaLearner]
6 |
--------------------------------------------------------------------------------
/mcs/learner/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/learner/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/learner/__pycache__/ac_rollout.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/ac_rollout.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/learner/__pycache__/ac_rollout.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/ac_rollout.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/learner/__pycache__/impala.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/impala.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/learner/__pycache__/impala.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/impala.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/learner/ac_rollout.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .base.learner_module import LearnerModule
4 | from .base.dm_return_scale import DeepMindReturnScaler
5 |
6 |
7 | class ACRolloutLearner(LearnerModule):
8 | """
9 | Actor Critic Rollout Learner
10 | """
11 |
12 | args = {
13 | "discount": 0.99,
14 | "normalize_advantage": False,
15 | "entropy_weight": 0.01,
16 | "return_scale": False,
17 | }
18 |
19 | def __init__(
20 | self,
21 | reward_normalizer,
22 | discount,
23 | normalize_advantage,
24 | entropy_weight,
25 | return_scale,
26 | ):
27 | self.reward_normalizer = reward_normalizer
28 | self.discount = discount
29 | self.normalize_advantage = normalize_advantage
30 | self.entropy_weight = entropy_weight
31 | self.return_scale = return_scale
32 | if return_scale:
33 | self.dm_scaler = DeepMindReturnScaler(10.0 ** -3)
34 |
35 | @classmethod
36 | def from_args(cls, args, reward_normalizer):
37 | return cls(
38 | reward_normalizer,
39 | args.discount,
40 | args.normalize_advantage,
41 | args.entropy_weight,
42 | args.return_scale,
43 | )
44 |
45 | def learn_step(self, updater, network, experiences, next_obs, internals):
46 | # normalize rewards
47 | rewards = self.reward_normalizer(torch.stack(experiences.rewards))
48 |
49 | # torch stack rollouts
50 | r_log_probs_action = torch.stack(experiences.log_probs)
51 | r_values = torch.stack(experiences.values)
52 | r_entropies = torch.stack(experiences.entropies)
53 |
54 | # estimate value of next state
55 | with torch.no_grad():
56 | results, _, _ = network(next_obs, internals)
57 | last_values = results["critic"].squeeze(1).data
58 |
59 | # compute nstep return and advantage over batch
60 | r_tgt_returns = self.compute_returns(
61 | last_values, rewards, experiences.terminals
62 | )
63 | r_advantages = r_tgt_returns - r_values.data
64 |
65 | # normalize advantage so that an even number
66 | # of actions are reinforced and penalized
67 | if self.normalize_advantage:
68 | r_advantages = (r_advantages - r_advantages.mean()) / (
69 | r_advantages.std() + 1e-5
70 | )
71 |
72 | # batched losses
73 | policy_loss = -(r_log_probs_action) * r_advantages.unsqueeze(-1)
74 | # mean over actions, seq, batch
75 | policy_loss = policy_loss.mean()
76 | entropy_loss = -r_entropies.mean() * self.entropy_weight
77 | value_loss = 0.5 * (r_tgt_returns - r_values).pow(2).mean()
78 |
79 | updater.step(value_loss + policy_loss + entropy_loss)
80 |
81 | losses = {
82 | "value_loss": value_loss,
83 | "policy_loss": policy_loss,
84 | "entropy_loss": entropy_loss,
85 | }
86 | metrics = {}
87 | return losses, metrics
88 |
89 | def compute_returns(self, bootstrap_value, rewards, terminals):
90 | # First step of nstep reward target is estimated value of t+1
91 | target_return = bootstrap_value
92 | rollout_len = len(rewards)
93 | nstep_target_returns = []
94 | for i in reversed(range(rollout_len)):
95 | reward = rewards[i]
96 | terminal_mask = 1.0 - terminals[i].float()
97 |
98 | if self.return_scale:
99 | target_return = self.dm_scaler.calc_scale(
100 | reward
101 | + self.discount
102 | * self.dm_scaler.calc_inverse_scale(target_return)
103 | * terminal_mask
104 | )
105 | else:
106 | target_return = reward + (
107 | self.discount * target_return * terminal_mask
108 | )
109 | nstep_target_returns.append(target_return)
110 |
111 | # reverse lists
112 | nstep_target_returns = torch.stack(
113 | list(reversed(nstep_target_returns))
114 | ).data
115 |
116 | return nstep_target_returns
117 |
--------------------------------------------------------------------------------
/mcs/learner/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .learner_module import LearnerModule
2 |
--------------------------------------------------------------------------------
/mcs/learner/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/learner/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/learner/base/__pycache__/dm_return_scale.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/dm_return_scale.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/learner/base/__pycache__/dm_return_scale.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/dm_return_scale.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/learner/base/__pycache__/learner_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/learner_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/learner/base/__pycache__/learner_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/learner_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/learner/base/dm_return_scale.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class DeepMindReturnScaler:
5 | """
6 | Scale returns as in R2D2.
7 | https://openreview.net/pdf?id=r1lyTjAqYX
8 | """
9 |
10 | def __init__(self, scale):
11 | self.scale = scale
12 |
13 | def calc_scale(self, x):
14 | return (
15 | torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + self.scale * x
16 | )
17 |
18 | def calc_inverse_scale(self, x):
19 | sign = torch.sign(x)
20 | sqrt = torch.sqrt(1 + 4 * self.scale * (torch.abs(x) + 1 + self.scale))
21 | return sign * ((((sqrt - 1) / (2 * self.scale)) ** 2) - 1)
22 |
--------------------------------------------------------------------------------
/mcs/learner/base/learner_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | """
16 | A Learner receives agent-environment interactions which are used to compute
17 | loss.
18 | """
19 | import abc
20 | from mcs.utils.requires_args import RequiresArgsMixin
21 |
22 |
23 | class LearnerModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
24 | """
25 | This one of the modules to use for custom Actor-Learner code.
26 | """
27 |
28 | @classmethod
29 | @abc.abstractmethod
30 | def from_args(self, args, reward_normalizer):
31 | raise NotImplementedError
32 |
33 | @abc.abstractmethod
34 | def learn_step(self, updater, network,target_network, experiences, next_obs, internals,avg_net):
35 | raise NotImplementedError
36 |
--------------------------------------------------------------------------------
/mcs/logs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/logs/.DS_Store
--------------------------------------------------------------------------------
/mcs/manager/__init__.py:
--------------------------------------------------------------------------------
1 | from mcs.manager.simple_env_manager import SimpleEnvManager
2 | from mcs.manager.subproc_env_manager import SubProcEnvManager
3 | from mcs.manager.base.manager_module import EnvManagerModule
4 |
5 |
6 | MANAGER_REG = [SimpleEnvManager, SubProcEnvManager]
7 |
--------------------------------------------------------------------------------
/mcs/manager/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/manager/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/manager/__pycache__/simple_env_manager.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/simple_env_manager.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/manager/__pycache__/simple_env_manager.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/simple_env_manager.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/manager/__pycache__/subproc_env_manager.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/subproc_env_manager.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/manager/__pycache__/subproc_env_manager.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/subproc_env_manager.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/manager/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__init__.py
--------------------------------------------------------------------------------
/mcs/manager/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/manager/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/manager/base/__pycache__/manager_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/manager_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/manager/base/__pycache__/manager_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/manager_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/manager/base/manager_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from mcs.env.base._env import EnvBase
18 | from mcs.utils.requires_args import RequiresArgsMixin
19 |
20 |
21 | class EnvManagerModule(EnvBase, RequiresArgsMixin, metaclass=abc.ABCMeta):
22 | def __init__(self, env_fns, engine):
23 | self._env_fns = env_fns
24 | self._engine = engine
25 |
26 | @property
27 | def env_fns(self):
28 | return self._env_fns
29 |
30 | @property
31 | def engine(self):
32 | return self._engine
33 |
34 | @property
35 | def nb_env(self):
36 | return len(self._env_fns)
37 |
38 | @classmethod
39 | def from_args(cls, args, engine, env_cls, seed=None, nb_env=None, **kwargs):
40 | if seed is None:
41 | seed = int(args.seed)
42 | if nb_env is None:
43 | nb_env = args.nb_env
44 |
45 | env_fns = []
46 | for i in range(nb_env):
47 | env_fns.append(env_cls.from_args_curry(args, seed + i, **kwargs))
48 | return cls(env_fns, engine)
49 |
--------------------------------------------------------------------------------
/mcs/manager/simple_env_manager.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import numpy as np
16 | import torch
17 |
18 | from mcs.utils import listd_to_dlist
19 | from mcs.utils.util import dlist_to_listd
20 | from .base.manager_module import EnvManagerModule
21 |
22 |
23 | class SimpleEnvManager(EnvManagerModule):
24 | """
25 | Manages multiple env in the same process. This is slower than a
26 | SubProcEnvManager but allows debugging.
27 | """
28 |
29 | args = {}
30 |
31 | def __init__(self, env_fns, engine):
32 | super(SimpleEnvManager, self).__init__(env_fns, engine)
33 | self.envs = [fn() for fn in env_fns]
34 | env = self.envs[0]
35 | self._observation_space = env.observation_space
36 | self._action_space = env.action_space
37 | self._cpu_preprocessor = env.cpu_preprocessor
38 | self._gpu_preprocessor = env.gpu_preprocessor
39 |
40 | self.buf_obs = [None for _ in range(self.nb_env)]
41 | self.buf_dones = [None for _ in range(self.nb_env)]
42 | self.buf_rews = [None for _ in range(self.nb_env)]
43 | self.buf_infos = [None for _ in range(self.nb_env)]
44 | self.actions = None
45 |
46 | @property
47 | def cpu_preprocessor(self):
48 | return self._cpu_preprocessor
49 |
50 | @property
51 | def gpu_preprocessor(self):
52 | return self._gpu_preprocessor
53 |
54 | @property
55 | def observation_space(self):
56 | return self._observation_space
57 |
58 | @property
59 | def action_space(self):
60 | return self._action_space
61 |
62 | def step(self, actions):
63 | self.step_async(actions)
64 | return self.step_wait()
65 |
66 | def step_async(self, actions):
67 | actions_tensor = dlist_to_listd(actions)
68 | self.actions = [
69 | {k: v.item() for k, v in a_ten.items()} for a_ten in actions_tensor
70 | ]
71 |
72 | def step_wait(self):
73 | obs = []
74 | for e in range(self.nb_env):
75 | (
76 | ob,
77 | self.buf_rews[e],
78 | self.buf_dones[e],
79 | self.buf_infos[e],
80 | ) = self.envs[e].step(self.actions[e])
81 | if self.buf_dones[e]:
82 | ob = self.envs[e].reset()
83 | obs.append(ob)
84 | obs = listd_to_dlist(obs)
85 | new_obs = {}
86 | for k, v in dummy_handle_ob(obs).items():
87 | if self._is_tensor_key(k):
88 | new_obs[k] = torch.stack(v)
89 | else:
90 | new_obs[k] = v
91 | self.buf_obs = new_obs
92 |
93 | return (
94 | self.buf_obs,
95 | torch.tensor(self.buf_rews),
96 | torch.tensor(self.buf_dones),
97 | self.buf_infos,
98 | )
99 |
100 | def reset(self):
101 | obs = []
102 | for e in range(self.nb_env):
103 | ob = self.envs[e].reset()
104 | obs.append(ob)
105 | obs = listd_to_dlist(obs)
106 | new_obs = {}
107 | for k, v in dummy_handle_ob(obs).items():
108 | if self._is_tensor_key(k):
109 | new_obs[k] = torch.stack(v)
110 | else:
111 | new_obs[k] = v
112 | self.buf_obs = new_obs
113 | return self.buf_obs
114 |
115 | def close(self):
116 | return [e.close() for e in self.envs]
117 |
118 | def render(self, mode="human"):
119 | return [e.render(mode=mode) for e in self.envs]
120 |
121 | def _is_tensor_key(self, key):
122 | return None not in self.cpu_preprocessor.observation_space[key]
123 |
124 |
125 | def dummy_handle_ob(ob):
126 | new_ob = {}
127 | for k, v in ob.items():
128 | if isinstance(v, np.ndarray):
129 | new_ob[k] = torch.from_numpy(v)
130 | else:
131 | new_ob[k] = v
132 | return new_ob
133 |
--------------------------------------------------------------------------------
/mcs/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .norm import Identity
16 | from .attention import MultiHeadSelfAttention, RMCCell
17 | from .sequence import LSTMCellLayerNorm
18 | from .spatial import Residual2DPreact
19 |
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/attention.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/attention.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/attention.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/attention.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/norm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/norm.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/norm.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/sequence.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/sequence.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/sequence.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/sequence.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/spatial.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/spatial.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/modules/__pycache__/spatial.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/spatial.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/modules/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import math
16 |
17 | import torch
18 | from torch import nn
19 | from torch.nn import Parameter, functional as F
20 |
21 |
22 | class GaussianLinear(nn.Module):
23 | def __init__(self, fan_in, nodes):
24 | super().__init__()
25 | self.mu = nn.Linear(fan_in, nodes)
26 | # init_tflearn_fc_(self.mu)
27 | self.std = nn.Linear(fan_in, nodes)
28 | # init_tflearn_fc_(self.std)
29 |
30 | def forward(self, x):
31 | mu = self.mu(x)
32 | if self.training:
33 | std = self.std(x)
34 | std = torch.exp(0.5 * std)
35 | eps = torch.randn_like(std)
36 | return eps.mul(std).add_(mu)
37 | else:
38 | return mu
39 |
40 | def get_parameter_names(self):
41 | return ["Mu_W", "Mu_b", "Std_W", "Std_b"]
42 |
43 |
44 | class NoisyLinear(nn.Linear):
45 | """
46 | Reference implementation:
47 | https://github.com/Kaixhin/NoisyNet-A3C/blob/master/model.py
48 | """
49 |
50 | def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
51 | super(NoisyLinear, self).__init__(in_features, out_features, bias=True)
52 | # µ^w and µ^b reuse self.weight and self.bias
53 | self.sigma_init = sigma_init
54 | self.sigma_weight = Parameter(
55 | torch.Tensor(out_features, in_features)
56 | ) # σ^w
57 | self.sigma_bias = Parameter(torch.Tensor(out_features)) # σ^b
58 | self.init_params()
59 |
60 | def init_params(self):
61 | limit = math.sqrt(3 / self.in_features)
62 |
63 | self.weight.data.uniform_(-limit, limit)
64 | self.bias.data.uniform_(-limit, limit)
65 | self.sigma_weight.data.fill_(self.sigma_init)
66 | self.sigma_bias.data.fill_(self.sigma_init)
67 |
68 | def forward(self, x, internals):
69 | if self.training:
70 | w = self.weight + self.sigma_weight * internals[0]
71 | b = self.bias + self.sigma_bias * internals[1]
72 | else:
73 | w = self.weight + self.sigma_weight
74 | b = self.bias + self.sigma_bias
75 | return F.linear(x, w, b)
76 |
77 | def batch_forward(self, x, internals, batch_size=None):
78 | print(
79 | "WARNING: calling forward multiple times is actually"
80 | "faster than this and takes less memory"
81 | )
82 | batch_size = batch_size if batch_size is not None else x.shape[0]
83 | x = x.unsqueeze(1)
84 | # internals come in as [[w, b], ...] reshape to [w, ...], [b, ...]
85 | eps_w, eps_b = zip(*internals)
86 | eps_w = torch.stack(eps_w)
87 | eps_b = torch.stack(eps_b)
88 | batch_w = self.weight.unsqueeze(0).expand(
89 | batch_size, -1, -1
90 | ) + self.sigma_weight.unsqueeze(0).expand(batch_size, -1, -1)
91 | batch_w += eps_w
92 | # permute to b x m x p
93 | batch_w = batch_w.permute(0, 2, 1)
94 | batch_b = self.bias.expand(batch_size, -1) + self.sigma_bias.expand(
95 | batch_size, -1
96 | )
97 | batch_b += eps_b
98 |
99 | bmm = torch.bmm(x, batch_w).squeeze(1)
100 |
101 | return bmm + batch_b
102 |
103 | def reset(self, gpu=False, device=None):
104 | # sample new noise
105 | if not gpu:
106 | return (
107 | torch.randn(self.out_features, self.in_features).detach(),
108 | torch.randn(self.out_features).detach(),
109 | )
110 | else:
111 | return (
112 | torch.randn(self.out_features, self.in_features)
113 | .cuda(device, non_blocking=True)
114 | .detach(),
115 | torch.randn(self.out_features)
116 | .cuda(device, non_blocking=True)
117 | .detach(),
118 | )
119 |
120 | def get_parameter_names(self):
121 | return ["W", "b", "sigma_W", "sigma_b"]
122 |
--------------------------------------------------------------------------------
/mcs/modules/norm.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 |
17 |
18 | class Identity(torch.nn.Module):
19 | def forward(self, x):
20 | return x
21 |
--------------------------------------------------------------------------------
/mcs/modules/sequence.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from torch.nn import Module, Linear, LayerNorm
17 |
18 |
19 | class LSTMCellLayerNorm(Module):
20 | """
21 | A lstm cell that layer norms the cell state
22 | https://github.com/seba-1511/lstms.pth/blob/master/lstms/lstm.py for reference.
23 | Original License Apache 2.0
24 | """
25 |
26 | def __init__(self, input_size, hidden_size, bias=True, forget_bias=0):
27 | super().__init__()
28 | self.input_size = input_size
29 | self.hidden_size = hidden_size
30 | self.ih = Linear(input_size, 4 * hidden_size, bias=bias)
31 | self.hh = Linear(hidden_size, 4 * hidden_size, bias=bias)
32 |
33 | if bias:
34 | self.ih.bias.data.fill_(0)
35 | self.hh.bias.data.fill_(0)
36 | # forget bias init
37 | self.ih.bias.data[hidden_size : hidden_size * 2].fill_(forget_bias)
38 | self.hh.bias.data[hidden_size : hidden_size * 2].fill_(forget_bias)
39 |
40 | self.ln_cell = LayerNorm(hidden_size)
41 |
42 | def forward(self, x, hidden):
43 | """
44 | LSTM Cell that layer normalizes the cell state.
45 | :param x: Tensor{B, C}
46 | :param hidden: A Tuple[Tensor{B, C}, Tensor{B, C}] of (previous output, cell state)
47 | :return:
48 | """
49 | h, c = hidden
50 |
51 | # Linear mappings
52 | i2h = self.ih(x)
53 | h2h = self.hh(h)
54 | preact = i2h + h2h
55 |
56 | # activations
57 | gates = preact[:, : 3 * self.hidden_size].sigmoid()
58 | g_t = preact[:, 3 * self.hidden_size :].tanh()
59 | i_t = gates[:, : self.hidden_size]
60 | f_t = gates[:, self.hidden_size : 2 * self.hidden_size]
61 | o_t = gates[:, -self.hidden_size :]
62 |
63 | # cell computations
64 | c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t)
65 | c_t = self.ln_cell(c_t)
66 | h_t = torch.mul(o_t, c_t.tanh())
67 |
68 | return h_t, c_t
69 |
--------------------------------------------------------------------------------
/mcs/modules/spatial.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from torch import nn as nn
16 | from torch.nn import functional as F
17 |
18 |
19 | class Residual2DPreact(nn.Module):
20 | def __init__(self, nb_in_chan, nb_out_chan, stride=1):
21 | super(Residual2DPreact, self).__init__()
22 |
23 | self.nb_in_chan = nb_in_chan
24 | self.nb_out_chan = nb_out_chan
25 | self.stride = stride
26 |
27 | self.bn1 = nn.BatchNorm2d(nb_in_chan)
28 | self.conv1 = nn.Conv2d(
29 | nb_in_chan, nb_out_chan, 3, stride=stride, padding=1, bias=False
30 | )
31 | self.bn2 = nn.BatchNorm2d(nb_out_chan)
32 | self.conv2 = nn.Conv2d(
33 | nb_out_chan, nb_out_chan, 3, stride=1, padding=1, bias=False
34 | )
35 |
36 | relu_gain = nn.init.calculate_gain("relu")
37 | self.conv1.weight.data.mul_(relu_gain)
38 | self.conv2.weight.data.mul_(relu_gain)
39 |
40 | self.do_projection = (
41 | self.nb_in_chan != self.nb_out_chan or self.stride > 1
42 | )
43 | if self.do_projection:
44 | self.projection = nn.Conv2d(
45 | nb_in_chan, nb_out_chan, 3, stride=stride, padding=1
46 | )
47 | self.projection.weight.data.mul_(relu_gain)
48 |
49 | def forward(self, x):
50 | first = F.relu(self.bn1(x))
51 | if self.do_projection:
52 | projection = self.projection(first)
53 | else:
54 | projection = x
55 | x = self.conv1(first)
56 | x = self.conv2(F.relu(self.bn2(x)))
57 | return x + projection
58 |
--------------------------------------------------------------------------------
/mcs/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .base.network_module import NetworkModule
2 | from .modular_network import ModularNetwork
3 | from .net1d.submodule_1d import SubModule1D
4 | from .net2d.submodule_2d import SubModule2D
5 | from .net3d.submodule_3d import SubModule3D
6 | from .net4d.submodule_4d import SubModule4D
7 |
8 | from .net1d.linear import Linear
9 | from .net1d.identity_1d import Identity1D
10 | from .net1d.lstm import LSTM
11 |
12 | from .net2d.identity_2d import Identity2D
13 | from .net3d.identity_3d import Identity3D
14 | from .net3d.four_conv import FourConv
15 |
16 | from .net4d.identity_4d import Identity4D
17 | from .net2d.deconv import DeConv2D
18 |
19 |
20 |
21 | NET_REG = []
22 | SUBMOD_REG = [
23 | Identity1D,
24 | Linear,
25 | LSTM,
26 | Identity2D,
27 | Identity3D,
28 | FourConv,
29 | Identity4D,
30 | DeConv2D
31 |
32 | ]
33 |
--------------------------------------------------------------------------------
/mcs/network/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/__pycache__/modular_network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/modular_network.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/__pycache__/modular_network.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/modular_network.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__init__.py
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/base.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/base.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/base.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/base.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/network_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/network_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/network_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/network_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/submodule.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/submodule.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/base/__pycache__/submodule.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/submodule.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/base/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | import torch
18 | from torch import distributed as dist
19 |
20 |
21 | class BaseNetwork(torch.nn.Module):
22 | @classmethod
23 | @abc.abstractmethod
24 | def from_args(
25 | cls, args, observation_space, output_space, gpu_preprocessor, net_reg
26 | ):
27 | raise NotImplementedError
28 |
29 | @abc.abstractmethod
30 | def new_internals(self, device):
31 | """
32 | :return: Dict[InternalKey, torch.Tensor (ND)]
33 | """
34 | raise NotImplementedError
35 |
36 | @abc.abstractmethod
37 | def forward(self, observation, internals):
38 | raise NotImplementedError
39 |
40 | def internal_space(self):
41 | return {k: t.shape for k, t in self.new_internals("cpu").items()}
42 |
43 | def sync(self, src, grp=None, async_op=False):
44 |
45 | keys = []
46 | handles = []
47 |
48 | for k, t in self.state_dict().items():
49 | if grp is None:
50 | h = dist.broadcast(t, src, async_op=True)
51 | else:
52 | h = dist.broadcast(t, src, grp, async_op=True)
53 |
54 | keys.append(k)
55 | handles.append(h)
56 |
57 | if not async_op:
58 | for k, h in zip(keys, handles):
59 | h.wait()
60 |
61 | return handles
62 |
--------------------------------------------------------------------------------
/mcs/network/base/network_module.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import abc
16 |
17 | from mcs.network.base.base import BaseNetwork
18 | from mcs.utils.requires_args import RequiresArgsMixin
19 |
20 |
21 | class NetworkModule(BaseNetwork, RequiresArgsMixin, metaclass=abc.ABCMeta):
22 | pass
23 |
--------------------------------------------------------------------------------
/mcs/network/net1d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__init__.py
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/identity_1d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/identity_1d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/identity_1d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/identity_1d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/linear.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/linear.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/linear.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/linear.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/lstm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/lstm.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/lstm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/lstm.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/submodule_1d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/submodule_1d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/__pycache__/submodule_1d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/submodule_1d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net1d/identity_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_1d import SubModule1D
16 |
17 |
18 | class Identity1D(SubModule1D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/mcs/network/net1d/linear.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from __future__ import division
16 |
17 | from torch import nn
18 | from torch.nn import functional as F
19 |
20 | from mcs.modules import Identity
21 |
22 | from .submodule_1d import SubModule1D
23 |
24 |
25 | class Linear(SubModule1D):
26 | args = {"linear_normalize": "bn", "linear_nb_hidden": 512, "nb_layer": 3}
27 |
28 | def __init__(self, input_shape, id, normalize, nb_hidden, nb_layer):
29 | super().__init__(input_shape, id)
30 | self._nb_hidden = nb_hidden
31 |
32 | nb_input_channel = input_shape[0]
33 |
34 | bias = not normalize
35 | self.linears = nn.ModuleList(
36 | [
37 | nn.Linear(
38 | nb_input_channel if i == 0 else nb_hidden, nb_hidden, bias
39 | )
40 | for i in range(nb_layer)
41 | ]
42 | )
43 | if normalize == "bn":
44 | self.norms = nn.ModuleList(
45 | [nn.BatchNorm1d(nb_hidden) for _ in range(nb_layer)]
46 | )
47 | elif normalize == "gn":
48 | if nb_hidden % 16 != 0:
49 | raise Exception(
50 | "linear_nb_hidden must be divisible by 16 for Group Norm"
51 | )
52 | self.norms = nn.ModuleList(
53 | [
54 | nn.GroupNorm(nb_hidden // 16, nb_hidden)
55 | for _ in range(nb_layer)
56 | ]
57 | )
58 | else:
59 | self.norms = nn.ModuleList([Identity() for _ in range(nb_layer)])
60 |
61 | @classmethod
62 | def from_args(cls, args, input_shape, id):
63 | return cls(
64 | input_shape,
65 | id,
66 | args.linear_normalize,
67 | args.linear_nb_hidden,
68 | args.nb_layer,
69 | )
70 |
71 | def _forward(self, xs, internals, **kwargs):
72 | for linear, norm in zip(self.linears, self.norms):
73 | xs = F.relu(norm(linear(xs)))
74 | return xs, {}
75 |
76 | def _new_internals(self):
77 | return {}
78 |
79 | @property
80 | def _output_shape(self):
81 | return (self._nb_hidden,)
82 |
--------------------------------------------------------------------------------
/mcs/network/net1d/lstm.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from torch.nn import LSTMCell, Sequential, Linear, ReLU
17 |
18 | from mcs.modules import LSTMCellLayerNorm
19 | from .submodule_1d import SubModule1D
20 |
21 |
22 | class LSTM(SubModule1D):
23 | args = {"lstm_normalize": True, "lstm_nb_hidden": 512}
24 |
25 | def __init__(self, input_shape, id, normalize, nb_hidden):
26 | super().__init__(input_shape, id)
27 | self._nb_hidden = nb_hidden
28 | # self.fc = Sequential(
29 | # Linear(input_shape[0], 1024), ReLU(),
30 | # Linear(1024, 512), ReLU()
31 | # )
32 |
33 | if normalize:
34 | self.lstm = LSTMCellLayerNorm(input_shape[0], nb_hidden)
35 | else:
36 | self.lstm = LSTMCell(input_shape[0], nb_hidden)
37 | self.lstm.bias_ih.data.fill_(0)
38 | self.lstm.bias_hh.data.fill_(0)
39 |
40 | @classmethod
41 | def from_args(cls, args, input_shape, id):
42 | return cls(input_shape, id, args.lstm_normalize, args.lstm_nb_hidden)
43 |
44 | @property
45 | def _output_shape(self):
46 | return (self._nb_hidden,)
47 |
48 | def _forward(self, xs, internals, **kwargs):
49 | # xs = self.fc(xs)
50 | # print(xs.shape)
51 | # print(internals)
52 | # print('!!!!!!!!!!!!!!!!!')
53 | #print(internals)
54 | hxs = self.stacked_internals("hx", internals)
55 | cxs = self.stacked_internals("cx", internals)
56 | hxs, cxs = self.lstm(xs, (hxs, cxs))
57 | return (
58 | hxs,
59 | {
60 | "hx": list(torch.unbind(hxs, dim=0)),
61 | "cx": list(torch.unbind(cxs, dim=0)),
62 | },
63 | )
64 |
65 | def _new_internals(self):
66 | return {
67 | "hx": torch.zeros(self._nb_hidden),
68 | "cx": torch.zeros(self._nb_hidden),
69 | }
70 |
71 | def _to_2d_shape(self):
72 | return self._output_shape + (1,)
--------------------------------------------------------------------------------
/mcs/network/net1d/submodule_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from mcs.network.base.submodule import SubModule
16 | import abc
17 |
18 |
19 | class SubModule1D(SubModule, metaclass=abc.ABCMeta):
20 | dim = 1
21 |
22 | def __init__(self, input_shape, id):
23 | super(SubModule1D, self).__init__(input_shape, id)
24 |
25 | def _to_1d_shape(self):
26 | return self._output_shape
27 |
28 | def _to_2d_shape(self):
29 | return self._output_shape + (1,)
30 |
31 | def _to_3d_shape(self):
32 | return self._output_shape + (1, 1)
33 |
34 | def _to_4d_shape(self):
35 | return self._output_shape + (1, 1, 1)
36 |
--------------------------------------------------------------------------------
/mcs/network/net2d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__init__.py
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/deconv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/deconv.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/deconv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/deconv.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/identity_2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/identity_2d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/identity_2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/identity_2d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/submodule_2d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/submodule_2d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/__pycache__/submodule_2d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/submodule_2d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net2d/deconv.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | import torch.nn as nn
17 | import torch
18 | from ..net2d.submodule_2d import SubModule2D
19 |
20 |
21 | class DeConv2D(SubModule2D):
22 | args = {}
23 |
24 | def __init__(self, input_shape, id):
25 | super().__init__(input_shape, id)
26 | self._input_shape = input_shape
27 | self._out_shape = None
28 | self._pc_layers = 16
29 | self.num_action = 6
30 |
31 | self.fc = nn.Sequential(
32 | nn.Linear(in_features=input_shape[0], out_features=9 * 9 * self._pc_layers), nn.ReLU())
33 |
34 | self.deconv_v = nn.Sequential(
35 | nn.ConvTranspose2d(in_channels=self._pc_layers, out_channels=1, kernel_size=5, stride=2),
36 | nn.ReLU())
37 | self.deconv_a = nn.Sequential(
38 | nn.ConvTranspose2d(in_channels=self._pc_layers, out_channels=self.num_action, kernel_size=5, stride=2),
39 | nn.ReLU())
40 |
41 | @classmethod
42 | def from_args(cls,args,input_shape, id):
43 | return cls(input_shape, id)
44 |
45 | @property
46 | def _output_shape(self):
47 | if self._out_shape is None:
48 | self._out_shape = 21*21,self.num_action
49 | return self._out_shape
50 |
51 | def _forward(self, input, internals=None, **kwargs):
52 | #print(input.shape) # 32 512 1
53 | input = self.fc(input.view(-1, 512))
54 | input = input.view([-1, self._pc_layers, 9, 9])
55 | v = self.deconv_v(input)
56 | a = self.deconv_a(input)
57 | a_mean = torch.mean(a, dim=1, keepdim=True)
58 | q = v + a - a_mean
59 | out = q.reshape(-1, self.num_action, 21 * 21)
60 | #print(out.shape) # 32 6 441
61 | return out.permute(0, 2, 1).contiguous(),{}
62 |
63 | def _new_internals(self):
64 | return {}
65 |
66 | @_output_shape.setter
67 | def _output_shape(self, value):
68 | self.__output_shape = value
69 |
70 |
71 |
72 | def calc_output_dim(dim_size, kernel_size, stride=1, input_padding=0, output_padding=0, dilation=1):
73 | numerator = (dim_size - 1) * stride - 2 * input_padding + dilation * (kernel_size - 1) + output_padding + 1
74 | return numerator
75 |
76 |
77 | if __name__ == "__main__":
78 | out = [4, 5, 1]
79 | for output_dim in out:
80 | output_dim = calc_output_dim(output_dim, 5, 2)
81 | print(output_dim)
82 |
--------------------------------------------------------------------------------
/mcs/network/net2d/identity_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_2d import SubModule2D
16 |
17 |
18 | class Identity2D(SubModule2D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/mcs/network/net2d/submodule_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from ..base.submodule import SubModule
16 | import abc
17 | import math
18 |
19 |
20 | class SubModule2D(SubModule, metaclass=abc.ABCMeta):
21 | dim = 2
22 |
23 | def __init__(self, input_shape, id):
24 | super(SubModule2D, self).__init__(input_shape, id)
25 |
26 | def _to_1d_shape(self):
27 | f, s = self._output_shape
28 | return (f * s,)
29 |
30 | def _to_2d_shape(self):
31 | return self._output_shape
32 |
33 | def _to_3d_shape(self):
34 | f, s = self._output_shape
35 | return (f * s, 1, 1)
36 |
37 | def _to_4d_shape(self):
38 | f, s = self._output_shape
39 | return (f, s, 1, 1)
40 |
--------------------------------------------------------------------------------
/mcs/network/net3d/RelationalMHDPA.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from mcs.network.net2d.submodule_2d import SubModule2D
16 | import torch.nn as nn
17 | import torch
18 | import numpy as np
19 | import math
20 |
21 |
22 | class RelationalMHDPA(nn.Module):
23 | args = {}
24 |
25 | def __init__(self, input_shape, nb_head, scale=False):
26 | super(RelationalMHDPA, self).__init__()
27 | self._out_shape = input_shape
28 |
29 |
30 | nb_channel = input_shape[0]
31 | height = input_shape[1]
32 | width = input_shape[2]
33 |
34 | assert nb_channel % nb_head == 0
35 | seq_len = height * width
36 | self.register_buffer(
37 | "b",
38 | torch.tril(torch.ones(seq_len, seq_len)).view(
39 | 1, 1, seq_len, seq_len
40 | ),
41 | )
42 | self.nb_head = nb_head
43 | self.split_size = nb_channel
44 | self.scale = scale
45 | self.projection = nn.Linear(nb_channel, nb_channel * 3)
46 | self.re = nn.ReLU()
47 | self.mlp = nn.Linear(nb_channel, nb_channel)
48 |
49 | def forward(self, x):
50 | """
51 | :param x: A tensor with a shape of [batch, seq_len, nb_channel]
52 | :return: A tensor with a shape of [batch, seq_len, nb_channel]
53 | """
54 |
55 | size_out = x.size()[:-1] + (self.split_size * 3,) # [batch, seq_len, nb_channel*3]
56 |
57 | x = self.projection(x.view(-1, x.size(-1))) # [BT,C]
58 | # x = self.re(x)
59 | x = x.view(*size_out) # [B,T,3C]
60 |
61 | query, key, value = x.split(self.split_size, dim=2)
62 | query = self.split_heads(query)
63 | key = self.split_heads(key, k=True)
64 | value = self.split_heads(value)
65 |
66 | a = self._attn(query, key, value)
67 | e = self.merge_heads(a)
68 |
69 | return self.mlp(e)
70 |
71 | def _new_internals(self):
72 | return {}
73 |
74 | def _attn(self, q, k, v):
75 | w = torch.matmul(q, k)
76 |
77 | if self.scale:
78 | w = w / math.sqrt(v.size(-1))
79 | w = w * self.b + -1e9 * (
80 | 1 - self.b
81 | ) # TF implem method: mask_attn_weights
82 | w = nn.Softmax(dim=-1)(w)
83 | h = torch.matmul(w, v)
84 |
85 | return h
86 |
87 | def merge_heads(self, x):
88 | x = x.permute(0, 2, 1, 3).contiguous()
89 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
90 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
91 |
92 | def split_heads(self, x, k=False):
93 | # keep dims, but expand the last dim to be [head, chan // head] X[B,T,C]
94 | new_x_shape = x.size()[:-1] + (self.nb_head, x.size(-1) // self.nb_head) # [B,T,H,C//H]
95 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
96 | if k:
97 | # batch, head, channel, attend
98 | return x.permute(0, 2, 3, 1)
99 | else:
100 | # batch, head, attend, channel
101 | return x.permute(0, 2, 1, 3)
102 |
103 | def get_parameter_names(self, layer):
104 | return [
105 | "Proj{}_W".format(layer),
106 | "Proj{}_b".format(layer),
107 | "MLP{}_W".format(layer),
108 | "MLP{}_b".format(layer),
109 | ]
110 |
--------------------------------------------------------------------------------
/mcs/network/net3d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__init__.py
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/RelationalMHDPA.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/RelationalMHDPA.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/deconv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/deconv.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/four_conv.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/four_conv.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/four_conv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/four_conv.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/identity_3d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/identity_3d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/identity_3d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/identity_3d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/submodule_3d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/submodule_3d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/__pycache__/submodule_3d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/submodule_3d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net3d/four_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from torch.nn import Conv2d, BatchNorm2d, GroupNorm, init, functional as F
16 |
17 | from mcs.modules import Identity
18 | from mcs.network.net3d.submodule_3d import SubModule3D
19 | from mcs.network.net3d.RelationalMHDPA import RelationalMHDPA
20 |
21 |
22 | class FourConv(SubModule3D):
23 | args = {"fourconv_norm": "bn"}
24 |
25 | def __init__(self, in_shape, id, normalize, args):
26 | super().__init__(in_shape, id)
27 | bias = not normalize
28 | self._in_shape = in_shape
29 | self._out_shape = None
30 | self._args = args
31 | self.conv1 = Conv2d(in_shape[0], 32, 7, stride=2, padding=1, bias=bias)
32 | self.conv2 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias)
33 | self.conv3 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias)
34 | self.conv4 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias)
35 |
36 | self.use_mhra = args.use_mhra
37 | if normalize == "bn":
38 | self.bn1 = BatchNorm2d(32)
39 | self.bn2 = BatchNorm2d(32)
40 | self.bn3 = BatchNorm2d(32)
41 | self.bn4 = BatchNorm2d(32)
42 | elif normalize == "gn":
43 | self.bn1 = GroupNorm(8, 32)
44 | self.bn2 = GroupNorm(8, 32)
45 | self.bn3 = GroupNorm(8, 32)
46 | self.bn4 = GroupNorm(8, 32)
47 | else:
48 | self.bn1 = Identity()
49 | self.bn2 = Identity()
50 | self.bn3 = Identity()
51 | self.bn4 = Identity()
52 | if args.use_mhra:
53 | self.att = RelationalMHDPA(input_shape=(32, 5, 5), nb_head=args.num_head)
54 |
55 | relu_gain = init.calculate_gain("relu")
56 | self.conv1.weight.data.mul_(relu_gain)
57 | self.conv2.weight.data.mul_(relu_gain)
58 | self.conv3.weight.data.mul_(relu_gain)
59 | self.conv4.weight.data.mul_(relu_gain)
60 |
61 | @classmethod
62 | def from_args(cls, args, in_shape, id):
63 | return cls(in_shape, id, args.linear_normalize,args)
64 |
65 | @property
66 | def _output_shape(self):
67 | # For 84x84, (32, 5, 5)
68 | if self._out_shape is None:
69 | output_dim = calc_output_dim(self._in_shape[1], 7, 2, 1, 1)
70 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
71 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
72 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
73 | self._out_shape = 32, output_dim, output_dim
74 | return self._out_shape
75 |
76 | def _forward(self, xs, internals, **kwargs):
77 |
78 | xs = F.relu(self.bn1(self.conv1(xs)))
79 | xs = F.relu(self.bn2(self.conv2(xs)))
80 | xs = F.relu(self.bn3(self.conv3(xs)))
81 | xs = F.relu(self.bn4(self.conv4(xs)))
82 | if self.use_mhra:
83 | xs =xs + self._atten(xs, self.att)
84 | return xs, {}
85 |
86 | def _new_internals(self):
87 | return {}
88 |
89 | def _atten(self, x, att):
90 | W, H = x.shape[-2:]
91 | h = x.view(-1, 32, W * H)
92 | h = h.permute(0, 2, 1).contiguous()
93 | h = att(h)
94 | h = h.permute(0, 2, 1).contiguous()
95 | h = h.view(-1, 32, W, H)
96 | return h
97 |
98 |
99 | def calc_output_dim(dim_size, kernel_size, stride, padding, dilation):
100 | numerator = dim_size + 2 * padding - dilation * (kernel_size - 1) - 1
101 | return numerator // stride + 1
102 |
103 |
104 | if __name__ == "__main__":
105 | output_dim = 84
106 | output_dim = calc_output_dim(output_dim, 7, 2, 1, 1)
107 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
108 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
109 | print(output_dim)
110 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
111 | print(output_dim) # should be 5
112 |
--------------------------------------------------------------------------------
/mcs/network/net3d/identity_3d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_3d import SubModule3D
16 |
17 |
18 | class Identity3D(SubModule3D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/mcs/network/net3d/rmc.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from torch.nn import Conv2d, Linear, BatchNorm2d, BatchNorm1d, functional as F
17 |
18 | from mcs.modules import RMCCell, Identity
19 |
20 | from mcs.network.net3d.submodule_3d import SubModule3D
21 |
22 | # TODO
23 | class RMC(SubModule3D):
24 | """
25 | Relational Memory Core
26 | https://arxiv.org/pdf/1806.01822.pdf
27 | """
28 |
29 | def __init__(self, nb_in_chan, output_shape_dict, normalize):
30 | self.embedding_size = 512
31 | super(RMC, self).__init__(self.embedding_size, output_shape_dict)
32 | bias = not normalize
33 | self.conv1 = Conv2d(
34 | nb_in_chan, 32, kernel_size=3, stride=2, padding=1, bias=bias
35 | )
36 | self.conv2 = Conv2d(
37 | 32, 32, kernel_size=3, stride=2, padding=1, bias=bias
38 | )
39 | self.conv3 = Conv2d(
40 | 32, 32, kernel_size=3, stride=2, padding=1, bias=bias
41 | )
42 | self.attention = RMCCell(100, 100, 34)
43 | self.conv4 = Conv2d(
44 | 34, 8, kernel_size=3, stride=1, padding=1, bias=bias
45 | )
46 | # BATCH x 8 x 10 x 10
47 | self.linear = Linear(800, 512, bias=bias)
48 |
49 | if normalize:
50 | self.bn1 = BatchNorm2d(32)
51 | self.bn2 = BatchNorm2d(32)
52 | self.bn3 = BatchNorm2d(32)
53 | self.bn4 = BatchNorm2d(8)
54 | self.bn_linear = BatchNorm1d(512)
55 | else:
56 | self.bn1 = Identity()
57 | self.bn2 = Identity()
58 | self.bn3 = Identity()
59 | self.bn_linear = Identity()
60 |
61 | def forward(self, input, prev_memories):
62 | """
63 | :param input: Tensor{B, C, H, W}
64 | :param prev_memories: Tuple{B}[Tensor{C}]
65 | :return:
66 | """
67 |
68 | x = F.relu(self.bn1(self.conv1(input)))
69 | x = F.relu(self.bn2(self.conv2(x)))
70 | x = F.relu(self.bn3(self.conv3(x)))
71 |
72 | h = x.size(2)
73 | w = x.size(3)
74 | xs_chan = (
75 | torch.linspace(-1, 1, w)
76 | .view(1, 1, 1, w)
77 | .expand(input.size(0), 1, w, w)
78 | .to(input.device)
79 | )
80 | ys_chan = (
81 | torch.linspace(-1, 1, h)
82 | .view(1, 1, h, 1)
83 | .expand(input.size(0), 1, h, h)
84 | .to(input.device)
85 | )
86 | x = torch.cat([x, xs_chan, ys_chan], dim=1)
87 |
88 | # need to transpose because attention expects
89 | # attention dim before channel dim
90 | x = x.view(x.size(0), x.size(1), h * w).transpose(1, 2)
91 | prev_memories = torch.stack(prev_memories)
92 | x = next_memories = self.attention(x.contiguous(), prev_memories)
93 | # need to undo the transpose before output
94 | x = x.transpose(1, 2)
95 | x = x.view(x.size(0), x.size(1), h, w)
96 |
97 | x = F.relu(self.bn4(self.conv4(x)))
98 | x = x.view(x.size(0), -1)
99 | x = F.relu(self.bn_linear(self.linear(x)))
100 | return x, list(torch.unbind(next_memories, 0))
101 |
102 | @classmethod
103 | def from_args(cls, args, in_shape, id):
104 | return cls(in_shape, id, args.fourconv_norm)
105 |
106 | @property
107 | def _output_shape(self):
108 | # For 84x84, (32, 5, 5)
109 | if self._out_shape is None:
110 | output_dim = calc_output_dim(self._in_shape[1], 3, 2, 1, 1)
111 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
112 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
113 | output_dim = calc_output_dim(output_dim, 3, 1, 1, 1)
114 | self._out_shape = 32, output_dim, output_dim
115 | return self._out_shape
116 |
117 | def _forward(self, xs, internals, **kwargs):
118 |
119 | xs = F.relu(self.bn1(self.conv1(xs)))
120 | xs = F.relu(self.bn2(self.conv2(xs)))
121 | xs = F.relu(self.bn3(self.conv3(xs)))
122 | xs = F.relu(self.bn4(self.conv4(xs)))
123 | return xs, {}
124 |
125 | def _new_internals(self):
126 | return {}
127 |
128 |
129 | def calc_output_dim(dim_size, kernel_size, stride, padding, dilation):
130 | numerator = dim_size + 2 * padding - dilation * (kernel_size - 1) - 1
131 | return numerator // stride + 1
132 |
133 | if __name__ == "__main__":
134 | output_dim = 84
135 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
136 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
137 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1)
138 | output_dim = calc_output_dim(output_dim, 3, 1, 1, 1)
139 | print(output_dim) # should be 5
140 |
141 |
--------------------------------------------------------------------------------
/mcs/network/net3d/submodule_3d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from ..base.submodule import SubModule
16 | import abc
17 |
18 |
19 | class SubModule3D(SubModule, metaclass=abc.ABCMeta):
20 | dim = 3
21 |
22 | def __init__(self, input_shape, id):
23 | super(SubModule3D, self).__init__(input_shape, id)
24 |
25 | def _to_1d_shape(self):
26 | f, h, w = self._output_shape
27 | return (f * h * w,)
28 |
29 | def _to_2d_shape(self):
30 | f, h, w = self._output_shape
31 | return (f, h * w)
32 |
33 | def _to_3d_shape(self):
34 | return self._output_shape
35 |
36 | def _to_4d_shape(self):
37 | f, h, w = self._output_shape
38 | return (f, 1, h, w)
39 |
--------------------------------------------------------------------------------
/mcs/network/net4d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__init__.py
--------------------------------------------------------------------------------
/mcs/network/net4d/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net4d/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net4d/__pycache__/identity_4d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/identity_4d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net4d/__pycache__/identity_4d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/identity_4d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net4d/__pycache__/submodule_4d.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/submodule_4d.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/network/net4d/__pycache__/submodule_4d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/submodule_4d.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/network/net4d/identity_4d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .submodule_4d import SubModule4D
16 |
17 |
18 | class Identity4D(SubModule4D):
19 | args = {}
20 |
21 | def __init__(self, input_shape, id):
22 | super().__init__(input_shape, id)
23 |
24 | @classmethod
25 | def from_args(cls, args, input_shape, id):
26 | return cls(input_shape, id)
27 |
28 | @property
29 | def _output_shape(self):
30 | return self.input_shape
31 |
32 | def _forward(self, input, internals, **kwargs):
33 | return input, {}
34 |
35 | def _new_internals(self):
36 | return {}
37 |
--------------------------------------------------------------------------------
/mcs/network/net4d/submodule_4d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from ..base.submodule import SubModule
16 | import abc
17 |
18 |
19 | class SubModule4D(SubModule, metaclass=abc.ABCMeta):
20 | dim = 4
21 |
22 | def __init__(self, input_shape, id):
23 | super(SubModule4D, self).__init__(input_shape, id)
24 |
25 | def _to_1d_shape(self):
26 | f, d, h, w = self._output_shape
27 | return (f * d * h * w,)
28 |
29 | def _to_2d_shape(self):
30 | f, d, h, w = self._output_shape
31 | return (f, d * h * w)
32 |
33 | def _to_3d_shape(self):
34 | f, d, h, w = self._output_shape
35 | return (f * d, h, w)
36 |
37 | def _to_4d_shape(self):
38 | return self._output_shape
39 |
--------------------------------------------------------------------------------
/mcs/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 Heron Systems, Inc.
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 | """
17 |
--------------------------------------------------------------------------------
/mcs/preprocess/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/preprocess/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/preprocess/__pycache__/observation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/observation.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/preprocess/__pycache__/observation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/observation.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/preprocess/__pycache__/ops.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/ops.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/preprocess/__pycache__/ops.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/ops.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/preprocess/observation.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from copy import deepcopy
16 |
17 |
18 | class ObsPreprocessor:
19 | def __init__(self, ops, observation_space, observation_dtypes=None):
20 | """
21 | :param ops: List[Operation]
22 | :param observation_space: Dict[ObsKey, Shape]
23 | :param observation_dtypes: Dict[ObsKey, dtype_str]
24 | """
25 | cur_space = deepcopy(observation_space)
26 | cur_dtypes = deepcopy(observation_dtypes)
27 |
28 | rank_to_names = self._bld_rank_to_names(observation_space)
29 |
30 | for op in ops:
31 | if op.name_filters:
32 | names = op.name_filters
33 | elif op.rank_filters:
34 | names = []
35 | for rank in op.rank_filters:
36 | names += rank_to_names[rank]
37 | else:
38 | names = list(cur_space.keys())
39 |
40 | cur_space = self._update(names, cur_space, op.update_shape)
41 | if observation_dtypes:
42 | cur_dtypes = self._update(names, cur_dtypes, op.update_dtype)
43 | rank_to_names = self._bld_rank_to_names(observation_space)
44 |
45 | self.ops = ops
46 | self.observation_space = cur_space
47 | self.observation_dtypes = cur_dtypes
48 | self.rank_to_names = rank_to_names
49 |
50 | def __call__(self, obs):
51 | for op in self.ops:
52 | obs = op.update_obs(obs)
53 | return obs
54 |
55 | def reset(self):
56 | for o in self.ops:
57 | o.reset()
58 |
59 | def _bld_rank_to_names(self, obs_space):
60 | d = {1: [], 2: [], 3: [], 4: []}
61 | for name, shape in obs_space.items():
62 | d[len(shape)].append(name)
63 | return d
64 |
65 | def _update(self, names, prev, fn):
66 | cur = {}
67 | for name in names:
68 | cur[name] = prev[name]
69 | del prev[name]
70 | update = fn(cur)
71 | return {**prev, **update}
72 |
--------------------------------------------------------------------------------
/mcs/registry/__init__.py:
--------------------------------------------------------------------------------
1 | from .registry import Registry
2 |
3 | REGISTRY = Registry()
4 |
--------------------------------------------------------------------------------
/mcs/registry/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/registry/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/registry/__pycache__/registry.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/registry.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/registry/__pycache__/registry.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/registry.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import RewardNormModule
2 | from .normalizers import Scale, Clip, Identity
3 |
4 | REWARD_NORM_REG = [Scale, Clip, Identity]
5 |
--------------------------------------------------------------------------------
/mcs/rewardnorm/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/__pycache__/normalizers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/normalizers.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/__pycache__/normalizers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/normalizers.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .rewnorm_module import RewardNormModule
2 |
--------------------------------------------------------------------------------
/mcs/rewardnorm/base/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/base/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/rewardnorm/base/rewnorm_module.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from mcs.utils.requires_args import RequiresArgsMixin
4 |
5 |
6 | class RewardNormModule(RequiresArgsMixin, metaclass=abc.ABCMeta):
7 | def __call__(self, reward):
8 | """
9 | Normalizes a reward tensor.
10 |
11 | :param reward: torch.Tensor (1D)
12 | :return:
13 | """
14 | raise NotImplementedError
15 |
--------------------------------------------------------------------------------
/mcs/rewardnorm/normalizers.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import torch
16 | from .base import RewardNormModule
17 |
18 |
19 | class Clip(RewardNormModule):
20 | args = {"floor": -1, "ceil": 1}
21 |
22 | def __init__(self, floor, ceil):
23 | self.floor = floor
24 | self.ceil = ceil
25 |
26 | @classmethod
27 | def from_args(cls, args):
28 | return cls(args.floor, args.ceil)
29 |
30 | def __call__(self, reward):
31 | return torch.clamp(reward, self.floor, self.ceil)
32 |
33 |
34 | class Scale(RewardNormModule):
35 | args = {"coefficient": 0.1}
36 |
37 | def __init__(self, coefficient):
38 | self.coefficient = coefficient
39 |
40 | @classmethod
41 | def from_args(cls, args):
42 | return cls(args.coefficient)
43 |
44 | def __call__(self, reward):
45 | return self.coefficient * reward
46 |
47 |
48 | class Identity(RewardNormModule):
49 | args = {}
50 |
51 | @classmethod
52 | def from_args(cls, args):
53 | return cls()
54 |
55 | def __call__(self, item):
56 | return item
57 |
--------------------------------------------------------------------------------
/mcs/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | from .util import listd_to_dlist, dlist_to_listd, dtensor_to_dev
16 |
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/logging.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/logging.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/logging.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/logging.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/requires_args.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/requires_args.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/requires_args.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/requires_args.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/script_helpers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/script_helpers.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/script_helpers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/script_helpers.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/mcs/utils/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/mcs/utils/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import os
16 | from collections import namedtuple
17 | from time import time
18 |
19 | import torch
20 |
21 | from mcs.utils.util import HeapQueue
22 |
23 |
24 | class ModelSaver:
25 | BufferEntry = namedtuple(
26 | "BufferEntry", ["reward", "priority", "network", "optimizer"]
27 | )
28 |
29 | def __init__(self, nb_top_model, log_id_dir):
30 | self.nb_top_model = nb_top_model
31 | self._buffer = HeapQueue(nb_top_model)
32 | self._log_id_dir = log_id_dir
33 |
34 | def append_if_better(self, reward, network, optimizer):
35 | self._buffer.push(
36 | self.BufferEntry(
37 | reward, time(), network.state_dict(), optimizer.state_dict()
38 | )
39 | )
40 |
41 | def write_state_dicts(self, epoch_id):
42 | save_dir = os.path.join(self._log_id_dir, str(epoch_id))
43 | if len(self._buffer) > 0:
44 | os.makedirs(save_dir)
45 | for j, buff_entry in enumerate(self._buffer.flush()):
46 | torch.save(
47 | buff_entry.network,
48 | os.path.join(
49 | save_dir,
50 | "model_{}_{}.pth".format(j + 1, int(buff_entry.reward)),
51 | ),
52 | )
53 | torch.save(
54 | buff_entry.optimizer,
55 | os.path.join(
56 | save_dir,
57 | "optimizer_{}_{}.pth".format(j + 1, int(buff_entry.reward)),
58 | ),
59 | )
60 |
61 |
62 | class SimpleModelSaver:
63 | def __init__(self, log_id_dir):
64 | self._log_id_dir = log_id_dir
65 |
66 | def save_state_dicts(self, network, step_count, optimizer=None):
67 | save_dir = os.path.join(self._log_id_dir, str(step_count))
68 | os.makedirs(save_dir)
69 | torch.save(
70 | network.state_dict(),
71 | os.path.join(save_dir, "model_{}.pth".format(step_count)),
72 | )
73 | if optimizer is not None:
74 | torch.save(
75 | optimizer.state_dict(),
76 | os.path.join(save_dir, "optimizer_{}.pth".format(step_count)),
77 | )
78 |
--------------------------------------------------------------------------------
/mcs/utils/requires_args.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import json
16 | import abc
17 |
18 |
19 | class RequiresArgsMixin(metaclass=abc.ABCMeta):
20 | """
21 | This mixin makes it so that subclasses must implement an args class
22 | attribute. These arguments are parsed at runtime and the user is offered a
23 | chance to change any desired args. Classes the use this mixin must
24 | implement the from_args() class method. from_args() is essentially a
25 | secondary constructor.
26 | """
27 |
28 | args = None # Dict[str, Any]
29 |
30 | @classmethod
31 | def check_args_implemented(cls):
32 | if cls.args is None:
33 | raise NotImplementedError(
34 | 'Subclass must define class attribute "args"'
35 | )
36 |
37 | @classmethod
38 | def prompt(cls, provided=None):
39 | """
40 | Display defaults as JSON, prompt user for changes.
41 |
42 | :param provided: Dict[str, Any], Override default prompts.
43 | :return: Dict[str, Any] Updated config dictionary.
44 | """
45 | if provided is not None:
46 | overrides = {k: v for k, v in provided.items() if k in cls.args}
47 | args = {**cls.args, **overrides}
48 | else:
49 | args = cls.args
50 | return cls._prompt(cls.__name__, args)
51 |
52 | @staticmethod
53 | def _prompt(name, args):
54 | """
55 | Display defaults as JSON, prompt user for changes.
56 |
57 | :param name: str Name of class
58 | :param args: Dict[str, Any]
59 | :return: Dict[str, Any] Updated config dictionary.
60 | """
61 | if not args:
62 | return args
63 |
64 | user_input = input(
65 | "\n{} Defaults:\n{}\n"
66 | "Press ENTER to use defaults. Otherwise, "
67 | "modify JSON keys then press ENTER.\n".format(
68 | name, json.dumps(args, indent=2, sort_keys=True)
69 | )
70 | + 'Example: {"x": True, "gamma": 0.001}\n'
71 | )
72 |
73 | # use defaults if no changes specified
74 | if user_input == "":
75 | return args
76 |
77 | updates = json.loads(user_input)
78 | return {**args, **updates}
79 |
80 | @classmethod
81 | @abc.abstractmethod
82 | def from_args(cls, *argss, **kwargs):
83 | raise NotImplementedError
84 |
--------------------------------------------------------------------------------
/mcs/utils/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2018 Heron Systems, Inc.
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 | import json
16 | import heapq
17 | from collections import OrderedDict
18 |
19 | import numpy as np
20 | import torch
21 |
22 |
23 | def listd_to_dlist(list_of_dicts):
24 | """
25 | Converts a list of dictionaries to a dictionary of lists. Preserves key
26 | order.
27 |
28 | K is type of key.
29 | V is type of value.
30 | :param list_of_dicts: List[Dict[K, V]]
31 | :return: Dict[K, List[V]]
32 | """
33 | new_dict = OrderedDict()
34 | for d in list_of_dicts:
35 | for k, v in d.items():
36 | if k not in new_dict:
37 | new_dict[k] = [v]
38 | else:
39 | new_dict[k].append(v)
40 | return new_dict
41 |
42 |
43 | def dlist_to_listd(dict_of_lists):
44 | """
45 | Converts a dictionary of lists to a list of dictionaries. Preserves key
46 | order.
47 |
48 | K is type of key.
49 | V is type of value.
50 | :param dict_of_lists: Dict[K, List[V]]
51 | :return: List[Dict[K, V]]
52 | """
53 | keys = dict_of_lists.keys()
54 | list_len = len(dict_of_lists[next(iter(keys))])
55 | new_list = []
56 | for i in range(list_len):
57 | temp_d = OrderedDict()
58 | for k in keys:
59 | temp_d[k] = dict_of_lists[k][i]
60 | new_list.append(temp_d)
61 | return new_list
62 |
63 |
64 | def dtensor_to_dev(d_tensor, device):
65 | """
66 | Move a dictionary of tensors to a device.
67 |
68 | :param d_tensor: Dict[str, Tensor]
69 | :param device: torch.device
70 | :return: Dict[str, Tensor] on desired device.
71 | """
72 | return {k: v.to(device) for k, v in d_tensor.items()}
73 |
74 |
75 | def json_to_dict(file_path):
76 | """Read JSON config."""
77 | json_object = json.load(open(file_path, "r"))
78 | return json_object
79 |
80 |
81 | _numpy_to_torch_dtype = {
82 | np.float16: torch.float16,
83 | np.float32: torch.float32,
84 | np.float64: torch.float64,
85 | np.uint8: torch.uint8,
86 | np.int8: torch.int8,
87 | np.int16: torch.int16,
88 | np.int32: torch.int32,
89 | np.int64: torch.int64,
90 | }
91 | _torch_to_numpy_dtype = {v: k for k, v in _numpy_to_torch_dtype.items()}
92 |
93 |
94 | def numpy_to_torch_dtype(dtype):
95 |
96 | # check if dtype is weird and convert to familiar format
97 | if type(dtype) == np.dtype:
98 | dtype = dtype.type
99 |
100 | if dtype not in _numpy_to_torch_dtype:
101 | raise ValueError(
102 | "Could not convert numpy dtype {} to a torch dtype.".format(
103 | dtype
104 | )
105 | )
106 |
107 | return _numpy_to_torch_dtype[dtype]
108 |
109 |
110 | def torch_to_numpy_dtype(dtype):
111 | if dtype not in _torch_to_numpy_dtype:
112 | raise ValueError(
113 | "Could not convert torch dtype {} to a numpy dtype.".format(
114 | dtype
115 | )
116 | )
117 |
118 | return _torch_to_numpy_dtype[dtype]
119 |
120 |
121 | class CircularBuffer(object):
122 | def __init__(self, size):
123 | self.index = 0
124 | self.size = size
125 | self._data = []
126 |
127 | def append(self, value):
128 | if len(self._data) == self.size:
129 | self._data[self.index] = value
130 | else:
131 | self._data.append(value)
132 | self.index = (self.index + 1) % self.size
133 |
134 | def is_empty(self):
135 | return self._data == []
136 |
137 | def not_empty(self):
138 | return not self.is_empty()
139 |
140 | def is_full(self):
141 | return len(self) == self.size
142 |
143 | def not_full(self):
144 | return not self.is_full()
145 |
146 | def __getitem__(self, key):
147 | """get element by index like a regular array"""
148 | return self._data[key]
149 |
150 | def __setitem__(self, key, value):
151 | self._data[key] = value
152 |
153 | def __repr__(self):
154 | """return string representation"""
155 | return self._data.__repr__() + " (" + str(len(self._data)) + " items)"
156 |
157 | def __len__(self):
158 | return len(self._data)
159 |
160 |
161 | class HeapQueue:
162 | def __init__(self, maxlen):
163 | self.q = []
164 | self.maxlen = maxlen
165 |
166 | def push(self, item):
167 | if len(self.q) < self.maxlen:
168 | heapq.heappush(self.q, item)
169 | else:
170 | heapq.heappushpop(self.q, item)
171 |
172 | def flush(self):
173 | q = self.q
174 | self.q = []
175 | return q
176 |
177 | def __len__(self):
178 | return len(self.q)
179 |
180 |
181 | class DotDict(dict):
182 | """
183 | Dictionary to access attributes
184 | """
185 |
186 | __getattr__ = dict.get
187 | __setattr__ = dict.__setitem__
188 | __delattr__ = dict.__delitem__
189 |
190 | # Support pickling
191 | def __getstate__(obj):
192 | return dict(obj.items())
193 |
194 | def __setstate__(cls, attributes):
195 | return DotDict(**attributes)
196 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | aiohttp==3.7.0
3 | aiohttp-cors==0.7.0
4 | aioredis==1.3.1
5 | aiosignal==1.2.0
6 | ale-py==0.7.4
7 | argon2-cffi==21.3.0
8 | argon2-cffi-bindings==21.2.0
9 | asttokens==2.0.5
10 | astunparse==1.6.3
11 | async-generator==1.10
12 | async-timeout==3.0.1
13 | atari-py==0.2.6
14 | attrs==21.2.0
15 | autopep8==1.6.0
16 | backcall==0.2.0
17 | backports==1.0
18 | backports.functools_lru_cache==1.6.4
19 | beautifulsoup4==4.10.0
20 | bleach==4.1.0
21 | blessings==1.7
22 | bokeh==2.4.2
23 | box2d==2.3.10
24 | branca==0.4.2
25 | ca-certificates==2021.10.8
26 | cachetools==4.2.1
27 | certifi==2021.10.8
28 | cffi==1.14.0
29 | chardet==3.0.4
30 | charset-normalizer==2.0.9
31 | click==8.0.3
32 | click-plugins==1.1.1
33 | cligj==0.7.2
34 | cloudpickle==1.2.0
35 | colorama==0.4.4
36 | colorcet==3.0.0
37 | cryptography==36.0.2
38 | cvxpy==1.2.0
39 | cycler==0.11.0
40 | debugpy==1.5.1
41 | decorator==5.1.0
42 | deepdiff==5.8.0
43 | defusedxml==0.6.0
44 | deprecated==1.2.13
45 | docker-pycreds==0.4.0
46 | docopt==0.6.2
47 | drain3==0.9.10
48 | ecos==2.0.10
49 | entrypoints==0.3
50 | enum34==1.1.10
51 | executing==0.8.3
52 | filelock==3.4.0
53 | fiona==1.8.21
54 | flatbuffers==2.0
55 | folium==0.12.1.post1
56 | fonttools==4.29.1
57 | frozenlist==1.2.0
58 | future==0.18.2
59 | gast==0.4.0
60 | gitdb==4.0.9
61 | gitpython==3.1.27
62 | glfw==1.11.0
63 | google==3.0.0
64 | google-api-core==2.3.0
65 | google-auth==2.3.3
66 | google-auth-oauthlib==0.4.6
67 | google-pasta==0.2.0
68 | googleapis-common-protos==1.54.0
69 | gpustat==0.6.0
70 | grpcio==1.43.0rc1
71 | gym==0.23.1
72 | gym-notices==0.0.6
73 | gym-super-mario-bros==7.3.2
74 | h11==0.13.0
75 | h5py==3.6.0
76 | hiredis==2.0.0
77 | holoviews==1.14.9a1
78 | hvplot==0.8.0a10
79 | idna==3.3
80 | importlib-metadata==4.11.3
81 | importlib-resources==5.4.0
82 | ipykernel==6.6.0
83 | ipython==7.30.1
84 | ipython-genutils==0.2.0
85 | jedi==0.18.1
86 | jinja2==3.0.3
87 | joblib==1.1.0
88 | json5==0.9.6
89 | jsonpickle==1.5.1
90 | jsonschema==4.2.1
91 | jupyter-client==7.1.0
92 | jupyter-core==4.9.1
93 | jupyter_client==7.1.2
94 | jupyter_core==4.9.2
95 | jupyterlab-pygments==0.1.2
96 | jupyterlab-widgets==1.0.2
97 | keras==2.7.0
98 | keras-preprocessing==1.1.2
99 | kiwisolver==1.4.0
100 | kmeans-pytorch==0.3
101 | ld_impl_linux-64==2.36.1
102 | libclang==12.0.0
103 | libffi==3.4.2
104 | libgcc-ng==11.2.0
105 | libgomp==11.2.0
106 | libnsl==2.0.0
107 | libsodium==1.0.18
108 | libstdcxx-ng==11.2.0
109 | libzlib==1.2.11
110 | markdown==3.3.6
111 | markupsafe==2.0.1
112 | matplotlib==3.5.1
113 | matplotlib-inline==0.1.3
114 | mc-bin-client==1.0.1
115 | mistune==0.8.4
116 | mock==4.0.3
117 | mpyq==0.2.5
118 | msgpack==1.0.3
119 | msgpack-numpy==0.4.7.1
120 | multidict==5.2.0
121 | munch==2.5.0
122 | nbclient==0.5.9
123 | nbconvert==6.3.0
124 | nbformat==5.1.3
125 | ncurses==6.3
126 | nes-py==8.1.8
127 | nest-asyncio==1.5.4
128 | networkx==2.7.1
129 | notebook==6.4.6
130 | numpy==1.19.5
131 | numpy-stl==2.16.3
132 | nvidia-ml-py3==7.352.0
133 | oauthlib==3.1.1
134 | opencensus==0.8.0
135 | opencensus-context==0.1.2
136 | opencv-python==4.5.4.60
137 | openssl==3.0.0
138 | opt-einsum==3.3.0
139 | ordered-set==4.1.0
140 | osqp==0.6.2.post5
141 | outcome==1.1.0
142 | packaging==21.3
143 | pandas==1.3.5
144 | pandocfilters==1.5.0
145 | panel==0.12.6
146 | param==1.12.0
147 | parso==0.8.3
148 | pathtools==0.1.2
149 | pexpect==4.8.0
150 | pickleshare==0.7.5
151 | pillow==9.0.1
152 | pip==22.0.4
153 | portpicker==1.5.0
154 | prometheus-client==0.12.0
155 | promise==2.3
156 | prompt-toolkit==3.0.24
157 | protobuf==3.19.1
158 | psutil==5.8.0
159 | ptyprocess==0.7.0
160 | pure_eval==0.2.2
161 | py-spy==0.3.11
162 | pyasn1==0.4.8
163 | pyasn1-modules==0.2.8
164 | pybullet==2.7.2
165 | pycodestyle==2.8.0
166 | pycparser==2.21
167 | pyct==0.4.8
168 | pygame==2.1.2
169 | pyglet==1.5.11
170 | pygments==2.10.0
171 | pyopenssl==22.0.0
172 | pyparsing==3.0.7
173 | pyrsistent==0.18.0
174 | pysc2==3.0.0
175 | pysocks==1.7.1
176 | python==3.8.12
177 | python-dateutil==2.8.2
178 | python-tsp==0.2.1
179 | python-utils==3.1.0
180 | python_abi==3.8
181 | pytz==2021.3
182 | pyviz-comms==2.1.0
183 | pyyaml==6.0
184 | pyzmq==22.3.0
185 | qdldl==0.1.5.post2
186 | qtpy==1.11.3
187 | ray==1.3.0
188 | readline==8.1
189 | redis==4.0.2
190 | requests==2.26.0
191 | requests-oauthlib==1.3.0
192 | rsa==4.8
193 | s2clientprotocol==5.0.9.87702.0
194 | s2protocol==5.0.9.87702.0
195 | scikit-learn==1.0.2
196 | scipy==1.8.0
197 | scs==3.2.0
198 | seaborn==0.11.2
199 | selenium==4.1.3
200 | send2trash==1.8.0
201 | sentry-sdk==1.5.9
202 | setproctitle==1.2.2
203 | setuptools==59.5.0
204 | shapely==1.8.1
205 | shortuuid==1.0.8
206 | six==1.12.0
207 | sk-video==1.1.10
208 | sklearn==0.0
209 | smmap==5.0.0
210 | sniffio==1.2.0
211 | sortedcontainers==2.4.0
212 | soupsieve==2.3.1
213 | sqlite==3.37.1
214 | stack_data==0.2.0
215 | tabulate==0.8.9
216 | tap==0.2
217 | tensorboard==2.7.0
218 | tensorboard-data-server==0.6.1
219 | tensorboard-plugin-wit==1.8.0
220 | tensorboardx==2.5
221 | tensorflow==2.7.0
222 | tensorflow-estimator==2.7.0
223 | tensorflow-io-gcs-filesystem==0.22.0
224 | termcolor==1.1.0
225 | terminado==0.12.1
226 | testpath==0.5.0
227 | threadpoolctl==3.1.0
228 | tk==8.6.12
229 | toml==0.10.2
230 | torch==1.11.0+cu113
231 | torchaudio==0.11.0+cu113
232 | torchvision==0.12.0+cu113
233 | tornado==6.1
234 | tqdm==4.63.0
235 | traitlets==5.1.1
236 | trio==0.20.0
237 | trio-websocket==0.9.2
238 | tsplib95==0.7.1
239 | typing-extensions==4.0.1
240 | urllib3==1.26.7
241 | v==0.0.0
242 | wcwidth==0.2.5
243 | webencodings==0.5.1
244 | websocket-client==1.3.2
245 | werkzeug==2.0.2
246 | wheel==0.37.0
247 | whichcraft==0.6.1
248 | widgetsnbextension==3.5.2
249 | wrapt==1.13.3
250 | wsproto==1.1.0
251 | xz==5.2.5
252 | yaml==0.2.5
253 | yarl==1.7.2
254 | zeromq==4.3.4
255 | zipp==3.6.0
256 | zlib==1.2.11
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from mcs.globals import VERSION
3 |
4 | # https://github.com/kennethreitz/setup.py/blob/master/setup.py
5 |
6 |
7 | with open("README.md", "r") as fh:
8 | long_description = fh.read()
9 |
10 | extras = {
11 | "profiler": ["pyinstrument>=2.0"],
12 | }
13 | test_deps = ["pytest"]
14 |
15 | all_deps = []
16 | for group_name in extras:
17 | all_deps += extras[group_name]
18 | all_deps = all_deps + test_deps
19 | extras["all"] = all_deps
20 |
21 |
22 | setup(
23 | name="mcs",
24 | version=VERSION,
25 | python_requires=">=3.5.0",
26 | packages=find_packages(),
27 | )
28 |
--------------------------------------------------------------------------------