├── .gitignore ├── LICENSE ├── README.md ├── generate_download_script.py ├── open_x_dataset_pytorch.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | download.sh 163 | datasets 164 | test_* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xiang Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch DataLoader for Open-X Embodiment Datasets 2 | 3 | An unofficial pytorch dataloader for [Open-X Embodiment Datasets](https://robotics-transformer-x.github.io/). 4 | 5 | This README will guide you to integrate the Open-X Embodiment Datasets into your PyTorch project. For a native TensorFlow experience, please check [the official repo](https://github.com/google-deepmind/open_x_embodiment). 6 | 7 | ## Download the datasets 8 | 9 | 1. Check available datasets and their corresponding metadata in [the dataset spreadsheet](https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit#gid=0) 10 | * **Warning** The images in `utokyo_saytap_converted_externally_to_rlds` seem to be corrupted. 11 | 2. Set your preferred download destination `download_dst` in [generate_download_script.py](generate_download_script.py) and confirm the datasets you want to download. By default, the Python script will create a shell script that downloads all 53 datasets, amounting to a total size of approximately 4.5TB. 12 | 3. Follow [this guide](https://cloud.google.com/storage/docs/gsutil_install#linux) to setup `gsutil` 13 | 4. Generate the shell script and start to download: 14 | ``` 15 | python3 generate_download_script.py 16 | chmod +x download.sh 17 | ./download.sh 18 | ``` 19 | 20 | This section was last updated on 1/19/2024. 21 | 22 | ## Play with the dataloader 23 | 24 | 1. Install python dependence 25 | ``` 26 | pip3 install -r requirements.txt 27 | ``` 28 | 2. If your machine has enough RAM to hold the whole dataset, you can init the dataset with `class OpenXDataset(Dataset)` in `open_x_dataset_pytorch.py`. A quick example: 29 | 30 | ``` 31 | d = OpenXDataset( 32 | tf_dir='datasets/asu_table_top_converted_externally_to_rlds/0.1.0/', 33 | fetch_pattern=r'.*image.*', 34 | sample_length=8, 35 | ) 36 | print(d) 37 | ``` 38 | 39 | * `tf_dir`: full directory containing the downloaded dataset, including the version number. 40 | * `fetch_pattern`: regular expression utilized to specify the data you wish to retrieve. Defaults to `r'steps*'`. The example above only retrieves visual observations. 41 | * `sample_length`: number of transitions per sample. If set to `2`, the returned sample will be $[s_1, s_2]$. 42 | 43 | The last several lines of the output of the code above: 44 | ``` 45 | ========== 46 | Total episodes: 110 47 | Total samples: 1433503 48 | ========== 49 | Output keys: 50 | - steps/observation/image 51 | Masked keys: 52 | - steps/observation/state_vel 53 | - steps/ground_truth_states/bread 54 | - steps/is_first 55 | - steps/ground_truth_states/coke 56 | - steps/ground_truth_states/cube 57 | - steps/language_embedding 58 | - steps/is_terminal 59 | - steps/is_last 60 | - steps/discount 61 | - steps/ground_truth_states/EE 62 | - steps/language_instruction 63 | - steps/ground_truth_states/pepsi 64 | - steps/ground_truth_states/milk 65 | - steps/observation/state 66 | - steps/goal_object 67 | - steps/action 68 | - episode_metadata/file_path 69 | - steps/ground_truth_states/bottle 70 | - steps/action_delta 71 | - steps/action_inst 72 | - steps/reward 73 | ``` 74 | 75 | `__getitem__()` returns a dictionary where the keys correspond to `fetch_pattern`. The associated value for each key will be either a tensor of size `(sample_length, *original feature shape)`[^1] or a list with `sample_length` elements. 76 | 77 | 3. If the machine does not have enough RAM: use `class IterableOpenXDataset(IterableDataset)` in `open_x_dataset_pytorch.py` instead. It takes the same input parameters as the one above, though it does not maintain the total number of samples. 78 | ``` 79 | d = IterableOpenXDataset( 80 | tf_dir='datasets/asu_table_top_converted_externally_to_rlds/0.1.0/', 81 | fetch_pattern=r'.*image.*', 82 | sample_length=8, 83 | ) 84 | print(d) 85 | ``` 86 | 87 | ## TODO 88 | 1. Filter out the invalid episodes according to [the dataset format](https://github.com/google-research/rlds?tab=readme-ov-file#dataset-format) 89 | 90 | ## Acknowledgment 91 | 92 | I really appreciate the substantial open-sourcing effort contributed by the creators of this extensive dataset. 93 | Thank [Jinghuan Shang](https://elicassion.github.io/) for valuable discussions. 94 | 95 | [^1]: When the feature is an image, the tensor will have a shape of `(sample_length, C, H, W)` instead. 96 | -------------------------------------------------------------------------------- /generate_download_script.py: -------------------------------------------------------------------------------- 1 | # Reference: 2 | # https://github.com/google-deepmind/open_x_embodiment/blob/main/colabs/Open_X_Embodiment_Datasets.ipynb 3 | # https://github.com/google-deepmind/open_x_embodiment/issues/5 4 | 5 | # 1. Set your preferred download destination 6 | download_dst='CHANGE_ME' 7 | 8 | # 2. Select the dataset you want to download 9 | # The list below contains all the available datasets to download 10 | # To skip a dataset, just comment out the corresponding line 11 | DATASETS = [ 12 | 'fractal20220817_data', 13 | 'kuka', 14 | 'bridge', 15 | 'taco_play', 16 | 'jaco_play', 17 | 'berkeley_cable_routing', 18 | 'roboturk', 19 | 'nyu_door_opening_surprising_effectiveness', 20 | 'viola', 21 | 'berkeley_autolab_ur5', 22 | 'toto', 23 | 'language_table', 24 | 'columbia_cairlab_pusht_real', 25 | 'stanford_kuka_multimodal_dataset_converted_externally_to_rlds', 26 | 'nyu_rot_dataset_converted_externally_to_rlds', 27 | 'stanford_hydra_dataset_converted_externally_to_rlds', 28 | 'austin_buds_dataset_converted_externally_to_rlds', 29 | 'nyu_franka_play_dataset_converted_externally_to_rlds', 30 | 'maniskill_dataset_converted_externally_to_rlds', 31 | 'cmu_franka_exploration_dataset_converted_externally_to_rlds', 32 | 'ucsd_kitchen_dataset_converted_externally_to_rlds', 33 | 'ucsd_pick_and_place_dataset_converted_externally_to_rlds', 34 | 'austin_sailor_dataset_converted_externally_to_rlds', 35 | 'austin_sirius_dataset_converted_externally_to_rlds', 36 | 'bc_z', 37 | 'usc_cloth_sim_converted_externally_to_rlds', 38 | 'utokyo_pr2_opening_fridge_converted_externally_to_rlds', 39 | 'utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds', 40 | 'utokyo_saytap_converted_externally_to_rlds', 41 | 'utokyo_xarm_pick_and_place_converted_externally_to_rlds', 42 | 'utokyo_xarm_bimanual_converted_externally_to_rlds', 43 | 'robo_net', 44 | 'berkeley_mvp_converted_externally_to_rlds', 45 | 'berkeley_rpt_converted_externally_to_rlds', 46 | 'kaist_nonprehensile_converted_externally_to_rlds', 47 | 'stanford_mask_vit_converted_externally_to_rlds', 48 | 'tokyo_u_lsmo_converted_externally_to_rlds', 49 | 'dlr_sara_pour_converted_externally_to_rlds', 50 | 'dlr_sara_grid_clamp_converted_externally_to_rlds', 51 | 'dlr_edan_shared_control_converted_externally_to_rlds', 52 | 'asu_table_top_converted_externally_to_rlds', 53 | 'stanford_robocook_converted_externally_to_rlds', 54 | 'eth_agent_affordances', 55 | 'imperialcollege_sawyer_wrist_cam', 56 | 'iamlab_cmu_pickup_insert_converted_externally_to_rlds', 57 | 'uiuc_d3field', 58 | 'utaustin_mutex', 59 | 'berkeley_fanuc_manipulation', 60 | 'cmu_play_fusion', 61 | 'cmu_stretch', 62 | 'berkeley_gnm_recon', 63 | 'berkeley_gnm_cory_hall', 64 | 'berkeley_gnm_sac_son' 65 | ] 66 | 67 | 68 | if download_dst == 'CHANGE_ME': 69 | print('Please assign a valid download destination') 70 | exit(0) 71 | 72 | 73 | def dataset2path(dataset_name: str) -> str: 74 | if dataset_name == 'robo_net': 75 | version = '1.0.0' 76 | elif dataset_name == 'language_table': 77 | version = '0.0.1' 78 | else: 79 | version = '0.1.0' 80 | return f'gs://gresearch/robotics/{dataset_name}/{version}' 81 | 82 | 83 | with open('download.sh', 'w') as f: 84 | for dataset in DATASETS: 85 | f.write(f'gsutil -m cp -r {dataset2path(dataset)} {download_dst}\n') 86 | -------------------------------------------------------------------------------- /open_x_dataset_pytorch.py: -------------------------------------------------------------------------------- 1 | import tensorflow_datasets as tfds 2 | from PIL import Image 3 | import io 4 | import re 5 | 6 | import torch 7 | from torchdata.datapipes.iter import FileLister, FileOpener 8 | from torch.utils.data import Dataset, IterableDataset 9 | from torchvision.transforms import ToTensor 10 | 11 | 12 | def parse_episode_using_meta(episode: dict, meta: tfds.core.features.FeaturesDict)-> dict: 13 | def fetch_key_in_meta(key: str) \ 14 | -> tfds.core.features.Tensor | tfds.core.features.Text | tfds.core.features.Image: 15 | iter_keys = key.split('/') 16 | key_meta = meta 17 | for key in iter_keys: 18 | key_meta = key_meta[key] 19 | return key_meta 20 | 21 | transform = ToTensor() # H, W, C -> C, H, W; 0-255 -> 0-1.0 22 | 23 | for key, value in episode.items(): 24 | key_meta = fetch_key_in_meta(key) 25 | # print(key, key_meta, type(value)) 26 | if isinstance(value, torch.Tensor): 27 | episode[key] = value.reshape(-1, *key_meta.shape) 28 | else: 29 | # another type should be 30 | # google.protobuf.pyext._message.RepeatedScalarContainer 31 | if isinstance(key_meta, tfds.core.features.Text): 32 | # decode text bytes (bytes -> str) 33 | episode[key] = [raw_bytes.decode('utf8') for raw_bytes in value] 34 | elif isinstance(key_meta, tfds.core.features.Image): 35 | # decode image bytes (PNG/JPG -> tensor) 36 | image_tensors = [transform(Image.open(io.BytesIO(raw_bytes))) for raw_bytes in value] 37 | image_tensors = torch.stack(image_tensors, dim=0) # B, C, H, W 38 | episode[key] = image_tensors 39 | else: 40 | raise ValueError(f'Unknown data type of {key} - {type(key_meta)}') 41 | return episode 42 | 43 | 44 | def read_from_tf(tf_dir: str): 45 | # Load metadata using tfds 46 | b = tfds.builder_from_directory(builder_dir=tf_dir) 47 | 48 | # Load data using torch 49 | datapipe1 = FileLister(tf_dir, "*.tfrecord*") 50 | datapipe2 = FileOpener(datapipe1, mode="b") 51 | tfrecord_loader_dp = datapipe2.load_from_tfrecord() 52 | 53 | return b.info, tfrecord_loader_dp 54 | 55 | 56 | def prepare_output_dict(transition_idx: int, episode_idx: int, episode: dict, sample_length: int, fetch_pattern: str): 57 | res = { 58 | key: value[transition_idx: transition_idx + sample_length] 59 | for key, value in episode.items() 60 | if re.search(fetch_pattern, key) 61 | } 62 | res['transition_idx'] = transition_idx 63 | res['episode_idx'] = episode_idx 64 | return res 65 | 66 | 67 | # Caution: this class will load the whole dataset to your RAM! 68 | # For a huge dataset, please check IterateDataset instead. 69 | class OpenXDataset(Dataset): 70 | def __init__(self, 71 | tf_dir: str, 72 | fetch_pattern: str=r'steps*', 73 | sample_length: int=2) -> None: 74 | """ 75 | Args: 76 | tf_dir (str): full directory containing the downloaded dataset, including the version number. 77 | fetch_pattern (str, optional): regular expression utilized to specify the data you wish to retrieve. Defaults to r'steps*'. 78 | sample_length (int, optional): number of transitions per sample. Defaults to 2. 79 | """ 80 | super().__init__() 81 | 82 | # fetch all the keys that match this regular expression 83 | self.fetch_pattern = fetch_pattern 84 | 85 | assert sample_length > 0, \ 86 | 'The number of transitions in each sample must be larger than 0' 87 | self.sample_length = sample_length 88 | 89 | self.info, tfrecord_loader_dp = read_from_tf(tf_dir) 90 | self.meta = self.info.features 91 | 92 | # container for all the data 93 | self.episodes: list[dict] = [] 94 | # sample index -> (index of the episode, index within the episode) 95 | self.episode_idx: list[tuple[int, int]] = [] 96 | 97 | episode_start = 0 98 | for i, episode in enumerate(tfrecord_loader_dp): 99 | # TODO: filter out invalid episodes using metadata 100 | self.episodes.append( 101 | parse_episode_using_meta(episode, self.meta) 102 | ) 103 | # count the total number of samples we can extract from an episode 104 | episode_length = episode['steps/is_first'].shape[0] 105 | samples_per_episode = episode_length - sample_length + 1 106 | episode_start += samples_per_episode 107 | self.episode_idx.extend([(i, j) for j in range(samples_per_episode)]) 108 | 109 | assert len(self.episodes), 'Fail to load any episodes' 110 | self.episode_keys = list(self.episodes[0].keys()) 111 | 112 | def __repr__(self) -> str: 113 | result = self.info.__repr__() 114 | result += '\n' + '=' * 10 115 | result += f'\nTotal episodes: {len(self.episodes)}' 116 | result += f'\nTotal samples: {self.__len__()} given sample length: {self.sample_length}' 117 | result += '\n' + '=' * 10 118 | 119 | output_keys: list[str] = [] 120 | masked_keys: list[str] = [] 121 | for key in self.episodes[0]: 122 | if re.search(self.fetch_pattern, key): 123 | output_keys.append(key) 124 | else: 125 | masked_keys.append(key) 126 | 127 | result += f'\nOutput keys:\n - ' + '\n - '.join(output_keys) 128 | result += f'\nMasked keys:\n - ' + '\n - '.join(masked_keys) 129 | return result 130 | 131 | def __len__(self) -> int: 132 | return len(self.episode_idx) 133 | 134 | def __getitem__(self, index: int) -> dict: 135 | episode_idx, transition_idx = self.episode_idx[index] 136 | return prepare_output_dict( 137 | transition_idx, 138 | episode_idx, 139 | self.episodes[episode_idx], 140 | self.sample_length, 141 | self.fetch_pattern 142 | ) 143 | 144 | class IterableOpenXDataset(IterableDataset): 145 | def __init__(self, 146 | tf_dir: str, 147 | fetch_pattern: str=r'steps*', 148 | sample_length: int=2) -> None: 149 | """ 150 | Args: 151 | tf_dir (str): full directory containing the downloaded dataset, including the version number. 152 | fetch_pattern (str, optional): regular expression utilized to specify the data you wish to retrieve. Defaults to r'steps*'. 153 | sample_length (int, optional): number of transitions per sample. Defaults to 2. 154 | """ 155 | super().__init__() 156 | 157 | # fetch all the keys that match this regular expression 158 | self.fetch_pattern = fetch_pattern 159 | 160 | assert sample_length > 0, \ 161 | 'The number of transitions in each sample must be larger than 0' 162 | self.sample_length = sample_length 163 | 164 | self.info, self.tfrecord_loader_dp = read_from_tf(tf_dir) 165 | self.meta = self.info.features 166 | 167 | def __iter__(self): 168 | worker_info = torch.utils.data.get_worker_info() 169 | 170 | if worker_info is None: # single-process data loading, return the full iterator 171 | num_workers = 1 172 | worker_id = 0 173 | else: # in a worker process 174 | num_workers = worker_info.num_workers 175 | worker_id = worker_info.id 176 | 177 | example_id = 0 178 | for episode_idx, episode in enumerate(self.tfrecord_loader_dp): 179 | # TODO: filter out invalid episodes using metadata 180 | converted_episode = parse_episode_using_meta(episode, self.meta) 181 | # count the total number of samples we can extract from an episode 182 | episode_length = episode['steps/is_first'].shape[0] 183 | samples_per_episode = episode_length - self.sample_length + 1 184 | for transition_idx in range(samples_per_episode): 185 | if example_id % num_workers == worker_id: 186 | yield prepare_output_dict( 187 | transition_idx, 188 | episode_idx, 189 | converted_episode, 190 | self.sample_length, 191 | self.fetch_pattern 192 | ) 193 | example_id += 1 194 | 195 | def __repr__(self) -> str: 196 | result = self.info.__repr__() 197 | result += '\n' + '=' * 10 198 | 199 | output_keys: list[str] = [] 200 | masked_keys: list[str] = [] 201 | 202 | for episode in self.tfrecord_loader_dp: 203 | for key in episode: 204 | if re.search(self.fetch_pattern, key): 205 | output_keys.append(key) 206 | else: 207 | masked_keys.append(key) 208 | 209 | result += f'\nOutput keys:\n - ' + '\n - '.join(output_keys) 210 | result += f'\nMasked keys:\n - ' + '\n - '.join(masked_keys) 211 | return result 212 | 213 | 214 | if __name__ == '__main__': 215 | d = OpenXDataset( 216 | tf_dir='datasets/asu_table_top_converted_externally_to_rlds/0.1.0/', 217 | fetch_pattern=r'.*image.*', 218 | sample_length=8, 219 | ) 220 | print(d) 221 | from torch.utils.data import DataLoader 222 | from tqdm import tqdm 223 | dl = DataLoader(d, batch_size=12, shuffle=True) 224 | for batch in tqdm(dl): 225 | for k, v in batch.items(): 226 | print(k, v.shape if isinstance(v, torch.Tensor) else len(v)) 227 | break 228 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow_datasets 2 | pillow 3 | torch 4 | torchdata==0.9.0 5 | --------------------------------------------------------------------------------