├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── assets ├── bottle_out_fridge.png ├── cup_out_open_cabinet.png ├── meat_off_grill.png ├── reach_elbow_pose.png ├── serve_coffee.png └── slide_cup.png ├── docker_run.sh ├── exp_hydra_logs └── .gitkeep ├── libvvcl.so ├── requirements.txt ├── rl ├── .gitignore ├── agent │ ├── mf_utils.py │ ├── sac.py │ └── sac.yaml ├── configs │ ├── default.yaml │ ├── rlbench_both.yaml │ ├── rlbench_pixels.yaml │ └── rlbench_states.yaml ├── env │ ├── __init__.py │ ├── custom_arm_action_modes.py │ ├── custom_rlbench_tasks │ │ ├── elbow_angle_task_design.ttt │ │ ├── task_ttms │ │ │ ├── barista.ttm │ │ │ ├── bottle_out_moving_fridge.ttm │ │ │ ├── cup_out_open_cabinet.ttm │ │ │ ├── reach_gripper_and_elbow.ttm │ │ │ └── slide_cup.ttm │ │ └── tasks │ │ │ ├── __init__.py │ │ │ ├── barista.py │ │ │ ├── bottle_out_moving_fridge.py │ │ │ ├── cup_out_open_cabinet.py │ │ │ ├── meat_off_grill.py │ │ │ ├── phone_on_base.py │ │ │ ├── pick_and_lift.py │ │ │ ├── pick_up_cup.py │ │ │ ├── put_rubbish_in_bin.py │ │ │ ├── reach_gripper_and_elbow.py │ │ │ ├── reach_target.py │ │ │ ├── slide_cup.py │ │ │ ├── stack_wine.py │ │ │ ├── take_lid_off_saucepan.py │ │ │ └── take_umbrella_out_of_umbrella_stand.py │ └── rlbench_envs.py ├── envs.py ├── hydra │ └── hydra_logging │ │ └── custom.yaml ├── logger.py ├── np_replay.py ├── train.py ├── train.yaml └── utils.py └── xvfb_run.sh /.gitignore: -------------------------------------------------------------------------------- 1 | exp_sweep 2 | exp_local 3 | __pycache__ 4 | *.mp4 5 | *.jpeg 6 | datasets/* 7 | pt_models 8 | nohup* 9 | data 10 | output.png 11 | exp_hydra_logs/2* 12 | .ipynb_checkpoints/ 13 | .git/* 14 | rw_demos/* 15 | rw_models/* -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.2.2-base-ubuntu20.04 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | 5 | # controls which driver libraries/binaries will be mounted inside the container 6 | # docs: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html 7 | # This can be specifed at runtime or at image build time. Here we do it at build time. 8 | # Graphics capability is required to be specified for rendering. 9 | ENV NVIDIA_DRIVER_CAPABILITIES=compute,graphics,utility 10 | 11 | ARG uid 12 | ARG user 13 | 14 | RUN \ 15 | apt-get update && \ 16 | apt-get install -y \ 17 | sudo \ 18 | python3-pip \ 19 | git \ 20 | zsh \ 21 | curl \ 22 | wget \ 23 | unzip \ 24 | tmux \ 25 | vim \ 26 | mesa-utils \ 27 | xvfb \ 28 | qtbase5-dev \ 29 | qtdeclarative5-dev \ 30 | libqt5webkit5-dev \ 31 | libsqlite3-dev \ 32 | qt5-default \ 33 | qttools5-dev-tools 34 | 35 | RUN \ 36 | useradd -u ${uid} ${user} && \ 37 | echo "${user} ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/${user} && \ 38 | chmod 0440 /etc/sudoers.d/${user} && \ 39 | mkdir -p /home/${user} && \ 40 | chown -R ${user}:${user} /home/${user} && \ 41 | chown ${user}:${user} /usr/local/bin && \ 42 | mkdir /tmp/.X11-unix && \ 43 | chmod 1777 /tmp/.X11-unix && \ 44 | chown root /tmp/.X11-unix 45 | 46 | USER ${user} 47 | 48 | WORKDIR /home/${user} 49 | 50 | WORKDIR /home/${user} 51 | 52 | 53 | RUN \ 54 | cur=`pwd` && \ 55 | wget http://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz && \ 56 | tar -xf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz && \ 57 | export COPPELIASIM_ROOT="$cur/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04" && \ 58 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT:$COPPELIASIM_ROOT/platforms && \ 59 | export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT && \ 60 | git clone https://github.com/stepjam/PyRep.git && \ 61 | cd PyRep && \ 62 | pip3 install -r requirements.txt && \ 63 | pip3 install setuptools && \ 64 | pip3 install . 65 | 66 | RUN \ 67 | git clone https://github.com/stepjam/RLBench.git && cd RLBench && \ 68 | pip install -r requirements.txt && \ 69 | pip install . 70 | 71 | RUN \ 72 | mkdir -p ~/.config/fish 73 | 74 | RUN \ 75 | export cur=`pwd` && echo "set -x COPPELIASIM_ROOT $cur/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04" >> ~/.config/fish/config.fish && \ 76 | echo 'set -x LD_LIBRARY_PATH $LD_LIBRARY_PATH $COPPELIASIM_ROOT $COPPELIASIM_ROOT/platforms' >> ~/.config/fish/config.fish && \ 77 | echo 'set -x QT_QPA_PLATFORM_PLUGIN_PATH $COPPELIASIM_ROOT' >> ~/.config/fish/config.fish 78 | 79 | RUN sudo apt-get install -y fish 80 | 81 | RUN pip3 install torch==2.0.1 hydra-core scipy shapely trimesh pyrender wandb==0.15.4 timm 82 | 83 | # install VS Code (code-server) 84 | RUN curl -fsSL https://code-server.dev/install.sh | sh 85 | RUN code-server --install-extension ms-python.python ms-toolsai.jupyter 86 | 87 | # install VS Code extensions 88 | RUN sudo apt-get install wget gpg 89 | RUN wget -qO- https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > packages.microsoft.gpg 90 | RUN sudo install -D -o root -g root -m 644 packages.microsoft.gpg /etc/apt/keyrings/packages.microsoft.gpg 91 | RUN sudo sh -c 'echo "deb [arch=amd64,arm64,armhf signed-by=/etc/apt/keyrings/packages.microsoft.gpg] https://packages.microsoft.com/repos/code stable main" > /etc/apt/sources.list.d/vscode.list' 92 | RUN rm -f packages.microsoft.gpg 93 | RUN sudo apt-get update 94 | RUN sudo apt-get install -y code 95 | # RUN code --install-extension ms-python.python ms-toolsai.jupyter 96 | 97 | # Additional packages 98 | RUN sudo apt-get install libglew2.1 libgl1-mesa-glx libosmesa6 99 | RUN pip3 install gym termcolor hydra-submitit-launcher PyOpenGL==3.1.4 PyOpenGL_accelerate notebook matplotlib 100 | RUN pip3 install --upgrade requests 101 | 102 | COPY libvvcl.so CoppeliaSim_Edu_V4_1_0_Ubuntu20_04 103 | RUN pip3 install tensorboard imageio[ffmpeg] hydra-joblib-launcher moviepy 104 | 105 | RUN echo 'set -x PATH $PATH $HOME/.local/bin' >> ~/.config/fish/config.fish 106 | 107 | RUN \ 108 | sudo apt-get update && sudo apt-get install -y \ 109 | ffmpeg git python3-pip vim libglew-dev \ 110 | x11-xserver-utils xvfb \ 111 | && sudo apt-get clean 112 | 113 | RUN pip3 install einops dm_env 114 | 115 | ENTRYPOINT ["/usr/bin/fish"] 116 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Pietro Mazzaglia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code for Redundancy-aware Action Spaces for Robot Learning 2 | 3 | 4 |

5 | 6 | 7 | 8 |

9 |

10 | 11 | 12 | 13 |

14 | 15 | ## RL experiments 16 | 17 | ### How to run 18 | 19 | For running the code, using Docker is recommended. 20 | 21 | To start the container, run: 22 | 23 | `bash docker_run.sh` 24 | 25 | To start an experiment launch a command inside the docker like: 26 | 27 | `sh xvfb_run.sh python3 rl/train.py use_wandb=False project_name=er agent=sac configs=rlbench_both task=rlbench_stack_wine env.action_mode=erjoint seed=1,2,3 experiment=test --multirun &` 28 | 29 | The command above launches 3 seeds in parallel. You should set `seed=SEED` to a unique value and remove `--multirun` to run a single experiment. 30 | 31 | Note that by default all log is redirected to the experiments folder in `out.log` and `err.log`, so no output will be logged in your console. 32 | 33 | ### Options 34 | 35 | The possible action modes are: `env.action_mode=joint,ee,erjoint,erangle` 36 | 37 | You can change the ERJ joint to the wrist one by setting: `+env.erj_joint=6` 38 | 39 | ### Notes 40 | 41 | * The implementation of SAC (with distributional critic) can be found at `rl/agent/sac.py` along with the hyperparameters used in the `sac.yaml` file. 42 | * The new RLBench tasks can be found at `rl/env/custom_rlbench_tasks.py` 43 | * The new action modes can be found at `rl/env/custom_arm_action_modes.py` 44 | 45 | ## Imitation learning experiments 46 | 47 | For the real-world experiments we used ACT: https://github.com/tonyzhaozh/act 48 | 49 | Some additional notes about the setup are: 50 | * Each action space model was trained for 3000 epochs using the hyper-parameters presented in the ACT paper (apart from using a batch size of 16, and chunk size of 20). 51 | * Additionally, the EE pose was added to the state information. 52 | * All quaternions in demonstrations and inference were forced to have a positive `w`. Lastly, only a wrist camera (with resolution 224x224) was used. 53 | 54 | For the IK in we used: [pick\_ik](https://github.com/PickNikRobotics/pick_ik) with default parameters for solving the IK for EE and ERJ. 55 | 56 | Some additional notes: 57 | * We followed the standard [pick\_ik Kinematics Solver" tutorial](https://moveit.picknik.ai/main/doc/how_to_guides/pick_ik/pick_ik_tutorial.html\#pick-ik-kinematics-solver) from the pick\_ik documentation 58 | * We only altered two parameters for the IK, from the standard parameters: `approximate_solution_position_threshold=0.01` and `approximate_solution_orientation_threshold=0.01` (both from the original value 0.05) to increase accuracy 59 | 60 | 61 | -------------------------------------------------------------------------------- /assets/bottle_out_fridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/bottle_out_fridge.png -------------------------------------------------------------------------------- /assets/cup_out_open_cabinet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/cup_out_open_cabinet.png -------------------------------------------------------------------------------- /assets/meat_off_grill.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/meat_off_grill.png -------------------------------------------------------------------------------- /assets/reach_elbow_pose.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/reach_elbow_pose.png -------------------------------------------------------------------------------- /assets/serve_coffee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/serve_coffee.png -------------------------------------------------------------------------------- /assets/slide_cup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/slide_cup.png -------------------------------------------------------------------------------- /docker_run.sh: -------------------------------------------------------------------------------- 1 | docker run --rm --gpus=all --net=host --privileged -e=DISPLAY -e=XDG_RUNTIME_DIR --shm-size=2gb \ 2 | -v `pwd`:${HOME}/repos/ERAS \ 3 | -v ${HOME}/.ssh:${HOME}/.ssh \ 4 | -v $XDG_RUNTIME_DIR:$XDG_RUNTIME_DIR \ 5 | -v ${XAUTHORITY}:${XAUTHORITY} \ 6 | -v /usr/local/share/ca-certificates:/usr/local/share/ca-certificates \ 7 | -v /etc/ssl/certs/ca-certificates.crt:/etc/ssl/certs/ca-certificates.crt \ 8 | -w ${HOME}/repos/ERAS \ 9 | -it $(docker build -q --build-arg uid=$(id -u ${USER}) --build-arg user=${USER} -t "local/robot_learning/er:latest" .) -------------------------------------------------------------------------------- /exp_hydra_logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/exp_hydra_logs/.gitkeep -------------------------------------------------------------------------------- /libvvcl.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/libvvcl.so -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | scipy 3 | shapely 4 | trimesh 5 | pyrender 6 | notebook -------------------------------------------------------------------------------- /rl/.gitignore: -------------------------------------------------------------------------------- 1 | ManiSkill2-Learn 2 | 3 | __pycache__ 4 | exp_local 5 | exp_sweep 6 | pretrained_models/* 7 | pt_models 8 | nohup* 9 | data 10 | output.png 11 | exp_hydra_logs/* 12 | datasets/* -------------------------------------------------------------------------------- /rl/agent/mf_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import utils 6 | 7 | class DrQEncoder(nn.Module): 8 | def __init__(self, obs_shape, key='observation', is_rgb=True): 9 | super().__init__() 10 | 11 | self._key = key 12 | self.is_rgb = is_rgb 13 | assert len(obs_shape) == 3 14 | self.repr_dim = 32 * 35 * 35 15 | 16 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2), 17 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 18 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 19 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1), 20 | nn.ReLU()) 21 | 22 | self.apply(utils.weight_init) 23 | 24 | def forward(self, obs): 25 | h = self.convnet(obs[self._key]) 26 | h = h.view(h.shape[0], -1) 27 | return h 28 | 29 | def preprocess(self, obs): 30 | if self.is_rgb: 31 | obs[self._key] = obs[self._key] / 255.0 - 0.5 32 | return obs 33 | 34 | class IdentityEncoder(nn.Identity): 35 | def __init__(self, key): 36 | super().__init__() 37 | self._key = key 38 | self.fake_param = nn.parameter.Parameter() 39 | 40 | def forward(self, obs): 41 | return obs[self._key] 42 | 43 | def preprocess(self, obs): 44 | return obs 45 | 46 | class Actor(nn.Module): 47 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim): 48 | super().__init__() 49 | 50 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim), 51 | nn.LayerNorm(feature_dim), nn.Tanh(), 52 | ) 53 | 54 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 55 | nn.ReLU(inplace=True), 56 | nn.Linear(hidden_dim, hidden_dim), 57 | nn.ReLU(inplace=True), 58 | nn.Linear(hidden_dim, action_dim * 2)) 59 | 60 | self.log_std_bounds = [-10, 2] 61 | self.apply(utils.weight_init) 62 | 63 | def forward(self, enc,): 64 | enc = self.trunk(enc) 65 | mu, log_std = self.policy(enc).chunk(2, dim=-1) 66 | self._mu_std = mu.std().item() 67 | 68 | # constrain log_std inside [log_std_min, log_std_max] 69 | log_std = torch.tanh(log_std) 70 | log_std_min, log_std_max = self.log_std_bounds 71 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) 72 | std = log_std.exp() 73 | 74 | dist = utils.SquashedNormal(mu, std) 75 | return dist 76 | 77 | class ActorFixedStd(nn.Module): 78 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim): 79 | super().__init__() 80 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim), 81 | nn.LayerNorm(feature_dim), nn.Tanh(), 82 | ) 83 | 84 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim), 85 | nn.ReLU(inplace=True), 86 | nn.Linear(hidden_dim, hidden_dim), 87 | nn.ReLU(inplace=True), 88 | nn.Linear(hidden_dim, action_dim)) 89 | 90 | self.apply(utils.weight_init) 91 | 92 | def forward(self, enc, std,): 93 | enc = self.trunk(enc) 94 | mu = self.policy(enc) 95 | self._mu_std = mu.std().item() 96 | mu = torch.tanh(mu) 97 | std = torch.ones_like(mu) * std 98 | 99 | dist = utils.TruncatedNormal(mu, std) 100 | return dist 101 | 102 | def signed_hyperbolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: 103 | """Signed hyperbolic transform, inverse of signed_parabolic.""" 104 | return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x 105 | 106 | def signed_parabolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: 107 | """Signed parabolic transform, inverse of signed_hyperbolic.""" 108 | z = torch.sqrt(1 + 4 * eps * (eps + 1 + torch.abs(x))) / 2 / eps - 1 / 2 / eps 109 | return torch.sign(x) * (torch.square(z) - 1) 110 | 111 | def from_categorical(distribution, limit=20, offset=0., logits=True): 112 | distribution = distribution.float().squeeze(-1) # Avoid any fp16 shenanigans 113 | if logits: 114 | distribution = torch.softmax(distribution, -1) 115 | num_atoms = distribution.shape[-1] 116 | shift = limit * 2 / (num_atoms - 1) 117 | weights = torch.linspace(-(num_atoms//2), num_atoms//2, num_atoms, device=distribution.device).float().unsqueeze(-1) 118 | return signed_parabolic((distribution @ weights) * shift) - offset 119 | 120 | def to_categorical(value, limit=20, offset=0., num_atoms=251): 121 | value = value.float() + offset # Avoid any fp16 shenanigans 122 | shift = limit * 2 / (num_atoms - 1) 123 | value = signed_hyperbolic(value) / shift 124 | value = value.clamp(-(num_atoms//2), num_atoms//2) 125 | distribution = torch.zeros(value.shape[0], num_atoms, 1, device=value.device) 126 | lower = value.floor().long() + num_atoms // 2 127 | upper = value.ceil().long() + num_atoms // 2 128 | upper_weight = value % 1 129 | lower_weight = 1 - upper_weight 130 | distribution.scatter_add_(-2, lower.unsqueeze(-1), lower_weight.unsqueeze(-1)) 131 | distribution.scatter_add_(-2, upper.unsqueeze(-1), upper_weight.unsqueeze(-1)) 132 | return distribution 133 | 134 | class Critic(nn.Module): 135 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim): 136 | super().__init__() 137 | 138 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim), 139 | nn.LayerNorm(feature_dim), nn.Tanh(), 140 | ) 141 | 142 | self.q1_net = nn.Sequential( 143 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True), 144 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 145 | nn.Linear(hidden_dim, 1)) 146 | 147 | self.q2_net = nn.Sequential( 148 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True), 149 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 150 | nn.Linear(hidden_dim, 1)) 151 | 152 | self.apply(utils.weight_init) 153 | 154 | def forward(self, enc, action): 155 | enc = self.trunk(enc) 156 | obs_action = torch.cat([enc, action], dim=-1) 157 | q1 = self.q1_net(obs_action) 158 | q2 = self.q2_net(obs_action) 159 | return q1, q2 160 | 161 | class DistributionalCritic(nn.Module): 162 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim, num_atoms=251): 163 | super().__init__() 164 | self.distributional = True 165 | 166 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim), 167 | nn.LayerNorm(feature_dim), nn.Tanh(), 168 | ) 169 | 170 | self.q1_net = nn.Sequential( 171 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True), 172 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 173 | nn.Linear(hidden_dim, num_atoms)) 174 | 175 | self.q2_net = nn.Sequential( 176 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True), 177 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), 178 | nn.Linear(hidden_dim, num_atoms)) 179 | 180 | self.apply(utils.weight_init) 181 | 182 | def forward(self, enc, action): 183 | enc = self.trunk(enc) 184 | obs_action = torch.cat([enc, action], dim=-1) 185 | q1 = self.q1_net(obs_action) 186 | q2 = self.q2_net(obs_action) 187 | return q1, q2 -------------------------------------------------------------------------------- /rl/agent/sac.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import OrderedDict, defaultdict 7 | 8 | import utils 9 | from utils import StreamNorm 10 | from agent.mf_utils import * 11 | 12 | class SACAgent: 13 | def __init__(self, 14 | name, 15 | cfg, 16 | obs_space, 17 | act_space, 18 | device, 19 | lr, 20 | hidden_dim, 21 | feature_dim, 22 | critic_target_tau, 23 | action_target_entropy, 24 | init_temperature, 25 | policy_delay, 26 | frame_stack, 27 | distributional, 28 | normalize_reward, 29 | normalize_returns, 30 | obs_keys, 31 | drq_encoder, 32 | drq_aug): 33 | self.cfg = cfg 34 | self.act_space = act_space 35 | self.obs_space = obs_space 36 | self.action_dim = np.sum((np.prod(v.shape) for v in act_space.values())) # prev: action_shape[0] 37 | self.hidden_dim = hidden_dim 38 | self.lr = lr 39 | self.device = device 40 | self.critic_target_tau = critic_target_tau 41 | self.policy_delay = policy_delay 42 | self.obs_keys = obs_keys.split('|') 43 | shapes = {} 44 | for k,v in obs_space.items(): 45 | shapes[k] = list(v.shape) 46 | if len(v.shape) == 3: 47 | shapes[k][0] = shapes[k][0] * frame_stack 48 | 49 | self.frame_stack = frame_stack 50 | self.obs_buffer = defaultdict(list) 51 | self._batch_reward = 0 52 | 53 | # models 54 | self.encoders = {} 55 | self.augs = {} 56 | for k in self.obs_keys: 57 | if len(shapes[k]) == 3: 58 | img_size = shapes[k][-1] 59 | pad = img_size // 21 # pad=4 for 84 60 | self.augs[k] = utils.RandomShiftsAug(pad=pad) if drq_aug else nn.Identity() 61 | if drq_encoder: 62 | self.encoders[k] = DrQEncoder(shapes[k], key=k, is_rgb=obs_space[k].shape[0] == 3).to(self.device) 63 | else: 64 | raise NotImplementedError("") 65 | else: 66 | self.augs[k] = nn.Identity() 67 | self.encoders[k] = IdentityEncoder(k) 68 | self.encoders[k].repr_dim = shapes[k][0] 69 | self.encoders = nn.ModuleDict(self.encoders) 70 | self.enc_repr_dim = sum(e.repr_dim for e in self.encoders.values()) 71 | 72 | self.actor = Actor(self.enc_repr_dim, self.action_dim, 73 | hidden_dim, feature_dim).to(device) 74 | 75 | if distributional: 76 | self.critic = DistributionalCritic(self.enc_repr_dim, self.action_dim, 77 | hidden_dim, feature_dim).to(device) 78 | self.critic_target = DistributionalCritic(self.enc_repr_dim, self.action_dim, 79 | hidden_dim, feature_dim).to(device) 80 | else: 81 | self.critic = Critic(self.enc_repr_dim, self.action_dim, 82 | hidden_dim, feature_dim).to(device) 83 | self.critic_target = Critic(self.enc_repr_dim, self.action_dim, 84 | hidden_dim, feature_dim).to(device) 85 | self.critic_target.load_state_dict(self.critic.state_dict()) 86 | 87 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(device) 88 | self.log_alpha.requires_grad = True 89 | if action_target_entropy == 'neg': 90 | # set target entropy to -|A| 91 | self.target_entropy = -self.action_dim 92 | elif action_target_entropy == 'neg_double': 93 | self.target_entropy = -self.action_dim * 2 94 | elif action_target_entropy == 'neglog': 95 | self.target_entropy = -torch.Tensor([self.action_dim]).to(self.device) 96 | elif action_target_entropy == 'zero': 97 | self.target_entropy = 0 98 | else: 99 | self.target_entropy = action_target_entropy 100 | 101 | # optimizers 102 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) 103 | self.encoder_opt = torch.optim.Adam(self.encoders.parameters(), lr=lr) 104 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr) 105 | self.log_alpha_opt = torch.optim.Adam([self.log_alpha], lr=lr) 106 | 107 | self.normalize_reward = normalize_reward 108 | self.normalize_returns = normalize_returns 109 | self.rewnorm = StreamNorm(**{"momentum": 0.99, "scale": 1.0, "eps": 1e-8}, device=self.device) 110 | self.retnorm = StreamNorm(**{"momentum": 0.99, "scale": 1.0, "eps": 1e-8}, device=self.device) 111 | 112 | self.train() 113 | self.critic_target.train() 114 | 115 | def init_meta(self): 116 | return {} 117 | 118 | def get_meta_specs(self): 119 | return {} 120 | 121 | def update_meta(self, meta, global_step, time_step): 122 | return self.init_meta() 123 | 124 | def train(self, training=True): 125 | self.training = training 126 | self.encoders.train(training) 127 | self.actor.train(training) 128 | self.critic.train(training) 129 | 130 | @property 131 | def alpha(self): 132 | return self.log_alpha.exp() 133 | 134 | @torch.no_grad() 135 | def act(self, obs, meta, step, eval_mode, state,): 136 | is_first = obs['is_first'].all() or len(self.obs_buffer[self.obs_keys[0]]) == 0 137 | 138 | for k in self.obs_keys: 139 | obs[k] = torch.as_tensor(np.copy(obs[k]), device=self.device) 140 | if is_first: 141 | self.obs_buffer[k] = [obs[k]] * self.frame_stack 142 | else: 143 | self.obs_buffer[k].pop(0) 144 | self.obs_buffer[k].append(obs[k]) 145 | obs_ch = obs[k].shape[1] 146 | obs_size = obs[k].shape[2:] 147 | obs[k] = torch.stack(self.obs_buffer[k], dim=1).reshape(-1, obs_ch * self.frame_stack, *obs_size) 148 | 149 | obs = torch.cat([ e(e.preprocess(obs)) for e in self.encoders.values()], dim=-1) 150 | 151 | policy = self.actor(obs,) 152 | if eval_mode: 153 | action = policy.mean 154 | else: 155 | action = policy.sample() 156 | if step < (self.cfg.num_seed_frames // self.cfg.action_repeat): 157 | action.uniform_(-1.0, 1.0) 158 | action = action.clamp(-1.0, 1.0) 159 | # @returns: action, state 160 | return action.cpu().numpy(), None 161 | 162 | def update_critic(self, obs, action, reward, discount, next_obs, step): 163 | metrics = dict() 164 | 165 | with torch.no_grad(): 166 | dist = self.actor(next_obs) 167 | next_action = dist.rsample() 168 | log_prob = dist.log_prob(next_action).sum(-1, keepdim=True) 169 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 170 | if getattr(self.critic, 'distributional', False): 171 | target_Q1, target_Q2 = from_categorical(target_Q1), from_categorical(target_Q2) 172 | target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_prob 173 | mag_norm_return, batch_return_metrics = self.retnorm((discount * target_V).clone()) 174 | if self.normalize_returns: 175 | ret_mean, ret_var, ret_std = self.retnorm.corrected_mean_var_std() 176 | target_V = target_V / ret_std 177 | target_Q = reward + (discount * target_V) 178 | if getattr(self.critic, 'distributional', False): 179 | target_Q_dist = to_categorical(target_Q,) 180 | 181 | Q1, Q2 = self.critic(obs, action) 182 | if getattr(self.critic, 'distributional', False): 183 | critic_loss = - torch.mean(torch.sum(torch.log_softmax(Q1, -1) * target_Q_dist.squeeze(-1).detach(), -1)) - torch.mean(torch.sum(torch.log_softmax(Q2, -1) * target_Q_dist.squeeze(-1).detach(), -1)) 184 | Q1, Q2 = from_categorical(Q1), from_categorical(Q2) 185 | else: 186 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q) 187 | 188 | metrics['critic_target_q'] = target_Q.mean().item() 189 | metrics['critic_target_q_max'] = target_Q.max().item() 190 | metrics['critic_target_q_min'] = target_Q.min().item() 191 | if self.normalize_returns: 192 | metrics['critic_target_v_running_std'] = ret_std.item() 193 | metrics['critic_q1'] = Q1.mean().item() 194 | metrics['critic_q2'] = Q2.mean().item() 195 | metrics['critic_loss'] = critic_loss.item() 196 | 197 | # optimize encoder and critic 198 | self.critic_opt.zero_grad(set_to_none=True) 199 | self.encoder_opt.zero_grad(set_to_none=True) 200 | 201 | critic_loss.backward() 202 | 203 | self.critic_opt.step() 204 | self.encoder_opt.step() 205 | 206 | return metrics 207 | 208 | def update_actor_and_alpha(self, obs, action, step): 209 | metrics = dict() 210 | 211 | policy = self.actor(obs) 212 | action = policy.rsample() 213 | log_prob = policy.log_prob(action).sum(-1, keepdim=True) 214 | 215 | Q1, Q2 = self.critic(obs, action) 216 | if getattr(self.critic, 'distributional', False): 217 | Q1, Q2 = from_categorical(Q1), from_categorical(Q2) 218 | Q = torch.min(Q1, Q2) 219 | 220 | actor_loss = (self.alpha.detach() * log_prob -Q).mean() 221 | 222 | # optimize actor 223 | self.actor_opt.zero_grad(set_to_none=True) 224 | actor_loss.backward() 225 | self.actor_opt.step() 226 | 227 | alpha_loss = (self.alpha * 228 | (-log_prob - self.target_entropy).detach()).mean() 229 | 230 | self.log_alpha_opt.zero_grad(set_to_none=True) 231 | alpha_loss.backward() 232 | self.log_alpha_opt.step() 233 | 234 | metrics['actor_loss'] = actor_loss.item() 235 | metrics['alpha_loss'] = alpha_loss.item() 236 | metrics['actor_mean_stddev'] = self.actor._mu_std 237 | metrics['alpha'] = self.alpha.item() 238 | metrics['actor_ent'] = -log_prob.mean().item() 239 | 240 | policy_ent_per_dim = policy.base_dist.entropy().mean(dim=0) 241 | for ai in range(action.shape[-1]): 242 | metrics[f'policy_dist/dim_{ai}'] = policy_ent_per_dim[ai] 243 | 244 | return metrics 245 | 246 | def update(self, batch, step): 247 | metrics = dict() 248 | 249 | obs, next_obs = {}, {} 250 | for k in self.obs_keys: 251 | b, t = batch[k].shape[:2] 252 | # assert t == (self.frame_stack + 1) 253 | obs_ch = batch[k].shape[2] 254 | obs_size = batch[k].shape[3:] 255 | if len(obs_size) == 2: 256 | obs[k] = batch[k][:, 0:self.frame_stack].reshape(b, obs_ch * self.frame_stack, *obs_size) 257 | next_obs[k] = batch[k][:, 1:self.frame_stack+1].reshape(b, obs_ch * self.frame_stack, *obs_size) 258 | else: 259 | obs[k] = batch[k][:, self.frame_stack-1].reshape(b, obs_ch, *obs_size) 260 | next_obs[k] = batch[k][:, self.frame_stack].reshape(b, obs_ch, *obs_size) 261 | 262 | obs[k] = self.augs[k](obs[k].float()).to(self.device) 263 | next_obs[k] = self.augs[k](next_obs[k].float()).to(self.device) 264 | 265 | action = batch['action'][:, self.frame_stack].to(self.device) 266 | reward = batch['reward'][:, self.frame_stack].to(self.device) 267 | discount = (batch['discount'][:, self.frame_stack] * self.cfg.discount).to(self.device) 268 | 269 | mag_norm_reward, batch_rew_metrics = self.rewnorm(reward.clone()) 270 | if self.normalize_reward: 271 | rw_mean, rw_var, rw_std = self.rewnorm.corrected_mean_var_std() 272 | reward = reward / rw_std 273 | 274 | obs = torch.cat([e(e.preprocess(obs)) for e in self.encoders.values()], dim=-1) 275 | with torch.no_grad(): 276 | next_obs = torch.cat([e(e.preprocess(next_obs)) for e in self.encoders.values()], dim=-1) 277 | 278 | if self.normalize_reward: 279 | metrics['reward_running_mean'] = rw_mean.item() 280 | metrics['reward_running_std'] = rw_std.item() 281 | 282 | # update critic 283 | metrics.update( 284 | self.update_critic(obs, action, reward, discount, next_obs, step)) 285 | 286 | if step % self.policy_delay == 0: 287 | # update actor 288 | metrics.update(self.update_actor_and_alpha(obs.detach(), action, step)) 289 | 290 | # update critic target 291 | utils.soft_update_params(self.critic, self.critic_target, 292 | self.critic_target_tau) 293 | # @returns: state, metrics 294 | return None, metrics -------------------------------------------------------------------------------- /rl/agent/sac.yaml: -------------------------------------------------------------------------------- 1 | # @package agent 2 | _target_: agent.sac.SACAgent 3 | name: sac 4 | obs_space: ??? # to be specified later 5 | act_space: ??? # to be specified later 6 | device: ${device} 7 | lr: 3e-4 # 1e-4 in ExORL 8 | critic_target_tau: 0.01 # 0.005 in SpinningUp # 0.01 in EXORL 9 | hidden_dim: 1024 10 | feature_dim: 50 11 | # entropy 12 | init_temperature: 0.1 # 0.1 was default 13 | action_target_entropy: neg # neg was default 14 | # 15 | policy_delay: 1 16 | frame_stack: 1 # 3 for DMC pixels 17 | obs_keys: front_rgb|wrist_rgb|state # default for pixels 18 | drq_encoder: true 19 | drq_aug: true 20 | # normalization 21 | distributional: true 22 | normalize_reward: true 23 | normalize_returns: true 24 | -------------------------------------------------------------------------------- /rl/configs/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: False} 3 | discount: 0.99 4 | 5 | # RLBench defaults 6 | train_every_actions: 8 7 | parallel_envs: 12 8 | async_mode: 'FULL' 9 | batch_size: 128 10 | batch_length: 2 -------------------------------------------------------------------------------- /rl/configs/rlbench_both.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | obs_type: both 3 | action_repeat: 1 4 | encoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} 5 | decoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} 6 | replay.capacity: 2e6 7 | 8 | env: 9 | action_mode: ee 10 | cameras: front|wrist 11 | state_info: gripper_open|gripper_pose|joint_cart_pos|joint_positions -------------------------------------------------------------------------------- /rl/configs/rlbench_pixels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | obs_type: pixels 3 | action_repeat: 1 4 | encoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} 5 | decoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} 6 | replay.capacity: 2e6 7 | 8 | env: 9 | action_mode: ee -------------------------------------------------------------------------------- /rl/configs/rlbench_states.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | obs_type: states 3 | action_repeat: 1 4 | encoder: {mlp_keys: 'state', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]} 5 | decoder: {mlp_keys: 'state', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]} 6 | replay.capacity: 2e6 7 | 8 | env: 9 | action_mode: ee -------------------------------------------------------------------------------- /rl/env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/__init__.py -------------------------------------------------------------------------------- /rl/env/custom_arm_action_modes.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import List, Union 3 | from pyrep.objects import Object 4 | 5 | import numpy as np 6 | from enum import Enum 7 | from pyquaternion import Quaternion 8 | from scipy.spatial.transform import Rotation 9 | from pyrep.const import ConfigurationPathAlgorithms as Algos, ObjectType 10 | from pyrep.errors import ConfigurationPathError, IKError, ConfigurationError 11 | from pyrep.robots.configuration_paths.arm_configuration_path import ArmConfigurationPath 12 | from pyrep.const import PYREP_SCRIPT_TYPE 13 | 14 | from rlbench.backend.exceptions import InvalidActionError 15 | from rlbench.backend.robot import Robot 16 | from rlbench.backend.scene import Scene 17 | from rlbench.const import SUPPORTED_ROBOTS 18 | 19 | from rlbench.action_modes.arm_action_modes import ArmActionMode, EndEffectorPoseViaIK 20 | from pyrep.backend import sim, utils 21 | from pyrep.objects.dummy import Dummy 22 | 23 | def assert_action_shape(action: np.ndarray, expected_shape: tuple): 24 | if np.shape(action) != expected_shape: 25 | raise InvalidActionError( 26 | 'Expected the action shape to be: %s, but was shape: %s' % ( 27 | str(expected_shape), str(np.shape(action)))) 28 | 29 | 30 | def assert_unit_quaternion(quat): 31 | if not np.isclose(np.linalg.norm(quat), 1.0): 32 | raise InvalidActionError('Action contained non unit quaternion!') 33 | 34 | 35 | def calculate_delta_pose(robot: Robot, action: np.ndarray): 36 | a_x, a_y, a_z, a_qx, a_qy, a_qz, a_qw = action 37 | x, y, z, qx, qy, qz, qw = robot.arm.get_tip().get_pose() 38 | new_rot = Quaternion( 39 | a_qw, a_qx, a_qy, a_qz) * Quaternion(qw, qx, qy, qz) 40 | qw, qx, qy, qz = list(new_rot) 41 | pose = [a_x + x, a_y + y, a_z + z] + [qx, qy, qz, qw] 42 | return pose 43 | 44 | def calculate_delta_pose_elbow(robot: Robot, action: np.ndarray): 45 | a_x, a_y, a_z, ae_x, ae_y, ae_z, a_qx, a_qy, a_qz, a_qw = action 46 | x, y, z, qx, qy, qz, qw = robot.arm.get_tip().get_pose() 47 | ex, ey, ez = robot.arm._ik_elbow.get_position() 48 | new_rot = Quaternion( 49 | a_qw, a_qx, a_qy, a_qz) * Quaternion(qw, qx, qy, qz) 50 | qw, qx, qy, qz = list(new_rot) 51 | new_action = [a_x + x, a_y + y, a_z + z] + [ae_x + ex, ae_y + ey, ae_z + ez] + [qx, qy, qz, qw] 52 | return new_action 53 | 54 | class RelativeFrame(Enum): 55 | WORLD = 0 56 | EE = 1 57 | 58 | class EEOrientationState(Enum): 59 | FREE = 0 60 | FIXED = 1 61 | KEEP = 2 62 | 63 | class ERAngleViaIK(ArmActionMode): 64 | """High-level action where target EE pose + Elbow angle is given in ER 65 | space (End-Effector and Elbow) and reached via IK. 66 | 67 | Given a target pose, IK via inverse Jacobian is performed. This requires 68 | the target pose to be close to the current pose, otherwise the action 69 | will fail. It is up to the user to constrain the action to 70 | meaningful values. 71 | 72 | The decision to apply collision checking is a crucial trade off! 73 | With collision checking enabled, you are guaranteed collision free paths, 74 | but this may not be applicable for task that do require some collision. 75 | E.g. using this mode on pushing object will mean that the generated 76 | path will actively avoid not pushing the object. 77 | """ 78 | 79 | def __init__(self, 80 | absolute_mode: bool = True, 81 | frame: RelativeFrame = RelativeFrame.WORLD, 82 | collision_checking: bool = False, 83 | orientation_state: EEOrientationState = None): 84 | self._absolute_mode = absolute_mode 85 | self._frame = frame 86 | self._collision_checking = collision_checking 87 | self._orientation_state = orientation_state 88 | self._action_shape = (8,) 89 | 90 | def setup_arm(self, arm, name='Panda', suffix=''): 91 | # @comment: params are hardcoded for now, but could be useful to extend for all robots 92 | arm._ik_elbow_target = Dummy('%s_elbow_target%s' % (name, suffix)) 93 | arm._ik_elbow = Dummy('%s_elbow%s' % (name, suffix)) 94 | arm._elbow_ik_group = sim.simGetIkGroupHandle('%s_e3_ik%s' % (name, suffix)) 95 | joint2 = sim.simGetObjectHandle('%s_joint2%s' % (name, suffix)) 96 | joint6 = sim.simGetObjectHandle('%s_joint6%s' % (name, suffix)) 97 | arm._joint2_obj = Object.get_object(joint2) 98 | arm._joint6_obj = Object.get_object(joint6) 99 | 100 | sim.simSetIkGroupProperties( 101 | arm._elbow_ik_group, 102 | sim.sim_ik_damped_least_squares_method, 103 | maxIterations=400, 104 | damping=1e-3 105 | ) 106 | 107 | ee_constraints = \ 108 | sim.sim_ik_x_constraint | sim.sim_ik_y_constraint | sim.sim_ik_z_constraint | \ 109 | sim.sim_ik_alpha_beta_constraint | sim.sim_ik_gamma_constraint 110 | if self._orientation_state == EEOrientationState.FREE: 111 | ee_constraints = sim.sim_ik_x_constraint | sim.sim_ik_y_constraint |sim.sim_ik_z_constraint 112 | sim.simSetIkElementProperties( 113 | arm._elbow_ik_group, 114 | arm._ik_tip.get_handle(), 115 | ee_constraints, 116 | precision=[5e-4, 5/180*np.pi], 117 | weight=[1, 1] 118 | ) 119 | 120 | elbow_constraints = sim.sim_ik_alpha_beta_constraint | sim.sim_ik_gamma_constraint 121 | sim.simSetIkElementProperties( 122 | arm._elbow_ik_group, 123 | arm._ik_elbow.get_handle(), 124 | elbow_constraints, 125 | precision=[5e-4, 3/180*np.pi], 126 | weight=[0, 1] 127 | ) 128 | 129 | def solve_ik_via_jacobian(self, 130 | arm, 131 | position: Union[List[float], np.ndarray], 132 | euler: Union[List[float], np.ndarray] = None, 133 | quaternion: Union[List[float], np.ndarray] = None, 134 | angle: float = None, 135 | relative_to: Object = None) -> List[float]: 136 | """Solves an IK group and returns the calculated joint values. 137 | 138 | This IK method performs a linearisation around the current robot 139 | configuration via the Jacobian. The linearisation is valid when the 140 | start and goal pose are not too far away, but after a certain point, 141 | linearisation will no longer be valid. In that case, the user is better 142 | off using 'solve_ik_via_sampling'. 143 | 144 | Must specify either rotation in euler or quaternions, but not both! 145 | 146 | :param arm: The CoppelliaSim arm to move. 147 | :param position: The x, y, z position of the ee target. 148 | :param euler: The x, y, z orientation of the ee target (in radians). 149 | :param quaternion: A list containing the quaternion (x,y,z,w). 150 | :param angle: The target angle for the elbow to rotate around the 151 | vector normal to the circle of possible elbow positions (in radians). 152 | :param relative_to: Indicates relative to which reference frame we want 153 | the target pose. Specify None to retrieve the absolute pose, 154 | or an Object relative to whose reference frame we want the pose. 155 | :return: A list containing the calculated joint values. 156 | """ 157 | assert len(position) == 3 158 | arm._ik_target.set_position(position, relative_to) 159 | 160 | if euler is not None: 161 | arm._ik_target.set_orientation(euler, relative_to) 162 | elif quaternion is not None: 163 | arm._ik_target.set_quaternion(quaternion, relative_to) 164 | 165 | target_elbow_quat = arm._ik_elbow.get_quaternion() 166 | 167 | if angle is not None: 168 | w = arm._joint6_obj.get_position() - arm._joint2_obj.get_position() 169 | 170 | w_norm = w / np.linalg.norm(w) 171 | q = np.concatenate((np.sin(angle)*w_norm, np.array([np.cos(angle)]))) 172 | r = Rotation.from_quat(q) 173 | 174 | elbow_rot = Rotation.from_quat(arm._ik_elbow.get_quaternion()) 175 | target_elbow_quat = (r * elbow_rot).as_quat() 176 | 177 | arm._ik_elbow_target.set_quaternion(target_elbow_quat, relative_to) 178 | 179 | ik_result, joint_values = sim.simCheckIkGroup( 180 | arm._elbow_ik_group, [j.get_handle() for j in arm.joints]) 181 | if ik_result == sim.sim_ikresult_fail: 182 | raise IKError('IK failed. Perhaps the distance was between the tip ' 183 | ' and target was too large.') 184 | elif ik_result == sim.sim_ikresult_not_performed: 185 | raise IKError('IK not performed.') 186 | return joint_values 187 | 188 | def action(self, scene: Scene, action: np.ndarray): 189 | """Performs action using IK. 190 | 191 | :param scene: CoppeliaSim scene. 192 | :param action: Must be in the form [ee_pose, elbow_angle] with a len of 8. 193 | """ 194 | arm = scene.robot.arm 195 | if not hasattr(arm, '_ik_elbow_target'): 196 | self.setup_arm(arm) 197 | assert_action_shape(action, self._action_shape) 198 | assert_unit_quaternion(action[3:-1]) 199 | angle = action[-1] 200 | ee_action = action[:-1] 201 | if not self._absolute_mode and self._frame != RelativeFrame.EE: 202 | ee_action = calculate_delta_pose(scene.robot, ee_action) 203 | relative_to = None if self._frame == RelativeFrame.WORLD else arm.get_tip() 204 | 205 | assert relative_to is None # NOT IMPLEMENTED 206 | 207 | if self._orientation_state == EEOrientationState.FIXED: 208 | ee_action[3:] = np.array([0, 1, 0, 0]) 209 | if self._orientation_state == EEOrientationState.KEEP: 210 | ee_action[3:] = arm._ik_tip.get_quaternion() 211 | 212 | try: 213 | joint_positions = self.solve_ik_via_jacobian( 214 | arm, 215 | position=ee_action[:3], quaternion=ee_action[3:], angle=angle, 216 | relative_to=relative_to 217 | ) 218 | arm.set_joint_target_positions(joint_positions) 219 | except IKError as e: 220 | raise InvalidActionError( 221 | 'Could not perform IK via Jacobian; most likely due to current ' 222 | 'end-effector pose being too far from the given target pose. ' 223 | 'Try limiting/bounding your action space.') from e 224 | 225 | done = False 226 | prev_values = None 227 | max_steps = 10 228 | steps = 0 229 | 230 | while not done and steps < max_steps: 231 | scene.step() 232 | cur_positions = arm.get_joint_positions() 233 | reached = np.allclose(cur_positions, joint_positions, atol=0.01) 234 | not_moving = False 235 | if prev_values is not None: 236 | not_moving = np.allclose( 237 | cur_positions, prev_values, atol=0.001) 238 | prev_values = cur_positions 239 | done = reached or not_moving 240 | steps += 1 241 | 242 | def action_shape(self, _: Scene) -> tuple: 243 | return self._action_shape 244 | 245 | class ERJointViaIK(ArmActionMode): 246 | """High-level action where target EE pose + Elbow angle is given in ER 247 | space (End-Effector and Elbow) and reached via IK. 248 | 249 | Given a target pose, IK via inverse Jacobian is performed. This requires 250 | the target pose to be close to the current pose, otherwise the action 251 | will fail. It is up to the user to constrain the action to 252 | meaningful values. 253 | 254 | The decision to apply collision checking is a crucial trade off! 255 | With collision checking enabled, you are guaranteed collision free paths, 256 | but this may not be applicable for task that do require some collision. 257 | E.g. using this mode on pushing object will mean that the generated 258 | path will actively avoid not pushing the object. 259 | """ 260 | 261 | def __init__(self, 262 | absolute_mode: bool = True, 263 | frame: RelativeFrame = RelativeFrame.WORLD, 264 | collision_checking: bool = False, 265 | orientation_state: EEOrientationState = None, 266 | commanded_joint : int = 0, 267 | eps : float = 1e-3, 268 | delta_angle : bool = False): 269 | self._absolute_mode = absolute_mode 270 | self._frame = frame 271 | self._collision_checking = collision_checking 272 | self._orientation_state = orientation_state 273 | self._action_shape = (8,) 274 | self._excl_j_idx = commanded_joint 275 | self.EPS = eps 276 | self.delta_angle = delta_angle 277 | 278 | def solve_ik_via_jacobian(self, 279 | arm, 280 | position: Union[List[float], np.ndarray], 281 | euler: Union[List[float], np.ndarray] = None, 282 | quaternion: Union[List[float], np.ndarray] = None, 283 | relative_to: Object = None) -> List[float]: 284 | """Solves an IK group and returns the calculated joint values. 285 | 286 | This IK method performs a linearisation around the current robot 287 | configuration via the Jacobian. The linearisation is valid when the 288 | start and goal pose are not too far away, but after a certain point, 289 | linearisation will no longer be valid. In that case, the user is better 290 | off using 'solve_ik_via_sampling'. 291 | 292 | Must specify either rotation in euler or quaternions, but not both! 293 | 294 | :param arm: The CoppelliaSim arm to move. 295 | :param position: The x, y, z position of the ee target. 296 | :param euler: The x, y, z orientation of the ee target (in radians). 297 | :param quaternion: A list containing the quaternion (x,y,z,w). 298 | :param angle: The target angle for the elbow to rotate around the 299 | vector normal to the circle of possible elbow positions (in radians). 300 | :param relative_to: Indicates relative to which reference frame we want 301 | the target pose. Specify None to retrieve the absolute pose, 302 | or an Object relative to whose reference frame we want the pose. 303 | :return: A list containing the calculated joint values. 304 | """ 305 | assert len(position) == 3 306 | arm._ik_target.set_position(position, relative_to) 307 | 308 | if euler is not None: 309 | arm._ik_target.set_orientation(euler, relative_to) 310 | elif quaternion is not None: 311 | arm._ik_target.set_quaternion(quaternion, relative_to) 312 | 313 | # Removing the joint controlling the elbow position from IK chain 314 | joint_excl_elbow = arm.joints[:self._excl_j_idx] + arm.joints[self._excl_j_idx+1:] 315 | 316 | ik_result, joint_values = sim.simCheckIkGroup( 317 | arm._ik_group, [j.get_handle() for j in joint_excl_elbow]) 318 | if ik_result == sim.sim_ikresult_fail: 319 | raise IKError('IK failed. Perhaps the distance was between the tip ' 320 | ' and target was too large.') 321 | elif ik_result == sim.sim_ikresult_not_performed: 322 | raise IKError('IK not performed.') 323 | return joint_values 324 | 325 | def action(self, scene: Scene, action: np.ndarray): 326 | """Performs action using IK. 327 | 328 | :param scene: CoppeliaSim scene. 329 | :param action: Must be in the form [ee_pose, elbow_angle] with a len of 8. 330 | """ 331 | arm = scene.robot.arm 332 | assert_action_shape(action, self._action_shape) 333 | assert_unit_quaternion(action[3:-1]) 334 | angle = action[-1] 335 | ee_action = action[:-1] 336 | if self.delta_angle: 337 | assert not self._absolute_mode, 'Cannot use delta_angle_mode' 338 | 339 | if self._excl_j_idx == 0: 340 | c = arm.joints[0].get_position()[:2] 341 | p = arm.get_tip().get_position()[:2] 342 | a = angle 343 | 344 | new_p = [c[0] + (p[0] - c[0]) * np.cos(a) - (p[1] - c[1]) * np.sin(a), 345 | c[1] + (p[0] - c[0]) * np.sin(a) + (p[1] - c[1]) * np.cos(a)] 346 | 347 | angle_delta_ee = new_p - p 348 | ee_action[:2] += angle_delta_ee 349 | 350 | rot = Rotation.from_quat(ee_action[-4:]).as_euler('xyz') 351 | rot[-1] += angle 352 | ee_action[-4:] = Rotation.from_euler('xyz', rot).as_quat() 353 | elif self._excl_j_idx == 6: 354 | # Get axis for z axis wrt the gripper 355 | v = np.array([0.,0.,1.]) 356 | axis = Rotation.from_quat(arm.get_tip().get_quaternion()).apply(v) # Rotation.from_quat(arm.get_tip().get_quaternion()).as_euler('xyz') 357 | # Get rotation given by joint 358 | quat_new = Quaternion(axis=axis, angle=angle) 359 | w_new, x_new, y_new, z_new = quat_new 360 | quat_new = Rotation.from_quat([x_new, y_new, z_new, w_new]) 361 | # Add rotation to original rotation 362 | rot = Rotation.from_quat(ee_action[-4:]) 363 | rot_new = rot * quat_new 364 | ee_action[-4:] = rot_new.as_quat() 365 | else: 366 | raise NotImplementedError(f'Cannot use delta_angle_mode with joint {self._excl_j_idx}') 367 | 368 | if not self._absolute_mode and self._frame != RelativeFrame.EE: 369 | ee_action = calculate_delta_pose(scene.robot, ee_action) 370 | relative_to = None if self._frame == RelativeFrame.WORLD else arm.get_tip() 371 | 372 | if self._orientation_state == EEOrientationState.FIXED: 373 | ee_action[3:] = np.array([0, 1, 0, 0]) 374 | if self._orientation_state == EEOrientationState.KEEP: 375 | ee_action[3:] = arm._ik_tip.get_quaternion() 376 | 377 | try: 378 | # Constrain joint to final position 379 | prev_joint_pos = arm.get_joint_positions()[self._excl_j_idx] 380 | new_joint_pos = angle if self._absolute_mode else prev_joint_pos + angle 381 | 382 | eps = self.EPS 383 | orig_cyclic, orig_interval = arm.joints[self._excl_j_idx].get_joint_interval() 384 | # Making the joint angle valid 385 | if new_joint_pos - eps < orig_interval[0]: 386 | new_joint_pos = orig_interval[0] + eps 387 | if new_joint_pos + eps > orig_interval[0] + orig_interval[1]: 388 | new_joint_pos = orig_interval[0] + orig_interval[1] - eps 389 | # Set target joint interval 390 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, [new_joint_pos-eps, 2 * eps]) 391 | 392 | joint_positions = self.solve_ik_via_jacobian( 393 | arm, 394 | position=ee_action[:3], quaternion=ee_action[3:], 395 | relative_to=relative_to 396 | ) 397 | # Restore joint constraints 398 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, orig_interval) 399 | joint_positions = joint_positions[:self._excl_j_idx] + [new_joint_pos] + joint_positions[self._excl_j_idx:] 400 | arm.set_joint_target_positions(joint_positions) 401 | except IKError as e: 402 | # Restoring joint constraints if there was an error (also restoring to prev_pos first to avoid internal accumulating error) 403 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, [prev_joint_pos, 2 * eps]) 404 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, orig_interval) 405 | raise InvalidActionError( 406 | 'Could not perform IK via Jacobian; most likely due to current ' 407 | 'end-effector pose being too far from the given target pose. ' 408 | 'Try limiting/bounding your action space.') from e 409 | 410 | done = False 411 | prev_values = None 412 | max_steps = 50 413 | steps = 0 414 | 415 | # Move until reached target joint positions or until we stop moving 416 | # (e.g. when we collide wth something) 417 | 418 | while not done and steps < max_steps: 419 | scene.step() 420 | cur_positions = arm.get_joint_positions() 421 | reached = np.allclose(cur_positions, joint_positions, atol=1e-3) 422 | not_moving = False 423 | if prev_values is not None: 424 | not_moving = np.allclose( 425 | cur_positions, prev_values, atol=1e-3) 426 | prev_values = cur_positions 427 | done = reached or not_moving 428 | steps += 1 429 | 430 | def action_shape(self, _: Scene) -> tuple: 431 | return self._action_shape 432 | 433 | class TimeoutEndEffectorPoseViaIK(EndEffectorPoseViaIK): 434 | """ 435 | The exact same EE action mode of RLBench, but with a timeout (max steps), to prevent infinite loops 436 | """ 437 | 438 | def action(self, scene: Scene, action: np.ndarray): 439 | assert_action_shape(action, (7,)) 440 | assert_unit_quaternion(action[3:]) 441 | if not self._absolute_mode and self._frame != 'end effector': 442 | action = calculate_delta_pose(scene.robot, action) 443 | relative_to = None if self._frame == 'world' else scene.robot.arm.get_tip() 444 | 445 | try: 446 | joint_positions = scene.robot.arm.solve_ik_via_jacobian( 447 | action[:3], quaternion=action[3:], relative_to=relative_to) 448 | scene.robot.arm.set_joint_target_positions(joint_positions) 449 | except IKError as e: 450 | raise InvalidActionError( 451 | 'Could not perform IK via Jacobian; most likely due to current ' 452 | 'end-effector pose being too far from the given target pose. ' 453 | 'Try limiting/bounding your action space.') from e 454 | done = False 455 | prev_values = None 456 | steps = 0 457 | max_steps = 50 458 | # Move until reached target joint positions or until we stop moving 459 | # (e.g. when we collide wth something) 460 | while not done and steps < max_steps: 461 | scene.step() 462 | cur_positions = scene.robot.arm.get_joint_positions() 463 | reached = np.allclose(cur_positions, joint_positions, atol=0.01) 464 | not_moving = False 465 | if prev_values is not None: 466 | not_moving = np.allclose( 467 | cur_positions, prev_values, atol=0.001) 468 | prev_values = cur_positions 469 | done = reached or not_moving 470 | steps += 1 -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/elbow_angle_task_design.ttt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/elbow_angle_task_design.ttt -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/task_ttms/barista.ttm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/barista.ttm -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/task_ttms/bottle_out_moving_fridge.ttm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/bottle_out_moving_fridge.ttm -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/task_ttms/cup_out_open_cabinet.ttm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/cup_out_open_cabinet.ttm -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/task_ttms/reach_gripper_and_elbow.ttm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/reach_gripper_and_elbow.ttm -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/task_ttms/slide_cup.ttm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/slide_cup.ttm -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .phone_on_base import PhoneOnBase 2 | from .pick_up_cup import PickUpCup 3 | from .put_rubbish_in_bin import PutRubbishInBin 4 | from .stack_wine import StackWine 5 | from .take_umbrella_out_of_umbrella_stand import TakeUmbrellaOutOfUmbrellaStand 6 | from .take_lid_off_saucepan import TakeLidOffSaucepan 7 | from .reach_target import ReachTarget 8 | from .pick_and_lift import PickAndLift 9 | from .meat_off_grill import MeatOffGrill 10 | from .bottle_out_moving_fridge import BottleOutMovingFridge 11 | from .barista import Barista 12 | from .reach_gripper_and_elbow import ReachGripperAndElbow 13 | from .slide_cup import SlideCup 14 | from .cup_out_open_cabinet import CupOutOpenCabinet 15 | 16 | 17 | CUSTOM_TASKS = \ 18 | { 19 | # Standard 8 tasks 20 | 'reach_target' : ReachTarget, 21 | 'pick_up_cup' : PickUpCup, 22 | 'take_umbrella_out_of_umbrella_stand' : TakeUmbrellaOutOfUmbrellaStand, 23 | 'take_lid_off_saucepan' : TakeLidOffSaucepan, 24 | 'pick_and_lift': PickAndLift, 25 | 'phone_on_base' : PhoneOnBase, 26 | 'stack_wine' : StackWine, 27 | 'put_rubbish_in_bin' : PutRubbishInBin, 28 | # Abbreviations 29 | 'umbrella_out' : TakeUmbrellaOutOfUmbrellaStand, 30 | 'saucepan' : TakeLidOffSaucepan, 31 | # New tasks 32 | 'reach_elbow_pose' : ReachGripperAndElbow, 33 | 'take_bottle_out_fridge': BottleOutMovingFridge, 34 | 'serve_coffee_obstacles' : Barista, 35 | 'slide_cup_obstacles' : SlideCup, 36 | 'meat_off_grill' : MeatOffGrill, 37 | 'take_cup_out_cabinet' : CupOutOpenCabinet, 38 | 39 | } -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/barista.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | from pyrep.objects.dummy import Dummy 4 | from pyrep.objects.shape import Shape 5 | from pyrep.objects.object import Object 6 | from pyrep.objects.proximity_sensor import ProximitySensor 7 | from pyrep.const import PrimitiveShape 8 | from rlbench.const import colors 9 | from rlbench.backend.task import Task 10 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, GraspedCondition, Condition 11 | from rlbench.backend.spawn_boundary import SpawnBoundary 12 | from pyquaternion import Quaternion 13 | from pyrep.const import ObjectType 14 | 15 | import os 16 | from os.path import dirname, abspath, join 17 | 18 | def get_z_distance(obj, target=None): 19 | if target is None: 20 | target_dir = - np.pi/2 21 | else: 22 | target_dir = target.get_orientation()[2] 23 | dist = np.abs(obj.get_orientation()[2] - target_dir) 24 | while dist > 2 * np.pi: 25 | dist -= 2 * np.pi 26 | return min(dist, 2*np.pi - dist) 27 | 28 | class UpCondition(Condition): 29 | def __init__(self, obj: Object, threshold = 0.24): # 6 degress 30 | """in radians if revoloute, or meters if prismatic""" 31 | self._obj = obj 32 | self._threshold = threshold 33 | 34 | def condition_met(self): 35 | dist = get_z_distance(self._obj) 36 | met = dist <= self._threshold 37 | return met, False 38 | 39 | class Barista(Task): 40 | 41 | def init_task(self) -> None: 42 | self.cup_source = Shape('cup_source') 43 | self.plate_target = Shape('plate') 44 | self.plant = Shape('plant') 45 | self.waypoint = Dummy('waypoint1') 46 | self.bottles = [Shape('bottle'), Shape('bottle0'), Shape('bottle1')] 47 | self.collidables = [s for s in self.pyrep.get_objects_in_tree( object_type=ObjectType.SHAPE) if ('bottle' in Object.get_object_name(s._handle) or 'plant' in Object.get_object_name(s._handle)) and s.is_respondable()] 48 | self.check_collisions = True 49 | 50 | self.success_detector = ProximitySensor('success') 51 | 52 | self.grasped_cond = GraspedCondition(self.robot.gripper, self.cup_source) 53 | self.drops_detector = ProximitySensor('detector') 54 | 55 | self.orientation_cond = UpCondition(self.cup_source, threshold=0.5) # ~30 degrees 56 | self.cup_condition = DetectedCondition(self.cup_source, self.success_detector) 57 | self.register_success_conditions([self.orientation_cond, self.cup_condition]) 58 | 59 | self.register_graspable_objects([self.cup_source]) 60 | 61 | def init_episode(self, index: int) -> List[str]: 62 | self.init_orientation = self.cup_source.get_orientation() 63 | 64 | return ['coffee mug on plate'] 65 | 66 | def variation_count(self) -> int: 67 | return 1 68 | 69 | def is_static_workspace(self) -> bool: 70 | return True 71 | 72 | def reward(self,): 73 | grasped = self.grasped_cond.condition_met()[0] 74 | cup_placed = self.cup_condition.condition_met()[0] 75 | well_oriented = self.orientation_cond.condition_met()[0] 76 | 77 | grasp_reward = orientation_reward = move_reward = 0.0 78 | 79 | if self.check_collisions: 80 | if np.any([self.cup_source.check_collision(c) for c in self.collidables]) or \ 81 | np.any([self.robot.arm.check_collision(c) for c in self.collidables]) or \ 82 | np.any([self.robot.gripper.check_collision(c) for c in self.collidables]): 83 | self.terminate_episode = True 84 | else: 85 | self.terminate_episode = False 86 | 87 | if grasped: 88 | grasp_reward = 1.0 89 | 90 | if well_oriented: 91 | orientation_reward = 1.0 92 | else: 93 | # Orientation around vertical axis is locked (very high moment of inertia) -> just other two rotations 94 | orientation_reward = np.exp(-get_z_distance(self.cup_source)) 95 | 96 | if cup_placed: 97 | move_reward = 2.0 98 | else: 99 | cup_forward_and_high = (self.cup_source.get_position()[0] >= (self.waypoint.get_position()[0] - 0.01)) and (self.cup_source.get_position()[2] >= (self.waypoint.get_position()[2] - 0.01)) 100 | 101 | if cup_forward_and_high: 102 | move_reward = 1.0 + np.exp(-np.linalg.norm(self.cup_source.get_position() - self.success_detector.get_position())) 103 | else: 104 | move_reward = np.exp(-np.linalg.norm(self.cup_source.get_position() - self.waypoint.get_position())) 105 | else: 106 | grasp_reward = np.exp(-np.linalg.norm(self.cup_source.get_position() - self.robot.arm.get_tip().get_position())) 107 | 108 | reward = orientation_reward + grasp_reward + move_reward 109 | 110 | return reward 111 | 112 | def load(self): 113 | ttm_file = join( 114 | dirname(abspath(__file__)), 115 | '../task_ttms/%s.ttm' % self.name) 116 | if not os.path.isfile(ttm_file): 117 | raise FileNotFoundError( 118 | 'The following is not a valid task .ttm file: %s' % ttm_file) 119 | self._base_object = self.pyrep.import_model(ttm_file) 120 | return self._base_object 121 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/bottle_out_moving_fridge.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | from pyrep.objects.object import Object 4 | from pyrep.objects.proximity_sensor import ProximitySensor 5 | from pyrep.objects.shape import Shape 6 | from rlbench.backend.conditions import DetectedCondition, Condition, GraspedCondition, JointCondition 7 | from rlbench.backend.task import Task 8 | from pyrep.objects.joint import Joint 9 | from pyrep.const import ObjectType 10 | 11 | import os 12 | from os.path import dirname, abspath, join 13 | 14 | class SteadyCondition(Condition): 15 | def __init__(self, obj, correct_position, threshold = 0.025): 16 | self._obj = obj 17 | self._correct_position = correct_position 18 | self._threshold = threshold 19 | 20 | def condition_met(self): 21 | met = np.linalg.norm( 22 | self._obj.get_position() - self._correct_position) <= self._threshold 23 | return met, False 24 | 25 | class NoCollisions: 26 | def __init__(self, pyrep): 27 | self.colliding_shapes = [s for s in pyrep.get_objects_in_tree( 28 | object_type=ObjectType.SHAPE) if s.is_collidable()] 29 | 30 | def __enter__(self): 31 | for s in self.colliding_shapes: 32 | s.set_collidable(False) 33 | 34 | def __exit__(self, *args): 35 | for s in self.colliding_shapes: 36 | s.set_collidable(False) 37 | 38 | class ChangingPointCondition(Condition): 39 | def __init__(self, val=False) -> None: 40 | self.val = val 41 | 42 | def set(self, value): 43 | self.val = value 44 | 45 | def condition_met(self): 46 | return self.val, False 47 | 48 | FRIDGE_OPEN_JOINT_ANGLE = 45 / 180 * np.pi 49 | FRIDGE_INIT_JOINT_ANGLE = 30 / 180 * np.pi # not setting this, but it's around this value 50 | 51 | class BottleOutMovingFridge(Task): 52 | 53 | def init_task(self) -> None: 54 | self.bottle = Shape('bottle') 55 | self._success_sensor = ProximitySensor('success') 56 | self._fridge_door = Joint("top_joint") 57 | self._fridge_target_velocity = self._fridge_door.get_joint_target_velocity() 58 | 59 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.bottle) 60 | self._detected_cond = DetectedCondition(self.bottle, self._success_sensor) 61 | self.register_graspable_objects([self.bottle]) 62 | self.grasp_init = False 63 | 64 | def init_episode(self, index: int) -> List[str]: 65 | with NoCollisions(self.pyrep): 66 | fridge_joint_pos = np.pi / 2 67 | self._fridge_door.set_joint_position(fridge_joint_pos, disable_dynamics=False) 68 | arm_joint_pos = np.array([+1.983e+01, +2.183e+01, -1.814e+01, -9.465e+01, +9.332e+01, +8.073e+01, -6.668e+01]) / 180 * np.pi 69 | self.robot.arm.set_joint_positions(arm_joint_pos, disable_dynamics=False) 70 | 71 | arm_joint_pos = np.array([+4.067e+01, +2.979e+01, -2.909e+01, -8.598e+01, +1.045e+02, +9.017e+01, -6.300e+01]) / 180 * np.pi 72 | self.robot.arm.set_joint_positions(arm_joint_pos, disable_dynamics=True) 73 | 74 | self._fridge_door.set_joint_target_position(0.) 75 | self._fridge_door.set_joint_target_velocity(self._fridge_target_velocity) 76 | 77 | if self.grasp_init: 78 | # Grasp object 79 | self.robot.gripper.actuate(0, 0.2) # Close gripper 80 | self.robot.gripper.grasp(self.bottle) 81 | assert len(self.robot.gripper.get_grasped_objects()) > 0, "Object not grasped" 82 | 83 | self._steady_cond = SteadyCondition(self.bottle, np.copy(self.bottle.get_position())) # stay within 2.5cm 84 | self._changing_point_cond = ChangingPointCondition() 85 | 86 | self.register_success_conditions([self._grasped_cond, self._detected_cond, self._changing_point_cond]) 87 | 88 | return ['put bottle in fridge', 89 | 'place the bottle inside the fridge', 90 | 'open the fridge and put the bottle in there', 91 | 'open the fridge door, pick up the bottle, and leave it in the ' 92 | 'fridge'] 93 | 94 | def variation_count(self) -> int: 95 | return 1 96 | 97 | def boundary_root(self) -> Object: 98 | return Shape('fridge_root') 99 | 100 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]: 101 | return [0.0, 0.0, 0.0], [0.0, 0.0, 0.0] 102 | 103 | def is_static_workspace(self) -> bool: 104 | return True 105 | 106 | def reward(self) -> float: 107 | grasped = self._grasped_cond.condition_met()[0] 108 | detected = self._detected_cond.condition_met()[0] 109 | bottle_steady = self._steady_cond.condition_met()[0] 110 | past_door_point = self._changing_point_cond.condition_met()[0] 111 | fridge_open = self._fridge_door.get_joint_position() >= FRIDGE_OPEN_JOINT_ANGLE 112 | 113 | grasp_bottle_reward = fridge_open_reward = reach_target_reward = 0.0 114 | 115 | if grasped: 116 | if past_door_point: 117 | grasp_bottle_reward = 1.0 118 | fridge_open_reward = 1.0 119 | 120 | if detected: 121 | reach_target_reward = 1.0 122 | else: 123 | reach_target_reward = np.exp( 124 | -np.linalg.norm( 125 | self.bottle.get_position() 126 | - self._success_sensor.get_position() 127 | ) 128 | ) 129 | else: 130 | if bottle_steady: 131 | grasp_bottle_reward = 1.0 132 | if fridge_open: 133 | self._changing_point_cond.set(True) # = grasped + bottle steady + fridge open 134 | fridge_open_reward = 1.0 135 | # Blocking the fridge helps mantaining Markovianity of the state (the agent can see the fridge is locked in position) 136 | self._fridge_door.set_joint_position(FRIDGE_OPEN_JOINT_ANGLE) 137 | self._fridge_door.set_joint_target_position(FRIDGE_OPEN_JOINT_ANGLE) 138 | self._fridge_door.set_joint_target_velocity(0.) 139 | else: 140 | fridge_open_dist = np.clip(FRIDGE_OPEN_JOINT_ANGLE - self._fridge_door.get_joint_position(), 0, np.inf) 141 | fridge_open_reward = np.exp(-fridge_open_dist) 142 | 143 | reward = grasp_bottle_reward + fridge_open_reward + reach_target_reward 144 | 145 | return reward 146 | 147 | def load(self) -> Object: 148 | ttm_file = join( 149 | dirname(abspath(__file__)), 150 | '../task_ttms/%s.ttm' % self.name) 151 | if not os.path.isfile(ttm_file): 152 | raise FileNotFoundError( 153 | 'The following is not a valid task .ttm file: %s' % ttm_file) 154 | self._base_object = self.pyrep.import_model(ttm_file) 155 | return self._base_object 156 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/cup_out_open_cabinet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from pyrep.objects.dummy import Dummy 3 | from pyrep.objects.joint import Joint 4 | from pyrep.objects.shape import Shape 5 | from pyrep.objects.proximity_sensor import ProximitySensor 6 | from rlbench.backend.task import Task 7 | from rlbench.backend.conditions import DetectedCondition, NothingGrasped, GraspedCondition 8 | from .reach_gripper_and_elbow import DistanceCondition 9 | 10 | import numpy as np 11 | 12 | import os 13 | from os.path import dirname, abspath, join 14 | 15 | OPTIONS = ['left', 'right'] 16 | 17 | class CupOutOpenCabinet(Task): 18 | 19 | def init_task(self) -> None: 20 | self.cup = Shape('cup') 21 | self.left_placeholder = Dummy('left_cup_placeholder') 22 | self.waypoint1 = Dummy('waypoint1') 23 | self.waypoint2 = Dummy('waypoint2') 24 | 25 | self.target = self.waypoint3 = Dummy('waypoint3') 26 | 27 | self.left_way_placeholder1 = Dummy('left_way_placeholder1') 28 | self.left_way_placeholder2 = Dummy('left_way_placeholder2') 29 | 30 | self.grasped_cond = GraspedCondition(self.robot.gripper, self.cup) 31 | 32 | self.cup_target_cond = DistanceCondition(self.cup, self.target, 0.05) 33 | self.register_graspable_objects([self.cup]) 34 | 35 | 36 | self.register_success_conditions( 37 | [self.grasped_cond, self.cup_target_cond,]) 38 | 39 | def init_episode(self, index: int) -> List[str]: 40 | option = 'right' # OPTIONS[index] 41 | 42 | self.joint_target = Joint(f'{option}_joint') 43 | self.handle_wp = self.waypoint1 44 | 45 | 46 | return ['take out a cup from the %s half of the cabinet' % option, 47 | 'open the %s side of the cabinet and get the cup' 48 | % option, 49 | 'grasping the %s handle, open the cabinet, then retrieve the ' 50 | 'cup' % option, 51 | 'slide open the %s door on the cabinet and put take the cup out' 52 | % option, 53 | 'remove the cup from the %s part of the cabinet' % option] 54 | 55 | def reward(self,) -> float: 56 | cup_grasped = self.grasped_cond.condition_met()[0] 57 | cup_is_out = self.cup_target_cond.condition_met()[0] 58 | 59 | reach_cup_reward = cup_out_reward = 0 60 | 61 | if cup_grasped: 62 | reach_cup_reward = 1.0 63 | 64 | if cup_is_out: 65 | cup_out_reward = 1.0 66 | else: 67 | cup_out_reward = np.exp(-np.linalg.norm(self.cup.get_position() - self.target.get_position())) 68 | else: 69 | reach_cup_reward = np.exp(-np.linalg.norm(self.cup.get_position() - self.robot.arm.get_tip().get_position())) 70 | 71 | reward = reach_cup_reward + cup_out_reward 72 | 73 | return reward 74 | 75 | 76 | def variation_count(self) -> int: 77 | return 2 78 | 79 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]: 80 | return [0.0, 0.0, -3.14/2], [0.0, 0.0, 3.14/2] 81 | 82 | 83 | def load(self): 84 | ttm_file = join( 85 | dirname(abspath(__file__)), 86 | '../task_ttms/%s.ttm' % self.name) 87 | if not os.path.isfile(ttm_file): 88 | raise FileNotFoundError( 89 | 'The following is not a valid task .ttm file: %s' % ttm_file) 90 | self._base_object = self.pyrep.import_model(ttm_file) 91 | return self._base_object -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/meat_off_grill.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pyrep.objects.dummy import Dummy 3 | from pyrep.objects.proximity_sensor import ProximitySensor 4 | from pyrep.objects.shape import Shape 5 | from rlbench.backend.conditions import NothingGrasped, DetectedCondition, GraspedCondition 6 | from rlbench.backend.task import Task 7 | import numpy as np 8 | 9 | MEAT = ['chicken', 'steak'] 10 | 11 | 12 | class MeatOffGrill(Task): 13 | 14 | def init_task(self) -> None: 15 | self._steak = Shape('steak') 16 | self._chicken = Shape('chicken') 17 | self._success_sensor = ProximitySensor('success') 18 | self.register_graspable_objects([self._chicken, self._steak]) 19 | self._w1 = Dummy('waypoint1') 20 | self._w1z= self._w1.get_position()[2] 21 | 22 | self._nothing_grasped_condition = NothingGrasped(self.robot.gripper) 23 | 24 | def init_episode(self, index: int) -> List[str]: 25 | if index == 0: 26 | self._target = self._chicken 27 | else: 28 | self._target = self._meat 29 | x, y, _ = self._target.get_position() 30 | self._w1.set_position([x, y, self._w1z]) 31 | self._detected_condition = DetectedCondition(self._target, self._success_sensor) 32 | self._grasped_condition = GraspedCondition(self.robot.gripper, self._target) 33 | 34 | conditions = [self._nothing_grasped_condition, self._detected_condition] 35 | self.register_success_conditions(conditions) 36 | return ['take the %s off the grill' % MEAT[index], 37 | 'pick up the %s and place it next to the grill' % MEAT[index], 38 | 'remove the %s from the grill and set it down to the side' 39 | % MEAT[index]] 40 | 41 | def reward(self,) -> float: 42 | nothing_grasped = self._nothing_grasped_condition.condition_met()[0] 43 | detected = self._detected_condition.condition_met()[0] 44 | grasped = self._grasped_condition.condition_met()[0] 45 | 46 | reach_reward = move_reward = release_reward = 0. 47 | self.reward_open_gripper = False 48 | 49 | if grasped: 50 | reach_reward = 1.0 51 | 52 | if detected: 53 | move_reward = 1.0 54 | self.reward_open_gripper = True 55 | else: 56 | move_reward = np.exp( 57 | -np.linalg.norm(self._target.get_position() - self._success_sensor.get_position()) 58 | ) 59 | else: 60 | if detected: 61 | if nothing_grasped: 62 | reach_reward = move_reward = release_reward = 1.0 63 | else: 64 | reach_reward = move_reward = 1.0 65 | self.reward_open_gripper = True 66 | else: 67 | reach_reward = np.exp( 68 | -np.linalg.norm(self._target.get_position() - self.robot.arm.get_tip().get_position()) 69 | ) 70 | 71 | reward = reach_reward + move_reward + release_reward 72 | 73 | return reward 74 | 75 | def variation_count(self) -> int: 76 | return 2 77 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/phone_on_base.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | from pyrep.objects.proximity_sensor import ProximitySensor 5 | from pyrep.objects.shape import Shape 6 | 7 | from rlbench.backend.conditions import ( 8 | DetectedCondition, 9 | GraspedCondition, 10 | NothingGrasped, 11 | ) 12 | from rlbench.backend.task import Task 13 | 14 | 15 | class PhoneOnBase(Task): 16 | def init_task(self) -> None: 17 | self.phone = Shape("phone") 18 | self.success_detector = ProximitySensor("success") 19 | self.register_graspable_objects([self.phone]) 20 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.phone) 21 | self._nothing_grapsed_cond = NothingGrasped(self.robot.gripper) 22 | self._phone_cond = DetectedCondition(self.phone, self.success_detector) 23 | 24 | self.register_success_conditions( 25 | [self._phone_cond, NothingGrasped(self.robot.gripper)] 26 | ) 27 | 28 | def init_episode(self, index: int) -> List[str]: 29 | return [ 30 | "put the phone on the base", 31 | "put the phone on the stand", 32 | "put the hone on the hub", 33 | "grasp the phone and put it on the base", 34 | "place the phone on the base", 35 | "put the phone back on the base", 36 | ] 37 | 38 | def variation_count(self) -> int: 39 | return 1 40 | 41 | def reward(self) -> float: 42 | grasped = self._grasped_cond.condition_met()[0] 43 | phone_on_base = self._phone_cond.condition_met()[0] 44 | nothing_grasped = self._nothing_grapsed_cond.condition_met()[0] 45 | 46 | grasp_phone_reward = move_phone_reward = release_reward = 0 47 | 48 | self.reward_open_gripper = False 49 | 50 | if phone_on_base: 51 | if nothing_grasped: 52 | # phone is not grasped anymore 53 | grasp_phone_reward = move_phone_reward = release_reward = 1.0 54 | else: 55 | # phone is in base, but gripper still holds the phone (or something else) 56 | move_phone_reward = 1.0 57 | grasp_phone_reward = 1.0 58 | self.reward_open_gripper = True 59 | else: 60 | if not grasped: 61 | # reaching the phone 62 | grasp_phone_reward = np.exp( 63 | -np.linalg.norm( 64 | self.phone.get_position() 65 | - self.robot.arm.get_tip().get_position() 66 | ) 67 | ) 68 | else: 69 | grasp_phone_reward = 1.0 70 | # moving the phone toward base 71 | move_phone_reward = np.exp( 72 | -np.linalg.norm( 73 | self.phone.get_position() - self.success_detector.get_position() 74 | ) 75 | ) 76 | 77 | reward = grasp_phone_reward + move_phone_reward + release_reward 78 | 79 | return reward 80 | 81 | def get_low_dim_state(self) -> np.ndarray: 82 | # For ad-hoc reward computation, attach reward 83 | state = super().get_low_dim_state() 84 | return state 85 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/pick_and_lift.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | from pyrep.objects.shape import Shape 4 | from pyrep.objects.proximity_sensor import ProximitySensor 5 | from rlbench.backend.task import Task 6 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, \ 7 | GraspedCondition 8 | from rlbench.backend.spawn_boundary import SpawnBoundary 9 | from rlbench.const import colors 10 | 11 | 12 | class PickAndLift(Task): 13 | 14 | def init_task(self) -> None: 15 | self.target_block = Shape('pick_and_lift_target') 16 | self.distractors = [ 17 | Shape('stack_blocks_distractor%d' % i) 18 | for i in range(2)] 19 | self.register_graspable_objects([self.target_block]) 20 | self.boundary = SpawnBoundary([Shape('pick_and_lift_boundary')]) 21 | self.success_detector = ProximitySensor('pick_and_lift_success') 22 | 23 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.target_block) 24 | self._detected_cond = DetectedCondition(self.target_block, self.success_detector) 25 | cond_set = ConditionSet([ 26 | self._grasped_cond, 27 | self._detected_cond 28 | ]) 29 | self.register_success_conditions([cond_set]) 30 | 31 | def init_episode(self, index: int) -> List[str]: 32 | 33 | block_color_name, block_rgb = colors[index] 34 | self.target_block.set_color(block_rgb) 35 | 36 | color_choices = np.random.choice( 37 | list(range(index)) + list(range(index + 1, len(colors))), 38 | size=2, replace=False) 39 | for i, ob in enumerate(self.distractors): 40 | name, rgb = colors[color_choices[int(i)]] 41 | ob.set_color(rgb) 42 | 43 | self.boundary.clear() 44 | self.boundary.sample( 45 | self.success_detector, min_rotation=(0.0, 0.0, 0.0), 46 | max_rotation=(0.0, 0.0, 0.0)) 47 | for block in [self.target_block] + self.distractors: 48 | self.boundary.sample(block, min_distance=0.1) 49 | 50 | return ['pick up the %s block and lift it up to the target' % 51 | block_color_name, 52 | 'grasp the %s block to the target' % block_color_name, 53 | 'lift the %s block up to the target' % block_color_name] 54 | 55 | def variation_count(self) -> int: 56 | return len(colors) 57 | 58 | def reward(self) -> float: 59 | grasped = self._grasped_cond.condition_met()[0] 60 | detected = self._detected_cond.condition_met()[0] 61 | 62 | pick_reward = lift_reward = 0. 63 | 64 | if not grasped: 65 | pick_reward = np.exp( 66 | -np.linalg.norm( 67 | self.target_block.get_position() - self.robot.arm.get_tip().get_position() 68 | ) 69 | ) 70 | else: 71 | pick_reward = 1.0 72 | 73 | if detected: 74 | lift_reward = 1.0 75 | else: 76 | lift_reward = np.exp( 77 | -np.linalg.norm( 78 | self.target_block.get_position() - self.success_detector.get_position() 79 | ) 80 | ) 81 | 82 | 83 | reward = pick_reward + lift_reward 84 | 85 | return reward 86 | 87 | def get_low_dim_state(self) -> np.ndarray: 88 | # One of the few tasks that have a custom low_dim_state function. 89 | return np.concatenate([self.target_block.get_position(), self.success_detector.get_position()], 0) -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/pick_up_cup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | from pyrep.objects.shape import Shape 4 | from pyrep.objects.proximity_sensor import ProximitySensor 5 | from rlbench.const import colors 6 | from rlbench.backend.task import Task 7 | from rlbench.backend.conditions import ( 8 | DetectedCondition, 9 | NothingGrasped, 10 | GraspedCondition, 11 | ) 12 | from rlbench.backend.spawn_boundary import SpawnBoundary 13 | 14 | 15 | class PickUpCup(Task): 16 | def init_task(self) -> None: 17 | self.cup1 = Shape("cup1") 18 | self.cup2 = Shape("cup2") 19 | self.cup1_visual = Shape("cup1_visual") 20 | self.cup2_visual = Shape("cup2_visual") 21 | self.boundary = SpawnBoundary([Shape("boundary")]) 22 | self.success_sensor = ProximitySensor("success") 23 | self.register_graspable_objects([self.cup1, self.cup2]) 24 | 25 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.cup1) 26 | self._not_detected_cond = DetectedCondition( 27 | self.cup1, self.success_sensor, negated=True 28 | ) 29 | 30 | self.register_success_conditions([self._not_detected_cond, self._grasped_cond]) 31 | 32 | def init_episode(self, index: int) -> List[str]: 33 | self.variation_index = index 34 | target_color_name, target_rgb = colors[index] 35 | 36 | random_idx = np.random.choice(len(colors)) 37 | while random_idx == index: 38 | random_idx = np.random.choice(len(colors)) 39 | _, other1_rgb = colors[random_idx] 40 | 41 | self.cup1_visual.set_color(target_rgb) 42 | self.cup2_visual.set_color(other1_rgb) 43 | 44 | self.boundary.clear() 45 | self.boundary.sample(self.cup2, min_distance=0.1) 46 | self.boundary.sample(self.success_sensor, min_distance=0.1) 47 | 48 | return [ 49 | "pick up the %s cup" % target_color_name, 50 | "grasp the %s cup and lift it" % target_color_name, 51 | "lift the %s cup" % target_color_name, 52 | ] 53 | 54 | def variation_count(self) -> int: 55 | return len(colors) 56 | 57 | def reward(self) -> float: 58 | grasped = self._grasped_cond.condition_met()[0] 59 | not_detected = self._not_detected_cond.condition_met()[0] 60 | 61 | pick_reward = up_reward = 0.0 62 | 63 | if not grasped: 64 | pick_reward = np.exp( 65 | -np.linalg.norm( 66 | self.cup1.get_position() - self.robot.arm.get_tip().get_position() 67 | ) 68 | ) 69 | else: 70 | pick_reward = 1.0 71 | 72 | if not_detected: 73 | up_reward = 1.0 74 | else: 75 | distance_from_orig_pose = np.linalg.norm( 76 | self.cup1.get_position() - self.success_sensor.get_position() 77 | ) 78 | up_reward = np.tanh(distance_from_orig_pose) 79 | 80 | reward = pick_reward + up_reward 81 | 82 | return reward 83 | 84 | def get_low_dim_state(self) -> np.ndarray: 85 | # For ad-hoc reward computation, attach reward 86 | state = super().get_low_dim_state() 87 | return state -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/put_rubbish_in_bin.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | import copy 4 | from pyrep.objects.proximity_sensor import ProximitySensor 5 | from pyrep.objects.shape import Shape 6 | from rlbench.backend.task import Task 7 | from rlbench.backend.conditions import DetectedCondition, GraspedCondition 8 | 9 | 10 | class PutRubbishInBin(Task): 11 | def init_task(self): 12 | self.success_sensor = ProximitySensor("success") 13 | self.rubbish = Shape("rubbish") 14 | self.register_graspable_objects([self.rubbish]) 15 | self.register_success_conditions( 16 | [DetectedCondition(self.rubbish, self.success_sensor)] 17 | ) 18 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.rubbish) 19 | self._detected_cond = DetectedCondition(self.rubbish, self.success_sensor) 20 | self.HIGH_Z_TARGET = 1.05 21 | self.LOW_Z_TARGET = 0.9 22 | 23 | def init_episode(self, index: int) -> List[str]: 24 | tomato1 = Shape("tomato1") 25 | tomato2 = Shape("tomato2") 26 | x1, y1, z1 = tomato2.get_position() 27 | x2, y2, z2 = self.rubbish.get_position() 28 | x3, y3, z3 = tomato1.get_position() 29 | pos = np.random.randint(3) 30 | if pos == 0: 31 | self.rubbish.set_position([x1, y1, z2]) 32 | tomato2.set_position([x2, y2, z1]) 33 | elif pos == 2: 34 | self.rubbish.set_position([x3, y3, z2]) 35 | tomato1.set_position([x2, y2, z3]) 36 | 37 | self.lifted = False 38 | # self.reward_open_gripper = False 39 | 40 | return [ 41 | "put rubbish in bin", 42 | "drop the rubbish into the bin", 43 | "pick up the rubbish and leave it in the trash can", 44 | "throw away the trash, leaving any other objects alone", 45 | "chuck way any rubbish on the table rubbish", 46 | ] 47 | 48 | def variation_count(self) -> int: 49 | return 1 50 | 51 | def reward(self) -> float: 52 | grasped = self._grasped_cond.condition_met()[0] 53 | detected = self._detected_cond.condition_met()[0] 54 | 55 | 56 | target1_pos = copy.deepcopy(self.success_sensor.get_position()) 57 | target1_pos[-1] = self.HIGH_Z_TARGET 58 | 59 | target2_pos = copy.deepcopy(self.success_sensor.get_position()) 60 | target2_pos[-1] = self.LOW_Z_TARGET 61 | 62 | grasp_rubbish_reward = move_rubbish_reward = release_reward = 0 63 | self.reward_open_gripper = False 64 | 65 | if not grasped: 66 | if detected: 67 | grasp_rubbish_reward = move_rubbish_reward = release_reward = 1.0 68 | else: 69 | grasp_rubbish_reward = np.exp( 70 | -np.linalg.norm( 71 | self.rubbish.get_position() 72 | - self.robot.arm.get_tip().get_position() 73 | ) 74 | ) 75 | else: 76 | grasp_rubbish_reward = 1.0 77 | 78 | rubbish_in_bin_area_dist = np.linalg.norm(self.rubbish.get_position()[:2] - target2_pos[:2]) 79 | rubbish_in_bin_area = rubbish_in_bin_area_dist < 0.03 # if within 3cm 80 | 81 | rubbish_height = self.rubbish.get_position()[2] 82 | 83 | if rubbish_in_bin_area: 84 | above_bin_dist = np.abs(rubbish_height - self.LOW_Z_TARGET) 85 | rubbish_above_bin = above_bin_dist < 0.06 # if within 6cm 86 | 87 | if rubbish_above_bin: 88 | move_rubbish_reward = 1.0 89 | 90 | self.reward_open_gripper = True 91 | else: 92 | move_rubbish_reward = 0.5 + 0.5 * np.exp(-above_bin_dist) # 0.5 for getting in the area + dist 93 | else: 94 | move_rubbish_reward = 0.5 * np.exp(-np.linalg.norm(self.rubbish.get_position() - target1_pos)) # up to 0.5 -> needs to get in area 95 | 96 | reward = grasp_rubbish_reward + move_rubbish_reward + release_reward 97 | 98 | return reward 99 | 100 | def get_low_dim_state(self) -> np.ndarray: 101 | # For ad-hoc reward computation, attach reward 102 | state = super().get_low_dim_state() 103 | return state 104 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/reach_gripper_and_elbow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname, abspath, join 3 | from typing import List, Tuple 4 | import numpy as np 5 | from pyrep.objects import Object 6 | from pyrep.objects.dummy import Dummy 7 | from pyrep.objects.object import Object 8 | from pyrep.objects.proximity_sensor import ProximitySensor 9 | from pyrep.objects.shape import Shape 10 | from pyrep.objects.joint import Joint 11 | # from rlbench.backend import sim 12 | from rlbench.backend.task import Task 13 | from rlbench.backend.spawn_boundary import SpawnBoundary 14 | from rlbench.backend.conditions import DetectedCondition, Condition 15 | from pyquaternion import Quaternion 16 | 17 | 18 | def get_quaternion_distance(obj, target): 19 | pyq_quaternion = lambda x: Quaternion([x[3],x[0],x[1],x[2]]) 20 | q1, q2 = pyq_quaternion(obj.get_quaternion()), pyq_quaternion(target.get_quaternion()) 21 | return np.abs((q1.inverse * q2).angle) 22 | 23 | def get_orientation_distance(obj, target): 24 | ori1, ori2 = obj.get_orientation(), target.get_orientation() 25 | diff = np.abs(ori2 - ori1) 26 | for i in range(3): 27 | while diff[i] > 2 * np.pi: 28 | diff[i] -= 2 * np.pi 29 | diff[i] = min(diff[i], 2*np.pi - diff[i]) 30 | return diff.mean() 31 | 32 | class DistanceCondition(Condition): 33 | def __init__(self, obj: Object, target: Object, threshold = 0.075): 34 | """in radians if revoloute, or meters if prismatic""" 35 | self._obj = obj 36 | self._target = target 37 | self._threshold = threshold 38 | 39 | def condition_met(self): 40 | met = np.linalg.norm( 41 | self._obj.get_position() - self._target.get_position()) <= self._threshold 42 | return met, False 43 | 44 | class AngleCondition(Condition): 45 | def __init__(self, obj: Object, target: Object, threshold = 0.18): # 10 degress 46 | """in radians if revoloute, or meters if prismatic""" 47 | self._obj = obj 48 | self._target = target 49 | self._threshold = threshold 50 | 51 | def condition_met(self): 52 | met = get_orientation_distance(self._obj, self._target) <= self._threshold 53 | return met, False 54 | 55 | class ReachGripperAndElbow(Task): 56 | 57 | def init_task(self) -> None: 58 | self.ee_target = Shape("ee_target") 59 | self.elbow_target = Shape("elbow_target") 60 | self.boundaries = Shape("boundary") 61 | self.ee_dummy = Dummy('waypoint1') 62 | 63 | self.ee_success_sensor = ProximitySensor("ee_success") 64 | elbow_success_sensor = ProximitySensor("elbow_success") # not being used, cause the joint is not detected 65 | self.elbow = Joint("Panda_joint4") 66 | 67 | self.ee_condition = DistanceCondition( 68 | self.robot.arm.get_tip(), self.ee_success_sensor, threshold=0.075 69 | ) 70 | self.elbow_condition = DistanceCondition( 71 | self.elbow, self.elbow_target, threshold=0.075 72 | ) 73 | self.orientation_condition = AngleCondition( 74 | self.robot.arm.get_tip(), self.ee_dummy, threshold = 0.09 # 5 degress per direction 75 | ) 76 | 77 | self.randomize_orientation = False 78 | 79 | self.register_success_conditions([ 80 | self.ee_condition, 81 | self.elbow_condition, 82 | self.orientation_condition 83 | ]) 84 | 85 | def init_episode(self, index: int) -> List[str]: 86 | b = SpawnBoundary([self.boundaries]) 87 | b.sample(self.ee_target, min_distance=0.2, 88 | min_rotation=(0, 0, 0), max_rotation=(0, 0, 0)) 89 | 90 | if np.random.randint(2) == 0: 91 | theta = np.random.uniform(np.pi * .8, np.pi * .95) 92 | else: 93 | theta = np.random.uniform(np.pi * .05, np.pi * .2) 94 | 95 | # https://frankaemika.github.io/docs/control_parameters.html 96 | # https://download.franka.de/documents/220010_Product%20Manual_Franka%20Hand_1.2_EN.pdf 97 | r_h = 0.1070 + 0.1270 # wrist + gripper 98 | r_w = 0.0880 99 | r_1, r_2 = 0.3160, 0.3840 100 | 101 | c_1 = Joint("Panda_joint2").get_position() 102 | c_2 = self.ee_success_sensor.get_position() 103 | 104 | d = np.linalg.norm(c_1 - c_2) 105 | n_i = (c_2 - c_1) / d 106 | 107 | # t_i, b_i are the tangent, bitangent 108 | t_i = np.array([0, -1, 0]) 109 | if not np.array_equal(n_i, np.array([0, 0, 1])): 110 | t_i = np.cross(n_i, np.array([0, 0, 1])) 111 | t_i = t_i / np.linalg.norm(t_i) 112 | b_i = np.array([-1, 0, 0]) 113 | if not np.array_equal(n_i, np.array([0, 1, 0])): 114 | b_i = np.cross(n_i, np.array([0, 1, 0])) 115 | b_i = b_i / np.linalg.norm(b_i) 116 | 117 | if d >= r_1 + r_2: 118 | p_i = c_1 + (c_2 - c_1) * r_1 / d 119 | else: 120 | h = 1/2 + (r_1 ** 2 - r_2 ** 2)/(2 * d ** 2) 121 | c_i = c_1 + h * (c_2 - c_1) 122 | r_i = np.sqrt(r_1 ** 2 - (h * d) ** 2) 123 | p_i = c_i + r_i * (t_i * np.cos(theta) + b_i * np.sin(theta)) 124 | 125 | self.elbow_target.set_position(p_i) 126 | # ee_pos = c_2 + r_w * n_i - r_h * b_i 127 | ee_pos = self.ee_target.get_position() + r_w * np.array([1, 0, 0]) - r_h * np.array([0, 0, 1]) 128 | self.ee_target.set_position(ee_pos) 129 | 130 | if self.randomize_orientation: 131 | ee_ori = self.ee_target.get_orientation() 132 | ee_ori += np.random.uniform(-np.pi/4, +np.pi/4, size=(3,)) 133 | self.ee_target.set_orientation(ee_ori) 134 | return [''] 135 | 136 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]: 137 | return [0.0, 0.0, 0.0], [0.0, 0.0, 0.0] 138 | 139 | def is_static_workspace(self) -> bool: 140 | return True 141 | 142 | def variation_count(self) -> int: 143 | return 1 144 | 145 | def get_low_dim_state(self) -> np.ndarray: 146 | return np.array([ 147 | self.ee_target.get_position(), self.elbow_target.get_position() 148 | ]).flatten() 149 | 150 | def reward(self) -> float: 151 | ee_detected = self.ee_condition.condition_met()[0] 152 | elbow_detected = self.elbow_condition.condition_met()[0] 153 | orientation_detected = self.orientation_condition.condition_met()[0] 154 | 155 | orient_reward = elbow_reward = ee_reward = 0 156 | 157 | if ee_detected: 158 | ee_reward = 1.0 159 | else: 160 | ee_distance = np.linalg.norm(self.ee_target.get_position() - self.robot.arm.get_tip().get_position()) 161 | ee_reward = np.exp(-ee_distance) 162 | 163 | if orientation_detected: 164 | orient_reward = 1.0 165 | else: 166 | orient_distance = get_orientation_distance(self.robot.arm.get_tip(), self.ee_dummy) 167 | orient_reward = np.exp(-orient_distance) 168 | 169 | if elbow_detected: 170 | elbow_reward = 1.0 171 | else: 172 | elbow_distance = np.linalg.norm(self.elbow_target.get_position() - self.elbow.get_position()) 173 | elbow_reward = np.exp(-elbow_distance) 174 | 175 | reward = elbow_reward + ee_reward + orient_reward 176 | 177 | return reward 178 | 179 | def load(self) -> Object: 180 | ttm_file = join( 181 | dirname(abspath(__file__)), 182 | '../task_ttms/%s.ttm' % self.name) 183 | if not os.path.isfile(ttm_file): 184 | raise FileNotFoundError( 185 | 'The following is not a valid task .ttm file: %s' % ttm_file) 186 | self._base_object = self.pyrep.import_model(ttm_file) 187 | return self._base_object 188 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/reach_target.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | from pyrep.objects.shape import Shape 4 | from pyrep.objects.proximity_sensor import ProximitySensor 5 | from rlbench.const import colors 6 | from rlbench.backend.task import Task 7 | from rlbench.backend.spawn_boundary import SpawnBoundary 8 | from rlbench.backend.conditions import DetectedCondition 9 | 10 | 11 | class ReachTarget(Task): 12 | 13 | def init_task(self) -> None: 14 | self.target = Shape('target') 15 | self.distractor0 = Shape('distractor0') 16 | self.distractor1 = Shape('distractor1') 17 | self.boundaries = Shape('boundary') 18 | success_sensor = ProximitySensor('success') 19 | 20 | self._detected_condition = DetectedCondition(self.robot.arm.get_tip(), success_sensor) 21 | self.register_success_conditions([self._detected_condition]) 22 | 23 | def init_episode(self, index: int) -> List[str]: 24 | color_name, color_rgb = colors[index] 25 | self.target.set_color(color_rgb) 26 | color_choices = np.random.choice( 27 | list(range(index)) + list(range(index + 1, len(colors))), 28 | size=2, replace=False) 29 | for ob, i in zip([self.distractor0, self.distractor1], color_choices): 30 | name, rgb = colors[i] 31 | ob.set_color(rgb) 32 | b = SpawnBoundary([self.boundaries]) 33 | for ob in [self.target, self.distractor0, self.distractor1]: 34 | b.sample(ob, min_distance=0.2, 35 | min_rotation=(0, 0, 0), max_rotation=(0, 0, 0)) 36 | 37 | return ['reach the %s target' % color_name, 38 | 'touch the %s ball with the panda gripper' % color_name, 39 | 'reach the %s sphere' %color_name] 40 | 41 | def variation_count(self) -> int: 42 | return len(colors) 43 | 44 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]: 45 | return [0.0, 0.0, 0.0], [0.0, 0.0, 0.0] 46 | 47 | def get_low_dim_state(self) -> np.ndarray: 48 | # One of the few tasks that have a custom low_dim_state function. 49 | return np.array(self.target.get_position()) 50 | 51 | def is_static_workspace(self) -> bool: 52 | return True 53 | 54 | def reward(self) -> float: 55 | success = self._detected_condition.condition_met()[0] 56 | 57 | if success: 58 | reward = 1.0 59 | else: 60 | reward = np.exp(-np.linalg.norm(self.target.get_position() - 61 | self.robot.arm.get_tip().get_position())) 62 | return reward 63 | 64 | def validate(self): 65 | pass -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/slide_cup.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import numpy as np 3 | from pyrep.objects.dummy import Dummy 4 | from pyrep.objects.shape import Shape 5 | from pyrep.objects.object import Object 6 | from pyrep.objects.proximity_sensor import ProximitySensor 7 | from pyrep.const import PrimitiveShape 8 | from rlbench.const import colors 9 | from rlbench.backend.task import Task 10 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, GraspedCondition, Condition 11 | from rlbench.backend.spawn_boundary import SpawnBoundary 12 | from pyquaternion import Quaternion 13 | from pyrep.const import ObjectType 14 | 15 | import os 16 | from os.path import dirname, abspath, join 17 | 18 | class SlideCup(Task): 19 | 20 | def init_task(self) -> None: 21 | self.cup_source = Shape('cup_source') 22 | self.collidables = [s for s in self.pyrep.get_objects_in_tree( object_type=ObjectType.SHAPE) if ('bottle' in Object.get_object_name(s._handle)) and s.is_respondable()] 23 | 24 | self.success_detector = ProximitySensor('success') 25 | self.cup_condition = DetectedCondition(self.cup_source, self.success_detector) 26 | self.register_success_conditions([self.cup_condition]) 27 | 28 | 29 | def init_episode(self, index: int) -> List[str]: 30 | self.initial_z = self.cup_source.get_position()[2] 31 | 32 | return ['slide coffee'] 33 | 34 | def variation_count(self) -> int: 35 | return 1 36 | 37 | def is_static_workspace(self) -> bool: 38 | return True 39 | 40 | def reward(self,): 41 | cup_placed = self.cup_condition.condition_met()[0] 42 | cup_fallen = self.cup_source.get_position()[2] < (self.initial_z - 0.075) 43 | 44 | close_reward = move_reward = 0.0 45 | 46 | if cup_fallen: 47 | self.terminate_episode = True 48 | else: 49 | self.terminate_episode = False 50 | 51 | if cup_placed: 52 | move_reward = close_reward = 1.0 53 | else: 54 | left_cup_position = (self.cup_source.get_position() - np.array([0,0.05,0])) 55 | 56 | close_distance = np.linalg.norm(left_cup_position - self.robot.arm.get_tip().get_position()) 57 | 58 | if close_distance <= 0.025: 59 | close_reward = 1.0 60 | 61 | move_reward = np.exp(-np.linalg.norm(self.cup_source.get_position() - self.success_detector.get_position())) 62 | else: 63 | # position is offset by 5cm to the left, from the human view 64 | close_reward = np.exp(-np.linalg.norm(left_cup_position - self.robot.arm.get_tip().get_position())) 65 | 66 | reward = close_reward + move_reward 67 | 68 | return reward 69 | 70 | def validate(self): 71 | pass 72 | 73 | def load(self): 74 | ttm_file = join( 75 | dirname(abspath(__file__)), 76 | '../task_ttms/%s.ttm' % self.name) 77 | if not os.path.isfile(ttm_file): 78 | raise FileNotFoundError( 79 | 'The following is not a valid task .ttm file: %s' % ttm_file) 80 | self._base_object = self.pyrep.import_model(ttm_file) 81 | return self._base_object 82 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/stack_wine.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from pyrep.objects.shape import Shape 5 | from pyrep.objects.proximity_sensor import ProximitySensor 6 | from rlbench.backend.task import Task 7 | from rlbench.backend.conditions import ( 8 | DetectedCondition, 9 | GraspedCondition, 10 | NothingGrasped, 11 | ) 12 | 13 | 14 | class StackWine(Task): 15 | def init_task(self): 16 | self.wine_bottle = Shape("wine_bottle") 17 | self.register_graspable_objects([self.wine_bottle]) 18 | 19 | self._success_sensor = ProximitySensor("success") 20 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.wine_bottle) 21 | self._detected_cond = DetectedCondition(self.wine_bottle, self._success_sensor) 22 | 23 | self.register_success_conditions([self._detected_cond]) 24 | 25 | def init_episode(self, index: int) -> List[str]: 26 | return [ 27 | "stack wine bottle", 28 | "slide the bottle onto the wine rack", 29 | "put the wine away", 30 | "leave the wine on the shelf", 31 | "grasp the bottle and put it away", 32 | "place the wine bottle on the wine rack", 33 | ] 34 | 35 | def variation_count(self) -> int: 36 | return 1 37 | 38 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]: 39 | return [0, 0, -np.pi / 4.0], [0, 0, np.pi / 4.0] 40 | 41 | def reward(self) -> float: 42 | grasped = self._grasped_cond.condition_met()[0] 43 | detected = self._detected_cond.condition_met()[0] 44 | 45 | grasp_wine_reward = reach_target_reward = 0.0 46 | 47 | if not grasped: 48 | grasp_wine_reward = np.exp( 49 | -np.linalg.norm( 50 | self.wine_bottle.get_position() 51 | - self.robot.arm.get_tip().get_position() 52 | ) 53 | ) 54 | else: 55 | grasp_wine_reward = 1.0 56 | 57 | if detected: 58 | reach_target_reward = 1.0 59 | else: 60 | reach_target_reward = np.exp( 61 | -np.linalg.norm( 62 | self.wine_bottle.get_position() 63 | - self._success_sensor.get_position() 64 | ) 65 | ) 66 | 67 | reward = grasp_wine_reward + reach_target_reward 68 | 69 | return reward 70 | 71 | def get_low_dim_state(self) -> np.ndarray: 72 | # For ad-hoc reward computation, attach reward 73 | state = super().get_low_dim_state() 74 | return state 75 | -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/take_lid_off_saucepan.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | from pyrep.objects.proximity_sensor import ProximitySensor 5 | from pyrep.objects.shape import Shape 6 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, \ 7 | GraspedCondition 8 | from rlbench.backend.task import Task 9 | 10 | 11 | class TakeLidOffSaucepan(Task): 12 | 13 | def init_task(self) -> None: 14 | self.lid = Shape('saucepan_lid_grasp_point') 15 | self.success_detector = ProximitySensor('success') 16 | self.register_graspable_objects([self.lid]) 17 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.lid) 18 | self._detected_cond = DetectedCondition(self.lid, self.success_detector) 19 | cond_set = ConditionSet([ 20 | self._grasped_cond, 21 | self._detected_cond, 22 | ]) 23 | self.register_success_conditions([cond_set]) 24 | 25 | def init_episode(self, index: int) -> List[str]: 26 | return ['take lid off the saucepan', 27 | 'using the handle, lift the lid off of the pan', 28 | 'remove the lid from the pan', 29 | 'grip the saucepan\'s lid and remove it from the pan', 30 | 'leave the pan open', 31 | 'uncover the saucepan'] 32 | 33 | def variation_count(self) -> int: 34 | return 1 35 | 36 | def reward(self) -> float: 37 | grasped = self._grasped_cond.condition_met()[0] 38 | detected = self._detected_cond.condition_met()[0] 39 | 40 | grasp_lid_reward = lift_lid_reward = 0.0 41 | 42 | if grasped: 43 | grasp_lid_reward = 1.0 44 | 45 | if detected: 46 | lift_lid_reward = 1.0 47 | else: 48 | lift_lid_reward = np.exp(-np.linalg.norm( 49 | self.lid.get_position() - self.success_detector.get_position())) 50 | else: 51 | grasp_lid_reward = np.exp(-np.linalg.norm( 52 | self.lid.get_position() - self.robot.arm.get_tip().get_position())) 53 | 54 | reward = grasp_lid_reward + lift_lid_reward 55 | 56 | return reward -------------------------------------------------------------------------------- /rl/env/custom_rlbench_tasks/tasks/take_umbrella_out_of_umbrella_stand.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | import copy 4 | from scipy.spatial.transform import Rotation 5 | from pyrep.objects.shape import Shape 6 | from pyrep.objects.proximity_sensor import ProximitySensor 7 | from rlbench.backend.task import Task 8 | from rlbench.backend.conditions import DetectedCondition, GraspedCondition 9 | 10 | 11 | class TakeUmbrellaOutOfUmbrellaStand(Task): 12 | def init_task(self): 13 | self.success_sensor = ProximitySensor("success") 14 | self.umbrella = Shape("umbrella") 15 | self.register_graspable_objects([self.umbrella]) 16 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.umbrella) 17 | self._detected_cond = DetectedCondition( 18 | self.umbrella, self.success_sensor, negated=True 19 | ) 20 | self.register_success_conditions([self._detected_cond]) 21 | self.Z_TARGET = 1.12 22 | 23 | def init_episode(self, index: int) -> List[str]: 24 | return [ 25 | "take umbrella out of umbrella stand", 26 | "grasping the umbrella by its handle, lift it up and out of the" " stand", 27 | "remove the umbrella from the stand", 28 | "retrieve the umbrella from the stand", 29 | "get the umbrella", 30 | "lift the umbrella out of the stand", 31 | ] 32 | 33 | def variation_count(self) -> int: 34 | return 1 35 | 36 | def reward(self) -> float: 37 | target_pos = copy.deepcopy(self.success_sensor.get_position()) 38 | target_pos[-1] = self.Z_TARGET 39 | dist_from_target = np.linalg.norm(self.umbrella.get_position() - target_pos) 40 | 41 | grasped = self._grasped_cond.condition_met()[0] 42 | lifted = dist_from_target < 0.03 # less than 3cm 43 | 44 | grasp_umbrella_reward = lift_umbrella_reward = 0.0 45 | 46 | 47 | if not grasped: 48 | grasp_umbrella_reward = np.exp( 49 | -np.linalg.norm( 50 | self.umbrella.get_position() 51 | - self.robot.arm.get_tip().get_position() 52 | ) 53 | ) 54 | else: 55 | grasp_umbrella_reward = 1.0 56 | 57 | if lifted: 58 | lift_umbrella_reward = 1.0 59 | else: 60 | lift_umbrella_reward = np.exp(-dist_from_target) 61 | 62 | reward = grasp_umbrella_reward + lift_umbrella_reward 63 | 64 | return reward 65 | 66 | def get_low_dim_state(self) -> np.ndarray: 67 | # For ad-hoc reward computation, attach reward 68 | state = super().get_low_dim_state() 69 | return state -------------------------------------------------------------------------------- /rl/env/rlbench_envs.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | from typing import Union, Dict, Tuple 5 | from pathlib import Path 6 | import shutil 7 | import inspect 8 | 9 | from pyrep.const import RenderMode 10 | from pyrep.objects.vision_sensor import VisionSensor 11 | from pyrep.objects.dummy import Dummy 12 | from pyrep.backend import sim 13 | from pyrep.objects import Object 14 | 15 | import rlbench 16 | from rlbench.action_modes.action_mode import MoveArmThenGripper 17 | from rlbench.action_modes.gripper_action_modes import Discrete 18 | from rlbench.backend.task import Task 19 | from rlbench.environment import Environment 20 | from rlbench.observation_config import ObservationConfig, CameraConfig 21 | from rlbench.utils import name_to_task_class 22 | from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaIK, EndEffectorPoseViaPlanning, JointPosition 23 | from rlbench.backend.exceptions import InvalidActionError 24 | 25 | from pyquaternion import Quaternion 26 | from scipy.spatial.transform import Rotation 27 | 28 | from .custom_arm_action_modes import ERAngleViaIK, EEOrientationState, ERJointViaIK, TimeoutEndEffectorPoseViaIK 29 | from .custom_rlbench_tasks.tasks import CUSTOM_TASKS 30 | 31 | 32 | REACH_TASKS = ['reach_target', 'reach_target_and_elbow', 'reach_target_wall', 'reach_target_shelf'] 33 | SHAPED_TASKS = list(CUSTOM_TASKS.keys()) 34 | CUSTOM_CAM_TASKS = { 35 | 'bottle_in_open_fridge' : 'overhead', 36 | 'take_bottle_out_fridge': 'overhead', 37 | 'bottle_out_moving_fridge': 'overhead', 38 | 'plate_out_open_dishwasher' : 'right_shoulder', 39 | } 40 | reorder_wxyz_to_xyzw = lambda q: [q[1], q[2], q[3], q[0]] 41 | 42 | """ 43 | List of possible low-dim states 44 | -------------------------------- 45 | joint_velocities 46 | joint_velocities_noise 47 | joint_positions 48 | joint_positions_noise 49 | joint_forces 50 | joint_forces_noise 51 | gripper_open 52 | gripper_pose 53 | gripper_matrix 54 | gripper_joint_positions 55 | gripper_touch_forces 56 | wrist_camera_matrix 57 | record_gripper_closing 58 | task_low_dim_state 59 | 60 | List of possible cameras 61 | ------------------------------ 62 | left_shoulder 63 | right_shoulder 64 | overhead 65 | wrist 66 | front 67 | 68 | """ 69 | 70 | DEBUG_RESOLUTION = (128,128) 71 | 72 | def decorate_observation(self, observation): 73 | observation.elbow_pos = _get_elbow_pos() 74 | observation.joint_cart_pos = _get_joint_cart_pos() 75 | observation.elbow_angle = _get_elbow_angle() 76 | if hasattr(self, 'lid'): 77 | observation.lid_pos = self.lid.get_position() 78 | return observation 79 | 80 | class RLBench: 81 | def __init__(self, name, observation_mode='state', img_size=84, action_repeat=1, use_angle : bool = True, terminal_invalid : bool = False, correct_invalid : bool = False, 82 | action_mode='ee', render_mode: Union[None, str] = None, goal_centric : bool = False, quaternion_angle : bool = False, 83 | cameras='front|wrist', state_info='gripper_open|gripper_pose|joint_cart_pos|joint_positions|task_low_dim_state', 84 | use_depth : bool = False, debug_viz : bool = False, robot_setup='panda', reward_scale = 1.0, success_scale = 0.0, 85 | pos_step_size=0.05, rot_step_size=0.05, elbow_step_size=0.075, joint_step_size=0.075, opengl3 = False, action_filter : str = 'none', 86 | erj_joint :int = 0, erj_eps : float = 2e-2, erj_delta_angle : bool = True 87 | ): 88 | if name in CUSTOM_TASKS: 89 | task_class = CUSTOM_TASKS[name] 90 | else: 91 | task_class = name_to_task_class(name) 92 | task_class.decorate_observation = decorate_observation 93 | 94 | self.name = name 95 | if action_mode in ['erangle']: 96 | self.replace_scene() 97 | 98 | # Setup observation configs 99 | self._terminal_invalid = terminal_invalid 100 | self._correct_invalid = correct_invalid 101 | self._goal_centric = goal_centric 102 | self._observation_mode = observation_mode 103 | self._cameras = sorted(cameras.split('|')) 104 | self._state_info = sorted(state_info.split('|')) 105 | # if self.name == 'take_lid_off_saucepan': 106 | # self._state_info.append('lid_pos') 107 | # self._state_info.remove('task_low_dim_state') 108 | self._reward_scale = reward_scale 109 | self._success_scale = success_scale 110 | 111 | obs_config = ObservationConfig() 112 | obs_config.set_all_high_dim(False) 113 | obs_config.set_all_low_dim(False) 114 | 115 | if observation_mode in ['state', 'states', 'both']: 116 | for st in self._state_info: 117 | setattr(obs_config, st, True) 118 | if debug_viz and observation_mode != 'both': 119 | if name in CUSTOM_CAM_TASKS: 120 | custom_camera = getattr(obs_config, f'{CUSTOM_CAM_TASKS[name]}_camera') 121 | custom_camera.rgb = True 122 | custom_camera.image_size = DEBUG_RESOLUTION 123 | custom_camera.render_mode = RenderMode.OPENGL3 if name == 'pick_and_lift' or opengl3 else RenderMode.OPENGL 124 | else: 125 | obs_config.front_camera.rgb = True 126 | obs_config.front_camera.image_size = DEBUG_RESOLUTION 127 | obs_config.front_camera.render_mode = RenderMode.OPENGL3 if name == 'pick_and_lift' or opengl3 else RenderMode.OPENGL 128 | if observation_mode in ['vision', 'pixels', 'both']: 129 | for cam in self._cameras: 130 | if name in CUSTOM_CAM_TASKS and cam == 'front': 131 | cam = CUSTOM_CAM_TASKS[self.name] 132 | camera_config = getattr(obs_config, f'{cam}_camera') 133 | camera_config.rgb = True 134 | camera_config.depth = use_depth 135 | camera_config.image_size = (img_size, img_size) 136 | camera_config.render_mode = RenderMode.OPENGL3 if name == 'pick_and_lift' or opengl3 else RenderMode.OPENGL 137 | 138 | # Setup action mode 139 | self._action_repeat = action_repeat 140 | if action_mode == 'erangle': 141 | arm_action_mode = ERAngleViaIK(absolute_mode=False,) 142 | elif action_mode == 'erjoint': 143 | arm_action_mode = ERJointViaIK(absolute_mode=False, commanded_joint=erj_joint, eps=erj_eps, delta_angle=erj_delta_angle) 144 | elif action_mode == 'ee': 145 | arm_action_mode = TimeoutEndEffectorPoseViaIK(absolute_mode=False,) 146 | elif action_mode == 'ee_plan': 147 | arm_action_mode = EndEffectorPoseViaPlanning(absolute_mode=False,) 148 | elif action_mode == 'joint': 149 | arm_action_mode = JointPosition(absolute_mode=False,) 150 | self._action_mode = action_mode.replace('_plan', '') 151 | 152 | 153 | self.POS_STEP_SIZE = pos_step_size 154 | self.ELBOW_STEP_SIZE = elbow_step_size 155 | self.ROT_STEP_SIZE = rot_step_size 156 | self.JOINT_STEP_SIZE = joint_step_size 157 | self.action_filter = action_filter 158 | 159 | action_modality = MoveArmThenGripper( 160 | arm_action_mode=arm_action_mode, 161 | gripper_action_mode=Discrete() 162 | ) 163 | 164 | # Launch environment and setup spaces 165 | self._env = Environment(action_modality, obs_config=obs_config, headless=True, robot_setup=robot_setup, 166 | shaped_rewards=True if name in SHAPED_TASKS else False,) 167 | self._env.launch() 168 | 169 | self.task = self._env.get_task(task_class) 170 | _, obs = self.task.reset() 171 | 172 | self._use_angle = use_angle 173 | self._quaternion_angle = quaternion_angle 174 | if use_angle: 175 | act_shape = self._env.action_shape if quaternion_angle or action_mode in ['joint'] else (self._env.action_shape[0]-1, ) 176 | else: 177 | act_shape = (self._env.action_shape[0]-4, ) if action_mode != 'joint' else self._env.action_shape 178 | self.act_space = spaces.Dict({'action' : spaces.Box(low=-1.0, high=1.0, shape=act_shape)}) 179 | 180 | state_space = list(obs.get_low_dim_data().shape) 181 | if 'lid_pos' in self._state_info: 182 | state_space[0] += 3 183 | if 'elbow_angle' in self._state_info: 184 | # Size of two as representing angle as a unit vector 185 | state_space[0] += 2 186 | if 'elbow_pos' in self._state_info: 187 | state_space[0] += 3 188 | if 'joint_cart_pos' in self._state_info: 189 | state_space[0] += 3 * 7 190 | state_space = tuple(state_space) 191 | 192 | self._env_obs_space = {} 193 | if observation_mode in ['state', 'states', 'both']: 194 | self._env_obs_space['state'] = spaces.Box(low=-np.inf, high=np.inf, shape=state_space) 195 | if debug_viz: 196 | self._env_obs_space["front_rgb"] = spaces.Box(low=0, high=255, shape=(3,) + DEBUG_RESOLUTION, dtype=np.uint8) 197 | if observation_mode in ['vision', 'pixels', 'both']: 198 | for cam in self._cameras: 199 | self._env_obs_space[f"{cam}_rgb"] = spaces.Box(low=0, high=255, shape=(3, img_size, img_size), dtype=np.uint8) 200 | if use_depth: 201 | self._env_obs_space[f"{cam}_depth"] = spaces.Box(low=-np.inf, high=+np.inf, shape=(1, img_size, img_size), dtype=np.float32) 202 | 203 | # Render more for extra viz 204 | self._render_mode = render_mode 205 | if render_mode is not None: 206 | # Add the camera to the scene 207 | cam_placeholder = Dummy('cam_cinematic_placeholder') 208 | self._gym_cam = VisionSensor.create([640, 360]) 209 | self._gym_cam.set_pose(cam_placeholder.get_pose()) 210 | if render_mode == 'human': 211 | self._gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED) 212 | else: 213 | self._gym_cam.set_render_mode(RenderMode.OPENGL3) 214 | 215 | def _get_state_vec(self, obs): 216 | vec = [] 217 | for k in self._state_info: 218 | data = getattr(obs, k) 219 | if type(data) == float: 220 | data = np.array([data]) 221 | if len(data.shape) == 0: 222 | data = data.reshape(1,) 223 | 224 | if self._goal_centric and self.name in REACH_TASKS: 225 | if k == 'gripper_pose': 226 | data[:3] = data[:3] - obs.task_low_dim_state.copy() 227 | if k == 'task_low_dim_state': 228 | data = data * 0 229 | vec.append(data) 230 | return vec 231 | 232 | def _extract_obs(self, obs) -> Dict[str, np.ndarray]: 233 | val = {} 234 | if 'state' in self._env_obs_space: 235 | state = np.concatenate(self._get_state_vec(obs)).astype(self.obs_space['state'].dtype) 236 | val['state'] = state 237 | for k in self._env_obs_space: 238 | if k == 'state': continue 239 | # Assuming all other observations are vision-based 240 | if self.name in CUSTOM_CAM_TASKS and 'front' in k: 241 | data = getattr(obs, k.replace('front', CUSTOM_CAM_TASKS[self.name])) 242 | else: 243 | data = getattr(obs, k) 244 | if 'depth' in k: 245 | data = np.expand_dims(data, 0) 246 | if 'rgb' in k: 247 | data = data.transpose(2,0,1) 248 | val[k] = data.astype(self.obs_space[k].dtype) 249 | return val 250 | 251 | @property 252 | def obs_space(self): 253 | spaces = { 254 | **self._env_obs_space, 255 | "is_first": gym.spaces.Box(0, 1, (), dtype=bool), 256 | "is_last": gym.spaces.Box(0, 1, (), dtype=bool), 257 | "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool), 258 | "success": gym.spaces.Box(0, 1, (), dtype=bool), 259 | "invalid_action" : gym.spaces.Box(0, np.iinfo(np.uint8).max, (), dtype=np.uint8) 260 | } 261 | return spaces 262 | 263 | def act2env(self, action): 264 | # Gripper 265 | gripper = (action[-1:].copy() + 1) / 2 266 | 267 | if self.action_filter == 'passband': 268 | pos_filter = lambda x : (np.abs(x) >= 1e-3) * x # Smoothen moves less than 1mm 269 | euler_filter = lambda x : (np.abs(x) >= 1e-2) * x # Smoothen angles less than 0.57295 degrees 270 | joint_filter = lambda x : (np.abs(x) >= 1e-2) * x # Smoothen angles less than 0.57295 degrees 271 | elif self.action_filter == 'round': 272 | pos_filter = lambda x : np.round(x, 3) # Smoothen moves less than 1mm 273 | euler_filter = lambda x : np.round(x, 2) # # Smoothen angles less than 0.57295 degrees 274 | joint_filter = lambda x : np.round(x, 2) # Smoothen angles less than 0.57295 degrees 275 | elif self.action_filter == 'none': 276 | pos_filter = euler_filter = joint_filter = lambda x : x # no filter 277 | else: 278 | raise NotImplementedError(f'Filter not implemented') 279 | 280 | 281 | if self._action_mode in ['ee', 'erangle', 'erjoint']: 282 | T_IDX = 3 283 | 284 | # Translation 285 | translation = pos_filter(self.POS_STEP_SIZE * action[:T_IDX].copy()) 286 | 287 | # Rotation 288 | if self._use_angle: 289 | if self._quaternion_angle: 290 | wxyz_rotation = Quaternion(action[T_IDX:T_IDX+4].copy() / np.linalg.norm(action[T_IDX:T_IDX+4].copy())) 291 | wxyz_rotation = Quaternion(axis=wxyz_rotation.axis, radians=np.clip(wxyz_rotation.radians, 0, self.ROT_STEP_SIZE)) 292 | xyzw_rotation = reorder_wxyz_to_xyzw(wxyz_rotation.elements) 293 | else: 294 | euler_rotation = euler_filter(action[T_IDX:T_IDX+3].copy() * self.ROT_STEP_SIZE) 295 | xyzw_rotation = Rotation.from_euler('xyz', euler_rotation).as_quat() 296 | else: 297 | wxyz_to = Quaternion([ 1.20908514e-01, 3.09618457e-07, 9.92663622e-01, -1.02228194e-06,]) 298 | x,y,z,w = self.task._scene.robot.arm.get_tip().get_quaternion() 299 | wxyz_from = Quaternion([w,x,y,z]) 300 | wxyz_rotation = wxyz_to * wxyz_from.inverse 301 | xyzw_rotation = reorder_wxyz_to_xyzw(wxyz_rotation.elements) 302 | 303 | if self._action_mode in ['erangle', 'erjoint']: 304 | elbow_filter = {'erangle' : euler_filter, 'erjoint' : joint_filter }[self._action_mode] 305 | angle = [elbow_filter(action[-2].copy() * self.ELBOW_STEP_SIZE)] 306 | env_action = np.concatenate([translation, xyzw_rotation, angle, gripper]) 307 | else: 308 | env_action = np.concatenate([translation, xyzw_rotation, gripper]) 309 | elif self._action_mode == 'joint': 310 | joint_action = joint_filter(action[:-1].copy() * self.JOINT_STEP_SIZE) 311 | env_action = np.concatenate([joint_action, gripper]) 312 | return env_action 313 | 314 | def step(self, action): 315 | assert (max(action) <= 1) and (min(action) >= -1) 316 | 317 | env_action = self.act2env(action) 318 | 319 | orig_action = action.copy() 320 | reward = 0.0 321 | invalid_action = 0 322 | must_terminate = False 323 | for _ in range(self._action_repeat): 324 | if self._correct_invalid: 325 | raise NotImplementedError('The implementation is outdated') 326 | else: 327 | try: 328 | env_obs, rew, _ = self.task.step(env_action) 329 | self.consecutive_invalid = 0 330 | except InvalidActionError as e: 331 | # Penalty for unsuccesful IK 332 | rew = 0 333 | env_obs = self._prev_env_obs 334 | invalid_action += 1 335 | self.consecutive_invalid += 1 336 | 337 | if getattr(self.task._task, 'reward_open_gripper', False): 338 | if self.consecutive_invalid == 0: # to avoid rewarding invalid states 339 | rew += (orig_action[-1] + 1) / 2 # to be in [0,1] 340 | 341 | reward += rew 342 | self._prev_reward = rew 343 | self._prev_env_obs = env_obs 344 | 345 | if getattr(self.task._task, 'terminate_episode', False): 346 | must_terminate = True 347 | reward = 0. 348 | break 349 | 350 | 351 | is_terminal = (int(self._terminal_invalid) if invalid_action > 0 else 0) or int(must_terminate) 352 | discount = 1 - is_terminal 353 | 354 | if not invalid_action: 355 | success, _ = self.task._task.success() 356 | else: 357 | success = 0 358 | 359 | if success: 360 | self.consecutive_success += 1 361 | else: 362 | self.consecutive_success = 0 363 | 364 | 365 | obs = { 366 | "reward": reward * self._reward_scale + float(success) * self._success_scale, 367 | "is_first": False, 368 | "is_last": True if (self.consecutive_invalid >= 5) or (self.consecutive_success >= 10) or must_terminate else False, # will be handled by timelimit wrapper 369 | "is_terminal": is_terminal, # if not set will be handled by per_episode function 370 | 'action' : orig_action, 371 | 'discount' : discount, 372 | 'success' : success, 373 | "invalid_action" : invalid_action 374 | } 375 | obs.update(self._extract_obs(env_obs)) 376 | return obs 377 | 378 | def reset(self, **kwargs): 379 | _, env_obs = self.task.reset(**kwargs) 380 | self.consecutive_invalid = 0 381 | self.consecutive_success = 0 382 | 383 | self._prev_env_obs = env_obs 384 | obs = { 385 | "reward": self.task._task.reward() * self._reward_scale, 386 | "is_first": True, 387 | "is_last": False, 388 | "is_terminal": False, 389 | 'action' : np.zeros_like(self.act_space['action'].sample()), 390 | 'discount' : 1, 391 | "success": False, 392 | "invalid_action" : 0 393 | } 394 | obs.update(self._extract_obs(env_obs)) 395 | self._prev_reward = obs['reward'] 396 | return obs 397 | 398 | def render(self, mode='human') -> Union[None, np.ndarray]: 399 | if mode != self._render_mode: 400 | raise ValueError( 401 | 'The render mode must match the render mode selected in the ' 402 | 'constructor. \nI.e. if you want "human" render mode, then ' 403 | 'create the env by calling: ' 404 | 'gym.make("reach_target-state-v0", render_mode="human").\n' 405 | 'You passed in mode %s, but expected %s.' % ( 406 | mode, self._render_mode)) 407 | if mode == 'rgb_array': 408 | frame = self._gym_cam.capture_rgb() 409 | frame = np.clip((frame * 255.).astype(np.uint8), 0, 255) 410 | return frame 411 | 412 | def get_demos(self, *args, **kwargs,): 413 | return self.task.get_demos(*args, **kwargs) 414 | 415 | def close(self) -> None: 416 | self._env.shutdown() 417 | 418 | def __del__(self,) -> None: 419 | self.close() 420 | 421 | def replace_scene(self,): 422 | task_src = Path(inspect.getfile(self.__class__)).parent / 'custom_rlbench_tasks' / 'elbow_angle_task_design.ttt' 423 | task_dst = Path(rlbench.__path__[0]) / 'task_design.ttt' 424 | 425 | shutil.copyfile(task_src, task_dst) 426 | 427 | def __getattr__(self, name): 428 | if name in ['obs_space', 'act_space']: 429 | return self.__getattribute__(name) 430 | else: 431 | return getattr(self._env, name) 432 | 433 | def _get_elbow_angle() -> np.ndarray: 434 | try: 435 | joint2_obj = Object.get_object(sim.simGetObjectHandle('Panda_joint2')) 436 | joint7_obj = Object.get_object(sim.simGetObjectHandle('Panda_joint7')) 437 | elbow_obj = Object.get_object(sim.simGetObjectHandle('Panda_joint4')) 438 | w = joint7_obj.get_position() - joint2_obj.get_position() 439 | w = w / np.linalg.norm(w) 440 | a = joint2_obj.get_position() 441 | p = elbow_obj.get_position() 442 | 443 | # find vector on plan that is orthogonal to y axis. 444 | angle_origin = np.array([-1, 0, 0]) 445 | if not np.array_equal(w, np.array([0, 1, 0])): 446 | angle_origin = np.cross(w, np.array([0, 1, 0])) 447 | 448 | # find center of "elbow circle" 449 | alpha = sum([w_i * (p_i - a_i) for w_i, p_i, a_i in zip(w, p, a)]) / sum([w_i ** 2 for w_i in w]) 450 | center = [a_i + alpha * w_i for a_i, w_i in zip(a, w)] 451 | 452 | elbow_vector = p - center 453 | 454 | # normalise the vectors 455 | angle_origin = angle_origin / np.linalg.norm(angle_origin) 456 | elbow_vector = elbow_vector / np.linalg.norm(elbow_vector) 457 | 458 | x = np.dot(w, np.cross(angle_origin, elbow_vector)) 459 | y = np.cos(np.arcsin(x)) 460 | 461 | return np.array([x, y]) 462 | except: 463 | return np.array([0,0]) 464 | 465 | def _get_joint_cart_pos(): 466 | try: 467 | return np.concatenate([Object.get_object(sim.simGetObjectHandle(f'Panda_joint{i}')).get_position() for i in range(1,8)]) 468 | except: 469 | return np.zeros([21]) 470 | 471 | def _get_elbow_pos() -> np.ndarray: 472 | try: 473 | return Object.get_object(sim.simGetObjectHandle('Panda_joint4')).get_position() 474 | except: 475 | return np.array([0,0,0]) 476 | 477 | 478 | if __name__ == '__main__': 479 | env = RLBench('reach_target') 480 | obs = env.reset() 481 | print(obs) -------------------------------------------------------------------------------- /rl/envs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import astuple, dataclass 2 | from enum import Enum 3 | from multiprocessing import Pipe, Process 4 | from multiprocessing import set_start_method as mp_set_start_method 5 | from multiprocessing.connection import Connection 6 | from typing import Any, Callable, Iterator, List, Optional, Tuple, Dict 7 | 8 | import numpy as np 9 | from collections import defaultdict, deque 10 | import gym 11 | from pathlib import Path 12 | 13 | import sys 14 | import os 15 | from contextlib import redirect_stderr, redirect_stdout 16 | import time 17 | 18 | import traceback 19 | 20 | ERROR = -1 21 | WAITING = 0 22 | 23 | class MessageType(Enum): 24 | EXCEPTION = -3 25 | RESET = 0 26 | STEP = 2 27 | STEP_RETURN = 3 28 | CLOSE = 4 29 | OBS_SPACE = 5 30 | OBS_SPACE_RETURN = 6 31 | ACT_SPACE = 7 32 | ACT_SPACE_RETURN = 8 33 | 34 | @dataclass 35 | class Message: 36 | type: MessageType 37 | content: Optional[Any] = None 38 | 39 | def __iter__(self) -> Iterator: 40 | return iter(astuple(self)) 41 | 42 | 43 | def child_fn(child_id: int, env_fn: Callable, child_conn: Connection, redirect_output_to: str = None) -> None: 44 | np.random.seed(child_id + np.random.randint(0, 2 ** 31 - 1)) 45 | if redirect_output_to is not None: 46 | redirect_output_to = Path(redirect_output_to) 47 | os.makedirs(str(redirect_output_to / str(child_id)), exist_ok=True) 48 | with open(str(redirect_output_to / str(child_id) / "out.log"), 'a') as stdout, redirect_stdout(stdout), open(str(redirect_output_to / str(child_id) / "err.log"), 'a') as stderr, redirect_stderr(stderr): 49 | child_env(child_id, env_fn, child_conn) 50 | else: 51 | child_env(child_id, env_fn, child_conn) 52 | 53 | def child_env(child_id, env_fn: Callable, child_conn: Connection,) -> None: 54 | try: 55 | env = env_fn() 56 | while True: 57 | message_type, content = child_conn.recv() 58 | if message_type == MessageType.RESET: 59 | obs = env.reset() 60 | obs['env_idx'] = child_id 61 | obs['env_error'] = 0 62 | child_conn.send(Message(MessageType.STEP_RETURN, obs)) 63 | elif message_type == MessageType.STEP: 64 | obs = env.step(content) 65 | # if obs['is_last']: 66 | # obs = env.reset() 67 | obs['env_idx'] = child_id 68 | obs['env_error'] = 0 69 | child_conn.send(Message(MessageType.STEP_RETURN, obs)) 70 | elif message_type == MessageType.CLOSE: 71 | child_conn.close() 72 | return 73 | elif message_type == MessageType.OBS_SPACE: 74 | obs_space = env.obs_space 75 | child_conn.send(Message(MessageType.OBS_SPACE_RETURN, obs_space)) 76 | elif message_type == MessageType.ACT_SPACE: 77 | act_space = env.act_space 78 | child_conn.send(Message(MessageType.ACT_SPACE_RETURN, act_space)) 79 | else: 80 | raise NotImplementedError 81 | sys.stdout.flush(), sys.stderr.flush() 82 | except Exception as e: 83 | child_conn.send(Message(MessageType.EXCEPTION, traceback.format_exc())) 84 | 85 | class MultiProcessEnv(gym.Env): 86 | def __init__(self, env_fn: Callable, num_envs: int, redirect_output_to: str = None) -> None: 87 | super().__init__() 88 | self.env_fn = env_fn 89 | self.num_envs = num_envs 90 | self.processes, self.parent_conns, self.child_conns = [], [], [] 91 | mp_set_start_method('fork', force=True) 92 | 93 | self.idx_queue = deque() 94 | self.FPS_TARGET = 100 # FPS 95 | self.TIMEOUT_LIMIT = 20 96 | self.timeouts = np.array([np.inf for _ in range(num_envs)], dtype=np.float64) 97 | 98 | for child_id in range(num_envs): 99 | parent_conn, child_conn = Pipe() 100 | self.parent_conns.append(parent_conn) 101 | self.child_conns.append(child_conn) 102 | p = Process(target=child_fn, args=(child_id, env_fn, child_conn, redirect_output_to), daemon=True) 103 | self.processes.append(p) 104 | for idx, p in enumerate(self.processes): 105 | p.start() 106 | # Waiting for reset to work, to avoid concurrency issues in starting the simulator 107 | self.reset_idx(idx) 108 | 109 | self._obs_space, self._act_space = None, None 110 | print("Observation space:") 111 | for k, v in self.obs_space.items(): 112 | print(f" {k} : {v}") 113 | print("Action space:") 114 | for k, v in self.act_space.items(): 115 | print(f" {k} : {v}") 116 | 117 | def _clear(self,): 118 | self.timeouts = np.array([np.inf for _ in range(self.num_envs)], dtype=np.float64) 119 | for p in self.parent_conns: 120 | if p.poll(): 121 | p.recv() 122 | 123 | def _restore_process_idx(self, idx : float): 124 | print("Restoring process", idx) 125 | self.child_conns[idx].close() 126 | self.parent_conns[idx].close() 127 | self.processes[idx].kill() 128 | 129 | parent_conn, child_conn = Pipe() 130 | self.parent_conns[idx] = parent_conn 131 | self.child_conns[idx] = child_conn 132 | p = Process(target=child_fn, args=(idx, self.env_fn, child_conn), daemon=True) 133 | self.processes[idx] = p 134 | self.timeouts[idx] = np.inf 135 | p.start() 136 | 137 | def _receive_idx(self, idx : int, check_type : Optional[MessageType], timeout = 1e-8): 138 | timeout = max(1 / (self.FPS_TARGET * self.num_envs), timeout) 139 | t = time.time() 140 | if self.parent_conns[idx].poll(timeout=timeout): 141 | # do stuff 142 | elapsed = time.time() - t 143 | self.timeouts -= elapsed # Time passes for all processes 144 | message = self.parent_conns[idx].recv() 145 | if check_type is not None: 146 | if message.type == MessageType.EXCEPTION: 147 | print(f"Received exception from process {idx} : {message.content}") 148 | return {'env_idx' : idx, 'env_error' : 1} 149 | else: 150 | assert message.type == check_type, f"Process: {idx}, received type: {message.type}, request type: {check_type}" 151 | self.timeouts[idx] = self.TIMEOUT_LIMIT 152 | content = message.content 153 | return content 154 | self.timeouts -= timeout # Time passes for all processes 155 | if self.timeouts[idx] > 0: 156 | return WAITING 157 | else: 158 | self._restore_process_idx(idx) 159 | return {'env_idx' : idx, 'env_error' : 1} 160 | 161 | def _receive_all(self, check_type: Optional[MessageType] = None, timeout = 1e-8) -> List[Any]: 162 | contents = [self._receive_idx(idx, check_type=check_type, timeout=timeout) for idx in range(self.num_envs)] 163 | return contents 164 | 165 | def _send_reset(self, idx): 166 | self.parent_conns[idx].send(Message(MessageType.RESET)) 167 | self.timeouts[idx] = self.TIMEOUT_LIMIT 168 | self.idx_queue.append(idx) 169 | 170 | def reset_idx(self, idx) -> np.ndarray: 171 | content = ERROR # starting with ERROR to trigger the loop 172 | self.parent_conns[idx].send(Message(MessageType.RESET)) 173 | self.timeouts[idx] = self.TIMEOUT_LIMIT 174 | attempts = 600 / self.TIMEOUT_LIMIT # 300secs = 5mins of attempts 175 | while content in [ERROR, WAITING]: 176 | content = self._receive_idx(idx, check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT) 177 | if content['env_error']: 178 | content = ERROR 179 | attempts -= 1 180 | print(f"Remaining attempts for process {idx}: {attempts}") 181 | self.parent_conns[idx].send(Message(MessageType.RESET)) 182 | self.timeouts[idx] = self.TIMEOUT_LIMIT 183 | if attempts <= 0: 184 | raise ChildProcessError("Could not reset environment") 185 | return content 186 | 187 | def reset_all(self) -> np.ndarray: 188 | # Sending messages and setting timeouts 189 | for parent_conn in self.parent_conns: 190 | parent_conn.send(Message(MessageType.RESET)) 191 | self.timeouts = np.array([self.TIMEOUT_LIMIT for _ in range(self.num_envs)], dtype=np.float64) 192 | 193 | content = self._receive_all(check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT) 194 | ret_obs = defaultdict(list) 195 | for c in content: 196 | if c['env_error']: 197 | for k,v in {**self.obs_space, **self.act_space,}.items(): 198 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype)) 199 | ret_obs['reward'].append(0.) 200 | ret_obs['discount'].append(0.) 201 | else: 202 | for k,v in c.items(): 203 | ret_obs[k].append(v) 204 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()} 205 | return ret_obs 206 | 207 | def step_all(self, actions: np.ndarray) -> Dict: 208 | # Sending messages and setting timeouts 209 | for parent_conn, action in zip(self.parent_conns, actions): 210 | parent_conn.send(Message(MessageType.STEP, action)) 211 | self.timeouts = np.array([self.TIMEOUT_LIMIT for _ in range(self.num_envs)], dtype=np.float64) 212 | 213 | content = self._receive_all(check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT) 214 | ret_obs = defaultdict(list) 215 | for c in content: 216 | if c['env_error']: 217 | for k,v in {**self.obs_space, **self.act_space,}.items(): 218 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype)) 219 | ret_obs['reward'].append(0.) 220 | ret_obs['discount'].append(0.) 221 | # For all cases (also in case of error) 222 | for k,v in c.items(): 223 | ret_obs[k].append(v) 224 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()} 225 | return ret_obs 226 | 227 | def step_by_idx(self, actions: np.ndarray, idxs : List, requested_steps, ignore_idxs : List = []) -> Dict: 228 | for idx, action in zip(idxs, actions): 229 | if idx in ignore_idxs: 230 | continue 231 | self.parent_conns[idx].send(Message(MessageType.STEP, action)) 232 | self.idx_queue.append(idx) 233 | self.timeouts[idx] = self.TIMEOUT_LIMIT 234 | ret_obs = defaultdict(list) 235 | while len(ret_obs['env_idx']) < requested_steps: 236 | idx = self.idx_queue.popleft() 237 | c = self._receive_idx(idx, check_type=MessageType.STEP_RETURN,) 238 | if c == WAITING: 239 | self.idx_queue.append(idx) 240 | continue 241 | if c['env_error']: 242 | for k,v in {**self.obs_space, **self.act_space,}.items(): 243 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype)) 244 | ret_obs['reward'].append(0.) 245 | ret_obs['discount'].append(0.) 246 | # For all cases (also in case of error) 247 | for k,v in c.items(): 248 | ret_obs[k].append(v) 249 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()} 250 | return ret_obs 251 | 252 | def step_receive_by_idx(self, actions: np.ndarray, send_idxs : List, recv_idxs : List) -> Dict: 253 | # Send 254 | for idx, action in zip(send_idxs, actions): 255 | self.parent_conns[idx].send(Message(MessageType.STEP, action)) 256 | self.timeouts[idx] = self.TIMEOUT_LIMIT 257 | ret_obs = defaultdict(list) 258 | if len(recv_idxs) == 0: 259 | return 260 | # Receive 261 | content = [self._receive_idx(idx, check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT) for idx in recv_idxs] 262 | for c in content: 263 | if c['env_error']: 264 | for k,v in {**self.obs_space, **self.act_space,}.items(): 265 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype)) 266 | ret_obs['reward'].append(0.) 267 | ret_obs['discount'].append(0.) 268 | # For all cases (also in case of error) 269 | for k,v in c.items(): 270 | ret_obs[k].append(v) 271 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()} 272 | return ret_obs 273 | 274 | @property 275 | def obs_space(self,): 276 | while self._obs_space in [None, WAITING]: 277 | self.parent_conns[0].send(Message(MessageType.OBS_SPACE, None)) 278 | self.timeouts[0] = self.TIMEOUT_LIMIT 279 | content = self._receive_idx(0, check_type=MessageType.OBS_SPACE_RETURN, timeout=self.TIMEOUT_LIMIT) 280 | self._obs_space = content 281 | if 'env_error' in self._obs_space: 282 | raise ChildProcessError("Problem instantiating the environments") 283 | return self._obs_space 284 | 285 | @property 286 | def act_space(self,): 287 | while self._act_space in [None, WAITING]: 288 | self.parent_conns[0].send(Message(MessageType.ACT_SPACE, None)) 289 | self.timeouts[0] = self.TIMEOUT_LIMIT 290 | content = self._receive_idx(0, check_type=MessageType.ACT_SPACE_RETURN, timeout=self.TIMEOUT_LIMIT) 291 | self._act_space = content 292 | if 'env_error' in self._act_space: 293 | raise ChildProcessError("Problem instantiating the environments") 294 | return self._act_space 295 | 296 | def close(self) -> None: 297 | for parent_conn in self.parent_conns: 298 | parent_conn.send(Message(MessageType.CLOSE)) 299 | for parent_conn in self.parent_conns: 300 | parent_conn.close() 301 | for p in self.processes: 302 | if p.is_alive(): 303 | p.join(5) 304 | 305 | def __del__(self): 306 | self.close() 307 | 308 | class TimeLimit: 309 | def __init__(self, env, duration): 310 | self._env = env 311 | self._duration = duration 312 | self._step = None 313 | 314 | def __getattr__(self, name): 315 | if name.startswith('__'): 316 | raise AttributeError(name) 317 | try: 318 | return getattr(self._env, name) 319 | except AttributeError: 320 | raise ValueError(name) 321 | 322 | def step(self, action): 323 | assert self._step is not None, 'Must reset environment.' 324 | obs = self._env.step(action) 325 | self._step += 1 326 | if self._duration and self._step >= self._duration: 327 | obs['is_last'] = True 328 | self._step = None 329 | return obs 330 | 331 | def reset(self, **kwargs): 332 | self._step = 0 333 | return self._env.reset(**kwargs) 334 | 335 | def reset_with_task_id(self, task_id): 336 | self._step = 0 337 | return self._env.reset_with_task_id(task_id) 338 | 339 | def make(name, obs_type, frame_stack, action_repeat, seed, cfg=None, img_size=84, exorl=False, is_eval=False): 340 | assert obs_type in ['states', 'pixels', 'both'] 341 | domain, task = name.split('_', 1) 342 | if domain == 'rlbench': 343 | import env.rlbench_envs as rlbench_envs 344 | return TimeLimit(rlbench_envs.RLBench(task, observation_mode=obs_type, action_repeat=action_repeat, **cfg.env), 200 // action_repeat) 345 | else: 346 | raise NotImplementedError("") -------------------------------------------------------------------------------- /rl/hydra/hydra_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | # A logger config that directs hydra verbose logging to a file 2 | version: 1 3 | formatters: 4 | simple: 5 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 6 | handlers: 7 | file: 8 | class: logging.FileHandler 9 | mode: w 10 | formatter: simple 11 | # relative to the job log directory 12 | filename: exp_hydra_logs/${now:%Y.%m.%d}_${now:%H%M%S}_${experiment}_${agent.name}_${obs_type}.log 13 | delay: true 14 | root: 15 | level: DEBUG 16 | handlers: [file] 17 | 18 | loggers: 19 | hydra: 20 | level: DEBUG 21 | 22 | disable_existing_loggers: false -------------------------------------------------------------------------------- /rl/logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import wandb 9 | from termcolor import colored 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 14 | ('episode_reward', 'R', 'float'), 15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')] 16 | 17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'), 18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'), 19 | ('episode_reward', 'R', 'float'), 20 | ('total_time', 'T', 'time')] 21 | 22 | 23 | class AverageMeter(object): 24 | def __init__(self): 25 | self._sum = 0 26 | self._count = 0 27 | 28 | def update(self, value, n=1): 29 | self._sum += value 30 | self._count += n 31 | 32 | def value(self): 33 | return self._sum / max(1, self._count) 34 | 35 | 36 | class MetersGroup(object): 37 | def __init__(self, csv_file_name, formating, use_wandb): 38 | self._csv_file_name = csv_file_name 39 | self._formating = formating 40 | self._meters = defaultdict(AverageMeter) 41 | self._csv_file = None 42 | self._csv_writer = None 43 | self.use_wandb = use_wandb 44 | 45 | def log(self, key, value, n=1): 46 | self._meters[key].update(value, n) 47 | 48 | def _prime_meters(self): 49 | data = dict() 50 | for key, meter in self._meters.items(): 51 | if key.startswith('train'): 52 | key = key[len('train') + 1:] 53 | else: 54 | key = key[len('eval') + 1:] 55 | key = key.replace('/', '_') 56 | data[key] = meter.value() 57 | return data 58 | 59 | def _remove_old_entries(self, data): 60 | rows = [] 61 | with self._csv_file_name.open('r') as f: 62 | reader = csv.DictReader(f) 63 | for row in reader: 64 | if 'episode' in row: 65 | # BUGFIX: covers weird cases where CSV are badly written 66 | if row['episode'] == '': 67 | rows.append(row) 68 | continue 69 | if type(row['episode']) == type(None): 70 | continue 71 | if float(row['episode']) >= data['episode']: 72 | break 73 | rows.append(row) 74 | with self._csv_file_name.open('w') as f: 75 | # To handle CSV that have more keys than new data 76 | keys = set(data.keys()) 77 | if len(rows) > 0: keys = keys | set(row.keys()) 78 | keys = sorted(list(keys)) 79 | # 80 | writer = csv.DictWriter(f, 81 | fieldnames=keys, 82 | restval=0.0) 83 | writer.writeheader() 84 | for row in rows: 85 | writer.writerow(row) 86 | 87 | def _dump_to_csv(self, data): 88 | if self._csv_writer is None: 89 | should_write_header = True 90 | if self._csv_file_name.exists(): 91 | self._remove_old_entries(data) 92 | should_write_header = False 93 | 94 | self._csv_file = self._csv_file_name.open('a') 95 | self._csv_writer = csv.DictWriter(self._csv_file, 96 | fieldnames=sorted(data.keys()), 97 | restval=0.0) 98 | if should_write_header: 99 | self._csv_writer.writeheader() 100 | 101 | # To handle components that start training later 102 | # (restval covers only when data has less keys than the CSV) 103 | if self._csv_writer.fieldnames != sorted(data.keys()) and \ 104 | len(self._csv_writer.fieldnames) < len(data.keys()): 105 | self._csv_file.close() 106 | self._csv_file = self._csv_file_name.open('r') 107 | dict_reader = csv.DictReader(self._csv_file) 108 | rows = [row for row in dict_reader] 109 | self._csv_file.close() 110 | self._csv_file = self._csv_file_name.open('w') 111 | self._csv_writer = csv.DictWriter(self._csv_file, 112 | fieldnames=sorted(data.keys()), 113 | restval=0.0) 114 | self._csv_writer.writeheader() 115 | for row in rows: 116 | self._csv_writer.writerow(row) 117 | 118 | self._csv_writer.writerow(data) 119 | self._csv_file.flush() 120 | 121 | def _format(self, key, value, ty): 122 | if ty == 'int': 123 | value = int(value) 124 | return f'{key}: {value}' 125 | elif ty == 'float': 126 | return f'{key}: {value:.04f}' 127 | elif ty == 'time': 128 | value = str(datetime.timedelta(seconds=int(value))) 129 | return f'{key}: {value}' 130 | else: 131 | raise f'invalid format type: {ty}' 132 | 133 | def _dump_to_console(self, data, prefix): 134 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 135 | pieces = [f'| {prefix: <14}'] 136 | for key, disp_key, ty in self._formating: 137 | value = data.get(key, 0) 138 | pieces.append(self._format(disp_key, value, ty)) 139 | print(' | '.join(pieces)) 140 | 141 | def _dump_to_wandb(self, data): 142 | wandb.log(data) 143 | 144 | def dump(self, step, prefix): 145 | if len(self._meters) == 0: 146 | return 147 | data = self._prime_meters() 148 | data['frame'] = step 149 | if self.use_wandb: 150 | wandb_data = {prefix + '/' + key: val for key, val in data.items()} 151 | self._dump_to_wandb(data=wandb_data) 152 | self._dump_to_console(data, prefix) 153 | self._meters.clear() 154 | 155 | 156 | class Logger(object): 157 | def __init__(self, log_dir, use_tb, use_wandb): 158 | self._log_dir = log_dir 159 | self._train_mg = MetersGroup(log_dir / 'train.csv', 160 | formating=COMMON_TRAIN_FORMAT, 161 | use_wandb=use_wandb) 162 | self._eval_mg = MetersGroup(log_dir / 'eval.csv', 163 | formating=COMMON_EVAL_FORMAT, 164 | use_wandb=use_wandb) 165 | if use_tb: 166 | self._sw = SummaryWriter(str(log_dir / 'tb')) 167 | else: 168 | self._sw = None 169 | self.use_wandb = use_wandb 170 | 171 | def _try_sw_log(self, key, value, step): 172 | if self._sw is not None: 173 | self._sw.add_scalar(key, value, step) 174 | 175 | def log(self, key, value, step): 176 | assert key.startswith('train') or key.startswith('eval') 177 | if type(value) == torch.Tensor: 178 | value = value.item() 179 | self._try_sw_log(key, value, step) 180 | mg = self._train_mg if key.startswith('train') else self._eval_mg 181 | mg.log(key, value) 182 | 183 | def log_metrics(self, metrics, step, ty): 184 | for key, value in metrics.items(): 185 | self.log(f'{ty}/{key}', value, step) 186 | 187 | def dump(self, step, ty=None): 188 | if ty is None or ty == 'eval': 189 | self._eval_mg.dump(step, 'eval') 190 | if ty is None or ty == 'train': 191 | self._train_mg.dump(step, 'train') 192 | 193 | def log_and_dump_ctx(self, step, ty): 194 | return LogAndDumpCtx(self, step, ty) 195 | 196 | def log_video(self, data, step): 197 | if self._sw is not None: 198 | for k, v in data.items(): 199 | self._sw.add_video(k, v, global_step=step, fps=15) 200 | if self.use_wandb: 201 | for k, v in data.items(): 202 | if type(v) == torch.Tensor: 203 | v = v.cpu() 204 | v = np.uint8(v) 205 | wandb.log({k: wandb.Video(v, fps=15, format="gif")}) 206 | 207 | 208 | class LogAndDumpCtx: 209 | def __init__(self, logger, step, ty): 210 | self._logger = logger 211 | self._step = step 212 | self._ty = ty 213 | 214 | def __enter__(self): 215 | return self 216 | 217 | def __call__(self, key, value): 218 | self._logger.log(f'{self._ty}/{key}', value, self._step) 219 | 220 | def __exit__(self, *args): 221 | self._logger.dump(self._step, self._ty) 222 | -------------------------------------------------------------------------------- /rl/np_replay.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | import io 4 | import pathlib 5 | import uuid 6 | import os 7 | 8 | import numpy as np 9 | from gym.spaces import Dict 10 | import random 11 | from torch.utils.data import IterableDataset, DataLoader 12 | import torch 13 | import utils 14 | import traceback 15 | 16 | SIG_FAILURE = -1 17 | 18 | def get_length(filename): 19 | if "-" in str(filename): 20 | length = int(str(filename).split('-')[-1]) 21 | else: 22 | length = int(str(filename).split('_')[-1]) 23 | return length 24 | 25 | def get_idx(filename): 26 | if "-" in str(filename): 27 | length = int(str(filename).split('-')[0]) 28 | else: 29 | length = int(str(filename).split('_')[0]) 30 | return length 31 | 32 | def on_fn(): return collections.defaultdict(list) # this function is to avoid lambdas 33 | 34 | class ReplayBuffer(IterableDataset): 35 | 36 | def __init__( 37 | self, data_specs, meta_specs, directory, length=20, capacity=0, ongoing=False, minlen=1, maxlen=0, 38 | prioritize_ends=False, device='cuda', load_first=False, delete_old_storage=False, save_episodes=True, **kwargs): 39 | self._directory = pathlib.Path(directory).expanduser() 40 | self._directory.mkdir(parents=True, exist_ok=True) 41 | self._capacity = capacity 42 | self._ongoing = ongoing 43 | self._minlen = minlen 44 | self._maxlen = maxlen 45 | self._prioritize_ends = prioritize_ends 46 | # self._random = np.random.RandomState() 47 | # filename -> key -> value_sequence 48 | 49 | self._save_episodes = save_episodes 50 | self._last_added_idx = 0 51 | 52 | if delete_old_storage: 53 | raise NotImplementedError("May want this feature") 54 | 55 | self._episode_lens = np.array([]) 56 | self._complete_eps = {} 57 | for spec_group in [data_specs, meta_specs]: 58 | for spec in spec_group: 59 | if type(spec) in [dict, Dict]: 60 | for k,v in spec.items(): 61 | self._complete_eps[k] = [] 62 | else: 63 | self._complete_eps[spec.name] = [] 64 | 65 | # load episodes 66 | self._total_episodes, self._total_steps = count_episodes(directory) 67 | self._loaded_episodes = 0 68 | self._loaded_steps = 0 69 | for f in load_filenames(self._directory, capacity, minlen, load_first=load_first): 70 | self.store_episode(filename=f) 71 | 72 | # worker -> key -> value_sequence 73 | self._length = length 74 | self._ongoing_eps = collections.defaultdict(on_fn) 75 | self._data_specs = data_specs 76 | self._meta_specs = meta_specs 77 | self.device = device 78 | try: 79 | assert self._minlen <= self._length <= self._maxlen 80 | except: 81 | print("Sampling sequences with fixed length ", length) 82 | self._minlen = self._maxlen = self._length = length 83 | 84 | def __len__(self): 85 | return self._total_steps 86 | 87 | def preallocate_memory(self, max_size): 88 | self._preallocated_mem = collections.defaultdict(list) 89 | for spec in self._data_specs: 90 | if type(spec) in [dict, Dict]: 91 | for k,v in spec.items(): 92 | for _ in range(max_size): 93 | self._preallocated_mem[k].append(np.empty(list(v.shape), v.dtype)) 94 | self._preallocated_mem[k][-1].fill(0.) 95 | else: 96 | for _ in range(max_size): 97 | self._preallocated_mem[spec.name].append(np.empty(list(v.shape), v.dtype)) 98 | self._preallocated_mem[spec.name][-1].fill(0.) 99 | 100 | @property 101 | def stats(self): 102 | return { 103 | 'total_steps': self._total_steps, 104 | 'total_episodes': self._total_episodes, 105 | 'loaded_steps': self._loaded_steps, 106 | 'loaded_episodes': self._loaded_episodes, 107 | } 108 | 109 | def add(self, time_step, meta, idx=0): 110 | ### Useful if there was any failure in the environment 111 | if time_step == SIG_FAILURE: 112 | episode = self._ongoing_eps[idx] 113 | episode.clear() 114 | print("Discarding episode from process", idx) 115 | return 116 | #### 117 | 118 | episode = self._ongoing_eps[idx] 119 | 120 | def add_to_episode(name, data, spec): 121 | value = data[name] 122 | if np.isscalar(value): 123 | value = np.full(spec.shape, value, spec.dtype) 124 | assert spec.shape == value.shape and spec.dtype == value.dtype, f"Expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })" 125 | ### Deallocate preallocated memory 126 | if getattr(self, '_preallocated_mem', False): 127 | if len(self._preallocated_mem[name]) > 0: 128 | tmp = self._preallocated_mem[name].pop() 129 | del tmp 130 | else: 131 | # Out of pre-allocated memory 132 | del self._preallocated_mem 133 | ### 134 | episode[name].append(value) 135 | 136 | for spec in self._data_specs: 137 | if type(spec) in [dict, Dict]: 138 | for k,v in spec.items(): 139 | add_to_episode(k, time_step, v) 140 | else: 141 | add_to_episode(spec.name, time_step, spec) 142 | for spec in self._meta_specs: 143 | if type(spec) in [dict, Dict]: 144 | for k,v in spec.items(): 145 | add_to_episode(k, meta, v) 146 | else: 147 | add_to_episode(spec.name, meta, spec) 148 | if type(time_step) in [dict, Dict]: 149 | if time_step['is_last']: 150 | self.add_episode(episode) 151 | episode.clear() 152 | else: 153 | if time_step.last(): 154 | self.add_episode(episode) 155 | episode.clear() 156 | 157 | def add_episode(self, episode): 158 | length = eplen(episode) 159 | if length < self._minlen: 160 | print(f'Skipping short episode of length {length}.') 161 | return 162 | self._total_steps += length 163 | self._total_episodes += 1 164 | episode = {key: convert(value) for key, value in episode.items()} 165 | if self._save_episodes: 166 | filename = self.save_episode(self._directory, episode) 167 | self.store_episode(episode=episode) 168 | 169 | def store_episode(self, filename=None, episode=None): 170 | if filename is not None: 171 | episode = load_episode(filename) 172 | if not episode: 173 | return False 174 | length = eplen(episode) 175 | 176 | # Enforce limit 177 | while self._loaded_steps + length > self._capacity: 178 | for k in self._complete_eps: 179 | self._complete_eps[k].pop(0) 180 | removed_len, self._episode_lens = self._episode_lens[0], self._episode_lens[1:] 181 | self._loaded_steps -= removed_len 182 | self._loaded_episodes -= 1 183 | 184 | # add episode 185 | for k,v in episode.items(): 186 | self._complete_eps[k].append(v) 187 | self._episode_lens = np.append(self._episode_lens, length) 188 | self._loaded_steps += length 189 | self._loaded_episodes += 1 190 | 191 | return True 192 | 193 | def __iter__(self): 194 | while True: 195 | sequences, batch_size, batch_length = self._loaded_episodes, self.batch_size, self._length 196 | 197 | b_indices = np.random.randint(0, sequences, size=batch_size) 198 | t_indices = np.random.randint(np.zeros(batch_size), self._episode_lens[b_indices]-batch_length+1, size=batch_size) 199 | t_ranges = np.repeat( np.expand_dims(np.arange(0, batch_length,), 0), batch_size, axis=0) + np.expand_dims(t_indices, 1) 200 | 201 | chunk = {} 202 | for k in self._complete_eps: 203 | chunk[k] = np.stack([self._complete_eps[k][b][t] for b,t in zip(b_indices, t_ranges)]) 204 | yield chunk 205 | 206 | @utils.retry 207 | def save_episode(self, directory, episode): 208 | idx = self._total_episodes 209 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 210 | identifier = str(uuid.uuid4().hex) 211 | length = eplen(episode) 212 | filename = directory / f'{idx}-{timestamp}-{identifier}-{length}.npz' 213 | with io.BytesIO() as f1: 214 | np.savez_compressed(f1, **episode) 215 | f1.seek(0) 216 | with filename.open('wb') as f2: 217 | f2.write(f1.read()) 218 | return filename 219 | 220 | def load_episode(filename): 221 | try: 222 | with filename.open('rb') as f: 223 | episode = np.load(f) 224 | episode = {k: episode[k] for k in episode.keys()} 225 | except Exception as e: 226 | print(f'Could not load episode {str(filename)}: {e}') 227 | return False 228 | return episode 229 | 230 | def count_episodes(directory): 231 | filenames = list(directory.glob('*.npz')) 232 | num_episodes = len(filenames) 233 | if num_episodes == 0 : return 0, 0 234 | if len(filenames) > 0 and "-" in str(filenames[0]): 235 | num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames) 236 | last_episode = sorted(list(int(n.stem.split('-')[0]) for n in filenames))[-1] 237 | else: 238 | num_steps = sum(int(str(n).split('_')[-1][:-4]) - 1 for n in filenames) 239 | last_episode = sorted(list(int(n.stem.split('_')[0]) for n in filenames))[-1] 240 | return last_episode, num_steps 241 | 242 | def load_filenames(directory, capacity=None, minlen=1, load_first=False): 243 | # The returned directory from filenames to episodes is guaranteed to be in 244 | # temporally sorted order. 245 | filenames = sorted(directory.glob('*.npz')) 246 | if capacity: 247 | num_steps = 0 248 | num_episodes = 0 249 | ordered_filenames = filenames if load_first else reversed(filenames) 250 | for filename in ordered_filenames: 251 | if "-" in str(filename): 252 | length = int(str(filename).split('-')[-1][:-4]) 253 | else: 254 | length = int(str(filename).split('_')[-1][:-4]) 255 | num_steps += length 256 | num_episodes += 1 257 | if num_steps >= capacity: 258 | break 259 | if load_first: 260 | filenames = filenames[:num_episodes] 261 | else: 262 | filenames = filenames[-num_episodes:] 263 | return filenames 264 | 265 | def convert(value): 266 | value = np.array(value) 267 | if np.issubdtype(value.dtype, np.floating): 268 | return value.astype(np.float32) 269 | elif np.issubdtype(value.dtype, np.signedinteger): 270 | return value.astype(np.int32) 271 | elif np.issubdtype(value.dtype, np.uint8): 272 | return value.astype(np.uint8) 273 | return value 274 | 275 | 276 | def eplen(episode): 277 | return len(episode['action']) 278 | 279 | def make_replay_loader(buffer, batch_size, num_workers=0): 280 | buffer.batch_size = batch_size 281 | return DataLoader(buffer, 282 | batch_size=None, 283 | # NOTE: do not use any workers, 284 | # as they don't get copies of the replay buffer 285 | # (takes more time to refetch the data than sampling in main) 286 | ) 287 | -------------------------------------------------------------------------------- /rl/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import traceback 3 | 4 | warnings.warn_explicit = warnings.warn = lambda *_, **__: None 5 | warnings.filterwarnings('ignore', category=DeprecationWarning) 6 | 7 | 8 | import os 9 | import sys 10 | from contextlib import redirect_stderr, redirect_stdout 11 | 12 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' 13 | 14 | os.environ['MUJOCO_GL'] = 'egl' 15 | os.environ['PYDEVD_UNBLOCK_THREADS_TIMEOUT'] = '900000' 16 | 17 | from pathlib import Path 18 | 19 | import hydra 20 | import numpy as np 21 | import torch 22 | import wandb 23 | from dm_env import specs 24 | 25 | import envs 26 | import utils 27 | from logger import Logger 28 | from np_replay import ReplayBuffer, make_replay_loader, SIG_FAILURE 29 | from collections import defaultdict 30 | 31 | from functools import partial 32 | 33 | torch.backends.cudnn.benchmark = True 34 | 35 | def get_gpu_memory(): 36 | import subprocess as sp 37 | command = "nvidia-smi --query-gpu=memory.free --format=csv" 38 | memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:] 39 | memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] 40 | return memory_free_values 41 | 42 | def make_agent(obs_space, act_space, cur_config, cfg): 43 | from copy import deepcopy 44 | cur_config = deepcopy(cur_config) 45 | del cur_config.agent 46 | return hydra.utils.instantiate(cfg, cfg=cur_config, obs_space=obs_space, act_space=act_space) 47 | 48 | class Workspace: 49 | def __init__(self, cfg, savedir=None, workdir=None): 50 | self.workdir = Path.cwd() if workdir is None else workdir 51 | print(f'workspace: {self.workdir}') 52 | 53 | device = None 54 | 55 | self.cfg = cfg 56 | utils.set_seed_everywhere(cfg.seed) 57 | 58 | # create logger 59 | self.logger = Logger(self.workdir, 60 | use_tb=cfg.use_tb, 61 | use_wandb=cfg.use_wandb) 62 | # create envs 63 | task = cfg.task 64 | frame_stack = 1 65 | img_size = getattr(getattr(cfg, 'env', None), 'img_size', 84) # 84 is the DrQ default 66 | 67 | self.parallel_envs = cfg.parallel_envs 68 | 69 | if cfg.spread_envs and os.environ['DISPLAY']: 70 | if torch.cuda.device_count() > 1: 71 | proposed_display = np.argmax(get_gpu_memory()).item() 72 | if proposed_display < torch.cuda.device_count(): 73 | os.environ['DISPLAY'] = os.environ['DISPLAY'] + '.' + str(proposed_display) 74 | 75 | self.train_env_fn = partial(envs.make, task, cfg.obs_type, frame_stack, 76 | cfg.action_repeat, cfg.seed, img_size=img_size, cfg=cfg, is_eval=False) 77 | self.train_env = envs.MultiProcessEnv(self.train_env_fn, self.parallel_envs) 78 | 79 | if cfg.flexible_gpu: 80 | import time 81 | from hydra.core.hydra_config import HydraConfig 82 | 83 | try: 84 | job_num = getattr(HydraConfig.get().job, 'num', None) 85 | if job_num is not None: 86 | print("Job number:", HydraConfig.get().job.num) 87 | time.sleep(HydraConfig.get().job.num) 88 | except: 89 | pass 90 | 91 | while device is None: 92 | try: 93 | cfg.device = device = 'cuda:' + str(np.argmax(get_gpu_memory()).item()) 94 | print("Using device:", device) 95 | except: 96 | pass 97 | 98 | self.device = torch.device(cfg.device) 99 | 100 | # # create agent 101 | self.agent = make_agent(self.train_env.obs_space, 102 | self.train_env.act_space, cfg, cfg.agent) 103 | # get meta specs 104 | meta_specs = self.agent.get_meta_specs() 105 | # create replay buffer 106 | data_specs = (self.train_env.obs_space, 107 | self.train_env.act_space, 108 | specs.Array((1,), np.float32, 'reward'), 109 | specs.Array((1,), np.float32, 'discount')) 110 | self.act_space = self.train_env.act_space 111 | 112 | 113 | # create replay storage 114 | self.replay_storage = ReplayBuffer(data_specs, meta_specs, 115 | self.workdir / 'buffer', 116 | length=cfg.batch_length, **cfg.replay, 117 | device=cfg.device, 118 | fetch_every=cfg.batch_size, 119 | save_episodes=cfg.save_episodes) 120 | 121 | if cfg.preallocate_memory: 122 | self.replay_storage.preallocate_memory(cfg.num_train_frames // cfg.action_repeat) 123 | 124 | # create replay buffer 125 | self.replay_loader = make_replay_loader(self.replay_storage, 126 | cfg.batch_size, # 127 | cfg.num_workers 128 | ) 129 | self._replay_iter = None 130 | 131 | self.timer = utils.Timer() 132 | self._global_step = 0 133 | self._global_episode = 0 134 | self._reward_stats = {'max' : -1e10, 'min' : 1e10, 'ep_avg_max' : -1e10 } 135 | 136 | @property 137 | def global_step(self): 138 | return self._global_step 139 | 140 | @property 141 | def global_episode(self): 142 | return self._global_episode 143 | 144 | @property 145 | def global_frame(self): 146 | return self.global_step * self.cfg.action_repeat 147 | 148 | @property 149 | def replay_iter(self): 150 | if self._replay_iter is None: 151 | self._replay_iter = iter(self.replay_loader) 152 | return self._replay_iter 153 | 154 | def train(self): 155 | # predicates 156 | train_until_step = utils.Until(self.cfg.num_train_frames, 157 | self.cfg.action_repeat) 158 | seed_until_step = utils.Until(self.cfg.num_seed_frames, 159 | self.cfg.action_repeat) 160 | 161 | # To preserve speed never request more than 3/4 of the num of envs 162 | train_every_n_steps = max(min(self.cfg.train_every_actions, (self.parallel_envs * 3) // 4 ), 1) 163 | print(f"Training every {train_every_n_steps} steps from the environment (num envs: {self.cfg.parallel_envs})") 164 | 165 | next_log_point = (self.global_frame // self.cfg.log_every_frames) + 1 166 | 167 | episode_step, episode_reward, episode_success, episode_invalid, episode_max_reward = np.zeros(self.parallel_envs), np.zeros(self.parallel_envs), np.zeros(self.parallel_envs), np.zeros(self.parallel_envs), np.zeros(self.parallel_envs) 168 | average_steps = [] 169 | last_episodes = [] 170 | complete_idxs = [] 171 | time_step = self.train_env.reset_all() 172 | agent_state = None 173 | meta = self.agent.init_meta() 174 | for n, idx in enumerate(time_step['env_idx']): 175 | env_obs = { k: time_step[k][n] for k in time_step} 176 | if env_obs['env_error']: 177 | env_obs = self.train_env.reset_idx(idx) 178 | del env_obs['env_error'] 179 | del env_obs['env_idx'] 180 | self.replay_storage.add(env_obs, meta, idx=idx) 181 | 182 | metrics = None 183 | elapsed_time, total_time = self.timer.reset() 184 | 185 | while train_until_step(self.global_step): 186 | for n in np.where(time_step['is_last'])[0]: 187 | self._global_episode += 1 188 | idx = time_step['env_idx'][n] 189 | complete_idxs.append(idx) 190 | 191 | if not time_step['env_error'][n]: 192 | last_episodes.append([episode_step[idx], episode_reward[idx], episode_success[idx], episode_invalid[idx], episode_max_reward[idx]]) 193 | self._reward_stats['ep_avg_max'] = max(self._reward_stats['ep_avg_max'], episode_reward[idx] / episode_step[idx]) 194 | episode_step[idx], episode_reward[idx], episode_success[idx], episode_invalid[idx], episode_max_reward[idx] = [0,0,0,0,0] 195 | 196 | if self.cfg.async_mode == 'FULL': 197 | self.train_env._send_reset(idx) 198 | else: 199 | reset_obs = self.train_env.reset_idx(idx) 200 | for k,v in reset_obs.items(): 201 | time_step[k][n] = v 202 | del reset_obs['env_error'] 203 | assert idx == reset_obs.pop('env_idx') 204 | self.replay_storage.add(reset_obs, meta, idx=idx) 205 | 206 | # wait until all the metrics schema is populated 207 | if (self.global_step >= next_log_point * self.cfg.log_every_frames): 208 | next_log_point += 1 209 | # Episodes logging 210 | if len(last_episodes) > 0: 211 | last_episodes = np.stack(last_episodes, axis=0) 212 | last_step, last_reward, last_success, last_invalid, last_max_reward = np.mean(last_episodes, axis=0) 213 | 214 | # log stats 215 | elapsed_time, total_time = self.timer.reset() 216 | last_frame = last_step * self.cfg.action_repeat 217 | with self.logger.log_and_dump_ctx(self.global_frame, 218 | ty='train') as log: 219 | log('fps', self.cfg.log_every_frames / elapsed_time) 220 | log('total_time', total_time) 221 | log('buffer_size', len(self.replay_storage)) 222 | log('episode_reward', last_reward ) 223 | log('episode_avg_valid_reward', last_reward / (last_step - last_invalid) ) 224 | log('episode_max_reward', last_max_reward ) 225 | log('episode_length', last_frame ) 226 | log('episode', self.global_episode) 227 | log('step', self.global_step) 228 | log('average_steps', sum(average_steps) / len(average_steps)) 229 | if 'invalid_action' in time_step: 230 | log('episode_invalid', last_invalid ) 231 | if 'success' in time_step: 232 | # episode_success = np.stack(episode_success) 233 | # ep_success = (episode_success[-10:].mean(axis=0) > 0.5).mean() 234 | # log('success', ep_success) 235 | # anytime_success = (episode_success.sum(axis=0) > 0.).mean() 236 | log('anytime_success', last_success) 237 | if getattr(self.agent, '_stats', False): 238 | for k,v in self.agent._stats.items(): 239 | log(k, v) 240 | 241 | last_episodes = [] 242 | average_steps = [] 243 | 244 | # Agent logging 245 | if metrics is not None: 246 | # add rew metrics 247 | rew_metrics = {f'reward_stats/{k}' : v for k,v in self._reward_stats.items()} 248 | metrics.update(rew_metrics) 249 | self.logger.log_metrics(metrics, self.global_frame, ty='train') 250 | 251 | meta = self.agent.update_meta(meta, self.global_step, time_step) 252 | 253 | # sample action 254 | with torch.no_grad(), utils.eval_mode(self.agent): 255 | action, agent_state = self.agent.act(time_step, # time_step.observation 256 | meta, 257 | self.global_step, 258 | eval_mode=False, 259 | state=agent_state) 260 | 261 | # try to update the agent 262 | if not seed_until_step(self.global_step) and not (self.replay_storage.stats['total_episodes'] < self.cfg.num_seed_episodes): 263 | metrics = self.agent.update(next(self.replay_iter), self.global_step)[1] 264 | 265 | # take env step 266 | if self.cfg.async_mode == 'FULL': 267 | time_step = self.train_env.step_by_idx(action, idxs=time_step['env_idx'], requested_steps=train_every_n_steps, ignore_idxs=complete_idxs) 268 | complete_idxs = [] 269 | elif self.cfg.async_mode == 'HALF': 270 | if time_step['env_idx'].shape[0] > self.cfg.parallel_envs//2: 271 | assert (time_step['env_idx'] == np.arange(0,self.cfg.parallel_envs)).all() 272 | self.train_env.step_receive_by_idx(action[:self.cfg.parallel_envs//2], send_idxs=np.arange(0,self.cfg.parallel_envs//2), recv_idxs=[]) # receive next 273 | send_idxs, recv_idxs = np.arange(self.cfg.parallel_envs//2, self.cfg.parallel_envs), np.arange(0,self.cfg.parallel_envs//2) 274 | time_step = self.train_env.step_receive_by_idx(action[self.cfg.parallel_envs//2:], send_idxs=send_idxs, recv_idxs=recv_idxs) 275 | else: 276 | send_idxs, recv_idxs = recv_idxs, send_idxs 277 | assert (time_step['env_idx'] == send_idxs).all() 278 | time_step = self.train_env.step_receive_by_idx(action, send_idxs=send_idxs, recv_idxs=recv_idxs) 279 | elif self.cfg.async_mode == 'OFF': 280 | time_step = self.train_env.step_all(action) 281 | else: 282 | raise NotImplementedError(f"Odd async modality : {self.cfg.async_mode}") 283 | 284 | # process env data 285 | for n, idx in enumerate(time_step['env_idx']): 286 | env_obs = { k: time_step[k][n] for k in time_step} 287 | if env_obs['env_error']: 288 | env_obs = SIG_FAILURE 289 | # Forcing reset 290 | time_step['is_last'][n] = 1.0 291 | # Fixing global stats (steps were invalid) 292 | self._global_step -= episode_step[idx] 293 | else: 294 | # Remove extra keys 295 | del env_obs['env_error'] 296 | del env_obs['env_idx'] 297 | 298 | # update episode stats 299 | episode_reward[idx] += env_obs['reward'] 300 | episode_max_reward[idx] = max(env_obs['reward'], episode_max_reward[idx]) 301 | if 'invalid_action' in env_obs: 302 | episode_invalid[idx] += env_obs['invalid_action'] 303 | if 'success' in env_obs: 304 | episode_success[idx] += env_obs['success'] 305 | episode_success[idx] = np.clip(episode_success[idx], 0, 1) 306 | episode_step[idx] += 1 307 | 308 | if not seed_until_step(self.global_step) and self.cfg.log_best_episodes : 309 | if env_obs['is_last'] and episode_reward[idx] / episode_step[idx] > self._reward_stats['ep_avg_max'] and len(self.replay_storage._ongoing_eps[idx]['action']) > 0: 310 | self._reward_stats['ep_avg_max'] = episode_reward[idx] / episode_step[idx] 311 | # Log video of best episode 312 | videos = {} 313 | if 'front_rgb' in env_obs: 314 | videos['ep_avg_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['front_rgb'], axis=0), axis=0) 315 | if 'wrist_rgb' in env_obs: 316 | if 'ep_avg_max/rgb' in videos: 317 | videos['ep_avg_max/rgb'] = np.concatenate([videos['ep_avg_max/rgb'], np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0)], axis=0) 318 | else: 319 | videos['ep_avg_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0) 320 | self.logger.log_video(videos, self.global_frame) 321 | if env_obs['reward'] > self._reward_stats['max'] and len(self.replay_storage._ongoing_eps[idx]['action']) > 0: 322 | self._reward_stats['max'] = env_obs['reward'] 323 | # Log video of best reward 324 | videos = {} 325 | if 'front_rgb' in env_obs: 326 | videos['rew_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['front_rgb'], axis=0), axis=0) 327 | if 'wrist_rgb' in env_obs: 328 | if 'rew_max/rgb' in videos: 329 | videos['rew_max/rgb'] = np.concatenate([videos['rew_max/rgb'], np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0)], axis=0) 330 | else: 331 | videos['rew_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0) 332 | self.logger.log_video(videos, self.global_frame) 333 | 334 | self.replay_storage.add(env_obs, meta, idx=idx) 335 | 336 | # update global stats 337 | self._reward_stats['max'] = max(self._reward_stats['max'], max(time_step['reward'])) 338 | self._reward_stats['min'] = min(self._reward_stats['min'], min(time_step['reward'])) 339 | self._global_step += time_step['env_idx'].shape[0] 340 | average_steps.append(time_step['env_idx'].shape[0]) 341 | 342 | # save last model 343 | if self.global_frame % 5000 == 0 and self.cfg.save_episodes: 344 | self.save_last_model() 345 | sys.stdout.flush(), sys.stderr.flush() 346 | 347 | @utils.retry 348 | def save_snapshot(self): 349 | snapshot = self.get_snapshot_dir() / f'snapshot_{self.global_frame}.pt' 350 | keys_to_save = ['agent', '_global_step', '_global_episode'] 351 | payload = {k: self.__dict__[k] for k in keys_to_save} 352 | with snapshot.open('wb') as f: 353 | torch.save(payload, f) 354 | 355 | def setup_wandb(self): 356 | cfg = self.cfg 357 | exp_name = '_'.join([ 358 | getattr(getattr(cfg,'env', {}), 'action_mode', '_'), cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type, str(cfg.seed) 359 | ]) 360 | wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name, notes=f'workspace: {self.workdir}') 361 | wandb.config.update(cfg) 362 | 363 | # define our custom x axis metric 364 | wandb.define_metric("train/frame") 365 | # set all other train/ metrics to use this step 366 | wandb.define_metric("train/*", step_metric="train/frame") 367 | 368 | self.wandb_run_id = wandb.run.id 369 | 370 | @utils.retry 371 | def save_last_model(self): 372 | snapshot = self.root_dir / 'last_snapshot.pt' 373 | if snapshot.is_file(): 374 | temp = Path(str(snapshot).replace("last_snapshot.pt", "second_last_snapshot.pt")) 375 | os.replace(snapshot, temp) 376 | keys_to_save = ['agent', '_global_step', '_global_episode', '_reward_stats'] 377 | if self.cfg.use_wandb: 378 | keys_to_save.append('wandb_run_id') 379 | payload = {k: self.__dict__[k] for k in keys_to_save} 380 | with snapshot.open('wb') as f: 381 | torch.save(payload, f) 382 | 383 | def load_snapshot(self): 384 | try: 385 | snapshot = self.root_dir / 'last_snapshot.pt' 386 | with snapshot.open('rb') as f: 387 | payload = torch.load(f) 388 | except: 389 | snapshot = self.root_dir / 'second_last_snapshot.pt' 390 | with snapshot.open('rb') as f: 391 | payload = torch.load(f) 392 | for k,v in payload.items(): 393 | setattr(self, k, v) 394 | if k == 'wandb_run_id': 395 | assert wandb.run is None 396 | cfg = self.cfg 397 | exp_name = '_'.join([ 398 | getattr(getattr(cfg,'env', {}), 'action_mode', '_'), cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type, str(cfg.seed) 399 | ]) 400 | wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name, id=v, resume="must", notes=f'workspace: {self.workdir}') 401 | # define our custom x axis metric 402 | wandb.define_metric("train/frame") 403 | # set all other train/ metrics to use this step 404 | wandb.define_metric("train/*", step_metric="train/frame") 405 | 406 | 407 | def get_snapshot_dir(self): 408 | snap_dir = self.cfg.snapshot_dir 409 | snapshot_dir = self.workdir / Path(snap_dir) 410 | snapshot_dir.mkdir(exist_ok=True, parents=True) 411 | return snapshot_dir 412 | 413 | @hydra.main(config_path='.', config_name='train') 414 | def main(cfg): 415 | try: 416 | root_dir = Path.cwd() 417 | with open(str(root_dir / "out.log"), 'a') as stdout, redirect_stdout(stdout), open(str(root_dir / "err.log"), 'a') as stderr, redirect_stderr(stderr): 418 | workspace = Workspace(cfg) 419 | workspace.root_dir = root_dir 420 | # for resuming, config env.run.dir to the snapshot path 421 | snapshot = workspace.root_dir / 'last_snapshot.pt' 422 | if snapshot.exists(): 423 | print(f'resuming: {snapshot}') 424 | workspace.load_snapshot() 425 | if cfg.use_wandb and wandb.run is None: 426 | # otherwise it was resumed 427 | workspace.setup_wandb() 428 | workspace.train() 429 | except Exception as e: 430 | print(traceback.format_exc()) 431 | finally: 432 | if hasattr(workspace, 'train_env'): 433 | del workspace.train_env 434 | 435 | if __name__ == '__main__': 436 | main() 437 | -------------------------------------------------------------------------------- /rl/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - configs/default 4 | - agent: sac 5 | - configs: ${configs} 6 | # - override hydra/launcher: submitit_local 7 | - override hydra/launcher: joblib 8 | - override hydra/hydra_logging: custom 9 | - override hydra/job_logging: stdout 10 | 11 | # task settings 12 | task: none 13 | domain: walker # primal task will be infered in runtime 14 | # train settings 15 | num_train_frames: 500010 16 | num_seed_frames: 4000 17 | num_seed_episodes: ${num_workers} 18 | # eval 19 | eval_every_frames: 1000000000 # not necessary during pretrain 20 | num_eval_episodes: 10 21 | # snapshot 22 | snapshots: [100000, 500000, 1000000, 2000000] 23 | snapshot_dir: ../../../pretrained_models/${obs_type}/${task}/${agent.name}/${seed} 24 | 25 | # replay buffer 26 | replay_buffer_size: 1000000 27 | num_workers: 4 28 | save_episodes: False 29 | preallocate_memory: False 30 | 31 | # misc 32 | seed: 1 33 | device: cuda 34 | use_tb: true 35 | use_wandb: true 36 | 37 | # experiment 38 | experiment: default 39 | project_name: ??? 40 | flexible_gpu: true 41 | spread_envs: true 42 | 43 | # log settings 44 | log_every_frames: 2500 45 | log_best_episodes: True 46 | 47 | hydra: 48 | run: 49 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}_${agent.name}_${obs_type}_${task}_${env.action_mode}_${seed} 50 | sweep: 51 | dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}_${agent.name}_${obs_type} 52 | subdir: ${task}_${env.action_mode}_${seed}_${hydra.job.num} 53 | 54 | # launcher: 55 | # timeout_min: 4300 56 | # cpus_per_task: 2 57 | # gpus_per_node: 4 58 | # tasks_per_node: 1 59 | # mem_gb: 160 60 | # nodes: 1 61 | # submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name}_${experiment}_${seed}/.slurm 62 | -------------------------------------------------------------------------------- /rl/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import re 4 | import time 5 | from functools import wraps 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from omegaconf import OmegaConf 12 | from torch import distributions as pyd 13 | from torch.distributions.utils import _standard_normal 14 | import numbers 15 | 16 | class eval_mode: 17 | def __init__(self, *models): 18 | self.models = models 19 | 20 | def __enter__(self): 21 | self.prev_states = [] 22 | for model in self.models: 23 | self.prev_states.append(model.training) 24 | model.train(False) 25 | 26 | def __exit__(self, *args): 27 | for model, state in zip(self.models, self.prev_states): 28 | model.train(state) 29 | return False 30 | 31 | 32 | def set_seed_everywhere(seed): 33 | torch.manual_seed(seed) 34 | if torch.cuda.is_available(): 35 | torch.cuda.manual_seed_all(seed) 36 | np.random.seed(seed) 37 | random.seed(seed) 38 | 39 | 40 | def chain(*iterables): 41 | for it in iterables: 42 | yield from it 43 | 44 | 45 | def soft_update_params(net, target_net, tau): 46 | for param, target_param in zip(net.parameters(), target_net.parameters()): 47 | target_param.data.copy_(tau * param.data + 48 | (1 - tau) * target_param.data) 49 | 50 | 51 | def hard_update_params(net, target_net): 52 | for param, target_param in zip(net.parameters(), target_net.parameters()): 53 | target_param.data.copy_(param.data) 54 | 55 | 56 | def to_torch(xs, device): 57 | return tuple(torch.as_tensor(x, device=device) for x in xs) 58 | 59 | 60 | def weight_init(m): 61 | """Custom weight init for Conv2D and Linear layers.""" 62 | if isinstance(m, nn.Linear): 63 | nn.init.orthogonal_(m.weight.data) 64 | if hasattr(m.bias, 'data'): 65 | m.bias.data.fill_(0.0) 66 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 67 | gain = nn.init.calculate_gain('relu') 68 | nn.init.orthogonal_(m.weight.data, gain) 69 | if hasattr(m.bias, 'data'): 70 | m.bias.data.fill_(0.0) 71 | 72 | 73 | def grad_norm(params, norm_type=2.0): 74 | params = [p for p in params if p.grad is not None] 75 | total_norm = torch.norm( 76 | torch.stack([torch.norm(p.grad.detach(), norm_type) for p in params]), 77 | norm_type) 78 | return total_norm.item() 79 | 80 | 81 | def param_norm(params, norm_type=2.0): 82 | total_norm = torch.norm( 83 | torch.stack([torch.norm(p.detach(), norm_type) for p in params]), 84 | norm_type) 85 | return total_norm.item() 86 | 87 | 88 | class Until: 89 | def __init__(self, until, action_repeat=1): 90 | self._until = until 91 | self._action_repeat = action_repeat 92 | 93 | def __call__(self, step): 94 | if self._until is None: 95 | return True 96 | until = self._until // self._action_repeat 97 | return step < until 98 | 99 | 100 | class Every: 101 | def __init__(self, every, action_repeat=1, label='train'): 102 | self._every = every 103 | self._action_repeat = action_repeat 104 | if self._every // self._action_repeat == 0: 105 | print(f"WARNING: asking to {label} every 0 steps. Defaulting to 1.") 106 | 107 | def __call__(self, step): 108 | if self._every is None: 109 | return False 110 | every = max(1, self._every // self._action_repeat) 111 | if step % every == 0: 112 | return True 113 | return False 114 | 115 | 116 | class Timer: 117 | def __init__(self): 118 | self._start_time = time.time() 119 | self._last_time = time.time() 120 | 121 | def reset(self): 122 | elapsed_time = time.time() - self._last_time 123 | self._last_time = time.time() 124 | total_time = time.time() - self._start_time 125 | return elapsed_time, total_time 126 | 127 | def total_time(self): 128 | return time.time() - self._start_time 129 | 130 | 131 | class TruncatedNormal(pyd.Normal): 132 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 133 | super().__init__(loc, scale, validate_args=False) 134 | self.low = low 135 | self.high = high 136 | self.eps = eps 137 | 138 | def _clamp(self, x): 139 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 140 | x = x - x.detach() + clamped_x.detach() 141 | return x 142 | 143 | def sample(self, sample_shape=torch.Size(), clip=None): 144 | shape = self._extended_shape(sample_shape) 145 | eps = _standard_normal(shape, 146 | dtype=self.loc.dtype, 147 | device=self.loc.device) 148 | eps *= self.scale 149 | if clip is not None: 150 | eps = torch.clamp(eps, -clip, clip) 151 | x = self.loc + eps 152 | return self._clamp(x) 153 | 154 | 155 | class TanhTransform(pyd.transforms.Transform): 156 | domain = pyd.constraints.real 157 | codomain = pyd.constraints.interval(-1.0, 1.0) 158 | bijective = True 159 | sign = +1 160 | 161 | def __init__(self, cache_size=1): 162 | super().__init__(cache_size=cache_size) 163 | 164 | @staticmethod 165 | def atanh(x): 166 | return 0.5 * (x.log1p() - (-x).log1p()) 167 | 168 | def __eq__(self, other): 169 | return isinstance(other, TanhTransform) 170 | 171 | def _call(self, x): 172 | return x.tanh() 173 | 174 | def _inverse(self, y): 175 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 176 | # one should use `cache_size=1` instead 177 | return self.atanh(y) 178 | 179 | def log_abs_det_jacobian(self, x, y): 180 | # We use a formula that is more numerically stable, see details in the following link 181 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 182 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 183 | 184 | 185 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 186 | def __init__(self, loc, scale): 187 | self.loc = loc 188 | self.scale = scale 189 | 190 | self.base_dist = pyd.Normal(loc, scale) 191 | transforms = [TanhTransform()] 192 | super().__init__(self.base_dist, transforms) 193 | 194 | @property 195 | def mean(self): 196 | mu = self.loc 197 | for tr in self.transforms: 198 | mu = tr(mu) 199 | return mu 200 | 201 | 202 | def schedule(schdl, step): 203 | try: 204 | return float(schdl) 205 | except ValueError: 206 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl) 207 | if match: 208 | init, final, duration = [float(g) for g in match.groups()] 209 | mix = np.clip(step / duration, 0.0, 1.0) 210 | return (1.0 - mix) * init + mix * final 211 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl) 212 | if match: 213 | init, final1, duration1, final2, duration2 = [ 214 | float(g) for g in match.groups() 215 | ] 216 | if step <= duration1: 217 | mix = np.clip(step / duration1, 0.0, 1.0) 218 | return (1.0 - mix) * init + mix * final1 219 | else: 220 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0) 221 | return (1.0 - mix) * final1 + mix * final2 222 | raise NotImplementedError(schdl) 223 | 224 | def retry(func): 225 | """ 226 | A Decorator to retry a function for a certain amount of attempts 227 | """ 228 | 229 | @wraps(func) 230 | def wrapper(*args, **kwargs): 231 | attempts = 0 232 | max_attempts = 1000 233 | while attempts < max_attempts: 234 | try: 235 | return func(*args, **kwargs) 236 | except (OSError, PermissionError): 237 | attempts += 1 238 | time.sleep(0.1) 239 | raise OSError("Retry failed") 240 | 241 | return wrapper 242 | 243 | class RandomShiftsAug(nn.Module): 244 | def __init__(self, pad): 245 | super().__init__() 246 | self.pad = pad 247 | 248 | def forward(self, x): 249 | x = x.float() 250 | n, c, h, w = x.size() 251 | assert h == w 252 | padding = tuple([self.pad] * 4) 253 | x = F.pad(x, padding, 'replicate') 254 | eps = 1.0 / (h + 2 * self.pad) 255 | arange = torch.linspace(-1.0 + eps, 256 | 1.0 - eps, 257 | h + 2 * self.pad, 258 | device=x.device, 259 | dtype=x.dtype)[:h] 260 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 261 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 262 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 263 | 264 | shift = torch.randint(0, 265 | 2 * self.pad + 1, 266 | size=(n, 1, 1, 2), 267 | device=x.device, 268 | dtype=x.dtype) 269 | shift *= 2.0 / (h + 2 * self.pad) 270 | 271 | grid = base_grid + shift 272 | return F.grid_sample(x, 273 | grid, 274 | padding_mode='zeros', 275 | align_corners=False) 276 | 277 | 278 | class RMS(object): 279 | """running mean and std """ 280 | def __init__(self, device, epsilon=1e-4, shape=(1,)): 281 | self.M = torch.zeros(shape).to(device) 282 | self.S = torch.ones(shape).to(device) 283 | self.n = epsilon 284 | 285 | def __call__(self, x): 286 | bs = x.size(0) 287 | delta = torch.mean(x, dim=0) - self.M 288 | new_M = self.M + delta * bs / (self.n + bs) 289 | new_S = (self.S * self.n + torch.var(x, dim=0) * bs + 290 | torch.square(delta) * self.n * bs / 291 | (self.n + bs)) / (self.n + bs) 292 | 293 | self.M = new_M 294 | self.S = new_S 295 | self.n += bs 296 | 297 | return self.M, self.S 298 | 299 | 300 | class PBE(object): 301 | """particle-based entropy based on knn normalized by running mean """ 302 | def __init__(self, rms, knn_clip, knn_k, knn_avg, knn_rms, device): 303 | self.rms = rms 304 | self.knn_rms = knn_rms 305 | self.knn_k = knn_k 306 | self.knn_avg = knn_avg 307 | self.knn_clip = knn_clip 308 | self.device = device 309 | 310 | def __call__(self, rep, cdist=False, apply_log=True): 311 | source = target = rep 312 | b1, b2 = source.size(0), target.size(0) 313 | # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2) 314 | if cdist: 315 | sim_matrix = torch.cdist(source, target.detach(), p=2) 316 | else: 317 | sim_matrix = torch.norm(source[:, None, :].view(b1, 1, -1) - 318 | target[None, :, :].view(1, b2, -1), 319 | dim=-1, 320 | p=2) 321 | reward, _ = sim_matrix.topk(self.knn_k, 322 | dim=1, 323 | largest=False, 324 | sorted=True) # (b1, k) 325 | if not self.knn_avg: # only keep k-th nearest neighbor 326 | reward = reward[:, -1] 327 | reward = reward.reshape(-1, 1) # (b1, 1) 328 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0 329 | reward = torch.maximum( 330 | reward - self.knn_clip, 331 | torch.zeros_like(reward).to(self.device) 332 | ) if self.knn_clip >= 0.0 else reward # (b1, 1) 333 | else: # average over all k nearest neighbors 334 | reward = reward.reshape(-1, 1) # (b1 * k, 1) 335 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0 336 | reward = torch.maximum( 337 | reward - self.knn_clip, 338 | torch.zeros_like(reward).to( 339 | self.device)) if self.knn_clip >= 0.0 else reward 340 | reward = reward.reshape((b1, self.knn_k)) # (b1, k) 341 | reward = reward.mean(dim=1, keepdim=True) # (b1, 1) 342 | if apply_log: 343 | reward = torch.log(reward + 1.0) 344 | return reward 345 | 346 | def rgb2hsv(rgb, eps=1e-8): 347 | # Reference: https://www.rapidtables.com/convert/color/rgb-to-hsv.html 348 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287 349 | 350 | _device = rgb.device 351 | r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :] 352 | 353 | Cmax = rgb.max(1)[0] 354 | Cmin = rgb.min(1)[0] 355 | delta = Cmax - Cmin 356 | 357 | hue = torch.zeros((rgb.shape[0], rgb.shape[2], rgb.shape[3])).to(_device) 358 | hue[Cmax== r] = (((g - b)/(delta + eps)) % 6)[Cmax == r] 359 | hue[Cmax == g] = ((b - r)/(delta + eps) + 2)[Cmax == g] 360 | hue[Cmax == b] = ((r - g)/(delta + eps) + 4)[Cmax == b] 361 | hue[Cmax == 0] = 0.0 362 | hue = hue / 6. # making hue range as [0, 1.0) 363 | hue = hue.unsqueeze(dim=1) 364 | 365 | saturation = (delta) / (Cmax + eps) 366 | saturation[Cmax == 0.] = 0. 367 | saturation = saturation.to(_device) 368 | saturation = saturation.unsqueeze(dim=1) 369 | 370 | value = Cmax 371 | value = value.to(_device) 372 | value = value.unsqueeze(dim=1) 373 | 374 | return torch.cat((hue, saturation, value), dim=1)#.type(torch.FloatTensor).to(_device) 375 | # return hue, saturation, value 376 | 377 | def hsv2rgb(hsv): 378 | # Reference: https://www.rapidtables.com/convert/color/hsv-to-rgb.html 379 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287 380 | 381 | _device = hsv.device 382 | 383 | hsv = torch.clamp(hsv, 0, 1) 384 | hue = hsv[:, 0, :, :] * 360. 385 | saturation = hsv[:, 1, :, :] 386 | value = hsv[:, 2, :, :] 387 | 388 | c = value * saturation 389 | x = - c * (torch.abs((hue / 60.) % 2 - 1) - 1) 390 | m = (value - c).unsqueeze(dim=1) 391 | 392 | rgb_prime = torch.zeros_like(hsv).to(_device) 393 | 394 | inds = (hue < 60) * (hue >= 0) 395 | rgb_prime[:, 0, :, :][inds] = c[inds] 396 | rgb_prime[:, 1, :, :][inds] = x[inds] 397 | 398 | inds = (hue < 120) * (hue >= 60) 399 | rgb_prime[:, 0, :, :][inds] = x[inds] 400 | rgb_prime[:, 1, :, :][inds] = c[inds] 401 | 402 | inds = (hue < 180) * (hue >= 120) 403 | rgb_prime[:, 1, :, :][inds] = c[inds] 404 | rgb_prime[:, 2, :, :][inds] = x[inds] 405 | 406 | inds = (hue < 240) * (hue >= 180) 407 | rgb_prime[:, 1, :, :][inds] = x[inds] 408 | rgb_prime[:, 2, :, :][inds] = c[inds] 409 | 410 | inds = (hue < 300) * (hue >= 240) 411 | rgb_prime[:, 2, :, :][inds] = c[inds] 412 | rgb_prime[:, 0, :, :][inds] = x[inds] 413 | 414 | inds = (hue < 360) * (hue >= 300) 415 | rgb_prime[:, 2, :, :][inds] = x[inds] 416 | rgb_prime[:, 0, :, :][inds] = c[inds] 417 | 418 | rgb = rgb_prime + torch.cat((m, m, m), dim=1) 419 | rgb = rgb.to(_device) 420 | 421 | return torch.clamp(rgb, 0, 1) 422 | 423 | class ColorJitterLayer(nn.Module): 424 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0, batch_size=128, stack_size=3): 425 | super(ColorJitterLayer, self).__init__() 426 | self.brightness = self._check_input(brightness, 'brightness') 427 | self.contrast = self._check_input(contrast, 'contrast') 428 | self.saturation = self._check_input(saturation, 'saturation') 429 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 430 | clip_first_on_zero=False) 431 | self.prob = p 432 | self.batch_size = batch_size 433 | self.stack_size = stack_size 434 | 435 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 436 | if isinstance(value, numbers.Number): 437 | if value < 0: 438 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 439 | value = [center - value, center + value] 440 | if clip_first_on_zero: 441 | value[0] = max(value[0], 0) 442 | elif isinstance(value, (tuple, list)) and len(value) == 2: 443 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 444 | raise ValueError("{} values should be between {}".format(name, bound)) 445 | else: 446 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 447 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 448 | # or (0., 0.) for hue, do nothing 449 | if value[0] == value[1] == center: 450 | value = None 451 | return value 452 | 453 | def adjust_contrast(self, x): 454 | """ 455 | Args: 456 | x: torch tensor img (rgb type) 457 | Factor: torch tensor with same length as x 458 | 0 gives gray solid image, 1 gives original image, 459 | Returns: 460 | torch tensor image: Brightness adjusted 461 | """ 462 | _device = x.device 463 | factor = torch.empty(x.shape[0], device=_device).uniform_(*self.contrast) 464 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1) 465 | means = torch.mean(x, dim=(2, 3), keepdim=True) 466 | return torch.clamp((x - means) 467 | * factor.view(len(x), 1, 1, 1) + means, 0, 1) 468 | 469 | def adjust_hue(self, x): 470 | _device = x.device 471 | factor = torch.empty(x.shape[0], device=_device).uniform_(*self.hue) 472 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1) 473 | h = x[:, 0, :, :] 474 | h += (factor.view(len(x), 1, 1) * 255. / 360.) 475 | h = (h % 1) 476 | x[:, 0, :, :] = h 477 | return x 478 | 479 | def adjust_brightness(self, x): 480 | """ 481 | Args: 482 | x: torch tensor img (hsv type) 483 | Factor: 484 | torch tensor with same length as x 485 | 0 gives black image, 1 gives original image, 486 | 2 gives the brightness factor of 2. 487 | Returns: 488 | torch tensor image: Brightness adjusted 489 | """ 490 | _device = x.device 491 | factor = torch.empty(x.shape[0], device=_device).uniform_(*self.brightness) 492 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1) 493 | x[:, 2, :, :] = torch.clamp(x[:, 2, :, :] 494 | * factor.view(len(x), 1, 1), 0, 1) 495 | return torch.clamp(x, 0, 1) 496 | 497 | def adjust_saturate(self, x): 498 | """ 499 | Args: 500 | x: torch tensor img (hsv type) 501 | Factor: 502 | torch tensor with same length as x 503 | 0 gives black image and white, 1 gives original image, 504 | 2 gives the brightness factor of 2. 505 | Returns: 506 | torch tensor image: Brightness adjusted 507 | """ 508 | _device = x.device 509 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.saturation) 510 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1) 511 | x[:, 1, :, :] = torch.clamp(x[:, 1, :, :] 512 | * factor.view(len(x), 1, 1), 0, 1) 513 | return torch.clamp(x, 0, 1) 514 | 515 | def transform(self, inputs): 516 | hsv_transform_list = [rgb2hsv, self.adjust_brightness, 517 | self.adjust_hue, self.adjust_saturate, 518 | hsv2rgb] 519 | rgb_transform_list = [self.adjust_contrast] 520 | # Shuffle transform 521 | if random.uniform(0,1) >= 0.5: 522 | transform_list = rgb_transform_list + hsv_transform_list 523 | else: 524 | transform_list = hsv_transform_list + rgb_transform_list 525 | for t in transform_list: 526 | inputs = t(inputs) 527 | return inputs 528 | 529 | def forward(self, inputs): 530 | _device = inputs.device 531 | random_inds = np.random.choice( 532 | [True, False], len(inputs), p=[self.prob, 1 - self.prob]) 533 | inds = torch.tensor(random_inds).to(_device) 534 | if random_inds.sum() > 0: 535 | inputs[inds] = self.transform(inputs[inds]) 536 | return inputs 537 | 538 | class StreamNorm: 539 | def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8, device='cuda'): 540 | # Momentum of 0 normalizes only based on the current batch. 541 | # Momentum of 1 disables normalization. 542 | self.device = device 543 | self._shape = tuple(shape) 544 | self._momentum = momentum 545 | self._scale = scale 546 | self._eps = eps 547 | self.mag = None # torch.ones(shape).to(self.device) 548 | 549 | self.step = 0 550 | self.mean = None # torch.zeros(shape).to(self.device) 551 | self.square_mean = None # torch.zeros(shape).to(self.device) 552 | 553 | def __call__(self, inputs): 554 | metrics = {} 555 | self.update(inputs) 556 | metrics['mean'] = inputs.mean() 557 | metrics['std'] = inputs.std() 558 | outputs = self.transform(inputs) 559 | metrics['normed_mean'] = outputs.mean() 560 | metrics['normed_std'] = outputs.std() 561 | return outputs, metrics 562 | 563 | def reset(self): 564 | self.mag = None # torch.ones_like(self.mag).to(self.device) 565 | 566 | self.step = 0 567 | self.mean = None # torch.zeros_like(self.mean).to(self.device) 568 | self.square_mean = None # torch.zeros_like(self.square_mean).to(self.device) 569 | 570 | def update(self, inputs): 571 | batch = inputs.reshape((-1,) + self._shape) 572 | 573 | mag = torch.abs(batch).mean(0) 574 | if self.mag is not None: 575 | self.mag.data = self._momentum * self.mag.data + (1 - self._momentum) * mag 576 | else: 577 | self.mag = mag.clone() 578 | 579 | self.step += 1 580 | 581 | mean = torch.mean(batch) 582 | if self.mean is not None: 583 | self.mean.data = self._momentum * self.mean.data + (1 - self._momentum) * mean 584 | else: 585 | self.mean = mean.clone() 586 | 587 | square_mean = torch.mean(batch * batch) 588 | if self.square_mean is not None: 589 | self.square_mean.data = self._momentum * self.square_mean.data + (1 - self._momentum) * square_mean 590 | else: 591 | self.square_mean = square_mean.clone() 592 | 593 | def transform(self, inputs): 594 | values = inputs.reshape((-1,) + self._shape) 595 | values /= self.mag[None] + self._eps 596 | values *= self._scale 597 | return values.reshape(inputs.shape) 598 | 599 | def corrected_mean_var_std(self,): 600 | corr = 1 601 | corr_mean = self.mean / corr 602 | corr_var = (self.square_mean / corr) - self.mean ** 2 603 | corr_std = torch.sqrt(torch.maximum(corr_var, torch.zeros_like(corr_var, device=self.device)) + self._eps) 604 | return corr_mean, corr_var, corr_std 605 | -------------------------------------------------------------------------------- /xvfb_run.sh: -------------------------------------------------------------------------------- 1 | xvfb-run -a -s '-screen 0 1024x768x24 -ac +extension GLX +render -noreset' "$@" --------------------------------------------------------------------------------