├── .gitignore ├── LICENSE ├── README.md ├── environment_macos.yml ├── environment_ubuntu.yml ├── modify_rlds_dataset.py ├── prepare_open_x.sh ├── rlds_dataset_mod ├── mod_functions.py └── multithreaded_adhoc_tfds_builder.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | __pycache__ 3 | .idea 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Karl Pertsch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RLDS Dataset Modification 2 | 3 | This repo contains scripts for modifying existing RLDS datasets. 4 | By running [`modify_rlds_dataset.py`](modify_rlds_dataset.py), you will load an existing RLDS dataset, apply the specified 5 | modifications to each sample, reshard the resulting dataset and store it in a new directory. Apart from a number of simple 6 | modification functions, this repo implements a parallelized `AdhocTFDSBuilder` that can perform the data modifications 7 | in parallel for increased conversion speed. 8 | 9 | ## Installation 10 | 11 | First create a conda environment using the provided environment.yml file (use `environment_ubuntu.yml` or `environment_macos.yml` depending on the operating system you're using): 12 | ``` 13 | conda env create -f environment_ubuntu.yml 14 | ``` 15 | 16 | Then activate the environment using: 17 | ``` 18 | conda activate rlds_env 19 | ``` 20 | 21 | If you want to manually create an environment, the key packages to install are `tensorflow` and `tensorflow_datasets`. 22 | 23 | To download datasets from the [Open X-Embodiment Dataset](https://robotics-transformer-x.github.io/) Google cloud bucket, 24 | please install `gsutil` using the [installation instructions](https://cloud.google.com/storage/docs/gsutil_install). 25 | 26 | 27 | ## Modifying RLDS Datasets 28 | 29 | The command below resizes all RGB and depth images to a max. size of 336 and encodes RGB images as jpeg. 30 | This can e.g. be useful for reducing the storage size of datasets in the [Open X-Embodiment Dataset](https://robotics-transformer-x.github.io/). 31 | ``` 32 | python3 modify_rlds_dataset.py --dataset= --mods=resize_and_jpeg_encode --target_dir= 33 | ``` 34 | 35 | This creates a new dataset with smaller, jpeg encoded images in the `target_dir`. 36 | 37 | You can switch out the `resize_and_jpeg_encode` mod for other functions in [mod_functions.py](rlds_dataset_mod/mod_functions.py). 38 | 39 | 40 | ## Command Arguments 41 | 42 | The [`modify_rlds_dataset.py`](modify_rlds_dataset.py) script supports the following command line arguments: 43 | ``` 44 | modify_rlds_dataset.py: 45 | --data_dir: Directory where source data is stored. 46 | --dataset: Dataset name. 47 | --max_episodes_in_memory: Number of episodes converted & stored in memory before writing to disk. 48 | (default: '100') 49 | (an integer) 50 | --mods: List of modification functions, applied in order. 51 | (a comma separated list) 52 | --n_workers: Number of parallel workers for data conversion. 53 | (default: '10') 54 | (an integer) 55 | --target_dir: Directory where modified data is stored. 56 | ``` 57 | You can increase the `n_workers` and `max_episodes_in_memory` parameters based on the resources of your machine. 58 | The larger the respective value, the faster the dataset conversion. 59 | 60 | A list of all supported dataset modifications ("mods") can be found in [mod_functions.py](rlds_dataset_mod/mod_functions.py). 61 | 62 | 63 | ## Adding New Mods 64 | 65 | You can add your own custom modification functions in [mod_functions.py](rlds_dataset_mod/mod_functions.py) by implementing 66 | the `TfdsModFunction` interface. Your mod function needs to provide one function that modifies the dataset feature spec 67 | and one map function that modifies an input `tf.data.Dataset`. You can use the existing mod functions as examples. 68 | Make sure to register your new mod in the `TFDS_MOD_FUNCTIONS` object. 69 | 70 | 71 | ## Download Open X-Embodiment Dataset 72 | To download the Open X-Embodiment dataset and convert it for training, run `bash prepare_open_x.sh`. You can 73 | specify the download directory at the top of the script. 74 | 75 | 76 | ## FAQ / Troubleshooting 77 | 78 | - **No new tempfile could be created**: The script stores large datasets in intermediate temporary files in the 79 | `\tmp` directory. Depending on the dataset size it can store up to 1000 such temp files. The default number of 80 | files openable in parallel in Ubuntu is 1024, so this limit can lead to the error above. You can increase the limit by 81 | running `ulimit -n 200000`. 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /environment_macos.yml: -------------------------------------------------------------------------------- 1 | name: rlds_env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _tflow_select=2.2.0=eigen 6 | - abseil-cpp=20211102.0=he9d5cce_0 7 | - aiosignal=1.2.0=pyhd3eb1b0_0 8 | - appdirs=1.4.4=pyhd3eb1b0_0 9 | - astunparse=1.6.3=py_0 10 | - blas=1.0=mkl 11 | - bzip2=1.0.8=h1de35cc_0 12 | - c-ares=1.19.0=h6c40b1e_0 13 | - ca-certificates=2023.05.30=hecd8cb5_0 14 | - cachetools=4.2.2=pyhd3eb1b0_0 15 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 16 | - flatbuffers=2.0.0=h23ab428_0 17 | - gast=0.4.0=pyhd3eb1b0_0 18 | - giflib=5.2.1=h6c40b1e_3 19 | - google-auth=2.6.0=pyhd3eb1b0_0 20 | - google-pasta=0.2.0=pyhd3eb1b0_0 21 | - grpc-cpp=1.48.2=h3afe56f_0 22 | - hdf5=1.10.6=h10fe05b_1 23 | - icu=68.1=h23ab428_0 24 | - intel-openmp=2023.1.0=ha357a0b_43547 25 | - jpeg=9e=h6c40b1e_1 26 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 27 | - krb5=1.20.1=hdba6334_1 28 | - libcurl=8.1.1=ha585b31_1 29 | - libcxx=14.0.6=h9765a3e_0 30 | - libedit=3.1.20221030=h6c40b1e_0 31 | - libev=4.33=h9ed2024_1 32 | - libffi=3.4.4=hecd8cb5_0 33 | - libgfortran=5.0.0=11_3_0_hecd8cb5_28 34 | - libgfortran5=11.3.0=h9dfd629_28 35 | - libnghttp2=1.52.0=h1c88b7d_1 36 | - libpng=1.6.39=h6c40b1e_0 37 | - libprotobuf=3.20.3=hfff2838_0 38 | - libssh2=1.10.0=hdb2fb19_2 39 | - llvm-openmp=14.0.6=h0dcd299_0 40 | - mkl=2023.1.0=h59209a4_43558 41 | - mkl_fft=1.3.6=py311hdb55bb0_1 42 | - mkl_random=1.2.2=py311hdb55bb0_1 43 | - ncurses=6.4=hcec6c5f_0 44 | - numpy-base=1.23.5=py311h53bf9ac_1 45 | - openssl=1.1.1u=hca72f7f_0 46 | - opt_einsum=3.3.0=pyhd3eb1b0_1 47 | - pooch=1.4.0=pyhd3eb1b0_0 48 | - pyasn1=0.4.8=pyhd3eb1b0_0 49 | - pyasn1-modules=0.2.8=py_0 50 | - pycparser=2.21=pyhd3eb1b0_0 51 | - python=3.11.4=h1fd4e5f_0 52 | - python-flatbuffers=2.0=pyhd3eb1b0_0 53 | - re2=2022.04.01=he9d5cce_0 54 | - readline=8.2=hca72f7f_0 55 | - requests-oauthlib=1.3.0=py_0 56 | - rsa=4.7.2=pyhd3eb1b0_1 57 | - six=1.16.0=pyhd3eb1b0_1 58 | - snappy=1.1.9=he9d5cce_0 59 | - sqlite=3.41.2=h6c40b1e_0 60 | - tbb=2021.8.0=ha357a0b_0 61 | - tensorboard-plugin-wit=1.6.0=py_0 62 | - tensorflow-base=2.12.0=eigen_py311hbf87084_0 63 | - tk=8.6.12=h5d9f67b_0 64 | - typing_extensions=4.6.3=py311hecd8cb5_0 65 | - tzdata=2023c=h04d1e81_0 66 | - wheel=0.35.1=pyhd3eb1b0_0 67 | - xz=5.4.2=h6c40b1e_0 68 | - zlib=1.2.13=h4dc903c_0 69 | - pip: 70 | - absl-py==1.4.0 71 | - aiohttp==3.8.3 72 | - apache-beam==2.48.0 73 | - array-record==0.4.0 74 | - async-timeout==4.0.2 75 | - attrs==22.1.0 76 | - blinker==1.4 77 | - brotlipy==0.7.0 78 | - certifi==2023.5.7 79 | - cffi==1.15.1 80 | - click==8.0.4 81 | - cloudpickle==2.2.1 82 | - contourpy==1.1.0 83 | - crcmod==1.7 84 | - cryptography==39.0.1 85 | - cycler==0.11.0 86 | - dill==0.3.1.1 87 | - dm-tree==0.1.8 88 | - dnspython==2.3.0 89 | - docker-pycreds==0.4.0 90 | - docopt==0.6.2 91 | - etils==1.3.0 92 | - fastavro==1.8.0 93 | - fasteners==0.18 94 | - fonttools==4.41.0 95 | - frozenlist==1.3.3 96 | - gitdb==4.0.10 97 | - gitpython==3.1.32 98 | - google-auth-oauthlib==0.5.2 99 | - googleapis-common-protos==1.59.1 100 | - grpcio==1.48.2 101 | - h5py==3.7.0 102 | - hdfs==2.7.0 103 | - httplib2==0.22.0 104 | - idna==3.4 105 | - importlib-resources==6.0.0 106 | - keras==2.12.0 107 | - kiwisolver==1.4.4 108 | - markdown==3.4.1 109 | - markupsafe==2.1.1 110 | - matplotlib==3.7.2 111 | - mkl-fft==1.3.6 112 | - mkl-random==1.2.2 113 | - mkl-service==2.4.0 114 | - multidict==6.0.2 115 | - numpy==1.23.5 116 | - oauthlib==3.2.2 117 | - objsize==0.6.1 118 | - orjson==3.9.2 119 | - packaging==23.0 120 | - pathtools==0.1.2 121 | - pillow==10.0.0 122 | - pip==23.1.2 123 | - plotly==5.15.0 124 | - promise==2.3 125 | - proto-plus==1.22.3 126 | - protobuf==3.20.3 127 | - psutil==5.9.5 128 | - pyarrow==11.0.0 129 | - pydot==1.4.2 130 | - pyjwt==2.4.0 131 | - pymongo==4.4.1 132 | - pyopenssl==23.0.0 133 | - pyparsing==3.0.9 134 | - pysocks==1.7.1 135 | - python-dateutil==2.8.2 136 | - pytz==2023.3 137 | - pyyaml==6.0 138 | - regex==2023.6.3 139 | - requests==2.29.0 140 | - scipy==1.10.1 141 | - sentry-sdk==1.28.1 142 | - setproctitle==1.3.2 143 | - setuptools==67.8.0 144 | - smmap==5.0.0 145 | - tenacity==8.2.2 146 | - tensorboard==2.12.1 147 | - tensorboard-data-server==0.7.0 148 | - tensorflow==2.12.0 149 | - tensorflow-datasets==4.9.2 150 | - tensorflow-estimator==2.12.0 151 | - tensorflow-hub==0.14.0 152 | - tensorflow-metadata==1.13.1 153 | - termcolor==2.1.0 154 | - toml==0.10.2 155 | - tqdm==4.65.0 156 | - typing-extensions==4.6.3 157 | - urllib3==1.26.16 158 | - wandb==0.15.5 159 | - werkzeug==2.2.3 160 | - wrapt==1.14.1 161 | - yarl==1.8.1 162 | - zipp==3.16.1 163 | - zstandard==0.21.0 164 | - dlimp @ git+https://github.com/kvablack/dlimp@fba663b10858793d35f9a0fdbe8f0d51906c8c90 165 | prefix: /Users/karl/miniconda3/envs/rlds_env 166 | -------------------------------------------------------------------------------- /environment_ubuntu.yml: -------------------------------------------------------------------------------- 1 | name: rlds_env 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - _libgcc_mutex=0.1=conda_forge 6 | - _openmp_mutex=4.5=2_gnu 7 | - ca-certificates=2023.7.22=hbcca054_0 8 | - ld_impl_linux-64=2.40=h41732ed_0 9 | - libffi=3.3=h58526e2_2 10 | - libgcc-ng=13.1.0=he5830b7_0 11 | - libgomp=13.1.0=he5830b7_0 12 | - libsqlite=3.42.0=h2797004_0 13 | - libstdcxx-ng=13.1.0=hfd8a6a1_0 14 | - libzlib=1.2.13=hd590300_5 15 | - ncurses=6.4=hcb278e6_0 16 | - openssl=1.1.1u=hd590300_0 17 | - pip=23.2.1=pyhd8ed1ab_0 18 | - python=3.9.0=hffdb5ce_5_cpython 19 | - readline=8.2=h8228510_1 20 | - setuptools=68.0.0=pyhd8ed1ab_0 21 | - sqlite=3.42.0=h2c6b66d_0 22 | - tk=8.6.12=h27826a3_0 23 | - tzdata=2023c=h71feb2d_0 24 | - wheel=0.41.0=pyhd8ed1ab_0 25 | - xz=5.2.6=h166bdaf_0 26 | - zlib=1.2.13=hd590300_5 27 | - pip: 28 | - absl-py==1.4.0 29 | - anyio==3.7.1 30 | - apache-beam==2.49.0 31 | - appdirs==1.4.4 32 | - array-record==0.4.0 33 | - astunparse==1.6.3 34 | - cachetools==5.3.1 35 | - certifi==2023.7.22 36 | - charset-normalizer==3.2.0 37 | - click==8.1.6 38 | - cloudpickle==2.2.1 39 | - contourpy==1.1.0 40 | - crcmod==1.7 41 | - cycler==0.11.0 42 | - dill==0.3.1.1 43 | - dm-tree==0.1.8 44 | - dnspython==2.4.0 45 | - docker-pycreds==0.4.0 46 | - docopt==0.6.2 47 | - etils==1.3.0 48 | - exceptiongroup==1.1.2 49 | - fastavro==1.8.2 50 | - fasteners==0.18 51 | - flatbuffers==23.5.26 52 | - fonttools==4.41.1 53 | - gast==0.4.0 54 | - gitdb==4.0.10 55 | - gitpython==3.1.32 56 | - google-auth==2.22.0 57 | - google-auth-oauthlib==1.0.0 58 | - google-pasta==0.2.0 59 | - googleapis-common-protos==1.59.1 60 | - grpcio==1.56.2 61 | - h11==0.14.0 62 | - h5py==3.9.0 63 | - hdfs==2.7.0 64 | - httpcore==0.17.3 65 | - httplib2==0.22.0 66 | - idna==3.4 67 | - importlib-metadata==6.8.0 68 | - importlib-resources==6.0.0 69 | - keras==2.13.1 70 | - kiwisolver==1.4.4 71 | - libclang==16.0.6 72 | - markdown==3.4.3 73 | - markupsafe==2.1.3 74 | - matplotlib==3.7.2 75 | - numpy==1.24.3 76 | - oauthlib==3.2.2 77 | - objsize==0.6.1 78 | - opt-einsum==3.3.0 79 | - orjson==3.9.2 80 | - packaging==23.1 81 | - pathtools==0.1.2 82 | - pillow==10.0.0 83 | - plotly==5.15.0 84 | - promise==2.3 85 | - proto-plus==1.22.3 86 | - protobuf==4.23.4 87 | - psutil==5.9.5 88 | - pyarrow==11.0.0 89 | - pyasn1==0.5.0 90 | - pyasn1-modules==0.3.0 91 | - pydot==1.4.2 92 | - pymongo==4.4.1 93 | - pyparsing==3.0.9 94 | - python-dateutil==2.8.2 95 | - pytz==2023.3 96 | - pyyaml==6.0.1 97 | - regex==2023.6.3 98 | - requests==2.31.0 99 | - requests-oauthlib==1.3.1 100 | - rsa==4.9 101 | - sentry-sdk==1.28.1 102 | - setproctitle==1.3.2 103 | - six==1.16.0 104 | - smmap==5.0.0 105 | - sniffio==1.3.0 106 | - tenacity==8.2.2 107 | - tensorboard==2.13.0 108 | - tensorboard-data-server==0.7.1 109 | - tensorflow==2.13.0 110 | - tensorflow-datasets==4.9.2 111 | - tensorflow-estimator==2.13.0 112 | - tensorflow-hub==0.14.0 113 | - tensorflow-io-gcs-filesystem==0.32.0 114 | - tensorflow-metadata==1.13.1 115 | - termcolor==2.3.0 116 | - toml==0.10.2 117 | - tqdm==4.65.0 118 | - typing-extensions==4.5.0 119 | - urllib3==1.26.16 120 | - wandb==0.15.6 121 | - werkzeug==2.3.6 122 | - wrapt==1.15.0 123 | - zipp==3.16.2 124 | - zstandard==0.21.0 125 | - dlimp @ git+https://github.com/kvablack/dlimp@fba663b10858793d35f9a0fdbe8f0d51906c8c90 126 | prefix: /scr/kpertsch/miniconda3/envs/rlds_env 127 | -------------------------------------------------------------------------------- /modify_rlds_dataset.py: -------------------------------------------------------------------------------- 1 | """Modifies TFDS dataset with a map function, updates the feature definition and stores new dataset.""" 2 | from functools import partial 3 | 4 | from absl import app, flags 5 | import tensorflow as tf 6 | import tensorflow_datasets as tfds 7 | 8 | from rlds_dataset_mod.mod_functions import TFDS_MOD_FUNCTIONS 9 | from rlds_dataset_mod.multithreaded_adhoc_tfds_builder import ( 10 | MultiThreadedAdhocDatasetBuilder, 11 | ) 12 | 13 | FLAGS = flags.FLAGS 14 | 15 | flags.DEFINE_string("dataset", None, "Dataset name.") 16 | flags.DEFINE_string("data_dir", None, "Directory where source data is stored.") 17 | flags.DEFINE_string("target_dir", None, "Directory where modified data is stored.") 18 | flags.DEFINE_list("mods", None, "List of modification functions, applied in order.") 19 | flags.DEFINE_integer("n_workers", 10, "Number of parallel workers for data conversion.") 20 | flags.DEFINE_integer( 21 | "max_episodes_in_memory", 22 | 100, 23 | "Number of episodes converted & stored in memory before writing to disk.", 24 | ) 25 | 26 | 27 | def mod_features(features): 28 | """Modifies feature dict.""" 29 | for mod in FLAGS.mods: 30 | features = TFDS_MOD_FUNCTIONS[mod].mod_features(features) 31 | return features 32 | 33 | 34 | def mod_dataset_generator(builder, split, mods): 35 | """Modifies dataset features.""" 36 | ds = builder.as_dataset(split=split) 37 | for mod in mods: 38 | ds = TFDS_MOD_FUNCTIONS[mod].mod_dataset(ds) 39 | for episode in tfds.core.dataset_utils.as_numpy(ds): 40 | yield episode 41 | 42 | 43 | def main(_): 44 | builder = tfds.builder(FLAGS.dataset, data_dir=FLAGS.data_dir) 45 | 46 | features = mod_features(builder.info.features) 47 | print("############# Target features: ###############") 48 | print(features) 49 | print("##############################################") 50 | assert FLAGS.data_dir != FLAGS.target_dir # prevent overwriting original dataset 51 | 52 | mod_dataset_builder = MultiThreadedAdhocDatasetBuilder( 53 | name=FLAGS.dataset, 54 | version=builder.version, 55 | features=features, 56 | split_datasets={split: builder.info.splits[split] for split in builder.info.splits}, 57 | config=builder.builder_config, 58 | data_dir=FLAGS.target_dir, 59 | description=builder.info.description, 60 | generator_fcn=partial(mod_dataset_generator, builder=builder, mods=FLAGS.mods), 61 | n_workers=FLAGS.n_workers, 62 | max_episodes_in_memory=FLAGS.max_episodes_in_memory, 63 | ) 64 | mod_dataset_builder.download_and_prepare() 65 | 66 | 67 | if __name__ == "__main__": 68 | app.run(main) 69 | -------------------------------------------------------------------------------- /prepare_open_x.sh: -------------------------------------------------------------------------------- 1 | : ' 2 | Script for downloading, cleaning and resizing Open X-Embodiment Dataset (https://robotics-transformer-x.github.io/) 3 | 4 | Performs the preprocessing steps: 5 | 1. Downloads mixture of 25 Open X-Embodiment datasets 6 | 2. Runs resize function to convert all datasets to 256x256 (if image resolution is larger) and jpeg encoding 7 | 3. Fixes channel flip errors in a few datsets, filters success-only for QT-Opt ("kuka") data 8 | 9 | To reduce disk memory usage during conversion, we download the datasets 1-by-1, convert them 10 | and then delete the original. 11 | We specify the number of parallel workers below -- the more parallel workers, the faster data conversion will run. 12 | Adjust workers to fit the available memory of your machine, the more workers + episodes in memory, the faster. 13 | The default values are tested with a server with ~120GB of RAM and 24 cores. 14 | ' 15 | 16 | DOWNLOAD_DIR= 17 | CONVERSION_DIR= 18 | N_WORKERS=20 # number of workers used for parallel conversion --> adjust based on available RAM 19 | MAX_EPISODES_IN_MEMORY=200 # number of episodes converted in parallel --> adjust based on available RAM 20 | 21 | # increase limit on number of files opened in parallel to 20k --> conversion opens up to 1k temporary files 22 | # in /tmp to store dataset during conversion 23 | ulimit -n 20000 24 | 25 | echo "!!! Warning: This script downloads the Bridge dataset from the Open X-Embodiment bucket, which is currently outdated !!!" 26 | echo "!!! Instead download the bridge_dataset from here: https://rail.eecs.berkeley.edu/datasets/bridge_release/data/tfds/ !!!" 27 | 28 | # format: [dataset_name, dataset_version, transforms] 29 | DATASET_TRANSFORMS=( 30 | "fractal20220817_data 0.1.0 resize_and_jpeg_encode" 31 | "bridge 0.1.0 resize_and_jpeg_encode" 32 | "kuka 0.1.0 resize_and_jpeg_encode,filter_success" 33 | "taco_play 0.1.0 resize_and_jpeg_encode" 34 | "jaco_play 0.1.0 resize_and_jpeg_encode" 35 | "berkeley_cable_routing 0.1.0 resize_and_jpeg_encode" 36 | "roboturk 0.1.0 resize_and_jpeg_encode" 37 | "nyu_door_opening_surprising_effectiveness 0.1.0 resize_and_jpeg_encode" 38 | "viola 0.1.0 resize_and_jpeg_encode" 39 | "berkeley_autolab_ur5 0.1.0 resize_and_jpeg_encode,flip_wrist_image_channels" 40 | "toto 0.1.0 resize_and_jpeg_encode" 41 | "language_table 0.1.0 resize_and_jpeg_encode" 42 | "stanford_hydra_dataset_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode,flip_wrist_image_channels,flip_image_channels" 43 | "austin_buds_dataset_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 44 | "nyu_franka_play_dataset_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 45 | "furniture_bench_dataset_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 46 | "ucsd_kitchen_dataset_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 47 | "austin_sailor_dataset_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 48 | "austin_sirius_dataset_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 49 | "bc_z 0.1.0 resize_and_jpeg_encode" 50 | "dlr_edan_shared_control_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 51 | "iamlab_cmu_pickup_insert_converted_externally_to_rlds 0.1.0 resize_and_jpeg_encode" 52 | "utaustin_mutex 0.1.0 resize_and_jpeg_encode,flip_wrist_image_channels,flip_image_channels" 53 | "berkeley_fanuc_manipulation 0.1.0 resize_and_jpeg_encode,flip_wrist_image_channels,flip_image_channels" 54 | "cmu_stretch 0.1.0 resize_and_jpeg_encode" 55 | ) 56 | 57 | for tuple in "${DATASET_TRANSFORMS[@]}"; do 58 | # Extract strings from the tuple 59 | strings=($tuple) 60 | DATASET=${strings[0]} 61 | VERSION=${strings[1]} 62 | TRANSFORM=${strings[2]} 63 | mkdir ${DOWNLOAD_DIR}/${DATASET} 64 | gsutil -m cp -r gs://gresearch/robotics/${DATASET}/${VERSION} ${DOWNLOAD_DIR}/${DATASET} 65 | python3 modify_rlds_dataset.py --dataset=$DATASET --data_dir=$DOWNLOAD_DIR --target_dir=$CONVERSION_DIR --mods=$TRANSFORM --n_workers=$N_WORKERS --max_episodes_in_memory=$MAX_EPISODES_IN_MEMORY 66 | rm -rf ${DOWNLOAD_DIR}/${DATASET} 67 | mv ${CONVERSION_DIR}/${DATASET} ${DOWNLOAD_DIR} 68 | done 69 | -------------------------------------------------------------------------------- /rlds_dataset_mod/mod_functions.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import dlimp as dl 4 | import tensorflow as tf 5 | import tensorflow_datasets as tfds 6 | 7 | 8 | class TfdsModFunction(ABC): 9 | @classmethod 10 | @abstractmethod 11 | def mod_features( 12 | cls, 13 | features: tfds.features.FeaturesDict, 14 | ) -> tfds.features.FeaturesDict: 15 | """ 16 | Modifies the data builder feature dict to reflect feature changes of ModFunction. 17 | """ 18 | ... 19 | 20 | @classmethod 21 | @abstractmethod 22 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 23 | """ 24 | Perform arbitrary modifications on the dataset that comply with the modified feature definition. 25 | """ 26 | ... 27 | 28 | 29 | def mod_obs_features(features, obs_feature_mod_function): 30 | """Utility function to only modify keys in observation dict.""" 31 | return tfds.features.FeaturesDict( 32 | { 33 | "steps": tfds.features.Dataset( 34 | { 35 | "observation": tfds.features.FeaturesDict( 36 | { 37 | key: obs_feature_mod_function( 38 | key, features["steps"]["observation"][key] 39 | ) 40 | for key in features["steps"]["observation"].keys() 41 | } 42 | ), 43 | **{ 44 | key: features["steps"][key] 45 | for key in features["steps"].keys() 46 | if key not in ("observation",) 47 | }, 48 | } 49 | ), 50 | **{key: features[key] for key in features.keys() if key not in ("steps",)}, 51 | } 52 | ) 53 | 54 | 55 | class ResizeAndJpegEncode(TfdsModFunction): 56 | MAX_RES: int = 256 57 | 58 | @classmethod 59 | def mod_features( 60 | cls, 61 | features: tfds.features.FeaturesDict, 62 | ) -> tfds.features.FeaturesDict: 63 | def downsize_and_jpeg(key, feat): 64 | """Downsizes image features, encodes as jpeg.""" 65 | if len(feat.shape) >= 2 and feat.shape[0] >= 64 and feat.shape[1] >= 64: # is image / depth feature 66 | should_jpeg_encode = ( 67 | isinstance(feat, tfds.features.Image) and "depth" not in key 68 | ) 69 | if len(feat.shape) > 2: 70 | new_shape = (ResizeAndJpegEncode.MAX_RES, ResizeAndJpegEncode.MAX_RES, feat.shape[2]) 71 | else: 72 | new_shape = (ResizeAndJpegEncode.MAX_RES, ResizeAndJpegEncode.MAX_RES) 73 | 74 | if isinstance(feat, tfds.features.Image): 75 | return tfds.features.Image( 76 | shape=new_shape, 77 | dtype=feat.dtype, 78 | encoding_format="jpeg" if should_jpeg_encode else "png", 79 | doc=feat.doc, 80 | ) 81 | else: 82 | return tfds.features.Tensor( 83 | shape=new_shape, 84 | dtype=feat.dtype, 85 | doc=feat.doc, 86 | ) 87 | 88 | return feat 89 | 90 | return mod_obs_features(features, downsize_and_jpeg) 91 | 92 | @classmethod 93 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 94 | def resize_image_fn(step): 95 | # resize images 96 | for key in step["observation"]: 97 | if len(step["observation"][key].shape) >= 2 and ( 98 | step["observation"][key].shape[0] >= 64 99 | or step["observation"][key].shape[1] >= 64 100 | ): 101 | size = (ResizeAndJpegEncode.MAX_RES, 102 | ResizeAndJpegEncode.MAX_RES) 103 | if "depth" in key: 104 | step["observation"][key] = tf.cast( 105 | dl.utils.resize_depth_image( 106 | tf.cast(step["observation"][key], tf.float32), size 107 | ), 108 | step["observation"][key].dtype, 109 | ) 110 | else: 111 | step["observation"][key] = tf.cast( 112 | dl.utils.resize_image(step["observation"][key], size), 113 | tf.uint8, 114 | ) 115 | return step 116 | 117 | def episode_map_fn(episode): 118 | episode["steps"] = episode["steps"].map(resize_image_fn) 119 | return episode 120 | 121 | return ds.map(episode_map_fn) 122 | 123 | 124 | class FilterSuccess(TfdsModFunction): 125 | @classmethod 126 | def mod_features( 127 | cls, 128 | features: tfds.features.FeaturesDict, 129 | ) -> tfds.features.FeaturesDict: 130 | return features # no feature changes 131 | 132 | @classmethod 133 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 134 | return ds.filter(lambda e: e["success"]) 135 | 136 | 137 | class FlipImgChannels(TfdsModFunction): 138 | FLIP_KEYS = ["image"] 139 | 140 | @classmethod 141 | def mod_features( 142 | cls, 143 | features: tfds.features.FeaturesDict, 144 | ) -> tfds.features.FeaturesDict: 145 | return features # no feature changes 146 | 147 | @classmethod 148 | def mod_dataset(cls, ds: tf.data.Dataset) -> tf.data.Dataset: 149 | def flip(step): 150 | for key in cls.FLIP_KEYS: 151 | if key in step["observation"]: 152 | step["observation"][key] = step["observation"][key][..., ::-1] 153 | return step 154 | 155 | def episode_map_fn(episode): 156 | episode["steps"] = episode["steps"].map(flip) 157 | return episode 158 | 159 | return ds.map(episode_map_fn) 160 | 161 | 162 | class FlipWristImgChannels(FlipImgChannels): 163 | FLIP_KEYS = ["wrist_image", "hand_image"] 164 | 165 | 166 | TFDS_MOD_FUNCTIONS = { 167 | "resize_and_jpeg_encode": ResizeAndJpegEncode, 168 | "filter_success": FilterSuccess, 169 | "flip_image_channels": FlipImgChannels, 170 | "flip_wrist_image_channels": FlipWristImgChannels, 171 | } 172 | 173 | 174 | -------------------------------------------------------------------------------- /rlds_dataset_mod/multithreaded_adhoc_tfds_builder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import itertools 3 | from multiprocessing import Pool 4 | from typing import Any, Callable, Dict, Iterable, Tuple, Union 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | import tensorflow_datasets as tfds 9 | from tensorflow_datasets.core import ( 10 | dataset_builder, 11 | download, 12 | example_serializer, 13 | file_adapters, 14 | naming, 15 | ) 16 | from tensorflow_datasets.core import split_builder as split_builder_lib 17 | from tensorflow_datasets.core import splits as splits_lib 18 | from tensorflow_datasets.core import utils 19 | from tensorflow_datasets.core import writer as writer_lib 20 | 21 | Key = Union[str, int] 22 | # The nested example dict passed to `features.encode_example` 23 | Example = Dict[str, Any] 24 | KeyExample = Tuple[Key, Example] 25 | 26 | 27 | class MultiThreadedAdhocDatasetBuilder(tfds.core.dataset_builders.AdhocBuilder): 28 | """Multithreaded adhoc dataset builder.""" 29 | 30 | def __init__( 31 | self, *args, generator_fcn, n_workers, max_episodes_in_memory, **kwargs 32 | ): 33 | super().__init__(*args, **kwargs) 34 | self._generator_fcn = generator_fcn 35 | self._n_workers = n_workers 36 | self._max_episodes_in_memory = max_episodes_in_memory 37 | 38 | def _download_and_prepare( # pytype: disable=signature-mismatch # overriding-parameter-type-checks 39 | self, 40 | dl_manager: download.DownloadManager, 41 | download_config: download.DownloadConfig, 42 | ) -> None: 43 | """Generate all splits and returns the computed split infos.""" 44 | assert ( 45 | self._max_episodes_in_memory % self._n_workers == 0 46 | ) # need to divide max_episodes by workers 47 | split_builder = ParallelSplitBuilder( 48 | split_dict=self._split_datasets, 49 | features=self.info.features, 50 | dataset_size=self.info.dataset_size, 51 | max_examples_per_split=download_config.max_examples_per_split, 52 | beam_options=download_config.beam_options, 53 | beam_runner=download_config.beam_runner, 54 | file_format=self.info.file_format, 55 | shard_config=download_config.get_shard_config(), 56 | generator_fcn=self._generator_fcn, 57 | n_workers=self._n_workers, 58 | max_episodes_in_memory=self._max_episodes_in_memory, 59 | ) 60 | split_generators = self._split_generators(dl_manager) 61 | split_generators = split_builder.normalize_legacy_split_generators( 62 | split_generators=split_generators, 63 | generator_fn=self._generate_examples, 64 | is_beam=False, 65 | ) 66 | dataset_builder._check_split_names(split_generators.keys()) 67 | 68 | # Start generating data for all splits 69 | path_suffix = file_adapters.ADAPTER_FOR_FORMAT[ 70 | self.info.file_format 71 | ].FILE_SUFFIX 72 | 73 | split_info_futures = [] 74 | for split_name, generator in utils.tqdm( 75 | split_generators.items(), 76 | desc="Generating splits...", 77 | unit=" splits", 78 | leave=False, 79 | ): 80 | filename_template = naming.ShardedFileTemplate( 81 | split=split_name, 82 | dataset_name=self.name, 83 | data_dir=self.data_path, 84 | filetype_suffix=path_suffix, 85 | ) 86 | future = split_builder.submit_split_generation( 87 | split_name=split_name, 88 | generator=generator, 89 | filename_template=filename_template, 90 | disable_shuffling=self.info.disable_shuffling, 91 | ) 92 | split_info_futures.append(future) 93 | 94 | # Finalize the splits (after apache beam completed, if it was used) 95 | split_infos = [future.result() for future in split_info_futures] 96 | 97 | # Update the info object with the splits. 98 | split_dict = splits_lib.SplitDict(split_infos) 99 | self.info.set_splits(split_dict) 100 | 101 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 102 | """Define dummy split generators.""" 103 | 104 | def dummy_generator(): 105 | yield None 106 | 107 | return {split: dummy_generator() for split in self._split_datasets} 108 | 109 | 110 | class _SplitInfoFuture: 111 | """Future containing the `tfds.core.SplitInfo` result.""" 112 | 113 | def __init__(self, callback: Callable[[], splits_lib.SplitInfo]): 114 | self._callback = callback 115 | 116 | def result(self) -> splits_lib.SplitInfo: 117 | return self._callback() 118 | 119 | 120 | def parse_examples_from_generator( 121 | episodes, max_episodes, fcn, split_name, total_num_examples, features, serializer 122 | ): 123 | upper = episodes[-1] + 1 124 | upper_str = f'{upper}' if upper < max_episodes else '' 125 | generator = fcn(split=split_name + f"[{episodes[0]}:{upper_str}]") 126 | outputs = [] 127 | for key, sample in utils.tqdm( 128 | zip(episodes, generator), 129 | desc=f"Generating {split_name} examples...", 130 | unit=" examples", 131 | total=total_num_examples, 132 | leave=False, 133 | mininterval=1.0, 134 | ): 135 | if sample is None: 136 | continue 137 | try: 138 | sample = features.encode_example(sample) 139 | except Exception as e: # pylint: disable=broad-except 140 | utils.reraise(e, prefix=f"Failed to encode example:\n{sample}\n") 141 | outputs.append((str(key), serializer.serialize_example(sample))) 142 | return outputs 143 | 144 | 145 | class ParallelSplitBuilder(split_builder_lib.SplitBuilder): 146 | def __init__( 147 | self, *args, generator_fcn, n_workers, max_episodes_in_memory, **kwargs 148 | ): 149 | super().__init__(*args, **kwargs) 150 | self._generator_fcn = generator_fcn 151 | self._n_workers = n_workers 152 | self._max_episodes_in_memory = max_episodes_in_memory 153 | 154 | def _build_from_generator( 155 | self, 156 | split_name: str, 157 | generator: Iterable[KeyExample], 158 | filename_template: naming.ShardedFileTemplate, 159 | disable_shuffling: bool, 160 | ) -> _SplitInfoFuture: 161 | """Split generator for example generators. 162 | 163 | Args: 164 | split_name: str, 165 | generator: Iterable[KeyExample], 166 | filename_template: Template to format the filename for a shard. 167 | disable_shuffling: Specifies whether to shuffle the examples, 168 | 169 | Returns: 170 | future: The future containing the `tfds.core.SplitInfo`. 171 | """ 172 | total_num_examples = None 173 | serialized_info = self._features.get_serialized_info() 174 | writer = writer_lib.Writer( 175 | serializer=example_serializer.ExampleSerializer(serialized_info), 176 | filename_template=filename_template, 177 | hash_salt=split_name, 178 | disable_shuffling=disable_shuffling, 179 | file_format=self._file_format, 180 | shard_config=self._shard_config, 181 | ) 182 | 183 | del generator # use parallel generators instead 184 | episode_lists = chunk_max( 185 | list(np.arange(self._split_dict[split_name].num_examples)), 186 | self._n_workers, 187 | self._max_episodes_in_memory, 188 | ) # generate N episode lists 189 | print(f"Generating with {self._n_workers} workers!") 190 | pool = Pool(processes=self._n_workers) 191 | for i, episodes in enumerate(episode_lists): 192 | print(f"Processing chunk {i + 1} of {len(episode_lists)}.") 193 | results = pool.map( 194 | partial( 195 | parse_examples_from_generator, 196 | fcn=self._generator_fcn, 197 | split_name=split_name, 198 | total_num_examples=total_num_examples, 199 | serializer=writer._serializer, 200 | features=self._features, 201 | max_episodes=self._split_dict[split_name].num_examples, 202 | ), 203 | episodes, 204 | ) 205 | # write results to shuffler --> this will automatically offload to disk if necessary 206 | print("Writing conversion results...") 207 | for result in itertools.chain(*results): 208 | key, serialized_example = result 209 | writer._shuffler.add(key, serialized_example) 210 | writer._num_examples += 1 211 | pool.close() 212 | 213 | print("Finishing split conversion...") 214 | shard_lengths, total_size = writer.finalize() 215 | 216 | split_info = splits_lib.SplitInfo( 217 | name=split_name, 218 | shard_lengths=shard_lengths, 219 | num_bytes=total_size, 220 | filename_template=filename_template, 221 | ) 222 | return _SplitInfoFuture(lambda: split_info) 223 | 224 | 225 | def dictlist2listdict(DL): 226 | "Converts a dict of lists to a list of dicts" 227 | return [dict(zip(DL, t)) for t in zip(*DL.values())] 228 | 229 | 230 | def chunks(l, n): 231 | """Yield n number of sequential chunks from l.""" 232 | d, r = divmod(len(l), n) 233 | for i in range(n): 234 | si = (d + 1) * (i if i < r else r) + d * (0 if i < r else i - r) 235 | yield l[si : si + (d + 1 if i < r else d)] 236 | 237 | 238 | def chunk_max(l, n, max_chunk_sum): 239 | out = [] 240 | for _ in range(int(np.ceil(len(l) / max_chunk_sum))): 241 | out.append([c for c in chunks(l[:max_chunk_sum], n) if c]) 242 | l = l[max_chunk_sum:] 243 | return out 244 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="rlds_dataset_mod", packages=["rlds_dataset_mod"]) 4 | --------------------------------------------------------------------------------