├── example_cfg.yaml ├── data_preprocessing ├── utils.py ├── dataset_list.yaml ├── crop_images.py ├── customized_r3m.py ├── normalize_actions.py ├── move_h5_image_to_png.py ├── extract_image_features.py ├── extract_language_features.py ├── check_data.ipynb ├── convert_tfds_to_h5.py └── rt-x_data_cfg.yaml ├── README.md ├── example_model.py ├── train.py └── dataset.py /example_cfg.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | horizon: 16 3 | task: rt-x 4 | 5 | model: 6 | target: example_model.ExampleModel 7 | kwargs: 8 | model_kwargs: 9 | 10 | training_kwargs: 11 | 12 | 13 | trainer: 14 | target: lightning.pytorch.trainer.Trainer 15 | kwargs: 16 | devices: [0, 1] 17 | max_epochs: 10 18 | check_val_every_n_epoch: 1 19 | log_every_n_steps: 10 20 | logger: 21 | target: lightning.pytorch.loggers.wandb.WandbLogger 22 | kwargs: 23 | project: 24 | name: 25 | num_sanity_val_steps: 2 26 | 27 | dataset: 28 | target: example_dataset.MultiDataset 29 | kwargs: 30 | root_dir: 31 | dataset_names: 32 | data_cfg: 33 | horizon: ${horizon} 34 | get_language: True 35 | get_canonical_image: 36 | get_image_dict: 37 | get_low_dim: 38 | feature_type: r3m_resnet34 39 | # feature_type: clip_ViT-B32 40 | 41 | dataloader: 42 | batch_size: 512 43 | num_workers: 32 44 | pin_memory: True 45 | persistent_workers: True 46 | -------------------------------------------------------------------------------- /data_preprocessing/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | 5 | def uniform_normalization(raw_data: np.array, min_values: np.array, max_values: np.array): 6 | lower_bound = -1 7 | upper_bound = 1 8 | std = max_values - min_values 9 | 10 | normalized_data = np.zeros_like(raw_data) 11 | # normalize the raw_data to [-1, 1] uniformly 12 | for i in range(7): 13 | if std[i] == 0: # ignore 14 | continue 15 | normalized_data[:, i] = lower_bound + (upper_bound - lower_bound) * (raw_data[:, i] - min_values[i]) / std[i] 16 | 17 | return normalized_data 18 | 19 | def scale_only_normalization(raw_data, min_values, max_values): 20 | normalized_data = copy.deepcopy(raw_data) 21 | for i in range(7): 22 | if min_values[i] == 0 or max_values[i] == 0: 23 | continue 24 | larger = max(abs(min_values[i]), abs(max_values[i])) 25 | normalized_data[i] /= larger 26 | return normalized_data 27 | 28 | def scale_only_unnormalization(raw_data, min_values, max_values): 29 | unnormalized_data = copy.deepcopy(raw_data) 30 | for i in range(7): 31 | if min_values[i] == 0 or max_values[i] == 0: 32 | continue 33 | larger = max(abs(min_values[i]), abs(max_values[i])) 34 | unnormalized_data[i] *= larger 35 | return unnormalized_data 36 | 37 | def uniform_unnormalization(normalized_data: np.array, min_values: np.array, max_values: np.array): 38 | # unnormalize the data in [-1, 1] back to the raw_data 39 | raw_data = min_values + 0.5 * (normalized_data + 1) * (max_values - min_values) 40 | return raw_data 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data pre-processing and training code on Open-X-Embodiment for PyTorch users. 2 | 3 | 4 | ### Dataset access: 5 | Refer to the [rt-x official repo](https://github.com/google-deepmind/open_x_embodiment#dataset-access) 6 | 7 | ### Usages: 8 | With this repo, you could: 9 | - Convert tfds to h5 file with *convert_tfds_to_h5.py* # converting large datasets takes MASSIVE disk space. (up to 8 TB for kuka) 10 | 11 | - Visualize the processed h5 file with *check_data.ipynb*. 12 | 13 | - Extract raw images with *extract_images.py*. 14 | 15 | - Extract image and language features *extract_language_features.py* and *extract_image_features.py*. (we use R3M and CLIP, and it's easy to customize it) 16 | 17 | - Normalize actions according to the statistics with *normalize_actoins.py* and *rt-x_data_cfg.yaml*. 18 | 19 | and 20 | 21 | - Customize you model and directly train it! 22 | 23 | 24 | ### Features: 25 | - For extracting features, we use multi-processing among datasets for better efficiency. 26 | - For converting tfds to hdf5 files, we also support parallel processing by setting the --index argument. # turned off by default 27 | 28 | 29 | ### Environment 30 | In your python environment: 31 | 32 | - install tf and tfds 33 | ``` 34 | pip install tensorflow tensorflow-datasets 35 | ``` 36 | 37 | - some basic libraries 38 | ``` 39 | conda install h5py yaml jupyter tqdm omegaconf gdown matplotlib 40 | ``` 41 | 42 | - install pytorch (version not restricted) 43 | ``` 44 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia 45 | ``` 46 | 47 | - optional: 48 | ``` 49 | conda install lightning transformers diffusers # for model training 50 | ``` 51 | 52 | ``` 53 | pip install git+https://github.com/openai/CLIP.git 54 | ``` 55 | 56 | ``` 57 | pip install git+https://github.com/facebookresearch/r3m.git 58 | ``` 59 | -------------------------------------------------------------------------------- /data_preprocessing/dataset_list.yaml: -------------------------------------------------------------------------------- 1 | small: 2 | - taco_play 3 | - jaco_play 4 | - berkeley_cable_routing 5 | - roboturk 6 | - nyu_door_opening_surprising_effectiveness 7 | - viola 8 | - berkeley_autolab_ur5 9 | - columbia_cairlab_pusht_real 10 | - stanford_kuka_multimodal_dataset_converted_externally_to_rlds 11 | - nyu_rot_dataset_converted_externally_to_rlds 12 | - stanford_hydra_dataset_converted_externally_to_rlds 13 | - austin_buds_dataset_converted_externally_to_rlds 14 | - nyu_franka_play_dataset_converted_externally_to_rlds 15 | - cmu_franka_exploration_dataset_converted_externally_to_rlds 16 | - ucsd_kitchen_dataset_converted_externally_to_rlds 17 | - ucsd_pick_and_place_dataset_converted_externally_to_rlds 18 | - austin_sailor_dataset_converted_externally_to_rlds 19 | - austin_sirius_dataset_converted_externally_to_rlds 20 | - usc_cloth_sim_converted_externally_to_rlds 21 | - utokyo_pr2_opening_fridge_converted_externally_to_rlds 22 | - utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds 23 | - utokyo_xarm_pick_and_place_converted_externally_to_rlds 24 | - berkeley_mvp_converted_externally_to_rlds 25 | - berkeley_rpt_converted_externally_to_rlds 26 | - kaist_nonprehensile_converted_externally_to_rlds 27 | - stanford_mask_vit_converted_externally_to_rlds 28 | - tokyo_u_lsmo_converted_externally_to_rlds 29 | - dlr_sara_pour_converted_externally_to_rlds 30 | - dlr_sara_grid_clamp_converted_externally_to_rlds 31 | - dlr_edan_shared_control_converted_externally_to_rlds 32 | - asu_table_top_converted_externally_to_rlds 33 | - stanford_robocook_converted_externally_to_rlds 34 | - eth_agent_affordances 35 | - imperialcollege_sawyer_wrist_cam 36 | - iamlab_cmu_pickup_insert_converted_externally_to_rlds 37 | - uiuc_d3field 38 | - utaustin_mutex 39 | - berkeley_fanuc_manipulation 40 | - cmu_play_fusion 41 | - cmu_stretch 42 | 43 | large: # single file larger than 100G 44 | - fractal20220817_data 45 | - kuka 46 | - bridge 47 | - robo_net 48 | - toto 49 | - bc_z 50 | - maniskill_dataset_converted_externally_to_rlds 51 | - language_table 52 | 53 | others: 54 | # wheeled 55 | - berkeley_gnm_recon 56 | - berkeley_gnm_cory_hall 57 | - berkeley_gnm_sac_son 58 | 59 | # quadrupedal robot 60 | - utokyo_saytap_converted_externally_to_rlds 61 | 62 | # bi-manual 63 | - utokyo_xarm_bimanual_converted_externally_to_rlds 64 | -------------------------------------------------------------------------------- /data_preprocessing/crop_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import yaml 4 | from pathlib import Path 5 | import tqdm 6 | import torch 7 | import torchvision.transforms as T 8 | from PIL import Image 9 | 10 | 11 | DATASET_ROOT_RIR = '' 12 | 13 | 14 | def get_preprocess(shorter_edge): 15 | return T.Compose([ 16 | T.CenterCrop(shorter_edge), 17 | T.Resize(224), 18 | # T.ToTensor(), 19 | # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 20 | ]) 21 | 22 | class DatasetConverter: 23 | def __init__( 24 | self, 25 | dataset_name: str, 26 | src_h5_path: Path, 27 | tgt_h5_path: Path, 28 | data_cfg: dict 29 | ): 30 | self.dataset_name = dataset_name 31 | self.src_h5_path = src_h5_path 32 | self.tgt_h5_path = tgt_h5_path 33 | self.data_cfg = data_cfg 34 | self.preprocess = dict() 35 | 36 | def run(self): 37 | print(f'processing {self.src_h5_path}') 38 | 39 | image_cfg = self.data_cfg['image'] 40 | 41 | for view_name, shape in image_cfg.items(): 42 | self.preprocess[view_name] = get_preprocess(min(shape[:2])) 43 | 44 | with h5py.File(str(self.src_h5_path), 'r') as src_file, h5py.File(str(self.tgt_h5_path), 'w') as tgt_file: 45 | num_episodes = src_file['episodes']['length'][()] 46 | 47 | for episode_index in tqdm.trange(num_episodes): 48 | episode = src_file['episodes'][f'episode_{episode_index}'] 49 | episode_length = int(episode['length'][()]) 50 | tgt_episode = tgt_file.create_group(name=f'episode_{episode_index}') 51 | 52 | for key in image_cfg.keys(): 53 | src_images = episode['observation'][key][()] 54 | tgt_images = np.zeros(shape=(episode_length, 224, 224, 3), dtype=np.uint8) 55 | 56 | for step_idx, src_image in enumerate(src_images): 57 | pil_img = Image.fromarray(src_image) 58 | tgt_image = np.array(self.preprocess[key](pil_img)) 59 | tgt_images[step_idx] = tgt_image 60 | tgt_episode.create_dataset(name=key, data=tgt_images) 61 | 62 | 63 | if __name__ == '__main__': 64 | src_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_h5') 65 | tgt_root_dir = Path(f'{DATASET_ROOT_DIR}/our_rt-x/cropped_images') 66 | tgt_root_dir.mkdir(exist_ok=True, parents=True) 67 | 68 | with open('robo_ldm/configs/data_cfg.yaml', 'r') as f: 69 | data_cfgs = yaml.safe_load(f) 70 | 71 | for dataset_name, cfg in data_cfgs.items(): 72 | if dataset_name[0] == '_': 73 | continue 74 | 75 | src_h5_path=src_root_dir / f'{dataset_name}.hdf5' 76 | tgt_h5_path=tgt_root_dir / f'{dataset_name}.hdf5' 77 | 78 | print(f'processing {dataset_name}') 79 | dataset_converter = DatasetConverter( 80 | dataset_name=dataset_name, 81 | src_h5_path=src_h5_path, 82 | tgt_h5_path=tgt_h5_path, 83 | data_cfg=cfg 84 | ) 85 | dataset_converter.run() 86 | print(f'{dataset_name} done') 87 | -------------------------------------------------------------------------------- /example_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from omegaconf import OmegaConf 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from diffusers.optimization import get_cosine_schedule_with_warmup 7 | 8 | import lightning.pytorch as pl 9 | 10 | from train import instantiate_from_config 11 | 12 | 13 | class ExampleModel(pl.LightningModule): 14 | def __init__( 15 | self, 16 | model_kwargs, 17 | training_kwargs, 18 | all_config=None 19 | ): 20 | super().__init__() 21 | 22 | self.all_config = all_config 23 | self.training_kwargs = training_kwargs 24 | self.model_kwargs = model_kwargs 25 | self.save_hyperparameters() 26 | 27 | self.action_dim = action_dim = model_kwargs['action_dim'] 28 | self.hidden_size = hidden_size = model_kwargs['hidden_size'] 29 | self.horizon = horizon = model_kwargs['horizon'] 30 | 31 | self.action_emb = nn.Linear(action_dim, hidden_size) 32 | if model_kwargs.get('low_dim_feature_dim') is not None: 33 | self.low_dim_emb = nn.Linear(model_kwargs['low_dim_feature_dim'], hidden_size) 34 | else: 35 | self.low_dim_emb = None 36 | self.language_emb = nn.Linear(in_features=model_kwargs['language_feature_dim'], out_features=hidden_size) 37 | self.img_emb = nn.Linear(in_features=512, out_features=hidden_size) 38 | 39 | self.action_head = nn.Linear(hidden_size, action_dim) 40 | 41 | def configure_optimizers(self): 42 | kwargs = self.training_kwargs 43 | tuned_parameters = [p for p in self.parameters() if p.requires_grad] 44 | optimizer = torch.optim.Adam( 45 | tuned_parameters, 46 | lr=kwargs.lr, 47 | ) 48 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=kwargs.warmup_steps, num_training_steps=kwargs.num_training_steps) 49 | 50 | self.lr_scheduler = scheduler 51 | return { 52 | 'optimizer': optimizer, 53 | 'lr_scheduler': { 54 | 'scheduler': scheduler, 55 | 'interval': 'step' 56 | } 57 | } 58 | 59 | def get_img_emb(self, raw_image_features): 60 | return self.img_emb(raw_image_features) 61 | 62 | def get_language_emb(self, raw_language_features): 63 | return self.language_emb(raw_language_features) 64 | 65 | def get_low_dim_emb(self, raw_low_dim_data=None): 66 | if raw_low_dim_data is None or self.low_dim_emb is None: 67 | return None 68 | return self.low_dim_emb(raw_low_dim_data) 69 | 70 | def forward(self, batch, batch_idx, sample_posterior=True, split='train'): 71 | action = batch['action'] 72 | language_emb = self.get_language_emb(batch['language']) 73 | img_emb = self.get_img_emb(batch['image']) 74 | low_dim_emb = self.get_low_dim_emb(batch.get('low_dim')) 75 | 76 | batch_size = action.shape[0] 77 | pred_action = action 78 | 79 | loss = F.mse_loss(action, pred_action) 80 | loss_log = { 81 | f'{split}/loss': loss 82 | } 83 | return loss, loss_log 84 | 85 | def training_step(self, batch, batch_idx): 86 | self.last_training_batch = batch 87 | total_loss, log_dict = self.forward(batch=batch, batch_idx=batch_idx) 88 | self.log_dict(log_dict, sync_dist=True) 89 | return total_loss 90 | 91 | def validation_step(self, batch, batch_idx): 92 | total_loss, log_dict = self.forward(batch=batch, batch_idx=batch_idx, split='val') 93 | self.log_dict(log_dict, sync_dist=True) 94 | return total_loss 95 | -------------------------------------------------------------------------------- /data_preprocessing/customized_r3m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import gdown 4 | import omegaconf 5 | import hydra 6 | from PIL import Image 7 | 8 | import torch 9 | import torchvision.transforms as T 10 | 11 | 12 | def load_r3m(modelid): 13 | VALID_ARGS = ["_target_", "device", "lr", "hidden_dim", "size", "l2weight", "l1weight", "langweight", "tcnweight", "l2dist", "bs"] 14 | def remove_language_head(state_dict): 15 | keys = state_dict.keys() 16 | ## Hardcodes to remove the language head 17 | ## Assumes downstream use is as visual representation 18 | for key in list(keys): 19 | if ("lang_enc" in key) or ("lang_rew" in key): 20 | del state_dict[key] 21 | return state_dict 22 | 23 | def cleanup_config(cfg): 24 | config = copy.deepcopy(cfg) 25 | keys = config.agent.keys() 26 | for key in list(keys): 27 | if key not in VALID_ARGS: 28 | del config.agent[key] 29 | config.agent["_target_"] = "r3m.R3M" 30 | config["device"] = 'cpu' 31 | 32 | ## Hardcodes to remove the language head 33 | ## Assumes downstream use is as visual representation 34 | config.agent["langweight"] = 0 35 | return config.agent 36 | 37 | home = os.path.join(os.path.expanduser("~"), ".r3m") 38 | if modelid == "resnet50": 39 | foldername = "r3m_50" 40 | modelurl = 'https://drive.google.com/uc?id=1Xu0ssuG0N1zjZS54wmWzJ7-nb0-7XzbA' 41 | configurl = 'https://drive.google.com/uc?id=10jY2VxrrhfOdNPmsFdES568hjjIoBJx8' 42 | elif modelid == "resnet34": 43 | foldername = "r3m_34" 44 | modelurl = 'https://drive.google.com/uc?id=15bXD3QRhspIRacOKyWPw5y2HpoWUCEnE' 45 | configurl = 'https://drive.google.com/uc?id=1RY0NS-Tl4G7M1Ik_lOym0b5VIBxX9dqW' 46 | elif modelid == "resnet18": 47 | foldername = "r3m_18" 48 | modelurl = 'https://drive.google.com/uc?id=1A1ic-p4KtYlKXdXHcV2QV0cUzI4kn0u-' 49 | configurl = 'https://drive.google.com/uc?id=1nitbHQ-GRorxc7vMUiEHjHWP5N11Jvc6' 50 | else: 51 | raise NameError('Invalid Model ID') 52 | 53 | if not os.path.exists(os.path.join(home, foldername)): 54 | os.makedirs(os.path.join(home, foldername)) 55 | modelpath = os.path.join(home, foldername, "model.pt") 56 | configpath = os.path.join(home, foldername, "config.yaml") 57 | if not os.path.exists(modelpath): 58 | gdown.download(modelurl, modelpath, quiet=False) 59 | gdown.download(configurl, configpath, quiet=False) 60 | 61 | modelcfg = omegaconf.OmegaConf.load(configpath) 62 | cleancfg = cleanup_config(modelcfg) 63 | rep = hydra.utils.instantiate(cleancfg) 64 | r3m_state_dict = remove_language_head(torch.load(modelpath, map_location=torch.device('cpu'))['r3m']) 65 | filtered_state_dict = {} 66 | for key, value in r3m_state_dict.items(): 67 | if key.startswith("module"): 68 | new_key = key.replace("module.", "") 69 | filtered_state_dict[new_key] = value 70 | rep.load_state_dict(filtered_state_dict) 71 | return rep 72 | 73 | 74 | class R3M(torch.nn.Module): 75 | def __init__( 76 | self, 77 | model_id: str 78 | ): 79 | super().__init__() 80 | r3m = load_r3m(model_id) 81 | 82 | self.model = r3m.convnet 83 | self.normlayer = r3m.normlayer 84 | self.feature_size = r3m.outdim 85 | 86 | self.preprocess = T.Compose([ 87 | T.ToTensor(), 88 | self.normlayer 89 | ]) 90 | 91 | def raw_preprocess(self, image: Image): 92 | # depreciated 93 | shorter_edge = min(image.size) 94 | process = T.Compose([ 95 | T.CenterCrop(shorter_edge), 96 | T.Resize(224), 97 | T.ToTensor(), 98 | self.normlayer 99 | ]) 100 | return process(image) 101 | 102 | def forward(self, images): 103 | return self.model(images) 104 | -------------------------------------------------------------------------------- /data_preprocessing/normalize_actions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 4 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 5 | import numpy as np 6 | import h5py 7 | import yaml 8 | from pathlib import Path 9 | import tqdm 10 | import pickle 11 | import torch 12 | from concurrent.futures import ProcessPoolExecutor as PPE 13 | from data_preprocessing.utils import uniform_normalization, scale_only_normalization 14 | 15 | 16 | DATASET_ROOT_DIR = '' 17 | 18 | 19 | class DatasetConverter: 20 | def __init__( 21 | self, 22 | dataset_name: str, 23 | src_h5_path: Path, 24 | normalized_pkl_path: Path, 25 | raw_pkl_path: Path, 26 | data_cfg: dict 27 | ): 28 | self.dataset_name = dataset_name 29 | self.src_h5_path = src_h5_path 30 | self.normalized_pkl_path = normalized_pkl_path 31 | self.raw_pkl_path = raw_pkl_path 32 | self.data_cfg = data_cfg 33 | 34 | def run(self): 35 | print(f'processing {self.src_h5_path}') 36 | 37 | with h5py.File(str(self.src_h5_path), 'r') as src_file: 38 | num_episodes = src_file['episodes']['length'][()] 39 | raw_episodes = [] 40 | normalized_episodes = [] 41 | 42 | action_cfg = self.data_cfg['action'] 43 | action_outer_key = action_cfg['outer_key'] 44 | action_inner_keys = action_cfg['inner_keys'] 45 | index_mapping = action_cfg['index_mapping'] 46 | mins = np.array(action_cfg['min'], dtype=np.float32) 47 | maxs = np.array(action_cfg['max'], dtype=np.float32) 48 | 49 | for episode_index in tqdm.trange(num_episodes): 50 | episode = src_file['episodes'][f'episode_{episode_index}'] 51 | 52 | length = int(episode['length'][()]) 53 | concated_data = \ 54 | [np.zeros(shape=(length, 1), dtype=np.float32)] +\ 55 | [episode[action_outer_key][a_inner_key][()] for a_inner_key in action_inner_keys] 56 | for data_idx, d in enumerate(concated_data): 57 | if len(d.shape) == 1: 58 | concated_data[data_idx] = np.expand_dims(d, axis=1) 59 | elif d.dtype == bool: 60 | concated_data[data_idx] = np.array(d, dtype=np.float32).reshape(-1, 1) 61 | concated_data = np.concatenate(concated_data, axis=-1) 62 | 63 | arranged_action = np.zeros(shape=(length, 7), dtype=np.float32) 64 | for tgt_idx, src_idx in enumerate(index_mapping): 65 | arranged_action[:, tgt_idx] = concated_data[:, src_idx] 66 | 67 | raw_episodes.append(torch.from_numpy(arranged_action).to(dtype=torch.float32)) 68 | 69 | normalized_action = scale_only_normalization(arranged_action, min_values=mins, max_values=maxs) 70 | if not action_cfg['gripper_close_is_positive']: 71 | normalized_action[:, -1] = -normalized_action[:, -1] 72 | normalized_episodes.append(torch.from_numpy(normalized_action)) 73 | 74 | with self.raw_pkl_path.open('wb') as f: 75 | pickle.dump(raw_episodes, f) 76 | with self.normalized_pkl_path.open('wb') as f: 77 | pickle.dump(normalized_episodes, f) 78 | 79 | 80 | def process_dataset(dataset_name, src_h5_path, normalized_pkl_path, raw_pkl_path, data_cfg): 81 | dataset_converter = DatasetConverter( 82 | dataset_name=dataset_name, 83 | src_h5_path=src_h5_path, 84 | normalized_pkl_path=normalized_pkl_path, 85 | raw_pkl_path=raw_pkl_path, 86 | data_cfg=data_cfg 87 | ) 88 | try: 89 | dataset_converter.run() 90 | except Exception as e: 91 | print(f'{dataset_name} error as:') 92 | print(e) 93 | print('----------------------------------------------') 94 | else: 95 | print(f'{dataset_name} done') 96 | 97 | 98 | if __name__ == '__main__': 99 | src_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_h5') 100 | raw_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_pt/raw_actions') 101 | normalized_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_pt/normalized_actions') 102 | raw_root_dir.mkdir(exist_ok=True, parents=True) 103 | normalized_root_dir.mkdir(exist_ok=True, parents=True) 104 | 105 | with open('./data_preprocessing/rt-x_data_cfg.yaml', 'r') as f: 106 | data_cfgs = yaml.safe_load(f)['datasets'] 107 | 108 | dataset_name_list = [n for n in data_cfgs.keys() if n[0] != '_'] 109 | dataset_cfg_list = [data_cfgs[n] for n in dataset_name_list] 110 | src_h5_path_list = [src_root_dir / f'{n}.hdf5' for n in dataset_name_list] 111 | normalized_pkl_path_list = [normalized_root_dir / f'{n}.pkl' for n in dataset_name_list] 112 | raw_pkl_path_list = [raw_root_dir / f'{n}.pkl' for n in dataset_name_list] 113 | 114 | with PPE() as ppe: 115 | list(ppe.map(process_dataset, dataset_name_list, src_h5_path_list, normalized_pkl_path_list, raw_pkl_path_list, dataset_cfg_list)) 116 | -------------------------------------------------------------------------------- /data_preprocessing/move_h5_image_to_png.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import yaml 5 | from pathlib import Path 6 | from PIL import Image 7 | import time 8 | from concurrent.futures import ProcessPoolExecutor 9 | 10 | 11 | DATASET_ROOT_DIR = '' 12 | 13 | 14 | def process_episode(dataset_dir: Path, h5_path, episode_path, episode_index, obs_keys): 15 | with h5py.File(h5_path, 'r') as f: 16 | episodes = f[episode_path] 17 | episode_dir = dataset_dir / f'episode_{episode_index}' 18 | episode_dir.mkdir(exist_ok=True) 19 | obs_data = episodes[f'episode_{episode_index}']['observation'] 20 | num_steps = episodes[f'episode_{episode_index}']['length'][()] 21 | 22 | for obs_key in obs_keys: 23 | obs_dir = episode_dir / obs_key 24 | obs_dir.mkdir(exist_ok=True) 25 | images = obs_data[obs_key][()] 26 | for step_index in range(num_steps): 27 | np_image = images[step_index] 28 | if 'depth' in obs_key: 29 | np_image = (np_image - np_image.min()) / (np_image.max() - np_image.min()) * 65535 30 | np_image = np.squeeze(np_image.astype(np.uint16)) 31 | image = Image.fromarray(np_image) 32 | image.save(str(obs_dir / f'{step_index}.png')) 33 | 34 | 35 | class DatasetConverter: 36 | def __init__( 37 | self, 38 | dataset_name: str, 39 | src_h5_path: Path, 40 | temp_h5_path: Path, 41 | tgt_png_dir: Path): 42 | 43 | self.dataset_name = dataset_name 44 | 45 | self.src_h5_path = src_h5_path 46 | self.temp_h5_path = temp_h5_path 47 | self.tgt_png_dir = tgt_png_dir 48 | self.tgt_png_dir.mkdir(parents=True, exist_ok=True) 49 | 50 | def run(self): 51 | print(f'processing {self.src_h5_path}') 52 | with h5py.File(str(self.src_h5_path), 'r') as f: 53 | obs_info = f['shape_type_info']['observation'] 54 | 55 | obs_keys = [] 56 | for key in obs_info.keys(): 57 | key_shape = obs_info[key].get('shape') 58 | if key_shape is not None and key_shape[()].shape[0] == 3 and 'flow' not in key: 59 | obs_keys.append(key) 60 | 61 | num_episodes = f['episodes']['length'][()] 62 | episode_indices = list(range(num_episodes)) 63 | print(f'num of episodes: {num_episodes}') 64 | begin_time = time.perf_counter() 65 | 66 | with ProcessPoolExecutor() as ppe: 67 | list(ppe.map(process_episode, [self.tgt_png_dir]*num_episodes, [self.src_h5_path]*num_episodes, ['episodes']*num_episodes, episode_indices, [obs_keys] * num_episodes)) 68 | 69 | print(f'data saved at {self.tgt_png_dir}, took {time.perf_counter() - begin_time} seconds') 70 | 71 | def copy_group(self, src_group, dst_group): 72 | for key in src_group: 73 | src_obj = src_group[key] 74 | if isinstance(src_obj, h5py.Dataset): 75 | if key not in self.obs_keys: 76 | src_group.copy(key, dst_group, key) 77 | elif isinstance(src_obj, h5py.Group): 78 | dst_sub_group = dst_group.create_group(key) 79 | self.copy_group(src_obj, dst_sub_group) 80 | 81 | def reclaiming(self): 82 | with h5py.File(str(self.src_h5_path), 'r') as f: 83 | obs_info = f['shape_type_info']['observation'] 84 | 85 | obs_keys = [] 86 | for key in obs_info.keys(): 87 | key_shape = obs_info[key].get('shape') 88 | if key_shape is not None and key_shape[()].shape[0] == 3 and 'flow' not in key: 89 | obs_keys.append(key) 90 | self.obs_keys = obs_keys 91 | 92 | with h5py.File(str(self.src_h5_path), "r") as src_file, h5py.File(str(self.temp_h5_path), "w") as dst_file: 93 | self.copy_group(src_file, dst_file) 94 | 95 | 96 | if __name__ == '__main__': 97 | src_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_h5') 98 | temp_root_dir = Path(f'{DATASET_ROOT_DIR}/_rt-x_h5') 99 | temp_root_dir.mkdir(exist_ok=True) 100 | tgt_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_png') 101 | tgt_root_dir.mkdir(exist_ok=True) 102 | 103 | with open('./data_preprocessing/dataset_list.yaml', 'r') as f: 104 | dataset_list = yaml.safe_load(f) 105 | dataset_list = dataset_list['large'] + dataset_list['small'] 106 | 107 | for dataset_name in dataset_list: 108 | src_h5_path=src_root_dir / f'{dataset_name}.hdf5' 109 | temp_h5_path=temp_root_dir / f'{dataset_name}.hdf5' 110 | tgt_png_dir=tgt_root_dir / dataset_name 111 | 112 | if temp_h5_path.exists(): 113 | continue 114 | dataset_converter = DatasetConverter( 115 | dataset_name=dataset_name, 116 | src_h5_path=src_h5_path, 117 | temp_h5_path=temp_h5_path, 118 | tgt_png_dir=tgt_png_dir 119 | ) 120 | dataset_converter.run() 121 | dataset_converter.reclaiming() 122 | # os.remove(str(dataset_converter.src_h5_path)) 123 | print(f'{dataset_name} done') 124 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from omegaconf import OmegaConf 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import lightning.pytorch as pl 7 | from lightning.pytorch.trainer import Trainer 8 | 9 | import math 10 | import importlib 11 | from datetime import datetime 12 | from omegaconf.dictconfig import DictConfig 13 | 14 | 15 | def get_timestamp(): 16 | return datetime.now().strftime('%Y%m%d-%H%M%S') 17 | 18 | 19 | def get_obj_from_str(string, reload=False): 20 | module, cls = string.rsplit(".", 1) 21 | if reload: 22 | module_imp = importlib.import_module(module) 23 | importlib.reload(module_imp) 24 | return getattr(importlib.import_module(module, package=None), cls) 25 | 26 | 27 | def instantiate_from_config(config, extra_kwargs=dict()): 28 | config_dict = dict(config) 29 | if not "target" in config_dict: 30 | if config_dict == '__is_first_stage__': 31 | return None 32 | elif config_dict == "__is_unconditional__": 33 | return None 34 | raise KeyError("Expected key `target` to instantiate.") 35 | target_kwargs = dict(config_dict.get('kwargs', dict())) 36 | 37 | for k, v in target_kwargs.items(): 38 | if isinstance(v, DictConfig) and 'target' in v.keys(): 39 | target_kwargs[k] = instantiate_from_config(v) 40 | target_kwargs.update(extra_kwargs) 41 | return get_obj_from_str(config_dict["target"])(**target_kwargs) 42 | 43 | 44 | def get_train_val_loader(dataset, **dataloader_kwargs): 45 | train_ds, val_ds = dataset.split_train_val(train_ratio=0.95) 46 | train_loader = DataLoader(dataset=train_ds, **dataloader_kwargs, shuffle=True) 47 | val_loader = DataLoader(dataset=val_ds, **dataloader_kwargs, shuffle=False) 48 | return train_loader, val_loader 49 | 50 | 51 | def preprocess_config(config, args): 52 | # set timestamp 53 | task = args.task 54 | project_name = config.model.target.split('.')[-2] + '_logs' 55 | config.trainer.kwargs.logger.kwargs.project = project_name 56 | config.trainer.kwargs.logger.kwargs.name = f'{get_timestamp()}-{task}' 57 | 58 | # overriding horizon 59 | config.horizon = args.horizon 60 | config.model.kwargs.model_kwargs.horizon = args.horizon 61 | config.dataset.kwargs.horizon = args.horizon 62 | 63 | # devices 64 | devices = args.devices 65 | if devices is not None: 66 | devices = devices.split(',') 67 | devices = [int(rank) for rank in devices] 68 | config.trainer.kwargs.devices = devices 69 | 70 | # avoid gpu rank overflow 71 | device_count = torch.cuda.device_count() 72 | if len(config.trainer.kwargs.devices) > device_count: 73 | config.trainer.kwargs.devices = list(range(device_count)) 74 | print(f'using {device_count} devices') 75 | 76 | # batch size for ddp 77 | total_bs = config.dataloader.batch_size 78 | num_devices = len(config.trainer.kwargs.devices) 79 | bs_per_device = total_bs // num_devices 80 | real_bs = bs_per_device * num_devices 81 | if real_bs != total_bs: 82 | print(f'real batch size is {real_bs}') 83 | config.dataloader.batch_size = bs_per_device 84 | 85 | # dataset/tasks/mode 86 | data_cfg = OmegaConf.load(f'{task}_data_cfg.yaml') 87 | 88 | datasets_cfg = data_cfg.datasets 89 | config.dataset.kwargs.root_dir = f'YOUR_DATASET_ROOT_DIR_HERE/{task}_processed' 90 | config.dataset.kwargs.data_cfg = datasets_cfg 91 | config.dataset.kwargs.dataset_names = [key for key in datasets_cfg.keys() if key[0] != '_'] 92 | config.dataset.kwargs.average_step_per_episode = data_cfg.average_step_per_episode 93 | 94 | # feature dimension: 95 | if config.dataset.kwargs.feature_type[:3] == 'r3m': 96 | config.model.kwargs.model_kwargs.language_feature_dim = 768 97 | else: # clip 98 | config.model.kwargs.model_kwargs.language_feature_dim = 512 99 | 100 | return config 101 | 102 | 103 | def get_parser_args(): 104 | parser = argparse.ArgumentParser() 105 | 106 | parser.add_argument( 107 | '--config_name', 108 | default='example_cfg' 109 | ) 110 | parser.add_argument( 111 | '--task', 112 | default='rt-x' 113 | ) 114 | parser.add_argument( 115 | '--devices', 116 | type=str, 117 | default='0', 118 | ) 119 | parser.add_argument( 120 | '--horizon', 121 | type=int, 122 | default=16 123 | ) 124 | 125 | return parser.parse_args() 126 | 127 | 128 | def main(): 129 | args = get_parser_args() 130 | 131 | raw_config = OmegaConf.load(f'{args.config_name}.yaml') 132 | OmegaConf.resolve(raw_config) 133 | config = preprocess_config(raw_config, args) 134 | 135 | pl.seed_everything(config.seed) 136 | 137 | model: pl.LightningModule = instantiate_from_config(config.model, extra_kwargs={"all_config": config}) 138 | 139 | dataset = instantiate_from_config(config.dataset) 140 | train_loader, val_loader = get_train_val_loader(dataset=dataset, **config.dataloader) 141 | 142 | epoch_length = len(train_loader) // len(config.trainer.kwargs.devices) 143 | config.model.kwargs.training_kwargs['num_training_steps'] = epoch_length * config.trainer.kwargs.max_epochs 144 | 145 | trainer: Trainer = instantiate_from_config(config.trainer) 146 | trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader) 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /data_preprocessing/extract_image_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import yaml 4 | from pathlib import Path 5 | import tqdm 6 | import torch 7 | import torchvision.transforms as T 8 | from PIL import Image 9 | import pickle 10 | 11 | import sys 12 | import os 13 | 14 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 16 | 17 | from data_preprocessing.customized_r3m import R3M 18 | import clip 19 | 20 | 21 | DATASET_ROOT_DIR = '' 22 | DEVICE_A = 'cuda:0' 23 | DEVICE_B = 'cuda:1' 24 | 25 | 26 | def get_r3m_preprocess(shorter_edge): 27 | return T.Compose([ 28 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 29 | ]) 30 | 31 | 32 | def get_clip_preprocess(shorter_edge): 33 | return T.Compose([ 34 | T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 35 | ]) 36 | 37 | 38 | def get_preprocess(shorter_edge): 39 | return T.Compose([ 40 | T.CenterCrop(shorter_edge), 41 | T.Resize(224), 42 | T.ToTensor(), 43 | ]) 44 | 45 | 46 | class DatasetConverter: 47 | def __init__( 48 | self, 49 | dataset_name: str, 50 | src_h5_path: Path, 51 | r3m_file_path: Path, 52 | clip_file_path: Path, 53 | data_cfg: dict, 54 | r3m_model_id: str, 55 | clip_model_id: str, 56 | ): 57 | self.dataset_name = dataset_name 58 | self.src_h5_path = src_h5_path 59 | self.r3m_file_path = r3m_file_path 60 | self.clip_file_path = clip_file_path 61 | self.data_cfg = data_cfg 62 | self.r3m_preprocess = dict() 63 | self.clip_preprocess = dict() 64 | self.preprocess = dict() 65 | 66 | self.r3m = R3M(r3m_model_id).to(DEVICE_A).eval() 67 | self.clip, _ = clip.load(clip_model_id, device=DEVICE_B) 68 | self.clip = self.clip.eval() 69 | 70 | def run(self): 71 | print(f'processing {self.src_h5_path}') 72 | 73 | image_cfg = self.data_cfg['image'] 74 | 75 | for view_name, shape in image_cfg.items(): 76 | self.r3m_preprocess[view_name] = get_r3m_preprocess(min(shape[:2])) 77 | self.clip_preprocess[view_name] = get_clip_preprocess(min(shape[:2])) 78 | self.preprocess[view_name] = get_preprocess(min(shape[:2])) 79 | 80 | with h5py.File(str(self.src_h5_path), 'r') as src_file: 81 | num_episodes = int(src_file['episodes']['length'][()]) 82 | keys = [k for k,v in image_cfg.items() if len(v) == 3 and v[-1] == 3] # shape be like [h, w, 3] 83 | 84 | r3m_episodes = [] 85 | clip_episodes = [] 86 | for episode_index in tqdm.trange(num_episodes): 87 | episode = src_file['episodes'][f'episode_{episode_index}'] 88 | r3m_episode = dict() 89 | clip_episode = dict() 90 | 91 | for key in keys: 92 | src_images = episode['observation'][key][()] 93 | r3m_tensors = [] 94 | clip_tensors = [] 95 | 96 | for src_image in src_images: 97 | pil_img = Image.fromarray(src_image) 98 | 99 | preprocessed = self.preprocess[key](pil_img) 100 | r3m_image = self.r3m_preprocess[key](preprocessed) 101 | r3m_tensors.append(r3m_image) 102 | clip_image = self.clip_preprocess[key](preprocessed) 103 | clip_tensors.append(clip_image) 104 | 105 | r3m_tensors = torch.stack(r3m_tensors, dim=0).to(DEVICE_A) 106 | clip_tensors = torch.stack(clip_tensors, dim=0).to(DEVICE_B) 107 | 108 | r3m_features = self.r3m(r3m_tensors).cpu() 109 | clip_features = self.clip.encode_image(clip_tensors).cpu() 110 | 111 | r3m_episode[key] = r3m_features 112 | clip_episode[key] = clip_features 113 | 114 | r3m_episodes.append(r3m_episode) 115 | clip_episodes.append(clip_episode) 116 | 117 | with self.r3m_file_path.open('wb') as r3m_f, self.clip_file_path.open('wb') as clip_f: 118 | pickle.dump(r3m_episodes, r3m_f) 119 | pickle.dump(clip_episodes, clip_f) 120 | 121 | 122 | if __name__ == '__main__': 123 | r3m_model_id = 'resnet34' 124 | 125 | src_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_h5') 126 | r3m_root_dir = Path(f'{DATASET_ROOT_DIR}/our_rt-x/r3m_{r3m_model_id}_image') 127 | r3m_root_dir.mkdir(exist_ok=True, parents=True) 128 | 129 | clip_model_id = 'ViT-B/32' 130 | clip_root_dir = Path(f'{DATASET_ROOT_DIR}/our_rt-x/clip_{clip_model_id.replace("/", "")}_image') 131 | clip_root_dir.mkdir(exist_ok=True, parents=True) 132 | 133 | with open('dataset_preprocessing/data_cfg.yaml', 'r') as f: 134 | data_cfgs = yaml.safe_load(f) 135 | 136 | for dataset_name, cfg in data_cfgs.items(): 137 | if dataset_name[0] == '_': 138 | continue 139 | 140 | src_h5_path=src_root_dir / f'{dataset_name}.hdf5' 141 | r3m_file_path=r3m_root_dir / f'{dataset_name}.pkl' 142 | clip_file_path=clip_root_dir / f'{dataset_name}.pkl' 143 | 144 | if r3m_file_path.exists() and clip_file_path.exists(): 145 | continue 146 | 147 | print(f'processing {dataset_name}') 148 | dataset_converter = DatasetConverter( 149 | dataset_name = dataset_name, 150 | src_h5_path = src_h5_path, 151 | r3m_file_path = r3m_file_path, 152 | clip_file_path = clip_file_path, 153 | data_cfg = cfg, 154 | r3m_model_id = r3m_model_id, 155 | clip_model_id = clip_model_id 156 | ) 157 | try: 158 | with torch.no_grad(): 159 | dataset_converter.run() 160 | except Exception as e: 161 | print(f'{dataset_name} error as:') 162 | print(e) 163 | print('----------------------------------------------') 164 | else: 165 | print(f'{dataset_name} done') 166 | torch.cuda.empty_cache() 167 | -------------------------------------------------------------------------------- /data_preprocessing/extract_language_features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import h5py 5 | import yaml 6 | from pathlib import Path 7 | import tqdm 8 | import pickle 9 | from transformers import AutoTokenizer, AutoModel, AutoConfig 10 | import torch 11 | 12 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 14 | 15 | import clip 16 | from data_preprocessing.customized_r3m import r3m 17 | 18 | 19 | DATASET_ROOT_DIR = '' 20 | DEVICE_A = 'cuda:0' 21 | DEVICE_B = 'cuda:1' 22 | 23 | 24 | def are_elements_same_along_first_dimension(data): 25 | all_same = np.all(np.equal(data[:-1], data[1:]), axis=0) 26 | return np.all(all_same) 27 | 28 | 29 | class DatasetConverter: 30 | def __init__( 31 | self, 32 | dataset_name: str, 33 | src_h5_path: Path, 34 | r3m_file_path: Path, 35 | clip_file_path: Path, 36 | data_cfg: dict, 37 | clip_model_id: str, 38 | ): 39 | self.dataset_name = dataset_name 40 | self.src_h5_path = src_h5_path 41 | self.r3m_file_path = r3m_file_path 42 | self.clip_file_path = clip_file_path 43 | self.data_cfg = data_cfg 44 | 45 | self.r3m_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") 46 | self.r3m_model = AutoModel.from_pretrained("distilbert-base-uncased").to(DEVICE_A) 47 | 48 | self.clip, _ = clip.load(clip_model_id, device=DEVICE_B) 49 | self.clip = self.clip.eval() 50 | 51 | def extract_r3m_features(self, langs): 52 | with torch.no_grad(): 53 | encoded_input = self.r3m_tokenizer(langs, return_tensors='pt', padding=True) 54 | input_ids = encoded_input['input_ids'].to(DEVICE_A) 55 | attention_mask = encoded_input['attention_mask'].to(DEVICE_A) 56 | lang_embedding = self.r3m_model(input_ids, attention_mask=attention_mask).last_hidden_state 57 | lang_embedding = lang_embedding.mean(1) 58 | return lang_embedding.cpu() 59 | 60 | def extract_clip_features(self, langs): 61 | with torch.no_grad(): 62 | text = clip.tokenize(langs).to(DEVICE_B) 63 | text_features = self.clip.encode_text(text) 64 | return text_features.cpu() 65 | 66 | def all_the_same(self): 67 | outer_key = self.data_cfg['language']['outer_key'] 68 | inner_key = self.data_cfg['language']['inner_key'] 69 | 70 | with h5py.File(str(self.src_h5_path), 'r') as src_file: 71 | num_episodes = src_file['episodes']['length'][()] 72 | 73 | for episode_index in tqdm.trange(num_episodes): 74 | episode = src_file['episodes'][f'episode_{episode_index}'] 75 | if inner_key is None: 76 | episode = episode[outer_key][::5] # skip some frames to reduce computation 77 | else: 78 | episode = episode[outer_key][inner_key][::10] 79 | sameness = are_elements_same_along_first_dimension(episode) 80 | if not sameness: 81 | print(f'sentences in episode_{episode_index} are not all the same') 82 | return False 83 | return True 84 | 85 | def run(self): 86 | if self.all_the_same(): 87 | self.run_all_same() 88 | 89 | def run_step_by_step(self): 90 | print(f'processing {self.src_h5_path}') 91 | 92 | outer_key = self.data_cfg['language']['outer_key'] 93 | inner_key = self.data_cfg['language']['inner_key'] 94 | 95 | with h5py.File(str(self.src_h5_path), 'r') as src_file: 96 | num_episodes = src_file['episodes']['length'][()] 97 | episodes = [] 98 | 99 | for episode_index in tqdm.trange(num_episodes): 100 | episode = src_file['episodes'][f'episode_{episode_index}'] 101 | if inner_key is None: 102 | episode = episode[outer_key][()] 103 | else: 104 | episode = episode[outer_key][inner_key][()] 105 | episodes.append([s.decode() for s in episode]) 106 | 107 | r3m_data = [] 108 | clip_data = [] 109 | for langs in episodes: 110 | r3m_features = self.extract_r3m_features(langs) 111 | clip_fetures = self.extract_clip_features(langs) 112 | r3m_data.append(r3m_features) 113 | clip_data.append(clip_fetures) 114 | 115 | with self.r3m_file_path.open('wb') as f: 116 | pickle.dump(r3m_data, f) 117 | with self.clip_file_path.open('wb') as f: 118 | pickle.dump(clip_data, f) 119 | 120 | def run_all_same(self): 121 | print(f'processing {self.src_h5_path}') 122 | 123 | outer_key = self.data_cfg['language']['outer_key'] 124 | inner_key = self.data_cfg['language']['inner_key'] 125 | 126 | with h5py.File(str(self.src_h5_path), 'r') as src_file: 127 | num_episodes = src_file['episodes']['length'][()] 128 | sentences = [] 129 | 130 | for episode_index in tqdm.trange(num_episodes): 131 | episode = src_file['episodes'][f'episode_{episode_index}'] 132 | if inner_key is None: 133 | episode = episode[outer_key][0] 134 | else: 135 | episode = episode[outer_key][inner_key][0] 136 | sentences.append(episode.decode()) 137 | 138 | begin_idx = 0 139 | batch_size = 2048 140 | r3m_data = [] 141 | clip_data = [] 142 | while True: 143 | end_idx = min(begin_idx + batch_size, num_episodes) 144 | 145 | r3m_features = self.extract_r3m_features(sentences[begin_idx: end_idx]) 146 | clip_fetures = self.extract_clip_features(sentences[begin_idx: end_idx]) 147 | r3m_data.append(r3m_features) 148 | clip_data.append(clip_fetures) 149 | 150 | begin_idx = end_idx 151 | if begin_idx >= num_episodes: 152 | break 153 | print(f'{end_idx}', end=' ') 154 | 155 | r3m_data = torch.cat(r3m_data, dim=0) 156 | clip_data = torch.cat(clip_data, dim=0) 157 | 158 | with self.r3m_file_path.open('wb') as f: 159 | pickle.dump(r3m_data, f) 160 | with self.clip_file_path.open('wb') as f: 161 | pickle.dump(clip_data, f) 162 | 163 | 164 | if __name__ == '__main__': 165 | src_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_h5') 166 | 167 | clip_model_id = 'ViT-B/32' 168 | 169 | r3m_root_dir = Path(f'{DATASET_ROOT_DIR}/our_rt-x/distilbert_language') 170 | clip_root_dir = Path(f'{DATASET_ROOT_DIR}/our_rt-x/clip_{clip_model_id.replace("/", "")}_language') 171 | 172 | r3m_root_dir.mkdir(exist_ok=True, parents=True) 173 | clip_root_dir.mkdir(exist_ok=True, parents=True) 174 | 175 | with open('robo_ldm/configs/data_cfg.yaml', 'r') as f: 176 | data_cfgs = yaml.safe_load(f) 177 | 178 | for dataset_name, cfg in data_cfgs.items(): 179 | if dataset_name[0] == '_': 180 | continue 181 | 182 | src_h5_path=src_root_dir / f'{dataset_name}.hdf5' 183 | r3m_file_path=r3m_root_dir / f'{dataset_name}.pkl' 184 | clip_file_path=clip_root_dir / f'{dataset_name}.pkl' 185 | if r3m_file_path.exists() and clip_file_path.exists(): 186 | continue 187 | print(f'processing {dataset_name}') 188 | dataset_converter = DatasetConverter( 189 | dataset_name=dataset_name, 190 | src_h5_path=src_h5_path, 191 | r3m_file_path=r3m_file_path, 192 | clip_file_path=clip_file_path, 193 | data_cfg=cfg, 194 | clip_model_id=clip_model_id 195 | ) 196 | try: 197 | with torch.no_grad(): 198 | dataset_converter.run() 199 | except Exception as e: 200 | print(f'{dataset_name} error as:') 201 | print(e) 202 | print('----------------------------------------------') 203 | else: 204 | print(f'{dataset_name} done') 205 | -------------------------------------------------------------------------------- /data_preprocessing/check_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "DATASET_ROOT_RID = 'YOUR_DIR_HERE'" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import h5py\n", 19 | "from IPython.display import Image, display\n", 20 | "import imageio\n", 21 | "import numpy as np \n", 22 | "import matplotlib.pyplot as plt \n", 23 | "import yaml\n", 24 | "\n", 25 | "\n", 26 | "def display_gif_from_images(images, fps=10, filename='temp.gif'): \n", 27 | " imageio.mimsave(filename, images, format='GIF', fps=fps) \n", 28 | " display(Image(filename=filename))\n", 29 | "\n", 30 | "\n", 31 | "def print_group_structure(group, indent=''): \n", 32 | " for key in group.keys(): \n", 33 | " item = group[key]\n", 34 | " if isinstance(item, h5py.Dataset): \n", 35 | " if key == 'shape':\n", 36 | " print(f\"{indent}- {key}: {item[()]}\") \n", 37 | " else:\n", 38 | " print(f\"{indent}- {key}: {item.shape}\") \n", 39 | "\n", 40 | " elif isinstance(item, h5py.Group):\n", 41 | " print(f\"{indent}- {key}: Group\")\n", 42 | " print_group_structure(item, indent + ' ')\n", 43 | "def accumulate_delta_actions(actions):\n", 44 | " for i in range(1, len(actions)):\n", 45 | " actions[i] += actions[i-1]\n", 46 | "\n", 47 | " \n", 48 | "def plot_3d_points_with_gradient_colors(points): \n", 49 | " x = points[:, 0] \n", 50 | " y = points[:, 1] \n", 51 | " z = points[:, 2] \n", 52 | "\n", 53 | " num_points = len(points)\n", 54 | " colors = plt.cm.viridis(np.arange(num_points) / (num_points - 1))\n", 55 | "\n", 56 | " fig = plt.figure() \n", 57 | " ax = fig.add_subplot(111, projection='3d') \n", 58 | " ax.scatter(x, y, z, c=colors) \n", 59 | "\n", 60 | " ax.set_xlabel('X-axis') \n", 61 | " ax.set_ylabel('Y-axis') \n", 62 | " ax.set_zlabel('Z-axis') \n", 63 | " plt.show() \n", 64 | "\n", 65 | "def get_statistics(all_actions: np.array):\n", 66 | " length = all_actions.shape[-1]\n", 67 | " def format_statistic(data):\n", 68 | " s = '['\n", 69 | " for i in range(length):\n", 70 | " s += str(data[i])\n", 71 | " if i != length-1:\n", 72 | " s += ','\n", 73 | " s += ']'\n", 74 | " return s\n", 75 | " print(f'min: {format_statistic(all_actions.min(axis=0))}')\n", 76 | " print(f'max: {format_statistic(all_actions.max(axis=0))}')\n", 77 | " print(f'mean: {format_statistic(all_actions.mean(axis=0))}')\n", 78 | " print(f'std: {format_statistic(all_actions.std(axis=0))}')" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "ds_cfgs = yaml.safe_load(open('./data_preprocessing/data_cfg.yaml', 'r'))\n", 88 | "\n", 89 | "def ds_cfg_generator():\n", 90 | " for k,v in ds_cfgs.items():\n", 91 | " if k[0] != '_':\n", 92 | " yield (k, v)\n", 93 | "\n", 94 | "ds_cfg_iter = ds_cfg_generator()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "dataset_name, cfg = next(ds_cfg_iter)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "\n", 113 | "f_path = f'/{DATASET_ROOT_RID}/{dataset_name}.hdf5'\n", 114 | "outer_key = cfg['action']['outer_key']\n", 115 | "inner_keys = cfg['action']['inner_keys']\n", 116 | "f = h5py.File(f_path, 'r')\n", 117 | "image_key = list(cfg['image'].keys())[0]\n", 118 | "\n", 119 | "episode = f['episodes/episode_1/']\n", 120 | "actions = []\n", 121 | "for inner_key in inner_keys:\n", 122 | " v = episode[outer_key][inner_key][()]\n", 123 | " if v.dtype == bool or len(v.shape) == 1:\n", 124 | " v = np.array(v, dtype=np.float32).reshape(-1, 1)\n", 125 | " actions.append(v)\n", 126 | "actions = np.concatenate(actions, axis=1)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "images = episode['observation'][image_key][()]\n", 136 | "display_gif_from_images(images)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "f.close()" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "dataset_name = 'fractal20220817_data'\n", 155 | "f_path = f'/{DATASET_ROOT_RID}/{dataset_name}.hdf5'\n", 156 | "\n", 157 | "f = h5py.File(f_path, 'r')\n", 158 | "\n", 159 | "info = f['shape_type_info']\n", 160 | "print_group_structure(info)\n", 161 | "episodes = f['episodes']\n", 162 | "print(f'episode count: {episodes[\"length\"][()]}')" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "def get_episode_action(one_episode):\n", 172 | " actions = []\n", 173 | " for inner_key in inner_keys:\n", 174 | " v = one_episode[outer_key][inner_key][()]\n", 175 | " if v.dtype == bool:\n", 176 | " v = np.array(v, dtype=np.float32).reshape(-1, 1)\n", 177 | " actions.append(v)\n", 178 | " actions = np.concatenate(actions, axis=1)\n", 179 | " return actions" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "# action:\n", 189 | "outer_key = 'action'\n", 190 | "inner_keys = ['world_vector', 'rotation_delta', 'gripper_closedness_action']\n", 191 | "one_episode = episodes['episode_1']\n", 192 | "\n", 193 | "actions = get_episode_action(one_episode)\n", 194 | "\n", 195 | "images = one_episode['observation']['image'][()]" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "# this is desired to be applied on actions as if action is delta, actions should be accumulated to represent a whole trajectory (to be displayed)\n", 205 | "accumulate_delta_actions(actions[:, :3])" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "plot_3d_points_with_gradient_colors(actions)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "display_gif_from_images(images)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "import tqdm\n", 233 | "num_episodes = episodes['length'][()]\n", 234 | "episode_lengthes = [0] * num_episodes\n", 235 | "for i in tqdm.trange(num_episodes):\n", 236 | " episode_lengthes[i] = int(episodes[f'episode_{i}']['length'][()])" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "\n", 246 | "all_actions = np.zeros(shape=(sum(episode_lengthes), actions.shape[-1], ), dtype=np.float32)\n", 247 | "\n", 248 | "begin_idx = 0\n", 249 | "for i in tqdm.trange(num_episodes):\n", 250 | " all_actions[begin_idx: begin_idx+episode_lengthes[i]] = get_episode_action(episodes[f'episode_{i}'])\n", 251 | " begin_idx += episode_lengthes[i]\n", 252 | "\n", 253 | "get_statistics(all_actions)" 254 | ] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "Python 3", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.10.12" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 2 278 | } 279 | -------------------------------------------------------------------------------- /data_preprocessing/convert_tfds_to_h5.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_datasets as tfds 4 | import h5py 5 | import tqdm 6 | import argparse 7 | from pathlib import Path 8 | import yaml 9 | 10 | 11 | DATASET_ROOT_DIR = '' 12 | 13 | 14 | def dataset_add_version(dataset_name): 15 | if dataset_name == 'robo_net': 16 | version = '1.0.0' 17 | elif dataset_name == 'language_table': 18 | version = '0.0.1' 19 | else: 20 | version = '0.1.0' 21 | return f'{dataset_name}/{version}' 22 | 23 | 24 | def get_merged_dataset(src_builder): 25 | splits = src_builder.info.splits 26 | 27 | first_split = list(splits.keys())[0] 28 | merged_dataset = src_builder.as_dataset(split=first_split) 29 | 30 | for split_name in list(splits.keys())[1:]: 31 | dataset = src_builder.as_dataset(split=split_name) 32 | merged_dataset = merged_dataset.concatenate(dataset) 33 | 34 | return merged_dataset 35 | 36 | 37 | def to_np_dtype(maybe_tf_dtype): 38 | dtype_map = { 39 | tf.bool: np.dtype('bool'), 40 | tf.string: str, 41 | tf.float16: np.float16, 42 | tf.float32: np.float32, 43 | tf.float64: np.float64, 44 | tf.int8: np.int8, 45 | tf.int16: np.int16, 46 | tf.int32: np.int32, 47 | tf.int64: np.int64, 48 | tf.uint8: np.uint8, 49 | tf.uint16: np.uint16, 50 | } 51 | 52 | # keep the unfound the same 53 | np_dtype = dtype_map.get(maybe_tf_dtype, maybe_tf_dtype) 54 | 55 | return np_dtype 56 | 57 | 58 | class DatasetConverter: 59 | def __init__( 60 | self, 61 | dataset_name: str, 62 | src_dataset_dir: Path, 63 | tgt_h5_path: Path): 64 | 65 | self.dataset_name = dataset_name 66 | 67 | self.src_dataset_dir = src_dataset_dir 68 | self.tgt_h5_path = tgt_h5_path 69 | 70 | def process_episode(self, h5_group: h5py.Group, episode, episode_index, shape_dtypes): 71 | this_episode_group = h5_group.create_group(name=f'episode_{episode_index}') 72 | 73 | steps = episode['steps'].as_numpy_iterator() 74 | steps = [step for step in steps] 75 | num_steps = len(steps) 76 | this_episode_group.create_dataset(name='length', data=num_steps) 77 | 78 | if 'language_instruction' in shape_dtypes.keys(): 79 | language_instructions = [str(step['language_instruction'], encoding='utf-8') for step in steps] 80 | elif 'natural_language_instruction' in shape_dtypes['observation'].keys(): 81 | language_instructions = [str(step['observation']['natural_language_instruction'], encoding='utf-8') for step in steps] 82 | else: 83 | language_instructions = ['push the T-shaped building block to the matching area'] 84 | 85 | this_episode_group.create_dataset('language_instruction', data=np.array(language_instructions, dtype=h5py.string_dtype(encoding='utf-8'))) 86 | 87 | for data_key, shape_dtype_or_dict in shape_dtypes.items(): 88 | if 'language' in data_key: 89 | continue 90 | shape_dtype_or_dict = shape_dtypes[data_key] 91 | group = this_episode_group.create_group(name=data_key) 92 | 93 | # shape_type should be a dict, or a value 94 | if 'shape' in shape_dtype_or_dict.keys(): 95 | if 'language' in data_key or shape_dtype_or_dict['dtype'] == tf.string: 96 | continue 97 | # if shape_dtypes['action']['shape'] is not hierarchical, can directly access to step['action'] 98 | group.create_dataset( 99 | name=data_key, 100 | data=self.get_episode_np_from_steps( 101 | steps=steps, num_steps=num_steps, shape_dtype=shape_dtype_or_dict, data_key=data_key 102 | ), 103 | ) 104 | else: 105 | # add data hierarchicaly 106 | for k, shape_dtype in shape_dtype_or_dict.items(): 107 | if 'language' in k or shape_dtype['dtype'] == tf.string: 108 | continue 109 | group.create_dataset( 110 | name=k, 111 | data=self.get_episode_np_from_steps( 112 | steps=steps, num_steps=num_steps, shape_dtype=shape_dtype, data_key=data_key, deeper_data_key=k 113 | ), 114 | ) 115 | 116 | @staticmethod 117 | def get_episode_np_from_steps(steps, num_steps, shape_dtype, data_key, deeper_data_key=None): 118 | shape = shape_dtype['shape'] 119 | dtype = shape_dtype['dtype'] 120 | # if deeper_data_key == 'image': 121 | # shape = (shape[0] // 2, shape[1] // 2, shape[2]) 122 | 123 | episode_np_data = np.ndarray( 124 | shape=(num_steps,) + shape, 125 | dtype=to_np_dtype(dtype) 126 | ) 127 | 128 | for step_index, step in enumerate(steps): 129 | if deeper_data_key is None: 130 | episode_np_data[step_index] = step[data_key] 131 | else: 132 | # if deeper_data_key == 'image': 133 | # episode_np_data[step_index] = cv2.resize(step[data_key][deeper_data_key], (shape[1], shape[0]), interpolation=cv2.INTER_AREA) 134 | # else: 135 | episode_np_data[step_index] = step[data_key][deeper_data_key] 136 | 137 | return episode_np_data 138 | 139 | @staticmethod 140 | def dataset_info_to_dict(dataset_info: tfds.core.DatasetInfo) -> dict: 141 | info_dict = { 142 | 'name': dataset_info.name, 143 | 'version': str(dataset_info.version), 144 | 'description': dataset_info.description, 145 | 'homepage': dataset_info.homepage, 146 | 'citation': dataset_info.citation, 147 | # 'splits': list(dataset_info.splits.keys()), 148 | 'features': str(dataset_info.features), 149 | } 150 | return info_dict 151 | 152 | def _add_dict_to_group(self, h5_group: h5py.Group, tar_dict: dict): 153 | for k, v in tar_dict.items(): 154 | if isinstance(v, dict): 155 | g = h5_group.create_group(name=k) 156 | self._add_dict_to_group(g, v) 157 | else: 158 | try: 159 | h5_group.create_dataset(name=k, data=v) 160 | except: 161 | # print(f'try to convert data of {k} to string') 162 | try: 163 | h5_group.create_dataset(name=k, data=str(v)) 164 | except: 165 | print(f'{k} can not be added into h5 file') 166 | 167 | def merge_shapes_dtypes(self, shapes, dtypes): 168 | assert shapes.keys() == dtypes.keys() 169 | 170 | res = dict() 171 | for k, v in shapes.items(): 172 | if isinstance(v, dict): 173 | res[k] = self.merge_shapes_dtypes(shapes[k], dtypes[k]) 174 | else: 175 | res[k] = { 176 | 'shape': shapes[k], 177 | 'dtype': dtypes[k] 178 | } 179 | return res 180 | 181 | def _add_builder_to_h5(self, h5_file: h5py.File, tfds_builder: tfds.core.dataset_builder.DatasetBuilder): 182 | # |__/meta_info 183 | # |__/shape_info 184 | # |__/episodes 185 | info = tfds_builder.info 186 | 187 | info_dict = self.dataset_info_to_dict(info) 188 | info_group = h5_file.create_group(name='meta_info') 189 | for k, v in info_dict.items(): 190 | info_group.create_dataset(name=k, data=v) 191 | 192 | shape_type_group = h5_file.create_group(name='shape_type_info') 193 | shapes = info.features.shape['steps'] 194 | dtypes = info.features.dtype['steps'] 195 | shape_types = self.merge_shapes_dtypes(shapes, dtypes) 196 | 197 | self._add_dict_to_group(shape_type_group, shape_types) 198 | 199 | merged_dataset = get_merged_dataset(tfds_builder) 200 | episodes_group = h5_file.create_group('episodes') 201 | 202 | num_episodes = int(merged_dataset.cardinality()) 203 | 204 | episodes_group.create_dataset(name='length', data=num_episodes) 205 | 206 | # single-processing 207 | for episode_index, episode in enumerate(tqdm.tqdm(merged_dataset, total=num_episodes)): 208 | self.process_episode( 209 | h5_group=episodes_group, episode=episode, episode_index=episode_index, shape_dtypes=shape_types 210 | ) 211 | 212 | def run(self): 213 | print(f'target h5 path: {self.tgt_h5_path}') 214 | 215 | builder = tfds.builder_from_directory(builder_dir=str(self.src_dataset_dir)) 216 | with h5py.File(str(self.tgt_h5_path), 'w') as f: 217 | self._add_builder_to_h5(f, builder) 218 | 219 | print(f'data saved at {self.tgt_h5_path}') 220 | 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument('--i', default=0) 225 | args = parser.parse_args() 226 | 227 | dataset_index = int(args.i) 228 | 229 | src_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x') 230 | tgt_root_dir = Path(f'{DATASET_ROOT_DIR}/rt-x_h5') 231 | tgt_root_dir.mkdir(exist_ok=True) 232 | 233 | with open('./data_preprocessing/dataset_list.yaml', 'r') as f: 234 | dataset_list = yaml.safe_load(f) 235 | dataset_list = dataset_list['small'] 236 | dataset_name = dataset_list[dataset_index] 237 | 238 | dataset_converter = DatasetConverter( 239 | dataset_name=dataset_name, 240 | src_dataset_dir=src_root_dir / dataset_add_version(dataset_name), 241 | tgt_h5_path=tgt_root_dir / f'{dataset_name}.hdf5' 242 | ) 243 | dataset_converter.run() 244 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from itertools import accumulate 3 | import os 4 | import copy 5 | import random 6 | from typing import List 7 | import pickle 8 | import numpy as np 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class SingleDataset(Dataset): 15 | def __init__( 16 | self, 17 | dataset_name: str, 18 | action_path: str, 19 | language_path: str, 20 | image_path: str, 21 | low_dim_path: str, 22 | data_cfg: dict, 23 | load_language: bool, 24 | load_image: bool, 25 | load_low_dim: bool 26 | ): 27 | super().__init__() 28 | 29 | self.dataset_name = dataset_name 30 | self.data_cfg = data_cfg 31 | 32 | self.action_path = action_path 33 | self.language_path = language_path 34 | self.image_path = image_path 35 | self.low_dim_path = low_dim_path 36 | 37 | self.image_cfg = data_cfg['image'] 38 | try: 39 | self.image_keys = list(data_cfg['image'].keys()) 40 | except: 41 | pass 42 | self.canonical_image_key = data_cfg['canonical_view'] 43 | 44 | self.image_preprocess = None 45 | self.action_data = None 46 | self.image_data = None 47 | self.language_data = None 48 | self.load_image = load_image 49 | self.load_language = load_language 50 | self.load_low_dim = load_low_dim 51 | 52 | def split_train_val(self, train_ratio=0.98): 53 | val_ds = copy.deepcopy(self) 54 | 55 | with open(self.action_path, 'rb') as f: 56 | action_data = pickle.load(f) 57 | n_episodes = len(action_data) 58 | if n_episodes > 10: 59 | num_train_episodes = int(train_ratio * len(action_data)) - 1 # at least one val episode 60 | else: 61 | num_train_episodes = len(action_data) - 1 # few shot, ignoring val 62 | self.action_data = action_data[:num_train_episodes] 63 | val_ds.action_data = action_data[num_train_episodes:] 64 | 65 | if self.load_language: 66 | with open(self.language_path, 'rb') as f: 67 | language_data = pickle.load(f) 68 | if len(language_data) == 1: 69 | self.language_data = language_data 70 | val_ds.language_data = language_data 71 | else: 72 | self.language_data = language_data[:num_train_episodes] 73 | val_ds.language_data = language_data[num_train_episodes:] 74 | 75 | if self.load_image: 76 | with open(self.image_path, 'rb') as f: 77 | image_data = pickle.load(f) 78 | self.image_data = image_data[:num_train_episodes] 79 | val_ds.image_data = image_data[num_train_episodes:] 80 | 81 | if self.load_low_dim: 82 | with open(self.low_dim_path, 'rb') as f: 83 | low_dim_data = pickle.load(f) 84 | self.low_dim_data = low_dim_data[:num_train_episodes] 85 | val_ds.low_dim_data = low_dim_data[num_train_episodes:] 86 | 87 | return self, val_ds 88 | 89 | def __len__(self): 90 | # return num of episodes 91 | return len(self.action_data) 92 | 93 | def get_data( 94 | self, horizon, episode_index, get_language, get_canonical_image, get_image_dict, get_low_dim, 95 | recursize_depth=0, 96 | ): 97 | if recursize_depth > 100: 98 | print(f'No enough episode longer than {horizon} steps in {self.dataset_name}') 99 | return None 100 | 101 | num_steps = len(self.action_data[episode_index]) 102 | if num_steps < horizon: # get random one 103 | episode_index = random.randint(0, len(self) - 1) 104 | return self.get_data( 105 | horizon=horizon, episode_index=episode_index, get_language=get_language, get_canonical_image=get_canonical_image, get_image_dict=get_image_dict, get_low_dim=get_low_dim, 106 | recursize_depth = recursize_depth + 1, 107 | ) 108 | 109 | # actions 110 | begin_index = random.randint(0, num_steps - horizon) 111 | action_seq = self.action_data[episode_index][begin_index: begin_index + horizon] 112 | 113 | data = {'action': action_seq} 114 | 115 | if get_canonical_image: 116 | key = random.choice(self.image_keys) 117 | image_feature = self.image_data[episode_index][key][begin_index].unsqueeze(0) 118 | data.update({'image': image_feature.to(dtype=torch.float32)}) 119 | 120 | elif get_image_dict: 121 | # check all the data have same num of views in the same dataset 122 | multiview_image_tensors = [] 123 | for key in sorted(self.image_cfg): # must be sorted 124 | image_feature = self.image_data[episode_index][key][begin_index] 125 | multiview_image_tensors.append(image_feature) 126 | multiview_image_tensors = torch.stack(multiview_image_tensors, dim=0).to(dtype=torch.float32) 127 | 128 | data.update({'image': multiview_image_tensors}) 129 | 130 | if get_language: 131 | if len(self.language_data) == 1: 132 | # downstream finetuning dataset 133 | language_feature = self.language_data[0] 134 | else: 135 | language_feature = self.language_data[episode_index] 136 | data.update({'language': language_feature.unsqueeze(0).to(dtype=torch.float32)}) 137 | 138 | if get_low_dim: 139 | data.update({'low_dim': self.low_dim_data[episode_index][begin_index].unsqueeze(0)}) 140 | 141 | return data 142 | 143 | 144 | class MultiDataset(Dataset): 145 | def __init__( 146 | self, 147 | root_dir: str, 148 | dataset_names: List[str], 149 | horizon: int, 150 | data_cfg: dict, 151 | get_language: bool, 152 | get_canonical_image: bool, 153 | get_image_dict: bool, 154 | get_low_dim: bool, 155 | feature_type: str, 156 | average_step_per_episode: int 157 | ): 158 | super().__init__() 159 | self.horizon = horizon 160 | 161 | self.get_language = get_language 162 | self.get_canonical_image = get_canonical_image 163 | self.get_image_dict = get_image_dict 164 | self.get_low_dim = get_low_dim 165 | self.average_step_per_episode = average_step_per_episode 166 | 167 | self.dataset_names = dataset_names 168 | self.datasets = [ 169 | SingleDataset( 170 | dataset_name = dataset_name, 171 | action_path = os.path.join(root_dir, 'normalized_actions', f'{dataset_name}.pkl'), 172 | low_dim_path = os.path.join(root_dir, f'low_dim', f'{dataset_name}.pkl'), 173 | language_path = os.path.join(root_dir, f'{feature_type}_language', f'{dataset_name}.pkl'), 174 | image_path = os.path.join(root_dir, f'{feature_type}_image', f'{dataset_name}.pkl'), 175 | data_cfg = data_cfg[dataset_name], 176 | load_language = get_language, 177 | load_image = get_canonical_image or get_image_dict, 178 | load_low_dim = get_low_dim 179 | ) for dataset_name in dataset_names 180 | ] 181 | 182 | self.dynamic_variables_loaded = False 183 | self.dataset_lengthes = None 184 | self.accumulated_lengthes = None 185 | self.image_preprocess = None 186 | 187 | def register_image_preprocess_hook(self, func): 188 | self.image_preprocess = func 189 | 190 | def split_train_val(self, train_ratio): 191 | train_datasets = [] 192 | val_datasets = [] 193 | 194 | for dataset in self.datasets: 195 | train_ds, val_ds = dataset.split_train_val(train_ratio) 196 | train_datasets.append(train_ds) 197 | val_datasets.append(val_ds) 198 | 199 | val_dataset = copy.deepcopy(self) 200 | 201 | val_dataset.datasets = val_datasets 202 | self.datasets = train_datasets 203 | 204 | self.load_dynamic_variables() 205 | val_dataset.load_dynamic_variables() 206 | print(f'{len(train_datasets)} datasets loaded') 207 | return self, val_dataset 208 | 209 | def load_dynamic_variables(self): 210 | # called after splitting 211 | self.dataset_lengthes = [len(dataset) for dataset in self.datasets] 212 | self.accumulated_lengthes = list(accumulate(self.dataset_lengthes)) 213 | self.dynamic_variables_loaded = True 214 | 215 | def __len__(self): 216 | if not self.dynamic_variables_loaded: 217 | raise ValueError('please call load_dynamic_variables() before training') 218 | # len(dataset) returns the num of episodes, scaling by an avarage number of steps per episode 219 | return self.accumulated_lengthes[-1] * self.average_step_per_episode 220 | 221 | def get_dataset_and_episode_index(self, index): 222 | if not self.dynamic_variables_loaded: 223 | raise ValueError('please call load_dynamic_variables() before training') 224 | 225 | index = index % self.accumulated_lengthes[-1] 226 | 227 | dataset_idx = bisect.bisect_right(self.accumulated_lengthes, index) 228 | 229 | if dataset_idx == 0: 230 | data_index = index 231 | else: 232 | data_index = index - self.accumulated_lengthes[dataset_idx - 1] 233 | 234 | return dataset_idx, data_index 235 | 236 | def __getitem__(self, index): 237 | dataset_index, data_index = self.get_dataset_and_episode_index(index) 238 | return self.datasets[dataset_index].get_data( 239 | horizon = self.horizon, 240 | episode_index = data_index, 241 | get_canonical_image = self.get_canonical_image, 242 | get_image_dict = self.get_image_dict, 243 | get_language = self.get_language, 244 | get_low_dim = self.get_low_dim 245 | ) 246 | -------------------------------------------------------------------------------- /data_preprocessing/rt-x_data_cfg.yaml: -------------------------------------------------------------------------------- 1 | average_step_per_episode: 100 2 | 3 | datasets: 4 | _maniskill_dataset_converted_externally_to_rlds: 5 | language: 6 | outer_key: 7 | inner_key: 8 | action: 9 | outer_key: 10 | inner_keys: 11 | gripper_is_bool: 12 | gripper_close_is_positive: 13 | index_mapping: 14 | image: 15 | image: 16 | canonical_view: 17 | 18 | _bc_z: 19 | language: 20 | outer_key: 21 | inner_key: 22 | action: 23 | outer_key: 24 | inner_keys: 25 | gripper_is_bool: 26 | gripper_close_is_positive: 27 | index_mapping: 28 | image: 29 | image: 30 | canonical_view: 31 | 32 | robo_net: 33 | language: 34 | outer_key: language_instruction 35 | inner_key: 36 | action: 37 | outer_key: action 38 | inner_keys: [action] 39 | gripper_is_bool: True 40 | gripper_signal: binary 41 | gripper_close_is_positive: False 42 | min: [-0.15101519,-0.1565145,-0.6257775,0,0,-0.9654438,-1.0] 43 | max: [0.15718322,0.14304419,0.57623166,0,0,0.9297816,1.0] 44 | mean: [-1.1366916e-05,1.5125085e-05,-0.008062339,0,0,-2.5888812e-05,-0.028595591] 45 | std: [0.032429904,0.032394417,0.09075778,0,0,0.17676172,0.9996361] 46 | index_mapping: [1, 2, 3, 0, 0, 4, 5] 47 | image: 48 | image: [240, 320, 3] 49 | canonical_view: image 50 | 51 | _toto: 52 | language: 53 | outer_key: observation 54 | inner_key: natural_language_instruction 55 | action: 56 | outer_key: action 57 | inner_keys: [world_vector, rotation_delta, open_gripper] 58 | abs_action: True 59 | gripper_is_bool: True 60 | gripper_close_is_positive: False 61 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 62 | image: 63 | image: [480, 640, 3] 64 | canonical_view: image 65 | 66 | bridge: 67 | language: 68 | outer_key: language_instruction 69 | inner_key: 70 | action: 71 | outer_key: action 72 | inner_keys: [world_vector, rotation_delta, open_gripper] 73 | gripper_is_bool: True 74 | 75 | gripper_signal: binary 76 | gripper_close_is_positive: False 77 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 78 | min: [-0.15559052,-0.16075383,-0.20664044,-0.78723234,-0.65459704,-1.1610136,0.0] 79 | max: [0.13607752,0.17013738,0.2128921,0.75443304,0.5700094,0.82569605,1.0] 80 | mean: [0.0005796338,0.0001665241,0.00025199732,-7.463656e-05,-0.002136332,0.00010259719,0.6684366] 81 | std: [0.010073104,0.014915115,0.013089378,0.03009547,0.031722616,0.05266533,0.47108224] 82 | image: 83 | image: [480, 640, 3] 84 | canonical_view: image 85 | 86 | _utaustin_mutex: # too long instruction? 87 | language: 88 | outer_key: language_instruction 89 | inner_key: 90 | 91 | action: 92 | outer_key: action 93 | inner_keys: [action] 94 | gripper_is_bool: False # if True, the last key will be converted to float 95 | gripper_close_is_positive: True # If true, a larger gripper action stands for a closer gripper 96 | 97 | gripper_signal: binary 98 | # then concat all the data step[outer_key][inner_keys] (or just step[outer_key]) 99 | # supppose the concated data is 5-dim ([x, y, z, yaw, gripper]) 100 | # first pad the concated data with 0 at index 0 [0, x, y, z, yaw, gripper] 101 | # the target should be [x,y,z,0,0,yaw,gripper] so create tgt=np.zeros(shape=(7,)) 102 | # then the values according to the mapping 103 | # concated_data: [0:0, 1:x, 2:y, 3:z, 4:yaw, 5:gripper] 104 | 105 | # index_mapping: [1, 2, 3, 0, 0, 4, 5] 106 | # target_data: [x, y, z, 0, 0, yaw, gripper] 107 | # code like: 108 | # target_action = np.zeros(shape=(7, ), dtype=np.float32) 109 | # for tgt_idx, src_idx in enumerate(index_mapping): 110 | # target_action[tgt_idx] = concated_data[src_idx] 111 | # normalized_tgt_action = (tgt_action - mean) / std 112 | 113 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 114 | min: [-1.0,-1.0,-1.0,-0.375,-0.375,-0.375,-1.0] 115 | max: [1.0,1.0,1.0,0.375,0.375,0.375,1.0] 116 | mean: [0.06176343,-0.0050054663,0.10216819,-0.03314115,0.013894996,-0.0113176685,-0.007795337] 117 | std: [0.18749881,0.4468454,0.37927994,0.14098226,0.064536214,0.11765014,1.00209] 118 | 119 | image: 120 | image: [128, 128, 3] 121 | wrist_image: [128, 128, 3] 122 | canonical_view: image 123 | 124 | _berkeley_fanuc_manipulation: 125 | language: 126 | outer_key: language_instruction 127 | inner_key: 128 | action: 129 | outer_key: 130 | inner_keys: 131 | gripper_is_bool: False 132 | gripper_close_is_positive: False 133 | index_mapping: [] 134 | mean: [] 135 | std: [] 136 | max: [] 137 | min: [] 138 | image: 139 | image: [224, 224, 3] 140 | wrist_image: [224, 224, 3] 141 | 142 | _cmu_play_fusion: 143 | language: 144 | outer_key: language_instruction 145 | inner_key: 146 | action: 147 | outer_key: action 148 | inner_keys: [action] 149 | gripper_is_bool: False 150 | gripper_close_is_positive: False 151 | index_mapping: [1, 2, 3, 8] 152 | image: 153 | image: [128, 128, 3] 154 | 155 | cmu_stretch: 156 | language: 157 | outer_key: language_instruction 158 | inner_key: 159 | action: 160 | outer_key: action 161 | inner_keys: [action] 162 | gripper_is_bool: False 163 | 164 | gripper_signal: binary 165 | gripper_close_is_positive: False 166 | min: [-0.019353798,0.0,-0.020192152,0.0,0.0,0.0,0.0] 167 | max: [0.023384072,0.0,0.023404928,0.0,0.0,0.0,1.0] 168 | mean: [0.0003630468,0.0,0.001646696,0.0,0.0,0.0,0.39870483] 169 | std: [0.00408182,1.0,0.0037743386,1.0,1.0,1.0,0.48963726] 170 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 171 | image: 172 | image: [128, 128, 3] 173 | canonical_view: image 174 | 175 | fractal20220817_data: 176 | language: 177 | outer_key: language_instruction 178 | inner_key: 179 | action: 180 | outer_key: action 181 | inner_keys: [world_vector, rotation_delta, gripper_closedness_action] 182 | gripper_is_bool: False 183 | gripper_signal: force-threshold-0 184 | gripper_close_is_positive: True 185 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 186 | min: [-2.020452,-5.4978995,-2.0316634,-1.5699179,-1.5698922,-1.5704194,-1.0] 187 | max: [2.9984593,22.090528,2.7507524,1.5706365,1.5321087,1.5691522,1.0] 188 | mean: [0.0069875014,0.0062659234,-0.012625135,0.043331914,-0.005756167,0.00091309793,0.021864852] 189 | std: [0.06921227,0.059654854,0.073531315,0.15610056,0.13164213,0.14593266,0.3603207] 190 | image: 191 | image: [256, 320, 3] 192 | canonical_view: image 193 | 194 | _kuka: # episodes very short 195 | language: 196 | outer_key: observation 197 | inner_key: natural_language_instruction 198 | action: 199 | outer_key: action 200 | inner_keys: [world_vector, rotation_delta, gripper_closedness_action] 201 | gripper_is_bool: False 202 | gripper_close_is_positive: True 203 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 204 | image: 205 | image: [256, 320, 3] 206 | canonical_view: image 207 | 208 | taco_play: 209 | language: 210 | outer_key: language_instruction 211 | inner_key: 212 | action: 213 | outer_key: action 214 | inner_keys: [rel_actions_world] 215 | gripper_is_bool: False 216 | 217 | gripper_signal: binary 218 | gripper_close_is_positive: False 219 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 220 | mean: [-0.0038459413,0.009671559,0.012780648,-0.0054038013,-0.009606571,-0.0024807807,-0.1472174] 221 | std: [0.23254019,0.36298397,0.28692976,0.26177117,0.24388969,0.52164966,0.98938197] 222 | max: [1.4915844,2.1842432,2.6836395,5.035227,2.665865,4.2507687,1.0] 223 | min: [-4.242458,-3.192805,-1.3371468,-4.2026834,-2.6722639,-3.3467135,-1.0] 224 | image: 225 | depth_gripper: [84, 84] 226 | depth_static: [150, 200] 227 | rgb_gripper: [84, 84, 3] 228 | rgb_static: [150, 200, 3] 229 | canonical_view: rgb_static 230 | 231 | jaco_play: 232 | language: 233 | outer_key: language_instruction 234 | inner_key: 235 | action: 236 | outer_key: action 237 | inner_keys: [world_vector, gripper_closedness_action] 238 | gripper_is_bool: False 239 | 240 | gripper_signal: force-threshold-0 241 | gripper_close_is_positive: True 242 | index_mapping: [1, 2, 3, 0, 0, 0, 4] 243 | mean: [0.00096585735,-0.005800745,-0.003950604,0.0,0.0,0.0,0.029295197] 244 | std: [0.122350916,0.09678775,0.11155401,0.0,0.0,0.0,0.47126555] 245 | max: [0.2,0.2,0.2,0.0,0.0,0.0,1.0] 246 | min: [-0.2,-0.2,-0.2,0.0,0.0,0.0,-1.0] 247 | image: 248 | image: [224, 224, 3] 249 | image_wrist: [224, 224, 3] 250 | canonical_view: image 251 | 252 | berkeley_cable_routing: 253 | language: 254 | outer_key: language_instruction 255 | inner_key: 256 | action: 257 | outer_key: action 258 | inner_keys: [world_vector, rotation_delta] 259 | gripper_is_bool: False 260 | 261 | gripper_signal: 262 | gripper_close_is_positive: False 263 | index_mapping: [1, 2, 3, 4, 5, 6, 0] 264 | mean: [-0.07139853,0.023609024,0.102419436,0.0,0.0,0.04967077,0.0] 265 | std: [0.18154977,0.18109904,0.21220727,0.0,0.0,0.3475515,0.0] 266 | max: [0.9633283,1.0,1.0,0.0,0.0,1.0,0.0] 267 | min: [-0.98090816,-0.9554349,-0.9994775,0.0,0.0,-1.0,0.0] 268 | image: 269 | image: [128, 128, 3] 270 | top_image: [128, 128, 3] 271 | wrist225_image: [128, 128, 3] 272 | wrist45_image: [128, 128, 3] 273 | canonical_view: image 274 | 275 | roboturk: # one episode contains multiple tasks, mind the instructions 276 | language: 277 | outer_key: language_instruction 278 | inner_key: 279 | action: 280 | outer_key: action 281 | inner_keys: [world_vector, rotation_delta, gripper_closedness_action] 282 | gripper_is_bool: False 283 | 284 | gripper_signal: binary 285 | gripper_close_is_positive: True 286 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 287 | mean: [0.0014448888,-0.0015945357,-0.0011753805,0.0023012396,-0.00093824376,-0.00011485874,-0.1492051] 288 | std: [0.0493537,0.06354564,0.06116491,0.09553406,0.084200144,0.065179124,0.9890353] 289 | max: [0.39124173,0.46010283,0.48708335,1.8168887,1.8240283,1.4824821,1.0] 290 | min: [-0.6546999,-0.6365841,-0.42177236,-1.6695483,-1.8023357,-1.4630828,-1.0] 291 | image: 292 | front_rgb: [480, 640, 3] 293 | canonical_view: front_rgb 294 | 295 | nyu_door_opening_surprising_effectiveness: 296 | language: 297 | outer_key: language_instruction 298 | inner_key: 299 | action: 300 | outer_key: action 301 | inner_keys: [world_vector, rotation_delta, gripper_closedness_action] 302 | gripper_is_bool: False 303 | 304 | gripper_signal: force-threshold-0.01 305 | gripper_close_is_positive: True 306 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 307 | mean: [-0.0062134075,0.0012825616,0.0012440508,-0.0001565738,-0.011653881,-0.0015855222,0.022675302] 308 | std: [0.012148533,0.019559348,0.0077468683,0.008133054,0.033606686,0.013172098,0.08920925] 309 | max: [0.035083357,0.063189335,0.043332618,0.054775182,0.1734558,0.06685609,0.96622527] 310 | min: [-0.27681807,-0.10727508,-0.3554436,-0.036789477,-0.21697818,-0.0676727,-0.28207946] 311 | image: 312 | image: [720, 960, 3] 313 | canonical_view: image 314 | 315 | viola: 316 | language: 317 | outer_key: language_instruction 318 | inner_key: 319 | action: 320 | outer_key: action 321 | inner_keys: [world_vector, rotation_delta, gripper_closedness_action] 322 | gripper_is_bool: False 323 | 324 | gripper_signal: binary 325 | gripper_close_is_positive: True 326 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 327 | mean: [0.047618777,-0.02920461,0.055867158,-0.0026184842,0.0068673426,-0.016821278,-0.4647555] 328 | std: [0.39158034,0.4076541,0.4007757,0.10023959,0.084432065,0.10375133,0.88521934] 329 | max: [1.0,1.0,1.0,0.375,0.36321428,0.375,1.0] 330 | min: [-1.0,-1.0,-1.0,-0.375,-0.375,-0.375,-1.0] 331 | image: 332 | agentview_rgb: [224, 224, 3] 333 | eye_in_hand_rgb: [224, 224, 3] 334 | canonical_view: agentview_rgb 335 | 336 | berkeley_autolab_ur5: 337 | language: 338 | outer_key: language_instruction 339 | inner_key: 340 | action: 341 | outer_key: action 342 | inner_keys: [world_vector, rotation_delta, gripper_closedness_action] 343 | gripper_is_bool: False 344 | 345 | gripper_signal: force-threshold-0.1 346 | gripper_close_is_positive: True 347 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 348 | mean: [0.0005683611,0.0012176944,-0.0005296353,0.00021029709,6.069498e-05,0.0012049833,0.0] 349 | std: [0.011533243,0.007990539,0.009577767,0.009433038,0.01642755,0.011054285,0.12375644] 350 | max: [0.02,0.02,0.02,0.06666667,0.06666667,0.06666667,1.0] 351 | min: [-0.02,-0.02,-0.02,-0.06666667,-0.06666667,-0.06666667,-1.0] 352 | image: 353 | hand_image: [480, 640, 3] 354 | image: [480, 640, 3] 355 | image_with_depth: [480, 640, 1] 356 | canonical_view: image 357 | 358 | columbia_cairlab_pusht_real: 359 | language: 360 | outer_key: language_instruction 361 | inner_key: 362 | action: 363 | outer_key: action 364 | inner_keys: [world_vector, rotation_delta, gripper_closedness_action] 365 | gripper_is_bool: False 366 | gripper_close_is_positive: False 367 | 368 | gripper_signal: 369 | # index_mapping: [1, 2, 3, 4, 5, 6, 7] 370 | index_mapping: [1, 2, 0, 0, 0, 0, 0] 371 | mean: [-0.0013286311,0.0014844551,0.0,0.0,0.0,0.0,0.0] 372 | std: [0.013066671,0.015327031,0.0,0.0,0.0,0.0,0.0] 373 | max: [0.23432465,0.45121098,0.0,0.0,0.0,0.0,0.0] 374 | min: [-0.3853991,-0.31175086,0.0,0.0,0.0,0.0,0.0] 375 | image: 376 | image: [240, 320, 3] 377 | wrist_image: [240, 320, 3] 378 | canonical_view: image 379 | 380 | stanford_kuka_multimodal_dataset_converted_externally_to_rlds: 381 | language: 382 | outer_key: language_instruction 383 | inner_key: 384 | action: 385 | outer_key: action 386 | inner_keys: [action] 387 | gripper_is_bool: False 388 | gripper_close_is_positive: False 389 | 390 | gripper_signal: 391 | index_mapping: [1, 2, 3, 0, 0, 4, 0] # should be checked 392 | mean: [-0.0028558734,0.00039253273,-0.0029044405,0,0,-0.0019456731, 0] 393 | std: [0.02839532,0.029717535,0.027051711,1,1,0.09263818, 1] 394 | max: [0.05,0.05,0.05,0, 0, 0.15, 0] 395 | min: [-0.05,-0.05,-0.05,0,0,-0.15, 0] 396 | image: 397 | depth_image: [128, 128, 1] 398 | image: [128, 128, 3] 399 | canonical_view: image 400 | 401 | nyu_rot_dataset_converted_externally_to_rlds: 402 | language: 403 | outer_key: language_instruction 404 | inner_key: 405 | action: 406 | outer_key: action 407 | inner_keys: [action] 408 | gripper_is_bool: False 409 | 410 | gripper_signal: 411 | gripper_close_is_positive: False 412 | index_mapping: [1, 2, 3, 0, 0, 5, 0] 413 | mean: [0.1965463,-0.104999505,-0.039819203,0.0,0.0022727272,0.0,0.0] 414 | std: [0.4231604,0.46901545,0.42725626,0.0,0.047619108,0.0,0.0] 415 | max: [1.0,1.0,1.0,0.0,1.0,0.0,0.0] 416 | min: [-1.0,-1.0,-1.0,0.0,0.0,0.0,0.0] 417 | image: 418 | image: [84, 84, 3] 419 | canonical_view: image 420 | 421 | stanford_hydra_dataset_converted_externally_to_rlds: 422 | language: 423 | outer_key: language_instruction 424 | inner_key: 425 | action: 426 | outer_key: action 427 | inner_keys: [action] 428 | gripper_is_bool: False 429 | gripper_close_is_positive: True 430 | 431 | gripper_signal: binary 432 | index_mapping: [1, 2, 3, 4, 5, 6, 7] 433 | mean: [0.0007790164,0.00013707925,-0.0002548616,0.0012903379,-0.0047517866,0.0026929018,0.51144785] 434 | std: [0.008022228,0.009131469,0.009574297,0.04122218,0.038430043,0.04606715,0.4997606] 435 | max: [0.024998546,0.024999034,0.024999922,0.24974458,0.2499703,0.24999946,1.0] 436 | min: [-0.024999045,-0.0249997,-0.024999298,-0.24993226,-0.2499666,-0.24999325,0.0] 437 | image: 438 | image: [240, 320, 3] 439 | wrist_image: [240, 320, 3] 440 | canonical_view: image 441 | 442 | austin_buds_dataset_converted_externally_to_rlds: 443 | language: 444 | outer_key: language_instruction 445 | inner_key: 446 | action: 447 | outer_key: action 448 | inner_keys: [action] 449 | gripper_is_bool: False 450 | 451 | gripper_signal: binary 452 | gripper_close_is_positive: True 453 | index_mapping: [1, 2, 3, 0, 0, 0, 7] 454 | mean: [-0.07678356,0.0036849175,0.05644921,0.0,0.0,0.0,0.29790103] 455 | std: [0.63677496,0.3788917,0.4779629,0.0,0.0,0.0,0.95442176] 456 | max: [1.0,1.0,1.0,0.0,0.0,0.0,1.0] 457 | min: [-1.0,-1.0,-1.0,0.0,0.0,0.0,-1.0] 458 | image: 459 | image: [128, 128, 3] 460 | wrist_image: [128, 128, 3] 461 | canonical_view: image 462 | 463 | _nyu_franka_play_dataset_converted_externally_to_rlds: 464 | language: 465 | outer_key: language_instruction 466 | inner_key: 467 | action: 468 | outer_key: action 469 | inner_keys: [action] 470 | gripper_is_bool: False 471 | gripper_close_is_positive: False 472 | index_mapping: [1,2,3,4,5,6,14] 473 | mean: [-0.002895747,0.003557173,0.00064357877,0.0023153278,-0.0023922238,-0.0015158538,0.0027481185,0.0010219901,-0.00012002677,0.0003289423,0.0015034276,-0.0021985276,-0.0016632306,0.4460167,0.01016156] 474 | std: [0.037964188,0.02999344,0.03088907,0.035438135,0.02886598,0.043608427,0.06070748,0.013274147,0.013215902,0.012822104,0.2732448,0.057022575,0.039172936,0.8950625,0.10027119] 475 | max: [0.14495707,0.13028586,0.11370349,0.11802268,0.21037066,0.12956262,0.19070536,0.064241886,0.07027635,0.061296612,6.281068,0.196773,0.26377416,1.0,1.0] 476 | min: [-0.107201084,-0.11304033,-0.11667186,-0.12557268,-0.12566182,-0.1590954,-0.19047071,-0.0595223,-0.072324455,-0.06730807,-6.2784348,-0.21479034,-0.36276197,-1.0,0.0] 477 | image: 478 | image: [128, 128, 3] 479 | image_additional_view: [128, 128, 3] 480 | depth: [128, 128, 1] 481 | depth_additional_view: [128, 128, 1] 482 | canonical_view: image 483 | 484 | _cmu_franka_exploration_dataset_converted_externally_to_rlds: 485 | # seems no enough steps per episode 486 | language: 487 | outer_key: language_instruction 488 | inner_key: 489 | action: 490 | outer_key: action 491 | inner_keys: [action] 492 | gripper_is_bool: True 493 | gripper_close_is_positive: True 494 | index_mapping: [1,2,3,4,5,6,7] 495 | mean: [0.02362956,0.0051615043,-0.015222261,0.04188222,0.0046763527,0.0988112,0.3] 496 | std: [0.11304981,0.116313554,0.078526765,0.19957197,0.09077263,1.0257257,0.4582545] 497 | max: [0.47253367,0.39019224,0.1048612,0.9421266,0.51699054,6.2564864,1.0] 498 | min: [-0.14635915,-0.4534304,-0.32937068,-0.290073,-0.48279625,-6.2891498,0.0] 499 | image: 500 | image: [64, 64, 3] 501 | highres_image: [480, 640, 3] 502 | canonical_view: highres_image 503 | 504 | _ucsd_kitchen_dataset_converted_externally_to_rlds: 505 | language: 506 | outer_key: language_instruction 507 | inner_key: 508 | action: 509 | outer_key: action 510 | inner_keys: [action] 511 | abs_action: true 512 | gripper_is_bool: True 513 | gripper_close_is_positive: False 514 | index_mapping: [1,2,3,4,5,6,7] 515 | mean: [410.3756,116.95189,192.35036,-121.22441,-33.848927,50.016136,0.7418136] 516 | std: [122.81497,108.80083,130.30342,116.281845,27.621872,41.020966,0.43763337] 517 | max: [678.0,400.0,507.0,180.00002,6.000014,116.999985,1.0] 518 | min: [172.0,-166.0,-99.99999,-180.00002,-89.0,-96.00011,0.0] 519 | image: 520 | image: [480, 640, 3] 521 | canonical_view: image 522 | 523 | _ucsd_pick_and_place_dataset_converted_externally_to_rlds: # low quality data 524 | language: 525 | outer_key: language_instruction 526 | inner_key: 527 | action: 528 | outer_key: action 529 | inner_keys: [action] 530 | gripper_is_bool: False 531 | gripper_close_is_positive: False 532 | index_mapping: [1,2,3,0,0,0,4] 533 | mean: [0.14699791,-0.12457364,0.053909536,0,0,0,-0.07569984] 534 | std: [0.48489225,0.46433133,0.540843,0,0,0,0.89286923] 535 | max: [1.0,1.0,1.0,0,0,0,1.0] 536 | min: [-1.0,-1.0,-1.0,0,0,0,-1.0] 537 | image: 538 | image: [224, 224, 3] 539 | canonical_view: image 540 | 541 | austin_sailor_dataset_converted_externally_to_rlds: 542 | language: 543 | outer_key: language_instruction 544 | inner_key: 545 | action: 546 | outer_key: action 547 | inner_keys: [action] 548 | gripper_is_bool: True 549 | 550 | gripper_signal: binary 551 | gripper_close_is_positive: True 552 | index_mapping: [1,2,3,4,5,6,7] 553 | mean: [0.011825301,0.006460939,0.06023686,0.0,0.0,0.001646604,-0.05219007] 554 | std: [0.46349025,0.4124005,0.4118623,0.0,0.0,0.057860684,0.99787027] 555 | max: [1.0,1.0,1.0,0.0,0.0,0.375,1.0] 556 | min: [-1.0,-1.0,-1.0,0.0,0.0,-0.375,-1.0] 557 | image: 558 | image: [128, 128, 3] 559 | wrist_image: [128, 128, 3] 560 | canonical_view: image 561 | 562 | austin_sirius_dataset_converted_externally_to_rlds: 563 | language: 564 | outer_key: language_instruction 565 | inner_key: 566 | action: 567 | outer_key: action 568 | inner_keys: [action] 569 | gripper_is_bool: True 570 | 571 | gripper_signal: binary 572 | gripper_close_is_positive: True 573 | index_mapping: [1,2,3,4,5,6,7] 574 | mean: [0.07747644,0.03195479,0.0424472,0.0,0.0,-0.016034609,0.13479717] 575 | std: [0.3906358,0.2998168,0.27823064,0.0,0.0,0.08120734,0.9905528] 576 | max: [1.0002285,0.9606087,1.1051798,0.0,0.0,0.34178573,1.000465] 577 | min: [-1.0183026,-0.98,-0.9774575,0.0,0.0,-0.34607142,-1.0004185] 578 | image: 579 | image: [84, 84, 3] 580 | wrist_image: [84, 84, 3] 581 | canonical_view: image 582 | 583 | usc_cloth_sim_converted_externally_to_rlds: 584 | language: 585 | outer_key: language_instruction 586 | inner_key: 587 | action: 588 | outer_key: action 589 | inner_keys: [action] 590 | gripper_is_bool: True 591 | 592 | gripper_signal: binary 593 | gripper_close_is_positive: True 594 | index_mapping: [1,2,3,0,0,0,4] 595 | mean: [0.105,0.03899963,2.3841857e-12,0,0,0,0.288093] 596 | std: [0.20360108,0.22256258,0.36332047,0,0,0,0.38395545] 597 | max: [0.5,0.5,1.0,0,0,0,0.8] 598 | min: [0.0,-0.6,-0.5,0,0,0,0.0] 599 | image: 600 | image: [32, 32, 3] 601 | canonical_view: image 602 | 603 | _utokyo_pr2_opening_fridge_converted_externally_to_rlds: 604 | # abs action? 605 | language: 606 | outer_key: language_instruction 607 | inner_key: 608 | action: 609 | outer_key: action 610 | inner_keys: [action] 611 | gripper_is_bool: True 612 | gripper_close_is_positive: True 613 | index_mapping: [1,2,3,4,5,6,7] 614 | mean: [648.7963,134.00903,1084.4326,-0.59741896,-0.19925973,0.011027544,0.25481686] 615 | std: [243.08586,257.75443,144.15361,0.49704772,0.15853582,0.14871414,0.435763] 616 | max: [992.8727,776.98096,1578.6831,0.24240845,0.4034255,0.9767319] 617 | min: [-453.4911,-1294.9354,766.853,-2.916839,-1.0906351,-0.7050959,0.0] 618 | image: 619 | image: [128, 128, 3] 620 | canonical_view: image 621 | 622 | _utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds: 623 | # is this abs action? 624 | language: 625 | outer_key: language_instruction 626 | inner_key: 627 | action: 628 | outer_key: action 629 | inner_keys: [action] 630 | gripper_is_bool: True 631 | gripper_close_is_positive: True 632 | index_mapping: [1,2,3,4,5,6,7] 633 | mean: [476.25525,-62.400757,817.60736,1.5703903,1.5703903,-1.5703903,0.42604256] 634 | std: [66.20367,157.32535,79.65957,0.0004060108,0.0004060108,0.0004060108,0.49454728] 635 | max: [701.22705,378.27518,1037.9014,1.5707964,1.5707964,-1.5707964,1.0] 636 | min: [233.80075,-448.76117,683.6527,1.5707964,1.5707964,-1.5707964,0.0] 637 | image: 638 | image: [128, 128, 3] 639 | canonical_view: image 640 | 641 | utokyo_xarm_pick_and_place_converted_externally_to_rlds: 642 | language: 643 | outer_key: language_instruction 644 | inner_key: 645 | action: 646 | outer_key: action 647 | inner_keys: [action] 648 | gripper_is_bool: True 649 | 650 | gripper_signal: abs 651 | gripper_close_is_positive: True 652 | index_mapping: [1,2,3,4,5,6,7] 653 | mean: [0.37275583,-0.0033627972,0.302364,2.3023665,-0.028892199,0.12031945,0.43367606] 654 | std: [0.04503441,0.14538307,0.09052395,1.7120283,0.11656404,0.22716746,0.48757684] 655 | max: [0.5145256,0.28572002,0.5560113,3.1411602,0.49824417,0.9626807,1.0] 656 | min: [0.22990057,-0.3306972,0.1617671,-3.1413672,-0.35810038,-0.53911775,0.0] 657 | image: 658 | image: [224, 224, 3] 659 | image2: [224, 224, 3] 660 | hand_image: [224, 224, 3] 661 | canonical_view: image 662 | 663 | _berkeley_mvp_converted_externally_to_rlds: 664 | language: 665 | outer_key: language_instruction 666 | inner_key: 667 | action: 668 | outer_key: action 669 | inner_keys: [action] 670 | gripper_is_bool: True 671 | gripper_close_is_positive: True 672 | index_mapping: [1,2,3,4,5,6,8] 673 | mean: [-6.0901424e-05,0.0032694037,-0.00014031114,-0.00093984953,-2.0682975e-05,-0.002937962,-0.0006887745,0.48194578] 674 | std: [0.0025867769,0.012797507,0.00560216,0.018046888,0.0016723009,0.021038346,0.005722353,0.49976623] 675 | max: [0.022487685,0.13018322,0.06808573,0.10099727,0.024909932,0.092411876,0.047665834,1.0] 676 | min: [-0.02135203,-0.03705667,-0.1387428,-0.19421488,-0.017915316,-0.19084352,-0.09572661,0.0] 677 | image: 678 | hand_image: [480, 640, 3] 679 | 680 | _berkeley_rpt_converted_externally_to_rlds: 681 | language: 682 | outer_key: language_instruction 683 | inner_key: 684 | action: 685 | outer_key: action 686 | inner_keys: [action] 687 | gripper_is_bool: True 688 | gripper_close_is_positive: True 689 | index_mapping: [1,2,3,4,5,6,8] 690 | mean: [0.00013917917,-0.00027910402,-9.376837e-06,-0.00032673698,1.9516052e-05,3.068495e-05,4.756752e-05,0.4785775] 691 | std: [0.0015279724,0.0045425007,0.0007791061,0.003020025,0.0010572867,0.005123675,0.004143369,0.49978197] 692 | max: [0.066869184,0.112888634,0.0115112215,0.03328538,0.16681346,0.08019853,0.062677264,1.0] 693 | min: [-0.029659934,-0.03170164,-0.026589055,-0.042382836,-0.0830309,-0.02891469,-0.12611546,0.0] 694 | image: 695 | hand_image: [480, 640, 3] 696 | 697 | _kaist_nonprehensile_converted_externally_to_rlds: 698 | language: 699 | outer_key: language_instruction 700 | inner_key: 701 | action: 702 | outer_key: action 703 | inner_keys: [action] 704 | gripper_is_bool: False 705 | index_mapping: [1,2,3,0,0,0,0] # seems no gripper but not sure 706 | mean: [0.0019469671,0.00023607437,0.00090286764,-0.0022928442,0.00011904385,-0.0036542765,180.56076,170.3104,185.60385,152.46106,174.59824,82.46711,31.534365,1.0037495,0.6999126,0.33574754,0.48796102,0.4932014,0.73938084,1.232158] 707 | std: [0.0146865165,0.016561238,0.012853604,0.023209855,0.020493455,0.020038577,28.134468,34.891018,25.784292,53.76362,33.45856,50.650562,32.378338,0.5329004,0.4424238,0.12782517,0.37656787,0.3380472,0.39779198,0.5643148] 708 | max: [0.02,0.02,0.02,0.03,0.03,0.03,200.0,200.0,200.0,200.0,200.0,200.0,200.0,2.0,2.0,1.7018844,2.0,2.0,2.0,2.0] 709 | min: [-0.02,-0.02,-0.02,-0.03,-0.03,-0.03,10.0,10.0,10.0,10.0,10.0,10.0,10.0,0.29999995,0.29999995,0.29999995,0.29999995,0.29999995,0.29999995,0.29999995] 710 | image: 711 | image: [480, 640, 3] 712 | 713 | stanford_mask_vit_converted_externally_to_rlds: # seems low quality 714 | language: 715 | outer_key: language_instruction 716 | inner_key: 717 | action: 718 | outer_key: action 719 | inner_keys: [action] 720 | gripper_is_bool: True 721 | 722 | gripper_signal: binary 723 | gripper_close_is_positive: True 724 | index_mapping: [1,2,3,0,0,4,5] 725 | mean: [-5.7280602e-05,4.4461805e-05,0.00013061121,0,0,-0.0003851414,-0.17713885] 726 | std: [0.03302672,0.033040244,0.07856057,0,0,0.17140482,0.96742386] 727 | max: [0.07,0.07,0.35625914,0,0,0.6708251,1.0] 728 | min: [-0.07,-0.07,-0.38690937,0,0,-0.7456048,-1.0] 729 | image: 730 | image: [480, 480, 3] 731 | canonical_view: image 732 | 733 | tokyo_u_lsmo_converted_externally_to_rlds: 734 | language: 735 | outer_key: language_instruction 736 | inner_key: 737 | action: 738 | outer_key: action 739 | inner_keys: [action] 740 | gripper_is_bool: True 741 | gripper_signal: 742 | gripper_close_is_positive: True 743 | index_mapping: [1,2,3,4,5,6,7] 744 | mean: [0.0015662411,-0.000200291,-3.4290827e-06,1.3049686e-05,5.041861e-05,-0.0010225909,0.0] 745 | std: [0.0012936676,0.00078826846,0.0015917509,0.0037797757,0.005711879,0.11527428,0.0] 746 | max: [0.0042887772,0.0033740131,0.0041465242,0.01142326,0.014931569,6.2830777,0.0] 747 | min: [-0.003931232,-0.0026556132,-0.005801674,-0.011649237,-0.01700513,-6.282474,0.0] 748 | image: 749 | image: [120, 120, 3] 750 | canonical_view: image 751 | 752 | _dlr_sara_pour_converted_externally_to_rlds: 753 | language: 754 | outer_key: language_instruction 755 | inner_key: 756 | action: 757 | outer_key: action 758 | inner_keys: [action] 759 | gripper_is_bool: True 760 | gripper_close_is_positive: False 761 | index_mapping: [1,2,3,4,5,6,7] 762 | mean: [-1.779125e-05,-3.824358e-05,-0.00038651552,2.8173781e-05,7.860234e-05,6.662502e-05,1.0] 763 | std: [0.00042584693,0.00050736236,0.0012398473,0.0005687393,0.00075076945,0.00069576234,0.0] 764 | max: [0.0041252756,0.0034917803,0.0064525837,0.0065098777,0.006107756,0.006421665,1.0] 765 | min: [-0.004632457,-0.0057487926,-0.011109131,-0.004965867,-0.0076577803,-0.00692513,1.0] 766 | image: 767 | image: [480, 640, 3] 768 | canonical_view: image 769 | 770 | _dlr_edan_shared_control_converted_externally_to_rlds: # instructions within one episode not all the same 771 | language: 772 | outer_key: language_instruction 773 | inner_key: 774 | action: 775 | outer_key: action 776 | inner_keys: [action] 777 | gripper_is_bool: True 778 | gripper_signal: binary 779 | gripper_close_is_positive: True 780 | index_mapping: [1,2,3,4,5,6,7] 781 | mean: [0.0066478024,-0.0007657355,0.006522838,0.0011679777,-0.006395635,-0.011903042,0.3014113] 782 | std: [0.021393627,0.018142333,0.03374378,0.017435411,0.033943783,0.04641878,0.45885926] 783 | max: [0.18991442,0.07390025,0.1806482,0.08664861,0.13464981,0.1691028,1.0] 784 | min: [-0.10054297,-0.08427435,-0.13533439,-0.17556548,-0.18485673,-0.26806858,0.0] 785 | image: 786 | image: [360, 640, 3] 787 | canonical_view: image 788 | 789 | _asu_table_top_converted_externally_to_rlds: # instructions not all the same 790 | language: 791 | outer_key: language_instruction 792 | inner_key: 793 | action: 794 | outer_key: action 795 | inner_keys: [action] 796 | gripper_is_bool: False 797 | gripper_signal: abs 798 | gripper_close_is_positive: True 799 | index_mapping: [1,2,3,4,5,6,7] 800 | mean: [-0.04829057,0.19164251,0.09978654,2.460989,0.00046229397,1.5699155,0.28321108] 801 | std: [0.24690276,0.09880382,0.06263857,1.4282565,0.0068450016,0.011229702,0.32279274] 802 | max: [0.39995855,0.55227333,0.270501,3.1415906,0.039888673,1.631198,0.9596181] 803 | min: [-0.45636123,-0.0033472818,0.020711672,-3.1415923,-0.038345333,1.4695326,-0.2130372] 804 | image: 805 | image: [224, 224, 3] 806 | canonical_view: image 807 | 808 | _stanford_robocook_converted_externally_to_rlds: # bad image 809 | language: 810 | outer_key: language_instruction 811 | inner_key: 812 | action: 813 | outer_key: action 814 | inner_keys: [action] 815 | gripper_is_bool: False 816 | gripper_close_is_positive: False 817 | index_mapping: [1,2,3,4,5,6,7] 818 | mean: [-2.67842e-06,1.8303779e-06,-0.0013767683,-2.3712617e-05,-0.00010789655,-4.2797503e-05,-2.8979612e-05] 819 | std: [0.002275514,0.0023861625,0.010277068,0.0005761813,0.0015414119,0.0006268042,0.0027927894] 820 | max: [0.05666566,0.05632849,0.11884554,0.022299306,0.029122513,0.010286045,0.04290463] 821 | min: [-0.05887693,-0.056699876,-0.118049115,-0.012663912,-0.030290456,-0.012011405,-0.05652061] 822 | image: 823 | depth_1: [256, 256] 824 | depth_2: [256, 256] 825 | depth_3: [256, 256] 826 | depth_4: [256, 256] 827 | image_1: [256, 256, 3] 828 | image_2: [256, 256, 3] 829 | image_3: [256, 256, 3] 830 | image_4: [256, 256, 3] 831 | canonical_view: image_1 832 | 833 | imperialcollege_sawyer_wrist_cam: 834 | language: 835 | outer_key: language_instruction 836 | inner_key: 837 | action: 838 | outer_key: action 839 | inner_keys: [action] 840 | gripper_is_bool: True 841 | gripper_signal: binary 842 | gripper_close_is_positive: True 843 | index_mapping: [1,2,3,4,5,6,7] 844 | mean: [0.00023605324,-0.0009842712,0.00094666186,0.0011849315,-4.6923204e-05,1.4285401e-05,0.5726077] 845 | std: [0.0030598326,0.006771865,0.010962971,0.023719376,0.0032084314,0.003953903,0.49469307] 846 | max: [0.03886363,0.030030029,0.09125323,0.1702448,0.036239807,0.049347647,1.0] 847 | min: [-0.020447163,-0.05311591,-0.04699578,-0.13995367,-0.0328519,-0.055988327,0.0] 848 | image: 849 | image: [64, 64, 3] 850 | wrist_image: [64, 64, 3] 851 | canonical_view: image 852 | 853 | _iamlab_cmu_pickup_insert_converted_externally_to_rlds: 854 | language: 855 | outer_key: language_instruction 856 | inner_key: 857 | action: 858 | outer_key: action 859 | inner_keys: [action] 860 | abs_action: True 861 | gripper_is_bool: True 862 | gripper_close_is_positive: False 863 | index_mapping: [1,2,3,4,5,6,8] 864 | mean: [0.5274292,0.028582023,0.18712406,-0.0131298825,0.9998938,0.0036105025,0.5550632] 865 | std: [0.081083916,0.111675,0.07747591,0.016080942,0.00063183776,0.0078107985,0.49694073] 866 | max: [0.66349816,0.23428471,0.43082854,0.041561358,0.9999999,0.023352295,1.0] 867 | min: [0.3071657,-0.2975497,0.06578229,-0.06755125,0.9966192,-0.026384523,0.0] 868 | image: 869 | image: [360, 640, 3] 870 | wrist_image: [240, 320, 3] 871 | canonical_view: image 872 | 873 | _uiuc_d3field: 874 | language: 875 | outer_key: language_instruction 876 | inner_key: 877 | action: 878 | outer_key: action 879 | inner_keys: [action] 880 | abs_action: True 881 | gripper_is_bool: None 882 | gripper_close_is_positive: None 883 | index_mapping: [1,2,3,0,0,0] 884 | min: [ -0.0151053965,-0.015266597,-0.015180364 ] 885 | max: [ 0.015263855,0.015294969,0.015344886 ] 886 | mean: [ 0.00018498691,1.10236315e-05,-0.00017510964 ] 887 | std: [ 0.0029065583,0.0025309222,0.002859444 ] 888 | image: 889 | depth_1: [360, 640, 1] 890 | depth_2: [360, 640, 1] 891 | depth_3: [360, 640, 1] 892 | depth_4: [360, 640, 3] 893 | image_1: [360, 640, 3] 894 | image_2: [360, 640, 3] 895 | image_3: [360, 640, 3] 896 | image_4: [360, 640, 3] 897 | canonical_view: image_1 898 | --------------------------------------------------------------------------------