├── .gitmodules
├── setup.py
├── docs
└── resources
│ └── viskill_teaser.png
├── viskill
├── agents
│ ├── __init__.py
│ ├── factory.py
│ ├── sc_sac_sil.py
│ ├── base.py
│ ├── hier_agent.py
│ ├── sc_awac.py
│ ├── sl_ddpgbc.py
│ ├── sl_dex.py
│ ├── sc_ddpg.py
│ └── sc_sac.py
├── configs
│ ├── hier_agent
│ │ └── hier_agent.yaml
│ ├── sl_agent
│ │ ├── sl_ddpgbc.yaml
│ │ └── sl_dex.yaml
│ ├── sc_agent
│ │ ├── sc_sac.yaml
│ │ ├── sc_awac.yaml
│ │ └── sc_sac_sil.yaml
│ ├── eval.yaml
│ ├── skill_learning.yaml
│ └── skill_chaining.yaml
├── modules
│ ├── subnetworks.py
│ ├── distributions.py
│ ├── critics.py
│ ├── policies.py
│ ├── sampler.py
│ └── replay_buffer.py
├── trainers
│ ├── base_trainer.py
│ ├── sl_trainer.py
│ └── sc_trainer.py
├── components
│ ├── normalizer.py
│ ├── checkpointer.py
│ ├── logger.py
│ └── envrionment.py
└── utils
│ ├── vis_utils.py
│ ├── mpi.py
│ ├── rl_utils.py
│ └── general_utils.py
├── requirements.txt
├── eval.py
├── train_sc.py
├── train_sl.py
├── LICENSE
├── .gitignore
└── README.md
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "SurRoL"]
2 | path = SurRoL
3 | url = https://github.com/TaoHuang13/SurRoL.git
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(name='viskill', version='0.0.1', packages=['viskill'])
--------------------------------------------------------------------------------
/docs/resources/viskill_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/med-air/ViSkill/HEAD/docs/resources/viskill_teaser.png
--------------------------------------------------------------------------------
/viskill/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .hier_agent import HierachicalAgent
2 |
3 |
4 | def make_hier_agent(env_params, samplers, cfg):
5 | return HierachicalAgent(env_params, samplers, cfg)
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core
2 | torch
3 | torchvision
4 | moviepy
5 | imageio
6 |
7 | # RL
8 | gym==0.15.6
9 | # mpi4py
10 |
11 | # Log
12 | colorlog
13 | termcolor
14 | wandb
15 |
16 |
--------------------------------------------------------------------------------
/viskill/configs/hier_agent/hier_agent.yaml:
--------------------------------------------------------------------------------
1 | name: hier_agent
2 | checkpoint_dir: ${checkpoint_dir}
3 |
4 | # SC agent (high-level agent)
5 | sc_agent: ${sc_agent}
6 | update_sc_agent: True
7 |
8 | # SL agent (low-level agent)
9 | sl_agent: ${sl_agent}
10 | update_sl_agent: False
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from viskill.trainers.sc_trainer import SkillChainingTrainer
3 |
4 |
5 | @hydra.main(version_base=None, config_path="./viskill/configs", config_name="eval")
6 | def main(cfg):
7 | exp = SkillChainingTrainer(cfg)
8 | exp.eval_ckpt()
9 |
10 | if __name__ == "__main__":
11 | main()
--------------------------------------------------------------------------------
/train_sc.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from viskill.trainers.sc_trainer import SkillChainingTrainer
3 |
4 |
5 | @hydra.main(version_base=None, config_path="./viskill/configs", config_name="skill_chaining")
6 | def main(cfg):
7 | exp = SkillChainingTrainer(cfg)
8 | exp.train()
9 |
10 | if __name__ == "__main__":
11 | main()
--------------------------------------------------------------------------------
/train_sl.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from viskill.trainers.sl_trainer import SkillLearningTrainer
3 |
4 |
5 | @hydra.main(version_base=None, config_path="./viskill/configs", config_name="skill_learning")
6 | def main(cfg):
7 | exp = SkillLearningTrainer(cfg)
8 | exp.train()
9 |
10 | if __name__ == "__main__":
11 | main()
--------------------------------------------------------------------------------
/viskill/configs/sl_agent/sl_ddpgbc.yaml:
--------------------------------------------------------------------------------
1 | name: SL_DDPGBC
2 | device: ${device}
3 | discount: 0.99
4 | reward_scale: 1
5 | n_seed_steps: 1000
6 |
7 | actor_lr: 1e-3
8 | critic_lr: 1e-3
9 | noise_eps: 0.1
10 | aux_weight: 5
11 | p_dist: 2
12 | soft_target_tau: 0.005
13 | clip_obs: 200
14 | norm_clip: 5
15 | norm_eps: 0.01
16 | hidden_dim: 256
17 | sampler:
18 | type: her
19 | strategy: future
20 | k: 4
21 | update_epoch: 40
--------------------------------------------------------------------------------
/viskill/configs/sl_agent/sl_dex.yaml:
--------------------------------------------------------------------------------
1 | name: SL_DEX
2 | device: ${device}
3 | discount: 0.99
4 | reward_scale: 1
5 | n_seed_steps: 200
6 |
7 | actor_lr: 1e-3
8 | critic_lr: 1e-3
9 | noise_eps: 0.2
10 | aux_weight: 5
11 | p_dist: 2
12 | k: 5
13 | soft_target_tau: 0.05
14 | clip_obs: 200
15 | norm_clip: 5
16 | norm_eps: 0.01
17 | hidden_dim: 256
18 | sampler:
19 | type: her_seq
20 | strategy: future
21 | k: 4
22 | update_epoch: 40
--------------------------------------------------------------------------------
/viskill/modules/subnetworks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class MLP(nn.Module):
5 | def __init__(self, in_dim, out_dim, hidden_dim=256):
6 | super().__init__()
7 |
8 | self.mlp = nn.Sequential(
9 | nn.Linear(in_dim, hidden_dim), nn.ReLU(),
10 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
11 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
12 | nn.Linear(hidden_dim, out_dim),
13 | )
14 |
15 | def forward(self, input):
16 | return self.mlp(input)
--------------------------------------------------------------------------------
/viskill/trainers/base_trainer.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from pathlib import Path
3 |
4 |
5 | class BaseTrainer:
6 | def __init__(self, cfg) -> None:
7 | self.cfg = cfg
8 | self.work_dir = Path(cfg.cwd)
9 | self._setup()
10 |
11 | @abstractmethod
12 | def _setup(self):
13 | raise NotImplementedError
14 |
15 | @abstractmethod
16 | def train(self):
17 | '''Training agent'''
18 | raise NotImplementedError
19 |
20 | @abstractmethod
21 | def eval(self):
22 | '''Evaluating agent.'''
23 | raise NotImplementedError
24 |
--------------------------------------------------------------------------------
/viskill/configs/sc_agent/sc_sac.yaml:
--------------------------------------------------------------------------------
1 | name: SC_SAC
2 | device: ${device}
3 | discount: 0.99
4 | reward_scale: 1
5 |
6 | actor_lr: 1e-4
7 | critic_lr: 1e-4
8 | temp_lr: 1e-4
9 | random_eps: 0.3
10 | noise_eps: 0.01
11 | aux_weight: 5
12 | decay: False
13 | action_l2: 1
14 | p_dist: 2
15 | soft_target_tau: 0.005
16 | clip_obs: 200
17 | norm_clip: 5
18 | norm_eps: 0.01
19 | hidden_dim: 256
20 | sampler:
21 | type: her_seq
22 | strategy: future
23 | k: 4
24 | update_epoch: ${update_epoch}
25 |
26 | normalize: False
27 | intr_reward: True
28 | raw_env_reward: True
29 |
30 | learnable_temperature: True
31 | init_temperature: 0.1
--------------------------------------------------------------------------------
/viskill/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - skill_chaining
3 | - _self_
4 |
5 | sc_seed: 2
6 | sc_ckpt_dir: ./exp/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}/s${sc_seed}/model
7 | sc_ckpt_episode: best
8 |
9 | # Working space
10 | hydra:
11 | run:
12 | dir: ./exp/eval/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}/s${seed}
13 | sweep:
14 | dir: ./exp/eval/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}
15 | subdir: s${seed}
16 | sweeper:
17 | params:
18 | num_demo: 200
19 | seed: 1,2,3,4,5
20 |
--------------------------------------------------------------------------------
/viskill/configs/sc_agent/sc_awac.yaml:
--------------------------------------------------------------------------------
1 | name: SC_AWAC
2 | device: ${device}
3 | discount: 0.99
4 | reward_scale: 1
5 |
6 | actor_lr: 1e-4
7 | critic_lr: 1e-4
8 | temp_lr: 1e-4
9 | random_eps: 0.3
10 | noise_eps: 0.01
11 | aux_weight: 5
12 | decay: False
13 | action_l2: 1
14 | p_dist: 2
15 | soft_target_tau: 0.005
16 | clip_obs: 200
17 | norm_clip: 5
18 | norm_eps: 0.01
19 | hidden_dim: 256
20 | sampler:
21 | type: her_seq
22 | strategy: future
23 | k: 4
24 | update_epoch: ${update_epoch}
25 |
26 | normalize: False
27 | intr_reward: True
28 | raw_env_reward: True
29 |
30 | n_action_samples: 1
31 | lam: 1
32 | learnable_temperature: False
33 | init_temperature: 0.001
--------------------------------------------------------------------------------
/viskill/configs/sc_agent/sc_sac_sil.yaml:
--------------------------------------------------------------------------------
1 | name: SC_SAC_SIL
2 | device: ${device}
3 | discount: 0.99
4 | reward_scale: 1
5 |
6 | actor_lr: 1e-4
7 | critic_lr: 1e-4
8 | temp_lr: 1e-4
9 | random_eps: 0.3
10 | noise_eps: 0.01
11 | aux_weight: 5
12 | decay: False
13 | action_l2: 1
14 | p_dist: 2
15 | soft_target_tau: 0.005
16 | clip_obs: 200
17 | norm_clip: 5
18 | norm_eps: 0.01
19 | hidden_dim: 256
20 | sampler:
21 | type: her_seq
22 | strategy: future
23 | k: 4
24 | update_epoch: ${update_epoch}
25 |
26 | normalize: False
27 | raw_env_reward: True
28 |
29 | n_action_samples: 1
30 | lam: 1
31 | learnable_temperature: False
32 | init_temperature: 0
33 | policy_delay: 30
--------------------------------------------------------------------------------
/viskill/configs/skill_learning.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - sl_agent@agent: sl_dex
3 | - _self_
4 |
5 | # File path
6 | cwd: ${hydra:runtime.output_dir}
7 |
8 | # Training params
9 | n_train_steps: 2_000_001
10 | n_eval: 400
11 | n_save: 160
12 | n_log: 4000
13 | num_demo: 200
14 | eval_frequency: 2_000
15 | n_seed_steps: 200
16 |
17 | replay_buffer_capacity: 100_000
18 | checkpoint_frequency: 20_000
19 | update_epoch: 80
20 | batch_size: 128
21 | device: cuda:0
22 | seed: 1
23 | task: BiPegTransfer-v0
24 | subtask: grasp
25 | postfix: null
26 | skill_chaining: False
27 | dont_save: False
28 | n_eval_episodes: 8
29 | save_buffer: False
30 |
31 | use_wb: True
32 | project_name: viskill
33 | entity_name: thuang22
34 |
35 | mpi: {rank: null, is_chef: null, num_workers: null}
36 | # Working space
37 | hydra:
38 | run:
39 | dir: ./exp/skill_learning/${task}/${agent.name}/d${num_demo}/s${seed}/${subtask}
40 | sweep:
41 | dir: ./exp/skill_learning/${task}/${agent.name}/d${num_demo}/s${seed}
42 | subdir: ${subtask}
43 | sweeper:
44 | params:
45 | seed: 1,2,3,4,5
46 | subtask: grasp,handover,release
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Med-AIR@CUHK
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/viskill/agents/factory.py:
--------------------------------------------------------------------------------
1 | from .sc_awac import SkillChainingAWAC
2 | from .sc_ddpg import SkillChainingDDPG
3 | from .sc_sac import SkillChainingSAC
4 | from .sc_sac_sil import SkillChainingSACSIL
5 | from .sl_ddpgbc import SkillLearningDDPGBC
6 | from .sl_dex import SkillLearningDEX
7 |
8 | AGENTS = {
9 | 'SL_DDPGBC': SkillLearningDDPGBC,
10 | 'SL_DEX': SkillLearningDEX,
11 | 'SC_DDPG': SkillChainingDDPG,
12 | 'SC_AWAC': SkillChainingAWAC,
13 | 'SC_SAC': SkillChainingSAC,
14 | 'SC_SAC_SIL': SkillChainingSACSIL,
15 | }
16 |
17 | def make_sl_agent(env_params, sampler, cfg):
18 | if cfg.name not in AGENTS.keys():
19 | assert 'Agent is not supported: %s' % cfg.name
20 | else:
21 | assert 'SL' in cfg.name
22 | return AGENTS[cfg.name](
23 | env_params=env_params,
24 | sampler=sampler,
25 | agent_cfg=cfg
26 | )
27 |
28 |
29 | def make_sc_agent(env_params, cfg, sl_agent):
30 | if cfg.name not in AGENTS.keys():
31 | assert 'Agent is not supported: %s' % cfg.name
32 | else:
33 | assert 'SC' in cfg.name
34 | return AGENTS[cfg.name](
35 | env_params=env_params,
36 | agent_cfg=cfg,
37 | sl_agent=sl_agent
38 | )
39 |
--------------------------------------------------------------------------------
/viskill/configs/skill_chaining.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - hier_agent@agent: hier_agent
3 | - sc_agent: sc_sac_sil
4 | - sl_agent: sl_dex
5 | - _self_
6 |
7 | # File path
8 | cwd: ${hydra:runtime.output_dir}
9 |
10 | # Training params
11 | task: BiPegTransfer-v0
12 | init_subtask: grasp
13 | subtask: ${init_subtask}
14 | postfix: null
15 | skill_chaining: True
16 | dont_save: False
17 | checkpoint_dir: ./exp/skill_learning/${task}/${sl_agent.name}/d${num_demo}/s${model_seed}
18 |
19 | num_demo: 200
20 | seed: 1
21 | model_seed: 1
22 | device: cuda:0
23 | update_epoch: 20
24 | replay_buffer_capacity: 100_000
25 | batch_size: 128
26 |
27 | n_train_steps: 1_000_001
28 | n_eval: 1600
29 | n_save: 800
30 | n_log: 9600
31 | n_eval_episodes: 10
32 | eval_frequency: 2_000
33 | n_seed_steps: 200
34 | use_demo_buffer: True
35 | ckpt_episode: latest
36 |
37 | use_wb: True
38 | project_name: viskill
39 | entity_name: thuang22
40 |
41 | mpi: {rank: null, is_chef: null, num_workers: null}
42 | # Working space
43 | hydra:
44 | run:
45 | dir: ./exp/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}/s${seed}
46 | sweep:
47 | dir: ./exp/viskill/${task}/${sc_agent.name}_${sl_agent.name}/d${num_demo}/rs${sc_agent.reward_scale}
48 | subdir: s${seed}
49 | sweeper:
50 | params:
51 | num_demo: 200
52 | seed: 1,2,3,4,5
53 |
--------------------------------------------------------------------------------
/viskill/components/normalizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class Normalizer:
6 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf):
7 | self.size = size
8 | self.eps = eps
9 | self.default_clip_range = default_clip_range
10 | # some local information
11 | self.total_sum = np.zeros(self.size, np.float32)
12 | self.total_sumsq = np.zeros(self.size, np.float32)
13 | self.total_count = np.zeros(1, np.float32)
14 | # get the mean and std
15 | self.mean = np.zeros(self.size, np.float32)
16 | self.std = np.ones(self.size, np.float32)
17 |
18 | # update the parameters of the normalizer
19 | def update(self, v):
20 | v = v.reshape(-1, self.size)
21 | # do the computing
22 | self.total_sum += v.sum(axis=0)
23 | self.total_sumsq += (np.square(v)).sum(axis=0)
24 | self.total_count[0] += v.shape[0]
25 |
26 | def recompute_stats(self):
27 | # calculate the new mean and std
28 | self.mean = self.total_sum / self.total_count
29 | self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square(self.total_sum / self.total_count)))
30 |
31 | # normalize the observation
32 | def normalize(self, v, clip_range=None, device=None):
33 | if clip_range is None:
34 | clip_range = self.default_clip_range
35 | if isinstance(v, np.ndarray):
36 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range)
37 | elif isinstance(v, torch.Tensor):
38 | return (v - torch.tensor(self.mean, dtype=torch.float32).to(device)) / (torch.tensor(self.std, dtype=torch.float32).to(device)).clamp(-clip_range, clip_range)
--------------------------------------------------------------------------------
/viskill/modules/distributions.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn.functional as F
4 | from torch import distributions as pyd
5 |
6 |
7 | class TanhTransform(pyd.transforms.Transform):
8 | domain = pyd.constraints.real
9 | codomain = pyd.constraints.interval(-1.0, 1.0)
10 | bijective = True
11 | sign = +1
12 |
13 | def __init__(self, cache_size=1):
14 | super().__init__(cache_size=cache_size)
15 |
16 | @staticmethod
17 | def atanh(x):
18 | return 0.5 * (x.log1p() - (-x).log1p())
19 |
20 | def __eq__(self, other):
21 | return isinstance(other, TanhTransform)
22 |
23 | def _call(self, x):
24 | return x.tanh()
25 |
26 | def _inverse(self, y):
27 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
28 | # one should use `cache_size=1` instead
29 | return self.atanh(y)
30 |
31 | def log_abs_det_jacobian(self, x, y):
32 | # We use a formula that is more numerically stable, see details in the following link
33 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
34 | return 2. * (math.log(2.) - x - F.softplus(-2. * x))
35 |
36 |
37 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
38 | def __init__(self, loc, scale):
39 | self.loc = loc
40 | self.scale = scale
41 |
42 | self.base_dist = pyd.Normal(loc, scale)
43 | transforms = [TanhTransform()]
44 | super().__init__(self.base_dist, transforms)
45 |
46 | @property
47 | def mean(self):
48 | mu = self.loc
49 | for tr in self.transforms:
50 | mu = tr(mu)
51 | return mu
52 |
--------------------------------------------------------------------------------
/viskill/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def add_caption_to_img(img, info, name=None, flip_rgb=False):
6 | """ Adds caption to an image. info is dict with keys and text/array.
7 | :arg name: if given this will be printed as heading in the first line
8 | :arg flip_rgb: set to True for inputs with BGR color channels
9 | """
10 | offset = 12
11 |
12 | frame = img * 255.0 if img.max() <= 1.0 else img
13 | if flip_rgb:
14 | frame = frame[:, :, ::-1]
15 |
16 | # make frame larger if needed
17 | if frame.shape[0] < 300:
18 | frame = cv2.resize(frame, (400, 400), interpolation=cv2.INTER_CUBIC)
19 |
20 | fheight, fwidth = frame.shape[:2]
21 | frame = np.concatenate([frame, np.zeros((offset * (len(info.keys()) + 2), fwidth, 3))], 0)
22 |
23 | font_size = 0.4
24 | thickness = 1
25 | x, y = 5, fheight + 10
26 | if name is not None:
27 | cv2.putText(frame, '[{}]'.format(name),
28 | (x, y), cv2.FONT_HERSHEY_SIMPLEX,
29 | font_size, (100, 100, 0), thickness, cv2.LINE_AA)
30 | for i, k in enumerate(info.keys()):
31 | v = info[k]
32 | key_text = '{}: '.format(k)
33 | (key_width, _), _ = cv2.getTextSize(key_text, cv2.FONT_HERSHEY_SIMPLEX,
34 | font_size, thickness)
35 |
36 | cv2.putText(frame, key_text,
37 | (x, y + offset * (i + 2)),
38 | cv2.FONT_HERSHEY_SIMPLEX,
39 | font_size, (66, 133, 244), thickness, cv2.LINE_AA)
40 |
41 | cv2.putText(frame, str(v),
42 | (x + key_width, y + offset * (i + 2)),
43 | cv2.FONT_HERSHEY_SIMPLEX,
44 | font_size, (100, 100, 100), thickness, cv2.LINE_AA)
45 |
46 | if flip_rgb:
47 | frame = frame[:, :, ::-1]
48 |
49 | return frame
50 |
51 | def add_captions_to_seq(img_seq, info_seq, **kwargs):
52 | """Adds caption to sequence of image. info_seq is list of dicts with keys and text/array."""
53 | return [add_caption_to_img(img, info, name='Timestep {:03d}'.format(i), **kwargs) for i, (img, info) in enumerate(zip(img_seq, info_seq))]
54 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/viskill/modules/critics.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ..modules.subnetworks import MLP
7 |
8 |
9 | class Critic(nn.Module):
10 | def __init__(self, in_dim, hidden_dim):
11 | super().__init__()
12 |
13 | self.q = MLP(
14 | in_dim=in_dim,
15 | out_dim=1,
16 | hidden_dim=hidden_dim
17 | )
18 |
19 | def forward(self, state, action):
20 | sa = torch.cat([state, action], dim=-1)
21 | q = self.q(sa)
22 | return q
23 |
24 |
25 | class DoubleCritic(nn.Module):
26 | def __init__(self, in_dim, hidden_dim):
27 | super().__init__()
28 |
29 | self.q1 = MLP(
30 | in_dim=in_dim,
31 | out_dim=1,
32 | hidden_dim=hidden_dim
33 | )
34 |
35 | self.q2 = MLP(
36 | in_dim=in_dim,
37 | out_dim=1,
38 | hidden_dim=hidden_dim
39 | )
40 |
41 | def forward(self, state, action):
42 | sa = torch.cat([state, action], dim=-1)
43 | q1 = self.q1(sa)
44 | q2 = self.q2(sa)
45 | return q1, q2
46 |
47 | def q(self, state, action):
48 | sa = torch.cat([state, action], dim=-1)
49 | q1 = self.q1(sa)
50 | q2 = self.q2(sa)
51 | return torch.min(q1, q2)
52 |
53 |
54 | class SkillChainingCritic(nn.Module):
55 | def __init__(self, in_dim, hidden_dim, middle_subtasks, last_subtask):
56 | super().__init__()
57 |
58 | self.qs = nn.ModuleDict({
59 | subtask: Critic(
60 | in_dim=in_dim,
61 | hidden_dim=hidden_dim
62 | ) for subtask in middle_subtasks})
63 | self.qs.update({last_subtask: None})
64 |
65 | def __getitem__(self, key):
66 | return self.qs[key]
67 |
68 | def forward(self, state, action, subtask):
69 | q = self.qs[subtask](state, action)
70 | return q
71 |
72 | def init_last_subtask_q(self, last_subtask, critic):
73 | '''Initialize with pre-trained local q-function'''
74 | assert self.qs[last_subtask] is None
75 | self.qs[last_subtask] = copy.deepcopy(critic)
76 |
77 |
78 | class SkillChainingDoubleCritic(SkillChainingCritic):
79 | def __init__(self, in_dim, hidden_dim, middle_subtasks, last_subtask):
80 | super(SkillChainingCritic, self).__init__()
81 |
82 | self.qs = nn.ModuleDict({
83 | subtask: DoubleCritic(
84 | in_dim=in_dim,
85 | hidden_dim=hidden_dim
86 | ) for subtask in middle_subtasks})
87 | self.qs.update({last_subtask: None})
88 |
89 | def forward(self, state, action, subtask):
90 | q1, q2 = self.qs[subtask](state, action)
91 | return q1, q2
92 |
93 | def q(self, state, action, subtask):
94 | return self.qs[subtask].q(state, action)
95 |
--------------------------------------------------------------------------------
/viskill/components/checkpointer.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import pipes
4 | import sys
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from ..components.logger import logger
10 | from ..utils.general_utils import get_last_argmax, str2int
11 |
12 |
13 | class CheckpointHandler:
14 | @staticmethod
15 | def get_ckpt_name(episode):
16 | return 'weights_ep{}.pth'.format(episode)
17 |
18 | @staticmethod
19 | def get_episode(path):
20 | checkpoint_names = glob.glob(os.path.abspath(path) + "/*.pth")
21 | if len(checkpoint_names) == 0:
22 | logger.error("No checkpoints found at {}!".format(path))
23 | processed_names = [file.split('/')[-1].replace('weights_ep', '').replace('.pth', '')
24 | for file in checkpoint_names]
25 | episodes = list(filter(lambda x: x is not None, [str2int(name) for name in processed_names]))
26 | return episodes
27 |
28 | @staticmethod
29 | def get_resume_ckpt_file(resume, path):
30 | episodes = CheckpointHandler.get_episode(path)
31 | file_paths = [os.path.join(path, CheckpointHandler.get_ckpt_name(episode)) for episode in episodes]
32 | scores = [torch.load(file_path)['score'] for file_path in file_paths]
33 | if resume == 'latest':
34 | max_episode = np.max(episodes)
35 | resume_file = CheckpointHandler.get_ckpt_name(max_episode)
36 | logger.info(f'Checkpoints with max episode {max_episode} with the success rate {scores[np.argmax(episodes)]}!')
37 | elif resume == 'best':
38 | max_episode = episodes[get_last_argmax(scores)]
39 | resume_file = CheckpointHandler.get_ckpt_name(max_episode)
40 | logger.info(f'Checkpoints with success rate {scores}, the highest success rate {max(scores)}!')
41 | return os.path.join(path, resume_file), max_episode
42 |
43 | @staticmethod
44 | def save_checkpoint(state, folder, filename='checkpoint.pth'):
45 | torch.save(state, os.path.join(folder, filename))
46 |
47 | @staticmethod
48 | def load_checkpoint(checkpt_dir, agent, device, episode='best'):
49 | """Loads weigths from checkpoint."""
50 | checkpt_path, max_episode = CheckpointHandler.get_resume_ckpt_file(episode, checkpt_dir)
51 | checkpt = torch.load(checkpt_path, map_location=device)
52 |
53 | logger.info(f'Loading pre-trained model from {checkpt_path}!')
54 | agent.load_state_dict(checkpt['state_dict'])
55 | if 'g_norm' in checkpt.keys() and 'o_norm' in checkpt.keys():
56 | agent.g_norm = checkpt['g_norm']
57 | agent.o_norm = checkpt['o_norm']
58 |
59 |
60 | def save_cmd(base_dir):
61 | train_cmd = 'python ' + ' '.join([sys.argv[0]] + [pipes.quote(s) for s in sys.argv[1:]])
62 | train_cmd += '\n\n'
63 | print('\n' + '*' * 80)
64 | print('Training command:\n' + train_cmd)
65 | print('*' * 80 + '\n')
66 | with open(os.path.join(base_dir, "cmd.txt"), "a") as f:
67 | f.write(train_cmd)
68 |
--------------------------------------------------------------------------------
/viskill/utils/mpi.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from mpi4py import MPI
4 |
5 | from .general_utils import (AttrDict, joinListDict, joinListDictList,
6 | joinListList)
7 |
8 |
9 | def update_mpi_config(cfg):
10 | rank = MPI.COMM_WORLD.Get_rank()
11 | cfg.mpi.rank = rank
12 | cfg.mpi.is_chef = rank == 0
13 | cfg.mpi.num_workers = MPI.COMM_WORLD.Get_size()
14 |
15 | # update conf
16 | cfg.seed = cfg.seed + rank
17 |
18 |
19 | def mpi_sum(x):
20 | buf = np.zeros_like(np.array(x))
21 | MPI.COMM_WORLD.Allreduce(np.array(x), buf, op=MPI.SUM)
22 | return buf
23 |
24 |
25 | def mpi_gather_experience_episode(experience_episode):
26 | buf = MPI.COMM_WORLD.allgather(experience_episode)
27 | return joinListDictList(buf)
28 |
29 |
30 | def mpi_gather_experience_rollots(experience_rollouts):
31 | buf = MPI.COMM_WORLD.allgather(experience_rollouts)
32 | return joinListDict(buf)
33 |
34 |
35 | def mpi_gather_experience_transitions(experience_transitions):
36 | buf = MPI.COMM_WORLD.allgather(experience_transitions)
37 | return joinListList(buf)
38 |
39 |
40 | def mpi_gather_experience_successful_transitions(experience_transitions):
41 | buf = MPI.COMM_WORLD.allgather(experience_transitions)
42 | jointLL = []
43 | for i in experience_transitions:
44 | if i[0] is not None:
45 | jointLL.append(i)
46 | return jointLL
47 |
48 |
49 | def mpi_gather_experience(experience_episode):
50 | """Gathers data across workers, can handle hierarchical and flat experience dicts."""
51 | return mpi_gather_experience_episode(experience_episode)
52 |
53 |
54 | def mpi_gather_rollouts(rollouts):
55 | """Gathers data across workers, can handle hierarchical and flat experience dicts."""
56 | return mpi_gather_experience_rollots(rollouts)
57 |
58 |
59 | # sync_networks across the different cores
60 | def sync_networks(network):
61 | """
62 | netowrk is the network you want to sync
63 | """
64 | comm = MPI.COMM_WORLD
65 | flat_params, params_shape = _get_flat_params(network)
66 | comm.Bcast(flat_params, root=0)
67 | # set the flat params back to the network
68 | _set_flat_params(network, params_shape, flat_params)
69 |
70 |
71 | # get the flat params from the network
72 | def _get_flat_params(network):
73 | param_shape = {}
74 | flat_params = None
75 | for key_name, value in network.state_dict().items():
76 | param_shape[key_name] = value.cpu().detach().numpy().shape
77 | if flat_params is None:
78 | flat_params = value.cpu().detach().numpy().flatten()
79 | else:
80 | flat_params = np.append(flat_params, value.cpu().detach().numpy().flatten())
81 | return flat_params, param_shape
82 |
83 |
84 | # set the params from the network
85 | def _set_flat_params(network, params_shape, params):
86 | pointer = 0
87 | device = torch.device("cuda:0")
88 |
89 | for key_name, values in network.state_dict().items():
90 | # get the length of the parameters
91 | len_param = int(np.prod(params_shape[key_name]))
92 | copy_params = params[pointer:pointer + len_param].reshape(params_shape[key_name])
93 | copy_params = torch.tensor(copy_params).to(device)
94 | # copy the params
95 | values.data.copy_(copy_params.data)
96 | # update the pointer
97 | pointer += len_param
98 |
--------------------------------------------------------------------------------
/viskill/agents/sc_sac_sil.py:
--------------------------------------------------------------------------------
1 | from ..utils.general_utils import AttrDict, prefix_dict
2 | from .sc_awac import SkillChainingAWAC
3 |
4 |
5 | class SkillChainingSACSIL(SkillChainingAWAC):
6 | def __init__(
7 | self,
8 | env_params,
9 | agent_cfg,
10 | sl_agent
11 | ):
12 | super().__init__(env_params, agent_cfg, sl_agent)
13 | self.policy_delay = agent_cfg.policy_delay
14 |
15 | def update_actor_and_alpha(self, obs, action, subtask, sil=False):
16 | if sil:
17 | #metrics = super(SkillChainingSACSIL, self).update_actor(obs, action, subtask)
18 | # compute log probability
19 | dist = self.actor(obs, subtask)
20 | log_probs = dist.log_prob(action).sum(-1, keepdim=True)
21 | # compute exponential weight
22 | weights = self._compute_weights(obs, action, subtask)
23 | actor_loss = -(log_probs * weights).sum()
24 |
25 | # optimize actor loss
26 | self.actor_optimizer[subtask].zero_grad()
27 | actor_loss.backward()
28 | self.actor_optimizer[subtask].step()
29 |
30 | metrics = AttrDict(
31 | log_probs=log_probs.mean(),
32 | actor_loss=actor_loss.item()
33 | )
34 | metrics = prefix_dict(metrics, 'sil_')
35 | else:
36 | # compute log probability
37 | #metrics = super(SkillChainingAWAC, self).update_actor_and_alpha(obs, subtask)
38 | dist = self.actor(obs, subtask)
39 | action = dist.rsample()
40 | log_probs = dist.log_prob(action).sum(-1, keepdim=True)
41 | # compute state value
42 | actor_Q = self.critic.q(obs, action, subtask)
43 | actor_loss = (- actor_Q).mean()
44 |
45 | # optimize actor loss
46 | self.actor_optimizer[subtask].zero_grad()
47 | actor_loss.backward()
48 | self.actor_optimizer[subtask].step()
49 |
50 | metrics = AttrDict(
51 | log_probs=log_probs.mean(),
52 | actor_loss=actor_loss.item()
53 | )
54 |
55 | return prefix_dict(metrics, subtask + '_')
56 |
57 | def update(self, replay_buffer, demo_buffer):
58 | metrics = AttrDict()
59 |
60 | for subtask in self.env_params['middle_subtasks']:
61 | for i in range(self.update_epoch):
62 | # sample from replay buffer
63 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(replay_buffer, subtask)
64 | action = (action / self.max_action)
65 |
66 | # update critic and actor
67 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask))
68 | if (i + 1) % self.policy_delay == 0:
69 | metrics.update(self.update_actor_and_alpha(obs, action, subtask, sil=False))
70 |
71 | # sample from replay buffer
72 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(demo_buffer, subtask)
73 | action = (action / self.max_action)
74 |
75 | # update actor
76 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask))
77 | metrics.update(self.update_actor_and_alpha(obs, action, subtask, sil=True))
78 |
79 | # update target critic and actor
80 | self.update_target()
81 |
82 | return metrics
83 |
--------------------------------------------------------------------------------
/viskill/modules/policies.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .distributions import SquashedNormal
9 | from .subnetworks import MLP
10 |
11 | LOG_STD_BOUNDS = (-5, 2)
12 |
13 | class DeterministicActor(nn.Module):
14 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1.):
15 | super().__init__()
16 |
17 | self.trunk = MLP(
18 | in_dim=in_dim,
19 | out_dim=out_dim,
20 | hidden_dim=hidden_dim
21 | )
22 | self.max_action = max_action
23 |
24 | def forward(self, state):
25 | a = self.trunk(state)
26 | return self.max_action * torch.tanh(a)
27 |
28 |
29 | class DiagGaussianActor(nn.Module):
30 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1.):
31 | super().__init__()
32 |
33 | self.trunk = MLP(
34 | in_dim=in_dim,
35 | out_dim=2*out_dim,
36 | hidden_dim=hidden_dim
37 | )
38 | self.max_action = max_action
39 |
40 | def forward(self, obs):
41 | mu, log_std = self.trunk(obs).chunk(2, dim=-1)
42 |
43 | # constrain log_std inside [log_std_min, log_std_max]
44 | log_std = torch.tanh(log_std)
45 | log_std_min, log_std_max = LOG_STD_BOUNDS
46 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)
47 | std = log_std.exp()
48 |
49 | dist = SquashedNormal(mu, std)
50 | return dist
51 |
52 | def sample_n(self, obs, n_samples):
53 | return self.forward(obs).sample_n(n_samples)
54 |
55 |
56 | class SkillChainingActor(nn.Module):
57 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1.,
58 | middle_subtasks=None, last_subtask=None):
59 | super().__init__()
60 |
61 | self.actors = nn.ModuleDict({
62 | subtask: DeterministicActor(
63 | in_dim=in_dim,
64 | out_dim=out_dim,
65 | hidden_dim=hidden_dim,
66 | max_action=max_action
67 | ) for subtask in middle_subtasks})
68 | self.actors.update({last_subtask: None})
69 |
70 | def __getitem__(self, key):
71 | return self.actors[key]
72 |
73 | def forward(self, state, subtask):
74 | a = self.actors[subtask](state)
75 | return a
76 |
77 | def init_last_subtask_actor(self, last_subtask, actor):
78 | '''Initialize with pre-trained local actor'''
79 | assert self.actors[last_subtask] is None
80 | self.actors[last_subtask] = copy.deepcopy(actor)
81 |
82 |
83 | class SkillChainingDiagGaussianActor(SkillChainingActor):
84 | def __init__(self, in_dim, out_dim, hidden_dim=256, max_action=1.,
85 | middle_subtasks=None, last_subtask=None):
86 | super(SkillChainingActor, self).__init__()
87 |
88 | self.actors = nn.ModuleDict({
89 | subtask: DiagGaussianActor(
90 | in_dim=in_dim,
91 | out_dim=out_dim,
92 | hidden_dim=hidden_dim,
93 | max_action=max_action
94 | ) for subtask in middle_subtasks})
95 | self.actors.update({last_subtask: None})
96 |
97 | def sample_n(self, obs, subtask, n_samples):
98 | return self.actors[subtask].sample_n(obs, n_samples)
99 |
100 | def squash_action(self, dist, raw_action):
101 | squashed_action = torch.tanh(raw_action)
102 | jacob = 2 * (math.log(2) - raw_action - F.softplus(-2 * raw_action))
103 | log_prob = (dist.log_prob(raw_action) - jacob).sum(dim=-1, keepdims=True)
104 | return squashed_action, log_prob
105 |
--------------------------------------------------------------------------------
/viskill/agents/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from ..utils.mpi import sync_networks
6 |
7 |
8 | class BaseAgent(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 |
12 | def train(self, training=True):
13 | self.training = training
14 | self.actor.train(training)
15 | self.critic.train(training)
16 |
17 | def get_samples(self, replay_buffer):
18 | # sample replay buffer
19 | transitions = replay_buffer.sample()
20 |
21 | # preprocess
22 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g']
23 | transitions['obs'], transitions['g'] = self._preproc_og(o, g)
24 | transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g)
25 |
26 | obs_norm = self.o_norm.normalize(transitions['obs'])
27 | g_norm = self.g_norm.normalize(transitions['g'])
28 | inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
29 |
30 | obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
31 | g_next_norm = self.g_norm.normalize(transitions['g_next'])
32 | inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
33 |
34 | obs = self.to_torch(inputs_norm)
35 | next_obs = self.to_torch(inputs_next_norm)
36 | action = self.to_torch(transitions['actions'])
37 | reward = self.to_torch(transitions['r'])
38 | done = self.to_torch(transitions['dones'])
39 |
40 | return obs, action, reward, done, next_obs
41 |
42 | def update_target(self):
43 | # Update the frozen target models
44 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
45 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data)
46 |
47 | def update_normalizer(self, episode_batch):
48 | mb_obs, mb_ag, mb_g, mb_actions, dones = episode_batch.obs, episode_batch.ag, episode_batch.g, \
49 | episode_batch.actions, episode_batch.dones
50 | mb_obs_next = mb_obs[:, 1:, :]
51 | mb_ag_next = mb_ag[:, 1:, :]
52 | # get the number of normalization transitions
53 | num_transitions = mb_actions.shape[1]
54 | # create the new buffer to store them
55 | buffer_temp = {'obs': mb_obs,
56 | 'ag': mb_ag,
57 | 'g': mb_g,
58 | 'actions': mb_actions,
59 | 'obs_next': mb_obs_next,
60 | 'ag_next': mb_ag_next,
61 | }
62 | transitions = self.her_sampler.sample_her_transitions(buffer_temp, num_transitions)
63 | obs, g = transitions['obs'], transitions['g']
64 | # pre process the obs and g
65 | transitions['obs'], transitions['g'] = self._preproc_og(obs, g)
66 | # update
67 | self.o_norm.update(transitions['obs'])
68 | self.g_norm.update(transitions['g'])
69 | # recompute the stats
70 | self.o_norm.recompute_stats()
71 | self.g_norm.recompute_stats()
72 |
73 | def _preproc_og(self, o, g):
74 | o = np.clip(o, -self.clip_obs, self.clip_obs)
75 | g = np.clip(g, -self.clip_obs, self.clip_obs)
76 | return o, g
77 |
78 | def _preproc_inputs(self, o, g, dim=0, device=None):
79 | o_norm = self.o_norm.normalize(o, device=device)
80 | g_norm = self.g_norm.normalize(g, device=device)
81 |
82 | if isinstance(o_norm, np.ndarray):
83 | inputs = np.concatenate([o_norm, g_norm], dim)
84 | inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0).to(self.device)
85 | elif isinstance(o_norm, torch.Tensor):
86 | inputs = torch.cat([o_norm, g_norm], dim=1)
87 | return inputs
88 |
89 | def to_torch(self, array, copy=True):
90 | if copy:
91 | return torch.tensor(array, dtype=torch.float32).to(self.device)
92 | return torch.as_tensor(array).to(self.device)
93 |
94 | def sync_networks(self):
95 | sync_networks(self)
--------------------------------------------------------------------------------
/viskill/agents/hier_agent.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..components.checkpointer import CheckpointHandler
4 | from ..utils.general_utils import AttrDict, prefix_dict
5 | from .base import BaseAgent
6 | from .factory import make_sc_agent, make_sl_agent
7 |
8 |
9 | class HierachicalAgent(BaseAgent):
10 | def __init__(
11 | self,
12 | env_params,
13 | samplers,
14 | cfg,
15 | ):
16 | super().__init__()
17 | self.cfg = cfg
18 | self.device = cfg.device
19 | self.subtasks = env_params.subtasks
20 | self.goal_adaptor = env_params['adaptor_sc']
21 | self.subtasks_steps = env_params['subtask_steps']
22 | self.middle_subtasks = env_params['middle_subtasks']
23 | self.last_subtask = env_params['last_subtask']
24 | self.len_cond = env_params['len_cond']
25 | self.curr_subtask = None
26 |
27 | self.sl_agent = torch.nn.ModuleDict(
28 | {subtask: make_sl_agent(env_params, samplers[subtask], cfg.sl_agent) for subtask in self.subtasks})
29 | self._init_sl_agent()
30 | self.sc_agent = make_sc_agent(env_params, cfg.sc_agent, self.sl_agent)
31 | self._init_sc_agent()
32 |
33 | def _init_sl_agent(self):
34 | checkpt_dir = self.cfg.checkpoint_dir
35 | for subtask in self.subtasks:
36 | # TODO(tao): expose loading metric and dir speicification
37 | sub_checkpt_dir = checkpt_dir + f'/{subtask}/model'
38 | CheckpointHandler.load_checkpoint(
39 | sub_checkpt_dir, self.sl_agent[subtask], self.device, episode=self.cfg.ckpt_episode)
40 |
41 | def _init_sc_agent(self):
42 | self.sc_agent.init(self.last_subtask, self.sl_agent)
43 |
44 | def update(self, sc_buffer, sl_buffer, sc_demo_buffer=None, sl_demo_buffer=None):
45 | metrics = AttrDict()
46 | if sc_demo_buffer is None:
47 | sc_metrics = prefix_dict(self.sc_agent.update(sc_buffer), 'sc_')
48 | else:
49 | sc_metrics = prefix_dict(self.sc_agent.update(sc_buffer, sc_demo_buffer), 'sc_')
50 | metrics.update(sc_metrics)
51 |
52 | if self.cfg.agent.update_sl_agent:
53 | for subtask in self.subtasks:
54 | sl_metrics = prefix_dict(self.sl_agent[subtask].update(sl_buffer[subtask]), 'sl_' + subtask + '_')
55 | metrics.update(sl_metrics)
56 | return metrics
57 |
58 | def get_action(self, obs, subtask, noise=False):
59 | output = AttrDict()
60 | if self._perform_hl_step_now(subtask):
61 | # perform step with skill-chaining policy
62 | if subtask not in self.middle_subtasks:
63 | self._last_sc_action = obs['desired_goal'][:-self.len_cond]
64 | else:
65 | self._last_sc_action = self.sc_agent.get_action(obs['observation'], subtask, noise=noise)
66 | output.is_sc_step = True
67 | self.curr_subtask = subtask
68 | else:
69 | output.is_sc_step = False
70 |
71 | # perform step with skill-learning policy
72 | assert self._last_sc_action is not None
73 | self.goal_adaption(obs, subtask)
74 | sl_action = self.sl_agent[subtask].get_action(obs, noise=False)
75 |
76 | output.update(AttrDict(
77 | sc_action=self._last_sc_action,
78 | sl_action=sl_action
79 | ))
80 | return output
81 |
82 | def goal_adaption(self, obs, subtask):
83 | # Add contact condition to make goal compatible with surrol wrapper
84 | if subtask == 'release' and self.cfg.task == 'MatchBoard-v0':
85 | adpt_goal = self.goal_adaptor(self._last_sc_action, subtask)
86 | adpt_goal[3: 6] = obs['desired_goal'][3: 6].copy()
87 | obs['desired_goal'] = adpt_goal
88 | elif subtask in ['push', 'pull'] and self.cfg.task == 'MatchBoard-v0':
89 | adpt_goal = self.goal_adaptor(self._last_sc_action, subtask)
90 | adpt_goal[3] = obs['desired_goal'][3].copy()
91 | adpt_goal[5] = obs['desired_goal'][5].copy()
92 | obs['desired_goal'] = adpt_goal
93 | else:
94 | obs['desired_goal'] = self.goal_adaptor(self._last_sc_action, subtask)
95 |
96 | def _perform_hl_step_now(self, subtask):
97 | """Indicates whether the skill-chaining policy should be executed in the current time step."""
98 | return subtask != self.curr_subtask
99 |
100 | def sync_networks(self):
101 | self.sc_agent.sync_networks()
--------------------------------------------------------------------------------
/viskill/agents/sc_awac.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from ..utils.general_utils import AttrDict, prefix_dict
5 | from .sc_sac import SkillChainingSAC
6 |
7 |
8 | class SkillChainingAWAC(SkillChainingSAC):
9 | def __init__(
10 | self,
11 | env_params,
12 | agent_cfg,
13 | sl_agent
14 | ):
15 | super().__init__(env_params, agent_cfg, sl_agent)
16 |
17 | # AWAC parameters
18 | self.n_action_samples = agent_cfg.n_action_samples
19 | self.lam = agent_cfg.lam
20 |
21 | def update_critic(self, obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask):
22 | assert subtask != self.env_params['last_subtask']
23 |
24 | with torch.no_grad():
25 | next_subtask = self.env_params['next_subtasks'][subtask]
26 | if next_subtask != self.env_params['last_subtask']:
27 | dist = self.actor(next_obs, next_subtask)
28 | action_out = dist.rsample()
29 | target_V = self.critic_target.q(next_obs, action_out, next_subtask)
30 |
31 | if self.raw_env_reward and next_subtask == self.env_params['last_subtask']:
32 | target_Q = self.reward_scale * reward + (self.discount * raw_reward)
33 | else:
34 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach()
35 |
36 | current_Q1, current_Q2 = self.critic(obs, action, subtask)
37 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
38 |
39 | # Optimize critic loss
40 | self.critic_optimizer[subtask].zero_grad()
41 | critic_loss.backward()
42 | self.critic_optimizer[subtask].step()
43 |
44 | metrics = AttrDict(
45 | critic_target_q=target_Q.mean().item(),
46 | critic_q=current_Q1.mean().item(),
47 | critic_loss=critic_loss.item()
48 | )
49 | return prefix_dict(metrics, subtask + '_')
50 |
51 | def update_actor(self, obs, action, subtask):
52 | # compute log probability
53 | dist = self.actor(obs, subtask)
54 | log_probs = dist.log_prob(action).sum(-1, keepdim=True)
55 |
56 | # compute exponential weight
57 | weights = self._compute_weights(obs, action, subtask)
58 | actor_loss = -(log_probs * weights).sum()
59 |
60 | self.actor_optimizer[subtask].zero_grad()
61 | actor_loss.backward()
62 | self.actor_optimizer[subtask].step()
63 |
64 | metrics = AttrDict(
65 | log_probs=log_probs.mean(),
66 | actor_loss=actor_loss.item()
67 | )
68 | return prefix_dict(metrics, subtask + '_')
69 |
70 | def update(self, replay_buffer, demo_buffer):
71 | metrics = AttrDict()
72 |
73 | for i in range(self.update_epoch):
74 | for subtask in self.env_params['middle_subtasks']:
75 | # sample from replay buffer
76 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(replay_buffer, subtask)
77 | action = action / self.max_action
78 |
79 | # update critic and actor
80 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask))
81 | metrics.update(self.update_actor(obs, action, subtask))
82 |
83 | # sample from replay buffer
84 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(demo_buffer, subtask)
85 | action = action / self.max_action
86 |
87 | # update critic and actor
88 | metrics.update(self.update_critic(obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask))
89 | metrics.update(self.update_actor(obs, action, subtask))
90 |
91 | # update target critic and actor
92 | self.update_target()
93 |
94 | return metrics
95 |
96 | def _compute_weights(self, obs, act, subtask):
97 | with torch.no_grad():
98 | batch_size = obs.shape[0]
99 |
100 | # compute action-value
101 | q_values = self.critic.q(obs, act, subtask)
102 |
103 | # sample actions
104 | policy_actions = self.actor.sample_n(obs, subtask, self.n_action_samples)
105 | flat_actions = policy_actions.reshape(-1, self.dima)
106 |
107 | # repeat observation
108 | reshaped_obs = obs.view(batch_size, 1, *obs.shape[1:])
109 | reshaped_obs = reshaped_obs.expand(batch_size, self.n_action_samples, *obs.shape[1:])
110 | flat_obs = reshaped_obs.reshape(-1, *obs.shape[1:])
111 |
112 | # compute state-value
113 | flat_v_values = self.critic.q(flat_obs, flat_actions, subtask)
114 | reshaped_v_values = flat_v_values.view(obs.shape[0], -1, 1)
115 | v_values = reshaped_v_values.mean(dim=1)
116 |
117 | # compute normalized weight
118 | adv_values = (q_values - v_values).view(-1)
119 | weights = F.softmax(adv_values / self.lam, dim=0).view(-1, 1)
120 |
121 | return weights * adv_values.numel()
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ViSkill: Value-Informed Skill Chaining for Policy Learning of Long-Horizon Tasks with Surgical Robot
2 | This is the official PyTorch implementation of the paper "[**Value-Informed Skill Chaining for Policy Learning of Long-Horizon Tasks with Surgical Robot**](https://arxiv.org/pdf/2307.16503.pdf)" (IROS 2023).
3 |
6 |
7 |
8 |
9 |
10 |
11 | # Prerequisites
12 | * Ubuntu 18.04
13 | * Python 3.7+
14 |
15 |
16 | # Installation Instructions
17 |
18 | 1. Clone this repository.
19 | ```bash
20 | git clone --recursive https://github.com/med-air/ViSkill.git
21 | cd ViSkill
22 | git submodule update --init --recursive
23 | ```
24 |
25 | 2. Create a virtual environment
26 | ```bash
27 | conda create -n viskill python=3.8
28 | conda activate viskill
29 | ```
30 |
31 | 3. Install packages
32 |
33 | ```bash
34 | pip3 install -e SurRoL/ # install surrol environments
35 | pip3 install -r requirements.txt
36 | pip3 install -e .
37 | ```
38 |
39 | 4. Then add one line of code at the top of `gym/gym/envs/__init__.py` to register SurRoL tasks:
40 |
41 | ```python
42 | # directory: anaconda3/envs/dex/lib/python3.8/site-packages/
43 | import surrol.gym
44 | ```
45 |
46 | # Usage
47 | Commands for ViSkill. Results will be logged to WandB. Before running the commands below, please change the wandb entity in [```skill_learning.yaml```](viskill/configs/skill_learning.yaml#L32) and [```skill_chaining.yaml```](viskill/configs/skill_chaining.yaml#L39) to match your account.
48 |
49 | We collect demonstration data for each subtask via the scripted controllers provided by SurRoL. Take the BiPegTransfer task as example:
50 | ```bash
51 | mkdir SurRoL/surrol/data/demo
52 | python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask grasp
53 | python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask handover
54 | python SurRoL/surrol/data/data_generation_bipegtransfer.py --env BiPegTransfer-v0 --subtask release
55 | ```
56 | ## Training Commands
57 |
58 | - Train subtask policies:
59 | ```bash
60 | mpirun -np 8 python -m train_sl seed=1 subtask=grasp
61 | mpirun -np 8 python -m train_sl seed=1 subtask=handover
62 | mpirun -np 8 python -m train_sl seed=1 subtask=release
63 | ```
64 |
65 | - Train chaining policies:
66 | ```bash
67 | mpirun -np 8 python -m train_sc.py model_seed=1 task=BiPegTransfer-v0
68 | ```
69 |
70 |
71 | # Starting to Modify the Code
72 | ## Modifying the hyperparameters
73 | The default hyperparameters are defined in `viskill/configs`, where [```skill_learning.yaml```](viskill/configs/skill_learning.yaml) and [```skill_chaining.yaml```](viskill/configs/skill_chaining.yaml) define the experiment settings of learning subtask policies and chaining policies, respectively, and YAML file in the directory [```sl_agent```](viskill/configs/sl_agent) and [```sc_agent```](viskill/configs/sc_agent) define the hyperparameters of each method. Modifications to these parameters can be directly defined in the experiment or agent config files, or passed through the terminal command.
74 |
75 | ## Adding a new RL algorithm
76 | The core RL algorithms are implemented within the `BaseAgent` class. For adding a new skill chaining algorithm, a new file needs to be created in
77 | `viskill/agents` and [```BaseAgent```](viskill/agents/base.py#L8) needs to be subclassed. In particular, any required
78 | networks (actor, critic etc) need to be constructed and the `update(...)` function and `get_action(...)` needs to be overwritten. When implementation is done, a registration is needed in [```factory.py```](viskill/agents/factory.py) and a config file should also be made in [```sc_agent```](viskill/configs/sc_agent) to specify the model parameters.
79 |
80 | # Code Navigation
81 |
82 | ```
83 | viskill
84 | |- agents # implements core algorithms in agent classes
85 | |- components # reusable infrastructure for model training
86 | | |- checkpointer.py # handles saving + loading of model checkpoints
87 | | |- environment.py # environment wrappers for SurRoL environments
88 | | |- normalizer.py # normalizer for vectorized input
89 | | |- logger.py # implements core logging functionality using wandB
90 | |
91 | |- configs # experiment configs
92 | | |- skill_learning.yaml # configs for subtask policy learning
93 | | |- skill_chaining.yaml # configs for chaining policy learning
94 | | |- sl_agent # configs for demonstration-guided RL algorithm
95 | | |- sc_agent # configs for skill chaining algorithm
96 | | |- hier_agent # configs for (synthetic) hierachical agent
97 | |
98 | |- modules # reusable architecture components
99 | | |- critic.py # basic critic implementations (eg MLP-based critic)
100 | | |- distributions.py # pytorch distribution utils for density model
101 | | |- policy.py # basic actor implementations
102 | | |- replay_buffer.py # her replay buffer with future sampling strategy
103 | | |- sampler.py # rollout sampler for collecting experience
104 | | |- subnetworks.py # basic networks
105 | |
106 | |- trainers # main model training script, builds all components + runs training loop and logging
107 | |
108 | |- utils # general and rl utilities, pytorch / visualization utilities etc
109 | |- train_sl.py # experiment launcher for skill learning
110 | |- train_sc.py # experiment launcher for skill chaining
111 | ```
112 |
113 | # Contact
114 | For any questions, please feel free to email taou.cs13@gmail.com
115 |
--------------------------------------------------------------------------------
/viskill/agents/sl_ddpgbc.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from ..components.normalizer import Normalizer
8 | from ..modules.critics import Critic
9 | from ..modules.policies import DeterministicActor
10 | from .base import BaseAgent
11 |
12 |
13 | class SkillLearningDDPGBC(BaseAgent):
14 | def __init__(
15 | self,
16 | env_params,
17 | sampler,
18 | agent_cfg,
19 | ):
20 | super().__init__()
21 |
22 | self.discount = agent_cfg.discount
23 | self.reward_scale = agent_cfg.reward_scale
24 | self.update_epoch = agent_cfg.update_epoch
25 | self.her_sampler = sampler # same as which in buffer
26 | self.device = agent_cfg.device
27 |
28 | self.noise_eps = agent_cfg.noise_eps
29 | self.aux_weight = agent_cfg.aux_weight
30 | self.p_dist = agent_cfg.p_dist
31 | self.soft_target_tau = agent_cfg.soft_target_tau
32 |
33 | self.clip_obs = agent_cfg.clip_obs
34 | self.norm_clip = agent_cfg.norm_clip
35 | self.norm_eps = agent_cfg.norm_eps
36 |
37 | self.dima = env_params['act']
38 | self.dimo, self.dimg = env_params['obs'], env_params['goal']
39 |
40 | self.max_action = env_params['max_action']
41 | self.act_sampler = env_params['act_rand_sampler']
42 |
43 | # normarlizer
44 | self.o_norm = Normalizer(
45 | size=self.dimo,
46 | default_clip_range=self.norm_clip,
47 | eps=agent_cfg.norm_eps
48 | )
49 | self.g_norm = Normalizer(
50 | size=self.dimg,
51 | default_clip_range=self.norm_clip,
52 | eps=agent_cfg.norm_eps
53 | )
54 |
55 | # build policy
56 | self.actor = DeterministicActor(
57 | self.dimo+self.dimg, self.dima, agent_cfg.hidden_dim).to(agent_cfg.device)
58 | self.actor_target = copy.deepcopy(self.actor).to(agent_cfg.device)
59 |
60 | self.critic = Critic(
61 | self.dimo+self.dimg+self.dima, agent_cfg.hidden_dim).to(agent_cfg.device)
62 | self.critic_target = copy.deepcopy(self.critic).to(agent_cfg.device)
63 |
64 | # optimizer
65 | self.actor_optimizer = torch.optim.Adam(
66 | self.actor.parameters(), lr=agent_cfg.actor_lr
67 | )
68 | self.critic_optimizer = torch.optim.Adam(
69 | self.critic.parameters(), lr=agent_cfg.critic_lr
70 | )
71 |
72 | def get_action(self, state, noise=False):
73 | # random action at initial stage
74 | with torch.no_grad():
75 | o, g = state['observation'], state['desired_goal']
76 | input_tensor = self._preproc_inputs(o, g)
77 | action = self.actor(input_tensor).cpu().data.numpy().flatten()
78 |
79 | # Gaussian noise
80 | if noise:
81 | action = (action + self.max_action * self.noise_eps * np.random.randn(action.shape[0])).clip(
82 | -self.max_action, self.max_action)
83 |
84 | return action
85 |
86 | def update_critic(self, obs, action, reward, next_obs):
87 | metrics = dict()
88 |
89 | with torch.no_grad():
90 | action_out = self.actor_target(next_obs)
91 | target_V = self.critic_target(next_obs, action_out)
92 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach()
93 |
94 | clip_return = 1 / (1 - self.discount)
95 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach()
96 |
97 | Q = self.critic(obs, action)
98 | critic_loss = F.mse_loss(Q, target_Q)
99 |
100 | # optimize critic loss
101 | self.critic_optimizer.zero_grad()
102 | critic_loss.backward()
103 | self.critic_optimizer.step()
104 |
105 | metrics['critic_target_q'] = target_Q.mean().item()
106 | metrics['critic_q'] = Q.mean().item()
107 | metrics['critic_loss'] = critic_loss.item()
108 | return metrics
109 |
110 | def update_actor(self, obs, action, is_demo=False):
111 | metrics = dict()
112 |
113 | action_out = self.actor(obs)
114 | Q_out = self.critic(obs, action_out)
115 |
116 | # refer to https://arxiv.org/pdf/1709.10089.pdf
117 | if is_demo:
118 | bc_loss = self.norm_dist(action_out, action)
119 | # q-filter
120 | with torch.no_grad():
121 | q_filter = self.critic_target(obs, action) >= self.critic_target(obs, action_out)
122 | bc_loss = q_filter * bc_loss
123 | actor_loss = -(Q_out + self.aux_weight * bc_loss).mean()
124 | else:
125 | actor_loss = -(Q_out).mean()
126 |
127 | actor_loss += action_out.pow(2).mean()
128 |
129 | # optimize actor loss
130 | self.actor_optimizer.zero_grad()
131 | actor_loss.backward()
132 | self.actor_optimizer.step()
133 |
134 | metrics['actor_loss'] = actor_loss.item()
135 | return metrics
136 |
137 | def update(self, replay_buffer, demo_buffer):
138 | metrics = dict()
139 |
140 | for i in range(self.update_epoch):
141 | # sample from replay buffer
142 | obs, action, reward, done, next_obs = self.get_samples(replay_buffer)
143 |
144 | # ppdate critic and actor
145 | metrics.update(self.update_critic(obs, action, reward, next_obs))
146 | metrics.update(self.update_actor(obs, action))
147 |
148 | # sample from demo buffer
149 | obs, action, reward, done, next_obs = self.get_samples(demo_buffer)
150 |
151 | # update critic and actor
152 | self.update_critic(obs, action, reward, next_obs)
153 | self.update_actor(obs, action, is_demo=True)
154 |
155 | # update target critic and actor
156 | self.update_target()
157 | return metrics
158 |
159 | def update_target(self):
160 | # update the frozen target models
161 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
162 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data)
163 |
164 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
165 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data)
166 |
167 | def norm_dist(self, a1, a2):
168 | self.p_dist = np.inf if self.p_dist == -1 else self.p_dist
169 | return - torch.norm(a1 - a2, p=self.p_dist, dim=1, keepdim=True).pow(2) / self.dima
--------------------------------------------------------------------------------
/viskill/agents/sl_dex.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from ..utils.general_utils import AttrDict
6 | from .sl_ddpgbc import SkillLearningDDPGBC
7 |
8 |
9 | class SkillLearningDEX(SkillLearningDDPGBC):
10 | def __init__(
11 | self,
12 | env_params,
13 | sampler,
14 | agent_cfg,
15 | ):
16 | super().__init__(env_params, sampler, agent_cfg)
17 | self.k = 5
18 |
19 | def get_samples(self, replay_buffer):
20 | '''Addtionally sample next action for guidance propagation'''
21 | transitions = replay_buffer.sample()
22 |
23 | # preprocess
24 | o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g']
25 | transitions['obs'], transitions['g'] = self._preproc_og(o, g)
26 | transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g)
27 |
28 | obs_norm = self.o_norm.normalize(transitions['obs'])
29 | g_norm = self.g_norm.normalize(transitions['g'])
30 | inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
31 |
32 | obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
33 | g_next_norm = self.g_norm.normalize(transitions['g_next'])
34 | inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
35 |
36 | obs = self.to_torch(inputs_norm)
37 | next_obs = self.to_torch(inputs_next_norm)
38 | action = self.to_torch(transitions['actions'])
39 | next_action = self.to_torch(transitions['next_actions'])
40 | reward = self.to_torch(transitions['r'])
41 | done = self.to_torch(transitions['dones'])
42 | return obs, action, reward, done, next_obs, next_action
43 |
44 | def update_critic(self, obs, action, reward, next_obs, next_obs_demo, next_action_demo):
45 | with torch.no_grad():
46 | next_action_out = self.actor_target(next_obs)
47 | target_V = self.critic_target(next_obs, next_action_out)
48 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach()
49 |
50 | # exploration guidance
51 | topk_actions = self.compute_propagated_actions(next_obs, next_obs_demo, next_action_demo)
52 | act_dist = self.norm_dist(topk_actions, next_action_out)
53 | target_Q += self.aux_weight * act_dist
54 |
55 | clip_return = 5 / (1 - self.discount)
56 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach()
57 |
58 | Q = self.critic(obs, action)
59 | critic_loss = F.mse_loss(Q, target_Q)
60 |
61 | # optimize critic loss
62 | self.critic_optimizer.zero_grad()
63 | critic_loss.backward()
64 | self.critic_optimizer.step()
65 |
66 | metrics = AttrDict(
67 | critic_q=Q.mean().item(),
68 | critic_target_q=target_Q.mean().item(),
69 | critic_loss=critic_loss.item(),
70 | bacth_reward=reward.mean().item()
71 | )
72 | return metrics
73 |
74 | def update_actor(self, obs, obs_demo, action_demo):
75 | action_out = self.actor(obs)
76 | Q_out = self.critic(obs, action_out)
77 |
78 | topk_actions = self.compute_propagated_actions(obs, obs_demo, action_demo)
79 | act_dist = self.norm_dist(action_out, topk_actions)
80 | actor_loss = -(Q_out + self.aux_weight * act_dist).mean()
81 | actor_loss += action_out.pow(2).mean()
82 |
83 | # optimize actor loss
84 | self.actor_optimizer.zero_grad()
85 | actor_loss.backward()
86 | self.actor_optimizer.step()
87 |
88 | metrics = AttrDict(
89 | actor_loss=actor_loss.item(),
90 | act_dist=act_dist.mean().item()
91 | )
92 | return metrics
93 |
94 | def update(self, replay_buffer, demo_buffer):
95 | for i in range(self.update_epoch):
96 | # sample from replay buffer
97 | obs, action, reward, done, next_obs, next_action = self.get_samples(replay_buffer)
98 | obs_, action_, reward_, done_, next_obs_, next_action_ = self.get_samples(demo_buffer)
99 |
100 | with torch.no_grad():
101 | next_action_out = self.actor_target(next_obs)
102 | target_V = self.critic_target(next_obs, next_action_out)
103 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach()
104 |
105 | l2_pair = torch.cdist(next_obs, next_obs_)
106 | topk_value, topk_indices = l2_pair.topk(self.k, dim=1, largest=False)
107 | topk_weight = F.softmin(topk_value.sqrt(), dim=1)
108 | topk_actions = torch.ones_like(next_action_)
109 |
110 | for i in range(topk_actions.size(0)):
111 | topk_actions[i] = torch.mm(topk_weight[i].unsqueeze(0), next_action_[topk_indices[i]]).squeeze(0)
112 | intr = self.norm_dist(topk_actions, next_action_out)
113 | target_Q += self.aux_weight * intr
114 | next_action_out_ =self.actor_target(next_obs_)
115 | target_V_ = self.critic_target(next_obs_, next_action_out_)
116 | target_Q_ = self.reward_scale * reward_ + (self.discount * target_V_).detach()
117 | intr_ = self.norm_dist(next_action_, next_action_out_)
118 | target_Q_ += self.aux_weight * intr_
119 |
120 | clip_return = 5 / (1 - self.discount)
121 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach()
122 | target_Q_ = torch.clamp(target_Q_, -clip_return, 0).detach()
123 |
124 |
125 | Q = self.critic(obs, action)
126 | Q_ = self.critic(obs_, action_)
127 | critic_loss = F.mse_loss(Q, target_Q) + F.mse_loss(Q_, target_Q_)
128 |
129 | # optimize critic loss
130 | self.critic_optimizer.zero_grad()
131 | critic_loss.backward()
132 | self.critic_optimizer.step()
133 |
134 | action_out = self.actor(obs)
135 | action_out_ = self.actor(obs_)
136 | Q_out = self.critic(obs, action_out)
137 | Q_out_ = self.critic(obs_, action_out_)
138 |
139 | with torch.no_grad():
140 | l2_pair = torch.cdist(obs, obs_)
141 | topk_value, topk_indices = l2_pair.topk(self.k, dim=1, largest=False)
142 | topk_weight = F.softmin(topk_value.sqrt(), dim=1)
143 | topk_actions = torch.ones_like(action)
144 |
145 | for i in range(topk_actions.size(0)):
146 | topk_actions[i] = torch.mm(topk_weight[i].unsqueeze(0), action_[topk_indices[i]]).squeeze(0)
147 |
148 | intr2 = self.norm_dist(action_out, topk_actions)
149 | intr3 = self.norm_dist(action_out_, action_)
150 |
151 | # Refer to https://arxiv.org/pdf/1709.10089.pdf
152 | actor_loss = - (Q_out + self.aux_weight * intr2).mean()
153 | actor_loss += -(Q_out_ + self.aux_weight * intr3).mean()
154 |
155 | actor_loss += action_out.pow(2).mean()
156 | actor_loss += action_out_.pow(2).mean()
157 |
158 | #actor_loss += self.action_l2 * action_out.pow(2).mean()
159 |
160 | # Optimize actor loss
161 | self.actor_optimizer.zero_grad()
162 | actor_loss.backward()
163 | self.actor_optimizer.step()
164 |
165 | self.update_target()
166 |
167 | metrics = AttrDict(
168 | batch_reward=reward.mean().item(),
169 | critic_q=Q.mean().item(),
170 | critic_q_=Q_.mean().item(),
171 | critic_target_q=target_Q.mean().item(),
172 | critic_loss=critic_loss.item(),
173 | actor_loss=actor_loss.item()
174 | )
175 | return metrics
--------------------------------------------------------------------------------
/viskill/modules/sampler.py:
--------------------------------------------------------------------------------
1 | from ..utils.general_utils import AttrDict, listdict2dictlist
2 | from ..utils.rl_utils import ReplayCache, ReplayCacheGT
3 |
4 |
5 | class Sampler:
6 | """Collects rollouts from the environment using the given agent."""
7 | def __init__(self, env, agent, max_episode_len):
8 | self._env = env
9 | self._agent = agent
10 | self._max_episode_len = max_episode_len
11 |
12 | self._obs = None
13 | self._episode_step = 0
14 | self._episode_cache = ReplayCacheGT(max_episode_len)
15 |
16 | def init(self):
17 | """Starts a new rollout. Render indicates whether output should contain image."""
18 | self._episode_reset()
19 |
20 | def sample_action(self, obs, is_train):
21 | return self._agent.get_action(obs, noise=is_train)
22 |
23 | def sample_episode(self, is_train, render=False):
24 | """Samples one episode from the environment."""
25 | self.init()
26 | episode, done = [], False
27 | while not done and self._episode_step < self._max_episode_len:
28 | action = self.sample_action(self._obs, is_train)
29 | if action is None:
30 | break
31 | if render:
32 | render_obs = self._env.render('rgb_array')
33 | obs, reward, done, info = self._env.step(action)
34 | episode.append(AttrDict(
35 | reward=reward,
36 | success=info['is_success'],
37 | info=info
38 | ))
39 | self._episode_cache.store_transition(obs, action, done, info['gt_goal'])
40 | if render:
41 | episode[-1].update(AttrDict(image=render_obs))
42 |
43 | # update stored observation
44 | self._obs = obs
45 | self._episode_step += 1
46 |
47 | episode[-1].done = True # make sure episode is marked as done at final time step
48 | rollouts = self._episode_cache.pop()
49 | assert self._episode_step == self._max_episode_len
50 | return listdict2dictlist(episode), rollouts, self._episode_step
51 |
52 | def _episode_reset(self, global_step=None):
53 | """Resets sampler at the end of an episode."""
54 | self._episode_step, self._episode_reward = 0, 0.
55 | self._obs = self._reset_env()
56 | self._episode_cache.store_obs(self._obs)
57 |
58 | def _reset_env(self):
59 | return self._env.reset()
60 |
61 |
62 | class HierarchicalSampler(Sampler):
63 | """Collects experience batches by rolling out a hierarchical agent. Aggregates low-level batches into HL batch."""
64 | def __init__(self, env, agent, env_params):
65 | super().__init__(env, agent, env_params['max_timesteps'])
66 |
67 | self._env_params = env_params
68 | self._episode_cache = AttrDict(
69 | {subtask: ReplayCache(steps) for subtask, steps in env_params.subtask_steps.items()})
70 |
71 | def sample_episode(self, is_train, render=False):
72 | """Samples one episode from the environment."""
73 | self.init()
74 | sc_transitions = AttrDict({subtask: [] for subtask in self._env_params.subtasks})
75 | sc_succ_transitions = AttrDict({subtask: [] for subtask in self._env_params.subtasks})
76 | sc_episode, sl_episode, done, prev_subtask_succ = [], AttrDict(), False, AttrDict()
77 | while not done and self._episode_step < self._max_episode_len:
78 | agent_output = self.sample_action(self._obs, is_train, self._env.subtask)
79 | if self.last_sc_action is None:
80 | self._episode_cache[self._env.subtask].store_obs(self._obs)
81 |
82 | if render:
83 | render_obs = self._env.render('rgb_array')
84 | if agent_output.is_sc_step:
85 | self.last_sc_action = agent_output.sc_action
86 | self.reward_since_last_sc = 0
87 |
88 | obs, reward, done, info = self._env.step(agent_output.sl_action)
89 | self.reward_since_last_sc += reward
90 | if info['subtask_done']:
91 | if not done:
92 | # store skill-chaining transition
93 | sc_transitions[info['subtask']].append(
94 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']])
95 |
96 | if info['subtask_is_success']:
97 | sc_succ_transitions[info['subtask']].append(
98 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']])
99 | else:
100 | sc_succ_transitions[info['subtask']].append([None])
101 |
102 | # middle subtask
103 | self._episode_cache[self._env.subtask].store_obs(obs)
104 | self._episode_cache[self._env.prev_subtasks[self._env.subtask]].\
105 | store_transition(obs, agent_output.sl_action, True)
106 | self.last_sc_obs = obs['observation']
107 | else:
108 | # terminal subtask
109 | sc_transitions[info['subtask']] = []
110 | sc_transitions[info['subtask']].append(
111 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']])
112 | if info['subtask_is_success']:
113 | sc_succ_transitions[info['subtask']].append(
114 | [self.last_sc_obs, self.last_sc_action, self.reward_since_last_sc, obs['observation'], done, obs['desired_goal']])
115 | else:
116 | sc_succ_transitions[info['subtask']].append([None])
117 | self._episode_cache[self._env.subtask].store_transition(obs, agent_output.sl_action, True)
118 | prev_subtask_succ[self._env.subtask] = info['subtask_is_success']
119 | else:
120 | self._episode_cache[self._env.subtask].store_transition(obs, agent_output.sl_action, False)
121 | sc_episode.append(AttrDict(
122 | reward=reward,
123 | success=info['is_success'],
124 | info=info))
125 | if render:
126 | sc_episode[-1].update(AttrDict(image=render_obs))
127 |
128 | # update stored observation
129 | self._obs = obs
130 | self._episode_step += 1
131 |
132 | assert self._episode_step == self._max_episode_len
133 | for subtask in self._env_params.subtasks:
134 | if subtask not in prev_subtask_succ.keys():
135 | sl_episode[subtask] = self._episode_cache[subtask].pop()
136 | continue
137 | if prev_subtask_succ[subtask]:
138 | sl_episode[subtask] = self._episode_cache[subtask].pop()
139 | else:
140 | self._episode_cache[subtask].pop()
141 |
142 | sc_episode = listdict2dictlist(sc_episode)
143 | sc_episode.update(AttrDict(
144 | sc_transitions=sc_transitions,
145 | sc_succ_transitions=sc_succ_transitions)
146 | )
147 |
148 | return sc_episode, sl_episode, self._episode_step
149 |
150 | def _episode_reset(self, global_step=None):
151 | """Resets sampler at the end of an episode."""
152 | self._episode_step, self._episode_reward = 0, 0.
153 | self._obs = self._reset_env()
154 | self.last_sc_obs, self.last_sc_action = self._obs['observation'], None # stores observation when last hl action was taken
155 | self.reward_since_last_sc = 0 # accumulates the reward since the last HL step for HL transition
156 |
157 | def sample_action(self, obs, is_train, subtask):
158 | return self._agent.get_action(obs, subtask, noise=is_train)
--------------------------------------------------------------------------------
/viskill/utils/rl_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | from .general_utils import AttrDict, RecursiveAverageMeter
6 |
7 |
8 | def get_env_params(env, cfg):
9 | obs = env.reset()
10 | env_params = AttrDict(
11 | obs=obs['observation'].shape[0],
12 | achieved_goal=obs['achieved_goal'].shape[0],
13 | goal=obs['desired_goal'].shape[0],
14 | act=env.action_space.shape[0],
15 | act_rand_sampler=env.action_space.sample,
16 | max_timesteps=env.max_episode_steps,
17 | max_action=env.action_space.high[0],
18 | )
19 | if cfg.skill_chaining:
20 | env_params.update(AttrDict(
21 | act_sc=obs['achieved_goal'].shape[0] - env.len_cond, # withoug contact condition
22 | max_action_sc=env.max_action_range,
23 | adaptor_sc=env.goal_adapator,
24 | subtask_order=env.subtask_order,
25 | num_subtasks=len(env.subtask_order),
26 | subtask_steps=env.subtask_steps,
27 | subtasks=env.subtasks,
28 | next_subtasks=env.next_subtasks,
29 | prev_subtasks=env.prev_subtasks,
30 | middle_subtasks=env.next_subtasks.keys(),
31 | last_subtask=env.last_subtask,
32 | reward_funcs=env.get_reward_functions(),
33 | len_cond=env.len_cond
34 | ))
35 | return env_params
36 |
37 |
38 | class ReplayCache:
39 | def __init__(self, T):
40 | self.T = T
41 | self.reset()
42 |
43 | def reset(self):
44 | self.t = 0
45 | self.obs, self.ag, self.g, self.actions, self.dones = [], [], [], [], []
46 |
47 | def store_transition(self, obs, action, done):
48 | self.obs.append(obs['observation'])
49 | self.ag.append(obs['achieved_goal'])
50 | self.g.append(obs['desired_goal'])
51 | self.actions.append(action)
52 | self.dones.append(done)
53 |
54 | def store_obs(self, obs):
55 | self.obs.append(obs['observation'])
56 | self.ag.append(obs['achieved_goal'])
57 |
58 | def pop(self):
59 | assert len(self.obs) == self.T + 1 and len(self.actions) == self.T
60 | obs = np.expand_dims(np.array(self.obs.copy()),axis=0)
61 | ag = np.expand_dims(np.array(self.ag.copy()), axis=0)
62 | #print(self.ag)
63 | g = np.expand_dims(np.array(self.g.copy()), axis=0)
64 | actions = np.expand_dims(np.array(self.actions.copy()), axis=0)
65 | dones = np.expand_dims(np.array(self.dones.copy()), axis=1)
66 | dones = np.expand_dims(dones, axis=0)
67 |
68 | self.reset()
69 | episode = AttrDict(obs=obs, ag=ag, g=g, actions=actions, dones=dones)
70 | return episode
71 |
72 |
73 | class ReplayCacheGT(ReplayCache):
74 | def reset(self):
75 | self.t = 0
76 | self.obs, self.ag, self.g, self.actions, self.dones, self.gt_g = [], [], [], [], [], []
77 |
78 | def store_transition(self, obs, action, done, gt_goal):
79 | self.obs.append(obs['observation'])
80 | self.ag.append(obs['achieved_goal'])
81 | self.g.append(obs['desired_goal'])
82 | self.actions.append(action)
83 | self.dones.append(done)
84 | self.gt_g.append(gt_goal)
85 |
86 | def pop(self):
87 | assert len(self.obs) == self.T + 1 and len(self.actions) == self.T
88 | obs = np.expand_dims(np.array(self.obs.copy()),axis=0)
89 | ag = np.expand_dims(np.array(self.ag.copy()), axis=0)
90 | #print(self.ag)
91 | g = np.expand_dims(np.array(self.g.copy()), axis=0)
92 | actions = np.expand_dims(np.array(self.actions.copy()), axis=0)
93 | dones = np.expand_dims(np.array(self.dones.copy()), axis=1)
94 | dones = np.expand_dims(dones, axis=0)
95 | gt_g = np.expand_dims(np.array(self.gt_g.copy()), axis=0)
96 |
97 | self.reset()
98 | episode = AttrDict(obs=obs, ag=ag, g=g, actions=actions, dones=dones, gt_g=gt_g)
99 | return episode
100 |
101 |
102 | def init_demo_buffer(cfg, buffer, agent, subtask=None, update_normalizer=True):
103 | '''Load demonstrations into buffer and initilaize normalizer'''
104 | demo_path = os.path.join(os.getcwd(),'SurRoL/surrol/data/demo')
105 | file_name = "data_"
106 | file_name += cfg.task
107 | file_name += "_" + 'random'
108 | if subtask is None:
109 | file_name += "_" + str(cfg.num_demo) + '_primitive_new' + cfg.subtask
110 | else:
111 | file_name += "_" + str(cfg.num_demo) + '_primitive_new' + subtask
112 | file_name += ".npz"
113 |
114 | demo_path = os.path.join(demo_path, file_name)
115 | demo = np.load(demo_path, allow_pickle=True)
116 | demo_obs, demo_acs, demo_gt = demo['observations'], demo['actions'], demo['gt_actions']
117 |
118 | episode_cache = ReplayCacheGT(buffer.T)
119 | for epsd in range(cfg.num_demo):
120 | episode_cache.store_obs(demo_obs[epsd][0])
121 | for i in range(buffer.T):
122 | episode_cache.store_transition(
123 | obs=demo_obs[epsd][i+1],
124 | action=demo_acs[epsd][i],
125 | done=i==(buffer.T-1),
126 | gt_goal=demo_gt[epsd][i]
127 | )
128 | episode = episode_cache.pop()
129 | buffer.store_episode(episode)
130 | if update_normalizer:
131 | agent.update_normalizer(episode)
132 |
133 |
134 | def init_sc_buffer(cfg, buffer, agent, env_params):
135 | '''Load demonstrations into buffer and initilaize normalizer'''
136 | for subtask in env_params.subtasks:
137 | demo_path = os.path.join(os.getcwd(),'SurRoL/surrol/data/demo')
138 | file_name = "data_"
139 | file_name += cfg.task
140 | file_name += "_" + 'random'
141 | file_name += "_" + str(cfg.num_demo) + '_primitive_new' + subtask
142 | file_name += ".npz"
143 |
144 | demo_path = os.path.join(demo_path, file_name)
145 | demo = np.load(demo_path, allow_pickle=True)
146 | demo_obs, demo_acs, demo_gt = demo['observations'], demo['actions'], demo['gt_actions']
147 |
148 | for epsd in range(cfg.num_demo):
149 | obs = demo_obs[epsd][0]['observation']
150 | next_obs = demo_obs[epsd][-1]['observation']
151 | action = demo_obs[epsd][0]['desired_goal'][:-env_params.len_cond]
152 | # reward = sum([env_params.reward_funcs[subtask](demo_obs[epsd][i+1]['achieved_goal'], demo_obs[epsd][i+1]['desired_goal']) \
153 | # for i in range(len(demo_acs[epsd]))])
154 | reward = env_params.reward_funcs[subtask](demo_obs[epsd][-1]['achieved_goal'], demo_obs[epsd][-1]['desired_goal'])
155 | #print(subtask, epsd, reward)
156 | done = subtask not in env_params.next_subtasks.keys()
157 | reward = done * reward
158 | gt_action = demo_gt[epsd][-1]
159 | buffer[subtask].add(obs, action, reward, next_obs, done, gt_action)
160 | if agent.sc_agent.normalize:
161 | # TODO: hide normalized
162 | agent.sc_agent.o_norm[subtask].update(obs)
163 |
164 |
165 | class RolloutStorage:
166 | """Can hold multiple rollouts, can compute statistics over these rollouts."""
167 | def __init__(self):
168 | self.rollouts = []
169 |
170 | def append(self, rollout):
171 | """Adds rollout to storage."""
172 | self.rollouts.append(rollout)
173 |
174 | def rollout_stats(self):
175 | """Returns AttrDict of average statistics over the rollouts."""
176 | assert self.rollouts # rollout storage should not be empty
177 | stats = RecursiveAverageMeter()
178 | for rollout in self.rollouts:
179 | stats.update(AttrDict(
180 | avg_reward=np.stack(rollout.reward).sum(),
181 | avg_success_rate=rollout.success[-1],
182 | ))
183 | return stats.avg
184 |
185 | def reset(self):
186 | del self.rollouts
187 | self.rollouts = []
188 |
189 | def get(self):
190 | return self.rollouts
191 |
192 | def __contains__(self, key):
193 | return self.rollouts and key in self.rollouts[0]
--------------------------------------------------------------------------------
/viskill/agents/sc_ddpg.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from ..components.normalizer import Normalizer
8 | from ..modules.critics import SkillChainingCritic
9 | from ..modules.policies import SkillChainingActor
10 | from ..utils.general_utils import AttrDict
11 | from .base import BaseAgent
12 |
13 |
14 | class SkillChainingDDPG(BaseAgent):
15 | def __init__(
16 | self,
17 | env_params,
18 | agent_cfg,
19 | sl_agent,
20 | ):
21 | super().__init__()
22 |
23 | self.sl_agent = sl_agent
24 |
25 | self.discount = agent_cfg.discount
26 | self.reward_scale = agent_cfg.reward_scale
27 | self.update_epoch = agent_cfg.update_epoch
28 | self.device = agent_cfg.device
29 | self.env_params = env_params
30 |
31 | self.random_eps = agent_cfg.random_eps
32 | self.noise_eps = agent_cfg.noise_eps
33 | self.soft_target_tau = agent_cfg.soft_target_tau
34 |
35 | self.normalize = agent_cfg.normalize
36 | self.clip_obs = agent_cfg.clip_obs
37 | self.norm_clip = agent_cfg.norm_clip
38 | self.norm_eps = agent_cfg.norm_eps
39 | self.intr_reward = agent_cfg.intr_reward
40 | self.raw_env_reward = agent_cfg.raw_env_reward
41 |
42 | self.dima = env_params['act_sc'] #
43 | self.dimo = env_params['obs']
44 | self.max_action = env_params['max_action_sc']
45 | self.goal_adapator = env_params['adaptor_sc']
46 |
47 | # TODO: normarlizer
48 | self.o_norm = {subtask: Normalizer(
49 | size=self.dimo,
50 | default_clip_range=self.norm_clip,
51 | eps=agent_cfg.norm_eps) for subtask in env_params['middle_subtasks']}
52 |
53 | # build policy
54 | self.actor = SkillChainingActor(
55 | in_dim=self.dimo,
56 | out_dim=self.dima,
57 | hidden_dim=agent_cfg.hidden_dim,
58 | max_action=self.max_action,
59 | middle_subtasks=env_params['middle_subtasks'],
60 | last_subtask=env_params['last_subtask']
61 | ).to(agent_cfg.device)
62 | self.actor_target = copy.deepcopy(self.actor).to(agent_cfg.device)
63 |
64 | self.critic = SkillChainingCritic(
65 | in_dim=self.dimo+self.dima,
66 | hidden_dim=agent_cfg.hidden_dim,
67 | middle_subtasks=env_params['middle_subtasks'],
68 | last_subtask=env_params['last_subtask']
69 | ).to(agent_cfg.device)
70 | self.critic_target = copy.deepcopy(self.critic).to(agent_cfg.device)
71 |
72 | # optimizer
73 | self.actor_optimizer = {subtask: torch.optim.Adam(
74 | self.actor.parameters(), lr=agent_cfg.actor_lr) for subtask in env_params['middle_subtasks']}
75 | self.critic_optimizer = {subtask: torch.optim.Adam(
76 | self.critic.parameters(), lr=agent_cfg.critic_lr) for subtask in env_params['middle_subtasks']}
77 |
78 | def init(self, task, sl_agent):
79 | '''Initialize the actor, critic and normalizers of last subtask.'''
80 | self.actor.init_last_subtask_actor(task, sl_agent[task].actor)
81 | self.actor_target.init_last_subtask_actor(task, sl_agent[task].actor_target)
82 | self.critic.init_last_subtask_q(task, sl_agent[task].critic)
83 | self.critic_target.init_last_subtask_q(task, sl_agent[task].critic_target)
84 | self.sl_normarlizer = {subtask: sl_agent[subtask]._preproc_inputs
85 | for subtask in self.env_params['subtasks']}
86 |
87 | def get_samples(self, replay_buffer, subtask):
88 | next_subtask = self.env_params['next_subtasks'][subtask]
89 |
90 | obs, action, reward, next_obs, done, gt_action, idxs = replay_buffer[subtask].sample(return_idxs=True)
91 | sl_norm_next_obs = self.sl_normarlizer[next_subtask](next_obs, gt_action, dim=1) # only for terminal subtask
92 | obs = self._preproc_obs(obs, subtask)
93 | next_obs = self._preproc_obs(next_obs, next_subtask)
94 | action = self.to_torch(action)
95 | reward = self.to_torch(reward)
96 | done = self.to_torch(done)
97 | gt_action = self.to_torch(gt_action)
98 |
99 | if next_subtask == self.env_params['last_subtask'] and self.raw_env_reward:
100 | assert len(replay_buffer[subtask]) == len(replay_buffer[next_subtask])
101 | _, _, raw_reward, _, _, _ = replay_buffer[next_subtask].sample(idxs=idxs)
102 | return obs, action, reward, next_obs, done, sl_norm_next_obs, self.to_torch(raw_reward)
103 |
104 | return obs, action, reward, next_obs, done, sl_norm_next_obs, None
105 |
106 | def get_action(self, state, subtask, noise=False):
107 | # random action at initial stage
108 | with torch.no_grad():
109 | input_tensor = self._preproc_obs(state, subtask)
110 | action = self.actor[subtask](input_tensor).cpu().data.numpy().flatten()
111 | # Gaussian noise
112 | if noise:
113 | action = (action + self.max_action * self.noise_eps * np.random.randn(action.shape[0])).clip(
114 | -self.max_action, self.max_action)
115 | return action
116 |
117 | def update_critic(self, obs, action, reward, next_obs):
118 | metrics = AttrDict()
119 |
120 | with torch.no_grad():
121 | action_out = self.actor_target(next_obs)
122 | target_V = self.critic_target(next_obs, action_out)
123 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach()
124 |
125 | clip_return = 1 / (1 - self.discount)
126 | target_Q = torch.clamp(target_Q, -clip_return, 0).detach()
127 |
128 | Q = self.critic(obs, action)
129 | critic_loss = F.mse_loss(Q, target_Q)
130 |
131 | # Optimize critic loss
132 | self.critic_optimizer.zero_grad()
133 | critic_loss.backward()
134 | self.critic_optimizer.step()
135 |
136 | metrics = AttrDict(
137 | critic_target_q=target_Q.mean().item(),
138 | critic_q=Q.mean().item(),
139 | critic_loss=critic_loss.item()
140 | )
141 | return metrics
142 |
143 | def update_actor(self, obs, action, is_demo=False):
144 | action_out = self.actor(obs)
145 | Q_out = self.critic(obs, action_out)
146 | actor_loss = -(Q_out).mean()
147 |
148 | # Optimize actor loss
149 | self.actor_optimizer.zero_grad()
150 | actor_loss.backward()
151 | self.actor_optimizer.step()
152 |
153 | metrics = AttrDict(
154 | actor_loss=actor_loss.item()
155 | )
156 | return metrics
157 |
158 | def update(self, replay_buffer):
159 | metrics = AttrDict()
160 |
161 | for i in range(self.update_epoch):
162 | # Sample from replay buffer
163 | obs, action, reward, next_obs, done = self.get_samples(replay_buffer)
164 | # Update critic and actor
165 | metrics.update(self.update_critic(obs, action, reward, next_obs))
166 | metrics.update(self.update_actor(obs, action))
167 |
168 | # Update target critic and actor
169 | self.update_target()
170 | return metrics
171 |
172 | def _preproc_obs(self, o, subtask):
173 | o = np.clip(o, -self.clip_obs, self.clip_obs)
174 | if self.normalize and subtask != self.env_params.last_subtask:
175 | o = self.o_norm[subtask].normalize(o)
176 | inputs = torch.tensor(o, dtype=torch.float32).to(self.device)
177 | return inputs
178 |
179 | def update_target(self):
180 | # Update the frozen target models
181 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
182 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data)
183 |
184 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
185 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data)
186 |
187 | def update_normalizer(self, rollouts, subtask):
188 | for transition in rollouts:
189 | obs, _, _, _, _, _ = transition
190 | # update
191 | self.o_norm[subtask].update(obs)
192 | # recompute the stats
193 | self.o_norm[subtask].recompute_stats()
--------------------------------------------------------------------------------
/viskill/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import itertools
3 | import random
4 | import time
5 | from functools import reduce
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def set_seed_everywhere(seed):
12 | torch.manual_seed(seed)
13 | if torch.cuda.is_available():
14 | torch.cuda.manual_seed_all(seed)
15 | np.random.seed(seed)
16 | random.seed(seed)
17 |
18 |
19 | class Until:
20 | def __init__(self, until, action_repeat=1):
21 | self._until = until
22 | self._action_repeat = action_repeat
23 |
24 | def __call__(self, step):
25 | if self._until is None:
26 | return True
27 | until = self._until // self._action_repeat
28 | return step < until
29 |
30 |
31 | class Every:
32 | def __init__(self, every, action_repeat=1):
33 | self._every = every
34 | self._action_repeat = action_repeat
35 |
36 | def __call__(self, step):
37 | if self._every is None:
38 | return False
39 | every = self._every // self._action_repeat
40 | if step % every == 0:
41 | return True
42 | return False
43 |
44 |
45 | class Timer:
46 | def __init__(self):
47 | self._start_time = time.time()
48 | self._last_time = time.time()
49 |
50 | def reset(self):
51 | elapsed_time = time.time() - self._last_time
52 | self._last_time = time.time()
53 | total_time = time.time() - self._start_time
54 | return elapsed_time, total_time
55 |
56 | def total_time(self):
57 | return time.time() - self._start_time
58 |
59 |
60 | class AverageMeter(object):
61 | """Computes and stores the average and current value"""
62 |
63 | def __init__(self, digits=None):
64 | """
65 | :param digits: number of digits returned for average value
66 | """
67 | self._digits = digits
68 | self.reset()
69 |
70 | def reset(self):
71 | self.val = 0
72 | self._avg = 0
73 | self.sum = 0
74 | self.count = 0
75 |
76 | def update(self, val, n=1):
77 | self.val = val
78 | self.sum += val * n
79 | self.count += n
80 | self._avg = self.sum / max(1, self.count)
81 |
82 | @property
83 | def avg(self):
84 | if self._digits is not None:
85 | return np.round(self._avg, self._digits)
86 | else:
87 | return self._avg
88 |
89 |
90 | class RecursiveAverageMeter(object):
91 | """Computes and stores the average and current value"""
92 | def __init__(self):
93 | self.reset()
94 |
95 | def reset(self):
96 | self.val = None
97 | self.avg = None
98 | self.sum = None
99 | self.count = 0
100 |
101 | def update(self, val):
102 | self.val = val
103 | if self.sum is None:
104 | self.sum = val
105 | else:
106 | self.sum = map_recursive_list(lambda x, y: x + y, [self.sum, val])
107 | self.count += 1
108 | self.avg = map_recursive(lambda x: x / self.count, self.sum)
109 |
110 |
111 |
112 | class AttrDict(dict):
113 | __setattr__ = dict.__setitem__
114 |
115 | def __getattr__(self, attr):
116 | # Take care that getattr() raises AttributeError, not KeyError.
117 | # Required e.g. for hasattr(), deepcopy and OrderedDict.
118 | try:
119 | return self.__getitem__(attr)
120 | except KeyError:
121 | raise AttributeError("Attribute %r not found" % attr)
122 |
123 | def __getstate__(self):
124 | return self
125 |
126 | def __setstate__(self, d):
127 | self = d
128 |
129 |
130 | def map_dict(fn, d):
131 | """takes a dictionary and applies the function to every element"""
132 | return type(d)(map(lambda kv: (kv[0], fn(kv[1])), d.items()))
133 |
134 |
135 | def map_recursive(fn, tensors):
136 | return make_recursive(fn)(tensors)
137 |
138 |
139 | def map_recursive_list(fn, tensors):
140 | return make_recursive_list(fn)(tensors)
141 |
142 |
143 | def make_recursive(fn, *argv, **kwargs):
144 | """ Takes a fn and returns a function that can apply fn on tensor structure
145 | which can be a single tensor, tuple or a list. """
146 |
147 | def recursive_map(tensors):
148 | if tensors is None:
149 | return tensors
150 | elif isinstance(tensors, list) or isinstance(tensors, tuple):
151 | return type(tensors)(map(recursive_map, tensors))
152 | elif isinstance(tensors, dict):
153 | return type(tensors)(map_dict(recursive_map, tensors))
154 | elif isinstance(tensors, torch.Tensor) or isinstance(tensors, np.ndarray):
155 | return fn(tensors, *argv, **kwargs)
156 | else:
157 | try:
158 | return fn(tensors, *argv, **kwargs)
159 | except Exception as e:
160 | print("The following error was raised when recursively applying a function:")
161 | print(e)
162 | raise ValueError("Type {} not supported for recursive map".format(type(tensors)))
163 |
164 | return recursive_map
165 |
166 |
167 | def make_recursive_list(fn):
168 | """ Takes a fn and returns a function that can apply fn across tuples of tensor structures,
169 | each of which can be a single tensor, tuple or a list. """
170 |
171 | def recursive_map(tensors):
172 | if tensors is None:
173 | return tensors
174 | elif isinstance(tensors[0], list) or isinstance(tensors[0], tuple):
175 | return type(tensors[0])(map(recursive_map, zip(*tensors)))
176 | elif isinstance(tensors[0], dict):
177 | return map_dict(recursive_map, listdict2dictlist(tensors))
178 | elif isinstance(tensors[0], torch.Tensor):
179 | return fn(*tensors)
180 | else:
181 | try:
182 | return fn(*tensors)
183 | except Exception as e:
184 | print("The following error was raised when recursively applying a function:")
185 | print(e)
186 | raise ValueError("Type {} not supported for recursive map".format(type(tensors)))
187 |
188 | return recursive_map
189 |
190 |
191 | def listdict2dictlist(LD):
192 | """ Converts a list of dicts to a dict of lists """
193 |
194 | # Take intersection of keys
195 | keys = reduce(lambda x,y: x & y, (map(lambda d: d.keys(), LD)))
196 | return AttrDict({k: [dic[k] for dic in LD] for k in keys})
197 |
198 |
199 | # def joinListDictList(LDL):
200 | # """Joins a list of dictionaries that contain lists."""
201 | # DLL = listdict2dictlist(LDL)
202 | # return type(LDL[0])({k: list(itertools.chain.from_iterable(DLL[k])) for k in DLL})
203 |
204 |
205 | def joinListDictList(LDL):
206 | """Joins a list of dictionaries that contain lists."""
207 | DLL = listdict2dictlist(LDL)
208 | return type(LDL[0])({k: np.array(list(itertools.chain.from_iterable(DLL[k]))) for k in DLL})
209 |
210 |
211 | def joinListDict(LD):
212 | """Joins a list of dictionaries that contain lists."""
213 | DL = listdict2dictlist(LD)
214 | return type(LD[0])({k: np.array(DL[k]) for k in DL})
215 |
216 |
217 | def joinListList(LL):
218 | """Joins a list of dictionaries that contain lists."""
219 | return type(LL[0])(itertools.chain.from_iterable(LL))
220 |
221 |
222 | def obj2np(obj):
223 | """Wraps an object into an np.array."""
224 | ar = np.zeros((1,), dtype=np.object_)
225 | ar[0] = obj
226 | return ar
227 |
228 |
229 | def flatten_dict(d, parent_key='', sep='_'):
230 | items = []
231 | for k, v in d.items():
232 | new_key = parent_key + sep + k if parent_key else k
233 | if isinstance(v, collections.MutableMapping):
234 | items.extend(flatten_dict(v, new_key, sep=sep).items())
235 | else:
236 | items.append((new_key, v))
237 | return dict(items)
238 |
239 |
240 | def prefix_dict(d, prefix):
241 | """Adds the prefix to all keys of dict d."""
242 | return type(d)({prefix+k: v for k, v in d.items()})
243 |
244 |
245 | def np2obj(np_array):
246 | if isinstance(np_array, list) or np_array.size > 1:
247 | return [e[0] for e in np_array]
248 | else:
249 | return np_array[0]
250 |
251 |
252 | def str2int(str):
253 | try:
254 | return int(str)
255 | except ValueError:
256 | return None
257 |
258 |
259 | def get_last_argmax(array):
260 | b = array[::-1]
261 | idx = len(b) - np.argmax(array) - 1
262 | return idx
263 |
--------------------------------------------------------------------------------
/viskill/agents/sc_sac.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from ..components.normalizer import Normalizer
8 | from ..modules.critics import SkillChainingDoubleCritic
9 | from ..modules.policies import SkillChainingDiagGaussianActor
10 | from ..utils.general_utils import AttrDict, prefix_dict
11 | from .sc_ddpg import SkillChainingDDPG
12 |
13 |
14 | class SkillChainingSAC(SkillChainingDDPG):
15 | def __init__(
16 | self,
17 | env_params,
18 | agent_cfg,
19 | sl_agent
20 | ):
21 | super(SkillChainingDDPG, self).__init__()
22 |
23 | self.sl_agent = sl_agent
24 |
25 | self.discount = agent_cfg.discount
26 | self.reward_scale = agent_cfg.reward_scale
27 | self.update_epoch = agent_cfg.update_epoch
28 | self.device = agent_cfg.device
29 | self.raw_env_reward = agent_cfg.raw_env_reward
30 | self.env_params = env_params
31 |
32 | # SAC parameters
33 | self.learnable_temperature = agent_cfg.learnable_temperature
34 | self.soft_target_tau = agent_cfg.soft_target_tau
35 |
36 | self.normalize = agent_cfg.normalize
37 | self.clip_obs = agent_cfg.clip_obs
38 | self.norm_clip = agent_cfg.norm_clip
39 | self.norm_eps = agent_cfg.norm_eps
40 |
41 | self.dima = env_params['act_sc']
42 | self.dimo = env_params['obs']
43 | self.max_action = env_params['max_action_sc']
44 | self.goal_adapator = env_params['adaptor_sc']
45 |
46 | # normarlizer
47 | self.o_norm = Normalizer(
48 | size=self.dimo,
49 | default_clip_range=self.norm_clip,
50 | eps=agent_cfg.norm_eps
51 | )
52 |
53 | # build policy
54 | self.actor = SkillChainingDiagGaussianActor(
55 | in_dim=self.dimo,
56 | out_dim=self.dima,
57 | hidden_dim=agent_cfg.hidden_dim,
58 | max_action=self.max_action,
59 | middle_subtasks=env_params['middle_subtasks'],
60 | last_subtask=env_params['last_subtask']
61 | ).to(agent_cfg.device)
62 |
63 | self.critic = SkillChainingDoubleCritic(
64 | in_dim=self.dimo+self.dima,
65 | hidden_dim=agent_cfg.hidden_dim,
66 | middle_subtasks=env_params['middle_subtasks'],
67 | last_subtask=env_params['last_subtask']
68 | ).to(agent_cfg.device)
69 | self.critic_target = copy.deepcopy(self.critic).to(agent_cfg.device)
70 |
71 | # entropy term
72 | if self.learnable_temperature:
73 | self.target_entropy = -self.dima
74 | self.log_alpha = {subtask : torch.tensor(
75 | np.log(agent_cfg.init_temperature)).to(self.device) for subtask in env_params['middle_subtasks']}
76 | for subtask in env_params['middle_subtasks']:
77 | self.log_alpha[subtask].requires_grad = True
78 | else:
79 | self.log_alpha = {subtask : torch.tensor(
80 | np.log(agent_cfg.init_temperature)).to(self.device) for subtask in env_params['middle_subtasks']}
81 |
82 | # optimizer
83 | self.actor_optimizer = {subtask: torch.optim.Adam(
84 | self.actor.parameters(), lr=agent_cfg.actor_lr) for subtask in env_params['middle_subtasks']}
85 | self.critic_optimizer = {subtask: torch.optim.Adam(
86 | self.critic.parameters(), lr=agent_cfg.critic_lr) for subtask in env_params['middle_subtasks']}
87 | self.temp_optimizer = {subtask: torch.optim.Adam(
88 | [self.log_alpha[subtask]], lr=agent_cfg.temp_lr) for subtask in env_params['middle_subtasks']}
89 |
90 | def init(self, task, sl_agent):
91 | '''Initialize the actor, critic and normalizers of last subtask.'''
92 | self.actor.init_last_subtask_actor(task, sl_agent[task].actor)
93 | self.critic.init_last_subtask_q(task, sl_agent[task].critic)
94 | self.critic_target.init_last_subtask_q(task, sl_agent[task].critic_target)
95 | self.sl_normarlizer = {subtask: sl_agent[subtask]._preproc_inputs
96 | for subtask in self.env_params['subtasks']}
97 |
98 | def alpha(self, subtask):
99 | return self.log_alpha[subtask].exp()
100 |
101 | def get_action(self, state, subtask, noise=False):
102 | with torch.no_grad():
103 | #state = {key: self.to_torch(state[key].reshape([1, -1])) for key in state.keys()} # unsqueeze
104 | input_tensor = self._preproc_obs(state, subtask)
105 | dist = self.actor(input_tensor, subtask)
106 | if noise:
107 | action = dist.sample()
108 | else:
109 | action = dist.mean
110 |
111 | return action.cpu().data.numpy().flatten() * self.max_action
112 |
113 | def get_q_value(self, state, action, subtask):
114 | with torch.no_grad():
115 | input_tensor = self._preproc_obs(state, subtask)
116 | action = self.to_torch(action)
117 | q_value = self.critic.q(input_tensor, action, subtask)
118 |
119 | return q_value
120 |
121 | def update_critic(self, obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask):
122 | assert subtask != self.env_params['last_subtask']
123 |
124 | with torch.no_grad():
125 | next_subtask = self.env_params['next_subtasks'][subtask]
126 | if next_subtask != self.env_params['last_subtask']:
127 | dist = self.actor(next_obs, next_subtask)
128 | action_out = dist.rsample()
129 | log_prob = dist.log_prob(action_out).sum(-1, keepdim=True)
130 | target_V = self.critic_target.q(next_obs, action_out, next_subtask)
131 | target_V = target_V - self.alpha(next_subtask).detach() * log_prob
132 | else:
133 | action_out = self.actor[next_subtask](sl_norm_next_obs)
134 | target_V = self.critic[next_subtask](sl_norm_next_obs, action_out).squeeze(0)
135 |
136 | if self.raw_env_reward and next_subtask == self.env_params['last_subtask']:
137 | target_Q = self.reward_scale * reward + (self.discount * raw_reward)
138 | else:
139 | target_Q = self.reward_scale * reward + (self.discount * target_V).detach()
140 |
141 | current_Q1, current_Q2 = self.critic(obs, action, subtask)
142 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
143 |
144 | # optimize critic loss
145 | self.critic_optimizer[subtask].zero_grad()
146 | critic_loss.backward()
147 | self.critic_optimizer[subtask].step()
148 |
149 | metrics = AttrDict(
150 | critic_target_q=target_Q.mean().item(),
151 | critic_q=current_Q1.mean().item(),
152 | critic_loss=critic_loss.item()
153 | )
154 | return prefix_dict(metrics, subtask + '_')
155 |
156 | def update_actor_and_alpha(self, obs, subtask):
157 | # compute log probability
158 | dist = self.actor(obs, subtask)
159 | action = dist.rsample()
160 | log_probs = dist.log_prob(action).sum(-1, keepdim=True)
161 |
162 | # compute state value
163 | actor_Q = self.critic.q(obs, action, subtask)
164 | actor_loss = (self.alpha(subtask).detach() * log_probs - actor_Q).mean()
165 |
166 | # optimize actor loss
167 | self.actor_optimizer[subtask].zero_grad()
168 | actor_loss.backward()
169 | self.actor_optimizer[subtask].step()
170 |
171 | metrics = AttrDict(
172 | log_probs=log_probs.mean(),
173 | actor_loss=actor_loss.item()
174 | )
175 |
176 | # compute temp loss
177 | if self.learnable_temperature:
178 | temp_loss = (self.alpha(subtask) * (-log_probs - self.target_entropy).detach()).mean()
179 | self.temp_optimizer[subtask].zero_grad()
180 | temp_loss.backward()
181 | self.temp_optimizer[subtask].step()
182 |
183 | metrics.update(AttrDict(
184 | temp_loss=temp_loss.item(),
185 | temp=self.alpha(subtask)
186 | ))
187 | return prefix_dict(metrics, subtask + '_')
188 |
189 | def update(self, replay_buffer, demo_buffer):
190 | metrics = AttrDict()
191 |
192 | for i in range(self.update_epoch):
193 | for subtask in self.env_params['middle_subtasks']:
194 | # sample from replay buffer
195 | obs, action, reward, next_obs, done, sl_norm_next_obs, raw_reward = self.get_samples(replay_buffer, subtask)
196 | action = action / self.max_action
197 |
198 | # update critic and actor
199 | metrics.update(self.update_critic( obs, action, reward, next_obs, sl_norm_next_obs, raw_reward, subtask))
200 | metrics.update(self.update_actor_and_alpha(obs, subtask))
201 |
202 | # update target critic and actor
203 | self.update_target()
204 |
205 | return metrics
206 |
207 | def update_target(self):
208 | # update the frozen target models
209 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
210 | target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data)
211 |
--------------------------------------------------------------------------------
/viskill/trainers/sl_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | from ..agents.factory import make_sl_agent
6 | from ..components.checkpointer import CheckpointHandler, save_cmd
7 | from ..components.envrionment import make_env
8 | from ..components.logger import Logger, WandBLogger, logger
9 | from ..modules.replay_buffer import HerReplayBufferWithGT, get_buffer_sampler
10 | from ..modules.sampler import Sampler
11 | from ..utils.general_utils import (AverageMeter, Every, Timer, Until,
12 | set_seed_everywhere)
13 | from ..utils.mpi import (mpi_gather_experience_episode,
14 | mpi_gather_experience_rollots, mpi_sum,
15 | update_mpi_config)
16 | from ..utils.rl_utils import RolloutStorage, get_env_params, init_demo_buffer
17 | from .base_trainer import BaseTrainer
18 |
19 |
20 | class SkillLearningTrainer(BaseTrainer):
21 | def _setup(self):
22 | self._setup_env() # Environment
23 | self._setup_buffer() # Relay buffer
24 | self._setup_agent() # Agent
25 | self._setup_sampler() # Sampler
26 | self._setup_logger() # Logger
27 | self._setup_misc() # MISC
28 |
29 | if self.is_chef:
30 | self.termlog.info('Setup done')
31 |
32 | def _setup_env(self):
33 | self.train_env = make_env(self.cfg)
34 | self.eval_env = make_env(self.cfg)
35 | self.env_params = get_env_params(self.train_env, self.cfg)
36 |
37 | def _setup_buffer(self):
38 | self.buffer_sampler = get_buffer_sampler(self.train_env, self.cfg.agent.sampler)
39 | self.buffer = HerReplayBufferWithGT(buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params,
40 | batch_size=self.cfg.batch_size, sampler=self.buffer_sampler)
41 | self.demo_buffer = HerReplayBufferWithGT(buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params,
42 | batch_size=self.cfg.batch_size, sampler=self.buffer_sampler)
43 |
44 | def _setup_agent(self):
45 | self.agent = make_sl_agent(self.env_params, self.buffer_sampler, self.cfg.agent)
46 |
47 | def _setup_sampler(self):
48 | self.train_sampler = Sampler(self.train_env, self.agent, self.env_params['max_timesteps'])
49 | self.eval_sampler = Sampler(self.eval_env, self.agent, self.env_params['max_timesteps'])
50 |
51 | def _setup_logger(self):
52 | update_mpi_config(self.cfg)
53 | if self.is_chef:
54 | exp_name = f"SL_{self.cfg.task}_{self.cfg.subtask}_{self.cfg.agent.name}_seed{self.cfg.seed}"
55 | if self.cfg.postfix is not None:
56 | exp_name = exp_name + '_' + self.cfg.postfix
57 | self.wb = WandBLogger(exp_name=exp_name, project_name=self.cfg.project_name, entity=self.cfg.entity_name, \
58 | path=self.work_dir, conf=self.cfg)
59 | self.logger = Logger(self.work_dir)
60 | self.termlog = logger
61 | save_cmd(self.work_dir)
62 | else:
63 | self.wb, self.logger, self.termlog = None, None, None
64 |
65 | def _setup_misc(self):
66 | init_demo_buffer(self.cfg, self.demo_buffer, self.agent)
67 |
68 | if self.is_chef:
69 | self.model_dir = self.work_dir / 'model'
70 | self.model_dir.mkdir(exist_ok=True)
71 | for file in os.listdir(self.model_dir):
72 | os.remove(self.model_dir / file)
73 |
74 | self.device = torch.device(self.cfg.device)
75 | self.timer = Timer()
76 | self._global_step = 0
77 | self._global_episode = 0
78 | set_seed_everywhere(self.cfg.seed)
79 |
80 | def train(self):
81 | n_train_episodes = int(self.cfg.n_train_steps / self.env_params['max_timesteps'])
82 | n_eval_episodes = int(n_train_episodes / self.cfg.n_eval) * self.cfg.mpi.num_workers
83 | n_save_episodes = int(n_train_episodes / self.cfg.n_save) * self.cfg.mpi.num_workers
84 | n_log_episodes = int(n_train_episodes / self.cfg.n_log) * self.cfg.mpi.num_workers
85 |
86 | assert n_save_episodes > n_eval_episodes
87 | if n_save_episodes % n_eval_episodes != 0:
88 | n_save_episodes = int(n_save_episodes / n_eval_episodes) * n_eval_episodes
89 |
90 | train_until_episode = Until(n_train_episodes)
91 | save_every_episodes = Every(n_save_episodes)
92 | eval_every_episodes = Every(n_eval_episodes)
93 | log_every_episodes = Every(n_log_episodes)
94 | seed_until_steps = Until(self.cfg.n_seed_steps)
95 |
96 | if self.is_chef:
97 | self.termlog.info('Starting training')
98 | while train_until_episode(self.global_episode):
99 | self._train_episode(log_every_episodes, seed_until_steps)
100 |
101 | if eval_every_episodes(self.global_episode):
102 | score = self.eval()
103 |
104 | if not self.cfg.dont_save and save_every_episodes(self.global_episode) and self.is_chef:
105 | filename = CheckpointHandler.get_ckpt_name(self.global_episode)
106 | # TODO(tao): expose scoring metric
107 | CheckpointHandler.save_checkpoint({
108 | 'episode': self.global_episode,
109 | 'global_step': self.global_step,
110 | 'state_dict': self.agent.state_dict(),
111 | 'o_norm': self.agent.o_norm,
112 | 'g_norm': self.agent.g_norm,
113 | 'score': score,
114 | }, self.model_dir, filename)
115 | if self.cfg.save_buffer:
116 | self.buffer.save(self.model_dir, self.global_episode)
117 | self.termlog.info(f'Save checkpoint to {os.path.join(self.model_dir, filename)}')
118 |
119 | def _train_episode(self, log_every_episodes, seed_until_steps):
120 | # sync network parameters across workers
121 | if self.use_multiple_workers:
122 | self.agent.sync_networks()
123 |
124 | self.timer.reset()
125 | batch_time = AverageMeter()
126 | ep_start_step = self.global_step
127 | metrics = None
128 |
129 | # collect experience
130 | rollout_storage = RolloutStorage()
131 | episode, rollouts, env_steps = self.train_sampler.sample_episode(is_train=True, render=False)
132 | if self.use_multiple_workers:
133 | rollouts = mpi_gather_experience_episode(rollouts)
134 |
135 | # update status
136 | rollout_storage.append(episode)
137 | rollout_status = rollout_storage.rollout_stats()
138 | self._global_step += int(mpi_sum(env_steps))
139 | self._global_episode += int(mpi_sum(1))
140 |
141 | # save to buffer
142 | self.buffer.store_episode(rollouts)
143 | self.agent.update_normalizer(rollouts)
144 |
145 | # update policy
146 | if not seed_until_steps(ep_start_step):
147 | if self.is_chef:
148 | metrics = self.agent.update(self.buffer, self.demo_buffer)
149 | if self.use_multiple_workers:
150 | self.agent.sync_networks()
151 |
152 | # log results
153 | if metrics is not None and log_every_episodes(self.global_episode) and self.is_chef:
154 | elapsed_time, total_time = self.timer.reset()
155 | batch_time.update(elapsed_time)
156 | togo_train_time = batch_time.avg * (self.cfg.n_train_steps - ep_start_step) / env_steps / self.cfg.mpi.num_workers
157 |
158 | self.logger.log_metrics(metrics, self.global_step, ty='train')
159 | with self.logger.log_and_dump_ctx(self.global_step, ty='train') as log:
160 | log('fps', env_steps / elapsed_time)
161 | log('total_time', total_time)
162 | log('episode_reward', rollout_status.avg_reward)
163 | log('episode_length', env_steps)
164 | log('episode_sr', rollout_status.avg_success_rate)
165 | log('episode', self.global_episode)
166 | log('step', self.global_step)
167 | log('ETA', togo_train_time)
168 | self.wb.log_outputs(metrics, None, log_images=False, step=self.global_step, is_train=True)
169 |
170 | def eval(self):
171 | '''Eval agent.'''
172 | eval_rollout_storage = RolloutStorage()
173 | for _ in range(self.cfg.n_eval_episodes):
174 | episode, _, env_steps = self.eval_sampler.sample_episode(is_train=False, render=True)
175 | eval_rollout_storage.append(episode)
176 | rollout_status = eval_rollout_storage.rollout_stats()
177 | if self.use_multiple_workers:
178 | rollout_status = mpi_gather_experience_rollots(rollout_status)
179 | for key, value in rollout_status.items():
180 | rollout_status[key] = value.mean()
181 |
182 | if self.is_chef:
183 | self.wb.log_outputs(rollout_status, eval_rollout_storage, log_images=True, step=self.global_step)
184 | with self.logger.log_and_dump_ctx(self.global_step, ty='eval') as log:
185 | log('episode_sr', rollout_status.avg_success_rate)
186 | log('episode_reward', rollout_status.avg_reward)
187 | log('episode_length', env_steps)
188 | log('episode', self.global_episode)
189 | log('step', self.global_step)
190 |
191 | del eval_rollout_storage
192 | return rollout_status.avg_success_rate
193 |
194 | @property
195 | def global_step(self):
196 | return self._global_step
197 |
198 | @property
199 | def global_episode(self):
200 | return self._global_episode
201 |
202 | @property
203 | def is_chef(self):
204 | return self.cfg.mpi.is_chef
205 |
206 | @property
207 | def use_multiple_workers(self):
208 | return self.cfg.mpi.num_workers > 1
--------------------------------------------------------------------------------
/viskill/components/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | import inspect
4 | import logging
5 | import os
6 | from collections import defaultdict
7 |
8 | import colorlog
9 | import numpy as np
10 | import torch
11 | import wandb
12 | from termcolor import colored
13 |
14 | from ..utils.general_utils import flatten_dict, np2obj, prefix_dict
15 | from ..utils.vis_utils import add_captions_to_seq
16 |
17 | #----------------------Termnial Logger----------------------
18 | formatter = colorlog.ColoredFormatter(
19 | "%(asctime_log_color)s[%(asctime)s]%(name_log_color)s[%(name)s]%(levelname_log_color)s[%(levelname)s] - %(message_log_color)s%(message)s",
20 | datefmt=None,
21 | reset=True,
22 | secondary_log_colors={
23 | 'asctime': {
24 | 'INFO': 'cyan',
25 | 'ERROR': 'cyan'
26 | },
27 | 'name': {
28 | 'INFO': 'blue',
29 | 'ERROR': 'blue'
30 | },
31 | 'levelname': {
32 | 'INFO': 'green',
33 | 'ERROR': 'red'
34 | },
35 | 'message': {
36 | 'INFO': 'white',
37 | 'ERROR': 'red'
38 | }
39 | },
40 | style="%",
41 | )
42 |
43 | logger = colorlog.getLogger("skill-chaining")
44 | logger.setLevel(logging.INFO)
45 | logger.propagate = False
46 |
47 | if not logger.handlers:
48 | ch = colorlog.StreamHandler()
49 | ch.setLevel(logging.INFO)
50 | ch.setFormatter(formatter)
51 | logger.addHandler(ch)
52 |
53 |
54 | #----------------------WandB Logger----------------------
55 | class WandBLogger:
56 | """Logs to WandB."""
57 | N_LOGGED_SAMPLES = 50 # how many examples should be logged in each logging step
58 |
59 | def __init__(self, exp_name, project_name, entity, path, conf, exclude=None):
60 | """
61 | :param exp_name: full name of experiment in WandB
62 | :param project_name: name of overall project
63 | :param entity: name of head entity in WandB that hosts the project
64 | :param path: path to which WandB log-files will be written
65 | :param conf: hyperparam config that will get logged to WandB
66 | :param exclude: (optional) list of (flattened) hyperparam names that should not get logged
67 | """
68 | if exclude is None: exclude = []
69 | flat_config = flatten_dict(conf)
70 | filtered_config = {k: v for k, v in flat_config.items() if (k not in exclude and not inspect.isclass(v))}
71 |
72 | # clear dir
73 | # save_dir = path / 'wandb'
74 | # save_dir.mkdir(exist_ok=True)
75 | # shutil.rmtree(f"{save_dir}/")
76 |
77 | logger.info("Init wandb")
78 | wandb.init(
79 | resume='allow',
80 | project=project_name,
81 | config=filtered_config,
82 | dir=path,
83 | entity=entity,
84 | notes=conf.notes if 'notes' in conf else '',
85 | )
86 |
87 | def log_scalar_dict(self, d, prefix='', step=None):
88 | """Logs all entries from a dict of scalars. Optionally can prefix all keys in dict before logging."""
89 | if prefix: d = prefix_dict(d, prefix + '_')
90 | wandb.log(d) if step is None else wandb.log(d, step=step)
91 |
92 | def log_videos(self, vids, name, step=None):
93 | """Logs videos to WandB in mp4 format.
94 | Assumes list of numpy arrays as input with [time, channels, height, width]."""
95 | assert len(vids[0].shape) == 4 and vids[0].shape[1] == 3
96 | assert isinstance(vids[0], np.ndarray)
97 | if vids[0].max() <= 1.0: vids = [np.asarray(vid * 255.0, dtype=np.uint8) for vid in vids]
98 | # TODO(karl) expose the FPS as a parameter
99 | log_dict = {name: [wandb.Video(vid, fps=10, format="mp4") for vid in vids]}
100 | wandb.log(log_dict) if step is None else wandb.log(log_dict, step=step)
101 |
102 | def log_plot(self, fig, name, step=None):
103 | """Logs matplotlib graph to WandB.
104 | fig is a matplotlib figure handle."""
105 | img = wandb.Image(fig)
106 | wandb.log({name: img}) if step is None else wandb.log({name: img}, step=step)
107 |
108 | def log_outputs(self, logging_stats, rollout_storage, log_images, step, is_train=False, log_videos=True, log_video_caption=False):
109 | """Visualizes/logs all training outputs."""
110 | self.log_scalar_dict(logging_stats, prefix='train' if is_train else 'eval', step=step)
111 |
112 | if log_images:
113 | assert rollout_storage is not None # need rollout data for image logging
114 | # log rollout videos with info captions
115 | if 'image' in rollout_storage and log_videos:
116 | if log_video_caption:
117 | vids = [np.stack(add_captions_to_seq(rollout.image, np2obj(rollout.info))).transpose(0, 3, 1, 2)
118 | for rollout in rollout_storage.get()[-self.n_logged_samples:]]
119 | else:
120 | vids = [np.stack(rollout.image).transpose(0, 3, 1, 2)
121 | for rollout in rollout_storage.get()[-self.n_logged_samples:]]
122 | self.log_videos(vids, name="rollouts", step=step)
123 |
124 | @property
125 | def n_logged_samples(self):
126 | # TODO(karl) put this functionality in a base logger class + give it default parameters and config
127 | return self.N_LOGGED_SAMPLES
128 |
129 |
130 | #----------------------CSV Logger----------------------
131 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
132 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
133 | ('episode_reward', 'R', 'float'), ('episode_sr', 'SR', 'float'),
134 | ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'),
135 | ('total_time', 'T', 'time'), ('ETA', 'ETA', 'time')]
136 |
137 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
138 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
139 | ('episode_reward', 'R', 'float'),
140 | ('total_time', 'T', 'time')]
141 |
142 |
143 | class AverageMeter(object):
144 | def __init__(self):
145 | self._sum = 0
146 | self._count = 0
147 |
148 | def update(self, value, n=1):
149 | self._sum += value
150 | self._count += n
151 |
152 | def value(self):
153 | return self._sum / max(1, self._count)
154 |
155 |
156 | class MetersGroup(object):
157 | def __init__(self, csv_file_name, formating):
158 | self._csv_file_name = csv_file_name
159 | if(os.path.exists(csv_file_name) and os.path.isfile(csv_file_name)):
160 | os.remove(csv_file_name)
161 |
162 | self._formating = formating
163 | self._meters = defaultdict(AverageMeter)
164 | self._csv_file = None
165 | self._csv_writer = None
166 |
167 | def log(self, key, value, n=1):
168 | self._meters[key].update(value, n)
169 |
170 | def _prime_meters(self):
171 | data = dict()
172 | for key, meter in self._meters.items():
173 | if key.startswith('train'):
174 | key = key[len('train') + 1:]
175 | else:
176 | key = key[len('eval') + 1:]
177 | key = key.replace('/', '_')
178 | data[key] = meter.value()
179 | return data
180 |
181 | def _remove_old_entries(self, data):
182 | rows = []
183 | with self._csv_file_name.open('r') as f:
184 | reader = csv.DictReader(f)
185 | for row in reader:
186 | # print(row)
187 | # if float(row['episode']) >= data['episode']:
188 | # break
189 | rows.append(row)
190 | with self._csv_file_name.open('w') as f:
191 | writer = csv.DictWriter(f,
192 | fieldnames=sorted(data.keys()),
193 | restval=0.0)
194 | writer.writeheader()
195 | for row in rows:
196 | writer.writerow(row)
197 |
198 | def _dump_to_csv(self, data):
199 | if self._csv_writer is None:
200 | should_write_header = True
201 | if self._csv_file_name.exists():
202 | self._remove_old_entries(data)
203 | should_write_header = False
204 |
205 | self._csv_file = self._csv_file_name.open('a')
206 | self._csv_writer = csv.DictWriter(self._csv_file,
207 | fieldnames=sorted(data.keys()),
208 | restval=0.0)
209 | if should_write_header:
210 | self._csv_writer.writeheader()
211 |
212 | self._csv_writer.writerow(data)
213 | self._csv_file.flush()
214 |
215 | def _format(self, key, value, ty):
216 | if ty == 'int':
217 | value = int(value)
218 | return f'{key}: {value}'
219 | elif ty == 'float':
220 | return f'{key}: {value:.04f}'
221 | elif ty == 'time':
222 | value = str(datetime.timedelta(seconds=int(value)))
223 | return f'{key}: {value}'
224 | else:
225 | raise f'invalid format type: {ty}'
226 |
227 | def _dump_to_console(self, data, prefix):
228 | prefix = colored(prefix, 'blue' if prefix == 'train' else 'green')
229 | pieces = [f'| {prefix: <14}']
230 | for key, disp_key, ty in self._formating:
231 | value = data.get(key, 0)
232 | pieces.append(self._format(disp_key, value, ty))
233 | logger.info(' | '.join(pieces))
234 |
235 | def dump(self, step, prefix):
236 | if len(self._meters) == 0:
237 | return
238 | data = self._prime_meters()
239 | data['frame'] = step
240 | self._dump_to_csv(data)
241 | self._dump_to_console(data, prefix)
242 | self._meters.clear()
243 |
244 |
245 | class Logger(object):
246 | def __init__(self, log_dir):
247 | self._log_dir = log_dir
248 | self._train_mg = MetersGroup(log_dir / 'train.csv',
249 | formating=COMMON_TRAIN_FORMAT)
250 | self._eval_mg = MetersGroup(log_dir / 'eval.csv',
251 | formating=COMMON_EVAL_FORMAT)
252 | self._sw = None
253 |
254 | def _try_sw_log(self, key, value, step):
255 | if self._sw is not None:
256 | self._sw.add_scalar(key, value, step)
257 |
258 | def log(self, key, value, step):
259 | assert key.startswith('train') or key.startswith('eval')
260 | if type(value) == torch.Tensor:
261 | value = value.item()
262 | self._try_sw_log(key, value, step)
263 | mg = self._train_mg if key.startswith('train') else self._eval_mg
264 | mg.log(key, value)
265 |
266 | def log_metrics(self, metrics, step, ty):
267 | for key, value in metrics.items():
268 | self.log(f'{ty}/{key}', value, step)
269 |
270 | def dump(self, step, ty=None):
271 | if ty is None or ty == 'eval':
272 | self._eval_mg.dump(step, 'eval')
273 | if ty is None or ty == 'train':
274 | self._train_mg.dump(step, 'train')
275 |
276 | def log_and_dump_ctx(self, step, ty):
277 | return LogAndDumpCtx(self, step, ty)
278 |
279 |
280 | class LogAndDumpCtx:
281 | def __init__(self, logger, step, ty):
282 | self._logger = logger
283 | self._step = step
284 | self._ty = ty
285 |
286 | def __enter__(self):
287 | return self
288 |
289 | def __call__(self, key, value):
290 | self._logger.log(f'{self._ty}/{key}', value, self._step)
291 |
292 | def __exit__(self, *args):
293 | self._logger.dump(self._step, self._ty)
--------------------------------------------------------------------------------
/viskill/trainers/sc_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | from ..agents import make_hier_agent
6 | from ..components.checkpointer import CheckpointHandler
7 | from ..components.logger import Logger, WandBLogger, logger
8 | from ..modules.replay_buffer import (HerReplayBuffer, ReplayBuffer,
9 | get_hier_buffer_samplers)
10 | from ..modules.sampler import HierarchicalSampler
11 | from ..utils.general_utils import (AttrDict, AverageMeter, Every, Timer, Until,
12 | set_seed_everywhere)
13 | from ..utils.mpi import (mpi_gather_experience_successful_transitions,
14 | mpi_gather_experience_transitions, mpi_sum,
15 | mpi_gather_experience_rollots,
16 | update_mpi_config)
17 | from ..utils.rl_utils import RolloutStorage, init_demo_buffer, init_sc_buffer
18 | from .sl_trainer import SkillLearningTrainer
19 |
20 |
21 | class SkillChainingTrainer(SkillLearningTrainer):
22 | def _setup(self):
23 | self._setup_env() # Environment
24 | self._setup_buffer() # Relay buffer
25 | self._setup_agent() # Agent
26 | self._setup_sampler() # Sampler
27 | self._setup_logger() # Logger
28 | self._setup_misc() # MISC
29 |
30 | if self.is_chef:
31 | self.termlog.info('Setup done')
32 |
33 | def _setup_buffer(self):
34 | # Skill learning buffer -> HER dict replay buffer
35 | self.sl_buffer_samplers = get_hier_buffer_samplers(self.train_env, self.cfg.sl_agent.sampler)
36 | self.sl_buffer, self.sl_demo_buffer = AttrDict(), AttrDict()
37 | if self.cfg.agent.update_sl_agent:
38 | self.sl_buffer.update({subtask: HerReplayBuffer(
39 | buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params, batch_size=self.cfg.batch_size,
40 | sampler=self.sl_buffer_samplers[subtask], T=self.env_params.subtask_steps[subtask]) for subtask in self.env_params.subtasks}
41 | )
42 | self.sl_demo_buffer.update({subtask: HerReplayBuffer(
43 | buffer_size=self.cfg.replay_buffer_capacity, env_params=self.env_params, batch_size=self.cfg.batch_size,
44 | sampler=self.sl_buffer_samplers[subtask], T=self.env_params.subtask_steps[subtask]) for subtask in self.env_params.subtasks}
45 | )
46 |
47 | # Skill chaining buffer -> Rollout state replay buffer
48 | self.sc_buffer = AttrDict({subtask : ReplayBuffer(
49 | obs_shape=self.env_params['obs'], action_shape=self.env_params['act_sc'],
50 | capacity=self.cfg.replay_buffer_capacity, batch_size=self.cfg.batch_size, len_cond=self.env_params['len_cond']) for subtask in self.env_params.subtasks}
51 | )
52 | self.sc_demo_buffer = AttrDict({subtask : ReplayBuffer(
53 | obs_shape=self.env_params['obs'], action_shape=self.env_params['act_sc'],
54 | capacity=self.cfg.replay_buffer_capacity, batch_size=self.cfg.batch_size, len_cond=self.env_params['len_cond']) for subtask in self.env_params.subtasks}
55 | )
56 |
57 | def _setup_agent(self):
58 | update_mpi_config(self.cfg)
59 | self.agent = make_hier_agent(self.env_params, self.sl_buffer_samplers, self.cfg)
60 |
61 | def _setup_sampler(self):
62 | self.train_sampler = HierarchicalSampler(self.train_env, self.agent, self.env_params)
63 | self.eval_sampler = HierarchicalSampler(self.eval_env, self.agent, self.env_params)
64 |
65 | def _setup_logger(self):
66 | if self.is_chef:
67 | exp_name = f"SC_{self.cfg.task}_{self.cfg.agent.sc_agent.name}_{self.cfg.agent.sl_agent.name}_seed{self.cfg.seed}"
68 | if self.cfg.postfix is not None:
69 | exp_name = exp_name + '_' + self.cfg.postfix
70 | self.wb = WandBLogger(exp_name=exp_name, project_name=self.cfg.project_name, entity=self.cfg.entity_name, \
71 | path=self.work_dir, conf=self.cfg)
72 | self.logger = Logger(self.work_dir)
73 | self.termlog = logger
74 | else:
75 | self.wb, self.logger, self.termlog = None, None, None
76 |
77 | def _setup_misc(self):
78 | init_sc_buffer(self.cfg, self.sc_buffer, self.agent, self.env_params)
79 | init_sc_buffer(self.cfg, self.sc_demo_buffer, self.agent, self.env_params)
80 |
81 | if self.cfg.agent.update_sl_agent:
82 | for subtask in self.env_params.middle_subtasks:
83 | init_demo_buffer(self.cfg, self.sl_buffer[subtask], self.agent.sl_agent[subtask], subtask, False)
84 | init_demo_buffer(self.cfg, self.sl_demo_buffer[subtask], self.agent.sl_agent[subtask], subtask, False)
85 |
86 | if self.is_chef:
87 | self.model_dir = self.work_dir / 'model'
88 | self.model_dir.mkdir(exist_ok=True)
89 | for file in os.listdir(self.model_dir):
90 | os.remove(self.model_dir / file)
91 |
92 | self.device = torch.device(self.cfg.device)
93 | self.timer = Timer()
94 | self._global_step = 0
95 | self._global_episode = 0
96 | set_seed_everywhere(self.cfg.seed)
97 |
98 | def train(self):
99 | n_train_episodes = int(self.cfg.n_train_steps / self.env_params['max_timesteps'])
100 | n_eval_episodes = int(n_train_episodes / self.cfg.n_eval) * self.cfg.mpi.num_workers
101 | n_save_episodes = int(n_train_episodes / self.cfg.n_save) * self.cfg.mpi.num_workers
102 | n_log_episodes = int(n_train_episodes / self.cfg.n_log) * self.cfg.mpi.num_workers
103 |
104 | assert n_save_episodes >= n_eval_episodes
105 | if n_save_episodes % n_eval_episodes != 0:
106 | n_save_episodes = int(n_save_episodes / n_eval_episodes) * n_eval_episodes
107 |
108 | train_until_episode = Until(n_train_episodes)
109 | save_every_episodes = Every(n_save_episodes)
110 | eval_every_episodes = Every(n_eval_episodes)
111 | log_every_episodes = Every(n_log_episodes)
112 | seed_until_steps = Until(self.cfg.n_seed_steps)
113 |
114 | if self.is_chef:
115 | self.termlog.info('Starting training')
116 | while train_until_episode(self.global_episode):
117 | self._train_episode(log_every_episodes, seed_until_steps)
118 |
119 | if eval_every_episodes(self.global_episode):
120 | score = self.eval()
121 |
122 | if not self.cfg.dont_save and save_every_episodes(self.global_episode) and self.is_chef:
123 | filename = CheckpointHandler.get_ckpt_name(self.global_episode)
124 | # TODO(tao): expose scoring metric
125 | CheckpointHandler.save_checkpoint({
126 | 'episode': self.global_episode,
127 | 'global_step': self.global_step,
128 | 'state_dict': self.agent.state_dict(),
129 | 'score': score,
130 | }, self.model_dir, filename)
131 | self.termlog.info(f'Save checkpoint to {os.path.join(self.model_dir, filename)}')
132 |
133 | def _train_episode(self, log_every_episodes, seed_until_steps):
134 | # sync network parameters across workers
135 | if self.use_multiple_workers:
136 | self.agent.sync_networks()
137 |
138 | self.timer.reset()
139 | batch_time = AverageMeter()
140 | ep_start_step = self.global_step
141 | metrics = None
142 |
143 | # collect experience and save to buffer
144 | rollout_storage = RolloutStorage()
145 | sc_episode, sl_episode, env_steps = self.train_sampler.sample_episode(is_train=True, render=False)
146 | if self.use_multiple_workers:
147 | for subtask in sc_episode.sc_transitions.keys():
148 | transitions_batch = mpi_gather_experience_transitions(sc_episode.sc_transitions[subtask])
149 | # save to buffer
150 | self.sc_buffer[subtask].add_rollouts(transitions_batch)
151 | if self.cfg.agent.sc_agent.normalize:
152 | self.agent.sc_agent.update_normalizer(transitions_batch, subtask)
153 | if self.cfg.agent.update_sl_agent:
154 | for subtask in self.env_params.subtasks:
155 | self.sl_buffer[subtask].store_episode(sl_episode[subtask])
156 |
157 | if self.use_multiple_workers:
158 | for subtask in sc_episode.sc_succ_transitions.keys():
159 | demo_batch = mpi_gather_experience_successful_transitions(sc_episode.sc_succ_transitions[subtask])
160 | # save to buffer
161 | self.sc_demo_buffer[subtask].add_rollouts(demo_batch)
162 | else:
163 | raise NotImplementedError
164 | #transitions_batch = sc_episode.sc_transitions
165 |
166 | # update status
167 | rollout_storage.append(sc_episode)
168 | rollout_status = rollout_storage.rollout_stats()
169 | self._global_step += int(mpi_sum(env_steps))
170 | self._global_episode += int(mpi_sum(1))
171 |
172 | # update policy
173 | if not seed_until_steps(ep_start_step) and self.is_chef:
174 | if not self.cfg.use_demo_buffer:
175 | metrics = self.agent.update(self.sc_buffer, self.sl_buffer)
176 | else:
177 | metrics = self.agent.update(self.sc_buffer, self.sl_buffer, self.sc_demo_buffer, self.sl_demo_buffer)
178 | if self.use_multiple_workers:
179 | self.agent.sync_networks()
180 |
181 | # log results
182 | if metrics is not None and log_every_episodes(self.global_episode) and self.is_chef:
183 | elapsed_time, total_time = self.timer.reset()
184 | batch_time.update(elapsed_time)
185 | togo_train_time = batch_time.avg * (self.cfg.n_train_steps - ep_start_step) / env_steps / self.cfg.mpi.num_workers
186 |
187 | self.logger.log_metrics(metrics, self.global_step, ty='train')
188 | with self.logger.log_and_dump_ctx(self.global_step, ty='train') as log:
189 | log('fps', env_steps / elapsed_time)
190 | log('total_time', total_time)
191 | log('episode_reward', rollout_status.avg_reward)
192 | log('episode_length', env_steps)
193 | log('episode_sr', rollout_status.avg_success_rate)
194 | log('episode', self.global_episode)
195 | log('step', self.global_step)
196 | log('ETA', togo_train_time)
197 | self.wb.log_outputs(metrics, None, log_images=False, step=self.global_step, is_train=True)
198 |
199 | def eval_ckpt(self):
200 | '''Eval checkpoint.'''
201 | CheckpointHandler.load_checkpoint(
202 | self.cfg.sc_ckpt_dir, self.agent, self.device, self.cfg.sc_ckpt_episode
203 | )
204 |
205 | eval_rollout_storage = RolloutStorage()
206 | for _ in range(self.cfg.n_eval_episodes):
207 | episode, _, env_steps = self.eval_sampler.sample_episode(is_train=False, render=True)
208 | eval_rollout_storage.append(episode)
209 | rollout_status = eval_rollout_storage.rollout_stats()
210 | if self.use_multiple_workers:
211 | rollout_status = mpi_gather_experience_rollots(rollout_status)
212 | for key, value in rollout_status.items():
213 | rollout_status[key] = value.mean()
214 |
215 | if self.is_chef:
216 | self.wb.log_outputs(rollout_status, eval_rollout_storage, log_images=True, step=self.global_step)
217 | with self.logger.log_and_dump_ctx(self.global_step, ty='eval') as log:
218 | log('episode_sr', rollout_status.avg_success_rate)
219 | log('episode_reward', rollout_status.avg_reward)
220 | log('episode_length', env_steps)
221 | log('episode', self.global_episode)
222 | log('step', self.global_step)
223 |
224 | del eval_rollout_storage
225 |
--------------------------------------------------------------------------------
/viskill/modules/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import os
3 | import pickle
4 |
5 | import numpy as np
6 | from torch.utils.data import Dataset
7 |
8 | from ..utils.general_utils import AttrDict
9 |
10 |
11 | #-------------------------Hindsight Experience Replay-------------------------
12 | class HerReplayBuffer:
13 | def __init__(self, env_params, buffer_size, batch_size, sampler, T=None):
14 | # TODO(tao): unwrap env_params
15 | self.env_params = env_params
16 | self.T = T if T is not None else env_params['max_timesteps']
17 | self.size = buffer_size // self.T
18 | self.batch_size = batch_size
19 |
20 | # memory management
21 | self.current_size = 0
22 | self.n_transitions_stored = 0
23 | self.sample_func = sampler
24 |
25 | # create the buffer to store info
26 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]),
27 | 'ag': np.empty([self.size, self.T + 1, self.env_params['achieved_goal']]),
28 | 'g': np.empty([self.size, self.T, self.env_params['goal']]),
29 | 'actions': np.empty([self.size, self.T, self.env_params['act']]),
30 | 'dones': np.empty([self.size, self.T, 1]),
31 | }
32 |
33 | # store the episode
34 | def store_episode(self, episode_batch):
35 | mb_obs, mb_ag, mb_g, mb_actions, dones = episode_batch.obs, episode_batch.ag, episode_batch.g, \
36 | episode_batch.actions, episode_batch.dones
37 | batch_size = mb_obs.shape[0]
38 | idxs = self._get_storage_idx(inc=batch_size)
39 |
40 | # store the informations
41 | self.buffers['obs'][idxs] = mb_obs
42 | self.buffers['ag'][idxs] = mb_ag
43 | self.buffers['g'][idxs] = mb_g
44 | self.buffers['actions'][idxs] = mb_actions
45 | self.buffers['dones'][idxs] = dones
46 | self.n_transitions_stored += self.T * batch_size
47 |
48 | # sample the data from the replay buffer
49 | def sample(self):
50 | temp_buffers = {}
51 | for key in self.buffers.keys():
52 | temp_buffers[key] = self.buffers[key][:self.current_size]
53 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :]
54 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :]
55 | # sample transitions
56 | transitions = self.sample_func.sample_her_transitions(temp_buffers, self.batch_size)
57 | return transitions
58 |
59 | def _get_storage_idx(self, inc=None):
60 | inc = inc or 1
61 | if self.current_size+inc <= self.size:
62 | idx = np.arange(self.current_size, self.current_size+inc)
63 | elif self.current_size < self.size:
64 | overflow = inc - (self.size - self.current_size)
65 | idx_a = np.arange(self.current_size, self.size)
66 | idx_b = np.random.randint(0, self.current_size, overflow)
67 | idx = np.concatenate([idx_a, idx_b])
68 | else:
69 | idx = np.random.randint(0, self.size, inc)
70 | self.current_size = min(self.size, self.current_size+inc)
71 | if inc == 1:
72 | idx = idx[0]
73 | return idx
74 |
75 | def save(self, save_dir, episode):
76 | with gzip.open(os.path.join(save_dir, f"replay_buffer_ep{episode}.zip"), 'wb') as f:
77 | pickle.dump(self.buffers, f)
78 | np.save(os.path.join(save_dir, f'idx_size_ep{episode}'), np.array([self.current_size, self.n_transitions_stored]))
79 |
80 | def load(self, save_dir, episode):
81 | with gzip.open(os.path.join(save_dir, f"replay_buffer_ep{episode}.zip"), 'rb') as f:
82 | self.buffers = pickle.load(f)
83 | idx_size = np.load(os.path.join(save_dir, f"idx_size_ep{episode}.npy"))
84 | self.current_size, self.n_transitions_stored = int(idx_size[0]), int(idx_size[1])
85 |
86 |
87 | class HerReplayBufferWithGT(HerReplayBuffer):
88 | def __init__(self, env_params, buffer_size, batch_size, sampler, T=None):
89 | # TODO(tao): unwrap env_params
90 | self.env_params = env_params
91 | self.T = T if T is not None else env_params['max_timesteps']
92 | self.size = buffer_size // self.T
93 | self.batch_size = batch_size
94 | # memory management
95 | self.current_size = 0
96 | self.n_transitions_stored = 0
97 | self.sample_func = sampler
98 | # create the buffer to store info
99 | self.buffers = {'obs': np.empty([self.size, self.T + 1, self.env_params['obs']]),
100 | 'ag': np.empty([self.size, self.T + 1, self.env_params['achieved_goal']]),
101 | 'g': np.empty([self.size, self.T, self.env_params['goal']]),
102 | 'actions': np.empty([self.size, self.T, self.env_params['act']]),
103 | 'dones': np.empty([self.size, self.T, 1]),
104 | 'gt_g': np.empty([self.size, self.T, self.env_params['goal']]),
105 | }
106 | self.sample_keys = ['obs', 'ag', 'g', 'actions', 'dones']
107 |
108 | # store the episode
109 | def store_episode(self, episode_batch):
110 | mb_obs, mb_ag, mb_g, mb_actions, dones, mb_gt_g = episode_batch.obs, episode_batch.ag, episode_batch.g, \
111 | episode_batch.actions, episode_batch.dones, episode_batch.gt_g
112 | batch_size = mb_obs.shape[0]
113 | idxs = self._get_storage_idx(inc=batch_size)
114 |
115 | # store the informations
116 | self.buffers['obs'][idxs] = mb_obs
117 | self.buffers['ag'][idxs] = mb_ag
118 | self.buffers['g'][idxs] = mb_g
119 | self.buffers['actions'][idxs] = mb_actions
120 | self.buffers['dones'][idxs] = dones
121 | self.buffers['gt_g'][idxs] = mb_gt_g
122 | self.n_transitions_stored += self.T * batch_size
123 |
124 | # sample the data from the replay buffer
125 | def sample(self):
126 | temp_buffers = {}
127 | for key in self.sample_keys:
128 | temp_buffers[key] = self.buffers[key][:self.current_size]
129 | temp_buffers['obs_next'] = temp_buffers['obs'][:, 1:, :]
130 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :]
131 | # sample transitions
132 | transitions = self.sample_func.sample_her_transitions(temp_buffers, self.batch_size)
133 | return transitions
134 |
135 |
136 | class HER_sampler:
137 | def __init__(self, replay_strategy, replay_k, reward_func=None):
138 | self.replay_strategy = replay_strategy
139 | self.replay_k = replay_k
140 | if self.replay_strategy == 'future':
141 | self.future_p = 1 - (1. / (1 + replay_k))
142 | else:
143 | self.future_p = 0
144 | self.reward_func = reward_func
145 |
146 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
147 | T = episode_batch['actions'].shape[1]
148 | rollout_batch_size = episode_batch['actions'].shape[0]
149 | batch_size = batch_size_in_transitions
150 |
151 | # select which rollouts and which timesteps to be used
152 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
153 | t_samples = np.random.randint(T, size=batch_size)
154 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
155 |
156 | # her idx
157 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p)
158 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
159 | future_offset = future_offset.astype(int)
160 | future_t = (t_samples + 1 + future_offset)[her_indexes]
161 |
162 | # replace go with achieved goal
163 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
164 | transitions['g'][her_indexes] = future_ag
165 |
166 | # to get the params to re-compute reward
167 | transitions['r'] = self.reward_func(transitions['ag_next'], transitions['g'], None)
168 | if len(transitions['r'].shape) == 1:
169 | transitions['r'] = np.expand_dims(transitions['r'], 1)
170 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}
171 |
172 | return transitions
173 |
174 |
175 | class HER_sampler_seq(HER_sampler):
176 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
177 | T = episode_batch['actions'].shape[1]
178 | rollout_batch_size = episode_batch['actions'].shape[0]
179 | batch_size = batch_size_in_transitions
180 | # select which rollouts and which timesteps to be used
181 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
182 | t_samples = np.random.randint(T-1, size=batch_size) # from T to T-1
183 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() for key in episode_batch.keys()}
184 |
185 | next_actions = episode_batch['actions'][episode_idxs, t_samples + 1].copy()
186 | transitions['next_actions'] = next_actions
187 | # her idx
188 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p)
189 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
190 | future_offset = future_offset.astype(int)
191 | future_t = (t_samples + 1 + future_offset)[her_indexes]
192 | # replace go with achieved goal
193 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
194 | transitions['g'][her_indexes] = future_ag
195 | # to get the params to re-compute reward
196 | transitions['r'] = self.reward_func(transitions['ag_next'], transitions['g'], None)
197 | if len(transitions['r'].shape) == 1:
198 | transitions['r'] = np.expand_dims(transitions['r'], 1)
199 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()}
200 |
201 | return transitions
202 |
203 |
204 | def get_buffer_sampler(env, cfg):
205 | if cfg.type == 'her':
206 | sampler = HER_sampler(
207 | replay_strategy=cfg.strategy,
208 | replay_k=cfg.k,
209 | reward_func=env.compute_reward,
210 | )
211 | elif cfg.type == 'her_seq':
212 | sampler = HER_sampler_seq(
213 | replay_strategy=cfg.strategy,
214 | replay_k=cfg.k,
215 | reward_func=env.compute_reward,
216 | )
217 | else:
218 | raise NotImplementedError
219 | return sampler
220 |
221 |
222 | def get_hier_buffer_samplers(env, cfg):
223 | reward_funcs = env.get_reward_functions()
224 | samplers = AttrDict()
225 | if cfg.type == 'her':
226 | for subtask in env.subtasks:
227 | samplers.update({subtask: HER_sampler(
228 | replay_strategy=cfg.strategy,
229 | replay_k=cfg.k,
230 | reward_func=reward_funcs[subtask],
231 | )})
232 | elif cfg.type == 'her_seq':
233 | for subtask in env.subtasks:
234 | samplers.update({subtask: HER_sampler_seq(
235 | replay_strategy=cfg.strategy,
236 | replay_k=cfg.k,
237 | reward_func=reward_funcs[subtask],
238 | )})
239 | else:
240 | raise NotImplementedError
241 | return samplers
242 |
243 |
244 | #-------------------------Rollout Replay Buffer-------------------------
245 | class ReplayBuffer(Dataset):
246 | """Buffer to store environment transitions."""
247 | def __init__(self, obs_shape, action_shape, capacity, batch_size, len_cond):
248 | self.capacity = capacity
249 | self.batch_size = batch_size
250 | # the proprioceptive obs is stored as float32, pixels obs as uint8
251 | obs_dtype = np.float32
252 |
253 | self.obses = np.empty((capacity, obs_shape), dtype=obs_dtype)
254 | self.next_obses = np.empty((capacity, obs_shape), dtype=obs_dtype)
255 | self.actions = np.empty((capacity, action_shape), dtype=np.float32)
256 | self.rewards = np.empty((capacity, 1), dtype=np.float32)
257 | self.dones = np.empty((capacity, 1), dtype=np.float32)
258 | # gt_action should not be accessed unless entering terminal subtask
259 | # TODO: expose arm number
260 | self.gt_actions = np.empty((capacity, action_shape+len_cond), dtype=np.float32)
261 |
262 | self.idx = 0
263 | self.last_save = 0
264 | self.full = False
265 |
266 | def add(self, obs, action, reward, next_obs, done, gt_action):
267 | np.copyto(self.obses[self.idx], obs)
268 | np.copyto(self.actions[self.idx], action)
269 | np.copyto(self.rewards[self.idx], reward)
270 | np.copyto(self.next_obses[self.idx], next_obs)
271 | np.copyto(self.dones[self.idx], done)
272 | np.copyto(self.gt_actions[self.idx], gt_action)
273 |
274 | self.idx = (self.idx + 1) % self.capacity
275 | self.full = self.full or self.idx == 0
276 |
277 | def add_rollouts(self, rollouts):
278 | for transition in rollouts:
279 | self.add(*transition)
280 |
281 | def sample(self, idxs=None, return_idxs=False):
282 | if idxs is None:
283 | idxs = np.random.randint(
284 | 0, self.capacity if self.full else self.idx, size=self.batch_size
285 | )
286 |
287 | obses = self.obses[idxs]
288 | next_obses = self.next_obses[idxs]
289 | actions = self.actions[idxs]
290 | rewards = self.rewards[idxs]
291 | dones = self.dones[idxs]
292 | gt_actions = self.gt_actions[idxs]
293 | if return_idxs:
294 | return obses, actions, rewards, next_obses, dones, gt_actions, idxs
295 | else:
296 | return obses, actions, rewards, next_obses, dones, gt_actions
297 |
298 | def __getitem__(self, idx):
299 | idx = np.random.randint(
300 | 0, self.capacity if self.full else self.idx, size=1
301 | )
302 | idx = idx[0]
303 | obs = self.obses[idx]
304 | action = self.actions[idx]
305 | reward = self.rewards[idx]
306 | next_obs = self.next_obses[idx]
307 | done = self.dones[idx]
308 | gt_action = self.gt_actions[idx]
309 |
310 | return obs, action, reward, next_obs, done, gt_action
311 |
312 | def __len__(self):
313 | return self.capacity
--------------------------------------------------------------------------------
/viskill/components/envrionment.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from contextlib import contextmanager
3 |
4 | import gym
5 | import numpy as np
6 | import torch
7 | from surrol.utils.pybullet_utils import (pairwise_collision,
8 | pairwise_link_collision)
9 |
10 |
11 | def approx_collision(goal_a, goal_b, th=0.025):
12 | assert goal_a.shape == goal_b.shape
13 | return np.linalg.norm(goal_a - goal_b, axis=-1) < th
14 |
15 |
16 | class SkillLearningWrapper(gym.Wrapper):
17 | def __init__(self, env, subtask, output_raw_obs):
18 | super().__init__(env)
19 | self.subtask = subtask
20 | self._start_subtask = subtask
21 | self._elapsed_steps = None
22 | self._output_raw_obs = output_raw_obs
23 |
24 | @abc.abstractmethod
25 | def _replace_goal_with_subgoal(self, obs):
26 | """Replace achieved goal and desired goal."""
27 | raise NotImplementedError
28 |
29 | @abc.abstractmethod
30 | def _subgoal(self):
31 | """Output goal of subtask."""
32 | raise NotImplementedError
33 |
34 | @contextmanager
35 | def switch_subtask(self, subtask=None):
36 | '''Temporally switch subtask, default: next subtask'''
37 | if subtask is not None:
38 | curr_subtask = self.subtask
39 | self.subtask = subtask
40 | yield
41 | self.subtask = curr_subtask
42 | else:
43 | self.subtask = self.SUBTASK_PREV_SUBTASK[self.subtask]
44 | yield
45 | self.subtask = self.SUBTASK_NEXT_SUBTASK[self.subtask]
46 |
47 |
48 | #-----------------------------BiPegTransfer-v0-----------------------------
49 | class BiPegTransferSLWrapper(SkillLearningWrapper):
50 | '''Wrapper for skill learning'''
51 | SUBTASK_ORDER = {
52 | 'grasp': 0,
53 | 'handover': 1,
54 | 'release': 2
55 | }
56 | SUBTASK_STEPS = {
57 | 'grasp': 45,
58 | 'handover': 35,
59 | 'release': 20
60 | }
61 | SUBTASK_RESET_INDEX = {
62 | 'handover': 4,
63 | 'release': 10
64 | }
65 | SUBTASK_RESET_MAX_STEPS = {
66 | 'handover': 45,
67 | 'release': 70
68 | }
69 | SUBTASK_PREV_SUBTASK = {
70 | 'handover': 'grasp',
71 | 'release': 'handover'
72 | }
73 | SUBTASK_NEXT_SUBTASK = {
74 | 'grasp': 'handover',
75 | 'handover': 'release'
76 | }
77 | SUBTASK_CONTACT_CONDITION = {
78 | 'grasp': [0, 1],
79 | 'handover': [1, 0],
80 | 'release': [0, 0]
81 | }
82 | LAST_SUBTASK = 'release'
83 | def __init__(self, env, subtask='grasp', output_raw_obs=False):
84 | super().__init__(env, subtask, output_raw_obs)
85 | self.done_subtasks = {key: False for key in self.SUBTASK_STEPS.keys()}
86 |
87 | @property
88 | def max_episode_steps(self):
89 | assert np.sum([x for x in self.SUBTASK_STEPS.values()]) == self.env._max_episode_steps
90 | return self.SUBTASK_STEPS[self.subtask]
91 |
92 | def step(self, action):
93 | next_obs, reward, done, info = self.env.step(action)
94 | self._elapsed_steps += 1
95 | next_obs_ = self._replace_goal_with_subgoal(next_obs.copy())
96 | reward = self.compute_reward(next_obs_['achieved_goal'], next_obs_['desired_goal'])
97 | info['is_success'] = reward + 1
98 | done = self._elapsed_steps == self.max_episode_steps
99 | # save groud truth goal
100 | with self.switch_subtask(self.LAST_SUBTASK):
101 | info['gt_goal'] = self._replace_goal_with_subgoal(next_obs.copy())['desired_goal']
102 |
103 | if self._output_raw_obs: return next_obs_, reward, done, info, next_obs
104 | else: return next_obs_, reward, done, info,
105 |
106 | def reset(self):
107 | self.subtask = self._start_subtask
108 | if self.subtask not in self.SUBTASK_RESET_INDEX.keys():
109 | obs = self.env.reset()
110 | self._elapsed_steps = 0
111 | else:
112 | success = False
113 | while not success:
114 | obs = self.env.reset()
115 | self.subtask = self._start_subtask
116 | self._elapsed_steps = 0
117 |
118 | action, skill_index = self.env.get_oracle_action(obs)
119 | count, max_steps = 0, self.SUBTASK_RESET_MAX_STEPS[self.subtask]
120 | while skill_index < self.SUBTASK_RESET_INDEX[self.subtask] and count < max_steps:
121 | obs, reward, done, info = self.env.step(action)
122 | action, skill_index = self.env.get_oracle_action(obs)
123 | count += 1
124 |
125 | # Reset again if failed
126 | with self.switch_subtask():
127 | obs_ = self._replace_goal_with_subgoal(obs.copy()) # in case repeatedly replace goal
128 | success = self.compute_reward(obs_['achieved_goal'], obs_['desired_goal']) + 1
129 |
130 | if self._output_raw_obs: return self._replace_goal_with_subgoal(obs), obs
131 | else: return self._replace_goal_with_subgoal(obs)
132 |
133 | def _replace_goal_with_subgoal(self, obs):
134 | """Replace ag and g"""
135 | subgoal = self._subgoal()
136 | psm1col = pairwise_collision(self.env.obj_id, self.psm1.body)
137 | psm2col = pairwise_collision(self.env.obj_id, self.psm2.body)
138 |
139 | if self.subtask == 'grasp':
140 | obs['achieved_goal'] = np.concatenate([obs['observation'][0: 3], obs['observation'][7: 10], [psm1col, psm2col]])
141 | elif self.subtask == 'handover':
142 | obs['achieved_goal'] = np.concatenate([obs['observation'][7: 10], obs['observation'][0: 3], [psm1col, psm2col]])
143 | elif self.subtask == 'release':
144 | obs['achieved_goal'] = np.concatenate([obs['observation'][7: 10], obs['achieved_goal'], [psm1col, psm2col]])
145 | obs['desired_goal'] = np.append(subgoal, self.SUBTASK_CONTACT_CONDITION[self.subtask])
146 | return obs
147 |
148 | def _subgoal(self):
149 | """Output goal of subtask"""
150 | goal = self.env.subgoals[self.SUBTASK_ORDER[self.subtask]]
151 | return goal
152 |
153 | def compute_reward(self, ag, g, info=None):
154 | """Compute reward that indicates the success of subtask"""
155 | if len(ag.shape) == 1:
156 | if self.subtask == 'release':
157 | goal_reach = self.env.compute_reward(ag[-5:-2], g[-5:-2], None) + 1
158 | else:
159 | goal_reach = self.env.compute_reward(ag[:-2], g[:-2], None) + 1
160 | contact_cond = np.all(ag[-2:]==g[-2:])
161 | reward = (goal_reach and contact_cond) - 1
162 | else:
163 | if self.subtask == 'release':
164 | goal_reach = self.env.compute_reward(ag[:,-5:-2], g[:,-5:-2], None).reshape(-1, 1) + 1
165 | else:
166 | goal_reach = self.env.compute_reward(ag[:,:-2], g[:,:-2], None).reshape(-1, 1) + 1
167 | contact_cond = np.all(ag[:, -2:]==g[:, -2:], axis=1).reshape(-1, 1)
168 | reward = np.all(np.hstack([goal_reach, contact_cond]), axis=1) - 1.
169 | return reward
170 |
171 |
172 | class BiPegTransferSCWrapper(BiPegTransferSLWrapper):
173 | '''Wrapper for skill chaining.'''
174 | MAX_ACTION_RANGE = 4.
175 | REWARD_SCALE = 30.
176 | def step(self, action):
177 | next_obs, reward, done, info = self.env.step(action)
178 | self._elapsed_steps += 1
179 | next_obs_ = self._replace_goal_with_subgoal(next_obs.copy())
180 | reward = self.compute_reward(next_obs_['achieved_goal'], next_obs_['desired_goal'])
181 | info['step'] = 1 - reward
182 | done = self._elapsed_steps == self.SUBTASK_STEPS[self.subtask]
183 | reward = done * reward
184 | info['subtask'] = self.subtask
185 | info['subtask_done'] = False
186 | info['subtask_is_success'] = reward
187 |
188 | if done:
189 | info['subtask_done'] = True
190 | # Transit to next subtask (if current subtask is not terminal) and reset elapsed steps
191 | if self.subtask in self.SUBTASK_NEXT_SUBTASK.keys():
192 | done = False
193 | self._elapsed_steps = 0
194 | self.subtask = self.SUBTASK_NEXT_SUBTASK[self.subtask]
195 | info['is_success'] = False
196 | reward = 0
197 | else:
198 | info['is_success'] = reward
199 | next_obs_ = self._replace_goal_with_subgoal(next_obs)
200 |
201 | if self._output_raw_obs: return next_obs_, reward, done, info, next_obs
202 | else: return next_obs_, reward, done, info
203 |
204 | def reset(self, subtask=None):
205 | self.subtask = self._start_subtask if subtask is None else subtask
206 | if self.subtask not in self.SUBTASK_RESET_INDEX.keys():
207 | obs = self.env.reset()
208 | self._elapsed_steps = 0
209 | else:
210 | success = False
211 | while not success:
212 | obs = self.env.reset()
213 | self.subtask = self._start_subtask if subtask is None else subtask
214 | self._elapsed_steps = 0
215 |
216 | action, skill_index = self.env.get_oracle_action(obs)
217 | count, max_steps = 0, self.SUBTASK_RESET_MAX_STEPS[self.subtask]
218 | while skill_index < self.SUBTASK_RESET_INDEX[self.subtask] and count < max_steps:
219 | obs, reward, done, info = self.env.step(action)
220 | action, skill_index = self.env.get_oracle_action(obs)
221 | count += 1
222 |
223 | # Reset again if failed
224 | with self.switch_subtask():
225 | obs_ = self._replace_goal_with_subgoal(obs.copy()) # in case repeatedly replace goal
226 | success = self.compute_reward(obs_['achieved_goal'], obs_['desired_goal']) + 1
227 |
228 | if self._output_raw_obs: return self._replace_goal_with_subgoal(obs), obs
229 | else: return self._replace_goal_with_subgoal(obs)
230 |
231 | #---------------------------Reward---------------------------
232 | def compute_reward(self, ag, g, info=None):
233 | """Compute reward that indicates the success of subtask"""
234 | if len(ag.shape) == 1:
235 | if self.subtask == 'release':
236 | goal_reach = self.env.compute_reward(ag[-5:-2], g[-5:-2], None) + 1
237 | else:
238 | goal_reach = self.env.compute_reward(ag[:-2], g[:-2], None) + 1
239 | contact_cond = np.all(ag[-2:]==g[-2:])
240 | reward = (goal_reach and contact_cond) - 1
241 | else:
242 | if self.subtask == 'release':
243 | goal_reach = self.env.compute_reward(ag[:,-5:-2], g[:,-5:-2], None).reshape(-1, 1) + 1
244 | else:
245 | goal_reach = self.env.compute_reward(ag[:,:-2], g[:,:-2], None).reshape(-1, 1) + 1
246 | if self.subtask == 'grasp':
247 | raise NotImplementedError
248 | contact_cond = np.all(ag[:, -2:]==g[:, -2:], axis=1).reshape(-1, 1)
249 | reward = np.all(np.hstack([goal_reach, contact_cond]), axis=1) - 1.
250 | return reward + 1
251 |
252 | def goal_adapator(self, goal, subtask, device=None):
253 | '''Make predicted goal compatible with wrapper'''
254 | if isinstance(goal, np.ndarray):
255 | return np.append(goal, self.SUBTASK_CONTACT_CONDITION[subtask])
256 | elif isinstance(goal, torch.Tensor):
257 | assert device is not None
258 | ct_cond = torch.tensor(self.SUBTASK_CONTACT_CONDITION[subtask], dtype=torch.float32)
259 | ct_cond = ct_cond.repeat(goal.shape[0], 1).to(device)
260 | adp_goal = torch.cat([goal, ct_cond], 1)
261 | return adp_goal
262 |
263 | def get_reward_functions(self):
264 | reward_funcs = {}
265 | for subtask in self.subtask_order.keys():
266 | with self.switch_subtask(subtask):
267 | reward_funcs[subtask] = self.compute_reward
268 | return reward_funcs
269 |
270 | @property
271 | def start_subtask(self):
272 | return self._start_subtask
273 |
274 | @property
275 | def max_episode_steps(self):
276 | assert np.sum([x for x in self.SUBTASK_STEPS.values()]) == self.env._max_episode_steps
277 | return self.env._max_episode_steps
278 |
279 | @property
280 | def max_action_range(self):
281 | return self.MAX_ACTION_RANGE
282 |
283 | @property
284 | def subtask_order(self):
285 | return self.SUBTASK_ORDER
286 |
287 | @property
288 | def subtask_steps(self):
289 | return self.SUBTASK_STEPS
290 |
291 | @property
292 | def subtasks(self):
293 | subtasks = []
294 | for subtask, order in self.subtask_order.items():
295 | if order >= self.subtask_order[self.start_subtask]:
296 | subtasks.append(subtask)
297 | return subtasks
298 |
299 | @property
300 | def prev_subtasks(self):
301 | return self.SUBTASK_PREV_SUBTASK
302 |
303 | @property
304 | def next_subtasks(self):
305 | return self.SUBTASK_NEXT_SUBTASK
306 |
307 | @property
308 | def last_subtask(self):
309 | return self.LAST_SUBTASK
310 |
311 | @property
312 | def len_cond(self):
313 | return len(self.SUBTASK_CONTACT_CONDITION[self.last_subtask])
314 |
315 |
316 | class BiPegBoardSLWrapper(BiPegTransferSLWrapper):
317 | '''Wrapper for skill learning'''
318 | SUBTASK_STEPS = {
319 | 'grasp': 30,
320 | 'handover': 35,
321 | 'release': 35
322 | }
323 | SUBTASK_RESET_INDEX = {
324 | 'handover': 5,
325 | 'release': 9
326 | }
327 | SUBTASK_RESET_MAX_STEPS = {
328 | 'handover': 30,
329 | 'release': 60
330 | }
331 |
332 |
333 | class BiPegBoardSCWrapper(BiPegTransferSCWrapper, BiPegBoardSLWrapper):
334 | '''Wrapper for skill chaining'''
335 | SUBTASK_STEPS = {
336 | 'grasp': 30,
337 | 'handover': 35,
338 | 'release': 35
339 | }
340 | SUBTASK_RESET_INDEX = {
341 | 'handover': 5,
342 | 'release': 9
343 | }
344 | SUBTASK_RESET_MAX_STEPS = {
345 | 'handover': 30,
346 | 'release': 60
347 | }
348 |
349 |
350 | class MatchBoardSLWrapper(BiPegTransferSLWrapper, SkillLearningWrapper):
351 | '''Wrapper for skill learning'''
352 | SUBTASK_ORDER = {
353 | 'pull': 0,
354 | 'grasp': 1,
355 | 'release': 2,
356 | 'push': 3
357 | }
358 | SUBTASK_STEPS = {
359 | 'pull': 50,
360 | 'grasp': 30,
361 | 'release': 20,
362 | 'push': 50
363 | }
364 | SUBTASK_RESET_INDEX = {
365 | 'grasp': 5,
366 | 'release': 9,
367 | 'push': 11,
368 | }
369 | SUBTASK_RESET_MAX_STEPS = {
370 | 'grasp': 60,
371 | 'release': 90,
372 | 'push': 110
373 | }
374 | SUBTASK_PREV_SUBTASK = {
375 | 'grasp': 'pull',
376 | 'release': 'grasp',
377 | 'push': 'release'
378 | }
379 | SUBTASK_NEXT_SUBTASK = {
380 | 'pull': 'grasp',
381 | 'grasp': 'release',
382 | 'release': 'push'
383 | }
384 | SUBTASK_CONTACT_CONDITION = {
385 | 'pull': [0],
386 | 'grasp': [0],
387 | 'release': [0],
388 | 'push': [0]
389 | }
390 | LAST_SUBTASK = 'push'
391 | def __init__(self, env, subtask='grasp', output_raw_obs=False):
392 | super().__init__(env, subtask, output_raw_obs)
393 | self.col_with_lid = False
394 |
395 | def reset(self):
396 | self.col_with_lid = False
397 | return super().reset()
398 |
399 | def _replace_goal_with_subgoal(self, obs):
400 | """Replace ag and g"""
401 | subgoal = self._subgoal()
402 |
403 | # collision condition
404 | if not self.col_with_lid:
405 | self.col_with_lid = pairwise_link_collision(self.env.psm1.body, 4, self.env.obj_ids['fixed'][-1], self.env.target_row * 6 + 3) != ()
406 |
407 | if self.subtask in ['pull', 'push']:
408 | obs['achieved_goal'] = np.concatenate([obs['observation'][0: 3], obs['achieved_goal'][3:6], [int(self.col_with_lid)]])
409 | obs['desired_goal'] = np.concatenate([subgoal[0:3], subgoal[6:9], [0]])
410 | elif self.subtask in ['grasp', 'release']:
411 | obs['achieved_goal'] = np.concatenate([obs['observation'][0: 3], obs['achieved_goal'][:3], [int(self.col_with_lid)]])
412 | obs['desired_goal'] = np.concatenate([subgoal[0:3], subgoal[3:6], [0]])
413 | return obs
414 |
415 | #---------------------------Reward---------------------------
416 | def compute_reward(self, ag, g, info=None):
417 | """Compute reward that indicates the success of subtask"""
418 | if len(ag.shape) == 1:
419 | goal_reach = self.env.compute_reward(ag[:-1], g[:-1], None) + 1
420 | reward = (goal_reach and (1 - ag[-1])) - 1
421 | else:
422 | goal_reach = self.env.compute_reward(ag[:, :-1], g[:, :-1], None).reshape(-1, 1) + 1
423 | reward = np.all(np.hstack([goal_reach, 1 - ag[:, -1].reshape(-1, 1)]), axis=1) - 1.
424 | return reward
425 |
426 |
427 | class MatchBoardSCWrapper(MatchBoardSLWrapper, BiPegTransferSCWrapper):
428 | MAX_ACTION_RANGE = 4.
429 | def step(self, action):
430 | next_obs, reward, done, info = self.env.step(action)
431 | self._elapsed_steps += 1
432 | next_obs_ = self._replace_goal_with_subgoal(next_obs.copy())
433 | reward = self.compute_reward(next_obs_['achieved_goal'], next_obs_['desired_goal'])
434 | info['step'] = 1 - reward
435 | done = self._elapsed_steps == self.SUBTASK_STEPS[self.subtask]
436 | reward = done * reward
437 | info['subtask'] = self.subtask
438 | info['subtask_done'] = False
439 | info['subtask_is_success'] = reward
440 |
441 | if done:
442 | info['subtask_done'] = True
443 | # Transit to next subtask (if current subtask is not terminal) and reset elapsed steps
444 | if self.subtask in self.SUBTASK_NEXT_SUBTASK.keys():
445 | done = False
446 | self._elapsed_steps = 0
447 | self.subtask = self.SUBTASK_NEXT_SUBTASK[self.subtask]
448 | info['is_success'] = False
449 | reward = 0
450 | else:
451 | info['is_success'] = reward
452 | next_obs_ = self._replace_goal_with_subgoal(next_obs)
453 |
454 | if self._output_raw_obs: return next_obs_, reward, done, info, next_obs
455 | else: return next_obs_, reward, done, info
456 |
457 | def reset(self, subtask=None):
458 | self.col_with_lid = False
459 | return BiPegTransferSCWrapper.reset(self, subtask)
460 |
461 | def compute_reward(self, ag, g, info=None):
462 | """Compute reward that indicates the success of subtask"""
463 | if len(ag.shape) == 1:
464 | goal_reach = self.env.compute_reward(ag[:-1], g[:-1], None) + 1
465 | reward = (goal_reach and (1 - ag[-1])) - 1
466 | else:
467 | goal_reach = self.env.compute_reward(ag[:, :-1], g[:, :-1], None).reshape(-1, 1) + 1
468 | reward = np.all(np.hstack([goal_reach, 1 - ag[:, -1].reshape(-1, 1)]), axis=1) - 1.
469 | return reward + 1
470 |
471 | def goal_adapator(self, goal, subtask, device=None):
472 | '''Make predicted goal compatible with wrapper'''
473 | if isinstance(goal, np.ndarray):
474 | return np.append(goal, self.SUBTASK_CONTACT_CONDITION[subtask])
475 | elif isinstance(goal, torch.Tensor):
476 | assert device is not None
477 | ct_cond = torch.tensor(self.SUBTASK_CONTACT_CONDITION[subtask], dtype=torch.float32)
478 | ct_cond = ct_cond.repeat(goal.shape[0], 1).to(device)
479 | adp_goal = torch.cat([goal, ct_cond], 1)
480 | return adp_goal
481 |
482 |
483 | #-----------------------------Make envrionment-----------------------------
484 | def make_env(cfg):
485 | env = gym.make(cfg.task)
486 | if cfg.task == 'BiPegTransfer-v0':
487 | if cfg.skill_chaining:
488 | env = BiPegTransferSCWrapper(env, cfg.init_subtask, output_raw_obs=False)
489 | else:
490 | env = BiPegTransferSLWrapper(env, cfg.subtask, output_raw_obs=False)
491 | elif cfg.task == 'BiPegBoard-v0':
492 | if cfg.skill_chaining:
493 | env = BiPegBoardSCWrapper(env, cfg.init_subtask, output_raw_obs=False)
494 | else:
495 | env = BiPegBoardSLWrapper(env, cfg.subtask, output_raw_obs=False)
496 | elif cfg.task == 'MatchBoard-v0':
497 | if cfg.skill_chaining:
498 | env = MatchBoardSCWrapper(env, cfg.init_subtask, output_raw_obs=False)
499 | else:
500 | env = MatchBoardSLWrapper(env, cfg.subtask, output_raw_obs=False)
501 | else:
502 | raise NotImplementedError
503 | return env
--------------------------------------------------------------------------------