├── .gitmodules ├── setup.py ├── docs └── resources │ └── viskill_teaser.png ├── viskill ├── agents │ ├── __init__.py │ ├── factory.py │ ├── sc_sac_sil.py │ ├── base.py │ ├── hier_agent.py │ ├── sc_awac.py │ ├── sl_ddpgbc.py │ ├── sl_dex.py │ ├── sc_ddpg.py │ └── sc_sac.py ├── configs │ ├── hier_agent │ │ └── hier_agent.yaml │ ├── sl_agent │ │ ├── sl_ddpgbc.yaml │ │ └── sl_dex.yaml │ ├── sc_agent │ │ ├── sc_sac.yaml │ │ ├── sc_awac.yaml │ │ └── sc_sac_sil.yaml │ ├── eval.yaml │ ├── skill_learning.yaml │ └── skill_chaining.yaml ├── modules │ ├── subnetworks.py │ ├── distributions.py │ ├── critics.py │ ├── policies.py │ ├── sampler.py │ └── replay_buffer.py ├── trainers │ ├── base_trainer.py │ ├── sl_trainer.py │ └── sc_trainer.py ├── components │ ├── normalizer.py │ ├── checkpointer.py │ ├── logger.py │ └── envrionment.py └── utils │ ├── vis_utils.py │ ├── mpi.py │ ├── rl_utils.py │ └── general_utils.py ├── requirements.txt ├── eval.py ├── train_sc.py ├── train_sl.py ├── LICENSE ├── .gitignore └── README.md /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "SurRoL"] 2 | path = SurRoL 3 | url = https://github.com/TaoHuang13/SurRoL.git 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup(name='viskill', version='0.0.1', packages=['viskill']) -------------------------------------------------------------------------------- /docs/resources/viskill_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/med-air/ViSkill/HEAD/docs/resources/viskill_teaser.png -------------------------------------------------------------------------------- /viskill/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .hier_agent import HierachicalAgent 2 | 3 | 4 | def make_hier_agent(env_params, samplers, cfg): 5 | return HierachicalAgent(env_params, samplers, cfg) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | torch 3 | torchvision 4 | moviepy 5 | imageio 6 | 7 | # RL 8 | gym==0.15.6 9 | # mpi4py 10 | 11 | # Log 12 | colorlog 13 | termcolor 14 | wandb 15 | 16 | -------------------------------------------------------------------------------- /viskill/configs/hier_agent/hier_agent.yaml: -------------------------------------------------------------------------------- 1 | name: hier_agent 2 | checkpoint_dir: ${checkpoint_dir} 3 | 4 | # SC agent (high-level agent) 5 | sc_agent: ${sc_agent} 6 | update_sc_agent: True 7 | 8 | # SL agent (low-level agent) 9 | sl_agent: ${sl_agent} 10 | update_sl_agent: False -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from viskill.trainers.sc_trainer import SkillChainingTrainer 3 | 4 | 5 | @hydra.main(version_base=None, config_path="./viskill/configs", config_name="eval") 6 | def main(cfg): 7 | exp = SkillChainingTrainer(cfg) 8 | exp.eval_ckpt() 9 | 10 | if __name__ == "__main__": 11 | main() -------------------------------------------------------------------------------- /train_sc.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from viskill.trainers.sc_trainer import SkillChainingTrainer 3 | 4 | 5 | @hydra.main(version_base=None, config_path="./viskill/configs", config_name="skill_chaining") 6 | def main(cfg): 7 | exp = SkillChainingTrainer(cfg) 8 | exp.train() 9 | 10 | if __name__ == "__main__": 11 | main() -------------------------------------------------------------------------------- /train_sl.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from viskill.trainers.sl_trainer import SkillLearningTrainer 3 | 4 | 5 | @hydra.main(version_base=None, config_path="./viskill/configs", config_name="skill_learning") 6 | def main(cfg): 7 | exp = SkillLearningTrainer(cfg) 8 | exp.train() 9 | 10 | if __name__ == "__main__": 11 | main() -------------------------------------------------------------------------------- /viskill/configs/sl_agent/sl_ddpgbc.yaml: -------------------------------------------------------------------------------- 1 | name: SL_DDPGBC 2 | device: ${device} 3 | discount: 0.99 4 | reward_scale: 1 5 | n_seed_steps: 1000 6 | 7 | actor_lr: 1e-3 8 | critic_lr: 1e-3 9 | noise_eps: 0.1 10 | aux_weight: 5 11 | p_dist: 2 12 | soft_target_tau: 0.005 13 | clip_obs: 200 14 | norm_clip: 5 15 | norm_eps: 0.01 16 | hidden_dim: 256 17 | sampler: 18 | type: her 19 | strategy: future 20 | k: 4 21 | update_epoch: 40 -------------------------------------------------------------------------------- /viskill/configs/sl_agent/sl_dex.yaml: -------------------------------------------------------------------------------- 1 | name: SL_DEX 2 | device: ${device} 3 | discount: 0.99 4 | reward_scale: 1 5 | n_seed_steps: 200 6 | 7 | actor_lr: 1e-3 8 | critic_lr: 1e-3 9 | noise_eps: 0.2 10 | aux_weight: 5 11 | p_dist: 2 12 | k: 5 13 | soft_target_tau: 0.05 14 | clip_obs: 200 15 | norm_clip: 5 16 | norm_eps: 0.01 17 | hidden_dim: 256 18 | sampler: 19 | type: her_seq 20 | strategy: future 21 | k: 4 22 | update_epoch: 40 -------------------------------------------------------------------------------- /viskill/modules/subnetworks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, in_dim, out_dim, hidden_dim=256): 6 | super().__init__() 7 | 8 | self.mlp = nn.Sequential( 9 | nn.Linear(in_dim, hidden_dim), nn.ReLU(), 10 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 11 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), 12 | nn.Linear(hidden_dim, out_dim), 13 | ) 14 | 15 | def forward(self, input): 16 | return self.mlp(input) -------------------------------------------------------------------------------- /viskill/trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from pathlib import Path 3 | 4 | 5 | class BaseTrainer: 6 | def __init__(self, cfg) -> None: 7 | self.cfg = cfg 8 | self.work_dir = Path(cfg.cwd) 9 | self._setup() 10 | 11 | @abstractmethod 12 | def _setup(self): 13 | raise NotImplementedError 14 | 15 | @abstractmethod 16 | def train(self): 17 | '''Training agent''' 18 | raise NotImplementedError 19 | 20 | @abstractmethod 21 | def eval(self): 22 | '''Evaluating agent.''' 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /viskill/configs/sc_agent/sc_sac.yaml: -------------------------------------------------------------------------------- 1 | name: SC_SAC 2 | device: ${device} 3 | discount: 0.99 4 | reward_scale: 1 5 | 6 | actor_lr: 1e-4 7 | critic_lr: 1e-4 8 | temp_lr: 1e-4 9 | random_eps: 0.3 10 | noise_eps: 0.01 11 | aux_weight: 5 12 | decay: False 13 | action_l2: 1 14 | p_dist: 2 15 | soft_target_tau: 0.005 16 | clip_obs: 200 17 | norm_clip: 5 18 | norm_eps: 0.01 19 | hidden_dim: 256 20 | sampler: 21 | type: her_seq 22 | strategy: future 23 | k: 4 24 | update_epoch: ${update_epoch} 25 | 26 | normalize: False 27 | intr_reward: True 28 | raw_env_reward: True 29 | 30 | learnable_temperature: True 31 | init_temperature: 0.1 -------------------------------------------------------------------------------- /viskill/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - skill_chaining 3 | - _self_ 4 | 5 | sc_seed: 2 6 | sc_ckpt_dir: ./exp/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}/s${sc_seed}/model 7 | sc_ckpt_episode: best 8 | 9 | # Working space 10 | hydra: 11 | run: 12 | dir: ./exp/eval/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}/s${seed} 13 | sweep: 14 | dir: ./exp/eval/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale} 15 | subdir: s${seed} 16 | sweeper: 17 | params: 18 | num_demo: 200 19 | seed: 1,2,3,4,5 20 | -------------------------------------------------------------------------------- /viskill/configs/sc_agent/sc_awac.yaml: -------------------------------------------------------------------------------- 1 | name: SC_AWAC 2 | device: ${device} 3 | discount: 0.99 4 | reward_scale: 1 5 | 6 | actor_lr: 1e-4 7 | critic_lr: 1e-4 8 | temp_lr: 1e-4 9 | random_eps: 0.3 10 | noise_eps: 0.01 11 | aux_weight: 5 12 | decay: False 13 | action_l2: 1 14 | p_dist: 2 15 | soft_target_tau: 0.005 16 | clip_obs: 200 17 | norm_clip: 5 18 | norm_eps: 0.01 19 | hidden_dim: 256 20 | sampler: 21 | type: her_seq 22 | strategy: future 23 | k: 4 24 | update_epoch: ${update_epoch} 25 | 26 | normalize: False 27 | intr_reward: True 28 | raw_env_reward: True 29 | 30 | n_action_samples: 1 31 | lam: 1 32 | learnable_temperature: False 33 | init_temperature: 0.001 -------------------------------------------------------------------------------- /viskill/configs/sc_agent/sc_sac_sil.yaml: -------------------------------------------------------------------------------- 1 | name: SC_SAC_SIL 2 | device: ${device} 3 | discount: 0.99 4 | reward_scale: 1 5 | 6 | actor_lr: 1e-4 7 | critic_lr: 1e-4 8 | temp_lr: 1e-4 9 | random_eps: 0.3 10 | noise_eps: 0.01 11 | aux_weight: 5 12 | decay: False 13 | action_l2: 1 14 | p_dist: 2 15 | soft_target_tau: 0.005 16 | clip_obs: 200 17 | norm_clip: 5 18 | norm_eps: 0.01 19 | hidden_dim: 256 20 | sampler: 21 | type: her_seq 22 | strategy: future 23 | k: 4 24 | update_epoch: ${update_epoch} 25 | 26 | normalize: False 27 | raw_env_reward: True 28 | 29 | n_action_samples: 1 30 | lam: 1 31 | learnable_temperature: False 32 | init_temperature: 0 33 | policy_delay: 30 -------------------------------------------------------------------------------- /viskill/configs/skill_learning.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - sl_agent@agent: sl_dex 3 | - _self_ 4 | 5 | # File path 6 | cwd: ${hydra:runtime.output_dir} 7 | 8 | # Training params 9 | n_train_steps: 2_000_001 10 | n_eval: 400 11 | n_save: 160 12 | n_log: 4000 13 | num_demo: 200 14 | eval_frequency: 2_000 15 | n_seed_steps: 200 16 | 17 | replay_buffer_capacity: 100_000 18 | checkpoint_frequency: 20_000 19 | update_epoch: 80 20 | batch_size: 128 21 | device: cuda:0 22 | seed: 1 23 | task: BiPegTransfer-v0 24 | subtask: grasp 25 | postfix: null 26 | skill_chaining: False 27 | dont_save: False 28 | n_eval_episodes: 8 29 | save_buffer: False 30 | 31 | use_wb: True 32 | project_name: viskill 33 | entity_name: thuang22 34 | 35 | mpi: {rank: null, is_chef: null, num_workers: null} 36 | # Working space 37 | hydra: 38 | run: 39 | dir: ./exp/skill_learning/${task}/${agent.name}/d${num_demo}/s${seed}/${subtask} 40 | sweep: 41 | dir: ./exp/skill_learning/${task}/${agent.name}/d${num_demo}/s${seed} 42 | subdir: ${subtask} 43 | sweeper: 44 | params: 45 | seed: 1,2,3,4,5 46 | subtask: grasp,handover,release 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Med-AIR@CUHK 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /viskill/agents/factory.py: -------------------------------------------------------------------------------- 1 | from .sc_awac import SkillChainingAWAC 2 | from .sc_ddpg import SkillChainingDDPG 3 | from .sc_sac import SkillChainingSAC 4 | from .sc_sac_sil import SkillChainingSACSIL 5 | from .sl_ddpgbc import SkillLearningDDPGBC 6 | from .sl_dex import SkillLearningDEX 7 | 8 | AGENTS = { 9 | 'SL_DDPGBC': SkillLearningDDPGBC, 10 | 'SL_DEX': SkillLearningDEX, 11 | 'SC_DDPG': SkillChainingDDPG, 12 | 'SC_AWAC': SkillChainingAWAC, 13 | 'SC_SAC': SkillChainingSAC, 14 | 'SC_SAC_SIL': SkillChainingSACSIL, 15 | } 16 | 17 | def make_sl_agent(env_params, sampler, cfg): 18 | if cfg.name not in AGENTS.keys(): 19 | assert 'Agent is not supported: %s' % cfg.name 20 | else: 21 | assert 'SL' in cfg.name 22 | return AGENTS[cfg.name]( 23 | env_params=env_params, 24 | sampler=sampler, 25 | agent_cfg=cfg 26 | ) 27 | 28 | 29 | def make_sc_agent(env_params, cfg, sl_agent): 30 | if cfg.name not in AGENTS.keys(): 31 | assert 'Agent is not supported: %s' % cfg.name 32 | else: 33 | assert 'SC' in cfg.name 34 | return AGENTS[cfg.name]( 35 | env_params=env_params, 36 | agent_cfg=cfg, 37 | sl_agent=sl_agent 38 | ) 39 | -------------------------------------------------------------------------------- /viskill/configs/skill_chaining.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hier_agent@agent: hier_agent 3 | - sc_agent: sc_sac_sil 4 | - sl_agent: sl_dex 5 | - _self_ 6 | 7 | # File path 8 | cwd: ${hydra:runtime.output_dir} 9 | 10 | # Training params 11 | task: BiPegTransfer-v0 12 | init_subtask: grasp 13 | subtask: ${init_subtask} 14 | postfix: null 15 | skill_chaining: True 16 | dont_save: False 17 | checkpoint_dir: ./exp/skill_learning/${task}/${sl_agent.name}/d${num_demo}/s${model_seed} 18 | 19 | num_demo: 200 20 | seed: 1 21 | model_seed: 1 22 | device: cuda:0 23 | update_epoch: 20 24 | replay_buffer_capacity: 100_000 25 | batch_size: 128 26 | 27 | n_train_steps: 1_000_001 28 | n_eval: 1600 29 | n_save: 800 30 | n_log: 9600 31 | n_eval_episodes: 10 32 | eval_frequency: 2_000 33 | n_seed_steps: 200 34 | use_demo_buffer: True 35 | ckpt_episode: latest 36 | 37 | use_wb: True 38 | project_name: viskill 39 | entity_name: thuang22 40 | 41 | mpi: {rank: null, is_chef: null, num_workers: null} 42 | # Working space 43 | hydra: 44 | run: 45 | dir: ./exp/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}/s${seed} 46 | sweep: 47 | dir: ./exp/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale} 48 | subdir: s${seed} 49 | sweeper: 50 | params: 51 | num_demo: 200 52 | seed: 1,2,3,4,5 53 | -------------------------------------------------------------------------------- /viskill/components/normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Normalizer: 6 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf): 7 | self.size = size 8 | self.eps = eps 9 | self.default_clip_range = default_clip_range 10 | # some local information 11 | self.total_sum = np.zeros(self.size, np.float32) 12 | self.total_sumsq = np.zeros(self.size, np.float32) 13 | self.total_count = np.zeros(1, np.float32) 14 | # get the mean and std 15 | self.mean = np.zeros(self.size, np.float32) 16 | self.std = np.ones(self.size, np.float32) 17 | 18 | # update the parameters of the normalizer 19 | def update(self, v): 20 | v = v.reshape(-1, self.size) 21 | # do the computing 22 | self.total_sum += v.sum(axis=0) 23 | self.total_sumsq += (np.square(v)).sum(axis=0) 24 | self.total_count[0] += v.shape[0] 25 | 26 | def recompute_stats(self): 27 | # calculate the new mean and std 28 | self.mean = self.total_sum / self.total_count 29 | self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square(self.total_sum / self.total_count))) 30 | 31 | # normalize the observation 32 | def normalize(self, v, clip_range=None, device=None): 33 | if clip_range is None: 34 | clip_range = self.default_clip_range 35 | if isinstance(v, np.ndarray): 36 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range) 37 | elif isinstance(v, torch.Tensor): 38 | return (v - torch.tensor(self.mean, dtype=torch.float32).to(device)) / (torch.tensor(self.std, dtype=torch.float32).to(device)).clamp(-clip_range, clip_range) -------------------------------------------------------------------------------- /viskill/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn.functional as F 4 | from torch import distributions as pyd 5 | 6 | 7 | class TanhTransform(pyd.transforms.Transform): 8 | domain = pyd.constraints.real 9 | codomain = pyd.constraints.interval(-1.0, 1.0) 10 | bijective = True 11 | sign = +1 12 | 13 | def __init__(self, cache_size=1): 14 | super().__init__(cache_size=cache_size) 15 | 16 | @staticmethod 17 | def atanh(x): 18 | return 0.5 * (x.log1p() - (-x).log1p()) 19 | 20 | def __eq__(self, other): 21 | return isinstance(other, TanhTransform) 22 | 23 | def _call(self, x): 24 | return x.tanh() 25 | 26 | def _inverse(self, y): 27 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 28 | # one should use `cache_size=1` instead 29 | return self.atanh(y) 30 | 31 | def log_abs_det_jacobian(self, x, y): 32 | # We use a formula that is more numerically stable, see details in the following link 33 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 34 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 35 | 36 | 37 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 38 | def __init__(self, loc, scale): 39 | self.loc = loc 40 | self.scale = scale 41 | 42 | self.base_dist = pyd.Normal(loc, scale) 43 | transforms = [TanhTransform()] 44 | super().__init__(self.base_dist, transforms) 45 | 46 | @property 47 | def mean(self): 48 | mu = self.loc 49 | for tr in self.transforms: 50 | mu = tr(mu) 51 | return mu 52 | -------------------------------------------------------------------------------- /viskill/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def add_caption_to_img(img, info, name=None, flip_rgb=False): 6 | """ Adds caption to an image. info is dict with keys and text/array. 7 | :arg name: if given this will be printed as heading in the first line 8 | :arg flip_rgb: set to True for inputs with BGR color channels 9 | """ 10 | offset = 12 11 | 12 | frame = img * 255.0 if img.max() <= 1.0 else img 13 | if flip_rgb: 14 | frame = frame[:, :, ::-1] 15 | 16 | # make frame larger if needed 17 | if frame.shape[0] < 300: 18 | frame = cv2.resize(frame, (400, 400), interpolation=cv2.INTER_CUBIC) 19 | 20 | fheight, fwidth = frame.shape[:2] 21 | frame = np.concatenate([frame, np.zeros((offset * (len(info.keys()) + 2), fwidth, 3))], 0) 22 | 23 | font_size = 0.4 24 | thickness = 1 25 | x, y = 5, fheight + 10 26 | if name is not None: 27 | cv2.putText(frame, '[{}]'.format(name), 28 | (x, y), cv2.FONT_HERSHEY_SIMPLEX, 29 | font_size, (100, 100, 0), thickness, cv2.LINE_AA) 30 | for i, k in enumerate(info.keys()): 31 | v = info[k] 32 | key_text = '{}: '.format(k) 33 | (key_width, _), _ = cv2.getTextSize(key_text, cv2.FONT_HERSHEY_SIMPLEX, 34 | font_size, thickness) 35 | 36 | cv2.putText(frame, key_text, 37 | (x, y + offset * (i + 2)), 38 | cv2.FONT_HERSHEY_SIMPLEX, 39 | font_size, (66, 133, 244), thickness, cv2.LINE_AA) 40 | 41 | cv2.putText(frame, str(v), 42 | (x + key_width, y + offset * (i + 2)), 43 | cv2.FONT_HERSHEY_SIMPLEX, 44 | font_size, (100, 100, 100), thickness, cv2.LINE_AA) 45 | 46 | if flip_rgb: 47 | frame = frame[:, :, ::-1] 48 | 49 | return frame 50 | 51 | def add_captions_to_seq(img_seq, info_seq, **kwargs): 52 | """Adds caption to sequence of image. info_seq is list of dicts with keys and text/array.""" 53 | return [add_caption_to_img(img, info, name='Timestep {:03d}'.format(i), **kwargs) for i, (img, info) in enumerate(zip(img_seq, info_seq))] 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /viskill/modules/critics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..modules.subnetworks import MLP 7 | 8 | 9 | class Critic(nn.Module): 10 | def __init__(self, in_dim, hidden_dim): 11 | super().__init__() 12 | 13 | self.q = MLP( 14 | in_dim=in_dim, 15 | out_dim=1, 16 | hidden_dim=hidden_dim 17 | ) 18 | 19 | def forward(self, state, action): 20 | sa = torch.cat([state, action], dim=-1) 21 | q = self.q(sa) 22 | return q 23 | 24 | 25 | class DoubleCritic(nn.Module): 26 | def __init__(self, in_dim, hidden_dim): 27 | super().__init__() 28 | 29 | self.q1 = MLP( 30 | in_dim=in_dim, 31 | out_dim=1, 32 | hidden_dim=hidden_dim 33 | ) 34 | 35 | self.q2 = MLP( 36 | in_dim=in_dim, 37 | out_dim=1, 38 | hidden_dim=hidden_dim 39 | ) 40 | 41 | def forward(self, state, action): 42 | sa = torch.cat([state, action], dim=-1) 43 | q1 = self.q1(sa) 44 | q2 = self.q2(sa) 45 | return q1, q2 46 | 47 | def q(self, state, action): 48 | sa = torch.cat([state, action], dim=-1) 49 | q1 = self.q1(sa) 50 | q2 = self.q2(sa) 51 | return torch.min(q1, q2) 52 | 53 | 54 | class SkillChainingCritic(nn.Module): 55 | def __init__(self, in_dim, hidden_dim, middle_subtasks, last_subtask): 56 | super().__init__() 57 | 58 | self.qs = nn.ModuleDict({ 59 | subtask: Critic( 60 | in_dim=in_dim, 61 | hidden_dim=hidden_dim 62 | ) for subtask in middle_subtasks}) 63 | self.qs.update({last_subtask: None}) 64 | 65 | def __getitem__(self, key): 66 | return self.qs[key] 67 | 68 | def forward(self, state, action, subtask): 69 | q = self.qs[subtask](state, action) 70 | return q 71 | 72 | def init_last_subtask_q(self, last_subtask, critic): 73 | '''Initialize with pre-trained local q-function''' 74 | assert self.qs[last_subtask] is None 75 | self.qs[last_subtask] = copy.deepcopy(critic) 76 | 77 | 78 | class SkillChainingDoubleCritic(SkillChainingCritic): 79 | def __init__(self, in_dim, hidden_dim, middle_subtasks, last_subtask): 80 | super(SkillChainingCritic, self).__init__() 81 | 82 | self.qs = nn.ModuleDict({ 83 | subtask: DoubleCritic( 84 | in_dim=in_dim, 85 | hidden_dim=hidden_dim 86 | ) for subtask in middle_subtasks}) 87 | self.qs.update({last_subtask: None}) 88 | 89 | def forward(self, state, action, subtask): 90 | q1, q2 = self.qs[subtask](state, action) 91 | return q1, q2 92 | 93 | def q(self, state, action, subtask): 94 | return self.qs[subtask].q(state, action) 95 | -------------------------------------------------------------------------------- /viskill/components/checkpointer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pipes 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from ..components.logger import logger 10 | from ..utils.general_utils import get_last_argmax, str2int 11 | 12 | 13 | class CheckpointHandler: 14 | @staticmethod 15 | def get_ckpt_name(episode): 16 | return 'weights_ep{}.pth'.format(episode) 17 | 18 | @staticmethod 19 | def get_episode(path): 20 | checkpoint_names = glob.glob(os.path.abspath(path) + "/*.pth") 21 | if len(checkpoint_names) == 0: 22 | logger.error("No checkpoints found at {}!".format(path)) 23 | processed_names = [file.split('/')[-1].replace('weights_ep', '').replace('.pth', '') 24 | for file in checkpoint_names] 25 | episodes = list(filter(lambda x: x is not None, [str2int(name) for name in processed_names])) 26 | return episodes 27 | 28 | @staticmethod 29 | def get_resume_ckpt_file(resume, path): 30 | episodes = CheckpointHandler.get_episode(path) 31 | file_paths = [os.path.join(path, CheckpointHandler.get_ckpt_name(episode)) for episode in episodes] 32 | scores = [torch.load(file_path)['score'] for file_path in file_paths] 33 | if resume == 'latest': 34 | max_episode = np.max(episodes) 35 | resume_file = CheckpointHandler.get_ckpt_name(max_episode) 36 | logger.info(f'Checkpoints with max episode {max_episode} with the success rate {scores[np.argmax(episodes)]}!') 37 | elif resume == 'best': 38 | max_episode = episodes[get_last_argmax(scores)] 39 | resume_file = CheckpointHandler.get_ckpt_name(max_episode) 40 | logger.info(f'Checkpoints with success rate {scores}, the highest success rate {max(scores)}!') 41 | return os.path.join(path, resume_file), max_episode 42 | 43 | @staticmethod 44 | def save_checkpoint(state, folder, filename='checkpoint.pth'): 45 | torch.save(state, os.path.join(folder, filename)) 46 | 47 | @staticmethod 48 | def load_checkpoint(checkpt_dir, agent, device, episode='best'): 49 | """Loads weigths from checkpoint.""" 50 | checkpt_path, max_episode = CheckpointHandler.get_resume_ckpt_file(episode, checkpt_dir) 51 | checkpt = torch.load(checkpt_path, map_location=device) 52 | 53 | logger.info(f'Loading pre-trained model from {checkpt_path}!') 54 | agent.load_state_dict(checkpt['state_dict']) 55 | if 'g_norm' in checkpt.keys() and 'o_norm' in checkpt.keys(): 56 | agent.g_norm = checkpt['g_norm'] 57 | agent.o_norm = checkpt['o_norm'] 58 | 59 | 60 | def save_cmd(base_dir): 61 | train_cmd = 'python ' + ' '.join([sys.argv[0]] + [pipes.quote(s) for s in sys.argv[1:]]) 62 | train_cmd += '\n\n' 63 | print('\n' + '*' * 80) 64 | print('Training command:\n' + train_cmd) 65 | print('*' * 80 + '\n') 66 | with open(os.path.join(base_dir, "cmd.txt"), "a") as f: 67 | f.write(train_cmd) 68 | -------------------------------------------------------------------------------- /viskill/utils/mpi.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mpi4py import MPI 4 | 5 | from .general_utils import (AttrDict, joinListDict, joinListDictList, 6 | joinListList) 7 | 8 | 9 | def update_mpi_config(cfg): 10 | rank = MPI.COMM_WORLD.Get_rank() 11 | cfg.mpi.rank = rank 12 | cfg.mpi.is_chef = rank == 0 13 | cfg.mpi.num_workers = MPI.COMM_WORLD.Get_size() 14 | 15 | # update conf 16 | cfg.seed = cfg.seed + rank 17 | 18 | 19 | def mpi_sum(x): 20 | buf = np.zeros_like(np.array(x)) 21 | MPI.COMM_WORLD.Allreduce(np.array(x), buf, op=MPI.SUM) 22 | return buf 23 | 24 | 25 | def mpi_gather_experience_episode(experience_episode): 26 | buf = MPI.COMM_WORLD.allgather(experience_episode) 27 | return joinListDictList(buf) 28 | 29 | 30 | def mpi_gather_experience_rollots(experience_rollouts): 31 | buf = MPI.COMM_WORLD.allgather(experience_rollouts) 32 | return joinListDict(buf) 33 | 34 | 35 | def mpi_gather_experience_transitions(experience_transitions): 36 | buf = MPI.COMM_WORLD.allgather(experience_transitions) 37 | return joinListList(buf) 38 | 39 | 40 | def mpi_gather_experience_successful_transitions(experience_transitions): 41 | buf = MPI.COMM_WORLD.allgather(experience_transitions) 42 | jointLL = [] 43 | for i in experience_transitions: 44 | if i[0] is not None: 45 | jointLL.append(i) 46 | return jointLL 47 | 48 | 49 | def mpi_gather_experience(experience_episode): 50 | """Gathers data across workers, can handle hierarchical and flat experience dicts.""" 51 | return mpi_gather_experience_episode(experience_episode) 52 | 53 | 54 | def mpi_gather_rollouts(rollouts): 55 | """Gathers data across workers, can handle hierarchical and flat experience dicts.""" 56 | return mpi_gather_experience_rollots(rollouts) 57 | 58 | 59 | # sync_networks across the different cores 60 | def sync_networks(network): 61 | """ 62 | netowrk is the network you want to sync 63 | """ 64 | comm = MPI.COMM_WORLD 65 | flat_params, params_shape = _get_flat_params(network) 66 | comm.Bcast(flat_params, root=0) 67 | # set the flat params back to the network 68 | _set_flat_params(network, params_shape, flat_params) 69 | 70 | 71 | # get the flat params from the network 72 | def _get_flat_params(network): 73 | param_shape = {} 74 | flat_params = None 75 | for key_name, value in network.state_dict().items(): 76 | param_shape[key_name] = value.cpu().detach().numpy().shape 77 | if flat_params is None: 78 | flat_params = value.cpu().detach().numpy().flatten() 79 | else: 80 | flat_params = np.append(flat_params, value.cpu().detach().numpy().flatten()) 81 | return flat_params, param_shape 82 | 83 | 84 | # set the params from the network 85 | def _set_flat_params(network, params_shape, params): 86 | pointer = 0 87 | device = torch.device("cuda:0") 88 | 89 | for key_name, values in network.state_dict().items(): 90 | # get the length of the parameters 91 | len_param = int(np.prod(params_shape[key_name])) 92 | copy_params = params[pointer:pointer + len_param].reshape(params_shape[key_name]) 93 | copy_params = torch.tensor(copy_params).to(device) 94 | # copy the params 95 | values.data.copy_(copy_params.data) 96 | # update the pointer 97 | pointer += len_param 98 | -------------------------------------------------------------------------------- /viskill/agents/sc_sac_sil.py: -------------------------------------------------------------------------------- 1 | from ..utils.general_utils import AttrDict, prefix_dict 2 | from .sc_awac import SkillChainingAWAC 3 | 4 | 5 | class SkillChainingSACSIL(SkillChainingAWAC): 6 | def __init__( 7 | self, 8 | env_params, 9 | agent_cfg, 10 | sl_agent 11 | ): 12 | super().__init__(env_params, agent_cfg, sl_agent) 13 | self.policy_delay = agent_cfg.policy_delay 14 | 15 | def update_actor_and_alpha(self, obs, action, subtask, sil=False): 16 | if sil: 17 | #metrics = super(SkillChainingSACSIL, self).update_actor(obs, action, subtask) 18 | # compute log probability 19 | dist = self.actor(obs, subtask) 20 | log_probs = dist.log_prob(action).sum(-1, keepdim=True) 21 | # compute exponential weight 22 | weights = self._compute_weights(obs, action, subtask) 23 | actor_loss = -(log_probs * weights).sum() 24 | 25 | # optimize actor loss 26 | self.actor_optimizer[subtask].zero_grad() 27 | actor_loss.backward() 28 | self.actor_optimizer[subtask].step() 29 | 30 | metrics = AttrDict( 31 | log_probs=log_probs.mean(), 32 | actor_loss=actor_loss.item() 33 | ) 34 | metrics = prefix_dict(metrics, 'sil_') 35 | else: 36 | # compute log probability 37 | #metrics = super(SkillChainingAWAC, self).update_actor_and_alpha(obs, subtask) 38 | dist = self.actor(obs, subtask) 39 | action = dist.rsample() 40 | log_probs = dist.log_prob(action).sum(-1, keepdim=True) 41 | # compute state value 42 | actor_Q = self.critic.q(obs, action, subtask) 43 | actor_loss = (- actor_Q).mean() 44 | 45 | # optimize actor loss 46 | self.actor_optimizer[subtask].zero_grad() 47 | actor_loss.backward() 48 | self.actor_optimizer[subtask].step() 49 | 50 | metrics = AttrDict( 51 | log_probs=log_probs.mean(), 52 | actor_loss=actor_loss.item() 53 | ) 54 | 55 | return prefix_dict(metrics, subtask + '_') 56 | 57 | def update(self, replay_buffer, demo_buffer): 58 | metrics = AttrDict() 59 | 60 | for subtask in self.env_params['middle_subtasks']: 61 | for i in range(self.update_epoch): 62 | # sample from replay buffer 63 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(replay_buffer, subtask) 64 | action = (action / self.max_action) 65 | 66 | # update critic and actor 67 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask)) 68 | if (i + 1) % self.policy_delay == 0: 69 | metrics.update(self.update_actor_and_alpha(obs, action, subtask, sil=False)) 70 | 71 | # sample from replay buffer 72 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(demo_buffer, subtask) 73 | action = (action / self.max_action) 74 | 75 | # update actor 76 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask)) 77 | metrics.update(self.update_actor_and_alpha(obs, action, subtask, sil=True)) 78 | 79 | # update target critic and actor 80 | self.update_target() 81 | 82 | return metrics 83 | -------------------------------------------------------------------------------- /viskill/modules/policies.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .distributions import SquashedNormal 9 | from .subnetworks import MLP 10 | 11 | LOG_STD_BOUNDS = (-5, 2) 12 | 13 | class DeterministicActor(nn.Module): 14 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1.): 15 | super().__init__() 16 | 17 | self.trunk = MLP( 18 | in_dim=in_dim, 19 | out_dim=out_dim, 20 | hidden_dim=hidden_dim 21 | ) 22 | self.max_action = max_action 23 | 24 | def forward(self, state): 25 | a = self.trunk(state) 26 | return self.max_action * torch.tanh(a) 27 | 28 | 29 | class DiagGaussianActor(nn.Module): 30 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1.): 31 | super().__init__() 32 | 33 | self.trunk = MLP( 34 | in_dim=in_dim, 35 | out_dim=2*out_dim, 36 | hidden_dim=hidden_dim 37 | ) 38 | self.max_action = max_action 39 | 40 | def forward(self, obs): 41 | mu, log_std = self.trunk(obs).chunk(2, dim=-1) 42 | 43 | # constrain log_std inside [log_std_min, log_std_max] 44 | log_std = torch.tanh(log_std) 45 | log_std_min, log_std_max = LOG_STD_BOUNDS 46 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) 47 | std = log_std.exp() 48 | 49 | dist = SquashedNormal(mu, std) 50 | return dist 51 | 52 | def sample_n(self, obs, n_samples): 53 | return self.forward(obs).sample_n(n_samples) 54 | 55 | 56 | class SkillChainingActor(nn.Module): 57 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1., 58 | middle_subtasks=None, last_subtask=None): 59 | super().__init__() 60 | 61 | self.actors = nn.ModuleDict({ 62 | subtask: DeterministicActor( 63 | in_dim=in_dim, 64 | out_dim=out_dim, 65 | hidden_dim=hidden_dim, 66 | max_action=max_action 67 | ) for subtask in middle_subtasks}) 68 | self.actors.update({last_subtask: None}) 69 | 70 | def __getitem__(self, key): 71 | return self.actors[key] 72 | 73 | def forward(self, state, subtask): 74 | a = self.actors[subtask](state) 75 | return a 76 | 77 | def init_last_subtask_actor(self, last_subtask, actor): 78 | '''Initialize with pre-trained local actor''' 79 | assert self.actors[last_subtask] is None 80 | self.actors[last_subtask] = copy.deepcopy(actor) 81 | 82 | 83 | class SkillChainingDiagGaussianActor(SkillChainingActor): 84 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1., 85 | middle_subtasks=None, last_subtask=None): 86 | super(SkillChainingActor, self).__init__() 87 | 88 | self.actors = nn.ModuleDict({ 89 | subtask: DiagGaussianActor( 90 | in_dim=in_dim, 91 | out_dim=out_dim, 92 | hidden_dim=hidden_dim, 93 | max_action=max_action 94 | ) for subtask in middle_subtasks}) 95 | self.actors.update({last_subtask: None}) 96 | 97 | def sample_n(self, obs, subtask, n_samples): 98 | return self.actors[subtask].sample_n(obs, n_samples) 99 | 100 | def squash_action(self, dist, raw_action): 101 | squashed_action = torch.tanh(raw_action) 102 | jacob = 2 * (math.log(2) - raw_action - F.softplus(-2 * raw_action)) 103 | log_prob = (dist.log_prob(raw_action) - jacob).sum(dim=-1, keepdims=True) 104 | return squashed_action, log_prob 105 | -------------------------------------------------------------------------------- /viskill/agents/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..utils.mpi import sync_networks 6 | 7 | 8 | class BaseAgent(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def train(self, training=True): 13 | self.training = training 14 | self.actor.train(training) 15 | self.critic.train(training) 16 | 17 | def get_samples(self, replay_buffer): 18 | # sample replay buffer 19 | transitions = replay_buffer.sample() 20 | 21 | # preprocess 22 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g'] 23 | transitions['obs'], transitions['g'] = self._preproc_og(o, g) 24 | transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g) 25 | 26 | obs_norm = self.o_norm.normalize(transitions['obs']) 27 | g_norm = self.g_norm.normalize(transitions['g']) 28 | inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) 29 | 30 | obs_next_norm = self.o_norm.normalize(transitions['obs_next']) 31 | g_next_norm = self.g_norm.normalize(transitions['g_next']) 32 | inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) 33 | 34 | obs = self.to_torch(inputs_norm) 35 | next_obs = self.to_torch(inputs_next_norm) 36 | action = self.to_torch(transitions['actions']) 37 | reward = self.to_torch(transitions['r']) 38 | done = self.to_torch(transitions['dones']) 39 | 40 | return obs, action, reward, done, next_obs 41 | 42 | def update_target(self): 43 | # Update the frozen target models 44 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 45 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data) 46 | 47 | def update_normalizer(self, episode_batch): 48 | mb_obs, mb_ag, mb_g, mb_actions, dones = episode_batch.obs, episode_batch.ag, episode_batch.g, \ 49 | episode_batch.actions, episode_batch.dones 50 | mb_obs_next = mb_obs[:, 1:, :] 51 | mb_ag_next = mb_ag[:, 1:, :] 52 | # get the number of normalization transitions 53 | num_transitions = mb_actions.shape[1] 54 | # create the new buffer to store them 55 | buffer_temp = {'obs': mb_obs, 56 | 'ag': mb_ag, 57 | 'g': mb_g, 58 | 'actions': mb_actions, 59 | 'obs_next': mb_obs_next, 60 | 'ag_next': mb_ag_next, 61 | } 62 | transitions = self.her_sampler.sample_her_transitions(buffer_temp, num_transitions) 63 | obs, g = transitions['obs'], transitions['g'] 64 | # pre process the obs and g 65 | transitions['obs'], transitions['g'] = self._preproc_og(obs, g) 66 | # update 67 | self.o_norm.update(transitions['obs']) 68 | self.g_norm.update(transitions['g']) 69 | # recompute the stats 70 | self.o_norm.recompute_stats() 71 | self.g_norm.recompute_stats() 72 | 73 | def _preproc_og(self, o, g): 74 | o = np.clip(o, -self.clip_obs, self.clip_obs) 75 | g = np.clip(g, -self.clip_obs, self.clip_obs) 76 | return o, g 77 | 78 | def _preproc_inputs(self, o, g, dim=0, device=None): 79 | o_norm = self.o_norm.normalize(o, device=device) 80 | g_norm = self.g_norm.normalize(g, device=device) 81 | 82 | if isinstance(o_norm, np.ndarray): 83 | inputs = np.concatenate([o_norm, g_norm], dim) 84 | inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0).to(self.device) 85 | elif isinstance(o_norm, torch.Tensor): 86 | inputs = torch.cat([o_norm, g_norm], dim=1) 87 | return inputs 88 | 89 | def to_torch(self, array, copy=True): 90 | if copy: 91 | return torch.tensor(array, dtype=torch.float32).to(self.device) 92 | return torch.as_tensor(array).to(self.device) 93 | 94 | def sync_networks(self): 95 | sync_networks(self) -------------------------------------------------------------------------------- /viskill/agents/hier_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..components.checkpointer import CheckpointHandler 4 | from ..utils.general_utils import AttrDict, prefix_dict 5 | from .base import BaseAgent 6 | from .factory import make_sc_agent, make_sl_agent 7 | 8 | 9 | class HierachicalAgent(BaseAgent): 10 | def __init__( 11 | self, 12 | env_params, 13 | samplers, 14 | cfg, 15 | ): 16 | super().__init__() 17 | self.cfg = cfg 18 | self.device = cfg.device 19 | self.subtasks = env_params.subtasks 20 | self.goal_adaptor = env_params['adaptor_sc'] 21 | self.subtasks_steps = env_params['subtask_steps'] 22 | self.middle_subtasks = env_params['middle_subtasks'] 23 | self.last_subtask = env_params['last_subtask'] 24 | self.len_cond = env_params['len_cond'] 25 | self.curr_subtask = None 26 | 27 | self.sl_agent = torch.nn.ModuleDict( 28 | {subtask: make_sl_agent(env_params, samplers[subtask], cfg.sl_agent) for subtask in self.subtasks}) 29 | self._init_sl_agent() 30 | self.sc_agent = make_sc_agent(env_params, cfg.sc_agent, self.sl_agent) 31 | self._init_sc_agent() 32 | 33 | def _init_sl_agent(self): 34 | checkpt_dir = self.cfg.checkpoint_dir 35 | for subtask in self.subtasks: 36 | # TODO(tao): expose loading metric and dir speicification 37 | sub_checkpt_dir = checkpt_dir + f'/{subtask}/model' 38 | CheckpointHandler.load_checkpoint( 39 | sub_checkpt_dir, self.sl_agent[subtask], self.device, episode=self.cfg.ckpt_episode) 40 | 41 | def _init_sc_agent(self): 42 | self.sc_agent.init(self.last_subtask, self.sl_agent) 43 | 44 | def update(self, sc_buffer, sl_buffer, sc_demo_buffer=None, sl_demo_buffer=None): 45 | metrics = AttrDict() 46 | if sc_demo_buffer is None: 47 | sc_metrics = prefix_dict(self.sc_agent.update(sc_buffer), 'sc_') 48 | else: 49 | sc_metrics = prefix_dict(self.sc_agent.update(sc_buffer, sc_demo_buffer), 'sc_') 50 | metrics.update(sc_metrics) 51 | 52 | if self.cfg.agent.update_sl_agent: 53 | for subtask in self.subtasks: 54 | sl_metrics = prefix_dict(self.sl_agent[subtask].update(sl_buffer[subtask]), 'sl_' + subtask + '_') 55 | metrics.update(sl_metrics) 56 | return metrics 57 | 58 | def get_action(self, obs, subtask, noise=False): 59 | output = AttrDict() 60 | if self._perform_hl_step_now(subtask): 61 | # perform step with skill-chaining policy 62 | if subtask not in self.middle_subtasks: 63 | self._last_sc_action = obs['desired_goal'][:-self.len_cond] 64 | else: 65 | self._last_sc_action = self.sc_agent.get_action(obs['observation'], subtask, noise=noise) 66 | output.is_sc_step = True 67 | self.curr_subtask = subtask 68 | else: 69 | output.is_sc_step = False 70 | 71 | # perform step with skill-learning policy 72 | assert self._last_sc_action is not None 73 | self.goal_adaption(obs, subtask) 74 | sl_action = self.sl_agent[subtask].get_action(obs, noise=False) 75 | 76 | output.update(AttrDict( 77 | sc_action=self._last_sc_action, 78 | sl_action=sl_action 79 | )) 80 | return output 81 | 82 | def goal_adaption(self, obs, subtask): 83 | # Add contact condition to make goal compatible with surrol wrapper 84 | if subtask == 'release' and self.cfg.task == 'MatchBoard-v0': 85 | adpt_goal = self.goal_adaptor(self._last_sc_action, subtask) 86 | adpt_goal[3: 6] = obs['desired_goal'][3: 6].copy() 87 | obs['desired_goal'] = adpt_goal 88 | elif subtask in ['push', 'pull'] and self.cfg.task == 'MatchBoard-v0': 89 | adpt_goal = self.goal_adaptor(self._last_sc_action, subtask) 90 | adpt_goal[3] = obs['desired_goal'][3].copy() 91 | adpt_goal[5] = obs['desired_goal'][5].copy() 92 | obs['desired_goal'] = adpt_goal 93 | else: 94 | obs['desired_goal'] = self.goal_adaptor(self._last_sc_action, subtask) 95 | 96 | def _perform_hl_step_now(self, subtask): 97 | """Indicates whether the skill-chaining policy should be executed in the current time step.""" 98 | return subtask != self.curr_subtask 99 | 100 | def sync_networks(self): 101 | self.sc_agent.sync_networks() -------------------------------------------------------------------------------- /viskill/agents/sc_awac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from ..utils.general_utils import AttrDict, prefix_dict 5 | from .sc_sac import SkillChainingSAC 6 | 7 | 8 | class SkillChainingAWAC(SkillChainingSAC): 9 | def __init__( 10 | self, 11 | env_params, 12 | agent_cfg, 13 | sl_agent 14 | ): 15 | super().__init__(env_params, agent_cfg, sl_agent) 16 | 17 | # AWAC parameters 18 | self.n_action_samples = agent_cfg.n_action_samples 19 | self.lam = agent_cfg.lam 20 | 21 | def update_critic(self, obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask): 22 | assert subtask != self.env_params['last_subtask'] 23 | 24 | with torch.no_grad(): 25 | next_subtask = self.env_params['next_subtasks'][subtask] 26 | if next_subtask != self.env_params['last_subtask']: 27 | dist = self.actor(next_obs, next_subtask) 28 | action_out = dist.rsample() 29 | target_V = self.critic_target.q(next_obs, action_out, next_subtask) 30 | 31 | if self.raw_env_reward and next_subtask == self.env_params['last_subtask']: 32 | target_Q = self.reward_scale * reward + (self.discount * raw_reward) 33 | else: 34 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach() 35 | 36 | current_Q1, current_Q2 = self.critic(obs, action, subtask) 37 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 38 | 39 | # Optimize critic loss 40 | self.critic_optimizer[subtask].zero_grad() 41 | critic_loss.backward() 42 | self.critic_optimizer[subtask].step() 43 | 44 | metrics = AttrDict( 45 | critic_target_q=target_Q.mean().item(), 46 | critic_q=current_Q1.mean().item(), 47 | critic_loss=critic_loss.item() 48 | ) 49 | return prefix_dict(metrics, subtask + '_') 50 | 51 | def update_actor(self, obs, action, subtask): 52 | # compute log probability 53 | dist = self.actor(obs, subtask) 54 | log_probs = dist.log_prob(action).sum(-1, keepdim=True) 55 | 56 | # compute exponential weight 57 | weights = self._compute_weights(obs, action, subtask) 58 | actor_loss = -(log_probs * weights).sum() 59 | 60 | self.actor_optimizer[subtask].zero_grad() 61 | actor_loss.backward() 62 | self.actor_optimizer[subtask].step() 63 | 64 | metrics = AttrDict( 65 | log_probs=log_probs.mean(), 66 | actor_loss=actor_loss.item() 67 | ) 68 | return prefix_dict(metrics, subtask + '_') 69 | 70 | def update(self, replay_buffer, demo_buffer): 71 | metrics = AttrDict() 72 | 73 | for i in range(self.update_epoch): 74 | for subtask in self.env_params['middle_subtasks']: 75 | # sample from replay buffer 76 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(replay_buffer, subtask) 77 | action = action / self.max_action 78 | 79 | # update critic and actor 80 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask)) 81 | metrics.update(self.update_actor(obs, action, subtask)) 82 | 83 | # sample from replay buffer 84 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(demo_buffer, subtask) 85 | action = action / self.max_action 86 | 87 | # update critic and actor 88 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask)) 89 | metrics.update(self.update_actor(obs, action, subtask)) 90 | 91 | # update target critic and actor 92 | self.update_target() 93 | 94 | return metrics 95 | 96 | def _compute_weights(self, obs, act, subtask): 97 | with torch.no_grad(): 98 | batch_size = obs.shape[0] 99 | 100 | # compute action-value 101 | q_values = self.critic.q(obs, act, subtask) 102 | 103 | # sample actions 104 | policy_actions = self.actor.sample_n(obs, subtask, self.n_action_samples) 105 | flat_actions = policy_actions.reshape(-1, self.dima) 106 | 107 | # repeat observation 108 | reshaped_obs = obs.view(batch_size, 1, *obs.shape[1:]) 109 | reshaped_obs = reshaped_obs.expand(batch_size, self.n_action_samples, *obs.shape[1:]) 110 | flat_obs = reshaped_obs.reshape(-1, *obs.shape[1:]) 111 | 112 | # compute state-value 113 | flat_v_values = self.critic.q(flat_obs, flat_actions, subtask) 114 | reshaped_v_values = flat_v_values.view(obs.shape[0], -1, 1) 115 | v_values = reshaped_v_values.mean(dim=1) 116 | 117 | # compute normalized weight 118 | adv_values = (q_values - v_values).view(-1) 119 | weights = F.softmax(adv_values / self.lam, dim=0).view(-1, 1) 120 | 121 | return weights * adv_values.numel() 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ViSkill: Value-Informed Skill Chaining for Policy Learning of Long-Horizon Tasks with Surgical Robot 2 | This is the official PyTorch implementation of the paper "[**Value-Informed Skill Chaining for Policy Learning of Long-Horizon Tasks with Surgical Robot**](https://arxiv.org/pdf/2307.16503.pdf)" (IROS 2023). 3 | 6 | 7 |

