├── .DS_Store ├── README.md ├── mcs ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── globals.cpython-37.pyc │ └── globals.cpython-38.pyc ├── actor │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── ac_eval.cpython-37.pyc │ │ ├── ac_eval.cpython-38.pyc │ │ ├── ac_rollout.cpython-37.pyc │ │ ├── ac_rollout.cpython-38.pyc │ │ ├── impala.cpython-37.pyc │ │ └── impala.cpython-38.pyc │ ├── ac_eval.py │ ├── ac_rollout.py │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── ac_helper.cpython-37.pyc │ │ │ ├── ac_helper.cpython-38.pyc │ │ │ ├── actor_module.cpython-37.pyc │ │ │ └── actor_module.cpython-38.pyc │ │ ├── ac_helper.py │ │ └── actor_module.py │ └── impala.py ├── agent │ ├── .DS_Store │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── actor_critic.cpython-37.pyc │ │ └── actor_critic.cpython-38.pyc │ ├── actor_critic.py │ └── base │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── agent_module.cpython-37.pyc │ │ └── agent_module.cpython-38.pyc │ │ └── agent_module.py ├── config │ └── default.json ├── container │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── distrib.cpython-37.pyc │ │ ├── distrib.cpython-38.pyc │ │ ├── evaluation.cpython-37.pyc │ │ ├── evaluation.cpython-38.pyc │ │ ├── init.cpython-37.pyc │ │ ├── init.cpython-38.pyc │ │ ├── local.cpython-37.pyc │ │ └── local.cpython-38.pyc │ ├── actorlearner │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── learner_container.cpython-37.pyc │ │ │ ├── learner_container.cpython-38.pyc │ │ │ ├── rollout_queuer.cpython-37.pyc │ │ │ ├── rollout_queuer.cpython-38.pyc │ │ │ ├── rollout_worker.cpython-37.pyc │ │ │ └── rollout_worker.cpython-38.pyc │ │ ├── learner_container.py │ │ ├── rollout_queuer.py │ │ └── rollout_worker.py │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── container.cpython-37.pyc │ │ │ ├── container.cpython-38.pyc │ │ │ ├── nccl_optimizer.cpython-37.pyc │ │ │ ├── nccl_optimizer.cpython-38.pyc │ │ │ ├── updater.cpython-37.pyc │ │ │ └── updater.cpython-38.pyc │ │ ├── container.py │ │ ├── nccl_optimizer.py │ │ └── updater.py │ ├── distrib.py │ ├── evaluation.py │ ├── evaluation_thread.py │ ├── init.py │ ├── local.py │ └── render.py ├── env │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── _gym_wrappers.cpython-37.pyc │ │ ├── _gym_wrappers.cpython-38.pyc │ │ ├── _spaces.cpython-37.pyc │ │ ├── _spaces.cpython-38.pyc │ │ ├── openai_gym.cpython-37.pyc │ │ └── openai_gym.cpython-38.pyc │ ├── _gym_wrappers.py │ ├── _spaces.py │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── _env.cpython-37.pyc │ │ │ ├── _env.cpython-38.pyc │ │ │ ├── env_module.cpython-37.pyc │ │ │ └── env_module.cpython-38.pyc │ │ ├── _env.py │ │ └── env_module.py │ └── openai_gym.py ├── evaluate.py ├── exp │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── repetitive_buffer.cpython-37.pyc │ │ ├── repetitive_buffer.cpython-38.pyc │ │ ├── replay.cpython-37.pyc │ │ ├── replay.cpython-38.pyc │ │ ├── rollout.cpython-37.pyc │ │ └── rollout.cpython-38.pyc │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── exp_module.cpython-37.pyc │ │ │ ├── exp_module.cpython-38.pyc │ │ │ ├── spec_builder.cpython-37.pyc │ │ │ └── spec_builder.cpython-38.pyc │ │ ├── exp_module.py │ │ └── spec_builder.py │ ├── repetitive_buffer.py │ ├── replay.py │ └── rollout.py ├── globals.py ├── learner │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── ac_rollout.cpython-37.pyc │ │ ├── ac_rollout.cpython-38.pyc │ │ ├── impala.cpython-37.pyc │ │ └── impala.cpython-38.pyc │ ├── ac_rollout.py │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dm_return_scale.cpython-37.pyc │ │ │ ├── dm_return_scale.cpython-38.pyc │ │ │ ├── learner_module.cpython-37.pyc │ │ │ └── learner_module.cpython-38.pyc │ │ ├── dm_return_scale.py │ │ └── learner_module.py │ └── impala.py ├── logs │ └── .DS_Store ├── manager │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── simple_env_manager.cpython-37.pyc │ │ ├── simple_env_manager.cpython-38.pyc │ │ ├── subproc_env_manager.cpython-37.pyc │ │ └── subproc_env_manager.cpython-38.pyc │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── manager_module.cpython-37.pyc │ │ │ └── manager_module.cpython-38.pyc │ │ └── manager_module.py │ ├── simple_env_manager.py │ └── subproc_env_manager.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── attention.cpython-37.pyc │ │ ├── attention.cpython-38.pyc │ │ ├── norm.cpython-37.pyc │ │ ├── norm.cpython-38.pyc │ │ ├── sequence.cpython-37.pyc │ │ ├── sequence.cpython-38.pyc │ │ ├── spatial.cpython-37.pyc │ │ └── spatial.cpython-38.pyc │ ├── attention.py │ ├── memory.py │ ├── mlp.py │ ├── norm.py │ ├── sequence.py │ └── spatial.py ├── network │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── modular_network.cpython-37.pyc │ │ └── modular_network.cpython-38.pyc │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── base.cpython-37.pyc │ │ │ ├── base.cpython-38.pyc │ │ │ ├── network_module.cpython-37.pyc │ │ │ ├── network_module.cpython-38.pyc │ │ │ ├── submodule.cpython-37.pyc │ │ │ └── submodule.cpython-38.pyc │ │ ├── base.py │ │ ├── network_module.py │ │ └── submodule.py │ ├── modular_network.py │ ├── net1d │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── identity_1d.cpython-37.pyc │ │ │ ├── identity_1d.cpython-38.pyc │ │ │ ├── linear.cpython-37.pyc │ │ │ ├── linear.cpython-38.pyc │ │ │ ├── lstm.cpython-37.pyc │ │ │ ├── lstm.cpython-38.pyc │ │ │ ├── submodule_1d.cpython-37.pyc │ │ │ └── submodule_1d.cpython-38.pyc │ │ ├── identity_1d.py │ │ ├── linear.py │ │ ├── lstm.py │ │ └── submodule_1d.py │ ├── net2d │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── deconv.cpython-37.pyc │ │ │ ├── deconv.cpython-38.pyc │ │ │ ├── identity_2d.cpython-37.pyc │ │ │ ├── identity_2d.cpython-38.pyc │ │ │ ├── submodule_2d.cpython-37.pyc │ │ │ └── submodule_2d.cpython-38.pyc │ │ ├── deconv.py │ │ ├── identity_2d.py │ │ └── submodule_2d.py │ ├── net3d │ │ ├── RelationalMHDPA.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── RelationalMHDPA.cpython-38.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── deconv.cpython-38.pyc │ │ │ ├── four_conv.cpython-37.pyc │ │ │ ├── four_conv.cpython-38.pyc │ │ │ ├── identity_3d.cpython-37.pyc │ │ │ ├── identity_3d.cpython-38.pyc │ │ │ ├── submodule_3d.cpython-37.pyc │ │ │ └── submodule_3d.cpython-38.pyc │ │ ├── _resnets.py │ │ ├── four_conv.py │ │ ├── identity_3d.py │ │ ├── networks.py │ │ ├── rmc.py │ │ └── submodule_3d.py │ └── net4d │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── identity_4d.cpython-37.pyc │ │ ├── identity_4d.cpython-38.pyc │ │ ├── submodule_4d.cpython-37.pyc │ │ └── submodule_4d.cpython-38.pyc │ │ ├── identity_4d.py │ │ └── submodule_4d.py ├── preprocess │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── observation.cpython-37.pyc │ │ ├── observation.cpython-38.pyc │ │ ├── ops.cpython-37.pyc │ │ └── ops.cpython-38.pyc │ ├── observation.py │ └── ops.py ├── registry │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── registry.cpython-37.pyc │ │ └── registry.cpython-38.pyc │ └── registry.py ├── rewardnorm │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── normalizers.cpython-37.pyc │ │ └── normalizers.cpython-38.pyc │ ├── base │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── rewnorm_module.cpython-37.pyc │ │ │ └── rewnorm_module.cpython-38.pyc │ │ └── rewnorm_module.py │ └── normalizers.py ├── train.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── logging.cpython-37.pyc │ ├── logging.cpython-38.pyc │ ├── requires_args.cpython-37.pyc │ ├── requires_args.cpython-38.pyc │ ├── script_helpers.cpython-37.pyc │ ├── script_helpers.cpython-38.pyc │ ├── util.cpython-37.pyc │ └── util.cpython-38.pyc │ ├── logging.py │ ├── requires_args.py │ ├── script_helpers.py │ └── util.py ├── requirements.txt └── setup.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/.DS_Store -------------------------------------------------------------------------------- /mcs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/.DS_Store -------------------------------------------------------------------------------- /mcs/__init__.py: -------------------------------------------------------------------------------- 1 | def register_agent(agent_cls): 2 | from mcs.registry import REGISTRY 3 | 4 | REGISTRY.register_agent(agent_cls) 5 | 6 | 7 | def register_actor(actor_cls): 8 | from mcs.registry import REGISTRY 9 | 10 | REGISTRY.register_actor(actor_cls) 11 | 12 | 13 | def register_exp(exp_cls): 14 | from mcs.registry import REGISTRY 15 | 16 | REGISTRY.register_exp(exp_cls) 17 | 18 | 19 | def register_learner(learner_cls): 20 | from mcs.registry import REGISTRY 21 | 22 | REGISTRY.register_learner(learner_cls) 23 | 24 | 25 | def register_env(env_cls): 26 | from mcs.registry import REGISTRY 27 | 28 | REGISTRY.register_env(env_cls) 29 | 30 | 31 | def register_reward_norm(rwd_norm_cls): 32 | from mcs.registry import REGISTRY 33 | 34 | REGISTRY.register_reward_normalizer(rwd_norm_cls) 35 | 36 | 37 | def register_network(network_cls): 38 | from mcs.registry import REGISTRY 39 | 40 | REGISTRY.register_network(network_cls) 41 | 42 | 43 | def register_submodule(submod_cls): 44 | from mcs.registry import REGISTRY 45 | 46 | REGISTRY.register_submodule(submod_cls) 47 | 48 | 49 | def register_manager(manager_cls): 50 | from mcs.registry import REGISTRY 51 | 52 | REGISTRY.register_manager(manager_cls) 53 | -------------------------------------------------------------------------------- /mcs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/__pycache__/globals.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/globals.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/__pycache__/globals.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/__pycache__/globals.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/__init__.py: -------------------------------------------------------------------------------- 1 | from .base.actor_module import ActorModule 2 | from .ac_rollout import ACRolloutActorTrain 3 | from mcs.actor.ac_eval import ACActorEval, ACActorEvalSample 4 | from .impala import ImpalaHostActor, ImpalaWorkerActor,ImpalaHostTargetActor 5 | 6 | ACTOR_REG = [ 7 | ACRolloutActorTrain, 8 | ACActorEval, 9 | ACActorEvalSample, 10 | ImpalaHostActor, 11 | ImpalaWorkerActor, 12 | ImpalaHostTargetActor 13 | ] 14 | -------------------------------------------------------------------------------- /mcs/actor/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/actor/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/__pycache__/ac_eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_eval.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/actor/__pycache__/ac_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_eval.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/__pycache__/ac_rollout.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_rollout.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/actor/__pycache__/ac_rollout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/ac_rollout.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/__pycache__/impala.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/impala.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/actor/__pycache__/impala.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/__pycache__/impala.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/ac_eval.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from mcs.actor import ActorModule 4 | from mcs.actor.base.ac_helper import ACActorHelperMixin 5 | 6 | 7 | class ACActorEval(ActorModule, ACActorHelperMixin): 8 | args = {} 9 | 10 | @classmethod 11 | def from_args(cls, action_space): 12 | return cls(action_space) 13 | 14 | @staticmethod 15 | def output_space(action_space): 16 | head_dict = {"critic": (1,), **action_space} 17 | return head_dict 18 | 19 | def compute_action_exp(self, preds, internals, obs, available_actions): 20 | actions = OrderedDict() 21 | 22 | for key in self.action_keys: 23 | logit = self.flatten_logits(preds[key]) 24 | 25 | softmax = self.softmax(logit) 26 | action = self.sample_action(softmax) 27 | 28 | actions[key] = action.cpu() 29 | return actions, {"value": preds["critic"].squeeze(-1)} 30 | 31 | @classmethod 32 | def _exp_spec( 33 | cls, rollout_len, batch_sz, obs_space, act_space, internal_space 34 | ): 35 | return {} 36 | 37 | 38 | class ACActorEvalSample(ACActorEval): 39 | def compute_action_exp(self, preds, internals, obs, available_actions): 40 | actions = OrderedDict() 41 | 42 | for key in self.action_keys: 43 | logit = self.flatten_logits(preds[key]) 44 | 45 | softmax = self.softmax(logit) 46 | action = self.sample_action(softmax) 47 | 48 | actions[key] = action.cpu() 49 | return actions, {"value": preds["critic"].squeeze(-1)} 50 | -------------------------------------------------------------------------------- /mcs/actor/ac_rollout.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from collections import OrderedDict 16 | 17 | import torch 18 | 19 | from mcs.actor.base.ac_helper import ACActorHelperMixin 20 | from mcs.actor.base.actor_module import ActorModule 21 | 22 | 23 | class ACRolloutActorTrain(ActorModule, ACActorHelperMixin): 24 | args = {} 25 | 26 | @classmethod 27 | def from_args(cls, action_space): 28 | return cls(action_space) 29 | 30 | @staticmethod 31 | def output_space(action_space): 32 | head_dict = {"critic": (1,), **action_space} 33 | return head_dict 34 | 35 | def compute_action_exp(self, preds, internals, obs, available_actions): 36 | values = preds["critic"].squeeze(1) 37 | 38 | actions = OrderedDict() 39 | log_probs = [] 40 | entropies = [] 41 | 42 | for key in self.action_keys: 43 | logit = self.flatten_logits(preds[key]) 44 | 45 | log_softmax, softmax = self.log_softmax(logit), self.softmax(logit) 46 | entropy = self.entropy(log_softmax, softmax) 47 | action = self.sample_action(softmax) 48 | 49 | entropies.append(entropy) 50 | log_probs.append(self.log_probability(log_softmax, action)) 51 | actions[key] = action.cpu() 52 | 53 | log_probs = torch.cat(log_probs, dim=1) 54 | entropies = torch.cat(entropies, dim=1) 55 | 56 | return ( 57 | actions, 58 | {"log_probs": log_probs, "entropies": entropies, "values": values}, 59 | ) 60 | 61 | @classmethod 62 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space): 63 | act_key_len = len(act_space.keys()) 64 | 65 | spec = { 66 | "log_probs": (exp_len, batch_sz, act_key_len), 67 | "entropies": (exp_len, batch_sz, act_key_len), 68 | "values": (exp_len, batch_sz), 69 | } 70 | 71 | return spec 72 | -------------------------------------------------------------------------------- /mcs/actor/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor_module import ActorModule 2 | -------------------------------------------------------------------------------- /mcs/actor/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/actor/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/base/__pycache__/ac_helper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/ac_helper.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/actor/base/__pycache__/ac_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/ac_helper.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/base/__pycache__/actor_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/actor_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/actor/base/__pycache__/actor_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/actor/base/__pycache__/actor_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/actor/base/ac_helper.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class ACActorHelperMixin(metaclass=abc.ABCMeta): 7 | """ 8 | A helper class for actor critic actors. 9 | """ 10 | 11 | @staticmethod 12 | def flatten_logits(logit): 13 | """ 14 | :param logits: Tensor of arbitrary dim 15 | :return: logits flattened to (N, X) 16 | """ 17 | size = logit.size() 18 | dim = logit.dim() 19 | 20 | if dim == 3: 21 | n, f, l = size 22 | logit = logit.view(n, f * l) 23 | elif dim == 4: 24 | n, f, h, w = size 25 | logit = logit.view(n, f * h * w) 26 | elif dim == 5: 27 | n, f, d, h, w = size 28 | logit = logit.view(n, f * d * h * w) 29 | return logit 30 | 31 | @staticmethod 32 | def softmax(logit): 33 | """ 34 | :param logit: torch.Tensor (N, X) 35 | :return: torch.Tensor (N, X) 36 | """ 37 | return F.softmax(logit, dim=1) 38 | 39 | @staticmethod 40 | def log_softmax(logit): 41 | """ 42 | :param logit: torch.Tensor (N, X) 43 | :return: torch.Tensor (N, X) 44 | """ 45 | return F.log_softmax(logit, dim=1) 46 | 47 | @staticmethod 48 | def log_probability(log_softmax, action): 49 | """ 50 | :param log_softmax: Tensor (N, X) 51 | :param action: LongTensor (N) 52 | :return: Tensor (N, 1) 53 | """ 54 | return log_softmax.gather(1, action.unsqueeze(1)) 55 | 56 | @staticmethod 57 | def entropy(log_softmax, softmax): 58 | """ 59 | :param log_softmax: Tensor (N, X) 60 | :param softmax: Tensor (N, X) 61 | :return: Tensor (N, 1) 62 | """ 63 | return -(log_softmax * softmax).sum(1, keepdim=True) 64 | 65 | @staticmethod 66 | def sample_action(softmax): 67 | """ 68 | Samples an action from a softmax distribution. 69 | 70 | :param softmax: torch.Tensor (N, X) 71 | :return: torch.Tensor (N) 72 | """ 73 | return softmax.multinomial(1).squeeze(1) 74 | 75 | @staticmethod 76 | def select_action(softmax): 77 | """ 78 | Selects the action with the highest probability. 79 | 80 | :param softmax: 81 | :return: 82 | """ 83 | return torch.argmax(softmax, dim=1) 84 | -------------------------------------------------------------------------------- /mcs/actor/base/actor_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | """ 16 | An actor observes the environment and decides actions. It also outputs extra 17 | info necessary for model updates (learning) to occur. 18 | """ 19 | import abc 20 | from collections import defaultdict 21 | 22 | from mcs.exp.base.spec_builder import ExpSpecBuilder 23 | from mcs.utils.requires_args import RequiresArgsMixin 24 | 25 | 26 | class ActorModule(RequiresArgsMixin, metaclass=abc.ABCMeta): 27 | def __init__(self,action_space): 28 | self._action_space = action_space 29 | 30 | 31 | @property 32 | def action_space(self): 33 | return self._action_space 34 | 35 | @property 36 | def action_keys(self): 37 | return sorted(self.action_space.keys()) 38 | 39 | @staticmethod 40 | @abc.abstractmethod 41 | def output_space(action_space): 42 | raise NotImplementedError 43 | 44 | @classmethod 45 | def exp_spec_builder(cls, obs_space, act_space, internal_space, batch_sz): 46 | def build_fn(exp_len): 47 | exp_space = cls._exp_spec( 48 | exp_len, batch_sz, obs_space, act_space, internal_space 49 | ) 50 | env_space = { 51 | "rewards": (exp_len, batch_sz), 52 | "terminals": (exp_len, batch_sz), 53 | } 54 | return {**exp_space, **env_space} 55 | 56 | key_types = cls._key_types(obs_space, act_space, internal_space) 57 | exp_keys = cls._exp_keys(obs_space, act_space, internal_space) 58 | return ExpSpecBuilder( 59 | obs_space, act_space, internal_space, key_types, exp_keys, build_fn 60 | ) 61 | 62 | @classmethod 63 | @abc.abstractmethod 64 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space): 65 | raise NotImplementedError 66 | 67 | @classmethod 68 | def _exp_keys(cls, obs_space, act_space, internal_space): 69 | dummy = cls._exp_spec(1, 1, obs_space, act_space, internal_space) 70 | return dummy.keys() 71 | 72 | @classmethod 73 | def _key_types(cls, obs_space, act_space, internal_space): 74 | return defaultdict(lambda: "float") 75 | 76 | @abc.abstractmethod 77 | def from_args(self, args, action_space): 78 | raise NotImplementedError 79 | 80 | @abc.abstractmethod 81 | def compute_action_exp(self, preds, internals, obs, available_actions): 82 | """ 83 | B = Batch Size 84 | 85 | :param preds: Dict[str, torch.Tensor] 86 | :return: 87 | actions: Dict[ActionKey, Tensor (B)] 88 | experience: Dict[str, Tensor (B, X)] 89 | """ 90 | raise NotImplementedError 91 | 92 | def act(self, network, obs, prev_internals): 93 | """ 94 | :param obs: Dict[str, Tensor] 95 | :param prev_internals: previous interal states. Dict[str, Tensor] 96 | :return: 97 | actions: Dict[ActionKey, Tensor (B)] 98 | experience: Dict[str, Tensor (B, X)] 99 | internal_states: Dict[str, Tensor] 100 | """ 101 | 102 | predictions, internal_states, pobs = network(obs, prev_internals) 103 | 104 | if "available_actions" in obs: 105 | av_actions = obs["available_actions"] 106 | else: 107 | av_actions = None 108 | 109 | actions, exp = self.compute_action_exp( 110 | predictions, prev_internals, pobs, av_actions 111 | ) 112 | return actions, exp, internal_states 113 | -------------------------------------------------------------------------------- /mcs/agent/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/.DS_Store -------------------------------------------------------------------------------- /mcs/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .base.agent_module import AgentModule 2 | from .actor_critic import ActorCritic 3 | 4 | AGENT_REG = [ 5 | ActorCritic 6 | ] 7 | -------------------------------------------------------------------------------- /mcs/agent/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/agent/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/agent/__pycache__/actor_critic.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/actor_critic.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/agent/__pycache__/actor_critic.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/__pycache__/actor_critic.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/agent/actor_critic.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from mcs.actor import ACRolloutActorTrain 16 | from mcs.exp import Rollout 17 | from mcs.learner import ACRolloutLearner 18 | from .base.agent_module import AgentModule 19 | 20 | 21 | class ActorCritic(AgentModule): 22 | args = {**Rollout.args, **ACRolloutActorTrain.args, **ACRolloutLearner.args} 23 | 24 | def __init__( 25 | self, 26 | reward_normalizer, 27 | action_space, 28 | spec_builder, 29 | rollout_len, 30 | discount, 31 | normalize_advantage, 32 | entropy_weight, 33 | return_scale, 34 | ): 35 | super(ActorCritic, self).__init__(reward_normalizer, action_space) 36 | self.discount = discount 37 | self.normalize_advantage = normalize_advantage 38 | self.entropy_weight = entropy_weight 39 | 40 | self._exp_cache = Rollout(spec_builder, rollout_len) 41 | self._actor = ACRolloutActorTrain(action_space) 42 | self._learner = ACRolloutLearner( 43 | reward_normalizer, 44 | discount, 45 | normalize_advantage, 46 | entropy_weight, 47 | return_scale, 48 | ) 49 | 50 | @classmethod 51 | def from_args( 52 | cls, args, reward_normalizer, action_space, spec_builder, **kwargs 53 | ): 54 | return cls( 55 | reward_normalizer, 56 | action_space, 57 | spec_builder, 58 | rollout_len=args.rollout_len, 59 | discount=args.discount, 60 | normalize_advantage=args.normalize_advantage, 61 | entropy_weight=args.entropy_weight, 62 | return_scale=args.return_scale, 63 | ) 64 | 65 | @property 66 | def exp_cache(self): 67 | return self._exp_cache 68 | 69 | @classmethod 70 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space): 71 | return ACRolloutActorTrain._exp_spec( 72 | exp_len, batch_sz, obs_space, act_space, internal_space 73 | ) 74 | 75 | @staticmethod 76 | def output_space(action_space): 77 | return ACRolloutActorTrain.output_space(action_space) 78 | 79 | def compute_action_exp( 80 | self, predictions, internals, obs, available_actions 81 | ): 82 | return self._actor.compute_action_exp( 83 | predictions, internals, obs, available_actions 84 | ) 85 | 86 | def learn_step(self, updater, network, next_obs, internals): 87 | return self._learner.learn_step( 88 | updater, network, self.exp_cache.read(), next_obs, internals 89 | ) 90 | -------------------------------------------------------------------------------- /mcs/agent/base/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/.DS_Store -------------------------------------------------------------------------------- /mcs/agent/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__init__.py -------------------------------------------------------------------------------- /mcs/agent/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/agent/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/agent/base/__pycache__/agent_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/agent_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/agent/base/__pycache__/agent_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/agent/base/__pycache__/agent_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/agent/base/agent_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | """ 16 | An Agent interacts with the environment and accumulates experience. 17 | """ 18 | from collections import defaultdict 19 | 20 | import abc 21 | 22 | from mcs.exp import ExpSpecBuilder 23 | from mcs.utils.requires_args import RequiresArgsMixin 24 | 25 | 26 | class AgentModule(RequiresArgsMixin, metaclass=abc.ABCMeta): 27 | """ 28 | An Agent is an Actor (chooses actions) and a Learner (updates parameters) 29 | and maintains a cache to store a rollout or experience replay. 30 | 31 | Actors and Learners are treated separately for Actor-Learner architectures 32 | where multiple actors send their experience to a single learner. 33 | """ 34 | 35 | def __init__(self, reward_normalizer, action_space): 36 | self._reward_normalizer = reward_normalizer 37 | self._action_space = action_space 38 | 39 | @classmethod 40 | @abc.abstractmethod 41 | def from_args( 42 | cls, args, reward_normalizer, action_space, spec_builder, **kwargs 43 | ): 44 | raise NotImplementedError 45 | 46 | @property 47 | @abc.abstractmethod 48 | def exp_cache(self): 49 | """Get experience cache""" 50 | raise NotImplementedError 51 | 52 | @property 53 | def action_space(self): 54 | return self._action_space 55 | 56 | @property 57 | def action_keys(self): 58 | return list(sorted(self.action_space.keys())) 59 | 60 | @classmethod 61 | def exp_spec_builder(cls, obs_space, act_space, internal_space, batch_sz): 62 | def build_fn(exp_len): 63 | exp_space = cls._exp_spec( 64 | exp_len, batch_sz, obs_space, act_space, internal_space 65 | ) 66 | env_space = { 67 | "rewards": (exp_len, batch_sz), 68 | "terminals": (exp_len, batch_sz), 69 | } 70 | return {**exp_space, **env_space} 71 | 72 | key_types = cls._key_types(obs_space, act_space, internal_space) 73 | exp_keys = cls._exp_keys(obs_space, act_space, internal_space) 74 | return ExpSpecBuilder( 75 | obs_space, act_space, internal_space, key_types, exp_keys, build_fn 76 | ) 77 | 78 | @classmethod 79 | @abc.abstractmethod 80 | def _exp_spec(cls, exp_len, batch_sz, obs_space, act_space, internal_space): 81 | raise NotImplementedError 82 | 83 | @classmethod 84 | def _exp_keys(cls, obs_space, act_space, internal_space): 85 | dummy = cls._exp_spec(1, 1, obs_space, act_space, internal_space) 86 | return dummy.keys() 87 | 88 | @classmethod 89 | def _key_types(cls, obs_space, act_space, internal_space): 90 | return defaultdict(lambda: "float") 91 | 92 | @staticmethod 93 | @abc.abstractmethod 94 | def output_space(action_space): 95 | raise NotImplementedError 96 | 97 | @abc.abstractmethod 98 | def compute_action_exp( 99 | self, predictions, internals, obs, available_actions 100 | ): 101 | raise NotImplementedError 102 | 103 | @abc.abstractmethod 104 | def learn_step(self, updater, network, next_obs, internals): 105 | raise NotImplementedError 106 | 107 | def is_ready(self): 108 | return self.exp_cache.is_ready() 109 | 110 | def clear(self): 111 | self.exp_cache.clear() 112 | 113 | def act(self, network, obs, prev_internals): 114 | """ 115 | :param network: NetworkModule 116 | :param obs: Dict[str, Tensor] 117 | :param prev_internals: previous interal states. Dict[str, Tensor] 118 | :return: 119 | actions: Dict[ActionKey, LongTensor (B)] 120 | internal_states: Dict[str, Tensor] 121 | """ 122 | predictions, internal_states, pobs = network(obs, prev_internals) 123 | 124 | if "available_actions" in obs: 125 | av_actions = obs["available_actions"] 126 | else: 127 | av_actions = None 128 | 129 | actions, experience = self.compute_action_exp( 130 | predictions, prev_internals, pobs, av_actions 131 | ) 132 | self.exp_cache.write_actor(experience) 133 | return actions, internal_states 134 | 135 | def observe(self, obs, rewards, terminals, infos): 136 | self.exp_cache.write_env(obs, rewards, terminals, infos) 137 | return rewards, terminals, infos 138 | 139 | def to(self, device): 140 | self.exp_cache.to(device) 141 | return self 142 | -------------------------------------------------------------------------------- /mcs/config/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "actor_host": "ImpalaHostActor", 3 | "actor_worker": "ImpalaWorkerActor", 4 | "actor_target": "ImpalaHostTargetActor", 5 | "ceil": 1, 6 | "custom_network": null, 7 | "discount": 0.99, 8 | "entropy_weight": 0.01, 9 | "env": "SpaceInvadersNoFrameskip-v4", 10 | "epoch_len": 1000000, 11 | "eval": false, 12 | "exp": "Rollout", 13 | "floor": -1, 14 | "fourconv_norm": "bn", 15 | "frame_stack": false, 16 | "grad_norm_clip": "0.5", 17 | "head1d": "Identity1D", 18 | "head2d": "DeConv2D", 19 | "head3d": "Identity3D", 20 | "head4d": "Identity4D", 21 | "learner": "ImpalaLearner", 22 | "load_network": null, 23 | "load_optim": null, 24 | "logdir": "./logs", 25 | "lr": 0.0007, 26 | "lstm_nb_hidden": 512, 27 | "lstm_normalize": true, 28 | "manager": "SubProcEnvManager", 29 | "max_episode_length": 10000, 30 | "minimum_importance_policy": 1.0, 31 | "minimum_importance_value": 1.0, 32 | "nb_env": 32, 33 | "nb_learn_batch": 2, 34 | "nb_learners": 1, 35 | "nb_step": 50000000, 36 | "nb_workers": 2, 37 | "net1d": "Identity1D", 38 | "net2d": "Identity2D", 39 | "net3d": "FourConv", 40 | "net4d": "Identity4D", 41 | "netbody": "Linear", 42 | "noop_max": 30, 43 | "profile": false, 44 | "prompt": false, 45 | "ray_addr": null, 46 | "resume": null, 47 | "rollout_len": 20, 48 | "rollout_queue_size": 8, 49 | "rwd_norm": "Clip", 50 | "seed": 99, 51 | "skip_rate": 4, 52 | "summary_freq": 10, 53 | "tag": "default", 54 | "learner_cpu_alloc": 5, 55 | "learner_gpu_alloc": 0.5, 56 | "worker_cpu_alloc": 1, 57 | "worker_gpu_alloc": 0.05, 58 | "debug_mode": false, 59 | "test_mode": false, 60 | "use_pixel_control": false, 61 | "cell_size": 4, 62 | "pixel_control_loss_gamma": 0.99, 63 | "action_space": 6, 64 | "gae_gamma": 0.99, 65 | "gae_lambda": 0.995, 66 | "use_mhra": false, 67 | "num_head": 4, 68 | "probs_clip": 0.4, 69 | "target_worker_clip_rho": 2, 70 | "minibatch_buffer_size": 4, 71 | "num_sgd": 1, 72 | "linear_normalize": "not", 73 | "linear_nb_hidden": 512, 74 | "nb_layer": 3 75 | } 76 | -------------------------------------------------------------------------------- /mcs/container/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from .local import Local 16 | from .distrib import DistribHost, DistribWorker 17 | from .actorlearner import ActorLearnerHost, ActorLearnerWorker 18 | from .evaluation import EvalContainer 19 | from .init import Init 20 | from .base.updater import Updater 21 | -------------------------------------------------------------------------------- /mcs/container/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/distrib.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/distrib.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/distrib.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/distrib.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/evaluation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/evaluation.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/init.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/init.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/init.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/init.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/local.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/local.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/__pycache__/local.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/__pycache__/local.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__init__.py: -------------------------------------------------------------------------------- 1 | from .learner_container import ActorLearnerHost 2 | from .rollout_worker import ActorLearnerWorker 3 | 4 | -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/learner_container.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/learner_container.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/learner_container.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/learner_container.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_queuer.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/rollout_worker.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_worker.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/__pycache__/rollout_worker.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/actorlearner/__pycache__/rollout_worker.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/actorlearner/rollout_queuer.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from time import time 16 | 17 | import queue 18 | import ray 19 | import threading 20 | import torch 21 | 22 | 23 | class RolloutQueuerAsync: 24 | def __init__(self, workers, num_rollouts, queue_max_size, timeout=15.0): 25 | self.workers = workers 26 | self.num_rollouts = num_rollouts 27 | self.queue_max_size = queue_max_size 28 | self.queue_timeout = timeout 29 | 30 | self.futures = [w.run.remote() for w in self.workers] 31 | self.future_inds = [w for w in range(len(self.workers))] 32 | self._should_stop = True 33 | self.rollout_queue = queue.Queue(self.queue_max_size) 34 | self._worker_wait_time = 0 35 | self._host_wait_time = 0 36 | 37 | def _background_queing_thread(self): 38 | while not self._should_stop: 39 | ready_ids, remaining_ids = ray.wait(self.futures, num_returns=1) 40 | 41 | # if ray returns an empty list that means all objects have been gotten 42 | # this should never happen? 43 | if len(ready_ids) == 0: 44 | print("WARNING: ray returned no ready rollouts") 45 | # otherwise rollout was returned 46 | else: 47 | for ready in ready_ids: 48 | # get object and add to queue 49 | rollouts = ray.get(ready) 50 | # this will block if queue is at max size 51 | self._add_to_queue(rollouts) 52 | 53 | # remove from futures 54 | self._idle_workers = [] 55 | for ready in ready_ids: 56 | index = self.futures.index(ready) 57 | del self.futures[index] 58 | self._idle_workers.append(self.future_inds[index]) 59 | del self.future_inds[index] 60 | 61 | # tell the worker(s) to start another rollout 62 | self._restart_idle_workers() 63 | 64 | # done, wait for all remaining to finish 65 | dones, not_dones = ray.wait( 66 | self.futures, len(self.futures), timeout=self.queue_timeout 67 | ) 68 | 69 | if len(not_dones) > 0: 70 | print("WARNING: Not all rollout workers finished") 71 | 72 | def _add_to_queue(self, rollout): 73 | st = time() 74 | self.rollout_queue.put(rollout, timeout=self.queue_timeout) 75 | et = time() 76 | self._worker_wait_time += et - st 77 | 78 | def start(self): 79 | self._should_stop = False 80 | self.background_thread = threading.Thread( 81 | target=self._background_queing_thread 82 | ) 83 | self.background_thread.start() 84 | 85 | def get(self): 86 | st = time() 87 | worker_data = [ 88 | self.rollout_queue.get(True) for _ in range(self.num_rollouts) 89 | ] 90 | et = time() 91 | self._host_wait_time += et - st 92 | 93 | rollouts = [] 94 | terminal_rewards = [] 95 | terminal_infos = [] 96 | for w in worker_data: 97 | r, t, i = w["rollout"], w["terminal_rewards"], w["terminal_infos"] 98 | rollouts.append(r) 99 | terminal_rewards.append(t) 100 | terminal_infos.append(i) 101 | 102 | return rollouts, terminal_rewards, terminal_infos 103 | 104 | def close(self): 105 | self._should_stop = True 106 | 107 | # try to join background thread 108 | self.background_thread.join() 109 | 110 | def metrics(self): 111 | return { 112 | "Host wait time": self._host_wait_time, 113 | "Worker wait time": self._worker_wait_time, 114 | } 115 | 116 | def _restart_idle_workers(self): 117 | del_inds = [] 118 | for d_ind, w_ind in enumerate(self._idle_workers): 119 | worker = self.workers[w_ind] 120 | self.futures.append(worker.run.remote()) 121 | self.future_inds.append(w_ind) 122 | del_inds.append(d_ind) 123 | 124 | for d in del_inds: 125 | del self._idle_workers[d] 126 | 127 | def __len__(self): 128 | return len(self.rollout_queue) 129 | -------------------------------------------------------------------------------- /mcs/container/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .container import Container 2 | from .nccl_optimizer import NCCLOptimizer 3 | -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/container.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/container.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/container.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/container.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/nccl_optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/nccl_optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/nccl_optimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/nccl_optimizer.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/updater.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/updater.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/container/base/__pycache__/updater.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/container/base/__pycache__/updater.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/container/base/container.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class Container: 5 | @staticmethod 6 | def load_network(network, path): 7 | network.load_state_dict( 8 | torch.load(path, map_location=lambda storage, loc: storage) 9 | ) 10 | return network 11 | 12 | @staticmethod 13 | def load_optim(optimizer, path): 14 | optimizer.load_state_dict( 15 | torch.load(path, map_location=lambda storage, loc: storage) 16 | ) 17 | return optimizer 18 | 19 | @staticmethod 20 | def init_next_save(initial_step_count, epoch_len): 21 | next_save = 0 22 | if initial_step_count > 0: 23 | while next_save <= initial_step_count: 24 | next_save += epoch_len 25 | return next_save 26 | 27 | @staticmethod 28 | def count_parameters(net): 29 | return sum(p.numel() for p in net.parameters() if p.requires_grad) 30 | 31 | @staticmethod 32 | def write_summaries( 33 | writer, step_count, total_loss, loss_dict, metric_dict, n_params 34 | ): 35 | writer.add_scalar("loss/total_loss", total_loss.item(), step_count) 36 | for l_name, loss in loss_dict.items(): 37 | writer.add_scalar("loss/" + l_name, loss.item(), step_count) 38 | for m_name, metric in metric_dict.items(): 39 | if m_name == "importance": 40 | writer.add_scalar("metric/" + m_name, metric.item(), step_count) 41 | elif m_name == "action": 42 | writer.add_histogram("metric/" + m_name, getData(metric), step_count) 43 | else: 44 | writer.add_scalar("metric/" + m_name, metric.item(), step_count) 45 | for p_name, param in n_params: 46 | p_name = p_name.replace(".", "/") 47 | writer.add_scalar(p_name, torch.norm(param).item(), step_count) 48 | if param.grad is not None: 49 | writer.add_scalar( 50 | p_name + ".grad", torch.norm(param.grad).item(), step_count 51 | ) 52 | 53 | 54 | 55 | def getData(data): 56 | result = np.linspace(0,26,27) 57 | repeat = [] 58 | for i in range(27): 59 | repeat.append(int(data[i]*100)) 60 | return result.repeat(repeat) 61 | 62 | -------------------------------------------------------------------------------- /mcs/container/base/nccl_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | 17 | class NCCLOptimizer: 18 | def __init__(self, optimizer_fn, network, world_size, param_sync_rate=1000): 19 | self.network = network 20 | self.optimizer = optimizer_fn(self.network.parameters()) 21 | self.param_sync_rate = param_sync_rate 22 | self._opt_count = 0 23 | self.process_group = None 24 | self.world_size = world_size 25 | 26 | def set_process_group(self, pg): 27 | self.process_group = pg 28 | 29 | def step(self): 30 | handles = [] 31 | for param in self.network.parameters(): 32 | handles.append(self.process_group.allreduce(param.grad)) 33 | for handle in handles: 34 | handle.wait() 35 | for param in self.network.parameters(): 36 | param.grad.mul_(1.0 / self.world_size) 37 | self.optimizer.step() 38 | self._opt_count += 1 39 | 40 | # sync params every once in a while to reduce numerical errors 41 | if self._opt_count % self.param_sync_rate == 0: 42 | self.sync_parameters() 43 | # can't just sync buffers, some are int and don't mean well 44 | # self.sync_buffers() 45 | 46 | def sync_parameters(self): 47 | for param in self.network.parameters(): 48 | self.process_group.allreduce(param.data) 49 | param.data.mul_(1.0 / self.world_size) 50 | 51 | def sync_buffers(self): 52 | for b in self.network.buffers(): 53 | self.process_group.allreduce(b.data) 54 | b.data.mul_(1.0 / self.world_size) 55 | 56 | def zero_grad(self): 57 | self.optimizer.zero_grad() 58 | 59 | def state_dict(self): 60 | return self.optimizer.state_dict() 61 | 62 | def load_state_dict(self, d): 63 | return self.optimizer.load_state_dict(d) 64 | -------------------------------------------------------------------------------- /mcs/container/base/updater.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Updater(metaclass=abc.ABCMeta): 5 | def __init__(self, optimizer, network, grad_norm_clip): 6 | self.optimizer = optimizer 7 | self.network = network 8 | self.grad_norm_clip = grad_norm_clip 9 | 10 | def step(self, loss): 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /mcs/container/evaluation_thread.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import numpy as np 16 | from threading import Thread 17 | from time import time, sleep 18 | 19 | 20 | class EvaluationThread(LogsAndSummarizesRewards): 21 | def __init__( 22 | self, 23 | training_network, 24 | agent, 25 | env, 26 | nb_env, 27 | logger, 28 | summary_writer, 29 | step_rate_limit, 30 | override_step_count_fn=None, 31 | ): 32 | self._training_network = training_network 33 | self._agent = agent 34 | self._environment = env 35 | self._nb_env = nb_env 36 | self._logger = logger 37 | self._summary_writer = summary_writer 38 | self._step_rate_limit = step_rate_limit 39 | self._override_step_count_fn = override_step_count_fn 40 | self._thread = Thread(target=self._run) 41 | self._should_stop = False 42 | 43 | def start(self): 44 | self._thread.start() 45 | 46 | def stop(self): 47 | self._should_stop = True 48 | self._thread.join() 49 | 50 | def _run(self): 51 | next_obs = self.environment.reset() 52 | self.start_time = time() 53 | while not self._should_stop: 54 | if self._step_rate_limit > 0: 55 | sleep(1 / self._step_rate_limit) 56 | obs = next_obs 57 | actions = self.agent.act_eval(obs) 58 | next_obs, rewards, terminals, infos = self.environment.step(actions) 59 | 60 | self.agent.reset_internals(terminals) 61 | # Perform state updates 62 | terminal_rewards, terminal_infos = self.update_buffers( 63 | rewards, terminals, infos 64 | ) 65 | self.log_episode_results( 66 | terminal_rewards, terminal_infos, self.local_step_count 67 | ) 68 | self.write_reward_summaries(terminal_rewards, self.local_step_count) 69 | 70 | if np.any(terminals) and np.any(infos): 71 | self.network.load_state_dict( 72 | self._training_network.state_dict() 73 | ) 74 | 75 | def log_episode_results( 76 | self, terminal_rewards, terminal_infos, step_count, initial_step_count=0 77 | ): 78 | if terminal_rewards: 79 | ep_reward = np.mean(terminal_rewards) 80 | self.logger.info( 81 | "eval_frames: {} reward: {} avg_eval_fps: {}".format( 82 | step_count, 83 | ep_reward, 84 | (step_count - initial_step_count) 85 | / (time() - self.start_time), 86 | ) 87 | ) 88 | return terminal_rewards 89 | 90 | @property 91 | def agent(self): 92 | return self._agent 93 | 94 | @property 95 | def environment(self): 96 | return self._environment 97 | 98 | @environment.setter 99 | def environment(self, new_env): 100 | self._environment = new_env 101 | 102 | @property 103 | def nb_env(self): 104 | return self._nb_env 105 | 106 | @property 107 | def logger(self): 108 | return self._logger 109 | 110 | @property 111 | def summary_writer(self): 112 | return self._summary_writer 113 | 114 | @property 115 | def summary_name(self): 116 | return "reward/eval" 117 | 118 | @property 119 | def local_step_count(self): 120 | if self._override_step_count_fn is not None: 121 | return self._override_step_count_fn() 122 | else: 123 | return self._local_step_count 124 | 125 | @local_step_count.setter 126 | def local_step_count(self, step_count): 127 | self._local_step_count = step_count 128 | -------------------------------------------------------------------------------- /mcs/container/render.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import torch 16 | 17 | from mcs.manager import SimpleEnvManager 18 | from mcs.network import ModularNetwork 19 | from mcs.registry import REGISTRY 20 | from mcs.utils import dtensor_to_dev, listd_to_dlist 21 | from mcs.utils.script_helpers import LogDirHelper 22 | 23 | 24 | class RenderContainer: 25 | def __init__( 26 | self, 27 | actor, 28 | epoch_id, 29 | start, 30 | end, 31 | logger, 32 | log_id_dir, 33 | gpu_id, 34 | seed, 35 | manager, 36 | ): 37 | self.log_dir_helper = log_dir_helper = LogDirHelper(log_id_dir) 38 | self.train_args = train_args = log_dir_helper.load_args() 39 | self.device = device = self._device_from_gpu_id(gpu_id) 40 | self.logger = logger 41 | 42 | if epoch_id: 43 | epoch_ids = [epoch_id] 44 | else: 45 | epoch_ids = self.log_dir_helper.epochs() 46 | epoch_ids = filter(lambda eid: eid >= start, epoch_ids) 47 | if end != -1.0: 48 | epoch_ids = filter(lambda eid: eid <= end, epoch_ids) 49 | epoch_ids = list(epoch_ids) 50 | self.epoch_ids = epoch_ids 51 | 52 | engine = REGISTRY.lookup_engine(train_args.env) 53 | env_cls = REGISTRY.lookup_env(train_args.env) 54 | manager_cls = REGISTRY.lookup_manager(manager) 55 | self.env_mgr = manager_cls.from_args( 56 | self.train_args, engine, env_cls, seed=seed, nb_env=1 57 | ) 58 | if train_args.agent: 59 | agent = train_args.agent 60 | else: 61 | agent = train_args.actor_host 62 | output_space = REGISTRY.lookup_output_space( 63 | agent, self.env_mgr.action_space 64 | ) 65 | actor_cls = REGISTRY.lookup_actor(actor) 66 | self.actor = actor_cls.from_args( 67 | actor_cls.prompt(), self.env_mgr.action_space 68 | ) 69 | 70 | self.network = self._init_network( 71 | train_args, 72 | self.env_mgr.observation_space, 73 | self.env_mgr.gpu_preprocessor, 74 | output_space, 75 | REGISTRY, 76 | ).to(device) 77 | 78 | @staticmethod 79 | def _device_from_gpu_id(gpu_id): 80 | return torch.device( 81 | "cuda:{}".format(gpu_id) 82 | if (torch.cuda.is_available() and gpu_id >= 0) 83 | else "cpu" 84 | ) 85 | 86 | @staticmethod 87 | def _init_network( 88 | train_args, obs_space, gpu_preprocessor, output_space, net_reg 89 | ): 90 | if train_args.custom_network: 91 | net_cls = net_reg.lookup_network(train_args.custom_network) 92 | else: 93 | net_cls = ModularNetwork 94 | 95 | return net_cls.from_args( 96 | train_args, obs_space, output_space, gpu_preprocessor, net_reg 97 | ) 98 | 99 | def run(self): 100 | for epoch_id in self.epoch_ids: 101 | reward_buf = 0 102 | for net_path in self.log_dir_helper.network_paths_at_epoch( 103 | epoch_id 104 | ): 105 | self.network.load_state_dict( 106 | torch.load( 107 | net_path, map_location=lambda storage, loc: storage 108 | ) 109 | ) 110 | self.network.eval() 111 | 112 | internals = listd_to_dlist( 113 | [self.network.new_internals(self.device)] 114 | ) 115 | next_obs = dtensor_to_dev(self.env_mgr.reset(), self.device) 116 | self.env_mgr.render() 117 | 118 | episode_complete = False 119 | while not episode_complete: 120 | obs = next_obs 121 | with torch.no_grad(): 122 | actions, _, internals = self.actor.act( 123 | self.network, obs, internals 124 | ) 125 | next_obs, rewards, terminals, infos = self.env_mgr.step( 126 | actions 127 | ) 128 | self.env_mgr.render() 129 | next_obs = dtensor_to_dev(next_obs, self.device) 130 | 131 | reward_buf += rewards[0] 132 | 133 | if terminals[0]: 134 | episode_complete = True 135 | 136 | print(f"EPOCH_ID: {epoch_id} REWARD: {reward_buf}") 137 | 138 | def close(self): 139 | self.env_mgr.close() 140 | -------------------------------------------------------------------------------- /mcs/env/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from mcs.env.openai_gym import ATARI_ENVS 16 | from .base.env_module import EnvModule 17 | from .openai_gym import AdeptGymEnv 18 | 19 | ENV_REG = [AdeptGymEnv] 20 | -------------------------------------------------------------------------------- /mcs/env/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/env/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/env/__pycache__/_gym_wrappers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_gym_wrappers.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/env/__pycache__/_gym_wrappers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_gym_wrappers.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/env/__pycache__/_spaces.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_spaces.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/env/__pycache__/_spaces.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/_spaces.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/env/__pycache__/openai_gym.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/openai_gym.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/env/__pycache__/openai_gym.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/__pycache__/openai_gym.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/env/_spaces.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from gym import spaces 16 | 17 | 18 | class Space(dict): 19 | def __init__(self, entries_by_name): 20 | super(Space, self).__init__(entries_by_name) 21 | 22 | @classmethod 23 | def from_gym(cls, gym_space): 24 | entries_by_name = Space._detect_gym_spaces(gym_space) 25 | return cls(entries_by_name) 26 | 27 | @staticmethod 28 | def _detect_gym_spaces(gym_space): 29 | if isinstance(gym_space, spaces.Discrete): 30 | return {"Discrete": (gym_space.n,)} 31 | elif isinstance(gym_space, spaces.MultiDiscrete): 32 | return {"MultiDiscrete": gym_space.nvec.tolist()} 33 | elif isinstance(gym_space, spaces.MultiBinary): 34 | return {"MultiBinary": (gym_space.n,)} 35 | elif isinstance(gym_space, spaces.Box): 36 | return {"Box": gym_space.shape} 37 | elif isinstance(gym_space, spaces.Dict): 38 | return { 39 | name: list(Space._detect_gym_spaces(s).values())[0] 40 | for name, s in gym_space.spaces.items() 41 | } 42 | elif isinstance(gym_space, spaces.Tuple): 43 | return { 44 | idx: list(Space._detect_gym_spaces(s).values())[0] 45 | for idx, s in enumerate(gym_space.spaces) 46 | } 47 | 48 | @staticmethod 49 | def dtypes_from_gym(gym_space): 50 | if isinstance(gym_space, spaces.Discrete): 51 | return {"Discrete": gym_space.dtype} 52 | elif isinstance(gym_space, spaces.MultiDiscrete): 53 | return {"MultiDiscrete":gym_space.dtype} 54 | elif isinstance(gym_space, spaces.MultiBinary): 55 | return {"MultiBinary": gym_space.dtype} 56 | elif isinstance(gym_space, spaces.Box): 57 | return {"Box": gym_space.dtype} 58 | elif isinstance(gym_space, spaces.Dict): 59 | return { 60 | name: list(Space.dtypes_from_gym(s).values())[0] 61 | for name, s in gym_space.spaces.items() 62 | } 63 | elif isinstance(gym_space, spaces.Tuple): 64 | return { 65 | idx: list(Space.dtypes_from_gym(s).values())[0] 66 | for idx, s in enumerate(gym_space.spaces) 67 | } 68 | else: 69 | raise NotImplementedError 70 | -------------------------------------------------------------------------------- /mcs/env/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__init__.py -------------------------------------------------------------------------------- /mcs/env/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/env/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/env/base/__pycache__/_env.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/_env.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/env/base/__pycache__/_env.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/_env.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/env/base/__pycache__/env_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/env_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/env/base/__pycache__/env_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/env/base/__pycache__/env_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/env/base/_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import abc 16 | 17 | 18 | class HasEnvMetaData(metaclass=abc.ABCMeta): 19 | @property 20 | @abc.abstractmethod 21 | def observation_space(self): 22 | raise NotImplementedError 23 | 24 | @property 25 | @abc.abstractmethod 26 | def action_space(self): 27 | raise NotImplementedError 28 | 29 | @property 30 | @abc.abstractmethod 31 | def cpu_preprocessor(self): 32 | raise NotImplementedError 33 | 34 | @property 35 | @abc.abstractmethod 36 | def gpu_preprocessor(self): 37 | raise NotImplementedError 38 | 39 | 40 | class EnvBase(HasEnvMetaData, metaclass=abc.ABCMeta): 41 | @abc.abstractmethod 42 | def step(self, action): 43 | """ 44 | :param action: Dict[ActionID, Any] Action dictionary 45 | :return: Tuple[Observation, Reward, Terminal, Info] 46 | """ 47 | raise NotImplementedError 48 | 49 | @abc.abstractmethod 50 | def reset(self, **kwargs): 51 | """ 52 | :param kwargs: 53 | :return: Dict[ObservationID, Any] Observation dictionary 54 | """ 55 | raise NotImplementedError 56 | 57 | @abc.abstractmethod 58 | def close(self): 59 | """ 60 | Close environment. Release resources. 61 | 62 | :return: 63 | """ 64 | raise NotImplementedError 65 | -------------------------------------------------------------------------------- /mcs/env/base/env_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import abc 16 | 17 | from mcs.utils.requires_args import RequiresArgsMixin 18 | from mcs.env.base._env import EnvBase 19 | 20 | 21 | class EnvModule(EnvBase, RequiresArgsMixin, metaclass=abc.ABCMeta): 22 | ids = None 23 | 24 | """ 25 | Implement this class to add your custom environment. Don't forget to 26 | implement args. 27 | """ 28 | 29 | def __init__(self, action_space, cpu_preprocessor, gpu_preprocessor): 30 | """ 31 | :param observation_space: ._spaces.Spaces 32 | :param action_space: ._spaces.Spaces 33 | :param cpu_preprocessor: mcs.preprocess.observation.ObsPreprocessor 34 | :param gpu_preprocessor: mcs.preprocess.observation.ObsPreprocessor 35 | """ 36 | self._action_space = action_space 37 | self._cpu_preprocessor = cpu_preprocessor 38 | self._gpu_preprocessor = gpu_preprocessor 39 | 40 | @classmethod 41 | @abc.abstractmethod 42 | def from_args(cls, args, seed, **kwargs): 43 | """ 44 | Construct from arguments. For convenience. 45 | 46 | :param args: Arguments object 47 | :param seed: Integer used to seed this environment. 48 | :param kwargs: Any custom arguments are passed through kwargs. 49 | :return: EnvModule instance. 50 | """ 51 | raise NotImplementedError 52 | 53 | @classmethod 54 | def from_args_curry(cls, args, seed, **kwargs): 55 | def _f(): 56 | return cls.from_args(args, seed, **kwargs) 57 | 58 | return _f 59 | 60 | @property 61 | def observation_space(self): 62 | return self._gpu_preprocessor.observation_space 63 | 64 | @property 65 | def action_space(self): 66 | return self._action_space 67 | 68 | @property 69 | def cpu_preprocessor(self): 70 | return self._cpu_preprocessor 71 | 72 | @property 73 | def gpu_preprocessor(self): 74 | return self._gpu_preprocessor 75 | 76 | @classmethod 77 | def check_ids_implemented(cls): 78 | if cls.ids is None: 79 | raise NotImplementedError( 80 | 'Subclass must define class attribute "ids"' 81 | ) 82 | -------------------------------------------------------------------------------- /mcs/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (C) 2018 Heron Systems, Inc. 3 | # 4 | # This program is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # This program is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | """ 17 | Usage: 18 | evaluate (--logdir ) [options] 19 | evaluate (-h | --help) 20 | 21 | Required: 22 | --logdir Path to train logs (.../logs//) 23 | 24 | Options: 25 | --epoch Epoch number to load [default: None] 26 | --actor Name of the eval actor [default: ACActorEval] 27 | --gpu-id CUDA device ID of GPU [default: 0] 28 | --nb-episode Number of episodes to average [default: 30] 29 | --start Epoch to start from [default: 0] 30 | --end Epoch to end on [default: -1] 31 | --seed Seed for random variables [default: 512] 32 | --custom-network Name of custom network class 33 | """ 34 | from mcs.container import EvalContainer 35 | from mcs.container import Init 36 | from mcs.registry import REGISTRY as R 37 | from mcs.utils.script_helpers import parse_path, parse_none 38 | from mcs.utils.util import DotDict 39 | 40 | import os 41 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 42 | 43 | def parse_args(): 44 | from docopt import docopt 45 | args = docopt(__doc__) 46 | args = {k.strip("--").replace("-", "_"): v for k, v in args.items()} 47 | del args["h"] 48 | del args["help"] 49 | 50 | args = DotDict(args) 51 | args.logdir = parse_path(args.logdir) 52 | # TODO implement Option utility 53 | epoch_option = parse_none(args.epoch) 54 | if epoch_option: 55 | args.epoch = int(float(epoch_option)) 56 | else: 57 | args.epoch = epoch_option 58 | 59 | args.gpu_id = int(args.gpu_id) 60 | args.nb_episode = int(args.nb_episode) 61 | args.start = float(args.start) 62 | args.end = float(args.end) 63 | args.seed = int(args.seed) 64 | return args 65 | 66 | 67 | def main(args): 68 | """ 69 | Run an evaluation. 70 | :param args: Dict[str, Any] 71 | :return: 72 | """ 73 | args = DotDict(args) 74 | 75 | Init.print_ascii_logo() 76 | logger = Init.setup_logger(args.logdir, "eval") 77 | Init.log_args(logger, args) 78 | R.load_extern_classes(args.logdir) 79 | 80 | eval_container = EvalContainer( 81 | args.actor, 82 | args.epoch, 83 | logger, 84 | args.logdir, 85 | args.gpu_id, 86 | args.nb_episode, 87 | args.start, 88 | args.end, 89 | args.seed, 90 | ) 91 | try: 92 | eval_container.run() 93 | finally: 94 | eval_container.close() 95 | 96 | 97 | if __name__ == "__main__": 98 | main(parse_args()) 99 | -------------------------------------------------------------------------------- /mcs/exp/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ExpModule 2 | from .base import ExpSpecBuilder 3 | from .replay import ExperienceReplay, PrioritizedExperienceReplay 4 | from .rollout import Rollout 5 | 6 | EXP_REG = [ 7 | Rollout, # , ExperienceReplay, PrioritizedExperienceReplay 8 | ] 9 | -------------------------------------------------------------------------------- /mcs/exp/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/exp/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/exp/__pycache__/repetitive_buffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/repetitive_buffer.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/exp/__pycache__/repetitive_buffer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/repetitive_buffer.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/exp/__pycache__/replay.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/replay.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/exp/__pycache__/replay.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/replay.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/exp/__pycache__/rollout.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/rollout.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/exp/__pycache__/rollout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/__pycache__/rollout.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/exp/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .spec_builder import ExpSpecBuilder 2 | from .exp_module import ExpModule 3 | -------------------------------------------------------------------------------- /mcs/exp/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/exp/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/exp/base/__pycache__/exp_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/exp_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/exp/base/__pycache__/exp_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/exp_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/exp/base/__pycache__/spec_builder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/spec_builder.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/exp/base/__pycache__/spec_builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/exp/base/__pycache__/spec_builder.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/exp/base/exp_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import abc 16 | 17 | from mcs.utils.requires_args import RequiresArgsMixin 18 | 19 | 20 | class ExpModule(RequiresArgsMixin, metaclass=abc.ABCMeta): 21 | @abc.abstractmethod 22 | def write_actor(self, experience): 23 | raise NotImplementedError 24 | 25 | @abc.abstractmethod 26 | def write_env(self, obs, rewards, terminals, infos): 27 | raise NotImplementedError 28 | 29 | @abc.abstractmethod 30 | def read(self): 31 | raise NotImplementedError 32 | 33 | @abc.abstractmethod 34 | def is_ready(self): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /mcs/exp/base/spec_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | 17 | class ExpSpecBuilder: 18 | def __init__( 19 | self, obs_keys, act_keys, internal_keys, key_types, exp_keys, build_fn 20 | ): 21 | self.obs_keys = sorted(obs_keys.keys()) 22 | self.action_keys = sorted(act_keys.keys()) 23 | self.internal_keys = sorted(internal_keys.keys()) 24 | self.key_types = key_types 25 | self.exp_keys = sorted(exp_keys) 26 | self.build_fn = build_fn 27 | 28 | def __call__(self, rollout_len): 29 | return self.build_fn(rollout_len) 30 | -------------------------------------------------------------------------------- /mcs/exp/repetitive_buffer.py: -------------------------------------------------------------------------------- 1 | from mcs.exp.rollout import Rollout 2 | import copy 3 | import torch 4 | 5 | 6 | class RepetitiveBuffer(object): 7 | 8 | 9 | def __init__(self, inqueue, buffer_size, num_k, exp): # (q,4,2_) 10 | self.inqueue = inqueue 11 | self.size = buffer_size # 4 12 | 13 | self.max_visit_times = num_k # 2 14 | self.cur_max_ttl = num_k # 2 15 | 16 | self.buffers = [None] * buffer_size 17 | self.buffer_count = [0] * buffer_size # [0,0,0,0] 18 | self.idx = 0 19 | self.exp = copy.deepcopy(exp) 20 | 21 | def get(self, target_actor, target_network, device): 22 | terminal_rewards = [] 23 | terminal_infos = [] 24 | if self.buffer_count[self.idx] <= 0: 25 | self.exp.clear() 26 | # Get batch from queue 27 | rollouts, terminal_rewards, terminal_infos = self.inqueue.get() 28 | 29 | # Iterate forward on batch 30 | self.exp.write_exps(rollouts) 31 | 32 | self.exp.to(device) 33 | r = self.exp.read() 34 | internals = {k: ts[0].unbind(0) for k, ts in r.internals.items()} 35 | with torch.no_grad(): 36 | for obs, rewards in zip( 37 | r.observations, r.rewards 38 | ): 39 | _, t_h_exp, t_internals = target_actor.act( 40 | target_network, obs, internals 41 | ) 42 | self.exp.write_actor(t_h_exp, no_env=True) 43 | self.exp.reset_index() 44 | self.buffers[self.idx] = copy.deepcopy(self.exp) 45 | self.buffer_count[self.idx] = self.max_visit_times 46 | 47 | buf = self.buffers[self.idx] 48 | self.buffer_count[self.idx] -= 1 49 | released = self.buffer_count[self.idx] <= 0 50 | if released: 51 | self.buffers[self.idx] = None 52 | self.idx = (self.idx + 1) % self.size 53 | return buf, terminal_rewards, terminal_infos 54 | 55 | def setMaxTimes(self, times): 56 | self.max_visit_times = times 57 | -------------------------------------------------------------------------------- /mcs/exp/replay.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from mcs.exp.base.exp_module import ExpModule 16 | 17 | 18 | class ExperienceReplay(ExpModule): 19 | def write_actor(self, items): 20 | pass 21 | 22 | def write_env(self, obs, rewards, terminals, infos): 23 | pass 24 | 25 | def read(self): 26 | pass 27 | 28 | def is_ready(self): 29 | pass 30 | 31 | 32 | class PrioritizedExperienceReplay(ExpModule): 33 | def write_actor(self, items): 34 | pass 35 | 36 | def write_env(self, obs, rewards, terminals, infos): 37 | pass 38 | 39 | def read(self): 40 | pass 41 | 42 | def is_ready(self): 43 | pass 44 | -------------------------------------------------------------------------------- /mcs/globals.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | VERSION = "1.0" 16 | -------------------------------------------------------------------------------- /mcs/learner/__init__.py: -------------------------------------------------------------------------------- 1 | from mcs.learner.base.learner_module import LearnerModule 2 | from .ac_rollout import ACRolloutLearner 3 | from .impala import ImpalaLearner 4 | 5 | LEARNER_REG = [ACRolloutLearner, ImpalaLearner] 6 | -------------------------------------------------------------------------------- /mcs/learner/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/learner/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/learner/__pycache__/ac_rollout.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/ac_rollout.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/learner/__pycache__/ac_rollout.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/ac_rollout.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/learner/__pycache__/impala.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/impala.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/learner/__pycache__/impala.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/__pycache__/impala.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/learner/ac_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base.learner_module import LearnerModule 4 | from .base.dm_return_scale import DeepMindReturnScaler 5 | 6 | 7 | class ACRolloutLearner(LearnerModule): 8 | """ 9 | Actor Critic Rollout Learner 10 | """ 11 | 12 | args = { 13 | "discount": 0.99, 14 | "normalize_advantage": False, 15 | "entropy_weight": 0.01, 16 | "return_scale": False, 17 | } 18 | 19 | def __init__( 20 | self, 21 | reward_normalizer, 22 | discount, 23 | normalize_advantage, 24 | entropy_weight, 25 | return_scale, 26 | ): 27 | self.reward_normalizer = reward_normalizer 28 | self.discount = discount 29 | self.normalize_advantage = normalize_advantage 30 | self.entropy_weight = entropy_weight 31 | self.return_scale = return_scale 32 | if return_scale: 33 | self.dm_scaler = DeepMindReturnScaler(10.0 ** -3) 34 | 35 | @classmethod 36 | def from_args(cls, args, reward_normalizer): 37 | return cls( 38 | reward_normalizer, 39 | args.discount, 40 | args.normalize_advantage, 41 | args.entropy_weight, 42 | args.return_scale, 43 | ) 44 | 45 | def learn_step(self, updater, network, experiences, next_obs, internals): 46 | # normalize rewards 47 | rewards = self.reward_normalizer(torch.stack(experiences.rewards)) 48 | 49 | # torch stack rollouts 50 | r_log_probs_action = torch.stack(experiences.log_probs) 51 | r_values = torch.stack(experiences.values) 52 | r_entropies = torch.stack(experiences.entropies) 53 | 54 | # estimate value of next state 55 | with torch.no_grad(): 56 | results, _, _ = network(next_obs, internals) 57 | last_values = results["critic"].squeeze(1).data 58 | 59 | # compute nstep return and advantage over batch 60 | r_tgt_returns = self.compute_returns( 61 | last_values, rewards, experiences.terminals 62 | ) 63 | r_advantages = r_tgt_returns - r_values.data 64 | 65 | # normalize advantage so that an even number 66 | # of actions are reinforced and penalized 67 | if self.normalize_advantage: 68 | r_advantages = (r_advantages - r_advantages.mean()) / ( 69 | r_advantages.std() + 1e-5 70 | ) 71 | 72 | # batched losses 73 | policy_loss = -(r_log_probs_action) * r_advantages.unsqueeze(-1) 74 | # mean over actions, seq, batch 75 | policy_loss = policy_loss.mean() 76 | entropy_loss = -r_entropies.mean() * self.entropy_weight 77 | value_loss = 0.5 * (r_tgt_returns - r_values).pow(2).mean() 78 | 79 | updater.step(value_loss + policy_loss + entropy_loss) 80 | 81 | losses = { 82 | "value_loss": value_loss, 83 | "policy_loss": policy_loss, 84 | "entropy_loss": entropy_loss, 85 | } 86 | metrics = {} 87 | return losses, metrics 88 | 89 | def compute_returns(self, bootstrap_value, rewards, terminals): 90 | # First step of nstep reward target is estimated value of t+1 91 | target_return = bootstrap_value 92 | rollout_len = len(rewards) 93 | nstep_target_returns = [] 94 | for i in reversed(range(rollout_len)): 95 | reward = rewards[i] 96 | terminal_mask = 1.0 - terminals[i].float() 97 | 98 | if self.return_scale: 99 | target_return = self.dm_scaler.calc_scale( 100 | reward 101 | + self.discount 102 | * self.dm_scaler.calc_inverse_scale(target_return) 103 | * terminal_mask 104 | ) 105 | else: 106 | target_return = reward + ( 107 | self.discount * target_return * terminal_mask 108 | ) 109 | nstep_target_returns.append(target_return) 110 | 111 | # reverse lists 112 | nstep_target_returns = torch.stack( 113 | list(reversed(nstep_target_returns)) 114 | ).data 115 | 116 | return nstep_target_returns 117 | -------------------------------------------------------------------------------- /mcs/learner/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .learner_module import LearnerModule 2 | -------------------------------------------------------------------------------- /mcs/learner/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/learner/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/learner/base/__pycache__/dm_return_scale.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/dm_return_scale.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/learner/base/__pycache__/dm_return_scale.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/dm_return_scale.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/learner/base/__pycache__/learner_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/learner_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/learner/base/__pycache__/learner_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/learner/base/__pycache__/learner_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/learner/base/dm_return_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DeepMindReturnScaler: 5 | """ 6 | Scale returns as in R2D2. 7 | https://openreview.net/pdf?id=r1lyTjAqYX 8 | """ 9 | 10 | def __init__(self, scale): 11 | self.scale = scale 12 | 13 | def calc_scale(self, x): 14 | return ( 15 | torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + self.scale * x 16 | ) 17 | 18 | def calc_inverse_scale(self, x): 19 | sign = torch.sign(x) 20 | sqrt = torch.sqrt(1 + 4 * self.scale * (torch.abs(x) + 1 + self.scale)) 21 | return sign * ((((sqrt - 1) / (2 * self.scale)) ** 2) - 1) 22 | -------------------------------------------------------------------------------- /mcs/learner/base/learner_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | """ 16 | A Learner receives agent-environment interactions which are used to compute 17 | loss. 18 | """ 19 | import abc 20 | from mcs.utils.requires_args import RequiresArgsMixin 21 | 22 | 23 | class LearnerModule(RequiresArgsMixin, metaclass=abc.ABCMeta): 24 | """ 25 | This one of the modules to use for custom Actor-Learner code. 26 | """ 27 | 28 | @classmethod 29 | @abc.abstractmethod 30 | def from_args(self, args, reward_normalizer): 31 | raise NotImplementedError 32 | 33 | @abc.abstractmethod 34 | def learn_step(self, updater, network,target_network, experiences, next_obs, internals,avg_net): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /mcs/logs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/logs/.DS_Store -------------------------------------------------------------------------------- /mcs/manager/__init__.py: -------------------------------------------------------------------------------- 1 | from mcs.manager.simple_env_manager import SimpleEnvManager 2 | from mcs.manager.subproc_env_manager import SubProcEnvManager 3 | from mcs.manager.base.manager_module import EnvManagerModule 4 | 5 | 6 | MANAGER_REG = [SimpleEnvManager, SubProcEnvManager] 7 | -------------------------------------------------------------------------------- /mcs/manager/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/manager/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/manager/__pycache__/simple_env_manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/simple_env_manager.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/manager/__pycache__/simple_env_manager.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/simple_env_manager.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/manager/__pycache__/subproc_env_manager.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/subproc_env_manager.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/manager/__pycache__/subproc_env_manager.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/__pycache__/subproc_env_manager.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/manager/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__init__.py -------------------------------------------------------------------------------- /mcs/manager/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/manager/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/manager/base/__pycache__/manager_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/manager_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/manager/base/__pycache__/manager_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/manager/base/__pycache__/manager_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/manager/base/manager_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import abc 16 | 17 | from mcs.env.base._env import EnvBase 18 | from mcs.utils.requires_args import RequiresArgsMixin 19 | 20 | 21 | class EnvManagerModule(EnvBase, RequiresArgsMixin, metaclass=abc.ABCMeta): 22 | def __init__(self, env_fns, engine): 23 | self._env_fns = env_fns 24 | self._engine = engine 25 | 26 | @property 27 | def env_fns(self): 28 | return self._env_fns 29 | 30 | @property 31 | def engine(self): 32 | return self._engine 33 | 34 | @property 35 | def nb_env(self): 36 | return len(self._env_fns) 37 | 38 | @classmethod 39 | def from_args(cls, args, engine, env_cls, seed=None, nb_env=None, **kwargs): 40 | if seed is None: 41 | seed = int(args.seed) 42 | if nb_env is None: 43 | nb_env = args.nb_env 44 | 45 | env_fns = [] 46 | for i in range(nb_env): 47 | env_fns.append(env_cls.from_args_curry(args, seed + i, **kwargs)) 48 | return cls(env_fns, engine) 49 | -------------------------------------------------------------------------------- /mcs/manager/simple_env_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import numpy as np 16 | import torch 17 | 18 | from mcs.utils import listd_to_dlist 19 | from mcs.utils.util import dlist_to_listd 20 | from .base.manager_module import EnvManagerModule 21 | 22 | 23 | class SimpleEnvManager(EnvManagerModule): 24 | """ 25 | Manages multiple env in the same process. This is slower than a 26 | SubProcEnvManager but allows debugging. 27 | """ 28 | 29 | args = {} 30 | 31 | def __init__(self, env_fns, engine): 32 | super(SimpleEnvManager, self).__init__(env_fns, engine) 33 | self.envs = [fn() for fn in env_fns] 34 | env = self.envs[0] 35 | self._observation_space = env.observation_space 36 | self._action_space = env.action_space 37 | self._cpu_preprocessor = env.cpu_preprocessor 38 | self._gpu_preprocessor = env.gpu_preprocessor 39 | 40 | self.buf_obs = [None for _ in range(self.nb_env)] 41 | self.buf_dones = [None for _ in range(self.nb_env)] 42 | self.buf_rews = [None for _ in range(self.nb_env)] 43 | self.buf_infos = [None for _ in range(self.nb_env)] 44 | self.actions = None 45 | 46 | @property 47 | def cpu_preprocessor(self): 48 | return self._cpu_preprocessor 49 | 50 | @property 51 | def gpu_preprocessor(self): 52 | return self._gpu_preprocessor 53 | 54 | @property 55 | def observation_space(self): 56 | return self._observation_space 57 | 58 | @property 59 | def action_space(self): 60 | return self._action_space 61 | 62 | def step(self, actions): 63 | self.step_async(actions) 64 | return self.step_wait() 65 | 66 | def step_async(self, actions): 67 | actions_tensor = dlist_to_listd(actions) 68 | self.actions = [ 69 | {k: v.item() for k, v in a_ten.items()} for a_ten in actions_tensor 70 | ] 71 | 72 | def step_wait(self): 73 | obs = [] 74 | for e in range(self.nb_env): 75 | ( 76 | ob, 77 | self.buf_rews[e], 78 | self.buf_dones[e], 79 | self.buf_infos[e], 80 | ) = self.envs[e].step(self.actions[e]) 81 | if self.buf_dones[e]: 82 | ob = self.envs[e].reset() 83 | obs.append(ob) 84 | obs = listd_to_dlist(obs) 85 | new_obs = {} 86 | for k, v in dummy_handle_ob(obs).items(): 87 | if self._is_tensor_key(k): 88 | new_obs[k] = torch.stack(v) 89 | else: 90 | new_obs[k] = v 91 | self.buf_obs = new_obs 92 | 93 | return ( 94 | self.buf_obs, 95 | torch.tensor(self.buf_rews), 96 | torch.tensor(self.buf_dones), 97 | self.buf_infos, 98 | ) 99 | 100 | def reset(self): 101 | obs = [] 102 | for e in range(self.nb_env): 103 | ob = self.envs[e].reset() 104 | obs.append(ob) 105 | obs = listd_to_dlist(obs) 106 | new_obs = {} 107 | for k, v in dummy_handle_ob(obs).items(): 108 | if self._is_tensor_key(k): 109 | new_obs[k] = torch.stack(v) 110 | else: 111 | new_obs[k] = v 112 | self.buf_obs = new_obs 113 | return self.buf_obs 114 | 115 | def close(self): 116 | return [e.close() for e in self.envs] 117 | 118 | def render(self, mode="human"): 119 | return [e.render(mode=mode) for e in self.envs] 120 | 121 | def _is_tensor_key(self, key): 122 | return None not in self.cpu_preprocessor.observation_space[key] 123 | 124 | 125 | def dummy_handle_ob(ob): 126 | new_ob = {} 127 | for k, v in ob.items(): 128 | if isinstance(v, np.ndarray): 129 | new_ob[k] = torch.from_numpy(v) 130 | else: 131 | new_ob[k] = v 132 | return new_ob 133 | -------------------------------------------------------------------------------- /mcs/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from .norm import Identity 16 | from .attention import MultiHeadSelfAttention, RMCCell 17 | from .sequence import LSTMCellLayerNorm 18 | from .spatial import Residual2DPreact 19 | -------------------------------------------------------------------------------- /mcs/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/norm.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/norm.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/sequence.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/sequence.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/sequence.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/sequence.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/spatial.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/spatial.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/modules/__pycache__/spatial.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/modules/__pycache__/spatial.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import math 16 | 17 | import torch 18 | from torch import nn 19 | from torch.nn import Parameter, functional as F 20 | 21 | 22 | class GaussianLinear(nn.Module): 23 | def __init__(self, fan_in, nodes): 24 | super().__init__() 25 | self.mu = nn.Linear(fan_in, nodes) 26 | # init_tflearn_fc_(self.mu) 27 | self.std = nn.Linear(fan_in, nodes) 28 | # init_tflearn_fc_(self.std) 29 | 30 | def forward(self, x): 31 | mu = self.mu(x) 32 | if self.training: 33 | std = self.std(x) 34 | std = torch.exp(0.5 * std) 35 | eps = torch.randn_like(std) 36 | return eps.mul(std).add_(mu) 37 | else: 38 | return mu 39 | 40 | def get_parameter_names(self): 41 | return ["Mu_W", "Mu_b", "Std_W", "Std_b"] 42 | 43 | 44 | class NoisyLinear(nn.Linear): 45 | """ 46 | Reference implementation: 47 | https://github.com/Kaixhin/NoisyNet-A3C/blob/master/model.py 48 | """ 49 | 50 | def __init__(self, in_features, out_features, sigma_init=0.017, bias=True): 51 | super(NoisyLinear, self).__init__(in_features, out_features, bias=True) 52 | # µ^w and µ^b reuse self.weight and self.bias 53 | self.sigma_init = sigma_init 54 | self.sigma_weight = Parameter( 55 | torch.Tensor(out_features, in_features) 56 | ) # σ^w 57 | self.sigma_bias = Parameter(torch.Tensor(out_features)) # σ^b 58 | self.init_params() 59 | 60 | def init_params(self): 61 | limit = math.sqrt(3 / self.in_features) 62 | 63 | self.weight.data.uniform_(-limit, limit) 64 | self.bias.data.uniform_(-limit, limit) 65 | self.sigma_weight.data.fill_(self.sigma_init) 66 | self.sigma_bias.data.fill_(self.sigma_init) 67 | 68 | def forward(self, x, internals): 69 | if self.training: 70 | w = self.weight + self.sigma_weight * internals[0] 71 | b = self.bias + self.sigma_bias * internals[1] 72 | else: 73 | w = self.weight + self.sigma_weight 74 | b = self.bias + self.sigma_bias 75 | return F.linear(x, w, b) 76 | 77 | def batch_forward(self, x, internals, batch_size=None): 78 | print( 79 | "WARNING: calling forward multiple times is actually" 80 | "faster than this and takes less memory" 81 | ) 82 | batch_size = batch_size if batch_size is not None else x.shape[0] 83 | x = x.unsqueeze(1) 84 | # internals come in as [[w, b], ...] reshape to [w, ...], [b, ...] 85 | eps_w, eps_b = zip(*internals) 86 | eps_w = torch.stack(eps_w) 87 | eps_b = torch.stack(eps_b) 88 | batch_w = self.weight.unsqueeze(0).expand( 89 | batch_size, -1, -1 90 | ) + self.sigma_weight.unsqueeze(0).expand(batch_size, -1, -1) 91 | batch_w += eps_w 92 | # permute to b x m x p 93 | batch_w = batch_w.permute(0, 2, 1) 94 | batch_b = self.bias.expand(batch_size, -1) + self.sigma_bias.expand( 95 | batch_size, -1 96 | ) 97 | batch_b += eps_b 98 | 99 | bmm = torch.bmm(x, batch_w).squeeze(1) 100 | 101 | return bmm + batch_b 102 | 103 | def reset(self, gpu=False, device=None): 104 | # sample new noise 105 | if not gpu: 106 | return ( 107 | torch.randn(self.out_features, self.in_features).detach(), 108 | torch.randn(self.out_features).detach(), 109 | ) 110 | else: 111 | return ( 112 | torch.randn(self.out_features, self.in_features) 113 | .cuda(device, non_blocking=True) 114 | .detach(), 115 | torch.randn(self.out_features) 116 | .cuda(device, non_blocking=True) 117 | .detach(), 118 | ) 119 | 120 | def get_parameter_names(self): 121 | return ["W", "b", "sigma_W", "sigma_b"] 122 | -------------------------------------------------------------------------------- /mcs/modules/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import torch 16 | 17 | 18 | class Identity(torch.nn.Module): 19 | def forward(self, x): 20 | return x 21 | -------------------------------------------------------------------------------- /mcs/modules/sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import torch 16 | from torch.nn import Module, Linear, LayerNorm 17 | 18 | 19 | class LSTMCellLayerNorm(Module): 20 | """ 21 | A lstm cell that layer norms the cell state 22 | https://github.com/seba-1511/lstms.pth/blob/master/lstms/lstm.py for reference. 23 | Original License Apache 2.0 24 | """ 25 | 26 | def __init__(self, input_size, hidden_size, bias=True, forget_bias=0): 27 | super().__init__() 28 | self.input_size = input_size 29 | self.hidden_size = hidden_size 30 | self.ih = Linear(input_size, 4 * hidden_size, bias=bias) 31 | self.hh = Linear(hidden_size, 4 * hidden_size, bias=bias) 32 | 33 | if bias: 34 | self.ih.bias.data.fill_(0) 35 | self.hh.bias.data.fill_(0) 36 | # forget bias init 37 | self.ih.bias.data[hidden_size : hidden_size * 2].fill_(forget_bias) 38 | self.hh.bias.data[hidden_size : hidden_size * 2].fill_(forget_bias) 39 | 40 | self.ln_cell = LayerNorm(hidden_size) 41 | 42 | def forward(self, x, hidden): 43 | """ 44 | LSTM Cell that layer normalizes the cell state. 45 | :param x: Tensor{B, C} 46 | :param hidden: A Tuple[Tensor{B, C}, Tensor{B, C}] of (previous output, cell state) 47 | :return: 48 | """ 49 | h, c = hidden 50 | 51 | # Linear mappings 52 | i2h = self.ih(x) 53 | h2h = self.hh(h) 54 | preact = i2h + h2h 55 | 56 | # activations 57 | gates = preact[:, : 3 * self.hidden_size].sigmoid() 58 | g_t = preact[:, 3 * self.hidden_size :].tanh() 59 | i_t = gates[:, : self.hidden_size] 60 | f_t = gates[:, self.hidden_size : 2 * self.hidden_size] 61 | o_t = gates[:, -self.hidden_size :] 62 | 63 | # cell computations 64 | c_t = torch.mul(c, f_t) + torch.mul(i_t, g_t) 65 | c_t = self.ln_cell(c_t) 66 | h_t = torch.mul(o_t, c_t.tanh()) 67 | 68 | return h_t, c_t 69 | -------------------------------------------------------------------------------- /mcs/modules/spatial.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from torch import nn as nn 16 | from torch.nn import functional as F 17 | 18 | 19 | class Residual2DPreact(nn.Module): 20 | def __init__(self, nb_in_chan, nb_out_chan, stride=1): 21 | super(Residual2DPreact, self).__init__() 22 | 23 | self.nb_in_chan = nb_in_chan 24 | self.nb_out_chan = nb_out_chan 25 | self.stride = stride 26 | 27 | self.bn1 = nn.BatchNorm2d(nb_in_chan) 28 | self.conv1 = nn.Conv2d( 29 | nb_in_chan, nb_out_chan, 3, stride=stride, padding=1, bias=False 30 | ) 31 | self.bn2 = nn.BatchNorm2d(nb_out_chan) 32 | self.conv2 = nn.Conv2d( 33 | nb_out_chan, nb_out_chan, 3, stride=1, padding=1, bias=False 34 | ) 35 | 36 | relu_gain = nn.init.calculate_gain("relu") 37 | self.conv1.weight.data.mul_(relu_gain) 38 | self.conv2.weight.data.mul_(relu_gain) 39 | 40 | self.do_projection = ( 41 | self.nb_in_chan != self.nb_out_chan or self.stride > 1 42 | ) 43 | if self.do_projection: 44 | self.projection = nn.Conv2d( 45 | nb_in_chan, nb_out_chan, 3, stride=stride, padding=1 46 | ) 47 | self.projection.weight.data.mul_(relu_gain) 48 | 49 | def forward(self, x): 50 | first = F.relu(self.bn1(x)) 51 | if self.do_projection: 52 | projection = self.projection(first) 53 | else: 54 | projection = x 55 | x = self.conv1(first) 56 | x = self.conv2(F.relu(self.bn2(x))) 57 | return x + projection 58 | -------------------------------------------------------------------------------- /mcs/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .base.network_module import NetworkModule 2 | from .modular_network import ModularNetwork 3 | from .net1d.submodule_1d import SubModule1D 4 | from .net2d.submodule_2d import SubModule2D 5 | from .net3d.submodule_3d import SubModule3D 6 | from .net4d.submodule_4d import SubModule4D 7 | 8 | from .net1d.linear import Linear 9 | from .net1d.identity_1d import Identity1D 10 | from .net1d.lstm import LSTM 11 | 12 | from .net2d.identity_2d import Identity2D 13 | from .net3d.identity_3d import Identity3D 14 | from .net3d.four_conv import FourConv 15 | 16 | from .net4d.identity_4d import Identity4D 17 | from .net2d.deconv import DeConv2D 18 | 19 | 20 | 21 | NET_REG = [] 22 | SUBMOD_REG = [ 23 | Identity1D, 24 | Linear, 25 | LSTM, 26 | Identity2D, 27 | Identity3D, 28 | FourConv, 29 | Identity4D, 30 | DeConv2D 31 | 32 | ] 33 | -------------------------------------------------------------------------------- /mcs/network/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/__pycache__/modular_network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/modular_network.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/__pycache__/modular_network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/__pycache__/modular_network.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__init__.py -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/network_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/network_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/network_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/network_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/submodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/submodule.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/base/__pycache__/submodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/base/__pycache__/submodule.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/base/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import abc 16 | 17 | import torch 18 | from torch import distributed as dist 19 | 20 | 21 | class BaseNetwork(torch.nn.Module): 22 | @classmethod 23 | @abc.abstractmethod 24 | def from_args( 25 | cls, args, observation_space, output_space, gpu_preprocessor, net_reg 26 | ): 27 | raise NotImplementedError 28 | 29 | @abc.abstractmethod 30 | def new_internals(self, device): 31 | """ 32 | :return: Dict[InternalKey, torch.Tensor (ND)] 33 | """ 34 | raise NotImplementedError 35 | 36 | @abc.abstractmethod 37 | def forward(self, observation, internals): 38 | raise NotImplementedError 39 | 40 | def internal_space(self): 41 | return {k: t.shape for k, t in self.new_internals("cpu").items()} 42 | 43 | def sync(self, src, grp=None, async_op=False): 44 | 45 | keys = [] 46 | handles = [] 47 | 48 | for k, t in self.state_dict().items(): 49 | if grp is None: 50 | h = dist.broadcast(t, src, async_op=True) 51 | else: 52 | h = dist.broadcast(t, src, grp, async_op=True) 53 | 54 | keys.append(k) 55 | handles.append(h) 56 | 57 | if not async_op: 58 | for k, h in zip(keys, handles): 59 | h.wait() 60 | 61 | return handles 62 | -------------------------------------------------------------------------------- /mcs/network/base/network_module.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import abc 16 | 17 | from mcs.network.base.base import BaseNetwork 18 | from mcs.utils.requires_args import RequiresArgsMixin 19 | 20 | 21 | class NetworkModule(BaseNetwork, RequiresArgsMixin, metaclass=abc.ABCMeta): 22 | pass 23 | -------------------------------------------------------------------------------- /mcs/network/net1d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__init__.py -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/identity_1d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/identity_1d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/identity_1d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/identity_1d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/linear.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/linear.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/linear.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/linear.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/lstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/lstm.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/lstm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/lstm.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/submodule_1d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/submodule_1d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/__pycache__/submodule_1d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net1d/__pycache__/submodule_1d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net1d/identity_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from .submodule_1d import SubModule1D 16 | 17 | 18 | class Identity1D(SubModule1D): 19 | args = {} 20 | 21 | def __init__(self, input_shape, id): 22 | super().__init__(input_shape, id) 23 | 24 | @classmethod 25 | def from_args(cls, args, input_shape, id): 26 | return cls(input_shape, id) 27 | 28 | @property 29 | def _output_shape(self): 30 | return self.input_shape 31 | 32 | def _forward(self, input, internals, **kwargs): 33 | return input, {} 34 | 35 | def _new_internals(self): 36 | return {} 37 | -------------------------------------------------------------------------------- /mcs/network/net1d/linear.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from __future__ import division 16 | 17 | from torch import nn 18 | from torch.nn import functional as F 19 | 20 | from mcs.modules import Identity 21 | 22 | from .submodule_1d import SubModule1D 23 | 24 | 25 | class Linear(SubModule1D): 26 | args = {"linear_normalize": "bn", "linear_nb_hidden": 512, "nb_layer": 3} 27 | 28 | def __init__(self, input_shape, id, normalize, nb_hidden, nb_layer): 29 | super().__init__(input_shape, id) 30 | self._nb_hidden = nb_hidden 31 | 32 | nb_input_channel = input_shape[0] 33 | 34 | bias = not normalize 35 | self.linears = nn.ModuleList( 36 | [ 37 | nn.Linear( 38 | nb_input_channel if i == 0 else nb_hidden, nb_hidden, bias 39 | ) 40 | for i in range(nb_layer) 41 | ] 42 | ) 43 | if normalize == "bn": 44 | self.norms = nn.ModuleList( 45 | [nn.BatchNorm1d(nb_hidden) for _ in range(nb_layer)] 46 | ) 47 | elif normalize == "gn": 48 | if nb_hidden % 16 != 0: 49 | raise Exception( 50 | "linear_nb_hidden must be divisible by 16 for Group Norm" 51 | ) 52 | self.norms = nn.ModuleList( 53 | [ 54 | nn.GroupNorm(nb_hidden // 16, nb_hidden) 55 | for _ in range(nb_layer) 56 | ] 57 | ) 58 | else: 59 | self.norms = nn.ModuleList([Identity() for _ in range(nb_layer)]) 60 | 61 | @classmethod 62 | def from_args(cls, args, input_shape, id): 63 | return cls( 64 | input_shape, 65 | id, 66 | args.linear_normalize, 67 | args.linear_nb_hidden, 68 | args.nb_layer, 69 | ) 70 | 71 | def _forward(self, xs, internals, **kwargs): 72 | for linear, norm in zip(self.linears, self.norms): 73 | xs = F.relu(norm(linear(xs))) 74 | return xs, {} 75 | 76 | def _new_internals(self): 77 | return {} 78 | 79 | @property 80 | def _output_shape(self): 81 | return (self._nb_hidden,) 82 | -------------------------------------------------------------------------------- /mcs/network/net1d/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import torch 16 | from torch.nn import LSTMCell, Sequential, Linear, ReLU 17 | 18 | from mcs.modules import LSTMCellLayerNorm 19 | from .submodule_1d import SubModule1D 20 | 21 | 22 | class LSTM(SubModule1D): 23 | args = {"lstm_normalize": True, "lstm_nb_hidden": 512} 24 | 25 | def __init__(self, input_shape, id, normalize, nb_hidden): 26 | super().__init__(input_shape, id) 27 | self._nb_hidden = nb_hidden 28 | # self.fc = Sequential( 29 | # Linear(input_shape[0], 1024), ReLU(), 30 | # Linear(1024, 512), ReLU() 31 | # ) 32 | 33 | if normalize: 34 | self.lstm = LSTMCellLayerNorm(input_shape[0], nb_hidden) 35 | else: 36 | self.lstm = LSTMCell(input_shape[0], nb_hidden) 37 | self.lstm.bias_ih.data.fill_(0) 38 | self.lstm.bias_hh.data.fill_(0) 39 | 40 | @classmethod 41 | def from_args(cls, args, input_shape, id): 42 | return cls(input_shape, id, args.lstm_normalize, args.lstm_nb_hidden) 43 | 44 | @property 45 | def _output_shape(self): 46 | return (self._nb_hidden,) 47 | 48 | def _forward(self, xs, internals, **kwargs): 49 | # xs = self.fc(xs) 50 | # print(xs.shape) 51 | # print(internals) 52 | # print('!!!!!!!!!!!!!!!!!') 53 | #print(internals) 54 | hxs = self.stacked_internals("hx", internals) 55 | cxs = self.stacked_internals("cx", internals) 56 | hxs, cxs = self.lstm(xs, (hxs, cxs)) 57 | return ( 58 | hxs, 59 | { 60 | "hx": list(torch.unbind(hxs, dim=0)), 61 | "cx": list(torch.unbind(cxs, dim=0)), 62 | }, 63 | ) 64 | 65 | def _new_internals(self): 66 | return { 67 | "hx": torch.zeros(self._nb_hidden), 68 | "cx": torch.zeros(self._nb_hidden), 69 | } 70 | 71 | def _to_2d_shape(self): 72 | return self._output_shape + (1,) -------------------------------------------------------------------------------- /mcs/network/net1d/submodule_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from mcs.network.base.submodule import SubModule 16 | import abc 17 | 18 | 19 | class SubModule1D(SubModule, metaclass=abc.ABCMeta): 20 | dim = 1 21 | 22 | def __init__(self, input_shape, id): 23 | super(SubModule1D, self).__init__(input_shape, id) 24 | 25 | def _to_1d_shape(self): 26 | return self._output_shape 27 | 28 | def _to_2d_shape(self): 29 | return self._output_shape + (1,) 30 | 31 | def _to_3d_shape(self): 32 | return self._output_shape + (1, 1) 33 | 34 | def _to_4d_shape(self): 35 | return self._output_shape + (1, 1, 1) 36 | -------------------------------------------------------------------------------- /mcs/network/net2d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__init__.py -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/deconv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/deconv.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/deconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/deconv.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/identity_2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/identity_2d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/identity_2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/identity_2d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/submodule_2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/submodule_2d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/__pycache__/submodule_2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net2d/__pycache__/submodule_2d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net2d/deconv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | import torch.nn as nn 17 | import torch 18 | from ..net2d.submodule_2d import SubModule2D 19 | 20 | 21 | class DeConv2D(SubModule2D): 22 | args = {} 23 | 24 | def __init__(self, input_shape, id): 25 | super().__init__(input_shape, id) 26 | self._input_shape = input_shape 27 | self._out_shape = None 28 | self._pc_layers = 16 29 | self.num_action = 6 30 | 31 | self.fc = nn.Sequential( 32 | nn.Linear(in_features=input_shape[0], out_features=9 * 9 * self._pc_layers), nn.ReLU()) 33 | 34 | self.deconv_v = nn.Sequential( 35 | nn.ConvTranspose2d(in_channels=self._pc_layers, out_channels=1, kernel_size=5, stride=2), 36 | nn.ReLU()) 37 | self.deconv_a = nn.Sequential( 38 | nn.ConvTranspose2d(in_channels=self._pc_layers, out_channels=self.num_action, kernel_size=5, stride=2), 39 | nn.ReLU()) 40 | 41 | @classmethod 42 | def from_args(cls,args,input_shape, id): 43 | return cls(input_shape, id) 44 | 45 | @property 46 | def _output_shape(self): 47 | if self._out_shape is None: 48 | self._out_shape = 21*21,self.num_action 49 | return self._out_shape 50 | 51 | def _forward(self, input, internals=None, **kwargs): 52 | #print(input.shape) # 32 512 1 53 | input = self.fc(input.view(-1, 512)) 54 | input = input.view([-1, self._pc_layers, 9, 9]) 55 | v = self.deconv_v(input) 56 | a = self.deconv_a(input) 57 | a_mean = torch.mean(a, dim=1, keepdim=True) 58 | q = v + a - a_mean 59 | out = q.reshape(-1, self.num_action, 21 * 21) 60 | #print(out.shape) # 32 6 441 61 | return out.permute(0, 2, 1).contiguous(),{} 62 | 63 | def _new_internals(self): 64 | return {} 65 | 66 | @_output_shape.setter 67 | def _output_shape(self, value): 68 | self.__output_shape = value 69 | 70 | 71 | 72 | def calc_output_dim(dim_size, kernel_size, stride=1, input_padding=0, output_padding=0, dilation=1): 73 | numerator = (dim_size - 1) * stride - 2 * input_padding + dilation * (kernel_size - 1) + output_padding + 1 74 | return numerator 75 | 76 | 77 | if __name__ == "__main__": 78 | out = [4, 5, 1] 79 | for output_dim in out: 80 | output_dim = calc_output_dim(output_dim, 5, 2) 81 | print(output_dim) 82 | -------------------------------------------------------------------------------- /mcs/network/net2d/identity_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from .submodule_2d import SubModule2D 16 | 17 | 18 | class Identity2D(SubModule2D): 19 | args = {} 20 | 21 | def __init__(self, input_shape, id): 22 | super().__init__(input_shape, id) 23 | 24 | @classmethod 25 | def from_args(cls, args, input_shape, id): 26 | return cls(input_shape, id) 27 | 28 | @property 29 | def _output_shape(self): 30 | return self.input_shape 31 | 32 | def _forward(self, input, internals, **kwargs): 33 | return input, {} 34 | 35 | def _new_internals(self): 36 | return {} 37 | -------------------------------------------------------------------------------- /mcs/network/net2d/submodule_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from ..base.submodule import SubModule 16 | import abc 17 | import math 18 | 19 | 20 | class SubModule2D(SubModule, metaclass=abc.ABCMeta): 21 | dim = 2 22 | 23 | def __init__(self, input_shape, id): 24 | super(SubModule2D, self).__init__(input_shape, id) 25 | 26 | def _to_1d_shape(self): 27 | f, s = self._output_shape 28 | return (f * s,) 29 | 30 | def _to_2d_shape(self): 31 | return self._output_shape 32 | 33 | def _to_3d_shape(self): 34 | f, s = self._output_shape 35 | return (f * s, 1, 1) 36 | 37 | def _to_4d_shape(self): 38 | f, s = self._output_shape 39 | return (f, s, 1, 1) 40 | -------------------------------------------------------------------------------- /mcs/network/net3d/RelationalMHDPA.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from mcs.network.net2d.submodule_2d import SubModule2D 16 | import torch.nn as nn 17 | import torch 18 | import numpy as np 19 | import math 20 | 21 | 22 | class RelationalMHDPA(nn.Module): 23 | args = {} 24 | 25 | def __init__(self, input_shape, nb_head, scale=False): 26 | super(RelationalMHDPA, self).__init__() 27 | self._out_shape = input_shape 28 | 29 | 30 | nb_channel = input_shape[0] 31 | height = input_shape[1] 32 | width = input_shape[2] 33 | 34 | assert nb_channel % nb_head == 0 35 | seq_len = height * width 36 | self.register_buffer( 37 | "b", 38 | torch.tril(torch.ones(seq_len, seq_len)).view( 39 | 1, 1, seq_len, seq_len 40 | ), 41 | ) 42 | self.nb_head = nb_head 43 | self.split_size = nb_channel 44 | self.scale = scale 45 | self.projection = nn.Linear(nb_channel, nb_channel * 3) 46 | self.re = nn.ReLU() 47 | self.mlp = nn.Linear(nb_channel, nb_channel) 48 | 49 | def forward(self, x): 50 | """ 51 | :param x: A tensor with a shape of [batch, seq_len, nb_channel] 52 | :return: A tensor with a shape of [batch, seq_len, nb_channel] 53 | """ 54 | 55 | size_out = x.size()[:-1] + (self.split_size * 3,) # [batch, seq_len, nb_channel*3] 56 | 57 | x = self.projection(x.view(-1, x.size(-1))) # [BT,C] 58 | # x = self.re(x) 59 | x = x.view(*size_out) # [B,T,3C] 60 | 61 | query, key, value = x.split(self.split_size, dim=2) 62 | query = self.split_heads(query) 63 | key = self.split_heads(key, k=True) 64 | value = self.split_heads(value) 65 | 66 | a = self._attn(query, key, value) 67 | e = self.merge_heads(a) 68 | 69 | return self.mlp(e) 70 | 71 | def _new_internals(self): 72 | return {} 73 | 74 | def _attn(self, q, k, v): 75 | w = torch.matmul(q, k) 76 | 77 | if self.scale: 78 | w = w / math.sqrt(v.size(-1)) 79 | w = w * self.b + -1e9 * ( 80 | 1 - self.b 81 | ) # TF implem method: mask_attn_weights 82 | w = nn.Softmax(dim=-1)(w) 83 | h = torch.matmul(w, v) 84 | 85 | return h 86 | 87 | def merge_heads(self, x): 88 | x = x.permute(0, 2, 1, 3).contiguous() 89 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 90 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 91 | 92 | def split_heads(self, x, k=False): 93 | # keep dims, but expand the last dim to be [head, chan // head] X[B,T,C] 94 | new_x_shape = x.size()[:-1] + (self.nb_head, x.size(-1) // self.nb_head) # [B,T,H,C//H] 95 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 96 | if k: 97 | # batch, head, channel, attend 98 | return x.permute(0, 2, 3, 1) 99 | else: 100 | # batch, head, attend, channel 101 | return x.permute(0, 2, 1, 3) 102 | 103 | def get_parameter_names(self, layer): 104 | return [ 105 | "Proj{}_W".format(layer), 106 | "Proj{}_b".format(layer), 107 | "MLP{}_W".format(layer), 108 | "MLP{}_b".format(layer), 109 | ] 110 | -------------------------------------------------------------------------------- /mcs/network/net3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__init__.py -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/RelationalMHDPA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/RelationalMHDPA.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/deconv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/deconv.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/four_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/four_conv.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/four_conv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/four_conv.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/identity_3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/identity_3d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/identity_3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/identity_3d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/submodule_3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/submodule_3d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/__pycache__/submodule_3d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net3d/__pycache__/submodule_3d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net3d/four_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from torch.nn import Conv2d, BatchNorm2d, GroupNorm, init, functional as F 16 | 17 | from mcs.modules import Identity 18 | from mcs.network.net3d.submodule_3d import SubModule3D 19 | from mcs.network.net3d.RelationalMHDPA import RelationalMHDPA 20 | 21 | 22 | class FourConv(SubModule3D): 23 | args = {"fourconv_norm": "bn"} 24 | 25 | def __init__(self, in_shape, id, normalize, args): 26 | super().__init__(in_shape, id) 27 | bias = not normalize 28 | self._in_shape = in_shape 29 | self._out_shape = None 30 | self._args = args 31 | self.conv1 = Conv2d(in_shape[0], 32, 7, stride=2, padding=1, bias=bias) 32 | self.conv2 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias) 33 | self.conv3 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias) 34 | self.conv4 = Conv2d(32, 32, 3, stride=2, padding=1, bias=bias) 35 | 36 | self.use_mhra = args.use_mhra 37 | if normalize == "bn": 38 | self.bn1 = BatchNorm2d(32) 39 | self.bn2 = BatchNorm2d(32) 40 | self.bn3 = BatchNorm2d(32) 41 | self.bn4 = BatchNorm2d(32) 42 | elif normalize == "gn": 43 | self.bn1 = GroupNorm(8, 32) 44 | self.bn2 = GroupNorm(8, 32) 45 | self.bn3 = GroupNorm(8, 32) 46 | self.bn4 = GroupNorm(8, 32) 47 | else: 48 | self.bn1 = Identity() 49 | self.bn2 = Identity() 50 | self.bn3 = Identity() 51 | self.bn4 = Identity() 52 | if args.use_mhra: 53 | self.att = RelationalMHDPA(input_shape=(32, 5, 5), nb_head=args.num_head) 54 | 55 | relu_gain = init.calculate_gain("relu") 56 | self.conv1.weight.data.mul_(relu_gain) 57 | self.conv2.weight.data.mul_(relu_gain) 58 | self.conv3.weight.data.mul_(relu_gain) 59 | self.conv4.weight.data.mul_(relu_gain) 60 | 61 | @classmethod 62 | def from_args(cls, args, in_shape, id): 63 | return cls(in_shape, id, args.linear_normalize,args) 64 | 65 | @property 66 | def _output_shape(self): 67 | # For 84x84, (32, 5, 5) 68 | if self._out_shape is None: 69 | output_dim = calc_output_dim(self._in_shape[1], 7, 2, 1, 1) 70 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 71 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 72 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 73 | self._out_shape = 32, output_dim, output_dim 74 | return self._out_shape 75 | 76 | def _forward(self, xs, internals, **kwargs): 77 | 78 | xs = F.relu(self.bn1(self.conv1(xs))) 79 | xs = F.relu(self.bn2(self.conv2(xs))) 80 | xs = F.relu(self.bn3(self.conv3(xs))) 81 | xs = F.relu(self.bn4(self.conv4(xs))) 82 | if self.use_mhra: 83 | xs =xs + self._atten(xs, self.att) 84 | return xs, {} 85 | 86 | def _new_internals(self): 87 | return {} 88 | 89 | def _atten(self, x, att): 90 | W, H = x.shape[-2:] 91 | h = x.view(-1, 32, W * H) 92 | h = h.permute(0, 2, 1).contiguous() 93 | h = att(h) 94 | h = h.permute(0, 2, 1).contiguous() 95 | h = h.view(-1, 32, W, H) 96 | return h 97 | 98 | 99 | def calc_output_dim(dim_size, kernel_size, stride, padding, dilation): 100 | numerator = dim_size + 2 * padding - dilation * (kernel_size - 1) - 1 101 | return numerator // stride + 1 102 | 103 | 104 | if __name__ == "__main__": 105 | output_dim = 84 106 | output_dim = calc_output_dim(output_dim, 7, 2, 1, 1) 107 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 108 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 109 | print(output_dim) 110 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 111 | print(output_dim) # should be 5 112 | -------------------------------------------------------------------------------- /mcs/network/net3d/identity_3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from .submodule_3d import SubModule3D 16 | 17 | 18 | class Identity3D(SubModule3D): 19 | args = {} 20 | 21 | def __init__(self, input_shape, id): 22 | super().__init__(input_shape, id) 23 | 24 | @classmethod 25 | def from_args(cls, args, input_shape, id): 26 | return cls(input_shape, id) 27 | 28 | @property 29 | def _output_shape(self): 30 | return self.input_shape 31 | 32 | def _forward(self, input, internals, **kwargs): 33 | return input, {} 34 | 35 | def _new_internals(self): 36 | return {} 37 | -------------------------------------------------------------------------------- /mcs/network/net3d/rmc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import torch 16 | from torch.nn import Conv2d, Linear, BatchNorm2d, BatchNorm1d, functional as F 17 | 18 | from mcs.modules import RMCCell, Identity 19 | 20 | from mcs.network.net3d.submodule_3d import SubModule3D 21 | 22 | # TODO 23 | class RMC(SubModule3D): 24 | """ 25 | Relational Memory Core 26 | https://arxiv.org/pdf/1806.01822.pdf 27 | """ 28 | 29 | def __init__(self, nb_in_chan, output_shape_dict, normalize): 30 | self.embedding_size = 512 31 | super(RMC, self).__init__(self.embedding_size, output_shape_dict) 32 | bias = not normalize 33 | self.conv1 = Conv2d( 34 | nb_in_chan, 32, kernel_size=3, stride=2, padding=1, bias=bias 35 | ) 36 | self.conv2 = Conv2d( 37 | 32, 32, kernel_size=3, stride=2, padding=1, bias=bias 38 | ) 39 | self.conv3 = Conv2d( 40 | 32, 32, kernel_size=3, stride=2, padding=1, bias=bias 41 | ) 42 | self.attention = RMCCell(100, 100, 34) 43 | self.conv4 = Conv2d( 44 | 34, 8, kernel_size=3, stride=1, padding=1, bias=bias 45 | ) 46 | # BATCH x 8 x 10 x 10 47 | self.linear = Linear(800, 512, bias=bias) 48 | 49 | if normalize: 50 | self.bn1 = BatchNorm2d(32) 51 | self.bn2 = BatchNorm2d(32) 52 | self.bn3 = BatchNorm2d(32) 53 | self.bn4 = BatchNorm2d(8) 54 | self.bn_linear = BatchNorm1d(512) 55 | else: 56 | self.bn1 = Identity() 57 | self.bn2 = Identity() 58 | self.bn3 = Identity() 59 | self.bn_linear = Identity() 60 | 61 | def forward(self, input, prev_memories): 62 | """ 63 | :param input: Tensor{B, C, H, W} 64 | :param prev_memories: Tuple{B}[Tensor{C}] 65 | :return: 66 | """ 67 | 68 | x = F.relu(self.bn1(self.conv1(input))) 69 | x = F.relu(self.bn2(self.conv2(x))) 70 | x = F.relu(self.bn3(self.conv3(x))) 71 | 72 | h = x.size(2) 73 | w = x.size(3) 74 | xs_chan = ( 75 | torch.linspace(-1, 1, w) 76 | .view(1, 1, 1, w) 77 | .expand(input.size(0), 1, w, w) 78 | .to(input.device) 79 | ) 80 | ys_chan = ( 81 | torch.linspace(-1, 1, h) 82 | .view(1, 1, h, 1) 83 | .expand(input.size(0), 1, h, h) 84 | .to(input.device) 85 | ) 86 | x = torch.cat([x, xs_chan, ys_chan], dim=1) 87 | 88 | # need to transpose because attention expects 89 | # attention dim before channel dim 90 | x = x.view(x.size(0), x.size(1), h * w).transpose(1, 2) 91 | prev_memories = torch.stack(prev_memories) 92 | x = next_memories = self.attention(x.contiguous(), prev_memories) 93 | # need to undo the transpose before output 94 | x = x.transpose(1, 2) 95 | x = x.view(x.size(0), x.size(1), h, w) 96 | 97 | x = F.relu(self.bn4(self.conv4(x))) 98 | x = x.view(x.size(0), -1) 99 | x = F.relu(self.bn_linear(self.linear(x))) 100 | return x, list(torch.unbind(next_memories, 0)) 101 | 102 | @classmethod 103 | def from_args(cls, args, in_shape, id): 104 | return cls(in_shape, id, args.fourconv_norm) 105 | 106 | @property 107 | def _output_shape(self): 108 | # For 84x84, (32, 5, 5) 109 | if self._out_shape is None: 110 | output_dim = calc_output_dim(self._in_shape[1], 3, 2, 1, 1) 111 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 112 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 113 | output_dim = calc_output_dim(output_dim, 3, 1, 1, 1) 114 | self._out_shape = 32, output_dim, output_dim 115 | return self._out_shape 116 | 117 | def _forward(self, xs, internals, **kwargs): 118 | 119 | xs = F.relu(self.bn1(self.conv1(xs))) 120 | xs = F.relu(self.bn2(self.conv2(xs))) 121 | xs = F.relu(self.bn3(self.conv3(xs))) 122 | xs = F.relu(self.bn4(self.conv4(xs))) 123 | return xs, {} 124 | 125 | def _new_internals(self): 126 | return {} 127 | 128 | 129 | def calc_output_dim(dim_size, kernel_size, stride, padding, dilation): 130 | numerator = dim_size + 2 * padding - dilation * (kernel_size - 1) - 1 131 | return numerator // stride + 1 132 | 133 | if __name__ == "__main__": 134 | output_dim = 84 135 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 136 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 137 | output_dim = calc_output_dim(output_dim, 3, 2, 1, 1) 138 | output_dim = calc_output_dim(output_dim, 3, 1, 1, 1) 139 | print(output_dim) # should be 5 140 | 141 | -------------------------------------------------------------------------------- /mcs/network/net3d/submodule_3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from ..base.submodule import SubModule 16 | import abc 17 | 18 | 19 | class SubModule3D(SubModule, metaclass=abc.ABCMeta): 20 | dim = 3 21 | 22 | def __init__(self, input_shape, id): 23 | super(SubModule3D, self).__init__(input_shape, id) 24 | 25 | def _to_1d_shape(self): 26 | f, h, w = self._output_shape 27 | return (f * h * w,) 28 | 29 | def _to_2d_shape(self): 30 | f, h, w = self._output_shape 31 | return (f, h * w) 32 | 33 | def _to_3d_shape(self): 34 | return self._output_shape 35 | 36 | def _to_4d_shape(self): 37 | f, h, w = self._output_shape 38 | return (f, 1, h, w) 39 | -------------------------------------------------------------------------------- /mcs/network/net4d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__init__.py -------------------------------------------------------------------------------- /mcs/network/net4d/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net4d/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net4d/__pycache__/identity_4d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/identity_4d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net4d/__pycache__/identity_4d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/identity_4d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net4d/__pycache__/submodule_4d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/submodule_4d.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/network/net4d/__pycache__/submodule_4d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/network/net4d/__pycache__/submodule_4d.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/network/net4d/identity_4d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from .submodule_4d import SubModule4D 16 | 17 | 18 | class Identity4D(SubModule4D): 19 | args = {} 20 | 21 | def __init__(self, input_shape, id): 22 | super().__init__(input_shape, id) 23 | 24 | @classmethod 25 | def from_args(cls, args, input_shape, id): 26 | return cls(input_shape, id) 27 | 28 | @property 29 | def _output_shape(self): 30 | return self.input_shape 31 | 32 | def _forward(self, input, internals, **kwargs): 33 | return input, {} 34 | 35 | def _new_internals(self): 36 | return {} 37 | -------------------------------------------------------------------------------- /mcs/network/net4d/submodule_4d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from ..base.submodule import SubModule 16 | import abc 17 | 18 | 19 | class SubModule4D(SubModule, metaclass=abc.ABCMeta): 20 | dim = 4 21 | 22 | def __init__(self, input_shape, id): 23 | super(SubModule4D, self).__init__(input_shape, id) 24 | 25 | def _to_1d_shape(self): 26 | f, d, h, w = self._output_shape 27 | return (f * d * h * w,) 28 | 29 | def _to_2d_shape(self): 30 | f, d, h, w = self._output_shape 31 | return (f, d * h * w) 32 | 33 | def _to_3d_shape(self): 34 | f, d, h, w = self._output_shape 35 | return (f * d, h, w) 36 | 37 | def _to_4d_shape(self): 38 | return self._output_shape 39 | -------------------------------------------------------------------------------- /mcs/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 Heron Systems, Inc. 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | """ 17 | -------------------------------------------------------------------------------- /mcs/preprocess/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/preprocess/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/preprocess/__pycache__/observation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/observation.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/preprocess/__pycache__/observation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/observation.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/preprocess/__pycache__/ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/ops.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/preprocess/__pycache__/ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/preprocess/__pycache__/ops.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/preprocess/observation.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from copy import deepcopy 16 | 17 | 18 | class ObsPreprocessor: 19 | def __init__(self, ops, observation_space, observation_dtypes=None): 20 | """ 21 | :param ops: List[Operation] 22 | :param observation_space: Dict[ObsKey, Shape] 23 | :param observation_dtypes: Dict[ObsKey, dtype_str] 24 | """ 25 | cur_space = deepcopy(observation_space) 26 | cur_dtypes = deepcopy(observation_dtypes) 27 | 28 | rank_to_names = self._bld_rank_to_names(observation_space) 29 | 30 | for op in ops: 31 | if op.name_filters: 32 | names = op.name_filters 33 | elif op.rank_filters: 34 | names = [] 35 | for rank in op.rank_filters: 36 | names += rank_to_names[rank] 37 | else: 38 | names = list(cur_space.keys()) 39 | 40 | cur_space = self._update(names, cur_space, op.update_shape) 41 | if observation_dtypes: 42 | cur_dtypes = self._update(names, cur_dtypes, op.update_dtype) 43 | rank_to_names = self._bld_rank_to_names(observation_space) 44 | 45 | self.ops = ops 46 | self.observation_space = cur_space 47 | self.observation_dtypes = cur_dtypes 48 | self.rank_to_names = rank_to_names 49 | 50 | def __call__(self, obs): 51 | for op in self.ops: 52 | obs = op.update_obs(obs) 53 | return obs 54 | 55 | def reset(self): 56 | for o in self.ops: 57 | o.reset() 58 | 59 | def _bld_rank_to_names(self, obs_space): 60 | d = {1: [], 2: [], 3: [], 4: []} 61 | for name, shape in obs_space.items(): 62 | d[len(shape)].append(name) 63 | return d 64 | 65 | def _update(self, names, prev, fn): 66 | cur = {} 67 | for name in names: 68 | cur[name] = prev[name] 69 | del prev[name] 70 | update = fn(cur) 71 | return {**prev, **update} 72 | -------------------------------------------------------------------------------- /mcs/registry/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import Registry 2 | 3 | REGISTRY = Registry() 4 | -------------------------------------------------------------------------------- /mcs/registry/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/registry/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/registry/__pycache__/registry.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/registry.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/registry/__pycache__/registry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/registry/__pycache__/registry.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import RewardNormModule 2 | from .normalizers import Scale, Clip, Identity 3 | 4 | REWARD_NORM_REG = [Scale, Clip, Identity] 5 | -------------------------------------------------------------------------------- /mcs/rewardnorm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/__pycache__/normalizers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/normalizers.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/__pycache__/normalizers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/__pycache__/normalizers.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .rewnorm_module import RewardNormModule 2 | -------------------------------------------------------------------------------- /mcs/rewardnorm/base/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/rewardnorm/base/__pycache__/rewnorm_module.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/rewardnorm/base/rewnorm_module.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from mcs.utils.requires_args import RequiresArgsMixin 4 | 5 | 6 | class RewardNormModule(RequiresArgsMixin, metaclass=abc.ABCMeta): 7 | def __call__(self, reward): 8 | """ 9 | Normalizes a reward tensor. 10 | 11 | :param reward: torch.Tensor (1D) 12 | :return: 13 | """ 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /mcs/rewardnorm/normalizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import torch 16 | from .base import RewardNormModule 17 | 18 | 19 | class Clip(RewardNormModule): 20 | args = {"floor": -1, "ceil": 1} 21 | 22 | def __init__(self, floor, ceil): 23 | self.floor = floor 24 | self.ceil = ceil 25 | 26 | @classmethod 27 | def from_args(cls, args): 28 | return cls(args.floor, args.ceil) 29 | 30 | def __call__(self, reward): 31 | return torch.clamp(reward, self.floor, self.ceil) 32 | 33 | 34 | class Scale(RewardNormModule): 35 | args = {"coefficient": 0.1} 36 | 37 | def __init__(self, coefficient): 38 | self.coefficient = coefficient 39 | 40 | @classmethod 41 | def from_args(cls, args): 42 | return cls(args.coefficient) 43 | 44 | def __call__(self, reward): 45 | return self.coefficient * reward 46 | 47 | 48 | class Identity(RewardNormModule): 49 | args = {} 50 | 51 | @classmethod 52 | def from_args(cls, args): 53 | return cls() 54 | 55 | def __call__(self, item): 56 | return item 57 | -------------------------------------------------------------------------------- /mcs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | from .util import listd_to_dlist, dlist_to_listd, dtensor_to_dev 16 | -------------------------------------------------------------------------------- /mcs/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/logging.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/logging.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/requires_args.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/requires_args.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/requires_args.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/requires_args.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/script_helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/script_helpers.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/script_helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/script_helpers.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /mcs/utils/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-MCS/DRL-DisasterVC/6291eea53e28b32995e098ba9d15446fcd6ee261/mcs/utils/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /mcs/utils/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import os 16 | from collections import namedtuple 17 | from time import time 18 | 19 | import torch 20 | 21 | from mcs.utils.util import HeapQueue 22 | 23 | 24 | class ModelSaver: 25 | BufferEntry = namedtuple( 26 | "BufferEntry", ["reward", "priority", "network", "optimizer"] 27 | ) 28 | 29 | def __init__(self, nb_top_model, log_id_dir): 30 | self.nb_top_model = nb_top_model 31 | self._buffer = HeapQueue(nb_top_model) 32 | self._log_id_dir = log_id_dir 33 | 34 | def append_if_better(self, reward, network, optimizer): 35 | self._buffer.push( 36 | self.BufferEntry( 37 | reward, time(), network.state_dict(), optimizer.state_dict() 38 | ) 39 | ) 40 | 41 | def write_state_dicts(self, epoch_id): 42 | save_dir = os.path.join(self._log_id_dir, str(epoch_id)) 43 | if len(self._buffer) > 0: 44 | os.makedirs(save_dir) 45 | for j, buff_entry in enumerate(self._buffer.flush()): 46 | torch.save( 47 | buff_entry.network, 48 | os.path.join( 49 | save_dir, 50 | "model_{}_{}.pth".format(j + 1, int(buff_entry.reward)), 51 | ), 52 | ) 53 | torch.save( 54 | buff_entry.optimizer, 55 | os.path.join( 56 | save_dir, 57 | "optimizer_{}_{}.pth".format(j + 1, int(buff_entry.reward)), 58 | ), 59 | ) 60 | 61 | 62 | class SimpleModelSaver: 63 | def __init__(self, log_id_dir): 64 | self._log_id_dir = log_id_dir 65 | 66 | def save_state_dicts(self, network, step_count, optimizer=None): 67 | save_dir = os.path.join(self._log_id_dir, str(step_count)) 68 | os.makedirs(save_dir) 69 | torch.save( 70 | network.state_dict(), 71 | os.path.join(save_dir, "model_{}.pth".format(step_count)), 72 | ) 73 | if optimizer is not None: 74 | torch.save( 75 | optimizer.state_dict(), 76 | os.path.join(save_dir, "optimizer_{}.pth".format(step_count)), 77 | ) 78 | -------------------------------------------------------------------------------- /mcs/utils/requires_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import json 16 | import abc 17 | 18 | 19 | class RequiresArgsMixin(metaclass=abc.ABCMeta): 20 | """ 21 | This mixin makes it so that subclasses must implement an args class 22 | attribute. These arguments are parsed at runtime and the user is offered a 23 | chance to change any desired args. Classes the use this mixin must 24 | implement the from_args() class method. from_args() is essentially a 25 | secondary constructor. 26 | """ 27 | 28 | args = None # Dict[str, Any] 29 | 30 | @classmethod 31 | def check_args_implemented(cls): 32 | if cls.args is None: 33 | raise NotImplementedError( 34 | 'Subclass must define class attribute "args"' 35 | ) 36 | 37 | @classmethod 38 | def prompt(cls, provided=None): 39 | """ 40 | Display defaults as JSON, prompt user for changes. 41 | 42 | :param provided: Dict[str, Any], Override default prompts. 43 | :return: Dict[str, Any] Updated config dictionary. 44 | """ 45 | if provided is not None: 46 | overrides = {k: v for k, v in provided.items() if k in cls.args} 47 | args = {**cls.args, **overrides} 48 | else: 49 | args = cls.args 50 | return cls._prompt(cls.__name__, args) 51 | 52 | @staticmethod 53 | def _prompt(name, args): 54 | """ 55 | Display defaults as JSON, prompt user for changes. 56 | 57 | :param name: str Name of class 58 | :param args: Dict[str, Any] 59 | :return: Dict[str, Any] Updated config dictionary. 60 | """ 61 | if not args: 62 | return args 63 | 64 | user_input = input( 65 | "\n{} Defaults:\n{}\n" 66 | "Press ENTER to use defaults. Otherwise, " 67 | "modify JSON keys then press ENTER.\n".format( 68 | name, json.dumps(args, indent=2, sort_keys=True) 69 | ) 70 | + 'Example: {"x": True, "gamma": 0.001}\n' 71 | ) 72 | 73 | # use defaults if no changes specified 74 | if user_input == "": 75 | return args 76 | 77 | updates = json.loads(user_input) 78 | return {**args, **updates} 79 | 80 | @classmethod 81 | @abc.abstractmethod 82 | def from_args(cls, *argss, **kwargs): 83 | raise NotImplementedError 84 | -------------------------------------------------------------------------------- /mcs/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2018 Heron Systems, Inc. 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | import json 16 | import heapq 17 | from collections import OrderedDict 18 | 19 | import numpy as np 20 | import torch 21 | 22 | 23 | def listd_to_dlist(list_of_dicts): 24 | """ 25 | Converts a list of dictionaries to a dictionary of lists. Preserves key 26 | order. 27 | 28 | K is type of key. 29 | V is type of value. 30 | :param list_of_dicts: List[Dict[K, V]] 31 | :return: Dict[K, List[V]] 32 | """ 33 | new_dict = OrderedDict() 34 | for d in list_of_dicts: 35 | for k, v in d.items(): 36 | if k not in new_dict: 37 | new_dict[k] = [v] 38 | else: 39 | new_dict[k].append(v) 40 | return new_dict 41 | 42 | 43 | def dlist_to_listd(dict_of_lists): 44 | """ 45 | Converts a dictionary of lists to a list of dictionaries. Preserves key 46 | order. 47 | 48 | K is type of key. 49 | V is type of value. 50 | :param dict_of_lists: Dict[K, List[V]] 51 | :return: List[Dict[K, V]] 52 | """ 53 | keys = dict_of_lists.keys() 54 | list_len = len(dict_of_lists[next(iter(keys))]) 55 | new_list = [] 56 | for i in range(list_len): 57 | temp_d = OrderedDict() 58 | for k in keys: 59 | temp_d[k] = dict_of_lists[k][i] 60 | new_list.append(temp_d) 61 | return new_list 62 | 63 | 64 | def dtensor_to_dev(d_tensor, device): 65 | """ 66 | Move a dictionary of tensors to a device. 67 | 68 | :param d_tensor: Dict[str, Tensor] 69 | :param device: torch.device 70 | :return: Dict[str, Tensor] on desired device. 71 | """ 72 | return {k: v.to(device) for k, v in d_tensor.items()} 73 | 74 | 75 | def json_to_dict(file_path): 76 | """Read JSON config.""" 77 | json_object = json.load(open(file_path, "r")) 78 | return json_object 79 | 80 | 81 | _numpy_to_torch_dtype = { 82 | np.float16: torch.float16, 83 | np.float32: torch.float32, 84 | np.float64: torch.float64, 85 | np.uint8: torch.uint8, 86 | np.int8: torch.int8, 87 | np.int16: torch.int16, 88 | np.int32: torch.int32, 89 | np.int64: torch.int64, 90 | } 91 | _torch_to_numpy_dtype = {v: k for k, v in _numpy_to_torch_dtype.items()} 92 | 93 | 94 | def numpy_to_torch_dtype(dtype): 95 | 96 | # check if dtype is weird and convert to familiar format 97 | if type(dtype) == np.dtype: 98 | dtype = dtype.type 99 | 100 | if dtype not in _numpy_to_torch_dtype: 101 | raise ValueError( 102 | "Could not convert numpy dtype {} to a torch dtype.".format( 103 | dtype 104 | ) 105 | ) 106 | 107 | return _numpy_to_torch_dtype[dtype] 108 | 109 | 110 | def torch_to_numpy_dtype(dtype): 111 | if dtype not in _torch_to_numpy_dtype: 112 | raise ValueError( 113 | "Could not convert torch dtype {} to a numpy dtype.".format( 114 | dtype 115 | ) 116 | ) 117 | 118 | return _torch_to_numpy_dtype[dtype] 119 | 120 | 121 | class CircularBuffer(object): 122 | def __init__(self, size): 123 | self.index = 0 124 | self.size = size 125 | self._data = [] 126 | 127 | def append(self, value): 128 | if len(self._data) == self.size: 129 | self._data[self.index] = value 130 | else: 131 | self._data.append(value) 132 | self.index = (self.index + 1) % self.size 133 | 134 | def is_empty(self): 135 | return self._data == [] 136 | 137 | def not_empty(self): 138 | return not self.is_empty() 139 | 140 | def is_full(self): 141 | return len(self) == self.size 142 | 143 | def not_full(self): 144 | return not self.is_full() 145 | 146 | def __getitem__(self, key): 147 | """get element by index like a regular array""" 148 | return self._data[key] 149 | 150 | def __setitem__(self, key, value): 151 | self._data[key] = value 152 | 153 | def __repr__(self): 154 | """return string representation""" 155 | return self._data.__repr__() + " (" + str(len(self._data)) + " items)" 156 | 157 | def __len__(self): 158 | return len(self._data) 159 | 160 | 161 | class HeapQueue: 162 | def __init__(self, maxlen): 163 | self.q = [] 164 | self.maxlen = maxlen 165 | 166 | def push(self, item): 167 | if len(self.q) < self.maxlen: 168 | heapq.heappush(self.q, item) 169 | else: 170 | heapq.heappushpop(self.q, item) 171 | 172 | def flush(self): 173 | q = self.q 174 | self.q = [] 175 | return q 176 | 177 | def __len__(self): 178 | return len(self.q) 179 | 180 | 181 | class DotDict(dict): 182 | """ 183 | Dictionary to access attributes 184 | """ 185 | 186 | __getattr__ = dict.get 187 | __setattr__ = dict.__setitem__ 188 | __delattr__ = dict.__delitem__ 189 | 190 | # Support pickling 191 | def __getstate__(obj): 192 | return dict(obj.items()) 193 | 194 | def __setstate__(cls, attributes): 195 | return DotDict(**attributes) 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | aiohttp==3.7.0 3 | aiohttp-cors==0.7.0 4 | aioredis==1.3.1 5 | aiosignal==1.2.0 6 | ale-py==0.7.4 7 | argon2-cffi==21.3.0 8 | argon2-cffi-bindings==21.2.0 9 | asttokens==2.0.5 10 | astunparse==1.6.3 11 | async-generator==1.10 12 | async-timeout==3.0.1 13 | atari-py==0.2.6 14 | attrs==21.2.0 15 | autopep8==1.6.0 16 | backcall==0.2.0 17 | backports==1.0 18 | backports.functools_lru_cache==1.6.4 19 | beautifulsoup4==4.10.0 20 | bleach==4.1.0 21 | blessings==1.7 22 | bokeh==2.4.2 23 | box2d==2.3.10 24 | branca==0.4.2 25 | ca-certificates==2021.10.8 26 | cachetools==4.2.1 27 | certifi==2021.10.8 28 | cffi==1.14.0 29 | chardet==3.0.4 30 | charset-normalizer==2.0.9 31 | click==8.0.3 32 | click-plugins==1.1.1 33 | cligj==0.7.2 34 | cloudpickle==1.2.0 35 | colorama==0.4.4 36 | colorcet==3.0.0 37 | cryptography==36.0.2 38 | cvxpy==1.2.0 39 | cycler==0.11.0 40 | debugpy==1.5.1 41 | decorator==5.1.0 42 | deepdiff==5.8.0 43 | defusedxml==0.6.0 44 | deprecated==1.2.13 45 | docker-pycreds==0.4.0 46 | docopt==0.6.2 47 | drain3==0.9.10 48 | ecos==2.0.10 49 | entrypoints==0.3 50 | enum34==1.1.10 51 | executing==0.8.3 52 | filelock==3.4.0 53 | fiona==1.8.21 54 | flatbuffers==2.0 55 | folium==0.12.1.post1 56 | fonttools==4.29.1 57 | frozenlist==1.2.0 58 | future==0.18.2 59 | gast==0.4.0 60 | gitdb==4.0.9 61 | gitpython==3.1.27 62 | glfw==1.11.0 63 | google==3.0.0 64 | google-api-core==2.3.0 65 | google-auth==2.3.3 66 | google-auth-oauthlib==0.4.6 67 | google-pasta==0.2.0 68 | googleapis-common-protos==1.54.0 69 | gpustat==0.6.0 70 | grpcio==1.43.0rc1 71 | gym==0.23.1 72 | gym-notices==0.0.6 73 | gym-super-mario-bros==7.3.2 74 | h11==0.13.0 75 | h5py==3.6.0 76 | hiredis==2.0.0 77 | holoviews==1.14.9a1 78 | hvplot==0.8.0a10 79 | idna==3.3 80 | importlib-metadata==4.11.3 81 | importlib-resources==5.4.0 82 | ipykernel==6.6.0 83 | ipython==7.30.1 84 | ipython-genutils==0.2.0 85 | jedi==0.18.1 86 | jinja2==3.0.3 87 | joblib==1.1.0 88 | json5==0.9.6 89 | jsonpickle==1.5.1 90 | jsonschema==4.2.1 91 | jupyter-client==7.1.0 92 | jupyter-core==4.9.1 93 | jupyter_client==7.1.2 94 | jupyter_core==4.9.2 95 | jupyterlab-pygments==0.1.2 96 | jupyterlab-widgets==1.0.2 97 | keras==2.7.0 98 | keras-preprocessing==1.1.2 99 | kiwisolver==1.4.0 100 | kmeans-pytorch==0.3 101 | ld_impl_linux-64==2.36.1 102 | libclang==12.0.0 103 | libffi==3.4.2 104 | libgcc-ng==11.2.0 105 | libgomp==11.2.0 106 | libnsl==2.0.0 107 | libsodium==1.0.18 108 | libstdcxx-ng==11.2.0 109 | libzlib==1.2.11 110 | markdown==3.3.6 111 | markupsafe==2.0.1 112 | matplotlib==3.5.1 113 | matplotlib-inline==0.1.3 114 | mc-bin-client==1.0.1 115 | mistune==0.8.4 116 | mock==4.0.3 117 | mpyq==0.2.5 118 | msgpack==1.0.3 119 | msgpack-numpy==0.4.7.1 120 | multidict==5.2.0 121 | munch==2.5.0 122 | nbclient==0.5.9 123 | nbconvert==6.3.0 124 | nbformat==5.1.3 125 | ncurses==6.3 126 | nes-py==8.1.8 127 | nest-asyncio==1.5.4 128 | networkx==2.7.1 129 | notebook==6.4.6 130 | numpy==1.19.5 131 | numpy-stl==2.16.3 132 | nvidia-ml-py3==7.352.0 133 | oauthlib==3.1.1 134 | opencensus==0.8.0 135 | opencensus-context==0.1.2 136 | opencv-python==4.5.4.60 137 | openssl==3.0.0 138 | opt-einsum==3.3.0 139 | ordered-set==4.1.0 140 | osqp==0.6.2.post5 141 | outcome==1.1.0 142 | packaging==21.3 143 | pandas==1.3.5 144 | pandocfilters==1.5.0 145 | panel==0.12.6 146 | param==1.12.0 147 | parso==0.8.3 148 | pathtools==0.1.2 149 | pexpect==4.8.0 150 | pickleshare==0.7.5 151 | pillow==9.0.1 152 | pip==22.0.4 153 | portpicker==1.5.0 154 | prometheus-client==0.12.0 155 | promise==2.3 156 | prompt-toolkit==3.0.24 157 | protobuf==3.19.1 158 | psutil==5.8.0 159 | ptyprocess==0.7.0 160 | pure_eval==0.2.2 161 | py-spy==0.3.11 162 | pyasn1==0.4.8 163 | pyasn1-modules==0.2.8 164 | pybullet==2.7.2 165 | pycodestyle==2.8.0 166 | pycparser==2.21 167 | pyct==0.4.8 168 | pygame==2.1.2 169 | pyglet==1.5.11 170 | pygments==2.10.0 171 | pyopenssl==22.0.0 172 | pyparsing==3.0.7 173 | pyrsistent==0.18.0 174 | pysc2==3.0.0 175 | pysocks==1.7.1 176 | python==3.8.12 177 | python-dateutil==2.8.2 178 | python-tsp==0.2.1 179 | python-utils==3.1.0 180 | python_abi==3.8 181 | pytz==2021.3 182 | pyviz-comms==2.1.0 183 | pyyaml==6.0 184 | pyzmq==22.3.0 185 | qdldl==0.1.5.post2 186 | qtpy==1.11.3 187 | ray==1.3.0 188 | readline==8.1 189 | redis==4.0.2 190 | requests==2.26.0 191 | requests-oauthlib==1.3.0 192 | rsa==4.8 193 | s2clientprotocol==5.0.9.87702.0 194 | s2protocol==5.0.9.87702.0 195 | scikit-learn==1.0.2 196 | scipy==1.8.0 197 | scs==3.2.0 198 | seaborn==0.11.2 199 | selenium==4.1.3 200 | send2trash==1.8.0 201 | sentry-sdk==1.5.9 202 | setproctitle==1.2.2 203 | setuptools==59.5.0 204 | shapely==1.8.1 205 | shortuuid==1.0.8 206 | six==1.12.0 207 | sk-video==1.1.10 208 | sklearn==0.0 209 | smmap==5.0.0 210 | sniffio==1.2.0 211 | sortedcontainers==2.4.0 212 | soupsieve==2.3.1 213 | sqlite==3.37.1 214 | stack_data==0.2.0 215 | tabulate==0.8.9 216 | tap==0.2 217 | tensorboard==2.7.0 218 | tensorboard-data-server==0.6.1 219 | tensorboard-plugin-wit==1.8.0 220 | tensorboardx==2.5 221 | tensorflow==2.7.0 222 | tensorflow-estimator==2.7.0 223 | tensorflow-io-gcs-filesystem==0.22.0 224 | termcolor==1.1.0 225 | terminado==0.12.1 226 | testpath==0.5.0 227 | threadpoolctl==3.1.0 228 | tk==8.6.12 229 | toml==0.10.2 230 | torch==1.11.0+cu113 231 | torchaudio==0.11.0+cu113 232 | torchvision==0.12.0+cu113 233 | tornado==6.1 234 | tqdm==4.63.0 235 | traitlets==5.1.1 236 | trio==0.20.0 237 | trio-websocket==0.9.2 238 | tsplib95==0.7.1 239 | typing-extensions==4.0.1 240 | urllib3==1.26.7 241 | v==0.0.0 242 | wcwidth==0.2.5 243 | webencodings==0.5.1 244 | websocket-client==1.3.2 245 | werkzeug==2.0.2 246 | wheel==0.37.0 247 | whichcraft==0.6.1 248 | widgetsnbextension==3.5.2 249 | wrapt==1.13.3 250 | wsproto==1.1.0 251 | xz==5.2.5 252 | yaml==0.2.5 253 | yarl==1.7.2 254 | zeromq==4.3.4 255 | zipp==3.6.0 256 | zlib==1.2.11 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from mcs.globals import VERSION 3 | 4 | # https://github.com/kennethreitz/setup.py/blob/master/setup.py 5 | 6 | 7 | with open("README.md", "r") as fh: 8 | long_description = fh.read() 9 | 10 | extras = { 11 | "profiler": ["pyinstrument>=2.0"], 12 | } 13 | test_deps = ["pytest"] 14 | 15 | all_deps = [] 16 | for group_name in extras: 17 | all_deps += extras[group_name] 18 | all_deps = all_deps + test_deps 19 | extras["all"] = all_deps 20 | 21 | 22 | setup( 23 | name="mcs", 24 | version=VERSION, 25 | python_requires=">=3.5.0", 26 | packages=find_packages(), 27 | ) 28 | --------------------------------------------------------------------------------