├── LICENSE
├── README.md
├── assets
├── fence
│ ├── obj.mtl
│ ├── tinker.obj
│ └── tinker.urdf
├── object_id
│ ├── shapenet_id.json
│ └── ycb_id.json
├── stick
│ └── stick.urdf
└── ur5
│ ├── collision
│ ├── base.stl
│ ├── forearm.stl
│ ├── shoulder.stl
│ ├── upperarm.stl
│ ├── wrist1.stl
│ ├── wrist2.stl
│ └── wrist3.stl
│ ├── mount.urdf
│ ├── notes.txt
│ ├── ur5.urdf
│ └── visual
│ ├── base.stl
│ ├── forearm.stl
│ ├── shoulder.stl
│ ├── upperarm.stl
│ ├── wrist1.stl
│ ├── wrist2.stl
│ └── wrist3.stl
├── binvox_utils.py
├── data.py
├── data
└── README.md
├── data_generation.py
├── figures
└── teaser.jpg
├── forward_warp.py
├── fusion.py
├── model.py
├── model_utils.py
├── object_models
└── README.md
├── pretrained_models
├── 3dflow.pth
├── README.md
├── dsr.pth
├── dsr_ft.pth
├── gtwarp.pth
├── nowarp.pth
└── single.pth
├── requirements.txt
├── se3
├── __pycache__
│ ├── se3_module.cpython-36.pyc
│ ├── se3_utils.cpython-36.pyc
│ ├── se3aa.cpython-36.pyc
│ ├── se3euler.cpython-36.pyc
│ ├── se3quat.cpython-36.pyc
│ └── se3spquat.cpython-36.pyc
├── se3_module.py
├── se3_utils.py
├── se3aa.py
├── se3euler.py
├── se3quat.py
└── se3spquat.py
├── sim.py
├── sim_env.py
├── test.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Columbia Robovision Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning 3D Dynamic Scene Representations for Robot Manipulation
2 |
3 | [Zhenjia Xu](http://www.zhenjiaxu.com/)1\*,
4 | [Zhanpeng He](https://zhanpenghe.github.io/)1\*,
5 | [Jiajun Wu](https://jiajunwu.com/)2,
6 | [Shuran Song](https://www.cs.columbia.edu/~shurans/)1
7 |
8 | 1Columbia University, 2Stanford University
9 |
10 | [CoRL 2020](https://www.robot-learning.org/)
11 |
12 | ### [Project Page](https://dsr-net.cs.columbia.edu/) | [Video](https://youtu.be/GQjYG3nQJ80) | [arXiv](https://arxiv.org/abs/2011.01968)
13 |
14 | ## Overview
15 | This repo contains the PyTorch implementation for paper "Learning 3D Dynamic Scene Representations for Robot Manipulation".
16 | 
17 |
18 | ## Content
19 |
20 | - [Prerequisites](#prerequisites)
21 | - [Data Preparation](#data-preparation)
22 | - [Pretrained Models](#pretrained-models)
23 | - [Training](#training)
24 |
25 | ## Prerequisites
26 |
27 | The code is built with Python 3.6. Libraries are listed in [requirements.txt](requirements.txt):
28 |
29 | ## Data Preparation
30 |
31 | ### Download Testing Data
32 | The following two testing datasets can be download.
33 | - [Sim](https://dsr-net.cs.columbia.edu/download/data/sim_test_data.zip): 400 sequences, generated in pybullet.
34 | - [Real](https://dsr-net.cs.columbia.edu/download/data/real_test_data.zip): 150 sequences, with full annotations.
35 |
36 | ### Generate Training Data
37 | Download object mesh: [shapenet](https://dsr-net.cs.columbia.edu/download/object_models/shapenet.zip) and [ycb](https://dsr-net.cs.columbia.edu/download/object_models/ycb.zip).
38 |
39 | To generate data in simulation, one can run
40 | ```
41 | python data_generation.py --data_path [path to data] --train_num [number of training sequences] --test_num [number of testing sequences] --object_type [type of objects]
42 | ```
43 | Where the `object_type` can be `cube`, `shpenet`, or `ycb`.
44 | The training data in the paper can be generated with the followint scripts:
45 | ```
46 | # cube
47 | python data_generation.py --data_path data/cube_train --train_num 4000 --test_num 400 --object_type cube
48 |
49 | # shapenet
50 | python data_generation.py --data_path data/shapenet_train --train_num 4000 --test_num 400 --object_type shapenet
51 | ```
52 |
53 | ## Pretrained Models
54 | Some of the pretrained models can be download in [pretrained_models](pretrained_models).
55 | To evaluate the pretrained models, one can run
56 | ```
57 | python test.py --resume [path to model] --data_path [path to data] --model_type [type of model] --test_type [type of test]
58 | ```
59 | where `model_type` can be one of the following:
60 | - `dsr`: DSR-Net introduced in the paper.
61 | - `single`: It does not use any history aggregation.
62 | - `nowarp`: It does not warp the representation before aggregation.
63 | - `gtwarp`: It warps the representation with ground truth motion (i.e., performance oracle)
64 | - `3dflow`: It predicts per-voxel scene flow for the entire 3D volume.
65 |
66 | Both motion prediction and mask prediction can be evaluated by choosing different `test_type`:
67 | - motion prediction: `motion_visible` or `motion_full`
68 | - mask prediction: `mask_ordered` or `mask_unordered`
69 |
70 | (Please refer to our paper for detailed explanation of each type of evaluation)
71 |
72 | Here are several examples:
73 | ```
74 | # evaluate mask prediction (ordered) of DSR-Net using real data:
75 | python test.py --resume [path to dsr model] --data_path [path to real data] --model_type dsr --test_type mask_ordered
76 |
77 | # evaluate mask prediction (unordered) of DSR-Net(finetuned) using real data:
78 | python test.py --resume [path to dsr_ft model] --data_path [path to real data] --model_type dsr --test_type mask_unordered
79 |
80 | # evaluate motion prediction (visible surface) of NoWarp model using sim data:
81 | python test.py --resume [path to nowarp model] --data_path [path to sim data] --model_type nowarp --test_type motion_visible
82 |
83 | # evaluate motion prediction (full volume) of SingleStep model using sim data:
84 | python test.py --resume [path to single model] --data_path [path to sim data] --model_type single --test_type motion_full
85 | ```
86 |
87 |
88 | ## Training
89 | Various training options can be modified or toggled on/off with different flags (run `python main.py -h` to see all options):
90 | ```
91 | usage: train.py [-h] [--exp EXP] [--gpus GPUS [GPUS ...]] [--resume RESUME]
92 | [--data_path DATA_PATH] [--object_num OBJECT_NUM]
93 | [--seq_len SEQ_LEN] [--batch BATCH] [--workers WORKERS]
94 | [--model_type {dsr,single,nowarp,gtwarp,3dflow}]
95 | [--transform_type {affine,se3euler,se3aa,se3spquat,se3quat}]
96 | [--alpha_motion ALPHA_MOTION] [--alpha_mask ALPHA_MASK]
97 | [--snapshot_freq SNAPSHOT_FREQ] [--epoch EPOCH] [--finetune]
98 | [--seed SEED] [--dist_backend DIST_BACKEND]
99 | [--dist_url DIST_URL]
100 | ```
101 | ### Training of DSR-Net
102 | Since the aggregation ability depends on the accuracy of motion prediction, we split the training process into three stages from easy to hard: (1) single-step on cube dataset; (2) multi-step on cube dataset; (3) multi-step on ShapeNet dataset.
103 | ```
104 | # Stage 1 (single-step on cube dataset)
105 | python train.py --exp dsr_stage1 --data_path [path to cube dataset] --seq_len 1 --model_type dsr --epoch 30
106 |
107 | # Stage 2 (multi-step on cube dataset)
108 | python train.py --exp dsr_stage2 --resume [path to stage1] --data_path [path to cube dataset] --seq_len 10 --model_type dsr --epoch 20 --finetune
109 |
110 | # Stage 3 (multi-step on shapenet dataset)
111 | python train.py --exp dsr_stage3 --resume [path to stage2] --data_path [path to shapenet dataset] --seq_len 10 --model_type dsr --epoch 20 --finetune
112 | ```
113 |
114 | ### Training of Baselines
115 | - `nowarp` and `gtwarp`. Use the same scripts as DSR-Net with corresponding `model_type`.
116 |
117 | - `single` and `3dflow`. Two-stage training: (1) single step on cube dataset; (2) single step on Shapenet dataset.
118 |
119 | ## BibTeX
120 | ```
121 | @inproceedings{xu2020learning,
122 | title={Learning 3D Dynamic Scene Representations for Robot Manipulation},
123 | author={Xu, Zhenjia and He, Zhanpeng and Wu, Jiajun and Song, Shuran},
124 | booktitle={Conference on Robot Learning (CoRL)},
125 | year={2020}
126 | }
127 | ```
128 |
129 | ## License
130 |
131 | This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
132 |
133 |
134 | ## Acknowledgement
135 | - THe code for [SE3](se3) is modified from [se3posenets-pytorch](https://github.com/abyravan/se3posenets-pytorch)
136 | - The code for [TSDF fusion](fusion.py) is modified from [tsdf-fusion-python](https://github.com/andyzeng/tsdf-fusion-python).
137 | - The code for [binvox processing](binvox_utils.py) is modified from [binvox-rw-py](https://github.com/dimatura/binvox-rw-py).
138 |
--------------------------------------------------------------------------------
/assets/fence/obj.mtl:
--------------------------------------------------------------------------------
1 | # Color definition for Tinkercad Obj File 2015
2 |
3 | newmtl color_11593967
4 | Ka 0 0 0
5 | Kd 0.6901960784313725 0.9098039215686274 0.9372549019607843
6 | d 1.0
7 | illum 0.0
8 |
9 | newmtl color_16500122
10 | Ka 0 0 0
11 | Kd 0.984313725490196 0.7725490196078432 0.6039215686274509
12 | d 1.0
13 | illum 0.0
14 |
15 | newmtl color_16311991
16 | Ka 0 0 0
17 | Kd 0.9725490196078431 0.9019607843137255 0.7176470588235294
18 | d 1.0
19 | illum 0.0
20 |
21 | newmtl color_13165757
22 | Ka 0 0 0
23 | Kd 0.7843137254901961 0.8941176470588236 0.7411764705882353
24 | d 1.0
25 | illum 0.0
26 |
27 |
--------------------------------------------------------------------------------
/assets/fence/tinker.obj:
--------------------------------------------------------------------------------
1 | # Object Export From Tinkercad Server 2015
2 |
3 | mtllib obj.mtl
4 |
5 | o obj_0
6 | v 52 -50 30
7 | v 52 -50 0
8 | v 52 50 0
9 | v 52 50 30
10 | v 50 50 30
11 | v 50 -50 30
12 | v 50 -50 0
13 | v 50 50 0
14 | v 50 52 0
15 | v 50 52 30
16 | v -52 50 30
17 | v -52 50 0
18 | v -52 -50 0
19 | v -52 -50 30
20 | v 50 -52 30
21 | v 50 -52 0
22 | v -50 50 30
23 | v -50 -50 30
24 | v -50 -50 0
25 | v -50 50 0
26 | v -50 -52 0
27 | v -50 -52 30
28 | v -50 52 30
29 | v -50 52 0
30 | # 24 vertices
31 |
32 | g group_0_11593967
33 |
34 | usemtl color_11593967
35 | s 0
36 |
37 | f 1 2 3
38 | f 1 3 4
39 | f 4 5 6
40 | f 4 6 1
41 | f 2 7 8
42 | f 2 8 3
43 | f 6 7 2
44 | f 6 2 1
45 | f 4 3 8
46 | f 4 8 5
47 | f 8 7 6
48 | f 8 6 5
49 | # 12 faces
50 |
51 | g group_0_13165757
52 |
53 | usemtl color_13165757
54 | s 0
55 |
56 | f 15 16 7
57 | f 15 7 6
58 | f 21 22 18
59 | f 21 18 19
60 | f 6 18 22
61 | f 6 22 15
62 | f 22 21 16
63 | f 22 16 15
64 | f 16 21 19
65 | f 16 19 7
66 | f 6 7 19
67 | f 6 19 18
68 | # 12 faces
69 |
70 | g group_0_16311991
71 |
72 | usemtl color_16311991
73 | s 0
74 |
75 | f 11 12 13
76 | f 11 13 14
77 | f 17 11 14
78 | f 17 14 18
79 | f 19 13 12
80 | f 19 12 20
81 | f 14 13 19
82 | f 14 19 18
83 | f 17 20 12
84 | f 17 12 11
85 | f 19 20 17
86 | f 19 17 18
87 | # 12 faces
88 |
89 | g group_0_16500122
90 |
91 | usemtl color_16500122
92 | s 0
93 |
94 | f 9 10 5
95 | f 9 5 8
96 | f 23 24 20
97 | f 23 20 17
98 | f 8 20 24
99 | f 8 24 9
100 | f 10 23 17
101 | f 10 17 5
102 | f 10 9 24
103 | f 10 24 23
104 | f 17 20 8
105 | f 17 8 5
106 | # 12 faces
107 |
108 | #end of obj_0
109 |
110 |
--------------------------------------------------------------------------------
/assets/fence/tinker.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/assets/object_id/shapenet_id.json:
--------------------------------------------------------------------------------
1 | {
2 | "bottle": {
3 | "category_id": "02876657",
4 | "object_id": [
5 | "6b8b2cb01c376064c8724d5673a063a6",
6 | "547fa0085800c5c3846564a8a219239b",
7 | "91235f7d65aec958ca972daa503b3095",
8 | "9dff3d09b297cdd930612f5c0ef21eb8",
9 | "d45bf1487b41d2f630612f5c0ef21eb8",
10 | "898101350771ff942ae40d06128938a1",
11 | "3b26c9021d9e31a7ad8912880b776dcf",
12 | "ed55f39e04668bf9837048966ef3fcb9",
13 | "74690ddde9c0182696c2e7960a15c619",
14 | "e4ada697d05ac7acf9907e8bdd53291e",
15 | "fdc47f5f8dff184830eaaf40a8a562c1",
16 | "ad33ed7da4ef1b1cca18d703b8006093",
17 | "4d4fc73864844dad1ceb7b8cc3792fd",
18 | "42f85b0eb5e9fd508f9c4ecc067067e9"
19 | ],
20 | "global_scaling": [0.7, 0.8],
21 | "large_scaling": [1.1, 1.2]
22 | },
23 | "can": {
24 | "category_id": "02946921",
25 | "object_id": [
26 | "f4ad0b7f82c36051f51f77a6d7299806",
27 | "a70947df1f1490c2a81ec39fd9664e9b",
28 | "b6c4d78363d965617cb2a55fa21392b7",
29 | "91a524cc9c9be4872999f92861cdea7a",
30 | "96387095255f7080b7886d94372e3c76",
31 | "9b1f0ddd23357e01a81ec39fd9664e9b"
32 | ],
33 | "global_scaling": [0.7, 0.9],
34 | "large_scaling": [1.1, 1.2]
35 | },
36 | "mug": {
37 | "category_id": "03797390",
38 | "object_id": [
39 | "599e604a8265cc0a98765d8aa3638e70",
40 | "b46e89995f4f9cc5161e440f04bd2a2",
41 | "9c930a8a3411f069e7f67f334aa9295c",
42 | "2d10421716b16580e45ef4135c266a12",
43 | "71ca4fc9c8c29fa8d5abaf84513415a2"
44 | ],
45 | "global_scaling": [0.8, 0.9],
46 | "large_scaling": [1.1, 1.2]
47 | },
48 | "sofa": {
49 | "category_id": "04256520",
50 | "object_id": [
51 | "930873705bff9098e6e46d06d31ee634",
52 | "f094521e8579917eea65c47b660136e7",
53 | "adc4a9767d1c7bae8522c33a9d3f5757",
54 | "d0e419a11fd8f4bce589b08489d157d",
55 | "9aef63feacf65dd9cc3e9831f31c9164",
56 | "d55d14f87d65faa84ccf9d6d546b307f",
57 | "526c4f841f777635b5b328c62af5142",
58 | "65fce4b727c5df50e5f5c582d1bee164",
59 | "7c68894c83afb0118e8dcbd53cc631ab"
60 | ],
61 | "global_scaling": [0.4, 0.7],
62 | "large_scaling": [1.0, 1.2]
63 | },
64 | "phone": {
65 | "category_id": "04401088",
66 | "object_id": [
67 | "611afaaa1671ac8cc56f78d9daf213b",
68 | "b8555009f82af5da8c3645155d02fccc",
69 | "2725909a09e1a7961df58f4da76e254b",
70 | "9d021614c39c53dabee972a203aaf80",
71 | "fe34b663c44baf622ad536a59974757f",
72 | "2bb42eb0676116d41580700c4211a379",
73 | "9efabcf2ff8a4be9a59562d67b11f3d",
74 | "b9f67617cf320c20de4349e5bfa4fedb",
75 | "78855e0d8d27f00b42e82e1724e35ca",
76 | "401d5604ebfb4b43a7d4c094203303b1"
77 | ],
78 | "global_scaling": [0.7, 1.0],
79 | "large_scaling": [1.0, 1.1]
80 |
81 | }
82 | }
--------------------------------------------------------------------------------
/assets/object_id/ycb_id.json:
--------------------------------------------------------------------------------
1 | {
2 | "large_list": [
3 | "002_master_chef_can",
4 | "004_sugar_box",
5 | "006_mustard_bottle",
6 | "flipped-065-j_cups",
7 | "071_nine_hole_peg_test",
8 | "051_large_clamp"
9 | ],
10 | "normal_list": [
11 | "005_tomato_soup_can",
12 | "007_tuna_fish_can",
13 | "008_pudding_box",
14 | "009_gelatin_box",
15 | "010_potted_meat_can",
16 | "025_mug",
17 | "061_foam_brick",
18 | "077_rubiks_cube",
19 |
20 | "flipped-065-a_cups",
21 | "flipped-065-d_cups",
22 | "flipped-065-g_cups",
23 | "filled-073-a_lego_duplo",
24 | "filled-073-b_lego_duplo",
25 | "filled-073-c_lego_duplo",
26 | "filled-073-d_lego_duplo",
27 | "filled-073-f_lego_duplo"
28 | ]
29 | }
--------------------------------------------------------------------------------
/assets/stick/stick.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/assets/ur5/collision/base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/base.stl
--------------------------------------------------------------------------------
/assets/ur5/collision/forearm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/forearm.stl
--------------------------------------------------------------------------------
/assets/ur5/collision/shoulder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/shoulder.stl
--------------------------------------------------------------------------------
/assets/ur5/collision/upperarm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/upperarm.stl
--------------------------------------------------------------------------------
/assets/ur5/collision/wrist1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/wrist1.stl
--------------------------------------------------------------------------------
/assets/ur5/collision/wrist2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/wrist2.stl
--------------------------------------------------------------------------------
/assets/ur5/collision/wrist3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/wrist3.stl
--------------------------------------------------------------------------------
/assets/ur5/mount.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/assets/ur5/notes.txt:
--------------------------------------------------------------------------------
1 |
2 | 2018-05-08 - Vijay Pradeep:
3 |
4 | These visual/ mesh files were generated from the dae files in
5 | ros-industrial universal robot repo [1]. Since the collada pyBullet
6 | parser is somewhat limited, it is unable to parse the UR collada mesh
7 | files. This, we imported these collada files into blender and
8 | converted them into STL files. We lost material definitions during
9 | the conversion, but that's ok.
10 |
11 | The URDF was generated by running the xacro xml preprocessor on the
12 | URDF included in the ur_description repo already mentioned here.
13 | Additional manual tweaking was required to update resource paths and
14 | to remove errors caused by missing inertia elements. Varios Gazebo
15 | plugin tags were also removed.
16 |
17 | [1] - https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_description/meshes/ur5/visual
18 |
--------------------------------------------------------------------------------
/assets/ur5/ur5.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 | >
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
--------------------------------------------------------------------------------
/assets/ur5/visual/base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/base.stl
--------------------------------------------------------------------------------
/assets/ur5/visual/forearm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/forearm.stl
--------------------------------------------------------------------------------
/assets/ur5/visual/shoulder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/shoulder.stl
--------------------------------------------------------------------------------
/assets/ur5/visual/upperarm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/upperarm.stl
--------------------------------------------------------------------------------
/assets/ur5/visual/wrist1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/wrist1.stl
--------------------------------------------------------------------------------
/assets/ur5/visual/wrist2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/wrist2.stl
--------------------------------------------------------------------------------
/assets/ur5/visual/wrist3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/wrist3.stl
--------------------------------------------------------------------------------
/binvox_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2012 Daniel Maturana
2 | # This file is part of binvox-rw-py.
3 | #
4 | # binvox-rw-py is free software: you can redistribute it and/or modify
5 | # it under the terms of the GNU General Public License as published by
6 | # the Free Software Foundation, either version 3 of the License, or
7 | # (at your option) any later version.
8 | #
9 | # binvox-rw-py is distributed in the hope that it will be useful,
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | # GNU General Public License for more details.
13 | #
14 | # You should have received a copy of the GNU General Public License
15 | # along with binvox-rw-py. If not, see .
16 | #
17 |
18 | """
19 | Binvox to Numpy and back.
20 |
21 |
22 | >>> import numpy as np
23 | >>> import binvox_rw
24 | >>> with open('chair.binvox', 'rb') as f:
25 | ... m1 = binvox_rw.read_as_3d_array(f)
26 | ...
27 | >>> m1.dims
28 | [32, 32, 32]
29 | >>> m1.scale
30 | 41.133000000000003
31 | >>> m1.translate
32 | [0.0, 0.0, 0.0]
33 | >>> with open('chair_out.binvox', 'wb') as f:
34 | ... m1.write(f)
35 | ...
36 | >>> with open('chair_out.binvox', 'rb') as f:
37 | ... m2 = binvox_rw.read_as_3d_array(f)
38 | ...
39 | >>> m1.dims==m2.dims
40 | True
41 | >>> m1.scale==m2.scale
42 | True
43 | >>> m1.translate==m2.translate
44 | True
45 | >>> np.all(m1.data==m2.data)
46 | True
47 |
48 | >>> with open('chair.binvox', 'rb') as f:
49 | ... md = binvox_rw.read_as_3d_array(f)
50 | ...
51 | >>> with open('chair.binvox', 'rb') as f:
52 | ... ms = binvox_rw.read_as_coord_array(f)
53 | ...
54 | >>> data_ds = binvox_rw.dense_to_sparse(md.data)
55 | >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)
56 | >>> np.all(data_sd==md.data)
57 | True
58 | >>> # the ordering of elements returned by numpy.nonzero changes with axis
59 | >>> # ordering, so to compare for equality we first lexically sort the voxels.
60 | >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])
61 | True
62 | """
63 |
64 | import numpy as np
65 |
66 | class Voxels(object):
67 | """ Holds a binvox model.
68 | data is either a three-dimensional numpy boolean array (dense representation)
69 | or a two-dimensional numpy float array (coordinate representation).
70 |
71 | dims, translate and scale are the model metadata.
72 |
73 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.
74 |
75 | scale and translate relate the voxels to the original model coordinates.
76 |
77 | To translate voxel coordinates i, j, k to original coordinates x, y, z:
78 |
79 | x_n = (i+.5)/dims[0]
80 | y_n = (j+.5)/dims[1]
81 | z_n = (k+.5)/dims[2]
82 | x = scale*x_n + translate[0]
83 | y = scale*y_n + translate[1]
84 | z = scale*z_n + translate[2]
85 |
86 | """
87 |
88 | def __init__(self, data, dims, translate, scale, axis_order):
89 | self.data = data
90 | self.dims = dims
91 | self.translate = translate
92 | self.scale = scale
93 | assert (axis_order in ('xzy', 'xyz'))
94 | self.axis_order = axis_order
95 |
96 | def clone(self):
97 | data = self.data.copy()
98 | dims = self.dims[:]
99 | translate = self.translate[:]
100 | return Voxels(data, dims, translate, self.scale, self.axis_order)
101 |
102 | def write(self, fp):
103 | write(self, fp)
104 |
105 | def read_header(fp):
106 | """ Read binvox header. Mostly meant for internal use.
107 | """
108 | line = fp.readline().strip()
109 | if not line.startswith(b'#binvox'):
110 | raise IOError('Not a binvox file')
111 | dims = list(map(int, fp.readline().strip().split(b' ')[1:]))
112 | translate = list(map(float, fp.readline().strip().split(b' ')[1:]))
113 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]
114 | line = fp.readline()
115 | return dims, translate, scale
116 |
117 | def read_as_3d_array(fp, fix_coords=True):
118 | """ Read binary binvox format as array.
119 |
120 | Returns the model with accompanying metadata.
121 |
122 | Voxels are stored in a three-dimensional numpy array, which is simple and
123 | direct, but may use a lot of memory for large models. (Storage requirements
124 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy
125 | boolean arrays use a byte per element).
126 |
127 | Doesn't do any checks on input except for the '#binvox' line.
128 | """
129 | dims, translate, scale = read_header(fp)
130 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
131 | # if just using reshape() on the raw data:
132 | # indexing the array as array[i,j,k], the indices map into the
133 | # coords as:
134 | # i -> x
135 | # j -> z
136 | # k -> y
137 | # if fix_coords is true, then data is rearranged so that
138 | # mapping is
139 | # i -> x
140 | # j -> y
141 | # k -> z
142 | values, counts = raw_data[::2], raw_data[1::2]
143 | data = np.repeat(values, counts).astype(np.bool)
144 | data = data.reshape(dims)
145 | if fix_coords:
146 | # xzy to xyz TODO the right thing
147 | data = np.transpose(data, (0, 2, 1))
148 | axis_order = 'xyz'
149 | else:
150 | axis_order = 'xzy'
151 | return Voxels(data, dims, translate, scale, axis_order)
152 |
153 | def read_as_coord_array(fp, fix_coords=True):
154 | """ Read binary binvox format as coordinates.
155 |
156 | Returns binvox model with voxels in a "coordinate" representation, i.e. an
157 | 3 x N array where N is the number of nonzero voxels. Each column
158 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates
159 | of the voxel. (The odd ordering is due to the way binvox format lays out
160 | data). Note that coordinates refer to the binvox voxels, without any
161 | scaling or translation.
162 |
163 | Use this to save memory if your model is very sparse (mostly empty).
164 |
165 | Doesn't do any checks on input except for the '#binvox' line.
166 | """
167 | dims, translate, scale = read_header(fp)
168 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8)
169 |
170 | values, counts = raw_data[::2], raw_data[1::2]
171 |
172 | sz = np.prod(dims)
173 | index, end_index = 0, 0
174 | end_indices = np.cumsum(counts)
175 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)
176 |
177 | values = values.astype(np.bool)
178 | indices = indices[values]
179 | end_indices = end_indices[values]
180 |
181 | nz_voxels = []
182 | for index, end_index in zip(indices, end_indices):
183 | nz_voxels.extend(range(index, end_index))
184 | nz_voxels = np.array(nz_voxels)
185 | # TODO are these dims correct?
186 | # according to docs,
187 | # index = x * wxh + z * width + y; // wxh = width * height = d * d
188 |
189 | x = nz_voxels / (dims[0]*dims[1])
190 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y
191 | z = zwpy / dims[0]
192 | y = zwpy % dims[0]
193 | if fix_coords:
194 | data = np.vstack((x, y, z))
195 | axis_order = 'xyz'
196 | else:
197 | data = np.vstack((x, z, y))
198 | axis_order = 'xzy'
199 |
200 | #return Voxels(data, dims, translate, scale, axis_order)
201 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)
202 |
203 | def dense_to_sparse(voxel_data, dtype=np.int):
204 | """ From dense representation to sparse (coordinate) representation.
205 | No coordinate reordering.
206 | """
207 | if voxel_data.ndim!=3:
208 | raise ValueError('voxel_data is wrong shape; should be 3D array.')
209 | return np.asarray(np.nonzero(voxel_data), dtype)
210 |
211 | def sparse_to_dense(voxel_data, dims, dtype=np.bool):
212 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3:
213 | raise ValueError('voxel_data is wrong shape; should be 3xN array.')
214 | if np.isscalar(dims):
215 | dims = [dims]*3
216 | dims = np.atleast_2d(dims).T
217 | # truncate to integers
218 | xyz = voxel_data.astype(np.int)
219 | # discard voxels that fall outside dims
220 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)
221 | xyz = xyz[:,valid_ix]
222 | out = np.zeros(dims.flatten(), dtype=dtype)
223 | out[tuple(xyz)] = True
224 | return out
225 |
226 | #def get_linear_index(x, y, z, dims):
227 | #""" Assuming xzy order. (y increasing fastest.
228 | #TODO ensure this is right when dims are not all same
229 | #"""
230 | #return x*(dims[1]*dims[2]) + z*dims[1] + y
231 |
232 | def write(voxel_model, fp):
233 | """ Write binary binvox format.
234 |
235 | Note that when saving a model in sparse (coordinate) format, it is first
236 | converted to dense format.
237 |
238 | Doesn't check if the model is 'sane'.
239 |
240 | """
241 | if voxel_model.data.ndim==2:
242 | # TODO avoid conversion to dense
243 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims)
244 | else:
245 | dense_voxel_data = voxel_model.data
246 |
247 | fp.write('#binvox 1\n'.encode('ascii'))
248 | fp.write(('dim '+' '.join(map(str, voxel_model.dims))+'\n').encode('ascii'))
249 | fp.write(('translate '+' '.join(map(str, voxel_model.translate))+'\n').encode('ascii'))
250 | fp.write(('scale '+str(voxel_model.scale)+'\n').encode('ascii'))
251 | fp.write('data\n'.encode('ascii'))
252 | # fp.write('#binvox 1\n')
253 | # fp.write('dim ' + ' '.join(map(str, voxel_model.dims)) + '\n')
254 | # fp.write('translate ' + ' '.join(map(str, voxel_model.translate)) + '\n')
255 | # fp.write('scale ' + str(voxel_model.scale) + '\n')
256 | # fp.write('data\n')
257 | if not voxel_model.axis_order in ('xzy', 'xyz'):
258 | raise ValueError('Unsupported voxel model axis order')
259 |
260 | if voxel_model.axis_order=='xzy':
261 | voxels_flat = dense_voxel_data.flatten()
262 | elif voxel_model.axis_order=='xyz':
263 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()
264 |
265 | # keep a sort of state machine for writing run length encoding
266 | state = voxels_flat[0]
267 | ctr = 0
268 | for c in voxels_flat:
269 | if c==state:
270 | ctr += 1
271 | # if ctr hits max, dump
272 | if ctr==255:
273 | # fp.write(chr(state))
274 | # fp.write(chr(ctr))
275 | fp.write(state.tobytes())
276 | fp.write(ctr.to_bytes(1, byteorder='little'))
277 | ctr = 0
278 | else:
279 | # if switch state, dump
280 | # fp.write(chr(state))
281 | # fp.write(chr(ctr))
282 | fp.write(state.tobytes())
283 | fp.write(ctr.to_bytes(1, byteorder='little'))
284 | state = c
285 | ctr = 1
286 | # flush out remainders
287 | if ctr > 0:
288 | # fp.write(chr(state))
289 | # fp.write(chr(ctr))
290 | fp.write(state.tobytes())
291 | fp.write(ctr.to_bytes(1, byteorder='little'))
292 |
293 | if __name__ == '__main__':
294 | import doctest
295 | doctest.testmod()
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os.path as osp
3 | from torch.utils.data import Dataset
4 | import h5py
5 | from utils import imretype, draw_arrow
6 |
7 |
8 | class Data(Dataset):
9 | def __init__(self, data_path, split, seq_len):
10 | self.data_path = data_path
11 | self.tot_seq_len = 10
12 | self.seq_len = seq_len
13 | self.volume_size = [128, 128, 48]
14 | self.direction_num = 8
15 | self.voxel_size = 0.004
16 | self.idx_list = open(osp.join(self.data_path, '%s.txt' % split)).read().splitlines()
17 | self.returns = ['action', 'color_heightmap', 'color_image', 'tsdf', 'mask_3d', 'scene_flow_3d']
18 | self.data_per_seq = self.tot_seq_len // self.seq_len
19 |
20 | def __getitem__(self, index):
21 | data_dict = {}
22 | idx_seq = index // self.data_per_seq
23 | idx_step = index % self.data_per_seq * self.seq_len
24 | for step_id in range(self.seq_len):
25 | f = h5py.File(osp.join(self.data_path, "%s_%d.hdf5" % (self.idx_list[idx_seq], idx_step + step_id)), "r")
26 |
27 | # action
28 | action = f['action']
29 | data_dict['%d-action' % step_id] = self.get_action(action)
30 |
31 | # color_image, [W, H, 3]
32 | if 'color_image' in self.returns:
33 | data_dict['%d-color_image' % step_id] = np.asarray(f['color_image_small'], dtype=np.uint8)
34 |
35 | # color_heightmap, [128, 128, 3]
36 | if 'color_heightmap' in self.returns:
37 | # draw arrow for visualization
38 | color_heightmap = draw_arrow(
39 | np.asarray(f['color_heightmap'], dtype=np.uint8),
40 | (int(action[2]), int(action[1]), int(action[0]))
41 | )
42 | data_dict['%d-color_heightmap' % step_id] = color_heightmap
43 |
44 | # tsdf, [S1, S2, S3]
45 | if 'tsdf' in self.returns:
46 | data_dict['%d-tsdf' % step_id] = np.asarray(f['tsdf'], dtype=np.float32)
47 |
48 | # mask_3d, [S1, S2, S3]
49 | if 'mask_3d' in self.returns:
50 | data_dict['%d-mask_3d' % step_id] = np.asarray(f['mask_3d'], dtype=np.int)
51 |
52 | # scene_flow_3d, [3, S1, S2, S3]
53 | if 'scene_flow_3d' in self.returns:
54 | scene_flow_3d = np.asarray(f['scene_flow_3d'], dtype=np.float32).transpose([3, 0, 1, 2])
55 | data_dict['%d-scene_flow_3d' % step_id] = scene_flow_3d
56 |
57 | return data_dict
58 |
59 | def __len__(self):
60 | return len(self.idx_list) * self.data_per_seq
61 |
62 | def get_action(self, action):
63 | direction, r, c = int(action[0]), int(action[1]), int(action[2])
64 | if direction < 0:
65 | direction += self.direction_num
66 | action_map = np.zeros(shape=[self.direction_num, self.volume_size[0], self.volume_size[1]], dtype=np.float32)
67 | action_map[direction, r, c] = 1
68 |
69 | return action_map
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | ### Generate training data
4 | - Please refer to [link](../README.md).
5 | ### Donload testing data
6 | - The following two testing datasets can be download.
7 | - [Sim](https://dsr-net.cs.columbia.edu/download/data/sim_test_data.zip): 400 sequences, generated in pybullet.
8 | - [Real](https://dsr-net.cs.columbia.edu/download/data/real_test_data.zip): 150 sequences, with full annotations.
9 | - Unzip `real_test_data.zip` and `sim_test_data.zip` in this folder.
--------------------------------------------------------------------------------
/data_generation.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | from tqdm import tqdm
4 | import argparse
5 | import h5py
6 |
7 | from sim_env import SimulationEnv
8 | from utils import mkdir, project_pts_to_3d
9 | from fusion import TSDFVolume
10 |
11 | parser = argparse.ArgumentParser()
12 |
13 | parser.add_argument('--data_path', type=str, help='path to data')
14 | parser.add_argument('--train_num', type=int, help='number of training sequences')
15 | parser.add_argument('--test_num', type=int, help='number of testing sequences')
16 | parser.add_argument('--object_type', type=str, default='ycb', choices=['cube', 'shapenet', 'ycb'])
17 | parser.add_argument('--max_path_length', type=int, default=10, help='maximum length for each sequence')
18 | parser.add_argument('--object_num', type=int, default=4, help='number of objects')
19 |
20 | def main():
21 |
22 | args = parser.parse_args()
23 |
24 | for key in vars(args):
25 | print('[{0}] = {1}'.format(key, getattr(args, key)))
26 | mkdir(args.data_path, clean=False)
27 |
28 | env = SimulationEnv(gui_enabled=False)
29 | camera_pose = env.sim.camera_params[0]['camera_pose']
30 | camera_intr = env.sim.camera_params[0]['camera_intr']
31 | camera_pose_small = env.sim.camera_params[1]['camera_pose']
32 | camera_intr_small = env.sim.camera_params[1]['camera_intr']
33 |
34 | for rollout in tqdm(range(args.train_num + args.test_num)):
35 | env.reset(args.object_num, args.object_type)
36 | for step_num in range(args.max_path_length):
37 | f = h5py.File(osp.join(args.data_path, '%d_%d.hdf5' % (rollout, step_num)), 'w')
38 |
39 | output = env.poke()
40 | for key, val in output.items():
41 | if key == 'action':
42 | action = val
43 | f['action'] = np.array([action['0'], action['1'], action['2']])
44 | else:
45 | f.create_dataset(key, data=val, compression="gzip", compression_opts=4)
46 |
47 | # tsdf
48 | tsdf = get_volume(
49 | color_image=output['color_image'],
50 | depth_image=output['depth_image'],
51 | cam_intr=camera_intr,
52 | cam_pose=camera_pose
53 | )
54 | f.create_dataset('tsdf', data=tsdf, compression="gzip", compression_opts=4)
55 |
56 | # 3d pts
57 | color_image_small = output['color_image_small']
58 | depth_image_small = output['depth_image_small']
59 | pts_small = project_pts_to_3d(color_image_small, depth_image_small, camera_intr_small, camera_pose_small)
60 | f.create_dataset('pts_small', data=pts_small, compression="gzip", compression_opts=4)
61 |
62 | if step_num == args.max_path_length - 1:
63 | g_next = f.create_group('next')
64 | output = env.get_scene_info(mask_info=True)
65 | for key, val in output.items():
66 | g_next.create_dataset(key, data=val, compression="gzip", compression_opts=4)
67 | f.close()
68 |
69 | id_list = [i for i in range(args.train_num + args.test_num)]
70 | np.random.shuffle(id_list)
71 |
72 | with open(osp.join(args.data_path, 'train.txt'), 'w') as f:
73 | for k in range(args.train_num):
74 | print(id_list[k], file=f)
75 |
76 | with open(osp.join(args.data_path, 'test.txt'), 'w') as f:
77 | for k in range(args.train_num, args.train_num + args.test_num):
78 | print(id_list[k], file=f)
79 |
80 |
81 | def get_volume(color_image, depth_image, cam_intr, cam_pose, vol_bnds=None):
82 | voxel_size = 0.004
83 | if vol_bnds is None:
84 | vol_bnds = np.array([[0.244, 0.756],
85 | [-0.256, 0.256],
86 | [0.0, 0.192]])
87 | tsdf_vol = TSDFVolume(vol_bnds, voxel_size=voxel_size, use_gpu=True)
88 | tsdf_vol.integrate(color_image, depth_image, cam_intr, cam_pose, obs_weight=1.)
89 | volume = np.asarray(tsdf_vol.get_volume()[0])
90 | volume = np.transpose(volume, [1, 0, 2])
91 | return volume
92 |
93 |
94 | if __name__ == '__main__':
95 | main()
--------------------------------------------------------------------------------
/figures/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/figures/teaser.jpg
--------------------------------------------------------------------------------
/forward_warp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Function
4 | from cupy.cuda import function
5 | from pynvrtc.compiler import Program
6 | from collections import namedtuple
7 |
8 |
9 | class Forward_Warp_Cupy(Function):
10 | @staticmethod
11 | def forward(ctx, feature, flow, mask):
12 | kernel = '''
13 | extern "C"
14 | __global__ void warp_forward(
15 | const float * im0, // [B, C, W, H, D]
16 | const float * flow, // [B, 3, W, H, D]
17 | const float * mask, // [B, W, H, D]
18 | float * im1, // [B, C, W, H, D]
19 | float * cnt, // [B, W, H, D]
20 | const int vol_batch,
21 | const int vol_dim_x,
22 | const int vol_dim_y,
23 | const int vol_dim_z,
24 | const int feature_dim,
25 | const int warp_mode //0 (bilinear), 1 (nearest)
26 | ) {
27 | // Get voxel index
28 | int max_threads_per_block = blockDim.x;
29 | int block_idx = blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x;
30 | int voxel_idx = block_idx * max_threads_per_block + threadIdx.x;
31 |
32 | int voxel_size_product = vol_dim_x * vol_dim_y * vol_dim_z;
33 |
34 | // IMPORTANT
35 | if (voxel_idx >= vol_batch * voxel_size_product) return;
36 |
37 | // Get voxel grid coordinates (note: be careful when casting)
38 | int tmp = voxel_idx;
39 |
40 | int voxel_z = tmp % vol_dim_z;
41 | tmp = tmp / vol_dim_z;
42 |
43 | int voxel_y = tmp % vol_dim_y;
44 | tmp = tmp / vol_dim_y;
45 |
46 | int voxel_x = tmp % vol_dim_x;
47 | int batch = tmp / vol_dim_x;
48 |
49 | int voxel_idx_BCWHD = voxel_idx + batch * (voxel_size_product * (feature_dim - 1));
50 | int voxel_idx_flow = voxel_idx + batch * (voxel_size_product * (3 - 1));
51 |
52 | // Main part
53 | if (warp_mode == 0) {
54 | // bilinear
55 | float x_float = voxel_x + flow[voxel_idx_flow];
56 | float y_float = voxel_y + flow[voxel_idx_flow + voxel_size_product];
57 | float z_float = voxel_z + flow[voxel_idx_flow + voxel_size_product + voxel_size_product];
58 |
59 | int x_floor = x_float;
60 | int y_floor = y_float;
61 | int z_floor = z_float;
62 |
63 | for(int t = 0; t < 8; t++) {
64 | int dx = (t >= 4);
65 | int dy = (t - 4 * dx) >= 2;
66 | int dz = t - 4 * dx - dy * 2;
67 |
68 | int x = x_floor + dx;
69 | int y = y_floor + dy;
70 | int z = z_floor + dz;
71 |
72 | if (x >= 0 && x < vol_dim_x && y >= 0 && y < vol_dim_y && z >= 0 && z < vol_dim_z) {
73 | float weight = mask[voxel_idx];
74 | weight *= (dx == 0 ? (x_floor + 1 - x_float) : (x_float - x_floor));
75 | weight *= (dy == 0 ? (y_floor + 1 - y_float) : (y_float - y_floor));
76 | weight *= (dz == 0 ? (z_floor + 1 - z_float) : (z_float - z_floor));
77 | int idx = (((int)batch * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z;
78 | atomicAdd(&cnt[idx], weight);
79 |
80 | int idx_BCWHD = (((int)batch * feature_dim * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z;
81 |
82 | for(int c = 0, offset = 0; c < feature_dim; c++, offset += voxel_size_product) {
83 | atomicAdd(&im1[idx_BCWHD + offset], im0[voxel_idx_BCWHD + offset] * weight);
84 | }
85 | }
86 |
87 | }
88 | } else {
89 | // nearest
90 | int x = round(voxel_x + flow[voxel_idx_flow]);
91 | int y = round(voxel_y + flow[voxel_idx_flow + voxel_size_product]);
92 | int z = round(voxel_z + flow[voxel_idx_flow + voxel_size_product + voxel_size_product]);
93 |
94 | if (x >= 0 && x < vol_dim_x && y >= 0 && y < vol_dim_y && z >= 0 && z < vol_dim_z) {
95 | int idx = (((int)batch * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z;
96 | float mask_weight = mask[voxel_idx];
97 | atomicAdd(&cnt[idx], mask_weight);
98 |
99 | int idx_BCWHD = (((int)batch * feature_dim * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z;
100 |
101 | for(int c = 0, offset = 0; c < feature_dim; c++, offset += voxel_size_product) {
102 | atomicAdd(&im1[idx_BCWHD + offset], im0[voxel_idx_BCWHD + offset] * mask_weight);
103 | }
104 | }
105 | }
106 | }
107 | '''
108 | program = Program(kernel, 'warp_forward.cu')
109 | ptx = program.compile()
110 | m = function.Module()
111 | m.load(bytes(ptx.encode()))
112 | f = m.get_function('warp_forward')
113 | Stream = namedtuple('Stream', ['ptr'])
114 | s = Stream(ptr=torch.cuda.current_stream().cuda_stream)
115 |
116 | B, C, W, H, D = feature.size()
117 | warp_mode = 0
118 | n_blocks = np.ceil(B * W * H * D / 1024.0)
119 | grid_dim_x = int(np.cbrt(n_blocks))
120 | grid_dim_y = int(np.sqrt(n_blocks / grid_dim_x))
121 | grid_dim_z = int(np.ceil(n_blocks / grid_dim_x / grid_dim_y))
122 | assert grid_dim_x * grid_dim_y * grid_dim_z * 1024 >= B * W * H * D
123 |
124 | feature_new = torch.zeros_like(feature)
125 | cnt = torch.zeros_like(mask)
126 |
127 | f(grid=(grid_dim_x, grid_dim_y, grid_dim_z), block=(1024, 1, 1),
128 | args=[feature.data_ptr(), flow.data_ptr(), mask.data_ptr(), feature_new.data_ptr(), cnt.data_ptr(),
129 | B, W, H, D, C, warp_mode], stream=s)
130 |
131 | eps=1e-3
132 | cnt = torch.max(cnt, other=torch.ones_like(cnt) * eps)
133 | feature_new = feature_new / torch.unsqueeze(cnt, 1)
134 |
135 | return feature_new
136 |
137 | @staticmethod
138 | def backward(ctx, feature_new_grad):
139 | # Not implemented
140 | return None, None, None
--------------------------------------------------------------------------------
/fusion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from numba import njit, prange
4 | from skimage import measure
5 |
6 | try:
7 | import pycuda.driver as cuda
8 | import pycuda.autoinit
9 | from pycuda.compiler import SourceModule
10 | FUSION_GPU_MODE = 1
11 | except Exception as err:
12 | print('Warning: {}'.format(err))
13 | print('Failed to import PyCUDA. Running fusion in CPU mode.')
14 | FUSION_GPU_MODE = 0
15 |
16 |
17 | class TSDFVolume:
18 | """Volumetric TSDF Fusion of RGB-D Images.
19 | """
20 |
21 | def __init__(self, vol_bnds, voxel_size, use_gpu=True):
22 | """Constructor.
23 | Args:
24 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the
25 | xyz bounds (min/max) in meters.
26 | voxel_size (float): The volume discretization in meters.
27 | """
28 | vol_bnds = np.asarray(vol_bnds)
29 | assert vol_bnds.shape == (
30 | 3, 2), "[!] `vol_bnds` should be of shape (3, 2)."
31 |
32 | # Define voxel volume parameters
33 | self._vol_bnds = vol_bnds
34 | self._voxel_size = float(voxel_size)
35 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF
36 | self._color_const = 256 * 256
37 |
38 | # Adjust volume bounds and ensure C-order contiguous
39 | self._vol_dim = np.ceil(
40 | (self._vol_bnds[:, 1]-self._vol_bnds[:, 0])/self._voxel_size).copy(order='C').astype(int)
41 | self._vol_bnds[:, 1] = self._vol_bnds[:, 0] + \
42 | self._vol_dim*self._voxel_size
43 | self._vol_origin = self._vol_bnds[:, 0].copy(
44 | order='C').astype(np.float32)
45 |
46 | # Initialize pointers to voxel volume in CPU memory
47 | self._tsdf_vol_cpu = np.ones(self._vol_dim).astype(np.float32)
48 | # for computing the cumulative moving average of observations per voxel
49 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
50 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
51 |
52 | self.gpu_mode = use_gpu and FUSION_GPU_MODE
53 |
54 | # Copy voxel volumes to GPU
55 | if self.gpu_mode:
56 | self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes)
57 | cuda.memcpy_htod(self._tsdf_vol_gpu, self._tsdf_vol_cpu)
58 | self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes)
59 | cuda.memcpy_htod(self._weight_vol_gpu, self._weight_vol_cpu)
60 | self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes)
61 | cuda.memcpy_htod(self._color_vol_gpu, self._color_vol_cpu)
62 |
63 | # Cuda kernel function (C++)
64 | self._cuda_src_mod = SourceModule("""
65 | __global__ void integrate(float * tsdf_vol,
66 | float * weight_vol,
67 | float * color_vol,
68 | float * vol_dim,
69 | float * vol_origin,
70 | float * cam_intr,
71 | float * cam_pose,
72 | float * other_params,
73 | float * color_im,
74 | float * depth_im) {
75 | // Get voxel index
76 | int gpu_loop_idx = (int) other_params[0];
77 | int max_threads_per_block = blockDim.x;
78 | int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x;
79 | int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x;
80 | int vol_dim_x = (int) vol_dim[0];
81 | int vol_dim_y = (int) vol_dim[1];
82 | int vol_dim_z = (int) vol_dim[2];
83 | if (voxel_idx >= vol_dim_x*vol_dim_y*vol_dim_z)
84 | return;
85 | // Get voxel grid coordinates (note: be careful when casting)
86 | float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z)));
87 | float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z));
88 | float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z);
89 | // Voxel grid coordinates to world coordinates
90 | float voxel_size = other_params[1];
91 | float pt_x = vol_origin[0]+voxel_x*voxel_size;
92 | float pt_y = vol_origin[1]+voxel_y*voxel_size;
93 | float pt_z = vol_origin[2]+voxel_z*voxel_size;
94 | // World coordinates to camera coordinates
95 | float tmp_pt_x = pt_x-cam_pose[0*4+3];
96 | float tmp_pt_y = pt_y-cam_pose[1*4+3];
97 | float tmp_pt_z = pt_z-cam_pose[2*4+3];
98 | float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z;
99 | float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z;
100 | float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z;
101 | // Camera coordinates to image pixels
102 | int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]);
103 | int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]);
104 | // Skip if outside view frustum
105 | int im_h = (int) other_params[2];
106 | int im_w = (int) other_params[3];
107 | if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0)
108 | return;
109 | // Skip invalid depth
110 | float depth_value = depth_im[pixel_y*im_w+pixel_x];
111 | if (depth_value == 0)
112 | return;
113 | // Integrate TSDF
114 | float trunc_margin = other_params[4];
115 | float depth_diff = depth_value-cam_pt_z;
116 | if (depth_diff < -trunc_margin)
117 | return;
118 | float dist = fmin(1.0f,depth_diff/trunc_margin);
119 | float w_old = weight_vol[voxel_idx];
120 | float obs_weight = other_params[5];
121 | float w_new = w_old + obs_weight;
122 | weight_vol[voxel_idx] = w_new;
123 | tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new;
124 | // Integrate color
125 | float old_color = color_vol[voxel_idx];
126 | float old_b = floorf(old_color/(256*256));
127 | float old_g = floorf((old_color-old_b*256*256)/256);
128 | float old_r = old_color-old_b*256*256-old_g*256;
129 | float new_color = color_im[pixel_y*im_w+pixel_x];
130 | float new_b = floorf(new_color/(256*256));
131 | float new_g = floorf((new_color-new_b*256*256)/256);
132 | float new_r = new_color-new_b*256*256-new_g*256;
133 | new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f);
134 | new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f);
135 | new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f);
136 | color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r;
137 | }""")
138 |
139 | self._cuda_integrate = self._cuda_src_mod.get_function("integrate")
140 |
141 | # Determine block/grid size on GPU
142 | gpu_dev = cuda.Device(0)
143 | self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK
144 | n_blocks = int(np.ceil(float(np.prod(self._vol_dim)) /
145 | float(self._max_gpu_threads_per_block)))
146 | grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X,
147 | int(np.floor(np.cbrt(n_blocks))))
148 | grid_dim_y = min(gpu_dev.MAX_GRID_DIM_Y, int(
149 | np.floor(np.sqrt(n_blocks/grid_dim_x))))
150 | grid_dim_z = min(gpu_dev.MAX_GRID_DIM_Z, int(
151 | np.ceil(float(n_blocks)/float(grid_dim_x*grid_dim_y))))
152 | self._max_gpu_grid_dim = np.array(
153 | [grid_dim_x, grid_dim_y, grid_dim_z]).astype(int)
154 | self._n_gpu_loops = int(np.ceil(float(np.prod(
155 | self._vol_dim))/float(np.prod(self._max_gpu_grid_dim)*self._max_gpu_threads_per_block)))
156 |
157 | else:
158 | # Get voxel grid coordinates
159 | xv, yv, zv = np.meshgrid(
160 | range(self._vol_dim[0]),
161 | range(self._vol_dim[1]),
162 | range(self._vol_dim[2]),
163 | indexing='ij'
164 | )
165 | self.vox_coords = np.concatenate([
166 | xv.reshape(1, -1),
167 | yv.reshape(1, -1),
168 | zv.reshape(1, -1)
169 | ], axis=0).astype(int).T
170 |
171 | @staticmethod
172 | @njit(parallel=True)
173 | def vox2world(vol_origin, vox_coords, vox_size):
174 | """Convert voxel grid coordinates to world coordinates.
175 | """
176 | vol_origin = vol_origin.astype(np.float32)
177 | vox_coords = vox_coords.astype(np.float32)
178 | cam_pts = np.empty_like(vox_coords, dtype=np.float32)
179 | for i in prange(vox_coords.shape[0]):
180 | for j in range(3):
181 | cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j])
182 | return cam_pts
183 |
184 | @staticmethod
185 | @njit(parallel=True)
186 | def cam2pix(cam_pts, intr):
187 | """Convert camera coordinates to pixel coordinates.
188 | """
189 | intr = intr.astype(np.float32)
190 | fx, fy = intr[0, 0], intr[1, 1]
191 | cx, cy = intr[0, 2], intr[1, 2]
192 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64)
193 | for i in prange(cam_pts.shape[0]):
194 | pix[i, 0] = int(
195 | np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx))
196 | pix[i, 1] = int(
197 | np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy))
198 | return pix
199 |
200 | @staticmethod
201 | @njit(parallel=True)
202 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight):
203 | """Integrate the TSDF volume.
204 | """
205 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32)
206 | w_new = np.empty_like(w_old, dtype=np.float32)
207 | for i in prange(len(tsdf_vol)):
208 | w_new[i] = w_old[i] + obs_weight
209 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] +
210 | obs_weight * dist[i]) / w_new[i]
211 | return tsdf_vol_int, w_new
212 |
213 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.):
214 | """Integrate an RGB-D frame into the TSDF volume.
215 | Args:
216 | color_im (ndarray): An RGB image of shape (H, W, 3).
217 | depth_im (ndarray): A depth image of shape (H, W).
218 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3).
219 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4).
220 | obs_weight (float): The weight to assign for the current observation. A higher
221 | value
222 | """
223 | im_h, im_w = depth_im.shape
224 |
225 | # Fold RGB color image into a single channel image
226 | color_im = color_im.astype(np.float32)
227 | color_im = np.floor(
228 | color_im[..., 2]*self._color_const + color_im[..., 1]*256 + color_im[..., 0])
229 |
230 | # GPU mode: integrate voxel volume (calls CUDA kernel)
231 | if self.gpu_mode:
232 | for gpu_loop_idx in range(self._n_gpu_loops):
233 | self._cuda_integrate(self._tsdf_vol_gpu,
234 | self._weight_vol_gpu,
235 | self._color_vol_gpu,
236 | cuda.InOut(
237 | self._vol_dim.astype(np.float32)),
238 | cuda.InOut(
239 | self._vol_origin.astype(np.float32)),
240 | cuda.InOut(
241 | cam_intr.reshape(-1).astype(np.float32)),
242 | cuda.InOut(
243 | cam_pose.reshape(-1).astype(np.float32)),
244 | cuda.InOut(np.asarray([
245 | gpu_loop_idx,
246 | self._voxel_size,
247 | im_h,
248 | im_w,
249 | self._trunc_margin,
250 | obs_weight
251 | ], np.float32)),
252 | cuda.InOut(
253 | color_im.reshape(-1).astype(np.float32)),
254 | cuda.InOut(
255 | depth_im.reshape(-1).astype(np.float32)),
256 | block=(
257 | self._max_gpu_threads_per_block, 1, 1),
258 | grid=(
259 | int(self._max_gpu_grid_dim[0]),
260 | int(self._max_gpu_grid_dim[1]),
261 | int(self._max_gpu_grid_dim[2]),
262 | )
263 | )
264 | else: # CPU mode: integrate voxel volume (vectorized implementation)
265 | # Convert voxel grid coordinates to pixel coordinates
266 | cam_pts = self.vox2world(
267 | self._vol_origin, self.vox_coords, self._voxel_size)
268 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose))
269 | pix_z = cam_pts[:, 2]
270 | pix = self.cam2pix(cam_pts, cam_intr)
271 | pix_x, pix_y = pix[:, 0], pix[:, 1]
272 |
273 | # Eliminate pixels outside view frustum
274 | valid_pix = np.logical_and(pix_x >= 0,
275 | np.logical_and(pix_x < im_w,
276 | np.logical_and(pix_y >= 0,
277 | np.logical_and(pix_y < im_h,
278 | pix_z > 0))))
279 | depth_val = np.zeros(pix_x.shape)
280 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]]
281 |
282 | # Integrate TSDF
283 | depth_diff = depth_val - pix_z
284 | valid_pts = np.logical_and(
285 | depth_val > 0, depth_diff >= -self._trunc_margin)
286 | dist = np.minimum(1, depth_diff / self._trunc_margin)
287 | valid_vox_x = self.vox_coords[valid_pts, 0]
288 | valid_vox_y = self.vox_coords[valid_pts, 1]
289 | valid_vox_z = self.vox_coords[valid_pts, 2]
290 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
291 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x,
292 | valid_vox_y, valid_vox_z]
293 | valid_dist = dist[valid_pts]
294 | tsdf_vol_new, w_new = self.integrate_tsdf(
295 | tsdf_vals, valid_dist, w_old, obs_weight)
296 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new
297 | self._tsdf_vol_cpu[valid_vox_x,
298 | valid_vox_y, valid_vox_z] = tsdf_vol_new
299 |
300 | # Integrate color
301 | old_color = self._color_vol_cpu[valid_vox_x,
302 | valid_vox_y, valid_vox_z]
303 | old_b = np.floor(old_color / self._color_const)
304 | old_g = np.floor((old_color-old_b*self._color_const)/256)
305 | old_r = old_color - old_b*self._color_const - old_g*256
306 | new_color = color_im[pix_y[valid_pts], pix_x[valid_pts]]
307 | new_b = np.floor(new_color / self._color_const)
308 | new_g = np.floor((new_color - new_b*self._color_const) / 256)
309 | new_r = new_color - new_b*self._color_const - new_g*256
310 | new_b = np.minimum(255., np.round(
311 | (w_old*old_b + obs_weight*new_b) / w_new))
312 | new_g = np.minimum(255., np.round(
313 | (w_old*old_g + obs_weight*new_g) / w_new))
314 | new_r = np.minimum(255., np.round(
315 | (w_old*old_r + obs_weight*new_r) / w_new))
316 | self._color_vol_cpu[valid_vox_x, valid_vox_y,
317 | valid_vox_z] = new_b*self._color_const + new_g*256 + new_r
318 |
319 | def get_volume(self):
320 | if self.gpu_mode:
321 | cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu)
322 | cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu)
323 | return self._tsdf_vol_cpu, self._color_vol_cpu
324 |
325 | def get_point_cloud(self):
326 | """Extract a point cloud from the voxel volume.
327 | """
328 | tsdf_vol, color_vol = self.get_volume()
329 |
330 | # Marching cubes
331 | verts = measure.marching_cubes(tsdf_vol, level=0)[0]
332 | verts_ind = np.round(verts).astype(int)
333 | verts = verts*self._voxel_size + self._vol_origin
334 |
335 | # Get vertex colors
336 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
337 | colors_b = np.floor(rgb_vals / self._color_const)
338 | colors_g = np.floor((rgb_vals - colors_b*self._color_const) / 256)
339 | colors_r = rgb_vals - colors_b*self._color_const - colors_g*256
340 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
341 | colors = colors.astype(np.uint8)
342 |
343 | pc = np.hstack([verts, colors])
344 | return pc
345 |
346 | def get_mesh(self):
347 | """Compute a mesh from the voxel volume using marching cubes.
348 | """
349 | tsdf_vol, color_vol = self.get_volume()
350 |
351 | # Marching cubes
352 | verts, faces, norms, vals = measure.marching_cubes(
353 | tsdf_vol, level=0)
354 | verts_ind = np.round(verts).astype(int)
355 | # voxel grid coordinates to world coordinates
356 | verts = verts*self._voxel_size+self._vol_origin
357 |
358 | # Get vertex colors
359 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
360 | colors_b = np.floor(rgb_vals/self._color_const)
361 | colors_g = np.floor((rgb_vals-colors_b*self._color_const)/256)
362 | colors_r = rgb_vals-colors_b*self._color_const-colors_g*256
363 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
364 | colors = colors.astype(np.uint8)
365 | return verts, faces, norms, colors
366 |
367 |
368 | def rigid_transform(xyz, transform):
369 | """Applies a rigid transform to an (N, 3) pointcloud.
370 | """
371 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)])
372 | xyz_t_h = np.dot(transform, xyz_h.T).T
373 | return xyz_t_h[:, :3]
374 |
375 |
376 | def meshwrite(filename, verts, faces, norms, colors):
377 | """Save a 3D mesh to a polygon .ply file.
378 | """
379 | # Write header
380 | ply_file = open(filename, 'w')
381 | ply_file.write("ply\n")
382 | ply_file.write("format ascii 1.0\n")
383 | ply_file.write("element vertex %d\n" % (verts.shape[0]))
384 | ply_file.write("property float x\n")
385 | ply_file.write("property float y\n")
386 | ply_file.write("property float z\n")
387 | ply_file.write("property float nx\n")
388 | ply_file.write("property float ny\n")
389 | ply_file.write("property float nz\n")
390 | ply_file.write("property uchar red\n")
391 | ply_file.write("property uchar green\n")
392 | ply_file.write("property uchar blue\n")
393 | ply_file.write("element face %d\n" % (faces.shape[0]))
394 | ply_file.write("property list uchar int vertex_index\n")
395 | ply_file.write("end_header\n")
396 |
397 | # Write vertex list
398 | for i in range(verts.shape[0]):
399 | ply_file.write("%f %f %f %f %f %f %d %d %d\n" % (
400 | verts[i, 0], verts[i, 1], verts[i, 2],
401 | norms[i, 0], norms[i, 1], norms[i, 2],
402 | colors[i, 0], colors[i, 1], colors[i, 2],
403 | ))
404 |
405 | # Write face list
406 | for i in range(faces.shape[0]):
407 | ply_file.write("3 %d %d %d\n" %
408 | (faces[i, 0], faces[i, 1], faces[i, 2]))
409 |
410 | ply_file.close()
411 |
412 | def pcwrite(filename, xyzrgb):
413 | """Save a point cloud to a polygon .ply file.
414 | """
415 | xyz = xyzrgb[:, :3]
416 | rgb = xyzrgb[:, 3:].astype(np.uint8)
417 |
418 | # Write header
419 | ply_file = open(filename,'w')
420 | ply_file.write("ply\n")
421 | ply_file.write("format ascii 1.0\n")
422 | ply_file.write("element vertex %d\n"%(xyz.shape[0]))
423 | ply_file.write("property float x\n")
424 | ply_file.write("property float y\n")
425 | ply_file.write("property float z\n")
426 | ply_file.write("property uchar red\n")
427 | ply_file.write("property uchar green\n")
428 | ply_file.write("property uchar blue\n")
429 | ply_file.write("end_header\n")
430 |
431 | # Write vertex list
432 | for i in range(xyz.shape[0]):
433 | ply_file.write("%f %f %f %d %d %d\n"%(
434 | xyz[i, 0], xyz[i, 1], xyz[i, 2],
435 | rgb[i, 0], rgb[i, 1], rgb[i, 2],
436 | ))
437 |
438 | def tsdf2mesh(tsdf_vol, mesh_path):
439 | tsdf_vol = np.transpose(tsdf_vol, (1, 0, 2))
440 | # Marching cubes
441 | verts,faces,norms,vals = measure.marching_cubes(tsdf_vol,level=0.0)
442 |
443 | # Get vertex colors
444 | colors = np.array([(186//2, 176//2, 172//2) for _ in range(verts.shape[0])], dtype=np.uint8)
445 | meshwrite(mesh_path, verts, faces, norms, colors)
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 | from model_utils import ConvBlock3D, ResBlock3D, ConvBlock2D, MLP
5 | from se3.se3_module import SE3
6 | from forward_warp import Forward_Warp_Cupy
7 |
8 |
9 | class VolumeEncoder(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 | input_channel = 12
13 | self.conv00 = ConvBlock3D(input_channel, 16, stride=2, dilation=1, norm=True, relu=True) # 64x64x24
14 |
15 | self.conv10 = ConvBlock3D(16, 32, stride=2, dilation=1, norm=True, relu=True) # 32x32x12
16 | self.conv11 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True)
17 | self.conv12 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True)
18 | self.conv13 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True)
19 |
20 | self.conv20 = ConvBlock3D(32, 64, stride=2, dilation=1, norm=True, relu=True) # 16x16x6
21 | self.conv21 = ConvBlock3D(64, 64, stride=1, dilation=1, norm=True, relu=True)
22 | self.conv22 = ConvBlock3D(64, 64, stride=1, dilation=1, norm=True, relu=True)
23 | self.conv23 = ConvBlock3D(64, 64, stride=1, dilation=1, norm=True, relu=True)
24 |
25 | self.conv30 = ConvBlock3D(64, 128, stride=2, dilation=1, norm=True, relu=True) # 8x8x3
26 | self.resn31 = ResBlock3D(128, 128)
27 | self.resn32 = ResBlock3D(128, 128)
28 |
29 |
30 | def forward(self, x):
31 | x0 = self.conv00(x)
32 |
33 | x1 = self.conv10(x0)
34 | x1 = self.conv11(x1)
35 | x1 = self.conv12(x1)
36 | x1 = self.conv13(x1)
37 |
38 | x2 = self.conv20(x1)
39 | x2 = self.conv21(x2)
40 | x2 = self.conv22(x2)
41 | x2 = self.conv23(x2)
42 |
43 | x3 = self.conv30(x2)
44 | x3 = self.resn31(x3)
45 | x3 = self.resn32(x3)
46 |
47 | return x3, (x2, x1, x0)
48 |
49 |
50 | class FeatureDecoder(nn.Module):
51 | def __init__(self):
52 | super().__init__()
53 | self.conv00 = ConvBlock3D(128, 64, norm=True, relu=True, upsm=True) # 16x16x6
54 | self.conv01 = ConvBlock3D(64, 64, norm=True, relu=True)
55 |
56 | self.conv10 = ConvBlock3D(64 + 64, 32, norm=True, relu=True, upsm=True) # 32x32x12
57 | self.conv11 = ConvBlock3D(32, 32, norm=True, relu=True)
58 |
59 | self.conv20 = ConvBlock3D(32 + 32, 16, norm=True, relu=True, upsm=True) # 64X64X24
60 | self.conv21 = ConvBlock3D(16, 16, norm=True, relu=True)
61 |
62 | self.conv30 = ConvBlock3D(16 + 16, 8, norm=True, relu=True, upsm=True) # 128X128X48
63 | self.conv31 = ConvBlock3D(8, 8, norm=True, relu=True)
64 |
65 | def forward(self, x, cache):
66 | m0, m1, m2 = cache
67 |
68 | x0 = self.conv00(x)
69 | x0 = self.conv01(x0)
70 |
71 | x1 = self.conv10(torch.cat([x0, m0], dim=1))
72 | x1 = self.conv11(x1)
73 |
74 | x2 = self.conv20(torch.cat([x1, m1], dim=1))
75 | x2 = self.conv21(x2)
76 |
77 | x3 = self.conv30(torch.cat([x2, m2], dim=1))
78 | x3 = self.conv31(x3)
79 |
80 | return x3
81 |
82 |
83 | class MotionDecoder(nn.Module):
84 | def __init__(self):
85 | super().__init__()
86 | self.conv3d00 = ConvBlock3D(8 + 8, 8, stride=2, dilation=1, norm=True, relu=True) # 64
87 |
88 | self.conv3d10 = ConvBlock3D(8 + 8, 16, stride=2, dilation=1, norm=True, relu=True) # 32
89 |
90 | self.conv3d20 = ConvBlock3D(16 + 16, 32, stride=2, dilation=1, norm=True, relu=True) # 16
91 |
92 | self.conv3d30 = ConvBlock3D(32, 16, dilation=1, norm=True, relu=True, upsm=True) # 32
93 | self.conv3d40 = ConvBlock3D(16, 8, dilation=1, norm=True, relu=True, upsm=True) # 64
94 | self.conv3d50 = ConvBlock3D(8, 8, dilation=1, norm=True, relu=True, upsm=True) # 128
95 | self.conv3d60 = nn.Conv3d(8, 3, kernel_size=3, padding=1)
96 |
97 |
98 | self.conv2d10 = ConvBlock2D(8, 64, stride=2, norm=True, relu=True) # 64
99 | self.conv2d11 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True)
100 | self.conv2d12 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True)
101 | self.conv2d13 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True)
102 | self.conv2d14 = ConvBlock2D(64, 8, stride=1, dilation=1, norm=True, relu=True)
103 |
104 | self.conv2d20 = ConvBlock2D(64, 128, stride=2, norm=True, relu=True) # 32
105 | self.conv2d21 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True)
106 | self.conv2d22 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True)
107 | self.conv2d23 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True)
108 | self.conv2d24 = ConvBlock2D(128, 16, stride=1, dilation=1, norm=True, relu=True)
109 |
110 | def forward(self, feature, action):
111 | # feature: [B, 8, 128, 128, 48]
112 | # action: [B, 8, 128, 128]
113 |
114 | action_embedding0 = torch.unsqueeze(action, -1).expand([-1, -1, -1, -1, 48])
115 | feature0 = self.conv3d00(torch.cat([feature, action_embedding0], dim=1))
116 |
117 | action1 = self.conv2d10(action)
118 | action1 = self.conv2d11(action1)
119 | action1 = self.conv2d12(action1)
120 | action1 = self.conv2d13(action1)
121 |
122 | action_embedding1 = self.conv2d14(action1)
123 | action_embedding1 = torch.unsqueeze(action_embedding1, -1).expand([-1, -1, -1, -1, 24])
124 | feature1 = self.conv3d10(torch.cat([feature0, action_embedding1], dim=1))
125 |
126 | action2 = self.conv2d20(action1)
127 | action2 = self.conv2d21(action2)
128 | action2 = self.conv2d22(action2)
129 | action2 = self.conv2d23(action2)
130 |
131 | action_embedding2 = self.conv2d24(action2)
132 | action_embedding2 = torch.unsqueeze(action_embedding2, -1).expand([-1, -1, -1, -1, 12])
133 | feature2 = self.conv3d20(torch.cat([feature1, action_embedding2], dim=1))
134 |
135 | feature3 = self.conv3d30(feature2)
136 | feature4 = self.conv3d40(feature3 + feature1)
137 | feature5 = self.conv3d50(feature4 + feature0)
138 |
139 | motion_pred = self.conv3d60(feature5)
140 |
141 | return motion_pred
142 |
143 | class MaskDecoder(nn.Module):
144 | def __init__(self, K):
145 | super().__init__()
146 | self.decoder = nn.Conv3d(8, K, kernel_size=1)
147 |
148 | def forward(self, x):
149 | logit = self.decoder(x)
150 | mask = torch.softmax(logit, dim=1)
151 | return logit, mask
152 |
153 |
154 | class TransformDecoder(nn.Module):
155 | def __init__(self, transform_type, object_num):
156 | super().__init__()
157 | num_params_dict = {
158 | 'affine': 12,
159 | 'se3euler': 6,
160 | 'se3aa': 6,
161 | 'se3spquat': 6,
162 | 'se3quat': 7
163 | }
164 | self.num_params = num_params_dict[transform_type]
165 | self.object_num = object_num
166 |
167 | self.conv3d00 = ConvBlock3D(8 + 8, 8, stride=2, dilation=1, norm=True, relu=True) # 64
168 |
169 | self.conv3d10 = ConvBlock3D(8 + 8, 16, stride=2, dilation=1, norm=True, relu=True) # 32
170 |
171 | self.conv3d20 = ConvBlock3D(16 + 16, 32, stride=2, dilation=1, norm=True, relu=True) # 16
172 | self.conv3d21 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True)
173 | self.conv3d22 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True)
174 | self.conv3d23 = ConvBlock3D(32, 64, stride=1, dilation=1, norm=True, relu=True)
175 |
176 | self.conv3d30 = ConvBlock3D(64, 128, stride=2, dilation=1, norm=True, relu=True) # 8
177 |
178 | self.conv3d40 = ConvBlock3D(128, 128, stride=2, dilation=1, norm=True, relu=True) # 4
179 |
180 | self.conv3d50 = nn.Conv3d(128, 128, kernel_size=(4, 4, 2))
181 |
182 |
183 | self.conv2d10 = ConvBlock2D(8, 64, stride=2, norm=True, relu=True) # 64
184 | self.conv2d11 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True)
185 | self.conv2d12 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True)
186 | self.conv2d13 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True)
187 | self.conv2d14 = ConvBlock2D(64, 8, stride=1, dilation=1, norm=True, relu=True)
188 |
189 | self.conv2d20 = ConvBlock2D(64, 128, stride=2, norm=True, relu=True) # 32
190 | self.conv2d21 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True)
191 | self.conv2d22 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True)
192 | self.conv2d23 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True)
193 | self.conv2d24 = ConvBlock2D(128, 16, stride=1, dilation=1, norm=True, relu=True)
194 |
195 | self.mlp = MLP(
196 | input_dim=128,
197 | output_dim=self.num_params * self.object_num,
198 | hidden_sizes=[512, 512, 512, 512],
199 | hidden_nonlinearity=F.leaky_relu
200 | )
201 |
202 |
203 | def forward(self, feature, action):
204 | # feature: [B, 8, 128, 128, 48]
205 | # action: [B, 8, 128, 128]
206 |
207 | action_embedding0 = torch.unsqueeze(action, -1).expand([-1, -1, -1, -1, 48])
208 | feature0 = self.conv3d00(torch.cat([feature, action_embedding0], dim=1))
209 |
210 | action1 = self.conv2d10(action)
211 | action1 = self.conv2d11(action1)
212 | action1 = self.conv2d12(action1)
213 | action1 = self.conv2d13(action1)
214 |
215 | action_embedding1 = self.conv2d14(action1)
216 | action_embedding1 = torch.unsqueeze(action_embedding1, -1).expand([-1, -1, -1, -1, 24])
217 | feature1 = self.conv3d10(torch.cat([feature0, action_embedding1], dim=1))
218 |
219 | action2 = self.conv2d20(action1)
220 | action2 = self.conv2d21(action2)
221 | action2 = self.conv2d22(action2)
222 | action2 = self.conv2d23(action2)
223 |
224 | action_embedding2 = self.conv2d24(action2)
225 | action_embedding2 = torch.unsqueeze(action_embedding2, -1).expand([-1, -1, -1, -1, 12])
226 | feature2 = self.conv3d20(torch.cat([feature1, action_embedding2], dim=1))
227 | feature2 = self.conv3d21(feature2)
228 | feature2 = self.conv3d22(feature2)
229 | feature2 = self.conv3d23(feature2)
230 |
231 | feature3 = self.conv3d30(feature2)
232 | feature4 = self.conv3d40(feature3)
233 | feature5 = self.conv3d50(feature4)
234 |
235 | params = self.mlp(feature5.view([-1, 128]))
236 | params = params.view([-1, self.object_num, self.num_params])
237 |
238 | return params
239 |
240 |
241 | class ModelDSR(nn.Module):
242 | def __init__(self, object_num=5, transform_type='se3euler', motion_type='se3'):
243 | # transform_type options: None, 'affine', 'se3euler', 'se3aa', 'se3quat', 'se3spquat'
244 | # motion_type options: 'se3', 'conv'
245 | # input volume size: [128, 128, 48]
246 |
247 | super().__init__()
248 | self.transform_type = transform_type
249 | self.K = object_num
250 | self.motion_type = motion_type
251 |
252 | # modules
253 | self.forward_warp = Forward_Warp_Cupy.apply
254 | self.volume_encoder = VolumeEncoder()
255 | self.feature_decoder = FeatureDecoder()
256 | if self.motion_type == 'se3':
257 | self.mask_decoder = MaskDecoder(self.K)
258 | self.transform_decoder = TransformDecoder(
259 | transform_type=self.transform_type,
260 | object_num=self.K - 1
261 | )
262 | self.se3 = SE3(self.transform_type)
263 | elif self.motion_type == 'conv':
264 | self.motion_decoder = MotionDecoder()
265 | else:
266 | raise ValueError('motion_type doesn\'t support ', self.motion_type)
267 |
268 | # initialization
269 | for m in self.named_modules():
270 | if isinstance(m[1], nn.Conv3d) or isinstance(m[1], nn.Conv2d):
271 | nn.init.kaiming_normal_(m[1].weight.data)
272 | elif isinstance(m[1], nn.BatchNorm3d) or isinstance(m[1], nn.BatchNorm2d):
273 | m[1].weight.data.fill_(1)
274 | m[1].bias.data.zero_()
275 |
276 | # const value
277 | self.grids = torch.stack(torch.meshgrid(
278 | torch.linspace(0, 127, 128),
279 | torch.linspace(0, 127, 128),
280 | torch.linspace(0, 47, 48)
281 | ))
282 | self.coord_feature = self.grids / torch.tensor([128, 128, 48]).view([3, 1, 1, 1])
283 | self.grids_flat = self.grids.view(1, 1, 3, 128 * 128 * 48)
284 | self.zero_vec = torch.zeros([1, 1, 3], dtype=torch.float)
285 | self.eye_mat = torch.eye(3, dtype=torch.float)
286 |
287 | def forward(self, input_volume, last_s=None, input_action=None, input_motion=None, next_mask=False, no_warp=False):
288 | B, _, S1, S2, S3 = input_volume.size()
289 | K = self.K
290 | device = input_volume.device
291 | output = {}
292 |
293 | input = torch.cat((input_volume, self.coord_feature.expand(B, -1, -1, -1, -1).to(device)), dim=1)
294 | input = torch.cat((input, last_s), dim=1) # aggregate history
295 |
296 | volume_embedding, cache = self.volume_encoder(input)
297 | mask_feature = self.feature_decoder(volume_embedding, cache)
298 |
299 | if self.motion_type == 'conv':
300 | motion = self.motion_decoder(mask_feature, input_action)
301 | output['motion'] = motion
302 |
303 | return output
304 |
305 |
306 | assert(self.motion_type == 'se3')
307 | logit, mask = self.mask_decoder(mask_feature)
308 | output['init_logit'] = logit
309 | transform_param = self.transform_decoder(mask_feature, input_action)
310 |
311 | # trans, pivot: [B, K-1, 3]
312 | # rot_matrix: [B, K-1, 3, 3]
313 | trans_vec, rot_mat = self.se3(transform_param)
314 | mask_object = torch.narrow(mask, 1, 0, K - 1)
315 | sum_mask = torch.sum(mask_object, dim=(2, 3, 4))
316 | heatmap = torch.unsqueeze(mask_object, dim=2) * self.grids.to(device)
317 | pivot_vec = torch.sum(heatmap, dim=(3, 4, 5)) / torch.unsqueeze(sum_mask, dim=2)
318 |
319 | # [Important] The last one is the background!
320 | trans_vec = torch.cat([trans_vec, self.zero_vec.expand(B, -1, -1).to(device)], dim=1).unsqueeze(-1)
321 | rot_mat = torch.cat([rot_mat, self.eye_mat.expand(B, 1, -1, -1).to(device)], dim=1)
322 | pivot_vec = torch.cat([pivot_vec, self.zero_vec.expand(B, -1, -1).to(device)], dim=1).unsqueeze(-1)
323 |
324 | grids_flat = self.grids_flat.to(device)
325 | grids_after_flat = rot_mat @ (grids_flat - pivot_vec) + pivot_vec + trans_vec
326 | motion = (grids_after_flat - grids_flat).view([B, K, 3, S1, S2, S3])
327 |
328 | motion = torch.sum(motion * torch.unsqueeze(mask, 2), 1)
329 |
330 | output['motion'] = motion
331 |
332 | if no_warp:
333 | output['s'] = mask_feature
334 | elif input_motion is not None:
335 | mask_feature_warp = self.forward_warp(
336 | mask_feature,
337 | input_motion,
338 | torch.sum(mask[:, :-1, ], dim=1)
339 | )
340 | output['s'] = mask_feature_warp
341 | else:
342 | mask_feature_warp = self.forward_warp(
343 | mask_feature,
344 | motion,
345 | torch.sum(mask[:, :-1, ], dim=1)
346 | )
347 | output['s'] = mask_feature_warp
348 |
349 | if next_mask:
350 | mask_warp = self.forward_warp(
351 | mask,
352 | motion,
353 | torch.sum(mask[:, :-1, ], dim=1)
354 | )
355 | output['next_mask'] = mask_warp
356 |
357 | return output
358 |
359 |
360 |
361 | def get_init_repr(self, batch_size):
362 | return torch.zeros([batch_size, 8, 128, 128, 48], dtype=torch.float)
363 |
364 |
365 | if __name__=='__main__':
366 | torch.cuda.set_device(4)
367 | model = ModelDSR(
368 | object_num=2,
369 | transform_type='se3euler',
370 | with_history=True,
371 | motion_type='se3'
372 | ).cuda()
373 |
374 | input_volume = torch.rand((4, 1, 128, 128, 48)).cuda()
375 | input_action = torch.rand((4, 8, 128, 128)).cuda()
376 | last_s = model.get_init_repr(4).cuda()
377 |
378 | output = model(input_volume=input_volume, last_s=last_s, input_action=input_action, next_mask=True)
379 |
380 | for k in output.keys():
381 | print(k, output[k].size())
382 |
383 |
384 |
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['ConvBlock2D', 'ConvBlock3D', 'ResBlock2D', 'ResBlock3D', 'MLP']
6 |
7 |
8 | class ConvBlock2D(nn.Module):
9 | def __init__(self, inplanes, planes, stride=1, dilation=1, norm=False, relu=False, pool=False, upsm=False):
10 | super().__init__()
11 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=not norm)
12 | self.norm = nn.BatchNorm2d(planes) if norm else None
13 | self.relu = nn.LeakyReLU(inplace=True) if relu else None
14 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if pool else None
15 | self.upsm = upsm
16 |
17 | def forward(self, x):
18 | out = self.conv(x)
19 |
20 | out = out if self.norm is None else self.norm(out)
21 | out = out if self.relu is None else self.relu(out)
22 | out = out if self.pool is None else self.pool(out)
23 | out = out if not self.upsm else F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True)
24 |
25 | return out
26 |
27 |
28 | class ConvBlock3D(nn.Module):
29 | def __init__(self, inplanes, planes, stride=1, dilation=1, norm=False, relu=False, pool=False, upsm=False):
30 | super().__init__()
31 |
32 | self.conv = nn.Conv3d(inplanes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=not norm)
33 | self.norm = nn.BatchNorm3d(planes) if norm else None
34 | self.relu = nn.LeakyReLU(inplace=True) if relu else None
35 | self.pool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) if pool else None
36 | self.upsm = upsm
37 |
38 | def forward(self, x):
39 | out = self.conv(x)
40 |
41 | out = out if self.norm is None else self.norm(out)
42 | out = out if self.relu is None else self.relu(out)
43 | out = out if self.pool is None else self.pool(out)
44 | out = out if not self.upsm else F.interpolate(out, scale_factor=2, mode='trilinear', align_corners=True)
45 |
46 | return out
47 |
48 |
49 | class ResBlock2D(nn.Module):
50 | def __init__(self, inplanes, planes, downsample=None, bias=False):
51 | super().__init__()
52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=bias)
53 | self.bn1 = nn.BatchNorm2d(planes)
54 | self.relu = nn.LeakyReLU(inplace=True)
55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=bias)
56 | self.bn2 = nn.BatchNorm2d(planes)
57 | self.downsample = downsample
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | residual = self.downsample(x)
71 |
72 | out += residual
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 |
78 | class ResBlock3D(nn.Module):
79 | def __init__(self, inplanes, planes, downsample=None):
80 | super().__init__()
81 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
82 | self.bn1 = nn.BatchNorm3d(planes)
83 | self.relu = nn.LeakyReLU(inplace=True)
84 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
85 | self.bn2 = nn.BatchNorm3d(planes)
86 | self.downsample = downsample
87 |
88 | def forward(self, x):
89 | residual = x
90 |
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv2(out)
96 | out = self.bn2(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 |
107 | class MLP(nn.Module):
108 | """
109 | MLP Model.
110 | Args:
111 | input_dim (int) : Dimension of the network input.
112 | output_dim (int): Dimension of the network output.
113 | hidden_sizes (list[int]): Output dimension of dense layer(s).
114 | For example, (32, 32) means this MLP consists of two
115 | hidden layers, each with 32 hidden units.
116 | hidden_nonlinearity (callable): Activation function for intermediate
117 | dense layer(s). It should return a torch.Tensor. Set it to
118 | None to maintain a linear activation.
119 | hidden_w_init (callable): Initializer function for the weight
120 | of intermediate dense layer(s). The function should return a
121 | torch.Tensor.
122 | hidden_b_init (callable): Initializer function for the bias
123 | of intermediate dense layer(s). The function should return a
124 | torch.Tensor.
125 | output_nonlinearity (callable): Activation function for output dense
126 | layer. It should return a torch.Tensor. Set it to None to
127 | maintain a linear activation.
128 | output_w_init (callable): Initializer function for the weight
129 | of output dense layer(s). The function should return a
130 | torch.Tensor.
131 | output_b_init (callable): Initializer function for the bias
132 | of output dense layer(s). The function should return a
133 | torch.Tensor.
134 | layer_normalization (bool): Bool for using layer normalization or not.
135 | Return:
136 | The output torch.Tensor of the MLP
137 | """
138 |
139 | def __init__(self,
140 | input_dim,
141 | output_dim,
142 | hidden_sizes,
143 | hidden_nonlinearity=F.relu,
144 | hidden_w_init=nn.init.xavier_normal_,
145 | hidden_b_init=nn.init.zeros_,
146 | output_nonlinearity=None,
147 | output_w_init=nn.init.xavier_normal_,
148 | output_b_init=nn.init.zeros_,
149 | layer_normalization=False):
150 | super().__init__()
151 |
152 | self._input_dim = input_dim
153 | self._output_dim = output_dim
154 | self._hidden_nonlinearity = hidden_nonlinearity
155 | self._output_nonlinearity = output_nonlinearity
156 | self._layer_normalization = layer_normalization
157 | self._layers = nn.ModuleList()
158 |
159 | prev_size = input_dim
160 |
161 | for size in hidden_sizes:
162 | layer = nn.Linear(prev_size, size)
163 | hidden_w_init(layer.weight)
164 | hidden_b_init(layer.bias)
165 | self._layers.append(layer)
166 | prev_size = size
167 |
168 | layer = nn.Linear(prev_size, output_dim)
169 | output_w_init(layer.weight)
170 | output_b_init(layer.bias)
171 | self._layers.append(layer)
172 |
173 | def forward(self, input_val):
174 | """Forward method."""
175 | B = input_val.size(0)
176 | x = input_val.view(B, -1)
177 | for layer in self._layers[:-1]:
178 | x = layer(x)
179 | if self._hidden_nonlinearity is not None:
180 | x = self._hidden_nonlinearity(x)
181 | if self._layer_normalization:
182 | x = nn.LayerNorm(x.shape[1])(x)
183 |
184 | x = self._layers[-1](x)
185 | if self._output_nonlinearity is not None:
186 | x = self._output_nonlinearity(x)
187 |
188 | return x
--------------------------------------------------------------------------------
/object_models/README.md:
--------------------------------------------------------------------------------
1 | # Object Models
2 |
3 | ### Object meshes are used for data generation in simulation.
4 | - Download object meshes: [shapenet](https://dsr-net.cs.columbia.edu/download/object_models/shapenet.zip) and [ycb](https://dsr-net.cs.columbia.edu/download/object_models/ycb.zip).
5 | - Unzip `shapenet.zip` and `ycb.zip` in this folder.
--------------------------------------------------------------------------------
/pretrained_models/3dflow.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/3dflow.pth
--------------------------------------------------------------------------------
/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | # Pretrained Models
2 |
3 | ### The following pretrained models are provided:
4 | - [dsr](dsr.pth): DSR-Net introduced in the paper. (without real data finetuning)
5 | - [dsr_ft](dsr_ft.pth): DSR-Net introduced in the paper. (with real data finetuning)
6 | - [single](single.pth): It does not use any history aggregation.
7 | - [nowarp](nowarp.pth): It does not warp the representation before aggregation.
8 | - [gtwarp](gtwarp.pth): It warps the representation with ground truth motion (i.e., performance oracle)
9 | - [3dflow](3dflow.pth): It predicts per-voxel scene flow for the entire 3D volume.
10 |
--------------------------------------------------------------------------------
/pretrained_models/dsr.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/dsr.pth
--------------------------------------------------------------------------------
/pretrained_models/dsr_ft.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/dsr_ft.pth
--------------------------------------------------------------------------------
/pretrained_models/gtwarp.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/gtwarp.pth
--------------------------------------------------------------------------------
/pretrained_models/nowarp.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/nowarp.pth
--------------------------------------------------------------------------------
/pretrained_models/single.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/single.pth
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.9.0
2 | appdirs==1.4.4
3 | cachetools==4.1.1
4 | certifi==2020.6.20
5 | chardet==3.0.4
6 | cupy==7.1.1
7 | cycler==0.10.0
8 | decorator==4.4.2
9 | dominate==2.4.0
10 | fastrlock==0.5
11 | google-auth==1.19.2
12 | google-auth-oauthlib==0.4.1
13 | grpcio==1.30.0
14 | h5py==2.10.0
15 | idna==2.10
16 | imageio==2.6.1
17 | importlib-metadata==1.7.0
18 | kiwisolver==1.2.0
19 | llvmlite==0.33.0
20 | Mako==1.1.3
21 | Markdown==3.2.2
22 | MarkupSafe==1.1.1
23 | matplotlib==3.3.0
24 | networkx==2.4
25 | numba==0.50.0
26 | numpy==1.18.1
27 | oauthlib==3.1.0
28 | opencv-python==4.2.0.32
29 | Pillow==8.1.1
30 | pkg-resources==0.0.0
31 | protobuf==3.12.2
32 | pyasn1==0.4.8
33 | pyasn1-modules==0.2.8
34 | pybullet==2.6.4
35 | pycuda==2019.1.2
36 | pynvrtc==9.2
37 | pyparsing==2.4.7
38 | python-dateutil==2.8.1
39 | pytools==2020.3.1
40 | PyWavelets==1.1.1
41 | requests==2.24.0
42 | requests-oauthlib==1.3.0
43 | rsa==4.6
44 | scikit-image==0.16.2
45 | scipy==1.5.1
46 | six==1.15.0
47 | tensorboard==2.2.2
48 | tensorboard-plugin-wit==1.7.0
49 | torch==1.4.0
50 | tqdm==4.48.0
51 | urllib3==1.25.9
52 | Werkzeug==1.0.1
53 | zipp==3.1.0
54 |
--------------------------------------------------------------------------------
/se3/__pycache__/se3_module.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3_module.cpython-36.pyc
--------------------------------------------------------------------------------
/se3/__pycache__/se3_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/se3/__pycache__/se3aa.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3aa.cpython-36.pyc
--------------------------------------------------------------------------------
/se3/__pycache__/se3euler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3euler.cpython-36.pyc
--------------------------------------------------------------------------------
/se3/__pycache__/se3quat.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3quat.cpython-36.pyc
--------------------------------------------------------------------------------
/se3/__pycache__/se3spquat.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3spquat.cpython-36.pyc
--------------------------------------------------------------------------------
/se3/se3_module.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules import Module
2 | from se3.se3spquat import Se3spquat
3 | from se3.se3quat import Se3quat
4 | from se3.se3euler import Se3euler
5 | from se3.se3aa import Se3aa
6 |
7 | class SE3(Module):
8 | def __init__(self, transform_type='affine', has_pivot=False):
9 | super().__init__()
10 | rot_param_num_dict = {
11 | 'affine': 9,
12 | 'se3euler': 3,
13 | 'se3aa': 3,
14 | 'se3spquat': 3,
15 | 'se3quat': 4
16 | }
17 | self.transform_type = transform_type
18 | self.rot_param_num = rot_param_num_dict[transform_type]
19 | self.has_pivot = has_pivot
20 | self.num_param = rot_param_num_dict[transform_type] + 3
21 | if self.has_pivot:
22 | self.num_param += 3
23 |
24 | def forward(self, input):
25 | B, K, L = input.size()
26 | if L != self.num_param:
27 | raise ValueError('Dimension Error!')
28 |
29 | trans_vec = input.narrow(2, 0, 3)
30 | rot_params = input.narrow(2, 3, self.rot_param_num)
31 | if self.has_pivot:
32 | pivot_vec = input.narrow(2, 3 + self.rot_param_num, 3)
33 |
34 |
35 | if self.transform_type == 'affine':
36 | rot_mat = rot_params.view(B, K, 3, 3)
37 | elif self.transform_type == 'se3euler':
38 | rot_mat = Se3euler.apply(rot_params)
39 | elif self.transform_type == 'se3aa':
40 | rot_mat = Se3aa.apply(rot_params)
41 | elif self.transform_type == 'se3spquat':
42 | rot_mat = Se3spquat.apply(rot_params)
43 | elif self.transform_type == 'se3quat':
44 | rot_mat = Se3quat.apply(rot_params)
45 |
46 | if self.has_pivot:
47 | return trans_vec, rot_mat, pivot_vec
48 | else:
49 | return trans_vec, rot_mat
--------------------------------------------------------------------------------
/se3/se3_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | # Rotation about the X-axis by theta
5 | # From Barfoot's book: http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf (6.7)
6 | def create_rotx(theta):
7 | N = theta.size(0)
8 | rot = torch.eye(3).type_as(theta).view(1, 3, 3).repeat(N, 1, 1)
9 | rot[:, 1, 1] = torch.cos(theta)
10 | rot[:, 2, 2] = rot[:, 1, 1]
11 | rot[:, 1, 2] = torch.sin(theta)
12 | rot[:, 2, 1] = -rot[:, 1, 2]
13 | return rot
14 |
15 |
16 | # Rotation about the Y-axis by theta
17 | # From Barfoot's book: http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf (6.6)
18 | def create_roty(theta):
19 | N = theta.size(0)
20 | rot = torch.eye(3).type_as(theta).view(1, 3, 3).repeat(N, 1, 1)
21 | rot[:, 0, 0] = torch.cos(theta)
22 | rot[:, 2, 2] = rot[:, 0, 0]
23 | rot[:, 2, 0] = torch.sin(theta)
24 | rot[:, 0, 2] = -rot[:, 2, 0]
25 | return rot
26 |
27 |
28 | # Rotation about the Z-axis by theta
29 | # From Barfoot's book: http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf (6.5)
30 | def create_rotz(theta):
31 | N = theta.size(0)
32 | rot = torch.eye(3).type_as(theta).view(1, 3, 3).repeat(N, 1, 1)
33 | rot[:, 0, 0] = torch.cos(theta)
34 | rot[:, 1, 1] = rot[:, 0, 0]
35 | rot[:, 0, 1] = torch.sin(theta)
36 | rot[:, 1, 0] = -rot[:, 0, 1]
37 | return rot
38 |
39 |
40 | # Create a skew-symmetric matrix "S" of size [B x 3 x 3] (passed in) given a [B x 3] vector
41 | def create_skew_symmetric_matrix(vector):
42 | # Create the skew symmetric matrix:
43 | # [0 -z y; z 0 -x; -y x 0]
44 | N = vector.size(0)
45 | vec = vector.contiguous().view(N, 3)
46 | output = vec.new().resize_(N, 3, 3).fill_(0)
47 | output[:, 0, 1] = -vec[:, 2]
48 | output[:, 1, 0] = vec[:, 2]
49 | output[:, 0, 2] = vec[:, 1]
50 | output[:, 2, 0] = -vec[:, 1]
51 | output[:, 1, 2] = -vec[:, 0]
52 | output[:, 2, 1] = vec[:, 0]
53 | return output
54 |
55 |
56 | # Compute the rotation matrix R from a set of unit-quaternions (N x 4):
57 | # From: http://www.tech.plymouth.ac.uk/sme/springerusv/2011/publications_files/Terzakis%20et%20al%202012,%20A%20Recipe%20on%20the%20Parameterization%20of%20Rotation%20Matrices...MIDAS.SME.2012.TR.004.pdf (Eqn 9)
58 | def create_rot_from_unitquat(unitquat):
59 | # Init memory
60 | N = unitquat.size(0)
61 | rot = unitquat.new_zeros([N, 3, 3])
62 |
63 | # Get quaternion elements. Quat = [qx,qy,qz,qw] with the scalar at the rear
64 | x, y, z, w = unitquat[:, 0], unitquat[:, 1], unitquat[:, 2], unitquat[:, 3]
65 | x2, y2, z2, w2 = x * x, y * y, z * z, w * w
66 |
67 | # Row 1
68 | rot[:, 0, 0] = w2 + x2 - y2 - z2 # rot(0,0) = w^2 + x^2 - y^2 - z^2
69 | rot[:, 0, 1] = 2 * (x * y - w * z) # rot(0,1) = 2*x*y - 2*w*z
70 | rot[:, 0, 2] = 2 * (x * z + w * y) # rot(0,2) = 2*x*z + 2*w*y
71 |
72 | # Row 2
73 | rot[:, 1, 0] = 2 * (x * y + w * z) # rot(1,0) = 2*x*y + 2*w*z
74 | rot[:, 1, 1] = w2 - x2 + y2 - z2 # rot(1,1) = w^2 - x^2 + y^2 - z^2
75 | rot[:, 1, 2] = 2 * (y * z - w * x) # rot(1,2) = 2*y*z - 2*w*x
76 |
77 | # Row 3
78 | rot[:, 2, 0] = 2 * (x * z - w * y) # rot(2,0) = 2*x*z - 2*w*y
79 | rot[:, 2, 1] = 2 * (y * z + w * x) # rot(2,1) = 2*y*z + 2*w*x
80 | rot[:, 2, 2] = w2 - x2 - y2 + z2 # rot(2,2) = w^2 - x^2 - y^2 + z^2
81 |
82 | return rot
83 |
84 |
85 | # Compute the derivatives of the rotation matrix w.r.t the unit quaternion
86 | # From: http://www.tech.plymouth.ac.uk/sme/springerusv/2011/publications_files/Terzakis%20et%20al%202012,%20A%20Recipe%20on%20the%20Parameterization%20of%20Rotation%20Matrices...MIDAS.SME.2012.TR.004.pdf (Eqn 33-36)
87 | def compute_grad_rot_wrt_unitquat(unitquat):
88 | # Compute dR/dq' (9x4 matrix)
89 | N = unitquat.size(0)
90 | x, y, z, w = unitquat.narrow(1, 0, 1), unitquat.narrow(1, 1, 1), unitquat.narrow(1, 2, 1), unitquat.narrow(1, 3, 1)
91 | dRdqh_w = 2 * torch.cat([w, -z, y, z, w, -x, -y, x, w], 1).view(N, 9, 1) # Eqn 33, rows first
92 | dRdqh_x = 2 * torch.cat([x, y, z, y, -x, -w, z, w, -x], 1).view(N, 9, 1) # Eqn 34, rows first
93 | dRdqh_y = 2 * torch.cat([-y, x, w, x, y, z, -w, z, -y], 1).view(N, 9, 1) # Eqn 35, rows first
94 | dRdqh_z = 2 * torch.cat([-z, -w, x, w, -z, y, x, y, z], 1).view(N, 9, 1) # Eqn 36, rows first
95 | dRdqh = torch.cat([dRdqh_x, dRdqh_y, dRdqh_z, dRdqh_w], 2) # N x 9 x 4
96 |
97 | return dRdqh
98 |
99 |
100 | # Compute the derivatives of a unit quaternion w.r.t a quaternion
101 | def compute_grad_unitquat_wrt_quat(unitquat, quat):
102 | # Compute the quaternion norms
103 | N = quat.size(0)
104 | unitquat_v = unitquat.view(-1, 4, 1)
105 | norm2 = (quat * quat).sum(1) # Norm-squared
106 | norm = torch.sqrt(norm2) # Length of the quaternion
107 |
108 | # Compute gradient dq'/dq
109 | # TODO: No check for normalization issues currently
110 | I = torch.eye(4).view(1, 4, 4).expand(N, 4, 4).type_as(quat)
111 | qQ = torch.bmm(unitquat_v, unitquat_v.transpose(1, 2)) # q'*q'^T
112 | dqhdq = (I - qQ) / (norm.view(N, 1, 1).expand_as(I))
113 |
114 | return dqhdq
115 |
116 |
117 | # Compute the derivatives of a unit quaternion w.r.t a SP quaternion
118 | # From: http://www.tech.plymouth.ac.uk/sme/springerusv/2011/publications_files/Terzakis%20et%20al%202012,%20A%20Recipe%20on%20the%20Parameterization%20of%20Rotation%20Matrices...MIDAS.SME.2012.TR.004.pdf (Eqn 42-45)
119 | def compute_grad_unitquat_wrt_spquat(spquat):
120 | # Compute scalars
121 | N = spquat.size(0)
122 | x, y, z = spquat.narrow(1, 0, 1), spquat.narrow(1, 1, 1), spquat.narrow(1, 2, 1)
123 | x2, y2, z2 = x * x, y * y, z * z
124 | s = 1 + x2 + y2 + z2 # 1 + x^2 + y^2 + z^2 = 1 + alpha^2
125 | s2 = (s * s).expand(N, 4) # (1 + alpha^2)^2
126 |
127 | # Compute gradient dq'/dspq
128 | dqhdspq_x = (torch.cat([2 * s - 4 * x2, -4 * x * y, -4 * x * z, -4 * x], 1) / s2).view(N, 4, 1)
129 | dqhdspq_y = (torch.cat([-4 * x * y, 2 * s - 4 * y2, -4 * y * z, -4 * y], 1) / s2).view(N, 4, 1)
130 | dqhdspq_z = (torch.cat([-4 * x * z, -4 * y * z, 2 * s - 4 * z2, -4 * z], 1) / s2).view(N, 4, 1)
131 | dqhdspq = torch.cat([dqhdspq_x, dqhdspq_y, dqhdspq_z], 2)
132 |
133 | return dqhdspq
134 |
135 |
136 | # Compute Unit Quaternion from SP-Quaternion
137 | def create_unitquat_from_spquat(spquat):
138 | N = spquat.size(0)
139 | unitquat = spquat.new_zeros([N, 4])
140 | x, y, z = spquat[:, 0], spquat[:, 1], spquat[:, 2]
141 | alpha2 = x * x + y * y + z * z # x^2 + y^2 + z^2
142 | unitquat[:, 0] = (2 * x) / (1 + alpha2) # qx
143 | unitquat[:, 1] = (2 * y) / (1 + alpha2) # qy
144 | unitquat[:, 2] = (2 * z) / (1 + alpha2) # qz
145 | unitquat[:, 3] = (1 - alpha2) / (1 + alpha2) # qw
146 |
147 | return unitquat
148 |
--------------------------------------------------------------------------------
/se3/se3aa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | import torch.nn.functional as F
4 | from se3.se3_utils import create_skew_symmetric_matrix
5 |
6 |
7 | class Se3aa(Function):
8 | @staticmethod
9 | def forward(ctx, input):
10 | batch_size, num_se3, num_params = input.size()
11 | N = batch_size * num_se3
12 | eps = 1e-12
13 |
14 | rot_params = input.view(batch_size * num_se3, -1)
15 |
16 | # Get the un-normalized axis and angle
17 | axis = rot_params.view(N, 3, 1) # Un-normalized axis
18 | angle2 = (axis * axis).sum(1).view(N, 1, 1) # Norm of vector (squared angle)
19 | angle = torch.sqrt(angle2) # Angle
20 |
21 | # Compute skew-symmetric matrix "K" from the axis of rotation
22 | K = create_skew_symmetric_matrix(axis)
23 | K2 = torch.bmm(K, K) # K * K
24 |
25 | # Compute sines
26 | S = torch.sin(angle) / angle
27 | S.masked_fill_(angle2.lt(eps), 1) # sin(0)/0 ~= 1
28 |
29 | # Compute cosines
30 | C = (1 - torch.cos(angle)) / angle2
31 | C.masked_fill_(angle2.lt(eps), 0) # (1 - cos(0))/0^2 ~= 0
32 |
33 | # Compute the rotation matrix: R = I + (sin(theta)/theta)*K + ((1-cos(theta))/theta^2) * K^2
34 | rot = torch.eye(3).view(1, 3, 3).repeat(N, 1, 1).type_as(rot_params) # R = I
35 | rot += K * S.expand(N, 3, 3) # R = I + (sin(theta)/theta)*K
36 | rot += K2 * C.expand(N, 3, 3) # R = I + (sin(theta)/theta)*K + ((1-cos(theta))/theta^2)*K^2
37 |
38 | ctx.save_for_backward(input, rot)
39 |
40 | return rot.view(batch_size, num_se3, 3, 3)
41 |
42 | @staticmethod
43 | def backward(ctx, grad_output):
44 | input, rot = ctx.saved_tensors
45 | batch_size, num_se3, num_params = input.size()
46 | N = batch_size * num_se3
47 | eps = 1e-12
48 | grad_output =grad_output.contiguous().view(N, 3, 3)
49 |
50 | rot_params = input.view(batch_size * num_se3, -1)
51 |
52 | axis = rot_params.view(N, 3, 1) # Un-normalized axis
53 | angle2 = (axis * axis).sum(1) # (Bk) x 1 x 1 => Norm of the vector (squared angle)
54 | nSmall = angle2.lt(eps).sum() # Num angles less than threshold
55 |
56 | # Compute: v x (Id - R) for all the columns of (Id-R)
57 | I = torch.eye(3).type_as(input).repeat(N, 1, 1).add(-1, rot) # (Bk) x 3 x 3 => Id - R
58 | vI = torch.cross(axis.expand_as(I), I, 1) # (Bk) x 3 x 3 => v x (Id - R)
59 |
60 | # Compute [v * v' + v x (Id - R)] / ||v||^2
61 | vV = torch.bmm(axis, axis.transpose(1, 2)) # (Bk) x 3 x 3 => v * v'
62 | vV = (vV + vI) / (angle2.view(N, 1, 1).expand_as(vV)) # (Bk) x 3 x 3 => [v * v' + v x (Id - R)] / ||v||^2
63 |
64 | # Iterate over the 3-axis angle parameters to compute their gradients
65 | # ([v * v' + v x (Id - R)] / ||v||^2 _ k) x (R) .* gradOutput where "x" is the cross product
66 | grad_input_list = []
67 | for k in range(3):
68 | # Create skew symmetric matrix
69 | skewsym = create_skew_symmetric_matrix(vV.narrow(2, k, 1))
70 |
71 | # For those AAs with angle^2 < threshold, gradient is different
72 | # We assume angle = 0 for these AAs and update the skew-symmetric matrix to be one w.r.t identity
73 | if (nSmall > 0):
74 | vec = torch.zeros(1, 3).type_as(skewsym)
75 | vec[0][k] = 1 # Unit vector
76 | idskewsym = create_skew_symmetric_matrix(vec)
77 | for i in range(N):
78 | if (angle2[i].squeeze()[0] < eps):
79 | skewsym[i].copy_(idskewsym.squeeze()) # Use the new skew sym matrix (around identity)
80 |
81 | # Compute the gradients now
82 | grad_input_list.append(torch.sum(torch.bmm(skewsym, rot) * grad_output, dim=(1, 2))) # [(Bk) x 1 x 1] => (vV x R) .* gradOutput
83 | grad_input = torch.stack(grad_input_list, 1).view(batch_size, num_se3, 3)
84 |
85 | return grad_input
86 |
--------------------------------------------------------------------------------
/se3/se3euler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from se3.se3_utils import create_rotx, create_roty, create_rotz
4 | from se3.se3_utils import create_skew_symmetric_matrix
5 |
6 |
7 | class Se3euler(Function):
8 | @staticmethod
9 | def forward(ctx, input):
10 | batch_size, num_se3, num_params = input.size()
11 |
12 | rot_params = input.view(batch_size * num_se3, -1)
13 |
14 | # Create rotations about X,Y,Z axes
15 | # R = Rz(theta3) * Ry(theta2) * Rx(theta1)
16 | # Last 3 parameters are [theta1, theta2 ,theta3]
17 | rotx = create_rotx(rot_params[:, 0]) # Rx(theta1)
18 | roty = create_roty(rot_params[:, 1]) # Ry(theta2)
19 | rotz = create_rotz(rot_params[:, 2]) # Rz(theta3)
20 |
21 | # Compute Rz(theta3) * Ry(theta2)
22 | rotzy = torch.bmm(rotz, roty) # Rzy = R32
23 |
24 | # Compute rotation matrix R3*R2*R1 = R32*R1
25 | # R = Rz(t3) * Ry(t2) * Rx(t1)
26 | output = torch.bmm(rotzy, rotx) # R = Rzyx
27 |
28 | ctx.save_for_backward(input, output, rotx, roty, rotz, rotzy)
29 |
30 | return output.view(batch_size, num_se3, 3, 3)
31 |
32 | @staticmethod
33 | def backward(ctx, grad_output):
34 | input, output, rotx, roty, rotz, rotzy = ctx.saved_tensors
35 | batch_size, num_se3, num_params = input.size()
36 | grad_output = grad_output.contiguous().view(batch_size * num_se3, 3, 3)
37 |
38 | # Gradient w.r.t Euler angles from Barfoot's book (http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf)
39 | grad_input_list = []
40 | for k in range(3):
41 | gradr = grad_output[:, k] # Gradient w.r.t angle (k)
42 | vec = torch.zeros(1, 3).type_as(gradr)
43 | vec[0][k] = 1 # Unit vector
44 | skewsym = create_skew_symmetric_matrix(vec).view(1, 3, 3).expand_as(output) # Skew symmetric matrix of unit vector
45 | if (k == 0):
46 | Rv = torch.bmm(torch.bmm(rotzy, skewsym), rotx) # Eqn 6.61c
47 | elif (k == 1):
48 | Rv = torch.bmm(torch.bmm(rotz, skewsym), torch.bmm(roty, rotx)) # Eqn 6.61b
49 | else:
50 | Rv = torch.bmm(skewsym, output)
51 | grad_input_list.append(torch.sum(-Rv * grad_output, dim=(1, 2)))
52 | grad_input = torch.stack(grad_input_list, 1).view(batch_size, num_se3, 3)
53 |
54 | return grad_input
55 |
--------------------------------------------------------------------------------
/se3/se3quat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | import torch.nn.functional as F
4 | from se3.se3_utils import create_rot_from_unitquat
5 | from se3.se3_utils import compute_grad_rot_wrt_unitquat
6 | from se3.se3_utils import compute_grad_unitquat_wrt_quat
7 |
8 |
9 | class Se3quat(Function):
10 | @staticmethod
11 | def forward(ctx, input):
12 | batch_size, num_se3, num_params = input.size()
13 |
14 | rot_params = input.view(batch_size * num_se3, -1)
15 |
16 | unitquat = F.normalize(rot_params)
17 |
18 | output = create_rot_from_unitquat(unitquat).view(batch_size, num_se3, 3, 3)
19 |
20 | ctx.save_for_backward(input)
21 |
22 | return output
23 |
24 | @staticmethod
25 | def backward(ctx, grad_output):
26 | input = ctx.saved_tensors[0]
27 | batch_size, num_se3, num_params = input.size()
28 |
29 | rot_params = input.view(batch_size * num_se3, -1)
30 |
31 | unitquat = F.normalize(rot_params)
32 |
33 | # Compute dR/dq'
34 | dRdqh = compute_grad_rot_wrt_unitquat(unitquat)
35 |
36 | # Compute dq'/dq = d(q/||q||)/dq = 1/||q|| (I - q'q'^T)
37 | dqhdq = compute_grad_unitquat_wrt_quat(unitquat, rot_params)
38 |
39 |
40 | # Compute dR/dq = dR/dq' * dq'/dq
41 | dRdq = torch.bmm(dRdqh, dqhdq).view(batch_size, num_se3, 3, 3, 4) # B x k x 3 x 3 x 4
42 |
43 | # Scale by grad w.r.t output and sum to get gradient w.r.t quaternion params
44 | grad_out = grad_output.contiguous().view(batch_size, num_se3, 3, 3, 1).expand_as(dRdq) # B x k x 3 x 3 x 4
45 |
46 | grad_input = torch.sum(dRdq * grad_out, dim=(2, 3)) # (Bk) x 3
47 |
48 | return grad_input
49 |
--------------------------------------------------------------------------------
/se3/se3spquat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from se3.se3_utils import create_unitquat_from_spquat
4 | from se3.se3_utils import create_rot_from_unitquat
5 | from se3.se3_utils import compute_grad_rot_wrt_unitquat
6 | from se3.se3_utils import compute_grad_unitquat_wrt_spquat
7 |
8 |
9 | class Se3spquat(Function):
10 | @staticmethod
11 | def forward(ctx, input):
12 | batch_size, num_se3, num_params = input.size()
13 |
14 | rot_params = input.view(batch_size * num_se3, -1)
15 |
16 | unitquat = create_unitquat_from_spquat(rot_params)
17 |
18 | output = create_rot_from_unitquat(unitquat).view(batch_size, num_se3, 3, 3)
19 |
20 | ctx.save_for_backward(input)
21 |
22 | return output
23 |
24 | @staticmethod
25 | def backward(ctx, grad_output):
26 | input = ctx.saved_tensors[0]
27 | batch_size, num_se3, num_params = input.size()
28 |
29 | rot_params = input.view(batch_size * num_se3, -1)
30 |
31 | unitquat = create_unitquat_from_spquat(rot_params)
32 |
33 | # Compute dR/dq'
34 | dRdqh = compute_grad_rot_wrt_unitquat(unitquat)
35 |
36 | # Compute dq'/dq = d(q/||q||)/dq = 1/||q|| (I - q'q'^T)
37 | dqhdspq = compute_grad_unitquat_wrt_spquat(rot_params)
38 |
39 |
40 | # Compute dR/dq = dR/dq' * dq'/dq
41 | dRdq = torch.bmm(dRdqh, dqhdspq).view(batch_size, num_se3, 3, 3, 3) # B x k x 3 x 3 x 3
42 |
43 | # Scale by grad w.r.t output and sum to get gradient w.r.t quaternion params
44 | grad_out = grad_output.contiguous().view(batch_size, num_se3, 3, 3, 1).expand_as(dRdq) # B x k x 3 x 3 x 3
45 |
46 | grad_input = torch.sum(dRdq * grad_out, dim=(2, 4)) # (Bk) x 3
47 |
48 | return grad_input
49 |
--------------------------------------------------------------------------------
/sim.py:
--------------------------------------------------------------------------------
1 | import math
2 | import threading
3 | import time
4 |
5 | import numpy as np
6 | import pybullet as p
7 | import pybullet_data
8 |
9 | import utils
10 |
11 |
12 | class PybulletSim:
13 | def __init__(self, gui_enabled, heightmap_pixel_size=0.004, tool='stick'):
14 |
15 | self._workspace_bounds = np.array([[0.244, 0.756],
16 | [-0.256, 0.256],
17 | [0.0, 0.192]])
18 |
19 | self._view_bounds = self._workspace_bounds
20 |
21 | # Start PyBullet simulation
22 | if gui_enabled:
23 | self._physics_client = p.connect(p.GUI) # or p.DIRECT for non-graphical version
24 | else:
25 | self._physics_client = p.connect(p.DIRECT) # non-graphical version
26 | p.setAdditionalSearchPath(pybullet_data.getDataPath())
27 | p.setGravity(0, 0, -9.8)
28 | step_sim_thread = threading.Thread(target=self.step_simulation)
29 | step_sim_thread.daemon = True
30 | step_sim_thread.start()
31 |
32 | # Add ground plane & table
33 | self._plane_id = p.loadURDF("plane.urdf")
34 | # self._table_id = p.loadURDF('assets/table/table.urdf', [0.5, 0, 0], useFixedBase=True)
35 |
36 | # Add UR5 robot
37 | self._robot_body_id = p.loadURDF("assets/ur5/ur5.urdf", [0, 0, 0], p.getQuaternionFromEuler([0, 0, 0]))
38 | # Get revolute joint indices of robot (skip fixed joints)
39 | robot_joint_info = [p.getJointInfo(self._robot_body_id, i) for i in range(p.getNumJoints(self._robot_body_id))]
40 | self._robot_joint_indices = [x[0] for x in robot_joint_info if x[2] == p.JOINT_REVOLUTE]
41 | self._joint_epsilon = 0.01 # joint position threshold in radians for blocking calls (i.e. move until joint difference < epsilon)
42 |
43 | # Move robot to home joint configuration
44 | self._robot_home_joint_config = [-3.186603833231106, -2.7046623323544323, 1.9797780717750348,
45 | -0.8458013020952369, -1.5941890970134802, -0.04501555880643846]
46 | self.move_joints(self._robot_home_joint_config, blocking=True, speed=1.0)
47 |
48 | self.tool=tool
49 | # Attach a sticker to UR5 robot
50 | self._gripper_body_id = p.loadURDF("assets/stick/stick.urdf")
51 | p.resetBasePositionAndOrientation(self._gripper_body_id, [0.5, 0.1, 0.2],
52 | p.getQuaternionFromEuler([np.pi, 0, 0]))
53 | self._robot_tool_joint_idx = 9
54 | self._robot_tool_tip_joint_idx = 10
55 | self._robot_tool_offset = [0, 0, -0.0725]
56 |
57 | p.createConstraint(self._robot_body_id, self._robot_tool_joint_idx, self._gripper_body_id, 0,
58 | jointType=p.JOINT_FIXED, jointAxis=[0, 0, 0], parentFramePosition=[0, 0, 0],
59 | childFramePosition=self._robot_tool_offset,
60 | childFrameOrientation=p.getQuaternionFromEuler([0, 0, np.pi / 2]))
61 | self._tool_tip_to_ee_joint = [0, 0, 0.17]
62 | # Define Denavit-Hartenberg parameters for UR5
63 | self._ur5_kinematics_d = np.array([0.089159, 0., 0., 0.10915, 0.09465, 0.0823])
64 | self._ur5_kinematics_a = np.array([0., -0.42500, -0.39225, 0., 0., 0.])
65 |
66 | # Set friction coefficients for gripper fingers
67 | for i in range(p.getNumJoints(self._gripper_body_id)):
68 | p.changeDynamics(
69 | self._gripper_body_id, i,
70 | lateralFriction=1.0,
71 | spinningFriction=1.0,
72 | rollingFriction=0.0001,
73 | frictionAnchor=True
74 | )
75 |
76 | # Add RGB-D camera (mimic RealSense D415)
77 | self.camera_params = {
78 | # large camera, image_size = (240 * 4, 320 * 4)
79 | 0: self._get_camera_param(
80 | camera_position=[0.5, -0.7, 0.3],
81 | camera_image_size=[240 * 4, 320 * 4]
82 | ),
83 | # small camera, image_size = (240, 320)
84 | 1: self._get_camera_param(
85 | camera_position=[0.5, -0.7, 0.3],
86 | camera_image_size=[240, 320]
87 | ),
88 | # top-down camera, image_size = (480, 480)
89 | 2: self._get_camera_param(
90 | camera_position=[0.5, 0, 0.5],
91 | camera_image_size=[480, 480]
92 | ),
93 | }
94 |
95 |
96 | self._heightmap_pixel_size = heightmap_pixel_size
97 | self._heightmap_size = np.round(
98 | ((self._view_bounds[1][1] - self._view_bounds[1][0]) / self._heightmap_pixel_size,
99 | (self._view_bounds[0][1] - self._view_bounds[0][0]) / self._heightmap_pixel_size)).astype(int)
100 |
101 |
102 | def _get_camera_param(self, camera_position, camera_image_size):
103 | camera_lookat = [0.5, 0, 0]
104 | camera_up_direction = [0, camera_position[2], -camera_position[1]]
105 | camera_view_matrix = p.computeViewMatrix(camera_position, camera_lookat, camera_up_direction)
106 | camera_pose = np.linalg.inv(np.array(camera_view_matrix).reshape(4, 4).T)
107 | camera_pose[:, 1:3] = -camera_pose[:, 1:3]
108 | camera_z_near = 0.01
109 | camera_z_far = 10.0
110 | camera_fov_w = 69.40
111 | camera_focal_length = (float(camera_image_size[1]) / 2) / np.tan((np.pi * camera_fov_w / 180) / 2)
112 | camera_fov_h = (math.atan((float(camera_image_size[0]) / 2) / camera_focal_length) * 2 / np.pi) * 180
113 | camera_projection_matrix = p.computeProjectionMatrixFOV(
114 | fov=camera_fov_h,
115 | aspect=float(camera_image_size[1]) / float(camera_image_size[0]),
116 | nearVal=camera_z_near,
117 | farVal=camera_z_far
118 | ) # notes: 1) FOV is vertical FOV 2) aspect must be float
119 | camera_intrinsics = np.array(
120 | [[camera_focal_length, 0, float(camera_image_size[1]) / 2],
121 | [0, camera_focal_length, float(camera_image_size[0]) / 2],
122 | [0, 0, 1]])
123 | camera_param = {
124 | 'camera_image_size': camera_image_size,
125 | 'camera_intr': camera_intrinsics,
126 | 'camera_pose': camera_pose,
127 | 'camera_view_matrix': camera_view_matrix,
128 | 'camera_projection_matrix': camera_projection_matrix,
129 | 'camera_z_near': camera_z_near,
130 | 'camera_z_far': camera_z_far
131 | }
132 | return camera_param
133 |
134 | # Step through simulation time
135 | def step_simulation(self):
136 | while True:
137 | p.stepSimulation()
138 | time.sleep(0.0001)
139 |
140 | # Get RGB-D heightmap from RGB-D image
141 | def get_heightmap(self, color_image, depth_image, cam_param):
142 | color_heightmap, depth_heightmap = utils.get_heightmap(
143 | color_img=color_image,
144 | depth_img=depth_image,
145 | cam_intrinsics=cam_param['camera_intr'],
146 | cam_pose=cam_param['camera_pose'],
147 | workspace_limits=self._view_bounds,
148 | heightmap_resolution=self._heightmap_pixel_size
149 | )
150 | return color_heightmap, depth_heightmap
151 |
152 | # Get latest RGB-D image
153 | def get_camera_data(self, cam_param):
154 | camera_data = p.getCameraImage(cam_param['camera_image_size'][1], cam_param['camera_image_size'][0],
155 | cam_param['camera_view_matrix'], cam_param['camera_projection_matrix'],
156 | shadow=1, flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX,
157 | renderer=p.ER_BULLET_HARDWARE_OPENGL)
158 |
159 | color_image = np.asarray(camera_data[2]).reshape(
160 | [cam_param['camera_image_size'][0], cam_param['camera_image_size'][1], 4])[:, :, :3] # remove alpha channel
161 | z_buffer = np.asarray(camera_data[3]).reshape(cam_param['camera_image_size'])
162 | camera_z_near = cam_param['camera_z_near']
163 | camera_z_far = cam_param['camera_z_far']
164 | depth_image = (2.0 * camera_z_near * camera_z_far) / (
165 | camera_z_far + camera_z_near - (2.0 * z_buffer - 1.0) * (
166 | camera_z_far - camera_z_near))
167 | return color_image, depth_image
168 |
169 | # Move robot tool to specified pose
170 | def move_tool(self, position, orientation, blocking=False, speed=0.03):
171 |
172 | # Use IK to compute target joint configuration
173 | target_joint_state = np.array(
174 | p.calculateInverseKinematics(self._robot_body_id, self._robot_tool_tip_joint_idx, position, orientation,
175 | maxNumIterations=10000,
176 | residualThreshold=.0001))
177 | target_joint_state[5] = (
178 | (target_joint_state[5] + np.pi) % (2 * np.pi) - np.pi) # keep EE joint angle between -180/+180
179 |
180 | # Move joints
181 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices, p.POSITION_CONTROL,
182 | target_joint_state,
183 | positionGains=speed * np.ones(len(self._robot_joint_indices)))
184 |
185 | # Block call until joints move to target configuration
186 | if blocking:
187 | actual_joint_state = [p.getJointState(self._robot_body_id, x)[0] for x in self._robot_joint_indices]
188 | timeout_t0 = time.time()
189 | while not all([np.abs(actual_joint_state[i] - target_joint_state[i]) < self._joint_epsilon for i in
190 | range(6)]): # and (time.time()-timeout_t0) < timeout:
191 | if time.time() - timeout_t0 > 5:
192 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices, p.POSITION_CONTROL,
193 | self._robot_home_joint_config,
194 | positionGains=np.ones(len(self._robot_joint_indices)))
195 | break
196 | actual_joint_state = [p.getJointState(self._robot_body_id, x)[0] for x in self._robot_joint_indices]
197 | time.sleep(0.001)
198 |
199 | # Move robot arm to specified joint configuration
200 | def move_joints(self, target_joint_state, blocking=False, speed=0.03):
201 |
202 | # Move joints
203 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices,
204 | p.POSITION_CONTROL, target_joint_state,
205 | positionGains=speed * np.ones(len(self._robot_joint_indices)))
206 |
207 | # Block call until joints move to target configuration
208 | if blocking:
209 | actual_joint_state = [p.getJointState(self._robot_body_id, i)[0] for i in self._robot_joint_indices]
210 | timeout_t0 = time.time()
211 | while not all([np.abs(actual_joint_state[i] - target_joint_state[i]) < self._joint_epsilon for i in
212 | range(6)]):
213 | if time.time() - timeout_t0 > 5:
214 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices, p.POSITION_CONTROL,
215 | self._robot_home_joint_config,
216 | positionGains=np.ones(len(self._robot_joint_indices)))
217 | break
218 | actual_joint_state = [p.getJointState(self._robot_body_id, i)[0] for i in self._robot_joint_indices]
219 | time.sleep(0.001)
220 |
221 |
222 | def robot_go_home(self, blocking=True, speed=0.1):
223 | self.move_joints(self._robot_home_joint_config, blocking, speed)
224 |
225 |
226 | def primitive_push(self, position, rotation_angle, speed=0.01, distance=0.1):
227 | push_orientation = [1.0, 0.0]
228 | push_direction = np.asarray(
229 | [push_orientation[0] * np.cos(rotation_angle) - push_orientation[1] * np.sin(rotation_angle),
230 | push_orientation[0] * np.sin(rotation_angle) + push_orientation[1] * np.cos(rotation_angle), 0.0])
231 | target_x = position[0] + push_direction[0] * distance
232 | target_y = position[1] + push_direction[1] * distance
233 | position_end = np.asarray([target_x, target_y, position[2]])
234 | self.move_tool([position[0], position[1], 0.15], orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=0.05)
235 | self.move_tool(position, orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=0.1)
236 | self.move_tool(position_end, orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=speed)
237 |
238 | position_end[2]=0.15
239 | self.move_tool(position_end, orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=0.005)
240 |
--------------------------------------------------------------------------------
/sim_env.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import time
3 | import cv2
4 | import numpy as np
5 | import pybullet as p
6 | import json
7 |
8 | from sim import PybulletSim
9 | from binvox_utils import read_as_coord_array
10 | from utils import euler2rotm, project_pts_to_2d
11 |
12 |
13 | class SimulationEnv():
14 | def __init__(self, gui_enabled):
15 |
16 | self.gui_enabled = gui_enabled
17 | self.sim = PybulletSim(gui_enabled=gui_enabled, tool='stick')
18 | self.heightmap_size = self.sim._heightmap_size
19 | self.heightmap_pixel_size = self.sim._heightmap_pixel_size
20 | self.view_bounds = self.sim._view_bounds
21 | self.direction_num = 8
22 | self.voxel_size = 0.004
23 |
24 | self.object_ids = []
25 |
26 | self.object_type = 'cube' # choice: 'cube', 'shapenet', 'ycb'
27 |
28 | # process ycb
29 | self.ycb_path = 'object_models/ycb'
30 | self.ycb_info = json.load(open('assets/object_id/ycb_id.json', 'r'))
31 |
32 | # process shapenent
33 | self.shapenet_path = 'object_models/shapenet'
34 | self.shapenet_info = json.load(open('assets/object_id/shapenet_id.json', 'r'))
35 |
36 | self.voxel_coord = {}
37 | self.cnt_dict = {}
38 | self.init_position = {}
39 | self.last_direction = {}
40 |
41 |
42 | def _get_coord(self, obj_id, position, orientation, vol_bnds=None, voxel_size=None):
43 | # if vol_bnds is not None, return coord in voxel, else, return world coord
44 | coord = self.voxel_coord[obj_id]
45 | mat = euler2rotm(p.getEulerFromQuaternion(orientation))
46 | coord = (mat @ (coord.T)).T + np.asarray(position)
47 | if vol_bnds is not None:
48 | coord = np.round((coord - vol_bnds[:, 0]) / voxel_size).astype(np.int)
49 | return coord
50 |
51 | def _get_scene_flow_3d(self, old_po_ors):
52 | vol_bnds = self.view_bounds
53 | scene_flow = np.zeros([int((x[1] - x[0] + 1e-7) / self.voxel_size) for x in vol_bnds] + [3])
54 | mask = np.zeros([int((x[1] - x[0] + 1e-7) / self.voxel_size) for x in vol_bnds], dtype=np.int)
55 |
56 | cur_cnt = 0
57 | for obj_id, old_po_or in zip(self.object_ids, old_po_ors):
58 | position, orientation = p.getBasePositionAndOrientation(obj_id)
59 | new_coord = self._get_coord(obj_id, position, orientation, vol_bnds, self.voxel_size)
60 |
61 | position, orientation = old_po_or
62 | old_coord = self._get_coord(obj_id, position, orientation, vol_bnds, self.voxel_size)
63 |
64 | motion = new_coord - old_coord
65 |
66 | valid_idx = np.logical_and(
67 | np.logical_and(old_coord[:, 1] >= 0, old_coord[:, 1] < 128),
68 | np.logical_and(
69 | np.logical_and(old_coord[:, 0] >= 0, old_coord[:, 0] < 128),
70 | np.logical_and(old_coord[:, 2] >= 0, old_coord[:, 2] < 48)
71 | )
72 | )
73 | x = old_coord[valid_idx, 1]
74 | y = old_coord[valid_idx, 0]
75 | z = old_coord[valid_idx, 2]
76 | motion = motion[valid_idx]
77 | motion = np.stack([motion[:, 1], motion[:, 0], motion[:, 2]], axis=1)
78 |
79 | scene_flow[x, y, z] = motion
80 |
81 | # mask
82 | cur_cnt += 1
83 | mask[x, y, z] = cur_cnt
84 |
85 | return mask, scene_flow
86 |
87 | def _get_scene_flow_2d(self, old_po_or):
88 | old_coords_world = []
89 | new_coords_world = []
90 | point_id_list = []
91 | cur_cnt = 0
92 | for obj_id, po_or in zip(self.object_ids, old_po_or):
93 | position, orientation = po_or
94 | old_coord = self._get_coord(obj_id, position, orientation)
95 | old_coords_world.append(old_coord)
96 |
97 | position, orientation = p.getBasePositionAndOrientation(obj_id)
98 | new_coord = self._get_coord(obj_id, position, orientation)
99 | new_coords_world.append(new_coord)
100 |
101 | cur_cnt += 1
102 | point_id_list.append([cur_cnt for _ in range(old_coord.shape[0])])
103 |
104 | point_id = np.concatenate(point_id_list)
105 | old_coords_world = np.concatenate(old_coords_world)
106 | new_coords_world = np.concatenate(new_coords_world)
107 | camera_view_matrix = np.array(self.sim.camera_params[1]['camera_view_matrix']).reshape(4, 4).T
108 | camera_intr = self.sim.camera_params[1]['camera_intr']
109 | image_size = self.sim.camera_params[1]['camera_image_size']
110 | old_coords_2d = project_pts_to_2d(old_coords_world.T, camera_view_matrix, camera_intr)
111 | y = np.round(old_coords_2d[0]).astype(np.int)
112 | x = np.round(old_coords_2d[1]).astype(np.int)
113 | depth = old_coords_2d[2]
114 | valid_idx = np.logical_and(
115 | np.logical_and(x >= 0, x < image_size[0]),
116 | np.logical_and(y >= 0, y < image_size[1])
117 | )
118 | x = x[valid_idx]
119 | y = y[valid_idx]
120 | depth = depth[valid_idx]
121 | point_id = point_id[valid_idx]
122 | motion = (new_coords_world - old_coords_world)[valid_idx]
123 |
124 | sort_id = np.argsort(-depth)
125 | x = x[sort_id]
126 | y = y[sort_id]
127 | point_id = point_id[sort_id]
128 | motion = motion[sort_id]
129 | motion = np.stack([motion[:, 1], motion[:, 0], motion[:, 2]], axis=1)
130 |
131 | scene_flow = np.zeros([image_size[0], image_size[1], 3])
132 | mask = np.zeros([image_size[0], image_size[1]])
133 |
134 | scene_flow[x, y] = motion
135 | mask[x, y] = point_id
136 |
137 | return mask, scene_flow
138 |
139 |
140 | def check_occlusion(self):
141 | coords_world = []
142 | point_id_list = []
143 | cur_cnt = 0
144 | for obj_id in self.object_ids:
145 | position, orientation = p.getBasePositionAndOrientation(obj_id)
146 | coord = self._get_coord(obj_id, position, orientation)
147 | coords_world.append(coord)
148 | cur_cnt += 1
149 | point_id_list.append([cur_cnt for _ in range(coord.shape[0])])
150 | point_id = np.concatenate(point_id_list)
151 | coords_world = np.concatenate(coords_world)
152 | camera_view_matrix = np.array(self.sim.camera_params[1]['camera_view_matrix']).reshape(4, 4).T
153 | camera_intr = self.sim.camera_params[1]['camera_intr']
154 | image_size = self.sim.camera_params[1]['camera_image_size']
155 | coords_2d = project_pts_to_2d(coords_world.T, camera_view_matrix, camera_intr)
156 |
157 | y = np.round(coords_2d[0]).astype(np.int)
158 | x = np.round(coords_2d[1]).astype(np.int)
159 | depth = coords_2d[2]
160 | valid_idx = np.logical_and(
161 | np.logical_and(x >= 0, x < image_size[0]),
162 | np.logical_and(y >= 0, y < image_size[1])
163 | )
164 | x = x[valid_idx]
165 | y = y[valid_idx]
166 | depth = depth[valid_idx]
167 | point_id = point_id[valid_idx]
168 |
169 | sort_id = np.argsort(-depth)
170 | x = x[sort_id]
171 | y = y[sort_id]
172 | point_id = point_id[sort_id]
173 |
174 | mask = np.zeros([image_size[0], image_size[1]])
175 | mask[x, y] = point_id
176 |
177 | obj_num = len(self.object_ids)
178 | mask_sep = np.zeros([obj_num + 1, image_size[0], image_size[1]])
179 | mask_sep[point_id, x, y] = 1
180 | for i in range(obj_num):
181 | tot_pixel_num = np.sum(mask_sep[i + 1])
182 | vis_pixel_num = np.sum((mask == (i+1)).astype(np.float))
183 | if vis_pixel_num < 0.4 * tot_pixel_num:
184 | return False
185 | return True
186 |
187 | def _get_image_and_heightmap(self):
188 | color_image0, depth_image0 = self.sim.get_camera_data(self.sim.camera_params[0])
189 | color_image1, depth_image1 = self.sim.get_camera_data(self.sim.camera_params[1])
190 | color_image2, depth_image2 = self.sim.get_camera_data(self.sim.camera_params[2])
191 |
192 | color_heightmap, depth_heightmap = self.sim.get_heightmap(color_image2, depth_image2, self.sim.camera_params[2])
193 |
194 | self.current_depth_heightmap = depth_heightmap
195 | self.current_color_heightmap = color_heightmap
196 | self.current_depth_image0 = depth_image0
197 | self.current_color_image0 = color_image0
198 | self.current_depth_image1 = depth_image1
199 | self.current_color_image1 = color_image1
200 |
201 | def _random_drop(self, object_num, object_type):
202 | large_object_id = np.random.choice(object_num)
203 | if object_type == 'cube' or np.random.rand() < 0.1:
204 | large_object_id = -1
205 | self.large_object_id = large_object_id
206 | self.can_with_box = False
207 |
208 | while True:
209 | xy_pos = np.random.rand(object_num, 2) * 0.26 + np.asarray([0.5-0.13, -0.13])
210 | flag = True
211 | for i in range(object_num - 1):
212 | for j in range(i + 1, object_num):
213 | d = np.sqrt(np.sum((xy_pos[i] - xy_pos[j])**2))
214 | if i == large_object_id or j == large_object_id:
215 | if d < 0.13:
216 | flag = False
217 | else:
218 | if d < 0.07:
219 | flag = False
220 | if large_object_id != -1 and xy_pos[large_object_id][1] > -0.05 and np.random.rand() < 0.15:
221 | flag = False
222 | if flag:
223 | break
224 |
225 | xy_pos -= np.mean(xy_pos, 0)
226 | xy_pos += np.array([0.5, 0])
227 |
228 | for i in range(object_num):
229 | if object_type == 'cube':
230 | md = np.ones([60, 60, 70])
231 | coord = (np.asarray(np.nonzero(md)).T + 0.5 - np.array([30, 30, 35]))
232 | size_cube = np.random.choice([700, 750, 800, 850, 900, 1000, 1100, 1200, 1400])
233 | collision_id = p.createCollisionShape(p.GEOM_BOX, halfExtents=np.array([30, 30, 35]) / size_cube)
234 | body_id = p.createMultiBody(
235 | 0.05, collision_id, -1,
236 | [xy_pos[i, 0], xy_pos[i, 1], 0.2],
237 | p.getQuaternionFromEuler(np.random.rand(3) * np.pi)
238 | )
239 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05)
240 | p.changeVisualShape(body_id, -1, rgbaColor=np.concatenate([1 * np.random.rand(3), [1]]))
241 | self.object_ids.append(body_id)
242 | self.voxel_coord[body_id] = coord / size_cube
243 | time.sleep(0.2)
244 | elif object_type == 'ycb':
245 | # get object
246 | if i == large_object_id:
247 | obj_name = np.random.choice(self.ycb_info['large_list'])
248 | else:
249 | obj_name = np.random.choice(self.ycb_info['normal_list'])
250 |
251 | with open(osp.join(self.ycb_path, obj_name, 'model_com.binvox'), 'rb') as f:
252 | md = read_as_coord_array(f)
253 | coord = (md.data.T + 0.5) / md.dims * md.scale + md.translate
254 |
255 | # position & quat
256 | random_euler = [0, 0, np.random.rand() * 2 * np.pi]
257 | quat = p.getQuaternionFromEuler(random_euler)
258 | obj_position = [xy_pos[i, 0], xy_pos[i, 1], np.max(-coord[:, 2]) + 0.01]
259 |
260 | urdf_path = osp.join(self.ycb_path, obj_name, 'obj.urdf')
261 | body_id = p.loadURDF(
262 | fileName=urdf_path,
263 | basePosition=obj_position,
264 | baseOrientation=quat,
265 | globalScaling=1
266 | )
267 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05)
268 |
269 | self.object_ids.append(body_id)
270 | self.voxel_coord[body_id] = coord
271 | time.sleep(2)
272 | elif (object_type=='shapenet' and i != large_object_id and np.random.rand() < 0.3) or \
273 | (object_type=='shapenet' and i == large_object_id and np.random.rand() < 0.12):
274 | box_size = 'small'
275 | if i == large_object_id:
276 | box_size='large'
277 | elif self.can_with_box and np.random.rand() < 0.1:
278 | self.can_with_box=False
279 | box_size='large'
280 |
281 | if box_size == 'large':
282 | dim_x = np.random.choice(list(range(35, 55)))
283 | dim_y = np.random.choice(list(range(70, 85)))
284 | dim_z = np.random.choice(list(range(15, 30)))
285 | else:
286 | dim_x = np.random.choice(list(range(25, 40)))
287 | dim_y = np.random.choice(list(range(30, 60)))
288 | dim_z = np.random.choice(list(range(15, 25)) + [30, 32])
289 | md = np.ones([dim_x, dim_y, dim_z])
290 | coord = (np.asarray(np.nonzero(md)).T + 0.5 - np.array([dim_x / 2, dim_y / 2, dim_z / 2]))
291 | size_cube = 500
292 | collision_id = p.createCollisionShape(p.GEOM_BOX, halfExtents=np.array(
293 | [dim_x / 2, dim_y / 2, dim_z / 2]) / size_cube)
294 | body_id = p.createMultiBody(
295 | 0.05, collision_id, -1,
296 | [xy_pos[i, 0], xy_pos[i, 1], 0.1],
297 | [xy_pos[i, 0], xy_pos[i, 1], 0.1],
298 | p.getQuaternionFromEuler([0, 0, np.random.rand() * np.pi])
299 | )
300 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05)
301 | p.changeVisualShape(body_id, -1, rgbaColor=np.concatenate([1 * np.random.rand(3), [1]]))
302 | self.object_ids.append(body_id)
303 | self.voxel_coord[body_id] = coord / size_cube
304 | time.sleep(0.2)
305 | else:
306 | object_cat_cur = np.random.choice(list(self.shapenet_info.keys()))
307 | if np.random.rand() < 0.3:
308 | object_cat_cur = 'can'
309 | if i == large_object_id and object_cat_cur == 'can':
310 | self.can_with_box=True
311 |
312 | category_id = self.shapenet_info[object_cat_cur]['category_id']
313 | tmp = np.random.choice(len(self.shapenet_info[object_cat_cur]['object_id']))
314 | object_id = self.shapenet_info[object_cat_cur]['object_id'][tmp]
315 | urdf_path = osp.join(self.shapenet_path, '%s/%s/obj.urdf' % (category_id, object_id))
316 |
317 | # load object
318 | if i == large_object_id:
319 | scaling_range = self.shapenet_info[object_cat_cur]['large_scaling']
320 | else:
321 | scaling_range = self.shapenet_info[object_cat_cur]['global_scaling']
322 |
323 | globalScaling = np.random.rand() * (scaling_range[1] - scaling_range[0]) + scaling_range[0]
324 |
325 | # save nonzero voxel coord
326 | with open(osp.join(self.shapenet_path, '%s/%s/model_com.binvox' % (category_id, object_id)),
327 | 'rb') as f:
328 | md = read_as_coord_array(f)
329 | coord = (md.data.T + 0.5) / md.dims * md.scale + md.translate
330 | coord = coord * 0.15 * globalScaling # 0.15 is the rescale value in .urdf
331 |
332 | # position & quat
333 | random_euler = [0, 0, np.random.rand() * 2 * np.pi]
334 | quat = p.getQuaternionFromEuler(random_euler)
335 | obj_position = [xy_pos[i, 0], xy_pos[i, 1], np.max(-coord[:, 2]) + 0.01]
336 |
337 | body_id = p.loadURDF(
338 | fileName=urdf_path,
339 | basePosition=obj_position,
340 | baseOrientation=quat,
341 | globalScaling=globalScaling
342 | )
343 |
344 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05)
345 | p.changeVisualShape(body_id, -1, rgbaColor=np.concatenate([1 * np.random.rand(3), [1]]))
346 | self.object_ids.append(body_id)
347 | self.voxel_coord[body_id] = coord
348 | time.sleep(0.2)
349 | for obj_id in self.object_ids:
350 | self.cnt_dict[obj_id] = 0
351 | init_p = p.getBasePositionAndOrientation(obj_id)[0]
352 | self.init_position[obj_id] = np.asarray(init_p[:2])
353 | self.last_direction[obj_id] = None
354 |
355 | # for heightmap
356 | def coord2pixel(self, x_coord, y_coord):
357 | x_pixel = int((x_coord - self.view_bounds[0, 0]) / self.heightmap_pixel_size)
358 | y_pixel = int((y_coord - self.view_bounds[1, 0]) / self.heightmap_pixel_size)
359 | return x_pixel, y_pixel
360 |
361 | def pixel2coord(self, x_pixel, y_pixel):
362 | x_coord = x_pixel * self.heightmap_pixel_size + self.view_bounds[0, 0]
363 | y_coord = y_pixel * self.heightmap_pixel_size + self.view_bounds[1, 0]
364 | return x_coord, y_coord
365 |
366 | def policy_generation(self):
367 | def softmax(input):
368 | value = np.exp(input)
369 | output = value / np.sum(value)
370 | return output
371 |
372 | # choose object
373 | value = []
374 | for x in self.object_ids:
375 | t = self.cnt_dict[x]
376 | if t > 2:
377 | t = -2
378 | elif t > 1 and np.random.rand() < 0.5:
379 | t = -2
380 | self.cnt_dict[x] = t
381 | value.append(t)
382 | if self.large_object_id != -1:
383 | value[self.large_object_id] += 0.5
384 | obj_id = np.random.choice(self.object_ids, p=softmax(np.array(value)))
385 |
386 | # get position
387 | position = p.getBasePositionAndOrientation(obj_id)[0]
388 | position = np.asarray([position[0], position[1]])
389 |
390 | # choose direction
391 | direction_value = [0 for i in range(self.direction_num)]
392 | for d in range(self.direction_num):
393 | ang = 2 * np.pi * d / self.direction_num
394 | unit_vec = np.asarray([np.cos(ang), np.sin(ang)])
395 |
396 | off_direction = np.asarray([0.5, 0]) - position
397 | off_direction_unit = off_direction / np.sqrt(np.sum(off_direction ** 2))
398 | weight = 5 if self.last_direction[obj_id] is None else 1
399 | direction_value[d] += weight * np.sum(off_direction_unit * unit_vec) * np.exp(np.sum(np.abs(off_direction)))
400 | if np.sqrt(np.sum(off_direction ** 2)) > 0.2 and np.sum(off_direction_unit * unit_vec) < 0:
401 | direction_value[d] -= 10
402 |
403 | for obj_id_enm in self.object_ids:
404 | if obj_id_enm != obj_id:
405 | obj_position_enm = p.getBasePositionAndOrientation(obj_id_enm)[0]
406 | off_direction = np.asarray(obj_position_enm[:2]) - position
407 | off_direction_unit = off_direction / np.sqrt(np.sum(off_direction ** 2))
408 | if np.sqrt(np.sum(off_direction ** 2)) < 0.15 and np.sum(off_direction_unit * unit_vec) > 0.4:
409 | direction_value[d] -= 3 * np.sum(off_direction_unit * unit_vec)
410 | if np.sqrt(np.sum(off_direction ** 2)) < 0.15 and np.abs(
411 | np.sum(off_direction_unit * unit_vec)) < 0.3:
412 | direction_value[d] += 1.5
413 |
414 | off_direction = position - self.init_position[obj_id]
415 | if np.sum(np.abs(off_direction)) > 0.001:
416 | off_direction_unit = off_direction / np.sqrt(np.sum(off_direction ** 2))
417 | if np.sqrt(np.sum(off_direction ** 2)) < 0.25 and np.sum(off_direction_unit * unit_vec) < 0:
418 | direction_value[d] += 1.5 * np.sum(off_direction_unit * unit_vec)
419 |
420 | if self.last_direction[obj_id] is not None:
421 | ang_last = 2 * np.pi * self.last_direction[obj_id] / self.direction_num
422 | unit_vec_last = np.asarray([np.cos(ang_last), np.sin(ang_last)])
423 | direction_value[d] += 2 * np.sum(unit_vec_last * unit_vec)
424 |
425 | direction = np.random.choice(self.direction_num, p=softmax(np.array(direction_value) / 2))
426 |
427 | direction_angle = direction / 4.0 * np.pi
428 |
429 | pos = position
430 | pos -= np.asarray([np.cos(direction_angle), np.sin(direction_angle)]) * 0.04
431 |
432 | for _ in range(5):
433 | x_coord, y_coord = pos[0], pos[1]
434 | x_pixel, y_pixel = self.coord2pixel(x_coord, y_coord)
435 | pos -= np.asarray([np.cos(direction_angle), np.sin(direction_angle)]) * 0.01
436 | if min(x_pixel, y_pixel) < 0 or max(x_pixel, y_pixel) >= 128:
437 | continue
438 |
439 | detection_mask = cv2.circle(np.zeros(self.heightmap_size), (x_pixel, y_pixel), 5, 1, thickness=-1)
440 | if np.max(self.current_depth_heightmap * detection_mask) > 0.005:
441 | continue
442 |
443 | d = 0.04
444 | new_pixel = self.coord2pixel(x_coord + np.cos(direction_angle) * d, y_coord + np.sin(direction_angle) * d)
445 | detection_mask = cv2.circle(np.zeros(self.heightmap_size), new_pixel, 3, 1, thickness=-1)
446 | if np.max(detection_mask * self.current_depth_heightmap) > 0.005:
447 | self.last_direction[obj_id] = direction
448 | self.cnt_dict[obj_id] += 1
449 | return x_pixel, y_pixel, x_coord, y_coord, 0.005, direction
450 |
451 | return None
452 |
453 |
454 | def poke(self):
455 | # log the current position & quat
456 | old_po_ors = [p.getBasePositionAndOrientation(object_id) for object_id in self.object_ids]
457 | output = self.get_scene_info()
458 |
459 | # generate action
460 | policy = None
461 | while policy is None:
462 | policy = self.policy_generation()
463 |
464 | x_pixel, y_pixel, x_coord, y_coord, z_coord, direction = policy
465 |
466 | # take action
467 | self.sim.primitive_push(
468 | position=[x_coord, y_coord, z_coord],
469 | rotation_angle=direction / 4.0 * np.pi,
470 | speed=0.005,
471 | distance=0.15
472 | )
473 | self.sim.robot_go_home()
474 | action = {'0': direction, '1': y_pixel, '2': x_pixel}
475 |
476 | mask_3d, scene_flow_3d = self._get_scene_flow_3d(old_po_ors)
477 | mask_2d, scene_flow_2d = self._get_scene_flow_2d(old_po_ors)
478 |
479 | output['action'] = action
480 | output['mask_3d'] = mask_3d
481 | output['scene_flow_3d'] = scene_flow_3d
482 | output['mask_2d'] = mask_2d
483 | output['scene_flow_2d'] = scene_flow_2d
484 |
485 | return output
486 |
487 |
488 | def get_scene_info(self, mask_info=False):
489 | self._get_image_and_heightmap()
490 |
491 | positions, orientations = [], []
492 | for i, obj_id in enumerate(self.object_ids):
493 | info = p.getBasePositionAndOrientation(obj_id)
494 | positions.append(info[0])
495 | orientations.append(info[1])
496 |
497 | scene_info = {
498 | 'color_heightmap': self.current_color_heightmap,
499 | 'depth_heightmap': self.current_depth_heightmap,
500 | 'color_image': self.current_color_image0,
501 | 'depth_image': self.current_depth_image0,
502 | 'color_image_small': self.current_color_image1,
503 | 'depth_image_small': self.current_depth_image1,
504 | 'positions': np.array(positions),
505 | 'orientations': np.array(orientations)
506 | }
507 | if mask_info:
508 | old_po_ors = [p.getBasePositionAndOrientation(object_id) for object_id in self.object_ids]
509 | mask_3d, scene_flow_3d = self._get_scene_flow_3d(old_po_ors)
510 | mask_2d, scene_flow_2d = self._get_scene_flow_2d(old_po_ors)
511 | scene_info['mask_3d'] = mask_3d
512 | scene_info['mask_2d'] = mask_2d
513 |
514 | return scene_info
515 |
516 |
517 | def reset(self, object_num=4, object_type=None):
518 | if object_type is None:
519 | object_type = self.object_type
520 |
521 | while True:
522 | # remove objects
523 | for obj_id in self.object_ids:
524 | p.removeBody(obj_id)
525 | self.object_ids = []
526 | self.cnt_dict = {}
527 |
528 | # load fences
529 | self.fence_id = p.loadURDF(
530 | fileName='assets/fence/tinker.urdf',
531 | basePosition=[0.5, 0, 0.001],
532 | baseOrientation=p.getQuaternionFromEuler([0, 0, 0]),
533 | useFixedBase=True
534 | )
535 |
536 | # load objects
537 | self._random_drop(object_num, object_type)
538 | time.sleep(1)
539 | p.removeBody(self.fence_id)
540 | old_ps = np.array([p.getBasePositionAndOrientation(object_id)[0] for object_id in self.object_ids])
541 | for _ in range(10):
542 | time.sleep(1)
543 | new_ps = np.array([p.getBasePositionAndOrientation(object_id)[0] for object_id in self.object_ids])
544 | if np.sum((new_ps - old_ps) ** 2) < 1e-6:
545 | break
546 | old_ps = new_ps
547 | self._get_image_and_heightmap()
548 |
549 | # check occlusion
550 | if self.check_occlusion():
551 | return
552 |
553 |
554 | if __name__ == '__main__':
555 | env = SimulationEnv(gui_enabled=False)
556 | env.reset(4, 'ycb')
557 |
558 | # if you just want to get the information of the scene, use env.get_scene_info
559 | output = env.get_scene_info(mask_info=True)
560 | print(output.keys())
561 |
562 | # if use the pushing. env.poke() will also give you everything, together with scene flow
563 | output = env.poke()
564 | print(output.keys())
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 | from tqdm import tqdm
5 | import os.path as osp
6 | from data import Data
7 | from torch.utils.data import DataLoader
8 | from model import ModelDSR
9 | import itertools
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 |
14 | parser.add_argument('--resume', type=str, help='path to model')
15 | parser.add_argument('--data_path', type=str, help='path to data')
16 | parser.add_argument('--test_type', type=str, choices=['motion_visible', 'motion_full', 'mask_ordered', 'mask_unordered'])
17 |
18 | parser.add_argument('--gpu', type=int, default=0, help='gpu id (single gpu)')
19 | parser.add_argument('--object_num', type=int, default=5, help='number of objects')
20 | parser.add_argument('--seq_len', type=int, default=10, help='sequence length')
21 | parser.add_argument('--batch', type=int, default=12, help='batch size')
22 | parser.add_argument('--workers', type=int, default=2, help='number of workers in data loader')
23 |
24 | parser.add_argument('--model_type', type=str, default='dsr', choices=['dsr', 'single', 'nowarp', 'gtwarp', '3dflow'])
25 | parser.add_argument('--transform_type', type=str, default='se3euler', choices=['affine', 'se3euler', 'se3aa', 'se3spquat', 'se3quat'])
26 |
27 | def main():
28 | args = parser.parse_args()
29 | torch.cuda.set_device(args.gpu)
30 |
31 | data, loaders = {}, {}
32 | for split in ['test']:
33 | data[split] = Data(data_path=args.data_path, split=split, seq_len=args.seq_len)
34 | loaders[split] = DataLoader(dataset=data[split], batch_size=args.batch, num_workers=args.workers)
35 | print('==> dataset loaded: [size] = {0}'.format(len(data['test'])))
36 |
37 |
38 | model = ModelDSR(
39 | object_num=args.object_num,
40 | transform_type=args.transform_type,
41 | motion_type='se3' if args.model_type != '3dflow' else 'conv',
42 | )
43 | model.cuda()
44 |
45 | checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{args.gpu}'))
46 | model.load_state_dict(checkpoint['state_dict'])
47 | print('==> resume: ' + args.resume)
48 |
49 | with torch.no_grad():
50 | if args.test_type == 'motion_visible':
51 | evaluation_motion_visible(args, model, loaders['test'])
52 |
53 | if args.test_type == 'motion_full':
54 | evaluation_motion_full(args, model, loaders['test'])
55 |
56 | if args.test_type == 'mask_ordered':
57 | evaluation_mask_ordered(args, model, loaders['test'])
58 |
59 | if args.test_type == 'mask_unordered':
60 | evaluation_mask_unordered(args, model, loaders['test'])
61 |
62 | def evaluation_mask_unordered(args, model, loader):
63 | print(f'==> evaluation_mask (unordered)')
64 | iou_dict = [[] for _ in range(args.seq_len)]
65 | for batch in tqdm(loader):
66 | batch_size = batch['0-action'].size(0)
67 | last_s = model.get_init_repr(batch_size).cuda()
68 | logit_pred_list, mask_gt_list = [], []
69 | for step_id in range(args.seq_len):
70 | output = model(
71 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1),
72 | last_s=last_s,
73 | input_action=batch['%d-action' % step_id].cuda(),
74 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None,
75 | no_warp=args.model_type=='nowarp'
76 | )
77 | if not args.model_type == 'single':
78 | last_s = output['s'].data
79 |
80 | logit_pred = output['init_logit']
81 | mask_gt = batch['%d-mask_3d' % step_id].cuda()
82 | iou_unordered = calc_iou_unordered(logit_pred, mask_gt)
83 | iou_dict[step_id].append(iou_unordered)
84 | print('mask_unordered (IoU) = ', np.mean([np.mean(np.concatenate(iou_dict[i])) for i in range(args.seq_len)]))
85 |
86 |
87 | def calc_iou_unordered(logit_pred, mask_gt_argmax):
88 | # logit_pred: [B, K, S1, S2, S3], softmax, the last channel is empty
89 | # mask_gt_argmax: [B, S1, S2, S3], 0 represents empty
90 | B, K, S1, S2, S3 = logit_pred.size()
91 | logit_pred_argmax = torch.argmax(logit_pred, dim=1, keepdim=True)
92 | mask_gt_argmax = torch.unsqueeze(mask_gt_argmax, 1)
93 | mask_pred_onehot = torch.zeros_like(logit_pred).scatter(1, logit_pred_argmax, 1)[:, :-1]
94 | mask_gt_onehot = torch.zeros_like(logit_pred).scatter(1, mask_gt_argmax, 1)[:, 1:]
95 | K -= 1
96 | info_dict = {'I': np.zeros([B, K, K]), 'U': np.zeros([B, K, K])}
97 | for b in range(B):
98 | for i in range(K):
99 | for j in range(K):
100 | mask_gt = mask_gt_onehot[b, i]
101 | mask_pred = mask_pred_onehot[b, j]
102 | I = torch.sum(mask_gt * mask_pred).item()
103 | U = torch.sum(mask_gt + mask_pred).item() - I
104 | info_dict['I'][b, i, j] = I
105 | info_dict['U'][b, i, j] = U
106 | batch_ious = []
107 | for b in range(B):
108 | best_iou, best_p = 0, None
109 | for p in list(itertools.permutations(range(K))):
110 | cur_I = [info_dict['I'][b, i, p[i]] for i in range(K)]
111 | cur_U = [info_dict['U'][b, i, p[i]] for i in range(K)]
112 | cur_iou = np.mean(np.array(cur_I) / np.maximum(np.array(cur_U), 1))
113 | if cur_iou > best_iou:
114 | best_iou = cur_iou
115 | batch_ious.append(best_iou)
116 |
117 | return np.array(batch_ious)
118 |
119 |
120 | def evaluation_mask_ordered(args, model, loader):
121 | print(f'==> evaluation_mask (ordered)')
122 | iou_dict = []
123 | for batch in tqdm(loader):
124 | batch_size = batch['0-action'].size(0)
125 | last_s = model.get_init_repr(batch_size).cuda()
126 | logit_pred_list, mask_gt_list = [], []
127 | for step_id in range(args.seq_len):
128 | output = model(
129 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1),
130 | last_s=last_s,
131 | input_action=batch['%d-action' % step_id].cuda(),
132 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None,
133 | no_warp=args.model_type=='nowarp'
134 | )
135 | if not args.model_type == 'single':
136 | last_s = output['s'].data
137 |
138 | logit_pred = output['init_logit']
139 | mask_gt = batch['%d-mask_3d' % step_id].cuda()
140 | logit_pred_list.append(logit_pred)
141 | mask_gt_list.append(mask_gt)
142 | iou_ordered = calc_iou_ordered(logit_pred_list, mask_gt_list)
143 | iou_dict.append(iou_ordered)
144 | print('mask_ordered (IoU) = ', np.mean(np.concatenate(iou_dict)))
145 |
146 |
147 | def calc_iou_ordered(logit_pred_list, mask_gt_argmax_list):
148 | # logit_pred_list: [L, B, K, S1, S2, S3], softmax, the last channel is empty
149 | # mask_gt_argmax_list: [L, B, S1, S2, S3], 0 represents empty
150 | L = len(logit_pred_list)
151 | B, K, S1, S2, S3 = logit_pred_list[0].size()
152 | K -= 1
153 | info_dict = {'I': np.zeros([L, B, K, K]), 'U': np.zeros([L, B, K, K])}
154 | for l in range(L):
155 | logit_pred = logit_pred_list[l]
156 | mask_gt_argmax = mask_gt_argmax_list[l]
157 | logit_pred_argmax = torch.argmax(logit_pred, dim=1, keepdim=True)
158 | mask_gt_argmax = torch.unsqueeze(mask_gt_argmax, 1)
159 | mask_pred_onehot = torch.zeros_like(logit_pred).scatter(1, logit_pred_argmax, 1)[:, :-1]
160 | mask_gt_onehot = torch.zeros_like(logit_pred).scatter(1, mask_gt_argmax, 1)[:, 1:]
161 | for b in range(B):
162 | for i in range(K):
163 | for j in range(K):
164 | mask_gt = mask_gt_onehot[b, i]
165 | mask_pred = mask_pred_onehot[b, j]
166 | I = torch.sum(mask_gt * mask_pred).item()
167 | U = torch.sum(mask_gt + mask_pred).item() - I
168 | info_dict['I'][l, b, i, j] = I
169 | info_dict['U'][l, b, i, j] = U
170 | batch_ious = []
171 | for b in range(B):
172 | best_iou, best_p = 0, None
173 | for p in list(itertools.permutations(range(K))):
174 | cur_I = [info_dict['I'][l, b, i, p[i]] for l in range(L) for i in range(K)]
175 | cur_U = [info_dict['U'][l, b, i, p[i]] for l in range(L) for i in range(K)]
176 | cur_iou = np.mean(np.array(cur_I) / np.maximum(np.array(cur_U), 1))
177 | if cur_iou > best_iou:
178 | best_iou = cur_iou
179 | batch_ious.append(best_iou)
180 |
181 | return np.array(batch_ious)
182 |
183 |
184 | def evaluation_motion_visible(args, model, loader):
185 | print('==> evaluation_motion (visible surface)')
186 | mse_dict = [0 for _ in range(args.seq_len)]
187 | data_num = 0
188 | for batch in tqdm(loader):
189 | batch_size = batch['0-action'].size(0)
190 | data_num += batch_size
191 | last_s = model.get_init_repr(batch_size).cuda()
192 | for step_id in range(args.seq_len):
193 | output = model(
194 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1),
195 | last_s=last_s,
196 | input_action=batch['%d-action' % step_id].cuda(),
197 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None,
198 | no_warp=args.model_type=='nowarp'
199 | )
200 | if not args.model_type in ['single', '3dflow'] :
201 | last_s = output['s'].data
202 |
203 | tsdf = batch['%d-tsdf' % step_id].cuda().unsqueeze(1)
204 | mask = batch['%d-mask_3d' % step_id].cuda().unsqueeze(1)
205 | surface_mask = ((tsdf > -0.99).float()) * ((tsdf < 0).float()) * ((mask > 0).float())
206 | surface_mask[..., 0] = 0
207 |
208 | target = batch['%d-scene_flow_3d' % step_id].cuda()
209 | pred = output['motion']
210 |
211 | mse = torch.sum((target - pred) ** 2 * surface_mask, dim=[1, 2, 3, 4]) / torch.sum(surface_mask, dim=[1, 2, 3, 4])
212 | mse_dict[step_id] += torch.sum(mse).item() * 0.16
213 | # 0.16(0.4^2) is the scale to convert the unit from "voxel" to "cm".
214 | # The voxel size is 0.4cm. Here we use seuqre error.
215 | print('motion_visible (MSE in cm) = ', np.mean([np.mean(mse_dict[i]) / data_num for i in range(args.seq_len)]))
216 |
217 |
218 | def evaluation_motion_full(args, model, loader):
219 | print('==> evaluation_motion (full volume)')
220 | mse_dict = [0 for _ in range(args.seq_len)]
221 | data_num = 0
222 | for batch in tqdm(loader):
223 | batch_size = batch['0-action'].size(0)
224 | data_num += batch_size
225 | last_s = model.get_init_repr(batch_size).cuda()
226 | for step_id in range(args.seq_len):
227 | output = model(
228 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1),
229 | last_s=last_s,
230 | input_action=batch['%d-action' % step_id].cuda(),
231 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None,
232 | no_warp=args.model_type=='nowarp'
233 | )
234 | if not args.model_type in ['single', '3dflow'] :
235 | last_s = output['s'].data
236 |
237 | target = batch['%d-scene_flow_3d' % step_id].cuda()
238 | pred = output['motion']
239 |
240 | mse = torch.mean((target - pred) ** 2, dim=[1, 2, 3, 4])
241 | mse_dict[step_id] += torch.sum(mse).item() * 0.16
242 | # 0.16(0.4^2) is the scale to convert the unit from "voxel" to "cm".
243 | # The voxel size is 0.4cm. Here we use seuqre error.
244 | print('motion_full (MSE in cm) = ', np.mean([np.mean(mse_dict[i]) / data_num for i in range(args.seq_len)]))
245 |
246 |
247 | if __name__ == '__main__':
248 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import torch
3 | from torch import nn
4 | import torch.multiprocessing as mp
5 | import torch.distributed as dist
6 | from torch.utils.data import DataLoader
7 | from torch.utils.tensorboard import SummaryWriter
8 | import argparse
9 | import itertools
10 | import shutil
11 | from tqdm import tqdm
12 | from data import Data
13 | from utils import mkdir, flow2im, html_visualize, mask_visualization, tsdf_visualization
14 | from model import ModelDSR
15 |
16 | parser = argparse.ArgumentParser()
17 |
18 | # exp args
19 | parser.add_argument('--exp', type=str, help='name of exp')
20 | parser.add_argument('--gpus', type=int, nargs='+', help='list of gpus to be used, separated by space')
21 | parser.add_argument('--resume', default=None, type=str, help='path to model or exp, None means training from scratch')
22 |
23 | # data args
24 | parser.add_argument('--data_path', type=str, help='path to data')
25 | parser.add_argument('--object_num', type=int, default=5, help='number of objects')
26 | parser.add_argument('--seq_len', type=int, default=10, help='sequence length for training')
27 | parser.add_argument('--batch', type=int, default=12, help='batch size per gpu')
28 | parser.add_argument('--workers', type=int, default=4, help='number of workers per gpu')
29 |
30 | parser.add_argument('--model_type', type=str, default='dsr', choices=['dsr', 'single', 'nowarp', 'gtwarp', '3dflow'])
31 | parser.add_argument('--transform_type', type=str, default='se3euler', choices=['affine', 'se3euler', 'se3aa', 'se3spquat', 'se3quat'])
32 |
33 | # loss args
34 | parser.add_argument('--alpha_motion', type=float, default=1.0, help='weight of motino loss (MSE)')
35 | parser.add_argument('--alpha_mask', type=float, default=5.0, help='weight of mask loss (BCE)')
36 |
37 | # training args
38 | parser.add_argument('--snapshot_freq', type=int, default=1, help='snapshot frequency')
39 | parser.add_argument('--epoch', type=int, default=30, help='number of training eposhes')
40 | parser.add_argument('--finetune', dest='finetune', action='store_true',
41 | help='finetuning or training from scratch ==> different learning rate strategies')
42 |
43 | # distributed training args
44 | parser.add_argument('--seed', type=int, default=23333, help='random seed')
45 | parser.add_argument('--dist_backend', type=str, default='nccl', help='distributed training backend')
46 | parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:2333', help='distributed training url')
47 |
48 |
49 | def main():
50 | args = parser.parse_args()
51 |
52 | # loss types & loss_idx
53 | loss_types = ['all', 'motion', 'mask']
54 | loss_idx = {}
55 | for i, loss_type in enumerate(loss_types):
56 | loss_idx[loss_type] = i
57 | print('==> loss types: ', loss_types)
58 | args.loss_types = loss_types
59 | args.loss_idx = loss_idx
60 |
61 | # check sequence length
62 | if args.model_type == 'single':
63 | assert(args.seq_len == 1)
64 |
65 | # resume
66 | if args.resume is not None and not args.resume.endswith('.pth'):
67 | args.resume = osp.join('exp', args.resume, 'models/latest.pth')
68 |
69 | # dir & args
70 | exp_dir = osp.join('exp', args.exp)
71 | mkdir(exp_dir)
72 |
73 | print('==> arguments parsed')
74 | str_list = []
75 | for key in vars(args):
76 | print('[{0}] = {1}'.format(key, getattr(args, key)))
77 | str_list.append('--{0}={1} \\'.format(key, getattr(args, key)))
78 |
79 | args.model_dir = osp.join(exp_dir, 'models')
80 | mkdir(args.model_dir)
81 | args.visualization_dir = osp.join(exp_dir, 'visualization')
82 | mkdir(args.visualization_dir)
83 |
84 | mp.spawn(main_worker, nprocs=len(args.gpus), args=(len(args.gpus), args))
85 |
86 |
87 | def main_worker(rank, world_size, args):
88 | args.gpu = args.gpus[rank]
89 | if rank == 0:
90 | writer = SummaryWriter(osp.join('exp', args.exp))
91 | print(f'==> Rank={rank}, Use GPU: {args.gpu} for training.')
92 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=world_size, rank=rank)
93 |
94 | torch.cuda.set_device(args.gpu)
95 |
96 | model = ModelDSR(
97 | object_num=args.object_num,
98 | transform_type=args.transform_type,
99 | motion_type='se3' if args.model_type != '3dflow' else 'conv',
100 | )
101 |
102 | model.cuda()
103 | optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.95))
104 |
105 | if args.resume is not None:
106 | checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{args.gpu}'))
107 | model.load_state_dict(checkpoint['state_dict'])
108 | print(f'==> rank={rank}, loaded checkpoint {args.resume}')
109 |
110 | data, samplers, loaders = {}, {}, {}
111 | for split in ['train', 'test']:
112 | data[split] = Data(data_path=args.data_path, split=split, seq_len=args.seq_len)
113 | samplers[split] = torch.utils.data.distributed.DistributedSampler(data[split])
114 | loaders[split] = DataLoader(
115 | dataset=data[split],
116 | batch_size=args.batch,
117 | num_workers=args.workers,
118 | sampler=samplers[split],
119 | pin_memory=False
120 | )
121 | print('==> dataset loaded: [size] = {0} + {1}'.format(len(data['train']), len(data['test'])))
122 |
123 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
124 |
125 | for epoch in range(args.epoch):
126 | samplers['train'].set_epoch(epoch)
127 | lr = adjust_learning_rate(optimizer, epoch, args)
128 | if rank == 0:
129 | print(f'==> epoch = {epoch}, lr = {lr}')
130 |
131 | with torch.enable_grad():
132 | loss_tensor_train = iterate(loaders['train'], model, optimizer, rank, args)
133 | with torch.no_grad():
134 | loss_tensor_test = iterate(loaders['test'], model, None, rank, args)
135 |
136 | # tensorboard log
137 | loss_tensor = torch.stack([loss_tensor_train, loss_tensor_test]).cuda()
138 | torch.distributed.all_reduce(loss_tensor)
139 | if rank == 0:
140 | training_step = (epoch + 1) * len(data['train'])
141 | loss_tensor = loss_tensor.cpu().numpy()
142 | for i, split in enumerate(['train', 'test']):
143 | for j, loss_type in enumerate(args.loss_types):
144 | for step_id in range(args.seq_len):
145 | writer.add_scalar(
146 | '%s-loss_%s/%d' % (split, loss_type, step_id),
147 | loss_tensor[i, j, step_id] / len(data[split]), epoch+1)
148 | writer.add_scalar('learning_rate', lr, epoch + 1)
149 |
150 | if rank == 0 and (epoch + 1) % args.snapshot_freq == 0:
151 | visualize(loaders, model, epoch, args)
152 | save_state = {
153 | 'state_dict': model.module.state_dict(),
154 | }
155 | torch.save(save_state, osp.join(args.model_dir, 'latest.pth'))
156 | shutil.copyfile(
157 | osp.join(args.model_dir, 'latest.pth'),
158 | osp.join(args.model_dir, 'epoch_%d.pth' % (epoch + 1))
159 | )
160 |
161 | def adjust_learning_rate(optimizer, epoch, args):
162 | if args.finetune:
163 | if epoch < 5:
164 | lr = 5e-4
165 | elif epoch < 10:
166 | lr = 2e-4
167 | elif epoch < 15:
168 | lr = 5e-5
169 | else:
170 | lr = 1e-5
171 |
172 | else:
173 | if epoch < 2:
174 | lr = 1e-5
175 | elif epoch < 5:
176 | lr = 1e-3
177 | elif epoch < 10:
178 | lr = 5e-4
179 | elif epoch < 20:
180 | lr = 2e-4
181 | elif epoch < 25:
182 | lr = 5e-5
183 | else:
184 | lr = 1e-5
185 |
186 | for param_group in optimizer.param_groups:
187 | param_group['lr'] = lr
188 | return lr
189 |
190 |
191 | def iterate(loader, model, optimizer, rank, args):
192 | motion_metric = nn.MSELoss()
193 | loss_tensor = torch.zeros([len(args.loss_types), args.seq_len])
194 | if rank == 0:
195 | loader = tqdm(loader, desc='test' if optimizer is None else 'train')
196 | for batch in loader:
197 | batch_size = batch['0-action'].size(0)
198 | last_s = model.module.get_init_repr(batch_size).cuda()
199 | batch_order = None
200 |
201 | for step_id in range(args.seq_len):
202 | output = model(
203 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1),
204 | last_s=last_s,
205 | input_action=batch['%d-action' % step_id].cuda(),
206 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None,
207 | no_warp=args.model_type=='nowarp'
208 | )
209 | last_s = output['s'].data
210 | loss = 0
211 |
212 | if 'motion' in args.loss_types:
213 | loss_motion = motion_metric(
214 | output['motion'],
215 | batch['%d-scene_flow_3d' % step_id].cuda()
216 | )
217 | loss_tensor[args.loss_idx['motion'], step_id] += loss_motion.item() * batch_size
218 | loss += args.alpha_motion * loss_motion
219 |
220 | if 'mask' in args.loss_types:
221 | mask_gt = batch['%d-mask_3d' % step_id].cuda()
222 | if batch_order is None:
223 | batch_order = get_batch_order(output['init_logit'], mask_gt)
224 | loss_mask = get_mask_loss(output['init_logit'], mask_gt, batch_order)
225 | loss_tensor[args.loss_idx['mask'], step_id] += loss_mask.item() * batch_size
226 | loss += args.alpha_mask * loss_mask
227 |
228 | loss_tensor[args.loss_idx['all'], step_id] += loss.item() * batch_size
229 |
230 | if optimizer is not None:
231 | optimizer.zero_grad()
232 | loss.backward()
233 | optimizer.step()
234 |
235 | if step_id != args.seq_len - 1:
236 | batch_order = get_batch_order(output['init_logit'], mask_gt)
237 | return loss_tensor
238 |
239 |
240 | def get_batch_order(logit_pred, mask_gt):
241 | batch_order = []
242 | B, K, S1, S2, S3 = logit_pred.size()
243 | sum = 0
244 | for b in range(B):
245 | all_p = list(itertools.permutations(list(range(K - 1))))
246 | best_loss, best_p = None, None
247 | for p in all_p:
248 | permute_pred = torch.stack(
249 | [logit_pred[b:b + 1, -1]] + [logit_pred[b:b + 1, i] for i in p],
250 | dim=1).contiguous()
251 | cur_loss = nn.CrossEntropyLoss()(permute_pred, mask_gt[b:b + 1]).item()
252 | if best_loss is None or cur_loss < best_loss:
253 | best_loss = cur_loss
254 | best_p = p
255 | batch_order.append(best_p)
256 | sum += best_loss
257 | return batch_order
258 |
259 |
260 | def get_mask_loss(logit_pred, mask_gt, batch_order):
261 | loss = 0
262 | B, K, S1, S2, S3 = logit_pred.size()
263 | for b in range(B):
264 | permute_pred = torch.stack(
265 | [logit_pred[b:b + 1, -1]] + [logit_pred[b:b + 1, i] for i in batch_order[b]],
266 | dim=1).contiguous()
267 | loss += nn.CrossEntropyLoss()(permute_pred, mask_gt[b:b + 1])
268 | return loss
269 |
270 |
271 | def visualize(loaders, model, epoch, args):
272 | visualization_path = osp.join(args.visualization_dir, 'epoch_%03d' % (epoch + 1))
273 | figures = {}
274 | ids = [split + '_' + str(itr) + '-' + str(step_id)
275 | for split in ['train', 'test']
276 | for itr in range(args.batch)
277 | for step_id in range(args.seq_len)]
278 | cols = ['color_image', 'color_heightmap', 'motion_gt', 'motion_pred', 'mask_gt']
279 | if args.model_type != '3dflow':
280 | cols = cols + ['mask_pred', 'next_mask_pred']
281 |
282 | with torch.no_grad():
283 | for split in ['train', 'test']:
284 | model.train()
285 | batch = iter(loaders[split]).next()
286 | batch_size = batch['0-action'].size(0)
287 | last_s = model.module.get_init_repr(batch_size).cuda()
288 | for step_id in range(args.seq_len):
289 | output = model(
290 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1),
291 | last_s=last_s,
292 | input_action=batch['%d-action' % step_id].cuda(),
293 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None,
294 | no_warp=args.model_type=='nowarp',
295 | next_mask=True
296 | )
297 | last_s = output['s'].data
298 |
299 | vis_color_image = batch['%d-color_image' % step_id].numpy()
300 | vis_color_heightmap = batch['%d-color_heightmap' % step_id].numpy()
301 | motion_gt = torch.sum(batch['%d-scene_flow_3d' % step_id][:, :2, ...], dim=4).numpy()
302 | motion_pred = torch.sum(output['motion'][:, :2, ...], dim=4).cpu().numpy()
303 |
304 | vis_mask_gt = mask_visualization(batch['%d-mask_3d' % step_id].numpy())
305 |
306 | if args.model_type != '3dflow':
307 | vis_mask_pred = mask_visualization(output['init_logit'].cpu().numpy())
308 | vis_next_mask_pred = mask_visualization(output['next_mask'].cpu().numpy())
309 |
310 | for k in range(args.batch):
311 | figures['%s_%d-%d_color_image' % (split, k, step_id)] = vis_color_image[k]
312 | figures['%s_%d-%d_color_heightmap' % (split, k, step_id)] = vis_color_heightmap[k]
313 | figures['%s_%d-%d_motion_gt' % (split, k, step_id)] = flow2im(motion_gt[k])
314 | figures['%s_%d-%d_motion_pred' % (split, k, step_id)] = flow2im(motion_pred[k])
315 | figures['%s_%d-%d_mask_gt' % (split, k, step_id)] = vis_mask_gt[k]
316 | if args.model_type != '3dflow':
317 | figures['%s_%d-%d_mask_pred' % (split, k, step_id)] = vis_mask_pred[k]
318 | figures['%s_%d-%d_next_mask_pred' % (split, k, step_id)] = vis_next_mask_pred[k]
319 |
320 | html_visualize(visualization_path, figures, ids, cols, title=args.exp)
321 |
322 |
323 | if __name__ == '__main__':
324 | main()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import collections
3 | import math
4 | import os
5 | import shutil
6 |
7 | import cv2
8 | import imageio
9 | import numpy as np
10 | import dominate
11 | from dominate.tags import *
12 | import queue
13 | import threading
14 |
15 |
16 | # Get rotation matrix from euler angles
17 | def euler2rotm(theta):
18 | R_x = np.array([[1, 0, 0],
19 | [0, math.cos(theta[0]), -math.sin(theta[0])],
20 | [0, math.sin(theta[0]), math.cos(theta[0])]
21 | ])
22 | R_y = np.array([[math.cos(theta[1]), 0, math.sin(theta[1])],
23 | [0, 1, 0],
24 | [-math.sin(theta[1]), 0, math.cos(theta[1])]
25 | ])
26 | R_z = np.array([[math.cos(theta[2]), -math.sin(theta[2]), 0],
27 | [math.sin(theta[2]), math.cos(theta[2]), 0],
28 | [0, 0, 1]
29 | ])
30 | R = np.dot(R_z, np.dot(R_y, R_x))
31 | return R
32 |
33 |
34 | def transform_points(pts, transform):
35 | # pts = [3xN] array
36 | # transform: [3x4]
37 | pts_t = np.dot(transform[0:3, 0:3], pts) + np.tile(transform[0:3, 3:], (1, pts.shape[1]))
38 | return pts_t
39 |
40 |
41 | def project_pts_to_2d(pts, camera_view_matrix, camera_intrisic):
42 | # transformation from word to virtual camera
43 | # camera_intrisic for virtual camera [ [f,0,0],[0,f,0],[0,0,1]] f is focal length
44 | # RT_wrd2cam
45 | pts_c = transform_points(pts, camera_view_matrix[0:3, :])
46 | rot_algix = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0]])
47 | pts_c = transform_points(pts_c, rot_algix)
48 | coord_2d = np.dot(camera_intrisic, pts_c)
49 | coord_2d[0:2, :] = coord_2d[0:2, :] / np.tile(coord_2d[2, :], (2, 1))
50 | coord_2d[2, :] = pts_c[2, :]
51 | return coord_2d
52 |
53 |
54 | def project_pts_to_3d(color_image, depth_image, camera_intr, camera_pose):
55 | W, H = depth_image.shape
56 | cam_pts, rgb_pts = get_pointcloud(color_image, depth_image, camera_intr)
57 | world_pts = np.transpose(
58 | np.dot(camera_pose[0:3, 0:3], np.transpose(cam_pts)) + np.tile(camera_pose[0:3, 3:], (1, cam_pts.shape[0])))
59 |
60 | pts = world_pts.reshape([W, H, 3])
61 | pts = np.transpose(pts, [2, 0, 1])
62 |
63 | return pts
64 |
65 |
66 | def get_pointcloud(color_img, depth_img, camera_intrinsics):
67 | # Get depth image size
68 | im_h = depth_img.shape[0]
69 | im_w = depth_img.shape[1]
70 |
71 | # Project depth into 3D point cloud in camera coordinates
72 | pix_x, pix_y = np.meshgrid(np.linspace(0, im_w - 1, im_w), np.linspace(0, im_h - 1, im_h))
73 | cam_pts_x = np.multiply(pix_x - camera_intrinsics[0, 2], depth_img / camera_intrinsics[0, 0])
74 | cam_pts_y = np.multiply(pix_y - camera_intrinsics[1, 2], depth_img / camera_intrinsics[1, 1])
75 | cam_pts_z = depth_img.copy()
76 | cam_pts_x.shape = (im_h * im_w, 1)
77 | cam_pts_y.shape = (im_h * im_w, 1)
78 | cam_pts_z.shape = (im_h * im_w, 1)
79 |
80 | # Reshape image into colors for 3D point cloud
81 | rgb_pts_r = color_img[:, :, 0]
82 | rgb_pts_g = color_img[:, :, 1]
83 | rgb_pts_b = color_img[:, :, 2]
84 | rgb_pts_r.shape = (im_h * im_w, 1)
85 | rgb_pts_g.shape = (im_h * im_w, 1)
86 | rgb_pts_b.shape = (im_h * im_w, 1)
87 |
88 | cam_pts = np.concatenate((cam_pts_x, cam_pts_y, cam_pts_z), axis=1)
89 | rgb_pts = np.concatenate((rgb_pts_r, rgb_pts_g, rgb_pts_b), axis=1)
90 |
91 | return cam_pts, rgb_pts
92 |
93 |
94 | def get_heightmap(color_img, depth_img, cam_intrinsics, cam_pose, workspace_limits, heightmap_resolution):
95 | # Compute heightmap size
96 | heightmap_size = np.round(((workspace_limits[1][1] - workspace_limits[1][0]) / heightmap_resolution,
97 | (workspace_limits[0][1] - workspace_limits[0][0]) / heightmap_resolution)).astype(int)
98 |
99 | # Get 3D point cloud from RGB-D images
100 | surface_pts, color_pts = get_pointcloud(color_img, depth_img, cam_intrinsics)
101 |
102 | # Transform 3D point cloud from camera coordinates to robot coordinates
103 | surface_pts = np.transpose(
104 | np.dot(cam_pose[0:3, 0:3], np.transpose(surface_pts)) + np.tile(cam_pose[0:3, 3:], (1, surface_pts.shape[0])))
105 |
106 | # Sort surface points by z value
107 | sort_z_ind = np.argsort(surface_pts[:, 2])
108 | surface_pts = surface_pts[sort_z_ind]
109 | color_pts = color_pts[sort_z_ind]
110 |
111 | # Filter out surface points outside heightmap boundaries
112 | heightmap_valid_ind = np.logical_and(np.logical_and(np.logical_and(
113 | np.logical_and(surface_pts[:, 0] >= workspace_limits[0][0], surface_pts[:, 0] < workspace_limits[0][1]),
114 | surface_pts[:, 1] >= workspace_limits[1][0]), surface_pts[:, 1] < workspace_limits[1][1]),
115 | surface_pts[:, 2] < workspace_limits[2][1])
116 | surface_pts = surface_pts[heightmap_valid_ind]
117 | color_pts = color_pts[heightmap_valid_ind]
118 |
119 | # Create orthographic top-down-view RGB-D heightmaps
120 | color_heightmap_r = np.zeros((heightmap_size[0], heightmap_size[1], 1), dtype=np.uint8)
121 | color_heightmap_g = np.zeros((heightmap_size[0], heightmap_size[1], 1), dtype=np.uint8)
122 | color_heightmap_b = np.zeros((heightmap_size[0], heightmap_size[1], 1), dtype=np.uint8)
123 | depth_heightmap = np.zeros(heightmap_size)
124 | heightmap_pix_x = np.floor((surface_pts[:, 0] - workspace_limits[0][0]) / heightmap_resolution).astype(int)
125 | heightmap_pix_y = np.floor((surface_pts[:, 1] - workspace_limits[1][0]) / heightmap_resolution).astype(int)
126 | color_heightmap_r[heightmap_pix_y, heightmap_pix_x] = color_pts[:, [0]]
127 | color_heightmap_g[heightmap_pix_y, heightmap_pix_x] = color_pts[:, [1]]
128 | color_heightmap_b[heightmap_pix_y, heightmap_pix_x] = color_pts[:, [2]]
129 | color_heightmap = np.concatenate((color_heightmap_r, color_heightmap_g, color_heightmap_b), axis=2)
130 | depth_heightmap[heightmap_pix_y, heightmap_pix_x] = surface_pts[:, 2]
131 | z_bottom = workspace_limits[2][0]
132 | depth_heightmap = depth_heightmap - z_bottom
133 | depth_heightmap[depth_heightmap < 0] = 0
134 | # depth_heightmap[depth_heightmap == -z_bottom] = np.nan
135 |
136 | return color_heightmap, depth_heightmap
137 |
138 |
139 | def mkdir(path, clean=False):
140 | if clean and os.path.exists(path):
141 | shutil.rmtree(path)
142 | if not os.path.exists(path):
143 | os.makedirs(path)
144 |
145 |
146 | def imresize(im, dsize, cfirst=False):
147 | if cfirst:
148 | im = im.transpose(1, 2, 0)
149 | im = cv2.resize(im, dsize=dsize)
150 | if cfirst:
151 | im = im.transpose(2, 0, 1)
152 | return im
153 |
154 |
155 | def imretype(im, dtype):
156 | im = np.array(im)
157 |
158 | if im.dtype in ['float', 'float16', 'float32', 'float64']:
159 | im = im.astype(np.float)
160 | elif im.dtype == 'uint8':
161 | im = im.astype(np.float) / 255.
162 | elif im.dtype == 'uint16':
163 | im = im.astype(np.float) / 65535.
164 | else:
165 | raise NotImplementedError('unsupported source dtype: {0}'.format(im.dtype))
166 |
167 | assert np.min(im) >= 0 and np.max(im) <= 1
168 |
169 | if dtype in ['float', 'float16', 'float32', 'float64']:
170 | im = im.astype(dtype)
171 | elif dtype == 'uint8':
172 | im = (im * 255.).astype(dtype)
173 | elif dtype == 'uint16':
174 | im = (im * 65535.).astype(dtype)
175 | else:
176 | raise NotImplementedError('unsupported target dtype: {0}'.format(dtype))
177 |
178 | return im
179 |
180 |
181 | def imwrite(path, obj):
182 | if not isinstance(obj, (collections.Sequence, collections.UserList)):
183 | obj = [obj]
184 | writer = imageio.get_writer(path)
185 | for im in obj:
186 | im = imretype(im, dtype='uint8').squeeze()
187 | if len(im.shape) == 3 and im.shape[0] == 3:
188 | im = np.transpose(im, (1, 2, 0))
189 | writer.append_data(im)
190 | writer.close()
191 |
192 |
193 | def flow2im(flow, max=None, dtype='float32', cfirst=False):
194 | flow = np.array(flow)
195 |
196 | if np.ndim(flow) == 3 and flow.shape[0] == 2:
197 | x, y = flow[:, ...]
198 | elif np.ndim(flow) == 3 and flow.shape[-1] == 2:
199 | x = flow[..., 0]
200 | y = flow[..., 1]
201 | else:
202 | raise NotImplementedError(
203 | 'unsupported flow size: {0}'.format(flow.shape))
204 |
205 | rho, theta = cv2.cartToPolar(x, y)
206 |
207 | if max is None:
208 | max = np.maximum(np.max(rho), 1e-6)
209 |
210 | hsv = np.zeros(list(rho.shape) + [3], dtype=np.uint8)
211 | hsv[..., 0] = theta * 90 / np.pi
212 | hsv[..., 1] = 255
213 | hsv[..., 2] = np.minimum(rho / max, 1) * 255
214 |
215 | im = cv2.cvtColor(hsv, code=cv2.COLOR_HSV2RGB)
216 | im = imretype(im, dtype=dtype)
217 |
218 | if cfirst:
219 | im = im.transpose(2, 0, 1)
220 | return im
221 |
222 |
223 | def draw_arrow(image, action, direction_num=8, heightmap_pixel_size=0.004):
224 | # image: [W, H, 3] (color image) or [W, H] (depth image)
225 | def put_in_bound(val, bound):
226 | # output: 0 <= val < bound
227 | val = min(max(0, val), bound - 1)
228 | return val
229 |
230 | img = image.copy()
231 | if isinstance(action, tuple):
232 | x_ini, y_ini, direction = action
233 | else:
234 | x_ini, y_ini, direction = action['2'], action['1'], action['0']
235 |
236 | pushing_distance = 0.15
237 |
238 | angle = direction / direction_num * 2 * np.pi
239 | x_end = put_in_bound(int(x_ini + pushing_distance / heightmap_pixel_size * np.cos(angle)), image.shape[1])
240 | y_end = put_in_bound(int(y_ini + pushing_distance / heightmap_pixel_size * np.sin(angle)), image.shape[0])
241 |
242 | if img.shape[0] == 1:
243 | # gray img, white arrow
244 | img = imretype(img[:, :, np.newaxis], 'uint8')
245 | cv2.arrowedLine(img=img, pt1=(x_ini, y_ini), pt2=(x_end, y_end), color=255, thickness=2, tipLength=0.2)
246 | elif img.shape[2] == 3:
247 | # rgb img, red arrow
248 | cv2.arrowedLine(img=img, pt1=(x_ini, y_ini), pt2=(x_end, y_end), color=(255, 0, 0), thickness=2, tipLength=0.2)
249 | return img
250 |
251 |
252 | def multithreading_exec(num, q, fun, blocking=True):
253 | """
254 | Multi-threading Execution
255 |
256 | :param num: number of threadings
257 | :param q: queue of args
258 | :param fun: function to be executed
259 | :param blocking: blocking or not (default True)
260 | """
261 |
262 | class Worker(threading.Thread):
263 | def __init__(self, q, fun):
264 | super().__init__()
265 | self.q = q
266 | self.fun = fun
267 | self.start()
268 |
269 | def run(self):
270 | while True:
271 | try:
272 | args = self.q.get(block=False)
273 | self.fun(*args)
274 | self.q.task_done()
275 | except queue.Empty:
276 | break
277 |
278 | thread_list = [Worker(q, fun) for i in range(num)]
279 | if blocking:
280 | for t in thread_list:
281 | if t.is_alive():
282 | t.join()
283 |
284 |
285 | def html_visualize(web_path, data, ids, cols, others=[], title='visualization', threading_num=10):
286 | """
287 | :param web_path: (str) directory to save webpage. It will clear the old data!
288 | :param data: (dict of data).
289 | key: {id}_{col}.
290 | value: figure or text
291 | - figure: ndarray --> .png or [ndarrays,] --> .gif
292 | - text: str or [str,]
293 | :param ids: (list of str) name of each row
294 | :param cols: (list of str) name of each column
295 | :param others: (list of dict) other figures
296 | 'name': str, name of the data, visualize using h2()
297 | 'data': string or ndarray(image)
298 | 'height': int, height of the image (default 256)
299 | :param title: (str) title of the webpage
300 | :param threading_num: number of threadings for imwrite (default 10)
301 | """
302 | figure_path = os.path.join(web_path, 'figures')
303 | mkdir(web_path, clean=True)
304 | mkdir(figure_path, clean=True)
305 | q = queue.Queue()
306 | for key, value in data.items():
307 | if isinstance(value, np.ndarray):
308 | q.put((os.path.join(figure_path, key + '.png'), value))
309 | if not isinstance(value, list) and isinstance(value[0], np.ndarray):
310 | q.put((os.path.join(figure_path, key + '.gif'), value))
311 | multithreading_exec(threading_num, q, imwrite)
312 |
313 | with dominate.document(title=title) as web:
314 | dominate.tags.h1(title)
315 | with dominate.tags.table(border=1, style='table-layout: fixed;'):
316 | with dominate.tags.tr():
317 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='center', width='64px'):
318 | dominate.tags.p('id')
319 | for col in cols:
320 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='center', ):
321 | dominate.tags.p(col)
322 | for id in ids:
323 | with dominate.tags.tr():
324 | bgcolor = 'F1C073' if id.startswith('train') else 'C5F173'
325 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='center',
326 | bgcolor=bgcolor):
327 | for part in id.split('_'):
328 | dominate.tags.p(part)
329 | for col in cols:
330 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='top'):
331 | value = data.get(f'{id}_{col}', None)
332 | if isinstance(value, str):
333 | dominate.tags.p(value)
334 | elif isinstance(value, list) and isinstance(value[0], str):
335 | for v in value:
336 | dominate.tags.p(v)
337 | else:
338 | dominate.tags.img(style='height:128px',
339 | src=os.path.join('figures', '{}_{}.png'.format(id, col)))
340 | for idx, other in enumerate(others):
341 | dominate.tags.h2(other['name'])
342 | if isinstance(other['data'], str):
343 | dominate.tags.p(other['data'])
344 | else:
345 | imwrite(os.path.join(figure_path, '_{}_{}.png'.format(idx, other['name'])), other['data'])
346 | dominate.tags.img(style='height:{}px'.format(other.get('height', 256)),
347 | src=os.path.join('figures', '_{}_{}.png'.format(idx, other['name'])))
348 | with open(os.path.join(web_path, 'index.html'), 'w') as fp:
349 | fp.write(web.render())
350 |
351 |
352 | def mask_visualization(mask):
353 | # mask: numpy array, [B, K, W, H, D] or [B, W, H, D]
354 | # Red, Green, Blue, Yellow, Purple
355 | colors = [(255, 87, 89), (89, 169, 79), (78, 121, 167), (237, 201, 72), (176, 122, 161)]
356 | if len(mask.shape) == 5:
357 | B, K, W, H, D = mask.shape
358 | argmax_mask = np.argmax(mask, axis=1)
359 | else:
360 | B, W, H, D = mask.shape
361 | K = max(np.max(mask) + 1, 2)
362 | argmax_mask = mask - 1
363 |
364 | mask_list = []
365 | for k in range(K - 1):
366 | mask_list.append((argmax_mask == k).astype(np.float32))
367 |
368 | mask = np.sum(np.stack(mask_list, axis=1), axis=4)
369 | sum_mask = np.sum(mask, 1) + 1 # [B, 1, W, H]
370 | color_mask = np.zeros([B, W, H, 3])
371 | for i in range(K - 1):
372 | color_mask += mask[:, i, ..., np.newaxis] * np.array(colors[i])
373 | return np.clip(color_mask / sum_mask[..., np.newaxis] / 255.0, 0, 1)
374 |
375 |
376 | def mask_visualization_2d(mask):
377 | # Red, Green, Blue, Yellow, Purple
378 | colors = [(255, 87, 89), (89, 169, 79), (78, 121, 167), (237, 201, 72), (176, 122, 161)]
379 | if len(mask.shape) == 4:
380 | B, K, W, H = mask.shape
381 | argmax_mask = np.argmax(mask, axis=1)
382 | else:
383 | B, W, H = mask.shape
384 | K = max(np.max(mask) + 1, 2)
385 | argmax_mask = mask
386 | mask_list = []
387 | for k in range(K):
388 | mask_list.append((argmax_mask == k).astype(np.float32))
389 |
390 | mask = np.stack(mask_list, axis=1)
391 | sum_mask = np.sum(mask, 1) + 1 # [B, 1, W, H]
392 | color_mask = np.zeros([B, W, H, 3])
393 | for i in range(K):
394 | color_mask += mask[:, i, ..., np.newaxis] * np.array(colors[i])
395 | return np.clip(color_mask / sum_mask[..., np.newaxis] / 255.0, 0, 1)
396 |
397 |
398 | def volume_visualization(volume):
399 | # volume: numpy array: [B, W, H, D]
400 | tmp = np.sum(volume, axis=-1)
401 | tmp -= np.min(tmp)
402 | tmp /= max(np.max(tmp), 1)
403 | return np.clip(tmp, 0, 1)
404 |
405 |
406 | def tsdf_visualization(tsdf):
407 | # tsdf: numpy array: [B, W, H, D]
408 | return volume_visualization((tsdf < 0).astype(np.float32))
409 |
--------------------------------------------------------------------------------