8 | 9 |

10 | 11 | # Prerequisites 12 | * Ubuntu 18.04 13 | * Python 3.7+ 14 | 15 | 16 | # Installation Instructions 17 | 18 | 1. Clone this repository. 19 | ```bash 20 | git clone --recursive https://github.com/med-air/ViSkill.git 21 | cd ViSkill 22 | git submodule update --init --recursive 23 | ``` 24 | 25 | 2. Create a virtual environment 26 | ```bash 27 | conda create -n viskill python=3.8 28 | conda activate viskill 29 | ``` 30 | 31 | 3. Install packages 32 | 33 | ```bash 34 | pip3 install -e SurRoL/ # install surrol environments 35 | pip3 install -r requirements.txt 36 | pip3 install -e . 37 | ``` 38 | 39 | 4. Then add one line of code at the top of `gym/gym/envs/__init__.py` to register SurRoL tasks: 40 | 41 | ```python 42 | # directory: anaconda3/envs/dex/lib/python3.8/site-packages/ 43 | import surrol.gym 44 | ``` 45 | 46 | # Usage 47 | Commands for ViSkill. Results will be logged to WandB. Before running the commands below, please change the wandb entity in [```skill_learning.yaml```](viskill/configs/skill_learning.yaml#L32) and [```skill_chaining.yaml```](viskill/configs/skill_chaining.yaml#L39) to match your account. 48 | 49 | We collect demonstration data for each subtask via the scripted controllers provided by SurRoL. Take the BiPegTransfer task as example: 50 | ```bash 51 | mkdir SurRoL/surrol/data/demo 52 | python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask grasp 53 | python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask handover 54 | python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask release 55 | ``` 56 | ## Training Commands 57 | 58 | - Train subtask policies: 59 | ```bash 60 | mpirun -np 8 python -m train_sl seed=1 subtask=grasp 61 | mpirun -np 8 python -m train_sl seed=1 subtask=handover 62 | mpirun -np 8 python -m train_sl seed=1 subtask=release 63 | ``` 64 | 65 | - Train chaining policies: 66 | ```bash 67 | mpirun -np 8 python -m train_sc.py model_seed=1 task=BiPegTransfer-v0 68 | ``` 69 | 70 | 71 | # Starting to Modify the Code 72 | ## Modifying the hyperparameters 73 | The default hyperparameters are defined in `viskill/configs`, where [```skill_learning.yaml```](viskill/configs/skill_learning.yaml) and [```skill_chaining.yaml```](viskill/configs/skill_chaining.yaml) define the experiment settings of learning subtask policies and chaining policies, respectively, and YAML file in the directory [```sl_agent```](viskill/configs/sl_agent) and [```sc_agent```](viskill/configs/sc_agent) define the hyperparameters of each method. Modifications to these parameters can be directly defined in the experiment or agent config files, or passed through the terminal command. 74 | 75 | ## Adding a new RL algorithm 76 | The core RL algorithms are implemented within the `BaseAgent` class. For adding a new skill chaining algorithm, a new file needs to be created in 77 | `viskill/agents` and [```BaseAgent```](viskill/agents/base.py#L8) needs to be subclassed. In particular, any required 78 | networks (actor, critic etc) need to be constructed and the `update(...)` function and `get_action(...)` needs to be overwritten. When implementation is done, a registration is needed in [```factory.py```](viskill/agents/factory.py) and a config file should also be made in [```sc_agent```](viskill/configs/sc_agent) to specify the model parameters. 79 | 80 | # Code Navigation 81 | 82 | ``` 83 | viskill 84 | |- agents # implements core algorithms in agent classes 85 | |- components # reusable infrastructure for model training 86 | | |- checkpointer.py # handles saving + loading of model checkpoints 87 | | |- environment.py # environment wrappers for SurRoL environments 88 | | |- normalizer.py # normalizer for vectorized input 89 | | |- logger.py # implements core logging functionality using wandB 90 | | 91 | |- configs # experiment configs 92 | | |- skill_learning.yaml # configs for subtask policy learning 93 | | |- skill_chaining.yaml # configs for chaining policy learning 94 | | |- sl_agent # configs for demonstration-guided RL algorithm 95 | | |- sc_agent # configs for skill chaining algorithm 96 | | |- hier_agent # configs for (synthetic) hierachical agent 97 | | 98 | |- modules # reusable architecture components 99 | | |- critic.py # basic critic implementations (eg MLP-based critic) 100 | | |- distributions.py # pytorch distribution utils for density model 101 | | |- policy.py # basic actor implementations 102 | | |- replay_buffer.py # her replay buffer with future sampling strategy 103 | | |- sampler.py # rollout sampler for collecting experience 104 | | |- subnetworks.py # basic networks 105 | | 106 | |- trainers # main model training script, builds all components + runs training loop and logging 107 | | 108 | |- utils # general and rl utilities, pytorch / visualization utilities etc 109 | |- train_sl.py # experiment launcher for skill learning 110 | |- train_sc.py # experiment launcher for skill chaining 111 | ``` 112 | 113 | # Contact 114 | For any questions, please feel free to email taou.cs13@gmail.com 115 | -------------------------------------------------------------------------------- /viskill/agents/sl_ddpgbc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from ..components.normalizer import Normalizer 8 | from ..modules.critics import Critic 9 | from ..modules.policies import DeterministicActor 10 | from .base import BaseAgent 11 | 12 | 13 | class SkillLearningDDPGBC(BaseAgent): 14 | def __init__( 15 | self, 16 | env_params, 17 | sampler, 18 | agent_cfg, 19 | ): 20 | super().__init__() 21 | 22 | self.discount = agent_cfg.discount 23 | self.reward_scale = agent_cfg.reward_scale 24 | self.update_epoch = agent_cfg.update_epoch 25 | self.her_sampler = sampler # same as which in buffer 26 | self.device = agent_cfg.device 27 | 28 | self.noise_eps = agent_cfg.noise_eps 29 | self.aux_weight = agent_cfg.aux_weight 30 | self.p_dist = agent_cfg.p_dist 31 | self.soft_target_tau = agent_cfg.soft_target_tau 32 | 33 | self.clip_obs = agent_cfg.clip_obs 34 | self.norm_clip = agent_cfg.norm_clip 35 | self.norm_eps = agent_cfg.norm_eps 36 | 37 | self.dima = env_params['act'] 38 | self.dimo, self.dimg = env_params['obs'], env_params['goal'] 39 | 40 | self.max_action = env_params['max_action'] 41 | self.act_sampler = env_params['act_rand_sampler'] 42 | 43 | # normarlizer 44 | self.o_norm = Normalizer( 45 | size=self.dimo, 46 | default_clip_range=self.norm_clip, 47 | eps=agent_cfg.norm_eps 48 | ) 49 | self.g_norm = Normalizer( 50 | size=self.dimg, 51 | default_clip_range=self.norm_clip, 52 | eps=agent_cfg.norm_eps 53 | ) 54 | 55 | # build policy 56 | self.actor = DeterministicActor( 57 | self.dimo+self.dimg, self.dima, agent_cfg.hidden_dim).to(agent_cfg.device) 58 | self.actor_target = copy.deepcopy(self.actor).to(agent_cfg.device) 59 | 60 | self.critic = Critic( 61 | self.dimo+self.dimg+self.dima, agent_cfg.hidden_dim).to(agent_cfg.device) 62 | self.critic_target = copy.deepcopy(self.critic).to(agent_cfg.device) 63 | 64 | # optimizer 65 | self.actor_optimizer = torch.optim.Adam( 66 | self.actor.parameters(), lr=agent_cfg.actor_lr 67 | ) 68 | self.critic_optimizer = torch.optim.Adam( 69 | self.critic.parameters(), lr=agent_cfg.critic_lr 70 | ) 71 | 72 | def get_action(self, state, noise=False): 73 | # random action at initial stage 74 | with torch.no_grad(): 75 | o, g = state['observation'], state['desired_goal'] 76 | input_tensor = self._preproc_inputs(o, g) 77 | action = self.actor(input_tensor).cpu().data.numpy().flatten() 78 | 79 | # Gaussian noise 80 | if noise: 81 | action = (action + self.max_action * self.noise_eps * np.random.randn(action.shape[0])).clip( 82 | -self.max_action, self.max_action) 83 | 84 | return action 85 | 86 | def update_critic(self, obs, action, reward, next_obs): 87 | metrics = dict() 88 | 89 | with torch.no_grad(): 90 | action_out = self.actor_target(next_obs) 91 | target_V = self.critic_target(next_obs, action_out) 92 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach() 93 | 94 | clip_return = 1 / (1 - self.discount) 95 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach() 96 | 97 | Q = self.critic(obs, action) 98 | critic_loss = F.mse_loss(Q, target_Q) 99 | 100 | # optimize critic loss 101 | self.critic_optimizer.zero_grad() 102 | critic_loss.backward() 103 | self.critic_optimizer.step() 104 | 105 | metrics['critic_target_q'] = target_Q.mean().item() 106 | metrics['critic_q'] = Q.mean().item() 107 | metrics['critic_loss'] = critic_loss.item() 108 | return metrics 109 | 110 | def update_actor(self, obs, action, is_demo=False): 111 | metrics = dict() 112 | 113 | action_out = self.actor(obs) 114 | Q_out = self.critic(obs, action_out) 115 | 116 | # refer to https://arxiv.org/pdf/1709.10089.pdf 117 | if is_demo: 118 | bc_loss = self.norm_dist(action_out, action) 119 | # q-filter 120 | with torch.no_grad(): 121 | q_filter = self.critic_target(obs, action) >= self.critic_target(obs, action_out) 122 | bc_loss = q_filter * bc_loss 123 | actor_loss = -(Q_out + self.aux_weight * bc_loss).mean() 124 | else: 125 | actor_loss = -(Q_out).mean() 126 | 127 | actor_loss += action_out.pow(2).mean() 128 | 129 | # optimize actor loss 130 | self.actor_optimizer.zero_grad() 131 | actor_loss.backward() 132 | self.actor_optimizer.step() 133 | 134 | metrics['actor_loss'] = actor_loss.item() 135 | return metrics 136 | 137 | def update(self, replay_buffer, demo_buffer): 138 | metrics = dict() 139 | 140 | for i in range(self.update_epoch): 141 | # sample from replay buffer 142 | obs, action, reward, done, next_obs = self.get_samples(replay_buffer) 143 | 144 | # ppdate critic and actor 145 | metrics.update(self.update_critic(obs, action, reward, next_obs)) 146 | metrics.update(self.update_actor(obs, action)) 147 | 148 | # sample from demo buffer 149 | obs, action, reward, done, next_obs = self.get_samples(demo_buffer) 150 | 151 | # update critic and actor 152 | self.update_critic(obs, action, reward, next_obs) 153 | self.update_actor(obs, action, is_demo=True) 154 | 155 | # update target critic and actor 156 | self.update_target() 157 | return metrics 158 | 159 | def update_target(self): 160 | # update the frozen target models 161 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 162 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data) 163 | 164 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 165 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data) 166 | 167 | def norm_dist(self, a1, a2): 168 | self.p_dist = np.inf if self.p_dist == -1 else self.p_dist 169 | return - torch.norm(a1 - a2, p=self.p_dist, dim=1, keepdim=True).pow(2) / self.dima -------------------------------------------------------------------------------- /viskill/agents/sl_dex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from ..utils.general_utils import AttrDict 6 | from .sl_ddpgbc import SkillLearningDDPGBC 7 | 8 | 9 | class SkillLearningDEX(SkillLearningDDPGBC): 10 | def __init__( 11 | self, 12 | env_params, 13 | sampler, 14 | agent_cfg, 15 | ): 16 | super().__init__(env_params, sampler, agent_cfg) 17 | self.k = 5 18 | 19 | def get_samples(self, replay_buffer): 20 | '''Addtionally sample next action for guidance propagation''' 21 | transitions = replay_buffer.sample() 22 | 23 | # preprocess 24 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g'] 25 | transitions['obs'], transitions['g'] = self._preproc_og(o, g) 26 | transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g) 27 | 28 | obs_norm = self.o_norm.normalize(transitions['obs']) 29 | g_norm = self.g_norm.normalize(transitions['g']) 30 | inputs_norm = np.concatenate([obs_norm, g_norm], axis=1) 31 | 32 | obs_next_norm = self.o_norm.normalize(transitions['obs_next']) 33 | g_next_norm = self.g_norm.normalize(transitions['g_next']) 34 | inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1) 35 | 36 | obs = self.to_torch(inputs_norm) 37 | next_obs = self.to_torch(inputs_next_norm) 38 | action = self.to_torch(transitions['actions']) 39 | next_action = self.to_torch(transitions['next_actions']) 40 | reward = self.to_torch(transitions['r']) 41 | done = self.to_torch(transitions['dones']) 42 | return obs, action, reward, done, next_obs, next_action 43 | 44 | def update_critic(self, obs, action, reward, next_obs, next_obs_demo, next_action_demo): 45 | with torch.no_grad(): 46 | next_action_out = self.actor_target(next_obs) 47 | target_V = self.critic_target(next_obs, next_action_out) 48 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach() 49 | 50 | # exploration guidance 51 | topk_actions = self.compute_propagated_actions(next_obs, next_obs_demo, next_action_demo) 52 | act_dist = self.norm_dist(topk_actions, next_action_out) 53 | target_Q += self.aux_weight * act_dist 54 | 55 | clip_return = 5 / (1 - self.discount) 56 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach() 57 | 58 | Q = self.critic(obs, action) 59 | critic_loss = F.mse_loss(Q, target_Q) 60 | 61 | # optimize critic loss 62 | self.critic_optimizer.zero_grad() 63 | critic_loss.backward() 64 | self.critic_optimizer.step() 65 | 66 | metrics = AttrDict( 67 | critic_q=Q.mean().item(), 68 | critic_target_q=target_Q.mean().item(), 69 | critic_loss=critic_loss.item(), 70 | bacth_reward=reward.mean().item() 71 | ) 72 | return metrics 73 | 74 | def update_actor(self, obs, obs_demo, action_demo): 75 | action_out = self.actor(obs) 76 | Q_out = self.critic(obs, action_out) 77 | 78 | topk_actions = self.compute_propagated_actions(obs, obs_demo, action_demo) 79 | act_dist = self.norm_dist(action_out, topk_actions) 80 | actor_loss = -(Q_out + self.aux_weight * act_dist).mean() 81 | actor_loss += action_out.pow(2).mean() 82 | 83 | # optimize actor loss 84 | self.actor_optimizer.zero_grad() 85 | actor_loss.backward() 86 | self.actor_optimizer.step() 87 | 88 | metrics = AttrDict( 89 | actor_loss=actor_loss.item(), 90 | act_dist=act_dist.mean().item() 91 | ) 92 | return metrics 93 | 94 | def update(self, replay_buffer, demo_buffer): 95 | for i in range(self.update_epoch): 96 | # sample from replay buffer 97 | obs, action, reward, done, next_obs, next_action = self.get_samples(replay_buffer) 98 | obs_, action_, reward_, done_, next_obs_, next_action_ = self.get_samples(demo_buffer) 99 | 100 | with torch.no_grad(): 101 | next_action_out = self.actor_target(next_obs) 102 | target_V = self.critic_target(next_obs, next_action_out) 103 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach() 104 | 105 | l2_pair = torch.cdist(next_obs, next_obs_) 106 | topk_value, topk_indices = l2_pair.topk(self.k, dim=1, largest=False) 107 | topk_weight = F.softmin(topk_value.sqrt(), dim=1) 108 | topk_actions = torch.ones_like(next_action_) 109 | 110 | for i in range(topk_actions.size(0)): 111 | topk_actions[i] = torch.mm(topk_weight[i].unsqueeze(0), next_action_[topk_indices[i]]).squeeze(0) 112 | intr = self.norm_dist(topk_actions, next_action_out) 113 | target_Q += self.aux_weight * intr 114 | next_action_out_ =self.actor_target(next_obs_) 115 | target_V_ = self.critic_target(next_obs_, next_action_out_) 116 | target_Q_ = self.reward_scale * reward_ + (self.discount * target_V_).detach() 117 | intr_ = self.norm_dist(next_action_, next_action_out_) 118 | target_Q_ += self.aux_weight * intr_ 119 | 120 | clip_return = 5 / (1 - self.discount) 121 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach() 122 | target_Q_ = torch.clamp(target_Q_, -clip_return, 0).detach() 123 | 124 | 125 | Q = self.critic(obs, action) 126 | Q_ = self.critic(obs_, action_) 127 | critic_loss = F.mse_loss(Q, target_Q) + F.mse_loss(Q_, target_Q_) 128 | 129 | # optimize critic loss 130 | self.critic_optimizer.zero_grad() 131 | critic_loss.backward() 132 | self.critic_optimizer.step() 133 | 134 | action_out = self.actor(obs) 135 | action_out_ = self.actor(obs_) 136 | Q_out = self.critic(obs, action_out) 137 | Q_out_ = self.critic(obs_, action_out_) 138 | 139 | with torch.no_grad(): 140 | l2_pair = torch.cdist(obs, obs_) 141 | topk_value, topk_indices = l2_pair.topk(self.k, dim=1, largest=False) 142 | topk_weight = F.softmin(topk_value.sqrt(), dim=1) 143 | topk_actions = torch.ones_like(action) 144 | 145 | for i in range(topk_actions.size(0)): 146 | topk_actions[i] = torch.mm(topk_weight[i].unsqueeze(0), action_[topk_indices[i]]).squeeze(0) 147 | 148 | intr2 = self.norm_dist(action_out, topk_actions) 149 | intr3 = self.norm_dist(action_out_, action_) 150 | 151 | # Refer to https://arxiv.org/pdf/1709.10089.pdf 152 | actor_loss = - (Q_out + self.aux_weight * intr2).mean() 153 | actor_loss += -(Q_out_ + self.aux_weight * intr3).mean() 154 | 155 | actor_loss += action_out.pow(2).mean() 156 | actor_loss += action_out_.pow(2).mean() 157 | 158 | #actor_loss += self.action_l2 * action_out.pow(2).mean() 159 | 160 | # Optimize actor loss 161 | self.actor_optimizer.zero_grad() 162 | actor_loss.backward() 163 | self.actor_optimizer.step() 164 | 165 | self.update_target() 166 | 167 | metrics = AttrDict( 168 | batch_reward=reward.mean().item(), 169 | critic_q=Q.mean().item(), 170 | critic_q_=Q_.mean().item(), 171 | critic_target_q=target_Q.mean().item(), 172 | critic_loss=critic_loss.item(), 173 | actor_loss=actor_loss.item() 174 | ) 175 | return metrics -------------------------------------------------------------------------------- /viskill/modules/sampler.py: -------------------------------------------------------------------------------- 1 | from ..utils.general_utils import AttrDict, listdict2dictlist 2 | from ..utils.rl_utils import ReplayCache, ReplayCacheGT 3 | 4 | 5 | class Sampler: 6 | """Collects rollouts from the environment using the given agent.""" 7 | def __init__(self, env, agent, max_episode_len): 8 | self._env = env 9 | self._agent = agent 10 | self._max_episode_len = max_episode_len 11 | 12 | self._obs = None 13 | self._episode_step = 0 14 | self._episode_cache = ReplayCacheGT(max_episode_len) 15 | 16 | def init(self): 17 | """Starts a new rollout. Render indicates whether output should contain image.""" 18 | self._episode_reset() 19 | 20 | def sample_action(self, obs, is_train): 21 | return self._agent.get_action(obs, noise=is_train) 22 | 23 | def sample_episode(self, is_train, render=False): 24 | """Samples one episode from the environment.""" 25 | self.init() 26 | episode, done = [], False 27 | while not done and self._episode_step < self._max_episode_len: 28 | action = self.sample_action(self._obs, is_train) 29 | if action is None: 30 | break 31 | if render: 32 | render_obs = self._env.render('rgb_array') 33 | obs, reward, done, info = self._env.step(action) 34 | episode.append(AttrDict( 35 | reward=reward, 36 | success=info['is_success'], 37 | info=info 38 | )) 39 | self._episode_cache.store_transition(obs, action, done, info['gt_goal']) 40 | if render: 41 | episode[-1].update(AttrDict(image=render_obs)) 42 | 43 | # update stored observation 44 | self._obs = obs 45 | self._episode_step += 1 46 | 47 | episode[-1].done = True # make sure episode is marked as done at final time step 48 | rollouts = self._episode_cache.pop() 49 | assert self._episode_step == self._max_episode_len 50 | return listdict2dictlist(episode), rollouts, self._episode_step 51 | 52 | def _episode_reset(self, global_step=None): 53 | """Resets sampler at the end of an episode.""" 54 | self._episode_step, self._episode_reward = 0, 0. 55 | self._obs = self._reset_env() 56 | self._episode_cache.store_obs(self._obs) 57 | 58 | def _reset_env(self): 59 | return self._env.reset() 60 | 61 | 62 | class HierarchicalSampler(Sampler): 63 | """Collects experience batches by rolling out a hierarchical agent. Aggregates low-level batches into HL batch.""" 64 | def __init__(self, env, agent, env_params): 65 | super().__init__(env, agent, env_params['max_timesteps']) 66 | 67 | self._env_params = env_params 68 | self._episode_cache = AttrDict( 69 | {subtask: ReplayCache(steps) for subtask, steps in env_params.subtask_steps.items()}) 70 | 71 | def sample_episode(self, is_train, render=False): 72 | """Samples one episode from the environment.""" 73 | self.init() 74 | sc_transitions = AttrDict({subtask: [] for subtask in self._env_params.subtasks}) 75 | sc_succ_transitions = AttrDict({subtask: [] for subtask in self._env_params.subtasks}) 76 | sc_episode, sl_episode, done, prev_subtask_succ = [], AttrDict(), False, AttrDict() 77 | while not done and self._episode_step < self._max_episode_len: 78 | agent_output = self.sample_action(self._obs, is_train, self._env.subtask) 79 | if self.last_sc_action is None: 80 | self._episode_cache[self._env.subtask].store_obs(self._obs) 81 | 82 | if render: 83 | render_obs = self._env.render('rgb_array') 84 | if agent_output.is_sc_step: 85 | self.last_sc_action = agent_output.sc_action 86 | self.reward_since_last_sc = 0 87 | 88 | obs, reward, done, info = self._env.step(agent_output.sl_action) 89 | self.reward_since_last_sc += reward 90 | if info['subtask_done']: 91 | if not done: 92 | # store skill-chaining transition 93 | sc_transitions[info['subtask']].append( 94 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']]) 95 | 96 | if info['subtask_is_success']: 97 | sc_succ_transitions[info['subtask']].append( 98 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']]) 99 | else: 100 | sc_succ_transitions[info['subtask']].append([None]) 101 | 102 | # middle subtask 103 | self._episode_cache[self._env.subtask].store_obs(obs) 104 | self._episode_cache[self._env.prev_subtasks[self._env.subtask]].\ 105 | store_transition(obs, agent_output.sl_action, True) 106 | self.last_sc_obs = obs['observation'] 107 | else: 108 | # terminal subtask 109 | sc_transitions[info['subtask']] = [] 110 | sc_transitions[info['subtask']].append( 111 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']]) 112 | if info['subtask_is_success']: 113 | sc_succ_transitions[info['subtask']].append( 114 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']]) 115 | else: 116 | sc_succ_transitions[info['subtask']].append([None]) 117 | self._episode_cache[self._env.subtask].store_transition(obs, agent_output.sl_action, True) 118 | prev_subtask_succ[self._env.subtask] = info['subtask_is_success'] 119 | else: 120 | self._episode_cache[self._env.subtask].store_transition(obs, agent_output.sl_action, False) 121 | sc_episode.append(AttrDict( 122 | reward=reward, 123 | success=info['is_success'], 124 | info=info)) 125 | if render: 126 | sc_episode[-1].update(AttrDict(image=render_obs)) 127 | 128 | # update stored observation 129 | self._obs = obs 130 | self._episode_step += 1 131 | 132 | assert self._episode_step == self._max_episode_len 133 | for subtask in self._env_params.subtasks: 134 | if subtask not in prev_subtask_succ.keys(): 135 | sl_episode[subtask] = self._episode_cache[subtask].pop() 136 | continue 137 | if prev_subtask_succ[subtask]: 138 | sl_episode[subtask] = self._episode_cache[subtask].pop() 139 | else: 140 | self._episode_cache[subtask].pop() 141 | 142 | sc_episode = listdict2dictlist(sc_episode) 143 | sc_episode.update(AttrDict( 144 | sc_transitions=sc_transitions, 145 | sc_succ_transitions=sc_succ_transitions) 146 | ) 147 | 148 | return sc_episode, sl_episode, self._episode_step 149 | 150 | def _episode_reset(self, global_step=None): 151 | """Resets sampler at the end of an episode.""" 152 | self._episode_step, self._episode_reward = 0, 0. 153 | self._obs = self._reset_env() 154 | self.last_sc_obs, self.last_sc_action = self._obs['observation'], None # stores observation when last hl action was taken 155 | self.reward_since_last_sc = 0 # accumulates the reward since the last HL step for HL transition 156 | 157 | def sample_action(self, obs, is_train, subtask): 158 | return self._agent.get_action(obs, subtask, noise=is_train) -------------------------------------------------------------------------------- /viskill/utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from .general_utils import AttrDict, RecursiveAverageMeter 6 | 7 | 8 | def get_env_params(env, cfg): 9 | obs = env.reset() 10 | env_params = AttrDict( 11 | obs=obs['observation'].shape[0], 12 | achieved_goal=obs['achieved_goal'].shape[0], 13 | goal=obs['desired_goal'].shape[0], 14 | act=env.action_space.shape[0], 15 | act_rand_sampler=env.action_space.sample, 16 | max_timesteps=env.max_episode_steps, 17 | max_action=env.action_space.high[0], 18 | ) 19 | if cfg.skill_chaining: 20 | env_params.update(AttrDict( 21 | act_sc=obs['achieved_goal'].shape[0] - env.len_cond, # withoug contact condition 22 | max_action_sc=env.max_action_range, 23 | adaptor_sc=env.goal_adapator, 24 | subtask_order=env.subtask_order, 25 | num_subtasks=len(env.subtask_order), 26 | subtask_steps=env.subtask_steps, 27 | subtasks=env.subtasks, 28 | next_subtasks=env.next_subtasks, 29 | prev_subtasks=env.prev_subtasks, 30 | middle_subtasks=env.next_subtasks.keys(), 31 | last_subtask=env.last_subtask, 32 | reward_funcs=env.get_reward_functions(), 33 | len_cond=env.len_cond 34 | )) 35 | return env_params 36 | 37 | 38 | class ReplayCache: 39 | def __init__(self, T): 40 | self.T = T 41 | self.reset() 42 | 43 | def reset(self): 44 | self.t = 0 45 | self.obs, self.ag, self.g, self.actions, self.dones = [], [], [], [], [] 46 | 47 | def store_transition(self, obs, action, done): 48 | self.obs.append(obs['observation']) 49 | self.ag.append(obs['achieved_goal']) 50 | self.g.append(obs['desired_goal']) 51 | self.actions.append(action) 52 | self.dones.append(done) 53 | 54 | def store_obs(self, obs): 55 | self.obs.append(obs['observation']) 56 | self.ag.append(obs['achieved_goal']) 57 | 58 | def pop(self): 59 | assert len(self.obs) == self.T + 1 and len(self.actions) == self.T 60 | obs = np.expand_dims(np.array(self.obs.copy()),axis=0) 61 | ag = np.expand_dims(np.array(self.ag.copy()), axis=0) 62 | #print(self.ag) 63 | g = np.expand_dims(np.array(self.g.copy()), axis=0) 64 | actions = np.expand_dims(np.array(self.actions.copy()), axis=0) 65 | dones = np.expand_dims(np.array(self.dones.copy()), axis=1) 66 | dones = np.expand_dims(dones, axis=0) 67 | 68 | self.reset() 69 | episode = AttrDict(obs=obs, ag=ag, g=g, actions=actions, dones=dones) 70 | return episode 71 | 72 | 73 | class ReplayCacheGT(ReplayCache): 74 | def reset(self): 75 | self.t = 0 76 | self.obs, self.ag, self.g, self.actions, self.dones, self.gt_g = [], [], [], [], [], [] 77 | 78 | def store_transition(self, obs, action, done, gt_goal): 79 | self.obs.append(obs['observation']) 80 | self.ag.append(obs['achieved_goal']) 81 | self.g.append(obs['desired_goal']) 82 | self.actions.append(action) 83 | self.dones.append(done) 84 | self.gt_g.append(gt_goal) 85 | 86 | def pop(self): 87 | assert len(self.obs) == self.T + 1 and len(self.actions) == self.T 88 | obs = np.expand_dims(np.array(self.obs.copy()),axis=0) 89 | ag = np.expand_dims(np.array(self.ag.copy()), axis=0) 90 | #print(self.ag) 91 | g = np.expand_dims(np.array(self.g.copy()), axis=0) 92 | actions = np.expand_dims(np.array(self.actions.copy()), axis=0) 93 | dones = np.expand_dims(np.array(self.dones.copy()), axis=1) 94 | dones = np.expand_dims(dones, axis=0) 95 | gt_g = np.expand_dims(np.array(self.gt_g.copy()), axis=0) 96 | 97 | self.reset() 98 | episode = AttrDict(obs=obs, ag=ag, g=g, actions=actions, dones=dones, gt_g=gt_g) 99 | return episode 100 | 101 | 102 | def init_demo_buffer(cfg, buffer, agent, subtask=None, update_normalizer=True): 103 | '''Load demonstrations into buffer and initilaize normalizer''' 104 | demo_path = os.path.join(os.getcwd(),'SurRoL/surrol/data/demo') 105 | file_name = "data_" 106 | file_name += cfg.task 107 | file_name += "_" + 'random' 108 | if subtask is None: 109 | file_name += "_" + str(cfg.num_demo) + '_primitive_new' + cfg.subtask 110 | else: 111 | file_name += "_" + str(cfg.num_demo) + '_primitive_new' + subtask 112 | file_name += ".npz" 113 | 114 | demo_path = os.path.join(demo_path, file_name) 115 | demo = np.load(demo_path, allow_pickle=True) 116 | demo_obs, demo_acs, demo_gt = demo['observations'], demo['actions'], demo['gt_actions'] 117 | 118 | episode_cache = ReplayCacheGT(buffer.T) 119 | for epsd in range(cfg.num_demo): 120 | episode_cache.store_obs(demo_obs[epsd][0]) 121 | for i in range(buffer.T): 122 | episode_cache.store_transition( 123 | obs=demo_obs[epsd][i+1], 124 | action=demo_acs[epsd][i], 125 | done=i==(buffer.T-1), 126 | gt_goal=demo_gt[epsd][i] 127 | ) 128 | episode = episode_cache.pop() 129 | buffer.store_episode(episode) 130 | if update_normalizer: 131 | agent.update_normalizer(episode) 132 | 133 | 134 | def init_sc_buffer(cfg, buffer, agent, env_params): 135 | '''Load demonstrations into buffer and initilaize normalizer''' 136 | for subtask in env_params.subtasks: 137 | demo_path = os.path.join(os.getcwd(),'SurRoL/surrol/data/demo') 138 | file_name = "data_" 139 | file_name += cfg.task 140 | file_name += "_" + 'random' 141 | file_name += "_" + str(cfg.num_demo) + '_primitive_new' + subtask 142 | file_name += ".npz" 143 | 144 | demo_path = os.path.join(demo_path, file_name) 145 | demo = np.load(demo_path, allow_pickle=True) 146 | demo_obs, demo_acs, demo_gt = demo['observations'], demo['actions'], demo['gt_actions'] 147 | 148 | for epsd in range(cfg.num_demo): 149 | obs = demo_obs[epsd][0]['observation'] 150 | next_obs = demo_obs[epsd][-1]['observation'] 151 | action = demo_obs[epsd][0]['desired_goal'][:-env_params.len_cond] 152 | # reward = sum([env_params.reward_funcs[subtask](demo_obs[epsd][i+1]['achieved_goal'], demo_obs[epsd][i+1]['desired_goal']) \ 153 | # for i in range(len(demo_acs[epsd]))]) 154 | reward = env_params.reward_funcs[subtask](demo_obs[epsd][-1]['achieved_goal'], demo_obs[epsd][-1]['desired_goal']) 155 | #print(subtask, epsd, reward) 156 | done = subtask not in env_params.next_subtasks.keys() 157 | reward = done * reward 158 | gt_action = demo_gt[epsd][-1] 159 | buffer[subtask].add(obs, action, reward, next_obs, done, gt_action) 160 | if agent.sc_agent.normalize: 161 | # TODO: hide normalized 162 | agent.sc_agent.o_norm[subtask].update(obs) 163 | 164 | 165 | class RolloutStorage: 166 | """Can hold multiple rollouts, can compute statistics over these rollouts.""" 167 | def __init__(self): 168 | self.rollouts = [] 169 | 170 | def append(self, rollout): 171 | """Adds rollout to storage.""" 172 | self.rollouts.append(rollout) 173 | 174 | def rollout_stats(self): 175 | """Returns AttrDict of average statistics over the rollouts.""" 176 | assert self.rollouts # rollout storage should not be empty 177 | stats = RecursiveAverageMeter() 178 | for rollout in self.rollouts: 179 | stats.update(AttrDict( 180 | avg_reward=np.stack(rollout.reward).sum(), 181 | avg_success_rate=rollout.success[-1], 182 | )) 183 | return stats.avg 184 | 185 | def reset(self): 186 | del self.rollouts 187 | self.rollouts = [] 188 | 189 | def get(self): 190 | return self.rollouts 191 | 192 | def __contains__(self, key): 193 | return self.rollouts and key in self.rollouts[0] -------------------------------------------------------------------------------- /viskill/agents/sc_ddpg.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from ..components.normalizer import Normalizer 8 | from ..modules.critics import SkillChainingCritic 9 | from ..modules.policies import SkillChainingActor 10 | from ..utils.general_utils import AttrDict 11 | from .base import BaseAgent 12 | 13 | 14 | class SkillChainingDDPG(BaseAgent): 15 | def __init__( 16 | self, 17 | env_params, 18 | agent_cfg, 19 | sl_agent, 20 | ): 21 | super().__init__() 22 | 23 | self.sl_agent = sl_agent 24 | 25 | self.discount = agent_cfg.discount 26 | self.reward_scale = agent_cfg.reward_scale 27 | self.update_epoch = agent_cfg.update_epoch 28 | self.device = agent_cfg.device 29 | self.env_params = env_params 30 | 31 | self.random_eps = agent_cfg.random_eps 32 | self.noise_eps = agent_cfg.noise_eps 33 | self.soft_target_tau = agent_cfg.soft_target_tau 34 | 35 | self.normalize = agent_cfg.normalize 36 | self.clip_obs = agent_cfg.clip_obs 37 | self.norm_clip = agent_cfg.norm_clip 38 | self.norm_eps = agent_cfg.norm_eps 39 | self.intr_reward = agent_cfg.intr_reward 40 | self.raw_env_reward = agent_cfg.raw_env_reward 41 | 42 | self.dima = env_params['act_sc'] # 43 | self.dimo = env_params['obs'] 44 | self.max_action = env_params['max_action_sc'] 45 | self.goal_adapator = env_params['adaptor_sc'] 46 | 47 | # TODO: normarlizer 48 | self.o_norm = {subtask: Normalizer( 49 | size=self.dimo, 50 | default_clip_range=self.norm_clip, 51 | eps=agent_cfg.norm_eps) for subtask in env_params['middle_subtasks']} 52 | 53 | # build policy 54 | self.actor = SkillChainingActor( 55 | in_dim=self.dimo, 56 | out_dim=self.dima, 57 | hidden_dim=agent_cfg.hidden_dim, 58 | max_action=self.max_action, 59 | middle_subtasks=env_params['middle_subtasks'], 60 | last_subtask=env_params['last_subtask'] 61 | ).to(agent_cfg.device) 62 | self.actor_target = copy.deepcopy(self.actor).to(agent_cfg.device) 63 | 64 | self.critic = SkillChainingCritic( 65 | in_dim=self.dimo+self.dima, 66 | hidden_dim=agent_cfg.hidden_dim, 67 | middle_subtasks=env_params['middle_subtasks'], 68 | last_subtask=env_params['last_subtask'] 69 | ).to(agent_cfg.device) 70 | self.critic_target = copy.deepcopy(self.critic).to(agent_cfg.device) 71 | 72 | # optimizer 73 | self.actor_optimizer = {subtask: torch.optim.Adam( 74 | self.actor.parameters(), lr=agent_cfg.actor_lr) for subtask in env_params['middle_subtasks']} 75 | self.critic_optimizer = {subtask: torch.optim.Adam( 76 | self.critic.parameters(), lr=agent_cfg.critic_lr) for subtask in env_params['middle_subtasks']} 77 | 78 | def init(self, task, sl_agent): 79 | '''Initialize the actor, critic and normalizers of last subtask.''' 80 | self.actor.init_last_subtask_actor(task, sl_agent[task].actor) 81 | self.actor_target.init_last_subtask_actor(task, sl_agent[task].actor_target) 82 | self.critic.init_last_subtask_q(task, sl_agent[task].critic) 83 | self.critic_target.init_last_subtask_q(task, sl_agent[task].critic_target) 84 | self.sl_normarlizer = {subtask: sl_agent[subtask]._preproc_inputs 85 | for subtask in self.env_params['subtasks']} 86 | 87 | def get_samples(self, replay_buffer, subtask): 88 | next_subtask = self.env_params['next_subtasks'][subtask] 89 | 90 | obs, action, reward, next_obs, done, gt_action, idxs = replay_buffer[subtask].sample(return_idxs=True) 91 | sl_norm_next_obs = self.sl_normarlizer[next_subtask](next_obs, gt_action, dim=1) # only for terminal subtask 92 | obs = self._preproc_obs(obs, subtask) 93 | next_obs = self._preproc_obs(next_obs, next_subtask) 94 | action = self.to_torch(action) 95 | reward = self.to_torch(reward) 96 | done = self.to_torch(done) 97 | gt_action = self.to_torch(gt_action) 98 | 99 | if next_subtask == self.env_params['last_subtask'] and self.raw_env_reward: 100 | assert len(replay_buffer[subtask]) == len(replay_buffer[next_subtask]) 101 | _, _, raw_reward, _, _, _ = replay_buffer[next_subtask].sample(idxs=idxs) 102 | return obs, action, reward, next_obs, done, sl_norm_next_obs, self.to_torch(raw_reward) 103 | 104 | return obs, action, reward, next_obs, done, sl_norm_next_obs, None 105 | 106 | def get_action(self, state, subtask, noise=False): 107 | # random action at initial stage 108 | with torch.no_grad(): 109 | input_tensor = self._preproc_obs(state, subtask) 110 | action = self.actor[subtask](input_tensor).cpu().data.numpy().flatten() 111 | # Gaussian noise 112 | if noise: 113 | action = (action + self.max_action * self.noise_eps * np.random.randn(action.shape[0])).clip( 114 | -self.max_action, self.max_action) 115 | return action 116 | 117 | def update_critic(self, obs, action, reward, next_obs): 118 | metrics = AttrDict() 119 | 120 | with torch.no_grad(): 121 | action_out = self.actor_target(next_obs) 122 | target_V = self.critic_target(next_obs, action_out) 123 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach() 124 | 125 | clip_return = 1 / (1 - self.discount) 126 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach() 127 | 128 | Q = self.critic(obs, action) 129 | critic_loss = F.mse_loss(Q, target_Q) 130 | 131 | # Optimize critic loss 132 | self.critic_optimizer.zero_grad() 133 | critic_loss.backward() 134 | self.critic_optimizer.step() 135 | 136 | metrics = AttrDict( 137 | critic_target_q=target_Q.mean().item(), 138 | critic_q=Q.mean().item(), 139 | critic_loss=critic_loss.item() 140 | ) 141 | return metrics 142 | 143 | def update_actor(self, obs, action, is_demo=False): 144 | action_out = self.actor(obs) 145 | Q_out = self.critic(obs, action_out) 146 | actor_loss = -(Q_out).mean() 147 | 148 | # Optimize actor loss 149 | self.actor_optimizer.zero_grad() 150 | actor_loss.backward() 151 | self.actor_optimizer.step() 152 | 153 | metrics = AttrDict( 154 | actor_loss=actor_loss.item() 155 | ) 156 | return metrics 157 | 158 | def update(self, replay_buffer): 159 | metrics = AttrDict() 160 | 161 | for i in range(self.update_epoch): 162 | # Sample from replay buffer 163 | obs, action, reward, next_obs, done = self.get_samples(replay_buffer) 164 | # Update critic and actor 165 | metrics.update(self.update_critic(obs, action, reward, next_obs)) 166 | metrics.update(self.update_actor(obs, action)) 167 | 168 | # Update target critic and actor 169 | self.update_target() 170 | return metrics 171 | 172 | def _preproc_obs(self, o, subtask): 173 | o = np.clip(o, -self.clip_obs, self.clip_obs) 174 | if self.normalize and subtask != self.env_params.last_subtask: 175 | o = self.o_norm[subtask].normalize(o) 176 | inputs = torch.tensor(o, dtype=torch.float32).to(self.device) 177 | return inputs 178 | 179 | def update_target(self): 180 | # Update the frozen target models 181 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 182 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data) 183 | 184 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 185 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data) 186 | 187 | def update_normalizer(self, rollouts, subtask): 188 | for transition in rollouts: 189 | obs, _, _, _, _, _ = transition 190 | # update 191 | self.o_norm[subtask].update(obs) 192 | # recompute the stats 193 | self.o_norm[subtask].recompute_stats() -------------------------------------------------------------------------------- /viskill/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import itertools 3 | import random 4 | import time 5 | from functools import reduce 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def set_seed_everywhere(seed): 12 | torch.manual_seed(seed) 13 | if torch.cuda.is_available(): 14 | torch.cuda.manual_seed_all(seed) 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | 18 | 19 | class Until: 20 | def __init__(self, until, action_repeat=1): 21 | self._until = until 22 | self._action_repeat = action_repeat 23 | 24 | def __call__(self, step): 25 | if self._until is None: 26 | return True 27 | until = self._until // self._action_repeat 28 | return step < until 29 | 30 | 31 | class Every: 32 | def __init__(self, every, action_repeat=1): 33 | self._every = every 34 | self._action_repeat = action_repeat 35 | 36 | def __call__(self, step): 37 | if self._every is None: 38 | return False 39 | every = self._every // self._action_repeat 40 | if step % every == 0: 41 | return True 42 | return False 43 | 44 | 45 | class Timer: 46 | def __init__(self): 47 | self._start_time = time.time() 48 | self._last_time = time.time() 49 | 50 | def reset(self): 51 | elapsed_time = time.time() - self._last_time 52 | self._last_time = time.time() 53 | total_time = time.time() - self._start_time 54 | return elapsed_time, total_time 55 | 56 | def total_time(self): 57 | return time.time() - self._start_time 58 | 59 | 60 | class AverageMeter(object): 61 | """Computes and stores the average and current value""" 62 | 63 | def __init__(self, digits=None): 64 | """ 65 | :param digits: number of digits returned for average value 66 | """ 67 | self._digits = digits 68 | self.reset() 69 | 70 | def reset(self): 71 | self.val = 0 72 | self._avg = 0 73 | self.sum = 0 74 | self.count = 0 75 | 76 | def update(self, val, n=1): 77 | self.val = val 78 | self.sum += val * n 79 | self.count += n 80 | self._avg = self.sum / max(1, self.count) 81 | 82 | @property 83 | def avg(self): 84 | if self._digits is not None: 85 | return np.round(self._avg, self._digits) 86 | else: 87 | return self._avg 88 | 89 | 90 | class RecursiveAverageMeter(object): 91 | """Computes and stores the average and current value""" 92 | def __init__(self): 93 | self.reset() 94 | 95 | def reset(self): 96 | self.val = None 97 | self.avg = None 98 | self.sum = None 99 | self.count = 0 100 | 101 | def update(self, val): 102 | self.val = val 103 | if self.sum is None: 104 | self.sum = val 105 | else: 106 | self.sum = map_recursive_list(lambda x, y: x + y, [self.sum, val]) 107 | self.count += 1 108 | self.avg = map_recursive(lambda x: x / self.count, self.sum) 109 | 110 | 111 | 112 | class AttrDict(dict): 113 | __setattr__ = dict.__setitem__ 114 | 115 | def __getattr__(self, attr): 116 | # Take care that getattr() raises AttributeError, not KeyError. 117 | # Required e.g. for hasattr(), deepcopy and OrderedDict. 118 | try: 119 | return self.__getitem__(attr) 120 | except KeyError: 121 | raise AttributeError("Attribute %r not found" % attr) 122 | 123 | def __getstate__(self): 124 | return self 125 | 126 | def __setstate__(self, d): 127 | self = d 128 | 129 | 130 | def map_dict(fn, d): 131 | """takes a dictionary and applies the function to every element""" 132 | return type(d)(map(lambda kv: (kv[0], fn(kv[1])), d.items())) 133 | 134 | 135 | def map_recursive(fn, tensors): 136 | return make_recursive(fn)(tensors) 137 | 138 | 139 | def map_recursive_list(fn, tensors): 140 | return make_recursive_list(fn)(tensors) 141 | 142 | 143 | def make_recursive(fn, *argv, **kwargs): 144 | """ Takes a fn and returns a function that can apply fn on tensor structure 145 | which can be a single tensor, tuple or a list. """ 146 | 147 | def recursive_map(tensors): 148 | if tensors is None: 149 | return tensors 150 | elif isinstance(tensors, list) or isinstance(tensors, tuple): 151 | return type(tensors)(map(recursive_map, tensors)) 152 | elif isinstance(tensors, dict): 153 | return type(tensors)(map_dict(recursive_map, tensors)) 154 | elif isinstance(tensors, torch.Tensor) or isinstance(tensors, np.ndarray): 155 | return fn(tensors, *argv, **kwargs) 156 | else: 157 | try: 158 | return fn(tensors, *argv, **kwargs) 159 | except Exception as e: 160 | print("The following error was raised when recursively applying a function:") 161 | print(e) 162 | raise ValueError("Type {} not supported for recursive map".format(type(tensors))) 163 | 164 | return recursive_map 165 | 166 | 167 | def make_recursive_list(fn): 168 | """ Takes a fn and returns a function that can apply fn across tuples of tensor structures, 169 | each of which can be a single tensor, tuple or a list. """ 170 | 171 | def recursive_map(tensors): 172 | if tensors is None: 173 | return tensors 174 | elif isinstance(tensors[0], list) or isinstance(tensors[0], tuple): 175 | return type(tensors[0])(map(recursive_map, zip(*tensors))) 176 | elif isinstance(tensors[0], dict): 177 | return map_dict(recursive_map, listdict2dictlist(tensors)) 178 | elif isinstance(tensors[0], torch.Tensor): 179 | return fn(*tensors) 180 | else: 181 | try: 182 | return fn(*tensors) 183 | except Exception as e: 184 | print("The following error was raised when recursively applying a function:") 185 | print(e) 186 | raise ValueError("Type {} not supported for recursive map".format(type(tensors))) 187 | 188 | return recursive_map 189 | 190 | 191 | def listdict2dictlist(LD): 192 | """ Converts a list of dicts to a dict of lists """ 193 | 194 | # Take intersection of keys 195 | keys = reduce(lambda x,y: x & y, (map(lambda d: d.keys(), LD))) 196 | return AttrDict({k: [dic[k] for dic in LD] for k in keys}) 197 | 198 | 199 | # def joinListDictList(LDL): 200 | # """Joins a list of dictionaries that contain lists.""" 201 | # DLL = listdict2dictlist(LDL) 202 | # return type(LDL[0])({k: list(itertools.chain.from_iterable(DLL[k])) for k in DLL}) 203 | 204 | 205 | def joinListDictList(LDL): 206 | """Joins a list of dictionaries that contain lists.""" 207 | DLL = listdict2dictlist(LDL) 208 | return type(LDL[0])({k: np.array(list(itertools.chain.from_iterable(DLL[k]))) for k in DLL}) 209 | 210 | 211 | def joinListDict(LD): 212 | """Joins a list of dictionaries that contain lists.""" 213 | DL = listdict2dictlist(LD) 214 | return type(LD[0])({k: np.array(DL[k]) for k in DL}) 215 | 216 | 217 | def joinListList(LL): 218 | """Joins a list of dictionaries that contain lists.""" 219 | return type(LL[0])(itertools.chain.from_iterable(LL)) 220 | 221 | 222 | def obj2np(obj): 223 | """Wraps an object into an np.array.""" 224 | ar = np.zeros((1,), dtype=np.object_) 225 | ar[0] = obj 226 | return ar 227 | 228 | 229 | def flatten_dict(d, parent_key='', sep='_'): 230 | items = [] 231 | for k, v in d.items(): 232 | new_key = parent_key + sep + k if parent_key else k 233 | if isinstance(v, collections.MutableMapping): 234 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 235 | else: 236 | items.append((new_key, v)) 237 | return dict(items) 238 | 239 | 240 | def prefix_dict(d, prefix): 241 | """Adds the prefix to all keys of dict d.""" 242 | return type(d)({prefix+k: v for k, v in d.items()}) 243 | 244 | 245 | def np2obj(np_array): 246 | if isinstance(np_array, list) or np_array.size > 1: 247 | return [e[0] for e in np_array] 248 | else: 249 | return np_array[0] 250 | 251 | 252 | def str2int(str): 253 | try: 254 | return int(str) 255 | except ValueError: 256 | return None 257 | 258 | 259 | def get_last_argmax(array): 260 | b = array[::-1] 261 | idx = len(b) - np.argmax(array) - 1 262 | return idx 263 | -------------------------------------------------------------------------------- /viskill/agents/sc_sac.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from ..components.normalizer import Normalizer 8 | from ..modules.critics import SkillChainingDoubleCritic 9 | from ..modules.policies import SkillChainingDiagGaussianActor 10 | from ..utils.general_utils import AttrDict, prefix_dict 11 | from .sc_ddpg import SkillChainingDDPG 12 | 13 | 14 | class SkillChainingSAC(SkillChainingDDPG): 15 | def __init__( 16 | self, 17 | env_params, 18 | agent_cfg, 19 | sl_agent 20 | ): 21 | super(SkillChainingDDPG, self).__init__() 22 | 23 | self.sl_agent = sl_agent 24 | 25 | self.discount = agent_cfg.discount 26 | self.reward_scale = agent_cfg.reward_scale 27 | self.update_epoch = agent_cfg.update_epoch 28 | self.device = agent_cfg.device 29 | self.raw_env_reward = agent_cfg.raw_env_reward 30 | self.env_params = env_params 31 | 32 | # SAC parameters 33 | self.learnable_temperature = agent_cfg.learnable_temperature 34 | self.soft_target_tau = agent_cfg.soft_target_tau 35 | 36 | self.normalize = agent_cfg.normalize 37 | self.clip_obs = agent_cfg.clip_obs 38 | self.norm_clip = agent_cfg.norm_clip 39 | self.norm_eps = agent_cfg.norm_eps 40 | 41 | self.dima = env_params['act_sc'] 42 | self.dimo = env_params['obs'] 43 | self.max_action = env_params['max_action_sc'] 44 | self.goal_adapator = env_params['adaptor_sc'] 45 | 46 | # normarlizer 47 | self.o_norm = Normalizer( 48 | size=self.dimo, 49 | default_clip_range=self.norm_clip, 50 | eps=agent_cfg.norm_eps 51 | ) 52 | 53 | # build policy 54 | self.actor = SkillChainingDiagGaussianActor( 55 | in_dim=self.dimo, 56 | out_dim=self.dima, 57 | hidden_dim=agent_cfg.hidden_dim, 58 | max_action=self.max_action, 59 | middle_subtasks=env_params['middle_subtasks'], 60 | last_subtask=env_params['last_subtask'] 61 | ).to(agent_cfg.device) 62 | 63 | self.critic = SkillChainingDoubleCritic( 64 | in_dim=self.dimo+self.dima, 65 | hidden_dim=agent_cfg.hidden_dim, 66 | middle_subtasks=env_params['middle_subtasks'], 67 | last_subtask=env_params['last_subtask'] 68 | ).to(agent_cfg.device) 69 | self.critic_target = copy.deepcopy(self.critic).to(agent_cfg.device) 70 | 71 | # entropy term 72 | if self.learnable_temperature: 73 | self.target_entropy = -self.dima 74 | self.log_alpha = {subtask : torch.tensor( 75 | np.log(agent_cfg.init_temperature)).to(self.device) for subtask in env_params['middle_subtasks']} 76 | for subtask in env_params['middle_subtasks']: 77 | self.log_alpha[subtask].requires_grad = True 78 | else: 79 | self.log_alpha = {subtask : torch.tensor( 80 | np.log(agent_cfg.init_temperature)).to(self.device) for subtask in env_params['middle_subtasks']} 81 | 82 | # optimizer 83 | self.actor_optimizer = {subtask: torch.optim.Adam( 84 | self.actor.parameters(), lr=agent_cfg.actor_lr) for subtask in env_params['middle_subtasks']} 85 | self.critic_optimizer = {subtask: torch.optim.Adam( 86 | self.critic.parameters(), lr=agent_cfg.critic_lr) for subtask in env_params['middle_subtasks']} 87 | self.temp_optimizer = {subtask: torch.optim.Adam( 88 | [self.log_alpha[subtask]], lr=agent_cfg.temp_lr) for subtask in env_params['middle_subtasks']} 89 | 90 | def init(self, task, sl_agent): 91 | '''Initialize the actor, critic and normalizers of last subtask.''' 92 | self.actor.init_last_subtask_actor(task, sl_agent[task].actor) 93 | self.critic.init_last_subtask_q(task, sl_agent[task].critic) 94 | self.critic_target.init_last_subtask_q(task, sl_agent[task].critic_target) 95 | self.sl_normarlizer = {subtask: sl_agent[subtask]._preproc_inputs 96 | for subtask in self.env_params['subtasks']} 97 | 98 | def alpha(self, subtask): 99 | return self.log_alpha[subtask].exp() 100 | 101 | def get_action(self, state, subtask, noise=False): 102 | with torch.no_grad(): 103 | #state = {key: self.to_torch(state[key].reshape([1, -1])) for key in state.keys()} # unsqueeze 104 | input_tensor = self._preproc_obs(state, subtask) 105 | dist = self.actor(input_tensor, subtask) 106 | if noise: 107 | action = dist.sample() 108 | else: 109 | action = dist.mean 110 | 111 | return action.cpu().data.numpy().flatten() * self.max_action 112 | 113 | def get_q_value(self, state, action, subtask): 114 | with torch.no_grad(): 115 | input_tensor = self._preproc_obs(state, subtask) 116 | action = self.to_torch(action) 117 | q_value = self.critic.q(input_tensor, action, subtask) 118 | 119 | return q_value 120 | 121 | def update_critic(self, obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask): 122 | assert subtask != self.env_params['last_subtask'] 123 | 124 | with torch.no_grad(): 125 | next_subtask = self.env_params['next_subtasks'][subtask] 126 | if next_subtask != self.env_params['last_subtask']: 127 | dist = self.actor(next_obs, next_subtask) 128 | action_out = dist.rsample() 129 | log_prob = dist.log_prob(action_out).sum(-1, keepdim=True) 130 | target_V = self.critic_target.q(next_obs, action_out, next_subtask) 131 | target_V = target_V - self.alpha(next_subtask).detach() * log_prob 132 | else: 133 | action_out = self.actor[next_subtask](sl_norm_next_obs) 134 | target_V = self.critic[next_subtask](sl_norm_next_obs, action_out).squeeze(0) 135 | 136 | if self.raw_env_reward and next_subtask == self.env_params['last_subtask']: 137 | target_Q = self.reward_scale * reward + (self.discount * raw_reward) 138 | else: 139 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach() 140 | 141 | current_Q1, current_Q2 = self.critic(obs, action, subtask) 142 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) 143 | 144 | # optimize critic loss 145 | self.critic_optimizer[subtask].zero_grad() 146 | critic_loss.backward() 147 | self.critic_optimizer[subtask].step() 148 | 149 | metrics = AttrDict( 150 | critic_target_q=target_Q.mean().item(), 151 | critic_q=current_Q1.mean().item(), 152 | critic_loss=critic_loss.item() 153 | ) 154 | return prefix_dict(metrics, subtask + '_') 155 | 156 | def update_actor_and_alpha(self, obs, subtask): 157 | # compute log probability 158 | dist = self.actor(obs, subtask) 159 | action = dist.rsample() 160 | log_probs = dist.log_prob(action).sum(-1, keepdim=True) 161 | 162 | # compute state value 163 | actor_Q = self.critic.q(obs, action, subtask) 164 | actor_loss = (self.alpha(subtask).detach() * log_probs - actor_Q).mean() 165 | 166 | # optimize actor loss 167 | self.actor_optimizer[subtask].zero_grad() 168 | actor_loss.backward() 169 | self.actor_optimizer[subtask].step() 170 | 171 | metrics = AttrDict( 172 | log_probs=log_probs.mean(), 173 | actor_loss=actor_loss.item() 174 | ) 175 | 176 | # compute temp loss 177 | if self.learnable_temperature: 178 | temp_loss = (self.alpha(subtask) * (-log_probs - self.target_entropy).detach()).mean() 179 | self.temp_optimizer[subtask].zero_grad() 180 | temp_loss.backward() 181 | self.temp_optimizer[subtask].step() 182 | 183 | metrics.update(AttrDict( 184 | temp_loss=temp_loss.item(), 185 | temp=self.alpha(subtask) 186 | )) 187 | return prefix_dict(metrics, subtask + '_') 188 | 189 | def update(self, replay_buffer, demo_buffer): 190 | metrics = AttrDict() 191 | 192 | for i in range(self.update_epoch): 193 | for subtask in self.env_params['middle_subtasks']: 194 | # sample from replay buffer 195 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(replay_buffer, subtask) 196 | action = action / self.max_action 197 | 198 | # update critic and actor 199 | metrics.update(self.update_critic( obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask)) 200 | metrics.update(self.update_actor_and_alpha(obs, subtask)) 201 | 202 | # update target critic and actor 203 | self.update_target() 204 | 205 | return metrics 206 | 207 | def update_target(self): 208 | # update the frozen target models 209 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 210 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data) 211 | -------------------------------------------------------------------------------- /viskill/trainers/sl_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from ..agents.factory import make_sl_agent 6 | from ..components.checkpointer import CheckpointHandler, save_cmd 7 | from ..components.envrionment import make_env 8 | from ..components.logger import Logger, WandBLogger, logger 9 | from ..modules.replay_buffer import HerReplayBufferWithGT, get_buffer_sampler 10 | from ..modules.sampler import Sampler 11 | from ..utils.general_utils import (AverageMeter, Every, Timer, Until, 12 | set_seed_everywhere) 13 | from ..utils.mpi import (mpi_gather_experience_episode, 14 | mpi_gather_experience_rollots, mpi_sum, 15 | update_mpi_config) 16 | from ..utils.rl_utils import RolloutStorage, get_env_params, init_demo_buffer 17 | from .base_trainer import BaseTrainer 18 | 19 | 20 | class SkillLearningTrainer(BaseTrainer): 21 | def _setup(self): 22 | self._setup_env() # Environment 23 | self._setup_buffer() # Relay buffer 24 | self._setup_agent() # Agent 25 | self._setup_sampler() # Sampler 26 | self._setup_logger() # Logger 27 | self._setup_misc() # MISC 28 | 29 | if self.is_chef: 30 | self.termlog.info('Setup done') 31 | 32 | def _setup_env(self): 33 | self.train_env = make_env(self.cfg) 34 | self.eval_env = make_env(self.cfg) 35 | self.env_params = get_env_params(self.train_env, self.cfg) 36 | 37 | def _setup_buffer(self): 38 | self.buffer_sampler = get_buffer_sampler(self.train_env, self.cfg.agent.sampler) 39 | self.buffer = HerReplayBufferWithGT(buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params, 40 | batch_size=self.cfg.batch_size, sampler=self.buffer_sampler) 41 | self.demo_buffer = HerReplayBufferWithGT(buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params, 42 | batch_size=self.cfg.batch_size, sampler=self.buffer_sampler) 43 | 44 | def _setup_agent(self): 45 | self.agent = make_sl_agent(self.env_params, self.buffer_sampler, self.cfg.agent) 46 | 47 | def _setup_sampler(self): 48 | self.train_sampler = Sampler(self.train_env, self.agent, self.env_params['max_timesteps']) 49 | self.eval_sampler = Sampler(self.eval_env, self.agent, self.env_params['max_timesteps']) 50 | 51 | def _setup_logger(self): 52 | update_mpi_config(self.cfg) 53 | if self.is_chef: 54 | exp_name = f"SL_{self.cfg.task}_{self.cfg.subtask}_{self.cfg.agent.name}_seed{self.cfg.seed}" 55 | if self.cfg.postfix is not None: 56 | exp_name = exp_name + '_' + self.cfg.postfix 57 | self.wb = WandBLogger(exp_name=exp_name, project_name=self.cfg.project_name, entity=self.cfg.entity_name, \ 58 | path=self.work_dir, conf=self.cfg) 59 | self.logger = Logger(self.work_dir) 60 | self.termlog = logger 61 | save_cmd(self.work_dir) 62 | else: 63 | self.wb, self.logger, self.termlog = None, None, None 64 | 65 | def _setup_misc(self): 66 | init_demo_buffer(self.cfg, self.demo_buffer, self.agent) 67 | 68 | if self.is_chef: 69 | self.model_dir = self.work_dir / 'model' 70 | self.model_dir.mkdir(exist_ok=True) 71 | for file in os.listdir(self.model_dir): 72 | os.remove(self.model_dir / file) 73 | 74 | self.device = torch.device(self.cfg.device) 75 | self.timer = Timer() 76 | self._global_step = 0 77 | self._global_episode = 0 78 | set_seed_everywhere(self.cfg.seed) 79 | 80 | def train(self): 81 | n_train_episodes = int(self.cfg.n_train_steps / self.env_params['max_timesteps']) 82 | n_eval_episodes = int(n_train_episodes / self.cfg.n_eval) * self.cfg.mpi.num_workers 83 | n_save_episodes = int(n_train_episodes / self.cfg.n_save) * self.cfg.mpi.num_workers 84 | n_log_episodes = int(n_train_episodes / self.cfg.n_log) * self.cfg.mpi.num_workers 85 | 86 | assert n_save_episodes > n_eval_episodes 87 | if n_save_episodes % n_eval_episodes != 0: 88 | n_save_episodes = int(n_save_episodes / n_eval_episodes) * n_eval_episodes 89 | 90 | train_until_episode = Until(n_train_episodes) 91 | save_every_episodes = Every(n_save_episodes) 92 | eval_every_episodes = Every(n_eval_episodes) 93 | log_every_episodes = Every(n_log_episodes) 94 | seed_until_steps = Until(self.cfg.n_seed_steps) 95 | 96 | if self.is_chef: 97 | self.termlog.info('Starting training') 98 | while train_until_episode(self.global_episode): 99 | self._train_episode(log_every_episodes, seed_until_steps) 100 | 101 | if eval_every_episodes(self.global_episode): 102 | score = self.eval() 103 | 104 | if not self.cfg.dont_save and save_every_episodes(self.global_episode) and self.is_chef: 105 | filename = CheckpointHandler.get_ckpt_name(self.global_episode) 106 | # TODO(tao): expose scoring metric 107 | CheckpointHandler.save_checkpoint({ 108 | 'episode': self.global_episode, 109 | 'global_step': self.global_step, 110 | 'state_dict': self.agent.state_dict(), 111 | 'o_norm': self.agent.o_norm, 112 | 'g_norm': self.agent.g_norm, 113 | 'score': score, 114 | }, self.model_dir, filename) 115 | if self.cfg.save_buffer: 116 | self.buffer.save(self.model_dir, self.global_episode) 117 | self.termlog.info(f'Save checkpoint to {os.path.join(self.model_dir, filename)}') 118 | 119 | def _train_episode(self, log_every_episodes, seed_until_steps): 120 | # sync network parameters across workers 121 | if self.use_multiple_workers: 122 | self.agent.sync_networks() 123 | 124 | self.timer.reset() 125 | batch_time = AverageMeter() 126 | ep_start_step = self.global_step 127 | metrics = None 128 | 129 | # collect experience 130 | rollout_storage = RolloutStorage() 131 | episode, rollouts, env_steps = self.train_sampler.sample_episode(is_train=True, render=False) 132 | if self.use_multiple_workers: 133 | rollouts = mpi_gather_experience_episode(rollouts) 134 | 135 | # update status 136 | rollout_storage.append(episode) 137 | rollout_status = rollout_storage.rollout_stats() 138 | self._global_step += int(mpi_sum(env_steps)) 139 | self._global_episode += int(mpi_sum(1)) 140 | 141 | # save to buffer 142 | self.buffer.store_episode(rollouts) 143 | self.agent.update_normalizer(rollouts) 144 | 145 | # update policy 146 | if not seed_until_steps(ep_start_step): 147 | if self.is_chef: 148 | metrics = self.agent.update(self.buffer, self.demo_buffer) 149 | if self.use_multiple_workers: 150 | self.agent.sync_networks() 151 | 152 | # log results 153 | if metrics is not None and log_every_episodes(self.global_episode) and self.is_chef: 154 | elapsed_time, total_time = self.timer.reset() 155 | batch_time.update(elapsed_time) 156 | togo_train_time = batch_time.avg * (self.cfg.n_train_steps - ep_start_step) / env_steps / self.cfg.mpi.num_workers 157 | 158 | self.logger.log_metrics(metrics, self.global_step, ty='train') 159 | with self.logger.log_and_dump_ctx(self.global_step, ty='train') as log: 160 | log('fps', env_steps / elapsed_time) 161 | log('total_time', total_time) 162 | log('episode_reward', rollout_status.avg_reward) 163 | log('episode_length', env_steps) 164 | log('episode_sr', rollout_status.avg_success_rate) 165 | log('episode', self.global_episode) 166 | log('step', self.global_step) 167 | log('ETA', togo_train_time) 168 | self.wb.log_outputs(metrics, None, log_images=False, step=self.global_step, is_train=True) 169 | 170 | def eval(self): 171 | '''Eval agent.''' 172 | eval_rollout_storage = RolloutStorage() 173 | for _ in range(self.cfg.n_eval_episodes): 174 | episode, _, env_steps = self.eval_sampler.sample_episode(is_train=False, render=True) 175 | eval_rollout_storage.append(episode) 176 | rollout_status = eval_rollout_storage.rollout_stats() 177 | if self.use_multiple_workers: 178 | rollout_status = mpi_gather_experience_rollots(rollout_status) 179 | for key, value in rollout_status.items(): 180 | rollout_status[key] = value.mean() 181 | 182 | if self.is_chef: 183 | self.wb.log_outputs(rollout_status, eval_rollout_storage, log_images=True, step=self.global_step) 184 | with self.logger.log_and_dump_ctx(self.global_step, ty='eval') as log: 185 | log('episode_sr', rollout_status.avg_success_rate) 186 | log('episode_reward', rollout_status.avg_reward) 187 | log('episode_length', env_steps) 188 | log('episode', self.global_episode) 189 | log('step', self.global_step) 190 | 191 | del eval_rollout_storage 192 | return rollout_status.avg_success_rate 193 | 194 | @property 195 | def global_step(self): 196 | return self._global_step 197 | 198 | @property 199 | def global_episode(self): 200 | return self._global_episode 201 | 202 | @property 203 | def is_chef(self): 204 | return self.cfg.mpi.is_chef 205 | 206 | @property 207 | def use_multiple_workers(self): 208 | return self.cfg.mpi.num_workers > 1 -------------------------------------------------------------------------------- /viskill/components/logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | import inspect 4 | import logging 5 | import os 6 | from collections import defaultdict 7 | 8 | import colorlog 9 | import numpy as np 10 | import torch 11 | import wandb 12 | from termcolor import colored 13 | 14 | from ..utils.general_utils import flatten_dict, np2obj, prefix_dict 15 | from ..utils.vis_utils import add_captions_to_seq 16 | 17 | #----------------------Termnial Logger---------------------- 18 | formatter = colorlog.ColoredFormatter( 19 | "%(asctime_log_color)s[%(asctime)s]%(name_log_color)s[%(name)s]%(levelname_log_color)s[%(levelname)s] - %(message_log_color)s%(message)s", 20 | datefmt=None, 21 | reset=True, 22 | secondary_log_colors={ 23 | 'asctime': { 24 | 'INFO': 'cyan', 25 | 'ERROR': 'cyan' 26 | }, 27 | 'name': { 28 | 'INFO': 'blue', 29 | 'ERROR': 'blue' 30 | }, 31 | 'levelname': { 32 | 'INFO': 'green', 33 | 'ERROR': 'red' 34 | }, 35 | 'message': { 36 | 'INFO': 'white', 37 | 'ERROR': 'red' 38 | } 39 | }, 40 | style="%", 41 | ) 42 | 43 | logger = colorlog.getLogger("skill-chaining") 44 | logger.setLevel(logging.INFO) 45 | logger.propagate = False 46 | 47 | if not logger.handlers: 48 | ch = colorlog.StreamHandler() 49 | ch.setLevel(logging.INFO) 50 | ch.setFormatter(formatter) 51 | logger.addHandler(ch) 52 | 53 | 54 | #----------------------WandB Logger---------------------- 55 | class WandBLogger: 56 | """Logs to WandB.""" 57 | N_LOGGED_SAMPLES = 50 # how many examples should be logged in each logging step 58 | 59 | def __init__(self, exp_name, project_name, entity, path, conf, exclude=None): 60 | """ 61 | :param exp_name: full name of experiment in WandB 62 | :param project_name: name of overall project 63 | :param entity: name of head entity in WandB that hosts the project 64 | :param path: path to which WandB log-files will be written 65 | :param conf: hyperparam config that will get logged to WandB 66 | :param exclude: (optional) list of (flattened) hyperparam names that should not get logged 67 | """ 68 | if exclude is None: exclude = [] 69 | flat_config = flatten_dict(conf) 70 | filtered_config = {k: v for k, v in flat_config.items() if (k not in exclude and not inspect.isclass(v))} 71 | 72 | # clear dir 73 | # save_dir = path / 'wandb' 74 | # save_dir.mkdir(exist_ok=True) 75 | # shutil.rmtree(f"{save_dir}/") 76 | 77 | logger.info("Init wandb") 78 | wandb.init( 79 | resume='allow', 80 | project=project_name, 81 | config=filtered_config, 82 | dir=path, 83 | entity=entity, 84 | notes=conf.notes if 'notes' in conf else '', 85 | ) 86 | 87 | def log_scalar_dict(self, d, prefix='', step=None): 88 | """Logs all entries from a dict of scalars. Optionally can prefix all keys in dict before logging.""" 89 | if prefix: d = prefix_dict(d, prefix + '_') 90 | wandb.log(d) if step is None else wandb.log(d, step=step) 91 | 92 | def log_videos(self, vids, name, step=None): 93 | """Logs videos to WandB in mp4 format. 94 | Assumes list of numpy arrays as input with [time, channels, height, width].""" 95 | assert len(vids[0].shape) == 4 and vids[0].shape[1] == 3 96 | assert isinstance(vids[0], np.ndarray) 97 | if vids[0].max() <= 1.0: vids = [np.asarray(vid * 255.0, dtype=np.uint8) for vid in vids] 98 | # TODO(karl) expose the FPS as a parameter 99 | log_dict = {name: [wandb.Video(vid, fps=10, format="mp4") for vid in vids]} 100 | wandb.log(log_dict) if step is None else wandb.log(log_dict, step=step) 101 | 102 | def log_plot(self, fig, name, step=None): 103 | """Logs matplotlib graph to WandB. 104 | fig is a matplotlib figure handle.""" 105 | img = wandb.Image(fig) 106 | wandb.log({name: img}) if step is None else wandb.log({name: img}, step=step) 107 | 108 | def log_outputs(self, logging_stats, rollout_storage, log_images, step, is_train=False, log_videos=True, log_video_caption=False): 109 | """Visualizes/logs all training outputs.""" 110 | self.log_scalar_dict(logging_stats, prefix='train' if is_train else 'eval', step=step) 111 | 112 | if log_images: 113 | assert rollout_storage is not None # need rollout data for image logging 114 | # log rollout videos with info captions 115 | if 'image' in rollout_storage and log_videos: 116 | if log_video_caption: 117 | vids = [np.stack(add_captions_to_seq(rollout.image, np2obj(rollout.info))).transpose(0, 3, 1, 2) 118 | for rollout in rollout_storage.get()[-self.n_logged_samples:]] 119 | else: 120 | vids = [np.stack(rollout.image).transpose(0, 3, 1, 2) 121 | for rollout in rollout_storage.get()[-self.n_logged_samples:]] 122 | self.log_videos(vids, name="rollouts", step=step) 123 | 124 | @property 125 | def n_logged_samples(self): 126 | # TODO(karl) put this functionality in a base logger class + give it default parameters and config 127 | return self.N_LOGGED_SAMPLES 128 | 129 | 130 | #----------------------CSV Logger---------------------- 131 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 132 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 133 | ('episode_reward', 'R', 'float'), ('episode_sr', 'SR', 'float'), 134 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'), 135 | ('total_time', 'T', 'time'), ('ETA', 'ETA', 'time')] 136 | 137 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 138 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 139 | ('episode_reward', 'R', 'float'), 140 | ('total_time', 'T', 'time')] 141 | 142 | 143 | class AverageMeter(object): 144 | def __init__(self): 145 | self._sum = 0 146 | self._count = 0 147 | 148 | def update(self, value, n=1): 149 | self._sum += value 150 | self._count += n 151 | 152 | def value(self): 153 | return self._sum / max(1, self._count) 154 | 155 | 156 | class MetersGroup(object): 157 | def __init__(self, csv_file_name, formating): 158 | self._csv_file_name = csv_file_name 159 | if(os.path.exists(csv_file_name) and os.path.isfile(csv_file_name)): 160 | os.remove(csv_file_name) 161 | 162 | self._formating = formating 163 | self._meters = defaultdict(AverageMeter) 164 | self._csv_file = None 165 | self._csv_writer = None 166 | 167 | def log(self, key, value, n=1): 168 | self._meters[key].update(value, n) 169 | 170 | def _prime_meters(self): 171 | data = dict() 172 | for key, meter in self._meters.items(): 173 | if key.startswith('train'): 174 | key = key[len('train') + 1:] 175 | else: 176 | key = key[len('eval') + 1:] 177 | key = key.replace('/', '_') 178 | data[key] = meter.value() 179 | return data 180 | 181 | def _remove_old_entries(self, data): 182 | rows = [] 183 | with self._csv_file_name.open('r') as f: 184 | reader = csv.DictReader(f) 185 | for row in reader: 186 | # print(row) 187 | # if float(row['episode']) >= data['episode']: 188 | # break 189 | rows.append(row) 190 | with self._csv_file_name.open('w') as f: 191 | writer = csv.DictWriter(f, 192 | fieldnames=sorted(data.keys()), 193 | restval=0.0) 194 | writer.writeheader() 195 | for row in rows: 196 | writer.writerow(row) 197 | 198 | def _dump_to_csv(self, data): 199 | if self._csv_writer is None: 200 | should_write_header = True 201 | if self._csv_file_name.exists(): 202 | self._remove_old_entries(data) 203 | should_write_header = False 204 | 205 | self._csv_file = self._csv_file_name.open('a') 206 | self._csv_writer = csv.DictWriter(self._csv_file, 207 | fieldnames=sorted(data.keys()), 208 | restval=0.0) 209 | if should_write_header: 210 | self._csv_writer.writeheader() 211 | 212 | self._csv_writer.writerow(data) 213 | self._csv_file.flush() 214 | 215 | def _format(self, key, value, ty): 216 | if ty == 'int': 217 | value = int(value) 218 | return f'{key}: {value}' 219 | elif ty == 'float': 220 | return f'{key}: {value:.04f}' 221 | elif ty == 'time': 222 | value = str(datetime.timedelta(seconds=int(value))) 223 | return f'{key}: {value}' 224 | else: 225 | raise f'invalid format type: {ty}' 226 | 227 | def _dump_to_console(self, data, prefix): 228 | prefix = colored(prefix, 'blue' if prefix == 'train' else 'green') 229 | pieces = [f'| {prefix: <14}'] 230 | for key, disp_key, ty in self._formating: 231 | value = data.get(key, 0) 232 | pieces.append(self._format(disp_key, value, ty)) 233 | logger.info(' | '.join(pieces)) 234 | 235 | def dump(self, step, prefix): 236 | if len(self._meters) == 0: 237 | return 238 | data = self._prime_meters() 239 | data['frame'] = step 240 | self._dump_to_csv(data) 241 | self._dump_to_console(data, prefix) 242 | self._meters.clear() 243 | 244 | 245 | class Logger(object): 246 | def __init__(self, log_dir): 247 | self._log_dir = log_dir 248 | self._train_mg = MetersGroup(log_dir / 'train.csv', 249 | formating=COMMON_TRAIN_FORMAT) 250 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 251 | formating=COMMON_EVAL_FORMAT) 252 | self._sw = None 253 | 254 | def _try_sw_log(self, key, value, step): 255 | if self._sw is not None: 256 | self._sw.add_scalar(key, value, step) 257 | 258 | def log(self, key, value, step): 259 | assert key.startswith('train') or key.startswith('eval') 260 | if type(value) == torch.Tensor: 261 | value = value.item() 262 | self._try_sw_log(key, value, step) 263 | mg = self._train_mg if key.startswith('train') else self._eval_mg 264 | mg.log(key, value) 265 | 266 | def log_metrics(self, metrics, step, ty): 267 | for key, value in metrics.items(): 268 | self.log(f'{ty}/{key}', value, step) 269 | 270 | def dump(self, step, ty=None): 271 | if ty is None or ty == 'eval': 272 | self._eval_mg.dump(step, 'eval') 273 | if ty is None or ty == 'train': 274 | self._train_mg.dump(step, 'train') 275 | 276 | def log_and_dump_ctx(self, step, ty): 277 | return LogAndDumpCtx(self, step, ty) 278 | 279 | 280 | class LogAndDumpCtx: 281 | def __init__(self, logger, step, ty): 282 | self._logger = logger 283 | self._step = step 284 | self._ty = ty 285 | 286 | def __enter__(self): 287 | return self 288 | 289 | def __call__(self, key, value): 290 | self._logger.log(f'{self._ty}/{key}', value, self._step) 291 | 292 | def __exit__(self, *args): 293 | self._logger.dump(self._step, self._ty) -------------------------------------------------------------------------------- /viskill/trainers/sc_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from ..agents import make_hier_agent 6 | from ..components.checkpointer import CheckpointHandler 7 | from ..components.logger import Logger, WandBLogger, logger 8 | from ..modules.replay_buffer import (HerReplayBuffer, ReplayBuffer, 9 | get_hier_buffer_samplers) 10 | from ..modules.sampler import HierarchicalSampler 11 | from ..utils.general_utils import (AttrDict, AverageMeter, Every, Timer, Until, 12 | set_seed_everywhere) 13 | from ..utils.mpi import (mpi_gather_experience_successful_transitions, 14 | mpi_gather_experience_transitions, mpi_sum, 15 | mpi_gather_experience_rollots, 16 | update_mpi_config) 17 | from ..utils.rl_utils import RolloutStorage, init_demo_buffer, init_sc_buffer 18 | from .sl_trainer import SkillLearningTrainer 19 | 20 | 21 | class SkillChainingTrainer(SkillLearningTrainer): 22 | def _setup(self): 23 | self._setup_env() # Environment 24 | self._setup_buffer() # Relay buffer 25 | self._setup_agent() # Agent 26 | self._setup_sampler() # Sampler 27 | self._setup_logger() # Logger 28 | self._setup_misc() # MISC 29 | 30 | if self.is_chef: 31 | self.termlog.info('Setup done') 32 | 33 | def _setup_buffer(self): 34 | # Skill learning buffer -> HER dict replay buffer 35 | self.sl_buffer_samplers = get_hier_buffer_samplers(self.train_env, self.cfg.sl_agent.sampler) 36 | self.sl_buffer, self.sl_demo_buffer = AttrDict(), AttrDict() 37 | if self.cfg.agent.update_sl_agent: 38 | self.sl_buffer.update({subtask: HerReplayBuffer( 39 | buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params, batch_size=self.cfg.batch_size, 40 | sampler=self.sl_buffer_samplers[subtask], T=self.env_params.subtask_steps[subtask]) for subtask in self.env_params.subtasks} 41 | ) 42 | self.sl_demo_buffer.update({subtask: HerReplayBuffer( 43 | buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params, batch_size=self.cfg.batch_size, 44 | sampler=self.sl_buffer_samplers[subtask], T=self.env_params.subtask_steps[subtask]) for subtask in self.env_params.subtasks} 45 | ) 46 | 47 | # Skill chaining buffer -> Rollout state replay buffer 48 | self.sc_buffer = AttrDict({subtask : ReplayBuffer( 49 | obs_shape=self.env_params['obs'], action_shape=self.env_params['act_sc'], 50 | capacity=self.cfg.replay_buffer_capacity, batch_size=self.cfg.batch_size, len_cond=self.env_params['len_cond']) for subtask in self.env_params.subtasks} 51 | ) 52 | self.sc_demo_buffer = AttrDict({subtask : ReplayBuffer( 53 | obs_shape=self.env_params['obs'], action_shape=self.env_params['act_sc'], 54 | capacity=self.cfg.replay_buffer_capacity, batch_size=self.cfg.batch_size, len_cond=self.env_params['len_cond']) for subtask in self.env_params.subtasks} 55 | ) 56 | 57 | def _setup_agent(self): 58 | update_mpi_config(self.cfg) 59 | self.agent = make_hier_agent(self.env_params, self.sl_buffer_samplers, self.cfg) 60 | 61 | def _setup_sampler(self): 62 | self.train_sampler = HierarchicalSampler(self.train_env, self.agent, self.env_params) 63 | self.eval_sampler = HierarchicalSampler(self.eval_env, self.agent, self.env_params) 64 | 65 | def _setup_logger(self): 66 | if self.is_chef: 67 | exp_name = f"SC_{self.cfg.task}_{self.cfg.agent.sc_agent.name}_{self.cfg.agent.sl_agent.name}_seed{self.cfg.seed}" 68 | if self.cfg.postfix is not None: 69 | exp_name = exp_name + '_' + self.cfg.postfix 70 | self.wb = WandBLogger(exp_name=exp_name, project_name=self.cfg.project_name, entity=self.cfg.entity_name, \ 71 | path=self.work_dir, conf=self.cfg) 72 | self.logger = Logger(self.work_dir) 73 | self.termlog = logger 74 | else: 75 | self.wb, self.logger, self.termlog = None, None, None 76 | 77 | def _setup_misc(self): 78 | init_sc_buffer(self.cfg, self.sc_buffer, self.agent, self.env_params) 79 | init_sc_buffer(self.cfg, self.sc_demo_buffer, self.agent, self.env_params) 80 | 81 | if self.cfg.agent.update_sl_agent: 82 | for subtask in self.env_params.middle_subtasks: 83 | init_demo_buffer(self.cfg, self.sl_buffer[subtask], self.agent.sl_agent[subtask], subtask, False) 84 | init_demo_buffer(self.cfg, self.sl_demo_buffer[subtask], self.agent.sl_agent[subtask], subtask, False) 85 | 86 | if self.is_chef: 87 | self.model_dir = self.work_dir / 'model' 88 | self.model_dir.mkdir(exist_ok=True) 89 | for file in os.listdir(self.model_dir): 90 | os.remove(self.model_dir / file) 91 | 92 | self.device = torch.device(self.cfg.device) 93 | self.timer = Timer() 94 | self._global_step = 0 95 | self._global_episode = 0 96 | set_seed_everywhere(self.cfg.seed) 97 | 98 | def train(self): 99 | n_train_episodes = int(self.cfg.n_train_steps / self.env_params['max_timesteps']) 100 | n_eval_episodes = int(n_train_episodes / self.cfg.n_eval) * self.cfg.mpi.num_workers 101 | n_save_episodes = int(n_train_episodes / self.cfg.n_save) * self.cfg.mpi.num_workers 102 | n_log_episodes = int(n_train_episodes / self.cfg.n_log) * self.cfg.mpi.num_workers 103 | 104 | assert n_save_episodes >= n_eval_episodes 105 | if n_save_episodes % n_eval_episodes != 0: 106 | n_save_episodes = int(n_save_episodes / n_eval_episodes) * n_eval_episodes 107 | 108 | train_until_episode = Until(n_train_episodes) 109 | save_every_episodes = Every(n_save_episodes) 110 | eval_every_episodes = Every(n_eval_episodes) 111 | log_every_episodes = Every(n_log_episodes) 112 | seed_until_steps = Until(self.cfg.n_seed_steps) 113 | 114 | if self.is_chef: 115 | self.termlog.info('Starting training') 116 | while train_until_episode(self.global_episode): 117 | self._train_episode(log_every_episodes, seed_until_steps) 118 | 119 | if eval_every_episodes(self.global_episode): 120 | score = self.eval() 121 | 122 | if not self.cfg.dont_save and save_every_episodes(self.global_episode) and self.is_chef: 123 | filename = CheckpointHandler.get_ckpt_name(self.global_episode) 124 | # TODO(tao): expose scoring metric 125 | CheckpointHandler.save_checkpoint({ 126 | 'episode': self.global_episode, 127 | 'global_step': self.global_step, 128 | 'state_dict': self.agent.state_dict(), 129 | 'score': score, 130 | }, self.model_dir, filename) 131 | self.termlog.info(f'Save checkpoint to {os.path.join(self.model_dir, filename)}') 132 | 133 | def _train_episode(self, log_every_episodes, seed_until_steps): 134 | # sync network parameters across workers 135 | if self.use_multiple_workers: 136 | self.agent.sync_networks() 137 | 138 | self.timer.reset() 139 | batch_time = AverageMeter() 140 | ep_start_step = self.global_step 141 | metrics = None 142 | 143 | # collect experience and save to buffer 144 | rollout_storage = RolloutStorage() 145 | sc_episode, sl_episode, env_steps = self.train_sampler.sample_episode(is_train=True, render=False) 146 | if self.use_multiple_workers: 147 | for subtask in sc_episode.sc_transitions.keys(): 148 | transitions_batch = mpi_gather_experience_transitions(sc_episode.sc_transitions[subtask]) 149 | # save to buffer 150 | self.sc_buffer[subtask].add_rollouts(transitions_batch) 151 | if self.cfg.agent.sc_agent.normalize: 152 | self.agent.sc_agent.update_normalizer(transitions_batch, subtask) 153 | if self.cfg.agent.update_sl_agent: 154 | for subtask in self.env_params.subtasks: 155 | self.sl_buffer[subtask].store_episode(sl_episode[subtask]) 156 | 157 | if self.use_multiple_workers: 158 | for subtask in sc_episode.sc_succ_transitions.keys(): 159 | demo_batch = mpi_gather_experience_successful_transitions(sc_episode.sc_succ_transitions[subtask]) 160 | # save to buffer 161 | self.sc_demo_buffer[subtask].add_rollouts(demo_batch) 162 | else: 163 | raise NotImplementedError 164 | #transitions_batch = sc_episode.sc_transitions 165 | 166 | # update status 167 | rollout_storage.append(sc_episode) 168 | rollout_status = rollout_storage.rollout_stats() 169 | self._global_step += int(mpi_sum(env_steps)) 170 | self._global_episode += int(mpi_sum(1)) 171 | 172 | # update policy 173 | if not seed_until_steps(ep_start_step) and self.is_chef: 174 | if not self.cfg.use_demo_buffer: 175 | metrics = self.agent.update(self.sc_buffer, self.sl_buffer) 176 | else: 177 | metrics = self.agent.update(self.sc_buffer, self.sl_buffer, self.sc_demo_buffer, self.sl_demo_buffer) 178 | if self.use_multiple_workers: 179 | self.agent.sync_networks() 180 | 181 | # log results 182 | if metrics is not None and log_every_episodes(self.global_episode) and self.is_chef: 183 | elapsed_time, total_time = self.timer.reset() 184 | batch_time.update(elapsed_time) 185 | togo_train_time = batch_time.avg * (self.cfg.n_train_steps - ep_start_step) / env_steps / self.cfg.mpi.num_workers 186 | 187 | self.logger.log_metrics(metrics, self.global_step, ty='train') 188 | with self.logger.log_and_dump_ctx(self.global_step, ty='train') as log: 189 | log('fps', env_steps / elapsed_time) 190 | log('total_time', total_time) 191 | log('episode_reward', rollout_status.avg_reward) 192 | log('episode_length', env_steps) 193 | log('episode_sr', rollout_status.avg_success_rate) 194 | log('episode', self.global_episode) 195 | log('step', self.global_step) 196 | log('ETA', togo_train_time) 197 | self.wb.log_outputs(metrics, None, log_images=False, step=self.global_step, is_train=True) 198 | 199 | def eval_ckpt(self): 200 | '''Eval checkpoint.''' 201 | CheckpointHandler.load_checkpoint( 202 | self.cfg.sc_ckpt_dir, self.agent, self.device, self.cfg.sc_ckpt_episode 203 | ) 204 | 205 | eval_rollout_storage = RolloutStorage() 206 | for _ in range(self.cfg.n_eval_episodes): 207 | episode, _, env_steps = self.eval_sampler.sample_episode(is_train=False, render=True) 208 | eval_rollout_storage.append(episode) 209 | rollout_status = eval_rollout_storage.rollout_stats() 210 | if self.use_multiple_workers: 211 | rollout_status = mpi_gather_experience_rollots(rollout_status) 212 | for key, value in rollout_status.items(): 213 | rollout_status[key] = value.mean() 214 | 215 | if self.is_chef: 216 | self.wb.log_outputs(rollout_status, eval_rollout_storage, log_images=True, step=self.global_step) 217 | with self.logger.log_and_dump_ctx(self.global_step, ty='eval') as log: 218 | log('episode_sr', rollout_status.avg_success_rate) 219 | log('episode_reward', rollout_status.avg_reward) 220 | log('episode_length', env_steps) 221 | log('episode', self.global_episode) 222 | log('step', self.global_step) 223 | 224 | del eval_rollout_storage 225 | -------------------------------------------------------------------------------- /viskill/modules/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | 8 | from ..utils.general_utils import AttrDict 9 | 10 | 11 | #-------------------------Hindsight Experience Replay------------------------- 12 | class HerReplayBuffer: 13 | def __init__(self, env_params, buffer_size, batch_size, sampler, T=None): 14 | # TODO(tao): unwrap env_params 15 | self.env_params = env_params 16 | self.T = T if T is not None else env_params['max_timesteps'] 17 | self.size = buffer_size // self.T 18 | self.batch_size = batch_size 19 | 20 | # memory management 21 | self.current_size = 0 22 | self.n_transitions_stored = 0 23 | self.sample_func = sampler 24 | 25 | # create the buffer to store info 26 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 27 | 'ag': np.empty([self.size, self.T + 1, self.env_params['achieved_goal']]), 28 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 29 | 'actions': np.empty([self.size, self.T, self.env_params['act']]), 30 | 'dones': np.empty([self.size, self.T, 1]), 31 | } 32 | 33 | # store the episode 34 | def store_episode(self, episode_batch): 35 | mb_obs, mb_ag, mb_g, mb_actions, dones = episode_batch.obs, episode_batch.ag, episode_batch.g, \ 36 | episode_batch.actions, episode_batch.dones 37 | batch_size = mb_obs.shape[0] 38 | idxs = self._get_storage_idx(inc=batch_size) 39 | 40 | # store the informations 41 | self.buffers['obs'][idxs] = mb_obs 42 | self.buffers['ag'][idxs] = mb_ag 43 | self.buffers['g'][idxs] = mb_g 44 | self.buffers['actions'][idxs] = mb_actions 45 | self.buffers['dones'][idxs] = dones 46 | self.n_transitions_stored += self.T * batch_size 47 | 48 | # sample the data from the replay buffer 49 | def sample(self): 50 | temp_buffers = {} 51 | for key in self.buffers.keys(): 52 | temp_buffers[key] = self.buffers[key][:self.current_size] 53 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :] 54 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :] 55 | # sample transitions 56 | transitions = self.sample_func.sample_her_transitions(temp_buffers, self.batch_size) 57 | return transitions 58 | 59 | def _get_storage_idx(self, inc=None): 60 | inc = inc or 1 61 | if self.current_size+inc <= self.size: 62 | idx = np.arange(self.current_size, self.current_size+inc) 63 | elif self.current_size < self.size: 64 | overflow = inc - (self.size - self.current_size) 65 | idx_a = np.arange(self.current_size, self.size) 66 | idx_b = np.random.randint(0, self.current_size, overflow) 67 | idx = np.concatenate([idx_a, idx_b]) 68 | else: 69 | idx = np.random.randint(0, self.size, inc) 70 | self.current_size = min(self.size, self.current_size+inc) 71 | if inc == 1: 72 | idx = idx[0] 73 | return idx 74 | 75 | def save(self, save_dir, episode): 76 | with gzip.open(os.path.join(save_dir, f"replay_buffer_ep{episode}.zip"), 'wb') as f: 77 | pickle.dump(self.buffers, f) 78 | np.save(os.path.join(save_dir, f'idx_size_ep{episode}'), np.array([self.current_size, self.n_transitions_stored])) 79 | 80 | def load(self, save_dir, episode): 81 | with gzip.open(os.path.join(save_dir, f"replay_buffer_ep{episode}.zip"), 'rb') as f: 82 | self.buffers = pickle.load(f) 83 | idx_size = np.load(os.path.join(save_dir, f"idx_size_ep{episode}.npy")) 84 | self.current_size, self.n_transitions_stored = int(idx_size[0]), int(idx_size[1]) 85 | 86 | 87 | class HerReplayBufferWithGT(HerReplayBuffer): 88 | def __init__(self, env_params, buffer_size, batch_size, sampler, T=None): 89 | # TODO(tao): unwrap env_params 90 | self.env_params = env_params 91 | self.T = T if T is not None else env_params['max_timesteps'] 92 | self.size = buffer_size // self.T 93 | self.batch_size = batch_size 94 | # memory management 95 | self.current_size = 0 96 | self.n_transitions_stored = 0 97 | self.sample_func = sampler 98 | # create the buffer to store info 99 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]), 100 | 'ag': np.empty([self.size, self.T + 1, self.env_params['achieved_goal']]), 101 | 'g': np.empty([self.size, self.T, self.env_params['goal']]), 102 | 'actions': np.empty([self.size, self.T, self.env_params['act']]), 103 | 'dones': np.empty([self.size, self.T, 1]), 104 | 'gt_g': np.empty([self.size, self.T, self.env_params['goal']]), 105 | } 106 | self.sample_keys = ['obs', 'ag', 'g', 'actions', 'dones'] 107 | 108 | # store the episode 109 | def store_episode(self, episode_batch): 110 | mb_obs, mb_ag, mb_g, mb_actions, dones, mb_gt_g = episode_batch.obs, episode_batch.ag, episode_batch.g, \ 111 | episode_batch.actions, episode_batch.dones, episode_batch.gt_g 112 | batch_size = mb_obs.shape[0] 113 | idxs = self._get_storage_idx(inc=batch_size) 114 | 115 | # store the informations 116 | self.buffers['obs'][idxs] = mb_obs 117 | self.buffers['ag'][idxs] = mb_ag 118 | self.buffers['g'][idxs] = mb_g 119 | self.buffers['actions'][idxs] = mb_actions 120 | self.buffers['dones'][idxs] = dones 121 | self.buffers['gt_g'][idxs] = mb_gt_g 122 | self.n_transitions_stored += self.T * batch_size 123 | 124 | # sample the data from the replay buffer 125 | def sample(self): 126 | temp_buffers = {} 127 | for key in self.sample_keys: 128 | temp_buffers[key] = self.buffers[key][:self.current_size] 129 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :] 130 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :] 131 | # sample transitions 132 | transitions = self.sample_func.sample_her_transitions(temp_buffers, self.batch_size) 133 | return transitions 134 | 135 | 136 | class HER_sampler: 137 | def __init__(self, replay_strategy, replay_k, reward_func=None): 138 | self.replay_strategy = replay_strategy 139 | self.replay_k = replay_k 140 | if self.replay_strategy == 'future': 141 | self.future_p = 1 - (1. / (1 + replay_k)) 142 | else: 143 | self.future_p = 0 144 | self.reward_func = reward_func 145 | 146 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 147 | T = episode_batch['actions'].shape[1] 148 | rollout_batch_size = episode_batch['actions'].shape[0] 149 | batch_size = batch_size_in_transitions 150 | 151 | # select which rollouts and which timesteps to be used 152 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 153 | t_samples = np.random.randint(T, size=batch_size) 154 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()} 155 | 156 | # her idx 157 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p) 158 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples) 159 | future_offset = future_offset.astype(int) 160 | future_t = (t_samples + 1 + future_offset)[her_indexes] 161 | 162 | # replace go with achieved goal 163 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] 164 | transitions['g'][her_indexes] = future_ag 165 | 166 | # to get the params to re-compute reward 167 | transitions['r'] = self.reward_func(transitions['ag_next'], transitions['g'], None) 168 | if len(transitions['r'].shape) == 1: 169 | transitions['r'] = np.expand_dims(transitions['r'], 1) 170 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()} 171 | 172 | return transitions 173 | 174 | 175 | class HER_sampler_seq(HER_sampler): 176 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 177 | T = episode_batch['actions'].shape[1] 178 | rollout_batch_size = episode_batch['actions'].shape[0] 179 | batch_size = batch_size_in_transitions 180 | # select which rollouts and which timesteps to be used 181 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 182 | t_samples = np.random.randint(T-1, size=batch_size) # from T to T-1 183 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()} 184 | 185 | next_actions = episode_batch['actions'][episode_idxs, t_samples + 1].copy() 186 | transitions['next_actions'] = next_actions 187 | # her idx 188 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p) 189 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples) 190 | future_offset = future_offset.astype(int) 191 | future_t = (t_samples + 1 + future_offset)[her_indexes] 192 | # replace go with achieved goal 193 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] 194 | transitions['g'][her_indexes] = future_ag 195 | # to get the params to re-compute reward 196 | transitions['r'] = self.reward_func(transitions['ag_next'], transitions['g'], None) 197 | if len(transitions['r'].shape) == 1: 198 | transitions['r'] = np.expand_dims(transitions['r'], 1) 199 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()} 200 | 201 | return transitions 202 | 203 | 204 | def get_buffer_sampler(env, cfg): 205 | if cfg.type == 'her': 206 | sampler = HER_sampler( 207 | replay_strategy=cfg.strategy, 208 | replay_k=cfg.k, 209 | reward_func=env.compute_reward, 210 | ) 211 | elif cfg.type == 'her_seq': 212 | sampler = HER_sampler_seq( 213 | replay_strategy=cfg.strategy, 214 | replay_k=cfg.k, 215 | reward_func=env.compute_reward, 216 | ) 217 | else: 218 | raise NotImplementedError 219 | return sampler 220 | 221 | 222 | def get_hier_buffer_samplers(env, cfg): 223 | reward_funcs = env.get_reward_functions() 224 | samplers = AttrDict() 225 | if cfg.type == 'her': 226 | for subtask in env.subtasks: 227 | samplers.update({subtask: HER_sampler( 228 | replay_strategy=cfg.strategy, 229 | replay_k=cfg.k, 230 | reward_func=reward_funcs[subtask], 231 | )}) 232 | elif cfg.type == 'her_seq': 233 | for subtask in env.subtasks: 234 | samplers.update({subtask: HER_sampler_seq( 235 | replay_strategy=cfg.strategy, 236 | replay_k=cfg.k, 237 | reward_func=reward_funcs[subtask], 238 | )}) 239 | else: 240 | raise NotImplementedError 241 | return samplers 242 | 243 | 244 | #-------------------------Rollout Replay Buffer------------------------- 245 | class ReplayBuffer(Dataset): 246 | """Buffer to store environment transitions.""" 247 | def __init__(self, obs_shape, action_shape, capacity, batch_size, len_cond): 248 | self.capacity = capacity 249 | self.batch_size = batch_size 250 | # the proprioceptive obs is stored as float32, pixels obs as uint8 251 | obs_dtype = np.float32 252 | 253 | self.obses = np.empty((capacity, obs_shape), dtype=obs_dtype) 254 | self.next_obses = np.empty((capacity, obs_shape), dtype=obs_dtype) 255 | self.actions = np.empty((capacity, action_shape), dtype=np.float32) 256 | self.rewards = np.empty((capacity, 1), dtype=np.float32) 257 | self.dones = np.empty((capacity, 1), dtype=np.float32) 258 | # gt_action should not be accessed unless entering terminal subtask 259 | # TODO: expose arm number 260 | self.gt_actions = np.empty((capacity, action_shape+len_cond), dtype=np.float32) 261 | 262 | self.idx = 0 263 | self.last_save = 0 264 | self.full = False 265 | 266 | def add(self, obs, action, reward, next_obs, done, gt_action): 267 | np.copyto(self.obses[self.idx], obs) 268 | np.copyto(self.actions[self.idx], action) 269 | np.copyto(self.rewards[self.idx], reward) 270 | np.copyto(self.next_obses[self.idx], next_obs) 271 | np.copyto(self.dones[self.idx], done) 272 | np.copyto(self.gt_actions[self.idx], gt_action) 273 | 274 | self.idx = (self.idx + 1) % self.capacity 275 | self.full = self.full or self.idx == 0 276 | 277 | def add_rollouts(self, rollouts): 278 | for transition in rollouts: 279 | self.add(*transition) 280 | 281 | def sample(self, idxs=None, return_idxs=False): 282 | if idxs is None: 283 | idxs = np.random.randint( 284 | 0, self.capacity if self.full else self.idx, size=self.batch_size 285 | ) 286 | 287 | obses = self.obses[idxs] 288 | next_obses = self.next_obses[idxs] 289 | actions = self.actions[idxs] 290 | rewards = self.rewards[idxs] 291 | dones = self.dones[idxs] 292 | gt_actions = self.gt_actions[idxs] 293 | if return_idxs: 294 | return obses, actions, rewards, next_obses, dones, gt_actions, idxs 295 | else: 296 | return obses, actions, rewards, next_obses, dones, gt_actions 297 | 298 | def __getitem__(self, idx): 299 | idx = np.random.randint( 300 | 0, self.capacity if self.full else self.idx, size=1 301 | ) 302 | idx = idx[0] 303 | obs = self.obses[idx] 304 | action = self.actions[idx] 305 | reward = self.rewards[idx] 306 | next_obs = self.next_obses[idx] 307 | done = self.dones[idx] 308 | gt_action = self.gt_actions[idx] 309 | 310 | return obs, action, reward, next_obs, done, gt_action 311 | 312 | def __len__(self): 313 | return self.capacity -------------------------------------------------------------------------------- /viskill/components/envrionment.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from contextlib import contextmanager 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | from surrol.utils.pybullet_utils import (pairwise_collision, 8 | pairwise_link_collision) 9 | 10 | 11 | def approx_collision(goal_a, goal_b, th=0.025): 12 | assert goal_a.shape == goal_b.shape 13 | return np.linalg.norm(goal_a - goal_b, axis=-1) < th 14 | 15 | 16 | class SkillLearningWrapper(gym.Wrapper): 17 | def __init__(self, env, subtask, output_raw_obs): 18 | super().__init__(env) 19 | self.subtask = subtask 20 | self._start_subtask = subtask 21 | self._elapsed_steps = None 22 | self._output_raw_obs = output_raw_obs 23 | 24 | @abc.abstractmethod 25 | def _replace_goal_with_subgoal(self, obs): 26 | """Replace achieved goal and desired goal.""" 27 | raise NotImplementedError 28 | 29 | @abc.abstractmethod 30 | def _subgoal(self): 31 | """Output goal of subtask.""" 32 | raise NotImplementedError 33 | 34 | @contextmanager 35 | def switch_subtask(self, subtask=None): 36 | '''Temporally switch subtask, default: next subtask''' 37 | if subtask is not None: 38 | curr_subtask = self.subtask 39 | self.subtask = subtask 40 | yield 41 | self.subtask = curr_subtask 42 | else: 43 | self.subtask = self.SUBTASK_PREV_SUBTASK[self.subtask] 44 | yield 45 | self.subtask = self.SUBTASK_NEXT_SUBTASK[self.subtask] 46 | 47 | 48 | #-----------------------------BiPegTransfer-v0----------------------------- 49 | class BiPegTransferSLWrapper(SkillLearningWrapper): 50 | '''Wrapper for skill learning''' 51 | SUBTASK_ORDER = { 52 | 'grasp': 0, 53 | 'handover': 1, 54 | 'release': 2 55 | } 56 | SUBTASK_STEPS = { 57 | 'grasp': 45, 58 | 'handover': 35, 59 | 'release': 20 60 | } 61 | SUBTASK_RESET_INDEX = { 62 | 'handover': 4, 63 | 'release': 10 64 | } 65 | SUBTASK_RESET_MAX_STEPS = { 66 | 'handover': 45, 67 | 'release': 70 68 | } 69 | SUBTASK_PREV_SUBTASK = { 70 | 'handover': 'grasp', 71 | 'release': 'handover' 72 | } 73 | SUBTASK_NEXT_SUBTASK = { 74 | 'grasp': 'handover', 75 | 'handover': 'release' 76 | } 77 | SUBTASK_CONTACT_CONDITION = { 78 | 'grasp': [0, 1], 79 | 'handover': [1, 0], 80 | 'release': [0, 0] 81 | } 82 | LAST_SUBTASK = 'release' 83 | def __init__(self, env, subtask='grasp', output_raw_obs=False): 84 | super().__init__(env, subtask, output_raw_obs) 85 | self.done_subtasks = {key: False for key in self.SUBTASK_STEPS.keys()} 86 | 87 | @property 88 | def max_episode_steps(self): 89 | assert np.sum([x for x in self.SUBTASK_STEPS.values()]) == self.env._max_episode_steps 90 | return self.SUBTASK_STEPS[self.subtask] 91 | 92 | def step(self, action): 93 | next_obs, reward, done, info = self.env.step(action) 94 | self._elapsed_steps += 1 95 | next_obs_ = self._replace_goal_with_subgoal(next_obs.copy()) 96 | reward = self.compute_reward(next_obs_['achieved_goal'], next_obs_['desired_goal']) 97 | info['is_success'] = reward + 1 98 | done = self._elapsed_steps == self.max_episode_steps 99 | # save groud truth goal 100 | with self.switch_subtask(self.LAST_SUBTASK): 101 | info['gt_goal'] = self._replace_goal_with_subgoal(next_obs.copy())['desired_goal'] 102 | 103 | if self._output_raw_obs: return next_obs_, reward, done, info, next_obs 104 | else: return next_obs_, reward, done, info, 105 | 106 | def reset(self): 107 | self.subtask = self._start_subtask 108 | if self.subtask not in self.SUBTASK_RESET_INDEX.keys(): 109 | obs = self.env.reset() 110 | self._elapsed_steps = 0 111 | else: 112 | success = False 113 | while not success: 114 | obs = self.env.reset() 115 | self.subtask = self._start_subtask 116 | self._elapsed_steps = 0 117 | 118 | action, skill_index = self.env.get_oracle_action(obs) 119 | count, max_steps = 0, self.SUBTASK_RESET_MAX_STEPS[self.subtask] 120 | while skill_index < self.SUBTASK_RESET_INDEX[self.subtask] and count < max_steps: 121 | obs, reward, done, info = self.env.step(action) 122 | action, skill_index = self.env.get_oracle_action(obs) 123 | count += 1 124 | 125 | # Reset again if failed 126 | with self.switch_subtask(): 127 | obs_ = self._replace_goal_with_subgoal(obs.copy()) # in case repeatedly replace goal 128 | success = self.compute_reward(obs_['achieved_goal'], obs_['desired_goal']) + 1 129 | 130 | if self._output_raw_obs: return self._replace_goal_with_subgoal(obs), obs 131 | else: return self._replace_goal_with_subgoal(obs) 132 | 133 | def _replace_goal_with_subgoal(self, obs): 134 | """Replace ag and g""" 135 | subgoal = self._subgoal() 136 | psm1col = pairwise_collision(self.env.obj_id, self.psm1.body) 137 | psm2col = pairwise_collision(self.env.obj_id, self.psm2.body) 138 | 139 | if self.subtask == 'grasp': 140 | obs['achieved_goal'] = np.concatenate([obs['observation'][0: 3], obs['observation'][7: 10], [psm1col, psm2col]]) 141 | elif self.subtask == 'handover': 142 | obs['achieved_goal'] = np.concatenate([obs['observation'][7: 10], obs['observation'][0: 3], [psm1col, psm2col]]) 143 | elif self.subtask == 'release': 144 | obs['achieved_goal'] = np.concatenate([obs['observation'][7: 10], obs['achieved_goal'], [psm1col, psm2col]]) 145 | obs['desired_goal'] = np.append(subgoal, self.SUBTASK_CONTACT_CONDITION[self.subtask]) 146 | return obs 147 | 148 | def _subgoal(self): 149 | """Output goal of subtask""" 150 | goal = self.env.subgoals[self.SUBTASK_ORDER[self.subtask]] 151 | return goal 152 | 153 | def compute_reward(self, ag, g, info=None): 154 | """Compute reward that indicates the success of subtask""" 155 | if len(ag.shape) == 1: 156 | if self.subtask == 'release': 157 | goal_reach = self.env.compute_reward(ag[-5:-2], g[-5:-2], None) + 1 158 | else: 159 | goal_reach = self.env.compute_reward(ag[:-2], g[:-2], None) + 1 160 | contact_cond = np.all(ag[-2:]==g[-2:]) 161 | reward = (goal_reach and contact_cond) - 1 162 | else: 163 | if self.subtask == 'release': 164 | goal_reach = self.env.compute_reward(ag[:,-5:-2], g[:,-5:-2], None).reshape(-1, 1) + 1 165 | else: 166 | goal_reach = self.env.compute_reward(ag[:,:-2], g[:,:-2], None).reshape(-1, 1) + 1 167 | contact_cond = np.all(ag[:, -2:]==g[:, -2:], axis=1).reshape(-1, 1) 168 | reward = np.all(np.hstack([goal_reach, contact_cond]), axis=1) - 1. 169 | return reward 170 | 171 | 172 | class BiPegTransferSCWrapper(BiPegTransferSLWrapper): 173 | '''Wrapper for skill chaining.''' 174 | MAX_ACTION_RANGE = 4. 175 | REWARD_SCALE = 30. 176 | def step(self, action): 177 | next_obs, reward, done, info = self.env.step(action) 178 | self._elapsed_steps += 1 179 | next_obs_ = self._replace_goal_with_subgoal(next_obs.copy()) 180 | reward = self.compute_reward(next_obs_['achieved_goal'], next_obs_['desired_goal']) 181 | info['step'] = 1 - reward 182 | done = self._elapsed_steps == self.SUBTASK_STEPS[self.subtask] 183 | reward = done * reward 184 | info['subtask'] = self.subtask 185 | info['subtask_done'] = False 186 | info['subtask_is_success'] = reward 187 | 188 | if done: 189 | info['subtask_done'] = True 190 | # Transit to next subtask (if current subtask is not terminal) and reset elapsed steps 191 | if self.subtask in self.SUBTASK_NEXT_SUBTASK.keys(): 192 | done = False 193 | self._elapsed_steps = 0 194 | self.subtask = self.SUBTASK_NEXT_SUBTASK[self.subtask] 195 | info['is_success'] = False 196 | reward = 0 197 | else: 198 | info['is_success'] = reward 199 | next_obs_ = self._replace_goal_with_subgoal(next_obs) 200 | 201 | if self._output_raw_obs: return next_obs_, reward, done, info, next_obs 202 | else: return next_obs_, reward, done, info 203 | 204 | def reset(self, subtask=None): 205 | self.subtask = self._start_subtask if subtask is None else subtask 206 | if self.subtask not in self.SUBTASK_RESET_INDEX.keys(): 207 | obs = self.env.reset() 208 | self._elapsed_steps = 0 209 | else: 210 | success = False 211 | while not success: 212 | obs = self.env.reset() 213 | self.subtask = self._start_subtask if subtask is None else subtask 214 | self._elapsed_steps = 0 215 | 216 | action, skill_index = self.env.get_oracle_action(obs) 217 | count, max_steps = 0, self.SUBTASK_RESET_MAX_STEPS[self.subtask] 218 | while skill_index < self.SUBTASK_RESET_INDEX[self.subtask] and count < max_steps: 219 | obs, reward, done, info = self.env.step(action) 220 | action, skill_index = self.env.get_oracle_action(obs) 221 | count += 1 222 | 223 | # Reset again if failed 224 | with self.switch_subtask(): 225 | obs_ = self._replace_goal_with_subgoal(obs.copy()) # in case repeatedly replace goal 226 | success = self.compute_reward(obs_['achieved_goal'], obs_['desired_goal']) + 1 227 | 228 | if self._output_raw_obs: return self._replace_goal_with_subgoal(obs), obs 229 | else: return self._replace_goal_with_subgoal(obs) 230 | 231 | #---------------------------Reward--------------------------- 232 | def compute_reward(self, ag, g, info=None): 233 | """Compute reward that indicates the success of subtask""" 234 | if len(ag.shape) == 1: 235 | if self.subtask == 'release': 236 | goal_reach = self.env.compute_reward(ag[-5:-2], g[-5:-2], None) + 1 237 | else: 238 | goal_reach = self.env.compute_reward(ag[:-2], g[:-2], None) + 1 239 | contact_cond = np.all(ag[-2:]==g[-2:]) 240 | reward = (goal_reach and contact_cond) - 1 241 | else: 242 | if self.subtask == 'release': 243 | goal_reach = self.env.compute_reward(ag[:,-5:-2], g[:,-5:-2], None).reshape(-1, 1) + 1 244 | else: 245 | goal_reach = self.env.compute_reward(ag[:,:-2], g[:,:-2], None).reshape(-1, 1) + 1 246 | if self.subtask == 'grasp': 247 | raise NotImplementedError 248 | contact_cond = np.all(ag[:, -2:]==g[:, -2:], axis=1).reshape(-1, 1) 249 | reward = np.all(np.hstack([goal_reach, contact_cond]), axis=1) - 1. 250 | return reward + 1 251 | 252 | def goal_adapator(self, goal, subtask, device=None): 253 | '''Make predicted goal compatible with wrapper''' 254 | if isinstance(goal, np.ndarray): 255 | return np.append(goal, self.SUBTASK_CONTACT_CONDITION[subtask]) 256 | elif isinstance(goal, torch.Tensor): 257 | assert device is not None 258 | ct_cond = torch.tensor(self.SUBTASK_CONTACT_CONDITION[subtask], dtype=torch.float32) 259 | ct_cond = ct_cond.repeat(goal.shape[0], 1).to(device) 260 | adp_goal = torch.cat([goal, ct_cond], 1) 261 | return adp_goal 262 | 263 | def get_reward_functions(self): 264 | reward_funcs = {} 265 | for subtask in self.subtask_order.keys(): 266 | with self.switch_subtask(subtask): 267 | reward_funcs[subtask] = self.compute_reward 268 | return reward_funcs 269 | 270 | @property 271 | def start_subtask(self): 272 | return self._start_subtask 273 | 274 | @property 275 | def max_episode_steps(self): 276 | assert np.sum([x for x in self.SUBTASK_STEPS.values()]) == self.env._max_episode_steps 277 | return self.env._max_episode_steps 278 | 279 | @property 280 | def max_action_range(self): 281 | return self.MAX_ACTION_RANGE 282 | 283 | @property 284 | def subtask_order(self): 285 | return self.SUBTASK_ORDER 286 | 287 | @property 288 | def subtask_steps(self): 289 | return self.SUBTASK_STEPS 290 | 291 | @property 292 | def subtasks(self): 293 | subtasks = [] 294 | for subtask, order in self.subtask_order.items(): 295 | if order >= self.subtask_order[self.start_subtask]: 296 | subtasks.append(subtask) 297 | return subtasks 298 | 299 | @property 300 | def prev_subtasks(self): 301 | return self.SUBTASK_PREV_SUBTASK 302 | 303 | @property 304 | def next_subtasks(self): 305 | return self.SUBTASK_NEXT_SUBTASK 306 | 307 | @property 308 | def last_subtask(self): 309 | return self.LAST_SUBTASK 310 | 311 | @property 312 | def len_cond(self): 313 | return len(self.SUBTASK_CONTACT_CONDITION[self.last_subtask]) 314 | 315 | 316 | class BiPegBoardSLWrapper(BiPegTransferSLWrapper): 317 | '''Wrapper for skill learning''' 318 | SUBTASK_STEPS = { 319 | 'grasp': 30, 320 | 'handover': 35, 321 | 'release': 35 322 | } 323 | SUBTASK_RESET_INDEX = { 324 | 'handover': 5, 325 | 'release': 9 326 | } 327 | SUBTASK_RESET_MAX_STEPS = { 328 | 'handover': 30, 329 | 'release': 60 330 | } 331 | 332 | 333 | class BiPegBoardSCWrapper(BiPegTransferSCWrapper, BiPegBoardSLWrapper): 334 | '''Wrapper for skill chaining''' 335 | SUBTASK_STEPS = { 336 | 'grasp': 30, 337 | 'handover': 35, 338 | 'release': 35 339 | } 340 | SUBTASK_RESET_INDEX = { 341 | 'handover': 5, 342 | 'release': 9 343 | } 344 | SUBTASK_RESET_MAX_STEPS = { 345 | 'handover': 30, 346 | 'release': 60 347 | } 348 | 349 | 350 | class MatchBoardSLWrapper(BiPegTransferSLWrapper, SkillLearningWrapper): 351 | '''Wrapper for skill learning''' 352 | SUBTASK_ORDER = { 353 | 'pull': 0, 354 | 'grasp': 1, 355 | 'release': 2, 356 | 'push': 3 357 | } 358 | SUBTASK_STEPS = { 359 | 'pull': 50, 360 | 'grasp': 30, 361 | 'release': 20, 362 | 'push': 50 363 | } 364 | SUBTASK_RESET_INDEX = { 365 | 'grasp': 5, 366 | 'release': 9, 367 | 'push': 11, 368 | } 369 | SUBTASK_RESET_MAX_STEPS = { 370 | 'grasp': 60, 371 | 'release': 90, 372 | 'push': 110 373 | } 374 | SUBTASK_PREV_SUBTASK = { 375 | 'grasp': 'pull', 376 | 'release': 'grasp', 377 | 'push': 'release' 378 | } 379 | SUBTASK_NEXT_SUBTASK = { 380 | 'pull': 'grasp', 381 | 'grasp': 'release', 382 | 'release': 'push' 383 | } 384 | SUBTASK_CONTACT_CONDITION = { 385 | 'pull': [0], 386 | 'grasp': [0], 387 | 'release': [0], 388 | 'push': [0] 389 | } 390 | LAST_SUBTASK = 'push' 391 | def __init__(self, env, subtask='grasp', output_raw_obs=False): 392 | super().__init__(env, subtask, output_raw_obs) 393 | self.col_with_lid = False 394 | 395 | def reset(self): 396 | self.col_with_lid = False 397 | return super().reset() 398 | 399 | def _replace_goal_with_subgoal(self, obs): 400 | """Replace ag and g""" 401 | subgoal = self._subgoal() 402 | 403 | # collision condition 404 | if not self.col_with_lid: 405 | self.col_with_lid = pairwise_link_collision(self.env.psm1.body, 4, self.env.obj_ids['fixed'][-1], self.env.target_row * 6 + 3) != () 406 | 407 | if self.subtask in ['pull', 'push']: 408 | obs['achieved_goal'] = np.concatenate([obs['observation'][0: 3], obs['achieved_goal'][3:6], [int(self.col_with_lid)]]) 409 | obs['desired_goal'] = np.concatenate([subgoal[0:3], subgoal[6:9], [0]]) 410 | elif self.subtask in ['grasp', 'release']: 411 | obs['achieved_goal'] = np.concatenate([obs['observation'][0: 3], obs['achieved_goal'][:3], [int(self.col_with_lid)]]) 412 | obs['desired_goal'] = np.concatenate([subgoal[0:3], subgoal[3:6], [0]]) 413 | return obs 414 | 415 | #---------------------------Reward--------------------------- 416 | def compute_reward(self, ag, g, info=None): 417 | """Compute reward that indicates the success of subtask""" 418 | if len(ag.shape) == 1: 419 | goal_reach = self.env.compute_reward(ag[:-1], g[:-1], None) + 1 420 | reward = (goal_reach and (1 - ag[-1])) - 1 421 | else: 422 | goal_reach = self.env.compute_reward(ag[:, :-1], g[:, :-1], None).reshape(-1, 1) + 1 423 | reward = np.all(np.hstack([goal_reach, 1 - ag[:, -1].reshape(-1, 1)]), axis=1) - 1. 424 | return reward 425 | 426 | 427 | class MatchBoardSCWrapper(MatchBoardSLWrapper, BiPegTransferSCWrapper): 428 | MAX_ACTION_RANGE = 4. 429 | def step(self, action): 430 | next_obs, reward, done, info = self.env.step(action) 431 | self._elapsed_steps += 1 432 | next_obs_ = self._replace_goal_with_subgoal(next_obs.copy()) 433 | reward = self.compute_reward(next_obs_['achieved_goal'], next_obs_['desired_goal']) 434 | info['step'] = 1 - reward 435 | done = self._elapsed_steps == self.SUBTASK_STEPS[self.subtask] 436 | reward = done * reward 437 | info['subtask'] = self.subtask 438 | info['subtask_done'] = False 439 | info['subtask_is_success'] = reward 440 | 441 | if done: 442 | info['subtask_done'] = True 443 | # Transit to next subtask (if current subtask is not terminal) and reset elapsed steps 444 | if self.subtask in self.SUBTASK_NEXT_SUBTASK.keys(): 445 | done = False 446 | self._elapsed_steps = 0 447 | self.subtask = self.SUBTASK_NEXT_SUBTASK[self.subtask] 448 | info['is_success'] = False 449 | reward = 0 450 | else: 451 | info['is_success'] = reward 452 | next_obs_ = self._replace_goal_with_subgoal(next_obs) 453 | 454 | if self._output_raw_obs: return next_obs_, reward, done, info, next_obs 455 | else: return next_obs_, reward, done, info 456 | 457 | def reset(self, subtask=None): 458 | self.col_with_lid = False 459 | return BiPegTransferSCWrapper.reset(self, subtask) 460 | 461 | def compute_reward(self, ag, g, info=None): 462 | """Compute reward that indicates the success of subtask""" 463 | if len(ag.shape) == 1: 464 | goal_reach = self.env.compute_reward(ag[:-1], g[:-1], None) + 1 465 | reward = (goal_reach and (1 - ag[-1])) - 1 466 | else: 467 | goal_reach = self.env.compute_reward(ag[:, :-1], g[:, :-1], None).reshape(-1, 1) + 1 468 | reward = np.all(np.hstack([goal_reach, 1 - ag[:, -1].reshape(-1, 1)]), axis=1) - 1. 469 | return reward + 1 470 | 471 | def goal_adapator(self, goal, subtask, device=None): 472 | '''Make predicted goal compatible with wrapper''' 473 | if isinstance(goal, np.ndarray): 474 | return np.append(goal, self.SUBTASK_CONTACT_CONDITION[subtask]) 475 | elif isinstance(goal, torch.Tensor): 476 | assert device is not None 477 | ct_cond = torch.tensor(self.SUBTASK_CONTACT_CONDITION[subtask], dtype=torch.float32) 478 | ct_cond = ct_cond.repeat(goal.shape[0], 1).to(device) 479 | adp_goal = torch.cat([goal, ct_cond], 1) 480 | return adp_goal 481 | 482 | 483 | #-----------------------------Make envrionment----------------------------- 484 | def make_env(cfg): 485 | env = gym.make(cfg.task) 486 | if cfg.task == 'BiPegTransfer-v0': 487 | if cfg.skill_chaining: 488 | env = BiPegTransferSCWrapper(env, cfg.init_subtask, output_raw_obs=False) 489 | else: 490 | env = BiPegTransferSLWrapper(env, cfg.subtask, output_raw_obs=False) 491 | elif cfg.task == 'BiPegBoard-v0': 492 | if cfg.skill_chaining: 493 | env = BiPegBoardSCWrapper(env, cfg.init_subtask, output_raw_obs=False) 494 | else: 495 | env = BiPegBoardSLWrapper(env, cfg.subtask, output_raw_obs=False) 496 | elif cfg.task == 'MatchBoard-v0': 497 | if cfg.skill_chaining: 498 | env = MatchBoardSCWrapper(env, cfg.init_subtask, output_raw_obs=False) 499 | else: 500 | env = MatchBoardSLWrapper(env, cfg.subtask, output_raw_obs=False) 501 | else: 502 | raise NotImplementedError 503 | return env --------------------------------------------------------------------------------