├── t3 ├── __init__.py ├── models │ ├── __init__.py │ ├── other_losses.py │ ├── mae_recon_loss.py │ ├── trunk.py │ ├── t3.py │ ├── decoder.py │ ├── nn_utils.py │ └── encoder.py ├── utils.py ├── _calandra17_label_dict.py ├── task_utils.py ├── data_loader.py └── pretrain.py ├── .gitignore ├── scripts └── train_nn.py ├── configs ├── datasets │ ├── visgel_variance_estimation.yaml │ ├── eval_ds_pose_mini.yaml │ ├── eval_ds_pose_svelte.yaml │ ├── eval_ds_pose_wedge.yaml │ ├── eval_ds_pose_densetact.yaml │ ├── eval_ds_cls_mini.yaml │ ├── eval_ds_cls_wedge.yaml │ ├── eval_ds_cls_svelte.yaml │ ├── eval_ds_cls_densetact.yaml │ ├── calandra17_classification.yaml │ ├── objectfolder_real_classification.yaml │ ├── touch_and_go_classification.yaml │ ├── eval_ds_pose_finray.yaml │ ├── eval_ds_cls_finray.yaml │ ├── cnc_mae.yaml │ ├── cnc_pose.yaml │ ├── yuan18_classification.yaml │ └── single_tower_mae.yaml ├── config.yaml └── network │ ├── finetune_exp_cls.yaml │ ├── finetune_exp_pose_regression.yaml │ ├── pretrain1_cross_mae.yaml │ ├── pretrain1_mae.yaml │ └── pretrain2.yaml ├── pyproject.toml ├── LICENSE ├── dataset_preproc ├── calculate_normalization.py ├── objectfolder_real_metadata.csv ├── preprocess_yuan18_images.py ├── utils.py ├── preprocess_tvl_images.py ├── README.md ├── preprocess_objectfolder_real_images.py ├── preprocess_calandra17_images.py ├── preprocess_touchandgo_images.py ├── preprocess_ycbsight_images.py └── preprocess_visgel_images.py └── README.md /t3/__init__.py: -------------------------------------------------------------------------------- 1 | from .pretrain import T3Pretrain -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | outputs 3 | *.egg-info 4 | *.egg 5 | __pycache__ 6 | wandb 7 | checkpoints 8 | checkpoints_supercloud 9 | cnc_Wedge -------------------------------------------------------------------------------- /t3/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import ResNetEncoder, CNNEncoder, ViTEncoder, MAEViTEncoder, IdentityEncoder 2 | from .trunk import MLPTrunk, TransformerTrunk, IdentityTrunk 3 | from .decoder import MLPDecoder, MAEViTDecoder, CrossMAEViTDecoder, CNNFCDecoder, IdentityDecoder, MLPTwoTowerDecoder, PoolingDecoder 4 | from .t3 import T3 5 | from .mae_recon_loss import MAEReconLoss 6 | from .other_losses import VarianceScaledLoss -------------------------------------------------------------------------------- /scripts/train_nn.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | from t3 import T3Pretrain 4 | 5 | @hydra.main(version_base=None, config_path="../configs", config_name="config.yaml") 6 | def train_nn(cfg): 7 | pretrainer = T3Pretrain(cfg) 8 | pretrainer.setup_model() 9 | pretrainer.setup_optimizer() 10 | pretrainer.setup_dataset() 11 | print("Dataset setup complete") 12 | # pretrainer.train() 13 | pretrainer.test(20, "", 0, False) 14 | 15 | if __name__ == "__main__": 16 | train_nn() -------------------------------------------------------------------------------- /configs/datasets/visgel_variance_estimation.yaml: -------------------------------------------------------------------------------- 1 | visgel_variance_estimation: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerVarianceDataset 6 | data_dir: "data/FoundationTactile/visgel_downsampled" 7 | encoder_domain: "gs_green" 8 | decoder_domain: "variance_estimation" 9 | random_resize_crop: false 10 | random_hv_flip_prob: 0.0 11 | color_jitter: null 12 | img_norm: 13 | mean: [0.36917, 0.51738, 0.51782] 14 | std: [0.13530, 0.11039, 0.10001] -------------------------------------------------------------------------------- /t3/models/other_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scaled MSE loss for Transferable Tactile Transformer (T3) pre-training 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | class VarianceScaledLoss(nn.Module): 13 | def __init__(self, scale=5.): 14 | super().__init__() 15 | self.scale = scale 16 | self.mse = nn.MSELoss() 17 | 18 | def forward(self, pred, Y): 19 | mse_loss = self.mse(pred, Y) 20 | return torch.sqrt(mse_loss) / self.scale -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "t3" 7 | version = "2024.5.30" 8 | requires-python = ">= 3.8" 9 | dependencies = [ 10 | "torch >= 2.1.0", 11 | "torchvision >= 0.16", 12 | "numpy", 13 | "hydra-core", 14 | "tqdm", 15 | "wandb", 16 | "pandas", 17 | "pyarrow", 18 | "autolab_core", 19 | "scipy", 20 | "timm", 21 | "webdataset", 22 | "natsort", 23 | "autolab_core", 24 | ] 25 | 26 | [tool.setuptools] 27 | packages = ["t3"] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_pose_mini.yaml: -------------------------------------------------------------------------------- 1 | cnc_mini: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 6 | data_dir: data/FoundationTactile/cnc/cnc_Mini 7 | encoder_domain: mini 8 | decoder_domain: pose_estimation_3d 9 | random_resize_crop: false 10 | random_hv_flip_prob: 0.0 11 | color_jitter: null 12 | pose_dim: 3 13 | img_norm: 14 | mean: [0.02172, 0.04152, 0.04152] 15 | std: [0.10158, 0.14423, 0.12185] 16 | label_norm: 17 | mean: [0.0, 0.0, 0.0] 18 | std: [1., 1., 1.] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_pose_svelte.yaml: -------------------------------------------------------------------------------- 1 | cnc_svelte: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 6 | data_dir: data/FoundationTactile/cnc/cnc_Svelte 7 | encoder_domain: svelte 8 | decoder_domain: pose_estimation_3d 9 | random_resize_crop: false 10 | random_hv_flip_prob: 0.0 11 | color_jitter: null 12 | pose_dim: 3 13 | img_norm: 14 | mean: [0.41175, 0.26801, 0.02041] 15 | std: [0.26641, 0.19636, 0.09754] 16 | label_norm: 17 | mean: [0.0, 0.0, 0.0] 18 | std: [1., 1., 1.] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_pose_wedge.yaml: -------------------------------------------------------------------------------- 1 | cnc_wedge: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 6 | data_dir: data/FoundationTactile/cnc/cnc_Wedge 7 | encoder_domain: wedge 8 | decoder_domain: pose_estimation_3d 9 | random_resize_crop: false 10 | random_hv_flip_prob: 0.0 11 | color_jitter: null 12 | pose_dim: 3 13 | img_norm: 14 | mean: [0.24580, 0.30085, 0.35867] 15 | std: [0.18356, 0.14779, 0.18460] 16 | label_norm: 17 | mean: [0.0, 0.0, 0.0] 18 | std: [1., 1., 1.] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_pose_densetact.yaml: -------------------------------------------------------------------------------- 1 | cnc_densetact: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 6 | data_dir: data/FoundationTactile/cnc/cnc_DenseTact 7 | encoder_domain: densetact 8 | decoder_domain: pose_estimation_3d 9 | random_resize_crop: false 10 | random_hv_flip_prob: 0.0 11 | color_jitter: null 12 | pose_dim: 3 13 | img_norm: 14 | mean: [0.20997, 0.28465, 0.26797] 15 | std: [0.29838, 0.36981, 0.33362] 16 | label_norm: 17 | mean: [0.0, 0.0, 0.0] 18 | std: [1., 1., 1.] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_cls_mini.yaml: -------------------------------------------------------------------------------- 1 | cnc_mini: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_cnc_cls_label 7 | data_dir: data/FoundationTactile/cnc/cnc_Mini 8 | encoder_domain: mini 9 | decoder_domain: cls_cnc 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.02172, 0.04152, 0.04152] 19 | std: [0.10158, 0.14423, 0.12185] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_cls_wedge.yaml: -------------------------------------------------------------------------------- 1 | cnc_wedge: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_cnc_cls_label 7 | data_dir: data/FoundationTactile/cnc/cnc_Wedge 8 | encoder_domain: wedge 9 | decoder_domain: cls_cnc 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.24580, 0.30085, 0.35867] 19 | std: [0.18356, 0.14779, 0.18460] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_cls_svelte.yaml: -------------------------------------------------------------------------------- 1 | cnc_svelte: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_cnc_cls_label 7 | data_dir: data/FoundationTactile/cnc/cnc_Svelte 8 | encoder_domain: svelte 9 | decoder_domain: cls_cnc 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.41175, 0.26801, 0.02041] 19 | std: [0.26641, 0.19636, 0.09754] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_cls_densetact.yaml: -------------------------------------------------------------------------------- 1 | cnc_densetact: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_cnc_cls_label 7 | data_dir: data/FoundationTactile/cnc/cnc_DenseTact 8 | encoder_domain: densetact 9 | decoder_domain: cls_cnc 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.20997, 0.28465, 0.26797] 19 | std: [0.29838, 0.36981, 0.33362] -------------------------------------------------------------------------------- /configs/datasets/calandra17_classification.yaml: -------------------------------------------------------------------------------- 1 | calandra17_class: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_calandra17_obj_label 7 | data_dir: "data/FoundationTactile/calandra17" 8 | encoder_domain: "gs_green" 9 | decoder_domain: "cls_calandra17" 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.27307, 0.27307, 0.27307] 19 | std: [0.26252, 0.28064, 0.30760] 20 | -------------------------------------------------------------------------------- /configs/datasets/objectfolder_real_classification.yaml: -------------------------------------------------------------------------------- 1 | objectfolder_material_class: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_objectfolder_real_label 7 | data_dir: "data/FoundationTactile/objectfolder_real" 8 | encoder_domain: "gs_black" 9 | decoder_domain: "cls_objectfolder" 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.46676, 0.45028, 0.45292] 19 | std: [0.08171, 0.06973, 0.08618] 20 | -------------------------------------------------------------------------------- /configs/datasets/touch_and_go_classification.yaml: -------------------------------------------------------------------------------- 1 | touch_and_go_classification: 2 | activate: true # if false, this dataset will not be used 3 | eval_only: false # if true, this dataset will not be used for training 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_touch_and_go_label 7 | data_dir: "data/FoundationTactile/touch_and_go" 8 | encoder_domain: "gs_tag" 9 | decoder_domain: "cls_touch_and_go" 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.51808, 0.50300, 0.51457] 19 | std: [0.13893, 0.11343, 0.13497] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 alanz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/datasets/eval_ds_pose_finray.yaml: -------------------------------------------------------------------------------- 1 | cnc_finray: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 6 | data_dir: data/FoundationTactile/cnc/cnc_Finray 7 | encoder_domain: finray 8 | decoder_domain: pose_estimation_3d 9 | random_resize_crop: false 10 | random_hv_flip_prob: 0.0 11 | color_jitter: null 12 | pose_dim: 3 13 | img_norm: 14 | mean: [0.01000, 0.04490, 0.07897] 15 | std: [0.06394, 0.14283, 0.20100] 16 | label_norm: 17 | mean: [0.0, 0.0, 0.0] 18 | std: [1., 1., 1.] 19 | 20 | cnc_finray2: 21 | activate: true 22 | eval_only: false 23 | data_loader: 24 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 25 | data_dir: data/FoundationTactile/cnc/cnc_Finray2 26 | encoder_domain: finray 27 | decoder_domain: pose_estimation_3d 28 | random_resize_crop: false 29 | random_hv_flip_prob: 0.0 30 | color_jitter: null 31 | pose_dim: 3 32 | img_norm: 33 | mean: [0.00139, 0.02078, 0.10659] 34 | std: [0.01877, 0.10335, 0.23180] 35 | label_norm: 36 | mean: [0.0, 0.0, 0.0] 37 | std: [1., 1., 1.] -------------------------------------------------------------------------------- /configs/datasets/eval_ds_cls_finray.yaml: -------------------------------------------------------------------------------- 1 | cnc_finray: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_cnc_cls_label 7 | data_dir: data/FoundationTactile/cnc/cnc_Finray 8 | encoder_domain: finray 9 | decoder_domain: cls_cnc 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.01000, 0.04490, 0.07897] 19 | std: [0.06394, 0.14283, 0.20100] 20 | 21 | cnc_finray2: 22 | activate: true 23 | eval_only: false 24 | data_loader: 25 | _target_: t3.data_loader.SingleTowerClassificationDataset 26 | label_process_func: t3.task_utils.process_cnc_cls_label 27 | data_dir: data/FoundationTactile/cnc/cnc_Finray2 28 | encoder_domain: finray 29 | decoder_domain: cls_cnc 30 | random_resize_crop: true 31 | random_hv_flip_prob: 0.5 32 | color_jitter: 33 | brightness: 0.4 34 | contrast: 0.4 35 | saturation: 0.5 36 | hue: 0.3 37 | img_norm: 38 | mean: [0.00139, 0.02078, 0.10659] 39 | std: [0.01877, 0.10335, 0.23180] -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | comment: "" # a comment that will be appended to the wandb run name 2 | 3 | train: 4 | # img_size: 224 # TODO: not being used 5 | batch_size: 64 6 | dl_weight_type: "root" # how each dataloader is weighted according to number of batches. "equal", "invlinear", "root" 7 | num_data_workers: 0 8 | wandb: false 9 | wandb_entity: "" # your wandb username 10 | log_freq: 10 # how often to log to wandb 11 | save_model: true 12 | finetune_from: "" # path to a model to finetune / load from 13 | # Will train for total_train_steps, during which will run eval for test_steps every test_every steps 14 | total_train_steps: 100000 15 | test_every: 750 16 | test_steps: 50 17 | generate_mae_visualizations: true 18 | 19 | # whether to freeze the encoder and trunk 20 | freeze_encoder: false 21 | freeze_trunk: false 22 | # whether to unfreeze the encoder and trunk at a given step. only effective when both freeze_encoder and freeze_trunk are true 23 | scheduled_unfreeze: false 24 | scheduled_unfreeze_step: 20000 25 | 26 | optimizer: 27 | _target_: torch.optim.AdamW 28 | lr: 1.0e-4 29 | eps: 1.0e-6 30 | weight_decay: 0.1 31 | # the head and stem are updated at different frequencies. they can be trained with less learning rates. 32 | nontrunk_lr_scale: 1.0 # 0.5 33 | 34 | scheduler: 35 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 36 | T_max: ${train.total_train_steps} 37 | eta_min: 1e-8 38 | 39 | defaults: 40 | - _self_ 41 | - network: finetune_exp_cls 42 | - datasets: 43 | - eval_ds_cls_wedge -------------------------------------------------------------------------------- /configs/datasets/cnc_mae.yaml: -------------------------------------------------------------------------------- 1 | cnc_finray: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerMAEDataset 6 | data_dir: "data/FoundationTactile/cnc/cnc_Finray" 7 | encoder_domain: "finray" 8 | decoder_domain: "mae_recon_single" 9 | random_resize_crop: ${datasets.VAR_random_resize_crop} 10 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 11 | color_jitter: ${datasets.VAR_color_jitter} 12 | img_norm: 13 | mean: [0.01000, 0.04490, 0.07897] 14 | std: [0.06394, 0.14283, 0.20100] 15 | 16 | cnc_finray2: 17 | activate: true 18 | eval_only: false 19 | data_loader: 20 | _target_: t3.data_loader.SingleTowerMAEDataset 21 | data_dir: "data/FoundationTactile/cnc/cnc_Finray2" 22 | encoder_domain: "finray" 23 | decoder_domain: "mae_recon_single" 24 | random_resize_crop: ${datasets.VAR_random_resize_crop} 25 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 26 | color_jitter: ${datasets.VAR_color_jitter} 27 | img_norm: 28 | mean: [0.00139, 0.02078, 0.10659] 29 | std: [0.01877, 0.10335, 0.23180] 30 | 31 | cnc_wedge: 32 | activate: true 33 | eval_only: false 34 | data_loader: 35 | _target_: t3.data_loader.SingleTowerMAEDataset 36 | data_dir: "data/FoundationTactile/cnc/cnc_Finray" 37 | encoder_domain: "wedge" 38 | decoder_domain: "mae_recon_single" 39 | random_resize_crop: ${datasets.VAR_random_resize_crop} 40 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 41 | color_jitter: ${datasets.VAR_color_jitter} 42 | img_norm: 43 | mean: [0.24580, 0.30085, 0.35867] 44 | std: [0.18356, 0.14779, 0.18460] -------------------------------------------------------------------------------- /configs/datasets/cnc_pose.yaml: -------------------------------------------------------------------------------- 1 | cnc_finray: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 6 | data_dir: data/FoundationTactile/cnc/cnc_Finray 7 | encoder_domain: finray 8 | decoder_domain: pose_estimation_3d 9 | random_resize_crop: false 10 | random_hv_flip_prob: 0.0 11 | color_jitter: null 12 | pose_dim: 3 13 | img_norm: 14 | mean: [0.01000, 0.04490, 0.07897] 15 | std: [0.06394, 0.14283, 0.20100] 16 | label_norm: 17 | mean: [0.0, 0.0, 0.0] 18 | std: [1., 1., 1.] 19 | 20 | cnc_finray2: 21 | activate: true 22 | eval_only: false 23 | data_loader: 24 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 25 | data_dir: data/FoundationTactile/cnc/cnc_Finray2 26 | encoder_domain: finray 27 | decoder_domain: pose_estimation_3d 28 | random_resize_crop: false 29 | random_hv_flip_prob: 0.0 30 | color_jitter: null 31 | pose_dim: 3 32 | img_norm: 33 | mean: [0.00139, 0.02078, 0.10659] 34 | std: [0.01877, 0.10335, 0.23180] 35 | label_norm: 36 | mean: [0.0, 0.0, 0.0] 37 | std: [1., 1., 1.] 38 | 39 | cnc_wedge: 40 | activate: true 41 | eval_only: false 42 | data_loader: 43 | _target_: t3.data_loader.DoubleTowerPoseEstimationDataset 44 | data_dir: data/FoundationTactile/cnc/cnc_Wedge 45 | encoder_domain: wedge 46 | decoder_domain: pose_estimation_3d 47 | random_resize_crop: false 48 | random_hv_flip_prob: 0.0 49 | color_jitter: null 50 | pose_dim: 3 51 | img_norm: 52 | mean: [0.24580, 0.30085, 0.35867] 53 | std: [0.18356, 0.14779, 0.18460] 54 | label_norm: 55 | mean: [0.0, 0.0, 0.0] 56 | std: [1., 1., 1.] -------------------------------------------------------------------------------- /configs/datasets/yuan18_classification.yaml: -------------------------------------------------------------------------------- 1 | yuan18_textile_type_class: 2 | activate: true 3 | eval_only: false 4 | data_loader: 5 | _target_: t3.data_loader.SingleTowerClassificationDataset 6 | label_process_func: t3.task_utils.process_yuan18_textile_type_label 7 | data_dir: "data/FoundationTactile/yuan18" 8 | encoder_domain: "gs_green" 9 | decoder_domain: "cls_yuan18_textile_type" 10 | random_resize_crop: true 11 | random_hv_flip_prob: 0.5 12 | color_jitter: 13 | brightness: 0.4 14 | contrast: 0.4 15 | saturation: 0.5 16 | hue: 0.3 17 | img_norm: 18 | mean: [0.41745, 0.42082, 0.40049] 19 | std: [0.11456, 0.11639, 0.10868] 20 | 21 | yuan18_smoothness_class: 22 | activate: true 23 | eval_only: false 24 | data_loader: 25 | _target_: t3.data_loader.SingleTowerClassificationDataset 26 | label_process_func: t3.task_utils.process_yuan18_smoothness_label 27 | data_dir: "data/FoundationTactile/yuan18" 28 | encoder_domain: "gs_green" 29 | decoder_domain: "cls_yuan18_smoothness" 30 | random_resize_crop: true 31 | random_hv_flip_prob: 0.5 32 | color_jitter: 33 | brightness: 0.4 34 | contrast: 0.4 35 | saturation: 0.5 36 | hue: 0.3 37 | img_norm: 38 | mean: [0.41745, 0.42082, 0.40049] 39 | std: [0.11456, 0.11639, 0.10868] 40 | 41 | yuan18_fuzziness_class: 42 | activate: true 43 | eval_only: false 44 | data_loader: 45 | _target_: t3.data_loader.SingleTowerClassificationDataset 46 | label_process_func: t3.task_utils.process_yuan18_fuzziness_label 47 | data_dir: "data/FoundationTactile/yuan18" 48 | encoder_domain: "gs_green" 49 | decoder_domain: "cls_yuan18_fuzziness" 50 | random_resize_crop: true 51 | random_hv_flip_prob: 0.5 52 | color_jitter: 53 | brightness: 0.4 54 | contrast: 0.4 55 | saturation: 0.5 56 | hue: 0.3 57 | img_norm: 58 | mean: [0.41745, 0.42082, 0.40049] 59 | std: [0.11456, 0.11639, 0.10868] -------------------------------------------------------------------------------- /t3/models/mae_recon_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reconstruction loss for MAE for Transferable Tactile Transformer (T3) pre-training 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | 11 | class MAEReconLoss(nn.Module): 12 | """ 13 | Reconstruction loss for MAE 14 | """ 15 | def __init__(self, patch_size, norm_pix_loss=False): 16 | super().__init__() 17 | self.patch_size = patch_size 18 | self.norm_pix_loss = norm_pix_loss 19 | 20 | def patchify(self, imgs): 21 | """ 22 | imgs: (N, 3, H, W) 23 | x: (N, L, patch_size**2 *3) 24 | """ 25 | p = self.patch_size 26 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 27 | 28 | h = w = imgs.shape[2] // p 29 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 30 | x = torch.einsum('nchpwq->nhwpqc', x) 31 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 32 | return x 33 | 34 | def forward(self, pred, imgs): 35 | """ 36 | imgs: [N, 3, H, W] 37 | pred: [N, L, p*p*3] (original MAE) or [N, l, p*p*3] (cross MAE) 38 | mask: [N, L], 0 is keep, 1 is remove, 39 | """ 40 | (x, mask, ids_restore) = pred 41 | target = self.patchify(imgs) 42 | # select the masked portion 43 | target = target.masked_select(mask.bool().unsqueeze(-1)).reshape(target.shape[0], -1, target.shape[-1]) 44 | if x.shape[1] != target.shape[1]: 45 | # in the case of original MAE, need to only select the masked portion 46 | x = x.masked_select(mask.bool().unsqueeze(-1)).reshape(x.shape[0], -1, x.shape[-1]) 47 | 48 | if self.norm_pix_loss: 49 | mean = target.mean(dim=-1, keepdim=True) 50 | var = target.var(dim=-1, keepdim=True) 51 | target = (target - mean) / (var + 1.e-6)**.5 52 | 53 | loss = (x - target) ** 2 54 | loss = loss.mean() 55 | return loss -------------------------------------------------------------------------------- /dataset_preproc/calculate_normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculate image normalization on each dataset 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | 9 | import webdataset as wds 10 | from natsort import natsorted 11 | 12 | root_dir = "data/processed" 13 | 14 | def print_array(arr): 15 | return "[" + ", ".join([f"{x:.5f}" for x in arr]) + "]" 16 | 17 | def run_sharded_dataset(root_dir, num_samples=50, exlude=[], only=[]): 18 | # sample num_samples images from the train folder and take the average 19 | if only: 20 | ds_names = [os.path.join(root_dir, _) for _ in only] 21 | else: 22 | ds_names = [os.path.join(root_dir, _) for _ in os.listdir(root_dir)] 23 | for ds in ds_names: 24 | if ds in exlude: 25 | continue 26 | if ds.startswith("panda"): 27 | continue 28 | print("-" * 80) 29 | print(f"Dataset: {ds}") 30 | tars = natsorted([d for d in os.listdir(os.path.join(ds, "train")) if d.startswith("data-")]) 31 | _start_idx = os.path.splitext(os.path.basename(tars[0]))[0].strip("data-") 32 | _end_idx = os.path.splitext(os.path.basename(tars[-1]))[0].strip("data-") 33 | url = os.path.join(ds, "train", f"data-{{{_start_idx}..{_end_idx}}}.tar") 34 | 35 | dataset = ( 36 | wds.WebDataset(url, shardshuffle=True) 37 | .shuffle(1000) 38 | .decode("pil") 39 | .to_tuple("jpg", "json") 40 | ) 41 | imgs = [] 42 | for i, (img, label) in enumerate(dataset): 43 | if len(imgs) >= num_samples: 44 | break 45 | img = np.array(img, dtype=np.float32) / 255.0 46 | if len(imgs) > 0 and img.shape != imgs[0].shape: 47 | continue 48 | imgs.append(img) 49 | flat_img = np.hstack(imgs) 50 | print("IMG mean & norm: ", print_array(np.mean(flat_img, axis=(0, 1))), print_array(np.std(flat_img, axis=(0, 1)))) 51 | 52 | if __name__ == "__main__": 53 | run_sharded_dataset(root_dir, only=["cnc/cnc_Mini", "cnc/cnc_DenseTact", "cnc/cnc_Svelte"]) -------------------------------------------------------------------------------- /configs/network/finetune_exp_cls.yaml: -------------------------------------------------------------------------------- 1 | # Example network configuration for an object classification finetune task 2 | 3 | # Network size corresponds to "t3_medium" but can be overwritten in code 4 | patch_size: 16 5 | encoder_embed_dim: 768 6 | encoder_heads: 12 7 | pooling: "none" 8 | encoder_depth: 3 9 | trunk_depth: 9 10 | 11 | encoders: 12 | wedge: 13 | _target_: t3.models.ViTEncoder 14 | patch_size: ${network.patch_size} 15 | embed_dim: ${network.encoder_embed_dim} 16 | depth: ${network.encoder_depth} 17 | num_heads: ${network.encoder_heads} 18 | mlp_ratio: 4. 19 | finray: 20 | _target_: t3.models.ViTEncoder 21 | patch_size: ${network.patch_size} 22 | embed_dim: ${network.encoder_embed_dim} 23 | depth: ${network.encoder_depth} 24 | num_heads: ${network.encoder_heads} 25 | mlp_ratio: 4. 26 | svelte: 27 | _target_: t3.models.ViTEncoder 28 | patch_size: ${network.patch_size} 29 | embed_dim: ${network.encoder_embed_dim} 30 | depth: ${network.encoder_depth} 31 | num_heads: ${network.encoder_heads} 32 | mlp_ratio: 4. 33 | densetact: 34 | _target_: t3.models.ViTEncoder 35 | patch_size: ${network.patch_size} 36 | embed_dim: ${network.encoder_embed_dim} 37 | depth: ${network.encoder_depth} 38 | num_heads: ${network.encoder_heads} 39 | mlp_ratio: 4. 40 | mini: 41 | _target_: t3.models.ViTEncoder 42 | patch_size: ${network.patch_size} 43 | embed_dim: ${network.encoder_embed_dim} 44 | depth: ${network.encoder_depth} 45 | num_heads: ${network.encoder_heads} 46 | mlp_ratio: 4. 47 | 48 | shared_trunk: 49 | _target_: t3.models.TransformerTrunk 50 | embed_dim: ${network.encoder_embed_dim} 51 | depth: ${network.trunk_depth} 52 | num_heads: ${network.encoder_heads} 53 | mlp_ratio: 4. 54 | pooling_type: ${network.pooling} 55 | 56 | decoders: 57 | cls_cnc: 58 | _target_: t3.models.MLPDecoder 59 | input_dim: ${network.encoder_embed_dim} 60 | output_dim: 6 61 | hidden_dims: [256, 128, 64] 62 | dropout_p: 0.1 63 | transformer_upstream: true 64 | pooling_type: cls 65 | loss_func: 66 | _target_: torch.nn.CrossEntropyLoss -------------------------------------------------------------------------------- /configs/network/finetune_exp_pose_regression.yaml: -------------------------------------------------------------------------------- 1 | # Example network configuration for a pose estimation finetune task 2 | 3 | # Network size corresponds to "t3_medium" but can be overwritten in code 4 | patch_size: 16 5 | encoder_embed_dim: 768 6 | encoder_heads: 12 7 | pooling: "none" 8 | encoder_depth: 3 9 | trunk_depth: 9 10 | 11 | encoders: 12 | wedge: 13 | _target_: t3.models.ViTEncoder 14 | patch_size: ${network.patch_size} 15 | embed_dim: ${network.encoder_embed_dim} 16 | depth: ${network.encoder_depth} 17 | num_heads: ${network.encoder_heads} 18 | mlp_ratio: 4. 19 | finray: 20 | _target_: t3.models.ViTEncoder 21 | patch_size: ${network.patch_size} 22 | embed_dim: ${network.encoder_embed_dim} 23 | depth: ${network.encoder_depth} 24 | num_heads: ${network.encoder_heads} 25 | mlp_ratio: 4. 26 | svelte: 27 | _target_: t3.models.ViTEncoder 28 | patch_size: ${network.patch_size} 29 | embed_dim: ${network.encoder_embed_dim} 30 | depth: ${network.encoder_depth} 31 | num_heads: ${network.encoder_heads} 32 | mlp_ratio: 4. 33 | densetact: 34 | _target_: t3.models.ViTEncoder 35 | patch_size: ${network.patch_size} 36 | embed_dim: ${network.encoder_embed_dim} 37 | depth: ${network.encoder_depth} 38 | num_heads: ${network.encoder_heads} 39 | mlp_ratio: 4. 40 | mini: 41 | _target_: t3.models.ViTEncoder 42 | patch_size: ${network.patch_size} 43 | embed_dim: ${network.encoder_embed_dim} 44 | depth: ${network.encoder_depth} 45 | num_heads: ${network.encoder_heads} 46 | mlp_ratio: 4. 47 | 48 | shared_trunk: 49 | _target_: t3.models.TransformerTrunk 50 | embed_dim: ${network.encoder_embed_dim} 51 | depth: ${network.trunk_depth} 52 | num_heads: ${network.encoder_heads} 53 | mlp_ratio: 4. 54 | pooling_type: ${network.pooling} 55 | 56 | decoders: 57 | pose_estimation_3d: 58 | _target_: t3.models.CNNFCDecoder 59 | inplanes: ${network.encoder_embed_dim} 60 | fc_hidden_dims: [256, 64] 61 | output_dim: 3 # using d9 representation 62 | stride: 2 63 | dropout_p: 0.1 64 | tanh_end: false 65 | transformer_upstream: true 66 | loss_func: 67 | _target_: torch.nn.MSELoss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transferable Tactile Transfomers (T3) and the Foundation Tactile (FoTa) dataset 2 | 3 | ![img](https://t3.alanz.info/imgs/archi_compact.png) 4 | 5 | [![Paper](https://badgen.net/badge/icon/arXiv?icon=awesome&label&color=red)](https://arxiv.org/abs/2406.13640) 6 | [![Repo](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/alanzjl/t3) 7 | [![Colab](https://badgen.net/badge/icon/Colab?icon=terminal&label&color=yellow)](https://colab.research.google.com/drive/1MmO9w1y59Gy6ds0iKlW04olszGko56Vf?usp=sharing) 8 | [![Huggingface](https://badgen.net/badge/icon/Dataset \& Checkpoints?label&color=cyan)](https://huggingface.co/datasets/alanz-mit/FoundationTactile) 9 | 10 | [[Project Website]](https://t3.alanz.info/) 11 | 12 | [Jialiang (Alan) Zhao](https://alanz.info/), 13 | [Yuxiang Ma](https://yuxiang-ma.github.io/), 14 | [Lirui Wang](https://liruiw.github.io/), and 15 | [Edward H. Adelson](https://persci.mit.edu/people/adelson/) 16 | 17 | MIT CSAIL 18 | 19 | ## Overview 20 | We present T3, a heterogeneous tactile representation learning framework based on transformers, and FoTa, a large tactile dataset that contains over 3 million tactile images collected from 13 sensors and 11 tasks. 21 | T3 extracts the common representation that is sharable between different camera-based tactile sensors and downstream tasks. 22 | 23 | ## Installation 24 | 25 | ```sh 26 | git clone https://github.com/alanzjl/t3 27 | cd t3 28 | pip install -e . 29 | ``` 30 | 31 | ## Get started 32 | The best way to get started with using T3 or FoTa is to checkout our [![Colab](https://badgen.net/badge/icon/Colab?icon=terminal&label&color=yellow)](https://colab.research.google.com/drive/1MmO9w1y59Gy6ds0iKlW04olszGko56Vf?usp=sharing) for step-and-step instructions on how to manipulate data and run T3. 33 | More details about file structure of FoTa can be found on [![Huggingface](https://badgen.net/badge/icon/Dataset \& Checkpoints?label&color=cyan)](https://huggingface.co/datasets/alanz-mit/FoundationTactile). 34 | 35 | ## Citation 36 | ``` 37 | @article{zhao2024transferable, 38 | title={Transferable Tactile Transformers for Representation Learning Across Diverse Sensors and Tasks}, 39 | author={Jialiang Zhao and Yuxiang Ma and Lirui Wang and Edward H. Adelson}, 40 | year={2024}, 41 | eprint={2406.13640}, 42 | archivePrefix={arXiv}, 43 | } 44 | ``` 45 | 46 | MIT License. -------------------------------------------------------------------------------- /configs/network/pretrain1_cross_mae.yaml: -------------------------------------------------------------------------------- 1 | # Network configuration for pretraining I with Cross MAE (https://crossmae.github.io/) 2 | # Not used in the paper, but can be used for comparison 3 | 4 | # Network size corresponds to "t3_medium" but can be overwritten in code 5 | patch_size: 16 6 | encoder_embed_dim: 768 7 | encoder_heads: 12 8 | mask_ratio: 0.3 9 | 10 | encoders: 11 | gs_360_v2: 12 | _target_: t3.models.MAEViTEncoder 13 | mask_ratio: ${network.mask_ratio} 14 | patch_size: ${network.patch_size} 15 | embed_dim: ${network.encoder_embed_dim} 16 | depth: 3 17 | num_heads: ${network.encoder_heads} 18 | mlp_ratio: 4. 19 | 20 | gs_green: 21 | _target_: t3.models.MAEViTEncoder 22 | mask_ratio: ${network.mask_ratio} 23 | patch_size: ${network.patch_size} 24 | embed_dim: ${network.encoder_embed_dim} 25 | depth: 3 26 | num_heads: ${network.encoder_heads} 27 | mlp_ratio: 4. 28 | 29 | digit: 30 | _target_: t3.models.MAEViTEncoder 31 | mask_ratio: ${network.mask_ratio} 32 | patch_size: ${network.patch_size} 33 | embed_dim: ${network.encoder_embed_dim} 34 | depth: 3 35 | num_heads: ${network.encoder_heads} 36 | mlp_ratio: 4. 37 | 38 | mini: 39 | _target_: t3.models.MAEViTEncoder 40 | mask_ratio: ${network.mask_ratio} 41 | patch_size: ${network.patch_size} 42 | embed_dim: ${network.encoder_embed_dim} 43 | depth: 3 44 | num_heads: ${network.encoder_heads} 45 | mlp_ratio: 4. 46 | 47 | wedge: 48 | _target_: t3.models.MAEViTEncoder 49 | mask_ratio: ${network.mask_ratio} 50 | patch_size: ${network.patch_size} 51 | embed_dim: ${network.encoder_embed_dim} 52 | depth: 3 53 | num_heads: ${network.encoder_heads} 54 | mlp_ratio: 4. 55 | 56 | shared_trunk: 57 | _target_: t3.models.TransformerTrunk 58 | embed_dim: ${network.encoder_embed_dim} 59 | depth: 9 60 | num_heads: ${network.encoder_heads} 61 | mlp_ratio: 4. 62 | 63 | decoders: 64 | mae_recon_single: 65 | _target_: t3.models.CrossMAEViTDecoder 66 | patch_size: ${network.patch_size} 67 | embed_dim: ${network.encoder_embed_dim} 68 | decoder_embed_dim: 512 69 | decoder_depth: 8 70 | decoder_num_heads: 16 71 | mlp_ratio: 4. 72 | loss_func: 73 | _target_: t3.models.MAEReconLoss 74 | patch_size: 16 75 | norm_pix_loss: false # true for better representation learning, false for pixel-based loss for better reconstruction aka visualization -------------------------------------------------------------------------------- /dataset_preproc/objectfolder_real_metadata.csv: -------------------------------------------------------------------------------- 1 | Index,Name,Material 2 | 1,Soup_Spoon,Ceramic 3 | 2,Bowl,Ceramic 4 | 3,Salad_Plate,Ceramic 5 | 4,Dinner_Plate,Ceramic 6 | 5,Hair_Comb,Wood 7 | 6,Blue_Bowl,Glass 8 | 7,Decorative_Plate,Glass 9 | 8,Mixing_Bowl,Ceramic 10 | 9,Serving_Bowl,Ceramic 11 | 10,Soup_Bowl,Ceramic 12 | 11,Strainer_Spoon,Wood 13 | 12,Soup_Ladle,Wood 14 | 13,Serving_Spoon,Wood 15 | 14,Salad_Fork,Wood 16 | 15,Mixing_Spoon,Wood 17 | 16,Frying_Spatula,Wood 18 | 17,8Inch_Skillet,Iron 19 | 18,10_dot_25Inch_Skillet,Iron 20 | 19,10_dot_5Inch_Griddle,Iron 21 | 20,Dutch_Oven,Iron 22 | 21,Dutch_Oven_Lid,Iron 23 | 22,Rinsing_Cup,Glass 24 | 23,Hand_Scoop,Plastic 25 | 24,Shovel_Toy_Red_Large,Plastic 26 | 25,Shovel_Toy_Green_Small,Plastic 27 | 26,Handle_Spoon,Polycarbonate 28 | 27,Round_Plate,Wood 29 | 28,Square_Plate,Wood 30 | 29,Cutting_Board_Large,Wood 31 | 30,Cutting_Board_Middle,Wood 32 | 31,Cutting_Board_Small,Wood 33 | 32,Wine_Glass,Wood 34 | 33,Drinking_Cup,Wood 35 | 34,Beer_Mug,Wood 36 | 35,Portion_Cup_Brown,Polycarbonate 37 | 36,Portion_Cup_White,Polycarbonate 38 | 37,Cake_Pan,Steel 39 | 38,Loaf_Pan,Steel 40 | 39,Wrench_Small,Steel 41 | 40,Wrench_Middle,Steel 42 | 41,Wrench_Large,Steel 43 | 42,Pestle,Iron 44 | 43,Mortar,Iron 45 | 44,Sculpture,Iron 46 | 45,Ladle,Iron 47 | 46,Spatula,Iron 48 | 47,Decorative_Cast,Iron 49 | 48,Mixing_Bowl_Large,Plastic 50 | 49,Mixing_Bowl_Middle,Plastic 51 | 50,Mixing_Bowl_Small,Plastic 52 | 51,Fruit_Bowl,Glass 53 | 52,Fork_Small,Steel 54 | 53,Fork_Large,Steel 55 | 54,Spoon_Small,Steel 56 | 55,Spoon_Large,Steel 57 | 56,Knife_Large,Plastic 58 | 57,Knife_Middle,Plastic 59 | 58,Knife_Small,Plastic 60 | 59,Soap_Dish,Glass 61 | 60,Beer_Glass,Glass 62 | 61,Container_Large,Ceramic 63 | 62,Container_Middle,Ceramic 64 | 63,Container_Small,Ceramic 65 | 64,Mug,Ceramic 66 | 65,Vase,Ceramic 67 | 66,Plate_Handle,Iron 68 | 67,Plate,Iron 69 | 68,Plate_Base,Wood 70 | 69,Display_Stand,Iron 71 | 70,Drop_Funnel,Polycarbonate 72 | 71,Container_Lid,Polycarbonate 73 | 72,Food_Pan,Polycarbonate 74 | 73,Flowerpot_Large,Ceramic 75 | 74,Flowerpot_Small,Ceramic 76 | 75,Vase_Green,Ceramic 77 | 76,Vase_Blue,Ceramic 78 | 77,Vase_Orange,Ceramic 79 | 78,Swan_Large,Ceramic 80 | 79,Swan_Small,Ceramic 81 | 80,Spoon_Holder,Wood 82 | 81,Utensil_Container,Wood 83 | 82,Can,Glass 84 | 83,Potato_Masher,Steel 85 | 84,Skimmer,Steel 86 | 85,Pasta_Server,Steel 87 | 86,Slotted_Spoon,Steel 88 | 87,Solid_Turner,Steel 89 | 88,Ladle,Steel 90 | 89,Solid_Spoon,Steel 91 | 90,Slotted_Turner,Steel 92 | 91,Glass_Green,Glass 93 | 92,Glass_Red,Glass 94 | 93,Vase,Glass 95 | 94,Salad_Bowl,Glass 96 | 95,Scoop,Polycarbonate 97 | 96,Box_Lid,Polycarbonate 98 | 97,Stanford_Frisbee,Plastic 99 | 98,Kettlebell,Iron 100 | 99,Trim_Removal_Tool,Plastic 101 | 100,Trim_Removal_Tool_2,Plastic 102 | -------------------------------------------------------------------------------- /dataset_preproc/preprocess_yuan18_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import WdsWriter 3 | import json 4 | import argparse 5 | 6 | def data_gen(args): 7 | parent_dir = args.path 8 | property_keys = [ 9 | "fuzziness", "thickness", "smoothness", "wool", "stretchiness", "endurance", "softness", 10 | "wind_resistance", "season", "wash_method", "textile_type"] 11 | 12 | with open(os.path.join(parent_dir, "..", "cloth_metadata.json"), "r") as f: 13 | material_metadata = json.load(f) 14 | 15 | # extract .mp4 to frames 16 | for cloth_idx in os.listdir(parent_dir): 17 | cloth_dir = os.path.join(parent_dir, cloth_idx) 18 | for trial_idx in os.listdir(cloth_dir): 19 | trial = os.path.join(cloth_dir, trial_idx) 20 | vid_p = os.path.join(trial, "GelSight_video.mp4") 21 | vid_frames_p = os.path.join(trial, "gsframes") 22 | if os.path.exists(vid_p): 23 | print(f"Unpacking {vid_p}") 24 | os.makedirs(vid_frames_p, exist_ok=True) 25 | os.system(f"ffmpeg -i {vid_p} -vf fps=30 {vid_frames_p}/frame%06d.jpg") 26 | 27 | img_paths, img_names = [], [] 28 | properties = {k: [] for k in property_keys} 29 | 30 | for cloth_idx in os.listdir(parent_dir): 31 | cloth_dir = os.path.join(parent_dir, cloth_idx) 32 | for trial_idx in os.listdir(cloth_dir): 33 | trial = os.path.join(cloth_dir, trial_idx) 34 | vid_frames_p = os.path.join(trial, "gsframes") 35 | for frame in os.listdir(vid_frames_p): 36 | if not frame.endswith(".jpg"): 37 | continue 38 | new_frame_name = f"{cloth_idx}_{trial_idx}_{frame}" 39 | img_paths.append(os.path.join(vid_frames_p, frame)) 40 | img_names.append(new_frame_name) 41 | for i, k in enumerate(property_keys): 42 | properties[k].append(material_metadata[str(cloth_idx)][i]) 43 | 44 | print(f"Writing {len(img_paths)} images to webdataset") 45 | output_dir = args.output_folder 46 | domain_dir = os.path.join(output_dir, "yuan18") 47 | os.makedirs(domain_dir, exist_ok=True) 48 | wds = WdsWriter(shard_size=args.shard_size) 49 | wds.add_imgs(img_paths=img_paths, img_names=img_names) 50 | wds.add_labels(**properties) 51 | wds.save(domain_dir) 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser(description='Dataset pre-processor for Yuan18 dataset') 55 | parser.add_argument('-O', '--output_folder', type=str, 56 | help='Output folder for the pre-processed dataset') 57 | parser.add_argument('-S', '--shard_size', type=int, default=10000, 58 | help='Maximum number of samples in a single WDS shard.') 59 | parser.add_argument('--path', type=str, 60 | help='Path to the dataset') 61 | args = parser.parse_args() 62 | 63 | data_gen(args) -------------------------------------------------------------------------------- /t3/utils.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from PIL import Image 6 | 7 | class bcolors: 8 | HEADER = '\033[95m' 9 | OKBLUE = '\033[94m' 10 | OKCYAN = '\033[96m' 11 | OKGREEN = '\033[92m' 12 | WARNING = '\033[93m' 13 | FAIL = '\033[91m' 14 | ENDC = '\033[0m' 15 | BOLD = '\033[1m' 16 | UNDERLINE = '\033[4m' 17 | 18 | color_style = { 19 | 1: "{}", 20 | 2: bcolors.OKBLUE + "{}" + bcolors.ENDC, 21 | 3: bcolors.OKGREEN + "{}" + bcolors.ENDC, 22 | 4: bcolors.OKCYAN + "{}" + bcolors.ENDC, 23 | 0: bcolors.WARNING + "{}" + bcolors.ENDC, 24 | -1: bcolors.FAIL + "{}" + bcolors.ENDC, 25 | "blue": bcolors.OKBLUE + "{}" + bcolors.ENDC, 26 | "green": bcolors.OKGREEN + "{}" + bcolors.ENDC, 27 | "cyan": bcolors.OKCYAN + "{}" + bcolors.ENDC, 28 | "warning": bcolors.WARNING + "{}" + bcolors.ENDC, 29 | "red": bcolors.FAIL + "{}" + bcolors.ENDC, 30 | "bold": bcolors.BOLD + "{}" + bcolors.ENDC, 31 | } 32 | 33 | def logging(s, verbose=True, style=1): 34 | if not verbose: 35 | return 36 | print(color_style[style].format(s)) 37 | 38 | def is_dist_avail_and_initialized(): 39 | if not dist.is_available(): 40 | return False 41 | if not dist.is_initialized(): 42 | return False 43 | return True 44 | 45 | def get_world_size(): 46 | if not is_dist_avail_and_initialized(): 47 | return 1 48 | return dist.get_world_size() 49 | 50 | def get_rank(): 51 | if not is_dist_avail_and_initialized(): 52 | return 0 53 | return dist.get_rank() 54 | 55 | def is_main_process(): 56 | return get_rank() == 0 57 | 58 | def load_label_from_csv(path: str): 59 | """Load a csv file with pandas""" 60 | return pd.read_csv(path, index_col=0, header=0, sep=",") 61 | 62 | def get_entry_or(cfg, key, default): 63 | if key in cfg: 64 | return cfg[key] 65 | return default 66 | 67 | def make_dataset_pie_plot(d, title=None, show=False): 68 | domains = [] 69 | traj_nums = [] 70 | for k, v in d.items(): 71 | domains.append(f"{k} - {v // 1000}K") 72 | traj_nums.append(v) 73 | domains = np.array(domains) 74 | traj_nums = np.array(traj_nums) 75 | # sort by number of trajectories 76 | idx = np.argsort(traj_nums)[::-1] 77 | domains = domains[idx] 78 | traj_nums = traj_nums[idx] 79 | # draw the dataset mixture as a pie plot 80 | fig1, ax1 = plt.subplots(figsize=(28, 10)) 81 | traj_prob = np.array(traj_nums) / np.sum(traj_nums) 82 | patches, _ = ax1.pie(traj_prob, startangle=90) 83 | ax1.axis("equal") 84 | ax1.legend(patches, domains, loc="center left", bbox_to_anchor=(0.7, 0.5), prop={"size": 25}) 85 | if title is not None: 86 | ax1.set_title(title, fontsize=60) 87 | if show: 88 | plt.show() 89 | fig1.canvas.draw() 90 | return Image.frombytes("RGB", fig1.canvas.get_width_height(), fig1.canvas.tostring_rgb()) -------------------------------------------------------------------------------- /dataset_preproc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import webdataset as wds 3 | import json 4 | import numpy as np 5 | 6 | class WdsWriter: 7 | def __init__(self, shard_size, random_order=True): 8 | self.shard_size = shard_size 9 | self.random_order = random_order 10 | self.N = 0 11 | 12 | def add_imgs(self, img_paths, img_names=None): 13 | """ 14 | Pass in img_names in case a different key is needed for the images. 15 | Otherwise, the file name will be used as the key. 16 | """ 17 | self.N = len(img_paths) 18 | if self.random_order: 19 | self.order = np.random.permutation(self.N) 20 | else: 21 | self.order = np.arange(self.N) 22 | self.img_paths = [img_paths[i] for i in self.order] 23 | if img_names is None: 24 | self.img_names = [img_path.split('/')[-1] for img_path in self.img_paths] 25 | else: 26 | assert len(img_names) == self.N, "number of img_names should be the same as number of images" 27 | self.img_names = [img_names[i] for i in self.order] 28 | 29 | def add_labels(self, **kwargs): 30 | assert self.N > 0, "add images first" 31 | for v in kwargs.values(): 32 | assert len(v) == self.N, "number of labels should be the same as number of images" 33 | self.labels = [] 34 | for idx in self.order: 35 | self.labels.append({k: kwargs[k][idx] for k in kwargs}) 36 | 37 | def samples_generator(self, st, ed): 38 | for i in range(st, ed): 39 | img_path = self.img_paths[i] 40 | img_name = self.img_names[i] 41 | with open(img_path, 'rb') as f: 42 | img = f.read() 43 | # key of a sample should not contain '.' or extension 44 | key = os.path.splitext(img_name)[0].replace('.', '-') 45 | sample = { 46 | "__key__": key, 47 | "jpg": img, 48 | "json": json.dumps(self.labels[i]) 49 | } 50 | yield sample 51 | 52 | def save(self, output_dir, val_ratio=0.2): 53 | os.makedirs(os.path.join(output_dir, "train"), exist_ok=True) 54 | os.makedirs(os.path.join(output_dir, "val"), exist_ok=True) 55 | N = len(self.img_paths) 56 | val_st = int(N * (1 - val_ratio)) 57 | with wds.ShardWriter(os.path.join(output_dir, "train", "data-%06d.tar"), maxcount=self.shard_size) as sink: 58 | for sample in self.samples_generator(0, val_st): 59 | sink.write(sample) 60 | with open(os.path.join(output_dir, "train", "count.txt"), 'w') as f: 61 | f.write(str(val_st)) 62 | with wds.ShardWriter(os.path.join(output_dir, "val", "data-%06d.tar"), maxcount=self.shard_size) as sink: 63 | for sample in self.samples_generator(val_st, N): 64 | sink.write(sample) 65 | with open(os.path.join(output_dir, "val", "count.txt"), 'w') as f: 66 | f.write(str(N - val_st)) -------------------------------------------------------------------------------- /configs/network/pretrain1_mae.yaml: -------------------------------------------------------------------------------- 1 | # Network configuration for pretraining I with MAE 2 | 3 | # Network size corresponds to "t3_medium" but can be overwritten in code 4 | patch_size: 16 5 | encoder_embed_dim: 768 6 | encoder_heads: 12 7 | encoder_depth: 3 8 | trunk_depth: 9 9 | mask_ratio: 0.4 10 | 11 | encoders: 12 | gs_360_v2: 13 | _target_: t3.models.MAEViTEncoder 14 | mask_ratio: ${network.mask_ratio} 15 | patch_size: ${network.patch_size} 16 | embed_dim: ${network.encoder_embed_dim} 17 | depth: ${network.encoder_depth} 18 | num_heads: ${network.encoder_heads} 19 | mlp_ratio: 4. 20 | 21 | gs_green: 22 | _target_: t3.models.MAEViTEncoder 23 | mask_ratio: ${network.mask_ratio} 24 | patch_size: ${network.patch_size} 25 | embed_dim: ${network.encoder_embed_dim} 26 | depth: ${network.encoder_depth} 27 | num_heads: ${network.encoder_heads} 28 | mlp_ratio: 4. 29 | 30 | gs_black: 31 | _target_: t3.models.MAEViTEncoder 32 | mask_ratio: ${network.mask_ratio} 33 | patch_size: ${network.patch_size} 34 | embed_dim: ${network.encoder_embed_dim} 35 | depth: ${network.encoder_depth} 36 | num_heads: ${network.encoder_heads} 37 | mlp_ratio: 4. 38 | 39 | gs_tag: # for touch-and-go 40 | _target_: t3.models.MAEViTEncoder 41 | mask_ratio: ${network.mask_ratio} 42 | patch_size: ${network.patch_size} 43 | embed_dim: ${network.encoder_embed_dim} 44 | depth: ${network.encoder_depth} 45 | num_heads: ${network.encoder_heads} 46 | mlp_ratio: 4. 47 | 48 | digit: 49 | _target_: t3.models.MAEViTEncoder 50 | mask_ratio: ${network.mask_ratio} 51 | patch_size: ${network.patch_size} 52 | embed_dim: ${network.encoder_embed_dim} 53 | depth: ${network.encoder_depth} 54 | num_heads: ${network.encoder_heads} 55 | mlp_ratio: 4. 56 | 57 | mini: 58 | _target_: t3.models.MAEViTEncoder 59 | mask_ratio: ${network.mask_ratio} 60 | patch_size: ${network.patch_size} 61 | embed_dim: ${network.encoder_embed_dim} 62 | depth: ${network.encoder_depth} 63 | num_heads: ${network.encoder_heads} 64 | mlp_ratio: 4. 65 | 66 | wedge: 67 | _target_: t3.models.MAEViTEncoder 68 | mask_ratio: ${network.mask_ratio} 69 | patch_size: ${network.patch_size} 70 | embed_dim: ${network.encoder_embed_dim} 71 | depth: ${network.encoder_depth} 72 | num_heads: ${network.encoder_heads} 73 | mlp_ratio: 4. 74 | 75 | finray: 76 | _target_: t3.models.MAEViTEncoder 77 | mask_ratio: ${network.mask_ratio} 78 | patch_size: ${network.patch_size} 79 | embed_dim: ${network.encoder_embed_dim} 80 | depth: ${network.encoder_depth} 81 | num_heads: ${network.encoder_heads} 82 | mlp_ratio: 4. 83 | 84 | shared_trunk: 85 | _target_: t3.models.TransformerTrunk 86 | embed_dim: ${network.encoder_embed_dim} 87 | depth: ${network.trunk_depth} 88 | num_heads: ${network.encoder_heads} 89 | mlp_ratio: 4. 90 | 91 | decoders: 92 | mae_recon_single: 93 | _target_: t3.models.MAEViTDecoder 94 | patch_size: ${network.patch_size} 95 | embed_dim: ${network.encoder_embed_dim} 96 | decoder_embed_dim: 512 97 | decoder_depth: 8 98 | decoder_num_heads: 16 99 | mlp_ratio: 4. 100 | loss_func: 101 | _target_: t3.models.MAEReconLoss 102 | patch_size: 16 103 | norm_pix_loss: true # true for better representation learning, false for pixel-based loss for better reconstruction aka visualization -------------------------------------------------------------------------------- /dataset_preproc/preprocess_tvl_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from utils import WdsWriter 4 | import pandas as pd 5 | import json 6 | import argparse 7 | 8 | def data_gen_create_webdataset(args) -> None: 9 | """ 10 | Create data_gen_create_webdataset 11 | """ 12 | output_dir = args.output_folder 13 | domain = "tvl" 14 | domain_dir = os.path.join(output_dir, domain) 15 | os.makedirs(domain_dir, exist_ok=True) 16 | img_paths, img_names, has_contact, desc = [], [], [], [] 17 | 18 | # Add in the HCT dataset 19 | for path in [os.path.join(args.hct_path, f"data{i}") for i in range(1, 4)]: 20 | train_dict = pd.read_csv(os.path.join(path, "train.csv"), header=0, sep=",") 21 | test_dict = pd.read_csv(os.path.join(path, "test.csv"), header=0, sep=",") 22 | d = pd.concat([train_dict, test_dict], ignore_index=True) 23 | for i in range(len(d)): 24 | img_paths.append(os.path.join(path, d.iloc[i]["tactile"])) 25 | img_names.append(f"{path.split('/')[-1]}_{d.iloc[i]['tactile'].split('/')[-1]}") 26 | has_contact.append(True) 27 | desc.append(d.iloc[i]["caption"]) 28 | # sample non-contact images 29 | N = int(args.hct_no_contact_ratio * len(d)) 30 | with open(os.path.join(path, "not_contact.json"), "r") as f: 31 | non_contact = json.load(f)["tactile"] 32 | for fn in np.random.choice(non_contact, N, replace=False): 33 | img_paths.append(os.path.join(path, fn)) 34 | img_names.append(f"{path.split('/')[-1]}_{fn.split('/')[-1]}") 35 | has_contact.append(False) 36 | desc.append("No contact") 37 | print("Processing {} for the HCT dataset...".format(path)) 38 | print("Processed HCT dataset. Total images so far: ", len(img_names)) 39 | 40 | # Add in the SSVTP dataset 41 | d = pd.read_csv(os.path.join(args.ssvtp_path, "train.csv"), index_col=0, header=0, sep=",") 42 | for i in range(len(d)): 43 | img_paths.append(os.path.join(args.ssvtp_path, d.iloc[i]["tactile"])) 44 | img_names.append(f"ssvtp_{d.iloc[i]['tactile'].split('/')[-1]}") 45 | has_contact.append(True) 46 | desc.append(d.iloc[i]["caption"]) 47 | print("Processed SSVTP dataset. Total images so far: ", len(img_names)) 48 | 49 | wds = WdsWriter(shard_size=args.shard_size) 50 | wds.add_imgs(img_paths=img_paths, img_names=img_names) 51 | wds.add_labels(has_contact=has_contact, desc=desc) 52 | wds.save(domain_dir) 53 | print("Labels saved for ", domain, " Total images: ", len(img_names)) 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description='Dataset pre-processor for TVL dataset') 57 | parser.add_argument('-O', '--output_folder', type=str, 58 | help='Output folder for the pre-processed dataset') 59 | parser.add_argument('-S', '--shard_size', type=int, default=10000, 60 | help='Maximum number of samples in a single WDS shard.') 61 | parser.add_argument('--hct_path', type=str, 62 | help='Path to the HCT dataset') 63 | parser.add_argument('--ssvtp_path', type=str, 64 | help='Path to the SSVTP dataset') 65 | parser.add_argument('--hct_no_contact_ratio', type=float, default=1.0, 66 | help='The HCT dataset also contains no-contact images. This ratio specifies the fraction of no-contact images to include in the dataset.') 67 | args = parser.parse_args() 68 | data_gen_create_webdataset(args) 69 | -------------------------------------------------------------------------------- /t3/_calandra17_label_dict.py: -------------------------------------------------------------------------------- 1 | # Label dict for calandra17 object classification 2 | 3 | CALANDRA17_LABEL_DICT = { 4 | 'no_contact': 0, 5 | 'soft_blue_cylinder': 1, 6 | 'metal_can': 2, 7 | '3d_printed_blue_connector': 3, 8 | 'green_and_black_sphere': 4, 9 | 'lemon': 5, 10 | 'cow': 6, 11 | 'wiry_sphere': 7, 12 | 'rubics_cube': 8, 13 | 'plastic_cow': 9, 14 | 'black_beans_bag': 10, 15 | 'fake_flower_in_pot': 11, 16 | 'sheep': 12, 17 | '3d_printed_blue_vase': 13, 18 | 'cinnamon': 14, 19 | 'brown_paper_cup': 15, 20 | '3d_printed_white_ball': 16, 21 | 'candle_in_glass': 17, 22 | 'calcium_antacid': 18, 23 | 'ogx_shampoo': 19, 24 | 'fox_head': 20, 25 | 'bandaid_box': 21, 26 | 'webcam_box': 22, 27 | 'spam': 23, 28 | 'axe_body_spray': 24, 29 | 'plastic_duck': 25, 30 | 'red_turtle': 26, 31 | 'bag_pack': 27, 32 | 'small_coffe_cup': 28, 33 | 'mesh_container': 29, 34 | 'dark_blue_sphere': 30, 35 | 'set_small_plastic_men_blue_guy': 31, 36 | 'onion': 32, 37 | 'hair_dryer_spiky_nozzle': 33, 38 | 'french_dip': 34, 39 | 'muffin': 35, 40 | 'plastic_watering_can': 36, 41 | 'moroccan_mint_tea_box': 37, 42 | 'yellow_wooden_robot': 38, 43 | 'peppermint_altoids_box': 39, 44 | 'creamy_petroleum_jelly': 40, 45 | 'soft_blue_hexagon': 41, 46 | 'glass_candle_holder': 42, 47 | 'tomato_paste_in_metal_can': 43, 48 | 'green_plastic_cup': 44, 49 | 'board_eraser': 45, 50 | 'plastic_sheep': 46, 51 | 'stuffed_beachball': 47, 52 | 'red_apple': 48, 53 | 'set_small_plastic_men_green_guy': 49, 54 | 'edge_shave_gel': 50, 55 | 'egg_crate_foam': 51, 56 | 'brown_paper_cup_2_upside': 52, 57 | 'isopropyl_alcohol': 53, 58 | 'plastic_mushroom': 54, 59 | 'monster_truck': 55, 60 | 'fake_plastic_transformer_toy': 56, 61 | 'soft_red_cube': 57, 62 | 'mentos_gum_can': 58, 63 | 'aspirin': 59, 64 | 'angry_bird': 60, 65 | 'red_bull': 61, 66 | 'lime': 62, 67 | 'orange_plastic_castle': 63, 68 | '3d_printed_blue_house': 64, 69 | 'set_small_plastic_men_red_racer': 65, 70 | 'pino_silvestre': 66, 71 | 'soft_zebra': 67, 72 | 'white_mini_american_hat': 68, 73 | 'emergency_stop_button_for_sawyer': 69, 74 | 'pink_glass_glass': 70, 75 | 'soft_beer_bottle_holder': 71, 76 | 'kong_dog_toy': 72, 77 | 'happy_fall_stone': 73, 78 | 'set_small_plastic_men_police_man': 74, 79 | 'blue_painted_glass': 75, 80 | 'soda_can': 76, 81 | 'set_small_plastic_men_yellow_construction_worker': 77, 82 | '3d_printed_black_cylinder_gear': 78, 83 | 'monofilament_line': 79, 84 | 'construction_worker': 80, 85 | 'blue_bottle_fuel_treatment': 81, 86 | 'pig': 82, 87 | 'plastic_chicken': 83, 88 | 'peanut_butter': 84, 89 | 'tuna_can': 85, 90 | 'durabuilt_measuring_tape': 86, 91 | 'metal_cylinder_with_holes': 87, 92 | 'peptobismol': 88, 93 | 'coffee_cup': 89, 94 | 'plastic_whale': 90, 95 | 'soft_toy_dragon': 91, 96 | 'ponds_dry_skin_cream': 92, 97 | 'baby_cup': 93, 98 | 'international_travel_adapter': 94, 99 | 'potato': 95, 100 | 'translucent_turquoise_cup': 96, 101 | 'black_metallic_candle_cage': 97, 102 | 'black_plastic_half_cylinder': 98, 103 | 'feathered_ball': 99, 104 | 'dog_toy_ice_cream_cone': 100, 105 | 'purple_small_plastic_fruit': 101, 106 | 'chocolate_shake': 102, 107 | 'playdoh_container': 103, 108 | 'pencil_case': 104, 109 | 'pink_blue_coke_bottle': 105, 110 | } 111 | -------------------------------------------------------------------------------- /dataset_preproc/README.md: -------------------------------------------------------------------------------- 1 | # Foundation Tactile (FoTa) Dataset 2 | 3 | Note: before running any data preprocessing scripts, make sure to check each script's arguments. 4 | 5 | To calculate image normalization of each dataset after preprocessing, use the script `calculate_normalization.py`. 6 | 7 | #### [VisGel](http://visgel.csail.mit.edu/) dataset -> 3,170,795 (downsampled to 726,740) tactile images for pretraining (GelSight green sensor) 8 | 9 | Download both the seen and unseen dataset from [here](https://github.com/YunzhuLi/VisGel). 10 | Prepare this dataset with 11 | ```sh 12 | python preprocess_visgel_images.py 13 | ``` 14 | 15 | Useful arguments: 16 | 17 | `num_processes=[a number]`: specifies how many processes to create to calculate `cv2.laplacian(diff_img).var()` in parallel. 18 | 19 | `seen_path` and `unseen_path` specify the paths to the data folder. 20 | 21 | Note that since a large portion of this dataset is flat images, we further downsample it to roughly 25% of its original size based on a threshold on the variance. 22 | 23 | #### [TVL](https://huggingface.co/datasets/mlfu7/Touch-Vision-Language-Dataset) dataset -> 82,463 images (DIGIT sensor) 24 | 25 | Download data from [here](https://huggingface.co/datasets/mlfu7/Touch-Vision-Language-Dataset/tree/main) 26 | Prepare this dataset with 27 | ```sh 28 | python preprocess_tvl_images.py 29 | ``` 30 | 31 | #### [Touch and Go](https://touch-and-go.github.io/) dataset -> 262,082 images (GelSight TAG sensor, 250,169 with contact) 32 | 33 | Download data from [here](https://drive.google.com/drive/folders/1NDasyshDCL9aaQzxjn_-Q5MBURRT360B) 34 | Prepare this dataset with 35 | ```sh 36 | python preprocess_touchandgo_images.py 37 | ``` 38 | 39 | This dataset is also used for a classification task (22 classes) with the `datasets/touch_and_go_classification` config. 40 | 41 | #### [Calandra corl17 - "More than a feeling"](https://sites.google.com/view/the-feeling-of-success) dataset -> 24,118 images (GelSight green sensor) 42 | 43 | Download data from [here](https://drive.google.com/drive/folders/1wHEg_RR8YAQjMnt9r5biUwo5z3P6bjR3) 44 | Prepare this dataset with 45 | ```sh 46 | python preprocess_calandra17_images.py 47 | ``` 48 | Note that this scripts requires the `deepdish` pip package to handle h5 file operations. 49 | 50 | This dataset is also used for a classification task (106 classes) with the `datasets/calandra17_classification` config. 51 | 52 | #### [yuan18 - cloth](https://arxiv.org/abs/1711.00574) dataset -> 494,655 images (GelSight green sensor) 53 | 54 | Download data from - [here](http://data.csail.mit.edu/active_clothing/Data_ICRA18.tar) 55 | Prepare this dataset with 56 | ```sh 57 | python preprocess_yuan18_images.py 58 | ``` 59 | 60 | This dataset is also used for a three classification tasks: textile_type (20 classes), fuzziness (4 classes), and smoothness (5 classes) with the `datasets/yuan18_classification` config. 61 | 62 | #### [YCB-Sight](https://github.com/Robo-Touch/YCB-Sight) dataset -> 480 real images and 1800 sim images (GelSight green sensor) 63 | 64 | Download data from - [here](https://drive.google.com/drive/folders/17BPST4biGzduVtoCUBswOmkISqNh1srI) 65 | Prepare this dataset with 66 | ```sh 67 | python preprocess_ycbsight_images.py 68 | ``` 69 | 70 | This dataset also contains object labels and poses. However they are not used for additional tasks other than MAE due to smaller size. 71 | 72 | Note that since YCB-Sight Real is significant smaller, we only use it for evaluation during MAE. 73 | 74 | #### [ObjectFolder-Real](https://objectfolder.stanford.edu) dataset -> 1,417,600 images (GelSight black sensor) 75 | 76 | Download data from - [here](https://objectfolder.stanford.edu/objectfolder-real-download). 77 | Prepare this dataset with 78 | ```sh 79 | python preprocess_objectfolder_real_images.py 80 | ``` -------------------------------------------------------------------------------- /dataset_preproc/preprocess_objectfolder_real_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import WdsWriter 3 | from PIL import Image 4 | from natsort import natsorted 5 | import pandas as pd 6 | import argparse 7 | 8 | def data_gen(args): 9 | domain = "objectfolder_real" 10 | parent_dir = args.path 11 | output_dir = args.output_folder 12 | domain_dir = os.path.join(output_dir, domain) 13 | 14 | labels = pd.read_csv(args.metadata_path, index_col=0, header=0, sep=",") 15 | material_lookup = {x: i for i, x in enumerate(sorted(list(set(labels["Material"]))))} 16 | 17 | img_paths, img_names = [], [] 18 | properties = { 19 | "obj_idx": [], "obj_name": [], 20 | "material_idx": [], "material": [], 21 | "trial_idx": [], 22 | } 23 | 24 | for obj_idx in natsorted(os.listdir(parent_dir)): 25 | if not os.path.isdir(os.path.join(parent_dir, obj_idx)): 26 | continue 27 | print(f"Processing obj {obj_idx}...", end="") 28 | obj_name = labels.loc[int(obj_idx)]["Name"] 29 | material = labels.loc[int(obj_idx)]["Material"] 30 | material_idx = material_lookup[material] 31 | for trial_idx in os.listdir(os.path.join(parent_dir, obj_idx, "tactile_data")): 32 | if not trial_idx.isdigit(): 33 | continue 34 | trial_p = os.path.join(parent_dir, obj_idx, "tactile_data", trial_idx) 35 | if not os.path.isdir(trial_p): 36 | continue 37 | gelsight_path = os.path.join(trial_p, "0", "gelsight") 38 | if not os.path.isdir(gelsight_path): 39 | print(f"ERROR: path {gelsight_path} does not exist. skipping obj {obj_idx} trial {trial_idx}") 40 | 41 | for img_fn in natsorted(os.listdir(gelsight_path)): 42 | if not img_fn.endswith(".png"): 43 | continue 44 | # this dataset's default image format is png. convert to jpg if not already done 45 | jpg_fn = img_fn[:-4] + ".jpg" 46 | jpg_path = os.path.join(gelsight_path, jpg_fn) 47 | if not os.path.exists(jpg_path): 48 | img = Image.open(os.path.join(gelsight_path, img_fn)) 49 | img.save(jpg_path) 50 | 51 | img_paths.append(jpg_path) 52 | img_names.append(f"{obj_name}_{trial_idx}_{jpg_fn}") 53 | properties["obj_idx"].append(obj_idx) 54 | properties["obj_name"].append(obj_name) 55 | properties["material_idx"].append(material_idx) 56 | properties["material"].append(material) 57 | properties["trial_idx"].append(trial_idx) 58 | print("done. Size so far: ", len(img_paths)) 59 | 60 | print(f"Writing {len(img_paths)} images in {domain} to webdataset") 61 | os.makedirs(domain_dir, exist_ok=True) 62 | wds = WdsWriter(shard_size=args.shard_size) 63 | wds.add_imgs(img_paths=img_paths, img_names=img_names) 64 | wds.add_labels(**properties) 65 | wds.save(domain_dir) 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser(description='Dataset pre-processor for ObjectFolder-Real dataset') 69 | parser.add_argument('-O', '--output_folder', type=str, 70 | help='Output folder for the pre-processed dataset') 71 | parser.add_argument('-S', '--shard_size', type=int, default=10000, 72 | help='Maximum number of samples in a single WDS shard.') 73 | parser.add_argument('--path', type=str, 74 | help='Path to the dataset') 75 | parser.add_argument('--metadata_path', type=str, default="objectfolder_real_metadata.csv", 76 | help='Path to the objectfolder_real_metadata.csv metadata file') 77 | args = parser.parse_args() 78 | 79 | data_gen(args) -------------------------------------------------------------------------------- /t3/models/trunk.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trunk definition for Transferable Tactile Transformer (T3) 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | 9 | import os 10 | import torch 11 | from torch import nn 12 | from typing import Literal 13 | from .nn_utils import makeMLP, get_device 14 | from t3.utils import logging 15 | import timm.models.vision_transformer as timm_vit 16 | 17 | class Trunk(nn.Module): 18 | def __init__(self, **kwargs): 19 | super().__init__() 20 | 21 | def freeze(self): 22 | for param in self.parameters(): 23 | param.requires_grad = False 24 | 25 | def unfreeze(self): 26 | for param in self.parameters(): 27 | param.requires_grad = True 28 | 29 | def save(self, path): 30 | torch.save(self.state_dict(), path) 31 | 32 | def load(self, path): 33 | kwargs = {} 34 | if not torch.cuda.is_available(): 35 | kwargs['map_location'] = get_device() 36 | if os.path.exists(path): 37 | logging(f"Loading trunk from weights from {path}", True, "green") 38 | self.load_state_dict(torch.load(path, **kwargs)) 39 | else: 40 | logging(f"Trunk weights not found at {path}. Skipping", True, "warning") 41 | 42 | class IdentityTrunk(Trunk): 43 | def __init__(self, **kwargs): 44 | super().__init__() 45 | 46 | def forward(self, x): 47 | return x 48 | 49 | class MLPTrunk(Trunk): 50 | def __init__(self, 51 | input_dim, 52 | output_dim, 53 | hidden_dims, 54 | dropout_p=0.1, 55 | tanh_end=False, 56 | ln=False, 57 | **kwargs): 58 | super().__init__() 59 | 60 | self.model = makeMLP(input_dim, output_dim, hidden_dims, dropout_p, tanh_end, ln) 61 | 62 | def forward(self, x): 63 | return self.model(x) 64 | 65 | 66 | class TransformerTrunk(Trunk): 67 | """ 68 | Transformer with only intermediate blocks and a final normalization layer 69 | """ 70 | def __init__(self, embed_dim=768, depth=9, num_heads=12, 71 | mlp_ratio=4., norm_layer=nn.LayerNorm, 72 | pooling_type: Literal['none', 'global', 'cls'] = 'none', 73 | **kwargs): 74 | super().__init__() 75 | 76 | self.blocks = nn.ModuleList([ 77 | timm_vit.Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 78 | for i in range(depth)]) 79 | self.norm = norm_layer(embed_dim) 80 | 81 | self.pooling_type = pooling_type 82 | 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | # we use xavier_uniform following official JAX ViT: 88 | torch.nn.init.xavier_uniform_(m.weight) 89 | if isinstance(m, nn.Linear) and m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | elif isinstance(m, nn.LayerNorm): 92 | nn.init.constant_(m.bias, 0) 93 | nn.init.constant_(m.weight, 1.0) 94 | 95 | def forward(self, x): 96 | is_mae = False 97 | if isinstance(x, tuple): 98 | (x, mask, ids_restore) = x 99 | is_mae = True 100 | # apply Transformer blocks 101 | for blk in self.blocks: 102 | x = blk(x) 103 | 104 | if self.pooling_type == 'none': 105 | x = self.norm(x) 106 | elif self.pooling_type == 'global': 107 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 108 | #TODO: maybe add another norm layer here 109 | elif self.pooling_type == 'cls': 110 | x = self.norm(x) 111 | x = x[:, 0] 112 | 113 | if is_mae: 114 | return (x, mask, ids_restore) 115 | else: 116 | return x 117 | -------------------------------------------------------------------------------- /dataset_preproc/preprocess_calandra17_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from utils import WdsWriter 4 | import deepdish as dd 5 | import cv2 6 | import csv 7 | import pandas as pd 8 | from PIL import Image 9 | import argparse 10 | 11 | def extract_frames(dir, extracted_contact_dir, extracted_ref_dir, labels, ref_ratio): 12 | video1_dir = str(dir) + '/' + 'video.mp4' 13 | video2_dir = str(dir) + '/' + 'gelsight.mp4' 14 | 15 | cap1 = cv2.VideoCapture(video1_dir) 16 | frame_number1 = int(cap1.get(7)) 17 | 18 | cap2 = cv2.VideoCapture(video2_dir) 19 | frame_number2 = int(cap2.get(7)) 20 | 21 | frame_number1 = min(frame_number1, frame_number2) 22 | 23 | for i in range(frame_number1): 24 | # cap1.set(cv2.CAP_PROP_POS_FRAMES, i) 25 | cap2.set(cv2.CAP_PROP_POS_FRAMES, i) 26 | # _, frame1 = cap1.read() 27 | _, frame2 = cap2.read() 28 | fname = os.path.join(os.path.basename(dir), str(i).rjust(10,'0') + '.jpg') 29 | if fname in labels.index: 30 | # has contact 31 | cv2.imwrite(os.path.join(extracted_contact_dir, fname.replace("/", "-")), frame2) 32 | else: 33 | # no contact 34 | if np.random.rand() < ref_ratio: 35 | cv2.imwrite(os.path.join(extracted_ref_dir, fname.replace("/", "-")), frame2) 36 | # cv2.imwrite(str(dir) + '/video_frame/' + str(i).rjust(10,'0') + '.jpg', frame1) 37 | # cv2.imwrite(str(dir) + '/gelsight_frame/' + str(i).rjust(10,'0') + '.jpg', frame2) 38 | 39 | cap1.release() 40 | cap2.release() 41 | 42 | def data_gen_create_webdataset(args) -> None: 43 | """ 44 | Create data_gen_create_webdataset 45 | """ 46 | output_dir = args.output_folder 47 | domain = "calandra17" 48 | extracted_dir = os.path.join(args.path, "extracted") 49 | 50 | def _create_img(entry, extracted_dir, dataset, idx, field): 51 | img = Image.fromarray(entry[field]) 52 | k = f"{os.path.splitext(dataset)[0]}_{idx}_{field}.jpg" 53 | img.save(os.path.join(extracted_dir, k)) 54 | return [k, entry["object_name"].decode("utf-8"), entry["is_gripping"], "during" in field] 55 | 56 | # extract frames 57 | os.makedirs(extracted_dir, exist_ok=True) 58 | labels = [["path", "object_name", "is_gripping", "has_contact"]] 59 | 60 | datasets = [f for f in os.listdir(os.path.join(args.path, "dataset")) if f.endswith(".h5")] 61 | for dataset in datasets: 62 | print(f"Extracting frames for {dataset}") 63 | entries = dd.io.load(os.path.join(args.path, "dataset", dataset)) 64 | for idx, entry in enumerate(entries): 65 | l = _create_img(entry, extracted_dir, dataset, idx, "gelsightA_during") 66 | labels.append(l) 67 | 68 | l = _create_img(entry, extracted_dir, dataset, idx, "gelsightB_during") 69 | labels.append(l) 70 | 71 | if np.random.rand() < args.no_contact_ratio: 72 | l = _create_img(entry, extracted_dir, dataset, idx, "gelsightA_before") 73 | labels.append(l) 74 | 75 | l = _create_img(entry, extracted_dir, dataset, idx, "gelsightB_before") 76 | labels.append(l) 77 | with open(os.path.join(extracted_dir, "labels.csv"), "w") as f: 78 | writer = csv.writer(f) 79 | writer.writerows(labels) 80 | 81 | domain_dir = os.path.join(output_dir, domain) 82 | os.makedirs(domain_dir, exist_ok=True) 83 | 84 | img_paths, img_names, has_contact, object_name, is_gripping = [], [], [], [], [] 85 | assert os.path.exists(extracted_dir), f"Extracted dir {extracted_dir} does not exist" 86 | labels = pd.read_csv(os.path.join(extracted_dir, "labels.csv"), header=0, sep=",") 87 | for i in range(len(labels)): 88 | img_paths.append(os.path.join(extracted_dir, labels.iloc[i]["path"])) 89 | img_names.append(labels.iloc[i]["path"]) 90 | has_contact.append(bool(labels.iloc[i]["has_contact"])) 91 | object_name.append(labels.iloc[i]["object_name"]) 92 | is_gripping.append(bool(labels.iloc[i]["is_gripping"])) 93 | 94 | print(f"Total images: {len(img_paths)}") 95 | 96 | wds = WdsWriter(shard_size=args.shard_size) 97 | wds.add_imgs(img_paths=img_paths, img_names=img_names) 98 | wds.add_labels(has_contact=has_contact, object_name=object_name, is_gripping=is_gripping) 99 | wds.save(domain_dir) 100 | print("Labels saved for ", domain, " Total images: ", len(img_names)) 101 | 102 | if __name__ == "__main__": 103 | parser = argparse.ArgumentParser(description='Dataset pre-processor for Calandra17 dataset') 104 | parser.add_argument('-O', '--output_folder', type=str, 105 | help='Output folder for the pre-processed dataset') 106 | parser.add_argument('-S', '--shard_size', type=int, default=10000, 107 | help='Maximum number of samples in a single WDS shard.') 108 | parser.add_argument('--path', type=str, 109 | help='Path to the dataset') 110 | parser.add_argument('--no_contact_ratio', type=float, default=0.3, 111 | help='This dataset also contains no-contact images. This ratio specifies the fraction of no-contact images to include in the dataset.') 112 | args = parser.parse_args() 113 | data_gen_create_webdataset(args) 114 | -------------------------------------------------------------------------------- /t3/models/t3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transferable Tactile Transformer (T3) 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | 9 | import hydra 10 | 11 | import os 12 | from torch import nn 13 | from t3.utils import logging 14 | 15 | class T3(nn.Module): 16 | def __init__(self, cfg, **kwargs): 17 | super().__init__() 18 | self.cfg = cfg 19 | self.encoders = {} 20 | self.decoders = {} 21 | self.loss_funcs = {} 22 | self.trunk = hydra.utils.instantiate(cfg.shared_trunk) 23 | self._is_trunk_transformer = "Transformer" in cfg.shared_trunk._target_ 24 | 25 | for name, encoder_cfg in cfg.encoders.items(): 26 | self.encoders[name] = hydra.utils.instantiate(encoder_cfg) 27 | 28 | for name, decoder_cfg in cfg.decoders.items(): 29 | self.decoders[name] = hydra.utils.instantiate(decoder_cfg) 30 | if hasattr(decoder_cfg, "loss_func"): 31 | self.loss_funcs[name] = hydra.utils.instantiate(decoder_cfg.loss_func) 32 | else: 33 | self.loss_funcs[name] = None 34 | 35 | self.encoders = nn.ModuleDict(self.encoders) 36 | self.decoders = nn.ModuleDict(self.decoders) 37 | self.loss_funcs = nn.ModuleDict(self.loss_funcs) 38 | self._encoder_domain = None 39 | self._decoder_domain = None 40 | 41 | def model_summary(self): 42 | print("==========================================") 43 | encoder_parameters = sum(p.numel() for p in self.encoders.parameters() if p.requires_grad) 44 | trunk_parameters = sum(p.numel() for p in self.trunk.parameters() if p.requires_grad) 45 | decoder_parameters = sum(p.numel() for p in self.decoders.parameters() if p.requires_grad) 46 | n_parameters = encoder_parameters + trunk_parameters + decoder_parameters 47 | logging( 48 | f"number of total trainable params (M): {n_parameters / 1.0e6:.3f} \n\ 49 | encoder: {encoder_parameters / 1.0e6:.3f} \n\ 50 | trunk: {trunk_parameters / 1.0e6:.3f} \n\ 51 | decoder: {decoder_parameters / 1.0e6:.3f}", True, "green") 52 | 53 | def set_domains(self, encoder_domain, decoder_domain, forward_mode): 54 | assert encoder_domain in self.encoders, f"encoder domain {encoder_domain} not found in encoders" 55 | assert decoder_domain in self.decoders, f"decoder domain {decoder_domain} not found in decoders" 56 | self._encoder_domain = encoder_domain 57 | self._decoder_domain = decoder_domain 58 | self._forward_mode = forward_mode 59 | 60 | def freeze_encoder(self, encoder_domain=None): 61 | if encoder_domain is None: 62 | for encoder in self.encoders.values(): 63 | encoder.freeze() 64 | else: 65 | assert encoder_domain in self.encoders, f"encoder domain {encoder_domain} not found in encoders" 66 | self.encoders[encoder_domain].freeze() 67 | 68 | def unfreeze_encoder(self, encoder_domain=None): 69 | if encoder_domain is None: 70 | for encoder in self.encoders.values(): 71 | encoder.unfreeze() 72 | else: 73 | assert encoder_domain in self.encoders, f"encoder domain {encoder_domain} not found in encoders" 74 | self.encoders[encoder_domain].unfreeze() 75 | 76 | def freeze_trunk(self): 77 | self.trunk.freeze() 78 | 79 | def unfreeze_trunk(self): 80 | self.trunk.unfreeze() 81 | 82 | def forward(self, *args, **kwargs): 83 | if self._forward_mode == "single_tower": 84 | return self.single_tower_forward(*args, **kwargs) 85 | elif self._forward_mode == "multi_tower": 86 | return self.multi_tower_forward(*args, **kwargs) 87 | else: 88 | raise ValueError(f"forward mode {self._forward_mode} not recognized") 89 | 90 | def single_tower_forward(self, x): 91 | x = self.encoders[self._encoder_domain](x) 92 | x = self.trunk(x) 93 | x = self.decoders[self._decoder_domain](x) 94 | return x 95 | 96 | def multi_tower_forward(self, *xs): 97 | xs = [self.encoders[self._encoder_domain](x) for x in xs] 98 | xs = [self.trunk(x) for x in xs] 99 | x = self.decoders[self._decoder_domain](*xs) 100 | return x 101 | 102 | def compute_loss(self, y_pred, y_true): 103 | return self.loss_funcs[self._decoder_domain](y_pred, y_true) 104 | 105 | def save_components(self, dir): 106 | os.makedirs(f"{dir}/encoders", exist_ok=True) 107 | os.makedirs(f"{dir}/decoders", exist_ok=True) 108 | for encoder_name, encoder in self.encoders.items(): 109 | encoder.save(f"{dir}/encoders/{encoder_name}.pth") 110 | for decoder_name, decoder in self.decoders.items(): 111 | decoder.save(f"{dir}/decoders/{decoder_name}.pth") 112 | self.trunk.save(f"{dir}/trunk.pth") 113 | 114 | def load_components(self, dir): 115 | for encoder_name, encoder in self.encoders.items(): 116 | encoder.load(f"{dir}/encoders/{encoder_name}.pth") 117 | for decoder_name, decoder in self.decoders.items(): 118 | decoder.load(f"{dir}/decoders/{decoder_name}.pth") 119 | self.trunk.load(f"{dir}/trunk.pth") 120 | 121 | def make_T3_tiny(cfg): 122 | return T3(cfg) -------------------------------------------------------------------------------- /dataset_preproc/preprocess_touchandgo_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from utils import WdsWriter 4 | import pandas as pd 5 | import cv2 6 | import argparse 7 | 8 | def extract_frames(dir, extracted_contact_dir, extracted_ref_dir, labels, ref_ratio): 9 | video1_dir = str(dir) + '/' + 'video.mp4' 10 | video2_dir = str(dir) + '/' + 'gelsight.mp4' 11 | 12 | cap1 = cv2.VideoCapture(video1_dir) 13 | frame_number1 = int(cap1.get(7)) 14 | 15 | cap2 = cv2.VideoCapture(video2_dir) 16 | frame_number2 = int(cap2.get(7)) 17 | 18 | frame_number1 = min(frame_number1, frame_number2) 19 | 20 | for i in range(frame_number1): 21 | # cap1.set(cv2.CAP_PROP_POS_FRAMES, i) 22 | cap2.set(cv2.CAP_PROP_POS_FRAMES, i) 23 | # _, frame1 = cap1.read() 24 | _, frame2 = cap2.read() 25 | fname = os.path.join(os.path.basename(dir), str(i).rjust(10,'0') + '.jpg') 26 | if fname in labels.index: 27 | # has contact 28 | cv2.imwrite(os.path.join(extracted_contact_dir, fname.replace("/", "-")), frame2) 29 | else: 30 | # no contact 31 | if np.random.rand() < ref_ratio: 32 | cv2.imwrite(os.path.join(extracted_ref_dir, fname.replace("/", "-")), frame2) 33 | # cv2.imwrite(str(dir) + '/video_frame/' + str(i).rjust(10,'0') + '.jpg', frame1) 34 | # cv2.imwrite(str(dir) + '/gelsight_frame/' + str(i).rjust(10,'0') + '.jpg', frame2) 35 | 36 | cap1.release() 37 | cap2.release() 38 | 39 | def data_gen_create_webdataset(args) -> None: 40 | """ 41 | Create data_gen_create_webdataset 42 | """ 43 | output_dir = args.output_folder 44 | domain = "touch_and_go" 45 | 46 | # load labels 47 | labels = pd.read_csv(os.path.join(args.path, "label.txt"), header=None, sep=",") 48 | labels.set_index(0, inplace=True) 49 | 50 | # load label references 51 | labels_ref = {} 52 | with open(os.path.join(args.path, "category_reference.txt"), "r") as f: 53 | for line in f: 54 | v, k = line.strip().split(":") 55 | if "(" in k: 56 | k = k[:k.find("(")].strip() 57 | labels_ref[int(k)] = v.lower().replace("'", "") 58 | 59 | extracted_contact_dir = os.path.join(args.path, "extracted", "contact") 60 | extracted_ref_dir = os.path.join(args.path, "extracted", "nocontact") 61 | 62 | # extract frames from video 63 | os.makedirs(extracted_contact_dir, exist_ok=True) 64 | os.makedirs(extracted_ref_dir, exist_ok=True) 65 | 66 | d = os.path.join(args.path, "dataset") 67 | folders = [f for f in os.listdir(d) if os.path.isdir(os.path.join(d, f))] 68 | for folder in folders: 69 | print(f"Extracting frames from {folder}...") 70 | extract_frames(os.path.join(d, folder), extracted_contact_dir, extracted_ref_dir, labels, ref_ratio=0.01) 71 | 72 | 73 | domain_dir = os.path.join(output_dir, domain) 74 | os.makedirs(domain_dir, exist_ok=True) 75 | 76 | img_paths, img_names, has_contact, contact_class, contact_name = [], [], [], [], [] 77 | assert os.path.exists(extracted_contact_dir), f"Extracted contact dir {extracted_contact_dir} does not exist" 78 | assert os.path.exists(extracted_ref_dir), f"Extracted ref dir {extracted_ref_dir} does not exist" 79 | for f in os.listdir(extracted_contact_dir): 80 | img_paths.append(os.path.join(extracted_contact_dir, f)) 81 | img_names.append(f) 82 | has_contact.append(True) 83 | class_id = int(labels.loc[f.replace("-", "/")][1]) 84 | if class_id in labels_ref: 85 | class_name = labels_ref[class_id] 86 | else: 87 | class_name = "unknown" 88 | contact_class.append(class_id) 89 | contact_name.append(class_name) 90 | 91 | N_contact = len(img_paths) 92 | 93 | ref_files = [f for f in os.listdir(extracted_ref_dir) if f.endswith(".jpg")] 94 | np.random.shuffle(ref_files) 95 | ref_files = ref_files[:int(len(img_paths) * args.no_contact_ratio)] 96 | for ref_f in ref_files: 97 | img_paths.append(os.path.join(extracted_ref_dir, ref_f)) 98 | img_names.append(ref_f) 99 | has_contact.append(False) 100 | contact_class.append(-2) 101 | contact_name.append("nocontact") 102 | 103 | print(f"Total images: {len(img_paths)} in which {N_contact} has contact") 104 | 105 | wds = WdsWriter(shard_size=args.shard_size) 106 | wds.add_imgs(img_paths=img_paths, img_names=img_names) 107 | wds.add_labels(has_contact=has_contact, contact_class=contact_class, contact_name=contact_name) 108 | wds.save(domain_dir) 109 | print("Labels saved for ", domain, " Total images: ", len(img_names)) 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser(description='Dataset pre-processor for Touch-and-Go dataset') 113 | parser.add_argument('-O', '--output_folder', type=str, 114 | help='Output folder for the pre-processed dataset') 115 | parser.add_argument('-S', '--shard_size', type=int, default=10000, 116 | help='Maximum number of samples in a single WDS shard.') 117 | parser.add_argument('--path', type=str, 118 | help='Path to the dataset') 119 | parser.add_argument('--no_contact_ratio', type=float, default=0.3, 120 | help='This dataset also contains no-contact images. This ratio specifies the fraction of no-contact images to include in the dataset.') 121 | args = parser.parse_args() 122 | data_gen_create_webdataset(args) 123 | -------------------------------------------------------------------------------- /configs/network/pretrain2.yaml: -------------------------------------------------------------------------------- 1 | # Network configuration for pretraining II 2 | 3 | # Network size corresponds to "t3_medium" but can be overwritten in code 4 | patch_size: 16 5 | encoder_embed_dim: 768 6 | encoder_heads: 12 7 | pooling: "none" 8 | encoder_depth: 3 9 | trunk_depth: 9 10 | 11 | encoders: 12 | gs_360_v2: 13 | _target_: t3.models.ViTEncoder 14 | patch_size: ${network.patch_size} 15 | embed_dim: ${network.encoder_embed_dim} 16 | depth: ${network.encoder_depth} 17 | num_heads: ${network.encoder_heads} 18 | mlp_ratio: 4. 19 | 20 | gs_green: 21 | _target_: t3.models.ViTEncoder 22 | patch_size: ${network.patch_size} 23 | embed_dim: ${network.encoder_embed_dim} 24 | depth: ${network.encoder_depth} 25 | num_heads: ${network.encoder_heads} 26 | mlp_ratio: 4. 27 | 28 | gs_black: # for objectfolder-real 29 | _target_: t3.models.ViTEncoder 30 | patch_size: ${network.patch_size} 31 | embed_dim: ${network.encoder_embed_dim} 32 | depth: ${network.encoder_depth} 33 | num_heads: ${network.encoder_heads} 34 | mlp_ratio: 4. 35 | 36 | gs_tag: # for touch-and-go 37 | _target_: t3.models.ViTEncoder 38 | patch_size: ${network.patch_size} 39 | embed_dim: ${network.encoder_embed_dim} 40 | depth: ${network.encoder_depth} 41 | num_heads: ${network.encoder_heads} 42 | mlp_ratio: 4. 43 | 44 | digit: 45 | _target_: t3.models.ViTEncoder 46 | patch_size: ${network.patch_size} 47 | embed_dim: ${network.encoder_embed_dim} 48 | depth: ${network.encoder_depth} 49 | num_heads: ${network.encoder_heads} 50 | mlp_ratio: 4. 51 | 52 | mini: 53 | _target_: t3.models.ViTEncoder 54 | patch_size: ${network.patch_size} 55 | embed_dim: ${network.encoder_embed_dim} 56 | depth: ${network.encoder_depth} 57 | num_heads: ${network.encoder_heads} 58 | mlp_ratio: 4. 59 | 60 | wedge: 61 | _target_: t3.models.ViTEncoder 62 | patch_size: ${network.patch_size} 63 | embed_dim: ${network.encoder_embed_dim} 64 | depth: ${network.encoder_depth} 65 | num_heads: ${network.encoder_heads} 66 | mlp_ratio: 4. 67 | 68 | finray: 69 | _target_: t3.models.ViTEncoder 70 | patch_size: ${network.patch_size} 71 | embed_dim: ${network.encoder_embed_dim} 72 | depth: ${network.encoder_depth} 73 | num_heads: ${network.encoder_heads} 74 | mlp_ratio: 4. 75 | 76 | shared_trunk: 77 | _target_: t3.models.TransformerTrunk 78 | embed_dim: ${network.encoder_embed_dim} 79 | depth: ${network.trunk_depth} 80 | num_heads: ${network.encoder_heads} 81 | mlp_ratio: 4. 82 | pooling_type: ${network.pooling} 83 | 84 | decoders: 85 | pose_estimation_6d_mini: 86 | _target_: t3.models.CNNFCDecoder 87 | inplanes: ${network.encoder_embed_dim} 88 | fc_hidden_dims: [256, 64] 89 | output_dim: 9 # using d9 representation 90 | stride: 2 91 | dropout_p: 0.1 92 | tanh_end: false 93 | transformer_upstream: true 94 | loss_func: 95 | _target_: torch.nn.MSELoss 96 | 97 | pose_estimation_6d_wedge: 98 | _target_: t3.models.CNNFCDecoder 99 | inplanes: ${network.encoder_embed_dim} 100 | fc_hidden_dims: [256, 64] 101 | output_dim: 9 # using d9 representation 102 | stride: 2 103 | dropout_p: 0.1 104 | tanh_end: false 105 | transformer_upstream: true 106 | loss_func: 107 | _target_: torch.nn.MSELoss 108 | 109 | pose_estimation_6d_digit: 110 | _target_: t3.models.CNNFCDecoder 111 | inplanes: ${network.encoder_embed_dim} 112 | fc_hidden_dims: [256, 64] 113 | output_dim: 9 # using d9 representation 114 | stride: 2 115 | dropout_p: 0.1 116 | tanh_end: false 117 | transformer_upstream: true 118 | loss_func: 119 | _target_: torch.nn.MSELoss 120 | 121 | variance_estimation: 122 | _target_: t3.models.MLPDecoder 123 | input_dim: ${network.encoder_embed_dim} 124 | output_dim: 1 125 | hidden_dims: [256, 128, 64] 126 | dropout_p: 0.1 127 | transformer_upstream: true 128 | pooling_type: global 129 | loss_func: 130 | _target_: t3.models.VarianceScaledLoss 131 | 132 | cls_touch_and_go: 133 | _target_: t3.models.MLPDecoder 134 | input_dim: ${network.encoder_embed_dim} 135 | output_dim: 22 136 | hidden_dims: [256, 128, 64] 137 | dropout_p: 0.1 138 | transformer_upstream: true 139 | pooling_type: cls 140 | loss_func: 141 | _target_: torch.nn.CrossEntropyLoss 142 | 143 | cls_calandra17: 144 | _target_: t3.models.MLPDecoder 145 | input_dim: ${network.encoder_embed_dim} 146 | output_dim: 106 147 | hidden_dims: [256, 128, 64] 148 | dropout_p: 0.1 149 | transformer_upstream: true 150 | pooling_type: cls 151 | loss_func: 152 | _target_: torch.nn.CrossEntropyLoss 153 | 154 | cls_yuan18_textile_type: 155 | _target_: t3.models.MLPDecoder 156 | input_dim: ${network.encoder_embed_dim} 157 | output_dim: 20 158 | hidden_dims: [256, 128, 64] 159 | dropout_p: 0.1 160 | transformer_upstream: true 161 | pooling_type: cls 162 | loss_func: 163 | _target_: torch.nn.CrossEntropyLoss 164 | 165 | cls_yuan18_smoothness: 166 | _target_: t3.models.MLPDecoder 167 | input_dim: ${network.encoder_embed_dim} 168 | output_dim: 5 169 | hidden_dims: [256, 128, 64] 170 | dropout_p: 0.1 171 | transformer_upstream: true 172 | pooling_type: cls 173 | loss_func: 174 | _target_: torch.nn.CrossEntropyLoss 175 | 176 | cls_yuan18_fuzziness: 177 | _target_: t3.models.MLPDecoder 178 | input_dim: ${network.encoder_embed_dim} 179 | output_dim: 4 180 | hidden_dims: [256, 128, 64] 181 | dropout_p: 0.1 182 | transformer_upstream: true 183 | pooling_type: cls 184 | loss_func: 185 | _target_: torch.nn.CrossEntropyLoss 186 | 187 | cls_objectfolder: 188 | _target_: t3.models.MLPDecoder 189 | input_dim: ${network.encoder_embed_dim} 190 | output_dim: 100 191 | hidden_dims: [256, 128, 64] 192 | dropout_p: 0.1 193 | transformer_upstream: true 194 | pooling_type: cls 195 | loss_func: 196 | _target_: torch.nn.CrossEntropyLoss 197 | 198 | pose_estimation_3d: 199 | _target_: t3.models.CNNFCDecoder 200 | inplanes: ${network.encoder_embed_dim} 201 | fc_hidden_dims: [256, 64] 202 | output_dim: 3 # using d9 representation 203 | stride: 2 204 | dropout_p: 0.1 205 | tanh_end: false 206 | transformer_upstream: true 207 | loss_func: 208 | _target_: torch.nn.MSELoss -------------------------------------------------------------------------------- /configs/datasets/single_tower_mae.yaml: -------------------------------------------------------------------------------- 1 | # variables should start with VAR_ and be defined in the same file 2 | VAR_random_resize_crop: true 3 | VAR_random_hv_flip_prob: 0.5 4 | VAR_color_jitter: 5 | brightness: 0.4 6 | contrast: 0.4 7 | saturation: 0.5 8 | hue: 0.3 9 | 10 | # tippur23_1000Cylinder: 11 | # activate: true # if false, this dataset will not be used 12 | # eval_only: false # if true, this dataset will not be used for training 13 | # data_loader: 14 | # _target_: t3.data_loader.SingleTowerMAEDataset 15 | # data_dir: "data/FoundationTactile/tippur23_1000Cylinder" 16 | # encoder_domain: "gs_360_v2" 17 | # decoder_domain: "mae_recon_single" 18 | # random_resize_crop: true 19 | # random_hv_flip_prob: 0.5 20 | # color_jitter: 21 | # brightness: 0.2 22 | # contrast: 0.2 23 | # saturation: 0.2 24 | # hue: 0.1 25 | # img_norm: 26 | # mean: [0.00174, 0.62280, 0.11578] 27 | # std: [0.01036, 0.07555, 0.06993] 28 | 29 | tippur23_cylv0: 30 | activate: true 31 | eval_only: false 32 | data_loader: 33 | _target_: t3.data_loader.SingleTowerMAEDataset 34 | data_dir: "data/FoundationTactile/tippur23_cylv0" 35 | encoder_domain: "gs_360_v2" 36 | decoder_domain: "mae_recon_single" 37 | random_resize_crop: ${datasets.VAR_random_resize_crop} 38 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 39 | color_jitter: ${datasets.VAR_color_jitter} 40 | img_norm: 41 | mean: [0.48141, 0.24831, 0.30964] 42 | std: [0.07714, 0.07175, 0.07466] 43 | 44 | tippur23_cylv1: 45 | activate: true 46 | eval_only: false 47 | data_loader: 48 | _target_: t3.data_loader.SingleTowerMAEDataset 49 | data_dir: "data/FoundationTactile/tippur23_cylv1" 50 | encoder_domain: "gs_360_v2" 51 | decoder_domain: "mae_recon_single" 52 | random_resize_crop: ${datasets.VAR_random_resize_crop} 53 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 54 | color_jitter: ${datasets.VAR_color_jitter} 55 | img_norm: 56 | mean: [0.26948, 0.23950, 0.25199] 57 | std: [0.07708, 0.07323, 0.08599] 58 | 59 | tvl: 60 | activate: true 61 | eval_only: false 62 | data_loader: 63 | _target_: t3.data_loader.SingleTowerMAEDataset 64 | data_dir: "data/FoundationTactile/tvl" 65 | encoder_domain: "digit" 66 | decoder_domain: "mae_recon_single" 67 | random_resize_crop: ${datasets.VAR_random_resize_crop} 68 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 69 | color_jitter: ${datasets.VAR_color_jitter} 70 | img_norm: 71 | mean: [0.42094, 0.44041, 0.44151] 72 | std: [0.14661, 0.08972, 0.09041] 73 | 74 | visgel_downsampled: 75 | activate: true 76 | eval_only: false 77 | data_loader: 78 | _target_: t3.data_loader.SingleTowerMAEDataset 79 | data_dir: "data/FoundationTactile/visgel_downsampled" 80 | encoder_domain: "gs_green" 81 | decoder_domain: "mae_recon_single" 82 | random_resize_crop: false 83 | random_hv_flip_prob: 0 84 | color_jitter: ${datasets.VAR_color_jitter} 85 | img_norm: 86 | mean: [0.36917, 0.51738, 0.51782] 87 | std: [0.13530, 0.11039, 0.10001] 88 | 89 | calandra17: 90 | activate: true 91 | eval_only: false 92 | data_loader: 93 | _target_: t3.data_loader.SingleTowerMAEDataset 94 | data_dir: "data/FoundationTactile/calandra17" 95 | encoder_domain: "gs_green" 96 | decoder_domain: "mae_recon_single" 97 | random_resize_crop: ${datasets.VAR_random_resize_crop} 98 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 99 | color_jitter: ${datasets.VAR_color_jitter} 100 | img_norm: 101 | mean: [0.27307, 0.27307, 0.27307] 102 | std: [0.26252, 0.28064, 0.30760] 103 | 104 | touch_and_go: 105 | activate: true 106 | eval_only: false 107 | data_loader: 108 | _target_: t3.data_loader.SingleTowerMAEDataset 109 | data_dir: "data/FoundationTactile/touch_and_go" 110 | encoder_domain: "gs_tag" 111 | decoder_domain: "mae_recon_single" 112 | random_resize_crop: ${datasets.VAR_random_resize_crop} 113 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 114 | color_jitter: ${datasets.VAR_color_jitter} 115 | img_norm: 116 | mean: [0.51808, 0.50300, 0.51457] 117 | std: [0.13893, 0.11343, 0.13497] 118 | 119 | yuan_18: 120 | activate: true 121 | eval_only: false 122 | data_loader: 123 | _target_: t3.data_loader.SingleTowerMAEDataset 124 | data_dir: "data/FoundationTactile/yuan18" 125 | encoder_domain: "gs_green" 126 | decoder_domain: "mae_recon_single" 127 | random_resize_crop: ${datasets.VAR_random_resize_crop} 128 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 129 | color_jitter: ${datasets.VAR_color_jitter} 130 | img_norm: 131 | mean: [0.41745, 0.42082, 0.40049] 132 | std: [0.11456, 0.11639, 0.10868] 133 | 134 | ycbsight_real: 135 | activate: true 136 | eval_only: true 137 | data_loader: 138 | _target_: t3.data_loader.SingleTowerMAEDataset 139 | data_dir: "data/FoundationTactile/ycbsight_real" 140 | encoder_domain: "gs_green" 141 | decoder_domain: "mae_recon_single" 142 | random_resize_crop: ${datasets.VAR_random_resize_crop} 143 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 144 | color_jitter: null 145 | img_norm: 146 | mean: [0.51040, 0.51558, 0.57299] 147 | std: [0.06538, 0.09097, 0.12421] 148 | 149 | ycbsight_sim: 150 | activate: true 151 | eval_only: false 152 | data_loader: 153 | _target_: t3.data_loader.SingleTowerMAEDataset 154 | data_dir: "data/FoundationTactile/ycbsight_sim" 155 | encoder_domain: "gs_green" 156 | decoder_domain: "mae_recon_single" 157 | random_resize_crop: ${datasets.VAR_random_resize_crop} 158 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 159 | color_jitter: ${datasets.VAR_color_jitter} 160 | img_norm: 161 | mean: [0.52429, 0.51465, 0.60872] 162 | std: [0.06524, 0.08824, 0.11888] 163 | 164 | objectfolder_real: 165 | activate: true 166 | eval_only: false 167 | data_loader: 168 | _target_: t3.data_loader.SingleTowerMAEDataset 169 | data_dir: "data/FoundationTactile/objectfolder_real" 170 | encoder_domain: "gs_black" 171 | decoder_domain: "mae_recon_single" 172 | random_resize_crop: ${datasets.VAR_random_resize_crop} 173 | random_hv_flip_prob: ${datasets.VAR_random_hv_flip_prob} 174 | color_jitter: ${datasets.VAR_color_jitter} 175 | img_norm: 176 | mean: [0.46676, 0.45028, 0.45292] 177 | std: [0.08171, 0.06973, 0.08618] -------------------------------------------------------------------------------- /dataset_preproc/preprocess_ycbsight_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import WdsWriter 3 | import json 4 | from scipy.spatial.transform import Rotation as R 5 | from autolab_core import RigidTransform 6 | from natsort import natsorted 7 | import argparse 8 | 9 | def calc_real_poses_for_obj(path): 10 | res = {} 11 | def d7_to_rotmat(d7, from_frame, to_frame): 12 | assert len(d7) == 7 13 | return RigidTransform( 14 | rotation=R.from_quat([float(x) for x in d7[3:]]).as_matrix(), 15 | translation=[float(x) for x in d7[:3]], 16 | from_frame=from_frame, 17 | to_frame=to_frame 18 | ) 19 | with open(os.path.join(path, "tf.json"), "r") as f: 20 | tfs = json.load(f) 21 | # not sure if this is absolutely correct but this is what the json key says 22 | robot_to_sensor = d7_to_rotmat(tfs["gripper2gelsight"], "robot", "sensor") 23 | world_to_obj = d7_to_rotmat(tfs["world2object"], "world", "obj") 24 | frame_cnt = 0 25 | with open(os.path.join(path, "robot.csv"), "r") as f: 26 | f.readline() 27 | for line in f: 28 | line = line.strip().split(",") 29 | # frame_idx = int(line[0]) 30 | robot_pose = d7_to_rotmat(line[1:], "robot", "world") # robot pose in world 31 | sensor_pose = robot_pose.dot(robot_to_sensor.inverse()) # sensor pose in world 32 | sensor_to_obj = world_to_obj.dot(sensor_pose) # sensor pose in obj 33 | assert sensor_to_obj.from_frame == "sensor" and sensor_to_obj.to_frame == "obj" 34 | res[frame_cnt] = [R.from_matrix(sensor_to_obj.rotation).as_quat(), 1000 * sensor_to_obj.translation] 35 | frame_cnt += 1 36 | return res 37 | 38 | def calc_sim_poses_for_obj(path): 39 | res = {} 40 | def purge_list(l): 41 | return [x for x in l if x != ""] 42 | with open(os.path.join(path, "pose.txt"), "r") as f: 43 | for line in f: 44 | line = line.strip().split(",") 45 | frame_idx = int(line[0].split(".")[0]) 46 | quat = [float(x) for x in purge_list(line[1].strip()[1:-1].strip().split(" "))[:4]] 47 | tra_mm = [1000*float(x) for x in purge_list(line[2].strip()[1:-1].strip().split(" "))[:3]] 48 | res[frame_idx] = [quat, tra_mm] 49 | return res 50 | 51 | def data_gen(args): 52 | """Unpack every video under parent_dir""" 53 | for domain in ["ycbsight_real", "ycbsight_sim"]: 54 | parent_dir = args.__dict__[domain] 55 | output_dir = args.output_folder 56 | domain_dir = os.path.join(output_dir, domain) 57 | is_real = "real" in domain 58 | 59 | img_paths, img_names = [], [] 60 | properties = { 61 | "x_mm": [], "y_mm": [], "z_mm": [], 62 | "quat_x": [], "quat_y": [], "quat_z": [], "quat_w": [], 63 | "obj_idx": [], "obj_name": [] 64 | } 65 | 66 | for obj_idx, obj_name in enumerate(os.listdir(parent_dir)): 67 | if is_real: 68 | poses = calc_real_poses_for_obj(os.path.join(parent_dir, obj_name)) 69 | sub_dir = os.path.join(parent_dir, obj_name, "gelsight") 70 | for frame_idx, img_p in enumerate(natsorted(os.listdir(sub_dir))): 71 | img_paths.append(os.path.join(sub_dir, img_p)) 72 | new_img_fn = f"{obj_name}_{img_p}" 73 | img_names.append(new_img_fn) 74 | # frame_idx = int(img_p.split("_")[2].split(".")[0]) 75 | try: 76 | properties["x_mm"].append(poses[frame_idx][1][0]) 77 | properties["y_mm"].append(poses[frame_idx][1][1]) 78 | properties["z_mm"].append(poses[frame_idx][1][2]) 79 | properties["quat_x"].append(poses[frame_idx][0][0]) 80 | properties["quat_y"].append(poses[frame_idx][0][1]) 81 | properties["quat_z"].append(poses[frame_idx][0][2]) 82 | properties["quat_w"].append(poses[frame_idx][0][3]) 83 | properties["obj_idx"].append(obj_idx) 84 | properties["obj_name"].append(obj_name) 85 | except KeyError: 86 | print(f"Frame {frame_idx} not found for {obj_name} in {sub_dir}") 87 | exit(1) 88 | else: 89 | poses = calc_sim_poses_for_obj(os.path.join(parent_dir, obj_name)) 90 | sub_dir = os.path.join(parent_dir, obj_name, "tactile_imgs") 91 | for img_p in os.listdir(sub_dir): 92 | img_paths.append(os.path.join(sub_dir, img_p)) 93 | new_img_fn = f"{obj_name}_{img_p}" 94 | img_names.append(new_img_fn) 95 | frame_idx = int(img_p.split(".")[0]) 96 | properties["x_mm"].append(poses[frame_idx][1][0]) 97 | properties["y_mm"].append(poses[frame_idx][1][1]) 98 | properties["z_mm"].append(poses[frame_idx][1][2]) 99 | properties["quat_x"].append(poses[frame_idx][0][0]) 100 | properties["quat_y"].append(poses[frame_idx][0][1]) 101 | properties["quat_z"].append(poses[frame_idx][0][2]) 102 | properties["quat_w"].append(poses[frame_idx][0][3]) 103 | properties["obj_idx"].append(obj_idx) 104 | properties["obj_name"].append(obj_name) 105 | print(f"Writing {len(img_paths)} images in {domain} to webdataset") 106 | os.makedirs(domain_dir, exist_ok=True) 107 | wds = WdsWriter(shard_size=args.shard_size) 108 | wds.add_imgs(img_paths=img_paths, img_names=img_names) 109 | wds.add_labels(**properties) 110 | wds.save(domain_dir) 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser(description='Dataset pre-processor for YCBSight dataset') 114 | parser.add_argument('-O', '--output_folder', type=str, 115 | help='Output folder for the pre-processed dataset') 116 | parser.add_argument('-S', '--shard_size', type=int, default=10000, 117 | help='Maximum number of samples in a single WDS shard.') 118 | parser.add_argument('--ycbsight_real', type=str, 119 | help='Path to the dataset for real images') 120 | parser.add_argument('--ycbsight_sim', type=str, 121 | help='Path to the dataset for simulated images') 122 | args = parser.parse_args() 123 | data_gen(args) -------------------------------------------------------------------------------- /dataset_preproc/preprocess_visgel_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from multiprocessing import Pool 4 | 5 | from natsort import natsorted 6 | from utils import WdsWriter 7 | import webdataset as wds 8 | import cv2 9 | import json 10 | import argparse 11 | 12 | 13 | def variance_of_laplacian(image): 14 | return cv2.Laplacian(image, cv2.CV_64F).var() 15 | 16 | def gen_variance(args): 17 | path, n_folder, rec_folder = args 18 | ref_img = "frame0000.jpg" 19 | if not os.path.isfile(os.path.join(path, n_folder, rec_folder, ref_img)): 20 | print("Ref image does not exist!") 21 | print(os.path.join(path, n_folder, rec_folder, ref_img)) 22 | return 23 | 24 | ref_img = cv2.imread(os.path.join(path, n_folder, rec_folder, ref_img)) 25 | gray_float_ref = cv2.cvtColor(ref_img, cv2.COLOR_RGB2GRAY).astype(float) 26 | weight_dict = {} 27 | for fn in os.listdir(os.path.join(path, n_folder, rec_folder)): 28 | if fn.endswith(".jpg"): 29 | des_img = cv2.imread(os.path.join(path, n_folder, rec_folder, fn)) 30 | gray_float_des = cv2.cvtColor(des_img, cv2.COLOR_RGB2GRAY).astype(float) 31 | weight = variance_of_laplacian(gray_float_des - gray_float_ref) 32 | weight_dict[fn] = weight 33 | np.save(os.path.join(path, n_folder, rec_folder, "variance.npy"), weight_dict) 34 | print("Variance calculated and saved for ", os.path.join(path, n_folder, rec_folder)) 35 | 36 | def data_gen_calc_variance_only(args) -> None: 37 | """ 38 | Calulate variance only without moving images 39 | """ 40 | arg_lst = [] 41 | for path in [args.seen_path, args.unseen_path]: 42 | for n_folder in os.listdir(path): 43 | for rec_folder in os.listdir(os.path.join(path, n_folder)): 44 | arg_lst.append((path, n_folder, rec_folder)) 45 | 46 | print("Calculating variance for ", len(arg_lst), " folders") 47 | with Pool(processes=args.num_processes) as pool: 48 | print("Using ", pool._processes, " processes") 49 | pool.map(gen_variance, arg_lst) 50 | 51 | def data_gen_create_webdataset(args) -> None: 52 | """ 53 | Create data_gen_create_webdataset 54 | """ 55 | output_dir = args.output_folder 56 | domain = "visgel" 57 | 58 | domain_dir = os.path.join(output_dir, domain) 59 | os.makedirs(domain_dir, exist_ok=True) 60 | 61 | img_paths, img_names, vars = [], [], [] 62 | for path in [args.seen_path, args.unseen_path]: 63 | for n_folder in os.listdir(path): 64 | print(f"Processing {domain} - {path} - {n_folder}") 65 | for rec_folder in os.listdir(os.path.join(path, n_folder)): 66 | var_dict = np.load(os.path.join(path, n_folder, rec_folder, "variance.npy"), allow_pickle=True).item() 67 | for img_fn in var_dict: 68 | prefix = "seen" if (path == args.seen_path) else "unseen" 69 | new_fn = f"{prefix}_{n_folder}_{rec_folder}_{img_fn}" 70 | # os.system(f"ln -s {os.path.join(path, n_folder, rec_folder, img_fn)} {os.path.join(domain_dir, new_fn)}") 71 | img_names.append(new_fn) 72 | img_paths.append(os.path.join(path, n_folder, rec_folder, img_fn)) 73 | vars.append(var_dict[img_fn]) 74 | wds = WdsWriter(shard_size=args.shard_size) 75 | wds.add_imgs(img_paths=img_paths, img_names=img_names) 76 | wds.add_labels(vars=vars) 77 | wds.save(domain_dir) 78 | print("Labels saved for ", domain, " Total images: ", len(img_names)) 79 | 80 | def downsample_visgel(args) -> None: 81 | """ 82 | Turns out that VisGel contains quite a lot flat images. 83 | This script downsamples the images to roughly 1/2 of the original size to make the training more balanced. 84 | """ 85 | 86 | def identity(x): 87 | return x 88 | 89 | output_dir = os.path.join(args.output_folder, "visgel_downsampled") 90 | 91 | for sub_dir in ["train", "val"]: 92 | ori_dir = os.path.join(args.output_folder, "visgel", sub_dir) 93 | # datasets under data_dir should have the format of data-xxxxxx.tar 94 | tars = natsorted([d for d in os.listdir(ori_dir) if d.startswith("data-")]) 95 | _start_idx = os.path.splitext(os.path.basename(tars[0]))[0].strip("data-") 96 | _end_idx = os.path.splitext(os.path.basename(tars[-1]))[0].strip("data-") 97 | with open(os.path.join(ori_dir, "count.txt"), "r") as f: 98 | ori_cnt = int(f.readline()) 99 | url = os.path.join(ori_dir, f"data-{{{_start_idx}..{_end_idx}}}.tar") 100 | ori_dataset = ( 101 | wds.WebDataset(url) 102 | .decode("pil") 103 | .to_tuple("__key__", "jpg", "json") 104 | .map_tuple(identity, identity, identity) 105 | .shuffle(1000) 106 | ) 107 | 108 | # all_vars = [] 109 | # for i, (key, image, data) in enumerate(ori_dataset): 110 | # all_vars.append(data["vars"]) 111 | # if i > 100000: 112 | # break 113 | # all_vars = np.array(all_vars) 114 | # print(f"median val: {np.median(all_vars)}, mean val: {np.mean(all_vars)}") 115 | ## median val: 15.168044748454403, mean val: 19.698067996121345 116 | 117 | os.makedirs(os.path.join(output_dir, sub_dir), exist_ok=True) 118 | data_cnt = 0 119 | with wds.ShardWriter(os.path.join(output_dir, sub_dir, "data-%06d.tar"), maxcount=args.shard_size) as sink: 120 | for key, image, data in ori_dataset: 121 | sample = { 122 | "__key__": key, 123 | "jpg": image, 124 | "json": json.dumps(data), 125 | } 126 | if data["vars"] > 18: 127 | data_cnt += 1 128 | if data_cnt % 100000 == 0: 129 | print(f"...donsampling {data_cnt} images from {ori_cnt} for {sub_dir}") 130 | sink.write(sample) 131 | continue 132 | with open(os.path.join(output_dir, sub_dir, "count.txt"), 'w') as f: 133 | f.write(str(data_cnt)) 134 | print(f"Downsampled {data_cnt} images from {ori_cnt} for {sub_dir}") 135 | 136 | if __name__ == "__main__": 137 | parser = argparse.ArgumentParser(description='Dataset pre-processor for VisGel dataset') 138 | parser.add_argument('-O', '--output_folder', type=str, 139 | help='Output folder for the pre-processed dataset') 140 | parser.add_argument('-S', '--shard_size', type=int, default=10000, 141 | help='Maximum number of samples in a single WDS shard.') 142 | parser.add_argument('-N', '--num_processes', type=int, default=1, 143 | help='Number of processes to use for variance calculation.') 144 | parser.add_argument('--seen_path', type=str, 145 | help='Path to the seen part of the dataset') 146 | parser.add_argument('--unseen_path', type=str, 147 | help='Path to the unseen part of the dataset') 148 | parser.add_argument('--no_downsampling', action='store_true', 149 | help='Avoid downsampling the images if specified.') 150 | args = parser.parse_args() 151 | 152 | data_gen_calc_variance_only(args) 153 | data_gen_create_webdataset(args) 154 | if not args.no_downsampling: 155 | downsample_visgel(args) 156 | -------------------------------------------------------------------------------- /t3/task_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for task-specific processing 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | from scipy.spatial.transform import Rotation as R 8 | from autolab_core import RigidTransform 9 | import unittest 10 | from ._calandra17_label_dict import CALANDRA17_LABEL_DICT 11 | 12 | def process_touch_and_go_label(label): 13 | """ 14 | Process the classification label for touch_and_go dataset. 15 | """ 16 | return int(label["contact_class"]) + 2 17 | 18 | def process_yuan18_textile_type_label(label): 19 | """ 20 | Process the textile_type classification label for yuan18 dataset. 21 | """ 22 | if "textile_type" not in label: 23 | return int(label["taxtile_type"]) # typo in the dataset 24 | return int(label["textile_type"]) 25 | 26 | def process_yuan18_smoothness_label(label): 27 | """ 28 | Process the smoothness classification label for yuan18 dataset. 29 | """ 30 | return int(label["smoothness"]) 31 | 32 | def process_yuan18_fuzziness_label(label): 33 | """ 34 | Process the fuzziness classification label for yuan18 dataset. 35 | """ 36 | return int(label["fuzziness"]) 37 | 38 | def process_calandra17_obj_label(label): 39 | """ 40 | Process the object classification label for calandra17 dataset. 41 | 106 classes in total. 42 | """ 43 | if not label["has_contact"]: 44 | return CALANDRA17_LABEL_DICT["no_contact"] 45 | return CALANDRA17_LABEL_DICT[label["object_name"]] 46 | 47 | def process_objectfolder_real_label(label): 48 | """ 49 | Process the material classification label for objectfolder_real dataset. 50 | """ 51 | return int(label["material_idx"]) 52 | 53 | def process_cnc_cls_label(label): 54 | """ 55 | Process the classification label for cnc dataset. 56 | """ 57 | return int(label["obj_idx"]) 58 | 59 | #################### 60 | # Classification processing # 61 | #################### 62 | 63 | def count_classification_topk(pred, Y, k=5): 64 | """ 65 | Count top-k success. 66 | """ 67 | if isinstance(Y, torch.Tensor): 68 | Y = Y.cpu().numpy() 69 | if isinstance(pred, torch.Tensor): 70 | pred = pred.cpu().numpy() 71 | topk = np.argsort(pred, axis=1)[:, -k:] 72 | return np.sum(np.any(topk == Y[:, None], axis=1)) 73 | 74 | 75 | #################### 76 | # Pose processing # 77 | #################### 78 | 79 | def rotmat_to_d6(r): 80 | """ 81 | convert rotation matrix to d6 rotation 82 | "On the Continuity of Rotation Representations in Neural Networks" https://arxiv.org/abs/1812.07035 83 | """ 84 | return np.concatenate((r[:,0], r[:, 1])) 85 | 86 | def quat_to_d6(q): 87 | """ 88 | convert quaternion [qx, qy, qz, qw] to d6 rotation 89 | """ 90 | r = R.from_quat(q).as_matrix() 91 | return rotmat_to_d6(r) 92 | 93 | def pose_to_d9(pose: np.ndarray): 94 | """ 95 | Convert pose to d9 representation. 96 | """ 97 | if len(pose.shape) == 1: 98 | return np.concatenate((pose[:3], quat_to_d6(pose[3:]))) 99 | else: 100 | res = np.zeros((pose.shape[0], 9)) 101 | res[:, :3] = pose[:, :3] 102 | for i, q in enumerate(pose[:, 3:]): 103 | res[i, 3:] = quat_to_d6(q) 104 | return res 105 | 106 | def d6_to_rotmat(d6): 107 | ''' 108 | Convert 6d representation to rotation matrix. 109 | ''' 110 | a1, a2 = d6[:3], d6[3:] 111 | b1 = a1 / np.linalg.norm(a1) 112 | b2 = a2 - np.dot(b1, a2) * b1 113 | b2 = b2 / np.linalg.norm(b2) 114 | b3 = np.cross(b1, b2) 115 | return np.stack((b1, b2, b3), axis=1) 116 | 117 | def d6_to_anga(d6): 118 | ''' 119 | Convert 6d representation to angle-axis representation. 120 | ''' 121 | rotmat = d6_to_rotmat(d6) 122 | return R.from_matrix(rotmat).as_rotvec() 123 | 124 | def calc_delta_between_pose(in1: np.ndarray, in2: np.ndarray): 125 | """ 126 | Calculate the delta between two [x, y, z, q_x, q_y, q_z, q_w]. Return a d9 pose. 127 | Input and output can be batched. 128 | """ 129 | def _single(a, b): 130 | mat_1 = np.eye(4) 131 | mat_1[:3, 3] = a[:3] 132 | mat_1[:3, :3] = R.from_quat(a[3:]).as_matrix() 133 | 134 | mat_2 = np.eye(4) 135 | mat_2[:3, 3] = b[:3] 136 | mat_2[:3, :3] = R.from_quat(b[3:]).as_matrix() 137 | 138 | delta = np.linalg.pinv(mat_1).dot(mat_2) 139 | delta_xyz = delta[:3, 3] 140 | delta_rot = rotmat_to_d6(delta[:3, :3]) 141 | return np.concatenate((delta_xyz, delta_rot)) 142 | 143 | if len(in1.shape) == 1: 144 | return _single(in1, in2) 145 | else: 146 | res = np.zeros((in1.shape[0], 9)) 147 | for i, (a, b) in enumerate(zip(in1, in2)): 148 | res[i] = _single(a, b) 149 | return res 150 | 151 | def d9_normalize(d9: np.ndarray, mean: np.ndarray, std: np.ndarray): 152 | """ 153 | Normalize d9 pose. 154 | if mean and std are 3d, only apply normalization to translations. 155 | """ 156 | res = d9.copy() 157 | if len(mean) == 3: 158 | res[:, :3] -= mean 159 | res[:, :3] /= std 160 | elif len(mean) == 9: 161 | res -= mean 162 | res /= std 163 | else: 164 | raise ValueError("Invalid mean and std shape") 165 | return res 166 | 167 | def d9_denormalize(d9: np.ndarray, mean: np.ndarray, std: np.ndarray): 168 | """ 169 | Denormalize d9 pose. 170 | if mean and std are 3d, only apply normalization to translations. 171 | """ 172 | res = d9.copy() 173 | if len(mean) == 3: 174 | res[:, :3] *= std 175 | res[:, :3] += mean 176 | elif len(mean) == 9: 177 | res *= std 178 | res += mean 179 | else: 180 | raise ValueError("Invalid mean and std shape") 181 | return res 182 | 183 | @torch.no_grad() 184 | def tra_rmse(pred, gt, denormalize_func=None): 185 | """ 186 | Calculate the RMSE of translations. 187 | Both inputs should either be np.ndarray or torch.Tensor. 188 | """ 189 | if isinstance(gt, torch.Tensor): 190 | gt = gt.cpu().numpy() 191 | if isinstance(pred, torch.Tensor): 192 | pred = pred.cpu().numpy() 193 | if denormalize_func is not None: 194 | gt = denormalize_func(gt) 195 | pred = denormalize_func(pred) 196 | return np.sqrt(np.mean(np.sum((gt[:, :3] - pred[:, :3]) ** 2, axis=1))) 197 | 198 | @torch.no_grad() 199 | def rot_rmse(pred, gt, denormalize_func=None): 200 | """ 201 | Calculate the RMSE of rotations. 202 | Both inputs should either be np.ndarray or torch.Tensor. 203 | Assume no normalization done on rotation. 204 | Output is in degrees. 205 | TODO: implement 3d rotation RMSE. 206 | """ 207 | if isinstance(gt, torch.Tensor): 208 | gt = gt.cpu().numpy() 209 | if isinstance(pred, torch.Tensor): 210 | pred = pred.cpu().numpy() 211 | if denormalize_func is not None: 212 | gt = denormalize_func(gt) 213 | pred = denormalize_func(pred) 214 | square_errors = [] 215 | for i in range(gt.shape[0]): 216 | gt_rot = d6_to_rotmat(gt[i, 3:]) 217 | pred_rot = d6_to_rotmat(pred[i, 3:]) 218 | delta = np.linalg.pinv(gt_rot).dot(pred_rot) 219 | delta_rotvec = R.from_matrix(delta).as_rotvec(degrees=True) 220 | square_errors.append(np.sum(delta_rotvec ** 2)) 221 | return np.sqrt(np.mean(square_errors)) 222 | 223 | def sample_transformation(r_tra, r_rot, from_frame, to_frame): 224 | ''' 225 | Sample a transformation matrix with range_translation and range_rotation. 226 | ''' 227 | T = RigidTransform(from_frame=from_frame, to_frame=to_frame) 228 | T.translation = (np.random.random(3) - 0.5) * 2 * r_tra 229 | rot = R.from_euler('zyx', (np.random.random(3) - 0.5) * 2 * r_rot, degrees=True) 230 | T.rotation = rot.as_matrix() 231 | return T 232 | 233 | class TestPoseConversions(unittest.TestCase): 234 | def test_d6(self): 235 | rot = sample_transformation(0.1, 10, "world", "camera").rotation 236 | d6 = rotmat_to_d6(rot) 237 | self.assertTrue(np.allclose(d6_to_rotmat(d6), rot)) 238 | 239 | if __name__ == "__main__": 240 | unittest.main() -------------------------------------------------------------------------------- /t3/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data loader for MAE for Transferable Tactile Transformer (T3) using the FoTa dataset 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | 9 | import torch 10 | from torch.utils.data import IterableDataset, DataLoader 11 | import hydra 12 | from functools import partial 13 | import os 14 | import numpy as np 15 | from torchvision import transforms 16 | from natsort import natsorted 17 | import webdataset as wds 18 | from .task_utils import calc_delta_between_pose, d9_normalize, d9_denormalize 19 | 20 | class TactileImageDatasetBase(IterableDataset): 21 | def __init__(self, 22 | data_dir: str, 23 | label_func, 24 | batch_size: int, 25 | encoder_domain: str, 26 | decoder_domain: str, 27 | img_norm=None, 28 | random_resize_crop=True, 29 | random_hv_flip_prob=0.5, # probability of flipping the image 30 | color_jitter=None, # dict with keys brightness, contrast, saturation, hue 31 | **kwargs): 32 | self.encoder_domain = encoder_domain 33 | self.decoder_domain = decoder_domain 34 | self.data_dir = data_dir 35 | self.batch_size = batch_size 36 | # datasets under data_dir should have the format of data-xxxxxx.tar 37 | tars = natsorted([d for d in os.listdir(data_dir) if d.startswith("data-")]) 38 | _start_idx = os.path.splitext(os.path.basename(tars[0]))[0].strip("data-") 39 | _end_idx = os.path.splitext(os.path.basename(tars[-1]))[0].strip("data-") 40 | with open(os.path.join(data_dir, "count.txt"), "r") as f: 41 | self.length = int(f.readline()) 42 | 43 | url = os.path.join(data_dir, f"data-{{{_start_idx}..{_end_idx}}}.tar") 44 | 45 | img_mean = img_norm["mean"] if img_norm is not None else [0.485, 0.456, 0.406] 46 | img_std = img_norm["std"] if img_norm is not None else [0.229, 0.224, 0.225] 47 | 48 | 49 | preproc = [] 50 | if random_resize_crop: 51 | preproc.append(transforms.RandomResizedCrop(224)) 52 | else: 53 | preproc.append(transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)) 54 | preproc.append(transforms.CenterCrop(224)) 55 | if random_hv_flip_prob > 1e-4: 56 | preproc.append(transforms.RandomHorizontalFlip(p=random_hv_flip_prob)) 57 | preproc.append(transforms.RandomVerticalFlip(p=random_hv_flip_prob)) 58 | if color_jitter is not None: 59 | preproc.append(transforms.ColorJitter(**color_jitter)) 60 | 61 | normalize = transforms.Normalize( 62 | mean=img_mean, 63 | std=img_std) 64 | 65 | preproc.extend([ 66 | transforms.ToTensor(), 67 | normalize, 68 | ]) 69 | 70 | # create an inv-normalize function for visualization 71 | self.inv_normalize = transforms.Normalize( 72 | mean=[-m/s for m, s in zip(img_mean, img_std)], 73 | std=[1/s for s in img_std] 74 | ) 75 | 76 | self.dataset = ( 77 | wds.WebDataset(url, shardshuffle=True) 78 | .shuffle(10000) 79 | .decode("pil") 80 | .to_tuple("jpg", "json") 81 | .map_tuple(transforms.Compose(preproc), label_func) 82 | .batched(batch_size) 83 | ) 84 | 85 | def __len__(self): 86 | return self.length // self.batch_size 87 | 88 | def __iter__(self): 89 | for img_batch, label_batch in self.dataset: 90 | # discard the last incomplete batch 91 | if len(img_batch) < self.batch_size: 92 | break 93 | yield { 94 | "X": img_batch, "Y": label_batch, 95 | "encoder_domain": self.encoder_domain, "decoder_domain": self.decoder_domain, 96 | "inv_normalize": self.inv_normalize} 97 | 98 | def get_dataloader(self, num_workers, **kwargs): 99 | return DataLoader(self, batch_size=None, shuffle=False, num_workers=num_workers, **kwargs) 100 | 101 | class SingleTowerMAEDataset(TactileImageDatasetBase): 102 | ''' 103 | This dataloader sets both X and Y to be the images. 104 | Use this for MAE where the target (Y) is also an image. 105 | ''' 106 | def __init__(self, **kwargs): 107 | label_func = lambda x: 0 # a dummy label for mae 108 | super().__init__(label_func=label_func, **kwargs) 109 | 110 | def __iter__(self): 111 | for img_batch, label_batch in self.dataset: 112 | # discard the last incomplete batch 113 | if len(img_batch) < self.batch_size: 114 | break 115 | yield { 116 | "X": img_batch, "Y": img_batch, 117 | "encoder_domain": self.encoder_domain, "decoder_domain": self.decoder_domain, 118 | "inv_normalize": self.inv_normalize} 119 | 120 | class SingleTowerClassificationDataset(TactileImageDatasetBase): 121 | ''' 122 | This dataloader sets Y based on a label_process_func which extracts a desired label. 123 | ''' 124 | def __init__(self, 125 | label_process_func, # a function to process the label 126 | **kwargs): 127 | if isinstance(label_process_func, str): 128 | label_process_func = hydra.utils.get_method(label_process_func) 129 | super().__init__(label_func=label_process_func, **kwargs) 130 | 131 | class SingleTowerVarianceDataset(TactileImageDatasetBase): 132 | ''' 133 | This dataloader sets Y to be the variance of laplacian of the images. 134 | ''' 135 | def __init__(self, **kwargs): 136 | label_func = lambda x: x["vars"] 137 | super().__init__(label_func=label_func, **kwargs) 138 | 139 | def __iter__(self): 140 | for img_batch, label_batch in self.dataset: 141 | # discard the last incomplete batch 142 | if len(img_batch) < self.batch_size: 143 | break 144 | yield { 145 | "X": img_batch, "Y": torch.from_numpy(label_batch.reshape(-1, 1)).type(torch.FloatTensor), 146 | "encoder_domain": self.encoder_domain, "decoder_domain": self.decoder_domain, 147 | "inv_normalize": self.inv_normalize} 148 | 149 | class DoubleTowerPoseEstimationDataset(TactileImageDatasetBase): 150 | ''' 151 | This dataloader sets X to be two images and Y to be the difference in pose between the two images. 152 | The two images in X are randomly rolled to create the pair. 153 | ''' 154 | def __init__(self, 155 | pose_dim, 156 | label_norm, 157 | **kwargs): 158 | assert pose_dim in [3, 6] 159 | if pose_dim == 3: 160 | label_func = lambda x: np.array([x["x_mm"], x["y_mm"], x["z_mm"]]) 161 | else: 162 | label_func = lambda x: np.array([ 163 | x["x_mm"], x["y_mm"], x["z_mm"], 164 | x["quat_x"], x["quat_y"], x["quat_z"], x["quat_w"]]) 165 | super().__init__(label_func=label_func, **kwargs) 166 | self.pose_dim = pose_dim 167 | self.Y_mean = label_norm["mean"] 168 | self.Y_std = label_norm["std"] 169 | # create an inv-normalize function for labels 170 | if pose_dim == 3: 171 | self.label_inv_normalize = lambda x: x * self.Y_std + self.Y_mean 172 | else: 173 | self.label_inv_normalize = partial(d9_denormalize, mean=self.Y_mean, std=self.Y_std) 174 | 175 | def __iter__(self): 176 | # For double tower, roll each batch by a random value to get the other images 177 | for img_batch, label_batch in self.dataset: 178 | # discard the last incomplete batch 179 | if len(img_batch) < self.batch_size: 180 | break 181 | amount_to_roll = np.random.randint(len(img_batch)) 182 | other_img_batch = torch.roll(img_batch, amount_to_roll, dims=0) # torch.flip(img_batch, dims=[0]) 183 | other_label_batch = np.roll(label_batch, amount_to_roll, axis=0) # np.flip(label_batch, axis=0) 184 | 185 | if self.pose_dim == 3: 186 | Y = other_label_batch - label_batch 187 | Y = (Y - self.Y_mean) / self.Y_std 188 | elif self.pose_dim == 6: 189 | Y = calc_delta_between_pose(label_batch, other_label_batch) 190 | Y = d9_normalize(Y, self.Y_mean, self.Y_std) 191 | 192 | Y = torch.from_numpy(Y).type(torch.FloatTensor) 193 | yield { 194 | "X": [img_batch, other_img_batch], "Y": Y, 195 | "encoder_domain": self.encoder_domain, "decoder_domain": self.decoder_domain, 196 | "inv_normalize": self.inv_normalize, 197 | "label_inv_normalize": self.label_inv_normalize} 198 | 199 | class WeightedDataLoader: 200 | def __init__(self, dataloaders, weight_type="root"): 201 | """ 202 | This dataloader combines multiple dataloaders into one to be used in training. 203 | :param dataloaders: list of pytorch dataloaders 204 | :param weight_type: type of weighting, e.g., "equal", "invlinear", "root" 205 | """ 206 | self.dataloaders = dataloaders 207 | datasizes = np.array([len(d) for d in dataloaders], dtype=float) 208 | if weight_type == 'equal': 209 | self.weights = np.ones(len(dataloaders)) / len(dataloaders) 210 | elif weight_type == "invlinear": 211 | self.weights = (1. / datasizes) / np.sum(1. / datasizes) 212 | elif weight_type == "root": 213 | inv_root = np.power(datasizes, 1.0 / 2) 214 | self.weights = inv_root / np.sum(inv_root) 215 | else: 216 | print(f"weight type '{weight_type}' not defined. Using equal weights.") 217 | self.weights = np.ones(len(dataloaders)) / len(dataloaders) 218 | 219 | self.loader_iters = [iter(dataloader) for dataloader in self.dataloaders] 220 | 221 | def __iter__(self): 222 | return self 223 | 224 | def __next__(self): 225 | # Choose a dataloader based on weights 226 | chosen_dataloader_idx = np.random.choice(len(self.dataloaders), p=self.weights) 227 | chosen_loader_iter = self.loader_iters[chosen_dataloader_idx] 228 | try: 229 | return next(chosen_loader_iter) 230 | except StopIteration: 231 | # Handle case where a dataloader is exhausted. Reinitialize the iterator. 232 | self.loader_iters[chosen_dataloader_idx] = iter(self.dataloaders[chosen_dataloader_idx]) 233 | return self.__next__() 234 | 235 | def __len__(self): 236 | return np.sum([len(dataloader) for dataloader in self.dataloaders]) -------------------------------------------------------------------------------- /t3/models/decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decoders for Transferable Tactile Transformer (T3) models 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | import os 9 | import torch 10 | from torch import nn 11 | from typing import Literal 12 | 13 | import torch 14 | import torch.nn as nn 15 | from .nn_utils import makeMLP, get_2d_sincos_pos_embed, CrossAttentionBlock, get_device 16 | import timm.models.vision_transformer as timm_vit 17 | from torchvision.models.resnet import BasicBlock 18 | 19 | from t3.utils import logging 20 | from math import sqrt 21 | 22 | class Decoder(nn.Module): 23 | def __init__(self, **kwargs): 24 | super().__init__() 25 | 26 | def freeze(self): 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def unfreeze(self): 31 | for param in self.parameters(): 32 | param.requires_grad = True 33 | 34 | def save(self, path): 35 | torch.save(self.state_dict(), path) 36 | 37 | def load(self, path): 38 | kwargs = {} 39 | if not torch.cuda.is_available(): 40 | kwargs['map_location'] = get_device() 41 | if os.path.exists(path): 42 | logging(f"Loading decoder from weights from {path}", True, "green") 43 | self.load_state_dict(torch.load(path, **kwargs)) 44 | else: 45 | logging(f"Decoder weights not found at {path}. Skipping", True, "warning") 46 | 47 | class IdentityDecoder(Decoder): 48 | def __init__(self, **kwargs): 49 | super().__init__() 50 | 51 | def forward(self, x): 52 | return x 53 | 54 | class MLPDecoder(Decoder): 55 | def __init__(self, 56 | input_dim, 57 | output_dim, 58 | hidden_dims, 59 | dropout_p=0.1, 60 | tanh_end=False, 61 | ln=False, 62 | transformer_upstream=False, # if True, the input is assumed to be a sequence of tokens 63 | pooling_type: Literal['global', 'cls'] = 'cls', # pooling type for transformer upstream 64 | **kwargs): 65 | super().__init__() 66 | self.transformer_upstream = transformer_upstream 67 | self.pooling_type = pooling_type 68 | 69 | self.model = makeMLP(input_dim, output_dim, hidden_dims, dropout_p, tanh_end, ln) 70 | 71 | def forward(self, x): 72 | if self.transformer_upstream: 73 | if self.pooling_type == 'global': 74 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 75 | elif self.pooling_type == 'cls': 76 | x = x[:, 0] 77 | return self.model(x) 78 | 79 | class PoolingDecoder(Decoder): 80 | """Only pooling the transformer output""" 81 | def __init__(self, 82 | pooling_type: Literal['global', 'cls'] = 'cls', # pooling type for transformer upstream 83 | **kwargs): 84 | super().__init__() 85 | self.pooling_type = pooling_type 86 | 87 | def forward(self, x): 88 | if self.pooling_type == 'global': 89 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 90 | elif self.pooling_type == 'cls': 91 | x = x[:, 0] 92 | return x 93 | 94 | class MLPTwoTowerDecoder(Decoder): 95 | def __init__(self, 96 | input_dim, 97 | output_dim, 98 | hidden_dims, 99 | dropout_p=0.1, 100 | tanh_end=False, 101 | ln=False, 102 | **kwargs): 103 | super().__init__() 104 | 105 | self.model = makeMLP(input_dim, output_dim, hidden_dims, dropout_p, tanh_end, ln) 106 | 107 | def forward(self, x1, x2): 108 | x = torch.cat([x1, x2], dim=1) 109 | return self.model(x) 110 | 111 | class CNNFCDecoder(Decoder): 112 | def __init__(self, 113 | inplanes, # input channels of each tower 114 | fc_hidden_dims, 115 | output_dim, 116 | stride, 117 | dropout_p=0.1, 118 | tanh_end=False, 119 | n_tower=2, # number of towers 120 | transformer_upstream=False, # if True, the input is assumed to be a sequence of tokens 121 | **kwargs): 122 | super().__init__() 123 | self.transformer_upstream = transformer_upstream 124 | self.n_tower = n_tower 125 | 126 | self.norm_layer = nn.BatchNorm2d 127 | 128 | downsample = nn.Sequential( 129 | nn.Conv2d(inplanes, inplanes, kernel_size=1, stride=stride, bias=False), 130 | self.norm_layer(inplanes), 131 | ) 132 | 133 | self.conv_layers = nn.Sequential( 134 | BasicBlock( 135 | inplanes, inplanes, stride=stride, downsample=downsample, norm_layer=self.norm_layer 136 | ), 137 | BasicBlock( 138 | inplanes, inplanes, norm_layer=self.norm_layer 139 | ) 140 | ) 141 | 142 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 143 | 144 | branch_fc_layers = [nn.Flatten(), nn.Linear(inplanes, inplanes // 2), nn.SiLU()] 145 | if dropout_p > 0: 146 | branch_fc_layers.append(nn.Dropout(dropout_p)) 147 | self.branch_fc = nn.Sequential(*branch_fc_layers) 148 | 149 | output_inplanes = (inplanes // 2) * n_tower 150 | self.output_fc = makeMLP(output_inplanes, output_dim, fc_hidden_dims, dropout_p, tanh_end, ln=False) 151 | 152 | def reshape_transformer_input(self, x): 153 | """ 154 | Reshape transformer input (B, T, C) to be a normal image-kind (B, C, H, W) for Conv layers 155 | """ 156 | x = x[:, 1:, :] # remove the cls token 157 | B, T, C = x.shape 158 | hw = int(sqrt(T)) 159 | assert hw * hw == T, "Input sequence length must be a perfect square" 160 | x = x.permute(0, 2, 1).reshape(B, C, hw, hw) 161 | return x 162 | 163 | def branch(self, x): 164 | x = self.conv_layers(x) 165 | x = self.avgpool(x) 166 | x = self.branch_fc(x) 167 | return x 168 | 169 | def forward(self, *xs): 170 | assert len(xs) == self.n_tower, f"Expected {self.n_tower} inputs, got {len(xs)}" 171 | if self.transformer_upstream: 172 | xs = [self.reshape_transformer_input(x) for x in xs] 173 | 174 | xs = [self.branch(x) for x in xs] 175 | x = torch.cat(xs, dim=1) 176 | x = self.output_fc(x) 177 | return x 178 | 179 | class MAEViTDecoder(Decoder): 180 | """ Masked Autoencoder with VisionTransformer backbone 181 | """ 182 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, 183 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 184 | mlp_ratio=4., norm_layer=nn.LayerNorm, **kwargs): 185 | super().__init__() 186 | 187 | # -------------------------------------------------------------------------- 188 | # MAE decoder specifics 189 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 190 | 191 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 192 | 193 | self.num_patches = (img_size // patch_size) ** 2 194 | 195 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 196 | 197 | self.decoder_blocks = nn.ModuleList([ 198 | timm_vit.Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 199 | for i in range(decoder_depth)]) 200 | 201 | self.decoder_norm = norm_layer(decoder_embed_dim) 202 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch 203 | # -------------------------------------------------------------------------- 204 | 205 | self.initialize_weights() 206 | 207 | def initialize_weights(self): 208 | # initialization 209 | # initialize (and freeze) pos_embed by sin-cos embedding 210 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True) 211 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 212 | 213 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 214 | torch.nn.init.normal_(self.mask_token, std=.02) 215 | 216 | # initialize nn.Linear and nn.LayerNorm 217 | self.apply(self._init_weights) 218 | 219 | def _init_weights(self, m): 220 | if isinstance(m, nn.Linear): 221 | # we use xavier_uniform following official JAX ViT: 222 | torch.nn.init.xavier_uniform_(m.weight) 223 | if isinstance(m, nn.Linear) and m.bias is not None: 224 | nn.init.constant_(m.bias, 0) 225 | elif isinstance(m, nn.LayerNorm): 226 | nn.init.constant_(m.bias, 0) 227 | nn.init.constant_(m.weight, 1.0) 228 | 229 | def forward(self, x): 230 | # embed tokens 231 | (x, mask, ids_restore) = x 232 | x = self.decoder_embed(x) 233 | 234 | # append mask tokens to sequence 235 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 236 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 237 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 238 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 239 | 240 | # add pos embed 241 | x = x + self.decoder_pos_embed 242 | 243 | # apply Transformer blocks 244 | for blk in self.decoder_blocks: 245 | x = blk(x) 246 | x = self.decoder_norm(x) 247 | 248 | # predictor projection 249 | x = self.decoder_pred(x) 250 | 251 | # remove cls token 252 | x = x[:, 1:, :] 253 | 254 | return (x, mask, ids_restore) 255 | 256 | class CrossMAEViTDecoder(MAEViTDecoder): 257 | """ 258 | CrossMAE with VisionTransformer backbone 259 | https://arxiv.org/pdf/2401.14391.pdf 260 | """ 261 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, 262 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 263 | mlp_ratio=4., norm_layer=nn.LayerNorm, **kwargs): 264 | super().__init__( 265 | img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, decoder_depth, 266 | decoder_num_heads, mlp_ratio, norm_layer, **kwargs) 267 | 268 | self.decoder_blocks = nn.ModuleList([ 269 | CrossAttentionBlock(embed_dim, decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) 270 | for i in range(decoder_depth)]) 271 | 272 | self.initialize_weights() 273 | 274 | def forward(self, x): 275 | # embed tokens 276 | (y, mask, ids_restore) = x 277 | 278 | N, L = ids_restore.shape 279 | 280 | # construct mask tokens 281 | x = self.decoder_pos_embed[:, 1:].masked_select(mask.bool().unsqueeze(-1)).reshape(N, -1, self.mask_token.shape[-1]) 282 | x = x + self.mask_token 283 | 284 | for i, blk in enumerate(self.decoder_blocks): 285 | x = blk(x, y) 286 | 287 | x = self.decoder_norm(x) 288 | x = self.decoder_pred(x) # N, L, patch_size**2 *3 289 | 290 | return (x, mask, ids_restore) -------------------------------------------------------------------------------- /t3/models/nn_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for Transferable Tactile Transformer (T3) 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.layers import Mlp, DropPath 13 | from typing import Optional, Tuple 14 | import numpy as np 15 | 16 | def get_device(): 17 | if torch.cuda.is_available(): 18 | device = "cuda:0" 19 | elif torch.backends.mps.is_available(): 20 | # Apple Silicon 21 | device = "mps" 22 | else: 23 | device = "cpu" 24 | return device 25 | 26 | def makeMLP(input_dim, 27 | output_dim, 28 | hidden_dims, 29 | dropout_p, 30 | tanh_end, 31 | ln): 32 | layers = [nn.Linear(input_dim, hidden_dims[0]), nn.SiLU()] 33 | for i in range(1, len(hidden_dims)): 34 | layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])) 35 | if dropout_p > 1e-5: 36 | layers.append(nn.Dropout(dropout_p)) 37 | if ln: 38 | layers.append(nn.LayerNorm(hidden_dims[i])) 39 | layers.append(nn.SiLU()) 40 | layers.append(nn.Linear(hidden_dims[-1], output_dim)) 41 | if tanh_end: 42 | layers.append(nn.Tanh()) 43 | return nn.Sequential(*layers) 44 | 45 | def makeCNN(input_channels, 46 | filters, 47 | kernel_size, 48 | stride, 49 | padding): 50 | layers = [ 51 | nn.Conv2d(input_channels, filters[0], kernel_size=7, stride=2, padding=3, bias=False), 52 | nn.ReLU(), 53 | nn.BatchNorm2d(filters[0]), 54 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)] 55 | for i in range(1, len(filters)): 56 | layers.append(nn.Conv2d(filters[i-1], filters[i], kernel_size=kernel_size, stride=stride, padding=padding, bias=False)) 57 | layers.append(nn.ReLU()) 58 | layers.append(nn.BatchNorm2d(filters[i])) 59 | 60 | layers.append(nn.AdaptiveAvgPool2d((1, 1))) 61 | layers.append(nn.Flatten()) 62 | return nn.Sequential(*layers) 63 | 64 | def findFlattenedSize(input_channels, img_size_x, img_size_y, *nns): 65 | out = torch.zeros((1, input_channels, img_size_x, img_size_y)) 66 | with torch.no_grad(): 67 | for nn in nns: 68 | out = nn(out) 69 | flattened_size = out.shape[1] 70 | return flattened_size 71 | 72 | # -------------------------------------------------------- 73 | # Position embedding utils https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 74 | # 75 | # 2D sine-cosine position embedding 76 | # References: 77 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 78 | # MoCo v3: https://github.com/facebookresearch/moco-v3 79 | # -------------------------------------------------------- 80 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 81 | """ 82 | grid_size: int of the grid height and width 83 | return: 84 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 85 | """ 86 | grid_h = np.arange(grid_size, dtype=np.float32) 87 | grid_w = np.arange(grid_size, dtype=np.float32) 88 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 89 | grid = np.stack(grid, axis=0) 90 | 91 | grid = grid.reshape([2, 1, grid_size, grid_size]) 92 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 93 | if cls_token: 94 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 95 | return pos_embed 96 | 97 | 98 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 99 | assert embed_dim % 2 == 0 100 | 101 | # use half of dimensions to encode grid_h 102 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 103 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 104 | 105 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 106 | return emb 107 | 108 | 109 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 110 | """ 111 | embed_dim: output dimension for each position 112 | pos: a list of positions to be encoded: size (M,) 113 | out: (M, D) 114 | """ 115 | assert embed_dim % 2 == 0 116 | omega = np.arange(embed_dim // 2, dtype=float) 117 | omega /= embed_dim / 2. 118 | omega = 1. / 10000**omega # (D/2,) 119 | 120 | pos = pos.reshape(-1) # (M,) 121 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 122 | 123 | emb_sin = np.sin(out) # (M, D/2) 124 | emb_cos = np.cos(out) # (M, D/2) 125 | 126 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 127 | return emb 128 | 129 | 130 | # -------------------------------------------------------- 131 | # Interpolate position embeddings for high-resolution 132 | # References: 133 | # DeiT: https://github.com/facebookresearch/deit 134 | # -------------------------------------------------------- 135 | def interpolate_pos_embed(model, checkpoint_model): 136 | if 'pos_embed' in checkpoint_model: 137 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 138 | embedding_size = pos_embed_checkpoint.shape[-1] 139 | num_patches = model.patch_embed.num_patches 140 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 141 | # height (== width) for the checkpoint position embedding 142 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 143 | # height (== width) for the new position embedding 144 | new_size = int(num_patches ** 0.5) 145 | # class_token and dist_token are kept unchanged 146 | if orig_size != new_size: 147 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 148 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 149 | # only the position tokens are interpolated 150 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 151 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 152 | pos_tokens = torch.nn.functional.interpolate( 153 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 154 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 155 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 156 | checkpoint_model['pos_embed'] = new_pos_embed 157 | 158 | class CrossAttention(nn.Module): 159 | def __init__(self, encoder_dim, decoder_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 160 | super().__init__() 161 | self.num_heads = num_heads 162 | head_dim = decoder_dim // num_heads 163 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 164 | self.scale = qk_scale or head_dim ** -0.5 165 | self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) 166 | self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) 167 | self.attn_drop = attn_drop 168 | self.proj = nn.Linear(decoder_dim, decoder_dim) 169 | self.proj_drop = nn.Dropout(proj_drop) 170 | 171 | def forward(self, x, y): 172 | """ 173 | query from decoder (x), key and value from encoder (y) 174 | """ 175 | B, N, C = x.shape 176 | Ny = y.shape[1] 177 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 178 | kv = self.kv(y).reshape(B, Ny, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 179 | k, v = kv[0], kv[1] 180 | 181 | attn = F.scaled_dot_product_attention( 182 | q, k, v, dropout_p=self.attn_drop, 183 | ) 184 | x = attn.transpose(1, 2).reshape(B, N, C) 185 | 186 | x = self.proj(x) 187 | x = self.proj_drop(x) 188 | return x 189 | 190 | class CrossAttentionBlock(nn.Module): 191 | 192 | def __init__(self, encoder_dim, decoder_dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 193 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 194 | super().__init__() 195 | 196 | self.norm1 = norm_layer(decoder_dim) 197 | self.cross_attn = CrossAttention( 198 | encoder_dim, decoder_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 199 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 200 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 201 | self.norm2 = norm_layer(decoder_dim) 202 | mlp_hidden_dim = int(decoder_dim * mlp_ratio) 203 | self.mlp = Mlp(in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 204 | 205 | def forward(self, x, y): 206 | """ 207 | x: decoder feature; y: encoder feature (after layernorm) 208 | """ 209 | x = x + self.drop_path(self.cross_attn(self.norm1(x), y)) 210 | x = x + self.drop_path(self.mlp(self.norm2(x))) 211 | return x 212 | 213 | # -------------------------------------------------------- 214 | # MAE utils for visualization 215 | # -------------------------------------------------------- 216 | @torch.no_grad() 217 | def mae_patchify(imgs, patch_size): 218 | """ 219 | imgs: (N, 3, H, W) 220 | x: (N, L, patch_size**2 *3) 221 | """ 222 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % patch_size == 0 223 | 224 | h = w = imgs.shape[2] // patch_size 225 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, patch_size, w, patch_size)) 226 | x = torch.einsum('nchpwq->nhwpqc', x) 227 | x = x.reshape(shape=(imgs.shape[0], h * w, patch_size**2 * 3)) 228 | return x 229 | 230 | @torch.no_grad() 231 | def mae_unpatchify(preds, patch_size): 232 | """ 233 | preds: (N, L, patch_size**2 *3) for original MAE 234 | return: (N, 3, H, W) 235 | """ 236 | h = w = int(preds.shape[1]**.5) 237 | assert h * w == preds.shape[1] 238 | 239 | preds = preds.reshape(shape=(preds.shape[0], h, w, patch_size, patch_size, 3)) 240 | preds = torch.einsum('nhwpqc->nchpwq', preds) 241 | imgs = preds.reshape(shape=(preds.shape[0], 3, h * patch_size, h * patch_size)) 242 | return imgs 243 | 244 | @torch.no_grad() 245 | def cross_mae_unpatchify(preds, imgs, masks, patch_size): 246 | """ 247 | preds: (N, l, patch_size**2 *3) for cross MAE 248 | mask: [N, L], 0 is keep, 1 is remove 249 | imgs: (N, 3, H, W) 250 | """ 251 | seq_mask = masks.unsqueeze(2) # (N, L) -> (N, L, 1) 252 | mask_patch = seq_mask.repeat(1, 1, patch_size**2 * 3) # (N, L, 1) -> (N, L, patch_size**2 *3) 253 | ori_img_patch = mae_patchify(imgs, patch_size) 254 | ori_img_patch[mask_patch > 0.5] = preds.reshape(-1) 255 | return mae_unpatchify(ori_img_patch, patch_size) 256 | 257 | @torch.no_grad() 258 | def mae_unpatchify_pred_only(preds, imgs, masks, patch_size): 259 | """ 260 | apply predicted patches to original images according to masks. 261 | Difference with mae_unpatchify is that instead of using preds for all patches, we only use preds for patches that are removed. 262 | 263 | preds: (N, L, patch_size**2 *3) for original MAE 264 | return: (N, 3, H, W) 265 | """ 266 | assert preds.shape[1] == masks.shape[1] 267 | preds_removed = preds.masked_select(masks.bool().unsqueeze(-1)).reshape(preds.shape[0], -1, preds.shape[-1]) 268 | return cross_mae_unpatchify(preds_removed, imgs, masks, patch_size) 269 | 270 | @torch.no_grad() 271 | def mae_apply_patchified_mask(imgs, masks, patch_size): 272 | """ 273 | imgs: (N, 3, H, W) 274 | mask: [N, L], 0 is keep, 1 is remove 275 | return: (N, 3, H, W), with 1.0 for masked patches 276 | """ 277 | seq_mask = masks.unsqueeze(2) # (N, L) -> (N, L, 1) 278 | seq_mask = seq_mask.repeat(1, 1, patch_size**2 * 3) # (N, L, 1) -> (N, L, patch_size**2 *3) 279 | img_mask = mae_unpatchify(seq_mask, patch_size) 280 | imgs_ret = imgs.clone() 281 | imgs_ret[img_mask > 0.5] = 1.0 282 | 283 | return imgs_ret -------------------------------------------------------------------------------- /t3/models/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Encoder for Transferable Tactile Transformer (T3) models 3 | 4 | Author: Jialiang (Alan) Zhao 5 | Email: alanzhao@csail.mit.edu 6 | MIT License 7 | """ 8 | import os 9 | import torch 10 | from torch import nn 11 | import torchvision 12 | import torch 13 | import torch.nn as nn 14 | 15 | import timm.models.vision_transformer as timm_vit 16 | from functools import partial 17 | from .nn_utils import makeCNN, findFlattenedSize, get_2d_sincos_pos_embed, get_device 18 | 19 | from t3.utils import logging 20 | 21 | class Encoder(nn.Module): 22 | def __init__(self): 23 | super(Encoder, self).__init__() 24 | 25 | def freeze(self): 26 | for param in self.parameters(): 27 | param.requires_grad = False 28 | 29 | def unfreeze(self): 30 | for param in self.parameters(): 31 | param.requires_grad = True 32 | 33 | def save(self, path): 34 | torch.save(self.state_dict(), path) 35 | 36 | def load(self, path): 37 | kwargs = {} 38 | if not torch.cuda.is_available(): 39 | kwargs['map_location'] = get_device() 40 | if os.path.exists(path): 41 | logging(f"Loading encoder from weights from {path}", True, "green") 42 | self.load_state_dict(torch.load(path, **kwargs)) 43 | else: 44 | # try to finetune from gs_green if it exists 45 | gs_green_path = path[:path.rfind('/')] + '/gs_green.pth' 46 | if os.path.exists(gs_green_path): 47 | logging(f"Encoder weights not found at {path}. Loading from gs_green", True, "warning") 48 | self.load_state_dict(torch.load(gs_green_path, **kwargs)) 49 | else: # if gs_green also doesn't exist, use random initialization 50 | logging(f"Encoder weights not found at {path}. Skipping", True, "warning") 51 | 52 | class IdentityEncoder(Encoder): 53 | def __init__(self, **kwargs): 54 | super().__init__() 55 | 56 | def forward(self, x): 57 | return x 58 | 59 | class ResNetEncoder(Encoder): 60 | def __init__(self, 61 | output_dim, 62 | model='resnet18', 63 | pretrained=True): 64 | super().__init__() 65 | if pretrained: 66 | weights = 'IMAGENET1K_V1' 67 | else: 68 | weights = None 69 | self.model = getattr(torchvision.models, model)(weights=weights) 70 | self.model.fc = nn.Linear(512, output_dim) 71 | 72 | def forward(self, x): 73 | return self.model(x) 74 | 75 | class CNNEncoder(Encoder): 76 | def __init__(self, 77 | output_dim, 78 | input_channels, 79 | img_size, 80 | filters, 81 | kernel_size, 82 | stride, 83 | padding, 84 | **kwargs): 85 | super(CNNEncoder, self).__init__() 86 | self.model = makeCNN(input_channels, filters, kernel_size, stride, padding) 87 | self.flattened_size = findFlattenedSize(input_channels, img_size, img_size, self.model) 88 | self.fc = nn.Linear(self.flattened_size, output_dim) 89 | 90 | def forward(self, x): 91 | x = self.model(x) 92 | x = self.fc(x) 93 | return x 94 | 95 | class ViTEncoder(timm_vit.VisionTransformer, Encoder): 96 | def __init__(self, 97 | embed_dim: int, 98 | num_heads: int, 99 | mlp_ratio: float, 100 | depth: int, 101 | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), 102 | **kwargs): 103 | super(ViTEncoder, self).__init__( 104 | embed_dim=embed_dim, 105 | num_heads=num_heads, 106 | mlp_ratio=mlp_ratio, 107 | depth=depth, 108 | norm_layer=norm_layer, 109 | **kwargs) 110 | 111 | self.blocks = nn.ModuleList([ 112 | timm_vit.Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 113 | for i in range(depth)]) 114 | 115 | del self.head # remove the head 116 | del self.norm # remove the normalization at the end, which will be added in the trunk 117 | 118 | def forward(self, x): 119 | B = x.shape[0] 120 | x = self.patch_embed(x) 121 | 122 | cls_tokens = self.cls_token.expand(B, -1, -1) 123 | x = torch.cat((cls_tokens, x), dim=1) 124 | x = x + self.pos_embed 125 | x = self.pos_drop(x) 126 | 127 | for blk in self.blocks: 128 | x = blk(x) 129 | return x 130 | 131 | def load(self, path): 132 | """ 133 | Positional embedding interpolation from DeiT 134 | https://github.com/facebookresearch/deit 135 | """ 136 | if os.path.exists(path): 137 | logging(f"Loading encoder from weights from {path}. Will apply pos_embed interpolation.", True, "green") 138 | checkpoint = torch.load(path, map_location='cpu') 139 | else: 140 | gs_green_path = path[:path.rfind('/')] + '/gs_green.pth' 141 | checkpoint = torch.load(gs_green_path, map_location='cpu') 142 | logging(f"Encoder weights not found at {path}. Loading from gs_green", True, "warning") 143 | if 'pos_embed' in checkpoint: 144 | pos_embed_checkpoint = checkpoint['pos_embed'] 145 | embedding_size = pos_embed_checkpoint.shape[-1] 146 | num_patches = self.patch_embed.num_patches 147 | num_extra_tokens = self.pos_embed.shape[-2] - num_patches 148 | # height (== width) for the checkpoint position embedding 149 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 150 | # height (== width) for the new position embedding 151 | new_size = int(num_patches ** 0.5) 152 | # class_token and dist_token are kept unchanged 153 | if orig_size != new_size: 154 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 155 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 156 | # only the position tokens are interpolated 157 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 158 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 159 | pos_tokens = torch.nn.functional.interpolate( 160 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 161 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 162 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 163 | checkpoint['pos_embed'] = new_pos_embed 164 | self.load_state_dict(checkpoint) 165 | 166 | class MAEViTEncoder(Encoder): 167 | """ 168 | Masked Autoencoder with VisionTransformer backbone 169 | https://arxiv.org/pdf/2111.06377.pdf 170 | """ 171 | def __init__(self, mask_ratio, 172 | img_size=224, patch_size=16, in_chans=3, 173 | embed_dim=768, depth=3, num_heads=12, 174 | mlp_ratio=4., norm_layer=nn.LayerNorm): 175 | super().__init__() 176 | self.mask_ratio = mask_ratio 177 | # -------------------------------------------------------------------------- 178 | # MAE encoder specifics 179 | self.patch_embed = timm_vit.PatchEmbed(img_size, patch_size, in_chans, embed_dim) 180 | num_patches = self.patch_embed.num_patches 181 | 182 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 183 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 184 | 185 | self.blocks = nn.ModuleList([ 186 | timm_vit.Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 187 | for i in range(depth)]) 188 | 189 | self.initialize_weights() 190 | 191 | def initialize_weights(self): 192 | # initialization 193 | # initialize (and freeze) pos_embed by sin-cos embedding 194 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 195 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 196 | 197 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 198 | w = self.patch_embed.proj.weight.data 199 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 200 | 201 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 202 | torch.nn.init.normal_(self.cls_token, std=.02) 203 | 204 | # initialize nn.Linear and nn.LayerNorm 205 | self.apply(self._init_weights) 206 | 207 | def _init_weights(self, m): 208 | if isinstance(m, nn.Linear): 209 | # we use xavier_uniform following official JAX ViT: 210 | torch.nn.init.xavier_uniform_(m.weight) 211 | if isinstance(m, nn.Linear) and m.bias is not None: 212 | nn.init.constant_(m.bias, 0) 213 | elif isinstance(m, nn.LayerNorm): 214 | nn.init.constant_(m.bias, 0) 215 | nn.init.constant_(m.weight, 1.0) 216 | 217 | def patchify(self, imgs): 218 | """ 219 | imgs: (N, 3, H, W) 220 | x: (N, L, patch_size**2 *3) 221 | """ 222 | p = self.patch_embed.patch_size[0] 223 | assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 224 | 225 | h = w = imgs.shape[2] // p 226 | x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) 227 | x = torch.einsum('nchpwq->nhwpqc', x) 228 | x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) 229 | return x 230 | 231 | def unpatchify(self, x): 232 | """ 233 | x: (N, L, patch_size**2 *3) 234 | imgs: (N, 3, H, W) 235 | """ 236 | p = self.patch_embed.patch_size[0] 237 | h = w = int(x.shape[1]**.5) 238 | assert h * w == x.shape[1] 239 | 240 | x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) 241 | x = torch.einsum('nhwpqc->nchpwq', x) 242 | imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) 243 | return imgs 244 | 245 | def random_masking(self, x): 246 | """ 247 | Perform per-sample random masking by per-sample shuffling. 248 | Per-sample shuffling is done by argsort random noise. 249 | x: [N, L, D], sequence 250 | """ 251 | N, L, D = x.shape # batch, length, dim 252 | len_keep = int(L * (1 - self.mask_ratio)) 253 | 254 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 255 | 256 | # sort noise for each sample 257 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 258 | ids_restore = torch.argsort(ids_shuffle, dim=1) 259 | 260 | # keep the first subset 261 | ids_keep = ids_shuffle[:, :len_keep] 262 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 263 | 264 | # generate the binary mask: 0 is keep, 1 is remove 265 | mask = torch.ones([N, L], device=x.device) 266 | mask[:, :len_keep] = 0 267 | # unshuffle to get the binary mask 268 | mask = torch.gather(mask, dim=1, index=ids_restore) 269 | 270 | return x_masked, mask, ids_restore 271 | 272 | def forward(self, x): 273 | # embed patches 274 | x = self.patch_embed(x) 275 | 276 | # add pos embed w/o cls token 277 | x = x + self.pos_embed[:, 1:, :] 278 | 279 | # masking: length -> length * mask_ratio 280 | x, mask, ids_restore = self.random_masking(x) 281 | 282 | # append cls token 283 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 284 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 285 | x = torch.cat((cls_tokens, x), dim=1) 286 | 287 | # apply Transformer blocks 288 | for blk in self.blocks: 289 | x = blk(x) 290 | # x = self.norm(x) 291 | 292 | return (x, mask, ids_restore) -------------------------------------------------------------------------------- /t3/pretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from t3.models import T3 4 | from t3.utils import logging 5 | from t3.data_loader import WeightedDataLoader 6 | from t3.models.nn_utils import mae_unpatchify, cross_mae_unpatchify, mae_unpatchify_pred_only, mae_apply_patchified_mask, get_device 7 | import hydra 8 | from omegaconf import OmegaConf 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from datetime import datetime 13 | from .utils import is_main_process, get_entry_or, make_dataset_pie_plot 14 | from .task_utils import rot_rmse, tra_rmse, count_classification_topk 15 | from torchvision.transforms.v2 import ToPILImage 16 | 17 | import os 18 | 19 | try: 20 | import wandb 21 | except ImportError: 22 | wandb = None 23 | print("wandb is not installed, will not log to wandb") 24 | 25 | class T3Pretrain: 26 | def __init__(self, cfg, run_id=None): 27 | self.cfg = cfg 28 | self.model = None 29 | self.train_dataset = None 30 | self.eval_dataset = None 31 | self.img_preprocessors = None 32 | self.optimizer = None 33 | self.scheduler = None 34 | 35 | self.encoder_frozen = False 36 | self.trunk_frozen = False 37 | self.scheduled_unfreeze_step = -1 38 | 39 | self.min_avg_val_loss = np.inf 40 | 41 | self.device = get_device() 42 | 43 | if run_id is None: 44 | self.run_id = self.gen_run_id() 45 | if "comment" in self.cfg: 46 | self.run_id += "-" + self.cfg.comment 47 | else: 48 | self.run_id = run_id 49 | if self.cfg.train.wandb and wandb and is_main_process(): 50 | wandb.init( 51 | project="TransferableTactileTransformer", 52 | config=OmegaConf.to_container(self.cfg, resolve=True), 53 | name=self.run_id, 54 | entity=self.cfg.train.wandb_entity, 55 | magic=False) 56 | # define our custom x axis metric 57 | wandb.define_metric("train/step") 58 | wandb.define_metric("eval/step") 59 | # set all other train/ metrics to use this step 60 | wandb.define_metric("train/*", step_metric="train/step") 61 | wandb.define_metric("eval/*", step_metric="eval/step") 62 | 63 | def gen_run_id(self): 64 | return f"{datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}" 65 | 66 | def setup_model(self): 67 | self.model = T3(self.cfg.network) 68 | self.encoder_frozen = False 69 | self.trunk_frozen = False 70 | self.scheduled_unfreeze_step = -1 71 | if get_entry_or(self.cfg.train, "freeze_encoder", False): 72 | self.model.freeze_encoder() 73 | self.encoder_frozen = True 74 | logging("Encoder will be frozen", True, "blue") 75 | if get_entry_or(self.cfg.train, "freeze_trunk", False): 76 | self.model.freeze_trunk() 77 | self.trunk_frozen = True 78 | logging("Trunk will be frozen", True, "blue") 79 | if self.encoder_frozen and self.trunk_frozen: 80 | if get_entry_or(self.cfg.train, "scheduled_unfreeze", False): 81 | self.scheduled_unfreeze_step = self.cfg.train.scheduled_unfreeze_step 82 | logging(f"Encoder and trunk will be frozen only until step {self.scheduled_unfreeze_step}", True, "blue") 83 | self.model.model_summary() 84 | 85 | def setup_optimizer(self): 86 | assert self.model is not None 87 | trunk_params = [v for k, v in self.model.named_parameters() if "trunk" in k] 88 | nontrunk_params = [v for k, v in self.model.named_parameters() if "trunk" not in k] 89 | params = [ 90 | {"params": trunk_params}, 91 | {"params": nontrunk_params, "lr": self.cfg.train.nontrunk_lr_scale * self.cfg.train.optimizer.lr}] 92 | self.optimizer = eval(self.cfg.train.optimizer["_target_"])( 93 | params=params, 94 | **{k: v for k, v in self.cfg.train.optimizer.items() if k != "_target_"}) 95 | self.scheduler = hydra.utils.instantiate(self.cfg.train.scheduler, optimizer=self.optimizer) 96 | 97 | def setup_dataset(self): 98 | self.train_dataset = {} 99 | self.eval_dataset = {} 100 | 101 | # stats 102 | dataset_sizes = {} 103 | encoder_sizes = {} 104 | decoder_sizes = {} 105 | def _add_or_create_stat(d, key, value): 106 | if key.startswith("panda"): 107 | # combine all panda entries 108 | if "panda_probe" in d: 109 | d["panda_probe"] += value 110 | else: 111 | d["panda_probe"] = value 112 | elif key.startswith("cnc"): 113 | # combine all cnc entries 114 | if "cnc_probe" in d: 115 | d["cnc_probe"] += value 116 | else: 117 | d["cnc_probe"] = value 118 | else: 119 | if key in d: 120 | d[key] += value 121 | else: 122 | d[key] = value 123 | 124 | num_data_workers = self.cfg.train.num_data_workers 125 | 126 | def _get_dl_config(ds_cfg, folder, for_eval): 127 | res = ds_cfg.copy() 128 | res["data_dir"] = os.path.join(ds_cfg["data_dir"], folder) 129 | if for_eval: 130 | # turn off data augmentation for eval dataset 131 | res["random_resize_crop"] = False 132 | res["random_hv_flip_prob"] = 0 133 | res["color_jitter"] = None 134 | return res 135 | 136 | # load all datasets according to the config as one WeightedDataLoader 137 | for ds_name, ds_cfg in self.cfg.datasets.items(): 138 | if ds_name.startswith("VAR_"): 139 | # skip the variables 140 | continue 141 | if not ds_cfg["activate"]: 142 | continue 143 | 144 | eval_only = ds_cfg["eval_only"] 145 | 146 | data_loader_cfg = dict(ds_cfg["data_loader"]) 147 | data_loader_cfg["batch_size"] = self.cfg.train.batch_size 148 | 149 | train_ds_cfg = _get_dl_config(data_loader_cfg, "train", for_eval=eval_only) 150 | eval_ds_cfg = _get_dl_config(data_loader_cfg, "val", for_eval=True) 151 | 152 | train_ds = hydra.utils.instantiate(train_ds_cfg) 153 | eval_ds = hydra.utils.instantiate(eval_ds_cfg) 154 | 155 | self.eval_dataset[ds_name] = eval_ds.get_dataloader(num_data_workers) 156 | if eval_only: 157 | self.eval_dataset[f"{ds_name}_train"] = train_ds.get_dataloader(num_data_workers) 158 | else: 159 | self.train_dataset[ds_name] = train_ds.get_dataloader(num_data_workers) 160 | 161 | total_count = len(train_ds) * self.cfg.train.batch_size + len(eval_ds) * self.cfg.train.batch_size 162 | _add_or_create_stat(dataset_sizes, ds_name, total_count) 163 | _add_or_create_stat(encoder_sizes, data_loader_cfg["encoder_domain"], total_count) 164 | _add_or_create_stat(decoder_sizes, data_loader_cfg["decoder_domain"], total_count) 165 | self.train_dataloader = WeightedDataLoader(list(self.train_dataset.values()), weight_type=self.cfg.train.dl_weight_type) 166 | self.eval_dataloader = WeightedDataLoader(list(self.eval_dataset.values()), weight_type=self.cfg.train.dl_weight_type) 167 | logging(f"Total train batches: {len(self.train_dataloader)}, eval batches: {len(self.eval_dataloader)}", True, "blue") 168 | 169 | if self.cfg.train.wandb and wandb and is_main_process(): 170 | # make dataset stat pie plots 171 | dataset_sizes_plot = make_dataset_pie_plot(dataset_sizes, "Dataset sizes", show=False) 172 | encoder_sizes_plot = make_dataset_pie_plot(encoder_sizes, "Encoder sizes", show=False) 173 | decoder_sizes_plot = make_dataset_pie_plot(decoder_sizes, "Decoder sizes", show=False) 174 | wandb.log({ 175 | f"stats/dataset_sizes": wandb.Image(dataset_sizes_plot), 176 | f"stats/encoder_sizes": wandb.Image(encoder_sizes_plot), 177 | f"stats/decoder_sizes": wandb.Image(decoder_sizes_plot), 178 | }) 179 | 180 | @staticmethod 181 | def compose_loss_history(loss_history, enc_domain, dec_domain, loss, pred=None, Y=None, denormalize_func=None): 182 | # Add to all losses 183 | loss_history["all_losses"].append(loss.item()) 184 | 185 | # add entry to loss_history 186 | entry_key = f"loss_{enc_domain}_{dec_domain}" 187 | if entry_key not in loss_history: 188 | loss_history[entry_key] = [loss.item()] 189 | else: 190 | loss_history[entry_key].append(loss.item()) 191 | 192 | # RMSE for pose estimation 193 | if "pose_estimation_6d" in dec_domain and (pred is not None) and (Y is not None): 194 | rot_rmse_key = f"rot_rmse_{enc_domain}" 195 | tra_rmse_key = f"tra_rmse_{enc_domain}" 196 | 197 | rot_rmse_val = rot_rmse(pred, Y, denormalize_func=denormalize_func) 198 | tra_rmse_val = tra_rmse(pred, Y, denormalize_func=denormalize_func) 199 | 200 | if rot_rmse_key not in loss_history: 201 | loss_history[rot_rmse_key] = [rot_rmse_val] 202 | else: 203 | loss_history[rot_rmse_key].append(rot_rmse_val) 204 | 205 | if tra_rmse_key not in loss_history: 206 | loss_history[tra_rmse_key] = [tra_rmse_val] 207 | else: 208 | loss_history[tra_rmse_key].append(tra_rmse_val) 209 | 210 | if "pose_estimation_3d" in dec_domain and (pred is not None) and (Y is not None): 211 | tra_rmse_key = f"tra_rmse_{enc_domain}" 212 | tra_rmse_val = tra_rmse(pred, Y, denormalize_func=denormalize_func) 213 | 214 | if tra_rmse_key not in loss_history: 215 | loss_history[tra_rmse_key] = [tra_rmse_val] 216 | else: 217 | loss_history[tra_rmse_key].append(tra_rmse_val) 218 | 219 | 220 | # classification accuracy 221 | if "cls" in dec_domain and (pred is not None) and (Y is not None): 222 | acc_top1_key = f"acc_{enc_domain}_top1" 223 | 224 | acc_top1_val = count_classification_topk(pred, Y, k=1) / len(Y) 225 | 226 | if acc_top1_key not in loss_history: 227 | loss_history[acc_top1_key] = [acc_top1_val] 228 | else: 229 | loss_history[acc_top1_key].append(acc_top1_val) 230 | 231 | 232 | @staticmethod 233 | def print_train_vs_test_stats(train_stat, test_stat): 234 | l = 35 235 | tl = 18 236 | logging("------- training vs eval stats -------", True, "blue") 237 | common_entries = set(train_stat.keys()).intersection(set(test_stat.keys())) 238 | for entry in sorted(common_entries): 239 | train_val = np.mean(train_stat[entry]) 240 | test_val = np.mean(test_stat[entry]) 241 | train_text = f"train: {train_val:.4f}".rjust(tl, ' ') 242 | val_text = f"test: {test_val:.4f}".rjust(tl, ' ') 243 | print(f"{entry.rjust(l, ' ')} \t {train_text} \t {val_text}") 244 | train_specific = set(train_stat.keys()).difference(common_entries) 245 | for entry in sorted(train_specific): 246 | train_val = np.mean(train_stat[entry]) 247 | train_text = f"train: {train_val:.4f}".rjust(tl, ' ') 248 | print(f"{entry.rjust(l, ' ')} \t {train_text}") 249 | test_specific = set(test_stat.keys()).difference(common_entries) 250 | for entry in sorted(test_specific): 251 | test_val = np.mean(test_stat[entry]) 252 | val_text = f"test: {test_val:.4f}".rjust(tl, ' ') 253 | print(f"{entry.rjust(l, ' ')} \t {' '*tl} \t {val_text}") 254 | 255 | def save_model(self, run_id, avg_val_loss, cur_step): 256 | if cur_step > 50 and avg_val_loss < self.min_avg_val_loss: 257 | # save as the best model 258 | self.min_avg_val_loss = avg_val_loss 259 | path = f"checkpoints/best_{run_id}" 260 | logging(f"Saving model to {path} as the best model", True, "green") 261 | else: 262 | path = f"checkpoints/{run_id}" 263 | 264 | logging(f"Current avg. test loss {avg_val_loss} v.s. best so far {self.min_avg_val_loss}. "\ 265 | f"Saving model to {path}", True, "green") 266 | # save the model 267 | self.model.save_components(path) 268 | 269 | # save the optimizer and scheduler 270 | opt_type = self.cfg.train.optimizer["_target_"].split(".")[-1] 271 | torch.save(self.optimizer.state_dict(), f"{path}/optimizer_{opt_type}.pt") 272 | sch_type = self.cfg.train.scheduler["_target_"].split(".")[-1] 273 | torch.save(self.scheduler.state_dict(), f"{path}/scheduler_{sch_type}.pt") 274 | 275 | # save the config file 276 | with open(f"{path}/config.yaml", "w") as f: 277 | f.write(OmegaConf.to_yaml(self.cfg)) 278 | # # save the git commit hash. Install gitpython and uncomment to use this feature 279 | # try: 280 | # repo = git.Repo(search_parent_directories=True) 281 | # with open(f"{path}/commit_hash.txt", "w") as f: 282 | # f.write(repo.head.object.hexsha) 283 | # del repo 284 | # except: 285 | # logging("Failed to save git commit hash, ignored", True, "red") 286 | 287 | def load_model(self, path, load_optimizer=False, load_scheduler=False): 288 | # load the network 289 | self.model.load_components(path) 290 | logging(f"Loaded model from {path}", True, "green") 291 | self.model.to(self.device) # need to move the model to device before loading optimizer and scheduler 292 | 293 | # load the optimizer and scheduler 294 | if load_optimizer: 295 | opt_type = self.cfg.train.optimizer["_target_"].split(".")[-1] 296 | self.optimizer.load_state_dict(torch.load(f"{path}/optimizer_{opt_type}.pt")) 297 | logging(f"Loaded optimizer from {path}", True, "green") 298 | 299 | if load_scheduler: 300 | sch_type = self.cfg.train.scheduler["_target_"].split(".")[-1] 301 | self.scheduler.load_state_dict(torch.load(f"{path}/scheduler_{sch_type}.pt")) 302 | logging(f"Loaded scheduler from {path}", True, "green") 303 | 304 | def forward_once(self, data_batch): 305 | enc_domain = data_batch["encoder_domain"] 306 | dec_domain = data_batch["decoder_domain"] 307 | batch_x = data_batch["X"] 308 | 309 | # use label denormalize function to calculate RMSE 310 | if "pose_estimation_" in dec_domain: 311 | label_inv_normalize = data_batch["label_inv_normalize"] 312 | else: 313 | label_inv_normalize = None 314 | 315 | # set the domains & forward mode for the model 316 | if "electroassem" in dec_domain or "pose_estimation" in dec_domain: 317 | forward_mode = "multi_tower" 318 | else: 319 | forward_mode = "single_tower" 320 | self.model.set_domains(enc_domain, dec_domain, forward_mode) 321 | 322 | if forward_mode == "single_tower": 323 | Xs = batch_x.to(self.device, non_blocking=True) 324 | pred = self.model(Xs) 325 | else: 326 | Xs = [x.to(self.device, non_blocking=True) for x in batch_x] 327 | pred = self.model(*Xs) 328 | return label_inv_normalize, pred 329 | 330 | def train_test(self, run_id, total_train_steps, test_every, test_steps): 331 | self.model.to(self.device) 332 | cur_step = 0 333 | 334 | train_iter = iter(self.train_dataloader) 335 | while cur_step < total_train_steps: 336 | # run training for test_every steps 337 | pbar = tqdm(range(test_every), position=0, leave=True) 338 | self.model.train() 339 | 340 | # unfreeze encoder and trunk if scheduled 341 | if self.scheduled_unfreeze_step > 0 and cur_step >= self.scheduled_unfreeze_step: 342 | if self.encoder_frozen: 343 | self.model.unfreeze_encoder() 344 | self.encoder_frozen = False 345 | logging("Encoder unfrozen", True, "green") 346 | if self.trunk_frozen: 347 | self.model.unfreeze_trunk() 348 | self.trunk_frozen = False 349 | logging("Trunk unfrozen", True, "green") 350 | 351 | train_loss_history = {"all_losses": []} 352 | for idx in pbar: 353 | cur_step += 1 354 | if cur_step >= total_train_steps: 355 | break 356 | # step the dataloader 357 | data = next(train_iter) 358 | enc_domain = data["encoder_domain"] 359 | dec_domain = data["decoder_domain"] 360 | batch_y = data["Y"] 361 | 362 | self.optimizer.zero_grad() 363 | label_inv_normalize, pred = self.forward_once(data) 364 | 365 | Y = batch_y.to(self.device) 366 | loss = self.model.compute_loss(pred, Y) 367 | loss.backward() 368 | self.optimizer.step() 369 | self.scheduler.step() 370 | 371 | self.compose_loss_history(train_loss_history, enc_domain, dec_domain, loss, denormalize_func=label_inv_normalize) 372 | 373 | # logging, if enabled 374 | if self.cfg.train.wandb and wandb and is_main_process() and cur_step % self.cfg.train.log_freq == 1: 375 | log_dict = { 376 | f"train/loss_{enc_domain}_{dec_domain}": loss.item(), 377 | f"train/epoch": cur_step // len(self.train_dataloader), 378 | f"train/step": cur_step, 379 | f"train/trunk_lr": self.optimizer.param_groups[0]["lr"], 380 | f"train/nontrunk_lr": self.optimizer.param_groups[1]["lr"]} 381 | if "pose_estimation_6d" in dec_domain: 382 | log_dict[f"train/6dpe_rot_rmse_{enc_domain}"] = rot_rmse(pred, Y, denormalize_func=label_inv_normalize) 383 | log_dict[f"train/6dpe_tra_rmse_{enc_domain}"] = tra_rmse(pred, Y, denormalize_func=label_inv_normalize) 384 | if "pose_estimation_3d" in dec_domain: 385 | log_dict[f"train/3dpe_tra_rmse_{enc_domain}"] = tra_rmse(pred, Y, denormalize_func=label_inv_normalize) 386 | if "cls" in dec_domain: 387 | log_dict[f"train/acc_{dec_domain}_top1"] = count_classification_topk(pred.detach(), batch_y, k=1) / len(Y) 388 | log_dict[f"train/acc_{dec_domain}_top5"] = count_classification_topk(pred.detach(), batch_y, k=5) / len(Y) 389 | wandb.log(log_dict) 390 | pbar.set_description( 391 | f"Train {cur_step}/{total_train_steps} steps | loss: {loss.item():.4f}") 392 | 393 | # run eval for test_steps 394 | test_loss_history = self.test(test_steps, run_id, cur_step, 395 | enable_wandb=(self.cfg.train.wandb and wandb and is_main_process())) 396 | 397 | self.print_train_vs_test_stats(train_loss_history, test_loss_history) 398 | 399 | # save model 400 | if self.cfg.train.save_model and is_main_process(): 401 | avg_val_loss = np.mean(test_loss_history["all_losses"]) 402 | self.save_model(run_id, avg_val_loss, cur_step) 403 | return 404 | 405 | @torch.no_grad() 406 | def test(self, test_steps, run_id, cur_step, enable_wandb): 407 | self.model.to(self.device) 408 | test_iter = iter(self.eval_dataloader) 409 | self.model.eval() 410 | losses = [] 411 | test_loss_history = {"all_losses": []} 412 | 413 | pbar = tqdm(range(test_steps), position=0, leave=True) 414 | for idx in pbar: 415 | # generate visualizations using inv_normalize_func 416 | data = next(test_iter) 417 | enc_domain = data["encoder_domain"] 418 | dec_domain = data["decoder_domain"] 419 | batch_y = data["Y"] 420 | # the denormalize function for the images 421 | inv_normalize_func = data["inv_normalize"] 422 | 423 | label_inv_normalize, pred = self.forward_once(data) 424 | Y = batch_y.to(self.device) 425 | 426 | loss = self.model.compute_loss(pred, Y) 427 | losses.append(loss.item()) 428 | 429 | self.compose_loss_history(test_loss_history, enc_domain, dec_domain, loss, pred, batch_y, denormalize_func=label_inv_normalize) 430 | 431 | # obtain loss, and (optionally) generate mae visualizations 432 | if get_entry_or(self.cfg.train, "generate_mae_visualizations", True) and idx == 0 and "mae" in dec_domain: 433 | # generate visualizations 434 | (pred_imgs, mask, ids_restore) = pred 435 | self.generate_mae_visualizations( 436 | data["X"].to(self.device, non_blocking=True), 437 | self.cfg.network.patch_size, pred_imgs, mask, inv_normalize_func, run_id, cur_step) 438 | 439 | pbar.set_description( 440 | f"Test {idx}/{test_steps} steps | loss: {loss.item():.4f}") 441 | if enable_wandb: 442 | log_items = { 443 | "eval/epoch": cur_step // len(self.train_dataloader), 444 | "eval/step": cur_step, 445 | f"eval/avg_test_loss": np.mean(losses)} 446 | for k, v in test_loss_history.items(): 447 | log_items[f"eval/{k}"] = np.mean(v) 448 | wandb.log(log_items) 449 | 450 | return test_loss_history 451 | 452 | @torch.no_grad() 453 | def predict(self, enc_domain, dec_domain, batch_x): 454 | self.model.to(self.device) 455 | self.model.eval() 456 | # set the domains & forward mode for the model 457 | if "electroassem" in dec_domain or "pose_estimation" in dec_domain: 458 | forward_mode = "multi_tower" 459 | else: 460 | forward_mode = "single_tower" 461 | self.model.set_domains(enc_domain, dec_domain, forward_mode) 462 | 463 | if forward_mode == "single_tower": 464 | Xs = batch_x.to(self.device, non_blocking=True) 465 | pred = self.model(Xs) 466 | else: 467 | Xs = [x.to(self.device, non_blocking=True) for x in batch_x] 468 | pred = self.model(*Xs) 469 | return pred 470 | 471 | @torch.no_grad() 472 | def generate_mae_visualizations(self, 473 | imgs, patch_size, preds, masks, 474 | inv_normalize_func, run_id, cur_step, 475 | num_to_generate=5, 476 | save=True): 477 | """ 478 | unpatchify preds (N, L, patch_size**2 *3) back to images (N, 3, H, W) 479 | """ 480 | if preds.shape[1] == masks.shape[1]: 481 | # the case for original MAE 482 | pred_imgs = inv_normalize_func(mae_unpatchify(preds, patch_size)).detach().cpu() 483 | pred_imgs_removed = inv_normalize_func(mae_unpatchify_pred_only(preds, imgs, masks, patch_size)).detach().cpu() 484 | else: 485 | # the case for cross MAE 486 | pred_imgs = inv_normalize_func(cross_mae_unpatchify(preds, imgs, masks, patch_size)).detach().cpu() 487 | pred_imgs_removed = pred_imgs 488 | ori_imgs = inv_normalize_func(imgs).detach().cpu() 489 | masked_imgs = mae_apply_patchified_mask(ori_imgs, masks, patch_size).detach().cpu() 490 | pil_converter = ToPILImage() 491 | 492 | imgs = [] 493 | for i in range(min(num_to_generate, len(pred_imgs))): 494 | img = torch.cat([ori_imgs[i], masked_imgs[i], pred_imgs[i], pred_imgs_removed[i]], dim=2) 495 | imgs.append(img) 496 | # save all 5 images as one big image 497 | imgs = torch.cat(imgs, dim=1) 498 | pil_img = pil_converter(imgs) 499 | 500 | if save: 501 | # save images 502 | p = f"checkpoints/{run_id}/mae_visualizations/{cur_step}" 503 | os.makedirs(p, exist_ok=True) 504 | pil_img.save(os.path.join(p, f"visualize.jpg")) 505 | 506 | if self.cfg.train.wandb and wandb and is_main_process(): 507 | wandb.log({ 508 | f"mae_visualizations/step": cur_step, 509 | f"mae_visualizations/visualize": wandb.Image(pil_img), 510 | }) 511 | return pil_img 512 | 513 | def train(self): 514 | 515 | if len(self.cfg.train.finetune_from) > 0: 516 | logging(f"WARNING: Loading existing model to finetune from {self.cfg.train.finetune_from}", True, "red") 517 | load_optimizer = get_entry_or(self.cfg.train, "load_optimizer", False) 518 | load_scheduler = get_entry_or(self.cfg.train, "load_scheduler", False) 519 | logging(f"Loading optimizer: {load_optimizer}, Loading scheduler: {load_scheduler}", True, "red") 520 | self.load_model(self.cfg.train.finetune_from, load_optimizer=load_optimizer, load_scheduler=load_scheduler) 521 | 522 | self.train_test( 523 | self.run_id, self.cfg.train.total_train_steps, self.cfg.train.test_every, self.cfg.train.test_steps) 524 | --------------------------------------------------------------------------------