├── .gitattributes
├── __init__.py
├── assets
└── optimus_gallery_gif.gif
├── pyproject.toml
├── optimus
├── config
│ ├── __init__.py
│ ├── base_config.py
│ └── bc_config.py
├── algo
│ └── __init__.py
├── envs
│ ├── box.py
│ ├── wrappers.py
│ └── env_robosuite.py
├── scripts
│ ├── download_datasets.py
│ ├── combine_hdf5.py
│ ├── filter_trajectories.py
│ ├── playback_dataset.py
│ └── run_trained_agent_pl.py
├── utils
│ ├── file_utils.py
│ ├── env_utils.py
│ ├── dataset.py
│ └── train_utils.py
└── exps
│ └── local
│ └── robosuite
│ ├── stack
│ └── bc_transformer.json
│ ├── stackfive
│ └── bc_transformer.json
│ ├── stackfour
│ └── bc_transformer.json
│ └── stackthree
│ └── bc_transformer.json
├── requirements.txt
├── setup.py
├── .gitignore
├── docker
└── mujoco
│ └── Dockerfile
├── LICENSE
└── README.md
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.gif filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
--------------------------------------------------------------------------------
/assets/optimus_gallery_gif.gif:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:33539d2962f145f0da70232080cd2bd143d086ed6630ef9d9046cfaf1ca847fd
3 | size 53889018
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.isort]
2 | multi_line_output = 3
3 | include_trailing_comma = true
4 | force_grid_wrap = 0
5 | use_parentheses = true
6 | ensure_newline_before_comments = true
7 | line_length = 100
8 |
9 | [tool.black]
10 | line-length = 100
11 |
--------------------------------------------------------------------------------
/optimus/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | from robomimic.config.config import Config
6 |
7 | from optimus.config.base_config import config_factory, get_all_registered_configs
8 |
9 | # note: these imports are needed to register these classes in the global config registry
10 | from optimus.config.bc_config import BCConfig
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py
2 | psutil
3 | tqdm
4 | termcolor
5 | tensorboard
6 | tensorboardX
7 | imageio
8 | imageio-ffmpeg
9 | egl_probe>=1.0.1
10 | ipdb
11 | wandb
12 |
13 | ipython
14 | patchelf
15 | robosuite==1.4.0
16 | robomimic==0.3.0
17 | jupyterlab
18 | notebook
19 | black
20 | flake8
21 | isort
22 | pytest
23 | protobuf==3.20.0
24 | pytorch_lightning==1.9.5
25 | gdown
26 |
27 | seaborn
28 | mujoco
29 | pygame
30 |
31 | numpy==1.23.5
32 | pyopengl==3.1.6
--------------------------------------------------------------------------------
/optimus/algo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | from robomimic.algo.algo import (
6 | Algo,
7 | HierarchicalAlgo,
8 | PlannerAlgo,
9 | PolicyAlgo,
10 | RolloutPolicy,
11 | ValueAlgo,
12 | algo_factory,
13 | algo_name_to_factory_func,
14 | register_algo_factory_func,
15 | )
16 |
17 | # note: these imports are needed to register these classes in the global algo registry
18 | from robomimic.algo.bc import BC, BC_GMM, BC_VAE, BC_Gaussian
19 | from robomimic.algo.bcq import BCQ, BCQ_GMM, BCQ_Distributional
20 | from robomimic.algo.cql import CQL
21 | from robomimic.algo.gl import GL, GL_VAE, ValuePlanner
22 | from robomimic.algo.hbc import HBC
23 | from robomimic.algo.iris import IRIS
24 | from robomimic.algo.td3_bc import TD3_BC
25 |
26 | from optimus.algo.bc import BC_RNN, BC_RNN_GMM, BC_Transformer, BC_Transformer_GMM
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # setup.py
2 | #!/usr/bin/env python
3 | from os import path
4 |
5 | from setuptools import find_packages, setup
6 |
7 | this_directory = path.abspath(path.dirname(__file__))
8 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f:
9 | lines = f.readlines()
10 |
11 | # remove images from README
12 | lines = [x for x in lines if ".png" not in x]
13 | long_description = "".join(lines)
14 |
15 | setup(
16 | name="optimus",
17 | packages=[package for package in find_packages() if package.startswith("optimus")],
18 | install_requires=[],
19 | eager_resources=["*"],
20 | include_package_data=True,
21 | python_requires=">=3",
22 | description="Official code release for Optimus: Imitating Task and Motion Planning with Visuomotor Transformers",
23 | author="Murtaza Dalal",
24 | url="https://github.com/NVlabs/Optimus.git",
25 | author_email="mdalal@andrew.cmu.edu",
26 | version="0.1.0",
27 | long_description=long_description,
28 | long_description_content_type="text/markdown",
29 | )
30 |
--------------------------------------------------------------------------------
/optimus/config/base_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | """
5 | The base config class that is used for all algorithm configs in this repository.
6 | Subclasses get registered into a global dictionary, making it easy to instantiate
7 | the correct config class given the algorithm name.
8 | """
9 |
10 | from copy import deepcopy
11 |
12 | import robomimic
13 | import six # preserve metaclass compatibility between python 2 and 3
14 | from robomimic.config import base_config
15 | from robomimic.config.config import Config
16 |
17 | # global dictionary for remembering name - class mappings
18 | REGISTERED_CONFIGS = base_config.REGISTERED_CONFIGS
19 |
20 |
21 | def get_all_registered_configs():
22 | """
23 | Give access to dictionary of all registered configs for external use.
24 | """
25 | return deepcopy(REGISTERED_CONFIGS)
26 |
27 |
28 | def config_factory(algo_name, dic=None):
29 | """
30 | Creates an instance of a config from the algo name. Optionally pass
31 | a dictionary to instantiate the config from the dictionary.
32 | """
33 | if algo_name not in REGISTERED_CONFIGS:
34 | raise Exception(
35 | "Config for algo name {} not found. Make sure it is a registered config among: {}".format(
36 | algo_name, ", ".join(REGISTERED_CONFIGS)
37 | )
38 | )
39 | return REGISTERED_CONFIGS[algo_name](dict_to_load=dic)
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Mac OSX
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | datasets/
135 | trained_models/
--------------------------------------------------------------------------------
/optimus/envs/box.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import numpy as np
6 | from robosuite.models.objects import PrimitiveObject
7 | from robosuite.utils.mjcf_utils import get_size, new_site
8 |
9 |
10 | class BoxObject(PrimitiveObject):
11 | """
12 | A box object.
13 |
14 | Args:
15 | size (3-tuple of float): (half-x, half-y, half-z) size parameters for this box object
16 | """
17 |
18 | def __init__(
19 | self,
20 | name,
21 | size=None,
22 | size_max=None,
23 | size_min=None,
24 | density=None,
25 | friction=None,
26 | rgba=None,
27 | solref=None,
28 | solimp=None,
29 | material=None,
30 | joints="default",
31 | obj_type="all",
32 | duplicate_collision_geoms=True,
33 | ):
34 | size = get_size(size, size_max, size_min, [0.07, 0.07, 0.07], [0.03, 0.03, 0.03])
35 | super().__init__(
36 | name=name,
37 | size=size,
38 | rgba=rgba,
39 | density=density,
40 | friction=friction,
41 | solref=solref,
42 | solimp=solimp,
43 | material=material,
44 | joints=joints,
45 | obj_type=obj_type,
46 | duplicate_collision_geoms=duplicate_collision_geoms,
47 | )
48 |
49 | def sanity_check(self):
50 | """
51 | Checks to make sure inputted size is of correct length
52 |
53 | Raises:
54 | AssertionError: [Invalid size length]
55 | """
56 | assert len(self.size) == 3, "box size should have length 3"
57 |
58 | def _get_object_subtree(self):
59 | return self._get_object_subtree_(ob_type="box")
60 |
61 | @property
62 | def bottom_offset(self):
63 | return np.array([0, 0, -1 * self.size[2]])
64 |
65 | @property
66 | def top_offset(self):
67 | return np.array([0, 0, self.size[2]])
68 |
69 | @property
70 | def horizontal_radius(self):
71 | return np.linalg.norm(self.size[0:2], 2)
72 |
73 | def get_bounding_box_size(self):
74 | return np.array([self.size[0], self.size[1], self.size[2]])
75 |
76 |
77 | class BoxObjectWithSites(BoxObject):
78 | """
79 | A box object with sites on the x and y axes.
80 | """
81 |
82 | def _get_object_subtree(self):
83 | tree = self._get_object_subtree_(ob_type="box")
84 | site_element_attr = self.get_site_attrib_template()
85 |
86 | delta = self.size[0] / 2
87 | site_element_attr["pos"] = f"{delta} 0 0"
88 | site_element_attr["name"] = "x-site"
89 | tree.append(new_site(**site_element_attr))
90 |
91 | delta = self.size[0] / 2
92 | site_element_attr["pos"] = f"0 {delta} 0"
93 | site_element_attr["name"] = "y-site"
94 | tree.append(new_site(**site_element_attr))
95 | return tree
96 |
--------------------------------------------------------------------------------
/optimus/scripts/download_datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to download datasets packaged with the repository.
3 | """
4 | import os
5 | import argparse
6 |
7 | import optimus
8 | import optimus.utils.file_utils as FileUtils
9 | dataset_links = dict(
10 | Stack="https://drive.google.com/file/d/1ciP5yP0D06gT7QXq7OvxlR-mcepIoHCK/view?usp=drive_link",
11 | StackThree="https://drive.google.com/file/d/1_Qo0jYPfepe4rUpyDrtbmI9WHPipRwmL/view?usp=drive_link",
12 | StackFour="https://drive.google.com/file/d/142XBL1Hy2Ru1qGNbv4JlBjHm-Y0q_is2/view?usp=drive_link",
13 | StackFive="https://drive.google.com/file/d/1z7CFsBEXkj7DxLSpfUteZjPuN2quwBtA/view?usp=drive_link"
14 | )
15 |
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 |
20 | # directory to download datasets to
21 | parser.add_argument(
22 | "--download_dir",
23 | type=str,
24 | default=None,
25 | help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.",
26 | )
27 |
28 | # tasks to download datasets for
29 | parser.add_argument(
30 | "--tasks",
31 | type=str,
32 | nargs='+',
33 | default=["Stack"],
34 | help="Tasks to download datasets for. Defaults to square_d0 task. Pass 'all' to download all tasks\
35 | for the provided dataset type or directly specify the list of tasks.",
36 | )
37 |
38 | # dry run - don't actually download datasets, but print which datasets would be downloaded
39 | parser.add_argument(
40 | "--dry_run",
41 | action='store_true',
42 | help="set this flag to do a dry run to only print which datasets would be downloaded"
43 | )
44 |
45 | args = parser.parse_args()
46 |
47 | # set default base directory for downloads
48 | default_base_dir = args.download_dir
49 | if default_base_dir is None:
50 | default_base_dir = os.path.join(optimus.__path__[0], "../datasets")
51 |
52 | # load args
53 | download_tasks = args.tasks
54 | if "all" in download_tasks:
55 | assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks)
56 | download_tasks = list(dataset_links.keys())
57 | else:
58 | for task in download_tasks:
59 | assert task in dataset_links, "got unknown task {}. Choose one of {}".format(task, list(dataset_links.keys()))
60 |
61 | # download requested datasets
62 | for task in download_tasks:
63 | download_dir = os.path.abspath(os.path.join(default_base_dir))
64 | download_path = os.path.join(download_dir, "{}.hdf5".format(task))
65 | print("\nDownloading dataset:\n task: {}\n download path: {}"
66 | .format(task, download_path))
67 | url = dataset_links[task]
68 | if args.dry_run:
69 | print("\ndry run: skip download")
70 | else:
71 | # Make sure path exists and create if it doesn't
72 | os.makedirs(download_dir, exist_ok=True)
73 | print("")
74 | FileUtils.download_url_from_gdrive(
75 | url=url,
76 | download_dir=download_dir,
77 | check_overwrite=True,
78 | )
79 | print("")
--------------------------------------------------------------------------------
/docker/mujoco/Dockerfile:
--------------------------------------------------------------------------------
1 | # FROM nvcr.io/nvidia/pytorch:21.08-py3
2 | FROM nvcr.io/nvidia/cudagl:11.3.0-devel-ubuntu20.04
3 | ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility
4 | # env variables for tzdata install
5 | ARG DEBIAN_FRONTEND=noninteractive
6 | ENV TZ=America
7 | ENV LD_LIBRARY_PATH "$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin"
8 | ENV LD_LIBRARY_PATH "$LD_LIBRARY_PATH:/usr/lib/nvidia"
9 | ENV PYTHONPATH ${PYTHONPATH}:/workspace
10 | ENV MUJOCO_GL 'egl'
11 | ENV PATH "/usr/local/cuda-new/bin:$PATH"
12 | ENV PIP_CONFIG_FILE pip.conf
13 | ENV PYTHONPATH ${PYTHONPATH}:/home/robosuite
14 | ENV NVIDIA_DRIVER_CAPABILITIES compute,graphics,utility
15 |
16 | # NOTE: each RUN command creates a new image layer, so minimize the number of run commands if possible
17 | # installing other nice functionalities and system packages required by e.g. robosuite
18 | RUN apt-get update -y && \
19 | apt-get install -y \
20 | htop screen tmux \
21 | sshfs libosmesa6-dev wget curl git \
22 | libeigen3-dev \
23 | liborocos-kdl-dev \
24 | libkdl-parser-dev \
25 | liburdfdom-dev \
26 | libnlopt-dev \
27 | libnlopt-cxx-dev \
28 | swig \
29 | python3 \
30 | python3-pip \
31 | python3-dev \
32 | vim \
33 | git-lfs \
34 | cmake \
35 | software-properties-common \
36 | libxcursor-dev \
37 | libxrandr-dev \
38 | libxinerama-dev \
39 | libxi-dev \
40 | mesa-common-dev \
41 | zip \
42 | unzip \
43 | make \
44 | g++ \
45 | python2.7 \
46 | wget \
47 | vulkan-utils \
48 | mesa-vulkan-drivers \
49 | apt nano rsync \
50 | libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev \
51 | software-properties-common net-tools unzip virtualenv \
52 | xpra xserver-xorg-dev libglfw3-dev patchelf python3-pip -y \
53 | && add-apt-repository -y ppa:openscad/releases && apt-get update && apt-get install -y openscad
54 |
55 | # install mujoco
56 | RUN mkdir /root/.mujoco/ \
57 | && wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz \
58 | && tar -xvf mujoco210-linux-x86_64.tar.gz \
59 | && mv mujoco210 /root/.mujoco/ \
60 | && rm mujoco210-linux-x86_64.tar.gz
61 |
62 | # robomimic dependencies and installing other packages
63 | RUN python3 -m pip install h5py psutil tqdm termcolor tensorboard tensorboardX imageio imageio-ffmpeg egl_probe>=1.0.1 ipdb wandb \
64 | && python3 -m pip install ipython patchelf robosuite jupyterlab notebook black flake8 isort pytest protobuf==3.20.1 pytorch_lightning \
65 | && python3 -m pip install seaborn mujoco pygame signatory==1.2.6.1.9.0 pyopengl==3.1.6 vit-pytorch stable_baselines3 \
66 | && pip uninstall torch torchvision torchaudio -y \
67 | && python3 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
68 |
69 | RUN wget https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run --no-check-certificate
70 | RUN sh cuda_11.3.0_465.19.01_linux.run --toolkit --silent --toolkitpath=/usr/local/cuda-new
71 | ENV TORCH_CUDA_ARCH_LIST "7.0 7.5 8.0"
72 | RUN git clone https://github.com/NVIDIA/apex /home/apex && cd /home/apex \
73 | && python3 -m pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
74 |
75 | RUN python3 -m pip install gym \
76 | && python3 -m pip uninstall opencv-python -y && python3 -m pip install opencv-python-headless
77 |
78 | RUN python3 -m pip install trimesh==3.12.6 pyopengl==3.1.6
79 |
80 | RUN python3 -m pip install numpy==1.23.5
81 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | NVIDIA License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
7 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
8 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
9 |
10 | 2. License Grant
11 |
12 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
13 |
14 | 3. Limitations
15 |
16 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
17 |
18 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
19 |
20 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
21 |
22 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
23 |
24 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
25 |
26 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
27 |
28 | 4. Disclaimer of Warranty.
29 |
30 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
31 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
32 |
33 | 5. Limitation of Liability.
34 |
35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
--------------------------------------------------------------------------------
/optimus/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | """
6 | A collection of utility functions for working with files, such as reading metadata from
7 | demonstration datasets, loading model checkpoints, or downloading dataset files.
8 | """
9 | import os
10 | from collections import OrderedDict
11 |
12 | import h5py
13 | import robomimic.utils.obs_utils as ObsUtils
14 | from robomimic.utils.file_utils import *
15 |
16 | from optimus.algo import algo_factory
17 |
18 | import os
19 | import shutil
20 | import tempfile
21 | import gdown
22 |
23 | def download_url_from_gdrive(url, download_dir, check_overwrite=True):
24 | """
25 | Downloads a file at a URL from Google Drive.
26 |
27 | Example usage:
28 | url = https://drive.google.com/file/d/1DABdqnBri6-l9UitjQV53uOq_84Dx7Xt/view?usp=drive_link
29 | download_dir = "/tmp"
30 | download_url_from_gdrive(url, download_dir, check_overwrite=True)
31 |
32 | Args:
33 | url (str): url string
34 | download_dir (str): path to directory where file should be downloaded
35 | check_overwrite (bool): if True, will sanity check the download fpath to make sure a file of that name
36 | doesn't already exist there
37 | """
38 | assert url_is_alive(url), "@download_url_from_gdrive got unreachable url: {}".format(url)
39 |
40 | with tempfile.TemporaryDirectory() as td:
41 | # HACK: Change directory to temp dir, download file there, and then move the file to desired directory.
42 | # We do this because we do not know the name of the file beforehand.
43 | cur_dir = os.getcwd()
44 | os.chdir(td)
45 | fpath = gdown.download(url, quiet=False, fuzzy=True)
46 | fname = os.path.basename(fpath)
47 | file_to_write = os.path.join(download_dir, fname)
48 | if check_overwrite and os.path.exists(file_to_write):
49 | user_response = input(f"Warning: file {file_to_write} already exists. Overwrite? y/n\n")
50 | assert user_response.lower() in {"yes", "y"}, f"Did not receive confirmation. Aborting download."
51 | shutil.move(fpath, file_to_write)
52 | os.chdir(cur_dir)
53 |
54 |
55 | def policy_from_checkpoint(device=None, ckpt_path=None, ckpt_dict=None, verbose=False):
56 | """
57 | This function restores a trained policy from a checkpoint file or
58 | loaded model dictionary.
59 |
60 | Args:
61 | device (torch.device): if provided, put model on this device
62 |
63 | ckpt_path (str): Path to checkpoint file. Only needed if not providing @ckpt_dict.
64 |
65 | ckpt_dict(dict): Loaded model checkpoint dictionary. Only needed if not providing @ckpt_path.
66 |
67 | verbose (bool): if True, include print statements
68 |
69 | Returns:
70 | model (RolloutPolicy): instance of Algo that has the saved weights from
71 | the checkpoint file, and also acts as a policy that can easily
72 | interact with an environment in a training loop
73 |
74 | ckpt_dict (dict): loaded checkpoint dictionary (convenient to avoid
75 | re-loading checkpoint from disk multiple times)
76 | """
77 | ckpt_dict = maybe_dict_from_checkpoint(ckpt_path=ckpt_path, ckpt_dict=ckpt_dict)
78 |
79 | # algo name and config from model dict
80 | algo_name, _ = algo_name_from_checkpoint(ckpt_dict=ckpt_dict)
81 | config, _ = config_from_checkpoint(algo_name=algo_name, ckpt_dict=ckpt_dict, verbose=verbose)
82 |
83 | # read config to set up metadata for observation modalities (e.g. detecting rgb observations)
84 | ObsUtils.initialize_obs_utils_with_config(config)
85 |
86 | # env meta from model dict to get info needed to create model
87 | env_meta = ckpt_dict["env_metadata"]
88 | shape_meta = ckpt_dict["shape_metadata"]
89 |
90 | # maybe restore observation normalization stats
91 | obs_normalization_stats = ckpt_dict.get("obs_normalization_stats", None)
92 | if obs_normalization_stats is not None:
93 | assert config.train.hdf5_normalize_obs
94 | for m in obs_normalization_stats:
95 | for k in obs_normalization_stats[m]:
96 | obs_normalization_stats[m][k] = np.array(obs_normalization_stats[m][k])
97 |
98 | if device is None:
99 | # get torch device
100 | device = TorchUtils.get_torch_device(try_to_use_cuda=config.train.cuda)
101 |
102 | # create model and load weights
103 | model = algo_factory(
104 | algo_name,
105 | config,
106 | obs_key_shapes=shape_meta["all_shapes"],
107 | ac_dim=shape_meta["ac_dim"],
108 | device=device,
109 | )
110 | model.deserialize(ckpt_dict["model"])
111 | model.set_eval()
112 | model = RolloutPolicy(model, obs_normalization_stats=obs_normalization_stats)
113 | if verbose:
114 | print("============= Loaded Policy =============")
115 | print(model)
116 | return model, ckpt_dict
117 |
118 |
119 | def get_shape_metadata_from_dataset(dataset_path, all_obs_keys=None, verbose=False):
120 | """
121 | Retrieves shape metadata from dataset.
122 |
123 | Args:
124 | dataset_path (str): path to dataset
125 | all_obs_keys (list): list of all modalities used by the model. If not provided, all modalities
126 | present in the file are used.
127 | verbose (bool): if True, include print statements
128 |
129 | Returns:
130 | shape_meta (dict): shape metadata. Contains the following keys:
131 |
132 | :`'ac_dim'`: action space dimension
133 | :`'all_shapes'`: dictionary that maps observation key string to shape
134 | :`'all_obs_keys'`: list of all observation modalities used
135 | :`'use_images'`: bool, whether or not image modalities are present
136 | """
137 |
138 | shape_meta = {}
139 |
140 | # read demo file for some metadata
141 | dataset_path = os.path.expanduser(dataset_path)
142 | f = h5py.File(dataset_path, "r")
143 | demo_id = list(f["data"].keys())[0]
144 | demo = f["data/{}".format(demo_id)]
145 |
146 | # action dimension
147 | shape_meta["ac_dim"] = f["data/{}/actions".format(demo_id)].shape[1]
148 |
149 | # observation dimensions
150 | all_shapes = OrderedDict()
151 |
152 | if all_obs_keys is None:
153 | # use all modalities present in the file
154 | all_obs_keys = [k for k in demo["obs"]]
155 |
156 | for k in sorted(all_obs_keys):
157 | if k == "timesteps":
158 | initial_shape = (1,)
159 | else:
160 | initial_shape = demo["obs/{}".format(k)].shape[1:]
161 | if verbose:
162 | print("obs key {} with shape {}".format(k, initial_shape))
163 | # Store processed shape for each obs key
164 | all_shapes[k] = ObsUtils.get_processed_shape(
165 | obs_modality=ObsUtils.OBS_KEYS_TO_MODALITIES[k],
166 | input_shape=initial_shape,
167 | )
168 |
169 | f.close()
170 |
171 | shape_meta["all_shapes"] = all_shapes
172 | shape_meta["all_obs_keys"] = all_obs_keys
173 | shape_meta["use_images"] = ObsUtils.has_modality("rgb", all_obs_keys)
174 |
175 | return shape_meta
176 |
--------------------------------------------------------------------------------
/optimus/scripts/combine_hdf5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | import traceback
5 |
6 | import h5py
7 | import numpy as np
8 | import robomimic.utils.file_utils as FileUtils
9 | from tqdm import tqdm
10 |
11 | """
12 | Script for combining multiple hdf5 files into a single hdf5 file.
13 | """
14 |
15 | def write_trajectory_to_dataset(
16 | env, traj, data_grp, demo_name, save_next_obs=False, env_type="mujoco"
17 | ):
18 | """
19 | Write the collected trajectory to hdf5 compatible with robomimic.
20 | """
21 |
22 | # create group for this trajectory
23 | ep_data_grp = data_grp.create_group(demo_name)
24 | ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]), compression="gzip")
25 |
26 | if env_type == "mujoco":
27 | data = np.array(traj["states"])
28 | ep_data_grp.create_dataset("states", data=data)
29 | if "obs" in traj:
30 | for k in traj["obs"]:
31 | ep_data_grp.create_dataset(
32 | "obs/{}".format(k), data=np.array(traj["obs"][k]), compression="gzip"
33 | )
34 | if save_next_obs:
35 | ep_data_grp.create_dataset(
36 | "next_obs/{}".format(k),
37 | data=np.array(traj["next_obs"][k]),
38 | compression="gzip",
39 | )
40 |
41 | # episode metadata
42 | ep_data_grp.attrs["num_samples"] = traj["attrs"][
43 | "num_samples"
44 | ] # number of transitions in this episode
45 | if "model_file" in traj:
46 | ep_data_grp.attrs["model_file"] = traj["model_file"]
47 | if "init_string" in traj:
48 | ep_data_grp.attrs["init_string"] = traj["init_string"]
49 | if "goal_parts_string" in traj:
50 | ep_data_grp.attrs["goal_parts_string"] = traj["goal_parts_string"]
51 | return traj["actions"].shape[0]
52 |
53 |
54 | def load_demo_info(hdf5_file, filter_key=None):
55 | """
56 | Args:
57 | filter_by_attribute (str): if provided, use the provided filter key
58 | to select a subset of demonstration trajectories to load
59 |
60 | demos (list): list of demonstration keys to load from the hdf5 file. If
61 | omitted, all demos in the file (or under the @filter_by_attribute
62 | filter key) are used.
63 | """
64 | if filter_key is not None:
65 | print("using filter key: {}".format(args.filter_key))
66 | demos = [elem.decode("utf-8") for elem in np.array(hdf5_file["mask/{}".format(filter_key)])]
67 | else:
68 | demos = list(hdf5_file["data"].keys())
69 |
70 | # sort demo keys
71 | inds = np.argsort([int(elem[5:]) for elem in demos])
72 | demos = [demos[i] for i in inds]
73 |
74 | return demos
75 |
76 |
77 | def load_dataset_in_memory(
78 | demo_list,
79 | hdf5_file,
80 | dataset_keys,
81 | data_grp,
82 | demo_count=0,
83 | total_samples=0,
84 | env_type="mujoco",
85 | ):
86 | """
87 | Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this
88 | differs from `self.getitem_cache`, which, if active, actually caches the outputs of the
89 | `getitem` operation.
90 |
91 | Args:
92 | demo_list (list): list of demo keys, e.g., 'demo_0'
93 | hdf5_file (h5py.File): file handle to the hdf5 dataset.
94 | obs_keys (list, tuple): observation keys to fetch, e.g., 'images'
95 | dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions'
96 | load_next_obs (bool): whether to load next_obs from the dataset
97 |
98 | Returns:
99 | all_data (dict): dictionary of loaded data.
100 | """
101 | for ep in tqdm(demo_list):
102 | demo_name = f"demo_{demo_count}"
103 | traj = {}
104 | traj["attrs"] = {}
105 | traj["attrs"]["num_samples"] = hdf5_file["data/{}".format(ep)].attrs["num_samples"]
106 | # get obs
107 | traj["obs"] = {
108 | k: hdf5_file["data/{}/obs/{}".format(ep, k)][()]
109 | for k in hdf5_file["data/{}/obs".format(ep)]
110 | }
111 | # get other dataset keys
112 | for k in dataset_keys:
113 | if env_type == "mujoco":
114 | traj[k] = hdf5_file["data/{}/{}".format(ep, k)][
115 | ()
116 | ] # NOTE: do not cast to float this breaks action playback!
117 | try:
118 | traj["model_file"] = hdf5_file["data/{}".format(ep)].attrs["model_file"]
119 | except:
120 | pass
121 |
122 | try:
123 | traj["init_string"] = hdf5_file["data/{}".format(ep)].attrs["init_string"]
124 | traj["goal_parts_string"] = hdf5_file["data/{}".format(ep)].attrs["goal_parts_string"]
125 | except:
126 | pass
127 | write_trajectory_to_dataset(None, traj, data_grp, demo_name=demo_name, env_type=env_type)
128 | demo_count += 1
129 | total_samples += traj["attrs"]["num_samples"]
130 | env_args = hdf5_file["data/"].attrs["env_args"]
131 | return env_args, demo_count, total_samples
132 |
133 |
134 | def global_dataset_updates(data_grp, total_samples, env_args):
135 | """
136 | Update the global dataset attributes.
137 | """
138 | data_grp.attrs["total_samples"] = total_samples
139 | data_grp.attrs["env_args"] = env_args
140 | return data_grp
141 |
142 |
143 | def combine_hdf5(hdf5_paths, hdf5_use_swmr, dataset_path, filter_key):
144 | data_writer = h5py.File(dataset_path, "w")
145 | data_grp = data_writer.create_group("data")
146 | dataset_keys = ["actions", "states"]
147 | demo_count = 0
148 | total_samples = 0
149 | for hdf5_path in tqdm(hdf5_paths):
150 | try:
151 | hdf5_file = h5py.File(hdf5_path, "r", swmr=hdf5_use_swmr, libver="latest")
152 | demo_list = load_demo_info(hdf5_file, filter_key=filter_key)
153 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=hdf5_path)
154 | env_type = "mujoco"
155 | env_args, demo_count, total_samples = load_dataset_in_memory(
156 | demo_list,
157 | hdf5_file,
158 | dataset_keys,
159 | data_grp=data_grp,
160 | total_samples=total_samples,
161 | demo_count=demo_count,
162 | env_type=env_type,
163 | )
164 | # this won't work when combining many files (just re-generate filter keys)
165 | # if "mask" in hdf5_file:
166 | # hdf5_file.copy("mask", data_writer)
167 | print("loaded: ", hdf5_path)
168 | except:
169 | print("failed to load: ", hdf5_path)
170 | print(traceback.format_exc())
171 | pass
172 | global_dataset_updates(data_grp, total_samples, env_args)
173 | data_writer.close()
174 |
175 |
176 | if __name__ == "__main__":
177 | import argparse
178 |
179 | parser = argparse.ArgumentParser()
180 | parser.add_argument("--hdf5_paths", nargs="+")
181 | parser.add_argument("--output_path")
182 | parser.add_argument("--filter_key", type=str, default=None)
183 |
184 | args = parser.parse_args()
185 | combine_hdf5(args.hdf5_paths, True, args.output_path, args.filter_key)
186 |
--------------------------------------------------------------------------------
/optimus/scripts/filter_trajectories.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | import traceback
5 | from collections import Counter
6 |
7 | import h5py
8 | import numpy as np
9 | from robomimic.utils.file_utils import create_hdf5_filter_key
10 | from tqdm import tqdm
11 |
12 |
13 | def write_trajectory_to_dataset(
14 | env, traj, data_grp, demo_name, save_next_obs=False, env_type="mujoco"
15 | ):
16 | """
17 | Write the collected trajectory to hdf5 compatible with robomimic.
18 | """
19 |
20 | # create group for this trajectory
21 | ep_data_grp = data_grp.create_group(demo_name)
22 | ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]), compression="gzip")
23 |
24 | data = np.array(traj["states"])
25 | ep_data_grp.create_dataset("states", data=data)
26 | if "obs" in traj:
27 | for k in traj["obs"]:
28 | ep_data_grp.create_dataset(
29 | "obs/{}".format(k), data=np.array(traj["obs"][k]), compression="gzip"
30 | )
31 | if save_next_obs:
32 | ep_data_grp.create_dataset(
33 | "next_obs/{}".format(k),
34 | data=np.array(traj["next_obs"][k]),
35 | compression="gzip",
36 | )
37 |
38 | # episode metadata
39 | ep_data_grp.attrs["num_samples"] = traj["attrs"][
40 | "num_samples"
41 | ] # number of transitions in this episode
42 | if "model_file" in traj:
43 | ep_data_grp.attrs["model_file"] = traj["model_file"]
44 | if "init_string" in traj:
45 | ep_data_grp.attrs["init_string"] = traj["init_string"]
46 | if "goal_parts_string" in traj:
47 | ep_data_grp.attrs["goal_parts_string"] = traj["goal_parts_string"]
48 | return traj["actions"].shape[0]
49 |
50 |
51 | def load_demo_info(hdf5_file):
52 | """
53 | Args:
54 | filter_by_attribute (str): if provided, use the provided filter key
55 | to select a subset of demonstval_ration trajectories to load
56 |
57 | demos (list): list of demonstval_ration keys to load from the hdf5 file. If
58 | omitted, all demos in the file (or under the @filter_by_attribute
59 | filter key) are used.
60 | """
61 | demos = list(hdf5_file["data"].keys())
62 |
63 | # sort demo keys
64 | inds = np.argsort([int(elem[5:]) for elem in demos])
65 | demos = [demos[i] for i in inds]
66 |
67 | return demos
68 |
69 |
70 | def compute_traj_length_stats(hdf5_file, demo_list):
71 | traj_lengths = []
72 | for ep in demo_list:
73 | traj_lengths.append(hdf5_file["data"][ep].attrs["num_samples"])
74 | return np.mean(traj_lengths), np.std(traj_lengths)
75 |
76 |
77 | def combine_hdf5(
78 | hdf5_paths,
79 | hdf5_use_swmr,
80 | outlier_traj_length_sd,
81 | x_bounds,
82 | y_bounds,
83 | z_bounds,
84 | val_ratio,
85 | filter_key_prefix,
86 | eef_pos_key,
87 | ):
88 | traj_lengths = []
89 | outlier_files = Counter()
90 | out_of_bb_files = Counter()
91 | num_outlier_trajectories = 0
92 | num_out_of_bb_trajectories = 0
93 | num_trajectories = 0
94 | for hdf5_path in tqdm(hdf5_paths):
95 | try:
96 | hdf5_file = h5py.File(hdf5_path, "r+", swmr=hdf5_use_swmr, libver="latest")
97 | demo_list = load_demo_info(hdf5_file)
98 | traj_length_mean, traj_length_std = compute_traj_length_stats(hdf5_file, demo_list)
99 | filtered_demos = []
100 | for ep in demo_list:
101 | eef_pos = hdf5_file["data/{}/obs/{}".format(ep, eef_pos_key)][()]
102 | traj_length = hdf5_file["data"][ep].attrs["num_samples"]
103 | if traj_length > traj_length_mean + outlier_traj_length_sd * traj_length_std:
104 | too_long_traj = True
105 | else:
106 | too_long_traj = False
107 | if x_bounds is not None and y_bounds is not None and z_bounds is not None:
108 | out_of_bb = (
109 | min(eef_pos[:, 0]) < x_bounds[0]
110 | or max(eef_pos[:, 0]) > x_bounds[1]
111 | or min(eef_pos[:, 1]) < y_bounds[0]
112 | or max(eef_pos[:, 1]) > y_bounds[1]
113 | or min(eef_pos[:, 2]) < z_bounds[0]
114 | or max(eef_pos[:, 2]) > z_bounds[1]
115 | )
116 | else:
117 | out_of_bb = False
118 | if out_of_bb:
119 | out_of_bb_files[hdf5_path] += 1
120 | num_out_of_bb_trajectories += 1
121 | elif too_long_traj:
122 | outlier_files[hdf5_path] += 1
123 | num_outlier_trajectories += 1
124 | else:
125 | filtered_demos.append(ep)
126 | traj_lengths.append(hdf5_file["data"][ep].attrs["num_samples"])
127 | num_trajectories += len(demo_list)
128 | hdf5_file.close()
129 |
130 | num_demos = len(filtered_demos)
131 | val_val_ratio = val_ratio
132 | num_val = int(val_val_ratio * num_demos)
133 | mask = np.zeros(num_demos)
134 | mask[:num_val] = 1.0
135 | np.random.shuffle(mask)
136 | mask = mask.astype(int)
137 | train_inds = (1 - mask).nonzero()[0]
138 | valid_inds = mask.nonzero()[0]
139 | train_keys = [filtered_demos[i] for i in train_inds]
140 | valid_keys = [filtered_demos[i] for i in valid_inds]
141 | for key in train_keys:
142 | assert not (key in valid_keys)
143 | print(
144 | "{} validation demonstval_rations out of {} total demonstval_rations.".format(
145 | num_val, num_demos
146 | )
147 | )
148 |
149 | # pass mask to generate split
150 | name_1 = f"train{filter_key_prefix}"
151 | name_2 = f"valid{filter_key_prefix}"
152 |
153 | train_lengths = create_hdf5_filter_key(
154 | hdf5_path=hdf5_path, demo_keys=train_keys, key_name=name_1
155 | )
156 | valid_lengths = create_hdf5_filter_key(
157 | hdf5_path=hdf5_path, demo_keys=valid_keys, key_name=name_2
158 | )
159 | all_valid_lengths = create_hdf5_filter_key(
160 | hdf5_path=hdf5_path, demo_keys=filtered_demos, key_name="all_valid"
161 | )
162 |
163 | print("Total number of train samples: {}".format(np.sum(train_lengths)))
164 | print("Average number of train samples {}".format(np.mean(train_lengths)))
165 |
166 | print("Total number of valid samples: {}".format(np.sum(valid_lengths)))
167 | print("Average number of valid samples {}".format(np.mean(valid_lengths)))
168 |
169 | print("Total number of all valid samples: {}".format(np.sum(all_valid_lengths)))
170 | print("Average number of all valid samples {}".format(np.mean(all_valid_lengths)))
171 | except:
172 | print("failed to load: ", hdf5_path)
173 | print(traceback.format_exc())
174 | pass
175 | print(f"Num files with > mean + 2*sd samples: ", len(outlier_files))
176 | print(
177 | f"Num trajectories with > mean + 2*sd samples: ",
178 | num_outlier_trajectories,
179 | )
180 | print(
181 | "percentage of outlier trajectories: ",
182 | num_outlier_trajectories / num_trajectories * 100,
183 | )
184 | print()
185 | print(f"Num files with out of bb: ", len(out_of_bb_files))
186 | print(f"Num trajectories with out of bb: ", num_out_of_bb_trajectories)
187 | print(
188 | "percentage of out of bb trajectories: ",
189 | num_out_of_bb_trajectories / num_trajectories * 100,
190 | )
191 |
192 | # print trajectory stats:
193 | print("min traj length: ", np.min(traj_lengths))
194 | print("max traj length: ", np.max(traj_lengths))
195 | print("mean traj length: ", np.mean(traj_lengths))
196 | print("median traj length: ", np.median(traj_lengths))
197 | print("std traj length: ", np.std(traj_lengths))
198 |
199 |
200 | if __name__ == "__main__":
201 | import argparse
202 |
203 | parser = argparse.ArgumentParser()
204 | parser.add_argument("--hdf5_paths", nargs="+")
205 | parser.add_argument("--outlier_traj_length_sd", type=int, default=None)
206 | parser.add_argument("--x_bounds", nargs=2, type=float, default=None)
207 | parser.add_argument("--y_bounds", nargs=2, type=float, default=None)
208 | parser.add_argument("--z_bounds", nargs=2, type=float, default=None)
209 | parser.add_argument("--val_ratio", type=float, default=0.1)
210 | parser.add_argument("--filter_key_prefix", type=str, default="")
211 | parser.add_argument("--eef_pos_key", type=str, default="robot0_eef_pos")
212 |
213 | args = parser.parse_args()
214 | combine_hdf5(
215 | args.hdf5_paths,
216 | True,
217 | args.outlier_traj_length_sd,
218 | args.x_bounds,
219 | args.y_bounds,
220 | args.z_bounds,
221 | val_ratio=args.val_ratio,
222 | filter_key_prefix=args.filter_key_prefix,
223 | eef_pos_key=args.eef_pos_key,
224 | )
225 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OPTIMUS: Imitating Task and Motion Planning with Visuomotor Transformers
2 |
3 |
4 | This repository is the official implementation of Imitating Task and Motion Planning with Visuomotor Transformers.
5 |
6 | [Murtaza Dalal](https://mihdalal.github.io/)$^1$, [Ajay Mandlekar](https://ai.stanford.edu/~amandlek/)$^2$, [Caelan Garrett](http://web.mit.edu/caelan/www/)$^2$, [Ankur Handa](https://ankurhanda.github.io/)$^2$, [Ruslan Salakhutdinov](https://www.cs.cmu.edu/~rsalakhu/)$^1$, [Dieter Fox](https://homes.cs.washington.edu/~fox/)$^2$
7 |
8 | $^1$ CMU, $^2$ NVIDIA
9 |
10 | [Project Page](https://mihdalal.github.io/optimus/) | [Arxiv](https://arxiv.org/abs/2305.16309) | [Video](https://www.youtube.com/watch?v=2ItlsuNWi6Y)
11 |
12 |
13 |

14 |
15 | Optimus is a framework for training large scale imitation policies for robotic manipulation by distilling Task and Motion Planning into visuomotor Transformers. In this release we include datasets for replicating our results on Robosuite as well as code for performing TAMP data filtration and training/evaluating visuomotor Transformers on TAMP data.
16 |
17 | If you find this codebase useful in your research, please cite:
18 | ```bibtex
19 | @inproceedings{dalal2023optimus,
20 | title={Imitating Task and Motion Planning with Visuomotor Transformers},
21 | author={Dalal, Murtaza and Mandlekar, Ajay and Garrett, Caelan and Handa, Ankur and Salakhutdinov, Ruslan and Fox, Dieter},
22 | journal={Conference on Robot Learning},
23 | year={2023}
24 | }
25 | ```
26 |
27 | # Table of Contents
28 |
29 | - [Installation](#installation)
30 | - [Dataset Download](#dataset-download)
31 | - [TAMP Data Cleaning](#tamp-data-cleaning)
32 | - [Model Training](#model-training)
33 | - [Model Inference](#model-inference)
34 | - [Task Visualizations](#task-visualizations)
35 | - [Troubleshooting and Known Issues](#troubleshooting-and-known-issues)
36 | - [Citation](#citation)
37 |
38 | # Installation
39 | To install dependencies, please run the following commands:
40 | ```
41 | sudo apt-get update
42 | sudo apt-get install -y \
43 | htop screen tmux \
44 | sshfs libosmesa6-dev wget curl git \
45 | libeigen3-dev \
46 | liborocos-kdl-dev \
47 | libkdl-parser-dev \
48 | liburdfdom-dev \
49 | libnlopt-dev \
50 | libnlopt-cxx-dev \
51 | swig \
52 | python3 \
53 | python3-pip \
54 | python3-dev \
55 | vim \
56 | git-lfs \
57 | cmake \
58 | software-properties-common \
59 | libxcursor-dev \
60 | libxrandr-dev \
61 | libxinerama-dev \
62 | libxi-dev \
63 | mesa-common-dev \
64 | zip \
65 | unzip \
66 | make \
67 | g++ \
68 | python2.7 \
69 | wget \
70 | vulkan-utils \
71 | mesa-vulkan-drivers \
72 | apt nano rsync \
73 | libgl1-mesa-dev libgl1-mesa-glx libglew-dev libosmesa6-dev \
74 | software-properties-common net-tools unzip virtualenv \
75 | xpra xserver-xorg-dev libglfw3-dev patchelf python3-pip -y \
76 | && add-apt-repository -y ppa:openscad/releases && apt-get update && apt-get install -y openscad
77 | ```
78 |
79 | Please add the following to your bashrc/zshrc:
80 | ```
81 | export MUJOCO_GL='egl'
82 | WANDB_API_KEY=...
83 | ```
84 |
85 | To install python requirements:
86 |
87 | ```
88 | conda create -n optimus python=3.8
89 | conda activate optimus
90 | pip install -r requirements.txt
91 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
92 | pip install -e .
93 | ```
94 |
95 | # Dataset Download
96 |
97 | #### Method 1: Using `download_datasets.py` (Recommended)
98 |
99 | `download_datasets.py` (located at `optimus/scripts`) is a python script that provides a programmatic way of downloading the datasets. This is the preferred method, because this script also sets up a directory structure for the datasets that works out of the box with the code for reproducing policy learning results.
100 |
101 | A few examples of using this script are provided below:
102 |
103 | ```
104 | # default behavior - just download Stack dataset
105 | python download_datasets.py
106 |
107 | # download datasets for Stack and Stack Three
108 | python download_datasets.py --tasks Stack StackThree
109 |
110 | # download all datasets, but do a dry run first to see what will be downloaded and where
111 | python download_datasets.py --tasks all --dry_run
112 |
113 | # download all datasets for all tasks
114 | python download_datasets.py --tasks all # this downloads Stack, StackThree, StackFour and StackFive
115 | ```
116 |
117 | #### Method 2: Using Direct Download Links
118 |
119 | You can download the datasets manually through Google Drive.
120 |
121 | **Google Drive folder with all datasets:** [link](https://drive.google.com/drive/folders/1Dfi313igOuvc5JUCMTzQrSixUKndhW__?usp=drive_link)
122 |
123 | # TAMP Data Cleaning
124 | As described in Section 3.2 of the Optimus paper, we develop two TAMP demonstration filtering strategies to curb variance in the expert supervision: 1) Prune out demonstrations that have out of distribution trajectory length 2) Remove demonstrations that exit the visible workspace. In practice, we remove trajectories that have length greater than 2 standard deviations than the mean and exit a pre-defined workspace which includes all visible regions from the fixed camera viewpoint.
125 |
126 | We include the data filtration code in `filter_trajectories.py` and give example usage below.
127 | ```
128 | # usage:
129 | python optimus/scripts/filter_trajectories.py --hdf5_paths datasets/<>.hdf5 --x_bounds <> <> --y_bounds <> <> --z_bounds <> <> --val_ratio <> --filter_key_prefix <> --outlier_traj_length_sd <>
130 |
131 | # example
132 | python optimus/scripts/filter_trajectories.py --hdf5_paths datasets/robosuite_stack.hdf5 --outlier_traj_length_sd 2 --x_bounds -0.2 0.2 --y_bounds -0.2 0.2 --z_bounds 0 1.1 --val_ratio 0.1
133 | ```
134 |
135 | For the datasets that we have released, we have already performed these filtration operations, so you do not need to do so. Please do not run the below
136 |
137 | # Model Training
138 | After downloading the appropriate datasets you’re interested in using by running the `download_datasets.py` script, you can train policies using the `pl_train.py` script in `optimus/scripts`. Our training code wraps around [robomimic](https://robomimic.github.io/) (the key difference is we use PyTorch Lightning), please see the robomimic docs for a detailed overview of the imitation learning code. Following the robomimic format, our training configs are located in optimus/exps/local/robosuite/, with a different folder for each environment (stack, stackthree, stackfour, stackfive). We demonstrate example usage below:
139 | ```
140 | # usage:
141 | python optimus/scripts/pl_train.py --config optimus/exps/local/robosuite//bc_transformer.json
142 |
143 | # example:
144 | python optimus/scripts/pl_train.py --config optimus/exps/local/robosuite/stack/bc_transformer.json
145 | ```
146 |
147 | # Model Inference
148 | Given a checkpoint (from a training run), if you want to perform inference and evaluate the model, you can use `run_trained_agent_pl.py`. This script is based on `run_trained_agent.py` in [robomimic](https://robomimic.github.io/) but adds support for our PyTorch Lightning checkpoints. Concretely you need to specify the path to a specific ckpt file (for `--agent`) and the directory which contains `config.json` (for `--resume_dir`). We demonstrate example usage below:
149 | ```
150 | # usage:
151 | python optimus/scripts/run_trained_agent_pl.py --agent /path/to/trained_model.ckpt --resume_dir /path/to/training_dir --n <> --video_path <>.mp4
152 |
153 | # example:
154 | python optimus/scripts/run_trained_agent_pl.py --agent optimus/trained_models/robosuite_trained_models/bc_transformer_stack_image/20231026101652/models/model_epoch_50_Stack_success_1.0.ckpt --resume_dir optimus/trained_models/robosuite_trained_models/bc_transformer_stack_image/20231026101652/ --n 10 --video_path stack.mp4
155 | ```
156 |
157 | # Troubleshooting and Known Issues
158 |
159 | - If your training seems to be proceeding slowly (especially for image-based agents), it might be a problem with robomimic and more modern versions of PyTorch. We recommend PyTorch 1.12.1 (on Ubuntu, we used `pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113`). It is also a good idea to verify that the GPU is being utilized during training.
160 | - If you run into trouble with installing [egl_probe](https://github.com/StanfordVL/egl_probe) during robomimic installation (e.g. `ERROR: Failed building wheel for egl_probe`) you may need to make sure `cmake` is installed. A simple `pip install cmake` should work.
161 | - If you run into other strange installation issues, one potential fix is to launch a new terminal, activate your conda environment, and try the install commands that are failing once again. One clue that the current terminal state is corrupt and this fix will help is if you see installations going into a different conda environment than the one you have active.
162 |
163 | If you run into an error not documented above, please search through the [GitHub issues](https://github.com/NVlabs/optimus/issues), and create a new one if you cannot find a fix.
164 |
165 | ## Citation
166 |
167 | Please cite [the Optimus paper](https://arxiv.org/abs/2305.16309) if you use this code in your work:
168 |
169 | ```bibtex
170 | @inproceedings{dalal2023optimus,
171 | title={Imitating Task and Motion Planning with Visuomotor Transformers},
172 | author={Dalal, Murtaza and Mandlekar, Ajay and Garrett, Caelan and Handa, Ankur and Salakhutdinov, Ruslan and Fox, Dieter},
173 | journal={Conference on Robot Learning},
174 | year={2023}
175 | }
176 | ```
177 |
--------------------------------------------------------------------------------
/optimus/exps/local/robosuite/stack/bc_transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "algo_name": "bc",
3 | "experiment": {
4 | "name": "bc_transformer_stack_image",
5 | "validate": true,
6 | "logging": {
7 | "terminal_output_to_txt": false,
8 | "log_tb": true
9 | },
10 | "save": {
11 | "enabled": true,
12 | "every_n_seconds": null,
13 | "every_n_epochs": 25,
14 | "epochs": [],
15 | "on_best_validation": false,
16 | "on_best_rollout_return": false,
17 | "on_best_rollout_success_rate": true
18 | },
19 | "epoch_every_n_steps": 1000,
20 | "validation_epoch_every_n_steps": 10000,
21 | "env": null,
22 | "additional_envs": null,
23 | "render": false,
24 | "render_video": true,
25 | "keep_all_videos": false,
26 | "video_skip": 5,
27 | "rollout": {
28 | "enabled": true,
29 | "n": 10,
30 | "horizon": 300,
31 | "rate": 25,
32 | "warmstart": 0,
33 | "terminate_on_success": true,
34 | "parallel_rollouts": false
35 | }
36 | },
37 | "train": {
38 | "data": "datasets/robosuite_stack.hdf5",
39 | "output_dir": "trained_models/robosuite_trained_models",
40 | "num_data_workers": 4,
41 | "hdf5_cache_mode": "low_dim",
42 | "hdf5_use_swmr": true,
43 | "hdf5_normalize_obs": false,
44 | "hdf5_filter_key": "train",
45 | "seq_length": 1,
46 | "frame_stack": 8,
47 | "pad_frame_stack": true,
48 | "pad_seq_length": false,
49 | "dataset_keys": [
50 | "actions",
51 | "rewards",
52 | "dones"
53 | ],
54 | "goal_mode": null,
55 | "cuda": true,
56 | "batch_size": 16,
57 | "num_epochs": 1000,
58 | "seed": 1,
59 | "amp_enabled": true,
60 | "max_grad_norm": 1
61 | },
62 | "algo": {
63 | "optim_params": {
64 | "policy": {
65 | "learning_rate": {
66 | "initial": 0.0001,
67 | "decay_factor": 0.01,
68 | "epoch_schedule": []
69 | },
70 | "regularization": {
71 | "L2": 0.0
72 | }
73 | }
74 | },
75 | "loss": {
76 | "l2_weight": 1.0,
77 | "l1_weight": 0.0,
78 | "cos_weight": 0.0
79 | },
80 | "actor_layer_dims": [],
81 | "gaussian": {
82 | "enabled": false,
83 | "fixed_std": false,
84 | "init_std": 0.1,
85 | "min_std": 0.01,
86 | "std_activation": "softplus",
87 | "low_noise_eval": true
88 | },
89 | "gmm": {
90 | "enabled": true,
91 | "num_modes": 5,
92 | "min_std": 0.0001,
93 | "std_activation": "softplus",
94 | "low_noise_eval": true
95 | },
96 | "vae": {
97 | "enabled": false,
98 | "latent_dim": 14,
99 | "latent_clip": null,
100 | "kl_weight": 1.0,
101 | "decoder": {
102 | "is_conditioned": true,
103 | "reconstruction_sum_across_elements": false
104 | },
105 | "prior": {
106 | "learn": false,
107 | "is_conditioned": false,
108 | "use_gmm": false,
109 | "gmm_num_modes": 10,
110 | "gmm_learn_weights": false,
111 | "use_categorical": false,
112 | "categorical_dim": 10,
113 | "categorical_gumbel_softmax_hard": false,
114 | "categorical_init_temp": 1.0,
115 | "categorical_temp_anneal_step": 0.001,
116 | "categorical_min_temp": 0.3
117 | },
118 | "encoder_layer_dims": [
119 | 300,
120 | 400
121 | ],
122 | "decoder_layer_dims": [
123 | 300,
124 | 400
125 | ],
126 | "prior_layer_dims": [
127 | 300,
128 | 400
129 | ]
130 | },
131 | "rnn": {
132 | "enabled": false,
133 | "horizon": 10,
134 | "hidden_dim": 400,
135 | "rnn_type": "LSTM",
136 | "num_layers": 2,
137 | "open_loop": false,
138 | "kwargs": {
139 | "bidirectional": false
140 | }
141 | },
142 | "transformer": {
143 | "enabled": true,
144 | "condition_on_actions": false,
145 | "context_length": 8,
146 | "sinusoidal_embedding": false,
147 | "relative_timestep": true,
148 | "supervise_all_steps": false,
149 | "nn_parameter_for_timesteps": true,
150 | "latent_dim":0,
151 | "kl_loss_weight": 0.0
152 | }
153 | },
154 | "observation": {
155 | "modalities": {
156 | "obs": {
157 | "low_dim": [
158 | "robot0_eef_pos",
159 | "robot0_eef_quat",
160 | "robot0_gripper_qpos",
161 | "timesteps"
162 | ],
163 | "rgb": [
164 | "agentview_image",
165 | "robot0_eye_in_hand_image"
166 | ],
167 | "depth": [],
168 | "scan": []
169 | },
170 | "goal": {
171 | "low_dim": [],
172 | "rgb": [],
173 | "depth": [],
174 | "scan": []
175 | }
176 | },
177 | "encoder": {
178 | "low_dim": {
179 | "core_class": null,
180 | "core_kwargs": {},
181 | "obs_randomizer_class": null,
182 | "obs_randomizer_kwargs": {}
183 | },
184 | "rgb": {
185 | "core_class": "VisualCore",
186 | "core_kwargs": {
187 | "feature_dimension": 64,
188 | "flatten": true,
189 | "backbone_class": "ResNet18Conv",
190 | "backbone_kwargs": {
191 | "pretrained": false,
192 | "input_coord_conv": false
193 | },
194 | "pool_class": "SpatialSoftmax",
195 | "pool_kwargs": {
196 | "num_kp": 32,
197 | "learnable_temperature": false,
198 | "temperature": 1.0,
199 | "noise_std": 0.0,
200 | "output_variance": false
201 | }
202 | },
203 | "obs_randomizer_class": "CropRandomizer",
204 | "obs_randomizer_kwargs": {
205 | "crop_height": 76,
206 | "crop_width": 76,
207 | "num_crops": 1,
208 | "pos_enc": false
209 | }
210 | },
211 | "depth": {
212 | "core_class": "VisualCore",
213 | "core_kwargs": {
214 | "feature_dimension": 64,
215 | "flatten": true,
216 | "backbone_class": "ResNet18Conv",
217 | "backbone_kwargs": {
218 | "pretrained": false,
219 | "input_coord_conv": false
220 | },
221 | "pool_class": "SpatialSoftmax",
222 | "pool_kwargs": {
223 | "num_kp": 32,
224 | "learnable_temperature": false,
225 | "temperature": 1.0,
226 | "noise_std": 0.0,
227 | "output_variance": false
228 | }
229 | },
230 | "obs_randomizer_class": null,
231 | "obs_randomizer_kwargs": {
232 | "crop_height": 76,
233 | "crop_width": 76,
234 | "num_crops": 1,
235 | "pos_enc": false
236 | }
237 | },
238 | "scan": {
239 | "core_class": "ScanCore",
240 | "core_kwargs": {
241 | "feature_dimension": 64,
242 | "flatten": true,
243 | "pool_class": "SpatialSoftmax",
244 | "pool_kwargs": {
245 | "num_kp": 32,
246 | "learnable_temperature": false,
247 | "temperature": 1.0,
248 | "noise_std": 0.0,
249 | "output_variance": false
250 | },
251 | "conv_activation": "relu",
252 | "conv_kwargs": {
253 | "out_channels": [
254 | 32,
255 | 64,
256 | 64
257 | ],
258 | "kernel_size": [
259 | 8,
260 | 4,
261 | 2
262 | ],
263 | "stride": [
264 | 4,
265 | 2,
266 | 1
267 | ]
268 | }
269 | },
270 | "obs_randomizer_class": null,
271 | "obs_randomizer_kwargs": {
272 | "crop_height": 76,
273 | "crop_width": 76,
274 | "num_crops": 1,
275 | "pos_enc": false
276 | }
277 | }
278 | }
279 | }
280 | }
281 |
--------------------------------------------------------------------------------
/optimus/exps/local/robosuite/stackfive/bc_transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "algo_name": "bc",
3 | "experiment": {
4 | "name": "bc_transformer_stack_five_image",
5 | "validate": true,
6 | "logging": {
7 | "terminal_output_to_txt": false,
8 | "log_tb": true
9 | },
10 | "save": {
11 | "enabled": true,
12 | "every_n_seconds": null,
13 | "every_n_epochs": 25,
14 | "epochs": [],
15 | "on_best_validation": false,
16 | "on_best_rollout_return": false,
17 | "on_best_rollout_success_rate": true
18 | },
19 | "epoch_every_n_steps": 1000,
20 | "validation_epoch_every_n_steps": 10000,
21 | "env": null,
22 | "additional_envs": null,
23 | "render": false,
24 | "render_video": true,
25 | "keep_all_videos": false,
26 | "video_skip": 5,
27 | "rollout": {
28 | "enabled": true,
29 | "n": 50,
30 | "horizon": 1800,
31 | "rate": 25,
32 | "warmstart": 99,
33 | "terminate_on_success": true,
34 | "parallel_rollouts": false
35 | }
36 | },
37 | "train": {
38 | "data": "datasets/robosuite_stack_five.hdf5",
39 | "output_dir": "trained_models/robosuite_trained_models",
40 | "num_data_workers": 4,
41 | "hdf5_cache_mode": "low_dim",
42 | "hdf5_use_swmr": true,
43 | "hdf5_normalize_obs": false,
44 | "hdf5_filter_key": "train",
45 | "seq_length": 1,
46 | "frame_stack": 8,
47 | "pad_frame_stack": true,
48 | "pad_seq_length": false,
49 | "dataset_keys": [
50 | "actions",
51 | "rewards",
52 | "dones"
53 | ],
54 | "goal_mode": null,
55 | "cuda": true,
56 | "batch_size": 16,
57 | "num_epochs": 10000,
58 | "seed": 1,
59 | "amp_enabled": true,
60 | "max_grad_norm": 20
61 | },
62 | "algo": {
63 | "optim_params": {
64 | "policy": {
65 | "learning_rate": {
66 | "initial": 0.0001,
67 | "decay_factor": 0.01,
68 | "epoch_schedule": []
69 | },
70 | "regularization": {
71 | "L2": 0.0
72 | }
73 | }
74 | },
75 | "loss": {
76 | "l2_weight": 1.0,
77 | "l1_weight": 0.0,
78 | "cos_weight": 0.0
79 | },
80 | "actor_layer_dims": [],
81 | "gaussian": {
82 | "enabled": false,
83 | "fixed_std": false,
84 | "init_std": 0.1,
85 | "min_std": 0.01,
86 | "std_activation": "softplus",
87 | "low_noise_eval": true
88 | },
89 | "gmm": {
90 | "enabled": true,
91 | "num_modes": 5,
92 | "min_std": 0.0001,
93 | "std_activation": "softplus",
94 | "low_noise_eval": true
95 | },
96 | "vae": {
97 | "enabled": false,
98 | "latent_dim": 14,
99 | "latent_clip": null,
100 | "kl_weight": 1.0,
101 | "decoder": {
102 | "is_conditioned": true,
103 | "reconstruction_sum_across_elements": false
104 | },
105 | "prior": {
106 | "learn": false,
107 | "is_conditioned": false,
108 | "use_gmm": false,
109 | "gmm_num_modes": 10,
110 | "gmm_learn_weights": false,
111 | "use_categorical": false,
112 | "categorical_dim": 10,
113 | "categorical_gumbel_softmax_hard": false,
114 | "categorical_init_temp": 1.0,
115 | "categorical_temp_anneal_step": 0.001,
116 | "categorical_min_temp": 0.3
117 | },
118 | "encoder_layer_dims": [
119 | 300,
120 | 400
121 | ],
122 | "decoder_layer_dims": [
123 | 300,
124 | 400
125 | ],
126 | "prior_layer_dims": [
127 | 300,
128 | 400
129 | ]
130 | },
131 | "rnn": {
132 | "enabled": false,
133 | "horizon": 10,
134 | "hidden_dim": 400,
135 | "rnn_type": "LSTM",
136 | "num_layers": 2,
137 | "open_loop": false,
138 | "kwargs": {
139 | "bidirectional": false
140 | }
141 | },
142 | "transformer": {
143 | "enabled": true,
144 | "condition_on_actions": false,
145 | "context_length": 8,
146 | "sinusoidal_embedding": false,
147 | "relative_timestep": true,
148 | "supervise_all_steps": false,
149 | "nn_parameter_for_timesteps": true,
150 | "latent_dim":0,
151 | "kl_loss_weight": 0.0
152 | }
153 | },
154 | "observation": {
155 | "modalities": {
156 | "obs": {
157 | "low_dim": [
158 | "robot0_eef_pos",
159 | "robot0_eef_quat",
160 | "robot0_gripper_qpos",
161 | "timesteps"
162 | ],
163 | "rgb": [
164 | "agentview_image",
165 | "robot0_eye_in_hand_image"
166 | ],
167 | "depth": [],
168 | "scan": []
169 | },
170 | "goal": {
171 | "low_dim": [],
172 | "rgb": [],
173 | "depth": [],
174 | "scan": []
175 | }
176 | },
177 | "encoder": {
178 | "low_dim": {
179 | "core_class": null,
180 | "core_kwargs": {},
181 | "obs_randomizer_class": null,
182 | "obs_randomizer_kwargs": {}
183 | },
184 | "rgb": {
185 | "core_class": "VisualCore",
186 | "core_kwargs": {
187 | "feature_dimension": 64,
188 | "flatten": true,
189 | "backbone_class": "ResNet18Conv",
190 | "backbone_kwargs": {
191 | "pretrained": false,
192 | "input_coord_conv": false
193 | },
194 | "pool_class": "SpatialSoftmax",
195 | "pool_kwargs": {
196 | "num_kp": 32,
197 | "learnable_temperature": false,
198 | "temperature": 1.0,
199 | "noise_std": 0.0,
200 | "output_variance": false
201 | }
202 | },
203 | "obs_randomizer_class": "CropRandomizer",
204 | "obs_randomizer_kwargs": {
205 | "crop_height": 76,
206 | "crop_width": 76,
207 | "num_crops": 1,
208 | "pos_enc": false
209 | }
210 | },
211 | "depth": {
212 | "core_class": "VisualCore",
213 | "core_kwargs": {
214 | "feature_dimension": 64,
215 | "flatten": true,
216 | "backbone_class": "ResNet18Conv",
217 | "backbone_kwargs": {
218 | "pretrained": false,
219 | "input_coord_conv": false
220 | },
221 | "pool_class": "SpatialSoftmax",
222 | "pool_kwargs": {
223 | "num_kp": 32,
224 | "learnable_temperature": false,
225 | "temperature": 1.0,
226 | "noise_std": 0.0,
227 | "output_variance": false
228 | }
229 | },
230 | "obs_randomizer_class": null,
231 | "obs_randomizer_kwargs": {
232 | "crop_height": 76,
233 | "crop_width": 76,
234 | "num_crops": 1,
235 | "pos_enc": false
236 | }
237 | },
238 | "scan": {
239 | "core_class": "ScanCore",
240 | "core_kwargs": {
241 | "feature_dimension": 64,
242 | "flatten": true,
243 | "pool_class": "SpatialSoftmax",
244 | "pool_kwargs": {
245 | "num_kp": 32,
246 | "learnable_temperature": false,
247 | "temperature": 1.0,
248 | "noise_std": 0.0,
249 | "output_variance": false
250 | },
251 | "conv_activation": "relu",
252 | "conv_kwargs": {
253 | "out_channels": [
254 | 32,
255 | 64,
256 | 64
257 | ],
258 | "kernel_size": [
259 | 8,
260 | 4,
261 | 2
262 | ],
263 | "stride": [
264 | 4,
265 | 2,
266 | 1
267 | ]
268 | }
269 | },
270 | "obs_randomizer_class": null,
271 | "obs_randomizer_kwargs": {
272 | "crop_height": 76,
273 | "crop_width": 76,
274 | "num_crops": 1,
275 | "pos_enc": false
276 | }
277 | }
278 | }
279 | }
280 | }
281 |
--------------------------------------------------------------------------------
/optimus/exps/local/robosuite/stackfour/bc_transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "algo_name": "bc",
3 | "experiment": {
4 | "name": "bc_transformer_stack_four_image",
5 | "validate": true,
6 | "logging": {
7 | "terminal_output_to_txt": false,
8 | "log_tb": true
9 | },
10 | "save": {
11 | "enabled": true,
12 | "every_n_seconds": null,
13 | "every_n_epochs": 25,
14 | "epochs": [],
15 | "on_best_validation": false,
16 | "on_best_rollout_return": false,
17 | "on_best_rollout_success_rate": true
18 | },
19 | "epoch_every_n_steps": 1000,
20 | "validation_epoch_every_n_steps": 10000,
21 | "env": null,
22 | "additional_envs": null,
23 | "render": false,
24 | "render_video": true,
25 | "keep_all_videos": false,
26 | "video_skip": 5,
27 | "rollout": {
28 | "enabled": true,
29 | "n": 50,
30 | "horizon": 1350,
31 | "rate": 25,
32 | "warmstart": 99,
33 | "terminate_on_success": true,
34 | "parallel_rollouts": false
35 | }
36 | },
37 | "train": {
38 | "data": "datasets/robosuite_stack_four.hdf5",
39 | "output_dir": "trained_models/robosuite_trained_models",
40 | "num_data_workers": 4,
41 | "hdf5_cache_mode": "low_dim",
42 | "hdf5_use_swmr": true,
43 | "hdf5_normalize_obs": false,
44 | "hdf5_filter_key": "train",
45 | "seq_length": 1,
46 | "frame_stack": 8,
47 | "pad_frame_stack": true,
48 | "pad_seq_length": false,
49 | "dataset_keys": [
50 | "actions",
51 | "rewards",
52 | "dones"
53 | ],
54 | "goal_mode": null,
55 | "cuda": true,
56 | "batch_size": 16,
57 | "num_epochs": 10000,
58 | "seed": 1,
59 | "amp_enabled": true,
60 | "max_grad_norm": 20
61 | },
62 | "algo": {
63 | "optim_params": {
64 | "policy": {
65 | "learning_rate": {
66 | "initial": 0.0001,
67 | "decay_factor": 0.01,
68 | "epoch_schedule": []
69 | },
70 | "regularization": {
71 | "L2": 0.0
72 | }
73 | }
74 | },
75 | "loss": {
76 | "l2_weight": 1.0,
77 | "l1_weight": 0.0,
78 | "cos_weight": 0.0
79 | },
80 | "actor_layer_dims": [],
81 | "gaussian": {
82 | "enabled": false,
83 | "fixed_std": false,
84 | "init_std": 0.1,
85 | "min_std": 0.01,
86 | "std_activation": "softplus",
87 | "low_noise_eval": true
88 | },
89 | "gmm": {
90 | "enabled": true,
91 | "num_modes": 5,
92 | "min_std": 0.0001,
93 | "std_activation": "softplus",
94 | "low_noise_eval": true
95 | },
96 | "vae": {
97 | "enabled": false,
98 | "latent_dim": 14,
99 | "latent_clip": null,
100 | "kl_weight": 1.0,
101 | "decoder": {
102 | "is_conditioned": true,
103 | "reconstruction_sum_across_elements": false
104 | },
105 | "prior": {
106 | "learn": false,
107 | "is_conditioned": false,
108 | "use_gmm": false,
109 | "gmm_num_modes": 10,
110 | "gmm_learn_weights": false,
111 | "use_categorical": false,
112 | "categorical_dim": 10,
113 | "categorical_gumbel_softmax_hard": false,
114 | "categorical_init_temp": 1.0,
115 | "categorical_temp_anneal_step": 0.001,
116 | "categorical_min_temp": 0.3
117 | },
118 | "encoder_layer_dims": [
119 | 300,
120 | 400
121 | ],
122 | "decoder_layer_dims": [
123 | 300,
124 | 400
125 | ],
126 | "prior_layer_dims": [
127 | 300,
128 | 400
129 | ]
130 | },
131 | "rnn": {
132 | "enabled": false,
133 | "horizon": 10,
134 | "hidden_dim": 400,
135 | "rnn_type": "LSTM",
136 | "num_layers": 2,
137 | "open_loop": false,
138 | "kwargs": {
139 | "bidirectional": false
140 | }
141 | },
142 | "transformer": {
143 | "enabled": true,
144 | "condition_on_actions": false,
145 | "context_length": 8,
146 | "sinusoidal_embedding": false,
147 | "relative_timestep": true,
148 | "supervise_all_steps": false,
149 | "nn_parameter_for_timesteps": true,
150 | "latent_dim":0,
151 | "kl_loss_weight": 0.0
152 | }
153 | },
154 | "observation": {
155 | "modalities": {
156 | "obs": {
157 | "low_dim": [
158 | "robot0_eef_pos",
159 | "robot0_eef_quat",
160 | "robot0_gripper_qpos",
161 | "timesteps"
162 | ],
163 | "rgb": [
164 | "agentview_image",
165 | "robot0_eye_in_hand_image"
166 | ],
167 | "depth": [],
168 | "scan": []
169 | },
170 | "goal": {
171 | "low_dim": [],
172 | "rgb": [],
173 | "depth": [],
174 | "scan": []
175 | }
176 | },
177 | "encoder": {
178 | "low_dim": {
179 | "core_class": null,
180 | "core_kwargs": {},
181 | "obs_randomizer_class": null,
182 | "obs_randomizer_kwargs": {}
183 | },
184 | "rgb": {
185 | "core_class": "VisualCore",
186 | "core_kwargs": {
187 | "feature_dimension": 64,
188 | "flatten": true,
189 | "backbone_class": "ResNet18Conv",
190 | "backbone_kwargs": {
191 | "pretrained": false,
192 | "input_coord_conv": false
193 | },
194 | "pool_class": "SpatialSoftmax",
195 | "pool_kwargs": {
196 | "num_kp": 32,
197 | "learnable_temperature": false,
198 | "temperature": 1.0,
199 | "noise_std": 0.0,
200 | "output_variance": false
201 | }
202 | },
203 | "obs_randomizer_class": "CropRandomizer",
204 | "obs_randomizer_kwargs": {
205 | "crop_height": 76,
206 | "crop_width": 76,
207 | "num_crops": 1,
208 | "pos_enc": false
209 | }
210 | },
211 | "depth": {
212 | "core_class": "VisualCore",
213 | "core_kwargs": {
214 | "feature_dimension": 64,
215 | "flatten": true,
216 | "backbone_class": "ResNet18Conv",
217 | "backbone_kwargs": {
218 | "pretrained": false,
219 | "input_coord_conv": false
220 | },
221 | "pool_class": "SpatialSoftmax",
222 | "pool_kwargs": {
223 | "num_kp": 32,
224 | "learnable_temperature": false,
225 | "temperature": 1.0,
226 | "noise_std": 0.0,
227 | "output_variance": false
228 | }
229 | },
230 | "obs_randomizer_class": null,
231 | "obs_randomizer_kwargs": {
232 | "crop_height": 76,
233 | "crop_width": 76,
234 | "num_crops": 1,
235 | "pos_enc": false
236 | }
237 | },
238 | "scan": {
239 | "core_class": "ScanCore",
240 | "core_kwargs": {
241 | "feature_dimension": 64,
242 | "flatten": true,
243 | "pool_class": "SpatialSoftmax",
244 | "pool_kwargs": {
245 | "num_kp": 32,
246 | "learnable_temperature": false,
247 | "temperature": 1.0,
248 | "noise_std": 0.0,
249 | "output_variance": false
250 | },
251 | "conv_activation": "relu",
252 | "conv_kwargs": {
253 | "out_channels": [
254 | 32,
255 | 64,
256 | 64
257 | ],
258 | "kernel_size": [
259 | 8,
260 | 4,
261 | 2
262 | ],
263 | "stride": [
264 | 4,
265 | 2,
266 | 1
267 | ]
268 | }
269 | },
270 | "obs_randomizer_class": null,
271 | "obs_randomizer_kwargs": {
272 | "crop_height": 76,
273 | "crop_width": 76,
274 | "num_crops": 1,
275 | "pos_enc": false
276 | }
277 | }
278 | }
279 | }
280 | }
281 |
--------------------------------------------------------------------------------
/optimus/exps/local/robosuite/stackthree/bc_transformer.json:
--------------------------------------------------------------------------------
1 | {
2 | "algo_name": "bc",
3 | "experiment": {
4 | "name": "bc_transformer_stack_three_image",
5 | "validate": true,
6 | "logging": {
7 | "terminal_output_to_txt": false,
8 | "log_tb": true
9 | },
10 | "save": {
11 | "enabled": true,
12 | "every_n_seconds": null,
13 | "every_n_epochs": 25,
14 | "epochs": [],
15 | "on_best_validation": false,
16 | "on_best_rollout_return": false,
17 | "on_best_rollout_success_rate": true
18 | },
19 | "epoch_every_n_steps": 1000,
20 | "validation_epoch_every_n_steps": 10000,
21 | "env": null,
22 | "additional_envs": null,
23 | "render": false,
24 | "render_video": true,
25 | "keep_all_videos": false,
26 | "video_skip": 5,
27 | "rollout": {
28 | "enabled": true,
29 | "n": 50,
30 | "horizon": 900,
31 | "rate": 25,
32 | "warmstart": 99,
33 | "terminate_on_success": true,
34 | "parallel_rollouts": false
35 | }
36 | },
37 | "train": {
38 | "data": "datasets/robosuite_stack_three.hdf5",
39 | "output_dir": "trained_models/robosuite_trained_models",
40 | "num_data_workers": 4,
41 | "hdf5_cache_mode": "low_dim",
42 | "hdf5_use_swmr": true,
43 | "hdf5_normalize_obs": false,
44 | "hdf5_filter_key": "train",
45 | "seq_length": 1,
46 | "frame_stack": 8,
47 | "pad_frame_stack": true,
48 | "pad_seq_length": false,
49 | "dataset_keys": [
50 | "actions",
51 | "rewards",
52 | "dones"
53 | ],
54 | "goal_mode": null,
55 | "cuda": true,
56 | "batch_size": 16,
57 | "num_epochs": 10000,
58 | "seed": 1,
59 | "amp_enabled": true,
60 | "max_grad_norm": 20
61 | },
62 | "algo": {
63 | "optim_params": {
64 | "policy": {
65 | "learning_rate": {
66 | "initial": 0.0001,
67 | "decay_factor": 0.01,
68 | "epoch_schedule": []
69 | },
70 | "regularization": {
71 | "L2": 0.0
72 | }
73 | }
74 | },
75 | "loss": {
76 | "l2_weight": 1.0,
77 | "l1_weight": 0.0,
78 | "cos_weight": 0.0
79 | },
80 | "actor_layer_dims": [],
81 | "gaussian": {
82 | "enabled": false,
83 | "fixed_std": false,
84 | "init_std": 0.1,
85 | "min_std": 0.01,
86 | "std_activation": "softplus",
87 | "low_noise_eval": true
88 | },
89 | "gmm": {
90 | "enabled": true,
91 | "num_modes": 5,
92 | "min_std": 0.0001,
93 | "std_activation": "softplus",
94 | "low_noise_eval": true
95 | },
96 | "vae": {
97 | "enabled": false,
98 | "latent_dim": 14,
99 | "latent_clip": null,
100 | "kl_weight": 1.0,
101 | "decoder": {
102 | "is_conditioned": true,
103 | "reconstruction_sum_across_elements": false
104 | },
105 | "prior": {
106 | "learn": false,
107 | "is_conditioned": false,
108 | "use_gmm": false,
109 | "gmm_num_modes": 10,
110 | "gmm_learn_weights": false,
111 | "use_categorical": false,
112 | "categorical_dim": 10,
113 | "categorical_gumbel_softmax_hard": false,
114 | "categorical_init_temp": 1.0,
115 | "categorical_temp_anneal_step": 0.001,
116 | "categorical_min_temp": 0.3
117 | },
118 | "encoder_layer_dims": [
119 | 300,
120 | 400
121 | ],
122 | "decoder_layer_dims": [
123 | 300,
124 | 400
125 | ],
126 | "prior_layer_dims": [
127 | 300,
128 | 400
129 | ]
130 | },
131 | "rnn": {
132 | "enabled": false,
133 | "horizon": 10,
134 | "hidden_dim": 400,
135 | "rnn_type": "LSTM",
136 | "num_layers": 2,
137 | "open_loop": false,
138 | "kwargs": {
139 | "bidirectional": false
140 | }
141 | },
142 | "transformer": {
143 | "enabled": true,
144 | "condition_on_actions": false,
145 | "context_length": 8,
146 | "sinusoidal_embedding": false,
147 | "relative_timestep": true,
148 | "supervise_all_steps": false,
149 | "nn_parameter_for_timesteps": true,
150 | "latent_dim":0,
151 | "kl_loss_weight": 0.0
152 | }
153 | },
154 | "observation": {
155 | "modalities": {
156 | "obs": {
157 | "low_dim": [
158 | "robot0_eef_pos",
159 | "robot0_eef_quat",
160 | "robot0_gripper_qpos",
161 | "timesteps"
162 | ],
163 | "rgb": [
164 | "agentview_image",
165 | "robot0_eye_in_hand_image"
166 | ],
167 | "depth": [],
168 | "scan": []
169 | },
170 | "goal": {
171 | "low_dim": [],
172 | "rgb": [],
173 | "depth": [],
174 | "scan": []
175 | }
176 | },
177 | "encoder": {
178 | "low_dim": {
179 | "core_class": null,
180 | "core_kwargs": {},
181 | "obs_randomizer_class": null,
182 | "obs_randomizer_kwargs": {}
183 | },
184 | "rgb": {
185 | "core_class": "VisualCore",
186 | "core_kwargs": {
187 | "feature_dimension": 64,
188 | "flatten": true,
189 | "backbone_class": "ResNet18Conv",
190 | "backbone_kwargs": {
191 | "pretrained": false,
192 | "input_coord_conv": false
193 | },
194 | "pool_class": "SpatialSoftmax",
195 | "pool_kwargs": {
196 | "num_kp": 32,
197 | "learnable_temperature": false,
198 | "temperature": 1.0,
199 | "noise_std": 0.0,
200 | "output_variance": false
201 | }
202 | },
203 | "obs_randomizer_class": "CropRandomizer",
204 | "obs_randomizer_kwargs": {
205 | "crop_height": 76,
206 | "crop_width": 76,
207 | "num_crops": 1,
208 | "pos_enc": false
209 | }
210 | },
211 | "depth": {
212 | "core_class": "VisualCore",
213 | "core_kwargs": {
214 | "feature_dimension": 64,
215 | "flatten": true,
216 | "backbone_class": "ResNet18Conv",
217 | "backbone_kwargs": {
218 | "pretrained": false,
219 | "input_coord_conv": false
220 | },
221 | "pool_class": "SpatialSoftmax",
222 | "pool_kwargs": {
223 | "num_kp": 32,
224 | "learnable_temperature": false,
225 | "temperature": 1.0,
226 | "noise_std": 0.0,
227 | "output_variance": false
228 | }
229 | },
230 | "obs_randomizer_class": null,
231 | "obs_randomizer_kwargs": {
232 | "crop_height": 76,
233 | "crop_width": 76,
234 | "num_crops": 1,
235 | "pos_enc": false
236 | }
237 | },
238 | "scan": {
239 | "core_class": "ScanCore",
240 | "core_kwargs": {
241 | "feature_dimension": 64,
242 | "flatten": true,
243 | "pool_class": "SpatialSoftmax",
244 | "pool_kwargs": {
245 | "num_kp": 32,
246 | "learnable_temperature": false,
247 | "temperature": 1.0,
248 | "noise_std": 0.0,
249 | "output_variance": false
250 | },
251 | "conv_activation": "relu",
252 | "conv_kwargs": {
253 | "out_channels": [
254 | 32,
255 | 64,
256 | 64
257 | ],
258 | "kernel_size": [
259 | 8,
260 | 4,
261 | 2
262 | ],
263 | "stride": [
264 | 4,
265 | 2,
266 | 1
267 | ]
268 | }
269 | },
270 | "obs_randomizer_class": null,
271 | "obs_randomizer_kwargs": {
272 | "crop_height": 76,
273 | "crop_width": 76,
274 | "num_crops": 1,
275 | "pos_enc": false
276 | }
277 | }
278 | }
279 | }
280 | }
281 |
--------------------------------------------------------------------------------
/optimus/utils/env_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | """
5 | This file contains several utility functions for working with environment
6 | wrappers provided by the repository, and with environment metadata saved
7 | in dataset files.
8 | """
9 | from copy import deepcopy
10 |
11 | import robomimic.envs.env_base as EB
12 |
13 |
14 | def get_env_class(env_meta=None, env_type=None, env=None):
15 | """
16 | Return env class from either env_meta, env_type, or env.
17 | Note the use of lazy imports - this ensures that modules are only
18 | imported when the corresponding env type is requested. This can
19 | be useful in practice. For example, a training run that only
20 | requires access to gym environments should not need to import
21 | robosuite.
22 |
23 | Args:
24 | env_meta (dict): environment metadata, which should be loaded from demonstration
25 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
26 | @FileUtils.env_from_checkpoint). Contains 3 keys:
27 |
28 | :`'env_name'`: name of environment
29 | :`'type'`: type of environment, should be a value in EB.EnvType
30 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
31 |
32 | env_type (int): the type of environment, which determines the env class that will
33 | be instantiated. Should be a value in EB.EnvType.
34 |
35 | env (instance of EB.EnvBase): environment instance
36 | """
37 | env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env)
38 | if env_type == EB.EnvType.ROBOSUITE_TYPE:
39 | from optimus.envs.env_robosuite import EnvRobosuite
40 |
41 | return EnvRobosuite
42 | elif env_type == EB.EnvType.GYM_TYPE:
43 | from robomimic.envs.env_gym import EnvGym
44 |
45 | return EnvGym
46 | elif env_type == EB.EnvType.IG_MOMART_TYPE:
47 | from robomimic.envs.env_ig_momart import EnvGibsonMOMART
48 |
49 | return EnvGibsonMOMART
50 | raise Exception("code should never reach this point")
51 |
52 |
53 | def get_env_type(env_meta=None, env_type=None, env=None):
54 | """
55 | Helper function to get env_type from a variety of inputs.
56 |
57 | Args:
58 | env_meta (dict): environment metadata, which should be loaded from demonstration
59 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
60 | @FileUtils.env_from_checkpoint). Contains 3 keys:
61 |
62 | :`'env_name'`: name of environment
63 | :`'type'`: type of environment, should be a value in EB.EnvType
64 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
65 |
66 | env_type (int): the type of environment, which determines the env class that will
67 | be instantiated. Should be a value in EB.EnvType.
68 |
69 | env (instance of EB.EnvBase): environment instance
70 | """
71 | checks = [(env_meta is not None), (env_type is not None), (env is not None)]
72 | assert sum(checks) == 1, "should provide only one of env_meta, env_type, env"
73 | if env_meta is not None:
74 | env_type = env_meta["type"]
75 | elif env is not None:
76 | env_type = env.type
77 | return env_type
78 |
79 |
80 | def check_env_type(type_to_check, env_meta=None, env_type=None, env=None):
81 | """
82 | Checks whether the passed env_meta, env_type, or env is of type @type_to_check.
83 | Type corresponds to EB.EnvType.
84 |
85 | Args:
86 | type_to_check (int): type to check equality against
87 |
88 | env_meta (dict): environment metadata, which should be loaded from demonstration
89 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
90 | @FileUtils.env_from_checkpoint). Contains 3 keys:
91 |
92 | :`'env_name'`: name of environment
93 | :`'type'`: type of environment, should be a value in EB.EnvType
94 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
95 |
96 | env_type (int): the type of environment, which determines the env class that will
97 | be instantiated. Should be a value in EB.EnvType.
98 |
99 | env (instance of EB.EnvBase): environment instance
100 | """
101 | env_type = get_env_type(env_meta=env_meta, env_type=env_type, env=env)
102 | return env_type == type_to_check
103 |
104 |
105 | def is_robosuite_env(env_meta=None, env_type=None, env=None):
106 | """
107 | Determines whether the environment is a robosuite environment. Accepts
108 | either env_meta, env_type, or env.
109 | """
110 | return check_env_type(
111 | type_to_check=EB.EnvType.ROBOSUITE_TYPE,
112 | env_meta=env_meta,
113 | env_type=env_type,
114 | env=env,
115 | )
116 |
117 |
118 | def create_env(
119 | env_type,
120 | env_name,
121 | render=False,
122 | render_offscreen=False,
123 | use_image_obs=False,
124 | env_meta=None,
125 | **kwargs,
126 | ):
127 | """
128 | Create environment.
129 |
130 | Args:
131 | env_type (int): the type of environment, which determines the env class that will
132 | be instantiated. Should be a value in EB.EnvType.
133 |
134 | env_name (str): name of environment
135 |
136 | render (bool): if True, environment supports on-screen rendering
137 |
138 | render_offscreen (bool): if True, environment supports off-screen rendering. This
139 | is forced to be True if @use_image_obs is True.
140 |
141 | use_image_obs (bool): if True, environment is expected to render rgb image observations
142 | on every env.step call. Set this to False for efficiency reasons, if image
143 | observations are not required.
144 | """
145 |
146 | # note: pass @postprocess_visual_obs True, to make sure images are processed for network inputs
147 | env_class = get_env_class(env_meta=env_meta)
148 | env = env_class(
149 | env_name=env_name,
150 | render=render,
151 | render_offscreen=render_offscreen,
152 | use_image_obs=use_image_obs,
153 | postprocess_visual_obs=True,
154 | **kwargs,
155 | )
156 | print("Created environment with name {}".format(env_name))
157 | print("Action size is {}".format(env.action_dimension))
158 | return env
159 |
160 |
161 | def create_env_from_metadata(
162 | env_meta,
163 | env_name=None,
164 | render=False,
165 | render_offscreen=False,
166 | use_image_obs=False,
167 | ):
168 | """
169 | Create environment.
170 |
171 | Args:
172 | env_meta (dict): environment metadata, which should be loaded from demonstration
173 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
174 | @FileUtils.env_from_checkpoint). Contains 3 keys:
175 |
176 | :`'env_name'`: name of environment
177 | :`'type'`: type of environment, should be a value in EB.EnvType
178 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
179 |
180 | env_name (str): name of environment. Only needs to be provided if making a different
181 | environment from the one in @env_meta.
182 |
183 | render (bool): if True, environment supports on-screen rendering
184 |
185 | render_offscreen (bool): if True, environment supports off-screen rendering. This
186 | is forced to be True if @use_image_obs is True.
187 |
188 | use_image_obs (bool): if True, environment is expected to render rgb image observations
189 | on every env.step call. Set this to False for efficiency reasons, if image
190 | observations are not required.
191 | """
192 | if env_name is None:
193 | env_name = env_meta["env_name"]
194 | env_type = get_env_type(env_meta=env_meta)
195 | env_kwargs = env_meta["env_kwargs"]
196 |
197 | env = create_env(
198 | env_type=env_type,
199 | env_name=env_name,
200 | render=render,
201 | render_offscreen=render_offscreen,
202 | use_image_obs=use_image_obs,
203 | env_meta=env_meta,
204 | **env_kwargs,
205 | )
206 | return env
207 |
208 |
209 | def create_env_for_data_processing(
210 | env_meta,
211 | camera_names,
212 | camera_height,
213 | camera_width,
214 | reward_shaping,
215 | ):
216 | """
217 | Creates environment for processing dataset observations and rewards.
218 |
219 | Args:
220 | env_meta (dict): environment metadata, which should be loaded from demonstration
221 | hdf5 with @FileUtils.get_env_metadata_from_dataset or from checkpoint (see
222 | @FileUtils.env_from_checkpoint). Contains 3 keys:
223 |
224 | :`'env_name'`: name of environment
225 | :`'type'`: type of environment, should be a value in EB.EnvType
226 | :`'env_kwargs'`: dictionary of keyword arguments to pass to environment constructor
227 |
228 | camera_names (list of st): list of camera names that correspond to image observations
229 |
230 | camera_height (int): camera height for all cameras
231 |
232 | camera_width (int): camera width for all cameras
233 |
234 | reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards
235 | """
236 | env_name = env_meta["env_name"]
237 | env_type = get_env_type(env_meta=env_meta)
238 | env_kwargs = env_meta["env_kwargs"]
239 | env_class = get_env_class(env_meta=env_meta)
240 |
241 | # remove possibly redundant values in kwargs
242 | env_kwargs = deepcopy(env_kwargs)
243 | env_kwargs.pop("env_name", None)
244 | env_kwargs.pop("camera_names", None)
245 | env_kwargs.pop("camera_height", None)
246 | env_kwargs.pop("camera_width", None)
247 | env_kwargs.pop("reward_shaping", None)
248 |
249 | return env_class.create_for_data_processing(
250 | env_name=env_name,
251 | camera_names=camera_names,
252 | camera_height=camera_height,
253 | camera_width=camera_width,
254 | reward_shaping=reward_shaping,
255 | **env_kwargs,
256 | )
257 |
--------------------------------------------------------------------------------
/optimus/config/bc_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | """
5 | Config for BC algorithm. Taken from Ajay Mandlekar's private version of robomimic.
6 | """
7 |
8 | from robomimic.config.base_config import BaseConfig
9 |
10 |
11 | class BCConfig(BaseConfig):
12 | ALGO_NAME = "bc"
13 |
14 | def algo_config(self):
15 | """
16 | This function populates the `config.algo` attribute of the config, and is given to the
17 | `Algo` subclass (see `algo/algo.py`) for each algorithm through the `algo_config`
18 | argument to the constructor. Any parameter that an algorithm needs to determine its
19 | training and test-time behavior should be populated here.
20 | """
21 |
22 | # optimization parameters
23 | self.algo.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate
24 | self.algo.optim_params.policy.learning_rate.decay_factor = (
25 | 0.1 # factor to decay LR by (if epoch schedule non-empty)
26 | )
27 | self.algo.optim_params.policy.learning_rate.epoch_schedule = (
28 | []
29 | ) # epochs where LR decay occurs
30 | self.algo.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength
31 | self.algo.optim_params.policy.learning_rate.betas = (0.9, 0.99)
32 | self.algo.optim_params.policy.learning_rate.lr_scheduler_interval = "epoch"
33 |
34 | # loss weights
35 | self.algo.loss.l2_weight = 1.0 # L2 loss weight
36 | self.algo.loss.l1_weight = 0.0 # L1 loss weight
37 | self.algo.loss.cos_weight = 0.0 # cosine loss weight
38 |
39 | # MLP network architecture (layers after observation encoder and RNN, if present)
40 | self.algo.actor_layer_dims = (1024, 1024)
41 |
42 | # stochastic Gaussian policy settings
43 | self.algo.gaussian.enabled = False # whether to train a Gaussian policy
44 | self.algo.gaussian.fixed_std = False # whether to train std output or keep it constant
45 | self.algo.gaussian.init_std = 0.1 # initial standard deviation (or constant)
46 | self.algo.gaussian.min_std = 0.01 # minimum std output from network
47 | self.algo.gaussian.std_activation = (
48 | "softplus" # activation to use for std output from policy net
49 | )
50 | self.algo.gaussian.low_noise_eval = True # low-std at test-time
51 |
52 | # stochastic GMM policy settings
53 | self.algo.gmm.enabled = False # whether to train a GMM policy
54 | self.algo.gmm.num_modes = 5 # number of GMM modes
55 | self.algo.gmm.min_std = 0.0001 # minimum std output from network
56 | self.algo.gmm.std_activation = (
57 | "softplus" # activation to use for std output from policy net
58 | )
59 | self.algo.gmm.low_noise_eval = True # low-std at test-time
60 |
61 | # stochastic VAE policy settings
62 | self.algo.vae.enabled = False # whether to train a VAE policy
63 | self.algo.vae.latent_dim = (
64 | 14 # VAE latent dimnsion - set to twice the dimensionality of action space
65 | )
66 | self.algo.vae.latent_clip = None # clip latent space when decoding (set to None to disable)
67 | self.algo.vae.kl_weight = (
68 | 1.0 # beta-VAE weight to scale KL loss relative to reconstruction loss in ELBO
69 | )
70 |
71 | # VAE decoder settings
72 | self.algo.vae.decoder.is_conditioned = (
73 | True # whether decoder should condition on observation
74 | )
75 | self.algo.vae.decoder.reconstruction_sum_across_elements = (
76 | False # sum instead of mean for reconstruction loss
77 | )
78 |
79 | # VAE prior settings
80 | self.algo.vae.prior.learn = False # learn Gaussian / GMM prior instead of N(0, 1)
81 | self.algo.vae.prior.is_conditioned = False # whether to condition prior on observations
82 | self.algo.vae.prior.use_gmm = False # whether to use GMM prior
83 | self.algo.vae.prior.gmm_num_modes = 10 # number of GMM modes
84 | self.algo.vae.prior.gmm_learn_weights = False # whether to learn GMM weights
85 | self.algo.vae.prior.use_categorical = False # whether to use categorical prior
86 | self.algo.vae.prior.categorical_dim = (
87 | 10 # the number of categorical classes for each latent dimension
88 | )
89 | self.algo.vae.prior.categorical_gumbel_softmax_hard = (
90 | False # use hard selection in forward pass
91 | )
92 | self.algo.vae.prior.categorical_init_temp = 1.0 # initial gumbel-softmax temp
93 | self.algo.vae.prior.categorical_temp_anneal_step = 0.001 # linear temp annealing rate
94 | self.algo.vae.prior.categorical_min_temp = 0.3 # lowest gumbel-softmax temp
95 |
96 | self.algo.vae.encoder_layer_dims = (300, 400) # encoder MLP layer dimensions
97 | self.algo.vae.decoder_layer_dims = (300, 400) # decoder MLP layer dimensions
98 | self.algo.vae.prior_layer_dims = (
99 | 300,
100 | 400,
101 | ) # prior MLP layer dimensions (if learning conditioned prior)
102 |
103 | # RNN policy settings
104 | self.algo.rnn.enabled = False # whether to train RNN policy
105 | self.algo.rnn.horizon = 10 # unroll length for RNN - should usually match train.seq_length
106 | self.algo.rnn.hidden_dim = 400 # hidden dimension size
107 | self.algo.rnn.rnn_type = "LSTM" # rnn type - one of "LSTM" or "GRU"
108 | self.algo.rnn.num_layers = 2 # number of RNN layers that are stacked
109 | self.algo.rnn.open_loop = False # if True, action predictions are only based on a single observation (not sequence)
110 | self.algo.rnn.kwargs.bidirectional = False # rnn kwargs
111 | self.algo.rnn.kwargs.do_not_lock_keys()
112 |
113 | # Transformer policy settings
114 | self.algo.transformer.enabled = False # whether to train transformer policy
115 | self.algo.transformer.context_length = 64 # length of (s, a) seqeunces to feed to transformer - should usually match train.frame_stack
116 | self.algo.transformer.embed_dim = 256 # dimension for embeddings used by transformer
117 | self.algo.transformer.num_layers = 6 # number of transformer blocks to stack
118 | self.algo.transformer.num_heads = 8 # number of attention heads for each transformer block (should divide embed_dim evenly)
119 | self.algo.transformer.embedding_dropout = (
120 | 0.1 # dropout probability for embedding inputs in transformer
121 | )
122 | self.algo.transformer.block_attention_dropout = (
123 | 0.1 # dropout probability for attention outputs for each transformer block
124 | )
125 | self.algo.transformer.block_output_dropout = (
126 | 0.1 # dropout probability for final outputs for each transformer block
127 | )
128 | self.algo.transformer.condition_on_actions = False # whether to condition on the sequence of past actions in addition to the observation sequence
129 | self.algo.transformer.predict_obs = (
130 | False # whether to predict observations in the output sequences as well
131 | )
132 | self.algo.transformer.mask_inputs = False # whether to use bert style input masking
133 | self.algo.transformer.relative_timestep = True # if true timesteps range from 0 to context length-1, if false use absolute position in trajectory
134 | self.algo.transformer.euclidean_distance_timestep = False # if true timesteps are based the cumulative distance traveled by the end effector. otherwise integer timesteps
135 | self.algo.transformer.max_timestep = (
136 | 1250 # for the nn.embedding layer, must know the maximal timestep value
137 | )
138 | self.algo.transformer.open_loop_predictions = (
139 | False # if true don't run transformer at every step, execute a set of predicted actions
140 | )
141 | self.algo.transformer.sinusoidal_embedding = (
142 | False # if True, use standard positional encodings (sin/cos)
143 | )
144 | self.algo.transformer.obs_noise_scale = (
145 | 0.05 # amount of noise to add to the observations during training
146 | )
147 | self.algo.transformer.add_noise_to_train_obs = (
148 | False # if true add noise to the observations during training
149 | )
150 | self.algo.transformer.use_custom_transformer_block = (
151 | True # if True, use custom transformer block
152 | )
153 | self.algo.transformer.optimizer_type = "adamw" # toggle the type of optimizer to use
154 | self.algo.transformer.lr_scheduler_type = "linear" # toggle the type of lr_scheduler to use
155 | self.algo.transformer.lr_warmup = (
156 | False # if True, warmup the learning rate from some small value
157 | )
158 | self.algo.transformer.activation = (
159 | "gelu" # activation function for MLP in Transformer Block
160 | )
161 | self.algo.transformer.num_open_loop_actions_to_execute = (
162 | 10 # number of actions to execute in open loop
163 | )
164 | self.algo.transformer.supervise_all_steps = (
165 | False # if true, supervise all intermediate actions, otherwise only final one
166 | )
167 | self.algo.transformer.nn_parameter_for_timesteps = (
168 | True # if true, use nn.Parameter otherwise use nn.Embedding
169 | )
170 | self.algo.transformer.num_task_ids = 1 # number of tasks we are training with
171 | self.algo.transformer.task_id_embed_dim = 0 # dimension of the task id embedding
172 | self.algo.transformer.language_enabled = False # if true, condition on language embeddings
173 | self.algo.transformer.language_embedding = (
174 | "raw" # string denoting the language embedding to use
175 | )
176 | self.algo.transformer.finetune_language_embedding = (
177 | False # if true, finetune the language embedding
178 | )
179 | self.algo.transformer.kl_loss_weight = 0 # 5e-6 # weight of the KL loss
180 | self.algo.transformer.use_cvae = True
181 | self.algo.transformer.predict_signature = False
182 | self.algo.transformer.layer_dims = (1024, 1024)
183 | self.algo.transformer.latent_dim = 0
184 | self.algo.transformer.prior_use_gmm = True
185 | self.algo.transformer.prior_gmm_num_modes = 10
186 | self.algo.transformer.prior_gmm_learn_weights = True
187 | self.algo.transformer.replan_every_step = False
188 | self.algo.transformer.decoder = False
189 | self.algo.transformer.prior_use_categorical = False
190 | self.algo.transformer.prior_categorical_gumbel_softmax_hard = False
191 | self.algo.transformer.prior_categorical_dim = 10
192 | self.algo.transformer.primitive_type = "none"
193 | self.algo.transformer.reset_context_after_primitive_exec = True
194 | self.algo.transformer.block_drop_path = 0.0
195 | self.algo.transformer.use_cross_attention_conditioning = False
196 | self.algo.transformer.use_alternating_cross_attention_conditioning = False
197 | self.algo.transformer.key_value_from_condition = False
198 | self.algo.transformer.add_primitive_id = False
199 | self.algo.transformer.tokenize_primitive_id = False
200 | self.algo.transformer.channel_condition = False
201 | self.algo.transformer.tokenize_obs_components = False
202 | self.algo.transformer.num_patches_per_image_dim = 1
203 | self.algo.transformer.nets_to_freeze = ()
204 | self.algo.transformer.use_ndp_decoder = False
205 | self.algo.transformer.ndp_decoder_kwargs = None
206 | self.algo.transformer.transformer_type = "gpt"
207 | self.algo.transformer.mega_kwargs = {}
208 |
209 | self.algo.corner_loss.enabled = False
210 |
211 | self.algo.dml.enabled = (
212 | False # if True, use Discretized Mixture of Logistics output distribution
213 | )
214 | self.algo.dml.num_modes = 2 # number of modes in the DML distribution
215 | self.algo.dml.num_classes = (
216 | 256 # number of classes in each dimension of the discretized action space
217 | )
218 | self.algo.dml.log_scale_min = -7.0 # minimum value of the log scale
219 | self.algo.dml.constant_variance = True # if True, use a constant variance
220 |
221 | self.algo.cat.enabled = False # if True, use categorical output distribution
222 | self.algo.cat.num_classes = (
223 | 256 # number of classes in each dimension of the categorical action space
224 | )
225 |
226 | self.env_seed = 0 # seed for environment
227 | self.wandb_project_name = "test"
228 | self.train.fast_dev_run = False # whether to run training in debug mode
229 | self.train.max_grad_norm = None # max gradient norm
230 | self.train.amp_enabled = False # use automatic mixed precision
231 | self.train.num_gpus = 1 # number of gpus to use
232 | self.experiment.rollout.goal_conditioning_enabled = (
233 | False # if True, use goal conditioning wrapper to sample goals for inference
234 | )
235 | self.train.load_next_obs = False # whether or not to load s'
236 | self.experiment.rollout.goal_success_threshold = 0.01 # if the distance from the goal is less than goal_success_threshold, this counts as a success
237 | self.train.pad_frame_stack = True
238 | self.train.pad_seq_length = True
239 | self.train.frame_stack = 1
240 | self.experiment.rollout.is_mujoco = False
241 | self.experiment.rollout.valid_key = "valid"
242 | self.train.ckpt_path = None
243 | self.train.use_swa = False
244 | self.experiment.rollout.parallel_rollouts = True
245 | self.experiment.rollout.select_random_subset = False
246 | self.train.save_ckpt_on_epoch_end = True
247 |
248 | self.algo.dagger.enabled = False # if true, enable dagger support
249 | self.algo.dagger.online_epoch_rate = 50 # how often to collect online data
250 | self.algo.dagger.num_rollouts = 1 # number of rollouts to collect per online epoch
251 | self.algo.dagger.rollout_type = (
252 | "state_error_mixed" # toggle rollout type, can range from policy to closed loop tamp
253 | )
254 | self.algo.dagger.state_error_threshold = (
255 | 0.001 # threshold for state error - triggers TAMP solver action
256 | )
257 | self.algo.dagger.action_error_threshold = (
258 | 0.001 # threshold for action error - triggers TAMP solver action
259 | )
260 | self.algo.dagger.mpc_horizon = (
261 | 100 # MPC horizon for closed loop tamp MPC - useful for running faster
262 | )
263 |
--------------------------------------------------------------------------------
/optimus/scripts/playback_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | """
5 | A script to visualize dataset trajectories by loading the simulation states
6 | one by one or loading the first state and playing actions back open-loop.
7 | The script can generate videos as well, by rendering simulation frames
8 | during playback. The videos can also be generated using the image observations
9 | in the dataset (this is useful for real-robot datasets) by using the
10 | --use-obs argument.
11 |
12 | Args:
13 | dataset (str): path to hdf5 dataset
14 |
15 | filter_key (str): if provided, use the subset of trajectories
16 | in the file that correspond to this filter key
17 |
18 | n (int): if provided, stop after n trajectories are processed
19 |
20 | use-obs (bool): if flag is provided, visualize trajectories with dataset
21 | image observations instead of simulator
22 |
23 | use-actions (bool): if flag is provided, use open-loop action playback
24 | instead of loading sim states
25 |
26 | render (bool): if flag is provided, use on-screen rendering during playback
27 |
28 | video_path (str): if provided, render trajectories to this video file path
29 |
30 | video_skip (int): render frames to a video every @video_skip steps
31 |
32 | render_image_names (str or [str]): camera name(s) / image observation(s) to
33 | use for rendering on-screen or to video
34 |
35 | first (bool): if flag is provided, use first frame of each episode for playback
36 | instead of the entire episode. Useful for visualizing task initializations.
37 |
38 | Example usage below:
39 |
40 | # force simulation states one by one, and render agentview and wrist view cameras to video
41 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \
42 | --render_image_names agentview robot0_eye_in_hand \
43 | --video_path /tmp/playback_dataset.mp4
44 |
45 | # playback the actions in the dataset, and render agentview camera during playback to video
46 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \
47 | --use-actions --render_image_names agentview \
48 | --video_path /tmp/playback_dataset_with_actions.mp4
49 |
50 | # use the observations stored in the dataset to render videos of the dataset trajectories
51 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \
52 | --use-obs --render_image_names agentview_image \
53 | --video_path /tmp/obs_trajectory.mp4
54 |
55 | # visualize initial states in the demonstration data
56 | python playback_dataset.py --dataset /path/to/dataset.hdf5 \
57 | --first --render_image_names agentview \
58 | --video_path /tmp/dataset_task_inits.mp4
59 | """
60 |
61 | import argparse
62 |
63 | import h5py
64 | import imageio
65 | import numpy as np
66 | import robomimic.utils.obs_utils as ObsUtils
67 | from robomimic.envs.env_base import EnvBase, EnvType
68 | from tqdm import tqdm
69 |
70 | import optimus.utils.env_utils as EnvUtils
71 | import optimus.utils.file_utils as FileUtils
72 |
73 | # Define default cameras to use for each env type
74 | DEFAULT_CAMERAS = {
75 | EnvType.ROBOSUITE_TYPE: ["agentview_image"],
76 | EnvType.IG_MOMART_TYPE: ["rgb"],
77 | EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"),
78 | }
79 |
80 |
81 | def playback_trajectory_with_env(
82 | env,
83 | initial_state,
84 | states,
85 | actions=None,
86 | render=False,
87 | video_writer=None,
88 | video_skip=5,
89 | camera_names=None,
90 | first=False,
91 | ):
92 | """
93 | Helper function to playback a single trajectory using the simulator environment.
94 | If @actions are not None, it will play them open-loop after loading the initial state.
95 | Otherwise, @states are loaded one by one.
96 |
97 | Args:
98 | env (instance of EnvBase): environment
99 | initial_state (dict): initial simulation state to load
100 | states (list of dict or np.array): array of simulation states to load
101 | actions (np.array): if provided, play actions back open-loop instead of using @states
102 | render (bool): if True, render on-screen
103 | video_writer (imageio writer): video writer
104 | video_skip (int): determines rate at which environment frames are written to video
105 | camera_names (list): determines which camera(s) are used for rendering. Pass more than
106 | one to output a video with multiple camera views concatenated horizontally.
107 | first (bool): if True, only use the first frame of each episode.
108 | """
109 | assert isinstance(env, EnvBase) or isinstance(env.env, EnvBase)
110 |
111 | write_video = video_writer is not None
112 | video_count = 0
113 | assert not (render and write_video)
114 | # load the initial state
115 | env.reset()
116 | env.reset_to(initial_state)
117 | action_playback = actions is not None
118 | traj_len = len(states) if not action_playback else len(actions)
119 | # if action_playback:
120 | # assert len(states) == (actions.shape[0] + 1)
121 |
122 | success = {k: False for k in env.is_success()}
123 | for i in range(traj_len):
124 | if action_playback:
125 | info = env.step(
126 | actions[i],
127 | # render_intermediate_obs=render
128 | # and (env.execute_controller_actions or env.execute_simplified_actions),
129 | # render_intermediate_video_obs=write_video
130 | # and (env.execute_controller_actions or env.execute_simplified_actions),
131 | )[-1]
132 | if i < traj_len - 1:
133 | # check whether the actions deterministically lead to the same recorded states
134 | state_playback = env.get_state()["states"]
135 | # else:
136 | # if not np.all(np.equal(states[i + 1], state_playback)):
137 | # err = np.linalg.norm(states[i + 1] - state_playback)
138 | # print("warning: playback diverged by {} at step {}".format(err, i))
139 | else:
140 | env.reset_to({"states": states[i]})
141 |
142 | # on-screen render
143 | if render:
144 | env.render(mode="human", camera_name=camera_names[0])
145 |
146 | # video render
147 | if write_video:
148 | if video_count % video_skip == 0:
149 | video_img = []
150 | for cam_name in camera_names:
151 | video_img.append(
152 | env.render(
153 | mode="rgb_array",
154 | height=512,
155 | width=512,
156 | camera_name=cam_name,
157 | )
158 | )
159 | video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
160 | video_writer.append_data(video_img)
161 | video_count += 1
162 | exec_success = env.is_success()
163 | success = {k: success[k] or bool(exec_success[k]) for k in exec_success}
164 | if first:
165 | break
166 | return success
167 |
168 |
169 | def playback_trajectory_with_obs(
170 | traj_grp,
171 | video_writer,
172 | video_skip=5,
173 | image_names=None,
174 | first=False,
175 | ):
176 | """
177 | This function reads all "rgb" observations in the dataset trajectory and
178 | writes them into a video.
179 |
180 | Args:
181 | traj_grp (hdf5 file group): hdf5 group which corresponds to the dataset trajectory to playback
182 | video_writer (imageio writer): video writer
183 | video_skip (int): determines rate at which environment frames are written to video
184 | image_names (list): determines which image observations are used for rendering. Pass more than
185 | one to output a video with multiple image observations concatenated horizontally.
186 | first (bool): if True, only use the first frame of each episode.
187 | """
188 | assert (
189 | image_names is not None
190 | ), "error: must specify at least one image observation to use in @image_names"
191 | video_count = 0
192 |
193 | traj_len = traj_grp["actions"].shape[0]
194 | for i in range(traj_len):
195 | if video_count % video_skip == 0:
196 | # concatenate image obs together
197 | im = [traj_grp["obs/{}".format(k)][i] for k in image_names]
198 | frame = np.concatenate(im, axis=1)
199 | video_writer.append_data(frame)
200 | video_count += 1
201 |
202 | if first:
203 | break
204 |
205 |
206 | def playback_dataset(args):
207 | # some arg checking
208 | write_video = args.video_path is not None
209 | assert not (args.render and write_video) # either on-screen or video but not both
210 |
211 | # Auto-fill camera rendering info if not specified
212 | if args.render_image_names is None:
213 | # We fill in the automatic values
214 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
215 | env_type = EnvUtils.get_env_type(env_meta=env_meta)
216 | args.render_image_names = DEFAULT_CAMERAS[env_type]
217 |
218 | if args.render:
219 | # on-screen rendering can only support one camera
220 | assert len(args.render_image_names) == 1
221 |
222 | if args.use_obs:
223 | assert write_video, "playback with observations can only write to video"
224 | assert (
225 | not args.use_actions
226 | ), "playback with observations is offline and does not support action playback"
227 |
228 | # create environment only if not playing back with observations
229 | if not args.use_obs:
230 | # need to make sure ObsUtils knows which observations are images, but it doesn't matter
231 | # for playback since observations are unused. Pass a dummy spec here.
232 | dummy_spec = dict(
233 | obs=dict(
234 | low_dim=["robot0_eef_pos"],
235 | rgb=[],
236 | ),
237 | )
238 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)
239 |
240 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=args.dataset)
241 | env = EnvUtils.create_env_from_metadata(
242 | env_meta=env_meta, render=args.render, render_offscreen=write_video
243 | )
244 |
245 | # some operations for playback are env-type-specific
246 | is_robosuite_env = EnvUtils.is_robosuite_env(env_meta)
247 |
248 | f = h5py.File(args.dataset, "r")
249 |
250 | # list of all demonstration episodes (sorted in increasing number order)
251 | if args.filter_key is not None:
252 | print("using filter key: {}".format(args.filter_key))
253 | demos = [elem.decode("utf-8") for elem in np.array(f["mask/{}".format(args.filter_key)])]
254 | else:
255 | demos = list(f["data"].keys())
256 | inds = np.argsort([int(elem[5:]) for elem in demos])
257 | demos = [demos[i] for i in inds]
258 |
259 | # maybe reduce the number of demonstrations to playback
260 | if args.n is not None:
261 | demos = demos[: args.n]
262 |
263 | # maybe dump video
264 | video_writer = None
265 | if write_video:
266 | video_writer = imageio.get_writer(args.video_path, fps=20)
267 |
268 | failed_episodes = []
269 | for ind in tqdm(range(len(demos))):
270 | ep = demos[ind]
271 | print("Playing back episode: {}".format(ep))
272 |
273 | if args.use_obs:
274 | playback_trajectory_with_obs(
275 | traj_grp=f["data/{}".format(ep)],
276 | video_writer=video_writer,
277 | video_skip=args.video_skip,
278 | image_names=args.render_image_names,
279 | first=args.first,
280 | )
281 | continue
282 | # prepare states to reload from
283 | states = f["data/{}/states".format(ep)][()]
284 | initial_state = dict(states=states[0])
285 | if is_robosuite_env:
286 | initial_state["model"] = f["data/{}".format(ep)].attrs["model_file"]
287 | try:
288 | initial_state["init_string"] = f["data/{}".format(ep)].attrs["init_string"]
289 | initial_state["goal_parts_string"] = f["data/{}".format(ep)].attrs[
290 | "goal_parts_string"
291 | ]
292 | print(initial_state["goal_parts_string"])
293 | except:
294 | print("No init_string or goal_parts_string in file")
295 | pass
296 |
297 | # supply actions if using open-loop action playback
298 | actions = None
299 | if args.use_actions:
300 | actions = f["data/{}/actions".format(ep)][()]
301 | success = playback_trajectory_with_env(
302 | env=env,
303 | initial_state=initial_state,
304 | states=states,
305 | actions=actions,
306 | render=args.render,
307 | video_writer=video_writer,
308 | video_skip=args.video_skip,
309 | camera_names=args.render_image_names,
310 | first=args.first,
311 | )
312 | print("trajectory_length: {}".format(len(states)))
313 | if not success["task"]:
314 | print("Episode {} failed".format(ep))
315 | failed_episodes.append(ep)
316 | print("Failed episodes: {}".format(failed_episodes))
317 |
318 | f.close()
319 | if write_video:
320 | video_writer.close()
321 |
322 |
323 | if __name__ == "__main__":
324 | parser = argparse.ArgumentParser()
325 | parser.add_argument(
326 | "--dataset",
327 | type=str,
328 | help="path to hdf5 dataset",
329 | )
330 | parser.add_argument(
331 | "--filter_key",
332 | type=str,
333 | default=None,
334 | help="(optional) filter key, to select a subset of trajectories in the file",
335 | )
336 |
337 | # number of trajectories to playback. If omitted, playback all of them.
338 | parser.add_argument(
339 | "--n",
340 | type=int,
341 | default=None,
342 | help="(optional) stop after n trajectories are played",
343 | )
344 |
345 | # Use image observations instead of doing playback using the simulator env.
346 | parser.add_argument(
347 | "--use-obs",
348 | action="store_true",
349 | help="visualize trajectories with dataset image observations instead of simulator",
350 | )
351 |
352 | # Playback stored dataset actions open-loop instead of loading from simulation states.
353 | parser.add_argument(
354 | "--use-actions",
355 | action="store_true",
356 | help="use open-loop action playback instead of loading sim states",
357 | )
358 |
359 | # Whether to render playback to screen
360 | parser.add_argument(
361 | "--render",
362 | action="store_true",
363 | help="on-screen rendering",
364 | )
365 |
366 | # Dump a video of the dataset playback to the specified path
367 | parser.add_argument(
368 | "--video_path",
369 | type=str,
370 | default=None,
371 | help="(optional) render trajectories to this video file path",
372 | )
373 |
374 | # How often to write video frames during the playback
375 | parser.add_argument(
376 | "--video_skip",
377 | type=int,
378 | default=5,
379 | help="render frames to video every n steps",
380 | )
381 |
382 | # camera names to render, or image observations to use for writing to video
383 | parser.add_argument(
384 | "--render_image_names",
385 | type=str,
386 | nargs="+",
387 | default=None,
388 | help="(optional) camera name(s) / image observation(s) to use for rendering on-screen or to video. Default is"
389 | "None, which corresponds to a predefined camera for each env type",
390 | )
391 |
392 | # Only use the first frame of each episode
393 | parser.add_argument(
394 | "--first",
395 | action="store_true",
396 | help="use first frame of each episode",
397 | )
398 |
399 | args = parser.parse_args()
400 | playback_dataset(args)
401 |
--------------------------------------------------------------------------------
/optimus/envs/wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | """
5 | A collection of useful environment wrappers. Taken from Ajay Mandlekar's private version of robomimic.
6 | """
7 | import textwrap
8 | from collections import deque
9 |
10 | import h5py
11 | import numpy as np
12 |
13 |
14 | class Wrapper(object):
15 | """
16 | Base class for all environment wrappers in optimus.
17 | """
18 |
19 | def __init__(self, env):
20 | """
21 | Args:
22 | env (EnvBase instance): The environment to wrap.
23 | """
24 | self.env = env
25 |
26 | @classmethod
27 | def class_name(cls):
28 | return cls.__name__
29 |
30 | def _warn_double_wrap(self):
31 | """
32 | Utility function that checks if we're accidentally trying to double wrap an env
33 | Raises:
34 | Exception: [Double wrapping env]
35 | """
36 | env = self.env
37 | while True:
38 | if isinstance(env, Wrapper):
39 | if env.class_name() == self.class_name():
40 | raise Exception(
41 | "Attempted to double wrap with Wrapper: {}".format(self.__class__.__name__)
42 | )
43 | env = env.env
44 | else:
45 | break
46 |
47 | @property
48 | def unwrapped(self):
49 | """
50 | Grabs unwrapped environment
51 |
52 | Returns:
53 | env (EnvBase instance): Unwrapped environment
54 | """
55 | if hasattr(self.env, "unwrapped"):
56 | return self.env.unwrapped
57 | else:
58 | return self.env
59 |
60 | def _to_string(self):
61 | """
62 | Subclasses should override this method to print out info about the
63 | wrapper (such as arguments passed to it).
64 | """
65 | return ""
66 |
67 | def __repr__(self):
68 | """Pretty print environment."""
69 | header = "{}".format(str(self.__class__.__name__))
70 | msg = ""
71 | indent = " " * 4
72 | if self._to_string() != "":
73 | msg += textwrap.indent("\n" + self._to_string(), indent)
74 | msg += textwrap.indent("\nenv={}".format(self.env), indent)
75 | msg = header + "(" + msg + "\n)"
76 | return msg
77 |
78 | # this method is a fallback option on any methods the original env might support
79 | def __getattr__(self, attr):
80 | # using getattr ensures that both __getattribute__ and __getattr__ (fallback) get called
81 | # (see https://stackoverflow.com/questions/3278077/difference-between-getattr-vs-getattribute)
82 | orig_attr = getattr(self.env, attr)
83 | if callable(orig_attr):
84 |
85 | def hooked(*args, **kwargs):
86 | result = orig_attr(*args, **kwargs)
87 | # prevent wrapped_class from becoming unwrapped
88 | if id(result) == id(self.env):
89 | return self
90 | return result
91 |
92 | return hooked
93 | else:
94 | return orig_attr
95 |
96 |
97 | class EvaluateOnDatasetWrapper(Wrapper):
98 | def __init__(
99 | self,
100 | env,
101 | dataset_path=None,
102 | valid_key="valid",
103 | ):
104 | super(EvaluateOnDatasetWrapper, self).__init__(env=env)
105 | self.dataset_path = dataset_path
106 | self.valid_key = valid_key
107 | if dataset_path is not None:
108 | self.load_evaluation_data(dataset_path)
109 |
110 | def sample_eval_episodes(self, num_episodes):
111 | """
112 | Sample a random set of episodes from the set of all episodes.
113 | """
114 | self.eval_indices = np.random.choice(
115 | range(len(self.initial_states)), size=num_episodes, replace=False
116 | )
117 | self.eval_current_index = 0
118 |
119 | def get_num_val_states(self):
120 | return len(self.demos)
121 |
122 | def set_eval_episode(self, eval_index):
123 | self.eval_indices = [eval_index]
124 | self.eval_current_index = 0
125 |
126 | def load_evaluation_data(self, hdf5_path):
127 | # NOTE: for the hierarchical primitive setting, this code will only
128 | # work for resetting the initial state, the signature/command_indices are wrong
129 | # DO NOT use with command_index or signature conditioning
130 | self.hdf5_file = h5py.File(hdf5_path, "r", swmr=True, libver="latest")
131 | filter_key = self.valid_key
132 | self.demos = [
133 | elem.decode("utf-8")
134 | for elem in np.array(self.hdf5_file["mask/{}".format(filter_key)][:])
135 | ]
136 | try:
137 | self.initial_states = [
138 | dict(
139 | states=self.hdf5_file["data/{}/states".format(ep)][()][0],
140 | model=self.hdf5_file["data/{}".format(ep)].attrs["model_file"],
141 | )
142 | for ep in self.demos
143 | ]
144 | self.actions = [self.hdf5_file["data/{}/actions".format(ep)][()] for ep in self.demos]
145 | self.obs = [
146 | {
147 | k: self.hdf5_file["data/{}/obs/{}".format(ep, k)][()]
148 | for k in ["robot0_eef_pos", "robot0_eef_quat"]
149 | }
150 | for ep in self.demos
151 | ]
152 | except:
153 | self.initial_states = [
154 | dict(
155 | states={
156 | k_: self.hdf5_file["data/{}/{}/{}".format(ep, "states", k_)][()].astype(
157 | "float32"
158 | )[0]
159 | for k_ in self.hdf5_file["data/{}/{}".format(ep, "states")].keys()
160 | }
161 | )
162 | for ep in self.demos
163 | ]
164 | try:
165 | self.command_indices = [
166 | self.hdf5_file["data/{}/obs/command_index".format(ep)][()] for ep in self.demos
167 | ]
168 | except:
169 | self.command_indices = [np.zeros(1).reshape(-1, 1) for ep in self.demos]
170 | try:
171 | self.init_strings = [
172 | self.hdf5_file["data/{}".format(ep)].attrs["init_string"] for ep in self.demos
173 | ]
174 | self.goal_parts_strings = [
175 | self.hdf5_file["data/{}".format(ep)].attrs["goal_parts_string"] for ep in self.demos
176 | ]
177 | except:
178 | print(traceback.format_exc())
179 | self.init_strings = [None for ep in self.demos]
180 | self.goal_parts_strings = [None for ep in self.demos]
181 |
182 | def reset(self):
183 | """
184 | Modify to return frame stacked observation which is @self.num_frames copies of
185 | the initial observation.
186 |
187 | Returns:
188 | obs_stacked (dict): each observation key in original observation now has
189 | leading shape @self.num_frames and consists of the previous @self.num_frames
190 | observations
191 | """
192 | if self.dataset_path is not None:
193 | print("resetting to a valid state")
194 | self.env.reset()
195 | states = self.initial_states[self.eval_indices[self.eval_current_index]]
196 | if (
197 | self.init_strings[self.eval_indices[self.eval_current_index]] is not None
198 | and self.goal_parts_strings[self.eval_indices[self.eval_current_index]] is not None
199 | ):
200 | states["init_string"] = self.init_strings[
201 | self.eval_indices[self.eval_current_index]
202 | ]
203 | states["goal_parts_string"] = self.goal_parts_strings[
204 | self.eval_indices[self.eval_current_index]
205 | ]
206 | self.eval_current_index += 1
207 | obs = self.reset_to(states)
208 | return obs
209 | else:
210 | obs = self.env.reset()
211 | self.timestep = 0 # always zero regardless of timestep type
212 | return obs
213 |
214 | def step(self, action, **kwargs):
215 | return self.env.step(action, **kwargs)
216 |
217 |
218 | class FrameStackWrapper(Wrapper):
219 | """
220 | Wrapper for frame stacking observations during rollouts. The agent
221 | receives a sequence of past observations instead of a single observation
222 | when it calls @env.reset, @env.reset_to, or @env.step in the rollout loop.
223 | """
224 |
225 | def __init__(
226 | self,
227 | env,
228 | num_frames,
229 | horizon=None,
230 | dataset_path=None,
231 | valid_key="valid",
232 | ):
233 | """
234 | Args:
235 | env (EnvBase instance): The environment to wrap.
236 | num_frames (int): number of past observations (including current observation)
237 | to stack together. Must be greater than 1 (otherwise this wrapper would
238 | be a no-op).
239 | """
240 | assert (
241 | num_frames > 1
242 | ), "error: FrameStackWrapper must have num_frames > 1 but got num_frames of {}".format(
243 | num_frames
244 | )
245 |
246 | super(FrameStackWrapper, self).__init__(env=env)
247 | self.num_frames = num_frames
248 |
249 | # keep track of last @num_frames observations for each obs key
250 | self.obs_history = None
251 | self.horizon = horizon
252 | self.dataset_path = dataset_path
253 | self.valid_key = valid_key
254 | if dataset_path is not None:
255 | self.hdf5_file = h5py.File(dataset_path, "r", swmr=True, libver="latest")
256 | filter_key = self.valid_key
257 | self.demos = [
258 | elem.decode("utf-8")
259 | for elem in np.array(self.hdf5_file["mask/{}".format(filter_key)][:])
260 | ]
261 |
262 | def load_evaluation_data(self, idx):
263 | ep = self.demos[idx]
264 | initial_states = dict(
265 | states=self.hdf5_file["data/{}/states".format(ep)][()][0],
266 | model=self.hdf5_file["data/{}".format(ep)].attrs["model_file"],
267 | )
268 |
269 | try:
270 | init_strings = self.hdf5_file["data/{}".format(ep)].attrs["init_string"]
271 | goal_parts_strings = self.hdf5_file["data/{}".format(ep)].attrs["goal_parts_string"]
272 | except:
273 | init_strings = None
274 | goal_parts_strings = None
275 | return (
276 | initial_states,
277 | init_strings,
278 | goal_parts_strings,
279 | )
280 |
281 | def _get_initial_obs_history(self, init_obs):
282 | """
283 | Helper method to get observation history from the initial observation, by
284 | repeating it.
285 |
286 | Returns:
287 | obs_history (dict): a deque for each observation key, with an extra
288 | leading dimension of 1 for each key (for easy concatenation later)
289 | """
290 | obs_history = {}
291 | for k in init_obs:
292 | obs_history[k] = deque(
293 | [init_obs[k][None] for _ in range(self.num_frames)],
294 | maxlen=self.num_frames,
295 | )
296 | return obs_history
297 |
298 | def _get_stacked_obs_from_history(self):
299 | """
300 | Helper method to convert internal variable @self.obs_history to a
301 | stacked observation where each key is a numpy array with leading dimension
302 | @self.num_frames.
303 | """
304 | # concatenate all frames per key so we return a numpy array per key
305 | return {k: np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history}
306 |
307 | def update_obs(self, obs, action=None, reset=False):
308 | obs["timesteps"] = np.array([self.timestep])
309 | if reset:
310 | obs["actions"] = np.zeros(self.env.action_dimension)
311 | else:
312 | self.timestep += 1
313 | obs["actions"] = action[: self.env.action_dimension]
314 |
315 | def sample_eval_episodes(self, num_episodes):
316 | """
317 | Sample a random set of episodes from the set of all episodes.
318 | """
319 | self.eval_indices = np.random.choice(
320 | range(len(self.demos)), size=num_episodes, replace=False
321 | )
322 | self.eval_current_index = 0
323 |
324 | def get_num_val_states(self):
325 | return len(self.demos)
326 |
327 | def set_eval_episode(self, eval_index):
328 | self.eval_indices = [eval_index]
329 | self.eval_current_index = 0
330 |
331 | def reset(self, use_eval_indices=True):
332 | """
333 | Modify to return frame stacked observation which is @self.num_frames copies of
334 | the initial observation.
335 |
336 | Returns:
337 | obs_stacked (dict): each observation key in original observation now has
338 | leading shape @self.num_frames and consists of the previous @self.num_frames
339 | observations
340 | """
341 | if self.dataset_path is not None and use_eval_indices:
342 | print("resetting to a valid state")
343 | self.env.reset()
344 | (
345 | states,
346 | init_string,
347 | goal_parts_string,
348 | ) = self.load_evaluation_data(self.eval_indices[self.eval_current_index])
349 | if init_string is not None and goal_parts_string is not None:
350 | states["init_string"] = init_string
351 | states["goal_parts_string"] = goal_parts_string
352 | self.eval_current_index += 1
353 | obs = self.reset_to(states)
354 | return obs
355 | else:
356 | obs = self.env.reset()
357 | self.timestep = 0 # always zero regardless of timestep type
358 | self.update_obs(obs, reset=True)
359 | self.obs_history = self._get_initial_obs_history(init_obs=obs)
360 | return self._get_stacked_obs_from_history()
361 |
362 | def reset_to(self, state):
363 | """
364 | Modify to return frame stacked observation which is @self.num_frames copies of
365 | the initial observation.
366 |
367 | Returns:
368 | obs_stacked (dict): each observation key in original observation now has
369 | leading shape @self.num_frames and consists of the previous @self.num_frames
370 | observations
371 | """
372 | obs = self.env.reset_to(state)
373 | self.timestep = 0 # always zero regardless of timestep type
374 | self.update_obs(obs, reset=True)
375 | self.obs_history = self._get_initial_obs_history(init_obs=obs)
376 | return self._get_stacked_obs_from_history()
377 |
378 | def step(self, action, **kwargs):
379 | """
380 | Modify to update the internal frame history and return frame stacked observation,
381 | which will have leading dimension @self.num_frames for each key.
382 |
383 | Args:
384 | action (np.array): action to take
385 |
386 | Returns:
387 | obs_stacked (dict): each observation key in original observation now has
388 | leading shape @self.num_frames and consists of the previous @self.num_frames
389 | observations
390 | reward (float): reward for this step
391 | done (bool): whether the task is done
392 | info (dict): extra information
393 | """
394 | obs, r, done, info = self.env.step(action, **kwargs)
395 | self.update_obs(obs, action=action, reset=False)
396 | # update frame history
397 | for k in obs:
398 | # make sure to have leading dim of 1 for easy concatenation
399 | self.obs_history[k].append(obs[k][None])
400 | obs_ret = self._get_stacked_obs_from_history()
401 | return obs_ret, r, done, info
402 |
403 | def _to_string(self):
404 | """Info to pretty print."""
405 | return "num_frames={}".format(self.num_frames)
406 |
--------------------------------------------------------------------------------
/optimus/envs/env_robosuite.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | """
5 | This file contains the robosuite environment wrapper that is used
6 | to provide a standardized environment API for training policies and interacting
7 | with metadata present in datasets.
8 | """
9 | import json
10 | import os
11 | import xml.etree.ElementTree as ET
12 | from copy import deepcopy
13 |
14 | import numpy as np
15 | import robomimic.envs.env_base as EB
16 | import robomimic.utils.obs_utils as ObsUtils
17 | import robosuite
18 |
19 | from optimus.envs.stack import *
20 |
21 |
22 | def postprocess_model_xml(xml_str):
23 | """
24 | This function postprocesses the model.xml collected from a MuJoCo demonstration
25 | in order to make sure that the STL files can be found.
26 |
27 | Args:
28 | xml_str (str): Mujoco sim demonstration XML file as string
29 |
30 | Returns:
31 | str: Post-processed xml file as string
32 | """
33 |
34 | path = os.path.split(robosuite.__file__)[0]
35 | path_split = path.split("/")
36 |
37 | # replace mesh and texture file paths
38 | tree = ET.fromstring(xml_str)
39 | root = tree
40 | asset = root.find("asset")
41 | meshes = asset.findall("mesh")
42 | textures = asset.findall("texture")
43 | all_elements = meshes + textures
44 |
45 | for elem in all_elements:
46 | old_path = elem.get("file")
47 | if old_path is None:
48 | continue
49 | old_path_split = old_path.split("/")
50 | ind = max(
51 | loc for loc, val in enumerate(old_path_split) if val == "robosuite"
52 | ) # last occurrence index
53 | new_path_split = path_split + old_path_split[ind + 1 :]
54 | new_path = "/".join(new_path_split)
55 | elem.set("file", new_path)
56 |
57 | return ET.tostring(root, encoding="utf8").decode("utf8")
58 |
59 |
60 | class EnvRobosuite(EB.EnvBase):
61 | """Wrapper class for robosuite environments (https://github.com/ARISE-Initiative/robosuite)"""
62 |
63 | def __init__(
64 | self,
65 | env_name,
66 | render=False,
67 | render_offscreen=False,
68 | use_image_obs=False,
69 | postprocess_visual_obs=True,
70 | **kwargs,
71 | ):
72 | """
73 | Args:
74 | env_name (str): name of environment. Only needs to be provided if making a different
75 | environment from the one in @env_meta.
76 |
77 | render (bool): if True, environment supports on-screen rendering
78 |
79 | render_offscreen (bool): if True, environment supports off-screen rendering. This
80 | is forced to be True if @env_meta["use_images"] is True.
81 |
82 | use_image_obs (bool): if True, environment is expected to render rgb image observations
83 | on every env.step call. Set this to False for efficiency reasons, if image
84 | observations are not required.
85 |
86 | postprocess_visual_obs (bool): if True, postprocess image observations
87 | to prepare for learning. This should only be False when extracting observations
88 | for saving to a dataset (to save space on RGB images for example).
89 | """
90 | self.postprocess_visual_obs = postprocess_visual_obs
91 |
92 | # robosuite version check
93 | self._is_v1 = robosuite.__version__.split(".")[0] == "1"
94 | if self._is_v1:
95 | assert (
96 | int(robosuite.__version__.split(".")[1]) >= 2
97 | ), "only support robosuite v0.3 and v1.2+"
98 |
99 | kwargs = deepcopy(kwargs)
100 |
101 | # update kwargs based on passed arguments
102 | update_kwargs = dict(
103 | has_renderer=render,
104 | has_offscreen_renderer=(render_offscreen or use_image_obs),
105 | ignore_done=True,
106 | use_object_obs=True,
107 | use_camera_obs=use_image_obs,
108 | camera_depths=False,
109 | )
110 | kwargs.update(update_kwargs)
111 |
112 | if self._is_v1:
113 | if kwargs["has_offscreen_renderer"]:
114 | # ensure that we select the correct GPU device for rendering by testing for EGL rendering
115 | # NOTE: this package should be installed from this link (https://github.com/StanfordVL/egl_probe)
116 | import egl_probe
117 |
118 | valid_gpu_devices = egl_probe.get_available_devices()
119 | if len(valid_gpu_devices) > 0:
120 | kwargs["render_gpu_device_id"] = valid_gpu_devices[0]
121 | else:
122 | # make sure gripper visualization is turned off (we almost always want this for learning)
123 | kwargs["gripper_visualization"] = False
124 | del kwargs["camera_depths"]
125 | kwargs["camera_depth"] = False # rename kwarg
126 |
127 | self._env_name = env_name
128 | self._init_kwargs = deepcopy(kwargs)
129 | del kwargs["deterministic"]
130 | del kwargs["max_height"]
131 | del kwargs["max_towers"]
132 | del kwargs["include_language"]
133 | self.env = robosuite.make(self._env_name, **kwargs)
134 | if self._is_v1:
135 | # Make sure joint position observations and eef vel observations are active
136 | for ob_name in self.env.observation_names:
137 | if ("joint_pos" in ob_name) or ("eef_vel" in ob_name):
138 | self.env.modify_observable(
139 | observable_name=ob_name, attribute="active", modifier=True
140 | )
141 |
142 | def step(self, action):
143 | """
144 | Step in the environment with an action.
145 |
146 | Args:
147 | action (np.array): action to take
148 |
149 | Returns:
150 | observation (dict): new observation dictionary
151 | reward (float): reward for this step
152 | done (bool): whether the task is done
153 | info (dict): extra information
154 | """
155 | obs, r, done, info = self.env.step(action)
156 | obs = self.get_observation(obs)
157 | return obs, r, self.is_done(), info
158 |
159 | def reset(self):
160 | """
161 | Reset environment.
162 |
163 | Returns:
164 | observation (dict): initial observation dictionary.
165 | """
166 | di = self.env.reset()
167 | return self.get_observation(di)
168 |
169 | def reset_to(self, state):
170 | """
171 | Reset to a specific simulator state.
172 |
173 | Args:
174 | state (dict): current simulator state that contains one or more of:
175 | - states (np.ndarray): initial state of the mujoco environment
176 | - model (str): mujoco scene xml
177 |
178 | Returns:
179 | observation (dict): observation dictionary after setting the simulator state (only
180 | if "states" is in @state)
181 | """
182 | should_ret = False
183 | if "model" in state:
184 | self.reset()
185 | xml = postprocess_model_xml(state["model"])
186 | self.env.reset_from_xml_string(xml)
187 | self.env.sim.reset()
188 | if not self._is_v1:
189 | # hide teleop visualization after restoring from model
190 | self.env.sim.model.site_rgba[self.env.eef_site_id] = np.array([0.0, 0.0, 0.0, 0.0])
191 | self.env.sim.model.site_rgba[self.env.eef_cylinder_id] = np.array(
192 | [0.0, 0.0, 0.0, 0.0]
193 | )
194 | if "states" in state:
195 | self.env.sim.set_state_from_flattened(state["states"])
196 | self.env.sim.forward()
197 | should_ret = True
198 |
199 | if "goal" in state:
200 | self.set_goal(**state["goal"])
201 | if should_ret:
202 | # only return obs if we've done a forward call - otherwise the observations will be garbage
203 | return self.get_observation()
204 | return None
205 |
206 | def render(self, mode="human", height=None, width=None, camera_name="agentview"):
207 | """
208 | Render from simulation to either an on-screen window or off-screen to RGB array.
209 |
210 | Args:
211 | mode (str): pass "human" for on-screen rendering or "rgb_array" for off-screen rendering
212 | height (int): height of image to render - only used if mode is "rgb_array"
213 | width (int): width of image to render - only used if mode is "rgb_array"
214 | camera_name (str): camera name to use for rendering
215 | """
216 | if mode == "human":
217 | cam_id = self.env.sim.camera_name2id(camera_name)
218 | self.env.viewer.set_camera(cam_id)
219 | return self.env.render()
220 | elif mode == "rgb_array":
221 | return self.env.sim.render(height=height, width=width, camera_name=camera_name)[::-1]
222 | else:
223 | raise NotImplementedError("mode={} is not implemented".format(mode))
224 |
225 | def get_observation(self, di=None):
226 | """
227 | Get current environment observation dictionary.
228 |
229 | Args:
230 | di (dict): current raw observation dictionary from robosuite to wrap and provide
231 | as a dictionary. If not provided, will be queried from robosuite.
232 | """
233 | if di is None:
234 | di = (
235 | self.env._get_observations(force_update=True)
236 | if self._is_v1
237 | else self.env._get_observation()
238 | )
239 | ret = {}
240 | for k in di:
241 | if (k in ObsUtils.OBS_KEYS_TO_MODALITIES) and ObsUtils.key_is_obs_modality(
242 | key=k, obs_modality="rgb"
243 | ):
244 | ret[k] = di[k][::-1]
245 | if self.postprocess_visual_obs:
246 | ret[k] = ObsUtils.process_obs(obs=ret[k], obs_key=k)
247 |
248 | # "object" key contains object information
249 | ret["object"] = np.array(di["object-state"])
250 |
251 | if self._is_v1:
252 | for robot in self.env.robots:
253 | # add all robot-arm-specific observations. Note the (k not in ret) check
254 | # ensures that we don't accidentally add robot wrist images a second time
255 | pf = robot.robot_model.naming_prefix
256 | for k in di:
257 | if k.startswith(pf) and (k not in ret) and (not k.endswith("proprio-state")):
258 | ret[k] = np.array(di[k])
259 | else:
260 | # minimal proprioception for older versions of robosuite
261 | ret["proprio"] = np.array(di["robot-state"])
262 | ret["eef_pos"] = np.array(di["eef_pos"])
263 | ret["eef_quat"] = np.array(di["eef_quat"])
264 | ret["gripper_qpos"] = np.array(di["gripper_qpos"])
265 | return ret
266 |
267 | def get_state(self):
268 | """
269 | Get current environment simulator state as a dictionary. Should be compatible with @reset_to.
270 | """
271 | xml = self.env.sim.model.get_xml() # model xml file
272 | state = np.array(self.env.sim.get_state().flatten()) # simulator state
273 | return dict(model=xml, states=state)
274 |
275 | def get_reward(self):
276 | """
277 | Get current reward.
278 | """
279 | return self.env.reward(None)
280 |
281 | def get_goal(self):
282 | """
283 | Get goal observation. Not all environments support this.
284 | """
285 | return self.get_observation(self.env._get_goal())
286 |
287 | def set_goal(self, **kwargs):
288 | """
289 | Set goal observation with external specification. Not all environments support this.
290 | """
291 | return self.env.set_goal(**kwargs)
292 |
293 | def is_done(self):
294 | """
295 | Check if the task is done (not necessarily successful).
296 | """
297 |
298 | # Robosuite envs always rollout to fixed horizon.
299 | return False
300 |
301 | def is_success(self):
302 | """
303 | Check if the task condition(s) is reached. Should return a dictionary
304 | { str: bool } with at least a "task" key for the overall task success,
305 | and additional optional keys corresponding to other task criteria.
306 | """
307 | succ = self.env._check_success()
308 | if isinstance(succ, dict):
309 | assert "task" in succ
310 | return succ
311 | return {"task": succ}
312 |
313 | @property
314 | def action_dimension(self):
315 | """
316 | Returns dimension of actions (int).
317 | """
318 | return self.env.action_spec[0].shape[0]
319 |
320 | @property
321 | def name(self):
322 | """
323 | Returns name of environment name (str).
324 | """
325 | return self._env_name
326 |
327 | @property
328 | def type(self):
329 | """
330 | Returns environment type (int) for this kind of environment.
331 | This helps identify this env class.
332 | """
333 | return EB.EnvType.ROBOSUITE_TYPE
334 |
335 | def serialize(self):
336 | """
337 | Save all information needed to re-instantiate this environment in a dictionary.
338 | This is the same as @env_meta - environment metadata stored in hdf5 datasets,
339 | and used in utils/env_utils.py.
340 | """
341 | return dict(env_name=self.name, type=self.type, env_kwargs=deepcopy(self._init_kwargs))
342 |
343 | @classmethod
344 | def create_for_data_processing(
345 | cls,
346 | env_name,
347 | camera_names,
348 | camera_height,
349 | camera_width,
350 | reward_shaping,
351 | **kwargs,
352 | ):
353 | """
354 | Create environment for processing datasets, which includes extracting
355 | observations, labeling dense / sparse rewards, and annotating dones in
356 | transitions.
357 |
358 | Args:
359 | env_name (str): name of environment
360 | camera_names (list of str): list of camera names that correspond to image observations
361 | camera_height (int): camera height for all cameras
362 | camera_width (int): camera width for all cameras
363 | reward_shaping (bool): if True, use shaped environment rewards, else use sparse task completion rewards
364 | """
365 | is_v1 = robosuite.__version__.split(".")[0] == "1"
366 | has_camera = len(camera_names) > 0
367 |
368 | new_kwargs = {
369 | "reward_shaping": reward_shaping,
370 | }
371 |
372 | if has_camera:
373 | if is_v1:
374 | new_kwargs["camera_names"] = list(camera_names)
375 | new_kwargs["camera_heights"] = camera_height
376 | new_kwargs["camera_widths"] = camera_width
377 | else:
378 | assert len(camera_names) == 1
379 | if has_camera:
380 | new_kwargs["camera_name"] = camera_names[0]
381 | new_kwargs["camera_height"] = camera_height
382 | new_kwargs["camera_width"] = camera_width
383 |
384 | kwargs.update(new_kwargs)
385 |
386 | # also initialize obs utils so it knows which modalities are image modalities
387 | image_modalities = list(camera_names)
388 | if is_v1:
389 | image_modalities = ["{}_image".format(cn) for cn in camera_names]
390 | elif has_camera:
391 | # v0.3 only had support for one image, and it was named "rgb"
392 | assert len(image_modalities) == 1
393 | image_modalities = ["rgb"]
394 | obs_modality_specs = {
395 | "obs": {
396 | "low_dim": [], # technically unused, so we don't have to specify all of them
397 | "rgb": image_modalities,
398 | }
399 | }
400 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs)
401 |
402 | # note that @postprocess_visual_obs is False since this env's images will be written to a dataset
403 | return cls(
404 | env_name=env_name,
405 | render=False,
406 | render_offscreen=has_camera,
407 | use_image_obs=has_camera,
408 | postprocess_visual_obs=False,
409 | **kwargs,
410 | )
411 |
412 | @property
413 | def rollout_exceptions(self):
414 | """
415 | Return tuple of exceptions to except when doing rollouts. This is useful to ensure
416 | that the entire training run doesn't crash because of a bad policy that causes unstable
417 | simulation computations.
418 | """
419 | return tuple()
420 |
421 | def __repr__(self):
422 | """
423 | Pretty-print env description.
424 | """
425 | return self.name + "\n" + json.dumps(self._init_kwargs, sort_keys=True, indent=4)
426 |
--------------------------------------------------------------------------------
/optimus/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | """
6 | This file contains Dataset classes that are used by torch dataloaders
7 | to fetch batches from hdf5 files.
8 | """
9 | import numpy as np
10 | import robomimic.utils.log_utils as LogUtils
11 | import robomimic.utils.obs_utils as ObsUtils
12 | import robomimic.utils.tensor_utils as TensorUtils
13 | from robomimic.utils.dataset import SequenceDataset
14 |
15 |
16 | class SequenceDataset(SequenceDataset):
17 | def __init__(
18 | self,
19 | *args,
20 | transformer_enabled=False,
21 | **kwargs,
22 | ):
23 | self.transformer_enabled = transformer_enabled
24 | self.vis_data = dict()
25 | self.ep_to_hdf5_file = None
26 | super().__init__(*args, **kwargs)
27 |
28 | def get_dataset_for_ep(self, ep, key):
29 | """
30 | Helper utility to get a dataset for a specific demonstration.
31 | Takes into account whether the dataset has been loaded into memory.
32 | """
33 | if self.ep_to_hdf5_file is None:
34 | self.ep_to_hdf5_file = {ep: self.hdf5_file for ep in self.demos}
35 | # check if this key should be in memory
36 | key_should_be_in_memory = self.hdf5_cache_mode in ["all", "low_dim"]
37 | if key_should_be_in_memory:
38 | # if key is an observation, it may not be in memory
39 | if "/" in key:
40 | key1, key2 = key.split("/")
41 | assert key1 in ["obs", "next_obs"]
42 | if key2 not in self.obs_keys_in_memory:
43 | key_should_be_in_memory = False
44 |
45 | if key_should_be_in_memory:
46 | # read cache
47 | if "/" in key:
48 | key1, key2 = key.split("/")
49 | assert key1 in ["obs", "next_obs"]
50 | ret = self.hdf5_cache[ep][key1][key2]
51 | else:
52 | ret = self.hdf5_cache[ep][key]
53 | else:
54 | # read from file
55 | hd5key = "data/{}/{}".format(ep, key)
56 | ret = self.ep_to_hdf5_file[ep][hd5key]
57 | return ret
58 |
59 | def get_sequence_from_demo(
60 | self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1
61 | ):
62 | """
63 | Extract a (sub)sequence of data items from a demo given the @keys of the items.
64 |
65 | Args:
66 | demo_id (str): id of the demo, e.g., demo_0
67 | index_in_demo (int): beginning index of the sequence wrt the demo
68 | keys (tuple): list of keys to extract
69 | num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range
70 | seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range
71 |
72 | Returns:
73 | a dictionary of extracted items.
74 | """
75 | assert num_frames_to_stack >= 0
76 | assert seq_length >= 1
77 |
78 | demo_length = self._demo_id_to_demo_length[demo_id]
79 | assert index_in_demo < demo_length
80 |
81 | # determine begin and end of sequence
82 | seq_begin_index = max(0, index_in_demo - num_frames_to_stack)
83 | seq_end_index = min(demo_length, index_in_demo + seq_length)
84 |
85 | # determine sequence padding
86 | seq_begin_pad = max(0, num_frames_to_stack - index_in_demo) # pad for frame stacking
87 | seq_end_pad = max(0, index_in_demo + seq_length - demo_length) # pad for sequence length
88 |
89 | # make sure we are not padding if specified.
90 | if not self.pad_frame_stack:
91 | assert seq_begin_pad == 0
92 | if not self.pad_seq_length:
93 | assert seq_end_pad == 0
94 |
95 | # fetch observation from the dataset file
96 | seq = dict()
97 | for k in keys:
98 | data = self.get_dataset_for_ep(demo_id, k)
99 | seq[k] = data[seq_begin_index:seq_end_index]
100 | seq = TensorUtils.pad_sequence(seq, padding=(seq_begin_pad, seq_end_pad), pad_same=True)
101 | pad_mask = np.array(
102 | [0] * seq_begin_pad + [1] * (seq_end_index - seq_begin_index) + [0] * seq_end_pad
103 | )
104 | pad_mask = pad_mask[:, None].astype(np.bool)
105 |
106 | return seq, pad_mask
107 |
108 | def get_obs_sequence_from_demo(
109 | self,
110 | demo_id,
111 | index_in_demo,
112 | keys,
113 | num_frames_to_stack=0,
114 | seq_length=1,
115 | prefix="obs",
116 | ):
117 | """
118 | Extract a (sub)sequence of observation items from a demo given the @keys of the items.
119 |
120 | Args:
121 | demo_id (str): id of the demo, e.g., demo_0
122 | index_in_demo (int): beginning index of the sequence wrt the demo
123 | keys (tuple): list of keys to extract
124 | num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range
125 | seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range
126 | prefix (str): one of "obs", "next_obs"
127 |
128 | Returns:
129 | a dictionary of extracted items.
130 | """
131 | obs, pad_mask = self.get_sequence_from_demo(
132 | demo_id,
133 | index_in_demo=index_in_demo,
134 | keys=tuple("{}/{}".format(prefix, k) for k in keys),
135 | num_frames_to_stack=num_frames_to_stack,
136 | seq_length=seq_length,
137 | )
138 | obs = {k.split("/")[1]: obs[k] for k in obs} # strip the prefix
139 | if self.get_pad_mask:
140 | obs["pad_mask"] = pad_mask
141 |
142 | # prepare image observations from dataset
143 | return obs
144 |
145 | def load_dataset_in_memory(self, demo_list, hdf5_file, obs_keys, dataset_keys, load_next_obs):
146 | """
147 | Loads the hdf5 dataset into memory, preserving the structure of the file. Note that this
148 | differs from `self.getitem_cache`, which, if active, actually caches the outputs of the
149 | `getitem` operation.
150 |
151 | Args:
152 | demo_list (list): list of demo keys, e.g., 'demo_0'
153 | hdf5_file (h5py.File): file handle to the hdf5 dataset.
154 | obs_keys (list, tuple): observation keys to fetch, e.g., 'images'
155 | dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions'
156 | load_next_obs (bool): whether to load next_obs from the dataset
157 |
158 | Returns:
159 | all_data (dict): dictionary of loaded data.
160 | """
161 | all_data = dict()
162 |
163 | print("SequenceDataset: loading dataset into memory...")
164 | obs_keys = [o for o in obs_keys if o != "timesteps" and o != "goal"]
165 |
166 | for ep in LogUtils.custom_tqdm(demo_list):
167 | all_data[ep] = {}
168 | all_data[ep]["attrs"] = {}
169 | all_data[ep]["attrs"]["num_samples"] = hdf5_file["data/{}".format(ep)].attrs[
170 | "num_samples"
171 | ]
172 |
173 | # get other dataset keys
174 | for k in dataset_keys:
175 | if k in hdf5_file["data/{}".format(ep)]:
176 | all_data[ep][k] = hdf5_file["data/{}/{}".format(ep, k)][()].astype("float32")
177 | else:
178 | all_data[ep][k] = np.zeros(
179 | (all_data[ep]["attrs"]["num_samples"], 1), dtype=np.float32
180 | )
181 | # get obs
182 | all_data[ep]["obs"] = {
183 | k: hdf5_file["data/{}/obs/{}".format(ep, k)][()].astype("float32") for k in obs_keys
184 | }
185 |
186 | if self.load_next_obs:
187 | # last block position is given by last elem of next_obs
188 | goal = hdf5_file["data/{}/next_obs/{}".format(ep, "object")][()].astype("float32")[
189 | -1, 7:10
190 | ]
191 | all_data[ep]["obs"]["goal"] = np.repeat(
192 | goal.reshape(1, -1), all_data[ep]["attrs"]["num_samples"], axis=0
193 | )
194 |
195 | if self.transformer_enabled:
196 | all_data[ep]["obs"]["timesteps"] = np.arange(
197 | 0, all_data[ep]["obs"][obs_keys[0]].shape[0]
198 | ).reshape(-1, 1)
199 | if load_next_obs:
200 | all_data[ep]["next_obs"] = {
201 | k: hdf5_file["data/{}/next_obs/{}".format(ep, k)][()].astype("float32")
202 | for k in obs_keys
203 | }
204 | if self.transformer_enabled:
205 | # Doesn't actually matter, won't be used
206 | all_data[ep]["next_obs"]["timesteps"] = np.zeros_like(
207 | all_data[ep]["obs"]["timesteps"]
208 | )
209 | all_data[ep]["next_obs"]["goal"] = np.repeat(
210 | goal.reshape(1, -1), all_data[ep]["attrs"]["num_samples"], axis=0
211 | )
212 | return all_data
213 |
214 | def get_dataset_sequence_from_demo(
215 | self, demo_id, index_in_demo, keys, num_frames_to_stack=0, seq_length=1
216 | ):
217 | """
218 | Extract a (sub)sequence of dataset items from a demo given the @keys of the items (e.g., states, actions).
219 |
220 | Args:
221 | demo_id (str): id of the demo, e.g., demo_0
222 | index_in_demo (int): beginning index of the sequence wrt the demo
223 | keys (tuple): list of keys to extract
224 | num_frames_to_stack (int): numbers of frame to stack. Seq gets prepended with repeated items if out of range
225 | seq_length (int): sequence length to extract. Seq gets post-pended with repeated items if out of range
226 |
227 | Returns:
228 | a dictionary of extracted items.
229 | """
230 | data, pad_mask = self.get_sequence_from_demo(
231 | demo_id,
232 | index_in_demo=index_in_demo,
233 | keys=keys,
234 | num_frames_to_stack=num_frames_to_stack, # don't frame stack for meta keys
235 | seq_length=seq_length,
236 | )
237 | if self.get_pad_mask:
238 | data["pad_mask"] = pad_mask
239 | return data
240 |
241 | def get_item(self, index):
242 | """
243 | Main implementation of getitem when not using cache.
244 | """
245 |
246 | demo_id = self._index_to_demo_id[index]
247 | demo_start_index = self._demo_id_to_start_indices[demo_id]
248 | demo_length = self._demo_id_to_demo_length[demo_id]
249 |
250 | # start at offset index if not padding for frame stacking
251 | demo_index_offset = 0 if self.pad_frame_stack else (self.n_frame_stack - 1)
252 | index_in_demo = index - demo_start_index + demo_index_offset
253 |
254 | # end at offset index if not padding for seq length
255 | demo_length_offset = 0 if self.pad_seq_length else (self.seq_length - 1)
256 | end_index_in_demo = demo_length - demo_length_offset
257 |
258 | keys = [*self.dataset_keys]
259 | meta = self.get_dataset_sequence_from_demo(
260 | demo_id,
261 | index_in_demo=index_in_demo,
262 | keys=keys,
263 | num_frames_to_stack=self.n_frame_stack - 1,
264 | seq_length=self.seq_length,
265 | )
266 |
267 | # determine goal index
268 | goal_index = None
269 | if self.goal_mode == "last":
270 | goal_index = end_index_in_demo - 1
271 |
272 | meta["obs"] = self.get_obs_sequence_from_demo(
273 | demo_id,
274 | index_in_demo=index_in_demo,
275 | keys=self.obs_keys,
276 | num_frames_to_stack=self.n_frame_stack - 1,
277 | seq_length=self.seq_length,
278 | prefix="obs",
279 | )
280 | if self.hdf5_normalize_obs:
281 | meta["obs"] = ObsUtils.normalize_obs(
282 | meta["obs"], obs_normalization_stats=self.obs_normalization_stats
283 | )
284 |
285 | if self.load_next_obs:
286 | meta["next_obs"] = self.get_obs_sequence_from_demo(
287 | demo_id,
288 | index_in_demo=index_in_demo,
289 | keys=self.obs_keys,
290 | num_frames_to_stack=self.n_frame_stack - 1,
291 | seq_length=self.seq_length,
292 | prefix="next_obs",
293 | )
294 | if self.hdf5_normalize_obs:
295 | meta["next_obs"] = ObsUtils.normalize_obs(
296 | meta["next_obs"],
297 | obs_normalization_stats=self.obs_normalization_stats,
298 | )
299 |
300 | if goal_index is not None:
301 | goal = self.get_obs_sequence_from_demo(
302 | demo_id,
303 | index_in_demo=goal_index,
304 | keys=self.obs_keys,
305 | num_frames_to_stack=0,
306 | seq_length=1,
307 | prefix="next_obs",
308 | )
309 | if self.hdf5_normalize_obs:
310 | goal = ObsUtils.normalize_obs(
311 | goal, obs_normalization_stats=self.obs_normalization_stats
312 | )
313 | meta["goal_obs"] = {k: goal[k][0] for k in goal} # remove sequence dimension for goal
314 |
315 | return meta
316 |
317 | def update_demo_info(self, demos, online_epoch, data, hdf5_file=None):
318 | """
319 | This function is called during online epochs to update the demo information based
320 | on newly collected demos.
321 | Args:
322 | demos (list): list of demonstration keys to load data.
323 | online_epoch (int): value of the current online epoch
324 | data (dict): dictionary containing newly collected demos
325 | """
326 | # sort demo keys
327 | inds = np.argsort(
328 | [int(elem[5:]) for elem in demos if not (elem in ["env_args", "model_file"])]
329 | )
330 | new_demos = [demos[i] for i in inds]
331 | self.demos.extend(new_demos)
332 |
333 | self.n_demos = len(self.demos)
334 |
335 | self.prev_total_num_sequences = self.total_num_sequences
336 | for new_ep in new_demos:
337 | self.ep_to_hdf5_file[new_ep] = hdf5_file
338 | demo_length = data[new_ep]["num_samples"]
339 | self._demo_id_to_start_indices[new_ep] = self.total_num_sequences
340 | self._demo_id_to_demo_length[new_ep] = demo_length
341 |
342 | num_sequences = demo_length
343 | # determine actual number of sequences taking into account whether to pad for frame_stack and seq_length
344 | if not self.pad_frame_stack:
345 | num_sequences -= self.n_frame_stack - 1
346 | if not self.pad_seq_length:
347 | num_sequences -= self.seq_length - 1
348 |
349 | if self.pad_seq_length:
350 | assert demo_length >= 1 # sequence needs to have at least one sample
351 | num_sequences = max(num_sequences, 1)
352 | else:
353 | assert (
354 | num_sequences >= 1
355 | ) # assume demo_length >= (self.n_frame_stack - 1 + self.seq_length)
356 |
357 | for _ in range(num_sequences):
358 | self._index_to_demo_id[self.total_num_sequences] = new_ep
359 | self.total_num_sequences += 1
360 | return new_demos
361 |
362 | def update_dataset_in_memory(
363 | self,
364 | demo_list,
365 | data,
366 | obs_keys,
367 | dataset_keys,
368 | load_next_obs=False,
369 | online_epoch=0,
370 | ):
371 | """
372 | Loads the newly collected dataset into memory, preserving the structure of the data. Note that this
373 | differs from `self.getitem_cache`, which, if active, actually caches the outputs of the
374 | `getitem` operation.
375 |
376 | Args:
377 | demo_list (list): list of demo keys, e.g., 'demo_0'
378 | data (dict): dictionary containing newly collected demos
379 | obs_keys (list, tuple): observation keys to fetch, e.g., 'images'
380 | dataset_keys (list, tuple): dataset keys to fetch, e.g., 'actions'
381 | load_next_obs (bool): whether to load next_obs from the dataset
382 |
383 | Returns:
384 | all_data (dict): dictionary of loaded data.
385 | """
386 | all_data = dict()
387 | print("SequenceDataset: loading dataset into memory...")
388 | obs_keys = [o for o in obs_keys if o != "timesteps"]
389 | for new_ep in LogUtils.custom_tqdm(demo_list):
390 | all_data[new_ep] = {}
391 | all_data[new_ep]["attrs"] = {}
392 | all_data[new_ep]["attrs"]["num_samples"] = data[new_ep]["num_samples"]
393 |
394 | # get other dataset keys
395 | for k in dataset_keys:
396 | if k in data[new_ep]:
397 | all_data[new_ep][k] = data[new_ep][k].astype("float32")
398 | else:
399 | all_data[new_ep][k] = np.zeros(
400 | (all_data[new_ep]["attrs"]["num_samples"], 1), dtype=np.float32
401 | )
402 | # get obs
403 | all_data[new_ep]["obs"] = {
404 | k: data[new_ep]["obs"][k] for k in obs_keys if k != "timesteps"
405 | }
406 |
407 | for k in all_data[new_ep]["obs"]:
408 | all_data[new_ep]["obs"][k] = all_data[new_ep]["obs"][k].astype("float32")
409 |
410 | if self.transformer_enabled:
411 | all_data[new_ep]["obs"]["timesteps"] = np.arange(
412 | 0, all_data[new_ep]["obs"][obs_keys[0]].shape[0]
413 | ).reshape(-1, 1)
414 |
415 | self.hdf5_cache.update(all_data)
416 |
--------------------------------------------------------------------------------
/optimus/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | import json
5 | import os
6 | import time
7 | from collections import OrderedDict
8 |
9 | import imageio
10 | import numpy as np
11 | import robomimic.utils.log_utils as LogUtils
12 | from robomimic.algo import RolloutPolicy
13 | from robomimic.utils.train_utils import *
14 |
15 | from optimus.config.base_config import config_factory
16 | from optimus.envs.wrappers import FrameStackWrapper
17 | from optimus.scripts.combine_hdf5 import global_dataset_updates, write_trajectory_to_dataset
18 | from optimus.utils.dataset import SequenceDataset
19 |
20 | import optimus
21 |
22 |
23 | def get_exp_dir(config, auto_remove_exp_dir=False):
24 | """
25 | Create experiment directory from config. If an identical experiment directory
26 | exists and @auto_remove_exp_dir is False (default), the function will prompt
27 | the user on whether to remove and replace it, or keep the existing one and
28 | add a new subdirectory with the new timestamp for the current run.
29 |
30 | Args:
31 | auto_remove_exp_dir (bool): if True, automatically remove the existing experiment
32 | folder if it exists at the same path.
33 |
34 | Returns:
35 | log_dir (str): path to created log directory (sub-folder in experiment directory)
36 | output_dir (str): path to created models directory (sub-folder in experiment directory)
37 | to store model checkpoints
38 | video_dir (str): path to video directory (sub-folder in experiment directory)
39 | to store rollout videos
40 | """
41 | # timestamp for directory names
42 | t_now = time.time()
43 | time_str = datetime.datetime.fromtimestamp(t_now).strftime("%Y%m%d%H%M%S")
44 |
45 | # create directory for where to dump model parameters, tensorboard logs, and videos
46 | base_output_dir = config.train.output_dir
47 | if not os.path.isabs(base_output_dir):
48 | # relative paths are specified relative to optimus module location
49 | base_output_dir = os.path.join(optimus.__path__[0], '../'+base_output_dir)
50 | base_output_dir = os.path.join(base_output_dir, config.experiment.name)
51 | if os.path.exists(base_output_dir):
52 | if not auto_remove_exp_dir:
53 | ans = input(
54 | "WARNING: model directory ({}) already exists! \noverwrite? (y/n)\n".format(
55 | base_output_dir
56 | )
57 | )
58 | else:
59 | ans = "y"
60 | if ans == "y":
61 | print("REMOVING")
62 | shutil.rmtree(base_output_dir)
63 |
64 | # only make model directory if model saving is enabled
65 | output_dir = None
66 | if config.experiment.save.enabled:
67 | output_dir = os.path.join(base_output_dir, time_str, "models")
68 | os.makedirs(output_dir)
69 |
70 | # tensorboard directory
71 | log_dir = os.path.join(base_output_dir, time_str, "logs")
72 | os.makedirs(log_dir)
73 |
74 | # video directory
75 | video_dir = os.path.join(base_output_dir, time_str, "videos")
76 | os.makedirs(video_dir)
77 | return log_dir, output_dir, video_dir, time_str
78 |
79 |
80 | def load_data_for_training(config, obs_keys):
81 | """
82 | Data loading at the start of an algorithm.
83 |
84 | Args:
85 | config (BaseConfig instance): config object
86 | obs_keys (list): list of observation modalities that are required for
87 | training (this will inform the dataloader on what modalities to load)
88 |
89 | Returns:
90 | train_dataset (SequenceDataset instance): train dataset object
91 | valid_dataset (SequenceDataset instance): valid dataset object (only if using validation)
92 | """
93 |
94 | # config can contain an attribute to filter on
95 | filter_by_attribute = config.train.hdf5_filter_key
96 |
97 | # load the dataset into memory
98 | if config.experiment.validate:
99 | train_dataset = dataset_factory(config, obs_keys, filter_by_attribute=filter_by_attribute)
100 | valid_dataset = dataset_factory(config, obs_keys, filter_by_attribute="valid")
101 | else:
102 | train_dataset = dataset_factory(config, obs_keys, filter_by_attribute=filter_by_attribute)
103 | valid_dataset = None
104 |
105 | return train_dataset, valid_dataset
106 |
107 |
108 | def dataset_factory(config, obs_keys, filter_by_attribute=None, dataset_path=None):
109 | """
110 | Create a SequenceDataset instance to pass to a torch DataLoader.
111 |
112 | Args:
113 | config (BaseConfig instance): config object
114 |
115 | obs_keys (list): list of observation modalities that are required for
116 | training (this will inform the dataloader on what modalities to load)
117 |
118 | filter_by_attribute (str): if provided, use the provided filter key
119 | to select a subset of demonstration trajectories to load
120 |
121 | dataset_path (str): if provided, the SequenceDataset instance should load
122 | data from this dataset path. Defaults to config.train.data.
123 |
124 | Returns:
125 | dataset (SequenceDataset instance): dataset object
126 | """
127 | if dataset_path is None:
128 | dataset_path = config.train.data
129 |
130 | ds_kwargs = dict(
131 | hdf5_path=dataset_path,
132 | obs_keys=obs_keys,
133 | dataset_keys=config.train.dataset_keys,
134 | load_next_obs=config.train.load_next_obs,
135 | frame_stack=config.train.frame_stack,
136 | seq_length=config.train.seq_length,
137 | pad_frame_stack=config.train.pad_frame_stack,
138 | pad_seq_length=config.train.pad_seq_length,
139 | get_pad_mask=False,
140 | goal_mode=config.train.goal_mode,
141 | hdf5_cache_mode=config.train.hdf5_cache_mode,
142 | hdf5_use_swmr=config.train.hdf5_use_swmr,
143 | hdf5_normalize_obs=config.train.hdf5_normalize_obs,
144 | filter_by_attribute=filter_by_attribute,
145 | transformer_enabled=config.algo.transformer.enabled,
146 | )
147 | dataset = SequenceDataset(**ds_kwargs)
148 |
149 | return dataset
150 |
151 |
152 | def run_rollout(
153 | policy,
154 | env,
155 | horizon,
156 | use_goals=False,
157 | render=False,
158 | video_writer=None,
159 | video_skip=5,
160 | terminate_on_success=False,
161 | ):
162 | """
163 | Runs a rollout in an environment with the current network parameters.
164 |
165 | Args:
166 | policy (RolloutPolicy instance): policy to use for rollouts.
167 |
168 | env (EnvBase instance): environment to use for rollouts.
169 |
170 | horizon (int): maximum number of steps to roll the agent out for
171 |
172 | use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env
173 |
174 | render (bool): if True, render the rollout to the screen
175 |
176 | video_writer (imageio Writer instance): if not None, use video writer object to append frames at
177 | rate given by @video_skip
178 |
179 | video_skip (int): how often to write video frame
180 |
181 | terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered
182 |
183 | Returns:
184 | results (dict): dictionary containing return, success rate, etc.
185 | """
186 | assert isinstance(policy, RolloutPolicy)
187 | assert (
188 | isinstance(env, EnvBase)
189 | or isinstance(env.env, EnvBase)
190 | or isinstance(env, FrameStackWrapper)
191 | )
192 |
193 | policy.start_episode()
194 |
195 | ob_dict = env.reset()
196 | goal_dict = None
197 | if use_goals:
198 | # retrieve goal from the environment
199 | goal_dict = env.get_goal()
200 |
201 | results = {}
202 | video_count = 0 # video frame counter
203 |
204 | total_reward = 0.0
205 | success = {k: False for k in env.is_success()} # success metrics
206 | obs_log = {k: [v.reshape(1, -1)] for k, v in ob_dict.items() if not (k.endswith("image"))}
207 | traj = dict(actions=[], states=[], initial_state_dict=env.get_state())
208 | try:
209 | for step_i in range(horizon):
210 | state_dict = env.get_state()
211 | traj["states"].append(state_dict["states"])
212 | # get action from policy
213 | ac = policy(ob=ob_dict, goal=goal_dict)
214 | # play action
215 | ob_dict, r, done, info = env.step(ac)
216 |
217 | for k, v in ob_dict.items():
218 | if not (k.endswith("image")):
219 | obs_log[k].append(v.reshape(1, -1))
220 |
221 | # render to screen
222 | if render:
223 | env.render(mode="human")
224 |
225 | # compute reward
226 | total_reward += r
227 |
228 | cur_success_metrics = env.is_success()
229 | for k in success:
230 | success[k] = success[k] or cur_success_metrics[k]
231 |
232 | # visualization
233 | if video_writer is not None:
234 | if video_count % video_skip == 0:
235 | video_img = []
236 | video_img.append(env.render(mode="rgb_array", height=512, width=512))
237 | video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
238 | video_writer.append_data(video_img)
239 | video_count += 1
240 |
241 | # break if done
242 | if done or (terminate_on_success and success["task"]):
243 | break
244 | state_dict = env.get_state()
245 | traj["states"].append(state_dict["states"])
246 | traj["actions"] = np.array([0]) # just a dummy value
247 | traj["attrs"] = dict(num_samples=len(traj["states"]))
248 | except env.rollout_exceptions as e:
249 | print("WARNING: got rollout exception {}".format(e))
250 |
251 | for k, v in obs_log.items():
252 | obs_log[k] = np.concatenate(v, axis=0)
253 |
254 | results["Return"] = total_reward
255 | results["Horizon"] = step_i + 1
256 | results["Success_Rate"] = float(success["task"])
257 | results["Observations"] = obs_log
258 | for k, v in info.items():
259 | if not (k.endswith("actions")) and not (k.endswith("obs")):
260 | results[k] = v
261 |
262 | # log additional success metrics
263 | for k in success:
264 | if k != "task":
265 | results["{}_Success_Rate".format(k)] = float(success[k])
266 | return results, traj
267 |
268 |
269 | @torch.no_grad()
270 | def rollout_with_stats(
271 | policy,
272 | envs,
273 | horizon,
274 | use_goals=False,
275 | num_episodes=None,
276 | render=False,
277 | video_dir=None,
278 | video_path=None,
279 | epoch=None,
280 | video_skip=5,
281 | terminate_on_success=False,
282 | verbose=False,
283 | rollout_dir=None,
284 | config=None,
285 | ):
286 | """
287 | A helper function used in the train loop to conduct evaluation rollouts per environment
288 | and summarize the results.
289 |
290 | Can specify @video_dir (to dump a video per environment) or @video_path (to dump a single video
291 | for all environments).
292 |
293 | Args:
294 | policy (RolloutPolicy instance): policy to use for rollouts.
295 |
296 | envs (dict): dictionary that maps env_name (str) to EnvBase instance. The policy will
297 | be rolled out in each env.
298 |
299 | horizon (int): maximum number of steps to roll the agent out for
300 |
301 | use_goals (bool): if True, agent is goal-conditioned, so provide goal observations from env
302 |
303 | num_episodes (int): number of rollout episodes per environment
304 |
305 | render (bool): if True, render the rollout to the screen
306 |
307 | video_dir (str): if not None, dump rollout videos to this directory (one per environment)
308 |
309 | video_path (str): if not None, dump a single rollout video for all environments
310 |
311 | epoch (int): epoch number (used for video naming)
312 |
313 | video_skip (int): how often to write video frame
314 |
315 | terminate_on_success (bool): if True, terminate episode early as soon as a success is encountered
316 |
317 | verbose (bool): if True, print results of each rollout
318 |
319 | Returns:
320 | all_rollout_logs (dict): dictionary of rollout statistics (e.g. return, success rate, ...)
321 | averaged across all rollouts
322 |
323 | video_paths (dict): path to rollout videos for each environment
324 | """
325 | assert isinstance(policy, RolloutPolicy)
326 |
327 | all_rollout_logs = OrderedDict()
328 |
329 | # handle paths and create writers for video writing
330 | assert (video_path is None) or (
331 | video_dir is None
332 | ), "rollout_with_stats: can't specify both video path and dir"
333 | write_video = (video_path is not None) or (video_dir is not None)
334 | video_paths = OrderedDict()
335 | video_writers = OrderedDict()
336 | if video_path is not None:
337 | # a single video is written for all envs
338 | video_paths = {k: video_path for k in envs}
339 | video_writer = imageio.get_writer(video_path, fps=20)
340 | video_writers = {k: video_writer for k in envs}
341 | if video_dir is not None:
342 | # video is written per env
343 | video_str = "_epoch_{}.mp4".format(epoch) if epoch is not None else ".mp4"
344 | video_paths = {k: os.path.join(video_dir, "{}{}".format(k, video_str)) for k in envs}
345 | video_writers = {k: imageio.get_writer(video_paths[k], fps=20) for k in envs}
346 |
347 | for env_name, env in envs.items():
348 | env_video_writer = None
349 | if write_video:
350 | print("video writes to " + video_paths[env_name])
351 | env_video_writer = video_writers[env_name]
352 |
353 | print(
354 | "rollout: env={}, horizon={}, use_goals={}, num_episodes={}".format(
355 | env.name,
356 | horizon,
357 | use_goals,
358 | num_episodes,
359 | )
360 | )
361 | rollout_logs = []
362 | num_valid_demos = get_num_valid_demos(config, env_name)
363 | num_episodes = min(num_valid_demos, num_episodes)
364 | iterator = range(num_episodes)
365 | if not verbose:
366 | iterator = LogUtils.custom_tqdm(iterator, total=num_episodes)
367 |
368 | num_success = 0
369 | obs_logs = {}
370 | env.sample_eval_episodes(num_episodes)
371 | data_writer = h5py.File(os.path.join(rollout_dir, f"rollout_{epoch}.hdf5"), "w")
372 | data_grp = data_writer.create_group("data")
373 | for ep_i in iterator:
374 | rollout_timestamp = time.time()
375 | rollout_info, traj = run_rollout(
376 | policy=policy,
377 | env=env,
378 | horizon=horizon,
379 | render=render,
380 | use_goals=use_goals,
381 | video_writer=env_video_writer,
382 | video_skip=video_skip,
383 | terminate_on_success=terminate_on_success,
384 | )
385 | rollout_info["time"] = time.time() - rollout_timestamp
386 | obs_logs[ep_i] = {"obs": rollout_info["Observations"]}
387 | del rollout_info["Observations"]
388 | rollout_logs.append(rollout_info)
389 | num_success += rollout_info["Success_Rate"]
390 | if verbose:
391 | print(
392 | "Episode {}, horizon={}, num_success={}".format(ep_i + 1, horizon, num_success)
393 | )
394 | print(json.dumps(rollout_info, sort_keys=True, indent=4))
395 | write_trajectory_to_dataset(
396 | None,
397 | traj,
398 | data_grp,
399 | demo_name=f"demo_{ep_i}",
400 | env_type="mujoco",
401 | )
402 | global_dataset_updates(data_grp, 0, json.dumps(env.serialize(), indent=4))
403 | data_writer.close()
404 | successes = [rollout_log["Success_Rate"] for rollout_log in rollout_logs]
405 | if video_dir is not None:
406 | # close this env's video writer (next env has it's own)
407 | env_video_writer.close()
408 |
409 | # average metric across all episodes
410 | rollout_logs = dict(
411 | (
412 | k,
413 | [
414 | rollout_logs[i][k]
415 | for i in range(len(rollout_logs))
416 | if rollout_logs[i][k] is not None
417 | ],
418 | )
419 | for k in rollout_logs[0]
420 | )
421 | rollout_logs_mean = dict((k, np.mean(v)) for k, v in rollout_logs.items())
422 | rollout_logs_mean["Time_Episode"] = (
423 | np.sum(rollout_logs["time"]) / 60.0
424 | ) # total time taken for rollouts in minutes
425 | all_rollout_logs[env_name] = rollout_logs_mean
426 |
427 | if video_path is not None:
428 | # close video writer that was used for all envs
429 | video_writer.close()
430 |
431 | return all_rollout_logs, video_paths
432 |
433 |
434 | def get_num_valid_demos(config, env_name):
435 | valid_key = config.experiment.rollout.valid_key
436 | dataset_path = config.train.data
437 | if dataset_path is not None:
438 | hdf5_file = h5py.File(dataset_path, "r", swmr=True, libver="latest")
439 | filter_key = valid_key
440 | demos = [
441 | elem.decode("utf-8") for elem in np.array(hdf5_file["mask/{}".format(filter_key)][:])
442 | ]
443 | return len(demos)
444 | else:
445 | return 0
446 |
447 |
448 | def get_config_from_path(config_path):
449 | ext_cfg = json.load(open(os.path.join(config_path), "r"))
450 | config = config_factory(ext_cfg["algo_name"])
451 | # update config with external json - this will throw errors if
452 | # the external config has keys not present in the base algo config
453 | with config.values_unlocked():
454 | config.update(ext_cfg)
455 | return config
456 |
--------------------------------------------------------------------------------
/optimus/scripts/run_trained_agent_pl.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 | """
5 | The main script for evaluating a policy in an environment. Adapted to use PyTorch Lightning and Optimus codebase.
6 |
7 | Args:
8 | agent (str): path to saved checkpoint pth file
9 |
10 | horizon (int): if provided, override maximum horizon of rollout from the one
11 | in the checkpoint
12 |
13 | env (str): if provided, override name of env from the one in the checkpoint,
14 | and use it for rollouts
15 |
16 | render (bool): if flag is provided, use on-screen rendering during rollouts
17 |
18 | video_path (str): if provided, render trajectories to this video file path
19 |
20 | video_skip (int): render frames to a video every @video_skip steps
21 |
22 | camera_names (str or [str]): camera name(s) to use for rendering on-screen or to video
23 |
24 | dataset_path (str): if provided, an hdf5 file will be written at this path with the
25 | rollout data
26 |
27 | dataset_obs (bool): if flag is provided, and @dataset_path is provided, include
28 | possible high-dimensional observations in output dataset hdf5 file (by default,
29 | observations are excluded and only simulator states are saved).
30 |
31 | seed (int): if provided, set seed for rollouts
32 |
33 | Example usage:
34 |
35 | # Evaluate a policy with 50 rollouts of maximum horizon 400 and save the rollouts to a video.
36 | # Visualize the agentview and wrist cameras during the rollout.
37 |
38 | python run_trained_agent.py --agent /path/to/model.pth \
39 | --n_rollouts 50 --horizon 400 --seed 0 \
40 | --video_path /path/to/output.mp4 \
41 | --camera_names agentview robot0_eye_in_hand
42 |
43 | # Write the 50 agent rollouts to a new dataset hdf5.
44 |
45 | python run_trained_agent.py --agent /path/to/model.pth \
46 | --n_rollouts 50 --horizon 400 --seed 0 \
47 | --dataset_path /path/to/output.hdf5 --dataset_obs
48 |
49 | # Write the 50 agent rollouts to a new dataset hdf5, but exclude the dataset observations
50 | # since they might be high-dimensional (they can be extracted again using the
51 | # dataset_states_to_obs.py script).
52 |
53 | python run_trained_agent.py --agent /path/to/model.pth \
54 | --n_rollouts 50 --horizon 400 --seed 0 \
55 | --dataset_path /path/to/output.hdf5
56 | """
57 | import argparse
58 | import json
59 | import os
60 | import random
61 | from collections import OrderedDict
62 | from copy import deepcopy
63 |
64 | import h5py
65 | import imageio
66 | import numpy as np
67 | from optimus.envs.wrappers import EvaluateOnDatasetWrapper, FrameStackWrapper
68 | import robomimic.utils.obs_utils as ObsUtils
69 | import robomimic.utils.tensor_utils as TensorUtils
70 | import torch
71 | from pytorch_lightning import seed_everything
72 | from robomimic.algo import RolloutPolicy
73 | from robomimic.envs.env_base import EnvType
74 | from tqdm import tqdm
75 |
76 | import optimus.utils.file_utils as FileUtils
77 | import optimus.utils.env_utils as EnvUtils
78 | from optimus.algo import algo_factory
79 | from optimus.config.base_config import config_factory
80 | from optimus.scripts.pl_train import ModelWrapper
81 |
82 | DEFAULT_CAMERAS = {
83 | EnvType.ROBOSUITE_TYPE: ["agentview"],
84 | EnvType.IG_MOMART_TYPE: ["rgb"],
85 | EnvType.GYM_TYPE: ValueError("No camera names supported for gym type env!"),
86 | }
87 |
88 |
89 | def rollout(
90 | policy,
91 | env,
92 | horizon,
93 | render=False,
94 | video_writer=None,
95 | video_skip=5,
96 | return_obs=False,
97 | camera_names=None,
98 | ):
99 | """
100 | Helper function to carry out rollouts. Supports on-screen rendering, off-screen rendering to a video,
101 | and returns the rollout trajectory.
102 |
103 | Args:
104 | policy (instance of RolloutPolicy): policy loaded from a checkpoint
105 | env (instance of EnvBase): env loaded from a checkpoint or demonstration metadata
106 | horizon (int): maximum horizon for the rollout
107 | render (bool): whether to render rollout on-screen
108 | video_writer (imageio writer): if provided, use to write rollout to video
109 | video_skip (int): how often to write video frames
110 | return_obs (bool): if True, return possibly high-dimensional observations along the trajectoryu.
111 | They are excluded by default because the low-dimensional simulation states should be a minimal
112 | representation of the environment.
113 | camera_names (list): determines which camera(s) are used for rendering. Pass more than
114 | one to output a video with multiple camera views concatenated horizontally.
115 |
116 | Returns:
117 | stats (dict): some statistics for the rollout - such as return, horizon, and task success
118 | traj (dict): dictionary that corresponds to the rollout trajectory
119 | """
120 | assert isinstance(policy, RolloutPolicy)
121 | assert not (render and (video_writer is not None))
122 |
123 | policy.start_episode()
124 | obs = env.reset()
125 | state_dict = env.get_state()
126 |
127 | # # hack that is necessary for robosuite tasks for deterministic action playback
128 | obs = env.reset_to(state_dict)
129 |
130 | video_count = 0 # video frame counter
131 | total_reward = 0.0
132 | traj = dict(actions=[], rewards=[], dones=[], states=[], initial_state_dict=state_dict)
133 | if return_obs:
134 | # store observations too
135 | traj.update(dict(obs=[], next_obs=[]))
136 | try:
137 | for step_i in range(horizon):
138 | # get action from policy
139 | act = policy(ob=obs)
140 |
141 | # play action
142 | next_obs, r, done, _ = env.step(act)
143 |
144 | # compute reward
145 | total_reward += r
146 | success = env.is_success()["task"]
147 |
148 | # visualization
149 | if render:
150 | env.render(mode="human", camera_name=camera_names[0])
151 | if video_writer is not None:
152 | if video_count % video_skip == 0:
153 | video_img = []
154 | for cam_name in camera_names:
155 | video_img.append(
156 | env.render(
157 | camera_name=cam_name,
158 | mode="rgb_array",
159 | height=512,
160 | width=512,
161 | )
162 | )
163 | video_img = np.concatenate(video_img, axis=1) # concatenate horizontally
164 | video_writer.append_data(video_img)
165 | video_count += 1
166 |
167 | # collect transition
168 | traj["actions"].append(act)
169 | traj["rewards"].append(r)
170 | traj["dones"].append(done)
171 | traj["states"].append(state_dict["states"])
172 | if return_obs:
173 | # Note: We need to "unprocess" the observations to prepare to write them to dataset.
174 | # This includes operations like channel swapping and float to uint8 conversion
175 | # for saving disk space.
176 | traj["obs"].append(ObsUtils.unprocess_obs_dict(obs))
177 | traj["next_obs"].append(ObsUtils.unprocess_obs_dict(next_obs))
178 |
179 | # break if done or if success
180 | if done or success:
181 | break
182 |
183 | # update for next iter
184 | obs = deepcopy(next_obs)
185 | state_dict = env.get_state()
186 |
187 | except env.rollout_exceptions as e:
188 | print("WARNING: got rollout exception {}".format(e))
189 |
190 | stats = dict(Return=total_reward, Horizon=(step_i + 1), Success_Rate=float(success))
191 |
192 | if return_obs:
193 | # convert list of dict to dict of list for obs dictionaries (for convenient writes to hdf5 dataset)
194 | traj["obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["obs"])
195 | traj["next_obs"] = TensorUtils.list_of_flat_dict_to_dict_of_list(traj["next_obs"])
196 |
197 | # list to numpy array
198 | for k in traj:
199 | if k == "initial_state_dict":
200 | continue
201 | if isinstance(traj[k], dict):
202 | for kp in traj[k]:
203 | traj[k][kp] = np.array(traj[k][kp])
204 | else:
205 | traj[k] = np.array(traj[k])
206 | return stats, traj
207 |
208 |
209 | def run_trained_agent(args):
210 | # some arg checking
211 | write_video = args.video_path is not None
212 | assert not (args.render and write_video) # either on-screen or video but not both
213 |
214 | if args.render:
215 | # on-screen rendering can only support one camera
216 | assert len(args.camera_names) == 1
217 |
218 | # relative path to agent
219 | ckpt_path = args.agent
220 |
221 | if args.resume_dir:
222 | resume_dir = args.resume_dir
223 | else:
224 | resume_dir = os.path.dirname(ckpt_path)[: -len("models")]
225 | config_path = os.path.join(resume_dir, "config.json")
226 | ext_cfg = json.load(open(config_path, "r"))
227 | config = config_factory(ext_cfg["algo_name"])
228 | # update config with external json - this will throw errors if
229 | # the external config has keys not present in the base algo config
230 | with config.values_unlocked():
231 | config.update(ext_cfg)
232 | # read rollout settings
233 | rollout_num_episodes = args.n_rollouts
234 | rollout_horizon = args.horizon
235 | if rollout_horizon is None:
236 | # read horizon from config
237 | rollout_horizon = config.experiment.rollout.horizon
238 |
239 | # read config to set up metadata for observation modalities (e.g. detecting rgb observations)
240 | ObsUtils.initialize_obs_utils_with_config(config)
241 |
242 | # make sure the dataset exists
243 | dataset_path = args.data if args.data else config.train.data
244 | if args.data:
245 | config.train.data = args.data
246 | json.dump(config, open(config_path, "w"), indent=4)
247 | if not os.path.exists(dataset_path):
248 | raise Exception("Dataset at provided path {} not found!".format(dataset_path))
249 |
250 | # load basic metadata from training file
251 | print("\n============= Loaded Environment Metadata =============")
252 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=dataset_path)
253 | shape_meta = FileUtils.get_shape_metadata_from_dataset(
254 | dataset_path=dataset_path, all_obs_keys=config.all_obs_keys, verbose=True
255 | )
256 |
257 | model = algo_factory(
258 | algo_name=config.algo_name,
259 | config=config,
260 | obs_key_shapes=shape_meta["all_shapes"],
261 | ac_dim=shape_meta["ac_dim"],
262 | device=torch.device("cpu"), # default to cpu, pl will move to gpu
263 | )
264 | model.nets["policy"].kl_loss_weight = config.algo.transformer.kl_loss_weight
265 |
266 | if config.experiment.env is not None:
267 | env_meta["env_name"] = config.experiment.env
268 | print("=" * 30 + "\n" + "Replacing Env to {}\n".format(env_meta["env_name"]) + "=" * 30)
269 | # create environment
270 | envs = OrderedDict()
271 | if config.experiment.rollout.enabled:
272 | # create environments for validation runs
273 | env_names = [env_meta["env_name"]]
274 |
275 | if config.experiment.additional_envs is not None:
276 | for name in config.experiment.additional_envs:
277 | env_names.append(name)
278 | for env_name in env_names:
279 | env = EnvUtils.create_env_from_metadata(
280 | env_meta=env_meta,
281 | env_name=env_name,
282 | render=False,
283 | render_offscreen=config.experiment.render_video,
284 | use_image_obs=shape_meta["use_images"],
285 | )
286 | if config.train.frame_stack > 1:
287 | env = FrameStackWrapper(
288 | env,
289 | config.train.frame_stack,
290 | config.experiment.rollout.horizon,
291 | dataset_path=config.train.data,
292 | valid_key=config.experiment.rollout.valid_key,
293 | )
294 | else:
295 | env = EvaluateOnDatasetWrapper(
296 | env,
297 | dataset_path=config.train.data,
298 | valid_key=config.experiment.rollout.valid_key,
299 | )
300 |
301 | envs[env.name] = env
302 | print(envs[env.name])
303 |
304 | print("")
305 |
306 | # maybe retreve statistics for normalizing observations
307 | obs_normalization_stats = None
308 | seed_everything(config.train.seed, workers=True)
309 | model = ModelWrapper.load_from_checkpoint(ckpt_path, model=model).cuda()
310 | model.model.nets = model.nets.cuda()
311 | model.model.device = torch.device("cuda")
312 | policy = RolloutPolicy(model.model, obs_normalization_stats=obs_normalization_stats)
313 | # maybe set seed
314 | if args.seed is not None:
315 | np.random.seed(args.seed)
316 | torch.manual_seed(args.seed)
317 |
318 | # maybe create video writer
319 | video_writer = None
320 | if write_video:
321 | video_writer = imageio.get_writer(args.video_path, fps=20)
322 |
323 | # maybe open hdf5 to write rollouts
324 | write_dataset = args.dataset_path is not None
325 | if write_dataset:
326 | data_writer = h5py.File(args.dataset_path, "w")
327 | data_grp = data_writer.create_group("data")
328 | total_samples = 0
329 |
330 | rollout_stats = []
331 | env.sample_eval_episodes(rollout_num_episodes)
332 | for i in tqdm(range(rollout_num_episodes)):
333 | stats, traj = rollout(
334 | policy=policy,
335 | env=env,
336 | horizon=rollout_horizon,
337 | render=args.render,
338 | video_writer=video_writer,
339 | video_skip=args.video_skip,
340 | return_obs=(write_dataset and args.dataset_obs),
341 | camera_names=args.camera_names,
342 | )
343 | rollout_stats.append(stats)
344 |
345 | if write_dataset:
346 | # store transitions
347 | ep_data_grp = data_grp.create_group("demo_{}".format(i))
348 | ep_data_grp.create_dataset("actions", data=np.array(traj["actions"]))
349 | ep_data_grp.create_dataset("states", data=np.array(traj["states"]))
350 | ep_data_grp.create_dataset("rewards", data=np.array(traj["rewards"]))
351 | ep_data_grp.create_dataset("dones", data=np.array(traj["dones"]))
352 | if args.dataset_obs:
353 | for k in traj["obs"]:
354 | ep_data_grp.create_dataset("obs/{}".format(k), data=np.array(traj["obs"][k]))
355 | ep_data_grp.create_dataset(
356 | "next_obs/{}".format(k), data=np.array(traj["next_obs"][k])
357 | )
358 |
359 | # episode metadata
360 | if "model" in traj["initial_state_dict"]:
361 | ep_data_grp.attrs["model_file"] = traj["initial_state_dict"][
362 | "model"
363 | ] # model xml for this episode
364 | ep_data_grp.attrs["num_samples"] = traj["actions"].shape[
365 | 0
366 | ] # number of transitions in this episode
367 | total_samples += traj["actions"].shape[0]
368 |
369 | rollout_stats = TensorUtils.list_of_flat_dict_to_dict_of_list(rollout_stats)
370 | avg_rollout_stats = {k: np.mean(rollout_stats[k]) for k in rollout_stats}
371 | avg_rollout_stats["Num_Success"] = np.sum(rollout_stats["Success_Rate"])
372 | print("Average Rollout Stats")
373 | print(json.dumps(avg_rollout_stats, indent=4))
374 |
375 | if write_video:
376 | video_writer.close()
377 |
378 | if write_dataset:
379 | # global metadata
380 | data_grp.attrs["total"] = total_samples
381 | data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4) # environment info
382 | data_writer.close()
383 | print("Wrote dataset trajectories to {}".format(args.dataset_path))
384 |
385 |
386 | if __name__ == "__main__":
387 | parser = argparse.ArgumentParser()
388 |
389 | # Path to trained model
390 | parser.add_argument(
391 | "--agent",
392 | type=str,
393 | required=True,
394 | help="path to saved checkpoint ckpt file",
395 | )
396 |
397 | parser.add_argument(
398 | "--resume_dir",
399 | type=str,
400 | required=True,
401 | help="path to saved checkpoint dir",
402 | )
403 |
404 | # number of rollouts
405 | parser.add_argument(
406 | "--n_rollouts",
407 | type=int,
408 | default=27,
409 | help="number of rollouts",
410 | )
411 |
412 | # maximum horizon of rollout, to override the one stored in the model checkpoint
413 | parser.add_argument(
414 | "--horizon",
415 | type=int,
416 | default=None,
417 | help="(optional) override maximum horizon of rollout from the one in the checkpoint",
418 | )
419 |
420 | # Env Name (to override the one stored in model checkpoint)
421 | parser.add_argument(
422 | "--env",
423 | type=str,
424 | default=None,
425 | help="(optional) override name of env from the one in the checkpoint, and use\
426 | it for rollouts",
427 | )
428 |
429 | # Whether to render rollouts to screen
430 | parser.add_argument(
431 | "--render",
432 | action="store_true",
433 | help="on-screen rendering",
434 | )
435 |
436 | # Dump a video of the rollouts to the specified path
437 | parser.add_argument(
438 | "--video_path",
439 | type=str,
440 | default=None,
441 | help="(optional) render rollouts to this video file path",
442 | )
443 |
444 | # How often to write video frames during the rollout
445 | parser.add_argument(
446 | "--video_skip",
447 | type=int,
448 | default=5,
449 | help="render frames to video every n steps",
450 | )
451 |
452 | # camera names to render
453 | parser.add_argument(
454 | "--camera_names",
455 | type=str,
456 | nargs="+",
457 | default=["agentview"],
458 | help="(optional) camera name(s) to use for rendering on-screen or to video",
459 | )
460 |
461 | # If provided, an hdf5 file will be written with the rollout data
462 | parser.add_argument(
463 | "--dataset_path",
464 | type=str,
465 | default=None,
466 | help="(optional) if provided, an hdf5 file will be written at this path with the rollout data",
467 | )
468 |
469 | # If True and @dataset_path is supplied, will write possibly high-dimensional observations to dataset.
470 | parser.add_argument(
471 | "--dataset_obs",
472 | action="store_true",
473 | help="include possibly high-dimensional observations in output dataset hdf5 file (by default,\
474 | observations are excluded and only simulator states are saved)",
475 | )
476 |
477 | # for seeding before starting rollouts
478 | parser.add_argument(
479 | "--seed",
480 | type=int,
481 | default=None,
482 | help="(optional) set seed for rollouts",
483 | )
484 |
485 | parser.add_argument("--data", type=str, default=None)
486 |
487 | args = parser.parse_args()
488 | run_trained_agent(args)
489 |
--------------------------------------------------------------------------------