├── LICENSE ├── README.md ├── envs └── __init__.py ├── models ├── ppo_mae.py └── pretrain_models.py ├── requirements.txt ├── teaser.png ├── train.py └── utils ├── add_tactile.py ├── callbacks.py ├── frame_stack.py ├── pretrain_utils.py ├── resize_dict.py └── wandb_logger.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Robot Learning Lab (RLL) lab @ UC Berkeley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ================================================================================ 24 | License for contents from vit-pytorch 25 | ================================================================================ 26 | MIT License 27 | 28 | Copyright (c) 2020 Phil Wang 29 | 30 | Permission is hereby granted, free of charge, to any person obtaining a copy 31 | of this software and associated documentation files (the "Software"), to deal 32 | in the Software without restriction, including without limitation the rights 33 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 34 | copies of the Software, and to permit persons to whom the Software is 35 | furnished to do so, subject to the following conditions: 36 | 37 | The above copyright notice and this permission notice shall be included in all 38 | copies or substantial portions of the Software. 39 | 40 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 41 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 42 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 43 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 44 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 45 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 46 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Power of the Senses: Generalizable Manipulation from Vision and Touch through Masked Multimodal Learning 2 | 3 | [Paper](https://arxiv.org/abs/2311.00924) [Website](https://sferrazza.cc/m3l_site/) 4 | 5 | Masked Multimodal Learning (**M3L**) is a representation learning technique for reinforcement learning that targets robotic manipulation systems provided with vision and high-resolution touch. 6 | 7 | ![image](teaser.png) 8 | 9 | ## Installation 10 | Please install [`tactile_envs`](https://github.com/carlosferrazza/tactile_envs.git) first. Then, install the remaining dependencies: 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Training M3L 16 | ``` 17 | MUJOCO_GL='egl' python train.py --env tactile_envs/Insertion-v0 18 | ``` 19 | 20 | ## Training M3L (vision policy) 21 | ``` 22 | MUJOCO_GL='egl' python train.py --env tactile_envs/Insertion-v0 --vision_only_control True 23 | ``` 24 | ## Citation 25 | If you find M3L useful for your research, please cite this work: 26 | ``` 27 | @article{sferrazza2023power, 28 | title={The power of the senses: Generalizable manipulation from vision and touch through masked multimodal learning}, 29 | author={Sferrazza, Carmelo and Seo, Younggyo and Liu, Hao and Lee, Youngwoon and Abbeel, Pieter}, 30 | journal={arXiv preprint arXiv:2311.00924}, 31 | year={2023} 32 | } 33 | ``` 34 | 35 | ## References 36 | This codebase contains some files adapted from other sources: 37 | * vit-pytorch: https://github.com/lucidrains/vit-pytorch 38 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.monitor import Monitor 2 | 3 | from utils.frame_stack import FrameStack 4 | 5 | import gymnasium as gym 6 | from gymnasium.wrappers.pixel_observation import PixelObservationWrapper 7 | from tactile_envs.utils.resize_dict import ResizeDict 8 | from tactile_envs.utils.add_tactile import AddTactile 9 | 10 | import numpy as np 11 | 12 | def make_env( 13 | env_name, 14 | rank, 15 | seed=0, 16 | state_type="vision_and_touch", 17 | camera_idx=0, 18 | objects=["square"], 19 | holders=["holder2"], 20 | frame_stack=1, 21 | no_rotation=True, 22 | skip_frame=2, 23 | **kwargs, 24 | ): 25 | """ 26 | Utility function for multiprocessed env. 27 | 28 | :param rank: (int) index of the subprocess 29 | :param seed: (int) the inital seed for RNG 30 | """ 31 | 32 | def _init(): 33 | if env_name in ["Door"]: 34 | import robosuite as suite 35 | from robosuite.wrappers.tactile_wrapper import TactileWrapper 36 | from robosuite import load_controller_config 37 | 38 | config = load_controller_config(default_controller="OSC_POSE") 39 | 40 | # Notice how the environment is wrapped by the wrapper 41 | if env_name == "Door": 42 | robots = ["PandaTactile"] 43 | placement_initializer = None 44 | init_qpos = [-0.073, 0.016, -0.392, -2.502, 0.240, 2.676, 0.189] 45 | env_config = kwargs.copy() 46 | env_config["robot_configs"] = [{"initial_qpos": init_qpos}] 47 | env_config["initialization_noise"] = None 48 | 49 | env = TactileWrapper( 50 | suite.make( 51 | env_name, 52 | robots=robots, # use PandaTactile robot 53 | use_camera_obs=True, # use pixel observations 54 | use_object_obs=False, 55 | has_offscreen_renderer=True, # needed for pixel obs 56 | has_renderer=False, # not needed due to offscreen rendering 57 | reward_shaping=True, # use dense rewards 58 | camera_names="agentview", 59 | horizon=300, 60 | controller_configs=config, 61 | placement_initializer=placement_initializer, 62 | camera_heights=64, 63 | camera_widths=64, 64 | **env_config, 65 | ), 66 | env_id=rank, 67 | state_type=state_type, 68 | ) 69 | env = FrameStack(env, frame_stack) 70 | elif env_name in ["HandManipulateBlockRotateZFixed-v1", "HandManipulateEggRotateFixed-v1", "HandManipulatePenRotateFixed-v1"]: 71 | env = gym.make(env_name, render_mode="rgb_array", reward_type='dense') 72 | env = PixelObservationWrapper(env, pixel_keys=('image',)) 73 | env = ResizeDict(env, 64, pixel_key='image') 74 | if state_type == "vision_and_touch": 75 | env = AddTactile(env) 76 | env = FrameStack(env, frame_stack) 77 | else: 78 | 79 | env = gym.make( 80 | env_name, 81 | state_type=state_type, 82 | camera_idx=camera_idx, 83 | symlog_tactile=True, 84 | env_id=rank, 85 | holders=holders, 86 | objects=objects, 87 | no_rotation=no_rotation, 88 | skip_frame=skip_frame, 89 | ) 90 | env = FrameStack(env, frame_stack) 91 | 92 | env = Monitor(env) 93 | np.random.seed(seed + rank) 94 | return env 95 | 96 | return _init 97 | -------------------------------------------------------------------------------- /models/ppo_mae.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, Optional, Type, TypeVar, Union 3 | 4 | import sys 5 | 6 | import numpy as np 7 | import torch as th 8 | from gymnasium import spaces 9 | from torch.nn import functional as F 10 | 11 | from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm 12 | from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy 13 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule 14 | from stable_baselines3.common.utils import explained_variance, get_schedule_fn 15 | 16 | from utils.pretrain_utils import vt_load 17 | 18 | import copy 19 | 20 | 21 | SelfPPO = TypeVar("SelfPPO", bound="PPO") 22 | 23 | 24 | class PPO_MAE(OnPolicyAlgorithm): 25 | """ 26 | Proximal Policy Optimization algorithm (PPO) (clip version) 27 | 28 | Paper: https://arxiv.org/abs/1707.06347 29 | Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) 30 | https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and 31 | Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines) 32 | 33 | Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html 34 | 35 | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) 36 | :param env: The environment to learn from (if registered in Gym, can be str) 37 | :param learning_rate: The learning rate, it can be a function 38 | of the current progress remaining (from 1 to 0) 39 | :param n_steps: The number of steps to run for each environment per update 40 | (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) 41 | NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) 42 | See https://github.com/pytorch/pytorch/issues/29372 43 | :param batch_size: Minibatch size 44 | :param n_epochs: Number of epoch when optimizing the surrogate loss 45 | :param gamma: Discount factor 46 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator 47 | :param clip_range: Clipping parameter, it can be a function of the current progress 48 | remaining (from 1 to 0). 49 | :param clip_range_vf: Clipping parameter for the value function, 50 | it can be a function of the current progress remaining (from 1 to 0). 51 | This is a parameter specific to the OpenAI implementation. If None is passed (default), 52 | no clipping will be done on the value function. 53 | IMPORTANT: this clipping depends on the reward scaling. 54 | :param normalize_advantage: Whether to normalize or not the advantage 55 | :param ent_coef: Entropy coefficient for the loss calculation 56 | :param vf_coef: Value function coefficient for the loss calculation 57 | :param max_grad_norm: The maximum value for the gradient clipping 58 | :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) 59 | instead of action noise exploration (default: False) 60 | :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE 61 | Default: -1 (only sample at the beginning of the rollout) 62 | :param target_kl: Limit the KL divergence between updates, 63 | because the clipping is not enough to prevent large update 64 | see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) 65 | By default, there is no limit on the kl div. 66 | :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average 67 | the reported success rate, mean episode length, and mean reward over 68 | :param tensorboard_log: the log location for tensorboard (if None, no logging) 69 | :param policy_kwargs: additional arguments to be passed to the policy on creation 70 | :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for 71 | debug messages 72 | :param seed: Seed for the pseudo random generators 73 | :param device: Device (cpu, cuda, ...) on which the code should be run. 74 | Setting it to auto, the code will be run on the GPU if possible. 75 | :param _init_setup_model: Whether or not to build the network at the creation of the instance 76 | """ 77 | 78 | policy_aliases: Dict[str, Type[BasePolicy]] = { 79 | "MlpPolicy": ActorCriticPolicy, 80 | "CnnPolicy": ActorCriticCnnPolicy, 81 | "MultiInputPolicy": MultiInputActorCriticPolicy, 82 | } 83 | 84 | def __init__( 85 | self, 86 | policy: Union[str, Type[ActorCriticPolicy]], 87 | env: Union[GymEnv, str], 88 | learning_rate: Union[float, Schedule] = 3e-4, 89 | n_steps: int = 2048, 90 | batch_size: int = 64, 91 | n_epochs: int = 10, 92 | gamma: float = 0.99, 93 | gae_lambda: float = 0.95, 94 | clip_range: Union[float, Schedule] = 0.2, 95 | clip_range_vf: Union[None, float, Schedule] = None, 96 | normalize_advantage: bool = True, 97 | ent_coef: float = 0.0, 98 | vf_coef: float = 0.5, 99 | max_grad_norm: float = 0.5, 100 | use_sde: bool = False, 101 | sde_sample_freq: int = -1, 102 | target_kl: Optional[float] = None, 103 | stats_window_size: int = 100, 104 | tensorboard_log: Optional[str] = None, 105 | policy_kwargs: Optional[Dict[str, Any]] = None, 106 | verbose: int = 0, 107 | seed: Optional[int] = None, 108 | device: Union[th.device, str] = "auto", 109 | mae = None, 110 | mae_batch_size = 32, 111 | separate_optimizer = False, 112 | _init_setup_model: bool = True, 113 | ): 114 | super().__init__( 115 | policy, 116 | env, 117 | learning_rate=learning_rate, 118 | n_steps=n_steps, 119 | gamma=gamma, 120 | gae_lambda=gae_lambda, 121 | ent_coef=ent_coef, 122 | vf_coef=vf_coef, 123 | max_grad_norm=max_grad_norm, 124 | use_sde=use_sde, 125 | sde_sample_freq=sde_sample_freq, 126 | stats_window_size=stats_window_size, 127 | tensorboard_log=tensorboard_log, 128 | policy_kwargs=policy_kwargs, 129 | verbose=verbose, 130 | device=device, 131 | seed=seed, 132 | _init_setup_model=False, 133 | supported_action_spaces=( 134 | spaces.Box, 135 | spaces.Discrete, 136 | spaces.MultiDiscrete, 137 | spaces.MultiBinary, 138 | ), 139 | ) 140 | 141 | 142 | 143 | # Sanity check, otherwise it will lead to noisy gradient and NaN 144 | # because of the advantage normalization 145 | if normalize_advantage: 146 | assert ( 147 | batch_size > 1 148 | ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" 149 | 150 | if self.env is not None: 151 | # Check that `n_steps * n_envs > 1` to avoid NaN 152 | # when doing advantage normalization 153 | buffer_size = self.env.num_envs * self.n_steps 154 | assert buffer_size > 1 or ( 155 | not normalize_advantage 156 | ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" 157 | # Check that the rollout buffer size is a multiple of the mini-batch size 158 | untruncated_batches = buffer_size // batch_size 159 | if buffer_size % batch_size > 0: 160 | warnings.warn( 161 | f"You have specified a mini-batch size of {batch_size}," 162 | f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," 163 | f" after every {untruncated_batches} untruncated mini-batches," 164 | f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" 165 | f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" 166 | f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" 167 | ) 168 | self.batch_size = batch_size 169 | self.n_epochs = n_epochs 170 | self.clip_range = clip_range 171 | self.clip_range_vf = clip_range_vf 172 | self.normalize_advantage = normalize_advantage 173 | self.target_kl = target_kl 174 | 175 | self.mae_batch_size = mae_batch_size 176 | self.separate_optimizer = separate_optimizer 177 | 178 | if _init_setup_model: 179 | self._setup_model() 180 | self.mae = mae 181 | self.mae_optimizer = th.optim.Adam(self.mae.parameters(), lr=1e-4) 182 | 183 | def _setup_model(self) -> None: 184 | super()._setup_model() 185 | 186 | # Initialize schedules for policy/value clipping 187 | self.clip_range = get_schedule_fn(self.clip_range) 188 | if self.clip_range_vf is not None: 189 | if isinstance(self.clip_range_vf, (float, int)): 190 | assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" 191 | 192 | self.clip_range_vf = get_schedule_fn(self.clip_range_vf) 193 | 194 | def load_mae(self, mae): 195 | self.mae = mae 196 | self.mae_optimizer = th.optim.Adam(self.mae.parameters(), lr=1e-4) 197 | 198 | def train(self) -> None: 199 | """ 200 | Update policy using the currently gathered rollout buffer. 201 | """ 202 | # Switch to train mode (this affects batch norm / dropout) 203 | self.policy.set_training_mode(True) 204 | # Update optimizer learning rate 205 | self._update_learning_rate(self.policy.optimizer) 206 | # Compute current clip range 207 | clip_range = self.clip_range(self._current_progress_remaining) 208 | # Optional: clip range for the value function 209 | if self.clip_range_vf is not None: 210 | clip_range_vf = self.clip_range_vf(self._current_progress_remaining) 211 | 212 | entropy_losses = [] 213 | pg_losses, value_losses = [], [] 214 | clip_fractions = [] 215 | 216 | continue_training = True 217 | # train for n_epochs epochs 218 | for epoch in range(self.n_epochs): 219 | approx_kl_divs = [] 220 | # Do a complete pass on the rollout buffer 221 | for rollout_data in self.rollout_buffer.get(self.batch_size): 222 | 223 | try: 224 | n_iter = rollout_data.observations['image'].shape[0] // self.mae_batch_size 225 | except: 226 | n_iter = rollout_data.observations['tactile'].shape[0] // self.mae_batch_size 227 | 228 | # self.policy.optimizer.zero_grad() # NEW 229 | 230 | observations = rollout_data.observations 231 | # print("image shape: ", observations['image'].shape) 232 | # print("tactile shape: ", observations['tactile'].shape) 233 | frame_stack = 1 234 | if 'image' in observations and len(observations['image'].shape) == 5: 235 | frame_stack = observations['image'].shape[1] 236 | observations['image'] = observations['image'].permute(0, 2, 3, 1, 4) 237 | observations['image'] = observations['image'].reshape((observations['image'].shape[0], observations['image'].shape[1], observations['image'].shape[2], -1)) 238 | if 'tactile' in observations and len(observations['tactile'].shape) == 5: 239 | frame_stack = observations['tactile'].shape[1] 240 | observations['tactile'] = observations['tactile'].reshape((observations['tactile'].shape[0], -1, observations['tactile'].shape[3], observations['tactile'].shape[4])) 241 | 242 | # x = vt_load(copy.deepcopy(observations), frame_stack=frame_stack) 243 | # mae_loss = self.mae(x) 244 | # mae_loss.backward() 245 | 246 | if not self.separate_optimizer: 247 | self.policy.optimizer.zero_grad() 248 | n_iter = 1 249 | 250 | for i in range(n_iter): 251 | # Optimization step 252 | 253 | if self.separate_optimizer: 254 | self.mae_optimizer.zero_grad() 255 | 256 | x = vt_load(copy.deepcopy({k: v[i*self.mae_batch_size:(i+1)*self.mae_batch_size] for k, v in observations.items()}), frame_stack=frame_stack) 257 | else: 258 | x = vt_load(copy.deepcopy(observations), frame_stack=frame_stack) 259 | 260 | mae_loss = self.mae(x) 261 | mae_loss.backward() 262 | 263 | if self.separate_optimizer: 264 | self.mae_optimizer.step() 265 | 266 | if self.separate_optimizer: 267 | self.policy.optimizer.zero_grad() 268 | 269 | actions = rollout_data.actions 270 | if isinstance(self.action_space, spaces.Discrete): 271 | # Convert discrete action from float to long 272 | actions = rollout_data.actions.long().flatten() 273 | 274 | # Re-sample the noise matrix because the log_std has changed 275 | if self.use_sde: 276 | self.policy.reset_noise(self.batch_size) 277 | 278 | values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) 279 | values = values.flatten() 280 | # Normalize advantage 281 | advantages = rollout_data.advantages 282 | # Normalization does not make sense if mini batchsize == 1, see GH issue #325 283 | if self.normalize_advantage and len(advantages) > 1: 284 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) 285 | 286 | # ratio between old and new policy, should be one at the first iteration 287 | ratio = th.exp(log_prob - rollout_data.old_log_prob) 288 | 289 | # clipped surrogate loss 290 | policy_loss_1 = advantages * ratio 291 | policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range) 292 | policy_loss = -th.min(policy_loss_1, policy_loss_2).mean() 293 | 294 | # Logging 295 | pg_losses.append(policy_loss.item()) 296 | clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item() 297 | clip_fractions.append(clip_fraction) 298 | 299 | if self.clip_range_vf is None: 300 | # No clipping 301 | values_pred = values 302 | else: 303 | # Clip the difference between old and new value 304 | # NOTE: this depends on the reward scaling 305 | values_pred = rollout_data.old_values + th.clamp( 306 | values - rollout_data.old_values, -clip_range_vf, clip_range_vf 307 | ) 308 | # Value loss using the TD(gae_lambda) target 309 | value_loss = F.mse_loss(rollout_data.returns, values_pred) 310 | value_losses.append(value_loss.item()) 311 | 312 | # Entropy loss favor exploration 313 | if entropy is None: 314 | # Approximate entropy when no analytical form 315 | entropy_loss = -th.mean(-log_prob) 316 | else: 317 | entropy_loss = -th.mean(entropy) 318 | 319 | entropy_losses.append(entropy_loss.item()) 320 | 321 | loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss 322 | 323 | # Calculate approximate form of reverse KL Divergence for early stopping 324 | # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417 325 | # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419 326 | # and Schulman blog: http://joschu.net/blog/kl-approx.html 327 | with th.no_grad(): 328 | log_ratio = log_prob - rollout_data.old_log_prob 329 | approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy() 330 | approx_kl_divs.append(approx_kl_div) 331 | 332 | if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl: 333 | continue_training = False 334 | if self.verbose >= 1: 335 | print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}") 336 | break 337 | 338 | loss.backward() 339 | # Clip grad norm 340 | th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) 341 | self.policy.optimizer.step() 342 | 343 | self._n_updates += 1 344 | if not continue_training: 345 | break 346 | 347 | explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) 348 | 349 | # Logs 350 | self.logger.record("train/entropy_loss", np.mean(entropy_losses)) 351 | self.logger.record("train/policy_gradient_loss", np.mean(pg_losses)) 352 | self.logger.record("train/value_loss", np.mean(value_losses)) 353 | self.logger.record("train/approx_kl", np.mean(approx_kl_divs)) 354 | self.logger.record("train/clip_fraction", np.mean(clip_fractions)) 355 | self.logger.record("train/loss", loss.item()) 356 | self.logger.record("train/explained_variance", explained_var) 357 | 358 | self.logger.record("train/mae_loss", mae_loss.item()) 359 | if hasattr(self.policy, "log_std"): 360 | self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) 361 | 362 | self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") 363 | self.logger.record("train/clip_range", clip_range) 364 | if self.clip_range_vf is not None: 365 | self.logger.record("train/clip_range_vf", clip_range_vf) 366 | 367 | def learn( 368 | self: SelfPPO, 369 | total_timesteps: int, 370 | callback: MaybeCallback = None, 371 | log_interval: int = 1, 372 | tb_log_name: str = "PPO", 373 | reset_num_timesteps: bool = True, 374 | progress_bar: bool = False, 375 | ) -> SelfPPO: 376 | return super().learn( 377 | total_timesteps=total_timesteps, 378 | callback=callback, 379 | log_interval=log_interval, 380 | tb_log_name=tb_log_name, 381 | reset_num_timesteps=reset_num_timesteps, 382 | progress_bar=progress_bar, 383 | ) 384 | -------------------------------------------------------------------------------- /models/pretrain_models.py: -------------------------------------------------------------------------------- 1 | from vit_pytorch.vit import pair, Transformer 2 | import torch 3 | from torch import nn 4 | from einops.layers.torch import Rearrange 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from einops import repeat 10 | import gymnasium as gym 11 | 12 | from stable_baselines3.common.torch_layers import ( 13 | BaseFeaturesExtractor, 14 | FlattenExtractor, 15 | ) 16 | from positional_encodings.torch_encodings import PositionalEncoding2D 17 | 18 | from typing import Any, Dict, List, Optional, Tuple, Type, Union 19 | 20 | from gymnasium import spaces 21 | from stable_baselines3.common.type_aliases import Schedule 22 | 23 | import random 24 | 25 | from stable_baselines3.common.policies import ActorCriticPolicy 26 | 27 | from vit_pytorch.vit import Transformer 28 | 29 | from utils.pretrain_utils import vt_load 30 | 31 | from tqdm import tqdm 32 | 33 | import torch.optim as optim 34 | 35 | import numpy as np 36 | 37 | class EarlyCNN(nn.Module): 38 | def __init__(self, in_channels, encoder_dim, key='image'): 39 | super().__init__() 40 | 41 | self.conv1 = nn.Conv2d(in_channels, encoder_dim//8, 4, stride=2, padding=1) 42 | self.conv2 = nn.Conv2d(encoder_dim//8, encoder_dim//4, 4, stride=2, padding=1) 43 | if key == 'image': 44 | self.conv3 = nn.Conv2d(encoder_dim//4, encoder_dim//2, 4, stride=2, padding=1) 45 | else: 46 | self.conv3 = nn.Conv2d(encoder_dim//4, encoder_dim//2, 3, stride=1, padding=1) 47 | 48 | self.conv4 = nn.Conv2d(encoder_dim//2, encoder_dim, 1) 49 | 50 | def forward(self, x): 51 | 52 | x = F.relu(self.conv1(x)) 53 | x = F.relu(self.conv2(x)) 54 | x = F.relu(self.conv3(x)) 55 | 56 | return self.conv4(x).flatten(2).transpose(1, 2) 57 | 58 | 59 | class VTMAE(nn.Module): 60 | def __init__( 61 | self, 62 | *, 63 | encoder, 64 | decoder_dim, 65 | masking_ratio = 0.75, 66 | decoder_depth = 1, 67 | decoder_heads = 8, 68 | decoder_dim_head = 64, 69 | num_tactiles = 2, 70 | early_conv_masking = False, 71 | use_sincosmod_encodings = True, 72 | frame_stack = 1, 73 | ): 74 | super().__init__() 75 | assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1' 76 | self.masking_ratio = masking_ratio 77 | 78 | self.num_tactiles = num_tactiles 79 | 80 | self.frame_stack = frame_stack 81 | 82 | # extract some hyperparameters and functions from encoder (vision transformer to be trained) 83 | 84 | self.encoder = encoder 85 | num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] 86 | 87 | num_decoder_patches = num_patches - 1 88 | 89 | self.use_sincosmod_encodings = use_sincosmod_encodings 90 | 91 | print("num_patches: ", num_patches) 92 | print("num_decoder_patches: ", num_decoder_patches) 93 | 94 | self.early_conv_masking = early_conv_masking 95 | if self.early_conv_masking: 96 | self.early_conv_vision = EarlyCNN(self.encoder.image_channels, encoder_dim, key='image') 97 | self.early_conv_tactile = EarlyCNN(self.encoder.tactile_channels, encoder_dim, key='tactile') 98 | 99 | self.image_to_patch = encoder.image_to_patch_embedding[0] 100 | self.image_patch_to_emb = nn.Sequential(*encoder.image_to_patch_embedding[1:]) 101 | pixel_values_per_patch = encoder.image_to_patch_embedding[2].weight.shape[-1] 102 | 103 | self.tactile_to_patch = encoder.tactile_to_patch_embedding[0] 104 | self.tactile_patch_to_emb = nn.Sequential(*encoder.tactile_to_patch_embedding[1:]) 105 | tactile_values_per_patch = encoder.tactile_to_patch_embedding[2].weight.shape[-1] 106 | 107 | self.encoder_dim = encoder_dim 108 | 109 | # decoder parameters 110 | self.decoder_dim = decoder_dim 111 | self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity() 112 | self.mask_token = nn.Parameter(torch.randn(decoder_dim)) 113 | self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4) 114 | self.decoder_pos_emb = nn.Embedding(num_decoder_patches, decoder_dim) 115 | self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch) 116 | self.to_tactiles = nn.Linear(decoder_dim, tactile_values_per_patch) 117 | 118 | self.num_tactiles = num_tactiles 119 | 120 | enc_pos_embedding = PositionalEncoding2D(encoder_dim) 121 | 122 | sample_image = torch.zeros((1, self.encoder.image_height//self.encoder.image_patch_height, self.encoder.image_width//self.encoder.image_patch_width, encoder_dim)) 123 | 124 | image_pos_embedding = enc_pos_embedding(sample_image).flatten(1,2) # 1 x 1image_patches x encoder_dim 125 | print("image_pos_embedding.shape: ", image_pos_embedding.shape) 126 | self.register_buffer('image_enc_pos_embedding', image_pos_embedding) # 1 x image_patches x encoder_dim 127 | 128 | sample_tactile = torch.zeros((1, self.encoder.tactile_height//self.encoder.tactile_patch_height, self.encoder.tactile_width//self.encoder.tactile_patch_width, encoder_dim)) 129 | 130 | tactile_pos_embedding = enc_pos_embedding(sample_tactile).flatten(1,2) # 1 x 1tactile_patches x encoder_dim 131 | 132 | self.register_buffer('tactile_enc_pos_embedding', repeat(tactile_pos_embedding, 'b n d -> b (v n) d', v = self.num_tactiles)) # 1 x tactile_patches x encoder_dim 133 | 134 | sample_image = torch.zeros((1, self.encoder.image_height//self.encoder.image_patch_height, self.encoder.image_width//self.encoder.image_patch_width, decoder_dim)) 135 | image_pos_embedding = enc_pos_embedding(sample_image).flatten(1,2) # 1 x 1image_patches x decoder_dim 136 | self.register_buffer('image_dec_pos_embedding', image_pos_embedding) # 1 x image_patches x decoder_dim 137 | 138 | sample_tactile = torch.zeros((1, self.encoder.tactile_height//self.encoder.tactile_patch_height, self.encoder.tactile_width//self.encoder.tactile_patch_width, decoder_dim)) 139 | tactile_pos_embedding = enc_pos_embedding(sample_tactile).flatten(1,2) # 1 x 1tactile_patches x decoder_dim 140 | self.register_buffer('tactile_dec_pos_embedding', repeat(tactile_pos_embedding, 'b n d -> b (v n) d', v = self.num_tactiles)) # 1 x tactile_patches x decoder_dim 141 | 142 | self.encoder_modality_embedding = nn.Embedding((1 + self.num_tactiles), encoder_dim) 143 | self.decoder_modality_embedding = nn.Embedding((1 + self.num_tactiles), decoder_dim) 144 | 145 | 146 | def forward(self, x, use_vision=True, use_tactile=True): 147 | 148 | if 'image' in x.keys(): 149 | device = x['image'].device 150 | else: 151 | device = x['tactile1'].device 152 | use_vision = False 153 | 154 | # get patches 155 | 156 | if use_vision: 157 | image_patches = self.image_to_patch(x['image']) 158 | batch, num_image_patches, *_ = image_patches.shape 159 | else: 160 | image_patches = torch.zeros((x['tactile1'].shape[0], 0, 3)).to(device) 161 | num_image_patches = 0 162 | 163 | if self.num_tactiles > 0 and use_tactile: 164 | tactile_patches_list = [] 165 | for i in range(1,self.num_tactiles+1): 166 | tactile_patches_list.append(self.tactile_to_patch(x['tactile'+str(i)])) 167 | 168 | tactile_patches = torch.cat(tactile_patches_list, dim=1) 169 | batch, num_tactile_patches, *_ = tactile_patches.shape 170 | else: 171 | tactile_patches = torch.zeros((x['image'].shape[0], 0, 3)).to(device) 172 | num_tactile_patches = 0 173 | 174 | num_patches = num_image_patches + num_tactile_patches 175 | 176 | num_decoder_patches = num_patches 177 | 178 | # patch to encoder tokens and add positions 179 | 180 | if self.early_conv_masking: 181 | if use_vision: 182 | image_tokens = self.early_conv_vision(x['image']) 183 | else: 184 | image_tokens = torch.zeros((batch, 0, self.encoder_dim)).to(device) 185 | if self.num_tactiles > 0 and use_tactile: 186 | tactile_tokens_list = [] 187 | for i in range(1,self.num_tactiles+1): 188 | tactile_tokens_list.append(self.early_conv_tactile(x['tactile'+str(i)])) 189 | tactile_tokens = torch.cat(tactile_tokens_list, dim=1) 190 | else: 191 | tactile_tokens = torch.zeros((batch, 0, image_tokens.shape[-1])).to(device) 192 | else: 193 | if use_vision: 194 | image_tokens = self.image_patch_to_emb(image_patches) 195 | else: 196 | image_tokens = torch.zeros((batch, 0, self.encoder_dim)).to(device) 197 | if self.num_tactiles > 0 and use_tactile: 198 | tactile_tokens = self.tactile_patch_to_emb(tactile_patches) 199 | else: 200 | tactile_tokens = torch.zeros((batch, 0, image_tokens.shape[-1])).to(device) 201 | 202 | if use_vision: 203 | 204 | if self.use_sincosmod_encodings: 205 | image_tokens += self.encoder_modality_embedding(torch.tensor(0, device = device)) 206 | image_tokens = image_tokens + self.image_enc_pos_embedding 207 | 208 | if self.num_tactiles > 0 and use_tactile: 209 | num_single_tactile_patches = num_tactile_patches//(self.num_tactiles) 210 | for i in range(self.num_tactiles): 211 | if self.use_sincosmod_encodings: 212 | tactile_tokens[:, i*num_single_tactile_patches:(i+1)*num_single_tactile_patches] += self.encoder_modality_embedding(torch.tensor(1+i, device = device)) 213 | if self.use_sincosmod_encodings: 214 | tactile_tokens = tactile_tokens + self.tactile_enc_pos_embedding 215 | 216 | tokens = torch.cat((image_tokens, tactile_tokens), dim=1) 217 | 218 | if not self.use_sincosmod_encodings: 219 | tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] 220 | 221 | # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked 222 | 223 | num_masked = int(self.masking_ratio * num_patches) 224 | image_perc = num_image_patches/num_patches 225 | num_masked_image = int(num_masked * image_perc) 226 | if self.num_tactiles > 0 and use_tactile: 227 | num_masked_tactile = (num_masked - num_masked_image)//self.num_tactiles 228 | 229 | rand_indices_image = torch.rand(batch, num_image_patches, device = device).argsort(dim = -1) 230 | masked_indices_image, unmasked_indices_image = rand_indices_image[:, :num_masked_image], rand_indices_image[:, num_masked_image:] 231 | 232 | if self.num_tactiles > 0 and use_tactile: 233 | masked_indices_tactile = [] 234 | unmasked_indices_tactile = [] 235 | count = num_image_patches 236 | for i in range(self.num_tactiles): 237 | rand_indices_tactile = torch.rand(batch, num_tactile_patches//self.num_tactiles, device = device).argsort(dim = -1)+count 238 | masked_indices_tactile.append(rand_indices_tactile[:, :num_masked_tactile]) 239 | unmasked_indices_tactile.append(rand_indices_tactile[:, num_masked_tactile:]) 240 | count += int(num_tactile_patches/self.num_tactiles) 241 | masked_indices_tactile = torch.cat(masked_indices_tactile, dim=1) 242 | unmasked_indices_tactile = torch.cat(unmasked_indices_tactile, dim=1) 243 | else: 244 | masked_indices_tactile = torch.zeros((batch, 0),dtype=torch.long).to(device) 245 | unmasked_indices_tactile = torch.zeros((batch, 0),dtype=torch.long).to(device) 246 | 247 | masked_indices = torch.cat((masked_indices_image, masked_indices_tactile), dim=1) 248 | unmasked_indices = torch.cat((unmasked_indices_image, unmasked_indices_tactile), dim=1) 249 | 250 | 251 | num_masked = masked_indices.shape[-1] 252 | 253 | # get the unmasked tokens to be encoded 254 | 255 | batch_range = torch.arange(batch, device = device)[:, None] 256 | tokens = tokens[batch_range, unmasked_indices] 257 | 258 | # get the patches to be masked for the final reconstruction loss 259 | 260 | if not self.early_conv_masking: # This is a hack to deal with the fact that masked_indices_image have different lenghts per sample in the batch 261 | masked_image_patches = image_patches[batch_range, masked_indices_image] 262 | masked_tactile_patches = tactile_patches[batch_range, masked_indices_tactile-image_patches.shape[1]] 263 | 264 | # attend with vision transformer 265 | 266 | encoded_tokens = self.encoder.transformer(tokens) 267 | 268 | # project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder 269 | 270 | decoder_tokens = self.enc_to_dec(encoded_tokens) 271 | 272 | # reapply decoder position embedding to unmasked tokens 273 | if self.use_sincosmod_encodings: 274 | unmasked_decoder_tokens = decoder_tokens 275 | else: 276 | unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices) 277 | 278 | # repeat mask tokens for number of masked, and add the positions using the masked indices derived above 279 | mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked) 280 | if not self.use_sincosmod_encodings: 281 | mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) 282 | 283 | # concat the masked tokens to the decoder tokens and attend with decoder 284 | 285 | decoder_tokens = torch.zeros(batch, num_decoder_patches, self.decoder_dim, device=device) 286 | decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens 287 | decoder_tokens[batch_range, masked_indices] = mask_tokens 288 | start_index = 0 289 | end_index = None 290 | 291 | decoder_image_tokens = decoder_tokens[:, start_index:num_image_patches+start_index] 292 | decoder_tactile_tokens = decoder_tokens[:, num_image_patches+start_index:end_index] 293 | 294 | if use_vision: 295 | if self.use_sincosmod_encodings: 296 | decoder_image_tokens += self.decoder_modality_embedding(torch.tensor(0, device=device)) 297 | decoder_image_tokens = decoder_image_tokens + self.image_dec_pos_embedding 298 | 299 | if self.num_tactiles > 0 and use_tactile: 300 | num_single_tactile_patches = num_tactile_patches//(self.num_tactiles) 301 | for i in range(self.num_tactiles): 302 | if self.use_sincosmod_encodings: 303 | decoder_tactile_tokens[:, i*num_single_tactile_patches:(i+1)*num_single_tactile_patches] += self.decoder_modality_embedding(torch.tensor(1+i, device=device)) 304 | if self.use_sincosmod_encodings: 305 | decoder_tactile_tokens = decoder_tactile_tokens + self.tactile_dec_pos_embedding 306 | 307 | decoder_tokens[:,start_index:end_index] = torch.cat((decoder_image_tokens, decoder_tactile_tokens), dim=1) 308 | 309 | decoded_tokens = self.decoder(decoder_tokens) 310 | 311 | if self.early_conv_masking: 312 | image_tokens = decoded_tokens[:, :num_image_patches] 313 | tactile_tokens = decoded_tokens[:, num_image_patches:] 314 | 315 | pred_pixel_values = self.to_pixels(image_tokens) 316 | pred_tactile_values = self.to_tactiles(tactile_tokens) 317 | 318 | recon_loss = 0 319 | if self.num_tactiles > 0 and use_tactile: 320 | recon_loss += 10*F.mse_loss(pred_tactile_values, tactile_patches) 321 | if use_vision: 322 | recon_loss += F.mse_loss(pred_pixel_values, image_patches) 323 | 324 | else: 325 | # splice out the mask tokens and project to pixel values 326 | 327 | mask_image_tokens = decoded_tokens[batch_range, masked_indices_image] 328 | pred_pixel_values = self.to_pixels(mask_image_tokens) 329 | 330 | # splice out the mask tokens and project to tactile values 331 | 332 | mask_tactile_tokens = decoded_tokens[batch_range, masked_indices_tactile] 333 | pred_tactile_values = self.to_tactiles(mask_tactile_tokens) 334 | 335 | # calculate reconstruction loss 336 | recon_loss = 0 337 | if self.num_tactiles > 0 and use_tactile: 338 | recon_loss += 10*F.mse_loss(pred_tactile_values, masked_tactile_patches) 339 | if use_vision: 340 | recon_loss += F.mse_loss(pred_pixel_values, masked_image_patches) 341 | 342 | return recon_loss 343 | 344 | def reconstruct(self, x, mask_ratio=None, use_vision=True, use_tactile=True): 345 | 346 | if mask_ratio is None: 347 | mask_ratio = self.masking_ratio 348 | 349 | if 'image' in x.keys(): 350 | device = x['image'].device 351 | else: 352 | device = x['tactile1'].device 353 | use_vision = False 354 | 355 | # get patches 356 | if use_vision: 357 | image_patches = self.image_to_patch(x['image']) 358 | batch, num_image_patches, *_ = image_patches.shape 359 | else: 360 | image_patches = torch.zeros((x['tactile1'].shape[0], 0, 3)).to(device) 361 | num_image_patches = 0 362 | 363 | if self.num_tactiles > 0 and use_tactile: 364 | tactile_patches_list = [] 365 | for i in range(1,self.num_tactiles+1): 366 | tactile_patches_list.append(self.tactile_to_patch(x['tactile'+str(i)])) 367 | 368 | tactile_patches = torch.cat(tactile_patches_list, dim=1) 369 | batch, num_tactile_patches, *_ = tactile_patches.shape 370 | else: 371 | tactile_patches = torch.zeros((x['image'].shape[0], 0, 3)).to(device) 372 | num_tactile_patches = 0 373 | 374 | num_patches = num_image_patches + num_tactile_patches 375 | 376 | num_decoder_patches = num_patches 377 | 378 | 379 | # patch to encoder tokens and add positions 380 | 381 | if self.early_conv_masking: 382 | if use_vision: 383 | image_tokens = self.early_conv_vision(x['image']) 384 | else: 385 | image_tokens = torch.zeros((batch, 0, self.encoder_dim)).to(device) 386 | if self.num_tactiles > 0 and use_tactile: 387 | tactile_tokens_list = [] 388 | for i in range(1,self.num_tactiles+1): 389 | tactile_tokens_list.append(self.early_conv_tactile(x['tactile'+str(i)])) 390 | tactile_tokens = torch.cat(tactile_tokens_list, dim=1) 391 | else: 392 | tactile_tokens = torch.zeros((batch, 0, image_tokens.shape[-1])).to(device) 393 | else: 394 | if use_vision: 395 | image_tokens = self.image_patch_to_emb(image_patches) 396 | else: 397 | image_tokens = torch.zeros((batch, 0, self.encoder_dim)).to(device) 398 | if self.num_tactiles > 0 and use_tactile: 399 | tactile_tokens = self.tactile_patch_to_emb(tactile_patches) 400 | else: 401 | tactile_tokens = torch.zeros((batch, 0, image_tokens.shape[-1])).to(device) 402 | 403 | if use_vision: 404 | if self.use_sincosmod_encodings: 405 | image_tokens += self.encoder_modality_embedding(torch.tensor(0, device=device)) 406 | image_tokens = image_tokens + self.image_enc_pos_embedding 407 | 408 | if self.num_tactiles > 0 and use_tactile: 409 | num_single_tactile_patches = num_tactile_patches//(self.num_tactiles) 410 | for i in range(self.num_tactiles): 411 | if self.use_sincosmod_encodings: 412 | tactile_tokens[:, i*num_single_tactile_patches:(i+1)*num_single_tactile_patches] += self.encoder_modality_embedding(torch.tensor(1+i, device=device)) 413 | if self.use_sincosmod_encodings: 414 | tactile_tokens = tactile_tokens + self.tactile_enc_pos_embedding 415 | 416 | tokens = torch.cat((image_tokens, tactile_tokens), dim=1) 417 | if not self.use_sincosmod_encodings: 418 | tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] 419 | 420 | # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked 421 | image_patches_vis = image_patches.clone() 422 | tactile_patches_vis = tactile_patches.clone() 423 | 424 | if use_vision: 425 | num_masked_image = int(mask_ratio * num_image_patches) 426 | rand_indices_image = torch.rand(batch, num_image_patches, device = device).argsort(dim = -1) 427 | masked_indices_image, unmasked_indices_image = rand_indices_image[:, :num_masked_image], rand_indices_image[:, num_masked_image:] 428 | else: 429 | masked_indices_image = torch.zeros((batch, 0),dtype=torch.long).to(device) 430 | unmasked_indices_image = torch.zeros((batch, 0),dtype=torch.long).to(device) 431 | 432 | if self.num_tactiles > 0 and use_tactile: 433 | num_masked_tactile = int(mask_ratio * num_tactile_patches / self.num_tactiles) 434 | masked_indices_tactile = [] 435 | unmasked_indices_tactile = [] 436 | count = num_image_patches 437 | for i in range(self.num_tactiles): 438 | rand_indices_tactile = torch.rand(batch, num_tactile_patches//self.num_tactiles, device = device).argsort(dim = -1)+count 439 | masked_indices_tactile.append(rand_indices_tactile[:, :num_masked_tactile]) 440 | unmasked_indices_tactile.append(rand_indices_tactile[:, num_masked_tactile:]) 441 | count += int(num_tactile_patches/self.num_tactiles) 442 | masked_indices_tactile = torch.cat(masked_indices_tactile, dim=1) 443 | unmasked_indices_tactile = torch.cat(unmasked_indices_tactile, dim=1) 444 | else: 445 | masked_indices_tactile = torch.zeros((batch, 0),dtype=torch.long).to(device) 446 | unmasked_indices_tactile = torch.zeros((batch, 0),dtype=torch.long).to(device) 447 | 448 | masked_indices = torch.cat((masked_indices_image, masked_indices_tactile), dim=1) 449 | unmasked_indices = torch.cat((unmasked_indices_image, unmasked_indices_tactile), dim=1) 450 | 451 | num_masked = masked_indices.shape[-1] 452 | 453 | # get the unmasked tokens to be encoded 454 | 455 | batch_range = torch.arange(batch, device = device)[:, None] 456 | tokens = tokens[batch_range, unmasked_indices] 457 | 458 | if not self.early_conv_masking: # Hack: see forward() for explanation 459 | masked_image_patches = image_patches[batch_range, masked_indices_image].clone() 460 | masked_tactile_patches = tactile_patches[batch_range, masked_indices_tactile-image_patches.shape[1]].clone() 461 | 462 | # set the masked patches to 0.5, for visualization (h: num_patches_in_height, w: num_patches_in_width, c: num_channels) 463 | if use_vision: 464 | image_transform = Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', 465 | h = self.encoder.image_height//self.encoder.image_patch_height, w = self.encoder.image_width//self.encoder.image_patch_width, 466 | p1 = self.encoder.image_patch_height, p2 = self.encoder.image_patch_width) 467 | if not self.early_conv_masking: 468 | image_patches[batch_range, masked_indices_image] = 0.5 469 | image_masked = image_transform(image_patches) 470 | else: 471 | image_patches_vis[batch_range, masked_indices_image] = 0.5 472 | image_masked = image_transform(image_patches_vis) 473 | 474 | if self.num_tactiles > 0 and use_tactile: 475 | 476 | tactile_transform = Rearrange('b (n h w) (p1 p2 c) -> b (n c) (h p1) (w p2)', 477 | h = self.encoder.tactile_height//self.encoder.tactile_patch_height, w = self.encoder.tactile_width//self.encoder.tactile_patch_width, 478 | p1 = self.encoder.tactile_patch_height, p2 = self.encoder.tactile_patch_width) 479 | if not self.early_conv_masking: 480 | tactile_patches[batch_range, masked_indices_tactile-image_patches.shape[1]] = np.inf 481 | tactile_masked = tactile_transform(tactile_patches) 482 | else: 483 | tactile_patches_vis[batch_range, masked_indices_tactile-image_patches.shape[1]] = np.inf 484 | tactile_masked = tactile_transform(tactile_patches_vis) 485 | 486 | # attend with vision transformer 487 | 488 | encoded_tokens = self.encoder.transformer(tokens) 489 | 490 | # project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder 491 | 492 | decoder_tokens = self.enc_to_dec(encoded_tokens) 493 | 494 | # reapply decoder position embedding to unmasked tokens 495 | 496 | if self.use_sincosmod_encodings: 497 | unmasked_decoder_tokens = decoder_tokens 498 | else: 499 | unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices) 500 | 501 | # repeat mask tokens for number of masked, and add the positions using the masked indices derived above 502 | 503 | mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked) 504 | if not self.use_sincosmod_encodings: 505 | mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices) 506 | 507 | # concat the masked tokens to the decoder tokens and attend with decoder 508 | 509 | decoder_tokens = torch.zeros(batch, num_decoder_patches, self.decoder_dim, device=device) 510 | decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens 511 | decoder_tokens[batch_range, masked_indices] = mask_tokens 512 | start_index = 0 513 | end_index = None 514 | 515 | decoder_image_tokens = decoder_tokens[:, start_index:num_image_patches+start_index] 516 | decoder_tactile_tokens = decoder_tokens[:, num_image_patches+start_index:end_index] 517 | 518 | if use_vision: 519 | if self.use_sincosmod_encodings: 520 | decoder_image_tokens += self.decoder_modality_embedding(torch.tensor(0, device = device)) 521 | decoder_image_tokens = decoder_image_tokens + self.image_dec_pos_embedding 522 | 523 | 524 | if self.num_tactiles > 0 and use_tactile: 525 | num_single_tactile_patches = num_tactile_patches//self.num_tactiles 526 | for i in range(self.num_tactiles): 527 | if self.use_sincosmod_encodings: 528 | decoder_tactile_tokens[:, i*num_single_tactile_patches:(i+1)*num_single_tactile_patches] += self.decoder_modality_embedding(torch.tensor(1+i, device = device)) 529 | if self.use_sincosmod_encodings: 530 | decoder_tactile_tokens = decoder_tactile_tokens + self.tactile_dec_pos_embedding 531 | 532 | 533 | decoder_tokens[:, start_index:end_index] = torch.cat((decoder_image_tokens, decoder_tactile_tokens), dim=1) 534 | 535 | decoded_tokens = self.decoder(decoder_tokens) 536 | 537 | if self.early_conv_masking: 538 | if use_vision: 539 | image_tokens = decoded_tokens[:, :num_image_patches] 540 | pred_pixel_values = self.to_pixels(image_tokens) 541 | 542 | recon_loss_image = F.mse_loss(pred_pixel_values, image_patches) 543 | image_patches = pred_pixel_values 544 | image_rec = image_transform(image_patches) 545 | 546 | if self.num_tactiles > 0 and use_tactile: 547 | tactile_tokens = decoded_tokens[:, num_image_patches:] 548 | pred_tactile_values = self.to_tactiles(tactile_tokens) 549 | recon_loss_tactile = F.mse_loss(pred_tactile_values, tactile_patches) 550 | tactile_patches = pred_tactile_values 551 | tactile_rec = tactile_transform(tactile_patches) 552 | 553 | else: 554 | # splice out the mask tokens and project to pixel values 555 | if use_vision: 556 | mask_image_tokens = decoded_tokens[batch_range, masked_indices_image] 557 | pred_pixel_values = self.to_pixels(mask_image_tokens) 558 | image_patches[batch_range, masked_indices_image] = pred_pixel_values 559 | image_rec = image_transform(image_patches) 560 | 561 | # splice out the mask tokens and project to tactile values 562 | 563 | if self.num_tactiles > 0 and use_tactile: 564 | mask_tactile_tokens = decoded_tokens[batch_range, masked_indices_tactile] 565 | pred_tactile_values = self.to_tactiles(mask_tactile_tokens) 566 | tactile_patches[batch_range, masked_indices_tactile-image_patches.shape[1]] = pred_tactile_values 567 | tactile_rec = tactile_transform(tactile_patches) 568 | 569 | recon_loss_image = torch.tensor(0, device=device) 570 | recon_loss_tactile = torch.tensor(0, device=device) 571 | if use_vision: 572 | recon_loss_image = F.mse_loss(pred_pixel_values, masked_image_patches) 573 | if self.num_tactiles > 0 and use_tactile: 574 | recon_loss_tactile = F.mse_loss(pred_tactile_values, masked_tactile_patches) 575 | 576 | return_dict = {} 577 | if use_vision: 578 | return_dict['image_rec'] = image_rec 579 | return_dict['image_masked'] = image_masked 580 | return_dict['recon_loss_image'] = recon_loss_image 581 | if self.num_tactiles > 0 and use_tactile: 582 | return_dict['tactile_rec'] = tactile_rec 583 | return_dict['tactile_masked'] = tactile_masked 584 | return_dict['recon_loss_tactile'] = recon_loss_tactile 585 | 586 | return return_dict 587 | 588 | def get_embeddings(self, x, eval=True, use_vision=True, use_tactile=True): 589 | 590 | if eval: 591 | self.eval() 592 | else: 593 | self.train() 594 | 595 | if 'image' in x.keys(): 596 | device = x['image'].device 597 | else: 598 | device = x['tactile1'].device 599 | use_vision = False 600 | 601 | # get patches 602 | if use_vision: 603 | image_patches = self.image_to_patch(x['image']) 604 | batch, num_image_patches, *_ = image_patches.shape 605 | else: 606 | image_patches = torch.zeros((x['tactile1'].shape[0], 0, 3)).to(device) 607 | num_image_patches = 0 608 | 609 | if self.num_tactiles > 0 and use_tactile: 610 | tactile_patches_list = [] 611 | for i in range(1,self.num_tactiles+1): 612 | tactile_patches_list.append(self.tactile_to_patch(x['tactile'+str(i)])) 613 | tactile_patches = torch.cat(tactile_patches_list, dim=1) 614 | batch, num_tactile_patches, *_ = tactile_patches.shape 615 | else: 616 | tactile_patches = torch.zeros((x['image'].shape[0], 0, 3)).to(device) 617 | num_tactile_patches = 0 618 | 619 | num_patches = num_image_patches + num_tactile_patches 620 | 621 | # patch to encoder tokens and add positions 622 | 623 | if self.early_conv_masking: 624 | if use_vision: 625 | image_tokens = self.early_conv_vision(x['image']) 626 | else: 627 | image_tokens = torch.zeros((batch, 0, self.encoder_dim)).to(device) 628 | 629 | if self.num_tactiles > 0 and use_tactile: 630 | tactile_tokens_list = [] 631 | for i in range(1,self.num_tactiles+1): 632 | tactile_tokens_list.append(self.early_conv_tactile(x['tactile'+str(i)])) 633 | tactile_tokens = torch.cat(tactile_tokens_list, dim=1) 634 | else: 635 | tactile_tokens = torch.zeros((batch, 0, image_tokens.shape[-1])).to(device) 636 | else: 637 | if use_vision: 638 | image_tokens = self.image_patch_to_emb(image_patches) 639 | else: 640 | image_tokens = torch.zeros((batch, 0, self.encoder_dim)).to(device) 641 | 642 | if self.num_tactiles > 0 and use_tactile: 643 | tactile_tokens = self.tactile_patch_to_emb(tactile_patches) 644 | else: 645 | tactile_tokens = torch.zeros((batch, 0, image_tokens.shape[-1])).to(device) 646 | 647 | if use_vision: 648 | if self.use_sincosmod_encodings: 649 | image_tokens += self.encoder_modality_embedding(torch.tensor(0, device = device)) 650 | image_tokens = image_tokens + self.image_enc_pos_embedding 651 | 652 | if self.num_tactiles > 0 and use_tactile: 653 | num_single_tactile_patches = num_tactile_patches//(self.num_tactiles) 654 | for i in range(self.num_tactiles): 655 | if self.use_sincosmod_encodings: 656 | tactile_tokens[:, i*num_single_tactile_patches:(i+1)*num_single_tactile_patches] += self.encoder_modality_embedding(torch.tensor(1+i, device = device)) 657 | if self.use_sincosmod_encodings: 658 | tactile_tokens = tactile_tokens + self.tactile_enc_pos_embedding 659 | 660 | tokens = torch.cat((image_tokens, tactile_tokens), dim=1) 661 | if not self.use_sincosmod_encodings: 662 | tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] 663 | 664 | # attend with vision transformer 665 | 666 | encoded_tokens = self.encoder.transformer(tokens) 667 | 668 | return encoded_tokens 669 | 670 | def initialize_training(self, train_args): 671 | 672 | # training parameters 673 | lr = train_args['lr'] 674 | 675 | self.optimizer = optim.AdamW(self.parameters(), lr=lr) 676 | self.batch_size = train_args['batch_size'] 677 | 678 | 679 | def train_iterations(self, iterations, replay_buffer, no_tactile=False): 680 | 681 | if len(replay_buffer) < self.batch_size: 682 | print("Not enough samples in replay buffer") 683 | return 684 | 685 | self.train() 686 | 687 | t = tqdm(range(iterations), desc='Iteration'.format(ncols=80)) 688 | 689 | for i in t: 690 | 691 | x = random.choices(replay_buffer,k=self.batch_size) 692 | new_x = {} 693 | keys = ['image'] if no_tactile else ['image', 'tactile'] 694 | for key in keys: 695 | new_x[key] = np.stack([x[j][key] for j in range(self.batch_size)]) 696 | 697 | if 'image' in new_x: 698 | new_x['image'] = new_x['image'].transpose((0, 2, 3, 1, 4)) 699 | new_x['image'] = new_x['image'].reshape((new_x['image'].shape[0], new_x['image'].shape[1], new_x['image'].shape[2], -1)) 700 | if 'tactile' in new_x: 701 | new_x['tactile'] = new_x['tactile'].reshape((new_x['tactile'].shape[0], -1, new_x['tactile'].shape[3], new_x['tactile'].shape[4])) 702 | x = vt_load(new_x, frame_stack=self.frame_stack) 703 | 704 | if torch.cuda.is_available(): 705 | for key in x: 706 | x[key] = x[key].to('cuda') 707 | self.optimizer.zero_grad() 708 | r_loss = self(x) 709 | r_loss.backward() 710 | torch.nn.utils.clip_grad_norm_(self.parameters(), 0.5) 711 | self.optimizer.step() 712 | 713 | t.set_description("rloss: {}, lr: {}, Progress: ".format(r_loss, self.optimizer.param_groups[0]['lr'])) 714 | 715 | self.eval() 716 | 717 | class VTT(nn.Module): 718 | 719 | def __init__(self, *, image_size, tactile_size, image_patch_size, tactile_patch_size, dim, depth, heads, mlp_dim, image_channels = 3, tactile_channels=3, dim_head = 64, dropout = 0., emb_dropout = 0, num_tactiles=2, frame_stack=1): 720 | super().__init__() 721 | image_height, image_width = pair(image_size) 722 | tactile_height, tactile_width = pair(tactile_size) 723 | image_patch_height, image_patch_width = pair(image_patch_size) 724 | tactile_patch_height, tactile_patch_width = pair(tactile_patch_size) 725 | 726 | self.image_height = image_height 727 | self.image_width = image_width 728 | self.tactile_height = tactile_height 729 | self.tactile_width = tactile_width 730 | self.image_patch_height = image_patch_height 731 | self.image_patch_width = image_patch_width 732 | self.tactile_patch_height = tactile_patch_height 733 | self.tactile_patch_width = tactile_patch_width 734 | 735 | self.image_channels = image_channels 736 | self.tactile_channels = tactile_channels 737 | 738 | self.frame_stack = frame_stack 739 | 740 | assert image_height % image_patch_height == 0 and image_width % image_patch_width == 0, 'Image dimensions must be divisible by the patch size.' 741 | assert tactile_height % tactile_patch_height == 0 and tactile_width % tactile_patch_width == 0, 'Tactile dimensions must be divisible by the patch size.' 742 | 743 | num_patches_image = (image_height // image_patch_height) * (image_width // image_patch_width) 744 | num_patches_tactile = (tactile_height // tactile_patch_height) * (tactile_width // tactile_patch_width) * num_tactiles 745 | 746 | num_patches = num_patches_image + num_patches_tactile 747 | 748 | image_patch_dim = image_channels * image_patch_height * image_patch_width 749 | tactile_patch_dim = tactile_channels * tactile_patch_height * tactile_patch_width 750 | 751 | self.image_to_patch_embedding = nn.Sequential( 752 | # Rearrange('b (n c) h w -> b c (n h) w', n = self.frame_stack, c = image_channels), 753 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = image_patch_height, p2 = image_patch_width), 754 | nn.LayerNorm(image_patch_dim), 755 | nn.Linear(image_patch_dim, dim), 756 | nn.LayerNorm(dim), 757 | ) 758 | self.tactile_to_patch_embedding = nn.Sequential( 759 | # Rearrange('b (n c) h w -> b c (n h) w', n = self.frame_stack, c = tactile_channels), 760 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = tactile_patch_height, p2 = tactile_patch_width), 761 | nn.LayerNorm(tactile_patch_dim), 762 | nn.Linear(tactile_patch_dim, dim), 763 | nn.LayerNorm(dim) 764 | ) 765 | 766 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 767 | self.dropout = nn.Dropout(emb_dropout) 768 | 769 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 770 | 771 | self.to_latent = nn.Identity() 772 | 773 | class MAEExtractor(BaseFeaturesExtractor): 774 | """ 775 | Feature extract that flatten the input. 776 | Used as a placeholder when feature extraction is not needed. 777 | 778 | :param observation_space: 779 | """ 780 | 781 | def __init__(self, observation_space: gym.Space, mae_model, dim_embeddings, vision_only_control, frame_stack) -> None: 782 | super().__init__(observation_space, dim_embeddings) 783 | self.flatten = nn.Flatten() 784 | self.mae_model = mae_model 785 | 786 | self.running_buffer = {} 787 | 788 | self.vision_only_control = vision_only_control 789 | 790 | self.frame_stack = frame_stack 791 | 792 | self.vit_layer = VTT( 793 | image_size = (64, 64), # not used 794 | tactile_size = (32, 32), # not used 795 | image_patch_size = 8, # not used 796 | tactile_patch_size = 4, # not used 797 | dim = dim_embeddings, 798 | depth = 1, 799 | heads = 4, 800 | mlp_dim = dim_embeddings*2, 801 | num_tactiles = 2, # not used 802 | ) 803 | 804 | def forward(self, observations: torch.Tensor) -> torch.Tensor: 805 | 806 | # print("image shape: ", observations['image'].shape) 807 | # print("tactile shape: ", observations['tactile'].shape) 808 | if 'image' in observations and len(observations['image'].shape) == 5: 809 | observations['image'] = observations['image'].permute(0, 2, 3, 1, 4) 810 | observations['image'] = observations['image'].reshape((observations['image'].shape[0], observations['image'].shape[1], observations['image'].shape[2], -1)) 811 | if 'tactile' in observations and len(observations['tactile'].shape) == 5: 812 | observations['tactile'] = observations['tactile'].reshape((observations['tactile'].shape[0], -1, observations['tactile'].shape[3], observations['tactile'].shape[4])) 813 | 814 | # Get embeddings 815 | vt_torch = vt_load(observations, frame_stack=self.frame_stack) 816 | if torch.cuda.is_available(): 817 | for key in vt_torch: 818 | vt_torch[key] = vt_torch[key].to('cuda') 819 | observations = self.mae_model.get_embeddings(vt_torch, eval=False, use_tactile=not self.vision_only_control) 820 | 821 | observations = self.vit_layer.transformer(observations) 822 | observations = torch.mean(observations, dim=1) 823 | 824 | flattened = self.flatten(observations) 825 | 826 | return flattened 827 | 828 | class MAEPolicy(ActorCriticPolicy): 829 | 830 | def __init__( 831 | self, 832 | observation_space: spaces.Space, 833 | action_space: spaces.Space, 834 | lr_schedule: Schedule, 835 | net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, 836 | activation_fn: Type[nn.Module] = nn.Tanh, 837 | ortho_init: bool = True, 838 | use_sde: bool = False, 839 | log_std_init: float = 0.0, 840 | full_std: bool = True, 841 | use_expln: bool = False, 842 | squash_output: bool = False, 843 | features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, 844 | features_extractor_kwargs: Optional[Dict[str, Any]] = None, 845 | share_features_extractor: bool = True, 846 | normalize_images: bool = True, 847 | mae_model = None, 848 | dim_embeddings = 256, 849 | frame_stack = 1, 850 | vision_only_control = False, 851 | optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam, 852 | optimizer_kwargs: Optional[Dict[str, Any]] = None, 853 | ): 854 | 855 | 856 | features_extractor_class = MAEExtractor 857 | features_extractor_kwargs = {'mae_model': mae_model, 'dim_embeddings': dim_embeddings, 'vision_only_control': vision_only_control, 'frame_stack': frame_stack} 858 | ortho_init = False 859 | 860 | super().__init__( 861 | observation_space, 862 | action_space, 863 | lr_schedule, 864 | net_arch, 865 | activation_fn, 866 | ortho_init, 867 | use_sde, 868 | log_std_init, 869 | full_std, 870 | use_expln, 871 | squash_output, 872 | features_extractor_class, 873 | features_extractor_kwargs, 874 | share_features_extractor, 875 | normalize_images, 876 | optimizer_class, 877 | optimizer_kwargs 878 | ) 879 | 880 | def forward(self, obs: torch.Tensor, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 881 | """ 882 | Forward pass in all the networks (actor and critic) 883 | 884 | :param obs: Observation 885 | :param deterministic: Whether to sample or use deterministic actions 886 | :return: action, value and log probability of the action 887 | """ 888 | # Preprocess the observation if needed 889 | features = self.extract_features(obs) 890 | 891 | if self.share_features_extractor: 892 | latent_pi, latent_vf = self.mlp_extractor(features) 893 | else: 894 | pi_features, vf_features = features 895 | latent_pi = self.mlp_extractor.forward_actor(pi_features) 896 | latent_vf = self.mlp_extractor.forward_critic(vf_features) 897 | # Evaluate the values for the given observations 898 | values = self.value_net(latent_vf) 899 | distribution = self._get_action_dist_from_latent(latent_pi) 900 | actions = distribution.get_actions(deterministic=deterministic) 901 | log_prob = distribution.log_prob(actions) 902 | actions = actions.reshape((-1, *self.action_space.shape)) 903 | return actions, values, log_prob 904 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | appdirs==1.4.4 3 | cachetools==5.3.2 4 | certifi==2023.7.22 5 | charset-normalizer==3.3.2 6 | click==7.1.2 7 | cloudpickle==3.0.0 8 | contourpy==1.2.0 9 | cycler==0.12.1 10 | decorator==4.4.2 11 | docker-pycreds==0.4.0 12 | docstring_parser==0.16 13 | einops==0.7.0 14 | etils==1.7.0 15 | evdev==1.7.1 16 | exceptiongroup==1.1.3 17 | Farama-Notifications==0.0.4 18 | filelock==3.13.1 19 | fonttools==4.44.3 20 | fsspec==2023.10.0 21 | ftfy==6.1.1 22 | gitdb==4.0.11 23 | GitPython==3.1.40 24 | glfw==2.6.2 25 | google-auth==2.23.4 26 | google-auth-oauthlib==1.1.0 27 | grpcio==1.59.2 28 | gymnasium==0.29.1 29 | h5py==3.11.0 30 | hidapi==0.14.0 31 | huggingface-hub==0.19.3 32 | idna==3.4 33 | imageio==2.29.0 34 | imageio-ffmpeg==0.4.9 35 | importlib-metadata==4.13.0 36 | importlib_resources==6.4.0 37 | iniconfig==2.0.0 38 | iopath==0.1.10 39 | Jinja2==3.1.2 40 | kiwisolver==1.4.5 41 | llvmlite==0.43.0 42 | lxml==5.2.2 43 | Markdown==3.5.1 44 | markdown-it-py==3.0.0 45 | MarkupSafe==2.1.3 46 | matplotlib==3.8.1 47 | mdurl==0.1.2 48 | mediapy==1.2.2 49 | moviepy==1.0.3 50 | mpmath==1.3.0 51 | mujoco==3.1.6 52 | networkx==3.2.1 53 | numba==0.60.0 54 | numpy==1.26.2 55 | nvidia-cublas-cu12==12.1.3.1 56 | nvidia-cuda-cupti-cu12==12.1.105 57 | nvidia-cuda-nvrtc-cu12==12.1.105 58 | nvidia-cuda-runtime-cu12==12.1.105 59 | nvidia-cudnn-cu12==8.9.2.26 60 | nvidia-cufft-cu12==11.0.2.54 61 | nvidia-curand-cu12==10.3.2.106 62 | nvidia-cusolver-cu12==11.4.5.107 63 | nvidia-cusparse-cu12==12.1.0.106 64 | nvidia-nccl-cu12==2.18.1 65 | nvidia-nvjitlink-cu12==12.3.101 66 | nvidia-nvtx-cu12==12.1.105 67 | oauthlib==3.2.2 68 | opencv-python==4.8.1.78 69 | pandas==2.1.3 70 | pettingzoo==1.24.1 71 | Pillow==10.0.1 72 | pip-date==1.0.5 73 | pluggy==1.3.0 74 | portalocker==2.8.2 75 | positional-encodings==6.0.1 76 | proglog==0.1.10 77 | protobuf==4.23.4 78 | pyasn1==0.5.0 79 | pyasn1-modules==0.3.0 80 | pynput==1.7.7 81 | PyOpenGL==3.1.7 82 | pyparsing==3.1.1 83 | pytest==7.4.3 84 | python-xlib==0.33 85 | pytz==2023.3.post1 86 | PyYAML==6.0.1 87 | regex==2023.10.3 88 | req2toml==1.2.0 89 | requests==2.31.0 90 | requests-oauthlib==1.3.1 91 | rich==13.7.1 92 | rsa==4.9 93 | safetensors==0.4.0 94 | scipy==1.14.0 95 | sentry-sdk==1.35.0 96 | setproctitle==1.3.3 97 | shtab==1.7.1 98 | smmap==5.0.1 99 | stable_baselines3==2.3.2 100 | sympy==1.12 101 | tensorboard==2.15.1 102 | tensorboard-data-server==0.7.2 103 | termcolor==2.4.0 104 | timm==0.9.10 105 | tomli==2.0.1 106 | torch==2.1.1 107 | torchaudio==2.1.1 108 | torchvision==0.16.1 109 | tqdm==4.66.1 110 | triton==2.1.0 111 | tyro==0.8.4 112 | tzdata==2023.3 113 | urllib3==2.1.0 114 | vit-pytorch==1.6.4 115 | wandb==0.16.0 116 | wcwidth==0.2.10 117 | Werkzeug==3.0.1 118 | zipp==3.17.0 119 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carlosferrazza/M3L/ec9786b5f1f28d834e5835b0ca33ec4bed2e7788/teaser.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from stable_baselines3 import PPO 5 | from models.ppo_mae import PPO_MAE 6 | 7 | from stable_baselines3.common.vec_env import SubprocVecEnv, VecNormalize, DummyVecEnv 8 | from stable_baselines3.common.utils import set_random_seed 9 | 10 | import tactile_envs 11 | import envs 12 | from utils.callbacks import create_callbacks 13 | from models.pretrain_models import VTT, VTMAE, MAEPolicy 14 | 15 | def str2bool(v): 16 | if v.lower() == "true": 17 | return True 18 | if v.lower() == "false": 19 | return False 20 | raise ValueError(f"boolean argument should be either True or False (got {v})") 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser("M3L") 24 | 25 | parser.add_argument("--seed", type=int, default=0) 26 | parser.add_argument("--save_freq", type=int, default=int(1e5)) 27 | parser.add_argument("--eval_every", type=int, default=int(2e5)) 28 | parser.add_argument("--total_timesteps", type=int, default=int(3e6)) 29 | 30 | parser.add_argument("--wandb_dir", type=str, default="./wandb/") 31 | parser.add_argument("--wandb_id", type=str, default=None) 32 | parser.add_argument("--wandb_entity", type=str, default=None) 33 | 34 | # Environment. 35 | parser.add_argument( 36 | "--env", 37 | type=str, 38 | default="tactile_envs/Insertion-v0", 39 | choices=[ 40 | "tactile_envs/Insertion-v0", 41 | "Door", 42 | "HandManipulateBlockRotateZFixed-v1", 43 | "HandManipulateEggRotateFixed-v1", 44 | "HandManipulatePenRotateFixed-v1" 45 | ], 46 | ) 47 | parser.add_argument("--n_envs", type=int, default=8) 48 | parser.add_argument( 49 | "--state_type", 50 | type=str, 51 | default="vision_and_touch", 52 | choices=["vision", "touch", "vision_and_touch"] 53 | ) 54 | parser.add_argument("--norm_reward", type=str2bool, default=True) 55 | parser.add_argument("--use_latch", type=str2bool, default=True) 56 | 57 | parser.add_argument("--camera_idx", type=int, default=0, choices=[0, 1, 2, 3]) 58 | parser.add_argument("--frame_stack", type=int, default=4) 59 | parser.add_argument("--no_rotation", type=str2bool, default=True) 60 | 61 | # MAE. 62 | parser.add_argument("--representation", type=str2bool, default=True) 63 | parser.add_argument("--early_conv_masking", type=str2bool, default=True) 64 | 65 | parser.add_argument("--dim_embedding", type=int, default=256) 66 | parser.add_argument("--use_sincosmod_encodings", type=str2bool, default=True) 67 | parser.add_argument("--masking_ratio", type=float, default=0.95) 68 | 69 | parser.add_argument("--mae_batch_size", type=int, default=32) 70 | parser.add_argument("--train_mae_every", type=int, default=1) 71 | 72 | # PPO. 73 | parser.add_argument("--rollout_length", type=int, default=32768) 74 | parser.add_argument("--ppo_epochs", type=int, default=10) 75 | parser.add_argument("--lr_ppo", type=float, default=1e-4) 76 | parser.add_argument("--vision_only_control", type=bool, default=False) 77 | parser.add_argument("--batch_size", type=int, default=512) 78 | 79 | # PPO-MAE. 80 | parser.add_argument("--separate_optimizer", type=str2bool, default=False) 81 | 82 | config = parser.parse_args() 83 | 84 | set_random_seed(config.seed) 85 | 86 | num_tactiles = 0 87 | if config.state_type == "vision_and_touch" or config.state_type == "touch": 88 | num_tactiles = 2 89 | if config.env == "HandManipulateBlockRotateZFixed-v1" or config.env == "HandManipulateEggRotateFixed-v1" or config.env == "HandManipulatePenRotateFixed-v1": 90 | num_tactiles = 1 91 | 92 | env_config = { 93 | "use_latch": config.use_latch, 94 | } 95 | 96 | objects = [ 97 | "square", 98 | "triangle", 99 | "horizontal", 100 | "vertical", 101 | "trapezoidal", 102 | "rhombus", 103 | ] 104 | holders = ["holder1", "holder2", "holder3"] 105 | 106 | env_list = [ 107 | envs.make_env( 108 | config.env, 109 | i, 110 | config.seed, 111 | config.state_type, 112 | objects=objects, 113 | holders=holders, 114 | camera_idx=config.camera_idx, 115 | frame_stack=config.frame_stack, 116 | no_rotation=config.no_rotation, 117 | **env_config, 118 | ) 119 | for i in range(config.n_envs) 120 | ] 121 | 122 | if config.n_envs < 100: 123 | env = SubprocVecEnv(env_list) 124 | else: 125 | env = DummyVecEnv(env_list) 126 | env = VecNormalize(env, norm_obs=False, norm_reward=config.norm_reward) 127 | 128 | v = VTT( 129 | image_size=(64, 64), 130 | tactile_size=(32, 32), 131 | image_patch_size=8, 132 | tactile_patch_size=4, 133 | dim=config.dim_embedding, 134 | depth=4, 135 | heads=4, 136 | mlp_dim=config.dim_embedding * 2, 137 | num_tactiles=num_tactiles, 138 | image_channels=3*config.frame_stack, 139 | tactile_channels=3*config.frame_stack, 140 | frame_stack=config.frame_stack, 141 | ) 142 | 143 | mae = VTMAE( 144 | encoder=v, 145 | masking_ratio=config.masking_ratio, 146 | decoder_dim=config.dim_embedding, 147 | decoder_depth=3, 148 | decoder_heads=4, 149 | num_tactiles=num_tactiles, 150 | early_conv_masking=config.early_conv_masking, 151 | use_sincosmod_encodings=config.use_sincosmod_encodings, 152 | frame_stack = config.frame_stack 153 | ) 154 | if torch.cuda.is_available(): 155 | mae.cuda() 156 | mae.eval() 157 | 158 | if config.representation: 159 | mae.initialize_training({"lr": 1e-4, "batch_size": config.mae_batch_size}) 160 | 161 | policy = MAEPolicy 162 | policy_kwargs={ 163 | "mae_model": mae, 164 | "dim_embeddings": config.dim_embedding, 165 | "vision_only_control": config.vision_only_control, 166 | "net_arch": dict(pi=[256, 256], vf=[256, 256]), 167 | "frame_stack": config.frame_stack, 168 | } 169 | 170 | model = PPO_MAE( 171 | policy, 172 | env, 173 | verbose=1, 174 | learning_rate=config.lr_ppo, 175 | tensorboard_log=config.wandb_dir+"tensorboard/", 176 | batch_size=config.batch_size, 177 | n_steps=config.rollout_length // config.n_envs, 178 | n_epochs=config.ppo_epochs, 179 | mae_batch_size=config.mae_batch_size, 180 | separate_optimizer=config.separate_optimizer, 181 | policy_kwargs=policy_kwargs, 182 | mae=mae, 183 | ) 184 | 185 | 186 | callbacks = create_callbacks( 187 | config, model, num_tactiles, objects, holders 188 | ) 189 | model.learn(total_timesteps=config.total_timesteps, callback=callbacks) 190 | else: 191 | 192 | model = PPO( 193 | MAEPolicy, 194 | env, 195 | verbose=1, 196 | learning_rate=config.lr_ppo, 197 | tensorboard_log=config.wandb_dir+"ppo_privileged_tensorboard/", 198 | batch_size=config.batch_size, 199 | n_steps=config.rollout_length // config.n_envs, 200 | n_epochs=config.ppo_epochs, 201 | policy_kwargs={ 202 | "mae_model": mae, 203 | "dim_embeddings": config.dim_embedding, 204 | "net_arch": dict(pi=[256, 256], vf=[256, 256]), 205 | "frame_stack": config.frame_stack, 206 | }, 207 | ) 208 | callbacks = create_callbacks( 209 | config, model, num_tactiles, objects, holders 210 | ) 211 | model.learn(total_timesteps=config.total_timesteps, callback=callbacks) 212 | 213 | 214 | if __name__ == "__main__": 215 | main() 216 | -------------------------------------------------------------------------------- /utils/add_tactile.py: -------------------------------------------------------------------------------- 1 | """Wrapper for resizing observations.""" 2 | from __future__ import annotations 3 | 4 | import numpy as np 5 | 6 | import gymnasium as gym 7 | from gymnasium.error import DependencyNotInstalled 8 | from gymnasium.spaces import Box, Dict 9 | 10 | 11 | class AddTactile(gym.ObservationWrapper): 12 | 13 | def __init__(self, env: gym.Env) -> None: 14 | 15 | gym.ObservationWrapper.__init__(self, env) 16 | 17 | self.obs_shape = (3, 32, 32) 18 | 19 | self.observation_space['tactile'] = Box( 20 | low=-np.inf, 21 | high=np.inf, 22 | shape=self.obs_shape, 23 | dtype=np.float32, 24 | ) 25 | 26 | self.mj_data = self.env.unwrapped.data 27 | 28 | def observation(self, observation): 29 | 30 | rh_palm_touch = self.mj_data.sensor('rh_palm_touch').data.reshape((3, 3, 3)) # 3 x nx x ny 31 | rh_palm_touch = rh_palm_touch[[1, 2, 0]] # zxy -> xyz 32 | 33 | rh_ffproximal_touch = self.mj_data.sensor('rh_ffproximal_touch').data.reshape((3, 3, 3)) 34 | rh_ffproximal_touch = rh_ffproximal_touch[[1, 2, 0]] 35 | 36 | rh_ffmiddle_touch = self.mj_data.sensor('rh_ffmiddle_touch').data.reshape((3, 3, 3)) 37 | rh_ffmiddle_touch = rh_ffmiddle_touch[[1, 2, 0]] 38 | 39 | rh_ffdistal_touch = self.mj_data.sensor('rh_ffdistal_touch').data.reshape((3, 3, 3)) 40 | rh_ffdistal_touch = rh_ffdistal_touch[[1, 2, 0]] 41 | 42 | rh_mfproximal_touch = self.mj_data.sensor('rh_mfproximal_touch').data.reshape((3, 3, 3)) 43 | rh_mfproximal_touch = rh_mfproximal_touch[[1, 2, 0]] 44 | 45 | rh_mfmiddle_touch = self.mj_data.sensor('rh_mfmiddle_touch').data.reshape((3, 3, 3)) 46 | rh_mfmiddle_touch = rh_mfmiddle_touch[[1, 2, 0]] 47 | 48 | rh_mfdistal_touch = self.mj_data.sensor('rh_mfdistal_touch').data.reshape((3, 3, 3)) 49 | rh_mfdistal_touch = rh_mfdistal_touch[[1, 2, 0]] 50 | 51 | rh_rfproximal_touch = self.mj_data.sensor('rh_rfproximal_touch').data.reshape((3, 3, 3)) 52 | rh_rfproximal_touch = rh_rfproximal_touch[[1, 2, 0]] 53 | 54 | rh_rfmiddle_touch = self.mj_data.sensor('rh_rfmiddle_touch').data.reshape((3, 3, 3)) 55 | rh_rfmiddle_touch = rh_rfmiddle_touch[[1, 2, 0]] 56 | 57 | rh_rfdistal_touch = self.mj_data.sensor('rh_rfdistal_touch').data.reshape((3, 3, 3)) 58 | rh_rfdistal_touch = rh_rfdistal_touch[[1, 2, 0]] 59 | 60 | rh_lfmetacarpal_touch = self.mj_data.sensor('rh_lfmetacarpal_touch').data.reshape((3, 3, 3)) 61 | rh_lfmetacarpal_touch = rh_lfmetacarpal_touch[[1, 2, 0]] 62 | 63 | rh_lfproximal_touch = self.mj_data.sensor('rh_lfproximal_touch').data.reshape((3, 3, 3)) 64 | rh_lfproximal_touch = rh_lfproximal_touch[[1, 2, 0]] 65 | 66 | rh_lfmiddle_touch = self.mj_data.sensor('rh_lfmiddle_touch').data.reshape((3, 3, 3)) 67 | rh_lfmiddle_touch = rh_lfmiddle_touch[[1, 2, 0]] 68 | 69 | rh_lfdistal_touch = self.mj_data.sensor('rh_lfdistal_touch').data.reshape((3, 3, 3)) 70 | rh_lfdistal_touch = rh_lfdistal_touch[[1, 2, 0]] 71 | 72 | rh_thproximal_touch = self.mj_data.sensor('rh_thproximal_touch').data.reshape((3, 3, 3)) 73 | rh_thproximal_touch = rh_thproximal_touch[[1, 2, 0]] 74 | 75 | rh_thmiddle_touch = self.mj_data.sensor('rh_thmiddle_touch').data.reshape((3, 3, 3)) 76 | rh_thmiddle_touch = rh_thmiddle_touch[[1, 2, 0]] 77 | 78 | rh_thdistal_touch = self.mj_data.sensor('rh_thdistal_touch').data.reshape((3, 3, 3)) 79 | rh_thdistal_touch = rh_thdistal_touch[[1, 2, 0]] 80 | 81 | block_1 = np.zeros((3, 8, 32)) 82 | 83 | block_2 = np.concatenate((np.zeros((3, 3, 2)), rh_lfdistal_touch, np.zeros((3, 3, 3)), rh_rfdistal_touch, np.zeros((3, 3, 3)), rh_mfdistal_touch, np.zeros((3, 3, 3)), rh_ffdistal_touch, np.zeros((3, 3, 4)), rh_thdistal_touch, np.zeros((3, 3, 2))), axis=2) 84 | 85 | block_3 = np.concatenate((np.zeros((3, 3, 2)), rh_lfmiddle_touch, np.zeros((3, 3, 3)), rh_rfmiddle_touch, np.zeros((3, 3, 3)), rh_mfmiddle_touch, np.zeros((3, 3, 3)), rh_ffmiddle_touch, np.zeros((3, 3, 4)), rh_thmiddle_touch, np.zeros((3, 3, 2))), axis=2) 86 | 87 | block_4 = np.concatenate((np.zeros((3, 3, 2)), rh_lfproximal_touch, np.zeros((3, 3, 3)), rh_rfproximal_touch, np.zeros((3, 3, 3)), rh_mfproximal_touch, np.zeros((3, 3, 3)), rh_ffproximal_touch, np.zeros((3, 3, 4)), rh_thproximal_touch, np.zeros((3, 3, 2))), axis=2) 88 | 89 | block_5 = np.concatenate((np.zeros((3, 3, 2)), rh_lfmetacarpal_touch, np.zeros((3, 3, 27))), axis=2) 90 | 91 | block_6 = np.concatenate((np.zeros((3, 3, 14)), rh_palm_touch, np.zeros((3, 3, 15))), axis=2) 92 | 93 | block_7 = np.zeros((3, 9, 32)) 94 | 95 | tactiles = np.concatenate((block_1, block_2, block_3, block_4, block_5, block_6, block_7), axis=1) 96 | 97 | tactiles = np.sign(tactiles) * np.log(1 + np.abs(tactiles)) 98 | observation['tactile'] = tactiles 99 | 100 | return observation 101 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | 4 | from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback 5 | from stable_baselines3.common.utils import configure_logger 6 | from stable_baselines3.common.vec_env import DummyVecEnv 7 | 8 | import envs 9 | from utils.wandb_logger import WandbLogger 10 | from utils.pretrain_utils import log_videos 11 | 12 | class TensorboardCallback(BaseCallback): 13 | """ 14 | Custom callback for plotting additional values in tensorboard. 15 | """ 16 | 17 | def __init__(self, verbose=0): 18 | super().__init__(verbose) 19 | 20 | def _on_step(self) -> bool: 21 | self.logger.record("rollout/avg_success", np.mean(self.model.ep_success_buffer)) 22 | return True 23 | 24 | 25 | class EvalCallback(BaseCallback): 26 | def __init__( 27 | self, 28 | env, 29 | state_type, 30 | no_tactile=False, 31 | representation=True, 32 | eval_every=1, 33 | verbose=0, 34 | config=None, 35 | objects=["square"], 36 | holders=["holder2"], 37 | camera_idx=0, 38 | frame_stack=1, 39 | ): 40 | super(EvalCallback, self).__init__(verbose) 41 | self.n_samples = 4 42 | self.eval_seed = 100 43 | self.no_tactile = no_tactile 44 | self.representation = representation 45 | 46 | env_config = {"use_latch": config.use_latch} 47 | 48 | self.test_env = DummyVecEnv( 49 | [ 50 | envs.make_env( 51 | env, 52 | 0, 53 | self.eval_seed, 54 | state_type=state_type, 55 | objects=objects, 56 | holders=holders, 57 | camera_idx=camera_idx, 58 | frame_stack=frame_stack, 59 | no_rotation=config.no_rotation, 60 | **env_config 61 | ) 62 | ] 63 | ) 64 | self.count = 0 65 | self.eval_every = eval_every 66 | 67 | def _on_step(self) -> bool: 68 | return True 69 | 70 | def _on_rollout_start(self) -> None: 71 | self.count += 1 72 | if self.count >= self.eval_every: 73 | 74 | ret, obses, rewards_per_step = self.eval_model() 75 | frame_stack = obses[0]["image"].shape[1] 76 | self.logger.record("eval/return", ret) 77 | 78 | log_videos( 79 | obses, 80 | rewards_per_step, 81 | self.logger, 82 | self.model.num_timesteps, 83 | frame_stack=frame_stack, 84 | ) 85 | self.count = 0 86 | 87 | def eval_model(self): 88 | print("Collect eval rollout") 89 | obs = self.test_env.reset() 90 | dones = [False] 91 | reward = 0 92 | obses = [] 93 | rewards_per_step = [] 94 | while not dones[0]: 95 | action, _ = self.model.predict(obs, deterministic=False) 96 | obs, rewards, dones, info = self.test_env.step(action) 97 | reward += rewards[0] 98 | rewards_per_step.append(rewards[0]) 99 | obses.append(obs) 100 | 101 | return reward, obses, rewards_per_step 102 | 103 | 104 | def create_callbacks(config, model, num_tactiles, objects, holders): 105 | no_tactile = num_tactiles == 0 106 | project_name = "MultimodalLearning" 107 | if config.env in ["Door"]: 108 | project_name += "_robosuite" 109 | 110 | callbacks = [] 111 | 112 | eval_callback = EvalCallback( 113 | config.env, 114 | config.state_type, 115 | no_tactile=no_tactile, 116 | representation=config.representation, 117 | eval_every=config.eval_every // config.rollout_length, 118 | config=config, 119 | objects=objects, 120 | holders=holders, 121 | camera_idx=config.camera_idx, 122 | frame_stack=config.frame_stack, 123 | ) 124 | callbacks.append(eval_callback) 125 | 126 | checkpoint_callback = CheckpointCallback( 127 | save_freq=max(config.save_freq // config.n_envs, 1), 128 | save_path="./logs/", 129 | name_prefix="rl_model", 130 | save_replay_buffer=False, 131 | save_vecnormalize=True, 132 | ) 133 | callbacks.append(checkpoint_callback) 134 | callbacks.append(TensorboardCallback()) 135 | 136 | default_logger = configure_logger( 137 | verbose=1, tensorboard_log=model.tensorboard_log, tb_log_name="PPO" 138 | ) 139 | wandb.init( 140 | project=project_name, 141 | config=config, 142 | save_code=True, 143 | name=default_logger.dir.split("/")[-1], 144 | dir=config.wandb_dir, 145 | id=config.wandb_id, 146 | entity=config.wandb_entity, 147 | ) 148 | logger = WandbLogger( 149 | default_logger.dir, default_logger.output_formats, log_interval=1000 150 | ) 151 | model.set_logger(logger) 152 | checkpoint_callback.save_path = wandb.run.dir 153 | 154 | return callbacks 155 | -------------------------------------------------------------------------------- /utils/frame_stack.py: -------------------------------------------------------------------------------- 1 | """Wrapper that stacks frames.""" 2 | from collections import deque 3 | 4 | import numpy as np 5 | 6 | import gymnasium as gym 7 | from gymnasium.spaces import Box 8 | 9 | class FrameStack(gym.ObservationWrapper): 10 | """Observation wrapper that stacks the observations in a rolling manner. 11 | 12 | For example, if the number of stacks is 4, then the returned observation contains 13 | the most recent 4 observations. For environment 'Pendulum-v1', the original observation 14 | is an array with shape [3], so if we stack 4 observations, the processed observation 15 | has shape [4, 3]. 16 | 17 | Note: 18 | - To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`. 19 | - The observation space must be :class:`Box` type. If one uses :class:`Dict` 20 | as observation space, it should apply :class:`FlattenObservation` wrapper first. 21 | - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. I.e. the observation returned by :meth:`reset` will consist of ``num_stack`-many identical frames, 22 | 23 | Example: 24 | >>> import gym 25 | >>> env = gym.make('CarRacing-v1') 26 | >>> env = FrameStack(env, 4) 27 | >>> env.observation_space 28 | Box(4, 96, 96, 3) 29 | >>> obs = env.reset() 30 | >>> obs.shape 31 | (4, 96, 96, 3) 32 | """ 33 | 34 | def __init__( 35 | self, 36 | env: gym.Env, 37 | num_stack: int 38 | ): 39 | """Observation wrapper that stacks the observations in a rolling manner. 40 | 41 | Args: 42 | env (Env): The environment to apply the wrapper 43 | num_stack (int): The number of frames to stack 44 | lz4_compress (bool): Use lz4 to compress the frames internally 45 | """ 46 | super().__init__(env) 47 | self.num_stack = num_stack 48 | 49 | 50 | self.frames = dict(zip(self.observation_space.spaces.keys(), [deque([], maxlen=num_stack) for _ in range(len(self.observation_space.spaces))])) 51 | 52 | self.observation_space = gym.spaces.Dict() 53 | for key in self.env.observation_space.spaces.keys(): 54 | low = np.repeat(self.env.observation_space[key].low[np.newaxis, ...], num_stack, axis=0) 55 | high = np.repeat( 56 | self.env.observation_space[key].high[np.newaxis, ...], num_stack, axis=0 57 | ) 58 | self.observation_space[key] = Box( 59 | low=low, high=high, dtype=self.env.observation_space[key].dtype 60 | ) 61 | 62 | def observation(self, observation): 63 | """Converts the wrappers current frames to lazy frames. 64 | 65 | Args: 66 | observation: Ignored 67 | 68 | Returns: 69 | :class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames` 70 | """ 71 | for key in self.frames.keys(): 72 | assert len(self.frames[key]) == self.num_stack, (len(self.frames[key]), self.num_stack) 73 | new_frames = dict(zip(self.observation_space.spaces.keys(), [np.stack(self.frames[key], axis=0) for key in self.frames.keys()])) 74 | return new_frames 75 | 76 | def step(self, action): 77 | """Steps through the environment, appending the observation to the frame buffer. 78 | 79 | Args: 80 | action: The action to step through the environment with 81 | 82 | Returns: 83 | Stacked observations, reward, terminated, truncated, and information from the environment 84 | """ 85 | 86 | observation, reward, terminated, truncated, info = self.env.step(action) 87 | for key in self.frames.keys(): 88 | self.frames[key].append(observation[key]) 89 | return self.observation(None), reward, terminated, truncated, info 90 | 91 | def reset(self, **kwargs): 92 | """Reset the environment with kwargs. 93 | 94 | Args: 95 | **kwargs: The kwargs for the environment reset 96 | 97 | Returns: 98 | The stacked observations 99 | """ 100 | obs, _ = self.env.reset(**kwargs) 101 | 102 | for key in self.frames.keys(): 103 | [self.frames[key].append(obs[key]) for _ in range(self.num_stack)] 104 | 105 | return self.observation(None), {} 106 | 107 | def render(self, highres=False): 108 | if highres: 109 | img = self.env.env.env.env.observation(None)['image'] 110 | return img.astype(np.float32)/255 111 | else: 112 | return self.env.render() -------------------------------------------------------------------------------- /utils/pretrain_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | from stable_baselines3.common.logger import Video 6 | 7 | def vt_load( 8 | x, image_normalization=[0, 1], tactile_normalization=[-1, 1], squeeze=False, frame_stack=1 9 | ): 10 | ### Load and normalize to [0,1] ### 11 | 12 | if isinstance(x, str): 13 | path = x 14 | x = np.load(path, allow_pickle=True).item() 15 | 16 | if "image" in x: 17 | if len(x["image"].shape) == 3: 18 | x["image"] = x["image"][None, :, :, :] # Add batch dimension 19 | 20 | if "tactile" in x: 21 | if len(x["tactile"].shape) == 3: 22 | x["tactile"] = x["tactile"][None, :, :, :] # Add batch dimension 23 | 24 | # Preprocess the image 25 | if "image" in x: 26 | assert x["image"].shape[-1] == 3*frame_stack 27 | x["image"] = torch.Tensor(x["image"]).permute(0, 3, 1, 2) 28 | x["image"] = (x["image"] - image_normalization[0]) / ( 29 | image_normalization[1] - image_normalization[0] 30 | ) 31 | 32 | # Preprocess the tactile 33 | if "tactile" in x: 34 | assert x["tactile"].shape[1] == 3*frame_stack or x["tactile"].shape[1] == 6*frame_stack or x["tactile"].shape[1] == 12*frame_stack 35 | 36 | idx = [] 37 | n_tactiles = x["tactile"].shape[1] // frame_stack 38 | for i in range(frame_stack): 39 | idx.append(i*n_tactiles+0) 40 | idx.append(i*n_tactiles+1) 41 | idx.append(i*n_tactiles+2) 42 | idx = np.array(idx).flatten() 43 | 44 | n_sensors = n_tactiles//3 45 | for tactile_idx in range(n_sensors): 46 | x["tactile"+str(tactile_idx+1)] = torch.Tensor(x["tactile"][:, idx+3*tactile_idx]) 47 | x["tactile"+str(tactile_idx+1)] = (x["tactile"+str(tactile_idx+1)] - tactile_normalization[0]) / ( 48 | tactile_normalization[1] - tactile_normalization[0] 49 | ) 50 | 51 | del x["tactile"] 52 | 53 | if squeeze: 54 | for key in x: 55 | x[key] = x[key].squeeze() 56 | 57 | return x 58 | 59 | 60 | def train(model, train_loader, optimizer, epoch, writer, normalize_image=False): 61 | model.train() 62 | 63 | t = tqdm(train_loader, desc="Iteration".format(ncols=80)) 64 | 65 | for iter, data in enumerate(t): 66 | x = data[0] 67 | 68 | if normalize_image: 69 | x = x.float() / 255 70 | if torch.cuda.is_available(): 71 | if isinstance(x, dict): 72 | for key in x: 73 | x[key] = x[key].cuda() 74 | else: 75 | x = x.cuda() 76 | optimizer.zero_grad() 77 | r_loss = model(x) 78 | r_loss.backward() 79 | optimizer.step() 80 | 81 | writer.add_scalar("Loss/train", r_loss, epoch * len(train_loader) + iter) 82 | 83 | t.set_description( 84 | "Epoch {}, rloss: {}, lr: {}, Progress: ".format( 85 | epoch, r_loss, optimizer.param_groups[0]["lr"] 86 | ) 87 | ) 88 | 89 | 90 | def eval_loss(model, loader, normalize_image=False): 91 | model.eval() 92 | 93 | r_loss = 0 94 | with torch.no_grad(): 95 | for data in loader: 96 | x = data[0] 97 | if normalize_image: 98 | x = x.float() / 255 99 | if torch.cuda.is_available(): 100 | if isinstance(x, dict): 101 | for key in x: 102 | x[key] = x[key].cuda() 103 | else: 104 | x = x.cuda() 105 | r_loss = model(x).item() * len(x) 106 | 107 | return r_loss / len(loader.dataset) 108 | 109 | def annotate_frame(step, frame, rew, info={}): 110 | 111 | """Renders a video frame and adds caption.""" 112 | if np.max(frame) <= 1.0: 113 | frame *= 255.0 114 | frame = frame.astype(np.uint8) 115 | 116 | # Set the minimum size of frames to (`S`, `S`) for caption readibility. 117 | # S = 512 118 | S = 128 119 | if frame.shape[0] < S: 120 | frame = cv2.resize(frame, (int(S * frame.shape[1] / frame.shape[0]), S)) 121 | h, w = frame.shape[:2] 122 | 123 | # Add caption. 124 | frame = np.concatenate([frame, np.zeros((64, w, 3), np.uint8)], 0) 125 | scale = h / S 126 | font_size = 0.4 * scale 127 | font_face = cv2.FONT_HERSHEY_SIMPLEX 128 | x, y = int(5 * scale), h + int(10 * scale) 129 | add_text = lambda x, y, c, t: cv2.putText( 130 | frame, t, (x, y), font_face, font_size, c, 1, cv2.LINE_AA 131 | ) 132 | 133 | add_text(x, y, (255, 255, 0), f"{step:5} {rew:.3f}") 134 | for i, k in enumerate(info.keys()): 135 | key_text = f"{k}: " 136 | key_width = cv2.getTextSize(key_text, font_face, font_size, 1)[0][0] 137 | offset = int(12 * scale) * (i + 2) 138 | add_text(x, y + offset, (66, 133, 244), key_text) 139 | value_text = str(info[k]) 140 | if isinstance(info[k], np.ndarray): 141 | value_text = np.array2string( 142 | info[k], precision=2, separator=", ", floatmode="fixed" 143 | ) 144 | add_text(x + key_width, y + offset, (255, 255, 255), value_text) 145 | 146 | return frame 147 | 148 | def log_videos( 149 | obses, 150 | rewards_per_step, 151 | logger, 152 | num_timesteps, 153 | frame_stack=1 154 | ): 155 | 156 | image_video = [] 157 | reward_video = [] 158 | 159 | episode_return = 0.0 160 | for (x, reward) in zip(obses, rewards_per_step): # For each timestep 161 | 162 | if 'image' in x: 163 | x['image'] = x['image'].transpose((0, 2, 3, 1, 4)) 164 | x['image'] = x['image'].reshape((x['image'].shape[0], x['image'].shape[1], x['image'].shape[2], -1)) 165 | if 'tactile' in x: 166 | x['tactile'] = x['tactile'].reshape((x['tactile'].shape[0], -1, x['tactile'].shape[3], x['tactile'].shape[4])) 167 | 168 | x = vt_load(x, frame_stack=frame_stack) 169 | 170 | if torch.cuda.is_available(): 171 | for key in ['image', 'tactile1', 'tactile2']: 172 | if key in x: 173 | x[key] = x[key].cuda() 174 | 175 | # if frame_stack > 1: # image is (1, frame_stack*channels, height, width) 176 | image = x["image"].reshape((frame_stack, -1, x["image"].shape[2], x["image"].shape[3])) 177 | image = image[-1].detach().cpu() 178 | 179 | image_video.append(image) 180 | 181 | episode_return += reward 182 | reward_video.append(episode_return) 183 | 184 | image_video = [ 185 | annotate_frame(i, frame.numpy().transpose(1, 2, 0), reward_video[i]) 186 | for i, frame in enumerate(image_video) 187 | ] 188 | 189 | logger.record( 190 | "eval/image_video", 191 | Video(np.stack(image_video).transpose(0, 3, 1, 2)[None], fps=int(40)), 192 | exclude=("stdout", "log", "json", "csv"), 193 | ) 194 | 195 | logger.dump(step=num_timesteps) 196 | 197 | return True 198 | 199 | -------------------------------------------------------------------------------- /utils/resize_dict.py: -------------------------------------------------------------------------------- 1 | """Wrapper for resizing observations.""" 2 | from __future__ import annotations 3 | 4 | import numpy as np 5 | 6 | import gymnasium as gym 7 | from gymnasium.error import DependencyNotInstalled 8 | from gymnasium.spaces import Box, Dict 9 | 10 | 11 | class ResizeDict(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): 12 | """Resize the image observation. 13 | 14 | This wrapper works on environments with image observations. More generally, 15 | the input can either be two-dimensional (AxB, e.g. grayscale images) or 16 | three-dimensional (AxBxC, e.g. color images). This resizes the observation 17 | to the shape given by the 2-tuple :attr:`shape`. 18 | The argument :attr:`shape` may also be an integer, in which case, the 19 | observation is scaled to a square of side-length :attr:`shape`. 20 | 21 | Example: 22 | >>> import gymnasium as gym 23 | >>> from gymnasium.wrappers import ResizeObservation 24 | >>> env = gym.make("CarRacing-v2") 25 | >>> env.observation_space.shape 26 | (96, 96, 3) 27 | >>> env = ResizeObservation(env, 64) 28 | >>> env.observation_space.shape 29 | (64, 64, 3) 30 | """ 31 | 32 | def __init__(self, env: gym.Env, shape: tuple[int, int] | int, pixel_key='pixels') -> None: 33 | """Resizes image observations to shape given by :attr:`shape`. 34 | 35 | Args: 36 | env: The environment to apply the wrapper 37 | shape: The shape of the resized observations 38 | """ 39 | gym.utils.RecordConstructorArgs.__init__(self, shape=shape) 40 | gym.ObservationWrapper.__init__(self, env) 41 | 42 | self.pixel_key = pixel_key 43 | 44 | if isinstance(shape, int): 45 | shape = (shape, shape) 46 | assert len(shape) == 2 and all( 47 | x > 0 for x in shape 48 | ), f"Expected shape to be a 2-tuple of positive integers, got: {shape}" 49 | 50 | self.shape = tuple(shape) 51 | 52 | assert isinstance( 53 | env.observation_space[self.pixel_key], Box 54 | ), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}" 55 | dims = len(env.observation_space[self.pixel_key].shape) 56 | assert ( 57 | dims == 2 or dims == 3 58 | ), f"Expected the observation space to have 2 or 3 dimensions, got: {dims}" 59 | 60 | obs_shape = self.shape + env.observation_space[self.pixel_key].shape[2:] 61 | self.observation_space = Dict({self.pixel_key: Box(low=0, high=1, shape=obs_shape, dtype=np.uint8)}) 62 | 63 | def observation(self, observation): 64 | """Updates the observations by resizing the observation to shape given by :attr:`shape`. 65 | 66 | Args: 67 | observation: The observation to reshape 68 | 69 | Returns: 70 | The reshaped observations 71 | 72 | Raises: 73 | DependencyNotInstalled: opencv-python is not installed 74 | """ 75 | try: 76 | import cv2 77 | except ImportError as e: 78 | raise DependencyNotInstalled( 79 | "opencv (cv2) is not installed, run `pip install gymnasium[other]`" 80 | ) from e 81 | 82 | observation[self.pixel_key] = cv2.resize( 83 | observation[self.pixel_key], self.shape[::-1], interpolation=cv2.INTER_AREA 84 | ) 85 | observation[self.pixel_key] = observation[self.pixel_key].reshape(self.observation_space[self.pixel_key].shape[-3:])/255 86 | return observation 87 | -------------------------------------------------------------------------------- /utils/wandb_logger.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import wandb 5 | import numpy as np 6 | from stable_baselines3.common.logger import Logger, KVWriter 7 | from stable_baselines3.common.logger import Image, Video 8 | 9 | 10 | class WandbLogger(Logger): 11 | def __init__(self, folder: Optional[str], output_formats: List[KVWriter], log_interval: int = 1000): 12 | super().__init__(folder, output_formats) 13 | self.log_interval = log_interval 14 | self.i_log = 0 15 | 16 | def dump(self, step: int = 0) -> None: 17 | if step - self.i_log >= self.log_interval: 18 | convert_name = {k: k for k in self.name_to_value.keys()} 19 | convert_name.update({ # SB3 <-> wandb 20 | 'rollout/ep_len_mean': 'rollout/ep_len_mean', 21 | 'rollout/ep_rew_mean': 'rollout/ep_rew_mean', 22 | 'time/fps': 'time/fps', 23 | 'train/approx_kl': 'train/approx_kl', 24 | 'train/clip_fraction': 'train/clip_fraction', 25 | 'train/clip_range': 'train/clip_range', 26 | 'train/entropy_loss': 'train/entropy_loss', 27 | 'train/explained_variance': 'train/explained_variance', 28 | 'train/learning_rate': 'train/learning_rate', 29 | 'train/loss': 'train/loss', 30 | 'train/policy_gradient_loss': 'train/policy_gradient_loss', 31 | 'train/std': 'train/std', 32 | 'train/value_loss': 'train/value_loss' 33 | }) 34 | 35 | log_result = {convert_name[k]: v for k, v in self.name_to_value.items()} 36 | 37 | for k, v in self.name_to_value.items(): 38 | if isinstance(v, Video): 39 | if isinstance(v.frames, torch.Tensor): 40 | v.frames = v.frames.numpy() 41 | if isinstance(v.frames, np.ndarray) and v.frames.dtype != np.uint8: 42 | v.frames = (255 * np.clip(v.frames, 0, 1)).astype(np.uint8) 43 | log_result[k] = wandb.Video(v.frames, fps=v.fps, format="gif") 44 | 45 | wandb.log(log_result, step) 46 | self.i_log = step 47 | 48 | super().dump(step) 49 | --------------------------------------------------------------------------------