├── .gitignore ├── LICENSE ├── README.md ├── brs_algo ├── __init__.py ├── learning │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── collate.py │ │ ├── data_module.py │ │ └── dataset.py │ ├── module │ │ ├── __init__.py │ │ ├── base.py │ │ └── diffusion_module.py │ ├── nn │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── misc.py │ │ │ └── mlp.py │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── diffusion_head.py │ │ │ └── unet.py │ │ ├── features │ │ │ ├── __init__.py │ │ │ ├── fusion.py │ │ │ └── pointnet.py │ │ └── gpt │ │ │ ├── __init__.py │ │ │ └── gpt.py │ └── policy │ │ ├── __init__.py │ │ ├── base.py │ │ └── wbvima_policy.py ├── lightning │ ├── __init__.py │ ├── lightning.py │ └── trainer.py ├── optim │ ├── __init__.py │ ├── lr_schedule.py │ └── optimizer_group.py ├── rollout │ ├── __init__.py │ ├── asynced_rollout.py │ └── synced_rollout.py └── utils │ ├── __init__.py │ ├── array_tensor_utils.py │ ├── config_utils.py │ ├── convert_utils.py │ ├── file_utils.py │ ├── functional_utils.py │ ├── misc_utils.py │ ├── print_utils.py │ ├── random_seed.py │ ├── shape_utils.py │ ├── termcolor.py │ ├── torch_utils.py │ └── tree_utils.py ├── main ├── rollout │ ├── clean_house_after_a_wild_party │ │ ├── common.py │ │ ├── rollout_async.py │ │ └── rollout_sync.py │ ├── clean_the_toilet │ │ ├── common.py │ │ ├── rollout_async.py │ │ └── rollout_sync.py │ ├── lay_clothes_out │ │ ├── common.py │ │ ├── rollout_async.py │ │ └── rollout_sync.py │ ├── put_items_onto_shelves │ │ ├── common.py │ │ ├── rollout_async.py │ │ └── rollout_sync.py │ └── take_trash_outside │ │ ├── common.py │ │ ├── rollout_async.py │ │ └── rollout_sync.py └── train │ ├── cfg │ ├── arch │ │ └── wbvima.yaml │ ├── cfg.yaml │ └── task │ │ ├── clean_house_after_a_wild_party.yaml │ │ ├── clean_the_toilet.yaml │ │ ├── lay_clothes_out.yaml │ │ ├── put_items_onto_shelves.yaml │ │ └── take_trash_outside.yaml │ └── train.py ├── media ├── SUSig-red.png ├── pull.gif └── wbvima.gif ├── pyproject.toml ├── scripts ├── merge_data_files.py └── post_process.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 behavior-robot-suite 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEHAVIOR Robot Suite: Streamlining Real-World Whole-Body Manipulation for Everyday Household Activities 2 |
3 | 4 | [Yunfan Jiang](https://yunfanj.com/), 5 | [Ruohan Zhang](https://ai.stanford.edu/~zharu/), 6 | [Josiah Wong](https://jdw.ong/), 7 | [Chen Wang](https://www.chenwangjeremy.net/), 8 | [Yanjie Ze](https://yanjieze.com/), 9 | [Hang Yin](https://hang-yin.github.io/), 10 | [Cem Gokmen](https://www.cemgokmen.com/), 11 | [Shuran Song](https://shurans.github.io/), 12 | [Jiajun Wu](https://jiajunwu.com/), 13 | [Li Fei-Fei](https://profiles.stanford.edu/fei-fei-li) 14 | 15 | 16 | 17 | [[Website]](https://behavior-robot-suite.github.io/) 18 | [[arXiv]](https://arxiv.org/abs/2503.05652) 19 | [[PDF]](https://behavior-robot-suite.github.io/assets/pdf/brs_paper.pdf) 20 | [[Doc]](https://behavior-robot-suite.github.io/docs/) 21 | [[Robot Code]](https://github.com/behavior-robot-suite/brs-ctrl) 22 | [[Training Data]](https://huggingface.co/datasets/behavior-robot-suite/data) 23 | 24 | [![Python Version](https://img.shields.io/badge/Python-3.11-blue.svg)](https://github.com/behavior-robot-suite/brs-algo) 25 | [](https://pytorch.org/) 26 | [](https://behavior-robot-suite.github.io/docs/) 27 | [![GitHub license](https://img.shields.io/github/license/behavior-robot-suite/brs-algo)](https://github.com/behavior-robot-suite/brs-algo/blob/main/LICENSE) 28 | 29 | ![](media/pull.gif) 30 | ______________________________________________________________________ 31 |
32 | 33 | We introduce the **BEHAVIOR Robot Suite** (BRS), a comprehensive framework for learning whole-body manipulation to tackle diverse real-world household tasks. BRS addresses both hardware and learning challenges through two key innovations: **WB-VIMA** and [JoyLo](https://github.com/behavior-robot-suite/brs-ctrl). 34 | 35 | WB-VIMA is an imitation learning algorithm designed to model whole-body actions by leveraging the robot’s inherent kinematic hierarchy. A key insight behind WB-VIMA is that robot joints exhibit strong interdependencies—small movements in upstream links (e.g., the torso) can lead to large displacements in downstream links (e.g., the end-effectors). To ensure precise coordination across all joints, WB-VIMA **conditions action predictions for downstream components on those of upstream components**, resulting in more synchronized whole-body movements. Additionally, WB-VIMA **dynamically aggregates multi-modal observations using self-attention**, allowing it to learn expressive policies while mitigating overfitting to proprioceptive inputs. 36 | 37 | ![](media/wbvima.gif) 38 | 39 | 40 | ## Getting Started 41 | 42 | > [!TIP] 43 | > 🚀 Check out the [doc](https://behavior-robot-suite.github.io/docs/sections/wbvima/overview.html) for detailed installation and usage instructions! 44 | 45 | To train a WB-VIMA policy, simply run the following command: 46 | 47 | ```bash 48 | python3 main/train/train.py data_dir= \ 49 | bs= \ 50 | arch=wbvima \ 51 | task= \ 52 | exp_root_dir= \ 53 | gpus= \ 54 | use_wandb= \ 55 | wandb_project= 56 | ``` 57 | 58 | To deploy a WB-VIMA policy on the real robot, simply run the following command: 59 | 60 | ```bash 61 | python3 main/rollout//rollout_async.py --ckpt_path --action_execute_start_idx 62 | ``` 63 | 64 | ## Check out Our Paper 65 | Our paper is posted on [arXiv](https://arxiv.org/abs/2503.05652). If you find our work useful, please consider citing us! 66 | 67 | ```bibtex 68 | @article{jiang2025brs, 69 | title = {BEHAVIOR Robot Suite: Streamlining Real-World Whole-Body Manipulation for Everyday Household Activities}, 70 | author = {Yunfan Jiang and Ruohan Zhang and Josiah Wong and Chen Wang and Yanjie Ze and Hang Yin and Cem Gokmen and Shuran Song and Jiajun Wu and Li Fei-Fei}, 71 | year = {2025}, 72 | journal = {arXiv preprint arXiv: 2503.05652} 73 | } 74 | ``` 75 | 76 | ## License 77 | This codebase is released under the [MIT License](LICENSE). 78 | -------------------------------------------------------------------------------- /brs_algo/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /brs_algo/learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/brs_algo/learning/__init__.py -------------------------------------------------------------------------------- /brs_algo/learning/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_module import ActionSeqChunkDataModule 2 | -------------------------------------------------------------------------------- /brs_algo/learning/data/collate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import brs_algo.utils as U 3 | 4 | 5 | def seq_chunk_collate_fn(sample_list): 6 | """ 7 | sample_list: list of (T, ...). PyTorch's native collate_fn can stack all data. 8 | But here we also add a leading singleton dimension, so it won't break the compatibility with episode data format. 9 | """ 10 | sample_list = U.any_stack(sample_list, dim=0) # (B, T, ...) 11 | sample_list = nested_np_expand_dims(sample_list, axis=0) # (1, B, T, ...) 12 | # convert to tensor 13 | return any_to_torch_tensor(sample_list) 14 | 15 | 16 | @U.make_recursive_func 17 | def nested_np_expand_dims(x, axis): 18 | if U.is_numpy(x): 19 | return np.expand_dims(x, axis=axis) 20 | else: 21 | raise ValueError(f"Input ({type(x)}) must be a numpy array.") 22 | 23 | 24 | def any_to_torch_tensor(x): 25 | if isinstance(x, dict): 26 | return {k: any_to_torch_tensor(v) for k, v in x.items()} 27 | return U.any_to_torch_tensor(x) 28 | -------------------------------------------------------------------------------- /brs_algo/learning/data/data_module.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | from torch.utils.data import DataLoader 4 | from pytorch_lightning import LightningDataModule 5 | 6 | from brs_algo import utils as U 7 | from brs_algo.learning.data.collate import seq_chunk_collate_fn 8 | from brs_algo.learning.data.dataset import ActionSeqChunkDataset 9 | 10 | 11 | class ActionSeqChunkDataModule(LightningDataModule): 12 | def __init__( 13 | self, 14 | *, 15 | data_path: str, 16 | pcd_downsample_points: int, 17 | pcd_x_range: Tuple[float, float], 18 | pcd_y_range: Tuple[float, float], 19 | pcd_z_range: Tuple[float, float], 20 | mobile_base_vel_action_min: Tuple[float, float, float], 21 | mobile_base_vel_action_max: Tuple[float, float, float], 22 | load_visual_obs_in_memory: bool = True, 23 | multi_view_cameras: Optional[List[str]] = None, 24 | load_multi_view_camera_rgb: bool = False, 25 | load_multi_view_camera_depth: bool = False, 26 | obs_window_size: int, 27 | action_prediction_horizon: int, 28 | batch_size: int, 29 | val_batch_size: Optional[int], 30 | val_split_ratio: float, 31 | dataloader_num_workers: int, 32 | seed: Optional[int] = None, 33 | ): 34 | super().__init__() 35 | self._data_path = data_path 36 | self._pcd_downsample_points = pcd_downsample_points 37 | self._pcd_x_range = pcd_x_range 38 | self._pcd_y_range = pcd_y_range 39 | self._pcd_z_range = pcd_z_range 40 | self._mobile_base_vel_action_min = mobile_base_vel_action_min 41 | self._mobile_base_vel_action_max = mobile_base_vel_action_max 42 | self._load_visual_obs_in_memory = load_visual_obs_in_memory 43 | self._multi_view_cameras = multi_view_cameras 44 | self._load_multi_view_camera_rgb = load_multi_view_camera_rgb 45 | self._load_multi_view_camera_depth = load_multi_view_camera_depth 46 | self._batch_size = batch_size 47 | self._val_batch_size = val_batch_size 48 | self._dataloader_num_workers = dataloader_num_workers 49 | self._seed = seed 50 | self._val_split_ratio = val_split_ratio 51 | 52 | self._train_dataset, self._val_dataset = None, None 53 | self._obs_window_size = obs_window_size 54 | self._action_prediction_horizon = action_prediction_horizon 55 | 56 | def setup(self, stage: str) -> None: 57 | if stage == "fit" or stage is None: 58 | all_dataset = ActionSeqChunkDataset( 59 | fpath=self._data_path, 60 | pcd_downsample_points=self._pcd_downsample_points, 61 | pcd_x_range=self._pcd_x_range, 62 | pcd_y_range=self._pcd_y_range, 63 | pcd_z_range=self._pcd_z_range, 64 | mobile_base_vel_action_min=self._mobile_base_vel_action_min, 65 | mobile_base_vel_action_max=self._mobile_base_vel_action_max, 66 | load_visual_obs_in_memory=self._load_visual_obs_in_memory, 67 | multi_view_cameras=self._multi_view_cameras, 68 | load_multi_view_camera_rgb=self._load_multi_view_camera_rgb, 69 | load_multi_view_camera_depth=self._load_multi_view_camera_depth, 70 | seed=self._seed, 71 | action_prediction_horizon=self._action_prediction_horizon, 72 | obs_window_size=self._obs_window_size, 73 | ) 74 | self._train_dataset, self._val_dataset = U.sequential_split_dataset( 75 | all_dataset, 76 | split_portions=[1 - self._val_split_ratio, self._val_split_ratio], 77 | ) 78 | 79 | def train_dataloader(self): 80 | return DataLoader( 81 | self._train_dataset, 82 | batch_size=self._batch_size, 83 | num_workers=min(self._batch_size, self._dataloader_num_workers), 84 | pin_memory=True, 85 | persistent_workers=True, 86 | collate_fn=seq_chunk_collate_fn, 87 | ) 88 | 89 | def val_dataloader(self): 90 | return DataLoader( 91 | self._val_dataset, 92 | batch_size=self._val_batch_size, 93 | num_workers=min(self._val_batch_size, self._dataloader_num_workers), 94 | pin_memory=True, 95 | persistent_workers=True, 96 | collate_fn=seq_chunk_collate_fn, 97 | ) 98 | -------------------------------------------------------------------------------- /brs_algo/learning/module/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_module import DiffusionModule 2 | -------------------------------------------------------------------------------- /brs_algo/learning/module/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pytorch_lightning import LightningModule 3 | 4 | 5 | class ImitationBaseModule(LightningModule): 6 | """ 7 | Base class for IL algorithms that require 1) an environment, 2) a policy, and 3) rollout evaluation. 8 | """ 9 | 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | 13 | def training_step(self, *args, **kwargs): 14 | self.on_train_start() 15 | loss, log_dict, batch_size = self.imitation_training_step(*args, **kwargs) 16 | log_dict = {f"train/{k}": v for k, v in log_dict.items()} 17 | log_dict["train/loss"] = loss 18 | self.log_dict( 19 | log_dict, 20 | prog_bar=True, 21 | on_step=False, 22 | on_epoch=True, 23 | batch_size=batch_size, 24 | ) 25 | return loss 26 | 27 | def validation_step(self, *args, **kwargs): 28 | loss, log_dict, real_batch_size = self.imitation_test_step(*args, **kwargs) 29 | log_dict = {f"val/{k}": v for k, v in log_dict.items()} 30 | log_dict["val/loss"] = loss 31 | self.log_dict( 32 | log_dict, 33 | prog_bar=True, 34 | on_step=False, 35 | on_epoch=True, 36 | batch_size=real_batch_size, 37 | ) 38 | return log_dict 39 | 40 | def test_step(self, *args, **kwargs): 41 | loss, log_dict, real_batch_size = self.imitation_test_step(*args, **kwargs) 42 | log_dict = {f"test/{k}": v for k, v in log_dict.items()} 43 | log_dict["test/loss"] = loss 44 | self.log_dict( 45 | log_dict, 46 | prog_bar=True, 47 | on_step=False, 48 | on_epoch=True, 49 | batch_size=real_batch_size, 50 | ) 51 | return log_dict 52 | 53 | def configure_optimizers(self): 54 | """ 55 | Get optimizers, which are subsequently used to train. 56 | """ 57 | raise NotImplementedError 58 | 59 | def imitation_training_step(self, *args, **kwargs) -> Any: 60 | raise NotImplementedError 61 | 62 | def imitation_test_step(self, *args, **kwargs) -> Any: 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /brs_algo/learning/module/diffusion_module.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Any, Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from hydra.utils import instantiate 7 | from omegaconf import DictConfig 8 | 9 | import brs_algo.utils as U 10 | from brs_algo.optim import CosineScheduleFunction 11 | from brs_algo.learning.policy.base import BasePolicy 12 | from brs_algo.learning.module.base import ImitationBaseModule 13 | 14 | 15 | class DiffusionModule(ImitationBaseModule): 16 | def __init__( 17 | self, 18 | *, 19 | # ====== policy ====== 20 | policy: Union[BasePolicy, DictConfig], 21 | action_prediction_horizon: int, 22 | # ====== learning ====== 23 | lr: float, 24 | use_cosine_lr: bool = False, 25 | lr_warmup_steps: Optional[int] = None, 26 | lr_cosine_steps: Optional[int] = None, 27 | lr_cosine_min: Optional[float] = None, 28 | lr_layer_decay: float = 1.0, 29 | weight_decay: float = 0.0, 30 | action_keys: List[str], 31 | loss_on_latest_obs_only: bool = False, 32 | ): 33 | super().__init__() 34 | if isinstance(policy, DictConfig): 35 | policy = instantiate(policy) 36 | self.policy = policy 37 | self._action_keys = action_keys 38 | self.action_prediction_horizon = action_prediction_horizon 39 | self.lr = lr 40 | self.use_cosine_lr = use_cosine_lr 41 | self.lr_warmup_steps = lr_warmup_steps 42 | self.lr_cosine_steps = lr_cosine_steps 43 | self.lr_cosine_min = lr_cosine_min 44 | self.lr_layer_decay = lr_layer_decay 45 | self.weight_decay = weight_decay 46 | self.loss_on_latest_obs_only = loss_on_latest_obs_only 47 | 48 | def imitation_training_step(self, *args, **kwargs) -> Any: 49 | return self.imitation_training_step_seq_policy(*args, **kwargs) 50 | 51 | def imitation_test_step(self, *args, **kwargs): 52 | return self.imitation_val_step_seq_policy(*args, **kwargs) 53 | 54 | def imitation_training_step_seq_policy(self, batch, batch_idx): 55 | B = U.get_batch_size( 56 | U.any_slice(batch["action_chunks"], np.s_[0]), 57 | strict=True, 58 | ) 59 | # obs data is dict of (N_chunks, B, window_size, ...) 60 | # action chunks is (N_chunks, B, window_size, action_prediction_horizon, A) 61 | # we loop over chunk dim 62 | main_data = U.unstack_sequence_fields( 63 | batch, batch_size=U.get_batch_size(batch, strict=True) 64 | ) 65 | all_loss, all_mask_sum = [], 0 66 | for i, main_data_chunk in enumerate(main_data): 67 | # get padding mask 68 | pad_mask = main_data_chunk.pop( 69 | "pad_mask" 70 | ) # (B, window_size, L_pred_horizon) 71 | action_chunks = main_data_chunk.pop( 72 | "action_chunks" 73 | ) # dict of (B, window_size, L_pred_horizon, A) 74 | gt_actions = torch.cat( 75 | [action_chunks[k] for k in self._action_keys], dim=-1 76 | ) 77 | transformer_output = self.policy( 78 | main_data_chunk 79 | ) # (B, L, E), where L is interleaved time and modality tokens 80 | loss = self.policy.compute_loss( 81 | transformer_output=transformer_output, 82 | gt_action=gt_actions, 83 | ) # (B, T_obs, T_act) 84 | if self.loss_on_latest_obs_only: 85 | mask = torch.zeros_like(pad_mask) 86 | mask[:, -1] = 1 87 | pad_mask = pad_mask * mask 88 | loss = loss * pad_mask 89 | all_loss.append(loss) 90 | all_mask_sum += pad_mask.sum() 91 | action_loss = torch.sum(torch.stack(all_loss)) / all_mask_sum 92 | # sum over action_prediction_horizon dim instead of avg 93 | action_loss = action_loss * self.action_prediction_horizon 94 | log_dict = {"diffusion_loss": action_loss} 95 | loss = action_loss 96 | return loss, log_dict, B 97 | 98 | def imitation_val_step_seq_policy(self, batch, batch_idx): 99 | """ 100 | Will denoise as if it is in rollout 101 | but no env interaction 102 | """ 103 | B = U.get_batch_size( 104 | U.any_slice(batch["action_chunks"], np.s_[0]), 105 | strict=True, 106 | ) 107 | # obs data is dict of (N_chunks, B, window_size, ...) 108 | # action chunks is (N_chunks, B, window_size, action_prediction_horizon, A) 109 | # we loop over chunk dim 110 | main_data = U.unstack_sequence_fields( 111 | batch, batch_size=U.get_batch_size(batch, strict=True) 112 | ) 113 | all_l1, all_mask_sum = {}, 0 114 | for i, main_data_chunk in enumerate(main_data): 115 | # get padding mask 116 | pad_mask = main_data_chunk.pop( 117 | "pad_mask" 118 | ) # (B, window_size, L_pred_horizon) 119 | gt_actions = main_data_chunk.pop( 120 | "action_chunks" 121 | ) # dict of (B, window_size, L_pred_horizon, A) 122 | transformer_output = self.policy( 123 | main_data_chunk 124 | ) # (B, L, E), where L is interleaved time and modality tokens 125 | pred_actions = self.policy.inference( 126 | transformer_output=transformer_output, 127 | return_last_timestep_only=False, 128 | ) # dict of (B, window_size, L_pred_horizon, A) 129 | for action_k in pred_actions: 130 | pred = pred_actions[action_k] 131 | gt = gt_actions[action_k] 132 | l1 = F.l1_loss( 133 | pred, gt, reduction="none" 134 | ) # (B, window_size, L_pred_horizon, A) 135 | # sum over action dim 136 | l1 = l1.sum(dim=-1).reshape( 137 | pad_mask.shape 138 | ) # (B, window_size, L_pred_horizon) 139 | if self.loss_on_latest_obs_only: 140 | mask = torch.zeros_like(pad_mask) 141 | mask[:, -1] = 1 142 | pad_mask = pad_mask * mask 143 | l1 = l1 * pad_mask 144 | if action_k not in all_l1: 145 | all_l1[action_k] = [ 146 | l1, 147 | ] 148 | else: 149 | all_l1[action_k].append(l1) 150 | all_mask_sum += pad_mask.sum() 151 | # avg on chunks dim, batch dim, and obs window dim so we can compare under different training settings 152 | all_l1 = { 153 | k: torch.sum(torch.stack(v)) / all_mask_sum for k, v in all_l1.items() 154 | } 155 | all_l1 = {k: v * self.action_prediction_horizon for k, v in all_l1.items()} 156 | summed_l1 = sum(all_l1.values()) 157 | all_l1 = {f"l1_{k}": v for k, v in all_l1.items()} 158 | all_l1["l1"] = summed_l1 159 | return summed_l1, all_l1, B 160 | 161 | def configure_optimizers(self): 162 | optimizer_groups = self.policy.get_optimizer_groups( 163 | weight_decay=self.weight_decay, 164 | lr_layer_decay=self.lr_layer_decay, 165 | lr_scale=1.0, 166 | ) 167 | 168 | optimizer = torch.optim.AdamW( 169 | optimizer_groups, 170 | lr=self.lr, 171 | weight_decay=self.weight_decay, 172 | ) 173 | 174 | if self.use_cosine_lr: 175 | scheduler_kwargs = dict( 176 | base_value=1.0, # anneal from the original LR value 177 | final_value=self.lr_cosine_min / self.lr, 178 | epochs=self.lr_cosine_steps, 179 | warmup_start_value=self.lr_cosine_min / self.lr, 180 | warmup_epochs=self.lr_warmup_steps, 181 | steps_per_epoch=1, 182 | ) 183 | scheduler = torch.optim.lr_scheduler.LambdaLR( 184 | optimizer=optimizer, 185 | lr_lambda=CosineScheduleFunction(**scheduler_kwargs), 186 | ) 187 | return ( 188 | [optimizer], 189 | [{"scheduler": scheduler, "interval": "step"}], 190 | ) 191 | 192 | return optimizer 193 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/brs_algo/learning/nn/__init__.py -------------------------------------------------------------------------------- /brs_algo/learning/nn/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import build_mlp, MLP 2 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/common/misc.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/common/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Callable, Literal 3 | 4 | 5 | def get_activation(activation: str | Callable | None) -> Callable: 6 | if not activation: 7 | return nn.Identity 8 | elif callable(activation): 9 | return activation 10 | ACT_LAYER = { 11 | "tanh": nn.Tanh, 12 | "relu": lambda: nn.ReLU(inplace=True), 13 | "leaky_relu": lambda: nn.LeakyReLU(inplace=True), 14 | "swish": lambda: nn.SiLU(inplace=True), # SiLU is alias for Swish 15 | "sigmoid": nn.Sigmoid, 16 | "elu": lambda: nn.ELU(inplace=True), 17 | "gelu": nn.GELU, 18 | } 19 | activation = activation.lower() 20 | assert activation in ACT_LAYER, f"Supported activations: {ACT_LAYER.keys()}" 21 | return ACT_LAYER[activation] 22 | 23 | 24 | def get_initializer(method: str | Callable, activation: str) -> Callable: 25 | if isinstance(method, str): 26 | assert hasattr( 27 | nn.init, f"{method}_" 28 | ), f"Initializer nn.init.{method}_ does not exist" 29 | if method == "orthogonal": 30 | try: 31 | gain = nn.init.calculate_gain(activation) 32 | except ValueError: 33 | gain = 1.0 34 | return lambda x: nn.init.orthogonal_(x, gain=gain) 35 | else: 36 | return getattr(nn.init, f"{method}_") 37 | else: 38 | assert callable(method) 39 | return method 40 | 41 | 42 | def build_mlp( 43 | input_dim, 44 | *, 45 | hidden_dim: int, 46 | output_dim: int, 47 | hidden_depth: int = None, 48 | num_layers: int = None, 49 | activation: str | Callable = "relu", 50 | weight_init: str | Callable = "orthogonal", 51 | bias_init="zeros", 52 | norm_type: Literal["batchnorm", "layernorm"] | None = None, 53 | add_input_activation: bool | str | Callable = False, 54 | add_input_norm: bool = False, 55 | add_output_activation: bool | str | Callable = False, 56 | add_output_norm: bool = False, 57 | ) -> nn.Sequential: 58 | """ 59 | In other popular RL implementations, tanh is typically used with orthogonal 60 | initialization, which may perform better than ReLU. 61 | 62 | Args: 63 | norm_type: None, "batchnorm", "layernorm", applied to intermediate layers 64 | add_input_activation: whether to add a nonlinearity to the input _before_ 65 | the MLP computation. This is useful for processing a feature from a preceding 66 | image encoder, for example. Image encoder typically has a linear layer 67 | at the end, and we don't want the MLP to immediately stack another linear 68 | layer on the input features. 69 | - True to add the same activation as the rest of the MLP 70 | - str to add an activation of a different type. 71 | add_input_norm: see `add_input_activation`, whether to add a normalization layer 72 | to the input _before_ the MLP computation. 73 | values: True to add the `norm_type` to the input 74 | add_output_activation: whether to add a nonlinearity to the output _after_ the 75 | MLP computation. 76 | - True to add the same activation as the rest of the MLP 77 | - str to add an activation of a different type. 78 | add_output_norm: see `add_output_activation`, whether to add a normalization layer 79 | _after_ the MLP computation. 80 | values: True to add the `norm_type` to the input 81 | """ 82 | assert (hidden_depth is None) != (num_layers is None), ( 83 | "Either hidden_depth or num_layers must be specified, but not both. " 84 | "num_layers is defined as hidden_depth+1" 85 | ) 86 | if hidden_depth is not None: 87 | assert hidden_depth >= 0 88 | if num_layers is not None: 89 | assert num_layers >= 1 90 | act_layer = get_activation(activation) 91 | 92 | weight_init = get_initializer(weight_init, activation) 93 | bias_init = get_initializer(bias_init, activation) 94 | 95 | if norm_type is not None: 96 | norm_type = norm_type.lower() 97 | 98 | if not norm_type: 99 | norm_type = nn.Identity 100 | elif norm_type == "batchnorm": 101 | norm_type = nn.BatchNorm1d 102 | elif norm_type == "layernorm": 103 | norm_type = nn.LayerNorm 104 | else: 105 | raise ValueError(f"Unsupported norm layer: {norm_type}") 106 | 107 | hidden_depth = num_layers - 1 if hidden_depth is None else hidden_depth 108 | if hidden_depth == 0: 109 | mods = [nn.Linear(input_dim, output_dim)] 110 | else: 111 | mods = [nn.Linear(input_dim, hidden_dim), norm_type(hidden_dim), act_layer()] 112 | for i in range(hidden_depth - 1): 113 | mods += [ 114 | nn.Linear(hidden_dim, hidden_dim), 115 | norm_type(hidden_dim), 116 | act_layer(), 117 | ] 118 | mods.append(nn.Linear(hidden_dim, output_dim)) 119 | 120 | if add_input_norm: 121 | mods = [norm_type(input_dim)] + mods 122 | if add_input_activation: 123 | if add_input_activation is not True: 124 | act_layer = get_activation(add_input_activation) 125 | mods = [act_layer()] + mods 126 | if add_output_norm: 127 | mods.append(norm_type(output_dim)) 128 | if add_output_activation: 129 | if add_output_activation is not True: 130 | act_layer = get_activation(add_output_activation) 131 | mods.append(act_layer()) 132 | 133 | for mod in mods: 134 | if isinstance(mod, nn.Linear): 135 | weight_init(mod.weight) 136 | bias_init(mod.bias) 137 | 138 | return nn.Sequential(*mods) 139 | 140 | 141 | class MLP(nn.Module): 142 | def __init__( 143 | self, 144 | input_dim, 145 | *, 146 | hidden_dim: int, 147 | output_dim: int, 148 | hidden_depth: int = None, 149 | num_layers: int = None, 150 | activation: str | Callable = "relu", 151 | weight_init: str | Callable = "orthogonal", 152 | bias_init="zeros", 153 | norm_type: Literal["batchnorm", "layernorm"] | None = None, 154 | add_input_activation: bool | str | Callable = False, 155 | add_input_norm: bool = False, 156 | add_output_activation: bool | str | Callable = False, 157 | add_output_norm: bool = False, 158 | ): 159 | super().__init__() 160 | # delegate to build_mlp by keywords 161 | self.layers = build_mlp( 162 | input_dim, 163 | hidden_dim=hidden_dim, 164 | output_dim=output_dim, 165 | hidden_depth=hidden_depth, 166 | num_layers=num_layers, 167 | activation=activation, 168 | weight_init=weight_init, 169 | bias_init=bias_init, 170 | norm_type=norm_type, 171 | add_input_activation=add_input_activation, 172 | add_input_norm=add_input_norm, 173 | add_output_activation=add_output_activation, 174 | add_output_norm=add_output_norm, 175 | ) 176 | # add attributes to the class 177 | self.input_dim = input_dim 178 | self.output_dim = output_dim 179 | self.hidden_depth = hidden_depth 180 | self.activation = activation 181 | self.weight_init = weight_init 182 | self.bias_init = bias_init 183 | self.norm_type = norm_type 184 | if add_input_activation is True: 185 | self.input_activation = activation 186 | else: 187 | self.input_activation = add_input_activation 188 | if add_input_norm is True: 189 | self.input_norm_type = norm_type 190 | else: 191 | self.input_norm_type = None 192 | # do the same for output activation and norm 193 | if add_output_activation is True: 194 | self.output_activation = activation 195 | else: 196 | self.output_activation = add_output_activation 197 | if add_output_norm is True: 198 | self.output_norm_type = norm_type 199 | else: 200 | self.output_norm_type = None 201 | 202 | def forward(self, x): 203 | return self.layers(x) 204 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_head import UNetDiffusionHead, WholeBodyUNetDiffusionHead 2 | from .unet import ConditionalUnet1D 3 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/diffusion/unet.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import einops 6 | from einops.layers.torch import Rearrange 7 | 8 | from brs_algo.optim import default_optimizer_groups 9 | 10 | 11 | class Downsample1d(nn.Module): 12 | def __init__(self, dim): 13 | super().__init__() 14 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 15 | 16 | def forward(self, x): 17 | return self.conv(x) 18 | 19 | 20 | class Upsample1d(nn.Module): 21 | def __init__(self, dim): 22 | super().__init__() 23 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 24 | 25 | def forward(self, x): 26 | return self.conv(x) 27 | 28 | 29 | class SinusoidalPosEmb(nn.Module): 30 | def __init__(self, dim): 31 | super().__init__() 32 | self.dim = dim 33 | 34 | def forward(self, x): 35 | device = x.device 36 | half_dim = self.dim // 2 37 | emb = math.log(10000) / (half_dim - 1) 38 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 39 | emb = x[:, None] * emb[None, :] 40 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 41 | return emb 42 | 43 | 44 | class Conv1dBlock(nn.Module): 45 | """ 46 | Conv1d --> GroupNorm --> Mish 47 | """ 48 | 49 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 50 | super().__init__() 51 | 52 | self.block = nn.Sequential( 53 | nn.Conv1d( 54 | inp_channels, out_channels, kernel_size, padding=kernel_size // 2 55 | ), 56 | # Rearrange('batch channels horizon -> batch channels 1 horizon'), 57 | nn.GroupNorm(n_groups, out_channels), 58 | # Rearrange('batch channels 1 horizon -> batch channels horizon'), 59 | nn.Mish(), 60 | ) 61 | 62 | def forward(self, x): 63 | return self.block(x) 64 | 65 | 66 | class ConditionalResidualBlock1D(nn.Module): 67 | def __init__( 68 | self, 69 | in_channels, 70 | out_channels, 71 | cond_dim, 72 | kernel_size=3, 73 | n_groups=8, 74 | cond_predict_scale=False, 75 | ): 76 | super().__init__() 77 | 78 | self.blocks = nn.ModuleList( 79 | [ 80 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 81 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 82 | ] 83 | ) 84 | 85 | # FiLM modulation https://arxiv.org/abs/1709.07871 86 | # predicts per-channel scale and bias 87 | cond_channels = out_channels 88 | if cond_predict_scale: 89 | cond_channels = out_channels * 2 90 | self.cond_predict_scale = cond_predict_scale 91 | self.out_channels = out_channels 92 | self.cond_encoder = nn.Sequential( 93 | nn.Mish(), 94 | nn.Linear(cond_dim, cond_channels), 95 | Rearrange("batch t -> batch t 1"), 96 | ) 97 | 98 | # make sure dimensions compatible 99 | self.residual_conv = ( 100 | nn.Conv1d(in_channels, out_channels, 1) 101 | if in_channels != out_channels 102 | else nn.Identity() 103 | ) 104 | 105 | def forward(self, x, cond): 106 | """ 107 | x : [ batch_size x in_channels x horizon ] 108 | cond : [ batch_size x cond_dim] 109 | 110 | returns: 111 | out : [ batch_size x out_channels x horizon ] 112 | """ 113 | out = self.blocks[0](x) 114 | embed = self.cond_encoder(cond) 115 | if self.cond_predict_scale: 116 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) 117 | scale = embed[:, 0, ...] 118 | bias = embed[:, 1, ...] 119 | out = scale * out + bias 120 | else: 121 | out = out + embed 122 | out = self.blocks[1](out) 123 | out = out + self.residual_conv(x) 124 | return out 125 | 126 | 127 | class ConditionalUnet1D(nn.Module): 128 | def __init__( 129 | self, 130 | input_dim, 131 | *, 132 | local_cond_dim=None, 133 | global_cond_dim=None, 134 | diffusion_step_embed_dim=256, 135 | down_dims=[256, 512, 1024], 136 | kernel_size=3, 137 | n_groups=8, 138 | cond_predict_scale=False, 139 | ): 140 | super().__init__() 141 | all_dims = [input_dim] + list(down_dims) 142 | start_dim = down_dims[0] 143 | 144 | dsed = diffusion_step_embed_dim 145 | diffusion_step_encoder = nn.Sequential( 146 | SinusoidalPosEmb(dsed), 147 | nn.Linear(dsed, dsed * 4), 148 | nn.Mish(), 149 | nn.Linear(dsed * 4, dsed), 150 | ) 151 | cond_dim = dsed 152 | if global_cond_dim is not None: 153 | cond_dim += global_cond_dim 154 | 155 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 156 | 157 | local_cond_encoder = None 158 | if local_cond_dim is not None: 159 | _, dim_out = in_out[0] 160 | dim_in = local_cond_dim 161 | local_cond_encoder = nn.ModuleList( 162 | [ 163 | # down encoder 164 | ConditionalResidualBlock1D( 165 | dim_in, 166 | dim_out, 167 | cond_dim=cond_dim, 168 | kernel_size=kernel_size, 169 | n_groups=n_groups, 170 | cond_predict_scale=cond_predict_scale, 171 | ), 172 | # up encoder 173 | ConditionalResidualBlock1D( 174 | dim_in, 175 | dim_out, 176 | cond_dim=cond_dim, 177 | kernel_size=kernel_size, 178 | n_groups=n_groups, 179 | cond_predict_scale=cond_predict_scale, 180 | ), 181 | ] 182 | ) 183 | 184 | mid_dim = all_dims[-1] 185 | self.mid_modules = nn.ModuleList( 186 | [ 187 | ConditionalResidualBlock1D( 188 | mid_dim, 189 | mid_dim, 190 | cond_dim=cond_dim, 191 | kernel_size=kernel_size, 192 | n_groups=n_groups, 193 | cond_predict_scale=cond_predict_scale, 194 | ), 195 | ConditionalResidualBlock1D( 196 | mid_dim, 197 | mid_dim, 198 | cond_dim=cond_dim, 199 | kernel_size=kernel_size, 200 | n_groups=n_groups, 201 | cond_predict_scale=cond_predict_scale, 202 | ), 203 | ] 204 | ) 205 | 206 | down_modules = nn.ModuleList([]) 207 | for ind, (dim_in, dim_out) in enumerate(in_out): 208 | is_last = ind >= (len(in_out) - 1) 209 | down_modules.append( 210 | nn.ModuleList( 211 | [ 212 | ConditionalResidualBlock1D( 213 | dim_in, 214 | dim_out, 215 | cond_dim=cond_dim, 216 | kernel_size=kernel_size, 217 | n_groups=n_groups, 218 | cond_predict_scale=cond_predict_scale, 219 | ), 220 | ConditionalResidualBlock1D( 221 | dim_out, 222 | dim_out, 223 | cond_dim=cond_dim, 224 | kernel_size=kernel_size, 225 | n_groups=n_groups, 226 | cond_predict_scale=cond_predict_scale, 227 | ), 228 | Downsample1d(dim_out) if not is_last else nn.Identity(), 229 | ] 230 | ) 231 | ) 232 | 233 | up_modules = nn.ModuleList([]) 234 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 235 | is_last = ind >= (len(in_out) - 1) 236 | up_modules.append( 237 | nn.ModuleList( 238 | [ 239 | ConditionalResidualBlock1D( 240 | dim_out * 2, 241 | dim_in, 242 | cond_dim=cond_dim, 243 | kernel_size=kernel_size, 244 | n_groups=n_groups, 245 | cond_predict_scale=cond_predict_scale, 246 | ), 247 | ConditionalResidualBlock1D( 248 | dim_in, 249 | dim_in, 250 | cond_dim=cond_dim, 251 | kernel_size=kernel_size, 252 | n_groups=n_groups, 253 | cond_predict_scale=cond_predict_scale, 254 | ), 255 | Upsample1d(dim_in) if not is_last else nn.Identity(), 256 | ] 257 | ) 258 | ) 259 | 260 | final_conv = nn.Sequential( 261 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 262 | nn.Conv1d(start_dim, input_dim, 1), 263 | ) 264 | 265 | self.diffusion_step_encoder = diffusion_step_encoder 266 | self.local_cond_encoder = local_cond_encoder 267 | self.up_modules = up_modules 268 | self.down_modules = down_modules 269 | self.final_conv = final_conv 270 | 271 | def forward( 272 | self, 273 | sample: torch.Tensor, 274 | timestep: Union[torch.Tensor, float, int], 275 | local_cond=None, 276 | global_cond=None, 277 | ): 278 | """ 279 | x: (B,T,input_dim) 280 | timestep: (B,) or int, diffusion step 281 | local_cond: (B,T,local_cond_dim) 282 | global_cond: (B,global_cond_dim) 283 | output: (B,T,input_dim) 284 | """ 285 | sample = einops.rearrange(sample, "b h t -> b t h") 286 | 287 | # 1. time 288 | timesteps = timestep 289 | if len(timesteps.shape) == 0: 290 | timesteps = timesteps[None].to(sample.device) 291 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 292 | timesteps = timesteps.expand(sample.shape[0]) 293 | 294 | global_feature = self.diffusion_step_encoder(timesteps) 295 | 296 | if global_cond is not None: 297 | global_feature = torch.cat([global_feature, global_cond], axis=-1) 298 | 299 | # encode local features 300 | h_local = list() 301 | if local_cond is not None: 302 | local_cond = einops.rearrange(local_cond, "b h t -> b t h") 303 | resnet, resnet2 = self.local_cond_encoder 304 | x = resnet(local_cond, global_feature) 305 | h_local.append(x) 306 | x = resnet2(local_cond, global_feature) 307 | h_local.append(x) 308 | 309 | x = sample 310 | h = [] 311 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 312 | x = resnet(x, global_feature) 313 | if idx == 0 and len(h_local) > 0: 314 | x = x + h_local[0] 315 | x = resnet2(x, global_feature) 316 | h.append(x) 317 | x = downsample(x) 318 | 319 | for mid_module in self.mid_modules: 320 | x = mid_module(x, global_feature) 321 | 322 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 323 | x = torch.cat((x, h.pop()), dim=1) 324 | x = resnet(x, global_feature) 325 | x = resnet2(x, global_feature) 326 | x = upsample(x) 327 | 328 | # apply local condition here after x being fully upsampled 329 | if len(h_local) > 0: 330 | x = x + h_local[1] 331 | x = self.final_conv(x) 332 | 333 | x = einops.rearrange(x, "b t h -> b h t") 334 | return x 335 | 336 | def get_optimizer_groups(self, weight_decay, lr_layer_decay, lr_scale=1.0): 337 | return default_optimizer_groups( 338 | self, 339 | weight_decay=weight_decay, 340 | lr_scale=lr_scale, 341 | ) 342 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/features/__init__.py: -------------------------------------------------------------------------------- 1 | from .fusion import ObsTokenizer 2 | from .pointnet import PointNet 3 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/features/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from einops import rearrange 5 | 6 | import brs_algo.utils as U 7 | from brs_algo.optim import default_optimizer_groups 8 | 9 | 10 | class ObsTokenizer(nn.Module): 11 | def __init__( 12 | self, 13 | extractors: dict[str, nn.Module], 14 | *, 15 | use_modality_type_tokens: bool, 16 | token_dim: int, 17 | token_concat_order: list[str], 18 | strict: bool = True, 19 | ): 20 | assert set(extractors.keys()) == set(token_concat_order) 21 | super().__init__() 22 | self._extractors = nn.ModuleDict(extractors) 23 | self.output_dim = token_dim 24 | self._token_concat_order = token_concat_order 25 | self._strict = strict 26 | self._obs_groups = None 27 | self._use_modality_type_tokens = use_modality_type_tokens 28 | self._modality_type_tokens = None 29 | if use_modality_type_tokens: 30 | modality_type_tokens = {} 31 | for k in extractors: 32 | modality_type_tokens[k] = nn.Parameter(torch.zeros(token_dim)) 33 | self._modality_type_tokens = nn.ParameterDict(modality_type_tokens) 34 | 35 | def forward(self, obs: dict[str, torch.Tensor]): 36 | """ 37 | x: Dict of (B, T, ...) 38 | 39 | Each encoder should encode corresponding obs field to (B, T, E), where E = token_dim 40 | 41 | The final output interleaves encoded tokens along the time dimension 42 | """ 43 | obs = self._group_obs(obs) 44 | self._check_obs_key_match(obs) 45 | x = { 46 | k: v.forward(obs[k]) for k, v in self._extractors.items() 47 | } # dict of (B, T, E) 48 | if self._use_modality_type_tokens: 49 | for k in x: 50 | x[k] = x[k] + self._modality_type_tokens[k] 51 | x = rearrange( 52 | [x[k] for k in self._token_concat_order], 53 | "F B T E -> B (T F) E", 54 | ) 55 | self._check_output_shape(obs, x) 56 | return x 57 | 58 | def _group_obs(self, obs): 59 | obs_keys = obs.keys() 60 | if self._obs_groups is None: 61 | # group by / 62 | obs_groups = {k.split("/")[0] for k in obs_keys} 63 | self._obs_groups = sorted(list(obs_groups)) 64 | obs_rtn = {} 65 | for g in self._obs_groups: 66 | is_subgroup = any(k.startswith(f"{g}/") for k in obs_keys) 67 | if is_subgroup: 68 | obs_rtn[g] = { 69 | k.split("/", 1)[1]: v 70 | for k, v in obs.items() 71 | if k.startswith(f"{g}/") 72 | } 73 | else: 74 | obs_rtn[g] = obs[g] 75 | return obs_rtn 76 | 77 | @U.call_once 78 | def _check_obs_key_match(self, obs: dict): 79 | if self._strict: 80 | assert set(self._extractors.keys()) == set(obs.keys()) 81 | elif set(self._extractors.keys()) != set(obs.keys()): 82 | print( 83 | U.color_text( 84 | f"[warning] obs key mismatch: {set(self._extractors.keys())} != {set(obs.keys())}", 85 | "yellow", 86 | ) 87 | ) 88 | 89 | @U.call_once 90 | def _check_output_shape(self, obs, output): 91 | T = U.get_batch_size(U.any_slice(obs, np.s_[0]), strict=True) 92 | U.check_shape( 93 | output, (None, T * len(self._token_concat_order), self.output_dim) 94 | ) 95 | 96 | def get_optimizer_groups(self, weight_decay, lr_layer_decay, lr_scale=1.0): 97 | pg, pid = default_optimizer_groups( 98 | self, 99 | weight_decay=weight_decay, 100 | lr_scale=lr_scale, 101 | no_decay_filter=[ 102 | "_extractors.*", 103 | "_modality_type_tokens.*", 104 | ], 105 | ) 106 | return pg, pid 107 | 108 | @property 109 | def num_tokens_per_step(self): 110 | return len(self._token_concat_order) 111 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/features/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import brs_algo.utils as U 5 | from brs_algo.learning.nn.common import build_mlp 6 | from brs_algo.optim import default_optimizer_groups 7 | 8 | 9 | class PointNetCore(nn.Module): 10 | def __init__( 11 | self, 12 | *, 13 | point_channels: int = 3, 14 | output_dim: int, 15 | hidden_dim: int, 16 | hidden_depth: int, 17 | activation: str = "gelu", 18 | ): 19 | super().__init__() 20 | self._mlp = build_mlp( 21 | input_dim=point_channels, 22 | hidden_dim=hidden_dim, 23 | output_dim=output_dim, 24 | hidden_depth=hidden_depth, 25 | activation=activation, 26 | ) 27 | self.output_dim = output_dim 28 | 29 | def forward(self, x): 30 | """ 31 | x: (..., points, point_channels) 32 | """ 33 | x = U.any_to_torch_tensor(x) 34 | x = self._mlp(x) # (..., points, output_dim) 35 | x = torch.max(x, dim=-2)[0] # (..., output_dim) 36 | return x 37 | 38 | 39 | class PointNet(nn.Module): 40 | def __init__( 41 | self, 42 | *, 43 | n_coordinates: int = 3, 44 | n_color: int = 3, 45 | output_dim: int = 512, 46 | hidden_dim: int = 512, 47 | hidden_depth: int = 2, 48 | activation: str = "gelu", 49 | subtract_mean: bool = False, 50 | ): 51 | super().__init__() 52 | pn_in_channels = n_coordinates + n_color 53 | if subtract_mean: 54 | pn_in_channels += n_coordinates 55 | self.pointnet = PointNetCore( 56 | point_channels=pn_in_channels, 57 | output_dim=output_dim, 58 | hidden_dim=hidden_dim, 59 | hidden_depth=hidden_depth, 60 | activation=activation, 61 | ) 62 | self.subtract_mean = subtract_mean 63 | self.output_dim = self.pointnet.output_dim 64 | 65 | def forward(self, x): 66 | """ 67 | x["xyz"]: (..., points, coordinates) 68 | x["rgb"]: (..., points, color) 69 | """ 70 | xyz = x["xyz"] 71 | rgb = x["rgb"] 72 | point = U.any_to_torch_tensor(xyz) 73 | if self.subtract_mean: 74 | mean = torch.mean(point, dim=-2, keepdim=True) # (..., 1, coordinates) 75 | mean = torch.broadcast_to(mean, point.shape) # (..., points, coordinates) 76 | point = point - mean 77 | point = torch.cat([point, mean], dim=-1) # (..., points, 2 * coordinates) 78 | rgb = U.any_to_torch_tensor(rgb) 79 | x = torch.cat([point, rgb], dim=-1) 80 | return self.pointnet(x) 81 | 82 | def get_optimizer_groups(self, weight_decay, lr_layer_decay, lr_scale=1.0): 83 | pg, pids = default_optimizer_groups( 84 | self, 85 | weight_decay=weight_decay, 86 | lr_scale=lr_scale, 87 | no_decay_filter=["ee_embd_layer.*"], 88 | ) 89 | return pg, pids 90 | -------------------------------------------------------------------------------- /brs_algo/learning/nn/gpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt import GPT 2 | -------------------------------------------------------------------------------- /brs_algo/learning/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .wbvima_policy import WBVIMAPolicy 2 | -------------------------------------------------------------------------------- /brs_algo/learning/policy/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from pytorch_lightning import LightningModule 4 | 5 | 6 | class BasePolicy(ABC, LightningModule): 7 | is_sequence_policy: bool = ( 8 | False # is this a feedforward policy or a policy requiring history 9 | ) 10 | 11 | @abstractmethod 12 | def forward(self, *args, **kwargs): 13 | """ 14 | Forward the NN. 15 | """ 16 | pass 17 | 18 | @abstractmethod 19 | def act(self, *args, **kwargs): 20 | """ 21 | Given obs, return action. 22 | """ 23 | pass 24 | 25 | 26 | class BaseDiffusionPolicy(BasePolicy): 27 | @abstractmethod 28 | def get_optimizer_groups(self, *args, **kwargs): 29 | """ 30 | Return a list of optimizer groups. 31 | """ 32 | pass 33 | -------------------------------------------------------------------------------- /brs_algo/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /brs_algo/lightning/lightning.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import logging 3 | import time 4 | from copy import deepcopy 5 | 6 | import sys 7 | from omegaconf import DictConfig, OmegaConf, ListConfig 8 | import pytorch_lightning as pl 9 | import pytorch_lightning.loggers as pl_loggers 10 | from pytorch_lightning.callbacks import ( 11 | Callback, 12 | ProgressBar, 13 | TQDMProgressBar, 14 | RichProgressBar, 15 | ) 16 | from pytorch_lightning.utilities import rank_zero_only 17 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug as rank_zero_debug_pl 18 | from pytorch_lightning.utilities.rank_zero import rank_zero_info as rank_zero_info_pl 19 | from pytorch_lightning.callbacks import ModelCheckpoint 20 | 21 | import brs_algo.utils as U 22 | 23 | 24 | __all__ = [ 25 | "LightingTrainer", 26 | "rank_zero_info", 27 | "rank_zero_debug", 28 | "rank_zero_warn", 29 | "rank_zero_info_pl", 30 | "rank_zero_debug_pl", 31 | ] 32 | 33 | logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING) 34 | logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING) 35 | U.logging_exclude_pattern( 36 | "root", patterns="*Reducer buckets have been rebuilt in this iteration*" 37 | ) 38 | 39 | 40 | class LightingTrainer: 41 | def __init__(self, cfg: DictConfig, eval_only=False): 42 | """ 43 | Args: 44 | eval_only: if True, will not save any model dir 45 | """ 46 | cfg = deepcopy(cfg) 47 | OmegaConf.set_struct(cfg, False) 48 | self.cfg = cfg 49 | self.run_command_args = sys.argv[1:] 50 | U.register_omegaconf_resolvers() 51 | self.register_extra_resolvers(cfg) 52 | self._register_classes(cfg) 53 | # U.pprint_(OmegaConf.to_container(cfg, resolve=True)) 54 | run_name = self.generate_run_name(cfg) 55 | self.run_dir = U.f_join(cfg.exp_root_dir, run_name) 56 | # rank_zero_info(OmegaConf.to_yaml(cfg, resolve=True)) 57 | self._eval_only = eval_only 58 | self._resume_mode = None # 'full state' or 'model only' 59 | if eval_only: 60 | rank_zero_info("Eval only, will not save any model dir") 61 | else: 62 | if "resume" in cfg and "ckpt_path" in cfg.resume and cfg.resume.ckpt_path: 63 | cfg.resume.ckpt_path = U.f_expand( 64 | cfg.resume.ckpt_path.replace("_RUN_DIR_", self.run_dir).replace( 65 | "_RUN_NAME_", run_name 66 | ) 67 | ) 68 | self._resume_mode = ( 69 | "full state" 70 | if cfg.resume.get("full_state", False) 71 | else "model only" 72 | ) 73 | rank_zero_info( 74 | "=" * 80, 75 | "=" * 80 + "\n", 76 | f"Resume training from {cfg.resume.ckpt_path}", 77 | f"\t({self._resume_mode})\n", 78 | "=" * 80, 79 | "=" * 80, 80 | sep="\n", 81 | end="\n\n", 82 | ) 83 | time.sleep(3) 84 | assert U.f_exists( 85 | cfg.resume.ckpt_path 86 | ), "resume ckpt_path does not exist" 87 | 88 | rank_zero_print("Run name:", run_name, "\nExp dir:", self.run_dir) 89 | U.f_mkdir(self.run_dir) 90 | U.f_mkdir(U.f_join(self.run_dir, "tb")) 91 | U.f_mkdir(U.f_join(self.run_dir, "logs")) 92 | U.f_mkdir(U.f_join(self.run_dir, "ckpt")) 93 | U.omegaconf_save(cfg, self.run_dir, "conf.yaml") 94 | rank_zero_print( 95 | "Checkpoint cfg:", U.omegaconf_to_dict(cfg.trainer.checkpoint) 96 | ) 97 | self.cfg = cfg 98 | self.run_name = run_name 99 | self.ckpt_cfg = cfg.trainer.pop("checkpoint") 100 | self.data_module = self.create_data_module(cfg) 101 | self._monkey_patch_add_info(self.data_module) 102 | self.trainer = self.create_trainer(cfg) 103 | self.module = self.create_module(cfg) 104 | self.module.data_module = self.data_module 105 | self._monkey_patch_add_info(self.module) 106 | 107 | if not eval_only and self._resume_mode == "model only": 108 | ret = self.module.load_state_dict( 109 | U.torch_load(cfg.resume.ckpt_path)["state_dict"], 110 | strict=cfg.resume.strict, 111 | ) 112 | U.rank_zero_warn("state_dict load status:", ret) 113 | 114 | def create_module(self, cfg) -> pl.LightningModule: 115 | return U.instantiate(cfg.module) 116 | 117 | def create_data_module(self, cfg) -> pl.LightningDataModule: 118 | return U.instantiate(cfg.data_module) 119 | 120 | def generate_run_name(self, cfg): 121 | return cfg.run_name + "_" + time.strftime("%Y%m%d-%H%M%S") 122 | 123 | def _monkey_patch_add_info(self, obj): 124 | """ 125 | Add useful info to module and data_module so they can access directly 126 | """ 127 | # our own info 128 | obj.run_config = self.cfg 129 | obj.run_name = self.run_name 130 | obj.run_command_args = self.run_command_args 131 | # add properties from trainer 132 | for attr in [ 133 | "global_rank", 134 | "local_rank", 135 | "world_size", 136 | "num_nodes", 137 | "num_processes", 138 | "node_rank", 139 | "num_gpus", 140 | "data_parallel_device_ids", 141 | ]: 142 | if hasattr(obj, attr): 143 | continue 144 | setattr( 145 | obj.__class__, 146 | attr, 147 | # force capture 'attr' 148 | property(lambda self, attr=attr: getattr(self.trainer, attr)), 149 | ) 150 | 151 | def create_loggers(self, cfg) -> List[pl.loggers.Logger]: 152 | if self._eval_only: 153 | loggers = [] 154 | else: 155 | loggers = [ 156 | pl_loggers.TensorBoardLogger(self.run_dir, name="tb", version=""), 157 | pl_loggers.CSVLogger(self.run_dir, name="logs", version=""), 158 | ] 159 | if cfg.use_wandb: 160 | loggers.append( 161 | pl_loggers.WandbLogger( 162 | name=cfg.wandb_run_name, project=cfg.wandb_project, id=self.run_name 163 | ) 164 | ) 165 | return loggers 166 | 167 | def create_callbacks(self, cfg) -> List[Callback]: 168 | ModelCheckpoint.FILE_EXTENSION = ".pth" 169 | callbacks = [] 170 | if isinstance(self.ckpt_cfg, DictConfig): 171 | ckpt = ModelCheckpoint( 172 | dirpath=U.f_join(self.run_dir, "ckpt"), **self.ckpt_cfg 173 | ) 174 | callbacks.append(ckpt) 175 | else: 176 | assert isinstance(self.ckpt_cfg, ListConfig) 177 | for _cfg in self.ckpt_cfg: 178 | ckpt = ModelCheckpoint(dirpath=U.f_join(self.run_dir, "ckpt"), **_cfg) 179 | callbacks.append(ckpt) 180 | 181 | if "callbacks" in cfg.trainer: 182 | extra_callbacks = U.instantiate(cfg.trainer.pop("callbacks")) 183 | assert U.is_sequence(extra_callbacks), "callbacks must be a sequence" 184 | callbacks.extend(extra_callbacks) 185 | if not any(isinstance(c, ProgressBar) for c in callbacks): 186 | callbacks.append(CustomTQDMProgressBar()) 187 | rank_zero_print( 188 | "Lightning callbacks:", [c.__class__.__name__ for c in callbacks] 189 | ) 190 | return callbacks 191 | 192 | def create_trainer(self, cfg) -> pl.Trainer: 193 | assert "trainer" in cfg 194 | C = cfg.trainer 195 | # find_unused_parameters = C.pop("find_unused_parameters", False) 196 | # rank_zero_info("DDP Strategy", C.strategy) 197 | return U.instantiate( 198 | C, logger=self.create_loggers(cfg), callbacks=self.create_callbacks(cfg) 199 | ) 200 | 201 | @property 202 | def tb_logger(self): 203 | return self.logger[0].experiment 204 | 205 | def fit(self): 206 | return self.trainer.fit( 207 | self.module, 208 | datamodule=self.data_module, 209 | ckpt_path=( 210 | self.cfg.resume.ckpt_path if self._resume_mode == "full state" else None 211 | ), 212 | ) 213 | 214 | def validate(self): 215 | rank_zero_print("Start validation ...") 216 | assert "testing" in self.cfg, "`testing` sub-dict must be defined in config" 217 | ckpt_path = self.cfg.testing.ckpt_path 218 | if ckpt_path: 219 | ckpt_path = U.f_expand(ckpt_path) 220 | assert U.f_exists(ckpt_path), f"ckpt_path {ckpt_path} does not exist" 221 | rank_zero_info("Run validation on ckpt:", ckpt_path) 222 | ret = self.module.load_state_dict( 223 | U.torch_load(ckpt_path)["state_dict"], strict=self.cfg.testing.strict 224 | ) 225 | U.rank_zero_warn("state_dict load status:", ret) 226 | ckpt_path = None # not using pytorch lightning's load 227 | else: 228 | rank_zero_warn("WARNING: no ckpt_path specified, will NOT load any weights") 229 | return self.trainer.validate( 230 | self.module, datamodule=self.data_module, ckpt_path=ckpt_path 231 | ) 232 | 233 | def _register_classes(self, cfg): 234 | U.register_callable("DDPStrategy", pl.strategies.DDPStrategy) 235 | U.register_callable("LearningRateMonitor", pl.callbacks.LearningRateMonitor) 236 | U.register_callable("ModelSummary", pl.callbacks.ModelSummary) 237 | U.register_callable("RichModelSummary", pl.callbacks.RichModelSummary) 238 | self.register_extra_classes(cfg) 239 | 240 | def register_extra_classes(self, cfg): 241 | pass 242 | 243 | def register_extra_resolvers(self, cfg): 244 | pass 245 | 246 | 247 | @U.register_class 248 | class CustomTQDMProgressBar(TQDMProgressBar): 249 | """ 250 | https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ProgressBarBase.html#pytorch_lightning.callbacks.ProgressBarBase.get_metrics 251 | """ 252 | 253 | def get_metrics(self, trainer, model): 254 | # don't show the version number 255 | items = super().get_metrics(trainer, model) 256 | items.pop("v_num", None) 257 | return items 258 | 259 | 260 | # do the same overriding for RichProgressBar 261 | @U.register_class 262 | class CustomRichProgressBar(RichProgressBar): 263 | def get_metrics(self, trainer, model): 264 | # don't show the version number 265 | items = super().get_metrics(trainer, model) 266 | items.pop("v_num", None) 267 | return items 268 | 269 | 270 | @rank_zero_only 271 | def rank_zero_print(*msg, **kwargs): 272 | U.pprint_(*msg, **kwargs) 273 | 274 | 275 | @rank_zero_only 276 | def rank_zero_info(*msg, **kwargs): 277 | U.pprint_( 278 | U.color_text("[INFO]", color="green", styles=["reverse", "bold"]), 279 | *msg, 280 | **kwargs, 281 | ) 282 | 283 | 284 | @rank_zero_only 285 | def rank_zero_warn(*msg, **kwargs): 286 | U.pprint_( 287 | U.color_text("[WARN]", color="yellow", styles=["reverse", "bold"]), 288 | *msg, 289 | **kwargs, 290 | ) 291 | 292 | 293 | @rank_zero_only 294 | def rank_zero_debug(*msg, **kwargs): 295 | if rank_zero_debug.enabled: 296 | U.pprint_( 297 | U.color_text("[DEBUG]", color="blue", bg_color="on_grey"), *msg, **kwargs 298 | ) 299 | 300 | 301 | rank_zero_debug.enabled = True 302 | -------------------------------------------------------------------------------- /brs_algo/lightning/trainer.py: -------------------------------------------------------------------------------- 1 | from hydra.utils import instantiate 2 | 3 | from brs_algo.lightning.lightning import LightingTrainer 4 | 5 | 6 | class Trainer(LightingTrainer): 7 | def create_module(self, cfg): 8 | return instantiate(cfg.module, _recursive_=False) 9 | 10 | def create_data_module(self, cfg): 11 | return instantiate(cfg.data_module) 12 | 13 | def create_callbacks(self, cfg): 14 | is_rollout_eval = self.cfg.rollout_eval 15 | 16 | del_idxs = [] 17 | for i, _cfg in enumerate(self.ckpt_cfg): 18 | eval_type = getattr(_cfg, "eval_type", None) 19 | if eval_type is not None: 20 | if is_rollout_eval and eval_type == "static": 21 | del_idxs.append(i) 22 | elif not is_rollout_eval and eval_type == "rollout": 23 | del_idxs.append(i) 24 | del _cfg["eval_type"] 25 | for i in reversed(del_idxs): 26 | del self.ckpt_cfg[i] 27 | return super().create_callbacks(cfg) 28 | -------------------------------------------------------------------------------- /brs_algo/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_schedule import * 2 | from .optimizer_group import * 3 | -------------------------------------------------------------------------------- /brs_algo/optim/lr_schedule.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | from torch.optim.lr_scheduler import LambdaLR 5 | 6 | 7 | __all__ = [ 8 | "CosineScheduleFunction", 9 | "CosineLRScheduler", 10 | "LambdaLRWithScale", 11 | "generate_cosine_schedule", 12 | ] 13 | 14 | 15 | def generate_cosine_schedule( 16 | base_value, 17 | final_value, 18 | epochs, 19 | steps_per_epoch, 20 | warmup_epochs=0, 21 | warmup_start_value=0, 22 | ) -> np.ndarray: 23 | warmup_schedule = np.array([]) 24 | warmup_iters = int(warmup_epochs * steps_per_epoch) 25 | if warmup_epochs > 0: 26 | warmup_schedule = np.linspace(warmup_start_value, base_value, warmup_iters) 27 | 28 | iters = np.arange(int(epochs * steps_per_epoch) - warmup_iters) 29 | schedule = np.array( 30 | [ 31 | final_value 32 | + 0.5 33 | * (base_value - final_value) 34 | * (1 + math.cos(math.pi * i / (len(iters)))) 35 | for i in iters 36 | ] 37 | ) 38 | schedule = np.concatenate((warmup_schedule, schedule)) 39 | assert len(schedule) == int(epochs * steps_per_epoch) 40 | return schedule 41 | 42 | 43 | class CosineScheduleFunction: 44 | def __init__( 45 | self, 46 | base_value, 47 | final_value, 48 | epochs, 49 | steps_per_epoch, 50 | warmup_epochs=0, 51 | warmup_start_value=0, 52 | ): 53 | """ 54 | Usage: 55 | scheduler = torch.optim.lr_scheduler.LambdaLR( 56 | optimizer=optimizer, lr_lambda=CosineScheduleFunction(**kwargs) 57 | ) 58 | or simply use CosineScheduler(**kwargs) 59 | 60 | Args: 61 | epochs: effective epochs for the cosine schedule, *including* warmup 62 | after these epochs, scheduler will output `final_value` ever after 63 | """ 64 | assert warmup_epochs < epochs, f"{warmup_epochs=} must be < {epochs=}" 65 | self._effective_steps = int(epochs * steps_per_epoch) 66 | self.schedule = generate_cosine_schedule( 67 | base_value=base_value, 68 | final_value=final_value, 69 | epochs=epochs, 70 | steps_per_epoch=steps_per_epoch, 71 | warmup_epochs=warmup_epochs, 72 | warmup_start_value=warmup_start_value, 73 | ) 74 | assert self.schedule.shape == (self._effective_steps,) 75 | self._final_value = final_value 76 | self._steps_tensor = torch.tensor(0, dtype=torch.long) # for register buffer 77 | 78 | def register_buffer(self, module: torch.nn.Module, name="cosine_steps"): 79 | module.register_buffer(name, self._steps_tensor, persistent=True) 80 | 81 | def __call__(self, step): 82 | self._steps_tensor.copy_(torch.tensor(step)) 83 | if step >= self._effective_steps: 84 | val = self._final_value 85 | else: 86 | val = self.schedule[step] 87 | return val 88 | 89 | 90 | class LambdaLRWithScale(LambdaLR): 91 | """ 92 | Supports param_groups['lr_scale'], multiplies base_lr with lr_scale 93 | """ 94 | 95 | def get_lr(self): 96 | lrs = super().get_lr() 97 | param_groups = self.optimizer.param_groups 98 | assert len(lrs) == len( 99 | param_groups 100 | ), f"INTERNAL: {len(lrs)=} != {len(param_groups)=}" 101 | for i, param_group in enumerate(param_groups): 102 | if "lr_scale" in param_group: 103 | lrs[i] *= param_group["lr_scale"] 104 | # print("LambdaLRWithScale: lrs =", lrs) 105 | return lrs 106 | 107 | 108 | class CosineLRScheduler(LambdaLRWithScale): 109 | """ 110 | Supports param_groups['lr_scale'], multiplies base_lr with lr_scale 111 | """ 112 | 113 | def __init__( 114 | self, 115 | optimizer, 116 | base_value, 117 | final_value, 118 | epochs, 119 | steps_per_epoch, 120 | warmup_epochs=0, 121 | warmup_start_value=0, 122 | last_epoch=-1, 123 | verbose=False, 124 | ): 125 | lr_lambda = CosineScheduleFunction( 126 | base_value=base_value, 127 | final_value=final_value, 128 | epochs=epochs, 129 | steps_per_epoch=steps_per_epoch, 130 | warmup_epochs=warmup_epochs, 131 | warmup_start_value=warmup_start_value, 132 | ) 133 | super().__init__( 134 | optimizer, lr_lambda=lr_lambda, last_epoch=last_epoch, verbose=verbose 135 | ) 136 | -------------------------------------------------------------------------------- /brs_algo/optim/optimizer_group.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate Optimizer groups 3 | """ 4 | 5 | from typing import Callable, Union, List, Tuple 6 | import torch 7 | import torch.nn as nn 8 | 9 | from brs_algo.utils.misc_utils import match_patterns, getattr_nested 10 | from brs_algo.utils.torch_utils import freeze_params 11 | 12 | 13 | __all__ = [ 14 | "default_optimizer_groups", 15 | "transformer_lr_decay_optimizer_groups", 16 | "transformer_freeze_layers", 17 | "transformer_freeze_except_last_layers", 18 | "check_optimizer_groups", 19 | ] 20 | 21 | FilterType = Union[ 22 | Callable[[str, torch.Tensor], bool], List[str], Tuple[str], str, None 23 | ] 24 | 25 | HAS_REMOVEPREFIX = hasattr(str, "removeprefix") 26 | 27 | 28 | def default_optimizer_groups( 29 | model: nn.Module, 30 | weight_decay: float, 31 | lr_scale: float = 1.0, 32 | no_decay_filter: FilterType = None, 33 | exclude_filter: FilterType = None, 34 | ): 35 | """ 36 | lr_scale is only effective when using with enlight.learn.lr_schedule.LambdaLRWithScale 37 | 38 | Returns: 39 | [{'lr_scale': 1.0, 'weight_decay': weight_decay, 'params': decay_group}, 40 | {'lr_scale': 1.0, 'weight_decay': 0.0, 'params': no_decay_group}], 41 | list of all param_ids processed 42 | """ 43 | no_decay_filter = _transform_filter(no_decay_filter) 44 | exclude_filter = _transform_filter(exclude_filter) 45 | decay_group = [] 46 | no_decay_group = [] 47 | all_params_id = [] 48 | for n, p in model.named_parameters(): 49 | all_params_id.append(id(p)) 50 | if not p.requires_grad or exclude_filter(n, p): 51 | continue 52 | 53 | # no decay: all 1D parameters and model specific ones 54 | if p.ndim == 1 or no_decay_filter(n, p): 55 | no_decay_group.append(p) 56 | else: 57 | decay_group.append(p) 58 | return [ 59 | {"weight_decay": weight_decay, "params": decay_group, "lr_scale": lr_scale}, 60 | {"weight_decay": 0.0, "params": no_decay_group, "lr_scale": lr_scale}, 61 | ], all_params_id 62 | 63 | 64 | def _get_transformer_blocks(model, block_sequence_name): 65 | block_sequence_name = block_sequence_name.rstrip(".") 66 | return getattr_nested(model, block_sequence_name) 67 | 68 | 69 | def transformer_freeze_layers( 70 | model, 71 | layer_0_params: List[str], 72 | block_sequence_name, 73 | freeze_layers: List[int], 74 | extra_freeze_filter: FilterType = None, 75 | ): 76 | """ 77 | Args: 78 | model: transformer model with pos embed and other preprocessing parts as layer 0 79 | and `block_sequence_name` that is a sequence of transformer blocks 80 | layer_0_params: list of parameter names before the first transformer 81 | block, which will be assigned layer 0 82 | block_sequence_name: name of the sequence module that contains transformer layers 83 | such that "block.0", "block.1", ... will share one LR within each block 84 | freeze_layers: list of layer indices to freeze. Include 0 to freeze 85 | preprocessing layers. Use negative indices to freeze from the last layers, 86 | e.g. -1 to freeze the last layer. Note that any nn.Module after the transformer 87 | block will NOT be frozen. 88 | extra_freeze_filter: filter to apply to the rest of the parameters 89 | """ 90 | extra_freeze_filter = _transform_filter(extra_freeze_filter) 91 | freeze_layers = list(freeze_layers) 92 | layer_0_params = _transform_filter(layer_0_params) 93 | blocks = _get_transformer_blocks(model, block_sequence_name) 94 | assert max(freeze_layers) <= len(blocks), f"max({freeze_layers}) > {len(blocks)}" 95 | # convert all negative indices to last 96 | freeze_layers = [(L if L >= 0 else len(blocks) + L + 1) for L in freeze_layers] 97 | for i, block in enumerate(blocks): 98 | if i + 1 in freeze_layers: 99 | freeze_params(block) 100 | 101 | for n, p in model.named_parameters(): 102 | if layer_0_params(n, p) and 0 in freeze_layers: 103 | freeze_params(p) 104 | if extra_freeze_filter(n, p): 105 | freeze_params(p) 106 | 107 | 108 | def transformer_freeze_except_last_layers( 109 | model, 110 | layer_0_params: List[str], 111 | block_sequence_name, 112 | num_last_layers: int, 113 | extra_freeze_filter: FilterType = None, 114 | ): 115 | """ 116 | According to Kaiming's MAE paper, finetune ONLY the last layers is typically 117 | as good as finetuning all layers, while being much more compute and memory efficient. 118 | 119 | Args: 120 | model: transformer model with pos embed and other preprocessing parts as layer 0 121 | and `block_sequence_name` that is a sequence of transformer blocks 122 | layer_0_params: list of parameter names before the first transformer 123 | block, which will be assigned layer 0 124 | block_sequence_name: name of the sequence module that contains transformer layers 125 | such that "block.0", "block.1", ... will share one LR within each block 126 | num_last_layers: number of last N layers to unfreeze (finetune) 127 | extra_freeze_filter: filter to apply to the rest of the parameters 128 | """ 129 | num_blocks = len(_get_transformer_blocks(model, block_sequence_name)) 130 | # blocks start from 1, because 0th block is preprocessing 131 | # get a range of the first num_blocks - num_last_layers blocks 132 | return transformer_freeze_layers( 133 | model, 134 | layer_0_params=layer_0_params, 135 | block_sequence_name=block_sequence_name, 136 | freeze_layers=range(num_blocks - num_last_layers + 1), 137 | extra_freeze_filter=extra_freeze_filter, 138 | ) 139 | 140 | 141 | def transformer_lr_decay_optimizer_groups( 142 | model, 143 | layer_0_params: List[str], 144 | block_sequence_name, 145 | *, 146 | weight_decay, 147 | lr_scale=1.0, 148 | lr_layer_decay, 149 | no_decay_filter: FilterType = None, 150 | exclude_filter: FilterType = None, 151 | ): 152 | """ 153 | Parameter groups for layer-wise lr decay 154 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 155 | lr_scale is only effective when using with enlight.learn.lr_schedule.LambdaLRWithScale 156 | 157 | Args: 158 | model: transformer model with pos embed and other preprocessing parts as layer 0, 159 | then the blocks start from layer 1 to layer N. LR will be progressively smaller 160 | from the last layer to the first layer. Layer N will be lr*lr_scale, 161 | N-1 will be lr*lr_scale*lr_layer_decay, 162 | N-2 will be lr*lr_scale*lr_layer_decay^2, ... 163 | layer_0_params: list of parameter names before the first transformer 164 | block, which will be assigned layer 0 165 | block_sequence_name: name of the sequence module that contains transformer layers 166 | such that "block.0", "block.1", ... will share one LR within each block 167 | """ 168 | no_decay_filter = _transform_filter(no_decay_filter) 169 | exclude_filter = _transform_filter(exclude_filter) 170 | layer_0_params = _transform_filter(layer_0_params) 171 | block_sequence_name = block_sequence_name.rstrip(".") 172 | 173 | param_group_names = {} 174 | param_groups = {} 175 | 176 | num_layers = len(getattr_nested(model, block_sequence_name)) + 1 177 | 178 | layer_scales = [ 179 | lr_scale * lr_layer_decay ** (num_layers - i) for i in range(num_layers + 1) 180 | ] 181 | all_params_id = [] 182 | 183 | for n, p in model.named_parameters(): 184 | all_params_id.append(id(p)) 185 | if not p.requires_grad or exclude_filter(n, p): 186 | continue 187 | 188 | # no decay: all 1D parameters and model specific ones 189 | if p.ndim == 1 or no_decay_filter(n, p): 190 | g_decay = "no_decay" 191 | this_decay = 0.0 192 | else: 193 | g_decay = "decay" 194 | this_decay = weight_decay 195 | 196 | # get the layer index of the param 197 | if layer_0_params(n, p): 198 | layer_id = 0 199 | elif n.startswith(block_sequence_name + "."): 200 | # blocks.0 -> layer 1; blocks.1 -> layer 2; ... 201 | try: 202 | if HAS_REMOVEPREFIX: 203 | layer_id = ( 204 | int(n.removeprefix(block_sequence_name + ".").split(".")[0]) + 1 205 | ) 206 | else: 207 | layer_id = int(n[len(block_sequence_name) + 1 :].split(".")[0]) + 1 208 | 209 | except ValueError: 210 | raise ValueError( 211 | f"{n} must have the format {block_sequence_name}.... " 212 | f"where is an integer" 213 | ) 214 | else: 215 | layer_id = num_layers 216 | group_name = f"layer_{layer_id}_{g_decay}" 217 | 218 | if group_name not in param_group_names: 219 | this_scale = layer_scales[layer_id] 220 | 221 | param_group_names[group_name] = { 222 | "lr_scale": this_scale, 223 | "weight_decay": this_decay, 224 | "params": [], 225 | } 226 | param_groups[group_name] = { 227 | "lr_scale": this_scale, 228 | "weight_decay": this_decay, 229 | "params": [], 230 | } 231 | 232 | param_group_names[group_name]["params"].append(n) 233 | param_groups[group_name]["params"].append(p) 234 | 235 | return list(param_groups.values()), all_params_id 236 | 237 | 238 | def check_optimizer_groups( 239 | model, param_groups: List[dict], verbose=True, order_by="param" 240 | ): 241 | """ 242 | For debugging purpose, check which param belongs to which param group 243 | This groups by param groups first 244 | 245 | Args: 246 | order_by: 'params' or 'groups', either by optimization group order or 247 | by nn.Module parameter list order. 248 | 249 | Returns: 250 | {param_name (str): group_idx (int)}, table (str, for ASCII print) 251 | """ 252 | from tabulate import tabulate 253 | 254 | group_configs = [ 255 | ", ".join(f"{k}={v:.6g}" for k, v in sorted(group.items()) if k != "params") 256 | for group in param_groups 257 | ] 258 | display_table = [] 259 | name_to_group_idx = {} 260 | pid_to_group_id = { 261 | id(p): i for i, group in enumerate(param_groups) for p in group["params"] 262 | } 263 | for n, p in model.named_parameters(recurse=True): 264 | if id(p) in pid_to_group_id: 265 | gid = pid_to_group_id[id(p)] 266 | display_table.append((n, gid, group_configs[gid])) 267 | if n in name_to_group_idx: 268 | if verbose: 269 | print( 270 | f"WARNING: {n} is in both group " 271 | f"{name_to_group_idx[n]} and {gid}" 272 | ) 273 | name_to_group_idx[n] = gid 274 | else: 275 | display_table.append((n, "_", "excluded")) 276 | name_to_group_idx[n] = None 277 | if order_by == "group": 278 | display_table.sort(key=lambda x: 1e10 if x[1] == "_" else x[1]) 279 | table_str = tabulate( 280 | display_table, headers=["param", "i", "config"], tablefmt="presto" 281 | ) 282 | return name_to_group_idx, table_str 283 | 284 | 285 | def _transform_filter(filter: FilterType): 286 | """ 287 | Filter can be: 288 | - None: always returns False 289 | - function(name, p) -> True to activate, False to deactivate 290 | - list of strings to match, can have wildcard 291 | """ 292 | if filter is None: 293 | return lambda name, p: False 294 | elif callable(filter): 295 | return filter 296 | elif isinstance(filter, (str, list, tuple)): 297 | if isinstance(filter, str): 298 | filter = [filter] 299 | return lambda name, p: match_patterns(name, include=filter) 300 | else: 301 | raise ValueError(f"Invalid filter: {filter}") 302 | -------------------------------------------------------------------------------- /brs_algo/rollout/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/brs_algo/rollout/__init__.py -------------------------------------------------------------------------------- /brs_algo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .array_tensor_utils import * 2 | from .config_utils import * 3 | from .convert_utils import * 4 | from .file_utils import * 5 | from .functional_utils import * 6 | from .misc_utils import * 7 | from .print_utils import * 8 | from .random_seed import * 9 | from .shape_utils import * 10 | from .termcolor import * 11 | from .torch_utils import * 12 | from .tree_utils import * 13 | -------------------------------------------------------------------------------- /brs_algo/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import tree 5 | import hydra 6 | import importlib.resources 7 | import sys 8 | 9 | from copy import deepcopy 10 | from omegaconf import OmegaConf, DictConfig 11 | from .functional_utils import meta_decorator, is_sequence, is_mapping, call_once 12 | from .print_utils import to_scientific_str 13 | 14 | _CLASS_REGISTRY = {} # for instantiation 15 | 16 | 17 | def resource_file_path(pkg_name, fname) -> str: 18 | with importlib.resources.path(pkg_name, fname) as p: 19 | return str(p) 20 | 21 | 22 | def print_config(cfg: DictConfig): 23 | print(cfg.pretty(resolve=True)) 24 | 25 | 26 | def is_hydra_initialized(): 27 | return hydra.utils.HydraConfig.initialized() 28 | 29 | 30 | def hydra_config(): 31 | # https://github.com/facebookresearch/hydra/issues/377 32 | # HydraConfig() is a singleton 33 | if is_hydra_initialized(): 34 | return hydra.utils.HydraConfig().cfg.hydra 35 | else: 36 | return None 37 | 38 | 39 | def hydra_override_arg_list() -> list[str]: 40 | """ 41 | Returns: 42 | list ["lr=0.2", "batch=64", ...] 43 | """ 44 | if is_hydra_initialized(): 45 | return hydra_config().overrides.task 46 | else: 47 | return [] 48 | 49 | 50 | def hydra_override_name(): 51 | if is_hydra_initialized(): 52 | return hydra_config().job.override_dirname 53 | else: 54 | return "" 55 | 56 | 57 | def hydra_original_dir(*subpaths): 58 | return os.path.join(hydra.utils.get_original_cwd(), *subpaths) 59 | 60 | 61 | @call_once(on_second_call="noop") 62 | def register_omegaconf_resolvers(): 63 | import numpy as np 64 | 65 | OmegaConf.register_new_resolver( 66 | "scientific", lambda v, i=0: to_scientific_str(v, i) 67 | ) 68 | OmegaConf.register_new_resolver("_optional", lambda v: f"_{v}" if v else "") 69 | OmegaConf.register_new_resolver("optional_", lambda v: f"{v}_" if v else "") 70 | OmegaConf.register_new_resolver("_optional_", lambda v: f"_{v}_" if v else "") 71 | OmegaConf.register_new_resolver("__optional", lambda v: f"__{v}" if v else "") 72 | OmegaConf.register_new_resolver("optional__", lambda v: f"{v}__" if v else "") 73 | OmegaConf.register_new_resolver("__optional__", lambda v: f"__{v}__" if v else "") 74 | OmegaConf.register_new_resolver( 75 | "iftrue", lambda cond, v_default: cond if cond else v_default 76 | ) 77 | OmegaConf.register_new_resolver( 78 | "ifelse", lambda cond, v1, v2="": v1 if cond else v2 79 | ) 80 | OmegaConf.register_new_resolver( 81 | "ifequal", lambda query, key, v1, v2: v1 if query == key else v2 82 | ) 83 | OmegaConf.register_new_resolver("intbool", lambda cond: 1 if cond else 0) 84 | OmegaConf.register_new_resolver("mult", lambda *x: np.prod(x).tolist()) 85 | OmegaConf.register_new_resolver("add", lambda *x: sum(x)) 86 | OmegaConf.register_new_resolver("div", lambda x, y: x / y) 87 | OmegaConf.register_new_resolver("intdiv", lambda x, y: x // y) 88 | 89 | # try each key until the key exists. Useful for multiple classes that have different 90 | # names for the same key 91 | def _try_key(cfg, *keys): 92 | for k in keys: 93 | if k in cfg: 94 | return cfg[k] 95 | raise KeyError(f"no key in {keys} is valid") 96 | 97 | OmegaConf.register_new_resolver("trykey", _try_key) 98 | # replace `resnet.gn.ws` -> `resnet_gn_ws`, because omegaconf doesn't support 99 | # keys with dots. Useful for generating run name with dots 100 | OmegaConf.register_new_resolver("underscore_to_dots", lambda s: s.replace("_", ".")) 101 | 102 | def _no_instantiate(cfg): 103 | cfg = deepcopy(cfg) 104 | cfg[_NO_INSTANTIATE] = True 105 | return cfg 106 | 107 | OmegaConf.register_new_resolver("no_instantiate", _no_instantiate) 108 | 109 | 110 | # ======================================================== 111 | # ================== Instantiation tools ================ 112 | # ======================================================== 113 | 114 | 115 | def register_callable(name, class_type): 116 | if isinstance(class_type, str): 117 | class_type, name = name, class_type 118 | assert callable(class_type) 119 | _CLASS_REGISTRY[name] = class_type 120 | 121 | 122 | @meta_decorator 123 | def register_class(cls, alias=None): 124 | """ 125 | Decorator 126 | """ 127 | assert callable(cls) 128 | _CLASS_REGISTRY[cls.__name__] = cls 129 | if alias: 130 | assert is_sequence(alias) 131 | for a in alias: 132 | _CLASS_REGISTRY[str(a)] = cls 133 | return cls 134 | 135 | 136 | def omegaconf_to_dict(cfg, resolve: bool = True, enum_to_str: bool = False): 137 | """ 138 | Convert arbitrary nested omegaconf objects to primitive containers 139 | 140 | WARNING: cannot use tree lib because it gets confused on DictConfig and ListConfig 141 | """ 142 | kw = dict(resolve=resolve, enum_to_str=enum_to_str) 143 | if OmegaConf.is_config(cfg): 144 | return OmegaConf.to_container(cfg, **kw) 145 | elif is_sequence(cfg): 146 | return type(cfg)(omegaconf_to_dict(c, **kw) for c in cfg) 147 | elif is_mapping(cfg): 148 | return {k: omegaconf_to_dict(c, **kw) for k, c in cfg.items()} 149 | else: 150 | return cfg 151 | 152 | 153 | def omegaconf_save(cfg, *paths: str, resolve: bool = True): 154 | """ 155 | Save omegaconf to yaml 156 | """ 157 | from .file_utils import f_join 158 | 159 | OmegaConf.save(cfg, f_join(*paths), resolve=resolve) 160 | 161 | 162 | def get_class(path): 163 | """ 164 | First try to find the class in the registry first, 165 | if it doesn't exist, use importlib to locate it 166 | """ 167 | if path in _CLASS_REGISTRY: 168 | return _CLASS_REGISTRY[path] 169 | else: 170 | assert "." in path, ( 171 | f"Because {path} is not found in class registry, " 172 | f"it must be a full module path" 173 | ) 174 | try: 175 | from importlib import import_module 176 | 177 | module_path, _, class_name = path.rpartition(".") 178 | mod = import_module(module_path) 179 | try: 180 | class_type = getattr(mod, class_name) 181 | except AttributeError: 182 | raise ImportError( 183 | "Class {} is not in module {}".format(class_name, module_path) 184 | ) 185 | return class_type 186 | except ValueError as e: 187 | print("Error initializing class " + path, file=sys.stderr) 188 | raise e 189 | 190 | 191 | _DELETE_ARG = "__delete__" 192 | _NO_INSTANTIATE = "__no_instantiate__" # return config as-is 193 | _OMEGA_MISSING = "???" 194 | 195 | 196 | def _get_instantiate_params(cfg, kwargs=None): 197 | params = cfg 198 | f_args, f_kwargs = (), {} 199 | for k, value in params.items(): 200 | if k in ["cls", "class"]: 201 | continue 202 | elif k == "*args": 203 | assert is_sequence(value), '"*args" value must be a sequence' 204 | f_args = list(value) 205 | continue 206 | if value == _OMEGA_MISSING: 207 | if kwargs and k in kwargs: 208 | value = kwargs[k] 209 | else: 210 | raise ValueError(f'Missing required keyword arg "{k}" in cfg: {cfg}') 211 | if value == _DELETE_ARG: 212 | continue 213 | else: 214 | f_kwargs[k] = value 215 | return f_args, f_kwargs 216 | 217 | 218 | def _instantiate_single(cfg): 219 | if is_mapping(cfg) and ("cls" in cfg or "class" in cfg): 220 | assert bool("cls" in cfg) != bool("class" in cfg), ( 221 | "to instantiate from config, " 222 | 'one and only one of "cls" or "class" key should be provided' 223 | ) 224 | if _NO_INSTANTIATE in cfg: 225 | no_instantiate = cfg.pop(_NO_INSTANTIATE) 226 | if no_instantiate: 227 | cfg = deepcopy(cfg) 228 | return cfg 229 | else: 230 | return _instantiate_single(cfg) 231 | 232 | cls = cfg.get("class", cfg.get("cls")) 233 | args, kwargs = _get_instantiate_params(cfg) 234 | try: 235 | class_type = get_class(cls) 236 | return class_type(*args, **kwargs) 237 | except Exception as e: 238 | raise RuntimeError(f"Error instantiating {cls}: {e}") 239 | else: 240 | return None 241 | 242 | 243 | def instantiate(_cfg_, **kwargs): 244 | """ 245 | Any dict with "cls" or "class" key is considered instantiable. 246 | 247 | Any key that has the special value "__delete__" 248 | will not be passed to the constructor 249 | 250 | **kwargs only apply to the top level object if it's a dict, otherwise raise error 251 | """ 252 | assert ( 253 | OmegaConf.is_config(_cfg_) 254 | or isinstance(_cfg_, (list, tuple)) 255 | or is_mapping(_cfg_) 256 | ), ( 257 | '"cfg" must be a dict, list, tuple, or OmegaConf config to be instantiated. ' 258 | f"Current its type is {type(_cfg_)}" 259 | ) 260 | 261 | _cfg_ = omegaconf_to_dict(_cfg_, resolve=True) 262 | 263 | if kwargs: 264 | if is_mapping(_cfg_): 265 | _cfg_ = _cfg_.copy() 266 | _cfg_.update(kwargs) 267 | _cfg_ = {k: v for k, v in _cfg_.items() if v != _DELETE_ARG} 268 | else: 269 | raise RuntimeError( 270 | f"**kwargs specified, but the top-level cfg is not a dict. " 271 | f"It has type {type(_cfg_)}" 272 | ) 273 | 274 | return tree.traverse(_instantiate_single, _cfg_, top_down=False) 275 | -------------------------------------------------------------------------------- /brs_algo/utils/convert_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import os.path 6 | from PIL import Image 7 | 8 | 9 | def np_dtype_size(dtype: Union[str, np.dtype]) -> int: 10 | return np.dtype(dtype).itemsize 11 | 12 | 13 | _TORCH_DTYPE_TABLE = { 14 | torch.bool: 1, 15 | torch.int8: 1, 16 | torch.uint8: 1, 17 | torch.int16: 2, 18 | torch.short: 2, 19 | torch.int32: 4, 20 | torch.int: 4, 21 | torch.int64: 8, 22 | torch.long: 8, 23 | torch.float16: 2, 24 | torch.bfloat16: 2, 25 | torch.half: 2, 26 | torch.float32: 4, 27 | torch.float: 4, 28 | torch.float64: 8, 29 | torch.double: 8, 30 | } 31 | 32 | 33 | def torch_dtype(dtype: Union[str, torch.dtype, None]) -> torch.dtype: 34 | if dtype is None: 35 | return None 36 | elif isinstance(dtype, torch.dtype): 37 | return dtype 38 | elif isinstance(dtype, str): 39 | try: 40 | dtype = getattr(torch, dtype) 41 | except AttributeError: 42 | raise ValueError(f'"{dtype}" is not a valid torch dtype') 43 | assert isinstance( 44 | dtype, torch.dtype 45 | ), f"dtype {dtype} is not a valid torch tensor type" 46 | return dtype 47 | else: 48 | raise NotImplementedError(f"{dtype} not supported") 49 | 50 | 51 | def torch_device(device: Union[str, int, None]) -> torch.device: 52 | """ 53 | Args: 54 | device: 55 | - "auto": use current torch context device, same as `.to('cuda')` 56 | - int: negative for CPU, otherwise GPU index 57 | """ 58 | if device is None: 59 | return None 60 | elif device == "auto": 61 | return torch.device("cuda") 62 | elif isinstance(device, int) and device < 0: 63 | return torch.device("cpu") 64 | else: 65 | return torch.device(device) 66 | 67 | 68 | def torch_dtype_size(dtype: Union[str, torch.dtype]) -> int: 69 | return _TORCH_DTYPE_TABLE[torch_dtype(dtype)] 70 | 71 | 72 | def _convert_then_transfer(x, dtype, device, copy, non_blocking): 73 | x = x.to(dtype=dtype, copy=copy, non_blocking=non_blocking) 74 | return x.to(device=device, copy=False, non_blocking=non_blocking) 75 | 76 | 77 | def _transfer_then_convert(x, dtype, device, copy, non_blocking): 78 | x = x.to(device=device, copy=copy, non_blocking=non_blocking) 79 | return x.to(dtype=dtype, copy=False, non_blocking=non_blocking) 80 | 81 | 82 | def any_to_torch_tensor( 83 | x, 84 | dtype: Union[str, torch.dtype, None] = None, 85 | device: Union[str, int, torch.device, None] = None, 86 | copy=False, 87 | non_blocking=False, 88 | smart_optimize: bool = True, 89 | ): 90 | dtype = torch_dtype(dtype) 91 | device = torch_device(device) 92 | 93 | if not isinstance(x, (torch.Tensor, np.ndarray)): 94 | # x is a primitive python sequence 95 | x = torch.tensor(x, dtype=dtype) 96 | copy = False 97 | 98 | # This step does not create any copy. 99 | # If x is a numpy array, simply wraps it in Tensor. If it's already a Tensor, do nothing. 100 | x = torch.as_tensor(x) 101 | # avoid passing None to .to(), PyTorch 1.4 bug 102 | dtype = dtype or x.dtype 103 | device = device or x.device 104 | 105 | if not smart_optimize: 106 | # do a single stage type conversion and transfer 107 | return x.to(dtype=dtype, device=device, copy=copy, non_blocking=non_blocking) 108 | 109 | # we have two choices: (1) convert dtype and then transfer to GPU 110 | # (2) transfer to GPU and then convert dtype 111 | # because CPU-to-GPU memory transfer is the bottleneck, we will reduce it as 112 | # much as possible by sending the smaller dtype 113 | 114 | src_dtype_size = torch_dtype_size(x.dtype) 115 | 116 | # destination dtype size 117 | if dtype is None: 118 | dest_dtype_size = src_dtype_size 119 | else: 120 | dest_dtype_size = torch_dtype_size(dtype) 121 | 122 | if x.dtype != dtype or x.device != device: 123 | # a copy will always be performed, no need to force copy again 124 | copy = False 125 | 126 | if src_dtype_size > dest_dtype_size: 127 | # better to do conversion on one device (e.g. CPU) and then transfer to another 128 | return _convert_then_transfer(x, dtype, device, copy, non_blocking) 129 | elif src_dtype_size == dest_dtype_size: 130 | # when equal, we prefer to do the conversion on whichever device that's GPU 131 | if x.device.type == "cuda": 132 | return _convert_then_transfer(x, dtype, device, copy, non_blocking) 133 | else: 134 | return _transfer_then_convert(x, dtype, device, copy, non_blocking) 135 | else: 136 | # better to transfer data across device first, and then do conversion 137 | return _transfer_then_convert(x, dtype, device, copy, non_blocking) 138 | 139 | 140 | def any_to_numpy( 141 | x, 142 | dtype: Union[str, np.dtype, None] = None, 143 | copy: bool = False, 144 | non_blocking: bool = False, 145 | smart_optimize: bool = True, 146 | ): 147 | if isinstance(x, torch.Tensor): 148 | x = any_to_torch_tensor( 149 | x, 150 | dtype=dtype, 151 | device="cpu", 152 | copy=copy, 153 | non_blocking=non_blocking, 154 | smart_optimize=smart_optimize, 155 | ) 156 | return x.detach().numpy() 157 | else: 158 | # primitive python sequence or ndarray 159 | return np.array(x, dtype=dtype, copy=copy) 160 | 161 | 162 | def img_to_tensor(file_path: str, dtype=None, device=None, add_batch_dim: bool = False): 163 | """ 164 | Args: 165 | scale_255: if True, scale to [0, 255] 166 | add_batch_dim: if 3D, add a leading batch dim 167 | 168 | Returns: 169 | tensor between [0, 255] 170 | 171 | """ 172 | # image path 173 | pic = Image.open(os.path.expanduser(file_path)).convert("RGB") 174 | # code referenced from torchvision.transforms.functional.to_tensor 175 | # handle PIL Image 176 | assert pic.mode == "RGB" 177 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 178 | img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) 179 | # put it from HWC to CHW format 180 | img = img.permute((2, 0, 1)).contiguous() 181 | 182 | img = any_to_torch_tensor(img, dtype=dtype, device=device) 183 | if add_batch_dim: 184 | img.unsqueeze_(dim=0) 185 | return img 186 | 187 | 188 | def any_to_float(x, strict: bool = False): 189 | """ 190 | Convert a singleton torch tensor or ndarray to float 191 | 192 | Args: 193 | strict: True to check if the input is a singleton and raise Exception if not. 194 | False to return the original value if not a singleton 195 | """ 196 | 197 | if torch.is_tensor(x) and x.numel() == 1: 198 | return float(x) 199 | elif isinstance(x, np.ndarray) and x.size == 1: 200 | return float(x) 201 | else: 202 | if strict: 203 | raise ValueError(f"{x} cannot be converted to a single float.") 204 | else: 205 | return x 206 | -------------------------------------------------------------------------------- /brs_algo/utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import codecs 4 | import fnmatch 5 | from collections import Counter 6 | from typing import Optional, Dict, Any, List, Union, Callable 7 | from typing_extensions import Literal 8 | 9 | 10 | def set_os_envs(envs: Optional[Dict[str, Any]] = None): 11 | """ 12 | Special value __delete__ or None indicates that the ENV_VAR should be removed 13 | """ 14 | if envs is None: 15 | envs = {} 16 | DEL = {None, "__delete__"} 17 | for k, v in envs.items(): 18 | if v in DEL: 19 | os.environ.pop(k, None) 20 | os.environ.update({k: str(v) for k, v in envs.items() if v not in DEL}) 21 | 22 | 23 | def argmax(L): 24 | return max(zip(L, range(len(L))))[1] 25 | 26 | 27 | def _match_patterns_helper(element, patterns): 28 | for p in patterns: 29 | if callable(p) and p(element): 30 | return True 31 | if fnmatch.fnmatch(element, p): 32 | return True 33 | return False 34 | 35 | 36 | def match_patterns( 37 | item: str, 38 | include: Union[str, List[str], Callable, List[Callable], None] = None, 39 | exclude: Union[str, List[str], Callable, List[Callable], None] = None, 40 | *, 41 | precedence: Literal["include", "exclude"] = "exclude", 42 | ): 43 | """ 44 | Args: 45 | include: None to disable `include` filter and delegate to exclude 46 | precedence: "include" or "exclude" 47 | """ 48 | assert precedence in ["include", "exclude"] 49 | if exclude is None: 50 | exclude = [] 51 | if isinstance(exclude, (str, Callable)): 52 | exclude = [exclude] 53 | if isinstance(include, (str, Callable)): 54 | include = [include] 55 | if include is None: 56 | # exclude is the sole veto vote 57 | return not _match_patterns_helper(item, exclude) 58 | 59 | if precedence == "include": 60 | return _match_patterns_helper(item, include) 61 | else: 62 | if _match_patterns_helper(item, exclude): 63 | return False 64 | else: 65 | return _match_patterns_helper(item, include) 66 | 67 | 68 | def filter_patterns( 69 | items: List[str], 70 | include: Union[str, List[str], Callable, List[Callable], None] = None, 71 | exclude: Union[str, List[str], Callable, List[Callable], None] = None, 72 | *, 73 | precedence: Literal["include", "exclude"] = "exclude", 74 | ordering: Literal["original", "include"] = "original", 75 | ): 76 | """ 77 | Args: 78 | ordering: affects the order of items in the returned list. Does not affect the 79 | content of the returned list. 80 | - "original": keep the ordering of items in the input list 81 | - "include": order items by the order of include patterns 82 | """ 83 | assert ordering in ["original", "include"] 84 | if include is None or isinstance(include, str) or ordering == "original": 85 | return [ 86 | x 87 | for x in items 88 | if match_patterns( 89 | x, include=include, exclude=exclude, precedence=precedence 90 | ) 91 | ] 92 | else: 93 | items = items.copy() 94 | ret = [] 95 | for inc in include: 96 | for i, item in enumerate(items): 97 | if item is None: 98 | continue 99 | if match_patterns( 100 | item, include=inc, exclude=exclude, precedence=precedence 101 | ): 102 | ret.append(item) 103 | items[i] = None 104 | return ret 105 | 106 | 107 | def getitem_nested(cfg, key: str): 108 | """ 109 | Recursively get key, if key has '.' in it 110 | """ 111 | keys = key.split(".") 112 | for k in keys: 113 | assert k in cfg, f'{k} in key "{key}" does not exist in config' 114 | cfg = cfg[k] 115 | return cfg 116 | 117 | 118 | def setitem_nested(cfg, key: str, value): 119 | """ 120 | Recursively get key, if key has '.' in it 121 | """ 122 | keys = key.split(".") 123 | for k in keys[:-1]: 124 | assert k in cfg, f'{k} in key "{key}" does not exist in config' 125 | cfg = cfg[k] 126 | cfg[keys[-1]] = value 127 | 128 | 129 | def getattr_nested(obj, key: str): 130 | """ 131 | Recursively get attribute 132 | """ 133 | keys = key.split(".") 134 | for k in keys: 135 | assert hasattr(obj, k), f'{k} in attribute "{key}" does not exist' 136 | obj = getattr(obj, k) 137 | return obj 138 | 139 | 140 | def setattr_nested(obj, key: str, value): 141 | """ 142 | Recursively set attribute 143 | """ 144 | keys = key.split(".") 145 | for k in keys[:-1]: 146 | assert hasattr(obj, k), f'{k} in attribute "{key}" does not exist' 147 | obj = getattr(obj, k) 148 | setattr(obj, keys[-1], value) 149 | 150 | 151 | class PeriodicEvent: 152 | """ 153 | triggers every period 154 | """ 155 | 156 | def __init__(self, period: int, initial_value=0): 157 | self._period = period 158 | assert self._period >= 1 159 | self._last_threshold = initial_value 160 | self._last_value = initial_value 161 | self._trigger_counts = 0 162 | 163 | def __call__(self, new_value=None, increment=None): 164 | assert bool(new_value is None) != bool(increment is None), ( 165 | "you must specify one and only one of new_value or increment, " 166 | "but not both" 167 | ) 168 | d = self._period 169 | if new_value is None: 170 | new_value = self._last_value + increment 171 | assert new_value >= self._last_value, ( 172 | f"value must be monotonically increasing. " 173 | f"Current value {new_value} < last value {self._last_value}" 174 | ) 175 | self._last_value = new_value 176 | if new_value - self._last_threshold >= d: 177 | self._last_threshold += (new_value - self._last_threshold) // d * d 178 | self._trigger_counts += 1 179 | return True 180 | else: 181 | return False 182 | 183 | @property 184 | def trigger_counts(self): 185 | return self._trigger_counts 186 | 187 | @property 188 | def current_value(self): 189 | return self._last_value 190 | 191 | 192 | class Once: 193 | def __init__(self): 194 | self._triggered = False 195 | 196 | def __call__(self): 197 | if not self._triggered: 198 | self._triggered = True 199 | return True 200 | else: 201 | return False 202 | 203 | def __bool__(self): 204 | raise RuntimeError("`Once` objects should be used by calling ()") 205 | 206 | 207 | _GLOBAL_ONCE_SET = set() 208 | _GLOBAL_NTIMES_COUNTER = Counter() 209 | 210 | 211 | def global_once(name): 212 | """ 213 | Try this to automate the name: 214 | https://gist.github.com/techtonik/2151727#gistcomment-2333747 215 | """ 216 | if name in _GLOBAL_ONCE_SET: 217 | return False 218 | else: 219 | _GLOBAL_ONCE_SET.add(name) 220 | return True 221 | 222 | 223 | def global_n_times(name, n: int): 224 | """ 225 | Triggers N times 226 | """ 227 | assert n >= 1 228 | if _GLOBAL_NTIMES_COUNTER[name] < n: 229 | _GLOBAL_NTIMES_COUNTER[name] += 1 230 | return True 231 | else: 232 | return False 233 | 234 | 235 | class Every: 236 | def __init__(self, n: int, on_first: bool = False): 237 | assert n > 0 238 | self._i = 0 if on_first else 1 239 | self._n = n 240 | 241 | def __call__(self): 242 | return self._i % self._n == 0 243 | 244 | def __bool__(self): 245 | raise RuntimeError("`Every` objects should be used by calling ()") 246 | 247 | 248 | def encode_base64(obj) -> str: 249 | return codecs.encode(pickle.dumps(obj), "base64").decode() 250 | 251 | 252 | def decode_base64(s: str): 253 | return pickle.loads(codecs.decode(s.encode(), "base64")) 254 | -------------------------------------------------------------------------------- /brs_algo/utils/print_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | import os 5 | import sys 6 | import logging 7 | import shlex 8 | import string 9 | import pprint 10 | import textwrap 11 | import time 12 | import traceback 13 | from datetime import datetime 14 | from typing import Callable, Union 15 | 16 | import numpy as np 17 | from typing_extensions import Literal 18 | 19 | from .functional_utils import meta_decorator 20 | from .misc_utils import match_patterns 21 | 22 | 23 | def to_readable_count_str(value: int, precision: int = 2) -> str: 24 | assert value >= 0 25 | labels = [" ", "K", "M", "B", "T"] 26 | num_digits = int(np.floor(np.log10(value)) + 1 if value > 0 else 1) 27 | num_groups = int(np.ceil(num_digits / 3)) 28 | num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions 29 | shift = -3 * (num_groups - 1) 30 | value = value * (10**shift) 31 | index = num_groups - 1 32 | rem = value - int(value) 33 | if precision > 0 and rem > 0.01: 34 | fmt = f"{{:.{precision}f}}" 35 | rem_str = fmt.format(rem).lstrip("0") 36 | else: 37 | rem_str = "" 38 | return f"{int(value):,d}{rem_str} {labels[index]}" 39 | 40 | 41 | def to_scientific_str(value, precision: int = 1, capitalize: bool = False) -> str: 42 | """ 43 | 0.0015 -> "1.5e-3" 44 | """ 45 | if value == 0: 46 | return "0" 47 | return f"{value:.{precision}e}".replace("e-0", "E-" if capitalize else "e-") 48 | 49 | 50 | def print_str(*args, **kwargs): 51 | """ 52 | Same as print() signature but returns a string 53 | """ 54 | sstream = io.StringIO() 55 | kwargs.pop("file", None) 56 | print(*args, **kwargs, file=sstream) 57 | return sstream.getvalue() 58 | 59 | 60 | def fstring(fmt_str, **kwargs): 61 | """ 62 | Simulate python f-string but without `f` 63 | """ 64 | locals().update(kwargs) 65 | return eval("f" + shlex.quote(fmt_str)) 66 | 67 | 68 | def get_format_keys(fmt_str): 69 | keys = [] 70 | for literal, field_name, fmt_spec, conversion in string.Formatter().parse(fmt_str): 71 | if field_name: 72 | keys.append(field_name) 73 | return keys 74 | 75 | 76 | def get_timestamp(milli_precision: int = 3): 77 | fmt = "%y-%m-%d %H:%M:%S" 78 | if milli_precision > 0: 79 | fmt += ".%f" 80 | stamp = datetime.now().strftime(fmt) 81 | if milli_precision > 0: 82 | stamp = stamp[:-milli_precision] 83 | return stamp 84 | 85 | 86 | def pretty_repr_str(obj, **kwargs): 87 | """ 88 | Useful to produce __repr__() 89 | """ 90 | if isinstance(obj, str): 91 | cls_name = obj 92 | else: 93 | cls_name = obj.__class__.__name__ 94 | kw_strs = [ 95 | k + "=" + pprint.pformat(v, indent=2, compact=True) for k, v in kwargs.items() 96 | ] 97 | new_line = len(cls_name) + sum(len(kw) for kw in kw_strs) > 84 98 | if new_line: 99 | kw = ",\n".join(kw_strs) 100 | return f"{cls_name}(\n{textwrap.indent(kw, ' ')}\n)" 101 | else: 102 | kw = ", ".join(kw_strs) 103 | return f"{cls_name}({kw})" 104 | 105 | 106 | def pprint_(*objs, **kwargs): 107 | """ 108 | Use pprint to format the objects 109 | """ 110 | print( 111 | *[ 112 | pprint.pformat(obj, indent=2) if not isinstance(obj, str) else obj 113 | for obj in objs 114 | ], 115 | **kwargs, 116 | ) 117 | 118 | 119 | def get_exception_info(to_str: bool = False): 120 | """ 121 | Returns: 122 | {'type': ExceptionType, 'value': ExceptionObject, 'trace': } 123 | """ 124 | typ_, value, trace = sys.exc_info() 125 | return { 126 | "type": typ_.__name__ if to_str else typ_, 127 | "value": str(value) if to_str else value, 128 | "trace": "".join(traceback.format_exception(typ_, value, trace)), 129 | } 130 | 131 | 132 | class DebugPrinter: 133 | """ 134 | Debug print, usage: dprint = DebugPrint(enabled=True) 135 | dprint(...) 136 | """ 137 | 138 | def __init__( 139 | self, enabled, tensor_summary: Literal["shape", "shape+dtype", "none"] = "shape" 140 | ): 141 | """ 142 | Args: 143 | tensor_summary: 144 | - shape: only prints shape 145 | - shape+dtype: also prints dtype and device 146 | - none: print full tensor 147 | """ 148 | self.enabled = enabled 149 | assert tensor_summary in ["shape", "shape+dtype", "none"] 150 | self.tensor_summary = tensor_summary 151 | 152 | def __call__(self, *args, **kwargs): 153 | if not self.enabled: 154 | return 155 | args = [self._process_arg(a) for a in args] 156 | pprint_(*args, **kwargs) 157 | 158 | def _process_arg(self, arg): 159 | import torch 160 | import numpy as np 161 | 162 | if torch.is_tensor(arg): 163 | if self.tensor_summary == "shape": 164 | return str(list(arg.size())) 165 | elif self.tensor_summary == "shape+dtype": 166 | return f"{arg.dtype}{list(arg.size())}|{arg.device}" 167 | elif isinstance(arg, np.ndarray): 168 | if self.tensor_summary == "shape": 169 | return str(list(arg.shape)) 170 | elif self.tensor_summary == "shape+dtype": 171 | return f"{arg.dtype}{list(arg.shape)}" 172 | return arg 173 | 174 | 175 | @meta_decorator 176 | def watch(func, seconds: int = 5, max_times: int = 0, keep_returns: bool = False): 177 | """ 178 | Decorator: executes a function repeated with the args and 179 | emulate `watch -n` capability 180 | 181 | See `gpustat` repo: https://github.com/wookayin/gpustat/pull/41/files 182 | 183 | Args: 184 | max_times: watch for `max_times` and then exit. If 0, never exits 185 | keep_returns: if True, will keep the return value from the function 186 | and return as a list at the end 187 | """ 188 | from blessings import Terminal 189 | 190 | def _wrapped(*args, **kwargs): 191 | term = Terminal() 192 | N = 0 193 | returns = [] 194 | with term.fullscreen(): 195 | while True: 196 | try: 197 | with term.location(0, 0): 198 | ret = func(*args, **kwargs) 199 | print(term.clear_eos, end="") 200 | if keep_returns: 201 | returns.append(ret) 202 | N += 1 203 | if max_times > 0 and N >= max_times: 204 | break 205 | time.sleep(seconds) 206 | except KeyboardInterrupt: 207 | break 208 | return returns 209 | 210 | return _wrapped 211 | 212 | 213 | class PrintRedirection(object): 214 | """ 215 | Context manager: temporarily redirects stdout and stderr 216 | """ 217 | 218 | def __init__(self, stdout=None, stderr=None): 219 | """ 220 | Args: 221 | stdout: if None, defaults to sys.stdout, unchanged 222 | stderr: if None, defaults to sys.stderr, unchanged 223 | """ 224 | if stdout is None: 225 | stdout = sys.stdout 226 | if stderr is None: 227 | stderr = sys.stderr 228 | self._stdout, self._stderr = stdout, stderr 229 | 230 | def __enter__(self): 231 | self._old_out, self._old_err = sys.stdout, sys.stderr 232 | self._old_out.flush() 233 | self._old_err.flush() 234 | sys.stdout, sys.stderr = self._stdout, self._stderr 235 | return self 236 | 237 | def __exit__(self, exc_type, exc_value, traceback): 238 | self.flush() 239 | # restore the normal stdout and stderr 240 | sys.stdout, sys.stderr = self._old_out, self._old_err 241 | 242 | def flush(self): 243 | "Manually flush the replaced stdout/stderr buffers." 244 | self._stdout.flush() 245 | self._stderr.flush() 246 | 247 | 248 | class PrintToFile(PrintRedirection): 249 | """ 250 | Print to file and save/close the handle at the end. 251 | """ 252 | 253 | def __init__(self, out_file=None, err_file=None): 254 | """ 255 | Args: 256 | out_file: file path 257 | err_file: file path. If the same as out_file, print both stdout 258 | and stderr to one file in order. 259 | """ 260 | self.out_file, self.err_file = out_file, err_file 261 | if out_file: 262 | out_file = os.path.expanduser(out_file) 263 | self.out_file = open(out_file, "w") 264 | if err_file: 265 | err_file = os.path.expanduser(out_file) 266 | if err_file == out_file: # redirect both stdout/err to one file 267 | self.err_file = self.out_file 268 | else: 269 | self.err_file = open(os.path.expanduser(out_file), "w") 270 | super().__init__(stdout=self.out_file, stderr=self.err_file) 271 | 272 | def __exit__(self, *args): 273 | super().__exit__(*args) 274 | if self.out_file: 275 | self.out_file.close() 276 | if self.err_file: 277 | self.err_file.close() 278 | 279 | 280 | def PrintSuppress(no_out=True, no_err=False): 281 | """ 282 | Args: 283 | no_out: stdout writes to sys.devnull 284 | no_err: stderr writes to sys.devnull 285 | """ 286 | out_file = os.devnull if no_out else None 287 | err_file = os.devnull if no_err else None 288 | return PrintToFile(out_file=out_file, err_file=err_file) 289 | 290 | 291 | class PrintString(PrintRedirection): 292 | """ 293 | Redirect stdout and stderr to strings. 294 | """ 295 | 296 | def __init__(self): 297 | self.out_stream = io.StringIO() 298 | self.err_stream = io.StringIO() 299 | super().__init__(stdout=self.out_stream, stderr=self.err_stream) 300 | 301 | def stdout(self): 302 | "Returns: stdout as one string." 303 | return self.out_stream.getvalue() 304 | 305 | def stderr(self): 306 | "Returns: stderr as one string." 307 | return self.err_stream.getvalue() 308 | 309 | def stdout_by_line(self): 310 | "Returns: a list of stdout line by line, ignore trailing blanks" 311 | return self.stdout().rstrip().split("\n") 312 | 313 | def stderr_by_line(self): 314 | "Returns: a list of stderr line by line, ignore trailing blanks" 315 | return self.stderr().rstrip().split("\n") 316 | 317 | 318 | # ==================== Logging filters ==================== 319 | class ExcludeLoggingFilter(logging.Filter): 320 | """ 321 | Usage: logging.getLogger('name').addFilter( 322 | ExcludeLoggingFilter(['info mess*age', 'Warning: *']) 323 | ) 324 | Supports wildcard. 325 | https://relaxdiego.com/2014/07/logging-in-python.html 326 | """ 327 | 328 | def __init__(self, patterns): 329 | super().__init__() 330 | self._patterns = patterns 331 | 332 | def filter(self, record): 333 | if match_patterns(record.msg, include=self._patterns): 334 | return False 335 | else: 336 | return True 337 | 338 | 339 | class ReplaceStringLoggingFilter(logging.Filter): 340 | def __init__(self, patterns, replacer: Callable): 341 | super().__init__() 342 | self._patterns = patterns 343 | assert callable(replacer) 344 | self._replacer = replacer 345 | 346 | def filter(self, record): 347 | if match_patterns(record.msg, include=self._patterns): 348 | record.msg = self._replacer(record.msg) 349 | 350 | 351 | def logging_exclude_pattern( 352 | logger_name, 353 | patterns: Union[str, list[str], Callable, list[Callable], None], 354 | ): 355 | """ 356 | Args: 357 | patterns: see enlight.utils.misc_utils.match_patterns 358 | """ 359 | logging.getLogger(logger_name).addFilter(ExcludeLoggingFilter(patterns)) 360 | 361 | 362 | def logging_replace_string( 363 | logger_name, 364 | patterns: Union[str, list[str], Callable, list[Callable], None], 365 | replacer: Callable, 366 | ): 367 | """ 368 | Args: 369 | patterns: see enlight.utils.misc_utils.match_patterns 370 | """ 371 | logging.getLogger(logger_name).addFilter( 372 | ReplaceStringLoggingFilter(patterns, replacer) 373 | ) 374 | -------------------------------------------------------------------------------- /brs_algo/utils/random_seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_seed(seed, torch_deterministic=False, rank=0): 8 | """set seed across modules""" 9 | if seed == -1 and torch_deterministic: 10 | seed = 42 + rank 11 | elif seed == -1: 12 | seed = np.random.randint(0, 10000) 13 | else: 14 | seed = seed + rank 15 | 16 | print("Setting seed: {}".format(seed)) 17 | 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | os.environ["PYTHONHASHSEED"] = str(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | if torch_deterministic: 26 | # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility 27 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.cudnn.deterministic = True 30 | torch.use_deterministic_algorithms(True) 31 | else: 32 | torch.backends.cudnn.benchmark = True 33 | torch.backends.cudnn.deterministic = False 34 | 35 | return seed 36 | -------------------------------------------------------------------------------- /brs_algo/utils/shape_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shape inference methods 3 | """ 4 | 5 | import math 6 | import numpy as np 7 | import torch 8 | import warnings 9 | from functools import partial 10 | 11 | from typing import List, Tuple, Union 12 | 13 | try: 14 | import tree 15 | except ImportError: 16 | pass 17 | 18 | 19 | # fmt: off 20 | __all__ = [ 21 | "shape_convnd", 22 | "shape_conv1d", "shape_conv2d", "shape_conv3d", 23 | "shape_transpose_convnd", 24 | "shape_transpose_conv1d", "shape_transpose_conv2d", "shape_transpose_conv3d", 25 | "shape_poolnd", 26 | "shape_maxpool1d", "shape_maxpool2d", "shape_maxpool3d", 27 | "shape_avgpool1d", "shape_avgpool2d", "shape_avgpool3d", 28 | "shape_slice", 29 | "check_shape" 30 | ] 31 | # fmt: on 32 | 33 | 34 | def _get_shape(x): 35 | "single object" 36 | if isinstance(x, np.ndarray): 37 | return tuple(x.shape) 38 | else: 39 | return tuple(x.size()) 40 | 41 | 42 | def _expands(dim, *xs): 43 | "repeat vars like kernel and stride to match dim" 44 | 45 | def _expand(x): 46 | if isinstance(x, int): 47 | return (x,) * dim 48 | else: 49 | assert len(x) == dim 50 | return x 51 | 52 | return map(lambda x: _expand(x), xs) 53 | 54 | 55 | _HELPER_TENSOR = torch.zeros((1,)) 56 | 57 | 58 | def shape_slice(input_shape, slice): 59 | """ 60 | Credit to Adam Paszke for the trick. Shape inference without instantiating 61 | an actual tensor. 62 | The key is that `.expand()` does not actually allocate memory 63 | Still needs to allocate a one-element HELPER_TENSOR. 64 | """ 65 | shape = _HELPER_TENSOR.expand(*input_shape)[slice] 66 | if hasattr(shape, "size"): 67 | return tuple(shape.size()) 68 | return (1,) 69 | 70 | 71 | class ShapeSlice: 72 | """ 73 | shape_slice inference with easy []-operator 74 | """ 75 | 76 | def __init__(self, input_shape): 77 | self.input_shape = input_shape 78 | 79 | def __getitem__(self, slice): 80 | return shape_slice(self.input_shape, slice) 81 | 82 | 83 | def check_shape( 84 | value: Union[Tuple, List, torch.Tensor, np.ndarray], 85 | expected: Union[Tuple, List, torch.Tensor, np.ndarray], 86 | err_msg="", 87 | mode="raise", 88 | ): 89 | """ 90 | Args: 91 | value: np array or torch Tensor 92 | expected: 93 | - list[int], tuple[int]: if any value is None, will match any dim 94 | - np array or torch Tensor: must have the same dimensions 95 | mode: 96 | - "raise": raise ValueError, shape mismatch 97 | - "return": returns True if shape matches, otherwise False 98 | - "warning": warnings.warn 99 | """ 100 | assert mode in ["raise", "return", "warning"] 101 | if torch.is_tensor(value): 102 | actual_shape = value.size() 103 | elif hasattr(value, "shape"): 104 | actual_shape = value.shape 105 | else: 106 | assert isinstance(value, (list, tuple)) 107 | actual_shape = value 108 | assert all( 109 | isinstance(s, int) for s in actual_shape 110 | ), f"actual shape: {actual_shape} is not a list of ints" 111 | 112 | if torch.is_tensor(expected): 113 | expected_shape = expected.size() 114 | elif hasattr(expected, "shape"): 115 | expected_shape = expected.shape 116 | else: 117 | assert isinstance(expected, (list, tuple)) 118 | expected_shape = expected 119 | 120 | err_msg = f" for {err_msg}" if err_msg else "" 121 | 122 | if len(actual_shape) != len(expected_shape): 123 | err_msg = ( 124 | f"Dimension mismatch{err_msg}: actual shape {actual_shape} " 125 | f"!= expected shape {expected_shape}." 126 | ) 127 | if mode == "raise": 128 | raise ValueError(err_msg) 129 | elif mode == "warning": 130 | warnings.warn(err_msg) 131 | return False 132 | 133 | for s_a, s_e in zip(actual_shape, expected_shape): 134 | if s_e is not None and s_a != s_e: 135 | err_msg = ( 136 | f"Shape mismatch{err_msg}: actual shape {actual_shape} " 137 | f"!= expected shape {expected_shape}." 138 | ) 139 | if mode == "raise": 140 | raise ValueError(err_msg) 141 | elif mode == "warning": 142 | warnings.warn(err_msg) 143 | return False 144 | return True 145 | 146 | 147 | def shape_convnd( 148 | dim, 149 | input_shape, 150 | out_channels, 151 | kernel_size, 152 | stride=1, 153 | padding=0, 154 | dilation=1, 155 | has_batch=False, 156 | ): 157 | """ 158 | http://pytorch.org/docs/nn.html#conv1d 159 | http://pytorch.org/docs/nn.html#conv2d 160 | http://pytorch.org/docs/nn.html#conv3d 161 | 162 | Args: 163 | dim: supports 1D to 3D 164 | input_shape: 165 | - 1D: [channel, length] 166 | - 2D: [channel, height, width] 167 | - 3D: [channel, depth, height, width] 168 | has_batch: whether the first dim is batch size or not 169 | """ 170 | if has_batch: 171 | assert ( 172 | len(input_shape) == dim + 2 173 | ), "input shape with batch should be {}-dimensional".format(dim + 2) 174 | else: 175 | assert ( 176 | len(input_shape) == dim + 1 177 | ), "input shape without batch should be {}-dimensional".format(dim + 1) 178 | if stride is None: 179 | # for pooling convention in PyTorch 180 | stride = kernel_size 181 | kernel_size, stride, padding, dilation = _expands( 182 | dim, kernel_size, stride, padding, dilation 183 | ) 184 | if has_batch: 185 | batch = input_shape[0] 186 | input_shape = input_shape[1:] 187 | else: 188 | batch = None 189 | _, *img = input_shape 190 | new_img_shape = [ 191 | math.floor( 192 | (img[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) 193 | // stride[i] 194 | + 1 195 | ) 196 | for i in range(dim) 197 | ] 198 | return ((batch,) if has_batch else ()) + (out_channels, *new_img_shape) 199 | 200 | 201 | def shape_poolnd( 202 | dim, input_shape, kernel_size, stride=None, padding=0, dilation=1, has_batch=False 203 | ): 204 | """ 205 | The only difference from infer_shape_convnd is that `stride` default is None 206 | """ 207 | if has_batch: 208 | out_channels = input_shape[1] 209 | else: 210 | out_channels = input_shape[0] 211 | return shape_convnd( 212 | dim, 213 | input_shape, 214 | out_channels, 215 | kernel_size, 216 | stride, 217 | padding, 218 | dilation, 219 | has_batch, 220 | ) 221 | 222 | 223 | def shape_transpose_convnd( 224 | dim, 225 | input_shape, 226 | out_channels, 227 | kernel_size, 228 | stride=1, 229 | padding=0, 230 | output_padding=0, 231 | dilation=1, 232 | has_batch=False, 233 | ): 234 | """ 235 | http://pytorch.org/docs/nn.html#convtranspose1d 236 | http://pytorch.org/docs/nn.html#convtranspose2d 237 | http://pytorch.org/docs/nn.html#convtranspose3d 238 | 239 | Args: 240 | dim: supports 1D to 3D 241 | input_shape: 242 | - 1D: [channel, length] 243 | - 2D: [channel, height, width] 244 | - 3D: [channel, depth, height, width] 245 | has_batch: whether the first dim is batch size or not 246 | """ 247 | if has_batch: 248 | assert ( 249 | len(input_shape) == dim + 2 250 | ), "input shape with batch should be {}-dimensional".format(dim + 2) 251 | else: 252 | assert ( 253 | len(input_shape) == dim + 1 254 | ), "input shape without batch should be {}-dimensional".format(dim + 1) 255 | kernel_size, stride, padding, output_padding, dilation = _expands( 256 | dim, kernel_size, stride, padding, output_padding, dilation 257 | ) 258 | if has_batch: 259 | batch = input_shape[0] 260 | input_shape = input_shape[1:] 261 | else: 262 | batch = None 263 | _, *img = input_shape 264 | new_img_shape = [ 265 | (img[i] - 1) * stride[i] - 2 * padding[i] + kernel_size[i] + output_padding[i] 266 | for i in range(dim) 267 | ] 268 | return ((batch,) if has_batch else ()) + (out_channels, *new_img_shape) 269 | 270 | 271 | shape_conv1d = partial(shape_convnd, 1) 272 | shape_conv2d = partial(shape_convnd, 2) 273 | shape_conv3d = partial(shape_convnd, 3) 274 | 275 | 276 | shape_transpose_conv1d = partial(shape_transpose_convnd, 1) 277 | shape_transpose_conv2d = partial(shape_transpose_convnd, 2) 278 | shape_transpose_conv3d = partial(shape_transpose_convnd, 3) 279 | 280 | 281 | shape_maxpool1d = partial(shape_poolnd, 1) 282 | shape_maxpool2d = partial(shape_poolnd, 2) 283 | shape_maxpool3d = partial(shape_poolnd, 3) 284 | 285 | 286 | """ 287 | http://pytorch.org/docs/nn.html#avgpool1d 288 | http://pytorch.org/docs/nn.html#avgpool2d 289 | http://pytorch.org/docs/nn.html#avgpool3d 290 | """ 291 | shape_avgpool1d = partial(shape_maxpool1d, dilation=1) 292 | shape_avgpool2d = partial(shape_maxpool2d, dilation=1) 293 | shape_avgpool3d = partial(shape_maxpool3d, dilation=1) 294 | -------------------------------------------------------------------------------- /brs_algo/utils/termcolor.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (c) 2008-2011 Volvox Development Team 3 | # 4 | # Permission is hereby granted, free of charge, to any person obtaining a copy 5 | # of this software and associated documentation files (the "Software"), to deal 6 | # in the Software without restriction, including without limitation the rights 7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | # copies of the Software, and to permit persons to whom the Software is 9 | # furnished to do so, subject to the following conditions: 10 | # 11 | # The above copyright notice and this permission notice shall be included in 12 | # all copies or substantial portions of the Software. 13 | # 14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | # THE SOFTWARE. 21 | # 22 | # Original Author: Konstantin Lepa 23 | # Updated by Jim Fan 24 | 25 | """ANSII Color formatting for output in terminal.""" 26 | import io 27 | import os 28 | from typing import Union, Optional, List 29 | 30 | 31 | __ALL__ = ["color_text", "cprint"] 32 | 33 | STYLES = dict( 34 | list( 35 | zip( 36 | ["bold", "dark", "", "underline", "blink", "", "reverse", "concealed"], 37 | list(range(1, 9)), 38 | ) 39 | ) 40 | ) 41 | del STYLES[""] 42 | 43 | 44 | HIGHLIGHTS = dict( 45 | list( 46 | zip( 47 | ["grey", "red", "green", "yellow", "blue", "magenta", "cyan", "white"], 48 | list(range(40, 48)), 49 | ) 50 | ) 51 | ) 52 | 53 | 54 | COLORS = dict( 55 | list( 56 | zip( 57 | ["grey", "red", "green", "yellow", "blue", "magenta", "cyan", "white"], 58 | list(range(30, 38)), 59 | ) 60 | ) 61 | ) 62 | 63 | 64 | def _strip_bg_prefix(color): 65 | "on_red -> red" 66 | if color.startswith("on_"): 67 | return color[len("on_") :] 68 | else: 69 | return color 70 | 71 | 72 | RESET = "\033[0m" 73 | 74 | 75 | def color_text( 76 | text, 77 | color: Optional[str] = None, 78 | bg_color: Optional[str] = None, 79 | styles: Optional[Union[str, List[str]]] = None, 80 | ): 81 | """Colorize text. 82 | 83 | Available text colors: 84 | red, green, yellow, blue, magenta, cyan, white. 85 | 86 | Available text highlights: 87 | on_red, on_green, on_yellow, on_blue, on_magenta, on_cyan, on_white. 88 | 89 | Available attributes: 90 | bold, dark, underline, blink, reverse, concealed. 91 | 92 | Example: 93 | colored('Hello, World!', 'red', 'on_grey', ['blue', 'blink']) 94 | colored('Hello, World!', 'green') 95 | """ 96 | if os.getenv("ANSI_COLORS_DISABLED") is None: 97 | fmt_str = "\033[%dm%s" 98 | if color is not None: 99 | text = fmt_str % (COLORS[color], text) 100 | 101 | if bg_color is not None: 102 | bg_color = _strip_bg_prefix(bg_color) 103 | text = fmt_str % (HIGHLIGHTS[bg_color], text) 104 | 105 | if styles is not None: 106 | if isinstance(styles, str): 107 | styles = [styles] 108 | for style in styles: 109 | text = fmt_str % (STYLES[style], text) 110 | 111 | text += RESET 112 | return text 113 | 114 | 115 | def cprint( 116 | *args, 117 | color: Optional[str] = None, 118 | bg_color: Optional[str] = None, 119 | styles: Optional[Union[str, List[str]]] = None, 120 | **kwargs, 121 | ): 122 | """Print colorize text. 123 | 124 | It accepts arguments of print function. 125 | """ 126 | sstream = io.StringIO() 127 | print(*args, sep=kwargs.pop("sep", None), end="", file=sstream) 128 | text = sstream.getvalue() 129 | print((color_text(text, color, bg_color, styles)), **kwargs) 130 | 131 | 132 | if __name__ == "__main__": 133 | print("Current terminal type: %s" % os.getenv("TERM")) 134 | print("Test basic colors:") 135 | cprint("Grey color", color="grey") 136 | cprint("Red color", color="red") 137 | cprint("Green color", color="green") 138 | cprint("Yellow color", color="yellow") 139 | cprint("Blue color", color="blue") 140 | cprint("Magenta color", color="magenta") 141 | cprint("Cyan color", color="cyan") 142 | cprint("White color", color="white") 143 | print(("-" * 78)) 144 | 145 | print("Test highlights:") 146 | cprint("On grey color", bg_color="on_grey") 147 | cprint("On red color", bg_color="on_red") 148 | cprint("On green color", bg_color="on_green") 149 | cprint("On yellow color", bg_color="on_yellow") 150 | cprint("On blue color", bg_color="on_blue") 151 | cprint("On magenta color", bg_color="on_magenta") 152 | cprint("On cyan color", bg_color="on_cyan") 153 | cprint("On white color", color="grey", bg_color="on_white") 154 | print("-" * 78) 155 | 156 | print("Test attributes:") 157 | cprint("Bold grey color", color="grey", styles="bold") 158 | cprint("Dark red color", color="red", styles=["dark"]) 159 | cprint("Underline green color", color="green", styles=["underline"]) 160 | cprint("Blink yellow color", color="yellow", styles=["blink"]) 161 | cprint("Reversed blue color", color="blue", styles=["reverse"]) 162 | cprint("Concealed Magenta color", color="magenta", styles=["concealed"]) 163 | cprint( 164 | "Bold underline reverse cyan color", 165 | color="cyan", 166 | styles=["bold", "underline", "reverse"], 167 | ) 168 | cprint( 169 | "Dark blink concealed white color", 170 | color="white", 171 | styles=["dark", "blink", "concealed"], 172 | ) 173 | print(("-" * 78)) 174 | 175 | print("Test mixing:") 176 | cprint( 177 | "Underline red on grey color", 178 | color="red", 179 | bg_color="on_grey", 180 | styles="underline", 181 | ) 182 | cprint( 183 | "Reversed green on red color", 184 | color="green", 185 | bg_color="on_red", 186 | styles="reverse", 187 | ) 188 | -------------------------------------------------------------------------------- /brs_algo/utils/tree_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils to handle nested data structures 3 | 4 | Install dm_tree first: 5 | https://tree.readthedocs.io/en/latest/api.html 6 | """ 7 | 8 | import collections 9 | import numpy as np 10 | from typing import Iterable, List, TypeVar, Any, Tuple 11 | 12 | 13 | try: 14 | import tree 15 | 16 | except ImportError: 17 | raise ImportError("Please install dm_tree first: `pip install dm_tree`") 18 | 19 | 20 | def is_sequence(obj): 21 | """ 22 | Returns: 23 | True if the sequence is a collections.Sequence and not a string. 24 | """ 25 | return isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str) 26 | 27 | 28 | def is_mapping(obj): 29 | """ 30 | Returns: 31 | True if the sequence is a collections.Mapping 32 | """ 33 | return isinstance(obj, collections.abc.Mapping) 34 | 35 | 36 | def tree_value_at_path(obj, paths: Tuple): 37 | try: 38 | for p in paths: 39 | obj = obj[p] 40 | return obj 41 | except Exception as e: 42 | raise ValueError(f"{e}\n\n-- Incorrect nested path {paths} for object: {obj}.") 43 | 44 | 45 | def tree_assign_at_path(obj, paths: Tuple, value): 46 | try: 47 | for p in paths[:-1]: 48 | obj = obj[p] 49 | if len(paths) > 0: 50 | obj[paths[-1]] = value 51 | except Exception as e: 52 | raise ValueError(f"{e}\n\n-- Incorrect nested path {paths} for object: {obj}.") 53 | 54 | 55 | def copy_non_leaf(obj): 56 | """ 57 | Deepcopy the nested structure, but does NOT copy the leaf values like Tensors 58 | """ 59 | return tree.map_structure(lambda x: x, obj) 60 | 61 | 62 | # ======================================================================= 63 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved. 64 | # 65 | # Licensed under the Apache License, Version 2.0 (the "License"); 66 | # you may not use this file except in compliance with the License. 67 | # You may obtain a copy of the License at 68 | # 69 | # http://www.apache.org/licenses/LICENSE-2.0 70 | # 71 | # Unless required by applicable law or agreed to in writing, software 72 | # distributed under the License is distributed on an "AS IS" BASIS, 73 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 74 | # See the License for the specific language governing permissions and 75 | # limitations under the License. 76 | 77 | # Tensor framework-agnostic utilities for manipulating nested structures. 78 | 79 | ElementType = TypeVar("ElementType") 80 | 81 | 82 | def fast_map_structure(func, *structure): 83 | """Faster map_structure implementation which skips some error checking.""" 84 | flat_structure = (tree.flatten(s) for s in structure) 85 | entries = zip(*flat_structure) 86 | # Arbitrarily choose one of the structures of the original sequence (the last) 87 | # to match the structure for the flattened sequence. 88 | return tree.unflatten_as(structure[-1], [func(*x) for x in entries]) 89 | 90 | 91 | def stack_sequence_fields(sequence: Iterable[ElementType]) -> ElementType: 92 | """Stacks a list of identically nested objects. 93 | 94 | This takes a sequence of identically nested objects and returns a single 95 | nested object whose ith leaf is a stacked numpy array of the corresponding 96 | ith leaf from each element of the sequence. 97 | 98 | For example, if `sequence` is: 99 | 100 | ```python 101 | [{ 102 | 'action': np.array([1.0]), 103 | 'observation': (np.array([0.0, 1.0, 2.0]),), 104 | 'reward': 1.0 105 | }, { 106 | 'action': np.array([0.5]), 107 | 'observation': (np.array([1.0, 2.0, 3.0]),), 108 | 'reward': 0.0 109 | }, { 110 | 'action': np.array([0.3]),1 111 | 'observation': (np.array([2.0, 3.0, 4.0]),), 112 | 'reward': 0.5 113 | }] 114 | ``` 115 | 116 | Then this function will return: 117 | 118 | ```python 119 | { 120 | 'action': np.array([....]) # array shape = [3 x 1] 121 | 'observation': (np.array([...]),) # array shape = [3 x 3] 122 | 'reward': np.array([...]) # array shape = [3] 123 | } 124 | ``` 125 | 126 | Note that the 'observation' entry in the above example has two levels of 127 | nesting, i.e it is a tuple of arrays. 128 | 129 | Args: 130 | sequence: a list of identically nested objects. 131 | 132 | Returns: 133 | A nested object with numpy. 134 | 135 | Raises: 136 | ValueError: If `sequence` is an empty sequence. 137 | """ 138 | # Handle empty input sequences. 139 | if not sequence: 140 | raise ValueError("Input sequence must not be empty") 141 | 142 | # Default to asarray when arrays don't have the same shape to be compatible 143 | # with old behaviour. 144 | try: 145 | return fast_map_structure(lambda *values: np.stack(values), *sequence) 146 | except ValueError: 147 | return fast_map_structure(lambda *values: np.asarray(values), *sequence) 148 | 149 | 150 | def unstack_sequence_fields(struct: ElementType, batch_size: int) -> List[ElementType]: 151 | """Converts a struct of batched arrays to a list of structs. 152 | 153 | This is effectively the inverse of `stack_sequence_fields`. 154 | 155 | Args: 156 | struct: An (arbitrarily nested) structure of arrays. 157 | batch_size: The length of the leading dimension of each array in the struct. 158 | This is assumed to be static and known. 159 | 160 | Returns: 161 | A list of structs with the same structure as `struct`, where each leaf node 162 | is an unbatched element of the original leaf node. 163 | """ 164 | 165 | return [tree.map_structure(lambda s, i=i: s[i], struct) for i in range(batch_size)] 166 | 167 | 168 | def broadcast_structures(*args: Any) -> Any: 169 | """Returns versions of the arguments that give them the same nested structure. 170 | 171 | Any nested items in *args must have the same structure. 172 | 173 | Any non-nested item will be replaced with a nested version that shares that 174 | structure. The leaves will all be references to the same original non-nested 175 | item. 176 | 177 | If all *args are nested, or all *args are non-nested, this function will 178 | return *args unchanged. 179 | 180 | Example: 181 | ``` 182 | a = ('a', 'b') 183 | b = 'c' 184 | tree_a, tree_b = broadcast_structure(a, b) 185 | tree_a 186 | > ('a', 'b') 187 | tree_b 188 | > ('c', 'c') 189 | ``` 190 | 191 | Args: 192 | *args: A Sequence of nested or non-nested items. 193 | 194 | Returns: 195 | `*args`, except with all items sharing the same nest structure. 196 | """ 197 | if not args: 198 | return 199 | 200 | reference_tree = None 201 | for arg in args: 202 | if tree.is_nested(arg): 203 | reference_tree = arg 204 | break 205 | 206 | if reference_tree is None: 207 | reference_tree = args[0] 208 | 209 | def mirror_structure(value, reference_tree): 210 | if tree.is_nested(value): 211 | # Use check_types=True so that the types of the trees we construct aren't 212 | # dependent on our arbitrary choice of which nested arg to use as the 213 | # reference_tree. 214 | tree.assert_same_structure(value, reference_tree, check_types=True) 215 | return value 216 | else: 217 | return tree.map_structure(lambda _: value, reference_tree) 218 | 219 | return tuple(mirror_structure(arg, reference_tree) for arg in args) 220 | -------------------------------------------------------------------------------- /main/rollout/clean_house_after_a_wild_party/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | INITIAL_QPOS = { 5 | "torso": np.array([0.01170956, -0.06480368, -0.06577059, -0.03678382]), 6 | "left_arm": np.array( 7 | [1.47320244, 2.82283323, -2.39968242, 0.04618586, -0.04400188, -0.02203692] 8 | ), 9 | "right_arm": np.array( 10 | [-1.54000939, 2.75047247, -2.27132666, -0.0316771, 0.05366865, 0.02300375] 11 | ), 12 | } 13 | 14 | 15 | GRIPPER_CLOSE_STROKE = 1.0 16 | GRIPPER_HALF_WIDTH = 50 17 | NUM_PCD_POINTS = 4096 18 | PAD_PCD_IF_LESS = False 19 | PCD_X_RANGE = (0.0, 2.0) 20 | PCD_Y_RANGE = (-1.0, 1.0) 21 | PCD_Z_RANGE = (-0.5, 1.6) 22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.3, -0.3, -0.3) 23 | MOBILE_BASE_VEL_ACTION_MAX = (0.3, 0.3, 0.3) 24 | HORIZON_STEPS = 2100 * 4 25 | CONTROL_FREQ = 100 26 | ACTION_REPEAT = 12 27 | -------------------------------------------------------------------------------- /main/rollout/clean_house_after_a_wild_party/rollout_async.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | ACTION_REPEAT, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | control_freq=65, 44 | wait_for_first_odom_msg=False, 45 | ) 46 | 47 | policy = WBVIMAPolicy( 48 | prop_dim=21, 49 | prop_keys=[ 50 | "odom/base_velocity", 51 | "qpos/torso", 52 | "qpos/left_arm", 53 | "qpos/left_gripper", 54 | "qpos/right_arm", 55 | "qpos/right_gripper", 56 | ], 57 | prop_mlp_hidden_depth=2, 58 | prop_mlp_hidden_dim=256, 59 | pointnet_n_coordinates=3, 60 | pointnet_n_color=3, 61 | pointnet_hidden_depth=2, 62 | pointnet_hidden_dim=256, 63 | action_keys=[ 64 | "mobile_base", 65 | "torso", 66 | "left_arm", 67 | "left_gripper", 68 | "right_arm", 69 | "right_gripper", 70 | ], 71 | action_key_dims={ 72 | "mobile_base": 3, 73 | "torso": 4, 74 | "left_arm": 6, 75 | "left_gripper": 1, 76 | "right_arm": 6, 77 | "right_gripper": 1, 78 | }, 79 | num_latest_obs=NUM_LATEST_OBS, 80 | use_modality_type_tokens=False, 81 | xf_n_embd=256, 82 | xf_n_layer=2, 83 | xf_n_head=8, 84 | xf_dropout_rate=0.1, 85 | xf_use_geglu=True, 86 | learnable_action_readout_token=False, 87 | action_dim=21, 88 | action_prediction_horizon=T_action_prediction, 89 | diffusion_step_embed_dim=128, 90 | unet_down_dims=[64, 128], 91 | unet_kernel_size=5, 92 | unet_n_groups=8, 93 | unet_cond_predict_scale=True, 94 | noise_scheduler=DDIMScheduler( 95 | num_train_timesteps=100, 96 | beta_start=0.0001, 97 | beta_end=0.02, 98 | beta_schedule="squaredcos_cap_v2", 99 | clip_sample=True, 100 | set_alpha_to_one=True, 101 | steps_offset=0, 102 | prediction_type="epsilon", 103 | ), 104 | noise_scheduler_step_kwargs=None, 105 | num_denoise_steps_per_inference=16, 106 | ) 107 | U.load_state_dict( 108 | policy, 109 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 110 | strip_prefix="policy.", 111 | strict=True, 112 | ) 113 | policy = policy.to(DEVICE) 114 | policy.eval() 115 | 116 | rollout = R1AsyncedRollout( 117 | robot_interface=robot, 118 | num_pcd_points=NUM_PCD_POINTS, 119 | pcd_x_range=PCD_X_RANGE, 120 | pcd_y_range=PCD_Y_RANGE, 121 | pcd_z_range=PCD_Z_RANGE, 122 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 123 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 124 | gripper_half_width=GRIPPER_HALF_WIDTH, 125 | num_latest_obs=NUM_LATEST_OBS, 126 | num_deployed_actions=T_action_prediction, 127 | device=DEVICE, 128 | action_execute_start_idx=args.action_execute_start_idx, 129 | policy=policy, 130 | horizon_steps=HORIZON_STEPS, 131 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 132 | action_repeat=ACTION_REPEAT, 133 | ) 134 | 135 | input("Press [ENTER] to reset robot to initial qpos") 136 | # reset robot to initial qpos 137 | robot.control( 138 | arm_cmd={ 139 | "left": INITIAL_QPOS["left_arm"], 140 | "right": INITIAL_QPOS["right_arm"], 141 | }, 142 | gripper_cmd={ 143 | "left": 0.1, 144 | "right": 0.1, 145 | }, 146 | torso_cmd=INITIAL_QPOS["torso"], 147 | ) 148 | 149 | input("Press [ENTER] to start rollout") 150 | for i in range(3): 151 | print(3 - i) 152 | time.sleep(1) 153 | rollout.rollout() 154 | 155 | 156 | if __name__ == "__main__": 157 | args = argparse.ArgumentParser() 158 | args.add_argument("--ckpt_path", type=str, required=True) 159 | args.add_argument("--action_execute_start_idx", type=int, default=1) 160 | args = args.parse_args() 161 | rollout(args) 162 | -------------------------------------------------------------------------------- /main/rollout/clean_house_after_a_wild_party/rollout_sync.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PCD_X_RANGE, 8 | PCD_Y_RANGE, 9 | PCD_Z_RANGE, 10 | MOBILE_BASE_VEL_ACTION_MIN, 11 | MOBILE_BASE_VEL_ACTION_MAX, 12 | GRIPPER_HALF_WIDTH, 13 | HORIZON_STEPS, 14 | CONTROL_FREQ, 15 | ) 16 | import numpy as np 17 | import torch 18 | from brs_ctrl.robot_interface import R1Interface 19 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 20 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 21 | import brs_algo.utils as U 22 | from brs_algo.learning.policy import WBVIMAPolicy 23 | from brs_algo.rollout.synced_rollout import R1SyncedRollout 24 | 25 | DEVICE = torch.device("cuda:0") 26 | NUM_LATEST_OBS = 2 27 | HORIZON = 16 28 | T_action_prediction = 8 29 | 30 | 31 | def rollout(args): 32 | robot = R1Interface( 33 | left_gripper=GalaxeaR1G1Gripper( 34 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 35 | ), 36 | right_gripper=GalaxeaR1G1Gripper( 37 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 38 | ), 39 | enable_rgbd=False, 40 | enable_pointcloud=True, 41 | mobile_base_cmd_limit=np.array([0.3, 0.3, 0.3]), 42 | ) 43 | 44 | policy = WBVIMAPolicy( 45 | prop_dim=21, 46 | prop_keys=[ 47 | "odom/base_velocity", 48 | "qpos/torso", 49 | "qpos/left_arm", 50 | "qpos/left_gripper", 51 | "qpos/right_arm", 52 | "qpos/right_gripper", 53 | ], 54 | prop_mlp_hidden_depth=2, 55 | prop_mlp_hidden_dim=256, 56 | pointnet_n_coordinates=3, 57 | pointnet_n_color=3, 58 | pointnet_hidden_depth=2, 59 | pointnet_hidden_dim=256, 60 | action_keys=[ 61 | "mobile_base", 62 | "torso", 63 | "left_arm", 64 | "left_gripper", 65 | "right_arm", 66 | "right_gripper", 67 | ], 68 | action_key_dims={ 69 | "mobile_base": 3, 70 | "torso": 4, 71 | "left_arm": 6, 72 | "left_gripper": 1, 73 | "right_arm": 6, 74 | "right_gripper": 1, 75 | }, 76 | num_latest_obs=NUM_LATEST_OBS, 77 | use_modality_type_tokens=False, 78 | xf_n_embd=256, 79 | xf_n_layer=2, 80 | xf_n_head=8, 81 | xf_dropout_rate=0.1, 82 | xf_use_geglu=True, 83 | learnable_action_readout_token=False, 84 | action_dim=21, 85 | action_prediction_horizon=T_action_prediction, 86 | diffusion_step_embed_dim=128, 87 | unet_down_dims=[64, 128], 88 | unet_kernel_size=5, 89 | unet_n_groups=8, 90 | unet_cond_predict_scale=True, 91 | noise_scheduler=DDIMScheduler( 92 | num_train_timesteps=100, 93 | beta_start=0.0001, 94 | beta_end=0.02, 95 | beta_schedule="squaredcos_cap_v2", 96 | clip_sample=True, 97 | set_alpha_to_one=True, 98 | steps_offset=0, 99 | prediction_type="epsilon", 100 | ), 101 | noise_scheduler_step_kwargs=None, 102 | num_denoise_steps_per_inference=16, 103 | ) 104 | U.load_state_dict( 105 | policy, 106 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 107 | strip_prefix="policy.", 108 | strict=True, 109 | ) 110 | policy = policy.to(DEVICE) 111 | policy.eval() 112 | 113 | rollout = R1SyncedRollout( 114 | robot_interface=robot, 115 | num_pcd_points=NUM_PCD_POINTS, 116 | pcd_x_range=PCD_X_RANGE, 117 | pcd_y_range=PCD_Y_RANGE, 118 | pcd_z_range=PCD_Z_RANGE, 119 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 120 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 121 | gripper_half_width=GRIPPER_HALF_WIDTH, 122 | num_latest_obs=NUM_LATEST_OBS, 123 | num_deployed_actions=T_action_prediction, 124 | device=DEVICE, 125 | policy=policy, 126 | horizon_steps=HORIZON_STEPS, 127 | pause_mode=args.pause_mode, 128 | control_freq=CONTROL_FREQ, 129 | ) 130 | 131 | input("Press [ENTER] to reset robot to initial qpos") 132 | # reset robot to initial qpos 133 | robot.control( 134 | arm_cmd={ 135 | "left": INITIAL_QPOS["left_arm"], 136 | "right": INITIAL_QPOS["right_arm"], 137 | }, 138 | gripper_cmd={ 139 | "left": 0.1, 140 | "right": 0.1, 141 | }, 142 | torso_cmd=INITIAL_QPOS["torso"], 143 | ) 144 | 145 | input("Press [ENTER] to start rollout") 146 | for i in range(3): 147 | print(3 - i) 148 | time.sleep(1) 149 | rollout.rollout() 150 | 151 | 152 | if __name__ == "__main__": 153 | args = argparse.ArgumentParser() 154 | args.add_argument("--ckpt_path", type=str, required=True) 155 | args.add_argument("--pause_mode", action="store_true") 156 | args = args.parse_args() 157 | rollout(args) 158 | -------------------------------------------------------------------------------- /main/rollout/clean_the_toilet/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | INITIAL_QPOS = { 5 | "torso": np.array([0.04507549, -0.15606569, -0.37281176, 0.02517941]), 6 | "left_arm": np.array( 7 | [1.51612224, 2.83734251, -2.39497914, 0.06206091, -0.07555695, 0.00982895] 8 | ), 9 | "right_arm": np.array( 10 | [-1.503665, 2.61041302, -2.01116604, 0.03022945, 0.11115978, 0.0534272] 11 | ), 12 | } 13 | 14 | 15 | GRIPPER_CLOSE_STROKE = 1.0 16 | GRIPPER_HALF_WIDTH = 50 17 | NUM_PCD_POINTS = 4096 18 | PAD_PCD_IF_LESS = True 19 | PCD_X_RANGE = (0.0, 1.0) 20 | PCD_Y_RANGE = (-0.5, 0.5) 21 | PCD_Z_RANGE = (0.0, 1.0) 22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.17, -0.15, -0.2) 23 | MOBILE_BASE_VEL_ACTION_MAX = (0.17, 0.15, 0.2) 24 | HORIZON_STEPS = 1200 * 4 25 | CONTROL_FREQ = 100 26 | ACTION_REPEAT = 12 27 | -------------------------------------------------------------------------------- /main/rollout/clean_the_toilet/rollout_async.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | ACTION_REPEAT, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | control_freq=65, 44 | ) 45 | 46 | policy = WBVIMAPolicy( 47 | prop_dim=21, 48 | prop_keys=[ 49 | "odom/base_velocity", 50 | "qpos/torso", 51 | "qpos/left_arm", 52 | "qpos/left_gripper", 53 | "qpos/right_arm", 54 | "qpos/right_gripper", 55 | ], 56 | prop_mlp_hidden_depth=2, 57 | prop_mlp_hidden_dim=256, 58 | pointnet_n_coordinates=3, 59 | pointnet_n_color=3, 60 | pointnet_hidden_depth=2, 61 | pointnet_hidden_dim=256, 62 | action_keys=[ 63 | "mobile_base", 64 | "torso", 65 | "left_arm", 66 | "left_gripper", 67 | "right_arm", 68 | "right_gripper", 69 | ], 70 | action_key_dims={ 71 | "mobile_base": 3, 72 | "torso": 4, 73 | "left_arm": 6, 74 | "left_gripper": 1, 75 | "right_arm": 6, 76 | "right_gripper": 1, 77 | }, 78 | num_latest_obs=NUM_LATEST_OBS, 79 | use_modality_type_tokens=False, 80 | xf_n_embd=256, 81 | xf_n_layer=2, 82 | xf_n_head=8, 83 | xf_dropout_rate=0.1, 84 | xf_use_geglu=True, 85 | learnable_action_readout_token=False, 86 | action_dim=21, 87 | action_prediction_horizon=T_action_prediction, 88 | diffusion_step_embed_dim=128, 89 | unet_down_dims=[64, 128], 90 | unet_kernel_size=5, 91 | unet_n_groups=8, 92 | unet_cond_predict_scale=True, 93 | noise_scheduler=DDIMScheduler( 94 | num_train_timesteps=100, 95 | beta_start=0.0001, 96 | beta_end=0.02, 97 | beta_schedule="squaredcos_cap_v2", 98 | clip_sample=True, 99 | set_alpha_to_one=True, 100 | steps_offset=0, 101 | prediction_type="epsilon", 102 | ), 103 | noise_scheduler_step_kwargs=None, 104 | num_denoise_steps_per_inference=16, 105 | ) 106 | U.load_state_dict( 107 | policy, 108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 109 | strip_prefix="policy.", 110 | strict=True, 111 | ) 112 | policy = policy.to(DEVICE) 113 | policy.eval() 114 | 115 | rollout = R1AsyncedRollout( 116 | robot_interface=robot, 117 | num_pcd_points=NUM_PCD_POINTS, 118 | pcd_x_range=PCD_X_RANGE, 119 | pcd_y_range=PCD_Y_RANGE, 120 | pcd_z_range=PCD_Z_RANGE, 121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 123 | gripper_half_width=GRIPPER_HALF_WIDTH, 124 | num_latest_obs=NUM_LATEST_OBS, 125 | num_deployed_actions=T_action_prediction, 126 | device=DEVICE, 127 | action_execute_start_idx=args.action_execute_start_idx, 128 | policy=policy, 129 | horizon_steps=HORIZON_STEPS, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | action_repeat=ACTION_REPEAT, 132 | ) 133 | 134 | input("Press [ENTER] to reset robot to initial qpos") 135 | # reset robot to initial qpos 136 | robot.control( 137 | arm_cmd={ 138 | "left": INITIAL_QPOS["left_arm"], 139 | "right": INITIAL_QPOS["right_arm"], 140 | }, 141 | gripper_cmd={ 142 | "left": 0.1, 143 | "right": 0.1, 144 | }, 145 | torso_cmd=INITIAL_QPOS["torso"], 146 | ) 147 | 148 | input("Press [ENTER] to start rollout") 149 | for i in range(3): 150 | print(3 - i) 151 | time.sleep(1) 152 | rollout.rollout() 153 | 154 | 155 | if __name__ == "__main__": 156 | args = argparse.ArgumentParser() 157 | args.add_argument("--ckpt_path", type=str, required=True) 158 | args.add_argument("--action_execute_start_idx", type=int, default=1) 159 | args = args.parse_args() 160 | rollout(args) 161 | -------------------------------------------------------------------------------- /main/rollout/clean_the_toilet/rollout_sync.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | CONTROL_FREQ, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | ) 44 | 45 | policy = WBVIMAPolicy( 46 | prop_dim=21, 47 | prop_keys=[ 48 | "odom/base_velocity", 49 | "qpos/torso", 50 | "qpos/left_arm", 51 | "qpos/left_gripper", 52 | "qpos/right_arm", 53 | "qpos/right_gripper", 54 | ], 55 | prop_mlp_hidden_depth=2, 56 | prop_mlp_hidden_dim=256, 57 | pointnet_n_coordinates=3, 58 | pointnet_n_color=3, 59 | pointnet_hidden_depth=2, 60 | pointnet_hidden_dim=256, 61 | action_keys=[ 62 | "mobile_base", 63 | "torso", 64 | "left_arm", 65 | "left_gripper", 66 | "right_arm", 67 | "right_gripper", 68 | ], 69 | action_key_dims={ 70 | "mobile_base": 3, 71 | "torso": 4, 72 | "left_arm": 6, 73 | "left_gripper": 1, 74 | "right_arm": 6, 75 | "right_gripper": 1, 76 | }, 77 | num_latest_obs=NUM_LATEST_OBS, 78 | use_modality_type_tokens=False, 79 | xf_n_embd=256, 80 | xf_n_layer=2, 81 | xf_n_head=8, 82 | xf_dropout_rate=0.1, 83 | xf_use_geglu=True, 84 | learnable_action_readout_token=False, 85 | action_dim=21, 86 | action_prediction_horizon=T_action_prediction, 87 | diffusion_step_embed_dim=128, 88 | unet_down_dims=[64, 128], 89 | unet_kernel_size=5, 90 | unet_n_groups=8, 91 | unet_cond_predict_scale=True, 92 | noise_scheduler=DDIMScheduler( 93 | num_train_timesteps=100, 94 | beta_start=0.0001, 95 | beta_end=0.02, 96 | beta_schedule="squaredcos_cap_v2", 97 | clip_sample=True, 98 | set_alpha_to_one=True, 99 | steps_offset=0, 100 | prediction_type="epsilon", 101 | ), 102 | noise_scheduler_step_kwargs=None, 103 | num_denoise_steps_per_inference=16, 104 | ) 105 | U.load_state_dict( 106 | policy, 107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 108 | strip_prefix="policy.", 109 | strict=True, 110 | ) 111 | policy = policy.to(DEVICE) 112 | policy.eval() 113 | 114 | rollout = R1SyncedRollout( 115 | robot_interface=robot, 116 | num_pcd_points=NUM_PCD_POINTS, 117 | pcd_x_range=PCD_X_RANGE, 118 | pcd_y_range=PCD_Y_RANGE, 119 | pcd_z_range=PCD_Z_RANGE, 120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 122 | gripper_half_width=GRIPPER_HALF_WIDTH, 123 | num_latest_obs=NUM_LATEST_OBS, 124 | num_deployed_actions=T_action_prediction, 125 | device=DEVICE, 126 | policy=policy, 127 | horizon_steps=HORIZON_STEPS, 128 | pause_mode=args.pause_mode, 129 | control_freq=CONTROL_FREQ, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | ) 132 | 133 | input("Press [ENTER] to reset robot to initial qpos") 134 | # reset robot to initial qpos 135 | robot.control( 136 | arm_cmd={ 137 | "left": INITIAL_QPOS["left_arm"], 138 | "right": INITIAL_QPOS["right_arm"], 139 | }, 140 | gripper_cmd={ 141 | "left": 0.1, 142 | "right": 0.1, 143 | }, 144 | torso_cmd=INITIAL_QPOS["torso"], 145 | ) 146 | 147 | input("Press [ENTER] to start rollout") 148 | for i in range(3): 149 | print(3 - i) 150 | time.sleep(1) 151 | rollout.rollout() 152 | 153 | 154 | if __name__ == "__main__": 155 | args = argparse.ArgumentParser() 156 | args.add_argument("--ckpt_path", type=str, required=True) 157 | args.add_argument("--pause_mode", action="store_true") 158 | args = args.parse_args() 159 | rollout(args) 160 | -------------------------------------------------------------------------------- /main/rollout/lay_clothes_out/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | INITIAL_QPOS = { 5 | "torso": np.array([-0.00018469, -0.02381224, -0.16378878, -0.00182143]), 6 | "left_arm": np.array( 7 | [1.54583152, 2.78541468, -2.3352779, 0.11783977, -0.13222753, -0.0264264] 8 | ), 9 | "right_arm": np.array( 10 | [-1.56691272, 2.74406209, -2.24469822, -0.13880591, 0.03090534, 0.08950716] 11 | ), 12 | } 13 | 14 | 15 | GRIPPER_CLOSE_STROKE = 0.5 16 | GRIPPER_HALF_WIDTH = 50 17 | NUM_PCD_POINTS = 4096 18 | PAD_PCD_IF_LESS = True 19 | PCD_X_RANGE = (0.0, 2.3) 20 | PCD_Y_RANGE = (-0.5, 0.5) 21 | PCD_Z_RANGE = (-0.3, 2.0) 22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.25, -0.25, -0.3) 23 | MOBILE_BASE_VEL_ACTION_MAX = (0.25, 0.25, 0.3) 24 | HORIZON_STEPS = 1200 * 4 25 | CONTROL_FREQ = 100 26 | ACTION_REPEAT = 12 27 | -------------------------------------------------------------------------------- /main/rollout/lay_clothes_out/rollout_async.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | ACTION_REPEAT, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | control_freq=65, 44 | ) 45 | 46 | policy = WBVIMAPolicy( 47 | prop_dim=21, 48 | prop_keys=[ 49 | "odom/base_velocity", 50 | "qpos/torso", 51 | "qpos/left_arm", 52 | "qpos/left_gripper", 53 | "qpos/right_arm", 54 | "qpos/right_gripper", 55 | ], 56 | prop_mlp_hidden_depth=2, 57 | prop_mlp_hidden_dim=256, 58 | pointnet_n_coordinates=3, 59 | pointnet_n_color=3, 60 | pointnet_hidden_depth=2, 61 | pointnet_hidden_dim=256, 62 | action_keys=[ 63 | "mobile_base", 64 | "torso", 65 | "left_arm", 66 | "left_gripper", 67 | "right_arm", 68 | "right_gripper", 69 | ], 70 | action_key_dims={ 71 | "mobile_base": 3, 72 | "torso": 4, 73 | "left_arm": 6, 74 | "left_gripper": 1, 75 | "right_arm": 6, 76 | "right_gripper": 1, 77 | }, 78 | num_latest_obs=NUM_LATEST_OBS, 79 | use_modality_type_tokens=False, 80 | xf_n_embd=256, 81 | xf_n_layer=2, 82 | xf_n_head=8, 83 | xf_dropout_rate=0.1, 84 | xf_use_geglu=True, 85 | learnable_action_readout_token=False, 86 | action_dim=21, 87 | action_prediction_horizon=T_action_prediction, 88 | diffusion_step_embed_dim=128, 89 | unet_down_dims=[64, 128], 90 | unet_kernel_size=5, 91 | unet_n_groups=8, 92 | unet_cond_predict_scale=True, 93 | noise_scheduler=DDIMScheduler( 94 | num_train_timesteps=100, 95 | beta_start=0.0001, 96 | beta_end=0.02, 97 | beta_schedule="squaredcos_cap_v2", 98 | clip_sample=True, 99 | set_alpha_to_one=True, 100 | steps_offset=0, 101 | prediction_type="epsilon", 102 | ), 103 | noise_scheduler_step_kwargs=None, 104 | num_denoise_steps_per_inference=16, 105 | ) 106 | U.load_state_dict( 107 | policy, 108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 109 | strip_prefix="policy.", 110 | strict=True, 111 | ) 112 | policy = policy.to(DEVICE) 113 | policy.eval() 114 | 115 | rollout = R1AsyncedRollout( 116 | robot_interface=robot, 117 | num_pcd_points=NUM_PCD_POINTS, 118 | pcd_x_range=PCD_X_RANGE, 119 | pcd_y_range=PCD_Y_RANGE, 120 | pcd_z_range=PCD_Z_RANGE, 121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 123 | gripper_half_width=GRIPPER_HALF_WIDTH, 124 | num_latest_obs=NUM_LATEST_OBS, 125 | num_deployed_actions=T_action_prediction, 126 | device=DEVICE, 127 | action_execute_start_idx=args.action_execute_start_idx, 128 | policy=policy, 129 | horizon_steps=HORIZON_STEPS, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | action_repeat=ACTION_REPEAT, 132 | ) 133 | 134 | input("Press [ENTER] to reset robot to initial qpos") 135 | # reset robot to initial qpos 136 | robot.control( 137 | arm_cmd={ 138 | "left": INITIAL_QPOS["left_arm"], 139 | "right": INITIAL_QPOS["right_arm"], 140 | }, 141 | gripper_cmd={ 142 | "left": 0.1, 143 | "right": 0.1, 144 | }, 145 | torso_cmd=INITIAL_QPOS["torso"], 146 | ) 147 | 148 | input("Press [ENTER] to start rollout") 149 | for i in range(3): 150 | print(3 - i) 151 | time.sleep(1) 152 | rollout.rollout() 153 | 154 | 155 | if __name__ == "__main__": 156 | args = argparse.ArgumentParser() 157 | args.add_argument("--ckpt_path", type=str, required=True) 158 | args.add_argument("--action_execute_start_idx", type=int, default=1) 159 | args = args.parse_args() 160 | rollout(args) 161 | -------------------------------------------------------------------------------- /main/rollout/lay_clothes_out/rollout_sync.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | CONTROL_FREQ, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | ) 44 | 45 | policy = WBVIMAPolicy( 46 | prop_dim=21, 47 | prop_keys=[ 48 | "odom/base_velocity", 49 | "qpos/torso", 50 | "qpos/left_arm", 51 | "qpos/left_gripper", 52 | "qpos/right_arm", 53 | "qpos/right_gripper", 54 | ], 55 | prop_mlp_hidden_depth=2, 56 | prop_mlp_hidden_dim=256, 57 | pointnet_n_coordinates=3, 58 | pointnet_n_color=3, 59 | pointnet_hidden_depth=2, 60 | pointnet_hidden_dim=256, 61 | action_keys=[ 62 | "mobile_base", 63 | "torso", 64 | "left_arm", 65 | "left_gripper", 66 | "right_arm", 67 | "right_gripper", 68 | ], 69 | action_key_dims={ 70 | "mobile_base": 3, 71 | "torso": 4, 72 | "left_arm": 6, 73 | "left_gripper": 1, 74 | "right_arm": 6, 75 | "right_gripper": 1, 76 | }, 77 | num_latest_obs=NUM_LATEST_OBS, 78 | use_modality_type_tokens=False, 79 | xf_n_embd=256, 80 | xf_n_layer=2, 81 | xf_n_head=8, 82 | xf_dropout_rate=0.1, 83 | xf_use_geglu=True, 84 | learnable_action_readout_token=False, 85 | action_dim=21, 86 | action_prediction_horizon=T_action_prediction, 87 | diffusion_step_embed_dim=128, 88 | unet_down_dims=[64, 128], 89 | unet_kernel_size=5, 90 | unet_n_groups=8, 91 | unet_cond_predict_scale=True, 92 | noise_scheduler=DDIMScheduler( 93 | num_train_timesteps=100, 94 | beta_start=0.0001, 95 | beta_end=0.02, 96 | beta_schedule="squaredcos_cap_v2", 97 | clip_sample=True, 98 | set_alpha_to_one=True, 99 | steps_offset=0, 100 | prediction_type="epsilon", 101 | ), 102 | noise_scheduler_step_kwargs=None, 103 | num_denoise_steps_per_inference=16, 104 | ) 105 | U.load_state_dict( 106 | policy, 107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 108 | strip_prefix="policy.", 109 | strict=True, 110 | ) 111 | policy = policy.to(DEVICE) 112 | policy.eval() 113 | 114 | rollout = R1SyncedRollout( 115 | robot_interface=robot, 116 | num_pcd_points=NUM_PCD_POINTS, 117 | pcd_x_range=PCD_X_RANGE, 118 | pcd_y_range=PCD_Y_RANGE, 119 | pcd_z_range=PCD_Z_RANGE, 120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 122 | gripper_half_width=GRIPPER_HALF_WIDTH, 123 | num_latest_obs=NUM_LATEST_OBS, 124 | num_deployed_actions=T_action_prediction, 125 | device=DEVICE, 126 | policy=policy, 127 | horizon_steps=HORIZON_STEPS, 128 | pause_mode=args.pause_mode, 129 | control_freq=CONTROL_FREQ, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | ) 132 | 133 | input("Press [ENTER] to reset robot to initial qpos") 134 | # reset robot to initial qpos 135 | robot.control( 136 | arm_cmd={ 137 | "left": INITIAL_QPOS["left_arm"], 138 | "right": INITIAL_QPOS["right_arm"], 139 | }, 140 | gripper_cmd={ 141 | "left": 0.1, 142 | "right": 0.1, 143 | }, 144 | torso_cmd=INITIAL_QPOS["torso"], 145 | ) 146 | 147 | input("Press [ENTER] to start rollout") 148 | for i in range(3): 149 | print(3 - i) 150 | time.sleep(1) 151 | rollout.rollout() 152 | 153 | 154 | if __name__ == "__main__": 155 | args = argparse.ArgumentParser() 156 | args.add_argument("--ckpt_path", type=str, required=True) 157 | args.add_argument("--pause_mode", action="store_true") 158 | args = args.parse_args() 159 | rollout(args) 160 | -------------------------------------------------------------------------------- /main/rollout/put_items_onto_shelves/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | INITIAL_QPOS = { 5 | "torso": np.array([0.015661, -0.073367, -0.342951, -0.027991]), 6 | "left_arm": np.array( 7 | [1.29920638, 2.81973617, -2.41644043, -0.00392979, -0.12388298, -0.04762766] 8 | ), 9 | "right_arm": np.array( 10 | [-1.32845532, 2.83736596, -2.39677021, -0.00901277, 0.11162979, 0.02240851] 11 | ), 12 | } 13 | 14 | 15 | GRIPPER_CLOSE_STROKE = 0.5 16 | GRIPPER_HALF_WIDTH = 50 17 | NUM_PCD_POINTS = 4096 18 | PAD_PCD_IF_LESS = True 19 | PCD_X_RANGE = (0.0, 2.0) 20 | PCD_Y_RANGE = (-0.5, 0.5) 21 | PCD_Z_RANGE = (-0.3, 2.0) 22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.1, -0.15, -0.3) 23 | MOBILE_BASE_VEL_ACTION_MAX = (0.1, 0.15, 0.3) 24 | HORIZON_STEPS = 600 * 4 25 | CONTROL_FREQ = 100 26 | ACTION_REPEAT = 12 27 | -------------------------------------------------------------------------------- /main/rollout/put_items_onto_shelves/rollout_async.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | ACTION_REPEAT, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | control_freq=65, 44 | ) 45 | 46 | policy = WBVIMAPolicy( 47 | prop_dim=21, 48 | prop_keys=[ 49 | "odom/base_velocity", 50 | "qpos/torso", 51 | "qpos/left_arm", 52 | "qpos/left_gripper", 53 | "qpos/right_arm", 54 | "qpos/right_gripper", 55 | ], 56 | prop_mlp_hidden_depth=2, 57 | prop_mlp_hidden_dim=256, 58 | pointnet_n_coordinates=3, 59 | pointnet_n_color=3, 60 | pointnet_hidden_depth=2, 61 | pointnet_hidden_dim=256, 62 | action_keys=[ 63 | "mobile_base", 64 | "torso", 65 | "left_arm", 66 | "left_gripper", 67 | "right_arm", 68 | "right_gripper", 69 | ], 70 | action_key_dims={ 71 | "mobile_base": 3, 72 | "torso": 4, 73 | "left_arm": 6, 74 | "left_gripper": 1, 75 | "right_arm": 6, 76 | "right_gripper": 1, 77 | }, 78 | num_latest_obs=NUM_LATEST_OBS, 79 | use_modality_type_tokens=False, 80 | xf_n_embd=256, 81 | xf_n_layer=2, 82 | xf_n_head=8, 83 | xf_dropout_rate=0.1, 84 | xf_use_geglu=True, 85 | learnable_action_readout_token=False, 86 | action_dim=21, 87 | action_prediction_horizon=T_action_prediction, 88 | diffusion_step_embed_dim=128, 89 | unet_down_dims=[64, 128], 90 | unet_kernel_size=5, 91 | unet_n_groups=8, 92 | unet_cond_predict_scale=True, 93 | noise_scheduler=DDIMScheduler( 94 | num_train_timesteps=100, 95 | beta_start=0.0001, 96 | beta_end=0.02, 97 | beta_schedule="squaredcos_cap_v2", 98 | clip_sample=True, 99 | set_alpha_to_one=True, 100 | steps_offset=0, 101 | prediction_type="epsilon", 102 | ), 103 | noise_scheduler_step_kwargs=None, 104 | num_denoise_steps_per_inference=16, 105 | ) 106 | U.load_state_dict( 107 | policy, 108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 109 | strip_prefix="policy.", 110 | strict=True, 111 | ) 112 | policy = policy.to(DEVICE) 113 | policy.eval() 114 | 115 | rollout = R1AsyncedRollout( 116 | robot_interface=robot, 117 | num_pcd_points=NUM_PCD_POINTS, 118 | pcd_x_range=PCD_X_RANGE, 119 | pcd_y_range=PCD_Y_RANGE, 120 | pcd_z_range=PCD_Z_RANGE, 121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 123 | gripper_half_width=GRIPPER_HALF_WIDTH, 124 | num_latest_obs=NUM_LATEST_OBS, 125 | num_deployed_actions=T_action_prediction, 126 | device=DEVICE, 127 | action_execute_start_idx=args.action_execute_start_idx, 128 | policy=policy, 129 | horizon_steps=HORIZON_STEPS, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | action_repeat=ACTION_REPEAT, 132 | ) 133 | 134 | input("Press [ENTER] to reset robot to initial qpos") 135 | # reset robot to initial qpos 136 | robot.control( 137 | arm_cmd={ 138 | "left": INITIAL_QPOS["left_arm"], 139 | "right": INITIAL_QPOS["right_arm"], 140 | }, 141 | gripper_cmd={ 142 | "left": 0.1, 143 | "right": 0.1, 144 | }, 145 | torso_cmd=INITIAL_QPOS["torso"], 146 | ) 147 | 148 | input("Press [ENTER] to start rollout") 149 | for i in range(3): 150 | print(3 - i) 151 | time.sleep(1) 152 | rollout.rollout() 153 | 154 | 155 | if __name__ == "__main__": 156 | args = argparse.ArgumentParser() 157 | args.add_argument("--ckpt_path", type=str, required=True) 158 | args.add_argument("--action_execute_start_idx", type=int, default=1) 159 | args = args.parse_args() 160 | rollout(args) 161 | -------------------------------------------------------------------------------- /main/rollout/put_items_onto_shelves/rollout_sync.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | CONTROL_FREQ, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | ) 44 | 45 | policy = WBVIMAPolicy( 46 | prop_dim=21, 47 | prop_keys=[ 48 | "odom/base_velocity", 49 | "qpos/torso", 50 | "qpos/left_arm", 51 | "qpos/left_gripper", 52 | "qpos/right_arm", 53 | "qpos/right_gripper", 54 | ], 55 | prop_mlp_hidden_depth=2, 56 | prop_mlp_hidden_dim=256, 57 | pointnet_n_coordinates=3, 58 | pointnet_n_color=3, 59 | pointnet_hidden_depth=2, 60 | pointnet_hidden_dim=256, 61 | action_keys=[ 62 | "mobile_base", 63 | "torso", 64 | "left_arm", 65 | "left_gripper", 66 | "right_arm", 67 | "right_gripper", 68 | ], 69 | action_key_dims={ 70 | "mobile_base": 3, 71 | "torso": 4, 72 | "left_arm": 6, 73 | "left_gripper": 1, 74 | "right_arm": 6, 75 | "right_gripper": 1, 76 | }, 77 | num_latest_obs=NUM_LATEST_OBS, 78 | use_modality_type_tokens=False, 79 | xf_n_embd=256, 80 | xf_n_layer=2, 81 | xf_n_head=8, 82 | xf_dropout_rate=0.1, 83 | xf_use_geglu=True, 84 | learnable_action_readout_token=False, 85 | action_dim=21, 86 | action_prediction_horizon=T_action_prediction, 87 | diffusion_step_embed_dim=128, 88 | unet_down_dims=[64, 128], 89 | unet_kernel_size=5, 90 | unet_n_groups=8, 91 | unet_cond_predict_scale=True, 92 | noise_scheduler=DDIMScheduler( 93 | num_train_timesteps=100, 94 | beta_start=0.0001, 95 | beta_end=0.02, 96 | beta_schedule="squaredcos_cap_v2", 97 | clip_sample=True, 98 | set_alpha_to_one=True, 99 | steps_offset=0, 100 | prediction_type="epsilon", 101 | ), 102 | noise_scheduler_step_kwargs=None, 103 | num_denoise_steps_per_inference=16, 104 | ) 105 | U.load_state_dict( 106 | policy, 107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 108 | strip_prefix="policy.", 109 | strict=True, 110 | ) 111 | policy = policy.to(DEVICE) 112 | policy.eval() 113 | 114 | rollout = R1SyncedRollout( 115 | robot_interface=robot, 116 | num_pcd_points=NUM_PCD_POINTS, 117 | pcd_x_range=PCD_X_RANGE, 118 | pcd_y_range=PCD_Y_RANGE, 119 | pcd_z_range=PCD_Z_RANGE, 120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 122 | gripper_half_width=GRIPPER_HALF_WIDTH, 123 | num_latest_obs=NUM_LATEST_OBS, 124 | num_deployed_actions=T_action_prediction, 125 | device=DEVICE, 126 | policy=policy, 127 | horizon_steps=HORIZON_STEPS, 128 | pause_mode=args.pause_mode, 129 | control_freq=CONTROL_FREQ, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | ) 132 | 133 | input("Press [ENTER] to reset robot to initial qpos") 134 | # reset robot to initial qpos 135 | robot.control( 136 | arm_cmd={ 137 | "left": INITIAL_QPOS["left_arm"], 138 | "right": INITIAL_QPOS["right_arm"], 139 | }, 140 | gripper_cmd={ 141 | "left": 0.1, 142 | "right": 0.1, 143 | }, 144 | torso_cmd=INITIAL_QPOS["torso"], 145 | ) 146 | 147 | input("Press [ENTER] to start rollout") 148 | for i in range(3): 149 | print(3 - i) 150 | time.sleep(1) 151 | rollout.rollout() 152 | 153 | 154 | if __name__ == "__main__": 155 | args = argparse.ArgumentParser() 156 | args.add_argument("--ckpt_path", type=str, required=True) 157 | args.add_argument("--pause_mode", action="store_true") 158 | args = args.parse_args() 159 | rollout(args) 160 | -------------------------------------------------------------------------------- /main/rollout/take_trash_outside/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | INITIAL_QPOS = { 5 | "torso": np.array([0.028215, -0.07880333, -0.12412, 0.02912167]), 6 | "left_arm": np.array( 7 | [ 8 | 1.48734574e00, 9 | 2.94314716e00, 10 | -2.58810284e00, 11 | -1.85106383e-03, 12 | -8.12056738e-04, 13 | 2.91258865e-02, 14 | ] 15 | ), 16 | "right_arm": np.array( 17 | [-1.56693262, 2.84392553, -2.37629965, 0.0035922, 0.03647163, 0.01210816] 18 | ), 19 | } 20 | 21 | 22 | GRIPPER_CLOSE_STROKE = 0.5 23 | GRIPPER_HALF_WIDTH = 50 24 | NUM_PCD_POINTS = 4096 25 | PAD_PCD_IF_LESS = True 26 | PCD_X_RANGE = (0.0, 4.0) 27 | PCD_Y_RANGE = (-1.0, 1.0) 28 | PCD_Z_RANGE = (-0.5, 1.6) 29 | MOBILE_BASE_VEL_ACTION_MIN = (-0.35, -0.35, -0.3) 30 | MOBILE_BASE_VEL_ACTION_MAX = (0.35, 0.35, 0.3) 31 | HORIZON_STEPS = 1300 * 4 32 | CONTROL_FREQ = 100 33 | ACTION_REPEAT = 12 34 | -------------------------------------------------------------------------------- /main/rollout/take_trash_outside/rollout_async.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | ACTION_REPEAT, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | control_freq=65, 44 | ) 45 | 46 | policy = WBVIMAPolicy( 47 | prop_dim=21, 48 | prop_keys=[ 49 | "odom/base_velocity", 50 | "qpos/torso", 51 | "qpos/left_arm", 52 | "qpos/left_gripper", 53 | "qpos/right_arm", 54 | "qpos/right_gripper", 55 | ], 56 | prop_mlp_hidden_depth=2, 57 | prop_mlp_hidden_dim=256, 58 | pointnet_n_coordinates=3, 59 | pointnet_n_color=3, 60 | pointnet_hidden_depth=2, 61 | pointnet_hidden_dim=256, 62 | action_keys=[ 63 | "mobile_base", 64 | "torso", 65 | "left_arm", 66 | "left_gripper", 67 | "right_arm", 68 | "right_gripper", 69 | ], 70 | action_key_dims={ 71 | "mobile_base": 3, 72 | "torso": 4, 73 | "left_arm": 6, 74 | "left_gripper": 1, 75 | "right_arm": 6, 76 | "right_gripper": 1, 77 | }, 78 | num_latest_obs=NUM_LATEST_OBS, 79 | use_modality_type_tokens=False, 80 | xf_n_embd=256, 81 | xf_n_layer=2, 82 | xf_n_head=8, 83 | xf_dropout_rate=0.1, 84 | xf_use_geglu=True, 85 | learnable_action_readout_token=False, 86 | action_dim=21, 87 | action_prediction_horizon=T_action_prediction, 88 | diffusion_step_embed_dim=128, 89 | unet_down_dims=[64, 128], 90 | unet_kernel_size=5, 91 | unet_n_groups=8, 92 | unet_cond_predict_scale=True, 93 | noise_scheduler=DDIMScheduler( 94 | num_train_timesteps=100, 95 | beta_start=0.0001, 96 | beta_end=0.02, 97 | beta_schedule="squaredcos_cap_v2", 98 | clip_sample=True, 99 | set_alpha_to_one=True, 100 | steps_offset=0, 101 | prediction_type="epsilon", 102 | ), 103 | noise_scheduler_step_kwargs=None, 104 | num_denoise_steps_per_inference=16, 105 | ) 106 | U.load_state_dict( 107 | policy, 108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 109 | strip_prefix="policy.", 110 | strict=True, 111 | ) 112 | policy = policy.to(DEVICE) 113 | policy.eval() 114 | 115 | rollout = R1AsyncedRollout( 116 | robot_interface=robot, 117 | num_pcd_points=NUM_PCD_POINTS, 118 | pcd_x_range=PCD_X_RANGE, 119 | pcd_y_range=PCD_Y_RANGE, 120 | pcd_z_range=PCD_Z_RANGE, 121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 123 | gripper_half_width=GRIPPER_HALF_WIDTH, 124 | num_latest_obs=NUM_LATEST_OBS, 125 | num_deployed_actions=T_action_prediction, 126 | device=DEVICE, 127 | action_execute_start_idx=args.action_execute_start_idx, 128 | policy=policy, 129 | horizon_steps=HORIZON_STEPS, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | action_repeat=ACTION_REPEAT, 132 | ) 133 | 134 | input("Press [ENTER] to reset robot to initial qpos") 135 | # reset robot to initial qpos 136 | robot.control( 137 | arm_cmd={ 138 | "left": INITIAL_QPOS["left_arm"], 139 | "right": INITIAL_QPOS["right_arm"], 140 | }, 141 | gripper_cmd={ 142 | "left": 0.1, 143 | "right": 0.1, 144 | }, 145 | torso_cmd=INITIAL_QPOS["torso"], 146 | ) 147 | 148 | input("Press [ENTER] to start rollout") 149 | for i in range(3): 150 | print(3 - i) 151 | time.sleep(1) 152 | rollout.rollout() 153 | 154 | 155 | if __name__ == "__main__": 156 | args = argparse.ArgumentParser() 157 | args.add_argument("--ckpt_path", type=str, required=True) 158 | args.add_argument("--action_execute_start_idx", type=int, default=1) 159 | args = args.parse_args() 160 | rollout(args) 161 | -------------------------------------------------------------------------------- /main/rollout/take_trash_outside/rollout_sync.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from common import ( 4 | INITIAL_QPOS, 5 | GRIPPER_CLOSE_STROKE, 6 | NUM_PCD_POINTS, 7 | PAD_PCD_IF_LESS, 8 | PCD_X_RANGE, 9 | PCD_Y_RANGE, 10 | PCD_Z_RANGE, 11 | MOBILE_BASE_VEL_ACTION_MIN, 12 | MOBILE_BASE_VEL_ACTION_MAX, 13 | GRIPPER_HALF_WIDTH, 14 | HORIZON_STEPS, 15 | CONTROL_FREQ, 16 | ) 17 | import numpy as np 18 | import torch 19 | from brs_ctrl.robot_interface import R1Interface 20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper 21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 22 | import brs_algo.utils as U 23 | from brs_algo.learning.policy import WBVIMAPolicy 24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout 25 | 26 | DEVICE = torch.device("cuda:0") 27 | NUM_LATEST_OBS = 2 28 | HORIZON = 16 29 | T_action_prediction = 8 30 | 31 | 32 | def rollout(args): 33 | robot = R1Interface( 34 | left_gripper=GalaxeaR1G1Gripper( 35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE 36 | ), 37 | right_gripper=GalaxeaR1G1Gripper( 38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE 39 | ), 40 | enable_rgbd=False, 41 | enable_pointcloud=True, 42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX), 43 | ) 44 | 45 | policy = WBVIMAPolicy( 46 | prop_dim=21, 47 | prop_keys=[ 48 | "odom/base_velocity", 49 | "qpos/torso", 50 | "qpos/left_arm", 51 | "qpos/left_gripper", 52 | "qpos/right_arm", 53 | "qpos/right_gripper", 54 | ], 55 | prop_mlp_hidden_depth=2, 56 | prop_mlp_hidden_dim=256, 57 | pointnet_n_coordinates=3, 58 | pointnet_n_color=3, 59 | pointnet_hidden_depth=2, 60 | pointnet_hidden_dim=256, 61 | action_keys=[ 62 | "mobile_base", 63 | "torso", 64 | "left_arm", 65 | "left_gripper", 66 | "right_arm", 67 | "right_gripper", 68 | ], 69 | action_key_dims={ 70 | "mobile_base": 3, 71 | "torso": 4, 72 | "left_arm": 6, 73 | "left_gripper": 1, 74 | "right_arm": 6, 75 | "right_gripper": 1, 76 | }, 77 | num_latest_obs=NUM_LATEST_OBS, 78 | use_modality_type_tokens=False, 79 | xf_n_embd=256, 80 | xf_n_layer=2, 81 | xf_n_head=8, 82 | xf_dropout_rate=0.1, 83 | xf_use_geglu=True, 84 | learnable_action_readout_token=False, 85 | action_dim=21, 86 | action_prediction_horizon=T_action_prediction, 87 | diffusion_step_embed_dim=128, 88 | unet_down_dims=[64, 128], 89 | unet_kernel_size=5, 90 | unet_n_groups=8, 91 | unet_cond_predict_scale=True, 92 | noise_scheduler=DDIMScheduler( 93 | num_train_timesteps=100, 94 | beta_start=0.0001, 95 | beta_end=0.02, 96 | beta_schedule="squaredcos_cap_v2", 97 | clip_sample=True, 98 | set_alpha_to_one=True, 99 | steps_offset=0, 100 | prediction_type="epsilon", 101 | ), 102 | noise_scheduler_step_kwargs=None, 103 | num_denoise_steps_per_inference=16, 104 | ) 105 | U.load_state_dict( 106 | policy, 107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"], 108 | strip_prefix="policy.", 109 | strict=True, 110 | ) 111 | policy = policy.to(DEVICE) 112 | policy.eval() 113 | 114 | rollout = R1SyncedRollout( 115 | robot_interface=robot, 116 | num_pcd_points=NUM_PCD_POINTS, 117 | pcd_x_range=PCD_X_RANGE, 118 | pcd_y_range=PCD_Y_RANGE, 119 | pcd_z_range=PCD_Z_RANGE, 120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN, 121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX, 122 | gripper_half_width=GRIPPER_HALF_WIDTH, 123 | num_latest_obs=NUM_LATEST_OBS, 124 | num_deployed_actions=T_action_prediction, 125 | device=DEVICE, 126 | policy=policy, 127 | horizon_steps=HORIZON_STEPS, 128 | pause_mode=args.pause_mode, 129 | control_freq=CONTROL_FREQ, 130 | pad_pcd_if_needed=PAD_PCD_IF_LESS, 131 | ) 132 | 133 | input("Press [ENTER] to reset robot to initial qpos") 134 | # reset robot to initial qpos 135 | robot.control( 136 | arm_cmd={ 137 | "left": INITIAL_QPOS["left_arm"], 138 | "right": INITIAL_QPOS["right_arm"], 139 | }, 140 | gripper_cmd={ 141 | "left": 0.1, 142 | "right": 0.1, 143 | }, 144 | torso_cmd=INITIAL_QPOS["torso"], 145 | ) 146 | 147 | input("Press [ENTER] to start rollout") 148 | for i in range(3): 149 | print(3 - i) 150 | time.sleep(1) 151 | rollout.rollout() 152 | 153 | 154 | if __name__ == "__main__": 155 | args = argparse.ArgumentParser() 156 | args.add_argument("--ckpt_path", type=str, required=True) 157 | args.add_argument("--pause_mode", action="store_true") 158 | args = args.parse_args() 159 | rollout(args) 160 | -------------------------------------------------------------------------------- /main/train/cfg/arch/wbvima.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | run_name: "${optional_:${prefix}}\ 4 | ${arch_name}\ 5 | _${task_name}\ 6 | _To${num_latest_obs}\ 7 | _Ta${action_prediction_horizon}\ 8 | _b${bs}\ 9 | ${__optional:${suffix}}\ 10 | ${_optional:${postsuffix}}" 11 | 12 | arch_name: wbvima 13 | 14 | # ====== DP specific ====== 15 | num_latest_obs: 2 16 | action_prediction_horizon: 8 17 | 18 | wd: 0.1 19 | 20 | # ------ module ------ 21 | module: 22 | _target_: brs_algo.learning.module.DiffusionModule 23 | policy: 24 | _target_: brs_algo.learning.policy.WBVIMAPolicy 25 | prop_dim: 21 26 | prop_keys: ["odom/base_velocity", "qpos/torso", "qpos/left_arm", "qpos/left_gripper", "qpos/right_arm", "qpos/right_gripper"] 27 | num_latest_obs: ${num_latest_obs} 28 | use_modality_type_tokens: false 29 | prop_mlp_hidden_depth: 2 30 | prop_mlp_hidden_dim: 256 31 | pointnet_n_coordinates: 3 32 | pointnet_n_color: 3 33 | pointnet_hidden_depth: 2 34 | pointnet_hidden_dim: 256 35 | action_keys: ${action_keys} 36 | action_key_dims: ${action_key_dims} 37 | # ====== Transformer ====== 38 | xf_n_embd: 256 39 | xf_n_layer: 2 40 | xf_n_head: 8 41 | xf_dropout_rate: 0.1 42 | xf_use_geglu: true 43 | # ====== Action Decoding ====== 44 | learnable_action_readout_token: false 45 | action_dim: 21 46 | action_prediction_horizon: ${action_prediction_horizon} 47 | diffusion_step_embed_dim: 128 48 | unet_down_dims: [64,128] 49 | unet_kernel_size: 5 50 | unet_n_groups: 8 51 | unet_cond_predict_scale: true 52 | # ====== diffusion ====== 53 | noise_scheduler: 54 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 55 | num_train_timesteps: 100 56 | beta_start: 0.0001 57 | beta_end: 0.02 58 | # beta_schedule is important 59 | # this is the best we found 60 | beta_schedule: squaredcos_cap_v2 61 | clip_sample: True 62 | set_alpha_to_one: True 63 | steps_offset: 0 64 | prediction_type: epsilon # or sample 65 | noise_scheduler_step_kwargs: null 66 | num_denoise_steps_per_inference: 16 67 | action_prediction_horizon: ${action_prediction_horizon} 68 | loss_on_latest_obs_only: false 69 | 70 | data_module: 71 | _target_: brs_algo.learning.data.ActionSeqChunkDataModule 72 | obs_window_size: ${num_latest_obs} 73 | action_prediction_horizon: ${action_prediction_horizon} 74 | -------------------------------------------------------------------------------- /main/train/cfg/cfg.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ # all below configs will override this conf.yaml 3 | - arch: wbvima 4 | - task: ??? 5 | 6 | run_name: "${optional_:${prefix}}\ 7 | ${arch_name}\ 8 | _${task_name}\ 9 | _lr${scientific:${lr},1}\ 10 | _wd${scientific:${wd}}\ 11 | _b${bs}\ 12 | ${__optional:${suffix}}\ 13 | ${_optional:${postsuffix}}" 14 | exp_root_dir: ??? 15 | arch_name: ??? # filled by arch 16 | 17 | # ====== main cfg ====== 18 | seed: -1 19 | gpus: 1 20 | lr: 7e-4 21 | use_cosine_lr: true 22 | lr_warmup_steps: 1000 23 | lr_cosine_steps: 300000 24 | lr_cosine_min: 5e-6 25 | lr_layer_decay: 1.0 26 | wd: 0.0 27 | bs: 256 28 | vbs: ${bs} 29 | data_dir: ??? 30 | eval_interval: 10 31 | rollout_eval: false 32 | # ------ logging ------ 33 | use_wandb: true 34 | wandb_project: ??? 35 | wandb_run_name: ${run_name} 36 | 37 | # ------ common ------ 38 | action_keys: ["mobile_base", "torso", "left_arm", "left_gripper", "right_arm", "right_gripper"] 39 | action_key_dims: 40 | mobile_base: 3 41 | torso: 4 42 | left_arm: 6 43 | left_gripper: 1 44 | right_arm: 6 45 | right_gripper: 1 46 | 47 | # ------ module ------ 48 | module: 49 | _target_: ??? # filled by arch 50 | # ====== policy ====== 51 | policy: ??? 52 | # ====== learning ====== 53 | lr: ${lr} 54 | use_cosine_lr: ${use_cosine_lr} 55 | lr_warmup_steps: ${lr_warmup_steps} 56 | lr_cosine_steps: ${lr_cosine_steps} 57 | lr_cosine_min: ${lr_cosine_min} 58 | lr_layer_decay: ${lr_layer_decay} 59 | weight_decay: ${wd} 60 | action_keys: ${action_keys} 61 | 62 | data_module: 63 | _target_: ??? 64 | data_path: ${data_dir} 65 | pcd_downsample_points: ${pcd_downsample_points} 66 | batch_size: ${bs} 67 | val_batch_size: ${vbs} 68 | val_split_ratio: 0.1 69 | seed: ${seed} 70 | dataloader_num_workers: 4 71 | 72 | trainer: 73 | cls: pytorch_lightning.Trainer 74 | accelerator: "gpu" 75 | devices: ${gpus} 76 | precision: 32 77 | benchmark: true # enables cudnn.benchmark 78 | accumulate_grad_batches: 1 79 | num_sanity_val_steps: 0 80 | max_epochs: 999999999 81 | val_check_interval: null 82 | check_val_every_n_epoch: ${eval_interval} 83 | gradient_clip_val: 1.0 84 | checkpoint: # this sub-dict will be popped to send to ModelCheckpoint as args 85 | - filename: "epoch{epoch}-train_loss{train/loss:.5f}" 86 | save_on_train_epoch_end: true # this is a training metric, so we save it at the end of training epoch 87 | save_top_k: 100 88 | save_last: true 89 | monitor: "train/loss" 90 | mode: min 91 | auto_insert_metric_name: false # prevent creating subfolder caused by the slash 92 | - filename: "epoch{epoch}-val_l1_{val/l1:.5f}" 93 | eval_type: "static" 94 | save_top_k: -1 95 | save_last: true 96 | monitor: "val/l1" 97 | mode: min 98 | auto_insert_metric_name: false # prevent creating subfolder caused by the slash 99 | callbacks: 100 | - cls: LearningRateMonitor 101 | logging_interval: step 102 | - cls: RichModelSummary 103 | 104 | # ------------- Global cfgs for enlight.LightningTrainer --------------- 105 | 106 | 107 | # ------------- Resume training --------------- 108 | resume: 109 | ckpt_path: null 110 | full_state: false # if true, resume all states including optimizer, amp, lightning callbacks 111 | strict: true 112 | 113 | # ------------- Testing --------------- 114 | test: 115 | ckpt_path: null 116 | 117 | # ---------------------------- 118 | 119 | prefix: 120 | suffix: 121 | postsuffix: 122 | 123 | hydra: 124 | job: 125 | chdir: true 126 | run: 127 | dir: "." 128 | output_subdir: null 129 | -------------------------------------------------------------------------------- /main/train/cfg/task/clean_house_after_a_wild_party.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task_name: clean_house_after_a_wild_party 4 | pcd_downsample_points: 4096 5 | data_module: 6 | pcd_x_range: [0.0, 2.0] 7 | pcd_y_range: [-1.0, 1.0] 8 | pcd_z_range: [-0.5, 1.6] 9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4] 10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4] 11 | -------------------------------------------------------------------------------- /main/train/cfg/task/clean_the_toilet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task_name: clean_the_toilet 4 | pcd_downsample_points: 4096 5 | data_module: 6 | pcd_x_range: [0.0, 1.0] 7 | pcd_y_range: [-0.5, 0.5] 8 | pcd_z_range: [0.0, 1.0] 9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4] 10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4] 11 | -------------------------------------------------------------------------------- /main/train/cfg/task/lay_clothes_out.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task_name: lay_clothes_out 4 | pcd_downsample_points: 4096 5 | data_module: 6 | pcd_x_range: [0.0, 2.3] 7 | pcd_y_range: [-0.5, 0.5] 8 | pcd_z_range: [-0.3, 2.0] 9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4] 10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4] 11 | -------------------------------------------------------------------------------- /main/train/cfg/task/put_items_onto_shelves.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task_name: put_items_onto_shelves 4 | pcd_downsample_points: 4096 5 | data_module: 6 | pcd_x_range: [0.0, 2.0] 7 | pcd_y_range: [-0.5, 0.5] 8 | pcd_z_range: [-0.3, 2.0] 9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4] 10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4] 11 | -------------------------------------------------------------------------------- /main/train/cfg/task/take_trash_outside.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | task_name: take_trash_outside 4 | pcd_downsample_points: 4096 5 | data_module: 6 | pcd_x_range: [0.0, 4.0] 7 | pcd_y_range: [-1.0, 1.0] 8 | pcd_z_range: [-0.5, 1.6] 9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4] 10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4] 11 | -------------------------------------------------------------------------------- /main/train/train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | import brs_algo.utils as U 4 | from brs_algo.lightning import Trainer 5 | 6 | 7 | @hydra.main(config_name="cfg", config_path="cfg", version_base="1.1") 8 | def main(cfg): 9 | cfg.seed = U.set_seed(cfg.seed) 10 | trainer_ = Trainer(cfg) 11 | trainer_.trainer.loggers[-1].log_hyperparams(U.omegaconf_to_dict(cfg)) 12 | trainer_.fit() 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /media/SUSig-red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/media/SUSig-red.png -------------------------------------------------------------------------------- /media/pull.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/media/pull.gif -------------------------------------------------------------------------------- /media/wbvima.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/media/wbvima.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "brs_algo" 7 | dynamic = ["version"] 8 | dependencies = [ 9 | "numpy >= 1.26.4", 10 | "torch >= 2.5.1", 11 | "dm_tree >= 0.1.8", 12 | "pytorch-lightning >= 2.3.0", 13 | "pillow >= 11.0.0", 14 | "h5py >= 3.12.1", 15 | "hydra-core >= 1.3.2", 16 | "diffusers >= 0.26.3", 17 | "huggingface-hub == 0.24.6", 18 | "einops >= 0.8.0", 19 | "omegaconf >= 2.3.0", 20 | "tqdm >= 4.67.1", 21 | "transformers >= 4.42.4", 22 | "tensorboard", 23 | "rich", 24 | "tabulate", 25 | "numpy-quaternion" 26 | ] 27 | requires-python = ">=3.11" 28 | authors = [ 29 | {name = "Yunfan Jiang", email = "yunfanj@cs.stanford.edu"}, 30 | ] 31 | maintainers = [ 32 | {name = "Yunfan Jiang", email = "yunfanj@cs.stanford.edu"}, 33 | ] 34 | description = "The algorithm repository for BEHAVIOR-Robot-Suite: Streamlining Real-World Whole-Body Manipulation for Everyday Household Activities" 35 | readme = "README.md" 36 | keywords = ["Robotics", "Machine Learning", "Whole-Body Manipulation", "Mobile Manipulation"] 37 | classifiers = [ 38 | "Development Status :: 3 - Alpha", 39 | "Intended Audience :: Researchers, Developers", 40 | "Topic :: Scientific/Engineering :: Robotics", 41 | "Programming Language :: Python :: 3.11", 42 | ] 43 | 44 | [project.urls] 45 | Homepage = "https://behavior-robot-suite.github.io/" 46 | 47 | 48 | [tool.setuptools] 49 | packages = ["brs_algo"] 50 | [tool.setuptools.dynamic] 51 | version = {attr = "brs_algo.__version__"} -------------------------------------------------------------------------------- /scripts/merge_data_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import h5py 6 | from tqdm import tqdm 7 | 8 | 9 | def merge(args): 10 | data_dir = args.data_dir 11 | output_file = args.output_file 12 | assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist" 13 | data_files = [ 14 | os.path.join(data_dir, f) 15 | for f in os.listdir(data_dir) 16 | if f.endswith(".h5") or f.endswith(".hdf5") 17 | ] 18 | print(f"[INFO] Found {len(data_files)} data files in {data_dir}") 19 | 20 | # start merging 21 | merged_file = h5py.File(output_file, "w") 22 | # meta data 23 | merged_file.attrs["num_demos"] = len(data_files) 24 | merged_file.attrs["merging_time"] = time.strftime("%Y-%m-%d-%H-%M-%S") 25 | merged_file.attrs["merged_data_files"] = data_files 26 | # merge each demo 27 | for idx, fpath in tqdm(enumerate(data_files), desc="Merging data files"): 28 | f_to_merge = h5py.File(fpath, "r") 29 | demo_grp = merged_file.create_group(f"demo_{idx}") 30 | # passover meta data 31 | for key in f_to_merge.attrs: 32 | demo_grp.attrs[key] = f_to_merge.attrs[key] 33 | # obs group 34 | obs_grp = demo_grp.create_group("obs") 35 | # joint_state 36 | for ks in f_to_merge["obs/joint_state"]: 37 | for k, v in f_to_merge["obs/joint_state"][ks].items(): 38 | obs_grp.create_dataset(f"joint_state/{ks}/{k}", data=v[:]) 39 | # gripper_state 40 | for ks in f_to_merge["obs/gripper_state"]: 41 | for k, v in f_to_merge["obs/gripper_state"][ks].items(): 42 | obs_grp.create_dataset(f"gripper_state/{ks}/{k}", data=v[:]) 43 | # link_poses 44 | for k, v in f_to_merge["obs/link_poses"].items(): 45 | obs_grp.create_dataset(f"link_poses/{k}", data=v[:]) 46 | # fused point cloud 47 | pcd_xyz = f_to_merge["obs/point_cloud/fused/xyz"][:] # (T, N_points, 3) 48 | pcd_rgb = f_to_merge["obs/point_cloud/fused/rgb"][:] 49 | pcd_padding_mask = f_to_merge["obs/point_cloud/fused/padding_mask"][ 50 | : 51 | ] # (T, N_points) 52 | obs_grp.create_dataset("point_cloud/fused/xyz", data=pcd_xyz) 53 | obs_grp.create_dataset("point_cloud/fused/rgb", data=pcd_rgb) 54 | obs_grp.create_dataset("point_cloud/fused/padding_mask", data=pcd_padding_mask) 55 | # odom 56 | odom_base_velocity = f_to_merge["obs/odom/base_velocity"][:] 57 | obs_grp.create_dataset("odom/base_velocity", data=odom_base_velocity) 58 | # multiview cameras 59 | head_camera_rgb = f_to_merge["obs/rgb/head/img"][:] 60 | head_camera_depth = f_to_merge["obs/depth/head/depth"][:] 61 | left_wrist_camera_rgb = f_to_merge["obs/rgb/left_wrist/img"][:] 62 | left_wrist_camera_depth = f_to_merge["obs/depth/left_wrist/depth"][:] 63 | right_wrist_camera_rgb = f_to_merge["obs/rgb/right_wrist/img"][:] 64 | right_wrist_camera_depth = f_to_merge["obs/depth/right_wrist/depth"][:] 65 | obs_grp.create_dataset("rgb/head/img", data=head_camera_rgb) 66 | obs_grp.create_dataset("depth/head/depth", data=head_camera_depth) 67 | obs_grp.create_dataset("rgb/left_wrist/img", data=left_wrist_camera_rgb) 68 | obs_grp.create_dataset("depth/left_wrist/depth", data=left_wrist_camera_depth) 69 | obs_grp.create_dataset("rgb/right_wrist/img", data=right_wrist_camera_rgb) 70 | obs_grp.create_dataset("depth/right_wrist/depth", data=right_wrist_camera_depth) 71 | # action 72 | action_grp = demo_grp.create_group("action") 73 | for k, v in f_to_merge["action"].items(): 74 | action_grp.create_dataset(k, data=v[:]) 75 | f_to_merge.close() 76 | merged_file.close() 77 | 78 | 79 | if __name__ == "__main__": 80 | args = argparse.ArgumentParser() 81 | args.add_argument( 82 | "--data_dir", 83 | type=str, 84 | required=True, 85 | help="Directory containing individual data files to consolidate.", 86 | ) 87 | args.add_argument( 88 | "--output_file", 89 | type=str, 90 | required=True, 91 | help="A single output file to save the consolidated data.", 92 | ) 93 | args = args.parse_args() 94 | merge(args) 95 | -------------------------------------------------------------------------------- /scripts/post_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post-process collected raw data. Including: 3 | - Compute poses for left/right EEFs and head. 4 | - Process raw odometry data such that it is consistent with R1Interface. 5 | """ 6 | 7 | import os 8 | import argparse 9 | from math import ceil 10 | 11 | import h5py 12 | import numpy as np 13 | from tqdm import tqdm 14 | import quaternion 15 | 16 | from brs_ctrl.kinematics import R1Kinematics 17 | 18 | 19 | JMAP = { 20 | "left_wrist": "obs/joint_state/left_arm/joint_position", 21 | "right_wrist": "obs/joint_state/right_arm/joint_position", 22 | "torso": "obs/joint_state/torso/joint_position", 23 | } 24 | PCDKMAP = { 25 | "left_wrist": "obs/point_cloud/left_wrist", 26 | "right_wrist": "obs/point_cloud/right_wrist", 27 | "torso": "obs/point_cloud/head", 28 | } 29 | CMAP = { 30 | "left_wrist": "left_wrist_camera", 31 | "right_wrist": "right_wrist_camera", 32 | "torso": "head_camera", 33 | } 34 | 35 | 36 | def main(args): 37 | data_dir = args.data_dir 38 | chunk_size = args.process_chunk_size 39 | assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist" 40 | 41 | kin = R1Kinematics() 42 | T_odom2base = kin.T_odom2base 43 | 44 | data_files = [ 45 | os.path.join(data_dir, f) 46 | for f in os.listdir(data_dir) 47 | if f.endswith(".h5") or f.endswith(".hdf5") 48 | ] 49 | print(f"[INFO] Found {len(data_files)} data files in {data_dir}") 50 | for fpath in tqdm(data_files, desc="Processing data files"): 51 | f = h5py.File(fpath, "r+") 52 | all_qpos = {k: f[v][:] for k, v in JMAP.items()} # (T, N_joints) 53 | raw_odom_linear_vel = f["obs/odom/linear_velocity"][:] # (T, 3) 54 | raw_odom_angular_vel = f["obs/odom/angular_velocity"][:] # (T, 3) 55 | # set small vel values to zero, using the same threshold as in R1Interface 56 | vx_zero_idxs = np.abs(raw_odom_linear_vel[:, 0]) <= 1e-2 57 | vy_zero_idxs = np.abs(raw_odom_linear_vel[:, 1]) <= 1e-2 58 | vyaw_zero_idxs = np.abs(raw_odom_angular_vel[:, 2]) <= 5e-3 59 | raw_odom_linear_vel[vx_zero_idxs, 0] = 0 60 | raw_odom_linear_vel[vy_zero_idxs, 1] = 0 61 | raw_odom_angular_vel[vyaw_zero_idxs, 2] = 0 62 | 63 | T = all_qpos[list(JMAP.keys())[0]].shape[0] 64 | link_poses = { 65 | "left_eef": { 66 | "position": np.zeros((T, 3), dtype=np.float32), 67 | "orientation": np.zeros((T, 3, 3), dtype=np.float32), 68 | }, 69 | "right_eef": { 70 | "position": np.zeros((T, 3), dtype=np.float32), 71 | "orientation": np.zeros((T, 3, 3), dtype=np.float32), 72 | }, 73 | "head": { 74 | "position": np.zeros((T, 3), dtype=np.float32), 75 | "orientation": np.zeros((T, 3, 3), dtype=np.float32), 76 | }, 77 | } 78 | base_vel = np.zeros((T, 3), dtype=np.float32) # (v_x, v_y, v_yaw) 79 | 80 | # process in chunks along time dimension 81 | N_chunks = ceil(T / chunk_size) 82 | for chunk_idx in tqdm(range(N_chunks), desc="Processing chunks"): 83 | start_t = chunk_idx * chunk_size 84 | end_t = min((chunk_idx + 1) * chunk_size, T) 85 | for t in range(start_t, end_t): 86 | link2base = kin.get_link_poses_in_base_link( 87 | curr_left_arm_joint=all_qpos["left_wrist"][t, :6], 88 | curr_right_arm_joint=all_qpos["right_wrist"][t, :6], 89 | curr_torso_joint=all_qpos["torso"][t], 90 | ) # dict of (4, 4) 91 | for k in link_poses: 92 | transform = link2base[k] 93 | link_poses[k]["position"][t] = transform[:3, 3] 94 | link_poses[k]["orientation"][t] = transform[:3, :3] 95 | 96 | raw_odom_linear_vel_chunk = raw_odom_linear_vel[ 97 | start_t:end_t 98 | ] # (T_chunk, 3) 99 | base_linear_vel_chunk = ( 100 | T_odom2base[:3, :3] @ raw_odom_linear_vel_chunk.T 101 | ).T 102 | base_angular_vel_chunk = raw_odom_angular_vel[start_t:end_t] 103 | base_vel[start_t:end_t, :2] = base_linear_vel_chunk[:, :2] 104 | base_vel[start_t:end_t, 2] = base_angular_vel_chunk[:, 2] 105 | 106 | # save link poses to file 107 | link_pose_grp = f.create_group("obs/link_poses") 108 | for k in link_poses: 109 | rot_mat = link_poses[k]["orientation"] 110 | rot_quat = quaternion.as_float_array( 111 | quaternion.from_rotation_matrix(rot_mat) 112 | ) # (T, 4) in wxyz order 113 | # change to xyzw order since that's what pybullet uses 114 | rot_quat = rot_quat[..., [1, 2, 3, 0]] 115 | pose = np.concatenate( 116 | [link_poses[k]["position"], rot_quat], axis=-1 117 | ) # (T, 7) 118 | link_pose_grp.create_dataset(k, data=pose) 119 | # save base velocity to file 120 | f.create_dataset("obs/odom/base_velocity", data=base_vel) 121 | f.close() 122 | 123 | 124 | if __name__ == "__main__": 125 | args = argparse.ArgumentParser() 126 | args.add_argument("--data_dir", type=str, required=True) 127 | args.add_argument("--process_chunk_size", type=int, default=500) 128 | args = args.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | --------------------------------------------------------------------------------