├── LICENSE ├── README.md ├── assets ├── fence │ ├── obj.mtl │ ├── tinker.obj │ └── tinker.urdf ├── object_id │ ├── shapenet_id.json │ └── ycb_id.json ├── stick │ └── stick.urdf └── ur5 │ ├── collision │ ├── base.stl │ ├── forearm.stl │ ├── shoulder.stl │ ├── upperarm.stl │ ├── wrist1.stl │ ├── wrist2.stl │ └── wrist3.stl │ ├── mount.urdf │ ├── notes.txt │ ├── ur5.urdf │ └── visual │ ├── base.stl │ ├── forearm.stl │ ├── shoulder.stl │ ├── upperarm.stl │ ├── wrist1.stl │ ├── wrist2.stl │ └── wrist3.stl ├── binvox_utils.py ├── data.py ├── data └── README.md ├── data_generation.py ├── figures └── teaser.jpg ├── forward_warp.py ├── fusion.py ├── model.py ├── model_utils.py ├── object_models └── README.md ├── pretrained_models ├── 3dflow.pth ├── README.md ├── dsr.pth ├── dsr_ft.pth ├── gtwarp.pth ├── nowarp.pth └── single.pth ├── requirements.txt ├── se3 ├── __pycache__ │ ├── se3_module.cpython-36.pyc │ ├── se3_utils.cpython-36.pyc │ ├── se3aa.cpython-36.pyc │ ├── se3euler.cpython-36.pyc │ ├── se3quat.cpython-36.pyc │ └── se3spquat.cpython-36.pyc ├── se3_module.py ├── se3_utils.py ├── se3aa.py ├── se3euler.py ├── se3quat.py └── se3spquat.py ├── sim.py ├── sim_env.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Columbia Robovision Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning 3D Dynamic Scene Representations for Robot Manipulation 2 | 3 | [Zhenjia Xu](http://www.zhenjiaxu.com/)1\*, 4 | [Zhanpeng He](https://zhanpenghe.github.io/)1\*, 5 | [Jiajun Wu](https://jiajunwu.com/)2, 6 | [Shuran Song](https://www.cs.columbia.edu/~shurans/)1 7 |
8 | 1Columbia University, 2Stanford University 9 |
10 | [CoRL 2020](https://www.robot-learning.org/) 11 | 12 | ### [Project Page](https://dsr-net.cs.columbia.edu/) | [Video](https://youtu.be/GQjYG3nQJ80) | [arXiv](https://arxiv.org/abs/2011.01968) 13 | 14 | ## Overview 15 | This repo contains the PyTorch implementation for paper "Learning 3D Dynamic Scene Representations for Robot Manipulation". 16 | ![teaser](figures/teaser.jpg) 17 | 18 | ## Content 19 | 20 | - [Prerequisites](#prerequisites) 21 | - [Data Preparation](#data-preparation) 22 | - [Pretrained Models](#pretrained-models) 23 | - [Training](#training) 24 | 25 | ## Prerequisites 26 | 27 | The code is built with Python 3.6. Libraries are listed in [requirements.txt](requirements.txt): 28 | 29 | ## Data Preparation 30 | 31 | ### Download Testing Data 32 | The following two testing datasets can be download. 33 | - [Sim](https://dsr-net.cs.columbia.edu/download/data/sim_test_data.zip): 400 sequences, generated in pybullet. 34 | - [Real](https://dsr-net.cs.columbia.edu/download/data/real_test_data.zip): 150 sequences, with full annotations. 35 | 36 | ### Generate Training Data 37 | Download object mesh: [shapenet](https://dsr-net.cs.columbia.edu/download/object_models/shapenet.zip) and [ycb](https://dsr-net.cs.columbia.edu/download/object_models/ycb.zip). 38 | 39 | To generate data in simulation, one can run 40 | ``` 41 | python data_generation.py --data_path [path to data] --train_num [number of training sequences] --test_num [number of testing sequences] --object_type [type of objects] 42 | ``` 43 | Where the `object_type` can be `cube`, `shpenet`, or `ycb`. 44 | The training data in the paper can be generated with the followint scripts: 45 | ``` 46 | # cube 47 | python data_generation.py --data_path data/cube_train --train_num 4000 --test_num 400 --object_type cube 48 | 49 | # shapenet 50 | python data_generation.py --data_path data/shapenet_train --train_num 4000 --test_num 400 --object_type shapenet 51 | ``` 52 | 53 | ## Pretrained Models 54 | Some of the pretrained models can be download in [pretrained_models](pretrained_models). 55 | To evaluate the pretrained models, one can run 56 | ``` 57 | python test.py --resume [path to model] --data_path [path to data] --model_type [type of model] --test_type [type of test] 58 | ``` 59 | where `model_type` can be one of the following: 60 | - `dsr`: DSR-Net introduced in the paper. 61 | - `single`: It does not use any history aggregation. 62 | - `nowarp`: It does not warp the representation before aggregation. 63 | - `gtwarp`: It warps the representation with ground truth motion (i.e., performance oracle) 64 | - `3dflow`: It predicts per-voxel scene flow for the entire 3D volume. 65 | 66 | Both motion prediction and mask prediction can be evaluated by choosing different `test_type`: 67 | - motion prediction: `motion_visible` or `motion_full` 68 | - mask prediction: `mask_ordered` or `mask_unordered` 69 | 70 | (Please refer to our paper for detailed explanation of each type of evaluation) 71 | 72 | Here are several examples: 73 | ``` 74 | # evaluate mask prediction (ordered) of DSR-Net using real data: 75 | python test.py --resume [path to dsr model] --data_path [path to real data] --model_type dsr --test_type mask_ordered 76 | 77 | # evaluate mask prediction (unordered) of DSR-Net(finetuned) using real data: 78 | python test.py --resume [path to dsr_ft model] --data_path [path to real data] --model_type dsr --test_type mask_unordered 79 | 80 | # evaluate motion prediction (visible surface) of NoWarp model using sim data: 81 | python test.py --resume [path to nowarp model] --data_path [path to sim data] --model_type nowarp --test_type motion_visible 82 | 83 | # evaluate motion prediction (full volume) of SingleStep model using sim data: 84 | python test.py --resume [path to single model] --data_path [path to sim data] --model_type single --test_type motion_full 85 | ``` 86 | 87 | 88 | ## Training 89 | Various training options can be modified or toggled on/off with different flags (run `python main.py -h` to see all options): 90 | ``` 91 | usage: train.py [-h] [--exp EXP] [--gpus GPUS [GPUS ...]] [--resume RESUME] 92 | [--data_path DATA_PATH] [--object_num OBJECT_NUM] 93 | [--seq_len SEQ_LEN] [--batch BATCH] [--workers WORKERS] 94 | [--model_type {dsr,single,nowarp,gtwarp,3dflow}] 95 | [--transform_type {affine,se3euler,se3aa,se3spquat,se3quat}] 96 | [--alpha_motion ALPHA_MOTION] [--alpha_mask ALPHA_MASK] 97 | [--snapshot_freq SNAPSHOT_FREQ] [--epoch EPOCH] [--finetune] 98 | [--seed SEED] [--dist_backend DIST_BACKEND] 99 | [--dist_url DIST_URL] 100 | ``` 101 | ### Training of DSR-Net 102 | Since the aggregation ability depends on the accuracy of motion prediction, we split the training process into three stages from easy to hard: (1) single-step on cube dataset; (2) multi-step on cube dataset; (3) multi-step on ShapeNet dataset. 103 | ``` 104 | # Stage 1 (single-step on cube dataset) 105 | python train.py --exp dsr_stage1 --data_path [path to cube dataset] --seq_len 1 --model_type dsr --epoch 30 106 | 107 | # Stage 2 (multi-step on cube dataset) 108 | python train.py --exp dsr_stage2 --resume [path to stage1] --data_path [path to cube dataset] --seq_len 10 --model_type dsr --epoch 20 --finetune 109 | 110 | # Stage 3 (multi-step on shapenet dataset) 111 | python train.py --exp dsr_stage3 --resume [path to stage2] --data_path [path to shapenet dataset] --seq_len 10 --model_type dsr --epoch 20 --finetune 112 | ``` 113 | 114 | ### Training of Baselines 115 | - `nowarp` and `gtwarp`. Use the same scripts as DSR-Net with corresponding `model_type`. 116 | 117 | - `single` and `3dflow`. Two-stage training: (1) single step on cube dataset; (2) single step on Shapenet dataset. 118 | 119 | ## BibTeX 120 | ``` 121 | @inproceedings{xu2020learning, 122 | title={Learning 3D Dynamic Scene Representations for Robot Manipulation}, 123 | author={Xu, Zhenjia and He, Zhanpeng and Wu, Jiajun and Song, Shuran}, 124 | booktitle={Conference on Robot Learning (CoRL)}, 125 | year={2020} 126 | } 127 | ``` 128 | 129 | ## License 130 | 131 | This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details. 132 | 133 | 134 | ## Acknowledgement 135 | - THe code for [SE3](se3) is modified from [se3posenets-pytorch](https://github.com/abyravan/se3posenets-pytorch) 136 | - The code for [TSDF fusion](fusion.py) is modified from [tsdf-fusion-python](https://github.com/andyzeng/tsdf-fusion-python). 137 | - The code for [binvox processing](binvox_utils.py) is modified from [binvox-rw-py](https://github.com/dimatura/binvox-rw-py). 138 | -------------------------------------------------------------------------------- /assets/fence/obj.mtl: -------------------------------------------------------------------------------- 1 | # Color definition for Tinkercad Obj File 2015 2 | 3 | newmtl color_11593967 4 | Ka 0 0 0 5 | Kd 0.6901960784313725 0.9098039215686274 0.9372549019607843 6 | d 1.0 7 | illum 0.0 8 | 9 | newmtl color_16500122 10 | Ka 0 0 0 11 | Kd 0.984313725490196 0.7725490196078432 0.6039215686274509 12 | d 1.0 13 | illum 0.0 14 | 15 | newmtl color_16311991 16 | Ka 0 0 0 17 | Kd 0.9725490196078431 0.9019607843137255 0.7176470588235294 18 | d 1.0 19 | illum 0.0 20 | 21 | newmtl color_13165757 22 | Ka 0 0 0 23 | Kd 0.7843137254901961 0.8941176470588236 0.7411764705882353 24 | d 1.0 25 | illum 0.0 26 | 27 | -------------------------------------------------------------------------------- /assets/fence/tinker.obj: -------------------------------------------------------------------------------- 1 | # Object Export From Tinkercad Server 2015 2 | 3 | mtllib obj.mtl 4 | 5 | o obj_0 6 | v 52 -50 30 7 | v 52 -50 0 8 | v 52 50 0 9 | v 52 50 30 10 | v 50 50 30 11 | v 50 -50 30 12 | v 50 -50 0 13 | v 50 50 0 14 | v 50 52 0 15 | v 50 52 30 16 | v -52 50 30 17 | v -52 50 0 18 | v -52 -50 0 19 | v -52 -50 30 20 | v 50 -52 30 21 | v 50 -52 0 22 | v -50 50 30 23 | v -50 -50 30 24 | v -50 -50 0 25 | v -50 50 0 26 | v -50 -52 0 27 | v -50 -52 30 28 | v -50 52 30 29 | v -50 52 0 30 | # 24 vertices 31 | 32 | g group_0_11593967 33 | 34 | usemtl color_11593967 35 | s 0 36 | 37 | f 1 2 3 38 | f 1 3 4 39 | f 4 5 6 40 | f 4 6 1 41 | f 2 7 8 42 | f 2 8 3 43 | f 6 7 2 44 | f 6 2 1 45 | f 4 3 8 46 | f 4 8 5 47 | f 8 7 6 48 | f 8 6 5 49 | # 12 faces 50 | 51 | g group_0_13165757 52 | 53 | usemtl color_13165757 54 | s 0 55 | 56 | f 15 16 7 57 | f 15 7 6 58 | f 21 22 18 59 | f 21 18 19 60 | f 6 18 22 61 | f 6 22 15 62 | f 22 21 16 63 | f 22 16 15 64 | f 16 21 19 65 | f 16 19 7 66 | f 6 7 19 67 | f 6 19 18 68 | # 12 faces 69 | 70 | g group_0_16311991 71 | 72 | usemtl color_16311991 73 | s 0 74 | 75 | f 11 12 13 76 | f 11 13 14 77 | f 17 11 14 78 | f 17 14 18 79 | f 19 13 12 80 | f 19 12 20 81 | f 14 13 19 82 | f 14 19 18 83 | f 17 20 12 84 | f 17 12 11 85 | f 19 20 17 86 | f 19 17 18 87 | # 12 faces 88 | 89 | g group_0_16500122 90 | 91 | usemtl color_16500122 92 | s 0 93 | 94 | f 9 10 5 95 | f 9 5 8 96 | f 23 24 20 97 | f 23 20 17 98 | f 8 20 24 99 | f 8 24 9 100 | f 10 23 17 101 | f 10 17 5 102 | f 10 9 24 103 | f 10 24 23 104 | f 17 20 8 105 | f 17 8 5 106 | # 12 faces 107 | 108 | #end of obj_0 109 | 110 | -------------------------------------------------------------------------------- /assets/fence/tinker.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /assets/object_id/shapenet_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "bottle": { 3 | "category_id": "02876657", 4 | "object_id": [ 5 | "6b8b2cb01c376064c8724d5673a063a6", 6 | "547fa0085800c5c3846564a8a219239b", 7 | "91235f7d65aec958ca972daa503b3095", 8 | "9dff3d09b297cdd930612f5c0ef21eb8", 9 | "d45bf1487b41d2f630612f5c0ef21eb8", 10 | "898101350771ff942ae40d06128938a1", 11 | "3b26c9021d9e31a7ad8912880b776dcf", 12 | "ed55f39e04668bf9837048966ef3fcb9", 13 | "74690ddde9c0182696c2e7960a15c619", 14 | "e4ada697d05ac7acf9907e8bdd53291e", 15 | "fdc47f5f8dff184830eaaf40a8a562c1", 16 | "ad33ed7da4ef1b1cca18d703b8006093", 17 | "4d4fc73864844dad1ceb7b8cc3792fd", 18 | "42f85b0eb5e9fd508f9c4ecc067067e9" 19 | ], 20 | "global_scaling": [0.7, 0.8], 21 | "large_scaling": [1.1, 1.2] 22 | }, 23 | "can": { 24 | "category_id": "02946921", 25 | "object_id": [ 26 | "f4ad0b7f82c36051f51f77a6d7299806", 27 | "a70947df1f1490c2a81ec39fd9664e9b", 28 | "b6c4d78363d965617cb2a55fa21392b7", 29 | "91a524cc9c9be4872999f92861cdea7a", 30 | "96387095255f7080b7886d94372e3c76", 31 | "9b1f0ddd23357e01a81ec39fd9664e9b" 32 | ], 33 | "global_scaling": [0.7, 0.9], 34 | "large_scaling": [1.1, 1.2] 35 | }, 36 | "mug": { 37 | "category_id": "03797390", 38 | "object_id": [ 39 | "599e604a8265cc0a98765d8aa3638e70", 40 | "b46e89995f4f9cc5161e440f04bd2a2", 41 | "9c930a8a3411f069e7f67f334aa9295c", 42 | "2d10421716b16580e45ef4135c266a12", 43 | "71ca4fc9c8c29fa8d5abaf84513415a2" 44 | ], 45 | "global_scaling": [0.8, 0.9], 46 | "large_scaling": [1.1, 1.2] 47 | }, 48 | "sofa": { 49 | "category_id": "04256520", 50 | "object_id": [ 51 | "930873705bff9098e6e46d06d31ee634", 52 | "f094521e8579917eea65c47b660136e7", 53 | "adc4a9767d1c7bae8522c33a9d3f5757", 54 | "d0e419a11fd8f4bce589b08489d157d", 55 | "9aef63feacf65dd9cc3e9831f31c9164", 56 | "d55d14f87d65faa84ccf9d6d546b307f", 57 | "526c4f841f777635b5b328c62af5142", 58 | "65fce4b727c5df50e5f5c582d1bee164", 59 | "7c68894c83afb0118e8dcbd53cc631ab" 60 | ], 61 | "global_scaling": [0.4, 0.7], 62 | "large_scaling": [1.0, 1.2] 63 | }, 64 | "phone": { 65 | "category_id": "04401088", 66 | "object_id": [ 67 | "611afaaa1671ac8cc56f78d9daf213b", 68 | "b8555009f82af5da8c3645155d02fccc", 69 | "2725909a09e1a7961df58f4da76e254b", 70 | "9d021614c39c53dabee972a203aaf80", 71 | "fe34b663c44baf622ad536a59974757f", 72 | "2bb42eb0676116d41580700c4211a379", 73 | "9efabcf2ff8a4be9a59562d67b11f3d", 74 | "b9f67617cf320c20de4349e5bfa4fedb", 75 | "78855e0d8d27f00b42e82e1724e35ca", 76 | "401d5604ebfb4b43a7d4c094203303b1" 77 | ], 78 | "global_scaling": [0.7, 1.0], 79 | "large_scaling": [1.0, 1.1] 80 | 81 | } 82 | } -------------------------------------------------------------------------------- /assets/object_id/ycb_id.json: -------------------------------------------------------------------------------- 1 | { 2 | "large_list": [ 3 | "002_master_chef_can", 4 | "004_sugar_box", 5 | "006_mustard_bottle", 6 | "flipped-065-j_cups", 7 | "071_nine_hole_peg_test", 8 | "051_large_clamp" 9 | ], 10 | "normal_list": [ 11 | "005_tomato_soup_can", 12 | "007_tuna_fish_can", 13 | "008_pudding_box", 14 | "009_gelatin_box", 15 | "010_potted_meat_can", 16 | "025_mug", 17 | "061_foam_brick", 18 | "077_rubiks_cube", 19 | 20 | "flipped-065-a_cups", 21 | "flipped-065-d_cups", 22 | "flipped-065-g_cups", 23 | "filled-073-a_lego_duplo", 24 | "filled-073-b_lego_duplo", 25 | "filled-073-c_lego_duplo", 26 | "filled-073-d_lego_duplo", 27 | "filled-073-f_lego_duplo" 28 | ] 29 | } -------------------------------------------------------------------------------- /assets/stick/stick.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /assets/ur5/collision/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/base.stl -------------------------------------------------------------------------------- /assets/ur5/collision/forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/forearm.stl -------------------------------------------------------------------------------- /assets/ur5/collision/shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/shoulder.stl -------------------------------------------------------------------------------- /assets/ur5/collision/upperarm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/upperarm.stl -------------------------------------------------------------------------------- /assets/ur5/collision/wrist1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/wrist1.stl -------------------------------------------------------------------------------- /assets/ur5/collision/wrist2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/wrist2.stl -------------------------------------------------------------------------------- /assets/ur5/collision/wrist3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/collision/wrist3.stl -------------------------------------------------------------------------------- /assets/ur5/mount.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /assets/ur5/notes.txt: -------------------------------------------------------------------------------- 1 | 2 | 2018-05-08 - Vijay Pradeep: 3 | 4 | These visual/ mesh files were generated from the dae files in 5 | ros-industrial universal robot repo [1]. Since the collada pyBullet 6 | parser is somewhat limited, it is unable to parse the UR collada mesh 7 | files. This, we imported these collada files into blender and 8 | converted them into STL files. We lost material definitions during 9 | the conversion, but that's ok. 10 | 11 | The URDF was generated by running the xacro xml preprocessor on the 12 | URDF included in the ur_description repo already mentioned here. 13 | Additional manual tweaking was required to update resource paths and 14 | to remove errors caused by missing inertia elements. Varios Gazebo 15 | plugin tags were also removed. 16 | 17 | [1] - https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_description/meshes/ur5/visual 18 | -------------------------------------------------------------------------------- /assets/ur5/ur5.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | > 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | -------------------------------------------------------------------------------- /assets/ur5/visual/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/base.stl -------------------------------------------------------------------------------- /assets/ur5/visual/forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/forearm.stl -------------------------------------------------------------------------------- /assets/ur5/visual/shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/shoulder.stl -------------------------------------------------------------------------------- /assets/ur5/visual/upperarm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/upperarm.stl -------------------------------------------------------------------------------- /assets/ur5/visual/wrist1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/wrist1.stl -------------------------------------------------------------------------------- /assets/ur5/visual/wrist2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/wrist2.stl -------------------------------------------------------------------------------- /assets/ur5/visual/wrist3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/assets/ur5/visual/wrist3.stl -------------------------------------------------------------------------------- /binvox_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2012 Daniel Maturana 2 | # This file is part of binvox-rw-py. 3 | # 4 | # binvox-rw-py is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | # 9 | # binvox-rw-py is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | # 14 | # You should have received a copy of the GNU General Public License 15 | # along with binvox-rw-py. If not, see . 16 | # 17 | 18 | """ 19 | Binvox to Numpy and back. 20 | 21 | 22 | >>> import numpy as np 23 | >>> import binvox_rw 24 | >>> with open('chair.binvox', 'rb') as f: 25 | ... m1 = binvox_rw.read_as_3d_array(f) 26 | ... 27 | >>> m1.dims 28 | [32, 32, 32] 29 | >>> m1.scale 30 | 41.133000000000003 31 | >>> m1.translate 32 | [0.0, 0.0, 0.0] 33 | >>> with open('chair_out.binvox', 'wb') as f: 34 | ... m1.write(f) 35 | ... 36 | >>> with open('chair_out.binvox', 'rb') as f: 37 | ... m2 = binvox_rw.read_as_3d_array(f) 38 | ... 39 | >>> m1.dims==m2.dims 40 | True 41 | >>> m1.scale==m2.scale 42 | True 43 | >>> m1.translate==m2.translate 44 | True 45 | >>> np.all(m1.data==m2.data) 46 | True 47 | 48 | >>> with open('chair.binvox', 'rb') as f: 49 | ... md = binvox_rw.read_as_3d_array(f) 50 | ... 51 | >>> with open('chair.binvox', 'rb') as f: 52 | ... ms = binvox_rw.read_as_coord_array(f) 53 | ... 54 | >>> data_ds = binvox_rw.dense_to_sparse(md.data) 55 | >>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32) 56 | >>> np.all(data_sd==md.data) 57 | True 58 | >>> # the ordering of elements returned by numpy.nonzero changes with axis 59 | >>> # ordering, so to compare for equality we first lexically sort the voxels. 60 | >>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)]) 61 | True 62 | """ 63 | 64 | import numpy as np 65 | 66 | class Voxels(object): 67 | """ Holds a binvox model. 68 | data is either a three-dimensional numpy boolean array (dense representation) 69 | or a two-dimensional numpy float array (coordinate representation). 70 | 71 | dims, translate and scale are the model metadata. 72 | 73 | dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model. 74 | 75 | scale and translate relate the voxels to the original model coordinates. 76 | 77 | To translate voxel coordinates i, j, k to original coordinates x, y, z: 78 | 79 | x_n = (i+.5)/dims[0] 80 | y_n = (j+.5)/dims[1] 81 | z_n = (k+.5)/dims[2] 82 | x = scale*x_n + translate[0] 83 | y = scale*y_n + translate[1] 84 | z = scale*z_n + translate[2] 85 | 86 | """ 87 | 88 | def __init__(self, data, dims, translate, scale, axis_order): 89 | self.data = data 90 | self.dims = dims 91 | self.translate = translate 92 | self.scale = scale 93 | assert (axis_order in ('xzy', 'xyz')) 94 | self.axis_order = axis_order 95 | 96 | def clone(self): 97 | data = self.data.copy() 98 | dims = self.dims[:] 99 | translate = self.translate[:] 100 | return Voxels(data, dims, translate, self.scale, self.axis_order) 101 | 102 | def write(self, fp): 103 | write(self, fp) 104 | 105 | def read_header(fp): 106 | """ Read binvox header. Mostly meant for internal use. 107 | """ 108 | line = fp.readline().strip() 109 | if not line.startswith(b'#binvox'): 110 | raise IOError('Not a binvox file') 111 | dims = list(map(int, fp.readline().strip().split(b' ')[1:])) 112 | translate = list(map(float, fp.readline().strip().split(b' ')[1:])) 113 | scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0] 114 | line = fp.readline() 115 | return dims, translate, scale 116 | 117 | def read_as_3d_array(fp, fix_coords=True): 118 | """ Read binary binvox format as array. 119 | 120 | Returns the model with accompanying metadata. 121 | 122 | Voxels are stored in a three-dimensional numpy array, which is simple and 123 | direct, but may use a lot of memory for large models. (Storage requirements 124 | are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy 125 | boolean arrays use a byte per element). 126 | 127 | Doesn't do any checks on input except for the '#binvox' line. 128 | """ 129 | dims, translate, scale = read_header(fp) 130 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 131 | # if just using reshape() on the raw data: 132 | # indexing the array as array[i,j,k], the indices map into the 133 | # coords as: 134 | # i -> x 135 | # j -> z 136 | # k -> y 137 | # if fix_coords is true, then data is rearranged so that 138 | # mapping is 139 | # i -> x 140 | # j -> y 141 | # k -> z 142 | values, counts = raw_data[::2], raw_data[1::2] 143 | data = np.repeat(values, counts).astype(np.bool) 144 | data = data.reshape(dims) 145 | if fix_coords: 146 | # xzy to xyz TODO the right thing 147 | data = np.transpose(data, (0, 2, 1)) 148 | axis_order = 'xyz' 149 | else: 150 | axis_order = 'xzy' 151 | return Voxels(data, dims, translate, scale, axis_order) 152 | 153 | def read_as_coord_array(fp, fix_coords=True): 154 | """ Read binary binvox format as coordinates. 155 | 156 | Returns binvox model with voxels in a "coordinate" representation, i.e. an 157 | 3 x N array where N is the number of nonzero voxels. Each column 158 | corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates 159 | of the voxel. (The odd ordering is due to the way binvox format lays out 160 | data). Note that coordinates refer to the binvox voxels, without any 161 | scaling or translation. 162 | 163 | Use this to save memory if your model is very sparse (mostly empty). 164 | 165 | Doesn't do any checks on input except for the '#binvox' line. 166 | """ 167 | dims, translate, scale = read_header(fp) 168 | raw_data = np.frombuffer(fp.read(), dtype=np.uint8) 169 | 170 | values, counts = raw_data[::2], raw_data[1::2] 171 | 172 | sz = np.prod(dims) 173 | index, end_index = 0, 0 174 | end_indices = np.cumsum(counts) 175 | indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype) 176 | 177 | values = values.astype(np.bool) 178 | indices = indices[values] 179 | end_indices = end_indices[values] 180 | 181 | nz_voxels = [] 182 | for index, end_index in zip(indices, end_indices): 183 | nz_voxels.extend(range(index, end_index)) 184 | nz_voxels = np.array(nz_voxels) 185 | # TODO are these dims correct? 186 | # according to docs, 187 | # index = x * wxh + z * width + y; // wxh = width * height = d * d 188 | 189 | x = nz_voxels / (dims[0]*dims[1]) 190 | zwpy = nz_voxels % (dims[0]*dims[1]) # z*w + y 191 | z = zwpy / dims[0] 192 | y = zwpy % dims[0] 193 | if fix_coords: 194 | data = np.vstack((x, y, z)) 195 | axis_order = 'xyz' 196 | else: 197 | data = np.vstack((x, z, y)) 198 | axis_order = 'xzy' 199 | 200 | #return Voxels(data, dims, translate, scale, axis_order) 201 | return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order) 202 | 203 | def dense_to_sparse(voxel_data, dtype=np.int): 204 | """ From dense representation to sparse (coordinate) representation. 205 | No coordinate reordering. 206 | """ 207 | if voxel_data.ndim!=3: 208 | raise ValueError('voxel_data is wrong shape; should be 3D array.') 209 | return np.asarray(np.nonzero(voxel_data), dtype) 210 | 211 | def sparse_to_dense(voxel_data, dims, dtype=np.bool): 212 | if voxel_data.ndim!=2 or voxel_data.shape[0]!=3: 213 | raise ValueError('voxel_data is wrong shape; should be 3xN array.') 214 | if np.isscalar(dims): 215 | dims = [dims]*3 216 | dims = np.atleast_2d(dims).T 217 | # truncate to integers 218 | xyz = voxel_data.astype(np.int) 219 | # discard voxels that fall outside dims 220 | valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0) 221 | xyz = xyz[:,valid_ix] 222 | out = np.zeros(dims.flatten(), dtype=dtype) 223 | out[tuple(xyz)] = True 224 | return out 225 | 226 | #def get_linear_index(x, y, z, dims): 227 | #""" Assuming xzy order. (y increasing fastest. 228 | #TODO ensure this is right when dims are not all same 229 | #""" 230 | #return x*(dims[1]*dims[2]) + z*dims[1] + y 231 | 232 | def write(voxel_model, fp): 233 | """ Write binary binvox format. 234 | 235 | Note that when saving a model in sparse (coordinate) format, it is first 236 | converted to dense format. 237 | 238 | Doesn't check if the model is 'sane'. 239 | 240 | """ 241 | if voxel_model.data.ndim==2: 242 | # TODO avoid conversion to dense 243 | dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims) 244 | else: 245 | dense_voxel_data = voxel_model.data 246 | 247 | fp.write('#binvox 1\n'.encode('ascii')) 248 | fp.write(('dim '+' '.join(map(str, voxel_model.dims))+'\n').encode('ascii')) 249 | fp.write(('translate '+' '.join(map(str, voxel_model.translate))+'\n').encode('ascii')) 250 | fp.write(('scale '+str(voxel_model.scale)+'\n').encode('ascii')) 251 | fp.write('data\n'.encode('ascii')) 252 | # fp.write('#binvox 1\n') 253 | # fp.write('dim ' + ' '.join(map(str, voxel_model.dims)) + '\n') 254 | # fp.write('translate ' + ' '.join(map(str, voxel_model.translate)) + '\n') 255 | # fp.write('scale ' + str(voxel_model.scale) + '\n') 256 | # fp.write('data\n') 257 | if not voxel_model.axis_order in ('xzy', 'xyz'): 258 | raise ValueError('Unsupported voxel model axis order') 259 | 260 | if voxel_model.axis_order=='xzy': 261 | voxels_flat = dense_voxel_data.flatten() 262 | elif voxel_model.axis_order=='xyz': 263 | voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten() 264 | 265 | # keep a sort of state machine for writing run length encoding 266 | state = voxels_flat[0] 267 | ctr = 0 268 | for c in voxels_flat: 269 | if c==state: 270 | ctr += 1 271 | # if ctr hits max, dump 272 | if ctr==255: 273 | # fp.write(chr(state)) 274 | # fp.write(chr(ctr)) 275 | fp.write(state.tobytes()) 276 | fp.write(ctr.to_bytes(1, byteorder='little')) 277 | ctr = 0 278 | else: 279 | # if switch state, dump 280 | # fp.write(chr(state)) 281 | # fp.write(chr(ctr)) 282 | fp.write(state.tobytes()) 283 | fp.write(ctr.to_bytes(1, byteorder='little')) 284 | state = c 285 | ctr = 1 286 | # flush out remainders 287 | if ctr > 0: 288 | # fp.write(chr(state)) 289 | # fp.write(chr(ctr)) 290 | fp.write(state.tobytes()) 291 | fp.write(ctr.to_bytes(1, byteorder='little')) 292 | 293 | if __name__ == '__main__': 294 | import doctest 295 | doctest.testmod() -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | from torch.utils.data import Dataset 4 | import h5py 5 | from utils import imretype, draw_arrow 6 | 7 | 8 | class Data(Dataset): 9 | def __init__(self, data_path, split, seq_len): 10 | self.data_path = data_path 11 | self.tot_seq_len = 10 12 | self.seq_len = seq_len 13 | self.volume_size = [128, 128, 48] 14 | self.direction_num = 8 15 | self.voxel_size = 0.004 16 | self.idx_list = open(osp.join(self.data_path, '%s.txt' % split)).read().splitlines() 17 | self.returns = ['action', 'color_heightmap', 'color_image', 'tsdf', 'mask_3d', 'scene_flow_3d'] 18 | self.data_per_seq = self.tot_seq_len // self.seq_len 19 | 20 | def __getitem__(self, index): 21 | data_dict = {} 22 | idx_seq = index // self.data_per_seq 23 | idx_step = index % self.data_per_seq * self.seq_len 24 | for step_id in range(self.seq_len): 25 | f = h5py.File(osp.join(self.data_path, "%s_%d.hdf5" % (self.idx_list[idx_seq], idx_step + step_id)), "r") 26 | 27 | # action 28 | action = f['action'] 29 | data_dict['%d-action' % step_id] = self.get_action(action) 30 | 31 | # color_image, [W, H, 3] 32 | if 'color_image' in self.returns: 33 | data_dict['%d-color_image' % step_id] = np.asarray(f['color_image_small'], dtype=np.uint8) 34 | 35 | # color_heightmap, [128, 128, 3] 36 | if 'color_heightmap' in self.returns: 37 | # draw arrow for visualization 38 | color_heightmap = draw_arrow( 39 | np.asarray(f['color_heightmap'], dtype=np.uint8), 40 | (int(action[2]), int(action[1]), int(action[0])) 41 | ) 42 | data_dict['%d-color_heightmap' % step_id] = color_heightmap 43 | 44 | # tsdf, [S1, S2, S3] 45 | if 'tsdf' in self.returns: 46 | data_dict['%d-tsdf' % step_id] = np.asarray(f['tsdf'], dtype=np.float32) 47 | 48 | # mask_3d, [S1, S2, S3] 49 | if 'mask_3d' in self.returns: 50 | data_dict['%d-mask_3d' % step_id] = np.asarray(f['mask_3d'], dtype=np.int) 51 | 52 | # scene_flow_3d, [3, S1, S2, S3] 53 | if 'scene_flow_3d' in self.returns: 54 | scene_flow_3d = np.asarray(f['scene_flow_3d'], dtype=np.float32).transpose([3, 0, 1, 2]) 55 | data_dict['%d-scene_flow_3d' % step_id] = scene_flow_3d 56 | 57 | return data_dict 58 | 59 | def __len__(self): 60 | return len(self.idx_list) * self.data_per_seq 61 | 62 | def get_action(self, action): 63 | direction, r, c = int(action[0]), int(action[1]), int(action[2]) 64 | if direction < 0: 65 | direction += self.direction_num 66 | action_map = np.zeros(shape=[self.direction_num, self.volume_size[0], self.volume_size[1]], dtype=np.float32) 67 | action_map[direction, r, c] = 1 68 | 69 | return action_map -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | ### Generate training data 4 | - Please refer to [link](../README.md). 5 | ### Donload testing data 6 | - The following two testing datasets can be download. 7 | - [Sim](https://dsr-net.cs.columbia.edu/download/data/sim_test_data.zip): 400 sequences, generated in pybullet. 8 | - [Real](https://dsr-net.cs.columbia.edu/download/data/real_test_data.zip): 150 sequences, with full annotations. 9 | - Unzip `real_test_data.zip` and `sim_test_data.zip` in this folder. -------------------------------------------------------------------------------- /data_generation.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | from tqdm import tqdm 4 | import argparse 5 | import h5py 6 | 7 | from sim_env import SimulationEnv 8 | from utils import mkdir, project_pts_to_3d 9 | from fusion import TSDFVolume 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--data_path', type=str, help='path to data') 14 | parser.add_argument('--train_num', type=int, help='number of training sequences') 15 | parser.add_argument('--test_num', type=int, help='number of testing sequences') 16 | parser.add_argument('--object_type', type=str, default='ycb', choices=['cube', 'shapenet', 'ycb']) 17 | parser.add_argument('--max_path_length', type=int, default=10, help='maximum length for each sequence') 18 | parser.add_argument('--object_num', type=int, default=4, help='number of objects') 19 | 20 | def main(): 21 | 22 | args = parser.parse_args() 23 | 24 | for key in vars(args): 25 | print('[{0}] = {1}'.format(key, getattr(args, key))) 26 | mkdir(args.data_path, clean=False) 27 | 28 | env = SimulationEnv(gui_enabled=False) 29 | camera_pose = env.sim.camera_params[0]['camera_pose'] 30 | camera_intr = env.sim.camera_params[0]['camera_intr'] 31 | camera_pose_small = env.sim.camera_params[1]['camera_pose'] 32 | camera_intr_small = env.sim.camera_params[1]['camera_intr'] 33 | 34 | for rollout in tqdm(range(args.train_num + args.test_num)): 35 | env.reset(args.object_num, args.object_type) 36 | for step_num in range(args.max_path_length): 37 | f = h5py.File(osp.join(args.data_path, '%d_%d.hdf5' % (rollout, step_num)), 'w') 38 | 39 | output = env.poke() 40 | for key, val in output.items(): 41 | if key == 'action': 42 | action = val 43 | f['action'] = np.array([action['0'], action['1'], action['2']]) 44 | else: 45 | f.create_dataset(key, data=val, compression="gzip", compression_opts=4) 46 | 47 | # tsdf 48 | tsdf = get_volume( 49 | color_image=output['color_image'], 50 | depth_image=output['depth_image'], 51 | cam_intr=camera_intr, 52 | cam_pose=camera_pose 53 | ) 54 | f.create_dataset('tsdf', data=tsdf, compression="gzip", compression_opts=4) 55 | 56 | # 3d pts 57 | color_image_small = output['color_image_small'] 58 | depth_image_small = output['depth_image_small'] 59 | pts_small = project_pts_to_3d(color_image_small, depth_image_small, camera_intr_small, camera_pose_small) 60 | f.create_dataset('pts_small', data=pts_small, compression="gzip", compression_opts=4) 61 | 62 | if step_num == args.max_path_length - 1: 63 | g_next = f.create_group('next') 64 | output = env.get_scene_info(mask_info=True) 65 | for key, val in output.items(): 66 | g_next.create_dataset(key, data=val, compression="gzip", compression_opts=4) 67 | f.close() 68 | 69 | id_list = [i for i in range(args.train_num + args.test_num)] 70 | np.random.shuffle(id_list) 71 | 72 | with open(osp.join(args.data_path, 'train.txt'), 'w') as f: 73 | for k in range(args.train_num): 74 | print(id_list[k], file=f) 75 | 76 | with open(osp.join(args.data_path, 'test.txt'), 'w') as f: 77 | for k in range(args.train_num, args.train_num + args.test_num): 78 | print(id_list[k], file=f) 79 | 80 | 81 | def get_volume(color_image, depth_image, cam_intr, cam_pose, vol_bnds=None): 82 | voxel_size = 0.004 83 | if vol_bnds is None: 84 | vol_bnds = np.array([[0.244, 0.756], 85 | [-0.256, 0.256], 86 | [0.0, 0.192]]) 87 | tsdf_vol = TSDFVolume(vol_bnds, voxel_size=voxel_size, use_gpu=True) 88 | tsdf_vol.integrate(color_image, depth_image, cam_intr, cam_pose, obs_weight=1.) 89 | volume = np.asarray(tsdf_vol.get_volume()[0]) 90 | volume = np.transpose(volume, [1, 0, 2]) 91 | return volume 92 | 93 | 94 | if __name__ == '__main__': 95 | main() -------------------------------------------------------------------------------- /figures/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/figures/teaser.jpg -------------------------------------------------------------------------------- /forward_warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Function 4 | from cupy.cuda import function 5 | from pynvrtc.compiler import Program 6 | from collections import namedtuple 7 | 8 | 9 | class Forward_Warp_Cupy(Function): 10 | @staticmethod 11 | def forward(ctx, feature, flow, mask): 12 | kernel = ''' 13 | extern "C" 14 | __global__ void warp_forward( 15 | const float * im0, // [B, C, W, H, D] 16 | const float * flow, // [B, 3, W, H, D] 17 | const float * mask, // [B, W, H, D] 18 | float * im1, // [B, C, W, H, D] 19 | float * cnt, // [B, W, H, D] 20 | const int vol_batch, 21 | const int vol_dim_x, 22 | const int vol_dim_y, 23 | const int vol_dim_z, 24 | const int feature_dim, 25 | const int warp_mode //0 (bilinear), 1 (nearest) 26 | ) { 27 | // Get voxel index 28 | int max_threads_per_block = blockDim.x; 29 | int block_idx = blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x; 30 | int voxel_idx = block_idx * max_threads_per_block + threadIdx.x; 31 | 32 | int voxel_size_product = vol_dim_x * vol_dim_y * vol_dim_z; 33 | 34 | // IMPORTANT 35 | if (voxel_idx >= vol_batch * voxel_size_product) return; 36 | 37 | // Get voxel grid coordinates (note: be careful when casting) 38 | int tmp = voxel_idx; 39 | 40 | int voxel_z = tmp % vol_dim_z; 41 | tmp = tmp / vol_dim_z; 42 | 43 | int voxel_y = tmp % vol_dim_y; 44 | tmp = tmp / vol_dim_y; 45 | 46 | int voxel_x = tmp % vol_dim_x; 47 | int batch = tmp / vol_dim_x; 48 | 49 | int voxel_idx_BCWHD = voxel_idx + batch * (voxel_size_product * (feature_dim - 1)); 50 | int voxel_idx_flow = voxel_idx + batch * (voxel_size_product * (3 - 1)); 51 | 52 | // Main part 53 | if (warp_mode == 0) { 54 | // bilinear 55 | float x_float = voxel_x + flow[voxel_idx_flow]; 56 | float y_float = voxel_y + flow[voxel_idx_flow + voxel_size_product]; 57 | float z_float = voxel_z + flow[voxel_idx_flow + voxel_size_product + voxel_size_product]; 58 | 59 | int x_floor = x_float; 60 | int y_floor = y_float; 61 | int z_floor = z_float; 62 | 63 | for(int t = 0; t < 8; t++) { 64 | int dx = (t >= 4); 65 | int dy = (t - 4 * dx) >= 2; 66 | int dz = t - 4 * dx - dy * 2; 67 | 68 | int x = x_floor + dx; 69 | int y = y_floor + dy; 70 | int z = z_floor + dz; 71 | 72 | if (x >= 0 && x < vol_dim_x && y >= 0 && y < vol_dim_y && z >= 0 && z < vol_dim_z) { 73 | float weight = mask[voxel_idx]; 74 | weight *= (dx == 0 ? (x_floor + 1 - x_float) : (x_float - x_floor)); 75 | weight *= (dy == 0 ? (y_floor + 1 - y_float) : (y_float - y_floor)); 76 | weight *= (dz == 0 ? (z_floor + 1 - z_float) : (z_float - z_floor)); 77 | int idx = (((int)batch * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z; 78 | atomicAdd(&cnt[idx], weight); 79 | 80 | int idx_BCWHD = (((int)batch * feature_dim * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z; 81 | 82 | for(int c = 0, offset = 0; c < feature_dim; c++, offset += voxel_size_product) { 83 | atomicAdd(&im1[idx_BCWHD + offset], im0[voxel_idx_BCWHD + offset] * weight); 84 | } 85 | } 86 | 87 | } 88 | } else { 89 | // nearest 90 | int x = round(voxel_x + flow[voxel_idx_flow]); 91 | int y = round(voxel_y + flow[voxel_idx_flow + voxel_size_product]); 92 | int z = round(voxel_z + flow[voxel_idx_flow + voxel_size_product + voxel_size_product]); 93 | 94 | if (x >= 0 && x < vol_dim_x && y >= 0 && y < vol_dim_y && z >= 0 && z < vol_dim_z) { 95 | int idx = (((int)batch * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z; 96 | float mask_weight = mask[voxel_idx]; 97 | atomicAdd(&cnt[idx], mask_weight); 98 | 99 | int idx_BCWHD = (((int)batch * feature_dim * vol_dim_x + x) * vol_dim_y + y) * vol_dim_z + z; 100 | 101 | for(int c = 0, offset = 0; c < feature_dim; c++, offset += voxel_size_product) { 102 | atomicAdd(&im1[idx_BCWHD + offset], im0[voxel_idx_BCWHD + offset] * mask_weight); 103 | } 104 | } 105 | } 106 | } 107 | ''' 108 | program = Program(kernel, 'warp_forward.cu') 109 | ptx = program.compile() 110 | m = function.Module() 111 | m.load(bytes(ptx.encode())) 112 | f = m.get_function('warp_forward') 113 | Stream = namedtuple('Stream', ['ptr']) 114 | s = Stream(ptr=torch.cuda.current_stream().cuda_stream) 115 | 116 | B, C, W, H, D = feature.size() 117 | warp_mode = 0 118 | n_blocks = np.ceil(B * W * H * D / 1024.0) 119 | grid_dim_x = int(np.cbrt(n_blocks)) 120 | grid_dim_y = int(np.sqrt(n_blocks / grid_dim_x)) 121 | grid_dim_z = int(np.ceil(n_blocks / grid_dim_x / grid_dim_y)) 122 | assert grid_dim_x * grid_dim_y * grid_dim_z * 1024 >= B * W * H * D 123 | 124 | feature_new = torch.zeros_like(feature) 125 | cnt = torch.zeros_like(mask) 126 | 127 | f(grid=(grid_dim_x, grid_dim_y, grid_dim_z), block=(1024, 1, 1), 128 | args=[feature.data_ptr(), flow.data_ptr(), mask.data_ptr(), feature_new.data_ptr(), cnt.data_ptr(), 129 | B, W, H, D, C, warp_mode], stream=s) 130 | 131 | eps=1e-3 132 | cnt = torch.max(cnt, other=torch.ones_like(cnt) * eps) 133 | feature_new = feature_new / torch.unsqueeze(cnt, 1) 134 | 135 | return feature_new 136 | 137 | @staticmethod 138 | def backward(ctx, feature_new_grad): 139 | # Not implemented 140 | return None, None, None -------------------------------------------------------------------------------- /fusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from numba import njit, prange 4 | from skimage import measure 5 | 6 | try: 7 | import pycuda.driver as cuda 8 | import pycuda.autoinit 9 | from pycuda.compiler import SourceModule 10 | FUSION_GPU_MODE = 1 11 | except Exception as err: 12 | print('Warning: {}'.format(err)) 13 | print('Failed to import PyCUDA. Running fusion in CPU mode.') 14 | FUSION_GPU_MODE = 0 15 | 16 | 17 | class TSDFVolume: 18 | """Volumetric TSDF Fusion of RGB-D Images. 19 | """ 20 | 21 | def __init__(self, vol_bnds, voxel_size, use_gpu=True): 22 | """Constructor. 23 | Args: 24 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the 25 | xyz bounds (min/max) in meters. 26 | voxel_size (float): The volume discretization in meters. 27 | """ 28 | vol_bnds = np.asarray(vol_bnds) 29 | assert vol_bnds.shape == ( 30 | 3, 2), "[!] `vol_bnds` should be of shape (3, 2)." 31 | 32 | # Define voxel volume parameters 33 | self._vol_bnds = vol_bnds 34 | self._voxel_size = float(voxel_size) 35 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF 36 | self._color_const = 256 * 256 37 | 38 | # Adjust volume bounds and ensure C-order contiguous 39 | self._vol_dim = np.ceil( 40 | (self._vol_bnds[:, 1]-self._vol_bnds[:, 0])/self._voxel_size).copy(order='C').astype(int) 41 | self._vol_bnds[:, 1] = self._vol_bnds[:, 0] + \ 42 | self._vol_dim*self._voxel_size 43 | self._vol_origin = self._vol_bnds[:, 0].copy( 44 | order='C').astype(np.float32) 45 | 46 | # Initialize pointers to voxel volume in CPU memory 47 | self._tsdf_vol_cpu = np.ones(self._vol_dim).astype(np.float32) 48 | # for computing the cumulative moving average of observations per voxel 49 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 50 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 51 | 52 | self.gpu_mode = use_gpu and FUSION_GPU_MODE 53 | 54 | # Copy voxel volumes to GPU 55 | if self.gpu_mode: 56 | self._tsdf_vol_gpu = cuda.mem_alloc(self._tsdf_vol_cpu.nbytes) 57 | cuda.memcpy_htod(self._tsdf_vol_gpu, self._tsdf_vol_cpu) 58 | self._weight_vol_gpu = cuda.mem_alloc(self._weight_vol_cpu.nbytes) 59 | cuda.memcpy_htod(self._weight_vol_gpu, self._weight_vol_cpu) 60 | self._color_vol_gpu = cuda.mem_alloc(self._color_vol_cpu.nbytes) 61 | cuda.memcpy_htod(self._color_vol_gpu, self._color_vol_cpu) 62 | 63 | # Cuda kernel function (C++) 64 | self._cuda_src_mod = SourceModule(""" 65 | __global__ void integrate(float * tsdf_vol, 66 | float * weight_vol, 67 | float * color_vol, 68 | float * vol_dim, 69 | float * vol_origin, 70 | float * cam_intr, 71 | float * cam_pose, 72 | float * other_params, 73 | float * color_im, 74 | float * depth_im) { 75 | // Get voxel index 76 | int gpu_loop_idx = (int) other_params[0]; 77 | int max_threads_per_block = blockDim.x; 78 | int block_idx = blockIdx.z*gridDim.y*gridDim.x+blockIdx.y*gridDim.x+blockIdx.x; 79 | int voxel_idx = gpu_loop_idx*gridDim.x*gridDim.y*gridDim.z*max_threads_per_block+block_idx*max_threads_per_block+threadIdx.x; 80 | int vol_dim_x = (int) vol_dim[0]; 81 | int vol_dim_y = (int) vol_dim[1]; 82 | int vol_dim_z = (int) vol_dim[2]; 83 | if (voxel_idx >= vol_dim_x*vol_dim_y*vol_dim_z) 84 | return; 85 | // Get voxel grid coordinates (note: be careful when casting) 86 | float voxel_x = floorf(((float)voxel_idx)/((float)(vol_dim_y*vol_dim_z))); 87 | float voxel_y = floorf(((float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z))/((float)vol_dim_z)); 88 | float voxel_z = (float)(voxel_idx-((int)voxel_x)*vol_dim_y*vol_dim_z-((int)voxel_y)*vol_dim_z); 89 | // Voxel grid coordinates to world coordinates 90 | float voxel_size = other_params[1]; 91 | float pt_x = vol_origin[0]+voxel_x*voxel_size; 92 | float pt_y = vol_origin[1]+voxel_y*voxel_size; 93 | float pt_z = vol_origin[2]+voxel_z*voxel_size; 94 | // World coordinates to camera coordinates 95 | float tmp_pt_x = pt_x-cam_pose[0*4+3]; 96 | float tmp_pt_y = pt_y-cam_pose[1*4+3]; 97 | float tmp_pt_z = pt_z-cam_pose[2*4+3]; 98 | float cam_pt_x = cam_pose[0*4+0]*tmp_pt_x+cam_pose[1*4+0]*tmp_pt_y+cam_pose[2*4+0]*tmp_pt_z; 99 | float cam_pt_y = cam_pose[0*4+1]*tmp_pt_x+cam_pose[1*4+1]*tmp_pt_y+cam_pose[2*4+1]*tmp_pt_z; 100 | float cam_pt_z = cam_pose[0*4+2]*tmp_pt_x+cam_pose[1*4+2]*tmp_pt_y+cam_pose[2*4+2]*tmp_pt_z; 101 | // Camera coordinates to image pixels 102 | int pixel_x = (int) roundf(cam_intr[0*3+0]*(cam_pt_x/cam_pt_z)+cam_intr[0*3+2]); 103 | int pixel_y = (int) roundf(cam_intr[1*3+1]*(cam_pt_y/cam_pt_z)+cam_intr[1*3+2]); 104 | // Skip if outside view frustum 105 | int im_h = (int) other_params[2]; 106 | int im_w = (int) other_params[3]; 107 | if (pixel_x < 0 || pixel_x >= im_w || pixel_y < 0 || pixel_y >= im_h || cam_pt_z<0) 108 | return; 109 | // Skip invalid depth 110 | float depth_value = depth_im[pixel_y*im_w+pixel_x]; 111 | if (depth_value == 0) 112 | return; 113 | // Integrate TSDF 114 | float trunc_margin = other_params[4]; 115 | float depth_diff = depth_value-cam_pt_z; 116 | if (depth_diff < -trunc_margin) 117 | return; 118 | float dist = fmin(1.0f,depth_diff/trunc_margin); 119 | float w_old = weight_vol[voxel_idx]; 120 | float obs_weight = other_params[5]; 121 | float w_new = w_old + obs_weight; 122 | weight_vol[voxel_idx] = w_new; 123 | tsdf_vol[voxel_idx] = (tsdf_vol[voxel_idx]*w_old+obs_weight*dist)/w_new; 124 | // Integrate color 125 | float old_color = color_vol[voxel_idx]; 126 | float old_b = floorf(old_color/(256*256)); 127 | float old_g = floorf((old_color-old_b*256*256)/256); 128 | float old_r = old_color-old_b*256*256-old_g*256; 129 | float new_color = color_im[pixel_y*im_w+pixel_x]; 130 | float new_b = floorf(new_color/(256*256)); 131 | float new_g = floorf((new_color-new_b*256*256)/256); 132 | float new_r = new_color-new_b*256*256-new_g*256; 133 | new_b = fmin(roundf((old_b*w_old+obs_weight*new_b)/w_new),255.0f); 134 | new_g = fmin(roundf((old_g*w_old+obs_weight*new_g)/w_new),255.0f); 135 | new_r = fmin(roundf((old_r*w_old+obs_weight*new_r)/w_new),255.0f); 136 | color_vol[voxel_idx] = new_b*256*256+new_g*256+new_r; 137 | }""") 138 | 139 | self._cuda_integrate = self._cuda_src_mod.get_function("integrate") 140 | 141 | # Determine block/grid size on GPU 142 | gpu_dev = cuda.Device(0) 143 | self._max_gpu_threads_per_block = gpu_dev.MAX_THREADS_PER_BLOCK 144 | n_blocks = int(np.ceil(float(np.prod(self._vol_dim)) / 145 | float(self._max_gpu_threads_per_block))) 146 | grid_dim_x = min(gpu_dev.MAX_GRID_DIM_X, 147 | int(np.floor(np.cbrt(n_blocks)))) 148 | grid_dim_y = min(gpu_dev.MAX_GRID_DIM_Y, int( 149 | np.floor(np.sqrt(n_blocks/grid_dim_x)))) 150 | grid_dim_z = min(gpu_dev.MAX_GRID_DIM_Z, int( 151 | np.ceil(float(n_blocks)/float(grid_dim_x*grid_dim_y)))) 152 | self._max_gpu_grid_dim = np.array( 153 | [grid_dim_x, grid_dim_y, grid_dim_z]).astype(int) 154 | self._n_gpu_loops = int(np.ceil(float(np.prod( 155 | self._vol_dim))/float(np.prod(self._max_gpu_grid_dim)*self._max_gpu_threads_per_block))) 156 | 157 | else: 158 | # Get voxel grid coordinates 159 | xv, yv, zv = np.meshgrid( 160 | range(self._vol_dim[0]), 161 | range(self._vol_dim[1]), 162 | range(self._vol_dim[2]), 163 | indexing='ij' 164 | ) 165 | self.vox_coords = np.concatenate([ 166 | xv.reshape(1, -1), 167 | yv.reshape(1, -1), 168 | zv.reshape(1, -1) 169 | ], axis=0).astype(int).T 170 | 171 | @staticmethod 172 | @njit(parallel=True) 173 | def vox2world(vol_origin, vox_coords, vox_size): 174 | """Convert voxel grid coordinates to world coordinates. 175 | """ 176 | vol_origin = vol_origin.astype(np.float32) 177 | vox_coords = vox_coords.astype(np.float32) 178 | cam_pts = np.empty_like(vox_coords, dtype=np.float32) 179 | for i in prange(vox_coords.shape[0]): 180 | for j in range(3): 181 | cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j]) 182 | return cam_pts 183 | 184 | @staticmethod 185 | @njit(parallel=True) 186 | def cam2pix(cam_pts, intr): 187 | """Convert camera coordinates to pixel coordinates. 188 | """ 189 | intr = intr.astype(np.float32) 190 | fx, fy = intr[0, 0], intr[1, 1] 191 | cx, cy = intr[0, 2], intr[1, 2] 192 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64) 193 | for i in prange(cam_pts.shape[0]): 194 | pix[i, 0] = int( 195 | np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx)) 196 | pix[i, 1] = int( 197 | np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy)) 198 | return pix 199 | 200 | @staticmethod 201 | @njit(parallel=True) 202 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight): 203 | """Integrate the TSDF volume. 204 | """ 205 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32) 206 | w_new = np.empty_like(w_old, dtype=np.float32) 207 | for i in prange(len(tsdf_vol)): 208 | w_new[i] = w_old[i] + obs_weight 209 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + 210 | obs_weight * dist[i]) / w_new[i] 211 | return tsdf_vol_int, w_new 212 | 213 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.): 214 | """Integrate an RGB-D frame into the TSDF volume. 215 | Args: 216 | color_im (ndarray): An RGB image of shape (H, W, 3). 217 | depth_im (ndarray): A depth image of shape (H, W). 218 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3). 219 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4). 220 | obs_weight (float): The weight to assign for the current observation. A higher 221 | value 222 | """ 223 | im_h, im_w = depth_im.shape 224 | 225 | # Fold RGB color image into a single channel image 226 | color_im = color_im.astype(np.float32) 227 | color_im = np.floor( 228 | color_im[..., 2]*self._color_const + color_im[..., 1]*256 + color_im[..., 0]) 229 | 230 | # GPU mode: integrate voxel volume (calls CUDA kernel) 231 | if self.gpu_mode: 232 | for gpu_loop_idx in range(self._n_gpu_loops): 233 | self._cuda_integrate(self._tsdf_vol_gpu, 234 | self._weight_vol_gpu, 235 | self._color_vol_gpu, 236 | cuda.InOut( 237 | self._vol_dim.astype(np.float32)), 238 | cuda.InOut( 239 | self._vol_origin.astype(np.float32)), 240 | cuda.InOut( 241 | cam_intr.reshape(-1).astype(np.float32)), 242 | cuda.InOut( 243 | cam_pose.reshape(-1).astype(np.float32)), 244 | cuda.InOut(np.asarray([ 245 | gpu_loop_idx, 246 | self._voxel_size, 247 | im_h, 248 | im_w, 249 | self._trunc_margin, 250 | obs_weight 251 | ], np.float32)), 252 | cuda.InOut( 253 | color_im.reshape(-1).astype(np.float32)), 254 | cuda.InOut( 255 | depth_im.reshape(-1).astype(np.float32)), 256 | block=( 257 | self._max_gpu_threads_per_block, 1, 1), 258 | grid=( 259 | int(self._max_gpu_grid_dim[0]), 260 | int(self._max_gpu_grid_dim[1]), 261 | int(self._max_gpu_grid_dim[2]), 262 | ) 263 | ) 264 | else: # CPU mode: integrate voxel volume (vectorized implementation) 265 | # Convert voxel grid coordinates to pixel coordinates 266 | cam_pts = self.vox2world( 267 | self._vol_origin, self.vox_coords, self._voxel_size) 268 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose)) 269 | pix_z = cam_pts[:, 2] 270 | pix = self.cam2pix(cam_pts, cam_intr) 271 | pix_x, pix_y = pix[:, 0], pix[:, 1] 272 | 273 | # Eliminate pixels outside view frustum 274 | valid_pix = np.logical_and(pix_x >= 0, 275 | np.logical_and(pix_x < im_w, 276 | np.logical_and(pix_y >= 0, 277 | np.logical_and(pix_y < im_h, 278 | pix_z > 0)))) 279 | depth_val = np.zeros(pix_x.shape) 280 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]] 281 | 282 | # Integrate TSDF 283 | depth_diff = depth_val - pix_z 284 | valid_pts = np.logical_and( 285 | depth_val > 0, depth_diff >= -self._trunc_margin) 286 | dist = np.minimum(1, depth_diff / self._trunc_margin) 287 | valid_vox_x = self.vox_coords[valid_pts, 0] 288 | valid_vox_y = self.vox_coords[valid_pts, 1] 289 | valid_vox_z = self.vox_coords[valid_pts, 2] 290 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 291 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, 292 | valid_vox_y, valid_vox_z] 293 | valid_dist = dist[valid_pts] 294 | tsdf_vol_new, w_new = self.integrate_tsdf( 295 | tsdf_vals, valid_dist, w_old, obs_weight) 296 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new 297 | self._tsdf_vol_cpu[valid_vox_x, 298 | valid_vox_y, valid_vox_z] = tsdf_vol_new 299 | 300 | # Integrate color 301 | old_color = self._color_vol_cpu[valid_vox_x, 302 | valid_vox_y, valid_vox_z] 303 | old_b = np.floor(old_color / self._color_const) 304 | old_g = np.floor((old_color-old_b*self._color_const)/256) 305 | old_r = old_color - old_b*self._color_const - old_g*256 306 | new_color = color_im[pix_y[valid_pts], pix_x[valid_pts]] 307 | new_b = np.floor(new_color / self._color_const) 308 | new_g = np.floor((new_color - new_b*self._color_const) / 256) 309 | new_r = new_color - new_b*self._color_const - new_g*256 310 | new_b = np.minimum(255., np.round( 311 | (w_old*old_b + obs_weight*new_b) / w_new)) 312 | new_g = np.minimum(255., np.round( 313 | (w_old*old_g + obs_weight*new_g) / w_new)) 314 | new_r = np.minimum(255., np.round( 315 | (w_old*old_r + obs_weight*new_r) / w_new)) 316 | self._color_vol_cpu[valid_vox_x, valid_vox_y, 317 | valid_vox_z] = new_b*self._color_const + new_g*256 + new_r 318 | 319 | def get_volume(self): 320 | if self.gpu_mode: 321 | cuda.memcpy_dtoh(self._tsdf_vol_cpu, self._tsdf_vol_gpu) 322 | cuda.memcpy_dtoh(self._color_vol_cpu, self._color_vol_gpu) 323 | return self._tsdf_vol_cpu, self._color_vol_cpu 324 | 325 | def get_point_cloud(self): 326 | """Extract a point cloud from the voxel volume. 327 | """ 328 | tsdf_vol, color_vol = self.get_volume() 329 | 330 | # Marching cubes 331 | verts = measure.marching_cubes(tsdf_vol, level=0)[0] 332 | verts_ind = np.round(verts).astype(int) 333 | verts = verts*self._voxel_size + self._vol_origin 334 | 335 | # Get vertex colors 336 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 337 | colors_b = np.floor(rgb_vals / self._color_const) 338 | colors_g = np.floor((rgb_vals - colors_b*self._color_const) / 256) 339 | colors_r = rgb_vals - colors_b*self._color_const - colors_g*256 340 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 341 | colors = colors.astype(np.uint8) 342 | 343 | pc = np.hstack([verts, colors]) 344 | return pc 345 | 346 | def get_mesh(self): 347 | """Compute a mesh from the voxel volume using marching cubes. 348 | """ 349 | tsdf_vol, color_vol = self.get_volume() 350 | 351 | # Marching cubes 352 | verts, faces, norms, vals = measure.marching_cubes( 353 | tsdf_vol, level=0) 354 | verts_ind = np.round(verts).astype(int) 355 | # voxel grid coordinates to world coordinates 356 | verts = verts*self._voxel_size+self._vol_origin 357 | 358 | # Get vertex colors 359 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 360 | colors_b = np.floor(rgb_vals/self._color_const) 361 | colors_g = np.floor((rgb_vals-colors_b*self._color_const)/256) 362 | colors_r = rgb_vals-colors_b*self._color_const-colors_g*256 363 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 364 | colors = colors.astype(np.uint8) 365 | return verts, faces, norms, colors 366 | 367 | 368 | def rigid_transform(xyz, transform): 369 | """Applies a rigid transform to an (N, 3) pointcloud. 370 | """ 371 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)]) 372 | xyz_t_h = np.dot(transform, xyz_h.T).T 373 | return xyz_t_h[:, :3] 374 | 375 | 376 | def meshwrite(filename, verts, faces, norms, colors): 377 | """Save a 3D mesh to a polygon .ply file. 378 | """ 379 | # Write header 380 | ply_file = open(filename, 'w') 381 | ply_file.write("ply\n") 382 | ply_file.write("format ascii 1.0\n") 383 | ply_file.write("element vertex %d\n" % (verts.shape[0])) 384 | ply_file.write("property float x\n") 385 | ply_file.write("property float y\n") 386 | ply_file.write("property float z\n") 387 | ply_file.write("property float nx\n") 388 | ply_file.write("property float ny\n") 389 | ply_file.write("property float nz\n") 390 | ply_file.write("property uchar red\n") 391 | ply_file.write("property uchar green\n") 392 | ply_file.write("property uchar blue\n") 393 | ply_file.write("element face %d\n" % (faces.shape[0])) 394 | ply_file.write("property list uchar int vertex_index\n") 395 | ply_file.write("end_header\n") 396 | 397 | # Write vertex list 398 | for i in range(verts.shape[0]): 399 | ply_file.write("%f %f %f %f %f %f %d %d %d\n" % ( 400 | verts[i, 0], verts[i, 1], verts[i, 2], 401 | norms[i, 0], norms[i, 1], norms[i, 2], 402 | colors[i, 0], colors[i, 1], colors[i, 2], 403 | )) 404 | 405 | # Write face list 406 | for i in range(faces.shape[0]): 407 | ply_file.write("3 %d %d %d\n" % 408 | (faces[i, 0], faces[i, 1], faces[i, 2])) 409 | 410 | ply_file.close() 411 | 412 | def pcwrite(filename, xyzrgb): 413 | """Save a point cloud to a polygon .ply file. 414 | """ 415 | xyz = xyzrgb[:, :3] 416 | rgb = xyzrgb[:, 3:].astype(np.uint8) 417 | 418 | # Write header 419 | ply_file = open(filename,'w') 420 | ply_file.write("ply\n") 421 | ply_file.write("format ascii 1.0\n") 422 | ply_file.write("element vertex %d\n"%(xyz.shape[0])) 423 | ply_file.write("property float x\n") 424 | ply_file.write("property float y\n") 425 | ply_file.write("property float z\n") 426 | ply_file.write("property uchar red\n") 427 | ply_file.write("property uchar green\n") 428 | ply_file.write("property uchar blue\n") 429 | ply_file.write("end_header\n") 430 | 431 | # Write vertex list 432 | for i in range(xyz.shape[0]): 433 | ply_file.write("%f %f %f %d %d %d\n"%( 434 | xyz[i, 0], xyz[i, 1], xyz[i, 2], 435 | rgb[i, 0], rgb[i, 1], rgb[i, 2], 436 | )) 437 | 438 | def tsdf2mesh(tsdf_vol, mesh_path): 439 | tsdf_vol = np.transpose(tsdf_vol, (1, 0, 2)) 440 | # Marching cubes 441 | verts,faces,norms,vals = measure.marching_cubes(tsdf_vol,level=0.0) 442 | 443 | # Get vertex colors 444 | colors = np.array([(186//2, 176//2, 172//2) for _ in range(verts.shape[0])], dtype=np.uint8) 445 | meshwrite(mesh_path, verts, faces, norms, colors) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from model_utils import ConvBlock3D, ResBlock3D, ConvBlock2D, MLP 5 | from se3.se3_module import SE3 6 | from forward_warp import Forward_Warp_Cupy 7 | 8 | 9 | class VolumeEncoder(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | input_channel = 12 13 | self.conv00 = ConvBlock3D(input_channel, 16, stride=2, dilation=1, norm=True, relu=True) # 64x64x24 14 | 15 | self.conv10 = ConvBlock3D(16, 32, stride=2, dilation=1, norm=True, relu=True) # 32x32x12 16 | self.conv11 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True) 17 | self.conv12 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True) 18 | self.conv13 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True) 19 | 20 | self.conv20 = ConvBlock3D(32, 64, stride=2, dilation=1, norm=True, relu=True) # 16x16x6 21 | self.conv21 = ConvBlock3D(64, 64, stride=1, dilation=1, norm=True, relu=True) 22 | self.conv22 = ConvBlock3D(64, 64, stride=1, dilation=1, norm=True, relu=True) 23 | self.conv23 = ConvBlock3D(64, 64, stride=1, dilation=1, norm=True, relu=True) 24 | 25 | self.conv30 = ConvBlock3D(64, 128, stride=2, dilation=1, norm=True, relu=True) # 8x8x3 26 | self.resn31 = ResBlock3D(128, 128) 27 | self.resn32 = ResBlock3D(128, 128) 28 | 29 | 30 | def forward(self, x): 31 | x0 = self.conv00(x) 32 | 33 | x1 = self.conv10(x0) 34 | x1 = self.conv11(x1) 35 | x1 = self.conv12(x1) 36 | x1 = self.conv13(x1) 37 | 38 | x2 = self.conv20(x1) 39 | x2 = self.conv21(x2) 40 | x2 = self.conv22(x2) 41 | x2 = self.conv23(x2) 42 | 43 | x3 = self.conv30(x2) 44 | x3 = self.resn31(x3) 45 | x3 = self.resn32(x3) 46 | 47 | return x3, (x2, x1, x0) 48 | 49 | 50 | class FeatureDecoder(nn.Module): 51 | def __init__(self): 52 | super().__init__() 53 | self.conv00 = ConvBlock3D(128, 64, norm=True, relu=True, upsm=True) # 16x16x6 54 | self.conv01 = ConvBlock3D(64, 64, norm=True, relu=True) 55 | 56 | self.conv10 = ConvBlock3D(64 + 64, 32, norm=True, relu=True, upsm=True) # 32x32x12 57 | self.conv11 = ConvBlock3D(32, 32, norm=True, relu=True) 58 | 59 | self.conv20 = ConvBlock3D(32 + 32, 16, norm=True, relu=True, upsm=True) # 64X64X24 60 | self.conv21 = ConvBlock3D(16, 16, norm=True, relu=True) 61 | 62 | self.conv30 = ConvBlock3D(16 + 16, 8, norm=True, relu=True, upsm=True) # 128X128X48 63 | self.conv31 = ConvBlock3D(8, 8, norm=True, relu=True) 64 | 65 | def forward(self, x, cache): 66 | m0, m1, m2 = cache 67 | 68 | x0 = self.conv00(x) 69 | x0 = self.conv01(x0) 70 | 71 | x1 = self.conv10(torch.cat([x0, m0], dim=1)) 72 | x1 = self.conv11(x1) 73 | 74 | x2 = self.conv20(torch.cat([x1, m1], dim=1)) 75 | x2 = self.conv21(x2) 76 | 77 | x3 = self.conv30(torch.cat([x2, m2], dim=1)) 78 | x3 = self.conv31(x3) 79 | 80 | return x3 81 | 82 | 83 | class MotionDecoder(nn.Module): 84 | def __init__(self): 85 | super().__init__() 86 | self.conv3d00 = ConvBlock3D(8 + 8, 8, stride=2, dilation=1, norm=True, relu=True) # 64 87 | 88 | self.conv3d10 = ConvBlock3D(8 + 8, 16, stride=2, dilation=1, norm=True, relu=True) # 32 89 | 90 | self.conv3d20 = ConvBlock3D(16 + 16, 32, stride=2, dilation=1, norm=True, relu=True) # 16 91 | 92 | self.conv3d30 = ConvBlock3D(32, 16, dilation=1, norm=True, relu=True, upsm=True) # 32 93 | self.conv3d40 = ConvBlock3D(16, 8, dilation=1, norm=True, relu=True, upsm=True) # 64 94 | self.conv3d50 = ConvBlock3D(8, 8, dilation=1, norm=True, relu=True, upsm=True) # 128 95 | self.conv3d60 = nn.Conv3d(8, 3, kernel_size=3, padding=1) 96 | 97 | 98 | self.conv2d10 = ConvBlock2D(8, 64, stride=2, norm=True, relu=True) # 64 99 | self.conv2d11 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True) 100 | self.conv2d12 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True) 101 | self.conv2d13 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True) 102 | self.conv2d14 = ConvBlock2D(64, 8, stride=1, dilation=1, norm=True, relu=True) 103 | 104 | self.conv2d20 = ConvBlock2D(64, 128, stride=2, norm=True, relu=True) # 32 105 | self.conv2d21 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True) 106 | self.conv2d22 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True) 107 | self.conv2d23 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True) 108 | self.conv2d24 = ConvBlock2D(128, 16, stride=1, dilation=1, norm=True, relu=True) 109 | 110 | def forward(self, feature, action): 111 | # feature: [B, 8, 128, 128, 48] 112 | # action: [B, 8, 128, 128] 113 | 114 | action_embedding0 = torch.unsqueeze(action, -1).expand([-1, -1, -1, -1, 48]) 115 | feature0 = self.conv3d00(torch.cat([feature, action_embedding0], dim=1)) 116 | 117 | action1 = self.conv2d10(action) 118 | action1 = self.conv2d11(action1) 119 | action1 = self.conv2d12(action1) 120 | action1 = self.conv2d13(action1) 121 | 122 | action_embedding1 = self.conv2d14(action1) 123 | action_embedding1 = torch.unsqueeze(action_embedding1, -1).expand([-1, -1, -1, -1, 24]) 124 | feature1 = self.conv3d10(torch.cat([feature0, action_embedding1], dim=1)) 125 | 126 | action2 = self.conv2d20(action1) 127 | action2 = self.conv2d21(action2) 128 | action2 = self.conv2d22(action2) 129 | action2 = self.conv2d23(action2) 130 | 131 | action_embedding2 = self.conv2d24(action2) 132 | action_embedding2 = torch.unsqueeze(action_embedding2, -1).expand([-1, -1, -1, -1, 12]) 133 | feature2 = self.conv3d20(torch.cat([feature1, action_embedding2], dim=1)) 134 | 135 | feature3 = self.conv3d30(feature2) 136 | feature4 = self.conv3d40(feature3 + feature1) 137 | feature5 = self.conv3d50(feature4 + feature0) 138 | 139 | motion_pred = self.conv3d60(feature5) 140 | 141 | return motion_pred 142 | 143 | class MaskDecoder(nn.Module): 144 | def __init__(self, K): 145 | super().__init__() 146 | self.decoder = nn.Conv3d(8, K, kernel_size=1) 147 | 148 | def forward(self, x): 149 | logit = self.decoder(x) 150 | mask = torch.softmax(logit, dim=1) 151 | return logit, mask 152 | 153 | 154 | class TransformDecoder(nn.Module): 155 | def __init__(self, transform_type, object_num): 156 | super().__init__() 157 | num_params_dict = { 158 | 'affine': 12, 159 | 'se3euler': 6, 160 | 'se3aa': 6, 161 | 'se3spquat': 6, 162 | 'se3quat': 7 163 | } 164 | self.num_params = num_params_dict[transform_type] 165 | self.object_num = object_num 166 | 167 | self.conv3d00 = ConvBlock3D(8 + 8, 8, stride=2, dilation=1, norm=True, relu=True) # 64 168 | 169 | self.conv3d10 = ConvBlock3D(8 + 8, 16, stride=2, dilation=1, norm=True, relu=True) # 32 170 | 171 | self.conv3d20 = ConvBlock3D(16 + 16, 32, stride=2, dilation=1, norm=True, relu=True) # 16 172 | self.conv3d21 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True) 173 | self.conv3d22 = ConvBlock3D(32, 32, stride=1, dilation=1, norm=True, relu=True) 174 | self.conv3d23 = ConvBlock3D(32, 64, stride=1, dilation=1, norm=True, relu=True) 175 | 176 | self.conv3d30 = ConvBlock3D(64, 128, stride=2, dilation=1, norm=True, relu=True) # 8 177 | 178 | self.conv3d40 = ConvBlock3D(128, 128, stride=2, dilation=1, norm=True, relu=True) # 4 179 | 180 | self.conv3d50 = nn.Conv3d(128, 128, kernel_size=(4, 4, 2)) 181 | 182 | 183 | self.conv2d10 = ConvBlock2D(8, 64, stride=2, norm=True, relu=True) # 64 184 | self.conv2d11 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True) 185 | self.conv2d12 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True) 186 | self.conv2d13 = ConvBlock2D(64, 64, stride=1, dilation=1, norm=True, relu=True) 187 | self.conv2d14 = ConvBlock2D(64, 8, stride=1, dilation=1, norm=True, relu=True) 188 | 189 | self.conv2d20 = ConvBlock2D(64, 128, stride=2, norm=True, relu=True) # 32 190 | self.conv2d21 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True) 191 | self.conv2d22 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True) 192 | self.conv2d23 = ConvBlock2D(128, 128, stride=1, dilation=1, norm=True, relu=True) 193 | self.conv2d24 = ConvBlock2D(128, 16, stride=1, dilation=1, norm=True, relu=True) 194 | 195 | self.mlp = MLP( 196 | input_dim=128, 197 | output_dim=self.num_params * self.object_num, 198 | hidden_sizes=[512, 512, 512, 512], 199 | hidden_nonlinearity=F.leaky_relu 200 | ) 201 | 202 | 203 | def forward(self, feature, action): 204 | # feature: [B, 8, 128, 128, 48] 205 | # action: [B, 8, 128, 128] 206 | 207 | action_embedding0 = torch.unsqueeze(action, -1).expand([-1, -1, -1, -1, 48]) 208 | feature0 = self.conv3d00(torch.cat([feature, action_embedding0], dim=1)) 209 | 210 | action1 = self.conv2d10(action) 211 | action1 = self.conv2d11(action1) 212 | action1 = self.conv2d12(action1) 213 | action1 = self.conv2d13(action1) 214 | 215 | action_embedding1 = self.conv2d14(action1) 216 | action_embedding1 = torch.unsqueeze(action_embedding1, -1).expand([-1, -1, -1, -1, 24]) 217 | feature1 = self.conv3d10(torch.cat([feature0, action_embedding1], dim=1)) 218 | 219 | action2 = self.conv2d20(action1) 220 | action2 = self.conv2d21(action2) 221 | action2 = self.conv2d22(action2) 222 | action2 = self.conv2d23(action2) 223 | 224 | action_embedding2 = self.conv2d24(action2) 225 | action_embedding2 = torch.unsqueeze(action_embedding2, -1).expand([-1, -1, -1, -1, 12]) 226 | feature2 = self.conv3d20(torch.cat([feature1, action_embedding2], dim=1)) 227 | feature2 = self.conv3d21(feature2) 228 | feature2 = self.conv3d22(feature2) 229 | feature2 = self.conv3d23(feature2) 230 | 231 | feature3 = self.conv3d30(feature2) 232 | feature4 = self.conv3d40(feature3) 233 | feature5 = self.conv3d50(feature4) 234 | 235 | params = self.mlp(feature5.view([-1, 128])) 236 | params = params.view([-1, self.object_num, self.num_params]) 237 | 238 | return params 239 | 240 | 241 | class ModelDSR(nn.Module): 242 | def __init__(self, object_num=5, transform_type='se3euler', motion_type='se3'): 243 | # transform_type options: None, 'affine', 'se3euler', 'se3aa', 'se3quat', 'se3spquat' 244 | # motion_type options: 'se3', 'conv' 245 | # input volume size: [128, 128, 48] 246 | 247 | super().__init__() 248 | self.transform_type = transform_type 249 | self.K = object_num 250 | self.motion_type = motion_type 251 | 252 | # modules 253 | self.forward_warp = Forward_Warp_Cupy.apply 254 | self.volume_encoder = VolumeEncoder() 255 | self.feature_decoder = FeatureDecoder() 256 | if self.motion_type == 'se3': 257 | self.mask_decoder = MaskDecoder(self.K) 258 | self.transform_decoder = TransformDecoder( 259 | transform_type=self.transform_type, 260 | object_num=self.K - 1 261 | ) 262 | self.se3 = SE3(self.transform_type) 263 | elif self.motion_type == 'conv': 264 | self.motion_decoder = MotionDecoder() 265 | else: 266 | raise ValueError('motion_type doesn\'t support ', self.motion_type) 267 | 268 | # initialization 269 | for m in self.named_modules(): 270 | if isinstance(m[1], nn.Conv3d) or isinstance(m[1], nn.Conv2d): 271 | nn.init.kaiming_normal_(m[1].weight.data) 272 | elif isinstance(m[1], nn.BatchNorm3d) or isinstance(m[1], nn.BatchNorm2d): 273 | m[1].weight.data.fill_(1) 274 | m[1].bias.data.zero_() 275 | 276 | # const value 277 | self.grids = torch.stack(torch.meshgrid( 278 | torch.linspace(0, 127, 128), 279 | torch.linspace(0, 127, 128), 280 | torch.linspace(0, 47, 48) 281 | )) 282 | self.coord_feature = self.grids / torch.tensor([128, 128, 48]).view([3, 1, 1, 1]) 283 | self.grids_flat = self.grids.view(1, 1, 3, 128 * 128 * 48) 284 | self.zero_vec = torch.zeros([1, 1, 3], dtype=torch.float) 285 | self.eye_mat = torch.eye(3, dtype=torch.float) 286 | 287 | def forward(self, input_volume, last_s=None, input_action=None, input_motion=None, next_mask=False, no_warp=False): 288 | B, _, S1, S2, S3 = input_volume.size() 289 | K = self.K 290 | device = input_volume.device 291 | output = {} 292 | 293 | input = torch.cat((input_volume, self.coord_feature.expand(B, -1, -1, -1, -1).to(device)), dim=1) 294 | input = torch.cat((input, last_s), dim=1) # aggregate history 295 | 296 | volume_embedding, cache = self.volume_encoder(input) 297 | mask_feature = self.feature_decoder(volume_embedding, cache) 298 | 299 | if self.motion_type == 'conv': 300 | motion = self.motion_decoder(mask_feature, input_action) 301 | output['motion'] = motion 302 | 303 | return output 304 | 305 | 306 | assert(self.motion_type == 'se3') 307 | logit, mask = self.mask_decoder(mask_feature) 308 | output['init_logit'] = logit 309 | transform_param = self.transform_decoder(mask_feature, input_action) 310 | 311 | # trans, pivot: [B, K-1, 3] 312 | # rot_matrix: [B, K-1, 3, 3] 313 | trans_vec, rot_mat = self.se3(transform_param) 314 | mask_object = torch.narrow(mask, 1, 0, K - 1) 315 | sum_mask = torch.sum(mask_object, dim=(2, 3, 4)) 316 | heatmap = torch.unsqueeze(mask_object, dim=2) * self.grids.to(device) 317 | pivot_vec = torch.sum(heatmap, dim=(3, 4, 5)) / torch.unsqueeze(sum_mask, dim=2) 318 | 319 | # [Important] The last one is the background! 320 | trans_vec = torch.cat([trans_vec, self.zero_vec.expand(B, -1, -1).to(device)], dim=1).unsqueeze(-1) 321 | rot_mat = torch.cat([rot_mat, self.eye_mat.expand(B, 1, -1, -1).to(device)], dim=1) 322 | pivot_vec = torch.cat([pivot_vec, self.zero_vec.expand(B, -1, -1).to(device)], dim=1).unsqueeze(-1) 323 | 324 | grids_flat = self.grids_flat.to(device) 325 | grids_after_flat = rot_mat @ (grids_flat - pivot_vec) + pivot_vec + trans_vec 326 | motion = (grids_after_flat - grids_flat).view([B, K, 3, S1, S2, S3]) 327 | 328 | motion = torch.sum(motion * torch.unsqueeze(mask, 2), 1) 329 | 330 | output['motion'] = motion 331 | 332 | if no_warp: 333 | output['s'] = mask_feature 334 | elif input_motion is not None: 335 | mask_feature_warp = self.forward_warp( 336 | mask_feature, 337 | input_motion, 338 | torch.sum(mask[:, :-1, ], dim=1) 339 | ) 340 | output['s'] = mask_feature_warp 341 | else: 342 | mask_feature_warp = self.forward_warp( 343 | mask_feature, 344 | motion, 345 | torch.sum(mask[:, :-1, ], dim=1) 346 | ) 347 | output['s'] = mask_feature_warp 348 | 349 | if next_mask: 350 | mask_warp = self.forward_warp( 351 | mask, 352 | motion, 353 | torch.sum(mask[:, :-1, ], dim=1) 354 | ) 355 | output['next_mask'] = mask_warp 356 | 357 | return output 358 | 359 | 360 | 361 | def get_init_repr(self, batch_size): 362 | return torch.zeros([batch_size, 8, 128, 128, 48], dtype=torch.float) 363 | 364 | 365 | if __name__=='__main__': 366 | torch.cuda.set_device(4) 367 | model = ModelDSR( 368 | object_num=2, 369 | transform_type='se3euler', 370 | with_history=True, 371 | motion_type='se3' 372 | ).cuda() 373 | 374 | input_volume = torch.rand((4, 1, 128, 128, 48)).cuda() 375 | input_action = torch.rand((4, 8, 128, 128)).cuda() 376 | last_s = model.get_init_repr(4).cuda() 377 | 378 | output = model(input_volume=input_volume, last_s=last_s, input_action=input_action, next_mask=True) 379 | 380 | for k in output.keys(): 381 | print(k, output[k].size()) 382 | 383 | 384 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['ConvBlock2D', 'ConvBlock3D', 'ResBlock2D', 'ResBlock3D', 'MLP'] 6 | 7 | 8 | class ConvBlock2D(nn.Module): 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, norm=False, relu=False, pool=False, upsm=False): 10 | super().__init__() 11 | self.conv = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=not norm) 12 | self.norm = nn.BatchNorm2d(planes) if norm else None 13 | self.relu = nn.LeakyReLU(inplace=True) if relu else None 14 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if pool else None 15 | self.upsm = upsm 16 | 17 | def forward(self, x): 18 | out = self.conv(x) 19 | 20 | out = out if self.norm is None else self.norm(out) 21 | out = out if self.relu is None else self.relu(out) 22 | out = out if self.pool is None else self.pool(out) 23 | out = out if not self.upsm else F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True) 24 | 25 | return out 26 | 27 | 28 | class ConvBlock3D(nn.Module): 29 | def __init__(self, inplanes, planes, stride=1, dilation=1, norm=False, relu=False, pool=False, upsm=False): 30 | super().__init__() 31 | 32 | self.conv = nn.Conv3d(inplanes, planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=not norm) 33 | self.norm = nn.BatchNorm3d(planes) if norm else None 34 | self.relu = nn.LeakyReLU(inplace=True) if relu else None 35 | self.pool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) if pool else None 36 | self.upsm = upsm 37 | 38 | def forward(self, x): 39 | out = self.conv(x) 40 | 41 | out = out if self.norm is None else self.norm(out) 42 | out = out if self.relu is None else self.relu(out) 43 | out = out if self.pool is None else self.pool(out) 44 | out = out if not self.upsm else F.interpolate(out, scale_factor=2, mode='trilinear', align_corners=True) 45 | 46 | return out 47 | 48 | 49 | class ResBlock2D(nn.Module): 50 | def __init__(self, inplanes, planes, downsample=None, bias=False): 51 | super().__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=bias) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.LeakyReLU(inplace=True) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=bias) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | residual = self.downsample(x) 71 | 72 | out += residual 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class ResBlock3D(nn.Module): 79 | def __init__(self, inplanes, planes, downsample=None): 80 | super().__init__() 81 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm3d(planes) 83 | self.relu = nn.LeakyReLU(inplace=True) 84 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn2 = nn.BatchNorm3d(planes) 86 | self.downsample = downsample 87 | 88 | def forward(self, x): 89 | residual = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class MLP(nn.Module): 108 | """ 109 | MLP Model. 110 | Args: 111 | input_dim (int) : Dimension of the network input. 112 | output_dim (int): Dimension of the network output. 113 | hidden_sizes (list[int]): Output dimension of dense layer(s). 114 | For example, (32, 32) means this MLP consists of two 115 | hidden layers, each with 32 hidden units. 116 | hidden_nonlinearity (callable): Activation function for intermediate 117 | dense layer(s). It should return a torch.Tensor. Set it to 118 | None to maintain a linear activation. 119 | hidden_w_init (callable): Initializer function for the weight 120 | of intermediate dense layer(s). The function should return a 121 | torch.Tensor. 122 | hidden_b_init (callable): Initializer function for the bias 123 | of intermediate dense layer(s). The function should return a 124 | torch.Tensor. 125 | output_nonlinearity (callable): Activation function for output dense 126 | layer. It should return a torch.Tensor. Set it to None to 127 | maintain a linear activation. 128 | output_w_init (callable): Initializer function for the weight 129 | of output dense layer(s). The function should return a 130 | torch.Tensor. 131 | output_b_init (callable): Initializer function for the bias 132 | of output dense layer(s). The function should return a 133 | torch.Tensor. 134 | layer_normalization (bool): Bool for using layer normalization or not. 135 | Return: 136 | The output torch.Tensor of the MLP 137 | """ 138 | 139 | def __init__(self, 140 | input_dim, 141 | output_dim, 142 | hidden_sizes, 143 | hidden_nonlinearity=F.relu, 144 | hidden_w_init=nn.init.xavier_normal_, 145 | hidden_b_init=nn.init.zeros_, 146 | output_nonlinearity=None, 147 | output_w_init=nn.init.xavier_normal_, 148 | output_b_init=nn.init.zeros_, 149 | layer_normalization=False): 150 | super().__init__() 151 | 152 | self._input_dim = input_dim 153 | self._output_dim = output_dim 154 | self._hidden_nonlinearity = hidden_nonlinearity 155 | self._output_nonlinearity = output_nonlinearity 156 | self._layer_normalization = layer_normalization 157 | self._layers = nn.ModuleList() 158 | 159 | prev_size = input_dim 160 | 161 | for size in hidden_sizes: 162 | layer = nn.Linear(prev_size, size) 163 | hidden_w_init(layer.weight) 164 | hidden_b_init(layer.bias) 165 | self._layers.append(layer) 166 | prev_size = size 167 | 168 | layer = nn.Linear(prev_size, output_dim) 169 | output_w_init(layer.weight) 170 | output_b_init(layer.bias) 171 | self._layers.append(layer) 172 | 173 | def forward(self, input_val): 174 | """Forward method.""" 175 | B = input_val.size(0) 176 | x = input_val.view(B, -1) 177 | for layer in self._layers[:-1]: 178 | x = layer(x) 179 | if self._hidden_nonlinearity is not None: 180 | x = self._hidden_nonlinearity(x) 181 | if self._layer_normalization: 182 | x = nn.LayerNorm(x.shape[1])(x) 183 | 184 | x = self._layers[-1](x) 185 | if self._output_nonlinearity is not None: 186 | x = self._output_nonlinearity(x) 187 | 188 | return x -------------------------------------------------------------------------------- /object_models/README.md: -------------------------------------------------------------------------------- 1 | # Object Models 2 | 3 | ### Object meshes are used for data generation in simulation. 4 | - Download object meshes: [shapenet](https://dsr-net.cs.columbia.edu/download/object_models/shapenet.zip) and [ycb](https://dsr-net.cs.columbia.edu/download/object_models/ycb.zip). 5 | - Unzip `shapenet.zip` and `ycb.zip` in this folder. -------------------------------------------------------------------------------- /pretrained_models/3dflow.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/3dflow.pth -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | # Pretrained Models 2 | 3 | ### The following pretrained models are provided: 4 | - [dsr](dsr.pth): DSR-Net introduced in the paper. (without real data finetuning) 5 | - [dsr_ft](dsr_ft.pth): DSR-Net introduced in the paper. (with real data finetuning) 6 | - [single](single.pth): It does not use any history aggregation. 7 | - [nowarp](nowarp.pth): It does not warp the representation before aggregation. 8 | - [gtwarp](gtwarp.pth): It warps the representation with ground truth motion (i.e., performance oracle) 9 | - [3dflow](3dflow.pth): It predicts per-voxel scene flow for the entire 3D volume. 10 | -------------------------------------------------------------------------------- /pretrained_models/dsr.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/dsr.pth -------------------------------------------------------------------------------- /pretrained_models/dsr_ft.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/dsr_ft.pth -------------------------------------------------------------------------------- /pretrained_models/gtwarp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/gtwarp.pth -------------------------------------------------------------------------------- /pretrained_models/nowarp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/nowarp.pth -------------------------------------------------------------------------------- /pretrained_models/single.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/pretrained_models/single.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | appdirs==1.4.4 3 | cachetools==4.1.1 4 | certifi==2020.6.20 5 | chardet==3.0.4 6 | cupy==7.1.1 7 | cycler==0.10.0 8 | decorator==4.4.2 9 | dominate==2.4.0 10 | fastrlock==0.5 11 | google-auth==1.19.2 12 | google-auth-oauthlib==0.4.1 13 | grpcio==1.30.0 14 | h5py==2.10.0 15 | idna==2.10 16 | imageio==2.6.1 17 | importlib-metadata==1.7.0 18 | kiwisolver==1.2.0 19 | llvmlite==0.33.0 20 | Mako==1.1.3 21 | Markdown==3.2.2 22 | MarkupSafe==1.1.1 23 | matplotlib==3.3.0 24 | networkx==2.4 25 | numba==0.50.0 26 | numpy==1.18.1 27 | oauthlib==3.1.0 28 | opencv-python==4.2.0.32 29 | Pillow==8.1.1 30 | pkg-resources==0.0.0 31 | protobuf==3.12.2 32 | pyasn1==0.4.8 33 | pyasn1-modules==0.2.8 34 | pybullet==2.6.4 35 | pycuda==2019.1.2 36 | pynvrtc==9.2 37 | pyparsing==2.4.7 38 | python-dateutil==2.8.1 39 | pytools==2020.3.1 40 | PyWavelets==1.1.1 41 | requests==2.24.0 42 | requests-oauthlib==1.3.0 43 | rsa==4.6 44 | scikit-image==0.16.2 45 | scipy==1.5.1 46 | six==1.15.0 47 | tensorboard==2.2.2 48 | tensorboard-plugin-wit==1.7.0 49 | torch==1.4.0 50 | tqdm==4.48.0 51 | urllib3==1.25.9 52 | Werkzeug==1.0.1 53 | zipp==3.1.0 54 | -------------------------------------------------------------------------------- /se3/__pycache__/se3_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3_module.cpython-36.pyc -------------------------------------------------------------------------------- /se3/__pycache__/se3_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3_utils.cpython-36.pyc -------------------------------------------------------------------------------- /se3/__pycache__/se3aa.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3aa.cpython-36.pyc -------------------------------------------------------------------------------- /se3/__pycache__/se3euler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3euler.cpython-36.pyc -------------------------------------------------------------------------------- /se3/__pycache__/se3quat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3quat.cpython-36.pyc -------------------------------------------------------------------------------- /se3/__pycache__/se3spquat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/dsr/d7560e09f67fc997546c9d735bdcc913e8b5ea79/se3/__pycache__/se3spquat.cpython-36.pyc -------------------------------------------------------------------------------- /se3/se3_module.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import Module 2 | from se3.se3spquat import Se3spquat 3 | from se3.se3quat import Se3quat 4 | from se3.se3euler import Se3euler 5 | from se3.se3aa import Se3aa 6 | 7 | class SE3(Module): 8 | def __init__(self, transform_type='affine', has_pivot=False): 9 | super().__init__() 10 | rot_param_num_dict = { 11 | 'affine': 9, 12 | 'se3euler': 3, 13 | 'se3aa': 3, 14 | 'se3spquat': 3, 15 | 'se3quat': 4 16 | } 17 | self.transform_type = transform_type 18 | self.rot_param_num = rot_param_num_dict[transform_type] 19 | self.has_pivot = has_pivot 20 | self.num_param = rot_param_num_dict[transform_type] + 3 21 | if self.has_pivot: 22 | self.num_param += 3 23 | 24 | def forward(self, input): 25 | B, K, L = input.size() 26 | if L != self.num_param: 27 | raise ValueError('Dimension Error!') 28 | 29 | trans_vec = input.narrow(2, 0, 3) 30 | rot_params = input.narrow(2, 3, self.rot_param_num) 31 | if self.has_pivot: 32 | pivot_vec = input.narrow(2, 3 + self.rot_param_num, 3) 33 | 34 | 35 | if self.transform_type == 'affine': 36 | rot_mat = rot_params.view(B, K, 3, 3) 37 | elif self.transform_type == 'se3euler': 38 | rot_mat = Se3euler.apply(rot_params) 39 | elif self.transform_type == 'se3aa': 40 | rot_mat = Se3aa.apply(rot_params) 41 | elif self.transform_type == 'se3spquat': 42 | rot_mat = Se3spquat.apply(rot_params) 43 | elif self.transform_type == 'se3quat': 44 | rot_mat = Se3quat.apply(rot_params) 45 | 46 | if self.has_pivot: 47 | return trans_vec, rot_mat, pivot_vec 48 | else: 49 | return trans_vec, rot_mat -------------------------------------------------------------------------------- /se3/se3_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Rotation about the X-axis by theta 5 | # From Barfoot's book: http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf (6.7) 6 | def create_rotx(theta): 7 | N = theta.size(0) 8 | rot = torch.eye(3).type_as(theta).view(1, 3, 3).repeat(N, 1, 1) 9 | rot[:, 1, 1] = torch.cos(theta) 10 | rot[:, 2, 2] = rot[:, 1, 1] 11 | rot[:, 1, 2] = torch.sin(theta) 12 | rot[:, 2, 1] = -rot[:, 1, 2] 13 | return rot 14 | 15 | 16 | # Rotation about the Y-axis by theta 17 | # From Barfoot's book: http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf (6.6) 18 | def create_roty(theta): 19 | N = theta.size(0) 20 | rot = torch.eye(3).type_as(theta).view(1, 3, 3).repeat(N, 1, 1) 21 | rot[:, 0, 0] = torch.cos(theta) 22 | rot[:, 2, 2] = rot[:, 0, 0] 23 | rot[:, 2, 0] = torch.sin(theta) 24 | rot[:, 0, 2] = -rot[:, 2, 0] 25 | return rot 26 | 27 | 28 | # Rotation about the Z-axis by theta 29 | # From Barfoot's book: http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf (6.5) 30 | def create_rotz(theta): 31 | N = theta.size(0) 32 | rot = torch.eye(3).type_as(theta).view(1, 3, 3).repeat(N, 1, 1) 33 | rot[:, 0, 0] = torch.cos(theta) 34 | rot[:, 1, 1] = rot[:, 0, 0] 35 | rot[:, 0, 1] = torch.sin(theta) 36 | rot[:, 1, 0] = -rot[:, 0, 1] 37 | return rot 38 | 39 | 40 | # Create a skew-symmetric matrix "S" of size [B x 3 x 3] (passed in) given a [B x 3] vector 41 | def create_skew_symmetric_matrix(vector): 42 | # Create the skew symmetric matrix: 43 | # [0 -z y; z 0 -x; -y x 0] 44 | N = vector.size(0) 45 | vec = vector.contiguous().view(N, 3) 46 | output = vec.new().resize_(N, 3, 3).fill_(0) 47 | output[:, 0, 1] = -vec[:, 2] 48 | output[:, 1, 0] = vec[:, 2] 49 | output[:, 0, 2] = vec[:, 1] 50 | output[:, 2, 0] = -vec[:, 1] 51 | output[:, 1, 2] = -vec[:, 0] 52 | output[:, 2, 1] = vec[:, 0] 53 | return output 54 | 55 | 56 | # Compute the rotation matrix R from a set of unit-quaternions (N x 4): 57 | # From: http://www.tech.plymouth.ac.uk/sme/springerusv/2011/publications_files/Terzakis%20et%20al%202012,%20A%20Recipe%20on%20the%20Parameterization%20of%20Rotation%20Matrices...MIDAS.SME.2012.TR.004.pdf (Eqn 9) 58 | def create_rot_from_unitquat(unitquat): 59 | # Init memory 60 | N = unitquat.size(0) 61 | rot = unitquat.new_zeros([N, 3, 3]) 62 | 63 | # Get quaternion elements. Quat = [qx,qy,qz,qw] with the scalar at the rear 64 | x, y, z, w = unitquat[:, 0], unitquat[:, 1], unitquat[:, 2], unitquat[:, 3] 65 | x2, y2, z2, w2 = x * x, y * y, z * z, w * w 66 | 67 | # Row 1 68 | rot[:, 0, 0] = w2 + x2 - y2 - z2 # rot(0,0) = w^2 + x^2 - y^2 - z^2 69 | rot[:, 0, 1] = 2 * (x * y - w * z) # rot(0,1) = 2*x*y - 2*w*z 70 | rot[:, 0, 2] = 2 * (x * z + w * y) # rot(0,2) = 2*x*z + 2*w*y 71 | 72 | # Row 2 73 | rot[:, 1, 0] = 2 * (x * y + w * z) # rot(1,0) = 2*x*y + 2*w*z 74 | rot[:, 1, 1] = w2 - x2 + y2 - z2 # rot(1,1) = w^2 - x^2 + y^2 - z^2 75 | rot[:, 1, 2] = 2 * (y * z - w * x) # rot(1,2) = 2*y*z - 2*w*x 76 | 77 | # Row 3 78 | rot[:, 2, 0] = 2 * (x * z - w * y) # rot(2,0) = 2*x*z - 2*w*y 79 | rot[:, 2, 1] = 2 * (y * z + w * x) # rot(2,1) = 2*y*z + 2*w*x 80 | rot[:, 2, 2] = w2 - x2 - y2 + z2 # rot(2,2) = w^2 - x^2 - y^2 + z^2 81 | 82 | return rot 83 | 84 | 85 | # Compute the derivatives of the rotation matrix w.r.t the unit quaternion 86 | # From: http://www.tech.plymouth.ac.uk/sme/springerusv/2011/publications_files/Terzakis%20et%20al%202012,%20A%20Recipe%20on%20the%20Parameterization%20of%20Rotation%20Matrices...MIDAS.SME.2012.TR.004.pdf (Eqn 33-36) 87 | def compute_grad_rot_wrt_unitquat(unitquat): 88 | # Compute dR/dq' (9x4 matrix) 89 | N = unitquat.size(0) 90 | x, y, z, w = unitquat.narrow(1, 0, 1), unitquat.narrow(1, 1, 1), unitquat.narrow(1, 2, 1), unitquat.narrow(1, 3, 1) 91 | dRdqh_w = 2 * torch.cat([w, -z, y, z, w, -x, -y, x, w], 1).view(N, 9, 1) # Eqn 33, rows first 92 | dRdqh_x = 2 * torch.cat([x, y, z, y, -x, -w, z, w, -x], 1).view(N, 9, 1) # Eqn 34, rows first 93 | dRdqh_y = 2 * torch.cat([-y, x, w, x, y, z, -w, z, -y], 1).view(N, 9, 1) # Eqn 35, rows first 94 | dRdqh_z = 2 * torch.cat([-z, -w, x, w, -z, y, x, y, z], 1).view(N, 9, 1) # Eqn 36, rows first 95 | dRdqh = torch.cat([dRdqh_x, dRdqh_y, dRdqh_z, dRdqh_w], 2) # N x 9 x 4 96 | 97 | return dRdqh 98 | 99 | 100 | # Compute the derivatives of a unit quaternion w.r.t a quaternion 101 | def compute_grad_unitquat_wrt_quat(unitquat, quat): 102 | # Compute the quaternion norms 103 | N = quat.size(0) 104 | unitquat_v = unitquat.view(-1, 4, 1) 105 | norm2 = (quat * quat).sum(1) # Norm-squared 106 | norm = torch.sqrt(norm2) # Length of the quaternion 107 | 108 | # Compute gradient dq'/dq 109 | # TODO: No check for normalization issues currently 110 | I = torch.eye(4).view(1, 4, 4).expand(N, 4, 4).type_as(quat) 111 | qQ = torch.bmm(unitquat_v, unitquat_v.transpose(1, 2)) # q'*q'^T 112 | dqhdq = (I - qQ) / (norm.view(N, 1, 1).expand_as(I)) 113 | 114 | return dqhdq 115 | 116 | 117 | # Compute the derivatives of a unit quaternion w.r.t a SP quaternion 118 | # From: http://www.tech.plymouth.ac.uk/sme/springerusv/2011/publications_files/Terzakis%20et%20al%202012,%20A%20Recipe%20on%20the%20Parameterization%20of%20Rotation%20Matrices...MIDAS.SME.2012.TR.004.pdf (Eqn 42-45) 119 | def compute_grad_unitquat_wrt_spquat(spquat): 120 | # Compute scalars 121 | N = spquat.size(0) 122 | x, y, z = spquat.narrow(1, 0, 1), spquat.narrow(1, 1, 1), spquat.narrow(1, 2, 1) 123 | x2, y2, z2 = x * x, y * y, z * z 124 | s = 1 + x2 + y2 + z2 # 1 + x^2 + y^2 + z^2 = 1 + alpha^2 125 | s2 = (s * s).expand(N, 4) # (1 + alpha^2)^2 126 | 127 | # Compute gradient dq'/dspq 128 | dqhdspq_x = (torch.cat([2 * s - 4 * x2, -4 * x * y, -4 * x * z, -4 * x], 1) / s2).view(N, 4, 1) 129 | dqhdspq_y = (torch.cat([-4 * x * y, 2 * s - 4 * y2, -4 * y * z, -4 * y], 1) / s2).view(N, 4, 1) 130 | dqhdspq_z = (torch.cat([-4 * x * z, -4 * y * z, 2 * s - 4 * z2, -4 * z], 1) / s2).view(N, 4, 1) 131 | dqhdspq = torch.cat([dqhdspq_x, dqhdspq_y, dqhdspq_z], 2) 132 | 133 | return dqhdspq 134 | 135 | 136 | # Compute Unit Quaternion from SP-Quaternion 137 | def create_unitquat_from_spquat(spquat): 138 | N = spquat.size(0) 139 | unitquat = spquat.new_zeros([N, 4]) 140 | x, y, z = spquat[:, 0], spquat[:, 1], spquat[:, 2] 141 | alpha2 = x * x + y * y + z * z # x^2 + y^2 + z^2 142 | unitquat[:, 0] = (2 * x) / (1 + alpha2) # qx 143 | unitquat[:, 1] = (2 * y) / (1 + alpha2) # qy 144 | unitquat[:, 2] = (2 * z) / (1 + alpha2) # qz 145 | unitquat[:, 3] = (1 - alpha2) / (1 + alpha2) # qw 146 | 147 | return unitquat 148 | -------------------------------------------------------------------------------- /se3/se3aa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import torch.nn.functional as F 4 | from se3.se3_utils import create_skew_symmetric_matrix 5 | 6 | 7 | class Se3aa(Function): 8 | @staticmethod 9 | def forward(ctx, input): 10 | batch_size, num_se3, num_params = input.size() 11 | N = batch_size * num_se3 12 | eps = 1e-12 13 | 14 | rot_params = input.view(batch_size * num_se3, -1) 15 | 16 | # Get the un-normalized axis and angle 17 | axis = rot_params.view(N, 3, 1) # Un-normalized axis 18 | angle2 = (axis * axis).sum(1).view(N, 1, 1) # Norm of vector (squared angle) 19 | angle = torch.sqrt(angle2) # Angle 20 | 21 | # Compute skew-symmetric matrix "K" from the axis of rotation 22 | K = create_skew_symmetric_matrix(axis) 23 | K2 = torch.bmm(K, K) # K * K 24 | 25 | # Compute sines 26 | S = torch.sin(angle) / angle 27 | S.masked_fill_(angle2.lt(eps), 1) # sin(0)/0 ~= 1 28 | 29 | # Compute cosines 30 | C = (1 - torch.cos(angle)) / angle2 31 | C.masked_fill_(angle2.lt(eps), 0) # (1 - cos(0))/0^2 ~= 0 32 | 33 | # Compute the rotation matrix: R = I + (sin(theta)/theta)*K + ((1-cos(theta))/theta^2) * K^2 34 | rot = torch.eye(3).view(1, 3, 3).repeat(N, 1, 1).type_as(rot_params) # R = I 35 | rot += K * S.expand(N, 3, 3) # R = I + (sin(theta)/theta)*K 36 | rot += K2 * C.expand(N, 3, 3) # R = I + (sin(theta)/theta)*K + ((1-cos(theta))/theta^2)*K^2 37 | 38 | ctx.save_for_backward(input, rot) 39 | 40 | return rot.view(batch_size, num_se3, 3, 3) 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | input, rot = ctx.saved_tensors 45 | batch_size, num_se3, num_params = input.size() 46 | N = batch_size * num_se3 47 | eps = 1e-12 48 | grad_output =grad_output.contiguous().view(N, 3, 3) 49 | 50 | rot_params = input.view(batch_size * num_se3, -1) 51 | 52 | axis = rot_params.view(N, 3, 1) # Un-normalized axis 53 | angle2 = (axis * axis).sum(1) # (Bk) x 1 x 1 => Norm of the vector (squared angle) 54 | nSmall = angle2.lt(eps).sum() # Num angles less than threshold 55 | 56 | # Compute: v x (Id - R) for all the columns of (Id-R) 57 | I = torch.eye(3).type_as(input).repeat(N, 1, 1).add(-1, rot) # (Bk) x 3 x 3 => Id - R 58 | vI = torch.cross(axis.expand_as(I), I, 1) # (Bk) x 3 x 3 => v x (Id - R) 59 | 60 | # Compute [v * v' + v x (Id - R)] / ||v||^2 61 | vV = torch.bmm(axis, axis.transpose(1, 2)) # (Bk) x 3 x 3 => v * v' 62 | vV = (vV + vI) / (angle2.view(N, 1, 1).expand_as(vV)) # (Bk) x 3 x 3 => [v * v' + v x (Id - R)] / ||v||^2 63 | 64 | # Iterate over the 3-axis angle parameters to compute their gradients 65 | # ([v * v' + v x (Id - R)] / ||v||^2 _ k) x (R) .* gradOutput where "x" is the cross product 66 | grad_input_list = [] 67 | for k in range(3): 68 | # Create skew symmetric matrix 69 | skewsym = create_skew_symmetric_matrix(vV.narrow(2, k, 1)) 70 | 71 | # For those AAs with angle^2 < threshold, gradient is different 72 | # We assume angle = 0 for these AAs and update the skew-symmetric matrix to be one w.r.t identity 73 | if (nSmall > 0): 74 | vec = torch.zeros(1, 3).type_as(skewsym) 75 | vec[0][k] = 1 # Unit vector 76 | idskewsym = create_skew_symmetric_matrix(vec) 77 | for i in range(N): 78 | if (angle2[i].squeeze()[0] < eps): 79 | skewsym[i].copy_(idskewsym.squeeze()) # Use the new skew sym matrix (around identity) 80 | 81 | # Compute the gradients now 82 | grad_input_list.append(torch.sum(torch.bmm(skewsym, rot) * grad_output, dim=(1, 2))) # [(Bk) x 1 x 1] => (vV x R) .* gradOutput 83 | grad_input = torch.stack(grad_input_list, 1).view(batch_size, num_se3, 3) 84 | 85 | return grad_input 86 | -------------------------------------------------------------------------------- /se3/se3euler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from se3.se3_utils import create_rotx, create_roty, create_rotz 4 | from se3.se3_utils import create_skew_symmetric_matrix 5 | 6 | 7 | class Se3euler(Function): 8 | @staticmethod 9 | def forward(ctx, input): 10 | batch_size, num_se3, num_params = input.size() 11 | 12 | rot_params = input.view(batch_size * num_se3, -1) 13 | 14 | # Create rotations about X,Y,Z axes 15 | # R = Rz(theta3) * Ry(theta2) * Rx(theta1) 16 | # Last 3 parameters are [theta1, theta2 ,theta3] 17 | rotx = create_rotx(rot_params[:, 0]) # Rx(theta1) 18 | roty = create_roty(rot_params[:, 1]) # Ry(theta2) 19 | rotz = create_rotz(rot_params[:, 2]) # Rz(theta3) 20 | 21 | # Compute Rz(theta3) * Ry(theta2) 22 | rotzy = torch.bmm(rotz, roty) # Rzy = R32 23 | 24 | # Compute rotation matrix R3*R2*R1 = R32*R1 25 | # R = Rz(t3) * Ry(t2) * Rx(t1) 26 | output = torch.bmm(rotzy, rotx) # R = Rzyx 27 | 28 | ctx.save_for_backward(input, output, rotx, roty, rotz, rotzy) 29 | 30 | return output.view(batch_size, num_se3, 3, 3) 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | input, output, rotx, roty, rotz, rotzy = ctx.saved_tensors 35 | batch_size, num_se3, num_params = input.size() 36 | grad_output = grad_output.contiguous().view(batch_size * num_se3, 3, 3) 37 | 38 | # Gradient w.r.t Euler angles from Barfoot's book (http://asrl.utias.utoronto.ca/~tdb/bib/barfoot_ser15.pdf) 39 | grad_input_list = [] 40 | for k in range(3): 41 | gradr = grad_output[:, k] # Gradient w.r.t angle (k) 42 | vec = torch.zeros(1, 3).type_as(gradr) 43 | vec[0][k] = 1 # Unit vector 44 | skewsym = create_skew_symmetric_matrix(vec).view(1, 3, 3).expand_as(output) # Skew symmetric matrix of unit vector 45 | if (k == 0): 46 | Rv = torch.bmm(torch.bmm(rotzy, skewsym), rotx) # Eqn 6.61c 47 | elif (k == 1): 48 | Rv = torch.bmm(torch.bmm(rotz, skewsym), torch.bmm(roty, rotx)) # Eqn 6.61b 49 | else: 50 | Rv = torch.bmm(skewsym, output) 51 | grad_input_list.append(torch.sum(-Rv * grad_output, dim=(1, 2))) 52 | grad_input = torch.stack(grad_input_list, 1).view(batch_size, num_se3, 3) 53 | 54 | return grad_input 55 | -------------------------------------------------------------------------------- /se3/se3quat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import torch.nn.functional as F 4 | from se3.se3_utils import create_rot_from_unitquat 5 | from se3.se3_utils import compute_grad_rot_wrt_unitquat 6 | from se3.se3_utils import compute_grad_unitquat_wrt_quat 7 | 8 | 9 | class Se3quat(Function): 10 | @staticmethod 11 | def forward(ctx, input): 12 | batch_size, num_se3, num_params = input.size() 13 | 14 | rot_params = input.view(batch_size * num_se3, -1) 15 | 16 | unitquat = F.normalize(rot_params) 17 | 18 | output = create_rot_from_unitquat(unitquat).view(batch_size, num_se3, 3, 3) 19 | 20 | ctx.save_for_backward(input) 21 | 22 | return output 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | input = ctx.saved_tensors[0] 27 | batch_size, num_se3, num_params = input.size() 28 | 29 | rot_params = input.view(batch_size * num_se3, -1) 30 | 31 | unitquat = F.normalize(rot_params) 32 | 33 | # Compute dR/dq' 34 | dRdqh = compute_grad_rot_wrt_unitquat(unitquat) 35 | 36 | # Compute dq'/dq = d(q/||q||)/dq = 1/||q|| (I - q'q'^T) 37 | dqhdq = compute_grad_unitquat_wrt_quat(unitquat, rot_params) 38 | 39 | 40 | # Compute dR/dq = dR/dq' * dq'/dq 41 | dRdq = torch.bmm(dRdqh, dqhdq).view(batch_size, num_se3, 3, 3, 4) # B x k x 3 x 3 x 4 42 | 43 | # Scale by grad w.r.t output and sum to get gradient w.r.t quaternion params 44 | grad_out = grad_output.contiguous().view(batch_size, num_se3, 3, 3, 1).expand_as(dRdq) # B x k x 3 x 3 x 4 45 | 46 | grad_input = torch.sum(dRdq * grad_out, dim=(2, 3)) # (Bk) x 3 47 | 48 | return grad_input 49 | -------------------------------------------------------------------------------- /se3/se3spquat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from se3.se3_utils import create_unitquat_from_spquat 4 | from se3.se3_utils import create_rot_from_unitquat 5 | from se3.se3_utils import compute_grad_rot_wrt_unitquat 6 | from se3.se3_utils import compute_grad_unitquat_wrt_spquat 7 | 8 | 9 | class Se3spquat(Function): 10 | @staticmethod 11 | def forward(ctx, input): 12 | batch_size, num_se3, num_params = input.size() 13 | 14 | rot_params = input.view(batch_size * num_se3, -1) 15 | 16 | unitquat = create_unitquat_from_spquat(rot_params) 17 | 18 | output = create_rot_from_unitquat(unitquat).view(batch_size, num_se3, 3, 3) 19 | 20 | ctx.save_for_backward(input) 21 | 22 | return output 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | input = ctx.saved_tensors[0] 27 | batch_size, num_se3, num_params = input.size() 28 | 29 | rot_params = input.view(batch_size * num_se3, -1) 30 | 31 | unitquat = create_unitquat_from_spquat(rot_params) 32 | 33 | # Compute dR/dq' 34 | dRdqh = compute_grad_rot_wrt_unitquat(unitquat) 35 | 36 | # Compute dq'/dq = d(q/||q||)/dq = 1/||q|| (I - q'q'^T) 37 | dqhdspq = compute_grad_unitquat_wrt_spquat(rot_params) 38 | 39 | 40 | # Compute dR/dq = dR/dq' * dq'/dq 41 | dRdq = torch.bmm(dRdqh, dqhdspq).view(batch_size, num_se3, 3, 3, 3) # B x k x 3 x 3 x 3 42 | 43 | # Scale by grad w.r.t output and sum to get gradient w.r.t quaternion params 44 | grad_out = grad_output.contiguous().view(batch_size, num_se3, 3, 3, 1).expand_as(dRdq) # B x k x 3 x 3 x 3 45 | 46 | grad_input = torch.sum(dRdq * grad_out, dim=(2, 4)) # (Bk) x 3 47 | 48 | return grad_input 49 | -------------------------------------------------------------------------------- /sim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import threading 3 | import time 4 | 5 | import numpy as np 6 | import pybullet as p 7 | import pybullet_data 8 | 9 | import utils 10 | 11 | 12 | class PybulletSim: 13 | def __init__(self, gui_enabled, heightmap_pixel_size=0.004, tool='stick'): 14 | 15 | self._workspace_bounds = np.array([[0.244, 0.756], 16 | [-0.256, 0.256], 17 | [0.0, 0.192]]) 18 | 19 | self._view_bounds = self._workspace_bounds 20 | 21 | # Start PyBullet simulation 22 | if gui_enabled: 23 | self._physics_client = p.connect(p.GUI) # or p.DIRECT for non-graphical version 24 | else: 25 | self._physics_client = p.connect(p.DIRECT) # non-graphical version 26 | p.setAdditionalSearchPath(pybullet_data.getDataPath()) 27 | p.setGravity(0, 0, -9.8) 28 | step_sim_thread = threading.Thread(target=self.step_simulation) 29 | step_sim_thread.daemon = True 30 | step_sim_thread.start() 31 | 32 | # Add ground plane & table 33 | self._plane_id = p.loadURDF("plane.urdf") 34 | # self._table_id = p.loadURDF('assets/table/table.urdf', [0.5, 0, 0], useFixedBase=True) 35 | 36 | # Add UR5 robot 37 | self._robot_body_id = p.loadURDF("assets/ur5/ur5.urdf", [0, 0, 0], p.getQuaternionFromEuler([0, 0, 0])) 38 | # Get revolute joint indices of robot (skip fixed joints) 39 | robot_joint_info = [p.getJointInfo(self._robot_body_id, i) for i in range(p.getNumJoints(self._robot_body_id))] 40 | self._robot_joint_indices = [x[0] for x in robot_joint_info if x[2] == p.JOINT_REVOLUTE] 41 | self._joint_epsilon = 0.01 # joint position threshold in radians for blocking calls (i.e. move until joint difference < epsilon) 42 | 43 | # Move robot to home joint configuration 44 | self._robot_home_joint_config = [-3.186603833231106, -2.7046623323544323, 1.9797780717750348, 45 | -0.8458013020952369, -1.5941890970134802, -0.04501555880643846] 46 | self.move_joints(self._robot_home_joint_config, blocking=True, speed=1.0) 47 | 48 | self.tool=tool 49 | # Attach a sticker to UR5 robot 50 | self._gripper_body_id = p.loadURDF("assets/stick/stick.urdf") 51 | p.resetBasePositionAndOrientation(self._gripper_body_id, [0.5, 0.1, 0.2], 52 | p.getQuaternionFromEuler([np.pi, 0, 0])) 53 | self._robot_tool_joint_idx = 9 54 | self._robot_tool_tip_joint_idx = 10 55 | self._robot_tool_offset = [0, 0, -0.0725] 56 | 57 | p.createConstraint(self._robot_body_id, self._robot_tool_joint_idx, self._gripper_body_id, 0, 58 | jointType=p.JOINT_FIXED, jointAxis=[0, 0, 0], parentFramePosition=[0, 0, 0], 59 | childFramePosition=self._robot_tool_offset, 60 | childFrameOrientation=p.getQuaternionFromEuler([0, 0, np.pi / 2])) 61 | self._tool_tip_to_ee_joint = [0, 0, 0.17] 62 | # Define Denavit-Hartenberg parameters for UR5 63 | self._ur5_kinematics_d = np.array([0.089159, 0., 0., 0.10915, 0.09465, 0.0823]) 64 | self._ur5_kinematics_a = np.array([0., -0.42500, -0.39225, 0., 0., 0.]) 65 | 66 | # Set friction coefficients for gripper fingers 67 | for i in range(p.getNumJoints(self._gripper_body_id)): 68 | p.changeDynamics( 69 | self._gripper_body_id, i, 70 | lateralFriction=1.0, 71 | spinningFriction=1.0, 72 | rollingFriction=0.0001, 73 | frictionAnchor=True 74 | ) 75 | 76 | # Add RGB-D camera (mimic RealSense D415) 77 | self.camera_params = { 78 | # large camera, image_size = (240 * 4, 320 * 4) 79 | 0: self._get_camera_param( 80 | camera_position=[0.5, -0.7, 0.3], 81 | camera_image_size=[240 * 4, 320 * 4] 82 | ), 83 | # small camera, image_size = (240, 320) 84 | 1: self._get_camera_param( 85 | camera_position=[0.5, -0.7, 0.3], 86 | camera_image_size=[240, 320] 87 | ), 88 | # top-down camera, image_size = (480, 480) 89 | 2: self._get_camera_param( 90 | camera_position=[0.5, 0, 0.5], 91 | camera_image_size=[480, 480] 92 | ), 93 | } 94 | 95 | 96 | self._heightmap_pixel_size = heightmap_pixel_size 97 | self._heightmap_size = np.round( 98 | ((self._view_bounds[1][1] - self._view_bounds[1][0]) / self._heightmap_pixel_size, 99 | (self._view_bounds[0][1] - self._view_bounds[0][0]) / self._heightmap_pixel_size)).astype(int) 100 | 101 | 102 | def _get_camera_param(self, camera_position, camera_image_size): 103 | camera_lookat = [0.5, 0, 0] 104 | camera_up_direction = [0, camera_position[2], -camera_position[1]] 105 | camera_view_matrix = p.computeViewMatrix(camera_position, camera_lookat, camera_up_direction) 106 | camera_pose = np.linalg.inv(np.array(camera_view_matrix).reshape(4, 4).T) 107 | camera_pose[:, 1:3] = -camera_pose[:, 1:3] 108 | camera_z_near = 0.01 109 | camera_z_far = 10.0 110 | camera_fov_w = 69.40 111 | camera_focal_length = (float(camera_image_size[1]) / 2) / np.tan((np.pi * camera_fov_w / 180) / 2) 112 | camera_fov_h = (math.atan((float(camera_image_size[0]) / 2) / camera_focal_length) * 2 / np.pi) * 180 113 | camera_projection_matrix = p.computeProjectionMatrixFOV( 114 | fov=camera_fov_h, 115 | aspect=float(camera_image_size[1]) / float(camera_image_size[0]), 116 | nearVal=camera_z_near, 117 | farVal=camera_z_far 118 | ) # notes: 1) FOV is vertical FOV 2) aspect must be float 119 | camera_intrinsics = np.array( 120 | [[camera_focal_length, 0, float(camera_image_size[1]) / 2], 121 | [0, camera_focal_length, float(camera_image_size[0]) / 2], 122 | [0, 0, 1]]) 123 | camera_param = { 124 | 'camera_image_size': camera_image_size, 125 | 'camera_intr': camera_intrinsics, 126 | 'camera_pose': camera_pose, 127 | 'camera_view_matrix': camera_view_matrix, 128 | 'camera_projection_matrix': camera_projection_matrix, 129 | 'camera_z_near': camera_z_near, 130 | 'camera_z_far': camera_z_far 131 | } 132 | return camera_param 133 | 134 | # Step through simulation time 135 | def step_simulation(self): 136 | while True: 137 | p.stepSimulation() 138 | time.sleep(0.0001) 139 | 140 | # Get RGB-D heightmap from RGB-D image 141 | def get_heightmap(self, color_image, depth_image, cam_param): 142 | color_heightmap, depth_heightmap = utils.get_heightmap( 143 | color_img=color_image, 144 | depth_img=depth_image, 145 | cam_intrinsics=cam_param['camera_intr'], 146 | cam_pose=cam_param['camera_pose'], 147 | workspace_limits=self._view_bounds, 148 | heightmap_resolution=self._heightmap_pixel_size 149 | ) 150 | return color_heightmap, depth_heightmap 151 | 152 | # Get latest RGB-D image 153 | def get_camera_data(self, cam_param): 154 | camera_data = p.getCameraImage(cam_param['camera_image_size'][1], cam_param['camera_image_size'][0], 155 | cam_param['camera_view_matrix'], cam_param['camera_projection_matrix'], 156 | shadow=1, flags=p.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX, 157 | renderer=p.ER_BULLET_HARDWARE_OPENGL) 158 | 159 | color_image = np.asarray(camera_data[2]).reshape( 160 | [cam_param['camera_image_size'][0], cam_param['camera_image_size'][1], 4])[:, :, :3] # remove alpha channel 161 | z_buffer = np.asarray(camera_data[3]).reshape(cam_param['camera_image_size']) 162 | camera_z_near = cam_param['camera_z_near'] 163 | camera_z_far = cam_param['camera_z_far'] 164 | depth_image = (2.0 * camera_z_near * camera_z_far) / ( 165 | camera_z_far + camera_z_near - (2.0 * z_buffer - 1.0) * ( 166 | camera_z_far - camera_z_near)) 167 | return color_image, depth_image 168 | 169 | # Move robot tool to specified pose 170 | def move_tool(self, position, orientation, blocking=False, speed=0.03): 171 | 172 | # Use IK to compute target joint configuration 173 | target_joint_state = np.array( 174 | p.calculateInverseKinematics(self._robot_body_id, self._robot_tool_tip_joint_idx, position, orientation, 175 | maxNumIterations=10000, 176 | residualThreshold=.0001)) 177 | target_joint_state[5] = ( 178 | (target_joint_state[5] + np.pi) % (2 * np.pi) - np.pi) # keep EE joint angle between -180/+180 179 | 180 | # Move joints 181 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices, p.POSITION_CONTROL, 182 | target_joint_state, 183 | positionGains=speed * np.ones(len(self._robot_joint_indices))) 184 | 185 | # Block call until joints move to target configuration 186 | if blocking: 187 | actual_joint_state = [p.getJointState(self._robot_body_id, x)[0] for x in self._robot_joint_indices] 188 | timeout_t0 = time.time() 189 | while not all([np.abs(actual_joint_state[i] - target_joint_state[i]) < self._joint_epsilon for i in 190 | range(6)]): # and (time.time()-timeout_t0) < timeout: 191 | if time.time() - timeout_t0 > 5: 192 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices, p.POSITION_CONTROL, 193 | self._robot_home_joint_config, 194 | positionGains=np.ones(len(self._robot_joint_indices))) 195 | break 196 | actual_joint_state = [p.getJointState(self._robot_body_id, x)[0] for x in self._robot_joint_indices] 197 | time.sleep(0.001) 198 | 199 | # Move robot arm to specified joint configuration 200 | def move_joints(self, target_joint_state, blocking=False, speed=0.03): 201 | 202 | # Move joints 203 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices, 204 | p.POSITION_CONTROL, target_joint_state, 205 | positionGains=speed * np.ones(len(self._robot_joint_indices))) 206 | 207 | # Block call until joints move to target configuration 208 | if blocking: 209 | actual_joint_state = [p.getJointState(self._robot_body_id, i)[0] for i in self._robot_joint_indices] 210 | timeout_t0 = time.time() 211 | while not all([np.abs(actual_joint_state[i] - target_joint_state[i]) < self._joint_epsilon for i in 212 | range(6)]): 213 | if time.time() - timeout_t0 > 5: 214 | p.setJointMotorControlArray(self._robot_body_id, self._robot_joint_indices, p.POSITION_CONTROL, 215 | self._robot_home_joint_config, 216 | positionGains=np.ones(len(self._robot_joint_indices))) 217 | break 218 | actual_joint_state = [p.getJointState(self._robot_body_id, i)[0] for i in self._robot_joint_indices] 219 | time.sleep(0.001) 220 | 221 | 222 | def robot_go_home(self, blocking=True, speed=0.1): 223 | self.move_joints(self._robot_home_joint_config, blocking, speed) 224 | 225 | 226 | def primitive_push(self, position, rotation_angle, speed=0.01, distance=0.1): 227 | push_orientation = [1.0, 0.0] 228 | push_direction = np.asarray( 229 | [push_orientation[0] * np.cos(rotation_angle) - push_orientation[1] * np.sin(rotation_angle), 230 | push_orientation[0] * np.sin(rotation_angle) + push_orientation[1] * np.cos(rotation_angle), 0.0]) 231 | target_x = position[0] + push_direction[0] * distance 232 | target_y = position[1] + push_direction[1] * distance 233 | position_end = np.asarray([target_x, target_y, position[2]]) 234 | self.move_tool([position[0], position[1], 0.15], orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=0.05) 235 | self.move_tool(position, orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=0.1) 236 | self.move_tool(position_end, orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=speed) 237 | 238 | position_end[2]=0.15 239 | self.move_tool(position_end, orientation=[-1.0, 1.0, 0.0, 0.0], blocking=True, speed=0.005) 240 | -------------------------------------------------------------------------------- /sim_env.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | import cv2 4 | import numpy as np 5 | import pybullet as p 6 | import json 7 | 8 | from sim import PybulletSim 9 | from binvox_utils import read_as_coord_array 10 | from utils import euler2rotm, project_pts_to_2d 11 | 12 | 13 | class SimulationEnv(): 14 | def __init__(self, gui_enabled): 15 | 16 | self.gui_enabled = gui_enabled 17 | self.sim = PybulletSim(gui_enabled=gui_enabled, tool='stick') 18 | self.heightmap_size = self.sim._heightmap_size 19 | self.heightmap_pixel_size = self.sim._heightmap_pixel_size 20 | self.view_bounds = self.sim._view_bounds 21 | self.direction_num = 8 22 | self.voxel_size = 0.004 23 | 24 | self.object_ids = [] 25 | 26 | self.object_type = 'cube' # choice: 'cube', 'shapenet', 'ycb' 27 | 28 | # process ycb 29 | self.ycb_path = 'object_models/ycb' 30 | self.ycb_info = json.load(open('assets/object_id/ycb_id.json', 'r')) 31 | 32 | # process shapenent 33 | self.shapenet_path = 'object_models/shapenet' 34 | self.shapenet_info = json.load(open('assets/object_id/shapenet_id.json', 'r')) 35 | 36 | self.voxel_coord = {} 37 | self.cnt_dict = {} 38 | self.init_position = {} 39 | self.last_direction = {} 40 | 41 | 42 | def _get_coord(self, obj_id, position, orientation, vol_bnds=None, voxel_size=None): 43 | # if vol_bnds is not None, return coord in voxel, else, return world coord 44 | coord = self.voxel_coord[obj_id] 45 | mat = euler2rotm(p.getEulerFromQuaternion(orientation)) 46 | coord = (mat @ (coord.T)).T + np.asarray(position) 47 | if vol_bnds is not None: 48 | coord = np.round((coord - vol_bnds[:, 0]) / voxel_size).astype(np.int) 49 | return coord 50 | 51 | def _get_scene_flow_3d(self, old_po_ors): 52 | vol_bnds = self.view_bounds 53 | scene_flow = np.zeros([int((x[1] - x[0] + 1e-7) / self.voxel_size) for x in vol_bnds] + [3]) 54 | mask = np.zeros([int((x[1] - x[0] + 1e-7) / self.voxel_size) for x in vol_bnds], dtype=np.int) 55 | 56 | cur_cnt = 0 57 | for obj_id, old_po_or in zip(self.object_ids, old_po_ors): 58 | position, orientation = p.getBasePositionAndOrientation(obj_id) 59 | new_coord = self._get_coord(obj_id, position, orientation, vol_bnds, self.voxel_size) 60 | 61 | position, orientation = old_po_or 62 | old_coord = self._get_coord(obj_id, position, orientation, vol_bnds, self.voxel_size) 63 | 64 | motion = new_coord - old_coord 65 | 66 | valid_idx = np.logical_and( 67 | np.logical_and(old_coord[:, 1] >= 0, old_coord[:, 1] < 128), 68 | np.logical_and( 69 | np.logical_and(old_coord[:, 0] >= 0, old_coord[:, 0] < 128), 70 | np.logical_and(old_coord[:, 2] >= 0, old_coord[:, 2] < 48) 71 | ) 72 | ) 73 | x = old_coord[valid_idx, 1] 74 | y = old_coord[valid_idx, 0] 75 | z = old_coord[valid_idx, 2] 76 | motion = motion[valid_idx] 77 | motion = np.stack([motion[:, 1], motion[:, 0], motion[:, 2]], axis=1) 78 | 79 | scene_flow[x, y, z] = motion 80 | 81 | # mask 82 | cur_cnt += 1 83 | mask[x, y, z] = cur_cnt 84 | 85 | return mask, scene_flow 86 | 87 | def _get_scene_flow_2d(self, old_po_or): 88 | old_coords_world = [] 89 | new_coords_world = [] 90 | point_id_list = [] 91 | cur_cnt = 0 92 | for obj_id, po_or in zip(self.object_ids, old_po_or): 93 | position, orientation = po_or 94 | old_coord = self._get_coord(obj_id, position, orientation) 95 | old_coords_world.append(old_coord) 96 | 97 | position, orientation = p.getBasePositionAndOrientation(obj_id) 98 | new_coord = self._get_coord(obj_id, position, orientation) 99 | new_coords_world.append(new_coord) 100 | 101 | cur_cnt += 1 102 | point_id_list.append([cur_cnt for _ in range(old_coord.shape[0])]) 103 | 104 | point_id = np.concatenate(point_id_list) 105 | old_coords_world = np.concatenate(old_coords_world) 106 | new_coords_world = np.concatenate(new_coords_world) 107 | camera_view_matrix = np.array(self.sim.camera_params[1]['camera_view_matrix']).reshape(4, 4).T 108 | camera_intr = self.sim.camera_params[1]['camera_intr'] 109 | image_size = self.sim.camera_params[1]['camera_image_size'] 110 | old_coords_2d = project_pts_to_2d(old_coords_world.T, camera_view_matrix, camera_intr) 111 | y = np.round(old_coords_2d[0]).astype(np.int) 112 | x = np.round(old_coords_2d[1]).astype(np.int) 113 | depth = old_coords_2d[2] 114 | valid_idx = np.logical_and( 115 | np.logical_and(x >= 0, x < image_size[0]), 116 | np.logical_and(y >= 0, y < image_size[1]) 117 | ) 118 | x = x[valid_idx] 119 | y = y[valid_idx] 120 | depth = depth[valid_idx] 121 | point_id = point_id[valid_idx] 122 | motion = (new_coords_world - old_coords_world)[valid_idx] 123 | 124 | sort_id = np.argsort(-depth) 125 | x = x[sort_id] 126 | y = y[sort_id] 127 | point_id = point_id[sort_id] 128 | motion = motion[sort_id] 129 | motion = np.stack([motion[:, 1], motion[:, 0], motion[:, 2]], axis=1) 130 | 131 | scene_flow = np.zeros([image_size[0], image_size[1], 3]) 132 | mask = np.zeros([image_size[0], image_size[1]]) 133 | 134 | scene_flow[x, y] = motion 135 | mask[x, y] = point_id 136 | 137 | return mask, scene_flow 138 | 139 | 140 | def check_occlusion(self): 141 | coords_world = [] 142 | point_id_list = [] 143 | cur_cnt = 0 144 | for obj_id in self.object_ids: 145 | position, orientation = p.getBasePositionAndOrientation(obj_id) 146 | coord = self._get_coord(obj_id, position, orientation) 147 | coords_world.append(coord) 148 | cur_cnt += 1 149 | point_id_list.append([cur_cnt for _ in range(coord.shape[0])]) 150 | point_id = np.concatenate(point_id_list) 151 | coords_world = np.concatenate(coords_world) 152 | camera_view_matrix = np.array(self.sim.camera_params[1]['camera_view_matrix']).reshape(4, 4).T 153 | camera_intr = self.sim.camera_params[1]['camera_intr'] 154 | image_size = self.sim.camera_params[1]['camera_image_size'] 155 | coords_2d = project_pts_to_2d(coords_world.T, camera_view_matrix, camera_intr) 156 | 157 | y = np.round(coords_2d[0]).astype(np.int) 158 | x = np.round(coords_2d[1]).astype(np.int) 159 | depth = coords_2d[2] 160 | valid_idx = np.logical_and( 161 | np.logical_and(x >= 0, x < image_size[0]), 162 | np.logical_and(y >= 0, y < image_size[1]) 163 | ) 164 | x = x[valid_idx] 165 | y = y[valid_idx] 166 | depth = depth[valid_idx] 167 | point_id = point_id[valid_idx] 168 | 169 | sort_id = np.argsort(-depth) 170 | x = x[sort_id] 171 | y = y[sort_id] 172 | point_id = point_id[sort_id] 173 | 174 | mask = np.zeros([image_size[0], image_size[1]]) 175 | mask[x, y] = point_id 176 | 177 | obj_num = len(self.object_ids) 178 | mask_sep = np.zeros([obj_num + 1, image_size[0], image_size[1]]) 179 | mask_sep[point_id, x, y] = 1 180 | for i in range(obj_num): 181 | tot_pixel_num = np.sum(mask_sep[i + 1]) 182 | vis_pixel_num = np.sum((mask == (i+1)).astype(np.float)) 183 | if vis_pixel_num < 0.4 * tot_pixel_num: 184 | return False 185 | return True 186 | 187 | def _get_image_and_heightmap(self): 188 | color_image0, depth_image0 = self.sim.get_camera_data(self.sim.camera_params[0]) 189 | color_image1, depth_image1 = self.sim.get_camera_data(self.sim.camera_params[1]) 190 | color_image2, depth_image2 = self.sim.get_camera_data(self.sim.camera_params[2]) 191 | 192 | color_heightmap, depth_heightmap = self.sim.get_heightmap(color_image2, depth_image2, self.sim.camera_params[2]) 193 | 194 | self.current_depth_heightmap = depth_heightmap 195 | self.current_color_heightmap = color_heightmap 196 | self.current_depth_image0 = depth_image0 197 | self.current_color_image0 = color_image0 198 | self.current_depth_image1 = depth_image1 199 | self.current_color_image1 = color_image1 200 | 201 | def _random_drop(self, object_num, object_type): 202 | large_object_id = np.random.choice(object_num) 203 | if object_type == 'cube' or np.random.rand() < 0.1: 204 | large_object_id = -1 205 | self.large_object_id = large_object_id 206 | self.can_with_box = False 207 | 208 | while True: 209 | xy_pos = np.random.rand(object_num, 2) * 0.26 + np.asarray([0.5-0.13, -0.13]) 210 | flag = True 211 | for i in range(object_num - 1): 212 | for j in range(i + 1, object_num): 213 | d = np.sqrt(np.sum((xy_pos[i] - xy_pos[j])**2)) 214 | if i == large_object_id or j == large_object_id: 215 | if d < 0.13: 216 | flag = False 217 | else: 218 | if d < 0.07: 219 | flag = False 220 | if large_object_id != -1 and xy_pos[large_object_id][1] > -0.05 and np.random.rand() < 0.15: 221 | flag = False 222 | if flag: 223 | break 224 | 225 | xy_pos -= np.mean(xy_pos, 0) 226 | xy_pos += np.array([0.5, 0]) 227 | 228 | for i in range(object_num): 229 | if object_type == 'cube': 230 | md = np.ones([60, 60, 70]) 231 | coord = (np.asarray(np.nonzero(md)).T + 0.5 - np.array([30, 30, 35])) 232 | size_cube = np.random.choice([700, 750, 800, 850, 900, 1000, 1100, 1200, 1400]) 233 | collision_id = p.createCollisionShape(p.GEOM_BOX, halfExtents=np.array([30, 30, 35]) / size_cube) 234 | body_id = p.createMultiBody( 235 | 0.05, collision_id, -1, 236 | [xy_pos[i, 0], xy_pos[i, 1], 0.2], 237 | p.getQuaternionFromEuler(np.random.rand(3) * np.pi) 238 | ) 239 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05) 240 | p.changeVisualShape(body_id, -1, rgbaColor=np.concatenate([1 * np.random.rand(3), [1]])) 241 | self.object_ids.append(body_id) 242 | self.voxel_coord[body_id] = coord / size_cube 243 | time.sleep(0.2) 244 | elif object_type == 'ycb': 245 | # get object 246 | if i == large_object_id: 247 | obj_name = np.random.choice(self.ycb_info['large_list']) 248 | else: 249 | obj_name = np.random.choice(self.ycb_info['normal_list']) 250 | 251 | with open(osp.join(self.ycb_path, obj_name, 'model_com.binvox'), 'rb') as f: 252 | md = read_as_coord_array(f) 253 | coord = (md.data.T + 0.5) / md.dims * md.scale + md.translate 254 | 255 | # position & quat 256 | random_euler = [0, 0, np.random.rand() * 2 * np.pi] 257 | quat = p.getQuaternionFromEuler(random_euler) 258 | obj_position = [xy_pos[i, 0], xy_pos[i, 1], np.max(-coord[:, 2]) + 0.01] 259 | 260 | urdf_path = osp.join(self.ycb_path, obj_name, 'obj.urdf') 261 | body_id = p.loadURDF( 262 | fileName=urdf_path, 263 | basePosition=obj_position, 264 | baseOrientation=quat, 265 | globalScaling=1 266 | ) 267 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05) 268 | 269 | self.object_ids.append(body_id) 270 | self.voxel_coord[body_id] = coord 271 | time.sleep(2) 272 | elif (object_type=='shapenet' and i != large_object_id and np.random.rand() < 0.3) or \ 273 | (object_type=='shapenet' and i == large_object_id and np.random.rand() < 0.12): 274 | box_size = 'small' 275 | if i == large_object_id: 276 | box_size='large' 277 | elif self.can_with_box and np.random.rand() < 0.1: 278 | self.can_with_box=False 279 | box_size='large' 280 | 281 | if box_size == 'large': 282 | dim_x = np.random.choice(list(range(35, 55))) 283 | dim_y = np.random.choice(list(range(70, 85))) 284 | dim_z = np.random.choice(list(range(15, 30))) 285 | else: 286 | dim_x = np.random.choice(list(range(25, 40))) 287 | dim_y = np.random.choice(list(range(30, 60))) 288 | dim_z = np.random.choice(list(range(15, 25)) + [30, 32]) 289 | md = np.ones([dim_x, dim_y, dim_z]) 290 | coord = (np.asarray(np.nonzero(md)).T + 0.5 - np.array([dim_x / 2, dim_y / 2, dim_z / 2])) 291 | size_cube = 500 292 | collision_id = p.createCollisionShape(p.GEOM_BOX, halfExtents=np.array( 293 | [dim_x / 2, dim_y / 2, dim_z / 2]) / size_cube) 294 | body_id = p.createMultiBody( 295 | 0.05, collision_id, -1, 296 | [xy_pos[i, 0], xy_pos[i, 1], 0.1], 297 | [xy_pos[i, 0], xy_pos[i, 1], 0.1], 298 | p.getQuaternionFromEuler([0, 0, np.random.rand() * np.pi]) 299 | ) 300 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05) 301 | p.changeVisualShape(body_id, -1, rgbaColor=np.concatenate([1 * np.random.rand(3), [1]])) 302 | self.object_ids.append(body_id) 303 | self.voxel_coord[body_id] = coord / size_cube 304 | time.sleep(0.2) 305 | else: 306 | object_cat_cur = np.random.choice(list(self.shapenet_info.keys())) 307 | if np.random.rand() < 0.3: 308 | object_cat_cur = 'can' 309 | if i == large_object_id and object_cat_cur == 'can': 310 | self.can_with_box=True 311 | 312 | category_id = self.shapenet_info[object_cat_cur]['category_id'] 313 | tmp = np.random.choice(len(self.shapenet_info[object_cat_cur]['object_id'])) 314 | object_id = self.shapenet_info[object_cat_cur]['object_id'][tmp] 315 | urdf_path = osp.join(self.shapenet_path, '%s/%s/obj.urdf' % (category_id, object_id)) 316 | 317 | # load object 318 | if i == large_object_id: 319 | scaling_range = self.shapenet_info[object_cat_cur]['large_scaling'] 320 | else: 321 | scaling_range = self.shapenet_info[object_cat_cur]['global_scaling'] 322 | 323 | globalScaling = np.random.rand() * (scaling_range[1] - scaling_range[0]) + scaling_range[0] 324 | 325 | # save nonzero voxel coord 326 | with open(osp.join(self.shapenet_path, '%s/%s/model_com.binvox' % (category_id, object_id)), 327 | 'rb') as f: 328 | md = read_as_coord_array(f) 329 | coord = (md.data.T + 0.5) / md.dims * md.scale + md.translate 330 | coord = coord * 0.15 * globalScaling # 0.15 is the rescale value in .urdf 331 | 332 | # position & quat 333 | random_euler = [0, 0, np.random.rand() * 2 * np.pi] 334 | quat = p.getQuaternionFromEuler(random_euler) 335 | obj_position = [xy_pos[i, 0], xy_pos[i, 1], np.max(-coord[:, 2]) + 0.01] 336 | 337 | body_id = p.loadURDF( 338 | fileName=urdf_path, 339 | basePosition=obj_position, 340 | baseOrientation=quat, 341 | globalScaling=globalScaling 342 | ) 343 | 344 | p.changeDynamics(body_id, -1, spinningFriction=0.003, lateralFriction=0.25, mass=0.05) 345 | p.changeVisualShape(body_id, -1, rgbaColor=np.concatenate([1 * np.random.rand(3), [1]])) 346 | self.object_ids.append(body_id) 347 | self.voxel_coord[body_id] = coord 348 | time.sleep(0.2) 349 | for obj_id in self.object_ids: 350 | self.cnt_dict[obj_id] = 0 351 | init_p = p.getBasePositionAndOrientation(obj_id)[0] 352 | self.init_position[obj_id] = np.asarray(init_p[:2]) 353 | self.last_direction[obj_id] = None 354 | 355 | # for heightmap 356 | def coord2pixel(self, x_coord, y_coord): 357 | x_pixel = int((x_coord - self.view_bounds[0, 0]) / self.heightmap_pixel_size) 358 | y_pixel = int((y_coord - self.view_bounds[1, 0]) / self.heightmap_pixel_size) 359 | return x_pixel, y_pixel 360 | 361 | def pixel2coord(self, x_pixel, y_pixel): 362 | x_coord = x_pixel * self.heightmap_pixel_size + self.view_bounds[0, 0] 363 | y_coord = y_pixel * self.heightmap_pixel_size + self.view_bounds[1, 0] 364 | return x_coord, y_coord 365 | 366 | def policy_generation(self): 367 | def softmax(input): 368 | value = np.exp(input) 369 | output = value / np.sum(value) 370 | return output 371 | 372 | # choose object 373 | value = [] 374 | for x in self.object_ids: 375 | t = self.cnt_dict[x] 376 | if t > 2: 377 | t = -2 378 | elif t > 1 and np.random.rand() < 0.5: 379 | t = -2 380 | self.cnt_dict[x] = t 381 | value.append(t) 382 | if self.large_object_id != -1: 383 | value[self.large_object_id] += 0.5 384 | obj_id = np.random.choice(self.object_ids, p=softmax(np.array(value))) 385 | 386 | # get position 387 | position = p.getBasePositionAndOrientation(obj_id)[0] 388 | position = np.asarray([position[0], position[1]]) 389 | 390 | # choose direction 391 | direction_value = [0 for i in range(self.direction_num)] 392 | for d in range(self.direction_num): 393 | ang = 2 * np.pi * d / self.direction_num 394 | unit_vec = np.asarray([np.cos(ang), np.sin(ang)]) 395 | 396 | off_direction = np.asarray([0.5, 0]) - position 397 | off_direction_unit = off_direction / np.sqrt(np.sum(off_direction ** 2)) 398 | weight = 5 if self.last_direction[obj_id] is None else 1 399 | direction_value[d] += weight * np.sum(off_direction_unit * unit_vec) * np.exp(np.sum(np.abs(off_direction))) 400 | if np.sqrt(np.sum(off_direction ** 2)) > 0.2 and np.sum(off_direction_unit * unit_vec) < 0: 401 | direction_value[d] -= 10 402 | 403 | for obj_id_enm in self.object_ids: 404 | if obj_id_enm != obj_id: 405 | obj_position_enm = p.getBasePositionAndOrientation(obj_id_enm)[0] 406 | off_direction = np.asarray(obj_position_enm[:2]) - position 407 | off_direction_unit = off_direction / np.sqrt(np.sum(off_direction ** 2)) 408 | if np.sqrt(np.sum(off_direction ** 2)) < 0.15 and np.sum(off_direction_unit * unit_vec) > 0.4: 409 | direction_value[d] -= 3 * np.sum(off_direction_unit * unit_vec) 410 | if np.sqrt(np.sum(off_direction ** 2)) < 0.15 and np.abs( 411 | np.sum(off_direction_unit * unit_vec)) < 0.3: 412 | direction_value[d] += 1.5 413 | 414 | off_direction = position - self.init_position[obj_id] 415 | if np.sum(np.abs(off_direction)) > 0.001: 416 | off_direction_unit = off_direction / np.sqrt(np.sum(off_direction ** 2)) 417 | if np.sqrt(np.sum(off_direction ** 2)) < 0.25 and np.sum(off_direction_unit * unit_vec) < 0: 418 | direction_value[d] += 1.5 * np.sum(off_direction_unit * unit_vec) 419 | 420 | if self.last_direction[obj_id] is not None: 421 | ang_last = 2 * np.pi * self.last_direction[obj_id] / self.direction_num 422 | unit_vec_last = np.asarray([np.cos(ang_last), np.sin(ang_last)]) 423 | direction_value[d] += 2 * np.sum(unit_vec_last * unit_vec) 424 | 425 | direction = np.random.choice(self.direction_num, p=softmax(np.array(direction_value) / 2)) 426 | 427 | direction_angle = direction / 4.0 * np.pi 428 | 429 | pos = position 430 | pos -= np.asarray([np.cos(direction_angle), np.sin(direction_angle)]) * 0.04 431 | 432 | for _ in range(5): 433 | x_coord, y_coord = pos[0], pos[1] 434 | x_pixel, y_pixel = self.coord2pixel(x_coord, y_coord) 435 | pos -= np.asarray([np.cos(direction_angle), np.sin(direction_angle)]) * 0.01 436 | if min(x_pixel, y_pixel) < 0 or max(x_pixel, y_pixel) >= 128: 437 | continue 438 | 439 | detection_mask = cv2.circle(np.zeros(self.heightmap_size), (x_pixel, y_pixel), 5, 1, thickness=-1) 440 | if np.max(self.current_depth_heightmap * detection_mask) > 0.005: 441 | continue 442 | 443 | d = 0.04 444 | new_pixel = self.coord2pixel(x_coord + np.cos(direction_angle) * d, y_coord + np.sin(direction_angle) * d) 445 | detection_mask = cv2.circle(np.zeros(self.heightmap_size), new_pixel, 3, 1, thickness=-1) 446 | if np.max(detection_mask * self.current_depth_heightmap) > 0.005: 447 | self.last_direction[obj_id] = direction 448 | self.cnt_dict[obj_id] += 1 449 | return x_pixel, y_pixel, x_coord, y_coord, 0.005, direction 450 | 451 | return None 452 | 453 | 454 | def poke(self): 455 | # log the current position & quat 456 | old_po_ors = [p.getBasePositionAndOrientation(object_id) for object_id in self.object_ids] 457 | output = self.get_scene_info() 458 | 459 | # generate action 460 | policy = None 461 | while policy is None: 462 | policy = self.policy_generation() 463 | 464 | x_pixel, y_pixel, x_coord, y_coord, z_coord, direction = policy 465 | 466 | # take action 467 | self.sim.primitive_push( 468 | position=[x_coord, y_coord, z_coord], 469 | rotation_angle=direction / 4.0 * np.pi, 470 | speed=0.005, 471 | distance=0.15 472 | ) 473 | self.sim.robot_go_home() 474 | action = {'0': direction, '1': y_pixel, '2': x_pixel} 475 | 476 | mask_3d, scene_flow_3d = self._get_scene_flow_3d(old_po_ors) 477 | mask_2d, scene_flow_2d = self._get_scene_flow_2d(old_po_ors) 478 | 479 | output['action'] = action 480 | output['mask_3d'] = mask_3d 481 | output['scene_flow_3d'] = scene_flow_3d 482 | output['mask_2d'] = mask_2d 483 | output['scene_flow_2d'] = scene_flow_2d 484 | 485 | return output 486 | 487 | 488 | def get_scene_info(self, mask_info=False): 489 | self._get_image_and_heightmap() 490 | 491 | positions, orientations = [], [] 492 | for i, obj_id in enumerate(self.object_ids): 493 | info = p.getBasePositionAndOrientation(obj_id) 494 | positions.append(info[0]) 495 | orientations.append(info[1]) 496 | 497 | scene_info = { 498 | 'color_heightmap': self.current_color_heightmap, 499 | 'depth_heightmap': self.current_depth_heightmap, 500 | 'color_image': self.current_color_image0, 501 | 'depth_image': self.current_depth_image0, 502 | 'color_image_small': self.current_color_image1, 503 | 'depth_image_small': self.current_depth_image1, 504 | 'positions': np.array(positions), 505 | 'orientations': np.array(orientations) 506 | } 507 | if mask_info: 508 | old_po_ors = [p.getBasePositionAndOrientation(object_id) for object_id in self.object_ids] 509 | mask_3d, scene_flow_3d = self._get_scene_flow_3d(old_po_ors) 510 | mask_2d, scene_flow_2d = self._get_scene_flow_2d(old_po_ors) 511 | scene_info['mask_3d'] = mask_3d 512 | scene_info['mask_2d'] = mask_2d 513 | 514 | return scene_info 515 | 516 | 517 | def reset(self, object_num=4, object_type=None): 518 | if object_type is None: 519 | object_type = self.object_type 520 | 521 | while True: 522 | # remove objects 523 | for obj_id in self.object_ids: 524 | p.removeBody(obj_id) 525 | self.object_ids = [] 526 | self.cnt_dict = {} 527 | 528 | # load fences 529 | self.fence_id = p.loadURDF( 530 | fileName='assets/fence/tinker.urdf', 531 | basePosition=[0.5, 0, 0.001], 532 | baseOrientation=p.getQuaternionFromEuler([0, 0, 0]), 533 | useFixedBase=True 534 | ) 535 | 536 | # load objects 537 | self._random_drop(object_num, object_type) 538 | time.sleep(1) 539 | p.removeBody(self.fence_id) 540 | old_ps = np.array([p.getBasePositionAndOrientation(object_id)[0] for object_id in self.object_ids]) 541 | for _ in range(10): 542 | time.sleep(1) 543 | new_ps = np.array([p.getBasePositionAndOrientation(object_id)[0] for object_id in self.object_ids]) 544 | if np.sum((new_ps - old_ps) ** 2) < 1e-6: 545 | break 546 | old_ps = new_ps 547 | self._get_image_and_heightmap() 548 | 549 | # check occlusion 550 | if self.check_occlusion(): 551 | return 552 | 553 | 554 | if __name__ == '__main__': 555 | env = SimulationEnv(gui_enabled=False) 556 | env.reset(4, 'ycb') 557 | 558 | # if you just want to get the information of the scene, use env.get_scene_info 559 | output = env.get_scene_info(mask_info=True) 560 | print(output.keys()) 561 | 562 | # if use the pushing. env.poke() will also give you everything, together with scene flow 563 | output = env.poke() 564 | print(output.keys()) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | from tqdm import tqdm 5 | import os.path as osp 6 | from data import Data 7 | from torch.utils.data import DataLoader 8 | from model import ModelDSR 9 | import itertools 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--resume', type=str, help='path to model') 15 | parser.add_argument('--data_path', type=str, help='path to data') 16 | parser.add_argument('--test_type', type=str, choices=['motion_visible', 'motion_full', 'mask_ordered', 'mask_unordered']) 17 | 18 | parser.add_argument('--gpu', type=int, default=0, help='gpu id (single gpu)') 19 | parser.add_argument('--object_num', type=int, default=5, help='number of objects') 20 | parser.add_argument('--seq_len', type=int, default=10, help='sequence length') 21 | parser.add_argument('--batch', type=int, default=12, help='batch size') 22 | parser.add_argument('--workers', type=int, default=2, help='number of workers in data loader') 23 | 24 | parser.add_argument('--model_type', type=str, default='dsr', choices=['dsr', 'single', 'nowarp', 'gtwarp', '3dflow']) 25 | parser.add_argument('--transform_type', type=str, default='se3euler', choices=['affine', 'se3euler', 'se3aa', 'se3spquat', 'se3quat']) 26 | 27 | def main(): 28 | args = parser.parse_args() 29 | torch.cuda.set_device(args.gpu) 30 | 31 | data, loaders = {}, {} 32 | for split in ['test']: 33 | data[split] = Data(data_path=args.data_path, split=split, seq_len=args.seq_len) 34 | loaders[split] = DataLoader(dataset=data[split], batch_size=args.batch, num_workers=args.workers) 35 | print('==> dataset loaded: [size] = {0}'.format(len(data['test']))) 36 | 37 | 38 | model = ModelDSR( 39 | object_num=args.object_num, 40 | transform_type=args.transform_type, 41 | motion_type='se3' if args.model_type != '3dflow' else 'conv', 42 | ) 43 | model.cuda() 44 | 45 | checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{args.gpu}')) 46 | model.load_state_dict(checkpoint['state_dict']) 47 | print('==> resume: ' + args.resume) 48 | 49 | with torch.no_grad(): 50 | if args.test_type == 'motion_visible': 51 | evaluation_motion_visible(args, model, loaders['test']) 52 | 53 | if args.test_type == 'motion_full': 54 | evaluation_motion_full(args, model, loaders['test']) 55 | 56 | if args.test_type == 'mask_ordered': 57 | evaluation_mask_ordered(args, model, loaders['test']) 58 | 59 | if args.test_type == 'mask_unordered': 60 | evaluation_mask_unordered(args, model, loaders['test']) 61 | 62 | def evaluation_mask_unordered(args, model, loader): 63 | print(f'==> evaluation_mask (unordered)') 64 | iou_dict = [[] for _ in range(args.seq_len)] 65 | for batch in tqdm(loader): 66 | batch_size = batch['0-action'].size(0) 67 | last_s = model.get_init_repr(batch_size).cuda() 68 | logit_pred_list, mask_gt_list = [], [] 69 | for step_id in range(args.seq_len): 70 | output = model( 71 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1), 72 | last_s=last_s, 73 | input_action=batch['%d-action' % step_id].cuda(), 74 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None, 75 | no_warp=args.model_type=='nowarp' 76 | ) 77 | if not args.model_type == 'single': 78 | last_s = output['s'].data 79 | 80 | logit_pred = output['init_logit'] 81 | mask_gt = batch['%d-mask_3d' % step_id].cuda() 82 | iou_unordered = calc_iou_unordered(logit_pred, mask_gt) 83 | iou_dict[step_id].append(iou_unordered) 84 | print('mask_unordered (IoU) = ', np.mean([np.mean(np.concatenate(iou_dict[i])) for i in range(args.seq_len)])) 85 | 86 | 87 | def calc_iou_unordered(logit_pred, mask_gt_argmax): 88 | # logit_pred: [B, K, S1, S2, S3], softmax, the last channel is empty 89 | # mask_gt_argmax: [B, S1, S2, S3], 0 represents empty 90 | B, K, S1, S2, S3 = logit_pred.size() 91 | logit_pred_argmax = torch.argmax(logit_pred, dim=1, keepdim=True) 92 | mask_gt_argmax = torch.unsqueeze(mask_gt_argmax, 1) 93 | mask_pred_onehot = torch.zeros_like(logit_pred).scatter(1, logit_pred_argmax, 1)[:, :-1] 94 | mask_gt_onehot = torch.zeros_like(logit_pred).scatter(1, mask_gt_argmax, 1)[:, 1:] 95 | K -= 1 96 | info_dict = {'I': np.zeros([B, K, K]), 'U': np.zeros([B, K, K])} 97 | for b in range(B): 98 | for i in range(K): 99 | for j in range(K): 100 | mask_gt = mask_gt_onehot[b, i] 101 | mask_pred = mask_pred_onehot[b, j] 102 | I = torch.sum(mask_gt * mask_pred).item() 103 | U = torch.sum(mask_gt + mask_pred).item() - I 104 | info_dict['I'][b, i, j] = I 105 | info_dict['U'][b, i, j] = U 106 | batch_ious = [] 107 | for b in range(B): 108 | best_iou, best_p = 0, None 109 | for p in list(itertools.permutations(range(K))): 110 | cur_I = [info_dict['I'][b, i, p[i]] for i in range(K)] 111 | cur_U = [info_dict['U'][b, i, p[i]] for i in range(K)] 112 | cur_iou = np.mean(np.array(cur_I) / np.maximum(np.array(cur_U), 1)) 113 | if cur_iou > best_iou: 114 | best_iou = cur_iou 115 | batch_ious.append(best_iou) 116 | 117 | return np.array(batch_ious) 118 | 119 | 120 | def evaluation_mask_ordered(args, model, loader): 121 | print(f'==> evaluation_mask (ordered)') 122 | iou_dict = [] 123 | for batch in tqdm(loader): 124 | batch_size = batch['0-action'].size(0) 125 | last_s = model.get_init_repr(batch_size).cuda() 126 | logit_pred_list, mask_gt_list = [], [] 127 | for step_id in range(args.seq_len): 128 | output = model( 129 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1), 130 | last_s=last_s, 131 | input_action=batch['%d-action' % step_id].cuda(), 132 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None, 133 | no_warp=args.model_type=='nowarp' 134 | ) 135 | if not args.model_type == 'single': 136 | last_s = output['s'].data 137 | 138 | logit_pred = output['init_logit'] 139 | mask_gt = batch['%d-mask_3d' % step_id].cuda() 140 | logit_pred_list.append(logit_pred) 141 | mask_gt_list.append(mask_gt) 142 | iou_ordered = calc_iou_ordered(logit_pred_list, mask_gt_list) 143 | iou_dict.append(iou_ordered) 144 | print('mask_ordered (IoU) = ', np.mean(np.concatenate(iou_dict))) 145 | 146 | 147 | def calc_iou_ordered(logit_pred_list, mask_gt_argmax_list): 148 | # logit_pred_list: [L, B, K, S1, S2, S3], softmax, the last channel is empty 149 | # mask_gt_argmax_list: [L, B, S1, S2, S3], 0 represents empty 150 | L = len(logit_pred_list) 151 | B, K, S1, S2, S3 = logit_pred_list[0].size() 152 | K -= 1 153 | info_dict = {'I': np.zeros([L, B, K, K]), 'U': np.zeros([L, B, K, K])} 154 | for l in range(L): 155 | logit_pred = logit_pred_list[l] 156 | mask_gt_argmax = mask_gt_argmax_list[l] 157 | logit_pred_argmax = torch.argmax(logit_pred, dim=1, keepdim=True) 158 | mask_gt_argmax = torch.unsqueeze(mask_gt_argmax, 1) 159 | mask_pred_onehot = torch.zeros_like(logit_pred).scatter(1, logit_pred_argmax, 1)[:, :-1] 160 | mask_gt_onehot = torch.zeros_like(logit_pred).scatter(1, mask_gt_argmax, 1)[:, 1:] 161 | for b in range(B): 162 | for i in range(K): 163 | for j in range(K): 164 | mask_gt = mask_gt_onehot[b, i] 165 | mask_pred = mask_pred_onehot[b, j] 166 | I = torch.sum(mask_gt * mask_pred).item() 167 | U = torch.sum(mask_gt + mask_pred).item() - I 168 | info_dict['I'][l, b, i, j] = I 169 | info_dict['U'][l, b, i, j] = U 170 | batch_ious = [] 171 | for b in range(B): 172 | best_iou, best_p = 0, None 173 | for p in list(itertools.permutations(range(K))): 174 | cur_I = [info_dict['I'][l, b, i, p[i]] for l in range(L) for i in range(K)] 175 | cur_U = [info_dict['U'][l, b, i, p[i]] for l in range(L) for i in range(K)] 176 | cur_iou = np.mean(np.array(cur_I) / np.maximum(np.array(cur_U), 1)) 177 | if cur_iou > best_iou: 178 | best_iou = cur_iou 179 | batch_ious.append(best_iou) 180 | 181 | return np.array(batch_ious) 182 | 183 | 184 | def evaluation_motion_visible(args, model, loader): 185 | print('==> evaluation_motion (visible surface)') 186 | mse_dict = [0 for _ in range(args.seq_len)] 187 | data_num = 0 188 | for batch in tqdm(loader): 189 | batch_size = batch['0-action'].size(0) 190 | data_num += batch_size 191 | last_s = model.get_init_repr(batch_size).cuda() 192 | for step_id in range(args.seq_len): 193 | output = model( 194 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1), 195 | last_s=last_s, 196 | input_action=batch['%d-action' % step_id].cuda(), 197 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None, 198 | no_warp=args.model_type=='nowarp' 199 | ) 200 | if not args.model_type in ['single', '3dflow'] : 201 | last_s = output['s'].data 202 | 203 | tsdf = batch['%d-tsdf' % step_id].cuda().unsqueeze(1) 204 | mask = batch['%d-mask_3d' % step_id].cuda().unsqueeze(1) 205 | surface_mask = ((tsdf > -0.99).float()) * ((tsdf < 0).float()) * ((mask > 0).float()) 206 | surface_mask[..., 0] = 0 207 | 208 | target = batch['%d-scene_flow_3d' % step_id].cuda() 209 | pred = output['motion'] 210 | 211 | mse = torch.sum((target - pred) ** 2 * surface_mask, dim=[1, 2, 3, 4]) / torch.sum(surface_mask, dim=[1, 2, 3, 4]) 212 | mse_dict[step_id] += torch.sum(mse).item() * 0.16 213 | # 0.16(0.4^2) is the scale to convert the unit from "voxel" to "cm". 214 | # The voxel size is 0.4cm. Here we use seuqre error. 215 | print('motion_visible (MSE in cm) = ', np.mean([np.mean(mse_dict[i]) / data_num for i in range(args.seq_len)])) 216 | 217 | 218 | def evaluation_motion_full(args, model, loader): 219 | print('==> evaluation_motion (full volume)') 220 | mse_dict = [0 for _ in range(args.seq_len)] 221 | data_num = 0 222 | for batch in tqdm(loader): 223 | batch_size = batch['0-action'].size(0) 224 | data_num += batch_size 225 | last_s = model.get_init_repr(batch_size).cuda() 226 | for step_id in range(args.seq_len): 227 | output = model( 228 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1), 229 | last_s=last_s, 230 | input_action=batch['%d-action' % step_id].cuda(), 231 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None, 232 | no_warp=args.model_type=='nowarp' 233 | ) 234 | if not args.model_type in ['single', '3dflow'] : 235 | last_s = output['s'].data 236 | 237 | target = batch['%d-scene_flow_3d' % step_id].cuda() 238 | pred = output['motion'] 239 | 240 | mse = torch.mean((target - pred) ** 2, dim=[1, 2, 3, 4]) 241 | mse_dict[step_id] += torch.sum(mse).item() * 0.16 242 | # 0.16(0.4^2) is the scale to convert the unit from "voxel" to "cm". 243 | # The voxel size is 0.4cm. Here we use seuqre error. 244 | print('motion_full (MSE in cm) = ', np.mean([np.mean(mse_dict[i]) / data_num for i in range(args.seq_len)])) 245 | 246 | 247 | if __name__ == '__main__': 248 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | from torch import nn 4 | import torch.multiprocessing as mp 5 | import torch.distributed as dist 6 | from torch.utils.data import DataLoader 7 | from torch.utils.tensorboard import SummaryWriter 8 | import argparse 9 | import itertools 10 | import shutil 11 | from tqdm import tqdm 12 | from data import Data 13 | from utils import mkdir, flow2im, html_visualize, mask_visualization, tsdf_visualization 14 | from model import ModelDSR 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | # exp args 19 | parser.add_argument('--exp', type=str, help='name of exp') 20 | parser.add_argument('--gpus', type=int, nargs='+', help='list of gpus to be used, separated by space') 21 | parser.add_argument('--resume', default=None, type=str, help='path to model or exp, None means training from scratch') 22 | 23 | # data args 24 | parser.add_argument('--data_path', type=str, help='path to data') 25 | parser.add_argument('--object_num', type=int, default=5, help='number of objects') 26 | parser.add_argument('--seq_len', type=int, default=10, help='sequence length for training') 27 | parser.add_argument('--batch', type=int, default=12, help='batch size per gpu') 28 | parser.add_argument('--workers', type=int, default=4, help='number of workers per gpu') 29 | 30 | parser.add_argument('--model_type', type=str, default='dsr', choices=['dsr', 'single', 'nowarp', 'gtwarp', '3dflow']) 31 | parser.add_argument('--transform_type', type=str, default='se3euler', choices=['affine', 'se3euler', 'se3aa', 'se3spquat', 'se3quat']) 32 | 33 | # loss args 34 | parser.add_argument('--alpha_motion', type=float, default=1.0, help='weight of motino loss (MSE)') 35 | parser.add_argument('--alpha_mask', type=float, default=5.0, help='weight of mask loss (BCE)') 36 | 37 | # training args 38 | parser.add_argument('--snapshot_freq', type=int, default=1, help='snapshot frequency') 39 | parser.add_argument('--epoch', type=int, default=30, help='number of training eposhes') 40 | parser.add_argument('--finetune', dest='finetune', action='store_true', 41 | help='finetuning or training from scratch ==> different learning rate strategies') 42 | 43 | # distributed training args 44 | parser.add_argument('--seed', type=int, default=23333, help='random seed') 45 | parser.add_argument('--dist_backend', type=str, default='nccl', help='distributed training backend') 46 | parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:2333', help='distributed training url') 47 | 48 | 49 | def main(): 50 | args = parser.parse_args() 51 | 52 | # loss types & loss_idx 53 | loss_types = ['all', 'motion', 'mask'] 54 | loss_idx = {} 55 | for i, loss_type in enumerate(loss_types): 56 | loss_idx[loss_type] = i 57 | print('==> loss types: ', loss_types) 58 | args.loss_types = loss_types 59 | args.loss_idx = loss_idx 60 | 61 | # check sequence length 62 | if args.model_type == 'single': 63 | assert(args.seq_len == 1) 64 | 65 | # resume 66 | if args.resume is not None and not args.resume.endswith('.pth'): 67 | args.resume = osp.join('exp', args.resume, 'models/latest.pth') 68 | 69 | # dir & args 70 | exp_dir = osp.join('exp', args.exp) 71 | mkdir(exp_dir) 72 | 73 | print('==> arguments parsed') 74 | str_list = [] 75 | for key in vars(args): 76 | print('[{0}] = {1}'.format(key, getattr(args, key))) 77 | str_list.append('--{0}={1} \\'.format(key, getattr(args, key))) 78 | 79 | args.model_dir = osp.join(exp_dir, 'models') 80 | mkdir(args.model_dir) 81 | args.visualization_dir = osp.join(exp_dir, 'visualization') 82 | mkdir(args.visualization_dir) 83 | 84 | mp.spawn(main_worker, nprocs=len(args.gpus), args=(len(args.gpus), args)) 85 | 86 | 87 | def main_worker(rank, world_size, args): 88 | args.gpu = args.gpus[rank] 89 | if rank == 0: 90 | writer = SummaryWriter(osp.join('exp', args.exp)) 91 | print(f'==> Rank={rank}, Use GPU: {args.gpu} for training.') 92 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=world_size, rank=rank) 93 | 94 | torch.cuda.set_device(args.gpu) 95 | 96 | model = ModelDSR( 97 | object_num=args.object_num, 98 | transform_type=args.transform_type, 99 | motion_type='se3' if args.model_type != '3dflow' else 'conv', 100 | ) 101 | 102 | model.cuda() 103 | optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.95)) 104 | 105 | if args.resume is not None: 106 | checkpoint = torch.load(args.resume, map_location=torch.device(f'cuda:{args.gpu}')) 107 | model.load_state_dict(checkpoint['state_dict']) 108 | print(f'==> rank={rank}, loaded checkpoint {args.resume}') 109 | 110 | data, samplers, loaders = {}, {}, {} 111 | for split in ['train', 'test']: 112 | data[split] = Data(data_path=args.data_path, split=split, seq_len=args.seq_len) 113 | samplers[split] = torch.utils.data.distributed.DistributedSampler(data[split]) 114 | loaders[split] = DataLoader( 115 | dataset=data[split], 116 | batch_size=args.batch, 117 | num_workers=args.workers, 118 | sampler=samplers[split], 119 | pin_memory=False 120 | ) 121 | print('==> dataset loaded: [size] = {0} + {1}'.format(len(data['train']), len(data['test']))) 122 | 123 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 124 | 125 | for epoch in range(args.epoch): 126 | samplers['train'].set_epoch(epoch) 127 | lr = adjust_learning_rate(optimizer, epoch, args) 128 | if rank == 0: 129 | print(f'==> epoch = {epoch}, lr = {lr}') 130 | 131 | with torch.enable_grad(): 132 | loss_tensor_train = iterate(loaders['train'], model, optimizer, rank, args) 133 | with torch.no_grad(): 134 | loss_tensor_test = iterate(loaders['test'], model, None, rank, args) 135 | 136 | # tensorboard log 137 | loss_tensor = torch.stack([loss_tensor_train, loss_tensor_test]).cuda() 138 | torch.distributed.all_reduce(loss_tensor) 139 | if rank == 0: 140 | training_step = (epoch + 1) * len(data['train']) 141 | loss_tensor = loss_tensor.cpu().numpy() 142 | for i, split in enumerate(['train', 'test']): 143 | for j, loss_type in enumerate(args.loss_types): 144 | for step_id in range(args.seq_len): 145 | writer.add_scalar( 146 | '%s-loss_%s/%d' % (split, loss_type, step_id), 147 | loss_tensor[i, j, step_id] / len(data[split]), epoch+1) 148 | writer.add_scalar('learning_rate', lr, epoch + 1) 149 | 150 | if rank == 0 and (epoch + 1) % args.snapshot_freq == 0: 151 | visualize(loaders, model, epoch, args) 152 | save_state = { 153 | 'state_dict': model.module.state_dict(), 154 | } 155 | torch.save(save_state, osp.join(args.model_dir, 'latest.pth')) 156 | shutil.copyfile( 157 | osp.join(args.model_dir, 'latest.pth'), 158 | osp.join(args.model_dir, 'epoch_%d.pth' % (epoch + 1)) 159 | ) 160 | 161 | def adjust_learning_rate(optimizer, epoch, args): 162 | if args.finetune: 163 | if epoch < 5: 164 | lr = 5e-4 165 | elif epoch < 10: 166 | lr = 2e-4 167 | elif epoch < 15: 168 | lr = 5e-5 169 | else: 170 | lr = 1e-5 171 | 172 | else: 173 | if epoch < 2: 174 | lr = 1e-5 175 | elif epoch < 5: 176 | lr = 1e-3 177 | elif epoch < 10: 178 | lr = 5e-4 179 | elif epoch < 20: 180 | lr = 2e-4 181 | elif epoch < 25: 182 | lr = 5e-5 183 | else: 184 | lr = 1e-5 185 | 186 | for param_group in optimizer.param_groups: 187 | param_group['lr'] = lr 188 | return lr 189 | 190 | 191 | def iterate(loader, model, optimizer, rank, args): 192 | motion_metric = nn.MSELoss() 193 | loss_tensor = torch.zeros([len(args.loss_types), args.seq_len]) 194 | if rank == 0: 195 | loader = tqdm(loader, desc='test' if optimizer is None else 'train') 196 | for batch in loader: 197 | batch_size = batch['0-action'].size(0) 198 | last_s = model.module.get_init_repr(batch_size).cuda() 199 | batch_order = None 200 | 201 | for step_id in range(args.seq_len): 202 | output = model( 203 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1), 204 | last_s=last_s, 205 | input_action=batch['%d-action' % step_id].cuda(), 206 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None, 207 | no_warp=args.model_type=='nowarp' 208 | ) 209 | last_s = output['s'].data 210 | loss = 0 211 | 212 | if 'motion' in args.loss_types: 213 | loss_motion = motion_metric( 214 | output['motion'], 215 | batch['%d-scene_flow_3d' % step_id].cuda() 216 | ) 217 | loss_tensor[args.loss_idx['motion'], step_id] += loss_motion.item() * batch_size 218 | loss += args.alpha_motion * loss_motion 219 | 220 | if 'mask' in args.loss_types: 221 | mask_gt = batch['%d-mask_3d' % step_id].cuda() 222 | if batch_order is None: 223 | batch_order = get_batch_order(output['init_logit'], mask_gt) 224 | loss_mask = get_mask_loss(output['init_logit'], mask_gt, batch_order) 225 | loss_tensor[args.loss_idx['mask'], step_id] += loss_mask.item() * batch_size 226 | loss += args.alpha_mask * loss_mask 227 | 228 | loss_tensor[args.loss_idx['all'], step_id] += loss.item() * batch_size 229 | 230 | if optimizer is not None: 231 | optimizer.zero_grad() 232 | loss.backward() 233 | optimizer.step() 234 | 235 | if step_id != args.seq_len - 1: 236 | batch_order = get_batch_order(output['init_logit'], mask_gt) 237 | return loss_tensor 238 | 239 | 240 | def get_batch_order(logit_pred, mask_gt): 241 | batch_order = [] 242 | B, K, S1, S2, S3 = logit_pred.size() 243 | sum = 0 244 | for b in range(B): 245 | all_p = list(itertools.permutations(list(range(K - 1)))) 246 | best_loss, best_p = None, None 247 | for p in all_p: 248 | permute_pred = torch.stack( 249 | [logit_pred[b:b + 1, -1]] + [logit_pred[b:b + 1, i] for i in p], 250 | dim=1).contiguous() 251 | cur_loss = nn.CrossEntropyLoss()(permute_pred, mask_gt[b:b + 1]).item() 252 | if best_loss is None or cur_loss < best_loss: 253 | best_loss = cur_loss 254 | best_p = p 255 | batch_order.append(best_p) 256 | sum += best_loss 257 | return batch_order 258 | 259 | 260 | def get_mask_loss(logit_pred, mask_gt, batch_order): 261 | loss = 0 262 | B, K, S1, S2, S3 = logit_pred.size() 263 | for b in range(B): 264 | permute_pred = torch.stack( 265 | [logit_pred[b:b + 1, -1]] + [logit_pred[b:b + 1, i] for i in batch_order[b]], 266 | dim=1).contiguous() 267 | loss += nn.CrossEntropyLoss()(permute_pred, mask_gt[b:b + 1]) 268 | return loss 269 | 270 | 271 | def visualize(loaders, model, epoch, args): 272 | visualization_path = osp.join(args.visualization_dir, 'epoch_%03d' % (epoch + 1)) 273 | figures = {} 274 | ids = [split + '_' + str(itr) + '-' + str(step_id) 275 | for split in ['train', 'test'] 276 | for itr in range(args.batch) 277 | for step_id in range(args.seq_len)] 278 | cols = ['color_image', 'color_heightmap', 'motion_gt', 'motion_pred', 'mask_gt'] 279 | if args.model_type != '3dflow': 280 | cols = cols + ['mask_pred', 'next_mask_pred'] 281 | 282 | with torch.no_grad(): 283 | for split in ['train', 'test']: 284 | model.train() 285 | batch = iter(loaders[split]).next() 286 | batch_size = batch['0-action'].size(0) 287 | last_s = model.module.get_init_repr(batch_size).cuda() 288 | for step_id in range(args.seq_len): 289 | output = model( 290 | input_volume=batch['%d-tsdf' % step_id].cuda().unsqueeze(1), 291 | last_s=last_s, 292 | input_action=batch['%d-action' % step_id].cuda(), 293 | input_motion=batch['%d-scene_flow_3d' % step_id].cuda() if args.model_type=='gtwarp' else None, 294 | no_warp=args.model_type=='nowarp', 295 | next_mask=True 296 | ) 297 | last_s = output['s'].data 298 | 299 | vis_color_image = batch['%d-color_image' % step_id].numpy() 300 | vis_color_heightmap = batch['%d-color_heightmap' % step_id].numpy() 301 | motion_gt = torch.sum(batch['%d-scene_flow_3d' % step_id][:, :2, ...], dim=4).numpy() 302 | motion_pred = torch.sum(output['motion'][:, :2, ...], dim=4).cpu().numpy() 303 | 304 | vis_mask_gt = mask_visualization(batch['%d-mask_3d' % step_id].numpy()) 305 | 306 | if args.model_type != '3dflow': 307 | vis_mask_pred = mask_visualization(output['init_logit'].cpu().numpy()) 308 | vis_next_mask_pred = mask_visualization(output['next_mask'].cpu().numpy()) 309 | 310 | for k in range(args.batch): 311 | figures['%s_%d-%d_color_image' % (split, k, step_id)] = vis_color_image[k] 312 | figures['%s_%d-%d_color_heightmap' % (split, k, step_id)] = vis_color_heightmap[k] 313 | figures['%s_%d-%d_motion_gt' % (split, k, step_id)] = flow2im(motion_gt[k]) 314 | figures['%s_%d-%d_motion_pred' % (split, k, step_id)] = flow2im(motion_pred[k]) 315 | figures['%s_%d-%d_mask_gt' % (split, k, step_id)] = vis_mask_gt[k] 316 | if args.model_type != '3dflow': 317 | figures['%s_%d-%d_mask_pred' % (split, k, step_id)] = vis_mask_pred[k] 318 | figures['%s_%d-%d_next_mask_pred' % (split, k, step_id)] = vis_next_mask_pred[k] 319 | 320 | html_visualize(visualization_path, figures, ids, cols, title=args.exp) 321 | 322 | 323 | if __name__ == '__main__': 324 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import collections 3 | import math 4 | import os 5 | import shutil 6 | 7 | import cv2 8 | import imageio 9 | import numpy as np 10 | import dominate 11 | from dominate.tags import * 12 | import queue 13 | import threading 14 | 15 | 16 | # Get rotation matrix from euler angles 17 | def euler2rotm(theta): 18 | R_x = np.array([[1, 0, 0], 19 | [0, math.cos(theta[0]), -math.sin(theta[0])], 20 | [0, math.sin(theta[0]), math.cos(theta[0])] 21 | ]) 22 | R_y = np.array([[math.cos(theta[1]), 0, math.sin(theta[1])], 23 | [0, 1, 0], 24 | [-math.sin(theta[1]), 0, math.cos(theta[1])] 25 | ]) 26 | R_z = np.array([[math.cos(theta[2]), -math.sin(theta[2]), 0], 27 | [math.sin(theta[2]), math.cos(theta[2]), 0], 28 | [0, 0, 1] 29 | ]) 30 | R = np.dot(R_z, np.dot(R_y, R_x)) 31 | return R 32 | 33 | 34 | def transform_points(pts, transform): 35 | # pts = [3xN] array 36 | # transform: [3x4] 37 | pts_t = np.dot(transform[0:3, 0:3], pts) + np.tile(transform[0:3, 3:], (1, pts.shape[1])) 38 | return pts_t 39 | 40 | 41 | def project_pts_to_2d(pts, camera_view_matrix, camera_intrisic): 42 | # transformation from word to virtual camera 43 | # camera_intrisic for virtual camera [ [f,0,0],[0,f,0],[0,0,1]] f is focal length 44 | # RT_wrd2cam 45 | pts_c = transform_points(pts, camera_view_matrix[0:3, :]) 46 | rot_algix = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0]]) 47 | pts_c = transform_points(pts_c, rot_algix) 48 | coord_2d = np.dot(camera_intrisic, pts_c) 49 | coord_2d[0:2, :] = coord_2d[0:2, :] / np.tile(coord_2d[2, :], (2, 1)) 50 | coord_2d[2, :] = pts_c[2, :] 51 | return coord_2d 52 | 53 | 54 | def project_pts_to_3d(color_image, depth_image, camera_intr, camera_pose): 55 | W, H = depth_image.shape 56 | cam_pts, rgb_pts = get_pointcloud(color_image, depth_image, camera_intr) 57 | world_pts = np.transpose( 58 | np.dot(camera_pose[0:3, 0:3], np.transpose(cam_pts)) + np.tile(camera_pose[0:3, 3:], (1, cam_pts.shape[0]))) 59 | 60 | pts = world_pts.reshape([W, H, 3]) 61 | pts = np.transpose(pts, [2, 0, 1]) 62 | 63 | return pts 64 | 65 | 66 | def get_pointcloud(color_img, depth_img, camera_intrinsics): 67 | # Get depth image size 68 | im_h = depth_img.shape[0] 69 | im_w = depth_img.shape[1] 70 | 71 | # Project depth into 3D point cloud in camera coordinates 72 | pix_x, pix_y = np.meshgrid(np.linspace(0, im_w - 1, im_w), np.linspace(0, im_h - 1, im_h)) 73 | cam_pts_x = np.multiply(pix_x - camera_intrinsics[0, 2], depth_img / camera_intrinsics[0, 0]) 74 | cam_pts_y = np.multiply(pix_y - camera_intrinsics[1, 2], depth_img / camera_intrinsics[1, 1]) 75 | cam_pts_z = depth_img.copy() 76 | cam_pts_x.shape = (im_h * im_w, 1) 77 | cam_pts_y.shape = (im_h * im_w, 1) 78 | cam_pts_z.shape = (im_h * im_w, 1) 79 | 80 | # Reshape image into colors for 3D point cloud 81 | rgb_pts_r = color_img[:, :, 0] 82 | rgb_pts_g = color_img[:, :, 1] 83 | rgb_pts_b = color_img[:, :, 2] 84 | rgb_pts_r.shape = (im_h * im_w, 1) 85 | rgb_pts_g.shape = (im_h * im_w, 1) 86 | rgb_pts_b.shape = (im_h * im_w, 1) 87 | 88 | cam_pts = np.concatenate((cam_pts_x, cam_pts_y, cam_pts_z), axis=1) 89 | rgb_pts = np.concatenate((rgb_pts_r, rgb_pts_g, rgb_pts_b), axis=1) 90 | 91 | return cam_pts, rgb_pts 92 | 93 | 94 | def get_heightmap(color_img, depth_img, cam_intrinsics, cam_pose, workspace_limits, heightmap_resolution): 95 | # Compute heightmap size 96 | heightmap_size = np.round(((workspace_limits[1][1] - workspace_limits[1][0]) / heightmap_resolution, 97 | (workspace_limits[0][1] - workspace_limits[0][0]) / heightmap_resolution)).astype(int) 98 | 99 | # Get 3D point cloud from RGB-D images 100 | surface_pts, color_pts = get_pointcloud(color_img, depth_img, cam_intrinsics) 101 | 102 | # Transform 3D point cloud from camera coordinates to robot coordinates 103 | surface_pts = np.transpose( 104 | np.dot(cam_pose[0:3, 0:3], np.transpose(surface_pts)) + np.tile(cam_pose[0:3, 3:], (1, surface_pts.shape[0]))) 105 | 106 | # Sort surface points by z value 107 | sort_z_ind = np.argsort(surface_pts[:, 2]) 108 | surface_pts = surface_pts[sort_z_ind] 109 | color_pts = color_pts[sort_z_ind] 110 | 111 | # Filter out surface points outside heightmap boundaries 112 | heightmap_valid_ind = np.logical_and(np.logical_and(np.logical_and( 113 | np.logical_and(surface_pts[:, 0] >= workspace_limits[0][0], surface_pts[:, 0] < workspace_limits[0][1]), 114 | surface_pts[:, 1] >= workspace_limits[1][0]), surface_pts[:, 1] < workspace_limits[1][1]), 115 | surface_pts[:, 2] < workspace_limits[2][1]) 116 | surface_pts = surface_pts[heightmap_valid_ind] 117 | color_pts = color_pts[heightmap_valid_ind] 118 | 119 | # Create orthographic top-down-view RGB-D heightmaps 120 | color_heightmap_r = np.zeros((heightmap_size[0], heightmap_size[1], 1), dtype=np.uint8) 121 | color_heightmap_g = np.zeros((heightmap_size[0], heightmap_size[1], 1), dtype=np.uint8) 122 | color_heightmap_b = np.zeros((heightmap_size[0], heightmap_size[1], 1), dtype=np.uint8) 123 | depth_heightmap = np.zeros(heightmap_size) 124 | heightmap_pix_x = np.floor((surface_pts[:, 0] - workspace_limits[0][0]) / heightmap_resolution).astype(int) 125 | heightmap_pix_y = np.floor((surface_pts[:, 1] - workspace_limits[1][0]) / heightmap_resolution).astype(int) 126 | color_heightmap_r[heightmap_pix_y, heightmap_pix_x] = color_pts[:, [0]] 127 | color_heightmap_g[heightmap_pix_y, heightmap_pix_x] = color_pts[:, [1]] 128 | color_heightmap_b[heightmap_pix_y, heightmap_pix_x] = color_pts[:, [2]] 129 | color_heightmap = np.concatenate((color_heightmap_r, color_heightmap_g, color_heightmap_b), axis=2) 130 | depth_heightmap[heightmap_pix_y, heightmap_pix_x] = surface_pts[:, 2] 131 | z_bottom = workspace_limits[2][0] 132 | depth_heightmap = depth_heightmap - z_bottom 133 | depth_heightmap[depth_heightmap < 0] = 0 134 | # depth_heightmap[depth_heightmap == -z_bottom] = np.nan 135 | 136 | return color_heightmap, depth_heightmap 137 | 138 | 139 | def mkdir(path, clean=False): 140 | if clean and os.path.exists(path): 141 | shutil.rmtree(path) 142 | if not os.path.exists(path): 143 | os.makedirs(path) 144 | 145 | 146 | def imresize(im, dsize, cfirst=False): 147 | if cfirst: 148 | im = im.transpose(1, 2, 0) 149 | im = cv2.resize(im, dsize=dsize) 150 | if cfirst: 151 | im = im.transpose(2, 0, 1) 152 | return im 153 | 154 | 155 | def imretype(im, dtype): 156 | im = np.array(im) 157 | 158 | if im.dtype in ['float', 'float16', 'float32', 'float64']: 159 | im = im.astype(np.float) 160 | elif im.dtype == 'uint8': 161 | im = im.astype(np.float) / 255. 162 | elif im.dtype == 'uint16': 163 | im = im.astype(np.float) / 65535. 164 | else: 165 | raise NotImplementedError('unsupported source dtype: {0}'.format(im.dtype)) 166 | 167 | assert np.min(im) >= 0 and np.max(im) <= 1 168 | 169 | if dtype in ['float', 'float16', 'float32', 'float64']: 170 | im = im.astype(dtype) 171 | elif dtype == 'uint8': 172 | im = (im * 255.).astype(dtype) 173 | elif dtype == 'uint16': 174 | im = (im * 65535.).astype(dtype) 175 | else: 176 | raise NotImplementedError('unsupported target dtype: {0}'.format(dtype)) 177 | 178 | return im 179 | 180 | 181 | def imwrite(path, obj): 182 | if not isinstance(obj, (collections.Sequence, collections.UserList)): 183 | obj = [obj] 184 | writer = imageio.get_writer(path) 185 | for im in obj: 186 | im = imretype(im, dtype='uint8').squeeze() 187 | if len(im.shape) == 3 and im.shape[0] == 3: 188 | im = np.transpose(im, (1, 2, 0)) 189 | writer.append_data(im) 190 | writer.close() 191 | 192 | 193 | def flow2im(flow, max=None, dtype='float32', cfirst=False): 194 | flow = np.array(flow) 195 | 196 | if np.ndim(flow) == 3 and flow.shape[0] == 2: 197 | x, y = flow[:, ...] 198 | elif np.ndim(flow) == 3 and flow.shape[-1] == 2: 199 | x = flow[..., 0] 200 | y = flow[..., 1] 201 | else: 202 | raise NotImplementedError( 203 | 'unsupported flow size: {0}'.format(flow.shape)) 204 | 205 | rho, theta = cv2.cartToPolar(x, y) 206 | 207 | if max is None: 208 | max = np.maximum(np.max(rho), 1e-6) 209 | 210 | hsv = np.zeros(list(rho.shape) + [3], dtype=np.uint8) 211 | hsv[..., 0] = theta * 90 / np.pi 212 | hsv[..., 1] = 255 213 | hsv[..., 2] = np.minimum(rho / max, 1) * 255 214 | 215 | im = cv2.cvtColor(hsv, code=cv2.COLOR_HSV2RGB) 216 | im = imretype(im, dtype=dtype) 217 | 218 | if cfirst: 219 | im = im.transpose(2, 0, 1) 220 | return im 221 | 222 | 223 | def draw_arrow(image, action, direction_num=8, heightmap_pixel_size=0.004): 224 | # image: [W, H, 3] (color image) or [W, H] (depth image) 225 | def put_in_bound(val, bound): 226 | # output: 0 <= val < bound 227 | val = min(max(0, val), bound - 1) 228 | return val 229 | 230 | img = image.copy() 231 | if isinstance(action, tuple): 232 | x_ini, y_ini, direction = action 233 | else: 234 | x_ini, y_ini, direction = action['2'], action['1'], action['0'] 235 | 236 | pushing_distance = 0.15 237 | 238 | angle = direction / direction_num * 2 * np.pi 239 | x_end = put_in_bound(int(x_ini + pushing_distance / heightmap_pixel_size * np.cos(angle)), image.shape[1]) 240 | y_end = put_in_bound(int(y_ini + pushing_distance / heightmap_pixel_size * np.sin(angle)), image.shape[0]) 241 | 242 | if img.shape[0] == 1: 243 | # gray img, white arrow 244 | img = imretype(img[:, :, np.newaxis], 'uint8') 245 | cv2.arrowedLine(img=img, pt1=(x_ini, y_ini), pt2=(x_end, y_end), color=255, thickness=2, tipLength=0.2) 246 | elif img.shape[2] == 3: 247 | # rgb img, red arrow 248 | cv2.arrowedLine(img=img, pt1=(x_ini, y_ini), pt2=(x_end, y_end), color=(255, 0, 0), thickness=2, tipLength=0.2) 249 | return img 250 | 251 | 252 | def multithreading_exec(num, q, fun, blocking=True): 253 | """ 254 | Multi-threading Execution 255 | 256 | :param num: number of threadings 257 | :param q: queue of args 258 | :param fun: function to be executed 259 | :param blocking: blocking or not (default True) 260 | """ 261 | 262 | class Worker(threading.Thread): 263 | def __init__(self, q, fun): 264 | super().__init__() 265 | self.q = q 266 | self.fun = fun 267 | self.start() 268 | 269 | def run(self): 270 | while True: 271 | try: 272 | args = self.q.get(block=False) 273 | self.fun(*args) 274 | self.q.task_done() 275 | except queue.Empty: 276 | break 277 | 278 | thread_list = [Worker(q, fun) for i in range(num)] 279 | if blocking: 280 | for t in thread_list: 281 | if t.is_alive(): 282 | t.join() 283 | 284 | 285 | def html_visualize(web_path, data, ids, cols, others=[], title='visualization', threading_num=10): 286 | """ 287 | :param web_path: (str) directory to save webpage. It will clear the old data! 288 | :param data: (dict of data). 289 | key: {id}_{col}. 290 | value: figure or text 291 | - figure: ndarray --> .png or [ndarrays,] --> .gif 292 | - text: str or [str,] 293 | :param ids: (list of str) name of each row 294 | :param cols: (list of str) name of each column 295 | :param others: (list of dict) other figures 296 | 'name': str, name of the data, visualize using h2() 297 | 'data': string or ndarray(image) 298 | 'height': int, height of the image (default 256) 299 | :param title: (str) title of the webpage 300 | :param threading_num: number of threadings for imwrite (default 10) 301 | """ 302 | figure_path = os.path.join(web_path, 'figures') 303 | mkdir(web_path, clean=True) 304 | mkdir(figure_path, clean=True) 305 | q = queue.Queue() 306 | for key, value in data.items(): 307 | if isinstance(value, np.ndarray): 308 | q.put((os.path.join(figure_path, key + '.png'), value)) 309 | if not isinstance(value, list) and isinstance(value[0], np.ndarray): 310 | q.put((os.path.join(figure_path, key + '.gif'), value)) 311 | multithreading_exec(threading_num, q, imwrite) 312 | 313 | with dominate.document(title=title) as web: 314 | dominate.tags.h1(title) 315 | with dominate.tags.table(border=1, style='table-layout: fixed;'): 316 | with dominate.tags.tr(): 317 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='center', width='64px'): 318 | dominate.tags.p('id') 319 | for col in cols: 320 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='center', ): 321 | dominate.tags.p(col) 322 | for id in ids: 323 | with dominate.tags.tr(): 324 | bgcolor = 'F1C073' if id.startswith('train') else 'C5F173' 325 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='center', 326 | bgcolor=bgcolor): 327 | for part in id.split('_'): 328 | dominate.tags.p(part) 329 | for col in cols: 330 | with dominate.tags.td(style='word-wrap: break-word;', halign='center', align='top'): 331 | value = data.get(f'{id}_{col}', None) 332 | if isinstance(value, str): 333 | dominate.tags.p(value) 334 | elif isinstance(value, list) and isinstance(value[0], str): 335 | for v in value: 336 | dominate.tags.p(v) 337 | else: 338 | dominate.tags.img(style='height:128px', 339 | src=os.path.join('figures', '{}_{}.png'.format(id, col))) 340 | for idx, other in enumerate(others): 341 | dominate.tags.h2(other['name']) 342 | if isinstance(other['data'], str): 343 | dominate.tags.p(other['data']) 344 | else: 345 | imwrite(os.path.join(figure_path, '_{}_{}.png'.format(idx, other['name'])), other['data']) 346 | dominate.tags.img(style='height:{}px'.format(other.get('height', 256)), 347 | src=os.path.join('figures', '_{}_{}.png'.format(idx, other['name']))) 348 | with open(os.path.join(web_path, 'index.html'), 'w') as fp: 349 | fp.write(web.render()) 350 | 351 | 352 | def mask_visualization(mask): 353 | # mask: numpy array, [B, K, W, H, D] or [B, W, H, D] 354 | # Red, Green, Blue, Yellow, Purple 355 | colors = [(255, 87, 89), (89, 169, 79), (78, 121, 167), (237, 201, 72), (176, 122, 161)] 356 | if len(mask.shape) == 5: 357 | B, K, W, H, D = mask.shape 358 | argmax_mask = np.argmax(mask, axis=1) 359 | else: 360 | B, W, H, D = mask.shape 361 | K = max(np.max(mask) + 1, 2) 362 | argmax_mask = mask - 1 363 | 364 | mask_list = [] 365 | for k in range(K - 1): 366 | mask_list.append((argmax_mask == k).astype(np.float32)) 367 | 368 | mask = np.sum(np.stack(mask_list, axis=1), axis=4) 369 | sum_mask = np.sum(mask, 1) + 1 # [B, 1, W, H] 370 | color_mask = np.zeros([B, W, H, 3]) 371 | for i in range(K - 1): 372 | color_mask += mask[:, i, ..., np.newaxis] * np.array(colors[i]) 373 | return np.clip(color_mask / sum_mask[..., np.newaxis] / 255.0, 0, 1) 374 | 375 | 376 | def mask_visualization_2d(mask): 377 | # Red, Green, Blue, Yellow, Purple 378 | colors = [(255, 87, 89), (89, 169, 79), (78, 121, 167), (237, 201, 72), (176, 122, 161)] 379 | if len(mask.shape) == 4: 380 | B, K, W, H = mask.shape 381 | argmax_mask = np.argmax(mask, axis=1) 382 | else: 383 | B, W, H = mask.shape 384 | K = max(np.max(mask) + 1, 2) 385 | argmax_mask = mask 386 | mask_list = [] 387 | for k in range(K): 388 | mask_list.append((argmax_mask == k).astype(np.float32)) 389 | 390 | mask = np.stack(mask_list, axis=1) 391 | sum_mask = np.sum(mask, 1) + 1 # [B, 1, W, H] 392 | color_mask = np.zeros([B, W, H, 3]) 393 | for i in range(K): 394 | color_mask += mask[:, i, ..., np.newaxis] * np.array(colors[i]) 395 | return np.clip(color_mask / sum_mask[..., np.newaxis] / 255.0, 0, 1) 396 | 397 | 398 | def volume_visualization(volume): 399 | # volume: numpy array: [B, W, H, D] 400 | tmp = np.sum(volume, axis=-1) 401 | tmp -= np.min(tmp) 402 | tmp /= max(np.max(tmp), 1) 403 | return np.clip(tmp, 0, 1) 404 | 405 | 406 | def tsdf_visualization(tsdf): 407 | # tsdf: numpy array: [B, W, H, D] 408 | return volume_visualization((tsdf < 0).astype(np.float32)) 409 | --------------------------------------------------------------------------------