├── .gitattributes ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── rvt ├── config.py ├── configs │ ├── peract_official_config.yaml │ ├── rvt.yaml │ └── rvt2.yaml ├── eval.py ├── libs │ ├── peract_colab │ │ └── setup.py │ └── point-renderer │ │ ├── .gitattributes │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── demo.png │ │ ├── image_0_splat_2xaa.png │ │ ├── pcd_data.tar.gz │ │ ├── point_renderer │ │ ├── cameras.py │ │ ├── csrc │ │ │ ├── bindings.cpp │ │ │ └── render │ │ │ │ ├── render_feature_pointcloud.cu │ │ │ │ └── render_pointcloud.h │ │ ├── ops.py │ │ ├── profiler.py │ │ ├── renderer.py │ │ ├── rvt_ops.py │ │ └── rvt_renderer.py │ │ ├── pointcloud-notebook.ipynb │ │ ├── requirements.txt │ │ └── setup.py ├── models │ ├── peract_official.py │ └── rvt_agent.py ├── mvt │ ├── __init__.py │ ├── attn.py │ ├── aug_utils.py │ ├── augmentation.py │ ├── config.py │ ├── configs │ │ ├── rvt2.yaml │ │ └── rvt2_partial.yaml │ ├── mvt.py │ ├── mvt_single.py │ ├── raft_utils.py │ ├── renderer.py │ └── utils.py ├── train.py └── utils │ ├── __init__.py │ ├── custom_rlbench_env.py │ ├── dataset.py │ ├── ddp_utils.py │ ├── get_dataset.py │ ├── lr_sched_utils.py │ ├── peract_utils.py │ ├── rlbench_planning.py │ └── rvt_utils.py └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | rlbench/task_design.ttt filter=lfs diff=lfs merge=lfs -text 6 | *.ttt filter=lfs diff=lfs merge=lfs -text 7 | *.ttm filter=lfs diff=lfs merge=lfs -text 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | *__pycache__/ 6 | rvt/data/train/ 7 | rvt/data/val/ 8 | rvt/data/test/ 9 | rvt/replay/ 10 | coppelia_install_dir/ 11 | rvt/runs/ 12 | runs_ngc_mnt 13 | runs_temp 14 | *.ipynb_checkpoints/ 15 | .vscode/ 16 | build/ 17 | dist/ 18 | *.egg-info/ 19 | 20 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "libs/PyRep"] 2 | path = rvt/libs/PyRep 3 | url = https://github.com/stepjam/PyRep.git 4 | [submodule "libs/RLBench"] 5 | path = rvt/libs/RLBench 6 | url = https://github.com/buttomnutstoast/RLBench.git 7 | [submodule "libs/YARR"] 8 | path = rvt/libs/YARR 9 | url = https://github.com/NVlabs/YARR.git 10 | [submodule "libs/peract"] 11 | path = rvt/libs/peract 12 | url = https://github.com/NVlabs/peract.git 13 | [submodule "libs/peract_colab/peract_colab"] 14 | path = rvt/libs/peract_colab/peract_colab 15 | url = https://github.com/peract/peract_colab.git 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA License 2 | 3 | 1. Definitions 4 | 5 | "Licensor" means any person or entity that distributes its Work. 6 | 7 | "Work" means (a) the original work of authorship made available under this 8 | license, which may include software, documentation, or other files, and (b) any 9 | additions to or derivative works thereof that are made available under this 10 | license. 11 | 12 | The terms "reproduce," "reproduction," "derivative works," and "distribution" 13 | have the meaning as provided under U.S. copyright law; provided, however, that 14 | for the purposes of this license, derivative works shall not include works that 15 | remain separable from, or merely link (or bind by name) to the interfaces of, 16 | the Work. 17 | 18 | Works are "made available" under this license by including in or with the Work 19 | either (a) a copyright notice referencing the applicability of this license to 20 | the Work, or (b) a copy of this license. 21 | 22 | 2. License Grant 23 | 24 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each 25 | Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, 26 | copyright license to use, reproduce, prepare derivative works of, publicly 27 | display, publicly perform, sublicense and distribute its Work and any resulting 28 | derivative works in any form. 29 | 30 | 3. Limitations 31 | 32 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do 33 | so under this license, (b) you include a complete copy of this license with 34 | your distribution, and (c) you retain without modification any copyright, 35 | patent, trademark, or attribution notices that are present in the Work. 36 | 37 | 3.2 Derivative Works. You may specify that additional or different terms apply 38 | to the use, reproduction, and distribution of your derivative works of the Work 39 | ("Your Terms") only if (a) Your Terms provide that the use limitation in 40 | Section 3.3 applies to your derivative works, and (b) you identify the specific 41 | derivative works that are subject to Your Terms. Notwithstanding Your Terms, 42 | this license (including the redistribution requirements in Section 3.1) will 43 | continue to apply to the Work itself. 44 | 45 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used 46 | or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA 47 | Corporation and its affiliates may use the Work and any derivative works 48 | commercially. As used herein, "non-commercially" means for research or 49 | evaluation purposes only. 50 | 51 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any 52 | Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to 53 | enforce any patents that you allege are infringed by any Work, then your rights 54 | under this license from such Licensor (including the grant in Section 2.1) will 55 | terminate immediately. 56 | 57 | 3.5 Trademarks. This license does not grant any rights to use any Licensor's or 58 | its affiliates' names, logos, or trademarks, except as necessary to reproduce 59 | the notices described in this license. 60 | 61 | 3.6 Termination. If you violate any term of this license, then your rights 62 | under this license (including the grant in Section 2.1) will terminate 63 | immediately. 64 | 65 | 4. Disclaimer of Warranty. 66 | 67 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 68 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 69 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. 70 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 71 | 72 | 5. Limitation of Liability. 73 | 74 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, 75 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY 76 | LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, 77 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, 78 | THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF 79 | GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR 80 | MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 81 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rvt-2-learning-precise-manipulation-from-few/robot-manipulation-on-rlbench)](https://paperswithcode.com/sota/robot-manipulation-on-rlbench?p=rvt-2-learning-precise-manipulation-from-few) 2 | 3 | [***RVT-2: Learning Precise Manipulation from Few Examples***](https://robotic-view-transformer-2.github.io/)
4 | [Ankit Goyal](http://imankgoyal.github.io), [Valts Blukis](https://www.cs.cornell.edu/~valts/), [Jie Xu](https://people.csail.mit.edu/jiex), [Yijie Guo](https://www.guoyijie.me/), [Yu-Wei Chao](https://research.nvidia.com/person/yu-wei-chao), [Dieter Fox](https://homes.cs.washington.edu/~fox/)
5 | ***RSS 2024*** 6 | 7 | [***RVT: Robotic View Transformer for 3D Object Manipulation***](https://robotic-view-transformer.github.io/)
8 | [Ankit Goyal](http://imankgoyal.github.io), [Jie Xu](https://people.csail.mit.edu/jiex), [Yijie Guo](https://www.guoyijie.me/), [Valts Blukis](https://www.cs.cornell.edu/~valts/), [Yu-Wei Chao](https://research.nvidia.com/person/yu-wei-chao), [Dieter Fox](https://homes.cs.washington.edu/~fox/)
9 | ***CoRL 2023 (Oral)*** 10 | 11 | 21 | 22 |

23 | RVT-2 24 |               25 | RVT 26 |
27 |

RVT-2 solving high precision tasks       Single RVT solving multiple tasks

28 |

29 | 30 | This is the official repository that reproduces the results for [RVT-2](https://robotic-view-transformer-2.github.io/) and [RVT](https://robotic-view-transformer.github.io/). The repository is backward compatible. So you just need to pull the latest commit and can switch from RVT to RVT-2! 31 | 32 | 33 | If you find our work useful, please consider citing our: 34 | ``` 35 | @article{goyal2024rvt2, 36 | title={RVT2: Learning Precise Manipulation from Few Demonstrations}, 37 | author={Goyal, Ankit and Blukis, Valts and Xu, Jie and Guo, Yijie and Chao, Yu-Wei and Fox, Dieter}, 38 | journal={RSS}, 39 | year={2024}, 40 | } 41 | @article{goyal2023rvt, 42 | title={RVT: Robotic View Transformer for 3D Object Manipulation}, 43 | author={Goyal, Ankit and Xu, Jie and Guo, Yijie and Blukis, Valts and Chao, Yu-Wei and Fox, Dieter}, 44 | journal={CoRL}, 45 | year={2023} 46 | } 47 | ``` 48 | 49 | ## Getting Started 50 | 51 | ### Install 52 | - Tested (Recommended) Versions: Python 3.8. We used CUDA 11.1. 53 | 54 | - **Step 1 (Optional):** 55 | We recommend using [conda](https://docs.conda.io/en/latest/miniconda.html) and creating a virtual environment. 56 | ``` 57 | conda create --name rvt python=3.8 58 | conda activate rvt 59 | ``` 60 | 61 | - **Step 2:** Install PyTorch. Make sure the PyTorch version is compatible with the CUDA version. One recommended version compatible with CUDA 11.1 and PyTorch3D can be installed with the following command. More instructions to install PyTorch can be found [here](https://pytorch.org/). 62 | ``` 63 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 64 | ``` 65 | 66 | Recently, we noticed an issue while using conda to install PyTorch. More details can be found [here](https://github.com/pytorch/pytorch/issues/123097). If you face the same issue, you can use the following command to install PyTorch using pip. 67 | ``` 68 | pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 --index-url https://download.pytorch.org/whl/cu113 69 | ``` 70 | 71 | - **Step 3:** Install PyTorch3D. 72 | 73 | You can skip this step if you only want to use RVT-2 as it uses our custom Point-Renderer for rendering. PyTorch3D is required for RVT. 74 | 75 | One recommended version that is compatible with the rest of the library can be installed as follows. Note that this might take some time. For more instructions visit [here](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). 76 | ``` 77 | curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz 78 | tar xzf 1.10.0.tar.gz 79 | export CUB_HOME=$(pwd)/cub-1.10.0 80 | pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable' 81 | ``` 82 | 83 | - **Step 4:** Install CoppeliaSim. PyRep requires version **4.1** of CoppeliaSim. Download and unzip CoppeliaSim: 84 | - [Ubuntu 16.04](https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Player_V4_1_0_Ubuntu16_04.tar.xz) 85 | - [Ubuntu 18.04](https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Player_V4_1_0_Ubuntu18_04.tar.xz) 86 | - [Ubuntu 20.04](https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Player_V4_1_0_Ubuntu20_04.tar.xz) 87 | 88 | Once you have downloaded CoppeliaSim, add the following to your *~/.bashrc* file. (__NOTE__: the 'EDIT ME' in the first line) 89 | 90 | ``` 91 | export COPPELIASIM_ROOT=/PATH/TO/COPPELIASIM/INSTALL/DIR 92 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT 93 | export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT 94 | export DISPLAY=:1.0 95 | ``` 96 | Remember to source your .bashrc (`source ~/.bashrc`) or .zshrc (`source ~/.zshrc`) after this. 97 | 98 | - **Step 5:** Clone the repository with the submodules using the following command. 99 | 100 | ``` 101 | git clone --recurse-submodules git@github.com:NVlabs/RVT.git && cd RVT && git submodule update --init 102 | ``` 103 | 104 | Now, locally install the repository. You can either `pip install -e '.[xformers]'` to install the library with [xformers](https://github.com/facebookresearch/xformers) or `pip install -e .` to install without it. We recommend using the former as improves speed. However, sometimes the installation might fail due to the xformers dependency. In that case, you can install the library without xformers. The performance difference between the two is minimal but speed could be slower without xformers. 105 | ``` 106 | pip install -e '.[xformers]' 107 | ``` 108 | 109 | Install, required libraries for PyRep, RLBench, YARR, PerAct Colab, and Point Renderer. 110 | ``` 111 | pip install -e rvt/libs/PyRep 112 | pip install -e rvt/libs/RLBench 113 | pip install -e rvt/libs/YARR 114 | pip install -e rvt/libs/peract_colab 115 | pip install -e rvt/libs/point-renderer 116 | ``` 117 | 118 | - **Step 6:** Download dataset. 119 | - For experiments on RLBench, we use [pre-generated dataset](https://drive.google.com/drive/folders/0B2LlLwoO3nfZfkFqMEhXWkxBdjJNNndGYl9uUDQwS1pfNkNHSzFDNGwzd1NnTmlpZXR1bVE?resourcekey=0-jRw5RaXEYRLe2W6aNrNFEQ) provided by [PerAct](https://github.com/peract/peract#download). Please download and place them under `RVT/rvt/data/xxx` where `xxx` is either `train`, `test`, or `val`. 120 | 121 | - Additionally, we use the same dataloader as PerAct, which is based on [YARR](https://github.com/stepjam/YARR). YARR creates a replay buffer on the fly which can increase the startup time. We provide an option to directly load the replay buffer from the disk. We recommend using the pre-generated replay buffer (98 GB) as it reduces the startup time. You can download the replay buffer for [indidual tasks](https://huggingface.co/datasets/ankgoyal/rvt/tree/main/replay). After downloading, uncompress the replay buffer(s) (for example using the command `tar -xf .tar.xz`) and place it under `RVT/rvt/replay/replay_xxx/` where `xxx` is either `train` or `val`. It is useful only if you want to train RVT from scratch and not needed if you want to evaluate the pre-trained model. 122 | 123 | 124 | ## Using the library 125 | 126 | ### Training 127 | ##### Training RVT-2 128 | 129 | To train RVT-2 on all RLBench tasks, use the following command (from folder `RVT/rvt`): 130 | ``` 131 | python train.py --exp_cfg_path configs/rvt2.yaml --mvt_cfg_path mvt/configs/rvt2.yaml --device 0,1,2,3,4,5,6,7 132 | ``` 133 | 134 | ##### Training RVT 135 | To train RVT, use the following command (from folder `RVT/rvt`): 136 | ``` 137 | python train.py --exp_cfg_path configs/rvt.yaml --device 0,1,2,3,4,5,6,7 138 | ``` 139 | We use 8 V100 GPUs. Change the `device` flag depending on available compute. 140 | 141 | ##### More details about `train.py` 142 | - default parameters for an `experiment` are defined [here](https://github.com/NVlabs/RVT/blob/master/rvt/config.py). 143 | - default parameters for `rvt` are defined [here](https://github.com/NVlabs/RVT/blob/master/rvt/mvt/config.py). 144 | - the parameters in for `experiment` and `rvt` can be overwritten by two ways: 145 | - specifying the path of a yaml file 146 | - manually overwriting using a `opts` string of format ` ..` 147 | - Manual overwriting has higher precedence over the yaml file. 148 | 149 | ``` 150 | python train.py --exp_cfg_opts <> --mvt_cfg_opts <> --exp_cfg_path <> --mvt_cfg_path <> 151 | ``` 152 | 153 | The following command overwrites the parameters for the `experiment` with the `configs/all.yaml` file. It also overwrites the `bs` parameters through the command line. 154 | ``` 155 | python train.py --exp_cfg_opts "bs 4" --exp_cfg_path configs/rvt.yaml --device 0 156 | ``` 157 | 158 | ### Evaluate on RLBench 159 | ##### Evaluate RVT-2 on RLBench 160 | Download the [pretrained RVT-2 model](https://huggingface.co/ankgoyal/rvt/tree/main/rvt2). Place the model (`model_99.pth` trained for 99 epochs or ~80K steps with batch size 192) and the config files under the folder `RVT/rvt/runs/rvt2/`. Run evaluation using (from folder `RVT/rvt`): 161 | ``` 162 | python eval.py --model-folder runs/rvt2 --eval-datafolder ./data/test --tasks all --eval-episodes 25 --log-name test/1 --device 0 --headless --model-name model_99.pth 163 | ``` 164 | ##### Evaluate RVT on RLBench 165 | Download the [pretrained RVT model](https://huggingface.co/ankgoyal/rvt/tree/main/rvt). Place the model (`model_14.pth` trained for 15 epochs or 100K steps) and the config files under the folder `runs/rvt/`. Run evaluation using (from folder `RVT/rvt`): 166 | ``` 167 | python eval.py --model-folder runs/rvt --eval-datafolder ./data/test --tasks all --eval-episodes 25 --log-name test/1 --device 0 --headless --model-name model_14.pth 168 | ``` 169 | 170 | ##### Evaluate the official PerAct model on RLBench 171 | Download the [officially released PerAct model](https://github.com/peract/peract/releases/download/v1.0.0/peract_600k.zip). 172 | Put the downloaded policy under the `runs` folder with the recommended folder layout: `runs/peract_official/seed0`. 173 | Run the evaluation using: 174 | ``` 175 | python eval.py --eval-episodes 25 --peract_official --peract_model_dir runs/peract_official/seed0/weights/600000 --model-name QAttentionAgent_layer0.pt --headless --task all --eval-datafolder ./data/test --device 0 176 | ``` 177 | 178 | ## Gotchas 179 | - If you face issues installing `xformers` and PyTorch3D, information in this issue might be useful https://github.com/NVlabs/RVT/issues/45. 180 | 181 | - If you get qt plugin error like `qt.qpa.plugin: Could not load the Qt platform plugin "xcb" /cv2/qt/plugins" even though it was found`, try uninstalling opencv-python and installing opencv-python-headless 182 | 183 | ``` 184 | pip uninstall opencv-python 185 | pip install opencv-python-headless 186 | ``` 187 | 188 | - If you have CUDA 11.7, an alternate installation strategy could be to use the following command for Step 2 and Step 3. Note that this is not heavily tested. 189 | ``` 190 | # Step 2: 191 | pip install pytorch torchvision torchaudio 192 | # Step 3: 193 | pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable' 194 | ``` 195 | 196 | - If you are having issues running evaluation on a headless server, please refer to https://github.com/NVlabs/RVT/issues/2#issuecomment-1620704943. 197 | 198 | - If you want to generate visualization videos, please refer to https://github.com/NVlabs/RVT/issues/5. 199 | 200 | ## FAQ's 201 | ###### Q. What is the advantag of RVT-2 and RVT over PerAct? 202 | RVT's are both faster to train and performs better than PerAct.
203 | 204 | 205 | ###### Q. What resources are required to train RVT? 206 | For training on 18 RLBench tasks, with 100 demos per task, we use 8 V100 GPUs (16 GB memory each). The model trains in ~1 day. 207 | 208 | Note that for fair comparison with PerAct, we used the same dataset, which means [duplicate keyframes are loaded into the replay buffer](https://github.com/peract/peract#why-are-duplicate-keyframes-loaded-into-the-replay-buffer). For other datasets, one could consider not doing so, which might further speed up training. 209 | 210 | ###### Q. Why do you use `pe_fix=True` in the rvt [config](https://github.com/NVlabs/RVT/blob/master/rvt/mvt/config.py#L32)? 211 | For fair comparison with offical PerAct model, we use this setting. More detials about this can be found in PerAct [code](https://github.com/peract/peract/blob/main/agents/peract_bc/perceiver_lang_io.py#L387-L398). For future, we recommend using `pe_fix=False` for language input. 212 | 213 | ###### Q. Why are the results for PerAct different from the PerAct paper? 214 | In the PerAct paper, for each task, the best checkpoint is chosen based on the validation set performance. Hence, the model weights can be different for different tasks. We evaluate PerAct and RVT only on the final checkpoint, so that all tasks are strictly evaluated on the same model weights. Note that only the final model for PerAct has been released officially. 215 | 216 | ###### Q. Why is there a variance in performance on RLBench even when evaluting the same checkpoint? 217 | We hypothesize that it is because of the sampling based planner used in RLBench, which could be the source of the randomization. Hence, we evaluate each checkpoint 5 times and report mean and variance. 218 | 219 | ###### Q. Why did you use a cosine decay learning rate scheduler instead of a fixed learning rate schedule as done in PerAct? 220 | We found the cosine learning rate scheduler led to faster convergence for RVT. Training PerAct with our training hyper-parameters (cosine learning rate scheduler and same number of iterations) led to worse performance (in ~4 days of training time). Hence for Fig. 1, we used the official hyper-parameters for PerAct. 221 | 222 | ###### Q. For my use case, I want to render images at real camera locations (input camera poses) with PyTorch3D. Is it possible to do so and how can I do that? 223 | Yes, it is possible to do so. A self-sufficient example is present [here](https://github.com/NVlabs/RVT/issues/9). Depending on your use case, the code may need be modified. Also note that 3D augmentation cannot be used while rendering images at real camera locations as it would change the pose of the camera with respect to the point cloud. 224 | 225 | For questions and comments, please contact [Ankit Goyal](https://imankgoyal.github.io/). 226 | 227 | ## Acknowledgement 228 | We sincerely thank the authors of the following repositories for sharing their code. 229 | 230 | - [PerAct](https://github.com/peract/peract) 231 | - [PerAct Colab](https://github.com/peract/peract_colab/tree/master) 232 | - [PyRep](https://github.com/stepjam/PyRep) 233 | - [RLBench](https://github.com/stepjam/RLBench/tree/master) 234 | - [YARR](https://github.com/stepjam/YARR) 235 | 236 | ## License 237 | License Copyright © 2023, NVIDIA Corporation & affiliates. All rights reserved. 238 | 239 | This work is made available under the [Nvidia Source Code License](https://github.com/NVlabs/RVT/blob/master/LICENSE). 240 | The pretrained RVT models are released under the CC-BY-NC-SA-4.0 license. 241 | -------------------------------------------------------------------------------- /rvt/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | _C = CN() 8 | 9 | _C.agent = "our" 10 | _C.tasks = "insert_onto_square_peg,open_drawer,place_wine_at_rack_location,light_bulb_in" 11 | _C.exp_id = "def" 12 | _C.resume = "" 13 | # bs per device, effective bs is scaled by num device 14 | _C.bs = 4 15 | _C.epochs = 20 16 | # number of dataloader workers, >= 0 17 | _C.num_workers = 0 18 | # 'transition_uniform' or 'task_uniform' 19 | _C.sample_distribution_mode = 'transition_uniform' 20 | _C.train_iter = 16 * 10000 21 | 22 | # arguments present in both peract and rvt 23 | # some of them donot support every possible combination in peract 24 | _C.peract = CN() 25 | _C.peract.lambda_weight_l2 = 1e-6 26 | # lr should be thought on per sample basis 27 | # effective lr is multiplied by bs * num_devices 28 | _C.peract.lr = 2.5e-5 29 | _C.peract.optimizer_type = "lamb" 30 | _C.peract.warmup_steps = 0 31 | _C.peract.lr_cos_dec = False 32 | _C.peract.add_rgc_loss = True 33 | _C.peract.num_rotation_classes = 72 34 | _C.peract.amp = False 35 | _C.peract.bnb = False 36 | _C.peract.transform_augmentation = True 37 | _C.peract.transform_augmentation_xyz = [0.1, 0.1, 0.1] 38 | _C.peract.transform_augmentation_rpy = [0.0, 0.0, 20.0] 39 | 40 | # arguments present in only rvt and not peract 41 | _C.rvt = CN() 42 | _C.rvt.gt_hm_sigma = 1.5 43 | _C.rvt.img_aug = 0.1 44 | _C.rvt.place_with_mean = True 45 | _C.rvt.move_pc_in_bound = True 46 | 47 | # arguments present in peract official 48 | _C.peract_official = CN() 49 | _C.peract_official.cfg_path = "configs/peract_official_config.yaml" 50 | 51 | 52 | def get_cfg_defaults(): 53 | """Get a yacs CfgNode object with default values for my_project.""" 54 | return _C.clone() 55 | -------------------------------------------------------------------------------- /rvt/configs/peract_official_config.yaml: -------------------------------------------------------------------------------- 1 | # copied from: https://github.com/peract/peract/releases/download/v1.0.0/peract_600k.zip 2 | method: 3 | name: PERACT_BC 4 | lr: 0.0005 5 | lr_scheduler: false 6 | num_warmup_steps: 3000 7 | optimizer: lamb 8 | activation: lrelu 9 | norm: None 10 | lambda_weight_l2: 1.0e-06 11 | trans_loss_weight: 1.0 12 | rot_loss_weight: 1.0 13 | grip_loss_weight: 1.0 14 | collision_loss_weight: 1.0 15 | rotation_resolution: 5 16 | image_crop_size: 64 17 | bounds_offset: 18 | - 0.15 19 | voxel_sizes: 20 | - 100 21 | num_latents: 2048 22 | latent_dim: 512 23 | transformer_depth: 6 24 | transformer_iterations: 1 25 | cross_heads: 1 26 | cross_dim_head: 64 27 | latent_heads: 8 28 | latent_dim_head: 64 29 | pos_encoding_with_lang: false 30 | lang_fusion_type: seq 31 | voxel_patch_size: 5 32 | voxel_patch_stride: 5 33 | input_dropout: 0.1 34 | attn_dropout: 0.1 35 | decoder_dropout: 0.0 36 | crop_augmentation: true 37 | final_dim: 64 38 | transform_augmentation: 39 | apply_se3: true 40 | aug_xyz: 41 | - 0.125 42 | - 0.125 43 | - 0.125 44 | aug_rpy: 45 | - 0.0 46 | - 0.0 47 | - 0.0 48 | aug_rot_resolution: 5 49 | demo_augmentation: true 50 | demo_augmentation_every_n: 10 51 | no_skip_connection: false 52 | no_perceiver: false 53 | no_language: false 54 | keypoint_method: heuristic 55 | ddp: 56 | master_addr: "localhost" 57 | master_port: "29500" 58 | num_devices: 1 59 | rlbench: 60 | task_name: multi 61 | tasks: 62 | - change_channel 63 | - close_jar 64 | - insert_onto_square_peg 65 | - light_bulb_in 66 | - meat_off_grill 67 | - open_drawer 68 | - place_cups 69 | - place_shape_in_shape_sorter 70 | - push_buttons 71 | - put_groceries_in_cupboard 72 | - put_item_in_drawer 73 | - put_money_in_safe 74 | - reach_and_drag 75 | - stack_blocks 76 | - stack_cups 77 | - turn_tap 78 | - set_clock_to_time 79 | - place_wine_at_rack_location 80 | - put_rubbish_in_color_bin 81 | - slide_block_to_color_target 82 | - sweep_to_dustpan_of_size 83 | demos: 100 84 | demo_path: /raid/dataset/ 85 | episode_length: 25 86 | cameras: 87 | - front 88 | - left_shoulder 89 | - right_shoulder 90 | - wrist 91 | camera_resolution: 92 | - 128 93 | - 128 94 | scene_bounds: 95 | - -0.3 96 | - -0.5 97 | - 0.6 98 | - 0.7 99 | - 0.5 100 | - 1.6 101 | include_lang_goal_in_obs: True 102 | replay: 103 | batch_size: 16 104 | timesteps: 1 105 | prioritisation: false 106 | task_uniform: true 107 | use_disk: true 108 | path: /raid/arm/replay 109 | max_parallel_processes: 32 110 | framework: 111 | log_freq: 100 112 | save_freq: 10000 113 | train_envs: 1 114 | replay_ratio: 16 115 | transitions_before_train: 200 116 | tensorboard_logging: true 117 | csv_logging: true 118 | training_iterations: 600001 119 | gpu: 0 120 | env_gpu: 0 121 | logdir: /home/user/workspace/logs_may16_n100 122 | seeds: 1 123 | start_seed: 0 124 | load_existing_weights: true 125 | num_weights_to_keep: 60 126 | record_every_n: 5 127 | 128 | -------------------------------------------------------------------------------- /rvt/configs/rvt.yaml: -------------------------------------------------------------------------------- 1 | exp_id: rvt 2 | tasks: all 3 | bs: 3 4 | num_workers: 3 5 | epochs: 15 6 | sample_distribution_mode: task_uniform 7 | peract: 8 | lr: 1e-4 9 | warmup_steps: 2000 10 | optimizer_type: lamb 11 | lr_cos_dec: True 12 | transform_augmentation_xyz: [0.125, 0.125, 0.125] 13 | transform_augmentation_rpy: [0.0, 0.0, 45.0] 14 | rvt: 15 | place_with_mean: False 16 | -------------------------------------------------------------------------------- /rvt/configs/rvt2.yaml: -------------------------------------------------------------------------------- 1 | exp_id: rvt2 2 | tasks: all 3 | bs: 24 4 | num_workers: 3 5 | epochs: 15 6 | sample_distribution_mode: task_uniform 7 | peract: 8 | lr: 1.25e-5 9 | warmup_steps: 2000 10 | optimizer_type: lamb 11 | lr_cos_dec: True 12 | transform_augmentation_xyz: [0.125, 0.125, 0.125] 13 | transform_augmentation_rpy: [0.0, 0.0, 45.0] 14 | amp: True 15 | bnb: True 16 | lambda_weight_l2: 1e-4 17 | rvt: 18 | place_with_mean: False 19 | img_aug: 0.0 20 | -------------------------------------------------------------------------------- /rvt/libs/peract_colab/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | """ 6 | Setup of peract 7 | Author: Ankit Goyal 8 | """ 9 | from setuptools import setup 10 | 11 | requirements = [ 12 | ] 13 | 14 | setup( 15 | name="peract_colab", 16 | # version=__version__, 17 | long_description="", 18 | url="", 19 | keywords="robotics computer vision", 20 | classifiers=[ 21 | "Programming Language :: Python", 22 | ], 23 | packages=["peract_colab"], 24 | install_requires=requirements, 25 | ) 26 | 27 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/.gitattributes: -------------------------------------------------------------------------------- 1 | pcd_data.tar.gz filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | *.egg-info/* 3 | *.so 4 | */__pycache__/* 5 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022-2023, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for instant neural graphics primitives 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/README.md: -------------------------------------------------------------------------------- 1 | ## Point Renderer 2 | A minimal, lightweight CUDA-accelerated renderer of pointclouds. 3 | 4 |
5 | 6 | ### Install 7 | 8 | ``` 9 | pip install -r requirements.txt 10 | pip install -e . 11 | ``` 12 | 13 | ### Run 14 | 15 | **Load Data** 16 | Extract included pcd_data.tar.gz 17 | 18 | ``` 19 | import numpy as np 20 | 21 | data = np.load("pcd_data/w1280_h720/3.npy", allow_pickle=True) 22 | data = data[None][0] 23 | pc = data["pc"] 24 | rgb = data["img_feat"] 25 | ``` 26 | 27 | **Render the image** 28 | 29 | ``` 30 | # Make the renderer 31 | from point_renderer.renderer import PointRenderer 32 | renderer = PointRenderer(device="cuda", perf_timer=False) 33 | 34 | # Define a batch of cameras 35 | img_size = (512, 512) 36 | K = renderer.get_camera_intrinsics(hfov=70, img_size=img_size) 37 | camera_poses = renderer.get_batch_of_camera_poses( 38 | cam_positions=[[1.5, 1.5, 1.5],[-1.5, -1.5, -1.5]], 39 | cam_lookats=[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]]) 40 | 41 | # Render the pointcloud from the given cameras 42 | images, depths = renderer.render_batch(pc, rgb, camera_poses, K, img_size, 43 | default_color=1.0, 44 | splat_radius=0.005, 45 | aa_factor=2 46 | ) 47 | 48 | # Show the results 49 | plt.imshow(images[0].detach().cpu().numpy()); plt.show() 50 | plt.imshow(depths[0].detach().cpu().numpy()); plt.show() 51 | plt.imshow(images[1].detach().cpu().numpy()); plt.show() 52 | plt.imshow(depths[1].detach().cpu().numpy()); plt.show() 53 | ``` 54 | 55 | .. Or run the jupyter notebook that has this same code above, and also all the benchmarks. 56 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RVT/367995a1a2169b6352bf4e8b0ed405890462a3a0/rvt/libs/point-renderer/demo.png -------------------------------------------------------------------------------- /rvt/libs/point-renderer/image_0_splat_2xaa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RVT/367995a1a2169b6352bf4e8b0ed405890462a3a0/rvt/libs/point-renderer/image_0_splat_2xaa.png -------------------------------------------------------------------------------- /rvt/libs/point-renderer/pcd_data.tar.gz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c47f7c538558e941a82b49c2dfa3f0eb95618a7beaf1cb69e697acb6692983e9 3 | size 12732634 4 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/cameras.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from point_renderer import ops 4 | from functools import lru_cache 5 | 6 | @lru_cache(maxsize=32) 7 | def linalg_inv(poses): 8 | return torch.linalg.inv(poses) 9 | 10 | class Cameras: 11 | def __init__(self, poses, intrinsics, img_size, inv_poses=None): 12 | self.poses = poses 13 | self.img_size = img_size 14 | if inv_poses is None: 15 | self.inv_poses = linalg_inv(poses) 16 | else: 17 | self.inv_poses = inv_poses 18 | self.intrinsics = intrinsics 19 | 20 | def __len__(self): 21 | return len(self.poses) 22 | 23 | def scale(self, constant): 24 | self.intrinsics = self.intrinsics.clone() 25 | self.intrinsics[:, :2, :3] *= constant 26 | 27 | def is_orthographic(self): 28 | raise ValueError("is_orthographic should be called on child classes only") 29 | 30 | def is_perspective(self): 31 | raise ValueError("is_perspective should be called on child classes only") 32 | 33 | 34 | class PerspectiveCameras(Cameras): 35 | def __init__(self, poses, intrinsics, img_size, inv_poses=None): 36 | super().__init__(poses, intrinsics, img_size, inv_poses) 37 | 38 | @classmethod 39 | def from_lookat(cls, eyes, ats, ups, hfov, img_size, device="cpu"): 40 | cam_poses = [] 41 | for eye, at, up in zip(eyes, ats, ups): 42 | T = ops.lookat_to_cam_pose(eye, at, up, device=device) 43 | cam_poses.append(T) 44 | cam_poses = torch.stack(cam_poses, dim=0) 45 | intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device) 46 | intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous() 47 | return PerspectiveCameras(cam_poses, intrinsics, img_size) 48 | 49 | @classmethod 50 | def from_rotation_and_translation(cls, R, T, S, hfov, img_size): 51 | device = R.device 52 | assert T.device == device 53 | cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float) 54 | cam_poses[:, :3, :3] = R * S[None, :] 55 | cam_poses[:, :3, 3] = T 56 | cam_poses[:, 3, 3] = 1.0 57 | intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device) 58 | intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous() 59 | return PerspectiveCameras(cam_poses, intrinsics, img_size) 60 | 61 | def to(self, device): 62 | return PerspectiveCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device)) 63 | 64 | def is_orthographic(self): 65 | return False 66 | 67 | def is_perspective(self): 68 | return True 69 | 70 | class OrthographicCameras(Cameras): 71 | def __init__(self, poses, intrinsics, img_size, inv_poses=None): 72 | super().__init__(poses, intrinsics, img_size, inv_poses) 73 | 74 | @classmethod 75 | def from_lookat(cls, eyes, ats, ups, img_sizes_w, img_size_px, device="cpu"): 76 | """ 77 | Args: 78 | eyes: Nx3 tensor of camera coordinates 79 | ats: Nx3 tensor of look-at directions 80 | ups: Nx3 tensor of up-vectors 81 | scale: Nx2 tensor defining image sizes in world coordinates 82 | img_size: 2-dim tuple defining image size in pixels 83 | Returns: 84 | OrthographicCamera 85 | """ 86 | if isinstance(img_sizes_w, list): 87 | img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(eyes), 1)) 88 | 89 | cam_poses = [] 90 | for eye, at, up in zip(eyes, ats, ups): 91 | T = ops.lookat_to_cam_pose(eye, at, up, device=device) 92 | cam_poses.append(T) 93 | cam_poses = torch.stack(cam_poses, dim=0) 94 | intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device) 95 | return OrthographicCameras(cam_poses, intrinsics, img_size_px) 96 | 97 | @classmethod 98 | def from_rotation_and_translation(cls, R, T, img_sizes_w, img_size_px, device="cpu"): 99 | if isinstance(img_sizes_w, list): 100 | img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(R), 1)) 101 | 102 | device = R.device 103 | assert T.device == device 104 | cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float) 105 | cam_poses[:, :3, :3] = R 106 | cam_poses[:, :3, 3] = T 107 | cam_poses[:, 3, 3] = 1.0 108 | intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device) 109 | intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous() 110 | return OrthographicCameras(cam_poses, intrinsics, img_size_px) 111 | 112 | def to(self, device): 113 | return OrthographicCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device)) 114 | 115 | def is_orthographic(self): 116 | return True 117 | 118 | def is_perspective(self): 119 | return False 120 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/csrc/bindings.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION and its licensors retain all intellectual property 5 | * and proprietary rights in and to this software, related documentation 6 | * and any modifications thereto. Any use, reproduction, disclosure or 7 | * distribution of this software and related documentation without an express 8 | * license agreement from NVIDIA CORPORATION is strictly prohibited. 9 | */ 10 | 11 | /** @file bindings.cpp 12 | * @author Valts Blukis, NVIDIA 13 | * @brief PyTorch bindings for pointcloud renderer 14 | */ 15 | 16 | #include 17 | #include "./render/render_pointcloud.h" 18 | 19 | namespace rvt { 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | py::module render = m.def_submodule("render"); 23 | render.def("render_feature_pointcloud_to_image", &render_feature_pointcloud_to_image); 24 | render.def("screen_space_splatting", &screen_space_splatting); 25 | render.def("aa_subsample", &aa_subsample); 26 | } 27 | 28 | } 29 | 30 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/csrc/render/render_feature_pointcloud.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION and its licensors retain all intellectual property 5 | * and proprietary rights in and to this software, related documentation 6 | * and any modifications thereto. Any use, reproduction, disclosure or 7 | * distribution of this software and related documentation without an express 8 | * license agreement from NVIDIA CORPORATION is strictly prohibited. 9 | */ 10 | 11 | /** @file render_feature_pointcloud.cu 12 | * @author Valts Blukis, NVIDIA 13 | * @brief Renders a point cloud with associated feature vectors to image 14 | */ 15 | 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | namespace rvt { 26 | 27 | 28 | // By what factor to scale depth for integer representation 29 | __constant__ const float DEPTH_FACTOR = 1000; 30 | __constant__ const float DEPTH_INV_FACTOR = 1 / DEPTH_FACTOR; 31 | 32 | 33 | __global__ void render_pointcloud_to_depth_index_buffer_cuda_kernel( 34 | int64_t num_points, 35 | int64_t img_height, 36 | int64_t img_width, 37 | 38 | int64_t* point_indices, 39 | float* point_depths, 40 | 41 | uint64_t* packed_buffer 42 | ){ 43 | uint32_t tidx = blockDim.x * blockIdx.x + threadIdx.x; 44 | 45 | if (tidx < num_points) { 46 | int64_t pixel_index = point_indices[tidx]; 47 | 48 | if (pixel_index >= 0 && pixel_index < img_height * img_width) { 49 | float point_depth = point_depths[tidx]; 50 | uint32_t point_depth_mm = (uint32_t) (point_depth * DEPTH_FACTOR); 51 | uint64_t packed_depth_and_index = ((uint64_t) point_depth_mm << 32) | ((uint64_t) tidx); 52 | atomicMin((unsigned long long*) (packed_buffer + pixel_index), (unsigned long long)packed_depth_and_index); 53 | } 54 | } 55 | } 56 | 57 | 58 | __global__ void output_render_to_feature_image_cuda_kernel( 59 | int64_t img_height, 60 | int64_t img_width, 61 | int64_t num_channels, 62 | int32_t num_points, 63 | float* point_features, 64 | uint64_t* packed_buffer, 65 | float* image_out, 66 | float* depth_out, 67 | float default_depth, 68 | float default_feature 69 | ) { 70 | 71 | uint tidx = blockDim.x * blockIdx.x + threadIdx.x; 72 | if (tidx < img_height * img_width) { 73 | uint64_t packed = packed_buffer[tidx]; 74 | uint32_t packed_depth_mm = (uint32_t) (packed >> 32); 75 | // The modulo is to support batching without having to tile the features 76 | uint32_t packed_index = ((uint32_t) packed) % num_points; 77 | 78 | if (packed_depth_mm == 0xFFFFFFFF) { 79 | depth_out[tidx] = default_depth; 80 | for (int i = 0; i < num_channels; i++) { 81 | image_out[tidx * num_channels + i] = default_feature; 82 | } 83 | } 84 | else { 85 | depth_out[tidx] = (float) (packed_depth_mm) * DEPTH_INV_FACTOR; 86 | for (int i = 0; i < num_channels; i++) { 87 | image_out[tidx * num_channels + i] = point_features[packed_index * num_channels + i]; 88 | } 89 | } 90 | } 91 | } 92 | 93 | 94 | __global__ void output_render_to_feature_image_cuda_2d_kernel( 95 | int64_t img_height, 96 | int64_t img_width, 97 | int64_t num_channels, 98 | int32_t num_points, 99 | float* point_features, 100 | uint64_t* packed_buffer, 101 | float* image_out, 102 | float* depth_out, 103 | float default_depth, 104 | float default_feature 105 | ) { 106 | uint pixidx = blockDim.x * blockIdx.x + threadIdx.x; 107 | uint cidx = blockDim.y * blockIdx.y + threadIdx.y; 108 | 109 | if (pixidx < img_height * img_width && cidx < num_channels) { 110 | uint64_t packed = packed_buffer[pixidx]; 111 | uint32_t packed_depth_mm = (uint32_t) (packed >> 32); 112 | // The modulo is to support batching without having to tile the features 113 | uint32_t packed_index = ((uint32_t) packed) % num_points; 114 | 115 | if (packed_depth_mm == 0xFFFFFFFF) { 116 | if (cidx == 0) 117 | depth_out[pixidx] = default_depth; 118 | image_out[pixidx * num_channels + cidx] = default_feature; 119 | } 120 | else { 121 | if (cidx == 0) 122 | depth_out[pixidx] = (float) (packed_depth_mm) * DEPTH_INV_FACTOR; 123 | image_out[pixidx * num_channels + cidx] = point_features[packed_index * num_channels + cidx]; 124 | } 125 | } 126 | } 127 | 128 | 129 | void render_feature_pointcloud_to_image( 130 | at::Tensor point_indices, // Index into flattened image. -1 if out of bounds. 131 | at::Tensor point_depths, 132 | at::Tensor point_features, 133 | at::Tensor image_out, 134 | at::Tensor depth_out, 135 | float default_depth, 136 | float default_color) { 137 | 138 | int64_t num_points = point_indices.size(0); 139 | int32_t num_points_per_batch = point_features.size(0); 140 | int64_t img_height = image_out.size(0); 141 | int64_t img_width = image_out.size(1); 142 | int64_t num_channels = image_out.size(2); 143 | 144 | if (num_channels != point_features.size(1)) { 145 | throw std::runtime_error("Output image and point features must have the same channel dimension"); 146 | } 147 | 148 | // TODO: Play with this to see if we can speed it up 149 | uint64_t num_threads_per_block = 1024; 150 | 151 | // Make sure cudaMalloc uses the correct device 152 | int device_index = point_indices.get_device(); 153 | cudaSetDevice(device_index); 154 | 155 | // Allocate memory for storing packed depths and colors 156 | uint64_t* packed_depths_and_indices; 157 | cudaMalloc((void**) &packed_depths_and_indices, img_width*img_height*sizeof(uint64_t)); 158 | cudaMemset(packed_depths_and_indices, 0xFFFFFFFFFFFFFFFF, img_width*img_height*sizeof(uint64_t)); 159 | 160 | render_pointcloud_to_depth_index_buffer_cuda_kernel<<<(num_points + num_threads_per_block - 1) / num_threads_per_block, num_threads_per_block>>>( 161 | num_points, 162 | img_height, 163 | img_width, 164 | 165 | point_indices.data_ptr(), 166 | point_depths.data_ptr(), 167 | 168 | packed_depths_and_indices); 169 | 170 | // With few channels, it's faster to launch a thread per pixel, in each thread looping over the channels and copying the data 171 | if (num_channels < 10) 172 | { 173 | output_render_to_feature_image_cuda_kernel<<<(img_height * img_width + num_threads_per_block - 1) / num_threads_per_block, num_threads_per_block>>>( 174 | img_height, 175 | img_width, 176 | num_channels, 177 | num_points_per_batch, 178 | point_features.data_ptr(), 179 | packed_depths_and_indices, 180 | image_out.data_ptr(), 181 | depth_out.data_ptr(), 182 | default_depth, 183 | default_color); 184 | } 185 | // With more channels, it's better to launch a separate thread per pixel per channel, in each thread copying only one feature scalar 186 | else 187 | { 188 | output_render_to_feature_image_cuda_2d_kernel<<>>( 189 | img_height, 190 | img_width, 191 | num_channels, 192 | num_points_per_batch, 193 | point_features.data_ptr(), 194 | packed_depths_and_indices, 195 | image_out.data_ptr(), 196 | depth_out.data_ptr(), 197 | default_depth, 198 | default_color); 199 | } 200 | 201 | cudaFree(packed_depths_and_indices); 202 | } 203 | 204 | 205 | __global__ void screen_space_splatting_cuda_kernel( 206 | int64_t batch_size, 207 | int64_t img_height, 208 | int64_t img_width, 209 | int64_t num_channels, 210 | int k, 211 | float splat_radius, 212 | float focal_length_px, 213 | float* depth_in, 214 | float* image_in, 215 | float* depth_out, 216 | float* image_out, 217 | float default_depth, 218 | bool orthographic 219 | ) { 220 | uint x = blockDim.x * blockIdx.x + threadIdx.x; 221 | uint y = blockDim.y * blockIdx.y + threadIdx.y; 222 | uint c_index = y * img_width + x; 223 | uint batch_elem_height = img_height / batch_size; 224 | 225 | if (y < img_height && x < img_width) { 226 | float center_depth = depth_in[c_index]; 227 | float min_depth = center_depth; 228 | int splat_index = c_index; 229 | 230 | int b_elem = y / batch_elem_height; 231 | 232 | // Loop over pixel's neighbourhood 233 | for (int dx = -k/2; dx <= k/2; dx++) { 234 | for (int dy = -k/2; dy <= k/2; dy++) { 235 | int nx = x + dx; 236 | int ny = y + dy; 237 | if (nx >= img_width || nx < 0 || ny >= (b_elem + 1) * batch_elem_height || ny < b_elem * batch_elem_height) { 238 | continue; 239 | } 240 | // ignore the center pixel itself 241 | /*if (dx == 0 && dy == 0) { 242 | continue; 243 | }*/ 244 | int n_index = ny * img_width + nx; 245 | 246 | // Compute neighbor's splat size in pixels 247 | float neighbor_depth = depth_in[n_index]; 248 | 249 | // If neighbor is further than current center value, or is unobserved, ignore it 250 | if (neighbor_depth == default_depth || (neighbor_depth > min_depth && min_depth != default_depth)) { 251 | continue; 252 | } 253 | 254 | // Otherwise neighbor is closer to camera than center. Consider it. 255 | float n_splat_size_px; 256 | if (orthographic) { 257 | n_splat_size_px = focal_length_px * splat_radius; 258 | } else { 259 | n_splat_size_px = focal_length_px * splat_radius / neighbor_depth; 260 | } 261 | 262 | float n_dst = sqrt((float)(dx * dx + dy * dy)); 263 | // If the splat is big enough to cover the center pixel, remember it 264 | if (n_splat_size_px > n_dst && neighbor_depth) { 265 | splat_index = n_index; 266 | min_depth = neighbor_depth; 267 | } 268 | } 269 | } 270 | 271 | // TODO: we can consider applying some blending instead of just a harsh copy 272 | depth_out[c_index] = depth_in[splat_index]; 273 | for (int i = 0; i < num_channels; i++) { 274 | image_out[c_index * num_channels + i] = image_in[splat_index * num_channels + i]; 275 | } 276 | } 277 | } 278 | 279 | 280 | void screen_space_splatting( 281 | int batch_size, 282 | at::Tensor depth_in, 283 | at::Tensor image_in, 284 | at::Tensor depth_out, 285 | at::Tensor image_out, 286 | float default_depth, 287 | float focal_length_px, 288 | float splat_radius, 289 | int kernel_size, 290 | bool orthographic 291 | ){ 292 | int64_t img_height = image_out.size(0); 293 | int64_t img_width = image_out.size(1); 294 | int64_t num_channels = image_out.size(2); 295 | 296 | int num_threads_per_block_x = 16; 297 | screen_space_splatting_cuda_kernel<<>>( 300 | batch_size, 301 | img_height, 302 | img_width, 303 | num_channels, 304 | kernel_size, 305 | splat_radius, 306 | focal_length_px, 307 | depth_in.data_ptr(), 308 | image_in.data_ptr(), 309 | depth_out.data_ptr(), 310 | image_out.data_ptr(), 311 | default_depth, 312 | orthographic 313 | ); 314 | } 315 | 316 | 317 | __global__ void aa_subsampling_kernel( 318 | int64_t batch_size, 319 | int64_t img_height, 320 | int64_t img_width, 321 | int64_t num_channels, 322 | int aa_factor, 323 | float* depth_in, 324 | float* image_in, 325 | float* depth_out, 326 | float* image_out, 327 | float default_depth 328 | ) { 329 | uint x = blockDim.x * blockIdx.x + threadIdx.x; 330 | uint y = blockDim.y * blockIdx.y + threadIdx.y; 331 | uint b = blockDim.z * blockIdx.z + threadIdx.z; 332 | uint out_index = b * img_height * img_width + y * img_width + x; 333 | 334 | uint aa_img_width = img_width * aa_factor; 335 | uint aa_img_height = img_height * aa_factor; 336 | 337 | if (y < img_height && x < img_width && b < batch_size) { 338 | // Loop over pixel's corresponding patch in the input image 339 | for (int dx = 0; dx < aa_factor; dx++) { 340 | for (int dy = 0; dy < aa_factor; dy++) { 341 | int ox = x * aa_factor + dx; 342 | int oy = y * aa_factor + dy; 343 | if (ox >= aa_img_width|| ox < 0 || oy >= aa_img_height|| oy < 0) { 344 | continue; 345 | } 346 | int in_index = b * aa_img_height * aa_img_width + oy * aa_img_width + ox; 347 | 348 | // Average color over all pixels 349 | for (int i = 0; i < num_channels; i++) { 350 | image_out[out_index * num_channels + i] += image_in[in_index * num_channels + i]; 351 | } 352 | // Take min depth across all pixels (median would be better, but that needs more memory to order the values) 353 | float d_in = depth_in[in_index]; 354 | float d_out = depth_out[out_index]; 355 | depth_out[out_index] = d_in; 356 | 357 | // TODO: The epsilon is hard-coded here, for some applications this might be a problem 358 | if (fabsf(d_in - default_depth) > 1e-6) { //d_in != default_depth 359 | if (fabsf(d_out - default_depth) < 1e-6) { //d_out == default_depth 360 | depth_out[out_index] = d_in; 361 | } 362 | else { 363 | depth_out[out_index] = min(d_out, d_in); 364 | } 365 | } 366 | } 367 | } 368 | 369 | for (int i = 0; i < num_channels; i++) { 370 | image_out[out_index * num_channels + i] = image_out[out_index * num_channels + i] / (aa_factor * aa_factor); 371 | } 372 | } 373 | } 374 | 375 | 376 | void aa_subsample( 377 | at::Tensor depth_in, 378 | at::Tensor image_in, 379 | at::Tensor depth_out, 380 | at::Tensor image_out, 381 | int aa_factor, 382 | float default_depth 383 | ) { 384 | int64_t batch_size = image_out.size(0); 385 | int64_t img_height = image_out.size(1); 386 | int64_t img_width = image_out.size(2); 387 | int64_t num_channels = image_out.size(3); 388 | 389 | int num_threads_per_block_x = 32; 390 | 391 | aa_subsampling_kernel<<>>( 396 | batch_size, 397 | img_height, 398 | img_width, 399 | num_channels, 400 | aa_factor, 401 | depth_in.data_ptr(), 402 | image_in.data_ptr(), 403 | depth_out.data_ptr(), 404 | image_out.data_ptr(), 405 | default_depth 406 | ); 407 | } 408 | 409 | 410 | } // namespace rvt 411 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/csrc/render/render_pointcloud.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * NVIDIA CORPORATION and its licensors retain all intellectual property 5 | * and proprietary rights in and to this software, related documentation 6 | * and any modifications thereto. Any use, reproduction, disclosure or 7 | * distribution of this software and related documentation without an express 8 | * license agreement from NVIDIA CORPORATION is strictly prohibited. 9 | */ 10 | 11 | /** @file render_pointcloud.h 12 | * @author Valts Blukis, NVIDIA 13 | * @brief Header for pointcloud rendering 14 | */ 15 | 16 | #pragma once 17 | 18 | #include 19 | 20 | 21 | namespace rvt { 22 | 23 | void render_feature_pointcloud_to_image( 24 | at::Tensor point_indices, // Index into flattened image. -1 if out of bounds. 25 | at::Tensor point_depths, 26 | at::Tensor point_features, 27 | at::Tensor image_out, 28 | at::Tensor depth_out, 29 | float default_depth, 30 | float default_color); 31 | 32 | void screen_space_splatting( 33 | int batch_size, 34 | at::Tensor depth_in, 35 | at::Tensor image_in, 36 | at::Tensor depth_out, 37 | at::Tensor image_out, 38 | float default_depth, 39 | float focal_length_px, 40 | float splat_radius, 41 | int kernel_size, 42 | bool orthographic 43 | ); 44 | 45 | void aa_subsample( 46 | at::Tensor depth_in, 47 | at::Tensor image_in, 48 | at::Tensor depth_out, 49 | at::Tensor image_out, 50 | int aa_factor, 51 | float default_depth 52 | ); 53 | 54 | } // namespace rvt -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import math 13 | from transforms3d import euler, quaternions, affines 14 | import numpy as np 15 | 16 | 17 | def transform_points_batch(pc : torch.Tensor, inv_cam_poses : torch.Tensor): 18 | pc_h = torch.cat([pc, torch.ones_like(pc[:, 0:1])], dim=1) 19 | pc_cam_h = torch.einsum("bxy,ny->bnx", inv_cam_poses, pc_h) 20 | pc_cam = pc_cam_h[:, :, :3] 21 | return pc_cam 22 | 23 | def transform_points(pc : torch.Tensor, inv_cam_pose : torch.Tensor): 24 | pc_h = torch.cat([pc, torch.ones_like(pc[:, 0:1])], dim=1) 25 | pc_cam_h = torch.einsum("xy,ny->nx", inv_cam_pose, pc_h) 26 | pc_cam = pc_cam_h[:, :3] 27 | return pc_cam 28 | 29 | def orthographic_camera_projection_batch(pc_cam : torch.Tensor, K : torch.Tensor): 30 | # For orthographic camera projection, treat all points as if they are at depth 1 31 | uvZ = torch.einsum("bxy,bny->bnx", K, torch.cat([pc_cam[:, :, :2], torch.ones_like(pc_cam[:, :, 2:3])], dim=2)) 32 | return uvZ[:, :, :2] 33 | 34 | def orthographic_camera_projection(pc_cam : torch.Tensor, K : torch.Tensor): 35 | # For orthographic camera projection, treat all points as if they are at depth 1 36 | uvZ = torch.einsum("xy,ny->nx", K, torch.cat([pc_cam[:, :2], torch.ones_like(pc_cam[:, 2:3])], dim=1)) 37 | return uvZ[:, :2] 38 | 39 | def perspective_camera_projection_batch(pc_cam : torch.Tensor, K : torch.Tensor): 40 | uvZ = torch.einsum("bxy,bny->bnx", K, pc_cam) 41 | uv = torch.stack([uvZ[:, :, 0] / uvZ[:, :, 2], uvZ[:, :, 1] / uvZ[:, :, 2]], dim=2) 42 | return uv 43 | 44 | def perspective_camera_projection(pc_cam : torch.Tensor, K : torch.Tensor): 45 | uvZ = torch.einsum("xy,ny->nx", K, pc_cam) 46 | uv = torch.stack([uvZ[:, 0] / uvZ[:, 2], uvZ[:, 1] / uvZ[:, 2]], dim=1) 47 | return uv 48 | 49 | def project_points_3d_to_pixels(pc : torch.Tensor, inv_cam_poses : torch.Tensor, intrinsics : torch.Tensor, orthographic : bool): 50 | """ 51 | This combines the projection from 3D coordinates to camera coordinates using extrinsics, 52 | followed by projection from camera coordinates to pixel coordinates using the intrinsics. 53 | """ 54 | # Project points from world to camera frame 55 | pc_cam = transform_points_batch(pc, inv_cam_poses) 56 | # Project points from camera frame to pixel space 57 | if orthographic: 58 | pc_px = orthographic_camera_projection_batch(pc_cam, intrinsics) 59 | else: 60 | pc_px = perspective_camera_projection_batch(pc_cam, intrinsics) 61 | return pc_px, pc_cam 62 | 63 | 64 | def get_batch_pixel_index(pc_px : torch.Tensor, img_height : int, img_width : int): 65 | """ 66 | Convert a 2D pixel coordinate from a batch of pointclouds to an index 67 | that indexes into a corresponding flattened batch of 2D images. 68 | """ 69 | # batch_idx 70 | batch_idx = torch.arange(pc_px.shape[0], device=pc_px.device, dtype=torch.long)[:, None]#.repeat((1, pc_px.shape[1])). 71 | pix_off = 0.0 72 | row_idx = (pc_px[:, :, 1] + pix_off).long() 73 | col_idx = (pc_px[:, :, 0] + pix_off).long() 74 | pixel_index = batch_idx * img_height * img_width + row_idx * img_width + col_idx 75 | return pixel_index 76 | 77 | def get_pixel_index(pc_px : torch.Tensor, img_width : int): 78 | pix_off = 0.0 79 | row_idx = (pc_px[:, 1] + pix_off).long() 80 | col_idx = (pc_px[:, 0] + pix_off).long() 81 | pixel_index = row_idx * img_width + col_idx 82 | return pixel_index.long() 83 | 84 | 85 | def batch_frustrum_mask(pc_px : torch.Tensor, img_height : int, img_width : int): 86 | imask_x = torch.logical_and(pc_px[:, :, 0] >= 0, pc_px[:, :, 0] < img_width) 87 | imask_y = torch.logical_and(pc_px[:, :, 1] >= 0, pc_px[:, :, 1] < img_height) 88 | imask = torch.logical_and(imask_x, imask_y) 89 | return imask 90 | 91 | def frustrum_mask(pc_px : torch.Tensor, img_height : int, img_width : int): 92 | imask_x = torch.logical_and(pc_px[:, 0] >= 0, pc_px[:, 0] < img_width) 93 | imask_y = torch.logical_and(pc_px[:, 1] >= 0, pc_px[:, 1] < img_height) 94 | imask = torch.logical_and(imask_x, imask_y) 95 | return imask 96 | 97 | def lookat_to_cam_pose(eye, at, up=[0, 0, 1], device="cpu"): 98 | # This runs on CPU, moving to GPU at the end (that's faster) 99 | eye = torch.tensor(eye, device="cpu", dtype=torch.float32) 100 | at = torch.tensor(at, device="cpu", dtype=torch.float32) 101 | 102 | camera_view = F.normalize(at - eye, dim=0) 103 | camera_right = F.normalize(torch.cross(camera_view, torch.tensor(up, dtype=torch.float32, device="cpu"), dim=0), dim=0) 104 | camera_up = F.normalize(torch.cross(camera_right, camera_view, dim=0), dim=0) 105 | 106 | # rotation matrix from opencv conventions 107 | R = torch.stack([camera_right, -camera_up, camera_view], dim=0).T 108 | 109 | T = torch.from_numpy(affines.compose(eye, R, [1, 1, 1])) 110 | return T.float().to(device) 111 | 112 | def fov_and_size_to_intrinsics(fov, img_size, device="cpu"): 113 | img_h, img_w = img_size 114 | fx = img_w / (2 * math.tan(math.radians(fov) / 2)) 115 | fy = img_h / (2 * math.tan(math.radians(fov) / 2)) 116 | 117 | intrinsics = torch.tensor([ 118 | [fx, 0, img_h / 2], 119 | [0, fy, img_w / 2], 120 | [0, 0, 1] 121 | ], dtype=torch.float, device=device) 122 | return intrinsics 123 | 124 | def orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device="cpu"): 125 | img_h, img_w = img_size_px 126 | fx = img_h / (img_sizes_w[:, 0]) 127 | fy = img_w / (img_sizes_w[:, 1]) 128 | 129 | intrinsics = torch.zeros([len(img_sizes_w), 3, 3], dtype=torch.float, device=device) 130 | intrinsics[:, 0, 0] = fx 131 | intrinsics[:, 1, 1] = fy 132 | intrinsics[:, 0, 2] = img_h / 2 133 | intrinsics[:, 1, 2] = img_w / 2 134 | #intrinsics[:, 2, 2] = 1.0 135 | return intrinsics 136 | 137 | 138 | def unravel_index(pixel_index : torch.Tensor, img_width : int): 139 | row_idx = pixel_index // img_width 140 | col_idx = pixel_index % img_width 141 | return torch.stack([col_idx, row_idx], dim=1) 142 | 143 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/profiler.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 2021, NVIDIA CORPORATION. 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | # this software and associated documentation files (the "Software"), to deal in 7 | # the Software without restriction, including without limitation the rights to 8 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | # the Software, and to permit persons to whom the Software is furnished to do so, 10 | # subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | 22 | 23 | import time 24 | import torch 25 | import glob 26 | import os 27 | 28 | class bcolors: 29 | HEADER = '\033[95m' 30 | OKBLUE = '\033[94m' 31 | OKGREEN = '\033[92m' 32 | WARNING = '\033[93m' 33 | FAIL = '\033[91m' 34 | ENDC = '\033[0m' 35 | BOLD = '\033[1m' 36 | UNDERLINE = '\033[4m' 37 | 38 | def colorize_time(elapsed): 39 | if elapsed > 1e-3: 40 | return bcolors.FAIL + "{:.3e}".format(elapsed) + bcolors.ENDC 41 | elif elapsed > 1e-4: 42 | return bcolors.WARNING + "{:.3e}".format(elapsed) + bcolors.ENDC 43 | elif elapsed > 1e-5: 44 | return bcolors.OKBLUE + "{:.3e}".format(elapsed) + bcolors.ENDC 45 | else: 46 | return "{:.3e}".format(elapsed) 47 | 48 | def print_gpu_memory(): 49 | torch.cuda.empty_cache() 50 | print(f"{torch.cuda.memory_allocated()//(1024*1024)} mb") 51 | 52 | 53 | class PerfTimer(): 54 | def __init__(self, activate=False, show_memory=False, print_mode=True): 55 | self.activate = activate 56 | if activate: 57 | self.show_memory = show_memory 58 | self.print_mode = print_mode 59 | self.init() 60 | 61 | def init(self): 62 | self.reset() 63 | self.loop_totals_cpu = {} 64 | self.loop_totals_gpu = {} 65 | self.loop_counts = {} 66 | 67 | 68 | def reset(self): 69 | if self.activate: 70 | self.counter = 0 71 | self.prev_time = time.perf_counter() 72 | self.start = torch.cuda.Event(enable_timing=True) 73 | self.end = torch.cuda.Event(enable_timing=True) 74 | self.prev_time_gpu = self.start.record() 75 | 76 | def check(self, name=None): 77 | if self.activate: 78 | cpu_time = time.perf_counter() - self.prev_time 79 | 80 | self.end.record() 81 | torch.cuda.synchronize() 82 | 83 | gpu_time = self.start.elapsed_time(self.end) / 1e3 84 | 85 | # Keep track of averages. For this to work, keys need to be unique in a global scope 86 | if name not in self.loop_counts: 87 | self.loop_totals_cpu[name] = 0 88 | self.loop_totals_gpu[name] = 0 89 | self.loop_counts[name] = 0 90 | self.loop_totals_gpu[name] += gpu_time 91 | self.loop_totals_cpu[name] += cpu_time 92 | self.loop_counts[name] += 1 93 | 94 | if self.print_mode and name: 95 | cpu_time_disp = colorize_time(cpu_time) 96 | gpu_time_disp = colorize_time(gpu_time) 97 | cpu_time_disp_avg = colorize_time(self.loop_totals_cpu[name] / self.loop_counts[name]) 98 | gpu_time_disp_avg = colorize_time(self.loop_totals_gpu[name] / self.loop_counts[name]) 99 | 100 | if name: 101 | print(f"CPU Checkpoint {name}: {cpu_time_disp}s (Avg: {cpu_time_disp_avg}s)") 102 | print(f"GPU Checkpoint {name}: {gpu_time_disp}s (Avg: {gpu_time_disp_avg}s)") 103 | else: 104 | print("CPU Checkpoint {}: {} s".format(self.counter, cpu_time_disp)) 105 | print("GPU Checkpoint {}: {} s".format(self.counter, gpu_time_disp)) 106 | if self.show_memory: 107 | #torch.cuda.empty_cache() 108 | print(f"{torch.cuda.memory_allocated()//1048576}MB") 109 | 110 | 111 | self.prev_time = time.perf_counter() 112 | self.prev_time_gpu = self.start.record() 113 | self.counter += 1 114 | return cpu_time, gpu_time 115 | 116 | def get_avg_cpu_time(self, name): 117 | return self.loop_totals_cpu[name] / self.loop_counts[name] 118 | 119 | def get_avg_gpu_time(self, name): 120 | return self.loop_totals_gpu[name] / self.loop_counts[name] 121 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/renderer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import torch 11 | import math 12 | import point_renderer.ops as ops 13 | from point_renderer.cameras import OrthographicCameras, PerspectiveCameras 14 | import point_renderer._C.render as r 15 | 16 | from point_renderer.profiler import PerfTimer 17 | 18 | 19 | @torch.jit.script 20 | def _prep_render_batch_inputs(points, features, inv_poses, intrinsics, img_h : int, img_w: int, orthographic : bool): 21 | batch_size = len(inv_poses) 22 | num_points = points.shape[0] 23 | 24 | # Project points from 3D world coordinates to pixels and camera coordinates 25 | pc_px, pc_cam = ops.project_points_3d_to_pixels(points, inv_poses, intrinsics, orthographic) 26 | 27 | # Convert pixel coordinates to flattened pixel coordinates, marking out-of-image points with index -1 28 | point_pixel_index = ops.get_batch_pixel_index(pc_px, img_h, img_w) 29 | imask = ops.batch_frustrum_mask(pc_px, img_h, img_w) 30 | point_pixel_index[~imask] = -1 31 | 32 | point_depths = pc_cam[:, :, 2] 33 | return point_pixel_index.reshape([batch_size * num_points]).contiguous(), point_depths.reshape([batch_size * num_points]).contiguous(), features 34 | 35 | 36 | 37 | class PointRenderer: 38 | 39 | def __init__(self, device="cuda", perf_timer=False): 40 | self.device = device 41 | assert "cuda" in self.device, "Currently only a CUDA implementation is available" 42 | self.timer = PerfTimer(activate=perf_timer, show_memory=False, print_mode=perf_timer) 43 | 44 | @torch.no_grad() 45 | def splat_filter(self, batch_size, splat_radius, cameras, point_depths, depth_buf, image_buf, default_depth, splat_max_k): 46 | # This assumes same focal length for all cameras. 47 | if cameras.is_perspective(): 48 | # We are assuming x and y focal lengths to be about the same. 49 | focal_length = cameras.intrinsics[0, 0, 0].item() 50 | closest_point = point_depths.min() 51 | biggest_splat = math.ceil((splat_radius * focal_length / closest_point)) 52 | elif cameras.is_orthographic(): 53 | # In orthographic cameras all points are the same size 54 | focal_length = cameras.intrinsics[0, 0, 0].item() 55 | biggest_splat = math.ceil(focal_length * splat_radius) 56 | else: 57 | raise ValueError(f"Unknown camera type: {type(cameras)}") 58 | 59 | if splat_max_k is None: 60 | splat_max_k = 7 61 | kernel_size = min(biggest_splat * 2 + 1, splat_max_k) 62 | 63 | #print(f"Splatting filter with k={kernel_size}, b={batch_size}") 64 | r.screen_space_splatting( 65 | batch_size, 66 | depth_buf.clone(), 67 | image_buf.clone(), 68 | depth_buf, 69 | image_buf, 70 | default_depth, 71 | focal_length, 72 | splat_radius, 73 | kernel_size, 74 | cameras.is_orthographic() 75 | ) 76 | 77 | 78 | @torch.no_grad() 79 | def render_batch(self, points, features, cameras, img_size, default_depth=0.0, default_color=0.0, splat_radius=None, splat_max_k=51, aa_factor=1): 80 | # Figure out dimensions of the problem 81 | img_h, img_w = img_size 82 | aa_factor = int(aa_factor) 83 | assert aa_factor >= 1, "Antialiasing factor must be greater than 1" 84 | img_h = img_h * aa_factor 85 | img_w = img_w * aa_factor 86 | batch_size = len(cameras) 87 | num_channels = features.shape[1] 88 | 89 | self.timer.reset() 90 | 91 | # Make sure inputs are on the rendering device 92 | cameras = cameras.to(self.device) 93 | features = features.to(self.device) 94 | points = points.to(self.device) 95 | 96 | # Scale the camera if we want to internally render a bigger image for fake antialiasing 97 | if aa_factor > 1: 98 | cameras.scale(aa_factor) 99 | 100 | # Project points to image space depending on the camera type 101 | # (these would be more elegant as class methods, but TorchScript doesn't support it) 102 | point_pixel_index, point_depths, features = _prep_render_batch_inputs( 103 | points, features, cameras.inv_poses, cameras.intrinsics, img_h, img_w, type(cameras) == OrthographicCameras) 104 | 105 | # Allocate depth and image render buffers 106 | # Batch size is rolled in with the height dimension 107 | # a.k.a. we render a single image that contains all images vertically stacked 108 | depth_buf = torch.zeros((batch_size * img_h, img_w), device=self.device, dtype=torch.float32) 109 | assert features.dtype == torch.float32, "For now only torch.uint8 and torch.float32 colors are supported" 110 | image_buf = torch.zeros((batch_size * img_h, img_w, num_channels), device=self.device, dtype=torch.float32) 111 | 112 | self.timer.check("render_setup") 113 | 114 | # Render points to pixel buffers 115 | r.render_feature_pointcloud_to_image( 116 | point_pixel_index, 117 | point_depths, 118 | features, 119 | image_buf, 120 | depth_buf, 121 | default_depth, 122 | default_color 123 | ) 124 | 125 | self.timer.check("render") 126 | 127 | # Apply screen-space splatting filter 128 | if splat_radius is not None: 129 | self.splat_filter(batch_size, splat_radius, cameras, point_depths, depth_buf, image_buf, default_depth, splat_max_k) 130 | self.timer.check("splatting") 131 | 132 | # Separate the batch dimension, so that we have a batch of images 133 | image_buf = image_buf.reshape((batch_size, img_h, img_w, num_channels)) 134 | depth_buf = depth_buf.reshape((batch_size, img_h, img_w)) 135 | 136 | if aa_factor != 1: 137 | # Subsample the larger render buffers to produce the output image 138 | img_h_out = img_h // aa_factor 139 | img_w_out = img_w // aa_factor 140 | depth_out = torch.zeros((batch_size, img_h_out, img_w_out), device=self.device, dtype=torch.float32).fill_(default_depth) 141 | image_out = torch.zeros((batch_size, img_h_out, img_w_out, num_channels), device=self.device, dtype=torch.float32) 142 | print(depth_buf.min(), depth_buf.max()) 143 | 144 | r.aa_subsample( 145 | depth_buf, 146 | image_buf, 147 | depth_out, 148 | image_out, 149 | aa_factor, 150 | default_depth 151 | ) 152 | self.timer.check("antialiasing") 153 | return image_out.reshape((batch_size, img_h_out, img_w_out, num_channels)), depth_out.reshape((batch_size, img_h_out, img_w_out)) 154 | else: 155 | # Output render buffers as-is if we're rendering at the same resolution as the output 156 | self.timer.check("antialiasing") 157 | return image_buf, depth_buf 158 | 159 | 160 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/rvt_ops.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import torch 3 | import math 4 | 5 | 6 | # source: https://discuss.pytorch.org/t/batched-index-select/9115/6 7 | def batched_index_select(inp, dim, index): 8 | """ 9 | input: B x * x ... x * 10 | dim: 0 < scalar 11 | index: B x M 12 | """ 13 | views = [inp.shape[0]] + [1 if i != dim else -1 for i in range(1, len(inp.shape))] 14 | expanse = list(inp.shape) 15 | expanse[0] = -1 16 | expanse[dim] = -1 17 | index = index.view(views).expand(expanse) 18 | return torch.gather(inp, dim, index) 19 | 20 | 21 | # TODO: break into two functions 22 | def select_feat_from_hm( 23 | pt_cam: torch.Tensor, hm: torch.Tensor, pt_cam_wei: Optional[torch.Tensor] = None 24 | ) -> Tuple[torch.Tensor]: 25 | """ 26 | :param pt_cam: 27 | continuous location of point coordinates from where value needs to be 28 | selected. it is of size [nc, npt, 2], locations in pytorch3d screen 29 | notations 30 | :param hm: size [nc, nw, h, w] 31 | :param pt_cam_wei: 32 | some predifined weight of size [nc, npt], it is used along with the 33 | distance weights 34 | :return: 35 | tuple with the first element being the wighted average for each point 36 | according to the hm values. the size is [nc, npt, nw]. the second and 37 | third elements are intermediate values to be used while chaching 38 | """ 39 | nc, nw, h, w = hm.shape 40 | npt = pt_cam.shape[1] 41 | if pt_cam_wei is None: 42 | pt_cam_wei = torch.ones([nc, npt]).to(hm.device) 43 | 44 | # giving points outside the image zero weight 45 | pt_cam_wei[pt_cam[:, :, 0] < 0] = 0 46 | pt_cam_wei[pt_cam[:, :, 1] < 0] = 0 47 | pt_cam_wei[pt_cam[:, :, 0] > (w - 1)] = 0 48 | pt_cam_wei[pt_cam[:, :, 1] > (h - 1)] = 0 49 | 50 | pt_cam = pt_cam.unsqueeze(2).repeat([1, 1, 4, 1]) 51 | # later used for calculating weight 52 | pt_cam_con = pt_cam.detach().clone() 53 | 54 | # getting discrete grid location of pts in the camera image space 55 | pt_cam[:, :, 0, 0] = torch.floor(pt_cam[:, :, 0, 0]) 56 | pt_cam[:, :, 0, 1] = torch.floor(pt_cam[:, :, 0, 1]) 57 | pt_cam[:, :, 1, 0] = torch.floor(pt_cam[:, :, 1, 0]) 58 | pt_cam[:, :, 1, 1] = torch.ceil(pt_cam[:, :, 1, 1]) 59 | pt_cam[:, :, 2, 0] = torch.ceil(pt_cam[:, :, 2, 0]) 60 | pt_cam[:, :, 2, 1] = torch.floor(pt_cam[:, :, 2, 1]) 61 | pt_cam[:, :, 3, 0] = torch.ceil(pt_cam[:, :, 3, 0]) 62 | pt_cam[:, :, 3, 1] = torch.ceil(pt_cam[:, :, 3, 1]) 63 | pt_cam = pt_cam.long() # [nc, npt, 4, 2] 64 | # since we are taking modulo, points at the edge, i,e at h or w will be 65 | # mapped to 0. this will make their distance from the continous location 66 | # large and hence they won't matter. therefore we don't need an explicit 67 | # step to remove such points 68 | pt_cam[:, :, :, 0] = torch.fmod(pt_cam[:, :, :, 0], int(w)) 69 | pt_cam[:, :, :, 1] = torch.fmod(pt_cam[:, :, :, 1], int(h)) 70 | pt_cam[pt_cam < 0] = 0 71 | 72 | # getting normalized weight for each discrete location for pt 73 | # weight based on distance of point from the discrete location 74 | # [nc, npt, 4] 75 | pt_cam_dis = 1 / (torch.sqrt(torch.sum((pt_cam_con - pt_cam) ** 2, dim=-1)) + 1e-10) 76 | pt_cam_wei = pt_cam_wei.unsqueeze(-1) * pt_cam_dis 77 | _pt_cam_wei = torch.sum(pt_cam_wei, dim=-1, keepdim=True) 78 | _pt_cam_wei[_pt_cam_wei == 0.0] = 1 79 | # cached pt_cam_wei in select_feat_from_hm_cache 80 | pt_cam_wei = pt_cam_wei / _pt_cam_wei # [nc, npt, 4] 81 | 82 | # transforming indices from 2D to 1D to use pytorch gather 83 | hm = hm.permute(0, 2, 3, 1).view(nc, h * w, nw) # [nc, h * w, nw] 84 | pt_cam = pt_cam.view(nc, 4 * npt, 2) # [nc, 4 * npt, 2] 85 | # cached pt_cam in select_feat_from_hm_cache 86 | pt_cam = (pt_cam[:, :, 1] * w) + pt_cam[:, :, 0] # [nc, 4 * npt] 87 | # [nc, 4 * npt, nw] 88 | pt_cam_val = batched_index_select(hm, dim=1, index=pt_cam) 89 | # tranforming back each discrete location of point 90 | pt_cam_val = pt_cam_val.view(nc, npt, 4, nw) 91 | # summing weighted contribution of each discrete location of a point 92 | # [nc, npt, nw] 93 | pt_cam_val = torch.sum(pt_cam_val * pt_cam_wei.unsqueeze(-1), dim=2) 94 | return pt_cam_val, pt_cam, pt_cam_wei 95 | 96 | 97 | def select_feat_from_hm_cache( 98 | pt_cam: torch.Tensor, 99 | hm: torch.Tensor, 100 | pt_cam_wei: torch.Tensor, 101 | ) -> torch.Tensor: 102 | """ 103 | Cached version of select_feat_from_hm where we feed in directly the 104 | intermediate value of pt_cam and pt_cam_wei. Look into the original 105 | function to get the meaning of these values and return type. It could be 106 | used while inference if the location of the points remain the same. 107 | """ 108 | 109 | nc, nw, h, w = hm.shape 110 | # transforming indices from 2D to 1D to use pytorch gather 111 | hm = hm.permute(0, 2, 3, 1).view(nc, h * w, nw) # [nc, h * w, nw] 112 | # [nc, 4 * npt, nw] 113 | pt_cam_val = batched_index_select(hm, dim=1, index=pt_cam) 114 | # tranforming back each discrete location of point 115 | pt_cam_val = pt_cam_val.view(nc, -1, 4, nw) 116 | # summing weighted contribution of each discrete location of a point 117 | # [nc, npt, nw] 118 | pt_cam_val = torch.sum(pt_cam_val * pt_cam_wei.unsqueeze(-1), dim=2) 119 | return pt_cam_val 120 | 121 | 122 | # unit tests to verify select_feat_from_hm 123 | def test_select_feat_from_hm(): 124 | def get_out(pt_cam, hm): 125 | nc, nw, d = pt_cam.shape 126 | nc2, c, h, w = hm.shape 127 | assert nc == nc2 128 | assert d == 2 129 | out = torch.zeros((nc, nw, c)) 130 | for i in range(nc): 131 | for j in range(nw): 132 | wx, hx = pt_cam[i, j] 133 | if (wx < 0) or (hx < 0) or (wx > (w - 1)) or (hx > (h - 1)): 134 | out[i, j, :] = 0 135 | else: 136 | coords = ( 137 | (math.floor(wx), math.floor(hx)), 138 | (math.floor(wx), math.ceil(hx)), 139 | (math.ceil(wx), math.floor(hx)), 140 | (math.ceil(wx), math.ceil(hx)), 141 | ) 142 | vals = [] 143 | total = 0 144 | for x, y in coords: 145 | val = 1 / (math.sqrt(((wx - x) ** 2) + ((hx - y) ** 2)) + 1e-10) 146 | vals.append(val) 147 | total += val 148 | 149 | vals = [x / total for x in vals] 150 | 151 | for (x, y), val in zip(coords, vals): 152 | out[i, j] += val * hm[i, :, y, x] 153 | return out 154 | 155 | pt_cam_1 = torch.tensor([[[11.11, 120.1], [37.8, 0.0], [99, 76.5]]]) 156 | hm_1_1 = torch.ones((1, 1, 100, 120)) 157 | hm_1_2 = torch.ones((1, 1, 120, 100)) 158 | out_1 = torch.ones((1, 3, 1)) 159 | out_1[0, 0, 0] = 0 160 | 161 | pt_cam_2 = torch.tensor( 162 | [ 163 | [[11.11, 12.11], [37.8, 0.0]], 164 | [[61.00, 12.00], [123.99, 123.0]], 165 | ] 166 | ) 167 | hm_2_1 = torch.rand((2, 1, 200, 100)) 168 | hm_2_2 = torch.rand((2, 1, 100, 200)) 169 | 170 | test_sets = [ 171 | (pt_cam_1, hm_1_1, out_1), 172 | (pt_cam_1, hm_1_2, out_1), 173 | (pt_cam_2, hm_2_1, get_out(pt_cam_2, hm_2_1)), 174 | (pt_cam_2, hm_2_2, get_out(pt_cam_2, hm_2_2)), 175 | ] 176 | 177 | for i, test in enumerate(test_sets): 178 | pt_cam, hm, out = test 179 | _out, _, _ = select_feat_from_hm(pt_cam, hm) 180 | out = out.float() 181 | if torch.all(torch.abs(_out - out) < 1e-5): 182 | print(f"Passed test {i}, {out}, {_out}") 183 | else: 184 | print(f"Failed test {i}, {out}, {_out}") 185 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/point_renderer/rvt_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import point_renderer.ops as ops 3 | from point_renderer.cameras import OrthographicCameras, PerspectiveCameras 4 | from point_renderer.renderer import PointRenderer 5 | from mvt.utils import ForkedPdb 6 | 7 | import point_renderer.rvt_ops as rvt_ops 8 | 9 | 10 | class RVTBoxRenderer(): 11 | """ 12 | Wrapper around PointRenderer that fixes the cameras to be orthographic cameras 13 | on the faces of a 2x2x2 cube placed at the origin 14 | """ 15 | 16 | def __init__( 17 | self, 18 | img_size, 19 | radius=0.012, 20 | default_color=0.0, 21 | default_depth=-1.0, 22 | antialiasing_factor=1, 23 | pers=False, 24 | normalize_output=True, 25 | with_depth=True, 26 | device="cuda", 27 | perf_timer=False, 28 | strict_input_device=True, 29 | no_down=True, 30 | no_top=False, 31 | three_views=False, 32 | two_views=False, 33 | one_view=False, 34 | add_3p=False, 35 | **kwargs): 36 | 37 | self.renderer = PointRenderer(device=device, perf_timer=perf_timer) 38 | 39 | self.img_size = img_size 40 | self.splat_radius = radius 41 | self.default_color = default_color 42 | self.default_depth = default_depth 43 | self.aa_factor = antialiasing_factor 44 | self.normalize_output = normalize_output 45 | self.with_depth = with_depth 46 | 47 | self.strict_input_device = strict_input_device 48 | 49 | # Pre-compute fixed cameras ahead of time 50 | self.cameras = self._get_cube_cameras( 51 | img_size=self.img_size, 52 | orthographic=not pers, 53 | no_down=no_down, 54 | no_top=no_top, 55 | three_views=three_views, 56 | two_views=two_views, 57 | one_view=one_view, 58 | add_3p=add_3p, 59 | ) 60 | self.cameras = self.cameras.to(device) 61 | 62 | # TODO(Valts): add support for dynamic cameras 63 | 64 | # Cache 65 | self._fix_pts_cam = None 66 | self._fix_pts_cam_wei = None 67 | self._pts = None 68 | 69 | # RVT API (that we might want to refactor) 70 | self.num_img = len(self.cameras) 71 | self.only_dyn_cam = False 72 | 73 | def _check_device(self, input, input_name): 74 | if self.strict_input_device: 75 | assert str(input.device) == str(self.renderer.device), ( 76 | f"Input {input_name} (device {input.device}) should be on the same device as the renderer ({self.renderer.device})") 77 | 78 | def _get_cube_cameras( 79 | self, 80 | img_size, 81 | orthographic, 82 | no_down, 83 | no_top, 84 | three_views, 85 | two_views, 86 | one_view, 87 | add_3p, 88 | ): 89 | cam_dict = { 90 | "top": {"eye": [0, 0, 1], "at": [0, 0, 0], "up": [0, 1, 0]}, 91 | "front": {"eye": [1, 0, 0], "at": [0, 0, 0], "up": [0, 0, 1]}, 92 | "down": {"eye": [0, 0, -1], "at": [0, 0, 0], "up": [0, 1, 0]}, 93 | "back": {"eye": [-1, 0, 0], "at": [0, 0, 0], "up": [0, 0, 1]}, 94 | "left": {"eye": [0, -1, 0], "at": [0, 0, 0], "up": [0, 0, 1]}, 95 | "right": {"eye": [0, 0.5, 0], "at": [0, 0, 0], "up": [0, 0, 1]}, 96 | } 97 | 98 | assert not (two_views and three_views) 99 | assert not (one_view and three_views) 100 | assert not (one_view and two_views) 101 | assert not add_3p, "Not supported with point renderer yet," 102 | if two_views or three_views or one_view: 103 | if no_down or no_top or add_3p: 104 | print( 105 | f"WARNING: when three_views={three_views} or two_views={two_views} -- " 106 | f"no_down={no_down} no_top={no_top} add_3p={add_3p} does not matter." 107 | ) 108 | 109 | if three_views: 110 | cam_names = ["top", "front", "right"] 111 | elif two_views: 112 | cam_names = ["top", "front"] 113 | elif one_view: 114 | cam_names = ["front"] 115 | else: 116 | cam_names = ["top", "front", "down", "back", "left", "right"] 117 | if no_down: 118 | # select index of "down" camera and remove it from the list 119 | del cam_names[cam_names.index("down")] 120 | if no_top: 121 | del cam_names[cam_names.index("top")] 122 | 123 | 124 | cam_list = [cam_dict[n] for n in cam_names] 125 | eyes = [c["eye"] for c in cam_list] 126 | ats = [c["at"] for c in cam_list] 127 | ups = [c["up"] for c in cam_list] 128 | 129 | if orthographic: 130 | # img_sizes_w specifies height and width dimensions of the image in world coordinates 131 | # [2, 2] means it will image coordinates from -1 to 1 in the camera frame 132 | cameras = OrthographicCameras.from_lookat(eyes, ats, ups, img_sizes_w=[2, 2], img_size_px=img_size) 133 | else: 134 | cameras = PerspectiveCameras.from_lookat(eyes, ats, ups, hfov=70, img_size=img_size) 135 | return cameras 136 | 137 | @torch.no_grad() 138 | def get_pt_loc_on_img(self, pt, fix_cam=False, dyn_cam_info=None): 139 | """ 140 | returns the location of a point on the image of the cameras 141 | :param pt: torch.Tensor of shape (bs, np, 3) 142 | :returns: the location of the pt on the image. this is different from the 143 | camera screen coordinate system in pytorch3d. the difference is that 144 | pytorch3d camera screen projects the point to [0, 0] to [H, W]; while the 145 | index on the img is from [0, 0] to [H-1, W-1]. We verified that 146 | the to transform from pytorch3d camera screen point to img we have to 147 | subtract (1/H, 1/W) from the pytorch3d camera screen coordinate. 148 | :return type: torch.Tensor of shape (bs, np, self.num_img, 2) 149 | """ 150 | assert len(pt.shape) == 3 151 | assert pt.shape[-1] == 3 152 | assert fix_cam, "Not supported with point renderer" 153 | assert dyn_cam_info is None, "Not supported with point renderer" 154 | 155 | bs, np, _ = pt.shape 156 | 157 | self._check_device(pt, "pt") 158 | 159 | # TODO(Valts): Ask Ankit what what is the bs dimension here, and treat it correctly here 160 | 161 | pcs_px = [] 162 | for i in range(bs): 163 | pc_px, pc_cam = ops.project_points_3d_to_pixels( 164 | pt[i], self.cameras.inv_poses, self.cameras.intrinsics, self.cameras.is_orthographic()) 165 | pcs_px.append(pc_px) 166 | pcs_px = torch.stack(pcs_px, dim=0) 167 | pcs_px = torch.permute(pcs_px, (0, 2, 1, 3)) 168 | 169 | # TODO(Valts): Double-check with Ankit that these projections are truly pixel-aligned 170 | return pcs_px 171 | 172 | @torch.no_grad() 173 | def get_feat_frm_hm_cube(self, hm, fix_cam=False, dyn_cam_info=None): 174 | """ 175 | :param hm: torch.Tensor of (1, num_img, h, w) 176 | :return: tupe of ((num_img, h^3, 1), (h^3, 3)) 177 | """ 178 | x, nc, h, w = hm.shape 179 | assert x == 1 180 | assert nc == self.num_img 181 | assert self.img_size == (h, w) 182 | assert fix_cam, "Not supported with point renderer" 183 | assert dyn_cam_info is None, "Not supported with point renderer" 184 | 185 | self._check_device(hm, "hm") 186 | 187 | if self._pts is None: 188 | res = self.img_size[0] 189 | pts = torch.linspace(-1 + (1 / res), 1 - (1 / res), res).to(hm.device) 190 | pts = torch.cartesian_prod(pts, pts, pts) 191 | self._pts = pts 192 | 193 | pts_hm = [] 194 | 195 | # if self._fix_cam 196 | if self._fix_pts_cam is None: 197 | # (np, nc, 2) 198 | pts_img = self.get_pt_loc_on_img(self._pts.unsqueeze(0), 199 | fix_cam=True).squeeze(0) 200 | # pts_img = pts_img.permute((1, 0, 2)) 201 | # (nc, np, bs) 202 | fix_pts_hm, pts_cam, pts_cam_wei = rvt_ops.select_feat_from_hm( 203 | pts_img.transpose(0, 1), hm.transpose(0, 1)[0 : len(self.cameras)] 204 | ) 205 | self._fix_pts_img = pts_img 206 | self._fix_pts_cam = pts_cam 207 | self._fix_pts_cam_wei = pts_cam_wei 208 | else: 209 | pts_cam = self._fix_pts_cam 210 | pts_cam_wei = self._fix_pts_cam_wei 211 | fix_pts_hm = rvt_ops.select_feat_from_hm_cache( 212 | pts_cam, hm.transpose(0, 1)[0 : len(self.cameras)], pts_cam_wei 213 | ) 214 | pts_hm.append(fix_pts_hm) 215 | 216 | #if not dyn_cam_info is None: 217 | # TODO(Valts): implement 218 | pts_hm = torch.cat(pts_hm, 0) 219 | return pts_hm, self._pts 220 | 221 | @torch.no_grad() 222 | def get_max_3d_frm_hm_cube(self, hm, fix_cam=False, dyn_cam_info=None, 223 | topk=1, non_max_sup=False, 224 | non_max_sup_dist=0.02): 225 | """ 226 | given set of heat maps, return the 3d location of the point with the 227 | largest score, assumes the points are in a cube [-1, 1]. This function 228 | should be used along with the render. For standalone version look for 229 | the other function with same name in the file. 230 | :param hm: (1, nc, h, w) 231 | :return: (1, topk, 3) 232 | """ 233 | assert fix_cam, "Not supported with point renderer" 234 | assert dyn_cam_info is None, "Not supported with point renderer" 235 | 236 | self._check_device(hm, "hm") 237 | 238 | x, nc, h, w = hm.shape 239 | assert x == 1 240 | assert nc == len(self.cameras) 241 | assert self.img_size == (h, w) 242 | 243 | pts_hm, pts = self.get_feat_frm_hm_cube(hm, fix_cam, dyn_cam_info) 244 | # (bs, np, nc) 245 | pts_hm = pts_hm.permute(2, 1, 0) 246 | # (bs, np) 247 | pts_hm = torch.mean(pts_hm, -1) 248 | if non_max_sup and topk > 1: 249 | _pts = pts.clone() 250 | pts = [] 251 | pts_hm = torch.squeeze(pts_hm, 0) 252 | for i in range(topk): 253 | ind_max_pts = torch.argmax(pts_hm, -1) 254 | sel_pts = _pts[ind_max_pts] 255 | pts.append(sel_pts) 256 | dist = torch.sqrt(torch.sum((_pts - sel_pts) ** 2, -1)) 257 | pts_hm[dist < non_max_sup_dist] = -1 258 | pts = torch.stack(pts, 0).unsqueeze(0) 259 | else: 260 | # (bs, topk) 261 | ind_max_pts = torch.topk(pts_hm, topk)[1] 262 | # (bs, topk, 3) 263 | pts = pts[ind_max_pts] 264 | return pts 265 | 266 | def __call__(self, pc, feat, fix_cam=False, dyn_cam_info=None): 267 | 268 | self._check_device(pc, "pc") 269 | self._check_device(pc, "feat") 270 | 271 | pc_images, pc_depths = self.renderer.render_batch(pc, feat, 272 | cameras=self.cameras, 273 | img_size=self.img_size, 274 | splat_radius=self.splat_radius, 275 | default_color=self.default_color, 276 | default_depth=self.default_depth, 277 | aa_factor=self.aa_factor 278 | ) 279 | 280 | if self.normalize_output: 281 | _, h, w = pc_depths.shape 282 | depth_0 = pc_depths == -1 283 | depth_sum = torch.sum(pc_depths, (1, 2)) + torch.sum(depth_0, (1, 2)) 284 | depth_mean = depth_sum / ((h * w) - torch.sum(depth_0, (1, 2))) 285 | pc_depths -= depth_mean.unsqueeze(-1).unsqueeze(-1) 286 | pc_depths[depth_0] = -1 287 | 288 | if self.with_depth: 289 | img_out = torch.cat([pc_images, pc_depths[:, :, :, None]], dim=-1) 290 | else: 291 | img_out = pc_images 292 | 293 | return img_out 294 | -------------------------------------------------------------------------------- /rvt/libs/point-renderer/requirements.txt: -------------------------------------------------------------------------------- 1 | transforms3d 2 | jupyter 3 | numpy 4 | torch 5 | matplotlib 6 | imageio 7 | trimesh 8 | meshcat -------------------------------------------------------------------------------- /rvt/libs/point-renderer/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | import os 11 | import sys 12 | from setuptools import setup, find_packages, dist 13 | import glob 14 | import logging 15 | 16 | import torch 17 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 18 | 19 | PACKAGE_NAME = 'point_renderer' 20 | DESCRIPTION = 'Fast Point Cloud Renderer' 21 | URL = 'https://gitlab-master.nvidia.com/vblukis/point-renderer' 22 | AUTHOR = 'Valts Blukis' 23 | LICENSE = 'NVIDIA' 24 | version = '0.2.0' 25 | 26 | 27 | def get_extensions(): 28 | extra_compile_args = {'cxx': ['-O3']} 29 | define_macros = [] 30 | include_dirs = [] 31 | extensions = [] 32 | sources = glob.glob('point_renderer/csrc/**/*.cpp', recursive=True) 33 | 34 | if len(sources) == 0: 35 | print("No source files found for extension, skipping extension compilation") 36 | return None 37 | 38 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1' or True: 39 | define_macros += [("WITH_CUDA", None), ("THRUST_IGNORE_CUB_VERSION_CHECK", None)] 40 | sources += glob.glob('point_renderer/csrc/**/*.cu', recursive=True) 41 | extension = CUDAExtension 42 | extra_compile_args.update({'nvcc': ['-O3']}) 43 | #include_dirs = get_include_dirs() 44 | else: 45 | assert(False, "CUDA is not available. Set FORCE_CUDA=1 for Docker builds") 46 | 47 | extensions.append( 48 | extension( 49 | name='point_renderer._C', 50 | sources=sources, 51 | define_macros=define_macros, 52 | extra_compile_args=extra_compile_args, 53 | #include_dirs=include_dirs 54 | ) 55 | ) 56 | 57 | for ext in extensions: 58 | ext.libraries = ['cudart_static' if x == 'cudart' else x 59 | for x in ext.libraries] 60 | 61 | return extensions 62 | 63 | 64 | if __name__ == '__main__': 65 | setup( 66 | # Metadata 67 | name=PACKAGE_NAME, 68 | version=version, 69 | author=AUTHOR, 70 | description=DESCRIPTION, 71 | url=URL, 72 | license=LICENSE, 73 | python_requires='>=3.7', 74 | 75 | # Package info 76 | packages=['point_renderer'], 77 | include_package_data=True, 78 | zip_safe=True, 79 | ext_modules=get_extensions(), 80 | cmdclass={ 81 | 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True) 82 | } 83 | 84 | ) 85 | -------------------------------------------------------------------------------- /rvt/models/peract_official.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | from rvt.libs.peract.helpers.preprocess_agent import PreprocessAgent 6 | from rvt.libs.peract.agents.peract_bc.launch_utils import create_agent 7 | 8 | 9 | class PreprocessAgent2(PreprocessAgent): 10 | def eval(self): 11 | self._pose_agent._qattention_agents[0]._q.eval() 12 | 13 | def train(self): 14 | self._pose_agent._qattention_agents[0]._q.train() 15 | 16 | def build(self, *args, **kwargs): 17 | super().build(*args, **kwargs) 18 | 19 | self._device = self._pose_agent._qattention_agents[0]._device 20 | 21 | 22 | def create_agent_our(cfg): 23 | """ 24 | Reuses the official peract agent, but replaces PreprocessAgent2 with PreprocessAgent 25 | """ 26 | agent = create_agent(cfg) 27 | agent = agent._pose_agent 28 | agent = PreprocessAgent2(agent) 29 | return agent 30 | -------------------------------------------------------------------------------- /rvt/mvt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | -------------------------------------------------------------------------------- /rvt/mvt/attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | # Sources: 6 | # https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_io.py 7 | # https://github.com/peract/peract/blob/main/helpers/network_utils.py 8 | 9 | import math 10 | from math import log 11 | from functools import wraps 12 | from packaging import version 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from tqdm import tqdm 18 | from torch import nn, einsum 19 | from einops import rearrange, repeat 20 | 21 | try: 22 | import xformers.ops as xops 23 | except ImportError as e: 24 | xops = None 25 | 26 | LRELU_SLOPE = 0.02 27 | 28 | 29 | def exists(val): 30 | return val is not None 31 | 32 | 33 | def default(val, d): 34 | return val if exists(val) else d 35 | 36 | 37 | def cache_fn(f): 38 | cache = None 39 | 40 | @wraps(f) 41 | def cached_fn(*args, _cache=True, **kwargs): 42 | if not _cache: 43 | return f(*args, **kwargs) 44 | nonlocal cache 45 | if cache is not None: 46 | return cache 47 | cache = f(*args, **kwargs) 48 | return cache 49 | 50 | return cached_fn 51 | 52 | 53 | class PreNorm(nn.Module): 54 | def __init__(self, dim, fn, context_dim=None): 55 | super().__init__() 56 | self.fn = fn 57 | self.norm = nn.LayerNorm(dim) 58 | self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None 59 | 60 | def forward(self, x, **kwargs): 61 | x = self.norm(x) 62 | 63 | if exists(self.norm_context): 64 | context = kwargs["context"] 65 | normed_context = self.norm_context(context) 66 | kwargs.update(context=normed_context) 67 | 68 | return self.fn(x, **kwargs) 69 | 70 | 71 | class GEGLU(nn.Module): 72 | def forward(self, x): 73 | x, gates = x.chunk(2, dim=-1) 74 | return x * F.gelu(gates) 75 | 76 | 77 | class FeedForward(nn.Module): 78 | def __init__(self, dim, mult=4): 79 | super().__init__() 80 | self.net = nn.Sequential( 81 | nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim) 82 | ) 83 | 84 | def forward(self, x): 85 | return self.net(x) 86 | 87 | 88 | class Attention(nn.Module): # is all you need. Living up to its name. 89 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, 90 | dropout=0.0, use_fast=False): 91 | 92 | super().__init__() 93 | self.use_fast = use_fast 94 | inner_dim = dim_head * heads 95 | context_dim = default(context_dim, query_dim) 96 | self.scale = dim_head**-0.5 97 | self.heads = heads 98 | 99 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 100 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 101 | self.to_out = nn.Linear(inner_dim, query_dim) 102 | 103 | self.dropout_p = dropout 104 | # dropout left in use_fast for backward compatibility 105 | self.dropout = nn.Dropout(self.dropout_p) 106 | 107 | self.avail_xf = False 108 | if self.use_fast: 109 | if not xops is None: 110 | self.avail_xf = True 111 | else: 112 | self.use_fast = False 113 | 114 | def forward(self, x, context=None, mask=None): 115 | h = self.heads 116 | 117 | q = self.to_q(x) 118 | context = default(context, x) 119 | k, v = self.to_kv(context).chunk(2, dim=-1) 120 | 121 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 122 | if self.use_fast: 123 | # using py2 if available 124 | dropout_p = self.dropout_p if self.training else 0.0 125 | # using xf if available 126 | if self.avail_xf: 127 | out = xops.memory_efficient_attention( 128 | query=q, key=k, value=v, p=dropout_p 129 | ) 130 | else: 131 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 132 | if exists(mask): 133 | mask = rearrange(mask, "b ... -> b (...)") 134 | max_neg_value = -torch.finfo(sim.dtype).max 135 | mask = repeat(mask, "b j -> (b h) () j", h=h) 136 | sim.masked_fill_(~mask, max_neg_value) 137 | # attention 138 | attn = sim.softmax(dim=-1) 139 | # dropout 140 | attn = self.dropout(attn) 141 | out = einsum("b i j, b j d -> b i d", attn, v) 142 | 143 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 144 | out = self.to_out(out) 145 | return out 146 | 147 | 148 | def act_layer(act): 149 | if act == "relu": 150 | return nn.ReLU() 151 | elif act == "lrelu": 152 | return nn.LeakyReLU(LRELU_SLOPE) 153 | elif act == "elu": 154 | return nn.ELU() 155 | elif act == "tanh": 156 | return nn.Tanh() 157 | elif act == "prelu": 158 | return nn.PReLU() 159 | else: 160 | raise ValueError("%s not recognized." % act) 161 | 162 | 163 | def norm_layer2d(norm, channels): 164 | if norm == "batch": 165 | return nn.BatchNorm2d(channels) 166 | elif norm == "instance": 167 | return nn.InstanceNorm2d(channels, affine=True) 168 | elif norm == "layer": 169 | return nn.GroupNorm(1, channels, affine=True) 170 | elif norm == "group": 171 | return nn.GroupNorm(4, channels, affine=True) 172 | else: 173 | raise ValueError("%s not recognized." % norm) 174 | 175 | 176 | def norm_layer1d(norm, num_channels): 177 | if norm == "batch": 178 | return nn.BatchNorm1d(num_channels) 179 | elif norm == "instance": 180 | return nn.InstanceNorm1d(num_channels, affine=True) 181 | elif norm == "layer": 182 | return nn.LayerNorm(num_channels) 183 | elif norm == "group": 184 | return nn.GroupNorm(4, num_channels, affine=True) 185 | else: 186 | raise ValueError("%s not recognized." % norm) 187 | 188 | 189 | class Conv2DBlock(nn.Module): 190 | def __init__( 191 | self, 192 | in_channels, 193 | out_channels, 194 | kernel_sizes=3, 195 | strides=1, 196 | norm=None, 197 | activation=None, 198 | padding_mode="replicate", 199 | padding=None, 200 | ): 201 | super().__init__() 202 | padding = kernel_sizes // 2 if padding is None else padding 203 | self.conv2d = nn.Conv2d( 204 | in_channels, 205 | out_channels, 206 | kernel_sizes, 207 | strides, 208 | padding=padding, 209 | padding_mode=padding_mode, 210 | ) 211 | 212 | if activation is None: 213 | nn.init.xavier_uniform_( 214 | self.conv2d.weight, gain=nn.init.calculate_gain("linear") 215 | ) 216 | nn.init.zeros_(self.conv2d.bias) 217 | elif activation == "tanh": 218 | nn.init.xavier_uniform_( 219 | self.conv2d.weight, gain=nn.init.calculate_gain("tanh") 220 | ) 221 | nn.init.zeros_(self.conv2d.bias) 222 | elif activation == "lrelu": 223 | nn.init.kaiming_uniform_( 224 | self.conv2d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" 225 | ) 226 | nn.init.zeros_(self.conv2d.bias) 227 | elif activation == "relu": 228 | nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity="relu") 229 | nn.init.zeros_(self.conv2d.bias) 230 | else: 231 | raise ValueError() 232 | 233 | self.activation = None 234 | if norm is not None: 235 | self.norm = norm_layer2d(norm, out_channels) 236 | else: 237 | self.norm = None 238 | if activation is not None: 239 | self.activation = act_layer(activation) 240 | self.out_channels = out_channels 241 | 242 | def forward(self, x): 243 | x = self.conv2d(x) 244 | x = self.norm(x) if self.norm is not None else x 245 | x = self.activation(x) if self.activation is not None else x 246 | return x 247 | 248 | 249 | class Conv2DUpsampleBlock(nn.Module): 250 | def __init__( 251 | self, 252 | in_channels, 253 | out_channels, 254 | strides, 255 | kernel_sizes=3, 256 | norm=None, 257 | activation=None, 258 | out_size=None, 259 | ): 260 | super().__init__() 261 | layer = [ 262 | Conv2DBlock(in_channels, out_channels, kernel_sizes, 1, norm, activation) 263 | ] 264 | if strides > 1: 265 | if out_size is None: 266 | layer.append( 267 | nn.Upsample(scale_factor=strides, mode="bilinear", align_corners=False) 268 | ) 269 | else: 270 | layer.append( 271 | nn.Upsample(size=out_size, mode="bilinear", align_corners=False) 272 | ) 273 | 274 | if out_size is not None: 275 | if kernel_sizes % 2 == 0: 276 | kernel_sizes += 1 277 | 278 | convt_block = Conv2DBlock( 279 | out_channels, out_channels, kernel_sizes, 1, norm, activation 280 | ) 281 | layer.append(convt_block) 282 | self.conv_up = nn.Sequential(*layer) 283 | 284 | def forward(self, x): 285 | return self.conv_up(x) 286 | 287 | 288 | class DenseBlock(nn.Module): 289 | def __init__(self, in_features, out_features, norm=None, activation=None): 290 | super(DenseBlock, self).__init__() 291 | self.linear = nn.Linear(in_features, out_features) 292 | 293 | if activation is None: 294 | nn.init.xavier_uniform_( 295 | self.linear.weight, gain=nn.init.calculate_gain("linear") 296 | ) 297 | nn.init.zeros_(self.linear.bias) 298 | elif activation == "tanh": 299 | nn.init.xavier_uniform_( 300 | self.linear.weight, gain=nn.init.calculate_gain("tanh") 301 | ) 302 | nn.init.zeros_(self.linear.bias) 303 | elif activation == "lrelu": 304 | nn.init.kaiming_uniform_( 305 | self.linear.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" 306 | ) 307 | nn.init.zeros_(self.linear.bias) 308 | elif activation == "relu": 309 | nn.init.kaiming_uniform_(self.linear.weight, nonlinearity="relu") 310 | nn.init.zeros_(self.linear.bias) 311 | else: 312 | raise ValueError() 313 | 314 | self.activation = None 315 | self.norm = None 316 | if norm is not None: 317 | self.norm = norm_layer1d(norm, out_features) 318 | if activation is not None: 319 | self.activation = act_layer(activation) 320 | 321 | def forward(self, x): 322 | x = self.linear(x) 323 | x = self.norm(x) if self.norm is not None else x 324 | x = self.activation(x) if self.activation is not None else x 325 | return x 326 | 327 | 328 | # based on https://pytorch.org/tutorials/beginner/transformer_tutorial.html 329 | class FixedPositionalEncoding(nn.Module): 330 | def __init__(self, feat_per_dim: int, feat_scale_factor: int): 331 | super().__init__() 332 | self.feat_scale_factor = feat_scale_factor 333 | # shape [1, feat_per_dim // 2] 334 | div_term = torch.exp( 335 | torch.arange(0, feat_per_dim, 2) * (-math.log(10000.0) / 336 | feat_per_dim) 337 | ).unsqueeze(0) 338 | self.register_buffer("div_term", div_term) 339 | 340 | def forward(self, x): 341 | """ 342 | :param x: Tensor, shape [batch_size, input_dim] 343 | :return: Tensor, shape [batch_size, input_dim * feat_per_dim] 344 | """ 345 | assert len(x.shape) == 2 346 | batch_size, input_dim = x.shape 347 | x = x.view(-1, 1) 348 | x = torch.cat(( 349 | torch.sin(self.feat_scale_factor * x * self.div_term), 350 | torch.cos(self.feat_scale_factor * x * self.div_term)), dim=1) 351 | x = x.view(batch_size, -1) 352 | return x 353 | -------------------------------------------------------------------------------- /rvt/mvt/aug_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | # Adapted from: https://github.com/stepjam/ARM/blob/main/arm/utils.py 6 | # utils functions for rotation augmentation 7 | import torch 8 | import numpy as np 9 | from scipy.spatial.transform import Rotation 10 | 11 | 12 | def rand_dist(size, min=-1.0, max=1.0): 13 | return (max - min) * torch.rand(size) + min 14 | 15 | 16 | def rand_discrete(size, min=0, max=1): 17 | if min == max: 18 | return torch.zeros(size) 19 | return torch.randint(min, max + 1, size) 20 | 21 | 22 | def normalize_quaternion(quat): 23 | return np.array(quat) / np.linalg.norm(quat, axis=-1, keepdims=True) 24 | 25 | 26 | def sensitive_gimble_fix(euler): 27 | """ 28 | :param euler: euler angles in degree as np.ndarray in shape either [3] or 29 | [b, 3] 30 | """ 31 | # selecting sensitive angle 32 | select1 = (89 < euler[..., 1]) & (euler[..., 1] < 91) 33 | euler[select1, 1] = 90 34 | # selecting sensitive angle 35 | select2 = (-91 < euler[..., 1]) & (euler[..., 1] < -89) 36 | euler[select2, 1] = -90 37 | 38 | # recalulating the euler angles, see assert 39 | r = Rotation.from_euler("xyz", euler, degrees=True) 40 | euler = r.as_euler("xyz", degrees=True) 41 | 42 | select = select1 | select2 43 | assert (euler[select][..., 2] == 0).all(), euler 44 | 45 | return euler 46 | 47 | 48 | def quaternion_to_discrete_euler(quaternion, resolution, gimble_fix=True): 49 | """ 50 | :param gimble_fix: the euler values for x and y can be very sensitive 51 | around y=90 degrees. this leads to a multimodal distribution of x and y 52 | which could be hard for a network to learn. When gimble_fix is true, around 53 | y=90, we change the mode towards x=0, potentially making it easy for the 54 | network to learn. 55 | """ 56 | r = Rotation.from_quat(quaternion) 57 | 58 | euler = r.as_euler("xyz", degrees=True) 59 | if gimble_fix: 60 | euler = sensitive_gimble_fix(euler) 61 | 62 | euler += 180 63 | assert np.min(euler) >= 0 and np.max(euler) <= 360 64 | disc = np.around((euler / resolution)).astype(int) 65 | disc[disc == int(360 / resolution)] = 0 66 | return disc 67 | 68 | 69 | def quaternion_to_euler(quaternion, gimble_fix=True): 70 | """ 71 | :param gimble_fix: the euler values for x and y can be very sensitive 72 | around y=90 degrees. this leads to a multimodal distribution of x and y 73 | which could be hard for a network to learn. When gimble_fix is true, around 74 | y=90, we change the mode towards x=0, potentially making it easy for the 75 | network to learn. 76 | """ 77 | r = Rotation.from_quat(quaternion) 78 | 79 | euler = r.as_euler("xyz", degrees=True) 80 | if gimble_fix: 81 | euler = sensitive_gimble_fix(euler) 82 | 83 | euler += 180 84 | return euler 85 | 86 | 87 | def discrete_euler_to_quaternion(discrete_euler, resolution): 88 | euluer = (discrete_euler * resolution) - 180 89 | return Rotation.from_euler("xyz", euluer, degrees=True).as_quat() 90 | 91 | 92 | def point_to_voxel_index( 93 | point: np.ndarray, voxel_size: np.ndarray, coord_bounds: np.ndarray 94 | ): 95 | bb_mins = np.array(coord_bounds[0:3]) 96 | bb_maxs = np.array(coord_bounds[3:]) 97 | dims_m_one = np.array([voxel_size] * 3) - 1 98 | bb_ranges = bb_maxs - bb_mins 99 | res = bb_ranges / (np.array([voxel_size] * 3) + 1e-12) 100 | voxel_indicy = np.minimum( 101 | np.floor((point - bb_mins) / (res + 1e-12)).astype(np.int32), dims_m_one 102 | ) 103 | return voxel_indicy 104 | -------------------------------------------------------------------------------- /rvt/mvt/augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import numpy as np 6 | import torch 7 | import rvt.mvt.aug_utils as aug_utils 8 | from pytorch3d import transforms as torch3d_tf 9 | from scipy.spatial.transform import Rotation 10 | 11 | 12 | def perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds): 13 | """Perturb point clouds with given transformation. 14 | :param pcd: 15 | Either: 16 | - list of point clouds [[bs, 3, H, W], ...] for N cameras 17 | - point cloud [bs, 3, H, W] 18 | - point cloud [bs, 3, num_point] 19 | - point cloud [bs, num_point, 3] 20 | :param trans_shift_4x4: translation matrix [bs, 4, 4] 21 | :param rot_shift_4x4: rotation matrix [bs, 4, 4] 22 | :param action_gripper_4x4: original keyframe action gripper pose [bs, 4, 4] 23 | :param bounds: metric scene bounds [bs, 6] 24 | :return: peturbed point clouds in the same format as input 25 | """ 26 | # batch bounds if necessary 27 | 28 | # for easier compatibility 29 | single_pc = False 30 | if not isinstance(pcd, list): 31 | single_pc = True 32 | pcd = [pcd] 33 | 34 | bs = pcd[0].shape[0] 35 | if bounds.shape[0] != bs: 36 | bounds = bounds.repeat(bs, 1) 37 | 38 | perturbed_pcd = [] 39 | for p in pcd: 40 | p_shape = p.shape 41 | permute_p = False 42 | if len(p.shape) == 3: 43 | if p_shape[-1] == 3: 44 | num_points = p_shape[-2] 45 | p = p.permute(0, 2, 1) 46 | permute_p = True 47 | elif p_shape[-2] == 3: 48 | num_points = p_shape[-1] 49 | else: 50 | assert False, p_shape 51 | 52 | elif len(p.shape) == 4: 53 | assert p_shape[-1] != 3, p_shape[-1] 54 | assert p_shape[-2] != 3, p_shape[-2] 55 | num_points = p_shape[-1] * p_shape[-2] 56 | 57 | else: 58 | assert False, len(p.shape) 59 | 60 | action_trans_3x1 = ( 61 | action_gripper_4x4[:, 0:3, 3].unsqueeze(-1).repeat(1, 1, num_points) 62 | ) 63 | trans_shift_3x1 = ( 64 | trans_shift_4x4[:, 0:3, 3].unsqueeze(-1).repeat(1, 1, num_points) 65 | ) 66 | 67 | # flatten point cloud 68 | p_flat = p.reshape(bs, 3, -1) 69 | p_flat_4x1_action_origin = torch.ones(bs, 4, p_flat.shape[-1]).to(p_flat.device) 70 | 71 | # shift points to have action_gripper pose as the origin 72 | p_flat_4x1_action_origin[:, :3, :] = p_flat - action_trans_3x1 73 | 74 | # apply rotation 75 | perturbed_p_flat_4x1_action_origin = torch.bmm( 76 | p_flat_4x1_action_origin.transpose(2, 1), rot_shift_4x4 77 | ).transpose(2, 1) 78 | 79 | # apply bounded translations 80 | bounds_x_min, bounds_x_max = bounds[:, 0].min(), bounds[:, 3].max() 81 | bounds_y_min, bounds_y_max = bounds[:, 1].min(), bounds[:, 4].max() 82 | bounds_z_min, bounds_z_max = bounds[:, 2].min(), bounds[:, 5].max() 83 | 84 | action_then_trans_3x1 = action_trans_3x1 + trans_shift_3x1 85 | action_then_trans_3x1_x = torch.clamp( 86 | action_then_trans_3x1[:, 0], min=bounds_x_min, max=bounds_x_max 87 | ) 88 | action_then_trans_3x1_y = torch.clamp( 89 | action_then_trans_3x1[:, 1], min=bounds_y_min, max=bounds_y_max 90 | ) 91 | action_then_trans_3x1_z = torch.clamp( 92 | action_then_trans_3x1[:, 2], min=bounds_z_min, max=bounds_z_max 93 | ) 94 | action_then_trans_3x1 = torch.stack( 95 | [action_then_trans_3x1_x, action_then_trans_3x1_y, action_then_trans_3x1_z], 96 | dim=1, 97 | ) 98 | 99 | # shift back the origin 100 | perturbed_p_flat_3x1 = ( 101 | perturbed_p_flat_4x1_action_origin[:, :3, :] + action_then_trans_3x1 102 | ) 103 | if permute_p: 104 | perturbed_p_flat_3x1 = torch.permute(perturbed_p_flat_3x1, (0, 2, 1)) 105 | perturbed_p = perturbed_p_flat_3x1.reshape(p_shape) 106 | perturbed_pcd.append(perturbed_p) 107 | 108 | if single_pc: 109 | perturbed_pcd = perturbed_pcd[0] 110 | 111 | return perturbed_pcd 112 | 113 | 114 | # version copied from peract: 115 | # https://github.com/peract/peract/blob/a3b0bd855d7e749119e4fcbe3ed7168ba0f283fd/voxel/augmentation.py#L68 116 | def apply_se3_augmentation( 117 | pcd, 118 | action_gripper_pose, 119 | action_trans, 120 | action_rot_grip, 121 | bounds, 122 | layer, 123 | trans_aug_range, 124 | rot_aug_range, 125 | rot_aug_resolution, 126 | voxel_size, 127 | rot_resolution, 128 | device, 129 | ): 130 | """Apply SE3 augmentation to a point clouds and actions. 131 | :param pcd: list of point clouds [[bs, 3, H, W], ...] for N cameras 132 | :param action_gripper_pose: 6-DoF pose of keyframe action [bs, 7] 133 | :param action_trans: discretized translation action [bs, 3] 134 | :param action_rot_grip: discretized rotation and gripper action [bs, 4] 135 | :param bounds: metric scene bounds of voxel grid [bs, 6] 136 | :param layer: voxelization layer (always 1 for PerAct) 137 | :param trans_aug_range: range of translation augmentation [x_range, y_range, z_range] 138 | :param rot_aug_range: range of rotation augmentation [x_range, y_range, z_range] 139 | :param rot_aug_resolution: degree increments for discretized augmentation rotations 140 | :param voxel_size: voxelization resoltion 141 | :param rot_resolution: degree increments for discretized rotations 142 | :param device: torch device 143 | :return: perturbed action_trans, action_rot_grip, pcd 144 | """ 145 | 146 | # batch size 147 | bs = pcd[0].shape[0] 148 | 149 | # identity matrix 150 | identity_4x4 = torch.eye(4).unsqueeze(0).repeat(bs, 1, 1).to(device=device) 151 | 152 | # 4x4 matrix of keyframe action gripper pose 153 | action_gripper_trans = action_gripper_pose[:, :3] 154 | action_gripper_quat_wxyz = torch.cat( 155 | (action_gripper_pose[:, 6].unsqueeze(1), action_gripper_pose[:, 3:6]), dim=1 156 | ) 157 | action_gripper_rot = torch3d_tf.quaternion_to_matrix(action_gripper_quat_wxyz) 158 | action_gripper_4x4 = identity_4x4.detach().clone() 159 | action_gripper_4x4[:, :3, :3] = action_gripper_rot 160 | action_gripper_4x4[:, 0:3, 3] = action_gripper_trans 161 | 162 | perturbed_trans = torch.full_like(action_trans, -1.0) 163 | perturbed_rot_grip = torch.full_like(action_rot_grip, -1.0) 164 | 165 | # perturb the action, check if it is within bounds, if not, try another perturbation 166 | perturb_attempts = 0 167 | while torch.any(perturbed_trans < 0): 168 | # might take some repeated attempts to find a perturbation that doesn't go out of bounds 169 | perturb_attempts += 1 170 | if perturb_attempts > 100: 171 | raise Exception("Failing to perturb action and keep it within bounds.") 172 | 173 | # sample translation perturbation with specified range 174 | trans_range = (bounds[:, 3:] - bounds[:, :3]) * trans_aug_range.to( 175 | device=device 176 | ) 177 | trans_shift = trans_range * aug_utils.rand_dist((bs, 3)).to(device=device) 178 | trans_shift_4x4 = identity_4x4.detach().clone() 179 | trans_shift_4x4[:, 0:3, 3] = trans_shift 180 | 181 | # sample rotation perturbation at specified resolution and range 182 | roll_aug_steps = int(rot_aug_range[0] // rot_aug_resolution) 183 | pitch_aug_steps = int(rot_aug_range[1] // rot_aug_resolution) 184 | yaw_aug_steps = int(rot_aug_range[2] // rot_aug_resolution) 185 | 186 | roll = aug_utils.rand_discrete( 187 | (bs, 1), min=-roll_aug_steps, max=roll_aug_steps 188 | ) * np.deg2rad(rot_aug_resolution) 189 | pitch = aug_utils.rand_discrete( 190 | (bs, 1), min=-pitch_aug_steps, max=pitch_aug_steps 191 | ) * np.deg2rad(rot_aug_resolution) 192 | yaw = aug_utils.rand_discrete( 193 | (bs, 1), min=-yaw_aug_steps, max=yaw_aug_steps 194 | ) * np.deg2rad(rot_aug_resolution) 195 | rot_shift_3x3 = torch3d_tf.euler_angles_to_matrix( 196 | torch.cat((roll, pitch, yaw), dim=1), "XYZ" 197 | ) 198 | rot_shift_4x4 = identity_4x4.detach().clone() 199 | rot_shift_4x4[:, :3, :3] = rot_shift_3x3 200 | 201 | # rotate then translate the 4x4 keyframe action 202 | perturbed_action_gripper_4x4 = torch.bmm(action_gripper_4x4, rot_shift_4x4) 203 | perturbed_action_gripper_4x4[:, 0:3, 3] += trans_shift 204 | 205 | # convert transformation matrix to translation + quaternion 206 | perturbed_action_trans = perturbed_action_gripper_4x4[:, 0:3, 3].cpu().numpy() 207 | perturbed_action_quat_wxyz = torch3d_tf.matrix_to_quaternion( 208 | perturbed_action_gripper_4x4[:, :3, :3] 209 | ) 210 | perturbed_action_quat_xyzw = ( 211 | torch.cat( 212 | [ 213 | perturbed_action_quat_wxyz[:, 1:], 214 | perturbed_action_quat_wxyz[:, 0].unsqueeze(1), 215 | ], 216 | dim=1, 217 | ) 218 | .cpu() 219 | .numpy() 220 | ) 221 | 222 | # discretize perturbed translation and rotation 223 | # TODO(mohit): do this in torch without any numpy. 224 | trans_indicies, rot_grip_indicies = [], [] 225 | for b in range(bs): 226 | bounds_idx = b if layer > 0 else 0 227 | bounds_np = bounds[bounds_idx].cpu().numpy() 228 | 229 | trans_idx = aug_utils.point_to_voxel_index( 230 | perturbed_action_trans[b], voxel_size, bounds_np 231 | ) 232 | trans_indicies.append(trans_idx.tolist()) 233 | 234 | quat = perturbed_action_quat_xyzw[b] 235 | quat = aug_utils.normalize_quaternion(perturbed_action_quat_xyzw[b]) 236 | if quat[-1] < 0: 237 | quat = -quat 238 | disc_rot = aug_utils.quaternion_to_discrete_euler(quat, rot_resolution) 239 | rot_grip_indicies.append( 240 | disc_rot.tolist() + [int(action_rot_grip[b, 3].cpu().numpy())] 241 | ) 242 | 243 | # if the perturbed action is out of bounds, 244 | # the discretized perturb_trans should have invalid indicies 245 | perturbed_trans = torch.from_numpy(np.array(trans_indicies)).to(device=device) 246 | perturbed_rot_grip = torch.from_numpy(np.array(rot_grip_indicies)).to( 247 | device=device 248 | ) 249 | 250 | action_trans = perturbed_trans 251 | action_rot_grip = perturbed_rot_grip 252 | 253 | # apply perturbation to pointclouds 254 | pcd = perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds) 255 | 256 | return action_trans, action_rot_grip, pcd 257 | 258 | 259 | def apply_se3_aug_con( 260 | pcd, 261 | action_gripper_pose, 262 | bounds, 263 | trans_aug_range, 264 | rot_aug_range, 265 | scale_aug_range=False, 266 | single_scale=True, 267 | ver=2, 268 | ): 269 | """Apply SE3 augmentation to a point clouds and actions. 270 | :param pcd: [bs, num_points, 3] 271 | :param action_gripper_pose: 6-DoF pose of keyframe action [bs, 7] 272 | :param bounds: metric scene bounds 273 | Either: 274 | - [bs, 6] 275 | - [6] 276 | :param trans_aug_range: range of translation augmentation 277 | [x_range, y_range, z_range]; this is expressed as the percentage of the scene bound 278 | :param rot_aug_range: range of rotation augmentation [x_range, y_range, z_range] 279 | :param scale_aug_range: range of scale augmentation [x_range, y_range, z_range] 280 | :param single_scale: whether we preserve the relative dimensions 281 | :return: perturbed action_gripper_pose, pcd 282 | """ 283 | 284 | # batch size 285 | bs = pcd.shape[0] 286 | device = pcd.device 287 | 288 | if len(bounds.shape) == 1: 289 | bounds = bounds.unsqueeze(0).repeat(bs, 1).to(device) 290 | if len(trans_aug_range.shape) == 1: 291 | trans_aug_range = trans_aug_range.unsqueeze(0).repeat(bs, 1).to(device) 292 | if len(rot_aug_range.shape) == 1: 293 | rot_aug_range = rot_aug_range.unsqueeze(0).repeat(bs, 1) 294 | 295 | # identity matrix 296 | identity_4x4 = torch.eye(4).unsqueeze(0).repeat(bs, 1, 1).to(device=device) 297 | 298 | # 4x4 matrix of keyframe action gripper pose 299 | action_gripper_trans = action_gripper_pose[:, :3] 300 | 301 | if ver == 1: 302 | action_gripper_quat_wxyz = torch.cat( 303 | (action_gripper_pose[:, 6].unsqueeze(1), action_gripper_pose[:, 3:6]), dim=1 304 | ) 305 | action_gripper_rot = torch3d_tf.quaternion_to_matrix(action_gripper_quat_wxyz) 306 | 307 | elif ver == 2: 308 | # applying gimble fix to calculate a new action_gripper_rot 309 | r = Rotation.from_quat(action_gripper_pose[:, 3:7].cpu().numpy()) 310 | euler = r.as_euler("xyz", degrees=True) 311 | euler = aug_utils.sensitive_gimble_fix(euler) 312 | action_gripper_rot = torch.tensor( 313 | Rotation.from_euler("xyz", euler, degrees=True).as_matrix(), 314 | device=action_gripper_pose.device, 315 | ) 316 | else: 317 | assert False 318 | 319 | action_gripper_4x4 = identity_4x4.detach().clone() 320 | action_gripper_4x4[:, :3, :3] = action_gripper_rot 321 | action_gripper_4x4[:, 0:3, 3] = action_gripper_trans 322 | 323 | # sample translation perturbation with specified range 324 | # augmentation range is a percentage of the scene bound 325 | trans_range = (bounds[:, 3:] - bounds[:, :3]) * trans_aug_range.to(device=device) 326 | # rand_dist samples value from -1 to 1 327 | trans_shift = trans_range * aug_utils.rand_dist((bs, 3)).to(device=device) 328 | 329 | # apply bounded translations 330 | bounds_x_min, bounds_x_max = bounds[:, 0], bounds[:, 3] 331 | bounds_y_min, bounds_y_max = bounds[:, 1], bounds[:, 4] 332 | bounds_z_min, bounds_z_max = bounds[:, 2], bounds[:, 5] 333 | 334 | trans_shift[:, 0] = torch.clamp( 335 | trans_shift[:, 0], 336 | min=bounds_x_min - action_gripper_trans[:, 0], 337 | max=bounds_x_max - action_gripper_trans[:, 0], 338 | ) 339 | trans_shift[:, 1] = torch.clamp( 340 | trans_shift[:, 1], 341 | min=bounds_y_min - action_gripper_trans[:, 1], 342 | max=bounds_y_max - action_gripper_trans[:, 1], 343 | ) 344 | trans_shift[:, 2] = torch.clamp( 345 | trans_shift[:, 2], 346 | min=bounds_z_min - action_gripper_trans[:, 2], 347 | max=bounds_z_max - action_gripper_trans[:, 2], 348 | ) 349 | 350 | trans_shift_4x4 = identity_4x4.detach().clone() 351 | trans_shift_4x4[:, 0:3, 3] = trans_shift 352 | 353 | roll = np.deg2rad(rot_aug_range[:, 0:1] * aug_utils.rand_dist((bs, 1))) 354 | pitch = np.deg2rad(rot_aug_range[:, 1:2] * aug_utils.rand_dist((bs, 1))) 355 | yaw = np.deg2rad(rot_aug_range[:, 2:3] * aug_utils.rand_dist((bs, 1))) 356 | rot_shift_3x3 = torch3d_tf.euler_angles_to_matrix( 357 | torch.cat((roll, pitch, yaw), dim=1), "XYZ" 358 | ) 359 | rot_shift_4x4 = identity_4x4.detach().clone() 360 | rot_shift_4x4[:, :3, :3] = rot_shift_3x3 361 | 362 | if ver == 1: 363 | # rotate then translate the 4x4 keyframe action 364 | perturbed_action_gripper_4x4 = torch.bmm(action_gripper_4x4, rot_shift_4x4) 365 | elif ver == 2: 366 | perturbed_action_gripper_4x4 = identity_4x4.detach().clone() 367 | perturbed_action_gripper_4x4[:, 0:3, 3] = action_gripper_4x4[:, 0:3, 3] 368 | perturbed_action_gripper_4x4[:, :3, :3] = torch.bmm( 369 | rot_shift_4x4.transpose(1, 2)[:, :3, :3], action_gripper_4x4[:, :3, :3] 370 | ) 371 | else: 372 | assert False 373 | 374 | perturbed_action_gripper_4x4[:, 0:3, 3] += trans_shift 375 | 376 | # convert transformation matrix to translation + quaternion 377 | perturbed_action_trans = perturbed_action_gripper_4x4[:, 0:3, 3].cpu().numpy() 378 | perturbed_action_quat_wxyz = torch3d_tf.matrix_to_quaternion( 379 | perturbed_action_gripper_4x4[:, :3, :3] 380 | ) 381 | perturbed_action_quat_xyzw = ( 382 | torch.cat( 383 | [ 384 | perturbed_action_quat_wxyz[:, 1:], 385 | perturbed_action_quat_wxyz[:, 0].unsqueeze(1), 386 | ], 387 | dim=1, 388 | ) 389 | .cpu() 390 | .numpy() 391 | ) 392 | 393 | # TODO: add scale augmentation 394 | 395 | # apply perturbation to pointclouds 396 | # takes care for not moving the point out of the image 397 | pcd = perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds) 398 | 399 | return perturbed_action_trans, perturbed_action_quat_xyzw, pcd 400 | -------------------------------------------------------------------------------- /rvt/mvt/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | _C = CN() 8 | 9 | _C.depth = 8 10 | _C.img_size = 220 11 | _C.add_proprio = True 12 | _C.proprio_dim = 4 13 | _C.add_lang = True 14 | _C.lang_dim = 512 15 | _C.lang_len = 77 16 | _C.img_feat_dim = 3 17 | _C.feat_dim = (72 * 3) + 2 + 2 18 | _C.im_channels = 64 19 | _C.attn_dim = 512 20 | _C.attn_heads = 8 21 | _C.attn_dim_head = 64 22 | _C.activation = "lrelu" 23 | _C.weight_tie_layers = False 24 | _C.attn_dropout = 0.1 25 | _C.decoder_dropout = 0.0 26 | _C.img_patch_size = 11 27 | _C.final_dim = 64 28 | _C.self_cross_ver = 1 29 | _C.add_corr = True 30 | _C.norm_corr = False 31 | _C.add_pixel_loc = True 32 | _C.add_depth = True 33 | _C.rend_three_views = False 34 | _C.use_point_renderer = False 35 | _C.pe_fix = True 36 | _C.feat_ver = 0 37 | _C.wpt_img_aug = 0.01 38 | _C.inp_pre_pro = True 39 | _C.inp_pre_con = True 40 | _C.cvx_up = False 41 | _C.xops = False 42 | _C.rot_ver = 0 43 | _C.num_rot = 72 44 | _C.stage_two = False 45 | _C.st_sca = 4 46 | _C.st_wpt_loc_aug = 0.05 47 | _C.st_wpt_loc_inp_no_noise = False 48 | _C.img_aug_2 = 0.0 49 | 50 | 51 | def get_cfg_defaults(): 52 | """Get a yacs CfgNode object with default values for my_project.""" 53 | return _C.clone() 54 | -------------------------------------------------------------------------------- /rvt/mvt/configs/rvt2.yaml: -------------------------------------------------------------------------------- 1 | stage_two: True 2 | norm_corr: True 3 | rend_three_views: True 4 | use_point_renderer: True 5 | inp_pre_pro: False 6 | inp_pre_con: False 7 | feat_ver: 1 8 | img_size: 224 9 | img_patch_size: 14 10 | cvx_up: True 11 | xops: True 12 | rot_ver: 1 13 | st_wpt_loc_inp_no_noise: True 14 | wpt_img_aug: 0.0 15 | img_aug_2: 0.05 16 | -------------------------------------------------------------------------------- /rvt/mvt/configs/rvt2_partial.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/RVT/367995a1a2169b6352bf4e8b0ed405890462a3a0/rvt/mvt/configs/rvt2_partial.yaml -------------------------------------------------------------------------------- /rvt/mvt/mvt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import copy 6 | import torch 7 | 8 | from torch import nn 9 | from torch.cuda.amp import autocast 10 | 11 | import rvt.mvt.utils as mvt_utils 12 | 13 | from rvt.mvt.mvt_single import MVT as MVTSingle 14 | from rvt.mvt.config import get_cfg_defaults 15 | from rvt.mvt.renderer import BoxRenderer 16 | 17 | 18 | class MVT(nn.Module): 19 | def __init__( 20 | self, 21 | depth, 22 | img_size, 23 | add_proprio, 24 | proprio_dim, 25 | add_lang, 26 | lang_dim, 27 | lang_len, 28 | img_feat_dim, 29 | feat_dim, 30 | im_channels, 31 | attn_dim, 32 | attn_heads, 33 | attn_dim_head, 34 | activation, 35 | weight_tie_layers, 36 | attn_dropout, 37 | decoder_dropout, 38 | img_patch_size, 39 | final_dim, 40 | self_cross_ver, 41 | add_corr, 42 | norm_corr, 43 | add_pixel_loc, 44 | add_depth, 45 | rend_three_views, 46 | use_point_renderer, 47 | pe_fix, 48 | feat_ver, 49 | wpt_img_aug, 50 | inp_pre_pro, 51 | inp_pre_con, 52 | cvx_up, 53 | xops, 54 | rot_ver, 55 | num_rot, 56 | stage_two, 57 | st_sca, 58 | st_wpt_loc_aug, 59 | st_wpt_loc_inp_no_noise, 60 | img_aug_2, 61 | renderer_device="cuda:0", 62 | ): 63 | """MultiView Transfomer 64 | :param stage_two: whether or not there are two stages 65 | :param st_sca: scaling of the pc in the second stage 66 | :param st_wpt_loc_aug: how much noise is to be added to wpt_local when 67 | transforming the pc in the second stage while training. This is 68 | expressed as a percentage of total pc size which is 2. 69 | :param st_wpt_loc_inp_no_noise: whether or not to add any noise to the 70 | wpt_local location which is fed to stage_two. This wpt_local 71 | location is used to extract features for rotation prediction 72 | currently. Other use cases might also arise later on. Even if 73 | st_wpt_loc_aug is True, this will compensate for that if set to 74 | True. 75 | :param img_aug_2: similar to img_aug in rvt repo but applied only to 76 | point feat and not the whole point cloud 77 | """ 78 | super().__init__() 79 | 80 | self.use_point_renderer = use_point_renderer 81 | if self.use_point_renderer: 82 | from point_renderer.rvt_renderer import RVTBoxRenderer as BoxRenderer 83 | else: 84 | from mvt.renderer import BoxRenderer 85 | global BoxRenderer 86 | 87 | # creating a dictonary of all the input parameters 88 | args = copy.deepcopy(locals()) 89 | del args["self"] 90 | del args["__class__"] 91 | del args["stage_two"] 92 | del args["st_sca"] 93 | del args["st_wpt_loc_aug"] 94 | del args["st_wpt_loc_inp_no_noise"] 95 | del args["img_aug_2"] 96 | 97 | self.rot_ver = rot_ver 98 | self.num_rot = num_rot 99 | self.stage_two = stage_two 100 | self.st_sca = st_sca 101 | self.st_wpt_loc_aug = st_wpt_loc_aug 102 | self.st_wpt_loc_inp_no_noise = st_wpt_loc_inp_no_noise 103 | self.img_aug_2 = img_aug_2 104 | 105 | # for verifying the input 106 | self.feat_ver = feat_ver 107 | self.img_feat_dim = img_feat_dim 108 | self.add_proprio = add_proprio 109 | self.proprio_dim = proprio_dim 110 | self.add_lang = add_lang 111 | if add_lang: 112 | lang_emb_dim, lang_max_seq_len = lang_dim, lang_len 113 | else: 114 | lang_emb_dim, lang_max_seq_len = 0, 0 115 | self.lang_emb_dim = lang_emb_dim 116 | self.lang_max_seq_len = lang_max_seq_len 117 | 118 | self.renderer = BoxRenderer( 119 | device=renderer_device, 120 | img_size=(img_size, img_size), 121 | three_views=rend_three_views, 122 | with_depth=add_depth, 123 | ) 124 | self.num_img = self.renderer.num_img 125 | self.proprio_dim = proprio_dim 126 | self.img_size = img_size 127 | 128 | self.mvt1 = MVTSingle( 129 | **args, 130 | renderer=self.renderer, 131 | no_feat=self.stage_two, 132 | ) 133 | if self.stage_two: 134 | self.mvt2 = MVTSingle(**args, renderer=self.renderer) 135 | 136 | def get_pt_loc_on_img(self, pt, mvt1_or_mvt2, dyn_cam_info, out=None): 137 | """ 138 | :param pt: point for which location on image is to be found. the point 139 | shoud be in the same reference frame as wpt_local (see forward()), 140 | even for mvt2 141 | :param out: output from mvt, when using mvt2, we also need to provide the 142 | origin location where where the point cloud needs to be shifted 143 | before estimating the location in the image 144 | """ 145 | assert len(pt.shape) == 3 146 | bs, _np, x = pt.shape 147 | assert x == 3 148 | 149 | assert isinstance(mvt1_or_mvt2, bool) 150 | if mvt1_or_mvt2: 151 | assert out is None 152 | out = self.mvt1.get_pt_loc_on_img(pt, dyn_cam_info) 153 | else: 154 | assert self.stage_two 155 | assert out is not None 156 | assert out['wpt_local1'].shape == (bs, 3) 157 | pt, _ = mvt_utils.trans_pc(pt, loc=out["wpt_local1"], sca=self.st_sca) 158 | pt = pt.view(bs, _np, 3) 159 | out = self.mvt2.get_pt_loc_on_img(pt, dyn_cam_info) 160 | 161 | return out 162 | 163 | def get_wpt(self, out, mvt1_or_mvt2, dyn_cam_info, y_q=None): 164 | """ 165 | Estimate the q-values given output from mvt 166 | :param out: output from mvt 167 | :param y_q: refer to the definition in mvt_single.get_wpt 168 | """ 169 | assert isinstance(mvt1_or_mvt2, bool) 170 | if mvt1_or_mvt2: 171 | wpt = self.mvt1.get_wpt( 172 | out, dyn_cam_info, y_q, 173 | ) 174 | else: 175 | assert self.stage_two 176 | wpt = self.mvt2.get_wpt( 177 | out["mvt2"], dyn_cam_info, y_q 178 | ) 179 | wpt = out["rev_trans"](wpt) 180 | 181 | return wpt 182 | 183 | def render(self, pc, img_feat, img_aug, mvt1_or_mvt2, dyn_cam_info): 184 | assert isinstance(mvt1_or_mvt2, bool) 185 | if mvt1_or_mvt2: 186 | mvt = self.mvt1 187 | else: 188 | mvt = self.mvt2 189 | 190 | with torch.no_grad(): 191 | with autocast(enabled=False): 192 | if dyn_cam_info is None: 193 | dyn_cam_info_itr = (None,) * len(pc) 194 | else: 195 | dyn_cam_info_itr = dyn_cam_info 196 | 197 | if mvt.add_corr: 198 | if mvt.norm_corr: 199 | img = [] 200 | for _pc, _img_feat, _dyn_cam_info in zip( 201 | pc, img_feat, dyn_cam_info_itr 202 | ): 203 | # fix when the pc is empty 204 | max_pc = 1.0 if len(_pc) == 0 else torch.max(torch.abs(_pc)) 205 | img.append( 206 | self.renderer( 207 | _pc, 208 | torch.cat((_pc / max_pc, _img_feat), dim=-1), 209 | fix_cam=True, 210 | dyn_cam_info=(_dyn_cam_info,) 211 | if not (_dyn_cam_info is None) 212 | else None, 213 | ).unsqueeze(0) 214 | ) 215 | else: 216 | img = [ 217 | self.renderer( 218 | _pc, 219 | torch.cat((_pc, _img_feat), dim=-1), 220 | fix_cam=True, 221 | dyn_cam_info=(_dyn_cam_info,) 222 | if not (_dyn_cam_info is None) 223 | else None, 224 | ).unsqueeze(0) 225 | for (_pc, _img_feat, _dyn_cam_info) in zip( 226 | pc, img_feat, dyn_cam_info_itr 227 | ) 228 | ] 229 | else: 230 | img = [ 231 | self.renderer( 232 | _pc, 233 | _img_feat, 234 | fix_cam=True, 235 | dyn_cam_info=(_dyn_cam_info,) 236 | if not (_dyn_cam_info is None) 237 | else None, 238 | ).unsqueeze(0) 239 | for (_pc, _img_feat, _dyn_cam_info) in zip( 240 | pc, img_feat, dyn_cam_info_itr 241 | ) 242 | ] 243 | 244 | img = torch.cat(img, 0) 245 | img = img.permute(0, 1, 4, 2, 3) 246 | 247 | # for visualization purposes 248 | if mvt.add_corr: 249 | mvt.img = img[:, :, 3:].clone().detach() 250 | else: 251 | mvt.img = img.clone().detach() 252 | 253 | # image augmentation 254 | if img_aug != 0: 255 | stdv = img_aug * torch.rand(1, device=img.device) 256 | # values in [-stdv, stdv] 257 | noise = stdv * ((2 * torch.rand(*img.shape, device=img.device)) - 1) 258 | img = torch.clamp(img + noise, -1, 1) 259 | 260 | if mvt.add_pixel_loc: 261 | bs = img.shape[0] 262 | pixel_loc = mvt.pixel_loc.to(img.device) 263 | img = torch.cat( 264 | (img, pixel_loc.unsqueeze(0).repeat(bs, 1, 1, 1, 1)), dim=2 265 | ) 266 | 267 | return img 268 | 269 | def verify_inp( 270 | self, 271 | pc, 272 | img_feat, 273 | proprio, 274 | lang_emb, 275 | img_aug, 276 | wpt_local, 277 | rot_x_y, 278 | ): 279 | bs = len(pc) 280 | assert bs == len(img_feat) 281 | 282 | if not self.training: 283 | # no img_aug when not training 284 | assert img_aug == 0 285 | assert rot_x_y is None, f"rot_x_y={rot_x_y}" 286 | 287 | if self.training: 288 | assert ( 289 | (not self.feat_ver == 1) 290 | or (not wpt_local is None) 291 | ) 292 | 293 | if self.rot_ver == 0: 294 | assert rot_x_y is None, f"rot_x_y={rot_x_y}" 295 | elif self.rot_ver == 1: 296 | assert rot_x_y.shape == (bs, 2), f"rot_x_y.shape={rot_x_y.shape}" 297 | assert (rot_x_y >= 0).all() and ( 298 | rot_x_y < self.num_rot 299 | ).all(), f"rot_x_y={rot_x_y}" 300 | else: 301 | assert False 302 | 303 | for _pc, _img_feat in zip(pc, img_feat): 304 | np, x1 = _pc.shape 305 | np2, x2 = _img_feat.shape 306 | 307 | assert np == np2 308 | assert x1 == 3 309 | assert x2 == self.img_feat_dim 310 | 311 | if self.add_proprio: 312 | bs3, x3 = proprio.shape 313 | assert bs == bs3 314 | assert ( 315 | x3 == self.proprio_dim 316 | ), "Does not support proprio of shape {proprio.shape}" 317 | else: 318 | assert proprio is None, "Invalid input for proprio={proprio}" 319 | 320 | if self.add_lang: 321 | bs4, x4, x5 = lang_emb.shape 322 | assert bs == bs4 323 | assert ( 324 | x4 == self.lang_max_seq_len 325 | ), "Does not support lang_emb of shape {lang_emb.shape}" 326 | assert ( 327 | x5 == self.lang_emb_dim 328 | ), "Does not support lang_emb of shape {lang_emb.shape}" 329 | else: 330 | assert (lang_emb is None) or ( 331 | torch.all(lang_emb == 0) 332 | ), f"Invalid input for lang={lang}" 333 | 334 | if not (wpt_local is None): 335 | bs5, x6 = wpt_local.shape 336 | assert bs == bs5 337 | assert x6 == 3, "Does not support wpt_local of shape {wpt_local.shape}" 338 | 339 | if self.training: 340 | assert (not self.stage_two) or (not wpt_local is None) 341 | 342 | def forward( 343 | self, 344 | pc, 345 | img_feat, 346 | proprio=None, 347 | lang_emb=None, 348 | img_aug=0, 349 | wpt_local=None, 350 | rot_x_y=None, 351 | **kwargs, 352 | ): 353 | """ 354 | :param pc: list of tensors, each tensor of shape (num_points, 3) 355 | :param img_feat: list tensors, each tensor of shape 356 | (bs, num_points, img_feat_dim) 357 | :param proprio: tensor of shape (bs, priprio_dim) 358 | :param lang_emb: tensor of shape (bs, lang_len, lang_dim) 359 | :param img_aug: (float) magnitude of augmentation in rgb image 360 | :param wpt_local: gt location of the wpt in 3D, tensor of shape 361 | (bs, 3) 362 | :param rot_x_y: (bs, 2) rotation in x and y direction 363 | """ 364 | self.verify_inp( 365 | pc=pc, 366 | img_feat=img_feat, 367 | proprio=proprio, 368 | lang_emb=lang_emb, 369 | img_aug=img_aug, 370 | wpt_local=wpt_local, 371 | rot_x_y=rot_x_y, 372 | ) 373 | with torch.no_grad(): 374 | if self.training and (self.img_aug_2 != 0): 375 | for x in img_feat: 376 | stdv = self.img_aug_2 * torch.rand(1, device=x.device) 377 | # values in [-stdv, stdv] 378 | noise = stdv * ((2 * torch.rand(*x.shape, device=x.device)) - 1) 379 | x = x + noise 380 | img = self.render( 381 | pc=pc, 382 | img_feat=img_feat, 383 | img_aug=img_aug, 384 | mvt1_or_mvt2=True, 385 | dyn_cam_info=None, 386 | ) 387 | 388 | if self.training: 389 | wpt_local_stage_one = wpt_local 390 | wpt_local_stage_one = wpt_local_stage_one.clone().detach() 391 | else: 392 | wpt_local_stage_one = wpt_local 393 | 394 | out = self.mvt1( 395 | img=img, 396 | proprio=proprio, 397 | lang_emb=lang_emb, 398 | wpt_local=wpt_local_stage_one, 399 | rot_x_y=rot_x_y, 400 | **kwargs, 401 | ) 402 | 403 | if self.stage_two: 404 | with torch.no_grad(): 405 | # adding then noisy location for training 406 | if self.training: 407 | # noise is added so that the wpt_local2 is not exactly at 408 | # the center of the pc 409 | wpt_local_stage_one_noisy = mvt_utils.add_uni_noi( 410 | wpt_local_stage_one.clone().detach(), 2 * self.st_wpt_loc_aug 411 | ) 412 | pc, rev_trans = mvt_utils.trans_pc( 413 | pc, loc=wpt_local_stage_one_noisy, sca=self.st_sca 414 | ) 415 | 416 | if self.st_wpt_loc_inp_no_noise: 417 | wpt_local2, _ = mvt_utils.trans_pc( 418 | wpt_local, loc=wpt_local_stage_one_noisy, sca=self.st_sca 419 | ) 420 | else: 421 | wpt_local2, _ = mvt_utils.trans_pc( 422 | wpt_local, loc=wpt_local_stage_one, sca=self.st_sca 423 | ) 424 | 425 | else: 426 | # bs, 3 427 | wpt_local = self.get_wpt( 428 | out, y_q=None, mvt1_or_mvt2=True, 429 | dyn_cam_info=None, 430 | ) 431 | pc, rev_trans = mvt_utils.trans_pc( 432 | pc, loc=wpt_local, sca=self.st_sca 433 | ) 434 | # bad name! 435 | wpt_local_stage_one_noisy = wpt_local 436 | 437 | # must pass None to mvt2 while in eval 438 | wpt_local2 = None 439 | 440 | img = self.render( 441 | pc=pc, 442 | img_feat=img_feat, 443 | img_aug=img_aug, 444 | mvt1_or_mvt2=False, 445 | dyn_cam_info=None, 446 | ) 447 | 448 | out_mvt2 = self.mvt2( 449 | img=img, 450 | proprio=proprio, 451 | lang_emb=lang_emb, 452 | wpt_local=wpt_local2, 453 | rot_x_y=rot_x_y, 454 | **kwargs, 455 | ) 456 | 457 | out["wpt_local1"] = wpt_local_stage_one_noisy 458 | out["rev_trans"] = rev_trans 459 | out["mvt2"] = out_mvt2 460 | 461 | return out 462 | 463 | def free_mem(self): 464 | """ 465 | Could be used for freeing up the memory once a batch of testing is done 466 | """ 467 | if not self.use_point_renderer: 468 | print("Freeing up some memory") 469 | self.renderer.free_mem() 470 | 471 | 472 | if __name__ == "__main__": 473 | cfg = get_cfg_defaults() 474 | mvt = MVT(**cfg) 475 | breakpoint() 476 | -------------------------------------------------------------------------------- /rvt/mvt/raft_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from mvt.utils import ForkedPdb 9 | 10 | 11 | class ConvexUpSample(nn.Module): 12 | """ 13 | Learned convex upsampling 14 | """ 15 | 16 | def __init__( 17 | self, in_dim, out_dim, up_ratio, up_kernel=3, mask_scale=0.1, with_bn=False 18 | ): 19 | """ 20 | 21 | :param in_dim: 22 | :param out_dim: 23 | :param up_ratio: 24 | :param up_kernel: 25 | :param mask_scale: 26 | """ 27 | super().__init__() 28 | self.in_dim = in_dim 29 | self.out_dim = out_dim 30 | self.up_ratio = up_ratio 31 | self.up_kernel = up_kernel 32 | self.mask_scale = mask_scale 33 | self.with_bn = with_bn 34 | 35 | assert (self.up_kernel % 2) == 1 36 | 37 | if with_bn: 38 | self.net_out_bn1 = nn.BatchNorm2d(2 * in_dim) 39 | self.net_out_bn2 = nn.BatchNorm2d(2 * in_dim) 40 | 41 | self.net_out = nn.Sequential( 42 | nn.Conv2d(in_dim, 2 * in_dim, 3, padding=1), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(2 * in_dim, 2 * in_dim, 3, padding=1), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(2 * in_dim, out_dim, 3, padding=1), 47 | ) 48 | 49 | mask_dim = (self.up_ratio**2) * (self.up_kernel**2) 50 | self.net_mask = nn.Sequential( 51 | nn.Conv2d(in_dim, 2 * in_dim, 3, padding=1), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(2 * in_dim, mask_dim, 1, padding=0), 54 | ) 55 | 56 | def forward(self, x): 57 | """ 58 | 59 | :param x: (bs, in_dim, h, w) 60 | :return: (bs, out_dim, h*up_ratio, w*up_ratio) 61 | """ 62 | 63 | bs, c, h, w = x.shape 64 | assert c == self.in_dim, c 65 | 66 | # low resolution output 67 | if self.with_bn: 68 | out_low = self.net_out[0](x) 69 | out_low = self.net_out_bn1(out_low) 70 | out_low = self.net_out[1](out_low) 71 | out_low = self.net_out[2](out_low) 72 | out_low = self.net_out_bn2(out_low) 73 | out_low = self.net_out[3](out_low) 74 | out_low = self.net_out[4](out_low) 75 | else: 76 | out_low = self.net_out(x) 77 | 78 | mask = self.mask_scale * self.net_mask(x) 79 | mask = mask.view(bs, 1, self.up_kernel**2, self.up_ratio, self.up_ratio, h, w) 80 | mask = torch.softmax(mask, dim=2) 81 | 82 | out = F.unfold( 83 | out_low, 84 | kernel_size=[self.up_kernel, self.up_kernel], 85 | padding=self.up_kernel // 2, 86 | ) 87 | out = out.view(bs, self.out_dim, self.up_kernel**2, 1, 1, h, w) 88 | 89 | out = torch.sum(out * mask, dim=2) 90 | out = out.permute(0, 1, 4, 2, 5, 3) 91 | out = out.reshape(bs, self.out_dim, h * self.up_ratio, w * self.up_ratio) 92 | 93 | return out 94 | 95 | 96 | if __name__ == "__main__": 97 | net = ConvexUpSample(2, 5, 20).cuda() 98 | x = torch.rand(4, 5, 10, 10).cuda() 99 | y = net(x) 100 | breakpoint() 101 | -------------------------------------------------------------------------------- /rvt/mvt/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | """ 6 | Utility function for MVT 7 | """ 8 | import pdb 9 | import sys 10 | 11 | import torch 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | def place_pc_in_cube( 17 | pc, app_pc=None, with_mean_or_bounds=True, scene_bounds=None, no_op=False 18 | ): 19 | """ 20 | calculate the transformation that would place the point cloud (pc) inside a 21 | cube of size (2, 2, 2). The pc is centered at mean if with_mean_or_bounds 22 | is True. If with_mean_or_bounds is False, pc is centered around the mid 23 | point of the bounds. The transformation is applied to point cloud app_pc if 24 | it is not None. If app_pc is None, the transformation is applied on pc. 25 | :param pc: pc of shape (num_points_1, 3) 26 | :param app_pc: 27 | Either 28 | - pc of shape (num_points_2, 3) 29 | - None 30 | :param with_mean_or_bounds: 31 | Either: 32 | True: pc is centered around its mean 33 | False: pc is centered around the center of the scene bounds 34 | :param scene_bounds: [x_min, y_min, z_min, x_max, y_max, z_max] 35 | :param no_op: if no_op, then this function does not do any operation 36 | """ 37 | if no_op: 38 | if app_pc is None: 39 | app_pc = torch.clone(pc) 40 | 41 | return app_pc, lambda x: x 42 | 43 | if with_mean_or_bounds: 44 | assert scene_bounds is None 45 | else: 46 | assert not (scene_bounds is None) 47 | if with_mean_or_bounds: 48 | pc_mid = (torch.max(pc, 0)[0] + torch.min(pc, 0)[0]) / 2 49 | x_len, y_len, z_len = torch.max(pc, 0)[0] - torch.min(pc, 0)[0] 50 | else: 51 | x_min, y_min, z_min, x_max, y_max, z_max = scene_bounds 52 | pc_mid = torch.tensor( 53 | [ 54 | (x_min + x_max) / 2, 55 | (y_min + y_max) / 2, 56 | (z_min + z_max) / 2, 57 | ] 58 | ).to(pc.device) 59 | x_len, y_len, z_len = x_max - x_min, y_max - y_min, z_max - z_min 60 | 61 | scale = 2 / max(x_len, y_len, z_len) 62 | if app_pc is None: 63 | app_pc = torch.clone(pc) 64 | app_pc = (app_pc - pc_mid) * scale 65 | 66 | # reverse transformation to obtain app_pc in original frame 67 | def rev_trans(x): 68 | return (x / scale) + pc_mid 69 | 70 | return app_pc, rev_trans 71 | 72 | 73 | def trans_pc(pc, loc, sca): 74 | """ 75 | change location of the center of the pc and scale it 76 | :param pc: 77 | either: 78 | - tensor of shape(b, num_points, 3) 79 | - tensor of shape(b, 3) 80 | - list of pc each with size (num_points, 3) 81 | :param loc: (b, 3 ) 82 | :param sca: 1 or (3) 83 | """ 84 | assert len(loc.shape) == 2 85 | assert loc.shape[-1] == 3 86 | if isinstance(pc, list): 87 | assert all([(len(x.shape) == 2) and (x.shape[1] == 3) for x in pc]) 88 | pc = [sca * (x - y) for x, y in zip(pc, loc)] 89 | elif isinstance(pc, torch.Tensor): 90 | assert len(pc.shape) in [2, 3] 91 | assert pc.shape[-1] == 3 92 | if len(pc.shape) == 2: 93 | pc = sca * (pc - loc) 94 | else: 95 | pc = sca * (pc - loc.unsqueeze(1)) 96 | else: 97 | assert False 98 | 99 | # reverse transformation to obtain app_pc in original frame 100 | def rev_trans(x): 101 | assert isinstance(x, torch.Tensor) 102 | return (x / sca) + loc 103 | 104 | return pc, rev_trans 105 | 106 | 107 | def add_uni_noi(x, u): 108 | """ 109 | adds uniform noise to a tensor x. output is tensor where each element is 110 | in [x-u, x+u] 111 | :param x: tensor 112 | :param u: float 113 | """ 114 | assert isinstance(u, float) 115 | # move noise in -1 to 1 116 | noise = (2 * torch.rand(*x.shape, device=x.device)) - 1 117 | x = x + (u * noise) 118 | return x 119 | 120 | 121 | def generate_hm_from_pt(pt, res, sigma, thres_sigma_times=3): 122 | """ 123 | Pytorch code to generate heatmaps from point. Points with values less than 124 | thres are made 0 125 | :type pt: torch.FloatTensor of size (num_pt, 2) 126 | :type res: int or (int, int) 127 | :param sigma: the std of the gaussian distribition. if it is -1, we 128 | generate a hm with one hot vector 129 | :type sigma: float 130 | :type thres: float 131 | """ 132 | num_pt, x = pt.shape 133 | assert x == 2 134 | 135 | if isinstance(res, int): 136 | resx = resy = res 137 | else: 138 | resx, resy = res 139 | 140 | _hmx = torch.arange(0, resy).to(pt.device) 141 | _hmx = _hmx.view([1, resy]).repeat(resx, 1).view([resx, resy, 1]) 142 | _hmy = torch.arange(0, resx).to(pt.device) 143 | _hmy = _hmy.view([resx, 1]).repeat(1, resy).view([resx, resy, 1]) 144 | hm = torch.cat([_hmx, _hmy], dim=-1) 145 | hm = hm.view([1, resx, resy, 2]).repeat(num_pt, 1, 1, 1) 146 | 147 | pt = pt.view([num_pt, 1, 1, 2]) 148 | hm = torch.exp(-1 * torch.sum((hm - pt) ** 2, -1) / (2 * (sigma**2))) 149 | thres = np.exp(-1 * (thres_sigma_times**2) / 2) 150 | hm[hm < thres] = 0.0 151 | 152 | hm /= torch.sum(hm, (1, 2), keepdim=True) + 1e-6 153 | 154 | # TODO: make a more efficient version 155 | if sigma == -1: 156 | _hm = hm.view(num_pt, resx * resy) 157 | hm = torch.zeros((num_pt, resx * resy), device=hm.device) 158 | temp = torch.arange(num_pt).to(hm.device) 159 | hm[temp, _hm.argmax(-1)] = 1 160 | 161 | return hm 162 | 163 | 164 | class ForkedPdb(pdb.Pdb): 165 | """A Pdb subclass that may be used 166 | from a forked multiprocessing child 167 | """ 168 | 169 | def interaction(self, *args, **kwargs): 170 | _stdin = sys.stdin 171 | try: 172 | sys.stdin = open("/dev/stdin") 173 | pdb.Pdb.interaction(self, *args, **kwargs) 174 | finally: 175 | sys.stdin = _stdin 176 | -------------------------------------------------------------------------------- /rvt/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import os 6 | import time 7 | import tqdm 8 | import random 9 | import yaml 10 | import argparse 11 | 12 | from collections import defaultdict 13 | from contextlib import redirect_stdout 14 | 15 | import torch 16 | import torch.multiprocessing as mp 17 | import torch.distributed as dist 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | 20 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 21 | os.environ["BITSANDBYTES_NOWELCOME"] = "1" 22 | 23 | import config as exp_cfg_mod 24 | import rvt.models.rvt_agent as rvt_agent 25 | import rvt.utils.ddp_utils as ddp_utils 26 | import rvt.mvt.config as mvt_cfg_mod 27 | 28 | from rvt.mvt.mvt import MVT 29 | from rvt.models.rvt_agent import print_eval_log, print_loss_log 30 | from rvt.utils.get_dataset import get_dataset 31 | from rvt.utils.rvt_utils import ( 32 | TensorboardManager, 33 | short_name, 34 | get_num_feat, 35 | load_agent, 36 | RLBENCH_TASKS, 37 | ) 38 | from rvt.utils.peract_utils import ( 39 | CAMERAS, 40 | SCENE_BOUNDS, 41 | IMAGE_SIZE, 42 | DATA_FOLDER, 43 | ) 44 | 45 | 46 | # new train takes the dataset as input 47 | def train(agent, dataset, training_iterations, rank=0): 48 | agent.train() 49 | log = defaultdict(list) 50 | 51 | data_iter = iter(dataset) 52 | iter_command = range(training_iterations) 53 | 54 | for iteration in tqdm.tqdm( 55 | iter_command, disable=(rank != 0), position=0, leave=True 56 | ): 57 | 58 | raw_batch = next(data_iter) 59 | batch = { 60 | k: v.to(agent._device) 61 | for k, v in raw_batch.items() 62 | if type(v) == torch.Tensor 63 | } 64 | batch["tasks"] = raw_batch["tasks"] 65 | batch["lang_goal"] = raw_batch["lang_goal"] 66 | update_args = { 67 | "step": iteration, 68 | } 69 | update_args.update( 70 | { 71 | "replay_sample": batch, 72 | "backprop": True, 73 | "reset_log": (iteration == 0), 74 | "eval_log": False, 75 | } 76 | ) 77 | agent.update(**update_args) 78 | 79 | if rank == 0: 80 | log = print_loss_log(agent) 81 | 82 | return log 83 | 84 | 85 | def save_agent(agent, path, epoch): 86 | model = agent._network 87 | optimizer = agent._optimizer 88 | lr_sched = agent._lr_sched 89 | 90 | if isinstance(model, DDP): 91 | model_state = model.module.state_dict() 92 | else: 93 | model_state = model.state_dict() 94 | 95 | torch.save( 96 | { 97 | "epoch": epoch, 98 | "model_state": model_state, 99 | "optimizer_state": optimizer.state_dict(), 100 | "lr_sched_state": lr_sched.state_dict(), 101 | }, 102 | path, 103 | ) 104 | 105 | 106 | def get_tasks(exp_cfg): 107 | parsed_tasks = exp_cfg.tasks.split(",") 108 | if parsed_tasks[0] == "all": 109 | tasks = RLBENCH_TASKS 110 | else: 111 | tasks = parsed_tasks 112 | return tasks 113 | 114 | 115 | def get_logdir(cmd_args, exp_cfg): 116 | log_dir = os.path.join(cmd_args.log_dir, exp_cfg.exp_id) 117 | os.makedirs(log_dir, exist_ok=True) 118 | return log_dir 119 | 120 | 121 | def dump_log(exp_cfg, mvt_cfg, cmd_args, log_dir): 122 | with open(f"{log_dir}/exp_cfg.yaml", "w") as yaml_file: 123 | with redirect_stdout(yaml_file): 124 | print(exp_cfg.dump()) 125 | 126 | with open(f"{log_dir}/mvt_cfg.yaml", "w") as yaml_file: 127 | with redirect_stdout(yaml_file): 128 | print(mvt_cfg.dump()) 129 | 130 | args = cmd_args.__dict__ 131 | with open(f"{log_dir}/args.yaml", "w") as yaml_file: 132 | yaml.dump(args, yaml_file) 133 | 134 | 135 | def experiment(rank, cmd_args, devices, port): 136 | """experiment. 137 | 138 | :param rank: 139 | :param cmd_args: 140 | :param devices: list or int. if list, we use ddp else not 141 | """ 142 | device = devices[rank] 143 | device = f"cuda:{device}" 144 | ddp = len(devices) > 1 145 | ddp_utils.setup(rank, world_size=len(devices), port=port) 146 | 147 | exp_cfg = exp_cfg_mod.get_cfg_defaults() 148 | if cmd_args.exp_cfg_path != "": 149 | exp_cfg.merge_from_file(cmd_args.exp_cfg_path) 150 | if cmd_args.exp_cfg_opts != "": 151 | exp_cfg.merge_from_list(cmd_args.exp_cfg_opts.split(" ")) 152 | 153 | if ddp: 154 | print(f"Running DDP on rank {rank}.") 155 | 156 | old_exp_cfg_peract_lr = exp_cfg.peract.lr 157 | old_exp_cfg_exp_id = exp_cfg.exp_id 158 | 159 | exp_cfg.peract.lr *= len(devices) * exp_cfg.bs 160 | if cmd_args.exp_cfg_opts != "": 161 | exp_cfg.exp_id += f"_{short_name(cmd_args.exp_cfg_opts)}" 162 | if cmd_args.mvt_cfg_opts != "": 163 | exp_cfg.exp_id += f"_{short_name(cmd_args.mvt_cfg_opts)}" 164 | 165 | if rank == 0: 166 | print(f"dict(exp_cfg)={dict(exp_cfg)}") 167 | exp_cfg.freeze() 168 | 169 | # Things to change 170 | BATCH_SIZE_TRAIN = exp_cfg.bs 171 | NUM_TRAIN = 100 172 | # to match peract, iterations per epoch 173 | TRAINING_ITERATIONS = int(exp_cfg.train_iter // (exp_cfg.bs * len(devices))) 174 | EPOCHS = exp_cfg.epochs 175 | TRAIN_REPLAY_STORAGE_DIR = "replay/replay_train" 176 | TEST_REPLAY_STORAGE_DIR = "replay/replay_val" 177 | log_dir = get_logdir(cmd_args, exp_cfg) 178 | tasks = get_tasks(exp_cfg) 179 | print("Training on {} tasks: {}".format(len(tasks), tasks)) 180 | 181 | t_start = time.time() 182 | get_dataset_func = lambda: get_dataset( 183 | tasks, 184 | BATCH_SIZE_TRAIN, 185 | None, 186 | TRAIN_REPLAY_STORAGE_DIR, 187 | None, 188 | DATA_FOLDER, 189 | NUM_TRAIN, 190 | None, 191 | cmd_args.refresh_replay, 192 | device, 193 | num_workers=exp_cfg.num_workers, 194 | only_train=True, 195 | sample_distribution_mode=exp_cfg.sample_distribution_mode, 196 | ) 197 | train_dataset, _ = get_dataset_func() 198 | t_end = time.time() 199 | print("Created Dataset. Time Cost: {} minutes".format((t_end - t_start) / 60.0)) 200 | 201 | if exp_cfg.agent == "our": 202 | mvt_cfg = mvt_cfg_mod.get_cfg_defaults() 203 | if cmd_args.mvt_cfg_path != "": 204 | mvt_cfg.merge_from_file(cmd_args.mvt_cfg_path) 205 | if cmd_args.mvt_cfg_opts != "": 206 | mvt_cfg.merge_from_list(cmd_args.mvt_cfg_opts.split(" ")) 207 | 208 | mvt_cfg.feat_dim = get_num_feat(exp_cfg.peract) 209 | mvt_cfg.freeze() 210 | 211 | # for maintaining backward compatibility 212 | assert mvt_cfg.num_rot == exp_cfg.peract.num_rotation_classes, print( 213 | mvt_cfg.num_rot, exp_cfg.peract.num_rotation_classes 214 | ) 215 | 216 | torch.cuda.set_device(device) 217 | torch.cuda.empty_cache() 218 | rvt = MVT( 219 | renderer_device=device, 220 | **mvt_cfg, 221 | ).to(device) 222 | if ddp: 223 | rvt = DDP(rvt, device_ids=[device]) 224 | 225 | agent = rvt_agent.RVTAgent( 226 | network=rvt, 227 | image_resolution=[IMAGE_SIZE, IMAGE_SIZE], 228 | add_lang=mvt_cfg.add_lang, 229 | stage_two=mvt_cfg.stage_two, 230 | rot_ver=mvt_cfg.rot_ver, 231 | scene_bounds=SCENE_BOUNDS, 232 | cameras=CAMERAS, 233 | log_dir=f"{log_dir}/test_run/", 234 | cos_dec_max_step=EPOCHS * TRAINING_ITERATIONS, 235 | **exp_cfg.peract, 236 | **exp_cfg.rvt, 237 | ) 238 | agent.build(training=True, device=device) 239 | else: 240 | assert False, "Incorrect agent" 241 | 242 | start_epoch = 0 243 | end_epoch = EPOCHS 244 | if exp_cfg.resume != "": 245 | agent_path = exp_cfg.resume 246 | print(f"Recovering model and checkpoint from {exp_cfg.resume}") 247 | epoch = load_agent(agent_path, agent, only_epoch=False) 248 | start_epoch = epoch + 1 249 | dist.barrier() 250 | 251 | if rank == 0: 252 | ## logging unchanged values to reproduce the same setting 253 | temp1 = exp_cfg.peract.lr 254 | temp2 = exp_cfg.exp_id 255 | exp_cfg.defrost() 256 | exp_cfg.peract.lr = old_exp_cfg_peract_lr 257 | exp_cfg.exp_id = old_exp_cfg_exp_id 258 | dump_log(exp_cfg, mvt_cfg, cmd_args, log_dir) 259 | exp_cfg.peract.lr = temp1 260 | exp_cfg.exp_id = temp2 261 | exp_cfg.freeze() 262 | tb = TensorboardManager(log_dir) 263 | 264 | print("Start training ...", flush=True) 265 | i = start_epoch 266 | while True: 267 | if i == end_epoch: 268 | break 269 | 270 | print(f"Rank [{rank}], Epoch [{i}]: Training on train dataset") 271 | out = train(agent, train_dataset, TRAINING_ITERATIONS, rank) 272 | 273 | if rank == 0: 274 | tb.update("train", i, out) 275 | 276 | if rank == 0: 277 | # TODO: add logic to only save some models 278 | save_agent(agent, f"{log_dir}/model_{i}.pth", i) 279 | save_agent(agent, f"{log_dir}/model_last.pth", i) 280 | i += 1 281 | 282 | if rank == 0: 283 | tb.close() 284 | print("[Finish]") 285 | 286 | 287 | if __name__ == "__main__": 288 | parser = argparse.ArgumentParser() 289 | parser.set_defaults(entry=lambda cmd_args: parser.print_help()) 290 | 291 | parser.add_argument("--refresh_replay", action="store_true", default=False) 292 | parser.add_argument("--device", type=str, default="0") 293 | parser.add_argument("--mvt_cfg_path", type=str, default="") 294 | parser.add_argument("--exp_cfg_path", type=str, default="") 295 | 296 | parser.add_argument("--mvt_cfg_opts", type=str, default="") 297 | parser.add_argument("--exp_cfg_opts", type=str, default="") 298 | 299 | parser.add_argument("--log-dir", type=str, default="runs") 300 | parser.add_argument("--with-eval", action="store_true", default=False) 301 | 302 | cmd_args = parser.parse_args() 303 | del ( 304 | cmd_args.entry 305 | ) # hack for multi processing -- removes an argument called entry which is not picklable 306 | 307 | devices = cmd_args.device.split(",") 308 | devices = [int(x) for x in devices] 309 | 310 | port = (random.randint(0, 3000) % 3000) + 27000 311 | mp.spawn(experiment, args=(cmd_args, devices, port), nprocs=len(devices), join=True) 312 | -------------------------------------------------------------------------------- /rvt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | -------------------------------------------------------------------------------- /rvt/utils/custom_rlbench_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | from rvt.libs.peract.helpers.custom_rlbench_env import CustomMultiTaskRLBenchEnv 6 | 7 | 8 | class CustomMultiTaskRLBenchEnv2(CustomMultiTaskRLBenchEnv): 9 | def __init__(self, *args, **kwargs): 10 | super(CustomMultiTaskRLBenchEnv2, self).__init__(*args, **kwargs) 11 | 12 | def reset(self) -> dict: 13 | super().reset() 14 | self._record_current_episode = ( 15 | self.eval 16 | and self._record_every_n > 0 17 | and self._episode_index % self._record_every_n == 0 18 | ) 19 | return self._previous_obs_dict 20 | 21 | def reset_to_demo(self, i, variation_number=-1): 22 | if self._episodes_this_task == self._swap_task_every: 23 | self._set_new_task() 24 | self._episodes_this_task = 0 25 | self._episodes_this_task += 1 26 | 27 | self._i = 0 28 | self._task.set_variation(-1) 29 | d = self._task.get_demos( 30 | 1, live_demos=False, random_selection=False, from_episode_number=i 31 | )[0] 32 | 33 | self._task.set_variation(d.variation_number) 34 | desc, obs = self._task.reset_to_demo(d) 35 | self._lang_goal = desc[0] 36 | 37 | self._previous_obs_dict = self.extract_obs(obs) 38 | self._record_current_episode = ( 39 | self.eval 40 | and self._record_every_n > 0 41 | and self._episode_index % self._record_every_n == 0 42 | ) 43 | self._episode_index += 1 44 | self._recorded_images.clear() 45 | 46 | return self._previous_obs_dict 47 | -------------------------------------------------------------------------------- /rvt/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | # initial source: https://colab.research.google.com/drive/1HAqemP4cE81SQ6QO1-N85j5bF4C0qLs0?usp=sharing 6 | # adapted to support loading from disk for faster initialization time 7 | 8 | # Adapted from: https://github.com/stepjam/ARM/blob/main/arm/c2farm/launch_utils.py 9 | import os 10 | import torch 11 | import pickle 12 | import logging 13 | import numpy as np 14 | from typing import List 15 | 16 | import clip 17 | import peract_colab.arm.utils as utils 18 | 19 | from peract_colab.rlbench.utils import get_stored_demo 20 | from yarr.utils.observation_type import ObservationElement 21 | from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer 22 | from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer 23 | from rlbench.backend.observation import Observation 24 | from rlbench.demo import Demo 25 | 26 | from rvt.utils.peract_utils import LOW_DIM_SIZE, IMAGE_SIZE, CAMERAS 27 | from rvt.libs.peract.helpers.demo_loading_utils import keypoint_discovery 28 | from rvt.libs.peract.helpers.utils import extract_obs 29 | 30 | 31 | def create_replay( 32 | batch_size: int, 33 | timesteps: int, 34 | disk_saving: bool, 35 | cameras: list, 36 | voxel_sizes, 37 | replay_size=3e5, 38 | ): 39 | 40 | trans_indicies_size = 3 * len(voxel_sizes) 41 | rot_and_grip_indicies_size = 3 + 1 42 | gripper_pose_size = 7 43 | ignore_collisions_size = 1 44 | max_token_seq_len = 77 45 | lang_feat_dim = 1024 46 | lang_emb_dim = 512 47 | 48 | # low_dim_state 49 | observation_elements = [] 50 | observation_elements.append( 51 | ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32) 52 | ) 53 | 54 | # rgb, depth, point cloud, intrinsics, extrinsics 55 | for cname in cameras: 56 | observation_elements.append( 57 | ObservationElement( 58 | "%s_rgb" % cname, 59 | ( 60 | 3, 61 | IMAGE_SIZE, 62 | IMAGE_SIZE, 63 | ), 64 | np.float32, 65 | ) 66 | ) 67 | observation_elements.append( 68 | ObservationElement( 69 | "%s_depth" % cname, 70 | ( 71 | 1, 72 | IMAGE_SIZE, 73 | IMAGE_SIZE, 74 | ), 75 | np.float32, 76 | ) 77 | ) 78 | observation_elements.append( 79 | ObservationElement( 80 | "%s_point_cloud" % cname, 81 | ( 82 | 3, 83 | IMAGE_SIZE, 84 | IMAGE_SIZE, 85 | ), 86 | np.float32, 87 | ) 88 | ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames 89 | observation_elements.append( 90 | ObservationElement( 91 | "%s_camera_extrinsics" % cname, 92 | ( 93 | 4, 94 | 4, 95 | ), 96 | np.float32, 97 | ) 98 | ) 99 | observation_elements.append( 100 | ObservationElement( 101 | "%s_camera_intrinsics" % cname, 102 | ( 103 | 3, 104 | 3, 105 | ), 106 | np.float32, 107 | ) 108 | ) 109 | 110 | # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings 111 | observation_elements.extend( 112 | [ 113 | ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32), 114 | ReplayElement( 115 | "rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32 116 | ), 117 | ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32), 118 | ReplayElement("gripper_pose", (gripper_pose_size,), np.float32), 119 | ReplayElement( 120 | "lang_goal_embs", 121 | ( 122 | max_token_seq_len, 123 | lang_emb_dim, 124 | ), # extracted from CLIP's language encoder 125 | np.float32, 126 | ), 127 | ReplayElement( 128 | "lang_goal", (1,), object 129 | ), # language goal string for debugging and visualization 130 | ] 131 | ) 132 | 133 | extra_replay_elements = [ 134 | ReplayElement("demo", (), bool), 135 | ReplayElement("keypoint_idx", (), int), 136 | ReplayElement("episode_idx", (), int), 137 | ReplayElement("keypoint_frame", (), int), 138 | ReplayElement("next_keypoint_frame", (), int), 139 | ReplayElement("sample_frame", (), int), 140 | ] 141 | 142 | replay_buffer = ( 143 | UniformReplayBuffer( # all tuples in the buffer have equal sample weighting 144 | disk_saving=disk_saving, 145 | batch_size=batch_size, 146 | timesteps=timesteps, 147 | replay_capacity=int(replay_size), 148 | action_shape=(8,), # 3 translation + 4 rotation quaternion + 1 gripper open 149 | action_dtype=np.float32, 150 | reward_shape=(), 151 | reward_dtype=np.float32, 152 | update_horizon=1, 153 | observation_elements=observation_elements, 154 | extra_replay_elements=extra_replay_elements, 155 | ) 156 | ) 157 | return replay_buffer 158 | 159 | 160 | # discretize translation, rotation, gripper open, and ignore collision actions 161 | def _get_action( 162 | obs_tp1: Observation, 163 | obs_tm1: Observation, 164 | rlbench_scene_bounds: List[float], # metric 3D bounds of the scene 165 | voxel_sizes: List[int], 166 | rotation_resolution: int, 167 | crop_augmentation: bool, 168 | ): 169 | quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) 170 | if quat[-1] < 0: 171 | quat = -quat 172 | disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution) 173 | attention_coordinate = obs_tp1.gripper_pose[:3] 174 | trans_indicies, attention_coordinates = [], [] 175 | bounds = np.array(rlbench_scene_bounds) 176 | ignore_collisions = int(obs_tm1.ignore_collisions) 177 | for depth, vox_size in enumerate( 178 | voxel_sizes 179 | ): # only single voxelization-level is used in PerAct 180 | index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds) 181 | trans_indicies.extend(index.tolist()) 182 | res = (bounds[3:] - bounds[:3]) / vox_size 183 | attention_coordinate = bounds[:3] + res * index 184 | attention_coordinates.append(attention_coordinate) 185 | 186 | rot_and_grip_indicies = disc_rot.tolist() 187 | grip = float(obs_tp1.gripper_open) 188 | rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)]) 189 | return ( 190 | trans_indicies, 191 | rot_and_grip_indicies, 192 | ignore_collisions, 193 | np.concatenate([obs_tp1.gripper_pose, np.array([grip])]), 194 | attention_coordinates, 195 | ) 196 | 197 | 198 | # extract CLIP language features for goal string 199 | def _clip_encode_text(clip_model, text): 200 | x = clip_model.token_embedding(text).type( 201 | clip_model.dtype 202 | ) # [batch_size, n_ctx, d_model] 203 | 204 | x = x + clip_model.positional_embedding.type(clip_model.dtype) 205 | x = x.permute(1, 0, 2) # NLD -> LND 206 | x = clip_model.transformer(x) 207 | x = x.permute(1, 0, 2) # LND -> NLD 208 | x = clip_model.ln_final(x).type(clip_model.dtype) 209 | 210 | emb = x.clone() 211 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ clip_model.text_projection 212 | 213 | return x, emb 214 | 215 | 216 | # add individual data points to a replay 217 | def _add_keypoints_to_replay( 218 | replay: ReplayBuffer, 219 | task: str, 220 | task_replay_storage_folder: str, 221 | episode_idx: int, 222 | sample_frame: int, 223 | inital_obs: Observation, 224 | demo: Demo, 225 | episode_keypoints: List[int], 226 | cameras: List[str], 227 | rlbench_scene_bounds: List[float], 228 | voxel_sizes: List[int], 229 | rotation_resolution: int, 230 | crop_augmentation: bool, 231 | next_keypoint_idx: int, 232 | description: str = "", 233 | clip_model=None, 234 | device="cpu", 235 | ): 236 | prev_action = None 237 | obs = inital_obs 238 | for k in range( 239 | next_keypoint_idx, len(episode_keypoints) 240 | ): # confused here, it seems that there are many similar samples in the replay 241 | keypoint = episode_keypoints[k] 242 | obs_tp1 = demo[keypoint] 243 | obs_tm1 = demo[max(0, keypoint - 1)] 244 | ( 245 | trans_indicies, 246 | rot_grip_indicies, 247 | ignore_collisions, 248 | action, 249 | attention_coordinates, 250 | ) = _get_action( 251 | obs_tp1, 252 | obs_tm1, 253 | rlbench_scene_bounds, 254 | voxel_sizes, 255 | rotation_resolution, 256 | crop_augmentation, 257 | ) 258 | 259 | terminal = k == len(episode_keypoints) - 1 260 | reward = float(terminal) * 1.0 if terminal else 0 261 | 262 | obs_dict = extract_obs( 263 | obs, 264 | CAMERAS, 265 | t=k - next_keypoint_idx, 266 | prev_action=prev_action, 267 | episode_length=25, 268 | ) 269 | tokens = clip.tokenize([description]).numpy() 270 | token_tensor = torch.from_numpy(tokens).to(device) 271 | with torch.no_grad(): 272 | lang_feats, lang_embs = _clip_encode_text(clip_model, token_tensor) 273 | obs_dict["lang_goal_embs"] = lang_embs[0].float().detach().cpu().numpy() 274 | 275 | prev_action = np.copy(action) 276 | 277 | if k == 0: 278 | keypoint_frame = -1 279 | else: 280 | keypoint_frame = episode_keypoints[k - 1] 281 | others = { 282 | "demo": True, 283 | "keypoint_idx": k, 284 | "episode_idx": episode_idx, 285 | "keypoint_frame": keypoint_frame, 286 | "next_keypoint_frame": keypoint, 287 | "sample_frame": sample_frame, 288 | } 289 | final_obs = { 290 | "trans_action_indicies": trans_indicies, 291 | "rot_grip_action_indicies": rot_grip_indicies, 292 | "gripper_pose": obs_tp1.gripper_pose, 293 | "lang_goal": np.array([description], dtype=object), 294 | } 295 | 296 | others.update(final_obs) 297 | others.update(obs_dict) 298 | 299 | timeout = False 300 | replay.add( 301 | task, 302 | task_replay_storage_folder, 303 | action, 304 | reward, 305 | terminal, 306 | timeout, 307 | **others 308 | ) 309 | obs = obs_tp1 310 | sample_frame = keypoint 311 | 312 | # final step 313 | obs_dict_tp1 = extract_obs( 314 | obs_tp1, 315 | CAMERAS, 316 | t=k + 1 - next_keypoint_idx, 317 | prev_action=prev_action, 318 | episode_length=25, 319 | ) 320 | obs_dict_tp1["lang_goal_embs"] = lang_embs[0].float().detach().cpu().numpy() 321 | 322 | obs_dict_tp1.pop("wrist_world_to_cam", None) 323 | obs_dict_tp1.update(final_obs) 324 | replay.add_final(task, task_replay_storage_folder, **obs_dict_tp1) 325 | 326 | 327 | def fill_replay( 328 | replay: ReplayBuffer, 329 | task: str, 330 | task_replay_storage_folder: str, 331 | start_idx: int, 332 | num_demos: int, 333 | demo_augmentation: bool, 334 | demo_augmentation_every_n: int, 335 | cameras: List[str], 336 | rlbench_scene_bounds: List[float], # AKA: DEPTH0_BOUNDS 337 | voxel_sizes: List[int], 338 | rotation_resolution: int, 339 | crop_augmentation: bool, 340 | data_path: str, 341 | episode_folder: str, 342 | variation_desriptions_pkl: str, 343 | clip_model=None, 344 | device="cpu", 345 | ): 346 | 347 | disk_exist = False 348 | if replay._disk_saving: 349 | if os.path.exists(task_replay_storage_folder): 350 | print( 351 | "[Info] Replay dataset already exists in the disk: {}".format( 352 | task_replay_storage_folder 353 | ), 354 | flush=True, 355 | ) 356 | disk_exist = True 357 | else: 358 | logging.info("\t saving to disk: %s", task_replay_storage_folder) 359 | os.makedirs(task_replay_storage_folder, exist_ok=True) 360 | 361 | if disk_exist: 362 | replay.recover_from_disk(task, task_replay_storage_folder) 363 | else: 364 | print("Filling replay ...") 365 | for d_idx in range(start_idx, start_idx + num_demos): 366 | print("Filling demo %d" % d_idx) 367 | demo = get_stored_demo(data_path=data_path, index=d_idx) 368 | 369 | # get language goal from disk 370 | varation_descs_pkl_file = os.path.join( 371 | data_path, episode_folder % d_idx, variation_desriptions_pkl 372 | ) 373 | with open(varation_descs_pkl_file, "rb") as f: 374 | descs = pickle.load(f) 375 | 376 | # extract keypoints 377 | episode_keypoints = keypoint_discovery(demo) 378 | next_keypoint_idx = 0 379 | for i in range(len(demo) - 1): 380 | if not demo_augmentation and i > 0: 381 | break 382 | if i % demo_augmentation_every_n != 0: # choose only every n-th frame 383 | continue 384 | 385 | obs = demo[i] 386 | desc = descs[0] 387 | # if our starting point is past one of the keypoints, then remove it 388 | while ( 389 | next_keypoint_idx < len(episode_keypoints) 390 | and i >= episode_keypoints[next_keypoint_idx] 391 | ): 392 | next_keypoint_idx += 1 393 | if next_keypoint_idx == len(episode_keypoints): 394 | break 395 | _add_keypoints_to_replay( 396 | replay, 397 | task, 398 | task_replay_storage_folder, 399 | d_idx, 400 | i, 401 | obs, 402 | demo, 403 | episode_keypoints, 404 | cameras, 405 | rlbench_scene_bounds, 406 | voxel_sizes, 407 | rotation_resolution, 408 | crop_augmentation, 409 | next_keypoint_idx=next_keypoint_idx, 410 | description=desc, 411 | clip_model=clip_model, 412 | device=device, 413 | ) 414 | 415 | # save TERMINAL info in replay_info.npy 416 | task_idx = replay._task_index[task] 417 | with open( 418 | os.path.join(task_replay_storage_folder, "replay_info.npy"), "wb" 419 | ) as fp: 420 | np.save( 421 | fp, 422 | replay._store["terminal"][ 423 | replay._task_replay_start_index[ 424 | task_idx 425 | ] : replay._task_replay_start_index[task_idx] 426 | + replay._task_add_count[task_idx].value 427 | ], 428 | ) 429 | 430 | print("Replay filled with demos.") 431 | -------------------------------------------------------------------------------- /rvt/utils/ddp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import os 6 | import torch.distributed as dist 7 | import random 8 | 9 | 10 | def setup(rank, world_size, port): 11 | os.environ["MASTER_ADDR"] = "localhost" 12 | os.environ["MASTER_PORT"] = str(port) 13 | 14 | # initialize the process group 15 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 16 | 17 | 18 | def cleanup(): 19 | dist.destroy_process_group() 20 | -------------------------------------------------------------------------------- /rvt/utils/get_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import os 6 | import sys 7 | import shutil 8 | import torch 9 | import clip 10 | 11 | from rvt.libs.peract.helpers.utils import extract_obs 12 | from rvt.utils.rvt_utils import ForkedPdb 13 | from rvt.utils.dataset import create_replay, fill_replay 14 | from rvt.utils.peract_utils import ( 15 | CAMERAS, 16 | SCENE_BOUNDS, 17 | EPISODE_FOLDER, 18 | VARIATION_DESCRIPTIONS_PKL, 19 | DEMO_AUGMENTATION_EVERY_N, 20 | ROTATION_RESOLUTION, 21 | VOXEL_SIZES, 22 | ) 23 | from yarr.replay_buffer.wrappers.pytorch_replay_buffer import PyTorchReplayBuffer 24 | 25 | 26 | def get_dataset( 27 | tasks, 28 | BATCH_SIZE_TRAIN, 29 | BATCH_SIZE_TEST, 30 | TRAIN_REPLAY_STORAGE_DIR, 31 | TEST_REPLAY_STORAGE_DIR, 32 | DATA_FOLDER, 33 | NUM_TRAIN, 34 | NUM_VAL, 35 | refresh_replay, 36 | device, 37 | num_workers, 38 | only_train, 39 | sample_distribution_mode="transition_uniform", 40 | ): 41 | 42 | train_replay_buffer = create_replay( 43 | batch_size=BATCH_SIZE_TRAIN, 44 | timesteps=1, 45 | disk_saving=True, 46 | cameras=CAMERAS, 47 | voxel_sizes=VOXEL_SIZES, 48 | ) 49 | if not only_train: 50 | test_replay_buffer = create_replay( 51 | batch_size=BATCH_SIZE_TEST, 52 | timesteps=1, 53 | disk_saving=True, 54 | cameras=CAMERAS, 55 | voxel_sizes=VOXEL_SIZES, 56 | ) 57 | 58 | # load pre-trained language model 59 | try: 60 | clip_model, _ = clip.load("RN50", device="cpu") # CLIP-ResNet50 61 | clip_model = clip_model.to(device) 62 | clip_model.eval() 63 | except RuntimeError: 64 | print("WARNING: Setting Clip to None. Will not work if replay not on disk.") 65 | clip_model = None 66 | 67 | for task in tasks: # for each task 68 | # print("---- Preparing the data for {} task ----".format(task), flush=True) 69 | EPISODES_FOLDER_TRAIN = f"train/{task}/all_variations/episodes" 70 | EPISODES_FOLDER_VAL = f"val/{task}/all_variations/episodes" 71 | data_path_train = os.path.join(DATA_FOLDER, EPISODES_FOLDER_TRAIN) 72 | data_path_val = os.path.join(DATA_FOLDER, EPISODES_FOLDER_VAL) 73 | train_replay_storage_folder = f"{TRAIN_REPLAY_STORAGE_DIR}/{task}" 74 | test_replay_storage_folder = f"{TEST_REPLAY_STORAGE_DIR}/{task}" 75 | 76 | # if refresh_replay, then remove the existing replay data folder 77 | if refresh_replay: 78 | print("[Info] Remove exisitng replay dataset as requested.", flush=True) 79 | if os.path.exists(train_replay_storage_folder) and os.path.isdir( 80 | train_replay_storage_folder 81 | ): 82 | shutil.rmtree(train_replay_storage_folder) 83 | print(f"remove {train_replay_storage_folder}") 84 | if os.path.exists(test_replay_storage_folder) and os.path.isdir( 85 | test_replay_storage_folder 86 | ): 87 | shutil.rmtree(test_replay_storage_folder) 88 | print(f"remove {test_replay_storage_folder}") 89 | 90 | # print("----- Train Buffer -----") 91 | fill_replay( 92 | replay=train_replay_buffer, 93 | task=task, 94 | task_replay_storage_folder=train_replay_storage_folder, 95 | start_idx=0, 96 | num_demos=NUM_TRAIN, 97 | demo_augmentation=True, 98 | demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N, 99 | cameras=CAMERAS, 100 | rlbench_scene_bounds=SCENE_BOUNDS, 101 | voxel_sizes=VOXEL_SIZES, 102 | rotation_resolution=ROTATION_RESOLUTION, 103 | crop_augmentation=False, 104 | data_path=data_path_train, 105 | episode_folder=EPISODE_FOLDER, 106 | variation_desriptions_pkl=VARIATION_DESCRIPTIONS_PKL, 107 | clip_model=clip_model, 108 | device=device, 109 | ) 110 | 111 | if not only_train: 112 | # print("----- Test Buffer -----") 113 | fill_replay( 114 | replay=test_replay_buffer, 115 | task=task, 116 | task_replay_storage_folder=test_replay_storage_folder, 117 | start_idx=0, 118 | num_demos=NUM_VAL, 119 | demo_augmentation=True, 120 | demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N, 121 | cameras=CAMERAS, 122 | rlbench_scene_bounds=SCENE_BOUNDS, 123 | voxel_sizes=VOXEL_SIZES, 124 | rotation_resolution=ROTATION_RESOLUTION, 125 | crop_augmentation=False, 126 | data_path=data_path_val, 127 | episode_folder=EPISODE_FOLDER, 128 | variation_desriptions_pkl=VARIATION_DESCRIPTIONS_PKL, 129 | clip_model=clip_model, 130 | device=device, 131 | ) 132 | 133 | # delete the CLIP model since we have already extracted language features 134 | del clip_model 135 | with torch.cuda.device(device): 136 | torch.cuda.empty_cache() 137 | 138 | # wrap buffer with PyTorch dataset and make iterator 139 | train_wrapped_replay = PyTorchReplayBuffer( 140 | train_replay_buffer, 141 | sample_mode="random", 142 | num_workers=num_workers, 143 | sample_distribution_mode=sample_distribution_mode, 144 | ) 145 | train_dataset = train_wrapped_replay.dataset() 146 | 147 | if only_train: 148 | test_dataset = None 149 | else: 150 | test_wrapped_replay = PyTorchReplayBuffer( 151 | test_replay_buffer, 152 | sample_mode="enumerate", 153 | num_workers=num_workers, 154 | ) 155 | test_dataset = test_wrapped_replay.dataset() 156 | return train_dataset, test_dataset 157 | -------------------------------------------------------------------------------- /rvt/utils/lr_sched_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | 8 | # source: https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py 9 | # updated such that it is suitable for cases when epoch number start from 0 10 | # lr constantly increases from "epoch 0" to "epoch (total_epoch - 1)" such that 11 | # lr at epoch is same as the base_lr for the after_scheduler 12 | # Only tested for case when multiplier is 1.0 and after schduler is a 13 | # MultiStepLR 14 | 15 | 16 | class GradualWarmupScheduler(_LRScheduler): 17 | """Gradually warm-up(increasing) learning rate in optimizer. 18 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 19 | Args: 20 | optimizer (Optimizer): Wrapped optimizer. 21 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 22 | total_epoch: target learning rate is reached at total_epoch, gradually 23 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 24 | """ 25 | 26 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 27 | self.multiplier = multiplier 28 | if self.multiplier < 1.0: 29 | raise ValueError("multiplier should be greater thant or equal to 1.") 30 | self.total_epoch = total_epoch 31 | self.after_scheduler = after_scheduler 32 | self.finished = False 33 | super(GradualWarmupScheduler, self).__init__(optimizer) 34 | 35 | def get_lr(self): 36 | if (self.last_epoch + 1) > self.total_epoch: 37 | if self.after_scheduler: 38 | if not self.finished: 39 | self.after_scheduler.base_lrs = [ 40 | base_lr * self.multiplier for base_lr in self.base_lrs 41 | ] 42 | self.finished = True 43 | return self.after_scheduler.get_last_lr() 44 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 45 | 46 | if self.multiplier == 1.0: 47 | return [ 48 | base_lr * ((float(self.last_epoch) + 1) / self.total_epoch) 49 | for base_lr in self.base_lrs 50 | ] 51 | else: 52 | return [ 53 | base_lr 54 | * ( 55 | (self.multiplier - 1.0) * (self.last_epoch + 1) / self.total_epoch 56 | + 1.0 57 | ) 58 | for base_lr in self.base_lrs 59 | ] 60 | 61 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 62 | if epoch is None: 63 | epoch = self.last_epoch + 1 64 | self.last_epoch = ( 65 | epoch if epoch != 0 else 1 66 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 67 | if self.last_epoch <= self.total_epoch: 68 | warmup_lr = [ 69 | base_lr 70 | * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) 71 | for base_lr in self.base_lrs 72 | ] 73 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 74 | param_group["lr"] = lr 75 | else: 76 | if epoch is None: 77 | self.after_scheduler.step(metrics, None) 78 | else: 79 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 80 | 81 | def step(self, epoch=None, metrics=None): 82 | if type(self.after_scheduler) != ReduceLROnPlateau: 83 | if self.finished and self.after_scheduler: 84 | if epoch is None: 85 | self.after_scheduler.step(None) 86 | else: 87 | self.after_scheduler.step(epoch) 88 | self._last_lr = self.after_scheduler.get_last_lr() 89 | else: 90 | return super(GradualWarmupScheduler, self).step(epoch) 91 | else: 92 | self.step_ReduceLROnPlateau(metrics, epoch) 93 | 94 | def state_dict(self): 95 | state_dict = { 96 | key: value 97 | for key, value in self.__dict__.items() 98 | if key not in ["optimizer", "after_scheduler"] 99 | } 100 | 101 | if not (self.after_scheduler is None): 102 | state_dict["after_scheduler_state_dict"] = self.after_scheduler.state_dict() 103 | 104 | return state_dict 105 | 106 | def load_state_dict(self, state_dict): 107 | if self.after_scheduler is None: 108 | assert not ("after_scheduler_state_dict" in state_dict) 109 | else: 110 | self.after_scheduler.load_state_dict( 111 | state_dict["after_scheduler_state_dict"] 112 | ) 113 | del state_dict["after_scheduler_state_dict"] 114 | 115 | self.__dict__.update(state_dict) 116 | -------------------------------------------------------------------------------- /rvt/utils/peract_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import torch 6 | from omegaconf import OmegaConf 7 | 8 | from rvt.models.peract_official import create_agent_our 9 | from peract_colab.arm.utils import stack_on_channel 10 | from torch.optim.lr_scheduler import CosineAnnealingLR 11 | from rvt.utils.lr_sched_utils import GradualWarmupScheduler 12 | 13 | # Contants 14 | # TODO: Unclear about the best way to handle them 15 | CAMERAS = ["front", "left_shoulder", "right_shoulder", "wrist"] 16 | SCENE_BOUNDS = [ 17 | -0.3, 18 | -0.5, 19 | 0.6, 20 | 0.7, 21 | 0.5, 22 | 1.6, 23 | ] # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized 24 | IMAGE_SIZE = 128 25 | VOXEL_SIZES = [100] # 100x100x100 voxels 26 | LOW_DIM_SIZE = 4 # {left_finger_joint, right_finger_joint, gripper_open, timestep} 27 | 28 | DATA_FOLDER = "data" 29 | EPISODE_FOLDER = "episode%d" 30 | VARIATION_DESCRIPTIONS_PKL = "variation_descriptions.pkl" # the pkl file that contains language goals for each demonstration 31 | DEMO_AUGMENTATION_EVERY_N = 10 # sample n-th frame in demo 32 | ROTATION_RESOLUTION = 5 # degree increments per axis 33 | # settings 34 | NUM_LATENTS = 512 # PerceiverIO latents 35 | 36 | 37 | def _norm_rgb(x): 38 | return (x.float() / 255.0) * 2.0 - 1.0 39 | 40 | 41 | def _preprocess_inputs(replay_sample, cameras): 42 | obs, pcds = [], [] 43 | for n in cameras: 44 | rgb = stack_on_channel(replay_sample["%s_rgb" % n]) 45 | pcd = stack_on_channel(replay_sample["%s_point_cloud" % n]) 46 | 47 | rgb = _norm_rgb(rgb) 48 | 49 | obs.append( 50 | [rgb, pcd] 51 | ) # obs contains both rgb and pointcloud (used in ARM for other baselines) 52 | pcds.append(pcd) # only pointcloud 53 | return obs, pcds 54 | 55 | 56 | def get_official_peract( 57 | cfg_path, 58 | training, 59 | device, 60 | bs, 61 | ): 62 | """ 63 | Creates an official peract agent 64 | :param cfg_path: path to the config file 65 | :param training: whether to build the agent in training mode 66 | :param device: device to build the agent on 67 | :param bs: batch size, does not matter when we need a model for inference. 68 | """ 69 | with open(cfg_path, "r") as f: 70 | cfg = OmegaConf.load(f) 71 | 72 | # we need to modify the batch size as in our case we specify batchsize per 73 | # gpu 74 | cfg.replay.batch_size = bs 75 | agent = create_agent_our(cfg) 76 | agent.build(training=training, device=device) 77 | 78 | return agent 79 | -------------------------------------------------------------------------------- /rvt/utils/rlbench_planning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | import numpy as np 6 | from rlbench.action_modes.arm_action_modes import ( 7 | EndEffectorPoseViaPlanning, 8 | Scene, 9 | ) 10 | 11 | 12 | class EndEffectorPoseViaPlanning2(EndEffectorPoseViaPlanning): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | 16 | def action(self, scene: Scene, action: np.ndarray, ignore_collisions: bool = True): 17 | action[:3] = np.clip( 18 | action[:3], 19 | np.array( 20 | [scene._workspace_minx, scene._workspace_miny, scene._workspace_minz] 21 | ) 22 | + 1e-7, 23 | np.array( 24 | [scene._workspace_maxx, scene._workspace_maxy, scene._workspace_maxz] 25 | ) 26 | - 1e-7, 27 | ) 28 | 29 | super().action(scene, action, ignore_collisions) 30 | -------------------------------------------------------------------------------- /rvt/utils/rvt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | """ 6 | Utility function for Our Agent 7 | """ 8 | import pdb 9 | import argparse 10 | import sys 11 | import signal 12 | from datetime import datetime 13 | 14 | import torch 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | 19 | import rvt.utils.peract_utils as peract_utils 20 | from rvt.models.peract_official import PreprocessAgent2 21 | 22 | 23 | def get_pc_img_feat(obs, pcd, bounds=None): 24 | """ 25 | preprocess the data in the peract to our framework 26 | """ 27 | # obs, pcd = peract_utils._preprocess_inputs(batch) 28 | bs = obs[0][0].shape[0] 29 | # concatenating the points from all the cameras 30 | # (bs, num_points, 3) 31 | pc = torch.cat([p.permute(0, 2, 3, 1).reshape(bs, -1, 3) for p in pcd], 1) 32 | _img_feat = [o[0] for o in obs] 33 | img_dim = _img_feat[0].shape[1] 34 | # (bs, num_points, 3) 35 | img_feat = torch.cat( 36 | [p.permute(0, 2, 3, 1).reshape(bs, -1, img_dim) for p in _img_feat], 1 37 | ) 38 | 39 | img_feat = (img_feat + 1) / 2 40 | 41 | # x_min, y_min, z_min, x_max, y_max, z_max = bounds 42 | # inv_pnt = ( 43 | # (pc[:, :, 0] < x_min) 44 | # | (pc[:, :, 0] > x_max) 45 | # | (pc[:, :, 1] < y_min) 46 | # | (pc[:, :, 1] > y_max) 47 | # | (pc[:, :, 2] < z_min) 48 | # | (pc[:, :, 2] > z_max) 49 | # ) 50 | 51 | # # TODO: move from a list to a better batched version 52 | # pc = [pc[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)] 53 | # img_feat = [img_feat[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)] 54 | 55 | return pc, img_feat 56 | 57 | 58 | def move_pc_in_bound(pc, img_feat, bounds, no_op=False): 59 | """ 60 | :param no_op: no operation 61 | """ 62 | if no_op: 63 | return pc, img_feat 64 | 65 | x_min, y_min, z_min, x_max, y_max, z_max = bounds 66 | inv_pnt = ( 67 | (pc[:, :, 0] < x_min) 68 | | (pc[:, :, 0] > x_max) 69 | | (pc[:, :, 1] < y_min) 70 | | (pc[:, :, 1] > y_max) 71 | | (pc[:, :, 2] < z_min) 72 | | (pc[:, :, 2] > z_max) 73 | | torch.isnan(pc[:, :, 0]) 74 | | torch.isnan(pc[:, :, 1]) 75 | | torch.isnan(pc[:, :, 2]) 76 | ) 77 | 78 | # TODO: move from a list to a better batched version 79 | pc = [pc[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)] 80 | img_feat = [img_feat[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)] 81 | return pc, img_feat 82 | 83 | 84 | def count_parameters(model): 85 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 86 | 87 | 88 | class TensorboardManager: 89 | def __init__(self, path): 90 | self.writer = SummaryWriter(path) 91 | 92 | def update(self, split, step, vals): 93 | for k, v in vals.items(): 94 | if "image" in k: 95 | for i, x in enumerate(v): 96 | self.writer.add_image(f"{split}_{step}", x, i) 97 | elif "hist" in k: 98 | if isinstance(v, list): 99 | self.writer.add_histogram(k, v, step) 100 | elif isinstance(v, dict): 101 | hist_id = {} 102 | for i, idx in enumerate(sorted(v.keys())): 103 | self.writer.add_histogram(f"{split}_{k}_{step}", v[idx], i) 104 | hist_id[i] = idx 105 | self.writer.add_text(f"{split}_{k}_{step}_id", f"{hist_id}") 106 | else: 107 | assert False 108 | else: 109 | self.writer.add_scalar("%s_%s" % (split, k), v, step) 110 | 111 | def close(self): 112 | self.writer.flush() 113 | self.writer.close() 114 | 115 | 116 | class ForkedPdb(pdb.Pdb): 117 | """A Pdb subclass that may be used 118 | from a forked multiprocessing child 119 | 120 | """ 121 | 122 | def interaction(self, *args, **kwargs): 123 | _stdin = sys.stdin 124 | try: 125 | sys.stdin = open("/dev/stdin") 126 | pdb.Pdb.interaction(self, *args, **kwargs) 127 | finally: 128 | sys.stdin = _stdin 129 | 130 | 131 | def short_name(cfg_opts): 132 | SHORT_FORMS = { 133 | "peract": "PA", 134 | "sample_distribution_mode": "SDM", 135 | "optimizer_type": "OPT", 136 | "lr_cos_dec": "LCD", 137 | "num_workers": "NW", 138 | "True": "T", 139 | "False": "F", 140 | "pe_fix": "pf", 141 | "transform_augmentation_rpy": "tar", 142 | "lambda_weight_l2": "l2", 143 | "resume": "RES", 144 | "inp_pre_pro": "IPP", 145 | "inp_pre_con": "IPC", 146 | "cvx_up": "CU", 147 | "stage_two": "ST", 148 | "feat_ver": "FV", 149 | "lamb": "L", 150 | "img_size": "IS", 151 | "img_patch_size": "IPS", 152 | "rlbench": "RLB", 153 | "move_pc_in_bound": "MPIB", 154 | "rend": "R", 155 | "xops": "X", 156 | "warmup_steps": "WS", 157 | "epochs": "E", 158 | "amp": "A", 159 | } 160 | 161 | if "resume" in cfg_opts: 162 | cfg_opts = cfg_opts.split(" ") 163 | res_idx = cfg_opts.index("resume") 164 | cfg_opts.pop(res_idx + 1) 165 | cfg_opts = " ".join(cfg_opts) 166 | 167 | cfg_opts = cfg_opts.replace(" ", "_") 168 | cfg_opts = cfg_opts.replace("/", "_") 169 | cfg_opts = cfg_opts.replace("[", "") 170 | cfg_opts = cfg_opts.replace("]", "") 171 | cfg_opts = cfg_opts.replace("..", "") 172 | for a, b in SHORT_FORMS.items(): 173 | cfg_opts = cfg_opts.replace(a, b) 174 | 175 | return cfg_opts 176 | 177 | 178 | def get_num_feat(cfg): 179 | num_feat = cfg.num_rotation_classes * 3 180 | # 2 for grip, 2 for collision 181 | num_feat += 4 182 | return num_feat 183 | 184 | 185 | def get_eval_parser(): 186 | parser = argparse.ArgumentParser() 187 | 188 | parser.add_argument( 189 | "--tasks", type=str, nargs="+", default=["insert_onto_square_peg"] 190 | ) 191 | parser.add_argument("--model-folder", type=str, default=None) 192 | parser.add_argument("--eval-datafolder", type=str, default="./data/val/") 193 | parser.add_argument( 194 | "--start-episode", 195 | type=int, 196 | default=0, 197 | help="start to evaluate from which episode", 198 | ) 199 | parser.add_argument( 200 | "--eval-episodes", 201 | type=int, 202 | default=10, 203 | help="how many episodes to be evaluated for each task", 204 | ) 205 | parser.add_argument( 206 | "--episode-length", 207 | type=int, 208 | default=25, 209 | help="maximum control steps allowed for each episode", 210 | ) 211 | parser.add_argument("--headless", action="store_true", default=False) 212 | parser.add_argument("--ground-truth", action="store_true", default=False) 213 | parser.add_argument("--exp_cfg_path", type=str, default=None) 214 | parser.add_argument("--mvt_cfg_path", type=str, default=None) 215 | parser.add_argument("--peract_official", action="store_true") 216 | parser.add_argument( 217 | "--peract_model_dir", 218 | type=str, 219 | default="runs/peract_official/seed0/weights/600000", 220 | ) 221 | parser.add_argument("--device", type=int, default=0) 222 | parser.add_argument("--log-name", type=str, default=None) 223 | parser.add_argument("--model-name", type=str, default=None) 224 | parser.add_argument("--use-input-place-with-mean", action="store_true") 225 | parser.add_argument("--save-video", action="store_true") 226 | parser.add_argument("--skip", action="store_true") 227 | 228 | return parser 229 | 230 | 231 | RLBENCH_TASKS = [ 232 | "put_item_in_drawer", 233 | "reach_and_drag", 234 | "turn_tap", 235 | "slide_block_to_color_target", 236 | "open_drawer", 237 | "put_groceries_in_cupboard", 238 | "place_shape_in_shape_sorter", 239 | "put_money_in_safe", 240 | "push_buttons", 241 | "close_jar", 242 | "stack_blocks", 243 | "place_cups", 244 | "place_wine_at_rack_location", 245 | "light_bulb_in", 246 | "sweep_to_dustpan_of_size", 247 | "insert_onto_square_peg", 248 | "meat_off_grill", 249 | "stack_cups", 250 | ] 251 | 252 | 253 | def load_agent(agent_path, agent=None, only_epoch=False): 254 | if isinstance(agent, PreprocessAgent2): 255 | assert not only_epoch 256 | agent._pose_agent.load_weights(agent_path) 257 | return 0 258 | 259 | checkpoint = torch.load(agent_path, map_location="cpu") 260 | epoch = checkpoint["epoch"] 261 | 262 | if not only_epoch: 263 | if hasattr(agent, "_q"): 264 | model = agent._q 265 | elif hasattr(agent, "_network"): 266 | model = agent._network 267 | optimizer = agent._optimizer 268 | lr_sched = agent._lr_sched 269 | 270 | if isinstance(model, DDP): 271 | model = model.module 272 | 273 | try: 274 | model.load_state_dict(checkpoint["model_state"]) 275 | except RuntimeError: 276 | try: 277 | print( 278 | "WARNING: loading states in mvt1. " 279 | "Be cautious if you are using a two stage network." 280 | ) 281 | model.mvt1.load_state_dict(checkpoint["model_state"]) 282 | except RuntimeError: 283 | print( 284 | "WARNING: loading states with strick=False! " 285 | "KNOW WHAT YOU ARE DOING!!" 286 | ) 287 | model.load_state_dict(checkpoint["model_state"], strict=False) 288 | 289 | if "optimizer_state" in checkpoint: 290 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 291 | else: 292 | print( 293 | "WARNING: No optimizer_state in checkpoint" "KNOW WHAT YOU ARE DOING!!" 294 | ) 295 | 296 | if "lr_sched_state" in checkpoint: 297 | lr_sched.load_state_dict(checkpoint["lr_sched_state"]) 298 | else: 299 | print( 300 | "WARNING: No lr_sched_state in checkpoint" "KNOW WHAT YOU ARE DOING!!" 301 | ) 302 | 303 | return epoch 304 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | """ 6 | Setup of RVT 7 | Author: Ankit Goyal 8 | """ 9 | from setuptools import setup, find_packages 10 | 11 | requirements = [ 12 | "numpy", 13 | "scipy", 14 | "einops", 15 | "pyrender", 16 | "transformers", 17 | "omegaconf", 18 | "natsort", 19 | "cffi", 20 | "pandas", 21 | "tensorflow", 22 | "pyquaternion", 23 | "matplotlib", 24 | "bitsandbytes==0.38.1", 25 | "transforms3d", 26 | "clip @ git+https://github.com/openai/CLIP.git", 27 | ] 28 | 29 | __version__ = "0.0.1" 30 | setup( 31 | name="rvt", 32 | version=__version__, 33 | description="RVT", 34 | long_description="", 35 | author="Ankit Goyal", 36 | author_email="angoyal@nvidia.com", 37 | url="", 38 | keywords="robotics,computer vision", 39 | classifiers=[ 40 | "Programming Language :: Python", 41 | "Programming Language :: Python :: 3.8", 42 | "Natural Language :: English", 43 | "Topic :: Scientific/Engineering", 44 | ], 45 | packages=['rvt'], 46 | install_requires=requirements, 47 | extras_require={ 48 | "xformers": [ 49 | "xformers @ git+https://github.com/facebookresearch/xformers.git@main#egg=xformers", 50 | ] 51 | }, 52 | ) 53 | --------------------------------------------------------------------------------