├── .gitignore
├── LICENSE
├── README.md
├── brs_algo
├── __init__.py
├── learning
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── collate.py
│ │ ├── data_module.py
│ │ └── dataset.py
│ ├── module
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── diffusion_module.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── common
│ │ │ ├── __init__.py
│ │ │ ├── misc.py
│ │ │ └── mlp.py
│ │ ├── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── diffusion_head.py
│ │ │ └── unet.py
│ │ ├── features
│ │ │ ├── __init__.py
│ │ │ ├── fusion.py
│ │ │ └── pointnet.py
│ │ └── gpt
│ │ │ ├── __init__.py
│ │ │ └── gpt.py
│ └── policy
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── wbvima_policy.py
├── lightning
│ ├── __init__.py
│ ├── lightning.py
│ └── trainer.py
├── optim
│ ├── __init__.py
│ ├── lr_schedule.py
│ └── optimizer_group.py
├── rollout
│ ├── __init__.py
│ ├── asynced_rollout.py
│ └── synced_rollout.py
└── utils
│ ├── __init__.py
│ ├── array_tensor_utils.py
│ ├── config_utils.py
│ ├── convert_utils.py
│ ├── file_utils.py
│ ├── functional_utils.py
│ ├── misc_utils.py
│ ├── print_utils.py
│ ├── random_seed.py
│ ├── shape_utils.py
│ ├── termcolor.py
│ ├── torch_utils.py
│ └── tree_utils.py
├── main
├── rollout
│ ├── clean_house_after_a_wild_party
│ │ ├── common.py
│ │ ├── rollout_async.py
│ │ └── rollout_sync.py
│ ├── clean_the_toilet
│ │ ├── common.py
│ │ ├── rollout_async.py
│ │ └── rollout_sync.py
│ ├── lay_clothes_out
│ │ ├── common.py
│ │ ├── rollout_async.py
│ │ └── rollout_sync.py
│ ├── put_items_onto_shelves
│ │ ├── common.py
│ │ ├── rollout_async.py
│ │ └── rollout_sync.py
│ └── take_trash_outside
│ │ ├── common.py
│ │ ├── rollout_async.py
│ │ └── rollout_sync.py
└── train
│ ├── cfg
│ ├── arch
│ │ └── wbvima.yaml
│ ├── cfg.yaml
│ └── task
│ │ ├── clean_house_after_a_wild_party.yaml
│ │ ├── clean_the_toilet.yaml
│ │ ├── lay_clothes_out.yaml
│ │ ├── put_items_onto_shelves.yaml
│ │ └── take_trash_outside.yaml
│ └── train.py
├── media
├── SUSig-red.png
├── pull.gif
└── wbvima.gif
├── pyproject.toml
├── scripts
├── merge_data_files.py
└── post_process.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 behavior-robot-suite
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BEHAVIOR Robot Suite: Streamlining Real-World Whole-Body Manipulation for Everyday Household Activities
2 |
3 |
4 | [Yunfan Jiang](https://yunfanj.com/),
5 | [Ruohan Zhang](https://ai.stanford.edu/~zharu/),
6 | [Josiah Wong](https://jdw.ong/),
7 | [Chen Wang](https://www.chenwangjeremy.net/),
8 | [Yanjie Ze](https://yanjieze.com/),
9 | [Hang Yin](https://hang-yin.github.io/),
10 | [Cem Gokmen](https://www.cemgokmen.com/),
11 | [Shuran Song](https://shurans.github.io/),
12 | [Jiajun Wu](https://jiajunwu.com/),
13 | [Li Fei-Fei](https://profiles.stanford.edu/fei-fei-li)
14 |
15 |

16 |
17 | [[Website]](https://behavior-robot-suite.github.io/)
18 | [[arXiv]](https://arxiv.org/abs/2503.05652)
19 | [[PDF]](https://behavior-robot-suite.github.io/assets/pdf/brs_paper.pdf)
20 | [[Doc]](https://behavior-robot-suite.github.io/docs/)
21 | [[Robot Code]](https://github.com/behavior-robot-suite/brs-ctrl)
22 | [[Training Data]](https://huggingface.co/datasets/behavior-robot-suite/data)
23 |
24 | [](https://github.com/behavior-robot-suite/brs-algo)
25 | [

](https://pytorch.org/)
26 | [

](https://behavior-robot-suite.github.io/docs/)
27 | [](https://github.com/behavior-robot-suite/brs-algo/blob/main/LICENSE)
28 |
29 | 
30 | ______________________________________________________________________
31 |
32 |
33 | We introduce the **BEHAVIOR Robot Suite** (BRS), a comprehensive framework for learning whole-body manipulation to tackle diverse real-world household tasks. BRS addresses both hardware and learning challenges through two key innovations: **WB-VIMA** and [JoyLo](https://github.com/behavior-robot-suite/brs-ctrl).
34 |
35 | WB-VIMA is an imitation learning algorithm designed to model whole-body actions by leveraging the robot’s inherent kinematic hierarchy. A key insight behind WB-VIMA is that robot joints exhibit strong interdependencies—small movements in upstream links (e.g., the torso) can lead to large displacements in downstream links (e.g., the end-effectors). To ensure precise coordination across all joints, WB-VIMA **conditions action predictions for downstream components on those of upstream components**, resulting in more synchronized whole-body movements. Additionally, WB-VIMA **dynamically aggregates multi-modal observations using self-attention**, allowing it to learn expressive policies while mitigating overfitting to proprioceptive inputs.
36 |
37 | 
38 |
39 |
40 | ## Getting Started
41 |
42 | > [!TIP]
43 | > 🚀 Check out the [doc](https://behavior-robot-suite.github.io/docs/sections/wbvima/overview.html) for detailed installation and usage instructions!
44 |
45 | To train a WB-VIMA policy, simply run the following command:
46 |
47 | ```bash
48 | python3 main/train/train.py data_dir= \
49 | bs= \
50 | arch=wbvima \
51 | task= \
52 | exp_root_dir= \
53 | gpus= \
54 | use_wandb= \
55 | wandb_project=
56 | ```
57 |
58 | To deploy a WB-VIMA policy on the real robot, simply run the following command:
59 |
60 | ```bash
61 | python3 main/rollout//rollout_async.py --ckpt_path --action_execute_start_idx
62 | ```
63 |
64 | ## Check out Our Paper
65 | Our paper is posted on [arXiv](https://arxiv.org/abs/2503.05652). If you find our work useful, please consider citing us!
66 |
67 | ```bibtex
68 | @article{jiang2025brs,
69 | title = {BEHAVIOR Robot Suite: Streamlining Real-World Whole-Body Manipulation for Everyday Household Activities},
70 | author = {Yunfan Jiang and Ruohan Zhang and Josiah Wong and Chen Wang and Yanjie Ze and Hang Yin and Cem Gokmen and Shuran Song and Jiajun Wu and Li Fei-Fei},
71 | year = {2025},
72 | journal = {arXiv preprint arXiv: 2503.05652}
73 | }
74 | ```
75 |
76 | ## License
77 | This codebase is released under the [MIT License](LICENSE).
78 |
--------------------------------------------------------------------------------
/brs_algo/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1"
2 |
--------------------------------------------------------------------------------
/brs_algo/learning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/brs_algo/learning/__init__.py
--------------------------------------------------------------------------------
/brs_algo/learning/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_module import ActionSeqChunkDataModule
2 |
--------------------------------------------------------------------------------
/brs_algo/learning/data/collate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import brs_algo.utils as U
3 |
4 |
5 | def seq_chunk_collate_fn(sample_list):
6 | """
7 | sample_list: list of (T, ...). PyTorch's native collate_fn can stack all data.
8 | But here we also add a leading singleton dimension, so it won't break the compatibility with episode data format.
9 | """
10 | sample_list = U.any_stack(sample_list, dim=0) # (B, T, ...)
11 | sample_list = nested_np_expand_dims(sample_list, axis=0) # (1, B, T, ...)
12 | # convert to tensor
13 | return any_to_torch_tensor(sample_list)
14 |
15 |
16 | @U.make_recursive_func
17 | def nested_np_expand_dims(x, axis):
18 | if U.is_numpy(x):
19 | return np.expand_dims(x, axis=axis)
20 | else:
21 | raise ValueError(f"Input ({type(x)}) must be a numpy array.")
22 |
23 |
24 | def any_to_torch_tensor(x):
25 | if isinstance(x, dict):
26 | return {k: any_to_torch_tensor(v) for k, v in x.items()}
27 | return U.any_to_torch_tensor(x)
28 |
--------------------------------------------------------------------------------
/brs_algo/learning/data/data_module.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | from torch.utils.data import DataLoader
4 | from pytorch_lightning import LightningDataModule
5 |
6 | from brs_algo import utils as U
7 | from brs_algo.learning.data.collate import seq_chunk_collate_fn
8 | from brs_algo.learning.data.dataset import ActionSeqChunkDataset
9 |
10 |
11 | class ActionSeqChunkDataModule(LightningDataModule):
12 | def __init__(
13 | self,
14 | *,
15 | data_path: str,
16 | pcd_downsample_points: int,
17 | pcd_x_range: Tuple[float, float],
18 | pcd_y_range: Tuple[float, float],
19 | pcd_z_range: Tuple[float, float],
20 | mobile_base_vel_action_min: Tuple[float, float, float],
21 | mobile_base_vel_action_max: Tuple[float, float, float],
22 | load_visual_obs_in_memory: bool = True,
23 | multi_view_cameras: Optional[List[str]] = None,
24 | load_multi_view_camera_rgb: bool = False,
25 | load_multi_view_camera_depth: bool = False,
26 | obs_window_size: int,
27 | action_prediction_horizon: int,
28 | batch_size: int,
29 | val_batch_size: Optional[int],
30 | val_split_ratio: float,
31 | dataloader_num_workers: int,
32 | seed: Optional[int] = None,
33 | ):
34 | super().__init__()
35 | self._data_path = data_path
36 | self._pcd_downsample_points = pcd_downsample_points
37 | self._pcd_x_range = pcd_x_range
38 | self._pcd_y_range = pcd_y_range
39 | self._pcd_z_range = pcd_z_range
40 | self._mobile_base_vel_action_min = mobile_base_vel_action_min
41 | self._mobile_base_vel_action_max = mobile_base_vel_action_max
42 | self._load_visual_obs_in_memory = load_visual_obs_in_memory
43 | self._multi_view_cameras = multi_view_cameras
44 | self._load_multi_view_camera_rgb = load_multi_view_camera_rgb
45 | self._load_multi_view_camera_depth = load_multi_view_camera_depth
46 | self._batch_size = batch_size
47 | self._val_batch_size = val_batch_size
48 | self._dataloader_num_workers = dataloader_num_workers
49 | self._seed = seed
50 | self._val_split_ratio = val_split_ratio
51 |
52 | self._train_dataset, self._val_dataset = None, None
53 | self._obs_window_size = obs_window_size
54 | self._action_prediction_horizon = action_prediction_horizon
55 |
56 | def setup(self, stage: str) -> None:
57 | if stage == "fit" or stage is None:
58 | all_dataset = ActionSeqChunkDataset(
59 | fpath=self._data_path,
60 | pcd_downsample_points=self._pcd_downsample_points,
61 | pcd_x_range=self._pcd_x_range,
62 | pcd_y_range=self._pcd_y_range,
63 | pcd_z_range=self._pcd_z_range,
64 | mobile_base_vel_action_min=self._mobile_base_vel_action_min,
65 | mobile_base_vel_action_max=self._mobile_base_vel_action_max,
66 | load_visual_obs_in_memory=self._load_visual_obs_in_memory,
67 | multi_view_cameras=self._multi_view_cameras,
68 | load_multi_view_camera_rgb=self._load_multi_view_camera_rgb,
69 | load_multi_view_camera_depth=self._load_multi_view_camera_depth,
70 | seed=self._seed,
71 | action_prediction_horizon=self._action_prediction_horizon,
72 | obs_window_size=self._obs_window_size,
73 | )
74 | self._train_dataset, self._val_dataset = U.sequential_split_dataset(
75 | all_dataset,
76 | split_portions=[1 - self._val_split_ratio, self._val_split_ratio],
77 | )
78 |
79 | def train_dataloader(self):
80 | return DataLoader(
81 | self._train_dataset,
82 | batch_size=self._batch_size,
83 | num_workers=min(self._batch_size, self._dataloader_num_workers),
84 | pin_memory=True,
85 | persistent_workers=True,
86 | collate_fn=seq_chunk_collate_fn,
87 | )
88 |
89 | def val_dataloader(self):
90 | return DataLoader(
91 | self._val_dataset,
92 | batch_size=self._val_batch_size,
93 | num_workers=min(self._val_batch_size, self._dataloader_num_workers),
94 | pin_memory=True,
95 | persistent_workers=True,
96 | collate_fn=seq_chunk_collate_fn,
97 | )
98 |
--------------------------------------------------------------------------------
/brs_algo/learning/module/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffusion_module import DiffusionModule
2 |
--------------------------------------------------------------------------------
/brs_algo/learning/module/base.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from pytorch_lightning import LightningModule
3 |
4 |
5 | class ImitationBaseModule(LightningModule):
6 | """
7 | Base class for IL algorithms that require 1) an environment, 2) a policy, and 3) rollout evaluation.
8 | """
9 |
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 |
13 | def training_step(self, *args, **kwargs):
14 | self.on_train_start()
15 | loss, log_dict, batch_size = self.imitation_training_step(*args, **kwargs)
16 | log_dict = {f"train/{k}": v for k, v in log_dict.items()}
17 | log_dict["train/loss"] = loss
18 | self.log_dict(
19 | log_dict,
20 | prog_bar=True,
21 | on_step=False,
22 | on_epoch=True,
23 | batch_size=batch_size,
24 | )
25 | return loss
26 |
27 | def validation_step(self, *args, **kwargs):
28 | loss, log_dict, real_batch_size = self.imitation_test_step(*args, **kwargs)
29 | log_dict = {f"val/{k}": v for k, v in log_dict.items()}
30 | log_dict["val/loss"] = loss
31 | self.log_dict(
32 | log_dict,
33 | prog_bar=True,
34 | on_step=False,
35 | on_epoch=True,
36 | batch_size=real_batch_size,
37 | )
38 | return log_dict
39 |
40 | def test_step(self, *args, **kwargs):
41 | loss, log_dict, real_batch_size = self.imitation_test_step(*args, **kwargs)
42 | log_dict = {f"test/{k}": v for k, v in log_dict.items()}
43 | log_dict["test/loss"] = loss
44 | self.log_dict(
45 | log_dict,
46 | prog_bar=True,
47 | on_step=False,
48 | on_epoch=True,
49 | batch_size=real_batch_size,
50 | )
51 | return log_dict
52 |
53 | def configure_optimizers(self):
54 | """
55 | Get optimizers, which are subsequently used to train.
56 | """
57 | raise NotImplementedError
58 |
59 | def imitation_training_step(self, *args, **kwargs) -> Any:
60 | raise NotImplementedError
61 |
62 | def imitation_test_step(self, *args, **kwargs) -> Any:
63 | raise NotImplementedError
64 |
--------------------------------------------------------------------------------
/brs_algo/learning/module/diffusion_module.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Any, Optional
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from hydra.utils import instantiate
7 | from omegaconf import DictConfig
8 |
9 | import brs_algo.utils as U
10 | from brs_algo.optim import CosineScheduleFunction
11 | from brs_algo.learning.policy.base import BasePolicy
12 | from brs_algo.learning.module.base import ImitationBaseModule
13 |
14 |
15 | class DiffusionModule(ImitationBaseModule):
16 | def __init__(
17 | self,
18 | *,
19 | # ====== policy ======
20 | policy: Union[BasePolicy, DictConfig],
21 | action_prediction_horizon: int,
22 | # ====== learning ======
23 | lr: float,
24 | use_cosine_lr: bool = False,
25 | lr_warmup_steps: Optional[int] = None,
26 | lr_cosine_steps: Optional[int] = None,
27 | lr_cosine_min: Optional[float] = None,
28 | lr_layer_decay: float = 1.0,
29 | weight_decay: float = 0.0,
30 | action_keys: List[str],
31 | loss_on_latest_obs_only: bool = False,
32 | ):
33 | super().__init__()
34 | if isinstance(policy, DictConfig):
35 | policy = instantiate(policy)
36 | self.policy = policy
37 | self._action_keys = action_keys
38 | self.action_prediction_horizon = action_prediction_horizon
39 | self.lr = lr
40 | self.use_cosine_lr = use_cosine_lr
41 | self.lr_warmup_steps = lr_warmup_steps
42 | self.lr_cosine_steps = lr_cosine_steps
43 | self.lr_cosine_min = lr_cosine_min
44 | self.lr_layer_decay = lr_layer_decay
45 | self.weight_decay = weight_decay
46 | self.loss_on_latest_obs_only = loss_on_latest_obs_only
47 |
48 | def imitation_training_step(self, *args, **kwargs) -> Any:
49 | return self.imitation_training_step_seq_policy(*args, **kwargs)
50 |
51 | def imitation_test_step(self, *args, **kwargs):
52 | return self.imitation_val_step_seq_policy(*args, **kwargs)
53 |
54 | def imitation_training_step_seq_policy(self, batch, batch_idx):
55 | B = U.get_batch_size(
56 | U.any_slice(batch["action_chunks"], np.s_[0]),
57 | strict=True,
58 | )
59 | # obs data is dict of (N_chunks, B, window_size, ...)
60 | # action chunks is (N_chunks, B, window_size, action_prediction_horizon, A)
61 | # we loop over chunk dim
62 | main_data = U.unstack_sequence_fields(
63 | batch, batch_size=U.get_batch_size(batch, strict=True)
64 | )
65 | all_loss, all_mask_sum = [], 0
66 | for i, main_data_chunk in enumerate(main_data):
67 | # get padding mask
68 | pad_mask = main_data_chunk.pop(
69 | "pad_mask"
70 | ) # (B, window_size, L_pred_horizon)
71 | action_chunks = main_data_chunk.pop(
72 | "action_chunks"
73 | ) # dict of (B, window_size, L_pred_horizon, A)
74 | gt_actions = torch.cat(
75 | [action_chunks[k] for k in self._action_keys], dim=-1
76 | )
77 | transformer_output = self.policy(
78 | main_data_chunk
79 | ) # (B, L, E), where L is interleaved time and modality tokens
80 | loss = self.policy.compute_loss(
81 | transformer_output=transformer_output,
82 | gt_action=gt_actions,
83 | ) # (B, T_obs, T_act)
84 | if self.loss_on_latest_obs_only:
85 | mask = torch.zeros_like(pad_mask)
86 | mask[:, -1] = 1
87 | pad_mask = pad_mask * mask
88 | loss = loss * pad_mask
89 | all_loss.append(loss)
90 | all_mask_sum += pad_mask.sum()
91 | action_loss = torch.sum(torch.stack(all_loss)) / all_mask_sum
92 | # sum over action_prediction_horizon dim instead of avg
93 | action_loss = action_loss * self.action_prediction_horizon
94 | log_dict = {"diffusion_loss": action_loss}
95 | loss = action_loss
96 | return loss, log_dict, B
97 |
98 | def imitation_val_step_seq_policy(self, batch, batch_idx):
99 | """
100 | Will denoise as if it is in rollout
101 | but no env interaction
102 | """
103 | B = U.get_batch_size(
104 | U.any_slice(batch["action_chunks"], np.s_[0]),
105 | strict=True,
106 | )
107 | # obs data is dict of (N_chunks, B, window_size, ...)
108 | # action chunks is (N_chunks, B, window_size, action_prediction_horizon, A)
109 | # we loop over chunk dim
110 | main_data = U.unstack_sequence_fields(
111 | batch, batch_size=U.get_batch_size(batch, strict=True)
112 | )
113 | all_l1, all_mask_sum = {}, 0
114 | for i, main_data_chunk in enumerate(main_data):
115 | # get padding mask
116 | pad_mask = main_data_chunk.pop(
117 | "pad_mask"
118 | ) # (B, window_size, L_pred_horizon)
119 | gt_actions = main_data_chunk.pop(
120 | "action_chunks"
121 | ) # dict of (B, window_size, L_pred_horizon, A)
122 | transformer_output = self.policy(
123 | main_data_chunk
124 | ) # (B, L, E), where L is interleaved time and modality tokens
125 | pred_actions = self.policy.inference(
126 | transformer_output=transformer_output,
127 | return_last_timestep_only=False,
128 | ) # dict of (B, window_size, L_pred_horizon, A)
129 | for action_k in pred_actions:
130 | pred = pred_actions[action_k]
131 | gt = gt_actions[action_k]
132 | l1 = F.l1_loss(
133 | pred, gt, reduction="none"
134 | ) # (B, window_size, L_pred_horizon, A)
135 | # sum over action dim
136 | l1 = l1.sum(dim=-1).reshape(
137 | pad_mask.shape
138 | ) # (B, window_size, L_pred_horizon)
139 | if self.loss_on_latest_obs_only:
140 | mask = torch.zeros_like(pad_mask)
141 | mask[:, -1] = 1
142 | pad_mask = pad_mask * mask
143 | l1 = l1 * pad_mask
144 | if action_k not in all_l1:
145 | all_l1[action_k] = [
146 | l1,
147 | ]
148 | else:
149 | all_l1[action_k].append(l1)
150 | all_mask_sum += pad_mask.sum()
151 | # avg on chunks dim, batch dim, and obs window dim so we can compare under different training settings
152 | all_l1 = {
153 | k: torch.sum(torch.stack(v)) / all_mask_sum for k, v in all_l1.items()
154 | }
155 | all_l1 = {k: v * self.action_prediction_horizon for k, v in all_l1.items()}
156 | summed_l1 = sum(all_l1.values())
157 | all_l1 = {f"l1_{k}": v for k, v in all_l1.items()}
158 | all_l1["l1"] = summed_l1
159 | return summed_l1, all_l1, B
160 |
161 | def configure_optimizers(self):
162 | optimizer_groups = self.policy.get_optimizer_groups(
163 | weight_decay=self.weight_decay,
164 | lr_layer_decay=self.lr_layer_decay,
165 | lr_scale=1.0,
166 | )
167 |
168 | optimizer = torch.optim.AdamW(
169 | optimizer_groups,
170 | lr=self.lr,
171 | weight_decay=self.weight_decay,
172 | )
173 |
174 | if self.use_cosine_lr:
175 | scheduler_kwargs = dict(
176 | base_value=1.0, # anneal from the original LR value
177 | final_value=self.lr_cosine_min / self.lr,
178 | epochs=self.lr_cosine_steps,
179 | warmup_start_value=self.lr_cosine_min / self.lr,
180 | warmup_epochs=self.lr_warmup_steps,
181 | steps_per_epoch=1,
182 | )
183 | scheduler = torch.optim.lr_scheduler.LambdaLR(
184 | optimizer=optimizer,
185 | lr_lambda=CosineScheduleFunction(**scheduler_kwargs),
186 | )
187 | return (
188 | [optimizer],
189 | [{"scheduler": scheduler, "interval": "step"}],
190 | )
191 |
192 | return optimizer
193 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/brs_algo/learning/nn/__init__.py
--------------------------------------------------------------------------------
/brs_algo/learning/nn/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .mlp import build_mlp, MLP
2 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/common/misc.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/common/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import Callable, Literal
3 |
4 |
5 | def get_activation(activation: str | Callable | None) -> Callable:
6 | if not activation:
7 | return nn.Identity
8 | elif callable(activation):
9 | return activation
10 | ACT_LAYER = {
11 | "tanh": nn.Tanh,
12 | "relu": lambda: nn.ReLU(inplace=True),
13 | "leaky_relu": lambda: nn.LeakyReLU(inplace=True),
14 | "swish": lambda: nn.SiLU(inplace=True), # SiLU is alias for Swish
15 | "sigmoid": nn.Sigmoid,
16 | "elu": lambda: nn.ELU(inplace=True),
17 | "gelu": nn.GELU,
18 | }
19 | activation = activation.lower()
20 | assert activation in ACT_LAYER, f"Supported activations: {ACT_LAYER.keys()}"
21 | return ACT_LAYER[activation]
22 |
23 |
24 | def get_initializer(method: str | Callable, activation: str) -> Callable:
25 | if isinstance(method, str):
26 | assert hasattr(
27 | nn.init, f"{method}_"
28 | ), f"Initializer nn.init.{method}_ does not exist"
29 | if method == "orthogonal":
30 | try:
31 | gain = nn.init.calculate_gain(activation)
32 | except ValueError:
33 | gain = 1.0
34 | return lambda x: nn.init.orthogonal_(x, gain=gain)
35 | else:
36 | return getattr(nn.init, f"{method}_")
37 | else:
38 | assert callable(method)
39 | return method
40 |
41 |
42 | def build_mlp(
43 | input_dim,
44 | *,
45 | hidden_dim: int,
46 | output_dim: int,
47 | hidden_depth: int = None,
48 | num_layers: int = None,
49 | activation: str | Callable = "relu",
50 | weight_init: str | Callable = "orthogonal",
51 | bias_init="zeros",
52 | norm_type: Literal["batchnorm", "layernorm"] | None = None,
53 | add_input_activation: bool | str | Callable = False,
54 | add_input_norm: bool = False,
55 | add_output_activation: bool | str | Callable = False,
56 | add_output_norm: bool = False,
57 | ) -> nn.Sequential:
58 | """
59 | In other popular RL implementations, tanh is typically used with orthogonal
60 | initialization, which may perform better than ReLU.
61 |
62 | Args:
63 | norm_type: None, "batchnorm", "layernorm", applied to intermediate layers
64 | add_input_activation: whether to add a nonlinearity to the input _before_
65 | the MLP computation. This is useful for processing a feature from a preceding
66 | image encoder, for example. Image encoder typically has a linear layer
67 | at the end, and we don't want the MLP to immediately stack another linear
68 | layer on the input features.
69 | - True to add the same activation as the rest of the MLP
70 | - str to add an activation of a different type.
71 | add_input_norm: see `add_input_activation`, whether to add a normalization layer
72 | to the input _before_ the MLP computation.
73 | values: True to add the `norm_type` to the input
74 | add_output_activation: whether to add a nonlinearity to the output _after_ the
75 | MLP computation.
76 | - True to add the same activation as the rest of the MLP
77 | - str to add an activation of a different type.
78 | add_output_norm: see `add_output_activation`, whether to add a normalization layer
79 | _after_ the MLP computation.
80 | values: True to add the `norm_type` to the input
81 | """
82 | assert (hidden_depth is None) != (num_layers is None), (
83 | "Either hidden_depth or num_layers must be specified, but not both. "
84 | "num_layers is defined as hidden_depth+1"
85 | )
86 | if hidden_depth is not None:
87 | assert hidden_depth >= 0
88 | if num_layers is not None:
89 | assert num_layers >= 1
90 | act_layer = get_activation(activation)
91 |
92 | weight_init = get_initializer(weight_init, activation)
93 | bias_init = get_initializer(bias_init, activation)
94 |
95 | if norm_type is not None:
96 | norm_type = norm_type.lower()
97 |
98 | if not norm_type:
99 | norm_type = nn.Identity
100 | elif norm_type == "batchnorm":
101 | norm_type = nn.BatchNorm1d
102 | elif norm_type == "layernorm":
103 | norm_type = nn.LayerNorm
104 | else:
105 | raise ValueError(f"Unsupported norm layer: {norm_type}")
106 |
107 | hidden_depth = num_layers - 1 if hidden_depth is None else hidden_depth
108 | if hidden_depth == 0:
109 | mods = [nn.Linear(input_dim, output_dim)]
110 | else:
111 | mods = [nn.Linear(input_dim, hidden_dim), norm_type(hidden_dim), act_layer()]
112 | for i in range(hidden_depth - 1):
113 | mods += [
114 | nn.Linear(hidden_dim, hidden_dim),
115 | norm_type(hidden_dim),
116 | act_layer(),
117 | ]
118 | mods.append(nn.Linear(hidden_dim, output_dim))
119 |
120 | if add_input_norm:
121 | mods = [norm_type(input_dim)] + mods
122 | if add_input_activation:
123 | if add_input_activation is not True:
124 | act_layer = get_activation(add_input_activation)
125 | mods = [act_layer()] + mods
126 | if add_output_norm:
127 | mods.append(norm_type(output_dim))
128 | if add_output_activation:
129 | if add_output_activation is not True:
130 | act_layer = get_activation(add_output_activation)
131 | mods.append(act_layer())
132 |
133 | for mod in mods:
134 | if isinstance(mod, nn.Linear):
135 | weight_init(mod.weight)
136 | bias_init(mod.bias)
137 |
138 | return nn.Sequential(*mods)
139 |
140 |
141 | class MLP(nn.Module):
142 | def __init__(
143 | self,
144 | input_dim,
145 | *,
146 | hidden_dim: int,
147 | output_dim: int,
148 | hidden_depth: int = None,
149 | num_layers: int = None,
150 | activation: str | Callable = "relu",
151 | weight_init: str | Callable = "orthogonal",
152 | bias_init="zeros",
153 | norm_type: Literal["batchnorm", "layernorm"] | None = None,
154 | add_input_activation: bool | str | Callable = False,
155 | add_input_norm: bool = False,
156 | add_output_activation: bool | str | Callable = False,
157 | add_output_norm: bool = False,
158 | ):
159 | super().__init__()
160 | # delegate to build_mlp by keywords
161 | self.layers = build_mlp(
162 | input_dim,
163 | hidden_dim=hidden_dim,
164 | output_dim=output_dim,
165 | hidden_depth=hidden_depth,
166 | num_layers=num_layers,
167 | activation=activation,
168 | weight_init=weight_init,
169 | bias_init=bias_init,
170 | norm_type=norm_type,
171 | add_input_activation=add_input_activation,
172 | add_input_norm=add_input_norm,
173 | add_output_activation=add_output_activation,
174 | add_output_norm=add_output_norm,
175 | )
176 | # add attributes to the class
177 | self.input_dim = input_dim
178 | self.output_dim = output_dim
179 | self.hidden_depth = hidden_depth
180 | self.activation = activation
181 | self.weight_init = weight_init
182 | self.bias_init = bias_init
183 | self.norm_type = norm_type
184 | if add_input_activation is True:
185 | self.input_activation = activation
186 | else:
187 | self.input_activation = add_input_activation
188 | if add_input_norm is True:
189 | self.input_norm_type = norm_type
190 | else:
191 | self.input_norm_type = None
192 | # do the same for output activation and norm
193 | if add_output_activation is True:
194 | self.output_activation = activation
195 | else:
196 | self.output_activation = add_output_activation
197 | if add_output_norm is True:
198 | self.output_norm_type = norm_type
199 | else:
200 | self.output_norm_type = None
201 |
202 | def forward(self, x):
203 | return self.layers(x)
204 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffusion_head import UNetDiffusionHead, WholeBodyUNetDiffusionHead
2 | from .unet import ConditionalUnet1D
3 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/diffusion/unet.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import einops
6 | from einops.layers.torch import Rearrange
7 |
8 | from brs_algo.optim import default_optimizer_groups
9 |
10 |
11 | class Downsample1d(nn.Module):
12 | def __init__(self, dim):
13 | super().__init__()
14 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
15 |
16 | def forward(self, x):
17 | return self.conv(x)
18 |
19 |
20 | class Upsample1d(nn.Module):
21 | def __init__(self, dim):
22 | super().__init__()
23 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
24 |
25 | def forward(self, x):
26 | return self.conv(x)
27 |
28 |
29 | class SinusoidalPosEmb(nn.Module):
30 | def __init__(self, dim):
31 | super().__init__()
32 | self.dim = dim
33 |
34 | def forward(self, x):
35 | device = x.device
36 | half_dim = self.dim // 2
37 | emb = math.log(10000) / (half_dim - 1)
38 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
39 | emb = x[:, None] * emb[None, :]
40 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
41 | return emb
42 |
43 |
44 | class Conv1dBlock(nn.Module):
45 | """
46 | Conv1d --> GroupNorm --> Mish
47 | """
48 |
49 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
50 | super().__init__()
51 |
52 | self.block = nn.Sequential(
53 | nn.Conv1d(
54 | inp_channels, out_channels, kernel_size, padding=kernel_size // 2
55 | ),
56 | # Rearrange('batch channels horizon -> batch channels 1 horizon'),
57 | nn.GroupNorm(n_groups, out_channels),
58 | # Rearrange('batch channels 1 horizon -> batch channels horizon'),
59 | nn.Mish(),
60 | )
61 |
62 | def forward(self, x):
63 | return self.block(x)
64 |
65 |
66 | class ConditionalResidualBlock1D(nn.Module):
67 | def __init__(
68 | self,
69 | in_channels,
70 | out_channels,
71 | cond_dim,
72 | kernel_size=3,
73 | n_groups=8,
74 | cond_predict_scale=False,
75 | ):
76 | super().__init__()
77 |
78 | self.blocks = nn.ModuleList(
79 | [
80 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
81 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
82 | ]
83 | )
84 |
85 | # FiLM modulation https://arxiv.org/abs/1709.07871
86 | # predicts per-channel scale and bias
87 | cond_channels = out_channels
88 | if cond_predict_scale:
89 | cond_channels = out_channels * 2
90 | self.cond_predict_scale = cond_predict_scale
91 | self.out_channels = out_channels
92 | self.cond_encoder = nn.Sequential(
93 | nn.Mish(),
94 | nn.Linear(cond_dim, cond_channels),
95 | Rearrange("batch t -> batch t 1"),
96 | )
97 |
98 | # make sure dimensions compatible
99 | self.residual_conv = (
100 | nn.Conv1d(in_channels, out_channels, 1)
101 | if in_channels != out_channels
102 | else nn.Identity()
103 | )
104 |
105 | def forward(self, x, cond):
106 | """
107 | x : [ batch_size x in_channels x horizon ]
108 | cond : [ batch_size x cond_dim]
109 |
110 | returns:
111 | out : [ batch_size x out_channels x horizon ]
112 | """
113 | out = self.blocks[0](x)
114 | embed = self.cond_encoder(cond)
115 | if self.cond_predict_scale:
116 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
117 | scale = embed[:, 0, ...]
118 | bias = embed[:, 1, ...]
119 | out = scale * out + bias
120 | else:
121 | out = out + embed
122 | out = self.blocks[1](out)
123 | out = out + self.residual_conv(x)
124 | return out
125 |
126 |
127 | class ConditionalUnet1D(nn.Module):
128 | def __init__(
129 | self,
130 | input_dim,
131 | *,
132 | local_cond_dim=None,
133 | global_cond_dim=None,
134 | diffusion_step_embed_dim=256,
135 | down_dims=[256, 512, 1024],
136 | kernel_size=3,
137 | n_groups=8,
138 | cond_predict_scale=False,
139 | ):
140 | super().__init__()
141 | all_dims = [input_dim] + list(down_dims)
142 | start_dim = down_dims[0]
143 |
144 | dsed = diffusion_step_embed_dim
145 | diffusion_step_encoder = nn.Sequential(
146 | SinusoidalPosEmb(dsed),
147 | nn.Linear(dsed, dsed * 4),
148 | nn.Mish(),
149 | nn.Linear(dsed * 4, dsed),
150 | )
151 | cond_dim = dsed
152 | if global_cond_dim is not None:
153 | cond_dim += global_cond_dim
154 |
155 | in_out = list(zip(all_dims[:-1], all_dims[1:]))
156 |
157 | local_cond_encoder = None
158 | if local_cond_dim is not None:
159 | _, dim_out = in_out[0]
160 | dim_in = local_cond_dim
161 | local_cond_encoder = nn.ModuleList(
162 | [
163 | # down encoder
164 | ConditionalResidualBlock1D(
165 | dim_in,
166 | dim_out,
167 | cond_dim=cond_dim,
168 | kernel_size=kernel_size,
169 | n_groups=n_groups,
170 | cond_predict_scale=cond_predict_scale,
171 | ),
172 | # up encoder
173 | ConditionalResidualBlock1D(
174 | dim_in,
175 | dim_out,
176 | cond_dim=cond_dim,
177 | kernel_size=kernel_size,
178 | n_groups=n_groups,
179 | cond_predict_scale=cond_predict_scale,
180 | ),
181 | ]
182 | )
183 |
184 | mid_dim = all_dims[-1]
185 | self.mid_modules = nn.ModuleList(
186 | [
187 | ConditionalResidualBlock1D(
188 | mid_dim,
189 | mid_dim,
190 | cond_dim=cond_dim,
191 | kernel_size=kernel_size,
192 | n_groups=n_groups,
193 | cond_predict_scale=cond_predict_scale,
194 | ),
195 | ConditionalResidualBlock1D(
196 | mid_dim,
197 | mid_dim,
198 | cond_dim=cond_dim,
199 | kernel_size=kernel_size,
200 | n_groups=n_groups,
201 | cond_predict_scale=cond_predict_scale,
202 | ),
203 | ]
204 | )
205 |
206 | down_modules = nn.ModuleList([])
207 | for ind, (dim_in, dim_out) in enumerate(in_out):
208 | is_last = ind >= (len(in_out) - 1)
209 | down_modules.append(
210 | nn.ModuleList(
211 | [
212 | ConditionalResidualBlock1D(
213 | dim_in,
214 | dim_out,
215 | cond_dim=cond_dim,
216 | kernel_size=kernel_size,
217 | n_groups=n_groups,
218 | cond_predict_scale=cond_predict_scale,
219 | ),
220 | ConditionalResidualBlock1D(
221 | dim_out,
222 | dim_out,
223 | cond_dim=cond_dim,
224 | kernel_size=kernel_size,
225 | n_groups=n_groups,
226 | cond_predict_scale=cond_predict_scale,
227 | ),
228 | Downsample1d(dim_out) if not is_last else nn.Identity(),
229 | ]
230 | )
231 | )
232 |
233 | up_modules = nn.ModuleList([])
234 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
235 | is_last = ind >= (len(in_out) - 1)
236 | up_modules.append(
237 | nn.ModuleList(
238 | [
239 | ConditionalResidualBlock1D(
240 | dim_out * 2,
241 | dim_in,
242 | cond_dim=cond_dim,
243 | kernel_size=kernel_size,
244 | n_groups=n_groups,
245 | cond_predict_scale=cond_predict_scale,
246 | ),
247 | ConditionalResidualBlock1D(
248 | dim_in,
249 | dim_in,
250 | cond_dim=cond_dim,
251 | kernel_size=kernel_size,
252 | n_groups=n_groups,
253 | cond_predict_scale=cond_predict_scale,
254 | ),
255 | Upsample1d(dim_in) if not is_last else nn.Identity(),
256 | ]
257 | )
258 | )
259 |
260 | final_conv = nn.Sequential(
261 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
262 | nn.Conv1d(start_dim, input_dim, 1),
263 | )
264 |
265 | self.diffusion_step_encoder = diffusion_step_encoder
266 | self.local_cond_encoder = local_cond_encoder
267 | self.up_modules = up_modules
268 | self.down_modules = down_modules
269 | self.final_conv = final_conv
270 |
271 | def forward(
272 | self,
273 | sample: torch.Tensor,
274 | timestep: Union[torch.Tensor, float, int],
275 | local_cond=None,
276 | global_cond=None,
277 | ):
278 | """
279 | x: (B,T,input_dim)
280 | timestep: (B,) or int, diffusion step
281 | local_cond: (B,T,local_cond_dim)
282 | global_cond: (B,global_cond_dim)
283 | output: (B,T,input_dim)
284 | """
285 | sample = einops.rearrange(sample, "b h t -> b t h")
286 |
287 | # 1. time
288 | timesteps = timestep
289 | if len(timesteps.shape) == 0:
290 | timesteps = timesteps[None].to(sample.device)
291 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
292 | timesteps = timesteps.expand(sample.shape[0])
293 |
294 | global_feature = self.diffusion_step_encoder(timesteps)
295 |
296 | if global_cond is not None:
297 | global_feature = torch.cat([global_feature, global_cond], axis=-1)
298 |
299 | # encode local features
300 | h_local = list()
301 | if local_cond is not None:
302 | local_cond = einops.rearrange(local_cond, "b h t -> b t h")
303 | resnet, resnet2 = self.local_cond_encoder
304 | x = resnet(local_cond, global_feature)
305 | h_local.append(x)
306 | x = resnet2(local_cond, global_feature)
307 | h_local.append(x)
308 |
309 | x = sample
310 | h = []
311 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
312 | x = resnet(x, global_feature)
313 | if idx == 0 and len(h_local) > 0:
314 | x = x + h_local[0]
315 | x = resnet2(x, global_feature)
316 | h.append(x)
317 | x = downsample(x)
318 |
319 | for mid_module in self.mid_modules:
320 | x = mid_module(x, global_feature)
321 |
322 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
323 | x = torch.cat((x, h.pop()), dim=1)
324 | x = resnet(x, global_feature)
325 | x = resnet2(x, global_feature)
326 | x = upsample(x)
327 |
328 | # apply local condition here after x being fully upsampled
329 | if len(h_local) > 0:
330 | x = x + h_local[1]
331 | x = self.final_conv(x)
332 |
333 | x = einops.rearrange(x, "b t h -> b h t")
334 | return x
335 |
336 | def get_optimizer_groups(self, weight_decay, lr_layer_decay, lr_scale=1.0):
337 | return default_optimizer_groups(
338 | self,
339 | weight_decay=weight_decay,
340 | lr_scale=lr_scale,
341 | )
342 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/features/__init__.py:
--------------------------------------------------------------------------------
1 | from .fusion import ObsTokenizer
2 | from .pointnet import PointNet
3 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/features/fusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from einops import rearrange
5 |
6 | import brs_algo.utils as U
7 | from brs_algo.optim import default_optimizer_groups
8 |
9 |
10 | class ObsTokenizer(nn.Module):
11 | def __init__(
12 | self,
13 | extractors: dict[str, nn.Module],
14 | *,
15 | use_modality_type_tokens: bool,
16 | token_dim: int,
17 | token_concat_order: list[str],
18 | strict: bool = True,
19 | ):
20 | assert set(extractors.keys()) == set(token_concat_order)
21 | super().__init__()
22 | self._extractors = nn.ModuleDict(extractors)
23 | self.output_dim = token_dim
24 | self._token_concat_order = token_concat_order
25 | self._strict = strict
26 | self._obs_groups = None
27 | self._use_modality_type_tokens = use_modality_type_tokens
28 | self._modality_type_tokens = None
29 | if use_modality_type_tokens:
30 | modality_type_tokens = {}
31 | for k in extractors:
32 | modality_type_tokens[k] = nn.Parameter(torch.zeros(token_dim))
33 | self._modality_type_tokens = nn.ParameterDict(modality_type_tokens)
34 |
35 | def forward(self, obs: dict[str, torch.Tensor]):
36 | """
37 | x: Dict of (B, T, ...)
38 |
39 | Each encoder should encode corresponding obs field to (B, T, E), where E = token_dim
40 |
41 | The final output interleaves encoded tokens along the time dimension
42 | """
43 | obs = self._group_obs(obs)
44 | self._check_obs_key_match(obs)
45 | x = {
46 | k: v.forward(obs[k]) for k, v in self._extractors.items()
47 | } # dict of (B, T, E)
48 | if self._use_modality_type_tokens:
49 | for k in x:
50 | x[k] = x[k] + self._modality_type_tokens[k]
51 | x = rearrange(
52 | [x[k] for k in self._token_concat_order],
53 | "F B T E -> B (T F) E",
54 | )
55 | self._check_output_shape(obs, x)
56 | return x
57 |
58 | def _group_obs(self, obs):
59 | obs_keys = obs.keys()
60 | if self._obs_groups is None:
61 | # group by /
62 | obs_groups = {k.split("/")[0] for k in obs_keys}
63 | self._obs_groups = sorted(list(obs_groups))
64 | obs_rtn = {}
65 | for g in self._obs_groups:
66 | is_subgroup = any(k.startswith(f"{g}/") for k in obs_keys)
67 | if is_subgroup:
68 | obs_rtn[g] = {
69 | k.split("/", 1)[1]: v
70 | for k, v in obs.items()
71 | if k.startswith(f"{g}/")
72 | }
73 | else:
74 | obs_rtn[g] = obs[g]
75 | return obs_rtn
76 |
77 | @U.call_once
78 | def _check_obs_key_match(self, obs: dict):
79 | if self._strict:
80 | assert set(self._extractors.keys()) == set(obs.keys())
81 | elif set(self._extractors.keys()) != set(obs.keys()):
82 | print(
83 | U.color_text(
84 | f"[warning] obs key mismatch: {set(self._extractors.keys())} != {set(obs.keys())}",
85 | "yellow",
86 | )
87 | )
88 |
89 | @U.call_once
90 | def _check_output_shape(self, obs, output):
91 | T = U.get_batch_size(U.any_slice(obs, np.s_[0]), strict=True)
92 | U.check_shape(
93 | output, (None, T * len(self._token_concat_order), self.output_dim)
94 | )
95 |
96 | def get_optimizer_groups(self, weight_decay, lr_layer_decay, lr_scale=1.0):
97 | pg, pid = default_optimizer_groups(
98 | self,
99 | weight_decay=weight_decay,
100 | lr_scale=lr_scale,
101 | no_decay_filter=[
102 | "_extractors.*",
103 | "_modality_type_tokens.*",
104 | ],
105 | )
106 | return pg, pid
107 |
108 | @property
109 | def num_tokens_per_step(self):
110 | return len(self._token_concat_order)
111 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/features/pointnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import brs_algo.utils as U
5 | from brs_algo.learning.nn.common import build_mlp
6 | from brs_algo.optim import default_optimizer_groups
7 |
8 |
9 | class PointNetCore(nn.Module):
10 | def __init__(
11 | self,
12 | *,
13 | point_channels: int = 3,
14 | output_dim: int,
15 | hidden_dim: int,
16 | hidden_depth: int,
17 | activation: str = "gelu",
18 | ):
19 | super().__init__()
20 | self._mlp = build_mlp(
21 | input_dim=point_channels,
22 | hidden_dim=hidden_dim,
23 | output_dim=output_dim,
24 | hidden_depth=hidden_depth,
25 | activation=activation,
26 | )
27 | self.output_dim = output_dim
28 |
29 | def forward(self, x):
30 | """
31 | x: (..., points, point_channels)
32 | """
33 | x = U.any_to_torch_tensor(x)
34 | x = self._mlp(x) # (..., points, output_dim)
35 | x = torch.max(x, dim=-2)[0] # (..., output_dim)
36 | return x
37 |
38 |
39 | class PointNet(nn.Module):
40 | def __init__(
41 | self,
42 | *,
43 | n_coordinates: int = 3,
44 | n_color: int = 3,
45 | output_dim: int = 512,
46 | hidden_dim: int = 512,
47 | hidden_depth: int = 2,
48 | activation: str = "gelu",
49 | subtract_mean: bool = False,
50 | ):
51 | super().__init__()
52 | pn_in_channels = n_coordinates + n_color
53 | if subtract_mean:
54 | pn_in_channels += n_coordinates
55 | self.pointnet = PointNetCore(
56 | point_channels=pn_in_channels,
57 | output_dim=output_dim,
58 | hidden_dim=hidden_dim,
59 | hidden_depth=hidden_depth,
60 | activation=activation,
61 | )
62 | self.subtract_mean = subtract_mean
63 | self.output_dim = self.pointnet.output_dim
64 |
65 | def forward(self, x):
66 | """
67 | x["xyz"]: (..., points, coordinates)
68 | x["rgb"]: (..., points, color)
69 | """
70 | xyz = x["xyz"]
71 | rgb = x["rgb"]
72 | point = U.any_to_torch_tensor(xyz)
73 | if self.subtract_mean:
74 | mean = torch.mean(point, dim=-2, keepdim=True) # (..., 1, coordinates)
75 | mean = torch.broadcast_to(mean, point.shape) # (..., points, coordinates)
76 | point = point - mean
77 | point = torch.cat([point, mean], dim=-1) # (..., points, 2 * coordinates)
78 | rgb = U.any_to_torch_tensor(rgb)
79 | x = torch.cat([point, rgb], dim=-1)
80 | return self.pointnet(x)
81 |
82 | def get_optimizer_groups(self, weight_decay, lr_layer_decay, lr_scale=1.0):
83 | pg, pids = default_optimizer_groups(
84 | self,
85 | weight_decay=weight_decay,
86 | lr_scale=lr_scale,
87 | no_decay_filter=["ee_embd_layer.*"],
88 | )
89 | return pg, pids
90 |
--------------------------------------------------------------------------------
/brs_algo/learning/nn/gpt/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpt import GPT
2 |
--------------------------------------------------------------------------------
/brs_algo/learning/policy/__init__.py:
--------------------------------------------------------------------------------
1 | from .wbvima_policy import WBVIMAPolicy
2 |
--------------------------------------------------------------------------------
/brs_algo/learning/policy/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from pytorch_lightning import LightningModule
4 |
5 |
6 | class BasePolicy(ABC, LightningModule):
7 | is_sequence_policy: bool = (
8 | False # is this a feedforward policy or a policy requiring history
9 | )
10 |
11 | @abstractmethod
12 | def forward(self, *args, **kwargs):
13 | """
14 | Forward the NN.
15 | """
16 | pass
17 |
18 | @abstractmethod
19 | def act(self, *args, **kwargs):
20 | """
21 | Given obs, return action.
22 | """
23 | pass
24 |
25 |
26 | class BaseDiffusionPolicy(BasePolicy):
27 | @abstractmethod
28 | def get_optimizer_groups(self, *args, **kwargs):
29 | """
30 | Return a list of optimizer groups.
31 | """
32 | pass
33 |
--------------------------------------------------------------------------------
/brs_algo/lightning/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
2 |
--------------------------------------------------------------------------------
/brs_algo/lightning/lightning.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import logging
3 | import time
4 | from copy import deepcopy
5 |
6 | import sys
7 | from omegaconf import DictConfig, OmegaConf, ListConfig
8 | import pytorch_lightning as pl
9 | import pytorch_lightning.loggers as pl_loggers
10 | from pytorch_lightning.callbacks import (
11 | Callback,
12 | ProgressBar,
13 | TQDMProgressBar,
14 | RichProgressBar,
15 | )
16 | from pytorch_lightning.utilities import rank_zero_only
17 | from pytorch_lightning.utilities.rank_zero import rank_zero_debug as rank_zero_debug_pl
18 | from pytorch_lightning.utilities.rank_zero import rank_zero_info as rank_zero_info_pl
19 | from pytorch_lightning.callbacks import ModelCheckpoint
20 |
21 | import brs_algo.utils as U
22 |
23 |
24 | __all__ = [
25 | "LightingTrainer",
26 | "rank_zero_info",
27 | "rank_zero_debug",
28 | "rank_zero_warn",
29 | "rank_zero_info_pl",
30 | "rank_zero_debug_pl",
31 | ]
32 |
33 | logging.getLogger("torch.distributed.distributed_c10d").setLevel(logging.WARNING)
34 | logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING)
35 | U.logging_exclude_pattern(
36 | "root", patterns="*Reducer buckets have been rebuilt in this iteration*"
37 | )
38 |
39 |
40 | class LightingTrainer:
41 | def __init__(self, cfg: DictConfig, eval_only=False):
42 | """
43 | Args:
44 | eval_only: if True, will not save any model dir
45 | """
46 | cfg = deepcopy(cfg)
47 | OmegaConf.set_struct(cfg, False)
48 | self.cfg = cfg
49 | self.run_command_args = sys.argv[1:]
50 | U.register_omegaconf_resolvers()
51 | self.register_extra_resolvers(cfg)
52 | self._register_classes(cfg)
53 | # U.pprint_(OmegaConf.to_container(cfg, resolve=True))
54 | run_name = self.generate_run_name(cfg)
55 | self.run_dir = U.f_join(cfg.exp_root_dir, run_name)
56 | # rank_zero_info(OmegaConf.to_yaml(cfg, resolve=True))
57 | self._eval_only = eval_only
58 | self._resume_mode = None # 'full state' or 'model only'
59 | if eval_only:
60 | rank_zero_info("Eval only, will not save any model dir")
61 | else:
62 | if "resume" in cfg and "ckpt_path" in cfg.resume and cfg.resume.ckpt_path:
63 | cfg.resume.ckpt_path = U.f_expand(
64 | cfg.resume.ckpt_path.replace("_RUN_DIR_", self.run_dir).replace(
65 | "_RUN_NAME_", run_name
66 | )
67 | )
68 | self._resume_mode = (
69 | "full state"
70 | if cfg.resume.get("full_state", False)
71 | else "model only"
72 | )
73 | rank_zero_info(
74 | "=" * 80,
75 | "=" * 80 + "\n",
76 | f"Resume training from {cfg.resume.ckpt_path}",
77 | f"\t({self._resume_mode})\n",
78 | "=" * 80,
79 | "=" * 80,
80 | sep="\n",
81 | end="\n\n",
82 | )
83 | time.sleep(3)
84 | assert U.f_exists(
85 | cfg.resume.ckpt_path
86 | ), "resume ckpt_path does not exist"
87 |
88 | rank_zero_print("Run name:", run_name, "\nExp dir:", self.run_dir)
89 | U.f_mkdir(self.run_dir)
90 | U.f_mkdir(U.f_join(self.run_dir, "tb"))
91 | U.f_mkdir(U.f_join(self.run_dir, "logs"))
92 | U.f_mkdir(U.f_join(self.run_dir, "ckpt"))
93 | U.omegaconf_save(cfg, self.run_dir, "conf.yaml")
94 | rank_zero_print(
95 | "Checkpoint cfg:", U.omegaconf_to_dict(cfg.trainer.checkpoint)
96 | )
97 | self.cfg = cfg
98 | self.run_name = run_name
99 | self.ckpt_cfg = cfg.trainer.pop("checkpoint")
100 | self.data_module = self.create_data_module(cfg)
101 | self._monkey_patch_add_info(self.data_module)
102 | self.trainer = self.create_trainer(cfg)
103 | self.module = self.create_module(cfg)
104 | self.module.data_module = self.data_module
105 | self._monkey_patch_add_info(self.module)
106 |
107 | if not eval_only and self._resume_mode == "model only":
108 | ret = self.module.load_state_dict(
109 | U.torch_load(cfg.resume.ckpt_path)["state_dict"],
110 | strict=cfg.resume.strict,
111 | )
112 | U.rank_zero_warn("state_dict load status:", ret)
113 |
114 | def create_module(self, cfg) -> pl.LightningModule:
115 | return U.instantiate(cfg.module)
116 |
117 | def create_data_module(self, cfg) -> pl.LightningDataModule:
118 | return U.instantiate(cfg.data_module)
119 |
120 | def generate_run_name(self, cfg):
121 | return cfg.run_name + "_" + time.strftime("%Y%m%d-%H%M%S")
122 |
123 | def _monkey_patch_add_info(self, obj):
124 | """
125 | Add useful info to module and data_module so they can access directly
126 | """
127 | # our own info
128 | obj.run_config = self.cfg
129 | obj.run_name = self.run_name
130 | obj.run_command_args = self.run_command_args
131 | # add properties from trainer
132 | for attr in [
133 | "global_rank",
134 | "local_rank",
135 | "world_size",
136 | "num_nodes",
137 | "num_processes",
138 | "node_rank",
139 | "num_gpus",
140 | "data_parallel_device_ids",
141 | ]:
142 | if hasattr(obj, attr):
143 | continue
144 | setattr(
145 | obj.__class__,
146 | attr,
147 | # force capture 'attr'
148 | property(lambda self, attr=attr: getattr(self.trainer, attr)),
149 | )
150 |
151 | def create_loggers(self, cfg) -> List[pl.loggers.Logger]:
152 | if self._eval_only:
153 | loggers = []
154 | else:
155 | loggers = [
156 | pl_loggers.TensorBoardLogger(self.run_dir, name="tb", version=""),
157 | pl_loggers.CSVLogger(self.run_dir, name="logs", version=""),
158 | ]
159 | if cfg.use_wandb:
160 | loggers.append(
161 | pl_loggers.WandbLogger(
162 | name=cfg.wandb_run_name, project=cfg.wandb_project, id=self.run_name
163 | )
164 | )
165 | return loggers
166 |
167 | def create_callbacks(self, cfg) -> List[Callback]:
168 | ModelCheckpoint.FILE_EXTENSION = ".pth"
169 | callbacks = []
170 | if isinstance(self.ckpt_cfg, DictConfig):
171 | ckpt = ModelCheckpoint(
172 | dirpath=U.f_join(self.run_dir, "ckpt"), **self.ckpt_cfg
173 | )
174 | callbacks.append(ckpt)
175 | else:
176 | assert isinstance(self.ckpt_cfg, ListConfig)
177 | for _cfg in self.ckpt_cfg:
178 | ckpt = ModelCheckpoint(dirpath=U.f_join(self.run_dir, "ckpt"), **_cfg)
179 | callbacks.append(ckpt)
180 |
181 | if "callbacks" in cfg.trainer:
182 | extra_callbacks = U.instantiate(cfg.trainer.pop("callbacks"))
183 | assert U.is_sequence(extra_callbacks), "callbacks must be a sequence"
184 | callbacks.extend(extra_callbacks)
185 | if not any(isinstance(c, ProgressBar) for c in callbacks):
186 | callbacks.append(CustomTQDMProgressBar())
187 | rank_zero_print(
188 | "Lightning callbacks:", [c.__class__.__name__ for c in callbacks]
189 | )
190 | return callbacks
191 |
192 | def create_trainer(self, cfg) -> pl.Trainer:
193 | assert "trainer" in cfg
194 | C = cfg.trainer
195 | # find_unused_parameters = C.pop("find_unused_parameters", False)
196 | # rank_zero_info("DDP Strategy", C.strategy)
197 | return U.instantiate(
198 | C, logger=self.create_loggers(cfg), callbacks=self.create_callbacks(cfg)
199 | )
200 |
201 | @property
202 | def tb_logger(self):
203 | return self.logger[0].experiment
204 |
205 | def fit(self):
206 | return self.trainer.fit(
207 | self.module,
208 | datamodule=self.data_module,
209 | ckpt_path=(
210 | self.cfg.resume.ckpt_path if self._resume_mode == "full state" else None
211 | ),
212 | )
213 |
214 | def validate(self):
215 | rank_zero_print("Start validation ...")
216 | assert "testing" in self.cfg, "`testing` sub-dict must be defined in config"
217 | ckpt_path = self.cfg.testing.ckpt_path
218 | if ckpt_path:
219 | ckpt_path = U.f_expand(ckpt_path)
220 | assert U.f_exists(ckpt_path), f"ckpt_path {ckpt_path} does not exist"
221 | rank_zero_info("Run validation on ckpt:", ckpt_path)
222 | ret = self.module.load_state_dict(
223 | U.torch_load(ckpt_path)["state_dict"], strict=self.cfg.testing.strict
224 | )
225 | U.rank_zero_warn("state_dict load status:", ret)
226 | ckpt_path = None # not using pytorch lightning's load
227 | else:
228 | rank_zero_warn("WARNING: no ckpt_path specified, will NOT load any weights")
229 | return self.trainer.validate(
230 | self.module, datamodule=self.data_module, ckpt_path=ckpt_path
231 | )
232 |
233 | def _register_classes(self, cfg):
234 | U.register_callable("DDPStrategy", pl.strategies.DDPStrategy)
235 | U.register_callable("LearningRateMonitor", pl.callbacks.LearningRateMonitor)
236 | U.register_callable("ModelSummary", pl.callbacks.ModelSummary)
237 | U.register_callable("RichModelSummary", pl.callbacks.RichModelSummary)
238 | self.register_extra_classes(cfg)
239 |
240 | def register_extra_classes(self, cfg):
241 | pass
242 |
243 | def register_extra_resolvers(self, cfg):
244 | pass
245 |
246 |
247 | @U.register_class
248 | class CustomTQDMProgressBar(TQDMProgressBar):
249 | """
250 | https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.callbacks.ProgressBarBase.html#pytorch_lightning.callbacks.ProgressBarBase.get_metrics
251 | """
252 |
253 | def get_metrics(self, trainer, model):
254 | # don't show the version number
255 | items = super().get_metrics(trainer, model)
256 | items.pop("v_num", None)
257 | return items
258 |
259 |
260 | # do the same overriding for RichProgressBar
261 | @U.register_class
262 | class CustomRichProgressBar(RichProgressBar):
263 | def get_metrics(self, trainer, model):
264 | # don't show the version number
265 | items = super().get_metrics(trainer, model)
266 | items.pop("v_num", None)
267 | return items
268 |
269 |
270 | @rank_zero_only
271 | def rank_zero_print(*msg, **kwargs):
272 | U.pprint_(*msg, **kwargs)
273 |
274 |
275 | @rank_zero_only
276 | def rank_zero_info(*msg, **kwargs):
277 | U.pprint_(
278 | U.color_text("[INFO]", color="green", styles=["reverse", "bold"]),
279 | *msg,
280 | **kwargs,
281 | )
282 |
283 |
284 | @rank_zero_only
285 | def rank_zero_warn(*msg, **kwargs):
286 | U.pprint_(
287 | U.color_text("[WARN]", color="yellow", styles=["reverse", "bold"]),
288 | *msg,
289 | **kwargs,
290 | )
291 |
292 |
293 | @rank_zero_only
294 | def rank_zero_debug(*msg, **kwargs):
295 | if rank_zero_debug.enabled:
296 | U.pprint_(
297 | U.color_text("[DEBUG]", color="blue", bg_color="on_grey"), *msg, **kwargs
298 | )
299 |
300 |
301 | rank_zero_debug.enabled = True
302 |
--------------------------------------------------------------------------------
/brs_algo/lightning/trainer.py:
--------------------------------------------------------------------------------
1 | from hydra.utils import instantiate
2 |
3 | from brs_algo.lightning.lightning import LightingTrainer
4 |
5 |
6 | class Trainer(LightingTrainer):
7 | def create_module(self, cfg):
8 | return instantiate(cfg.module, _recursive_=False)
9 |
10 | def create_data_module(self, cfg):
11 | return instantiate(cfg.data_module)
12 |
13 | def create_callbacks(self, cfg):
14 | is_rollout_eval = self.cfg.rollout_eval
15 |
16 | del_idxs = []
17 | for i, _cfg in enumerate(self.ckpt_cfg):
18 | eval_type = getattr(_cfg, "eval_type", None)
19 | if eval_type is not None:
20 | if is_rollout_eval and eval_type == "static":
21 | del_idxs.append(i)
22 | elif not is_rollout_eval and eval_type == "rollout":
23 | del_idxs.append(i)
24 | del _cfg["eval_type"]
25 | for i in reversed(del_idxs):
26 | del self.ckpt_cfg[i]
27 | return super().create_callbacks(cfg)
28 |
--------------------------------------------------------------------------------
/brs_algo/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_schedule import *
2 | from .optimizer_group import *
3 |
--------------------------------------------------------------------------------
/brs_algo/optim/lr_schedule.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | from torch.optim.lr_scheduler import LambdaLR
5 |
6 |
7 | __all__ = [
8 | "CosineScheduleFunction",
9 | "CosineLRScheduler",
10 | "LambdaLRWithScale",
11 | "generate_cosine_schedule",
12 | ]
13 |
14 |
15 | def generate_cosine_schedule(
16 | base_value,
17 | final_value,
18 | epochs,
19 | steps_per_epoch,
20 | warmup_epochs=0,
21 | warmup_start_value=0,
22 | ) -> np.ndarray:
23 | warmup_schedule = np.array([])
24 | warmup_iters = int(warmup_epochs * steps_per_epoch)
25 | if warmup_epochs > 0:
26 | warmup_schedule = np.linspace(warmup_start_value, base_value, warmup_iters)
27 |
28 | iters = np.arange(int(epochs * steps_per_epoch) - warmup_iters)
29 | schedule = np.array(
30 | [
31 | final_value
32 | + 0.5
33 | * (base_value - final_value)
34 | * (1 + math.cos(math.pi * i / (len(iters))))
35 | for i in iters
36 | ]
37 | )
38 | schedule = np.concatenate((warmup_schedule, schedule))
39 | assert len(schedule) == int(epochs * steps_per_epoch)
40 | return schedule
41 |
42 |
43 | class CosineScheduleFunction:
44 | def __init__(
45 | self,
46 | base_value,
47 | final_value,
48 | epochs,
49 | steps_per_epoch,
50 | warmup_epochs=0,
51 | warmup_start_value=0,
52 | ):
53 | """
54 | Usage:
55 | scheduler = torch.optim.lr_scheduler.LambdaLR(
56 | optimizer=optimizer, lr_lambda=CosineScheduleFunction(**kwargs)
57 | )
58 | or simply use CosineScheduler(**kwargs)
59 |
60 | Args:
61 | epochs: effective epochs for the cosine schedule, *including* warmup
62 | after these epochs, scheduler will output `final_value` ever after
63 | """
64 | assert warmup_epochs < epochs, f"{warmup_epochs=} must be < {epochs=}"
65 | self._effective_steps = int(epochs * steps_per_epoch)
66 | self.schedule = generate_cosine_schedule(
67 | base_value=base_value,
68 | final_value=final_value,
69 | epochs=epochs,
70 | steps_per_epoch=steps_per_epoch,
71 | warmup_epochs=warmup_epochs,
72 | warmup_start_value=warmup_start_value,
73 | )
74 | assert self.schedule.shape == (self._effective_steps,)
75 | self._final_value = final_value
76 | self._steps_tensor = torch.tensor(0, dtype=torch.long) # for register buffer
77 |
78 | def register_buffer(self, module: torch.nn.Module, name="cosine_steps"):
79 | module.register_buffer(name, self._steps_tensor, persistent=True)
80 |
81 | def __call__(self, step):
82 | self._steps_tensor.copy_(torch.tensor(step))
83 | if step >= self._effective_steps:
84 | val = self._final_value
85 | else:
86 | val = self.schedule[step]
87 | return val
88 |
89 |
90 | class LambdaLRWithScale(LambdaLR):
91 | """
92 | Supports param_groups['lr_scale'], multiplies base_lr with lr_scale
93 | """
94 |
95 | def get_lr(self):
96 | lrs = super().get_lr()
97 | param_groups = self.optimizer.param_groups
98 | assert len(lrs) == len(
99 | param_groups
100 | ), f"INTERNAL: {len(lrs)=} != {len(param_groups)=}"
101 | for i, param_group in enumerate(param_groups):
102 | if "lr_scale" in param_group:
103 | lrs[i] *= param_group["lr_scale"]
104 | # print("LambdaLRWithScale: lrs =", lrs)
105 | return lrs
106 |
107 |
108 | class CosineLRScheduler(LambdaLRWithScale):
109 | """
110 | Supports param_groups['lr_scale'], multiplies base_lr with lr_scale
111 | """
112 |
113 | def __init__(
114 | self,
115 | optimizer,
116 | base_value,
117 | final_value,
118 | epochs,
119 | steps_per_epoch,
120 | warmup_epochs=0,
121 | warmup_start_value=0,
122 | last_epoch=-1,
123 | verbose=False,
124 | ):
125 | lr_lambda = CosineScheduleFunction(
126 | base_value=base_value,
127 | final_value=final_value,
128 | epochs=epochs,
129 | steps_per_epoch=steps_per_epoch,
130 | warmup_epochs=warmup_epochs,
131 | warmup_start_value=warmup_start_value,
132 | )
133 | super().__init__(
134 | optimizer, lr_lambda=lr_lambda, last_epoch=last_epoch, verbose=verbose
135 | )
136 |
--------------------------------------------------------------------------------
/brs_algo/optim/optimizer_group.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate Optimizer groups
3 | """
4 |
5 | from typing import Callable, Union, List, Tuple
6 | import torch
7 | import torch.nn as nn
8 |
9 | from brs_algo.utils.misc_utils import match_patterns, getattr_nested
10 | from brs_algo.utils.torch_utils import freeze_params
11 |
12 |
13 | __all__ = [
14 | "default_optimizer_groups",
15 | "transformer_lr_decay_optimizer_groups",
16 | "transformer_freeze_layers",
17 | "transformer_freeze_except_last_layers",
18 | "check_optimizer_groups",
19 | ]
20 |
21 | FilterType = Union[
22 | Callable[[str, torch.Tensor], bool], List[str], Tuple[str], str, None
23 | ]
24 |
25 | HAS_REMOVEPREFIX = hasattr(str, "removeprefix")
26 |
27 |
28 | def default_optimizer_groups(
29 | model: nn.Module,
30 | weight_decay: float,
31 | lr_scale: float = 1.0,
32 | no_decay_filter: FilterType = None,
33 | exclude_filter: FilterType = None,
34 | ):
35 | """
36 | lr_scale is only effective when using with enlight.learn.lr_schedule.LambdaLRWithScale
37 |
38 | Returns:
39 | [{'lr_scale': 1.0, 'weight_decay': weight_decay, 'params': decay_group},
40 | {'lr_scale': 1.0, 'weight_decay': 0.0, 'params': no_decay_group}],
41 | list of all param_ids processed
42 | """
43 | no_decay_filter = _transform_filter(no_decay_filter)
44 | exclude_filter = _transform_filter(exclude_filter)
45 | decay_group = []
46 | no_decay_group = []
47 | all_params_id = []
48 | for n, p in model.named_parameters():
49 | all_params_id.append(id(p))
50 | if not p.requires_grad or exclude_filter(n, p):
51 | continue
52 |
53 | # no decay: all 1D parameters and model specific ones
54 | if p.ndim == 1 or no_decay_filter(n, p):
55 | no_decay_group.append(p)
56 | else:
57 | decay_group.append(p)
58 | return [
59 | {"weight_decay": weight_decay, "params": decay_group, "lr_scale": lr_scale},
60 | {"weight_decay": 0.0, "params": no_decay_group, "lr_scale": lr_scale},
61 | ], all_params_id
62 |
63 |
64 | def _get_transformer_blocks(model, block_sequence_name):
65 | block_sequence_name = block_sequence_name.rstrip(".")
66 | return getattr_nested(model, block_sequence_name)
67 |
68 |
69 | def transformer_freeze_layers(
70 | model,
71 | layer_0_params: List[str],
72 | block_sequence_name,
73 | freeze_layers: List[int],
74 | extra_freeze_filter: FilterType = None,
75 | ):
76 | """
77 | Args:
78 | model: transformer model with pos embed and other preprocessing parts as layer 0
79 | and `block_sequence_name` that is a sequence of transformer blocks
80 | layer_0_params: list of parameter names before the first transformer
81 | block, which will be assigned layer 0
82 | block_sequence_name: name of the sequence module that contains transformer layers
83 | such that "block.0", "block.1", ... will share one LR within each block
84 | freeze_layers: list of layer indices to freeze. Include 0 to freeze
85 | preprocessing layers. Use negative indices to freeze from the last layers,
86 | e.g. -1 to freeze the last layer. Note that any nn.Module after the transformer
87 | block will NOT be frozen.
88 | extra_freeze_filter: filter to apply to the rest of the parameters
89 | """
90 | extra_freeze_filter = _transform_filter(extra_freeze_filter)
91 | freeze_layers = list(freeze_layers)
92 | layer_0_params = _transform_filter(layer_0_params)
93 | blocks = _get_transformer_blocks(model, block_sequence_name)
94 | assert max(freeze_layers) <= len(blocks), f"max({freeze_layers}) > {len(blocks)}"
95 | # convert all negative indices to last
96 | freeze_layers = [(L if L >= 0 else len(blocks) + L + 1) for L in freeze_layers]
97 | for i, block in enumerate(blocks):
98 | if i + 1 in freeze_layers:
99 | freeze_params(block)
100 |
101 | for n, p in model.named_parameters():
102 | if layer_0_params(n, p) and 0 in freeze_layers:
103 | freeze_params(p)
104 | if extra_freeze_filter(n, p):
105 | freeze_params(p)
106 |
107 |
108 | def transformer_freeze_except_last_layers(
109 | model,
110 | layer_0_params: List[str],
111 | block_sequence_name,
112 | num_last_layers: int,
113 | extra_freeze_filter: FilterType = None,
114 | ):
115 | """
116 | According to Kaiming's MAE paper, finetune ONLY the last layers is typically
117 | as good as finetuning all layers, while being much more compute and memory efficient.
118 |
119 | Args:
120 | model: transformer model with pos embed and other preprocessing parts as layer 0
121 | and `block_sequence_name` that is a sequence of transformer blocks
122 | layer_0_params: list of parameter names before the first transformer
123 | block, which will be assigned layer 0
124 | block_sequence_name: name of the sequence module that contains transformer layers
125 | such that "block.0", "block.1", ... will share one LR within each block
126 | num_last_layers: number of last N layers to unfreeze (finetune)
127 | extra_freeze_filter: filter to apply to the rest of the parameters
128 | """
129 | num_blocks = len(_get_transformer_blocks(model, block_sequence_name))
130 | # blocks start from 1, because 0th block is preprocessing
131 | # get a range of the first num_blocks - num_last_layers blocks
132 | return transformer_freeze_layers(
133 | model,
134 | layer_0_params=layer_0_params,
135 | block_sequence_name=block_sequence_name,
136 | freeze_layers=range(num_blocks - num_last_layers + 1),
137 | extra_freeze_filter=extra_freeze_filter,
138 | )
139 |
140 |
141 | def transformer_lr_decay_optimizer_groups(
142 | model,
143 | layer_0_params: List[str],
144 | block_sequence_name,
145 | *,
146 | weight_decay,
147 | lr_scale=1.0,
148 | lr_layer_decay,
149 | no_decay_filter: FilterType = None,
150 | exclude_filter: FilterType = None,
151 | ):
152 | """
153 | Parameter groups for layer-wise lr decay
154 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
155 | lr_scale is only effective when using with enlight.learn.lr_schedule.LambdaLRWithScale
156 |
157 | Args:
158 | model: transformer model with pos embed and other preprocessing parts as layer 0,
159 | then the blocks start from layer 1 to layer N. LR will be progressively smaller
160 | from the last layer to the first layer. Layer N will be lr*lr_scale,
161 | N-1 will be lr*lr_scale*lr_layer_decay,
162 | N-2 will be lr*lr_scale*lr_layer_decay^2, ...
163 | layer_0_params: list of parameter names before the first transformer
164 | block, which will be assigned layer 0
165 | block_sequence_name: name of the sequence module that contains transformer layers
166 | such that "block.0", "block.1", ... will share one LR within each block
167 | """
168 | no_decay_filter = _transform_filter(no_decay_filter)
169 | exclude_filter = _transform_filter(exclude_filter)
170 | layer_0_params = _transform_filter(layer_0_params)
171 | block_sequence_name = block_sequence_name.rstrip(".")
172 |
173 | param_group_names = {}
174 | param_groups = {}
175 |
176 | num_layers = len(getattr_nested(model, block_sequence_name)) + 1
177 |
178 | layer_scales = [
179 | lr_scale * lr_layer_decay ** (num_layers - i) for i in range(num_layers + 1)
180 | ]
181 | all_params_id = []
182 |
183 | for n, p in model.named_parameters():
184 | all_params_id.append(id(p))
185 | if not p.requires_grad or exclude_filter(n, p):
186 | continue
187 |
188 | # no decay: all 1D parameters and model specific ones
189 | if p.ndim == 1 or no_decay_filter(n, p):
190 | g_decay = "no_decay"
191 | this_decay = 0.0
192 | else:
193 | g_decay = "decay"
194 | this_decay = weight_decay
195 |
196 | # get the layer index of the param
197 | if layer_0_params(n, p):
198 | layer_id = 0
199 | elif n.startswith(block_sequence_name + "."):
200 | # blocks.0 -> layer 1; blocks.1 -> layer 2; ...
201 | try:
202 | if HAS_REMOVEPREFIX:
203 | layer_id = (
204 | int(n.removeprefix(block_sequence_name + ".").split(".")[0]) + 1
205 | )
206 | else:
207 | layer_id = int(n[len(block_sequence_name) + 1 :].split(".")[0]) + 1
208 |
209 | except ValueError:
210 | raise ValueError(
211 | f"{n} must have the format {block_sequence_name}.... "
212 | f"where is an integer"
213 | )
214 | else:
215 | layer_id = num_layers
216 | group_name = f"layer_{layer_id}_{g_decay}"
217 |
218 | if group_name not in param_group_names:
219 | this_scale = layer_scales[layer_id]
220 |
221 | param_group_names[group_name] = {
222 | "lr_scale": this_scale,
223 | "weight_decay": this_decay,
224 | "params": [],
225 | }
226 | param_groups[group_name] = {
227 | "lr_scale": this_scale,
228 | "weight_decay": this_decay,
229 | "params": [],
230 | }
231 |
232 | param_group_names[group_name]["params"].append(n)
233 | param_groups[group_name]["params"].append(p)
234 |
235 | return list(param_groups.values()), all_params_id
236 |
237 |
238 | def check_optimizer_groups(
239 | model, param_groups: List[dict], verbose=True, order_by="param"
240 | ):
241 | """
242 | For debugging purpose, check which param belongs to which param group
243 | This groups by param groups first
244 |
245 | Args:
246 | order_by: 'params' or 'groups', either by optimization group order or
247 | by nn.Module parameter list order.
248 |
249 | Returns:
250 | {param_name (str): group_idx (int)}, table (str, for ASCII print)
251 | """
252 | from tabulate import tabulate
253 |
254 | group_configs = [
255 | ", ".join(f"{k}={v:.6g}" for k, v in sorted(group.items()) if k != "params")
256 | for group in param_groups
257 | ]
258 | display_table = []
259 | name_to_group_idx = {}
260 | pid_to_group_id = {
261 | id(p): i for i, group in enumerate(param_groups) for p in group["params"]
262 | }
263 | for n, p in model.named_parameters(recurse=True):
264 | if id(p) in pid_to_group_id:
265 | gid = pid_to_group_id[id(p)]
266 | display_table.append((n, gid, group_configs[gid]))
267 | if n in name_to_group_idx:
268 | if verbose:
269 | print(
270 | f"WARNING: {n} is in both group "
271 | f"{name_to_group_idx[n]} and {gid}"
272 | )
273 | name_to_group_idx[n] = gid
274 | else:
275 | display_table.append((n, "_", "excluded"))
276 | name_to_group_idx[n] = None
277 | if order_by == "group":
278 | display_table.sort(key=lambda x: 1e10 if x[1] == "_" else x[1])
279 | table_str = tabulate(
280 | display_table, headers=["param", "i", "config"], tablefmt="presto"
281 | )
282 | return name_to_group_idx, table_str
283 |
284 |
285 | def _transform_filter(filter: FilterType):
286 | """
287 | Filter can be:
288 | - None: always returns False
289 | - function(name, p) -> True to activate, False to deactivate
290 | - list of strings to match, can have wildcard
291 | """
292 | if filter is None:
293 | return lambda name, p: False
294 | elif callable(filter):
295 | return filter
296 | elif isinstance(filter, (str, list, tuple)):
297 | if isinstance(filter, str):
298 | filter = [filter]
299 | return lambda name, p: match_patterns(name, include=filter)
300 | else:
301 | raise ValueError(f"Invalid filter: {filter}")
302 |
--------------------------------------------------------------------------------
/brs_algo/rollout/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/brs_algo/rollout/__init__.py
--------------------------------------------------------------------------------
/brs_algo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .array_tensor_utils import *
2 | from .config_utils import *
3 | from .convert_utils import *
4 | from .file_utils import *
5 | from .functional_utils import *
6 | from .misc_utils import *
7 | from .print_utils import *
8 | from .random_seed import *
9 | from .shape_utils import *
10 | from .termcolor import *
11 | from .torch_utils import *
12 | from .tree_utils import *
13 |
--------------------------------------------------------------------------------
/brs_algo/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import tree
5 | import hydra
6 | import importlib.resources
7 | import sys
8 |
9 | from copy import deepcopy
10 | from omegaconf import OmegaConf, DictConfig
11 | from .functional_utils import meta_decorator, is_sequence, is_mapping, call_once
12 | from .print_utils import to_scientific_str
13 |
14 | _CLASS_REGISTRY = {} # for instantiation
15 |
16 |
17 | def resource_file_path(pkg_name, fname) -> str:
18 | with importlib.resources.path(pkg_name, fname) as p:
19 | return str(p)
20 |
21 |
22 | def print_config(cfg: DictConfig):
23 | print(cfg.pretty(resolve=True))
24 |
25 |
26 | def is_hydra_initialized():
27 | return hydra.utils.HydraConfig.initialized()
28 |
29 |
30 | def hydra_config():
31 | # https://github.com/facebookresearch/hydra/issues/377
32 | # HydraConfig() is a singleton
33 | if is_hydra_initialized():
34 | return hydra.utils.HydraConfig().cfg.hydra
35 | else:
36 | return None
37 |
38 |
39 | def hydra_override_arg_list() -> list[str]:
40 | """
41 | Returns:
42 | list ["lr=0.2", "batch=64", ...]
43 | """
44 | if is_hydra_initialized():
45 | return hydra_config().overrides.task
46 | else:
47 | return []
48 |
49 |
50 | def hydra_override_name():
51 | if is_hydra_initialized():
52 | return hydra_config().job.override_dirname
53 | else:
54 | return ""
55 |
56 |
57 | def hydra_original_dir(*subpaths):
58 | return os.path.join(hydra.utils.get_original_cwd(), *subpaths)
59 |
60 |
61 | @call_once(on_second_call="noop")
62 | def register_omegaconf_resolvers():
63 | import numpy as np
64 |
65 | OmegaConf.register_new_resolver(
66 | "scientific", lambda v, i=0: to_scientific_str(v, i)
67 | )
68 | OmegaConf.register_new_resolver("_optional", lambda v: f"_{v}" if v else "")
69 | OmegaConf.register_new_resolver("optional_", lambda v: f"{v}_" if v else "")
70 | OmegaConf.register_new_resolver("_optional_", lambda v: f"_{v}_" if v else "")
71 | OmegaConf.register_new_resolver("__optional", lambda v: f"__{v}" if v else "")
72 | OmegaConf.register_new_resolver("optional__", lambda v: f"{v}__" if v else "")
73 | OmegaConf.register_new_resolver("__optional__", lambda v: f"__{v}__" if v else "")
74 | OmegaConf.register_new_resolver(
75 | "iftrue", lambda cond, v_default: cond if cond else v_default
76 | )
77 | OmegaConf.register_new_resolver(
78 | "ifelse", lambda cond, v1, v2="": v1 if cond else v2
79 | )
80 | OmegaConf.register_new_resolver(
81 | "ifequal", lambda query, key, v1, v2: v1 if query == key else v2
82 | )
83 | OmegaConf.register_new_resolver("intbool", lambda cond: 1 if cond else 0)
84 | OmegaConf.register_new_resolver("mult", lambda *x: np.prod(x).tolist())
85 | OmegaConf.register_new_resolver("add", lambda *x: sum(x))
86 | OmegaConf.register_new_resolver("div", lambda x, y: x / y)
87 | OmegaConf.register_new_resolver("intdiv", lambda x, y: x // y)
88 |
89 | # try each key until the key exists. Useful for multiple classes that have different
90 | # names for the same key
91 | def _try_key(cfg, *keys):
92 | for k in keys:
93 | if k in cfg:
94 | return cfg[k]
95 | raise KeyError(f"no key in {keys} is valid")
96 |
97 | OmegaConf.register_new_resolver("trykey", _try_key)
98 | # replace `resnet.gn.ws` -> `resnet_gn_ws`, because omegaconf doesn't support
99 | # keys with dots. Useful for generating run name with dots
100 | OmegaConf.register_new_resolver("underscore_to_dots", lambda s: s.replace("_", "."))
101 |
102 | def _no_instantiate(cfg):
103 | cfg = deepcopy(cfg)
104 | cfg[_NO_INSTANTIATE] = True
105 | return cfg
106 |
107 | OmegaConf.register_new_resolver("no_instantiate", _no_instantiate)
108 |
109 |
110 | # ========================================================
111 | # ================== Instantiation tools ================
112 | # ========================================================
113 |
114 |
115 | def register_callable(name, class_type):
116 | if isinstance(class_type, str):
117 | class_type, name = name, class_type
118 | assert callable(class_type)
119 | _CLASS_REGISTRY[name] = class_type
120 |
121 |
122 | @meta_decorator
123 | def register_class(cls, alias=None):
124 | """
125 | Decorator
126 | """
127 | assert callable(cls)
128 | _CLASS_REGISTRY[cls.__name__] = cls
129 | if alias:
130 | assert is_sequence(alias)
131 | for a in alias:
132 | _CLASS_REGISTRY[str(a)] = cls
133 | return cls
134 |
135 |
136 | def omegaconf_to_dict(cfg, resolve: bool = True, enum_to_str: bool = False):
137 | """
138 | Convert arbitrary nested omegaconf objects to primitive containers
139 |
140 | WARNING: cannot use tree lib because it gets confused on DictConfig and ListConfig
141 | """
142 | kw = dict(resolve=resolve, enum_to_str=enum_to_str)
143 | if OmegaConf.is_config(cfg):
144 | return OmegaConf.to_container(cfg, **kw)
145 | elif is_sequence(cfg):
146 | return type(cfg)(omegaconf_to_dict(c, **kw) for c in cfg)
147 | elif is_mapping(cfg):
148 | return {k: omegaconf_to_dict(c, **kw) for k, c in cfg.items()}
149 | else:
150 | return cfg
151 |
152 |
153 | def omegaconf_save(cfg, *paths: str, resolve: bool = True):
154 | """
155 | Save omegaconf to yaml
156 | """
157 | from .file_utils import f_join
158 |
159 | OmegaConf.save(cfg, f_join(*paths), resolve=resolve)
160 |
161 |
162 | def get_class(path):
163 | """
164 | First try to find the class in the registry first,
165 | if it doesn't exist, use importlib to locate it
166 | """
167 | if path in _CLASS_REGISTRY:
168 | return _CLASS_REGISTRY[path]
169 | else:
170 | assert "." in path, (
171 | f"Because {path} is not found in class registry, "
172 | f"it must be a full module path"
173 | )
174 | try:
175 | from importlib import import_module
176 |
177 | module_path, _, class_name = path.rpartition(".")
178 | mod = import_module(module_path)
179 | try:
180 | class_type = getattr(mod, class_name)
181 | except AttributeError:
182 | raise ImportError(
183 | "Class {} is not in module {}".format(class_name, module_path)
184 | )
185 | return class_type
186 | except ValueError as e:
187 | print("Error initializing class " + path, file=sys.stderr)
188 | raise e
189 |
190 |
191 | _DELETE_ARG = "__delete__"
192 | _NO_INSTANTIATE = "__no_instantiate__" # return config as-is
193 | _OMEGA_MISSING = "???"
194 |
195 |
196 | def _get_instantiate_params(cfg, kwargs=None):
197 | params = cfg
198 | f_args, f_kwargs = (), {}
199 | for k, value in params.items():
200 | if k in ["cls", "class"]:
201 | continue
202 | elif k == "*args":
203 | assert is_sequence(value), '"*args" value must be a sequence'
204 | f_args = list(value)
205 | continue
206 | if value == _OMEGA_MISSING:
207 | if kwargs and k in kwargs:
208 | value = kwargs[k]
209 | else:
210 | raise ValueError(f'Missing required keyword arg "{k}" in cfg: {cfg}')
211 | if value == _DELETE_ARG:
212 | continue
213 | else:
214 | f_kwargs[k] = value
215 | return f_args, f_kwargs
216 |
217 |
218 | def _instantiate_single(cfg):
219 | if is_mapping(cfg) and ("cls" in cfg or "class" in cfg):
220 | assert bool("cls" in cfg) != bool("class" in cfg), (
221 | "to instantiate from config, "
222 | 'one and only one of "cls" or "class" key should be provided'
223 | )
224 | if _NO_INSTANTIATE in cfg:
225 | no_instantiate = cfg.pop(_NO_INSTANTIATE)
226 | if no_instantiate:
227 | cfg = deepcopy(cfg)
228 | return cfg
229 | else:
230 | return _instantiate_single(cfg)
231 |
232 | cls = cfg.get("class", cfg.get("cls"))
233 | args, kwargs = _get_instantiate_params(cfg)
234 | try:
235 | class_type = get_class(cls)
236 | return class_type(*args, **kwargs)
237 | except Exception as e:
238 | raise RuntimeError(f"Error instantiating {cls}: {e}")
239 | else:
240 | return None
241 |
242 |
243 | def instantiate(_cfg_, **kwargs):
244 | """
245 | Any dict with "cls" or "class" key is considered instantiable.
246 |
247 | Any key that has the special value "__delete__"
248 | will not be passed to the constructor
249 |
250 | **kwargs only apply to the top level object if it's a dict, otherwise raise error
251 | """
252 | assert (
253 | OmegaConf.is_config(_cfg_)
254 | or isinstance(_cfg_, (list, tuple))
255 | or is_mapping(_cfg_)
256 | ), (
257 | '"cfg" must be a dict, list, tuple, or OmegaConf config to be instantiated. '
258 | f"Current its type is {type(_cfg_)}"
259 | )
260 |
261 | _cfg_ = omegaconf_to_dict(_cfg_, resolve=True)
262 |
263 | if kwargs:
264 | if is_mapping(_cfg_):
265 | _cfg_ = _cfg_.copy()
266 | _cfg_.update(kwargs)
267 | _cfg_ = {k: v for k, v in _cfg_.items() if v != _DELETE_ARG}
268 | else:
269 | raise RuntimeError(
270 | f"**kwargs specified, but the top-level cfg is not a dict. "
271 | f"It has type {type(_cfg_)}"
272 | )
273 |
274 | return tree.traverse(_instantiate_single, _cfg_, top_down=False)
275 |
--------------------------------------------------------------------------------
/brs_algo/utils/convert_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | import torch
5 | import os.path
6 | from PIL import Image
7 |
8 |
9 | def np_dtype_size(dtype: Union[str, np.dtype]) -> int:
10 | return np.dtype(dtype).itemsize
11 |
12 |
13 | _TORCH_DTYPE_TABLE = {
14 | torch.bool: 1,
15 | torch.int8: 1,
16 | torch.uint8: 1,
17 | torch.int16: 2,
18 | torch.short: 2,
19 | torch.int32: 4,
20 | torch.int: 4,
21 | torch.int64: 8,
22 | torch.long: 8,
23 | torch.float16: 2,
24 | torch.bfloat16: 2,
25 | torch.half: 2,
26 | torch.float32: 4,
27 | torch.float: 4,
28 | torch.float64: 8,
29 | torch.double: 8,
30 | }
31 |
32 |
33 | def torch_dtype(dtype: Union[str, torch.dtype, None]) -> torch.dtype:
34 | if dtype is None:
35 | return None
36 | elif isinstance(dtype, torch.dtype):
37 | return dtype
38 | elif isinstance(dtype, str):
39 | try:
40 | dtype = getattr(torch, dtype)
41 | except AttributeError:
42 | raise ValueError(f'"{dtype}" is not a valid torch dtype')
43 | assert isinstance(
44 | dtype, torch.dtype
45 | ), f"dtype {dtype} is not a valid torch tensor type"
46 | return dtype
47 | else:
48 | raise NotImplementedError(f"{dtype} not supported")
49 |
50 |
51 | def torch_device(device: Union[str, int, None]) -> torch.device:
52 | """
53 | Args:
54 | device:
55 | - "auto": use current torch context device, same as `.to('cuda')`
56 | - int: negative for CPU, otherwise GPU index
57 | """
58 | if device is None:
59 | return None
60 | elif device == "auto":
61 | return torch.device("cuda")
62 | elif isinstance(device, int) and device < 0:
63 | return torch.device("cpu")
64 | else:
65 | return torch.device(device)
66 |
67 |
68 | def torch_dtype_size(dtype: Union[str, torch.dtype]) -> int:
69 | return _TORCH_DTYPE_TABLE[torch_dtype(dtype)]
70 |
71 |
72 | def _convert_then_transfer(x, dtype, device, copy, non_blocking):
73 | x = x.to(dtype=dtype, copy=copy, non_blocking=non_blocking)
74 | return x.to(device=device, copy=False, non_blocking=non_blocking)
75 |
76 |
77 | def _transfer_then_convert(x, dtype, device, copy, non_blocking):
78 | x = x.to(device=device, copy=copy, non_blocking=non_blocking)
79 | return x.to(dtype=dtype, copy=False, non_blocking=non_blocking)
80 |
81 |
82 | def any_to_torch_tensor(
83 | x,
84 | dtype: Union[str, torch.dtype, None] = None,
85 | device: Union[str, int, torch.device, None] = None,
86 | copy=False,
87 | non_blocking=False,
88 | smart_optimize: bool = True,
89 | ):
90 | dtype = torch_dtype(dtype)
91 | device = torch_device(device)
92 |
93 | if not isinstance(x, (torch.Tensor, np.ndarray)):
94 | # x is a primitive python sequence
95 | x = torch.tensor(x, dtype=dtype)
96 | copy = False
97 |
98 | # This step does not create any copy.
99 | # If x is a numpy array, simply wraps it in Tensor. If it's already a Tensor, do nothing.
100 | x = torch.as_tensor(x)
101 | # avoid passing None to .to(), PyTorch 1.4 bug
102 | dtype = dtype or x.dtype
103 | device = device or x.device
104 |
105 | if not smart_optimize:
106 | # do a single stage type conversion and transfer
107 | return x.to(dtype=dtype, device=device, copy=copy, non_blocking=non_blocking)
108 |
109 | # we have two choices: (1) convert dtype and then transfer to GPU
110 | # (2) transfer to GPU and then convert dtype
111 | # because CPU-to-GPU memory transfer is the bottleneck, we will reduce it as
112 | # much as possible by sending the smaller dtype
113 |
114 | src_dtype_size = torch_dtype_size(x.dtype)
115 |
116 | # destination dtype size
117 | if dtype is None:
118 | dest_dtype_size = src_dtype_size
119 | else:
120 | dest_dtype_size = torch_dtype_size(dtype)
121 |
122 | if x.dtype != dtype or x.device != device:
123 | # a copy will always be performed, no need to force copy again
124 | copy = False
125 |
126 | if src_dtype_size > dest_dtype_size:
127 | # better to do conversion on one device (e.g. CPU) and then transfer to another
128 | return _convert_then_transfer(x, dtype, device, copy, non_blocking)
129 | elif src_dtype_size == dest_dtype_size:
130 | # when equal, we prefer to do the conversion on whichever device that's GPU
131 | if x.device.type == "cuda":
132 | return _convert_then_transfer(x, dtype, device, copy, non_blocking)
133 | else:
134 | return _transfer_then_convert(x, dtype, device, copy, non_blocking)
135 | else:
136 | # better to transfer data across device first, and then do conversion
137 | return _transfer_then_convert(x, dtype, device, copy, non_blocking)
138 |
139 |
140 | def any_to_numpy(
141 | x,
142 | dtype: Union[str, np.dtype, None] = None,
143 | copy: bool = False,
144 | non_blocking: bool = False,
145 | smart_optimize: bool = True,
146 | ):
147 | if isinstance(x, torch.Tensor):
148 | x = any_to_torch_tensor(
149 | x,
150 | dtype=dtype,
151 | device="cpu",
152 | copy=copy,
153 | non_blocking=non_blocking,
154 | smart_optimize=smart_optimize,
155 | )
156 | return x.detach().numpy()
157 | else:
158 | # primitive python sequence or ndarray
159 | return np.array(x, dtype=dtype, copy=copy)
160 |
161 |
162 | def img_to_tensor(file_path: str, dtype=None, device=None, add_batch_dim: bool = False):
163 | """
164 | Args:
165 | scale_255: if True, scale to [0, 255]
166 | add_batch_dim: if 3D, add a leading batch dim
167 |
168 | Returns:
169 | tensor between [0, 255]
170 |
171 | """
172 | # image path
173 | pic = Image.open(os.path.expanduser(file_path)).convert("RGB")
174 | # code referenced from torchvision.transforms.functional.to_tensor
175 | # handle PIL Image
176 | assert pic.mode == "RGB"
177 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
178 | img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
179 | # put it from HWC to CHW format
180 | img = img.permute((2, 0, 1)).contiguous()
181 |
182 | img = any_to_torch_tensor(img, dtype=dtype, device=device)
183 | if add_batch_dim:
184 | img.unsqueeze_(dim=0)
185 | return img
186 |
187 |
188 | def any_to_float(x, strict: bool = False):
189 | """
190 | Convert a singleton torch tensor or ndarray to float
191 |
192 | Args:
193 | strict: True to check if the input is a singleton and raise Exception if not.
194 | False to return the original value if not a singleton
195 | """
196 |
197 | if torch.is_tensor(x) and x.numel() == 1:
198 | return float(x)
199 | elif isinstance(x, np.ndarray) and x.size == 1:
200 | return float(x)
201 | else:
202 | if strict:
203 | raise ValueError(f"{x} cannot be converted to a single float.")
204 | else:
205 | return x
206 |
--------------------------------------------------------------------------------
/brs_algo/utils/misc_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import codecs
4 | import fnmatch
5 | from collections import Counter
6 | from typing import Optional, Dict, Any, List, Union, Callable
7 | from typing_extensions import Literal
8 |
9 |
10 | def set_os_envs(envs: Optional[Dict[str, Any]] = None):
11 | """
12 | Special value __delete__ or None indicates that the ENV_VAR should be removed
13 | """
14 | if envs is None:
15 | envs = {}
16 | DEL = {None, "__delete__"}
17 | for k, v in envs.items():
18 | if v in DEL:
19 | os.environ.pop(k, None)
20 | os.environ.update({k: str(v) for k, v in envs.items() if v not in DEL})
21 |
22 |
23 | def argmax(L):
24 | return max(zip(L, range(len(L))))[1]
25 |
26 |
27 | def _match_patterns_helper(element, patterns):
28 | for p in patterns:
29 | if callable(p) and p(element):
30 | return True
31 | if fnmatch.fnmatch(element, p):
32 | return True
33 | return False
34 |
35 |
36 | def match_patterns(
37 | item: str,
38 | include: Union[str, List[str], Callable, List[Callable], None] = None,
39 | exclude: Union[str, List[str], Callable, List[Callable], None] = None,
40 | *,
41 | precedence: Literal["include", "exclude"] = "exclude",
42 | ):
43 | """
44 | Args:
45 | include: None to disable `include` filter and delegate to exclude
46 | precedence: "include" or "exclude"
47 | """
48 | assert precedence in ["include", "exclude"]
49 | if exclude is None:
50 | exclude = []
51 | if isinstance(exclude, (str, Callable)):
52 | exclude = [exclude]
53 | if isinstance(include, (str, Callable)):
54 | include = [include]
55 | if include is None:
56 | # exclude is the sole veto vote
57 | return not _match_patterns_helper(item, exclude)
58 |
59 | if precedence == "include":
60 | return _match_patterns_helper(item, include)
61 | else:
62 | if _match_patterns_helper(item, exclude):
63 | return False
64 | else:
65 | return _match_patterns_helper(item, include)
66 |
67 |
68 | def filter_patterns(
69 | items: List[str],
70 | include: Union[str, List[str], Callable, List[Callable], None] = None,
71 | exclude: Union[str, List[str], Callable, List[Callable], None] = None,
72 | *,
73 | precedence: Literal["include", "exclude"] = "exclude",
74 | ordering: Literal["original", "include"] = "original",
75 | ):
76 | """
77 | Args:
78 | ordering: affects the order of items in the returned list. Does not affect the
79 | content of the returned list.
80 | - "original": keep the ordering of items in the input list
81 | - "include": order items by the order of include patterns
82 | """
83 | assert ordering in ["original", "include"]
84 | if include is None or isinstance(include, str) or ordering == "original":
85 | return [
86 | x
87 | for x in items
88 | if match_patterns(
89 | x, include=include, exclude=exclude, precedence=precedence
90 | )
91 | ]
92 | else:
93 | items = items.copy()
94 | ret = []
95 | for inc in include:
96 | for i, item in enumerate(items):
97 | if item is None:
98 | continue
99 | if match_patterns(
100 | item, include=inc, exclude=exclude, precedence=precedence
101 | ):
102 | ret.append(item)
103 | items[i] = None
104 | return ret
105 |
106 |
107 | def getitem_nested(cfg, key: str):
108 | """
109 | Recursively get key, if key has '.' in it
110 | """
111 | keys = key.split(".")
112 | for k in keys:
113 | assert k in cfg, f'{k} in key "{key}" does not exist in config'
114 | cfg = cfg[k]
115 | return cfg
116 |
117 |
118 | def setitem_nested(cfg, key: str, value):
119 | """
120 | Recursively get key, if key has '.' in it
121 | """
122 | keys = key.split(".")
123 | for k in keys[:-1]:
124 | assert k in cfg, f'{k} in key "{key}" does not exist in config'
125 | cfg = cfg[k]
126 | cfg[keys[-1]] = value
127 |
128 |
129 | def getattr_nested(obj, key: str):
130 | """
131 | Recursively get attribute
132 | """
133 | keys = key.split(".")
134 | for k in keys:
135 | assert hasattr(obj, k), f'{k} in attribute "{key}" does not exist'
136 | obj = getattr(obj, k)
137 | return obj
138 |
139 |
140 | def setattr_nested(obj, key: str, value):
141 | """
142 | Recursively set attribute
143 | """
144 | keys = key.split(".")
145 | for k in keys[:-1]:
146 | assert hasattr(obj, k), f'{k} in attribute "{key}" does not exist'
147 | obj = getattr(obj, k)
148 | setattr(obj, keys[-1], value)
149 |
150 |
151 | class PeriodicEvent:
152 | """
153 | triggers every period
154 | """
155 |
156 | def __init__(self, period: int, initial_value=0):
157 | self._period = period
158 | assert self._period >= 1
159 | self._last_threshold = initial_value
160 | self._last_value = initial_value
161 | self._trigger_counts = 0
162 |
163 | def __call__(self, new_value=None, increment=None):
164 | assert bool(new_value is None) != bool(increment is None), (
165 | "you must specify one and only one of new_value or increment, "
166 | "but not both"
167 | )
168 | d = self._period
169 | if new_value is None:
170 | new_value = self._last_value + increment
171 | assert new_value >= self._last_value, (
172 | f"value must be monotonically increasing. "
173 | f"Current value {new_value} < last value {self._last_value}"
174 | )
175 | self._last_value = new_value
176 | if new_value - self._last_threshold >= d:
177 | self._last_threshold += (new_value - self._last_threshold) // d * d
178 | self._trigger_counts += 1
179 | return True
180 | else:
181 | return False
182 |
183 | @property
184 | def trigger_counts(self):
185 | return self._trigger_counts
186 |
187 | @property
188 | def current_value(self):
189 | return self._last_value
190 |
191 |
192 | class Once:
193 | def __init__(self):
194 | self._triggered = False
195 |
196 | def __call__(self):
197 | if not self._triggered:
198 | self._triggered = True
199 | return True
200 | else:
201 | return False
202 |
203 | def __bool__(self):
204 | raise RuntimeError("`Once` objects should be used by calling ()")
205 |
206 |
207 | _GLOBAL_ONCE_SET = set()
208 | _GLOBAL_NTIMES_COUNTER = Counter()
209 |
210 |
211 | def global_once(name):
212 | """
213 | Try this to automate the name:
214 | https://gist.github.com/techtonik/2151727#gistcomment-2333747
215 | """
216 | if name in _GLOBAL_ONCE_SET:
217 | return False
218 | else:
219 | _GLOBAL_ONCE_SET.add(name)
220 | return True
221 |
222 |
223 | def global_n_times(name, n: int):
224 | """
225 | Triggers N times
226 | """
227 | assert n >= 1
228 | if _GLOBAL_NTIMES_COUNTER[name] < n:
229 | _GLOBAL_NTIMES_COUNTER[name] += 1
230 | return True
231 | else:
232 | return False
233 |
234 |
235 | class Every:
236 | def __init__(self, n: int, on_first: bool = False):
237 | assert n > 0
238 | self._i = 0 if on_first else 1
239 | self._n = n
240 |
241 | def __call__(self):
242 | return self._i % self._n == 0
243 |
244 | def __bool__(self):
245 | raise RuntimeError("`Every` objects should be used by calling ()")
246 |
247 |
248 | def encode_base64(obj) -> str:
249 | return codecs.encode(pickle.dumps(obj), "base64").decode()
250 |
251 |
252 | def decode_base64(s: str):
253 | return pickle.loads(codecs.decode(s.encode(), "base64"))
254 |
--------------------------------------------------------------------------------
/brs_algo/utils/print_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import io
4 | import os
5 | import sys
6 | import logging
7 | import shlex
8 | import string
9 | import pprint
10 | import textwrap
11 | import time
12 | import traceback
13 | from datetime import datetime
14 | from typing import Callable, Union
15 |
16 | import numpy as np
17 | from typing_extensions import Literal
18 |
19 | from .functional_utils import meta_decorator
20 | from .misc_utils import match_patterns
21 |
22 |
23 | def to_readable_count_str(value: int, precision: int = 2) -> str:
24 | assert value >= 0
25 | labels = [" ", "K", "M", "B", "T"]
26 | num_digits = int(np.floor(np.log10(value)) + 1 if value > 0 else 1)
27 | num_groups = int(np.ceil(num_digits / 3))
28 | num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
29 | shift = -3 * (num_groups - 1)
30 | value = value * (10**shift)
31 | index = num_groups - 1
32 | rem = value - int(value)
33 | if precision > 0 and rem > 0.01:
34 | fmt = f"{{:.{precision}f}}"
35 | rem_str = fmt.format(rem).lstrip("0")
36 | else:
37 | rem_str = ""
38 | return f"{int(value):,d}{rem_str} {labels[index]}"
39 |
40 |
41 | def to_scientific_str(value, precision: int = 1, capitalize: bool = False) -> str:
42 | """
43 | 0.0015 -> "1.5e-3"
44 | """
45 | if value == 0:
46 | return "0"
47 | return f"{value:.{precision}e}".replace("e-0", "E-" if capitalize else "e-")
48 |
49 |
50 | def print_str(*args, **kwargs):
51 | """
52 | Same as print() signature but returns a string
53 | """
54 | sstream = io.StringIO()
55 | kwargs.pop("file", None)
56 | print(*args, **kwargs, file=sstream)
57 | return sstream.getvalue()
58 |
59 |
60 | def fstring(fmt_str, **kwargs):
61 | """
62 | Simulate python f-string but without `f`
63 | """
64 | locals().update(kwargs)
65 | return eval("f" + shlex.quote(fmt_str))
66 |
67 |
68 | def get_format_keys(fmt_str):
69 | keys = []
70 | for literal, field_name, fmt_spec, conversion in string.Formatter().parse(fmt_str):
71 | if field_name:
72 | keys.append(field_name)
73 | return keys
74 |
75 |
76 | def get_timestamp(milli_precision: int = 3):
77 | fmt = "%y-%m-%d %H:%M:%S"
78 | if milli_precision > 0:
79 | fmt += ".%f"
80 | stamp = datetime.now().strftime(fmt)
81 | if milli_precision > 0:
82 | stamp = stamp[:-milli_precision]
83 | return stamp
84 |
85 |
86 | def pretty_repr_str(obj, **kwargs):
87 | """
88 | Useful to produce __repr__()
89 | """
90 | if isinstance(obj, str):
91 | cls_name = obj
92 | else:
93 | cls_name = obj.__class__.__name__
94 | kw_strs = [
95 | k + "=" + pprint.pformat(v, indent=2, compact=True) for k, v in kwargs.items()
96 | ]
97 | new_line = len(cls_name) + sum(len(kw) for kw in kw_strs) > 84
98 | if new_line:
99 | kw = ",\n".join(kw_strs)
100 | return f"{cls_name}(\n{textwrap.indent(kw, ' ')}\n)"
101 | else:
102 | kw = ", ".join(kw_strs)
103 | return f"{cls_name}({kw})"
104 |
105 |
106 | def pprint_(*objs, **kwargs):
107 | """
108 | Use pprint to format the objects
109 | """
110 | print(
111 | *[
112 | pprint.pformat(obj, indent=2) if not isinstance(obj, str) else obj
113 | for obj in objs
114 | ],
115 | **kwargs,
116 | )
117 |
118 |
119 | def get_exception_info(to_str: bool = False):
120 | """
121 | Returns:
122 | {'type': ExceptionType, 'value': ExceptionObject, 'trace': }
123 | """
124 | typ_, value, trace = sys.exc_info()
125 | return {
126 | "type": typ_.__name__ if to_str else typ_,
127 | "value": str(value) if to_str else value,
128 | "trace": "".join(traceback.format_exception(typ_, value, trace)),
129 | }
130 |
131 |
132 | class DebugPrinter:
133 | """
134 | Debug print, usage: dprint = DebugPrint(enabled=True)
135 | dprint(...)
136 | """
137 |
138 | def __init__(
139 | self, enabled, tensor_summary: Literal["shape", "shape+dtype", "none"] = "shape"
140 | ):
141 | """
142 | Args:
143 | tensor_summary:
144 | - shape: only prints shape
145 | - shape+dtype: also prints dtype and device
146 | - none: print full tensor
147 | """
148 | self.enabled = enabled
149 | assert tensor_summary in ["shape", "shape+dtype", "none"]
150 | self.tensor_summary = tensor_summary
151 |
152 | def __call__(self, *args, **kwargs):
153 | if not self.enabled:
154 | return
155 | args = [self._process_arg(a) for a in args]
156 | pprint_(*args, **kwargs)
157 |
158 | def _process_arg(self, arg):
159 | import torch
160 | import numpy as np
161 |
162 | if torch.is_tensor(arg):
163 | if self.tensor_summary == "shape":
164 | return str(list(arg.size()))
165 | elif self.tensor_summary == "shape+dtype":
166 | return f"{arg.dtype}{list(arg.size())}|{arg.device}"
167 | elif isinstance(arg, np.ndarray):
168 | if self.tensor_summary == "shape":
169 | return str(list(arg.shape))
170 | elif self.tensor_summary == "shape+dtype":
171 | return f"{arg.dtype}{list(arg.shape)}"
172 | return arg
173 |
174 |
175 | @meta_decorator
176 | def watch(func, seconds: int = 5, max_times: int = 0, keep_returns: bool = False):
177 | """
178 | Decorator: executes a function repeated with the args and
179 | emulate `watch -n` capability
180 |
181 | See `gpustat` repo: https://github.com/wookayin/gpustat/pull/41/files
182 |
183 | Args:
184 | max_times: watch for `max_times` and then exit. If 0, never exits
185 | keep_returns: if True, will keep the return value from the function
186 | and return as a list at the end
187 | """
188 | from blessings import Terminal
189 |
190 | def _wrapped(*args, **kwargs):
191 | term = Terminal()
192 | N = 0
193 | returns = []
194 | with term.fullscreen():
195 | while True:
196 | try:
197 | with term.location(0, 0):
198 | ret = func(*args, **kwargs)
199 | print(term.clear_eos, end="")
200 | if keep_returns:
201 | returns.append(ret)
202 | N += 1
203 | if max_times > 0 and N >= max_times:
204 | break
205 | time.sleep(seconds)
206 | except KeyboardInterrupt:
207 | break
208 | return returns
209 |
210 | return _wrapped
211 |
212 |
213 | class PrintRedirection(object):
214 | """
215 | Context manager: temporarily redirects stdout and stderr
216 | """
217 |
218 | def __init__(self, stdout=None, stderr=None):
219 | """
220 | Args:
221 | stdout: if None, defaults to sys.stdout, unchanged
222 | stderr: if None, defaults to sys.stderr, unchanged
223 | """
224 | if stdout is None:
225 | stdout = sys.stdout
226 | if stderr is None:
227 | stderr = sys.stderr
228 | self._stdout, self._stderr = stdout, stderr
229 |
230 | def __enter__(self):
231 | self._old_out, self._old_err = sys.stdout, sys.stderr
232 | self._old_out.flush()
233 | self._old_err.flush()
234 | sys.stdout, sys.stderr = self._stdout, self._stderr
235 | return self
236 |
237 | def __exit__(self, exc_type, exc_value, traceback):
238 | self.flush()
239 | # restore the normal stdout and stderr
240 | sys.stdout, sys.stderr = self._old_out, self._old_err
241 |
242 | def flush(self):
243 | "Manually flush the replaced stdout/stderr buffers."
244 | self._stdout.flush()
245 | self._stderr.flush()
246 |
247 |
248 | class PrintToFile(PrintRedirection):
249 | """
250 | Print to file and save/close the handle at the end.
251 | """
252 |
253 | def __init__(self, out_file=None, err_file=None):
254 | """
255 | Args:
256 | out_file: file path
257 | err_file: file path. If the same as out_file, print both stdout
258 | and stderr to one file in order.
259 | """
260 | self.out_file, self.err_file = out_file, err_file
261 | if out_file:
262 | out_file = os.path.expanduser(out_file)
263 | self.out_file = open(out_file, "w")
264 | if err_file:
265 | err_file = os.path.expanduser(out_file)
266 | if err_file == out_file: # redirect both stdout/err to one file
267 | self.err_file = self.out_file
268 | else:
269 | self.err_file = open(os.path.expanduser(out_file), "w")
270 | super().__init__(stdout=self.out_file, stderr=self.err_file)
271 |
272 | def __exit__(self, *args):
273 | super().__exit__(*args)
274 | if self.out_file:
275 | self.out_file.close()
276 | if self.err_file:
277 | self.err_file.close()
278 |
279 |
280 | def PrintSuppress(no_out=True, no_err=False):
281 | """
282 | Args:
283 | no_out: stdout writes to sys.devnull
284 | no_err: stderr writes to sys.devnull
285 | """
286 | out_file = os.devnull if no_out else None
287 | err_file = os.devnull if no_err else None
288 | return PrintToFile(out_file=out_file, err_file=err_file)
289 |
290 |
291 | class PrintString(PrintRedirection):
292 | """
293 | Redirect stdout and stderr to strings.
294 | """
295 |
296 | def __init__(self):
297 | self.out_stream = io.StringIO()
298 | self.err_stream = io.StringIO()
299 | super().__init__(stdout=self.out_stream, stderr=self.err_stream)
300 |
301 | def stdout(self):
302 | "Returns: stdout as one string."
303 | return self.out_stream.getvalue()
304 |
305 | def stderr(self):
306 | "Returns: stderr as one string."
307 | return self.err_stream.getvalue()
308 |
309 | def stdout_by_line(self):
310 | "Returns: a list of stdout line by line, ignore trailing blanks"
311 | return self.stdout().rstrip().split("\n")
312 |
313 | def stderr_by_line(self):
314 | "Returns: a list of stderr line by line, ignore trailing blanks"
315 | return self.stderr().rstrip().split("\n")
316 |
317 |
318 | # ==================== Logging filters ====================
319 | class ExcludeLoggingFilter(logging.Filter):
320 | """
321 | Usage: logging.getLogger('name').addFilter(
322 | ExcludeLoggingFilter(['info mess*age', 'Warning: *'])
323 | )
324 | Supports wildcard.
325 | https://relaxdiego.com/2014/07/logging-in-python.html
326 | """
327 |
328 | def __init__(self, patterns):
329 | super().__init__()
330 | self._patterns = patterns
331 |
332 | def filter(self, record):
333 | if match_patterns(record.msg, include=self._patterns):
334 | return False
335 | else:
336 | return True
337 |
338 |
339 | class ReplaceStringLoggingFilter(logging.Filter):
340 | def __init__(self, patterns, replacer: Callable):
341 | super().__init__()
342 | self._patterns = patterns
343 | assert callable(replacer)
344 | self._replacer = replacer
345 |
346 | def filter(self, record):
347 | if match_patterns(record.msg, include=self._patterns):
348 | record.msg = self._replacer(record.msg)
349 |
350 |
351 | def logging_exclude_pattern(
352 | logger_name,
353 | patterns: Union[str, list[str], Callable, list[Callable], None],
354 | ):
355 | """
356 | Args:
357 | patterns: see enlight.utils.misc_utils.match_patterns
358 | """
359 | logging.getLogger(logger_name).addFilter(ExcludeLoggingFilter(patterns))
360 |
361 |
362 | def logging_replace_string(
363 | logger_name,
364 | patterns: Union[str, list[str], Callable, list[Callable], None],
365 | replacer: Callable,
366 | ):
367 | """
368 | Args:
369 | patterns: see enlight.utils.misc_utils.match_patterns
370 | """
371 | logging.getLogger(logger_name).addFilter(
372 | ReplaceStringLoggingFilter(patterns, replacer)
373 | )
374 |
--------------------------------------------------------------------------------
/brs_algo/utils/random_seed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def set_seed(seed, torch_deterministic=False, rank=0):
8 | """set seed across modules"""
9 | if seed == -1 and torch_deterministic:
10 | seed = 42 + rank
11 | elif seed == -1:
12 | seed = np.random.randint(0, 10000)
13 | else:
14 | seed = seed + rank
15 |
16 | print("Setting seed: {}".format(seed))
17 |
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | os.environ["PYTHONHASHSEED"] = str(seed)
22 | torch.cuda.manual_seed(seed)
23 | torch.cuda.manual_seed_all(seed)
24 |
25 | if torch_deterministic:
26 | # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
27 | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
28 | torch.backends.cudnn.benchmark = False
29 | torch.backends.cudnn.deterministic = True
30 | torch.use_deterministic_algorithms(True)
31 | else:
32 | torch.backends.cudnn.benchmark = True
33 | torch.backends.cudnn.deterministic = False
34 |
35 | return seed
36 |
--------------------------------------------------------------------------------
/brs_algo/utils/shape_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Shape inference methods
3 | """
4 |
5 | import math
6 | import numpy as np
7 | import torch
8 | import warnings
9 | from functools import partial
10 |
11 | from typing import List, Tuple, Union
12 |
13 | try:
14 | import tree
15 | except ImportError:
16 | pass
17 |
18 |
19 | # fmt: off
20 | __all__ = [
21 | "shape_convnd",
22 | "shape_conv1d", "shape_conv2d", "shape_conv3d",
23 | "shape_transpose_convnd",
24 | "shape_transpose_conv1d", "shape_transpose_conv2d", "shape_transpose_conv3d",
25 | "shape_poolnd",
26 | "shape_maxpool1d", "shape_maxpool2d", "shape_maxpool3d",
27 | "shape_avgpool1d", "shape_avgpool2d", "shape_avgpool3d",
28 | "shape_slice",
29 | "check_shape"
30 | ]
31 | # fmt: on
32 |
33 |
34 | def _get_shape(x):
35 | "single object"
36 | if isinstance(x, np.ndarray):
37 | return tuple(x.shape)
38 | else:
39 | return tuple(x.size())
40 |
41 |
42 | def _expands(dim, *xs):
43 | "repeat vars like kernel and stride to match dim"
44 |
45 | def _expand(x):
46 | if isinstance(x, int):
47 | return (x,) * dim
48 | else:
49 | assert len(x) == dim
50 | return x
51 |
52 | return map(lambda x: _expand(x), xs)
53 |
54 |
55 | _HELPER_TENSOR = torch.zeros((1,))
56 |
57 |
58 | def shape_slice(input_shape, slice):
59 | """
60 | Credit to Adam Paszke for the trick. Shape inference without instantiating
61 | an actual tensor.
62 | The key is that `.expand()` does not actually allocate memory
63 | Still needs to allocate a one-element HELPER_TENSOR.
64 | """
65 | shape = _HELPER_TENSOR.expand(*input_shape)[slice]
66 | if hasattr(shape, "size"):
67 | return tuple(shape.size())
68 | return (1,)
69 |
70 |
71 | class ShapeSlice:
72 | """
73 | shape_slice inference with easy []-operator
74 | """
75 |
76 | def __init__(self, input_shape):
77 | self.input_shape = input_shape
78 |
79 | def __getitem__(self, slice):
80 | return shape_slice(self.input_shape, slice)
81 |
82 |
83 | def check_shape(
84 | value: Union[Tuple, List, torch.Tensor, np.ndarray],
85 | expected: Union[Tuple, List, torch.Tensor, np.ndarray],
86 | err_msg="",
87 | mode="raise",
88 | ):
89 | """
90 | Args:
91 | value: np array or torch Tensor
92 | expected:
93 | - list[int], tuple[int]: if any value is None, will match any dim
94 | - np array or torch Tensor: must have the same dimensions
95 | mode:
96 | - "raise": raise ValueError, shape mismatch
97 | - "return": returns True if shape matches, otherwise False
98 | - "warning": warnings.warn
99 | """
100 | assert mode in ["raise", "return", "warning"]
101 | if torch.is_tensor(value):
102 | actual_shape = value.size()
103 | elif hasattr(value, "shape"):
104 | actual_shape = value.shape
105 | else:
106 | assert isinstance(value, (list, tuple))
107 | actual_shape = value
108 | assert all(
109 | isinstance(s, int) for s in actual_shape
110 | ), f"actual shape: {actual_shape} is not a list of ints"
111 |
112 | if torch.is_tensor(expected):
113 | expected_shape = expected.size()
114 | elif hasattr(expected, "shape"):
115 | expected_shape = expected.shape
116 | else:
117 | assert isinstance(expected, (list, tuple))
118 | expected_shape = expected
119 |
120 | err_msg = f" for {err_msg}" if err_msg else ""
121 |
122 | if len(actual_shape) != len(expected_shape):
123 | err_msg = (
124 | f"Dimension mismatch{err_msg}: actual shape {actual_shape} "
125 | f"!= expected shape {expected_shape}."
126 | )
127 | if mode == "raise":
128 | raise ValueError(err_msg)
129 | elif mode == "warning":
130 | warnings.warn(err_msg)
131 | return False
132 |
133 | for s_a, s_e in zip(actual_shape, expected_shape):
134 | if s_e is not None and s_a != s_e:
135 | err_msg = (
136 | f"Shape mismatch{err_msg}: actual shape {actual_shape} "
137 | f"!= expected shape {expected_shape}."
138 | )
139 | if mode == "raise":
140 | raise ValueError(err_msg)
141 | elif mode == "warning":
142 | warnings.warn(err_msg)
143 | return False
144 | return True
145 |
146 |
147 | def shape_convnd(
148 | dim,
149 | input_shape,
150 | out_channels,
151 | kernel_size,
152 | stride=1,
153 | padding=0,
154 | dilation=1,
155 | has_batch=False,
156 | ):
157 | """
158 | http://pytorch.org/docs/nn.html#conv1d
159 | http://pytorch.org/docs/nn.html#conv2d
160 | http://pytorch.org/docs/nn.html#conv3d
161 |
162 | Args:
163 | dim: supports 1D to 3D
164 | input_shape:
165 | - 1D: [channel, length]
166 | - 2D: [channel, height, width]
167 | - 3D: [channel, depth, height, width]
168 | has_batch: whether the first dim is batch size or not
169 | """
170 | if has_batch:
171 | assert (
172 | len(input_shape) == dim + 2
173 | ), "input shape with batch should be {}-dimensional".format(dim + 2)
174 | else:
175 | assert (
176 | len(input_shape) == dim + 1
177 | ), "input shape without batch should be {}-dimensional".format(dim + 1)
178 | if stride is None:
179 | # for pooling convention in PyTorch
180 | stride = kernel_size
181 | kernel_size, stride, padding, dilation = _expands(
182 | dim, kernel_size, stride, padding, dilation
183 | )
184 | if has_batch:
185 | batch = input_shape[0]
186 | input_shape = input_shape[1:]
187 | else:
188 | batch = None
189 | _, *img = input_shape
190 | new_img_shape = [
191 | math.floor(
192 | (img[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1)
193 | // stride[i]
194 | + 1
195 | )
196 | for i in range(dim)
197 | ]
198 | return ((batch,) if has_batch else ()) + (out_channels, *new_img_shape)
199 |
200 |
201 | def shape_poolnd(
202 | dim, input_shape, kernel_size, stride=None, padding=0, dilation=1, has_batch=False
203 | ):
204 | """
205 | The only difference from infer_shape_convnd is that `stride` default is None
206 | """
207 | if has_batch:
208 | out_channels = input_shape[1]
209 | else:
210 | out_channels = input_shape[0]
211 | return shape_convnd(
212 | dim,
213 | input_shape,
214 | out_channels,
215 | kernel_size,
216 | stride,
217 | padding,
218 | dilation,
219 | has_batch,
220 | )
221 |
222 |
223 | def shape_transpose_convnd(
224 | dim,
225 | input_shape,
226 | out_channels,
227 | kernel_size,
228 | stride=1,
229 | padding=0,
230 | output_padding=0,
231 | dilation=1,
232 | has_batch=False,
233 | ):
234 | """
235 | http://pytorch.org/docs/nn.html#convtranspose1d
236 | http://pytorch.org/docs/nn.html#convtranspose2d
237 | http://pytorch.org/docs/nn.html#convtranspose3d
238 |
239 | Args:
240 | dim: supports 1D to 3D
241 | input_shape:
242 | - 1D: [channel, length]
243 | - 2D: [channel, height, width]
244 | - 3D: [channel, depth, height, width]
245 | has_batch: whether the first dim is batch size or not
246 | """
247 | if has_batch:
248 | assert (
249 | len(input_shape) == dim + 2
250 | ), "input shape with batch should be {}-dimensional".format(dim + 2)
251 | else:
252 | assert (
253 | len(input_shape) == dim + 1
254 | ), "input shape without batch should be {}-dimensional".format(dim + 1)
255 | kernel_size, stride, padding, output_padding, dilation = _expands(
256 | dim, kernel_size, stride, padding, output_padding, dilation
257 | )
258 | if has_batch:
259 | batch = input_shape[0]
260 | input_shape = input_shape[1:]
261 | else:
262 | batch = None
263 | _, *img = input_shape
264 | new_img_shape = [
265 | (img[i] - 1) * stride[i] - 2 * padding[i] + kernel_size[i] + output_padding[i]
266 | for i in range(dim)
267 | ]
268 | return ((batch,) if has_batch else ()) + (out_channels, *new_img_shape)
269 |
270 |
271 | shape_conv1d = partial(shape_convnd, 1)
272 | shape_conv2d = partial(shape_convnd, 2)
273 | shape_conv3d = partial(shape_convnd, 3)
274 |
275 |
276 | shape_transpose_conv1d = partial(shape_transpose_convnd, 1)
277 | shape_transpose_conv2d = partial(shape_transpose_convnd, 2)
278 | shape_transpose_conv3d = partial(shape_transpose_convnd, 3)
279 |
280 |
281 | shape_maxpool1d = partial(shape_poolnd, 1)
282 | shape_maxpool2d = partial(shape_poolnd, 2)
283 | shape_maxpool3d = partial(shape_poolnd, 3)
284 |
285 |
286 | """
287 | http://pytorch.org/docs/nn.html#avgpool1d
288 | http://pytorch.org/docs/nn.html#avgpool2d
289 | http://pytorch.org/docs/nn.html#avgpool3d
290 | """
291 | shape_avgpool1d = partial(shape_maxpool1d, dilation=1)
292 | shape_avgpool2d = partial(shape_maxpool2d, dilation=1)
293 | shape_avgpool3d = partial(shape_maxpool3d, dilation=1)
294 |
--------------------------------------------------------------------------------
/brs_algo/utils/termcolor.py:
--------------------------------------------------------------------------------
1 | # coding: utf-8
2 | # Copyright (c) 2008-2011 Volvox Development Team
3 | #
4 | # Permission is hereby granted, free of charge, to any person obtaining a copy
5 | # of this software and associated documentation files (the "Software"), to deal
6 | # in the Software without restriction, including without limitation the rights
7 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | # copies of the Software, and to permit persons to whom the Software is
9 | # furnished to do so, subject to the following conditions:
10 | #
11 | # The above copyright notice and this permission notice shall be included in
12 | # all copies or substantial portions of the Software.
13 | #
14 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | # THE SOFTWARE.
21 | #
22 | # Original Author: Konstantin Lepa
23 | # Updated by Jim Fan
24 |
25 | """ANSII Color formatting for output in terminal."""
26 | import io
27 | import os
28 | from typing import Union, Optional, List
29 |
30 |
31 | __ALL__ = ["color_text", "cprint"]
32 |
33 | STYLES = dict(
34 | list(
35 | zip(
36 | ["bold", "dark", "", "underline", "blink", "", "reverse", "concealed"],
37 | list(range(1, 9)),
38 | )
39 | )
40 | )
41 | del STYLES[""]
42 |
43 |
44 | HIGHLIGHTS = dict(
45 | list(
46 | zip(
47 | ["grey", "red", "green", "yellow", "blue", "magenta", "cyan", "white"],
48 | list(range(40, 48)),
49 | )
50 | )
51 | )
52 |
53 |
54 | COLORS = dict(
55 | list(
56 | zip(
57 | ["grey", "red", "green", "yellow", "blue", "magenta", "cyan", "white"],
58 | list(range(30, 38)),
59 | )
60 | )
61 | )
62 |
63 |
64 | def _strip_bg_prefix(color):
65 | "on_red -> red"
66 | if color.startswith("on_"):
67 | return color[len("on_") :]
68 | else:
69 | return color
70 |
71 |
72 | RESET = "\033[0m"
73 |
74 |
75 | def color_text(
76 | text,
77 | color: Optional[str] = None,
78 | bg_color: Optional[str] = None,
79 | styles: Optional[Union[str, List[str]]] = None,
80 | ):
81 | """Colorize text.
82 |
83 | Available text colors:
84 | red, green, yellow, blue, magenta, cyan, white.
85 |
86 | Available text highlights:
87 | on_red, on_green, on_yellow, on_blue, on_magenta, on_cyan, on_white.
88 |
89 | Available attributes:
90 | bold, dark, underline, blink, reverse, concealed.
91 |
92 | Example:
93 | colored('Hello, World!', 'red', 'on_grey', ['blue', 'blink'])
94 | colored('Hello, World!', 'green')
95 | """
96 | if os.getenv("ANSI_COLORS_DISABLED") is None:
97 | fmt_str = "\033[%dm%s"
98 | if color is not None:
99 | text = fmt_str % (COLORS[color], text)
100 |
101 | if bg_color is not None:
102 | bg_color = _strip_bg_prefix(bg_color)
103 | text = fmt_str % (HIGHLIGHTS[bg_color], text)
104 |
105 | if styles is not None:
106 | if isinstance(styles, str):
107 | styles = [styles]
108 | for style in styles:
109 | text = fmt_str % (STYLES[style], text)
110 |
111 | text += RESET
112 | return text
113 |
114 |
115 | def cprint(
116 | *args,
117 | color: Optional[str] = None,
118 | bg_color: Optional[str] = None,
119 | styles: Optional[Union[str, List[str]]] = None,
120 | **kwargs,
121 | ):
122 | """Print colorize text.
123 |
124 | It accepts arguments of print function.
125 | """
126 | sstream = io.StringIO()
127 | print(*args, sep=kwargs.pop("sep", None), end="", file=sstream)
128 | text = sstream.getvalue()
129 | print((color_text(text, color, bg_color, styles)), **kwargs)
130 |
131 |
132 | if __name__ == "__main__":
133 | print("Current terminal type: %s" % os.getenv("TERM"))
134 | print("Test basic colors:")
135 | cprint("Grey color", color="grey")
136 | cprint("Red color", color="red")
137 | cprint("Green color", color="green")
138 | cprint("Yellow color", color="yellow")
139 | cprint("Blue color", color="blue")
140 | cprint("Magenta color", color="magenta")
141 | cprint("Cyan color", color="cyan")
142 | cprint("White color", color="white")
143 | print(("-" * 78))
144 |
145 | print("Test highlights:")
146 | cprint("On grey color", bg_color="on_grey")
147 | cprint("On red color", bg_color="on_red")
148 | cprint("On green color", bg_color="on_green")
149 | cprint("On yellow color", bg_color="on_yellow")
150 | cprint("On blue color", bg_color="on_blue")
151 | cprint("On magenta color", bg_color="on_magenta")
152 | cprint("On cyan color", bg_color="on_cyan")
153 | cprint("On white color", color="grey", bg_color="on_white")
154 | print("-" * 78)
155 |
156 | print("Test attributes:")
157 | cprint("Bold grey color", color="grey", styles="bold")
158 | cprint("Dark red color", color="red", styles=["dark"])
159 | cprint("Underline green color", color="green", styles=["underline"])
160 | cprint("Blink yellow color", color="yellow", styles=["blink"])
161 | cprint("Reversed blue color", color="blue", styles=["reverse"])
162 | cprint("Concealed Magenta color", color="magenta", styles=["concealed"])
163 | cprint(
164 | "Bold underline reverse cyan color",
165 | color="cyan",
166 | styles=["bold", "underline", "reverse"],
167 | )
168 | cprint(
169 | "Dark blink concealed white color",
170 | color="white",
171 | styles=["dark", "blink", "concealed"],
172 | )
173 | print(("-" * 78))
174 |
175 | print("Test mixing:")
176 | cprint(
177 | "Underline red on grey color",
178 | color="red",
179 | bg_color="on_grey",
180 | styles="underline",
181 | )
182 | cprint(
183 | "Reversed green on red color",
184 | color="green",
185 | bg_color="on_red",
186 | styles="reverse",
187 | )
188 |
--------------------------------------------------------------------------------
/brs_algo/utils/tree_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utils to handle nested data structures
3 |
4 | Install dm_tree first:
5 | https://tree.readthedocs.io/en/latest/api.html
6 | """
7 |
8 | import collections
9 | import numpy as np
10 | from typing import Iterable, List, TypeVar, Any, Tuple
11 |
12 |
13 | try:
14 | import tree
15 |
16 | except ImportError:
17 | raise ImportError("Please install dm_tree first: `pip install dm_tree`")
18 |
19 |
20 | def is_sequence(obj):
21 | """
22 | Returns:
23 | True if the sequence is a collections.Sequence and not a string.
24 | """
25 | return isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)
26 |
27 |
28 | def is_mapping(obj):
29 | """
30 | Returns:
31 | True if the sequence is a collections.Mapping
32 | """
33 | return isinstance(obj, collections.abc.Mapping)
34 |
35 |
36 | def tree_value_at_path(obj, paths: Tuple):
37 | try:
38 | for p in paths:
39 | obj = obj[p]
40 | return obj
41 | except Exception as e:
42 | raise ValueError(f"{e}\n\n-- Incorrect nested path {paths} for object: {obj}.")
43 |
44 |
45 | def tree_assign_at_path(obj, paths: Tuple, value):
46 | try:
47 | for p in paths[:-1]:
48 | obj = obj[p]
49 | if len(paths) > 0:
50 | obj[paths[-1]] = value
51 | except Exception as e:
52 | raise ValueError(f"{e}\n\n-- Incorrect nested path {paths} for object: {obj}.")
53 |
54 |
55 | def copy_non_leaf(obj):
56 | """
57 | Deepcopy the nested structure, but does NOT copy the leaf values like Tensors
58 | """
59 | return tree.map_structure(lambda x: x, obj)
60 |
61 |
62 | # =======================================================================
63 | # Copyright 2018 DeepMind Technologies Limited. All rights reserved.
64 | #
65 | # Licensed under the Apache License, Version 2.0 (the "License");
66 | # you may not use this file except in compliance with the License.
67 | # You may obtain a copy of the License at
68 | #
69 | # http://www.apache.org/licenses/LICENSE-2.0
70 | #
71 | # Unless required by applicable law or agreed to in writing, software
72 | # distributed under the License is distributed on an "AS IS" BASIS,
73 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
74 | # See the License for the specific language governing permissions and
75 | # limitations under the License.
76 |
77 | # Tensor framework-agnostic utilities for manipulating nested structures.
78 |
79 | ElementType = TypeVar("ElementType")
80 |
81 |
82 | def fast_map_structure(func, *structure):
83 | """Faster map_structure implementation which skips some error checking."""
84 | flat_structure = (tree.flatten(s) for s in structure)
85 | entries = zip(*flat_structure)
86 | # Arbitrarily choose one of the structures of the original sequence (the last)
87 | # to match the structure for the flattened sequence.
88 | return tree.unflatten_as(structure[-1], [func(*x) for x in entries])
89 |
90 |
91 | def stack_sequence_fields(sequence: Iterable[ElementType]) -> ElementType:
92 | """Stacks a list of identically nested objects.
93 |
94 | This takes a sequence of identically nested objects and returns a single
95 | nested object whose ith leaf is a stacked numpy array of the corresponding
96 | ith leaf from each element of the sequence.
97 |
98 | For example, if `sequence` is:
99 |
100 | ```python
101 | [{
102 | 'action': np.array([1.0]),
103 | 'observation': (np.array([0.0, 1.0, 2.0]),),
104 | 'reward': 1.0
105 | }, {
106 | 'action': np.array([0.5]),
107 | 'observation': (np.array([1.0, 2.0, 3.0]),),
108 | 'reward': 0.0
109 | }, {
110 | 'action': np.array([0.3]),1
111 | 'observation': (np.array([2.0, 3.0, 4.0]),),
112 | 'reward': 0.5
113 | }]
114 | ```
115 |
116 | Then this function will return:
117 |
118 | ```python
119 | {
120 | 'action': np.array([....]) # array shape = [3 x 1]
121 | 'observation': (np.array([...]),) # array shape = [3 x 3]
122 | 'reward': np.array([...]) # array shape = [3]
123 | }
124 | ```
125 |
126 | Note that the 'observation' entry in the above example has two levels of
127 | nesting, i.e it is a tuple of arrays.
128 |
129 | Args:
130 | sequence: a list of identically nested objects.
131 |
132 | Returns:
133 | A nested object with numpy.
134 |
135 | Raises:
136 | ValueError: If `sequence` is an empty sequence.
137 | """
138 | # Handle empty input sequences.
139 | if not sequence:
140 | raise ValueError("Input sequence must not be empty")
141 |
142 | # Default to asarray when arrays don't have the same shape to be compatible
143 | # with old behaviour.
144 | try:
145 | return fast_map_structure(lambda *values: np.stack(values), *sequence)
146 | except ValueError:
147 | return fast_map_structure(lambda *values: np.asarray(values), *sequence)
148 |
149 |
150 | def unstack_sequence_fields(struct: ElementType, batch_size: int) -> List[ElementType]:
151 | """Converts a struct of batched arrays to a list of structs.
152 |
153 | This is effectively the inverse of `stack_sequence_fields`.
154 |
155 | Args:
156 | struct: An (arbitrarily nested) structure of arrays.
157 | batch_size: The length of the leading dimension of each array in the struct.
158 | This is assumed to be static and known.
159 |
160 | Returns:
161 | A list of structs with the same structure as `struct`, where each leaf node
162 | is an unbatched element of the original leaf node.
163 | """
164 |
165 | return [tree.map_structure(lambda s, i=i: s[i], struct) for i in range(batch_size)]
166 |
167 |
168 | def broadcast_structures(*args: Any) -> Any:
169 | """Returns versions of the arguments that give them the same nested structure.
170 |
171 | Any nested items in *args must have the same structure.
172 |
173 | Any non-nested item will be replaced with a nested version that shares that
174 | structure. The leaves will all be references to the same original non-nested
175 | item.
176 |
177 | If all *args are nested, or all *args are non-nested, this function will
178 | return *args unchanged.
179 |
180 | Example:
181 | ```
182 | a = ('a', 'b')
183 | b = 'c'
184 | tree_a, tree_b = broadcast_structure(a, b)
185 | tree_a
186 | > ('a', 'b')
187 | tree_b
188 | > ('c', 'c')
189 | ```
190 |
191 | Args:
192 | *args: A Sequence of nested or non-nested items.
193 |
194 | Returns:
195 | `*args`, except with all items sharing the same nest structure.
196 | """
197 | if not args:
198 | return
199 |
200 | reference_tree = None
201 | for arg in args:
202 | if tree.is_nested(arg):
203 | reference_tree = arg
204 | break
205 |
206 | if reference_tree is None:
207 | reference_tree = args[0]
208 |
209 | def mirror_structure(value, reference_tree):
210 | if tree.is_nested(value):
211 | # Use check_types=True so that the types of the trees we construct aren't
212 | # dependent on our arbitrary choice of which nested arg to use as the
213 | # reference_tree.
214 | tree.assert_same_structure(value, reference_tree, check_types=True)
215 | return value
216 | else:
217 | return tree.map_structure(lambda _: value, reference_tree)
218 |
219 | return tuple(mirror_structure(arg, reference_tree) for arg in args)
220 |
--------------------------------------------------------------------------------
/main/rollout/clean_house_after_a_wild_party/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | INITIAL_QPOS = {
5 | "torso": np.array([0.01170956, -0.06480368, -0.06577059, -0.03678382]),
6 | "left_arm": np.array(
7 | [1.47320244, 2.82283323, -2.39968242, 0.04618586, -0.04400188, -0.02203692]
8 | ),
9 | "right_arm": np.array(
10 | [-1.54000939, 2.75047247, -2.27132666, -0.0316771, 0.05366865, 0.02300375]
11 | ),
12 | }
13 |
14 |
15 | GRIPPER_CLOSE_STROKE = 1.0
16 | GRIPPER_HALF_WIDTH = 50
17 | NUM_PCD_POINTS = 4096
18 | PAD_PCD_IF_LESS = False
19 | PCD_X_RANGE = (0.0, 2.0)
20 | PCD_Y_RANGE = (-1.0, 1.0)
21 | PCD_Z_RANGE = (-0.5, 1.6)
22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.3, -0.3, -0.3)
23 | MOBILE_BASE_VEL_ACTION_MAX = (0.3, 0.3, 0.3)
24 | HORIZON_STEPS = 2100 * 4
25 | CONTROL_FREQ = 100
26 | ACTION_REPEAT = 12
27 |
--------------------------------------------------------------------------------
/main/rollout/clean_house_after_a_wild_party/rollout_async.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | ACTION_REPEAT,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | control_freq=65,
44 | wait_for_first_odom_msg=False,
45 | )
46 |
47 | policy = WBVIMAPolicy(
48 | prop_dim=21,
49 | prop_keys=[
50 | "odom/base_velocity",
51 | "qpos/torso",
52 | "qpos/left_arm",
53 | "qpos/left_gripper",
54 | "qpos/right_arm",
55 | "qpos/right_gripper",
56 | ],
57 | prop_mlp_hidden_depth=2,
58 | prop_mlp_hidden_dim=256,
59 | pointnet_n_coordinates=3,
60 | pointnet_n_color=3,
61 | pointnet_hidden_depth=2,
62 | pointnet_hidden_dim=256,
63 | action_keys=[
64 | "mobile_base",
65 | "torso",
66 | "left_arm",
67 | "left_gripper",
68 | "right_arm",
69 | "right_gripper",
70 | ],
71 | action_key_dims={
72 | "mobile_base": 3,
73 | "torso": 4,
74 | "left_arm": 6,
75 | "left_gripper": 1,
76 | "right_arm": 6,
77 | "right_gripper": 1,
78 | },
79 | num_latest_obs=NUM_LATEST_OBS,
80 | use_modality_type_tokens=False,
81 | xf_n_embd=256,
82 | xf_n_layer=2,
83 | xf_n_head=8,
84 | xf_dropout_rate=0.1,
85 | xf_use_geglu=True,
86 | learnable_action_readout_token=False,
87 | action_dim=21,
88 | action_prediction_horizon=T_action_prediction,
89 | diffusion_step_embed_dim=128,
90 | unet_down_dims=[64, 128],
91 | unet_kernel_size=5,
92 | unet_n_groups=8,
93 | unet_cond_predict_scale=True,
94 | noise_scheduler=DDIMScheduler(
95 | num_train_timesteps=100,
96 | beta_start=0.0001,
97 | beta_end=0.02,
98 | beta_schedule="squaredcos_cap_v2",
99 | clip_sample=True,
100 | set_alpha_to_one=True,
101 | steps_offset=0,
102 | prediction_type="epsilon",
103 | ),
104 | noise_scheduler_step_kwargs=None,
105 | num_denoise_steps_per_inference=16,
106 | )
107 | U.load_state_dict(
108 | policy,
109 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
110 | strip_prefix="policy.",
111 | strict=True,
112 | )
113 | policy = policy.to(DEVICE)
114 | policy.eval()
115 |
116 | rollout = R1AsyncedRollout(
117 | robot_interface=robot,
118 | num_pcd_points=NUM_PCD_POINTS,
119 | pcd_x_range=PCD_X_RANGE,
120 | pcd_y_range=PCD_Y_RANGE,
121 | pcd_z_range=PCD_Z_RANGE,
122 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
123 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
124 | gripper_half_width=GRIPPER_HALF_WIDTH,
125 | num_latest_obs=NUM_LATEST_OBS,
126 | num_deployed_actions=T_action_prediction,
127 | device=DEVICE,
128 | action_execute_start_idx=args.action_execute_start_idx,
129 | policy=policy,
130 | horizon_steps=HORIZON_STEPS,
131 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
132 | action_repeat=ACTION_REPEAT,
133 | )
134 |
135 | input("Press [ENTER] to reset robot to initial qpos")
136 | # reset robot to initial qpos
137 | robot.control(
138 | arm_cmd={
139 | "left": INITIAL_QPOS["left_arm"],
140 | "right": INITIAL_QPOS["right_arm"],
141 | },
142 | gripper_cmd={
143 | "left": 0.1,
144 | "right": 0.1,
145 | },
146 | torso_cmd=INITIAL_QPOS["torso"],
147 | )
148 |
149 | input("Press [ENTER] to start rollout")
150 | for i in range(3):
151 | print(3 - i)
152 | time.sleep(1)
153 | rollout.rollout()
154 |
155 |
156 | if __name__ == "__main__":
157 | args = argparse.ArgumentParser()
158 | args.add_argument("--ckpt_path", type=str, required=True)
159 | args.add_argument("--action_execute_start_idx", type=int, default=1)
160 | args = args.parse_args()
161 | rollout(args)
162 |
--------------------------------------------------------------------------------
/main/rollout/clean_house_after_a_wild_party/rollout_sync.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PCD_X_RANGE,
8 | PCD_Y_RANGE,
9 | PCD_Z_RANGE,
10 | MOBILE_BASE_VEL_ACTION_MIN,
11 | MOBILE_BASE_VEL_ACTION_MAX,
12 | GRIPPER_HALF_WIDTH,
13 | HORIZON_STEPS,
14 | CONTROL_FREQ,
15 | )
16 | import numpy as np
17 | import torch
18 | from brs_ctrl.robot_interface import R1Interface
19 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
20 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
21 | import brs_algo.utils as U
22 | from brs_algo.learning.policy import WBVIMAPolicy
23 | from brs_algo.rollout.synced_rollout import R1SyncedRollout
24 |
25 | DEVICE = torch.device("cuda:0")
26 | NUM_LATEST_OBS = 2
27 | HORIZON = 16
28 | T_action_prediction = 8
29 |
30 |
31 | def rollout(args):
32 | robot = R1Interface(
33 | left_gripper=GalaxeaR1G1Gripper(
34 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
35 | ),
36 | right_gripper=GalaxeaR1G1Gripper(
37 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
38 | ),
39 | enable_rgbd=False,
40 | enable_pointcloud=True,
41 | mobile_base_cmd_limit=np.array([0.3, 0.3, 0.3]),
42 | )
43 |
44 | policy = WBVIMAPolicy(
45 | prop_dim=21,
46 | prop_keys=[
47 | "odom/base_velocity",
48 | "qpos/torso",
49 | "qpos/left_arm",
50 | "qpos/left_gripper",
51 | "qpos/right_arm",
52 | "qpos/right_gripper",
53 | ],
54 | prop_mlp_hidden_depth=2,
55 | prop_mlp_hidden_dim=256,
56 | pointnet_n_coordinates=3,
57 | pointnet_n_color=3,
58 | pointnet_hidden_depth=2,
59 | pointnet_hidden_dim=256,
60 | action_keys=[
61 | "mobile_base",
62 | "torso",
63 | "left_arm",
64 | "left_gripper",
65 | "right_arm",
66 | "right_gripper",
67 | ],
68 | action_key_dims={
69 | "mobile_base": 3,
70 | "torso": 4,
71 | "left_arm": 6,
72 | "left_gripper": 1,
73 | "right_arm": 6,
74 | "right_gripper": 1,
75 | },
76 | num_latest_obs=NUM_LATEST_OBS,
77 | use_modality_type_tokens=False,
78 | xf_n_embd=256,
79 | xf_n_layer=2,
80 | xf_n_head=8,
81 | xf_dropout_rate=0.1,
82 | xf_use_geglu=True,
83 | learnable_action_readout_token=False,
84 | action_dim=21,
85 | action_prediction_horizon=T_action_prediction,
86 | diffusion_step_embed_dim=128,
87 | unet_down_dims=[64, 128],
88 | unet_kernel_size=5,
89 | unet_n_groups=8,
90 | unet_cond_predict_scale=True,
91 | noise_scheduler=DDIMScheduler(
92 | num_train_timesteps=100,
93 | beta_start=0.0001,
94 | beta_end=0.02,
95 | beta_schedule="squaredcos_cap_v2",
96 | clip_sample=True,
97 | set_alpha_to_one=True,
98 | steps_offset=0,
99 | prediction_type="epsilon",
100 | ),
101 | noise_scheduler_step_kwargs=None,
102 | num_denoise_steps_per_inference=16,
103 | )
104 | U.load_state_dict(
105 | policy,
106 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
107 | strip_prefix="policy.",
108 | strict=True,
109 | )
110 | policy = policy.to(DEVICE)
111 | policy.eval()
112 |
113 | rollout = R1SyncedRollout(
114 | robot_interface=robot,
115 | num_pcd_points=NUM_PCD_POINTS,
116 | pcd_x_range=PCD_X_RANGE,
117 | pcd_y_range=PCD_Y_RANGE,
118 | pcd_z_range=PCD_Z_RANGE,
119 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
120 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
121 | gripper_half_width=GRIPPER_HALF_WIDTH,
122 | num_latest_obs=NUM_LATEST_OBS,
123 | num_deployed_actions=T_action_prediction,
124 | device=DEVICE,
125 | policy=policy,
126 | horizon_steps=HORIZON_STEPS,
127 | pause_mode=args.pause_mode,
128 | control_freq=CONTROL_FREQ,
129 | )
130 |
131 | input("Press [ENTER] to reset robot to initial qpos")
132 | # reset robot to initial qpos
133 | robot.control(
134 | arm_cmd={
135 | "left": INITIAL_QPOS["left_arm"],
136 | "right": INITIAL_QPOS["right_arm"],
137 | },
138 | gripper_cmd={
139 | "left": 0.1,
140 | "right": 0.1,
141 | },
142 | torso_cmd=INITIAL_QPOS["torso"],
143 | )
144 |
145 | input("Press [ENTER] to start rollout")
146 | for i in range(3):
147 | print(3 - i)
148 | time.sleep(1)
149 | rollout.rollout()
150 |
151 |
152 | if __name__ == "__main__":
153 | args = argparse.ArgumentParser()
154 | args.add_argument("--ckpt_path", type=str, required=True)
155 | args.add_argument("--pause_mode", action="store_true")
156 | args = args.parse_args()
157 | rollout(args)
158 |
--------------------------------------------------------------------------------
/main/rollout/clean_the_toilet/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | INITIAL_QPOS = {
5 | "torso": np.array([0.04507549, -0.15606569, -0.37281176, 0.02517941]),
6 | "left_arm": np.array(
7 | [1.51612224, 2.83734251, -2.39497914, 0.06206091, -0.07555695, 0.00982895]
8 | ),
9 | "right_arm": np.array(
10 | [-1.503665, 2.61041302, -2.01116604, 0.03022945, 0.11115978, 0.0534272]
11 | ),
12 | }
13 |
14 |
15 | GRIPPER_CLOSE_STROKE = 1.0
16 | GRIPPER_HALF_WIDTH = 50
17 | NUM_PCD_POINTS = 4096
18 | PAD_PCD_IF_LESS = True
19 | PCD_X_RANGE = (0.0, 1.0)
20 | PCD_Y_RANGE = (-0.5, 0.5)
21 | PCD_Z_RANGE = (0.0, 1.0)
22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.17, -0.15, -0.2)
23 | MOBILE_BASE_VEL_ACTION_MAX = (0.17, 0.15, 0.2)
24 | HORIZON_STEPS = 1200 * 4
25 | CONTROL_FREQ = 100
26 | ACTION_REPEAT = 12
27 |
--------------------------------------------------------------------------------
/main/rollout/clean_the_toilet/rollout_async.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | ACTION_REPEAT,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | control_freq=65,
44 | )
45 |
46 | policy = WBVIMAPolicy(
47 | prop_dim=21,
48 | prop_keys=[
49 | "odom/base_velocity",
50 | "qpos/torso",
51 | "qpos/left_arm",
52 | "qpos/left_gripper",
53 | "qpos/right_arm",
54 | "qpos/right_gripper",
55 | ],
56 | prop_mlp_hidden_depth=2,
57 | prop_mlp_hidden_dim=256,
58 | pointnet_n_coordinates=3,
59 | pointnet_n_color=3,
60 | pointnet_hidden_depth=2,
61 | pointnet_hidden_dim=256,
62 | action_keys=[
63 | "mobile_base",
64 | "torso",
65 | "left_arm",
66 | "left_gripper",
67 | "right_arm",
68 | "right_gripper",
69 | ],
70 | action_key_dims={
71 | "mobile_base": 3,
72 | "torso": 4,
73 | "left_arm": 6,
74 | "left_gripper": 1,
75 | "right_arm": 6,
76 | "right_gripper": 1,
77 | },
78 | num_latest_obs=NUM_LATEST_OBS,
79 | use_modality_type_tokens=False,
80 | xf_n_embd=256,
81 | xf_n_layer=2,
82 | xf_n_head=8,
83 | xf_dropout_rate=0.1,
84 | xf_use_geglu=True,
85 | learnable_action_readout_token=False,
86 | action_dim=21,
87 | action_prediction_horizon=T_action_prediction,
88 | diffusion_step_embed_dim=128,
89 | unet_down_dims=[64, 128],
90 | unet_kernel_size=5,
91 | unet_n_groups=8,
92 | unet_cond_predict_scale=True,
93 | noise_scheduler=DDIMScheduler(
94 | num_train_timesteps=100,
95 | beta_start=0.0001,
96 | beta_end=0.02,
97 | beta_schedule="squaredcos_cap_v2",
98 | clip_sample=True,
99 | set_alpha_to_one=True,
100 | steps_offset=0,
101 | prediction_type="epsilon",
102 | ),
103 | noise_scheduler_step_kwargs=None,
104 | num_denoise_steps_per_inference=16,
105 | )
106 | U.load_state_dict(
107 | policy,
108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
109 | strip_prefix="policy.",
110 | strict=True,
111 | )
112 | policy = policy.to(DEVICE)
113 | policy.eval()
114 |
115 | rollout = R1AsyncedRollout(
116 | robot_interface=robot,
117 | num_pcd_points=NUM_PCD_POINTS,
118 | pcd_x_range=PCD_X_RANGE,
119 | pcd_y_range=PCD_Y_RANGE,
120 | pcd_z_range=PCD_Z_RANGE,
121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
123 | gripper_half_width=GRIPPER_HALF_WIDTH,
124 | num_latest_obs=NUM_LATEST_OBS,
125 | num_deployed_actions=T_action_prediction,
126 | device=DEVICE,
127 | action_execute_start_idx=args.action_execute_start_idx,
128 | policy=policy,
129 | horizon_steps=HORIZON_STEPS,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | action_repeat=ACTION_REPEAT,
132 | )
133 |
134 | input("Press [ENTER] to reset robot to initial qpos")
135 | # reset robot to initial qpos
136 | robot.control(
137 | arm_cmd={
138 | "left": INITIAL_QPOS["left_arm"],
139 | "right": INITIAL_QPOS["right_arm"],
140 | },
141 | gripper_cmd={
142 | "left": 0.1,
143 | "right": 0.1,
144 | },
145 | torso_cmd=INITIAL_QPOS["torso"],
146 | )
147 |
148 | input("Press [ENTER] to start rollout")
149 | for i in range(3):
150 | print(3 - i)
151 | time.sleep(1)
152 | rollout.rollout()
153 |
154 |
155 | if __name__ == "__main__":
156 | args = argparse.ArgumentParser()
157 | args.add_argument("--ckpt_path", type=str, required=True)
158 | args.add_argument("--action_execute_start_idx", type=int, default=1)
159 | args = args.parse_args()
160 | rollout(args)
161 |
--------------------------------------------------------------------------------
/main/rollout/clean_the_toilet/rollout_sync.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | CONTROL_FREQ,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | )
44 |
45 | policy = WBVIMAPolicy(
46 | prop_dim=21,
47 | prop_keys=[
48 | "odom/base_velocity",
49 | "qpos/torso",
50 | "qpos/left_arm",
51 | "qpos/left_gripper",
52 | "qpos/right_arm",
53 | "qpos/right_gripper",
54 | ],
55 | prop_mlp_hidden_depth=2,
56 | prop_mlp_hidden_dim=256,
57 | pointnet_n_coordinates=3,
58 | pointnet_n_color=3,
59 | pointnet_hidden_depth=2,
60 | pointnet_hidden_dim=256,
61 | action_keys=[
62 | "mobile_base",
63 | "torso",
64 | "left_arm",
65 | "left_gripper",
66 | "right_arm",
67 | "right_gripper",
68 | ],
69 | action_key_dims={
70 | "mobile_base": 3,
71 | "torso": 4,
72 | "left_arm": 6,
73 | "left_gripper": 1,
74 | "right_arm": 6,
75 | "right_gripper": 1,
76 | },
77 | num_latest_obs=NUM_LATEST_OBS,
78 | use_modality_type_tokens=False,
79 | xf_n_embd=256,
80 | xf_n_layer=2,
81 | xf_n_head=8,
82 | xf_dropout_rate=0.1,
83 | xf_use_geglu=True,
84 | learnable_action_readout_token=False,
85 | action_dim=21,
86 | action_prediction_horizon=T_action_prediction,
87 | diffusion_step_embed_dim=128,
88 | unet_down_dims=[64, 128],
89 | unet_kernel_size=5,
90 | unet_n_groups=8,
91 | unet_cond_predict_scale=True,
92 | noise_scheduler=DDIMScheduler(
93 | num_train_timesteps=100,
94 | beta_start=0.0001,
95 | beta_end=0.02,
96 | beta_schedule="squaredcos_cap_v2",
97 | clip_sample=True,
98 | set_alpha_to_one=True,
99 | steps_offset=0,
100 | prediction_type="epsilon",
101 | ),
102 | noise_scheduler_step_kwargs=None,
103 | num_denoise_steps_per_inference=16,
104 | )
105 | U.load_state_dict(
106 | policy,
107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
108 | strip_prefix="policy.",
109 | strict=True,
110 | )
111 | policy = policy.to(DEVICE)
112 | policy.eval()
113 |
114 | rollout = R1SyncedRollout(
115 | robot_interface=robot,
116 | num_pcd_points=NUM_PCD_POINTS,
117 | pcd_x_range=PCD_X_RANGE,
118 | pcd_y_range=PCD_Y_RANGE,
119 | pcd_z_range=PCD_Z_RANGE,
120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
122 | gripper_half_width=GRIPPER_HALF_WIDTH,
123 | num_latest_obs=NUM_LATEST_OBS,
124 | num_deployed_actions=T_action_prediction,
125 | device=DEVICE,
126 | policy=policy,
127 | horizon_steps=HORIZON_STEPS,
128 | pause_mode=args.pause_mode,
129 | control_freq=CONTROL_FREQ,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | )
132 |
133 | input("Press [ENTER] to reset robot to initial qpos")
134 | # reset robot to initial qpos
135 | robot.control(
136 | arm_cmd={
137 | "left": INITIAL_QPOS["left_arm"],
138 | "right": INITIAL_QPOS["right_arm"],
139 | },
140 | gripper_cmd={
141 | "left": 0.1,
142 | "right": 0.1,
143 | },
144 | torso_cmd=INITIAL_QPOS["torso"],
145 | )
146 |
147 | input("Press [ENTER] to start rollout")
148 | for i in range(3):
149 | print(3 - i)
150 | time.sleep(1)
151 | rollout.rollout()
152 |
153 |
154 | if __name__ == "__main__":
155 | args = argparse.ArgumentParser()
156 | args.add_argument("--ckpt_path", type=str, required=True)
157 | args.add_argument("--pause_mode", action="store_true")
158 | args = args.parse_args()
159 | rollout(args)
160 |
--------------------------------------------------------------------------------
/main/rollout/lay_clothes_out/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | INITIAL_QPOS = {
5 | "torso": np.array([-0.00018469, -0.02381224, -0.16378878, -0.00182143]),
6 | "left_arm": np.array(
7 | [1.54583152, 2.78541468, -2.3352779, 0.11783977, -0.13222753, -0.0264264]
8 | ),
9 | "right_arm": np.array(
10 | [-1.56691272, 2.74406209, -2.24469822, -0.13880591, 0.03090534, 0.08950716]
11 | ),
12 | }
13 |
14 |
15 | GRIPPER_CLOSE_STROKE = 0.5
16 | GRIPPER_HALF_WIDTH = 50
17 | NUM_PCD_POINTS = 4096
18 | PAD_PCD_IF_LESS = True
19 | PCD_X_RANGE = (0.0, 2.3)
20 | PCD_Y_RANGE = (-0.5, 0.5)
21 | PCD_Z_RANGE = (-0.3, 2.0)
22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.25, -0.25, -0.3)
23 | MOBILE_BASE_VEL_ACTION_MAX = (0.25, 0.25, 0.3)
24 | HORIZON_STEPS = 1200 * 4
25 | CONTROL_FREQ = 100
26 | ACTION_REPEAT = 12
27 |
--------------------------------------------------------------------------------
/main/rollout/lay_clothes_out/rollout_async.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | ACTION_REPEAT,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | control_freq=65,
44 | )
45 |
46 | policy = WBVIMAPolicy(
47 | prop_dim=21,
48 | prop_keys=[
49 | "odom/base_velocity",
50 | "qpos/torso",
51 | "qpos/left_arm",
52 | "qpos/left_gripper",
53 | "qpos/right_arm",
54 | "qpos/right_gripper",
55 | ],
56 | prop_mlp_hidden_depth=2,
57 | prop_mlp_hidden_dim=256,
58 | pointnet_n_coordinates=3,
59 | pointnet_n_color=3,
60 | pointnet_hidden_depth=2,
61 | pointnet_hidden_dim=256,
62 | action_keys=[
63 | "mobile_base",
64 | "torso",
65 | "left_arm",
66 | "left_gripper",
67 | "right_arm",
68 | "right_gripper",
69 | ],
70 | action_key_dims={
71 | "mobile_base": 3,
72 | "torso": 4,
73 | "left_arm": 6,
74 | "left_gripper": 1,
75 | "right_arm": 6,
76 | "right_gripper": 1,
77 | },
78 | num_latest_obs=NUM_LATEST_OBS,
79 | use_modality_type_tokens=False,
80 | xf_n_embd=256,
81 | xf_n_layer=2,
82 | xf_n_head=8,
83 | xf_dropout_rate=0.1,
84 | xf_use_geglu=True,
85 | learnable_action_readout_token=False,
86 | action_dim=21,
87 | action_prediction_horizon=T_action_prediction,
88 | diffusion_step_embed_dim=128,
89 | unet_down_dims=[64, 128],
90 | unet_kernel_size=5,
91 | unet_n_groups=8,
92 | unet_cond_predict_scale=True,
93 | noise_scheduler=DDIMScheduler(
94 | num_train_timesteps=100,
95 | beta_start=0.0001,
96 | beta_end=0.02,
97 | beta_schedule="squaredcos_cap_v2",
98 | clip_sample=True,
99 | set_alpha_to_one=True,
100 | steps_offset=0,
101 | prediction_type="epsilon",
102 | ),
103 | noise_scheduler_step_kwargs=None,
104 | num_denoise_steps_per_inference=16,
105 | )
106 | U.load_state_dict(
107 | policy,
108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
109 | strip_prefix="policy.",
110 | strict=True,
111 | )
112 | policy = policy.to(DEVICE)
113 | policy.eval()
114 |
115 | rollout = R1AsyncedRollout(
116 | robot_interface=robot,
117 | num_pcd_points=NUM_PCD_POINTS,
118 | pcd_x_range=PCD_X_RANGE,
119 | pcd_y_range=PCD_Y_RANGE,
120 | pcd_z_range=PCD_Z_RANGE,
121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
123 | gripper_half_width=GRIPPER_HALF_WIDTH,
124 | num_latest_obs=NUM_LATEST_OBS,
125 | num_deployed_actions=T_action_prediction,
126 | device=DEVICE,
127 | action_execute_start_idx=args.action_execute_start_idx,
128 | policy=policy,
129 | horizon_steps=HORIZON_STEPS,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | action_repeat=ACTION_REPEAT,
132 | )
133 |
134 | input("Press [ENTER] to reset robot to initial qpos")
135 | # reset robot to initial qpos
136 | robot.control(
137 | arm_cmd={
138 | "left": INITIAL_QPOS["left_arm"],
139 | "right": INITIAL_QPOS["right_arm"],
140 | },
141 | gripper_cmd={
142 | "left": 0.1,
143 | "right": 0.1,
144 | },
145 | torso_cmd=INITIAL_QPOS["torso"],
146 | )
147 |
148 | input("Press [ENTER] to start rollout")
149 | for i in range(3):
150 | print(3 - i)
151 | time.sleep(1)
152 | rollout.rollout()
153 |
154 |
155 | if __name__ == "__main__":
156 | args = argparse.ArgumentParser()
157 | args.add_argument("--ckpt_path", type=str, required=True)
158 | args.add_argument("--action_execute_start_idx", type=int, default=1)
159 | args = args.parse_args()
160 | rollout(args)
161 |
--------------------------------------------------------------------------------
/main/rollout/lay_clothes_out/rollout_sync.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | CONTROL_FREQ,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | )
44 |
45 | policy = WBVIMAPolicy(
46 | prop_dim=21,
47 | prop_keys=[
48 | "odom/base_velocity",
49 | "qpos/torso",
50 | "qpos/left_arm",
51 | "qpos/left_gripper",
52 | "qpos/right_arm",
53 | "qpos/right_gripper",
54 | ],
55 | prop_mlp_hidden_depth=2,
56 | prop_mlp_hidden_dim=256,
57 | pointnet_n_coordinates=3,
58 | pointnet_n_color=3,
59 | pointnet_hidden_depth=2,
60 | pointnet_hidden_dim=256,
61 | action_keys=[
62 | "mobile_base",
63 | "torso",
64 | "left_arm",
65 | "left_gripper",
66 | "right_arm",
67 | "right_gripper",
68 | ],
69 | action_key_dims={
70 | "mobile_base": 3,
71 | "torso": 4,
72 | "left_arm": 6,
73 | "left_gripper": 1,
74 | "right_arm": 6,
75 | "right_gripper": 1,
76 | },
77 | num_latest_obs=NUM_LATEST_OBS,
78 | use_modality_type_tokens=False,
79 | xf_n_embd=256,
80 | xf_n_layer=2,
81 | xf_n_head=8,
82 | xf_dropout_rate=0.1,
83 | xf_use_geglu=True,
84 | learnable_action_readout_token=False,
85 | action_dim=21,
86 | action_prediction_horizon=T_action_prediction,
87 | diffusion_step_embed_dim=128,
88 | unet_down_dims=[64, 128],
89 | unet_kernel_size=5,
90 | unet_n_groups=8,
91 | unet_cond_predict_scale=True,
92 | noise_scheduler=DDIMScheduler(
93 | num_train_timesteps=100,
94 | beta_start=0.0001,
95 | beta_end=0.02,
96 | beta_schedule="squaredcos_cap_v2",
97 | clip_sample=True,
98 | set_alpha_to_one=True,
99 | steps_offset=0,
100 | prediction_type="epsilon",
101 | ),
102 | noise_scheduler_step_kwargs=None,
103 | num_denoise_steps_per_inference=16,
104 | )
105 | U.load_state_dict(
106 | policy,
107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
108 | strip_prefix="policy.",
109 | strict=True,
110 | )
111 | policy = policy.to(DEVICE)
112 | policy.eval()
113 |
114 | rollout = R1SyncedRollout(
115 | robot_interface=robot,
116 | num_pcd_points=NUM_PCD_POINTS,
117 | pcd_x_range=PCD_X_RANGE,
118 | pcd_y_range=PCD_Y_RANGE,
119 | pcd_z_range=PCD_Z_RANGE,
120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
122 | gripper_half_width=GRIPPER_HALF_WIDTH,
123 | num_latest_obs=NUM_LATEST_OBS,
124 | num_deployed_actions=T_action_prediction,
125 | device=DEVICE,
126 | policy=policy,
127 | horizon_steps=HORIZON_STEPS,
128 | pause_mode=args.pause_mode,
129 | control_freq=CONTROL_FREQ,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | )
132 |
133 | input("Press [ENTER] to reset robot to initial qpos")
134 | # reset robot to initial qpos
135 | robot.control(
136 | arm_cmd={
137 | "left": INITIAL_QPOS["left_arm"],
138 | "right": INITIAL_QPOS["right_arm"],
139 | },
140 | gripper_cmd={
141 | "left": 0.1,
142 | "right": 0.1,
143 | },
144 | torso_cmd=INITIAL_QPOS["torso"],
145 | )
146 |
147 | input("Press [ENTER] to start rollout")
148 | for i in range(3):
149 | print(3 - i)
150 | time.sleep(1)
151 | rollout.rollout()
152 |
153 |
154 | if __name__ == "__main__":
155 | args = argparse.ArgumentParser()
156 | args.add_argument("--ckpt_path", type=str, required=True)
157 | args.add_argument("--pause_mode", action="store_true")
158 | args = args.parse_args()
159 | rollout(args)
160 |
--------------------------------------------------------------------------------
/main/rollout/put_items_onto_shelves/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | INITIAL_QPOS = {
5 | "torso": np.array([0.015661, -0.073367, -0.342951, -0.027991]),
6 | "left_arm": np.array(
7 | [1.29920638, 2.81973617, -2.41644043, -0.00392979, -0.12388298, -0.04762766]
8 | ),
9 | "right_arm": np.array(
10 | [-1.32845532, 2.83736596, -2.39677021, -0.00901277, 0.11162979, 0.02240851]
11 | ),
12 | }
13 |
14 |
15 | GRIPPER_CLOSE_STROKE = 0.5
16 | GRIPPER_HALF_WIDTH = 50
17 | NUM_PCD_POINTS = 4096
18 | PAD_PCD_IF_LESS = True
19 | PCD_X_RANGE = (0.0, 2.0)
20 | PCD_Y_RANGE = (-0.5, 0.5)
21 | PCD_Z_RANGE = (-0.3, 2.0)
22 | MOBILE_BASE_VEL_ACTION_MIN = (-0.1, -0.15, -0.3)
23 | MOBILE_BASE_VEL_ACTION_MAX = (0.1, 0.15, 0.3)
24 | HORIZON_STEPS = 600 * 4
25 | CONTROL_FREQ = 100
26 | ACTION_REPEAT = 12
27 |
--------------------------------------------------------------------------------
/main/rollout/put_items_onto_shelves/rollout_async.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | ACTION_REPEAT,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | control_freq=65,
44 | )
45 |
46 | policy = WBVIMAPolicy(
47 | prop_dim=21,
48 | prop_keys=[
49 | "odom/base_velocity",
50 | "qpos/torso",
51 | "qpos/left_arm",
52 | "qpos/left_gripper",
53 | "qpos/right_arm",
54 | "qpos/right_gripper",
55 | ],
56 | prop_mlp_hidden_depth=2,
57 | prop_mlp_hidden_dim=256,
58 | pointnet_n_coordinates=3,
59 | pointnet_n_color=3,
60 | pointnet_hidden_depth=2,
61 | pointnet_hidden_dim=256,
62 | action_keys=[
63 | "mobile_base",
64 | "torso",
65 | "left_arm",
66 | "left_gripper",
67 | "right_arm",
68 | "right_gripper",
69 | ],
70 | action_key_dims={
71 | "mobile_base": 3,
72 | "torso": 4,
73 | "left_arm": 6,
74 | "left_gripper": 1,
75 | "right_arm": 6,
76 | "right_gripper": 1,
77 | },
78 | num_latest_obs=NUM_LATEST_OBS,
79 | use_modality_type_tokens=False,
80 | xf_n_embd=256,
81 | xf_n_layer=2,
82 | xf_n_head=8,
83 | xf_dropout_rate=0.1,
84 | xf_use_geglu=True,
85 | learnable_action_readout_token=False,
86 | action_dim=21,
87 | action_prediction_horizon=T_action_prediction,
88 | diffusion_step_embed_dim=128,
89 | unet_down_dims=[64, 128],
90 | unet_kernel_size=5,
91 | unet_n_groups=8,
92 | unet_cond_predict_scale=True,
93 | noise_scheduler=DDIMScheduler(
94 | num_train_timesteps=100,
95 | beta_start=0.0001,
96 | beta_end=0.02,
97 | beta_schedule="squaredcos_cap_v2",
98 | clip_sample=True,
99 | set_alpha_to_one=True,
100 | steps_offset=0,
101 | prediction_type="epsilon",
102 | ),
103 | noise_scheduler_step_kwargs=None,
104 | num_denoise_steps_per_inference=16,
105 | )
106 | U.load_state_dict(
107 | policy,
108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
109 | strip_prefix="policy.",
110 | strict=True,
111 | )
112 | policy = policy.to(DEVICE)
113 | policy.eval()
114 |
115 | rollout = R1AsyncedRollout(
116 | robot_interface=robot,
117 | num_pcd_points=NUM_PCD_POINTS,
118 | pcd_x_range=PCD_X_RANGE,
119 | pcd_y_range=PCD_Y_RANGE,
120 | pcd_z_range=PCD_Z_RANGE,
121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
123 | gripper_half_width=GRIPPER_HALF_WIDTH,
124 | num_latest_obs=NUM_LATEST_OBS,
125 | num_deployed_actions=T_action_prediction,
126 | device=DEVICE,
127 | action_execute_start_idx=args.action_execute_start_idx,
128 | policy=policy,
129 | horizon_steps=HORIZON_STEPS,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | action_repeat=ACTION_REPEAT,
132 | )
133 |
134 | input("Press [ENTER] to reset robot to initial qpos")
135 | # reset robot to initial qpos
136 | robot.control(
137 | arm_cmd={
138 | "left": INITIAL_QPOS["left_arm"],
139 | "right": INITIAL_QPOS["right_arm"],
140 | },
141 | gripper_cmd={
142 | "left": 0.1,
143 | "right": 0.1,
144 | },
145 | torso_cmd=INITIAL_QPOS["torso"],
146 | )
147 |
148 | input("Press [ENTER] to start rollout")
149 | for i in range(3):
150 | print(3 - i)
151 | time.sleep(1)
152 | rollout.rollout()
153 |
154 |
155 | if __name__ == "__main__":
156 | args = argparse.ArgumentParser()
157 | args.add_argument("--ckpt_path", type=str, required=True)
158 | args.add_argument("--action_execute_start_idx", type=int, default=1)
159 | args = args.parse_args()
160 | rollout(args)
161 |
--------------------------------------------------------------------------------
/main/rollout/put_items_onto_shelves/rollout_sync.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | CONTROL_FREQ,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | )
44 |
45 | policy = WBVIMAPolicy(
46 | prop_dim=21,
47 | prop_keys=[
48 | "odom/base_velocity",
49 | "qpos/torso",
50 | "qpos/left_arm",
51 | "qpos/left_gripper",
52 | "qpos/right_arm",
53 | "qpos/right_gripper",
54 | ],
55 | prop_mlp_hidden_depth=2,
56 | prop_mlp_hidden_dim=256,
57 | pointnet_n_coordinates=3,
58 | pointnet_n_color=3,
59 | pointnet_hidden_depth=2,
60 | pointnet_hidden_dim=256,
61 | action_keys=[
62 | "mobile_base",
63 | "torso",
64 | "left_arm",
65 | "left_gripper",
66 | "right_arm",
67 | "right_gripper",
68 | ],
69 | action_key_dims={
70 | "mobile_base": 3,
71 | "torso": 4,
72 | "left_arm": 6,
73 | "left_gripper": 1,
74 | "right_arm": 6,
75 | "right_gripper": 1,
76 | },
77 | num_latest_obs=NUM_LATEST_OBS,
78 | use_modality_type_tokens=False,
79 | xf_n_embd=256,
80 | xf_n_layer=2,
81 | xf_n_head=8,
82 | xf_dropout_rate=0.1,
83 | xf_use_geglu=True,
84 | learnable_action_readout_token=False,
85 | action_dim=21,
86 | action_prediction_horizon=T_action_prediction,
87 | diffusion_step_embed_dim=128,
88 | unet_down_dims=[64, 128],
89 | unet_kernel_size=5,
90 | unet_n_groups=8,
91 | unet_cond_predict_scale=True,
92 | noise_scheduler=DDIMScheduler(
93 | num_train_timesteps=100,
94 | beta_start=0.0001,
95 | beta_end=0.02,
96 | beta_schedule="squaredcos_cap_v2",
97 | clip_sample=True,
98 | set_alpha_to_one=True,
99 | steps_offset=0,
100 | prediction_type="epsilon",
101 | ),
102 | noise_scheduler_step_kwargs=None,
103 | num_denoise_steps_per_inference=16,
104 | )
105 | U.load_state_dict(
106 | policy,
107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
108 | strip_prefix="policy.",
109 | strict=True,
110 | )
111 | policy = policy.to(DEVICE)
112 | policy.eval()
113 |
114 | rollout = R1SyncedRollout(
115 | robot_interface=robot,
116 | num_pcd_points=NUM_PCD_POINTS,
117 | pcd_x_range=PCD_X_RANGE,
118 | pcd_y_range=PCD_Y_RANGE,
119 | pcd_z_range=PCD_Z_RANGE,
120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
122 | gripper_half_width=GRIPPER_HALF_WIDTH,
123 | num_latest_obs=NUM_LATEST_OBS,
124 | num_deployed_actions=T_action_prediction,
125 | device=DEVICE,
126 | policy=policy,
127 | horizon_steps=HORIZON_STEPS,
128 | pause_mode=args.pause_mode,
129 | control_freq=CONTROL_FREQ,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | )
132 |
133 | input("Press [ENTER] to reset robot to initial qpos")
134 | # reset robot to initial qpos
135 | robot.control(
136 | arm_cmd={
137 | "left": INITIAL_QPOS["left_arm"],
138 | "right": INITIAL_QPOS["right_arm"],
139 | },
140 | gripper_cmd={
141 | "left": 0.1,
142 | "right": 0.1,
143 | },
144 | torso_cmd=INITIAL_QPOS["torso"],
145 | )
146 |
147 | input("Press [ENTER] to start rollout")
148 | for i in range(3):
149 | print(3 - i)
150 | time.sleep(1)
151 | rollout.rollout()
152 |
153 |
154 | if __name__ == "__main__":
155 | args = argparse.ArgumentParser()
156 | args.add_argument("--ckpt_path", type=str, required=True)
157 | args.add_argument("--pause_mode", action="store_true")
158 | args = args.parse_args()
159 | rollout(args)
160 |
--------------------------------------------------------------------------------
/main/rollout/take_trash_outside/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | INITIAL_QPOS = {
5 | "torso": np.array([0.028215, -0.07880333, -0.12412, 0.02912167]),
6 | "left_arm": np.array(
7 | [
8 | 1.48734574e00,
9 | 2.94314716e00,
10 | -2.58810284e00,
11 | -1.85106383e-03,
12 | -8.12056738e-04,
13 | 2.91258865e-02,
14 | ]
15 | ),
16 | "right_arm": np.array(
17 | [-1.56693262, 2.84392553, -2.37629965, 0.0035922, 0.03647163, 0.01210816]
18 | ),
19 | }
20 |
21 |
22 | GRIPPER_CLOSE_STROKE = 0.5
23 | GRIPPER_HALF_WIDTH = 50
24 | NUM_PCD_POINTS = 4096
25 | PAD_PCD_IF_LESS = True
26 | PCD_X_RANGE = (0.0, 4.0)
27 | PCD_Y_RANGE = (-1.0, 1.0)
28 | PCD_Z_RANGE = (-0.5, 1.6)
29 | MOBILE_BASE_VEL_ACTION_MIN = (-0.35, -0.35, -0.3)
30 | MOBILE_BASE_VEL_ACTION_MAX = (0.35, 0.35, 0.3)
31 | HORIZON_STEPS = 1300 * 4
32 | CONTROL_FREQ = 100
33 | ACTION_REPEAT = 12
34 |
--------------------------------------------------------------------------------
/main/rollout/take_trash_outside/rollout_async.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | ACTION_REPEAT,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.asynced_rollout import R1AsyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | control_freq=65,
44 | )
45 |
46 | policy = WBVIMAPolicy(
47 | prop_dim=21,
48 | prop_keys=[
49 | "odom/base_velocity",
50 | "qpos/torso",
51 | "qpos/left_arm",
52 | "qpos/left_gripper",
53 | "qpos/right_arm",
54 | "qpos/right_gripper",
55 | ],
56 | prop_mlp_hidden_depth=2,
57 | prop_mlp_hidden_dim=256,
58 | pointnet_n_coordinates=3,
59 | pointnet_n_color=3,
60 | pointnet_hidden_depth=2,
61 | pointnet_hidden_dim=256,
62 | action_keys=[
63 | "mobile_base",
64 | "torso",
65 | "left_arm",
66 | "left_gripper",
67 | "right_arm",
68 | "right_gripper",
69 | ],
70 | action_key_dims={
71 | "mobile_base": 3,
72 | "torso": 4,
73 | "left_arm": 6,
74 | "left_gripper": 1,
75 | "right_arm": 6,
76 | "right_gripper": 1,
77 | },
78 | num_latest_obs=NUM_LATEST_OBS,
79 | use_modality_type_tokens=False,
80 | xf_n_embd=256,
81 | xf_n_layer=2,
82 | xf_n_head=8,
83 | xf_dropout_rate=0.1,
84 | xf_use_geglu=True,
85 | learnable_action_readout_token=False,
86 | action_dim=21,
87 | action_prediction_horizon=T_action_prediction,
88 | diffusion_step_embed_dim=128,
89 | unet_down_dims=[64, 128],
90 | unet_kernel_size=5,
91 | unet_n_groups=8,
92 | unet_cond_predict_scale=True,
93 | noise_scheduler=DDIMScheduler(
94 | num_train_timesteps=100,
95 | beta_start=0.0001,
96 | beta_end=0.02,
97 | beta_schedule="squaredcos_cap_v2",
98 | clip_sample=True,
99 | set_alpha_to_one=True,
100 | steps_offset=0,
101 | prediction_type="epsilon",
102 | ),
103 | noise_scheduler_step_kwargs=None,
104 | num_denoise_steps_per_inference=16,
105 | )
106 | U.load_state_dict(
107 | policy,
108 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
109 | strip_prefix="policy.",
110 | strict=True,
111 | )
112 | policy = policy.to(DEVICE)
113 | policy.eval()
114 |
115 | rollout = R1AsyncedRollout(
116 | robot_interface=robot,
117 | num_pcd_points=NUM_PCD_POINTS,
118 | pcd_x_range=PCD_X_RANGE,
119 | pcd_y_range=PCD_Y_RANGE,
120 | pcd_z_range=PCD_Z_RANGE,
121 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
122 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
123 | gripper_half_width=GRIPPER_HALF_WIDTH,
124 | num_latest_obs=NUM_LATEST_OBS,
125 | num_deployed_actions=T_action_prediction,
126 | device=DEVICE,
127 | action_execute_start_idx=args.action_execute_start_idx,
128 | policy=policy,
129 | horizon_steps=HORIZON_STEPS,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | action_repeat=ACTION_REPEAT,
132 | )
133 |
134 | input("Press [ENTER] to reset robot to initial qpos")
135 | # reset robot to initial qpos
136 | robot.control(
137 | arm_cmd={
138 | "left": INITIAL_QPOS["left_arm"],
139 | "right": INITIAL_QPOS["right_arm"],
140 | },
141 | gripper_cmd={
142 | "left": 0.1,
143 | "right": 0.1,
144 | },
145 | torso_cmd=INITIAL_QPOS["torso"],
146 | )
147 |
148 | input("Press [ENTER] to start rollout")
149 | for i in range(3):
150 | print(3 - i)
151 | time.sleep(1)
152 | rollout.rollout()
153 |
154 |
155 | if __name__ == "__main__":
156 | args = argparse.ArgumentParser()
157 | args.add_argument("--ckpt_path", type=str, required=True)
158 | args.add_argument("--action_execute_start_idx", type=int, default=1)
159 | args = args.parse_args()
160 | rollout(args)
161 |
--------------------------------------------------------------------------------
/main/rollout/take_trash_outside/rollout_sync.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from common import (
4 | INITIAL_QPOS,
5 | GRIPPER_CLOSE_STROKE,
6 | NUM_PCD_POINTS,
7 | PAD_PCD_IF_LESS,
8 | PCD_X_RANGE,
9 | PCD_Y_RANGE,
10 | PCD_Z_RANGE,
11 | MOBILE_BASE_VEL_ACTION_MIN,
12 | MOBILE_BASE_VEL_ACTION_MAX,
13 | GRIPPER_HALF_WIDTH,
14 | HORIZON_STEPS,
15 | CONTROL_FREQ,
16 | )
17 | import numpy as np
18 | import torch
19 | from brs_ctrl.robot_interface import R1Interface
20 | from brs_ctrl.robot_interface.grippers import GalaxeaR1G1Gripper
21 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler
22 | import brs_algo.utils as U
23 | from brs_algo.learning.policy import WBVIMAPolicy
24 | from brs_algo.rollout.synced_rollout import R1SyncedRollout
25 |
26 | DEVICE = torch.device("cuda:0")
27 | NUM_LATEST_OBS = 2
28 | HORIZON = 16
29 | T_action_prediction = 8
30 |
31 |
32 | def rollout(args):
33 | robot = R1Interface(
34 | left_gripper=GalaxeaR1G1Gripper(
35 | left_or_right="left", gripper_close_stroke=GRIPPER_CLOSE_STROKE
36 | ),
37 | right_gripper=GalaxeaR1G1Gripper(
38 | left_or_right="right", gripper_close_stroke=GRIPPER_CLOSE_STROKE
39 | ),
40 | enable_rgbd=False,
41 | enable_pointcloud=True,
42 | mobile_base_cmd_limit=np.array(MOBILE_BASE_VEL_ACTION_MAX),
43 | )
44 |
45 | policy = WBVIMAPolicy(
46 | prop_dim=21,
47 | prop_keys=[
48 | "odom/base_velocity",
49 | "qpos/torso",
50 | "qpos/left_arm",
51 | "qpos/left_gripper",
52 | "qpos/right_arm",
53 | "qpos/right_gripper",
54 | ],
55 | prop_mlp_hidden_depth=2,
56 | prop_mlp_hidden_dim=256,
57 | pointnet_n_coordinates=3,
58 | pointnet_n_color=3,
59 | pointnet_hidden_depth=2,
60 | pointnet_hidden_dim=256,
61 | action_keys=[
62 | "mobile_base",
63 | "torso",
64 | "left_arm",
65 | "left_gripper",
66 | "right_arm",
67 | "right_gripper",
68 | ],
69 | action_key_dims={
70 | "mobile_base": 3,
71 | "torso": 4,
72 | "left_arm": 6,
73 | "left_gripper": 1,
74 | "right_arm": 6,
75 | "right_gripper": 1,
76 | },
77 | num_latest_obs=NUM_LATEST_OBS,
78 | use_modality_type_tokens=False,
79 | xf_n_embd=256,
80 | xf_n_layer=2,
81 | xf_n_head=8,
82 | xf_dropout_rate=0.1,
83 | xf_use_geglu=True,
84 | learnable_action_readout_token=False,
85 | action_dim=21,
86 | action_prediction_horizon=T_action_prediction,
87 | diffusion_step_embed_dim=128,
88 | unet_down_dims=[64, 128],
89 | unet_kernel_size=5,
90 | unet_n_groups=8,
91 | unet_cond_predict_scale=True,
92 | noise_scheduler=DDIMScheduler(
93 | num_train_timesteps=100,
94 | beta_start=0.0001,
95 | beta_end=0.02,
96 | beta_schedule="squaredcos_cap_v2",
97 | clip_sample=True,
98 | set_alpha_to_one=True,
99 | steps_offset=0,
100 | prediction_type="epsilon",
101 | ),
102 | noise_scheduler_step_kwargs=None,
103 | num_denoise_steps_per_inference=16,
104 | )
105 | U.load_state_dict(
106 | policy,
107 | U.torch_load(args.ckpt_path, map_location="cpu")["state_dict"],
108 | strip_prefix="policy.",
109 | strict=True,
110 | )
111 | policy = policy.to(DEVICE)
112 | policy.eval()
113 |
114 | rollout = R1SyncedRollout(
115 | robot_interface=robot,
116 | num_pcd_points=NUM_PCD_POINTS,
117 | pcd_x_range=PCD_X_RANGE,
118 | pcd_y_range=PCD_Y_RANGE,
119 | pcd_z_range=PCD_Z_RANGE,
120 | mobile_base_vel_action_min=MOBILE_BASE_VEL_ACTION_MIN,
121 | mobile_base_vel_action_max=MOBILE_BASE_VEL_ACTION_MAX,
122 | gripper_half_width=GRIPPER_HALF_WIDTH,
123 | num_latest_obs=NUM_LATEST_OBS,
124 | num_deployed_actions=T_action_prediction,
125 | device=DEVICE,
126 | policy=policy,
127 | horizon_steps=HORIZON_STEPS,
128 | pause_mode=args.pause_mode,
129 | control_freq=CONTROL_FREQ,
130 | pad_pcd_if_needed=PAD_PCD_IF_LESS,
131 | )
132 |
133 | input("Press [ENTER] to reset robot to initial qpos")
134 | # reset robot to initial qpos
135 | robot.control(
136 | arm_cmd={
137 | "left": INITIAL_QPOS["left_arm"],
138 | "right": INITIAL_QPOS["right_arm"],
139 | },
140 | gripper_cmd={
141 | "left": 0.1,
142 | "right": 0.1,
143 | },
144 | torso_cmd=INITIAL_QPOS["torso"],
145 | )
146 |
147 | input("Press [ENTER] to start rollout")
148 | for i in range(3):
149 | print(3 - i)
150 | time.sleep(1)
151 | rollout.rollout()
152 |
153 |
154 | if __name__ == "__main__":
155 | args = argparse.ArgumentParser()
156 | args.add_argument("--ckpt_path", type=str, required=True)
157 | args.add_argument("--pause_mode", action="store_true")
158 | args = args.parse_args()
159 | rollout(args)
160 |
--------------------------------------------------------------------------------
/main/train/cfg/arch/wbvima.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | run_name: "${optional_:${prefix}}\
4 | ${arch_name}\
5 | _${task_name}\
6 | _To${num_latest_obs}\
7 | _Ta${action_prediction_horizon}\
8 | _b${bs}\
9 | ${__optional:${suffix}}\
10 | ${_optional:${postsuffix}}"
11 |
12 | arch_name: wbvima
13 |
14 | # ====== DP specific ======
15 | num_latest_obs: 2
16 | action_prediction_horizon: 8
17 |
18 | wd: 0.1
19 |
20 | # ------ module ------
21 | module:
22 | _target_: brs_algo.learning.module.DiffusionModule
23 | policy:
24 | _target_: brs_algo.learning.policy.WBVIMAPolicy
25 | prop_dim: 21
26 | prop_keys: ["odom/base_velocity", "qpos/torso", "qpos/left_arm", "qpos/left_gripper", "qpos/right_arm", "qpos/right_gripper"]
27 | num_latest_obs: ${num_latest_obs}
28 | use_modality_type_tokens: false
29 | prop_mlp_hidden_depth: 2
30 | prop_mlp_hidden_dim: 256
31 | pointnet_n_coordinates: 3
32 | pointnet_n_color: 3
33 | pointnet_hidden_depth: 2
34 | pointnet_hidden_dim: 256
35 | action_keys: ${action_keys}
36 | action_key_dims: ${action_key_dims}
37 | # ====== Transformer ======
38 | xf_n_embd: 256
39 | xf_n_layer: 2
40 | xf_n_head: 8
41 | xf_dropout_rate: 0.1
42 | xf_use_geglu: true
43 | # ====== Action Decoding ======
44 | learnable_action_readout_token: false
45 | action_dim: 21
46 | action_prediction_horizon: ${action_prediction_horizon}
47 | diffusion_step_embed_dim: 128
48 | unet_down_dims: [64,128]
49 | unet_kernel_size: 5
50 | unet_n_groups: 8
51 | unet_cond_predict_scale: true
52 | # ====== diffusion ======
53 | noise_scheduler:
54 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler
55 | num_train_timesteps: 100
56 | beta_start: 0.0001
57 | beta_end: 0.02
58 | # beta_schedule is important
59 | # this is the best we found
60 | beta_schedule: squaredcos_cap_v2
61 | clip_sample: True
62 | set_alpha_to_one: True
63 | steps_offset: 0
64 | prediction_type: epsilon # or sample
65 | noise_scheduler_step_kwargs: null
66 | num_denoise_steps_per_inference: 16
67 | action_prediction_horizon: ${action_prediction_horizon}
68 | loss_on_latest_obs_only: false
69 |
70 | data_module:
71 | _target_: brs_algo.learning.data.ActionSeqChunkDataModule
72 | obs_window_size: ${num_latest_obs}
73 | action_prediction_horizon: ${action_prediction_horizon}
74 |
--------------------------------------------------------------------------------
/main/train/cfg/cfg.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_ # all below configs will override this conf.yaml
3 | - arch: wbvima
4 | - task: ???
5 |
6 | run_name: "${optional_:${prefix}}\
7 | ${arch_name}\
8 | _${task_name}\
9 | _lr${scientific:${lr},1}\
10 | _wd${scientific:${wd}}\
11 | _b${bs}\
12 | ${__optional:${suffix}}\
13 | ${_optional:${postsuffix}}"
14 | exp_root_dir: ???
15 | arch_name: ??? # filled by arch
16 |
17 | # ====== main cfg ======
18 | seed: -1
19 | gpus: 1
20 | lr: 7e-4
21 | use_cosine_lr: true
22 | lr_warmup_steps: 1000
23 | lr_cosine_steps: 300000
24 | lr_cosine_min: 5e-6
25 | lr_layer_decay: 1.0
26 | wd: 0.0
27 | bs: 256
28 | vbs: ${bs}
29 | data_dir: ???
30 | eval_interval: 10
31 | rollout_eval: false
32 | # ------ logging ------
33 | use_wandb: true
34 | wandb_project: ???
35 | wandb_run_name: ${run_name}
36 |
37 | # ------ common ------
38 | action_keys: ["mobile_base", "torso", "left_arm", "left_gripper", "right_arm", "right_gripper"]
39 | action_key_dims:
40 | mobile_base: 3
41 | torso: 4
42 | left_arm: 6
43 | left_gripper: 1
44 | right_arm: 6
45 | right_gripper: 1
46 |
47 | # ------ module ------
48 | module:
49 | _target_: ??? # filled by arch
50 | # ====== policy ======
51 | policy: ???
52 | # ====== learning ======
53 | lr: ${lr}
54 | use_cosine_lr: ${use_cosine_lr}
55 | lr_warmup_steps: ${lr_warmup_steps}
56 | lr_cosine_steps: ${lr_cosine_steps}
57 | lr_cosine_min: ${lr_cosine_min}
58 | lr_layer_decay: ${lr_layer_decay}
59 | weight_decay: ${wd}
60 | action_keys: ${action_keys}
61 |
62 | data_module:
63 | _target_: ???
64 | data_path: ${data_dir}
65 | pcd_downsample_points: ${pcd_downsample_points}
66 | batch_size: ${bs}
67 | val_batch_size: ${vbs}
68 | val_split_ratio: 0.1
69 | seed: ${seed}
70 | dataloader_num_workers: 4
71 |
72 | trainer:
73 | cls: pytorch_lightning.Trainer
74 | accelerator: "gpu"
75 | devices: ${gpus}
76 | precision: 32
77 | benchmark: true # enables cudnn.benchmark
78 | accumulate_grad_batches: 1
79 | num_sanity_val_steps: 0
80 | max_epochs: 999999999
81 | val_check_interval: null
82 | check_val_every_n_epoch: ${eval_interval}
83 | gradient_clip_val: 1.0
84 | checkpoint: # this sub-dict will be popped to send to ModelCheckpoint as args
85 | - filename: "epoch{epoch}-train_loss{train/loss:.5f}"
86 | save_on_train_epoch_end: true # this is a training metric, so we save it at the end of training epoch
87 | save_top_k: 100
88 | save_last: true
89 | monitor: "train/loss"
90 | mode: min
91 | auto_insert_metric_name: false # prevent creating subfolder caused by the slash
92 | - filename: "epoch{epoch}-val_l1_{val/l1:.5f}"
93 | eval_type: "static"
94 | save_top_k: -1
95 | save_last: true
96 | monitor: "val/l1"
97 | mode: min
98 | auto_insert_metric_name: false # prevent creating subfolder caused by the slash
99 | callbacks:
100 | - cls: LearningRateMonitor
101 | logging_interval: step
102 | - cls: RichModelSummary
103 |
104 | # ------------- Global cfgs for enlight.LightningTrainer ---------------
105 |
106 |
107 | # ------------- Resume training ---------------
108 | resume:
109 | ckpt_path: null
110 | full_state: false # if true, resume all states including optimizer, amp, lightning callbacks
111 | strict: true
112 |
113 | # ------------- Testing ---------------
114 | test:
115 | ckpt_path: null
116 |
117 | # ----------------------------
118 |
119 | prefix:
120 | suffix:
121 | postsuffix:
122 |
123 | hydra:
124 | job:
125 | chdir: true
126 | run:
127 | dir: "."
128 | output_subdir: null
129 |
--------------------------------------------------------------------------------
/main/train/cfg/task/clean_house_after_a_wild_party.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task_name: clean_house_after_a_wild_party
4 | pcd_downsample_points: 4096
5 | data_module:
6 | pcd_x_range: [0.0, 2.0]
7 | pcd_y_range: [-1.0, 1.0]
8 | pcd_z_range: [-0.5, 1.6]
9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4]
10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4]
11 |
--------------------------------------------------------------------------------
/main/train/cfg/task/clean_the_toilet.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task_name: clean_the_toilet
4 | pcd_downsample_points: 4096
5 | data_module:
6 | pcd_x_range: [0.0, 1.0]
7 | pcd_y_range: [-0.5, 0.5]
8 | pcd_z_range: [0.0, 1.0]
9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4]
10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4]
11 |
--------------------------------------------------------------------------------
/main/train/cfg/task/lay_clothes_out.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task_name: lay_clothes_out
4 | pcd_downsample_points: 4096
5 | data_module:
6 | pcd_x_range: [0.0, 2.3]
7 | pcd_y_range: [-0.5, 0.5]
8 | pcd_z_range: [-0.3, 2.0]
9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4]
10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4]
11 |
--------------------------------------------------------------------------------
/main/train/cfg/task/put_items_onto_shelves.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task_name: put_items_onto_shelves
4 | pcd_downsample_points: 4096
5 | data_module:
6 | pcd_x_range: [0.0, 2.0]
7 | pcd_y_range: [-0.5, 0.5]
8 | pcd_z_range: [-0.3, 2.0]
9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4]
10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4]
11 |
--------------------------------------------------------------------------------
/main/train/cfg/task/take_trash_outside.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | task_name: take_trash_outside
4 | pcd_downsample_points: 4096
5 | data_module:
6 | pcd_x_range: [0.0, 4.0]
7 | pcd_y_range: [-1.0, 1.0]
8 | pcd_z_range: [-0.5, 1.6]
9 | mobile_base_vel_action_min: [-0.3, -0.3, -0.4]
10 | mobile_base_vel_action_max: [0.3, 0.3, 0.4]
11 |
--------------------------------------------------------------------------------
/main/train/train.py:
--------------------------------------------------------------------------------
1 | import hydra
2 |
3 | import brs_algo.utils as U
4 | from brs_algo.lightning import Trainer
5 |
6 |
7 | @hydra.main(config_name="cfg", config_path="cfg", version_base="1.1")
8 | def main(cfg):
9 | cfg.seed = U.set_seed(cfg.seed)
10 | trainer_ = Trainer(cfg)
11 | trainer_.trainer.loggers[-1].log_hyperparams(U.omegaconf_to_dict(cfg))
12 | trainer_.fit()
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/media/SUSig-red.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/media/SUSig-red.png
--------------------------------------------------------------------------------
/media/pull.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/media/pull.gif
--------------------------------------------------------------------------------
/media/wbvima.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/behavior-robot-suite/brs-algo/aabc71e6c64945feff4d82cb559f07665a67ecb8/media/wbvima.gif
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "brs_algo"
7 | dynamic = ["version"]
8 | dependencies = [
9 | "numpy >= 1.26.4",
10 | "torch >= 2.5.1",
11 | "dm_tree >= 0.1.8",
12 | "pytorch-lightning >= 2.3.0",
13 | "pillow >= 11.0.0",
14 | "h5py >= 3.12.1",
15 | "hydra-core >= 1.3.2",
16 | "diffusers >= 0.26.3",
17 | "huggingface-hub == 0.24.6",
18 | "einops >= 0.8.0",
19 | "omegaconf >= 2.3.0",
20 | "tqdm >= 4.67.1",
21 | "transformers >= 4.42.4",
22 | "tensorboard",
23 | "rich",
24 | "tabulate",
25 | "numpy-quaternion"
26 | ]
27 | requires-python = ">=3.11"
28 | authors = [
29 | {name = "Yunfan Jiang", email = "yunfanj@cs.stanford.edu"},
30 | ]
31 | maintainers = [
32 | {name = "Yunfan Jiang", email = "yunfanj@cs.stanford.edu"},
33 | ]
34 | description = "The algorithm repository for BEHAVIOR-Robot-Suite: Streamlining Real-World Whole-Body Manipulation for Everyday Household Activities"
35 | readme = "README.md"
36 | keywords = ["Robotics", "Machine Learning", "Whole-Body Manipulation", "Mobile Manipulation"]
37 | classifiers = [
38 | "Development Status :: 3 - Alpha",
39 | "Intended Audience :: Researchers, Developers",
40 | "Topic :: Scientific/Engineering :: Robotics",
41 | "Programming Language :: Python :: 3.11",
42 | ]
43 |
44 | [project.urls]
45 | Homepage = "https://behavior-robot-suite.github.io/"
46 |
47 |
48 | [tool.setuptools]
49 | packages = ["brs_algo"]
50 | [tool.setuptools.dynamic]
51 | version = {attr = "brs_algo.__version__"}
--------------------------------------------------------------------------------
/scripts/merge_data_files.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 |
5 | import h5py
6 | from tqdm import tqdm
7 |
8 |
9 | def merge(args):
10 | data_dir = args.data_dir
11 | output_file = args.output_file
12 | assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist"
13 | data_files = [
14 | os.path.join(data_dir, f)
15 | for f in os.listdir(data_dir)
16 | if f.endswith(".h5") or f.endswith(".hdf5")
17 | ]
18 | print(f"[INFO] Found {len(data_files)} data files in {data_dir}")
19 |
20 | # start merging
21 | merged_file = h5py.File(output_file, "w")
22 | # meta data
23 | merged_file.attrs["num_demos"] = len(data_files)
24 | merged_file.attrs["merging_time"] = time.strftime("%Y-%m-%d-%H-%M-%S")
25 | merged_file.attrs["merged_data_files"] = data_files
26 | # merge each demo
27 | for idx, fpath in tqdm(enumerate(data_files), desc="Merging data files"):
28 | f_to_merge = h5py.File(fpath, "r")
29 | demo_grp = merged_file.create_group(f"demo_{idx}")
30 | # passover meta data
31 | for key in f_to_merge.attrs:
32 | demo_grp.attrs[key] = f_to_merge.attrs[key]
33 | # obs group
34 | obs_grp = demo_grp.create_group("obs")
35 | # joint_state
36 | for ks in f_to_merge["obs/joint_state"]:
37 | for k, v in f_to_merge["obs/joint_state"][ks].items():
38 | obs_grp.create_dataset(f"joint_state/{ks}/{k}", data=v[:])
39 | # gripper_state
40 | for ks in f_to_merge["obs/gripper_state"]:
41 | for k, v in f_to_merge["obs/gripper_state"][ks].items():
42 | obs_grp.create_dataset(f"gripper_state/{ks}/{k}", data=v[:])
43 | # link_poses
44 | for k, v in f_to_merge["obs/link_poses"].items():
45 | obs_grp.create_dataset(f"link_poses/{k}", data=v[:])
46 | # fused point cloud
47 | pcd_xyz = f_to_merge["obs/point_cloud/fused/xyz"][:] # (T, N_points, 3)
48 | pcd_rgb = f_to_merge["obs/point_cloud/fused/rgb"][:]
49 | pcd_padding_mask = f_to_merge["obs/point_cloud/fused/padding_mask"][
50 | :
51 | ] # (T, N_points)
52 | obs_grp.create_dataset("point_cloud/fused/xyz", data=pcd_xyz)
53 | obs_grp.create_dataset("point_cloud/fused/rgb", data=pcd_rgb)
54 | obs_grp.create_dataset("point_cloud/fused/padding_mask", data=pcd_padding_mask)
55 | # odom
56 | odom_base_velocity = f_to_merge["obs/odom/base_velocity"][:]
57 | obs_grp.create_dataset("odom/base_velocity", data=odom_base_velocity)
58 | # multiview cameras
59 | head_camera_rgb = f_to_merge["obs/rgb/head/img"][:]
60 | head_camera_depth = f_to_merge["obs/depth/head/depth"][:]
61 | left_wrist_camera_rgb = f_to_merge["obs/rgb/left_wrist/img"][:]
62 | left_wrist_camera_depth = f_to_merge["obs/depth/left_wrist/depth"][:]
63 | right_wrist_camera_rgb = f_to_merge["obs/rgb/right_wrist/img"][:]
64 | right_wrist_camera_depth = f_to_merge["obs/depth/right_wrist/depth"][:]
65 | obs_grp.create_dataset("rgb/head/img", data=head_camera_rgb)
66 | obs_grp.create_dataset("depth/head/depth", data=head_camera_depth)
67 | obs_grp.create_dataset("rgb/left_wrist/img", data=left_wrist_camera_rgb)
68 | obs_grp.create_dataset("depth/left_wrist/depth", data=left_wrist_camera_depth)
69 | obs_grp.create_dataset("rgb/right_wrist/img", data=right_wrist_camera_rgb)
70 | obs_grp.create_dataset("depth/right_wrist/depth", data=right_wrist_camera_depth)
71 | # action
72 | action_grp = demo_grp.create_group("action")
73 | for k, v in f_to_merge["action"].items():
74 | action_grp.create_dataset(k, data=v[:])
75 | f_to_merge.close()
76 | merged_file.close()
77 |
78 |
79 | if __name__ == "__main__":
80 | args = argparse.ArgumentParser()
81 | args.add_argument(
82 | "--data_dir",
83 | type=str,
84 | required=True,
85 | help="Directory containing individual data files to consolidate.",
86 | )
87 | args.add_argument(
88 | "--output_file",
89 | type=str,
90 | required=True,
91 | help="A single output file to save the consolidated data.",
92 | )
93 | args = args.parse_args()
94 | merge(args)
95 |
--------------------------------------------------------------------------------
/scripts/post_process.py:
--------------------------------------------------------------------------------
1 | """
2 | Post-process collected raw data. Including:
3 | - Compute poses for left/right EEFs and head.
4 | - Process raw odometry data such that it is consistent with R1Interface.
5 | """
6 |
7 | import os
8 | import argparse
9 | from math import ceil
10 |
11 | import h5py
12 | import numpy as np
13 | from tqdm import tqdm
14 | import quaternion
15 |
16 | from brs_ctrl.kinematics import R1Kinematics
17 |
18 |
19 | JMAP = {
20 | "left_wrist": "obs/joint_state/left_arm/joint_position",
21 | "right_wrist": "obs/joint_state/right_arm/joint_position",
22 | "torso": "obs/joint_state/torso/joint_position",
23 | }
24 | PCDKMAP = {
25 | "left_wrist": "obs/point_cloud/left_wrist",
26 | "right_wrist": "obs/point_cloud/right_wrist",
27 | "torso": "obs/point_cloud/head",
28 | }
29 | CMAP = {
30 | "left_wrist": "left_wrist_camera",
31 | "right_wrist": "right_wrist_camera",
32 | "torso": "head_camera",
33 | }
34 |
35 |
36 | def main(args):
37 | data_dir = args.data_dir
38 | chunk_size = args.process_chunk_size
39 | assert os.path.exists(data_dir), f"Data directory {data_dir} does not exist"
40 |
41 | kin = R1Kinematics()
42 | T_odom2base = kin.T_odom2base
43 |
44 | data_files = [
45 | os.path.join(data_dir, f)
46 | for f in os.listdir(data_dir)
47 | if f.endswith(".h5") or f.endswith(".hdf5")
48 | ]
49 | print(f"[INFO] Found {len(data_files)} data files in {data_dir}")
50 | for fpath in tqdm(data_files, desc="Processing data files"):
51 | f = h5py.File(fpath, "r+")
52 | all_qpos = {k: f[v][:] for k, v in JMAP.items()} # (T, N_joints)
53 | raw_odom_linear_vel = f["obs/odom/linear_velocity"][:] # (T, 3)
54 | raw_odom_angular_vel = f["obs/odom/angular_velocity"][:] # (T, 3)
55 | # set small vel values to zero, using the same threshold as in R1Interface
56 | vx_zero_idxs = np.abs(raw_odom_linear_vel[:, 0]) <= 1e-2
57 | vy_zero_idxs = np.abs(raw_odom_linear_vel[:, 1]) <= 1e-2
58 | vyaw_zero_idxs = np.abs(raw_odom_angular_vel[:, 2]) <= 5e-3
59 | raw_odom_linear_vel[vx_zero_idxs, 0] = 0
60 | raw_odom_linear_vel[vy_zero_idxs, 1] = 0
61 | raw_odom_angular_vel[vyaw_zero_idxs, 2] = 0
62 |
63 | T = all_qpos[list(JMAP.keys())[0]].shape[0]
64 | link_poses = {
65 | "left_eef": {
66 | "position": np.zeros((T, 3), dtype=np.float32),
67 | "orientation": np.zeros((T, 3, 3), dtype=np.float32),
68 | },
69 | "right_eef": {
70 | "position": np.zeros((T, 3), dtype=np.float32),
71 | "orientation": np.zeros((T, 3, 3), dtype=np.float32),
72 | },
73 | "head": {
74 | "position": np.zeros((T, 3), dtype=np.float32),
75 | "orientation": np.zeros((T, 3, 3), dtype=np.float32),
76 | },
77 | }
78 | base_vel = np.zeros((T, 3), dtype=np.float32) # (v_x, v_y, v_yaw)
79 |
80 | # process in chunks along time dimension
81 | N_chunks = ceil(T / chunk_size)
82 | for chunk_idx in tqdm(range(N_chunks), desc="Processing chunks"):
83 | start_t = chunk_idx * chunk_size
84 | end_t = min((chunk_idx + 1) * chunk_size, T)
85 | for t in range(start_t, end_t):
86 | link2base = kin.get_link_poses_in_base_link(
87 | curr_left_arm_joint=all_qpos["left_wrist"][t, :6],
88 | curr_right_arm_joint=all_qpos["right_wrist"][t, :6],
89 | curr_torso_joint=all_qpos["torso"][t],
90 | ) # dict of (4, 4)
91 | for k in link_poses:
92 | transform = link2base[k]
93 | link_poses[k]["position"][t] = transform[:3, 3]
94 | link_poses[k]["orientation"][t] = transform[:3, :3]
95 |
96 | raw_odom_linear_vel_chunk = raw_odom_linear_vel[
97 | start_t:end_t
98 | ] # (T_chunk, 3)
99 | base_linear_vel_chunk = (
100 | T_odom2base[:3, :3] @ raw_odom_linear_vel_chunk.T
101 | ).T
102 | base_angular_vel_chunk = raw_odom_angular_vel[start_t:end_t]
103 | base_vel[start_t:end_t, :2] = base_linear_vel_chunk[:, :2]
104 | base_vel[start_t:end_t, 2] = base_angular_vel_chunk[:, 2]
105 |
106 | # save link poses to file
107 | link_pose_grp = f.create_group("obs/link_poses")
108 | for k in link_poses:
109 | rot_mat = link_poses[k]["orientation"]
110 | rot_quat = quaternion.as_float_array(
111 | quaternion.from_rotation_matrix(rot_mat)
112 | ) # (T, 4) in wxyz order
113 | # change to xyzw order since that's what pybullet uses
114 | rot_quat = rot_quat[..., [1, 2, 3, 0]]
115 | pose = np.concatenate(
116 | [link_poses[k]["position"], rot_quat], axis=-1
117 | ) # (T, 7)
118 | link_pose_grp.create_dataset(k, data=pose)
119 | # save base velocity to file
120 | f.create_dataset("obs/odom/base_velocity", data=base_vel)
121 | f.close()
122 |
123 |
124 | if __name__ == "__main__":
125 | args = argparse.ArgumentParser()
126 | args.add_argument("--data_dir", type=str, required=True)
127 | args.add_argument("--process_chunk_size", type=int, default=500)
128 | args = args.parse_args()
129 | main(args)
130 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup()
4 |
--------------------------------------------------------------------------------