├── .gitattributes
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── rvt
├── config.py
├── configs
│ ├── peract_official_config.yaml
│ ├── rvt.yaml
│ └── rvt2.yaml
├── eval.py
├── libs
│ ├── peract_colab
│ │ └── setup.py
│ └── point-renderer
│ │ ├── .gitattributes
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── demo.png
│ │ ├── image_0_splat_2xaa.png
│ │ ├── pcd_data.tar.gz
│ │ ├── point_renderer
│ │ ├── cameras.py
│ │ ├── csrc
│ │ │ ├── bindings.cpp
│ │ │ └── render
│ │ │ │ ├── render_feature_pointcloud.cu
│ │ │ │ └── render_pointcloud.h
│ │ ├── ops.py
│ │ ├── profiler.py
│ │ ├── renderer.py
│ │ ├── rvt_ops.py
│ │ └── rvt_renderer.py
│ │ ├── pointcloud-notebook.ipynb
│ │ ├── requirements.txt
│ │ └── setup.py
├── models
│ ├── peract_official.py
│ └── rvt_agent.py
├── mvt
│ ├── __init__.py
│ ├── attn.py
│ ├── aug_utils.py
│ ├── augmentation.py
│ ├── config.py
│ ├── configs
│ │ ├── rvt2.yaml
│ │ └── rvt2_partial.yaml
│ ├── mvt.py
│ ├── mvt_single.py
│ ├── raft_utils.py
│ ├── renderer.py
│ └── utils.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── custom_rlbench_env.py
│ ├── dataset.py
│ ├── ddp_utils.py
│ ├── get_dataset.py
│ ├── lr_sched_utils.py
│ ├── peract_utils.py
│ ├── rlbench_planning.py
│ └── rvt_utils.py
└── setup.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | rlbench/task_design.ttt filter=lfs diff=lfs merge=lfs -text
6 | *.ttt filter=lfs diff=lfs merge=lfs -text
7 | *.ttm filter=lfs diff=lfs merge=lfs -text
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | *__pycache__/
6 | rvt/data/train/
7 | rvt/data/val/
8 | rvt/data/test/
9 | rvt/replay/
10 | coppelia_install_dir/
11 | rvt/runs/
12 | runs_ngc_mnt
13 | runs_temp
14 | *.ipynb_checkpoints/
15 | .vscode/
16 | build/
17 | dist/
18 | *.egg-info/
19 |
20 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "libs/PyRep"]
2 | path = rvt/libs/PyRep
3 | url = https://github.com/stepjam/PyRep.git
4 | [submodule "libs/RLBench"]
5 | path = rvt/libs/RLBench
6 | url = https://github.com/buttomnutstoast/RLBench.git
7 | [submodule "libs/YARR"]
8 | path = rvt/libs/YARR
9 | url = https://github.com/NVlabs/YARR.git
10 | [submodule "libs/peract"]
11 | path = rvt/libs/peract
12 | url = https://github.com/NVlabs/peract.git
13 | [submodule "libs/peract_colab/peract_colab"]
14 | path = rvt/libs/peract_colab/peract_colab
15 | url = https://github.com/peract/peract_colab.git
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | NVIDIA License
2 |
3 | 1. Definitions
4 |
5 | "Licensor" means any person or entity that distributes its Work.
6 |
7 | "Work" means (a) the original work of authorship made available under this
8 | license, which may include software, documentation, or other files, and (b) any
9 | additions to or derivative works thereof that are made available under this
10 | license.
11 |
12 | The terms "reproduce," "reproduction," "derivative works," and "distribution"
13 | have the meaning as provided under U.S. copyright law; provided, however, that
14 | for the purposes of this license, derivative works shall not include works that
15 | remain separable from, or merely link (or bind by name) to the interfaces of,
16 | the Work.
17 |
18 | Works are "made available" under this license by including in or with the Work
19 | either (a) a copyright notice referencing the applicability of this license to
20 | the Work, or (b) a copy of this license.
21 |
22 | 2. License Grant
23 |
24 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each
25 | Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free,
26 | copyright license to use, reproduce, prepare derivative works of, publicly
27 | display, publicly perform, sublicense and distribute its Work and any resulting
28 | derivative works in any form.
29 |
30 | 3. Limitations
31 |
32 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do
33 | so under this license, (b) you include a complete copy of this license with
34 | your distribution, and (c) you retain without modification any copyright,
35 | patent, trademark, or attribution notices that are present in the Work.
36 |
37 | 3.2 Derivative Works. You may specify that additional or different terms apply
38 | to the use, reproduction, and distribution of your derivative works of the Work
39 | ("Your Terms") only if (a) Your Terms provide that the use limitation in
40 | Section 3.3 applies to your derivative works, and (b) you identify the specific
41 | derivative works that are subject to Your Terms. Notwithstanding Your Terms,
42 | this license (including the redistribution requirements in Section 3.1) will
43 | continue to apply to the Work itself.
44 |
45 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used
46 | or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA
47 | Corporation and its affiliates may use the Work and any derivative works
48 | commercially. As used herein, "non-commercially" means for research or
49 | evaluation purposes only.
50 |
51 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any
52 | Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to
53 | enforce any patents that you allege are infringed by any Work, then your rights
54 | under this license from such Licensor (including the grant in Section 2.1) will
55 | terminate immediately.
56 |
57 | 3.5 Trademarks. This license does not grant any rights to use any Licensor's or
58 | its affiliates' names, logos, or trademarks, except as necessary to reproduce
59 | the notices described in this license.
60 |
61 | 3.6 Termination. If you violate any term of this license, then your rights
62 | under this license (including the grant in Section 2.1) will terminate
63 | immediately.
64 |
65 | 4. Disclaimer of Warranty.
66 |
67 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
68 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
69 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
70 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
71 |
72 | 5. Limitation of Liability.
73 |
74 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY,
75 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY
76 | LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL,
77 | INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE,
78 | THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF
79 | GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR
80 | MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
81 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
82 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/robot-manipulation-on-rlbench?p=rvt-2-learning-precise-manipulation-from-few)
2 |
3 | [***RVT-2: Learning Precise Manipulation from Few Examples***](https://robotic-view-transformer-2.github.io/)
4 | [Ankit Goyal](http://imankgoyal.github.io), [Valts Blukis](https://www.cs.cornell.edu/~valts/), [Jie Xu](https://people.csail.mit.edu/jiex), [Yijie Guo](https://www.guoyijie.me/), [Yu-Wei Chao](https://research.nvidia.com/person/yu-wei-chao), [Dieter Fox](https://homes.cs.washington.edu/~fox/)
5 | ***RSS 2024***
6 |
7 | [***RVT: Robotic View Transformer for 3D Object Manipulation***](https://robotic-view-transformer.github.io/)
8 | [Ankit Goyal](http://imankgoyal.github.io), [Jie Xu](https://people.csail.mit.edu/jiex), [Yijie Guo](https://www.guoyijie.me/), [Valts Blukis](https://www.cs.cornell.edu/~valts/), [Yu-Wei Chao](https://research.nvidia.com/person/yu-wei-chao), [Dieter Fox](https://homes.cs.washington.edu/~fox/)
9 | ***CoRL 2023 (Oral)***
10 |
11 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
RVT-2 solving high precision tasks Single RVT solving multiple tasks
28 |
29 |
30 | This is the official repository that reproduces the results for [RVT-2](https://robotic-view-transformer-2.github.io/) and [RVT](https://robotic-view-transformer.github.io/). The repository is backward compatible. So you just need to pull the latest commit and can switch from RVT to RVT-2!
31 |
32 |
33 | If you find our work useful, please consider citing our:
34 | ```
35 | @article{goyal2024rvt2,
36 | title={RVT2: Learning Precise Manipulation from Few Demonstrations},
37 | author={Goyal, Ankit and Blukis, Valts and Xu, Jie and Guo, Yijie and Chao, Yu-Wei and Fox, Dieter},
38 | journal={RSS},
39 | year={2024},
40 | }
41 | @article{goyal2023rvt,
42 | title={RVT: Robotic View Transformer for 3D Object Manipulation},
43 | author={Goyal, Ankit and Xu, Jie and Guo, Yijie and Blukis, Valts and Chao, Yu-Wei and Fox, Dieter},
44 | journal={CoRL},
45 | year={2023}
46 | }
47 | ```
48 |
49 | ## Getting Started
50 |
51 | ### Install
52 | - Tested (Recommended) Versions: Python 3.8. We used CUDA 11.1.
53 |
54 | - **Step 1 (Optional):**
55 | We recommend using [conda](https://docs.conda.io/en/latest/miniconda.html) and creating a virtual environment.
56 | ```
57 | conda create --name rvt python=3.8
58 | conda activate rvt
59 | ```
60 |
61 | - **Step 2:** Install PyTorch. Make sure the PyTorch version is compatible with the CUDA version. One recommended version compatible with CUDA 11.1 and PyTorch3D can be installed with the following command. More instructions to install PyTorch can be found [here](https://pytorch.org/).
62 | ```
63 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
64 | ```
65 |
66 | Recently, we noticed an issue while using conda to install PyTorch. More details can be found [here](https://github.com/pytorch/pytorch/issues/123097). If you face the same issue, you can use the following command to install PyTorch using pip.
67 | ```
68 | pip install torch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 --index-url https://download.pytorch.org/whl/cu113
69 | ```
70 |
71 | - **Step 3:** Install PyTorch3D.
72 |
73 | You can skip this step if you only want to use RVT-2 as it uses our custom Point-Renderer for rendering. PyTorch3D is required for RVT.
74 |
75 | One recommended version that is compatible with the rest of the library can be installed as follows. Note that this might take some time. For more instructions visit [here](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md).
76 | ```
77 | curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
78 | tar xzf 1.10.0.tar.gz
79 | export CUB_HOME=$(pwd)/cub-1.10.0
80 | pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
81 | ```
82 |
83 | - **Step 4:** Install CoppeliaSim. PyRep requires version **4.1** of CoppeliaSim. Download and unzip CoppeliaSim:
84 | - [Ubuntu 16.04](https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Player_V4_1_0_Ubuntu16_04.tar.xz)
85 | - [Ubuntu 18.04](https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Player_V4_1_0_Ubuntu18_04.tar.xz)
86 | - [Ubuntu 20.04](https://downloads.coppeliarobotics.com/V4_1_0/CoppeliaSim_Player_V4_1_0_Ubuntu20_04.tar.xz)
87 |
88 | Once you have downloaded CoppeliaSim, add the following to your *~/.bashrc* file. (__NOTE__: the 'EDIT ME' in the first line)
89 |
90 | ```
91 | export COPPELIASIM_ROOT=/PATH/TO/COPPELIASIM/INSTALL/DIR
92 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$COPPELIASIM_ROOT
93 | export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
94 | export DISPLAY=:1.0
95 | ```
96 | Remember to source your .bashrc (`source ~/.bashrc`) or .zshrc (`source ~/.zshrc`) after this.
97 |
98 | - **Step 5:** Clone the repository with the submodules using the following command.
99 |
100 | ```
101 | git clone --recurse-submodules git@github.com:NVlabs/RVT.git && cd RVT && git submodule update --init
102 | ```
103 |
104 | Now, locally install the repository. You can either `pip install -e '.[xformers]'` to install the library with [xformers](https://github.com/facebookresearch/xformers) or `pip install -e .` to install without it. We recommend using the former as improves speed. However, sometimes the installation might fail due to the xformers dependency. In that case, you can install the library without xformers. The performance difference between the two is minimal but speed could be slower without xformers.
105 | ```
106 | pip install -e '.[xformers]'
107 | ```
108 |
109 | Install, required libraries for PyRep, RLBench, YARR, PerAct Colab, and Point Renderer.
110 | ```
111 | pip install -e rvt/libs/PyRep
112 | pip install -e rvt/libs/RLBench
113 | pip install -e rvt/libs/YARR
114 | pip install -e rvt/libs/peract_colab
115 | pip install -e rvt/libs/point-renderer
116 | ```
117 |
118 | - **Step 6:** Download dataset.
119 | - For experiments on RLBench, we use [pre-generated dataset](https://drive.google.com/drive/folders/0B2LlLwoO3nfZfkFqMEhXWkxBdjJNNndGYl9uUDQwS1pfNkNHSzFDNGwzd1NnTmlpZXR1bVE?resourcekey=0-jRw5RaXEYRLe2W6aNrNFEQ) provided by [PerAct](https://github.com/peract/peract#download). Please download and place them under `RVT/rvt/data/xxx` where `xxx` is either `train`, `test`, or `val`.
120 |
121 | - Additionally, we use the same dataloader as PerAct, which is based on [YARR](https://github.com/stepjam/YARR). YARR creates a replay buffer on the fly which can increase the startup time. We provide an option to directly load the replay buffer from the disk. We recommend using the pre-generated replay buffer (98 GB) as it reduces the startup time. You can download the replay buffer for [indidual tasks](https://huggingface.co/datasets/ankgoyal/rvt/tree/main/replay). After downloading, uncompress the replay buffer(s) (for example using the command `tar -xf .tar.xz`) and place it under `RVT/rvt/replay/replay_xxx/` where `xxx` is either `train` or `val`. It is useful only if you want to train RVT from scratch and not needed if you want to evaluate the pre-trained model.
122 |
123 |
124 | ## Using the library
125 |
126 | ### Training
127 | ##### Training RVT-2
128 |
129 | To train RVT-2 on all RLBench tasks, use the following command (from folder `RVT/rvt`):
130 | ```
131 | python train.py --exp_cfg_path configs/rvt2.yaml --mvt_cfg_path mvt/configs/rvt2.yaml --device 0,1,2,3,4,5,6,7
132 | ```
133 |
134 | ##### Training RVT
135 | To train RVT, use the following command (from folder `RVT/rvt`):
136 | ```
137 | python train.py --exp_cfg_path configs/rvt.yaml --device 0,1,2,3,4,5,6,7
138 | ```
139 | We use 8 V100 GPUs. Change the `device` flag depending on available compute.
140 |
141 | ##### More details about `train.py`
142 | - default parameters for an `experiment` are defined [here](https://github.com/NVlabs/RVT/blob/master/rvt/config.py).
143 | - default parameters for `rvt` are defined [here](https://github.com/NVlabs/RVT/blob/master/rvt/mvt/config.py).
144 | - the parameters in for `experiment` and `rvt` can be overwritten by two ways:
145 | - specifying the path of a yaml file
146 | - manually overwriting using a `opts` string of format ` ..`
147 | - Manual overwriting has higher precedence over the yaml file.
148 |
149 | ```
150 | python train.py --exp_cfg_opts <> --mvt_cfg_opts <> --exp_cfg_path <> --mvt_cfg_path <>
151 | ```
152 |
153 | The following command overwrites the parameters for the `experiment` with the `configs/all.yaml` file. It also overwrites the `bs` parameters through the command line.
154 | ```
155 | python train.py --exp_cfg_opts "bs 4" --exp_cfg_path configs/rvt.yaml --device 0
156 | ```
157 |
158 | ### Evaluate on RLBench
159 | ##### Evaluate RVT-2 on RLBench
160 | Download the [pretrained RVT-2 model](https://huggingface.co/ankgoyal/rvt/tree/main/rvt2). Place the model (`model_99.pth` trained for 99 epochs or ~80K steps with batch size 192) and the config files under the folder `RVT/rvt/runs/rvt2/`. Run evaluation using (from folder `RVT/rvt`):
161 | ```
162 | python eval.py --model-folder runs/rvt2 --eval-datafolder ./data/test --tasks all --eval-episodes 25 --log-name test/1 --device 0 --headless --model-name model_99.pth
163 | ```
164 | ##### Evaluate RVT on RLBench
165 | Download the [pretrained RVT model](https://huggingface.co/ankgoyal/rvt/tree/main/rvt). Place the model (`model_14.pth` trained for 15 epochs or 100K steps) and the config files under the folder `runs/rvt/`. Run evaluation using (from folder `RVT/rvt`):
166 | ```
167 | python eval.py --model-folder runs/rvt --eval-datafolder ./data/test --tasks all --eval-episodes 25 --log-name test/1 --device 0 --headless --model-name model_14.pth
168 | ```
169 |
170 | ##### Evaluate the official PerAct model on RLBench
171 | Download the [officially released PerAct model](https://github.com/peract/peract/releases/download/v1.0.0/peract_600k.zip).
172 | Put the downloaded policy under the `runs` folder with the recommended folder layout: `runs/peract_official/seed0`.
173 | Run the evaluation using:
174 | ```
175 | python eval.py --eval-episodes 25 --peract_official --peract_model_dir runs/peract_official/seed0/weights/600000 --model-name QAttentionAgent_layer0.pt --headless --task all --eval-datafolder ./data/test --device 0
176 | ```
177 |
178 | ## Gotchas
179 | - If you face issues installing `xformers` and PyTorch3D, information in this issue might be useful https://github.com/NVlabs/RVT/issues/45.
180 |
181 | - If you get qt plugin error like `qt.qpa.plugin: Could not load the Qt platform plugin "xcb" /cv2/qt/plugins" even though it was found`, try uninstalling opencv-python and installing opencv-python-headless
182 |
183 | ```
184 | pip uninstall opencv-python
185 | pip install opencv-python-headless
186 | ```
187 |
188 | - If you have CUDA 11.7, an alternate installation strategy could be to use the following command for Step 2 and Step 3. Note that this is not heavily tested.
189 | ```
190 | # Step 2:
191 | pip install pytorch torchvision torchaudio
192 | # Step 3:
193 | pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
194 | ```
195 |
196 | - If you are having issues running evaluation on a headless server, please refer to https://github.com/NVlabs/RVT/issues/2#issuecomment-1620704943.
197 |
198 | - If you want to generate visualization videos, please refer to https://github.com/NVlabs/RVT/issues/5.
199 |
200 | ## FAQ's
201 | ###### Q. What is the advantag of RVT-2 and RVT over PerAct?
202 | RVT's are both faster to train and performs better than PerAct.
203 |
204 |
205 | ###### Q. What resources are required to train RVT?
206 | For training on 18 RLBench tasks, with 100 demos per task, we use 8 V100 GPUs (16 GB memory each). The model trains in ~1 day.
207 |
208 | Note that for fair comparison with PerAct, we used the same dataset, which means [duplicate keyframes are loaded into the replay buffer](https://github.com/peract/peract#why-are-duplicate-keyframes-loaded-into-the-replay-buffer). For other datasets, one could consider not doing so, which might further speed up training.
209 |
210 | ###### Q. Why do you use `pe_fix=True` in the rvt [config](https://github.com/NVlabs/RVT/blob/master/rvt/mvt/config.py#L32)?
211 | For fair comparison with offical PerAct model, we use this setting. More detials about this can be found in PerAct [code](https://github.com/peract/peract/blob/main/agents/peract_bc/perceiver_lang_io.py#L387-L398). For future, we recommend using `pe_fix=False` for language input.
212 |
213 | ###### Q. Why are the results for PerAct different from the PerAct paper?
214 | In the PerAct paper, for each task, the best checkpoint is chosen based on the validation set performance. Hence, the model weights can be different for different tasks. We evaluate PerAct and RVT only on the final checkpoint, so that all tasks are strictly evaluated on the same model weights. Note that only the final model for PerAct has been released officially.
215 |
216 | ###### Q. Why is there a variance in performance on RLBench even when evaluting the same checkpoint?
217 | We hypothesize that it is because of the sampling based planner used in RLBench, which could be the source of the randomization. Hence, we evaluate each checkpoint 5 times and report mean and variance.
218 |
219 | ###### Q. Why did you use a cosine decay learning rate scheduler instead of a fixed learning rate schedule as done in PerAct?
220 | We found the cosine learning rate scheduler led to faster convergence for RVT. Training PerAct with our training hyper-parameters (cosine learning rate scheduler and same number of iterations) led to worse performance (in ~4 days of training time). Hence for Fig. 1, we used the official hyper-parameters for PerAct.
221 |
222 | ###### Q. For my use case, I want to render images at real camera locations (input camera poses) with PyTorch3D. Is it possible to do so and how can I do that?
223 | Yes, it is possible to do so. A self-sufficient example is present [here](https://github.com/NVlabs/RVT/issues/9). Depending on your use case, the code may need be modified. Also note that 3D augmentation cannot be used while rendering images at real camera locations as it would change the pose of the camera with respect to the point cloud.
224 |
225 | For questions and comments, please contact [Ankit Goyal](https://imankgoyal.github.io/).
226 |
227 | ## Acknowledgement
228 | We sincerely thank the authors of the following repositories for sharing their code.
229 |
230 | - [PerAct](https://github.com/peract/peract)
231 | - [PerAct Colab](https://github.com/peract/peract_colab/tree/master)
232 | - [PyRep](https://github.com/stepjam/PyRep)
233 | - [RLBench](https://github.com/stepjam/RLBench/tree/master)
234 | - [YARR](https://github.com/stepjam/YARR)
235 |
236 | ## License
237 | License Copyright © 2023, NVIDIA Corporation & affiliates. All rights reserved.
238 |
239 | This work is made available under the [Nvidia Source Code License](https://github.com/NVlabs/RVT/blob/master/LICENSE).
240 | The pretrained RVT models are released under the CC-BY-NC-SA-4.0 license.
241 |
--------------------------------------------------------------------------------
/rvt/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | from yacs.config import CfgNode as CN
6 |
7 | _C = CN()
8 |
9 | _C.agent = "our"
10 | _C.tasks = "insert_onto_square_peg,open_drawer,place_wine_at_rack_location,light_bulb_in"
11 | _C.exp_id = "def"
12 | _C.resume = ""
13 | # bs per device, effective bs is scaled by num device
14 | _C.bs = 4
15 | _C.epochs = 20
16 | # number of dataloader workers, >= 0
17 | _C.num_workers = 0
18 | # 'transition_uniform' or 'task_uniform'
19 | _C.sample_distribution_mode = 'transition_uniform'
20 | _C.train_iter = 16 * 10000
21 |
22 | # arguments present in both peract and rvt
23 | # some of them donot support every possible combination in peract
24 | _C.peract = CN()
25 | _C.peract.lambda_weight_l2 = 1e-6
26 | # lr should be thought on per sample basis
27 | # effective lr is multiplied by bs * num_devices
28 | _C.peract.lr = 2.5e-5
29 | _C.peract.optimizer_type = "lamb"
30 | _C.peract.warmup_steps = 0
31 | _C.peract.lr_cos_dec = False
32 | _C.peract.add_rgc_loss = True
33 | _C.peract.num_rotation_classes = 72
34 | _C.peract.amp = False
35 | _C.peract.bnb = False
36 | _C.peract.transform_augmentation = True
37 | _C.peract.transform_augmentation_xyz = [0.1, 0.1, 0.1]
38 | _C.peract.transform_augmentation_rpy = [0.0, 0.0, 20.0]
39 |
40 | # arguments present in only rvt and not peract
41 | _C.rvt = CN()
42 | _C.rvt.gt_hm_sigma = 1.5
43 | _C.rvt.img_aug = 0.1
44 | _C.rvt.place_with_mean = True
45 | _C.rvt.move_pc_in_bound = True
46 |
47 | # arguments present in peract official
48 | _C.peract_official = CN()
49 | _C.peract_official.cfg_path = "configs/peract_official_config.yaml"
50 |
51 |
52 | def get_cfg_defaults():
53 | """Get a yacs CfgNode object with default values for my_project."""
54 | return _C.clone()
55 |
--------------------------------------------------------------------------------
/rvt/configs/peract_official_config.yaml:
--------------------------------------------------------------------------------
1 | # copied from: https://github.com/peract/peract/releases/download/v1.0.0/peract_600k.zip
2 | method:
3 | name: PERACT_BC
4 | lr: 0.0005
5 | lr_scheduler: false
6 | num_warmup_steps: 3000
7 | optimizer: lamb
8 | activation: lrelu
9 | norm: None
10 | lambda_weight_l2: 1.0e-06
11 | trans_loss_weight: 1.0
12 | rot_loss_weight: 1.0
13 | grip_loss_weight: 1.0
14 | collision_loss_weight: 1.0
15 | rotation_resolution: 5
16 | image_crop_size: 64
17 | bounds_offset:
18 | - 0.15
19 | voxel_sizes:
20 | - 100
21 | num_latents: 2048
22 | latent_dim: 512
23 | transformer_depth: 6
24 | transformer_iterations: 1
25 | cross_heads: 1
26 | cross_dim_head: 64
27 | latent_heads: 8
28 | latent_dim_head: 64
29 | pos_encoding_with_lang: false
30 | lang_fusion_type: seq
31 | voxel_patch_size: 5
32 | voxel_patch_stride: 5
33 | input_dropout: 0.1
34 | attn_dropout: 0.1
35 | decoder_dropout: 0.0
36 | crop_augmentation: true
37 | final_dim: 64
38 | transform_augmentation:
39 | apply_se3: true
40 | aug_xyz:
41 | - 0.125
42 | - 0.125
43 | - 0.125
44 | aug_rpy:
45 | - 0.0
46 | - 0.0
47 | - 0.0
48 | aug_rot_resolution: 5
49 | demo_augmentation: true
50 | demo_augmentation_every_n: 10
51 | no_skip_connection: false
52 | no_perceiver: false
53 | no_language: false
54 | keypoint_method: heuristic
55 | ddp:
56 | master_addr: "localhost"
57 | master_port: "29500"
58 | num_devices: 1
59 | rlbench:
60 | task_name: multi
61 | tasks:
62 | - change_channel
63 | - close_jar
64 | - insert_onto_square_peg
65 | - light_bulb_in
66 | - meat_off_grill
67 | - open_drawer
68 | - place_cups
69 | - place_shape_in_shape_sorter
70 | - push_buttons
71 | - put_groceries_in_cupboard
72 | - put_item_in_drawer
73 | - put_money_in_safe
74 | - reach_and_drag
75 | - stack_blocks
76 | - stack_cups
77 | - turn_tap
78 | - set_clock_to_time
79 | - place_wine_at_rack_location
80 | - put_rubbish_in_color_bin
81 | - slide_block_to_color_target
82 | - sweep_to_dustpan_of_size
83 | demos: 100
84 | demo_path: /raid/dataset/
85 | episode_length: 25
86 | cameras:
87 | - front
88 | - left_shoulder
89 | - right_shoulder
90 | - wrist
91 | camera_resolution:
92 | - 128
93 | - 128
94 | scene_bounds:
95 | - -0.3
96 | - -0.5
97 | - 0.6
98 | - 0.7
99 | - 0.5
100 | - 1.6
101 | include_lang_goal_in_obs: True
102 | replay:
103 | batch_size: 16
104 | timesteps: 1
105 | prioritisation: false
106 | task_uniform: true
107 | use_disk: true
108 | path: /raid/arm/replay
109 | max_parallel_processes: 32
110 | framework:
111 | log_freq: 100
112 | save_freq: 10000
113 | train_envs: 1
114 | replay_ratio: 16
115 | transitions_before_train: 200
116 | tensorboard_logging: true
117 | csv_logging: true
118 | training_iterations: 600001
119 | gpu: 0
120 | env_gpu: 0
121 | logdir: /home/user/workspace/logs_may16_n100
122 | seeds: 1
123 | start_seed: 0
124 | load_existing_weights: true
125 | num_weights_to_keep: 60
126 | record_every_n: 5
127 |
128 |
--------------------------------------------------------------------------------
/rvt/configs/rvt.yaml:
--------------------------------------------------------------------------------
1 | exp_id: rvt
2 | tasks: all
3 | bs: 3
4 | num_workers: 3
5 | epochs: 15
6 | sample_distribution_mode: task_uniform
7 | peract:
8 | lr: 1e-4
9 | warmup_steps: 2000
10 | optimizer_type: lamb
11 | lr_cos_dec: True
12 | transform_augmentation_xyz: [0.125, 0.125, 0.125]
13 | transform_augmentation_rpy: [0.0, 0.0, 45.0]
14 | rvt:
15 | place_with_mean: False
16 |
--------------------------------------------------------------------------------
/rvt/configs/rvt2.yaml:
--------------------------------------------------------------------------------
1 | exp_id: rvt2
2 | tasks: all
3 | bs: 24
4 | num_workers: 3
5 | epochs: 15
6 | sample_distribution_mode: task_uniform
7 | peract:
8 | lr: 1.25e-5
9 | warmup_steps: 2000
10 | optimizer_type: lamb
11 | lr_cos_dec: True
12 | transform_augmentation_xyz: [0.125, 0.125, 0.125]
13 | transform_augmentation_rpy: [0.0, 0.0, 45.0]
14 | amp: True
15 | bnb: True
16 | lambda_weight_l2: 1e-4
17 | rvt:
18 | place_with_mean: False
19 | img_aug: 0.0
20 |
--------------------------------------------------------------------------------
/rvt/libs/peract_colab/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | """
6 | Setup of peract
7 | Author: Ankit Goyal
8 | """
9 | from setuptools import setup
10 |
11 | requirements = [
12 | ]
13 |
14 | setup(
15 | name="peract_colab",
16 | # version=__version__,
17 | long_description="",
18 | url="",
19 | keywords="robotics computer vision",
20 | classifiers=[
21 | "Programming Language :: Python",
22 | ],
23 | packages=["peract_colab"],
24 | install_requires=requirements,
25 | )
26 |
27 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/.gitattributes:
--------------------------------------------------------------------------------
1 | pcd_data.tar.gz filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/.gitignore:
--------------------------------------------------------------------------------
1 | build/*
2 | *.egg-info/*
3 | *.so
4 | */__pycache__/*
5 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022-2023, NVIDIA Corporation & affiliates. All rights reserved.
2 |
3 |
4 | NVIDIA Source Code License for instant neural graphics primitives
5 |
6 |
7 | =======================================================================
8 |
9 | 1. Definitions
10 |
11 | "Licensor" means any person or entity that distributes its Work.
12 |
13 | "Software" means the original work of authorship made available under
14 | this License.
15 |
16 | "Work" means the Software and any additions to or derivative works of
17 | the Software that are made available under this License.
18 |
19 | The terms "reproduce," "reproduction," "derivative works," and
20 | "distribution" have the meaning as provided under U.S. copyright law;
21 | provided, however, that for the purposes of this License, derivative
22 | works shall not include works that remain separable from, or merely
23 | link (or bind by name) to the interfaces of, the Work.
24 |
25 | Works, including the Software, are "made available" under this License
26 | by including in or with the Work either (a) a copyright notice
27 | referencing the applicability of this License to the Work, or (b) a
28 | copy of this License.
29 |
30 | 2. License Grants
31 |
32 | 2.1 Copyright Grant. Subject to the terms and conditions of this
33 | License, each Licensor grants to you a perpetual, worldwide,
34 | non-exclusive, royalty-free, copyright license to reproduce,
35 | prepare derivative works of, publicly display, publicly perform,
36 | sublicense and distribute its Work and any resulting derivative
37 | works in any form.
38 |
39 | 3. Limitations
40 |
41 | 3.1 Redistribution. You may reproduce or distribute the Work only
42 | if (a) you do so under this License, (b) you include a complete
43 | copy of this License with your distribution, and (c) you retain
44 | without modification any copyright, patent, trademark, or
45 | attribution notices that are present in the Work.
46 |
47 | 3.2 Derivative Works. You may specify that additional or different
48 | terms apply to the use, reproduction, and distribution of your
49 | derivative works of the Work ("Your Terms") only if (a) Your Terms
50 | provide that the use limitation in Section 3.3 applies to your
51 | derivative works, and (b) you identify the specific derivative
52 | works that are subject to Your Terms. Notwithstanding Your Terms,
53 | this License (including the redistribution requirements in Section
54 | 3.1) will continue to apply to the Work itself.
55 |
56 | 3.3 Use Limitation. The Work and any derivative works thereof only
57 | may be used or intended for use non-commercially. Notwithstanding
58 | the foregoing, NVIDIA and its affiliates may use the Work and any
59 | derivative works commercially. As used herein, "non-commercially"
60 | means for research or evaluation purposes only.
61 |
62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63 | against any Licensor (including any claim, cross-claim or
64 | counterclaim in a lawsuit) to enforce any patents that you allege
65 | are infringed by any Work, then your rights under this License from
66 | such Licensor (including the grant in Section 2.1) will terminate
67 | immediately.
68 |
69 | 3.5 Trademarks. This License does not grant any rights to use any
70 | Licensor’s or its affiliates’ names, logos, or trademarks, except
71 | as necessary to reproduce the notices described in this License.
72 |
73 | 3.6 Termination. If you violate any term of this License, then your
74 | rights under this License (including the grant in Section 2.1) will
75 | terminate immediately.
76 |
77 | 4. Disclaimer of Warranty.
78 |
79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83 | THIS LICENSE.
84 |
85 | 5. Limitation of Liability.
86 |
87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95 | THE POSSIBILITY OF SUCH DAMAGES.
96 |
97 | =======================================================================
98 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/README.md:
--------------------------------------------------------------------------------
1 | ## Point Renderer
2 | A minimal, lightweight CUDA-accelerated renderer of pointclouds.
3 |
4 | 
5 |
6 | ### Install
7 |
8 | ```
9 | pip install -r requirements.txt
10 | pip install -e .
11 | ```
12 |
13 | ### Run
14 |
15 | **Load Data**
16 | Extract included pcd_data.tar.gz
17 |
18 | ```
19 | import numpy as np
20 |
21 | data = np.load("pcd_data/w1280_h720/3.npy", allow_pickle=True)
22 | data = data[None][0]
23 | pc = data["pc"]
24 | rgb = data["img_feat"]
25 | ```
26 |
27 | **Render the image**
28 |
29 | ```
30 | # Make the renderer
31 | from point_renderer.renderer import PointRenderer
32 | renderer = PointRenderer(device="cuda", perf_timer=False)
33 |
34 | # Define a batch of cameras
35 | img_size = (512, 512)
36 | K = renderer.get_camera_intrinsics(hfov=70, img_size=img_size)
37 | camera_poses = renderer.get_batch_of_camera_poses(
38 | cam_positions=[[1.5, 1.5, 1.5],[-1.5, -1.5, -1.5]],
39 | cam_lookats=[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]])
40 |
41 | # Render the pointcloud from the given cameras
42 | images, depths = renderer.render_batch(pc, rgb, camera_poses, K, img_size,
43 | default_color=1.0,
44 | splat_radius=0.005,
45 | aa_factor=2
46 | )
47 |
48 | # Show the results
49 | plt.imshow(images[0].detach().cpu().numpy()); plt.show()
50 | plt.imshow(depths[0].detach().cpu().numpy()); plt.show()
51 | plt.imshow(images[1].detach().cpu().numpy()); plt.show()
52 | plt.imshow(depths[1].detach().cpu().numpy()); plt.show()
53 | ```
54 |
55 | .. Or run the jupyter notebook that has this same code above, and also all the benchmarks.
56 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/RVT/367995a1a2169b6352bf4e8b0ed405890462a3a0/rvt/libs/point-renderer/demo.png
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/image_0_splat_2xaa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/RVT/367995a1a2169b6352bf4e8b0ed405890462a3a0/rvt/libs/point-renderer/image_0_splat_2xaa.png
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/pcd_data.tar.gz:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:c47f7c538558e941a82b49c2dfa3f0eb95618a7beaf1cb69e697acb6692983e9
3 | size 12732634
4 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/cameras.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from point_renderer import ops
4 | from functools import lru_cache
5 |
6 | @lru_cache(maxsize=32)
7 | def linalg_inv(poses):
8 | return torch.linalg.inv(poses)
9 |
10 | class Cameras:
11 | def __init__(self, poses, intrinsics, img_size, inv_poses=None):
12 | self.poses = poses
13 | self.img_size = img_size
14 | if inv_poses is None:
15 | self.inv_poses = linalg_inv(poses)
16 | else:
17 | self.inv_poses = inv_poses
18 | self.intrinsics = intrinsics
19 |
20 | def __len__(self):
21 | return len(self.poses)
22 |
23 | def scale(self, constant):
24 | self.intrinsics = self.intrinsics.clone()
25 | self.intrinsics[:, :2, :3] *= constant
26 |
27 | def is_orthographic(self):
28 | raise ValueError("is_orthographic should be called on child classes only")
29 |
30 | def is_perspective(self):
31 | raise ValueError("is_perspective should be called on child classes only")
32 |
33 |
34 | class PerspectiveCameras(Cameras):
35 | def __init__(self, poses, intrinsics, img_size, inv_poses=None):
36 | super().__init__(poses, intrinsics, img_size, inv_poses)
37 |
38 | @classmethod
39 | def from_lookat(cls, eyes, ats, ups, hfov, img_size, device="cpu"):
40 | cam_poses = []
41 | for eye, at, up in zip(eyes, ats, ups):
42 | T = ops.lookat_to_cam_pose(eye, at, up, device=device)
43 | cam_poses.append(T)
44 | cam_poses = torch.stack(cam_poses, dim=0)
45 | intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device)
46 | intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
47 | return PerspectiveCameras(cam_poses, intrinsics, img_size)
48 |
49 | @classmethod
50 | def from_rotation_and_translation(cls, R, T, S, hfov, img_size):
51 | device = R.device
52 | assert T.device == device
53 | cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float)
54 | cam_poses[:, :3, :3] = R * S[None, :]
55 | cam_poses[:, :3, 3] = T
56 | cam_poses[:, 3, 3] = 1.0
57 | intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device)
58 | intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
59 | return PerspectiveCameras(cam_poses, intrinsics, img_size)
60 |
61 | def to(self, device):
62 | return PerspectiveCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device))
63 |
64 | def is_orthographic(self):
65 | return False
66 |
67 | def is_perspective(self):
68 | return True
69 |
70 | class OrthographicCameras(Cameras):
71 | def __init__(self, poses, intrinsics, img_size, inv_poses=None):
72 | super().__init__(poses, intrinsics, img_size, inv_poses)
73 |
74 | @classmethod
75 | def from_lookat(cls, eyes, ats, ups, img_sizes_w, img_size_px, device="cpu"):
76 | """
77 | Args:
78 | eyes: Nx3 tensor of camera coordinates
79 | ats: Nx3 tensor of look-at directions
80 | ups: Nx3 tensor of up-vectors
81 | scale: Nx2 tensor defining image sizes in world coordinates
82 | img_size: 2-dim tuple defining image size in pixels
83 | Returns:
84 | OrthographicCamera
85 | """
86 | if isinstance(img_sizes_w, list):
87 | img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(eyes), 1))
88 |
89 | cam_poses = []
90 | for eye, at, up in zip(eyes, ats, ups):
91 | T = ops.lookat_to_cam_pose(eye, at, up, device=device)
92 | cam_poses.append(T)
93 | cam_poses = torch.stack(cam_poses, dim=0)
94 | intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device)
95 | return OrthographicCameras(cam_poses, intrinsics, img_size_px)
96 |
97 | @classmethod
98 | def from_rotation_and_translation(cls, R, T, img_sizes_w, img_size_px, device="cpu"):
99 | if isinstance(img_sizes_w, list):
100 | img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(R), 1))
101 |
102 | device = R.device
103 | assert T.device == device
104 | cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float)
105 | cam_poses[:, :3, :3] = R
106 | cam_poses[:, :3, 3] = T
107 | cam_poses[:, 3, 3] = 1.0
108 | intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device)
109 | intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
110 | return OrthographicCameras(cam_poses, intrinsics, img_size_px)
111 |
112 | def to(self, device):
113 | return OrthographicCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device))
114 |
115 | def is_orthographic(self):
116 | return True
117 |
118 | def is_perspective(self):
119 | return False
120 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/csrc/bindings.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION and its licensors retain all intellectual property
5 | * and proprietary rights in and to this software, related documentation
6 | * and any modifications thereto. Any use, reproduction, disclosure or
7 | * distribution of this software and related documentation without an express
8 | * license agreement from NVIDIA CORPORATION is strictly prohibited.
9 | */
10 |
11 | /** @file bindings.cpp
12 | * @author Valts Blukis, NVIDIA
13 | * @brief PyTorch bindings for pointcloud renderer
14 | */
15 |
16 | #include
17 | #include "./render/render_pointcloud.h"
18 |
19 | namespace rvt {
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | py::module render = m.def_submodule("render");
23 | render.def("render_feature_pointcloud_to_image", &render_feature_pointcloud_to_image);
24 | render.def("screen_space_splatting", &screen_space_splatting);
25 | render.def("aa_subsample", &aa_subsample);
26 | }
27 |
28 | }
29 |
30 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/csrc/render/render_feature_pointcloud.cu:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION and its licensors retain all intellectual property
5 | * and proprietary rights in and to this software, related documentation
6 | * and any modifications thereto. Any use, reproduction, disclosure or
7 | * distribution of this software and related documentation without an express
8 | * license agreement from NVIDIA CORPORATION is strictly prohibited.
9 | */
10 |
11 | /** @file render_feature_pointcloud.cu
12 | * @author Valts Blukis, NVIDIA
13 | * @brief Renders a point cloud with associated feature vectors to image
14 | */
15 |
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 |
25 | namespace rvt {
26 |
27 |
28 | // By what factor to scale depth for integer representation
29 | __constant__ const float DEPTH_FACTOR = 1000;
30 | __constant__ const float DEPTH_INV_FACTOR = 1 / DEPTH_FACTOR;
31 |
32 |
33 | __global__ void render_pointcloud_to_depth_index_buffer_cuda_kernel(
34 | int64_t num_points,
35 | int64_t img_height,
36 | int64_t img_width,
37 |
38 | int64_t* point_indices,
39 | float* point_depths,
40 |
41 | uint64_t* packed_buffer
42 | ){
43 | uint32_t tidx = blockDim.x * blockIdx.x + threadIdx.x;
44 |
45 | if (tidx < num_points) {
46 | int64_t pixel_index = point_indices[tidx];
47 |
48 | if (pixel_index >= 0 && pixel_index < img_height * img_width) {
49 | float point_depth = point_depths[tidx];
50 | uint32_t point_depth_mm = (uint32_t) (point_depth * DEPTH_FACTOR);
51 | uint64_t packed_depth_and_index = ((uint64_t) point_depth_mm << 32) | ((uint64_t) tidx);
52 | atomicMin((unsigned long long*) (packed_buffer + pixel_index), (unsigned long long)packed_depth_and_index);
53 | }
54 | }
55 | }
56 |
57 |
58 | __global__ void output_render_to_feature_image_cuda_kernel(
59 | int64_t img_height,
60 | int64_t img_width,
61 | int64_t num_channels,
62 | int32_t num_points,
63 | float* point_features,
64 | uint64_t* packed_buffer,
65 | float* image_out,
66 | float* depth_out,
67 | float default_depth,
68 | float default_feature
69 | ) {
70 |
71 | uint tidx = blockDim.x * blockIdx.x + threadIdx.x;
72 | if (tidx < img_height * img_width) {
73 | uint64_t packed = packed_buffer[tidx];
74 | uint32_t packed_depth_mm = (uint32_t) (packed >> 32);
75 | // The modulo is to support batching without having to tile the features
76 | uint32_t packed_index = ((uint32_t) packed) % num_points;
77 |
78 | if (packed_depth_mm == 0xFFFFFFFF) {
79 | depth_out[tidx] = default_depth;
80 | for (int i = 0; i < num_channels; i++) {
81 | image_out[tidx * num_channels + i] = default_feature;
82 | }
83 | }
84 | else {
85 | depth_out[tidx] = (float) (packed_depth_mm) * DEPTH_INV_FACTOR;
86 | for (int i = 0; i < num_channels; i++) {
87 | image_out[tidx * num_channels + i] = point_features[packed_index * num_channels + i];
88 | }
89 | }
90 | }
91 | }
92 |
93 |
94 | __global__ void output_render_to_feature_image_cuda_2d_kernel(
95 | int64_t img_height,
96 | int64_t img_width,
97 | int64_t num_channels,
98 | int32_t num_points,
99 | float* point_features,
100 | uint64_t* packed_buffer,
101 | float* image_out,
102 | float* depth_out,
103 | float default_depth,
104 | float default_feature
105 | ) {
106 | uint pixidx = blockDim.x * blockIdx.x + threadIdx.x;
107 | uint cidx = blockDim.y * blockIdx.y + threadIdx.y;
108 |
109 | if (pixidx < img_height * img_width && cidx < num_channels) {
110 | uint64_t packed = packed_buffer[pixidx];
111 | uint32_t packed_depth_mm = (uint32_t) (packed >> 32);
112 | // The modulo is to support batching without having to tile the features
113 | uint32_t packed_index = ((uint32_t) packed) % num_points;
114 |
115 | if (packed_depth_mm == 0xFFFFFFFF) {
116 | if (cidx == 0)
117 | depth_out[pixidx] = default_depth;
118 | image_out[pixidx * num_channels + cidx] = default_feature;
119 | }
120 | else {
121 | if (cidx == 0)
122 | depth_out[pixidx] = (float) (packed_depth_mm) * DEPTH_INV_FACTOR;
123 | image_out[pixidx * num_channels + cidx] = point_features[packed_index * num_channels + cidx];
124 | }
125 | }
126 | }
127 |
128 |
129 | void render_feature_pointcloud_to_image(
130 | at::Tensor point_indices, // Index into flattened image. -1 if out of bounds.
131 | at::Tensor point_depths,
132 | at::Tensor point_features,
133 | at::Tensor image_out,
134 | at::Tensor depth_out,
135 | float default_depth,
136 | float default_color) {
137 |
138 | int64_t num_points = point_indices.size(0);
139 | int32_t num_points_per_batch = point_features.size(0);
140 | int64_t img_height = image_out.size(0);
141 | int64_t img_width = image_out.size(1);
142 | int64_t num_channels = image_out.size(2);
143 |
144 | if (num_channels != point_features.size(1)) {
145 | throw std::runtime_error("Output image and point features must have the same channel dimension");
146 | }
147 |
148 | // TODO: Play with this to see if we can speed it up
149 | uint64_t num_threads_per_block = 1024;
150 |
151 | // Make sure cudaMalloc uses the correct device
152 | int device_index = point_indices.get_device();
153 | cudaSetDevice(device_index);
154 |
155 | // Allocate memory for storing packed depths and colors
156 | uint64_t* packed_depths_and_indices;
157 | cudaMalloc((void**) &packed_depths_and_indices, img_width*img_height*sizeof(uint64_t));
158 | cudaMemset(packed_depths_and_indices, 0xFFFFFFFFFFFFFFFF, img_width*img_height*sizeof(uint64_t));
159 |
160 | render_pointcloud_to_depth_index_buffer_cuda_kernel<<<(num_points + num_threads_per_block - 1) / num_threads_per_block, num_threads_per_block>>>(
161 | num_points,
162 | img_height,
163 | img_width,
164 |
165 | point_indices.data_ptr(),
166 | point_depths.data_ptr(),
167 |
168 | packed_depths_and_indices);
169 |
170 | // With few channels, it's faster to launch a thread per pixel, in each thread looping over the channels and copying the data
171 | if (num_channels < 10)
172 | {
173 | output_render_to_feature_image_cuda_kernel<<<(img_height * img_width + num_threads_per_block - 1) / num_threads_per_block, num_threads_per_block>>>(
174 | img_height,
175 | img_width,
176 | num_channels,
177 | num_points_per_batch,
178 | point_features.data_ptr(),
179 | packed_depths_and_indices,
180 | image_out.data_ptr(),
181 | depth_out.data_ptr(),
182 | default_depth,
183 | default_color);
184 | }
185 | // With more channels, it's better to launch a separate thread per pixel per channel, in each thread copying only one feature scalar
186 | else
187 | {
188 | output_render_to_feature_image_cuda_2d_kernel<<>>(
189 | img_height,
190 | img_width,
191 | num_channels,
192 | num_points_per_batch,
193 | point_features.data_ptr(),
194 | packed_depths_and_indices,
195 | image_out.data_ptr(),
196 | depth_out.data_ptr(),
197 | default_depth,
198 | default_color);
199 | }
200 |
201 | cudaFree(packed_depths_and_indices);
202 | }
203 |
204 |
205 | __global__ void screen_space_splatting_cuda_kernel(
206 | int64_t batch_size,
207 | int64_t img_height,
208 | int64_t img_width,
209 | int64_t num_channels,
210 | int k,
211 | float splat_radius,
212 | float focal_length_px,
213 | float* depth_in,
214 | float* image_in,
215 | float* depth_out,
216 | float* image_out,
217 | float default_depth,
218 | bool orthographic
219 | ) {
220 | uint x = blockDim.x * blockIdx.x + threadIdx.x;
221 | uint y = blockDim.y * blockIdx.y + threadIdx.y;
222 | uint c_index = y * img_width + x;
223 | uint batch_elem_height = img_height / batch_size;
224 |
225 | if (y < img_height && x < img_width) {
226 | float center_depth = depth_in[c_index];
227 | float min_depth = center_depth;
228 | int splat_index = c_index;
229 |
230 | int b_elem = y / batch_elem_height;
231 |
232 | // Loop over pixel's neighbourhood
233 | for (int dx = -k/2; dx <= k/2; dx++) {
234 | for (int dy = -k/2; dy <= k/2; dy++) {
235 | int nx = x + dx;
236 | int ny = y + dy;
237 | if (nx >= img_width || nx < 0 || ny >= (b_elem + 1) * batch_elem_height || ny < b_elem * batch_elem_height) {
238 | continue;
239 | }
240 | // ignore the center pixel itself
241 | /*if (dx == 0 && dy == 0) {
242 | continue;
243 | }*/
244 | int n_index = ny * img_width + nx;
245 |
246 | // Compute neighbor's splat size in pixels
247 | float neighbor_depth = depth_in[n_index];
248 |
249 | // If neighbor is further than current center value, or is unobserved, ignore it
250 | if (neighbor_depth == default_depth || (neighbor_depth > min_depth && min_depth != default_depth)) {
251 | continue;
252 | }
253 |
254 | // Otherwise neighbor is closer to camera than center. Consider it.
255 | float n_splat_size_px;
256 | if (orthographic) {
257 | n_splat_size_px = focal_length_px * splat_radius;
258 | } else {
259 | n_splat_size_px = focal_length_px * splat_radius / neighbor_depth;
260 | }
261 |
262 | float n_dst = sqrt((float)(dx * dx + dy * dy));
263 | // If the splat is big enough to cover the center pixel, remember it
264 | if (n_splat_size_px > n_dst && neighbor_depth) {
265 | splat_index = n_index;
266 | min_depth = neighbor_depth;
267 | }
268 | }
269 | }
270 |
271 | // TODO: we can consider applying some blending instead of just a harsh copy
272 | depth_out[c_index] = depth_in[splat_index];
273 | for (int i = 0; i < num_channels; i++) {
274 | image_out[c_index * num_channels + i] = image_in[splat_index * num_channels + i];
275 | }
276 | }
277 | }
278 |
279 |
280 | void screen_space_splatting(
281 | int batch_size,
282 | at::Tensor depth_in,
283 | at::Tensor image_in,
284 | at::Tensor depth_out,
285 | at::Tensor image_out,
286 | float default_depth,
287 | float focal_length_px,
288 | float splat_radius,
289 | int kernel_size,
290 | bool orthographic
291 | ){
292 | int64_t img_height = image_out.size(0);
293 | int64_t img_width = image_out.size(1);
294 | int64_t num_channels = image_out.size(2);
295 |
296 | int num_threads_per_block_x = 16;
297 | screen_space_splatting_cuda_kernel<<>>(
300 | batch_size,
301 | img_height,
302 | img_width,
303 | num_channels,
304 | kernel_size,
305 | splat_radius,
306 | focal_length_px,
307 | depth_in.data_ptr(),
308 | image_in.data_ptr(),
309 | depth_out.data_ptr(),
310 | image_out.data_ptr(),
311 | default_depth,
312 | orthographic
313 | );
314 | }
315 |
316 |
317 | __global__ void aa_subsampling_kernel(
318 | int64_t batch_size,
319 | int64_t img_height,
320 | int64_t img_width,
321 | int64_t num_channels,
322 | int aa_factor,
323 | float* depth_in,
324 | float* image_in,
325 | float* depth_out,
326 | float* image_out,
327 | float default_depth
328 | ) {
329 | uint x = blockDim.x * blockIdx.x + threadIdx.x;
330 | uint y = blockDim.y * blockIdx.y + threadIdx.y;
331 | uint b = blockDim.z * blockIdx.z + threadIdx.z;
332 | uint out_index = b * img_height * img_width + y * img_width + x;
333 |
334 | uint aa_img_width = img_width * aa_factor;
335 | uint aa_img_height = img_height * aa_factor;
336 |
337 | if (y < img_height && x < img_width && b < batch_size) {
338 | // Loop over pixel's corresponding patch in the input image
339 | for (int dx = 0; dx < aa_factor; dx++) {
340 | for (int dy = 0; dy < aa_factor; dy++) {
341 | int ox = x * aa_factor + dx;
342 | int oy = y * aa_factor + dy;
343 | if (ox >= aa_img_width|| ox < 0 || oy >= aa_img_height|| oy < 0) {
344 | continue;
345 | }
346 | int in_index = b * aa_img_height * aa_img_width + oy * aa_img_width + ox;
347 |
348 | // Average color over all pixels
349 | for (int i = 0; i < num_channels; i++) {
350 | image_out[out_index * num_channels + i] += image_in[in_index * num_channels + i];
351 | }
352 | // Take min depth across all pixels (median would be better, but that needs more memory to order the values)
353 | float d_in = depth_in[in_index];
354 | float d_out = depth_out[out_index];
355 | depth_out[out_index] = d_in;
356 |
357 | // TODO: The epsilon is hard-coded here, for some applications this might be a problem
358 | if (fabsf(d_in - default_depth) > 1e-6) { //d_in != default_depth
359 | if (fabsf(d_out - default_depth) < 1e-6) { //d_out == default_depth
360 | depth_out[out_index] = d_in;
361 | }
362 | else {
363 | depth_out[out_index] = min(d_out, d_in);
364 | }
365 | }
366 | }
367 | }
368 |
369 | for (int i = 0; i < num_channels; i++) {
370 | image_out[out_index * num_channels + i] = image_out[out_index * num_channels + i] / (aa_factor * aa_factor);
371 | }
372 | }
373 | }
374 |
375 |
376 | void aa_subsample(
377 | at::Tensor depth_in,
378 | at::Tensor image_in,
379 | at::Tensor depth_out,
380 | at::Tensor image_out,
381 | int aa_factor,
382 | float default_depth
383 | ) {
384 | int64_t batch_size = image_out.size(0);
385 | int64_t img_height = image_out.size(1);
386 | int64_t img_width = image_out.size(2);
387 | int64_t num_channels = image_out.size(3);
388 |
389 | int num_threads_per_block_x = 32;
390 |
391 | aa_subsampling_kernel<<>>(
396 | batch_size,
397 | img_height,
398 | img_width,
399 | num_channels,
400 | aa_factor,
401 | depth_in.data_ptr(),
402 | image_in.data_ptr(),
403 | depth_out.data_ptr(),
404 | image_out.data_ptr(),
405 | default_depth
406 | );
407 | }
408 |
409 |
410 | } // namespace rvt
411 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/csrc/render/render_pointcloud.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * NVIDIA CORPORATION and its licensors retain all intellectual property
5 | * and proprietary rights in and to this software, related documentation
6 | * and any modifications thereto. Any use, reproduction, disclosure or
7 | * distribution of this software and related documentation without an express
8 | * license agreement from NVIDIA CORPORATION is strictly prohibited.
9 | */
10 |
11 | /** @file render_pointcloud.h
12 | * @author Valts Blukis, NVIDIA
13 | * @brief Header for pointcloud rendering
14 | */
15 |
16 | #pragma once
17 |
18 | #include
19 |
20 |
21 | namespace rvt {
22 |
23 | void render_feature_pointcloud_to_image(
24 | at::Tensor point_indices, // Index into flattened image. -1 if out of bounds.
25 | at::Tensor point_depths,
26 | at::Tensor point_features,
27 | at::Tensor image_out,
28 | at::Tensor depth_out,
29 | float default_depth,
30 | float default_color);
31 |
32 | void screen_space_splatting(
33 | int batch_size,
34 | at::Tensor depth_in,
35 | at::Tensor image_in,
36 | at::Tensor depth_out,
37 | at::Tensor image_out,
38 | float default_depth,
39 | float focal_length_px,
40 | float splat_radius,
41 | int kernel_size,
42 | bool orthographic
43 | );
44 |
45 | void aa_subsample(
46 | at::Tensor depth_in,
47 | at::Tensor image_in,
48 | at::Tensor depth_out,
49 | at::Tensor image_out,
50 | int aa_factor,
51 | float default_depth
52 | );
53 |
54 | } // namespace rvt
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | import math
13 | from transforms3d import euler, quaternions, affines
14 | import numpy as np
15 |
16 |
17 | def transform_points_batch(pc : torch.Tensor, inv_cam_poses : torch.Tensor):
18 | pc_h = torch.cat([pc, torch.ones_like(pc[:, 0:1])], dim=1)
19 | pc_cam_h = torch.einsum("bxy,ny->bnx", inv_cam_poses, pc_h)
20 | pc_cam = pc_cam_h[:, :, :3]
21 | return pc_cam
22 |
23 | def transform_points(pc : torch.Tensor, inv_cam_pose : torch.Tensor):
24 | pc_h = torch.cat([pc, torch.ones_like(pc[:, 0:1])], dim=1)
25 | pc_cam_h = torch.einsum("xy,ny->nx", inv_cam_pose, pc_h)
26 | pc_cam = pc_cam_h[:, :3]
27 | return pc_cam
28 |
29 | def orthographic_camera_projection_batch(pc_cam : torch.Tensor, K : torch.Tensor):
30 | # For orthographic camera projection, treat all points as if they are at depth 1
31 | uvZ = torch.einsum("bxy,bny->bnx", K, torch.cat([pc_cam[:, :, :2], torch.ones_like(pc_cam[:, :, 2:3])], dim=2))
32 | return uvZ[:, :, :2]
33 |
34 | def orthographic_camera_projection(pc_cam : torch.Tensor, K : torch.Tensor):
35 | # For orthographic camera projection, treat all points as if they are at depth 1
36 | uvZ = torch.einsum("xy,ny->nx", K, torch.cat([pc_cam[:, :2], torch.ones_like(pc_cam[:, 2:3])], dim=1))
37 | return uvZ[:, :2]
38 |
39 | def perspective_camera_projection_batch(pc_cam : torch.Tensor, K : torch.Tensor):
40 | uvZ = torch.einsum("bxy,bny->bnx", K, pc_cam)
41 | uv = torch.stack([uvZ[:, :, 0] / uvZ[:, :, 2], uvZ[:, :, 1] / uvZ[:, :, 2]], dim=2)
42 | return uv
43 |
44 | def perspective_camera_projection(pc_cam : torch.Tensor, K : torch.Tensor):
45 | uvZ = torch.einsum("xy,ny->nx", K, pc_cam)
46 | uv = torch.stack([uvZ[:, 0] / uvZ[:, 2], uvZ[:, 1] / uvZ[:, 2]], dim=1)
47 | return uv
48 |
49 | def project_points_3d_to_pixels(pc : torch.Tensor, inv_cam_poses : torch.Tensor, intrinsics : torch.Tensor, orthographic : bool):
50 | """
51 | This combines the projection from 3D coordinates to camera coordinates using extrinsics,
52 | followed by projection from camera coordinates to pixel coordinates using the intrinsics.
53 | """
54 | # Project points from world to camera frame
55 | pc_cam = transform_points_batch(pc, inv_cam_poses)
56 | # Project points from camera frame to pixel space
57 | if orthographic:
58 | pc_px = orthographic_camera_projection_batch(pc_cam, intrinsics)
59 | else:
60 | pc_px = perspective_camera_projection_batch(pc_cam, intrinsics)
61 | return pc_px, pc_cam
62 |
63 |
64 | def get_batch_pixel_index(pc_px : torch.Tensor, img_height : int, img_width : int):
65 | """
66 | Convert a 2D pixel coordinate from a batch of pointclouds to an index
67 | that indexes into a corresponding flattened batch of 2D images.
68 | """
69 | # batch_idx
70 | batch_idx = torch.arange(pc_px.shape[0], device=pc_px.device, dtype=torch.long)[:, None]#.repeat((1, pc_px.shape[1])).
71 | pix_off = 0.0
72 | row_idx = (pc_px[:, :, 1] + pix_off).long()
73 | col_idx = (pc_px[:, :, 0] + pix_off).long()
74 | pixel_index = batch_idx * img_height * img_width + row_idx * img_width + col_idx
75 | return pixel_index
76 |
77 | def get_pixel_index(pc_px : torch.Tensor, img_width : int):
78 | pix_off = 0.0
79 | row_idx = (pc_px[:, 1] + pix_off).long()
80 | col_idx = (pc_px[:, 0] + pix_off).long()
81 | pixel_index = row_idx * img_width + col_idx
82 | return pixel_index.long()
83 |
84 |
85 | def batch_frustrum_mask(pc_px : torch.Tensor, img_height : int, img_width : int):
86 | imask_x = torch.logical_and(pc_px[:, :, 0] >= 0, pc_px[:, :, 0] < img_width)
87 | imask_y = torch.logical_and(pc_px[:, :, 1] >= 0, pc_px[:, :, 1] < img_height)
88 | imask = torch.logical_and(imask_x, imask_y)
89 | return imask
90 |
91 | def frustrum_mask(pc_px : torch.Tensor, img_height : int, img_width : int):
92 | imask_x = torch.logical_and(pc_px[:, 0] >= 0, pc_px[:, 0] < img_width)
93 | imask_y = torch.logical_and(pc_px[:, 1] >= 0, pc_px[:, 1] < img_height)
94 | imask = torch.logical_and(imask_x, imask_y)
95 | return imask
96 |
97 | def lookat_to_cam_pose(eye, at, up=[0, 0, 1], device="cpu"):
98 | # This runs on CPU, moving to GPU at the end (that's faster)
99 | eye = torch.tensor(eye, device="cpu", dtype=torch.float32)
100 | at = torch.tensor(at, device="cpu", dtype=torch.float32)
101 |
102 | camera_view = F.normalize(at - eye, dim=0)
103 | camera_right = F.normalize(torch.cross(camera_view, torch.tensor(up, dtype=torch.float32, device="cpu"), dim=0), dim=0)
104 | camera_up = F.normalize(torch.cross(camera_right, camera_view, dim=0), dim=0)
105 |
106 | # rotation matrix from opencv conventions
107 | R = torch.stack([camera_right, -camera_up, camera_view], dim=0).T
108 |
109 | T = torch.from_numpy(affines.compose(eye, R, [1, 1, 1]))
110 | return T.float().to(device)
111 |
112 | def fov_and_size_to_intrinsics(fov, img_size, device="cpu"):
113 | img_h, img_w = img_size
114 | fx = img_w / (2 * math.tan(math.radians(fov) / 2))
115 | fy = img_h / (2 * math.tan(math.radians(fov) / 2))
116 |
117 | intrinsics = torch.tensor([
118 | [fx, 0, img_h / 2],
119 | [0, fy, img_w / 2],
120 | [0, 0, 1]
121 | ], dtype=torch.float, device=device)
122 | return intrinsics
123 |
124 | def orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device="cpu"):
125 | img_h, img_w = img_size_px
126 | fx = img_h / (img_sizes_w[:, 0])
127 | fy = img_w / (img_sizes_w[:, 1])
128 |
129 | intrinsics = torch.zeros([len(img_sizes_w), 3, 3], dtype=torch.float, device=device)
130 | intrinsics[:, 0, 0] = fx
131 | intrinsics[:, 1, 1] = fy
132 | intrinsics[:, 0, 2] = img_h / 2
133 | intrinsics[:, 1, 2] = img_w / 2
134 | #intrinsics[:, 2, 2] = 1.0
135 | return intrinsics
136 |
137 |
138 | def unravel_index(pixel_index : torch.Tensor, img_width : int):
139 | row_idx = pixel_index // img_width
140 | col_idx = pixel_index % img_width
141 | return torch.stack([col_idx, row_idx], dim=1)
142 |
143 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/profiler.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c) 2021, NVIDIA CORPORATION.
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | # this software and associated documentation files (the "Software"), to deal in
7 | # the Software without restriction, including without limitation the rights to
8 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | # the Software, and to permit persons to whom the Software is furnished to do so,
10 | # subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
17 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
18 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
19 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21 |
22 |
23 | import time
24 | import torch
25 | import glob
26 | import os
27 |
28 | class bcolors:
29 | HEADER = '\033[95m'
30 | OKBLUE = '\033[94m'
31 | OKGREEN = '\033[92m'
32 | WARNING = '\033[93m'
33 | FAIL = '\033[91m'
34 | ENDC = '\033[0m'
35 | BOLD = '\033[1m'
36 | UNDERLINE = '\033[4m'
37 |
38 | def colorize_time(elapsed):
39 | if elapsed > 1e-3:
40 | return bcolors.FAIL + "{:.3e}".format(elapsed) + bcolors.ENDC
41 | elif elapsed > 1e-4:
42 | return bcolors.WARNING + "{:.3e}".format(elapsed) + bcolors.ENDC
43 | elif elapsed > 1e-5:
44 | return bcolors.OKBLUE + "{:.3e}".format(elapsed) + bcolors.ENDC
45 | else:
46 | return "{:.3e}".format(elapsed)
47 |
48 | def print_gpu_memory():
49 | torch.cuda.empty_cache()
50 | print(f"{torch.cuda.memory_allocated()//(1024*1024)} mb")
51 |
52 |
53 | class PerfTimer():
54 | def __init__(self, activate=False, show_memory=False, print_mode=True):
55 | self.activate = activate
56 | if activate:
57 | self.show_memory = show_memory
58 | self.print_mode = print_mode
59 | self.init()
60 |
61 | def init(self):
62 | self.reset()
63 | self.loop_totals_cpu = {}
64 | self.loop_totals_gpu = {}
65 | self.loop_counts = {}
66 |
67 |
68 | def reset(self):
69 | if self.activate:
70 | self.counter = 0
71 | self.prev_time = time.perf_counter()
72 | self.start = torch.cuda.Event(enable_timing=True)
73 | self.end = torch.cuda.Event(enable_timing=True)
74 | self.prev_time_gpu = self.start.record()
75 |
76 | def check(self, name=None):
77 | if self.activate:
78 | cpu_time = time.perf_counter() - self.prev_time
79 |
80 | self.end.record()
81 | torch.cuda.synchronize()
82 |
83 | gpu_time = self.start.elapsed_time(self.end) / 1e3
84 |
85 | # Keep track of averages. For this to work, keys need to be unique in a global scope
86 | if name not in self.loop_counts:
87 | self.loop_totals_cpu[name] = 0
88 | self.loop_totals_gpu[name] = 0
89 | self.loop_counts[name] = 0
90 | self.loop_totals_gpu[name] += gpu_time
91 | self.loop_totals_cpu[name] += cpu_time
92 | self.loop_counts[name] += 1
93 |
94 | if self.print_mode and name:
95 | cpu_time_disp = colorize_time(cpu_time)
96 | gpu_time_disp = colorize_time(gpu_time)
97 | cpu_time_disp_avg = colorize_time(self.loop_totals_cpu[name] / self.loop_counts[name])
98 | gpu_time_disp_avg = colorize_time(self.loop_totals_gpu[name] / self.loop_counts[name])
99 |
100 | if name:
101 | print(f"CPU Checkpoint {name}: {cpu_time_disp}s (Avg: {cpu_time_disp_avg}s)")
102 | print(f"GPU Checkpoint {name}: {gpu_time_disp}s (Avg: {gpu_time_disp_avg}s)")
103 | else:
104 | print("CPU Checkpoint {}: {} s".format(self.counter, cpu_time_disp))
105 | print("GPU Checkpoint {}: {} s".format(self.counter, gpu_time_disp))
106 | if self.show_memory:
107 | #torch.cuda.empty_cache()
108 | print(f"{torch.cuda.memory_allocated()//1048576}MB")
109 |
110 |
111 | self.prev_time = time.perf_counter()
112 | self.prev_time_gpu = self.start.record()
113 | self.counter += 1
114 | return cpu_time, gpu_time
115 |
116 | def get_avg_cpu_time(self, name):
117 | return self.loop_totals_cpu[name] / self.loop_counts[name]
118 |
119 | def get_avg_gpu_time(self, name):
120 | return self.loop_totals_gpu[name] / self.loop_counts[name]
121 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/renderer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 |
10 | import torch
11 | import math
12 | import point_renderer.ops as ops
13 | from point_renderer.cameras import OrthographicCameras, PerspectiveCameras
14 | import point_renderer._C.render as r
15 |
16 | from point_renderer.profiler import PerfTimer
17 |
18 |
19 | @torch.jit.script
20 | def _prep_render_batch_inputs(points, features, inv_poses, intrinsics, img_h : int, img_w: int, orthographic : bool):
21 | batch_size = len(inv_poses)
22 | num_points = points.shape[0]
23 |
24 | # Project points from 3D world coordinates to pixels and camera coordinates
25 | pc_px, pc_cam = ops.project_points_3d_to_pixels(points, inv_poses, intrinsics, orthographic)
26 |
27 | # Convert pixel coordinates to flattened pixel coordinates, marking out-of-image points with index -1
28 | point_pixel_index = ops.get_batch_pixel_index(pc_px, img_h, img_w)
29 | imask = ops.batch_frustrum_mask(pc_px, img_h, img_w)
30 | point_pixel_index[~imask] = -1
31 |
32 | point_depths = pc_cam[:, :, 2]
33 | return point_pixel_index.reshape([batch_size * num_points]).contiguous(), point_depths.reshape([batch_size * num_points]).contiguous(), features
34 |
35 |
36 |
37 | class PointRenderer:
38 |
39 | def __init__(self, device="cuda", perf_timer=False):
40 | self.device = device
41 | assert "cuda" in self.device, "Currently only a CUDA implementation is available"
42 | self.timer = PerfTimer(activate=perf_timer, show_memory=False, print_mode=perf_timer)
43 |
44 | @torch.no_grad()
45 | def splat_filter(self, batch_size, splat_radius, cameras, point_depths, depth_buf, image_buf, default_depth, splat_max_k):
46 | # This assumes same focal length for all cameras.
47 | if cameras.is_perspective():
48 | # We are assuming x and y focal lengths to be about the same.
49 | focal_length = cameras.intrinsics[0, 0, 0].item()
50 | closest_point = point_depths.min()
51 | biggest_splat = math.ceil((splat_radius * focal_length / closest_point))
52 | elif cameras.is_orthographic():
53 | # In orthographic cameras all points are the same size
54 | focal_length = cameras.intrinsics[0, 0, 0].item()
55 | biggest_splat = math.ceil(focal_length * splat_radius)
56 | else:
57 | raise ValueError(f"Unknown camera type: {type(cameras)}")
58 |
59 | if splat_max_k is None:
60 | splat_max_k = 7
61 | kernel_size = min(biggest_splat * 2 + 1, splat_max_k)
62 |
63 | #print(f"Splatting filter with k={kernel_size}, b={batch_size}")
64 | r.screen_space_splatting(
65 | batch_size,
66 | depth_buf.clone(),
67 | image_buf.clone(),
68 | depth_buf,
69 | image_buf,
70 | default_depth,
71 | focal_length,
72 | splat_radius,
73 | kernel_size,
74 | cameras.is_orthographic()
75 | )
76 |
77 |
78 | @torch.no_grad()
79 | def render_batch(self, points, features, cameras, img_size, default_depth=0.0, default_color=0.0, splat_radius=None, splat_max_k=51, aa_factor=1):
80 | # Figure out dimensions of the problem
81 | img_h, img_w = img_size
82 | aa_factor = int(aa_factor)
83 | assert aa_factor >= 1, "Antialiasing factor must be greater than 1"
84 | img_h = img_h * aa_factor
85 | img_w = img_w * aa_factor
86 | batch_size = len(cameras)
87 | num_channels = features.shape[1]
88 |
89 | self.timer.reset()
90 |
91 | # Make sure inputs are on the rendering device
92 | cameras = cameras.to(self.device)
93 | features = features.to(self.device)
94 | points = points.to(self.device)
95 |
96 | # Scale the camera if we want to internally render a bigger image for fake antialiasing
97 | if aa_factor > 1:
98 | cameras.scale(aa_factor)
99 |
100 | # Project points to image space depending on the camera type
101 | # (these would be more elegant as class methods, but TorchScript doesn't support it)
102 | point_pixel_index, point_depths, features = _prep_render_batch_inputs(
103 | points, features, cameras.inv_poses, cameras.intrinsics, img_h, img_w, type(cameras) == OrthographicCameras)
104 |
105 | # Allocate depth and image render buffers
106 | # Batch size is rolled in with the height dimension
107 | # a.k.a. we render a single image that contains all images vertically stacked
108 | depth_buf = torch.zeros((batch_size * img_h, img_w), device=self.device, dtype=torch.float32)
109 | assert features.dtype == torch.float32, "For now only torch.uint8 and torch.float32 colors are supported"
110 | image_buf = torch.zeros((batch_size * img_h, img_w, num_channels), device=self.device, dtype=torch.float32)
111 |
112 | self.timer.check("render_setup")
113 |
114 | # Render points to pixel buffers
115 | r.render_feature_pointcloud_to_image(
116 | point_pixel_index,
117 | point_depths,
118 | features,
119 | image_buf,
120 | depth_buf,
121 | default_depth,
122 | default_color
123 | )
124 |
125 | self.timer.check("render")
126 |
127 | # Apply screen-space splatting filter
128 | if splat_radius is not None:
129 | self.splat_filter(batch_size, splat_radius, cameras, point_depths, depth_buf, image_buf, default_depth, splat_max_k)
130 | self.timer.check("splatting")
131 |
132 | # Separate the batch dimension, so that we have a batch of images
133 | image_buf = image_buf.reshape((batch_size, img_h, img_w, num_channels))
134 | depth_buf = depth_buf.reshape((batch_size, img_h, img_w))
135 |
136 | if aa_factor != 1:
137 | # Subsample the larger render buffers to produce the output image
138 | img_h_out = img_h // aa_factor
139 | img_w_out = img_w // aa_factor
140 | depth_out = torch.zeros((batch_size, img_h_out, img_w_out), device=self.device, dtype=torch.float32).fill_(default_depth)
141 | image_out = torch.zeros((batch_size, img_h_out, img_w_out, num_channels), device=self.device, dtype=torch.float32)
142 | print(depth_buf.min(), depth_buf.max())
143 |
144 | r.aa_subsample(
145 | depth_buf,
146 | image_buf,
147 | depth_out,
148 | image_out,
149 | aa_factor,
150 | default_depth
151 | )
152 | self.timer.check("antialiasing")
153 | return image_out.reshape((batch_size, img_h_out, img_w_out, num_channels)), depth_out.reshape((batch_size, img_h_out, img_w_out))
154 | else:
155 | # Output render buffers as-is if we're rendering at the same resolution as the output
156 | self.timer.check("antialiasing")
157 | return image_buf, depth_buf
158 |
159 |
160 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/rvt_ops.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import torch
3 | import math
4 |
5 |
6 | # source: https://discuss.pytorch.org/t/batched-index-select/9115/6
7 | def batched_index_select(inp, dim, index):
8 | """
9 | input: B x * x ... x *
10 | dim: 0 < scalar
11 | index: B x M
12 | """
13 | views = [inp.shape[0]] + [1 if i != dim else -1 for i in range(1, len(inp.shape))]
14 | expanse = list(inp.shape)
15 | expanse[0] = -1
16 | expanse[dim] = -1
17 | index = index.view(views).expand(expanse)
18 | return torch.gather(inp, dim, index)
19 |
20 |
21 | # TODO: break into two functions
22 | def select_feat_from_hm(
23 | pt_cam: torch.Tensor, hm: torch.Tensor, pt_cam_wei: Optional[torch.Tensor] = None
24 | ) -> Tuple[torch.Tensor]:
25 | """
26 | :param pt_cam:
27 | continuous location of point coordinates from where value needs to be
28 | selected. it is of size [nc, npt, 2], locations in pytorch3d screen
29 | notations
30 | :param hm: size [nc, nw, h, w]
31 | :param pt_cam_wei:
32 | some predifined weight of size [nc, npt], it is used along with the
33 | distance weights
34 | :return:
35 | tuple with the first element being the wighted average for each point
36 | according to the hm values. the size is [nc, npt, nw]. the second and
37 | third elements are intermediate values to be used while chaching
38 | """
39 | nc, nw, h, w = hm.shape
40 | npt = pt_cam.shape[1]
41 | if pt_cam_wei is None:
42 | pt_cam_wei = torch.ones([nc, npt]).to(hm.device)
43 |
44 | # giving points outside the image zero weight
45 | pt_cam_wei[pt_cam[:, :, 0] < 0] = 0
46 | pt_cam_wei[pt_cam[:, :, 1] < 0] = 0
47 | pt_cam_wei[pt_cam[:, :, 0] > (w - 1)] = 0
48 | pt_cam_wei[pt_cam[:, :, 1] > (h - 1)] = 0
49 |
50 | pt_cam = pt_cam.unsqueeze(2).repeat([1, 1, 4, 1])
51 | # later used for calculating weight
52 | pt_cam_con = pt_cam.detach().clone()
53 |
54 | # getting discrete grid location of pts in the camera image space
55 | pt_cam[:, :, 0, 0] = torch.floor(pt_cam[:, :, 0, 0])
56 | pt_cam[:, :, 0, 1] = torch.floor(pt_cam[:, :, 0, 1])
57 | pt_cam[:, :, 1, 0] = torch.floor(pt_cam[:, :, 1, 0])
58 | pt_cam[:, :, 1, 1] = torch.ceil(pt_cam[:, :, 1, 1])
59 | pt_cam[:, :, 2, 0] = torch.ceil(pt_cam[:, :, 2, 0])
60 | pt_cam[:, :, 2, 1] = torch.floor(pt_cam[:, :, 2, 1])
61 | pt_cam[:, :, 3, 0] = torch.ceil(pt_cam[:, :, 3, 0])
62 | pt_cam[:, :, 3, 1] = torch.ceil(pt_cam[:, :, 3, 1])
63 | pt_cam = pt_cam.long() # [nc, npt, 4, 2]
64 | # since we are taking modulo, points at the edge, i,e at h or w will be
65 | # mapped to 0. this will make their distance from the continous location
66 | # large and hence they won't matter. therefore we don't need an explicit
67 | # step to remove such points
68 | pt_cam[:, :, :, 0] = torch.fmod(pt_cam[:, :, :, 0], int(w))
69 | pt_cam[:, :, :, 1] = torch.fmod(pt_cam[:, :, :, 1], int(h))
70 | pt_cam[pt_cam < 0] = 0
71 |
72 | # getting normalized weight for each discrete location for pt
73 | # weight based on distance of point from the discrete location
74 | # [nc, npt, 4]
75 | pt_cam_dis = 1 / (torch.sqrt(torch.sum((pt_cam_con - pt_cam) ** 2, dim=-1)) + 1e-10)
76 | pt_cam_wei = pt_cam_wei.unsqueeze(-1) * pt_cam_dis
77 | _pt_cam_wei = torch.sum(pt_cam_wei, dim=-1, keepdim=True)
78 | _pt_cam_wei[_pt_cam_wei == 0.0] = 1
79 | # cached pt_cam_wei in select_feat_from_hm_cache
80 | pt_cam_wei = pt_cam_wei / _pt_cam_wei # [nc, npt, 4]
81 |
82 | # transforming indices from 2D to 1D to use pytorch gather
83 | hm = hm.permute(0, 2, 3, 1).view(nc, h * w, nw) # [nc, h * w, nw]
84 | pt_cam = pt_cam.view(nc, 4 * npt, 2) # [nc, 4 * npt, 2]
85 | # cached pt_cam in select_feat_from_hm_cache
86 | pt_cam = (pt_cam[:, :, 1] * w) + pt_cam[:, :, 0] # [nc, 4 * npt]
87 | # [nc, 4 * npt, nw]
88 | pt_cam_val = batched_index_select(hm, dim=1, index=pt_cam)
89 | # tranforming back each discrete location of point
90 | pt_cam_val = pt_cam_val.view(nc, npt, 4, nw)
91 | # summing weighted contribution of each discrete location of a point
92 | # [nc, npt, nw]
93 | pt_cam_val = torch.sum(pt_cam_val * pt_cam_wei.unsqueeze(-1), dim=2)
94 | return pt_cam_val, pt_cam, pt_cam_wei
95 |
96 |
97 | def select_feat_from_hm_cache(
98 | pt_cam: torch.Tensor,
99 | hm: torch.Tensor,
100 | pt_cam_wei: torch.Tensor,
101 | ) -> torch.Tensor:
102 | """
103 | Cached version of select_feat_from_hm where we feed in directly the
104 | intermediate value of pt_cam and pt_cam_wei. Look into the original
105 | function to get the meaning of these values and return type. It could be
106 | used while inference if the location of the points remain the same.
107 | """
108 |
109 | nc, nw, h, w = hm.shape
110 | # transforming indices from 2D to 1D to use pytorch gather
111 | hm = hm.permute(0, 2, 3, 1).view(nc, h * w, nw) # [nc, h * w, nw]
112 | # [nc, 4 * npt, nw]
113 | pt_cam_val = batched_index_select(hm, dim=1, index=pt_cam)
114 | # tranforming back each discrete location of point
115 | pt_cam_val = pt_cam_val.view(nc, -1, 4, nw)
116 | # summing weighted contribution of each discrete location of a point
117 | # [nc, npt, nw]
118 | pt_cam_val = torch.sum(pt_cam_val * pt_cam_wei.unsqueeze(-1), dim=2)
119 | return pt_cam_val
120 |
121 |
122 | # unit tests to verify select_feat_from_hm
123 | def test_select_feat_from_hm():
124 | def get_out(pt_cam, hm):
125 | nc, nw, d = pt_cam.shape
126 | nc2, c, h, w = hm.shape
127 | assert nc == nc2
128 | assert d == 2
129 | out = torch.zeros((nc, nw, c))
130 | for i in range(nc):
131 | for j in range(nw):
132 | wx, hx = pt_cam[i, j]
133 | if (wx < 0) or (hx < 0) or (wx > (w - 1)) or (hx > (h - 1)):
134 | out[i, j, :] = 0
135 | else:
136 | coords = (
137 | (math.floor(wx), math.floor(hx)),
138 | (math.floor(wx), math.ceil(hx)),
139 | (math.ceil(wx), math.floor(hx)),
140 | (math.ceil(wx), math.ceil(hx)),
141 | )
142 | vals = []
143 | total = 0
144 | for x, y in coords:
145 | val = 1 / (math.sqrt(((wx - x) ** 2) + ((hx - y) ** 2)) + 1e-10)
146 | vals.append(val)
147 | total += val
148 |
149 | vals = [x / total for x in vals]
150 |
151 | for (x, y), val in zip(coords, vals):
152 | out[i, j] += val * hm[i, :, y, x]
153 | return out
154 |
155 | pt_cam_1 = torch.tensor([[[11.11, 120.1], [37.8, 0.0], [99, 76.5]]])
156 | hm_1_1 = torch.ones((1, 1, 100, 120))
157 | hm_1_2 = torch.ones((1, 1, 120, 100))
158 | out_1 = torch.ones((1, 3, 1))
159 | out_1[0, 0, 0] = 0
160 |
161 | pt_cam_2 = torch.tensor(
162 | [
163 | [[11.11, 12.11], [37.8, 0.0]],
164 | [[61.00, 12.00], [123.99, 123.0]],
165 | ]
166 | )
167 | hm_2_1 = torch.rand((2, 1, 200, 100))
168 | hm_2_2 = torch.rand((2, 1, 100, 200))
169 |
170 | test_sets = [
171 | (pt_cam_1, hm_1_1, out_1),
172 | (pt_cam_1, hm_1_2, out_1),
173 | (pt_cam_2, hm_2_1, get_out(pt_cam_2, hm_2_1)),
174 | (pt_cam_2, hm_2_2, get_out(pt_cam_2, hm_2_2)),
175 | ]
176 |
177 | for i, test in enumerate(test_sets):
178 | pt_cam, hm, out = test
179 | _out, _, _ = select_feat_from_hm(pt_cam, hm)
180 | out = out.float()
181 | if torch.all(torch.abs(_out - out) < 1e-5):
182 | print(f"Passed test {i}, {out}, {_out}")
183 | else:
184 | print(f"Failed test {i}, {out}, {_out}")
185 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/point_renderer/rvt_renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import point_renderer.ops as ops
3 | from point_renderer.cameras import OrthographicCameras, PerspectiveCameras
4 | from point_renderer.renderer import PointRenderer
5 | from mvt.utils import ForkedPdb
6 |
7 | import point_renderer.rvt_ops as rvt_ops
8 |
9 |
10 | class RVTBoxRenderer():
11 | """
12 | Wrapper around PointRenderer that fixes the cameras to be orthographic cameras
13 | on the faces of a 2x2x2 cube placed at the origin
14 | """
15 |
16 | def __init__(
17 | self,
18 | img_size,
19 | radius=0.012,
20 | default_color=0.0,
21 | default_depth=-1.0,
22 | antialiasing_factor=1,
23 | pers=False,
24 | normalize_output=True,
25 | with_depth=True,
26 | device="cuda",
27 | perf_timer=False,
28 | strict_input_device=True,
29 | no_down=True,
30 | no_top=False,
31 | three_views=False,
32 | two_views=False,
33 | one_view=False,
34 | add_3p=False,
35 | **kwargs):
36 |
37 | self.renderer = PointRenderer(device=device, perf_timer=perf_timer)
38 |
39 | self.img_size = img_size
40 | self.splat_radius = radius
41 | self.default_color = default_color
42 | self.default_depth = default_depth
43 | self.aa_factor = antialiasing_factor
44 | self.normalize_output = normalize_output
45 | self.with_depth = with_depth
46 |
47 | self.strict_input_device = strict_input_device
48 |
49 | # Pre-compute fixed cameras ahead of time
50 | self.cameras = self._get_cube_cameras(
51 | img_size=self.img_size,
52 | orthographic=not pers,
53 | no_down=no_down,
54 | no_top=no_top,
55 | three_views=three_views,
56 | two_views=two_views,
57 | one_view=one_view,
58 | add_3p=add_3p,
59 | )
60 | self.cameras = self.cameras.to(device)
61 |
62 | # TODO(Valts): add support for dynamic cameras
63 |
64 | # Cache
65 | self._fix_pts_cam = None
66 | self._fix_pts_cam_wei = None
67 | self._pts = None
68 |
69 | # RVT API (that we might want to refactor)
70 | self.num_img = len(self.cameras)
71 | self.only_dyn_cam = False
72 |
73 | def _check_device(self, input, input_name):
74 | if self.strict_input_device:
75 | assert str(input.device) == str(self.renderer.device), (
76 | f"Input {input_name} (device {input.device}) should be on the same device as the renderer ({self.renderer.device})")
77 |
78 | def _get_cube_cameras(
79 | self,
80 | img_size,
81 | orthographic,
82 | no_down,
83 | no_top,
84 | three_views,
85 | two_views,
86 | one_view,
87 | add_3p,
88 | ):
89 | cam_dict = {
90 | "top": {"eye": [0, 0, 1], "at": [0, 0, 0], "up": [0, 1, 0]},
91 | "front": {"eye": [1, 0, 0], "at": [0, 0, 0], "up": [0, 0, 1]},
92 | "down": {"eye": [0, 0, -1], "at": [0, 0, 0], "up": [0, 1, 0]},
93 | "back": {"eye": [-1, 0, 0], "at": [0, 0, 0], "up": [0, 0, 1]},
94 | "left": {"eye": [0, -1, 0], "at": [0, 0, 0], "up": [0, 0, 1]},
95 | "right": {"eye": [0, 0.5, 0], "at": [0, 0, 0], "up": [0, 0, 1]},
96 | }
97 |
98 | assert not (two_views and three_views)
99 | assert not (one_view and three_views)
100 | assert not (one_view and two_views)
101 | assert not add_3p, "Not supported with point renderer yet,"
102 | if two_views or three_views or one_view:
103 | if no_down or no_top or add_3p:
104 | print(
105 | f"WARNING: when three_views={three_views} or two_views={two_views} -- "
106 | f"no_down={no_down} no_top={no_top} add_3p={add_3p} does not matter."
107 | )
108 |
109 | if three_views:
110 | cam_names = ["top", "front", "right"]
111 | elif two_views:
112 | cam_names = ["top", "front"]
113 | elif one_view:
114 | cam_names = ["front"]
115 | else:
116 | cam_names = ["top", "front", "down", "back", "left", "right"]
117 | if no_down:
118 | # select index of "down" camera and remove it from the list
119 | del cam_names[cam_names.index("down")]
120 | if no_top:
121 | del cam_names[cam_names.index("top")]
122 |
123 |
124 | cam_list = [cam_dict[n] for n in cam_names]
125 | eyes = [c["eye"] for c in cam_list]
126 | ats = [c["at"] for c in cam_list]
127 | ups = [c["up"] for c in cam_list]
128 |
129 | if orthographic:
130 | # img_sizes_w specifies height and width dimensions of the image in world coordinates
131 | # [2, 2] means it will image coordinates from -1 to 1 in the camera frame
132 | cameras = OrthographicCameras.from_lookat(eyes, ats, ups, img_sizes_w=[2, 2], img_size_px=img_size)
133 | else:
134 | cameras = PerspectiveCameras.from_lookat(eyes, ats, ups, hfov=70, img_size=img_size)
135 | return cameras
136 |
137 | @torch.no_grad()
138 | def get_pt_loc_on_img(self, pt, fix_cam=False, dyn_cam_info=None):
139 | """
140 | returns the location of a point on the image of the cameras
141 | :param pt: torch.Tensor of shape (bs, np, 3)
142 | :returns: the location of the pt on the image. this is different from the
143 | camera screen coordinate system in pytorch3d. the difference is that
144 | pytorch3d camera screen projects the point to [0, 0] to [H, W]; while the
145 | index on the img is from [0, 0] to [H-1, W-1]. We verified that
146 | the to transform from pytorch3d camera screen point to img we have to
147 | subtract (1/H, 1/W) from the pytorch3d camera screen coordinate.
148 | :return type: torch.Tensor of shape (bs, np, self.num_img, 2)
149 | """
150 | assert len(pt.shape) == 3
151 | assert pt.shape[-1] == 3
152 | assert fix_cam, "Not supported with point renderer"
153 | assert dyn_cam_info is None, "Not supported with point renderer"
154 |
155 | bs, np, _ = pt.shape
156 |
157 | self._check_device(pt, "pt")
158 |
159 | # TODO(Valts): Ask Ankit what what is the bs dimension here, and treat it correctly here
160 |
161 | pcs_px = []
162 | for i in range(bs):
163 | pc_px, pc_cam = ops.project_points_3d_to_pixels(
164 | pt[i], self.cameras.inv_poses, self.cameras.intrinsics, self.cameras.is_orthographic())
165 | pcs_px.append(pc_px)
166 | pcs_px = torch.stack(pcs_px, dim=0)
167 | pcs_px = torch.permute(pcs_px, (0, 2, 1, 3))
168 |
169 | # TODO(Valts): Double-check with Ankit that these projections are truly pixel-aligned
170 | return pcs_px
171 |
172 | @torch.no_grad()
173 | def get_feat_frm_hm_cube(self, hm, fix_cam=False, dyn_cam_info=None):
174 | """
175 | :param hm: torch.Tensor of (1, num_img, h, w)
176 | :return: tupe of ((num_img, h^3, 1), (h^3, 3))
177 | """
178 | x, nc, h, w = hm.shape
179 | assert x == 1
180 | assert nc == self.num_img
181 | assert self.img_size == (h, w)
182 | assert fix_cam, "Not supported with point renderer"
183 | assert dyn_cam_info is None, "Not supported with point renderer"
184 |
185 | self._check_device(hm, "hm")
186 |
187 | if self._pts is None:
188 | res = self.img_size[0]
189 | pts = torch.linspace(-1 + (1 / res), 1 - (1 / res), res).to(hm.device)
190 | pts = torch.cartesian_prod(pts, pts, pts)
191 | self._pts = pts
192 |
193 | pts_hm = []
194 |
195 | # if self._fix_cam
196 | if self._fix_pts_cam is None:
197 | # (np, nc, 2)
198 | pts_img = self.get_pt_loc_on_img(self._pts.unsqueeze(0),
199 | fix_cam=True).squeeze(0)
200 | # pts_img = pts_img.permute((1, 0, 2))
201 | # (nc, np, bs)
202 | fix_pts_hm, pts_cam, pts_cam_wei = rvt_ops.select_feat_from_hm(
203 | pts_img.transpose(0, 1), hm.transpose(0, 1)[0 : len(self.cameras)]
204 | )
205 | self._fix_pts_img = pts_img
206 | self._fix_pts_cam = pts_cam
207 | self._fix_pts_cam_wei = pts_cam_wei
208 | else:
209 | pts_cam = self._fix_pts_cam
210 | pts_cam_wei = self._fix_pts_cam_wei
211 | fix_pts_hm = rvt_ops.select_feat_from_hm_cache(
212 | pts_cam, hm.transpose(0, 1)[0 : len(self.cameras)], pts_cam_wei
213 | )
214 | pts_hm.append(fix_pts_hm)
215 |
216 | #if not dyn_cam_info is None:
217 | # TODO(Valts): implement
218 | pts_hm = torch.cat(pts_hm, 0)
219 | return pts_hm, self._pts
220 |
221 | @torch.no_grad()
222 | def get_max_3d_frm_hm_cube(self, hm, fix_cam=False, dyn_cam_info=None,
223 | topk=1, non_max_sup=False,
224 | non_max_sup_dist=0.02):
225 | """
226 | given set of heat maps, return the 3d location of the point with the
227 | largest score, assumes the points are in a cube [-1, 1]. This function
228 | should be used along with the render. For standalone version look for
229 | the other function with same name in the file.
230 | :param hm: (1, nc, h, w)
231 | :return: (1, topk, 3)
232 | """
233 | assert fix_cam, "Not supported with point renderer"
234 | assert dyn_cam_info is None, "Not supported with point renderer"
235 |
236 | self._check_device(hm, "hm")
237 |
238 | x, nc, h, w = hm.shape
239 | assert x == 1
240 | assert nc == len(self.cameras)
241 | assert self.img_size == (h, w)
242 |
243 | pts_hm, pts = self.get_feat_frm_hm_cube(hm, fix_cam, dyn_cam_info)
244 | # (bs, np, nc)
245 | pts_hm = pts_hm.permute(2, 1, 0)
246 | # (bs, np)
247 | pts_hm = torch.mean(pts_hm, -1)
248 | if non_max_sup and topk > 1:
249 | _pts = pts.clone()
250 | pts = []
251 | pts_hm = torch.squeeze(pts_hm, 0)
252 | for i in range(topk):
253 | ind_max_pts = torch.argmax(pts_hm, -1)
254 | sel_pts = _pts[ind_max_pts]
255 | pts.append(sel_pts)
256 | dist = torch.sqrt(torch.sum((_pts - sel_pts) ** 2, -1))
257 | pts_hm[dist < non_max_sup_dist] = -1
258 | pts = torch.stack(pts, 0).unsqueeze(0)
259 | else:
260 | # (bs, topk)
261 | ind_max_pts = torch.topk(pts_hm, topk)[1]
262 | # (bs, topk, 3)
263 | pts = pts[ind_max_pts]
264 | return pts
265 |
266 | def __call__(self, pc, feat, fix_cam=False, dyn_cam_info=None):
267 |
268 | self._check_device(pc, "pc")
269 | self._check_device(pc, "feat")
270 |
271 | pc_images, pc_depths = self.renderer.render_batch(pc, feat,
272 | cameras=self.cameras,
273 | img_size=self.img_size,
274 | splat_radius=self.splat_radius,
275 | default_color=self.default_color,
276 | default_depth=self.default_depth,
277 | aa_factor=self.aa_factor
278 | )
279 |
280 | if self.normalize_output:
281 | _, h, w = pc_depths.shape
282 | depth_0 = pc_depths == -1
283 | depth_sum = torch.sum(pc_depths, (1, 2)) + torch.sum(depth_0, (1, 2))
284 | depth_mean = depth_sum / ((h * w) - torch.sum(depth_0, (1, 2)))
285 | pc_depths -= depth_mean.unsqueeze(-1).unsqueeze(-1)
286 | pc_depths[depth_0] = -1
287 |
288 | if self.with_depth:
289 | img_out = torch.cat([pc_images, pc_depths[:, :, :, None]], dim=-1)
290 | else:
291 | img_out = pc_images
292 |
293 | return img_out
294 |
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/requirements.txt:
--------------------------------------------------------------------------------
1 | transforms3d
2 | jupyter
3 | numpy
4 | torch
5 | matplotlib
6 | imageio
7 | trimesh
8 | meshcat
--------------------------------------------------------------------------------
/rvt/libs/point-renderer/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 |
10 | import os
11 | import sys
12 | from setuptools import setup, find_packages, dist
13 | import glob
14 | import logging
15 |
16 | import torch
17 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
18 |
19 | PACKAGE_NAME = 'point_renderer'
20 | DESCRIPTION = 'Fast Point Cloud Renderer'
21 | URL = 'https://gitlab-master.nvidia.com/vblukis/point-renderer'
22 | AUTHOR = 'Valts Blukis'
23 | LICENSE = 'NVIDIA'
24 | version = '0.2.0'
25 |
26 |
27 | def get_extensions():
28 | extra_compile_args = {'cxx': ['-O3']}
29 | define_macros = []
30 | include_dirs = []
31 | extensions = []
32 | sources = glob.glob('point_renderer/csrc/**/*.cpp', recursive=True)
33 |
34 | if len(sources) == 0:
35 | print("No source files found for extension, skipping extension compilation")
36 | return None
37 |
38 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1' or True:
39 | define_macros += [("WITH_CUDA", None), ("THRUST_IGNORE_CUB_VERSION_CHECK", None)]
40 | sources += glob.glob('point_renderer/csrc/**/*.cu', recursive=True)
41 | extension = CUDAExtension
42 | extra_compile_args.update({'nvcc': ['-O3']})
43 | #include_dirs = get_include_dirs()
44 | else:
45 | assert(False, "CUDA is not available. Set FORCE_CUDA=1 for Docker builds")
46 |
47 | extensions.append(
48 | extension(
49 | name='point_renderer._C',
50 | sources=sources,
51 | define_macros=define_macros,
52 | extra_compile_args=extra_compile_args,
53 | #include_dirs=include_dirs
54 | )
55 | )
56 |
57 | for ext in extensions:
58 | ext.libraries = ['cudart_static' if x == 'cudart' else x
59 | for x in ext.libraries]
60 |
61 | return extensions
62 |
63 |
64 | if __name__ == '__main__':
65 | setup(
66 | # Metadata
67 | name=PACKAGE_NAME,
68 | version=version,
69 | author=AUTHOR,
70 | description=DESCRIPTION,
71 | url=URL,
72 | license=LICENSE,
73 | python_requires='>=3.7',
74 |
75 | # Package info
76 | packages=['point_renderer'],
77 | include_package_data=True,
78 | zip_safe=True,
79 | ext_modules=get_extensions(),
80 | cmdclass={
81 | 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
82 | }
83 |
84 | )
85 |
--------------------------------------------------------------------------------
/rvt/models/peract_official.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | from rvt.libs.peract.helpers.preprocess_agent import PreprocessAgent
6 | from rvt.libs.peract.agents.peract_bc.launch_utils import create_agent
7 |
8 |
9 | class PreprocessAgent2(PreprocessAgent):
10 | def eval(self):
11 | self._pose_agent._qattention_agents[0]._q.eval()
12 |
13 | def train(self):
14 | self._pose_agent._qattention_agents[0]._q.train()
15 |
16 | def build(self, *args, **kwargs):
17 | super().build(*args, **kwargs)
18 |
19 | self._device = self._pose_agent._qattention_agents[0]._device
20 |
21 |
22 | def create_agent_our(cfg):
23 | """
24 | Reuses the official peract agent, but replaces PreprocessAgent2 with PreprocessAgent
25 | """
26 | agent = create_agent(cfg)
27 | agent = agent._pose_agent
28 | agent = PreprocessAgent2(agent)
29 | return agent
30 |
--------------------------------------------------------------------------------
/rvt/mvt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 |
--------------------------------------------------------------------------------
/rvt/mvt/attn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | # Sources:
6 | # https://github.com/lucidrains/perceiver-pytorch/blob/main/perceiver_pytorch/perceiver_io.py
7 | # https://github.com/peract/peract/blob/main/helpers/network_utils.py
8 |
9 | import math
10 | from math import log
11 | from functools import wraps
12 | from packaging import version
13 |
14 | import torch
15 | import torch.nn.functional as F
16 |
17 | from tqdm import tqdm
18 | from torch import nn, einsum
19 | from einops import rearrange, repeat
20 |
21 | try:
22 | import xformers.ops as xops
23 | except ImportError as e:
24 | xops = None
25 |
26 | LRELU_SLOPE = 0.02
27 |
28 |
29 | def exists(val):
30 | return val is not None
31 |
32 |
33 | def default(val, d):
34 | return val if exists(val) else d
35 |
36 |
37 | def cache_fn(f):
38 | cache = None
39 |
40 | @wraps(f)
41 | def cached_fn(*args, _cache=True, **kwargs):
42 | if not _cache:
43 | return f(*args, **kwargs)
44 | nonlocal cache
45 | if cache is not None:
46 | return cache
47 | cache = f(*args, **kwargs)
48 | return cache
49 |
50 | return cached_fn
51 |
52 |
53 | class PreNorm(nn.Module):
54 | def __init__(self, dim, fn, context_dim=None):
55 | super().__init__()
56 | self.fn = fn
57 | self.norm = nn.LayerNorm(dim)
58 | self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
59 |
60 | def forward(self, x, **kwargs):
61 | x = self.norm(x)
62 |
63 | if exists(self.norm_context):
64 | context = kwargs["context"]
65 | normed_context = self.norm_context(context)
66 | kwargs.update(context=normed_context)
67 |
68 | return self.fn(x, **kwargs)
69 |
70 |
71 | class GEGLU(nn.Module):
72 | def forward(self, x):
73 | x, gates = x.chunk(2, dim=-1)
74 | return x * F.gelu(gates)
75 |
76 |
77 | class FeedForward(nn.Module):
78 | def __init__(self, dim, mult=4):
79 | super().__init__()
80 | self.net = nn.Sequential(
81 | nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim)
82 | )
83 |
84 | def forward(self, x):
85 | return self.net(x)
86 |
87 |
88 | class Attention(nn.Module): # is all you need. Living up to its name.
89 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64,
90 | dropout=0.0, use_fast=False):
91 |
92 | super().__init__()
93 | self.use_fast = use_fast
94 | inner_dim = dim_head * heads
95 | context_dim = default(context_dim, query_dim)
96 | self.scale = dim_head**-0.5
97 | self.heads = heads
98 |
99 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
100 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
101 | self.to_out = nn.Linear(inner_dim, query_dim)
102 |
103 | self.dropout_p = dropout
104 | # dropout left in use_fast for backward compatibility
105 | self.dropout = nn.Dropout(self.dropout_p)
106 |
107 | self.avail_xf = False
108 | if self.use_fast:
109 | if not xops is None:
110 | self.avail_xf = True
111 | else:
112 | self.use_fast = False
113 |
114 | def forward(self, x, context=None, mask=None):
115 | h = self.heads
116 |
117 | q = self.to_q(x)
118 | context = default(context, x)
119 | k, v = self.to_kv(context).chunk(2, dim=-1)
120 |
121 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
122 | if self.use_fast:
123 | # using py2 if available
124 | dropout_p = self.dropout_p if self.training else 0.0
125 | # using xf if available
126 | if self.avail_xf:
127 | out = xops.memory_efficient_attention(
128 | query=q, key=k, value=v, p=dropout_p
129 | )
130 | else:
131 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
132 | if exists(mask):
133 | mask = rearrange(mask, "b ... -> b (...)")
134 | max_neg_value = -torch.finfo(sim.dtype).max
135 | mask = repeat(mask, "b j -> (b h) () j", h=h)
136 | sim.masked_fill_(~mask, max_neg_value)
137 | # attention
138 | attn = sim.softmax(dim=-1)
139 | # dropout
140 | attn = self.dropout(attn)
141 | out = einsum("b i j, b j d -> b i d", attn, v)
142 |
143 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
144 | out = self.to_out(out)
145 | return out
146 |
147 |
148 | def act_layer(act):
149 | if act == "relu":
150 | return nn.ReLU()
151 | elif act == "lrelu":
152 | return nn.LeakyReLU(LRELU_SLOPE)
153 | elif act == "elu":
154 | return nn.ELU()
155 | elif act == "tanh":
156 | return nn.Tanh()
157 | elif act == "prelu":
158 | return nn.PReLU()
159 | else:
160 | raise ValueError("%s not recognized." % act)
161 |
162 |
163 | def norm_layer2d(norm, channels):
164 | if norm == "batch":
165 | return nn.BatchNorm2d(channels)
166 | elif norm == "instance":
167 | return nn.InstanceNorm2d(channels, affine=True)
168 | elif norm == "layer":
169 | return nn.GroupNorm(1, channels, affine=True)
170 | elif norm == "group":
171 | return nn.GroupNorm(4, channels, affine=True)
172 | else:
173 | raise ValueError("%s not recognized." % norm)
174 |
175 |
176 | def norm_layer1d(norm, num_channels):
177 | if norm == "batch":
178 | return nn.BatchNorm1d(num_channels)
179 | elif norm == "instance":
180 | return nn.InstanceNorm1d(num_channels, affine=True)
181 | elif norm == "layer":
182 | return nn.LayerNorm(num_channels)
183 | elif norm == "group":
184 | return nn.GroupNorm(4, num_channels, affine=True)
185 | else:
186 | raise ValueError("%s not recognized." % norm)
187 |
188 |
189 | class Conv2DBlock(nn.Module):
190 | def __init__(
191 | self,
192 | in_channels,
193 | out_channels,
194 | kernel_sizes=3,
195 | strides=1,
196 | norm=None,
197 | activation=None,
198 | padding_mode="replicate",
199 | padding=None,
200 | ):
201 | super().__init__()
202 | padding = kernel_sizes // 2 if padding is None else padding
203 | self.conv2d = nn.Conv2d(
204 | in_channels,
205 | out_channels,
206 | kernel_sizes,
207 | strides,
208 | padding=padding,
209 | padding_mode=padding_mode,
210 | )
211 |
212 | if activation is None:
213 | nn.init.xavier_uniform_(
214 | self.conv2d.weight, gain=nn.init.calculate_gain("linear")
215 | )
216 | nn.init.zeros_(self.conv2d.bias)
217 | elif activation == "tanh":
218 | nn.init.xavier_uniform_(
219 | self.conv2d.weight, gain=nn.init.calculate_gain("tanh")
220 | )
221 | nn.init.zeros_(self.conv2d.bias)
222 | elif activation == "lrelu":
223 | nn.init.kaiming_uniform_(
224 | self.conv2d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu"
225 | )
226 | nn.init.zeros_(self.conv2d.bias)
227 | elif activation == "relu":
228 | nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity="relu")
229 | nn.init.zeros_(self.conv2d.bias)
230 | else:
231 | raise ValueError()
232 |
233 | self.activation = None
234 | if norm is not None:
235 | self.norm = norm_layer2d(norm, out_channels)
236 | else:
237 | self.norm = None
238 | if activation is not None:
239 | self.activation = act_layer(activation)
240 | self.out_channels = out_channels
241 |
242 | def forward(self, x):
243 | x = self.conv2d(x)
244 | x = self.norm(x) if self.norm is not None else x
245 | x = self.activation(x) if self.activation is not None else x
246 | return x
247 |
248 |
249 | class Conv2DUpsampleBlock(nn.Module):
250 | def __init__(
251 | self,
252 | in_channels,
253 | out_channels,
254 | strides,
255 | kernel_sizes=3,
256 | norm=None,
257 | activation=None,
258 | out_size=None,
259 | ):
260 | super().__init__()
261 | layer = [
262 | Conv2DBlock(in_channels, out_channels, kernel_sizes, 1, norm, activation)
263 | ]
264 | if strides > 1:
265 | if out_size is None:
266 | layer.append(
267 | nn.Upsample(scale_factor=strides, mode="bilinear", align_corners=False)
268 | )
269 | else:
270 | layer.append(
271 | nn.Upsample(size=out_size, mode="bilinear", align_corners=False)
272 | )
273 |
274 | if out_size is not None:
275 | if kernel_sizes % 2 == 0:
276 | kernel_sizes += 1
277 |
278 | convt_block = Conv2DBlock(
279 | out_channels, out_channels, kernel_sizes, 1, norm, activation
280 | )
281 | layer.append(convt_block)
282 | self.conv_up = nn.Sequential(*layer)
283 |
284 | def forward(self, x):
285 | return self.conv_up(x)
286 |
287 |
288 | class DenseBlock(nn.Module):
289 | def __init__(self, in_features, out_features, norm=None, activation=None):
290 | super(DenseBlock, self).__init__()
291 | self.linear = nn.Linear(in_features, out_features)
292 |
293 | if activation is None:
294 | nn.init.xavier_uniform_(
295 | self.linear.weight, gain=nn.init.calculate_gain("linear")
296 | )
297 | nn.init.zeros_(self.linear.bias)
298 | elif activation == "tanh":
299 | nn.init.xavier_uniform_(
300 | self.linear.weight, gain=nn.init.calculate_gain("tanh")
301 | )
302 | nn.init.zeros_(self.linear.bias)
303 | elif activation == "lrelu":
304 | nn.init.kaiming_uniform_(
305 | self.linear.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu"
306 | )
307 | nn.init.zeros_(self.linear.bias)
308 | elif activation == "relu":
309 | nn.init.kaiming_uniform_(self.linear.weight, nonlinearity="relu")
310 | nn.init.zeros_(self.linear.bias)
311 | else:
312 | raise ValueError()
313 |
314 | self.activation = None
315 | self.norm = None
316 | if norm is not None:
317 | self.norm = norm_layer1d(norm, out_features)
318 | if activation is not None:
319 | self.activation = act_layer(activation)
320 |
321 | def forward(self, x):
322 | x = self.linear(x)
323 | x = self.norm(x) if self.norm is not None else x
324 | x = self.activation(x) if self.activation is not None else x
325 | return x
326 |
327 |
328 | # based on https://pytorch.org/tutorials/beginner/transformer_tutorial.html
329 | class FixedPositionalEncoding(nn.Module):
330 | def __init__(self, feat_per_dim: int, feat_scale_factor: int):
331 | super().__init__()
332 | self.feat_scale_factor = feat_scale_factor
333 | # shape [1, feat_per_dim // 2]
334 | div_term = torch.exp(
335 | torch.arange(0, feat_per_dim, 2) * (-math.log(10000.0) /
336 | feat_per_dim)
337 | ).unsqueeze(0)
338 | self.register_buffer("div_term", div_term)
339 |
340 | def forward(self, x):
341 | """
342 | :param x: Tensor, shape [batch_size, input_dim]
343 | :return: Tensor, shape [batch_size, input_dim * feat_per_dim]
344 | """
345 | assert len(x.shape) == 2
346 | batch_size, input_dim = x.shape
347 | x = x.view(-1, 1)
348 | x = torch.cat((
349 | torch.sin(self.feat_scale_factor * x * self.div_term),
350 | torch.cos(self.feat_scale_factor * x * self.div_term)), dim=1)
351 | x = x.view(batch_size, -1)
352 | return x
353 |
--------------------------------------------------------------------------------
/rvt/mvt/aug_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | # Adapted from: https://github.com/stepjam/ARM/blob/main/arm/utils.py
6 | # utils functions for rotation augmentation
7 | import torch
8 | import numpy as np
9 | from scipy.spatial.transform import Rotation
10 |
11 |
12 | def rand_dist(size, min=-1.0, max=1.0):
13 | return (max - min) * torch.rand(size) + min
14 |
15 |
16 | def rand_discrete(size, min=0, max=1):
17 | if min == max:
18 | return torch.zeros(size)
19 | return torch.randint(min, max + 1, size)
20 |
21 |
22 | def normalize_quaternion(quat):
23 | return np.array(quat) / np.linalg.norm(quat, axis=-1, keepdims=True)
24 |
25 |
26 | def sensitive_gimble_fix(euler):
27 | """
28 | :param euler: euler angles in degree as np.ndarray in shape either [3] or
29 | [b, 3]
30 | """
31 | # selecting sensitive angle
32 | select1 = (89 < euler[..., 1]) & (euler[..., 1] < 91)
33 | euler[select1, 1] = 90
34 | # selecting sensitive angle
35 | select2 = (-91 < euler[..., 1]) & (euler[..., 1] < -89)
36 | euler[select2, 1] = -90
37 |
38 | # recalulating the euler angles, see assert
39 | r = Rotation.from_euler("xyz", euler, degrees=True)
40 | euler = r.as_euler("xyz", degrees=True)
41 |
42 | select = select1 | select2
43 | assert (euler[select][..., 2] == 0).all(), euler
44 |
45 | return euler
46 |
47 |
48 | def quaternion_to_discrete_euler(quaternion, resolution, gimble_fix=True):
49 | """
50 | :param gimble_fix: the euler values for x and y can be very sensitive
51 | around y=90 degrees. this leads to a multimodal distribution of x and y
52 | which could be hard for a network to learn. When gimble_fix is true, around
53 | y=90, we change the mode towards x=0, potentially making it easy for the
54 | network to learn.
55 | """
56 | r = Rotation.from_quat(quaternion)
57 |
58 | euler = r.as_euler("xyz", degrees=True)
59 | if gimble_fix:
60 | euler = sensitive_gimble_fix(euler)
61 |
62 | euler += 180
63 | assert np.min(euler) >= 0 and np.max(euler) <= 360
64 | disc = np.around((euler / resolution)).astype(int)
65 | disc[disc == int(360 / resolution)] = 0
66 | return disc
67 |
68 |
69 | def quaternion_to_euler(quaternion, gimble_fix=True):
70 | """
71 | :param gimble_fix: the euler values for x and y can be very sensitive
72 | around y=90 degrees. this leads to a multimodal distribution of x and y
73 | which could be hard for a network to learn. When gimble_fix is true, around
74 | y=90, we change the mode towards x=0, potentially making it easy for the
75 | network to learn.
76 | """
77 | r = Rotation.from_quat(quaternion)
78 |
79 | euler = r.as_euler("xyz", degrees=True)
80 | if gimble_fix:
81 | euler = sensitive_gimble_fix(euler)
82 |
83 | euler += 180
84 | return euler
85 |
86 |
87 | def discrete_euler_to_quaternion(discrete_euler, resolution):
88 | euluer = (discrete_euler * resolution) - 180
89 | return Rotation.from_euler("xyz", euluer, degrees=True).as_quat()
90 |
91 |
92 | def point_to_voxel_index(
93 | point: np.ndarray, voxel_size: np.ndarray, coord_bounds: np.ndarray
94 | ):
95 | bb_mins = np.array(coord_bounds[0:3])
96 | bb_maxs = np.array(coord_bounds[3:])
97 | dims_m_one = np.array([voxel_size] * 3) - 1
98 | bb_ranges = bb_maxs - bb_mins
99 | res = bb_ranges / (np.array([voxel_size] * 3) + 1e-12)
100 | voxel_indicy = np.minimum(
101 | np.floor((point - bb_mins) / (res + 1e-12)).astype(np.int32), dims_m_one
102 | )
103 | return voxel_indicy
104 |
--------------------------------------------------------------------------------
/rvt/mvt/augmentation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import numpy as np
6 | import torch
7 | import rvt.mvt.aug_utils as aug_utils
8 | from pytorch3d import transforms as torch3d_tf
9 | from scipy.spatial.transform import Rotation
10 |
11 |
12 | def perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds):
13 | """Perturb point clouds with given transformation.
14 | :param pcd:
15 | Either:
16 | - list of point clouds [[bs, 3, H, W], ...] for N cameras
17 | - point cloud [bs, 3, H, W]
18 | - point cloud [bs, 3, num_point]
19 | - point cloud [bs, num_point, 3]
20 | :param trans_shift_4x4: translation matrix [bs, 4, 4]
21 | :param rot_shift_4x4: rotation matrix [bs, 4, 4]
22 | :param action_gripper_4x4: original keyframe action gripper pose [bs, 4, 4]
23 | :param bounds: metric scene bounds [bs, 6]
24 | :return: peturbed point clouds in the same format as input
25 | """
26 | # batch bounds if necessary
27 |
28 | # for easier compatibility
29 | single_pc = False
30 | if not isinstance(pcd, list):
31 | single_pc = True
32 | pcd = [pcd]
33 |
34 | bs = pcd[0].shape[0]
35 | if bounds.shape[0] != bs:
36 | bounds = bounds.repeat(bs, 1)
37 |
38 | perturbed_pcd = []
39 | for p in pcd:
40 | p_shape = p.shape
41 | permute_p = False
42 | if len(p.shape) == 3:
43 | if p_shape[-1] == 3:
44 | num_points = p_shape[-2]
45 | p = p.permute(0, 2, 1)
46 | permute_p = True
47 | elif p_shape[-2] == 3:
48 | num_points = p_shape[-1]
49 | else:
50 | assert False, p_shape
51 |
52 | elif len(p.shape) == 4:
53 | assert p_shape[-1] != 3, p_shape[-1]
54 | assert p_shape[-2] != 3, p_shape[-2]
55 | num_points = p_shape[-1] * p_shape[-2]
56 |
57 | else:
58 | assert False, len(p.shape)
59 |
60 | action_trans_3x1 = (
61 | action_gripper_4x4[:, 0:3, 3].unsqueeze(-1).repeat(1, 1, num_points)
62 | )
63 | trans_shift_3x1 = (
64 | trans_shift_4x4[:, 0:3, 3].unsqueeze(-1).repeat(1, 1, num_points)
65 | )
66 |
67 | # flatten point cloud
68 | p_flat = p.reshape(bs, 3, -1)
69 | p_flat_4x1_action_origin = torch.ones(bs, 4, p_flat.shape[-1]).to(p_flat.device)
70 |
71 | # shift points to have action_gripper pose as the origin
72 | p_flat_4x1_action_origin[:, :3, :] = p_flat - action_trans_3x1
73 |
74 | # apply rotation
75 | perturbed_p_flat_4x1_action_origin = torch.bmm(
76 | p_flat_4x1_action_origin.transpose(2, 1), rot_shift_4x4
77 | ).transpose(2, 1)
78 |
79 | # apply bounded translations
80 | bounds_x_min, bounds_x_max = bounds[:, 0].min(), bounds[:, 3].max()
81 | bounds_y_min, bounds_y_max = bounds[:, 1].min(), bounds[:, 4].max()
82 | bounds_z_min, bounds_z_max = bounds[:, 2].min(), bounds[:, 5].max()
83 |
84 | action_then_trans_3x1 = action_trans_3x1 + trans_shift_3x1
85 | action_then_trans_3x1_x = torch.clamp(
86 | action_then_trans_3x1[:, 0], min=bounds_x_min, max=bounds_x_max
87 | )
88 | action_then_trans_3x1_y = torch.clamp(
89 | action_then_trans_3x1[:, 1], min=bounds_y_min, max=bounds_y_max
90 | )
91 | action_then_trans_3x1_z = torch.clamp(
92 | action_then_trans_3x1[:, 2], min=bounds_z_min, max=bounds_z_max
93 | )
94 | action_then_trans_3x1 = torch.stack(
95 | [action_then_trans_3x1_x, action_then_trans_3x1_y, action_then_trans_3x1_z],
96 | dim=1,
97 | )
98 |
99 | # shift back the origin
100 | perturbed_p_flat_3x1 = (
101 | perturbed_p_flat_4x1_action_origin[:, :3, :] + action_then_trans_3x1
102 | )
103 | if permute_p:
104 | perturbed_p_flat_3x1 = torch.permute(perturbed_p_flat_3x1, (0, 2, 1))
105 | perturbed_p = perturbed_p_flat_3x1.reshape(p_shape)
106 | perturbed_pcd.append(perturbed_p)
107 |
108 | if single_pc:
109 | perturbed_pcd = perturbed_pcd[0]
110 |
111 | return perturbed_pcd
112 |
113 |
114 | # version copied from peract:
115 | # https://github.com/peract/peract/blob/a3b0bd855d7e749119e4fcbe3ed7168ba0f283fd/voxel/augmentation.py#L68
116 | def apply_se3_augmentation(
117 | pcd,
118 | action_gripper_pose,
119 | action_trans,
120 | action_rot_grip,
121 | bounds,
122 | layer,
123 | trans_aug_range,
124 | rot_aug_range,
125 | rot_aug_resolution,
126 | voxel_size,
127 | rot_resolution,
128 | device,
129 | ):
130 | """Apply SE3 augmentation to a point clouds and actions.
131 | :param pcd: list of point clouds [[bs, 3, H, W], ...] for N cameras
132 | :param action_gripper_pose: 6-DoF pose of keyframe action [bs, 7]
133 | :param action_trans: discretized translation action [bs, 3]
134 | :param action_rot_grip: discretized rotation and gripper action [bs, 4]
135 | :param bounds: metric scene bounds of voxel grid [bs, 6]
136 | :param layer: voxelization layer (always 1 for PerAct)
137 | :param trans_aug_range: range of translation augmentation [x_range, y_range, z_range]
138 | :param rot_aug_range: range of rotation augmentation [x_range, y_range, z_range]
139 | :param rot_aug_resolution: degree increments for discretized augmentation rotations
140 | :param voxel_size: voxelization resoltion
141 | :param rot_resolution: degree increments for discretized rotations
142 | :param device: torch device
143 | :return: perturbed action_trans, action_rot_grip, pcd
144 | """
145 |
146 | # batch size
147 | bs = pcd[0].shape[0]
148 |
149 | # identity matrix
150 | identity_4x4 = torch.eye(4).unsqueeze(0).repeat(bs, 1, 1).to(device=device)
151 |
152 | # 4x4 matrix of keyframe action gripper pose
153 | action_gripper_trans = action_gripper_pose[:, :3]
154 | action_gripper_quat_wxyz = torch.cat(
155 | (action_gripper_pose[:, 6].unsqueeze(1), action_gripper_pose[:, 3:6]), dim=1
156 | )
157 | action_gripper_rot = torch3d_tf.quaternion_to_matrix(action_gripper_quat_wxyz)
158 | action_gripper_4x4 = identity_4x4.detach().clone()
159 | action_gripper_4x4[:, :3, :3] = action_gripper_rot
160 | action_gripper_4x4[:, 0:3, 3] = action_gripper_trans
161 |
162 | perturbed_trans = torch.full_like(action_trans, -1.0)
163 | perturbed_rot_grip = torch.full_like(action_rot_grip, -1.0)
164 |
165 | # perturb the action, check if it is within bounds, if not, try another perturbation
166 | perturb_attempts = 0
167 | while torch.any(perturbed_trans < 0):
168 | # might take some repeated attempts to find a perturbation that doesn't go out of bounds
169 | perturb_attempts += 1
170 | if perturb_attempts > 100:
171 | raise Exception("Failing to perturb action and keep it within bounds.")
172 |
173 | # sample translation perturbation with specified range
174 | trans_range = (bounds[:, 3:] - bounds[:, :3]) * trans_aug_range.to(
175 | device=device
176 | )
177 | trans_shift = trans_range * aug_utils.rand_dist((bs, 3)).to(device=device)
178 | trans_shift_4x4 = identity_4x4.detach().clone()
179 | trans_shift_4x4[:, 0:3, 3] = trans_shift
180 |
181 | # sample rotation perturbation at specified resolution and range
182 | roll_aug_steps = int(rot_aug_range[0] // rot_aug_resolution)
183 | pitch_aug_steps = int(rot_aug_range[1] // rot_aug_resolution)
184 | yaw_aug_steps = int(rot_aug_range[2] // rot_aug_resolution)
185 |
186 | roll = aug_utils.rand_discrete(
187 | (bs, 1), min=-roll_aug_steps, max=roll_aug_steps
188 | ) * np.deg2rad(rot_aug_resolution)
189 | pitch = aug_utils.rand_discrete(
190 | (bs, 1), min=-pitch_aug_steps, max=pitch_aug_steps
191 | ) * np.deg2rad(rot_aug_resolution)
192 | yaw = aug_utils.rand_discrete(
193 | (bs, 1), min=-yaw_aug_steps, max=yaw_aug_steps
194 | ) * np.deg2rad(rot_aug_resolution)
195 | rot_shift_3x3 = torch3d_tf.euler_angles_to_matrix(
196 | torch.cat((roll, pitch, yaw), dim=1), "XYZ"
197 | )
198 | rot_shift_4x4 = identity_4x4.detach().clone()
199 | rot_shift_4x4[:, :3, :3] = rot_shift_3x3
200 |
201 | # rotate then translate the 4x4 keyframe action
202 | perturbed_action_gripper_4x4 = torch.bmm(action_gripper_4x4, rot_shift_4x4)
203 | perturbed_action_gripper_4x4[:, 0:3, 3] += trans_shift
204 |
205 | # convert transformation matrix to translation + quaternion
206 | perturbed_action_trans = perturbed_action_gripper_4x4[:, 0:3, 3].cpu().numpy()
207 | perturbed_action_quat_wxyz = torch3d_tf.matrix_to_quaternion(
208 | perturbed_action_gripper_4x4[:, :3, :3]
209 | )
210 | perturbed_action_quat_xyzw = (
211 | torch.cat(
212 | [
213 | perturbed_action_quat_wxyz[:, 1:],
214 | perturbed_action_quat_wxyz[:, 0].unsqueeze(1),
215 | ],
216 | dim=1,
217 | )
218 | .cpu()
219 | .numpy()
220 | )
221 |
222 | # discretize perturbed translation and rotation
223 | # TODO(mohit): do this in torch without any numpy.
224 | trans_indicies, rot_grip_indicies = [], []
225 | for b in range(bs):
226 | bounds_idx = b if layer > 0 else 0
227 | bounds_np = bounds[bounds_idx].cpu().numpy()
228 |
229 | trans_idx = aug_utils.point_to_voxel_index(
230 | perturbed_action_trans[b], voxel_size, bounds_np
231 | )
232 | trans_indicies.append(trans_idx.tolist())
233 |
234 | quat = perturbed_action_quat_xyzw[b]
235 | quat = aug_utils.normalize_quaternion(perturbed_action_quat_xyzw[b])
236 | if quat[-1] < 0:
237 | quat = -quat
238 | disc_rot = aug_utils.quaternion_to_discrete_euler(quat, rot_resolution)
239 | rot_grip_indicies.append(
240 | disc_rot.tolist() + [int(action_rot_grip[b, 3].cpu().numpy())]
241 | )
242 |
243 | # if the perturbed action is out of bounds,
244 | # the discretized perturb_trans should have invalid indicies
245 | perturbed_trans = torch.from_numpy(np.array(trans_indicies)).to(device=device)
246 | perturbed_rot_grip = torch.from_numpy(np.array(rot_grip_indicies)).to(
247 | device=device
248 | )
249 |
250 | action_trans = perturbed_trans
251 | action_rot_grip = perturbed_rot_grip
252 |
253 | # apply perturbation to pointclouds
254 | pcd = perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds)
255 |
256 | return action_trans, action_rot_grip, pcd
257 |
258 |
259 | def apply_se3_aug_con(
260 | pcd,
261 | action_gripper_pose,
262 | bounds,
263 | trans_aug_range,
264 | rot_aug_range,
265 | scale_aug_range=False,
266 | single_scale=True,
267 | ver=2,
268 | ):
269 | """Apply SE3 augmentation to a point clouds and actions.
270 | :param pcd: [bs, num_points, 3]
271 | :param action_gripper_pose: 6-DoF pose of keyframe action [bs, 7]
272 | :param bounds: metric scene bounds
273 | Either:
274 | - [bs, 6]
275 | - [6]
276 | :param trans_aug_range: range of translation augmentation
277 | [x_range, y_range, z_range]; this is expressed as the percentage of the scene bound
278 | :param rot_aug_range: range of rotation augmentation [x_range, y_range, z_range]
279 | :param scale_aug_range: range of scale augmentation [x_range, y_range, z_range]
280 | :param single_scale: whether we preserve the relative dimensions
281 | :return: perturbed action_gripper_pose, pcd
282 | """
283 |
284 | # batch size
285 | bs = pcd.shape[0]
286 | device = pcd.device
287 |
288 | if len(bounds.shape) == 1:
289 | bounds = bounds.unsqueeze(0).repeat(bs, 1).to(device)
290 | if len(trans_aug_range.shape) == 1:
291 | trans_aug_range = trans_aug_range.unsqueeze(0).repeat(bs, 1).to(device)
292 | if len(rot_aug_range.shape) == 1:
293 | rot_aug_range = rot_aug_range.unsqueeze(0).repeat(bs, 1)
294 |
295 | # identity matrix
296 | identity_4x4 = torch.eye(4).unsqueeze(0).repeat(bs, 1, 1).to(device=device)
297 |
298 | # 4x4 matrix of keyframe action gripper pose
299 | action_gripper_trans = action_gripper_pose[:, :3]
300 |
301 | if ver == 1:
302 | action_gripper_quat_wxyz = torch.cat(
303 | (action_gripper_pose[:, 6].unsqueeze(1), action_gripper_pose[:, 3:6]), dim=1
304 | )
305 | action_gripper_rot = torch3d_tf.quaternion_to_matrix(action_gripper_quat_wxyz)
306 |
307 | elif ver == 2:
308 | # applying gimble fix to calculate a new action_gripper_rot
309 | r = Rotation.from_quat(action_gripper_pose[:, 3:7].cpu().numpy())
310 | euler = r.as_euler("xyz", degrees=True)
311 | euler = aug_utils.sensitive_gimble_fix(euler)
312 | action_gripper_rot = torch.tensor(
313 | Rotation.from_euler("xyz", euler, degrees=True).as_matrix(),
314 | device=action_gripper_pose.device,
315 | )
316 | else:
317 | assert False
318 |
319 | action_gripper_4x4 = identity_4x4.detach().clone()
320 | action_gripper_4x4[:, :3, :3] = action_gripper_rot
321 | action_gripper_4x4[:, 0:3, 3] = action_gripper_trans
322 |
323 | # sample translation perturbation with specified range
324 | # augmentation range is a percentage of the scene bound
325 | trans_range = (bounds[:, 3:] - bounds[:, :3]) * trans_aug_range.to(device=device)
326 | # rand_dist samples value from -1 to 1
327 | trans_shift = trans_range * aug_utils.rand_dist((bs, 3)).to(device=device)
328 |
329 | # apply bounded translations
330 | bounds_x_min, bounds_x_max = bounds[:, 0], bounds[:, 3]
331 | bounds_y_min, bounds_y_max = bounds[:, 1], bounds[:, 4]
332 | bounds_z_min, bounds_z_max = bounds[:, 2], bounds[:, 5]
333 |
334 | trans_shift[:, 0] = torch.clamp(
335 | trans_shift[:, 0],
336 | min=bounds_x_min - action_gripper_trans[:, 0],
337 | max=bounds_x_max - action_gripper_trans[:, 0],
338 | )
339 | trans_shift[:, 1] = torch.clamp(
340 | trans_shift[:, 1],
341 | min=bounds_y_min - action_gripper_trans[:, 1],
342 | max=bounds_y_max - action_gripper_trans[:, 1],
343 | )
344 | trans_shift[:, 2] = torch.clamp(
345 | trans_shift[:, 2],
346 | min=bounds_z_min - action_gripper_trans[:, 2],
347 | max=bounds_z_max - action_gripper_trans[:, 2],
348 | )
349 |
350 | trans_shift_4x4 = identity_4x4.detach().clone()
351 | trans_shift_4x4[:, 0:3, 3] = trans_shift
352 |
353 | roll = np.deg2rad(rot_aug_range[:, 0:1] * aug_utils.rand_dist((bs, 1)))
354 | pitch = np.deg2rad(rot_aug_range[:, 1:2] * aug_utils.rand_dist((bs, 1)))
355 | yaw = np.deg2rad(rot_aug_range[:, 2:3] * aug_utils.rand_dist((bs, 1)))
356 | rot_shift_3x3 = torch3d_tf.euler_angles_to_matrix(
357 | torch.cat((roll, pitch, yaw), dim=1), "XYZ"
358 | )
359 | rot_shift_4x4 = identity_4x4.detach().clone()
360 | rot_shift_4x4[:, :3, :3] = rot_shift_3x3
361 |
362 | if ver == 1:
363 | # rotate then translate the 4x4 keyframe action
364 | perturbed_action_gripper_4x4 = torch.bmm(action_gripper_4x4, rot_shift_4x4)
365 | elif ver == 2:
366 | perturbed_action_gripper_4x4 = identity_4x4.detach().clone()
367 | perturbed_action_gripper_4x4[:, 0:3, 3] = action_gripper_4x4[:, 0:3, 3]
368 | perturbed_action_gripper_4x4[:, :3, :3] = torch.bmm(
369 | rot_shift_4x4.transpose(1, 2)[:, :3, :3], action_gripper_4x4[:, :3, :3]
370 | )
371 | else:
372 | assert False
373 |
374 | perturbed_action_gripper_4x4[:, 0:3, 3] += trans_shift
375 |
376 | # convert transformation matrix to translation + quaternion
377 | perturbed_action_trans = perturbed_action_gripper_4x4[:, 0:3, 3].cpu().numpy()
378 | perturbed_action_quat_wxyz = torch3d_tf.matrix_to_quaternion(
379 | perturbed_action_gripper_4x4[:, :3, :3]
380 | )
381 | perturbed_action_quat_xyzw = (
382 | torch.cat(
383 | [
384 | perturbed_action_quat_wxyz[:, 1:],
385 | perturbed_action_quat_wxyz[:, 0].unsqueeze(1),
386 | ],
387 | dim=1,
388 | )
389 | .cpu()
390 | .numpy()
391 | )
392 |
393 | # TODO: add scale augmentation
394 |
395 | # apply perturbation to pointclouds
396 | # takes care for not moving the point out of the image
397 | pcd = perturb_se3(pcd, trans_shift_4x4, rot_shift_4x4, action_gripper_4x4, bounds)
398 |
399 | return perturbed_action_trans, perturbed_action_quat_xyzw, pcd
400 |
--------------------------------------------------------------------------------
/rvt/mvt/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | from yacs.config import CfgNode as CN
6 |
7 | _C = CN()
8 |
9 | _C.depth = 8
10 | _C.img_size = 220
11 | _C.add_proprio = True
12 | _C.proprio_dim = 4
13 | _C.add_lang = True
14 | _C.lang_dim = 512
15 | _C.lang_len = 77
16 | _C.img_feat_dim = 3
17 | _C.feat_dim = (72 * 3) + 2 + 2
18 | _C.im_channels = 64
19 | _C.attn_dim = 512
20 | _C.attn_heads = 8
21 | _C.attn_dim_head = 64
22 | _C.activation = "lrelu"
23 | _C.weight_tie_layers = False
24 | _C.attn_dropout = 0.1
25 | _C.decoder_dropout = 0.0
26 | _C.img_patch_size = 11
27 | _C.final_dim = 64
28 | _C.self_cross_ver = 1
29 | _C.add_corr = True
30 | _C.norm_corr = False
31 | _C.add_pixel_loc = True
32 | _C.add_depth = True
33 | _C.rend_three_views = False
34 | _C.use_point_renderer = False
35 | _C.pe_fix = True
36 | _C.feat_ver = 0
37 | _C.wpt_img_aug = 0.01
38 | _C.inp_pre_pro = True
39 | _C.inp_pre_con = True
40 | _C.cvx_up = False
41 | _C.xops = False
42 | _C.rot_ver = 0
43 | _C.num_rot = 72
44 | _C.stage_two = False
45 | _C.st_sca = 4
46 | _C.st_wpt_loc_aug = 0.05
47 | _C.st_wpt_loc_inp_no_noise = False
48 | _C.img_aug_2 = 0.0
49 |
50 |
51 | def get_cfg_defaults():
52 | """Get a yacs CfgNode object with default values for my_project."""
53 | return _C.clone()
54 |
--------------------------------------------------------------------------------
/rvt/mvt/configs/rvt2.yaml:
--------------------------------------------------------------------------------
1 | stage_two: True
2 | norm_corr: True
3 | rend_three_views: True
4 | use_point_renderer: True
5 | inp_pre_pro: False
6 | inp_pre_con: False
7 | feat_ver: 1
8 | img_size: 224
9 | img_patch_size: 14
10 | cvx_up: True
11 | xops: True
12 | rot_ver: 1
13 | st_wpt_loc_inp_no_noise: True
14 | wpt_img_aug: 0.0
15 | img_aug_2: 0.05
16 |
--------------------------------------------------------------------------------
/rvt/mvt/configs/rvt2_partial.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/RVT/367995a1a2169b6352bf4e8b0ed405890462a3a0/rvt/mvt/configs/rvt2_partial.yaml
--------------------------------------------------------------------------------
/rvt/mvt/mvt.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import copy
6 | import torch
7 |
8 | from torch import nn
9 | from torch.cuda.amp import autocast
10 |
11 | import rvt.mvt.utils as mvt_utils
12 |
13 | from rvt.mvt.mvt_single import MVT as MVTSingle
14 | from rvt.mvt.config import get_cfg_defaults
15 | from rvt.mvt.renderer import BoxRenderer
16 |
17 |
18 | class MVT(nn.Module):
19 | def __init__(
20 | self,
21 | depth,
22 | img_size,
23 | add_proprio,
24 | proprio_dim,
25 | add_lang,
26 | lang_dim,
27 | lang_len,
28 | img_feat_dim,
29 | feat_dim,
30 | im_channels,
31 | attn_dim,
32 | attn_heads,
33 | attn_dim_head,
34 | activation,
35 | weight_tie_layers,
36 | attn_dropout,
37 | decoder_dropout,
38 | img_patch_size,
39 | final_dim,
40 | self_cross_ver,
41 | add_corr,
42 | norm_corr,
43 | add_pixel_loc,
44 | add_depth,
45 | rend_three_views,
46 | use_point_renderer,
47 | pe_fix,
48 | feat_ver,
49 | wpt_img_aug,
50 | inp_pre_pro,
51 | inp_pre_con,
52 | cvx_up,
53 | xops,
54 | rot_ver,
55 | num_rot,
56 | stage_two,
57 | st_sca,
58 | st_wpt_loc_aug,
59 | st_wpt_loc_inp_no_noise,
60 | img_aug_2,
61 | renderer_device="cuda:0",
62 | ):
63 | """MultiView Transfomer
64 | :param stage_two: whether or not there are two stages
65 | :param st_sca: scaling of the pc in the second stage
66 | :param st_wpt_loc_aug: how much noise is to be added to wpt_local when
67 | transforming the pc in the second stage while training. This is
68 | expressed as a percentage of total pc size which is 2.
69 | :param st_wpt_loc_inp_no_noise: whether or not to add any noise to the
70 | wpt_local location which is fed to stage_two. This wpt_local
71 | location is used to extract features for rotation prediction
72 | currently. Other use cases might also arise later on. Even if
73 | st_wpt_loc_aug is True, this will compensate for that if set to
74 | True.
75 | :param img_aug_2: similar to img_aug in rvt repo but applied only to
76 | point feat and not the whole point cloud
77 | """
78 | super().__init__()
79 |
80 | self.use_point_renderer = use_point_renderer
81 | if self.use_point_renderer:
82 | from point_renderer.rvt_renderer import RVTBoxRenderer as BoxRenderer
83 | else:
84 | from mvt.renderer import BoxRenderer
85 | global BoxRenderer
86 |
87 | # creating a dictonary of all the input parameters
88 | args = copy.deepcopy(locals())
89 | del args["self"]
90 | del args["__class__"]
91 | del args["stage_two"]
92 | del args["st_sca"]
93 | del args["st_wpt_loc_aug"]
94 | del args["st_wpt_loc_inp_no_noise"]
95 | del args["img_aug_2"]
96 |
97 | self.rot_ver = rot_ver
98 | self.num_rot = num_rot
99 | self.stage_two = stage_two
100 | self.st_sca = st_sca
101 | self.st_wpt_loc_aug = st_wpt_loc_aug
102 | self.st_wpt_loc_inp_no_noise = st_wpt_loc_inp_no_noise
103 | self.img_aug_2 = img_aug_2
104 |
105 | # for verifying the input
106 | self.feat_ver = feat_ver
107 | self.img_feat_dim = img_feat_dim
108 | self.add_proprio = add_proprio
109 | self.proprio_dim = proprio_dim
110 | self.add_lang = add_lang
111 | if add_lang:
112 | lang_emb_dim, lang_max_seq_len = lang_dim, lang_len
113 | else:
114 | lang_emb_dim, lang_max_seq_len = 0, 0
115 | self.lang_emb_dim = lang_emb_dim
116 | self.lang_max_seq_len = lang_max_seq_len
117 |
118 | self.renderer = BoxRenderer(
119 | device=renderer_device,
120 | img_size=(img_size, img_size),
121 | three_views=rend_three_views,
122 | with_depth=add_depth,
123 | )
124 | self.num_img = self.renderer.num_img
125 | self.proprio_dim = proprio_dim
126 | self.img_size = img_size
127 |
128 | self.mvt1 = MVTSingle(
129 | **args,
130 | renderer=self.renderer,
131 | no_feat=self.stage_two,
132 | )
133 | if self.stage_two:
134 | self.mvt2 = MVTSingle(**args, renderer=self.renderer)
135 |
136 | def get_pt_loc_on_img(self, pt, mvt1_or_mvt2, dyn_cam_info, out=None):
137 | """
138 | :param pt: point for which location on image is to be found. the point
139 | shoud be in the same reference frame as wpt_local (see forward()),
140 | even for mvt2
141 | :param out: output from mvt, when using mvt2, we also need to provide the
142 | origin location where where the point cloud needs to be shifted
143 | before estimating the location in the image
144 | """
145 | assert len(pt.shape) == 3
146 | bs, _np, x = pt.shape
147 | assert x == 3
148 |
149 | assert isinstance(mvt1_or_mvt2, bool)
150 | if mvt1_or_mvt2:
151 | assert out is None
152 | out = self.mvt1.get_pt_loc_on_img(pt, dyn_cam_info)
153 | else:
154 | assert self.stage_two
155 | assert out is not None
156 | assert out['wpt_local1'].shape == (bs, 3)
157 | pt, _ = mvt_utils.trans_pc(pt, loc=out["wpt_local1"], sca=self.st_sca)
158 | pt = pt.view(bs, _np, 3)
159 | out = self.mvt2.get_pt_loc_on_img(pt, dyn_cam_info)
160 |
161 | return out
162 |
163 | def get_wpt(self, out, mvt1_or_mvt2, dyn_cam_info, y_q=None):
164 | """
165 | Estimate the q-values given output from mvt
166 | :param out: output from mvt
167 | :param y_q: refer to the definition in mvt_single.get_wpt
168 | """
169 | assert isinstance(mvt1_or_mvt2, bool)
170 | if mvt1_or_mvt2:
171 | wpt = self.mvt1.get_wpt(
172 | out, dyn_cam_info, y_q,
173 | )
174 | else:
175 | assert self.stage_two
176 | wpt = self.mvt2.get_wpt(
177 | out["mvt2"], dyn_cam_info, y_q
178 | )
179 | wpt = out["rev_trans"](wpt)
180 |
181 | return wpt
182 |
183 | def render(self, pc, img_feat, img_aug, mvt1_or_mvt2, dyn_cam_info):
184 | assert isinstance(mvt1_or_mvt2, bool)
185 | if mvt1_or_mvt2:
186 | mvt = self.mvt1
187 | else:
188 | mvt = self.mvt2
189 |
190 | with torch.no_grad():
191 | with autocast(enabled=False):
192 | if dyn_cam_info is None:
193 | dyn_cam_info_itr = (None,) * len(pc)
194 | else:
195 | dyn_cam_info_itr = dyn_cam_info
196 |
197 | if mvt.add_corr:
198 | if mvt.norm_corr:
199 | img = []
200 | for _pc, _img_feat, _dyn_cam_info in zip(
201 | pc, img_feat, dyn_cam_info_itr
202 | ):
203 | # fix when the pc is empty
204 | max_pc = 1.0 if len(_pc) == 0 else torch.max(torch.abs(_pc))
205 | img.append(
206 | self.renderer(
207 | _pc,
208 | torch.cat((_pc / max_pc, _img_feat), dim=-1),
209 | fix_cam=True,
210 | dyn_cam_info=(_dyn_cam_info,)
211 | if not (_dyn_cam_info is None)
212 | else None,
213 | ).unsqueeze(0)
214 | )
215 | else:
216 | img = [
217 | self.renderer(
218 | _pc,
219 | torch.cat((_pc, _img_feat), dim=-1),
220 | fix_cam=True,
221 | dyn_cam_info=(_dyn_cam_info,)
222 | if not (_dyn_cam_info is None)
223 | else None,
224 | ).unsqueeze(0)
225 | for (_pc, _img_feat, _dyn_cam_info) in zip(
226 | pc, img_feat, dyn_cam_info_itr
227 | )
228 | ]
229 | else:
230 | img = [
231 | self.renderer(
232 | _pc,
233 | _img_feat,
234 | fix_cam=True,
235 | dyn_cam_info=(_dyn_cam_info,)
236 | if not (_dyn_cam_info is None)
237 | else None,
238 | ).unsqueeze(0)
239 | for (_pc, _img_feat, _dyn_cam_info) in zip(
240 | pc, img_feat, dyn_cam_info_itr
241 | )
242 | ]
243 |
244 | img = torch.cat(img, 0)
245 | img = img.permute(0, 1, 4, 2, 3)
246 |
247 | # for visualization purposes
248 | if mvt.add_corr:
249 | mvt.img = img[:, :, 3:].clone().detach()
250 | else:
251 | mvt.img = img.clone().detach()
252 |
253 | # image augmentation
254 | if img_aug != 0:
255 | stdv = img_aug * torch.rand(1, device=img.device)
256 | # values in [-stdv, stdv]
257 | noise = stdv * ((2 * torch.rand(*img.shape, device=img.device)) - 1)
258 | img = torch.clamp(img + noise, -1, 1)
259 |
260 | if mvt.add_pixel_loc:
261 | bs = img.shape[0]
262 | pixel_loc = mvt.pixel_loc.to(img.device)
263 | img = torch.cat(
264 | (img, pixel_loc.unsqueeze(0).repeat(bs, 1, 1, 1, 1)), dim=2
265 | )
266 |
267 | return img
268 |
269 | def verify_inp(
270 | self,
271 | pc,
272 | img_feat,
273 | proprio,
274 | lang_emb,
275 | img_aug,
276 | wpt_local,
277 | rot_x_y,
278 | ):
279 | bs = len(pc)
280 | assert bs == len(img_feat)
281 |
282 | if not self.training:
283 | # no img_aug when not training
284 | assert img_aug == 0
285 | assert rot_x_y is None, f"rot_x_y={rot_x_y}"
286 |
287 | if self.training:
288 | assert (
289 | (not self.feat_ver == 1)
290 | or (not wpt_local is None)
291 | )
292 |
293 | if self.rot_ver == 0:
294 | assert rot_x_y is None, f"rot_x_y={rot_x_y}"
295 | elif self.rot_ver == 1:
296 | assert rot_x_y.shape == (bs, 2), f"rot_x_y.shape={rot_x_y.shape}"
297 | assert (rot_x_y >= 0).all() and (
298 | rot_x_y < self.num_rot
299 | ).all(), f"rot_x_y={rot_x_y}"
300 | else:
301 | assert False
302 |
303 | for _pc, _img_feat in zip(pc, img_feat):
304 | np, x1 = _pc.shape
305 | np2, x2 = _img_feat.shape
306 |
307 | assert np == np2
308 | assert x1 == 3
309 | assert x2 == self.img_feat_dim
310 |
311 | if self.add_proprio:
312 | bs3, x3 = proprio.shape
313 | assert bs == bs3
314 | assert (
315 | x3 == self.proprio_dim
316 | ), "Does not support proprio of shape {proprio.shape}"
317 | else:
318 | assert proprio is None, "Invalid input for proprio={proprio}"
319 |
320 | if self.add_lang:
321 | bs4, x4, x5 = lang_emb.shape
322 | assert bs == bs4
323 | assert (
324 | x4 == self.lang_max_seq_len
325 | ), "Does not support lang_emb of shape {lang_emb.shape}"
326 | assert (
327 | x5 == self.lang_emb_dim
328 | ), "Does not support lang_emb of shape {lang_emb.shape}"
329 | else:
330 | assert (lang_emb is None) or (
331 | torch.all(lang_emb == 0)
332 | ), f"Invalid input for lang={lang}"
333 |
334 | if not (wpt_local is None):
335 | bs5, x6 = wpt_local.shape
336 | assert bs == bs5
337 | assert x6 == 3, "Does not support wpt_local of shape {wpt_local.shape}"
338 |
339 | if self.training:
340 | assert (not self.stage_two) or (not wpt_local is None)
341 |
342 | def forward(
343 | self,
344 | pc,
345 | img_feat,
346 | proprio=None,
347 | lang_emb=None,
348 | img_aug=0,
349 | wpt_local=None,
350 | rot_x_y=None,
351 | **kwargs,
352 | ):
353 | """
354 | :param pc: list of tensors, each tensor of shape (num_points, 3)
355 | :param img_feat: list tensors, each tensor of shape
356 | (bs, num_points, img_feat_dim)
357 | :param proprio: tensor of shape (bs, priprio_dim)
358 | :param lang_emb: tensor of shape (bs, lang_len, lang_dim)
359 | :param img_aug: (float) magnitude of augmentation in rgb image
360 | :param wpt_local: gt location of the wpt in 3D, tensor of shape
361 | (bs, 3)
362 | :param rot_x_y: (bs, 2) rotation in x and y direction
363 | """
364 | self.verify_inp(
365 | pc=pc,
366 | img_feat=img_feat,
367 | proprio=proprio,
368 | lang_emb=lang_emb,
369 | img_aug=img_aug,
370 | wpt_local=wpt_local,
371 | rot_x_y=rot_x_y,
372 | )
373 | with torch.no_grad():
374 | if self.training and (self.img_aug_2 != 0):
375 | for x in img_feat:
376 | stdv = self.img_aug_2 * torch.rand(1, device=x.device)
377 | # values in [-stdv, stdv]
378 | noise = stdv * ((2 * torch.rand(*x.shape, device=x.device)) - 1)
379 | x = x + noise
380 | img = self.render(
381 | pc=pc,
382 | img_feat=img_feat,
383 | img_aug=img_aug,
384 | mvt1_or_mvt2=True,
385 | dyn_cam_info=None,
386 | )
387 |
388 | if self.training:
389 | wpt_local_stage_one = wpt_local
390 | wpt_local_stage_one = wpt_local_stage_one.clone().detach()
391 | else:
392 | wpt_local_stage_one = wpt_local
393 |
394 | out = self.mvt1(
395 | img=img,
396 | proprio=proprio,
397 | lang_emb=lang_emb,
398 | wpt_local=wpt_local_stage_one,
399 | rot_x_y=rot_x_y,
400 | **kwargs,
401 | )
402 |
403 | if self.stage_two:
404 | with torch.no_grad():
405 | # adding then noisy location for training
406 | if self.training:
407 | # noise is added so that the wpt_local2 is not exactly at
408 | # the center of the pc
409 | wpt_local_stage_one_noisy = mvt_utils.add_uni_noi(
410 | wpt_local_stage_one.clone().detach(), 2 * self.st_wpt_loc_aug
411 | )
412 | pc, rev_trans = mvt_utils.trans_pc(
413 | pc, loc=wpt_local_stage_one_noisy, sca=self.st_sca
414 | )
415 |
416 | if self.st_wpt_loc_inp_no_noise:
417 | wpt_local2, _ = mvt_utils.trans_pc(
418 | wpt_local, loc=wpt_local_stage_one_noisy, sca=self.st_sca
419 | )
420 | else:
421 | wpt_local2, _ = mvt_utils.trans_pc(
422 | wpt_local, loc=wpt_local_stage_one, sca=self.st_sca
423 | )
424 |
425 | else:
426 | # bs, 3
427 | wpt_local = self.get_wpt(
428 | out, y_q=None, mvt1_or_mvt2=True,
429 | dyn_cam_info=None,
430 | )
431 | pc, rev_trans = mvt_utils.trans_pc(
432 | pc, loc=wpt_local, sca=self.st_sca
433 | )
434 | # bad name!
435 | wpt_local_stage_one_noisy = wpt_local
436 |
437 | # must pass None to mvt2 while in eval
438 | wpt_local2 = None
439 |
440 | img = self.render(
441 | pc=pc,
442 | img_feat=img_feat,
443 | img_aug=img_aug,
444 | mvt1_or_mvt2=False,
445 | dyn_cam_info=None,
446 | )
447 |
448 | out_mvt2 = self.mvt2(
449 | img=img,
450 | proprio=proprio,
451 | lang_emb=lang_emb,
452 | wpt_local=wpt_local2,
453 | rot_x_y=rot_x_y,
454 | **kwargs,
455 | )
456 |
457 | out["wpt_local1"] = wpt_local_stage_one_noisy
458 | out["rev_trans"] = rev_trans
459 | out["mvt2"] = out_mvt2
460 |
461 | return out
462 |
463 | def free_mem(self):
464 | """
465 | Could be used for freeing up the memory once a batch of testing is done
466 | """
467 | if not self.use_point_renderer:
468 | print("Freeing up some memory")
469 | self.renderer.free_mem()
470 |
471 |
472 | if __name__ == "__main__":
473 | cfg = get_cfg_defaults()
474 | mvt = MVT(**cfg)
475 | breakpoint()
476 |
--------------------------------------------------------------------------------
/rvt/mvt/raft_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from mvt.utils import ForkedPdb
9 |
10 |
11 | class ConvexUpSample(nn.Module):
12 | """
13 | Learned convex upsampling
14 | """
15 |
16 | def __init__(
17 | self, in_dim, out_dim, up_ratio, up_kernel=3, mask_scale=0.1, with_bn=False
18 | ):
19 | """
20 |
21 | :param in_dim:
22 | :param out_dim:
23 | :param up_ratio:
24 | :param up_kernel:
25 | :param mask_scale:
26 | """
27 | super().__init__()
28 | self.in_dim = in_dim
29 | self.out_dim = out_dim
30 | self.up_ratio = up_ratio
31 | self.up_kernel = up_kernel
32 | self.mask_scale = mask_scale
33 | self.with_bn = with_bn
34 |
35 | assert (self.up_kernel % 2) == 1
36 |
37 | if with_bn:
38 | self.net_out_bn1 = nn.BatchNorm2d(2 * in_dim)
39 | self.net_out_bn2 = nn.BatchNorm2d(2 * in_dim)
40 |
41 | self.net_out = nn.Sequential(
42 | nn.Conv2d(in_dim, 2 * in_dim, 3, padding=1),
43 | nn.ReLU(inplace=True),
44 | nn.Conv2d(2 * in_dim, 2 * in_dim, 3, padding=1),
45 | nn.ReLU(inplace=True),
46 | nn.Conv2d(2 * in_dim, out_dim, 3, padding=1),
47 | )
48 |
49 | mask_dim = (self.up_ratio**2) * (self.up_kernel**2)
50 | self.net_mask = nn.Sequential(
51 | nn.Conv2d(in_dim, 2 * in_dim, 3, padding=1),
52 | nn.ReLU(inplace=True),
53 | nn.Conv2d(2 * in_dim, mask_dim, 1, padding=0),
54 | )
55 |
56 | def forward(self, x):
57 | """
58 |
59 | :param x: (bs, in_dim, h, w)
60 | :return: (bs, out_dim, h*up_ratio, w*up_ratio)
61 | """
62 |
63 | bs, c, h, w = x.shape
64 | assert c == self.in_dim, c
65 |
66 | # low resolution output
67 | if self.with_bn:
68 | out_low = self.net_out[0](x)
69 | out_low = self.net_out_bn1(out_low)
70 | out_low = self.net_out[1](out_low)
71 | out_low = self.net_out[2](out_low)
72 | out_low = self.net_out_bn2(out_low)
73 | out_low = self.net_out[3](out_low)
74 | out_low = self.net_out[4](out_low)
75 | else:
76 | out_low = self.net_out(x)
77 |
78 | mask = self.mask_scale * self.net_mask(x)
79 | mask = mask.view(bs, 1, self.up_kernel**2, self.up_ratio, self.up_ratio, h, w)
80 | mask = torch.softmax(mask, dim=2)
81 |
82 | out = F.unfold(
83 | out_low,
84 | kernel_size=[self.up_kernel, self.up_kernel],
85 | padding=self.up_kernel // 2,
86 | )
87 | out = out.view(bs, self.out_dim, self.up_kernel**2, 1, 1, h, w)
88 |
89 | out = torch.sum(out * mask, dim=2)
90 | out = out.permute(0, 1, 4, 2, 5, 3)
91 | out = out.reshape(bs, self.out_dim, h * self.up_ratio, w * self.up_ratio)
92 |
93 | return out
94 |
95 |
96 | if __name__ == "__main__":
97 | net = ConvexUpSample(2, 5, 20).cuda()
98 | x = torch.rand(4, 5, 10, 10).cuda()
99 | y = net(x)
100 | breakpoint()
101 |
--------------------------------------------------------------------------------
/rvt/mvt/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | """
6 | Utility function for MVT
7 | """
8 | import pdb
9 | import sys
10 |
11 | import torch
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 |
15 |
16 | def place_pc_in_cube(
17 | pc, app_pc=None, with_mean_or_bounds=True, scene_bounds=None, no_op=False
18 | ):
19 | """
20 | calculate the transformation that would place the point cloud (pc) inside a
21 | cube of size (2, 2, 2). The pc is centered at mean if with_mean_or_bounds
22 | is True. If with_mean_or_bounds is False, pc is centered around the mid
23 | point of the bounds. The transformation is applied to point cloud app_pc if
24 | it is not None. If app_pc is None, the transformation is applied on pc.
25 | :param pc: pc of shape (num_points_1, 3)
26 | :param app_pc:
27 | Either
28 | - pc of shape (num_points_2, 3)
29 | - None
30 | :param with_mean_or_bounds:
31 | Either:
32 | True: pc is centered around its mean
33 | False: pc is centered around the center of the scene bounds
34 | :param scene_bounds: [x_min, y_min, z_min, x_max, y_max, z_max]
35 | :param no_op: if no_op, then this function does not do any operation
36 | """
37 | if no_op:
38 | if app_pc is None:
39 | app_pc = torch.clone(pc)
40 |
41 | return app_pc, lambda x: x
42 |
43 | if with_mean_or_bounds:
44 | assert scene_bounds is None
45 | else:
46 | assert not (scene_bounds is None)
47 | if with_mean_or_bounds:
48 | pc_mid = (torch.max(pc, 0)[0] + torch.min(pc, 0)[0]) / 2
49 | x_len, y_len, z_len = torch.max(pc, 0)[0] - torch.min(pc, 0)[0]
50 | else:
51 | x_min, y_min, z_min, x_max, y_max, z_max = scene_bounds
52 | pc_mid = torch.tensor(
53 | [
54 | (x_min + x_max) / 2,
55 | (y_min + y_max) / 2,
56 | (z_min + z_max) / 2,
57 | ]
58 | ).to(pc.device)
59 | x_len, y_len, z_len = x_max - x_min, y_max - y_min, z_max - z_min
60 |
61 | scale = 2 / max(x_len, y_len, z_len)
62 | if app_pc is None:
63 | app_pc = torch.clone(pc)
64 | app_pc = (app_pc - pc_mid) * scale
65 |
66 | # reverse transformation to obtain app_pc in original frame
67 | def rev_trans(x):
68 | return (x / scale) + pc_mid
69 |
70 | return app_pc, rev_trans
71 |
72 |
73 | def trans_pc(pc, loc, sca):
74 | """
75 | change location of the center of the pc and scale it
76 | :param pc:
77 | either:
78 | - tensor of shape(b, num_points, 3)
79 | - tensor of shape(b, 3)
80 | - list of pc each with size (num_points, 3)
81 | :param loc: (b, 3 )
82 | :param sca: 1 or (3)
83 | """
84 | assert len(loc.shape) == 2
85 | assert loc.shape[-1] == 3
86 | if isinstance(pc, list):
87 | assert all([(len(x.shape) == 2) and (x.shape[1] == 3) for x in pc])
88 | pc = [sca * (x - y) for x, y in zip(pc, loc)]
89 | elif isinstance(pc, torch.Tensor):
90 | assert len(pc.shape) in [2, 3]
91 | assert pc.shape[-1] == 3
92 | if len(pc.shape) == 2:
93 | pc = sca * (pc - loc)
94 | else:
95 | pc = sca * (pc - loc.unsqueeze(1))
96 | else:
97 | assert False
98 |
99 | # reverse transformation to obtain app_pc in original frame
100 | def rev_trans(x):
101 | assert isinstance(x, torch.Tensor)
102 | return (x / sca) + loc
103 |
104 | return pc, rev_trans
105 |
106 |
107 | def add_uni_noi(x, u):
108 | """
109 | adds uniform noise to a tensor x. output is tensor where each element is
110 | in [x-u, x+u]
111 | :param x: tensor
112 | :param u: float
113 | """
114 | assert isinstance(u, float)
115 | # move noise in -1 to 1
116 | noise = (2 * torch.rand(*x.shape, device=x.device)) - 1
117 | x = x + (u * noise)
118 | return x
119 |
120 |
121 | def generate_hm_from_pt(pt, res, sigma, thres_sigma_times=3):
122 | """
123 | Pytorch code to generate heatmaps from point. Points with values less than
124 | thres are made 0
125 | :type pt: torch.FloatTensor of size (num_pt, 2)
126 | :type res: int or (int, int)
127 | :param sigma: the std of the gaussian distribition. if it is -1, we
128 | generate a hm with one hot vector
129 | :type sigma: float
130 | :type thres: float
131 | """
132 | num_pt, x = pt.shape
133 | assert x == 2
134 |
135 | if isinstance(res, int):
136 | resx = resy = res
137 | else:
138 | resx, resy = res
139 |
140 | _hmx = torch.arange(0, resy).to(pt.device)
141 | _hmx = _hmx.view([1, resy]).repeat(resx, 1).view([resx, resy, 1])
142 | _hmy = torch.arange(0, resx).to(pt.device)
143 | _hmy = _hmy.view([resx, 1]).repeat(1, resy).view([resx, resy, 1])
144 | hm = torch.cat([_hmx, _hmy], dim=-1)
145 | hm = hm.view([1, resx, resy, 2]).repeat(num_pt, 1, 1, 1)
146 |
147 | pt = pt.view([num_pt, 1, 1, 2])
148 | hm = torch.exp(-1 * torch.sum((hm - pt) ** 2, -1) / (2 * (sigma**2)))
149 | thres = np.exp(-1 * (thres_sigma_times**2) / 2)
150 | hm[hm < thres] = 0.0
151 |
152 | hm /= torch.sum(hm, (1, 2), keepdim=True) + 1e-6
153 |
154 | # TODO: make a more efficient version
155 | if sigma == -1:
156 | _hm = hm.view(num_pt, resx * resy)
157 | hm = torch.zeros((num_pt, resx * resy), device=hm.device)
158 | temp = torch.arange(num_pt).to(hm.device)
159 | hm[temp, _hm.argmax(-1)] = 1
160 |
161 | return hm
162 |
163 |
164 | class ForkedPdb(pdb.Pdb):
165 | """A Pdb subclass that may be used
166 | from a forked multiprocessing child
167 | """
168 |
169 | def interaction(self, *args, **kwargs):
170 | _stdin = sys.stdin
171 | try:
172 | sys.stdin = open("/dev/stdin")
173 | pdb.Pdb.interaction(self, *args, **kwargs)
174 | finally:
175 | sys.stdin = _stdin
176 |
--------------------------------------------------------------------------------
/rvt/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import os
6 | import time
7 | import tqdm
8 | import random
9 | import yaml
10 | import argparse
11 |
12 | from collections import defaultdict
13 | from contextlib import redirect_stdout
14 |
15 | import torch
16 | import torch.multiprocessing as mp
17 | import torch.distributed as dist
18 | from torch.nn.parallel import DistributedDataParallel as DDP
19 |
20 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
21 | os.environ["BITSANDBYTES_NOWELCOME"] = "1"
22 |
23 | import config as exp_cfg_mod
24 | import rvt.models.rvt_agent as rvt_agent
25 | import rvt.utils.ddp_utils as ddp_utils
26 | import rvt.mvt.config as mvt_cfg_mod
27 |
28 | from rvt.mvt.mvt import MVT
29 | from rvt.models.rvt_agent import print_eval_log, print_loss_log
30 | from rvt.utils.get_dataset import get_dataset
31 | from rvt.utils.rvt_utils import (
32 | TensorboardManager,
33 | short_name,
34 | get_num_feat,
35 | load_agent,
36 | RLBENCH_TASKS,
37 | )
38 | from rvt.utils.peract_utils import (
39 | CAMERAS,
40 | SCENE_BOUNDS,
41 | IMAGE_SIZE,
42 | DATA_FOLDER,
43 | )
44 |
45 |
46 | # new train takes the dataset as input
47 | def train(agent, dataset, training_iterations, rank=0):
48 | agent.train()
49 | log = defaultdict(list)
50 |
51 | data_iter = iter(dataset)
52 | iter_command = range(training_iterations)
53 |
54 | for iteration in tqdm.tqdm(
55 | iter_command, disable=(rank != 0), position=0, leave=True
56 | ):
57 |
58 | raw_batch = next(data_iter)
59 | batch = {
60 | k: v.to(agent._device)
61 | for k, v in raw_batch.items()
62 | if type(v) == torch.Tensor
63 | }
64 | batch["tasks"] = raw_batch["tasks"]
65 | batch["lang_goal"] = raw_batch["lang_goal"]
66 | update_args = {
67 | "step": iteration,
68 | }
69 | update_args.update(
70 | {
71 | "replay_sample": batch,
72 | "backprop": True,
73 | "reset_log": (iteration == 0),
74 | "eval_log": False,
75 | }
76 | )
77 | agent.update(**update_args)
78 |
79 | if rank == 0:
80 | log = print_loss_log(agent)
81 |
82 | return log
83 |
84 |
85 | def save_agent(agent, path, epoch):
86 | model = agent._network
87 | optimizer = agent._optimizer
88 | lr_sched = agent._lr_sched
89 |
90 | if isinstance(model, DDP):
91 | model_state = model.module.state_dict()
92 | else:
93 | model_state = model.state_dict()
94 |
95 | torch.save(
96 | {
97 | "epoch": epoch,
98 | "model_state": model_state,
99 | "optimizer_state": optimizer.state_dict(),
100 | "lr_sched_state": lr_sched.state_dict(),
101 | },
102 | path,
103 | )
104 |
105 |
106 | def get_tasks(exp_cfg):
107 | parsed_tasks = exp_cfg.tasks.split(",")
108 | if parsed_tasks[0] == "all":
109 | tasks = RLBENCH_TASKS
110 | else:
111 | tasks = parsed_tasks
112 | return tasks
113 |
114 |
115 | def get_logdir(cmd_args, exp_cfg):
116 | log_dir = os.path.join(cmd_args.log_dir, exp_cfg.exp_id)
117 | os.makedirs(log_dir, exist_ok=True)
118 | return log_dir
119 |
120 |
121 | def dump_log(exp_cfg, mvt_cfg, cmd_args, log_dir):
122 | with open(f"{log_dir}/exp_cfg.yaml", "w") as yaml_file:
123 | with redirect_stdout(yaml_file):
124 | print(exp_cfg.dump())
125 |
126 | with open(f"{log_dir}/mvt_cfg.yaml", "w") as yaml_file:
127 | with redirect_stdout(yaml_file):
128 | print(mvt_cfg.dump())
129 |
130 | args = cmd_args.__dict__
131 | with open(f"{log_dir}/args.yaml", "w") as yaml_file:
132 | yaml.dump(args, yaml_file)
133 |
134 |
135 | def experiment(rank, cmd_args, devices, port):
136 | """experiment.
137 |
138 | :param rank:
139 | :param cmd_args:
140 | :param devices: list or int. if list, we use ddp else not
141 | """
142 | device = devices[rank]
143 | device = f"cuda:{device}"
144 | ddp = len(devices) > 1
145 | ddp_utils.setup(rank, world_size=len(devices), port=port)
146 |
147 | exp_cfg = exp_cfg_mod.get_cfg_defaults()
148 | if cmd_args.exp_cfg_path != "":
149 | exp_cfg.merge_from_file(cmd_args.exp_cfg_path)
150 | if cmd_args.exp_cfg_opts != "":
151 | exp_cfg.merge_from_list(cmd_args.exp_cfg_opts.split(" "))
152 |
153 | if ddp:
154 | print(f"Running DDP on rank {rank}.")
155 |
156 | old_exp_cfg_peract_lr = exp_cfg.peract.lr
157 | old_exp_cfg_exp_id = exp_cfg.exp_id
158 |
159 | exp_cfg.peract.lr *= len(devices) * exp_cfg.bs
160 | if cmd_args.exp_cfg_opts != "":
161 | exp_cfg.exp_id += f"_{short_name(cmd_args.exp_cfg_opts)}"
162 | if cmd_args.mvt_cfg_opts != "":
163 | exp_cfg.exp_id += f"_{short_name(cmd_args.mvt_cfg_opts)}"
164 |
165 | if rank == 0:
166 | print(f"dict(exp_cfg)={dict(exp_cfg)}")
167 | exp_cfg.freeze()
168 |
169 | # Things to change
170 | BATCH_SIZE_TRAIN = exp_cfg.bs
171 | NUM_TRAIN = 100
172 | # to match peract, iterations per epoch
173 | TRAINING_ITERATIONS = int(exp_cfg.train_iter // (exp_cfg.bs * len(devices)))
174 | EPOCHS = exp_cfg.epochs
175 | TRAIN_REPLAY_STORAGE_DIR = "replay/replay_train"
176 | TEST_REPLAY_STORAGE_DIR = "replay/replay_val"
177 | log_dir = get_logdir(cmd_args, exp_cfg)
178 | tasks = get_tasks(exp_cfg)
179 | print("Training on {} tasks: {}".format(len(tasks), tasks))
180 |
181 | t_start = time.time()
182 | get_dataset_func = lambda: get_dataset(
183 | tasks,
184 | BATCH_SIZE_TRAIN,
185 | None,
186 | TRAIN_REPLAY_STORAGE_DIR,
187 | None,
188 | DATA_FOLDER,
189 | NUM_TRAIN,
190 | None,
191 | cmd_args.refresh_replay,
192 | device,
193 | num_workers=exp_cfg.num_workers,
194 | only_train=True,
195 | sample_distribution_mode=exp_cfg.sample_distribution_mode,
196 | )
197 | train_dataset, _ = get_dataset_func()
198 | t_end = time.time()
199 | print("Created Dataset. Time Cost: {} minutes".format((t_end - t_start) / 60.0))
200 |
201 | if exp_cfg.agent == "our":
202 | mvt_cfg = mvt_cfg_mod.get_cfg_defaults()
203 | if cmd_args.mvt_cfg_path != "":
204 | mvt_cfg.merge_from_file(cmd_args.mvt_cfg_path)
205 | if cmd_args.mvt_cfg_opts != "":
206 | mvt_cfg.merge_from_list(cmd_args.mvt_cfg_opts.split(" "))
207 |
208 | mvt_cfg.feat_dim = get_num_feat(exp_cfg.peract)
209 | mvt_cfg.freeze()
210 |
211 | # for maintaining backward compatibility
212 | assert mvt_cfg.num_rot == exp_cfg.peract.num_rotation_classes, print(
213 | mvt_cfg.num_rot, exp_cfg.peract.num_rotation_classes
214 | )
215 |
216 | torch.cuda.set_device(device)
217 | torch.cuda.empty_cache()
218 | rvt = MVT(
219 | renderer_device=device,
220 | **mvt_cfg,
221 | ).to(device)
222 | if ddp:
223 | rvt = DDP(rvt, device_ids=[device])
224 |
225 | agent = rvt_agent.RVTAgent(
226 | network=rvt,
227 | image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
228 | add_lang=mvt_cfg.add_lang,
229 | stage_two=mvt_cfg.stage_two,
230 | rot_ver=mvt_cfg.rot_ver,
231 | scene_bounds=SCENE_BOUNDS,
232 | cameras=CAMERAS,
233 | log_dir=f"{log_dir}/test_run/",
234 | cos_dec_max_step=EPOCHS * TRAINING_ITERATIONS,
235 | **exp_cfg.peract,
236 | **exp_cfg.rvt,
237 | )
238 | agent.build(training=True, device=device)
239 | else:
240 | assert False, "Incorrect agent"
241 |
242 | start_epoch = 0
243 | end_epoch = EPOCHS
244 | if exp_cfg.resume != "":
245 | agent_path = exp_cfg.resume
246 | print(f"Recovering model and checkpoint from {exp_cfg.resume}")
247 | epoch = load_agent(agent_path, agent, only_epoch=False)
248 | start_epoch = epoch + 1
249 | dist.barrier()
250 |
251 | if rank == 0:
252 | ## logging unchanged values to reproduce the same setting
253 | temp1 = exp_cfg.peract.lr
254 | temp2 = exp_cfg.exp_id
255 | exp_cfg.defrost()
256 | exp_cfg.peract.lr = old_exp_cfg_peract_lr
257 | exp_cfg.exp_id = old_exp_cfg_exp_id
258 | dump_log(exp_cfg, mvt_cfg, cmd_args, log_dir)
259 | exp_cfg.peract.lr = temp1
260 | exp_cfg.exp_id = temp2
261 | exp_cfg.freeze()
262 | tb = TensorboardManager(log_dir)
263 |
264 | print("Start training ...", flush=True)
265 | i = start_epoch
266 | while True:
267 | if i == end_epoch:
268 | break
269 |
270 | print(f"Rank [{rank}], Epoch [{i}]: Training on train dataset")
271 | out = train(agent, train_dataset, TRAINING_ITERATIONS, rank)
272 |
273 | if rank == 0:
274 | tb.update("train", i, out)
275 |
276 | if rank == 0:
277 | # TODO: add logic to only save some models
278 | save_agent(agent, f"{log_dir}/model_{i}.pth", i)
279 | save_agent(agent, f"{log_dir}/model_last.pth", i)
280 | i += 1
281 |
282 | if rank == 0:
283 | tb.close()
284 | print("[Finish]")
285 |
286 |
287 | if __name__ == "__main__":
288 | parser = argparse.ArgumentParser()
289 | parser.set_defaults(entry=lambda cmd_args: parser.print_help())
290 |
291 | parser.add_argument("--refresh_replay", action="store_true", default=False)
292 | parser.add_argument("--device", type=str, default="0")
293 | parser.add_argument("--mvt_cfg_path", type=str, default="")
294 | parser.add_argument("--exp_cfg_path", type=str, default="")
295 |
296 | parser.add_argument("--mvt_cfg_opts", type=str, default="")
297 | parser.add_argument("--exp_cfg_opts", type=str, default="")
298 |
299 | parser.add_argument("--log-dir", type=str, default="runs")
300 | parser.add_argument("--with-eval", action="store_true", default=False)
301 |
302 | cmd_args = parser.parse_args()
303 | del (
304 | cmd_args.entry
305 | ) # hack for multi processing -- removes an argument called entry which is not picklable
306 |
307 | devices = cmd_args.device.split(",")
308 | devices = [int(x) for x in devices]
309 |
310 | port = (random.randint(0, 3000) % 3000) + 27000
311 | mp.spawn(experiment, args=(cmd_args, devices, port), nprocs=len(devices), join=True)
312 |
--------------------------------------------------------------------------------
/rvt/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
--------------------------------------------------------------------------------
/rvt/utils/custom_rlbench_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | from rvt.libs.peract.helpers.custom_rlbench_env import CustomMultiTaskRLBenchEnv
6 |
7 |
8 | class CustomMultiTaskRLBenchEnv2(CustomMultiTaskRLBenchEnv):
9 | def __init__(self, *args, **kwargs):
10 | super(CustomMultiTaskRLBenchEnv2, self).__init__(*args, **kwargs)
11 |
12 | def reset(self) -> dict:
13 | super().reset()
14 | self._record_current_episode = (
15 | self.eval
16 | and self._record_every_n > 0
17 | and self._episode_index % self._record_every_n == 0
18 | )
19 | return self._previous_obs_dict
20 |
21 | def reset_to_demo(self, i, variation_number=-1):
22 | if self._episodes_this_task == self._swap_task_every:
23 | self._set_new_task()
24 | self._episodes_this_task = 0
25 | self._episodes_this_task += 1
26 |
27 | self._i = 0
28 | self._task.set_variation(-1)
29 | d = self._task.get_demos(
30 | 1, live_demos=False, random_selection=False, from_episode_number=i
31 | )[0]
32 |
33 | self._task.set_variation(d.variation_number)
34 | desc, obs = self._task.reset_to_demo(d)
35 | self._lang_goal = desc[0]
36 |
37 | self._previous_obs_dict = self.extract_obs(obs)
38 | self._record_current_episode = (
39 | self.eval
40 | and self._record_every_n > 0
41 | and self._episode_index % self._record_every_n == 0
42 | )
43 | self._episode_index += 1
44 | self._recorded_images.clear()
45 |
46 | return self._previous_obs_dict
47 |
--------------------------------------------------------------------------------
/rvt/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | # initial source: https://colab.research.google.com/drive/1HAqemP4cE81SQ6QO1-N85j5bF4C0qLs0?usp=sharing
6 | # adapted to support loading from disk for faster initialization time
7 |
8 | # Adapted from: https://github.com/stepjam/ARM/blob/main/arm/c2farm/launch_utils.py
9 | import os
10 | import torch
11 | import pickle
12 | import logging
13 | import numpy as np
14 | from typing import List
15 |
16 | import clip
17 | import peract_colab.arm.utils as utils
18 |
19 | from peract_colab.rlbench.utils import get_stored_demo
20 | from yarr.utils.observation_type import ObservationElement
21 | from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
22 | from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
23 | from rlbench.backend.observation import Observation
24 | from rlbench.demo import Demo
25 |
26 | from rvt.utils.peract_utils import LOW_DIM_SIZE, IMAGE_SIZE, CAMERAS
27 | from rvt.libs.peract.helpers.demo_loading_utils import keypoint_discovery
28 | from rvt.libs.peract.helpers.utils import extract_obs
29 |
30 |
31 | def create_replay(
32 | batch_size: int,
33 | timesteps: int,
34 | disk_saving: bool,
35 | cameras: list,
36 | voxel_sizes,
37 | replay_size=3e5,
38 | ):
39 |
40 | trans_indicies_size = 3 * len(voxel_sizes)
41 | rot_and_grip_indicies_size = 3 + 1
42 | gripper_pose_size = 7
43 | ignore_collisions_size = 1
44 | max_token_seq_len = 77
45 | lang_feat_dim = 1024
46 | lang_emb_dim = 512
47 |
48 | # low_dim_state
49 | observation_elements = []
50 | observation_elements.append(
51 | ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
52 | )
53 |
54 | # rgb, depth, point cloud, intrinsics, extrinsics
55 | for cname in cameras:
56 | observation_elements.append(
57 | ObservationElement(
58 | "%s_rgb" % cname,
59 | (
60 | 3,
61 | IMAGE_SIZE,
62 | IMAGE_SIZE,
63 | ),
64 | np.float32,
65 | )
66 | )
67 | observation_elements.append(
68 | ObservationElement(
69 | "%s_depth" % cname,
70 | (
71 | 1,
72 | IMAGE_SIZE,
73 | IMAGE_SIZE,
74 | ),
75 | np.float32,
76 | )
77 | )
78 | observation_elements.append(
79 | ObservationElement(
80 | "%s_point_cloud" % cname,
81 | (
82 | 3,
83 | IMAGE_SIZE,
84 | IMAGE_SIZE,
85 | ),
86 | np.float32,
87 | )
88 | ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
89 | observation_elements.append(
90 | ObservationElement(
91 | "%s_camera_extrinsics" % cname,
92 | (
93 | 4,
94 | 4,
95 | ),
96 | np.float32,
97 | )
98 | )
99 | observation_elements.append(
100 | ObservationElement(
101 | "%s_camera_intrinsics" % cname,
102 | (
103 | 3,
104 | 3,
105 | ),
106 | np.float32,
107 | )
108 | )
109 |
110 | # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
111 | observation_elements.extend(
112 | [
113 | ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32),
114 | ReplayElement(
115 | "rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32
116 | ),
117 | ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32),
118 | ReplayElement("gripper_pose", (gripper_pose_size,), np.float32),
119 | ReplayElement(
120 | "lang_goal_embs",
121 | (
122 | max_token_seq_len,
123 | lang_emb_dim,
124 | ), # extracted from CLIP's language encoder
125 | np.float32,
126 | ),
127 | ReplayElement(
128 | "lang_goal", (1,), object
129 | ), # language goal string for debugging and visualization
130 | ]
131 | )
132 |
133 | extra_replay_elements = [
134 | ReplayElement("demo", (), bool),
135 | ReplayElement("keypoint_idx", (), int),
136 | ReplayElement("episode_idx", (), int),
137 | ReplayElement("keypoint_frame", (), int),
138 | ReplayElement("next_keypoint_frame", (), int),
139 | ReplayElement("sample_frame", (), int),
140 | ]
141 |
142 | replay_buffer = (
143 | UniformReplayBuffer( # all tuples in the buffer have equal sample weighting
144 | disk_saving=disk_saving,
145 | batch_size=batch_size,
146 | timesteps=timesteps,
147 | replay_capacity=int(replay_size),
148 | action_shape=(8,), # 3 translation + 4 rotation quaternion + 1 gripper open
149 | action_dtype=np.float32,
150 | reward_shape=(),
151 | reward_dtype=np.float32,
152 | update_horizon=1,
153 | observation_elements=observation_elements,
154 | extra_replay_elements=extra_replay_elements,
155 | )
156 | )
157 | return replay_buffer
158 |
159 |
160 | # discretize translation, rotation, gripper open, and ignore collision actions
161 | def _get_action(
162 | obs_tp1: Observation,
163 | obs_tm1: Observation,
164 | rlbench_scene_bounds: List[float], # metric 3D bounds of the scene
165 | voxel_sizes: List[int],
166 | rotation_resolution: int,
167 | crop_augmentation: bool,
168 | ):
169 | quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
170 | if quat[-1] < 0:
171 | quat = -quat
172 | disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
173 | attention_coordinate = obs_tp1.gripper_pose[:3]
174 | trans_indicies, attention_coordinates = [], []
175 | bounds = np.array(rlbench_scene_bounds)
176 | ignore_collisions = int(obs_tm1.ignore_collisions)
177 | for depth, vox_size in enumerate(
178 | voxel_sizes
179 | ): # only single voxelization-level is used in PerAct
180 | index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
181 | trans_indicies.extend(index.tolist())
182 | res = (bounds[3:] - bounds[:3]) / vox_size
183 | attention_coordinate = bounds[:3] + res * index
184 | attention_coordinates.append(attention_coordinate)
185 |
186 | rot_and_grip_indicies = disc_rot.tolist()
187 | grip = float(obs_tp1.gripper_open)
188 | rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
189 | return (
190 | trans_indicies,
191 | rot_and_grip_indicies,
192 | ignore_collisions,
193 | np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
194 | attention_coordinates,
195 | )
196 |
197 |
198 | # extract CLIP language features for goal string
199 | def _clip_encode_text(clip_model, text):
200 | x = clip_model.token_embedding(text).type(
201 | clip_model.dtype
202 | ) # [batch_size, n_ctx, d_model]
203 |
204 | x = x + clip_model.positional_embedding.type(clip_model.dtype)
205 | x = x.permute(1, 0, 2) # NLD -> LND
206 | x = clip_model.transformer(x)
207 | x = x.permute(1, 0, 2) # LND -> NLD
208 | x = clip_model.ln_final(x).type(clip_model.dtype)
209 |
210 | emb = x.clone()
211 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ clip_model.text_projection
212 |
213 | return x, emb
214 |
215 |
216 | # add individual data points to a replay
217 | def _add_keypoints_to_replay(
218 | replay: ReplayBuffer,
219 | task: str,
220 | task_replay_storage_folder: str,
221 | episode_idx: int,
222 | sample_frame: int,
223 | inital_obs: Observation,
224 | demo: Demo,
225 | episode_keypoints: List[int],
226 | cameras: List[str],
227 | rlbench_scene_bounds: List[float],
228 | voxel_sizes: List[int],
229 | rotation_resolution: int,
230 | crop_augmentation: bool,
231 | next_keypoint_idx: int,
232 | description: str = "",
233 | clip_model=None,
234 | device="cpu",
235 | ):
236 | prev_action = None
237 | obs = inital_obs
238 | for k in range(
239 | next_keypoint_idx, len(episode_keypoints)
240 | ): # confused here, it seems that there are many similar samples in the replay
241 | keypoint = episode_keypoints[k]
242 | obs_tp1 = demo[keypoint]
243 | obs_tm1 = demo[max(0, keypoint - 1)]
244 | (
245 | trans_indicies,
246 | rot_grip_indicies,
247 | ignore_collisions,
248 | action,
249 | attention_coordinates,
250 | ) = _get_action(
251 | obs_tp1,
252 | obs_tm1,
253 | rlbench_scene_bounds,
254 | voxel_sizes,
255 | rotation_resolution,
256 | crop_augmentation,
257 | )
258 |
259 | terminal = k == len(episode_keypoints) - 1
260 | reward = float(terminal) * 1.0 if terminal else 0
261 |
262 | obs_dict = extract_obs(
263 | obs,
264 | CAMERAS,
265 | t=k - next_keypoint_idx,
266 | prev_action=prev_action,
267 | episode_length=25,
268 | )
269 | tokens = clip.tokenize([description]).numpy()
270 | token_tensor = torch.from_numpy(tokens).to(device)
271 | with torch.no_grad():
272 | lang_feats, lang_embs = _clip_encode_text(clip_model, token_tensor)
273 | obs_dict["lang_goal_embs"] = lang_embs[0].float().detach().cpu().numpy()
274 |
275 | prev_action = np.copy(action)
276 |
277 | if k == 0:
278 | keypoint_frame = -1
279 | else:
280 | keypoint_frame = episode_keypoints[k - 1]
281 | others = {
282 | "demo": True,
283 | "keypoint_idx": k,
284 | "episode_idx": episode_idx,
285 | "keypoint_frame": keypoint_frame,
286 | "next_keypoint_frame": keypoint,
287 | "sample_frame": sample_frame,
288 | }
289 | final_obs = {
290 | "trans_action_indicies": trans_indicies,
291 | "rot_grip_action_indicies": rot_grip_indicies,
292 | "gripper_pose": obs_tp1.gripper_pose,
293 | "lang_goal": np.array([description], dtype=object),
294 | }
295 |
296 | others.update(final_obs)
297 | others.update(obs_dict)
298 |
299 | timeout = False
300 | replay.add(
301 | task,
302 | task_replay_storage_folder,
303 | action,
304 | reward,
305 | terminal,
306 | timeout,
307 | **others
308 | )
309 | obs = obs_tp1
310 | sample_frame = keypoint
311 |
312 | # final step
313 | obs_dict_tp1 = extract_obs(
314 | obs_tp1,
315 | CAMERAS,
316 | t=k + 1 - next_keypoint_idx,
317 | prev_action=prev_action,
318 | episode_length=25,
319 | )
320 | obs_dict_tp1["lang_goal_embs"] = lang_embs[0].float().detach().cpu().numpy()
321 |
322 | obs_dict_tp1.pop("wrist_world_to_cam", None)
323 | obs_dict_tp1.update(final_obs)
324 | replay.add_final(task, task_replay_storage_folder, **obs_dict_tp1)
325 |
326 |
327 | def fill_replay(
328 | replay: ReplayBuffer,
329 | task: str,
330 | task_replay_storage_folder: str,
331 | start_idx: int,
332 | num_demos: int,
333 | demo_augmentation: bool,
334 | demo_augmentation_every_n: int,
335 | cameras: List[str],
336 | rlbench_scene_bounds: List[float], # AKA: DEPTH0_BOUNDS
337 | voxel_sizes: List[int],
338 | rotation_resolution: int,
339 | crop_augmentation: bool,
340 | data_path: str,
341 | episode_folder: str,
342 | variation_desriptions_pkl: str,
343 | clip_model=None,
344 | device="cpu",
345 | ):
346 |
347 | disk_exist = False
348 | if replay._disk_saving:
349 | if os.path.exists(task_replay_storage_folder):
350 | print(
351 | "[Info] Replay dataset already exists in the disk: {}".format(
352 | task_replay_storage_folder
353 | ),
354 | flush=True,
355 | )
356 | disk_exist = True
357 | else:
358 | logging.info("\t saving to disk: %s", task_replay_storage_folder)
359 | os.makedirs(task_replay_storage_folder, exist_ok=True)
360 |
361 | if disk_exist:
362 | replay.recover_from_disk(task, task_replay_storage_folder)
363 | else:
364 | print("Filling replay ...")
365 | for d_idx in range(start_idx, start_idx + num_demos):
366 | print("Filling demo %d" % d_idx)
367 | demo = get_stored_demo(data_path=data_path, index=d_idx)
368 |
369 | # get language goal from disk
370 | varation_descs_pkl_file = os.path.join(
371 | data_path, episode_folder % d_idx, variation_desriptions_pkl
372 | )
373 | with open(varation_descs_pkl_file, "rb") as f:
374 | descs = pickle.load(f)
375 |
376 | # extract keypoints
377 | episode_keypoints = keypoint_discovery(demo)
378 | next_keypoint_idx = 0
379 | for i in range(len(demo) - 1):
380 | if not demo_augmentation and i > 0:
381 | break
382 | if i % demo_augmentation_every_n != 0: # choose only every n-th frame
383 | continue
384 |
385 | obs = demo[i]
386 | desc = descs[0]
387 | # if our starting point is past one of the keypoints, then remove it
388 | while (
389 | next_keypoint_idx < len(episode_keypoints)
390 | and i >= episode_keypoints[next_keypoint_idx]
391 | ):
392 | next_keypoint_idx += 1
393 | if next_keypoint_idx == len(episode_keypoints):
394 | break
395 | _add_keypoints_to_replay(
396 | replay,
397 | task,
398 | task_replay_storage_folder,
399 | d_idx,
400 | i,
401 | obs,
402 | demo,
403 | episode_keypoints,
404 | cameras,
405 | rlbench_scene_bounds,
406 | voxel_sizes,
407 | rotation_resolution,
408 | crop_augmentation,
409 | next_keypoint_idx=next_keypoint_idx,
410 | description=desc,
411 | clip_model=clip_model,
412 | device=device,
413 | )
414 |
415 | # save TERMINAL info in replay_info.npy
416 | task_idx = replay._task_index[task]
417 | with open(
418 | os.path.join(task_replay_storage_folder, "replay_info.npy"), "wb"
419 | ) as fp:
420 | np.save(
421 | fp,
422 | replay._store["terminal"][
423 | replay._task_replay_start_index[
424 | task_idx
425 | ] : replay._task_replay_start_index[task_idx]
426 | + replay._task_add_count[task_idx].value
427 | ],
428 | )
429 |
430 | print("Replay filled with demos.")
431 |
--------------------------------------------------------------------------------
/rvt/utils/ddp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import os
6 | import torch.distributed as dist
7 | import random
8 |
9 |
10 | def setup(rank, world_size, port):
11 | os.environ["MASTER_ADDR"] = "localhost"
12 | os.environ["MASTER_PORT"] = str(port)
13 |
14 | # initialize the process group
15 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
16 |
17 |
18 | def cleanup():
19 | dist.destroy_process_group()
20 |
--------------------------------------------------------------------------------
/rvt/utils/get_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import os
6 | import sys
7 | import shutil
8 | import torch
9 | import clip
10 |
11 | from rvt.libs.peract.helpers.utils import extract_obs
12 | from rvt.utils.rvt_utils import ForkedPdb
13 | from rvt.utils.dataset import create_replay, fill_replay
14 | from rvt.utils.peract_utils import (
15 | CAMERAS,
16 | SCENE_BOUNDS,
17 | EPISODE_FOLDER,
18 | VARIATION_DESCRIPTIONS_PKL,
19 | DEMO_AUGMENTATION_EVERY_N,
20 | ROTATION_RESOLUTION,
21 | VOXEL_SIZES,
22 | )
23 | from yarr.replay_buffer.wrappers.pytorch_replay_buffer import PyTorchReplayBuffer
24 |
25 |
26 | def get_dataset(
27 | tasks,
28 | BATCH_SIZE_TRAIN,
29 | BATCH_SIZE_TEST,
30 | TRAIN_REPLAY_STORAGE_DIR,
31 | TEST_REPLAY_STORAGE_DIR,
32 | DATA_FOLDER,
33 | NUM_TRAIN,
34 | NUM_VAL,
35 | refresh_replay,
36 | device,
37 | num_workers,
38 | only_train,
39 | sample_distribution_mode="transition_uniform",
40 | ):
41 |
42 | train_replay_buffer = create_replay(
43 | batch_size=BATCH_SIZE_TRAIN,
44 | timesteps=1,
45 | disk_saving=True,
46 | cameras=CAMERAS,
47 | voxel_sizes=VOXEL_SIZES,
48 | )
49 | if not only_train:
50 | test_replay_buffer = create_replay(
51 | batch_size=BATCH_SIZE_TEST,
52 | timesteps=1,
53 | disk_saving=True,
54 | cameras=CAMERAS,
55 | voxel_sizes=VOXEL_SIZES,
56 | )
57 |
58 | # load pre-trained language model
59 | try:
60 | clip_model, _ = clip.load("RN50", device="cpu") # CLIP-ResNet50
61 | clip_model = clip_model.to(device)
62 | clip_model.eval()
63 | except RuntimeError:
64 | print("WARNING: Setting Clip to None. Will not work if replay not on disk.")
65 | clip_model = None
66 |
67 | for task in tasks: # for each task
68 | # print("---- Preparing the data for {} task ----".format(task), flush=True)
69 | EPISODES_FOLDER_TRAIN = f"train/{task}/all_variations/episodes"
70 | EPISODES_FOLDER_VAL = f"val/{task}/all_variations/episodes"
71 | data_path_train = os.path.join(DATA_FOLDER, EPISODES_FOLDER_TRAIN)
72 | data_path_val = os.path.join(DATA_FOLDER, EPISODES_FOLDER_VAL)
73 | train_replay_storage_folder = f"{TRAIN_REPLAY_STORAGE_DIR}/{task}"
74 | test_replay_storage_folder = f"{TEST_REPLAY_STORAGE_DIR}/{task}"
75 |
76 | # if refresh_replay, then remove the existing replay data folder
77 | if refresh_replay:
78 | print("[Info] Remove exisitng replay dataset as requested.", flush=True)
79 | if os.path.exists(train_replay_storage_folder) and os.path.isdir(
80 | train_replay_storage_folder
81 | ):
82 | shutil.rmtree(train_replay_storage_folder)
83 | print(f"remove {train_replay_storage_folder}")
84 | if os.path.exists(test_replay_storage_folder) and os.path.isdir(
85 | test_replay_storage_folder
86 | ):
87 | shutil.rmtree(test_replay_storage_folder)
88 | print(f"remove {test_replay_storage_folder}")
89 |
90 | # print("----- Train Buffer -----")
91 | fill_replay(
92 | replay=train_replay_buffer,
93 | task=task,
94 | task_replay_storage_folder=train_replay_storage_folder,
95 | start_idx=0,
96 | num_demos=NUM_TRAIN,
97 | demo_augmentation=True,
98 | demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
99 | cameras=CAMERAS,
100 | rlbench_scene_bounds=SCENE_BOUNDS,
101 | voxel_sizes=VOXEL_SIZES,
102 | rotation_resolution=ROTATION_RESOLUTION,
103 | crop_augmentation=False,
104 | data_path=data_path_train,
105 | episode_folder=EPISODE_FOLDER,
106 | variation_desriptions_pkl=VARIATION_DESCRIPTIONS_PKL,
107 | clip_model=clip_model,
108 | device=device,
109 | )
110 |
111 | if not only_train:
112 | # print("----- Test Buffer -----")
113 | fill_replay(
114 | replay=test_replay_buffer,
115 | task=task,
116 | task_replay_storage_folder=test_replay_storage_folder,
117 | start_idx=0,
118 | num_demos=NUM_VAL,
119 | demo_augmentation=True,
120 | demo_augmentation_every_n=DEMO_AUGMENTATION_EVERY_N,
121 | cameras=CAMERAS,
122 | rlbench_scene_bounds=SCENE_BOUNDS,
123 | voxel_sizes=VOXEL_SIZES,
124 | rotation_resolution=ROTATION_RESOLUTION,
125 | crop_augmentation=False,
126 | data_path=data_path_val,
127 | episode_folder=EPISODE_FOLDER,
128 | variation_desriptions_pkl=VARIATION_DESCRIPTIONS_PKL,
129 | clip_model=clip_model,
130 | device=device,
131 | )
132 |
133 | # delete the CLIP model since we have already extracted language features
134 | del clip_model
135 | with torch.cuda.device(device):
136 | torch.cuda.empty_cache()
137 |
138 | # wrap buffer with PyTorch dataset and make iterator
139 | train_wrapped_replay = PyTorchReplayBuffer(
140 | train_replay_buffer,
141 | sample_mode="random",
142 | num_workers=num_workers,
143 | sample_distribution_mode=sample_distribution_mode,
144 | )
145 | train_dataset = train_wrapped_replay.dataset()
146 |
147 | if only_train:
148 | test_dataset = None
149 | else:
150 | test_wrapped_replay = PyTorchReplayBuffer(
151 | test_replay_buffer,
152 | sample_mode="enumerate",
153 | num_workers=num_workers,
154 | )
155 | test_dataset = test_wrapped_replay.dataset()
156 | return train_dataset, test_dataset
157 |
--------------------------------------------------------------------------------
/rvt/utils/lr_sched_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | from torch.optim.lr_scheduler import _LRScheduler
6 | from torch.optim.lr_scheduler import ReduceLROnPlateau
7 |
8 | # source: https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py
9 | # updated such that it is suitable for cases when epoch number start from 0
10 | # lr constantly increases from "epoch 0" to "epoch (total_epoch - 1)" such that
11 | # lr at epoch is same as the base_lr for the after_scheduler
12 | # Only tested for case when multiplier is 1.0 and after schduler is a
13 | # MultiStepLR
14 |
15 |
16 | class GradualWarmupScheduler(_LRScheduler):
17 | """Gradually warm-up(increasing) learning rate in optimizer.
18 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
19 | Args:
20 | optimizer (Optimizer): Wrapped optimizer.
21 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
22 | total_epoch: target learning rate is reached at total_epoch, gradually
23 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
24 | """
25 |
26 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
27 | self.multiplier = multiplier
28 | if self.multiplier < 1.0:
29 | raise ValueError("multiplier should be greater thant or equal to 1.")
30 | self.total_epoch = total_epoch
31 | self.after_scheduler = after_scheduler
32 | self.finished = False
33 | super(GradualWarmupScheduler, self).__init__(optimizer)
34 |
35 | def get_lr(self):
36 | if (self.last_epoch + 1) > self.total_epoch:
37 | if self.after_scheduler:
38 | if not self.finished:
39 | self.after_scheduler.base_lrs = [
40 | base_lr * self.multiplier for base_lr in self.base_lrs
41 | ]
42 | self.finished = True
43 | return self.after_scheduler.get_last_lr()
44 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
45 |
46 | if self.multiplier == 1.0:
47 | return [
48 | base_lr * ((float(self.last_epoch) + 1) / self.total_epoch)
49 | for base_lr in self.base_lrs
50 | ]
51 | else:
52 | return [
53 | base_lr
54 | * (
55 | (self.multiplier - 1.0) * (self.last_epoch + 1) / self.total_epoch
56 | + 1.0
57 | )
58 | for base_lr in self.base_lrs
59 | ]
60 |
61 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
62 | if epoch is None:
63 | epoch = self.last_epoch + 1
64 | self.last_epoch = (
65 | epoch if epoch != 0 else 1
66 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
67 | if self.last_epoch <= self.total_epoch:
68 | warmup_lr = [
69 | base_lr
70 | * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
71 | for base_lr in self.base_lrs
72 | ]
73 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
74 | param_group["lr"] = lr
75 | else:
76 | if epoch is None:
77 | self.after_scheduler.step(metrics, None)
78 | else:
79 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
80 |
81 | def step(self, epoch=None, metrics=None):
82 | if type(self.after_scheduler) != ReduceLROnPlateau:
83 | if self.finished and self.after_scheduler:
84 | if epoch is None:
85 | self.after_scheduler.step(None)
86 | else:
87 | self.after_scheduler.step(epoch)
88 | self._last_lr = self.after_scheduler.get_last_lr()
89 | else:
90 | return super(GradualWarmupScheduler, self).step(epoch)
91 | else:
92 | self.step_ReduceLROnPlateau(metrics, epoch)
93 |
94 | def state_dict(self):
95 | state_dict = {
96 | key: value
97 | for key, value in self.__dict__.items()
98 | if key not in ["optimizer", "after_scheduler"]
99 | }
100 |
101 | if not (self.after_scheduler is None):
102 | state_dict["after_scheduler_state_dict"] = self.after_scheduler.state_dict()
103 |
104 | return state_dict
105 |
106 | def load_state_dict(self, state_dict):
107 | if self.after_scheduler is None:
108 | assert not ("after_scheduler_state_dict" in state_dict)
109 | else:
110 | self.after_scheduler.load_state_dict(
111 | state_dict["after_scheduler_state_dict"]
112 | )
113 | del state_dict["after_scheduler_state_dict"]
114 |
115 | self.__dict__.update(state_dict)
116 |
--------------------------------------------------------------------------------
/rvt/utils/peract_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import torch
6 | from omegaconf import OmegaConf
7 |
8 | from rvt.models.peract_official import create_agent_our
9 | from peract_colab.arm.utils import stack_on_channel
10 | from torch.optim.lr_scheduler import CosineAnnealingLR
11 | from rvt.utils.lr_sched_utils import GradualWarmupScheduler
12 |
13 | # Contants
14 | # TODO: Unclear about the best way to handle them
15 | CAMERAS = ["front", "left_shoulder", "right_shoulder", "wrist"]
16 | SCENE_BOUNDS = [
17 | -0.3,
18 | -0.5,
19 | 0.6,
20 | 0.7,
21 | 0.5,
22 | 1.6,
23 | ] # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized
24 | IMAGE_SIZE = 128
25 | VOXEL_SIZES = [100] # 100x100x100 voxels
26 | LOW_DIM_SIZE = 4 # {left_finger_joint, right_finger_joint, gripper_open, timestep}
27 |
28 | DATA_FOLDER = "data"
29 | EPISODE_FOLDER = "episode%d"
30 | VARIATION_DESCRIPTIONS_PKL = "variation_descriptions.pkl" # the pkl file that contains language goals for each demonstration
31 | DEMO_AUGMENTATION_EVERY_N = 10 # sample n-th frame in demo
32 | ROTATION_RESOLUTION = 5 # degree increments per axis
33 | # settings
34 | NUM_LATENTS = 512 # PerceiverIO latents
35 |
36 |
37 | def _norm_rgb(x):
38 | return (x.float() / 255.0) * 2.0 - 1.0
39 |
40 |
41 | def _preprocess_inputs(replay_sample, cameras):
42 | obs, pcds = [], []
43 | for n in cameras:
44 | rgb = stack_on_channel(replay_sample["%s_rgb" % n])
45 | pcd = stack_on_channel(replay_sample["%s_point_cloud" % n])
46 |
47 | rgb = _norm_rgb(rgb)
48 |
49 | obs.append(
50 | [rgb, pcd]
51 | ) # obs contains both rgb and pointcloud (used in ARM for other baselines)
52 | pcds.append(pcd) # only pointcloud
53 | return obs, pcds
54 |
55 |
56 | def get_official_peract(
57 | cfg_path,
58 | training,
59 | device,
60 | bs,
61 | ):
62 | """
63 | Creates an official peract agent
64 | :param cfg_path: path to the config file
65 | :param training: whether to build the agent in training mode
66 | :param device: device to build the agent on
67 | :param bs: batch size, does not matter when we need a model for inference.
68 | """
69 | with open(cfg_path, "r") as f:
70 | cfg = OmegaConf.load(f)
71 |
72 | # we need to modify the batch size as in our case we specify batchsize per
73 | # gpu
74 | cfg.replay.batch_size = bs
75 | agent = create_agent_our(cfg)
76 | agent.build(training=training, device=device)
77 |
78 | return agent
79 |
--------------------------------------------------------------------------------
/rvt/utils/rlbench_planning.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | import numpy as np
6 | from rlbench.action_modes.arm_action_modes import (
7 | EndEffectorPoseViaPlanning,
8 | Scene,
9 | )
10 |
11 |
12 | class EndEffectorPoseViaPlanning2(EndEffectorPoseViaPlanning):
13 | def __init__(self, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 |
16 | def action(self, scene: Scene, action: np.ndarray, ignore_collisions: bool = True):
17 | action[:3] = np.clip(
18 | action[:3],
19 | np.array(
20 | [scene._workspace_minx, scene._workspace_miny, scene._workspace_minz]
21 | )
22 | + 1e-7,
23 | np.array(
24 | [scene._workspace_maxx, scene._workspace_maxy, scene._workspace_maxz]
25 | )
26 | - 1e-7,
27 | )
28 |
29 | super().action(scene, action, ignore_collisions)
30 |
--------------------------------------------------------------------------------
/rvt/utils/rvt_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | """
6 | Utility function for Our Agent
7 | """
8 | import pdb
9 | import argparse
10 | import sys
11 | import signal
12 | from datetime import datetime
13 |
14 | import torch
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | from torch.nn.parallel import DistributedDataParallel as DDP
18 |
19 | import rvt.utils.peract_utils as peract_utils
20 | from rvt.models.peract_official import PreprocessAgent2
21 |
22 |
23 | def get_pc_img_feat(obs, pcd, bounds=None):
24 | """
25 | preprocess the data in the peract to our framework
26 | """
27 | # obs, pcd = peract_utils._preprocess_inputs(batch)
28 | bs = obs[0][0].shape[0]
29 | # concatenating the points from all the cameras
30 | # (bs, num_points, 3)
31 | pc = torch.cat([p.permute(0, 2, 3, 1).reshape(bs, -1, 3) for p in pcd], 1)
32 | _img_feat = [o[0] for o in obs]
33 | img_dim = _img_feat[0].shape[1]
34 | # (bs, num_points, 3)
35 | img_feat = torch.cat(
36 | [p.permute(0, 2, 3, 1).reshape(bs, -1, img_dim) for p in _img_feat], 1
37 | )
38 |
39 | img_feat = (img_feat + 1) / 2
40 |
41 | # x_min, y_min, z_min, x_max, y_max, z_max = bounds
42 | # inv_pnt = (
43 | # (pc[:, :, 0] < x_min)
44 | # | (pc[:, :, 0] > x_max)
45 | # | (pc[:, :, 1] < y_min)
46 | # | (pc[:, :, 1] > y_max)
47 | # | (pc[:, :, 2] < z_min)
48 | # | (pc[:, :, 2] > z_max)
49 | # )
50 |
51 | # # TODO: move from a list to a better batched version
52 | # pc = [pc[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)]
53 | # img_feat = [img_feat[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)]
54 |
55 | return pc, img_feat
56 |
57 |
58 | def move_pc_in_bound(pc, img_feat, bounds, no_op=False):
59 | """
60 | :param no_op: no operation
61 | """
62 | if no_op:
63 | return pc, img_feat
64 |
65 | x_min, y_min, z_min, x_max, y_max, z_max = bounds
66 | inv_pnt = (
67 | (pc[:, :, 0] < x_min)
68 | | (pc[:, :, 0] > x_max)
69 | | (pc[:, :, 1] < y_min)
70 | | (pc[:, :, 1] > y_max)
71 | | (pc[:, :, 2] < z_min)
72 | | (pc[:, :, 2] > z_max)
73 | | torch.isnan(pc[:, :, 0])
74 | | torch.isnan(pc[:, :, 1])
75 | | torch.isnan(pc[:, :, 2])
76 | )
77 |
78 | # TODO: move from a list to a better batched version
79 | pc = [pc[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)]
80 | img_feat = [img_feat[i, ~_inv_pnt] for i, _inv_pnt in enumerate(inv_pnt)]
81 | return pc, img_feat
82 |
83 |
84 | def count_parameters(model):
85 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
86 |
87 |
88 | class TensorboardManager:
89 | def __init__(self, path):
90 | self.writer = SummaryWriter(path)
91 |
92 | def update(self, split, step, vals):
93 | for k, v in vals.items():
94 | if "image" in k:
95 | for i, x in enumerate(v):
96 | self.writer.add_image(f"{split}_{step}", x, i)
97 | elif "hist" in k:
98 | if isinstance(v, list):
99 | self.writer.add_histogram(k, v, step)
100 | elif isinstance(v, dict):
101 | hist_id = {}
102 | for i, idx in enumerate(sorted(v.keys())):
103 | self.writer.add_histogram(f"{split}_{k}_{step}", v[idx], i)
104 | hist_id[i] = idx
105 | self.writer.add_text(f"{split}_{k}_{step}_id", f"{hist_id}")
106 | else:
107 | assert False
108 | else:
109 | self.writer.add_scalar("%s_%s" % (split, k), v, step)
110 |
111 | def close(self):
112 | self.writer.flush()
113 | self.writer.close()
114 |
115 |
116 | class ForkedPdb(pdb.Pdb):
117 | """A Pdb subclass that may be used
118 | from a forked multiprocessing child
119 |
120 | """
121 |
122 | def interaction(self, *args, **kwargs):
123 | _stdin = sys.stdin
124 | try:
125 | sys.stdin = open("/dev/stdin")
126 | pdb.Pdb.interaction(self, *args, **kwargs)
127 | finally:
128 | sys.stdin = _stdin
129 |
130 |
131 | def short_name(cfg_opts):
132 | SHORT_FORMS = {
133 | "peract": "PA",
134 | "sample_distribution_mode": "SDM",
135 | "optimizer_type": "OPT",
136 | "lr_cos_dec": "LCD",
137 | "num_workers": "NW",
138 | "True": "T",
139 | "False": "F",
140 | "pe_fix": "pf",
141 | "transform_augmentation_rpy": "tar",
142 | "lambda_weight_l2": "l2",
143 | "resume": "RES",
144 | "inp_pre_pro": "IPP",
145 | "inp_pre_con": "IPC",
146 | "cvx_up": "CU",
147 | "stage_two": "ST",
148 | "feat_ver": "FV",
149 | "lamb": "L",
150 | "img_size": "IS",
151 | "img_patch_size": "IPS",
152 | "rlbench": "RLB",
153 | "move_pc_in_bound": "MPIB",
154 | "rend": "R",
155 | "xops": "X",
156 | "warmup_steps": "WS",
157 | "epochs": "E",
158 | "amp": "A",
159 | }
160 |
161 | if "resume" in cfg_opts:
162 | cfg_opts = cfg_opts.split(" ")
163 | res_idx = cfg_opts.index("resume")
164 | cfg_opts.pop(res_idx + 1)
165 | cfg_opts = " ".join(cfg_opts)
166 |
167 | cfg_opts = cfg_opts.replace(" ", "_")
168 | cfg_opts = cfg_opts.replace("/", "_")
169 | cfg_opts = cfg_opts.replace("[", "")
170 | cfg_opts = cfg_opts.replace("]", "")
171 | cfg_opts = cfg_opts.replace("..", "")
172 | for a, b in SHORT_FORMS.items():
173 | cfg_opts = cfg_opts.replace(a, b)
174 |
175 | return cfg_opts
176 |
177 |
178 | def get_num_feat(cfg):
179 | num_feat = cfg.num_rotation_classes * 3
180 | # 2 for grip, 2 for collision
181 | num_feat += 4
182 | return num_feat
183 |
184 |
185 | def get_eval_parser():
186 | parser = argparse.ArgumentParser()
187 |
188 | parser.add_argument(
189 | "--tasks", type=str, nargs="+", default=["insert_onto_square_peg"]
190 | )
191 | parser.add_argument("--model-folder", type=str, default=None)
192 | parser.add_argument("--eval-datafolder", type=str, default="./data/val/")
193 | parser.add_argument(
194 | "--start-episode",
195 | type=int,
196 | default=0,
197 | help="start to evaluate from which episode",
198 | )
199 | parser.add_argument(
200 | "--eval-episodes",
201 | type=int,
202 | default=10,
203 | help="how many episodes to be evaluated for each task",
204 | )
205 | parser.add_argument(
206 | "--episode-length",
207 | type=int,
208 | default=25,
209 | help="maximum control steps allowed for each episode",
210 | )
211 | parser.add_argument("--headless", action="store_true", default=False)
212 | parser.add_argument("--ground-truth", action="store_true", default=False)
213 | parser.add_argument("--exp_cfg_path", type=str, default=None)
214 | parser.add_argument("--mvt_cfg_path", type=str, default=None)
215 | parser.add_argument("--peract_official", action="store_true")
216 | parser.add_argument(
217 | "--peract_model_dir",
218 | type=str,
219 | default="runs/peract_official/seed0/weights/600000",
220 | )
221 | parser.add_argument("--device", type=int, default=0)
222 | parser.add_argument("--log-name", type=str, default=None)
223 | parser.add_argument("--model-name", type=str, default=None)
224 | parser.add_argument("--use-input-place-with-mean", action="store_true")
225 | parser.add_argument("--save-video", action="store_true")
226 | parser.add_argument("--skip", action="store_true")
227 |
228 | return parser
229 |
230 |
231 | RLBENCH_TASKS = [
232 | "put_item_in_drawer",
233 | "reach_and_drag",
234 | "turn_tap",
235 | "slide_block_to_color_target",
236 | "open_drawer",
237 | "put_groceries_in_cupboard",
238 | "place_shape_in_shape_sorter",
239 | "put_money_in_safe",
240 | "push_buttons",
241 | "close_jar",
242 | "stack_blocks",
243 | "place_cups",
244 | "place_wine_at_rack_location",
245 | "light_bulb_in",
246 | "sweep_to_dustpan_of_size",
247 | "insert_onto_square_peg",
248 | "meat_off_grill",
249 | "stack_cups",
250 | ]
251 |
252 |
253 | def load_agent(agent_path, agent=None, only_epoch=False):
254 | if isinstance(agent, PreprocessAgent2):
255 | assert not only_epoch
256 | agent._pose_agent.load_weights(agent_path)
257 | return 0
258 |
259 | checkpoint = torch.load(agent_path, map_location="cpu")
260 | epoch = checkpoint["epoch"]
261 |
262 | if not only_epoch:
263 | if hasattr(agent, "_q"):
264 | model = agent._q
265 | elif hasattr(agent, "_network"):
266 | model = agent._network
267 | optimizer = agent._optimizer
268 | lr_sched = agent._lr_sched
269 |
270 | if isinstance(model, DDP):
271 | model = model.module
272 |
273 | try:
274 | model.load_state_dict(checkpoint["model_state"])
275 | except RuntimeError:
276 | try:
277 | print(
278 | "WARNING: loading states in mvt1. "
279 | "Be cautious if you are using a two stage network."
280 | )
281 | model.mvt1.load_state_dict(checkpoint["model_state"])
282 | except RuntimeError:
283 | print(
284 | "WARNING: loading states with strick=False! "
285 | "KNOW WHAT YOU ARE DOING!!"
286 | )
287 | model.load_state_dict(checkpoint["model_state"], strict=False)
288 |
289 | if "optimizer_state" in checkpoint:
290 | optimizer.load_state_dict(checkpoint["optimizer_state"])
291 | else:
292 | print(
293 | "WARNING: No optimizer_state in checkpoint" "KNOW WHAT YOU ARE DOING!!"
294 | )
295 |
296 | if "lr_sched_state" in checkpoint:
297 | lr_sched.load_state_dict(checkpoint["lr_sched_state"])
298 | else:
299 | print(
300 | "WARNING: No lr_sched_state in checkpoint" "KNOW WHAT YOU ARE DOING!!"
301 | )
302 |
303 | return epoch
304 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | """
6 | Setup of RVT
7 | Author: Ankit Goyal
8 | """
9 | from setuptools import setup, find_packages
10 |
11 | requirements = [
12 | "numpy",
13 | "scipy",
14 | "einops",
15 | "pyrender",
16 | "transformers",
17 | "omegaconf",
18 | "natsort",
19 | "cffi",
20 | "pandas",
21 | "tensorflow",
22 | "pyquaternion",
23 | "matplotlib",
24 | "bitsandbytes==0.38.1",
25 | "transforms3d",
26 | "clip @ git+https://github.com/openai/CLIP.git",
27 | ]
28 |
29 | __version__ = "0.0.1"
30 | setup(
31 | name="rvt",
32 | version=__version__,
33 | description="RVT",
34 | long_description="",
35 | author="Ankit Goyal",
36 | author_email="angoyal@nvidia.com",
37 | url="",
38 | keywords="robotics,computer vision",
39 | classifiers=[
40 | "Programming Language :: Python",
41 | "Programming Language :: Python :: 3.8",
42 | "Natural Language :: English",
43 | "Topic :: Scientific/Engineering",
44 | ],
45 | packages=['rvt'],
46 | install_requires=requirements,
47 | extras_require={
48 | "xformers": [
49 | "xformers @ git+https://github.com/facebookresearch/xformers.git@main#egg=xformers",
50 | ]
51 | },
52 | )
53 |
--------------------------------------------------------------------------------