├── .gitignore ├── LIBERO_10 ├── CITATIONS.bib ├── LIBERO_10_dataset_builder.py ├── README.md ├── __init__.py └── conversion_utils.py ├── LIBERO_Goal ├── CITATIONS.bib ├── LIBERO_Goal_dataset_builder.py ├── README.md ├── __init__.py └── conversion_utils.py ├── LIBERO_Object ├── CITATIONS.bib ├── LIBERO_Object_dataset_builder.py ├── README.md ├── __init__.py └── conversion_utils.py ├── LIBERO_Spatial ├── CITATIONS.bib ├── LIBERO_Spatial_dataset_builder.py ├── README.md ├── __init__.py └── conversion_utils.py ├── LICENSE ├── README.md ├── aloha1_put_X_into_pot_300_demos ├── CITATIONS.bib ├── README.md ├── __init__.py ├── aloha1_put_X_into_pot_300_demos_dataset_builder.py └── conversion_utils.py ├── environment_macos.yml ├── environment_ubuntu.yml ├── example_dataset ├── CITATIONS.bib ├── README.md ├── __init__.py ├── create_example_data.py └── example_dataset_dataset_builder.py ├── example_transform └── transform.py ├── setup.py ├── test_dataset_transform.py └── visualize_dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | */data 2 | wandb 3 | __pycache__ 4 | .idea 5 | -------------------------------------------------------------------------------- /LIBERO_10/CITATIONS.bib: -------------------------------------------------------------------------------- 1 | // TODO(example_dataset): BibTeX citation 2 | -------------------------------------------------------------------------------- /LIBERO_10/LIBERO_10_dataset_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple, Any 2 | 3 | import os 4 | import h5py 5 | import glob 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | import sys 10 | from LIBERO_10.conversion_utils import MultiThreadedDatasetBuilder 11 | 12 | 13 | def _generate_examples(paths) -> Iterator[Tuple[str, Any]]: 14 | """Yields episodes for list of data paths.""" 15 | # the line below needs to be *inside* generate_examples so that each worker creates it's own model 16 | # creating one shared model outside this function would cause a deadlock 17 | 18 | def _parse_example(episode_path, demo_id): 19 | # load raw data 20 | with h5py.File(episode_path, "r") as F: 21 | if f"demo_{demo_id}" not in F['data'].keys(): 22 | return None # skip episode if the demo doesn't exist (e.g. due to failed demo) 23 | actions = F['data'][f"demo_{demo_id}"]["actions"][()] 24 | states = F['data'][f"demo_{demo_id}"]["obs"]["ee_states"][()] 25 | gripper_states = F['data'][f"demo_{demo_id}"]["obs"]["gripper_states"][()] 26 | joint_states = F['data'][f"demo_{demo_id}"]["obs"]["joint_states"][()] 27 | images = F['data'][f"demo_{demo_id}"]["obs"]["agentview_rgb"][()] 28 | wrist_images = F['data'][f"demo_{demo_id}"]["obs"]["eye_in_hand_rgb"][()] 29 | 30 | # compute language instruction 31 | raw_file_string = os.path.basename(episode_path).split('/')[-1] 32 | words = raw_file_string[:-10].split("_") 33 | command = '' 34 | for w in words: 35 | if "SCENE" in w: 36 | command = '' 37 | continue 38 | command = command + w + ' ' 39 | command = command[:-1] 40 | 41 | # assemble episode --> here we're assuming demos so we set reward to 1 at the end 42 | episode = [] 43 | for i in range(actions.shape[0]): 44 | episode.append({ 45 | 'observation': { 46 | 'image': images[i][::-1,::-1], 47 | 'wrist_image': wrist_images[i][::-1,::-1], 48 | 'state': np.asarray(np.concatenate((states[i], gripper_states[i]), axis=-1), np.float32), 49 | 'joint_state': np.asarray(joint_states[i], dtype=np.float32), 50 | }, 51 | 'action': np.asarray(actions[i], dtype=np.float32), 52 | 'discount': 1.0, 53 | 'reward': float(i == (actions.shape[0] - 1)), 54 | 'is_first': i == 0, 55 | 'is_last': i == (actions.shape[0] - 1), 56 | 'is_terminal': i == (actions.shape[0] - 1), 57 | 'language_instruction': command, 58 | }) 59 | 60 | # create output data sample 61 | sample = { 62 | 'steps': episode, 63 | 'episode_metadata': { 64 | 'file_path': episode_path 65 | } 66 | } 67 | 68 | # if you want to skip an example for whatever reason, simply return None 69 | return episode_path + f"_{demo_id}", sample 70 | 71 | # for smallish datasets, use single-thread parsing 72 | for sample in paths: 73 | with h5py.File(sample, "r") as F: 74 | n_demos = len(F['data']) 75 | idx = 0 76 | cnt = 0 77 | while cnt < n_demos: 78 | ret = _parse_example(sample, idx) 79 | if ret is not None: 80 | cnt += 1 81 | idx += 1 82 | yield ret 83 | 84 | 85 | class LIBERO10(MultiThreadedDatasetBuilder): 86 | """DatasetBuilder for example dataset.""" 87 | 88 | VERSION = tfds.core.Version('1.0.0') 89 | RELEASE_NOTES = { 90 | '1.0.0': 'Initial release.', 91 | } 92 | N_WORKERS = 40 # number of parallel workers for data conversion 93 | MAX_PATHS_IN_MEMORY = 80 # number of paths converted & stored in memory before writing to disk 94 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 95 | # note that one path may yield multiple episodes and adjust accordingly 96 | PARSE_FCN = _generate_examples # handle to parse function from file paths to RLDS episodes 97 | 98 | def _info(self) -> tfds.core.DatasetInfo: 99 | """Dataset metadata (homepage, citation,...).""" 100 | return self.dataset_info_from_configs( 101 | features=tfds.features.FeaturesDict({ 102 | 'steps': tfds.features.Dataset({ 103 | 'observation': tfds.features.FeaturesDict({ 104 | 'image': tfds.features.Image( 105 | shape=(256, 256, 3), 106 | dtype=np.uint8, 107 | encoding_format='jpeg', 108 | doc='Main camera RGB observation.', 109 | ), 110 | 'wrist_image': tfds.features.Image( 111 | shape=(256, 256, 3), 112 | dtype=np.uint8, 113 | encoding_format='jpeg', 114 | doc='Wrist camera RGB observation.', 115 | ), 116 | 'state': tfds.features.Tensor( 117 | shape=(8,), 118 | dtype=np.float32, 119 | doc='Robot EEF state (6D pose, 2D gripper).', 120 | ), 121 | 'joint_state': tfds.features.Tensor( 122 | shape=(7,), 123 | dtype=np.float32, 124 | doc='Robot joint angles.', 125 | ) 126 | }), 127 | 'action': tfds.features.Tensor( 128 | shape=(7,), 129 | dtype=np.float32, 130 | doc='Robot EEF action.', 131 | ), 132 | 'discount': tfds.features.Scalar( 133 | dtype=np.float32, 134 | doc='Discount if provided, default to 1.' 135 | ), 136 | 'reward': tfds.features.Scalar( 137 | dtype=np.float32, 138 | doc='Reward if provided, 1 on final step for demos.' 139 | ), 140 | 'is_first': tfds.features.Scalar( 141 | dtype=np.bool_, 142 | doc='True on first step of the episode.' 143 | ), 144 | 'is_last': tfds.features.Scalar( 145 | dtype=np.bool_, 146 | doc='True on last step of the episode.' 147 | ), 148 | 'is_terminal': tfds.features.Scalar( 149 | dtype=np.bool_, 150 | doc='True on last step of the episode if it is a terminal step, True for demos.' 151 | ), 152 | 'language_instruction': tfds.features.Text( 153 | doc='Language Instruction.' 154 | ), 155 | }), 156 | 'episode_metadata': tfds.features.FeaturesDict({ 157 | 'file_path': tfds.features.Text( 158 | doc='Path to the original data file.' 159 | ), 160 | }), 161 | })) 162 | 163 | def _split_paths(self): 164 | """Define filepaths for data splits.""" 165 | return { 166 | "train": glob.glob("/PATH/TO/LIBERO/libero/datasets/libero_10_no_noops/*.hdf5"), 167 | } 168 | -------------------------------------------------------------------------------- /LIBERO_10/README.md: -------------------------------------------------------------------------------- 1 | TODO(example_dataset): Markdown description of your dataset. 2 | Description is **formatted** as markdown. 3 | 4 | It should also contain any processing which has been applied (if any), 5 | (e.g. corrupted example skipped, images cropped,...): 6 | -------------------------------------------------------------------------------- /LIBERO_10/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moojink/rlds_dataset_builder/6174b0b6bb69df6361f1117944952bf14afb0cc3/LIBERO_10/__init__.py -------------------------------------------------------------------------------- /LIBERO_10/conversion_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Dict, Union, Callable, Iterable 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | import itertools 7 | from multiprocessing import Pool 8 | from functools import partial 9 | from tensorflow_datasets.core import download 10 | from tensorflow_datasets.core import split_builder as split_builder_lib 11 | from tensorflow_datasets.core import naming 12 | from tensorflow_datasets.core import splits as splits_lib 13 | from tensorflow_datasets.core import utils 14 | from tensorflow_datasets.core import writer as writer_lib 15 | from tensorflow_datasets.core import example_serializer 16 | from tensorflow_datasets.core import dataset_builder 17 | from tensorflow_datasets.core import file_adapters 18 | 19 | Key = Union[str, int] 20 | # The nested example dict passed to `features.encode_example` 21 | Example = Dict[str, Any] 22 | KeyExample = Tuple[Key, Example] 23 | 24 | 25 | class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder): 26 | """DatasetBuilder for example dataset.""" 27 | N_WORKERS = 10 # number of parallel workers for data conversion 28 | MAX_PATHS_IN_MEMORY = 100 # number of paths converted & stored in memory before writing to disk 29 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 30 | # note that one path may yield multiple episodes and adjust accordingly 31 | PARSE_FCN = None # needs to be filled with path-to-record-episode parse function 32 | 33 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 34 | """Define data splits.""" 35 | split_paths = self._split_paths() 36 | return {split: type(self).PARSE_FCN(paths=split_paths[split]) for split in split_paths} 37 | 38 | def _generate_examples(self): 39 | pass # this is implemented in global method to enable multiprocessing 40 | 41 | def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks 42 | self, 43 | dl_manager: download.DownloadManager, 44 | download_config: download.DownloadConfig, 45 | ) -> None: 46 | """Generate all splits and returns the computed split infos.""" 47 | assert self.PARSE_FCN is not None # need to overwrite parse function 48 | split_builder = ParallelSplitBuilder( 49 | split_dict=self.info.splits, 50 | features=self.info.features, 51 | dataset_size=self.info.dataset_size, 52 | max_examples_per_split=download_config.max_examples_per_split, 53 | beam_options=download_config.beam_options, 54 | beam_runner=download_config.beam_runner, 55 | file_format=self.info.file_format, 56 | shard_config=download_config.get_shard_config(), 57 | split_paths=self._split_paths(), 58 | parse_function=type(self).PARSE_FCN, 59 | n_workers=self.N_WORKERS, 60 | max_paths_in_memory=self.MAX_PATHS_IN_MEMORY, 61 | ) 62 | split_generators = self._split_generators(dl_manager) 63 | split_generators = split_builder.normalize_legacy_split_generators( 64 | split_generators=split_generators, 65 | generator_fn=self._generate_examples, 66 | is_beam=False, 67 | ) 68 | dataset_builder._check_split_names(split_generators.keys()) 69 | 70 | # Start generating data for all splits 71 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ 72 | self.info.file_format 73 | ].FILE_SUFFIX 74 | 75 | split_info_futures = [] 76 | for split_name, generator in utils.tqdm( 77 | split_generators.items(), 78 | desc="Generating splits...", 79 | unit=" splits", 80 | leave=False, 81 | ): 82 | filename_template = naming.ShardedFileTemplate( 83 | split=split_name, 84 | dataset_name=self.name, 85 | data_dir=self.data_path, 86 | filetype_suffix=path_suffix, 87 | ) 88 | future = split_builder.submit_split_generation( 89 | split_name=split_name, 90 | generator=generator, 91 | filename_template=filename_template, 92 | disable_shuffling=self.info.disable_shuffling, 93 | ) 94 | split_info_futures.append(future) 95 | 96 | # Finalize the splits (after apache beam completed, if it was used) 97 | split_infos = [future.result() for future in split_info_futures] 98 | 99 | # Update the info object with the splits. 100 | split_dict = splits_lib.SplitDict(split_infos) 101 | self.info.set_splits(split_dict) 102 | 103 | 104 | class _SplitInfoFuture: 105 | """Future containing the `tfds.core.SplitInfo` result.""" 106 | 107 | def __init__(self, callback: Callable[[], splits_lib.SplitInfo]): 108 | self._callback = callback 109 | 110 | def result(self) -> splits_lib.SplitInfo: 111 | return self._callback() 112 | 113 | 114 | def parse_examples_from_generator(paths, fcn, split_name, total_num_examples, features, serializer): 115 | generator = fcn(paths) 116 | outputs = [] 117 | for sample in utils.tqdm( 118 | generator, 119 | desc=f'Generating {split_name} examples...', 120 | unit=' examples', 121 | total=total_num_examples, 122 | leave=False, 123 | mininterval=1.0, 124 | ): 125 | if sample is None: continue 126 | key, example = sample 127 | try: 128 | example = features.encode_example(example) 129 | except Exception as e: # pylint: disable=broad-except 130 | utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') 131 | outputs.append((key, serializer.serialize_example(example))) 132 | return outputs 133 | 134 | 135 | class ParallelSplitBuilder(split_builder_lib.SplitBuilder): 136 | def __init__(self, *args, split_paths, parse_function, n_workers, max_paths_in_memory, **kwargs): 137 | super().__init__(*args, **kwargs) 138 | self._split_paths = split_paths 139 | self._parse_function = parse_function 140 | self._n_workers = n_workers 141 | self._max_paths_in_memory = max_paths_in_memory 142 | 143 | def _build_from_generator( 144 | self, 145 | split_name: str, 146 | generator: Iterable[KeyExample], 147 | filename_template: naming.ShardedFileTemplate, 148 | disable_shuffling: bool, 149 | ) -> _SplitInfoFuture: 150 | """Split generator for example generators. 151 | 152 | Args: 153 | split_name: str, 154 | generator: Iterable[KeyExample], 155 | filename_template: Template to format the filename for a shard. 156 | disable_shuffling: Specifies whether to shuffle the examples, 157 | 158 | Returns: 159 | future: The future containing the `tfds.core.SplitInfo`. 160 | """ 161 | total_num_examples = None 162 | serialized_info = self._features.get_serialized_info() 163 | writer = writer_lib.Writer( 164 | serializer=example_serializer.ExampleSerializer(serialized_info), 165 | filename_template=filename_template, 166 | hash_salt=split_name, 167 | disable_shuffling=disable_shuffling, 168 | file_format=self._file_format, 169 | shard_config=self._shard_config, 170 | ) 171 | 172 | del generator # use parallel generators instead 173 | paths = self._split_paths[split_name] 174 | path_lists = chunk_max(paths, self._n_workers, self._max_paths_in_memory) # generate N file lists 175 | print(f"Generating with {self._n_workers} workers!") 176 | pool = Pool(processes=self._n_workers) 177 | for i, paths in enumerate(path_lists): 178 | print(f"Processing chunk {i + 1} of {len(path_lists)}.") 179 | results = pool.map( 180 | partial( 181 | parse_examples_from_generator, 182 | fcn=self._parse_function, 183 | split_name=split_name, 184 | total_num_examples=total_num_examples, 185 | serializer=writer._serializer, 186 | features=self._features 187 | ), 188 | paths 189 | ) 190 | # write results to shuffler --> this will automatically offload to disk if necessary 191 | print("Writing conversion results...") 192 | for result in itertools.chain(*results): 193 | key, serialized_example = result 194 | writer._shuffler.add(key, serialized_example) 195 | writer._num_examples += 1 196 | pool.close() 197 | 198 | print("Finishing split conversion...") 199 | shard_lengths, total_size = writer.finalize() 200 | 201 | split_info = splits_lib.SplitInfo( 202 | name=split_name, 203 | shard_lengths=shard_lengths, 204 | num_bytes=total_size, 205 | filename_template=filename_template, 206 | ) 207 | return _SplitInfoFuture(lambda: split_info) 208 | 209 | 210 | def dictlist2listdict(DL): 211 | " Converts a dict of lists to a list of dicts " 212 | return [dict(zip(DL, t)) for t in zip(*DL.values())] 213 | 214 | def chunks(l, n): 215 | """Yield n number of sequential chunks from l.""" 216 | d, r = divmod(len(l), n) 217 | for i in range(n): 218 | si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r) 219 | yield l[si:si + (d + 1 if i < r else d)] 220 | 221 | def chunk_max(l, n, max_chunk_sum): 222 | out = [] 223 | for _ in range(int(np.ceil(len(l) / max_chunk_sum))): 224 | out.append(list(chunks(l[:max_chunk_sum], n))) 225 | l = l[max_chunk_sum:] 226 | return out -------------------------------------------------------------------------------- /LIBERO_Goal/CITATIONS.bib: -------------------------------------------------------------------------------- 1 | // TODO(example_dataset): BibTeX citation 2 | -------------------------------------------------------------------------------- /LIBERO_Goal/LIBERO_Goal_dataset_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple, Any 2 | 3 | import os 4 | import h5py 5 | import glob 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | import sys 10 | from LIBERO_Goal.conversion_utils import MultiThreadedDatasetBuilder 11 | 12 | 13 | def _generate_examples(paths) -> Iterator[Tuple[str, Any]]: 14 | """Yields episodes for list of data paths.""" 15 | # the line below needs to be *inside* generate_examples so that each worker creates it's own model 16 | # creating one shared model outside this function would cause a deadlock 17 | 18 | def _parse_example(episode_path, demo_id): 19 | # load raw data 20 | with h5py.File(episode_path, "r") as F: 21 | if f"demo_{demo_id}" not in F['data'].keys(): 22 | return None # skip episode if the demo doesn't exist (e.g. due to failed demo) 23 | actions = F['data'][f"demo_{demo_id}"]["actions"][()] 24 | states = F['data'][f"demo_{demo_id}"]["obs"]["ee_states"][()] 25 | gripper_states = F['data'][f"demo_{demo_id}"]["obs"]["gripper_states"][()] 26 | joint_states = F['data'][f"demo_{demo_id}"]["obs"]["joint_states"][()] 27 | images = F['data'][f"demo_{demo_id}"]["obs"]["agentview_rgb"][()] 28 | wrist_images = F['data'][f"demo_{demo_id}"]["obs"]["eye_in_hand_rgb"][()] 29 | 30 | # compute language instruction 31 | raw_file_string = os.path.basename(episode_path).split('/')[-1] 32 | words = raw_file_string[:-10].split("_") 33 | command = '' 34 | for w in words: 35 | if "SCENE" in w: 36 | command = '' 37 | continue 38 | command = command + w + ' ' 39 | command = command[:-1] 40 | 41 | # assemble episode --> here we're assuming demos so we set reward to 1 at the end 42 | episode = [] 43 | for i in range(actions.shape[0]): 44 | episode.append({ 45 | 'observation': { 46 | 'image': images[i][::-1,::-1], 47 | 'wrist_image': wrist_images[i][::-1,::-1], 48 | 'state': np.asarray(np.concatenate((states[i], gripper_states[i]), axis=-1), np.float32), 49 | 'joint_state': np.asarray(joint_states[i], dtype=np.float32), 50 | }, 51 | 'action': np.asarray(actions[i], dtype=np.float32), 52 | 'discount': 1.0, 53 | 'reward': float(i == (actions.shape[0] - 1)), 54 | 'is_first': i == 0, 55 | 'is_last': i == (actions.shape[0] - 1), 56 | 'is_terminal': i == (actions.shape[0] - 1), 57 | 'language_instruction': command, 58 | }) 59 | 60 | # create output data sample 61 | sample = { 62 | 'steps': episode, 63 | 'episode_metadata': { 64 | 'file_path': episode_path 65 | } 66 | } 67 | 68 | # if you want to skip an example for whatever reason, simply return None 69 | return episode_path + f"_{demo_id}", sample 70 | 71 | # for smallish datasets, use single-thread parsing 72 | for sample in paths: 73 | with h5py.File(sample, "r") as F: 74 | n_demos = len(F['data']) 75 | idx = 0 76 | cnt = 0 77 | while cnt < n_demos: 78 | ret = _parse_example(sample, idx) 79 | if ret is not None: 80 | cnt += 1 81 | idx += 1 82 | yield ret 83 | 84 | 85 | class LIBEROGoal(MultiThreadedDatasetBuilder): 86 | """DatasetBuilder for example dataset.""" 87 | 88 | VERSION = tfds.core.Version('1.0.0') 89 | RELEASE_NOTES = { 90 | '1.0.0': 'Initial release.', 91 | } 92 | N_WORKERS = 40 # number of parallel workers for data conversion 93 | MAX_PATHS_IN_MEMORY = 80 # number of paths converted & stored in memory before writing to disk 94 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 95 | # note that one path may yield multiple episodes and adjust accordingly 96 | PARSE_FCN = _generate_examples # handle to parse function from file paths to RLDS episodes 97 | 98 | def _info(self) -> tfds.core.DatasetInfo: 99 | """Dataset metadata (homepage, citation,...).""" 100 | return self.dataset_info_from_configs( 101 | features=tfds.features.FeaturesDict({ 102 | 'steps': tfds.features.Dataset({ 103 | 'observation': tfds.features.FeaturesDict({ 104 | 'image': tfds.features.Image( 105 | shape=(256, 256, 3), 106 | dtype=np.uint8, 107 | encoding_format='jpeg', 108 | doc='Main camera RGB observation.', 109 | ), 110 | 'wrist_image': tfds.features.Image( 111 | shape=(256, 256, 3), 112 | dtype=np.uint8, 113 | encoding_format='jpeg', 114 | doc='Wrist camera RGB observation.', 115 | ), 116 | 'state': tfds.features.Tensor( 117 | shape=(8,), 118 | dtype=np.float32, 119 | doc='Robot EEF state (6D pose, 2D gripper).', 120 | ), 121 | 'joint_state': tfds.features.Tensor( 122 | shape=(7,), 123 | dtype=np.float32, 124 | doc='Robot joint angles.', 125 | ) 126 | }), 127 | 'action': tfds.features.Tensor( 128 | shape=(7,), 129 | dtype=np.float32, 130 | doc='Robot EEF action.', 131 | ), 132 | 'discount': tfds.features.Scalar( 133 | dtype=np.float32, 134 | doc='Discount if provided, default to 1.' 135 | ), 136 | 'reward': tfds.features.Scalar( 137 | dtype=np.float32, 138 | doc='Reward if provided, 1 on final step for demos.' 139 | ), 140 | 'is_first': tfds.features.Scalar( 141 | dtype=np.bool_, 142 | doc='True on first step of the episode.' 143 | ), 144 | 'is_last': tfds.features.Scalar( 145 | dtype=np.bool_, 146 | doc='True on last step of the episode.' 147 | ), 148 | 'is_terminal': tfds.features.Scalar( 149 | dtype=np.bool_, 150 | doc='True on last step of the episode if it is a terminal step, True for demos.' 151 | ), 152 | 'language_instruction': tfds.features.Text( 153 | doc='Language Instruction.' 154 | ), 155 | }), 156 | 'episode_metadata': tfds.features.FeaturesDict({ 157 | 'file_path': tfds.features.Text( 158 | doc='Path to the original data file.' 159 | ), 160 | }), 161 | })) 162 | 163 | def _split_paths(self): 164 | """Define filepaths for data splits.""" 165 | return { 166 | "train": glob.glob("/PATH/TO/LIBERO/libero/datasets/libero_goal_no_noops/*.hdf5"), 167 | } 168 | -------------------------------------------------------------------------------- /LIBERO_Goal/README.md: -------------------------------------------------------------------------------- 1 | TODO(example_dataset): Markdown description of your dataset. 2 | Description is **formatted** as markdown. 3 | 4 | It should also contain any processing which has been applied (if any), 5 | (e.g. corrupted example skipped, images cropped,...): 6 | -------------------------------------------------------------------------------- /LIBERO_Goal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moojink/rlds_dataset_builder/6174b0b6bb69df6361f1117944952bf14afb0cc3/LIBERO_Goal/__init__.py -------------------------------------------------------------------------------- /LIBERO_Goal/conversion_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Dict, Union, Callable, Iterable 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | import itertools 7 | from multiprocessing import Pool 8 | from functools import partial 9 | from tensorflow_datasets.core import download 10 | from tensorflow_datasets.core import split_builder as split_builder_lib 11 | from tensorflow_datasets.core import naming 12 | from tensorflow_datasets.core import splits as splits_lib 13 | from tensorflow_datasets.core import utils 14 | from tensorflow_datasets.core import writer as writer_lib 15 | from tensorflow_datasets.core import example_serializer 16 | from tensorflow_datasets.core import dataset_builder 17 | from tensorflow_datasets.core import file_adapters 18 | 19 | Key = Union[str, int] 20 | # The nested example dict passed to `features.encode_example` 21 | Example = Dict[str, Any] 22 | KeyExample = Tuple[Key, Example] 23 | 24 | 25 | class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder): 26 | """DatasetBuilder for example dataset.""" 27 | N_WORKERS = 10 # number of parallel workers for data conversion 28 | MAX_PATHS_IN_MEMORY = 100 # number of paths converted & stored in memory before writing to disk 29 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 30 | # note that one path may yield multiple episodes and adjust accordingly 31 | PARSE_FCN = None # needs to be filled with path-to-record-episode parse function 32 | 33 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 34 | """Define data splits.""" 35 | split_paths = self._split_paths() 36 | return {split: type(self).PARSE_FCN(paths=split_paths[split]) for split in split_paths} 37 | 38 | def _generate_examples(self): 39 | pass # this is implemented in global method to enable multiprocessing 40 | 41 | def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks 42 | self, 43 | dl_manager: download.DownloadManager, 44 | download_config: download.DownloadConfig, 45 | ) -> None: 46 | """Generate all splits and returns the computed split infos.""" 47 | assert self.PARSE_FCN is not None # need to overwrite parse function 48 | split_builder = ParallelSplitBuilder( 49 | split_dict=self.info.splits, 50 | features=self.info.features, 51 | dataset_size=self.info.dataset_size, 52 | max_examples_per_split=download_config.max_examples_per_split, 53 | beam_options=download_config.beam_options, 54 | beam_runner=download_config.beam_runner, 55 | file_format=self.info.file_format, 56 | shard_config=download_config.get_shard_config(), 57 | split_paths=self._split_paths(), 58 | parse_function=type(self).PARSE_FCN, 59 | n_workers=self.N_WORKERS, 60 | max_paths_in_memory=self.MAX_PATHS_IN_MEMORY, 61 | ) 62 | split_generators = self._split_generators(dl_manager) 63 | split_generators = split_builder.normalize_legacy_split_generators( 64 | split_generators=split_generators, 65 | generator_fn=self._generate_examples, 66 | is_beam=False, 67 | ) 68 | dataset_builder._check_split_names(split_generators.keys()) 69 | 70 | # Start generating data for all splits 71 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ 72 | self.info.file_format 73 | ].FILE_SUFFIX 74 | 75 | split_info_futures = [] 76 | for split_name, generator in utils.tqdm( 77 | split_generators.items(), 78 | desc="Generating splits...", 79 | unit=" splits", 80 | leave=False, 81 | ): 82 | filename_template = naming.ShardedFileTemplate( 83 | split=split_name, 84 | dataset_name=self.name, 85 | data_dir=self.data_path, 86 | filetype_suffix=path_suffix, 87 | ) 88 | future = split_builder.submit_split_generation( 89 | split_name=split_name, 90 | generator=generator, 91 | filename_template=filename_template, 92 | disable_shuffling=self.info.disable_shuffling, 93 | ) 94 | split_info_futures.append(future) 95 | 96 | # Finalize the splits (after apache beam completed, if it was used) 97 | split_infos = [future.result() for future in split_info_futures] 98 | 99 | # Update the info object with the splits. 100 | split_dict = splits_lib.SplitDict(split_infos) 101 | self.info.set_splits(split_dict) 102 | 103 | 104 | class _SplitInfoFuture: 105 | """Future containing the `tfds.core.SplitInfo` result.""" 106 | 107 | def __init__(self, callback: Callable[[], splits_lib.SplitInfo]): 108 | self._callback = callback 109 | 110 | def result(self) -> splits_lib.SplitInfo: 111 | return self._callback() 112 | 113 | 114 | def parse_examples_from_generator(paths, fcn, split_name, total_num_examples, features, serializer): 115 | generator = fcn(paths) 116 | outputs = [] 117 | for sample in utils.tqdm( 118 | generator, 119 | desc=f'Generating {split_name} examples...', 120 | unit=' examples', 121 | total=total_num_examples, 122 | leave=False, 123 | mininterval=1.0, 124 | ): 125 | if sample is None: continue 126 | key, example = sample 127 | try: 128 | example = features.encode_example(example) 129 | except Exception as e: # pylint: disable=broad-except 130 | utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') 131 | outputs.append((key, serializer.serialize_example(example))) 132 | return outputs 133 | 134 | 135 | class ParallelSplitBuilder(split_builder_lib.SplitBuilder): 136 | def __init__(self, *args, split_paths, parse_function, n_workers, max_paths_in_memory, **kwargs): 137 | super().__init__(*args, **kwargs) 138 | self._split_paths = split_paths 139 | self._parse_function = parse_function 140 | self._n_workers = n_workers 141 | self._max_paths_in_memory = max_paths_in_memory 142 | 143 | def _build_from_generator( 144 | self, 145 | split_name: str, 146 | generator: Iterable[KeyExample], 147 | filename_template: naming.ShardedFileTemplate, 148 | disable_shuffling: bool, 149 | ) -> _SplitInfoFuture: 150 | """Split generator for example generators. 151 | 152 | Args: 153 | split_name: str, 154 | generator: Iterable[KeyExample], 155 | filename_template: Template to format the filename for a shard. 156 | disable_shuffling: Specifies whether to shuffle the examples, 157 | 158 | Returns: 159 | future: The future containing the `tfds.core.SplitInfo`. 160 | """ 161 | total_num_examples = None 162 | serialized_info = self._features.get_serialized_info() 163 | writer = writer_lib.Writer( 164 | serializer=example_serializer.ExampleSerializer(serialized_info), 165 | filename_template=filename_template, 166 | hash_salt=split_name, 167 | disable_shuffling=disable_shuffling, 168 | file_format=self._file_format, 169 | shard_config=self._shard_config, 170 | ) 171 | 172 | del generator # use parallel generators instead 173 | paths = self._split_paths[split_name] 174 | path_lists = chunk_max(paths, self._n_workers, self._max_paths_in_memory) # generate N file lists 175 | print(f"Generating with {self._n_workers} workers!") 176 | pool = Pool(processes=self._n_workers) 177 | for i, paths in enumerate(path_lists): 178 | print(f"Processing chunk {i + 1} of {len(path_lists)}.") 179 | results = pool.map( 180 | partial( 181 | parse_examples_from_generator, 182 | fcn=self._parse_function, 183 | split_name=split_name, 184 | total_num_examples=total_num_examples, 185 | serializer=writer._serializer, 186 | features=self._features 187 | ), 188 | paths 189 | ) 190 | # write results to shuffler --> this will automatically offload to disk if necessary 191 | print("Writing conversion results...") 192 | for result in itertools.chain(*results): 193 | key, serialized_example = result 194 | writer._shuffler.add(key, serialized_example) 195 | writer._num_examples += 1 196 | pool.close() 197 | 198 | print("Finishing split conversion...") 199 | shard_lengths, total_size = writer.finalize() 200 | 201 | split_info = splits_lib.SplitInfo( 202 | name=split_name, 203 | shard_lengths=shard_lengths, 204 | num_bytes=total_size, 205 | filename_template=filename_template, 206 | ) 207 | return _SplitInfoFuture(lambda: split_info) 208 | 209 | 210 | def dictlist2listdict(DL): 211 | " Converts a dict of lists to a list of dicts " 212 | return [dict(zip(DL, t)) for t in zip(*DL.values())] 213 | 214 | def chunks(l, n): 215 | """Yield n number of sequential chunks from l.""" 216 | d, r = divmod(len(l), n) 217 | for i in range(n): 218 | si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r) 219 | yield l[si:si + (d + 1 if i < r else d)] 220 | 221 | def chunk_max(l, n, max_chunk_sum): 222 | out = [] 223 | for _ in range(int(np.ceil(len(l) / max_chunk_sum))): 224 | out.append(list(chunks(l[:max_chunk_sum], n))) 225 | l = l[max_chunk_sum:] 226 | return out -------------------------------------------------------------------------------- /LIBERO_Object/CITATIONS.bib: -------------------------------------------------------------------------------- 1 | // TODO(example_dataset): BibTeX citation 2 | -------------------------------------------------------------------------------- /LIBERO_Object/LIBERO_Object_dataset_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple, Any 2 | 3 | import os 4 | import h5py 5 | import glob 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | import sys 10 | from LIBERO_Object.conversion_utils import MultiThreadedDatasetBuilder 11 | 12 | 13 | def _generate_examples(paths) -> Iterator[Tuple[str, Any]]: 14 | """Yields episodes for list of data paths.""" 15 | # the line below needs to be *inside* generate_examples so that each worker creates it's own model 16 | # creating one shared model outside this function would cause a deadlock 17 | 18 | def _parse_example(episode_path, demo_id): 19 | # load raw data 20 | with h5py.File(episode_path, "r") as F: 21 | if f"demo_{demo_id}" not in F['data'].keys(): 22 | return None # skip episode if the demo doesn't exist (e.g. due to failed demo) 23 | actions = F['data'][f"demo_{demo_id}"]["actions"][()] 24 | states = F['data'][f"demo_{demo_id}"]["obs"]["ee_states"][()] 25 | gripper_states = F['data'][f"demo_{demo_id}"]["obs"]["gripper_states"][()] 26 | joint_states = F['data'][f"demo_{demo_id}"]["obs"]["joint_states"][()] 27 | images = F['data'][f"demo_{demo_id}"]["obs"]["agentview_rgb"][()] 28 | wrist_images = F['data'][f"demo_{demo_id}"]["obs"]["eye_in_hand_rgb"][()] 29 | 30 | # compute language instruction 31 | raw_file_string = os.path.basename(episode_path).split('/')[-1] 32 | words = raw_file_string[:-10].split("_") 33 | command = '' 34 | for w in words: 35 | if "SCENE" in w: 36 | command = '' 37 | continue 38 | command = command + w + ' ' 39 | command = command[:-1] 40 | 41 | # assemble episode --> here we're assuming demos so we set reward to 1 at the end 42 | episode = [] 43 | for i in range(actions.shape[0]): 44 | episode.append({ 45 | 'observation': { 46 | 'image': images[i][::-1,::-1], 47 | 'wrist_image': wrist_images[i][::-1,::-1], 48 | 'state': np.asarray(np.concatenate((states[i], gripper_states[i]), axis=-1), np.float32), 49 | 'joint_state': np.asarray(joint_states[i], dtype=np.float32), 50 | }, 51 | 'action': np.asarray(actions[i], dtype=np.float32), 52 | 'discount': 1.0, 53 | 'reward': float(i == (actions.shape[0] - 1)), 54 | 'is_first': i == 0, 55 | 'is_last': i == (actions.shape[0] - 1), 56 | 'is_terminal': i == (actions.shape[0] - 1), 57 | 'language_instruction': command, 58 | }) 59 | 60 | # create output data sample 61 | sample = { 62 | 'steps': episode, 63 | 'episode_metadata': { 64 | 'file_path': episode_path 65 | } 66 | } 67 | 68 | # if you want to skip an example for whatever reason, simply return None 69 | return episode_path + f"_{demo_id}", sample 70 | 71 | # for smallish datasets, use single-thread parsing 72 | for sample in paths: 73 | with h5py.File(sample, "r") as F: 74 | n_demos = len(F['data']) 75 | idx = 0 76 | cnt = 0 77 | while cnt < n_demos: 78 | ret = _parse_example(sample, idx) 79 | if ret is not None: 80 | cnt += 1 81 | idx += 1 82 | yield ret 83 | 84 | 85 | class LIBEROObject(MultiThreadedDatasetBuilder): 86 | """DatasetBuilder for example dataset.""" 87 | 88 | VERSION = tfds.core.Version('1.0.0') 89 | RELEASE_NOTES = { 90 | '1.0.0': 'Initial release.', 91 | } 92 | N_WORKERS = 40 # number of parallel workers for data conversion 93 | MAX_PATHS_IN_MEMORY = 80 # number of paths converted & stored in memory before writing to disk 94 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 95 | # note that one path may yield multiple episodes and adjust accordingly 96 | PARSE_FCN = _generate_examples # handle to parse function from file paths to RLDS episodes 97 | 98 | def _info(self) -> tfds.core.DatasetInfo: 99 | """Dataset metadata (homepage, citation,...).""" 100 | return self.dataset_info_from_configs( 101 | features=tfds.features.FeaturesDict({ 102 | 'steps': tfds.features.Dataset({ 103 | 'observation': tfds.features.FeaturesDict({ 104 | 'image': tfds.features.Image( 105 | shape=(256, 256, 3), 106 | dtype=np.uint8, 107 | encoding_format='jpeg', 108 | doc='Main camera RGB observation.', 109 | ), 110 | 'wrist_image': tfds.features.Image( 111 | shape=(256, 256, 3), 112 | dtype=np.uint8, 113 | encoding_format='jpeg', 114 | doc='Wrist camera RGB observation.', 115 | ), 116 | 'state': tfds.features.Tensor( 117 | shape=(8,), 118 | dtype=np.float32, 119 | doc='Robot EEF state (6D pose, 2D gripper).', 120 | ), 121 | 'joint_state': tfds.features.Tensor( 122 | shape=(7,), 123 | dtype=np.float32, 124 | doc='Robot joint angles.', 125 | ) 126 | }), 127 | 'action': tfds.features.Tensor( 128 | shape=(7,), 129 | dtype=np.float32, 130 | doc='Robot EEF action.', 131 | ), 132 | 'discount': tfds.features.Scalar( 133 | dtype=np.float32, 134 | doc='Discount if provided, default to 1.' 135 | ), 136 | 'reward': tfds.features.Scalar( 137 | dtype=np.float32, 138 | doc='Reward if provided, 1 on final step for demos.' 139 | ), 140 | 'is_first': tfds.features.Scalar( 141 | dtype=np.bool_, 142 | doc='True on first step of the episode.' 143 | ), 144 | 'is_last': tfds.features.Scalar( 145 | dtype=np.bool_, 146 | doc='True on last step of the episode.' 147 | ), 148 | 'is_terminal': tfds.features.Scalar( 149 | dtype=np.bool_, 150 | doc='True on last step of the episode if it is a terminal step, True for demos.' 151 | ), 152 | 'language_instruction': tfds.features.Text( 153 | doc='Language Instruction.' 154 | ), 155 | }), 156 | 'episode_metadata': tfds.features.FeaturesDict({ 157 | 'file_path': tfds.features.Text( 158 | doc='Path to the original data file.' 159 | ), 160 | }), 161 | })) 162 | 163 | def _split_paths(self): 164 | """Define filepaths for data splits.""" 165 | return { 166 | "train": glob.glob("/PATH/TO/LIBERO/libero/datasets/libero_object_no_noops/*.hdf5"), 167 | } 168 | -------------------------------------------------------------------------------- /LIBERO_Object/README.md: -------------------------------------------------------------------------------- 1 | TODO(example_dataset): Markdown description of your dataset. 2 | Description is **formatted** as markdown. 3 | 4 | It should also contain any processing which has been applied (if any), 5 | (e.g. corrupted example skipped, images cropped,...): 6 | -------------------------------------------------------------------------------- /LIBERO_Object/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moojink/rlds_dataset_builder/6174b0b6bb69df6361f1117944952bf14afb0cc3/LIBERO_Object/__init__.py -------------------------------------------------------------------------------- /LIBERO_Object/conversion_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Dict, Union, Callable, Iterable 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | import itertools 7 | from multiprocessing import Pool 8 | from functools import partial 9 | from tensorflow_datasets.core import download 10 | from tensorflow_datasets.core import split_builder as split_builder_lib 11 | from tensorflow_datasets.core import naming 12 | from tensorflow_datasets.core import splits as splits_lib 13 | from tensorflow_datasets.core import utils 14 | from tensorflow_datasets.core import writer as writer_lib 15 | from tensorflow_datasets.core import example_serializer 16 | from tensorflow_datasets.core import dataset_builder 17 | from tensorflow_datasets.core import file_adapters 18 | 19 | Key = Union[str, int] 20 | # The nested example dict passed to `features.encode_example` 21 | Example = Dict[str, Any] 22 | KeyExample = Tuple[Key, Example] 23 | 24 | 25 | class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder): 26 | """DatasetBuilder for example dataset.""" 27 | N_WORKERS = 10 # number of parallel workers for data conversion 28 | MAX_PATHS_IN_MEMORY = 100 # number of paths converted & stored in memory before writing to disk 29 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 30 | # note that one path may yield multiple episodes and adjust accordingly 31 | PARSE_FCN = None # needs to be filled with path-to-record-episode parse function 32 | 33 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 34 | """Define data splits.""" 35 | split_paths = self._split_paths() 36 | return {split: type(self).PARSE_FCN(paths=split_paths[split]) for split in split_paths} 37 | 38 | def _generate_examples(self): 39 | pass # this is implemented in global method to enable multiprocessing 40 | 41 | def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks 42 | self, 43 | dl_manager: download.DownloadManager, 44 | download_config: download.DownloadConfig, 45 | ) -> None: 46 | """Generate all splits and returns the computed split infos.""" 47 | assert self.PARSE_FCN is not None # need to overwrite parse function 48 | split_builder = ParallelSplitBuilder( 49 | split_dict=self.info.splits, 50 | features=self.info.features, 51 | dataset_size=self.info.dataset_size, 52 | max_examples_per_split=download_config.max_examples_per_split, 53 | beam_options=download_config.beam_options, 54 | beam_runner=download_config.beam_runner, 55 | file_format=self.info.file_format, 56 | shard_config=download_config.get_shard_config(), 57 | split_paths=self._split_paths(), 58 | parse_function=type(self).PARSE_FCN, 59 | n_workers=self.N_WORKERS, 60 | max_paths_in_memory=self.MAX_PATHS_IN_MEMORY, 61 | ) 62 | split_generators = self._split_generators(dl_manager) 63 | split_generators = split_builder.normalize_legacy_split_generators( 64 | split_generators=split_generators, 65 | generator_fn=self._generate_examples, 66 | is_beam=False, 67 | ) 68 | dataset_builder._check_split_names(split_generators.keys()) 69 | 70 | # Start generating data for all splits 71 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ 72 | self.info.file_format 73 | ].FILE_SUFFIX 74 | 75 | split_info_futures = [] 76 | for split_name, generator in utils.tqdm( 77 | split_generators.items(), 78 | desc="Generating splits...", 79 | unit=" splits", 80 | leave=False, 81 | ): 82 | filename_template = naming.ShardedFileTemplate( 83 | split=split_name, 84 | dataset_name=self.name, 85 | data_dir=self.data_path, 86 | filetype_suffix=path_suffix, 87 | ) 88 | future = split_builder.submit_split_generation( 89 | split_name=split_name, 90 | generator=generator, 91 | filename_template=filename_template, 92 | disable_shuffling=self.info.disable_shuffling, 93 | ) 94 | split_info_futures.append(future) 95 | 96 | # Finalize the splits (after apache beam completed, if it was used) 97 | split_infos = [future.result() for future in split_info_futures] 98 | 99 | # Update the info object with the splits. 100 | split_dict = splits_lib.SplitDict(split_infos) 101 | self.info.set_splits(split_dict) 102 | 103 | 104 | class _SplitInfoFuture: 105 | """Future containing the `tfds.core.SplitInfo` result.""" 106 | 107 | def __init__(self, callback: Callable[[], splits_lib.SplitInfo]): 108 | self._callback = callback 109 | 110 | def result(self) -> splits_lib.SplitInfo: 111 | return self._callback() 112 | 113 | 114 | def parse_examples_from_generator(paths, fcn, split_name, total_num_examples, features, serializer): 115 | generator = fcn(paths) 116 | outputs = [] 117 | for sample in utils.tqdm( 118 | generator, 119 | desc=f'Generating {split_name} examples...', 120 | unit=' examples', 121 | total=total_num_examples, 122 | leave=False, 123 | mininterval=1.0, 124 | ): 125 | if sample is None: continue 126 | key, example = sample 127 | try: 128 | example = features.encode_example(example) 129 | except Exception as e: # pylint: disable=broad-except 130 | utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') 131 | outputs.append((key, serializer.serialize_example(example))) 132 | return outputs 133 | 134 | 135 | class ParallelSplitBuilder(split_builder_lib.SplitBuilder): 136 | def __init__(self, *args, split_paths, parse_function, n_workers, max_paths_in_memory, **kwargs): 137 | super().__init__(*args, **kwargs) 138 | self._split_paths = split_paths 139 | self._parse_function = parse_function 140 | self._n_workers = n_workers 141 | self._max_paths_in_memory = max_paths_in_memory 142 | 143 | def _build_from_generator( 144 | self, 145 | split_name: str, 146 | generator: Iterable[KeyExample], 147 | filename_template: naming.ShardedFileTemplate, 148 | disable_shuffling: bool, 149 | ) -> _SplitInfoFuture: 150 | """Split generator for example generators. 151 | 152 | Args: 153 | split_name: str, 154 | generator: Iterable[KeyExample], 155 | filename_template: Template to format the filename for a shard. 156 | disable_shuffling: Specifies whether to shuffle the examples, 157 | 158 | Returns: 159 | future: The future containing the `tfds.core.SplitInfo`. 160 | """ 161 | total_num_examples = None 162 | serialized_info = self._features.get_serialized_info() 163 | writer = writer_lib.Writer( 164 | serializer=example_serializer.ExampleSerializer(serialized_info), 165 | filename_template=filename_template, 166 | hash_salt=split_name, 167 | disable_shuffling=disable_shuffling, 168 | file_format=self._file_format, 169 | shard_config=self._shard_config, 170 | ) 171 | 172 | del generator # use parallel generators instead 173 | paths = self._split_paths[split_name] 174 | path_lists = chunk_max(paths, self._n_workers, self._max_paths_in_memory) # generate N file lists 175 | print(f"Generating with {self._n_workers} workers!") 176 | pool = Pool(processes=self._n_workers) 177 | for i, paths in enumerate(path_lists): 178 | print(f"Processing chunk {i + 1} of {len(path_lists)}.") 179 | results = pool.map( 180 | partial( 181 | parse_examples_from_generator, 182 | fcn=self._parse_function, 183 | split_name=split_name, 184 | total_num_examples=total_num_examples, 185 | serializer=writer._serializer, 186 | features=self._features 187 | ), 188 | paths 189 | ) 190 | # write results to shuffler --> this will automatically offload to disk if necessary 191 | print("Writing conversion results...") 192 | for result in itertools.chain(*results): 193 | key, serialized_example = result 194 | writer._shuffler.add(key, serialized_example) 195 | writer._num_examples += 1 196 | pool.close() 197 | 198 | print("Finishing split conversion...") 199 | shard_lengths, total_size = writer.finalize() 200 | 201 | split_info = splits_lib.SplitInfo( 202 | name=split_name, 203 | shard_lengths=shard_lengths, 204 | num_bytes=total_size, 205 | filename_template=filename_template, 206 | ) 207 | return _SplitInfoFuture(lambda: split_info) 208 | 209 | 210 | def dictlist2listdict(DL): 211 | " Converts a dict of lists to a list of dicts " 212 | return [dict(zip(DL, t)) for t in zip(*DL.values())] 213 | 214 | def chunks(l, n): 215 | """Yield n number of sequential chunks from l.""" 216 | d, r = divmod(len(l), n) 217 | for i in range(n): 218 | si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r) 219 | yield l[si:si + (d + 1 if i < r else d)] 220 | 221 | def chunk_max(l, n, max_chunk_sum): 222 | out = [] 223 | for _ in range(int(np.ceil(len(l) / max_chunk_sum))): 224 | out.append(list(chunks(l[:max_chunk_sum], n))) 225 | l = l[max_chunk_sum:] 226 | return out -------------------------------------------------------------------------------- /LIBERO_Spatial/CITATIONS.bib: -------------------------------------------------------------------------------- 1 | // TODO(example_dataset): BibTeX citation 2 | -------------------------------------------------------------------------------- /LIBERO_Spatial/LIBERO_Spatial_dataset_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple, Any 2 | 3 | import os 4 | import h5py 5 | import glob 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | import sys 10 | from LIBERO_Spatial.conversion_utils import MultiThreadedDatasetBuilder 11 | 12 | 13 | def _generate_examples(paths) -> Iterator[Tuple[str, Any]]: 14 | """Yields episodes for list of data paths.""" 15 | # the line below needs to be *inside* generate_examples so that each worker creates it's own model 16 | # creating one shared model outside this function would cause a deadlock 17 | 18 | def _parse_example(episode_path, demo_id): 19 | # load raw data 20 | with h5py.File(episode_path, "r") as F: 21 | if f"demo_{demo_id}" not in F['data'].keys(): 22 | return None # skip episode if the demo doesn't exist (e.g. due to failed demo) 23 | actions = F['data'][f"demo_{demo_id}"]["actions"][()] 24 | states = F['data'][f"demo_{demo_id}"]["obs"]["ee_states"][()] 25 | gripper_states = F['data'][f"demo_{demo_id}"]["obs"]["gripper_states"][()] 26 | joint_states = F['data'][f"demo_{demo_id}"]["obs"]["joint_states"][()] 27 | images = F['data'][f"demo_{demo_id}"]["obs"]["agentview_rgb"][()] 28 | wrist_images = F['data'][f"demo_{demo_id}"]["obs"]["eye_in_hand_rgb"][()] 29 | 30 | # compute language instruction 31 | raw_file_string = os.path.basename(episode_path).split('/')[-1] 32 | words = raw_file_string[:-10].split("_") 33 | command = '' 34 | for w in words: 35 | if "SCENE" in w: 36 | command = '' 37 | continue 38 | command = command + w + ' ' 39 | command = command[:-1] 40 | 41 | # assemble episode --> here we're assuming demos so we set reward to 1 at the end 42 | episode = [] 43 | for i in range(actions.shape[0]): 44 | episode.append({ 45 | 'observation': { 46 | 'image': images[i][::-1,::-1], 47 | 'wrist_image': wrist_images[i][::-1,::-1], 48 | 'state': np.asarray(np.concatenate((states[i], gripper_states[i]), axis=-1), np.float32), 49 | 'joint_state': np.asarray(joint_states[i], dtype=np.float32), 50 | }, 51 | 'action': np.asarray(actions[i], dtype=np.float32), 52 | 'discount': 1.0, 53 | 'reward': float(i == (actions.shape[0] - 1)), 54 | 'is_first': i == 0, 55 | 'is_last': i == (actions.shape[0] - 1), 56 | 'is_terminal': i == (actions.shape[0] - 1), 57 | 'language_instruction': command, 58 | }) 59 | 60 | # create output data sample 61 | sample = { 62 | 'steps': episode, 63 | 'episode_metadata': { 64 | 'file_path': episode_path 65 | } 66 | } 67 | 68 | # if you want to skip an example for whatever reason, simply return None 69 | return episode_path + f"_{demo_id}", sample 70 | 71 | # for smallish datasets, use single-thread parsing 72 | for sample in paths: 73 | with h5py.File(sample, "r") as F: 74 | n_demos = len(F['data']) 75 | idx = 0 76 | cnt = 0 77 | while cnt < n_demos: 78 | ret = _parse_example(sample, idx) 79 | if ret is not None: 80 | cnt += 1 81 | idx += 1 82 | yield ret 83 | 84 | 85 | class LIBEROSpatial(MultiThreadedDatasetBuilder): 86 | """DatasetBuilder for example dataset.""" 87 | 88 | VERSION = tfds.core.Version('1.0.0') 89 | RELEASE_NOTES = { 90 | '1.0.0': 'Initial release.', 91 | } 92 | N_WORKERS = 40 # number of parallel workers for data conversion 93 | MAX_PATHS_IN_MEMORY = 80 # number of paths converted & stored in memory before writing to disk 94 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 95 | # note that one path may yield multiple episodes and adjust accordingly 96 | PARSE_FCN = _generate_examples # handle to parse function from file paths to RLDS episodes 97 | 98 | def _info(self) -> tfds.core.DatasetInfo: 99 | """Dataset metadata (homepage, citation,...).""" 100 | return self.dataset_info_from_configs( 101 | features=tfds.features.FeaturesDict({ 102 | 'steps': tfds.features.Dataset({ 103 | 'observation': tfds.features.FeaturesDict({ 104 | 'image': tfds.features.Image( 105 | shape=(256, 256, 3), 106 | dtype=np.uint8, 107 | encoding_format='jpeg', 108 | doc='Main camera RGB observation.', 109 | ), 110 | 'wrist_image': tfds.features.Image( 111 | shape=(256, 256, 3), 112 | dtype=np.uint8, 113 | encoding_format='jpeg', 114 | doc='Wrist camera RGB observation.', 115 | ), 116 | 'state': tfds.features.Tensor( 117 | shape=(8,), 118 | dtype=np.float32, 119 | doc='Robot EEF state (6D pose, 2D gripper).', 120 | ), 121 | 'joint_state': tfds.features.Tensor( 122 | shape=(7,), 123 | dtype=np.float32, 124 | doc='Robot joint angles.', 125 | ) 126 | }), 127 | 'action': tfds.features.Tensor( 128 | shape=(7,), 129 | dtype=np.float32, 130 | doc='Robot EEF action.', 131 | ), 132 | 'discount': tfds.features.Scalar( 133 | dtype=np.float32, 134 | doc='Discount if provided, default to 1.' 135 | ), 136 | 'reward': tfds.features.Scalar( 137 | dtype=np.float32, 138 | doc='Reward if provided, 1 on final step for demos.' 139 | ), 140 | 'is_first': tfds.features.Scalar( 141 | dtype=np.bool_, 142 | doc='True on first step of the episode.' 143 | ), 144 | 'is_last': tfds.features.Scalar( 145 | dtype=np.bool_, 146 | doc='True on last step of the episode.' 147 | ), 148 | 'is_terminal': tfds.features.Scalar( 149 | dtype=np.bool_, 150 | doc='True on last step of the episode if it is a terminal step, True for demos.' 151 | ), 152 | 'language_instruction': tfds.features.Text( 153 | doc='Language Instruction.' 154 | ), 155 | }), 156 | 'episode_metadata': tfds.features.FeaturesDict({ 157 | 'file_path': tfds.features.Text( 158 | doc='Path to the original data file.' 159 | ), 160 | }), 161 | })) 162 | 163 | def _split_paths(self): 164 | """Define filepaths for data splits.""" 165 | return { 166 | "train": glob.glob("/PATH/TO/LIBERO/libero/datasets/libero_spatial_no_noops/*.hdf5"), 167 | } 168 | -------------------------------------------------------------------------------- /LIBERO_Spatial/README.md: -------------------------------------------------------------------------------- 1 | TODO(example_dataset): Markdown description of your dataset. 2 | Description is **formatted** as markdown. 3 | 4 | It should also contain any processing which has been applied (if any), 5 | (e.g. corrupted example skipped, images cropped,...): 6 | -------------------------------------------------------------------------------- /LIBERO_Spatial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moojink/rlds_dataset_builder/6174b0b6bb69df6361f1117944952bf14afb0cc3/LIBERO_Spatial/__init__.py -------------------------------------------------------------------------------- /LIBERO_Spatial/conversion_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Dict, Union, Callable, Iterable 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | import itertools 7 | from multiprocessing import Pool 8 | from functools import partial 9 | from tensorflow_datasets.core import download 10 | from tensorflow_datasets.core import split_builder as split_builder_lib 11 | from tensorflow_datasets.core import naming 12 | from tensorflow_datasets.core import splits as splits_lib 13 | from tensorflow_datasets.core import utils 14 | from tensorflow_datasets.core import writer as writer_lib 15 | from tensorflow_datasets.core import example_serializer 16 | from tensorflow_datasets.core import dataset_builder 17 | from tensorflow_datasets.core import file_adapters 18 | 19 | Key = Union[str, int] 20 | # The nested example dict passed to `features.encode_example` 21 | Example = Dict[str, Any] 22 | KeyExample = Tuple[Key, Example] 23 | 24 | 25 | class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder): 26 | """DatasetBuilder for example dataset.""" 27 | N_WORKERS = 10 # number of parallel workers for data conversion 28 | MAX_PATHS_IN_MEMORY = 100 # number of paths converted & stored in memory before writing to disk 29 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 30 | # note that one path may yield multiple episodes and adjust accordingly 31 | PARSE_FCN = None # needs to be filled with path-to-record-episode parse function 32 | 33 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 34 | """Define data splits.""" 35 | split_paths = self._split_paths() 36 | return {split: type(self).PARSE_FCN(paths=split_paths[split]) for split in split_paths} 37 | 38 | def _generate_examples(self): 39 | pass # this is implemented in global method to enable multiprocessing 40 | 41 | def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks 42 | self, 43 | dl_manager: download.DownloadManager, 44 | download_config: download.DownloadConfig, 45 | ) -> None: 46 | """Generate all splits and returns the computed split infos.""" 47 | assert self.PARSE_FCN is not None # need to overwrite parse function 48 | split_builder = ParallelSplitBuilder( 49 | split_dict=self.info.splits, 50 | features=self.info.features, 51 | dataset_size=self.info.dataset_size, 52 | max_examples_per_split=download_config.max_examples_per_split, 53 | beam_options=download_config.beam_options, 54 | beam_runner=download_config.beam_runner, 55 | file_format=self.info.file_format, 56 | shard_config=download_config.get_shard_config(), 57 | split_paths=self._split_paths(), 58 | parse_function=type(self).PARSE_FCN, 59 | n_workers=self.N_WORKERS, 60 | max_paths_in_memory=self.MAX_PATHS_IN_MEMORY, 61 | ) 62 | split_generators = self._split_generators(dl_manager) 63 | split_generators = split_builder.normalize_legacy_split_generators( 64 | split_generators=split_generators, 65 | generator_fn=self._generate_examples, 66 | is_beam=False, 67 | ) 68 | dataset_builder._check_split_names(split_generators.keys()) 69 | 70 | # Start generating data for all splits 71 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ 72 | self.info.file_format 73 | ].FILE_SUFFIX 74 | 75 | split_info_futures = [] 76 | for split_name, generator in utils.tqdm( 77 | split_generators.items(), 78 | desc="Generating splits...", 79 | unit=" splits", 80 | leave=False, 81 | ): 82 | filename_template = naming.ShardedFileTemplate( 83 | split=split_name, 84 | dataset_name=self.name, 85 | data_dir=self.data_path, 86 | filetype_suffix=path_suffix, 87 | ) 88 | future = split_builder.submit_split_generation( 89 | split_name=split_name, 90 | generator=generator, 91 | filename_template=filename_template, 92 | disable_shuffling=self.info.disable_shuffling, 93 | ) 94 | split_info_futures.append(future) 95 | 96 | # Finalize the splits (after apache beam completed, if it was used) 97 | split_infos = [future.result() for future in split_info_futures] 98 | 99 | # Update the info object with the splits. 100 | split_dict = splits_lib.SplitDict(split_infos) 101 | self.info.set_splits(split_dict) 102 | 103 | 104 | class _SplitInfoFuture: 105 | """Future containing the `tfds.core.SplitInfo` result.""" 106 | 107 | def __init__(self, callback: Callable[[], splits_lib.SplitInfo]): 108 | self._callback = callback 109 | 110 | def result(self) -> splits_lib.SplitInfo: 111 | return self._callback() 112 | 113 | 114 | def parse_examples_from_generator(paths, fcn, split_name, total_num_examples, features, serializer): 115 | generator = fcn(paths) 116 | outputs = [] 117 | for sample in utils.tqdm( 118 | generator, 119 | desc=f'Generating {split_name} examples...', 120 | unit=' examples', 121 | total=total_num_examples, 122 | leave=False, 123 | mininterval=1.0, 124 | ): 125 | if sample is None: continue 126 | key, example = sample 127 | try: 128 | example = features.encode_example(example) 129 | except Exception as e: # pylint: disable=broad-except 130 | utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') 131 | outputs.append((key, serializer.serialize_example(example))) 132 | return outputs 133 | 134 | 135 | class ParallelSplitBuilder(split_builder_lib.SplitBuilder): 136 | def __init__(self, *args, split_paths, parse_function, n_workers, max_paths_in_memory, **kwargs): 137 | super().__init__(*args, **kwargs) 138 | self._split_paths = split_paths 139 | self._parse_function = parse_function 140 | self._n_workers = n_workers 141 | self._max_paths_in_memory = max_paths_in_memory 142 | 143 | def _build_from_generator( 144 | self, 145 | split_name: str, 146 | generator: Iterable[KeyExample], 147 | filename_template: naming.ShardedFileTemplate, 148 | disable_shuffling: bool, 149 | ) -> _SplitInfoFuture: 150 | """Split generator for example generators. 151 | 152 | Args: 153 | split_name: str, 154 | generator: Iterable[KeyExample], 155 | filename_template: Template to format the filename for a shard. 156 | disable_shuffling: Specifies whether to shuffle the examples, 157 | 158 | Returns: 159 | future: The future containing the `tfds.core.SplitInfo`. 160 | """ 161 | total_num_examples = None 162 | serialized_info = self._features.get_serialized_info() 163 | writer = writer_lib.Writer( 164 | serializer=example_serializer.ExampleSerializer(serialized_info), 165 | filename_template=filename_template, 166 | hash_salt=split_name, 167 | disable_shuffling=disable_shuffling, 168 | file_format=self._file_format, 169 | shard_config=self._shard_config, 170 | ) 171 | 172 | del generator # use parallel generators instead 173 | paths = self._split_paths[split_name] 174 | path_lists = chunk_max(paths, self._n_workers, self._max_paths_in_memory) # generate N file lists 175 | print(f"Generating with {self._n_workers} workers!") 176 | pool = Pool(processes=self._n_workers) 177 | for i, paths in enumerate(path_lists): 178 | print(f"Processing chunk {i + 1} of {len(path_lists)}.") 179 | results = pool.map( 180 | partial( 181 | parse_examples_from_generator, 182 | fcn=self._parse_function, 183 | split_name=split_name, 184 | total_num_examples=total_num_examples, 185 | serializer=writer._serializer, 186 | features=self._features 187 | ), 188 | paths 189 | ) 190 | # write results to shuffler --> this will automatically offload to disk if necessary 191 | print("Writing conversion results...") 192 | for result in itertools.chain(*results): 193 | key, serialized_example = result 194 | writer._shuffler.add(key, serialized_example) 195 | writer._num_examples += 1 196 | pool.close() 197 | 198 | print("Finishing split conversion...") 199 | shard_lengths, total_size = writer.finalize() 200 | 201 | split_info = splits_lib.SplitInfo( 202 | name=split_name, 203 | shard_lengths=shard_lengths, 204 | num_bytes=total_size, 205 | filename_template=filename_template, 206 | ) 207 | return _SplitInfoFuture(lambda: split_info) 208 | 209 | 210 | def dictlist2listdict(DL): 211 | " Converts a dict of lists to a list of dicts " 212 | return [dict(zip(DL, t)) for t in zip(*DL.values())] 213 | 214 | def chunks(l, n): 215 | """Yield n number of sequential chunks from l.""" 216 | d, r = divmod(len(l), n) 217 | for i in range(n): 218 | si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r) 219 | yield l[si:si + (d + 1 if i < r else d)] 220 | 221 | def chunk_max(l, n, max_chunk_sum): 222 | out = [] 223 | for _ in range(int(np.ceil(len(l) / max_chunk_sum))): 224 | out.append(list(chunks(l[:max_chunk_sum], n))) 225 | l = l[max_chunk_sum:] 226 | return out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Karl Pertsch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RLDS Dataset Conversion 2 | 3 | This repo demonstrates how to convert an existing dataset into RLDS format for X-embodiment experiment integration. 4 | It provides an example for converting a dummy dataset to RLDS. To convert your own dataset, **fork** this repo and 5 | modify the example code for your dataset following the steps below. 6 | 7 | ## Installation 8 | 9 | First create a conda environment using the provided environment.yml file (use `environment_ubuntu.yml` or `environment_macos.yml` depending on the operating system you're using): 10 | ``` 11 | conda env create -f environment_ubuntu.yml 12 | ``` 13 | 14 | Then activate the environment using: 15 | ``` 16 | conda activate rlds_env 17 | ``` 18 | 19 | If you want to manually create an environment, the key packages to install are `tensorflow`, 20 | `tensorflow_datasets`, `tensorflow_hub`, `apache_beam`, `matplotlib`, `plotly` and `wandb`. 21 | 22 | 23 | ## Run Example RLDS Dataset Creation 24 | 25 | Before modifying the code to convert your own dataset, run the provided example dataset creation script to ensure 26 | everything is installed correctly. Run the following lines to create some dummy data and convert it to RLDS. 27 | ``` 28 | cd example_dataset 29 | python3 create_example_data.py 30 | tfds build 31 | ``` 32 | 33 | This should create a new dataset in `~/tensorflow_datasets/example_dataset`. Please verify that the example 34 | conversion worked before moving on. 35 | 36 | 37 | ## Converting your Own Dataset to RLDS 38 | 39 | Now we can modify the provided example to convert your own data. Follow the steps below: 40 | 41 | 1. **Rename Dataset**: Change the name of the dataset folder from `example_dataset` to the name of your dataset (e.g. robo_net_v2), 42 | also change the name of `example_dataset_dataset_builder.py` by replacing `example_dataset` with your dataset's name (e.g. robo_net_v2_dataset_builder.py) 43 | and change the class name `ExampleDataset` in the same file to match your dataset's name, using camel case instead of underlines (e.g. RoboNetV2). 44 | 45 | 2. **Modify Features**: Modify the data fields you plan to store in the dataset. You can find them in the `_info()` method 46 | of the `ExampleDataset` class. Please add **all** data fields your raw data contains, i.e. please add additional features for 47 | additional cameras, audio, tactile features etc. If your type of feature is not demonstrated in the example (e.g. audio), 48 | you can find a list of all supported feature types [here](https://www.tensorflow.org/datasets/api_docs/python/tfds/features?hl=en#classes). 49 | You can store step-wise info like camera images, actions etc in `'steps'` and episode-wise info like `collector_id` in `episode_metadata`. 50 | Please don't remove any of the existing features in the example (except for `wrist_image` and `state`), since they are required for RLDS compliance. 51 | Please add detailed documentation what each feature consists of (e.g. what are the dimensions of the action space etc.). 52 | Note that we store `language_instruction` in every step even though it is episode-wide information for easier downstream usage (if your dataset 53 | does not define language instructions, you can fill in a dummy string like `pick up something`). 54 | 55 | 3. **Modify Dataset Splits**: The function `_split_generator()` determines the splits of the generated dataset (e.g. training, validation etc.). 56 | If your dataset defines a train vs validation split, please provide the corresponding information to `_generate_examples()`, e.g. 57 | by pointing to the corresponding folders (like in the example) or file IDs etc. If your dataset does not define splits, 58 | remove the `val` split and only include the `train` split. You can then remove all arguments to `_generate_examples()`. 59 | 60 | 4. **Modify Dataset Conversion Code**: Next, modify the function `_generate_examples()`. Here, your own raw data should be 61 | loaded, filled into the episode steps and then yielded as a packaged example. Note that the value of the first return argument, 62 | `episode_path` in the example, is only used as a sample ID in the dataset and can be set to any value that is connected to the 63 | particular stored episode, or any other random value. Just ensure to avoid using the same ID twice. 64 | 65 | 5. **Provide Dataset Description**: Next, add a bibtex citation for your dataset in `CITATIONS.bib` and add a short description 66 | of your dataset in `README.md` inside the dataset folder. You can also provide a link to the dataset website and please add a 67 | few example trajectory images from the dataset for visualization. 68 | 69 | 6. **Add Appropriate License**: Please add an appropriate license to the repository. 70 | Most common is the [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license -- 71 | you can copy it from [here](https://github.com/teamdigitale/licenses/blob/master/CC-BY-4.0). 72 | 73 | That's it! You're all set to run dataset conversion. Inside the dataset directory, run: 74 | ``` 75 | tfds build --overwrite 76 | ``` 77 | The command line output should finish with a summary of the generated dataset (including size and number of samples). 78 | Please verify that this output looks as expected and that you can find the generated `tfrecord` files in `~/tensorflow_datasets/`. 79 | 80 | 81 | ### Parallelizing Data Processing 82 | By default, dataset conversion is single-threaded. If you are parsing a large dataset, you can use parallel processing. 83 | For this, replace the last two lines of `_generate_examples()` with the commented-out `beam` commands. This will use 84 | Apache Beam to parallelize data processing. Before starting the processing, you need to install your dataset package 85 | by filling in the name of your dataset into `setup.py` and running `pip install -e .` 86 | 87 | Then, make sure that no GPUs are used during data processing (`export CUDA_VISIBLE_DEVICES=`) and run: 88 | ``` 89 | tfds build --overwrite --beam_pipeline_options="direct_running_mode=multi_processing,direct_num_workers=10" 90 | ``` 91 | You can specify the desired number of workers with the `direct_num_workers` argument. 92 | 93 | ## Visualize Converted Dataset 94 | To verify that the data is converted correctly, please run the data visualization script from the base directory: 95 | ``` 96 | python3 visualize_dataset.py 97 | ``` 98 | This will display a few random episodes from the dataset with language commands and visualize action and state histograms per dimension. 99 | Note, if you are running on a headless server you can modify `WANDB_ENTITY` at the top of `visualize_dataset.py` and 100 | add your own WandB entity -- then the script will log all visualizations to WandB. 101 | 102 | ## Add Transform for Target Spec 103 | 104 | For X-embodiment training we are using specific inputs / outputs for the model: input is a single RGB camera, output 105 | is an 8-dimensional action, consisting of end-effector position and orientation, gripper open/close and a episode termination 106 | action. 107 | 108 | The final step in adding your dataset to the training mix is to provide a transform function, that transforms a step 109 | from your original dataset above to the required training spec. Please follow the two simple steps below: 110 | 111 | 1. **Modify Step Transform**: Modify the function `transform_step()` in `example_transform/transform.py`. The function 112 | takes in a step from your dataset above and is supposed to map it to the desired output spec. The file contains a detailed 113 | description of the desired output spec. 114 | 115 | 2. **Test Transform**: We provide a script to verify that the resulting __transformed__ dataset outputs match the desired 116 | output spec. Please run the following command: `python3 test_dataset_transform.py ` 117 | 118 | If the test passes successfully, you are ready to upload your dataset! 119 | 120 | ## Upload Your Data 121 | 122 | We provide a Google Cloud bucket that you can upload your data to. First, install `gsutil`, the Google cloud command 123 | line tool. You can follow the installation instructions [here](https://cloud.google.com/storage/docs/gsutil_install). 124 | 125 | Next, authenticate your Google account with: 126 | ``` 127 | gcloud auth login 128 | ``` 129 | This will open a browser window that allows you to log into your Google account (if you're on a headless server, 130 | you can add the `--no-launch-browser` flag). Ideally, use the email address that 131 | you used to communicate with Karl, since he will automatically grant permission to the bucket for this email address. 132 | If you want to upload data with a different email address / google account, please shoot Karl a quick email to ask 133 | to grant permissions to that Google account! 134 | 135 | After logging in with a Google account that has access permissions, you can upload your data with the following 136 | command: 137 | ``` 138 | gsutil -m cp -r ~/tensorflow_datasets/ gs://xembodiment_data 139 | ``` 140 | This will upload all data using multiple threads. If your internet connection gets interrupted anytime during the upload 141 | you can just rerun the command and it will resume the upload where it was interrupted. You can verify that the upload 142 | was successful by inspecting the bucket [here](https://console.cloud.google.com/storage/browser/xembodiment_data). 143 | 144 | The last step is to commit all changes to this repo and send Karl the link to the repo. 145 | 146 | **Thanks a lot for contributing your data! :)** 147 | -------------------------------------------------------------------------------- /aloha1_put_X_into_pot_300_demos/CITATIONS.bib: -------------------------------------------------------------------------------- 1 | // TODO(example_dataset): BibTeX citation 2 | -------------------------------------------------------------------------------- /aloha1_put_X_into_pot_300_demos/README.md: -------------------------------------------------------------------------------- 1 | TODO(example_dataset): Markdown description of your dataset. 2 | Description is **formatted** as markdown. 3 | 4 | It should also contain any processing which has been applied (if any), 5 | (e.g. corrupted example skipped, images cropped,...): 6 | -------------------------------------------------------------------------------- /aloha1_put_X_into_pot_300_demos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moojink/rlds_dataset_builder/6174b0b6bb69df6361f1117944952bf14afb0cc3/aloha1_put_X_into_pot_300_demos/__init__.py -------------------------------------------------------------------------------- /aloha1_put_X_into_pot_300_demos/aloha1_put_X_into_pot_300_demos_dataset_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple, Any 2 | 3 | import os 4 | import h5py 5 | import glob 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | import sys 10 | import sys 11 | sys.path.append('.') 12 | from aloha1_put_X_into_pot_300_demos.conversion_utils import MultiThreadedDatasetBuilder 13 | 14 | 15 | def _generate_examples(paths) -> Iterator[Tuple[str, Any]]: 16 | """Yields episodes for list of data paths.""" 17 | # the line below needs to be *inside* generate_examples so that each worker creates it's own model 18 | # creating one shared model outside this function would cause a deadlock 19 | 20 | def _parse_example(episode_path): 21 | # Load raw data 22 | with h5py.File(episode_path, "r") as F: 23 | actions = F["/action"][()] 24 | states = F["/observations/qpos"][()] 25 | images = F["/observations/images/cam_high"][()] # Primary camera (top-down view) 26 | left_wrist_images = F["/observations/images/cam_left_wrist"][()] # Left wrist camera 27 | right_wrist_images = F["/observations/images/cam_right_wrist"][()] # Right wrist camera 28 | low_cam_images = F["/observations/images/cam_low"][()] # Low third-person camera 29 | 30 | # Get language instruction 31 | # Assumes filepaths look like: "/PATH/TO/ALOHA/PREPROCESSED/DATASETS//train/episode_0.hdf5" 32 | raw_file_string = episode_path.split('/')[-3] # E.g., '/scr/moojink/data/aloha1_preprocessed/put_green_pepper_into_pot/train/episode_0.hdf5' -> put_green_pepper_into_pot 33 | command = " ".join(raw_file_string.split("_")) 34 | 35 | # Assemble episode: here we're assuming demos so we set reward to 1 at the end 36 | episode = [] 37 | for i in range(actions.shape[0]): 38 | episode.append({ 39 | 'observation': { 40 | 'image': images[i], 41 | 'left_wrist_image': left_wrist_images[i], 42 | 'right_wrist_image': right_wrist_images[i], 43 | 'low_cam_image': low_cam_images[i], 44 | 'state': np.asarray(states[i], np.float32), 45 | }, 46 | 'action': np.asarray(actions[i], dtype=np.float32), 47 | 'discount': 1.0, 48 | 'reward': float(i == (actions.shape[0] - 1)), 49 | 'is_first': i == 0, 50 | 'is_last': i == (actions.shape[0] - 1), 51 | 'is_terminal': i == (actions.shape[0] - 1), 52 | 'language_instruction': command, 53 | }) 54 | 55 | # Create output data sample 56 | sample = { 57 | 'steps': episode, 58 | 'episode_metadata': { 59 | 'file_path': episode_path 60 | } 61 | } 62 | 63 | # If you want to skip an example for whatever reason, simply return None 64 | return episode_path, sample 65 | 66 | # For smallish datasets, use single-thread parsing 67 | for sample in paths: 68 | ret = _parse_example(sample) 69 | yield ret 70 | 71 | 72 | class aloha1_put_X_into_pot_300_demos(MultiThreadedDatasetBuilder): 73 | """DatasetBuilder for example dataset.""" 74 | 75 | VERSION = tfds.core.Version('1.0.0') 76 | RELEASE_NOTES = { 77 | '1.0.0': 'Initial release.', 78 | } 79 | N_WORKERS = 40 # number of parallel workers for data conversion 80 | MAX_PATHS_IN_MEMORY = 80 # number of paths converted & stored in memory before writing to disk 81 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 82 | # note that one path may yield multiple episodes and adjust accordingly 83 | PARSE_FCN = _generate_examples # handle to parse function from file paths to RLDS episodes 84 | 85 | def _info(self) -> tfds.core.DatasetInfo: 86 | """Dataset metadata (homepage, citation,...).""" 87 | return self.dataset_info_from_configs( 88 | features=tfds.features.FeaturesDict({ 89 | 'steps': tfds.features.Dataset({ 90 | 'observation': tfds.features.FeaturesDict({ 91 | 'image': tfds.features.Image( 92 | shape=(256, 256, 3), 93 | dtype=np.uint8, 94 | encoding_format='jpeg', 95 | doc='Main camera RGB observation.', 96 | ), 97 | 'left_wrist_image': tfds.features.Image( 98 | shape=(256, 256, 3), 99 | dtype=np.uint8, 100 | encoding_format='jpeg', 101 | doc='Left wrist camera RGB observation.', 102 | ), 103 | 'right_wrist_image': tfds.features.Image( 104 | shape=(256, 256, 3), 105 | dtype=np.uint8, 106 | encoding_format='jpeg', 107 | doc='Right wrist camera RGB observation.', 108 | ), 109 | 'low_cam_image': tfds.features.Image( 110 | shape=(256, 256, 3), 111 | dtype=np.uint8, 112 | encoding_format='jpeg', 113 | doc='Lower camera RGB observation.', 114 | ), 115 | 'state': tfds.features.Tensor( 116 | shape=(14,), 117 | dtype=np.float32, 118 | doc='Robot joint state (7D left arm + 7D right arm).', 119 | ), 120 | }), 121 | 'action': tfds.features.Tensor( 122 | shape=(14,), 123 | dtype=np.float32, 124 | doc='Robot arm action.', 125 | ), 126 | 'discount': tfds.features.Scalar( 127 | dtype=np.float32, 128 | doc='Discount if provided, default to 1.' 129 | ), 130 | 'reward': tfds.features.Scalar( 131 | dtype=np.float32, 132 | doc='Reward if provided, 1 on final step for demos.' 133 | ), 134 | 'is_first': tfds.features.Scalar( 135 | dtype=np.bool_, 136 | doc='True on first step of the episode.' 137 | ), 138 | 'is_last': tfds.features.Scalar( 139 | dtype=np.bool_, 140 | doc='True on last step of the episode.' 141 | ), 142 | 'is_terminal': tfds.features.Scalar( 143 | dtype=np.bool_, 144 | doc='True on last step of the episode if it is a terminal step, True for demos.' 145 | ), 146 | 'language_instruction': tfds.features.Text( 147 | doc='Language Instruction.' 148 | ), 149 | }), 150 | 'episode_metadata': tfds.features.FeaturesDict({ 151 | 'file_path': tfds.features.Text( 152 | doc='Path to the original data file.' 153 | ), 154 | }), 155 | })) 156 | 157 | def _split_paths(self): 158 | """Define filepaths for data splits.""" 159 | return { 160 | "train": glob.glob("/scr/moojink/data/aloha1_preprocessed/put_green_pepper_into_pot/train/*.hdf5") + glob.glob("/scr/moojink/data/aloha1_preprocessed/put_red_pepper_into_pot/train/*.hdf5") + glob.glob("/scr/moojink/data/aloha1_preprocessed/put_yellow_corn_into_pot/train/*.hdf5"), 161 | "val": glob.glob("/scr/moojink/data/aloha1_preprocessed/put_green_pepper_into_pot/val/*.hdf5") + glob.glob("/scr/moojink/data/aloha1_preprocessed/put_red_pepper_into_pot/val/*.hdf5") + glob.glob("/scr/moojink/data/aloha1_preprocessed/put_yellow_corn_into_pot/val/*.hdf5"), 162 | } 163 | -------------------------------------------------------------------------------- /aloha1_put_X_into_pot_300_demos/conversion_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Dict, Union, Callable, Iterable 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | import itertools 7 | from multiprocessing import Pool 8 | from functools import partial 9 | from tensorflow_datasets.core import download 10 | from tensorflow_datasets.core import split_builder as split_builder_lib 11 | from tensorflow_datasets.core import naming 12 | from tensorflow_datasets.core import splits as splits_lib 13 | from tensorflow_datasets.core import utils 14 | from tensorflow_datasets.core import writer as writer_lib 15 | from tensorflow_datasets.core import example_serializer 16 | from tensorflow_datasets.core import dataset_builder 17 | from tensorflow_datasets.core import file_adapters 18 | 19 | Key = Union[str, int] 20 | # The nested example dict passed to `features.encode_example` 21 | Example = Dict[str, Any] 22 | KeyExample = Tuple[Key, Example] 23 | 24 | 25 | class MultiThreadedDatasetBuilder(tfds.core.GeneratorBasedBuilder): 26 | """DatasetBuilder for example dataset.""" 27 | N_WORKERS = 10 # number of parallel workers for data conversion 28 | MAX_PATHS_IN_MEMORY = 100 # number of paths converted & stored in memory before writing to disk 29 | # -> the higher the faster / more parallel conversion, adjust based on avilable RAM 30 | # note that one path may yield multiple episodes and adjust accordingly 31 | PARSE_FCN = None # needs to be filled with path-to-record-episode parse function 32 | 33 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 34 | """Define data splits.""" 35 | split_paths = self._split_paths() 36 | return {split: type(self).PARSE_FCN(paths=split_paths[split]) for split in split_paths} 37 | 38 | def _generate_examples(self): 39 | pass # this is implemented in global method to enable multiprocessing 40 | 41 | def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks 42 | self, 43 | dl_manager: download.DownloadManager, 44 | download_config: download.DownloadConfig, 45 | ) -> None: 46 | """Generate all splits and returns the computed split infos.""" 47 | assert self.PARSE_FCN is not None # need to overwrite parse function 48 | split_builder = ParallelSplitBuilder( 49 | split_dict=self.info.splits, 50 | features=self.info.features, 51 | dataset_size=self.info.dataset_size, 52 | max_examples_per_split=download_config.max_examples_per_split, 53 | beam_options=download_config.beam_options, 54 | beam_runner=download_config.beam_runner, 55 | file_format=self.info.file_format, 56 | shard_config=download_config.get_shard_config(), 57 | split_paths=self._split_paths(), 58 | parse_function=type(self).PARSE_FCN, 59 | n_workers=self.N_WORKERS, 60 | max_paths_in_memory=self.MAX_PATHS_IN_MEMORY, 61 | ) 62 | split_generators = self._split_generators(dl_manager) 63 | split_generators = split_builder.normalize_legacy_split_generators( 64 | split_generators=split_generators, 65 | generator_fn=self._generate_examples, 66 | is_beam=False, 67 | ) 68 | dataset_builder._check_split_names(split_generators.keys()) 69 | 70 | # Start generating data for all splits 71 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ 72 | self.info.file_format 73 | ].FILE_SUFFIX 74 | 75 | split_info_futures = [] 76 | for split_name, generator in utils.tqdm( 77 | split_generators.items(), 78 | desc="Generating splits...", 79 | unit=" splits", 80 | leave=False, 81 | ): 82 | filename_template = naming.ShardedFileTemplate( 83 | split=split_name, 84 | dataset_name=self.name, 85 | data_dir=self.data_path, 86 | filetype_suffix=path_suffix, 87 | ) 88 | future = split_builder.submit_split_generation( 89 | split_name=split_name, 90 | generator=generator, 91 | filename_template=filename_template, 92 | disable_shuffling=self.info.disable_shuffling, 93 | ) 94 | split_info_futures.append(future) 95 | 96 | # Finalize the splits (after apache beam completed, if it was used) 97 | split_infos = [future.result() for future in split_info_futures] 98 | 99 | # Update the info object with the splits. 100 | split_dict = splits_lib.SplitDict(split_infos) 101 | self.info.set_splits(split_dict) 102 | 103 | 104 | class _SplitInfoFuture: 105 | """Future containing the `tfds.core.SplitInfo` result.""" 106 | 107 | def __init__(self, callback: Callable[[], splits_lib.SplitInfo]): 108 | self._callback = callback 109 | 110 | def result(self) -> splits_lib.SplitInfo: 111 | return self._callback() 112 | 113 | 114 | def parse_examples_from_generator(paths, fcn, split_name, total_num_examples, features, serializer): 115 | generator = fcn(paths) 116 | outputs = [] 117 | for sample in utils.tqdm( 118 | generator, 119 | desc=f'Generating {split_name} examples...', 120 | unit=' examples', 121 | total=total_num_examples, 122 | leave=False, 123 | mininterval=1.0, 124 | ): 125 | if sample is None: continue 126 | key, example = sample 127 | try: 128 | example = features.encode_example(example) 129 | except Exception as e: # pylint: disable=broad-except 130 | utils.reraise(e, prefix=f'Failed to encode example:\n{example}\n') 131 | outputs.append((key, serializer.serialize_example(example))) 132 | return outputs 133 | 134 | 135 | class ParallelSplitBuilder(split_builder_lib.SplitBuilder): 136 | def __init__(self, *args, split_paths, parse_function, n_workers, max_paths_in_memory, **kwargs): 137 | super().__init__(*args, **kwargs) 138 | self._split_paths = split_paths 139 | self._parse_function = parse_function 140 | self._n_workers = n_workers 141 | self._max_paths_in_memory = max_paths_in_memory 142 | 143 | def _build_from_generator( 144 | self, 145 | split_name: str, 146 | generator: Iterable[KeyExample], 147 | filename_template: naming.ShardedFileTemplate, 148 | disable_shuffling: bool, 149 | ) -> _SplitInfoFuture: 150 | """Split generator for example generators. 151 | 152 | Args: 153 | split_name: str, 154 | generator: Iterable[KeyExample], 155 | filename_template: Template to format the filename for a shard. 156 | disable_shuffling: Specifies whether to shuffle the examples, 157 | 158 | Returns: 159 | future: The future containing the `tfds.core.SplitInfo`. 160 | """ 161 | total_num_examples = None 162 | serialized_info = self._features.get_serialized_info() 163 | writer = writer_lib.Writer( 164 | serializer=example_serializer.ExampleSerializer(serialized_info), 165 | filename_template=filename_template, 166 | hash_salt=split_name, 167 | disable_shuffling=disable_shuffling, 168 | file_format=self._file_format, 169 | shard_config=self._shard_config, 170 | ) 171 | 172 | del generator # use parallel generators instead 173 | paths = self._split_paths[split_name] 174 | path_lists = chunk_max(paths, self._n_workers, self._max_paths_in_memory) # generate N file lists 175 | print(f"Generating with {self._n_workers} workers!") 176 | pool = Pool(processes=self._n_workers) 177 | for i, paths in enumerate(path_lists): 178 | print(f"Processing chunk {i + 1} of {len(path_lists)}.") 179 | results = pool.map( 180 | partial( 181 | parse_examples_from_generator, 182 | fcn=self._parse_function, 183 | split_name=split_name, 184 | total_num_examples=total_num_examples, 185 | serializer=writer._serializer, 186 | features=self._features 187 | ), 188 | paths 189 | ) 190 | # write results to shuffler --> this will automatically offload to disk if necessary 191 | print("Writing conversion results...") 192 | for result in itertools.chain(*results): 193 | key, serialized_example = result 194 | writer._shuffler.add(key, serialized_example) 195 | writer._num_examples += 1 196 | pool.close() 197 | 198 | print("Finishing split conversion...") 199 | shard_lengths, total_size = writer.finalize() 200 | 201 | split_info = splits_lib.SplitInfo( 202 | name=split_name, 203 | shard_lengths=shard_lengths, 204 | num_bytes=total_size, 205 | filename_template=filename_template, 206 | ) 207 | return _SplitInfoFuture(lambda: split_info) 208 | 209 | 210 | def dictlist2listdict(DL): 211 | " Converts a dict of lists to a list of dicts " 212 | return [dict(zip(DL, t)) for t in zip(*DL.values())] 213 | 214 | def chunks(l, n): 215 | """Yield n number of sequential chunks from l.""" 216 | d, r = divmod(len(l), n) 217 | for i in range(n): 218 | si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r) 219 | yield l[si:si + (d + 1 if i < r else d)] 220 | 221 | def chunk_max(l, n, max_chunk_sum): 222 | out = [] 223 | for _ in range(int(np.ceil(len(l) / max_chunk_sum))): 224 | out.append(list(chunks(l[:max_chunk_sum], n))) 225 | l = l[max_chunk_sum:] 226 | return out -------------------------------------------------------------------------------- /environment_macos.yml: -------------------------------------------------------------------------------- 1 | name: rlds_env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _tflow_select=2.2.0=eigen 6 | - abseil-cpp=20211102.0=he9d5cce_0 7 | - aiosignal=1.2.0=pyhd3eb1b0_0 8 | - appdirs=1.4.4=pyhd3eb1b0_0 9 | - astunparse=1.6.3=py_0 10 | - blas=1.0=mkl 11 | - bzip2=1.0.8=h1de35cc_0 12 | - c-ares=1.19.0=h6c40b1e_0 13 | - ca-certificates=2023.05.30=hecd8cb5_0 14 | - cachetools=4.2.2=pyhd3eb1b0_0 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - flatbuffers=2.0.0=h23ab428_0 17 | - gast=0.4.0=pyhd3eb1b0_0 18 | - giflib=5.2.1=h6c40b1e_3 19 | - google-auth=2.6.0=pyhd3eb1b0_0 20 | - google-pasta=0.2.0=pyhd3eb1b0_0 21 | - grpc-cpp=1.48.2=h3afe56f_0 22 | - hdf5=1.10.6=h10fe05b_1 23 | - icu=68.1=h23ab428_0 24 | - intel-openmp=2023.1.0=ha357a0b_43547 25 | - jpeg=9e=h6c40b1e_1 26 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 27 | - krb5=1.20.1=hdba6334_1 28 | - libcurl=8.1.1=ha585b31_1 29 | - libcxx=14.0.6=h9765a3e_0 30 | - libedit=3.1.20221030=h6c40b1e_0 31 | - libev=4.33=h9ed2024_1 32 | - libffi=3.4.4=hecd8cb5_0 33 | - libgfortran=5.0.0=11_3_0_hecd8cb5_28 34 | - libgfortran5=11.3.0=h9dfd629_28 35 | - libnghttp2=1.52.0=h1c88b7d_1 36 | - libpng=1.6.39=h6c40b1e_0 37 | - libprotobuf=3.20.3=hfff2838_0 38 | - libssh2=1.10.0=hdb2fb19_2 39 | - llvm-openmp=14.0.6=h0dcd299_0 40 | - mkl=2023.1.0=h59209a4_43558 41 | - mkl_fft=1.3.6=py311hdb55bb0_1 42 | - mkl_random=1.2.2=py311hdb55bb0_1 43 | - ncurses=6.4=hcec6c5f_0 44 | - numpy-base=1.23.5=py311h53bf9ac_1 45 | - openssl=1.1.1u=hca72f7f_0 46 | - opt_einsum=3.3.0=pyhd3eb1b0_1 47 | - pooch=1.4.0=pyhd3eb1b0_0 48 | - pyasn1=0.4.8=pyhd3eb1b0_0 49 | - pyasn1-modules=0.2.8=py_0 50 | - pycparser=2.21=pyhd3eb1b0_0 51 | - python=3.11.4=h1fd4e5f_0 52 | - python-flatbuffers=2.0=pyhd3eb1b0_0 53 | - re2=2022.04.01=he9d5cce_0 54 | - readline=8.2=hca72f7f_0 55 | - requests-oauthlib=1.3.0=py_0 56 | - rsa=4.7.2=pyhd3eb1b0_1 57 | - six=1.16.0=pyhd3eb1b0_1 58 | - snappy=1.1.9=he9d5cce_0 59 | - sqlite=3.41.2=h6c40b1e_0 60 | - tbb=2021.8.0=ha357a0b_0 61 | - tensorboard-plugin-wit=1.6.0=py_0 62 | - tensorflow-base=2.12.0=eigen_py311hbf87084_0 63 | - tk=8.6.12=h5d9f67b_0 64 | - typing_extensions=4.6.3=py311hecd8cb5_0 65 | - tzdata=2023c=h04d1e81_0 66 | - wheel=0.35.1=pyhd3eb1b0_0 67 | - xz=5.4.2=h6c40b1e_0 68 | - zlib=1.2.13=h4dc903c_0 69 | - pip: 70 | - absl-py==1.4.0 71 | - aiohttp==3.8.3 72 | - apache-beam==2.48.0 73 | - array-record==0.4.0 74 | - async-timeout==4.0.2 75 | - attrs==22.1.0 76 | - blinker==1.4 77 | - brotlipy==0.7.0 78 | - certifi==2023.5.7 79 | - cffi==1.15.1 80 | - click==8.0.4 81 | - cloudpickle==2.2.1 82 | - contourpy==1.1.0 83 | - crcmod==1.7 84 | - cryptography==39.0.1 85 | - cycler==0.11.0 86 | - dill==0.3.1.1 87 | - dm-tree==0.1.8 88 | - dnspython==2.3.0 89 | - docker-pycreds==0.4.0 90 | - docopt==0.6.2 91 | - etils==1.3.0 92 | - fastavro==1.8.0 93 | - fasteners==0.18 94 | - fonttools==4.41.0 95 | - frozenlist==1.3.3 96 | - gitdb==4.0.10 97 | - gitpython==3.1.32 98 | - google-auth-oauthlib==0.5.2 99 | - googleapis-common-protos==1.59.1 100 | - grpcio==1.48.2 101 | - h5py==3.7.0 102 | - hdfs==2.7.0 103 | - httplib2==0.22.0 104 | - idna==3.4 105 | - importlib-resources==6.0.0 106 | - keras==2.12.0 107 | - kiwisolver==1.4.4 108 | - markdown==3.4.1 109 | - markupsafe==2.1.1 110 | - matplotlib==3.7.2 111 | - mkl-fft==1.3.6 112 | - mkl-random==1.2.2 113 | - mkl-service==2.4.0 114 | - multidict==6.0.2 115 | - numpy==1.23.5 116 | - oauthlib==3.2.2 117 | - objsize==0.6.1 118 | - orjson==3.9.2 119 | - packaging==23.0 120 | - pathtools==0.1.2 121 | - pillow==10.0.0 122 | - pip==23.1.2 123 | - plotly==5.15.0 124 | - promise==2.3 125 | - proto-plus==1.22.3 126 | - protobuf==3.20.3 127 | - psutil==5.9.5 128 | - pyarrow==11.0.0 129 | - pydot==1.4.2 130 | - pyjwt==2.4.0 131 | - pymongo==4.4.1 132 | - pyopenssl==23.0.0 133 | - pyparsing==3.0.9 134 | - pysocks==1.7.1 135 | - python-dateutil==2.8.2 136 | - pytz==2023.3 137 | - pyyaml==6.0 138 | - regex==2023.6.3 139 | - requests==2.29.0 140 | - scipy==1.10.1 141 | - sentry-sdk==1.28.1 142 | - setproctitle==1.3.2 143 | - setuptools==67.8.0 144 | - smmap==5.0.0 145 | - tenacity==8.2.2 146 | - tensorboard==2.12.1 147 | - tensorboard-data-server==0.7.0 148 | - tensorflow==2.12.0 149 | - tensorflow-datasets==4.9.2 150 | - tensorflow-estimator==2.12.0 151 | - tensorflow-hub==0.14.0 152 | - tensorflow-metadata==1.13.1 153 | - termcolor==2.1.0 154 | - toml==0.10.2 155 | - tqdm==4.65.0 156 | - typing-extensions==4.6.3 157 | - urllib3==1.26.16 158 | - wandb==0.15.5 159 | - werkzeug==2.2.3 160 | - wrapt==1.14.1 161 | - yarl==1.8.1 162 | - zipp==3.16.1 163 | - zstandard==0.21.0 164 | prefix: /Users/karl/miniconda3/envs/rlds_env 165 | -------------------------------------------------------------------------------- /environment_ubuntu.yml: -------------------------------------------------------------------------------- 1 | name: rlds_env 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - _libgcc_mutex=0.1=conda_forge 6 | - _openmp_mutex=4.5=2_gnu 7 | - ca-certificates=2023.7.22=hbcca054_0 8 | - ld_impl_linux-64=2.40=h41732ed_0 9 | - libffi=3.3=h58526e2_2 10 | - libgcc-ng=13.1.0=he5830b7_0 11 | - libgomp=13.1.0=he5830b7_0 12 | - libsqlite=3.42.0=h2797004_0 13 | - libstdcxx-ng=13.1.0=hfd8a6a1_0 14 | - libzlib=1.2.13=hd590300_5 15 | - ncurses=6.4=hcb278e6_0 16 | - openssl=1.1.1u=hd590300_0 17 | - pip=23.2.1=pyhd8ed1ab_0 18 | - python=3.9.0=hffdb5ce_5_cpython 19 | - readline=8.2=h8228510_1 20 | - setuptools=68.0.0=pyhd8ed1ab_0 21 | - sqlite=3.42.0=h2c6b66d_0 22 | - tk=8.6.12=h27826a3_0 23 | - tzdata=2023c=h71feb2d_0 24 | - wheel=0.41.0=pyhd8ed1ab_0 25 | - xz=5.2.6=h166bdaf_0 26 | - zlib=1.2.13=hd590300_5 27 | - pip: 28 | - absl-py==1.4.0 29 | - anyio==3.7.1 30 | - apache-beam==2.49.0 31 | - appdirs==1.4.4 32 | - array-record==0.4.0 33 | - astunparse==1.6.3 34 | - cachetools==5.3.1 35 | - certifi==2023.7.22 36 | - charset-normalizer==3.2.0 37 | - click==8.1.6 38 | - cloudpickle==2.2.1 39 | - contourpy==1.1.0 40 | - crcmod==1.7 41 | - cycler==0.11.0 42 | - dill==0.3.1.1 43 | - dm-tree==0.1.8 44 | - dnspython==2.4.0 45 | - docker-pycreds==0.4.0 46 | - docopt==0.6.2 47 | - etils==1.3.0 48 | - exceptiongroup==1.1.2 49 | - fastavro==1.8.2 50 | - fasteners==0.18 51 | - flatbuffers==23.5.26 52 | - fonttools==4.41.1 53 | - gast==0.4.0 54 | - gitdb==4.0.10 55 | - gitpython==3.1.32 56 | - google-auth==2.22.0 57 | - google-auth-oauthlib==1.0.0 58 | - google-pasta==0.2.0 59 | - googleapis-common-protos==1.59.1 60 | - grpcio==1.56.2 61 | - h11==0.14.0 62 | - h5py==3.9.0 63 | - hdfs==2.7.0 64 | - httpcore==0.17.3 65 | - httplib2==0.22.0 66 | - idna==3.4 67 | - importlib-metadata==6.8.0 68 | - importlib-resources==6.0.0 69 | - keras==2.13.1 70 | - kiwisolver==1.4.4 71 | - libclang==16.0.6 72 | - markdown==3.4.3 73 | - markupsafe==2.1.3 74 | - matplotlib==3.7.2 75 | - numpy==1.24.3 76 | - oauthlib==3.2.2 77 | - objsize==0.6.1 78 | - opt-einsum==3.3.0 79 | - orjson==3.9.2 80 | - packaging==23.1 81 | - pathtools==0.1.2 82 | - pillow==10.0.0 83 | - plotly==5.15.0 84 | - promise==2.3 85 | - proto-plus==1.22.3 86 | - protobuf==4.23.4 87 | - psutil==5.9.5 88 | - pyarrow==11.0.0 89 | - pyasn1==0.5.0 90 | - pyasn1-modules==0.3.0 91 | - pydot==1.4.2 92 | - pymongo==4.4.1 93 | - pyparsing==3.0.9 94 | - python-dateutil==2.8.2 95 | - pytz==2023.3 96 | - pyyaml==6.0.1 97 | - regex==2023.6.3 98 | - requests==2.31.0 99 | - requests-oauthlib==1.3.1 100 | - rsa==4.9 101 | - sentry-sdk==1.28.1 102 | - setproctitle==1.3.2 103 | - six==1.16.0 104 | - smmap==5.0.0 105 | - sniffio==1.3.0 106 | - tenacity==8.2.2 107 | - tensorboard==2.13.0 108 | - tensorboard-data-server==0.7.1 109 | - tensorflow==2.13.0 110 | - tensorflow-datasets==4.9.2 111 | - tensorflow-estimator==2.13.0 112 | - tensorflow-hub==0.14.0 113 | - tensorflow-io-gcs-filesystem==0.32.0 114 | - tensorflow-metadata==1.13.1 115 | - termcolor==2.3.0 116 | - toml==0.10.2 117 | - tqdm==4.65.0 118 | - typing-extensions==4.5.0 119 | - urllib3==1.26.16 120 | - wandb==0.15.6 121 | - werkzeug==2.3.6 122 | - wrapt==1.15.0 123 | - zipp==3.16.2 124 | - zstandard==0.21.0 125 | prefix: /scr/kpertsch/miniconda3/envs/rlds_env 126 | -------------------------------------------------------------------------------- /example_dataset/CITATIONS.bib: -------------------------------------------------------------------------------- 1 | // TODO(example_dataset): BibTeX citation 2 | -------------------------------------------------------------------------------- /example_dataset/README.md: -------------------------------------------------------------------------------- 1 | TODO(example_dataset): Markdown description of your dataset. 2 | Description is **formatted** as markdown. 3 | 4 | It should also contain any processing which has been applied (if any), 5 | (e.g. corrupted example skipped, images cropped,...): 6 | -------------------------------------------------------------------------------- /example_dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moojink/rlds_dataset_builder/6174b0b6bb69df6361f1117944952bf14afb0cc3/example_dataset/__init__.py -------------------------------------------------------------------------------- /example_dataset/create_example_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import os 4 | 5 | N_TRAIN_EPISODES = 100 6 | N_VAL_EPISODES = 100 7 | 8 | EPISODE_LENGTH = 10 9 | 10 | 11 | def create_fake_episode(path): 12 | episode = [] 13 | for step in range(EPISODE_LENGTH): 14 | episode.append({ 15 | 'image': np.asarray(np.random.rand(64, 64, 3) * 255, dtype=np.uint8), 16 | 'wrist_image': np.asarray(np.random.rand(64, 64, 3) * 255, dtype=np.uint8), 17 | 'state': np.asarray(np.random.rand(10), dtype=np.float32), 18 | 'action': np.asarray(np.random.rand(10), dtype=np.float32), 19 | 'language_instruction': 'dummy instruction', 20 | }) 21 | np.save(path, episode) 22 | 23 | 24 | # create fake episodes for train and validation 25 | print("Generating train examples...") 26 | os.makedirs('data/train', exist_ok=True) 27 | for i in tqdm.tqdm(range(N_TRAIN_EPISODES)): 28 | create_fake_episode(f'data/train/episode_{i}.npy') 29 | 30 | print("Generating val examples...") 31 | os.makedirs('data/val', exist_ok=True) 32 | for i in tqdm.tqdm(range(N_VAL_EPISODES)): 33 | create_fake_episode(f'data/val/episode_{i}.npy') 34 | 35 | print('Successfully created example data!') 36 | -------------------------------------------------------------------------------- /example_dataset/example_dataset_dataset_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Tuple, Any 2 | 3 | import glob 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow_datasets as tfds 7 | import tensorflow_hub as hub 8 | 9 | 10 | class ExampleDataset(tfds.core.GeneratorBasedBuilder): 11 | """DatasetBuilder for example dataset.""" 12 | 13 | VERSION = tfds.core.Version('1.0.0') 14 | RELEASE_NOTES = { 15 | '1.0.0': 'Initial release.', 16 | } 17 | 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self._embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5") 21 | 22 | def _info(self) -> tfds.core.DatasetInfo: 23 | """Dataset metadata (homepage, citation,...).""" 24 | return self.dataset_info_from_configs( 25 | features=tfds.features.FeaturesDict({ 26 | 'steps': tfds.features.Dataset({ 27 | 'observation': tfds.features.FeaturesDict({ 28 | 'image': tfds.features.Image( 29 | shape=(64, 64, 3), 30 | dtype=np.uint8, 31 | encoding_format='png', 32 | doc='Main camera RGB observation.', 33 | ), 34 | 'wrist_image': tfds.features.Image( 35 | shape=(64, 64, 3), 36 | dtype=np.uint8, 37 | encoding_format='png', 38 | doc='Wrist camera RGB observation.', 39 | ), 40 | 'state': tfds.features.Tensor( 41 | shape=(10,), 42 | dtype=np.float32, 43 | doc='Robot state, consists of [7x robot joint angles, ' 44 | '2x gripper position, 1x door opening angle].', 45 | ) 46 | }), 47 | 'action': tfds.features.Tensor( 48 | shape=(10,), 49 | dtype=np.float32, 50 | doc='Robot action, consists of [7x joint velocities, ' 51 | '2x gripper velocities, 1x terminate episode].', 52 | ), 53 | 'discount': tfds.features.Scalar( 54 | dtype=np.float32, 55 | doc='Discount if provided, default to 1.' 56 | ), 57 | 'reward': tfds.features.Scalar( 58 | dtype=np.float32, 59 | doc='Reward if provided, 1 on final step for demos.' 60 | ), 61 | 'is_first': tfds.features.Scalar( 62 | dtype=np.bool_, 63 | doc='True on first step of the episode.' 64 | ), 65 | 'is_last': tfds.features.Scalar( 66 | dtype=np.bool_, 67 | doc='True on last step of the episode.' 68 | ), 69 | 'is_terminal': tfds.features.Scalar( 70 | dtype=np.bool_, 71 | doc='True on last step of the episode if it is a terminal step, True for demos.' 72 | ), 73 | 'language_instruction': tfds.features.Text( 74 | doc='Language Instruction.' 75 | ), 76 | 'language_embedding': tfds.features.Tensor( 77 | shape=(512,), 78 | dtype=np.float32, 79 | doc='Kona language embedding. ' 80 | 'See https://tfhub.dev/google/universal-sentence-encoder-large/5' 81 | ), 82 | }), 83 | 'episode_metadata': tfds.features.FeaturesDict({ 84 | 'file_path': tfds.features.Text( 85 | doc='Path to the original data file.' 86 | ), 87 | }), 88 | })) 89 | 90 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 91 | """Define data splits.""" 92 | return { 93 | 'train': self._generate_examples(path='data/train/episode_*.npy'), 94 | 'val': self._generate_examples(path='data/val/episode_*.npy'), 95 | } 96 | 97 | def _generate_examples(self, path) -> Iterator[Tuple[str, Any]]: 98 | """Generator of examples for each split.""" 99 | 100 | def _parse_example(episode_path): 101 | # load raw data --> this should change for your dataset 102 | data = np.load(episode_path, allow_pickle=True) # this is a list of dicts in our case 103 | 104 | # assemble episode --> here we're assuming demos so we set reward to 1 at the end 105 | episode = [] 106 | for i, step in enumerate(data): 107 | # compute Kona language embedding 108 | language_embedding = self._embed([step['language_instruction']])[0].numpy() 109 | 110 | episode.append({ 111 | 'observation': { 112 | 'image': step['image'], 113 | 'wrist_image': step['wrist_image'], 114 | 'state': step['state'], 115 | }, 116 | 'action': step['action'], 117 | 'discount': 1.0, 118 | 'reward': float(i == (len(data) - 1)), 119 | 'is_first': i == 0, 120 | 'is_last': i == (len(data) - 1), 121 | 'is_terminal': i == (len(data) - 1), 122 | 'language_instruction': step['language_instruction'], 123 | 'language_embedding': language_embedding, 124 | }) 125 | 126 | # create output data sample 127 | sample = { 128 | 'steps': episode, 129 | 'episode_metadata': { 130 | 'file_path': episode_path 131 | } 132 | } 133 | 134 | # if you want to skip an example for whatever reason, simply return None 135 | return episode_path, sample 136 | 137 | # create list of all examples 138 | episode_paths = glob.glob(path) 139 | 140 | # for smallish datasets, use single-thread parsing 141 | for sample in episode_paths: 142 | yield _parse_example(sample) 143 | 144 | # for large datasets use beam to parallelize data parsing (this will have initialization overhead) 145 | # beam = tfds.core.lazy_imports.apache_beam 146 | # return ( 147 | # beam.Create(episode_paths) 148 | # | beam.Map(_parse_example) 149 | # ) 150 | 151 | -------------------------------------------------------------------------------- /example_transform/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | ################################################################################################ 7 | # Target config # 8 | ################################################################################################ 9 | # features=tfds.features.FeaturesDict({ 10 | # 'steps': tfds.features.Dataset({ 11 | # 'observation': tfds.features.FeaturesDict({ 12 | # 'image': tfds.features.Image( 13 | # shape=(128, 128, 3), 14 | # dtype=np.uint8, 15 | # encoding_format='jpeg', 16 | # doc='Main camera RGB observation.', 17 | # ), 18 | # }), 19 | # 'action': tfds.features.Tensor( 20 | # shape=(8,), 21 | # dtype=np.float32, 22 | # doc='Robot action, consists of [3x EEF position, ' 23 | # '3x EEF orientation yaw/pitch/roll, 1x gripper open/close position, ' 24 | # '1x terminate episode].', 25 | # ), 26 | # 'discount': tfds.features.Scalar( 27 | # dtype=np.float32, 28 | # doc='Discount if provided, default to 1.' 29 | # ), 30 | # 'reward': tfds.features.Scalar( 31 | # dtype=np.float32, 32 | # doc='Reward if provided, 1 on final step for demos.' 33 | # ), 34 | # 'is_first': tfds.features.Scalar( 35 | # dtype=np.bool_, 36 | # doc='True on first step of the episode.' 37 | # ), 38 | # 'is_last': tfds.features.Scalar( 39 | # dtype=np.bool_, 40 | # doc='True on last step of the episode.' 41 | # ), 42 | # 'is_terminal': tfds.features.Scalar( 43 | # dtype=np.bool_, 44 | # doc='True on last step of the episode if it is a terminal step, True for demos.' 45 | # ), 46 | # 'language_instruction': tfds.features.Text( 47 | # doc='Language Instruction.' 48 | # ), 49 | # 'language_embedding': tfds.features.Tensor( 50 | # shape=(512,), 51 | # dtype=np.float32, 52 | # doc='Kona language embedding. ' 53 | # 'See https://tfhub.dev/google/universal-sentence-encoder-large/5' 54 | # ), 55 | # }) 56 | ################################################################################################ 57 | # # 58 | ################################################################################################ 59 | 60 | 61 | def transform_step(step: Dict[str, Any]) -> Dict[str, Any]: 62 | """Maps step from source dataset to target dataset config. 63 | Input is dict of numpy arrays.""" 64 | img = Image.fromarray(step['observation']['image']).resize( 65 | (128, 128), Image.Resampling.LANCZOS) 66 | transformed_step = { 67 | 'observation': { 68 | 'image': np.array(img), 69 | }, 70 | 'action': np.concatenate( 71 | [step['action'][:3], step['action'][5:8], step['action'][-2:]]), 72 | } 73 | 74 | # copy over all other fields unchanged 75 | for copy_key in ['discount', 'reward', 'is_first', 'is_last', 'is_terminal', 76 | 'language_instruction', 'language_embedding']: 77 | transformed_step[copy_key] = step[copy_key] 78 | 79 | return transformed_step 80 | 81 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="", packages=[""]) 4 | -------------------------------------------------------------------------------- /test_dataset_transform.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import tqdm 4 | import numpy as np 5 | import os 6 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # suppress debug warning messages 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | 10 | from example_transform.transform import transform_step 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('dataset_name', help='name of the dataset to visualize') 14 | args = parser.parse_args() 15 | 16 | 17 | TARGET_SPEC = { 18 | 'observation': { 19 | 'image': {'shape': (128, 128, 3), 20 | 'dtype': np.uint8, 21 | 'range': (0, 255)} 22 | }, 23 | 'action': {'shape': (8,), 24 | 'dtype': np.float32, 25 | 'range': [(-1, -1, -1, -2*np.pi, -2*np.pi, -2*np.pi, -1, 0), 26 | (+1, +1, +1, +2*np.pi, +2*np.pi, +2*np.pi, +1, 1)]}, 27 | 'discount': {'shape': (), 28 | 'dtype': np.float32, 29 | 'range': (0, 1)}, 30 | 'reward': {'shape': (), 31 | 'dtype': np.float32, 32 | 'range': (0, 1)}, 33 | 'is_first': {'shape': (), 34 | 'dtype': np.bool_, 35 | 'range': None}, 36 | 'is_last': {'shape': (), 37 | 'dtype': np.bool_, 38 | 'range': None}, 39 | 'is_terminal': {'shape': (), 40 | 'dtype': np.bool_, 41 | 'range': None}, 42 | 'language_instruction': {'shape': (), 43 | 'dtype': str, 44 | 'range': None}, 45 | 'language_embedding': {'shape': (512,), 46 | 'dtype': np.float32, 47 | 'range': None}, 48 | } 49 | 50 | 51 | def check_elements(target, values): 52 | """Recursively checks that elements in `values` match the TARGET_SPEC.""" 53 | for elem in target: 54 | if isinstance(values[elem], dict): 55 | check_elements(target[elem], values[elem]) 56 | else: 57 | if target[elem]['shape']: 58 | if tuple(values[elem].shape) != target[elem]['shape']: 59 | raise ValueError( 60 | f"Shape of {elem} should be {target[elem]['shape']} but is {tuple(values[elem].shape)}") 61 | if not isinstance(values[elem], bytes) and values[elem].dtype != target[elem]['dtype']: 62 | raise ValueError(f"Dtype of {elem} should be {target[elem]['dtype']} but is {values[elem].dtype}") 63 | if target[elem]['range'] is not None: 64 | if isinstance(target[elem]['range'], list): 65 | for vmin, vmax, val in zip(target[elem]['range'][0], 66 | target[elem]['range'][1], 67 | values[elem]): 68 | if not (val >= vmin and val <= vmax): 69 | raise ValueError( 70 | f"{elem} is out of range. Should be in {target[elem]['range']} but is {values[elem]}.") 71 | else: 72 | if not (np.all(values[elem] >= target[elem]['range'][0]) 73 | and np.all(values[elem] <= target[elem]['range'][1])): 74 | raise ValueError( 75 | f"{elem} is out of range. Should be in {target[elem]['range']} but is {values[elem]}.") 76 | 77 | 78 | # create TF dataset 79 | dataset_name = args.dataset_name 80 | print(f"Visualizing data from dataset: {dataset_name}") 81 | module = importlib.import_module(dataset_name) 82 | ds = tfds.load(dataset_name, split='train') 83 | ds = ds.shuffle(100) 84 | 85 | for episode in tqdm.tqdm(ds.take(50)): 86 | steps = tfds.as_numpy(episode['steps']) 87 | for step in steps: 88 | transformed_step = transform_step(step) 89 | check_elements(TARGET_SPEC, transformed_step) 90 | print("Test passed! You're ready to submit!") 91 | -------------------------------------------------------------------------------- /visualize_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tqdm 3 | import importlib 4 | import os 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # suppress debug warning messages 6 | import tensorflow_datasets as tfds 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import wandb 10 | 11 | 12 | WANDB_ENTITY = None 13 | WANDB_PROJECT = 'vis_rlds' 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('dataset_name', help='name of the dataset to visualize') 18 | args = parser.parse_args() 19 | 20 | if WANDB_ENTITY is not None: 21 | render_wandb = True 22 | wandb.init(entity=WANDB_ENTITY, 23 | project=WANDB_PROJECT) 24 | else: 25 | render_wandb = False 26 | 27 | 28 | # create TF dataset 29 | dataset_name = args.dataset_name 30 | print(f"Visualizing data from dataset: {dataset_name}") 31 | module = importlib.import_module(dataset_name) 32 | ds = tfds.load(dataset_name, split='train') 33 | ds = ds.shuffle(100) 34 | 35 | # visualize episodes 36 | for i, episode in enumerate(ds.take(5)): 37 | images = [] 38 | for step in episode['steps']: 39 | images.append(step['observation']['image'].numpy()) 40 | image_strip = np.concatenate(images[::4], axis=1) 41 | caption = step['language_instruction'].numpy().decode() + ' (temp. downsampled 4x)' 42 | 43 | if render_wandb: 44 | wandb.log({f'image_{i}': wandb.Image(image_strip, caption=caption)}) 45 | else: 46 | plt.figure() 47 | plt.imshow(image_strip) 48 | plt.title(caption) 49 | 50 | # visualize action and state statistics 51 | actions, states = [], [] 52 | for episode in tqdm.tqdm(ds.take(500)): 53 | for step in episode['steps']: 54 | actions.append(step['action'].numpy()) 55 | states.append(step['observation']['state'].numpy()) 56 | actions = np.array(actions) 57 | states = np.array(states) 58 | action_mean = actions.mean(0) 59 | state_mean = states.mean(0) 60 | 61 | def vis_stats(vector, vector_mean, tag): 62 | assert len(vector.shape) == 2 63 | assert len(vector_mean.shape) == 1 64 | assert vector.shape[1] == vector_mean.shape[0] 65 | 66 | n_elems = vector.shape[1] 67 | fig = plt.figure(tag, figsize=(5*n_elems, 5)) 68 | for elem in range(n_elems): 69 | plt.subplot(1, n_elems, elem+1) 70 | plt.hist(vector[:, elem], bins=20) 71 | plt.title(vector_mean[elem]) 72 | 73 | if render_wandb: 74 | wandb.log({tag: wandb.Image(fig)}) 75 | 76 | vis_stats(actions, action_mean, 'action_stats') 77 | vis_stats(states, state_mean, 'state_stats') 78 | 79 | if not render_wandb: 80 | plt.show() 81 | 82 | 83 | --------------------------------------------------------------------------------