├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── assets
├── bottle_out_fridge.png
├── cup_out_open_cabinet.png
├── meat_off_grill.png
├── reach_elbow_pose.png
├── serve_coffee.png
└── slide_cup.png
├── docker_run.sh
├── exp_hydra_logs
└── .gitkeep
├── libvvcl.so
├── requirements.txt
├── rl
├── .gitignore
├── agent
│ ├── mf_utils.py
│ ├── sac.py
│ └── sac.yaml
├── configs
│ ├── default.yaml
│ ├── rlbench_both.yaml
│ ├── rlbench_pixels.yaml
│ └── rlbench_states.yaml
├── env
│ ├── __init__.py
│ ├── custom_arm_action_modes.py
│ ├── custom_rlbench_tasks
│ │ ├── elbow_angle_task_design.ttt
│ │ ├── task_ttms
│ │ │ ├── barista.ttm
│ │ │ ├── bottle_out_moving_fridge.ttm
│ │ │ ├── cup_out_open_cabinet.ttm
│ │ │ ├── reach_gripper_and_elbow.ttm
│ │ │ └── slide_cup.ttm
│ │ └── tasks
│ │ │ ├── __init__.py
│ │ │ ├── barista.py
│ │ │ ├── bottle_out_moving_fridge.py
│ │ │ ├── cup_out_open_cabinet.py
│ │ │ ├── meat_off_grill.py
│ │ │ ├── phone_on_base.py
│ │ │ ├── pick_and_lift.py
│ │ │ ├── pick_up_cup.py
│ │ │ ├── put_rubbish_in_bin.py
│ │ │ ├── reach_gripper_and_elbow.py
│ │ │ ├── reach_target.py
│ │ │ ├── slide_cup.py
│ │ │ ├── stack_wine.py
│ │ │ ├── take_lid_off_saucepan.py
│ │ │ └── take_umbrella_out_of_umbrella_stand.py
│ └── rlbench_envs.py
├── envs.py
├── hydra
│ └── hydra_logging
│ │ └── custom.yaml
├── logger.py
├── np_replay.py
├── train.py
├── train.yaml
└── utils.py
└── xvfb_run.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | exp_sweep
2 | exp_local
3 | __pycache__
4 | *.mp4
5 | *.jpeg
6 | datasets/*
7 | pt_models
8 | nohup*
9 | data
10 | output.png
11 | exp_hydra_logs/2*
12 | .ipynb_checkpoints/
13 | .git/*
14 | rw_demos/*
15 | rw_models/*
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.2.2-base-ubuntu20.04
2 |
3 | ENV DEBIAN_FRONTEND=noninteractive
4 |
5 | # controls which driver libraries/binaries will be mounted inside the container
6 | # docs: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html
7 | # This can be specifed at runtime or at image build time. Here we do it at build time.
8 | # Graphics capability is required to be specified for rendering.
9 | ENV NVIDIA_DRIVER_CAPABILITIES=compute,graphics,utility
10 |
11 | ARG uid
12 | ARG user
13 |
14 | RUN \
15 | apt-get update && \
16 | apt-get install -y \
17 | sudo \
18 | python3-pip \
19 | git \
20 | zsh \
21 | curl \
22 | wget \
23 | unzip \
24 | tmux \
25 | vim \
26 | mesa-utils \
27 | xvfb \
28 | qtbase5-dev \
29 | qtdeclarative5-dev \
30 | libqt5webkit5-dev \
31 | libsqlite3-dev \
32 | qt5-default \
33 | qttools5-dev-tools
34 |
35 | RUN \
36 | useradd -u ${uid} ${user} && \
37 | echo "${user} ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/${user} && \
38 | chmod 0440 /etc/sudoers.d/${user} && \
39 | mkdir -p /home/${user} && \
40 | chown -R ${user}:${user} /home/${user} && \
41 | chown ${user}:${user} /usr/local/bin && \
42 | mkdir /tmp/.X11-unix && \
43 | chmod 1777 /tmp/.X11-unix && \
44 | chown root /tmp/.X11-unix
45 |
46 | USER ${user}
47 |
48 | WORKDIR /home/${user}
49 |
50 | WORKDIR /home/${user}
51 |
52 |
53 | RUN \
54 | cur=`pwd` && \
55 | wget http://www.coppeliarobotics.com/files/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz && \
56 | tar -xf CoppeliaSim_Edu_V4_1_0_Ubuntu20_04.tar.xz && \
57 | export COPPELIASIM_ROOT="$cur/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04" && \
58 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT:$COPPELIASIM_ROOT/platforms && \
59 | export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT && \
60 | git clone https://github.com/stepjam/PyRep.git && \
61 | cd PyRep && \
62 | pip3 install -r requirements.txt && \
63 | pip3 install setuptools && \
64 | pip3 install .
65 |
66 | RUN \
67 | git clone https://github.com/stepjam/RLBench.git && cd RLBench && \
68 | pip install -r requirements.txt && \
69 | pip install .
70 |
71 | RUN \
72 | mkdir -p ~/.config/fish
73 |
74 | RUN \
75 | export cur=`pwd` && echo "set -x COPPELIASIM_ROOT $cur/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04" >> ~/.config/fish/config.fish && \
76 | echo 'set -x LD_LIBRARY_PATH $LD_LIBRARY_PATH $COPPELIASIM_ROOT $COPPELIASIM_ROOT/platforms' >> ~/.config/fish/config.fish && \
77 | echo 'set -x QT_QPA_PLATFORM_PLUGIN_PATH $COPPELIASIM_ROOT' >> ~/.config/fish/config.fish
78 |
79 | RUN sudo apt-get install -y fish
80 |
81 | RUN pip3 install torch==2.0.1 hydra-core scipy shapely trimesh pyrender wandb==0.15.4 timm
82 |
83 | # install VS Code (code-server)
84 | RUN curl -fsSL https://code-server.dev/install.sh | sh
85 | RUN code-server --install-extension ms-python.python ms-toolsai.jupyter
86 |
87 | # install VS Code extensions
88 | RUN sudo apt-get install wget gpg
89 | RUN wget -qO- https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > packages.microsoft.gpg
90 | RUN sudo install -D -o root -g root -m 644 packages.microsoft.gpg /etc/apt/keyrings/packages.microsoft.gpg
91 | RUN sudo sh -c 'echo "deb [arch=amd64,arm64,armhf signed-by=/etc/apt/keyrings/packages.microsoft.gpg] https://packages.microsoft.com/repos/code stable main" > /etc/apt/sources.list.d/vscode.list'
92 | RUN rm -f packages.microsoft.gpg
93 | RUN sudo apt-get update
94 | RUN sudo apt-get install -y code
95 | # RUN code --install-extension ms-python.python ms-toolsai.jupyter
96 |
97 | # Additional packages
98 | RUN sudo apt-get install libglew2.1 libgl1-mesa-glx libosmesa6
99 | RUN pip3 install gym termcolor hydra-submitit-launcher PyOpenGL==3.1.4 PyOpenGL_accelerate notebook matplotlib
100 | RUN pip3 install --upgrade requests
101 |
102 | COPY libvvcl.so CoppeliaSim_Edu_V4_1_0_Ubuntu20_04
103 | RUN pip3 install tensorboard imageio[ffmpeg] hydra-joblib-launcher moviepy
104 |
105 | RUN echo 'set -x PATH $PATH $HOME/.local/bin' >> ~/.config/fish/config.fish
106 |
107 | RUN \
108 | sudo apt-get update && sudo apt-get install -y \
109 | ffmpeg git python3-pip vim libglew-dev \
110 | x11-xserver-utils xvfb \
111 | && sudo apt-get clean
112 |
113 | RUN pip3 install einops dm_env
114 |
115 | ENTRYPOINT ["/usr/bin/fish"]
116 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Pietro Mazzaglia
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Code for Redundancy-aware Action Spaces for Robot Learning
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## RL experiments
16 |
17 | ### How to run
18 |
19 | For running the code, using Docker is recommended.
20 |
21 | To start the container, run:
22 |
23 | `bash docker_run.sh`
24 |
25 | To start an experiment launch a command inside the docker like:
26 |
27 | `sh xvfb_run.sh python3 rl/train.py use_wandb=False project_name=er agent=sac configs=rlbench_both task=rlbench_stack_wine env.action_mode=erjoint seed=1,2,3 experiment=test --multirun &`
28 |
29 | The command above launches 3 seeds in parallel. You should set `seed=SEED` to a unique value and remove `--multirun` to run a single experiment.
30 |
31 | Note that by default all log is redirected to the experiments folder in `out.log` and `err.log`, so no output will be logged in your console.
32 |
33 | ### Options
34 |
35 | The possible action modes are: `env.action_mode=joint,ee,erjoint,erangle`
36 |
37 | You can change the ERJ joint to the wrist one by setting: `+env.erj_joint=6`
38 |
39 | ### Notes
40 |
41 | * The implementation of SAC (with distributional critic) can be found at `rl/agent/sac.py` along with the hyperparameters used in the `sac.yaml` file.
42 | * The new RLBench tasks can be found at `rl/env/custom_rlbench_tasks.py`
43 | * The new action modes can be found at `rl/env/custom_arm_action_modes.py`
44 |
45 | ## Imitation learning experiments
46 |
47 | For the real-world experiments we used ACT: https://github.com/tonyzhaozh/act
48 |
49 | Some additional notes about the setup are:
50 | * Each action space model was trained for 3000 epochs using the hyper-parameters presented in the ACT paper (apart from using a batch size of 16, and chunk size of 20).
51 | * Additionally, the EE pose was added to the state information.
52 | * All quaternions in demonstrations and inference were forced to have a positive `w`. Lastly, only a wrist camera (with resolution 224x224) was used.
53 |
54 | For the IK in we used: [pick\_ik](https://github.com/PickNikRobotics/pick_ik) with default parameters for solving the IK for EE and ERJ.
55 |
56 | Some additional notes:
57 | * We followed the standard [pick\_ik Kinematics Solver" tutorial](https://moveit.picknik.ai/main/doc/how_to_guides/pick_ik/pick_ik_tutorial.html\#pick-ik-kinematics-solver) from the pick\_ik documentation
58 | * We only altered two parameters for the IK, from the standard parameters: `approximate_solution_position_threshold=0.01` and `approximate_solution_orientation_threshold=0.01` (both from the original value 0.05) to increase accuracy
59 |
60 |
61 |
--------------------------------------------------------------------------------
/assets/bottle_out_fridge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/bottle_out_fridge.png
--------------------------------------------------------------------------------
/assets/cup_out_open_cabinet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/cup_out_open_cabinet.png
--------------------------------------------------------------------------------
/assets/meat_off_grill.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/meat_off_grill.png
--------------------------------------------------------------------------------
/assets/reach_elbow_pose.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/reach_elbow_pose.png
--------------------------------------------------------------------------------
/assets/serve_coffee.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/serve_coffee.png
--------------------------------------------------------------------------------
/assets/slide_cup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/assets/slide_cup.png
--------------------------------------------------------------------------------
/docker_run.sh:
--------------------------------------------------------------------------------
1 | docker run --rm --gpus=all --net=host --privileged -e=DISPLAY -e=XDG_RUNTIME_DIR --shm-size=2gb \
2 | -v `pwd`:${HOME}/repos/ERAS \
3 | -v ${HOME}/.ssh:${HOME}/.ssh \
4 | -v $XDG_RUNTIME_DIR:$XDG_RUNTIME_DIR \
5 | -v ${XAUTHORITY}:${XAUTHORITY} \
6 | -v /usr/local/share/ca-certificates:/usr/local/share/ca-certificates \
7 | -v /etc/ssl/certs/ca-certificates.crt:/etc/ssl/certs/ca-certificates.crt \
8 | -w ${HOME}/repos/ERAS \
9 | -it $(docker build -q --build-arg uid=$(id -u ${USER}) --build-arg user=${USER} -t "local/robot_learning/er:latest" .)
--------------------------------------------------------------------------------
/exp_hydra_logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/exp_hydra_logs/.gitkeep
--------------------------------------------------------------------------------
/libvvcl.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/libvvcl.so
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core
2 | scipy
3 | shapely
4 | trimesh
5 | pyrender
6 | notebook
--------------------------------------------------------------------------------
/rl/.gitignore:
--------------------------------------------------------------------------------
1 | ManiSkill2-Learn
2 |
3 | __pycache__
4 | exp_local
5 | exp_sweep
6 | pretrained_models/*
7 | pt_models
8 | nohup*
9 | data
10 | output.png
11 | exp_hydra_logs/*
12 | datasets/*
--------------------------------------------------------------------------------
/rl/agent/mf_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | import utils
6 |
7 | class DrQEncoder(nn.Module):
8 | def __init__(self, obs_shape, key='observation', is_rgb=True):
9 | super().__init__()
10 |
11 | self._key = key
12 | self.is_rgb = is_rgb
13 | assert len(obs_shape) == 3
14 | self.repr_dim = 32 * 35 * 35
15 |
16 | self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], 32, 3, stride=2),
17 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
18 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
19 | nn.ReLU(), nn.Conv2d(32, 32, 3, stride=1),
20 | nn.ReLU())
21 |
22 | self.apply(utils.weight_init)
23 |
24 | def forward(self, obs):
25 | h = self.convnet(obs[self._key])
26 | h = h.view(h.shape[0], -1)
27 | return h
28 |
29 | def preprocess(self, obs):
30 | if self.is_rgb:
31 | obs[self._key] = obs[self._key] / 255.0 - 0.5
32 | return obs
33 |
34 | class IdentityEncoder(nn.Identity):
35 | def __init__(self, key):
36 | super().__init__()
37 | self._key = key
38 | self.fake_param = nn.parameter.Parameter()
39 |
40 | def forward(self, obs):
41 | return obs[self._key]
42 |
43 | def preprocess(self, obs):
44 | return obs
45 |
46 | class Actor(nn.Module):
47 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim):
48 | super().__init__()
49 |
50 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim),
51 | nn.LayerNorm(feature_dim), nn.Tanh(),
52 | )
53 |
54 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
55 | nn.ReLU(inplace=True),
56 | nn.Linear(hidden_dim, hidden_dim),
57 | nn.ReLU(inplace=True),
58 | nn.Linear(hidden_dim, action_dim * 2))
59 |
60 | self.log_std_bounds = [-10, 2]
61 | self.apply(utils.weight_init)
62 |
63 | def forward(self, enc,):
64 | enc = self.trunk(enc)
65 | mu, log_std = self.policy(enc).chunk(2, dim=-1)
66 | self._mu_std = mu.std().item()
67 |
68 | # constrain log_std inside [log_std_min, log_std_max]
69 | log_std = torch.tanh(log_std)
70 | log_std_min, log_std_max = self.log_std_bounds
71 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1)
72 | std = log_std.exp()
73 |
74 | dist = utils.SquashedNormal(mu, std)
75 | return dist
76 |
77 | class ActorFixedStd(nn.Module):
78 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim):
79 | super().__init__()
80 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim),
81 | nn.LayerNorm(feature_dim), nn.Tanh(),
82 | )
83 |
84 | self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
85 | nn.ReLU(inplace=True),
86 | nn.Linear(hidden_dim, hidden_dim),
87 | nn.ReLU(inplace=True),
88 | nn.Linear(hidden_dim, action_dim))
89 |
90 | self.apply(utils.weight_init)
91 |
92 | def forward(self, enc, std,):
93 | enc = self.trunk(enc)
94 | mu = self.policy(enc)
95 | self._mu_std = mu.std().item()
96 | mu = torch.tanh(mu)
97 | std = torch.ones_like(mu) * std
98 |
99 | dist = utils.TruncatedNormal(mu, std)
100 | return dist
101 |
102 | def signed_hyperbolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
103 | """Signed hyperbolic transform, inverse of signed_parabolic."""
104 | return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
105 |
106 | def signed_parabolic(x: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
107 | """Signed parabolic transform, inverse of signed_hyperbolic."""
108 | z = torch.sqrt(1 + 4 * eps * (eps + 1 + torch.abs(x))) / 2 / eps - 1 / 2 / eps
109 | return torch.sign(x) * (torch.square(z) - 1)
110 |
111 | def from_categorical(distribution, limit=20, offset=0., logits=True):
112 | distribution = distribution.float().squeeze(-1) # Avoid any fp16 shenanigans
113 | if logits:
114 | distribution = torch.softmax(distribution, -1)
115 | num_atoms = distribution.shape[-1]
116 | shift = limit * 2 / (num_atoms - 1)
117 | weights = torch.linspace(-(num_atoms//2), num_atoms//2, num_atoms, device=distribution.device).float().unsqueeze(-1)
118 | return signed_parabolic((distribution @ weights) * shift) - offset
119 |
120 | def to_categorical(value, limit=20, offset=0., num_atoms=251):
121 | value = value.float() + offset # Avoid any fp16 shenanigans
122 | shift = limit * 2 / (num_atoms - 1)
123 | value = signed_hyperbolic(value) / shift
124 | value = value.clamp(-(num_atoms//2), num_atoms//2)
125 | distribution = torch.zeros(value.shape[0], num_atoms, 1, device=value.device)
126 | lower = value.floor().long() + num_atoms // 2
127 | upper = value.ceil().long() + num_atoms // 2
128 | upper_weight = value % 1
129 | lower_weight = 1 - upper_weight
130 | distribution.scatter_add_(-2, lower.unsqueeze(-1), lower_weight.unsqueeze(-1))
131 | distribution.scatter_add_(-2, upper.unsqueeze(-1), upper_weight.unsqueeze(-1))
132 | return distribution
133 |
134 | class Critic(nn.Module):
135 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim):
136 | super().__init__()
137 |
138 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim),
139 | nn.LayerNorm(feature_dim), nn.Tanh(),
140 | )
141 |
142 | self.q1_net = nn.Sequential(
143 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True),
144 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
145 | nn.Linear(hidden_dim, 1))
146 |
147 | self.q2_net = nn.Sequential(
148 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True),
149 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
150 | nn.Linear(hidden_dim, 1))
151 |
152 | self.apply(utils.weight_init)
153 |
154 | def forward(self, enc, action):
155 | enc = self.trunk(enc)
156 | obs_action = torch.cat([enc, action], dim=-1)
157 | q1 = self.q1_net(obs_action)
158 | q2 = self.q2_net(obs_action)
159 | return q1, q2
160 |
161 | class DistributionalCritic(nn.Module):
162 | def __init__(self, enc_dim, action_dim, hidden_dim, feature_dim, num_atoms=251):
163 | super().__init__()
164 | self.distributional = True
165 |
166 | self.trunk = nn.Sequential(nn.Linear(enc_dim, feature_dim),
167 | nn.LayerNorm(feature_dim), nn.Tanh(),
168 | )
169 |
170 | self.q1_net = nn.Sequential(
171 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True),
172 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
173 | nn.Linear(hidden_dim, num_atoms))
174 |
175 | self.q2_net = nn.Sequential(
176 | nn.Linear(feature_dim + action_dim, hidden_dim), nn.ReLU(inplace=True),
177 | nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True),
178 | nn.Linear(hidden_dim, num_atoms))
179 |
180 | self.apply(utils.weight_init)
181 |
182 | def forward(self, enc, action):
183 | enc = self.trunk(enc)
184 | obs_action = torch.cat([enc, action], dim=-1)
185 | q1 = self.q1_net(obs_action)
186 | q2 = self.q2_net(obs_action)
187 | return q1, q2
--------------------------------------------------------------------------------
/rl/agent/sac.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from collections import OrderedDict, defaultdict
7 |
8 | import utils
9 | from utils import StreamNorm
10 | from agent.mf_utils import *
11 |
12 | class SACAgent:
13 | def __init__(self,
14 | name,
15 | cfg,
16 | obs_space,
17 | act_space,
18 | device,
19 | lr,
20 | hidden_dim,
21 | feature_dim,
22 | critic_target_tau,
23 | action_target_entropy,
24 | init_temperature,
25 | policy_delay,
26 | frame_stack,
27 | distributional,
28 | normalize_reward,
29 | normalize_returns,
30 | obs_keys,
31 | drq_encoder,
32 | drq_aug):
33 | self.cfg = cfg
34 | self.act_space = act_space
35 | self.obs_space = obs_space
36 | self.action_dim = np.sum((np.prod(v.shape) for v in act_space.values())) # prev: action_shape[0]
37 | self.hidden_dim = hidden_dim
38 | self.lr = lr
39 | self.device = device
40 | self.critic_target_tau = critic_target_tau
41 | self.policy_delay = policy_delay
42 | self.obs_keys = obs_keys.split('|')
43 | shapes = {}
44 | for k,v in obs_space.items():
45 | shapes[k] = list(v.shape)
46 | if len(v.shape) == 3:
47 | shapes[k][0] = shapes[k][0] * frame_stack
48 |
49 | self.frame_stack = frame_stack
50 | self.obs_buffer = defaultdict(list)
51 | self._batch_reward = 0
52 |
53 | # models
54 | self.encoders = {}
55 | self.augs = {}
56 | for k in self.obs_keys:
57 | if len(shapes[k]) == 3:
58 | img_size = shapes[k][-1]
59 | pad = img_size // 21 # pad=4 for 84
60 | self.augs[k] = utils.RandomShiftsAug(pad=pad) if drq_aug else nn.Identity()
61 | if drq_encoder:
62 | self.encoders[k] = DrQEncoder(shapes[k], key=k, is_rgb=obs_space[k].shape[0] == 3).to(self.device)
63 | else:
64 | raise NotImplementedError("")
65 | else:
66 | self.augs[k] = nn.Identity()
67 | self.encoders[k] = IdentityEncoder(k)
68 | self.encoders[k].repr_dim = shapes[k][0]
69 | self.encoders = nn.ModuleDict(self.encoders)
70 | self.enc_repr_dim = sum(e.repr_dim for e in self.encoders.values())
71 |
72 | self.actor = Actor(self.enc_repr_dim, self.action_dim,
73 | hidden_dim, feature_dim).to(device)
74 |
75 | if distributional:
76 | self.critic = DistributionalCritic(self.enc_repr_dim, self.action_dim,
77 | hidden_dim, feature_dim).to(device)
78 | self.critic_target = DistributionalCritic(self.enc_repr_dim, self.action_dim,
79 | hidden_dim, feature_dim).to(device)
80 | else:
81 | self.critic = Critic(self.enc_repr_dim, self.action_dim,
82 | hidden_dim, feature_dim).to(device)
83 | self.critic_target = Critic(self.enc_repr_dim, self.action_dim,
84 | hidden_dim, feature_dim).to(device)
85 | self.critic_target.load_state_dict(self.critic.state_dict())
86 |
87 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
88 | self.log_alpha.requires_grad = True
89 | if action_target_entropy == 'neg':
90 | # set target entropy to -|A|
91 | self.target_entropy = -self.action_dim
92 | elif action_target_entropy == 'neg_double':
93 | self.target_entropy = -self.action_dim * 2
94 | elif action_target_entropy == 'neglog':
95 | self.target_entropy = -torch.Tensor([self.action_dim]).to(self.device)
96 | elif action_target_entropy == 'zero':
97 | self.target_entropy = 0
98 | else:
99 | self.target_entropy = action_target_entropy
100 |
101 | # optimizers
102 | self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
103 | self.encoder_opt = torch.optim.Adam(self.encoders.parameters(), lr=lr)
104 | self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
105 | self.log_alpha_opt = torch.optim.Adam([self.log_alpha], lr=lr)
106 |
107 | self.normalize_reward = normalize_reward
108 | self.normalize_returns = normalize_returns
109 | self.rewnorm = StreamNorm(**{"momentum": 0.99, "scale": 1.0, "eps": 1e-8}, device=self.device)
110 | self.retnorm = StreamNorm(**{"momentum": 0.99, "scale": 1.0, "eps": 1e-8}, device=self.device)
111 |
112 | self.train()
113 | self.critic_target.train()
114 |
115 | def init_meta(self):
116 | return {}
117 |
118 | def get_meta_specs(self):
119 | return {}
120 |
121 | def update_meta(self, meta, global_step, time_step):
122 | return self.init_meta()
123 |
124 | def train(self, training=True):
125 | self.training = training
126 | self.encoders.train(training)
127 | self.actor.train(training)
128 | self.critic.train(training)
129 |
130 | @property
131 | def alpha(self):
132 | return self.log_alpha.exp()
133 |
134 | @torch.no_grad()
135 | def act(self, obs, meta, step, eval_mode, state,):
136 | is_first = obs['is_first'].all() or len(self.obs_buffer[self.obs_keys[0]]) == 0
137 |
138 | for k in self.obs_keys:
139 | obs[k] = torch.as_tensor(np.copy(obs[k]), device=self.device)
140 | if is_first:
141 | self.obs_buffer[k] = [obs[k]] * self.frame_stack
142 | else:
143 | self.obs_buffer[k].pop(0)
144 | self.obs_buffer[k].append(obs[k])
145 | obs_ch = obs[k].shape[1]
146 | obs_size = obs[k].shape[2:]
147 | obs[k] = torch.stack(self.obs_buffer[k], dim=1).reshape(-1, obs_ch * self.frame_stack, *obs_size)
148 |
149 | obs = torch.cat([ e(e.preprocess(obs)) for e in self.encoders.values()], dim=-1)
150 |
151 | policy = self.actor(obs,)
152 | if eval_mode:
153 | action = policy.mean
154 | else:
155 | action = policy.sample()
156 | if step < (self.cfg.num_seed_frames // self.cfg.action_repeat):
157 | action.uniform_(-1.0, 1.0)
158 | action = action.clamp(-1.0, 1.0)
159 | # @returns: action, state
160 | return action.cpu().numpy(), None
161 |
162 | def update_critic(self, obs, action, reward, discount, next_obs, step):
163 | metrics = dict()
164 |
165 | with torch.no_grad():
166 | dist = self.actor(next_obs)
167 | next_action = dist.rsample()
168 | log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
169 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
170 | if getattr(self.critic, 'distributional', False):
171 | target_Q1, target_Q2 = from_categorical(target_Q1), from_categorical(target_Q2)
172 | target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_prob
173 | mag_norm_return, batch_return_metrics = self.retnorm((discount * target_V).clone())
174 | if self.normalize_returns:
175 | ret_mean, ret_var, ret_std = self.retnorm.corrected_mean_var_std()
176 | target_V = target_V / ret_std
177 | target_Q = reward + (discount * target_V)
178 | if getattr(self.critic, 'distributional', False):
179 | target_Q_dist = to_categorical(target_Q,)
180 |
181 | Q1, Q2 = self.critic(obs, action)
182 | if getattr(self.critic, 'distributional', False):
183 | critic_loss = - torch.mean(torch.sum(torch.log_softmax(Q1, -1) * target_Q_dist.squeeze(-1).detach(), -1)) - torch.mean(torch.sum(torch.log_softmax(Q2, -1) * target_Q_dist.squeeze(-1).detach(), -1))
184 | Q1, Q2 = from_categorical(Q1), from_categorical(Q2)
185 | else:
186 | critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
187 |
188 | metrics['critic_target_q'] = target_Q.mean().item()
189 | metrics['critic_target_q_max'] = target_Q.max().item()
190 | metrics['critic_target_q_min'] = target_Q.min().item()
191 | if self.normalize_returns:
192 | metrics['critic_target_v_running_std'] = ret_std.item()
193 | metrics['critic_q1'] = Q1.mean().item()
194 | metrics['critic_q2'] = Q2.mean().item()
195 | metrics['critic_loss'] = critic_loss.item()
196 |
197 | # optimize encoder and critic
198 | self.critic_opt.zero_grad(set_to_none=True)
199 | self.encoder_opt.zero_grad(set_to_none=True)
200 |
201 | critic_loss.backward()
202 |
203 | self.critic_opt.step()
204 | self.encoder_opt.step()
205 |
206 | return metrics
207 |
208 | def update_actor_and_alpha(self, obs, action, step):
209 | metrics = dict()
210 |
211 | policy = self.actor(obs)
212 | action = policy.rsample()
213 | log_prob = policy.log_prob(action).sum(-1, keepdim=True)
214 |
215 | Q1, Q2 = self.critic(obs, action)
216 | if getattr(self.critic, 'distributional', False):
217 | Q1, Q2 = from_categorical(Q1), from_categorical(Q2)
218 | Q = torch.min(Q1, Q2)
219 |
220 | actor_loss = (self.alpha.detach() * log_prob -Q).mean()
221 |
222 | # optimize actor
223 | self.actor_opt.zero_grad(set_to_none=True)
224 | actor_loss.backward()
225 | self.actor_opt.step()
226 |
227 | alpha_loss = (self.alpha *
228 | (-log_prob - self.target_entropy).detach()).mean()
229 |
230 | self.log_alpha_opt.zero_grad(set_to_none=True)
231 | alpha_loss.backward()
232 | self.log_alpha_opt.step()
233 |
234 | metrics['actor_loss'] = actor_loss.item()
235 | metrics['alpha_loss'] = alpha_loss.item()
236 | metrics['actor_mean_stddev'] = self.actor._mu_std
237 | metrics['alpha'] = self.alpha.item()
238 | metrics['actor_ent'] = -log_prob.mean().item()
239 |
240 | policy_ent_per_dim = policy.base_dist.entropy().mean(dim=0)
241 | for ai in range(action.shape[-1]):
242 | metrics[f'policy_dist/dim_{ai}'] = policy_ent_per_dim[ai]
243 |
244 | return metrics
245 |
246 | def update(self, batch, step):
247 | metrics = dict()
248 |
249 | obs, next_obs = {}, {}
250 | for k in self.obs_keys:
251 | b, t = batch[k].shape[:2]
252 | # assert t == (self.frame_stack + 1)
253 | obs_ch = batch[k].shape[2]
254 | obs_size = batch[k].shape[3:]
255 | if len(obs_size) == 2:
256 | obs[k] = batch[k][:, 0:self.frame_stack].reshape(b, obs_ch * self.frame_stack, *obs_size)
257 | next_obs[k] = batch[k][:, 1:self.frame_stack+1].reshape(b, obs_ch * self.frame_stack, *obs_size)
258 | else:
259 | obs[k] = batch[k][:, self.frame_stack-1].reshape(b, obs_ch, *obs_size)
260 | next_obs[k] = batch[k][:, self.frame_stack].reshape(b, obs_ch, *obs_size)
261 |
262 | obs[k] = self.augs[k](obs[k].float()).to(self.device)
263 | next_obs[k] = self.augs[k](next_obs[k].float()).to(self.device)
264 |
265 | action = batch['action'][:, self.frame_stack].to(self.device)
266 | reward = batch['reward'][:, self.frame_stack].to(self.device)
267 | discount = (batch['discount'][:, self.frame_stack] * self.cfg.discount).to(self.device)
268 |
269 | mag_norm_reward, batch_rew_metrics = self.rewnorm(reward.clone())
270 | if self.normalize_reward:
271 | rw_mean, rw_var, rw_std = self.rewnorm.corrected_mean_var_std()
272 | reward = reward / rw_std
273 |
274 | obs = torch.cat([e(e.preprocess(obs)) for e in self.encoders.values()], dim=-1)
275 | with torch.no_grad():
276 | next_obs = torch.cat([e(e.preprocess(next_obs)) for e in self.encoders.values()], dim=-1)
277 |
278 | if self.normalize_reward:
279 | metrics['reward_running_mean'] = rw_mean.item()
280 | metrics['reward_running_std'] = rw_std.item()
281 |
282 | # update critic
283 | metrics.update(
284 | self.update_critic(obs, action, reward, discount, next_obs, step))
285 |
286 | if step % self.policy_delay == 0:
287 | # update actor
288 | metrics.update(self.update_actor_and_alpha(obs.detach(), action, step))
289 |
290 | # update critic target
291 | utils.soft_update_params(self.critic, self.critic_target,
292 | self.critic_target_tau)
293 | # @returns: state, metrics
294 | return None, metrics
--------------------------------------------------------------------------------
/rl/agent/sac.yaml:
--------------------------------------------------------------------------------
1 | # @package agent
2 | _target_: agent.sac.SACAgent
3 | name: sac
4 | obs_space: ??? # to be specified later
5 | act_space: ??? # to be specified later
6 | device: ${device}
7 | lr: 3e-4 # 1e-4 in ExORL
8 | critic_target_tau: 0.01 # 0.005 in SpinningUp # 0.01 in EXORL
9 | hidden_dim: 1024
10 | feature_dim: 50
11 | # entropy
12 | init_temperature: 0.1 # 0.1 was default
13 | action_target_entropy: neg # neg was default
14 | #
15 | policy_delay: 1
16 | frame_stack: 1 # 3 for DMC pixels
17 | obs_keys: front_rgb|wrist_rgb|state # default for pixels
18 | drq_encoder: true
19 | drq_aug: true
20 | # normalization
21 | distributional: true
22 | normalize_reward: true
23 | normalize_returns: true
24 |
--------------------------------------------------------------------------------
/rl/configs/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | replay: {capacity: 2e6, ongoing: False, minlen: 50, maxlen: 50, prioritize_ends: False}
3 | discount: 0.99
4 |
5 | # RLBench defaults
6 | train_every_actions: 8
7 | parallel_envs: 12
8 | async_mode: 'FULL'
9 | batch_size: 128
10 | batch_length: 2
--------------------------------------------------------------------------------
/rl/configs/rlbench_both.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | obs_type: both
3 | action_repeat: 1
4 | encoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]}
5 | decoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]}
6 | replay.capacity: 2e6
7 |
8 | env:
9 | action_mode: ee
10 | cameras: front|wrist
11 | state_info: gripper_open|gripper_pose|joint_cart_pos|joint_positions
--------------------------------------------------------------------------------
/rl/configs/rlbench_pixels.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | obs_type: pixels
3 | action_repeat: 1
4 | encoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]}
5 | decoder: {mlp_keys: '$^', cnn_keys: 'front_rgb', norm: none, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]}
6 | replay.capacity: 2e6
7 |
8 | env:
9 | action_mode: ee
--------------------------------------------------------------------------------
/rl/configs/rlbench_states.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | obs_type: states
3 | action_repeat: 1
4 | encoder: {mlp_keys: 'state', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [4, 4, 4, 4], mlp_layers: [400, 400, 400, 400]}
5 | decoder: {mlp_keys: 'state', cnn_keys: '$^', norm: layer, cnn_depth: 48, cnn_kernels: [5, 5, 6, 6], mlp_layers: [400, 400, 400, 400]}
6 | replay.capacity: 2e6
7 |
8 | env:
9 | action_mode: ee
--------------------------------------------------------------------------------
/rl/env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/__init__.py
--------------------------------------------------------------------------------
/rl/env/custom_arm_action_modes.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import List, Union
3 | from pyrep.objects import Object
4 |
5 | import numpy as np
6 | from enum import Enum
7 | from pyquaternion import Quaternion
8 | from scipy.spatial.transform import Rotation
9 | from pyrep.const import ConfigurationPathAlgorithms as Algos, ObjectType
10 | from pyrep.errors import ConfigurationPathError, IKError, ConfigurationError
11 | from pyrep.robots.configuration_paths.arm_configuration_path import ArmConfigurationPath
12 | from pyrep.const import PYREP_SCRIPT_TYPE
13 |
14 | from rlbench.backend.exceptions import InvalidActionError
15 | from rlbench.backend.robot import Robot
16 | from rlbench.backend.scene import Scene
17 | from rlbench.const import SUPPORTED_ROBOTS
18 |
19 | from rlbench.action_modes.arm_action_modes import ArmActionMode, EndEffectorPoseViaIK
20 | from pyrep.backend import sim, utils
21 | from pyrep.objects.dummy import Dummy
22 |
23 | def assert_action_shape(action: np.ndarray, expected_shape: tuple):
24 | if np.shape(action) != expected_shape:
25 | raise InvalidActionError(
26 | 'Expected the action shape to be: %s, but was shape: %s' % (
27 | str(expected_shape), str(np.shape(action))))
28 |
29 |
30 | def assert_unit_quaternion(quat):
31 | if not np.isclose(np.linalg.norm(quat), 1.0):
32 | raise InvalidActionError('Action contained non unit quaternion!')
33 |
34 |
35 | def calculate_delta_pose(robot: Robot, action: np.ndarray):
36 | a_x, a_y, a_z, a_qx, a_qy, a_qz, a_qw = action
37 | x, y, z, qx, qy, qz, qw = robot.arm.get_tip().get_pose()
38 | new_rot = Quaternion(
39 | a_qw, a_qx, a_qy, a_qz) * Quaternion(qw, qx, qy, qz)
40 | qw, qx, qy, qz = list(new_rot)
41 | pose = [a_x + x, a_y + y, a_z + z] + [qx, qy, qz, qw]
42 | return pose
43 |
44 | def calculate_delta_pose_elbow(robot: Robot, action: np.ndarray):
45 | a_x, a_y, a_z, ae_x, ae_y, ae_z, a_qx, a_qy, a_qz, a_qw = action
46 | x, y, z, qx, qy, qz, qw = robot.arm.get_tip().get_pose()
47 | ex, ey, ez = robot.arm._ik_elbow.get_position()
48 | new_rot = Quaternion(
49 | a_qw, a_qx, a_qy, a_qz) * Quaternion(qw, qx, qy, qz)
50 | qw, qx, qy, qz = list(new_rot)
51 | new_action = [a_x + x, a_y + y, a_z + z] + [ae_x + ex, ae_y + ey, ae_z + ez] + [qx, qy, qz, qw]
52 | return new_action
53 |
54 | class RelativeFrame(Enum):
55 | WORLD = 0
56 | EE = 1
57 |
58 | class EEOrientationState(Enum):
59 | FREE = 0
60 | FIXED = 1
61 | KEEP = 2
62 |
63 | class ERAngleViaIK(ArmActionMode):
64 | """High-level action where target EE pose + Elbow angle is given in ER
65 | space (End-Effector and Elbow) and reached via IK.
66 |
67 | Given a target pose, IK via inverse Jacobian is performed. This requires
68 | the target pose to be close to the current pose, otherwise the action
69 | will fail. It is up to the user to constrain the action to
70 | meaningful values.
71 |
72 | The decision to apply collision checking is a crucial trade off!
73 | With collision checking enabled, you are guaranteed collision free paths,
74 | but this may not be applicable for task that do require some collision.
75 | E.g. using this mode on pushing object will mean that the generated
76 | path will actively avoid not pushing the object.
77 | """
78 |
79 | def __init__(self,
80 | absolute_mode: bool = True,
81 | frame: RelativeFrame = RelativeFrame.WORLD,
82 | collision_checking: bool = False,
83 | orientation_state: EEOrientationState = None):
84 | self._absolute_mode = absolute_mode
85 | self._frame = frame
86 | self._collision_checking = collision_checking
87 | self._orientation_state = orientation_state
88 | self._action_shape = (8,)
89 |
90 | def setup_arm(self, arm, name='Panda', suffix=''):
91 | # @comment: params are hardcoded for now, but could be useful to extend for all robots
92 | arm._ik_elbow_target = Dummy('%s_elbow_target%s' % (name, suffix))
93 | arm._ik_elbow = Dummy('%s_elbow%s' % (name, suffix))
94 | arm._elbow_ik_group = sim.simGetIkGroupHandle('%s_e3_ik%s' % (name, suffix))
95 | joint2 = sim.simGetObjectHandle('%s_joint2%s' % (name, suffix))
96 | joint6 = sim.simGetObjectHandle('%s_joint6%s' % (name, suffix))
97 | arm._joint2_obj = Object.get_object(joint2)
98 | arm._joint6_obj = Object.get_object(joint6)
99 |
100 | sim.simSetIkGroupProperties(
101 | arm._elbow_ik_group,
102 | sim.sim_ik_damped_least_squares_method,
103 | maxIterations=400,
104 | damping=1e-3
105 | )
106 |
107 | ee_constraints = \
108 | sim.sim_ik_x_constraint | sim.sim_ik_y_constraint | sim.sim_ik_z_constraint | \
109 | sim.sim_ik_alpha_beta_constraint | sim.sim_ik_gamma_constraint
110 | if self._orientation_state == EEOrientationState.FREE:
111 | ee_constraints = sim.sim_ik_x_constraint | sim.sim_ik_y_constraint |sim.sim_ik_z_constraint
112 | sim.simSetIkElementProperties(
113 | arm._elbow_ik_group,
114 | arm._ik_tip.get_handle(),
115 | ee_constraints,
116 | precision=[5e-4, 5/180*np.pi],
117 | weight=[1, 1]
118 | )
119 |
120 | elbow_constraints = sim.sim_ik_alpha_beta_constraint | sim.sim_ik_gamma_constraint
121 | sim.simSetIkElementProperties(
122 | arm._elbow_ik_group,
123 | arm._ik_elbow.get_handle(),
124 | elbow_constraints,
125 | precision=[5e-4, 3/180*np.pi],
126 | weight=[0, 1]
127 | )
128 |
129 | def solve_ik_via_jacobian(self,
130 | arm,
131 | position: Union[List[float], np.ndarray],
132 | euler: Union[List[float], np.ndarray] = None,
133 | quaternion: Union[List[float], np.ndarray] = None,
134 | angle: float = None,
135 | relative_to: Object = None) -> List[float]:
136 | """Solves an IK group and returns the calculated joint values.
137 |
138 | This IK method performs a linearisation around the current robot
139 | configuration via the Jacobian. The linearisation is valid when the
140 | start and goal pose are not too far away, but after a certain point,
141 | linearisation will no longer be valid. In that case, the user is better
142 | off using 'solve_ik_via_sampling'.
143 |
144 | Must specify either rotation in euler or quaternions, but not both!
145 |
146 | :param arm: The CoppelliaSim arm to move.
147 | :param position: The x, y, z position of the ee target.
148 | :param euler: The x, y, z orientation of the ee target (in radians).
149 | :param quaternion: A list containing the quaternion (x,y,z,w).
150 | :param angle: The target angle for the elbow to rotate around the
151 | vector normal to the circle of possible elbow positions (in radians).
152 | :param relative_to: Indicates relative to which reference frame we want
153 | the target pose. Specify None to retrieve the absolute pose,
154 | or an Object relative to whose reference frame we want the pose.
155 | :return: A list containing the calculated joint values.
156 | """
157 | assert len(position) == 3
158 | arm._ik_target.set_position(position, relative_to)
159 |
160 | if euler is not None:
161 | arm._ik_target.set_orientation(euler, relative_to)
162 | elif quaternion is not None:
163 | arm._ik_target.set_quaternion(quaternion, relative_to)
164 |
165 | target_elbow_quat = arm._ik_elbow.get_quaternion()
166 |
167 | if angle is not None:
168 | w = arm._joint6_obj.get_position() - arm._joint2_obj.get_position()
169 |
170 | w_norm = w / np.linalg.norm(w)
171 | q = np.concatenate((np.sin(angle)*w_norm, np.array([np.cos(angle)])))
172 | r = Rotation.from_quat(q)
173 |
174 | elbow_rot = Rotation.from_quat(arm._ik_elbow.get_quaternion())
175 | target_elbow_quat = (r * elbow_rot).as_quat()
176 |
177 | arm._ik_elbow_target.set_quaternion(target_elbow_quat, relative_to)
178 |
179 | ik_result, joint_values = sim.simCheckIkGroup(
180 | arm._elbow_ik_group, [j.get_handle() for j in arm.joints])
181 | if ik_result == sim.sim_ikresult_fail:
182 | raise IKError('IK failed. Perhaps the distance was between the tip '
183 | ' and target was too large.')
184 | elif ik_result == sim.sim_ikresult_not_performed:
185 | raise IKError('IK not performed.')
186 | return joint_values
187 |
188 | def action(self, scene: Scene, action: np.ndarray):
189 | """Performs action using IK.
190 |
191 | :param scene: CoppeliaSim scene.
192 | :param action: Must be in the form [ee_pose, elbow_angle] with a len of 8.
193 | """
194 | arm = scene.robot.arm
195 | if not hasattr(arm, '_ik_elbow_target'):
196 | self.setup_arm(arm)
197 | assert_action_shape(action, self._action_shape)
198 | assert_unit_quaternion(action[3:-1])
199 | angle = action[-1]
200 | ee_action = action[:-1]
201 | if not self._absolute_mode and self._frame != RelativeFrame.EE:
202 | ee_action = calculate_delta_pose(scene.robot, ee_action)
203 | relative_to = None if self._frame == RelativeFrame.WORLD else arm.get_tip()
204 |
205 | assert relative_to is None # NOT IMPLEMENTED
206 |
207 | if self._orientation_state == EEOrientationState.FIXED:
208 | ee_action[3:] = np.array([0, 1, 0, 0])
209 | if self._orientation_state == EEOrientationState.KEEP:
210 | ee_action[3:] = arm._ik_tip.get_quaternion()
211 |
212 | try:
213 | joint_positions = self.solve_ik_via_jacobian(
214 | arm,
215 | position=ee_action[:3], quaternion=ee_action[3:], angle=angle,
216 | relative_to=relative_to
217 | )
218 | arm.set_joint_target_positions(joint_positions)
219 | except IKError as e:
220 | raise InvalidActionError(
221 | 'Could not perform IK via Jacobian; most likely due to current '
222 | 'end-effector pose being too far from the given target pose. '
223 | 'Try limiting/bounding your action space.') from e
224 |
225 | done = False
226 | prev_values = None
227 | max_steps = 10
228 | steps = 0
229 |
230 | while not done and steps < max_steps:
231 | scene.step()
232 | cur_positions = arm.get_joint_positions()
233 | reached = np.allclose(cur_positions, joint_positions, atol=0.01)
234 | not_moving = False
235 | if prev_values is not None:
236 | not_moving = np.allclose(
237 | cur_positions, prev_values, atol=0.001)
238 | prev_values = cur_positions
239 | done = reached or not_moving
240 | steps += 1
241 |
242 | def action_shape(self, _: Scene) -> tuple:
243 | return self._action_shape
244 |
245 | class ERJointViaIK(ArmActionMode):
246 | """High-level action where target EE pose + Elbow angle is given in ER
247 | space (End-Effector and Elbow) and reached via IK.
248 |
249 | Given a target pose, IK via inverse Jacobian is performed. This requires
250 | the target pose to be close to the current pose, otherwise the action
251 | will fail. It is up to the user to constrain the action to
252 | meaningful values.
253 |
254 | The decision to apply collision checking is a crucial trade off!
255 | With collision checking enabled, you are guaranteed collision free paths,
256 | but this may not be applicable for task that do require some collision.
257 | E.g. using this mode on pushing object will mean that the generated
258 | path will actively avoid not pushing the object.
259 | """
260 |
261 | def __init__(self,
262 | absolute_mode: bool = True,
263 | frame: RelativeFrame = RelativeFrame.WORLD,
264 | collision_checking: bool = False,
265 | orientation_state: EEOrientationState = None,
266 | commanded_joint : int = 0,
267 | eps : float = 1e-3,
268 | delta_angle : bool = False):
269 | self._absolute_mode = absolute_mode
270 | self._frame = frame
271 | self._collision_checking = collision_checking
272 | self._orientation_state = orientation_state
273 | self._action_shape = (8,)
274 | self._excl_j_idx = commanded_joint
275 | self.EPS = eps
276 | self.delta_angle = delta_angle
277 |
278 | def solve_ik_via_jacobian(self,
279 | arm,
280 | position: Union[List[float], np.ndarray],
281 | euler: Union[List[float], np.ndarray] = None,
282 | quaternion: Union[List[float], np.ndarray] = None,
283 | relative_to: Object = None) -> List[float]:
284 | """Solves an IK group and returns the calculated joint values.
285 |
286 | This IK method performs a linearisation around the current robot
287 | configuration via the Jacobian. The linearisation is valid when the
288 | start and goal pose are not too far away, but after a certain point,
289 | linearisation will no longer be valid. In that case, the user is better
290 | off using 'solve_ik_via_sampling'.
291 |
292 | Must specify either rotation in euler or quaternions, but not both!
293 |
294 | :param arm: The CoppelliaSim arm to move.
295 | :param position: The x, y, z position of the ee target.
296 | :param euler: The x, y, z orientation of the ee target (in radians).
297 | :param quaternion: A list containing the quaternion (x,y,z,w).
298 | :param angle: The target angle for the elbow to rotate around the
299 | vector normal to the circle of possible elbow positions (in radians).
300 | :param relative_to: Indicates relative to which reference frame we want
301 | the target pose. Specify None to retrieve the absolute pose,
302 | or an Object relative to whose reference frame we want the pose.
303 | :return: A list containing the calculated joint values.
304 | """
305 | assert len(position) == 3
306 | arm._ik_target.set_position(position, relative_to)
307 |
308 | if euler is not None:
309 | arm._ik_target.set_orientation(euler, relative_to)
310 | elif quaternion is not None:
311 | arm._ik_target.set_quaternion(quaternion, relative_to)
312 |
313 | # Removing the joint controlling the elbow position from IK chain
314 | joint_excl_elbow = arm.joints[:self._excl_j_idx] + arm.joints[self._excl_j_idx+1:]
315 |
316 | ik_result, joint_values = sim.simCheckIkGroup(
317 | arm._ik_group, [j.get_handle() for j in joint_excl_elbow])
318 | if ik_result == sim.sim_ikresult_fail:
319 | raise IKError('IK failed. Perhaps the distance was between the tip '
320 | ' and target was too large.')
321 | elif ik_result == sim.sim_ikresult_not_performed:
322 | raise IKError('IK not performed.')
323 | return joint_values
324 |
325 | def action(self, scene: Scene, action: np.ndarray):
326 | """Performs action using IK.
327 |
328 | :param scene: CoppeliaSim scene.
329 | :param action: Must be in the form [ee_pose, elbow_angle] with a len of 8.
330 | """
331 | arm = scene.robot.arm
332 | assert_action_shape(action, self._action_shape)
333 | assert_unit_quaternion(action[3:-1])
334 | angle = action[-1]
335 | ee_action = action[:-1]
336 | if self.delta_angle:
337 | assert not self._absolute_mode, 'Cannot use delta_angle_mode'
338 |
339 | if self._excl_j_idx == 0:
340 | c = arm.joints[0].get_position()[:2]
341 | p = arm.get_tip().get_position()[:2]
342 | a = angle
343 |
344 | new_p = [c[0] + (p[0] - c[0]) * np.cos(a) - (p[1] - c[1]) * np.sin(a),
345 | c[1] + (p[0] - c[0]) * np.sin(a) + (p[1] - c[1]) * np.cos(a)]
346 |
347 | angle_delta_ee = new_p - p
348 | ee_action[:2] += angle_delta_ee
349 |
350 | rot = Rotation.from_quat(ee_action[-4:]).as_euler('xyz')
351 | rot[-1] += angle
352 | ee_action[-4:] = Rotation.from_euler('xyz', rot).as_quat()
353 | elif self._excl_j_idx == 6:
354 | # Get axis for z axis wrt the gripper
355 | v = np.array([0.,0.,1.])
356 | axis = Rotation.from_quat(arm.get_tip().get_quaternion()).apply(v) # Rotation.from_quat(arm.get_tip().get_quaternion()).as_euler('xyz')
357 | # Get rotation given by joint
358 | quat_new = Quaternion(axis=axis, angle=angle)
359 | w_new, x_new, y_new, z_new = quat_new
360 | quat_new = Rotation.from_quat([x_new, y_new, z_new, w_new])
361 | # Add rotation to original rotation
362 | rot = Rotation.from_quat(ee_action[-4:])
363 | rot_new = rot * quat_new
364 | ee_action[-4:] = rot_new.as_quat()
365 | else:
366 | raise NotImplementedError(f'Cannot use delta_angle_mode with joint {self._excl_j_idx}')
367 |
368 | if not self._absolute_mode and self._frame != RelativeFrame.EE:
369 | ee_action = calculate_delta_pose(scene.robot, ee_action)
370 | relative_to = None if self._frame == RelativeFrame.WORLD else arm.get_tip()
371 |
372 | if self._orientation_state == EEOrientationState.FIXED:
373 | ee_action[3:] = np.array([0, 1, 0, 0])
374 | if self._orientation_state == EEOrientationState.KEEP:
375 | ee_action[3:] = arm._ik_tip.get_quaternion()
376 |
377 | try:
378 | # Constrain joint to final position
379 | prev_joint_pos = arm.get_joint_positions()[self._excl_j_idx]
380 | new_joint_pos = angle if self._absolute_mode else prev_joint_pos + angle
381 |
382 | eps = self.EPS
383 | orig_cyclic, orig_interval = arm.joints[self._excl_j_idx].get_joint_interval()
384 | # Making the joint angle valid
385 | if new_joint_pos - eps < orig_interval[0]:
386 | new_joint_pos = orig_interval[0] + eps
387 | if new_joint_pos + eps > orig_interval[0] + orig_interval[1]:
388 | new_joint_pos = orig_interval[0] + orig_interval[1] - eps
389 | # Set target joint interval
390 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, [new_joint_pos-eps, 2 * eps])
391 |
392 | joint_positions = self.solve_ik_via_jacobian(
393 | arm,
394 | position=ee_action[:3], quaternion=ee_action[3:],
395 | relative_to=relative_to
396 | )
397 | # Restore joint constraints
398 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, orig_interval)
399 | joint_positions = joint_positions[:self._excl_j_idx] + [new_joint_pos] + joint_positions[self._excl_j_idx:]
400 | arm.set_joint_target_positions(joint_positions)
401 | except IKError as e:
402 | # Restoring joint constraints if there was an error (also restoring to prev_pos first to avoid internal accumulating error)
403 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, [prev_joint_pos, 2 * eps])
404 | arm.joints[self._excl_j_idx].set_joint_interval(orig_cyclic, orig_interval)
405 | raise InvalidActionError(
406 | 'Could not perform IK via Jacobian; most likely due to current '
407 | 'end-effector pose being too far from the given target pose. '
408 | 'Try limiting/bounding your action space.') from e
409 |
410 | done = False
411 | prev_values = None
412 | max_steps = 50
413 | steps = 0
414 |
415 | # Move until reached target joint positions or until we stop moving
416 | # (e.g. when we collide wth something)
417 |
418 | while not done and steps < max_steps:
419 | scene.step()
420 | cur_positions = arm.get_joint_positions()
421 | reached = np.allclose(cur_positions, joint_positions, atol=1e-3)
422 | not_moving = False
423 | if prev_values is not None:
424 | not_moving = np.allclose(
425 | cur_positions, prev_values, atol=1e-3)
426 | prev_values = cur_positions
427 | done = reached or not_moving
428 | steps += 1
429 |
430 | def action_shape(self, _: Scene) -> tuple:
431 | return self._action_shape
432 |
433 | class TimeoutEndEffectorPoseViaIK(EndEffectorPoseViaIK):
434 | """
435 | The exact same EE action mode of RLBench, but with a timeout (max steps), to prevent infinite loops
436 | """
437 |
438 | def action(self, scene: Scene, action: np.ndarray):
439 | assert_action_shape(action, (7,))
440 | assert_unit_quaternion(action[3:])
441 | if not self._absolute_mode and self._frame != 'end effector':
442 | action = calculate_delta_pose(scene.robot, action)
443 | relative_to = None if self._frame == 'world' else scene.robot.arm.get_tip()
444 |
445 | try:
446 | joint_positions = scene.robot.arm.solve_ik_via_jacobian(
447 | action[:3], quaternion=action[3:], relative_to=relative_to)
448 | scene.robot.arm.set_joint_target_positions(joint_positions)
449 | except IKError as e:
450 | raise InvalidActionError(
451 | 'Could not perform IK via Jacobian; most likely due to current '
452 | 'end-effector pose being too far from the given target pose. '
453 | 'Try limiting/bounding your action space.') from e
454 | done = False
455 | prev_values = None
456 | steps = 0
457 | max_steps = 50
458 | # Move until reached target joint positions or until we stop moving
459 | # (e.g. when we collide wth something)
460 | while not done and steps < max_steps:
461 | scene.step()
462 | cur_positions = scene.robot.arm.get_joint_positions()
463 | reached = np.allclose(cur_positions, joint_positions, atol=0.01)
464 | not_moving = False
465 | if prev_values is not None:
466 | not_moving = np.allclose(
467 | cur_positions, prev_values, atol=0.001)
468 | prev_values = cur_positions
469 | done = reached or not_moving
470 | steps += 1
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/elbow_angle_task_design.ttt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/elbow_angle_task_design.ttt
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/task_ttms/barista.ttm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/barista.ttm
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/task_ttms/bottle_out_moving_fridge.ttm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/bottle_out_moving_fridge.ttm
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/task_ttms/cup_out_open_cabinet.ttm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/cup_out_open_cabinet.ttm
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/task_ttms/reach_gripper_and_elbow.ttm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/reach_gripper_and_elbow.ttm
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/task_ttms/slide_cup.ttm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazpie/redundancy-action-spaces/7e7d717819f1217ad1cfe9841b08ae104174904b/rl/env/custom_rlbench_tasks/task_ttms/slide_cup.ttm
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .phone_on_base import PhoneOnBase
2 | from .pick_up_cup import PickUpCup
3 | from .put_rubbish_in_bin import PutRubbishInBin
4 | from .stack_wine import StackWine
5 | from .take_umbrella_out_of_umbrella_stand import TakeUmbrellaOutOfUmbrellaStand
6 | from .take_lid_off_saucepan import TakeLidOffSaucepan
7 | from .reach_target import ReachTarget
8 | from .pick_and_lift import PickAndLift
9 | from .meat_off_grill import MeatOffGrill
10 | from .bottle_out_moving_fridge import BottleOutMovingFridge
11 | from .barista import Barista
12 | from .reach_gripper_and_elbow import ReachGripperAndElbow
13 | from .slide_cup import SlideCup
14 | from .cup_out_open_cabinet import CupOutOpenCabinet
15 |
16 |
17 | CUSTOM_TASKS = \
18 | {
19 | # Standard 8 tasks
20 | 'reach_target' : ReachTarget,
21 | 'pick_up_cup' : PickUpCup,
22 | 'take_umbrella_out_of_umbrella_stand' : TakeUmbrellaOutOfUmbrellaStand,
23 | 'take_lid_off_saucepan' : TakeLidOffSaucepan,
24 | 'pick_and_lift': PickAndLift,
25 | 'phone_on_base' : PhoneOnBase,
26 | 'stack_wine' : StackWine,
27 | 'put_rubbish_in_bin' : PutRubbishInBin,
28 | # Abbreviations
29 | 'umbrella_out' : TakeUmbrellaOutOfUmbrellaStand,
30 | 'saucepan' : TakeLidOffSaucepan,
31 | # New tasks
32 | 'reach_elbow_pose' : ReachGripperAndElbow,
33 | 'take_bottle_out_fridge': BottleOutMovingFridge,
34 | 'serve_coffee_obstacles' : Barista,
35 | 'slide_cup_obstacles' : SlideCup,
36 | 'meat_off_grill' : MeatOffGrill,
37 | 'take_cup_out_cabinet' : CupOutOpenCabinet,
38 |
39 | }
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/barista.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | import numpy as np
3 | from pyrep.objects.dummy import Dummy
4 | from pyrep.objects.shape import Shape
5 | from pyrep.objects.object import Object
6 | from pyrep.objects.proximity_sensor import ProximitySensor
7 | from pyrep.const import PrimitiveShape
8 | from rlbench.const import colors
9 | from rlbench.backend.task import Task
10 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, GraspedCondition, Condition
11 | from rlbench.backend.spawn_boundary import SpawnBoundary
12 | from pyquaternion import Quaternion
13 | from pyrep.const import ObjectType
14 |
15 | import os
16 | from os.path import dirname, abspath, join
17 |
18 | def get_z_distance(obj, target=None):
19 | if target is None:
20 | target_dir = - np.pi/2
21 | else:
22 | target_dir = target.get_orientation()[2]
23 | dist = np.abs(obj.get_orientation()[2] - target_dir)
24 | while dist > 2 * np.pi:
25 | dist -= 2 * np.pi
26 | return min(dist, 2*np.pi - dist)
27 |
28 | class UpCondition(Condition):
29 | def __init__(self, obj: Object, threshold = 0.24): # 6 degress
30 | """in radians if revoloute, or meters if prismatic"""
31 | self._obj = obj
32 | self._threshold = threshold
33 |
34 | def condition_met(self):
35 | dist = get_z_distance(self._obj)
36 | met = dist <= self._threshold
37 | return met, False
38 |
39 | class Barista(Task):
40 |
41 | def init_task(self) -> None:
42 | self.cup_source = Shape('cup_source')
43 | self.plate_target = Shape('plate')
44 | self.plant = Shape('plant')
45 | self.waypoint = Dummy('waypoint1')
46 | self.bottles = [Shape('bottle'), Shape('bottle0'), Shape('bottle1')]
47 | self.collidables = [s for s in self.pyrep.get_objects_in_tree( object_type=ObjectType.SHAPE) if ('bottle' in Object.get_object_name(s._handle) or 'plant' in Object.get_object_name(s._handle)) and s.is_respondable()]
48 | self.check_collisions = True
49 |
50 | self.success_detector = ProximitySensor('success')
51 |
52 | self.grasped_cond = GraspedCondition(self.robot.gripper, self.cup_source)
53 | self.drops_detector = ProximitySensor('detector')
54 |
55 | self.orientation_cond = UpCondition(self.cup_source, threshold=0.5) # ~30 degrees
56 | self.cup_condition = DetectedCondition(self.cup_source, self.success_detector)
57 | self.register_success_conditions([self.orientation_cond, self.cup_condition])
58 |
59 | self.register_graspable_objects([self.cup_source])
60 |
61 | def init_episode(self, index: int) -> List[str]:
62 | self.init_orientation = self.cup_source.get_orientation()
63 |
64 | return ['coffee mug on plate']
65 |
66 | def variation_count(self) -> int:
67 | return 1
68 |
69 | def is_static_workspace(self) -> bool:
70 | return True
71 |
72 | def reward(self,):
73 | grasped = self.grasped_cond.condition_met()[0]
74 | cup_placed = self.cup_condition.condition_met()[0]
75 | well_oriented = self.orientation_cond.condition_met()[0]
76 |
77 | grasp_reward = orientation_reward = move_reward = 0.0
78 |
79 | if self.check_collisions:
80 | if np.any([self.cup_source.check_collision(c) for c in self.collidables]) or \
81 | np.any([self.robot.arm.check_collision(c) for c in self.collidables]) or \
82 | np.any([self.robot.gripper.check_collision(c) for c in self.collidables]):
83 | self.terminate_episode = True
84 | else:
85 | self.terminate_episode = False
86 |
87 | if grasped:
88 | grasp_reward = 1.0
89 |
90 | if well_oriented:
91 | orientation_reward = 1.0
92 | else:
93 | # Orientation around vertical axis is locked (very high moment of inertia) -> just other two rotations
94 | orientation_reward = np.exp(-get_z_distance(self.cup_source))
95 |
96 | if cup_placed:
97 | move_reward = 2.0
98 | else:
99 | cup_forward_and_high = (self.cup_source.get_position()[0] >= (self.waypoint.get_position()[0] - 0.01)) and (self.cup_source.get_position()[2] >= (self.waypoint.get_position()[2] - 0.01))
100 |
101 | if cup_forward_and_high:
102 | move_reward = 1.0 + np.exp(-np.linalg.norm(self.cup_source.get_position() - self.success_detector.get_position()))
103 | else:
104 | move_reward = np.exp(-np.linalg.norm(self.cup_source.get_position() - self.waypoint.get_position()))
105 | else:
106 | grasp_reward = np.exp(-np.linalg.norm(self.cup_source.get_position() - self.robot.arm.get_tip().get_position()))
107 |
108 | reward = orientation_reward + grasp_reward + move_reward
109 |
110 | return reward
111 |
112 | def load(self):
113 | ttm_file = join(
114 | dirname(abspath(__file__)),
115 | '../task_ttms/%s.ttm' % self.name)
116 | if not os.path.isfile(ttm_file):
117 | raise FileNotFoundError(
118 | 'The following is not a valid task .ttm file: %s' % ttm_file)
119 | self._base_object = self.pyrep.import_model(ttm_file)
120 | return self._base_object
121 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/bottle_out_moving_fridge.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | import numpy as np
3 | from pyrep.objects.object import Object
4 | from pyrep.objects.proximity_sensor import ProximitySensor
5 | from pyrep.objects.shape import Shape
6 | from rlbench.backend.conditions import DetectedCondition, Condition, GraspedCondition, JointCondition
7 | from rlbench.backend.task import Task
8 | from pyrep.objects.joint import Joint
9 | from pyrep.const import ObjectType
10 |
11 | import os
12 | from os.path import dirname, abspath, join
13 |
14 | class SteadyCondition(Condition):
15 | def __init__(self, obj, correct_position, threshold = 0.025):
16 | self._obj = obj
17 | self._correct_position = correct_position
18 | self._threshold = threshold
19 |
20 | def condition_met(self):
21 | met = np.linalg.norm(
22 | self._obj.get_position() - self._correct_position) <= self._threshold
23 | return met, False
24 |
25 | class NoCollisions:
26 | def __init__(self, pyrep):
27 | self.colliding_shapes = [s for s in pyrep.get_objects_in_tree(
28 | object_type=ObjectType.SHAPE) if s.is_collidable()]
29 |
30 | def __enter__(self):
31 | for s in self.colliding_shapes:
32 | s.set_collidable(False)
33 |
34 | def __exit__(self, *args):
35 | for s in self.colliding_shapes:
36 | s.set_collidable(False)
37 |
38 | class ChangingPointCondition(Condition):
39 | def __init__(self, val=False) -> None:
40 | self.val = val
41 |
42 | def set(self, value):
43 | self.val = value
44 |
45 | def condition_met(self):
46 | return self.val, False
47 |
48 | FRIDGE_OPEN_JOINT_ANGLE = 45 / 180 * np.pi
49 | FRIDGE_INIT_JOINT_ANGLE = 30 / 180 * np.pi # not setting this, but it's around this value
50 |
51 | class BottleOutMovingFridge(Task):
52 |
53 | def init_task(self) -> None:
54 | self.bottle = Shape('bottle')
55 | self._success_sensor = ProximitySensor('success')
56 | self._fridge_door = Joint("top_joint")
57 | self._fridge_target_velocity = self._fridge_door.get_joint_target_velocity()
58 |
59 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.bottle)
60 | self._detected_cond = DetectedCondition(self.bottle, self._success_sensor)
61 | self.register_graspable_objects([self.bottle])
62 | self.grasp_init = False
63 |
64 | def init_episode(self, index: int) -> List[str]:
65 | with NoCollisions(self.pyrep):
66 | fridge_joint_pos = np.pi / 2
67 | self._fridge_door.set_joint_position(fridge_joint_pos, disable_dynamics=False)
68 | arm_joint_pos = np.array([+1.983e+01, +2.183e+01, -1.814e+01, -9.465e+01, +9.332e+01, +8.073e+01, -6.668e+01]) / 180 * np.pi
69 | self.robot.arm.set_joint_positions(arm_joint_pos, disable_dynamics=False)
70 |
71 | arm_joint_pos = np.array([+4.067e+01, +2.979e+01, -2.909e+01, -8.598e+01, +1.045e+02, +9.017e+01, -6.300e+01]) / 180 * np.pi
72 | self.robot.arm.set_joint_positions(arm_joint_pos, disable_dynamics=True)
73 |
74 | self._fridge_door.set_joint_target_position(0.)
75 | self._fridge_door.set_joint_target_velocity(self._fridge_target_velocity)
76 |
77 | if self.grasp_init:
78 | # Grasp object
79 | self.robot.gripper.actuate(0, 0.2) # Close gripper
80 | self.robot.gripper.grasp(self.bottle)
81 | assert len(self.robot.gripper.get_grasped_objects()) > 0, "Object not grasped"
82 |
83 | self._steady_cond = SteadyCondition(self.bottle, np.copy(self.bottle.get_position())) # stay within 2.5cm
84 | self._changing_point_cond = ChangingPointCondition()
85 |
86 | self.register_success_conditions([self._grasped_cond, self._detected_cond, self._changing_point_cond])
87 |
88 | return ['put bottle in fridge',
89 | 'place the bottle inside the fridge',
90 | 'open the fridge and put the bottle in there',
91 | 'open the fridge door, pick up the bottle, and leave it in the '
92 | 'fridge']
93 |
94 | def variation_count(self) -> int:
95 | return 1
96 |
97 | def boundary_root(self) -> Object:
98 | return Shape('fridge_root')
99 |
100 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]:
101 | return [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]
102 |
103 | def is_static_workspace(self) -> bool:
104 | return True
105 |
106 | def reward(self) -> float:
107 | grasped = self._grasped_cond.condition_met()[0]
108 | detected = self._detected_cond.condition_met()[0]
109 | bottle_steady = self._steady_cond.condition_met()[0]
110 | past_door_point = self._changing_point_cond.condition_met()[0]
111 | fridge_open = self._fridge_door.get_joint_position() >= FRIDGE_OPEN_JOINT_ANGLE
112 |
113 | grasp_bottle_reward = fridge_open_reward = reach_target_reward = 0.0
114 |
115 | if grasped:
116 | if past_door_point:
117 | grasp_bottle_reward = 1.0
118 | fridge_open_reward = 1.0
119 |
120 | if detected:
121 | reach_target_reward = 1.0
122 | else:
123 | reach_target_reward = np.exp(
124 | -np.linalg.norm(
125 | self.bottle.get_position()
126 | - self._success_sensor.get_position()
127 | )
128 | )
129 | else:
130 | if bottle_steady:
131 | grasp_bottle_reward = 1.0
132 | if fridge_open:
133 | self._changing_point_cond.set(True) # = grasped + bottle steady + fridge open
134 | fridge_open_reward = 1.0
135 | # Blocking the fridge helps mantaining Markovianity of the state (the agent can see the fridge is locked in position)
136 | self._fridge_door.set_joint_position(FRIDGE_OPEN_JOINT_ANGLE)
137 | self._fridge_door.set_joint_target_position(FRIDGE_OPEN_JOINT_ANGLE)
138 | self._fridge_door.set_joint_target_velocity(0.)
139 | else:
140 | fridge_open_dist = np.clip(FRIDGE_OPEN_JOINT_ANGLE - self._fridge_door.get_joint_position(), 0, np.inf)
141 | fridge_open_reward = np.exp(-fridge_open_dist)
142 |
143 | reward = grasp_bottle_reward + fridge_open_reward + reach_target_reward
144 |
145 | return reward
146 |
147 | def load(self) -> Object:
148 | ttm_file = join(
149 | dirname(abspath(__file__)),
150 | '../task_ttms/%s.ttm' % self.name)
151 | if not os.path.isfile(ttm_file):
152 | raise FileNotFoundError(
153 | 'The following is not a valid task .ttm file: %s' % ttm_file)
154 | self._base_object = self.pyrep.import_model(ttm_file)
155 | return self._base_object
156 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/cup_out_open_cabinet.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | from pyrep.objects.dummy import Dummy
3 | from pyrep.objects.joint import Joint
4 | from pyrep.objects.shape import Shape
5 | from pyrep.objects.proximity_sensor import ProximitySensor
6 | from rlbench.backend.task import Task
7 | from rlbench.backend.conditions import DetectedCondition, NothingGrasped, GraspedCondition
8 | from .reach_gripper_and_elbow import DistanceCondition
9 |
10 | import numpy as np
11 |
12 | import os
13 | from os.path import dirname, abspath, join
14 |
15 | OPTIONS = ['left', 'right']
16 |
17 | class CupOutOpenCabinet(Task):
18 |
19 | def init_task(self) -> None:
20 | self.cup = Shape('cup')
21 | self.left_placeholder = Dummy('left_cup_placeholder')
22 | self.waypoint1 = Dummy('waypoint1')
23 | self.waypoint2 = Dummy('waypoint2')
24 |
25 | self.target = self.waypoint3 = Dummy('waypoint3')
26 |
27 | self.left_way_placeholder1 = Dummy('left_way_placeholder1')
28 | self.left_way_placeholder2 = Dummy('left_way_placeholder2')
29 |
30 | self.grasped_cond = GraspedCondition(self.robot.gripper, self.cup)
31 |
32 | self.cup_target_cond = DistanceCondition(self.cup, self.target, 0.05)
33 | self.register_graspable_objects([self.cup])
34 |
35 |
36 | self.register_success_conditions(
37 | [self.grasped_cond, self.cup_target_cond,])
38 |
39 | def init_episode(self, index: int) -> List[str]:
40 | option = 'right' # OPTIONS[index]
41 |
42 | self.joint_target = Joint(f'{option}_joint')
43 | self.handle_wp = self.waypoint1
44 |
45 |
46 | return ['take out a cup from the %s half of the cabinet' % option,
47 | 'open the %s side of the cabinet and get the cup'
48 | % option,
49 | 'grasping the %s handle, open the cabinet, then retrieve the '
50 | 'cup' % option,
51 | 'slide open the %s door on the cabinet and put take the cup out'
52 | % option,
53 | 'remove the cup from the %s part of the cabinet' % option]
54 |
55 | def reward(self,) -> float:
56 | cup_grasped = self.grasped_cond.condition_met()[0]
57 | cup_is_out = self.cup_target_cond.condition_met()[0]
58 |
59 | reach_cup_reward = cup_out_reward = 0
60 |
61 | if cup_grasped:
62 | reach_cup_reward = 1.0
63 |
64 | if cup_is_out:
65 | cup_out_reward = 1.0
66 | else:
67 | cup_out_reward = np.exp(-np.linalg.norm(self.cup.get_position() - self.target.get_position()))
68 | else:
69 | reach_cup_reward = np.exp(-np.linalg.norm(self.cup.get_position() - self.robot.arm.get_tip().get_position()))
70 |
71 | reward = reach_cup_reward + cup_out_reward
72 |
73 | return reward
74 |
75 |
76 | def variation_count(self) -> int:
77 | return 2
78 |
79 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]:
80 | return [0.0, 0.0, -3.14/2], [0.0, 0.0, 3.14/2]
81 |
82 |
83 | def load(self):
84 | ttm_file = join(
85 | dirname(abspath(__file__)),
86 | '../task_ttms/%s.ttm' % self.name)
87 | if not os.path.isfile(ttm_file):
88 | raise FileNotFoundError(
89 | 'The following is not a valid task .ttm file: %s' % ttm_file)
90 | self._base_object = self.pyrep.import_model(ttm_file)
91 | return self._base_object
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/meat_off_grill.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from pyrep.objects.dummy import Dummy
3 | from pyrep.objects.proximity_sensor import ProximitySensor
4 | from pyrep.objects.shape import Shape
5 | from rlbench.backend.conditions import NothingGrasped, DetectedCondition, GraspedCondition
6 | from rlbench.backend.task import Task
7 | import numpy as np
8 |
9 | MEAT = ['chicken', 'steak']
10 |
11 |
12 | class MeatOffGrill(Task):
13 |
14 | def init_task(self) -> None:
15 | self._steak = Shape('steak')
16 | self._chicken = Shape('chicken')
17 | self._success_sensor = ProximitySensor('success')
18 | self.register_graspable_objects([self._chicken, self._steak])
19 | self._w1 = Dummy('waypoint1')
20 | self._w1z= self._w1.get_position()[2]
21 |
22 | self._nothing_grasped_condition = NothingGrasped(self.robot.gripper)
23 |
24 | def init_episode(self, index: int) -> List[str]:
25 | if index == 0:
26 | self._target = self._chicken
27 | else:
28 | self._target = self._meat
29 | x, y, _ = self._target.get_position()
30 | self._w1.set_position([x, y, self._w1z])
31 | self._detected_condition = DetectedCondition(self._target, self._success_sensor)
32 | self._grasped_condition = GraspedCondition(self.robot.gripper, self._target)
33 |
34 | conditions = [self._nothing_grasped_condition, self._detected_condition]
35 | self.register_success_conditions(conditions)
36 | return ['take the %s off the grill' % MEAT[index],
37 | 'pick up the %s and place it next to the grill' % MEAT[index],
38 | 'remove the %s from the grill and set it down to the side'
39 | % MEAT[index]]
40 |
41 | def reward(self,) -> float:
42 | nothing_grasped = self._nothing_grasped_condition.condition_met()[0]
43 | detected = self._detected_condition.condition_met()[0]
44 | grasped = self._grasped_condition.condition_met()[0]
45 |
46 | reach_reward = move_reward = release_reward = 0.
47 | self.reward_open_gripper = False
48 |
49 | if grasped:
50 | reach_reward = 1.0
51 |
52 | if detected:
53 | move_reward = 1.0
54 | self.reward_open_gripper = True
55 | else:
56 | move_reward = np.exp(
57 | -np.linalg.norm(self._target.get_position() - self._success_sensor.get_position())
58 | )
59 | else:
60 | if detected:
61 | if nothing_grasped:
62 | reach_reward = move_reward = release_reward = 1.0
63 | else:
64 | reach_reward = move_reward = 1.0
65 | self.reward_open_gripper = True
66 | else:
67 | reach_reward = np.exp(
68 | -np.linalg.norm(self._target.get_position() - self.robot.arm.get_tip().get_position())
69 | )
70 |
71 | reward = reach_reward + move_reward + release_reward
72 |
73 | return reward
74 |
75 | def variation_count(self) -> int:
76 | return 2
77 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/phone_on_base.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | from pyrep.objects.proximity_sensor import ProximitySensor
5 | from pyrep.objects.shape import Shape
6 |
7 | from rlbench.backend.conditions import (
8 | DetectedCondition,
9 | GraspedCondition,
10 | NothingGrasped,
11 | )
12 | from rlbench.backend.task import Task
13 |
14 |
15 | class PhoneOnBase(Task):
16 | def init_task(self) -> None:
17 | self.phone = Shape("phone")
18 | self.success_detector = ProximitySensor("success")
19 | self.register_graspable_objects([self.phone])
20 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.phone)
21 | self._nothing_grapsed_cond = NothingGrasped(self.robot.gripper)
22 | self._phone_cond = DetectedCondition(self.phone, self.success_detector)
23 |
24 | self.register_success_conditions(
25 | [self._phone_cond, NothingGrasped(self.robot.gripper)]
26 | )
27 |
28 | def init_episode(self, index: int) -> List[str]:
29 | return [
30 | "put the phone on the base",
31 | "put the phone on the stand",
32 | "put the hone on the hub",
33 | "grasp the phone and put it on the base",
34 | "place the phone on the base",
35 | "put the phone back on the base",
36 | ]
37 |
38 | def variation_count(self) -> int:
39 | return 1
40 |
41 | def reward(self) -> float:
42 | grasped = self._grasped_cond.condition_met()[0]
43 | phone_on_base = self._phone_cond.condition_met()[0]
44 | nothing_grasped = self._nothing_grapsed_cond.condition_met()[0]
45 |
46 | grasp_phone_reward = move_phone_reward = release_reward = 0
47 |
48 | self.reward_open_gripper = False
49 |
50 | if phone_on_base:
51 | if nothing_grasped:
52 | # phone is not grasped anymore
53 | grasp_phone_reward = move_phone_reward = release_reward = 1.0
54 | else:
55 | # phone is in base, but gripper still holds the phone (or something else)
56 | move_phone_reward = 1.0
57 | grasp_phone_reward = 1.0
58 | self.reward_open_gripper = True
59 | else:
60 | if not grasped:
61 | # reaching the phone
62 | grasp_phone_reward = np.exp(
63 | -np.linalg.norm(
64 | self.phone.get_position()
65 | - self.robot.arm.get_tip().get_position()
66 | )
67 | )
68 | else:
69 | grasp_phone_reward = 1.0
70 | # moving the phone toward base
71 | move_phone_reward = np.exp(
72 | -np.linalg.norm(
73 | self.phone.get_position() - self.success_detector.get_position()
74 | )
75 | )
76 |
77 | reward = grasp_phone_reward + move_phone_reward + release_reward
78 |
79 | return reward
80 |
81 | def get_low_dim_state(self) -> np.ndarray:
82 | # For ad-hoc reward computation, attach reward
83 | state = super().get_low_dim_state()
84 | return state
85 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/pick_and_lift.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import numpy as np
3 | from pyrep.objects.shape import Shape
4 | from pyrep.objects.proximity_sensor import ProximitySensor
5 | from rlbench.backend.task import Task
6 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, \
7 | GraspedCondition
8 | from rlbench.backend.spawn_boundary import SpawnBoundary
9 | from rlbench.const import colors
10 |
11 |
12 | class PickAndLift(Task):
13 |
14 | def init_task(self) -> None:
15 | self.target_block = Shape('pick_and_lift_target')
16 | self.distractors = [
17 | Shape('stack_blocks_distractor%d' % i)
18 | for i in range(2)]
19 | self.register_graspable_objects([self.target_block])
20 | self.boundary = SpawnBoundary([Shape('pick_and_lift_boundary')])
21 | self.success_detector = ProximitySensor('pick_and_lift_success')
22 |
23 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.target_block)
24 | self._detected_cond = DetectedCondition(self.target_block, self.success_detector)
25 | cond_set = ConditionSet([
26 | self._grasped_cond,
27 | self._detected_cond
28 | ])
29 | self.register_success_conditions([cond_set])
30 |
31 | def init_episode(self, index: int) -> List[str]:
32 |
33 | block_color_name, block_rgb = colors[index]
34 | self.target_block.set_color(block_rgb)
35 |
36 | color_choices = np.random.choice(
37 | list(range(index)) + list(range(index + 1, len(colors))),
38 | size=2, replace=False)
39 | for i, ob in enumerate(self.distractors):
40 | name, rgb = colors[color_choices[int(i)]]
41 | ob.set_color(rgb)
42 |
43 | self.boundary.clear()
44 | self.boundary.sample(
45 | self.success_detector, min_rotation=(0.0, 0.0, 0.0),
46 | max_rotation=(0.0, 0.0, 0.0))
47 | for block in [self.target_block] + self.distractors:
48 | self.boundary.sample(block, min_distance=0.1)
49 |
50 | return ['pick up the %s block and lift it up to the target' %
51 | block_color_name,
52 | 'grasp the %s block to the target' % block_color_name,
53 | 'lift the %s block up to the target' % block_color_name]
54 |
55 | def variation_count(self) -> int:
56 | return len(colors)
57 |
58 | def reward(self) -> float:
59 | grasped = self._grasped_cond.condition_met()[0]
60 | detected = self._detected_cond.condition_met()[0]
61 |
62 | pick_reward = lift_reward = 0.
63 |
64 | if not grasped:
65 | pick_reward = np.exp(
66 | -np.linalg.norm(
67 | self.target_block.get_position() - self.robot.arm.get_tip().get_position()
68 | )
69 | )
70 | else:
71 | pick_reward = 1.0
72 |
73 | if detected:
74 | lift_reward = 1.0
75 | else:
76 | lift_reward = np.exp(
77 | -np.linalg.norm(
78 | self.target_block.get_position() - self.success_detector.get_position()
79 | )
80 | )
81 |
82 |
83 | reward = pick_reward + lift_reward
84 |
85 | return reward
86 |
87 | def get_low_dim_state(self) -> np.ndarray:
88 | # One of the few tasks that have a custom low_dim_state function.
89 | return np.concatenate([self.target_block.get_position(), self.success_detector.get_position()], 0)
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/pick_up_cup.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import numpy as np
3 | from pyrep.objects.shape import Shape
4 | from pyrep.objects.proximity_sensor import ProximitySensor
5 | from rlbench.const import colors
6 | from rlbench.backend.task import Task
7 | from rlbench.backend.conditions import (
8 | DetectedCondition,
9 | NothingGrasped,
10 | GraspedCondition,
11 | )
12 | from rlbench.backend.spawn_boundary import SpawnBoundary
13 |
14 |
15 | class PickUpCup(Task):
16 | def init_task(self) -> None:
17 | self.cup1 = Shape("cup1")
18 | self.cup2 = Shape("cup2")
19 | self.cup1_visual = Shape("cup1_visual")
20 | self.cup2_visual = Shape("cup2_visual")
21 | self.boundary = SpawnBoundary([Shape("boundary")])
22 | self.success_sensor = ProximitySensor("success")
23 | self.register_graspable_objects([self.cup1, self.cup2])
24 |
25 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.cup1)
26 | self._not_detected_cond = DetectedCondition(
27 | self.cup1, self.success_sensor, negated=True
28 | )
29 |
30 | self.register_success_conditions([self._not_detected_cond, self._grasped_cond])
31 |
32 | def init_episode(self, index: int) -> List[str]:
33 | self.variation_index = index
34 | target_color_name, target_rgb = colors[index]
35 |
36 | random_idx = np.random.choice(len(colors))
37 | while random_idx == index:
38 | random_idx = np.random.choice(len(colors))
39 | _, other1_rgb = colors[random_idx]
40 |
41 | self.cup1_visual.set_color(target_rgb)
42 | self.cup2_visual.set_color(other1_rgb)
43 |
44 | self.boundary.clear()
45 | self.boundary.sample(self.cup2, min_distance=0.1)
46 | self.boundary.sample(self.success_sensor, min_distance=0.1)
47 |
48 | return [
49 | "pick up the %s cup" % target_color_name,
50 | "grasp the %s cup and lift it" % target_color_name,
51 | "lift the %s cup" % target_color_name,
52 | ]
53 |
54 | def variation_count(self) -> int:
55 | return len(colors)
56 |
57 | def reward(self) -> float:
58 | grasped = self._grasped_cond.condition_met()[0]
59 | not_detected = self._not_detected_cond.condition_met()[0]
60 |
61 | pick_reward = up_reward = 0.0
62 |
63 | if not grasped:
64 | pick_reward = np.exp(
65 | -np.linalg.norm(
66 | self.cup1.get_position() - self.robot.arm.get_tip().get_position()
67 | )
68 | )
69 | else:
70 | pick_reward = 1.0
71 |
72 | if not_detected:
73 | up_reward = 1.0
74 | else:
75 | distance_from_orig_pose = np.linalg.norm(
76 | self.cup1.get_position() - self.success_sensor.get_position()
77 | )
78 | up_reward = np.tanh(distance_from_orig_pose)
79 |
80 | reward = pick_reward + up_reward
81 |
82 | return reward
83 |
84 | def get_low_dim_state(self) -> np.ndarray:
85 | # For ad-hoc reward computation, attach reward
86 | state = super().get_low_dim_state()
87 | return state
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/put_rubbish_in_bin.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import numpy as np
3 | import copy
4 | from pyrep.objects.proximity_sensor import ProximitySensor
5 | from pyrep.objects.shape import Shape
6 | from rlbench.backend.task import Task
7 | from rlbench.backend.conditions import DetectedCondition, GraspedCondition
8 |
9 |
10 | class PutRubbishInBin(Task):
11 | def init_task(self):
12 | self.success_sensor = ProximitySensor("success")
13 | self.rubbish = Shape("rubbish")
14 | self.register_graspable_objects([self.rubbish])
15 | self.register_success_conditions(
16 | [DetectedCondition(self.rubbish, self.success_sensor)]
17 | )
18 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.rubbish)
19 | self._detected_cond = DetectedCondition(self.rubbish, self.success_sensor)
20 | self.HIGH_Z_TARGET = 1.05
21 | self.LOW_Z_TARGET = 0.9
22 |
23 | def init_episode(self, index: int) -> List[str]:
24 | tomato1 = Shape("tomato1")
25 | tomato2 = Shape("tomato2")
26 | x1, y1, z1 = tomato2.get_position()
27 | x2, y2, z2 = self.rubbish.get_position()
28 | x3, y3, z3 = tomato1.get_position()
29 | pos = np.random.randint(3)
30 | if pos == 0:
31 | self.rubbish.set_position([x1, y1, z2])
32 | tomato2.set_position([x2, y2, z1])
33 | elif pos == 2:
34 | self.rubbish.set_position([x3, y3, z2])
35 | tomato1.set_position([x2, y2, z3])
36 |
37 | self.lifted = False
38 | # self.reward_open_gripper = False
39 |
40 | return [
41 | "put rubbish in bin",
42 | "drop the rubbish into the bin",
43 | "pick up the rubbish and leave it in the trash can",
44 | "throw away the trash, leaving any other objects alone",
45 | "chuck way any rubbish on the table rubbish",
46 | ]
47 |
48 | def variation_count(self) -> int:
49 | return 1
50 |
51 | def reward(self) -> float:
52 | grasped = self._grasped_cond.condition_met()[0]
53 | detected = self._detected_cond.condition_met()[0]
54 |
55 |
56 | target1_pos = copy.deepcopy(self.success_sensor.get_position())
57 | target1_pos[-1] = self.HIGH_Z_TARGET
58 |
59 | target2_pos = copy.deepcopy(self.success_sensor.get_position())
60 | target2_pos[-1] = self.LOW_Z_TARGET
61 |
62 | grasp_rubbish_reward = move_rubbish_reward = release_reward = 0
63 | self.reward_open_gripper = False
64 |
65 | if not grasped:
66 | if detected:
67 | grasp_rubbish_reward = move_rubbish_reward = release_reward = 1.0
68 | else:
69 | grasp_rubbish_reward = np.exp(
70 | -np.linalg.norm(
71 | self.rubbish.get_position()
72 | - self.robot.arm.get_tip().get_position()
73 | )
74 | )
75 | else:
76 | grasp_rubbish_reward = 1.0
77 |
78 | rubbish_in_bin_area_dist = np.linalg.norm(self.rubbish.get_position()[:2] - target2_pos[:2])
79 | rubbish_in_bin_area = rubbish_in_bin_area_dist < 0.03 # if within 3cm
80 |
81 | rubbish_height = self.rubbish.get_position()[2]
82 |
83 | if rubbish_in_bin_area:
84 | above_bin_dist = np.abs(rubbish_height - self.LOW_Z_TARGET)
85 | rubbish_above_bin = above_bin_dist < 0.06 # if within 6cm
86 |
87 | if rubbish_above_bin:
88 | move_rubbish_reward = 1.0
89 |
90 | self.reward_open_gripper = True
91 | else:
92 | move_rubbish_reward = 0.5 + 0.5 * np.exp(-above_bin_dist) # 0.5 for getting in the area + dist
93 | else:
94 | move_rubbish_reward = 0.5 * np.exp(-np.linalg.norm(self.rubbish.get_position() - target1_pos)) # up to 0.5 -> needs to get in area
95 |
96 | reward = grasp_rubbish_reward + move_rubbish_reward + release_reward
97 |
98 | return reward
99 |
100 | def get_low_dim_state(self) -> np.ndarray:
101 | # For ad-hoc reward computation, attach reward
102 | state = super().get_low_dim_state()
103 | return state
104 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/reach_gripper_and_elbow.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import dirname, abspath, join
3 | from typing import List, Tuple
4 | import numpy as np
5 | from pyrep.objects import Object
6 | from pyrep.objects.dummy import Dummy
7 | from pyrep.objects.object import Object
8 | from pyrep.objects.proximity_sensor import ProximitySensor
9 | from pyrep.objects.shape import Shape
10 | from pyrep.objects.joint import Joint
11 | # from rlbench.backend import sim
12 | from rlbench.backend.task import Task
13 | from rlbench.backend.spawn_boundary import SpawnBoundary
14 | from rlbench.backend.conditions import DetectedCondition, Condition
15 | from pyquaternion import Quaternion
16 |
17 |
18 | def get_quaternion_distance(obj, target):
19 | pyq_quaternion = lambda x: Quaternion([x[3],x[0],x[1],x[2]])
20 | q1, q2 = pyq_quaternion(obj.get_quaternion()), pyq_quaternion(target.get_quaternion())
21 | return np.abs((q1.inverse * q2).angle)
22 |
23 | def get_orientation_distance(obj, target):
24 | ori1, ori2 = obj.get_orientation(), target.get_orientation()
25 | diff = np.abs(ori2 - ori1)
26 | for i in range(3):
27 | while diff[i] > 2 * np.pi:
28 | diff[i] -= 2 * np.pi
29 | diff[i] = min(diff[i], 2*np.pi - diff[i])
30 | return diff.mean()
31 |
32 | class DistanceCondition(Condition):
33 | def __init__(self, obj: Object, target: Object, threshold = 0.075):
34 | """in radians if revoloute, or meters if prismatic"""
35 | self._obj = obj
36 | self._target = target
37 | self._threshold = threshold
38 |
39 | def condition_met(self):
40 | met = np.linalg.norm(
41 | self._obj.get_position() - self._target.get_position()) <= self._threshold
42 | return met, False
43 |
44 | class AngleCondition(Condition):
45 | def __init__(self, obj: Object, target: Object, threshold = 0.18): # 10 degress
46 | """in radians if revoloute, or meters if prismatic"""
47 | self._obj = obj
48 | self._target = target
49 | self._threshold = threshold
50 |
51 | def condition_met(self):
52 | met = get_orientation_distance(self._obj, self._target) <= self._threshold
53 | return met, False
54 |
55 | class ReachGripperAndElbow(Task):
56 |
57 | def init_task(self) -> None:
58 | self.ee_target = Shape("ee_target")
59 | self.elbow_target = Shape("elbow_target")
60 | self.boundaries = Shape("boundary")
61 | self.ee_dummy = Dummy('waypoint1')
62 |
63 | self.ee_success_sensor = ProximitySensor("ee_success")
64 | elbow_success_sensor = ProximitySensor("elbow_success") # not being used, cause the joint is not detected
65 | self.elbow = Joint("Panda_joint4")
66 |
67 | self.ee_condition = DistanceCondition(
68 | self.robot.arm.get_tip(), self.ee_success_sensor, threshold=0.075
69 | )
70 | self.elbow_condition = DistanceCondition(
71 | self.elbow, self.elbow_target, threshold=0.075
72 | )
73 | self.orientation_condition = AngleCondition(
74 | self.robot.arm.get_tip(), self.ee_dummy, threshold = 0.09 # 5 degress per direction
75 | )
76 |
77 | self.randomize_orientation = False
78 |
79 | self.register_success_conditions([
80 | self.ee_condition,
81 | self.elbow_condition,
82 | self.orientation_condition
83 | ])
84 |
85 | def init_episode(self, index: int) -> List[str]:
86 | b = SpawnBoundary([self.boundaries])
87 | b.sample(self.ee_target, min_distance=0.2,
88 | min_rotation=(0, 0, 0), max_rotation=(0, 0, 0))
89 |
90 | if np.random.randint(2) == 0:
91 | theta = np.random.uniform(np.pi * .8, np.pi * .95)
92 | else:
93 | theta = np.random.uniform(np.pi * .05, np.pi * .2)
94 |
95 | # https://frankaemika.github.io/docs/control_parameters.html
96 | # https://download.franka.de/documents/220010_Product%20Manual_Franka%20Hand_1.2_EN.pdf
97 | r_h = 0.1070 + 0.1270 # wrist + gripper
98 | r_w = 0.0880
99 | r_1, r_2 = 0.3160, 0.3840
100 |
101 | c_1 = Joint("Panda_joint2").get_position()
102 | c_2 = self.ee_success_sensor.get_position()
103 |
104 | d = np.linalg.norm(c_1 - c_2)
105 | n_i = (c_2 - c_1) / d
106 |
107 | # t_i, b_i are the tangent, bitangent
108 | t_i = np.array([0, -1, 0])
109 | if not np.array_equal(n_i, np.array([0, 0, 1])):
110 | t_i = np.cross(n_i, np.array([0, 0, 1]))
111 | t_i = t_i / np.linalg.norm(t_i)
112 | b_i = np.array([-1, 0, 0])
113 | if not np.array_equal(n_i, np.array([0, 1, 0])):
114 | b_i = np.cross(n_i, np.array([0, 1, 0]))
115 | b_i = b_i / np.linalg.norm(b_i)
116 |
117 | if d >= r_1 + r_2:
118 | p_i = c_1 + (c_2 - c_1) * r_1 / d
119 | else:
120 | h = 1/2 + (r_1 ** 2 - r_2 ** 2)/(2 * d ** 2)
121 | c_i = c_1 + h * (c_2 - c_1)
122 | r_i = np.sqrt(r_1 ** 2 - (h * d) ** 2)
123 | p_i = c_i + r_i * (t_i * np.cos(theta) + b_i * np.sin(theta))
124 |
125 | self.elbow_target.set_position(p_i)
126 | # ee_pos = c_2 + r_w * n_i - r_h * b_i
127 | ee_pos = self.ee_target.get_position() + r_w * np.array([1, 0, 0]) - r_h * np.array([0, 0, 1])
128 | self.ee_target.set_position(ee_pos)
129 |
130 | if self.randomize_orientation:
131 | ee_ori = self.ee_target.get_orientation()
132 | ee_ori += np.random.uniform(-np.pi/4, +np.pi/4, size=(3,))
133 | self.ee_target.set_orientation(ee_ori)
134 | return ['']
135 |
136 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]:
137 | return [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]
138 |
139 | def is_static_workspace(self) -> bool:
140 | return True
141 |
142 | def variation_count(self) -> int:
143 | return 1
144 |
145 | def get_low_dim_state(self) -> np.ndarray:
146 | return np.array([
147 | self.ee_target.get_position(), self.elbow_target.get_position()
148 | ]).flatten()
149 |
150 | def reward(self) -> float:
151 | ee_detected = self.ee_condition.condition_met()[0]
152 | elbow_detected = self.elbow_condition.condition_met()[0]
153 | orientation_detected = self.orientation_condition.condition_met()[0]
154 |
155 | orient_reward = elbow_reward = ee_reward = 0
156 |
157 | if ee_detected:
158 | ee_reward = 1.0
159 | else:
160 | ee_distance = np.linalg.norm(self.ee_target.get_position() - self.robot.arm.get_tip().get_position())
161 | ee_reward = np.exp(-ee_distance)
162 |
163 | if orientation_detected:
164 | orient_reward = 1.0
165 | else:
166 | orient_distance = get_orientation_distance(self.robot.arm.get_tip(), self.ee_dummy)
167 | orient_reward = np.exp(-orient_distance)
168 |
169 | if elbow_detected:
170 | elbow_reward = 1.0
171 | else:
172 | elbow_distance = np.linalg.norm(self.elbow_target.get_position() - self.elbow.get_position())
173 | elbow_reward = np.exp(-elbow_distance)
174 |
175 | reward = elbow_reward + ee_reward + orient_reward
176 |
177 | return reward
178 |
179 | def load(self) -> Object:
180 | ttm_file = join(
181 | dirname(abspath(__file__)),
182 | '../task_ttms/%s.ttm' % self.name)
183 | if not os.path.isfile(ttm_file):
184 | raise FileNotFoundError(
185 | 'The following is not a valid task .ttm file: %s' % ttm_file)
186 | self._base_object = self.pyrep.import_model(ttm_file)
187 | return self._base_object
188 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/reach_target.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | import numpy as np
3 | from pyrep.objects.shape import Shape
4 | from pyrep.objects.proximity_sensor import ProximitySensor
5 | from rlbench.const import colors
6 | from rlbench.backend.task import Task
7 | from rlbench.backend.spawn_boundary import SpawnBoundary
8 | from rlbench.backend.conditions import DetectedCondition
9 |
10 |
11 | class ReachTarget(Task):
12 |
13 | def init_task(self) -> None:
14 | self.target = Shape('target')
15 | self.distractor0 = Shape('distractor0')
16 | self.distractor1 = Shape('distractor1')
17 | self.boundaries = Shape('boundary')
18 | success_sensor = ProximitySensor('success')
19 |
20 | self._detected_condition = DetectedCondition(self.robot.arm.get_tip(), success_sensor)
21 | self.register_success_conditions([self._detected_condition])
22 |
23 | def init_episode(self, index: int) -> List[str]:
24 | color_name, color_rgb = colors[index]
25 | self.target.set_color(color_rgb)
26 | color_choices = np.random.choice(
27 | list(range(index)) + list(range(index + 1, len(colors))),
28 | size=2, replace=False)
29 | for ob, i in zip([self.distractor0, self.distractor1], color_choices):
30 | name, rgb = colors[i]
31 | ob.set_color(rgb)
32 | b = SpawnBoundary([self.boundaries])
33 | for ob in [self.target, self.distractor0, self.distractor1]:
34 | b.sample(ob, min_distance=0.2,
35 | min_rotation=(0, 0, 0), max_rotation=(0, 0, 0))
36 |
37 | return ['reach the %s target' % color_name,
38 | 'touch the %s ball with the panda gripper' % color_name,
39 | 'reach the %s sphere' %color_name]
40 |
41 | def variation_count(self) -> int:
42 | return len(colors)
43 |
44 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]:
45 | return [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]
46 |
47 | def get_low_dim_state(self) -> np.ndarray:
48 | # One of the few tasks that have a custom low_dim_state function.
49 | return np.array(self.target.get_position())
50 |
51 | def is_static_workspace(self) -> bool:
52 | return True
53 |
54 | def reward(self) -> float:
55 | success = self._detected_condition.condition_met()[0]
56 |
57 | if success:
58 | reward = 1.0
59 | else:
60 | reward = np.exp(-np.linalg.norm(self.target.get_position() -
61 | self.robot.arm.get_tip().get_position()))
62 | return reward
63 |
64 | def validate(self):
65 | pass
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/slide_cup.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 | import numpy as np
3 | from pyrep.objects.dummy import Dummy
4 | from pyrep.objects.shape import Shape
5 | from pyrep.objects.object import Object
6 | from pyrep.objects.proximity_sensor import ProximitySensor
7 | from pyrep.const import PrimitiveShape
8 | from rlbench.const import colors
9 | from rlbench.backend.task import Task
10 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, GraspedCondition, Condition
11 | from rlbench.backend.spawn_boundary import SpawnBoundary
12 | from pyquaternion import Quaternion
13 | from pyrep.const import ObjectType
14 |
15 | import os
16 | from os.path import dirname, abspath, join
17 |
18 | class SlideCup(Task):
19 |
20 | def init_task(self) -> None:
21 | self.cup_source = Shape('cup_source')
22 | self.collidables = [s for s in self.pyrep.get_objects_in_tree( object_type=ObjectType.SHAPE) if ('bottle' in Object.get_object_name(s._handle)) and s.is_respondable()]
23 |
24 | self.success_detector = ProximitySensor('success')
25 | self.cup_condition = DetectedCondition(self.cup_source, self.success_detector)
26 | self.register_success_conditions([self.cup_condition])
27 |
28 |
29 | def init_episode(self, index: int) -> List[str]:
30 | self.initial_z = self.cup_source.get_position()[2]
31 |
32 | return ['slide coffee']
33 |
34 | def variation_count(self) -> int:
35 | return 1
36 |
37 | def is_static_workspace(self) -> bool:
38 | return True
39 |
40 | def reward(self,):
41 | cup_placed = self.cup_condition.condition_met()[0]
42 | cup_fallen = self.cup_source.get_position()[2] < (self.initial_z - 0.075)
43 |
44 | close_reward = move_reward = 0.0
45 |
46 | if cup_fallen:
47 | self.terminate_episode = True
48 | else:
49 | self.terminate_episode = False
50 |
51 | if cup_placed:
52 | move_reward = close_reward = 1.0
53 | else:
54 | left_cup_position = (self.cup_source.get_position() - np.array([0,0.05,0]))
55 |
56 | close_distance = np.linalg.norm(left_cup_position - self.robot.arm.get_tip().get_position())
57 |
58 | if close_distance <= 0.025:
59 | close_reward = 1.0
60 |
61 | move_reward = np.exp(-np.linalg.norm(self.cup_source.get_position() - self.success_detector.get_position()))
62 | else:
63 | # position is offset by 5cm to the left, from the human view
64 | close_reward = np.exp(-np.linalg.norm(left_cup_position - self.robot.arm.get_tip().get_position()))
65 |
66 | reward = close_reward + move_reward
67 |
68 | return reward
69 |
70 | def validate(self):
71 | pass
72 |
73 | def load(self):
74 | ttm_file = join(
75 | dirname(abspath(__file__)),
76 | '../task_ttms/%s.ttm' % self.name)
77 | if not os.path.isfile(ttm_file):
78 | raise FileNotFoundError(
79 | 'The following is not a valid task .ttm file: %s' % ttm_file)
80 | self._base_object = self.pyrep.import_model(ttm_file)
81 | return self._base_object
82 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/stack_wine.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import numpy as np
4 | from pyrep.objects.shape import Shape
5 | from pyrep.objects.proximity_sensor import ProximitySensor
6 | from rlbench.backend.task import Task
7 | from rlbench.backend.conditions import (
8 | DetectedCondition,
9 | GraspedCondition,
10 | NothingGrasped,
11 | )
12 |
13 |
14 | class StackWine(Task):
15 | def init_task(self):
16 | self.wine_bottle = Shape("wine_bottle")
17 | self.register_graspable_objects([self.wine_bottle])
18 |
19 | self._success_sensor = ProximitySensor("success")
20 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.wine_bottle)
21 | self._detected_cond = DetectedCondition(self.wine_bottle, self._success_sensor)
22 |
23 | self.register_success_conditions([self._detected_cond])
24 |
25 | def init_episode(self, index: int) -> List[str]:
26 | return [
27 | "stack wine bottle",
28 | "slide the bottle onto the wine rack",
29 | "put the wine away",
30 | "leave the wine on the shelf",
31 | "grasp the bottle and put it away",
32 | "place the wine bottle on the wine rack",
33 | ]
34 |
35 | def variation_count(self) -> int:
36 | return 1
37 |
38 | def base_rotation_bounds(self) -> Tuple[List[float], List[float]]:
39 | return [0, 0, -np.pi / 4.0], [0, 0, np.pi / 4.0]
40 |
41 | def reward(self) -> float:
42 | grasped = self._grasped_cond.condition_met()[0]
43 | detected = self._detected_cond.condition_met()[0]
44 |
45 | grasp_wine_reward = reach_target_reward = 0.0
46 |
47 | if not grasped:
48 | grasp_wine_reward = np.exp(
49 | -np.linalg.norm(
50 | self.wine_bottle.get_position()
51 | - self.robot.arm.get_tip().get_position()
52 | )
53 | )
54 | else:
55 | grasp_wine_reward = 1.0
56 |
57 | if detected:
58 | reach_target_reward = 1.0
59 | else:
60 | reach_target_reward = np.exp(
61 | -np.linalg.norm(
62 | self.wine_bottle.get_position()
63 | - self._success_sensor.get_position()
64 | )
65 | )
66 |
67 | reward = grasp_wine_reward + reach_target_reward
68 |
69 | return reward
70 |
71 | def get_low_dim_state(self) -> np.ndarray:
72 | # For ad-hoc reward computation, attach reward
73 | state = super().get_low_dim_state()
74 | return state
75 |
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/take_lid_off_saucepan.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | from pyrep.objects.proximity_sensor import ProximitySensor
5 | from pyrep.objects.shape import Shape
6 | from rlbench.backend.conditions import DetectedCondition, ConditionSet, \
7 | GraspedCondition
8 | from rlbench.backend.task import Task
9 |
10 |
11 | class TakeLidOffSaucepan(Task):
12 |
13 | def init_task(self) -> None:
14 | self.lid = Shape('saucepan_lid_grasp_point')
15 | self.success_detector = ProximitySensor('success')
16 | self.register_graspable_objects([self.lid])
17 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.lid)
18 | self._detected_cond = DetectedCondition(self.lid, self.success_detector)
19 | cond_set = ConditionSet([
20 | self._grasped_cond,
21 | self._detected_cond,
22 | ])
23 | self.register_success_conditions([cond_set])
24 |
25 | def init_episode(self, index: int) -> List[str]:
26 | return ['take lid off the saucepan',
27 | 'using the handle, lift the lid off of the pan',
28 | 'remove the lid from the pan',
29 | 'grip the saucepan\'s lid and remove it from the pan',
30 | 'leave the pan open',
31 | 'uncover the saucepan']
32 |
33 | def variation_count(self) -> int:
34 | return 1
35 |
36 | def reward(self) -> float:
37 | grasped = self._grasped_cond.condition_met()[0]
38 | detected = self._detected_cond.condition_met()[0]
39 |
40 | grasp_lid_reward = lift_lid_reward = 0.0
41 |
42 | if grasped:
43 | grasp_lid_reward = 1.0
44 |
45 | if detected:
46 | lift_lid_reward = 1.0
47 | else:
48 | lift_lid_reward = np.exp(-np.linalg.norm(
49 | self.lid.get_position() - self.success_detector.get_position()))
50 | else:
51 | grasp_lid_reward = np.exp(-np.linalg.norm(
52 | self.lid.get_position() - self.robot.arm.get_tip().get_position()))
53 |
54 | reward = grasp_lid_reward + lift_lid_reward
55 |
56 | return reward
--------------------------------------------------------------------------------
/rl/env/custom_rlbench_tasks/tasks/take_umbrella_out_of_umbrella_stand.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import numpy as np
3 | import copy
4 | from scipy.spatial.transform import Rotation
5 | from pyrep.objects.shape import Shape
6 | from pyrep.objects.proximity_sensor import ProximitySensor
7 | from rlbench.backend.task import Task
8 | from rlbench.backend.conditions import DetectedCondition, GraspedCondition
9 |
10 |
11 | class TakeUmbrellaOutOfUmbrellaStand(Task):
12 | def init_task(self):
13 | self.success_sensor = ProximitySensor("success")
14 | self.umbrella = Shape("umbrella")
15 | self.register_graspable_objects([self.umbrella])
16 | self._grasped_cond = GraspedCondition(self.robot.gripper, self.umbrella)
17 | self._detected_cond = DetectedCondition(
18 | self.umbrella, self.success_sensor, negated=True
19 | )
20 | self.register_success_conditions([self._detected_cond])
21 | self.Z_TARGET = 1.12
22 |
23 | def init_episode(self, index: int) -> List[str]:
24 | return [
25 | "take umbrella out of umbrella stand",
26 | "grasping the umbrella by its handle, lift it up and out of the" " stand",
27 | "remove the umbrella from the stand",
28 | "retrieve the umbrella from the stand",
29 | "get the umbrella",
30 | "lift the umbrella out of the stand",
31 | ]
32 |
33 | def variation_count(self) -> int:
34 | return 1
35 |
36 | def reward(self) -> float:
37 | target_pos = copy.deepcopy(self.success_sensor.get_position())
38 | target_pos[-1] = self.Z_TARGET
39 | dist_from_target = np.linalg.norm(self.umbrella.get_position() - target_pos)
40 |
41 | grasped = self._grasped_cond.condition_met()[0]
42 | lifted = dist_from_target < 0.03 # less than 3cm
43 |
44 | grasp_umbrella_reward = lift_umbrella_reward = 0.0
45 |
46 |
47 | if not grasped:
48 | grasp_umbrella_reward = np.exp(
49 | -np.linalg.norm(
50 | self.umbrella.get_position()
51 | - self.robot.arm.get_tip().get_position()
52 | )
53 | )
54 | else:
55 | grasp_umbrella_reward = 1.0
56 |
57 | if lifted:
58 | lift_umbrella_reward = 1.0
59 | else:
60 | lift_umbrella_reward = np.exp(-dist_from_target)
61 |
62 | reward = grasp_umbrella_reward + lift_umbrella_reward
63 |
64 | return reward
65 |
66 | def get_low_dim_state(self) -> np.ndarray:
67 | # For ad-hoc reward computation, attach reward
68 | state = super().get_low_dim_state()
69 | return state
--------------------------------------------------------------------------------
/rl/env/rlbench_envs.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import spaces
3 | import numpy as np
4 | from typing import Union, Dict, Tuple
5 | from pathlib import Path
6 | import shutil
7 | import inspect
8 |
9 | from pyrep.const import RenderMode
10 | from pyrep.objects.vision_sensor import VisionSensor
11 | from pyrep.objects.dummy import Dummy
12 | from pyrep.backend import sim
13 | from pyrep.objects import Object
14 |
15 | import rlbench
16 | from rlbench.action_modes.action_mode import MoveArmThenGripper
17 | from rlbench.action_modes.gripper_action_modes import Discrete
18 | from rlbench.backend.task import Task
19 | from rlbench.environment import Environment
20 | from rlbench.observation_config import ObservationConfig, CameraConfig
21 | from rlbench.utils import name_to_task_class
22 | from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaIK, EndEffectorPoseViaPlanning, JointPosition
23 | from rlbench.backend.exceptions import InvalidActionError
24 |
25 | from pyquaternion import Quaternion
26 | from scipy.spatial.transform import Rotation
27 |
28 | from .custom_arm_action_modes import ERAngleViaIK, EEOrientationState, ERJointViaIK, TimeoutEndEffectorPoseViaIK
29 | from .custom_rlbench_tasks.tasks import CUSTOM_TASKS
30 |
31 |
32 | REACH_TASKS = ['reach_target', 'reach_target_and_elbow', 'reach_target_wall', 'reach_target_shelf']
33 | SHAPED_TASKS = list(CUSTOM_TASKS.keys())
34 | CUSTOM_CAM_TASKS = {
35 | 'bottle_in_open_fridge' : 'overhead',
36 | 'take_bottle_out_fridge': 'overhead',
37 | 'bottle_out_moving_fridge': 'overhead',
38 | 'plate_out_open_dishwasher' : 'right_shoulder',
39 | }
40 | reorder_wxyz_to_xyzw = lambda q: [q[1], q[2], q[3], q[0]]
41 |
42 | """
43 | List of possible low-dim states
44 | --------------------------------
45 | joint_velocities
46 | joint_velocities_noise
47 | joint_positions
48 | joint_positions_noise
49 | joint_forces
50 | joint_forces_noise
51 | gripper_open
52 | gripper_pose
53 | gripper_matrix
54 | gripper_joint_positions
55 | gripper_touch_forces
56 | wrist_camera_matrix
57 | record_gripper_closing
58 | task_low_dim_state
59 |
60 | List of possible cameras
61 | ------------------------------
62 | left_shoulder
63 | right_shoulder
64 | overhead
65 | wrist
66 | front
67 |
68 | """
69 |
70 | DEBUG_RESOLUTION = (128,128)
71 |
72 | def decorate_observation(self, observation):
73 | observation.elbow_pos = _get_elbow_pos()
74 | observation.joint_cart_pos = _get_joint_cart_pos()
75 | observation.elbow_angle = _get_elbow_angle()
76 | if hasattr(self, 'lid'):
77 | observation.lid_pos = self.lid.get_position()
78 | return observation
79 |
80 | class RLBench:
81 | def __init__(self, name, observation_mode='state', img_size=84, action_repeat=1, use_angle : bool = True, terminal_invalid : bool = False, correct_invalid : bool = False,
82 | action_mode='ee', render_mode: Union[None, str] = None, goal_centric : bool = False, quaternion_angle : bool = False,
83 | cameras='front|wrist', state_info='gripper_open|gripper_pose|joint_cart_pos|joint_positions|task_low_dim_state',
84 | use_depth : bool = False, debug_viz : bool = False, robot_setup='panda', reward_scale = 1.0, success_scale = 0.0,
85 | pos_step_size=0.05, rot_step_size=0.05, elbow_step_size=0.075, joint_step_size=0.075, opengl3 = False, action_filter : str = 'none',
86 | erj_joint :int = 0, erj_eps : float = 2e-2, erj_delta_angle : bool = True
87 | ):
88 | if name in CUSTOM_TASKS:
89 | task_class = CUSTOM_TASKS[name]
90 | else:
91 | task_class = name_to_task_class(name)
92 | task_class.decorate_observation = decorate_observation
93 |
94 | self.name = name
95 | if action_mode in ['erangle']:
96 | self.replace_scene()
97 |
98 | # Setup observation configs
99 | self._terminal_invalid = terminal_invalid
100 | self._correct_invalid = correct_invalid
101 | self._goal_centric = goal_centric
102 | self._observation_mode = observation_mode
103 | self._cameras = sorted(cameras.split('|'))
104 | self._state_info = sorted(state_info.split('|'))
105 | # if self.name == 'take_lid_off_saucepan':
106 | # self._state_info.append('lid_pos')
107 | # self._state_info.remove('task_low_dim_state')
108 | self._reward_scale = reward_scale
109 | self._success_scale = success_scale
110 |
111 | obs_config = ObservationConfig()
112 | obs_config.set_all_high_dim(False)
113 | obs_config.set_all_low_dim(False)
114 |
115 | if observation_mode in ['state', 'states', 'both']:
116 | for st in self._state_info:
117 | setattr(obs_config, st, True)
118 | if debug_viz and observation_mode != 'both':
119 | if name in CUSTOM_CAM_TASKS:
120 | custom_camera = getattr(obs_config, f'{CUSTOM_CAM_TASKS[name]}_camera')
121 | custom_camera.rgb = True
122 | custom_camera.image_size = DEBUG_RESOLUTION
123 | custom_camera.render_mode = RenderMode.OPENGL3 if name == 'pick_and_lift' or opengl3 else RenderMode.OPENGL
124 | else:
125 | obs_config.front_camera.rgb = True
126 | obs_config.front_camera.image_size = DEBUG_RESOLUTION
127 | obs_config.front_camera.render_mode = RenderMode.OPENGL3 if name == 'pick_and_lift' or opengl3 else RenderMode.OPENGL
128 | if observation_mode in ['vision', 'pixels', 'both']:
129 | for cam in self._cameras:
130 | if name in CUSTOM_CAM_TASKS and cam == 'front':
131 | cam = CUSTOM_CAM_TASKS[self.name]
132 | camera_config = getattr(obs_config, f'{cam}_camera')
133 | camera_config.rgb = True
134 | camera_config.depth = use_depth
135 | camera_config.image_size = (img_size, img_size)
136 | camera_config.render_mode = RenderMode.OPENGL3 if name == 'pick_and_lift' or opengl3 else RenderMode.OPENGL
137 |
138 | # Setup action mode
139 | self._action_repeat = action_repeat
140 | if action_mode == 'erangle':
141 | arm_action_mode = ERAngleViaIK(absolute_mode=False,)
142 | elif action_mode == 'erjoint':
143 | arm_action_mode = ERJointViaIK(absolute_mode=False, commanded_joint=erj_joint, eps=erj_eps, delta_angle=erj_delta_angle)
144 | elif action_mode == 'ee':
145 | arm_action_mode = TimeoutEndEffectorPoseViaIK(absolute_mode=False,)
146 | elif action_mode == 'ee_plan':
147 | arm_action_mode = EndEffectorPoseViaPlanning(absolute_mode=False,)
148 | elif action_mode == 'joint':
149 | arm_action_mode = JointPosition(absolute_mode=False,)
150 | self._action_mode = action_mode.replace('_plan', '')
151 |
152 |
153 | self.POS_STEP_SIZE = pos_step_size
154 | self.ELBOW_STEP_SIZE = elbow_step_size
155 | self.ROT_STEP_SIZE = rot_step_size
156 | self.JOINT_STEP_SIZE = joint_step_size
157 | self.action_filter = action_filter
158 |
159 | action_modality = MoveArmThenGripper(
160 | arm_action_mode=arm_action_mode,
161 | gripper_action_mode=Discrete()
162 | )
163 |
164 | # Launch environment and setup spaces
165 | self._env = Environment(action_modality, obs_config=obs_config, headless=True, robot_setup=robot_setup,
166 | shaped_rewards=True if name in SHAPED_TASKS else False,)
167 | self._env.launch()
168 |
169 | self.task = self._env.get_task(task_class)
170 | _, obs = self.task.reset()
171 |
172 | self._use_angle = use_angle
173 | self._quaternion_angle = quaternion_angle
174 | if use_angle:
175 | act_shape = self._env.action_shape if quaternion_angle or action_mode in ['joint'] else (self._env.action_shape[0]-1, )
176 | else:
177 | act_shape = (self._env.action_shape[0]-4, ) if action_mode != 'joint' else self._env.action_shape
178 | self.act_space = spaces.Dict({'action' : spaces.Box(low=-1.0, high=1.0, shape=act_shape)})
179 |
180 | state_space = list(obs.get_low_dim_data().shape)
181 | if 'lid_pos' in self._state_info:
182 | state_space[0] += 3
183 | if 'elbow_angle' in self._state_info:
184 | # Size of two as representing angle as a unit vector
185 | state_space[0] += 2
186 | if 'elbow_pos' in self._state_info:
187 | state_space[0] += 3
188 | if 'joint_cart_pos' in self._state_info:
189 | state_space[0] += 3 * 7
190 | state_space = tuple(state_space)
191 |
192 | self._env_obs_space = {}
193 | if observation_mode in ['state', 'states', 'both']:
194 | self._env_obs_space['state'] = spaces.Box(low=-np.inf, high=np.inf, shape=state_space)
195 | if debug_viz:
196 | self._env_obs_space["front_rgb"] = spaces.Box(low=0, high=255, shape=(3,) + DEBUG_RESOLUTION, dtype=np.uint8)
197 | if observation_mode in ['vision', 'pixels', 'both']:
198 | for cam in self._cameras:
199 | self._env_obs_space[f"{cam}_rgb"] = spaces.Box(low=0, high=255, shape=(3, img_size, img_size), dtype=np.uint8)
200 | if use_depth:
201 | self._env_obs_space[f"{cam}_depth"] = spaces.Box(low=-np.inf, high=+np.inf, shape=(1, img_size, img_size), dtype=np.float32)
202 |
203 | # Render more for extra viz
204 | self._render_mode = render_mode
205 | if render_mode is not None:
206 | # Add the camera to the scene
207 | cam_placeholder = Dummy('cam_cinematic_placeholder')
208 | self._gym_cam = VisionSensor.create([640, 360])
209 | self._gym_cam.set_pose(cam_placeholder.get_pose())
210 | if render_mode == 'human':
211 | self._gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED)
212 | else:
213 | self._gym_cam.set_render_mode(RenderMode.OPENGL3)
214 |
215 | def _get_state_vec(self, obs):
216 | vec = []
217 | for k in self._state_info:
218 | data = getattr(obs, k)
219 | if type(data) == float:
220 | data = np.array([data])
221 | if len(data.shape) == 0:
222 | data = data.reshape(1,)
223 |
224 | if self._goal_centric and self.name in REACH_TASKS:
225 | if k == 'gripper_pose':
226 | data[:3] = data[:3] - obs.task_low_dim_state.copy()
227 | if k == 'task_low_dim_state':
228 | data = data * 0
229 | vec.append(data)
230 | return vec
231 |
232 | def _extract_obs(self, obs) -> Dict[str, np.ndarray]:
233 | val = {}
234 | if 'state' in self._env_obs_space:
235 | state = np.concatenate(self._get_state_vec(obs)).astype(self.obs_space['state'].dtype)
236 | val['state'] = state
237 | for k in self._env_obs_space:
238 | if k == 'state': continue
239 | # Assuming all other observations are vision-based
240 | if self.name in CUSTOM_CAM_TASKS and 'front' in k:
241 | data = getattr(obs, k.replace('front', CUSTOM_CAM_TASKS[self.name]))
242 | else:
243 | data = getattr(obs, k)
244 | if 'depth' in k:
245 | data = np.expand_dims(data, 0)
246 | if 'rgb' in k:
247 | data = data.transpose(2,0,1)
248 | val[k] = data.astype(self.obs_space[k].dtype)
249 | return val
250 |
251 | @property
252 | def obs_space(self):
253 | spaces = {
254 | **self._env_obs_space,
255 | "is_first": gym.spaces.Box(0, 1, (), dtype=bool),
256 | "is_last": gym.spaces.Box(0, 1, (), dtype=bool),
257 | "is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
258 | "success": gym.spaces.Box(0, 1, (), dtype=bool),
259 | "invalid_action" : gym.spaces.Box(0, np.iinfo(np.uint8).max, (), dtype=np.uint8)
260 | }
261 | return spaces
262 |
263 | def act2env(self, action):
264 | # Gripper
265 | gripper = (action[-1:].copy() + 1) / 2
266 |
267 | if self.action_filter == 'passband':
268 | pos_filter = lambda x : (np.abs(x) >= 1e-3) * x # Smoothen moves less than 1mm
269 | euler_filter = lambda x : (np.abs(x) >= 1e-2) * x # Smoothen angles less than 0.57295 degrees
270 | joint_filter = lambda x : (np.abs(x) >= 1e-2) * x # Smoothen angles less than 0.57295 degrees
271 | elif self.action_filter == 'round':
272 | pos_filter = lambda x : np.round(x, 3) # Smoothen moves less than 1mm
273 | euler_filter = lambda x : np.round(x, 2) # # Smoothen angles less than 0.57295 degrees
274 | joint_filter = lambda x : np.round(x, 2) # Smoothen angles less than 0.57295 degrees
275 | elif self.action_filter == 'none':
276 | pos_filter = euler_filter = joint_filter = lambda x : x # no filter
277 | else:
278 | raise NotImplementedError(f'Filter not implemented')
279 |
280 |
281 | if self._action_mode in ['ee', 'erangle', 'erjoint']:
282 | T_IDX = 3
283 |
284 | # Translation
285 | translation = pos_filter(self.POS_STEP_SIZE * action[:T_IDX].copy())
286 |
287 | # Rotation
288 | if self._use_angle:
289 | if self._quaternion_angle:
290 | wxyz_rotation = Quaternion(action[T_IDX:T_IDX+4].copy() / np.linalg.norm(action[T_IDX:T_IDX+4].copy()))
291 | wxyz_rotation = Quaternion(axis=wxyz_rotation.axis, radians=np.clip(wxyz_rotation.radians, 0, self.ROT_STEP_SIZE))
292 | xyzw_rotation = reorder_wxyz_to_xyzw(wxyz_rotation.elements)
293 | else:
294 | euler_rotation = euler_filter(action[T_IDX:T_IDX+3].copy() * self.ROT_STEP_SIZE)
295 | xyzw_rotation = Rotation.from_euler('xyz', euler_rotation).as_quat()
296 | else:
297 | wxyz_to = Quaternion([ 1.20908514e-01, 3.09618457e-07, 9.92663622e-01, -1.02228194e-06,])
298 | x,y,z,w = self.task._scene.robot.arm.get_tip().get_quaternion()
299 | wxyz_from = Quaternion([w,x,y,z])
300 | wxyz_rotation = wxyz_to * wxyz_from.inverse
301 | xyzw_rotation = reorder_wxyz_to_xyzw(wxyz_rotation.elements)
302 |
303 | if self._action_mode in ['erangle', 'erjoint']:
304 | elbow_filter = {'erangle' : euler_filter, 'erjoint' : joint_filter }[self._action_mode]
305 | angle = [elbow_filter(action[-2].copy() * self.ELBOW_STEP_SIZE)]
306 | env_action = np.concatenate([translation, xyzw_rotation, angle, gripper])
307 | else:
308 | env_action = np.concatenate([translation, xyzw_rotation, gripper])
309 | elif self._action_mode == 'joint':
310 | joint_action = joint_filter(action[:-1].copy() * self.JOINT_STEP_SIZE)
311 | env_action = np.concatenate([joint_action, gripper])
312 | return env_action
313 |
314 | def step(self, action):
315 | assert (max(action) <= 1) and (min(action) >= -1)
316 |
317 | env_action = self.act2env(action)
318 |
319 | orig_action = action.copy()
320 | reward = 0.0
321 | invalid_action = 0
322 | must_terminate = False
323 | for _ in range(self._action_repeat):
324 | if self._correct_invalid:
325 | raise NotImplementedError('The implementation is outdated')
326 | else:
327 | try:
328 | env_obs, rew, _ = self.task.step(env_action)
329 | self.consecutive_invalid = 0
330 | except InvalidActionError as e:
331 | # Penalty for unsuccesful IK
332 | rew = 0
333 | env_obs = self._prev_env_obs
334 | invalid_action += 1
335 | self.consecutive_invalid += 1
336 |
337 | if getattr(self.task._task, 'reward_open_gripper', False):
338 | if self.consecutive_invalid == 0: # to avoid rewarding invalid states
339 | rew += (orig_action[-1] + 1) / 2 # to be in [0,1]
340 |
341 | reward += rew
342 | self._prev_reward = rew
343 | self._prev_env_obs = env_obs
344 |
345 | if getattr(self.task._task, 'terminate_episode', False):
346 | must_terminate = True
347 | reward = 0.
348 | break
349 |
350 |
351 | is_terminal = (int(self._terminal_invalid) if invalid_action > 0 else 0) or int(must_terminate)
352 | discount = 1 - is_terminal
353 |
354 | if not invalid_action:
355 | success, _ = self.task._task.success()
356 | else:
357 | success = 0
358 |
359 | if success:
360 | self.consecutive_success += 1
361 | else:
362 | self.consecutive_success = 0
363 |
364 |
365 | obs = {
366 | "reward": reward * self._reward_scale + float(success) * self._success_scale,
367 | "is_first": False,
368 | "is_last": True if (self.consecutive_invalid >= 5) or (self.consecutive_success >= 10) or must_terminate else False, # will be handled by timelimit wrapper
369 | "is_terminal": is_terminal, # if not set will be handled by per_episode function
370 | 'action' : orig_action,
371 | 'discount' : discount,
372 | 'success' : success,
373 | "invalid_action" : invalid_action
374 | }
375 | obs.update(self._extract_obs(env_obs))
376 | return obs
377 |
378 | def reset(self, **kwargs):
379 | _, env_obs = self.task.reset(**kwargs)
380 | self.consecutive_invalid = 0
381 | self.consecutive_success = 0
382 |
383 | self._prev_env_obs = env_obs
384 | obs = {
385 | "reward": self.task._task.reward() * self._reward_scale,
386 | "is_first": True,
387 | "is_last": False,
388 | "is_terminal": False,
389 | 'action' : np.zeros_like(self.act_space['action'].sample()),
390 | 'discount' : 1,
391 | "success": False,
392 | "invalid_action" : 0
393 | }
394 | obs.update(self._extract_obs(env_obs))
395 | self._prev_reward = obs['reward']
396 | return obs
397 |
398 | def render(self, mode='human') -> Union[None, np.ndarray]:
399 | if mode != self._render_mode:
400 | raise ValueError(
401 | 'The render mode must match the render mode selected in the '
402 | 'constructor. \nI.e. if you want "human" render mode, then '
403 | 'create the env by calling: '
404 | 'gym.make("reach_target-state-v0", render_mode="human").\n'
405 | 'You passed in mode %s, but expected %s.' % (
406 | mode, self._render_mode))
407 | if mode == 'rgb_array':
408 | frame = self._gym_cam.capture_rgb()
409 | frame = np.clip((frame * 255.).astype(np.uint8), 0, 255)
410 | return frame
411 |
412 | def get_demos(self, *args, **kwargs,):
413 | return self.task.get_demos(*args, **kwargs)
414 |
415 | def close(self) -> None:
416 | self._env.shutdown()
417 |
418 | def __del__(self,) -> None:
419 | self.close()
420 |
421 | def replace_scene(self,):
422 | task_src = Path(inspect.getfile(self.__class__)).parent / 'custom_rlbench_tasks' / 'elbow_angle_task_design.ttt'
423 | task_dst = Path(rlbench.__path__[0]) / 'task_design.ttt'
424 |
425 | shutil.copyfile(task_src, task_dst)
426 |
427 | def __getattr__(self, name):
428 | if name in ['obs_space', 'act_space']:
429 | return self.__getattribute__(name)
430 | else:
431 | return getattr(self._env, name)
432 |
433 | def _get_elbow_angle() -> np.ndarray:
434 | try:
435 | joint2_obj = Object.get_object(sim.simGetObjectHandle('Panda_joint2'))
436 | joint7_obj = Object.get_object(sim.simGetObjectHandle('Panda_joint7'))
437 | elbow_obj = Object.get_object(sim.simGetObjectHandle('Panda_joint4'))
438 | w = joint7_obj.get_position() - joint2_obj.get_position()
439 | w = w / np.linalg.norm(w)
440 | a = joint2_obj.get_position()
441 | p = elbow_obj.get_position()
442 |
443 | # find vector on plan that is orthogonal to y axis.
444 | angle_origin = np.array([-1, 0, 0])
445 | if not np.array_equal(w, np.array([0, 1, 0])):
446 | angle_origin = np.cross(w, np.array([0, 1, 0]))
447 |
448 | # find center of "elbow circle"
449 | alpha = sum([w_i * (p_i - a_i) for w_i, p_i, a_i in zip(w, p, a)]) / sum([w_i ** 2 for w_i in w])
450 | center = [a_i + alpha * w_i for a_i, w_i in zip(a, w)]
451 |
452 | elbow_vector = p - center
453 |
454 | # normalise the vectors
455 | angle_origin = angle_origin / np.linalg.norm(angle_origin)
456 | elbow_vector = elbow_vector / np.linalg.norm(elbow_vector)
457 |
458 | x = np.dot(w, np.cross(angle_origin, elbow_vector))
459 | y = np.cos(np.arcsin(x))
460 |
461 | return np.array([x, y])
462 | except:
463 | return np.array([0,0])
464 |
465 | def _get_joint_cart_pos():
466 | try:
467 | return np.concatenate([Object.get_object(sim.simGetObjectHandle(f'Panda_joint{i}')).get_position() for i in range(1,8)])
468 | except:
469 | return np.zeros([21])
470 |
471 | def _get_elbow_pos() -> np.ndarray:
472 | try:
473 | return Object.get_object(sim.simGetObjectHandle('Panda_joint4')).get_position()
474 | except:
475 | return np.array([0,0,0])
476 |
477 |
478 | if __name__ == '__main__':
479 | env = RLBench('reach_target')
480 | obs = env.reset()
481 | print(obs)
--------------------------------------------------------------------------------
/rl/envs.py:
--------------------------------------------------------------------------------
1 | from dataclasses import astuple, dataclass
2 | from enum import Enum
3 | from multiprocessing import Pipe, Process
4 | from multiprocessing import set_start_method as mp_set_start_method
5 | from multiprocessing.connection import Connection
6 | from typing import Any, Callable, Iterator, List, Optional, Tuple, Dict
7 |
8 | import numpy as np
9 | from collections import defaultdict, deque
10 | import gym
11 | from pathlib import Path
12 |
13 | import sys
14 | import os
15 | from contextlib import redirect_stderr, redirect_stdout
16 | import time
17 |
18 | import traceback
19 |
20 | ERROR = -1
21 | WAITING = 0
22 |
23 | class MessageType(Enum):
24 | EXCEPTION = -3
25 | RESET = 0
26 | STEP = 2
27 | STEP_RETURN = 3
28 | CLOSE = 4
29 | OBS_SPACE = 5
30 | OBS_SPACE_RETURN = 6
31 | ACT_SPACE = 7
32 | ACT_SPACE_RETURN = 8
33 |
34 | @dataclass
35 | class Message:
36 | type: MessageType
37 | content: Optional[Any] = None
38 |
39 | def __iter__(self) -> Iterator:
40 | return iter(astuple(self))
41 |
42 |
43 | def child_fn(child_id: int, env_fn: Callable, child_conn: Connection, redirect_output_to: str = None) -> None:
44 | np.random.seed(child_id + np.random.randint(0, 2 ** 31 - 1))
45 | if redirect_output_to is not None:
46 | redirect_output_to = Path(redirect_output_to)
47 | os.makedirs(str(redirect_output_to / str(child_id)), exist_ok=True)
48 | with open(str(redirect_output_to / str(child_id) / "out.log"), 'a') as stdout, redirect_stdout(stdout), open(str(redirect_output_to / str(child_id) / "err.log"), 'a') as stderr, redirect_stderr(stderr):
49 | child_env(child_id, env_fn, child_conn)
50 | else:
51 | child_env(child_id, env_fn, child_conn)
52 |
53 | def child_env(child_id, env_fn: Callable, child_conn: Connection,) -> None:
54 | try:
55 | env = env_fn()
56 | while True:
57 | message_type, content = child_conn.recv()
58 | if message_type == MessageType.RESET:
59 | obs = env.reset()
60 | obs['env_idx'] = child_id
61 | obs['env_error'] = 0
62 | child_conn.send(Message(MessageType.STEP_RETURN, obs))
63 | elif message_type == MessageType.STEP:
64 | obs = env.step(content)
65 | # if obs['is_last']:
66 | # obs = env.reset()
67 | obs['env_idx'] = child_id
68 | obs['env_error'] = 0
69 | child_conn.send(Message(MessageType.STEP_RETURN, obs))
70 | elif message_type == MessageType.CLOSE:
71 | child_conn.close()
72 | return
73 | elif message_type == MessageType.OBS_SPACE:
74 | obs_space = env.obs_space
75 | child_conn.send(Message(MessageType.OBS_SPACE_RETURN, obs_space))
76 | elif message_type == MessageType.ACT_SPACE:
77 | act_space = env.act_space
78 | child_conn.send(Message(MessageType.ACT_SPACE_RETURN, act_space))
79 | else:
80 | raise NotImplementedError
81 | sys.stdout.flush(), sys.stderr.flush()
82 | except Exception as e:
83 | child_conn.send(Message(MessageType.EXCEPTION, traceback.format_exc()))
84 |
85 | class MultiProcessEnv(gym.Env):
86 | def __init__(self, env_fn: Callable, num_envs: int, redirect_output_to: str = None) -> None:
87 | super().__init__()
88 | self.env_fn = env_fn
89 | self.num_envs = num_envs
90 | self.processes, self.parent_conns, self.child_conns = [], [], []
91 | mp_set_start_method('fork', force=True)
92 |
93 | self.idx_queue = deque()
94 | self.FPS_TARGET = 100 # FPS
95 | self.TIMEOUT_LIMIT = 20
96 | self.timeouts = np.array([np.inf for _ in range(num_envs)], dtype=np.float64)
97 |
98 | for child_id in range(num_envs):
99 | parent_conn, child_conn = Pipe()
100 | self.parent_conns.append(parent_conn)
101 | self.child_conns.append(child_conn)
102 | p = Process(target=child_fn, args=(child_id, env_fn, child_conn, redirect_output_to), daemon=True)
103 | self.processes.append(p)
104 | for idx, p in enumerate(self.processes):
105 | p.start()
106 | # Waiting for reset to work, to avoid concurrency issues in starting the simulator
107 | self.reset_idx(idx)
108 |
109 | self._obs_space, self._act_space = None, None
110 | print("Observation space:")
111 | for k, v in self.obs_space.items():
112 | print(f" {k} : {v}")
113 | print("Action space:")
114 | for k, v in self.act_space.items():
115 | print(f" {k} : {v}")
116 |
117 | def _clear(self,):
118 | self.timeouts = np.array([np.inf for _ in range(self.num_envs)], dtype=np.float64)
119 | for p in self.parent_conns:
120 | if p.poll():
121 | p.recv()
122 |
123 | def _restore_process_idx(self, idx : float):
124 | print("Restoring process", idx)
125 | self.child_conns[idx].close()
126 | self.parent_conns[idx].close()
127 | self.processes[idx].kill()
128 |
129 | parent_conn, child_conn = Pipe()
130 | self.parent_conns[idx] = parent_conn
131 | self.child_conns[idx] = child_conn
132 | p = Process(target=child_fn, args=(idx, self.env_fn, child_conn), daemon=True)
133 | self.processes[idx] = p
134 | self.timeouts[idx] = np.inf
135 | p.start()
136 |
137 | def _receive_idx(self, idx : int, check_type : Optional[MessageType], timeout = 1e-8):
138 | timeout = max(1 / (self.FPS_TARGET * self.num_envs), timeout)
139 | t = time.time()
140 | if self.parent_conns[idx].poll(timeout=timeout):
141 | # do stuff
142 | elapsed = time.time() - t
143 | self.timeouts -= elapsed # Time passes for all processes
144 | message = self.parent_conns[idx].recv()
145 | if check_type is not None:
146 | if message.type == MessageType.EXCEPTION:
147 | print(f"Received exception from process {idx} : {message.content}")
148 | return {'env_idx' : idx, 'env_error' : 1}
149 | else:
150 | assert message.type == check_type, f"Process: {idx}, received type: {message.type}, request type: {check_type}"
151 | self.timeouts[idx] = self.TIMEOUT_LIMIT
152 | content = message.content
153 | return content
154 | self.timeouts -= timeout # Time passes for all processes
155 | if self.timeouts[idx] > 0:
156 | return WAITING
157 | else:
158 | self._restore_process_idx(idx)
159 | return {'env_idx' : idx, 'env_error' : 1}
160 |
161 | def _receive_all(self, check_type: Optional[MessageType] = None, timeout = 1e-8) -> List[Any]:
162 | contents = [self._receive_idx(idx, check_type=check_type, timeout=timeout) for idx in range(self.num_envs)]
163 | return contents
164 |
165 | def _send_reset(self, idx):
166 | self.parent_conns[idx].send(Message(MessageType.RESET))
167 | self.timeouts[idx] = self.TIMEOUT_LIMIT
168 | self.idx_queue.append(idx)
169 |
170 | def reset_idx(self, idx) -> np.ndarray:
171 | content = ERROR # starting with ERROR to trigger the loop
172 | self.parent_conns[idx].send(Message(MessageType.RESET))
173 | self.timeouts[idx] = self.TIMEOUT_LIMIT
174 | attempts = 600 / self.TIMEOUT_LIMIT # 300secs = 5mins of attempts
175 | while content in [ERROR, WAITING]:
176 | content = self._receive_idx(idx, check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT)
177 | if content['env_error']:
178 | content = ERROR
179 | attempts -= 1
180 | print(f"Remaining attempts for process {idx}: {attempts}")
181 | self.parent_conns[idx].send(Message(MessageType.RESET))
182 | self.timeouts[idx] = self.TIMEOUT_LIMIT
183 | if attempts <= 0:
184 | raise ChildProcessError("Could not reset environment")
185 | return content
186 |
187 | def reset_all(self) -> np.ndarray:
188 | # Sending messages and setting timeouts
189 | for parent_conn in self.parent_conns:
190 | parent_conn.send(Message(MessageType.RESET))
191 | self.timeouts = np.array([self.TIMEOUT_LIMIT for _ in range(self.num_envs)], dtype=np.float64)
192 |
193 | content = self._receive_all(check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT)
194 | ret_obs = defaultdict(list)
195 | for c in content:
196 | if c['env_error']:
197 | for k,v in {**self.obs_space, **self.act_space,}.items():
198 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype))
199 | ret_obs['reward'].append(0.)
200 | ret_obs['discount'].append(0.)
201 | else:
202 | for k,v in c.items():
203 | ret_obs[k].append(v)
204 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()}
205 | return ret_obs
206 |
207 | def step_all(self, actions: np.ndarray) -> Dict:
208 | # Sending messages and setting timeouts
209 | for parent_conn, action in zip(self.parent_conns, actions):
210 | parent_conn.send(Message(MessageType.STEP, action))
211 | self.timeouts = np.array([self.TIMEOUT_LIMIT for _ in range(self.num_envs)], dtype=np.float64)
212 |
213 | content = self._receive_all(check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT)
214 | ret_obs = defaultdict(list)
215 | for c in content:
216 | if c['env_error']:
217 | for k,v in {**self.obs_space, **self.act_space,}.items():
218 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype))
219 | ret_obs['reward'].append(0.)
220 | ret_obs['discount'].append(0.)
221 | # For all cases (also in case of error)
222 | for k,v in c.items():
223 | ret_obs[k].append(v)
224 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()}
225 | return ret_obs
226 |
227 | def step_by_idx(self, actions: np.ndarray, idxs : List, requested_steps, ignore_idxs : List = []) -> Dict:
228 | for idx, action in zip(idxs, actions):
229 | if idx in ignore_idxs:
230 | continue
231 | self.parent_conns[idx].send(Message(MessageType.STEP, action))
232 | self.idx_queue.append(idx)
233 | self.timeouts[idx] = self.TIMEOUT_LIMIT
234 | ret_obs = defaultdict(list)
235 | while len(ret_obs['env_idx']) < requested_steps:
236 | idx = self.idx_queue.popleft()
237 | c = self._receive_idx(idx, check_type=MessageType.STEP_RETURN,)
238 | if c == WAITING:
239 | self.idx_queue.append(idx)
240 | continue
241 | if c['env_error']:
242 | for k,v in {**self.obs_space, **self.act_space,}.items():
243 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype))
244 | ret_obs['reward'].append(0.)
245 | ret_obs['discount'].append(0.)
246 | # For all cases (also in case of error)
247 | for k,v in c.items():
248 | ret_obs[k].append(v)
249 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()}
250 | return ret_obs
251 |
252 | def step_receive_by_idx(self, actions: np.ndarray, send_idxs : List, recv_idxs : List) -> Dict:
253 | # Send
254 | for idx, action in zip(send_idxs, actions):
255 | self.parent_conns[idx].send(Message(MessageType.STEP, action))
256 | self.timeouts[idx] = self.TIMEOUT_LIMIT
257 | ret_obs = defaultdict(list)
258 | if len(recv_idxs) == 0:
259 | return
260 | # Receive
261 | content = [self._receive_idx(idx, check_type=MessageType.STEP_RETURN, timeout=self.TIMEOUT_LIMIT) for idx in recv_idxs]
262 | for c in content:
263 | if c['env_error']:
264 | for k,v in {**self.obs_space, **self.act_space,}.items():
265 | ret_obs[k].append(np.zeros(v.shape, dtype=v.dtype))
266 | ret_obs['reward'].append(0.)
267 | ret_obs['discount'].append(0.)
268 | # For all cases (also in case of error)
269 | for k,v in c.items():
270 | ret_obs[k].append(v)
271 | ret_obs = { k: np.stack(v, axis=0) for k,v in ret_obs.items()}
272 | return ret_obs
273 |
274 | @property
275 | def obs_space(self,):
276 | while self._obs_space in [None, WAITING]:
277 | self.parent_conns[0].send(Message(MessageType.OBS_SPACE, None))
278 | self.timeouts[0] = self.TIMEOUT_LIMIT
279 | content = self._receive_idx(0, check_type=MessageType.OBS_SPACE_RETURN, timeout=self.TIMEOUT_LIMIT)
280 | self._obs_space = content
281 | if 'env_error' in self._obs_space:
282 | raise ChildProcessError("Problem instantiating the environments")
283 | return self._obs_space
284 |
285 | @property
286 | def act_space(self,):
287 | while self._act_space in [None, WAITING]:
288 | self.parent_conns[0].send(Message(MessageType.ACT_SPACE, None))
289 | self.timeouts[0] = self.TIMEOUT_LIMIT
290 | content = self._receive_idx(0, check_type=MessageType.ACT_SPACE_RETURN, timeout=self.TIMEOUT_LIMIT)
291 | self._act_space = content
292 | if 'env_error' in self._act_space:
293 | raise ChildProcessError("Problem instantiating the environments")
294 | return self._act_space
295 |
296 | def close(self) -> None:
297 | for parent_conn in self.parent_conns:
298 | parent_conn.send(Message(MessageType.CLOSE))
299 | for parent_conn in self.parent_conns:
300 | parent_conn.close()
301 | for p in self.processes:
302 | if p.is_alive():
303 | p.join(5)
304 |
305 | def __del__(self):
306 | self.close()
307 |
308 | class TimeLimit:
309 | def __init__(self, env, duration):
310 | self._env = env
311 | self._duration = duration
312 | self._step = None
313 |
314 | def __getattr__(self, name):
315 | if name.startswith('__'):
316 | raise AttributeError(name)
317 | try:
318 | return getattr(self._env, name)
319 | except AttributeError:
320 | raise ValueError(name)
321 |
322 | def step(self, action):
323 | assert self._step is not None, 'Must reset environment.'
324 | obs = self._env.step(action)
325 | self._step += 1
326 | if self._duration and self._step >= self._duration:
327 | obs['is_last'] = True
328 | self._step = None
329 | return obs
330 |
331 | def reset(self, **kwargs):
332 | self._step = 0
333 | return self._env.reset(**kwargs)
334 |
335 | def reset_with_task_id(self, task_id):
336 | self._step = 0
337 | return self._env.reset_with_task_id(task_id)
338 |
339 | def make(name, obs_type, frame_stack, action_repeat, seed, cfg=None, img_size=84, exorl=False, is_eval=False):
340 | assert obs_type in ['states', 'pixels', 'both']
341 | domain, task = name.split('_', 1)
342 | if domain == 'rlbench':
343 | import env.rlbench_envs as rlbench_envs
344 | return TimeLimit(rlbench_envs.RLBench(task, observation_mode=obs_type, action_repeat=action_repeat, **cfg.env), 200 // action_repeat)
345 | else:
346 | raise NotImplementedError("")
--------------------------------------------------------------------------------
/rl/hydra/hydra_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | # A logger config that directs hydra verbose logging to a file
2 | version: 1
3 | formatters:
4 | simple:
5 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
6 | handlers:
7 | file:
8 | class: logging.FileHandler
9 | mode: w
10 | formatter: simple
11 | # relative to the job log directory
12 | filename: exp_hydra_logs/${now:%Y.%m.%d}_${now:%H%M%S}_${experiment}_${agent.name}_${obs_type}.log
13 | delay: true
14 | root:
15 | level: DEBUG
16 | handlers: [file]
17 |
18 | loggers:
19 | hydra:
20 | level: DEBUG
21 |
22 | disable_existing_loggers: false
--------------------------------------------------------------------------------
/rl/logger.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | import wandb
9 | from termcolor import colored
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
13 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
14 | ('episode_reward', 'R', 'float'),
15 | ('fps', 'FPS', 'float'), ('total_time', 'T', 'time')]
16 |
17 | COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
18 | ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
19 | ('episode_reward', 'R', 'float'),
20 | ('total_time', 'T', 'time')]
21 |
22 |
23 | class AverageMeter(object):
24 | def __init__(self):
25 | self._sum = 0
26 | self._count = 0
27 |
28 | def update(self, value, n=1):
29 | self._sum += value
30 | self._count += n
31 |
32 | def value(self):
33 | return self._sum / max(1, self._count)
34 |
35 |
36 | class MetersGroup(object):
37 | def __init__(self, csv_file_name, formating, use_wandb):
38 | self._csv_file_name = csv_file_name
39 | self._formating = formating
40 | self._meters = defaultdict(AverageMeter)
41 | self._csv_file = None
42 | self._csv_writer = None
43 | self.use_wandb = use_wandb
44 |
45 | def log(self, key, value, n=1):
46 | self._meters[key].update(value, n)
47 |
48 | def _prime_meters(self):
49 | data = dict()
50 | for key, meter in self._meters.items():
51 | if key.startswith('train'):
52 | key = key[len('train') + 1:]
53 | else:
54 | key = key[len('eval') + 1:]
55 | key = key.replace('/', '_')
56 | data[key] = meter.value()
57 | return data
58 |
59 | def _remove_old_entries(self, data):
60 | rows = []
61 | with self._csv_file_name.open('r') as f:
62 | reader = csv.DictReader(f)
63 | for row in reader:
64 | if 'episode' in row:
65 | # BUGFIX: covers weird cases where CSV are badly written
66 | if row['episode'] == '':
67 | rows.append(row)
68 | continue
69 | if type(row['episode']) == type(None):
70 | continue
71 | if float(row['episode']) >= data['episode']:
72 | break
73 | rows.append(row)
74 | with self._csv_file_name.open('w') as f:
75 | # To handle CSV that have more keys than new data
76 | keys = set(data.keys())
77 | if len(rows) > 0: keys = keys | set(row.keys())
78 | keys = sorted(list(keys))
79 | #
80 | writer = csv.DictWriter(f,
81 | fieldnames=keys,
82 | restval=0.0)
83 | writer.writeheader()
84 | for row in rows:
85 | writer.writerow(row)
86 |
87 | def _dump_to_csv(self, data):
88 | if self._csv_writer is None:
89 | should_write_header = True
90 | if self._csv_file_name.exists():
91 | self._remove_old_entries(data)
92 | should_write_header = False
93 |
94 | self._csv_file = self._csv_file_name.open('a')
95 | self._csv_writer = csv.DictWriter(self._csv_file,
96 | fieldnames=sorted(data.keys()),
97 | restval=0.0)
98 | if should_write_header:
99 | self._csv_writer.writeheader()
100 |
101 | # To handle components that start training later
102 | # (restval covers only when data has less keys than the CSV)
103 | if self._csv_writer.fieldnames != sorted(data.keys()) and \
104 | len(self._csv_writer.fieldnames) < len(data.keys()):
105 | self._csv_file.close()
106 | self._csv_file = self._csv_file_name.open('r')
107 | dict_reader = csv.DictReader(self._csv_file)
108 | rows = [row for row in dict_reader]
109 | self._csv_file.close()
110 | self._csv_file = self._csv_file_name.open('w')
111 | self._csv_writer = csv.DictWriter(self._csv_file,
112 | fieldnames=sorted(data.keys()),
113 | restval=0.0)
114 | self._csv_writer.writeheader()
115 | for row in rows:
116 | self._csv_writer.writerow(row)
117 |
118 | self._csv_writer.writerow(data)
119 | self._csv_file.flush()
120 |
121 | def _format(self, key, value, ty):
122 | if ty == 'int':
123 | value = int(value)
124 | return f'{key}: {value}'
125 | elif ty == 'float':
126 | return f'{key}: {value:.04f}'
127 | elif ty == 'time':
128 | value = str(datetime.timedelta(seconds=int(value)))
129 | return f'{key}: {value}'
130 | else:
131 | raise f'invalid format type: {ty}'
132 |
133 | def _dump_to_console(self, data, prefix):
134 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
135 | pieces = [f'| {prefix: <14}']
136 | for key, disp_key, ty in self._formating:
137 | value = data.get(key, 0)
138 | pieces.append(self._format(disp_key, value, ty))
139 | print(' | '.join(pieces))
140 |
141 | def _dump_to_wandb(self, data):
142 | wandb.log(data)
143 |
144 | def dump(self, step, prefix):
145 | if len(self._meters) == 0:
146 | return
147 | data = self._prime_meters()
148 | data['frame'] = step
149 | if self.use_wandb:
150 | wandb_data = {prefix + '/' + key: val for key, val in data.items()}
151 | self._dump_to_wandb(data=wandb_data)
152 | self._dump_to_console(data, prefix)
153 | self._meters.clear()
154 |
155 |
156 | class Logger(object):
157 | def __init__(self, log_dir, use_tb, use_wandb):
158 | self._log_dir = log_dir
159 | self._train_mg = MetersGroup(log_dir / 'train.csv',
160 | formating=COMMON_TRAIN_FORMAT,
161 | use_wandb=use_wandb)
162 | self._eval_mg = MetersGroup(log_dir / 'eval.csv',
163 | formating=COMMON_EVAL_FORMAT,
164 | use_wandb=use_wandb)
165 | if use_tb:
166 | self._sw = SummaryWriter(str(log_dir / 'tb'))
167 | else:
168 | self._sw = None
169 | self.use_wandb = use_wandb
170 |
171 | def _try_sw_log(self, key, value, step):
172 | if self._sw is not None:
173 | self._sw.add_scalar(key, value, step)
174 |
175 | def log(self, key, value, step):
176 | assert key.startswith('train') or key.startswith('eval')
177 | if type(value) == torch.Tensor:
178 | value = value.item()
179 | self._try_sw_log(key, value, step)
180 | mg = self._train_mg if key.startswith('train') else self._eval_mg
181 | mg.log(key, value)
182 |
183 | def log_metrics(self, metrics, step, ty):
184 | for key, value in metrics.items():
185 | self.log(f'{ty}/{key}', value, step)
186 |
187 | def dump(self, step, ty=None):
188 | if ty is None or ty == 'eval':
189 | self._eval_mg.dump(step, 'eval')
190 | if ty is None or ty == 'train':
191 | self._train_mg.dump(step, 'train')
192 |
193 | def log_and_dump_ctx(self, step, ty):
194 | return LogAndDumpCtx(self, step, ty)
195 |
196 | def log_video(self, data, step):
197 | if self._sw is not None:
198 | for k, v in data.items():
199 | self._sw.add_video(k, v, global_step=step, fps=15)
200 | if self.use_wandb:
201 | for k, v in data.items():
202 | if type(v) == torch.Tensor:
203 | v = v.cpu()
204 | v = np.uint8(v)
205 | wandb.log({k: wandb.Video(v, fps=15, format="gif")})
206 |
207 |
208 | class LogAndDumpCtx:
209 | def __init__(self, logger, step, ty):
210 | self._logger = logger
211 | self._step = step
212 | self._ty = ty
213 |
214 | def __enter__(self):
215 | return self
216 |
217 | def __call__(self, key, value):
218 | self._logger.log(f'{self._ty}/{key}', value, self._step)
219 |
220 | def __exit__(self, *args):
221 | self._logger.dump(self._step, self._ty)
222 |
--------------------------------------------------------------------------------
/rl/np_replay.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import datetime
3 | import io
4 | import pathlib
5 | import uuid
6 | import os
7 |
8 | import numpy as np
9 | from gym.spaces import Dict
10 | import random
11 | from torch.utils.data import IterableDataset, DataLoader
12 | import torch
13 | import utils
14 | import traceback
15 |
16 | SIG_FAILURE = -1
17 |
18 | def get_length(filename):
19 | if "-" in str(filename):
20 | length = int(str(filename).split('-')[-1])
21 | else:
22 | length = int(str(filename).split('_')[-1])
23 | return length
24 |
25 | def get_idx(filename):
26 | if "-" in str(filename):
27 | length = int(str(filename).split('-')[0])
28 | else:
29 | length = int(str(filename).split('_')[0])
30 | return length
31 |
32 | def on_fn(): return collections.defaultdict(list) # this function is to avoid lambdas
33 |
34 | class ReplayBuffer(IterableDataset):
35 |
36 | def __init__(
37 | self, data_specs, meta_specs, directory, length=20, capacity=0, ongoing=False, minlen=1, maxlen=0,
38 | prioritize_ends=False, device='cuda', load_first=False, delete_old_storage=False, save_episodes=True, **kwargs):
39 | self._directory = pathlib.Path(directory).expanduser()
40 | self._directory.mkdir(parents=True, exist_ok=True)
41 | self._capacity = capacity
42 | self._ongoing = ongoing
43 | self._minlen = minlen
44 | self._maxlen = maxlen
45 | self._prioritize_ends = prioritize_ends
46 | # self._random = np.random.RandomState()
47 | # filename -> key -> value_sequence
48 |
49 | self._save_episodes = save_episodes
50 | self._last_added_idx = 0
51 |
52 | if delete_old_storage:
53 | raise NotImplementedError("May want this feature")
54 |
55 | self._episode_lens = np.array([])
56 | self._complete_eps = {}
57 | for spec_group in [data_specs, meta_specs]:
58 | for spec in spec_group:
59 | if type(spec) in [dict, Dict]:
60 | for k,v in spec.items():
61 | self._complete_eps[k] = []
62 | else:
63 | self._complete_eps[spec.name] = []
64 |
65 | # load episodes
66 | self._total_episodes, self._total_steps = count_episodes(directory)
67 | self._loaded_episodes = 0
68 | self._loaded_steps = 0
69 | for f in load_filenames(self._directory, capacity, minlen, load_first=load_first):
70 | self.store_episode(filename=f)
71 |
72 | # worker -> key -> value_sequence
73 | self._length = length
74 | self._ongoing_eps = collections.defaultdict(on_fn)
75 | self._data_specs = data_specs
76 | self._meta_specs = meta_specs
77 | self.device = device
78 | try:
79 | assert self._minlen <= self._length <= self._maxlen
80 | except:
81 | print("Sampling sequences with fixed length ", length)
82 | self._minlen = self._maxlen = self._length = length
83 |
84 | def __len__(self):
85 | return self._total_steps
86 |
87 | def preallocate_memory(self, max_size):
88 | self._preallocated_mem = collections.defaultdict(list)
89 | for spec in self._data_specs:
90 | if type(spec) in [dict, Dict]:
91 | for k,v in spec.items():
92 | for _ in range(max_size):
93 | self._preallocated_mem[k].append(np.empty(list(v.shape), v.dtype))
94 | self._preallocated_mem[k][-1].fill(0.)
95 | else:
96 | for _ in range(max_size):
97 | self._preallocated_mem[spec.name].append(np.empty(list(v.shape), v.dtype))
98 | self._preallocated_mem[spec.name][-1].fill(0.)
99 |
100 | @property
101 | def stats(self):
102 | return {
103 | 'total_steps': self._total_steps,
104 | 'total_episodes': self._total_episodes,
105 | 'loaded_steps': self._loaded_steps,
106 | 'loaded_episodes': self._loaded_episodes,
107 | }
108 |
109 | def add(self, time_step, meta, idx=0):
110 | ### Useful if there was any failure in the environment
111 | if time_step == SIG_FAILURE:
112 | episode = self._ongoing_eps[idx]
113 | episode.clear()
114 | print("Discarding episode from process", idx)
115 | return
116 | ####
117 |
118 | episode = self._ongoing_eps[idx]
119 |
120 | def add_to_episode(name, data, spec):
121 | value = data[name]
122 | if np.isscalar(value):
123 | value = np.full(spec.shape, value, spec.dtype)
124 | assert spec.shape == value.shape and spec.dtype == value.dtype, f"Expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })"
125 | ### Deallocate preallocated memory
126 | if getattr(self, '_preallocated_mem', False):
127 | if len(self._preallocated_mem[name]) > 0:
128 | tmp = self._preallocated_mem[name].pop()
129 | del tmp
130 | else:
131 | # Out of pre-allocated memory
132 | del self._preallocated_mem
133 | ###
134 | episode[name].append(value)
135 |
136 | for spec in self._data_specs:
137 | if type(spec) in [dict, Dict]:
138 | for k,v in spec.items():
139 | add_to_episode(k, time_step, v)
140 | else:
141 | add_to_episode(spec.name, time_step, spec)
142 | for spec in self._meta_specs:
143 | if type(spec) in [dict, Dict]:
144 | for k,v in spec.items():
145 | add_to_episode(k, meta, v)
146 | else:
147 | add_to_episode(spec.name, meta, spec)
148 | if type(time_step) in [dict, Dict]:
149 | if time_step['is_last']:
150 | self.add_episode(episode)
151 | episode.clear()
152 | else:
153 | if time_step.last():
154 | self.add_episode(episode)
155 | episode.clear()
156 |
157 | def add_episode(self, episode):
158 | length = eplen(episode)
159 | if length < self._minlen:
160 | print(f'Skipping short episode of length {length}.')
161 | return
162 | self._total_steps += length
163 | self._total_episodes += 1
164 | episode = {key: convert(value) for key, value in episode.items()}
165 | if self._save_episodes:
166 | filename = self.save_episode(self._directory, episode)
167 | self.store_episode(episode=episode)
168 |
169 | def store_episode(self, filename=None, episode=None):
170 | if filename is not None:
171 | episode = load_episode(filename)
172 | if not episode:
173 | return False
174 | length = eplen(episode)
175 |
176 | # Enforce limit
177 | while self._loaded_steps + length > self._capacity:
178 | for k in self._complete_eps:
179 | self._complete_eps[k].pop(0)
180 | removed_len, self._episode_lens = self._episode_lens[0], self._episode_lens[1:]
181 | self._loaded_steps -= removed_len
182 | self._loaded_episodes -= 1
183 |
184 | # add episode
185 | for k,v in episode.items():
186 | self._complete_eps[k].append(v)
187 | self._episode_lens = np.append(self._episode_lens, length)
188 | self._loaded_steps += length
189 | self._loaded_episodes += 1
190 |
191 | return True
192 |
193 | def __iter__(self):
194 | while True:
195 | sequences, batch_size, batch_length = self._loaded_episodes, self.batch_size, self._length
196 |
197 | b_indices = np.random.randint(0, sequences, size=batch_size)
198 | t_indices = np.random.randint(np.zeros(batch_size), self._episode_lens[b_indices]-batch_length+1, size=batch_size)
199 | t_ranges = np.repeat( np.expand_dims(np.arange(0, batch_length,), 0), batch_size, axis=0) + np.expand_dims(t_indices, 1)
200 |
201 | chunk = {}
202 | for k in self._complete_eps:
203 | chunk[k] = np.stack([self._complete_eps[k][b][t] for b,t in zip(b_indices, t_ranges)])
204 | yield chunk
205 |
206 | @utils.retry
207 | def save_episode(self, directory, episode):
208 | idx = self._total_episodes
209 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
210 | identifier = str(uuid.uuid4().hex)
211 | length = eplen(episode)
212 | filename = directory / f'{idx}-{timestamp}-{identifier}-{length}.npz'
213 | with io.BytesIO() as f1:
214 | np.savez_compressed(f1, **episode)
215 | f1.seek(0)
216 | with filename.open('wb') as f2:
217 | f2.write(f1.read())
218 | return filename
219 |
220 | def load_episode(filename):
221 | try:
222 | with filename.open('rb') as f:
223 | episode = np.load(f)
224 | episode = {k: episode[k] for k in episode.keys()}
225 | except Exception as e:
226 | print(f'Could not load episode {str(filename)}: {e}')
227 | return False
228 | return episode
229 |
230 | def count_episodes(directory):
231 | filenames = list(directory.glob('*.npz'))
232 | num_episodes = len(filenames)
233 | if num_episodes == 0 : return 0, 0
234 | if len(filenames) > 0 and "-" in str(filenames[0]):
235 | num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames)
236 | last_episode = sorted(list(int(n.stem.split('-')[0]) for n in filenames))[-1]
237 | else:
238 | num_steps = sum(int(str(n).split('_')[-1][:-4]) - 1 for n in filenames)
239 | last_episode = sorted(list(int(n.stem.split('_')[0]) for n in filenames))[-1]
240 | return last_episode, num_steps
241 |
242 | def load_filenames(directory, capacity=None, minlen=1, load_first=False):
243 | # The returned directory from filenames to episodes is guaranteed to be in
244 | # temporally sorted order.
245 | filenames = sorted(directory.glob('*.npz'))
246 | if capacity:
247 | num_steps = 0
248 | num_episodes = 0
249 | ordered_filenames = filenames if load_first else reversed(filenames)
250 | for filename in ordered_filenames:
251 | if "-" in str(filename):
252 | length = int(str(filename).split('-')[-1][:-4])
253 | else:
254 | length = int(str(filename).split('_')[-1][:-4])
255 | num_steps += length
256 | num_episodes += 1
257 | if num_steps >= capacity:
258 | break
259 | if load_first:
260 | filenames = filenames[:num_episodes]
261 | else:
262 | filenames = filenames[-num_episodes:]
263 | return filenames
264 |
265 | def convert(value):
266 | value = np.array(value)
267 | if np.issubdtype(value.dtype, np.floating):
268 | return value.astype(np.float32)
269 | elif np.issubdtype(value.dtype, np.signedinteger):
270 | return value.astype(np.int32)
271 | elif np.issubdtype(value.dtype, np.uint8):
272 | return value.astype(np.uint8)
273 | return value
274 |
275 |
276 | def eplen(episode):
277 | return len(episode['action'])
278 |
279 | def make_replay_loader(buffer, batch_size, num_workers=0):
280 | buffer.batch_size = batch_size
281 | return DataLoader(buffer,
282 | batch_size=None,
283 | # NOTE: do not use any workers,
284 | # as they don't get copies of the replay buffer
285 | # (takes more time to refetch the data than sampling in main)
286 | )
287 |
--------------------------------------------------------------------------------
/rl/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import traceback
3 |
4 | warnings.warn_explicit = warnings.warn = lambda *_, **__: None
5 | warnings.filterwarnings('ignore', category=DeprecationWarning)
6 |
7 |
8 | import os
9 | import sys
10 | from contextlib import redirect_stderr, redirect_stdout
11 |
12 | os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
13 |
14 | os.environ['MUJOCO_GL'] = 'egl'
15 | os.environ['PYDEVD_UNBLOCK_THREADS_TIMEOUT'] = '900000'
16 |
17 | from pathlib import Path
18 |
19 | import hydra
20 | import numpy as np
21 | import torch
22 | import wandb
23 | from dm_env import specs
24 |
25 | import envs
26 | import utils
27 | from logger import Logger
28 | from np_replay import ReplayBuffer, make_replay_loader, SIG_FAILURE
29 | from collections import defaultdict
30 |
31 | from functools import partial
32 |
33 | torch.backends.cudnn.benchmark = True
34 |
35 | def get_gpu_memory():
36 | import subprocess as sp
37 | command = "nvidia-smi --query-gpu=memory.free --format=csv"
38 | memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
39 | memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
40 | return memory_free_values
41 |
42 | def make_agent(obs_space, act_space, cur_config, cfg):
43 | from copy import deepcopy
44 | cur_config = deepcopy(cur_config)
45 | del cur_config.agent
46 | return hydra.utils.instantiate(cfg, cfg=cur_config, obs_space=obs_space, act_space=act_space)
47 |
48 | class Workspace:
49 | def __init__(self, cfg, savedir=None, workdir=None):
50 | self.workdir = Path.cwd() if workdir is None else workdir
51 | print(f'workspace: {self.workdir}')
52 |
53 | device = None
54 |
55 | self.cfg = cfg
56 | utils.set_seed_everywhere(cfg.seed)
57 |
58 | # create logger
59 | self.logger = Logger(self.workdir,
60 | use_tb=cfg.use_tb,
61 | use_wandb=cfg.use_wandb)
62 | # create envs
63 | task = cfg.task
64 | frame_stack = 1
65 | img_size = getattr(getattr(cfg, 'env', None), 'img_size', 84) # 84 is the DrQ default
66 |
67 | self.parallel_envs = cfg.parallel_envs
68 |
69 | if cfg.spread_envs and os.environ['DISPLAY']:
70 | if torch.cuda.device_count() > 1:
71 | proposed_display = np.argmax(get_gpu_memory()).item()
72 | if proposed_display < torch.cuda.device_count():
73 | os.environ['DISPLAY'] = os.environ['DISPLAY'] + '.' + str(proposed_display)
74 |
75 | self.train_env_fn = partial(envs.make, task, cfg.obs_type, frame_stack,
76 | cfg.action_repeat, cfg.seed, img_size=img_size, cfg=cfg, is_eval=False)
77 | self.train_env = envs.MultiProcessEnv(self.train_env_fn, self.parallel_envs)
78 |
79 | if cfg.flexible_gpu:
80 | import time
81 | from hydra.core.hydra_config import HydraConfig
82 |
83 | try:
84 | job_num = getattr(HydraConfig.get().job, 'num', None)
85 | if job_num is not None:
86 | print("Job number:", HydraConfig.get().job.num)
87 | time.sleep(HydraConfig.get().job.num)
88 | except:
89 | pass
90 |
91 | while device is None:
92 | try:
93 | cfg.device = device = 'cuda:' + str(np.argmax(get_gpu_memory()).item())
94 | print("Using device:", device)
95 | except:
96 | pass
97 |
98 | self.device = torch.device(cfg.device)
99 |
100 | # # create agent
101 | self.agent = make_agent(self.train_env.obs_space,
102 | self.train_env.act_space, cfg, cfg.agent)
103 | # get meta specs
104 | meta_specs = self.agent.get_meta_specs()
105 | # create replay buffer
106 | data_specs = (self.train_env.obs_space,
107 | self.train_env.act_space,
108 | specs.Array((1,), np.float32, 'reward'),
109 | specs.Array((1,), np.float32, 'discount'))
110 | self.act_space = self.train_env.act_space
111 |
112 |
113 | # create replay storage
114 | self.replay_storage = ReplayBuffer(data_specs, meta_specs,
115 | self.workdir / 'buffer',
116 | length=cfg.batch_length, **cfg.replay,
117 | device=cfg.device,
118 | fetch_every=cfg.batch_size,
119 | save_episodes=cfg.save_episodes)
120 |
121 | if cfg.preallocate_memory:
122 | self.replay_storage.preallocate_memory(cfg.num_train_frames // cfg.action_repeat)
123 |
124 | # create replay buffer
125 | self.replay_loader = make_replay_loader(self.replay_storage,
126 | cfg.batch_size, #
127 | cfg.num_workers
128 | )
129 | self._replay_iter = None
130 |
131 | self.timer = utils.Timer()
132 | self._global_step = 0
133 | self._global_episode = 0
134 | self._reward_stats = {'max' : -1e10, 'min' : 1e10, 'ep_avg_max' : -1e10 }
135 |
136 | @property
137 | def global_step(self):
138 | return self._global_step
139 |
140 | @property
141 | def global_episode(self):
142 | return self._global_episode
143 |
144 | @property
145 | def global_frame(self):
146 | return self.global_step * self.cfg.action_repeat
147 |
148 | @property
149 | def replay_iter(self):
150 | if self._replay_iter is None:
151 | self._replay_iter = iter(self.replay_loader)
152 | return self._replay_iter
153 |
154 | def train(self):
155 | # predicates
156 | train_until_step = utils.Until(self.cfg.num_train_frames,
157 | self.cfg.action_repeat)
158 | seed_until_step = utils.Until(self.cfg.num_seed_frames,
159 | self.cfg.action_repeat)
160 |
161 | # To preserve speed never request more than 3/4 of the num of envs
162 | train_every_n_steps = max(min(self.cfg.train_every_actions, (self.parallel_envs * 3) // 4 ), 1)
163 | print(f"Training every {train_every_n_steps} steps from the environment (num envs: {self.cfg.parallel_envs})")
164 |
165 | next_log_point = (self.global_frame // self.cfg.log_every_frames) + 1
166 |
167 | episode_step, episode_reward, episode_success, episode_invalid, episode_max_reward = np.zeros(self.parallel_envs), np.zeros(self.parallel_envs), np.zeros(self.parallel_envs), np.zeros(self.parallel_envs), np.zeros(self.parallel_envs)
168 | average_steps = []
169 | last_episodes = []
170 | complete_idxs = []
171 | time_step = self.train_env.reset_all()
172 | agent_state = None
173 | meta = self.agent.init_meta()
174 | for n, idx in enumerate(time_step['env_idx']):
175 | env_obs = { k: time_step[k][n] for k in time_step}
176 | if env_obs['env_error']:
177 | env_obs = self.train_env.reset_idx(idx)
178 | del env_obs['env_error']
179 | del env_obs['env_idx']
180 | self.replay_storage.add(env_obs, meta, idx=idx)
181 |
182 | metrics = None
183 | elapsed_time, total_time = self.timer.reset()
184 |
185 | while train_until_step(self.global_step):
186 | for n in np.where(time_step['is_last'])[0]:
187 | self._global_episode += 1
188 | idx = time_step['env_idx'][n]
189 | complete_idxs.append(idx)
190 |
191 | if not time_step['env_error'][n]:
192 | last_episodes.append([episode_step[idx], episode_reward[idx], episode_success[idx], episode_invalid[idx], episode_max_reward[idx]])
193 | self._reward_stats['ep_avg_max'] = max(self._reward_stats['ep_avg_max'], episode_reward[idx] / episode_step[idx])
194 | episode_step[idx], episode_reward[idx], episode_success[idx], episode_invalid[idx], episode_max_reward[idx] = [0,0,0,0,0]
195 |
196 | if self.cfg.async_mode == 'FULL':
197 | self.train_env._send_reset(idx)
198 | else:
199 | reset_obs = self.train_env.reset_idx(idx)
200 | for k,v in reset_obs.items():
201 | time_step[k][n] = v
202 | del reset_obs['env_error']
203 | assert idx == reset_obs.pop('env_idx')
204 | self.replay_storage.add(reset_obs, meta, idx=idx)
205 |
206 | # wait until all the metrics schema is populated
207 | if (self.global_step >= next_log_point * self.cfg.log_every_frames):
208 | next_log_point += 1
209 | # Episodes logging
210 | if len(last_episodes) > 0:
211 | last_episodes = np.stack(last_episodes, axis=0)
212 | last_step, last_reward, last_success, last_invalid, last_max_reward = np.mean(last_episodes, axis=0)
213 |
214 | # log stats
215 | elapsed_time, total_time = self.timer.reset()
216 | last_frame = last_step * self.cfg.action_repeat
217 | with self.logger.log_and_dump_ctx(self.global_frame,
218 | ty='train') as log:
219 | log('fps', self.cfg.log_every_frames / elapsed_time)
220 | log('total_time', total_time)
221 | log('buffer_size', len(self.replay_storage))
222 | log('episode_reward', last_reward )
223 | log('episode_avg_valid_reward', last_reward / (last_step - last_invalid) )
224 | log('episode_max_reward', last_max_reward )
225 | log('episode_length', last_frame )
226 | log('episode', self.global_episode)
227 | log('step', self.global_step)
228 | log('average_steps', sum(average_steps) / len(average_steps))
229 | if 'invalid_action' in time_step:
230 | log('episode_invalid', last_invalid )
231 | if 'success' in time_step:
232 | # episode_success = np.stack(episode_success)
233 | # ep_success = (episode_success[-10:].mean(axis=0) > 0.5).mean()
234 | # log('success', ep_success)
235 | # anytime_success = (episode_success.sum(axis=0) > 0.).mean()
236 | log('anytime_success', last_success)
237 | if getattr(self.agent, '_stats', False):
238 | for k,v in self.agent._stats.items():
239 | log(k, v)
240 |
241 | last_episodes = []
242 | average_steps = []
243 |
244 | # Agent logging
245 | if metrics is not None:
246 | # add rew metrics
247 | rew_metrics = {f'reward_stats/{k}' : v for k,v in self._reward_stats.items()}
248 | metrics.update(rew_metrics)
249 | self.logger.log_metrics(metrics, self.global_frame, ty='train')
250 |
251 | meta = self.agent.update_meta(meta, self.global_step, time_step)
252 |
253 | # sample action
254 | with torch.no_grad(), utils.eval_mode(self.agent):
255 | action, agent_state = self.agent.act(time_step, # time_step.observation
256 | meta,
257 | self.global_step,
258 | eval_mode=False,
259 | state=agent_state)
260 |
261 | # try to update the agent
262 | if not seed_until_step(self.global_step) and not (self.replay_storage.stats['total_episodes'] < self.cfg.num_seed_episodes):
263 | metrics = self.agent.update(next(self.replay_iter), self.global_step)[1]
264 |
265 | # take env step
266 | if self.cfg.async_mode == 'FULL':
267 | time_step = self.train_env.step_by_idx(action, idxs=time_step['env_idx'], requested_steps=train_every_n_steps, ignore_idxs=complete_idxs)
268 | complete_idxs = []
269 | elif self.cfg.async_mode == 'HALF':
270 | if time_step['env_idx'].shape[0] > self.cfg.parallel_envs//2:
271 | assert (time_step['env_idx'] == np.arange(0,self.cfg.parallel_envs)).all()
272 | self.train_env.step_receive_by_idx(action[:self.cfg.parallel_envs//2], send_idxs=np.arange(0,self.cfg.parallel_envs//2), recv_idxs=[]) # receive next
273 | send_idxs, recv_idxs = np.arange(self.cfg.parallel_envs//2, self.cfg.parallel_envs), np.arange(0,self.cfg.parallel_envs//2)
274 | time_step = self.train_env.step_receive_by_idx(action[self.cfg.parallel_envs//2:], send_idxs=send_idxs, recv_idxs=recv_idxs)
275 | else:
276 | send_idxs, recv_idxs = recv_idxs, send_idxs
277 | assert (time_step['env_idx'] == send_idxs).all()
278 | time_step = self.train_env.step_receive_by_idx(action, send_idxs=send_idxs, recv_idxs=recv_idxs)
279 | elif self.cfg.async_mode == 'OFF':
280 | time_step = self.train_env.step_all(action)
281 | else:
282 | raise NotImplementedError(f"Odd async modality : {self.cfg.async_mode}")
283 |
284 | # process env data
285 | for n, idx in enumerate(time_step['env_idx']):
286 | env_obs = { k: time_step[k][n] for k in time_step}
287 | if env_obs['env_error']:
288 | env_obs = SIG_FAILURE
289 | # Forcing reset
290 | time_step['is_last'][n] = 1.0
291 | # Fixing global stats (steps were invalid)
292 | self._global_step -= episode_step[idx]
293 | else:
294 | # Remove extra keys
295 | del env_obs['env_error']
296 | del env_obs['env_idx']
297 |
298 | # update episode stats
299 | episode_reward[idx] += env_obs['reward']
300 | episode_max_reward[idx] = max(env_obs['reward'], episode_max_reward[idx])
301 | if 'invalid_action' in env_obs:
302 | episode_invalid[idx] += env_obs['invalid_action']
303 | if 'success' in env_obs:
304 | episode_success[idx] += env_obs['success']
305 | episode_success[idx] = np.clip(episode_success[idx], 0, 1)
306 | episode_step[idx] += 1
307 |
308 | if not seed_until_step(self.global_step) and self.cfg.log_best_episodes :
309 | if env_obs['is_last'] and episode_reward[idx] / episode_step[idx] > self._reward_stats['ep_avg_max'] and len(self.replay_storage._ongoing_eps[idx]['action']) > 0:
310 | self._reward_stats['ep_avg_max'] = episode_reward[idx] / episode_step[idx]
311 | # Log video of best episode
312 | videos = {}
313 | if 'front_rgb' in env_obs:
314 | videos['ep_avg_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['front_rgb'], axis=0), axis=0)
315 | if 'wrist_rgb' in env_obs:
316 | if 'ep_avg_max/rgb' in videos:
317 | videos['ep_avg_max/rgb'] = np.concatenate([videos['ep_avg_max/rgb'], np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0)], axis=0)
318 | else:
319 | videos['ep_avg_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0)
320 | self.logger.log_video(videos, self.global_frame)
321 | if env_obs['reward'] > self._reward_stats['max'] and len(self.replay_storage._ongoing_eps[idx]['action']) > 0:
322 | self._reward_stats['max'] = env_obs['reward']
323 | # Log video of best reward
324 | videos = {}
325 | if 'front_rgb' in env_obs:
326 | videos['rew_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['front_rgb'], axis=0), axis=0)
327 | if 'wrist_rgb' in env_obs:
328 | if 'rew_max/rgb' in videos:
329 | videos['rew_max/rgb'] = np.concatenate([videos['rew_max/rgb'], np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0)], axis=0)
330 | else:
331 | videos['rew_max/rgb'] = np.expand_dims(np.stack(self.replay_storage._ongoing_eps[idx]['wrist_rgb'], axis=0), axis=0)
332 | self.logger.log_video(videos, self.global_frame)
333 |
334 | self.replay_storage.add(env_obs, meta, idx=idx)
335 |
336 | # update global stats
337 | self._reward_stats['max'] = max(self._reward_stats['max'], max(time_step['reward']))
338 | self._reward_stats['min'] = min(self._reward_stats['min'], min(time_step['reward']))
339 | self._global_step += time_step['env_idx'].shape[0]
340 | average_steps.append(time_step['env_idx'].shape[0])
341 |
342 | # save last model
343 | if self.global_frame % 5000 == 0 and self.cfg.save_episodes:
344 | self.save_last_model()
345 | sys.stdout.flush(), sys.stderr.flush()
346 |
347 | @utils.retry
348 | def save_snapshot(self):
349 | snapshot = self.get_snapshot_dir() / f'snapshot_{self.global_frame}.pt'
350 | keys_to_save = ['agent', '_global_step', '_global_episode']
351 | payload = {k: self.__dict__[k] for k in keys_to_save}
352 | with snapshot.open('wb') as f:
353 | torch.save(payload, f)
354 |
355 | def setup_wandb(self):
356 | cfg = self.cfg
357 | exp_name = '_'.join([
358 | getattr(getattr(cfg,'env', {}), 'action_mode', '_'), cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type, str(cfg.seed)
359 | ])
360 | wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name, notes=f'workspace: {self.workdir}')
361 | wandb.config.update(cfg)
362 |
363 | # define our custom x axis metric
364 | wandb.define_metric("train/frame")
365 | # set all other train/ metrics to use this step
366 | wandb.define_metric("train/*", step_metric="train/frame")
367 |
368 | self.wandb_run_id = wandb.run.id
369 |
370 | @utils.retry
371 | def save_last_model(self):
372 | snapshot = self.root_dir / 'last_snapshot.pt'
373 | if snapshot.is_file():
374 | temp = Path(str(snapshot).replace("last_snapshot.pt", "second_last_snapshot.pt"))
375 | os.replace(snapshot, temp)
376 | keys_to_save = ['agent', '_global_step', '_global_episode', '_reward_stats']
377 | if self.cfg.use_wandb:
378 | keys_to_save.append('wandb_run_id')
379 | payload = {k: self.__dict__[k] for k in keys_to_save}
380 | with snapshot.open('wb') as f:
381 | torch.save(payload, f)
382 |
383 | def load_snapshot(self):
384 | try:
385 | snapshot = self.root_dir / 'last_snapshot.pt'
386 | with snapshot.open('rb') as f:
387 | payload = torch.load(f)
388 | except:
389 | snapshot = self.root_dir / 'second_last_snapshot.pt'
390 | with snapshot.open('rb') as f:
391 | payload = torch.load(f)
392 | for k,v in payload.items():
393 | setattr(self, k, v)
394 | if k == 'wandb_run_id':
395 | assert wandb.run is None
396 | cfg = self.cfg
397 | exp_name = '_'.join([
398 | getattr(getattr(cfg,'env', {}), 'action_mode', '_'), cfg.experiment, cfg.agent.name, cfg.task, cfg.obs_type, str(cfg.seed)
399 | ])
400 | wandb.init(project=cfg.project_name, group=cfg.agent.name, name=exp_name, id=v, resume="must", notes=f'workspace: {self.workdir}')
401 | # define our custom x axis metric
402 | wandb.define_metric("train/frame")
403 | # set all other train/ metrics to use this step
404 | wandb.define_metric("train/*", step_metric="train/frame")
405 |
406 |
407 | def get_snapshot_dir(self):
408 | snap_dir = self.cfg.snapshot_dir
409 | snapshot_dir = self.workdir / Path(snap_dir)
410 | snapshot_dir.mkdir(exist_ok=True, parents=True)
411 | return snapshot_dir
412 |
413 | @hydra.main(config_path='.', config_name='train')
414 | def main(cfg):
415 | try:
416 | root_dir = Path.cwd()
417 | with open(str(root_dir / "out.log"), 'a') as stdout, redirect_stdout(stdout), open(str(root_dir / "err.log"), 'a') as stderr, redirect_stderr(stderr):
418 | workspace = Workspace(cfg)
419 | workspace.root_dir = root_dir
420 | # for resuming, config env.run.dir to the snapshot path
421 | snapshot = workspace.root_dir / 'last_snapshot.pt'
422 | if snapshot.exists():
423 | print(f'resuming: {snapshot}')
424 | workspace.load_snapshot()
425 | if cfg.use_wandb and wandb.run is None:
426 | # otherwise it was resumed
427 | workspace.setup_wandb()
428 | workspace.train()
429 | except Exception as e:
430 | print(traceback.format_exc())
431 | finally:
432 | if hasattr(workspace, 'train_env'):
433 | del workspace.train_env
434 |
435 | if __name__ == '__main__':
436 | main()
437 |
--------------------------------------------------------------------------------
/rl/train.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - configs/default
4 | - agent: sac
5 | - configs: ${configs}
6 | # - override hydra/launcher: submitit_local
7 | - override hydra/launcher: joblib
8 | - override hydra/hydra_logging: custom
9 | - override hydra/job_logging: stdout
10 |
11 | # task settings
12 | task: none
13 | domain: walker # primal task will be infered in runtime
14 | # train settings
15 | num_train_frames: 500010
16 | num_seed_frames: 4000
17 | num_seed_episodes: ${num_workers}
18 | # eval
19 | eval_every_frames: 1000000000 # not necessary during pretrain
20 | num_eval_episodes: 10
21 | # snapshot
22 | snapshots: [100000, 500000, 1000000, 2000000]
23 | snapshot_dir: ../../../pretrained_models/${obs_type}/${task}/${agent.name}/${seed}
24 |
25 | # replay buffer
26 | replay_buffer_size: 1000000
27 | num_workers: 4
28 | save_episodes: False
29 | preallocate_memory: False
30 |
31 | # misc
32 | seed: 1
33 | device: cuda
34 | use_tb: true
35 | use_wandb: true
36 |
37 | # experiment
38 | experiment: default
39 | project_name: ???
40 | flexible_gpu: true
41 | spread_envs: true
42 |
43 | # log settings
44 | log_every_frames: 2500
45 | log_best_episodes: True
46 |
47 | hydra:
48 | run:
49 | dir: ./exp_local/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}_${agent.name}_${obs_type}_${task}_${env.action_mode}_${seed}
50 | sweep:
51 | dir: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M%S}_${experiment}_${agent.name}_${obs_type}
52 | subdir: ${task}_${env.action_mode}_${seed}_${hydra.job.num}
53 |
54 | # launcher:
55 | # timeout_min: 4300
56 | # cpus_per_task: 2
57 | # gpus_per_node: 4
58 | # tasks_per_node: 1
59 | # mem_gb: 160
60 | # nodes: 1
61 | # submitit_folder: ./exp_sweep/${now:%Y.%m.%d}/${now:%H%M%S}_${agent.name}_${experiment}_${seed}/.slurm
62 |
--------------------------------------------------------------------------------
/rl/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import re
4 | import time
5 | from functools import wraps
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from omegaconf import OmegaConf
12 | from torch import distributions as pyd
13 | from torch.distributions.utils import _standard_normal
14 | import numbers
15 |
16 | class eval_mode:
17 | def __init__(self, *models):
18 | self.models = models
19 |
20 | def __enter__(self):
21 | self.prev_states = []
22 | for model in self.models:
23 | self.prev_states.append(model.training)
24 | model.train(False)
25 |
26 | def __exit__(self, *args):
27 | for model, state in zip(self.models, self.prev_states):
28 | model.train(state)
29 | return False
30 |
31 |
32 | def set_seed_everywhere(seed):
33 | torch.manual_seed(seed)
34 | if torch.cuda.is_available():
35 | torch.cuda.manual_seed_all(seed)
36 | np.random.seed(seed)
37 | random.seed(seed)
38 |
39 |
40 | def chain(*iterables):
41 | for it in iterables:
42 | yield from it
43 |
44 |
45 | def soft_update_params(net, target_net, tau):
46 | for param, target_param in zip(net.parameters(), target_net.parameters()):
47 | target_param.data.copy_(tau * param.data +
48 | (1 - tau) * target_param.data)
49 |
50 |
51 | def hard_update_params(net, target_net):
52 | for param, target_param in zip(net.parameters(), target_net.parameters()):
53 | target_param.data.copy_(param.data)
54 |
55 |
56 | def to_torch(xs, device):
57 | return tuple(torch.as_tensor(x, device=device) for x in xs)
58 |
59 |
60 | def weight_init(m):
61 | """Custom weight init for Conv2D and Linear layers."""
62 | if isinstance(m, nn.Linear):
63 | nn.init.orthogonal_(m.weight.data)
64 | if hasattr(m.bias, 'data'):
65 | m.bias.data.fill_(0.0)
66 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
67 | gain = nn.init.calculate_gain('relu')
68 | nn.init.orthogonal_(m.weight.data, gain)
69 | if hasattr(m.bias, 'data'):
70 | m.bias.data.fill_(0.0)
71 |
72 |
73 | def grad_norm(params, norm_type=2.0):
74 | params = [p for p in params if p.grad is not None]
75 | total_norm = torch.norm(
76 | torch.stack([torch.norm(p.grad.detach(), norm_type) for p in params]),
77 | norm_type)
78 | return total_norm.item()
79 |
80 |
81 | def param_norm(params, norm_type=2.0):
82 | total_norm = torch.norm(
83 | torch.stack([torch.norm(p.detach(), norm_type) for p in params]),
84 | norm_type)
85 | return total_norm.item()
86 |
87 |
88 | class Until:
89 | def __init__(self, until, action_repeat=1):
90 | self._until = until
91 | self._action_repeat = action_repeat
92 |
93 | def __call__(self, step):
94 | if self._until is None:
95 | return True
96 | until = self._until // self._action_repeat
97 | return step < until
98 |
99 |
100 | class Every:
101 | def __init__(self, every, action_repeat=1, label='train'):
102 | self._every = every
103 | self._action_repeat = action_repeat
104 | if self._every // self._action_repeat == 0:
105 | print(f"WARNING: asking to {label} every 0 steps. Defaulting to 1.")
106 |
107 | def __call__(self, step):
108 | if self._every is None:
109 | return False
110 | every = max(1, self._every // self._action_repeat)
111 | if step % every == 0:
112 | return True
113 | return False
114 |
115 |
116 | class Timer:
117 | def __init__(self):
118 | self._start_time = time.time()
119 | self._last_time = time.time()
120 |
121 | def reset(self):
122 | elapsed_time = time.time() - self._last_time
123 | self._last_time = time.time()
124 | total_time = time.time() - self._start_time
125 | return elapsed_time, total_time
126 |
127 | def total_time(self):
128 | return time.time() - self._start_time
129 |
130 |
131 | class TruncatedNormal(pyd.Normal):
132 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
133 | super().__init__(loc, scale, validate_args=False)
134 | self.low = low
135 | self.high = high
136 | self.eps = eps
137 |
138 | def _clamp(self, x):
139 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
140 | x = x - x.detach() + clamped_x.detach()
141 | return x
142 |
143 | def sample(self, sample_shape=torch.Size(), clip=None):
144 | shape = self._extended_shape(sample_shape)
145 | eps = _standard_normal(shape,
146 | dtype=self.loc.dtype,
147 | device=self.loc.device)
148 | eps *= self.scale
149 | if clip is not None:
150 | eps = torch.clamp(eps, -clip, clip)
151 | x = self.loc + eps
152 | return self._clamp(x)
153 |
154 |
155 | class TanhTransform(pyd.transforms.Transform):
156 | domain = pyd.constraints.real
157 | codomain = pyd.constraints.interval(-1.0, 1.0)
158 | bijective = True
159 | sign = +1
160 |
161 | def __init__(self, cache_size=1):
162 | super().__init__(cache_size=cache_size)
163 |
164 | @staticmethod
165 | def atanh(x):
166 | return 0.5 * (x.log1p() - (-x).log1p())
167 |
168 | def __eq__(self, other):
169 | return isinstance(other, TanhTransform)
170 |
171 | def _call(self, x):
172 | return x.tanh()
173 |
174 | def _inverse(self, y):
175 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
176 | # one should use `cache_size=1` instead
177 | return self.atanh(y)
178 |
179 | def log_abs_det_jacobian(self, x, y):
180 | # We use a formula that is more numerically stable, see details in the following link
181 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
182 | return 2. * (math.log(2.) - x - F.softplus(-2. * x))
183 |
184 |
185 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
186 | def __init__(self, loc, scale):
187 | self.loc = loc
188 | self.scale = scale
189 |
190 | self.base_dist = pyd.Normal(loc, scale)
191 | transforms = [TanhTransform()]
192 | super().__init__(self.base_dist, transforms)
193 |
194 | @property
195 | def mean(self):
196 | mu = self.loc
197 | for tr in self.transforms:
198 | mu = tr(mu)
199 | return mu
200 |
201 |
202 | def schedule(schdl, step):
203 | try:
204 | return float(schdl)
205 | except ValueError:
206 | match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
207 | if match:
208 | init, final, duration = [float(g) for g in match.groups()]
209 | mix = np.clip(step / duration, 0.0, 1.0)
210 | return (1.0 - mix) * init + mix * final
211 | match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
212 | if match:
213 | init, final1, duration1, final2, duration2 = [
214 | float(g) for g in match.groups()
215 | ]
216 | if step <= duration1:
217 | mix = np.clip(step / duration1, 0.0, 1.0)
218 | return (1.0 - mix) * init + mix * final1
219 | else:
220 | mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
221 | return (1.0 - mix) * final1 + mix * final2
222 | raise NotImplementedError(schdl)
223 |
224 | def retry(func):
225 | """
226 | A Decorator to retry a function for a certain amount of attempts
227 | """
228 |
229 | @wraps(func)
230 | def wrapper(*args, **kwargs):
231 | attempts = 0
232 | max_attempts = 1000
233 | while attempts < max_attempts:
234 | try:
235 | return func(*args, **kwargs)
236 | except (OSError, PermissionError):
237 | attempts += 1
238 | time.sleep(0.1)
239 | raise OSError("Retry failed")
240 |
241 | return wrapper
242 |
243 | class RandomShiftsAug(nn.Module):
244 | def __init__(self, pad):
245 | super().__init__()
246 | self.pad = pad
247 |
248 | def forward(self, x):
249 | x = x.float()
250 | n, c, h, w = x.size()
251 | assert h == w
252 | padding = tuple([self.pad] * 4)
253 | x = F.pad(x, padding, 'replicate')
254 | eps = 1.0 / (h + 2 * self.pad)
255 | arange = torch.linspace(-1.0 + eps,
256 | 1.0 - eps,
257 | h + 2 * self.pad,
258 | device=x.device,
259 | dtype=x.dtype)[:h]
260 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
261 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
262 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
263 |
264 | shift = torch.randint(0,
265 | 2 * self.pad + 1,
266 | size=(n, 1, 1, 2),
267 | device=x.device,
268 | dtype=x.dtype)
269 | shift *= 2.0 / (h + 2 * self.pad)
270 |
271 | grid = base_grid + shift
272 | return F.grid_sample(x,
273 | grid,
274 | padding_mode='zeros',
275 | align_corners=False)
276 |
277 |
278 | class RMS(object):
279 | """running mean and std """
280 | def __init__(self, device, epsilon=1e-4, shape=(1,)):
281 | self.M = torch.zeros(shape).to(device)
282 | self.S = torch.ones(shape).to(device)
283 | self.n = epsilon
284 |
285 | def __call__(self, x):
286 | bs = x.size(0)
287 | delta = torch.mean(x, dim=0) - self.M
288 | new_M = self.M + delta * bs / (self.n + bs)
289 | new_S = (self.S * self.n + torch.var(x, dim=0) * bs +
290 | torch.square(delta) * self.n * bs /
291 | (self.n + bs)) / (self.n + bs)
292 |
293 | self.M = new_M
294 | self.S = new_S
295 | self.n += bs
296 |
297 | return self.M, self.S
298 |
299 |
300 | class PBE(object):
301 | """particle-based entropy based on knn normalized by running mean """
302 | def __init__(self, rms, knn_clip, knn_k, knn_avg, knn_rms, device):
303 | self.rms = rms
304 | self.knn_rms = knn_rms
305 | self.knn_k = knn_k
306 | self.knn_avg = knn_avg
307 | self.knn_clip = knn_clip
308 | self.device = device
309 |
310 | def __call__(self, rep, cdist=False, apply_log=True):
311 | source = target = rep
312 | b1, b2 = source.size(0), target.size(0)
313 | # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2)
314 | if cdist:
315 | sim_matrix = torch.cdist(source, target.detach(), p=2)
316 | else:
317 | sim_matrix = torch.norm(source[:, None, :].view(b1, 1, -1) -
318 | target[None, :, :].view(1, b2, -1),
319 | dim=-1,
320 | p=2)
321 | reward, _ = sim_matrix.topk(self.knn_k,
322 | dim=1,
323 | largest=False,
324 | sorted=True) # (b1, k)
325 | if not self.knn_avg: # only keep k-th nearest neighbor
326 | reward = reward[:, -1]
327 | reward = reward.reshape(-1, 1) # (b1, 1)
328 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0
329 | reward = torch.maximum(
330 | reward - self.knn_clip,
331 | torch.zeros_like(reward).to(self.device)
332 | ) if self.knn_clip >= 0.0 else reward # (b1, 1)
333 | else: # average over all k nearest neighbors
334 | reward = reward.reshape(-1, 1) # (b1 * k, 1)
335 | reward /= self.rms(reward)[0] if self.knn_rms else 1.0
336 | reward = torch.maximum(
337 | reward - self.knn_clip,
338 | torch.zeros_like(reward).to(
339 | self.device)) if self.knn_clip >= 0.0 else reward
340 | reward = reward.reshape((b1, self.knn_k)) # (b1, k)
341 | reward = reward.mean(dim=1, keepdim=True) # (b1, 1)
342 | if apply_log:
343 | reward = torch.log(reward + 1.0)
344 | return reward
345 |
346 | def rgb2hsv(rgb, eps=1e-8):
347 | # Reference: https://www.rapidtables.com/convert/color/rgb-to-hsv.html
348 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287
349 |
350 | _device = rgb.device
351 | r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
352 |
353 | Cmax = rgb.max(1)[0]
354 | Cmin = rgb.min(1)[0]
355 | delta = Cmax - Cmin
356 |
357 | hue = torch.zeros((rgb.shape[0], rgb.shape[2], rgb.shape[3])).to(_device)
358 | hue[Cmax== r] = (((g - b)/(delta + eps)) % 6)[Cmax == r]
359 | hue[Cmax == g] = ((b - r)/(delta + eps) + 2)[Cmax == g]
360 | hue[Cmax == b] = ((r - g)/(delta + eps) + 4)[Cmax == b]
361 | hue[Cmax == 0] = 0.0
362 | hue = hue / 6. # making hue range as [0, 1.0)
363 | hue = hue.unsqueeze(dim=1)
364 |
365 | saturation = (delta) / (Cmax + eps)
366 | saturation[Cmax == 0.] = 0.
367 | saturation = saturation.to(_device)
368 | saturation = saturation.unsqueeze(dim=1)
369 |
370 | value = Cmax
371 | value = value.to(_device)
372 | value = value.unsqueeze(dim=1)
373 |
374 | return torch.cat((hue, saturation, value), dim=1)#.type(torch.FloatTensor).to(_device)
375 | # return hue, saturation, value
376 |
377 | def hsv2rgb(hsv):
378 | # Reference: https://www.rapidtables.com/convert/color/hsv-to-rgb.html
379 | # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287
380 |
381 | _device = hsv.device
382 |
383 | hsv = torch.clamp(hsv, 0, 1)
384 | hue = hsv[:, 0, :, :] * 360.
385 | saturation = hsv[:, 1, :, :]
386 | value = hsv[:, 2, :, :]
387 |
388 | c = value * saturation
389 | x = - c * (torch.abs((hue / 60.) % 2 - 1) - 1)
390 | m = (value - c).unsqueeze(dim=1)
391 |
392 | rgb_prime = torch.zeros_like(hsv).to(_device)
393 |
394 | inds = (hue < 60) * (hue >= 0)
395 | rgb_prime[:, 0, :, :][inds] = c[inds]
396 | rgb_prime[:, 1, :, :][inds] = x[inds]
397 |
398 | inds = (hue < 120) * (hue >= 60)
399 | rgb_prime[:, 0, :, :][inds] = x[inds]
400 | rgb_prime[:, 1, :, :][inds] = c[inds]
401 |
402 | inds = (hue < 180) * (hue >= 120)
403 | rgb_prime[:, 1, :, :][inds] = c[inds]
404 | rgb_prime[:, 2, :, :][inds] = x[inds]
405 |
406 | inds = (hue < 240) * (hue >= 180)
407 | rgb_prime[:, 1, :, :][inds] = x[inds]
408 | rgb_prime[:, 2, :, :][inds] = c[inds]
409 |
410 | inds = (hue < 300) * (hue >= 240)
411 | rgb_prime[:, 2, :, :][inds] = c[inds]
412 | rgb_prime[:, 0, :, :][inds] = x[inds]
413 |
414 | inds = (hue < 360) * (hue >= 300)
415 | rgb_prime[:, 2, :, :][inds] = x[inds]
416 | rgb_prime[:, 0, :, :][inds] = c[inds]
417 |
418 | rgb = rgb_prime + torch.cat((m, m, m), dim=1)
419 | rgb = rgb.to(_device)
420 |
421 | return torch.clamp(rgb, 0, 1)
422 |
423 | class ColorJitterLayer(nn.Module):
424 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0, batch_size=128, stack_size=3):
425 | super(ColorJitterLayer, self).__init__()
426 | self.brightness = self._check_input(brightness, 'brightness')
427 | self.contrast = self._check_input(contrast, 'contrast')
428 | self.saturation = self._check_input(saturation, 'saturation')
429 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
430 | clip_first_on_zero=False)
431 | self.prob = p
432 | self.batch_size = batch_size
433 | self.stack_size = stack_size
434 |
435 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
436 | if isinstance(value, numbers.Number):
437 | if value < 0:
438 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
439 | value = [center - value, center + value]
440 | if clip_first_on_zero:
441 | value[0] = max(value[0], 0)
442 | elif isinstance(value, (tuple, list)) and len(value) == 2:
443 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
444 | raise ValueError("{} values should be between {}".format(name, bound))
445 | else:
446 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
447 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
448 | # or (0., 0.) for hue, do nothing
449 | if value[0] == value[1] == center:
450 | value = None
451 | return value
452 |
453 | def adjust_contrast(self, x):
454 | """
455 | Args:
456 | x: torch tensor img (rgb type)
457 | Factor: torch tensor with same length as x
458 | 0 gives gray solid image, 1 gives original image,
459 | Returns:
460 | torch tensor image: Brightness adjusted
461 | """
462 | _device = x.device
463 | factor = torch.empty(x.shape[0], device=_device).uniform_(*self.contrast)
464 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1)
465 | means = torch.mean(x, dim=(2, 3), keepdim=True)
466 | return torch.clamp((x - means)
467 | * factor.view(len(x), 1, 1, 1) + means, 0, 1)
468 |
469 | def adjust_hue(self, x):
470 | _device = x.device
471 | factor = torch.empty(x.shape[0], device=_device).uniform_(*self.hue)
472 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1)
473 | h = x[:, 0, :, :]
474 | h += (factor.view(len(x), 1, 1) * 255. / 360.)
475 | h = (h % 1)
476 | x[:, 0, :, :] = h
477 | return x
478 |
479 | def adjust_brightness(self, x):
480 | """
481 | Args:
482 | x: torch tensor img (hsv type)
483 | Factor:
484 | torch tensor with same length as x
485 | 0 gives black image, 1 gives original image,
486 | 2 gives the brightness factor of 2.
487 | Returns:
488 | torch tensor image: Brightness adjusted
489 | """
490 | _device = x.device
491 | factor = torch.empty(x.shape[0], device=_device).uniform_(*self.brightness)
492 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1)
493 | x[:, 2, :, :] = torch.clamp(x[:, 2, :, :]
494 | * factor.view(len(x), 1, 1), 0, 1)
495 | return torch.clamp(x, 0, 1)
496 |
497 | def adjust_saturate(self, x):
498 | """
499 | Args:
500 | x: torch tensor img (hsv type)
501 | Factor:
502 | torch tensor with same length as x
503 | 0 gives black image and white, 1 gives original image,
504 | 2 gives the brightness factor of 2.
505 | Returns:
506 | torch tensor image: Brightness adjusted
507 | """
508 | _device = x.device
509 | factor = torch.empty(self.batch_size, device=_device).uniform_(*self.saturation)
510 | factor = factor.reshape(-1,1).repeat(1, x.shape[1]).reshape(-1)
511 | x[:, 1, :, :] = torch.clamp(x[:, 1, :, :]
512 | * factor.view(len(x), 1, 1), 0, 1)
513 | return torch.clamp(x, 0, 1)
514 |
515 | def transform(self, inputs):
516 | hsv_transform_list = [rgb2hsv, self.adjust_brightness,
517 | self.adjust_hue, self.adjust_saturate,
518 | hsv2rgb]
519 | rgb_transform_list = [self.adjust_contrast]
520 | # Shuffle transform
521 | if random.uniform(0,1) >= 0.5:
522 | transform_list = rgb_transform_list + hsv_transform_list
523 | else:
524 | transform_list = hsv_transform_list + rgb_transform_list
525 | for t in transform_list:
526 | inputs = t(inputs)
527 | return inputs
528 |
529 | def forward(self, inputs):
530 | _device = inputs.device
531 | random_inds = np.random.choice(
532 | [True, False], len(inputs), p=[self.prob, 1 - self.prob])
533 | inds = torch.tensor(random_inds).to(_device)
534 | if random_inds.sum() > 0:
535 | inputs[inds] = self.transform(inputs[inds])
536 | return inputs
537 |
538 | class StreamNorm:
539 | def __init__(self, shape=(), momentum=0.99, scale=1.0, eps=1e-8, device='cuda'):
540 | # Momentum of 0 normalizes only based on the current batch.
541 | # Momentum of 1 disables normalization.
542 | self.device = device
543 | self._shape = tuple(shape)
544 | self._momentum = momentum
545 | self._scale = scale
546 | self._eps = eps
547 | self.mag = None # torch.ones(shape).to(self.device)
548 |
549 | self.step = 0
550 | self.mean = None # torch.zeros(shape).to(self.device)
551 | self.square_mean = None # torch.zeros(shape).to(self.device)
552 |
553 | def __call__(self, inputs):
554 | metrics = {}
555 | self.update(inputs)
556 | metrics['mean'] = inputs.mean()
557 | metrics['std'] = inputs.std()
558 | outputs = self.transform(inputs)
559 | metrics['normed_mean'] = outputs.mean()
560 | metrics['normed_std'] = outputs.std()
561 | return outputs, metrics
562 |
563 | def reset(self):
564 | self.mag = None # torch.ones_like(self.mag).to(self.device)
565 |
566 | self.step = 0
567 | self.mean = None # torch.zeros_like(self.mean).to(self.device)
568 | self.square_mean = None # torch.zeros_like(self.square_mean).to(self.device)
569 |
570 | def update(self, inputs):
571 | batch = inputs.reshape((-1,) + self._shape)
572 |
573 | mag = torch.abs(batch).mean(0)
574 | if self.mag is not None:
575 | self.mag.data = self._momentum * self.mag.data + (1 - self._momentum) * mag
576 | else:
577 | self.mag = mag.clone()
578 |
579 | self.step += 1
580 |
581 | mean = torch.mean(batch)
582 | if self.mean is not None:
583 | self.mean.data = self._momentum * self.mean.data + (1 - self._momentum) * mean
584 | else:
585 | self.mean = mean.clone()
586 |
587 | square_mean = torch.mean(batch * batch)
588 | if self.square_mean is not None:
589 | self.square_mean.data = self._momentum * self.square_mean.data + (1 - self._momentum) * square_mean
590 | else:
591 | self.square_mean = square_mean.clone()
592 |
593 | def transform(self, inputs):
594 | values = inputs.reshape((-1,) + self._shape)
595 | values /= self.mag[None] + self._eps
596 | values *= self._scale
597 | return values.reshape(inputs.shape)
598 |
599 | def corrected_mean_var_std(self,):
600 | corr = 1
601 | corr_mean = self.mean / corr
602 | corr_var = (self.square_mean / corr) - self.mean ** 2
603 | corr_std = torch.sqrt(torch.maximum(corr_var, torch.zeros_like(corr_var, device=self.device)) + self._eps)
604 | return corr_mean, corr_var, corr_std
605 |
--------------------------------------------------------------------------------
/xvfb_run.sh:
--------------------------------------------------------------------------------
1 | xvfb-run -a -s '-screen 0 1024x768x24 -ac +extension GLX +render -noreset' "$@"
--------------------------------------------------------------------------------