├── media ├── pr_example.gif ├── dp_manifold.gif └── ss_dynamics.gif ├── .gitignore ├── itps ├── common │ ├── utils │ │ ├── io_utils.py │ │ ├── import_utils.py │ │ ├── benchmark.py │ │ └── utils.py │ ├── policies │ │ ├── utils.py │ │ ├── policy_protocol.py │ │ ├── factory.py │ │ ├── act │ │ │ ├── configuration_act.py │ │ │ └── modeling_act.py │ │ ├── normalize.py │ │ ├── diffusion │ │ │ └── configuration_diffusion.py │ │ └── rollout_wrapper.py │ ├── envs │ │ ├── factory.py │ │ └── utils.py │ ├── datasets │ │ ├── sampler.py │ │ ├── factory.py │ │ ├── transforms.py │ │ ├── compute_stats.py │ │ ├── video_utils.py │ │ ├── utils.py │ │ └── lerobot_dataset.py │ └── logger.py └── interact_maze2d.py ├── LICENSE ├── pyproject.toml └── README.md /media/pr_example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanweiw/itps/HEAD/media/pr_example.gif -------------------------------------------------------------------------------- /media/dp_manifold.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanweiw/itps/HEAD/media/dp_manifold.gif -------------------------------------------------------------------------------- /media/ss_dynamics.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanweiw/itps/HEAD/media/ss_dynamics.gif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | itps/common/*/__pycache__/ 2 | itps/common/policies/*/__pycache__/ 3 | itps/weights_*/ 4 | itps/weights_*.zip 5 | .DS_Store 6 | .vscode 7 | -------------------------------------------------------------------------------- /itps/common/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import warnings 17 | 18 | import imageio 19 | 20 | 21 | def write_video(video_path, stacked_frames, fps): 22 | # Filter out DeprecationWarnings raised from pkg_resources 23 | with warnings.catch_warnings(): 24 | warnings.filterwarnings( 25 | "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning 26 | ) 27 | imageio.mimsave(video_path, stacked_frames, fps=fps) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yanwei Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | Some of this software is derived from LeRobot, which is subject to the following copyright notice: 24 | 25 | Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, 26 | Tony Z. Zhao, 27 | and The HuggingFace Inc. team. All rights reserved. 28 | Licensed under the Apache License, Version 2.0 (the "License"); 29 | you may not use this file except in compliance with the License. 30 | You may obtain a copy of the License at 31 | 32 | http://www.apache.org/licenses/LICENSE-2.0 33 | 34 | Unless required by applicable law or agreed to in writing, software 35 | distributed under the License is distributed on an "AS IS" BASIS, 36 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 37 | See the License for the specific language governing permissions and 38 | limitations under the License. 39 | 40 | -------------------------------------------------------------------------------- /itps/common/policies/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import torch 17 | from torch import nn 18 | 19 | 20 | def populate_queues(queues, batch): 21 | for key in batch: 22 | # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the 23 | # queues have the keys they want). 24 | if key not in queues: 25 | continue 26 | if len(queues[key]) != queues[key].maxlen: 27 | # initialize by copying the first observation several times until the queue is full 28 | while len(queues[key]) != queues[key].maxlen: 29 | queues[key].append(batch[key]) 30 | else: 31 | # add latest observation to the queue 32 | queues[key].append(batch[key]) 33 | return queues 34 | 35 | 36 | def get_device_from_parameters(module: nn.Module) -> torch.device: 37 | """Get a module's device by checking one of its parameters. 38 | 39 | Note: assumes that all parameters have the same device 40 | """ 41 | return next(iter(module.parameters())).device 42 | 43 | 44 | def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: 45 | """Get a module's parameter dtype by checking one of its parameters. 46 | 47 | Note: assumes that all parameters have the same dtype. 48 | """ 49 | return next(iter(module.parameters())).dtype 50 | -------------------------------------------------------------------------------- /itps/common/envs/factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import importlib 17 | 18 | import gymnasium as gym 19 | from omegaconf import DictConfig 20 | 21 | 22 | def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv: 23 | """Makes a gym vector environment according to the evaluation config. 24 | 25 | n_envs can be used to override eval.batch_size in the configuration. Must be at least 1. 26 | """ 27 | if n_envs is not None and n_envs < 1: 28 | raise ValueError("`n_envs must be at least 1") 29 | 30 | package_name = f"gym_{cfg.env.name}" 31 | 32 | try: 33 | importlib.import_module(package_name) 34 | except ModuleNotFoundError as e: 35 | print( 36 | f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`" 37 | ) 38 | raise e 39 | 40 | gym_handle = f"{package_name}/{cfg.env.task}" 41 | gym_kwgs = dict(cfg.env.get("gym", {})) 42 | 43 | if cfg.env.get("episode_length"): 44 | gym_kwgs["max_episode_steps"] = cfg.env.episode_length 45 | 46 | # batched version of the env that returns an observation of shape (b, c) 47 | env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv 48 | env = env_cls( 49 | [ 50 | lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) 51 | for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size) 52 | ] 53 | ) 54 | 55 | return env 56 | -------------------------------------------------------------------------------- /itps/common/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from typing import Iterator, Union 17 | 18 | import torch 19 | 20 | 21 | class EpisodeAwareSampler: 22 | def __init__( 23 | self, 24 | episode_data_index: dict, 25 | episode_indices_to_use: Union[list, None] = None, 26 | drop_n_first_frames: int = 0, 27 | drop_n_last_frames: int = 0, 28 | shuffle: bool = False, 29 | ): 30 | """Sampler that optionally incorporates episode boundary information. 31 | 32 | Args: 33 | episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. 34 | episode_indices_to_use: List of episode indices to use. If None, all episodes are used. 35 | Assumes that episodes are indexed from 0 to N-1. 36 | drop_n_first_frames: Number of frames to drop from the start of each episode. 37 | drop_n_last_frames: Number of frames to drop from the end of each episode. 38 | shuffle: Whether to shuffle the indices. 39 | """ 40 | indices = [] 41 | for episode_idx, (start_index, end_index) in enumerate( 42 | zip(episode_data_index["from"], episode_data_index["to"], strict=True) 43 | ): 44 | if episode_indices_to_use is None or episode_idx in episode_indices_to_use: 45 | indices.extend( 46 | range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) 47 | ) 48 | 49 | self.indices = indices 50 | self.shuffle = shuffle 51 | 52 | def __iter__(self) -> Iterator[int]: 53 | if self.shuffle: 54 | for i in torch.randperm(len(self.indices)): 55 | yield self.indices[i] 56 | else: 57 | for i in self.indices: 58 | yield i 59 | 60 | def __len__(self) -> int: 61 | return len(self.indices) 62 | -------------------------------------------------------------------------------- /itps/common/envs/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import einops 17 | import numpy as np 18 | import torch 19 | from torch import Tensor 20 | 21 | 22 | def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: 23 | """Convert environment observation to LeRobot format observation. 24 | Args: 25 | observation: Dictionary of observation batches from a Gym vector environment. 26 | Returns: 27 | Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. 28 | """ 29 | # map to expected inputs for the policy 30 | return_observations = {} 31 | if "pixels" in observations: 32 | if isinstance(observations["pixels"], dict): 33 | imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} 34 | else: 35 | imgs = {"observation.image": observations["pixels"]} 36 | 37 | for imgkey, img in imgs.items(): 38 | img = torch.from_numpy(img) 39 | 40 | # sanity check that images are channel last 41 | _, h, w, c = img.shape 42 | assert c < h and c < w, f"expect channel first images, but instead {img.shape}" 43 | 44 | # sanity check that images are uint8 45 | assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" 46 | 47 | # convert to channel first of type float32 in range [0,1] 48 | img = einops.rearrange(img, "b h w c -> b c h w").contiguous() 49 | img = img.type(torch.float32) 50 | img /= 255 51 | 52 | return_observations[imgkey] = img 53 | 54 | if "environment_state" in observations: 55 | return_observations["observation.environment_state"] = torch.from_numpy( 56 | observations["environment_state"] 57 | ).float() 58 | 59 | # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing 60 | # requirement for "agent_pos" 61 | return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() 62 | return return_observations 63 | -------------------------------------------------------------------------------- /itps/common/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import importlib 17 | import logging 18 | 19 | 20 | def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: 21 | """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py 22 | Check if the package spec exists and grab its version to avoid importing a local directory. 23 | **Note:** this doesn't work for all packages. 24 | """ 25 | package_exists = importlib.util.find_spec(pkg_name) is not None 26 | package_version = "N/A" 27 | if package_exists: 28 | try: 29 | # Primary method to get the package version 30 | package_version = importlib.metadata.version(pkg_name) 31 | except importlib.metadata.PackageNotFoundError: 32 | # Fallback method: Only for "torch" and versions containing "dev" 33 | if pkg_name == "torch": 34 | try: 35 | package = importlib.import_module(pkg_name) 36 | temp_version = getattr(package, "__version__", "N/A") 37 | # Check if the version contains "dev" 38 | if "dev" in temp_version: 39 | package_version = temp_version 40 | package_exists = True 41 | else: 42 | package_exists = False 43 | except ImportError: 44 | # If the package can't be imported, it's not available 45 | package_exists = False 46 | else: 47 | # For packages other than "torch", don't attempt the fallback and set as not available 48 | package_exists = False 49 | logging.debug(f"Detected {pkg_name} version: {package_version}") 50 | if return_version: 51 | return package_exists, package_version 52 | else: 53 | return package_exists 54 | 55 | 56 | _torch_available, _torch_version = is_package_available("torch", return_version=True) 57 | _gym_xarm_available = is_package_available("gym_xarm") 58 | _gym_aloha_available = is_package_available("gym_aloha") 59 | _gym_pusht_available = is_package_available("gym_pusht") 60 | -------------------------------------------------------------------------------- /itps/common/policies/policy_protocol.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """A protocol that all policies should follow. 17 | 18 | This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes 19 | subclass a base class. 20 | 21 | The protocol structure, method signatures, and docstrings should be used by developers as a reference for 22 | how to implement new policies. 23 | """ 24 | 25 | from typing import Protocol, runtime_checkable 26 | 27 | from torch import Tensor 28 | 29 | 30 | @runtime_checkable 31 | class Policy(Protocol): 32 | """The required interface for implementing a policy. 33 | 34 | We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin. 35 | """ 36 | 37 | name: str 38 | 39 | def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None): 40 | """ 41 | Args: 42 | cfg: Policy configuration class instance or None, in which case the default instantiation of the 43 | configuration class is used. 44 | dataset_stats: Dataset statistics to be used for normalization. 45 | """ 46 | 47 | @property 48 | def n_obs_steps(self) -> int: 49 | """TODO(now)""" 50 | 51 | @property 52 | def input_keys(self) -> int: 53 | """TODO(now)""" 54 | 55 | def forward(self, batch: dict[str, Tensor]) -> dict: 56 | """Run the batch through the model and compute the loss for training or validation. 57 | 58 | Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all 59 | other items should be logging-friendly, native Python types. 60 | """ 61 | 62 | def run_inference(self, observation_batch: dict[str, Tensor]) -> Tensor: 63 | """Return a sequence of actions to run in the environment (potentially in batch mode).""" 64 | 65 | 66 | @runtime_checkable 67 | class PolicyWithUpdate(Policy, Protocol): 68 | def update(self): 69 | """An update method that is to be called after a training optimization step. 70 | 71 | Implements an additional updates the model parameters may need (for example, doing an EMA step for a 72 | target model, or incrementing an internal buffer). 73 | """ 74 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "itps" 3 | version = "0.1.0" 4 | description = "Inference-Time Policy Steering" 5 | authors = [ 6 | "Yanwei Wang (Felix) " 7 | ] 8 | repository = "https://github.com/yanweiw/itps" 9 | readme = "README.md" 10 | license = "MIT License" 11 | classifiers=[ 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Education", 14 | "Intended Audience :: Science/Research", 15 | "Topic :: Software Development :: Build Tools", 16 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 17 | "Programming Language :: Python :: 3.10", 18 | ] 19 | packages = [{include = "itps"}] 20 | 21 | 22 | [tool.poetry.dependencies] 23 | python = ">=3.10,<3.13" 24 | termcolor = ">=2.4.0" 25 | omegaconf = ">=2.3.0" 26 | wandb = ">=0.16.3" 27 | imageio = {extras = ["ffmpeg"], version = ">=2.34.0"} 28 | gdown = ">=5.1.0" 29 | hydra-core = ">=1.3.2" 30 | einops = ">=0.8.0" 31 | pymunk = ">=6.6.0" 32 | zarr = ">=2.17.0" 33 | numba = ">=0.59.0" 34 | torch = "^2.2.1" 35 | opencv-python = ">=4.9.0" 36 | diffusers = "^0.27.2" 37 | torchvision = ">=0.17.1" 38 | h5py = ">=3.10.0" 39 | huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"} 40 | gymnasium = ">=0.29.1" 41 | cmake = ">=3.29.0.1" 42 | 43 | pre-commit = {version = ">=3.7.0", optional = true} 44 | debugpy = {version = ">=1.8.1", optional = true} 45 | pytest = {version = ">=8.1.0", optional = true} 46 | pytest-cov = {version = ">=5.0.0", optional = true} 47 | datasets = "^2.19.0" 48 | imagecodecs = { version = ">=2024.1.1", optional = true } 49 | moviepy = ">=1.0.3" 50 | rerun-sdk = ">=0.15.1" 51 | deepdiff = ">=7.0.1" 52 | scikit-image = {version = "^0.23.2", optional = true} 53 | pandas = {version = "^2.2.2", optional = true} 54 | pytest-mock = {version = "^3.14.0", optional = true} 55 | pygame = ">=2.6.1" 56 | scipy = ">=1.14.1" 57 | contourpy = ">=1.2.1" 58 | cycler = ">=0.12.1" 59 | fonttools = ">=4.53.1" 60 | kiwisolver = ">=1.4.5" 61 | matplotlib = ">=3.9.2" 62 | pyparsing = ">=3.1.4" 63 | 64 | 65 | [tool.poetry.extras] 66 | dev = ["pre-commit", "debugpy"] 67 | test = ["pytest", "pytest-cov", "pytest-mock"] 68 | video_benchmark = ["scikit-image", "pandas"] 69 | 70 | [tool.ruff] 71 | line-length = 110 72 | target-version = "py310" 73 | exclude = [ 74 | "tests/data", 75 | ".bzr", 76 | ".direnv", 77 | ".eggs", 78 | ".git", 79 | ".git-rewrite", 80 | ".hg", 81 | ".mypy_cache", 82 | ".nox", 83 | ".pants.d", 84 | ".pytype", 85 | ".ruff_cache", 86 | ".svn", 87 | ".tox", 88 | ".venv", 89 | "__pypackages__", 90 | "_build", 91 | "buck-out", 92 | "build", 93 | "dist", 94 | "node_modules", 95 | "venv", 96 | ] 97 | 98 | 99 | [tool.ruff.lint] 100 | select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] 101 | ignore-init-module-imports = true 102 | 103 | 104 | [build-system] 105 | requires = ["poetry-core"] 106 | build-backend = "poetry.core.masonry.api" 107 | -------------------------------------------------------------------------------- /itps/common/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import threading 17 | import time 18 | from contextlib import ContextDecorator 19 | 20 | 21 | class TimeBenchmark(ContextDecorator): 22 | """ 23 | Measures execution time using a context manager or decorator. 24 | 25 | This class supports both context manager and decorator usage, and is thread-safe for multithreaded 26 | environments. 27 | 28 | Args: 29 | print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults 30 | to False. 31 | 32 | Examples: 33 | 34 | Using as a context manager: 35 | 36 | >>> benchmark = TimeBenchmark() 37 | >>> with benchmark: 38 | ... time.sleep(1) 39 | >>> print(f"Block took {benchmark.result:.4f} seconds") 40 | Block took approximately 1.0000 seconds 41 | 42 | Using with multithreading: 43 | 44 | ```python 45 | import threading 46 | 47 | benchmark = TimeBenchmark() 48 | 49 | def context_manager_example(): 50 | with benchmark: 51 | time.sleep(0.01) 52 | print(f"Block took {benchmark.result_ms:.2f} milliseconds") 53 | 54 | threads = [] 55 | for _ in range(3): 56 | t1 = threading.Thread(target=context_manager_example) 57 | threads.append(t1) 58 | 59 | for t in threads: 60 | t.start() 61 | 62 | for t in threads: 63 | t.join() 64 | ``` 65 | Expected output: 66 | Block took approximately 10.00 milliseconds 67 | Block took approximately 10.00 milliseconds 68 | Block took approximately 10.00 milliseconds 69 | """ 70 | 71 | def __init__(self, print=False): 72 | self.local = threading.local() 73 | self.print_time = print 74 | 75 | def __enter__(self): 76 | self.local.start_time = time.perf_counter() 77 | return self 78 | 79 | def __exit__(self, *exc): 80 | self.local.end_time = time.perf_counter() 81 | self.local.elapsed_time = self.local.end_time - self.local.start_time 82 | if self.print_time: 83 | print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds") 84 | return False 85 | 86 | @property 87 | def result(self): 88 | return getattr(self.local, "elapsed_time", None) 89 | 90 | @property 91 | def result_ms(self): 92 | return self.result * 1e3 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inference-Time Policy Steering (ITPS) 2 | 3 | Maze2D benchmark of various sampling methods with sketch input from the paper [Inference-Time Policy Steering through Human Interactions](https://yanweiw.github.io/itps/). 4 | 5 | ## Installation 6 | Clone this repo 7 | ``` 8 | git clone git@github.com:yanweiw/itps.git 9 | cd itps 10 | ``` 11 | Create a virtual environment with Python 3.10 12 | ``` 13 | conda create -y -n itps python=3.10 14 | conda activate itps 15 | ``` 16 | Install ITPS 17 | ``` 18 | pip install -e . 19 | ``` 20 | Download the pre-trained weights for [Action Chunking Transformers](https://drive.google.com/file/d/1kKt__yQpXOzgAGFvfGpBWdtWX_QxWsVK/view?usp=sharing) and [Diffusion Policy](https://drive.google.com/file/d/1efez47zfkXl7HgGDSzW-tagdcPj1p8z2/view?usp=sharing) and put them in the `itps/itps` folder (Be sure to unzip the downloaded zip file). 21 | 22 | ## Visualize pre-trained policies. 23 | 24 | Run ACT or DP unconditionally to explore motion manifolds learned by these pre-trained policies. 25 | ``` 26 | python interact_maze2d.py -p [act, dp] -u 27 | ``` 28 | |Multimodal predictions of DP| 29 | |---------------------------| 30 | |![](media/dp_manifold.gif)| 31 | 32 | 33 | ## Bias sampling with sketch interaction. 34 | 35 | `-ph` - Post-Hoc Ranking 36 | `-op` - Output Perturbation 37 | `-bi` - Biased Initialization 38 | `-gd` - Guided Diffusion 39 | `-ss` - Stochastic Sampling 40 | ``` 41 | python interact_maze2d.py -p [act, dp] [-ph, -bi, -gd, -ss] 42 | ``` 43 | |Post-Hoc Ranking Example| 44 | |---------------------------| 45 | |![](media/pr_example.gif)| 46 | Draw by clicking and dragging the mouse. Re-initialize the agent (red) position by moving the mouse close to it without clicking. 47 | 48 | ## Visualize sampling dynamics. 49 | 50 | Run DP with BI, GD or SS with `-v` option. 51 | ``` 52 | python interact_maze2d.py -p [act, dp] [-bi, -gd, -ss] -v 53 | ``` 54 | | Stochastic Sampling Example| 55 | |---------------------------| 56 | |![](media/ss_dynamics.gif)| 57 | 58 | ## Benchmark methods. 59 | Save sketches into a file `exp00.json` and use them across methods. 60 | ``` 61 | python interact_maze2d.py -p [act, dp] -s exp00.json 62 | ``` 63 | Visualize saved sketches by loading the saved file, press the key `n` for next. 64 | ``` 65 | python interact_maze2d.py -p [act, dp] [-ph, -op, -bi, -gd, -ss] -l exp00.json 66 | ``` 67 | Save experiments into `exp00_dp_gd.json` 68 | ``` 69 | python interact_maze2d.py -p dp -gd -l exp00.json -s .json 70 | ``` 71 | Replay experiments. 72 | ``` 73 | python interact_maze2d.py -l exp00_dp_gd.json 74 | ``` 75 | 76 | ## How to get the pre-trained policy? 77 | While the ITPS framework assumes the pre-trained policy is given, I have received many requests to open source my training data [(D4RL Maze2D)](https://github.com/Farama-Foundation/D4RL/blob/89141a689b0353b0dac3da5cba60da4b1b16254d/d4rl/infos.py#L11) and training code [(my LeRobot fork)](https://github.com/yanweiw/lerobot/blob/custom_dataset/lerobot/scripts/train.py) (use it at your own risk as it is not as well-maintained as the inference code in this repo). So here you are: 78 | 79 | Make sure you are on the `custom_dataset` branch of the training codebase and use the [dataset here](https://drive.google.com/file/d/1UPdjg48e9WFs6j_GTmF2xUJPV_XNMiUk/view?usp=sharing). 80 | ``` 81 | python lerobot/scripts/train.py policy=maze2d_act env=maze2d 82 | ``` 83 | You can set `policy=maze2d_dp` to train a diffusion policy. If the `itps` conda environment does not support training, create a `lerobot` environment [following this](https://github.com/yanweiw/lerobot/tree/custom_dataset). Hopefully, this will work. But I cannot guarantee it, as this is not the paper contribution and I am not maintaining it. 84 | 85 | ## Acknowledgement 86 | 87 | Part of the codebase is modified from [LeRobot](https://github.com/huggingface/lerobot). 88 | -------------------------------------------------------------------------------- /itps/common/datasets/factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import logging 17 | 18 | import torch 19 | from omegaconf import ListConfig, OmegaConf 20 | 21 | from common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset 22 | from common.datasets.transforms import get_image_transforms 23 | 24 | 25 | def resolve_delta_timestamps(cfg): 26 | """Resolves delta_timestamps config key (in-place) by using `eval`. 27 | 28 | Doesn't do anything if delta_timestamps is not specified or has already been resolve (as evidenced by 29 | the data type of its values). 30 | """ 31 | delta_timestamps = cfg.training.get("delta_timestamps") 32 | if delta_timestamps is not None: 33 | for key in delta_timestamps: 34 | if isinstance(delta_timestamps[key], str): 35 | # TODO(rcadene, alexander-soare): remove `eval` to avoid exploit 36 | cfg.training.delta_timestamps[key] = eval(delta_timestamps[key]) 37 | 38 | 39 | def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset: 40 | """ 41 | Args: 42 | cfg: A Hydra config as per the LeRobot config scheme. 43 | split: Select the data subset used to create an instance of LeRobotDataset. 44 | All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train". 45 | Thus, by default, `split="train"` selects all the available data. `split` aims to work like the 46 | slicer in the hugging face datasets: 47 | https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits 48 | As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or 49 | `split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`. 50 | Returns: 51 | The LeRobotDataset. 52 | """ 53 | if not isinstance(cfg.dataset_repo_id, (str, ListConfig)): 54 | raise ValueError( 55 | "Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of " 56 | "strings to load multiple datasets." 57 | ) 58 | 59 | # A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora). 60 | if cfg.env.name != "dora": 61 | if isinstance(cfg.dataset_repo_id, str): 62 | dataset_repo_ids = [cfg.dataset_repo_id] # single dataset 63 | else: 64 | dataset_repo_ids = cfg.dataset_repo_id # multiple datasets 65 | 66 | for dataset_repo_id in dataset_repo_ids: 67 | if cfg.env.name not in dataset_repo_id: 68 | logging.warning( 69 | f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your " 70 | f"environment ({cfg.env.name=})." 71 | ) 72 | 73 | resolve_delta_timestamps(cfg) 74 | 75 | image_transforms = None 76 | if cfg.training.image_transforms.enable: 77 | cfg_tf = cfg.training.image_transforms 78 | image_transforms = get_image_transforms( 79 | brightness_weight=cfg_tf.brightness.weight, 80 | brightness_min_max=cfg_tf.brightness.min_max, 81 | contrast_weight=cfg_tf.contrast.weight, 82 | contrast_min_max=cfg_tf.contrast.min_max, 83 | saturation_weight=cfg_tf.saturation.weight, 84 | saturation_min_max=cfg_tf.saturation.min_max, 85 | hue_weight=cfg_tf.hue.weight, 86 | hue_min_max=cfg_tf.hue.min_max, 87 | sharpness_weight=cfg_tf.sharpness.weight, 88 | sharpness_min_max=cfg_tf.sharpness.min_max, 89 | max_num_transforms=cfg_tf.max_num_transforms, 90 | random_order=cfg_tf.random_order, 91 | ) 92 | 93 | if isinstance(cfg.dataset_repo_id, str): 94 | dataset = LeRobotDataset( 95 | cfg.dataset_repo_id, 96 | split=split, 97 | delta_timestamps=cfg.training.get("delta_timestamps"), 98 | image_transforms=image_transforms, 99 | video_backend=cfg.video_backend, 100 | ) 101 | else: 102 | dataset = MultiLeRobotDataset( 103 | cfg.dataset_repo_id, 104 | split=split, 105 | delta_timestamps=cfg.training.get("delta_timestamps"), 106 | image_transforms=image_transforms, 107 | video_backend=cfg.video_backend, 108 | ) 109 | 110 | if cfg.get("override_dataset_stats"): 111 | for key, stats_dict in cfg.override_dataset_stats.items(): 112 | for stats_type, listconfig in stats_dict.items(): 113 | # example of stats_type: min, max, mean, std 114 | stats = OmegaConf.to_container(listconfig, resolve=True) 115 | dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) 116 | 117 | return dataset 118 | -------------------------------------------------------------------------------- /itps/common/policies/factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import inspect 17 | import logging 18 | 19 | from omegaconf import DictConfig, OmegaConf 20 | 21 | from common.policies.policy_protocol import Policy 22 | from common.utils.utils import get_safe_torch_device 23 | 24 | 25 | def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): 26 | expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) 27 | if not set(hydra_cfg.policy).issuperset(expected_kwargs): 28 | logging.warning( 29 | f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" 30 | ) 31 | 32 | # OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid 33 | # issues with mutable defaults. This filter changes all lists to tuples. 34 | def list_to_tuple(item): 35 | return tuple(item) if isinstance(item, list) else item 36 | 37 | policy_cfg = policy_cfg_class( 38 | **{ 39 | k: list_to_tuple(v) 40 | for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items() 41 | if k in expected_kwargs 42 | } 43 | ) 44 | return policy_cfg 45 | 46 | 47 | def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: 48 | """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" 49 | if name == "tdmpc": 50 | from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig 51 | from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy 52 | 53 | return TDMPCPolicy, TDMPCConfig 54 | elif name == "diffusion": 55 | from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig 56 | from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy 57 | 58 | return DiffusionPolicy, DiffusionConfig 59 | elif name == "act": 60 | from lerobot.common.policies.act.configuration_act import ACTConfig 61 | from lerobot.common.policies.act.modeling_act import ACTPolicy 62 | 63 | return ACTPolicy, ACTConfig 64 | elif name == "vqbet": 65 | from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig 66 | from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy 67 | 68 | return VQBeTPolicy, VQBeTConfig 69 | else: 70 | raise NotImplementedError(f"Policy with name {name} is not implemented.") 71 | 72 | 73 | def make_policy( 74 | hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None 75 | ) -> Policy: 76 | """Make an instance of a policy class. 77 | 78 | Args: 79 | hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is 80 | provided, only `hydra_cfg.policy.name` is used while everything else is ignored. 81 | pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a 82 | directory containing weights saved using `Policy.save_pretrained`. Note that providing this 83 | argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`. 84 | dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must 85 | be provided when initializing a new policy, and must not be provided when loading a pretrained 86 | policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`. 87 | """ 88 | if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): 89 | raise ValueError( 90 | "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided." 91 | ) 92 | 93 | policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) 94 | 95 | policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) 96 | if pretrained_policy_name_or_path is None: 97 | # Make a fresh policy. 98 | policy = policy_cls(policy_cfg, dataset_stats) 99 | else: 100 | # Load a pretrained policy and override the config if needed (for example, if there are inference-time 101 | # hyperparameters that we want to vary). 102 | # TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, 103 | # pretrained weights which are then loaded into a fresh policy with the desired config. This PR in 104 | # huggingface_hub should make it possible to avoid the hack: 105 | # https://github.com/huggingface/huggingface_hub/pull/2274. 106 | policy = policy_cls(policy_cfg) 107 | policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) 108 | 109 | policy.to(get_safe_torch_device(hydra_cfg.device)) 110 | 111 | return policy 112 | -------------------------------------------------------------------------------- /itps/common/utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import logging 17 | import os.path as osp 18 | import random 19 | from contextlib import contextmanager 20 | from datetime import datetime 21 | from pathlib import Path 22 | from typing import Any, Generator 23 | 24 | import hydra 25 | import numpy as np 26 | import torch 27 | from omegaconf import DictConfig 28 | 29 | 30 | def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: 31 | """Given a string, return a torch.device with checks on whether the device is available.""" 32 | match cfg_device: 33 | case "cuda": 34 | assert torch.cuda.is_available() 35 | device = torch.device("cuda") 36 | case "mps": 37 | assert torch.backends.mps.is_available() 38 | device = torch.device("mps") 39 | case "cpu": 40 | device = torch.device("cpu") 41 | if log: 42 | logging.warning("Using CPU, this will be slow.") 43 | case _: 44 | device = torch.device(cfg_device) 45 | if log: 46 | logging.warning(f"Using custom {cfg_device} device.") 47 | 48 | return device 49 | 50 | 51 | def get_global_random_state() -> dict[str, Any]: 52 | """Get the random state for `random`, `numpy`, and `torch`.""" 53 | random_state_dict = { 54 | "random_state": random.getstate(), 55 | "numpy_random_state": np.random.get_state(), 56 | "torch_random_state": torch.random.get_rng_state(), 57 | } 58 | if torch.cuda.is_available(): 59 | random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state() 60 | return random_state_dict 61 | 62 | 63 | def set_global_random_state(random_state_dict: dict[str, Any]): 64 | """Set the random state for `random`, `numpy`, and `torch`. 65 | 66 | Args: 67 | random_state_dict: A dictionary of the form returned by `get_global_random_state`. 68 | """ 69 | random.setstate(random_state_dict["random_state"]) 70 | np.random.set_state(random_state_dict["numpy_random_state"]) 71 | torch.random.set_rng_state(random_state_dict["torch_random_state"]) 72 | if torch.cuda.is_available(): 73 | torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) 74 | 75 | 76 | def set_global_seed(seed): 77 | """Set seed for reproducibility.""" 78 | random.seed(seed) 79 | np.random.seed(seed) 80 | torch.manual_seed(seed) 81 | if torch.cuda.is_available(): 82 | torch.cuda.manual_seed_all(seed) 83 | 84 | 85 | @contextmanager 86 | def seeded_context(seed: int) -> Generator[None, None, None]: 87 | """Set the seed when entering a context, and restore the prior random state at exit. 88 | 89 | Example usage: 90 | 91 | ``` 92 | a = random.random() # produces some random number 93 | with seeded_context(1337): 94 | b = random.random() # produces some other random number 95 | c = random.random() # produces yet another random number, but the same it would have if we never made `b` 96 | ``` 97 | """ 98 | random_state_dict = get_global_random_state() 99 | set_global_seed(seed) 100 | yield None 101 | set_global_random_state(random_state_dict) 102 | 103 | 104 | def init_logging(): 105 | def custom_format(record): 106 | dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 107 | fnameline = f"{record.pathname}:{record.lineno}" 108 | message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" 109 | return message 110 | 111 | logging.basicConfig(level=logging.INFO) 112 | 113 | for handler in logging.root.handlers[:]: 114 | logging.root.removeHandler(handler) 115 | 116 | formatter = logging.Formatter() 117 | formatter.format = custom_format 118 | console_handler = logging.StreamHandler() 119 | console_handler.setFormatter(formatter) 120 | logging.getLogger().addHandler(console_handler) 121 | 122 | 123 | def format_big_number(num, precision=0): 124 | suffixes = ["", "K", "M", "B", "T", "Q"] 125 | divisor = 1000.0 126 | 127 | for suffix in suffixes: 128 | if abs(num) < divisor: 129 | return f"{num:.{precision}f}{suffix}" 130 | num /= divisor 131 | 132 | return num 133 | 134 | 135 | def _relative_path_between(path1: Path, path2: Path) -> Path: 136 | """Returns path1 relative to path2.""" 137 | path1 = path1.absolute() 138 | path2 = path2.absolute() 139 | try: 140 | return path1.relative_to(path2) 141 | except ValueError: # most likely because path1 is not a subpath of path2 142 | common_parts = Path(osp.commonpath([path1, path2])).parts 143 | return Path( 144 | "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) 145 | ) 146 | 147 | 148 | def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig: 149 | """Initialize a Hydra config given only the path to the relevant config file. 150 | 151 | For config resolution, it is assumed that the config file's parent is the Hydra config dir. 152 | """ 153 | # TODO(alexander-soare): Resolve configs without Hydra initialization. 154 | hydra.core.global_hydra.GlobalHydra.instance().clear() 155 | # Hydra needs a path relative to this file. 156 | hydra.initialize( 157 | str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)), 158 | version_base="1.2", 159 | ) 160 | cfg = hydra.compose(Path(config_path).stem, overrides) 161 | return cfg 162 | 163 | 164 | def print_cuda_memory_usage(): 165 | """Use this function to locate and debug memory leak.""" 166 | import gc 167 | 168 | gc.collect() 169 | # Also clear the cache if you want to fully release the memory 170 | torch.cuda.empty_cache() 171 | print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) 172 | print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) 173 | print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) 174 | print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) 175 | -------------------------------------------------------------------------------- /itps/common/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import collections 17 | from typing import Any, Callable, Dict, Sequence 18 | 19 | import torch 20 | from torchvision.transforms import v2 21 | from torchvision.transforms.v2 import Transform 22 | from torchvision.transforms.v2 import functional as F # noqa: N812 23 | 24 | 25 | class RandomSubsetApply(Transform): 26 | """Apply a random subset of N transformations from a list of transformations. 27 | 28 | Args: 29 | transforms: list of transformations. 30 | p: represents the multinomial probabilities (with no replacement) used for sampling the transform. 31 | If the sum of the weights is not 1, they will be normalized. If ``None`` (default), all transforms 32 | have the same probability. 33 | n_subset: number of transformations to apply. If ``None``, all transforms are applied. 34 | Must be in [1, len(transforms)]. 35 | random_order: apply transformations in a random order. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | transforms: Sequence[Callable], 41 | p: list[float] | None = None, 42 | n_subset: int | None = None, 43 | random_order: bool = False, 44 | ) -> None: 45 | super().__init__() 46 | if not isinstance(transforms, Sequence): 47 | raise TypeError("Argument transforms should be a sequence of callables") 48 | if p is None: 49 | p = [1] * len(transforms) 50 | elif len(p) != len(transforms): 51 | raise ValueError( 52 | f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}" 53 | ) 54 | 55 | if n_subset is None: 56 | n_subset = len(transforms) 57 | elif not isinstance(n_subset, int): 58 | raise TypeError("n_subset should be an int or None") 59 | elif not (1 <= n_subset <= len(transforms)): 60 | raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") 61 | 62 | self.transforms = transforms 63 | total = sum(p) 64 | self.p = [prob / total for prob in p] 65 | self.n_subset = n_subset 66 | self.random_order = random_order 67 | 68 | def forward(self, *inputs: Any) -> Any: 69 | needs_unpacking = len(inputs) > 1 70 | 71 | selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset) 72 | if not self.random_order: 73 | selected_indices = selected_indices.sort().values 74 | 75 | selected_transforms = [self.transforms[i] for i in selected_indices] 76 | 77 | for transform in selected_transforms: 78 | outputs = transform(*inputs) 79 | inputs = outputs if needs_unpacking else (outputs,) 80 | 81 | return outputs 82 | 83 | def extra_repr(self) -> str: 84 | return ( 85 | f"transforms={self.transforms}, " 86 | f"p={self.p}, " 87 | f"n_subset={self.n_subset}, " 88 | f"random_order={self.random_order}" 89 | ) 90 | 91 | 92 | class SharpnessJitter(Transform): 93 | """Randomly change the sharpness of an image or video. 94 | 95 | Similar to a v2.RandomAdjustSharpness with p=1 and a sharpness_factor sampled randomly. 96 | While v2.RandomAdjustSharpness applies — with a given probability — a fixed sharpness_factor to an image, 97 | SharpnessJitter applies a random sharpness_factor each time. This is to have a more diverse set of 98 | augmentations as a result. 99 | 100 | A sharpness_factor of 0 gives a blurred image, 1 gives the original image while 2 increases the sharpness 101 | by a factor of 2. 102 | 103 | If the input is a :class:`torch.Tensor`, 104 | it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. 105 | 106 | Args: 107 | sharpness: How much to jitter sharpness. sharpness_factor is chosen uniformly from 108 | [max(0, 1 - sharpness), 1 + sharpness] or the given 109 | [min, max]. Should be non negative numbers. 110 | """ 111 | 112 | def __init__(self, sharpness: float | Sequence[float]) -> None: 113 | super().__init__() 114 | self.sharpness = self._check_input(sharpness) 115 | 116 | def _check_input(self, sharpness): 117 | if isinstance(sharpness, (int, float)): 118 | if sharpness < 0: 119 | raise ValueError("If sharpness is a single number, it must be non negative.") 120 | sharpness = [1.0 - sharpness, 1.0 + sharpness] 121 | sharpness[0] = max(sharpness[0], 0.0) 122 | elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: 123 | sharpness = [float(v) for v in sharpness] 124 | else: 125 | raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") 126 | 127 | if not 0.0 <= sharpness[0] <= sharpness[1]: 128 | raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") 129 | 130 | return float(sharpness[0]), float(sharpness[1]) 131 | 132 | def _generate_value(self, left: float, right: float) -> float: 133 | return torch.empty(1).uniform_(left, right).item() 134 | 135 | def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: 136 | sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) 137 | return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) 138 | 139 | 140 | def get_image_transforms( 141 | brightness_weight: float = 1.0, 142 | brightness_min_max: tuple[float, float] | None = None, 143 | contrast_weight: float = 1.0, 144 | contrast_min_max: tuple[float, float] | None = None, 145 | saturation_weight: float = 1.0, 146 | saturation_min_max: tuple[float, float] | None = None, 147 | hue_weight: float = 1.0, 148 | hue_min_max: tuple[float, float] | None = None, 149 | sharpness_weight: float = 1.0, 150 | sharpness_min_max: tuple[float, float] | None = None, 151 | max_num_transforms: int | None = None, 152 | random_order: bool = False, 153 | ): 154 | def check_value(name, weight, min_max): 155 | if min_max is not None: 156 | if len(min_max) != 2: 157 | raise ValueError( 158 | f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided." 159 | ) 160 | if weight < 0.0: 161 | raise ValueError( 162 | f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})." 163 | ) 164 | 165 | check_value("brightness", brightness_weight, brightness_min_max) 166 | check_value("contrast", contrast_weight, contrast_min_max) 167 | check_value("saturation", saturation_weight, saturation_min_max) 168 | check_value("hue", hue_weight, hue_min_max) 169 | check_value("sharpness", sharpness_weight, sharpness_min_max) 170 | 171 | weights = [] 172 | transforms = [] 173 | if brightness_min_max is not None and brightness_weight > 0.0: 174 | weights.append(brightness_weight) 175 | transforms.append(v2.ColorJitter(brightness=brightness_min_max)) 176 | if contrast_min_max is not None and contrast_weight > 0.0: 177 | weights.append(contrast_weight) 178 | transforms.append(v2.ColorJitter(contrast=contrast_min_max)) 179 | if saturation_min_max is not None and saturation_weight > 0.0: 180 | weights.append(saturation_weight) 181 | transforms.append(v2.ColorJitter(saturation=saturation_min_max)) 182 | if hue_min_max is not None and hue_weight > 0.0: 183 | weights.append(hue_weight) 184 | transforms.append(v2.ColorJitter(hue=hue_min_max)) 185 | if sharpness_min_max is not None and sharpness_weight > 0.0: 186 | weights.append(sharpness_weight) 187 | transforms.append(SharpnessJitter(sharpness=sharpness_min_max)) 188 | 189 | n_subset = len(transforms) 190 | if max_num_transforms is not None: 191 | n_subset = min(n_subset, max_num_transforms) 192 | 193 | if n_subset == 0: 194 | return v2.Identity() 195 | else: 196 | # TODO(rcadene, aliberts): add v2.ToDtype float16? 197 | return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order) 198 | -------------------------------------------------------------------------------- /itps/common/policies/act/configuration_act.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from dataclasses import dataclass, field 17 | 18 | 19 | @dataclass 20 | class ACTConfig: 21 | """Configuration class for the Action Chunking Transformers policy. 22 | 23 | Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". 24 | 25 | The parameters you will most likely need to change are the ones which depend on the environment / sensors. 26 | Those are: `input_shapes` and 'output_shapes`. 27 | 28 | Notes on the inputs and outputs: 29 | - Either: 30 | - At least one key starting with "observation.image is required as an input. 31 | AND/OR 32 | - The key "observation.environment_state" is required as input. 33 | - If there are multiple keys beginning with "observation.images." they are treated as multiple camera 34 | views. Right now we only support all images having the same shape. 35 | - May optionally work without an "observation.state" key for the proprioceptive robot state. 36 | - "action" is required as an output key. 37 | 38 | Args: 39 | n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the 40 | current step and additional steps going back). 41 | chunk_size: The size of the action prediction "chunks" in units of environment steps. 42 | n_action_steps: The number of action steps to run in the environment for one invocation of the policy. 43 | This should be no greater than the chunk size. For example, if the chunk size size 100, you may 44 | set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the 45 | environment, and throws the other 50 out. 46 | input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents 47 | the input data name, and the value is a list indicating the dimensions of the corresponding data. 48 | For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], 49 | indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't 50 | include batch dimension or temporal dimension. 51 | output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents 52 | the output data name, and the value is a list indicating the dimensions of the corresponding data. 53 | For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. 54 | Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. 55 | input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), 56 | and the value specifies the normalization mode to apply. The two available modes are "mean_std" 57 | which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a 58 | [-1, 1] range. 59 | output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the 60 | original scale. Note that this is also used for normalizing the training targets. 61 | vision_backbone: Name of the torchvision resnet backbone to use for encoding images. 62 | pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. 63 | `None` means no pretrained weights. 64 | replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated 65 | convolution. 66 | pre_norm: Whether to use "pre-norm" in the transformer blocks. 67 | dim_model: The transformer blocks' main hidden dimension. 68 | n_heads: The number of heads to use in the transformer blocks' multi-head attention. 69 | dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward 70 | layers. 71 | feedforward_activation: The activation to use in the transformer block's feed-forward layers. 72 | n_encoder_layers: The number of transformer layers to use for the transformer encoder. 73 | n_decoder_layers: The number of transformer layers to use for the transformer decoder. 74 | use_vae: Whether to use a variational objective during training. This introduces another transformer 75 | which is used as the VAE's encoder (not to be confused with the transformer encoder - see 76 | documentation in the policy class). 77 | latent_dim: The VAE's latent dimension. 78 | n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder. 79 | temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling 80 | actions for a given time step over multiple policy invocations. Updates are calculated as: 81 | x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different 82 | parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our 83 | formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we 84 | require `n_action_steps == 1` (since we need to query the policy every step anyway). 85 | dropout: Dropout to use in the transformer layers (see code for details). 86 | kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective 87 | is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. 88 | """ 89 | 90 | # Input / output structure. 91 | n_obs_steps: int = 1 92 | chunk_size: int = 100 93 | n_action_steps: int = 100 94 | 95 | input_shapes: dict[str, list[int]] = field( 96 | default_factory=lambda: { 97 | "observation.images.top": [3, 480, 640], 98 | "observation.state": [14], 99 | } 100 | ) 101 | output_shapes: dict[str, list[int]] = field( 102 | default_factory=lambda: { 103 | "action": [14], 104 | } 105 | ) 106 | 107 | # Normalization / Unnormalization 108 | input_normalization_modes: dict[str, str] = field( 109 | default_factory=lambda: { 110 | "observation.images.top": "mean_std", 111 | "observation.state": "mean_std", 112 | } 113 | ) 114 | output_normalization_modes: dict[str, str] = field( 115 | default_factory=lambda: { 116 | "action": "mean_std", 117 | } 118 | ) 119 | 120 | # Architecture. 121 | # Vision backbone. 122 | vision_backbone: str = "resnet18" 123 | pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1" 124 | replace_final_stride_with_dilation: int = False 125 | # Transformer layers. 126 | pre_norm: bool = False 127 | dim_model: int = 512 128 | n_heads: int = 8 129 | dim_feedforward: int = 3200 130 | feedforward_activation: str = "relu" 131 | n_encoder_layers: int = 4 132 | # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code 133 | # that means only the first layer is used. Here we match the original implementation by setting this to 1. 134 | # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. 135 | n_decoder_layers: int = 1 136 | # VAE. 137 | use_vae: bool = True 138 | latent_dim: int = 32 139 | n_vae_encoder_layers: int = 4 140 | 141 | # Inference. 142 | temporal_ensemble_momentum: float | None = None 143 | 144 | # Training and loss computation. 145 | dropout: float = 0.1 146 | kl_weight: float = 10.0 147 | 148 | def __post_init__(self): 149 | """Input validation (not exhaustive).""" 150 | if not self.vision_backbone.startswith("resnet"): 151 | raise ValueError( 152 | f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." 153 | ) 154 | if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1: 155 | raise NotImplementedError( 156 | "`n_action_steps` must be 1 when using temporal ensembling. This is " 157 | "because the policy needs to be queried every step to compute the ensembled action." 158 | ) 159 | if self.n_action_steps > self.chunk_size: 160 | raise ValueError( 161 | f"The chunk size is the upper bound for the number of action steps per model invocation. Got " 162 | f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." 163 | ) 164 | if self.n_obs_steps != 1: 165 | raise ValueError( 166 | f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" 167 | ) 168 | if ( 169 | not any(k.startswith("observation.image") for k in self.input_shapes) 170 | and "observation.environment_state" not in self.input_shapes 171 | ): 172 | raise ValueError("You must provide at least one image or the environment state among the inputs.") 173 | -------------------------------------------------------------------------------- /itps/common/datasets/compute_stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from copy import deepcopy 17 | from math import ceil 18 | 19 | import einops 20 | import torch 21 | import tqdm 22 | from datasets import Image 23 | 24 | from common.datasets.video_utils import VideoFrame 25 | 26 | 27 | def get_stats_einops_patterns(dataset, num_workers=0): 28 | """These einops patterns will be used to aggregate batches and compute statistics. 29 | 30 | Note: We assume the images are in channel first format 31 | """ 32 | 33 | dataloader = torch.utils.data.DataLoader( 34 | dataset, 35 | num_workers=num_workers, 36 | batch_size=2, 37 | shuffle=False, 38 | ) 39 | batch = next(iter(dataloader)) 40 | 41 | stats_patterns = {} 42 | for key, feats_type in dataset.features.items(): 43 | # sanity check that tensors are not float64 44 | assert batch[key].dtype != torch.float64 45 | 46 | if isinstance(feats_type, (VideoFrame, Image)): 47 | # sanity check that images are channel first 48 | _, c, h, w = batch[key].shape 49 | assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" 50 | 51 | # sanity check that images are float32 in range [0,1] 52 | assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" 53 | assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" 54 | assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" 55 | 56 | stats_patterns[key] = "b c h w -> c 1 1" 57 | elif batch[key].ndim == 2: 58 | stats_patterns[key] = "b c -> c " 59 | elif batch[key].ndim == 1: 60 | stats_patterns[key] = "b -> 1" 61 | else: 62 | raise ValueError(f"{key}, {feats_type}, {batch[key].shape}") 63 | 64 | return stats_patterns 65 | 66 | 67 | def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None): 68 | """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" 69 | if max_num_samples is None: 70 | max_num_samples = len(dataset) 71 | 72 | # for more info on why we need to set the same number of workers, see `load_from_videos` 73 | stats_patterns = get_stats_einops_patterns(dataset, num_workers) 74 | 75 | # mean and std will be computed incrementally while max and min will track the running value. 76 | mean, std, max, min = {}, {}, {}, {} 77 | for key in stats_patterns: 78 | mean[key] = torch.tensor(0.0).float() 79 | std[key] = torch.tensor(0.0).float() 80 | max[key] = torch.tensor(-float("inf")).float() 81 | min[key] = torch.tensor(float("inf")).float() 82 | 83 | def create_seeded_dataloader(dataset, batch_size, seed): 84 | generator = torch.Generator() 85 | generator.manual_seed(seed) 86 | dataloader = torch.utils.data.DataLoader( 87 | dataset, 88 | num_workers=num_workers, 89 | batch_size=batch_size, 90 | shuffle=True, 91 | drop_last=False, 92 | generator=generator, 93 | ) 94 | return dataloader 95 | 96 | # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get 97 | # surprises when rerunning the sampler. 98 | first_batch = None 99 | running_item_count = 0 # for online mean computation 100 | dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) 101 | for i, batch in enumerate( 102 | tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") 103 | ): 104 | this_batch_size = len(batch["index"]) 105 | running_item_count += this_batch_size 106 | if first_batch is None: 107 | first_batch = deepcopy(batch) 108 | for key, pattern in stats_patterns.items(): 109 | batch[key] = batch[key].float() 110 | # Numerically stable update step for mean computation. 111 | batch_mean = einops.reduce(batch[key], pattern, "mean") 112 | # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents 113 | # the update step, N is the running item count, B is this batch size, x̄ is the running mean, 114 | # and x is the current batch mean. Some rearrangement is then required to avoid risking 115 | # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields 116 | # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ 117 | mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count 118 | max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) 119 | min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) 120 | 121 | if i == ceil(max_num_samples / batch_size) - 1: 122 | break 123 | 124 | first_batch_ = None 125 | running_item_count = 0 # for online std computation 126 | dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) 127 | for i, batch in enumerate( 128 | tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") 129 | ): 130 | this_batch_size = len(batch["index"]) 131 | running_item_count += this_batch_size 132 | # Sanity check to make sure the batches are still in the same order as before. 133 | if first_batch_ is None: 134 | first_batch_ = deepcopy(batch) 135 | for key in stats_patterns: 136 | assert torch.equal(first_batch_[key], first_batch[key]) 137 | for key, pattern in stats_patterns.items(): 138 | batch[key] = batch[key].float() 139 | # Numerically stable update step for mean computation (where the mean is over squared 140 | # residuals).See notes in the mean computation loop above. 141 | batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") 142 | std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count 143 | 144 | if i == ceil(max_num_samples / batch_size) - 1: 145 | break 146 | 147 | for key in stats_patterns: 148 | std[key] = torch.sqrt(std[key]) 149 | 150 | stats = {} 151 | for key in stats_patterns: 152 | stats[key] = { 153 | "mean": mean[key], 154 | "std": std[key], 155 | "max": max[key], 156 | "min": min[key], 157 | } 158 | return stats 159 | 160 | 161 | def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: 162 | """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. 163 | 164 | The final stats will have the union of all data keys from each of the datasets. 165 | 166 | The final stats will have the union of all data keys from each of the datasets. For instance: 167 | - new_max = max(max_dataset_0, max_dataset_1, ...) 168 | - new_min = min(min_dataset_0, min_dataset_1, ...) 169 | - new_mean = (mean of all data) 170 | - new_std = (std of all data) 171 | """ 172 | data_keys = set() 173 | for dataset in ls_datasets: 174 | data_keys.update(dataset.stats.keys()) 175 | stats = {k: {} for k in data_keys} 176 | for data_key in data_keys: 177 | for stat_key in ["min", "max"]: 178 | # compute `max(dataset_0["max"], dataset_1["max"], ...)` 179 | stats[data_key][stat_key] = einops.reduce( 180 | torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0), 181 | "n ... -> ...", 182 | stat_key, 183 | ) 184 | total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats) 185 | # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective 186 | # dataset, then divide by total_samples to get the overall "mean". 187 | # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of 188 | # numerical overflow! 189 | stats[data_key]["mean"] = sum( 190 | d.stats[data_key]["mean"] * (d.num_samples / total_samples) 191 | for d in ls_datasets 192 | if data_key in d.stats 193 | ) 194 | # The derivation for standard deviation is a little more involved but is much in the same spirit as 195 | # the computation of the mean. 196 | # Given two sets of data where the statistics are known: 197 | # σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ] 198 | # where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined 199 | # NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of 200 | # numerical overflow! 201 | stats[data_key]["std"] = torch.sqrt( 202 | sum( 203 | (d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2) 204 | * (d.num_samples / total_samples) 205 | for d in ls_datasets 206 | if data_key in d.stats 207 | ) 208 | ) 209 | return stats 210 | -------------------------------------------------------------------------------- /itps/common/datasets/video_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import logging 17 | import subprocess 18 | import warnings 19 | from collections import OrderedDict 20 | from dataclasses import dataclass, field 21 | from pathlib import Path 22 | from typing import Any, ClassVar 23 | 24 | import pyarrow as pa 25 | import torch 26 | import torchvision 27 | from datasets.features.features import register_feature 28 | 29 | 30 | def load_from_videos( 31 | item: dict[str, torch.Tensor], 32 | video_frame_keys: list[str], 33 | videos_dir: Path, 34 | tolerance_s: float, 35 | backend: str = "pyav", 36 | ): 37 | """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function 38 | in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. 39 | This probably happens because a memory reference to the video loader is created in the main process and a 40 | subprocess fails to access it. 41 | """ 42 | # since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4") 43 | data_dir = videos_dir.parent 44 | 45 | for key in video_frame_keys: 46 | if isinstance(item[key], list): 47 | # load multiple frames at once (expected when delta_timestamps is not None) 48 | timestamps = [frame["timestamp"] for frame in item[key]] 49 | paths = [frame["path"] for frame in item[key]] 50 | if len(set(paths)) > 1: 51 | raise NotImplementedError("All video paths are expected to be the same for now.") 52 | video_path = data_dir / paths[0] 53 | 54 | frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) 55 | item[key] = frames 56 | else: 57 | # load one frame 58 | timestamps = [item[key]["timestamp"]] 59 | video_path = data_dir / item[key]["path"] 60 | 61 | frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) 62 | item[key] = frames[0] 63 | 64 | return item 65 | 66 | 67 | def decode_video_frames_torchvision( 68 | video_path: str, 69 | timestamps: list[float], 70 | tolerance_s: float, 71 | backend: str = "pyav", 72 | log_loaded_timestamps: bool = False, 73 | ) -> torch.Tensor: 74 | """Loads frames associated to the requested timestamps of a video 75 | 76 | The backend can be either "pyav" (default) or "video_reader". 77 | "video_reader" requires installing torchvision from source, see: 78 | https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst 79 | (note that you need to compile against ffmpeg<4.3) 80 | 81 | While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup. 82 | For more info on video decoding, see `benchmark/video/README.md` 83 | 84 | See torchvision doc for more info on these two backends: 85 | https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend 86 | 87 | Note: Video benefits from inter-frame compression. Instead of storing every frame individually, 88 | the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to 89 | that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, 90 | and all subsequent frames until reaching the requested frame. The number of key frames in a video 91 | can be adjusted during encoding to take into account decoding time and video size in bytes. 92 | """ 93 | video_path = str(video_path) 94 | 95 | # set backend 96 | keyframes_only = False 97 | torchvision.set_video_backend(backend) 98 | if backend == "pyav": 99 | keyframes_only = True # pyav doesnt support accuracte seek 100 | 101 | # set a video stream reader 102 | # TODO(rcadene): also load audio stream at the same time 103 | reader = torchvision.io.VideoReader(video_path, "video") 104 | 105 | # set the first and last requested timestamps 106 | # Note: previous timestamps are usually loaded, since we need to access the previous key frame 107 | first_ts = timestamps[0] 108 | last_ts = timestamps[-1] 109 | 110 | # access closest key frame of the first requested frame 111 | # Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video) 112 | # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek 113 | reader.seek(first_ts, keyframes_only=keyframes_only) 114 | 115 | # load all frames until last requested frame 116 | loaded_frames = [] 117 | loaded_ts = [] 118 | for frame in reader: 119 | current_ts = frame["pts"] 120 | if log_loaded_timestamps: 121 | logging.info(f"frame loaded at timestamp={current_ts:.4f}") 122 | loaded_frames.append(frame["data"]) 123 | loaded_ts.append(current_ts) 124 | if current_ts >= last_ts: 125 | break 126 | 127 | if backend == "pyav": 128 | reader.container.close() 129 | 130 | reader = None 131 | 132 | query_ts = torch.tensor(timestamps) 133 | loaded_ts = torch.tensor(loaded_ts) 134 | 135 | # compute distances between each query timestamp and timestamps of all loaded frames 136 | dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) 137 | min_, argmin_ = dist.min(1) 138 | 139 | is_within_tol = min_ < tolerance_s 140 | assert is_within_tol.all(), ( 141 | f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." 142 | "It means that the closest frame that can be loaded from the video is too far away in time." 143 | "This might be due to synchronization issues with timestamps during data collection." 144 | "To be safe, we advise to ignore this item during training." 145 | f"\nqueried timestamps: {query_ts}" 146 | f"\nloaded timestamps: {loaded_ts}" 147 | f"\nvideo: {video_path}" 148 | f"\nbackend: {backend}" 149 | ) 150 | 151 | # get closest frames to the query timestamps 152 | closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) 153 | closest_ts = loaded_ts[argmin_] 154 | 155 | if log_loaded_timestamps: 156 | logging.info(f"{closest_ts=}") 157 | 158 | # convert to the pytorch format which is float32 in [0,1] range (and channel first) 159 | closest_frames = closest_frames.type(torch.float32) / 255 160 | 161 | assert len(timestamps) == len(closest_frames) 162 | return closest_frames 163 | 164 | 165 | def encode_video_frames( 166 | imgs_dir: Path, 167 | video_path: Path, 168 | fps: int, 169 | video_codec: str = "libsvtav1", 170 | pixel_format: str = "yuv420p", 171 | group_of_pictures_size: int | None = 2, 172 | constant_rate_factor: int | None = 30, 173 | fast_decode: int = 0, 174 | log_level: str | None = "error", 175 | overwrite: bool = False, 176 | ) -> None: 177 | """More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" 178 | video_path = Path(video_path) 179 | video_path.parent.mkdir(parents=True, exist_ok=True) 180 | 181 | ffmpeg_args = OrderedDict( 182 | [ 183 | ("-f", "image2"), 184 | ("-r", str(fps)), 185 | ("-i", str(imgs_dir / "frame_%06d.png")), 186 | ("-vcodec", video_codec), 187 | ("-pix_fmt", pixel_format), 188 | ] 189 | ) 190 | 191 | if group_of_pictures_size is not None: 192 | ffmpeg_args["-g"] = str(group_of_pictures_size) 193 | 194 | if constant_rate_factor is not None: 195 | ffmpeg_args["-crf"] = str(constant_rate_factor) 196 | 197 | if fast_decode: 198 | key = "-svtav1-params" if video_codec == "libsvtav1" else "-tune" 199 | value = f"fast-decode={fast_decode}" if video_codec == "libsvtav1" else "fastdecode" 200 | ffmpeg_args[key] = value 201 | 202 | if log_level is not None: 203 | ffmpeg_args["-loglevel"] = str(log_level) 204 | 205 | ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] 206 | if overwrite: 207 | ffmpeg_args.append("-y") 208 | 209 | ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] 210 | subprocess.run(ffmpeg_cmd, check=True) 211 | 212 | 213 | @dataclass 214 | class VideoFrame: 215 | # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo 216 | """ 217 | Provides a type for a dataset containing video frames. 218 | 219 | Example: 220 | 221 | ```python 222 | data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}] 223 | features = {"image": VideoFrame()} 224 | Dataset.from_dict(data_dict, features=Features(features)) 225 | ``` 226 | """ 227 | 228 | pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()}) 229 | _type: str = field(default="VideoFrame", init=False, repr=False) 230 | 231 | def __call__(self): 232 | return self.pa_type 233 | 234 | 235 | with warnings.catch_warnings(): 236 | warnings.filterwarnings( 237 | "ignore", 238 | "'register_feature' is experimental and might be subject to breaking changes in the future.", 239 | category=UserWarning, 240 | ) 241 | # to make VideoFrame available in HuggingFace `datasets` 242 | register_feature(VideoFrame, "VideoFrame") 243 | -------------------------------------------------------------------------------- /itps/common/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py 17 | 18 | # TODO(rcadene, alexander-soare): clean this file 19 | """ 20 | 21 | import logging 22 | import os 23 | import re 24 | from glob import glob 25 | from pathlib import Path 26 | 27 | import torch 28 | from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE 29 | from omegaconf import DictConfig, OmegaConf 30 | from termcolor import colored 31 | from torch.optim import Optimizer 32 | from torch.optim.lr_scheduler import LRScheduler 33 | 34 | from common.policies.policy_protocol import Policy 35 | from common.utils.utils import get_global_random_state, set_global_random_state 36 | 37 | 38 | def log_output_dir(out_dir): 39 | logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") 40 | 41 | 42 | def cfg_to_group(cfg: DictConfig, return_list: bool = False) -> list[str] | str: 43 | """Return a group name for logging. Optionally returns group name as list.""" 44 | lst = [ 45 | f"policy:{cfg.policy.name}", 46 | f"dataset:{cfg.dataset_repo_id}", 47 | f"env:{cfg.env.name}", 48 | f"seed:{cfg.seed}", 49 | ] 50 | return lst if return_list else "-".join(lst) 51 | 52 | 53 | def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str: 54 | # Get the WandB run ID. 55 | paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*")) 56 | if len(paths) != 1: 57 | raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") 58 | match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1]) 59 | if match is None: 60 | raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") 61 | wandb_run_id = match.groups(0)[0] 62 | return wandb_run_id 63 | 64 | 65 | class Logger: 66 | """Primary logger object. Logs either locally or using wandb. 67 | 68 | The logger creates the following directory structure: 69 | 70 | provided_log_dir 71 | ├── .hydra # hydra's configuration cache 72 | ├── checkpoints 73 | │ ├── specific_checkpoint_name 74 | │ │ ├── pretrained_model # Hugging Face pretrained model directory 75 | │ │ │ ├── ... 76 | │ │ └── training_state.pth # optimizer, scheduler, and random states + training step 77 | | ├── another_specific_checkpoint_name 78 | │ │ ├── ... 79 | | ├── ... 80 | │ └── last # a softlink to the last logged checkpoint 81 | """ 82 | 83 | pretrained_model_dir_name = "pretrained_model" 84 | training_state_file_name = "training_state.pth" 85 | 86 | def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None): 87 | """ 88 | Args: 89 | log_dir: The directory to save all logs and training outputs to. 90 | job_name: The WandB job name. 91 | """ 92 | self._cfg = cfg 93 | self.log_dir = Path(log_dir) 94 | self.log_dir.mkdir(parents=True, exist_ok=True) 95 | self.checkpoints_dir = self.get_checkpoints_dir(log_dir) 96 | self.last_checkpoint_dir = self.get_last_checkpoint_dir(log_dir) 97 | self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(log_dir) 98 | 99 | # Set up WandB. 100 | self._group = cfg_to_group(cfg) 101 | project = cfg.get("wandb", {}).get("project") 102 | entity = cfg.get("wandb", {}).get("entity") 103 | enable_wandb = cfg.get("wandb", {}).get("enable", False) 104 | run_offline = not enable_wandb or not project 105 | if run_offline: 106 | logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) 107 | self._wandb = None 108 | else: 109 | os.environ["WANDB_SILENT"] = "true" 110 | import wandb 111 | 112 | wandb_run_id = None 113 | if cfg.resume: 114 | wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir) 115 | 116 | wandb.init( 117 | id=wandb_run_id, 118 | project=project, 119 | entity=entity, 120 | name=wandb_job_name, 121 | notes=cfg.get("wandb", {}).get("notes"), 122 | tags=cfg_to_group(cfg, return_list=True), 123 | dir=log_dir, 124 | config=OmegaConf.to_container(cfg, resolve=True), 125 | # TODO(rcadene): try set to True 126 | save_code=False, 127 | # TODO(rcadene): split train and eval, and run async eval with job_type="eval" 128 | job_type="train_eval", 129 | resume="must" if cfg.resume else None, 130 | ) 131 | print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) 132 | logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") 133 | self._wandb = wandb 134 | 135 | @classmethod 136 | def get_checkpoints_dir(cls, log_dir: str | Path) -> Path: 137 | """Given the log directory, get the sub-directory in which checkpoints will be saved.""" 138 | return Path(log_dir) / "checkpoints" 139 | 140 | @classmethod 141 | def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path: 142 | """Given the log directory, get the sub-directory in which the last checkpoint will be saved.""" 143 | return cls.get_checkpoints_dir(log_dir) / "last" 144 | 145 | @classmethod 146 | def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path: 147 | """ 148 | Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will 149 | be saved. 150 | """ 151 | return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name 152 | 153 | def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None): 154 | """Save the weights of the Policy model using PyTorchModelHubMixin. 155 | 156 | The weights are saved in a folder called "pretrained_model" under the checkpoint directory. 157 | 158 | Optionally also upload the model to WandB. 159 | """ 160 | self.checkpoints_dir.mkdir(parents=True, exist_ok=True) 161 | policy.save_pretrained(save_dir) 162 | # Also save the full Hydra config for the env configuration. 163 | OmegaConf.save(self._cfg, save_dir / "config.yaml") 164 | if self._wandb and not self._cfg.wandb.disable_artifact: 165 | # note wandb artifact does not accept ":" or "/" in its name 166 | artifact = self._wandb.Artifact(wandb_artifact_name, type="model") 167 | artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) 168 | self._wandb.log_artifact(artifact) 169 | if self.last_checkpoint_dir.exists(): 170 | os.remove(self.last_checkpoint_dir) 171 | 172 | def save_training_state( 173 | self, 174 | save_dir: Path, 175 | train_step: int, 176 | optimizer: Optimizer, 177 | scheduler: LRScheduler | None, 178 | ): 179 | """Checkpoint the global training_step, optimizer state, scheduler state, and random state. 180 | 181 | All of these are saved as "training_state.pth" under the checkpoint directory. 182 | """ 183 | training_state = { 184 | "step": train_step, 185 | "optimizer": optimizer.state_dict(), 186 | **get_global_random_state(), 187 | } 188 | if scheduler is not None: 189 | training_state["scheduler"] = scheduler.state_dict() 190 | torch.save(training_state, save_dir / self.training_state_file_name) 191 | 192 | def save_checkpont( 193 | self, 194 | train_step: int, 195 | policy: Policy, 196 | optimizer: Optimizer, 197 | scheduler: LRScheduler | None, 198 | identifier: str, 199 | ): 200 | """Checkpoint the model weights and the training state.""" 201 | checkpoint_dir = self.checkpoints_dir / str(identifier) 202 | wandb_artifact_name = ( 203 | None 204 | if self._wandb is None 205 | else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}" 206 | ) 207 | self.save_model( 208 | checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name 209 | ) 210 | self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler) 211 | os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir) 212 | 213 | def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int: 214 | """ 215 | Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and 216 | random state, and return the global training step. 217 | """ 218 | training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name) 219 | optimizer.load_state_dict(training_state["optimizer"]) 220 | if scheduler is not None: 221 | scheduler.load_state_dict(training_state["scheduler"]) 222 | elif "scheduler" in training_state: 223 | raise ValueError( 224 | "The checkpoint contains a scheduler state_dict, but no LRScheduler was provided." 225 | ) 226 | # Small hack to get the expected keys: use `get_global_random_state`. 227 | set_global_random_state({k: training_state[k] for k in get_global_random_state()}) 228 | return training_state["step"] 229 | 230 | def log_dict(self, d, step, mode="train"): 231 | assert mode in {"train", "eval"} 232 | # TODO(alexander-soare): Add local text log. 233 | if self._wandb is not None: 234 | for k, v in d.items(): 235 | if not isinstance(v, (int, float, str)): 236 | logging.warning( 237 | f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' 238 | ) 239 | continue 240 | self._wandb.log({f"{mode}/{k}": v}, step=step) 241 | 242 | def log_video(self, video_path: str, step: int, mode: str = "train"): 243 | assert mode in {"train", "eval"} 244 | assert self._wandb is not None 245 | wandb_video = self._wandb.Video(video_path, fps=self._cfg.fps, format="mp4") 246 | self._wandb.log({f"{mode}/video": wandb_video}, step=step) 247 | -------------------------------------------------------------------------------- /itps/common/policies/normalize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import torch 17 | from torch import Tensor, nn 18 | 19 | 20 | def create_stats_buffers( 21 | shapes: dict[str, list[int]], 22 | modes: dict[str, str], 23 | stats: dict[str, dict[str, Tensor]] | None = None, 24 | ) -> dict[str, dict[str, nn.ParameterDict]]: 25 | """ 26 | Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max 27 | statistics. 28 | 29 | Args: (see Normalize and Unnormalize) 30 | 31 | Returns: 32 | dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing 33 | `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation. 34 | """ 35 | stats_buffers = {} 36 | 37 | for key, mode in modes.items(): 38 | assert mode in ["mean_std", "min_max"] 39 | 40 | shape = tuple(shapes[key]) 41 | 42 | if "image" in key: 43 | # sanity checks 44 | assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" 45 | c, h, w = shape 46 | assert c < h and c < w, f"{key} is not channel first ({shape=})" 47 | # override image shape to be invariant to height and width 48 | shape = (c, 1, 1) 49 | 50 | # Note: we initialize mean, std, min, max to infinity. They should be overwritten 51 | # downstream by `stats` or `policy.load_state_dict`, as expected. During forward, 52 | # we assert they are not infinity anymore. 53 | 54 | buffer = {} 55 | if mode == "mean_std": 56 | mean = torch.ones(shape, dtype=torch.float32) * torch.inf 57 | std = torch.ones(shape, dtype=torch.float32) * torch.inf 58 | buffer = nn.ParameterDict( 59 | { 60 | "mean": nn.Parameter(mean, requires_grad=False), 61 | "std": nn.Parameter(std, requires_grad=False), 62 | } 63 | ) 64 | elif mode == "min_max": 65 | min = torch.ones(shape, dtype=torch.float32) * torch.inf 66 | max = torch.ones(shape, dtype=torch.float32) * torch.inf 67 | buffer = nn.ParameterDict( 68 | { 69 | "min": nn.Parameter(min, requires_grad=False), 70 | "max": nn.Parameter(max, requires_grad=False), 71 | } 72 | ) 73 | 74 | if stats is not None: 75 | # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated 76 | # tensors anywhere (for example, when we use the same stats for normalization and 77 | # unnormalization). See the logic here 78 | # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. 79 | if mode == "mean_std": 80 | buffer["mean"].data = stats[key]["mean"].clone() 81 | buffer["std"].data = stats[key]["std"].clone() 82 | elif mode == "min_max": 83 | buffer["min"].data = stats[key]["min"].clone() 84 | buffer["max"].data = stats[key]["max"].clone() 85 | 86 | stats_buffers[key] = buffer 87 | return stats_buffers 88 | 89 | 90 | def _no_stats_error_str(name: str) -> str: 91 | return ( 92 | f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " 93 | "pretrained model." 94 | ) 95 | 96 | 97 | class Normalize(nn.Module): 98 | """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" 99 | 100 | def __init__( 101 | self, 102 | shapes: dict[str, list[int]], 103 | modes: dict[str, str], 104 | stats: dict[str, dict[str, Tensor]] | None = None, 105 | ): 106 | """ 107 | Args: 108 | shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values 109 | are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing 110 | mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape 111 | is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. 112 | modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values 113 | are their normalization modes among: 114 | - "mean_std": subtract the mean and divide by standard deviation. 115 | - "min_max": map to [-1, 1] range. 116 | stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") 117 | and values are dictionaries of statistic types and their values (e.g. 118 | `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for 119 | training the model for the first time, these statistics will overwrite the default buffers. If 120 | not provided, as expected for finetuning or evaluation, the default buffers should to be 121 | overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the 122 | dataset is not needed to get the stats, since they are already in the policy state_dict. 123 | """ 124 | super().__init__() 125 | self.shapes = shapes 126 | self.modes = modes 127 | self.stats = stats 128 | stats_buffers = create_stats_buffers(shapes, modes, stats) 129 | for key, buffer in stats_buffers.items(): 130 | setattr(self, "buffer_" + key.replace(".", "_"), buffer) 131 | 132 | # TODO(rcadene): should we remove torch.no_grad? 133 | @torch.no_grad 134 | def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: 135 | for key, mode in self.modes.items(): 136 | buffer = getattr(self, "buffer_" + key.replace(".", "_")) 137 | 138 | if mode == "mean_std": 139 | mean = buffer["mean"] 140 | std = buffer["std"] 141 | assert not torch.isinf(mean).any(), _no_stats_error_str("mean") 142 | assert not torch.isinf(std).any(), _no_stats_error_str("std") 143 | batch[key] = (batch[key] - mean) / (std + 1e-8) 144 | elif mode == "min_max": 145 | min = buffer["min"] 146 | max = buffer["max"] 147 | assert not torch.isinf(min).any(), _no_stats_error_str("min") 148 | assert not torch.isinf(max).any(), _no_stats_error_str("max") 149 | # normalize to [0,1] 150 | batch[key] = (batch[key] - min) / (max - min + 1e-8) 151 | # normalize to [-1, 1] 152 | batch[key] = batch[key] * 2 - 1 153 | else: 154 | raise ValueError(mode) 155 | return batch 156 | 157 | 158 | class Unnormalize(nn.Module): 159 | """ 160 | Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their 161 | original range used by the environment. 162 | """ 163 | 164 | def __init__( 165 | self, 166 | shapes: dict[str, list[int]], 167 | modes: dict[str, str], 168 | stats: dict[str, dict[str, Tensor]] | None = None, 169 | ): 170 | """ 171 | Args: 172 | shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values 173 | are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing 174 | mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape 175 | is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. 176 | modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values 177 | are their normalization modes among: 178 | - "mean_std": subtract the mean and divide by standard deviation. 179 | - "min_max": map to [-1, 1] range. 180 | stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") 181 | and values are dictionaries of statistic types and their values (e.g. 182 | `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for 183 | training the model for the first time, these statistics will overwrite the default buffers. If 184 | not provided, as expected for finetuning or evaluation, the default buffers should to be 185 | overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the 186 | dataset is not needed to get the stats, since they are already in the policy state_dict. 187 | """ 188 | super().__init__() 189 | self.shapes = shapes 190 | self.modes = modes 191 | self.stats = stats 192 | # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` 193 | stats_buffers = create_stats_buffers(shapes, modes, stats) 194 | for key, buffer in stats_buffers.items(): 195 | setattr(self, "buffer_" + key.replace(".", "_"), buffer) 196 | 197 | # TODO(rcadene): should we remove torch.no_grad? 198 | @torch.no_grad 199 | def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: 200 | for key, mode in self.modes.items(): 201 | buffer = getattr(self, "buffer_" + key.replace(".", "_")) 202 | 203 | if mode == "mean_std": 204 | mean = buffer["mean"] 205 | std = buffer["std"] 206 | assert not torch.isinf(mean).any(), _no_stats_error_str("mean") 207 | assert not torch.isinf(std).any(), _no_stats_error_str("std") 208 | batch[key] = batch[key] * std + mean 209 | elif mode == "min_max": 210 | min = buffer["min"] 211 | max = buffer["max"] 212 | assert not torch.isinf(min).any(), _no_stats_error_str("min") 213 | assert not torch.isinf(max).any(), _no_stats_error_str("max") 214 | batch[key] = (batch[key] + 1) / 2 215 | batch[key] = batch[key] * (max - min) + min 216 | else: 217 | raise ValueError(mode) 218 | return batch 219 | -------------------------------------------------------------------------------- /itps/common/policies/diffusion/configuration_diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, 4 | # and The HuggingFace Inc. team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | from dataclasses import dataclass, field 18 | 19 | 20 | @dataclass 21 | class DiffusionConfig: 22 | """Configuration class for DiffusionPolicy. 23 | 24 | Defaults are configured for training with PushT providing proprioceptive and single camera observations. 25 | 26 | The parameters you will most likely need to change are the ones which depend on the environment / sensors. 27 | Those are: `input_shapes` and `output_shapes`. 28 | 29 | Notes on the inputs and outputs: 30 | - "observation.state" is required as an input key. 31 | - Either: 32 | - At least one key starting with "observation.image is required as an input. 33 | AND/OR 34 | - The key "observation.environment_state" is required as input. 35 | - If there are multiple keys beginning with "observation.image" they are treated as multiple camera 36 | views. Right now we only support all images having the same shape. 37 | - "action" is required as an output key. 38 | 39 | Args: 40 | n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the 41 | current step and additional steps going back). 42 | horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. 43 | n_action_steps: The number of action steps to run in the environment for one invocation of the policy. 44 | See `DiffusionPolicy.select_action` for more details. 45 | input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents 46 | the input data name, and the value is a list indicating the dimensions of the corresponding data. 47 | For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], 48 | indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't 49 | include batch dimension or temporal dimension. 50 | output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents 51 | the output data name, and the value is a list indicating the dimensions of the corresponding data. 52 | For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. 53 | Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. 54 | input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), 55 | and the value specifies the normalization mode to apply. The two available modes are "mean_std" 56 | which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a 57 | [-1, 1] range. 58 | output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the 59 | original scale. Note that this is also used for normalizing the training targets. 60 | vision_backbone: Name of the torchvision resnet backbone to use for encoding images. 61 | crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit 62 | within the image size. If None, no cropping is done. 63 | crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval 64 | mode). 65 | pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. 66 | `None` means no pretrained weights. 67 | use_group_norm: Whether to replace batch normalization with group normalization in the backbone. 68 | The group sizes are set to be about 16 (to be precise, feature_dim // 16). 69 | spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. 70 | down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. 71 | You may provide a variable number of dimensions, therefore also controlling the degree of 72 | downsampling. 73 | kernel_size: The convolutional kernel size of the diffusion modeling Unet. 74 | n_groups: Number of groups used in the group norm of the Unet's convolutional blocks. 75 | diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear 76 | network. This is the output dimension of that network, i.e., the embedding dimension. 77 | use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. 78 | Bias modulation is used be default, while this parameter indicates whether to also use scale 79 | modulation. 80 | noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"]. 81 | num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. 82 | beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. 83 | beta_start: Beta value for the first forward-diffusion step. 84 | beta_end: Beta value for the last forward-diffusion step. 85 | prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon" 86 | or "sample". These have equivalent outcomes from a latent variable modeling perspective, but 87 | "epsilon" has been shown to work better in many deep neural network settings. 88 | clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each 89 | denoising step at inference time. WARNING: you will need to make sure your action-space is 90 | normalized to fit within this range. 91 | clip_sample_range: The magnitude of the clipping range as described above. 92 | num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly 93 | spaced). If not provided, this defaults to be the same as `num_train_timesteps`. 94 | do_mask_loss_for_padding: Whether to mask the loss when there are copy-padded actions. See 95 | `LeRobotDataset` and `load_previous_and_future_frames` for mor information. Note, this defaults 96 | to False as the original Diffusion Policy implementation does the same. 97 | """ 98 | 99 | # Inputs / output structure. 100 | n_obs_steps: int = 2 101 | horizon: int = 16 102 | n_action_steps: int = 8 103 | 104 | input_shapes: dict[str, list[int]] = field( 105 | default_factory=lambda: { 106 | "observation.image": [3, 96, 96], 107 | "observation.state": [2], 108 | } 109 | ) 110 | output_shapes: dict[str, list[int]] = field( 111 | default_factory=lambda: { 112 | "action": [2], 113 | } 114 | ) 115 | 116 | # Normalization / Unnormalization 117 | input_normalization_modes: dict[str, str] = field( 118 | default_factory=lambda: { 119 | "observation.image": "mean_std", 120 | "observation.state": "min_max", 121 | } 122 | ) 123 | output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) 124 | 125 | # Architecture / modeling. 126 | # Vision backbone. 127 | vision_backbone: str = "resnet18" 128 | crop_shape: tuple[int, int] | None = (84, 84) 129 | crop_is_random: bool = True 130 | pretrained_backbone_weights: str | None = None 131 | use_group_norm: bool = True 132 | spatial_softmax_num_keypoints: int = 32 133 | # Unet. 134 | down_dims: tuple[int, ...] = (512, 1024, 2048) 135 | kernel_size: int = 5 136 | n_groups: int = 8 137 | diffusion_step_embed_dim: int = 128 138 | use_film_scale_modulation: bool = True 139 | # Noise scheduler. 140 | noise_scheduler_type: str = "DDPM" 141 | num_train_timesteps: int = 100 142 | beta_schedule: str = "squaredcos_cap_v2" 143 | beta_start: float = 0.0001 144 | beta_end: float = 0.02 145 | prediction_type: str = "epsilon" 146 | clip_sample: bool = True 147 | clip_sample_range: float = 1.0 148 | 149 | # Inference 150 | num_inference_steps: int | None = None 151 | 152 | # Loss computation 153 | do_mask_loss_for_padding: bool = False 154 | 155 | def __post_init__(self): 156 | """Input validation (not exhaustive).""" 157 | if not self.vision_backbone.startswith("resnet"): 158 | raise ValueError( 159 | f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." 160 | ) 161 | 162 | image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} 163 | 164 | if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes: 165 | raise ValueError("You must provide at least one image or the environment state among the inputs.") 166 | 167 | if len(image_keys) > 0: 168 | if self.crop_shape is not None: 169 | for image_key in image_keys: 170 | if ( 171 | self.crop_shape[0] > self.input_shapes[image_key][1] 172 | or self.crop_shape[1] > self.input_shapes[image_key][2] 173 | ): 174 | raise ValueError( 175 | f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " 176 | f"for `crop_shape` and {self.input_shapes[image_key]} for " 177 | "`input_shapes[{image_key}]`." 178 | ) 179 | # Check that all input images have the same shape. 180 | first_image_key = next(iter(image_keys)) 181 | for image_key in image_keys: 182 | if self.input_shapes[image_key] != self.input_shapes[first_image_key]: 183 | raise ValueError( 184 | f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " 185 | "expect all image shapes to match." 186 | ) 187 | 188 | supported_prediction_types = ["epsilon", "sample"] 189 | if self.prediction_type not in supported_prediction_types: 190 | raise ValueError( 191 | f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." 192 | ) 193 | supported_noise_schedulers = ["DDPM", "DDIM"] 194 | if self.noise_scheduler_type not in supported_noise_schedulers: 195 | raise ValueError( 196 | f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. " 197 | f"Got {self.noise_scheduler_type}." 198 | ) 199 | -------------------------------------------------------------------------------- /itps/common/policies/rollout_wrapper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import threading 3 | import time 4 | from concurrent.futures import ThreadPoolExecutor, TimeoutError 5 | from copy import deepcopy 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | from common.policies.policy_protocol import Policy 11 | from common.policies.utils import get_device_from_parameters 12 | 13 | 14 | class PolicyRolloutWrapper: 15 | """Use this wrapper around a policy that you plan to roll out in an environment. 16 | 17 | This wrapper bridges the gap between the world of policies where we use the RL formalism, and real world 18 | environments where time may be a factor. It may still be used on synchronous simulation environments. 19 | 20 | The policy is designed to take in a sequence of s observations ending at some time-step n and return a 21 | sequence of h actions starting from that timestep n. Formally: 22 | 23 | a_{n:n+h-1} = π(o_{n-s+1:n}) 24 | 25 | This wrapper manages observations in a cache according to their timestamps in order to give the policy the 26 | correct inputs for the desired action timestamps. 27 | 28 | When running in synchronous simulation environments, the logic reduces to simply: 29 | 30 | 1. Run inference synchronously to get a sequence of actions. Return the first action. 31 | 2. Keep returning actions from the previously generated sequence. 32 | 3. When the sequence is depleted, go back to 1. 33 | 34 | When simulating real-time or running in the real world, the logic does something more like: 35 | 36 | 1. Run inference synchronously to get a sequence of actions. Return the first action. 37 | 2. Keep returning actions from the previously generated sequence. 38 | 3. When the sequence has <= n_action_buffer actions in it, run asynchronous inference. In the meantime go 39 | back to 2. 40 | 41 | The implementation details are more involved and are documented inline. The main point though is that we 42 | can set n_action_buffer to be just large enough such that before we run out of actions, we already have a 43 | new sequence ready. 44 | """ 45 | 46 | def __init__(self, policy: Policy, fps: float, n_action_buffer: int = 0): 47 | """ 48 | Args: 49 | policy: The policy to wrap. 50 | fps: The observation/action clock frequency. It is assumed that these are the same, but it is 51 | possible for the phases of the clocks to be offset from one another. 52 | n_action_buffer: As soon as the action buffer has <= n_action_buffer actions left, start an 53 | inference run. 54 | """ 55 | self.policy = policy 56 | self.period_ms = int(round(1000 * (1 / fps))) 57 | # We'll allow half a clock cycle of tolerance on timestamp retrieval. 58 | self.timestamp_tolerance_ms = int(round(1000 * (1 / fps / 2))) 59 | self.n_action_buffer = n_action_buffer 60 | 61 | # Set up async related logic. 62 | self._threadpool_executor = ThreadPoolExecutor(max_workers=1) 63 | self._thread_lock = threading.Lock() 64 | 65 | self.reset() 66 | 67 | def __del__(self): 68 | """TODO(now): This isn't really working. Runtime exits with an exception.""" 69 | self._threadpool_executor.shutdown(wait=True, cancel_futures=True) 70 | 71 | def reset(self): 72 | """Reset observation and action caches. 73 | 74 | NOTE: Ensure that any access to these caches is within a thread-locked context as their state is 75 | managed in different threads. 76 | """ 77 | with self._thread_lock: 78 | # Store a mapping from observation timestamp (the moment the observation was captured) to 79 | # observation batch. 80 | self._observation_cache: dict[int, dict[str, Tensor]] = {} 81 | # Store a mapping from action timestamp (the moment the policy intends for the action to be) 82 | # executed. 83 | self._action_cache: dict[int, Tensor] = {} 84 | 85 | def invalidate_action_cache(self): 86 | with self._thread_lock: 87 | self._action_cache: dict[int, Tensor] = {} 88 | 89 | def _invalidate_obsolete_observations(self): 90 | """TODO(now)""" 91 | 92 | def run_inference( 93 | self, 94 | observation_timestamp_ms: int, 95 | action_timestamp_ms: int, 96 | strict_observation_timestamps: bool = False, 97 | guide: Tensor | None = None, 98 | visualizer=None, 99 | ): 100 | """ 101 | Construct an observation sequence from the observation cache, use that as input for running inference, 102 | and update the action cache with the result. 103 | """ 104 | # Stack relevant observations into a sequence. 105 | observation_timestamps_ms = torch.tensor( 106 | [action_timestamp_ms + i * self.period_ms for i in range(1 - self.policy.n_obs_steps, 1)] 107 | ) 108 | with self._thread_lock: 109 | observation_cache_timestamps_ms = torch.tensor(sorted(self._observation_cache.keys())) 110 | dist = torch.cdist( 111 | observation_timestamps_ms.unsqueeze(-1).float(), 112 | observation_cache_timestamps_ms.unsqueeze(-1).float(), 113 | p=1, 114 | ).int() 115 | min_, argmin_ = dist.min(dim=1) 116 | if torch.any(min_ > self.timestamp_tolerance_ms): 117 | msg = "Couldn't find observations within the required timestamp tolerance." 118 | if strict_observation_timestamps: 119 | raise RuntimeError(msg) 120 | else: 121 | # logging.warning(msg) 122 | pass 123 | with self._thread_lock: 124 | observation_sequence_batch = { 125 | k: torch.stack( 126 | [ 127 | self._observation_cache[ts.item()][k] 128 | for ts in observation_cache_timestamps_ms[argmin_] 129 | ], 130 | dim=1, 131 | ) 132 | for k in next(iter(self._observation_cache.values())) 133 | if k.startswith("observation") 134 | } 135 | 136 | # Forget any observations we won't be needing any more. 137 | self._invalidate_obsolete_observations() # TODO(now) 138 | 139 | # Run inference. 140 | device = get_device_from_parameters(self.policy) 141 | observation_sequence_batch = { 142 | key: observation_sequence_batch[key].to(device, non_blocking=True) 143 | for key in observation_sequence_batch 144 | } 145 | actions = self.policy.run_inference(observation_sequence_batch, guide=guide, visualizer=visualizer).cpu() # (batch, seq, action_dim) 146 | 147 | # Update action cache. 148 | with self._thread_lock: 149 | self._action_cache.update( 150 | { 151 | observation_timestamp_ms + i * self.period_ms: action 152 | for i, action in enumerate(actions.transpose(1, 0)) 153 | } 154 | ) 155 | 156 | def _get_contiguous_action_sequence_from_cache(self, first_action_timestamp_ms: float) -> Tensor | None: 157 | with self._thread_lock: 158 | action_cache = deepcopy(self._action_cache) 159 | if len(action_cache) == 0: 160 | return None 161 | action_cache_timestamps_ms = torch.tensor(sorted(action_cache)) 162 | action_timestamps_ms = ( 163 | torch.arange(0, action_cache_timestamps_ms.max() + self.period_ms, self.period_ms) 164 | + first_action_timestamp_ms 165 | ) 166 | dist = torch.cdist( 167 | action_timestamps_ms.unsqueeze(-1).float(), 168 | action_cache_timestamps_ms.unsqueeze(-1).float(), 169 | p=1, 170 | ).int() 171 | min_, argmin_ = dist.min(dim=1) 172 | if min_[0] > self.timestamp_tolerance_ms: 173 | return None 174 | # Get contiguous sequence of argmins_ starting from 0. 175 | where_jump = torch.where(argmin_.diff() != 1)[0] 176 | if len(where_jump) > 0: 177 | argmin_ = argmin_[: where_jump[0] + 1] 178 | # Retrieve and stack the actions. 179 | action_sequence = torch.stack( 180 | [action_cache[ts.item()] for ts in action_cache_timestamps_ms[argmin_]], 181 | dim=0, 182 | ) 183 | return action_sequence 184 | 185 | def provide_observation_get_actions( 186 | self, 187 | observation_batch: dict[str, Tensor], 188 | observation_timestamp: float, 189 | first_action_timestamp: float, 190 | strict_observation_timestamps: bool = False, 191 | timeout: float | None = None, 192 | guide: Tensor | None = None, 193 | visualizer=None, 194 | ) -> Tensor | None: 195 | """Provide an observation and get an action sequence back. 196 | 197 | This method does several things: 198 | 1. Accepts an observation with a timestamp. This is added to the observation cache. 199 | 2. Runs inference either synchronously or asynchronously. 200 | 3. Returns a sequence of actions starting from the requested timestamp (from a cache which is 201 | populated by inference outputs). 202 | 203 | If `timeout` is not provided, inference is run synchronously. If `timeout` is provided, inference runs 204 | asynchronously, and the function aims to return either an action sequence or None within the timeout 205 | period. It is guaranteed that if the timeout is not honored, a RuntimeError is raised (TODO(now)). 206 | 207 | All time related arguments are to be provided in units of seconds, relative to an arbitrary reference 208 | point which is fixed throughout the duration of the rollout. 209 | 210 | Args: 211 | observation_batch: Mapping of observation type key to observation batch tensor for a single time 212 | step. 213 | observation_timestamp: The timestamp associated with observation_batch. It should be as faithful 214 | as possible to the time at which the observation was captured. 215 | first_action_timestamp: The timestamp of the first action in the requested action sequence. 216 | strict_observation_timestamps: Whether to raise a RuntimeError if there are no observations in the 217 | cache with the timestamps needed to construct the inference inputs (ie there are no 218 | observations within `self.timestamp_tolerance_ms`). 219 | Returns: 220 | A (sequence, batch, action_dim) tensor for a sequence of actions starting from the requested 221 | `first_action_timestamp` and spaced by `1/fps` or None if the `timeout` is reached and there is no 222 | first action available. 223 | """ 224 | start = time.perf_counter() 225 | # Immediately convert timestamps to integer milliseconds (so that hashing them for the cache keys 226 | # isn't susceptible to floating point issues). 227 | observation_timestamp_ms = int(round(observation_timestamp * 1000)) 228 | del observation_timestamp # defensive against accidentally using the seconds version 229 | first_action_timestamp_ms = int(round(first_action_timestamp * 1000)) 230 | del first_action_timestamp # defensive against accidentally using the seconds version 231 | 232 | # Update observation cache. 233 | # if not set(observation_batch).issubset(self.policy.input_keys): 234 | # raise ValueError( 235 | # f"Missing observation_keys: {set(self.policy.input_keys).difference(set(observation_batch))}" 236 | # ) 237 | with self._thread_lock: 238 | self._observation_cache[observation_timestamp_ms] = observation_batch 239 | 240 | ret = None # placeholder for this function's return value 241 | 242 | # Try retrieving an action sequence from the cache starting from `first_action_timestamp` and spaced 243 | # by `1 / fps`. While doing so remove stale actions (those which are older and outside tolerance). 244 | with self._thread_lock: 245 | action_cache_timestamps_ms = torch.tensor(sorted(self._action_cache)) 246 | if len(action_cache_timestamps_ms) > 0: 247 | diff = action_cache_timestamps_ms - first_action_timestamp_ms 248 | to_delete = torch.where(torch.bitwise_and(diff < 0, diff.abs() > self.timestamp_tolerance_ms))[0] 249 | for ix in to_delete: 250 | with self._thread_lock: 251 | del self._action_cache[action_cache_timestamps_ms[ix.item()].item()] 252 | # If the first action is in the cache, construct the action sequence. 253 | if diff.abs().argmin() <= self.timestamp_tolerance_ms: 254 | ret = self._get_contiguous_action_sequence_from_cache(first_action_timestamp_ms) 255 | 256 | if first_action_timestamp_ms < observation_timestamp_ms: 257 | raise RuntimeError("No action could be found in the cache, and we can't generate a past action.") 258 | 259 | # We would like to run inference if we don't have many actions left in the cache. 260 | want_to_run_inference = ret is None or (ret is not None and ret.shape[0] - 1 <= self.n_action_buffer) 261 | # Return an action right away if we know we don't want to run inference. 262 | if not want_to_run_inference: 263 | return ret 264 | 265 | # We can't run inference if a previous inference is already running. 266 | if hasattr(self, "_future") and self._future.running(): 267 | # Try to give the previous inference a chance to finish (within the allowable time limit). 268 | try: 269 | # TODO(now): the 1e-3 needs explaining 270 | timeout_ = None if timeout is None else timeout - (time.perf_counter() - start) - 1e-3 271 | self._future.result(timeout=timeout_) 272 | except TimeoutError: 273 | if ret is None: 274 | logging.warning("Your inference is begining to fall behind your rollout loop!") 275 | return ret 276 | 277 | # Start the inference job. 278 | self._future = self._threadpool_executor.submit( 279 | self.run_inference, 280 | observation_timestamp_ms, 281 | first_action_timestamp_ms, 282 | strict_observation_timestamps, 283 | guide, 284 | visualizer, 285 | ) 286 | 287 | # Attempt to wait for inference to complete, within the bounds of the `timeout` parameter. 288 | try: 289 | if timeout is None: 290 | self._future.result() 291 | else: 292 | elapsed = time.perf_counter() - start 293 | buffer = 1e-3 # ample buffer for the rest of the function to complete 294 | timeout_ = timeout - elapsed - buffer 295 | if timeout_ > 0: 296 | self._future.result(timeout=timeout_) 297 | except TimeoutError: 298 | pass 299 | 300 | # Return the actions we had extracted from the cache before starting inference (only if inference 301 | # is still running, otherwise we can get actions from the fresher cache). 302 | if self._future.running() and ret is not None: 303 | return ret 304 | 305 | return self._get_contiguous_action_sequence_from_cache(first_action_timestamp_ms) 306 | -------------------------------------------------------------------------------- /itps/common/datasets/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import json 17 | import re 18 | from pathlib import Path 19 | from typing import Dict 20 | 21 | import datasets 22 | import torch 23 | from datasets import load_dataset, load_from_disk 24 | from huggingface_hub import hf_hub_download, snapshot_download 25 | from PIL import Image as PILImage 26 | from safetensors.torch import load_file 27 | from torchvision import transforms 28 | 29 | 30 | def flatten_dict(d, parent_key="", sep="/"): 31 | """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. 32 | 33 | For example: 34 | ``` 35 | >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` 36 | >>> print(flatten_dict(dct)) 37 | {"a/b": 1, "a/c/d": 2, "e": 3} 38 | """ 39 | items = [] 40 | for k, v in d.items(): 41 | new_key = f"{parent_key}{sep}{k}" if parent_key else k 42 | if isinstance(v, dict): 43 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 44 | else: 45 | items.append((new_key, v)) 46 | return dict(items) 47 | 48 | 49 | def unflatten_dict(d, sep="/"): 50 | outdict = {} 51 | for key, value in d.items(): 52 | parts = key.split(sep) 53 | d = outdict 54 | for part in parts[:-1]: 55 | if part not in d: 56 | d[part] = {} 57 | d = d[part] 58 | d[parts[-1]] = value 59 | return outdict 60 | 61 | 62 | def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): 63 | """Get a transform function that convert items from Hugging Face dataset (pyarrow) 64 | to torch tensors. Importantly, images are converted from PIL, which corresponds to 65 | a channel last representation (h w c) of uint8 type, to a torch image representation 66 | with channel first (c h w) of float32 type in range [0,1]. 67 | """ 68 | for key in items_dict: 69 | first_item = items_dict[key][0] 70 | if isinstance(first_item, PILImage.Image): 71 | to_tensor = transforms.ToTensor() 72 | items_dict[key] = [to_tensor(img) for img in items_dict[key]] 73 | elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: 74 | # video frame will be processed downstream 75 | pass 76 | elif first_item is None: 77 | pass 78 | else: 79 | items_dict[key] = [torch.tensor(x) for x in items_dict[key]] 80 | return items_dict 81 | 82 | 83 | def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: 84 | """hf_dataset contains all the observations, states, actions, rewards, etc.""" 85 | if root is not None: 86 | hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) 87 | # TODO(rcadene): clean this which enables getting a subset of dataset 88 | if split != "train": 89 | if "%" in split: 90 | raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).") 91 | match_from = re.search(r"train\[(\d+):\]", split) 92 | match_to = re.search(r"train\[:(\d+)\]", split) 93 | if match_from: 94 | from_frame_index = int(match_from.group(1)) 95 | hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset))) 96 | elif match_to: 97 | to_frame_index = int(match_to.group(1)) 98 | hf_dataset = hf_dataset.select(range(to_frame_index)) 99 | else: 100 | raise ValueError( 101 | f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' 102 | ) 103 | else: 104 | hf_dataset = load_dataset(repo_id, revision=version, split=split) 105 | hf_dataset.set_transform(hf_transform_to_torch) 106 | return hf_dataset 107 | 108 | 109 | def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]: 110 | """episode_data_index contains the range of indices for each episode 111 | 112 | Example: 113 | ```python 114 | from_id = episode_data_index["from"][episode_id].item() 115 | to_id = episode_data_index["to"][episode_id].item() 116 | episode_frames = [dataset[i] for i in range(from_id, to_id)] 117 | ``` 118 | """ 119 | if root is not None: 120 | path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors" 121 | else: 122 | path = hf_hub_download( 123 | repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version 124 | ) 125 | 126 | return load_file(path) 127 | 128 | 129 | def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]: 130 | """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std 131 | 132 | Example: 133 | ```python 134 | normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"] 135 | ``` 136 | """ 137 | if root is not None: 138 | path = Path(root) / repo_id / "meta_data" / "stats.safetensors" 139 | else: 140 | path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version) 141 | 142 | stats = load_file(path) 143 | return unflatten_dict(stats) 144 | 145 | 146 | def load_info(repo_id, version, root) -> dict: 147 | """info contains useful information regarding the dataset that are not stored elsewhere 148 | 149 | Example: 150 | ```python 151 | print("frame per second used to collect the video", info["fps"]) 152 | ``` 153 | """ 154 | if root is not None: 155 | path = Path(root) / repo_id / "meta_data" / "info.json" 156 | else: 157 | path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version) 158 | 159 | with open(path) as f: 160 | info = json.load(f) 161 | return info 162 | 163 | 164 | def load_videos(repo_id, version, root) -> Path: 165 | if root is not None: 166 | path = Path(root) / repo_id / "videos" 167 | else: 168 | # TODO(rcadene): we download the whole repo here. see if we can avoid this 169 | repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=version) 170 | path = Path(repo_dir) / "videos" 171 | 172 | return path 173 | 174 | 175 | def load_previous_and_future_frames( 176 | item: dict[str, torch.Tensor], 177 | hf_dataset: datasets.Dataset, 178 | episode_data_index: dict[str, torch.Tensor], 179 | delta_timestamps: dict[str, list[float]], 180 | tolerance_s: float, 181 | ) -> dict[torch.Tensor]: 182 | """ 183 | Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of 184 | some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each 185 | given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest 186 | frames in the dataset. 187 | 188 | Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function 189 | raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after 190 | the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function 191 | populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array 192 | is useful during batched training to not supervise actions associated to timestamps coming after the end of the 193 | episode, or to pad the observations in a specific way. Note that by default the observation frames before the start 194 | of the episode are the same as the first frame of the episode. 195 | 196 | Parameters: 197 | - item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key 198 | corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). 199 | - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different 200 | modality (e.g., "timestamp", "observation.image", "action"). 201 | - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. 202 | They indicate the start index and end index of each episode in the dataset. 203 | - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be 204 | retrieved. These deltas are added to the item timestamp to form the query timestamps. 205 | - tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query 206 | timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the 207 | smallest expected inter-frame period, but large enough to account for jitter. 208 | 209 | Returns: 210 | - The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for 211 | each modality (e.g. "observation.image_is_pad"). 212 | 213 | Raises: 214 | - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization 215 | issues with timestamps during data collection. 216 | """ 217 | # get indices of the frames associated to the episode, and their timestamps 218 | ep_id = item["episode_index"].item() 219 | ep_data_id_from = episode_data_index["from"][ep_id].item() 220 | ep_data_id_to = episode_data_index["to"][ep_id].item() 221 | ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) 222 | 223 | # load timestamps 224 | ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"] 225 | ep_timestamps = torch.stack(ep_timestamps) 226 | 227 | # we make the assumption that the timestamps are sorted 228 | ep_first_ts = ep_timestamps[0] 229 | ep_last_ts = ep_timestamps[-1] 230 | current_ts = item["timestamp"].item() 231 | 232 | for key in delta_timestamps: 233 | # get timestamps used as query to retrieve data of previous/future frames 234 | delta_ts = delta_timestamps[key] 235 | query_ts = current_ts + torch.tensor(delta_ts) 236 | 237 | # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode 238 | dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1) 239 | min_, argmin_ = dist.min(1) 240 | 241 | # TODO(rcadene): synchronize timestamps + interpolation if needed 242 | 243 | is_pad = min_ > tolerance_s 244 | 245 | # check violated query timestamps are all outside the episode range 246 | assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( 247 | f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range." 248 | "This might be due to synchronization issues with timestamps during data collection." 249 | ) 250 | 251 | # get dataset indices corresponding to frames to be loaded 252 | data_ids = ep_data_ids[argmin_] 253 | 254 | # load frames modality 255 | item[key] = hf_dataset.select_columns(key)[data_ids][key] 256 | 257 | if isinstance(item[key][0], dict) and "path" in item[key][0]: 258 | # video mode where frame are expressed as dict of path and timestamp 259 | item[key] = item[key] 260 | else: 261 | item[key] = torch.stack(item[key]) 262 | 263 | item[f"{key}_is_pad"] = is_pad 264 | 265 | return item 266 | 267 | 268 | def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: 269 | """ 270 | Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. 271 | 272 | Parameters: 273 | - hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index. 274 | 275 | Returns: 276 | - episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys: 277 | - "from": A tensor containing the starting index of each episode. 278 | - "to": A tensor containing the ending index of each episode. 279 | """ 280 | episode_data_index = {"from": [], "to": []} 281 | 282 | current_episode = None 283 | """ 284 | The episode_index is a list of integers, each representing the episode index of the corresponding example. 285 | For instance, the following is a valid episode_index: 286 | [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2] 287 | 288 | Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and 289 | ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this: 290 | { 291 | "from": [0, 3, 7], 292 | "to": [3, 7, 12] 293 | } 294 | """ 295 | if len(hf_dataset) == 0: 296 | episode_data_index = { 297 | "from": torch.tensor([]), 298 | "to": torch.tensor([]), 299 | } 300 | return episode_data_index 301 | for idx, episode_idx in enumerate(hf_dataset["episode_index"]): 302 | if episode_idx != current_episode: 303 | # We encountered a new episode, so we append its starting location to the "from" list 304 | episode_data_index["from"].append(idx) 305 | # If this is not the first episode, we append the ending location of the previous episode to the "to" list 306 | if current_episode is not None: 307 | episode_data_index["to"].append(idx) 308 | # Let's keep track of the current episode index 309 | current_episode = episode_idx 310 | else: 311 | # We are still in the same episode, so there is nothing for us to do here 312 | pass 313 | # We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list 314 | episode_data_index["to"].append(idx + 1) 315 | 316 | for k in ["from", "to"]: 317 | episode_data_index[k] = torch.tensor(episode_data_index[k]) 318 | 319 | return episode_data_index 320 | 321 | 322 | def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: 323 | """Reset the `episode_index` of the provided HuggingFace Dataset. 324 | 325 | `episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the 326 | `episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0. 327 | 328 | This brings the `episode_index` to the required format. 329 | """ 330 | if len(hf_dataset) == 0: 331 | return hf_dataset 332 | unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() 333 | episode_idx_to_reset_idx_mapping = { 334 | ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) 335 | } 336 | 337 | def modify_ep_idx_func(example): 338 | example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()] 339 | return example 340 | 341 | hf_dataset = hf_dataset.map(modify_ep_idx_func) 342 | 343 | return hf_dataset 344 | 345 | 346 | def cycle(iterable): 347 | """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. 348 | 349 | See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. 350 | """ 351 | iterator = iter(iterable) 352 | while True: 353 | try: 354 | yield next(iterator) 355 | except StopIteration: 356 | iterator = iter(iterable) 357 | -------------------------------------------------------------------------------- /itps/common/datasets/lerobot_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import logging 17 | import os 18 | from pathlib import Path 19 | from typing import Callable 20 | 21 | import datasets 22 | import torch 23 | import torch.utils 24 | 25 | from common.datasets.compute_stats import aggregate_stats 26 | from common.datasets.utils import ( 27 | calculate_episode_data_index, 28 | load_episode_data_index, 29 | load_hf_dataset, 30 | load_info, 31 | load_previous_and_future_frames, 32 | load_stats, 33 | load_videos, 34 | reset_episode_index, 35 | ) 36 | from common.datasets.video_utils import VideoFrame, load_from_videos 37 | 38 | DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None 39 | CODEBASE_VERSION = "v1.5" 40 | 41 | 42 | class LeRobotDataset(torch.utils.data.Dataset): 43 | def __init__( 44 | self, 45 | repo_id: str, 46 | version: str | None = CODEBASE_VERSION, 47 | root: Path | None = DATA_DIR, 48 | split: str = "train", 49 | image_transforms: Callable | None = None, 50 | delta_timestamps: dict[list[float]] | None = None, 51 | video_backend: str | None = None, 52 | ): 53 | super().__init__() 54 | self.repo_id = repo_id 55 | self.version = version 56 | self.root = root 57 | self.split = split 58 | self.image_transforms = image_transforms 59 | self.delta_timestamps = delta_timestamps 60 | # load data from hub or locally when root is provided 61 | # TODO(rcadene, aliberts): implement faster transfer 62 | # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads 63 | self.hf_dataset = load_hf_dataset(repo_id, version, root, split) 64 | if split == "train": 65 | self.episode_data_index = load_episode_data_index(repo_id, version, root) 66 | else: 67 | self.episode_data_index = calculate_episode_data_index(self.hf_dataset) 68 | self.hf_dataset = reset_episode_index(self.hf_dataset) 69 | self.stats = load_stats(repo_id, version, root) 70 | self.info = load_info(repo_id, version, root) 71 | if self.video: 72 | self.videos_dir = load_videos(repo_id, version, root) 73 | self.video_backend = video_backend if video_backend is not None else "pyav" 74 | 75 | @property 76 | def fps(self) -> int: 77 | """Frames per second used during data collection.""" 78 | return self.info["fps"] 79 | 80 | @property 81 | def video(self) -> bool: 82 | """Returns True if this dataset loads video frames from mp4 files. 83 | Returns False if it only loads images from png files. 84 | """ 85 | return self.info.get("video", False) 86 | 87 | @property 88 | def features(self) -> datasets.Features: 89 | return self.hf_dataset.features 90 | 91 | @property 92 | def camera_keys(self) -> list[str]: 93 | """Keys to access image and video stream from cameras.""" 94 | keys = [] 95 | for key, feats in self.hf_dataset.features.items(): 96 | if isinstance(feats, (datasets.Image, VideoFrame)): 97 | keys.append(key) 98 | return keys 99 | 100 | @property 101 | def video_frame_keys(self) -> list[str]: 102 | """Keys to access video frames that requires to be decoded into images. 103 | 104 | Note: It is empty if the dataset contains images only, 105 | or equal to `self.cameras` if the dataset contains videos only, 106 | or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. 107 | """ 108 | video_frame_keys = [] 109 | for key, feats in self.hf_dataset.features.items(): 110 | if isinstance(feats, VideoFrame): 111 | video_frame_keys.append(key) 112 | return video_frame_keys 113 | 114 | @property 115 | def num_samples(self) -> int: 116 | """Number of samples/frames.""" 117 | return len(self.hf_dataset) 118 | 119 | @property 120 | def num_episodes(self) -> int: 121 | """Number of episodes.""" 122 | return len(self.hf_dataset.unique("episode_index")) 123 | 124 | @property 125 | def tolerance_s(self) -> float: 126 | """Tolerance in seconds used to discard loaded frames when their timestamps 127 | are not close enough from the requested frames. It is only used when `delta_timestamps` 128 | is provided or when loading video frames from mp4 files. 129 | """ 130 | # 1e-4 to account for possible numerical error 131 | return 1 / self.fps - 1e-4 132 | 133 | def __len__(self): 134 | return self.num_samples 135 | 136 | def __getitem__(self, idx): 137 | item = self.hf_dataset[idx] 138 | 139 | if self.delta_timestamps is not None: 140 | item = load_previous_and_future_frames( 141 | item, 142 | self.hf_dataset, 143 | self.episode_data_index, 144 | self.delta_timestamps, 145 | self.tolerance_s, 146 | ) 147 | 148 | if self.video: 149 | item = load_from_videos( 150 | item, 151 | self.video_frame_keys, 152 | self.videos_dir, 153 | self.tolerance_s, 154 | self.video_backend, 155 | ) 156 | 157 | if self.image_transforms is not None: 158 | for cam in self.camera_keys: 159 | item[cam] = self.image_transforms(item[cam]) 160 | 161 | return item 162 | 163 | def __repr__(self): 164 | return ( 165 | f"{self.__class__.__name__}(\n" 166 | f" Repository ID: '{self.repo_id}',\n" 167 | f" Version: '{self.version}',\n" 168 | f" Split: '{self.split}',\n" 169 | f" Number of Samples: {self.num_samples},\n" 170 | f" Number of Episodes: {self.num_episodes},\n" 171 | f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" 172 | f" Recorded Frames per Second: {self.fps},\n" 173 | f" Camera Keys: {self.camera_keys},\n" 174 | f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" 175 | f" Transformations: {self.image_transforms},\n" 176 | f")" 177 | ) 178 | 179 | @classmethod 180 | def from_preloaded( 181 | cls, 182 | repo_id: str = "from_preloaded", 183 | version: str | None = CODEBASE_VERSION, 184 | root: Path | None = None, 185 | split: str = "train", 186 | transform: callable = None, 187 | delta_timestamps: dict[list[float]] | None = None, 188 | # additional preloaded attributes 189 | hf_dataset=None, 190 | episode_data_index=None, 191 | stats=None, 192 | info=None, 193 | videos_dir=None, 194 | video_backend=None, 195 | ) -> "LeRobotDataset": 196 | """Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. 197 | 198 | It is especially useful when converting raw data into LeRobotDataset before saving the dataset 199 | on the filesystem or uploading to the hub. 200 | 201 | Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially 202 | meaningless depending on the downstream usage of the return dataset. 203 | """ 204 | # create an empty object of type LeRobotDataset 205 | obj = cls.__new__(cls) 206 | obj.repo_id = repo_id 207 | obj.version = version 208 | obj.root = root 209 | obj.split = split 210 | obj.image_transforms = transform 211 | obj.delta_timestamps = delta_timestamps 212 | obj.hf_dataset = hf_dataset 213 | obj.episode_data_index = episode_data_index 214 | obj.stats = stats 215 | obj.info = info if info is not None else {} 216 | obj.videos_dir = videos_dir 217 | obj.video_backend = video_backend if video_backend is not None else "pyav" 218 | return obj 219 | 220 | 221 | class MultiLeRobotDataset(torch.utils.data.Dataset): 222 | """A dataset consisting of multiple underlying `LeRobotDataset`s. 223 | 224 | The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API 225 | structure of `LeRobotDataset`. 226 | """ 227 | 228 | def __init__( 229 | self, 230 | repo_ids: list[str], 231 | version: str | None = CODEBASE_VERSION, 232 | root: Path | None = DATA_DIR, 233 | split: str = "train", 234 | image_transforms: Callable | None = None, 235 | delta_timestamps: dict[list[float]] | None = None, 236 | video_backend: str | None = None, 237 | ): 238 | super().__init__() 239 | self.repo_ids = repo_ids 240 | # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which 241 | # are handled by this class. 242 | self._datasets = [ 243 | LeRobotDataset( 244 | repo_id, 245 | version=version, 246 | root=root, 247 | split=split, 248 | delta_timestamps=delta_timestamps, 249 | image_transforms=image_transforms, 250 | video_backend=video_backend, 251 | ) 252 | for repo_id in repo_ids 253 | ] 254 | # Check that some properties are consistent across datasets. Note: We may relax some of these 255 | # consistency requirements in future iterations of this class. 256 | for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): 257 | if dataset.info != self._datasets[0].info: 258 | raise ValueError( 259 | f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is " 260 | "not yet supported." 261 | ) 262 | # Disable any data keys that are not common across all of the datasets. Note: we may relax this 263 | # restriction in future iterations of this class. For now, this is necessary at least for being able 264 | # to use PyTorch's default DataLoader collate function. 265 | self.disabled_data_keys = set() 266 | intersection_data_keys = set(self._datasets[0].hf_dataset.features) 267 | for dataset in self._datasets: 268 | intersection_data_keys.intersection_update(dataset.hf_dataset.features) 269 | if len(intersection_data_keys) == 0: 270 | raise RuntimeError( 271 | "Multiple datasets were provided but they had no keys common to all of them. The " 272 | "multi-dataset functionality currently only keeps common keys." 273 | ) 274 | for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True): 275 | extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys) 276 | logging.warning( 277 | f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " 278 | "other datasets." 279 | ) 280 | self.disabled_data_keys.update(extra_keys) 281 | 282 | self.version = version 283 | self.root = root 284 | self.split = split 285 | self.image_transforms = image_transforms 286 | self.delta_timestamps = delta_timestamps 287 | self.stats = aggregate_stats(self._datasets) 288 | 289 | @property 290 | def repo_id_to_index(self): 291 | """Return a mapping from dataset repo_id to a dataset index automatically created by this class. 292 | 293 | This index is incorporated as a data key in the dictionary returned by `__getitem__`. 294 | """ 295 | return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} 296 | 297 | @property 298 | def repo_index_to_id(self): 299 | """Return the inverse mapping if repo_id_to_index.""" 300 | return {v: k for k, v in self.repo_id_to_index} 301 | 302 | @property 303 | def fps(self) -> int: 304 | """Frames per second used during data collection. 305 | 306 | NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. 307 | """ 308 | return self._datasets[0].info["fps"] 309 | 310 | @property 311 | def video(self) -> bool: 312 | """Returns True if this dataset loads video frames from mp4 files. 313 | 314 | Returns False if it only loads images from png files. 315 | 316 | NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. 317 | """ 318 | return self._datasets[0].info.get("video", False) 319 | 320 | @property 321 | def features(self) -> datasets.Features: 322 | features = {} 323 | for dataset in self._datasets: 324 | features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys}) 325 | return features 326 | 327 | @property 328 | def camera_keys(self) -> list[str]: 329 | """Keys to access image and video stream from cameras.""" 330 | keys = [] 331 | for key, feats in self.features.items(): 332 | if isinstance(feats, (datasets.Image, VideoFrame)): 333 | keys.append(key) 334 | return keys 335 | 336 | @property 337 | def video_frame_keys(self) -> list[str]: 338 | """Keys to access video frames that requires to be decoded into images. 339 | 340 | Note: It is empty if the dataset contains images only, 341 | or equal to `self.cameras` if the dataset contains videos only, 342 | or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. 343 | """ 344 | video_frame_keys = [] 345 | for key, feats in self.features.items(): 346 | if isinstance(feats, VideoFrame): 347 | video_frame_keys.append(key) 348 | return video_frame_keys 349 | 350 | @property 351 | def num_samples(self) -> int: 352 | """Number of samples/frames.""" 353 | return sum(d.num_samples for d in self._datasets) 354 | 355 | @property 356 | def num_episodes(self) -> int: 357 | """Number of episodes.""" 358 | return sum(d.num_episodes for d in self._datasets) 359 | 360 | @property 361 | def tolerance_s(self) -> float: 362 | """Tolerance in seconds used to discard loaded frames when their timestamps 363 | are not close enough from the requested frames. It is only used when `delta_timestamps` 364 | is provided or when loading video frames from mp4 files. 365 | """ 366 | # 1e-4 to account for possible numerical error 367 | return 1 / self.fps - 1e-4 368 | 369 | def __len__(self): 370 | return self.num_samples 371 | 372 | def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: 373 | if idx >= len(self): 374 | raise IndexError(f"Index {idx} out of bounds.") 375 | # Determine which dataset to get an item from based on the index. 376 | start_idx = 0 377 | dataset_idx = 0 378 | for dataset in self._datasets: 379 | if idx >= start_idx + dataset.num_samples: 380 | start_idx += dataset.num_samples 381 | dataset_idx += 1 382 | continue 383 | break 384 | else: 385 | raise AssertionError("We expect the loop to break out as long as the index is within bounds.") 386 | item = self._datasets[dataset_idx][idx - start_idx] 387 | item["dataset_index"] = torch.tensor(dataset_idx) 388 | for data_key in self.disabled_data_keys: 389 | if data_key in item: 390 | del item[data_key] 391 | 392 | return item 393 | 394 | def __repr__(self): 395 | return ( 396 | f"{self.__class__.__name__}(\n" 397 | f" Repository IDs: '{self.repo_ids}',\n" 398 | f" Version: '{self.version}',\n" 399 | f" Split: '{self.split}',\n" 400 | f" Number of Samples: {self.num_samples},\n" 401 | f" Number of Episodes: {self.num_episodes},\n" 402 | f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" 403 | f" Recorded Frames per Second: {self.fps},\n" 404 | f" Camera Keys: {self.camera_keys},\n" 405 | f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" 406 | f" Transformations: {self.image_transforms},\n" 407 | f")" 408 | ) 409 | -------------------------------------------------------------------------------- /itps/interact_maze2d.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | # Copyright (c) 2024 Yanwei Wang 4 | 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # Some of this software is derived from LeRobot, which is subject to the following copyright notice: 24 | 25 | # Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, 26 | # Tony Z. Zhao 27 | # and The HuggingFace Inc. team. All rights reserved. 28 | # Licensed under the Apache License, Version 2.0 (the "License"); 29 | # you may not use this file except in compliance with the License. 30 | # You may obtain a copy of the License at 31 | 32 | # http://www.apache.org/licenses/LICENSE-2.0 33 | 34 | # Unless required by applicable law or agreed to in writing, software 35 | # distributed under the License is distributed on an "AS IS" BASIS, 36 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 37 | # See the License for the specific language governing permissions and 38 | # limitations under the License. 39 | 40 | 41 | import sys, os 42 | import numpy as np 43 | import pygame 44 | import torch 45 | import argparse 46 | import matplotlib.pyplot as plt 47 | import einops 48 | from pathlib import Path 49 | from huggingface_hub import snapshot_download 50 | from common.policies.diffusion.modeling_diffusion import DiffusionPolicy 51 | from common.policies.act.modeling_act import ACTPolicy 52 | from common.policies.rollout_wrapper import PolicyRolloutWrapper 53 | from common.utils.utils import seeded_context, init_hydra_config 54 | from common.policies.factory import make_policy 55 | from common.datasets.factory import make_dataset 56 | from scipy.special import softmax 57 | import time 58 | import json 59 | 60 | class MazeEnv: 61 | def __init__(self): 62 | # GUI x coord 0 -> gui_size[0] #1200 63 | # GUI y coord 0 64 | # | 65 | # v 66 | # gui_size[1] #900 67 | # xy is in the same coordinate system as the background 68 | # bkg y coord 0 -> maze_shape[1] #12 69 | # bkg x coord 0 70 | # | 71 | # v 72 | # maze_shape[0] #9 73 | 74 | self.maze = np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 75 | [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1], 76 | [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], 77 | [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], 78 | [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1], 79 | [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1], 80 | [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1], 81 | [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1], 82 | [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]).astype(bool) 83 | self.gui_size = (1200, 900) 84 | self.fps = 10 85 | self.batch_size = 32 86 | self.offset = 0.5 # Offset to put object in the center of the cell 87 | 88 | self.WHITE = (255, 255, 255) 89 | self.RED = (255, 0, 0) 90 | self.GRAY = (128, 128, 128) 91 | self.agent_color = self.RED 92 | 93 | # Initialize Pygame 94 | pygame.init() 95 | self.screen = pygame.display.set_mode(self.gui_size) 96 | pygame.display.set_caption("Maze") 97 | self.clock = pygame.time.Clock() 98 | self.agent_gui_pos = np.array([0, 0]) # Initialize the position of the red dot 99 | self.running = True 100 | 101 | def check_collision(self, xy_traj): 102 | assert xy_traj.shape[2] == 2, "Input must be a 2D array of (x, y) coordinates." 103 | batch_size, num_steps, _ = xy_traj.shape 104 | xy_traj = xy_traj.reshape(-1, 2) 105 | xy_traj = np.clip(xy_traj, [0, 0], [self.maze.shape[0] - 1, self.maze.shape[1] - 1]) 106 | maze_x = np.round(xy_traj[:, 0]).astype(int) 107 | maze_y = np.round(xy_traj[:, 1]).astype(int) 108 | collisions = self.maze[maze_x, maze_y] 109 | collisions = collisions.reshape(batch_size, num_steps) 110 | return np.any(collisions, axis=1) 111 | 112 | def find_first_collision_from_GUI(self, gui_traj): 113 | assert gui_traj.shape[1] == 2, "Input must be a 2D array" 114 | xy_traj = np.array([self.gui2xy(point) for point in gui_traj]) 115 | xy_traj = np.clip(xy_traj, [0, 0], [self.maze.shape[0] - 1, self.maze.shape[1] - 1]) 116 | maze_x = np.round(xy_traj[:, 0]).astype(int) 117 | maze_y = np.round(xy_traj[:, 1]).astype(int) 118 | collisions = self.maze[maze_x, maze_y] 119 | first_collision_idx = np.argmax(collisions) # find the first index of many possible collisions 120 | return first_collision_idx 121 | 122 | def blend_with_white(self, color, factor=0.5): 123 | white = np.array([255, 255, 255]) 124 | blended_color = (1 - factor) * np.array(color) + factor * white 125 | return blended_color.astype(int) 126 | 127 | def report_collision_percentage(self, collisions): 128 | num_trajectories = collisions.shape[0] 129 | num_collisions = np.sum(collisions) 130 | collision_percentage = (num_collisions / num_trajectories) * 100 131 | print(f"{num_collisions}/{num_trajectories} trajectories are in collision ({collision_percentage:.2f}%).") 132 | return collision_percentage 133 | 134 | def xy2gui(self, xy): 135 | xy = xy + self.offset # Adjust normalization as necessary 136 | x = xy[0] * self.gui_size[1] / (self.maze.shape[0]) 137 | y = xy[1] * self.gui_size[0] / (self.maze.shape[1]) 138 | return np.array([y, x], dtype=float) 139 | 140 | def gui2xy(self, gui): 141 | x = gui[1] / self.gui_size[1] * self.maze.shape[0] - self.offset 142 | y = gui[0] / self.gui_size[0] * self.maze.shape[1] - self.offset 143 | return np.array([x, y], dtype=float) 144 | 145 | def generate_time_color_map(self, num_steps): 146 | cmap = plt.get_cmap('rainbow') 147 | values = np.linspace(0, 1, num_steps) 148 | colors = cmap(values) 149 | return colors 150 | 151 | def draw_maze_background(self): 152 | surface = pygame.surfarray.make_surface(255 - np.swapaxes(np.repeat(self.maze[:, :, np.newaxis] * 255, 3, axis=2).astype(np.uint8), 0, 1)) 153 | surface = pygame.transform.scale(surface, self.gui_size) 154 | self.screen.blit(surface, (0, 0)) 155 | 156 | def update_screen(self, xy_pred=None, collisions=None, scores=None, keep_drawing=False, traj_in_gui_space=False): 157 | self.draw_maze_background() 158 | if xy_pred is not None: 159 | time_colors = self.generate_time_color_map(xy_pred.shape[1]) 160 | if collisions is None: 161 | collisions = self.check_collision(xy_pred) 162 | # self.report_collision_percentage(collisions) 163 | for idx, pred in enumerate(xy_pred): 164 | for step_idx in range(len(pred) - 1): 165 | color = (time_colors[step_idx, :3] * 255).astype(int) 166 | 167 | # visualize constraint violations (collisions) by tinting trajectories white 168 | whiteness_factor = 0.8 if collisions[idx] else 0.0 169 | color = self.blend_with_white(color, whiteness_factor) 170 | if scores is None: 171 | circle_size = 5 if collisions[idx] else 5 172 | else: # when similarity scores are provided, visualizing them by changing the trajectory size 173 | circle_size = int(3 + 20 * scores[idx]) 174 | if traj_in_gui_space: 175 | start_pos = pred[step_idx] 176 | end_pos = pred[step_idx + 1] 177 | else: 178 | start_pos = self.xy2gui(pred[step_idx]) 179 | end_pos = self.xy2gui(pred[step_idx + 1]) 180 | pygame.draw.circle(self.screen, color, start_pos, circle_size) 181 | 182 | pygame.draw.circle(self.screen, self.agent_color, (int(self.agent_gui_pos[0]), int(self.agent_gui_pos[1])), 20) 183 | if keep_drawing: # visualize the human drawing input 184 | for i in range(len(self.draw_traj) - 1): 185 | pygame.draw.line(self.screen, self.GRAY, self.draw_traj[i], self.draw_traj[i + 1], 10) 186 | 187 | 188 | pygame.display.flip() 189 | 190 | def similarity_score(self, samples, guide=None): 191 | # samples: (B, pred_horizon, action_dim) 192 | # guide: (guide_horizon, action_dim) 193 | if guide is None: 194 | return samples, None 195 | assert samples.shape[2] == 2 and guide.shape[1] == 2 196 | indices = np.linspace(0, guide.shape[0]-1, samples.shape[1], dtype=int) 197 | guide = np.expand_dims(guide[indices], axis=0) # (1, pred_horizon, action_dim) 198 | guide = np.tile(guide, (samples.shape[0], 1, 1)) # (B, pred_horizon, action_dim) 199 | scores = np.linalg.norm(samples[:, :] - guide[:, :], axis=2, ord=2).mean(axis=1) # (B,) 200 | scores = 1 - scores / (scores.max() + 1e-6) # normalize 201 | temperature = 20 202 | scores = softmax(scores*temperature) 203 | # normalize the score to be between 0 and 1 204 | scores = (scores - scores.min()) / (scores.max() - scores.min()) 205 | # sort the predictions based on scores, from smallest to largest, so that larger scores will be drawn on top 206 | sort_idx = np.argsort(scores) 207 | samples = samples[sort_idx] 208 | scores = scores[sort_idx] 209 | return samples, scores 210 | 211 | class UnconditionalMaze(MazeEnv): 212 | # for dragging the agent around to explore motion manifold 213 | def __init__(self, policy, policy_tag=None): 214 | super().__init__() 215 | self.mouse_pos = None 216 | self.agent_in_collision = False 217 | self.agent_history_xy = [] 218 | self.policy = policy 219 | self.policy_tag = policy_tag 220 | 221 | def infer_target(self, guide=None, visualizer=None): 222 | agent_hist_xy = self.agent_history_xy[-1] 223 | agent_hist_xy = np.array(agent_hist_xy).reshape(1, 2) 224 | if self.policy_tag == 'dp': 225 | agent_hist_xy = agent_hist_xy.repeat(2, axis=0) 226 | 227 | obs_batch = { 228 | "observation.state": einops.repeat( 229 | torch.from_numpy(agent_hist_xy).float().cuda(), "t d -> b t d", b=self.batch_size 230 | ) 231 | } 232 | obs_batch["observation.environment_state"] = einops.repeat( 233 | torch.from_numpy(agent_hist_xy).float().cuda(), "t d -> b t d", b=self.batch_size 234 | ) 235 | 236 | if guide is not None: 237 | guide = torch.from_numpy(guide).float().cuda() 238 | 239 | with torch.autocast(device_type="cuda"), seeded_context(0): 240 | if self.policy_tag == 'act': 241 | actions = self.policy.run_inference(obs_batch).cpu().numpy() 242 | else: 243 | actions = self.policy.run_inference(obs_batch, guide=guide, visualizer=visualizer).cpu().numpy() # directly call the policy in order to visualize the intermediate steps 244 | return actions 245 | 246 | def update_mouse_pos(self): 247 | self.mouse_pos = np.array(pygame.mouse.get_pos()) 248 | 249 | def update_agent_pos(self, new_agent_pos, history_len=1): 250 | self.agent_gui_pos = np.array(new_agent_pos) 251 | agent_xy_pos = self.gui2xy(self.agent_gui_pos) 252 | self.agent_in_collision = self.check_collision(agent_xy_pos.reshape(1, 1, 2))[0] 253 | if self.agent_in_collision: 254 | self.agent_color = self.blend_with_white(self.RED, 0.8) 255 | else: 256 | self.agent_color = self.RED 257 | self.agent_history_xy.append(agent_xy_pos) 258 | self.agent_history_xy = self.agent_history_xy[-history_len:] 259 | 260 | def run(self): 261 | while self.running: 262 | self.update_mouse_pos() 263 | 264 | # Handle events 265 | for event in pygame.event.get(): 266 | if event.type == pygame.QUIT: 267 | self.running = False 268 | break 269 | 270 | self.update_agent_pos(self.mouse_pos.copy()) 271 | xy_pred = self.infer_target() 272 | self.update_screen(xy_pred) 273 | self.clock.tick(30) 274 | 275 | pygame.quit() 276 | 277 | 278 | class ConditionalMaze(UnconditionalMaze): 279 | # for interactive guidance dataset collection 280 | def __init__(self, policy, vis_dp_dynamics=False, savepath=None, alignment_strategy=None, policy_tag=None): 281 | super().__init__(policy, policy_tag=policy_tag) 282 | self.drawing = False 283 | self.keep_drawing = False 284 | self.vis_dp_dynamics = vis_dp_dynamics 285 | self.savefile = None 286 | self.savepath = savepath 287 | self.draw_traj = [] # gui coordinates 288 | self.xy_pred = None # numpy array 289 | self.collisions = None # boolean array 290 | self.scores = None # numpy array 291 | self.alignment_strategy = alignment_strategy 292 | 293 | def run(self): 294 | if self.savepath is not None: 295 | self.savefile = open(self.savepath, "a+", buffering=1) 296 | self.trial_idx = 0 297 | 298 | while self.running: 299 | self.update_mouse_pos() 300 | 301 | # Handle events 302 | for event in pygame.event.get(): 303 | if event.type == pygame.QUIT: 304 | self.running = False 305 | break 306 | if any(pygame.mouse.get_pressed()): # Check if mouse button is pressed 307 | if not self.drawing: 308 | self.drawing = True 309 | self.draw_traj = [] 310 | self.draw_traj.append(self.mouse_pos) 311 | else: # mouse released 312 | if self.drawing: 313 | self.drawing = False # finish drawing action 314 | self.keep_drawing = True # keep visualizing the drawing 315 | if event.type == pygame.KEYDOWN: 316 | # press s to save the trial 317 | if event.key == pygame.K_s and self.savefile is not None: 318 | self.save_trials() 319 | 320 | if self.keep_drawing: # visualize the human drawing input 321 | # Check if mouse returns to the agent's location 322 | if np.linalg.norm(self.mouse_pos - self.agent_gui_pos) < 20: # Threshold distance to reactivate the agent 323 | self.keep_drawing = False # delete the drawing 324 | self.draw_traj = [] 325 | 326 | if not self.drawing: # inference mode 327 | if not self.keep_drawing: 328 | self.update_agent_pos(self.mouse_pos.copy()) 329 | if len(self.draw_traj) > 0: 330 | guide = np.array([self.gui2xy(point) for point in self.draw_traj]) 331 | else: 332 | guide = None 333 | self.xy_pred = self.infer_target(guide, visualizer=(self if self.vis_dp_dynamics and self.keep_drawing else None)) 334 | self.scores = None 335 | if self.alignment_strategy == 'post-hoc' and guide is not None: 336 | xy_pred, scores = self.similarity_score(self.xy_pred, guide) 337 | self.xy_pred = xy_pred 338 | self.scores = scores 339 | self.collisions = self.check_collision(self.xy_pred) 340 | 341 | self.update_screen(self.xy_pred, self.collisions, self.scores, (self.keep_drawing or self.drawing)) 342 | if self.vis_dp_dynamics and not self.drawing and self.keep_drawing: 343 | time.sleep(1) 344 | self.clock.tick(30) 345 | 346 | pygame.quit() 347 | 348 | def save_trials(self): 349 | b, t, _ = self.xy_pred.shape 350 | xy_pred = self.xy_pred.reshape(b*t, 2) 351 | pred_gui_traj = [self.xy2gui(xy) for xy in xy_pred] 352 | pred_gui_traj = np.array(pred_gui_traj).reshape(b, t, 2) 353 | entry = { 354 | "trial_idx": self.trial_idx, 355 | "agent_pos": self.agent_gui_pos.tolist(), 356 | "guide": np.array(self.draw_traj).tolist(), 357 | "pred_traj": pred_gui_traj.astype(int).tolist(), 358 | "collisions": self.collisions.tolist() 359 | } 360 | self.savefile.write(json.dumps(entry) + "\n") 361 | print(f"Trial {self.trial_idx} saved to {self.savepath}.") 362 | self.trial_idx += 1 363 | 364 | class MazeExp(ConditionalMaze): 365 | # for replaying the trials and benchmarking the alignment strategies 366 | def __init__(self, policy, vis_dp_dynamics=False, savepath=None, alignment_strategy=None, policy_tag=None, loadpath=None): 367 | super().__init__(policy, vis_dp_dynamics, savepath, policy_tag=policy_tag) 368 | # Load saved trails 369 | assert loadpath is not None 370 | with open(args.loadpath, "r", buffering=1) as file: 371 | file.seek(0) 372 | trials = [json.loads(line) for line in file] 373 | # set random seed and shuffle the trials 374 | np.random.seed(0) 375 | np.random.shuffle(trials) 376 | 377 | self.trials = trials 378 | self.trial_idx = 0 379 | # if savepath is not None: 380 | # # append loadpath to the savepath as prefix 381 | # self.savepath = loadpath[:-5] + '_' + policy_tag + '_' + savepath 382 | # self.savefile = open(self.savepath, "a+", buffering=1) 383 | # self.trial_idx = 0 384 | self.alignment_strategy = alignment_strategy 385 | print(f"Alignment strategy: {alignment_strategy}") 386 | 387 | def run(self): 388 | if self.savepath is not None: 389 | self.savefile = open(savepath, "w+", buffering=1) 390 | self.trial_idx = 0 391 | 392 | while self.trial_idx < len(self.trials): 393 | # Load the trial 394 | self.draw_traj = self.trials[self.trial_idx]["guide"] 395 | 396 | # skip empty trials 397 | if len(self.draw_traj) == 0: 398 | print(f"Skipping trial {self.trial_idx} which has no guide.") 399 | self.trial_idx += 1 400 | continue 401 | 402 | # skip trials with all collisions 403 | first_collision_idx = self.find_first_collision_from_GUI(np.array(self.draw_traj)) 404 | if first_collision_idx <= 0: # no collision or all collisions 405 | if np.array(self.trials[self.trial_idx]["collisions"]).all(): 406 | print(f"Skipping trial {self.trial_idx} which has all collisions.") 407 | self.trial_idx += 1 408 | continue 409 | 410 | # initialize the agent position 411 | if self.alignment_strategy == 'output-perturb': 412 | # find the location before the first collision to initialize the agent 413 | if first_collision_idx <= 0: # no collision or all collisions 414 | perturbed_pos = self.draw_traj[20] 415 | else: 416 | first_collision_idx = min(first_collision_idx, 20) 417 | perturbed_pos = self.draw_traj[first_collision_idx - 1] 418 | self.update_agent_pos(perturbed_pos) 419 | else: 420 | self.update_agent_pos(self.trials[self.trial_idx]["agent_pos"]) 421 | 422 | # infer the target based on the guide 423 | if self.policy is not None: 424 | guide = np.array([self.gui2xy(point) for point in self.draw_traj]) 425 | self.xy_pred = self.infer_target(guide, visualizer=(self if self.vis_dp_dynamics else None)) 426 | if self.alignment_strategy in ['output-perturb', 'post-hoc']: 427 | self.xy_pred, scores = self.similarity_score(self.xy_pred, guide) 428 | else: 429 | scores = None 430 | self.collisions = self.check_collision(self.xy_pred) 431 | self.update_screen(self.xy_pred, self.collisions, scores=scores, keep_drawing=True, traj_in_gui_space=False) 432 | if self.vis_dp_dynamics: 433 | time.sleep(1) 434 | 435 | # save the experiment trial 436 | if self.savepath is not None: 437 | self.save_trials() 438 | 439 | # just replay the trials without inference 440 | else: 441 | collisions = self.trials[self.trial_idx]["collisions"] 442 | pred_traj = np.array(self.trials[self.trial_idx]["pred_traj"]) 443 | if self.alignment_strategy in ['output-perturb', 'post-hoc']: 444 | _, scores = self.similarity_score(pred_traj, np.array(self.trials[self.trial_idx]["guide"])) # this is a hack as both pred_traj and guide are in gui space, don't use this score for absolute statistics calculation 445 | else: 446 | scores = None 447 | self.update_screen(pred_traj, collisions, scores=scores, keep_drawing=True, traj_in_gui_space=True) 448 | 449 | # Handle events 450 | for event in pygame.event.get(): 451 | if event.type == pygame.KEYDOWN: 452 | assert self.savefile is None 453 | if event.key == pygame.K_n and self.savefile is None: # visualization mode rather than saving mode 454 | print("manual skip to the next trial") 455 | self.trial_idx += 1 456 | 457 | self.clock.tick(10) 458 | 459 | pygame.quit() 460 | 461 | 462 | if __name__ == "__main__": 463 | parser = argparse.ArgumentParser() 464 | parser.add_argument('-c', "--checkpoint", type=str, help="Path to the checkpoint") 465 | parser.add_argument('-p', '--policy', required=True, type=str, help="Policy name") 466 | parser.add_argument('-u', '--unconditional', action='store_true', help="Unconditional Maze") 467 | parser.add_argument('-op', '--output-perturb', action='store_true', help="Output perturbation") 468 | parser.add_argument('-ph', '--post-hoc', action='store_true', help="Post-hoc ranking") 469 | parser.add_argument('-bi', '--biased-initialization', action='store_true', help="Biased initialization") 470 | parser.add_argument('-gd', '--guided-diffusion', action='store_true', help="Guided diffusion") 471 | parser.add_argument('-ss', '--stochastic-sampling', action='store_true', help="Stochastic sampling") 472 | parser.add_argument('-v', '--vis_dp_dynamics', action='store_true', help="Visualize dynamics in DP") 473 | parser.add_argument('-s', '--savepath', type=str, default=None, help="Filename to save the drawing") 474 | parser.add_argument('-l', '--loadpath', type=str, default=None, help="Filename to load the drawing") 475 | 476 | args = parser.parse_args() 477 | 478 | # Create and load the policy 479 | device = torch.device("cuda") 480 | 481 | alignment_strategy = 'post-hoc' 482 | if args.post_hoc: 483 | alignment_strategy = 'post-hoc' 484 | elif args.output_perturb: 485 | alignment_strategy = 'output-perturb' 486 | elif args.biased_initialization: 487 | alignment_strategy = 'biased-initialization' 488 | elif args.guided_diffusion: 489 | alignment_strategy = 'guided-diffusion' 490 | elif args.stochastic_sampling: 491 | alignment_strategy = 'stochastic-sampling' 492 | 493 | if args.policy in ["diffusion", "dp"]: 494 | checkpoint_path = 'weights_dp' 495 | elif args.policy in ["act"]: 496 | checkpoint_path = 'weights_act' 497 | else: 498 | raise NotImplementedError(f"Policy with name {args.policy} is not implemented.") 499 | 500 | if args.policy is not None: 501 | # Load policy 502 | pretrained_policy_path = Path(os.path.join(checkpoint_path, "pretrained_model")) 503 | 504 | if args.policy in ["diffusion", "dp"]: 505 | policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, alignment_strategy=alignment_strategy) 506 | policy.config.noise_scheduler_type = "DDIM" 507 | policy.diffusion.num_inference_steps = 10 508 | policy.config.n_action_steps = policy.config.horizon - policy.config.n_obs_steps + 1 509 | policy_tag = 'dp' 510 | policy.cuda() 511 | policy.eval() 512 | elif args.policy in ["act"]: 513 | policy = ACTPolicy.from_pretrained(pretrained_policy_path) 514 | policy_tag = 'act' 515 | policy.cuda() 516 | policy.eval() 517 | else: 518 | policy = None 519 | policy_tag = None 520 | 521 | if args.unconditional: 522 | interactiveMaze = UnconditionalMaze(policy, policy_tag=policy_tag) 523 | elif args.loadpath is not None: 524 | if args.savepath is None: 525 | savepath = None 526 | else: 527 | alignment_tag = 'ph' 528 | if alignment_strategy == 'output-perturb': 529 | alignment_tag = 'op' 530 | elif alignment_strategy == 'biased-initialization': 531 | alignment_tag = 'bi' 532 | elif alignment_strategy == 'guided-diffusion': 533 | alignment_tag = 'gd' 534 | elif alignment_strategy == 'stochastic-sampling': 535 | alignment_tag = 'ss' 536 | savepath = f"{args.loadpath[:-5]}_{policy_tag}_{alignment_tag}{args.savepath}" 537 | interactiveMaze = MazeExp(policy, args.vis_dp_dynamics, savepath, alignment_strategy, policy_tag=policy_tag, loadpath=args.loadpath) 538 | else: 539 | interactiveMaze = ConditionalMaze(policy, args.vis_dp_dynamics, args.savepath, alignment_strategy, policy_tag=policy_tag) 540 | interactiveMaze.run() 541 | -------------------------------------------------------------------------------- /itps/common/policies/act/modeling_act.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Action Chunking Transformer Policy 17 | 18 | As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). 19 | The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. 20 | """ 21 | 22 | import math 23 | from collections import deque 24 | from itertools import chain 25 | from typing import Callable 26 | 27 | import einops 28 | import numpy as np 29 | import torch 30 | import torch.nn.functional as F # noqa: N812 31 | import torchvision 32 | from huggingface_hub import PyTorchModelHubMixin 33 | from torch import Tensor, nn 34 | from torchvision.models._utils import IntermediateLayerGetter 35 | from torchvision.ops.misc import FrozenBatchNorm2d 36 | 37 | from common.policies.act.configuration_act import ACTConfig 38 | from common.policies.normalize import Normalize, Unnormalize 39 | 40 | 41 | class ACTPolicy(nn.Module, PyTorchModelHubMixin): 42 | """ 43 | Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost 44 | Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) 45 | """ 46 | 47 | name = "act" 48 | 49 | def __init__( 50 | self, 51 | config: ACTConfig | None = None, 52 | dataset_stats: dict[str, dict[str, Tensor]] | None = None, 53 | ): 54 | """ 55 | Args: 56 | config: Policy configuration class instance or None, in which case the default instantiation of 57 | the configuration class is used. 58 | dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected 59 | that they will be passed with a call to `load_state_dict` before the policy is used. 60 | """ 61 | super().__init__() 62 | if config is None: 63 | config = ACTConfig() 64 | self.config: ACTConfig = config 65 | 66 | self.normalize_inputs = Normalize( 67 | config.input_shapes, config.input_normalization_modes, dataset_stats 68 | ) 69 | self.normalize_targets = Normalize( 70 | config.output_shapes, config.output_normalization_modes, dataset_stats 71 | ) 72 | self.unnormalize_outputs = Unnormalize( 73 | config.output_shapes, config.output_normalization_modes, dataset_stats 74 | ) 75 | 76 | self.model = ACT(config) 77 | 78 | self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] 79 | 80 | self.reset() 81 | 82 | def reset(self): 83 | """This should be called whenever the environment is reset.""" 84 | if self.config.temporal_ensemble_momentum is not None: 85 | self._ensembled_actions = None 86 | else: 87 | self._action_queue = deque([], maxlen=self.config.n_action_steps) 88 | 89 | @property 90 | def n_obs_steps(self) -> int: 91 | return self.config.n_obs_steps 92 | 93 | @property 94 | def input_keys(self) -> set[str]: 95 | return set(self.config.input_shapes) 96 | 97 | @torch.no_grad 98 | def run_inference(self, observation_batch: dict[str, Tensor]) -> Tensor: 99 | observation_batch = self.normalize_inputs(observation_batch) 100 | if len(self.expected_image_keys) > 0: 101 | observation_batch["observation.images"] = torch.stack( 102 | [observation_batch[k] for k in self.expected_image_keys], dim=-4 103 | ) 104 | for k in observation_batch: 105 | if not k.startswith("observation"): 106 | continue 107 | observation_batch[k] = observation_batch[k].squeeze(1) 108 | actions, _ = self.model(observation_batch) 109 | actions = self.unnormalize_outputs({"action": actions})["action"] 110 | return actions 111 | 112 | def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: 113 | """Run the batch through the model and compute the loss for training or validation.""" 114 | batch = self.normalize_inputs(batch) 115 | if len(self.expected_image_keys) > 0: 116 | batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) 117 | batch = self.normalize_targets(batch) 118 | actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) 119 | 120 | l1_loss = ( 121 | F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) 122 | ).mean() 123 | 124 | loss_dict = {"l1_loss": l1_loss.item()} 125 | if self.config.use_vae: 126 | # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for 127 | # each dimension independently, we sum over the latent dimension to get the total 128 | # KL-divergence per batch element, then take the mean over the batch. 129 | # (See App. B of https://arxiv.org/abs/1312.6114 for more details). 130 | mean_kld = ( 131 | (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() 132 | ) 133 | loss_dict["kld_loss"] = mean_kld.item() 134 | loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight 135 | else: 136 | loss_dict["loss"] = l1_loss 137 | 138 | return loss_dict 139 | 140 | 141 | class ACT(nn.Module): 142 | """Action Chunking Transformer: The underlying neural network for ACTPolicy. 143 | 144 | Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. 145 | - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the 146 | model that encodes the target data (a sequence of actions), and the condition (the robot 147 | joint-space). 148 | - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with 149 | cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we 150 | have an option to train this model without the variational objective (in which case we drop the 151 | `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). 152 | 153 | Transformer 154 | Used alone for inference 155 | (acts as VAE decoder 156 | during training) 157 | ┌───────────────────────┐ 158 | │ Outputs │ 159 | │ ▲ │ 160 | │ ┌─────►┌───────┐ │ 161 | ┌──────┐ │ │ │Transf.│ │ 162 | │ │ │ ├─────►│decoder│ │ 163 | ┌────┴────┐ │ │ │ │ │ │ 164 | │ │ │ │ ┌───┴───┬─►│ │ │ 165 | │ VAE │ │ │ │ │ └───────┘ │ 166 | │ encoder │ │ │ │Transf.│ │ 167 | │ │ │ │ │encoder│ │ 168 | └───▲─────┘ │ │ │ │ │ 169 | │ │ │ └▲──▲─▲─┘ │ 170 | │ │ │ │ │ │ │ 171 | inputs └─────┼──┘ │ image emb. │ 172 | │ state emb. │ 173 | └───────────────────────┘ 174 | """ 175 | 176 | def __init__(self, config: ACTConfig): 177 | super().__init__() 178 | self.config = config 179 | # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. 180 | # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). 181 | self.use_robot_state = "observation.state" in config.input_shapes 182 | self.use_images = any(k.startswith("observation.image") for k in config.input_shapes) 183 | self.use_env_state = "observation.environment_state" in config.input_shapes 184 | if self.config.use_vae: 185 | self.vae_encoder = ACTEncoder(config) 186 | self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) 187 | # Projection layer for joint-space configuration to hidden dimension. 188 | if self.use_robot_state: 189 | self.vae_encoder_robot_state_input_proj = nn.Linear( 190 | config.input_shapes["observation.state"][0], config.dim_model 191 | ) 192 | # Projection layer for action (joint-space target) to hidden dimension. 193 | self.vae_encoder_action_input_proj = nn.Linear( 194 | config.output_shapes["action"][0], config.dim_model 195 | ) 196 | # Projection layer from the VAE encoder's output to the latent distribution's parameter space. 197 | self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) 198 | # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch 199 | # dimension. 200 | num_input_token_encoder = 1 + config.chunk_size 201 | if self.use_robot_state: 202 | num_input_token_encoder += 1 203 | self.register_buffer( 204 | "vae_encoder_pos_enc", 205 | create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), 206 | ) 207 | 208 | # Backbone for image feature extraction. 209 | if self.use_images: 210 | backbone_model = getattr(torchvision.models, config.vision_backbone)( 211 | replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], 212 | weights=config.pretrained_backbone_weights, 213 | norm_layer=FrozenBatchNorm2d, 214 | ) 215 | # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final 216 | # feature map). 217 | # Note: The forward method of this returns a dict: {"feature_map": output}. 218 | self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) 219 | 220 | # Transformer (acts as VAE decoder when training with the variational objective). 221 | self.encoder = ACTEncoder(config) 222 | self.decoder = ACTDecoder(config) 223 | 224 | # Transformer encoder input projections. The tokens will be structured like 225 | # [latent, (robot_state), (env_state), (image_feature_map_pixels)]. 226 | if self.use_robot_state: 227 | self.encoder_robot_state_input_proj = nn.Linear( 228 | config.input_shapes["observation.state"][0], config.dim_model 229 | ) 230 | if self.use_env_state: 231 | self.encoder_env_state_input_proj = nn.Linear( 232 | config.input_shapes["observation.environment_state"][0], config.dim_model 233 | ) 234 | self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) 235 | if self.use_images: 236 | self.encoder_img_feat_input_proj = nn.Conv2d( 237 | backbone_model.fc.in_features, config.dim_model, kernel_size=1 238 | ) 239 | # Transformer encoder positional embeddings. 240 | n_1d_tokens = 1 # for the latent 241 | if self.use_robot_state: 242 | n_1d_tokens += 1 243 | if self.use_env_state: 244 | n_1d_tokens += 1 245 | self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) 246 | if self.use_images: 247 | self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) 248 | 249 | # Transformer decoder. 250 | # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). 251 | self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) 252 | 253 | # Final action regression head on the output of the transformer's decoder. 254 | self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0]) 255 | 256 | self._reset_parameters() 257 | 258 | def _reset_parameters(self): 259 | """Xavier-uniform initialization of the transformer parameters as in the original code.""" 260 | for p in chain(self.encoder.parameters(), self.decoder.parameters()): 261 | if p.dim() > 1: 262 | nn.init.xavier_uniform_(p) 263 | 264 | def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: 265 | """A forward pass through the Action Chunking Transformer (with optional VAE encoder). 266 | 267 | `batch` should have the following structure: 268 | { 269 | "observation.state" (optional): (B, state_dim) batch of robot states. 270 | 271 | "observation.images": (B, n_cameras, C, H, W) batch of images. 272 | AND/OR 273 | "observation.environment_state": (B, env_dim) batch of environment states. 274 | 275 | "action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. 276 | } 277 | 278 | Returns: 279 | (B, chunk_size, action_dim) batch of action sequences 280 | Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the 281 | latent dimension. 282 | """ 283 | if self.config.use_vae and self.training: 284 | assert ( 285 | "action" in batch 286 | ), "actions must be provided when using the variational objective in training mode." 287 | 288 | batch_size = ( 289 | batch["observation.images"] 290 | if "observation.images" in batch 291 | else batch["observation.environment_state"] 292 | ).shape[0] 293 | 294 | # Prepare the latent for input to the transformer encoder. 295 | if self.config.use_vae and "action" in batch: 296 | # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. 297 | cls_embed = einops.repeat( 298 | self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size 299 | ) # (B, 1, D) 300 | if self.use_robot_state: 301 | robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) 302 | robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) 303 | action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) 304 | 305 | if self.use_robot_state: 306 | vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) 307 | else: 308 | vae_encoder_input = [cls_embed, action_embed] 309 | vae_encoder_input = torch.cat(vae_encoder_input, axis=1) 310 | 311 | # Prepare fixed positional embedding. 312 | # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. 313 | pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) 314 | 315 | # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the 316 | # sequence depending whether we use the input states or not (cls and robot state) 317 | # False means not a padding token. 318 | cls_joint_is_pad = torch.full( 319 | (batch_size, 2 if self.use_robot_state else 1), 320 | False, 321 | device=batch["observation.state"].device, 322 | ) 323 | key_padding_mask = torch.cat( 324 | [cls_joint_is_pad, batch["action_is_pad"]], axis=1 325 | ) # (bs, seq+1 or 2) 326 | 327 | # Forward pass through VAE encoder to get the latent PDF parameters. 328 | cls_token_out = self.vae_encoder( 329 | vae_encoder_input.permute(1, 0, 2), 330 | pos_embed=pos_embed.permute(1, 0, 2), 331 | key_padding_mask=key_padding_mask, 332 | )[0] # select the class token, with shape (B, D) 333 | latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) 334 | mu = latent_pdf_params[:, : self.config.latent_dim] 335 | # This is 2log(sigma). Done this way to match the original implementation. 336 | log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] 337 | 338 | # Sample the latent with the reparameterization trick. 339 | latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) 340 | elif self.config.use_vae: 341 | # When not using the VAE encoder, we sample the latent from the standard normal distribution. 342 | mu = log_sigma_x2 = None 343 | # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer 344 | latent_sample = torch.randn([batch_size, self.config.latent_dim], dtype=torch.float32).to( 345 | batch["observation.state"].device 346 | ) 347 | else: 348 | # When not using the VAE encoder, we set the latent to be all zeros. 349 | mu = log_sigma_x2 = None 350 | # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer 351 | latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( 352 | batch["observation.state"].device 353 | ) 354 | 355 | # Prepare transformer encoder inputs. 356 | encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] 357 | encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) 358 | # Robot state token. 359 | if self.use_robot_state: 360 | encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) 361 | # Environment state token. 362 | if self.use_env_state: 363 | encoder_in_tokens.append( 364 | self.encoder_env_state_input_proj(batch["observation.environment_state"]) 365 | ) 366 | 367 | # Camera observation features and positional embeddings. 368 | if self.use_images: 369 | all_cam_features = [] 370 | all_cam_pos_embeds = [] 371 | images = batch["observation.images"] 372 | 373 | for cam_index in range(images.shape[-4]): 374 | cam_features = self.backbone(images[:, cam_index])["feature_map"] 375 | # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use 376 | # buffer 377 | cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) 378 | cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) 379 | all_cam_features.append(cam_features) 380 | all_cam_pos_embeds.append(cam_pos_embed) 381 | # Concatenate camera observation feature maps and positional embeddings along the width dimension, 382 | # and move to (sequence, batch, dim). 383 | all_cam_features = torch.cat(all_cam_features, axis=-1) 384 | encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c")) 385 | all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1) 386 | encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")) 387 | 388 | # Stack all tokens along the sequence dimension. 389 | encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) 390 | encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0) 391 | 392 | # Forward pass through the transformer modules. 393 | encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) 394 | # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer 395 | decoder_in = torch.zeros( 396 | (self.config.chunk_size, batch_size, self.config.dim_model), 397 | dtype=encoder_in_pos_embed.dtype, 398 | device=encoder_in_pos_embed.device, 399 | ) 400 | decoder_out = self.decoder( 401 | decoder_in, 402 | encoder_out, 403 | encoder_pos_embed=encoder_in_pos_embed, 404 | decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), 405 | ) 406 | 407 | # Move back to (B, S, C). 408 | decoder_out = decoder_out.transpose(0, 1) 409 | 410 | actions = self.action_head(decoder_out) 411 | 412 | return actions, (mu, log_sigma_x2) 413 | 414 | 415 | class ACTEncoder(nn.Module): 416 | """Convenience module for running multiple encoder layers, maybe followed by normalization.""" 417 | 418 | def __init__(self, config: ACTConfig): 419 | super().__init__() 420 | self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)]) 421 | self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() 422 | 423 | def forward( 424 | self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None 425 | ) -> Tensor: 426 | for layer in self.layers: 427 | x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) 428 | x = self.norm(x) 429 | return x 430 | 431 | 432 | class ACTEncoderLayer(nn.Module): 433 | def __init__(self, config: ACTConfig): 434 | super().__init__() 435 | self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) 436 | 437 | # Feed forward layers. 438 | self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) 439 | self.dropout = nn.Dropout(config.dropout) 440 | self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) 441 | 442 | self.norm1 = nn.LayerNorm(config.dim_model) 443 | self.norm2 = nn.LayerNorm(config.dim_model) 444 | self.dropout1 = nn.Dropout(config.dropout) 445 | self.dropout2 = nn.Dropout(config.dropout) 446 | 447 | self.activation = get_activation_fn(config.feedforward_activation) 448 | self.pre_norm = config.pre_norm 449 | 450 | def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: 451 | skip = x 452 | if self.pre_norm: 453 | x = self.norm1(x) 454 | q = k = x if pos_embed is None else x + pos_embed 455 | x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask) 456 | x = x[0] # note: [0] to select just the output, not the attention weights 457 | x = skip + self.dropout1(x) 458 | if self.pre_norm: 459 | skip = x 460 | x = self.norm2(x) 461 | else: 462 | x = self.norm1(x) 463 | skip = x 464 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 465 | x = skip + self.dropout2(x) 466 | if not self.pre_norm: 467 | x = self.norm2(x) 468 | return x 469 | 470 | 471 | class ACTDecoder(nn.Module): 472 | def __init__(self, config: ACTConfig): 473 | """Convenience module for running multiple decoder layers followed by normalization.""" 474 | super().__init__() 475 | self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) 476 | self.norm = nn.LayerNorm(config.dim_model) 477 | 478 | def forward( 479 | self, 480 | x: Tensor, 481 | encoder_out: Tensor, 482 | decoder_pos_embed: Tensor | None = None, 483 | encoder_pos_embed: Tensor | None = None, 484 | ) -> Tensor: 485 | for layer in self.layers: 486 | x = layer( 487 | x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed 488 | ) 489 | if self.norm is not None: 490 | x = self.norm(x) 491 | return x 492 | 493 | 494 | class ACTDecoderLayer(nn.Module): 495 | def __init__(self, config: ACTConfig): 496 | super().__init__() 497 | self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) 498 | self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) 499 | 500 | # Feed forward layers. 501 | self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) 502 | self.dropout = nn.Dropout(config.dropout) 503 | self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) 504 | 505 | self.norm1 = nn.LayerNorm(config.dim_model) 506 | self.norm2 = nn.LayerNorm(config.dim_model) 507 | self.norm3 = nn.LayerNorm(config.dim_model) 508 | self.dropout1 = nn.Dropout(config.dropout) 509 | self.dropout2 = nn.Dropout(config.dropout) 510 | self.dropout3 = nn.Dropout(config.dropout) 511 | 512 | self.activation = get_activation_fn(config.feedforward_activation) 513 | self.pre_norm = config.pre_norm 514 | 515 | def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: 516 | return tensor if pos_embed is None else tensor + pos_embed 517 | 518 | def forward( 519 | self, 520 | x: Tensor, 521 | encoder_out: Tensor, 522 | decoder_pos_embed: Tensor | None = None, 523 | encoder_pos_embed: Tensor | None = None, 524 | ) -> Tensor: 525 | """ 526 | Args: 527 | x: (Decoder Sequence, Batch, Channel) tensor of input tokens. 528 | encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are 529 | cross-attending with. 530 | decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). 531 | encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). 532 | Returns: 533 | (DS, B, C) tensor of decoder output features. 534 | """ 535 | skip = x 536 | if self.pre_norm: 537 | x = self.norm1(x) 538 | q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) 539 | x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights 540 | x = skip + self.dropout1(x) 541 | if self.pre_norm: 542 | skip = x 543 | x = self.norm2(x) 544 | else: 545 | x = self.norm1(x) 546 | skip = x 547 | x = self.multihead_attn( 548 | query=self.maybe_add_pos_embed(x, decoder_pos_embed), 549 | key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), 550 | value=encoder_out, 551 | )[0] # select just the output, not the attention weights 552 | x = skip + self.dropout2(x) 553 | if self.pre_norm: 554 | skip = x 555 | x = self.norm3(x) 556 | else: 557 | x = self.norm2(x) 558 | skip = x 559 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 560 | x = skip + self.dropout3(x) 561 | if not self.pre_norm: 562 | x = self.norm3(x) 563 | return x 564 | 565 | 566 | def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: 567 | """1D sinusoidal positional embeddings as in Attention is All You Need. 568 | 569 | Args: 570 | num_positions: Number of token positions required. 571 | Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). 572 | 573 | """ 574 | 575 | def get_position_angle_vec(position): 576 | return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] 577 | 578 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) 579 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 580 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 581 | return torch.from_numpy(sinusoid_table).float() 582 | 583 | 584 | class ACTSinusoidalPositionEmbedding2d(nn.Module): 585 | """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. 586 | 587 | The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H 588 | for the vertical direction, and 1/W for the horizontal direction. 589 | """ 590 | 591 | def __init__(self, dimension: int): 592 | """ 593 | Args: 594 | dimension: The desired dimension of the embeddings. 595 | """ 596 | super().__init__() 597 | self.dimension = dimension 598 | self._two_pi = 2 * math.pi 599 | self._eps = 1e-6 600 | # Inverse "common ratio" for the geometric progression in sinusoid frequencies. 601 | self._temperature = 10000 602 | 603 | def forward(self, x: Tensor) -> Tensor: 604 | """ 605 | Args: 606 | x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. 607 | Returns: 608 | A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. 609 | """ 610 | not_mask = torch.ones_like(x[0, :1]) # (1, H, W) 611 | # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations 612 | # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. 613 | y_range = not_mask.cumsum(1, dtype=torch.float32) 614 | x_range = not_mask.cumsum(2, dtype=torch.float32) 615 | 616 | # "Normalize" the position index such that it ranges in [0, 2π]. 617 | # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range 618 | # are non-zero by construction. This is an artifact of the original code. 619 | y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi 620 | x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi 621 | 622 | inverse_frequency = self._temperature ** ( 623 | 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension 624 | ) 625 | 626 | x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) 627 | y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) 628 | 629 | # Note: this stack then flatten operation results in interleaved sine and cosine terms. 630 | # pos_embed_x and pos_embed_y are (1, H, W, C // 2). 631 | pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) 632 | pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) 633 | pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) 634 | 635 | return pos_embed 636 | 637 | 638 | def get_activation_fn(activation: str) -> Callable: 639 | """Return an activation function given a string.""" 640 | if activation == "relu": 641 | return F.relu 642 | if activation == "gelu": 643 | return F.gelu 644 | if activation == "glu": 645 | return F.glu 646 | raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") 647 | --------------------------------------------------------------------------------