├── .gitignore ├── LICENSE ├── README.md ├── config ├── resnet_delta_force_mod_single_view_force_Linear_encode.yaml ├── resnet_delta_force_mod_single_view_force_MLP_encode.yaml ├── resnet_delta_force_mod_single_view_no_encode.yaml ├── resnet_delta_no_force_single_view.yaml └── resnet_delta_with_force_single_view_force_Linear_crossattn_hybrid_crop.yaml ├── data_util.py ├── env_util.py ├── imgs └── overview_system.png ├── inference.py ├── inference_real.py ├── kuka_execute.py ├── module_attr_mixin.py ├── network.py ├── real_robot_network copy.py ├── real_robot_network.py ├── requirements.txt ├── robot_data_collection_joint_state.py ├── rotation_transformer.py ├── rotation_utils.py ├── test_rotation.py ├── train.py ├── train_real.py ├── train_utils.py └── transformer_obs_encoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.zip 3 | *.mp4 4 | *checkpoints* 5 | *.json 6 | *stats* 7 | *.pth 8 | *.ckpt 9 | *outputs* 10 | *r3m* 11 | *deprecated* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jeon Ho Kang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compliant Object Manipulation for High Precision Prying Task Using Diffusion Policy with Force Modality 2 | 3 | [[project page](https://rros-lab.github.io/diffusion-with-force.github.io/)] [[data](https://drive.google.com/drive/folders/1Mgbf2isA3XL6OeCrQGP3ahebH5lbbQgB?usp=drive_link)] [[Paper](https://arxiv.org/abs/2503.03998)] 4 | 5 | Jeon Ho Kang, Sagar Joshi, Ruopeng Huang, and Satyandra K. Gupta 6 | 7 | University of Southern California 8 | 9 | ![System Architecture](imgs/overview_system.png) 10 | 11 | Baseline code for diffusion policy was derived from [Diffusion Policy](https://github.com/real-stanford/diffusion_policy) 12 | 13 | All **Real** tags are for real robot implementation 14 | 15 | However, [data_util.py](data_util.py) is shared for real and test 16 | 17 | 18 | ## Dependencies 19 | 20 | Create Conda Environment (Recommended) and run: 21 | 22 | 23 | ```bash 24 | $ pip install requirements.txt 25 | ``` 26 | 27 | ## Real Robot 28 | 29 | For all demonstrations, we used [KUKA IIWA 14 Robot](https://www.kuka.com/en-de/products/robot-systems/industrial-robots/lbr-iiwa) 30 | 31 | 32 | ## Real Robot Data for Prying Task (Zarr File) 33 | Data collected on Kuka IIWA 14 robot containing robot state, image, force and action will be published [here](https://drive.google.com/drive/folders/1Mgbf2isA3XL6OeCrQGP3ahebH5lbbQgB?usp=drive_link) 34 | 35 | 36 | To collect your own data: 37 | 38 | After obtaininig joint state from handguiding or any other methods, 39 | 40 | Run 41 | 42 | ```bash 43 | 44 | $ python robot_data_collection_joint_state.py 45 | 46 | ``` 47 | 48 | 49 | ## Training Your Own Policy 50 | 51 | 52 | After loading your own zarr file or ours in [real_robot_network.py](real_robot_network.py) 53 | 54 | ```bash 55 | $ python train_real.py 56 | ``` 57 | 58 | You can select or create your own [Config](config) file for training configuration 59 | 60 | 61 | ## Inference 62 | 63 | ```bash 64 | $ python inference_real.py 65 | ``` 66 | 67 | ## Acknowledgement 68 | 69 | + Diffusion policy was adapted from [Diffusion Policy](https://github.com/real-stanford/diffusion_policy) 70 | 71 | 72 | 73 | ## 📚 Citation 74 | 75 | If you find our work useful or interesting, please consider citing: 76 | 77 | ```bibtex 78 | @article{kang2025robotic, 79 | author = {Jeon Ho Kang and Sagar Joshi and Ruopeng Huang and Satyandra K. Gupta}, 80 | title = {Robotic Compliant Object Prying Using Diffusion Policy Guided by Vision and Force Observations}, 81 | journal = {IEEE Robotics and Automation Letters}, 82 | year = {2025}, 83 | note = {Accepted for publication. Available on arXiv: arXiv:2503.03998}, 84 | } 85 | -------------------------------------------------------------------------------- /config/resnet_delta_force_mod_single_view_force_Linear_encode.yaml: -------------------------------------------------------------------------------- 1 | # resnet+delta_force_mod_single_veiw_force_encode.yaml 2 | defaults: 3 | - _self_ 4 | 5 | name: resnet_delta_no_force_single_view_force_no_encode 6 | 7 | model_config: 8 | continue_training: False 9 | start_epoch: 0 10 | end_epoch: 3000 11 | encoder: "resnet" 12 | action_def: "delta" 13 | force_mod: True 14 | single_view: True 15 | force_encode: True 16 | force_encoder: "Linear" 17 | cross_attn: False 18 | hybrid: False 19 | duplicate_view: False 20 | crop: 1000 # 128 -------------------------------------------------------------------------------- /config/resnet_delta_force_mod_single_view_force_MLP_encode.yaml: -------------------------------------------------------------------------------- 1 | # resnet+delta_force_mod_single_veiw_force_encode.yaml 2 | defaults: 3 | - _self_ 4 | 5 | name: resnet_delta_no_force_single_view_force_MLP_encode 6 | 7 | model_config: 8 | continue_training: False 9 | start_epoch: 0 10 | end_epoch: 3000 11 | encoder: "resnet" 12 | action_def: "delta" 13 | force_mod: True 14 | single_view: True 15 | force_encode: True 16 | force_encoder: "MLP" 17 | cross_attn: False 18 | hybrid: False 19 | duplicate_view: False 20 | crop: 1000 # 128 -------------------------------------------------------------------------------- /config/resnet_delta_force_mod_single_view_no_encode.yaml: -------------------------------------------------------------------------------- 1 | # resnet+delta_force_mod_single_veiw_force_encode.yaml 2 | defaults: 3 | - _self_ 4 | 5 | name: resnet_delta_no_force_single_view_force_no_encode 6 | 7 | model_config: 8 | continue_training: False 9 | start_epoch: 0 10 | end_epoch: 3000 11 | encoder: "resnet" 12 | action_def: "delta" 13 | force_mod: True 14 | single_view: True 15 | force_encode: False 16 | force_encoder: "Linear" 17 | cross_attn: False 18 | hybrid: False 19 | duplicate_view: False 20 | crop: 1000 # 128 -------------------------------------------------------------------------------- /config/resnet_delta_no_force_single_view.yaml: -------------------------------------------------------------------------------- 1 | # resnet+delta_force_mod_single_veiw_force_encode.yaml 2 | defaults: 3 | - _self_ 4 | 5 | name: resnet_delta_no_force_single_view_baseline 6 | 7 | model_config: 8 | continue_training: False 9 | start_epoch: 0 10 | end_epoch: 3000 11 | encoder: "resnet" 12 | action_def: "delta" 13 | force_mod: False 14 | single_view: True 15 | force_encode: False 16 | force_encoder: "Linear" 17 | cross_attn: False 18 | hybrid: False 19 | duplicate_view: False 20 | crop: 1000 # 128 -------------------------------------------------------------------------------- /config/resnet_delta_with_force_single_view_force_Linear_crossattn_hybrid_crop.yaml: -------------------------------------------------------------------------------- 1 | # resnet+delta_force_mod_single_veiw_force_encode.yaml 2 | defaults: 3 | - _self_ 4 | 5 | name: resnet_delta_with_force_single_view_force_Linear_crossattn_hybrid_crop98 6 | 7 | model_config: 8 | continue_training: False 9 | start_epoch: 0 10 | end_epoch: 3000 11 | encoder: "resnet" 12 | action_def: "delta" 13 | force_mod: True 14 | single_view: True 15 | force_encode: False 16 | force_encoder: "Linear" 17 | cross_attn: True 18 | hybrid: True 19 | duplicate_view: True 20 | crop: 98 # 128 -------------------------------------------------------------------------------- /data_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 3 | import numpy as np 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torchvision 8 | import collections 9 | import zarr 10 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 11 | from diffusers.training_utils import EMAModel 12 | from diffusers.optimization import get_scheduler 13 | from tqdm.auto import tqdm 14 | import os 15 | from torchvision import transforms 16 | from copy import deepcopy 17 | 18 | #@markdown ### **Dataset** 19 | #@markdown 20 | #@markdown Defines `PushTImageDataset` and helper functions 21 | #@markdown 22 | #@markdown The dataset class 23 | #@markdown - Load data ((image, agent_pos), action) from a zarr storage 24 | #@markdown - Normalizes each dimension of agent_pos and action to [-1,1] 25 | #@markdown - Returns 26 | #@markdown - All possible segments with length `pred_horizon` 27 | #@markdown - Pads the beginning and the end of each episode with repetition 28 | #@markdown - key `image`: shape (obs_hoirzon, 3, 96, 96) 29 | #@markdown - key `agent_pos`: shape (obs_hoirzon, 2) 30 | #@markdown - key `action`: shape (pred_horizon, 2) 31 | 32 | class data_utils: 33 | def __init__(self): 34 | pass 35 | def create_sample_indices( 36 | episode_ends:np.ndarray, sequence_length:int, 37 | pad_before: int=0, pad_after: int=0): 38 | indices = list() 39 | for i in range(len(episode_ends)): 40 | start_idx = 0 41 | if i > 0: 42 | start_idx = episode_ends[i-1] 43 | end_idx = episode_ends[i] 44 | episode_length = end_idx - start_idx 45 | 46 | min_start = -pad_before 47 | max_start = episode_length - sequence_length + pad_after 48 | 49 | # range stops one idx before end 50 | for idx in range(min_start, max_start+1): 51 | buffer_start_idx = max(idx, 0) + start_idx 52 | buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx 53 | start_offset = buffer_start_idx - (idx+start_idx) 54 | end_offset = (idx+sequence_length+start_idx) - buffer_end_idx 55 | sample_start_idx = 0 + start_offset 56 | sample_end_idx = sequence_length - end_offset 57 | indices.append([ 58 | buffer_start_idx, buffer_end_idx, 59 | sample_start_idx, sample_end_idx]) 60 | indices = np.array(indices) 61 | return indices 62 | 63 | 64 | def sample_sequence(train_data, sequence_length, 65 | buffer_start_idx, buffer_end_idx, 66 | sample_start_idx, sample_end_idx): 67 | result = dict() 68 | for key, input_arr in train_data.items(): 69 | sample = input_arr[buffer_start_idx:buffer_end_idx] 70 | data = sample 71 | if (sample_start_idx > 0) or (sample_end_idx < sequence_length): 72 | data = np.zeros( 73 | shape=(sequence_length,) + input_arr.shape[1:], 74 | dtype=input_arr.dtype) 75 | if sample_start_idx > 0: 76 | data[:sample_start_idx] = sample[0] 77 | if sample_end_idx < sequence_length: 78 | data[sample_end_idx:] = sample[-1] 79 | data[sample_start_idx:sample_end_idx] = sample 80 | result[key] = data 81 | return result 82 | 83 | # normalize data 84 | def get_data_stats(data): 85 | data = data.reshape(-1,data.shape[-1]) 86 | stats = { 87 | 'min': np.min(data, axis=0), 88 | 'max': np.max(data, axis=0) 89 | } 90 | return stats 91 | 92 | def normalize_data(data, stats): 93 | # nomalize to [0,1] 94 | ndata = (data - stats['min']) / (stats['max'] - stats['min']) 95 | # normalize to [-1, 1] 96 | ndata = ndata * 2 - 1 97 | return ndata 98 | 99 | def unnormalize_data(ndata, stats): 100 | ndata = (ndata + 1) / 2 101 | data = ndata * (stats['max'] - stats['min']) + stats['min'] 102 | return data 103 | 104 | def process_quaternion(quaternion_array): 105 | negative_real_indices = quaternion_array[:, 3] < 0 106 | # Negate the entire quaternion for rows where the real component is negative 107 | quaternion_array[negative_real_indices] *= -1 108 | 109 | return quaternion_array 110 | 111 | def normalize_force_vector(force_vector): 112 | 113 | # Calculate the magnitude of each force vector along axis 1 114 | magnitudes = np.linalg.norm(force_vector, axis=1, keepdims=True) 115 | 116 | # Avoid division by zero by using a small epsilon value 117 | magnitudes[magnitudes == 0] = 1e-8 118 | 119 | # Normalize each force vector by dividing by its magnitude 120 | normalized_forces = force_vector / magnitudes 121 | return magnitudes, normalized_forces 122 | 123 | def normalize_force_magnitude(data, stats): 124 | # nomalize to [0,1] 125 | ndata = (data - stats['min']) / (stats['max'] - stats['min']) 126 | return ndata 127 | 128 | def center_crop(images, crop_height, crop_width): 129 | # Get original dimensions 130 | N, C, H, W = images.shape 131 | assert crop_height <= H and crop_width <= W, "Crop size should be smaller than the original size" 132 | 133 | # Calculate the center + 20 only when using 98 and 128 is -20 for start_x only 134 | start_y = (H - crop_height + 20) // 2 135 | start_x = (W - crop_width - 20) // 2 136 | # start_y = (H - crop_height) // 2 137 | # start_x = (W - crop_width - 20) // 2 138 | # Perform cropping 139 | cropped_images = images[:, :, start_y:start_y + crop_height, start_x:start_x + crop_width] 140 | 141 | return cropped_images 142 | 143 | def regular_center_crop(images, crop_height, crop_width): 144 | # Get original dimensions 145 | N, C, H, W = images.shape 146 | assert crop_height <= H and crop_width <= W, "Crop size should be smaller than the original size" 147 | 148 | # Calculate the center + 20 only when using 98 and 128 is -20 for start_x only 149 | start_y = (H - crop_height) // 2 150 | start_x = (W - crop_width) // 2 151 | # start_y = (H - crop_height) // 2 152 | # start_x = (W - crop_width - 20) // 2 153 | # Perform cropping 154 | cropped_images = images[:, :, start_y:start_y + crop_height, start_x:start_x + crop_width] 155 | 156 | return cropped_images 157 | 158 | # dataset 159 | 160 | class PushTImageDataset(torch.utils.data.Dataset): 161 | def __init__(self, 162 | dataset_path: str, 163 | pred_horizon: int, 164 | obs_horizon: int, 165 | action_horizon: int): 166 | 167 | # read from zarr dataset 168 | dataset_root = zarr.open(dataset_path, 'r') 169 | 170 | # float32, [0,1], (N,96,96,3) 171 | train_image_data = dataset_root['data']['img'][:] 172 | train_image_data = np.moveaxis(train_image_data, -1,1) 173 | # Perform center cropping to 224x224 174 | # (N,3,96,96) 175 | # (N, D) 176 | train_data = { 177 | # first two dims of state vector are agent (i.e. gripper) locations 178 | 'agent_pos': dataset_root['data']['state'][:,:2], 179 | 'action': dataset_root['data']['action'][:] 180 | } 181 | episode_ends = dataset_root['meta']['episode_ends'][:] 182 | 183 | # compute start and end of each state-action sequence 184 | # also handles padding 185 | indices = data_utils.create_sample_indices( 186 | episode_ends=episode_ends, 187 | sequence_length=pred_horizon, 188 | pad_before=obs_horizon-1, 189 | pad_after=action_horizon-1) 190 | 191 | # compute statistics and normalized data to [-1,1] 192 | stats = dict() 193 | normalized_train_data = dict() 194 | for key, data in train_data.items(): 195 | stats[key] = data_utils.get_data_stats(data) 196 | normalized_train_data[key] = data_utils.normalize_data(data, stats[key]) 197 | 198 | # images are already normalized 199 | normalized_train_data['image'] = train_image_data 200 | 201 | self.indices = indices 202 | self.stats = stats 203 | self.normalized_train_data = normalized_train_data 204 | self.pred_horizon = pred_horizon 205 | self.action_horizon = action_horizon 206 | self.obs_horizon = obs_horizon 207 | 208 | def __len__(self): 209 | return len(self.indices) 210 | 211 | def __getitem__(self, idx): 212 | # get the start/end indices for this datapoint 213 | buffer_start_idx, buffer_end_idx, \ 214 | sample_start_idx, sample_end_idx = self.indices[idx] 215 | 216 | # get nomralized data using these indices 217 | nsample = data_utils.sample_sequence( 218 | train_data=self.normalized_train_data, 219 | sequence_length=self.pred_horizon, 220 | buffer_start_idx=buffer_start_idx, 221 | buffer_end_idx=buffer_end_idx, 222 | sample_start_idx=sample_start_idx, 223 | sample_end_idx=sample_end_idx 224 | ) 225 | 226 | # discard unused observations 227 | nsample['image'] = nsample['image'][:self.obs_horizon,:] 228 | nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:] 229 | return nsample 230 | 231 | 232 | 233 | 234 | #@markdown ### **Dataset Demo** 235 | class RealRobotDataSet(torch.utils.data.Dataset): 236 | def __init__(self, 237 | dataset_path: str, 238 | pred_horizon: int, 239 | obs_horizon: int, 240 | action_horizon: int, 241 | Transformer: bool = False, 242 | force_mod: bool = False, 243 | single_view: bool = False, 244 | augment: bool = False, 245 | duplicate_view = False, 246 | crop: int = 1000): 247 | 248 | # read from zarr dataset 249 | dataset_root = zarr.open(dataset_path, 'r') 250 | if single_view: 251 | train_image_data = dataset_root['data']['images_B'][:] 252 | train_image_data = np.moveaxis(train_image_data, -1,1) 253 | 254 | else: 255 | # float32, [0,1], (N,96,96,3) 256 | train_image_data = dataset_root['data']['images_B'][:] 257 | train_image_data = np.moveaxis(train_image_data, -1,1) 258 | train_image_data_second_view = dataset_root['data']['images_A'][:] 259 | train_image_data_second_view = np.moveaxis(train_image_data_second_view, -1,1) 260 | 261 | train_image_data = regular_center_crop(train_image_data, 224, 224) 262 | if duplicate_view: 263 | duplicate_image_view = deepcopy(train_image_data) 264 | if crop == 98: 265 | # If crop parameter 64 266 | train_image_data_copy = center_crop(duplicate_image_view, crop, crop) 267 | else: 268 | ("No Cropping") 269 | else: 270 | if crop == 98: 271 | train_image_data = center_crop(train_image_data, crop, crop) 272 | else: 273 | print("No image change") 274 | 275 | # (N,3,96,96) 276 | # (N, D) 277 | train_data = { 278 | # first seven dims of state vector are agent (i.e. gripper) locations 279 | # Seven because we will use quaternion 280 | 'agent_pos': dataset_root['data']['state'][:,:9], 281 | 'action': dataset_root['data']['action'] 282 | } 283 | episode_ends = dataset_root['meta']['episode_ends'][:] 284 | 285 | # compute start and end of each state-action sequence 286 | # also handles padding 287 | indices = data_utils.create_sample_indices( 288 | episode_ends=episode_ends, 289 | sequence_length=pred_horizon, 290 | pad_before=obs_horizon-1, 291 | pad_after=action_horizon-1) 292 | 293 | # compute statistics and normalized data to [-1,1] 294 | stats = dict() 295 | normalized_train_data = dict() 296 | for key, data in train_data.items(): 297 | stats[key] = data_utils.get_data_stats(data[:,:3]) 298 | normalized_position = data_utils.normalize_data(data[:,:3], stats[key]) 299 | normalized_orientation = data[:,3:9] 300 | # normalized_orientation = data_utils.process_quaternion(data[:,3:7]) 301 | normalized_train_data[key] = np.hstack((normalized_position, normalized_orientation)) 302 | ## TODO: Add code that will handle - and + sign for quaternion 303 | if force_mod: 304 | train_force_data = dataset_root['data']['state'][:,9:12] 305 | magnitudes, normalized_force_direction = data_utils.normalize_force_vector(train_force_data) 306 | stats['force_mag'] = data_utils.get_data_stats(magnitudes) 307 | normalized_force_mag = data_utils.normalize_force_magnitude(magnitudes, stats['force_mag']) 308 | normalized_force_data = np.hstack((normalized_force_mag, normalized_force_direction)) 309 | 310 | # Start adding normalized training data 311 | normalized_train_data['image'] = train_image_data 312 | if duplicate_view: 313 | normalized_train_data['duplicate_image'] = train_image_data_copy 314 | 315 | # images are already normalized 316 | if force_mod: 317 | normalized_train_data['force'] = normalized_force_data 318 | if not single_view: 319 | normalized_train_data['image2'] = train_image_data_second_view 320 | 321 | self.indices = indices 322 | self.stats = stats 323 | self.normalized_train_data = normalized_train_data 324 | self.pred_horizon = pred_horizon 325 | self.action_horizon = action_horizon 326 | self.obs_horizon = obs_horizon 327 | self.force_mod = force_mod 328 | self.single_view = single_view 329 | self.augment = augment 330 | self.crop = crop 331 | self.duplicate_view = duplicate_view 332 | if self.augment: 333 | if self.crop == 98: 334 | self.augmentation_transform = transforms.Compose([ 335 | transforms.RandomResizedCrop(size=(crop, crop), scale=(0.5, 1.5)), 336 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1), 337 | ]) 338 | else: 339 | self.augmentation_transform = transforms.Compose([ 340 | transforms.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.5)), 341 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.2), 342 | transforms.Resize((224, 224)), # Ensures final output is exactly 224x224 343 | 344 | ]) 345 | 346 | def __len__(self): 347 | return len(self.indices) 348 | 349 | def __getitem__(self, idx): 350 | # get the start/end indices for this datapoint 351 | buffer_start_idx, buffer_end_idx, \ 352 | sample_start_idx, sample_end_idx = self.indices[idx] 353 | 354 | # get nomralized data using these indices 355 | nsample = data_utils.sample_sequence( 356 | train_data=self.normalized_train_data, 357 | sequence_length=self.pred_horizon, 358 | buffer_start_idx=buffer_start_idx, 359 | buffer_end_idx=buffer_end_idx, 360 | sample_start_idx=sample_start_idx, 361 | sample_end_idx=sample_end_idx 362 | ) 363 | 364 | # Apply augmentation to images only 365 | if self.augment: 366 | # Convert image to PIL format for augmentation if needed 367 | if self.duplicate_view: 368 | img_tensor = torch.tensor(nsample['duplicate_image'][:self.obs_horizon, :]) 369 | img_augmented = [self.augmentation_transform(img) for img in img_tensor] 370 | nsample['duplicate_image'] = torch.stack(img_augmented) 371 | nsample['duplicate_image'] = np.array(nsample['duplicate_image']) 372 | nsample['image'] = nsample['image'][:self.obs_horizon,:] 373 | 374 | else: 375 | img_tensor = torch.tensor(nsample['image'][:self.obs_horizon, :]) 376 | img_augmented = [self.augmentation_transform(img) for img in img_tensor] 377 | nsample['image'] = torch.stack(img_augmented) 378 | nsample['image'] = np.array(nsample['image']) 379 | nsample['image'] = nsample['image'][:self.obs_horizon,:] 380 | 381 | if not self.single_view: 382 | img_tensor2 = torch.tensor(nsample['image2'][:self.obs_horizon, :]) 383 | img_augmented2 = [self.augmentation_transform(img) for img in img_tensor2] 384 | nsample['image2'] = torch.stack(img_augmented2) 385 | nsample['image2'] = np.array(nsample['image2']) 386 | else: 387 | # Convert images to tensor without augmentation 388 | nsample['image'] = nsample['image'][:self.obs_horizon, :] 389 | if not self.single_view: 390 | nsample['image2'] = nsample['image2'][:self.obs_horizon, :] 391 | 392 | nsample['agent_pos'] = nsample['agent_pos'][:self.obs_horizon,:] 393 | if self.duplicate_view: 394 | nsample['duplicate_image'] = nsample['duplicate_image'][:self.obs_horizon, :] 395 | if self.force_mod: 396 | # discard unused observations 397 | if self.augment: 398 | noise_std = 0.00005 399 | force_arr = nsample['force'][:self.obs_horizon, :] 400 | scaling_factors = np.random.uniform(0.9, 1.1) 401 | force_augmented = force_arr * scaling_factors + np.random.normal(0, noise_std, size=force_arr.shape) 402 | nsample['force'] = force_augmented.astype(np.float32) 403 | else: 404 | nsample['force'] = nsample['force'][:self.obs_horizon,:] 405 | # if not self.single_view: 406 | # # discard unused observations 407 | # nsample['image2'] = nsample['image2'][:self.obs_horizon,:] 408 | 409 | return nsample 410 | -------------------------------------------------------------------------------- /env_util.py: -------------------------------------------------------------------------------- 1 | # env import 2 | import gym 3 | from gym import spaces 4 | import pygame 5 | import pymunk 6 | import pymunk.pygame_util 7 | from pymunk.space_debug_draw_options import SpaceDebugColor 8 | from pymunk.vec2d import Vec2d 9 | import shapely.geometry as sg 10 | import cv2 11 | import skimage.transform as st 12 | from skvideo.io import vwrite 13 | from IPython.display import Video 14 | import gdown 15 | import os 16 | import numpy as np 17 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 18 | import collections 19 | 20 | 21 | #@markdown ### **Environment** 22 | #@markdown Defines a PyMunk-based Push-T environment `PushTEnv`. 23 | #@markdown And it's subclass `PushTImageEnv`. 24 | #@markdown 25 | #@markdown **Goal**: push the gray T-block into the green area. 26 | #@markdown 27 | #@markdown Adapted from [Implicit Behavior Cloning](https://implicitbc.github.io/) 28 | 29 | 30 | positive_y_is_up: bool = False 31 | """Make increasing values of y point upwards. 32 | 33 | When True:: 34 | 35 | y 36 | ^ 37 | | . (3, 3) 38 | | 39 | | . (2, 2) 40 | | 41 | +------ > x 42 | 43 | When False:: 44 | 45 | +------ > x 46 | | 47 | | . (2, 2) 48 | | 49 | | . (3, 3) 50 | v 51 | y 52 | 53 | """ 54 | 55 | def to_pygame(p: Tuple[float, float], surface: pygame.Surface) -> Tuple[int, int]: 56 | """Convenience method to convert pymunk coordinates to pygame surface 57 | local coordinates. 58 | 59 | Note that in case positive_y_is_up is False, this function wont actually do 60 | anything except converting the point to integers. 61 | """ 62 | if positive_y_is_up: 63 | return round(p[0]), surface.get_height() - round(p[1]) 64 | else: 65 | return round(p[0]), round(p[1]) 66 | 67 | 68 | def light_color(color: SpaceDebugColor): 69 | color = np.minimum(1.2 * np.float32([color.r, color.g, color.b, color.a]), np.float32([255])) 70 | color = SpaceDebugColor(r=color[0], g=color[1], b=color[2], a=color[3]) 71 | return color 72 | 73 | class DrawOptions(pymunk.SpaceDebugDrawOptions): 74 | def __init__(self, surface: pygame.Surface) -> None: 75 | """Draw a pymunk.Space on a pygame.Surface object. 76 | 77 | Typical usage:: 78 | 79 | >>> import pymunk 80 | >>> surface = pygame.Surface((10,10)) 81 | >>> space = pymunk.Space() 82 | >>> options = pymunk.pygame_util.DrawOptions(surface) 83 | >>> space.debug_draw(options) 84 | 85 | You can control the color of a shape by setting shape.color to the color 86 | you want it drawn in:: 87 | 88 | >>> c = pymunk.Circle(None, 10) 89 | >>> c.color = pygame.Color("pink") 90 | 91 | See pygame_util.demo.py for a full example 92 | 93 | Since pygame uses a coordiante system where y points down (in contrast 94 | to many other cases), you either have to make the physics simulation 95 | with Pymunk also behave in that way, or flip everything when you draw. 96 | 97 | The easiest is probably to just make the simulation behave the same 98 | way as Pygame does. In that way all coordinates used are in the same 99 | orientation and easy to reason about:: 100 | 101 | >>> space = pymunk.Space() 102 | >>> space.gravity = (0, -1000) 103 | >>> body = pymunk.Body() 104 | >>> body.position = (0, 0) # will be positioned in the top left corner 105 | >>> space.debug_draw(options) 106 | 107 | To flip the drawing its possible to set the module property 108 | :py:data:`positive_y_is_up` to True. Then the pygame drawing will flip 109 | the simulation upside down before drawing:: 110 | 111 | >>> positive_y_is_up = True 112 | >>> body = pymunk.Body() 113 | >>> body.position = (0, 0) 114 | >>> # Body will be position in bottom left corner 115 | 116 | :Parameters: 117 | surface : pygame.Surface 118 | Surface that the objects will be drawn on 119 | """ 120 | self.surface = surface 121 | super(DrawOptions, self).__init__() 122 | 123 | def draw_circle( 124 | self, 125 | pos: Vec2d, 126 | angle: float, 127 | radius: float, 128 | outline_color: SpaceDebugColor, 129 | fill_color: SpaceDebugColor, 130 | ) -> None: 131 | p = to_pygame(pos, self.surface) 132 | 133 | pygame.draw.circle(self.surface, fill_color.as_int(), p, round(radius), 0) 134 | pygame.draw.circle(self.surface, light_color(fill_color).as_int(), p, round(radius-4), 0) 135 | 136 | circle_edge = pos + Vec2d(radius, 0).rotated(angle) 137 | p2 = to_pygame(circle_edge, self.surface) 138 | line_r = 2 if radius > 20 else 1 139 | # pygame.draw.lines(self.surface, outline_color.as_int(), False, [p, p2], line_r) 140 | 141 | def draw_segment(self, a: Vec2d, b: Vec2d, color: SpaceDebugColor) -> None: 142 | p1 = to_pygame(a, self.surface) 143 | p2 = to_pygame(b, self.surface) 144 | 145 | pygame.draw.aalines(self.surface, color.as_int(), False, [p1, p2]) 146 | 147 | def draw_fat_segment( 148 | self, 149 | a: Tuple[float, float], 150 | b: Tuple[float, float], 151 | radius: float, 152 | outline_color: SpaceDebugColor, 153 | fill_color: SpaceDebugColor, 154 | ) -> None: 155 | p1 = to_pygame(a, self.surface) 156 | p2 = to_pygame(b, self.surface) 157 | 158 | r = round(max(1, radius * 2)) 159 | pygame.draw.lines(self.surface, fill_color.as_int(), False, [p1, p2], r) 160 | if r > 2: 161 | orthog = [abs(p2[1] - p1[1]), abs(p2[0] - p1[0])] 162 | if orthog[0] == 0 and orthog[1] == 0: 163 | return 164 | scale = radius / (orthog[0] * orthog[0] + orthog[1] * orthog[1]) ** 0.5 165 | orthog[0] = round(orthog[0] * scale) 166 | orthog[1] = round(orthog[1] * scale) 167 | points = [ 168 | (p1[0] - orthog[0], p1[1] - orthog[1]), 169 | (p1[0] + orthog[0], p1[1] + orthog[1]), 170 | (p2[0] + orthog[0], p2[1] + orthog[1]), 171 | (p2[0] - orthog[0], p2[1] - orthog[1]), 172 | ] 173 | pygame.draw.polygon(self.surface, fill_color.as_int(), points) 174 | pygame.draw.circle( 175 | self.surface, 176 | fill_color.as_int(), 177 | (round(p1[0]), round(p1[1])), 178 | round(radius), 179 | ) 180 | pygame.draw.circle( 181 | self.surface, 182 | fill_color.as_int(), 183 | (round(p2[0]), round(p2[1])), 184 | round(radius), 185 | ) 186 | 187 | def draw_polygon( 188 | self, 189 | verts: Sequence[Tuple[float, float]], 190 | radius: float, 191 | outline_color: SpaceDebugColor, 192 | fill_color: SpaceDebugColor, 193 | ) -> None: 194 | ps = [to_pygame(v, self.surface) for v in verts] 195 | ps += [ps[0]] 196 | 197 | radius = 2 198 | pygame.draw.polygon(self.surface, light_color(fill_color).as_int(), ps) 199 | 200 | if radius > 0: 201 | for i in range(len(verts)): 202 | a = verts[i] 203 | b = verts[(i + 1) % len(verts)] 204 | self.draw_fat_segment(a, b, radius, fill_color, fill_color) 205 | 206 | def draw_dot( 207 | self, size: float, pos: Tuple[float, float], color: SpaceDebugColor 208 | ) -> None: 209 | p = to_pygame(pos, self.surface) 210 | pygame.draw.circle(self.surface, color.as_int(), p, round(size), 0) 211 | 212 | 213 | def pymunk_to_shapely(body, shapes): 214 | geoms = list() 215 | for shape in shapes: 216 | if isinstance(shape, pymunk.shapes.Poly): 217 | verts = [body.local_to_world(v) for v in shape.get_vertices()] 218 | verts += [verts[0]] 219 | geoms.append(sg.Polygon(verts)) 220 | else: 221 | raise RuntimeError(f'Unsupported shape type {type(shape)}') 222 | geom = sg.MultiPolygon(geoms) 223 | return geom 224 | 225 | # env 226 | class PushTEnv(gym.Env): 227 | metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 10} 228 | reward_range = (0., 1.) 229 | 230 | def __init__(self, 231 | legacy=False, 232 | block_cog=None, damping=None, 233 | render_action=True, 234 | render_size=96, 235 | reset_to_state=None 236 | ): 237 | self._seed = None 238 | self.seed() 239 | self.window_size = ws = 512 # The size of the PyGame window 240 | self.render_size = render_size 241 | self.sim_hz = 100 242 | # Local controller params. 243 | self.k_p, self.k_v = 100, 20 # PD control.z 244 | self.control_hz = self.metadata['video.frames_per_second'] 245 | # legcay set_state for data compatiblity 246 | self.legacy = legacy 247 | 248 | # agent_pos, block_pos, block_angle 249 | self.observation_space = spaces.Box( 250 | low=np.array([0,0,0,0,0], dtype=np.float64), 251 | high=np.array([ws,ws,ws,ws,np.pi*2], dtype=np.float64), 252 | shape=(5,), 253 | dtype=np.float64 254 | ) 255 | 256 | # positional goal for agent 257 | self.action_space = spaces.Box( 258 | low=np.array([0,0], dtype=np.float64), 259 | high=np.array([ws,ws], dtype=np.float64), 260 | shape=(2,), 261 | dtype=np.float64 262 | ) 263 | 264 | self.block_cog = block_cog 265 | self.damping = damping 266 | self.render_action = render_action 267 | 268 | """ 269 | If human-rendering is used, `self.window` will be a reference 270 | to the window that we draw to. `self.clock` will be a clock that is used 271 | to ensure that the environment is rendered at the correct framerate in 272 | human-mode. They will remain `None` until human-mode is used for the 273 | first time. 274 | """ 275 | self.window = None 276 | self.clock = None 277 | self.screen = None 278 | 279 | self.space = None 280 | self.teleop = None 281 | self.render_buffer = None 282 | self.latest_action = None 283 | self.reset_to_state = reset_to_state 284 | 285 | def reset(self): 286 | seed = self._seed 287 | self._setup() 288 | if self.block_cog is not None: 289 | self.block.center_of_gravity = self.block_cog 290 | if self.damping is not None: 291 | self.space.damping = self.damping 292 | 293 | # use legacy RandomState for compatiblity 294 | state = self.reset_to_state 295 | if state is None: 296 | rs = np.random.RandomState(seed=seed) 297 | state = np.array([ 298 | rs.randint(50, 450), rs.randint(50, 450), 299 | rs.randint(100, 400), rs.randint(100, 400), 300 | rs.randn() * 2 * np.pi - np.pi 301 | ]) 302 | self._set_state(state) 303 | 304 | obs = self._get_obs() 305 | info = self._get_info() 306 | return obs, info 307 | 308 | def step(self, action): 309 | dt = 1.0 / self.sim_hz 310 | self.n_contact_points = 0 311 | n_steps = self.sim_hz // self.control_hz 312 | if action is not None: 313 | self.latest_action = action 314 | for i in range(n_steps): 315 | # Step PD control. 316 | # self.agent.velocity = self.k_p * (act - self.agent.position) # P control works too. 317 | acceleration = self.k_p * (action - self.agent.position) + self.k_v * (Vec2d(0, 0) - self.agent.velocity) 318 | self.agent.velocity += acceleration * dt 319 | 320 | # Step physics. 321 | self.space.step(dt) 322 | 323 | # compute reward 324 | goal_body = self._get_goal_pose_body(self.goal_pose) 325 | goal_geom = pymunk_to_shapely(goal_body, self.block.shapes) 326 | block_geom = pymunk_to_shapely(self.block, self.block.shapes) 327 | 328 | intersection_area = goal_geom.intersection(block_geom).area 329 | goal_area = goal_geom.area 330 | coverage = intersection_area / goal_area 331 | reward = np.clip(coverage / self.success_threshold, 0, 1) 332 | done = coverage > self.success_threshold 333 | terminated = done 334 | truncated = done 335 | 336 | observation = self._get_obs() 337 | info = self._get_info() 338 | 339 | return observation, reward, terminated, truncated, info 340 | 341 | def render(self, mode): 342 | return self._render_frame(mode) 343 | 344 | def teleop_agent(self): 345 | TeleopAgent = collections.namedtuple('TeleopAgent', ['act']) 346 | def act(obs): 347 | act = None 348 | mouse_position = pymunk.pygame_util.from_pygame(Vec2d(*pygame.mouse.get_pos()), self.screen) 349 | if self.teleop or (mouse_position - self.agent.position).length < 30: 350 | self.teleop = True 351 | act = mouse_position 352 | return act 353 | return TeleopAgent(act) 354 | 355 | def _get_obs(self): 356 | obs = np.array( 357 | tuple(self.agent.position) \ 358 | + tuple(self.block.position) \ 359 | + (self.block.angle % (2 * np.pi),)) 360 | return obs 361 | 362 | def _get_goal_pose_body(self, pose): 363 | mass = 1 364 | inertia = pymunk.moment_for_box(mass, (50, 100)) 365 | body = pymunk.Body(mass, inertia) 366 | # preserving the legacy assignment order for compatibility 367 | # the order here dosn't matter somehow, maybe because CoM is aligned with body origin 368 | body.position = pose[:2].tolist() 369 | body.angle = pose[2] 370 | return body 371 | 372 | def _get_info(self): 373 | n_steps = self.sim_hz // self.control_hz 374 | n_contact_points_per_step = int(np.ceil(self.n_contact_points / n_steps)) 375 | info = { 376 | 'pos_agent': np.array(self.agent.position), 377 | 'vel_agent': np.array(self.agent.velocity), 378 | 'block_pose': np.array(list(self.block.position) + [self.block.angle]), 379 | 'goal_pose': self.goal_pose, 380 | 'n_contacts': n_contact_points_per_step} 381 | return info 382 | 383 | def _render_frame(self, mode): 384 | 385 | if self.window is None and mode == "human": 386 | pygame.init() 387 | pygame.display.init() 388 | self.window = pygame.display.set_mode((self.window_size, self.window_size)) 389 | if self.clock is None and mode == "human": 390 | self.clock = pygame.time.Clock() 391 | 392 | canvas = pygame.Surface((self.window_size, self.window_size)) 393 | canvas.fill((255, 255, 255)) 394 | self.screen = canvas 395 | 396 | draw_options = DrawOptions(canvas) 397 | 398 | # Draw goal pose. 399 | goal_body = self._get_goal_pose_body(self.goal_pose) 400 | for shape in self.block.shapes: 401 | goal_points = [pymunk.pygame_util.to_pygame(goal_body.local_to_world(v), draw_options.surface) for v in shape.get_vertices()] 402 | goal_points += [goal_points[0]] 403 | pygame.draw.polygon(canvas, self.goal_color, goal_points) 404 | 405 | # Draw agent and block. 406 | self.space.debug_draw(draw_options) 407 | 408 | if mode == "human": 409 | # The following line copies our drawings from `canvas` to the visible window 410 | self.window.blit(canvas, canvas.get_rect()) 411 | pygame.event.pump() 412 | pygame.display.update() 413 | 414 | # the clock is aleady ticked during in step for "human" 415 | 416 | 417 | img = np.transpose( 418 | np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2) 419 | ) 420 | img = cv2.resize(img, (self.render_size, self.render_size)) 421 | if self.render_action: 422 | if self.render_action and (self.latest_action is not None): 423 | action = np.array(self.latest_action) 424 | coord = (action / 512 * 96).astype(np.int32) 425 | marker_size = int(8/96*self.render_size) 426 | thickness = int(1/96*self.render_size) 427 | cv2.drawMarker(img, coord, 428 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 429 | markerSize=marker_size, thickness=thickness) 430 | return img 431 | 432 | 433 | def close(self): 434 | if self.window is not None: 435 | pygame.display.quit() 436 | pygame.quit() 437 | 438 | def seed(self, seed=None): 439 | if seed is None: 440 | seed = np.random.randint(0,25536) 441 | self._seed = seed 442 | self.np_random = np.random.default_rng(seed) 443 | 444 | def _handle_collision(self, arbiter, space, data): 445 | self.n_contact_points += len(arbiter.contact_point_set.points) 446 | 447 | def _set_state(self, state): 448 | if isinstance(state, np.ndarray): 449 | state = state.tolist() 450 | pos_agent = state[:2] 451 | pos_block = state[2:4] 452 | rot_block = state[4] 453 | self.agent.position = pos_agent 454 | # setting angle rotates with respect to center of mass 455 | # therefore will modify the geometric position 456 | # if not the same as CoM 457 | # therefore should be modified first. 458 | if self.legacy: 459 | # for compatiblity with legacy data 460 | self.block.position = pos_block 461 | self.block.angle = rot_block 462 | else: 463 | self.block.angle = rot_block 464 | self.block.position = pos_block 465 | 466 | # Run physics to take effect 467 | self.space.step(1.0 / self.sim_hz) 468 | 469 | def _set_state_local(self, state_local): 470 | agent_pos_local = state_local[:2] 471 | block_pose_local = state_local[2:] 472 | tf_img_obj = st.AffineTransform( 473 | translation=self.goal_pose[:2], 474 | rotation=self.goal_pose[2]) 475 | tf_obj_new = st.AffineTransform( 476 | translation=block_pose_local[:2], 477 | rotation=block_pose_local[2] 478 | ) 479 | tf_img_new = st.AffineTransform( 480 | matrix=tf_img_obj.params @ tf_obj_new.params 481 | ) 482 | agent_pos_new = tf_img_new(agent_pos_local) 483 | new_state = np.array( 484 | list(agent_pos_new[0]) + list(tf_img_new.translation) \ 485 | + [tf_img_new.rotation]) 486 | self._set_state(new_state) 487 | return new_state 488 | 489 | def _setup(self): 490 | self.space = pymunk.Space() 491 | self.space.gravity = 0, 0 492 | self.space.damping = 0 493 | self.teleop = False 494 | self.render_buffer = list() 495 | 496 | # Add walls. 497 | walls = [ 498 | self._add_segment((5, 506), (5, 5), 2), 499 | self._add_segment((5, 5), (506, 5), 2), 500 | self._add_segment((506, 5), (506, 506), 2), 501 | self._add_segment((5, 506), (506, 506), 2) 502 | ] 503 | self.space.add(*walls) 504 | 505 | # Add agent, block, and goal zone. 506 | self.agent = self.add_circle((256, 400), 15) 507 | self.block = self.add_tee((256, 300), 0) 508 | self.goal_color = pygame.Color('LightGreen') 509 | self.goal_pose = np.array([256,256,np.pi/4]) # x, y, theta (in radians) 510 | 511 | # Add collision handeling 512 | self.collision_handeler = self.space.add_collision_handler(0, 0) 513 | self.collision_handeler.post_solve = self._handle_collision 514 | self.n_contact_points = 0 515 | 516 | self.max_score = 50 * 100 517 | self.success_threshold = 0.95 # 95% coverage. 518 | 519 | def _add_segment(self, a, b, radius): 520 | shape = pymunk.Segment(self.space.static_body, a, b, radius) 521 | shape.color = pygame.Color('LightGray') # https://htmlcolorcodes.com/color-names 522 | return shape 523 | 524 | def add_circle(self, position, radius): 525 | body = pymunk.Body(body_type=pymunk.Body.KINEMATIC) 526 | body.position = position 527 | body.friction = 1 528 | shape = pymunk.Circle(body, radius) 529 | shape.color = pygame.Color('RoyalBlue') 530 | self.space.add(body, shape) 531 | return body 532 | 533 | def add_box(self, position, height, width): 534 | mass = 1 535 | inertia = pymunk.moment_for_box(mass, (height, width)) 536 | body = pymunk.Body(mass, inertia) 537 | body.position = position 538 | shape = pymunk.Poly.create_box(body, (height, width)) 539 | shape.color = pygame.Color('LightSlateGray') 540 | self.space.add(body, shape) 541 | return body 542 | 543 | def add_tee(self, position, angle, scale=30, color='LightSlateGray', mask=pymunk.ShapeFilter.ALL_MASKS()): 544 | mass = 1 545 | length = 4 546 | vertices1 = [(-length*scale/2, scale), 547 | ( length*scale/2, scale), 548 | ( length*scale/2, 0), 549 | (-length*scale/2, 0)] 550 | inertia1 = pymunk.moment_for_poly(mass, vertices=vertices1) 551 | vertices2 = [(-scale/2, scale), 552 | (-scale/2, length*scale), 553 | ( scale/2, length*scale), 554 | ( scale/2, scale)] 555 | inertia2 = pymunk.moment_for_poly(mass, vertices=vertices1) 556 | body = pymunk.Body(mass, inertia1 + inertia2) 557 | shape1 = pymunk.Poly(body, vertices1) 558 | shape2 = pymunk.Poly(body, vertices2) 559 | shape1.color = pygame.Color(color) 560 | shape2.color = pygame.Color(color) 561 | shape1.filter = pymunk.ShapeFilter(mask=mask) 562 | shape2.filter = pymunk.ShapeFilter(mask=mask) 563 | body.center_of_gravity = (shape1.center_of_gravity + shape2.center_of_gravity) / 2 564 | body.position = position 565 | body.angle = angle 566 | body.friction = 1 567 | self.space.add(body, shape1, shape2) 568 | return body 569 | 570 | 571 | class PushTImageEnv(PushTEnv): 572 | metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10} 573 | 574 | def __init__(self, 575 | legacy=False, 576 | block_cog=None, 577 | damping=None, 578 | render_size=96): 579 | super().__init__( 580 | legacy=legacy, 581 | block_cog=block_cog, 582 | damping=damping, 583 | render_size=render_size, 584 | render_action=False) 585 | ws = self.window_size 586 | self.observation_space = spaces.Dict({ 587 | 'image': spaces.Box( 588 | low=0, 589 | high=1, 590 | shape=(3,render_size,render_size), 591 | dtype=np.float32 592 | ), 593 | 'agent_pos': spaces.Box( 594 | low=0, 595 | high=ws, 596 | shape=(2,), 597 | dtype=np.float32 598 | ) 599 | }) 600 | self.render_cache = None 601 | 602 | def _get_obs(self): 603 | img = super()._render_frame(mode='rgb_array') 604 | 605 | agent_pos = np.array(self.agent.position) 606 | img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) 607 | obs = { 608 | 'image': img_obs, 609 | 'agent_pos': agent_pos 610 | } 611 | 612 | # draw action 613 | if self.latest_action is not None: 614 | action = np.array(self.latest_action) 615 | coord = (action / 512 * 96).astype(np.int32) 616 | marker_size = int(8/96*self.render_size) 617 | thickness = int(1/96*self.render_size) 618 | cv2.drawMarker(img, coord, 619 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 620 | markerSize=marker_size, thickness=thickness) 621 | self.render_cache = img 622 | 623 | return obs 624 | 625 | def render(self, mode): 626 | assert mode == 'rgb_array' 627 | 628 | if self.render_cache is None: 629 | self._get_obs() 630 | 631 | return self.render_cache 632 | 633 | #@markdown ### **Env Demo** 634 | #@markdown Standard Gym Env (0.21.0 API) 635 | 636 | # 0. create env object 637 | env = PushTImageEnv() 638 | 639 | # 1. seed env for initial state. 640 | # Seed 0-200 are used for the demonstration dataset. 641 | env.seed(1000) 642 | 643 | # 2. must reset before use 644 | obs, info = env.reset() 645 | 646 | # 3. 2D positional action space [0,512] 647 | action = env.action_space.sample() 648 | 649 | # 4. Standard gym step method 650 | obs, reward, terminated, truncated, info = env.step(action) 651 | 652 | # prints and explains each dimension of the observation and action vectors 653 | with np.printoptions(precision=4, suppress=True, threshold=5): 654 | print("obs['image'].shape:", obs['image'].shape, "float32, [0,1]") 655 | print("obs['agent_pos'].shape:", obs['agent_pos'].shape, "float32, [0,512]") 656 | print("action.shape: ", action.shape, "float32, [0,512]") 657 | 658 | -------------------------------------------------------------------------------- /imgs/overview_system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeonHoKang/Diffusion_Policy_with_Visual_Force_CrossAttn/96996263602872234644face1fca167248007a84/imgs/overview_system.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 2 | import numpy as np 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | import collections 8 | import zarr 9 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 10 | from diffusers.training_utils import EMAModel 11 | from diffusers.optimization import get_scheduler 12 | from tqdm.auto import tqdm 13 | import gdown 14 | import os 15 | from skvideo.io import vwrite 16 | from env_util import PushTImageEnv 17 | from network import ConditionalUnet1D, DiffusionPolicy 18 | from data_util import data_utils 19 | import cv2 20 | from train_utils import train_utils 21 | 22 | #@markdown ### **Loading Pretrained Checkpoint** 23 | #@markdown Set `load_pretrained = True` to load pretrained weights. 24 | 25 | #@markdown ### **Network Demo** 26 | 27 | 28 | class EvaluatePushT: 29 | # construct ResNet18 encoder 30 | # if you have multiple camera views, use seperate encoder weights for each view. 31 | def __init__(self, max_steps): 32 | 33 | diffusion = DiffusionPolicy() 34 | # num_epochs = 100 35 | ema_nets = self.load_pretrained(diffusion) 36 | # ResNet18 has output dim of 512 37 | vision_feature_dim = 512 38 | # agent_pos is 2 dimensional 39 | # lowdim_obs_dim = 2 40 | # # observation feature has 514 dims in total per step 41 | # obs_dim = vision_feature_dim + lowdim_obs_dim 42 | # action_dim = 2 43 | #@markdown ### **Inference** 44 | 45 | # limit enviornment interaction to 200 steps before termination 46 | env = PushTImageEnv() 47 | # use a seed >200 to avoid initial states seen in the training dataset 48 | env.seed(100000) 49 | 50 | # get first observation 51 | obs, info = env.reset() 52 | 53 | # keep a queue of last 2 steps of observations 54 | obs_deque = collections.deque( 55 | [obs] * diffusion.obs_horizon, maxlen=diffusion.obs_horizon) 56 | # save visualization and rewards 57 | imgs = [env.render(mode='rgb_array')] 58 | 59 | rewards = list() 60 | step_idx = 0 61 | # device transfer 62 | device = torch.device('cuda') 63 | _ = diffusion.nets.to(device) 64 | 65 | 66 | self.diffusion = diffusion 67 | self.vision_feature_dim = vision_feature_dim 68 | self.env = env 69 | self.obs = obs 70 | self.info = info 71 | self.rewards = rewards 72 | self.device = device 73 | self.obs_deque = obs_deque 74 | self.imgs = imgs 75 | self.max_steps = max_steps 76 | self.ema_nets = ema_nets 77 | self.step_idx = step_idx 78 | 79 | # the final arch has 2 parts 80 | ###### Load Pretrained 81 | def load_pretrained(self, diffusion): 82 | 83 | load_pretrained = True 84 | if load_pretrained: 85 | # ckpt_path = "/home/jeon/jeon_ws/diffusion_policy/src/diffusion_cam/checkpoints/checkpoint_100.pth" 86 | ckpt_path = "pusht_vision_100ep.ckpt" 87 | if not os.path.isfile(ckpt_path): 88 | id = "1XKpfNSlwYMGaF5CncoFaLKCDTWoLAHf1&confirm=t" 89 | gdown.download(id=id, output=ckpt_path, quiet=False) 90 | 91 | state_dict = torch.load(ckpt_path, map_location='cuda') 92 | # noise_pred_net.load_state_dict(checkpoint['model_state_dict']) 93 | # start_epoch = checkpoint['epoch'] + 1 94 | # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 95 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) 96 | # start_epoch = checkpoint['epoch'] + 1 97 | ema_nets = diffusion.nets 98 | ema_nets.load_state_dict(state_dict) 99 | print('Pretrained weights loaded.') 100 | else: 101 | print("Skipped pretrained weight loading.") 102 | return ema_nets 103 | 104 | def inference(self): 105 | diffusion = self.diffusion 106 | max_steps = self.max_steps 107 | device = self.device 108 | obs_deque = self.obs_deque 109 | imgs = self.imgs 110 | ema_nets = self.ema_nets 111 | env = self.env 112 | rewards = self.rewards 113 | step_idx = self.step_idx 114 | done = False 115 | with tqdm(total=max_steps, desc="Eval PushTImageEnv") as pbar: 116 | while not done: 117 | B = 1 118 | # stack the last obs_horizon number of observations 119 | images = np.stack([x['image'] for x in obs_deque]) 120 | agent_poses = np.stack([x['agent_pos'] for x in obs_deque]) 121 | 122 | # normalize observation 123 | nagent_poses = data_utils.normalize_data(agent_poses, stats=diffusion.stats['agent_pos']) 124 | # images are already normalized to [0,1] 125 | nimages = images 126 | 127 | # device transfer 128 | nimages = torch.from_numpy(nimages).to(device, dtype=torch.float32) 129 | # (2,3,96,96) 130 | nagent_poses = torch.from_numpy(nagent_poses).to(device, dtype=torch.float32) 131 | # (2,2) 132 | 133 | # infer action 134 | with torch.no_grad(): 135 | # get image features 136 | image_features = ema_nets['vision_encoder'](nimages) 137 | # (2,512) 138 | 139 | # concat with low-dim observations 140 | obs_features = torch.cat([image_features, nagent_poses], dim=-1) 141 | 142 | # reshape observation to (B,obs_horizon*obs_dim) 143 | obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1) 144 | 145 | # initialize action from Guassian noise 146 | noisy_action = torch.randn( 147 | (B, diffusion.pred_horizon, diffusion.action_dim), device=device) 148 | naction = noisy_action 149 | 150 | # init scheduler 151 | diffusion.noise_scheduler.set_timesteps(diffusion.num_diffusion_iters) 152 | 153 | for k in diffusion.noise_scheduler.timesteps: 154 | # predict noise 155 | noise_pred = ema_nets['noise_pred_net']( 156 | sample=naction, 157 | timestep=k, 158 | global_cond=obs_cond 159 | ) 160 | 161 | # inverse diffusion step (remove noise) 162 | naction = diffusion.noise_scheduler.step( 163 | model_output=noise_pred, 164 | timestep=k, 165 | sample=naction 166 | ).prev_sample 167 | 168 | # unnormalize action 169 | naction = naction.detach().to('cpu').numpy() 170 | # (B, pred_horizon, action_dim) 171 | naction = naction[0] 172 | action_pred = data_utils.unnormalize_data(naction, stats=diffusion.stats['action']) 173 | 174 | # only take action_horizon number of actions5 175 | start = diffusion.obs_horizon - 1 176 | end = start + diffusion.action_horizon 177 | action = action_pred[start:end,:] 178 | # (action_horizon, action_dim) 179 | 180 | # execute action_horizon number of steps 181 | # without replanning 182 | for i in range(len(action)): 183 | # stepping env 184 | obs, reward, done, _, info = env.step(action[i]) 185 | # save observations 186 | obs_deque.append(obs) 187 | # and reward/vis 188 | rewards.append(reward) 189 | imgs.append(env.render(mode='rgb_array')) 190 | 191 | # update progress bar 192 | step_idx += 1 193 | pbar.update(1) 194 | pbar.set_postfix(reward=reward) 195 | if step_idx > max_steps: 196 | done = True 197 | if done: 198 | break 199 | # print out the maximum target coverage 200 | 201 | print('Score: ', max(rewards)) 202 | return imgs 203 | 204 | 205 | def main(): 206 | max_steps = 300 207 | 208 | eval_pusht = EvaluatePushT(max_steps) 209 | imgs = eval_pusht.inference() 210 | height, width, layers = imgs[0].shape 211 | video = cv2.VideoWriter('/home/jeon/jeon_ws/diffusion_policy/src/diffusion_cam/checkpoints/vis_PUSHT.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 15, (width, height)) 212 | 213 | for img in imgs: 214 | video.write(np.uint8(img)) 215 | 216 | video.release() 217 | if __name__ == "__main__": 218 | main() -------------------------------------------------------------------------------- /kuka_execute.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import csv 3 | import math 4 | import rclpy 5 | from rclpy.node import Node 6 | from rclpy.action import ActionClient 7 | from trajectory_msgs.msg import JointTrajectory, JointTrajectoryPoint 8 | from control_msgs.action import FollowJointTrajectory 9 | import time 10 | class KukaMotionPlanning(Node): 11 | def __init__(self, current_step): 12 | super().__init__('kuka_motion_planning') 13 | self._action_client = ActionClient(self, FollowJointTrajectory, '/lbr/joint_trajectory_controller/follow_joint_trajectory') 14 | self.current_step = current_step 15 | self.joint_names = ["A1", "A2", "A3", "A4", "A5", "A6", "A7"] 16 | 17 | 18 | def send_goal(self, joint_trajectories): 19 | goal_msg = FollowJointTrajectory.Goal() 20 | trajectory_msg = JointTrajectory() 21 | trajectory_msg.joint_names = self.joint_names 22 | point = JointTrajectoryPoint() 23 | point.positions = list(joint_trajectories.position) 24 | point.time_from_start.sec = 0 # Set the seconds part to 0 25 | point.time_from_start.nanosec = int(0.3 * 1e9) # Set the nanoseconds part to 750,000,000 26 | 27 | trajectory_msg.points.append(point) 28 | goal_msg.trajectory = trajectory_msg 29 | 30 | self._action_client.wait_for_server() 31 | self._send_goal_future = self._action_client.send_goal_async(goal_msg, feedback_callback=self.feedback_callback) 32 | rclpy.spin_until_future_complete(self, self._send_goal_future) 33 | goal_handle = self._send_goal_future.result() 34 | 35 | # if not goal_handle.accepted: 36 | # self.get_logger().info(f"Goal for point {i} was rejected") 37 | # return 38 | 39 | # self.get_logger().info(f"Goal for point {i} was accepted") 40 | # Wait for the result to complete before moving to the next trajectory point 41 | get_result_future = goal_handle.get_result_async() 42 | rclpy.spin_until_future_complete(self, get_result_future) 43 | result = get_result_future.result().result 44 | 45 | def feedback_callback(self, feedback_msg): 46 | feedback = feedback_msg.feedback 47 | # self.get_logger().info(f'Feedback: {feedback}') 48 | 49 | def get_result_callback(self, future): 50 | result = future.result().result 51 | self.get_logger().info(f'Result: {result}') 52 | # rclpy.shutdown() -------------------------------------------------------------------------------- /module_attr_mixin.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ModuleAttrMixin(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self._dummy_variable = nn.Parameter(requires_grad=False) 7 | 8 | @property 9 | def device(self): 10 | return next(iter(self.parameters())).device 11 | 12 | @property 13 | def dtype(self): 14 | return next(iter(self.parameters())).dtype 15 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | #@markdown ### **Imports** 2 | # diffusion policy import 3 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 4 | import numpy as np 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 9 | import gdown 10 | import os 11 | from data_util import PushTImageDataset 12 | from train_utils import train_utils 13 | #@markdown ### **Network** 14 | #@markdown 15 | #@markdown Defines a 1D UNet architecture `ConditionalUnet1D` 16 | #@markdown as the noies prediction network 17 | #@markdown 18 | #@markdown Components 19 | #@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k 20 | #@markdown - `Downsample1d` Strided convolution to reduce temporal resolution 21 | #@markdown - `Upsample1d` Transposed convolution to increase temporal resolution 22 | #@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish 23 | #@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \ 24 | #@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection. 25 | #@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning. 26 | 27 | class SinusoidalPosEmb(nn.Module): 28 | def __init__(self, dim): 29 | super().__init__() 30 | self.dim = dim 31 | 32 | def forward(self, x): 33 | device = x.device 34 | half_dim = self.dim // 2 35 | emb = math.log(10000) / (half_dim - 1) 36 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 37 | emb = x[:, None] * emb[None, :] 38 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 39 | return emb 40 | 41 | 42 | class Downsample1d(nn.Module): 43 | def __init__(self, dim): 44 | super().__init__() 45 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 46 | 47 | def forward(self, x): 48 | return self.conv(x) 49 | 50 | class Upsample1d(nn.Module): 51 | def __init__(self, dim): 52 | super().__init__() 53 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 54 | 55 | def forward(self, x): 56 | return self.conv(x) 57 | 58 | 59 | class Conv1dBlock(nn.Module): 60 | ''' 61 | Conv1d --> GroupNorm --> Mish 62 | ''' 63 | 64 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 65 | super().__init__() 66 | 67 | self.block = nn.Sequential( 68 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 69 | nn.GroupNorm(n_groups, out_channels), 70 | nn.Mish(), 71 | ) 72 | 73 | def forward(self, x): 74 | return self.block(x) 75 | 76 | 77 | class ConditionalResidualBlock1D(nn.Module): 78 | def __init__(self, 79 | in_channels, 80 | out_channels, 81 | cond_dim, 82 | kernel_size=3, 83 | n_groups=8): 84 | super().__init__() 85 | 86 | self.blocks = nn.ModuleList([ 87 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 88 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 89 | ]) 90 | 91 | # FiLM modulation https://arxiv.org/abs/1709.07871 92 | # predicts per-channel scale and bias 93 | cond_channels = out_channels * 2 94 | self.out_channels = out_channels 95 | self.cond_encoder = nn.Sequential( 96 | nn.Mish(), 97 | nn.Linear(cond_dim, cond_channels), 98 | nn.Unflatten(-1, (-1, 1)) 99 | ) 100 | 101 | # make sure dimensions compatible 102 | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ 103 | if in_channels != out_channels else nn.Identity() 104 | 105 | def forward(self, x, cond): 106 | ''' 107 | x : [ batch_size x in_channels x horizon ] 108 | cond : [ batch_size x cond_dim] 109 | 110 | returns: 111 | out : [ batch_size x out_channels x horizon ] 112 | ''' 113 | out = self.blocks[0](x) 114 | embed = self.cond_encoder(cond) 115 | 116 | embed = embed.reshape( 117 | embed.shape[0], 2, self.out_channels, 1) 118 | scale = embed[:,0,...] 119 | bias = embed[:,1,...] 120 | out = scale * out + bias 121 | 122 | out = self.blocks[1](out) 123 | out = out + self.residual_conv(x) 124 | return out 125 | 126 | 127 | class ConditionalUnet1D(nn.Module): 128 | def __init__(self, 129 | input_dim, 130 | global_cond_dim, 131 | diffusion_step_embed_dim=256, 132 | down_dims=[256,512,1024], 133 | kernel_size=5, 134 | n_groups=8 135 | ): 136 | """ 137 | input_dim: Dim of actions. 138 | global_cond_dim: Dim of global conditioning applied with FiLM 139 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim 140 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k 141 | down_dims: Channel size for each UNet level. 142 | The length of this array determines numebr of levels. 143 | kernel_size: Conv kernel size 144 | n_groups: Number of groups for GroupNorm 145 | """ 146 | 147 | super().__init__() 148 | all_dims = [input_dim] + list(down_dims) 149 | start_dim = down_dims[0] 150 | 151 | dsed = diffusion_step_embed_dim 152 | diffusion_step_encoder = nn.Sequential( 153 | SinusoidalPosEmb(dsed), 154 | nn.Linear(dsed, dsed * 4), 155 | nn.Mish(), 156 | nn.Linear(dsed * 4, dsed), 157 | ) 158 | cond_dim = dsed + global_cond_dim 159 | 160 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 161 | mid_dim = all_dims[-1] 162 | self.mid_modules = nn.ModuleList([ 163 | ConditionalResidualBlock1D( 164 | mid_dim, mid_dim, cond_dim=cond_dim, 165 | kernel_size=kernel_size, n_groups=n_groups 166 | ), 167 | ConditionalResidualBlock1D( 168 | mid_dim, mid_dim, cond_dim=cond_dim, 169 | kernel_size=kernel_size, n_groups=n_groups 170 | ), 171 | ]) 172 | 173 | down_modules = nn.ModuleList([]) 174 | for ind, (dim_in, dim_out) in enumerate(in_out): 175 | is_last = ind >= (len(in_out) - 1) 176 | down_modules.append(nn.ModuleList([ 177 | ConditionalResidualBlock1D( 178 | dim_in, dim_out, cond_dim=cond_dim, 179 | kernel_size=kernel_size, n_groups=n_groups), 180 | ConditionalResidualBlock1D( 181 | dim_out, dim_out, cond_dim=cond_dim, 182 | kernel_size=kernel_size, n_groups=n_groups), 183 | Downsample1d(dim_out) if not is_last else nn.Identity() 184 | ])) 185 | 186 | up_modules = nn.ModuleList([]) 187 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 188 | is_last = ind >= (len(in_out) - 1) 189 | up_modules.append(nn.ModuleList([ 190 | ConditionalResidualBlock1D( 191 | dim_out*2, dim_in, cond_dim=cond_dim, 192 | kernel_size=kernel_size, n_groups=n_groups), 193 | ConditionalResidualBlock1D( 194 | dim_in, dim_in, cond_dim=cond_dim, 195 | kernel_size=kernel_size, n_groups=n_groups), 196 | Upsample1d(dim_in) if not is_last else nn.Identity() 197 | ])) 198 | 199 | final_conv = nn.Sequential( 200 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 201 | nn.Conv1d(start_dim, input_dim, 1), 202 | ) 203 | 204 | self.diffusion_step_encoder = diffusion_step_encoder 205 | self.up_modules = up_modules 206 | self.down_modules = down_modules 207 | self.final_conv = final_conv 208 | 209 | print("number of parameters: {:e}".format( 210 | sum(p.numel() for p in self.parameters())) 211 | ) 212 | 213 | def forward(self, 214 | sample: torch.Tensor, 215 | timestep: Union[torch.Tensor, float, int], 216 | global_cond=None): 217 | """ 218 | x: (B,T,input_dim) 219 | timestep: (B,) or int, diffusion step 220 | global_cond: (B,global_cond_dim) 221 | output: (B,T,input_dim) 222 | """ 223 | # (B,T,C) 224 | sample = sample.moveaxis(-1,-2) 225 | # (B,C,T) 226 | 227 | # 1. time 228 | timesteps = timestep 229 | if not torch.is_tensor(timesteps): 230 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 231 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) 232 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 233 | timesteps = timesteps[None].to(sample.device) 234 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 235 | timesteps = timesteps.expand(sample.shape[0]) 236 | 237 | global_feature = self.diffusion_step_encoder(timesteps) 238 | 239 | if global_cond is not None: 240 | global_feature = torch.cat([ 241 | global_feature, global_cond 242 | ], axis=-1) 243 | 244 | x = sample 245 | h = [] 246 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 247 | x = resnet(x, global_feature) 248 | x = resnet2(x, global_feature) 249 | h.append(x) 250 | x = downsample(x) 251 | 252 | for mid_module in self.mid_modules: 253 | x = mid_module(x, global_feature) 254 | 255 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 256 | x = torch.cat((x, h.pop()), dim=1) 257 | x = resnet(x, global_feature) 258 | x = resnet2(x, global_feature) 259 | x = upsample(x) 260 | 261 | x = self.final_conv(x) 262 | 263 | # (B,C,T) 264 | x = x.moveaxis(-1,-2) 265 | # (B,T,C) 266 | return x 267 | 268 | 269 | 270 | # download demonstration data from Google Drive 271 | dataset_path = "pusht_cchi_v7_replay.zarr.zip" 272 | if not os.path.isfile(dataset_path): 273 | id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t" 274 | gdown.download(id=id, output=dataset_path, quiet=False) 275 | 276 | 277 | 278 | import timm 279 | 280 | #@markdown ### **Network Demo** 281 | class DiffusionPolicy: 282 | def __init__(self): 283 | 284 | # construct ResNet18 encoder 285 | # if you have multiple camera views, use seperate encoder weights for each view. 286 | vision_encoder = train_utils().get_resnet("resnet18") 287 | # Define Second vision encoder 288 | 289 | 290 | # IMPORTANT! 291 | # replace all BatchNorm with GroupNorm to work with EMA 292 | # performance will tank if you forget to do this! 293 | vision_encoder = train_utils().replace_bn_with_gn(vision_encoder) 294 | # ResNet18 has output dim of 512 295 | vision_feature_dim = 512 296 | # agent_pos is 2 dimensional 297 | lowdim_obs_dim = 2 298 | # observation feature has 514 dims in total per step 299 | obs_dim = vision_feature_dim + lowdim_obs_dim 300 | action_dim = 2 301 | # parameters 302 | pred_horizon = 16 303 | obs_horizon = 2 304 | action_horizon = 8 305 | #|o|o| observations: 2 306 | #| |a|a|a|a|a|a|a|a| actions executed: 8 307 | #|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16 308 | 309 | # create dataset from file 310 | dataset = PushTImageDataset( 311 | dataset_path=dataset_path, 312 | pred_horizon=pred_horizon, 313 | obs_horizon=obs_horizon, 314 | action_horizon=action_horizon 315 | ) 316 | # save training data statistics (min, max) for each dim 317 | stats = dataset.stats 318 | 319 | # create dataloader 320 | dataloader = torch.utils.data.DataLoader( 321 | dataset, 322 | batch_size=64, 323 | num_workers=4, 324 | shuffle=True, 325 | # accelerate cpu-gpu transfer 326 | pin_memory=True, 327 | # don't kill worker process afte each epoch 328 | persistent_workers=True 329 | ) 330 | #### For debugging purposes uncomment 331 | # import matplotlib.pyplot as plt 332 | # imdata = dataset[100]['image'] 333 | # if imdata.dtype == np.float32 or imdata.dtype == np.float64: 334 | # imdata = imdata / 255.0 335 | # img1 = imdata[0] 336 | # img2 = imdata[1] 337 | # # Loop through the two different "channels" 338 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 339 | # for i in range(2): 340 | # # Convert the 3x96x96 tensor to a 96x96x3 image (for display purposes) 341 | # img = np.transpose(imdata[i], (1, 2, 0)) 342 | 343 | # # Display the image in the i-th subplot 344 | # axes[i].imshow(img) 345 | # axes[i].set_title(f'Channel {i + 1}') 346 | # axes[i].axis('off') 347 | 348 | # # Show the plot 349 | # plt.show() 350 | 351 | # # Check if both images are exactly the same 352 | # are_equal = np.array_equal(img1, img2) 353 | 354 | # if are_equal: 355 | # print("The images are the same.") 356 | # else: 357 | # print("The images are different.") 358 | ######### End ######## 359 | 360 | 361 | # visualize data in batch 362 | batch = next(iter(dataloader)) 363 | print("batch['image'].shape:", batch['image'].shape) 364 | print("batch['agent_pos'].shape:", batch['agent_pos'].shape) 365 | print("batch['action'].shape", batch['action'].shape) 366 | 367 | # create network object 368 | noise_pred_net = ConditionalUnet1D( 369 | input_dim=action_dim, 370 | global_cond_dim=obs_dim*obs_horizon 371 | ) 372 | 373 | # the final arch has 2 parts 374 | nets = nn.ModuleDict({ 375 | 'vision_encoder': vision_encoder, 376 | 'noise_pred_net': noise_pred_net 377 | }) 378 | num_diffusion_iters = 100 379 | 380 | noise_scheduler = DDPMScheduler( 381 | num_train_timesteps=num_diffusion_iters, 382 | # the choise of beta schedule has big impact on performance 383 | # we found squared cosine works the best 384 | beta_schedule='squaredcos_cap_v2', 385 | # clip output to [-1,1] to improve stability 386 | clip_sample=True, 387 | # our network predicts noise (instead of denoised action) 388 | prediction_type='epsilon' 389 | ) 390 | self.nets = nets 391 | self.noise_scheduler = noise_scheduler 392 | self.num_diffusion_iters = num_diffusion_iters 393 | self.batch = batch 394 | self.dataloader = dataloader 395 | self.stats = stats 396 | self.obs_horizon = obs_horizon 397 | self.obs_dim = obs_dim 398 | self.vision_encoder = vision_encoder 399 | self.noise_pred_net = noise_pred_net 400 | self.action_horizon = action_horizon 401 | self.pred_horizon = pred_horizon 402 | self.lowdim_obs_dim = lowdim_obs_dim 403 | self.action_dim = action_dim -------------------------------------------------------------------------------- /real_robot_network copy.py: -------------------------------------------------------------------------------- 1 | #@markdown ### **Imports** 2 | # diffusion policy import 3 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 4 | import numpy as np 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 9 | import os 10 | from data_util import RealRobotDataSet 11 | from train_utils import train_utils 12 | from data_util import center_crop 13 | import json 14 | from transformer_obs_encoder import SimpleRGBObsEncoder 15 | import timm 16 | from torchvision import transforms 17 | from torchvision.datasets import ImageFolder 18 | from torch.utils.data import ConcatDataset 19 | 20 | #@markdown ### **Network** 21 | #@markdown 22 | #@markdown Defines a 1D UNet architecture `ConditionalUnet1D` 23 | #@markdown as the noies prediction network 24 | #@markdown 25 | #@markdown Components 26 | #@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k 27 | #@markdown - `Downsample1d` Strided convolution to reduce temporal resolution 28 | #@markdown - `Upsample1d` Transposed convolution to increase temporal resolution 29 | #@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish 30 | #@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \ 31 | #@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection. 32 | #@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning. 33 | 34 | 35 | 36 | def cross_center_crop(images, crop_height, crop_width): 37 | # Get original dimensions: B (batch size), T (sequence length), C (channels), H (height), W (width) 38 | B, T, C, H, W = images.shape 39 | assert crop_height <= H and crop_width <= W, "Crop size should be smaller than the original size" 40 | 41 | # Calculate the center for height and width 42 | start_y = (H - crop_height) // 2 43 | start_x = (W - crop_width) // 2 44 | 45 | # Perform cropping for each image in the sequence 46 | cropped_images = images[:, :, :, start_y:start_y + crop_height, start_x:start_x + crop_width] 47 | 48 | return cropped_images 49 | 50 | class ForceEncoder(nn.Module): 51 | def __init__(self, force_dim, hidden_dim, batch_size, obs_horizon, force_encoder = "CNN", cross_attn = False, im_encoder = "resnet", train = True): 52 | super(ForceEncoder, self).__init__() 53 | self.cross_attn = cross_attn 54 | self.batch_size = batch_size 55 | self.obs_horizon = obs_horizon 56 | self.force_encoder = force_encoder 57 | self.train = train 58 | if im_encoder == "viT": 59 | force_hidden_dim = 768 60 | else: 61 | force_hidden_dim = 512 62 | print(f"force_encoder: {force_encoder}") 63 | # Force feature extraction with Group Normalization 64 | # Convolutional layers to encode force data with Group Normalization 65 | if force_encoder == "CNN": 66 | self.conv_encoder = nn.Sequential( 67 | nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), 68 | nn.GroupNorm(num_groups=16, num_channels=32), 69 | nn.ReLU(), 70 | nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), 71 | nn.GroupNorm(num_groups=16, num_channels=64), 72 | nn.ReLU(), 73 | nn.Flatten() 74 | ) 75 | elif force_encoder == "Transformer": 76 | self.force_embedding = nn.Linear(4, force_hidden_dim) # Project 3D force to 512-dimensional embedding 77 | # Define a single Transformer Encoder Layer 78 | transformer_layer = nn.TransformerEncoderLayer(d_model=force_hidden_dim, nhead=8, batch_first= True) 79 | 80 | # Stack 6 layers of the Transformer Encoder Layer 81 | self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=6) 82 | self.fc = nn.Linear(force_hidden_dim, force_hidden_dim) # Optional final projection layer 83 | elif force_encoder == "MLP": 84 | if im_encoder == "viT": 85 | self.fc_encoder = nn.Sequential( 86 | nn.Linear(4, 64), # 4 force components -> 64 87 | nn.ReLU(), 88 | nn.Linear(64, 128), 89 | nn.ReLU(), 90 | nn.Linear(128, 256), # Output 512-dimensional feature 91 | nn.ReLU(), 92 | nn.Linear(256, 768) # Output 512-dimensional feature 93 | ) 94 | else: 95 | self.fc_encoder = nn.Sequential( 96 | nn.Linear(4, 64), # 4 force components -> 64 97 | nn.ReLU(), 98 | nn.Linear(64, 128), 99 | nn.ReLU(), 100 | nn.Linear(128, 512) # Output 512-dimensional feature 101 | ) 102 | elif force_encoder == "Linear": 103 | self.force_projection = nn.Linear(4, 512) 104 | 105 | self.projection_layer = nn.Linear(64 * force_dim, hidden_dim) 106 | 107 | def forward(self, x): 108 | if self.train: 109 | B,T,D = x.shape 110 | force_input = x.reshape(B*T, D) 111 | force_input = force_input.unsqueeze(1) 112 | else: 113 | force_input = x.unsqueeze(1) # Reshape to [batch_size, 1, input_dim] => [64, 1, 4] 114 | if self.force_encoder == "CNN": 115 | latent_vector = self.conv_encoder(force_input) 116 | latent_vector = self.projection_layer(latent_vector) # Shape: [batch_size, 512] 117 | elif self.force_encoder == "Transformer": 118 | embedded_force = self.force_embedding(force_input) # Shape: [seq_len, batch_size, 512] 119 | latent_vector = self.transformer_encoder(embedded_force) # Shape: [batch_size, embed_dim] 120 | # latent_vector = self.fc(encoded_force.mean(dim=0)) # Get the final 512-dimensional output 121 | elif self.force_encoder == "MLP": 122 | latent_vector = self.fc_encoder(force_input) 123 | elif self.force_encoder == "Linear": 124 | projected_force = self.force_projection(force_input) 125 | latent_vector = projected_force 126 | if self.train: 127 | latent_vector = latent_vector.reshape(int(B), self.obs_horizon, -1) 128 | else: 129 | latent_vector = latent_vector.squeeze(1) 130 | return latent_vector 131 | 132 | class CrossAttentionFusion(nn.Module): 133 | def __init__(self, image_dim, force_dim, hidden_dim= None, batch_size = 48, obs_horizon = 2, force_encoder = "CNN", im_encoder = "resnet", train=True): 134 | super(CrossAttentionFusion, self).__init__() 135 | self.obs_horizon = obs_horizon 136 | self.batch_size = batch_size 137 | self.im_encoder = im_encoder 138 | C,H,W = image_dim 139 | self.train = train 140 | # Image feature extraction 141 | # Image feature extraction layers 142 | if im_encoder == "CNN": 143 | self.image_encoder = nn.Sequential( 144 | nn.Conv2d(in_channels=C, out_channels=64, kernel_size=3, stride=1, padding=1), 145 | nn.GroupNorm(num_groups=16, num_channels=64), # Applying GroupNorm instead of BatchNorm 146 | nn.ReLU(), 147 | nn.MaxPool2d(kernel_size=2), 148 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 149 | nn.GroupNorm(num_groups=16, num_channels=128), # Applying GroupNorm to the second layer 150 | nn.ReLU(), 151 | nn.MaxPool2d(kernel_size=2), 152 | nn.Flatten() 153 | ) 154 | elif im_encoder == "resnet": 155 | self.image_encoder = train_utils().get_resnet("resnet18", weights= None) 156 | self.image_encoder = train_utils().replace_bn_with_gn(self.image_encoder) 157 | elif im_encoder == "viT": 158 | self.image_encoder = SimpleRGBObsEncoder() 159 | 160 | # train_utils().replace_bn_with_gn(self.image_encoder) 161 | # Dynamically calculate the image_dim after convolution and pooling 162 | with torch.no_grad(): 163 | sample_input = None 164 | if im_encoder == 'viT': 165 | sample_input = torch.zeros(1, 2, C, H, W) # Batch size of 1 166 | else: 167 | sample_input = torch.zeros(1, C, H, W) 168 | sample_output = self.image_encoder(sample_input) 169 | image_dim = sample_output.shape[1] # Get the flattened image dimension 170 | 171 | # Fully connected layer to map the image features to hidden_dim 172 | self.image_fc = nn.Linear(image_dim, hidden_dim) 173 | 174 | # Force feature extraction 175 | self.force_encoder = ForceEncoder(force_dim=force_dim, hidden_dim=hidden_dim, batch_size = batch_size, obs_horizon = obs_horizon, force_encoder=force_encoder, cross_attn=True, im_encoder = im_encoder, train = train) 176 | # Cross-attention layers 177 | self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4) 178 | 179 | # Fusion layers to create joint embedding 180 | self.fusion_layer = nn.Sequential( 181 | nn.Linear(hidden_dim, hidden_dim), 182 | nn.ReLU(), 183 | nn.Linear(hidden_dim, hidden_dim) 184 | ) 185 | 186 | def forward(self, image_input, force_input): 187 | # Encode image and force data 188 | current_batch_size = image_input.size(0) 189 | if self.im_encoder == "viT": 190 | image_input = cross_center_crop(image_input, 224, 224) 191 | image_features = self.image_encoder(image_input) 192 | 193 | # image_features = self.image_fc(image_features) 194 | if self.im_encoder != "viT" and self.train: 195 | image_features = image_features.view(int(current_batch_size/2), self.obs_horizon, -1) 196 | if self.train: 197 | image_features = image_features.permute(1, 0, 2) # Correct shape: (num_images, batch_size, hidden_dim) 198 | 199 | # Reshape for attention: (sequence_length, batch_size, hidden_dim) 200 | 201 | force_features = self.force_encoder(force_input) 202 | if self.train: 203 | # force_features = force_features.view(batch_size, obs_horizon, -1) 204 | force_features = force_features.permute(1, 0, 2) # Correct shape: (num_forces, batch_size, hidden_dim) 205 | 206 | # Cross-attention operation 207 | attn_output, _ = self.attention(query=force_features, key=image_features, value=image_features) 208 | if self.train: 209 | attn_output = attn_output.permute(1, 0, 2) # Shape: (batch_size, num_forces, hidden_dim) 210 | 211 | # Generate the fused embedding 212 | joint_embedding = self.fusion_layer(attn_output) 213 | return joint_embedding 214 | 215 | class NumpyEncoder(json.JSONEncoder): 216 | def default(self, obj): 217 | if isinstance(obj, np.ndarray): 218 | return obj.tolist() 219 | return super(NumpyEncoder, self).default(obj) 220 | 221 | class SinusoidalPosEmb(nn.Module): 222 | def __init__(self, dim): 223 | super().__init__() 224 | self.dim = dim 225 | 226 | def forward(self, x): 227 | device = x.device 228 | half_dim = self.dim // 2 229 | emb = math.log(10000) / (half_dim - 1) 230 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 231 | emb = x[:, None] * emb[None, :] 232 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 233 | return emb 234 | 235 | 236 | class Downsample1d(nn.Module): 237 | def __init__(self, dim): 238 | super().__init__() 239 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 240 | 241 | def forward(self, x): 242 | return self.conv(x) 243 | 244 | class Upsample1d(nn.Module): 245 | def __init__(self, dim): 246 | super().__init__() 247 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 248 | 249 | def forward(self, x): 250 | return self.conv(x) 251 | 252 | 253 | class Conv1dBlock(nn.Module): 254 | ''' 255 | Conv1d --> GroupNorm --> Mish 256 | ''' 257 | 258 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 259 | super().__init__() 260 | 261 | self.block = nn.Sequential( 262 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 263 | nn.GroupNorm(n_groups, out_channels), 264 | nn.Mish(), 265 | ) 266 | 267 | def forward(self, x): 268 | return self.block(x) 269 | 270 | 271 | class ConditionalResidualBlock1D(nn.Module): 272 | def __init__(self, 273 | in_channels, 274 | out_channels, 275 | cond_dim, 276 | kernel_size=3, 277 | n_groups=8): 278 | super().__init__() 279 | 280 | self.blocks = nn.ModuleList([ 281 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 282 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 283 | ]) 284 | 285 | # FiLM modulation https://arxiv.org/abs/1709.07871 286 | # predicts per-channel scale and bias 287 | cond_channels = out_channels * 2 288 | self.out_channels = out_channels 289 | self.cond_encoder = nn.Sequential( 290 | nn.Mish(), 291 | nn.Linear(cond_dim, cond_channels), 292 | nn.Unflatten(-1, (-1, 1)) 293 | ) 294 | 295 | # make sure dimensions compatible 296 | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ 297 | if in_channels != out_channels else nn.Identity() 298 | 299 | def forward(self, x, cond): 300 | ''' 301 | x : [ batch_size x in_channels x horizon ] 302 | cond : [ batch_size x cond_dim] 303 | 304 | returns: 305 | out : [ batch_size x out_channels x horizon ] 306 | ''' 307 | out = self.blocks[0](x) 308 | embed = self.cond_encoder(cond) 309 | 310 | embed = embed.reshape( 311 | embed.shape[0], 2, self.out_channels, 1) 312 | scale = embed[:,0,...] 313 | bias = embed[:,1,...] 314 | out = scale * out + bias 315 | 316 | out = self.blocks[1](out) 317 | out = out + self.residual_conv(x) 318 | return out 319 | 320 | 321 | class ConditionalUnet1D(nn.Module): 322 | def __init__(self, 323 | input_dim, 324 | global_cond_dim, 325 | diffusion_step_embed_dim=256, 326 | down_dims=[256,512,1024], 327 | kernel_size=5, 328 | n_groups=8 329 | ): 330 | """ 331 | input_dim: Dim of actions. 332 | global_cond_dim: Dim of global conditioning applied with FiLM 333 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim 334 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k 335 | down_dims: Channel size for each UNet level. 336 | The length of this array determines numebr of levels. 337 | kernel_size: Conv kernel size 338 | n_groups: Number of groups for GroupNorm 339 | """ 340 | 341 | super().__init__() 342 | all_dims = [input_dim] + list(down_dims) 343 | start_dim = down_dims[0] 344 | 345 | dsed = diffusion_step_embed_dim 346 | diffusion_step_encoder = nn.Sequential( 347 | SinusoidalPosEmb(dsed), 348 | nn.Linear(dsed, dsed * 4), 349 | nn.Mish(), 350 | nn.Linear(dsed * 4, dsed), 351 | ) 352 | cond_dim = dsed + global_cond_dim 353 | 354 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 355 | mid_dim = all_dims[-1] 356 | self.mid_modules = nn.ModuleList([ 357 | ConditionalResidualBlock1D( 358 | mid_dim, mid_dim, cond_dim=cond_dim, 359 | kernel_size=kernel_size, n_groups=n_groups 360 | ), 361 | ConditionalResidualBlock1D( 362 | mid_dim, mid_dim, cond_dim=cond_dim, 363 | kernel_size=kernel_size, n_groups=n_groups 364 | ), 365 | ]) 366 | 367 | down_modules = nn.ModuleList([]) 368 | for ind, (dim_in, dim_out) in enumerate(in_out): 369 | is_last = ind >= (len(in_out) - 1) 370 | down_modules.append(nn.ModuleList([ 371 | ConditionalResidualBlock1D( 372 | dim_in, dim_out, cond_dim=cond_dim, 373 | kernel_size=kernel_size, n_groups=n_groups), 374 | ConditionalResidualBlock1D( 375 | dim_out, dim_out, cond_dim=cond_dim, 376 | kernel_size=kernel_size, n_groups=n_groups), 377 | Downsample1d(dim_out) if not is_last else nn.Identity() 378 | ])) 379 | 380 | up_modules = nn.ModuleList([]) 381 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 382 | is_last = ind >= (len(in_out) - 1) 383 | up_modules.append(nn.ModuleList([ 384 | ConditionalResidualBlock1D( 385 | dim_out*2, dim_in, cond_dim=cond_dim, 386 | kernel_size=kernel_size, n_groups=n_groups), 387 | ConditionalResidualBlock1D( 388 | dim_in, dim_in, cond_dim=cond_dim, 389 | kernel_size=kernel_size, n_groups=n_groups), 390 | Upsample1d(dim_in) if not is_last else nn.Identity() 391 | ])) 392 | 393 | final_conv = nn.Sequential( 394 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 395 | nn.Conv1d(start_dim, input_dim, 1), 396 | ) 397 | 398 | self.diffusion_step_encoder = diffusion_step_encoder 399 | self.up_modules = up_modules 400 | self.down_modules = down_modules 401 | self.final_conv = final_conv 402 | 403 | print("number of parameters: {:e}".format( 404 | sum(p.numel() for p in self.parameters())) 405 | ) 406 | 407 | def forward(self, 408 | sample: torch.Tensor, 409 | timestep: Union[torch.Tensor, float, int], 410 | global_cond=None): 411 | """ 412 | x: (B,T,input_dim) 413 | timestep: (B,) or int, diffusion step 414 | global_cond: (B,global_cond_dim) 415 | output: (B,T,input_dim) 416 | """ 417 | # (B,T,C) 418 | sample = sample.moveaxis(-1,-2) 419 | # (B,C,T) 420 | 421 | # 1. time 422 | timesteps = timestep 423 | if not torch.is_tensor(timesteps): 424 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 425 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) 426 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 427 | timesteps = timesteps[None].to(sample.device) 428 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 429 | timesteps = timesteps.expand(sample.shape[0]) 430 | 431 | global_feature = self.diffusion_step_encoder(timesteps) 432 | 433 | if global_cond is not None: 434 | global_feature = torch.cat([ 435 | global_feature, global_cond 436 | ], axis=-1) 437 | 438 | x = sample 439 | h = [] 440 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 441 | x = resnet(x, global_feature) 442 | x = resnet2(x, global_feature) 443 | h.append(x) 444 | x = downsample(x) 445 | 446 | for mid_module in self.mid_modules: 447 | x = mid_module(x, global_feature) 448 | 449 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 450 | x = torch.cat((x, h.pop()), dim=1) 451 | x = resnet(x, global_feature) 452 | x = resnet2(x, global_feature) 453 | x = upsample(x) 454 | 455 | x = self.final_conv(x) 456 | 457 | # (B,C,T) 458 | x = x.moveaxis(-1,-2) 459 | # (B,T,C) 460 | return x 461 | 462 | 463 | 464 | # download demonstration data from Google Drive 465 | # dataset_path = "pusht_cchi_v7_replay.zarr.zip" 466 | # if not os.path.isfile(dataset_path): 467 | # id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t" 468 | # gdown.download(id=id, output=dataset_path, quiet=False) 469 | 470 | def get_filename(input_string): 471 | # Find the last instance of '/' 472 | last_slash_index = input_string.rfind('/') 473 | 474 | # Get the substring after the last '/' 475 | if last_slash_index != -1: 476 | result = input_string[last_slash_index + 1:] 477 | # Return the substring without the last 4 characters 478 | return result[:-7] if len(result) > 7 else "" 479 | else: 480 | return "" 481 | 482 | 483 | dataset_path = "/home/jeon/jeon_ws/diffusion_policy/src/diffusion_cam/RAL_AAA+D_419.zarr.zip" 484 | 485 | #@markdown ### **Network Demo** 486 | class DiffusionPolicy_Real: 487 | def __init__(self, 488 | train=True, 489 | encoder = "resnet", 490 | action_def = "delta", 491 | force_mod:bool = False, 492 | single_view:bool = False, 493 | force_encode = False, 494 | force_encoder = "CNN", 495 | cross_attn: bool = False, 496 | hybrid: bool = False, 497 | duplicate_view = False, 498 | crop: int = 1000, 499 | augment = True): 500 | # action dimension should also correspond with the state dimension (x,y,z, x, y, z, w) 501 | action_dim = 9 502 | # parameters 503 | pred_horizon = 16 504 | obs_horizon = 2 505 | action_horizon = 8 506 | #|o|o| observations: 2 507 | #| |a|a|a|a|a|a|a|a| actions executed: 8 508 | #|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16 509 | batch_size = 100 510 | Transformer_bool = None 511 | modality = "without_force" 512 | view = "dual_view" 513 | if force_mod: 514 | modality = "with_force" 515 | if single_view: 516 | view = "single_view" 517 | # construct ResNet18 encoder 518 | # if you have multiple camera views, use seperate encoder weights for each view. 519 | # Resnet18 and resnet34 both have same dimension for the output 520 | # Define Second vision encoder 521 | 522 | if not single_view: 523 | vision_encoder2 = train_utils().get_resnet('resnet18') 524 | vision_encoder2 = train_utils().replace_bn_with_gn(vision_encoder2) 525 | if duplicate_view: 526 | vision_encoder2 = train_utils().get_resnet('resnet18') 527 | vision_encoder2 = train_utils().replace_bn_with_gn(vision_encoder2) 528 | 529 | if force_mod and force_encode: 530 | if encoder == "viT": 531 | hidden_dim_force = 768 532 | else: 533 | hidden_dim_force = 512 534 | force_encoder = ForceEncoder(4, hidden_dim_force, batch_size = batch_size, 535 | obs_horizon = obs_horizon, 536 | force_encoder= force_encoder, 537 | cross_attn=cross_attn, 538 | train=train) 539 | 540 | if cross_attn and force_mod: 541 | if encoder == "viT": 542 | cross_hidden_dim = 768 543 | image_dim = (3,224,224) 544 | else: 545 | cross_hidden_dim = 512 546 | if crop == 98: 547 | image_dim = (3,98,98) 548 | else: 549 | image_dim = (3,320,240) 550 | joint_encoder = CrossAttentionFusion(image_dim, 4, cross_hidden_dim, batch_size = batch_size, 551 | obs_horizon=obs_horizon, 552 | force_encoder = force_encoder, 553 | im_encoder = encoder, 554 | train = train) 555 | else: 556 | if encoder == "resnet": 557 | print("resnet") 558 | vision_encoder = train_utils().get_resnet('resnet18') 559 | vision_encoder = train_utils().replace_bn_with_gn(vision_encoder) 560 | 561 | elif encoder == "Transformer": 562 | Transformer_bool = True 563 | print("Imported Transformer clip model") 564 | vision_encoder = SimpleRGBObsEncoder() 565 | # IMPORTANT! 566 | # replace all BatchNorm with GroupNorm to work with EMA 567 | # performance will tank if you forget to do this! 568 | # ResNet18 has output dim of 512 X 2 because two views 569 | if single_view: 570 | if encoder == "viT": 571 | vision_feature_dim = 768 572 | else: 573 | vision_feature_dim = 512 574 | else: 575 | if encoder == "viT": 576 | vision_feature_dim = 768 + 512 577 | else: 578 | vision_feature_dim = 512 + 512 579 | 580 | if force_encode: 581 | force_feature_dim = 512 582 | else: 583 | force_feature_dim = 4 584 | # agent_pos is seven (x,y,z, w, y, z, w ) dimensional 585 | lowdim_obs_dim = 9 586 | # observation feature has 514 dims in total per step 587 | if force_mod and not cross_attn: 588 | obs_dim = vision_feature_dim + force_feature_dim + lowdim_obs_dim 589 | elif force_mod and cross_attn and not duplicate_view: 590 | obs_dim = vision_feature_dim + lowdim_obs_dim 591 | elif force_mod and cross_attn and duplicate_view: 592 | obs_dim = vision_feature_dim * 2 + lowdim_obs_dim 593 | else: 594 | obs_dim = vision_feature_dim + lowdim_obs_dim 595 | if hybrid: 596 | obs_dim += 4 597 | 598 | data_name = get_filename(dataset_path) 599 | 600 | if train: 601 | # create dataset from file 602 | dataset = RealRobotDataSet( 603 | dataset_path=dataset_path, 604 | pred_horizon=pred_horizon, 605 | obs_horizon=obs_horizon, 606 | action_horizon=action_horizon, 607 | Transformer= Transformer_bool, 608 | force_mod = force_mod, 609 | single_view=single_view, 610 | augment = False, 611 | duplicate_view = duplicate_view, 612 | crop = crop 613 | ) 614 | # save training data statistics (min, max) for each dim 615 | 616 | 617 | 618 | 619 | if augment: 620 | dataset_augmented = RealRobotDataSet( 621 | dataset_path=dataset_path, 622 | pred_horizon=pred_horizon, 623 | obs_horizon=obs_horizon, 624 | action_horizon=action_horizon, 625 | Transformer= Transformer_bool, 626 | force_mod = force_mod, 627 | single_view=single_view, 628 | augment = True, 629 | duplicate_view = duplicate_view, 630 | crop = crop 631 | ) 632 | 633 | 634 | combined_dataset = ConcatDataset([dataset, dataset_augmented]) 635 | 636 | # DataLoader for combined dataset 637 | data_loader_combined = torch.utils.data.DataLoader( 638 | combined_dataset, 639 | batch_size=batch_size, 640 | num_workers=4, 641 | shuffle=True, # Shuffle to mix normal and augmented data 642 | pin_memory=True, 643 | persistent_workers=True 644 | ) 645 | self.dataloader = data_loader_combined 646 | batch = next(iter(data_loader_combined)) 647 | stats = dataset_augmented.stats 648 | 649 | else: 650 | # create dataloader 651 | dataloader = torch.utils.data.DataLoader( 652 | dataset, 653 | batch_size=batch_size, 654 | num_workers=4, 655 | shuffle=True, 656 | # accelerate cpu-gpu transfer 657 | pin_memory=True, 658 | # don't kill worker process afte each epoch 659 | persistent_workers=True, 660 | ) 661 | self.dataloader = dataloader 662 | batch = next(iter(dataloader)) 663 | stats = dataset.stats 664 | 665 | # Save the stats to a file 666 | with open(f'stats_{data_name}_{encoder}_{action_def}_{modality}_vn.json', 'w') as f: 667 | json.dump(stats, f, cls=NumpyEncoder) 668 | print("stats saved") 669 | 670 | # self.dataloader = data_loader_augmented 671 | # self.data_loader_augmented = data_loader_augmented 672 | self.stats = stats 673 | 674 | #### For debugging purposes uncomment 675 | # import matplotlib.pyplot as plt 676 | # imdata = dataset[100]['image'] 677 | # if imdata.dtype == np.float32 or imdata.dtype == np.float64: 678 | # imdata = imdata / 255.0 679 | # img1 = imdata[0] 680 | # img2 = imdata[1] 681 | # # Loop through the two different "channels" 682 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 683 | # for i in range(2): 684 | # # Convert the 3x96x96 tensor to a 96x96x3 image (for display purposes) 685 | # img = np.transpose(imdata[i], (1, 2, 0)) 686 | 687 | # # Display the image in the i-th subplot 688 | # axes[i].imshow(img) 689 | # axes[i].set_title(f'Channel {i + 1}') 690 | # axes[i].axis('off') 691 | 692 | # # Show the plot 693 | # plt.show() 694 | 695 | # # Check if both images are exactly the same 696 | # are_equal = np.array_equal(img1, img2) 697 | 698 | # if are_equal: 699 | # print("The images are the same.") 700 | # else: 701 | # print("The images are different.") 702 | ######### End ######## 703 | # visualize data in batch 704 | print("batch['image'].shape:", batch['image'].shape) 705 | if not single_view: 706 | print("batch[image2].shape", batch["image2"].shape) 707 | if duplicate_view: 708 | print("batch[image_duplicate].shape", batch["duplicate_image"].shape) 709 | 710 | print("batch['agent_pos'].shape:", batch['agent_pos'].shape) 711 | 712 | if force_mod: 713 | print("batch['force'].shape:", batch['force'].shape) 714 | 715 | print("batch['action'].shape", batch['action'].shape) 716 | self.batch = batch 717 | 718 | # create network object 719 | noise_pred_net = ConditionalUnet1D( 720 | input_dim=action_dim, 721 | global_cond_dim=obs_dim*obs_horizon 722 | ) 723 | if single_view and not force_mod and not force_encode and not cross_attn: 724 | # the final arch has 2 parts 725 | nets = nn.ModuleDict({ 726 | 'vision_encoder': vision_encoder, 727 | 'noise_pred_net': noise_pred_net 728 | }) 729 | elif single_view and force_mod and not force_encode and not cross_attn: 730 | # the final arch has 2 parts 731 | nets = nn.ModuleDict({ 732 | 'vision_encoder': vision_encoder, 733 | 'noise_pred_net': noise_pred_net 734 | }) 735 | elif single_view and force_encode: 736 | # the final arch has 2 parts 737 | nets = nn.ModuleDict({ 738 | 'vision_encoder': vision_encoder, 739 | 'force_encoder': force_encoder, 740 | 'noise_pred_net': noise_pred_net 741 | }) 742 | elif not single_view and force_encode: 743 | nets = nn.ModuleDict({ 744 | 'vision_encoder': vision_encoder, 745 | 'vision_encoder2': vision_encoder2, 746 | 'force_encoder': force_encoder, 747 | 'noise_pred_net': noise_pred_net 748 | }) 749 | elif not single_view and not force_encode and not cross_attn: 750 | nets = nn.ModuleDict({ 751 | 'vision_encoder': vision_encoder, 752 | 'vision_encoder2': vision_encoder2, 753 | 'noise_pred_net': noise_pred_net 754 | }) 755 | elif single_view and cross_attn and not duplicate_view: 756 | nets = nn.ModuleDict({ 757 | 'cross_attn_encoder': joint_encoder, 758 | 'noise_pred_net': noise_pred_net 759 | }) 760 | elif not single_view and cross_attn and not force_encode: 761 | nets = nn.ModuleDict({ 762 | 'cross_attn_encoder': joint_encoder, 763 | 'vision_encoder2': vision_encoder2, 764 | 'noise_pred_net': noise_pred_net 765 | }) 766 | elif single_view and duplicate_view and cross_attn and not force_encode: 767 | nets = nn.ModuleDict({ 768 | 'cross_attn_encoder': joint_encoder, 769 | 'vision_encoder2': vision_encoder2, 770 | 'noise_pred_net': noise_pred_net 771 | }) 772 | elif cross_attn and force_encode: 773 | print("Cross attn and force encode cannot be True at the same time") 774 | 775 | 776 | # diffusion iteration 777 | num_diffusion_iters = 100 778 | 779 | noise_scheduler = DDPMScheduler( 780 | num_train_timesteps=num_diffusion_iters, 781 | # the choise of beta schedule has big impact on performance 782 | # we found squared cosine works the best 783 | beta_schedule='squaredcos_cap_v2', 784 | # clip output to [-1,1] to improve stability 785 | clip_sample=True, 786 | # our network predicts noise (instead of denoised action) 787 | prediction_type='epsilon' 788 | ) 789 | 790 | 791 | self.nets = nets 792 | self.noise_scheduler = noise_scheduler 793 | self.num_diffusion_iters = num_diffusion_iters 794 | self.obs_horizon = obs_horizon 795 | self.obs_dim = obs_dim 796 | if not single_view or duplicate_view: 797 | self.vision_encoder2 = vision_encoder2 798 | if not cross_attn: 799 | self.vision_encoder = vision_encoder 800 | if force_encode: 801 | self.force_encoder = force_encoder 802 | self.noise_pred_net = noise_pred_net 803 | self.action_horizon = action_horizon 804 | self.pred_horizon = pred_horizon 805 | self.lowdim_obs_dim = lowdim_obs_dim 806 | self.action_dim = action_dim 807 | self.data_name = data_name 808 | 809 | 810 | 811 | def test(): 812 | # create dataset from file 813 | obs_horizon = 2 814 | dataset = RealRobotDataSet( 815 | dataset_path=dataset_path, 816 | pred_horizon=16, 817 | obs_horizon=obs_horizon, 818 | action_horizon=8, 819 | Transformer= False, 820 | force_mod = True, 821 | single_view= True 822 | ) 823 | # save training data statistics (min, max) for each dim 824 | stats = dataset.stats 825 | 826 | batch_size = 10 827 | # create dataloader 828 | dataloader = torch.utils.data.DataLoader( 829 | dataset, 830 | batch_size=batch_size, 831 | num_workers=4, 832 | shuffle=True, 833 | # accelerate cpu-gpu transfer 834 | pin_memory=True, 835 | # don't kill worker process afte each epoch 836 | persistent_workers=True 837 | ) 838 | 839 | batch = next(iter(dataloader)) 840 | print("batch['image'].shape:", batch['image'].shape) 841 | print("batch['image'].shape:", batch['image'].shape) 842 | 843 | # ### For debugging purposes uncomment 844 | # import matplotlib.pyplot as plt 845 | # imdata = dataset[100]['image'] 846 | # if imdata.dtype == np.float32 or imdata.dtype == np.float64: 847 | # imdata = imdata 848 | # img1 = imdata[0] 849 | # img2 = imdata[1] 850 | # # Loop through the two different "channels" 851 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 852 | # for i in range(2): 853 | # # Convert the 3x96x96 tensor to a 96x96x3 image (for display purposes) 854 | # img = np.transpose(imdata[i], (1, 2, 0)) 855 | 856 | # # Display the image in the i-th subplot 857 | # axes[i].imshow(img) 858 | # axes[i].set_title(f'Channel {i + 1}') 859 | # axes[i].axis('off') 860 | 861 | # # Show the plot 862 | # plt.show() 863 | 864 | print("batch['agent_pos'].shape:", batch['agent_pos'].shape) 865 | print("batch['force'].shape:", batch['force'].shape) 866 | print("batch['action'].shape", batch['action'].shape) 867 | image_input_shape = (3, 224, 224) 868 | force_dim = 4 869 | hidden_dim = 768 870 | 871 | import torch.optim as optim 872 | device = torch.device('cuda') 873 | # Standard ADAM optimizer 874 | # Note that EMA parametesr are not optimized 875 | model = CrossAttentionFusion(image_input_shape, force_dim, hidden_dim, batch_size = batch_size, obs_horizon=obs_horizon, force_encoder = "MLP", im_encoder= "viT") 876 | model = model.to(device) 877 | num_epochs = 10 # Set the number of epochs 878 | nimage = batch['image'][:,:2].to(device) 879 | nforce = batch['force'][:,:2].to(device) 880 | for epoch in range(num_epochs): 881 | # Example random input data for demonstration 882 | # image_input = nimage.flatten(end_dim=1).to(device) # Batch of 8 images 883 | # force_input = nforce.flatten(end_dim=1).to(device) 884 | # Batch of 8 force vectors 885 | 886 | # Forward pass 887 | latent_embedding = model(nimage, nforce) 888 | 889 | 890 | print(f'Epoch [{epoch+1}/{num_epochs}. {latent_embedding.shape}') 891 | ##TODO: Make sure that new CNN can work with the new architecture for CrossAttention 892 | # test() 893 | 894 | -------------------------------------------------------------------------------- /real_robot_network.py: -------------------------------------------------------------------------------- 1 | #@markdown ### **Imports** 2 | # diffusion policy import 3 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 4 | import numpy as np 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 9 | import os 10 | from data_util import RealRobotDataSet 11 | from train_utils import train_utils 12 | from data_util import center_crop 13 | import json 14 | from transformer_obs_encoder import SimpleRGBObsEncoder 15 | import timm 16 | from torchvision import transforms 17 | from torchvision.datasets import ImageFolder 18 | from torch.utils.data import ConcatDataset 19 | 20 | #@markdown ### **Network** 21 | #@markdown 22 | #@markdown Defines a 1D UNet architecture `ConditionalUnet1D` 23 | #@markdown as the noies prediction network 24 | #@markdown 25 | #@markdown Components 26 | #@markdown - `SinusoidalPosEmb` Positional encoding for the diffusion iteration k 27 | #@markdown - `Downsample1d` Strided convolution to reduce temporal resolution 28 | #@markdown - `Upsample1d` Transposed convolution to increase temporal resolution 29 | #@markdown - `Conv1dBlock` Conv1d --> GroupNorm --> Mish 30 | #@markdown - `ConditionalResidualBlock1D` Takes two inputs `x` and `cond`. \ 31 | #@markdown `x` is passed through 2 `Conv1dBlock` stacked together with residual connection. 32 | #@markdown `cond` is applied to `x` with [FiLM](https://arxiv.org/abs/1709.07871) conditioning. 33 | 34 | 35 | 36 | def cross_center_crop(images, crop_height, crop_width): 37 | # Get original dimensions: B (batch size), T (sequence length), C (channels), H (height), W (width) 38 | B, T, C, H, W = images.shape 39 | assert crop_height <= H and crop_width <= W, "Crop size should be smaller than the original size" 40 | 41 | # Calculate the center for height and width 42 | start_y = (H - crop_height) // 2 43 | start_x = (W - crop_width) // 2 44 | 45 | # Perform cropping for each image in the sequence 46 | cropped_images = images[:, :, :, start_y:start_y + crop_height, start_x:start_x + crop_width] 47 | 48 | return cropped_images 49 | 50 | class ForceEncoder(nn.Module): 51 | def __init__(self, force_dim, hidden_dim, batch_size, obs_horizon, force_encoder = "CNN", cross_attn = False, im_encoder = "resnet", train = True): 52 | super(ForceEncoder, self).__init__() 53 | self.cross_attn = cross_attn 54 | self.batch_size = batch_size 55 | self.obs_horizon = obs_horizon 56 | self.force_encoder = force_encoder 57 | self.train = train 58 | if im_encoder == "viT": 59 | force_hidden_dim = 768 60 | else: 61 | force_hidden_dim = 512 62 | print(f"force_encoder: {force_encoder}") 63 | # Force feature extraction with Group Normalization 64 | # Convolutional layers to encode force data with Group Normalization 65 | if force_encoder == "CNN": 66 | self.conv_encoder = nn.Sequential( 67 | nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1), 68 | nn.GroupNorm(num_groups=16, num_channels=32), 69 | nn.ReLU(), 70 | nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), 71 | nn.GroupNorm(num_groups=16, num_channels=64), 72 | nn.ReLU(), 73 | nn.Flatten() 74 | ) 75 | elif force_encoder == "Transformer": 76 | self.force_embedding = nn.Linear(4, force_hidden_dim) # Project 3D force to 512-dimensional embedding 77 | # Define a single Transformer Encoder Layer 78 | transformer_layer = nn.TransformerEncoderLayer(d_model=force_hidden_dim, nhead=8, batch_first= True) 79 | 80 | # Stack 6 layers of the Transformer Encoder Layer 81 | self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=6) 82 | self.fc = nn.Linear(force_hidden_dim, force_hidden_dim) # Optional final projection layer 83 | elif force_encoder == "MLP": 84 | if im_encoder == "viT": 85 | self.fc_encoder = nn.Sequential( 86 | nn.Linear(4, 64), # 4 force components -> 64 87 | nn.ReLU(), 88 | nn.Linear(64, 128), 89 | nn.ReLU(), 90 | nn.Linear(128, 256), # Output 512-dimensional feature 91 | nn.ReLU(), 92 | nn.Linear(256, 768) # Output 512-dimensional feature 93 | ) 94 | else: 95 | self.fc_encoder = nn.Sequential( 96 | nn.Linear(4, 64), # 4 force components -> 64 97 | nn.ReLU(), 98 | nn.Linear(64, 128), 99 | nn.ReLU(), 100 | nn.Linear(128, 512) # Output 512-dimensional feature 101 | ) 102 | elif force_encoder == "Linear": 103 | self.force_projection = nn.Linear(4, 512) 104 | 105 | self.projection_layer = nn.Linear(64 * force_dim, hidden_dim) 106 | 107 | def forward(self, x): 108 | if self.train: 109 | B,T,D = x.shape 110 | force_input = x.reshape(B*T, D) 111 | force_input = force_input.unsqueeze(1) 112 | else: 113 | force_input = x.unsqueeze(1) # Reshape to [batch_size, 1, input_dim] => [64, 1, 4] 114 | if self.force_encoder == "CNN": 115 | latent_vector = self.conv_encoder(force_input) 116 | latent_vector = self.projection_layer(latent_vector) # Shape: [batch_size, 512] 117 | elif self.force_encoder == "Transformer": 118 | embedded_force = self.force_embedding(force_input) # Shape: [seq_len, batch_size, 512] 119 | latent_vector = self.transformer_encoder(embedded_force) # Shape: [batch_size, embed_dim] 120 | # latent_vector = self.fc(encoded_force.mean(dim=0)) # Get the final 512-dimensional output 121 | elif self.force_encoder == "MLP": 122 | latent_vector = self.fc_encoder(force_input) 123 | elif self.force_encoder == "Linear": 124 | projected_force = self.force_projection(force_input) 125 | latent_vector = projected_force 126 | if self.train: 127 | latent_vector = latent_vector.reshape(int(B), self.obs_horizon, -1) 128 | else: 129 | latent_vector = latent_vector.squeeze(1) 130 | return latent_vector 131 | 132 | class CrossAttentionFusion(nn.Module): 133 | def __init__(self, image_dim, force_dim, hidden_dim= None, batch_size = 48, obs_horizon = 2, force_encoder = "CNN", im_encoder = "resnet", train=True): 134 | super(CrossAttentionFusion, self).__init__() 135 | self.obs_horizon = obs_horizon 136 | self.batch_size = batch_size 137 | self.im_encoder = im_encoder 138 | C,H,W = image_dim 139 | self.train = train 140 | # Image feature extraction 141 | # Image feature extraction layers 142 | if im_encoder == "CNN": 143 | self.image_encoder = nn.Sequential( 144 | nn.Conv2d(in_channels=C, out_channels=64, kernel_size=3, stride=1, padding=1), 145 | nn.GroupNorm(num_groups=16, num_channels=64), # Applying GroupNorm instead of BatchNorm 146 | nn.ReLU(), 147 | nn.MaxPool2d(kernel_size=2), 148 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 149 | nn.GroupNorm(num_groups=16, num_channels=128), # Applying GroupNorm to the second layer 150 | nn.ReLU(), 151 | nn.MaxPool2d(kernel_size=2), 152 | nn.Flatten() 153 | ) 154 | elif im_encoder == "resnet": 155 | self.image_encoder = train_utils().get_resnet("resnet18", weights= None) 156 | self.image_encoder = train_utils().replace_bn_with_gn(self.image_encoder) 157 | elif im_encoder == "viT": 158 | self.image_encoder = SimpleRGBObsEncoder() 159 | 160 | # train_utils().replace_bn_with_gn(self.image_encoder) 161 | # Dynamically calculate the image_dim after convolution and pooling 162 | with torch.no_grad(): 163 | sample_input = None 164 | if im_encoder == 'viT': 165 | sample_input = torch.zeros(1, 2, C, H, W) # Batch size of 1 166 | else: 167 | sample_input = torch.zeros(1, C, H, W) 168 | sample_output = self.image_encoder(sample_input) 169 | image_dim = sample_output.shape[1] # Get the flattened image dimension 170 | 171 | # Fully connected layer to map the image features to hidden_dim 172 | self.image_fc = nn.Linear(image_dim, hidden_dim) 173 | 174 | # Force feature extraction 175 | self.force_encoder = ForceEncoder(force_dim=force_dim, hidden_dim=hidden_dim, batch_size = batch_size, obs_horizon = obs_horizon, force_encoder=force_encoder, cross_attn=True, im_encoder = im_encoder, train = train) 176 | # Cross-attention layers 177 | self.attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=4) 178 | 179 | # Fusion layers to create joint embedding 180 | self.fusion_layer = nn.Sequential( 181 | nn.Linear(hidden_dim, hidden_dim), 182 | nn.ReLU(), 183 | nn.Linear(hidden_dim, hidden_dim) 184 | ) 185 | 186 | def forward(self, image_input, force_input): 187 | # Encode image and force data 188 | current_batch_size = image_input.size(0) 189 | if self.im_encoder == "viT": 190 | image_input = cross_center_crop(image_input, 224, 224) 191 | image_features = self.image_encoder(image_input) 192 | 193 | # image_features = self.image_fc(image_features) 194 | if self.im_encoder != "viT" and self.train: 195 | image_features = image_features.view(int(current_batch_size/2), self.obs_horizon, -1) 196 | if self.train: 197 | image_features = image_features.permute(1, 0, 2) # Correct shape: (num_images, batch_size, hidden_dim) 198 | 199 | # Reshape for attention: (sequence_length, batch_size, hidden_dim) 200 | 201 | force_features = self.force_encoder(force_input) 202 | if self.train: 203 | # force_features = force_features.view(batch_size, obs_horizon, -1) 204 | force_features = force_features.permute(1, 0, 2) # Correct shape: (num_forces, batch_size, hidden_dim) 205 | 206 | # Cross-attention operation 207 | attn_output, _ = self.attention(query=force_features, key=image_features, value=image_features) 208 | if self.train: 209 | attn_output = attn_output.permute(1, 0, 2) # Shape: (batch_size, num_forces, hidden_dim) 210 | 211 | # Generate the fused embedding 212 | joint_embedding = self.fusion_layer(attn_output) 213 | return joint_embedding 214 | 215 | class NumpyEncoder(json.JSONEncoder): 216 | def default(self, obj): 217 | if isinstance(obj, np.ndarray): 218 | return obj.tolist() 219 | return super(NumpyEncoder, self).default(obj) 220 | 221 | class SinusoidalPosEmb(nn.Module): 222 | def __init__(self, dim): 223 | super().__init__() 224 | self.dim = dim 225 | 226 | def forward(self, x): 227 | device = x.device 228 | half_dim = self.dim // 2 229 | emb = math.log(10000) / (half_dim - 1) 230 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 231 | emb = x[:, None] * emb[None, :] 232 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 233 | return emb 234 | 235 | 236 | class Downsample1d(nn.Module): 237 | def __init__(self, dim): 238 | super().__init__() 239 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 240 | 241 | def forward(self, x): 242 | return self.conv(x) 243 | 244 | class Upsample1d(nn.Module): 245 | def __init__(self, dim): 246 | super().__init__() 247 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 248 | 249 | def forward(self, x): 250 | return self.conv(x) 251 | 252 | 253 | class Conv1dBlock(nn.Module): 254 | ''' 255 | Conv1d --> GroupNorm --> Mish 256 | ''' 257 | 258 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 259 | super().__init__() 260 | 261 | self.block = nn.Sequential( 262 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 263 | nn.GroupNorm(n_groups, out_channels), 264 | nn.Mish(), 265 | ) 266 | 267 | def forward(self, x): 268 | return self.block(x) 269 | 270 | 271 | class ConditionalResidualBlock1D(nn.Module): 272 | def __init__(self, 273 | in_channels, 274 | out_channels, 275 | cond_dim, 276 | kernel_size=3, 277 | n_groups=8): 278 | super().__init__() 279 | 280 | self.blocks = nn.ModuleList([ 281 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 282 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 283 | ]) 284 | 285 | # FiLM modulation https://arxiv.org/abs/1709.07871 286 | # predicts per-channel scale and bias 287 | cond_channels = out_channels * 2 288 | self.out_channels = out_channels 289 | self.cond_encoder = nn.Sequential( 290 | nn.Mish(), 291 | nn.Linear(cond_dim, cond_channels), 292 | nn.Unflatten(-1, (-1, 1)) 293 | ) 294 | 295 | # make sure dimensions compatible 296 | self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ 297 | if in_channels != out_channels else nn.Identity() 298 | 299 | def forward(self, x, cond): 300 | ''' 301 | x : [ batch_size x in_channels x horizon ] 302 | cond : [ batch_size x cond_dim] 303 | 304 | returns: 305 | out : [ batch_size x out_channels x horizon ] 306 | ''' 307 | out = self.blocks[0](x) 308 | embed = self.cond_encoder(cond) 309 | 310 | embed = embed.reshape( 311 | embed.shape[0], 2, self.out_channels, 1) 312 | scale = embed[:,0,...] 313 | bias = embed[:,1,...] 314 | out = scale * out + bias 315 | 316 | out = self.blocks[1](out) 317 | out = out + self.residual_conv(x) 318 | return out 319 | 320 | 321 | class ConditionalUnet1D(nn.Module): 322 | def __init__(self, 323 | input_dim, 324 | global_cond_dim, 325 | diffusion_step_embed_dim=256, 326 | down_dims=[256,512,1024], 327 | kernel_size=5, 328 | n_groups=8 329 | ): 330 | """ 331 | input_dim: Dim of actions. 332 | global_cond_dim: Dim of global conditioning applied with FiLM 333 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim 334 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k 335 | down_dims: Channel size for each UNet level. 336 | The length of this array determines numebr of levels. 337 | kernel_size: Conv kernel size 338 | n_groups: Number of groups for GroupNorm 339 | """ 340 | 341 | super().__init__() 342 | all_dims = [input_dim] + list(down_dims) 343 | start_dim = down_dims[0] 344 | 345 | dsed = diffusion_step_embed_dim 346 | diffusion_step_encoder = nn.Sequential( 347 | SinusoidalPosEmb(dsed), 348 | nn.Linear(dsed, dsed * 4), 349 | nn.Mish(), 350 | nn.Linear(dsed * 4, dsed), 351 | ) 352 | cond_dim = dsed + global_cond_dim 353 | 354 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 355 | mid_dim = all_dims[-1] 356 | self.mid_modules = nn.ModuleList([ 357 | ConditionalResidualBlock1D( 358 | mid_dim, mid_dim, cond_dim=cond_dim, 359 | kernel_size=kernel_size, n_groups=n_groups 360 | ), 361 | ConditionalResidualBlock1D( 362 | mid_dim, mid_dim, cond_dim=cond_dim, 363 | kernel_size=kernel_size, n_groups=n_groups 364 | ), 365 | ]) 366 | 367 | down_modules = nn.ModuleList([]) 368 | for ind, (dim_in, dim_out) in enumerate(in_out): 369 | is_last = ind >= (len(in_out) - 1) 370 | down_modules.append(nn.ModuleList([ 371 | ConditionalResidualBlock1D( 372 | dim_in, dim_out, cond_dim=cond_dim, 373 | kernel_size=kernel_size, n_groups=n_groups), 374 | ConditionalResidualBlock1D( 375 | dim_out, dim_out, cond_dim=cond_dim, 376 | kernel_size=kernel_size, n_groups=n_groups), 377 | Downsample1d(dim_out) if not is_last else nn.Identity() 378 | ])) 379 | 380 | up_modules = nn.ModuleList([]) 381 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 382 | is_last = ind >= (len(in_out) - 1) 383 | up_modules.append(nn.ModuleList([ 384 | ConditionalResidualBlock1D( 385 | dim_out*2, dim_in, cond_dim=cond_dim, 386 | kernel_size=kernel_size, n_groups=n_groups), 387 | ConditionalResidualBlock1D( 388 | dim_in, dim_in, cond_dim=cond_dim, 389 | kernel_size=kernel_size, n_groups=n_groups), 390 | Upsample1d(dim_in) if not is_last else nn.Identity() 391 | ])) 392 | 393 | final_conv = nn.Sequential( 394 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 395 | nn.Conv1d(start_dim, input_dim, 1), 396 | ) 397 | 398 | self.diffusion_step_encoder = diffusion_step_encoder 399 | self.up_modules = up_modules 400 | self.down_modules = down_modules 401 | self.final_conv = final_conv 402 | 403 | print("number of parameters: {:e}".format( 404 | sum(p.numel() for p in self.parameters())) 405 | ) 406 | 407 | def forward(self, 408 | sample: torch.Tensor, 409 | timestep: Union[torch.Tensor, float, int], 410 | global_cond=None): 411 | """ 412 | x: (B,T,input_dim) 413 | timestep: (B,) or int, diffusion step 414 | global_cond: (B,global_cond_dim) 415 | output: (B,T,input_dim) 416 | """ 417 | # (B,T,C) 418 | sample = sample.moveaxis(-1,-2) 419 | # (B,C,T) 420 | 421 | # 1. time 422 | timesteps = timestep 423 | if not torch.is_tensor(timesteps): 424 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 425 | timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) 426 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 427 | timesteps = timesteps[None].to(sample.device) 428 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 429 | timesteps = timesteps.expand(sample.shape[0]) 430 | 431 | global_feature = self.diffusion_step_encoder(timesteps) 432 | 433 | if global_cond is not None: 434 | global_feature = torch.cat([ 435 | global_feature, global_cond 436 | ], axis=-1) 437 | 438 | x = sample 439 | h = [] 440 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 441 | x = resnet(x, global_feature) 442 | x = resnet2(x, global_feature) 443 | h.append(x) 444 | x = downsample(x) 445 | 446 | for mid_module in self.mid_modules: 447 | x = mid_module(x, global_feature) 448 | 449 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 450 | x = torch.cat((x, h.pop()), dim=1) 451 | x = resnet(x, global_feature) 452 | x = resnet2(x, global_feature) 453 | x = upsample(x) 454 | 455 | x = self.final_conv(x) 456 | 457 | # (B,C,T) 458 | x = x.moveaxis(-1,-2) 459 | # (B,T,C) 460 | return x 461 | 462 | 463 | 464 | # download demonstration data from Google Drive 465 | # dataset_path = "pusht_cchi_v7_replay.zarr.zip" 466 | # if not os.path.isfile(dataset_path): 467 | # id = "1KY1InLurpMvJDRb14L9NlXT_fEsCvVUq&confirm=t" 468 | # gdown.download(id=id, output=dataset_path, quiet=False) 469 | 470 | def get_filename(input_string): 471 | # Find the last instance of '/' 472 | last_slash_index = input_string.rfind('/') 473 | 474 | # Get the substring after the last '/' 475 | if last_slash_index != -1: 476 | result = input_string[last_slash_index + 1:] 477 | # Return the substring without the last 4 characters 478 | return result[:-7] if len(result) > 7 else "" 479 | else: 480 | return "" 481 | 482 | 483 | dataset_path = "/home/jeon/jeon_ws/diffusion_policy/src/diffusion_cam/RAL_AAA+D_419.zarr.zip" 484 | 485 | #@markdown ### **Network Demo** 486 | class DiffusionPolicy_Real: 487 | def __init__(self, 488 | train=True, 489 | encoder = "resnet", 490 | action_def = "delta", 491 | force_mod:bool = False, 492 | single_view:bool = False, 493 | force_encode = False, 494 | force_encoder = "CNN", 495 | cross_attn: bool = False, 496 | hybrid: bool = False, 497 | duplicate_view = False, 498 | crop: int = 1000, 499 | augment = True): 500 | # action dimension should also correspond with the state dimension (x,y,z, x, y, z, w) 501 | action_dim = 9 502 | # parameters 503 | pred_horizon = 16 504 | obs_horizon = 2 505 | action_horizon = 8 506 | #|o|o| observations: 2 507 | #| |a|a|a|a|a|a|a|a| actions executed: 8 508 | #|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16 509 | batch_size = 100 510 | Transformer_bool = None 511 | modality = "without_force" 512 | view = "dual_view" 513 | if force_mod: 514 | modality = "with_force" 515 | if single_view: 516 | view = "single_view" 517 | # construct ResNet18 encoder 518 | # if you have multiple camera views, use seperate encoder weights for each view. 519 | # Resnet18 and resnet34 both have same dimension for the output 520 | # Define Second vision encoder 521 | 522 | if not single_view: 523 | vision_encoder2 = train_utils().get_resnet('resnet18') 524 | vision_encoder2 = train_utils().replace_bn_with_gn(vision_encoder2) 525 | if duplicate_view: 526 | vision_encoder2 = train_utils().get_resnet('resnet18') 527 | vision_encoder2 = train_utils().replace_bn_with_gn(vision_encoder2) 528 | 529 | if force_mod and force_encode: 530 | if encoder == "viT": 531 | hidden_dim_force = 768 532 | else: 533 | hidden_dim_force = 512 534 | force_encoder = ForceEncoder(4, hidden_dim_force, batch_size = batch_size, 535 | obs_horizon = obs_horizon, 536 | force_encoder= force_encoder, 537 | cross_attn=cross_attn, 538 | train=train) 539 | 540 | if cross_attn and force_mod: 541 | if encoder == "viT": 542 | cross_hidden_dim = 768 543 | image_dim = (3,224,224) 544 | else: 545 | cross_hidden_dim = 512 546 | if crop == 98: 547 | image_dim = (3,98,98) 548 | else: 549 | image_dim = (3,320,240) 550 | joint_encoder = CrossAttentionFusion(image_dim, 4, cross_hidden_dim, batch_size = batch_size, 551 | obs_horizon=obs_horizon, 552 | force_encoder = force_encoder, 553 | im_encoder = encoder, 554 | train = train) 555 | else: 556 | if encoder == "resnet": 557 | print("resnet") 558 | vision_encoder = train_utils().get_resnet('resnet18') 559 | vision_encoder = train_utils().replace_bn_with_gn(vision_encoder) 560 | 561 | elif encoder == "Transformer": 562 | Transformer_bool = True 563 | print("Imported Transformer clip model") 564 | vision_encoder = SimpleRGBObsEncoder() 565 | # IMPORTANT! 566 | # replace all BatchNorm with GroupNorm to work with EMA 567 | # performance will tank if you forget to do this! 568 | # ResNet18 has output dim of 512 X 2 because two views 569 | if single_view: 570 | if encoder == "viT": 571 | vision_feature_dim = 768 572 | else: 573 | vision_feature_dim = 512 574 | else: 575 | if encoder == "viT": 576 | vision_feature_dim = 768 + 512 577 | else: 578 | vision_feature_dim = 512 + 512 579 | 580 | if force_encode: 581 | force_feature_dim = 512 582 | else: 583 | force_feature_dim = 4 584 | # agent_pos is seven (x,y,z, w, y, z, w ) dimensional 585 | lowdim_obs_dim = 9 586 | # observation feature has 514 dims in total per step 587 | if force_mod and not cross_attn: 588 | obs_dim = vision_feature_dim + force_feature_dim + lowdim_obs_dim 589 | elif force_mod and cross_attn and not duplicate_view: 590 | obs_dim = vision_feature_dim + lowdim_obs_dim 591 | elif force_mod and cross_attn and duplicate_view: 592 | obs_dim = vision_feature_dim * 2 + lowdim_obs_dim 593 | else: 594 | obs_dim = vision_feature_dim + lowdim_obs_dim 595 | if hybrid: 596 | obs_dim += 4 597 | 598 | data_name = get_filename(dataset_path) 599 | 600 | if train: 601 | # create dataset from file 602 | dataset = RealRobotDataSet( 603 | dataset_path=dataset_path, 604 | pred_horizon=pred_horizon, 605 | obs_horizon=obs_horizon, 606 | action_horizon=action_horizon, 607 | Transformer= Transformer_bool, 608 | force_mod = force_mod, 609 | single_view=single_view, 610 | augment = False, 611 | duplicate_view = duplicate_view, 612 | crop = crop 613 | ) 614 | # save training data statistics (min, max) for each dim 615 | 616 | 617 | 618 | 619 | if augment: 620 | dataset_augmented = RealRobotDataSet( 621 | dataset_path=dataset_path, 622 | pred_horizon=pred_horizon, 623 | obs_horizon=obs_horizon, 624 | action_horizon=action_horizon, 625 | Transformer= Transformer_bool, 626 | force_mod = force_mod, 627 | single_view=single_view, 628 | augment = True, 629 | duplicate_view = duplicate_view, 630 | crop = crop 631 | ) 632 | 633 | 634 | combined_dataset = ConcatDataset([dataset, dataset_augmented]) 635 | 636 | # DataLoader for combined dataset 637 | data_loader_combined = torch.utils.data.DataLoader( 638 | combined_dataset, 639 | batch_size=batch_size, 640 | num_workers=4, 641 | shuffle=True, # Shuffle to mix normal and augmented data 642 | pin_memory=True, 643 | persistent_workers=True 644 | ) 645 | self.dataloader = data_loader_combined 646 | batch = next(iter(data_loader_combined)) 647 | stats = dataset_augmented.stats 648 | 649 | else: 650 | # create dataloader 651 | dataloader = torch.utils.data.DataLoader( 652 | dataset, 653 | batch_size=batch_size, 654 | num_workers=4, 655 | shuffle=True, 656 | # accelerate cpu-gpu transfer 657 | pin_memory=True, 658 | # don't kill worker process afte each epoch 659 | persistent_workers=True, 660 | ) 661 | self.dataloader = dataloader 662 | batch = next(iter(dataloader)) 663 | stats = dataset.stats 664 | 665 | # Save the stats to a file 666 | with open(f'stats_{data_name}_{encoder}_{action_def}_{modality}_vn.json', 'w') as f: 667 | json.dump(stats, f, cls=NumpyEncoder) 668 | print("stats saved") 669 | 670 | # self.dataloader = data_loader_augmented 671 | # self.data_loader_augmented = data_loader_augmented 672 | self.stats = stats 673 | 674 | #### For debugging purposes uncomment 675 | # import matplotlib.pyplot as plt 676 | # imdata = dataset[100]['image'] 677 | # if imdata.dtype == np.float32 or imdata.dtype == np.float64: 678 | # imdata = imdata / 255.0 679 | # img1 = imdata[0] 680 | # img2 = imdata[1] 681 | # # Loop through the two different "channels" 682 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 683 | # for i in range(2): 684 | # # Convert the 3x96x96 tensor to a 96x96x3 image (for display purposes) 685 | # img = np.transpose(imdata[i], (1, 2, 0)) 686 | 687 | # # Display the image in the i-th subplot 688 | # axes[i].imshow(img) 689 | # axes[i].set_title(f'Channel {i + 1}') 690 | # axes[i].axis('off') 691 | 692 | # # Show the plot 693 | # plt.show() 694 | 695 | # # Check if both images are exactly the same 696 | # are_equal = np.array_equal(img1, img2) 697 | 698 | # if are_equal: 699 | # print("The images are the same.") 700 | # else: 701 | # print("The images are different.") 702 | ######### End ######## 703 | # visualize data in batch 704 | print("batch['image'].shape:", batch['image'].shape) 705 | if not single_view: 706 | print("batch[image2].shape", batch["image2"].shape) 707 | if duplicate_view: 708 | print("batch[image_duplicate].shape", batch["duplicate_image"].shape) 709 | 710 | print("batch['agent_pos'].shape:", batch['agent_pos'].shape) 711 | 712 | if force_mod: 713 | print("batch['force'].shape:", batch['force'].shape) 714 | 715 | print("batch['action'].shape", batch['action'].shape) 716 | self.batch = batch 717 | 718 | # create network object 719 | noise_pred_net = ConditionalUnet1D( 720 | input_dim=action_dim, 721 | global_cond_dim=obs_dim*obs_horizon 722 | ) 723 | if single_view and not force_mod and not force_encode and not cross_attn: 724 | # the final arch has 2 parts 725 | nets = nn.ModuleDict({ 726 | 'vision_encoder': vision_encoder, 727 | 'noise_pred_net': noise_pred_net 728 | }) 729 | elif single_view and force_mod and not force_encode and not cross_attn: 730 | # the final arch has 2 parts 731 | nets = nn.ModuleDict({ 732 | 'vision_encoder': vision_encoder, 733 | 'noise_pred_net': noise_pred_net 734 | }) 735 | elif single_view and force_encode: 736 | # the final arch has 2 parts 737 | nets = nn.ModuleDict({ 738 | 'vision_encoder': vision_encoder, 739 | 'force_encoder': force_encoder, 740 | 'noise_pred_net': noise_pred_net 741 | }) 742 | elif not single_view and force_encode: 743 | nets = nn.ModuleDict({ 744 | 'vision_encoder': vision_encoder, 745 | 'vision_encoder2': vision_encoder2, 746 | 'force_encoder': force_encoder, 747 | 'noise_pred_net': noise_pred_net 748 | }) 749 | elif not single_view and not force_encode and not cross_attn: 750 | nets = nn.ModuleDict({ 751 | 'vision_encoder': vision_encoder, 752 | 'vision_encoder2': vision_encoder2, 753 | 'noise_pred_net': noise_pred_net 754 | }) 755 | elif single_view and cross_attn and not duplicate_view: 756 | nets = nn.ModuleDict({ 757 | 'cross_attn_encoder': joint_encoder, 758 | 'noise_pred_net': noise_pred_net 759 | }) 760 | elif not single_view and cross_attn and not force_encode: 761 | nets = nn.ModuleDict({ 762 | 'cross_attn_encoder': joint_encoder, 763 | 'vision_encoder2': vision_encoder2, 764 | 'noise_pred_net': noise_pred_net 765 | }) 766 | elif single_view and duplicate_view and cross_attn and not force_encode: 767 | nets = nn.ModuleDict({ 768 | 'cross_attn_encoder': joint_encoder, 769 | 'vision_encoder2': vision_encoder2, 770 | 'noise_pred_net': noise_pred_net 771 | }) 772 | elif cross_attn and force_encode: 773 | print("Cross attn and force encode cannot be True at the same time") 774 | 775 | 776 | # diffusion iteration 777 | num_diffusion_iters = 100 778 | 779 | noise_scheduler = DDPMScheduler( 780 | num_train_timesteps=num_diffusion_iters, 781 | # the choise of beta schedule has big impact on performance 782 | # we found squared cosine works the best 783 | beta_schedule='squaredcos_cap_v2', 784 | # clip output to [-1,1] to improve stability 785 | clip_sample=True, 786 | # our network predicts noise (instead of denoised action) 787 | prediction_type='epsilon' 788 | ) 789 | 790 | 791 | self.nets = nets 792 | self.noise_scheduler = noise_scheduler 793 | self.num_diffusion_iters = num_diffusion_iters 794 | self.obs_horizon = obs_horizon 795 | self.obs_dim = obs_dim 796 | if not single_view or duplicate_view: 797 | self.vision_encoder2 = vision_encoder2 798 | if not cross_attn: 799 | self.vision_encoder = vision_encoder 800 | if force_encode: 801 | self.force_encoder = force_encoder 802 | self.noise_pred_net = noise_pred_net 803 | self.action_horizon = action_horizon 804 | self.pred_horizon = pred_horizon 805 | self.lowdim_obs_dim = lowdim_obs_dim 806 | self.action_dim = action_dim 807 | self.data_name = data_name 808 | 809 | 810 | 811 | def test(): 812 | # create dataset from file 813 | obs_horizon = 2 814 | dataset = RealRobotDataSet( 815 | dataset_path=dataset_path, 816 | pred_horizon=16, 817 | obs_horizon=obs_horizon, 818 | action_horizon=8, 819 | Transformer= False, 820 | force_mod = True, 821 | single_view= True 822 | ) 823 | # save training data statistics (min, max) for each dim 824 | stats = dataset.stats 825 | 826 | batch_size = 10 827 | # create dataloader 828 | dataloader = torch.utils.data.DataLoader( 829 | dataset, 830 | batch_size=batch_size, 831 | num_workers=4, 832 | shuffle=True, 833 | # accelerate cpu-gpu transfer 834 | pin_memory=True, 835 | # don't kill worker process afte each epoch 836 | persistent_workers=True 837 | ) 838 | 839 | batch = next(iter(dataloader)) 840 | print("batch['image'].shape:", batch['image'].shape) 841 | print("batch['image'].shape:", batch['image'].shape) 842 | 843 | # ### For debugging purposes uncomment 844 | # import matplotlib.pyplot as plt 845 | # imdata = dataset[100]['image'] 846 | # if imdata.dtype == np.float32 or imdata.dtype == np.float64: 847 | # imdata = imdata 848 | # img1 = imdata[0] 849 | # img2 = imdata[1] 850 | # # Loop through the two different "channels" 851 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 852 | # for i in range(2): 853 | # # Convert the 3x96x96 tensor to a 96x96x3 image (for display purposes) 854 | # img = np.transpose(imdata[i], (1, 2, 0)) 855 | 856 | # # Display the image in the i-th subplot 857 | # axes[i].imshow(img) 858 | # axes[i].set_title(f'Channel {i + 1}') 859 | # axes[i].axis('off') 860 | 861 | # # Show the plot 862 | # plt.show() 863 | 864 | print("batch['agent_pos'].shape:", batch['agent_pos'].shape) 865 | print("batch['force'].shape:", batch['force'].shape) 866 | print("batch['action'].shape", batch['action'].shape) 867 | image_input_shape = (3, 224, 224) 868 | force_dim = 4 869 | hidden_dim = 768 870 | 871 | import torch.optim as optim 872 | device = torch.device('cuda') 873 | # Standard ADAM optimizer 874 | # Note that EMA parametesr are not optimized 875 | model = CrossAttentionFusion(image_input_shape, force_dim, hidden_dim, batch_size = batch_size, obs_horizon=obs_horizon, force_encoder = "MLP", im_encoder= "viT") 876 | model = model.to(device) 877 | num_epochs = 10 # Set the number of epochs 878 | nimage = batch['image'][:,:2].to(device) 879 | nforce = batch['force'][:,:2].to(device) 880 | for epoch in range(num_epochs): 881 | # Example random input data for demonstration 882 | # image_input = nimage.flatten(end_dim=1).to(device) # Batch of 8 images 883 | # force_input = nforce.flatten(end_dim=1).to(device) 884 | # Batch of 8 force vectors 885 | 886 | # Forward pass 887 | latent_embedding = model(nimage, nforce) 888 | 889 | 890 | print(f'Epoch [{epoch+1}/{num_epochs}. {latent_embedding.shape}') 891 | ##TODO: Make sure that new CNN can work with the new architecture for CrossAttention 892 | # test() 893 | 894 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | diffusers==0.18.2 4 | scikit-image==0.19.3 5 | scikit-video==1.1.11 6 | zarr==2.12.0 7 | numcodecs==0.10.2 8 | pygame==2.1.2 9 | pymunk==6.2.1 10 | gym==0.26.2 11 | shapely==1.8.4 12 | opencv-python 13 | gdown -------------------------------------------------------------------------------- /robot_data_collection_joint_state.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import csv 3 | import math 4 | import os 5 | import time 6 | import numpy as np 7 | import pandas as pd 8 | import rclpy 9 | from rclpy.node import Node 10 | from rclpy.action import ActionClient 11 | from trajectory_msgs.msg import JointTrajectory, JointTrajectoryPoint 12 | from control_msgs.action import FollowJointTrajectory 13 | from geometry_msgs.msg import Pose 14 | from moveit_msgs.srv import GetPositionFK 15 | from moveit_msgs.msg import MoveItErrorCodes 16 | from sensor_msgs.msg import JointState 17 | from geometry_msgs.msg import WrenchStamped 18 | from submodules.wait_for_message import wait_for_message 19 | from moveit_msgs.srv import GetPositionIK, GetMotionPlan, GetPlanningScene, ApplyPlanningScene 20 | from moveit_msgs.msg import ( 21 | RobotState, 22 | RobotTrajectory, 23 | MoveItErrorCodes, 24 | Constraints, 25 | JointConstraint, 26 | PlanningScene, 27 | CollisionObject 28 | ) 29 | import time 30 | import pyrealsense2 as rs 31 | import cv2 32 | 33 | def create_directories(): 34 | os.makedirs("/home/lm-2023/jeon_team_ws/playback_pose/src/data_collection/data_collection/images_A", exist_ok=True) 35 | os.makedirs("/home/lm-2023/jeon_team_ws/playback_pose/src/data_collection/data_collection/images_B", exist_ok=True) 36 | 37 | def read_joint_states_from_csv(file_path): 38 | joint_states = [] 39 | with open(file_path, mode='r') as file: 40 | csv_reader = csv.DictReader(file) 41 | for row in csv_reader: 42 | joint_states.append([math.radians(float(row['A1'])), math.radians(float(row['A2'])), math.radians(float(row['A3'])), 43 | math.radians(float(row['A4'])), math.radians(float(row['A5'])), math.radians(float(row['A6'])), math.radians(float(row['A7']))]) 44 | return joint_states 45 | 46 | 47 | class KukaMotionPlanning(Node): 48 | timeout_sec_ = 5.0 49 | move_group_name_ = "arm" 50 | namespace_ = "lbr" 51 | joint_state_topic_ = "/lbr/joint_states" 52 | force_torque_topic_ = "/lbr/force_torque_broadcaster/wrench" 53 | plan_srv_name_ = "plan_kinematic_path" 54 | fk_srv_name_ = "lbr/compute_fk" 55 | execute_action_name_ = "execute_trajectory" 56 | get_planning_scene_srv_name = "get_planning_scene" 57 | 58 | apply_planning_scene_srv_name = "apply_planning_scene" 59 | base_ = "link_0" 60 | end_effector_ = "link_ee" 61 | 62 | def __init__(self, Transformer = False): 63 | super().__init__('kuka_motion_planning') 64 | self.joint_states = None 65 | self.force_torque_data = None # Initialize force/torque data container 66 | self._action_client = ActionClient(self, FollowJointTrajectory, '/lbr/joint_trajectory_controller/follow_joint_trajectory') 67 | self.joint_names = ["A1", "A2", "A3", "A4", "A5", "A6", "A7"] 68 | self.joint_trajectories = read_joint_states_from_csv('/home/lm-2023/Downloads/kuka_replay_force_test/2024-10-29_17-53-12.csv') 69 | self.initialize_robot_pose() 70 | create_directories() 71 | 72 | self.fk_client_ = self.create_client(GetPositionFK, self.fk_srv_name_) 73 | self.force_torque_subscriber = self.create_subscription(WrenchStamped, self.force_torque_topic_, self.force_torque_callback, 10) 74 | self.send_goal() 75 | self.feedback = None 76 | self.Transformer = Transformer 77 | 78 | def force_torque_callback(self, msg): 79 | self.force_torque_data = msg.wrench 80 | 81 | def joint_state_callback(self, msg): 82 | with self.lock: 83 | self.joint_states = msg 84 | self.joint_positions = msg.position 85 | 86 | def initialize_fk_service(self): 87 | self.fk_client_ = self.create_client(GetPositionFK, 'compute_fk') 88 | if not self.fk_client_.wait_for_service(timeout_sec=5.0): 89 | self.get_logger().error("FK service not available.") 90 | exit(1) 91 | 92 | def get_fk(self) -> Pose | None: 93 | current_joint_state = self.get_joint_state() 94 | if current_joint_state is None: 95 | self.get_logger().error("Failed to get current joint state!!!") 96 | return None 97 | current_robot_state = RobotState() 98 | current_robot_state.joint_state = current_joint_state 99 | request = GetPositionFK.Request() 100 | request.header.frame_id = f"{self.namespace_}/{self.base_}" 101 | request.header.stamp = self.get_clock().now().to_msg() 102 | request.fk_link_names.append(self.end_effector_) 103 | request.robot_state = current_robot_state 104 | future = self.fk_client_.call_async(request) 105 | rclpy.spin_until_future_complete(self, future) 106 | if future.result() is None: 107 | self.get_logger().error("Failed to get FK solution") 108 | return None 109 | response = future.result() 110 | if response.error_code.val != MoveItErrorCodes.SUCCESS: 111 | self.get_logger().error(f"Failed to get FK solution: {response.error_code.val}") 112 | return None 113 | return response.pose_stamped[0].pose 114 | 115 | def convert_joints_to_actions(self, joint_values: list[float]) -> Pose | None: 116 | # Create a RobotState message and manually assign joint values 117 | current_robot_state = RobotState() 118 | current_robot_state.joint_state.name = self.joint_names # Set joint names appropriately 119 | current_robot_state.joint_state.position = joint_values # Set the provided joint values 120 | 121 | # Create the FK request 122 | request = GetPositionFK.Request() 123 | request.header.frame_id = f"{self.namespace_}/{self.base_}" 124 | request.header.stamp = self.get_clock().now().to_msg() 125 | request.fk_link_names.append(self.end_effector_) # The end-effector link name 126 | request.robot_state = current_robot_state # Assign the robot state with joint values 127 | 128 | # Call the FK service asynchronously 129 | future = self.fk_client_.call_async(request) 130 | rclpy.spin_until_future_complete(self, future) 131 | 132 | # Check if the FK service call was successful 133 | if future.result() is None: 134 | self.get_logger().error("Failed to get action solution") 135 | return None 136 | 137 | response = future.result() 138 | if response.error_code.val != MoveItErrorCodes.SUCCESS: 139 | self.get_logger().error(f"Failed to get action solution: {response.error_code.val}") 140 | return None 141 | 142 | # Return the pose of the end-effector 143 | return response.pose_stamped[0].pose 144 | 145 | def get_joint_state(self) -> JointState: 146 | current_joint_state_set, current_joint_state = wait_for_message(JointState, self, self.joint_state_topic_) 147 | if not current_joint_state_set: 148 | self.get_logger().error("Failed to get current joint state") 149 | return None 150 | return current_joint_state 151 | 152 | def record_images(self, image_counter): 153 | global pipeline_A, pipeline_B, align_A, align_B 154 | 155 | frames_A = pipeline_A.wait_for_frames() 156 | aligned_frames_A = align_A.process(frames_A) 157 | color_frame_A = aligned_frames_A.get_color_frame() 158 | 159 | frames_B = pipeline_B.wait_for_frames() 160 | aligned_frames_B = align_B.process(frames_B) 161 | color_frame_B = aligned_frames_B.get_color_frame() 162 | 163 | if not color_frame_A or not color_frame_B: 164 | print("Could not acquire frames from RealSense cameras") 165 | return 166 | 167 | color_image_A = np.asanyarray(color_frame_A.get_data()) 168 | color_image_B = np.asanyarray(color_frame_B.get_data()) 169 | image_filename_A = f"camera_A_{image_counter}.png" 170 | color_image_A = cv2.resize(color_image_A, (320, 240), interpolation=cv2.INTER_AREA) 171 | cv2.imwrite(f"/home/lm-2023/jeon_team_ws/playback_pose/src/data_collection/data_collection/images_A/{image_filename_A}", color_image_A) 172 | 173 | 174 | # Get the image dimensions 175 | height_B, width_B, _ = color_image_B.shape 176 | 177 | # Define the center point 178 | center_x, center_y = width_B // 2, height_B // 2 179 | 180 | # Define the crop size 181 | crop_width, crop_height = 320, 240 182 | 183 | # Calculate the top-left corner of the crop box 184 | x1 = max(center_x - crop_width // 2, 0) 185 | y1 = max(center_y - crop_height // 2, 0) 186 | 187 | # Calculate the bottom-right corner of the crop box 188 | x2 = min(center_x + crop_width // 2, width_B) 189 | y2 = min(center_y + crop_height // 2, height_B) 190 | cropped_image_B = color_image_B[y1:y2, x1:x2] 191 | # resized_image_B = cv2.resize(cropped_image_B, (224, 224), interpolation=cv2.INTER_AREA) 192 | 193 | image_filename_B = f"camera_B_{image_counter}.png" 194 | # color_image_B = cv2.resize(color_image_B, (320, 240), interpolation=cv2.INTER_AREA) 195 | cv2.imwrite(f"/home/lm-2023/jeon_team_ws/playback_pose/src/data_collection/data_collection/images_B/{image_filename_B}", cropped_image_B) 196 | complete = os.path.exists(f"/home/lm-2023/jeon_team_ws/playback_pose/src/data_collection/data_collection/images_B/{image_filename_B}") 197 | 198 | while complete == False: 199 | complete = os.path.exists(f"/home/lm-2023/jeon_team_ws/playback_pose/src/data_collection/data_collection/images_B/{image_filename_B}") 200 | print(f"Image Saved, {complete}") 201 | print("Image A and B saved") 202 | 203 | def initialize_robot_pose(self): 204 | # Create a FollowJointTrajectory goal message 205 | goal_msg = FollowJointTrajectory.Goal() 206 | trajectory_msg = JointTrajectory() 207 | trajectory_msg.joint_names = self.joint_names 208 | 209 | self.get_logger().info(f"Processing trajectory point init") 210 | 211 | # Create a JointTrajectoryPoint and assign positions 212 | point = JointTrajectoryPoint() 213 | point.positions = self.joint_trajectories[0] 214 | # point.time_from_start.sec = 1 # Set the time to reach the point (you can modify this) 215 | point.time_from_start.sec = 8 # Set the seconds part to 0 216 | # point.time_from_start.nanosec = int(0.5 * 1e9) # Set the nanoseconds part to 750,000,000 217 | 218 | # Add the point to the trajectory 219 | trajectory_msg.points.append(point) 220 | goal_msg.trajectory = trajectory_msg 221 | self._action_client.wait_for_server() 222 | 223 | # Send the goal asynchronously 224 | send_goal_future = self._action_client.send_goal_async(goal_msg, feedback_callback=self.feedback_callback) 225 | 226 | rclpy.spin_until_future_complete(self, send_goal_future) 227 | 228 | goal_handle = send_goal_future.result() 229 | 230 | if not goal_handle.accepted: 231 | self.get_logger().info(f"Goal for init point was rejected") 232 | return 233 | 234 | self.get_logger().info(f"Goal for init point was accepted") 235 | # Wait for the result to complete before moving to the next trajectory point 236 | get_result_future = goal_handle.get_result_async() 237 | rclpy.spin_until_future_complete(self, get_result_future) 238 | result = get_result_future.result().result 239 | self.get_logger().info(f"Result : {result}, Initial position reached") 240 | input() 241 | 242 | 243 | def send_goal(self): 244 | # Iterate through each joint trajectory point one by one 245 | for i, joint_values in enumerate(self.joint_trajectories): 246 | # Create a FollowJointTrajectory goal message 247 | goal_msg = FollowJointTrajectory.Goal() 248 | trajectory_msg = JointTrajectory() 249 | trajectory_msg.joint_names = self.joint_names 250 | 251 | self.get_logger().info(f"Processing trajectory point {i}") 252 | 253 | # Create a JointTrajectoryPoint and assign positions 254 | point = JointTrajectoryPoint() 255 | point.positions = joint_values 256 | # point.time_from_start.sec = 1 # Set the time to reach the point (you can modify this) 257 | point.time_from_start.sec = 0 # Set the seconds part to 0 258 | point.time_from_start.nanosec = int(0.5 * 1e9) # Set the nanoseconds part to 750,000,000 259 | 260 | # Add the point to the trajectory 261 | trajectory_msg.points.append(point) 262 | goal_msg.trajectory = trajectory_msg 263 | self._action_client.wait_for_server() 264 | 265 | # Prerecord the states and action pair 266 | fk_time = time.time() 267 | robot_pose = self.get_fk() 268 | fk_end_time = time.time() 269 | fk_duration = fk_end_time-fk_time 270 | if robot_pose is not None: 271 | robot_state = [robot_pose.position.x, robot_pose.position.y, robot_pose.position.z, 272 | robot_pose.orientation.x, robot_pose.orientation.y, robot_pose.orientation.z, robot_pose.orientation.w] 273 | else: 274 | robot_pose_data = [None] * 7 # No valid robot pose available 275 | print(f"fk, {robot_state}") 276 | action_pose = self.convert_joints_to_actions(joint_values) 277 | print(action_pose) 278 | if action_pose is not None: 279 | robot_action = [action_pose.position.x, action_pose.position.y, action_pose.position.z, 280 | action_pose.orientation.x, action_pose.orientation.y, action_pose.orientation.z, action_pose.orientation.w] 281 | else: 282 | robot_action = [None] * 7 # No valid robot pose available 283 | # Get force/torque data 284 | if self.force_torque_data is not None: 285 | force_torque_data = [self.force_torque_data.force.x, self.force_torque_data.force.y, self.force_torque_data.force.z,self.force_torque_data.torque.x, self.force_torque_data.torque.y, self.force_torque_data.torque.z] 286 | else: 287 | force_torque_data = [None] * 6 # No valid force/torque data available 288 | print(f"fk_duration, {fk_duration}") 289 | print(self.force_torque_data) 290 | self.record_images(i) 291 | self.write_csv(robot_state, robot_action, force_torque_data) 292 | 293 | # Wait for the action server to be available 294 | curr_time = time.time() 295 | 296 | # Send the goal asynchronously 297 | send_goal_future = self._action_client.send_goal_async(goal_msg, feedback_callback=self.feedback_callback) 298 | after_exec = time.time() 299 | print(f"execution time, {after_exec - curr_time}") 300 | 301 | rclpy.spin_until_future_complete(self, send_goal_future) 302 | 303 | goal_handle = send_goal_future.result() 304 | 305 | if not goal_handle.accepted: 306 | self.get_logger().info(f"Goal for point {i} was rejected") 307 | return 308 | 309 | self.get_logger().info(f"Goal for point {i} was accepted") 310 | # Wait for the result to complete before moving to the next trajectory point 311 | get_result_future = goal_handle.get_result_async() 312 | rclpy.spin_until_future_complete(self, get_result_future) 313 | result = get_result_future.result().result 314 | 315 | self.get_logger().info(f"Result for point {i}: {result}") 316 | 317 | # Introduce a delay if necessary before moving to the next point 318 | # time.sleep(0.5) # Optional: Modify the sleep time if required 319 | 320 | self.get_logger().info('All joint trajectories have been processed.') 321 | 322 | def write_csv(self, robot_state_data, robot_action_data, force_torque_data): 323 | timestamp = time.time() 324 | # Append the new data to the CSV file 325 | csv_filename = '/home/lm-2023/jeon_team_ws/playback_pose/src/data_collection/data_collection/robot_poses1.csv' 326 | new_data = [robot_state_data + robot_action_data + force_torque_data + [timestamp]] 327 | 328 | # Check if the file exists to determine if headers are needed 329 | file_exists = os.path.isfile(csv_filename) 330 | 331 | with open(csv_filename, mode='a', newline='') as file: 332 | writer = csv.writer(file) 333 | if not file_exists: 334 | # Write the header if the file doesn't exist 335 | writer.writerow(['st_robot_x', 'st_robot_y', 'st_robot_z', 'st_robot_qx', 'st_robot_qy', 'st_robot_qz', 'st_robot_qw','ac_robot_x', 'ac_robot_y', 'ac_robot_z', 'ac_robot_qx', 'ac_robot_qy', 'ac_robot_qz', 'ac_robot_qw', 336 | 'force_x', 'force_y', 'force_z','torque_x', 'torque_y', 'torque_z', 'timestamp']) 337 | # Write the data 338 | writer.writerows(new_data) 339 | 340 | 341 | # def goal_response_callback(self, future): 342 | # goal_handle = future.result() 343 | # if not goal_handle.accepted: 344 | # self.get_logger().info('Goal rejected :(') 345 | # return 346 | # self.get_logger().info('Goal accepted :)') 347 | # self._get_result_future = goal_handle.get_result_async() 348 | # self._get_result_future.add_done_callback(self.get_result_callback) 349 | 350 | def feedback_callback(self, feedback_msg): 351 | self.feedback = feedback_msg.feedback 352 | 353 | # def get_result_callback(self, future): 354 | # result = future.result().result 355 | # self.get_logger().info(f'Result: {result}') 356 | 357 | 358 | pipeline_A = None 359 | pipeline_B = None 360 | align_A = None 361 | align_B = None 362 | 363 | 364 | 365 | def initialize_cameras(): 366 | global pipeline_A, pipeline_B, align_A, align_B 367 | 368 | pipeline_A = rs.pipeline() 369 | pipeline_B = rs.pipeline() 370 | 371 | context = rs.context() 372 | devices = context.query_devices() 373 | 374 | if len(devices) < 2: 375 | raise RuntimeError("Two cameras are required, but fewer were detected.") 376 | 377 | serial_A = devices[1].get_info(rs.camera_info.serial_number) 378 | serial_B = devices[0].get_info(rs.camera_info.serial_number) 379 | 380 | config_A = rs.config() 381 | config_A.enable_device(serial_A) 382 | config_A.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) 383 | 384 | config_B = rs.config() 385 | config_B.enable_device(serial_B) 386 | config_B.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) 387 | 388 | pipeline_A.start(config_A) 389 | pipeline_B.start(config_B) 390 | 391 | align_A = rs.align(rs.stream.color) 392 | align_B = rs.align(rs.stream.color) 393 | # Initialize once 394 | 395 | 396 | def main(args=None): 397 | try: 398 | rclpy.init(args=args) 399 | except Exception as e: 400 | print(f"Failed to initialize rclpy: {e}") 401 | return 402 | 403 | try: 404 | initialize_cameras() 405 | node = KukaMotionPlanning(Transformer = True) 406 | rclpy.spin(node) 407 | except Exception as e: 408 | print(f"Exception during ROS2 node operation: {e}") 409 | 410 | if __name__ == '__main__': 411 | main() 412 | -------------------------------------------------------------------------------- /rotation_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial.transform as st 3 | from rotation_utils import rot6d_to_mat, mat_to_rot6d 4 | 5 | # converting full name to scipy Rotation name 6 | scipy_rep_map = { 7 | 'axis_angle': 'rotvec', 8 | 'quaternion': 'quat', 9 | 'matrix': 'matrix', 10 | 'rotation_6d': None 11 | } 12 | 13 | def transform_rotation(x, from_rep, to_rep): 14 | from_rep = scipy_rep_map[from_rep] 15 | to_rep = scipy_rep_map[to_rep] 16 | if from_rep is not None and to_rep is not None: 17 | # scipy rotation transform 18 | rot = getattr(st.Rotation, f'from_{from_rep}')(x) 19 | out = getattr(rot, f'as_{to_rep}')() 20 | return out 21 | else: 22 | mat = None 23 | if from_rep is None: 24 | mat = rot6d_to_mat(x) 25 | else: 26 | mat = getattr(st.Rotation, f'from_{from_rep}')(x).as_matrix() 27 | 28 | if to_rep is None: 29 | out = mat_to_rot6d(mat) 30 | else: 31 | out = getattr(st.Rotation.from_matrix(mat), f'as_{to_rep}')() 32 | return out 33 | 34 | 35 | class RotationTransformer: 36 | # for legacy compatibility 37 | def __init__(self, 38 | from_rep='axis_angle', 39 | to_rep='rotation_6d'): 40 | """ 41 | Valid representations 42 | 43 | Always use matrix as intermediate representation. 44 | """ 45 | self.from_rep = from_rep 46 | self.to_rep = to_rep 47 | 48 | def forward(self, x: np.ndarray) -> np.ndarray: 49 | return transform_rotation(x, from_rep=self.from_rep, to_rep=self.to_rep) 50 | 51 | def inverse(self, x: np.ndarray) -> np.ndarray: 52 | return transform_rotation(x, from_rep=self.to_rep, to_rep=self.from_rep) 53 | 54 | 55 | def test(): 56 | tf = RotationTransformer() 57 | 58 | rotvec = np.random.uniform(-2*np.pi,2*np.pi,size=(1000,3)) 59 | rot6d = tf.forward(rotvec) 60 | new_rotvec = tf.inverse(rot6d) 61 | 62 | from scipy.spatial.transform import Rotation 63 | diff = Rotation.from_rotvec(rotvec) * Rotation.from_rotvec(new_rotvec).inv() 64 | dist = diff.magnitude() 65 | assert dist.max() < 1e-7 66 | 67 | tf = RotationTransformer('rotation_6d', 'matrix') 68 | rot6d_wrong = rot6d + np.random.normal(scale=0.1, size=rot6d.shape) 69 | mat = tf.forward(rot6d_wrong) 70 | mat_det = np.linalg.det(mat) 71 | assert np.allclose(mat_det, 1) 72 | # rotaiton_6d will be normalized to rotation matrix 73 | -------------------------------------------------------------------------------- /rotation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial.transform as st 3 | 4 | def pos_rot_to_mat(pos, rot): 5 | shape = pos.shape[:-1] 6 | mat = np.zeros(shape + (4,4), dtype=pos.dtype) 7 | mat[...,:3,3] = pos 8 | mat[...,:3,:3] = rot.as_matrix() 9 | mat[...,3,3] = 1 10 | return mat 11 | 12 | def mat_to_pos_rot(mat): 13 | pos = (mat[...,:3,3].T / mat[...,3,3].T).T 14 | rot = st.Rotation.from_matrix(mat[...,:3,:3]) 15 | return pos, rot 16 | 17 | def pos_rot_to_pose(pos, rot): 18 | shape = pos.shape[:-1] 19 | pose = np.zeros(shape+(6,), dtype=pos.dtype) 20 | pose[...,:3] = pos 21 | pose[...,3:] = rot.as_rotvec() 22 | return pose 23 | 24 | def pose_to_pos_rot(pose): 25 | pos = pose[...,:3] 26 | rot = st.Rotation.from_rotvec(pose[...,3:]) 27 | return pos, rot 28 | 29 | def pose_to_mat(pose): 30 | return pos_rot_to_mat(*pose_to_pos_rot(pose)) 31 | 32 | def mat_to_pose(mat): 33 | return pos_rot_to_pose(*mat_to_pos_rot(mat)) 34 | 35 | def transform_pose(tx, pose): 36 | """ 37 | tx: tx_new_old 38 | pose: tx_old_obj 39 | result: tx_new_obj 40 | """ 41 | pose_mat = pose_to_mat(pose) 42 | tf_pose_mat = tx @ pose_mat 43 | tf_pose = mat_to_pose(tf_pose_mat) 44 | return tf_pose 45 | 46 | def transform_point(tx, point): 47 | return point @ tx[:3,:3].T + tx[:3,3] 48 | 49 | def project_point(k, point): 50 | x = point @ k.T 51 | uv = x[...,:2] / x[...,[2]] 52 | return uv 53 | 54 | def apply_delta_pose(pose, delta_pose): 55 | new_pose = np.zeros_like(pose) 56 | 57 | # simple add for position 58 | new_pose[:3] = pose[:3] + delta_pose[:3] 59 | 60 | # matrix multiplication for rotation 61 | rot = st.Rotation.from_rotvec(pose[3:]) 62 | drot = st.Rotation.from_rotvec(delta_pose[3:]) 63 | new_pose[3:] = (drot * rot).as_rotvec() 64 | 65 | return new_pose 66 | 67 | def normalize(vec, tol=1e-7): 68 | return vec / np.maximum(np.linalg.norm(vec), tol) 69 | 70 | def rot_from_directions(from_vec, to_vec): 71 | from_vec = normalize(from_vec) 72 | to_vec = normalize(to_vec) 73 | axis = np.cross(from_vec, to_vec) 74 | axis = normalize(axis) 75 | angle = np.arccos(np.dot(from_vec, to_vec)) 76 | rotvec = axis * angle 77 | rot = st.Rotation.from_rotvec(rotvec) 78 | return rot 79 | 80 | def normalize(vec, eps=1e-12): 81 | norm = np.linalg.norm(vec, axis=-1) 82 | norm = np.maximum(norm, eps) 83 | out = (vec.T / norm).T 84 | return out 85 | 86 | def rot6d_to_mat(d6): 87 | a1, a2 = d6[..., :3], d6[..., 3:] 88 | b1 = normalize(a1) 89 | b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1 90 | b2 = normalize(b2) 91 | b3 = np.cross(b1, b2, axis=-1) 92 | out = np.stack((b1, b2, b3), axis=-2) 93 | return out 94 | 95 | def mat_to_rot6d(mat): 96 | batch_dim = mat.shape[:-2] 97 | out = mat[..., :2, :].copy().reshape(batch_dim + (6,)) 98 | return out 99 | 100 | def mat_to_pose10d(mat): 101 | pos = mat[...,:3,3] 102 | rotmat = mat[...,:3,:3] 103 | d6 = mat_to_rot6d(rotmat) 104 | d10 = np.concatenate([pos, d6], axis=-1) 105 | return d10 106 | 107 | def pose10d_to_mat(d10): 108 | pos = d10[...,:3] 109 | d6 = d10[...,3:] 110 | rotmat = rot6d_to_mat(d6) 111 | out = np.zeros(d10.shape[:-1]+(4,4), dtype=d10.dtype) 112 | out[...,:3,:3] = rotmat 113 | out[...,:3,3] = pos 114 | out[...,3,3] = 1 115 | return out 116 | 117 | def quat_from_rot_m(mat): 118 | rotation = st.Rotation.from_matrix(mat) # Create a Rotation object from the matrix 119 | quaternion = rotation.as_quat() # Convert to quaternion (x, y, z, w) 120 | return quaternion 121 | 122 | def quat_to_rot_m(quat): 123 | rotation = st.Rotation.from_quat(quat) # Create a Rotation object from the quaternion 124 | rot_matrix = rotation.as_matrix() # Convert to rotation matrix 125 | return rot_matrix 126 | -------------------------------------------------------------------------------- /test_rotation.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import sys 3 | import os 4 | 5 | ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) 6 | sys.path.append(ROOT_DIR) 7 | os.chdir(ROOT_DIR) 8 | 9 | # %% 10 | import numpy as np 11 | from rotation_transformer import RotationTransformer 12 | from rotation_utils import rot6d_to_mat, mat_to_rot6d, normalize, quat_from_rot_m 13 | 14 | # %% 15 | def test(): 16 | N = 100 17 | d6 = np.random.normal(size=(N,6)) 18 | rt = RotationTransformer(from_rep='rotation_6d', to_rep='matrix') 19 | gt_mat = rt.forward(d6) 20 | mat = rot6d_to_mat(d6) 21 | assert np.allclose(gt_mat, mat) 22 | 23 | to_d6 = mat_to_rot6d(mat) 24 | to_d6_gt = rt.inverse(mat) 25 | assert np.allclose(to_d6, to_d6_gt) 26 | gt_mat = rt.forward(d6[1]) 27 | mat = rot6d_to_mat(d6[1]) 28 | assert np.allclose(gt_mat, mat) 29 | print(mat) 30 | norm_mat = normalize(mat) 31 | print(norm_mat) 32 | to_d6 = mat_to_rot6d(norm_mat) 33 | to_d6_gt = rt.inverse(norm_mat) 34 | print(np.sqrt(to_d6[0]**2+to_d6[1]**2+to_d6[2]**2)) 35 | print(np.sqrt(to_d6[3]**2+to_d6[4]**2+to_d6[5]**2)) 36 | print(quat_from_rot_m(norm_mat)) 37 | assert np.allclose(to_d6, to_d6_gt) 38 | 39 | 40 | if __name__ == "__main__": 41 | test() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from network import DiffusionPolicy 2 | from real_robot_network import DiffusionPolicy_Real 3 | from diffusers.optimization import get_scheduler 4 | from diffusers.training_utils import EMAModel 5 | import numpy as np 6 | from tqdm.auto import tqdm 7 | import torch 8 | import torch.nn as nn 9 | import os 10 | import time 11 | 12 | 13 | def train(continue_training=False, start_epoch = 0): 14 | # # for this demo, we use DDPMScheduler with 100 diffusion iterations 15 | diffusion = DiffusionPolicy() 16 | # device transfer 17 | device = torch.device('cuda') 18 | _ = diffusion.nets.to(device) 19 | 20 | #@markdown ### **Training** 21 | #@markdown 22 | #@markdown Takes about 2.5 hours. If you don't want to wait, skip to the next cell 23 | #@markdown to load pre-trained weights 24 | 25 | num_epochs = 100 26 | 27 | checkpoint_dir = "/home/lm-2023/jeon_team_ws/playback_pose/src/Diffusion_Policy_ICRA/checkpoints/" 28 | if continue_training: 29 | start_epoch = 59 30 | checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{start_epoch}.pth') # Replace with the correct path 31 | # Load the saved state_dict into the model 32 | checkpoint = torch.load(checkpoint_path) 33 | diffusion.nets.load_state_dict(checkpoint) # Load model state 34 | start_epoch = 60 35 | print("Successfully loaded Checkpoint") 36 | # Exponential Moving Average 37 | # accelerates training and improves stability 38 | # holds a copy of the model weights 39 | ema = EMAModel( 40 | parameters=diffusion.nets.parameters(), 41 | power=0.75) 42 | 43 | # Standard ADAM optimizer 44 | # Note that EMA parametesr are not optimized 45 | optimizer = torch.optim.AdamW( 46 | params=diffusion.nets.parameters(), 47 | lr=1e-4, weight_decay=1e-6) 48 | 49 | # Cosine LR schedule with linear warmup 50 | lr_scheduler = get_scheduler( 51 | name='cosine', 52 | optimizer=optimizer, 53 | num_warmup_steps=500, 54 | num_training_steps=len(diffusion.dataloader) * num_epochs 55 | ) 56 | 57 | 58 | with tqdm(range(start_epoch, num_epochs), desc='Epoch') as tglobal: 59 | # epoch loop 60 | for epoch_idx in tglobal: 61 | epoch_loss = list() 62 | # batch loop 63 | with tqdm(diffusion.dataloader, desc='Batch', leave=False) as tepoch: 64 | for nbatch in tepoch: 65 | # data normalized in dataset 66 | # device transfer 67 | nimage = nbatch['image'][:,:diffusion.obs_horizon].to(device) 68 | nagent_pos = nbatch['agent_pos'][:,:diffusion.obs_horizon].to(device) 69 | naction = nbatch['action'].to(device) 70 | B = nagent_pos.shape[0] 71 | 72 | # encoder vision features 73 | image_features = diffusion.nets['vision_encoder']( 74 | nimage.flatten(end_dim=1)) 75 | image_features = image_features.reshape( 76 | *nimage.shape[:2],-1) 77 | # (B,obs_horizon,D) 78 | 79 | # concatenate vision feature and low-dim obs 80 | obs_features = torch.cat([image_features, nagent_pos], dim=-1) 81 | obs_cond = obs_features.flatten(start_dim=1) 82 | # (B, obs_horizon * obs_dim) 83 | 84 | # sample noise to add to actions 85 | noise = torch.randn(naction.shape, device=device) 86 | 87 | # sample a diffusion iteration for each data point 88 | timesteps = torch.randint( 89 | 0, diffusion.noise_scheduler.config.num_train_timesteps, 90 | (B,), device=device 91 | ).long() 92 | 93 | # add noise to the clean images according to the noise magnitude at each diffusion iteration 94 | # (this is the forward diffusion process) 95 | noisy_actions = diffusion.noise_scheduler.add_noise( 96 | naction, noise, timesteps) 97 | 98 | # predict the noise residual 99 | noise_pred = diffusion.noise_pred_net( 100 | noisy_actions, timesteps, global_cond=obs_cond) 101 | 102 | # L2 loss 103 | loss = nn.functional.mse_loss(noise_pred, noise) 104 | 105 | # optimize 106 | loss.backward() 107 | optimizer.step() 108 | optimizer.zero_grad() 109 | # step lr scheduler every batch 110 | # this is different from standard pytorch behavior 111 | lr_scheduler.step() 112 | 113 | # update Exponential Moving Average of the model weights 114 | ema.step(diffusion.nets.parameters()) 115 | 116 | # logging 117 | loss_cpu = loss.item() 118 | epoch_loss.append(loss_cpu) 119 | tepoch.set_postfix(loss=loss_cpu) 120 | 121 | tglobal.set_postfix(loss=np.mean(epoch_loss)) 122 | avg_loss = np.mean(epoch_loss) 123 | tglobal.set_postfix(loss=avg_loss) 124 | 125 | # Save checkpoint every 10 epochs or at the end of training 126 | if (epoch_idx + 1) % 30 == 0 or (epoch_idx + 1) == num_epochs: 127 | # Save only the state_dict of the model, including relevant submodules 128 | torch.save(diffusion.nets.state_dict(), os.path.join(checkpoint_dir, f'checkpoint_res50_{epoch_idx+1}.pth')) 129 | 130 | # Weights of the EMA model 131 | # is used for inference 132 | ema_nets = diffusion.nets 133 | ema.copy_to(ema_nets.parameters()) 134 | 135 | 136 | if __name__ == "__main__": 137 | train(continue_training=False) -------------------------------------------------------------------------------- /train_real.py: -------------------------------------------------------------------------------- 1 | from real_robot_network import DiffusionPolicy_Real 2 | from diffusers.optimization import get_scheduler 3 | from diffusers.training_utils import EMAModel 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | import torch 7 | import torch.nn as nn 8 | import os 9 | import matplotlib.pyplot as plt 10 | 11 | torch.cuda.empty_cache() 12 | 13 | import hydra 14 | from omegaconf import DictConfig 15 | 16 | # Make sure Crop is all there 17 | @hydra.main(version_base=None, config_path="config", config_name="resnet_delta_with_force_single_view_force_Linear_crossattn_hybrid_crop") 18 | def train_Real_Robot(cfg: DictConfig): 19 | continue_training= cfg.model_config.continue_training 20 | start_epoch = cfg.model_config.start_epoch 21 | end_epoch= cfg.model_config.end_epoch 22 | encoder:str = cfg.model_config.encoder 23 | action_def: str = cfg.model_config.action_def 24 | force_mod: bool = cfg.model_config.force_mod 25 | single_view: bool = cfg.model_config.single_view 26 | force_encode = cfg.model_config.force_encode 27 | force_encoder = cfg.model_config.force_encoder 28 | cross_attn = cfg.model_config.cross_attn 29 | hybrid = cfg.model_config.hybrid 30 | duplicate_view = cfg.model_config.duplicate_view 31 | crop = cfg.model_config.crop 32 | 33 | if force_encode: 34 | cross_attn = False 35 | if cross_attn: 36 | force_encode = False 37 | 38 | print(f"Training model with vision {cfg.name}") 39 | # # for this demo, we use DDPMScheduler with 100 diffusion iterations 40 | diffusion = DiffusionPolicy_Real(encoder= encoder, 41 | action_def = action_def, 42 | force_mod = force_mod, 43 | single_view= single_view, 44 | force_encode=force_encode, 45 | force_encoder=force_encoder, 46 | cross_attn=cross_attn, 47 | hybrid = hybrid, 48 | duplicate_view = duplicate_view, 49 | crop = crop) 50 | data_name = diffusion.data_name 51 | 52 | device = torch.device('cuda') 53 | _ = diffusion.nets.to(device) 54 | 55 | #@markdown ### **Training** 56 | #@markdown 57 | #@markdown Takes about 2.5 hours. If you don't want to wait, skip to the next cell 58 | #@markdown to load pre-trained weights 59 | 60 | 61 | # Exponential Moving Average 62 | # accelerates training and improves stability 63 | # holds a copy of the model weights 64 | ema = EMAModel( 65 | parameters=diffusion.nets.parameters(), 66 | power=0.75) 67 | checkpoint_dir = "/home/jeon/jeon_ws/diffusion_policy/src/diffusion_cam/checkpoints" 68 | # To continue t raining load and set the start epoch 69 | if continue_training: 70 | start_epoch = 1500 71 | checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{start_epoch}.pth') # Replace with the correct path 72 | # Load the saved state_dict into the model 73 | checkpoint = torch.load(checkpoint_path) 74 | diffusion.nets.load_state_dict(checkpoint) # Load model state 75 | print("Successfully loaded Checkpoint") 76 | 77 | # Standard ADAM optimizer 78 | # Note that EMA parametesr are not optimized 79 | optimizer = torch.optim.AdamW( 80 | params=diffusion.nets.parameters(), 81 | lr=2e-4, weight_decay=1e-6) 82 | 83 | # Cosine LR schedule with linear warmup 84 | lr_scheduler = get_scheduler( 85 | name='cosine', 86 | optimizer=optimizer, 87 | num_warmup_steps=500, 88 | num_training_steps=len(diffusion.dataloader) * end_epoch 89 | ) 90 | # Log loss for epochs 91 | 92 | epoch_losses = [] 93 | 94 | with tqdm(range(start_epoch, end_epoch), desc='Epoch') as tglobal: 95 | # epoch 96 | for epoch_idx in tglobal: 97 | epoch_loss = list() 98 | ### THis is for seperately training augmented and non augmented data 99 | # # Decide which data loader to use for this epoch 100 | # if epoch_idx % 2 == 0: 101 | # # Use normal data loader every other epoch 102 | # current_loader = diffusion.data_loader 103 | # tglobal.set_postfix({'Data': 'Normal'}) 104 | # else: 105 | # # Use augmented data loader on alternate epochs 106 | # current_loader = diffusion.data_loader_augmented 107 | # tglobal.set_postfix({'Data': 'Augmented'}) 108 | current_loader = diffusion.dataloader 109 | # batch loop 110 | with tqdm(current_loader, desc='Batch', leave=False) as tepoch: 111 | for nbatch in tepoch: 112 | # data normalized in dataset 113 | # device transfer 114 | nimage = nbatch['image'][:,:diffusion.obs_horizon].to(device) 115 | 116 | if not single_view: 117 | nimage_second_view = nbatch['image2'][:,:diffusion.obs_horizon].to(device) 118 | if duplicate_view: 119 | nimage_duplicate = nbatch['duplicate_image'][:,:diffusion.obs_horizon].to(device) 120 | if force_mod: 121 | nforce = nbatch['force'][:,:diffusion.obs_horizon].to(device) 122 | else: 123 | nforce = None 124 | nagent_pos = nbatch['agent_pos'][:,:diffusion.obs_horizon].to(device) 125 | naction = nbatch['action'].to(device) 126 | 127 | # Debug sequential data structure. It shoud be consecutive 128 | # import matplotlib.pyplot as plt 129 | # imdata1 = nimage[0].cpu() 130 | # imdata1 = imdata1.numpy() 131 | # print(f"shape of the image data:", imdata1.shape) 132 | # # imdata2 = nimage_duplicate[0].cpu() 133 | # # imdata2 = imdata2.numpy() 134 | # # print(f"shape of the image data:", imdata2.shape) 135 | 136 | # fig, axes = plt.subplots(1, 2, figsize=(10, 5)) 137 | # for j in range(2): 138 | # # Convert the 3x96x96 tensor to a 96x96x3 image (for display purposes) 139 | # img = imdata1[j].transpose(1, 2, 0) 140 | 141 | # # # Plot the image on the corresponding subplot 142 | # axes[j].imshow(img) 143 | # axes[j].axis('off') # Hide the axes 144 | # # Show the plot 145 | # plt.show() 146 | # ### For double realsense config only 147 | # for j in range(2): 148 | # # Convert the 3x96x96 tensor to a 96x96x3 image (for display purposes) 149 | # img2 = imdata2[j].transpose(1, 2, 0) 150 | 151 | # # Plot the image on the corresponding subplot 152 | # axes[j].imshow(img2) 153 | # axes[j].axis('off') # Hide the axes 154 | # # Show the plot 155 | # plt.show() 156 | 157 | if encoder == "resnet": 158 | if duplicate_view: 159 | image_input = nimage_duplicate.flatten(end_dim=1) 160 | else: 161 | image_input = nimage.flatten(end_dim=1) 162 | elif encoder == "viT": 163 | if duplicate_view: 164 | image_input = nimage_duplicate 165 | else: 166 | image_input = nimage 167 | B = nagent_pos.shape[0] 168 | if not cross_attn: 169 | # encoder vision features 170 | image_features = diffusion.nets['vision_encoder']( 171 | image_input) 172 | image_features = image_features.reshape( 173 | *nimage.shape[:2],-1) 174 | # (B,obs_horizon,D) 175 | 176 | if not single_view: 177 | # encoder vision features 178 | image_features_second_view = diffusion.nets['vision_encoder2']( 179 | nimage_second_view.flatten(end_dim=1)) 180 | image_features_second_view = image_features_second_view.reshape( 181 | *nimage_second_view.shape[:2],-1) 182 | if duplicate_view: 183 | # encoder vision features 184 | image_features_duplicate = diffusion.nets['vision_encoder2']( 185 | nimage.flatten(end_dim=1)) 186 | image_features_duplicate = image_features_duplicate.reshape( 187 | *nimage.shape[:2],-1) 188 | if force_mod and force_encode: 189 | force_feature = diffusion.nets['force_encoder'](nforce) 190 | # force_feature = force_feature.reshape( 191 | # *nforce.shape[:2],-1) 192 | else: 193 | force_feature = nforce 194 | 195 | if cross_attn: 196 | joint_features = diffusion.nets['cross_attn_encoder']( 197 | image_input, (nforce)) 198 | 199 | # (B,obs_horizon,D) 200 | if force_mod and single_view and not cross_attn: 201 | obs_features = torch.cat([image_features, force_feature, nagent_pos], dim=-1) 202 | elif force_mod and not single_view and not cross_attn: 203 | obs_features = torch.cat([image_features, image_features_second_view, force_feature, nagent_pos], dim=-1) 204 | elif not force_mod and single_view: 205 | obs_features = torch.cat([image_features, nagent_pos], dim=-1) 206 | elif not force_mod and not single_view: 207 | obs_features = torch.cat([image_features, image_features_second_view , nagent_pos], dim=-1) 208 | elif single_view and cross_attn: 209 | # TODO: If hybrid is true, then add force feature on top of it. 210 | if hybrid: 211 | obs_features = torch.cat([joint_features, force_feature, nagent_pos], dim=-1) 212 | else: 213 | obs_features = torch.cat([joint_features, nagent_pos], dim=-1) 214 | elif not single_view and cross_attn: 215 | # TODO: If hybrid is true, then add force feature on top of it. 216 | if hybrid: 217 | obs_features = torch.cat([joint_features, image_features_second_view, force_feature, nagent_pos], dim=-1) 218 | else: 219 | obs_features = torch.cat([joint_features, image_features_second_view, nagent_pos], dim=-1) 220 | elif single_view and duplicate_view and cross_attn: 221 | # TODO: If hybrid is true, then add force feature on top of it. 222 | if hybrid: 223 | obs_features = torch.cat([joint_features, image_features_duplicate, force_feature, nagent_pos], dim=-1) 224 | else: 225 | obs_features = torch.cat([joint_features, image_features_duplicate, nagent_pos], dim=-1) 226 | else: 227 | print("Check your configuration for training") 228 | 229 | obs_cond = obs_features.flatten(start_dim=1) 230 | # (B, obs_horizon * obs_dim) 231 | 232 | # sample noise to add to actions 233 | noise = torch.randn(naction.shape, device=device) 234 | 235 | # sample a diffusion iteration for each data point 236 | timesteps = torch.randint( 237 | 0, diffusion.noise_scheduler.config.num_train_timesteps, 238 | (B,), device=device 239 | ).long() 240 | 241 | # add noise to the clean images according to the noise magnitude at each diffusion iteration 242 | # (this is the forward diffusion process) 243 | noisy_actions = diffusion.noise_scheduler.add_noise( 244 | naction, noise, timesteps) 245 | 246 | # predict the noise residual 247 | noise_pred = diffusion.noise_pred_net( 248 | noisy_actions, timesteps, global_cond=obs_cond) 249 | 250 | # L2 loss 251 | loss = nn.functional.mse_loss(noise_pred, noise) 252 | 253 | # optimize 254 | loss.backward() 255 | optimizer.step() 256 | optimizer.zero_grad() 257 | # step lr scheduler every batch 258 | # this is different from standard pytorch behavior 259 | lr_scheduler.step() 260 | 261 | # update Exponential Moving Average of the model weights 262 | ema.step(diffusion.nets.parameters()) 263 | 264 | # logging 265 | loss_cpu = loss.item() 266 | epoch_loss.append(loss_cpu) 267 | tepoch.set_postfix(loss=loss_cpu) 268 | 269 | tglobal.set_postfix(loss=np.mean(epoch_loss)) 270 | avg_loss = np.mean(epoch_loss) 271 | epoch_losses.append(avg_loss) 272 | tglobal.set_postfix(loss=avg_loss) 273 | 274 | # Save checkpoint every 10 epochs or at the end of training 275 | if epoch_idx > 950: 276 | if (epoch_idx + 1) % 200 == 0 or (epoch_idx + 1) == end_epoch: 277 | # Save only the state_dict of the model, including relevant submodules 278 | torch.save(diffusion.nets.state_dict(), os.path.join(checkpoint_dir, f'{cfg.name}_{data_name}_{epoch_idx+1}_bench.pth')) 279 | # Plot the loss after training is complete 280 | plt.figure(figsize=(10, 6)) 281 | plt.plot(range(1, end_epoch + 1), epoch_losses, marker='o', label='Training Loss') 282 | plt.xlabel('Epoch') 283 | plt.ylabel('Loss') 284 | plt.title('Training Loss over Epocshs') 285 | plt.legend() 286 | plt.grid(True) 287 | plt.show() 288 | # Weights of the EMA model 289 | # is used for inference 290 | ema_nets = diffusion.nets 291 | ema.copy_to(ema_nets.parameters()) 292 | 293 | 294 | if __name__ == "__main__": 295 | train_Real_Robot() -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Sequence, Dict, Union, Optional, Callable 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | 6 | 7 | class train_utils: 8 | def __init__(self): 9 | pass 10 | 11 | #@markdown ### **Vision Encoder** 12 | #@markdown 13 | #@markdown Defines helper functions: 14 | #@markdown - `get_resnet` to initialize standard ResNet vision encoder 15 | #@markdown - `replace_bn_with_gn` to replace all BatchNorm layers with GroupNorm 16 | 17 | 18 | def get_resnet(self, name, weights=None, **kwargs): 19 | """ 20 | name: resnet18, resnet34, resnet50 21 | weights: "IMAGENET1K_V1", "r3m" 22 | """ 23 | # load r3m weights 24 | if (weights == "r3m") or (weights == "R3M"): 25 | return self.get_r3m(name=name, **kwargs) 26 | 27 | func = getattr(torchvision.models, name) 28 | resnet = func(weights=weights, **kwargs) 29 | resnet.fc = torch.nn.Identity() 30 | return resnet 31 | 32 | def get_r3m(self, name, **kwargs): 33 | """ 34 | name: resnet18, resnet34, resnet50 35 | """ 36 | import r3m 37 | r3m.device = 'cpu' 38 | model = r3m.load_r3m(name) 39 | r3m_model = model.module 40 | resnet_model = r3m_model.convnet 41 | resnet_model = resnet_model.to('cpu') 42 | return resnet_model 43 | 44 | def replace_submodules(self, 45 | root_module: nn.Module, 46 | predicate: Callable[[nn.Module], bool], 47 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 48 | """ 49 | Replace all submodules selected by the predicate with 50 | the output of func. 51 | 52 | predicate: Return true if the module is to be replaced. 53 | func: Return new module to use. 54 | """ 55 | if predicate(root_module): 56 | return func(root_module) 57 | 58 | bn_list = [k.split('.') for k, m 59 | in root_module.named_modules(remove_duplicate=True) 60 | if predicate(m)] 61 | for *parent, k in bn_list: 62 | parent_module = root_module 63 | if len(parent) > 0: 64 | parent_module = root_module.get_submodule('.'.join(parent)) 65 | if isinstance(parent_module, nn.Sequential): 66 | src_module = parent_module[int(k)] 67 | else: 68 | src_module = getattr(parent_module, k) 69 | tgt_module = func(src_module) 70 | if isinstance(parent_module, nn.Sequential): 71 | parent_module[int(k)] = tgt_module 72 | else: 73 | setattr(parent_module, k, tgt_module) 74 | # verify that all modules are replaced 75 | bn_list = [k.split('.') for k, m 76 | in root_module.named_modules(remove_duplicate=True) 77 | if predicate(m)] 78 | assert len(bn_list) == 0 79 | return root_module 80 | 81 | def replace_bn_with_gn(self, 82 | root_module: nn.Module, 83 | features_per_group: int=16) -> nn.Module: 84 | """ 85 | Relace all BatchNorm layers with GroupNorm. 86 | """ 87 | self.replace_submodules( 88 | root_module=root_module, 89 | predicate=lambda x: isinstance(x, nn.BatchNorm2d), 90 | func=lambda x: nn.GroupNorm( 91 | num_groups=x.num_features//features_per_group, 92 | num_channels=x.num_features) 93 | ) 94 | return root_module 95 | 96 | class SimpleViTEncoder(nn.Module): 97 | def __init__(self, model_name: str = 'vit_base_patch16_224', pretrained: bool = True, frozen: bool = False): 98 | super().__init__() 99 | # Load the ViT model 100 | self.vision_encoder = timm.create_model(model_name, pretrained=pretrained, num_classes=0) # Remove classifier 101 | 102 | # Optionally freeze the model if required 103 | if frozen: 104 | for param in self.vision_encoder.parameters(): 105 | param.requires_grad = False 106 | 107 | def forward(self, x): 108 | # Pass the input through the ViT model 109 | return self.vision_encoder(x) 110 | 111 | 112 | import timm 113 | import numpy as np 114 | import copy 115 | 116 | class TransformerObsEncoder(nn.Module): 117 | def __init__(self, 118 | shape_meta: dict, 119 | model_name: str = 'vit_base_patch16_clip_224.openai', 120 | global_pool: str = '', 121 | transforms: list = None, 122 | n_emb: int = 768, 123 | pretrained: bool = True, 124 | frozen: bool = False, 125 | use_group_norm: bool = True, 126 | share_rgb_model: bool = False, 127 | feature_aggregation: str = None, 128 | downsample_ratio: int = 32): 129 | """ 130 | Assumes rgb input: B,T,C,H,W 131 | Assumes low_dim input: B,T,D 132 | """ 133 | super().__init__() 134 | 135 | rgb_keys = list() 136 | low_dim_keys = list() 137 | key_model_map = nn.ModuleDict() 138 | key_transform_map = nn.ModuleDict() 139 | key_projection_map = nn.ModuleDict() 140 | key_shape_map = dict() 141 | 142 | assert global_pool == '' 143 | model = timm.create_model( 144 | model_name=model_name, 145 | pretrained=pretrained, 146 | global_pool=global_pool, # '' means no pooling 147 | num_classes=0 # remove classification layer 148 | ) 149 | self.model_name = model_name 150 | 151 | if frozen: 152 | assert pretrained 153 | for param in model.parameters(): 154 | param.requires_grad = False 155 | 156 | feature_dim = None 157 | if model_name.startswith('resnet'): 158 | if downsample_ratio == 32: 159 | modules = list(model.children())[:-2] 160 | model = torch.nn.Sequential(*modules) 161 | feature_dim = 512 162 | elif downsample_ratio == 16: 163 | modules = list(model.children())[:-3] 164 | model = torch.nn.Sequential(*modules) 165 | feature_dim = 256 166 | else: 167 | raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}") 168 | elif model_name.startswith('convnext'): 169 | if downsample_ratio == 32: 170 | modules = list(model.children())[:-2] 171 | model = torch.nn.Sequential(*modules) 172 | feature_dim = 1024 173 | else: 174 | raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}") 175 | 176 | if use_group_norm and not pretrained: 177 | model = self.replace_batch_norm_with_group_norm(model) 178 | 179 | # handle feature aggregation 180 | self.feature_aggregation = feature_aggregation 181 | if model_name.startswith('vit'): 182 | if self.feature_aggregation is None: 183 | pass 184 | elif self.feature_aggregation != 'cls': 185 | print(f'vit will use the CLS token. feature_aggregation ({self.feature_aggregation}) is ignored!') 186 | self.feature_aggregation = 'cls' 187 | 188 | if self.feature_aggregation == 'soft_attention': 189 | self.attention = nn.Sequential( 190 | nn.Linear(feature_dim, 1, bias=False), 191 | nn.Softmax(dim=1) 192 | ) 193 | 194 | image_shape = None 195 | obs_shape_meta = shape_meta['obs'] 196 | for key, attr in obs_shape_meta.items(): 197 | shape = tuple(attr['shape']) 198 | type = attr.get('type', 'low_dim') 199 | if type == 'rgb': 200 | assert image_shape is None or image_shape == shape[1:] 201 | image_shape = shape[1:] 202 | 203 | if transforms is not None and not isinstance(transforms[0], torch.nn.Module): 204 | assert transforms[0].type == 'RandomCrop' 205 | ratio = transforms[0].ratio 206 | transforms = [ 207 | torchvision.transforms.RandomCrop(size=int(image_shape[0] * ratio)), 208 | torchvision.transforms.Resize(size=image_shape[0], antialias=True) 209 | ] + transforms[1:] 210 | transform = nn.Identity() if transforms is None else torch.nn.Sequential(*transforms) 211 | 212 | for key, attr in obs_shape_meta.items(): 213 | shape = tuple(attr['shape']) 214 | type = attr.get('type', 'low_dim') 215 | key_shape_map[key] = shape 216 | if type == 'rgb': 217 | rgb_keys.append(key) 218 | 219 | this_model = model if share_rgb_model else copy.deepcopy(model) 220 | key_model_map[key] = this_model 221 | 222 | with torch.no_grad(): 223 | example_img = torch.zeros((1,) + tuple(shape)) 224 | example_feature_map = this_model(example_img) 225 | example_features = self.aggregate_feature(example_feature_map) 226 | feature_shape = example_features.shape 227 | feature_size = feature_shape[-1] 228 | 229 | proj = nn.Identity() 230 | if feature_size != n_emb: 231 | proj = nn.Linear(in_features=feature_size, out_features=n_emb) 232 | key_projection_map[key] = proj 233 | 234 | this_transform = transform 235 | key_transform_map[key] = this_transform 236 | elif type == 'low_dim': 237 | dim = np.prod(shape) 238 | proj = nn.Identity() 239 | if dim != n_emb: 240 | proj = nn.Linear(in_features=dim, out_features=n_emb) 241 | key_projection_map[key] = proj 242 | 243 | low_dim_keys.append(key) 244 | else: 245 | raise RuntimeError(f"Unsupported obs type: {type}") 246 | 247 | rgb_keys = sorted(rgb_keys) 248 | low_dim_keys = sorted(low_dim_keys) 249 | 250 | self.n_emb = n_emb 251 | self.shape_meta = shape_meta 252 | self.key_model_map = key_model_map 253 | self.key_transform_map = key_transform_map 254 | self.key_projection_map = key_projection_map 255 | self.share_rgb_model = share_rgb_model 256 | self.rgb_keys = rgb_keys 257 | self.low_dim_keys = low_dim_keys 258 | self.key_shape_map = key_shape_map 259 | 260 | def aggregate_feature(self, feature): 261 | if self.model_name.startswith('vit'): 262 | if self.feature_aggregation == 'cls': 263 | return feature[:, [0], :] 264 | assert self.feature_aggregation is None 265 | return feature 266 | 267 | assert len(feature.shape) == 4 268 | feature = torch.flatten(feature, start_dim=-2) # B, 512, 7*7 269 | feature = torch.transpose(feature, 1, 2) # B, 7*7, 512 270 | 271 | if self.feature_aggregation == 'avg': 272 | return torch.mean(feature, dim=[1], keepdim=True) 273 | elif self.feature_aggregation == 'max': 274 | return torch.amax(feature, dim=[1], keepdim=True) 275 | elif self.feature_aggregation == 'soft_attention': 276 | weight = self.attention(feature) 277 | return torch.sum(feature * weight, dim=1, keepdim=True) 278 | else: 279 | assert self.feature_aggregation is None 280 | return feature 281 | 282 | def forward(self, obs_dict): 283 | embeddings = list() 284 | batch_size = next(iter(obs_dict.values())).shape[0] 285 | 286 | for key in self.rgb_keys: 287 | img = obs_dict[key] 288 | B, T = img.shape[:2] 289 | assert B == batch_size 290 | assert img.shape[2:] == self.key_shape_map[key] 291 | img = img.reshape(B * T, *img.shape[2:]) 292 | img = self.key_transform_map[key](img) 293 | raw_feature = self.key_model_map[key](img) 294 | feature = self.aggregate_feature(raw_feature) 295 | emb = self.key_projection_map[key](feature) 296 | assert len(emb.shape) == 3 and emb.shape[0] == B * T and emb.shape[-1] == self.n_emb 297 | emb = emb.reshape(B, -1, self.n_emb) 298 | embeddings.append(emb) 299 | 300 | for key in self.low_dim_keys: 301 | data = obs_dict[key] 302 | B, T = data.shape[:2] 303 | assert B == batch_size 304 | assert data.shape[2:] == self.key_shape_map[key] 305 | data = data.reshape(B, T, -1) 306 | emb = self.key_projection_map[key](data) 307 | assert emb.shape[-1] == self.n_emb 308 | embeddings.append(emb) 309 | 310 | result = torch.cat(embeddings, dim=1) 311 | return result 312 | 313 | def replace_batch_norm_with_group_norm(self, model): 314 | def replace_submodules(root_module, predicate, func): 315 | for name, module in root_module.named_children(): 316 | if predicate(module): 317 | root_module.add_module(name, func(module)) 318 | else: 319 | replace_submodules(module, predicate, func) 320 | return root_module 321 | 322 | return replace_submodules( 323 | root_module=model, 324 | predicate=lambda x: isinstance(x, nn.BatchNorm2d), 325 | func=lambda x: nn.GroupNorm( 326 | num_groups=(x.num_features // 16) if (x.num_features % 16 == 0) else (x.num_features // 8), 327 | num_channels=x.num_features 328 | ) 329 | ) 330 | 331 | def test(): 332 | shape_meta = { 333 | 'obs': { 334 | 'rgb': {'shape': (3, 224, 224), 'type': 'rgb'}, 335 | 'low_dim': {'shape': (10,), 'type': 'low_dim'} 336 | } 337 | } 338 | 339 | encoder = TransformerObsEncoder(shape_meta=shape_meta) 340 | obs_dict = { 341 | 'rgb': torch.rand(2, 5, 3, 224, 224), 342 | 'low_dim': torch.rand(2, 5, 10) 343 | } 344 | 345 | result = encoder(obs_dict) 346 | print(result.shape) 347 | 348 | # test() -------------------------------------------------------------------------------- /transformer_obs_encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import timm 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | import logging 10 | from typing import Dict, Callable, List 11 | 12 | from module_attr_mixin import ModuleAttrMixin 13 | 14 | 15 | 16 | 17 | def replace_submodules( 18 | root_module: nn.Module, 19 | predicate: Callable[[nn.Module], bool], 20 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 21 | """ 22 | predicate: Return true if the module is to be replaced. 23 | func: Return new module to use. 24 | """ 25 | if predicate(root_module): 26 | return func(root_module) 27 | 28 | bn_list = [k.split('.') for k, m 29 | in root_module.named_modules(remove_duplicate=True) 30 | if predicate(m)] 31 | for *parent, k in bn_list: 32 | parent_module = root_module 33 | if len(parent) > 0: 34 | parent_module = root_module.get_submodule('.'.join(parent)) 35 | if isinstance(parent_module, nn.Sequential): 36 | src_module = parent_module[int(k)] 37 | else: 38 | src_module = getattr(parent_module, k) 39 | tgt_module = func(src_module) 40 | if isinstance(parent_module, nn.Sequential): 41 | parent_module[int(k)] = tgt_module 42 | else: 43 | setattr(parent_module, k, tgt_module) 44 | # verify that all BN are replaced 45 | bn_list = [k.split('.') for k, m 46 | in root_module.named_modules(remove_duplicate=True) 47 | if predicate(m)] 48 | assert len(bn_list) == 0 49 | return root_module 50 | 51 | logger = logging.getLogger(__name__) 52 | 53 | class AttentionPool2d(nn.Module): 54 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 55 | super().__init__() 56 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 57 | self.k_proj = nn.Linear(embed_dim, embed_dim) 58 | self.q_proj = nn.Linear(embed_dim, embed_dim) 59 | self.v_proj = nn.Linear(embed_dim, embed_dim) 60 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 61 | self.num_heads = num_heads 62 | 63 | def forward(self, x): 64 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 65 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 66 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 67 | x, _ = F.multi_head_attention_forward( 68 | query=x[:1], key=x, value=x, 69 | embed_dim_to_check=x.shape[-1], 70 | num_heads=self.num_heads, 71 | q_proj_weight=self.q_proj.weight, 72 | k_proj_weight=self.k_proj.weight, 73 | v_proj_weight=self.v_proj.weight, 74 | in_proj_weight=None, 75 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 76 | bias_k=None, 77 | bias_v=None, 78 | add_zero_attn=False, 79 | dropout_p=0, 80 | out_proj_weight=self.c_proj.weight, 81 | out_proj_bias=self.c_proj.bias, 82 | use_separate_proj_weight=True, 83 | training=self.training, 84 | need_weights=False 85 | ) 86 | return x.squeeze(0) 87 | 88 | 89 | class SimpleRGBObsEncoder(nn.Module): 90 | def __init__(self, 91 | model_name: str = 'vit_base_patch16_224', 92 | n_emb: int = 768, 93 | pretrained: bool = False, 94 | frozen: bool = False, 95 | use_group_norm: bool = True, 96 | feature_aggregation: str = "cls", 97 | downsample_ratio: int = 32): 98 | """ 99 | Assumes rgb input: B, T, C, H, W 100 | For images of fixed size 224x224 101 | """ 102 | super().__init__() 103 | 104 | self.feature_aggregation = feature_aggregation 105 | 106 | # Load the model 107 | model = timm.create_model( 108 | model_name=model_name, 109 | pretrained=pretrained, 110 | num_classes=0, # remove classification layer 111 | global_pool='' # no global pooling 112 | ) 113 | 114 | if frozen: 115 | assert pretrained, "If frozen, the model should be pretrained." 116 | for param in model.parameters(): 117 | param.requires_grad = False 118 | 119 | # Handling ResNet model specific downsample ratio 120 | if model_name.startswith('resnet'): 121 | if downsample_ratio == 32: 122 | modules = list(model.children())[:-2] 123 | model = torch.nn.Sequential(*modules) 124 | feature_dim = 512 125 | elif downsample_ratio == 16: 126 | modules = list(model.children())[:-3] 127 | model = torch.nn.Sequential(*modules) 128 | feature_dim = 256 129 | else: 130 | raise NotImplementedError(f"Unsupported downsample_ratio: {downsample_ratio}") 131 | else: 132 | # For models like ViT 133 | feature_dim = model.embed_dim if hasattr(model, 'embed_dim') else n_emb 134 | 135 | # Optional GroupNorm replacement if not pretrained 136 | if use_group_norm and not pretrained: 137 | model = self.replace_bn_with_gn(model) 138 | 139 | self.model = model 140 | self.n_emb = n_emb 141 | self.feature_dim = feature_dim 142 | 143 | # Optional projection if feature size does not match n_emb 144 | self.projection = nn.Identity() 145 | if feature_dim != n_emb: 146 | self.projection = nn.Linear(in_features=feature_dim, out_features=n_emb) 147 | 148 | def replace_bn_with_gn(self, model): 149 | """Replace all BatchNorm layers with GroupNorm""" 150 | for name, module in model.named_modules(): 151 | if isinstance(module, nn.BatchNorm2d): 152 | gn_layer = nn.GroupNorm( 153 | num_groups=(module.num_features // 16) if (module.num_features % 16 == 0) else (module.num_features // 8), 154 | num_channels=module.num_features 155 | ) 156 | setattr(model, name, gn_layer) 157 | return model 158 | 159 | def aggregate_feature(self, feature): 160 | """Aggregate features, handling different feature aggregation strategies""" 161 | if self.feature_aggregation == 'cls': 162 | return feature[:, 0, :] # ViT uses the CLS token for classification 163 | else: 164 | return feature # Return all tokens or raw feature map 165 | 166 | def forward(self, img): 167 | """ 168 | img: B, T, C, H, W 169 | """ 170 | B, T, C, H, W = img.shape 171 | assert H == W == 224, "Input image must be 224x224" 172 | img = img.reshape(B * T, C, H, W) 173 | 174 | # Pass through the model 175 | raw_feature = self.model(img) 176 | feature = self.aggregate_feature(raw_feature) 177 | 178 | # Apply projection if necessary 179 | emb = self.projection(feature) 180 | 181 | # Reshape to B, T, n_emb 182 | emb = emb.view(B, T, self.n_emb) 183 | return emb 184 | 185 | @torch.no_grad() 186 | def output_shape(self): 187 | example_img = torch.zeros((1, 1, 3, 224, 224)) 188 | example_output = self.forward(example_img) 189 | return example_output.shape 190 | 191 | --------------------------------------------------------------------------------