├── .gitattributes ├── __init__.py ├── assets └── optimus_gallery_gif.gif ├── pyproject.toml ├── optimus ├── config │ ├── __init__.py │ ├── base_config.py │ └── bc_config.py ├── algo │ └── __init__.py ├── envs │ ├── box.py │ ├── wrappers.py │ └── env_robosuite.py ├── scripts │ ├── download_datasets.py │ ├── combine_hdf5.py │ ├── filter_trajectories.py │ ├── playback_dataset.py │ └── run_trained_agent_pl.py ├── utils │ ├── file_utils.py │ ├── env_utils.py │ ├── dataset.py │ └── train_utils.py └── exps │ └── local │ └── robosuite │ ├── stack │ └── bc_transformer.json │ ├── stackfive │ └── bc_transformer.json │ ├── stackfour │ └── bc_transformer.json │ └── stackthree │ └── bc_transformer.json ├── requirements.txt ├── setup.py ├── .gitignore ├── docker └── mujoco │ └── Dockerfile ├── LICENSE └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | *.gif filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. -------------------------------------------------------------------------------- /assets/optimus_gallery_gif.gif: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:33539d2962f145f0da70232080cd2bd143d086ed6630ef9d9046cfaf1ca847fd 3 | size 53889018 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | multi_line_output = 3 3 | include_trailing_comma = true 4 | force_grid_wrap = 0 5 | use_parentheses = true 6 | ensure_newline_before_comments = true 7 | line_length = 100 8 | 9 | [tool.black] 10 | line-length = 100 11 | -------------------------------------------------------------------------------- /optimus/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | from robomimic.config.config import Config 6 | 7 | from optimus.config.base_config import config_factory, get_all_registered_configs 8 | 9 | # note: these imports are needed to register these classes in the global config registry 10 | from optimus.config.bc_config import BCConfig 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | psutil 3 | tqdm 4 | termcolor 5 | tensorboard 6 | tensorboardX 7 | imageio 8 | imageio-ffmpeg 9 | egl_probe>=1.0.1 10 | ipdb 11 | wandb 12 | 13 | ipython 14 | patchelf 15 | robosuite==1.4.0 16 | robomimic==0.3.0 17 | jupyterlab 18 | notebook 19 | black 20 | flake8 21 | isort 22 | pytest 23 | protobuf==3.20.0 24 | pytorch_lightning==1.9.5 25 | gdown 26 | 27 | seaborn 28 | mujoco 29 | pygame 30 | 31 | numpy==1.23.5 32 | pyopengl==3.1.6 -------------------------------------------------------------------------------- /optimus/algo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | from robomimic.algo.algo import ( 6 | Algo, 7 | HierarchicalAlgo, 8 | PlannerAlgo, 9 | PolicyAlgo, 10 | RolloutPolicy, 11 | ValueAlgo, 12 | algo_factory, 13 | algo_name_to_factory_func, 14 | register_algo_factory_func, 15 | ) 16 | 17 | # note: these imports are needed to register these classes in the global algo registry 18 | from robomimic.algo.bc import BC, BC_GMM, BC_VAE, BC_Gaussian 19 | from robomimic.algo.bcq import BCQ, BCQ_GMM, BCQ_Distributional 20 | from robomimic.algo.cql import CQL 21 | from robomimic.algo.gl import GL, GL_VAE, ValuePlanner 22 | from robomimic.algo.hbc import HBC 23 | from robomimic.algo.iris import IRIS 24 | from robomimic.algo.td3_bc import TD3_BC 25 | 26 | from optimus.algo.bc import BC_RNN, BC_RNN_GMM, BC_Transformer, BC_Transformer_GMM 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | #!/usr/bin/env python 3 | from os import path 4 | 5 | from setuptools import find_packages, setup 6 | 7 | this_directory = path.abspath(path.dirname(__file__)) 8 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 9 | lines = f.readlines() 10 | 11 | # remove images from README 12 | lines = [x for x in lines if ".png" not in x] 13 | long_description = "".join(lines) 14 | 15 | setup( 16 | name="optimus", 17 | packages=[package for package in find_packages() if package.startswith("optimus")], 18 | install_requires=[], 19 | eager_resources=["*"], 20 | include_package_data=True, 21 | python_requires=">=3", 22 | description="Official code release for Optimus: Imitating Task and Motion Planning with Visuomotor Transformers", 23 | author="Murtaza Dalal", 24 | url="https://github.com/NVlabs/Optimus.git", 25 | author_email="mdalal@andrew.cmu.edu", 26 | version="0.1.0", 27 | long_description=long_description, 28 | long_description_content_type="text/markdown", 29 | ) 30 | -------------------------------------------------------------------------------- /optimus/config/base_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | """ 5 | The base config class that is used for all algorithm configs in this repository. 6 | Subclasses get registered into a global dictionary, making it easy to instantiate 7 | the correct config class given the algorithm name. 8 | """ 9 | 10 | from copy import deepcopy 11 | 12 | import robomimic 13 | import six # preserve metaclass compatibility between python 2 and 3 14 | from robomimic.config import base_config 15 | from robomimic.config.config import Config 16 | 17 | # global dictionary for remembering name - class mappings 18 | REGISTERED_CONFIGS = base_config.REGISTERED_CONFIGS 19 | 20 | 21 | def get_all_registered_configs(): 22 | """ 23 | Give access to dictionary of all registered configs for external use. 24 | """ 25 | return deepcopy(REGISTERED_CONFIGS) 26 | 27 | 28 | def config_factory(algo_name, dic=None): 29 | """ 30 | Creates an instance of a config from the algo name. Optionally pass 31 | a dictionary to instantiate the config from the dictionary. 32 | """ 33 | if algo_name not in REGISTERED_CONFIGS: 34 | raise Exception( 35 | "Config for algo name {} not found. Make sure it is a registered config among: {}".format( 36 | algo_name, ", ".join(REGISTERED_CONFIGS) 37 | ) 38 | ) 39 | return REGISTERED_CONFIGS[algo_name](dict_to_load=dic) 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac OSX 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | datasets/ 135 | trained_models/ -------------------------------------------------------------------------------- /optimus/envs/box.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import numpy as np 6 | from robosuite.models.objects import PrimitiveObject 7 | from robosuite.utils.mjcf_utils import get_size, new_site 8 | 9 | 10 | class BoxObject(PrimitiveObject): 11 | """ 12 | A box object. 13 | 14 | Args: 15 | size (3-tuple of float): (half-x, half-y, half-z) size parameters for this box object 16 | """ 17 | 18 | def __init__( 19 | self, 20 | name, 21 | size=None, 22 | size_max=None, 23 | size_min=None, 24 | density=None, 25 | friction=None, 26 | rgba=None, 27 | solref=None, 28 | solimp=None, 29 | material=None, 30 | joints="default", 31 | obj_type="all", 32 | duplicate_collision_geoms=True, 33 | ): 34 | size = get_size(size, size_max, size_min, [0.07, 0.07, 0.07], [0.03, 0.03, 0.03]) 35 | super().__init__( 36 | name=name, 37 | size=size, 38 | rgba=rgba, 39 | density=density, 40 | friction=friction, 41 | solref=solref, 42 | solimp=solimp, 43 | material=material, 44 | joints=joints, 45 | obj_type=obj_type, 46 | duplicate_collision_geoms=duplicate_collision_geoms, 47 | ) 48 | 49 | def sanity_check(self): 50 | """ 51 | Checks to make sure inputted size is of correct length 52 | 53 | Raises: 54 | AssertionError: [Invalid size length] 55 | """ 56 | assert len(self.size) == 3, "box size should have length 3" 57 | 58 | def _get_object_subtree(self): 59 | return self._get_object_subtree_(ob_type="box") 60 | 61 | @property 62 | def bottom_offset(self): 63 | return np.array([0, 0, -1 * self.size[2]]) 64 | 65 | @property 66 | def top_offset(self): 67 | return np.array([0, 0, self.size[2]]) 68 | 69 | @property 70 | def horizontal_radius(self): 71 | return np.linalg.norm(self.size[0:2], 2) 72 | 73 | def get_bounding_box_size(self): 74 | return np.array([self.size[0], self.size[1], self.size[2]]) 75 | 76 | 77 | class BoxObjectWithSites(BoxObject): 78 | """ 79 | A box object with sites on the x and y axes. 80 | """ 81 | 82 | def _get_object_subtree(self): 83 | tree = self._get_object_subtree_(ob_type="box") 84 | site_element_attr = self.get_site_attrib_template() 85 | 86 | delta = self.size[0] / 2 87 | site_element_attr["pos"] = f"{delta} 0 0" 88 | site_element_attr["name"] = "x-site" 89 | tree.append(new_site(**site_element_attr)) 90 | 91 | delta = self.size[0] / 2 92 | site_element_attr["pos"] = f"0 {delta} 0" 93 | site_element_attr["name"] = "y-site" 94 | tree.append(new_site(**site_element_attr)) 95 | return tree 96 | -------------------------------------------------------------------------------- /optimus/scripts/download_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to download datasets packaged with the repository. 3 | """ 4 | import os 5 | import argparse 6 | 7 | import optimus 8 | import optimus.utils.file_utils as FileUtils 9 | dataset_links = dict( 10 | Stack="https://drive.google.com/file/d/1ciP5yP0D06gT7QXq7OvxlR-mcepIoHCK/view?usp=drive_link", 11 | StackThree="https://drive.google.com/file/d/1_Qo0jYPfepe4rUpyDrtbmI9WHPipRwmL/view?usp=drive_link", 12 | StackFour="https://drive.google.com/file/d/142XBL1Hy2Ru1qGNbv4JlBjHm-Y0q_is2/view?usp=drive_link", 13 | StackFive="https://drive.google.com/file/d/1z7CFsBEXkj7DxLSpfUteZjPuN2quwBtA/view?usp=drive_link" 14 | ) 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | 20 | # directory to download datasets to 21 | parser.add_argument( 22 | "--download_dir", 23 | type=str, 24 | default=None, 25 | help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.", 26 | ) 27 | 28 | # tasks to download datasets for 29 | parser.add_argument( 30 | "--tasks", 31 | type=str, 32 | nargs='+', 33 | default=["Stack"], 34 | help="Tasks to download datasets for. Defaults to square_d0 task. Pass 'all' to download all tasks\ 35 | for the provided dataset type or directly specify the list of tasks.", 36 | ) 37 | 38 | # dry run - don't actually download datasets, but print which datasets would be downloaded 39 | parser.add_argument( 40 | "--dry_run", 41 | action='store_true', 42 | help="set this flag to do a dry run to only print which datasets would be downloaded" 43 | ) 44 | 45 | args = parser.parse_args() 46 | 47 | # set default base directory for downloads 48 | default_base_dir = args.download_dir 49 | if default_base_dir is None: 50 | default_base_dir = os.path.join(optimus.__path__[0], "../datasets") 51 | 52 | # load args 53 | download_tasks = args.tasks 54 | if "all" in download_tasks: 55 | assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks) 56 | download_tasks = list(dataset_links.keys()) 57 | else: 58 | for task in download_tasks: 59 | assert task in dataset_links, "got unknown task {}. Choose one of {}".format(task, list(dataset_links.keys())) 60 | 61 | # download requested datasets 62 | for task in download_tasks: 63 | download_dir = os.path.abspath(os.path.join(default_base_dir)) 64 | download_path = os.path.join(download_dir, "{}.hdf5".format(task)) 65 | print("\nDownloading dataset:\n task: {}\n download path: {}" 66 | .format(task, download_path)) 67 | url = dataset_links[task] 68 | if args.dry_run: 69 | print("\ndry run: skip download") 70 | else: 71 | # Make sure path exists and create if it doesn't 72 | os.makedirs(download_dir, exist_ok=True) 73 | print("") 74 | FileUtils.download_url_from_gdrive( 75 | url=url, 76 | download_dir=download_dir, 77 | check_overwrite=True, 78 | ) 79 | print("") -------------------------------------------------------------------------------- /docker/mujoco/Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM nvcr.io/nvidia/pytorch:21.08-py3 2 | FROM nvcr.io/nvidia/cudagl:11.3.0-devel-ubuntu20.04 3 | ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility 4 | # env variables for tzdata install 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | ENV TZ=America 7 | ENV LD_LIBRARY_PATH "$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin" 8 | ENV LD_LIBRARY_PATH "$LD_LIBRARY_PATH:/usr/lib/nvidia" 9 | ENV PYTHONPATH ${PYTHONPATH}:/workspace 10 | ENV MUJOCO_GL 'egl' 11 | ENV PATH "/usr/local/cuda-new/bin:$PATH" 12 | ENV PIP_CONFIG_FILE pip.conf 13 | ENV PYTHONPATH ${PYTHONPATH}:/home/robosuite 14 | ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility 15 | 16 | # NOTE: each RUN command creates a new image layer, so minimize the number of run commands if possible 17 | # installing other nice functionalities and system packages required by e.g. robosuite 18 | RUN apt-get update -y && \ 19 | apt-get install -y \ 20 | htop screen tmux \ 21 | sshfs libosmesa6-dev wget curl git \ 22 | libeigen3-dev \ 23 | liborocos-kdl-dev \ 24 | libkdl-parser-dev \ 25 | liburdfdom-dev \ 26 | libnlopt-dev \ 27 | libnlopt-cxx-dev \ 28 | swig \ 29 | python3 \ 30 | python3-pip \ 31 | python3-dev \ 32 | vim \ 33 | git-lfs \ 34 | cmake \ 35 | software-properties-common \ 36 | libxcursor-dev \ 37 | libxrandr-dev \ 38 | libxinerama-dev \ 39 | libxi-dev \ 40 | mesa-common-dev \ 41 | zip \ 42 | unzip \ 43 | make \ 44 | g++ \ 45 | python2.7 \ 46 | wget \ 47 | vulkan-utils \ 48 | mesa-vulkan-drivers \ 49 | apt nano rsync \ 50 | libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev \ 51 | software-properties-common net-tools unzip virtualenv \ 52 | xpra xserver-xorg-dev libglfw3-dev patchelf python3-pip -y \ 53 | && add-apt-repository -y ppa:openscad/releases && apt-get update && apt-get install -y openscad 54 | 55 | # install mujoco 56 | RUN mkdir /root/.mujoco/ \ 57 | && wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz \ 58 | && tar -xvf mujoco210-linux-x86_64.tar.gz \ 59 | && mv mujoco210 /root/.mujoco/ \ 60 | && rm mujoco210-linux-x86_64.tar.gz 61 | 62 | # robomimic dependencies and installing other packages 63 | RUN python3 -m pip install h5py psutil tqdm termcolor tensorboard tensorboardX imageio imageio-ffmpeg egl_probe>=1.0.1 ipdb wandb \ 64 | && python3 -m pip install ipython patchelf robosuite jupyterlab notebook black flake8 isort pytest protobuf==3.20.1 pytorch_lightning \ 65 | && python3 -m pip install seaborn mujoco pygame signatory==1.2.6.1.9.0 pyopengl==3.1.6 vit-pytorch stable_baselines3 \ 66 | && pip uninstall torch torchvision torchaudio -y \ 67 | && python3 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 68 | 69 | RUN wget https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run --no-check-certificate 70 | RUN sh cuda_11.3.0_465.19.01_linux.run --toolkit --silent --toolkitpath=/usr/local/cuda-new 71 | ENV TORCH_CUDA_ARCH_LIST "7.0 7.5 8.0" 72 | RUN git clone https://github.com/NVIDIA/apex /home/apex && cd /home/apex \ 73 | && python3 -m pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 74 | 75 | RUN python3 -m pip install gym \ 76 | && python3 -m pip uninstall opencv-python -y && python3 -m pip install opencv-python-headless 77 | 78 | RUN python3 -m pip install trimesh==3.12.6 pyopengl==3.1.6 79 | 80 | RUN python3 -m pip install numpy==1.23.5 81 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. 7 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 8 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. 9 | 10 | 2. License Grant 11 | 12 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 13 | 14 | 3. Limitations 15 | 16 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 17 | 18 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 19 | 20 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 21 | 22 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 23 | 24 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. 25 | 26 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. 27 | 28 | 4. Disclaimer of Warranty. 29 | 30 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 31 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 32 | 33 | 5. Limitation of Liability. 34 | 35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /optimus/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | """ 6 | A collection of utility functions for working with files, such as reading metadata from 7 | demonstration datasets, loading model checkpoints, or downloading dataset files. 8 | """ 9 | import os 10 | from collections import OrderedDict 11 | 12 | import h5py 13 | import robomimic.utils.obs_utils as ObsUtils 14 | from robomimic.utils.file_utils import * 15 | 16 | from optimus.algo import algo_factory 17 | 18 | import os 19 | import shutil 20 | import tempfile 21 | import gdown 22 | 23 | def download_url_from_gdrive(url, download_dir, check_overwrite=True): 24 | """ 25 | Downloads a file at a URL from Google Drive. 26 | 27 | Example usage: 28 | url = https://drive.google.com/file/d/1DABdqnBri6-l9UitjQV53uOq_84Dx7Xt/view?usp=drive_link 29 | download_dir = "/tmp" 30 | download_url_from_gdrive(url, download_dir, check_overwrite=True) 31 | 32 | Args: 33 | url (str): url string 34 | download_dir (str): path to directory where file should be downloaded 35 | check_overwrite (bool): if True, will sanity check the download fpath to make sure a file of that name 36 | doesn't already exist there 37 | """ 38 | assert url_is_alive(url), "@download_url_from_gdrive got unreachable url: {}".format(url) 39 | 40 | with tempfile.TemporaryDirectory() as td: 41 | # HACK: Change directory to temp dir, download file there, and then move the file to desired directory. 42 | # We do this because we do not know the name of the file beforehand. 43 | cur_dir = os.getcwd() 44 | os.chdir(td) 45 | fpath = gdown.download(url, quiet=False, fuzzy=True) 46 | fname = os.path.basename(fpath) 47 | file_to_write = os.path.join(download_dir, fname) 48 | if check_overwrite and os.path.exists(file_to_write): 49 | user_response = input(f"Warning: file {file_to_write} already exists. Overwrite? y/n\n") 50 | assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download." 51 | shutil.move(fpath, file_to_write) 52 | os.chdir(cur_dir) 53 | 54 | 55 | def policy_from_checkpoint(device=None, ckpt_path=None, ckpt_dict=None, verbose=False): 56 | """ 57 | This function restores a trained policy from a checkpoint file or 58 | loaded model dictionary. 59 | 60 | Args: 61 | device (torch.device): if provided, put model on this device 62 | 63 | ckpt_path (str): Path to checkpoint file. Only needed if not providing @ckpt_dict. 64 | 65 | ckpt_dict(dict): Loaded model checkpoint dictionary. Only needed if not providing @ckpt_path. 66 | 67 | verbose (bool): if True, include print statements 68 | 69 | Returns: 70 | model (RolloutPolicy): instance of Algo that has the saved weights from 71 | the checkpoint file, and also acts as a policy that can easily 72 | interact with an environment in a training loop 73 | 74 | ckpt_dict (dict): loaded checkpoint dictionary (convenient to avoid 75 | re-loading checkpoint from disk multiple times) 76 | """ 77 | ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path, ckpt_dict=ckpt_dict) 78 | 79 | # algo name and config from model dict 80 | algo_name, _ = algo_name_from_checkpoint(ckpt_dict=ckpt_dict) 81 | config, _ = config_from_checkpoint(algo_name=algo_name, ckpt_dict=ckpt_dict, verbose=verbose) 82 | 83 | # read config to set up metadata for observation modalities (e.g. detecting rgb observations) 84 | ObsUtils.initialize_obs_utils_with_config(config) 85 | 86 | # env meta from model dict to get info needed to create model 87 | env_meta = ckpt_dict["env_metadata"] 88 | shape_meta = ckpt_dict["shape_metadata"] 89 | 90 | # maybe restore observation normalization stats 91 | obs_normalization_stats = ckpt_dict.get("obs_normalization_stats", None) 92 | if obs_normalization_stats is not None: 93 | assert config.train.hdf5_normalize_obs 94 | for m in obs_normalization_stats: 95 | for k in obs_normalization_stats[m]: 96 | obs_normalization_stats[m][k] = np.array(obs_normalization_stats[m][k]) 97 | 98 | if device is None: 99 | # get torch device 100 | device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda) 101 | 102 | # create model and load weights 103 | model = algo_factory( 104 | algo_name, 105 | config, 106 | obs_key_shapes=shape_meta["all_shapes"], 107 | ac_dim=shape_meta["ac_dim"], 108 | device=device, 109 | ) 110 | model.deserialize(ckpt_dict["model"]) 111 | model.set_eval() 112 | model = RolloutPolicy(model, obs_normalization_stats=obs_normalization_stats) 113 | if verbose: 114 | print("============= Loaded Policy =============") 115 | print(model) 116 | return model, ckpt_dict 117 | 118 | 119 | def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=False): 120 | """ 121 | Retrieves shape metadata from dataset. 122 | 123 | Args: 124 | dataset_path (str): path to dataset 125 | all_obs_keys (list): list of all modalities used by the model. If not provided, all modalities 126 | present in the file are used. 127 | verbose (bool): if True, include print statements 128 | 129 | Returns: 130 | shape_meta (dict): shape metadata. Contains the following keys: 131 | 132 | :`'ac_dim'`: action space dimension 133 | :`'all_shapes'`: dictionary that maps observation key string to shape 134 | :`'all_obs_keys'`: list of all observation modalities used 135 | :`'use_images'`: bool, whether or not image modalities are present 136 | """ 137 | 138 | shape_meta = {} 139 | 140 | # read demo file for some metadata 141 | dataset_path = os.path.expanduser(dataset_path) 142 | f = h5py.File(dataset_path, "r") 143 | demo_id = list(f["data"].keys())[0] 144 | demo = f["data/{}".format(demo_id)] 145 | 146 | # action dimension 147 | shape_meta["ac_dim"] = f["data/{}/actions".format(demo_id)].shape[1] 148 | 149 | # observation dimensions 150 | all_shapes = OrderedDict() 151 | 152 | if all_obs_keys is None: 153 | # use all modalities present in the file 154 | all_obs_keys = [k for k in demo["obs"]] 155 | 156 | for k in sorted(all_obs_keys): 157 | if k == "timesteps": 158 | initial_shape = (1,) 159 | else: 160 | initial_shape = demo["obs/{}".format(k)].shape[1:] 161 | if verbose: 162 | print("obs key {} with shape {}".format(k, initial_shape)) 163 | # Store processed shape for each obs key 164 | all_shapes[k] = ObsUtils.get_processed_shape( 165 | obs_modality=ObsUtils.OBS_KEYS_TO_MODALITIES[k], 166 | input_shape=initial_shape, 167 | ) 168 | 169 | f.close() 170 | 171 | shape_meta["all_shapes"] = all_shapes 172 | shape_meta["all_obs_keys"] = all_obs_keys 173 | shape_meta["use_images"] = ObsUtils.has_modality("rgb", all_obs_keys) 174 | 175 | return shape_meta 176 | -------------------------------------------------------------------------------- /optimus/scripts/combine_hdf5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | import traceback 5 | 6 | import h5py 7 | import numpy as np 8 | import robomimic.utils.file_utils as FileUtils 9 | from tqdm import tqdm 10 | 11 | """ 12 | Script for combining multiple hdf5 files into a single hdf5 file. 13 | """ 14 | 15 | def write_trajectory_to_dataset( 16 | env, traj, data_grp, demo_name, save_next_obs=False, env_type="mujoco" 17 | ): 18 | """ 19 | Write the collected trajectory to hdf5 compatible with robomimic. 20 | """ 21 | 22 | # create group for this trajectory 23 | ep_data_grp = data_grp.create_group(demo_name) 24 | ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]), compression="gzip") 25 | 26 | if env_type == "mujoco": 27 | data = np.array(traj["states"]) 28 | ep_data_grp.create_dataset("states", data=data) 29 | if "obs" in traj: 30 | for k in traj["obs"]: 31 | ep_data_grp.create_dataset( 32 | "obs/{}".format(k), data=np.array(traj["obs"][k]), compression="gzip" 33 | ) 34 | if save_next_obs: 35 | ep_data_grp.create_dataset( 36 | "next_obs/{}".format(k), 37 | data=np.array(traj["next_obs"][k]), 38 | compression="gzip", 39 | ) 40 | 41 | # episode metadata 42 | ep_data_grp.attrs["num_samples"] = traj["attrs"][ 43 | "num_samples" 44 | ] # number of transitions in this episode 45 | if "model_file" in traj: 46 | ep_data_grp.attrs["model_file"] = traj["model_file"] 47 | if "init_string" in traj: 48 | ep_data_grp.attrs["init_string"] = traj["init_string"] 49 | if "goal_parts_string" in traj: 50 | ep_data_grp.attrs["goal_parts_string"] = traj["goal_parts_string"] 51 | return traj["actions"].shape[0] 52 | 53 | 54 | def load_demo_info(hdf5_file, filter_key=None): 55 | """ 56 | Args: 57 | filter_by_attribute (str): if provided, use the provided filter key 58 | to select a subset of demonstration trajectories to load 59 | 60 | demos (list): list of demonstration keys to load from the hdf5 file. If 61 | omitted, all demos in the file (or under the @filter_by_attribute 62 | filter key) are used. 63 | """ 64 | if filter_key is not None: 65 | print("using filter key: {}".format(args.filter_key)) 66 | demos = [elem.decode("utf-8") for elem in np.array(hdf5_file["mask/{}".format(filter_key)])] 67 | else: 68 | demos = list(hdf5_file["data"].keys()) 69 | 70 | # sort demo keys 71 | inds = np.argsort([int(elem[5:]) for elem in demos]) 72 | demos = [demos[i] for i in inds] 73 | 74 | return demos 75 | 76 | 77 | def load_dataset_in_memory( 78 | demo_list, 79 | hdf5_file, 80 | dataset_keys, 81 | data_grp, 82 | demo_count=0, 83 | total_samples=0, 84 | env_type="mujoco", 85 | ): 86 | """ 87 | Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this 88 | differs from `self.getitem_cache`, which, if active, actually caches the outputs of the 89 | `getitem` operation. 90 | 91 | Args: 92 | demo_list (list): list of demo keys, e.g., 'demo_0' 93 | hdf5_file (h5py.File): file handle to the hdf5 dataset. 94 | obs_keys (list, tuple): observation keys to fetch, e.g., 'images' 95 | dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions' 96 | load_next_obs (bool): whether to load next_obs from the dataset 97 | 98 | Returns: 99 | all_data (dict): dictionary of loaded data. 100 | """ 101 | for ep in tqdm(demo_list): 102 | demo_name = f"demo_{demo_count}" 103 | traj = {} 104 | traj["attrs"] = {} 105 | traj["attrs"]["num_samples"] = hdf5_file["data/{}".format(ep)].attrs["num_samples"] 106 | # get obs 107 | traj["obs"] = { 108 | k: hdf5_file["data/{}/obs/{}".format(ep, k)][()] 109 | for k in hdf5_file["data/{}/obs".format(ep)] 110 | } 111 | # get other dataset keys 112 | for k in dataset_keys: 113 | if env_type == "mujoco": 114 | traj[k] = hdf5_file["data/{}/{}".format(ep, k)][ 115 | () 116 | ] # NOTE: do not cast to float this breaks action playback! 117 | try: 118 | traj["model_file"] = hdf5_file["data/{}".format(ep)].attrs["model_file"] 119 | except: 120 | pass 121 | 122 | try: 123 | traj["init_string"] = hdf5_file["data/{}".format(ep)].attrs["init_string"] 124 | traj["goal_parts_string"] = hdf5_file["data/{}".format(ep)].attrs["goal_parts_string"] 125 | except: 126 | pass 127 | write_trajectory_to_dataset(None, traj, data_grp, demo_name=demo_name, env_type=env_type) 128 | demo_count += 1 129 | total_samples += traj["attrs"]["num_samples"] 130 | env_args = hdf5_file["data/"].attrs["env_args"] 131 | return env_args, demo_count, total_samples 132 | 133 | 134 | def global_dataset_updates(data_grp, total_samples, env_args): 135 | """ 136 | Update the global dataset attributes. 137 | """ 138 | data_grp.attrs["total_samples"] = total_samples 139 | data_grp.attrs["env_args"] = env_args 140 | return data_grp 141 | 142 | 143 | def combine_hdf5(hdf5_paths, hdf5_use_swmr, dataset_path, filter_key): 144 | data_writer = h5py.File(dataset_path, "w") 145 | data_grp = data_writer.create_group("data") 146 | dataset_keys = ["actions", "states"] 147 | demo_count = 0 148 | total_samples = 0 149 | for hdf5_path in tqdm(hdf5_paths): 150 | try: 151 | hdf5_file = h5py.File(hdf5_path, "r", swmr=hdf5_use_swmr, libver="latest") 152 | demo_list = load_demo_info(hdf5_file, filter_key=filter_key) 153 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=hdf5_path) 154 | env_type = "mujoco" 155 | env_args, demo_count, total_samples = load_dataset_in_memory( 156 | demo_list, 157 | hdf5_file, 158 | dataset_keys, 159 | data_grp=data_grp, 160 | total_samples=total_samples, 161 | demo_count=demo_count, 162 | env_type=env_type, 163 | ) 164 | # this won't work when combining many files (just re-generate filter keys) 165 | # if "mask" in hdf5_file: 166 | # hdf5_file.copy("mask", data_writer) 167 | print("loaded: ", hdf5_path) 168 | except: 169 | print("failed to load: ", hdf5_path) 170 | print(traceback.format_exc()) 171 | pass 172 | global_dataset_updates(data_grp, total_samples, env_args) 173 | data_writer.close() 174 | 175 | 176 | if __name__ == "__main__": 177 | import argparse 178 | 179 | parser = argparse.ArgumentParser() 180 | parser.add_argument("--hdf5_paths", nargs="+") 181 | parser.add_argument("--output_path") 182 | parser.add_argument("--filter_key", type=str, default=None) 183 | 184 | args = parser.parse_args() 185 | combine_hdf5(args.hdf5_paths, True, args.output_path, args.filter_key) 186 | -------------------------------------------------------------------------------- /optimus/scripts/filter_trajectories.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | import traceback 5 | from collections import Counter 6 | 7 | import h5py 8 | import numpy as np 9 | from robomimic.utils.file_utils import create_hdf5_filter_key 10 | from tqdm import tqdm 11 | 12 | 13 | def write_trajectory_to_dataset( 14 | env, traj, data_grp, demo_name, save_next_obs=False, env_type="mujoco" 15 | ): 16 | """ 17 | Write the collected trajectory to hdf5 compatible with robomimic. 18 | """ 19 | 20 | # create group for this trajectory 21 | ep_data_grp = data_grp.create_group(demo_name) 22 | ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]), compression="gzip") 23 | 24 | data = np.array(traj["states"]) 25 | ep_data_grp.create_dataset("states", data=data) 26 | if "obs" in traj: 27 | for k in traj["obs"]: 28 | ep_data_grp.create_dataset( 29 | "obs/{}".format(k), data=np.array(traj["obs"][k]), compression="gzip" 30 | ) 31 | if save_next_obs: 32 | ep_data_grp.create_dataset( 33 | "next_obs/{}".format(k), 34 | data=np.array(traj["next_obs"][k]), 35 | compression="gzip", 36 | ) 37 | 38 | # episode metadata 39 | ep_data_grp.attrs["num_samples"] = traj["attrs"][ 40 | "num_samples" 41 | ] # number of transitions in this episode 42 | if "model_file" in traj: 43 | ep_data_grp.attrs["model_file"] = traj["model_file"] 44 | if "init_string" in traj: 45 | ep_data_grp.attrs["init_string"] = traj["init_string"] 46 | if "goal_parts_string" in traj: 47 | ep_data_grp.attrs["goal_parts_string"] = traj["goal_parts_string"] 48 | return traj["actions"].shape[0] 49 | 50 | 51 | def load_demo_info(hdf5_file): 52 | """ 53 | Args: 54 | filter_by_attribute (str): if provided, use the provided filter key 55 | to select a subset of demonstval_ration trajectories to load 56 | 57 | demos (list): list of demonstval_ration keys to load from the hdf5 file. If 58 | omitted, all demos in the file (or under the @filter_by_attribute 59 | filter key) are used. 60 | """ 61 | demos = list(hdf5_file["data"].keys()) 62 | 63 | # sort demo keys 64 | inds = np.argsort([int(elem[5:]) for elem in demos]) 65 | demos = [demos[i] for i in inds] 66 | 67 | return demos 68 | 69 | 70 | def compute_traj_length_stats(hdf5_file, demo_list): 71 | traj_lengths = [] 72 | for ep in demo_list: 73 | traj_lengths.append(hdf5_file["data"][ep].attrs["num_samples"]) 74 | return np.mean(traj_lengths), np.std(traj_lengths) 75 | 76 | 77 | def combine_hdf5( 78 | hdf5_paths, 79 | hdf5_use_swmr, 80 | outlier_traj_length_sd, 81 | x_bounds, 82 | y_bounds, 83 | z_bounds, 84 | val_ratio, 85 | filter_key_prefix, 86 | eef_pos_key, 87 | ): 88 | traj_lengths = [] 89 | outlier_files = Counter() 90 | out_of_bb_files = Counter() 91 | num_outlier_trajectories = 0 92 | num_out_of_bb_trajectories = 0 93 | num_trajectories = 0 94 | for hdf5_path in tqdm(hdf5_paths): 95 | try: 96 | hdf5_file = h5py.File(hdf5_path, "r+", swmr=hdf5_use_swmr, libver="latest") 97 | demo_list = load_demo_info(hdf5_file) 98 | traj_length_mean, traj_length_std = compute_traj_length_stats(hdf5_file, demo_list) 99 | filtered_demos = [] 100 | for ep in demo_list: 101 | eef_pos = hdf5_file["data/{}/obs/{}".format(ep, eef_pos_key)][()] 102 | traj_length = hdf5_file["data"][ep].attrs["num_samples"] 103 | if traj_length > traj_length_mean + outlier_traj_length_sd * traj_length_std: 104 | too_long_traj = True 105 | else: 106 | too_long_traj = False 107 | if x_bounds is not None and y_bounds is not None and z_bounds is not None: 108 | out_of_bb = ( 109 | min(eef_pos[:, 0]) < x_bounds[0] 110 | or max(eef_pos[:, 0]) > x_bounds[1] 111 | or min(eef_pos[:, 1]) < y_bounds[0] 112 | or max(eef_pos[:, 1]) > y_bounds[1] 113 | or min(eef_pos[:, 2]) < z_bounds[0] 114 | or max(eef_pos[:, 2]) > z_bounds[1] 115 | ) 116 | else: 117 | out_of_bb = False 118 | if out_of_bb: 119 | out_of_bb_files[hdf5_path] += 1 120 | num_out_of_bb_trajectories += 1 121 | elif too_long_traj: 122 | outlier_files[hdf5_path] += 1 123 | num_outlier_trajectories += 1 124 | else: 125 | filtered_demos.append(ep) 126 | traj_lengths.append(hdf5_file["data"][ep].attrs["num_samples"]) 127 | num_trajectories += len(demo_list) 128 | hdf5_file.close() 129 | 130 | num_demos = len(filtered_demos) 131 | val_val_ratio = val_ratio 132 | num_val = int(val_val_ratio * num_demos) 133 | mask = np.zeros(num_demos) 134 | mask[:num_val] = 1.0 135 | np.random.shuffle(mask) 136 | mask = mask.astype(int) 137 | train_inds = (1 - mask).nonzero()[0] 138 | valid_inds = mask.nonzero()[0] 139 | train_keys = [filtered_demos[i] for i in train_inds] 140 | valid_keys = [filtered_demos[i] for i in valid_inds] 141 | for key in train_keys: 142 | assert not (key in valid_keys) 143 | print( 144 | "{} validation demonstval_rations out of {} total demonstval_rations.".format( 145 | num_val, num_demos 146 | ) 147 | ) 148 | 149 | # pass mask to generate split 150 | name_1 = f"train{filter_key_prefix}" 151 | name_2 = f"valid{filter_key_prefix}" 152 | 153 | train_lengths = create_hdf5_filter_key( 154 | hdf5_path=hdf5_path, demo_keys=train_keys, key_name=name_1 155 | ) 156 | valid_lengths = create_hdf5_filter_key( 157 | hdf5_path=hdf5_path, demo_keys=valid_keys, key_name=name_2 158 | ) 159 | all_valid_lengths = create_hdf5_filter_key( 160 | hdf5_path=hdf5_path, demo_keys=filtered_demos, key_name="all_valid" 161 | ) 162 | 163 | print("Total number of train samples: {}".format(np.sum(train_lengths))) 164 | print("Average number of train samples {}".format(np.mean(train_lengths))) 165 | 166 | print("Total number of valid samples: {}".format(np.sum(valid_lengths))) 167 | print("Average number of valid samples {}".format(np.mean(valid_lengths))) 168 | 169 | print("Total number of all valid samples: {}".format(np.sum(all_valid_lengths))) 170 | print("Average number of all valid samples {}".format(np.mean(all_valid_lengths))) 171 | except: 172 | print("failed to load: ", hdf5_path) 173 | print(traceback.format_exc()) 174 | pass 175 | print(f"Num files with > mean + 2*sd samples: ", len(outlier_files)) 176 | print( 177 | f"Num trajectories with > mean + 2*sd samples: ", 178 | num_outlier_trajectories, 179 | ) 180 | print( 181 | "percentage of outlier trajectories: ", 182 | num_outlier_trajectories / num_trajectories * 100, 183 | ) 184 | print() 185 | print(f"Num files with out of bb: ", len(out_of_bb_files)) 186 | print(f"Num trajectories with out of bb: ", num_out_of_bb_trajectories) 187 | print( 188 | "percentage of out of bb trajectories: ", 189 | num_out_of_bb_trajectories / num_trajectories * 100, 190 | ) 191 | 192 | # print trajectory stats: 193 | print("min traj length: ", np.min(traj_lengths)) 194 | print("max traj length: ", np.max(traj_lengths)) 195 | print("mean traj length: ", np.mean(traj_lengths)) 196 | print("median traj length: ", np.median(traj_lengths)) 197 | print("std traj length: ", np.std(traj_lengths)) 198 | 199 | 200 | if __name__ == "__main__": 201 | import argparse 202 | 203 | parser = argparse.ArgumentParser() 204 | parser.add_argument("--hdf5_paths", nargs="+") 205 | parser.add_argument("--outlier_traj_length_sd", type=int, default=None) 206 | parser.add_argument("--x_bounds", nargs=2, type=float, default=None) 207 | parser.add_argument("--y_bounds", nargs=2, type=float, default=None) 208 | parser.add_argument("--z_bounds", nargs=2, type=float, default=None) 209 | parser.add_argument("--val_ratio", type=float, default=0.1) 210 | parser.add_argument("--filter_key_prefix", type=str, default="") 211 | parser.add_argument("--eef_pos_key", type=str, default="robot0_eef_pos") 212 | 213 | args = parser.parse_args() 214 | combine_hdf5( 215 | args.hdf5_paths, 216 | True, 217 | args.outlier_traj_length_sd, 218 | args.x_bounds, 219 | args.y_bounds, 220 | args.z_bounds, 221 | val_ratio=args.val_ratio, 222 | filter_key_prefix=args.filter_key_prefix, 223 | eef_pos_key=args.eef_pos_key, 224 | ) 225 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OPTIMUS: Imitating Task and Motion Planning with Visuomotor Transformers 2 |
3 | 4 | This repository is the official implementation of Imitating Task and Motion Planning with Visuomotor Transformers. 5 | 6 | [Murtaza Dalal](https://mihdalal.github.io/)$^1$, [Ajay Mandlekar](https://ai.stanford.edu/~amandlek/)$^2$, [Caelan Garrett](http://web.mit.edu/caelan/www/)$^2$, [Ankur Handa](https://ankurhanda.github.io/)$^2$, [Ruslan Salakhutdinov](https://www.cs.cmu.edu/~rsalakhu/)$^1$, [Dieter Fox](https://homes.cs.washington.edu/~fox/)$^2$ 7 | 8 | $^1$ CMU, $^2$ NVIDIA 9 | 10 | [Project Page](https://mihdalal.github.io/optimus/) | [Arxiv](https://arxiv.org/abs/2305.16309) | [Video](https://www.youtube.com/watch?v=2ItlsuNWi6Y) 11 | 12 | 13 | 14 |
15 | Optimus is a framework for training large scale imitation policies for robotic manipulation by distilling Task and Motion Planning into visuomotor Transformers. In this release we include datasets for replicating our results on Robosuite as well as code for performing TAMP data filtration and training/evaluating visuomotor Transformers on TAMP data. 16 | 17 | If you find this codebase useful in your research, please cite: 18 | ```bibtex 19 | @inproceedings{dalal2023optimus, 20 | title={Imitating Task and Motion Planning with Visuomotor Transformers}, 21 | author={Dalal, Murtaza and Mandlekar, Ajay and Garrett, Caelan and Handa, Ankur and Salakhutdinov, Ruslan and Fox, Dieter}, 22 | journal={Conference on Robot Learning}, 23 | year={2023} 24 | } 25 | ``` 26 | 27 | # Table of Contents 28 | 29 | - [Installation](#installation) 30 | - [Dataset Download](#dataset-download) 31 | - [TAMP Data Cleaning](#tamp-data-cleaning) 32 | - [Model Training](#model-training) 33 | - [Model Inference](#model-inference) 34 | - [Task Visualizations](#task-visualizations) 35 | - [Troubleshooting and Known Issues](#troubleshooting-and-known-issues) 36 | - [Citation](#citation) 37 | 38 | # Installation 39 | To install dependencies, please run the following commands: 40 | ``` 41 | sudo apt-get update 42 | sudo apt-get install -y \ 43 | htop screen tmux \ 44 | sshfs libosmesa6-dev wget curl git \ 45 | libeigen3-dev \ 46 | liborocos-kdl-dev \ 47 | libkdl-parser-dev \ 48 | liburdfdom-dev \ 49 | libnlopt-dev \ 50 | libnlopt-cxx-dev \ 51 | swig \ 52 | python3 \ 53 | python3-pip \ 54 | python3-dev \ 55 | vim \ 56 | git-lfs \ 57 | cmake \ 58 | software-properties-common \ 59 | libxcursor-dev \ 60 | libxrandr-dev \ 61 | libxinerama-dev \ 62 | libxi-dev \ 63 | mesa-common-dev \ 64 | zip \ 65 | unzip \ 66 | make \ 67 | g++ \ 68 | python2.7 \ 69 | wget \ 70 | vulkan-utils \ 71 | mesa-vulkan-drivers \ 72 | apt nano rsync \ 73 | libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev \ 74 | software-properties-common net-tools unzip virtualenv \ 75 | xpra xserver-xorg-dev libglfw3-dev patchelf python3-pip -y \ 76 | && add-apt-repository -y ppa:openscad/releases && apt-get update && apt-get install -y openscad 77 | ``` 78 | 79 | Please add the following to your bashrc/zshrc: 80 | ``` 81 | export MUJOCO_GL='egl' 82 | WANDB_API_KEY=... 83 | ``` 84 | 85 | To install python requirements: 86 | 87 | ``` 88 | conda create -n optimus python=3.8 89 | conda activate optimus 90 | pip install -r requirements.txt 91 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 92 | pip install -e . 93 | ``` 94 | 95 | # Dataset Download 96 | 97 | #### Method 1: Using `download_datasets.py` (Recommended) 98 | 99 | `download_datasets.py` (located at `optimus/scripts`) is a python script that provides a programmatic way of downloading the datasets. This is the preferred method, because this script also sets up a directory structure for the datasets that works out of the box with the code for reproducing policy learning results. 100 | 101 | A few examples of using this script are provided below: 102 | 103 | ``` 104 | # default behavior - just download Stack dataset 105 | python download_datasets.py 106 | 107 | # download datasets for Stack and Stack Three 108 | python download_datasets.py --tasks Stack StackThree 109 | 110 | # download all datasets, but do a dry run first to see what will be downloaded and where 111 | python download_datasets.py --tasks all --dry_run 112 | 113 | # download all datasets for all tasks 114 | python download_datasets.py --tasks all # this downloads Stack, StackThree, StackFour and StackFive 115 | ``` 116 | 117 | #### Method 2: Using Direct Download Links 118 | 119 | You can download the datasets manually through Google Drive. 120 | 121 | **Google Drive folder with all datasets:** [link](https://drive.google.com/drive/folders/1Dfi313igOuvc5JUCMTzQrSixUKndhW__?usp=drive_link) 122 | 123 | # TAMP Data Cleaning 124 | As described in Section 3.2 of the Optimus paper, we develop two TAMP demonstration filtering strategies to curb variance in the expert supervision: 1) Prune out demonstrations that have out of distribution trajectory length 2) Remove demonstrations that exit the visible workspace. In practice, we remove trajectories that have length greater than 2 standard deviations than the mean and exit a pre-defined workspace which includes all visible regions from the fixed camera viewpoint. 125 | 126 | We include the data filtration code in `filter_trajectories.py` and give example usage below. 127 | ``` 128 | # usage: 129 | python optimus/scripts/filter_trajectories.py --hdf5_paths datasets/<>.hdf5 --x_bounds <> <> --y_bounds <> <> --z_bounds <> <> --val_ratio <> --filter_key_prefix <> --outlier_traj_length_sd <> 130 | 131 | # example 132 | python optimus/scripts/filter_trajectories.py --hdf5_paths datasets/robosuite_stack.hdf5 --outlier_traj_length_sd 2 --x_bounds -0.2 0.2 --y_bounds -0.2 0.2 --z_bounds 0 1.1 --val_ratio 0.1 133 | ``` 134 | 135 | For the datasets that we have released, we have already performed these filtration operations, so you do not need to do so. Please do not run the below 136 | 137 | # Model Training 138 | After downloading the appropriate datasets you’re interested in using by running the `download_datasets.py` script, you can train policies using the `pl_train.py` script in `optimus/scripts`. Our training code wraps around [robomimic](https://robomimic.github.io/) (the key difference is we use PyTorch Lightning), please see the robomimic docs for a detailed overview of the imitation learning code. Following the robomimic format, our training configs are located in optimus/exps/local/robosuite/, with a different folder for each environment (stack, stackthree, stackfour, stackfive). We demonstrate example usage below: 139 | ``` 140 | # usage: 141 | python optimus/scripts/pl_train.py --config optimus/exps/local/robosuite//bc_transformer.json 142 | 143 | # example: 144 | python optimus/scripts/pl_train.py --config optimus/exps/local/robosuite/stack/bc_transformer.json 145 | ``` 146 | 147 | # Model Inference 148 | Given a checkpoint (from a training run), if you want to perform inference and evaluate the model, you can use `run_trained_agent_pl.py`. This script is based on `run_trained_agent.py` in [robomimic](https://robomimic.github.io/) but adds support for our PyTorch Lightning checkpoints. Concretely you need to specify the path to a specific ckpt file (for `--agent`) and the directory which contains `config.json` (for `--resume_dir`). We demonstrate example usage below: 149 | ``` 150 | # usage: 151 | python optimus/scripts/run_trained_agent_pl.py --agent /path/to/trained_model.ckpt --resume_dir /path/to/training_dir --n <> --video_path <>.mp4 152 | 153 | # example: 154 | python optimus/scripts/run_trained_agent_pl.py --agent optimus/trained_models/robosuite_trained_models/bc_transformer_stack_image/20231026101652/models/model_epoch_50_Stack_success_1.0.ckpt --resume_dir optimus/trained_models/robosuite_trained_models/bc_transformer_stack_image/20231026101652/ --n 10 --video_path stack.mp4 155 | ``` 156 | 157 | # Troubleshooting and Known Issues 158 | 159 | - If your training seems to be proceeding slowly (especially for image-based agents), it might be a problem with robomimic and more modern versions of PyTorch. We recommend PyTorch 1.12.1 (on Ubuntu, we used `pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113`). It is also a good idea to verify that the GPU is being utilized during training. 160 | - If you run into trouble with installing [egl_probe](https://github.com/StanfordVL/egl_probe) during robomimic installation (e.g. `ERROR: Failed building wheel for egl_probe`) you may need to make sure `cmake` is installed. A simple `pip install cmake` should work. 161 | - If you run into other strange installation issues, one potential fix is to launch a new terminal, activate your conda environment, and try the install commands that are failing once again. One clue that the current terminal state is corrupt and this fix will help is if you see installations going into a different conda environment than the one you have active. 162 | 163 | If you run into an error not documented above, please search through the [GitHub issues](https://github.com/NVlabs/optimus/issues), and create a new one if you cannot find a fix. 164 | 165 | ## Citation 166 | 167 | Please cite [the Optimus paper](https://arxiv.org/abs/2305.16309) if you use this code in your work: 168 | 169 | ```bibtex 170 | @inproceedings{dalal2023optimus, 171 | title={Imitating Task and Motion Planning with Visuomotor Transformers}, 172 | author={Dalal, Murtaza and Mandlekar, Ajay and Garrett, Caelan and Handa, Ankur and Salakhutdinov, Ruslan and Fox, Dieter}, 173 | journal={Conference on Robot Learning}, 174 | year={2023} 175 | } 176 | ``` 177 | -------------------------------------------------------------------------------- /optimus/exps/local/robosuite/stack/bc_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "algo_name": "bc", 3 | "experiment": { 4 | "name": "bc_transformer_stack_image", 5 | "validate": true, 6 | "logging": { 7 | "terminal_output_to_txt": false, 8 | "log_tb": true 9 | }, 10 | "save": { 11 | "enabled": true, 12 | "every_n_seconds": null, 13 | "every_n_epochs": 25, 14 | "epochs": [], 15 | "on_best_validation": false, 16 | "on_best_rollout_return": false, 17 | "on_best_rollout_success_rate": true 18 | }, 19 | "epoch_every_n_steps": 1000, 20 | "validation_epoch_every_n_steps": 10000, 21 | "env": null, 22 | "additional_envs": null, 23 | "render": false, 24 | "render_video": true, 25 | "keep_all_videos": false, 26 | "video_skip": 5, 27 | "rollout": { 28 | "enabled": true, 29 | "n": 10, 30 | "horizon": 300, 31 | "rate": 25, 32 | "warmstart": 0, 33 | "terminate_on_success": true, 34 | "parallel_rollouts": false 35 | } 36 | }, 37 | "train": { 38 | "data": "datasets/robosuite_stack.hdf5", 39 | "output_dir": "trained_models/robosuite_trained_models", 40 | "num_data_workers": 4, 41 | "hdf5_cache_mode": "low_dim", 42 | "hdf5_use_swmr": true, 43 | "hdf5_normalize_obs": false, 44 | "hdf5_filter_key": "train", 45 | "seq_length": 1, 46 | "frame_stack": 8, 47 | "pad_frame_stack": true, 48 | "pad_seq_length": false, 49 | "dataset_keys": [ 50 | "actions", 51 | "rewards", 52 | "dones" 53 | ], 54 | "goal_mode": null, 55 | "cuda": true, 56 | "batch_size": 16, 57 | "num_epochs": 1000, 58 | "seed": 1, 59 | "amp_enabled": true, 60 | "max_grad_norm": 1 61 | }, 62 | "algo": { 63 | "optim_params": { 64 | "policy": { 65 | "learning_rate": { 66 | "initial": 0.0001, 67 | "decay_factor": 0.01, 68 | "epoch_schedule": [] 69 | }, 70 | "regularization": { 71 | "L2": 0.0 72 | } 73 | } 74 | }, 75 | "loss": { 76 | "l2_weight": 1.0, 77 | "l1_weight": 0.0, 78 | "cos_weight": 0.0 79 | }, 80 | "actor_layer_dims": [], 81 | "gaussian": { 82 | "enabled": false, 83 | "fixed_std": false, 84 | "init_std": 0.1, 85 | "min_std": 0.01, 86 | "std_activation": "softplus", 87 | "low_noise_eval": true 88 | }, 89 | "gmm": { 90 | "enabled": true, 91 | "num_modes": 5, 92 | "min_std": 0.0001, 93 | "std_activation": "softplus", 94 | "low_noise_eval": true 95 | }, 96 | "vae": { 97 | "enabled": false, 98 | "latent_dim": 14, 99 | "latent_clip": null, 100 | "kl_weight": 1.0, 101 | "decoder": { 102 | "is_conditioned": true, 103 | "reconstruction_sum_across_elements": false 104 | }, 105 | "prior": { 106 | "learn": false, 107 | "is_conditioned": false, 108 | "use_gmm": false, 109 | "gmm_num_modes": 10, 110 | "gmm_learn_weights": false, 111 | "use_categorical": false, 112 | "categorical_dim": 10, 113 | "categorical_gumbel_softmax_hard": false, 114 | "categorical_init_temp": 1.0, 115 | "categorical_temp_anneal_step": 0.001, 116 | "categorical_min_temp": 0.3 117 | }, 118 | "encoder_layer_dims": [ 119 | 300, 120 | 400 121 | ], 122 | "decoder_layer_dims": [ 123 | 300, 124 | 400 125 | ], 126 | "prior_layer_dims": [ 127 | 300, 128 | 400 129 | ] 130 | }, 131 | "rnn": { 132 | "enabled": false, 133 | "horizon": 10, 134 | "hidden_dim": 400, 135 | "rnn_type": "LSTM", 136 | "num_layers": 2, 137 | "open_loop": false, 138 | "kwargs": { 139 | "bidirectional": false 140 | } 141 | }, 142 | "transformer": { 143 | "enabled": true, 144 | "condition_on_actions": false, 145 | "context_length": 8, 146 | "sinusoidal_embedding": false, 147 | "relative_timestep": true, 148 | "supervise_all_steps": false, 149 | "nn_parameter_for_timesteps": true, 150 | "latent_dim":0, 151 | "kl_loss_weight": 0.0 152 | } 153 | }, 154 | "observation": { 155 | "modalities": { 156 | "obs": { 157 | "low_dim": [ 158 | "robot0_eef_pos", 159 | "robot0_eef_quat", 160 | "robot0_gripper_qpos", 161 | "timesteps" 162 | ], 163 | "rgb": [ 164 | "agentview_image", 165 | "robot0_eye_in_hand_image" 166 | ], 167 | "depth": [], 168 | "scan": [] 169 | }, 170 | "goal": { 171 | "low_dim": [], 172 | "rgb": [], 173 | "depth": [], 174 | "scan": [] 175 | } 176 | }, 177 | "encoder": { 178 | "low_dim": { 179 | "core_class": null, 180 | "core_kwargs": {}, 181 | "obs_randomizer_class": null, 182 | "obs_randomizer_kwargs": {} 183 | }, 184 | "rgb": { 185 | "core_class": "VisualCore", 186 | "core_kwargs": { 187 | "feature_dimension": 64, 188 | "flatten": true, 189 | "backbone_class": "ResNet18Conv", 190 | "backbone_kwargs": { 191 | "pretrained": false, 192 | "input_coord_conv": false 193 | }, 194 | "pool_class": "SpatialSoftmax", 195 | "pool_kwargs": { 196 | "num_kp": 32, 197 | "learnable_temperature": false, 198 | "temperature": 1.0, 199 | "noise_std": 0.0, 200 | "output_variance": false 201 | } 202 | }, 203 | "obs_randomizer_class": "CropRandomizer", 204 | "obs_randomizer_kwargs": { 205 | "crop_height": 76, 206 | "crop_width": 76, 207 | "num_crops": 1, 208 | "pos_enc": false 209 | } 210 | }, 211 | "depth": { 212 | "core_class": "VisualCore", 213 | "core_kwargs": { 214 | "feature_dimension": 64, 215 | "flatten": true, 216 | "backbone_class": "ResNet18Conv", 217 | "backbone_kwargs": { 218 | "pretrained": false, 219 | "input_coord_conv": false 220 | }, 221 | "pool_class": "SpatialSoftmax", 222 | "pool_kwargs": { 223 | "num_kp": 32, 224 | "learnable_temperature": false, 225 | "temperature": 1.0, 226 | "noise_std": 0.0, 227 | "output_variance": false 228 | } 229 | }, 230 | "obs_randomizer_class": null, 231 | "obs_randomizer_kwargs": { 232 | "crop_height": 76, 233 | "crop_width": 76, 234 | "num_crops": 1, 235 | "pos_enc": false 236 | } 237 | }, 238 | "scan": { 239 | "core_class": "ScanCore", 240 | "core_kwargs": { 241 | "feature_dimension": 64, 242 | "flatten": true, 243 | "pool_class": "SpatialSoftmax", 244 | "pool_kwargs": { 245 | "num_kp": 32, 246 | "learnable_temperature": false, 247 | "temperature": 1.0, 248 | "noise_std": 0.0, 249 | "output_variance": false 250 | }, 251 | "conv_activation": "relu", 252 | "conv_kwargs": { 253 | "out_channels": [ 254 | 32, 255 | 64, 256 | 64 257 | ], 258 | "kernel_size": [ 259 | 8, 260 | 4, 261 | 2 262 | ], 263 | "stride": [ 264 | 4, 265 | 2, 266 | 1 267 | ] 268 | } 269 | }, 270 | "obs_randomizer_class": null, 271 | "obs_randomizer_kwargs": { 272 | "crop_height": 76, 273 | "crop_width": 76, 274 | "num_crops": 1, 275 | "pos_enc": false 276 | } 277 | } 278 | } 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /optimus/exps/local/robosuite/stackfive/bc_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "algo_name": "bc", 3 | "experiment": { 4 | "name": "bc_transformer_stack_five_image", 5 | "validate": true, 6 | "logging": { 7 | "terminal_output_to_txt": false, 8 | "log_tb": true 9 | }, 10 | "save": { 11 | "enabled": true, 12 | "every_n_seconds": null, 13 | "every_n_epochs": 25, 14 | "epochs": [], 15 | "on_best_validation": false, 16 | "on_best_rollout_return": false, 17 | "on_best_rollout_success_rate": true 18 | }, 19 | "epoch_every_n_steps": 1000, 20 | "validation_epoch_every_n_steps": 10000, 21 | "env": null, 22 | "additional_envs": null, 23 | "render": false, 24 | "render_video": true, 25 | "keep_all_videos": false, 26 | "video_skip": 5, 27 | "rollout": { 28 | "enabled": true, 29 | "n": 50, 30 | "horizon": 1800, 31 | "rate": 25, 32 | "warmstart": 99, 33 | "terminate_on_success": true, 34 | "parallel_rollouts": false 35 | } 36 | }, 37 | "train": { 38 | "data": "datasets/robosuite_stack_five.hdf5", 39 | "output_dir": "trained_models/robosuite_trained_models", 40 | "num_data_workers": 4, 41 | "hdf5_cache_mode": "low_dim", 42 | "hdf5_use_swmr": true, 43 | "hdf5_normalize_obs": false, 44 | "hdf5_filter_key": "train", 45 | "seq_length": 1, 46 | "frame_stack": 8, 47 | "pad_frame_stack": true, 48 | "pad_seq_length": false, 49 | "dataset_keys": [ 50 | "actions", 51 | "rewards", 52 | "dones" 53 | ], 54 | "goal_mode": null, 55 | "cuda": true, 56 | "batch_size": 16, 57 | "num_epochs": 10000, 58 | "seed": 1, 59 | "amp_enabled": true, 60 | "max_grad_norm": 20 61 | }, 62 | "algo": { 63 | "optim_params": { 64 | "policy": { 65 | "learning_rate": { 66 | "initial": 0.0001, 67 | "decay_factor": 0.01, 68 | "epoch_schedule": [] 69 | }, 70 | "regularization": { 71 | "L2": 0.0 72 | } 73 | } 74 | }, 75 | "loss": { 76 | "l2_weight": 1.0, 77 | "l1_weight": 0.0, 78 | "cos_weight": 0.0 79 | }, 80 | "actor_layer_dims": [], 81 | "gaussian": { 82 | "enabled": false, 83 | "fixed_std": false, 84 | "init_std": 0.1, 85 | "min_std": 0.01, 86 | "std_activation": "softplus", 87 | "low_noise_eval": true 88 | }, 89 | "gmm": { 90 | "enabled": true, 91 | "num_modes": 5, 92 | "min_std": 0.0001, 93 | "std_activation": "softplus", 94 | "low_noise_eval": true 95 | }, 96 | "vae": { 97 | "enabled": false, 98 | "latent_dim": 14, 99 | "latent_clip": null, 100 | "kl_weight": 1.0, 101 | "decoder": { 102 | "is_conditioned": true, 103 | "reconstruction_sum_across_elements": false 104 | }, 105 | "prior": { 106 | "learn": false, 107 | "is_conditioned": false, 108 | "use_gmm": false, 109 | "gmm_num_modes": 10, 110 | "gmm_learn_weights": false, 111 | "use_categorical": false, 112 | "categorical_dim": 10, 113 | "categorical_gumbel_softmax_hard": false, 114 | "categorical_init_temp": 1.0, 115 | "categorical_temp_anneal_step": 0.001, 116 | "categorical_min_temp": 0.3 117 | }, 118 | "encoder_layer_dims": [ 119 | 300, 120 | 400 121 | ], 122 | "decoder_layer_dims": [ 123 | 300, 124 | 400 125 | ], 126 | "prior_layer_dims": [ 127 | 300, 128 | 400 129 | ] 130 | }, 131 | "rnn": { 132 | "enabled": false, 133 | "horizon": 10, 134 | "hidden_dim": 400, 135 | "rnn_type": "LSTM", 136 | "num_layers": 2, 137 | "open_loop": false, 138 | "kwargs": { 139 | "bidirectional": false 140 | } 141 | }, 142 | "transformer": { 143 | "enabled": true, 144 | "condition_on_actions": false, 145 | "context_length": 8, 146 | "sinusoidal_embedding": false, 147 | "relative_timestep": true, 148 | "supervise_all_steps": false, 149 | "nn_parameter_for_timesteps": true, 150 | "latent_dim":0, 151 | "kl_loss_weight": 0.0 152 | } 153 | }, 154 | "observation": { 155 | "modalities": { 156 | "obs": { 157 | "low_dim": [ 158 | "robot0_eef_pos", 159 | "robot0_eef_quat", 160 | "robot0_gripper_qpos", 161 | "timesteps" 162 | ], 163 | "rgb": [ 164 | "agentview_image", 165 | "robot0_eye_in_hand_image" 166 | ], 167 | "depth": [], 168 | "scan": [] 169 | }, 170 | "goal": { 171 | "low_dim": [], 172 | "rgb": [], 173 | "depth": [], 174 | "scan": [] 175 | } 176 | }, 177 | "encoder": { 178 | "low_dim": { 179 | "core_class": null, 180 | "core_kwargs": {}, 181 | "obs_randomizer_class": null, 182 | "obs_randomizer_kwargs": {} 183 | }, 184 | "rgb": { 185 | "core_class": "VisualCore", 186 | "core_kwargs": { 187 | "feature_dimension": 64, 188 | "flatten": true, 189 | "backbone_class": "ResNet18Conv", 190 | "backbone_kwargs": { 191 | "pretrained": false, 192 | "input_coord_conv": false 193 | }, 194 | "pool_class": "SpatialSoftmax", 195 | "pool_kwargs": { 196 | "num_kp": 32, 197 | "learnable_temperature": false, 198 | "temperature": 1.0, 199 | "noise_std": 0.0, 200 | "output_variance": false 201 | } 202 | }, 203 | "obs_randomizer_class": "CropRandomizer", 204 | "obs_randomizer_kwargs": { 205 | "crop_height": 76, 206 | "crop_width": 76, 207 | "num_crops": 1, 208 | "pos_enc": false 209 | } 210 | }, 211 | "depth": { 212 | "core_class": "VisualCore", 213 | "core_kwargs": { 214 | "feature_dimension": 64, 215 | "flatten": true, 216 | "backbone_class": "ResNet18Conv", 217 | "backbone_kwargs": { 218 | "pretrained": false, 219 | "input_coord_conv": false 220 | }, 221 | "pool_class": "SpatialSoftmax", 222 | "pool_kwargs": { 223 | "num_kp": 32, 224 | "learnable_temperature": false, 225 | "temperature": 1.0, 226 | "noise_std": 0.0, 227 | "output_variance": false 228 | } 229 | }, 230 | "obs_randomizer_class": null, 231 | "obs_randomizer_kwargs": { 232 | "crop_height": 76, 233 | "crop_width": 76, 234 | "num_crops": 1, 235 | "pos_enc": false 236 | } 237 | }, 238 | "scan": { 239 | "core_class": "ScanCore", 240 | "core_kwargs": { 241 | "feature_dimension": 64, 242 | "flatten": true, 243 | "pool_class": "SpatialSoftmax", 244 | "pool_kwargs": { 245 | "num_kp": 32, 246 | "learnable_temperature": false, 247 | "temperature": 1.0, 248 | "noise_std": 0.0, 249 | "output_variance": false 250 | }, 251 | "conv_activation": "relu", 252 | "conv_kwargs": { 253 | "out_channels": [ 254 | 32, 255 | 64, 256 | 64 257 | ], 258 | "kernel_size": [ 259 | 8, 260 | 4, 261 | 2 262 | ], 263 | "stride": [ 264 | 4, 265 | 2, 266 | 1 267 | ] 268 | } 269 | }, 270 | "obs_randomizer_class": null, 271 | "obs_randomizer_kwargs": { 272 | "crop_height": 76, 273 | "crop_width": 76, 274 | "num_crops": 1, 275 | "pos_enc": false 276 | } 277 | } 278 | } 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /optimus/exps/local/robosuite/stackfour/bc_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "algo_name": "bc", 3 | "experiment": { 4 | "name": "bc_transformer_stack_four_image", 5 | "validate": true, 6 | "logging": { 7 | "terminal_output_to_txt": false, 8 | "log_tb": true 9 | }, 10 | "save": { 11 | "enabled": true, 12 | "every_n_seconds": null, 13 | "every_n_epochs": 25, 14 | "epochs": [], 15 | "on_best_validation": false, 16 | "on_best_rollout_return": false, 17 | "on_best_rollout_success_rate": true 18 | }, 19 | "epoch_every_n_steps": 1000, 20 | "validation_epoch_every_n_steps": 10000, 21 | "env": null, 22 | "additional_envs": null, 23 | "render": false, 24 | "render_video": true, 25 | "keep_all_videos": false, 26 | "video_skip": 5, 27 | "rollout": { 28 | "enabled": true, 29 | "n": 50, 30 | "horizon": 1350, 31 | "rate": 25, 32 | "warmstart": 99, 33 | "terminate_on_success": true, 34 | "parallel_rollouts": false 35 | } 36 | }, 37 | "train": { 38 | "data": "datasets/robosuite_stack_four.hdf5", 39 | "output_dir": "trained_models/robosuite_trained_models", 40 | "num_data_workers": 4, 41 | "hdf5_cache_mode": "low_dim", 42 | "hdf5_use_swmr": true, 43 | "hdf5_normalize_obs": false, 44 | "hdf5_filter_key": "train", 45 | "seq_length": 1, 46 | "frame_stack": 8, 47 | "pad_frame_stack": true, 48 | "pad_seq_length": false, 49 | "dataset_keys": [ 50 | "actions", 51 | "rewards", 52 | "dones" 53 | ], 54 | "goal_mode": null, 55 | "cuda": true, 56 | "batch_size": 16, 57 | "num_epochs": 10000, 58 | "seed": 1, 59 | "amp_enabled": true, 60 | "max_grad_norm": 20 61 | }, 62 | "algo": { 63 | "optim_params": { 64 | "policy": { 65 | "learning_rate": { 66 | "initial": 0.0001, 67 | "decay_factor": 0.01, 68 | "epoch_schedule": [] 69 | }, 70 | "regularization": { 71 | "L2": 0.0 72 | } 73 | } 74 | }, 75 | "loss": { 76 | "l2_weight": 1.0, 77 | "l1_weight": 0.0, 78 | "cos_weight": 0.0 79 | }, 80 | "actor_layer_dims": [], 81 | "gaussian": { 82 | "enabled": false, 83 | "fixed_std": false, 84 | "init_std": 0.1, 85 | "min_std": 0.01, 86 | "std_activation": "softplus", 87 | "low_noise_eval": true 88 | }, 89 | "gmm": { 90 | "enabled": true, 91 | "num_modes": 5, 92 | "min_std": 0.0001, 93 | "std_activation": "softplus", 94 | "low_noise_eval": true 95 | }, 96 | "vae": { 97 | "enabled": false, 98 | "latent_dim": 14, 99 | "latent_clip": null, 100 | "kl_weight": 1.0, 101 | "decoder": { 102 | "is_conditioned": true, 103 | "reconstruction_sum_across_elements": false 104 | }, 105 | "prior": { 106 | "learn": false, 107 | "is_conditioned": false, 108 | "use_gmm": false, 109 | "gmm_num_modes": 10, 110 | "gmm_learn_weights": false, 111 | "use_categorical": false, 112 | "categorical_dim": 10, 113 | "categorical_gumbel_softmax_hard": false, 114 | "categorical_init_temp": 1.0, 115 | "categorical_temp_anneal_step": 0.001, 116 | "categorical_min_temp": 0.3 117 | }, 118 | "encoder_layer_dims": [ 119 | 300, 120 | 400 121 | ], 122 | "decoder_layer_dims": [ 123 | 300, 124 | 400 125 | ], 126 | "prior_layer_dims": [ 127 | 300, 128 | 400 129 | ] 130 | }, 131 | "rnn": { 132 | "enabled": false, 133 | "horizon": 10, 134 | "hidden_dim": 400, 135 | "rnn_type": "LSTM", 136 | "num_layers": 2, 137 | "open_loop": false, 138 | "kwargs": { 139 | "bidirectional": false 140 | } 141 | }, 142 | "transformer": { 143 | "enabled": true, 144 | "condition_on_actions": false, 145 | "context_length": 8, 146 | "sinusoidal_embedding": false, 147 | "relative_timestep": true, 148 | "supervise_all_steps": false, 149 | "nn_parameter_for_timesteps": true, 150 | "latent_dim":0, 151 | "kl_loss_weight": 0.0 152 | } 153 | }, 154 | "observation": { 155 | "modalities": { 156 | "obs": { 157 | "low_dim": [ 158 | "robot0_eef_pos", 159 | "robot0_eef_quat", 160 | "robot0_gripper_qpos", 161 | "timesteps" 162 | ], 163 | "rgb": [ 164 | "agentview_image", 165 | "robot0_eye_in_hand_image" 166 | ], 167 | "depth": [], 168 | "scan": [] 169 | }, 170 | "goal": { 171 | "low_dim": [], 172 | "rgb": [], 173 | "depth": [], 174 | "scan": [] 175 | } 176 | }, 177 | "encoder": { 178 | "low_dim": { 179 | "core_class": null, 180 | "core_kwargs": {}, 181 | "obs_randomizer_class": null, 182 | "obs_randomizer_kwargs": {} 183 | }, 184 | "rgb": { 185 | "core_class": "VisualCore", 186 | "core_kwargs": { 187 | "feature_dimension": 64, 188 | "flatten": true, 189 | "backbone_class": "ResNet18Conv", 190 | "backbone_kwargs": { 191 | "pretrained": false, 192 | "input_coord_conv": false 193 | }, 194 | "pool_class": "SpatialSoftmax", 195 | "pool_kwargs": { 196 | "num_kp": 32, 197 | "learnable_temperature": false, 198 | "temperature": 1.0, 199 | "noise_std": 0.0, 200 | "output_variance": false 201 | } 202 | }, 203 | "obs_randomizer_class": "CropRandomizer", 204 | "obs_randomizer_kwargs": { 205 | "crop_height": 76, 206 | "crop_width": 76, 207 | "num_crops": 1, 208 | "pos_enc": false 209 | } 210 | }, 211 | "depth": { 212 | "core_class": "VisualCore", 213 | "core_kwargs": { 214 | "feature_dimension": 64, 215 | "flatten": true, 216 | "backbone_class": "ResNet18Conv", 217 | "backbone_kwargs": { 218 | "pretrained": false, 219 | "input_coord_conv": false 220 | }, 221 | "pool_class": "SpatialSoftmax", 222 | "pool_kwargs": { 223 | "num_kp": 32, 224 | "learnable_temperature": false, 225 | "temperature": 1.0, 226 | "noise_std": 0.0, 227 | "output_variance": false 228 | } 229 | }, 230 | "obs_randomizer_class": null, 231 | "obs_randomizer_kwargs": { 232 | "crop_height": 76, 233 | "crop_width": 76, 234 | "num_crops": 1, 235 | "pos_enc": false 236 | } 237 | }, 238 | "scan": { 239 | "core_class": "ScanCore", 240 | "core_kwargs": { 241 | "feature_dimension": 64, 242 | "flatten": true, 243 | "pool_class": "SpatialSoftmax", 244 | "pool_kwargs": { 245 | "num_kp": 32, 246 | "learnable_temperature": false, 247 | "temperature": 1.0, 248 | "noise_std": 0.0, 249 | "output_variance": false 250 | }, 251 | "conv_activation": "relu", 252 | "conv_kwargs": { 253 | "out_channels": [ 254 | 32, 255 | 64, 256 | 64 257 | ], 258 | "kernel_size": [ 259 | 8, 260 | 4, 261 | 2 262 | ], 263 | "stride": [ 264 | 4, 265 | 2, 266 | 1 267 | ] 268 | } 269 | }, 270 | "obs_randomizer_class": null, 271 | "obs_randomizer_kwargs": { 272 | "crop_height": 76, 273 | "crop_width": 76, 274 | "num_crops": 1, 275 | "pos_enc": false 276 | } 277 | } 278 | } 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /optimus/exps/local/robosuite/stackthree/bc_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "algo_name": "bc", 3 | "experiment": { 4 | "name": "bc_transformer_stack_three_image", 5 | "validate": true, 6 | "logging": { 7 | "terminal_output_to_txt": false, 8 | "log_tb": true 9 | }, 10 | "save": { 11 | "enabled": true, 12 | "every_n_seconds": null, 13 | "every_n_epochs": 25, 14 | "epochs": [], 15 | "on_best_validation": false, 16 | "on_best_rollout_return": false, 17 | "on_best_rollout_success_rate": true 18 | }, 19 | "epoch_every_n_steps": 1000, 20 | "validation_epoch_every_n_steps": 10000, 21 | "env": null, 22 | "additional_envs": null, 23 | "render": false, 24 | "render_video": true, 25 | "keep_all_videos": false, 26 | "video_skip": 5, 27 | "rollout": { 28 | "enabled": true, 29 | "n": 50, 30 | "horizon": 900, 31 | "rate": 25, 32 | "warmstart": 99, 33 | "terminate_on_success": true, 34 | "parallel_rollouts": false 35 | } 36 | }, 37 | "train": { 38 | "data": "datasets/robosuite_stack_three.hdf5", 39 | "output_dir": "trained_models/robosuite_trained_models", 40 | "num_data_workers": 4, 41 | "hdf5_cache_mode": "low_dim", 42 | "hdf5_use_swmr": true, 43 | "hdf5_normalize_obs": false, 44 | "hdf5_filter_key": "train", 45 | "seq_length": 1, 46 | "frame_stack": 8, 47 | "pad_frame_stack": true, 48 | "pad_seq_length": false, 49 | "dataset_keys": [ 50 | "actions", 51 | "rewards", 52 | "dones" 53 | ], 54 | "goal_mode": null, 55 | "cuda": true, 56 | "batch_size": 16, 57 | "num_epochs": 10000, 58 | "seed": 1, 59 | "amp_enabled": true, 60 | "max_grad_norm": 20 61 | }, 62 | "algo": { 63 | "optim_params": { 64 | "policy": { 65 | "learning_rate": { 66 | "initial": 0.0001, 67 | "decay_factor": 0.01, 68 | "epoch_schedule": [] 69 | }, 70 | "regularization": { 71 | "L2": 0.0 72 | } 73 | } 74 | }, 75 | "loss": { 76 | "l2_weight": 1.0, 77 | "l1_weight": 0.0, 78 | "cos_weight": 0.0 79 | }, 80 | "actor_layer_dims": [], 81 | "gaussian": { 82 | "enabled": false, 83 | "fixed_std": false, 84 | "init_std": 0.1, 85 | "min_std": 0.01, 86 | "std_activation": "softplus", 87 | "low_noise_eval": true 88 | }, 89 | "gmm": { 90 | "enabled": true, 91 | "num_modes": 5, 92 | "min_std": 0.0001, 93 | "std_activation": "softplus", 94 | "low_noise_eval": true 95 | }, 96 | "vae": { 97 | "enabled": false, 98 | "latent_dim": 14, 99 | "latent_clip": null, 100 | "kl_weight": 1.0, 101 | "decoder": { 102 | "is_conditioned": true, 103 | "reconstruction_sum_across_elements": false 104 | }, 105 | "prior": { 106 | "learn": false, 107 | "is_conditioned": false, 108 | "use_gmm": false, 109 | "gmm_num_modes": 10, 110 | "gmm_learn_weights": false, 111 | "use_categorical": false, 112 | "categorical_dim": 10, 113 | "categorical_gumbel_softmax_hard": false, 114 | "categorical_init_temp": 1.0, 115 | "categorical_temp_anneal_step": 0.001, 116 | "categorical_min_temp": 0.3 117 | }, 118 | "encoder_layer_dims": [ 119 | 300, 120 | 400 121 | ], 122 | "decoder_layer_dims": [ 123 | 300, 124 | 400 125 | ], 126 | "prior_layer_dims": [ 127 | 300, 128 | 400 129 | ] 130 | }, 131 | "rnn": { 132 | "enabled": false, 133 | "horizon": 10, 134 | "hidden_dim": 400, 135 | "rnn_type": "LSTM", 136 | "num_layers": 2, 137 | "open_loop": false, 138 | "kwargs": { 139 | "bidirectional": false 140 | } 141 | }, 142 | "transformer": { 143 | "enabled": true, 144 | "condition_on_actions": false, 145 | "context_length": 8, 146 | "sinusoidal_embedding": false, 147 | "relative_timestep": true, 148 | "supervise_all_steps": false, 149 | "nn_parameter_for_timesteps": true, 150 | "latent_dim":0, 151 | "kl_loss_weight": 0.0 152 | } 153 | }, 154 | "observation": { 155 | "modalities": { 156 | "obs": { 157 | "low_dim": [ 158 | "robot0_eef_pos", 159 | "robot0_eef_quat", 160 | "robot0_gripper_qpos", 161 | "timesteps" 162 | ], 163 | "rgb": [ 164 | "agentview_image", 165 | "robot0_eye_in_hand_image" 166 | ], 167 | "depth": [], 168 | "scan": [] 169 | }, 170 | "goal": { 171 | "low_dim": [], 172 | "rgb": [], 173 | "depth": [], 174 | "scan": [] 175 | } 176 | }, 177 | "encoder": { 178 | "low_dim": { 179 | "core_class": null, 180 | "core_kwargs": {}, 181 | "obs_randomizer_class": null, 182 | "obs_randomizer_kwargs": {} 183 | }, 184 | "rgb": { 185 | "core_class": "VisualCore", 186 | "core_kwargs": { 187 | "feature_dimension": 64, 188 | "flatten": true, 189 | "backbone_class": "ResNet18Conv", 190 | "backbone_kwargs": { 191 | "pretrained": false, 192 | "input_coord_conv": false 193 | }, 194 | "pool_class": "SpatialSoftmax", 195 | "pool_kwargs": { 196 | "num_kp": 32, 197 | "learnable_temperature": false, 198 | "temperature": 1.0, 199 | "noise_std": 0.0, 200 | "output_variance": false 201 | } 202 | }, 203 | "obs_randomizer_class": "CropRandomizer", 204 | "obs_randomizer_kwargs": { 205 | "crop_height": 76, 206 | "crop_width": 76, 207 | "num_crops": 1, 208 | "pos_enc": false 209 | } 210 | }, 211 | "depth": { 212 | "core_class": "VisualCore", 213 | "core_kwargs": { 214 | "feature_dimension": 64, 215 | "flatten": true, 216 | "backbone_class": "ResNet18Conv", 217 | "backbone_kwargs": { 218 | "pretrained": false, 219 | "input_coord_conv": false 220 | }, 221 | "pool_class": "SpatialSoftmax", 222 | "pool_kwargs": { 223 | "num_kp": 32, 224 | "learnable_temperature": false, 225 | "temperature": 1.0, 226 | "noise_std": 0.0, 227 | "output_variance": false 228 | } 229 | }, 230 | "obs_randomizer_class": null, 231 | "obs_randomizer_kwargs": { 232 | "crop_height": 76, 233 | "crop_width": 76, 234 | "num_crops": 1, 235 | "pos_enc": false 236 | } 237 | }, 238 | "scan": { 239 | "core_class": "ScanCore", 240 | "core_kwargs": { 241 | "feature_dimension": 64, 242 | "flatten": true, 243 | "pool_class": "SpatialSoftmax", 244 | "pool_kwargs": { 245 | "num_kp": 32, 246 | "learnable_temperature": false, 247 | "temperature": 1.0, 248 | "noise_std": 0.0, 249 | "output_variance": false 250 | }, 251 | "conv_activation": "relu", 252 | "conv_kwargs": { 253 | "out_channels": [ 254 | 32, 255 | 64, 256 | 64 257 | ], 258 | "kernel_size": [ 259 | 8, 260 | 4, 261 | 2 262 | ], 263 | "stride": [ 264 | 4, 265 | 2, 266 | 1 267 | ] 268 | } 269 | }, 270 | "obs_randomizer_class": null, 271 | "obs_randomizer_kwargs": { 272 | "crop_height": 76, 273 | "crop_width": 76, 274 | "num_crops": 1, 275 | "pos_enc": false 276 | } 277 | } 278 | } 279 | } 280 | } 281 | -------------------------------------------------------------------------------- /optimus/utils/env_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | """ 5 | This file contains several utility functions for working with environment 6 | wrappers provided by the repository, and with environment metadata saved 7 | in dataset files. 8 | """ 9 | from copy import deepcopy 10 | 11 | import robomimic.envs.env_base as EB 12 | 13 | 14 | def get_env_class(env_meta=None, env_type=None, env=None): 15 | """ 16 | Return env class from either env_meta, env_type, or env. 17 | Note the use of lazy imports - this ensures that modules are only 18 | imported when the corresponding env type is requested. This can 19 | be useful in practice. For example, a training run that only 20 | requires access to gym environments should not need to import 21 | robosuite. 22 | 23 | Args: 24 | env_meta (dict): environment metadata, which should be loaded from demonstration 25 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see 26 | @FileUtils.env_from_checkpoint). Contains 3 keys: 27 | 28 | :`'env_name'`: name of environment 29 | :`'type'`: type of environment, should be a value in EB.EnvType 30 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor 31 | 32 | env_type (int): the type of environment, which determines the env class that will 33 | be instantiated. Should be a value in EB.EnvType. 34 | 35 | env (instance of EB.EnvBase): environment instance 36 | """ 37 | env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env) 38 | if env_type == EB.EnvType.ROBOSUITE_TYPE: 39 | from optimus.envs.env_robosuite import EnvRobosuite 40 | 41 | return EnvRobosuite 42 | elif env_type == EB.EnvType.GYM_TYPE: 43 | from robomimic.envs.env_gym import EnvGym 44 | 45 | return EnvGym 46 | elif env_type == EB.EnvType.IG_MOMART_TYPE: 47 | from robomimic.envs.env_ig_momart import EnvGibsonMOMART 48 | 49 | return EnvGibsonMOMART 50 | raise Exception("code should never reach this point") 51 | 52 | 53 | def get_env_type(env_meta=None, env_type=None, env=None): 54 | """ 55 | Helper function to get env_type from a variety of inputs. 56 | 57 | Args: 58 | env_meta (dict): environment metadata, which should be loaded from demonstration 59 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see 60 | @FileUtils.env_from_checkpoint). Contains 3 keys: 61 | 62 | :`'env_name'`: name of environment 63 | :`'type'`: type of environment, should be a value in EB.EnvType 64 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor 65 | 66 | env_type (int): the type of environment, which determines the env class that will 67 | be instantiated. Should be a value in EB.EnvType. 68 | 69 | env (instance of EB.EnvBase): environment instance 70 | """ 71 | checks = [(env_meta is not None), (env_type is not None), (env is not None)] 72 | assert sum(checks) == 1, "should provide only one of env_meta, env_type, env" 73 | if env_meta is not None: 74 | env_type = env_meta["type"] 75 | elif env is not None: 76 | env_type = env.type 77 | return env_type 78 | 79 | 80 | def check_env_type(type_to_check, env_meta=None, env_type=None, env=None): 81 | """ 82 | Checks whether the passed env_meta, env_type, or env is of type @type_to_check. 83 | Type corresponds to EB.EnvType. 84 | 85 | Args: 86 | type_to_check (int): type to check equality against 87 | 88 | env_meta (dict): environment metadata, which should be loaded from demonstration 89 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see 90 | @FileUtils.env_from_checkpoint). Contains 3 keys: 91 | 92 | :`'env_name'`: name of environment 93 | :`'type'`: type of environment, should be a value in EB.EnvType 94 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor 95 | 96 | env_type (int): the type of environment, which determines the env class that will 97 | be instantiated. Should be a value in EB.EnvType. 98 | 99 | env (instance of EB.EnvBase): environment instance 100 | """ 101 | env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env) 102 | return env_type == type_to_check 103 | 104 | 105 | def is_robosuite_env(env_meta=None, env_type=None, env=None): 106 | """ 107 | Determines whether the environment is a robosuite environment. Accepts 108 | either env_meta, env_type, or env. 109 | """ 110 | return check_env_type( 111 | type_to_check=EB.EnvType.ROBOSUITE_TYPE, 112 | env_meta=env_meta, 113 | env_type=env_type, 114 | env=env, 115 | ) 116 | 117 | 118 | def create_env( 119 | env_type, 120 | env_name, 121 | render=False, 122 | render_offscreen=False, 123 | use_image_obs=False, 124 | env_meta=None, 125 | **kwargs, 126 | ): 127 | """ 128 | Create environment. 129 | 130 | Args: 131 | env_type (int): the type of environment, which determines the env class that will 132 | be instantiated. Should be a value in EB.EnvType. 133 | 134 | env_name (str): name of environment 135 | 136 | render (bool): if True, environment supports on-screen rendering 137 | 138 | render_offscreen (bool): if True, environment supports off-screen rendering. This 139 | is forced to be True if @use_image_obs is True. 140 | 141 | use_image_obs (bool): if True, environment is expected to render rgb image observations 142 | on every env.step call. Set this to False for efficiency reasons, if image 143 | observations are not required. 144 | """ 145 | 146 | # note: pass @postprocess_visual_obs True, to make sure images are processed for network inputs 147 | env_class = get_env_class(env_meta=env_meta) 148 | env = env_class( 149 | env_name=env_name, 150 | render=render, 151 | render_offscreen=render_offscreen, 152 | use_image_obs=use_image_obs, 153 | postprocess_visual_obs=True, 154 | **kwargs, 155 | ) 156 | print("Created environment with name {}".format(env_name)) 157 | print("Action size is {}".format(env.action_dimension)) 158 | return env 159 | 160 | 161 | def create_env_from_metadata( 162 | env_meta, 163 | env_name=None, 164 | render=False, 165 | render_offscreen=False, 166 | use_image_obs=False, 167 | ): 168 | """ 169 | Create environment. 170 | 171 | Args: 172 | env_meta (dict): environment metadata, which should be loaded from demonstration 173 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see 174 | @FileUtils.env_from_checkpoint). Contains 3 keys: 175 | 176 | :`'env_name'`: name of environment 177 | :`'type'`: type of environment, should be a value in EB.EnvType 178 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor 179 | 180 | env_name (str): name of environment. Only needs to be provided if making a different 181 | environment from the one in @env_meta. 182 | 183 | render (bool): if True, environment supports on-screen rendering 184 | 185 | render_offscreen (bool): if True, environment supports off-screen rendering. This 186 | is forced to be True if @use_image_obs is True. 187 | 188 | use_image_obs (bool): if True, environment is expected to render rgb image observations 189 | on every env.step call. Set this to False for efficiency reasons, if image 190 | observations are not required. 191 | """ 192 | if env_name is None: 193 | env_name = env_meta["env_name"] 194 | env_type = get_env_type(env_meta=env_meta) 195 | env_kwargs = env_meta["env_kwargs"] 196 | 197 | env = create_env( 198 | env_type=env_type, 199 | env_name=env_name, 200 | render=render, 201 | render_offscreen=render_offscreen, 202 | use_image_obs=use_image_obs, 203 | env_meta=env_meta, 204 | **env_kwargs, 205 | ) 206 | return env 207 | 208 | 209 | def create_env_for_data_processing( 210 | env_meta, 211 | camera_names, 212 | camera_height, 213 | camera_width, 214 | reward_shaping, 215 | ): 216 | """ 217 | Creates environment for processing dataset observations and rewards. 218 | 219 | Args: 220 | env_meta (dict): environment metadata, which should be loaded from demonstration 221 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see 222 | @FileUtils.env_from_checkpoint). Contains 3 keys: 223 | 224 | :`'env_name'`: name of environment 225 | :`'type'`: type of environment, should be a value in EB.EnvType 226 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor 227 | 228 | camera_names (list of st): list of camera names that correspond to image observations 229 | 230 | camera_height (int): camera height for all cameras 231 | 232 | camera_width (int): camera width for all cameras 233 | 234 | reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards 235 | """ 236 | env_name = env_meta["env_name"] 237 | env_type = get_env_type(env_meta=env_meta) 238 | env_kwargs = env_meta["env_kwargs"] 239 | env_class = get_env_class(env_meta=env_meta) 240 | 241 | # remove possibly redundant values in kwargs 242 | env_kwargs = deepcopy(env_kwargs) 243 | env_kwargs.pop("env_name", None) 244 | env_kwargs.pop("camera_names", None) 245 | env_kwargs.pop("camera_height", None) 246 | env_kwargs.pop("camera_width", None) 247 | env_kwargs.pop("reward_shaping", None) 248 | 249 | return env_class.create_for_data_processing( 250 | env_name=env_name, 251 | camera_names=camera_names, 252 | camera_height=camera_height, 253 | camera_width=camera_width, 254 | reward_shaping=reward_shaping, 255 | **env_kwargs, 256 | ) 257 | -------------------------------------------------------------------------------- /optimus/config/bc_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | """ 5 | Config for BC algorithm. Taken from Ajay Mandlekar's private version of robomimic. 6 | """ 7 | 8 | from robomimic.config.base_config import BaseConfig 9 | 10 | 11 | class BCConfig(BaseConfig): 12 | ALGO_NAME = "bc" 13 | 14 | def algo_config(self): 15 | """ 16 | This function populates the `config.algo` attribute of the config, and is given to the 17 | `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config` 18 | argument to the constructor. Any parameter that an algorithm needs to determine its 19 | training and test-time behavior should be populated here. 20 | """ 21 | 22 | # optimization parameters 23 | self.algo.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate 24 | self.algo.optim_params.policy.learning_rate.decay_factor = ( 25 | 0.1 # factor to decay LR by (if epoch schedule non-empty) 26 | ) 27 | self.algo.optim_params.policy.learning_rate.epoch_schedule = ( 28 | [] 29 | ) # epochs where LR decay occurs 30 | self.algo.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength 31 | self.algo.optim_params.policy.learning_rate.betas = (0.9, 0.99) 32 | self.algo.optim_params.policy.learning_rate.lr_scheduler_interval = "epoch" 33 | 34 | # loss weights 35 | self.algo.loss.l2_weight = 1.0 # L2 loss weight 36 | self.algo.loss.l1_weight = 0.0 # L1 loss weight 37 | self.algo.loss.cos_weight = 0.0 # cosine loss weight 38 | 39 | # MLP network architecture (layers after observation encoder and RNN, if present) 40 | self.algo.actor_layer_dims = (1024, 1024) 41 | 42 | # stochastic Gaussian policy settings 43 | self.algo.gaussian.enabled = False # whether to train a Gaussian policy 44 | self.algo.gaussian.fixed_std = False # whether to train std output or keep it constant 45 | self.algo.gaussian.init_std = 0.1 # initial standard deviation (or constant) 46 | self.algo.gaussian.min_std = 0.01 # minimum std output from network 47 | self.algo.gaussian.std_activation = ( 48 | "softplus" # activation to use for std output from policy net 49 | ) 50 | self.algo.gaussian.low_noise_eval = True # low-std at test-time 51 | 52 | # stochastic GMM policy settings 53 | self.algo.gmm.enabled = False # whether to train a GMM policy 54 | self.algo.gmm.num_modes = 5 # number of GMM modes 55 | self.algo.gmm.min_std = 0.0001 # minimum std output from network 56 | self.algo.gmm.std_activation = ( 57 | "softplus" # activation to use for std output from policy net 58 | ) 59 | self.algo.gmm.low_noise_eval = True # low-std at test-time 60 | 61 | # stochastic VAE policy settings 62 | self.algo.vae.enabled = False # whether to train a VAE policy 63 | self.algo.vae.latent_dim = ( 64 | 14 # VAE latent dimnsion - set to twice the dimensionality of action space 65 | ) 66 | self.algo.vae.latent_clip = None # clip latent space when decoding (set to None to disable) 67 | self.algo.vae.kl_weight = ( 68 | 1.0 # beta-VAE weight to scale KL loss relative to reconstruction loss in ELBO 69 | ) 70 | 71 | # VAE decoder settings 72 | self.algo.vae.decoder.is_conditioned = ( 73 | True # whether decoder should condition on observation 74 | ) 75 | self.algo.vae.decoder.reconstruction_sum_across_elements = ( 76 | False # sum instead of mean for reconstruction loss 77 | ) 78 | 79 | # VAE prior settings 80 | self.algo.vae.prior.learn = False # learn Gaussian / GMM prior instead of N(0, 1) 81 | self.algo.vae.prior.is_conditioned = False # whether to condition prior on observations 82 | self.algo.vae.prior.use_gmm = False # whether to use GMM prior 83 | self.algo.vae.prior.gmm_num_modes = 10 # number of GMM modes 84 | self.algo.vae.prior.gmm_learn_weights = False # whether to learn GMM weights 85 | self.algo.vae.prior.use_categorical = False # whether to use categorical prior 86 | self.algo.vae.prior.categorical_dim = ( 87 | 10 # the number of categorical classes for each latent dimension 88 | ) 89 | self.algo.vae.prior.categorical_gumbel_softmax_hard = ( 90 | False # use hard selection in forward pass 91 | ) 92 | self.algo.vae.prior.categorical_init_temp = 1.0 # initial gumbel-softmax temp 93 | self.algo.vae.prior.categorical_temp_anneal_step = 0.001 # linear temp annealing rate 94 | self.algo.vae.prior.categorical_min_temp = 0.3 # lowest gumbel-softmax temp 95 | 96 | self.algo.vae.encoder_layer_dims = (300, 400) # encoder MLP layer dimensions 97 | self.algo.vae.decoder_layer_dims = (300, 400) # decoder MLP layer dimensions 98 | self.algo.vae.prior_layer_dims = ( 99 | 300, 100 | 400, 101 | ) # prior MLP layer dimensions (if learning conditioned prior) 102 | 103 | # RNN policy settings 104 | self.algo.rnn.enabled = False # whether to train RNN policy 105 | self.algo.rnn.horizon = 10 # unroll length for RNN - should usually match train.seq_length 106 | self.algo.rnn.hidden_dim = 400 # hidden dimension size 107 | self.algo.rnn.rnn_type = "LSTM" # rnn type - one of "LSTM" or "GRU" 108 | self.algo.rnn.num_layers = 2 # number of RNN layers that are stacked 109 | self.algo.rnn.open_loop = False # if True, action predictions are only based on a single observation (not sequence) 110 | self.algo.rnn.kwargs.bidirectional = False # rnn kwargs 111 | self.algo.rnn.kwargs.do_not_lock_keys() 112 | 113 | # Transformer policy settings 114 | self.algo.transformer.enabled = False # whether to train transformer policy 115 | self.algo.transformer.context_length = 64 # length of (s, a) seqeunces to feed to transformer - should usually match train.frame_stack 116 | self.algo.transformer.embed_dim = 256 # dimension for embeddings used by transformer 117 | self.algo.transformer.num_layers = 6 # number of transformer blocks to stack 118 | self.algo.transformer.num_heads = 8 # number of attention heads for each transformer block (should divide embed_dim evenly) 119 | self.algo.transformer.embedding_dropout = ( 120 | 0.1 # dropout probability for embedding inputs in transformer 121 | ) 122 | self.algo.transformer.block_attention_dropout = ( 123 | 0.1 # dropout probability for attention outputs for each transformer block 124 | ) 125 | self.algo.transformer.block_output_dropout = ( 126 | 0.1 # dropout probability for final outputs for each transformer block 127 | ) 128 | self.algo.transformer.condition_on_actions = False # whether to condition on the sequence of past actions in addition to the observation sequence 129 | self.algo.transformer.predict_obs = ( 130 | False # whether to predict observations in the output sequences as well 131 | ) 132 | self.algo.transformer.mask_inputs = False # whether to use bert style input masking 133 | self.algo.transformer.relative_timestep = True # if true timesteps range from 0 to context length-1, if false use absolute position in trajectory 134 | self.algo.transformer.euclidean_distance_timestep = False # if true timesteps are based the cumulative distance traveled by the end effector. otherwise integer timesteps 135 | self.algo.transformer.max_timestep = ( 136 | 1250 # for the nn.embedding layer, must know the maximal timestep value 137 | ) 138 | self.algo.transformer.open_loop_predictions = ( 139 | False # if true don't run transformer at every step, execute a set of predicted actions 140 | ) 141 | self.algo.transformer.sinusoidal_embedding = ( 142 | False # if True, use standard positional encodings (sin/cos) 143 | ) 144 | self.algo.transformer.obs_noise_scale = ( 145 | 0.05 # amount of noise to add to the observations during training 146 | ) 147 | self.algo.transformer.add_noise_to_train_obs = ( 148 | False # if true add noise to the observations during training 149 | ) 150 | self.algo.transformer.use_custom_transformer_block = ( 151 | True # if True, use custom transformer block 152 | ) 153 | self.algo.transformer.optimizer_type = "adamw" # toggle the type of optimizer to use 154 | self.algo.transformer.lr_scheduler_type = "linear" # toggle the type of lr_scheduler to use 155 | self.algo.transformer.lr_warmup = ( 156 | False # if True, warmup the learning rate from some small value 157 | ) 158 | self.algo.transformer.activation = ( 159 | "gelu" # activation function for MLP in Transformer Block 160 | ) 161 | self.algo.transformer.num_open_loop_actions_to_execute = ( 162 | 10 # number of actions to execute in open loop 163 | ) 164 | self.algo.transformer.supervise_all_steps = ( 165 | False # if true, supervise all intermediate actions, otherwise only final one 166 | ) 167 | self.algo.transformer.nn_parameter_for_timesteps = ( 168 | True # if true, use nn.Parameter otherwise use nn.Embedding 169 | ) 170 | self.algo.transformer.num_task_ids = 1 # number of tasks we are training with 171 | self.algo.transformer.task_id_embed_dim = 0 # dimension of the task id embedding 172 | self.algo.transformer.language_enabled = False # if true, condition on language embeddings 173 | self.algo.transformer.language_embedding = ( 174 | "raw" # string denoting the language embedding to use 175 | ) 176 | self.algo.transformer.finetune_language_embedding = ( 177 | False # if true, finetune the language embedding 178 | ) 179 | self.algo.transformer.kl_loss_weight = 0 # 5e-6 # weight of the KL loss 180 | self.algo.transformer.use_cvae = True 181 | self.algo.transformer.predict_signature = False 182 | self.algo.transformer.layer_dims = (1024, 1024) 183 | self.algo.transformer.latent_dim = 0 184 | self.algo.transformer.prior_use_gmm = True 185 | self.algo.transformer.prior_gmm_num_modes = 10 186 | self.algo.transformer.prior_gmm_learn_weights = True 187 | self.algo.transformer.replan_every_step = False 188 | self.algo.transformer.decoder = False 189 | self.algo.transformer.prior_use_categorical = False 190 | self.algo.transformer.prior_categorical_gumbel_softmax_hard = False 191 | self.algo.transformer.prior_categorical_dim = 10 192 | self.algo.transformer.primitive_type = "none" 193 | self.algo.transformer.reset_context_after_primitive_exec = True 194 | self.algo.transformer.block_drop_path = 0.0 195 | self.algo.transformer.use_cross_attention_conditioning = False 196 | self.algo.transformer.use_alternating_cross_attention_conditioning = False 197 | self.algo.transformer.key_value_from_condition = False 198 | self.algo.transformer.add_primitive_id = False 199 | self.algo.transformer.tokenize_primitive_id = False 200 | self.algo.transformer.channel_condition = False 201 | self.algo.transformer.tokenize_obs_components = False 202 | self.algo.transformer.num_patches_per_image_dim = 1 203 | self.algo.transformer.nets_to_freeze = () 204 | self.algo.transformer.use_ndp_decoder = False 205 | self.algo.transformer.ndp_decoder_kwargs = None 206 | self.algo.transformer.transformer_type = "gpt" 207 | self.algo.transformer.mega_kwargs = {} 208 | 209 | self.algo.corner_loss.enabled = False 210 | 211 | self.algo.dml.enabled = ( 212 | False # if True, use Discretized Mixture of Logistics output distribution 213 | ) 214 | self.algo.dml.num_modes = 2 # number of modes in the DML distribution 215 | self.algo.dml.num_classes = ( 216 | 256 # number of classes in each dimension of the discretized action space 217 | ) 218 | self.algo.dml.log_scale_min = -7.0 # minimum value of the log scale 219 | self.algo.dml.constant_variance = True # if True, use a constant variance 220 | 221 | self.algo.cat.enabled = False # if True, use categorical output distribution 222 | self.algo.cat.num_classes = ( 223 | 256 # number of classes in each dimension of the categorical action space 224 | ) 225 | 226 | self.env_seed = 0 # seed for environment 227 | self.wandb_project_name = "test" 228 | self.train.fast_dev_run = False # whether to run training in debug mode 229 | self.train.max_grad_norm = None # max gradient norm 230 | self.train.amp_enabled = False # use automatic mixed precision 231 | self.train.num_gpus = 1 # number of gpus to use 232 | self.experiment.rollout.goal_conditioning_enabled = ( 233 | False # if True, use goal conditioning wrapper to sample goals for inference 234 | ) 235 | self.train.load_next_obs = False # whether or not to load s' 236 | self.experiment.rollout.goal_success_threshold = 0.01 # if the distance from the goal is less than goal_success_threshold, this counts as a success 237 | self.train.pad_frame_stack = True 238 | self.train.pad_seq_length = True 239 | self.train.frame_stack = 1 240 | self.experiment.rollout.is_mujoco = False 241 | self.experiment.rollout.valid_key = "valid" 242 | self.train.ckpt_path = None 243 | self.train.use_swa = False 244 | self.experiment.rollout.parallel_rollouts = True 245 | self.experiment.rollout.select_random_subset = False 246 | self.train.save_ckpt_on_epoch_end = True 247 | 248 | self.algo.dagger.enabled = False # if true, enable dagger support 249 | self.algo.dagger.online_epoch_rate = 50 # how often to collect online data 250 | self.algo.dagger.num_rollouts = 1 # number of rollouts to collect per online epoch 251 | self.algo.dagger.rollout_type = ( 252 | "state_error_mixed" # toggle rollout type, can range from policy to closed loop tamp 253 | ) 254 | self.algo.dagger.state_error_threshold = ( 255 | 0.001 # threshold for state error - triggers TAMP solver action 256 | ) 257 | self.algo.dagger.action_error_threshold = ( 258 | 0.001 # threshold for action error - triggers TAMP solver action 259 | ) 260 | self.algo.dagger.mpc_horizon = ( 261 | 100 # MPC horizon for closed loop tamp MPC - useful for running faster 262 | ) 263 | -------------------------------------------------------------------------------- /optimus/scripts/playback_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | """ 5 | A script to visualize dataset trajectories by loading the simulation states 6 | one by one or loading the first state and playing actions back open-loop. 7 | The script can generate videos as well, by rendering simulation frames 8 | during playback. The videos can also be generated using the image observations 9 | in the dataset (this is useful for real-robot datasets) by using the 10 | --use-obs argument. 11 | 12 | Args: 13 | dataset (str): path to hdf5 dataset 14 | 15 | filter_key (str): if provided, use the subset of trajectories 16 | in the file that correspond to this filter key 17 | 18 | n (int): if provided, stop after n trajectories are processed 19 | 20 | use-obs (bool): if flag is provided, visualize trajectories with dataset 21 | image observations instead of simulator 22 | 23 | use-actions (bool): if flag is provided, use open-loop action playback 24 | instead of loading sim states 25 | 26 | render (bool): if flag is provided, use on-screen rendering during playback 27 | 28 | video_path (str): if provided, render trajectories to this video file path 29 | 30 | video_skip (int): render frames to a video every @video_skip steps 31 | 32 | render_image_names (str or [str]): camera name(s) / image observation(s) to 33 | use for rendering on-screen or to video 34 | 35 | first (bool): if flag is provided, use first frame of each episode for playback 36 | instead of the entire episode. Useful for visualizing task initializations. 37 | 38 | Example usage below: 39 | 40 | # force simulation states one by one, and render agentview and wrist view cameras to video 41 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \ 42 | --render_image_names agentview robot0_eye_in_hand \ 43 | --video_path /tmp/playback_dataset.mp4 44 | 45 | # playback the actions in the dataset, and render agentview camera during playback to video 46 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \ 47 | --use-actions --render_image_names agentview \ 48 | --video_path /tmp/playback_dataset_with_actions.mp4 49 | 50 | # use the observations stored in the dataset to render videos of the dataset trajectories 51 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \ 52 | --use-obs --render_image_names agentview_image \ 53 | --video_path /tmp/obs_trajectory.mp4 54 | 55 | # visualize initial states in the demonstration data 56 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \ 57 | --first --render_image_names agentview \ 58 | --video_path /tmp/dataset_task_inits.mp4 59 | """ 60 | 61 | import argparse 62 | 63 | import h5py 64 | import imageio 65 | import numpy as np 66 | import robomimic.utils.obs_utils as ObsUtils 67 | from robomimic.envs.env_base import EnvBase, EnvType 68 | from tqdm import tqdm 69 | 70 | import optimus.utils.env_utils as EnvUtils 71 | import optimus.utils.file_utils as FileUtils 72 | 73 | # Define default cameras to use for each env type 74 | DEFAULT_CAMERAS = { 75 | EnvType.ROBOSUITE_TYPE: ["agentview_image"], 76 | EnvType.IG_MOMART_TYPE: ["rgb"], 77 | EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"), 78 | } 79 | 80 | 81 | def playback_trajectory_with_env( 82 | env, 83 | initial_state, 84 | states, 85 | actions=None, 86 | render=False, 87 | video_writer=None, 88 | video_skip=5, 89 | camera_names=None, 90 | first=False, 91 | ): 92 | """ 93 | Helper function to playback a single trajectory using the simulator environment. 94 | If @actions are not None, it will play them open-loop after loading the initial state. 95 | Otherwise, @states are loaded one by one. 96 | 97 | Args: 98 | env (instance of EnvBase): environment 99 | initial_state (dict): initial simulation state to load 100 | states (list of dict or np.array): array of simulation states to load 101 | actions (np.array): if provided, play actions back open-loop instead of using @states 102 | render (bool): if True, render on-screen 103 | video_writer (imageio writer): video writer 104 | video_skip (int): determines rate at which environment frames are written to video 105 | camera_names (list): determines which camera(s) are used for rendering. Pass more than 106 | one to output a video with multiple camera views concatenated horizontally. 107 | first (bool): if True, only use the first frame of each episode. 108 | """ 109 | assert isinstance(env, EnvBase) or isinstance(env.env, EnvBase) 110 | 111 | write_video = video_writer is not None 112 | video_count = 0 113 | assert not (render and write_video) 114 | # load the initial state 115 | env.reset() 116 | env.reset_to(initial_state) 117 | action_playback = actions is not None 118 | traj_len = len(states) if not action_playback else len(actions) 119 | # if action_playback: 120 | # assert len(states) == (actions.shape[0] + 1) 121 | 122 | success = {k: False for k in env.is_success()} 123 | for i in range(traj_len): 124 | if action_playback: 125 | info = env.step( 126 | actions[i], 127 | # render_intermediate_obs=render 128 | # and (env.execute_controller_actions or env.execute_simplified_actions), 129 | # render_intermediate_video_obs=write_video 130 | # and (env.execute_controller_actions or env.execute_simplified_actions), 131 | )[-1] 132 | if i < traj_len - 1: 133 | # check whether the actions deterministically lead to the same recorded states 134 | state_playback = env.get_state()["states"] 135 | # else: 136 | # if not np.all(np.equal(states[i + 1], state_playback)): 137 | # err = np.linalg.norm(states[i + 1] - state_playback) 138 | # print("warning: playback diverged by {} at step {}".format(err, i)) 139 | else: 140 | env.reset_to({"states": states[i]}) 141 | 142 | # on-screen render 143 | if render: 144 | env.render(mode="human", camera_name=camera_names[0]) 145 | 146 | # video render 147 | if write_video: 148 | if video_count % video_skip == 0: 149 | video_img = [] 150 | for cam_name in camera_names: 151 | video_img.append( 152 | env.render( 153 | mode="rgb_array", 154 | height=512, 155 | width=512, 156 | camera_name=cam_name, 157 | ) 158 | ) 159 | video_img = np.concatenate(video_img, axis=1) # concatenate horizontally 160 | video_writer.append_data(video_img) 161 | video_count += 1 162 | exec_success = env.is_success() 163 | success = {k: success[k] or bool(exec_success[k]) for k in exec_success} 164 | if first: 165 | break 166 | return success 167 | 168 | 169 | def playback_trajectory_with_obs( 170 | traj_grp, 171 | video_writer, 172 | video_skip=5, 173 | image_names=None, 174 | first=False, 175 | ): 176 | """ 177 | This function reads all "rgb" observations in the dataset trajectory and 178 | writes them into a video. 179 | 180 | Args: 181 | traj_grp (hdf5 file group): hdf5 group which corresponds to the dataset trajectory to playback 182 | video_writer (imageio writer): video writer 183 | video_skip (int): determines rate at which environment frames are written to video 184 | image_names (list): determines which image observations are used for rendering. Pass more than 185 | one to output a video with multiple image observations concatenated horizontally. 186 | first (bool): if True, only use the first frame of each episode. 187 | """ 188 | assert ( 189 | image_names is not None 190 | ), "error: must specify at least one image observation to use in @image_names" 191 | video_count = 0 192 | 193 | traj_len = traj_grp["actions"].shape[0] 194 | for i in range(traj_len): 195 | if video_count % video_skip == 0: 196 | # concatenate image obs together 197 | im = [traj_grp["obs/{}".format(k)][i] for k in image_names] 198 | frame = np.concatenate(im, axis=1) 199 | video_writer.append_data(frame) 200 | video_count += 1 201 | 202 | if first: 203 | break 204 | 205 | 206 | def playback_dataset(args): 207 | # some arg checking 208 | write_video = args.video_path is not None 209 | assert not (args.render and write_video) # either on-screen or video but not both 210 | 211 | # Auto-fill camera rendering info if not specified 212 | if args.render_image_names is None: 213 | # We fill in the automatic values 214 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset) 215 | env_type = EnvUtils.get_env_type(env_meta=env_meta) 216 | args.render_image_names = DEFAULT_CAMERAS[env_type] 217 | 218 | if args.render: 219 | # on-screen rendering can only support one camera 220 | assert len(args.render_image_names) == 1 221 | 222 | if args.use_obs: 223 | assert write_video, "playback with observations can only write to video" 224 | assert ( 225 | not args.use_actions 226 | ), "playback with observations is offline and does not support action playback" 227 | 228 | # create environment only if not playing back with observations 229 | if not args.use_obs: 230 | # need to make sure ObsUtils knows which observations are images, but it doesn't matter 231 | # for playback since observations are unused. Pass a dummy spec here. 232 | dummy_spec = dict( 233 | obs=dict( 234 | low_dim=["robot0_eef_pos"], 235 | rgb=[], 236 | ), 237 | ) 238 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec) 239 | 240 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset) 241 | env = EnvUtils.create_env_from_metadata( 242 | env_meta=env_meta, render=args.render, render_offscreen=write_video 243 | ) 244 | 245 | # some operations for playback are env-type-specific 246 | is_robosuite_env = EnvUtils.is_robosuite_env(env_meta) 247 | 248 | f = h5py.File(args.dataset, "r") 249 | 250 | # list of all demonstration episodes (sorted in increasing number order) 251 | if args.filter_key is not None: 252 | print("using filter key: {}".format(args.filter_key)) 253 | demos = [elem.decode("utf-8") for elem in np.array(f["mask/{}".format(args.filter_key)])] 254 | else: 255 | demos = list(f["data"].keys()) 256 | inds = np.argsort([int(elem[5:]) for elem in demos]) 257 | demos = [demos[i] for i in inds] 258 | 259 | # maybe reduce the number of demonstrations to playback 260 | if args.n is not None: 261 | demos = demos[: args.n] 262 | 263 | # maybe dump video 264 | video_writer = None 265 | if write_video: 266 | video_writer = imageio.get_writer(args.video_path, fps=20) 267 | 268 | failed_episodes = [] 269 | for ind in tqdm(range(len(demos))): 270 | ep = demos[ind] 271 | print("Playing back episode: {}".format(ep)) 272 | 273 | if args.use_obs: 274 | playback_trajectory_with_obs( 275 | traj_grp=f["data/{}".format(ep)], 276 | video_writer=video_writer, 277 | video_skip=args.video_skip, 278 | image_names=args.render_image_names, 279 | first=args.first, 280 | ) 281 | continue 282 | # prepare states to reload from 283 | states = f["data/{}/states".format(ep)][()] 284 | initial_state = dict(states=states[0]) 285 | if is_robosuite_env: 286 | initial_state["model"] = f["data/{}".format(ep)].attrs["model_file"] 287 | try: 288 | initial_state["init_string"] = f["data/{}".format(ep)].attrs["init_string"] 289 | initial_state["goal_parts_string"] = f["data/{}".format(ep)].attrs[ 290 | "goal_parts_string" 291 | ] 292 | print(initial_state["goal_parts_string"]) 293 | except: 294 | print("No init_string or goal_parts_string in file") 295 | pass 296 | 297 | # supply actions if using open-loop action playback 298 | actions = None 299 | if args.use_actions: 300 | actions = f["data/{}/actions".format(ep)][()] 301 | success = playback_trajectory_with_env( 302 | env=env, 303 | initial_state=initial_state, 304 | states=states, 305 | actions=actions, 306 | render=args.render, 307 | video_writer=video_writer, 308 | video_skip=args.video_skip, 309 | camera_names=args.render_image_names, 310 | first=args.first, 311 | ) 312 | print("trajectory_length: {}".format(len(states))) 313 | if not success["task"]: 314 | print("Episode {} failed".format(ep)) 315 | failed_episodes.append(ep) 316 | print("Failed episodes: {}".format(failed_episodes)) 317 | 318 | f.close() 319 | if write_video: 320 | video_writer.close() 321 | 322 | 323 | if __name__ == "__main__": 324 | parser = argparse.ArgumentParser() 325 | parser.add_argument( 326 | "--dataset", 327 | type=str, 328 | help="path to hdf5 dataset", 329 | ) 330 | parser.add_argument( 331 | "--filter_key", 332 | type=str, 333 | default=None, 334 | help="(optional) filter key, to select a subset of trajectories in the file", 335 | ) 336 | 337 | # number of trajectories to playback. If omitted, playback all of them. 338 | parser.add_argument( 339 | "--n", 340 | type=int, 341 | default=None, 342 | help="(optional) stop after n trajectories are played", 343 | ) 344 | 345 | # Use image observations instead of doing playback using the simulator env. 346 | parser.add_argument( 347 | "--use-obs", 348 | action="store_true", 349 | help="visualize trajectories with dataset image observations instead of simulator", 350 | ) 351 | 352 | # Playback stored dataset actions open-loop instead of loading from simulation states. 353 | parser.add_argument( 354 | "--use-actions", 355 | action="store_true", 356 | help="use open-loop action playback instead of loading sim states", 357 | ) 358 | 359 | # Whether to render playback to screen 360 | parser.add_argument( 361 | "--render", 362 | action="store_true", 363 | help="on-screen rendering", 364 | ) 365 | 366 | # Dump a video of the dataset playback to the specified path 367 | parser.add_argument( 368 | "--video_path", 369 | type=str, 370 | default=None, 371 | help="(optional) render trajectories to this video file path", 372 | ) 373 | 374 | # How often to write video frames during the playback 375 | parser.add_argument( 376 | "--video_skip", 377 | type=int, 378 | default=5, 379 | help="render frames to video every n steps", 380 | ) 381 | 382 | # camera names to render, or image observations to use for writing to video 383 | parser.add_argument( 384 | "--render_image_names", 385 | type=str, 386 | nargs="+", 387 | default=None, 388 | help="(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is" 389 | "None, which corresponds to a predefined camera for each env type", 390 | ) 391 | 392 | # Only use the first frame of each episode 393 | parser.add_argument( 394 | "--first", 395 | action="store_true", 396 | help="use first frame of each episode", 397 | ) 398 | 399 | args = parser.parse_args() 400 | playback_dataset(args) 401 | -------------------------------------------------------------------------------- /optimus/envs/wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | """ 5 | A collection of useful environment wrappers. Taken from Ajay Mandlekar's private version of robomimic. 6 | """ 7 | import textwrap 8 | from collections import deque 9 | 10 | import h5py 11 | import numpy as np 12 | 13 | 14 | class Wrapper(object): 15 | """ 16 | Base class for all environment wrappers in optimus. 17 | """ 18 | 19 | def __init__(self, env): 20 | """ 21 | Args: 22 | env (EnvBase instance): The environment to wrap. 23 | """ 24 | self.env = env 25 | 26 | @classmethod 27 | def class_name(cls): 28 | return cls.__name__ 29 | 30 | def _warn_double_wrap(self): 31 | """ 32 | Utility function that checks if we're accidentally trying to double wrap an env 33 | Raises: 34 | Exception: [Double wrapping env] 35 | """ 36 | env = self.env 37 | while True: 38 | if isinstance(env, Wrapper): 39 | if env.class_name() == self.class_name(): 40 | raise Exception( 41 | "Attempted to double wrap with Wrapper: {}".format(self.__class__.__name__) 42 | ) 43 | env = env.env 44 | else: 45 | break 46 | 47 | @property 48 | def unwrapped(self): 49 | """ 50 | Grabs unwrapped environment 51 | 52 | Returns: 53 | env (EnvBase instance): Unwrapped environment 54 | """ 55 | if hasattr(self.env, "unwrapped"): 56 | return self.env.unwrapped 57 | else: 58 | return self.env 59 | 60 | def _to_string(self): 61 | """ 62 | Subclasses should override this method to print out info about the 63 | wrapper (such as arguments passed to it). 64 | """ 65 | return "" 66 | 67 | def __repr__(self): 68 | """Pretty print environment.""" 69 | header = "{}".format(str(self.__class__.__name__)) 70 | msg = "" 71 | indent = " " * 4 72 | if self._to_string() != "": 73 | msg += textwrap.indent("\n" + self._to_string(), indent) 74 | msg += textwrap.indent("\nenv={}".format(self.env), indent) 75 | msg = header + "(" + msg + "\n)" 76 | return msg 77 | 78 | # this method is a fallback option on any methods the original env might support 79 | def __getattr__(self, attr): 80 | # using getattr ensures that both __getattribute__ and __getattr__ (fallback) get called 81 | # (see https://stackoverflow.com/questions/3278077/difference-between-getattr-vs-getattribute) 82 | orig_attr = getattr(self.env, attr) 83 | if callable(orig_attr): 84 | 85 | def hooked(*args, **kwargs): 86 | result = orig_attr(*args, **kwargs) 87 | # prevent wrapped_class from becoming unwrapped 88 | if id(result) == id(self.env): 89 | return self 90 | return result 91 | 92 | return hooked 93 | else: 94 | return orig_attr 95 | 96 | 97 | class EvaluateOnDatasetWrapper(Wrapper): 98 | def __init__( 99 | self, 100 | env, 101 | dataset_path=None, 102 | valid_key="valid", 103 | ): 104 | super(EvaluateOnDatasetWrapper, self).__init__(env=env) 105 | self.dataset_path = dataset_path 106 | self.valid_key = valid_key 107 | if dataset_path is not None: 108 | self.load_evaluation_data(dataset_path) 109 | 110 | def sample_eval_episodes(self, num_episodes): 111 | """ 112 | Sample a random set of episodes from the set of all episodes. 113 | """ 114 | self.eval_indices = np.random.choice( 115 | range(len(self.initial_states)), size=num_episodes, replace=False 116 | ) 117 | self.eval_current_index = 0 118 | 119 | def get_num_val_states(self): 120 | return len(self.demos) 121 | 122 | def set_eval_episode(self, eval_index): 123 | self.eval_indices = [eval_index] 124 | self.eval_current_index = 0 125 | 126 | def load_evaluation_data(self, hdf5_path): 127 | # NOTE: for the hierarchical primitive setting, this code will only 128 | # work for resetting the initial state, the signature/command_indices are wrong 129 | # DO NOT use with command_index or signature conditioning 130 | self.hdf5_file = h5py.File(hdf5_path, "r", swmr=True, libver="latest") 131 | filter_key = self.valid_key 132 | self.demos = [ 133 | elem.decode("utf-8") 134 | for elem in np.array(self.hdf5_file["mask/{}".format(filter_key)][:]) 135 | ] 136 | try: 137 | self.initial_states = [ 138 | dict( 139 | states=self.hdf5_file["data/{}/states".format(ep)][()][0], 140 | model=self.hdf5_file["data/{}".format(ep)].attrs["model_file"], 141 | ) 142 | for ep in self.demos 143 | ] 144 | self.actions = [self.hdf5_file["data/{}/actions".format(ep)][()] for ep in self.demos] 145 | self.obs = [ 146 | { 147 | k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()] 148 | for k in ["robot0_eef_pos", "robot0_eef_quat"] 149 | } 150 | for ep in self.demos 151 | ] 152 | except: 153 | self.initial_states = [ 154 | dict( 155 | states={ 156 | k_: self.hdf5_file["data/{}/{}/{}".format(ep, "states", k_)][()].astype( 157 | "float32" 158 | )[0] 159 | for k_ in self.hdf5_file["data/{}/{}".format(ep, "states")].keys() 160 | } 161 | ) 162 | for ep in self.demos 163 | ] 164 | try: 165 | self.command_indices = [ 166 | self.hdf5_file["data/{}/obs/command_index".format(ep)][()] for ep in self.demos 167 | ] 168 | except: 169 | self.command_indices = [np.zeros(1).reshape(-1, 1) for ep in self.demos] 170 | try: 171 | self.init_strings = [ 172 | self.hdf5_file["data/{}".format(ep)].attrs["init_string"] for ep in self.demos 173 | ] 174 | self.goal_parts_strings = [ 175 | self.hdf5_file["data/{}".format(ep)].attrs["goal_parts_string"] for ep in self.demos 176 | ] 177 | except: 178 | print(traceback.format_exc()) 179 | self.init_strings = [None for ep in self.demos] 180 | self.goal_parts_strings = [None for ep in self.demos] 181 | 182 | def reset(self): 183 | """ 184 | Modify to return frame stacked observation which is @self.num_frames copies of 185 | the initial observation. 186 | 187 | Returns: 188 | obs_stacked (dict): each observation key in original observation now has 189 | leading shape @self.num_frames and consists of the previous @self.num_frames 190 | observations 191 | """ 192 | if self.dataset_path is not None: 193 | print("resetting to a valid state") 194 | self.env.reset() 195 | states = self.initial_states[self.eval_indices[self.eval_current_index]] 196 | if ( 197 | self.init_strings[self.eval_indices[self.eval_current_index]] is not None 198 | and self.goal_parts_strings[self.eval_indices[self.eval_current_index]] is not None 199 | ): 200 | states["init_string"] = self.init_strings[ 201 | self.eval_indices[self.eval_current_index] 202 | ] 203 | states["goal_parts_string"] = self.goal_parts_strings[ 204 | self.eval_indices[self.eval_current_index] 205 | ] 206 | self.eval_current_index += 1 207 | obs = self.reset_to(states) 208 | return obs 209 | else: 210 | obs = self.env.reset() 211 | self.timestep = 0 # always zero regardless of timestep type 212 | return obs 213 | 214 | def step(self, action, **kwargs): 215 | return self.env.step(action, **kwargs) 216 | 217 | 218 | class FrameStackWrapper(Wrapper): 219 | """ 220 | Wrapper for frame stacking observations during rollouts. The agent 221 | receives a sequence of past observations instead of a single observation 222 | when it calls @env.reset, @env.reset_to, or @env.step in the rollout loop. 223 | """ 224 | 225 | def __init__( 226 | self, 227 | env, 228 | num_frames, 229 | horizon=None, 230 | dataset_path=None, 231 | valid_key="valid", 232 | ): 233 | """ 234 | Args: 235 | env (EnvBase instance): The environment to wrap. 236 | num_frames (int): number of past observations (including current observation) 237 | to stack together. Must be greater than 1 (otherwise this wrapper would 238 | be a no-op). 239 | """ 240 | assert ( 241 | num_frames > 1 242 | ), "error: FrameStackWrapper must have num_frames > 1 but got num_frames of {}".format( 243 | num_frames 244 | ) 245 | 246 | super(FrameStackWrapper, self).__init__(env=env) 247 | self.num_frames = num_frames 248 | 249 | # keep track of last @num_frames observations for each obs key 250 | self.obs_history = None 251 | self.horizon = horizon 252 | self.dataset_path = dataset_path 253 | self.valid_key = valid_key 254 | if dataset_path is not None: 255 | self.hdf5_file = h5py.File(dataset_path, "r", swmr=True, libver="latest") 256 | filter_key = self.valid_key 257 | self.demos = [ 258 | elem.decode("utf-8") 259 | for elem in np.array(self.hdf5_file["mask/{}".format(filter_key)][:]) 260 | ] 261 | 262 | def load_evaluation_data(self, idx): 263 | ep = self.demos[idx] 264 | initial_states = dict( 265 | states=self.hdf5_file["data/{}/states".format(ep)][()][0], 266 | model=self.hdf5_file["data/{}".format(ep)].attrs["model_file"], 267 | ) 268 | 269 | try: 270 | init_strings = self.hdf5_file["data/{}".format(ep)].attrs["init_string"] 271 | goal_parts_strings = self.hdf5_file["data/{}".format(ep)].attrs["goal_parts_string"] 272 | except: 273 | init_strings = None 274 | goal_parts_strings = None 275 | return ( 276 | initial_states, 277 | init_strings, 278 | goal_parts_strings, 279 | ) 280 | 281 | def _get_initial_obs_history(self, init_obs): 282 | """ 283 | Helper method to get observation history from the initial observation, by 284 | repeating it. 285 | 286 | Returns: 287 | obs_history (dict): a deque for each observation key, with an extra 288 | leading dimension of 1 for each key (for easy concatenation later) 289 | """ 290 | obs_history = {} 291 | for k in init_obs: 292 | obs_history[k] = deque( 293 | [init_obs[k][None] for _ in range(self.num_frames)], 294 | maxlen=self.num_frames, 295 | ) 296 | return obs_history 297 | 298 | def _get_stacked_obs_from_history(self): 299 | """ 300 | Helper method to convert internal variable @self.obs_history to a 301 | stacked observation where each key is a numpy array with leading dimension 302 | @self.num_frames. 303 | """ 304 | # concatenate all frames per key so we return a numpy array per key 305 | return {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history} 306 | 307 | def update_obs(self, obs, action=None, reset=False): 308 | obs["timesteps"] = np.array([self.timestep]) 309 | if reset: 310 | obs["actions"] = np.zeros(self.env.action_dimension) 311 | else: 312 | self.timestep += 1 313 | obs["actions"] = action[: self.env.action_dimension] 314 | 315 | def sample_eval_episodes(self, num_episodes): 316 | """ 317 | Sample a random set of episodes from the set of all episodes. 318 | """ 319 | self.eval_indices = np.random.choice( 320 | range(len(self.demos)), size=num_episodes, replace=False 321 | ) 322 | self.eval_current_index = 0 323 | 324 | def get_num_val_states(self): 325 | return len(self.demos) 326 | 327 | def set_eval_episode(self, eval_index): 328 | self.eval_indices = [eval_index] 329 | self.eval_current_index = 0 330 | 331 | def reset(self, use_eval_indices=True): 332 | """ 333 | Modify to return frame stacked observation which is @self.num_frames copies of 334 | the initial observation. 335 | 336 | Returns: 337 | obs_stacked (dict): each observation key in original observation now has 338 | leading shape @self.num_frames and consists of the previous @self.num_frames 339 | observations 340 | """ 341 | if self.dataset_path is not None and use_eval_indices: 342 | print("resetting to a valid state") 343 | self.env.reset() 344 | ( 345 | states, 346 | init_string, 347 | goal_parts_string, 348 | ) = self.load_evaluation_data(self.eval_indices[self.eval_current_index]) 349 | if init_string is not None and goal_parts_string is not None: 350 | states["init_string"] = init_string 351 | states["goal_parts_string"] = goal_parts_string 352 | self.eval_current_index += 1 353 | obs = self.reset_to(states) 354 | return obs 355 | else: 356 | obs = self.env.reset() 357 | self.timestep = 0 # always zero regardless of timestep type 358 | self.update_obs(obs, reset=True) 359 | self.obs_history = self._get_initial_obs_history(init_obs=obs) 360 | return self._get_stacked_obs_from_history() 361 | 362 | def reset_to(self, state): 363 | """ 364 | Modify to return frame stacked observation which is @self.num_frames copies of 365 | the initial observation. 366 | 367 | Returns: 368 | obs_stacked (dict): each observation key in original observation now has 369 | leading shape @self.num_frames and consists of the previous @self.num_frames 370 | observations 371 | """ 372 | obs = self.env.reset_to(state) 373 | self.timestep = 0 # always zero regardless of timestep type 374 | self.update_obs(obs, reset=True) 375 | self.obs_history = self._get_initial_obs_history(init_obs=obs) 376 | return self._get_stacked_obs_from_history() 377 | 378 | def step(self, action, **kwargs): 379 | """ 380 | Modify to update the internal frame history and return frame stacked observation, 381 | which will have leading dimension @self.num_frames for each key. 382 | 383 | Args: 384 | action (np.array): action to take 385 | 386 | Returns: 387 | obs_stacked (dict): each observation key in original observation now has 388 | leading shape @self.num_frames and consists of the previous @self.num_frames 389 | observations 390 | reward (float): reward for this step 391 | done (bool): whether the task is done 392 | info (dict): extra information 393 | """ 394 | obs, r, done, info = self.env.step(action, **kwargs) 395 | self.update_obs(obs, action=action, reset=False) 396 | # update frame history 397 | for k in obs: 398 | # make sure to have leading dim of 1 for easy concatenation 399 | self.obs_history[k].append(obs[k][None]) 400 | obs_ret = self._get_stacked_obs_from_history() 401 | return obs_ret, r, done, info 402 | 403 | def _to_string(self): 404 | """Info to pretty print.""" 405 | return "num_frames={}".format(self.num_frames) 406 | -------------------------------------------------------------------------------- /optimus/envs/env_robosuite.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | """ 5 | This file contains the robosuite environment wrapper that is used 6 | to provide a standardized environment API for training policies and interacting 7 | with metadata present in datasets. 8 | """ 9 | import json 10 | import os 11 | import xml.etree.ElementTree as ET 12 | from copy import deepcopy 13 | 14 | import numpy as np 15 | import robomimic.envs.env_base as EB 16 | import robomimic.utils.obs_utils as ObsUtils 17 | import robosuite 18 | 19 | from optimus.envs.stack import * 20 | 21 | 22 | def postprocess_model_xml(xml_str): 23 | """ 24 | This function postprocesses the model.xml collected from a MuJoCo demonstration 25 | in order to make sure that the STL files can be found. 26 | 27 | Args: 28 | xml_str (str): Mujoco sim demonstration XML file as string 29 | 30 | Returns: 31 | str: Post-processed xml file as string 32 | """ 33 | 34 | path = os.path.split(robosuite.__file__)[0] 35 | path_split = path.split("/") 36 | 37 | # replace mesh and texture file paths 38 | tree = ET.fromstring(xml_str) 39 | root = tree 40 | asset = root.find("asset") 41 | meshes = asset.findall("mesh") 42 | textures = asset.findall("texture") 43 | all_elements = meshes + textures 44 | 45 | for elem in all_elements: 46 | old_path = elem.get("file") 47 | if old_path is None: 48 | continue 49 | old_path_split = old_path.split("/") 50 | ind = max( 51 | loc for loc, val in enumerate(old_path_split) if val == "robosuite" 52 | ) # last occurrence index 53 | new_path_split = path_split + old_path_split[ind + 1 :] 54 | new_path = "/".join(new_path_split) 55 | elem.set("file", new_path) 56 | 57 | return ET.tostring(root, encoding="utf8").decode("utf8") 58 | 59 | 60 | class EnvRobosuite(EB.EnvBase): 61 | """Wrapper class for robosuite environments (https://github.com/ARISE-Initiative/robosuite)""" 62 | 63 | def __init__( 64 | self, 65 | env_name, 66 | render=False, 67 | render_offscreen=False, 68 | use_image_obs=False, 69 | postprocess_visual_obs=True, 70 | **kwargs, 71 | ): 72 | """ 73 | Args: 74 | env_name (str): name of environment. Only needs to be provided if making a different 75 | environment from the one in @env_meta. 76 | 77 | render (bool): if True, environment supports on-screen rendering 78 | 79 | render_offscreen (bool): if True, environment supports off-screen rendering. This 80 | is forced to be True if @env_meta["use_images"] is True. 81 | 82 | use_image_obs (bool): if True, environment is expected to render rgb image observations 83 | on every env.step call. Set this to False for efficiency reasons, if image 84 | observations are not required. 85 | 86 | postprocess_visual_obs (bool): if True, postprocess image observations 87 | to prepare for learning. This should only be False when extracting observations 88 | for saving to a dataset (to save space on RGB images for example). 89 | """ 90 | self.postprocess_visual_obs = postprocess_visual_obs 91 | 92 | # robosuite version check 93 | self._is_v1 = robosuite.__version__.split(".")[0] == "1" 94 | if self._is_v1: 95 | assert ( 96 | int(robosuite.__version__.split(".")[1]) >= 2 97 | ), "only support robosuite v0.3 and v1.2+" 98 | 99 | kwargs = deepcopy(kwargs) 100 | 101 | # update kwargs based on passed arguments 102 | update_kwargs = dict( 103 | has_renderer=render, 104 | has_offscreen_renderer=(render_offscreen or use_image_obs), 105 | ignore_done=True, 106 | use_object_obs=True, 107 | use_camera_obs=use_image_obs, 108 | camera_depths=False, 109 | ) 110 | kwargs.update(update_kwargs) 111 | 112 | if self._is_v1: 113 | if kwargs["has_offscreen_renderer"]: 114 | # ensure that we select the correct GPU device for rendering by testing for EGL rendering 115 | # NOTE: this package should be installed from this link (https://github.com/StanfordVL/egl_probe) 116 | import egl_probe 117 | 118 | valid_gpu_devices = egl_probe.get_available_devices() 119 | if len(valid_gpu_devices) > 0: 120 | kwargs["render_gpu_device_id"] = valid_gpu_devices[0] 121 | else: 122 | # make sure gripper visualization is turned off (we almost always want this for learning) 123 | kwargs["gripper_visualization"] = False 124 | del kwargs["camera_depths"] 125 | kwargs["camera_depth"] = False # rename kwarg 126 | 127 | self._env_name = env_name 128 | self._init_kwargs = deepcopy(kwargs) 129 | del kwargs["deterministic"] 130 | del kwargs["max_height"] 131 | del kwargs["max_towers"] 132 | del kwargs["include_language"] 133 | self.env = robosuite.make(self._env_name, **kwargs) 134 | if self._is_v1: 135 | # Make sure joint position observations and eef vel observations are active 136 | for ob_name in self.env.observation_names: 137 | if ("joint_pos" in ob_name) or ("eef_vel" in ob_name): 138 | self.env.modify_observable( 139 | observable_name=ob_name, attribute="active", modifier=True 140 | ) 141 | 142 | def step(self, action): 143 | """ 144 | Step in the environment with an action. 145 | 146 | Args: 147 | action (np.array): action to take 148 | 149 | Returns: 150 | observation (dict): new observation dictionary 151 | reward (float): reward for this step 152 | done (bool): whether the task is done 153 | info (dict): extra information 154 | """ 155 | obs, r, done, info = self.env.step(action) 156 | obs = self.get_observation(obs) 157 | return obs, r, self.is_done(), info 158 | 159 | def reset(self): 160 | """ 161 | Reset environment. 162 | 163 | Returns: 164 | observation (dict): initial observation dictionary. 165 | """ 166 | di = self.env.reset() 167 | return self.get_observation(di) 168 | 169 | def reset_to(self, state): 170 | """ 171 | Reset to a specific simulator state. 172 | 173 | Args: 174 | state (dict): current simulator state that contains one or more of: 175 | - states (np.ndarray): initial state of the mujoco environment 176 | - model (str): mujoco scene xml 177 | 178 | Returns: 179 | observation (dict): observation dictionary after setting the simulator state (only 180 | if "states" is in @state) 181 | """ 182 | should_ret = False 183 | if "model" in state: 184 | self.reset() 185 | xml = postprocess_model_xml(state["model"]) 186 | self.env.reset_from_xml_string(xml) 187 | self.env.sim.reset() 188 | if not self._is_v1: 189 | # hide teleop visualization after restoring from model 190 | self.env.sim.model.site_rgba[self.env.eef_site_id] = np.array([0.0, 0.0, 0.0, 0.0]) 191 | self.env.sim.model.site_rgba[self.env.eef_cylinder_id] = np.array( 192 | [0.0, 0.0, 0.0, 0.0] 193 | ) 194 | if "states" in state: 195 | self.env.sim.set_state_from_flattened(state["states"]) 196 | self.env.sim.forward() 197 | should_ret = True 198 | 199 | if "goal" in state: 200 | self.set_goal(**state["goal"]) 201 | if should_ret: 202 | # only return obs if we've done a forward call - otherwise the observations will be garbage 203 | return self.get_observation() 204 | return None 205 | 206 | def render(self, mode="human", height=None, width=None, camera_name="agentview"): 207 | """ 208 | Render from simulation to either an on-screen window or off-screen to RGB array. 209 | 210 | Args: 211 | mode (str): pass "human" for on-screen rendering or "rgb_array" for off-screen rendering 212 | height (int): height of image to render - only used if mode is "rgb_array" 213 | width (int): width of image to render - only used if mode is "rgb_array" 214 | camera_name (str): camera name to use for rendering 215 | """ 216 | if mode == "human": 217 | cam_id = self.env.sim.camera_name2id(camera_name) 218 | self.env.viewer.set_camera(cam_id) 219 | return self.env.render() 220 | elif mode == "rgb_array": 221 | return self.env.sim.render(height=height, width=width, camera_name=camera_name)[::-1] 222 | else: 223 | raise NotImplementedError("mode={} is not implemented".format(mode)) 224 | 225 | def get_observation(self, di=None): 226 | """ 227 | Get current environment observation dictionary. 228 | 229 | Args: 230 | di (dict): current raw observation dictionary from robosuite to wrap and provide 231 | as a dictionary. If not provided, will be queried from robosuite. 232 | """ 233 | if di is None: 234 | di = ( 235 | self.env._get_observations(force_update=True) 236 | if self._is_v1 237 | else self.env._get_observation() 238 | ) 239 | ret = {} 240 | for k in di: 241 | if (k in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality( 242 | key=k, obs_modality="rgb" 243 | ): 244 | ret[k] = di[k][::-1] 245 | if self.postprocess_visual_obs: 246 | ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k) 247 | 248 | # "object" key contains object information 249 | ret["object"] = np.array(di["object-state"]) 250 | 251 | if self._is_v1: 252 | for robot in self.env.robots: 253 | # add all robot-arm-specific observations. Note the (k not in ret) check 254 | # ensures that we don't accidentally add robot wrist images a second time 255 | pf = robot.robot_model.naming_prefix 256 | for k in di: 257 | if k.startswith(pf) and (k not in ret) and (not k.endswith("proprio-state")): 258 | ret[k] = np.array(di[k]) 259 | else: 260 | # minimal proprioception for older versions of robosuite 261 | ret["proprio"] = np.array(di["robot-state"]) 262 | ret["eef_pos"] = np.array(di["eef_pos"]) 263 | ret["eef_quat"] = np.array(di["eef_quat"]) 264 | ret["gripper_qpos"] = np.array(di["gripper_qpos"]) 265 | return ret 266 | 267 | def get_state(self): 268 | """ 269 | Get current environment simulator state as a dictionary. Should be compatible with @reset_to. 270 | """ 271 | xml = self.env.sim.model.get_xml() # model xml file 272 | state = np.array(self.env.sim.get_state().flatten()) # simulator state 273 | return dict(model=xml, states=state) 274 | 275 | def get_reward(self): 276 | """ 277 | Get current reward. 278 | """ 279 | return self.env.reward(None) 280 | 281 | def get_goal(self): 282 | """ 283 | Get goal observation. Not all environments support this. 284 | """ 285 | return self.get_observation(self.env._get_goal()) 286 | 287 | def set_goal(self, **kwargs): 288 | """ 289 | Set goal observation with external specification. Not all environments support this. 290 | """ 291 | return self.env.set_goal(**kwargs) 292 | 293 | def is_done(self): 294 | """ 295 | Check if the task is done (not necessarily successful). 296 | """ 297 | 298 | # Robosuite envs always rollout to fixed horizon. 299 | return False 300 | 301 | def is_success(self): 302 | """ 303 | Check if the task condition(s) is reached. Should return a dictionary 304 | { str: bool } with at least a "task" key for the overall task success, 305 | and additional optional keys corresponding to other task criteria. 306 | """ 307 | succ = self.env._check_success() 308 | if isinstance(succ, dict): 309 | assert "task" in succ 310 | return succ 311 | return {"task": succ} 312 | 313 | @property 314 | def action_dimension(self): 315 | """ 316 | Returns dimension of actions (int). 317 | """ 318 | return self.env.action_spec[0].shape[0] 319 | 320 | @property 321 | def name(self): 322 | """ 323 | Returns name of environment name (str). 324 | """ 325 | return self._env_name 326 | 327 | @property 328 | def type(self): 329 | """ 330 | Returns environment type (int) for this kind of environment. 331 | This helps identify this env class. 332 | """ 333 | return EB.EnvType.ROBOSUITE_TYPE 334 | 335 | def serialize(self): 336 | """ 337 | Save all information needed to re-instantiate this environment in a dictionary. 338 | This is the same as @env_meta - environment metadata stored in hdf5 datasets, 339 | and used in utils/env_utils.py. 340 | """ 341 | return dict(env_name=self.name, type=self.type, env_kwargs=deepcopy(self._init_kwargs)) 342 | 343 | @classmethod 344 | def create_for_data_processing( 345 | cls, 346 | env_name, 347 | camera_names, 348 | camera_height, 349 | camera_width, 350 | reward_shaping, 351 | **kwargs, 352 | ): 353 | """ 354 | Create environment for processing datasets, which includes extracting 355 | observations, labeling dense / sparse rewards, and annotating dones in 356 | transitions. 357 | 358 | Args: 359 | env_name (str): name of environment 360 | camera_names (list of str): list of camera names that correspond to image observations 361 | camera_height (int): camera height for all cameras 362 | camera_width (int): camera width for all cameras 363 | reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards 364 | """ 365 | is_v1 = robosuite.__version__.split(".")[0] == "1" 366 | has_camera = len(camera_names) > 0 367 | 368 | new_kwargs = { 369 | "reward_shaping": reward_shaping, 370 | } 371 | 372 | if has_camera: 373 | if is_v1: 374 | new_kwargs["camera_names"] = list(camera_names) 375 | new_kwargs["camera_heights"] = camera_height 376 | new_kwargs["camera_widths"] = camera_width 377 | else: 378 | assert len(camera_names) == 1 379 | if has_camera: 380 | new_kwargs["camera_name"] = camera_names[0] 381 | new_kwargs["camera_height"] = camera_height 382 | new_kwargs["camera_width"] = camera_width 383 | 384 | kwargs.update(new_kwargs) 385 | 386 | # also initialize obs utils so it knows which modalities are image modalities 387 | image_modalities = list(camera_names) 388 | if is_v1: 389 | image_modalities = ["{}_image".format(cn) for cn in camera_names] 390 | elif has_camera: 391 | # v0.3 only had support for one image, and it was named "rgb" 392 | assert len(image_modalities) == 1 393 | image_modalities = ["rgb"] 394 | obs_modality_specs = { 395 | "obs": { 396 | "low_dim": [], # technically unused, so we don't have to specify all of them 397 | "rgb": image_modalities, 398 | } 399 | } 400 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs) 401 | 402 | # note that @postprocess_visual_obs is False since this env's images will be written to a dataset 403 | return cls( 404 | env_name=env_name, 405 | render=False, 406 | render_offscreen=has_camera, 407 | use_image_obs=has_camera, 408 | postprocess_visual_obs=False, 409 | **kwargs, 410 | ) 411 | 412 | @property 413 | def rollout_exceptions(self): 414 | """ 415 | Return tuple of exceptions to except when doing rollouts. This is useful to ensure 416 | that the entire training run doesn't crash because of a bad policy that causes unstable 417 | simulation computations. 418 | """ 419 | return tuple() 420 | 421 | def __repr__(self): 422 | """ 423 | Pretty-print env description. 424 | """ 425 | return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4) 426 | -------------------------------------------------------------------------------- /optimus/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | """ 6 | This file contains Dataset classes that are used by torch dataloaders 7 | to fetch batches from hdf5 files. 8 | """ 9 | import numpy as np 10 | import robomimic.utils.log_utils as LogUtils 11 | import robomimic.utils.obs_utils as ObsUtils 12 | import robomimic.utils.tensor_utils as TensorUtils 13 | from robomimic.utils.dataset import SequenceDataset 14 | 15 | 16 | class SequenceDataset(SequenceDataset): 17 | def __init__( 18 | self, 19 | *args, 20 | transformer_enabled=False, 21 | **kwargs, 22 | ): 23 | self.transformer_enabled = transformer_enabled 24 | self.vis_data = dict() 25 | self.ep_to_hdf5_file = None 26 | super().__init__(*args, **kwargs) 27 | 28 | def get_dataset_for_ep(self, ep, key): 29 | """ 30 | Helper utility to get a dataset for a specific demonstration. 31 | Takes into account whether the dataset has been loaded into memory. 32 | """ 33 | if self.ep_to_hdf5_file is None: 34 | self.ep_to_hdf5_file = {ep: self.hdf5_file for ep in self.demos} 35 | # check if this key should be in memory 36 | key_should_be_in_memory = self.hdf5_cache_mode in ["all", "low_dim"] 37 | if key_should_be_in_memory: 38 | # if key is an observation, it may not be in memory 39 | if "/" in key: 40 | key1, key2 = key.split("/") 41 | assert key1 in ["obs", "next_obs"] 42 | if key2 not in self.obs_keys_in_memory: 43 | key_should_be_in_memory = False 44 | 45 | if key_should_be_in_memory: 46 | # read cache 47 | if "/" in key: 48 | key1, key2 = key.split("/") 49 | assert key1 in ["obs", "next_obs"] 50 | ret = self.hdf5_cache[ep][key1][key2] 51 | else: 52 | ret = self.hdf5_cache[ep][key] 53 | else: 54 | # read from file 55 | hd5key = "data/{}/{}".format(ep, key) 56 | ret = self.ep_to_hdf5_file[ep][hd5key] 57 | return ret 58 | 59 | def get_sequence_from_demo( 60 | self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1 61 | ): 62 | """ 63 | Extract a (sub)sequence of data items from a demo given the @keys of the items. 64 | 65 | Args: 66 | demo_id (str): id of the demo, e.g., demo_0 67 | index_in_demo (int): beginning index of the sequence wrt the demo 68 | keys (tuple): list of keys to extract 69 | num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range 70 | seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range 71 | 72 | Returns: 73 | a dictionary of extracted items. 74 | """ 75 | assert num_frames_to_stack >= 0 76 | assert seq_length >= 1 77 | 78 | demo_length = self._demo_id_to_demo_length[demo_id] 79 | assert index_in_demo < demo_length 80 | 81 | # determine begin and end of sequence 82 | seq_begin_index = max(0, index_in_demo - num_frames_to_stack) 83 | seq_end_index = min(demo_length, index_in_demo + seq_length) 84 | 85 | # determine sequence padding 86 | seq_begin_pad = max(0, num_frames_to_stack - index_in_demo) # pad for frame stacking 87 | seq_end_pad = max(0, index_in_demo + seq_length - demo_length) # pad for sequence length 88 | 89 | # make sure we are not padding if specified. 90 | if not self.pad_frame_stack: 91 | assert seq_begin_pad == 0 92 | if not self.pad_seq_length: 93 | assert seq_end_pad == 0 94 | 95 | # fetch observation from the dataset file 96 | seq = dict() 97 | for k in keys: 98 | data = self.get_dataset_for_ep(demo_id, k) 99 | seq[k] = data[seq_begin_index:seq_end_index] 100 | seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True) 101 | pad_mask = np.array( 102 | [0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad 103 | ) 104 | pad_mask = pad_mask[:, None].astype(np.bool) 105 | 106 | return seq, pad_mask 107 | 108 | def get_obs_sequence_from_demo( 109 | self, 110 | demo_id, 111 | index_in_demo, 112 | keys, 113 | num_frames_to_stack=0, 114 | seq_length=1, 115 | prefix="obs", 116 | ): 117 | """ 118 | Extract a (sub)sequence of observation items from a demo given the @keys of the items. 119 | 120 | Args: 121 | demo_id (str): id of the demo, e.g., demo_0 122 | index_in_demo (int): beginning index of the sequence wrt the demo 123 | keys (tuple): list of keys to extract 124 | num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range 125 | seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range 126 | prefix (str): one of "obs", "next_obs" 127 | 128 | Returns: 129 | a dictionary of extracted items. 130 | """ 131 | obs, pad_mask = self.get_sequence_from_demo( 132 | demo_id, 133 | index_in_demo=index_in_demo, 134 | keys=tuple("{}/{}".format(prefix, k) for k in keys), 135 | num_frames_to_stack=num_frames_to_stack, 136 | seq_length=seq_length, 137 | ) 138 | obs = {k.split("/")[1]: obs[k] for k in obs} # strip the prefix 139 | if self.get_pad_mask: 140 | obs["pad_mask"] = pad_mask 141 | 142 | # prepare image observations from dataset 143 | return obs 144 | 145 | def load_dataset_in_memory(self, demo_list, hdf5_file, obs_keys, dataset_keys, load_next_obs): 146 | """ 147 | Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this 148 | differs from `self.getitem_cache`, which, if active, actually caches the outputs of the 149 | `getitem` operation. 150 | 151 | Args: 152 | demo_list (list): list of demo keys, e.g., 'demo_0' 153 | hdf5_file (h5py.File): file handle to the hdf5 dataset. 154 | obs_keys (list, tuple): observation keys to fetch, e.g., 'images' 155 | dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions' 156 | load_next_obs (bool): whether to load next_obs from the dataset 157 | 158 | Returns: 159 | all_data (dict): dictionary of loaded data. 160 | """ 161 | all_data = dict() 162 | 163 | print("SequenceDataset: loading dataset into memory...") 164 | obs_keys = [o for o in obs_keys if o != "timesteps" and o != "goal"] 165 | 166 | for ep in LogUtils.custom_tqdm(demo_list): 167 | all_data[ep] = {} 168 | all_data[ep]["attrs"] = {} 169 | all_data[ep]["attrs"]["num_samples"] = hdf5_file["data/{}".format(ep)].attrs[ 170 | "num_samples" 171 | ] 172 | 173 | # get other dataset keys 174 | for k in dataset_keys: 175 | if k in hdf5_file["data/{}".format(ep)]: 176 | all_data[ep][k] = hdf5_file["data/{}/{}".format(ep, k)][()].astype("float32") 177 | else: 178 | all_data[ep][k] = np.zeros( 179 | (all_data[ep]["attrs"]["num_samples"], 1), dtype=np.float32 180 | ) 181 | # get obs 182 | all_data[ep]["obs"] = { 183 | k: hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype("float32") for k in obs_keys 184 | } 185 | 186 | if self.load_next_obs: 187 | # last block position is given by last elem of next_obs 188 | goal = hdf5_file["data/{}/next_obs/{}".format(ep, "object")][()].astype("float32")[ 189 | -1, 7:10 190 | ] 191 | all_data[ep]["obs"]["goal"] = np.repeat( 192 | goal.reshape(1, -1), all_data[ep]["attrs"]["num_samples"], axis=0 193 | ) 194 | 195 | if self.transformer_enabled: 196 | all_data[ep]["obs"]["timesteps"] = np.arange( 197 | 0, all_data[ep]["obs"][obs_keys[0]].shape[0] 198 | ).reshape(-1, 1) 199 | if load_next_obs: 200 | all_data[ep]["next_obs"] = { 201 | k: hdf5_file["data/{}/next_obs/{}".format(ep, k)][()].astype("float32") 202 | for k in obs_keys 203 | } 204 | if self.transformer_enabled: 205 | # Doesn't actually matter, won't be used 206 | all_data[ep]["next_obs"]["timesteps"] = np.zeros_like( 207 | all_data[ep]["obs"]["timesteps"] 208 | ) 209 | all_data[ep]["next_obs"]["goal"] = np.repeat( 210 | goal.reshape(1, -1), all_data[ep]["attrs"]["num_samples"], axis=0 211 | ) 212 | return all_data 213 | 214 | def get_dataset_sequence_from_demo( 215 | self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1 216 | ): 217 | """ 218 | Extract a (sub)sequence of dataset items from a demo given the @keys of the items (e.g., states, actions). 219 | 220 | Args: 221 | demo_id (str): id of the demo, e.g., demo_0 222 | index_in_demo (int): beginning index of the sequence wrt the demo 223 | keys (tuple): list of keys to extract 224 | num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range 225 | seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range 226 | 227 | Returns: 228 | a dictionary of extracted items. 229 | """ 230 | data, pad_mask = self.get_sequence_from_demo( 231 | demo_id, 232 | index_in_demo=index_in_demo, 233 | keys=keys, 234 | num_frames_to_stack=num_frames_to_stack, # don't frame stack for meta keys 235 | seq_length=seq_length, 236 | ) 237 | if self.get_pad_mask: 238 | data["pad_mask"] = pad_mask 239 | return data 240 | 241 | def get_item(self, index): 242 | """ 243 | Main implementation of getitem when not using cache. 244 | """ 245 | 246 | demo_id = self._index_to_demo_id[index] 247 | demo_start_index = self._demo_id_to_start_indices[demo_id] 248 | demo_length = self._demo_id_to_demo_length[demo_id] 249 | 250 | # start at offset index if not padding for frame stacking 251 | demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1) 252 | index_in_demo = index - demo_start_index + demo_index_offset 253 | 254 | # end at offset index if not padding for seq length 255 | demo_length_offset = 0 if self.pad_seq_length else (self.seq_length - 1) 256 | end_index_in_demo = demo_length - demo_length_offset 257 | 258 | keys = [*self.dataset_keys] 259 | meta = self.get_dataset_sequence_from_demo( 260 | demo_id, 261 | index_in_demo=index_in_demo, 262 | keys=keys, 263 | num_frames_to_stack=self.n_frame_stack - 1, 264 | seq_length=self.seq_length, 265 | ) 266 | 267 | # determine goal index 268 | goal_index = None 269 | if self.goal_mode == "last": 270 | goal_index = end_index_in_demo - 1 271 | 272 | meta["obs"] = self.get_obs_sequence_from_demo( 273 | demo_id, 274 | index_in_demo=index_in_demo, 275 | keys=self.obs_keys, 276 | num_frames_to_stack=self.n_frame_stack - 1, 277 | seq_length=self.seq_length, 278 | prefix="obs", 279 | ) 280 | if self.hdf5_normalize_obs: 281 | meta["obs"] = ObsUtils.normalize_obs( 282 | meta["obs"], obs_normalization_stats=self.obs_normalization_stats 283 | ) 284 | 285 | if self.load_next_obs: 286 | meta["next_obs"] = self.get_obs_sequence_from_demo( 287 | demo_id, 288 | index_in_demo=index_in_demo, 289 | keys=self.obs_keys, 290 | num_frames_to_stack=self.n_frame_stack - 1, 291 | seq_length=self.seq_length, 292 | prefix="next_obs", 293 | ) 294 | if self.hdf5_normalize_obs: 295 | meta["next_obs"] = ObsUtils.normalize_obs( 296 | meta["next_obs"], 297 | obs_normalization_stats=self.obs_normalization_stats, 298 | ) 299 | 300 | if goal_index is not None: 301 | goal = self.get_obs_sequence_from_demo( 302 | demo_id, 303 | index_in_demo=goal_index, 304 | keys=self.obs_keys, 305 | num_frames_to_stack=0, 306 | seq_length=1, 307 | prefix="next_obs", 308 | ) 309 | if self.hdf5_normalize_obs: 310 | goal = ObsUtils.normalize_obs( 311 | goal, obs_normalization_stats=self.obs_normalization_stats 312 | ) 313 | meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal 314 | 315 | return meta 316 | 317 | def update_demo_info(self, demos, online_epoch, data, hdf5_file=None): 318 | """ 319 | This function is called during online epochs to update the demo information based 320 | on newly collected demos. 321 | Args: 322 | demos (list): list of demonstration keys to load data. 323 | online_epoch (int): value of the current online epoch 324 | data (dict): dictionary containing newly collected demos 325 | """ 326 | # sort demo keys 327 | inds = np.argsort( 328 | [int(elem[5:]) for elem in demos if not (elem in ["env_args", "model_file"])] 329 | ) 330 | new_demos = [demos[i] for i in inds] 331 | self.demos.extend(new_demos) 332 | 333 | self.n_demos = len(self.demos) 334 | 335 | self.prev_total_num_sequences = self.total_num_sequences 336 | for new_ep in new_demos: 337 | self.ep_to_hdf5_file[new_ep] = hdf5_file 338 | demo_length = data[new_ep]["num_samples"] 339 | self._demo_id_to_start_indices[new_ep] = self.total_num_sequences 340 | self._demo_id_to_demo_length[new_ep] = demo_length 341 | 342 | num_sequences = demo_length 343 | # determine actual number of sequences taking into account whether to pad for frame_stack and seq_length 344 | if not self.pad_frame_stack: 345 | num_sequences -= self.n_frame_stack - 1 346 | if not self.pad_seq_length: 347 | num_sequences -= self.seq_length - 1 348 | 349 | if self.pad_seq_length: 350 | assert demo_length >= 1 # sequence needs to have at least one sample 351 | num_sequences = max(num_sequences, 1) 352 | else: 353 | assert ( 354 | num_sequences >= 1 355 | ) # assume demo_length >= (self.n_frame_stack - 1 + self.seq_length) 356 | 357 | for _ in range(num_sequences): 358 | self._index_to_demo_id[self.total_num_sequences] = new_ep 359 | self.total_num_sequences += 1 360 | return new_demos 361 | 362 | def update_dataset_in_memory( 363 | self, 364 | demo_list, 365 | data, 366 | obs_keys, 367 | dataset_keys, 368 | load_next_obs=False, 369 | online_epoch=0, 370 | ): 371 | """ 372 | Loads the newly collected dataset into memory, preserving the structure of the data. Note that this 373 | differs from `self.getitem_cache`, which, if active, actually caches the outputs of the 374 | `getitem` operation. 375 | 376 | Args: 377 | demo_list (list): list of demo keys, e.g., 'demo_0' 378 | data (dict): dictionary containing newly collected demos 379 | obs_keys (list, tuple): observation keys to fetch, e.g., 'images' 380 | dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions' 381 | load_next_obs (bool): whether to load next_obs from the dataset 382 | 383 | Returns: 384 | all_data (dict): dictionary of loaded data. 385 | """ 386 | all_data = dict() 387 | print("SequenceDataset: loading dataset into memory...") 388 | obs_keys = [o for o in obs_keys if o != "timesteps"] 389 | for new_ep in LogUtils.custom_tqdm(demo_list): 390 | all_data[new_ep] = {} 391 | all_data[new_ep]["attrs"] = {} 392 | all_data[new_ep]["attrs"]["num_samples"] = data[new_ep]["num_samples"] 393 | 394 | # get other dataset keys 395 | for k in dataset_keys: 396 | if k in data[new_ep]: 397 | all_data[new_ep][k] = data[new_ep][k].astype("float32") 398 | else: 399 | all_data[new_ep][k] = np.zeros( 400 | (all_data[new_ep]["attrs"]["num_samples"], 1), dtype=np.float32 401 | ) 402 | # get obs 403 | all_data[new_ep]["obs"] = { 404 | k: data[new_ep]["obs"][k] for k in obs_keys if k != "timesteps" 405 | } 406 | 407 | for k in all_data[new_ep]["obs"]: 408 | all_data[new_ep]["obs"][k] = all_data[new_ep]["obs"][k].astype("float32") 409 | 410 | if self.transformer_enabled: 411 | all_data[new_ep]["obs"]["timesteps"] = np.arange( 412 | 0, all_data[new_ep]["obs"][obs_keys[0]].shape[0] 413 | ).reshape(-1, 1) 414 | 415 | self.hdf5_cache.update(all_data) 416 | -------------------------------------------------------------------------------- /optimus/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | import json 5 | import os 6 | import time 7 | from collections import OrderedDict 8 | 9 | import imageio 10 | import numpy as np 11 | import robomimic.utils.log_utils as LogUtils 12 | from robomimic.algo import RolloutPolicy 13 | from robomimic.utils.train_utils import * 14 | 15 | from optimus.config.base_config import config_factory 16 | from optimus.envs.wrappers import FrameStackWrapper 17 | from optimus.scripts.combine_hdf5 import global_dataset_updates, write_trajectory_to_dataset 18 | from optimus.utils.dataset import SequenceDataset 19 | 20 | import optimus 21 | 22 | 23 | def get_exp_dir(config, auto_remove_exp_dir=False): 24 | """ 25 | Create experiment directory from config. If an identical experiment directory 26 | exists and @auto_remove_exp_dir is False (default), the function will prompt 27 | the user on whether to remove and replace it, or keep the existing one and 28 | add a new subdirectory with the new timestamp for the current run. 29 | 30 | Args: 31 | auto_remove_exp_dir (bool): if True, automatically remove the existing experiment 32 | folder if it exists at the same path. 33 | 34 | Returns: 35 | log_dir (str): path to created log directory (sub-folder in experiment directory) 36 | output_dir (str): path to created models directory (sub-folder in experiment directory) 37 | to store model checkpoints 38 | video_dir (str): path to video directory (sub-folder in experiment directory) 39 | to store rollout videos 40 | """ 41 | # timestamp for directory names 42 | t_now = time.time() 43 | time_str = datetime.datetime.fromtimestamp(t_now).strftime("%Y%m%d%H%M%S") 44 | 45 | # create directory for where to dump model parameters, tensorboard logs, and videos 46 | base_output_dir = config.train.output_dir 47 | if not os.path.isabs(base_output_dir): 48 | # relative paths are specified relative to optimus module location 49 | base_output_dir = os.path.join(optimus.__path__[0], '../'+base_output_dir) 50 | base_output_dir = os.path.join(base_output_dir, config.experiment.name) 51 | if os.path.exists(base_output_dir): 52 | if not auto_remove_exp_dir: 53 | ans = input( 54 | "WARNING: model directory ({}) already exists! \noverwrite? (y/n)\n".format( 55 | base_output_dir 56 | ) 57 | ) 58 | else: 59 | ans = "y" 60 | if ans == "y": 61 | print("REMOVING") 62 | shutil.rmtree(base_output_dir) 63 | 64 | # only make model directory if model saving is enabled 65 | output_dir = None 66 | if config.experiment.save.enabled: 67 | output_dir = os.path.join(base_output_dir, time_str, "models") 68 | os.makedirs(output_dir) 69 | 70 | # tensorboard directory 71 | log_dir = os.path.join(base_output_dir, time_str, "logs") 72 | os.makedirs(log_dir) 73 | 74 | # video directory 75 | video_dir = os.path.join(base_output_dir, time_str, "videos") 76 | os.makedirs(video_dir) 77 | return log_dir, output_dir, video_dir, time_str 78 | 79 | 80 | def load_data_for_training(config, obs_keys): 81 | """ 82 | Data loading at the start of an algorithm. 83 | 84 | Args: 85 | config (BaseConfig instance): config object 86 | obs_keys (list): list of observation modalities that are required for 87 | training (this will inform the dataloader on what modalities to load) 88 | 89 | Returns: 90 | train_dataset (SequenceDataset instance): train dataset object 91 | valid_dataset (SequenceDataset instance): valid dataset object (only if using validation) 92 | """ 93 | 94 | # config can contain an attribute to filter on 95 | filter_by_attribute = config.train.hdf5_filter_key 96 | 97 | # load the dataset into memory 98 | if config.experiment.validate: 99 | train_dataset = dataset_factory(config, obs_keys, filter_by_attribute=filter_by_attribute) 100 | valid_dataset = dataset_factory(config, obs_keys, filter_by_attribute="valid") 101 | else: 102 | train_dataset = dataset_factory(config, obs_keys, filter_by_attribute=filter_by_attribute) 103 | valid_dataset = None 104 | 105 | return train_dataset, valid_dataset 106 | 107 | 108 | def dataset_factory(config, obs_keys, filter_by_attribute=None, dataset_path=None): 109 | """ 110 | Create a SequenceDataset instance to pass to a torch DataLoader. 111 | 112 | Args: 113 | config (BaseConfig instance): config object 114 | 115 | obs_keys (list): list of observation modalities that are required for 116 | training (this will inform the dataloader on what modalities to load) 117 | 118 | filter_by_attribute (str): if provided, use the provided filter key 119 | to select a subset of demonstration trajectories to load 120 | 121 | dataset_path (str): if provided, the SequenceDataset instance should load 122 | data from this dataset path. Defaults to config.train.data. 123 | 124 | Returns: 125 | dataset (SequenceDataset instance): dataset object 126 | """ 127 | if dataset_path is None: 128 | dataset_path = config.train.data 129 | 130 | ds_kwargs = dict( 131 | hdf5_path=dataset_path, 132 | obs_keys=obs_keys, 133 | dataset_keys=config.train.dataset_keys, 134 | load_next_obs=config.train.load_next_obs, 135 | frame_stack=config.train.frame_stack, 136 | seq_length=config.train.seq_length, 137 | pad_frame_stack=config.train.pad_frame_stack, 138 | pad_seq_length=config.train.pad_seq_length, 139 | get_pad_mask=False, 140 | goal_mode=config.train.goal_mode, 141 | hdf5_cache_mode=config.train.hdf5_cache_mode, 142 | hdf5_use_swmr=config.train.hdf5_use_swmr, 143 | hdf5_normalize_obs=config.train.hdf5_normalize_obs, 144 | filter_by_attribute=filter_by_attribute, 145 | transformer_enabled=config.algo.transformer.enabled, 146 | ) 147 | dataset = SequenceDataset(**ds_kwargs) 148 | 149 | return dataset 150 | 151 | 152 | def run_rollout( 153 | policy, 154 | env, 155 | horizon, 156 | use_goals=False, 157 | render=False, 158 | video_writer=None, 159 | video_skip=5, 160 | terminate_on_success=False, 161 | ): 162 | """ 163 | Runs a rollout in an environment with the current network parameters. 164 | 165 | Args: 166 | policy (RolloutPolicy instance): policy to use for rollouts. 167 | 168 | env (EnvBase instance): environment to use for rollouts. 169 | 170 | horizon (int): maximum number of steps to roll the agent out for 171 | 172 | use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env 173 | 174 | render (bool): if True, render the rollout to the screen 175 | 176 | video_writer (imageio Writer instance): if not None, use video writer object to append frames at 177 | rate given by @video_skip 178 | 179 | video_skip (int): how often to write video frame 180 | 181 | terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered 182 | 183 | Returns: 184 | results (dict): dictionary containing return, success rate, etc. 185 | """ 186 | assert isinstance(policy, RolloutPolicy) 187 | assert ( 188 | isinstance(env, EnvBase) 189 | or isinstance(env.env, EnvBase) 190 | or isinstance(env, FrameStackWrapper) 191 | ) 192 | 193 | policy.start_episode() 194 | 195 | ob_dict = env.reset() 196 | goal_dict = None 197 | if use_goals: 198 | # retrieve goal from the environment 199 | goal_dict = env.get_goal() 200 | 201 | results = {} 202 | video_count = 0 # video frame counter 203 | 204 | total_reward = 0.0 205 | success = {k: False for k in env.is_success()} # success metrics 206 | obs_log = {k: [v.reshape(1, -1)] for k, v in ob_dict.items() if not (k.endswith("image"))} 207 | traj = dict(actions=[], states=[], initial_state_dict=env.get_state()) 208 | try: 209 | for step_i in range(horizon): 210 | state_dict = env.get_state() 211 | traj["states"].append(state_dict["states"]) 212 | # get action from policy 213 | ac = policy(ob=ob_dict, goal=goal_dict) 214 | # play action 215 | ob_dict, r, done, info = env.step(ac) 216 | 217 | for k, v in ob_dict.items(): 218 | if not (k.endswith("image")): 219 | obs_log[k].append(v.reshape(1, -1)) 220 | 221 | # render to screen 222 | if render: 223 | env.render(mode="human") 224 | 225 | # compute reward 226 | total_reward += r 227 | 228 | cur_success_metrics = env.is_success() 229 | for k in success: 230 | success[k] = success[k] or cur_success_metrics[k] 231 | 232 | # visualization 233 | if video_writer is not None: 234 | if video_count % video_skip == 0: 235 | video_img = [] 236 | video_img.append(env.render(mode="rgb_array", height=512, width=512)) 237 | video_img = np.concatenate(video_img, axis=1) # concatenate horizontally 238 | video_writer.append_data(video_img) 239 | video_count += 1 240 | 241 | # break if done 242 | if done or (terminate_on_success and success["task"]): 243 | break 244 | state_dict = env.get_state() 245 | traj["states"].append(state_dict["states"]) 246 | traj["actions"] = np.array([0]) # just a dummy value 247 | traj["attrs"] = dict(num_samples=len(traj["states"])) 248 | except env.rollout_exceptions as e: 249 | print("WARNING: got rollout exception {}".format(e)) 250 | 251 | for k, v in obs_log.items(): 252 | obs_log[k] = np.concatenate(v, axis=0) 253 | 254 | results["Return"] = total_reward 255 | results["Horizon"] = step_i + 1 256 | results["Success_Rate"] = float(success["task"]) 257 | results["Observations"] = obs_log 258 | for k, v in info.items(): 259 | if not (k.endswith("actions")) and not (k.endswith("obs")): 260 | results[k] = v 261 | 262 | # log additional success metrics 263 | for k in success: 264 | if k != "task": 265 | results["{}_Success_Rate".format(k)] = float(success[k]) 266 | return results, traj 267 | 268 | 269 | @torch.no_grad() 270 | def rollout_with_stats( 271 | policy, 272 | envs, 273 | horizon, 274 | use_goals=False, 275 | num_episodes=None, 276 | render=False, 277 | video_dir=None, 278 | video_path=None, 279 | epoch=None, 280 | video_skip=5, 281 | terminate_on_success=False, 282 | verbose=False, 283 | rollout_dir=None, 284 | config=None, 285 | ): 286 | """ 287 | A helper function used in the train loop to conduct evaluation rollouts per environment 288 | and summarize the results. 289 | 290 | Can specify @video_dir (to dump a video per environment) or @video_path (to dump a single video 291 | for all environments). 292 | 293 | Args: 294 | policy (RolloutPolicy instance): policy to use for rollouts. 295 | 296 | envs (dict): dictionary that maps env_name (str) to EnvBase instance. The policy will 297 | be rolled out in each env. 298 | 299 | horizon (int): maximum number of steps to roll the agent out for 300 | 301 | use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env 302 | 303 | num_episodes (int): number of rollout episodes per environment 304 | 305 | render (bool): if True, render the rollout to the screen 306 | 307 | video_dir (str): if not None, dump rollout videos to this directory (one per environment) 308 | 309 | video_path (str): if not None, dump a single rollout video for all environments 310 | 311 | epoch (int): epoch number (used for video naming) 312 | 313 | video_skip (int): how often to write video frame 314 | 315 | terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered 316 | 317 | verbose (bool): if True, print results of each rollout 318 | 319 | Returns: 320 | all_rollout_logs (dict): dictionary of rollout statistics (e.g. return, success rate, ...) 321 | averaged across all rollouts 322 | 323 | video_paths (dict): path to rollout videos for each environment 324 | """ 325 | assert isinstance(policy, RolloutPolicy) 326 | 327 | all_rollout_logs = OrderedDict() 328 | 329 | # handle paths and create writers for video writing 330 | assert (video_path is None) or ( 331 | video_dir is None 332 | ), "rollout_with_stats: can't specify both video path and dir" 333 | write_video = (video_path is not None) or (video_dir is not None) 334 | video_paths = OrderedDict() 335 | video_writers = OrderedDict() 336 | if video_path is not None: 337 | # a single video is written for all envs 338 | video_paths = {k: video_path for k in envs} 339 | video_writer = imageio.get_writer(video_path, fps=20) 340 | video_writers = {k: video_writer for k in envs} 341 | if video_dir is not None: 342 | # video is written per env 343 | video_str = "_epoch_{}.mp4".format(epoch) if epoch is not None else ".mp4" 344 | video_paths = {k: os.path.join(video_dir, "{}{}".format(k, video_str)) for k in envs} 345 | video_writers = {k: imageio.get_writer(video_paths[k], fps=20) for k in envs} 346 | 347 | for env_name, env in envs.items(): 348 | env_video_writer = None 349 | if write_video: 350 | print("video writes to " + video_paths[env_name]) 351 | env_video_writer = video_writers[env_name] 352 | 353 | print( 354 | "rollout: env={}, horizon={}, use_goals={}, num_episodes={}".format( 355 | env.name, 356 | horizon, 357 | use_goals, 358 | num_episodes, 359 | ) 360 | ) 361 | rollout_logs = [] 362 | num_valid_demos = get_num_valid_demos(config, env_name) 363 | num_episodes = min(num_valid_demos, num_episodes) 364 | iterator = range(num_episodes) 365 | if not verbose: 366 | iterator = LogUtils.custom_tqdm(iterator, total=num_episodes) 367 | 368 | num_success = 0 369 | obs_logs = {} 370 | env.sample_eval_episodes(num_episodes) 371 | data_writer = h5py.File(os.path.join(rollout_dir, f"rollout_{epoch}.hdf5"), "w") 372 | data_grp = data_writer.create_group("data") 373 | for ep_i in iterator: 374 | rollout_timestamp = time.time() 375 | rollout_info, traj = run_rollout( 376 | policy=policy, 377 | env=env, 378 | horizon=horizon, 379 | render=render, 380 | use_goals=use_goals, 381 | video_writer=env_video_writer, 382 | video_skip=video_skip, 383 | terminate_on_success=terminate_on_success, 384 | ) 385 | rollout_info["time"] = time.time() - rollout_timestamp 386 | obs_logs[ep_i] = {"obs": rollout_info["Observations"]} 387 | del rollout_info["Observations"] 388 | rollout_logs.append(rollout_info) 389 | num_success += rollout_info["Success_Rate"] 390 | if verbose: 391 | print( 392 | "Episode {}, horizon={}, num_success={}".format(ep_i + 1, horizon, num_success) 393 | ) 394 | print(json.dumps(rollout_info, sort_keys=True, indent=4)) 395 | write_trajectory_to_dataset( 396 | None, 397 | traj, 398 | data_grp, 399 | demo_name=f"demo_{ep_i}", 400 | env_type="mujoco", 401 | ) 402 | global_dataset_updates(data_grp, 0, json.dumps(env.serialize(), indent=4)) 403 | data_writer.close() 404 | successes = [rollout_log["Success_Rate"] for rollout_log in rollout_logs] 405 | if video_dir is not None: 406 | # close this env's video writer (next env has it's own) 407 | env_video_writer.close() 408 | 409 | # average metric across all episodes 410 | rollout_logs = dict( 411 | ( 412 | k, 413 | [ 414 | rollout_logs[i][k] 415 | for i in range(len(rollout_logs)) 416 | if rollout_logs[i][k] is not None 417 | ], 418 | ) 419 | for k in rollout_logs[0] 420 | ) 421 | rollout_logs_mean = dict((k, np.mean(v)) for k, v in rollout_logs.items()) 422 | rollout_logs_mean["Time_Episode"] = ( 423 | np.sum(rollout_logs["time"]) / 60.0 424 | ) # total time taken for rollouts in minutes 425 | all_rollout_logs[env_name] = rollout_logs_mean 426 | 427 | if video_path is not None: 428 | # close video writer that was used for all envs 429 | video_writer.close() 430 | 431 | return all_rollout_logs, video_paths 432 | 433 | 434 | def get_num_valid_demos(config, env_name): 435 | valid_key = config.experiment.rollout.valid_key 436 | dataset_path = config.train.data 437 | if dataset_path is not None: 438 | hdf5_file = h5py.File(dataset_path, "r", swmr=True, libver="latest") 439 | filter_key = valid_key 440 | demos = [ 441 | elem.decode("utf-8") for elem in np.array(hdf5_file["mask/{}".format(filter_key)][:]) 442 | ] 443 | return len(demos) 444 | else: 445 | return 0 446 | 447 | 448 | def get_config_from_path(config_path): 449 | ext_cfg = json.load(open(os.path.join(config_path), "r")) 450 | config = config_factory(ext_cfg["algo_name"]) 451 | # update config with external json - this will throw errors if 452 | # the external config has keys not present in the base algo config 453 | with config.values_unlocked(): 454 | config.update(ext_cfg) 455 | return config 456 | -------------------------------------------------------------------------------- /optimus/scripts/run_trained_agent_pl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | """ 5 | The main script for evaluating a policy in an environment. Adapted to use PyTorch Lightning and Optimus codebase. 6 | 7 | Args: 8 | agent (str): path to saved checkpoint pth file 9 | 10 | horizon (int): if provided, override maximum horizon of rollout from the one 11 | in the checkpoint 12 | 13 | env (str): if provided, override name of env from the one in the checkpoint, 14 | and use it for rollouts 15 | 16 | render (bool): if flag is provided, use on-screen rendering during rollouts 17 | 18 | video_path (str): if provided, render trajectories to this video file path 19 | 20 | video_skip (int): render frames to a video every @video_skip steps 21 | 22 | camera_names (str or [str]): camera name(s) to use for rendering on-screen or to video 23 | 24 | dataset_path (str): if provided, an hdf5 file will be written at this path with the 25 | rollout data 26 | 27 | dataset_obs (bool): if flag is provided, and @dataset_path is provided, include 28 | possible high-dimensional observations in output dataset hdf5 file (by default, 29 | observations are excluded and only simulator states are saved). 30 | 31 | seed (int): if provided, set seed for rollouts 32 | 33 | Example usage: 34 | 35 | # Evaluate a policy with 50 rollouts of maximum horizon 400 and save the rollouts to a video. 36 | # Visualize the agentview and wrist cameras during the rollout. 37 | 38 | python run_trained_agent.py --agent /path/to/model.pth \ 39 | --n_rollouts 50 --horizon 400 --seed 0 \ 40 | --video_path /path/to/output.mp4 \ 41 | --camera_names agentview robot0_eye_in_hand 42 | 43 | # Write the 50 agent rollouts to a new dataset hdf5. 44 | 45 | python run_trained_agent.py --agent /path/to/model.pth \ 46 | --n_rollouts 50 --horizon 400 --seed 0 \ 47 | --dataset_path /path/to/output.hdf5 --dataset_obs 48 | 49 | # Write the 50 agent rollouts to a new dataset hdf5, but exclude the dataset observations 50 | # since they might be high-dimensional (they can be extracted again using the 51 | # dataset_states_to_obs.py script). 52 | 53 | python run_trained_agent.py --agent /path/to/model.pth \ 54 | --n_rollouts 50 --horizon 400 --seed 0 \ 55 | --dataset_path /path/to/output.hdf5 56 | """ 57 | import argparse 58 | import json 59 | import os 60 | import random 61 | from collections import OrderedDict 62 | from copy import deepcopy 63 | 64 | import h5py 65 | import imageio 66 | import numpy as np 67 | from optimus.envs.wrappers import EvaluateOnDatasetWrapper, FrameStackWrapper 68 | import robomimic.utils.obs_utils as ObsUtils 69 | import robomimic.utils.tensor_utils as TensorUtils 70 | import torch 71 | from pytorch_lightning import seed_everything 72 | from robomimic.algo import RolloutPolicy 73 | from robomimic.envs.env_base import EnvType 74 | from tqdm import tqdm 75 | 76 | import optimus.utils.file_utils as FileUtils 77 | import optimus.utils.env_utils as EnvUtils 78 | from optimus.algo import algo_factory 79 | from optimus.config.base_config import config_factory 80 | from optimus.scripts.pl_train import ModelWrapper 81 | 82 | DEFAULT_CAMERAS = { 83 | EnvType.ROBOSUITE_TYPE: ["agentview"], 84 | EnvType.IG_MOMART_TYPE: ["rgb"], 85 | EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"), 86 | } 87 | 88 | 89 | def rollout( 90 | policy, 91 | env, 92 | horizon, 93 | render=False, 94 | video_writer=None, 95 | video_skip=5, 96 | return_obs=False, 97 | camera_names=None, 98 | ): 99 | """ 100 | Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video, 101 | and returns the rollout trajectory. 102 | 103 | Args: 104 | policy (instance of RolloutPolicy): policy loaded from a checkpoint 105 | env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata 106 | horizon (int): maximum horizon for the rollout 107 | render (bool): whether to render rollout on-screen 108 | video_writer (imageio writer): if provided, use to write rollout to video 109 | video_skip (int): how often to write video frames 110 | return_obs (bool): if True, return possibly high-dimensional observations along the trajectoryu. 111 | They are excluded by default because the low-dimensional simulation states should be a minimal 112 | representation of the environment. 113 | camera_names (list): determines which camera(s) are used for rendering. Pass more than 114 | one to output a video with multiple camera views concatenated horizontally. 115 | 116 | Returns: 117 | stats (dict): some statistics for the rollout - such as return, horizon, and task success 118 | traj (dict): dictionary that corresponds to the rollout trajectory 119 | """ 120 | assert isinstance(policy, RolloutPolicy) 121 | assert not (render and (video_writer is not None)) 122 | 123 | policy.start_episode() 124 | obs = env.reset() 125 | state_dict = env.get_state() 126 | 127 | # # hack that is necessary for robosuite tasks for deterministic action playback 128 | obs = env.reset_to(state_dict) 129 | 130 | video_count = 0 # video frame counter 131 | total_reward = 0.0 132 | traj = dict(actions=[], rewards=[], dones=[], states=[], initial_state_dict=state_dict) 133 | if return_obs: 134 | # store observations too 135 | traj.update(dict(obs=[], next_obs=[])) 136 | try: 137 | for step_i in range(horizon): 138 | # get action from policy 139 | act = policy(ob=obs) 140 | 141 | # play action 142 | next_obs, r, done, _ = env.step(act) 143 | 144 | # compute reward 145 | total_reward += r 146 | success = env.is_success()["task"] 147 | 148 | # visualization 149 | if render: 150 | env.render(mode="human", camera_name=camera_names[0]) 151 | if video_writer is not None: 152 | if video_count % video_skip == 0: 153 | video_img = [] 154 | for cam_name in camera_names: 155 | video_img.append( 156 | env.render( 157 | camera_name=cam_name, 158 | mode="rgb_array", 159 | height=512, 160 | width=512, 161 | ) 162 | ) 163 | video_img = np.concatenate(video_img, axis=1) # concatenate horizontally 164 | video_writer.append_data(video_img) 165 | video_count += 1 166 | 167 | # collect transition 168 | traj["actions"].append(act) 169 | traj["rewards"].append(r) 170 | traj["dones"].append(done) 171 | traj["states"].append(state_dict["states"]) 172 | if return_obs: 173 | # Note: We need to "unprocess" the observations to prepare to write them to dataset. 174 | # This includes operations like channel swapping and float to uint8 conversion 175 | # for saving disk space. 176 | traj["obs"].append(ObsUtils.unprocess_obs_dict(obs)) 177 | traj["next_obs"].append(ObsUtils.unprocess_obs_dict(next_obs)) 178 | 179 | # break if done or if success 180 | if done or success: 181 | break 182 | 183 | # update for next iter 184 | obs = deepcopy(next_obs) 185 | state_dict = env.get_state() 186 | 187 | except env.rollout_exceptions as e: 188 | print("WARNING: got rollout exception {}".format(e)) 189 | 190 | stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success)) 191 | 192 | if return_obs: 193 | # convert list of dict to dict of list for obs dictionaries (for convenient writes to hdf5 dataset) 194 | traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"]) 195 | traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"]) 196 | 197 | # list to numpy array 198 | for k in traj: 199 | if k == "initial_state_dict": 200 | continue 201 | if isinstance(traj[k], dict): 202 | for kp in traj[k]: 203 | traj[k][kp] = np.array(traj[k][kp]) 204 | else: 205 | traj[k] = np.array(traj[k]) 206 | return stats, traj 207 | 208 | 209 | def run_trained_agent(args): 210 | # some arg checking 211 | write_video = args.video_path is not None 212 | assert not (args.render and write_video) # either on-screen or video but not both 213 | 214 | if args.render: 215 | # on-screen rendering can only support one camera 216 | assert len(args.camera_names) == 1 217 | 218 | # relative path to agent 219 | ckpt_path = args.agent 220 | 221 | if args.resume_dir: 222 | resume_dir = args.resume_dir 223 | else: 224 | resume_dir = os.path.dirname(ckpt_path)[: -len("models")] 225 | config_path = os.path.join(resume_dir, "config.json") 226 | ext_cfg = json.load(open(config_path, "r")) 227 | config = config_factory(ext_cfg["algo_name"]) 228 | # update config with external json - this will throw errors if 229 | # the external config has keys not present in the base algo config 230 | with config.values_unlocked(): 231 | config.update(ext_cfg) 232 | # read rollout settings 233 | rollout_num_episodes = args.n_rollouts 234 | rollout_horizon = args.horizon 235 | if rollout_horizon is None: 236 | # read horizon from config 237 | rollout_horizon = config.experiment.rollout.horizon 238 | 239 | # read config to set up metadata for observation modalities (e.g. detecting rgb observations) 240 | ObsUtils.initialize_obs_utils_with_config(config) 241 | 242 | # make sure the dataset exists 243 | dataset_path = args.data if args.data else config.train.data 244 | if args.data: 245 | config.train.data = args.data 246 | json.dump(config, open(config_path, "w"), indent=4) 247 | if not os.path.exists(dataset_path): 248 | raise Exception("Dataset at provided path {} not found!".format(dataset_path)) 249 | 250 | # load basic metadata from training file 251 | print("\n============= Loaded Environment Metadata =============") 252 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path) 253 | shape_meta = FileUtils.get_shape_metadata_from_dataset( 254 | dataset_path=dataset_path, all_obs_keys=config.all_obs_keys, verbose=True 255 | ) 256 | 257 | model = algo_factory( 258 | algo_name=config.algo_name, 259 | config=config, 260 | obs_key_shapes=shape_meta["all_shapes"], 261 | ac_dim=shape_meta["ac_dim"], 262 | device=torch.device("cpu"), # default to cpu, pl will move to gpu 263 | ) 264 | model.nets["policy"].kl_loss_weight = config.algo.transformer.kl_loss_weight 265 | 266 | if config.experiment.env is not None: 267 | env_meta["env_name"] = config.experiment.env 268 | print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30) 269 | # create environment 270 | envs = OrderedDict() 271 | if config.experiment.rollout.enabled: 272 | # create environments for validation runs 273 | env_names = [env_meta["env_name"]] 274 | 275 | if config.experiment.additional_envs is not None: 276 | for name in config.experiment.additional_envs: 277 | env_names.append(name) 278 | for env_name in env_names: 279 | env = EnvUtils.create_env_from_metadata( 280 | env_meta=env_meta, 281 | env_name=env_name, 282 | render=False, 283 | render_offscreen=config.experiment.render_video, 284 | use_image_obs=shape_meta["use_images"], 285 | ) 286 | if config.train.frame_stack > 1: 287 | env = FrameStackWrapper( 288 | env, 289 | config.train.frame_stack, 290 | config.experiment.rollout.horizon, 291 | dataset_path=config.train.data, 292 | valid_key=config.experiment.rollout.valid_key, 293 | ) 294 | else: 295 | env = EvaluateOnDatasetWrapper( 296 | env, 297 | dataset_path=config.train.data, 298 | valid_key=config.experiment.rollout.valid_key, 299 | ) 300 | 301 | envs[env.name] = env 302 | print(envs[env.name]) 303 | 304 | print("") 305 | 306 | # maybe retreve statistics for normalizing observations 307 | obs_normalization_stats = None 308 | seed_everything(config.train.seed, workers=True) 309 | model = ModelWrapper.load_from_checkpoint(ckpt_path, model=model).cuda() 310 | model.model.nets = model.nets.cuda() 311 | model.model.device = torch.device("cuda") 312 | policy = RolloutPolicy(model.model, obs_normalization_stats=obs_normalization_stats) 313 | # maybe set seed 314 | if args.seed is not None: 315 | np.random.seed(args.seed) 316 | torch.manual_seed(args.seed) 317 | 318 | # maybe create video writer 319 | video_writer = None 320 | if write_video: 321 | video_writer = imageio.get_writer(args.video_path, fps=20) 322 | 323 | # maybe open hdf5 to write rollouts 324 | write_dataset = args.dataset_path is not None 325 | if write_dataset: 326 | data_writer = h5py.File(args.dataset_path, "w") 327 | data_grp = data_writer.create_group("data") 328 | total_samples = 0 329 | 330 | rollout_stats = [] 331 | env.sample_eval_episodes(rollout_num_episodes) 332 | for i in tqdm(range(rollout_num_episodes)): 333 | stats, traj = rollout( 334 | policy=policy, 335 | env=env, 336 | horizon=rollout_horizon, 337 | render=args.render, 338 | video_writer=video_writer, 339 | video_skip=args.video_skip, 340 | return_obs=(write_dataset and args.dataset_obs), 341 | camera_names=args.camera_names, 342 | ) 343 | rollout_stats.append(stats) 344 | 345 | if write_dataset: 346 | # store transitions 347 | ep_data_grp = data_grp.create_group("demo_{}".format(i)) 348 | ep_data_grp.create_dataset("actions", data=np.array(traj["actions"])) 349 | ep_data_grp.create_dataset("states", data=np.array(traj["states"])) 350 | ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"])) 351 | ep_data_grp.create_dataset("dones", data=np.array(traj["dones"])) 352 | if args.dataset_obs: 353 | for k in traj["obs"]: 354 | ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k])) 355 | ep_data_grp.create_dataset( 356 | "next_obs/{}".format(k), data=np.array(traj["next_obs"][k]) 357 | ) 358 | 359 | # episode metadata 360 | if "model" in traj["initial_state_dict"]: 361 | ep_data_grp.attrs["model_file"] = traj["initial_state_dict"][ 362 | "model" 363 | ] # model xml for this episode 364 | ep_data_grp.attrs["num_samples"] = traj["actions"].shape[ 365 | 0 366 | ] # number of transitions in this episode 367 | total_samples += traj["actions"].shape[0] 368 | 369 | rollout_stats = TensorUtils.list_of_flat_dict_to_dict_of_list(rollout_stats) 370 | avg_rollout_stats = {k: np.mean(rollout_stats[k]) for k in rollout_stats} 371 | avg_rollout_stats["Num_Success"] = np.sum(rollout_stats["Success_Rate"]) 372 | print("Average Rollout Stats") 373 | print(json.dumps(avg_rollout_stats, indent=4)) 374 | 375 | if write_video: 376 | video_writer.close() 377 | 378 | if write_dataset: 379 | # global metadata 380 | data_grp.attrs["total"] = total_samples 381 | data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4) # environment info 382 | data_writer.close() 383 | print("Wrote dataset trajectories to {}".format(args.dataset_path)) 384 | 385 | 386 | if __name__ == "__main__": 387 | parser = argparse.ArgumentParser() 388 | 389 | # Path to trained model 390 | parser.add_argument( 391 | "--agent", 392 | type=str, 393 | required=True, 394 | help="path to saved checkpoint ckpt file", 395 | ) 396 | 397 | parser.add_argument( 398 | "--resume_dir", 399 | type=str, 400 | required=True, 401 | help="path to saved checkpoint dir", 402 | ) 403 | 404 | # number of rollouts 405 | parser.add_argument( 406 | "--n_rollouts", 407 | type=int, 408 | default=27, 409 | help="number of rollouts", 410 | ) 411 | 412 | # maximum horizon of rollout, to override the one stored in the model checkpoint 413 | parser.add_argument( 414 | "--horizon", 415 | type=int, 416 | default=None, 417 | help="(optional) override maximum horizon of rollout from the one in the checkpoint", 418 | ) 419 | 420 | # Env Name (to override the one stored in model checkpoint) 421 | parser.add_argument( 422 | "--env", 423 | type=str, 424 | default=None, 425 | help="(optional) override name of env from the one in the checkpoint, and use\ 426 | it for rollouts", 427 | ) 428 | 429 | # Whether to render rollouts to screen 430 | parser.add_argument( 431 | "--render", 432 | action="store_true", 433 | help="on-screen rendering", 434 | ) 435 | 436 | # Dump a video of the rollouts to the specified path 437 | parser.add_argument( 438 | "--video_path", 439 | type=str, 440 | default=None, 441 | help="(optional) render rollouts to this video file path", 442 | ) 443 | 444 | # How often to write video frames during the rollout 445 | parser.add_argument( 446 | "--video_skip", 447 | type=int, 448 | default=5, 449 | help="render frames to video every n steps", 450 | ) 451 | 452 | # camera names to render 453 | parser.add_argument( 454 | "--camera_names", 455 | type=str, 456 | nargs="+", 457 | default=["agentview"], 458 | help="(optional) camera name(s) to use for rendering on-screen or to video", 459 | ) 460 | 461 | # If provided, an hdf5 file will be written with the rollout data 462 | parser.add_argument( 463 | "--dataset_path", 464 | type=str, 465 | default=None, 466 | help="(optional) if provided, an hdf5 file will be written at this path with the rollout data", 467 | ) 468 | 469 | # If True and @dataset_path is supplied, will write possibly high-dimensional observations to dataset. 470 | parser.add_argument( 471 | "--dataset_obs", 472 | action="store_true", 473 | help="include possibly high-dimensional observations in output dataset hdf5 file (by default,\ 474 | observations are excluded and only simulator states are saved)", 475 | ) 476 | 477 | # for seeding before starting rollouts 478 | parser.add_argument( 479 | "--seed", 480 | type=int, 481 | default=None, 482 | help="(optional) set seed for rollouts", 483 | ) 484 | 485 | parser.add_argument("--data", type=str, default=None) 486 | 487 | args = parser.parse_args() 488 | run_trained_agent(args) 489 | --------------------------------------------------------------------------------