├── models ├── utils │ ├── __init__.py │ ├── spmm.py │ └── scatter.py ├── __init__.py ├── graph_resnet.py ├── graph_unet.py ├── octree_ounet.py ├── graph_lenet.py ├── graph_ae.py ├── graph_slounet.py ├── mpu.py ├── graph_ounet.py ├── modules.py └── dual_octree.py ├── teaser.png ├── .gitmodules ├── .gitignore ├── requirements.txt ├── solver ├── README.md ├── __init__.py ├── sampler.py ├── dataset.py ├── config.py └── solver.py ├── losses ├── __init__.py └── loss.py ├── CODE_OF_CONDUCT.md ├── datasets ├── __init__.py ├── utils.py ├── pointcloud_eval.py ├── synthetic_room.py ├── shapenet.py └── pointcloud.py ├── configs ├── shapes.yaml ├── dfaust_eval.yaml ├── shapenet_eval.yaml ├── shapenet_ae_eval.yaml ├── shapenet_unseen5.yaml ├── synthetic_room_eval.yaml ├── finetune.yaml ├── cls_m40.yaml ├── dfaust.yaml ├── shapenet_ae.yaml ├── shapenet.yaml └── synthetic_room.yaml ├── LICENSE ├── SUPPORT.md ├── SECURITY.md ├── tools ├── compute_metrics.py ├── dfaust.py ├── room.py └── shapenet.py ├── dualocnn.py ├── utils.py └── README.md /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/DualOctreeGNN/HEAD/teaser.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "solver"] 2 | path = solver 3 | url = https://github.com/wang-ps/solver-pytorch.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | __pycache__ 4 | *.egg-info/ 5 | .vscode/ 6 | *.pyd 7 | *.so 8 | *.octree 9 | *.points 10 | *.xyz 11 | *.json 12 | debug.log 13 | data 14 | logs -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tqdm 4 | yacs 5 | scipy 6 | plyfile 7 | tensorboard 8 | scikit-image 9 | trimesh 10 | wget 11 | mesh2sdf 12 | setuptools==59.5.0 13 | matplotlib 14 | -------------------------------------------------------------------------------- /solver/README.md: -------------------------------------------------------------------------------- 1 | # The Solver for PyTorch 2 | 3 | This repository contains part of the code from the official implementation of 4 | [O-CNN](https://github.com/microsoft/O-CNN.git). 5 | And the code is released under the MIT license. 6 | 7 | The code can be used for training/testing PyTorch models. 8 | 9 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from .loss import shapenet_loss, dfaust_loss, synthetic_room_loss 9 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from . import config 9 | from .config import get_config, parse_args 10 | 11 | from . import solver 12 | from .solver import Solver 13 | 14 | from . import dataset 15 | from .dataset import Dataset -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from . import octree_ounet 9 | from . import graph_lenet 10 | from . import graph_resnet 11 | from . import graph_ounet 12 | from . import graph_slounet 13 | from . import graph_unet 14 | from . import graph_ae 15 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | from .shapenet import get_shapenet_dataset 9 | from .pointcloud import get_pointcloud_dataset, get_singlepointcloud_dataset 10 | from .pointcloud_eval import get_pointcloud_eval_dataset 11 | from .synthetic_room import get_synthetic_room_dataset 12 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | 11 | 12 | def collate_func(batch): 13 | output = ocnn.collate_octrees(batch) 14 | 15 | if 'pos' in output: 16 | batch_idx = torch.cat([torch.ones(pos.size(0), 1) * i 17 | for i, pos in enumerate(output['pos'])], dim=0) 18 | pos = torch.cat(output['pos'], dim=0) 19 | output['pos'] = torch.cat([pos, batch_idx], dim=1) 20 | 21 | for key in ['grad', 'sdf', 'occu', 'weight']: 22 | if key in output: 23 | output[key] = torch.cat(output[key], dim=0) 24 | 25 | return output 26 | -------------------------------------------------------------------------------- /configs/shapes.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: evaluate 11 | logdir: logs/shapes_eval 12 | ckpt: logs/dfaust/dfaust/checkpoints/00600.model.pth 13 | resolution: 420 14 | sdf_scale: 0.9 15 | 16 | 17 | DATA: 18 | test: 19 | name: pointcloud 20 | point_scale: 0.9 21 | 22 | # octree building 23 | depth: 8 24 | full_depth: 3 25 | node_dis: True 26 | split_label: True 27 | offset: 0.0 28 | 29 | # data loading 30 | location: data/shapes 31 | filelist: data/shapes/filelist.txt 32 | batch_size: 1 33 | shuffle: False 34 | # num_workers: 0 35 | 36 | 37 | MODEL: 38 | name: graph_unet 39 | resblock_type: basic 40 | 41 | depth: 8 42 | full_depth: 3 43 | depth_out: 8 44 | channel: 4 45 | nout: 4 46 | -------------------------------------------------------------------------------- /configs/dfaust_eval.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: evaluate 11 | logdir: logs/dfaust_eval/dfaust 12 | ckpt: logs/dfaust/dfaust/checkpoints/00600.model.pth 13 | resolution: 300 14 | sdf_scale: 0.9 15 | 16 | 17 | DATA: 18 | test: 19 | name: pointcloud 20 | point_scale: 1.0 21 | 22 | # octree building 23 | depth: 8 24 | full_depth: 3 25 | node_dis: True 26 | split_label: True 27 | offset: 0.0 28 | 29 | # data loading 30 | location: data/dfaust/dataset 31 | filelist: data/dfaust/filelist/test.txt 32 | batch_size: 1 33 | shuffle: False 34 | # num_workers: 0 35 | 36 | 37 | MODEL: 38 | name: graph_unet 39 | resblock_type: basic 40 | 41 | depth: 8 42 | full_depth: 3 43 | depth_out: 8 44 | channel: 4 45 | nout: 4 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /configs/shapenet_eval.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: evaluate 11 | logdir: logs/shapenet_eval/test 12 | ckpt: logs/shapenet/shapenet/checkpoints/00300.model.pth 13 | sdf_scale: 0.9 14 | resolution: 128 15 | 16 | 17 | DATA: 18 | test: 19 | name: pointcloud_eval 20 | 21 | # octree building 22 | depth: 6 23 | offset: 0.0 24 | full_depth: 3 25 | node_dis: True 26 | split_label: True 27 | 28 | # data loading 29 | # location: data/ShapeNet/dataset # the original testing data 30 | location: data/ShapeNet/test.input # the generated testing data 31 | filelist: data/ShapeNet/filelist/test.txt 32 | batch_size: 1 33 | shuffle: False 34 | in_memory: False 35 | # num_workers: 0 36 | 37 | 38 | MODEL: 39 | name: graph_ounet 40 | 41 | channel: 5 42 | depth: 6 43 | nout: 4 44 | depth_out: 6 45 | full_depth: 3 46 | bottleneck: 4 47 | 48 | resblock_type: basic 49 | -------------------------------------------------------------------------------- /configs/shapenet_ae_eval.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: evaluate 11 | logdir: logs/shapenet_eval/ae 12 | ckpt: logs/shapenet/ae/checkpoints/00300.model.pth 13 | sdf_scale: 0.9 14 | resolution: 160 15 | 16 | 17 | DATA: 18 | test: 19 | name: shapenet 20 | 21 | # octree building 22 | depth: 6 23 | offset: 0.0 24 | full_depth: 2 25 | node_dis: True 26 | split_label: True 27 | 28 | # no data augmentation 29 | distort: False 30 | 31 | # data loading 32 | location: data/ShapeNet/dataset 33 | filelist: data/ShapeNet/filelist/test_im.txt 34 | batch_size: 1 35 | shuffle: False 36 | load_sdf: False 37 | # num_workers: 0 38 | 39 | 40 | MODEL: 41 | name: graph_ae 42 | 43 | channel: 4 44 | depth: 6 45 | nout: 4 46 | depth_out: 6 47 | full_depth: 2 48 | bottleneck: 4 49 | resblock_type: basic 50 | 51 | LOSS: 52 | name: shapenet 53 | loss_type: sdf_reg_loss 54 | -------------------------------------------------------------------------------- /configs/shapenet_unseen5.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: evaluate 11 | logdir: logs/shapenet_eval/unseen5 12 | ckpt: logs/shapenet/shapenet/checkpoints/00300.model.pth 13 | sdf_scale: 0.9 14 | resolution: 128 15 | 16 | 17 | DATA: 18 | test: 19 | name: shapenet 20 | 21 | # octree building 22 | depth: 6 23 | offset: 0.0 24 | full_depth: 3 25 | node_dis: True 26 | split_label: True 27 | 28 | # data augmentation, add noise only 29 | distort: True 30 | 31 | # data loading 32 | location: data/ShapeNet/dataset.unseen5 # the original testing data 33 | filelist: data/ShapeNet/filelist/test_unseen5.txt 34 | batch_size: 1 35 | load_sdf: False 36 | shuffle: False 37 | in_memory: False 38 | # num_workers: 0 39 | 40 | 41 | MODEL: 42 | name: graph_ounet 43 | 44 | channel: 5 45 | depth: 6 46 | nout: 4 47 | depth_out: 6 48 | full_depth: 3 49 | bottleneck: 4 50 | 51 | resblock_type: basic 52 | -------------------------------------------------------------------------------- /configs/synthetic_room_eval.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: evaluate 11 | logdir: logs/room_eval/room 12 | ckpt: logs/room/room/checkpoints/00900.model.pth 13 | sdf_scale: 0.9 14 | resolution: 280 15 | 16 | DATA: 17 | test: 18 | name: pointcloud_eval 19 | point_scale: 0.6 20 | 21 | # octree building 22 | depth: 7 23 | offset: 0.0 24 | full_depth: 3 25 | node_dis: True 26 | split_label: True 27 | 28 | # data augmentation, add noise only 29 | # distort: True 30 | 31 | # data loading 32 | # location: data/room/synthetic_room_dataset 33 | location: data/room/test.input # the generated testing data 34 | filelist: data/room/filelist/test.txt 35 | batch_size: 1 36 | shuffle: False 37 | in_memory: False 38 | # num_workers: 0 39 | 40 | 41 | MODEL: 42 | name: graph_ounet 43 | 44 | channel: 5 45 | depth: 7 46 | nout: 4 47 | depth_out: 7 48 | full_depth: 3 49 | bottleneck: 4 50 | resblock_type: basic 51 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /configs/finetune.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: train 11 | 12 | logdir: logs/shapes/finetune 13 | max_epoch: 6000 14 | test_every_epoch: 100 15 | log_per_iter: 50 16 | ckpt_num: 200 17 | ckpt: '' 18 | 19 | # optimizer 20 | type: adamw 21 | weight_decay: 0.01 # default value of adamw 22 | lr: 0.0001 23 | 24 | # learning rate 25 | lr_type: constant 26 | 27 | 28 | DATA: 29 | train: &id001 30 | name: singlepointcloud 31 | point_scale: 0.9 32 | point_sample_num: 100000 33 | 34 | # octree building 35 | depth: 8 36 | full_depth: 3 37 | node_dis: True 38 | split_label: True 39 | offset: 0.0 40 | 41 | # data loading 42 | location: data/Shapes 43 | filelist: data/Shapes/filelist/lucy.txt 44 | batch_size: 1 45 | # num_workers: 0 46 | 47 | test: *id001 48 | 49 | 50 | MODEL: 51 | name: graph_unet 52 | resblock_type: basic 53 | find_unused_parameters: True 54 | 55 | depth: 8 56 | full_depth: 3 57 | depth_out: 8 58 | channel: 4 59 | nout: 4 60 | 61 | LOSS: 62 | name: dfaust 63 | loss_type: possion_grad_loss 64 | -------------------------------------------------------------------------------- /configs/cls_m40.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: train 11 | type: sgd 12 | 13 | logdir: logs/m40/m40 14 | max_epoch: 300 15 | test_every_epoch: 5 16 | 17 | # lr: 0.001 # default value of adamw 18 | # weight_decay: 0.01 # default value of adamw 19 | step_size: (120,180,240) 20 | ckpt_num: 20 21 | 22 | DATA: 23 | train: 24 | # octree building 25 | depth: 5 26 | offset: 0.016 27 | 28 | # data augmentations 29 | distort: True 30 | angle: (0, 0, 5) # small rotation along z axis 31 | interval: (1, 1, 1) 32 | scale: 0.25 33 | jitter: 0.125 34 | 35 | # data loading 36 | location: data/ModelNet40/ModelNet40.points 37 | filelist: data/ModelNet40/m40_train_points_list.txt 38 | batch_size: 32 39 | shuffle: True 40 | # num_workers: 0 41 | 42 | test: 43 | # octree building 44 | depth: 5 45 | offset: 0.016 46 | 47 | # data augmentations 48 | distort: False 49 | angle: (0, 0, 5) # small rotation along z axis 50 | interval: (1, 1, 1) 51 | scale: 0.25 52 | jitter: 0.125 53 | 54 | # data loading 55 | location: data/ModelNet40/ModelNet40.points 56 | filelist: data/ModelNet40/m40_test_points_list.txt 57 | batch_size: 32 58 | shuffle: False 59 | # num_workers: 0 60 | 61 | MODEL: 62 | name: lenet 63 | channel: 3 64 | nout: 40 65 | depth: 5 66 | 67 | LOSS: 68 | num_class: 40 -------------------------------------------------------------------------------- /configs/dfaust.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: train 11 | 12 | logdir: logs/dfaust/dfaust 13 | max_epoch: 600 14 | test_every_epoch: 10 15 | log_per_iter: 50 16 | ckpt_num: 200 17 | 18 | # optimizer 19 | type: adamw 20 | weight_decay: 0.01 # default value of adamw 21 | lr: 0.001 # default value of adamw 22 | 23 | # learning rate 24 | lr_type: poly 25 | step_size: (200,300) 26 | 27 | 28 | DATA: 29 | train: 30 | name: pointcloud 31 | point_scale: 1.0 32 | 33 | # octree building 34 | depth: 8 35 | full_depth: 3 36 | node_dis: True 37 | split_label: True 38 | offset: 0.0 39 | 40 | # data loading 41 | location: data/dfaust/dataset 42 | filelist: data/dfaust/filelist/train.txt 43 | batch_size: 16 44 | # num_workers: 0 45 | 46 | test: 47 | name: pointcloud 48 | point_scale: 1.0 49 | 50 | # octree building 51 | depth: 8 52 | full_depth: 3 53 | node_dis: True 54 | split_label: True 55 | offset: 0.0 56 | 57 | # data loading 58 | location: data/dfaust/dataset 59 | filelist: data/dfaust/filelist/test.txt 60 | batch_size: 4 61 | # num_workers: 0 62 | 63 | 64 | MODEL: 65 | name: graph_unet 66 | resblock_type: basic 67 | find_unused_parameters: True 68 | 69 | depth: 8 70 | full_depth: 3 71 | depth_out: 8 72 | channel: 4 73 | nout: 4 74 | 75 | LOSS: 76 | name: dfaust 77 | loss_type: possion_grad_loss 78 | -------------------------------------------------------------------------------- /solver/sampler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from torch.utils.data import Sampler, DistributedSampler, Dataset 10 | 11 | 12 | class InfSampler(Sampler): 13 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: 14 | self.dataset = dataset 15 | self.shuffle = shuffle 16 | self.reset_sampler() 17 | 18 | def reset_sampler(self): 19 | num = len(self.dataset) 20 | indices = torch.randperm(num) if self.shuffle else torch.arange(num) 21 | self.indices = indices.tolist() 22 | self.iter_num = 0 23 | 24 | def __iter__(self): 25 | return self 26 | 27 | def __next__(self): 28 | value = self.indices[self.iter_num] 29 | self.iter_num = self.iter_num + 1 30 | 31 | if self.iter_num >= len(self.indices): 32 | self.reset_sampler() 33 | return value 34 | 35 | def __len__(self): 36 | return len(self.dataset) 37 | 38 | 39 | class DistributedInfSampler(DistributedSampler): 40 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: 41 | super().__init__(dataset, shuffle=shuffle) 42 | self.reset_sampler() 43 | 44 | def reset_sampler(self): 45 | self.indices = list(super().__iter__()) 46 | self.iter_num = 0 47 | 48 | def __iter__(self): 49 | return self 50 | 51 | def __next__(self): 52 | value = self.indices[self.iter_num] 53 | self.iter_num = self.iter_num + 1 54 | 55 | if self.iter_num >= len(self.indices): 56 | self.reset_sampler() 57 | return value 58 | -------------------------------------------------------------------------------- /datasets/pointcloud_eval.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import ocnn 10 | import torch 11 | import numpy as np 12 | from plyfile import PlyData 13 | 14 | from solver import Dataset 15 | 16 | 17 | class Transform: 18 | r"""Load point clouds from ply files, rescale the points and build octree. 19 | Used to evaluate the network trained on ShapeNet.""" 20 | 21 | def __init__(self, flags): 22 | self.flags = flags 23 | 24 | self.point_scale = flags.point_scale 25 | self.points2octree = ocnn.Points2Octree(**flags) 26 | 27 | def __call__(self, points, idx): 28 | # After normalization, the points are in [-1, 1] 29 | pts = points[:, :3] / self.point_scale 30 | 31 | # construct the points 32 | ones = torch.ones(pts.shape[0], dtype=torch.float32) 33 | points = ocnn.points_new(pts, torch.Tensor(), ones, torch.Tensor()) 34 | points, _ = ocnn.clip_points(points, [-1.0]*3, [1.0]*3) 35 | 36 | # transform points to octree 37 | octree = self.points2octree(points) 38 | 39 | return {'points_in': points, 'octree_in': octree} 40 | 41 | 42 | def read_file(filename: str): 43 | plydata = PlyData.read(filename + '.ply') 44 | vtx = plydata['vertex'] 45 | points = np.stack([vtx['x'], vtx['y'], vtx['z']], axis=1).astype(np.float32) 46 | output = torch.from_numpy(points.astype(np.float32)) 47 | return output 48 | 49 | 50 | def get_pointcloud_eval_dataset(flags): 51 | transform = Transform(flags) 52 | dataset = Dataset(flags.location, flags.filelist, transform, 53 | read_file=read_file, in_memory=flags.in_memory) 54 | return dataset, ocnn.collate_octrees 55 | -------------------------------------------------------------------------------- /configs/shapenet_ae.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: train 11 | 12 | logdir: logs/shapenet/ae 13 | max_epoch: 300 14 | test_every_epoch: 20 15 | log_per_iter: 50 16 | ckpt_num: 40 17 | 18 | # optimizer 19 | type: adamw 20 | weight_decay: 0.01 # default value of adamw 21 | lr: 0.001 # default value of adamw 22 | 23 | # learning rate 24 | lr_type: poly 25 | step_size: (160,240) 26 | 27 | DATA: 28 | train: 29 | name: shapenet 30 | 31 | # octree building 32 | depth: 6 33 | offset: 0.0 34 | full_depth: 2 35 | node_dis: True 36 | split_label: True 37 | 38 | # no data augmentation 39 | distort: False 40 | 41 | # data loading 42 | location: data/ShapeNet/dataset 43 | filelist: data/ShapeNet/filelist/train_im.txt 44 | load_sdf: True 45 | batch_size: 16 46 | shuffle: True 47 | # num_workers: 0 48 | 49 | test: 50 | name: shapenet 51 | 52 | # octree building 53 | depth: 6 54 | offset: 0.0 55 | full_depth: 2 56 | node_dis: True 57 | split_label: True 58 | 59 | # no data augmentation 60 | distort: False 61 | 62 | # data loading 63 | location: data/ShapeNet/dataset 64 | filelist: data/ShapeNet/filelist/val_im.txt 65 | batch_size: 4 66 | load_sdf: True 67 | shuffle: False 68 | # num_workers: 0 69 | 70 | 71 | MODEL: 72 | name: graph_ae 73 | 74 | channel: 4 75 | depth: 6 76 | nout: 4 77 | depth_out: 6 78 | full_depth: 2 79 | bottleneck: 4 80 | resblock_type: basic 81 | 82 | LOSS: 83 | name: shapenet 84 | loss_type: sdf_reg_loss 85 | -------------------------------------------------------------------------------- /configs/shapenet.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: train 11 | 12 | logdir: logs/shapenet/shapenet 13 | max_epoch: 300 14 | test_every_epoch: 20 15 | log_per_iter: 50 16 | ckpt_num: 40 17 | 18 | # optimizer 19 | type: adamw 20 | weight_decay: 0.01 # default value of adamw 21 | lr: 0.001 # default value of adamw 22 | 23 | # learning rate 24 | lr_type: poly 25 | step_size: (160,240) 26 | 27 | DATA: 28 | train: 29 | name: shapenet 30 | 31 | # octree building 32 | depth: 6 33 | offset: 0.0 34 | full_depth: 3 35 | node_dis: True 36 | split_label: True 37 | 38 | # data augmentation, add noise only 39 | distort: True 40 | 41 | # data loading 42 | location: data/ShapeNet/dataset 43 | filelist: data/ShapeNet/filelist/train.txt 44 | load_sdf: True 45 | batch_size: 16 46 | shuffle: True 47 | # num_workers: 0 48 | 49 | test: 50 | name: shapenet 51 | 52 | # octree building 53 | depth: 6 54 | offset: 0.0 55 | full_depth: 3 56 | node_dis: True 57 | split_label: True 58 | 59 | # data augmentation, add noise only 60 | distort: True 61 | 62 | # data loading 63 | location: data/ShapeNet/dataset 64 | filelist: data/ShapeNet/filelist/val.txt 65 | batch_size: 8 66 | load_sdf: True 67 | shuffle: False 68 | # num_workers: 0 69 | 70 | 71 | MODEL: 72 | name: graph_ounet 73 | 74 | channel: 5 75 | depth: 6 76 | nout: 4 77 | depth_out: 6 78 | full_depth: 3 79 | bottleneck: 4 80 | resblock_type: basic 81 | 82 | LOSS: 83 | name: shapenet 84 | loss_type: sdf_reg_loss 85 | -------------------------------------------------------------------------------- /configs/synthetic_room.yaml: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | SOLVER: 9 | gpu: 0, 10 | run: train 11 | 12 | logdir: logs/room/room 13 | max_epoch: 900 14 | test_every_epoch: 10 15 | log_per_iter: 40 16 | ckpt_num: 20 17 | 18 | type: adamw 19 | lr: 0.001 # default value of adamw 20 | weight_decay: 0.01 # default value of adamw 21 | lr_type: poly 22 | step_size: (80,120) 23 | 24 | 25 | DATA: 26 | train: 27 | name: synthetic_room 28 | 29 | # octree building 30 | depth: 7 31 | offset: 0.0 32 | full_depth: 3 33 | node_dis: True 34 | split_label: True 35 | 36 | # data augmentation, add noise only 37 | distort: True 38 | 39 | # data loading 40 | location: data/room/synthetic_room_dataset 41 | filelist: data/room/filelist/train.txt 42 | load_occu: True 43 | sample_surf_points: True 44 | batch_size: 16 45 | shuffle: True 46 | # num_workers: 0 47 | 48 | test: 49 | name: synthetic_room 50 | 51 | 52 | # octree building 53 | depth: 7 54 | offset: 0.0 55 | full_depth: 3 56 | node_dis: True 57 | split_label: True 58 | 59 | # data augmentation, add noise only 60 | distort: True 61 | 62 | # data loading 63 | location: data/room/synthetic_room_dataset 64 | filelist: data/room/filelist/val.txt 65 | load_occu: True 66 | sample_surf_points: True 67 | batch_size: 8 68 | shuffle: False 69 | # num_workers: 0 70 | 71 | 72 | MODEL: 73 | name: graph_ounet 74 | 75 | channel: 5 76 | depth: 7 77 | nout: 4 78 | depth_out: 7 79 | full_depth: 3 80 | bottleneck: 4 81 | resblock_type: basic 82 | 83 | LOSS: 84 | name: synthetic_room 85 | loss_type: possion_grad_loss 86 | -------------------------------------------------------------------------------- /models/utils/spmm.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from .scatter import scatter_add 10 | 11 | 12 | def spmm(index, value, m, n, matrix): 13 | """Matrix product of sparse matrix with dense matrix. 14 | 15 | Args: 16 | index (:class:`LongTensor`): The index tensor of sparse matrix. 17 | value (:class:`Tensor`): The value tensor of sparse matrix. 18 | m (int): The first dimension of corresponding dense matrix. 19 | n (int): The second dimension of corresponding dense matrix. 20 | matrix (:class:`Tensor`): The dense matrix. 21 | 22 | :rtype: :class:`Tensor` 23 | """ 24 | 25 | assert n == matrix.size(-2) 26 | 27 | row, col = index 28 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) 29 | 30 | out = matrix.index_select(-2, col) 31 | out = out * value.unsqueeze(-1) 32 | out = scatter_add(out, row, dim=-2, dim_size=m) 33 | 34 | return out 35 | 36 | 37 | def modulated_spmm(index, value, m, n, matrix, xyzf): 38 | """Matrix product of sparse matrix with dense matrix. 39 | 40 | Args: 41 | index (:class:`LongTensor`): The index tensor of sparse matrix. 42 | value (:class:`Tensor`): The value tensor of sparse matrix. 43 | m (int): The first dimension of corresponding dense matrix. 44 | n (int): The second dimension of corresponding dense matrix. 45 | matrix (:class:`Tensor`): The dense matrix. 46 | 47 | :rtype: :class:`Tensor` 48 | """ 49 | 50 | assert n == matrix.size(-2) 51 | 52 | row, col = index 53 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) 54 | 55 | out = matrix.index_select(-2, col) 56 | ones = torch.ones((xyzf.shape[0], 1), device=xyzf.device) 57 | out = torch.sum(out * torch.cat([xyzf, ones], dim=1), dim=1, keepdim=True) 58 | out = out * value.unsqueeze(-1) 59 | out = scatter_add(out, row, dim=-2, dim_size=m) 60 | 61 | return out 62 | -------------------------------------------------------------------------------- /solver/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.utils.data 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | 15 | def read_file(filename): 16 | points = np.fromfile(filename, dtype=np.uint8) 17 | return torch.from_numpy(points) # convert it to torch.tensor 18 | 19 | 20 | class Dataset(torch.utils.data.Dataset): 21 | 22 | def __init__(self, root, filelist, transform, read_file=read_file, 23 | in_memory=False, take: int = -1): 24 | super(Dataset, self).__init__() 25 | self.root = root 26 | self.filelist = filelist 27 | self.transform = transform 28 | self.in_memory = in_memory 29 | self.read_file = read_file 30 | self.take = take 31 | 32 | self.filenames, self.labels = self.load_filenames() 33 | if self.in_memory: 34 | print('Load files into memory from ' + self.filelist) 35 | self.samples = [self.read_file(os.path.join(self.root, f)) 36 | for f in tqdm(self.filenames, ncols=80, leave=False)] 37 | 38 | def __len__(self): 39 | return len(self.filenames) 40 | 41 | def __getitem__(self, idx): 42 | sample = self.samples[idx] if self.in_memory else \ 43 | self.read_file(os.path.join(self.root, self.filenames[idx])) # noqa 44 | output = self.transform(sample, idx) # data augmentation + build octree 45 | output['label'] = self.labels[idx] 46 | output['filename'] = self.filenames[idx] 47 | return output 48 | 49 | def load_filenames(self): 50 | filenames, labels = [], [] 51 | with open(self.filelist) as fid: 52 | lines = fid.readlines() 53 | for line in lines: 54 | tokens = line.split() 55 | filename = tokens[0] 56 | label = tokens[1] if len(tokens) == 2 else 0 57 | filenames.append(filename) 58 | labels.append(int(label)) 59 | 60 | num = len(filenames) 61 | if self.take > num or self.take < 1: 62 | self.take = num 63 | 64 | return filenames[:self.take], labels[:self.take] 65 | -------------------------------------------------------------------------------- /models/utils/scatter.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | from typing import Optional 10 | 11 | 12 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 13 | if dim < 0: 14 | dim = other.dim() + dim 15 | if src.dim() == 1: 16 | for _ in range(0, dim): 17 | src = src.unsqueeze(0) 18 | for _ in range(src.dim(), other.dim()): 19 | src = src.unsqueeze(-1) 20 | src = src.expand_as(other) 21 | return src 22 | 23 | 24 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 25 | out: Optional[torch.Tensor] = None, 26 | dim_size: Optional[int] = None) -> torch.Tensor: 27 | index = broadcast(index, src, dim) 28 | if out is None: 29 | size = list(src.size()) 30 | if dim_size is not None: 31 | size[dim] = dim_size 32 | elif index.numel() == 0: 33 | size[dim] = 0 34 | else: 35 | size[dim] = int(index.max()) + 1 36 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 37 | return out.scatter_add_(dim, index, src) 38 | else: 39 | return out.scatter_add_(dim, index, src) 40 | 41 | 42 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 43 | weights: Optional[torch.Tensor] = None, 44 | out: Optional[torch.Tensor] = None, 45 | dim_size: Optional[int] = None) -> torch.Tensor: 46 | if weights is not None: 47 | src = src * broadcast(weights, src, dim) 48 | out = scatter_add(src, index, dim, out, dim_size) 49 | dim_size = out.size(dim) 50 | 51 | index_dim = dim 52 | if index_dim < 0: 53 | index_dim = index_dim + src.dim() 54 | if index.dim() <= index_dim: 55 | index_dim = index.dim() - 1 56 | 57 | if weights is None: 58 | weights = torch.ones(index.size(), dtype=src.dtype, device=src.device) 59 | count = scatter_add(weights, index, index_dim, None, dim_size) 60 | count[count < 1] = 1 61 | count = broadcast(count, out, dim) 62 | if out.is_floating_point(): 63 | out.true_divide_(count) 64 | else: 65 | out.div_(count, rounding_mode='floor') 66 | return out 67 | -------------------------------------------------------------------------------- /models/graph_resnet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import ocnn 10 | 11 | from .modules import GraphConvBnRelu, GraphResBlocks 12 | from .dual_octree import DualOctree 13 | 14 | 15 | class GraphResNet(torch.nn.Module): 16 | 17 | def __init__(self, depth, channel_in, nout, resblk_num): 18 | super().__init__() 19 | self.depth, self.channel_in = depth, channel_in 20 | channels = [2 ** max(11 - i, 2) for i in range(depth + 1)] 21 | channels.append(channels[depth]) 22 | n_edge_type, avg_degree, bottleneck = 7, 7, 4 23 | 24 | self.conv1 = GraphConvBnRelu( 25 | channel_in, channels[depth], n_edge_type, avg_degree) 26 | self.resblocks = torch.nn.ModuleList( 27 | [GraphResBlocks(channels[d + 1], channels[d], resblk_num, bottleneck, 28 | n_edge_type, avg_degree) for d in range(depth, 2, -1)]) 29 | self.pools = torch.nn.ModuleList( 30 | [ocnn.OctreeMaxPool(d) for d in range(depth, 2, -1)]) 31 | self.header = torch.nn.Sequential( 32 | ocnn.FullOctreeGlobalPool(depth=2), # global pool 33 | # torch.nn.Dropout(p=0.5), # drop 34 | torch.nn.Linear(channels[3], nout)) # fc 35 | 36 | def forward(self, octree): 37 | # Get the initial feature 38 | data = ocnn.octree_property(octree, 'feature', self.depth) 39 | assert data.size(1) == self.channel_in 40 | 41 | # build the dual octree 42 | doctree = DualOctree(octree) 43 | doctree.post_processing_for_ocnn() 44 | 45 | # forward the network 46 | for i, d in enumerate(range(self.depth, 2, -1)): 47 | # perform graph conv 48 | data = data.squeeze().t() 49 | edge_idx = doctree.graph[d]['edge_idx'] 50 | edge_type = doctree.graph[d]['edge_type'] 51 | if d == self.depth: # the first conv 52 | data = self.conv1(data, edge_idx, edge_type) 53 | data = self.resblocks[i](data, edge_idx, edge_type) 54 | 55 | # downsampleing 56 | data = data.t().unsqueeze(0).unsqueeze(-1) 57 | data = self.pools[i](data, octree) 58 | 59 | # classification head 60 | data = self.header(data) 61 | return data 62 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /models/graph_unet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn 10 | 11 | from . import dual_octree 12 | from . import graph_ounet 13 | 14 | 15 | class GraphUNet(graph_ounet.GraphOUNet): 16 | 17 | def _setup_channels_and_resblks(self): 18 | # self.resblk_num = [3] * 7 + [1] + [1] * 9 19 | self.resblk_num = [2] * 16 20 | self.channels = [4, 512, 512, 256, 128, 64, 32, 32, 32] 21 | 22 | def recons_decoder(self, convs, doctree_out): 23 | logits = dict() 24 | reg_voxs = dict() 25 | deconvs = dict() 26 | 27 | deconvs[self.full_depth] = convs[self.full_depth] 28 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)): 29 | if d > self.full_depth: 30 | nnum = doctree_out.nnum[d-1] 31 | leaf_mask = doctree_out.node_child(d-1) < 0 32 | deconvd = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum) 33 | deconvd = deconvd + convs[d] # skip connections 34 | 35 | edge_idx = doctree_out.graph[d]['edge_idx'] 36 | edge_type = doctree_out.graph[d]['edge_dir'] 37 | node_type = doctree_out.graph[d]['node_type'] 38 | deconvs[d] = self.decoder[i-1](deconvd, edge_idx, edge_type, node_type) 39 | 40 | # predict the splitting label 41 | logit = self.predict[i](deconvs[d]) 42 | nnum = doctree_out.nnum[d] 43 | logits[d] = logit[-nnum:] 44 | 45 | # predict the signal 46 | reg_vox = self.regress[i](deconvs[d]) 47 | 48 | # TODO: improve it 49 | # pad zeros to reg_vox to reuse the original code for ocnn 50 | node_mask = doctree_out.graph[d]['node_mask'] 51 | shape = (node_mask.shape[0], reg_vox.shape[1]) 52 | reg_vox_pad = torch.zeros(shape, device=reg_vox.device) 53 | reg_vox_pad[node_mask] = reg_vox 54 | reg_voxs[d] = reg_vox_pad 55 | 56 | return logits, reg_voxs 57 | 58 | def forward(self, octree_in, octree_out=None, pos=None): 59 | # octree_in and octree_out are the same for UNet 60 | doctree_in = dual_octree.DualOctree(octree_in) 61 | doctree_in.post_processing_for_docnn() 62 | 63 | # run encoder and decoder 64 | convs = self.octree_encoder(octree_in, doctree_in) 65 | out = self.recons_decoder(convs, doctree_in) 66 | output = {'reg_voxs': out[1], 'octree_out': octree_in} 67 | 68 | # compute function value with mpu 69 | if pos is not None: 70 | output['mpus'] = self.neural_mpu(pos, out[1], octree_in) 71 | 72 | # create the mpu wrapper 73 | def _neural_mpu(pos): 74 | pred = self.neural_mpu(pos, out[1], octree_in) 75 | return pred[self.depth_out][0] 76 | output['neural_mpu'] = _neural_mpu 77 | 78 | return output 79 | -------------------------------------------------------------------------------- /models/octree_ounet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | import torch.nn 11 | 12 | from . import mpu 13 | 14 | 15 | class OctreeOUNet(ocnn.OUNet): 16 | 17 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6): 18 | super().__init__(depth, channel_in, nout, full_depth) 19 | 20 | self.header = None 21 | self.depth_out = depth_out 22 | 23 | self.neural_mpu = mpu.NeuralMPU(self.full_depth, self.depth_out) 24 | self.regress = torch.nn.ModuleList( 25 | [self._make_predict_module(self.channels[d], 4) 26 | for d in range(full_depth, depth + 1)]) 27 | 28 | def ocnn_decoder(self, convs, octree_out, octree, update_octree=False): 29 | logits = dict() 30 | reg_voxs = dict() 31 | deconvs = dict() 32 | reg_voxs_list = [] 33 | 34 | deconvs[self.full_depth] = convs[self.full_depth] 35 | for i, d in enumerate(range(self.full_depth, self.depth_out + 1)): 36 | if d > self.full_depth: 37 | deconvd = self.upsample[i - 1](deconvs[d - 1], octree_out) 38 | skip, _ = ocnn.octree_align(convs[d], octree, octree_out, d) 39 | deconvd = deconvd + skip 40 | deconvs[d] = self.decoder[i - 1](deconvd, octree_out) 41 | 42 | # predict the splitting label 43 | logit = self.predict[i](deconvs[d]) 44 | logit = logit.squeeze().t() # (1, C, H, 1) -> (H, C) 45 | logits[d] = logit 46 | 47 | # update the octree according to predicted labels 48 | if update_octree: 49 | label = logits[d].argmax(1).to(torch.int32) 50 | octree_out = ocnn.octree_update(octree_out, label, d, split=1) 51 | if d < self.depth_out: 52 | octree_out = ocnn.octree_grow(octree_out, target_depth=d + 1) 53 | 54 | # predict the signal 55 | reg_vox = self.regress[i](deconvs[d]) 56 | reg_vox = reg_vox.squeeze().t() # (1, C, H, 1) -> (H, C) 57 | reg_voxs_list.append(reg_vox) 58 | reg_voxs[d] = torch.cat(reg_voxs_list, dim=0) 59 | 60 | return logits, reg_voxs, octree_out 61 | 62 | def forward(self, octree_in, octree_out=None, pos=None): 63 | update_octree = octree_out is None 64 | if update_octree: 65 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout) 66 | 67 | # run encoder and decoder 68 | convs = self.ocnn_encoder(octree_in) 69 | out = self.ocnn_decoder(convs, octree_out, octree_in, update_octree) 70 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]} 71 | 72 | # compute function value with mpu 73 | if pos is not None: 74 | output['mpus'] = self.neural_mpu(pos, out[1], out[2]) 75 | 76 | # create the mpu wrapper 77 | def _neural_mpu(pos): 78 | pred = self.neural_mpu(pos, out[1], out[2]) 79 | return pred[self.depth_out][0] 80 | output['neural_mpu'] = _neural_mpu 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /tools/compute_metrics.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import argparse 10 | import trimesh.sample 11 | import numpy as np 12 | from tqdm import tqdm 13 | from scipy.spatial import cKDTree 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--mesh_folder', type=str, required=True) 17 | parser.add_argument('--filename_out', type=str, required=True) 18 | parser.add_argument('--num_samples', type=int, default=30000) 19 | parser.add_argument('--ref_folder', type=str, default='data/dfaust/dfaust/mesh_gt') 20 | parser.add_argument('--filelist', type=str, default='data/dfaust/test_all.txt') 21 | args = parser.parse_args() 22 | 23 | 24 | with open(args.filelist, 'r') as fid: 25 | lines = fid.readlines() 26 | filenames = [line.strip() for line in lines] 27 | 28 | 29 | def compute_metrics(filename_ref, filename_pred, num_samples=30000): 30 | mesh_ref = trimesh.load(filename_ref) 31 | points_ref, idx_ref = trimesh.sample.sample_surface(mesh_ref, num_samples) 32 | normals_ref = mesh_ref.face_normals[idx_ref] 33 | # points_ref, normals_ref = read_ply(filename_ref) 34 | 35 | mesh_pred = trimesh.load(filename_pred) 36 | points_pred, idx_pred = trimesh.sample.sample_surface(mesh_pred, num_samples) 37 | normals_pred = mesh_pred.face_normals[idx_pred] 38 | 39 | kdtree_a = cKDTree(points_ref) 40 | dist_a, idx_a = kdtree_a.query(points_pred) 41 | chamfer_a = np.mean(dist_a) 42 | dot_a = np.sum(normals_pred * normals_ref[idx_a], axis=1) 43 | angle_a = np.mean(np.arccos(dot_a) * (180.0 / np.pi)) 44 | consist_a = np.mean(np.abs(dot_a)) 45 | 46 | kdtree_b = cKDTree(points_pred) 47 | dist_b, idx_b = kdtree_b.query(points_ref) 48 | chamfer_b = np.mean(dist_b) 49 | dot_b = np.sum(normals_ref * normals_pred[idx_b], axis=1) 50 | angle_b = np.mean(np.arccos(dot_b) * (180 / np.pi)) 51 | consist_b = np.mean(np.abs(dot_b)) 52 | 53 | return chamfer_a, chamfer_b, angle_a, angle_b, consist_a, consist_b 54 | 55 | 56 | counter = 0 57 | fid = open(args.filename_out, 'w') 58 | fid.write(('name, ' 59 | 'chamfer_a, chamfer_b, chamfer, ' 60 | 'angle_a, angle_b, angle, ' 61 | 'consist_a, consist_b, normal consistency\n')) 62 | for filename in tqdm(filenames, ncols=80): 63 | if filename.endswith('.npy'): 64 | filename = filename[:-4] 65 | filename_ref = os.path.join(args.ref_folder, filename + '.obj') 66 | # filename_ref = os.path.join(args.ref_folder, filename + '.ply') 67 | filename_pred = os.path.join(args.mesh_folder, filename + '.obj') 68 | metrics = compute_metrics(filename_ref, filename_pred, args.num_samples) 69 | 70 | chamfer_a, chamfer_b = metrics[0], metrics[1] 71 | angle_a, angle_b = metrics[2], metrics[3] 72 | consist_a, consist_b = metrics[4], metrics[5] 73 | 74 | msg = '{}, {}, {}, {}, {}, {}, {}, {}, {}, {}\n'.format( 75 | filename, 76 | chamfer_a, chamfer_b, 0.5 * (chamfer_a + chamfer_b), 77 | angle_a, angle_b, 0.5 * (angle_a + angle_b), 78 | consist_a, consist_b, 0.5 * (consist_a + consist_b)) 79 | fid.write(msg) 80 | tqdm.write(msg) 81 | 82 | fid.close() 83 | -------------------------------------------------------------------------------- /datasets/synthetic_room.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import ocnn 10 | import torch 11 | import numpy as np 12 | 13 | from solver import Dataset 14 | from .utils import collate_func 15 | from .shapenet import TransformShape 16 | 17 | 18 | class TransformScene(TransformShape): 19 | 20 | def __init__(self, flags): 21 | self.flags = flags 22 | 23 | self.point_sample_num = 10000 24 | self.occu_sample_num = 4096 25 | self.surface_sample_num = 2048 26 | self.sample_surf_points = flags.sample_surf_points 27 | self.points_scale = 0.6 # the points are actually in [-0.55, 0.55] 28 | self.noise_std = 0.005 29 | self.pos_weight = 10 30 | self.points2octree = ocnn.Points2Octree(**flags) 31 | 32 | def sample_occu(self, sample): 33 | points, occus = sample['points'], sample['occupancies'] 34 | points = points / self.points_scale 35 | points = points + 1e-6 * np.random.randn(*points.shape) # ref ConvoNet 36 | occus = np.unpackbits(occus)[:points.shape[0]] 37 | 38 | rand_idx = np.random.choice(points.shape[0], size=self.occu_sample_num) 39 | points = torch.from_numpy(points[rand_idx]).float() 40 | occus = torch.from_numpy(occus[rand_idx]).float() 41 | # 1 - outside shapes; 1 - inside shapes 42 | occus = 1 - occus # to be consistent with ShapeNet 43 | 44 | # The number of points inside shapes is roughly 1.2% of points outside 45 | # shapes, we set weights to 10 for points inside shapes. 46 | weight = torch.ones_like(occus) 47 | weight[occus < 0.5] = self.pos_weight 48 | 49 | # The points are not on the surfaces, the gradients are 0 50 | grad = torch.zeros_like(points) 51 | return {'pos': points, 'occu': occus, 'weight': weight, 'grad': grad} 52 | 53 | def sample_surface(self, sample): 54 | # get the input TODO: use normals 55 | points, normals = sample['points'], sample['normals'] 56 | 57 | # sample points 58 | rand_idx = np.random.choice(points.shape[0], size=self.surface_sample_num) 59 | pos = torch.from_numpy(points[rand_idx]) 60 | grad = torch.from_numpy(normals[rand_idx]) 61 | pos = pos / self.points_scale # scale to [-1.0, 1.0] 62 | occus = torch.ones(self.surface_sample_num) * 0.5 63 | weight = torch.ones(self.surface_sample_num) * 2.0 # TODO: tune this scale 64 | 65 | return {'pos': pos, 'occu': occus, 'weight': weight, 'grad': grad} 66 | 67 | def __call__(self, sample, idx): 68 | output = self.process_points_cloud(sample['point_cloud']) 69 | 70 | # sample ground truth sdfs 71 | if self.flags.load_occu: 72 | occus = self.sample_occu(sample['occus']) 73 | 74 | if self.sample_surf_points: 75 | surface_occus = self.sample_surface(sample['point_cloud']) 76 | for key in occus.keys(): 77 | occus[key] = torch.cat([occus[key], surface_occus[key]], dim=0) 78 | 79 | output.update(occus) 80 | 81 | return output 82 | 83 | 84 | class ReadFile: 85 | def __init__(self, load_occu=False): 86 | self.load_occu = load_occu 87 | self.num_files = 10 88 | 89 | def __call__(self, filename): 90 | num = np.random.randint(self.num_files) 91 | filename_pc = os.path.join( 92 | filename, 'pointcloud/pointcloud_%02d.npz' % num) 93 | raw = np.load(filename_pc) 94 | point_cloud = {'points': raw['points'], 'normals': raw['normals']} 95 | output = {'point_cloud': point_cloud} 96 | 97 | if self.load_occu: 98 | num = np.random.randint(self.num_files) 99 | filename_occu = os.path.join( 100 | filename, 'points_iou/points_iou_%02d.npz' % num) 101 | raw = np.load(filename_occu) 102 | occus = {'points': raw['points'], 'occupancies': raw['occupancies']} 103 | output['occus'] = occus 104 | 105 | return output 106 | 107 | 108 | def get_synthetic_room_dataset(flags): 109 | transform = TransformScene(flags) 110 | read_file = ReadFile(flags.load_occu) 111 | dataset = Dataset(flags.location, flags.filelist, transform, 112 | read_file=read_file, in_memory=flags.in_memory) 113 | return dataset, collate_func 114 | -------------------------------------------------------------------------------- /models/graph_lenet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import ocnn 10 | 11 | from .modules import GraphConvBnRelu, GraphDownsample, GraphMaxpool 12 | from .dual_octree import DualOctree 13 | 14 | 15 | class GraphLeNet(torch.nn.Module): 16 | '''This is to do comparison with the original O-CNN''' 17 | 18 | def __init__(self, depth, channel_in, nout): 19 | super().__init__() 20 | self.depth, self.channel_in = depth, channel_in 21 | channels = [2 ** max(9 - i, 2) for i in range(depth + 1)] 22 | channels.append(channel_in) 23 | 24 | self.convs = torch.nn.ModuleList([ 25 | GraphConvBnRelu(channels[d + 1], channels[d], n_edge_type=7, 26 | avg_degree=7) for d in range(depth, 2, -1)]) 27 | self.pools = torch.nn.ModuleList([ 28 | ocnn.OctreeMaxPool(d) for d in range(depth, 2, -1)]) 29 | self.octree2voxel = ocnn.FullOctree2Voxel(2) 30 | self.header = torch.nn.Sequential( 31 | torch.nn.Dropout(p=0.5), # drop1 32 | ocnn.FcBnRelu(channels[3] * 64, channels[2]), # fc1 33 | torch.nn.Dropout(p=0.5), # drop2 34 | torch.nn.Linear(channels[2], nout)) # fc2 35 | 36 | def forward(self, octree): 37 | # Get the initial feature 38 | data = ocnn.octree_property(octree, 'feature', self.depth) 39 | assert data.size(1) == self.channel_in 40 | 41 | # build the dual octree 42 | doctree = DualOctree(octree) 43 | doctree.post_processing_for_ocnn() 44 | 45 | # forward the network 46 | for i, d in enumerate(range(self.depth, 2, -1)): 47 | # perform graph conv 48 | data = data.squeeze().t() 49 | edge_idx = doctree.graph[d]['edge_idx'] 50 | edge_type = doctree.graph[d]['edge_type'] 51 | data = self.convs[i](data, edge_idx, edge_type) 52 | 53 | # downsampleing 54 | data = data.t().unsqueeze(0).unsqueeze(-1) 55 | data = self.pools[i](data, octree) 56 | 57 | # classification head 58 | data = self.octree2voxel(data) 59 | data = self.header(data) 60 | return data 61 | 62 | 63 | class DualGraphLeNet(torch.nn.Module): 64 | 65 | def __init__(self, depth, channel_in, nout): 66 | super().__init__() 67 | self.depth, self.channel_in = depth, channel_in 68 | channels = [2 ** max(9 - i, 2) for i in range(depth + 1)] 69 | channels.append(channel_in) 70 | 71 | self.convs = torch.nn.ModuleList([ 72 | GraphConvBnRelu(channels[d + 1], channels[d], n_edge_type=7, 73 | avg_degree=7, n_node_type=d-1) 74 | for d in range(depth, 2, -1)]) 75 | # self.downsample = torch.nn.ModuleList([ 76 | # GraphDownsample(channels[d]) for d in range(depth, 2, -1)]) 77 | self.downsample = torch.nn.ModuleList([ 78 | GraphMaxpool() for d in range(depth, 2, -1)]) 79 | self.octree2voxel = ocnn.FullOctree2Voxel(2) 80 | self.header = torch.nn.Sequential( 81 | torch.nn.Dropout(p=0.5), # drop1 82 | ocnn.FcBnRelu(channels[3] * 64, channels[2]), # fc1 83 | torch.nn.Dropout(p=0.5), # drop2 84 | torch.nn.Linear(channels[2], nout)) # fc2 85 | 86 | def forward(self, octree): 87 | # build the dual octree 88 | doctree = DualOctree(octree) 89 | doctree.post_processing_for_docnn() 90 | 91 | # Get the initial feature 92 | data = doctree.get_input_feature() 93 | assert data.size(1) == self.channel_in 94 | 95 | # forward the network 96 | for i, d in enumerate(range(self.depth, 2, -1)): 97 | # perform graph conv 98 | edge_idx = doctree.graph[d]['edge_idx'] 99 | edge_type = doctree.graph[d]['edge_dir'] 100 | node_type = doctree.graph[d]['node_type'] 101 | data = self.convs[i](data, edge_idx, edge_type, node_type) 102 | 103 | # downsampleing 104 | nnum = doctree.nnum[d] 105 | lnum = doctree.lnum[d-1] 106 | leaf_mask = doctree.node_child(d-1) < 0 107 | data = self.downsample[i](data, leaf_mask, nnum, lnum) 108 | 109 | # classification head 110 | data = data.t().unsqueeze(0).unsqueeze(-1) 111 | data = self.octree2voxel(data) 112 | data = self.header(data) 113 | return data 114 | -------------------------------------------------------------------------------- /models/graph_ae.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | import torch.nn 11 | 12 | from . import modules 13 | from . import dual_octree 14 | from . import graph_ounet 15 | 16 | 17 | class GraphAE(graph_ounet.GraphOUNet): 18 | 19 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6, 20 | resblk_type='bottleneck', bottleneck=4): 21 | super().__init__(depth, channel_in, nout, full_depth, depth_out, 22 | resblk_type, bottleneck) 23 | # this is to make the encoder and decoder symmetric 24 | n_edge_type, avg_degree = 7, 7 25 | self.decoder = torch.nn.ModuleList( 26 | [modules.GraphResBlocks(self.channels[d], self.channels[d], 27 | self.resblk_num[d], bottleneck, n_edge_type, avg_degree, d-1, resblk_type) 28 | for d in range(full_depth, depth + 1)]) 29 | 30 | self.code_channel = 8 31 | self.code_dim = self.code_channel * 2 ** (3 * self.full_depth) 32 | channel_in = self.channels[self.full_depth] 33 | self.project1 = modules.Conv1x1Bn(channel_in, self.code_channel) 34 | self.project2 = modules.Conv1x1BnRelu(self.code_channel, channel_in) 35 | 36 | def octree_encoder(self, octree, doctree): 37 | convs = super().octree_encoder(octree, doctree) 38 | # reduce the dimension 39 | code = self.project1(convs[self.full_depth]) 40 | # constrain the code in [-1, 1] 41 | code = torch.tanh(code) 42 | return code 43 | 44 | def octree_decoder(self, latent_code, doctree_out, doctree=None, update_octree=False): 45 | logits = dict() 46 | reg_voxs = dict() 47 | deconvs = dict() 48 | 49 | deconvs[self.full_depth] = self.project2(latent_code) 50 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)): 51 | if d > self.full_depth: 52 | nnum = doctree_out.nnum[d-1] 53 | leaf_mask = doctree_out.node_child(d-1) < 0 54 | deconvs[d] = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum) 55 | 56 | edge_idx = doctree_out.graph[d]['edge_idx'] 57 | edge_type = doctree_out.graph[d]['edge_dir'] 58 | node_type = doctree_out.graph[d]['node_type'] 59 | deconvs[d] = self.decoder[i](deconvs[d], edge_idx, edge_type, node_type) 60 | 61 | # predict the splitting label 62 | logit = self.predict[i](deconvs[d]) 63 | nnum = doctree_out.nnum[d] 64 | logits[d] = logit[-nnum:] 65 | 66 | # update the octree according to predicted labels 67 | if update_octree: 68 | label = logits[d].argmax(1).to(torch.int32) 69 | octree_out = doctree_out.octree 70 | octree_out = ocnn.octree_update(octree_out, label, d, split=1) 71 | if d < self.depth_out: 72 | octree_out = ocnn.octree_grow(octree_out, target_depth=d+1) 73 | doctree_out = dual_octree.DualOctree(octree_out) 74 | doctree_out.post_processing_for_docnn() 75 | 76 | # predict the signal 77 | reg_vox = self.regress[i](deconvs[d]) 78 | 79 | # TODO: improve it 80 | # pad zeros to reg_vox to reuse the original code for ocnn 81 | node_mask = doctree_out.graph[d]['node_mask'] 82 | shape = (node_mask.shape[0], reg_vox.shape[1]) 83 | reg_vox_pad = torch.zeros(shape, device=reg_vox.device) 84 | reg_vox_pad[node_mask] = reg_vox 85 | reg_voxs[d] = reg_vox_pad 86 | 87 | return logits, reg_voxs, doctree_out.octree 88 | 89 | def extract_code(self, octree_in): 90 | doctree_in = dual_octree.DualOctree(octree_in) 91 | doctree_in.post_processing_for_docnn() 92 | 93 | code = self.octree_encoder(octree_in, doctree_in) 94 | return code 95 | 96 | def decode_code(self, code): 97 | # generate dual octrees 98 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout) 99 | doctree_out = dual_octree.DualOctree(octree_out) 100 | doctree_out.post_processing_for_docnn() 101 | 102 | # run encoder and decoder 103 | out = self.octree_decoder(code, doctree_out, update_octree=True) 104 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]} 105 | 106 | # create the mpu wrapper 107 | def _neural_mpu(pos): 108 | pred = self.neural_mpu(pos, out[1], out[2]) 109 | return pred[self.depth_out][0] 110 | output['neural_mpu'] = _neural_mpu 111 | 112 | return output 113 | -------------------------------------------------------------------------------- /dualocnn.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | 12 | import builder 13 | import utils 14 | from solver import Solver, get_config 15 | 16 | 17 | class DualOcnnSolver(Solver): 18 | 19 | def get_model(self, flags): 20 | return builder.get_model(flags) 21 | 22 | def get_dataset(self, flags): 23 | return builder.get_dataset(flags) 24 | 25 | def batch_to_cuda(self, batch): 26 | keys = ['octree_in', 'octree_gt', 'pos', 'sdf', 'grad', 'weight', 'occu'] 27 | for key in keys: 28 | if key in batch: 29 | batch[key] = batch[key].cuda() 30 | batch['pos'].requires_grad_() 31 | 32 | def compute_loss(self, batch, model_out): 33 | flags = self.FLAGS.LOSS 34 | loss_func = builder.get_loss_function(flags) 35 | output = loss_func(batch, model_out, flags.loss_type) 36 | return output 37 | 38 | def model_forward(self, batch): 39 | self.batch_to_cuda(batch) 40 | model_out = self.model(batch['octree_in'], batch['octree_gt'], batch['pos']) 41 | 42 | output = self.compute_loss(batch, model_out) 43 | losses = [val for key, val in output.items() if 'loss' in key] 44 | output['loss'] = torch.sum(torch.stack(losses)) 45 | return output 46 | 47 | def train_step(self, batch): 48 | output = self.model_forward(batch) 49 | output = {'train/' + key: val for key, val in output.items()} 50 | return output 51 | 52 | def test_step(self, batch): 53 | output = self.model_forward(batch) 54 | output = {'test/' + key: val for key, val in output.items()} 55 | return output 56 | 57 | def extract_mesh(self, neural_mpu, filename, bbox=None): 58 | # bbox used for marching cubes 59 | if bbox is not None: 60 | bbmin, bbmax = bbox[:3], bbox[3:] 61 | else: 62 | sdf_scale = self.FLAGS.SOLVER.sdf_scale 63 | bbmin, bbmax = -sdf_scale, sdf_scale 64 | 65 | # create mesh 66 | utils.create_mesh(neural_mpu, filename, 67 | size=self.FLAGS.SOLVER.resolution, 68 | bbmin=bbmin, bbmax=bbmax, 69 | mesh_scale=self.FLAGS.DATA.test.point_scale, 70 | save_sdf=self.FLAGS.SOLVER.save_sdf) 71 | 72 | def eval_step(self, batch): 73 | # forward the model 74 | output = self.model.forward(batch['octree_in'].cuda()) 75 | 76 | # extract the mesh 77 | filename = batch['filename'][0] 78 | pos = filename.rfind('.') 79 | if pos != -1: filename = filename[:pos] # remove the suffix 80 | filename = os.path.join(self.logdir, filename + '.obj') 81 | folder = os.path.dirname(filename) 82 | if not os.path.exists(folder): os.makedirs(folder) 83 | bbox = batch['bbox'][0].numpy() if 'bbox' in batch else None 84 | self.extract_mesh(output['neural_mpu'], filename, bbox) 85 | 86 | # save the input point cloud 87 | filename = filename[:-4] + '.input.ply' 88 | utils.points2ply(filename, batch['points_in'][0].cpu(), 89 | self.FLAGS.DATA.test.point_scale) 90 | 91 | def save_tensors(self, batch, output): 92 | iter_num = batch['iter_num'] 93 | filename = os.path.join(self.logdir, '%04d.out.octree' % iter_num) 94 | output['octree_out'].cpu().numpy().tofile(filename) 95 | filename = os.path.join(self.logdir, '%04d.in.octree' % iter_num) 96 | batch['octree_in'].cpu().numpy().tofile(filename) 97 | filename = os.path.join(self.logdir, '%04d.in.points' % iter_num) 98 | batch['points_in'][0].cpu().numpy().tofile(filename) 99 | filename = os.path.join(self.logdir, '%04d.gt.octree' % iter_num) 100 | batch['octree_gt'].cpu().numpy().tofile(filename) 101 | filename = os.path.join(self.logdir, '%04d.gt.points' % iter_num) 102 | batch['points_gt'][0].cpu().numpy().tofile(filename) 103 | 104 | @classmethod 105 | def update_configs(cls): 106 | FLAGS = get_config() 107 | FLAGS.SOLVER.resolution = 128 # the resolution used for marching cubes 108 | FLAGS.SOLVER.save_sdf = False # save the sdfs in evaluation 109 | FLAGS.SOLVER.sdf_scale = 0.9 # the scale of sdfs 110 | 111 | FLAGS.DATA.train.point_scale = 0.5 # the scale of point clouds 112 | FLAGS.DATA.train.load_sdf = True # load sdf samples 113 | FLAGS.DATA.train.load_occu = False # load occupancy samples 114 | FLAGS.DATA.train.point_sample_num = 10000 115 | FLAGS.DATA.train.sample_surf_points = False 116 | 117 | # FLAGS.MODEL.skip_connections = True 118 | FLAGS.DATA.test = FLAGS.DATA.train.clone() 119 | FLAGS.LOSS.loss_type = 'sdf_reg_loss' 120 | 121 | 122 | if __name__ == '__main__': 123 | DualOcnnSolver.main() 124 | -------------------------------------------------------------------------------- /models/graph_slounet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import ocnn 10 | import torch.nn 11 | import torch.nn.functional as F 12 | 13 | from . import mpu 14 | from . import modules 15 | from . import graph_ounet 16 | from . import dual_octree 17 | 18 | 19 | class GraphSLDownsample(modules.GraphDownsample): 20 | 21 | def forward(self, x, leaf_mask, numd, lnumd): 22 | # downsample nodes at layer depth 23 | outd = x[-numd:] 24 | outd = self.downsample(outd) 25 | 26 | # get the nodes at layer (depth-1) 27 | out = torch.zeros(leaf_mask.shape[0], x.shape[1], device=x.device) 28 | out[leaf_mask.logical_not()] = outd 29 | 30 | if self.channels_in != self.channels_out: 31 | out = self.conv1x1(out) 32 | return out 33 | 34 | 35 | class GraphSLUpsample(modules.GraphUpsample): 36 | 37 | def forward(self, x, leaf_mask, numd): 38 | # upsample nodes at layer (depth-1) 39 | outd = x[-numd:] 40 | out = outd[leaf_mask.logical_not()] 41 | out = self.upsample(out) 42 | 43 | if self.channels_in != self.channels_out: 44 | out = self.conv1x1(out) 45 | return out 46 | 47 | 48 | class GraphSLOUNet(graph_ounet.GraphOUNet): 49 | 50 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6, 51 | resblk_type='bottleneck', bottleneck=4): 52 | super().__init__(depth, channel_in, nout, full_depth, depth_out, 53 | resblk_type, bottleneck) 54 | 55 | self.downsample = torch.nn.ModuleList( 56 | [GraphSLDownsample(self.channels[d], self.channels[d-1]) 57 | for d in range(depth, full_depth, -1)]) 58 | 59 | self.upsample = torch.nn.ModuleList( 60 | [GraphSLUpsample(self.channels[d-1], self.channels[d]) 61 | for d in range(full_depth+1, depth+1)]) 62 | 63 | def _get_input_feature(self, doctree): 64 | return doctree.get_input_feature(all_leaf_nodes=False) 65 | 66 | def octree_decoder(self, convs, doctree_out, doctree, update_octree=False): 67 | logits = dict() 68 | reg_voxs = dict() 69 | deconvs = dict() 70 | reg_voxs_list = [] 71 | 72 | deconvs[self.full_depth] = convs[self.full_depth] 73 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)): 74 | if d > self.full_depth: 75 | nnum = doctree_out.nnum[d-1] 76 | leaf_mask = doctree_out.node_child(d-1) < 0 77 | deconvd = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum) 78 | skip = modules.doctree_align( 79 | convs[d], doctree.graph[d]['keyd'], doctree_out.graph[d]['keyd']) 80 | deconvd = deconvd + skip # skip connections 81 | 82 | edge_idx = doctree_out.graph[d]['edge_idx'] 83 | edge_type = doctree_out.graph[d]['edge_dir'] 84 | node_type = doctree_out.graph[d]['node_type'] 85 | deconvs[d] = self.decoder[i-1](deconvd, edge_idx, edge_type, node_type) 86 | 87 | # predict the splitting label 88 | logit = self.predict[i](deconvs[d]) 89 | nnum = doctree_out.nnum[d] 90 | logits[d] = logit[-nnum:] 91 | 92 | # update the octree according to predicted labels 93 | if update_octree: 94 | label = logits[d].argmax(1).to(torch.int32) 95 | octree_out = doctree_out.octree 96 | octree_out = ocnn.octree_update(octree_out, label, d, split=1) 97 | if d < self.depth_out: 98 | octree_out = ocnn.octree_grow(octree_out, target_depth=d+1) 99 | doctree_out = dual_octree.DualOctree(octree_out) 100 | doctree_out.post_processing_for_ocnn() 101 | 102 | # predict the signal 103 | reg_vox = self.regress[i](deconvs[d]) 104 | reg_voxs_list.append(reg_vox) 105 | reg_voxs[d] = torch.cat(reg_voxs_list, dim=0) 106 | 107 | return logits, reg_voxs, doctree_out.octree 108 | 109 | def forward(self, octree_in, octree_out=None, pos=None): 110 | # generate dual octrees 111 | doctree_in = dual_octree.DualOctree(octree_in) 112 | doctree_in.post_processing_for_ocnn() 113 | 114 | update_octree = octree_out is None 115 | if update_octree: 116 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout) 117 | doctree_out = dual_octree.DualOctree(octree_out) 118 | doctree_out.post_processing_for_ocnn() 119 | 120 | # run encoder and decoder 121 | convs = self.octree_encoder(octree_in, doctree_in) 122 | out = self.octree_decoder(convs, doctree_out, doctree_in, update_octree) 123 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]} 124 | 125 | # mpus 126 | if pos is not None: 127 | output['mpus'] = self.neural_mpu(pos, out[1], octree_out) 128 | 129 | return output 130 | -------------------------------------------------------------------------------- /tools/dfaust.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import argparse 10 | import trimesh 11 | import trimesh.sample 12 | import numpy as np 13 | import time 14 | import zipfile 15 | import wget 16 | from tqdm import tqdm 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--run', type=str, default='prepare_dataset') 21 | parser.add_argument('--filelist', type=str, default='test.txt') 22 | parser.add_argument('--mesh_folder', type=str, default='logs/dfaust/mesh') 23 | parser.add_argument('--output_folder', type=str, default='logs/dfaust/mesh') 24 | args = parser.parse_args() 25 | 26 | 27 | project_folder = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 28 | root_folder = os.path.join(project_folder, 'data/dfaust') 29 | shape_scale = 0.8 30 | 31 | 32 | def create_flag_file(filename): 33 | r''' Creates a flag file to indicate whether some time-consuming works 34 | have been done. 35 | ''' 36 | 37 | folder = os.path.dirname(filename) 38 | if not os.path.exists(folder): 39 | os.makedirs(folder) 40 | with open(filename, 'w') as fid: 41 | fid.write('succ @ ' + time.ctime()) 42 | 43 | 44 | def check_folder(filenames: list): 45 | r''' Checks whether the folder contains the filename exists. 46 | ''' 47 | 48 | for filename in filenames: 49 | folder = os.path.dirname(filename) 50 | if not os.path.exists(folder): 51 | os.makedirs(folder) 52 | 53 | 54 | def get_filenames(filelist, root_folder): 55 | r''' Gets filenames from a filelist. 56 | ''' 57 | 58 | filelist = os.path.join(root_folder, 'filelist', filelist) 59 | with open(filelist, 'r') as fid: 60 | lines = fid.readlines() 61 | filenames = [line.split()[0] for line in lines] 62 | return filenames 63 | 64 | 65 | def download_filelist(): 66 | r''' Downloads the filelists used for learning. 67 | ''' 68 | 69 | print('-> Download the filelist.') 70 | url = 'https://www.dropbox.com/s/vxkpaz3umzjvi66/dfaust.filelist.zip?dl=1' 71 | filename = os.path.join(root_folder, 'filelist.zip') 72 | wget.download(url, filename, bar=None) 73 | 74 | folder = os.path.join(root_folder, 'filelist') 75 | with zipfile.ZipFile(filename, 'r') as zip_ref: 76 | zip_ref.extractall(path=folder) 77 | os.remove(filename) 78 | 79 | 80 | def sample_points(): 81 | r''' Samples points from raw scanns for training. 82 | ''' 83 | 84 | num_samples = 100000 85 | print('-> Sample points.') 86 | scans_folder = os.path.join(root_folder, 'scans') 87 | dataset_folder = os.path.join(root_folder, 'dataset') 88 | filenames = get_filenames('all.txt', root_folder) 89 | for filename in tqdm(filenames, ncols=80): 90 | filename_ply = os.path.join(scans_folder, filename + '.ply') 91 | filename_pts = os.path.join(dataset_folder, filename + '.npy') 92 | filename_center = filename_pts[:-3] + 'center.npy' 93 | check_folder([filename_pts]) 94 | 95 | # sample points 96 | mesh = trimesh.load(filename_ply) 97 | points, idx = trimesh.sample.sample_surface(mesh, num_samples) 98 | normals = mesh.face_normals[idx] 99 | 100 | # normalize: Centralize + Scale 101 | center = np.mean(points, axis=0, keepdims=True) 102 | points = (points - center) * shape_scale 103 | point_cloud = np.concatenate((points, normals), axis=-1).astype(np.float32) 104 | 105 | # save 106 | np.save(filename_pts, point_cloud) 107 | np.save(filename_center, center) 108 | 109 | 110 | def rescale_mesh(): 111 | r''' Rescales and translates the generated mesh to align with the raw scans 112 | to compute evaluation metrics. 113 | ''' 114 | 115 | filenames = get_filenames(args.filelist, root_folder) 116 | for filename in tqdm(filenames, ncols=80): 117 | filename = filename[:-4] 118 | filename_output = os.path.join(args.output_folder, filename + '.obj') 119 | filename_mesh = os.path.join(args.mesh_folder, filename + '.obj') 120 | filename_center = os.path.join( 121 | root_folder, 'dataset', filename + '.center.npy') 122 | check_folder([filename_output]) 123 | 124 | center = np.load(filename_center) 125 | mesh = trimesh.load(filename_mesh) 126 | vertices = mesh.vertices / shape_scale + center 127 | mesh.vertices = vertices 128 | mesh.export(filename_output) 129 | 130 | 131 | def generate_dataset(): 132 | download_filelist() 133 | sample_points() 134 | 135 | 136 | def download_dataset(): 137 | download_filelist() 138 | 139 | print('-> Download the dataset.') 140 | flag_file = os.path.join(root_folder, 'flags/download_dataset_succ') 141 | if not os.path.exists(flag_file): 142 | url = 'https://www.dropbox.com/s/eb5uk8f2fqswhs3/dfaust.dataset.zip?dl=1' 143 | filename = os.path.join(root_folder, 'dfaust.dataset.zip') 144 | wget.download(url, filename, bar=None) 145 | 146 | with zipfile.ZipFile(filename, 'r') as zip_ref: 147 | zip_ref.extractall(path=root_folder) 148 | # os.remove(filename) 149 | create_flag_file(flag_file) 150 | 151 | 152 | if __name__ == '__main__': 153 | eval('%s()' % args.run) 154 | -------------------------------------------------------------------------------- /models/mpu.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | import torch.nn 11 | 12 | from .utils.spmm import spmm, modulated_spmm 13 | 14 | kNN = 8 15 | 16 | 17 | class ABS(torch.autograd.Function): 18 | '''The derivative of torch.abs on `0` is `0`, and in this implementation, we 19 | modified it to `1` 20 | ''' 21 | @staticmethod 22 | def forward(ctx, input): 23 | ctx.save_for_backward(input) 24 | return input.abs() 25 | 26 | @staticmethod 27 | def backward(ctx, grad_in): 28 | input, = ctx.saved_tensors 29 | sign = input < 0 30 | grad_out = grad_in * (-2.0 * sign.to(input.dtype) + 1.0) 31 | return grad_out 32 | 33 | 34 | def linear_basis(x): 35 | return 1.0 - ABS.apply(x) 36 | 37 | 38 | def get_linear_mask(dim=3): 39 | mask = torch.tensor([0, 1], dtype=torch.float32) 40 | mask = torch.meshgrid([mask]*dim) 41 | mask = torch.stack(mask, -1).view(-1, dim) 42 | return mask 43 | 44 | 45 | def octree_linear_pts(octree, depth, pts): 46 | # get neigh coordinates 47 | scale = 2 ** depth 48 | mask = get_linear_mask(dim=3).to(pts.device) 49 | xyzf, ids = torch.split(pts, [3, 1], 1) 50 | xyzf = (xyzf + 1.0) * (scale / 2.0) # [-1, 1] -> [0, scale] 51 | xyzf = xyzf - 0.5 # the code is defined on the center 52 | xyzi = torch.floor(xyzf).detach() # the integer part (N, 3), use floor 53 | corners = xyzi.unsqueeze(1) + mask # (N, 8, 3) 54 | coordsf = xyzf.unsqueeze(1) - corners # (N, 8, 3), in [-1.0, 1.0] 55 | 56 | # coorers -> key 57 | ids = ids.detach().repeat(1, kNN).unsqueeze(-1) # (N, 8, 1) 58 | key = torch.cat([corners, ids], dim=-1).view(-1, 4).short() # (N*8, 4) 59 | key = ocnn.octree_encode_key(key).long() # (N*8, ) 60 | idx = ocnn.octree_search_key(key, octree, depth, key_is_xyz=True) 61 | 62 | # corners -> flags 63 | valid = torch.logical_and(corners > -1, corners < scale) # out-of-bound 64 | valid = torch.all(valid, dim=-1).view(-1) 65 | flgs = torch.logical_and(idx > -1, valid) 66 | 67 | # remove invalid pts 68 | idx = idx[flgs].long() # (N*8, ) -> (N', ) 69 | coordsf = coordsf.view(-1, 3)[flgs] # (N, 8, 3) -> (N', 3) 70 | 71 | # bspline weights 72 | weights = linear_basis(coordsf) # (N', 3) 73 | weights = torch.prod(weights, axis=-1).view(-1) # (N', ) 74 | # Here, the scale factor `2**(depth - 6)` is used to emphasize high-resolution 75 | # basis functions. Tune this factor further if needed! !!! NOTE !!! 76 | # weights = weights * 2**(depth - 6) # used for shapenet 77 | weights = weights * (depth**2 / 50) # testing 78 | 79 | # rescale back the original scale 80 | # After recaling, the coordsf is in the same scale as pts 81 | coordsf = coordsf * (2.0 / scale) # [-1.0, 1.0] -> [-2.0/scale, 2.0/scale] 82 | return {'idx': idx, 'xyzf': coordsf, 'weights': weights, 'flgs': flgs} 83 | 84 | 85 | def get_linear_pred(pts, octree, shape_code, neighs, depth_start, depth_end): 86 | npt = pts.size(0) 87 | indices, weights, xyzfs = [], [], [] 88 | nnum_cum = ocnn.octree_property(octree, 'node_num_cum') 89 | ids = torch.arange(npt, device=pts.device, dtype=torch.long) 90 | ids = ids.unsqueeze(-1).repeat(1, kNN).view(-1) 91 | for d in range(depth_start, depth_end+1): 92 | neighd = neighs[d] 93 | idxd = neighd['idx'] 94 | xyzfd = neighd['xyzf'] 95 | weightd = neighd['weights'] 96 | valid = neighd['flgs'] 97 | idsd = ids[valid] 98 | 99 | if d < depth_end: 100 | child = ocnn.octree_property(octree, 'child', d) 101 | leaf = child[idxd] < 0 # keep only leaf nodes 102 | idsd, idxd, weightd, xyzfd = idsd[leaf], idxd[leaf], weightd[leaf], xyzfd[leaf] 103 | 104 | idxd = idxd + (nnum_cum[d] - nnum_cum[depth_start]) 105 | indices.append(torch.stack([idsd, idxd], dim=1)) 106 | weights.append(weightd) 107 | xyzfs.append(xyzfd) 108 | 109 | indices = torch.cat(indices, dim=0).t() 110 | weights = torch.cat(weights, dim=0) 111 | xyzfs = torch.cat(xyzfs, dim=0) 112 | 113 | code_num = shape_code.size(0) 114 | output = modulated_spmm(indices, weights, npt, code_num, shape_code, xyzfs) 115 | norm = spmm(indices, weights, npt, code_num, torch.ones(code_num, 1).cuda()) 116 | output = torch.div(output, norm + 1e-8).squeeze() 117 | 118 | # whether the point has affected by the octree node in depth layer 119 | mask = neighs[depth_end]['flgs'].view(-1, kNN).any(axis=-1) 120 | return output, mask 121 | 122 | 123 | class NeuralMPU: 124 | def __init__(self, full_depth, depth): 125 | self.full_depth = full_depth 126 | self.depth = depth 127 | 128 | def __call__(self, pos, reg_voxs, octree_out): 129 | mpus = dict() 130 | neighs = dict() 131 | for d in range(self.full_depth, self.depth+1): 132 | neighs[d] = octree_linear_pts(octree_out, d, pos) 133 | fval, flgs = get_linear_pred( 134 | pos, octree_out, reg_voxs[d], neighs, self.full_depth, d) 135 | mpus[d] = (fval, flgs) 136 | return mpus 137 | -------------------------------------------------------------------------------- /datasets/shapenet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import ocnn 10 | import torch 11 | import numpy as np 12 | 13 | from solver import Dataset 14 | from .utils import collate_func 15 | 16 | 17 | class TransformShape: 18 | 19 | def __init__(self, flags): 20 | self.flags = flags 21 | 22 | self.point_sample_num = 3000 23 | self.sdf_sample_num = 5000 24 | self.points_scale = 0.5 # the points are in [-0.5, 0.5] 25 | self.noise_std = 0.005 26 | self.points2octree = ocnn.Points2Octree(**flags) 27 | 28 | def process_points_cloud(self, sample): 29 | # get the input 30 | points, normals = sample['points'], sample['normals'] 31 | points = points / self.points_scale # scale to [-1.0, 1.0] 32 | 33 | # transform points to octree 34 | points_gt = ocnn.points_new( 35 | torch.from_numpy(points).float(), torch.from_numpy(normals).float(), 36 | torch.Tensor(), torch.Tensor()) 37 | points_gt, _ = ocnn.clip_points(points_gt, [-1.0]*3, [1.0]*3) 38 | octree_gt = self.points2octree(points_gt) 39 | 40 | if self.flags.distort: 41 | # randomly sample points and add noise 42 | # Since we rescale points to [-1.0, 1.0] in Line 24, we also need to 43 | # rescale the `noise_std` here to make sure the `noise_std` is always 44 | # 0.5% of the bounding box size. 45 | noise_std = self.noise_std / self.points_scale 46 | noise = noise_std * np.random.randn(self.point_sample_num, 3) 47 | rand_idx = np.random.choice(points.shape[0], size=self.point_sample_num) 48 | points_noise = points[rand_idx] + noise 49 | 50 | points_in = ocnn.points_new( 51 | torch.from_numpy(points_noise).float(), torch.Tensor(), 52 | torch.ones(self.point_sample_num).float(), torch.Tensor()) 53 | points_in, _ = ocnn.clip_points(points_in, [-1.0]*3, [1.0]*3) 54 | octree_in = self.points2octree(points_in) 55 | else: 56 | points_in = points_gt 57 | octree_in = octree_gt 58 | 59 | # construct the output dict 60 | return {'octree_in': octree_in, 'points_in': points_in, 61 | 'octree_gt': octree_gt, 'points_gt': points_gt} 62 | 63 | def sample_sdf(self, sample): 64 | sdf = sample['sdf'] 65 | grad = sample['grad'] 66 | points = sample['points'] / self.points_scale # to [-1, 1] 67 | 68 | rand_idx = np.random.choice(points.shape[0], size=self.sdf_sample_num) 69 | points = torch.from_numpy(points[rand_idx]).float() 70 | sdf = torch.from_numpy(sdf[rand_idx]).float() 71 | grad = torch.from_numpy(grad[rand_idx]).float() 72 | return {'pos': points, 'sdf': sdf, 'grad': grad} 73 | 74 | def sample_on_surface(self, points, normals): 75 | rand_idx = np.random.choice(points.shape[0], size=self.sdf_sample_num) 76 | xyz = torch.from_numpy(points[rand_idx]).float() 77 | grad = torch.from_numpy(normals[rand_idx]).float() 78 | sdf = torch.zeros(self.sdf_sample_num) 79 | return {'pos': xyz, 'sdf': sdf, 'grad': grad} 80 | 81 | def sample_off_surface(self, xyz): 82 | xyz = xyz / self.points_scale # to [-1, 1] 83 | 84 | rand_idx = np.random.choice(xyz.shape[0], size=self.sdf_sample_num) 85 | xyz = torch.from_numpy(xyz[rand_idx]).float() 86 | # grad = torch.zeros(self.sample_number, 3) # dummy grads 87 | grad = xyz / (xyz.norm(p=2, dim=1, keepdim=True) + 1.0e-6) 88 | sdf = -1 * torch.ones(self.sdf_sample_num) # dummy sdfs 89 | return {'pos': xyz, 'sdf': sdf, 'grad': grad} 90 | 91 | def __call__(self, sample, idx): 92 | output = self.process_points_cloud(sample['point_cloud']) 93 | 94 | # sample ground truth sdfs 95 | if self.flags.load_sdf: 96 | sdf_samples = self.sample_sdf(sample['sdf']) 97 | output.update(sdf_samples) 98 | 99 | # sample on surface points and off surface points 100 | if self.flags.sample_surf_points: 101 | on_surf = self.sample_on_surface(sample['points'], sample['normals']) 102 | off_surf = self.sample_off_surface(sample['sdf']['points']) # TODO 103 | sdf_samples = { 104 | 'pos': torch.cat([on_surf['pos'], off_surf['pos']], dim=0), 105 | 'grad': torch.cat([on_surf['grad'], off_surf['grad']], dim=0), 106 | 'sdf': torch.cat([on_surf['sdf'], off_surf['sdf']], dim=0)} 107 | output.update(sdf_samples) 108 | 109 | return output 110 | 111 | 112 | class ReadFile: 113 | def __init__(self, load_sdf=False, load_occu=False): 114 | self.load_occu = load_occu 115 | self.load_sdf = load_sdf 116 | 117 | def __call__(self, filename): 118 | filename_pc = os.path.join(filename, 'pointcloud.npz') 119 | raw = np.load(filename_pc) 120 | point_cloud = {'points': raw['points'], 'normals': raw['normals']} 121 | output = {'point_cloud': point_cloud} 122 | 123 | if self.load_occu: 124 | filename_occu = os.path.join(filename, 'points.npz') 125 | raw = np.load(filename_occu) 126 | occu = {'points': raw['points'], 'occupancies': raw['occupancies']} 127 | output['occu'] = occu 128 | 129 | if self.load_sdf: 130 | filename_sdf = os.path.join(filename, 'sdf.npz') 131 | raw = np.load(filename_sdf) 132 | sdf = {'points': raw['points'], 'grad': raw['grad'], 'sdf': raw['sdf']} 133 | output['sdf'] = sdf 134 | return output 135 | 136 | 137 | def get_shapenet_dataset(flags): 138 | transform = TransformShape(flags) 139 | read_file = ReadFile(flags.load_sdf, flags.load_occu) 140 | dataset = Dataset(flags.location, flags.filelist, transform, 141 | read_file=read_file, in_memory=flags.in_memory) 142 | return dataset, collate_func 143 | -------------------------------------------------------------------------------- /tools/room.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import wget 10 | import time 11 | import zipfile 12 | import argparse 13 | import numpy as np 14 | from tqdm import tqdm 15 | from plyfile import PlyData, PlyElement 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--run', type=str, required=True) 19 | args = parser.parse_args() 20 | 21 | project_folder = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 22 | root_folder = os.path.join(project_folder, 'data/room') 23 | 24 | 25 | def create_flag_file(filename): 26 | r''' Creates a flag file to indicate whether some time-consuming works 27 | have been done. 28 | ''' 29 | 30 | folder = os.path.dirname(filename) 31 | if not os.path.exists(folder): 32 | os.makedirs(folder) 33 | with open(filename, 'w') as fid: 34 | fid.write('succ @ ' + time.ctime()) 35 | 36 | 37 | def check_folder(filenames: list): 38 | r''' Checks whether the folder contains the filename exists. 39 | ''' 40 | 41 | for filename in filenames: 42 | folder = os.path.dirname(filename) 43 | if not os.path.exists(folder): 44 | os.makedirs(folder) 45 | 46 | 47 | def get_filenames(filelist, root_folder): 48 | r''' Gets filenames from a filelist. 49 | ''' 50 | 51 | filelist = os.path.join(root_folder, 'filelist', filelist) 52 | with open(filelist, 'r') as fid: 53 | lines = fid.readlines() 54 | filenames = [line.split()[0] for line in lines] 55 | return filenames 56 | 57 | 58 | def download_and_unzip(): 59 | r''' Dowanload and unzip the data. 60 | ''' 61 | 62 | filename = os.path.join(root_folder, 'synthetic_room_dataset.zip') 63 | flag_file = os.path.join(root_folder, 'flags/download_room_dataset_succ') 64 | if not os.path.exists(flag_file): 65 | check_folder([filename]) 66 | print('-> Download synthetic_room_dataset.zip.') 67 | url = 'https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/synthetic_room_dataset.zip' 68 | wget.download(url, filename) 69 | create_flag_file(flag_file) 70 | 71 | flag_file = os.path.join(root_folder, 'flags/unzip_succ') 72 | if not os.path.exists(flag_file): 73 | print('-> Unzip synthetic_room_dataset.zip.') 74 | with zipfile.ZipFile(filename, 'r') as zip_ref: 75 | zip_ref.extractall(path=root_folder) 76 | # os.remove(filename) 77 | create_flag_file(flag_file) 78 | 79 | 80 | def download_filelist(): 81 | r''' Downloads the filelists used for learning. 82 | ''' 83 | 84 | flag_file = os.path.join(root_folder, 'flags/download_filelist_succ') 85 | if not os.path.exists(flag_file): 86 | print('-> Download the filelist.') 87 | url = 'https://www.dropbox.com/s/30v6pdek6777vkr/room.filelist.zip?dl=1' 88 | filename = os.path.join(root_folder, 'filelist.zip') 89 | wget.download(url, filename, bar=None) 90 | 91 | folder = os.path.join(root_folder, 'filelist') 92 | with zipfile.ZipFile(filename, 'r') as zip_ref: 93 | zip_ref.extractall(path=folder) 94 | # os.remove(filename) 95 | create_flag_file(flag_file) 96 | 97 | 98 | def download_ground_truth_mesh(): 99 | r''' Downloads the ground-truth meshes 100 | ''' 101 | 102 | flag_file = os.path.join(root_folder, 'flags/download_mesh_succ') 103 | if not os.path.exists(flag_file): 104 | print('-> Download the filelist.') 105 | url = 'https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/room_watertight_mesh.zip' 106 | filename = os.path.join(root_folder, 'filelist.zip') 107 | wget.download(url, filename, bar=None) 108 | create_flag_file(flag_file) 109 | 110 | 111 | def generate_test_points(): 112 | r''' Generates points in `ply` format for testing. 113 | ''' 114 | 115 | noise_std = 0.005 116 | point_sample_num = 10000 117 | print('-> Generate testing points.') 118 | # filenames = get_filenames('all.txt') 119 | filenames = get_filenames('test.txt', root_folder) 120 | for filename in tqdm(filenames, ncols=80): 121 | filename_pts = os.path.join( 122 | root_folder, 'synthetic_room_dataset', filename, 'pointcloud', 'pointcloud_00.npz') 123 | filename_ply = os.path.join( 124 | root_folder, 'test.input', filename + '.ply') 125 | if not os.path.exists(filename_pts): continue 126 | check_folder([filename_ply]) 127 | 128 | # sample points 129 | pts = np.load(filename_pts) 130 | points = pts['points'].astype(np.float32) 131 | noise = noise_std * np.random.randn(point_sample_num, 3) 132 | rand_idx = np.random.choice(points.shape[0], size=point_sample_num) 133 | points_noise = points[rand_idx] + noise 134 | 135 | # save ply 136 | vertices = [] 137 | py_types = (float, float, float) 138 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 139 | for idx in range(points_noise.shape[0]): 140 | vertices.append( 141 | tuple(dtype(d) for dtype, d in zip(py_types, points_noise[idx]))) 142 | structured_array = np.array(vertices, dtype=npy_types) 143 | el = PlyElement.describe(structured_array, 'vertex') 144 | PlyData([el]).write(filename_ply) 145 | 146 | 147 | def download_test_points(): 148 | r''' Downloads the test points used in our paper. 149 | ''' 150 | print('-> Download testing points.') 151 | flag_file = os.path.join(root_folder, 'flags/download_test_points_succ') 152 | if not os.path.exists(flag_file): 153 | url = 'https://www.dropbox.com/s/q3h47042xh6sua7/scene.test.input.zip?dl=1' 154 | filename = os.path.join(root_folder, 'test.input.zip') 155 | wget.download(url, filename, bar=None) 156 | 157 | folder = os.path.join(root_folder, 'test.input') 158 | with zipfile.ZipFile(filename, 'r') as zip_ref: 159 | zip_ref.extractall(path=folder) 160 | # os.remove(filename) 161 | create_flag_file(flag_file) 162 | 163 | 164 | def generate_dataset(): 165 | download_and_unzip() 166 | download_filelist() 167 | # generate_test_points() 168 | download_test_points() 169 | 170 | 171 | if __name__ == '__main__': 172 | eval('%s()' % args.run) 173 | -------------------------------------------------------------------------------- /datasets/pointcloud.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import ocnn 10 | import torch 11 | import numpy as np 12 | 13 | from solver import Dataset 14 | from .utils import collate_func 15 | 16 | 17 | class Transform: 18 | 19 | def __init__(self, flags): 20 | self.flags = flags 21 | 22 | self.octant_sample_num = 2 23 | self.point_sample_num = flags.point_sample_num 24 | self.point_scale = flags.point_scale 25 | self.points2octree = ocnn.Points2Octree(**flags) 26 | 27 | def build_octree(self, points): 28 | pts, normals = points[:, :3], points[:, 3:] 29 | points_in = ocnn.points_new(pts, normals, torch.Tensor(), torch.Tensor()) 30 | # points_in, _ = ocnn.clip_points(points_in, [-1.0]*3, [1.0]*3) 31 | octree = self.points2octree(points_in) 32 | return {'octree_in': octree, 'points_in': points_in, 'octree_gt': octree} 33 | 34 | def sample_on_surface(self, points): 35 | '''Randomly sample points on the surface.''' 36 | 37 | rnd_idx = torch.randint(high=points.shape[0], size=(self.point_sample_num,)) 38 | pos = points[rnd_idx, :3] 39 | normal = points[rnd_idx, 3:] 40 | sdf = torch.zeros(self.point_sample_num) 41 | return {'pos': pos, 'grad': normal, 'sdf': sdf} 42 | 43 | def sample_off_surface(self, bbox): 44 | '''Randomly sample points in the 3D space.''' 45 | 46 | # uniformly sampling in the whole 3D sapce 47 | pos = torch.rand(self.point_sample_num, 3) * 2 - 1 48 | 49 | # point gradients 50 | # grad = torch.zeros(self.point_sample_num, 3) 51 | norm = torch.sqrt(torch.sum(pos**2, dim=1, keepdim=True)) + 1e-6 52 | grad = pos / norm # fake off-surface gradients 53 | 54 | # sdf values 55 | esp = 0.04 56 | bbmin, bbmax = bbox[:3] - esp, bbox[3:] + esp 57 | mask = torch.logical_and(pos > bbmin, pos < bbmax).all(1) # inbox 58 | sdf = -1.0 * torch.ones(self.point_sample_num) 59 | sdf[mask.logical_not()] = 1.0 # exactly out-of-bbox 60 | return {'pos': pos, 'grad': grad, 'sdf': sdf} 61 | 62 | def sample_on_octree(self, octree): 63 | '''Adaptively sample points according the octree in the 3D space.''' 64 | 65 | xyzs = [] 66 | sample_num = 0 67 | depth = ocnn.octree_property(octree, 'depth') 68 | full_depth = ocnn.octree_property(octree, 'full_depth') 69 | for d in range(full_depth, depth+1): 70 | # get octree key 71 | xyz = ocnn.octree_property(octree, 'xyz', d) 72 | empty_node = ocnn.octree_property(octree, 'child', d) < 0 73 | xyz = ocnn.octree_decode_key(xyz[empty_node]) 74 | 75 | # sample k points in each octree node 76 | xyz = xyz[:, :3].float() # + 0.5 -> octree node center 77 | rnd = torch.rand(xyz.shape[0], self.octant_sample_num, 3) 78 | xyz = xyz.unsqueeze(1) + rnd 79 | xyz = xyz.view(-1, 3) # (N, 3) 80 | xyz = xyz * (2 ** (1 - d)) - 1 # normalize to [-1, 1] 81 | xyzs.append(xyz) 82 | 83 | # make sure that the max sample number is self.point_sample_num 84 | sample_num += xyz.shape[0] 85 | if sample_num > self.point_sample_num: 86 | size = self.point_sample_num - (sample_num - xyz.shape[0]) 87 | rnd_idx = torch.randint(high=xyz.shape[0], size=(size,)) 88 | xyzs[-1] = xyzs[-1][rnd_idx] 89 | break 90 | 91 | pos = torch.cat(xyzs, dim=0) 92 | grad = torch.zeros_like(pos) 93 | sdf = -1.0 * torch.ones(pos.shape[0]) 94 | return {'pos': pos, 'grad': grad, 'sdf': sdf} 95 | 96 | def scale_and_clip_points(self, points): 97 | points[:, :3] = points[:, :3] * self.point_scale # rescale points 98 | pts = points[:, :3] 99 | mask = torch.logical_and(pts > -1.0, pts < 1.0).all(1) 100 | return points[mask] 101 | 102 | def compute_bbox(self, points): 103 | pts = points[:, :3] 104 | bbmin = pts.min(0)[0] - 0.06 105 | bbmax = pts.max(0)[0] + 0.06 106 | bbmin = torch.clamp(bbmin, min=-1, max=1) 107 | bbmax = torch.clamp(bbmax, min=-1, max=1) 108 | return torch.cat([bbmin, bbmax]) 109 | 110 | def __call__(self, points, idx): 111 | points = self.scale_and_clip_points(points) # clip points to [-1, 1] 112 | output = self.build_octree(points) 113 | bbox = self.compute_bbox(points) 114 | output['bbox'] = bbox # used in marching cubes 115 | 116 | sdf_on_surf = self.sample_on_surface(points) 117 | # TODO: compare self.sample_on_octree & self.sample_off_surface 118 | # sdf_off_surf = self.sample_on_octree(output['octree_in']) 119 | sdf_off_surf = self.sample_off_surface(bbox) 120 | sdfs = {key: torch.cat([sdf_on_surf[key], sdf_off_surf[key]], dim=0) 121 | for key in sdf_on_surf.keys()} 122 | output.update(sdfs) 123 | 124 | return output 125 | 126 | 127 | class SingleTransform(Transform): 128 | 129 | def __init__(self, flags): 130 | super().__init__(flags) 131 | self.output = None 132 | self.points = None 133 | self.bbox = None 134 | 135 | def __call__(self, points, idx): 136 | if self.output is None: 137 | self.points = self.scale_and_clip_points(points) # clip points to [-1, 1] 138 | self.output = self.build_octree(self.points) 139 | self.bbox = self.compute_bbox(self.points) 140 | self.output['bbox'] = self.bbox # used in marching cubes 141 | 142 | output = self.output 143 | sdf_on_surf = self.sample_on_surface(self.points) 144 | # TODO: compare self.sample_on_octree & self.sample_off_surface 145 | # sdf_off_surf = self.sample_on_octree(output['octree_in']) 146 | sdf_off_surf = self.sample_off_surface(self.bbox) 147 | sdfs = {key: torch.cat([sdf_on_surf[key], sdf_off_surf[key]], dim=0) 148 | for key in sdf_on_surf.keys()} 149 | output.update(sdfs) 150 | 151 | return output 152 | 153 | 154 | def read_file(filename: str): 155 | if filename.endswith('.xyz'): 156 | points = np.loadtxt(filename) 157 | elif filename.endswith('.npy'): 158 | points = np.load(filename) 159 | else: 160 | raise NotImplementedError 161 | output = torch.from_numpy(points.astype(np.float32)) 162 | return output 163 | 164 | 165 | def get_pointcloud_dataset(flags): 166 | transform = Transform(flags) 167 | dataset = Dataset(flags.location, flags.filelist, transform, 168 | read_file=read_file, in_memory=flags.in_memory) 169 | return dataset, collate_func 170 | 171 | 172 | def get_singlepointcloud_dataset(flags): 173 | transform = SingleTransform(flags) 174 | dataset = Dataset(flags.location, flags.filelist, transform, 175 | read_file=read_file, in_memory=flags.in_memory) 176 | return dataset, collate_func 177 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | # autopep8: off 9 | import ocnn 10 | import torch 11 | import torch.autograd 12 | import numpy as np 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | import skimage.measure 17 | import trimesh 18 | from plyfile import PlyData, PlyElement 19 | from scipy.spatial import cKDTree 20 | # autopep8: on 21 | 22 | 23 | def get_mgrid(size, dim=3): 24 | r''' 25 | Example: 26 | >>> get_mgrid(3, dim=2) 27 | array([[0.0, 0.0], 28 | [0.0, 1.0], 29 | [0.0, 2.0], 30 | [1.0, 0.0], 31 | [1.0, 1.0], 32 | [1.0, 2.0], 33 | [2.0, 0.0], 34 | [2.0, 1.0], 35 | [2.0, 2.0]], dtype=float32) 36 | ''' 37 | coord = np.arange(0, size, dtype=np.float32) 38 | coords = [coord] * dim 39 | output = np.meshgrid(*coords, indexing='ij') 40 | output = np.stack(output, -1) 41 | output = output.reshape(size**dim, dim) 42 | return output 43 | 44 | 45 | def lin2img(tensor): 46 | channels = 1 47 | num_samples = tensor.shape 48 | size = int(np.sqrt(num_samples)) 49 | return tensor.view(channels, size, size) 50 | 51 | 52 | def make_contour_plot(array_2d, mode='log'): 53 | fig, ax = plt.subplots(figsize=(2.75, 2.75), dpi=300) 54 | 55 | if(mode == 'log'): 56 | nlevels = 6 57 | levels_pos = np.logspace(-2, 0, num=nlevels) # logspace 58 | levels_neg = -1. * levels_pos[::-1] 59 | levels = np.concatenate((levels_neg, np.zeros((0)), levels_pos), axis=0) 60 | colors = plt.get_cmap("Spectral")(np.linspace(0., 1., num=nlevels * 2 + 1)) 61 | elif(mode == 'lin'): 62 | nlevels = 10 63 | levels = np.linspace(-.5, .5, num=nlevels) 64 | colors = plt.get_cmap("Spectral")(np.linspace(0., 1., num=nlevels)) 65 | else: 66 | raise NotImplementedError 67 | 68 | sample = np.flipud(array_2d) 69 | CS = ax.contourf(sample, levels=levels, colors=colors) 70 | cbar = fig.colorbar(CS) 71 | 72 | ax.contour(sample, levels=levels, colors='k', linewidths=0.1) 73 | ax.contour(sample, levels=[0], colors='k', linewidths=0.3) 74 | ax.axis('off') 75 | return fig 76 | 77 | 78 | def write_sdf_summary(model, writer, global_step, alias=''): 79 | size = 128 80 | coords_2d = get_mgrid(size, dim=2) 81 | coords_2d = coords_2d / size - 1.0 # [0, size] -> [-1, 1] 82 | coords_2d = torch.from_numpy(coords_2d) 83 | with torch.no_grad(): 84 | zeros = torch.zeros_like(coords_2d[:, :1]) 85 | ones = torch.ones_like(coords_2d[:, :1]) 86 | names = ['train_yz_sdf_slice', 'train_xz_sdf_slice', 'train_xy_sdf_slice'] 87 | coords = [torch.cat((zeros, coords_2d), dim=-1), 88 | torch.cat((coords_2d[:, :1], zeros, coords_2d[:, -1:]), dim=-1), 89 | torch.cat((coords_2d, -0.75 * ones), dim=-1)] 90 | for name, coord in zip(names, coords): 91 | ids = torch.zeros(coord.shape[0], 1) 92 | coord = torch.cat([coord, ids], dim=1).cuda() 93 | sdf_values = model(coord) 94 | sdf_values = lin2img(sdf_values).squeeze().cpu().numpy() 95 | fig = make_contour_plot(sdf_values) 96 | writer.add_figure(alias + name, fig, global_step=global_step) 97 | 98 | 99 | def calc_sdf(model, size=256, max_batch=64**3, bbmin=-1.0, bbmax=1.0): 100 | # generate samples 101 | num_samples = size ** 3 102 | samples = get_mgrid(size, dim=3) 103 | samples = samples * ((bbmax - bbmin) / size) + bbmin # [0,sz]->[bbmin,bbmax] 104 | samples = torch.from_numpy(samples) 105 | sdfs = torch.zeros(num_samples) 106 | 107 | # forward 108 | head = 0 109 | while head < num_samples: 110 | tail = min(head + max_batch, num_samples) 111 | sample_subset = samples[head:tail, :] 112 | idx = torch.zeros(sample_subset.shape[0], 1) 113 | pts = torch.cat([sample_subset, idx], dim=1).cuda() 114 | pred = model(pts).squeeze().detach().cpu() 115 | sdfs[head:tail] = pred 116 | head += max_batch 117 | sdfs = sdfs.reshape(size, size, size).numpy() 118 | return sdfs 119 | 120 | 121 | def create_mesh(model, filename, size=256, max_batch=64**3, level=0, 122 | bbmin=-0.9, bbmax=0.9, mesh_scale=1.0, save_sdf=False, **kwargs): 123 | # marching cubes 124 | sdf_values = calc_sdf(model, size, max_batch, bbmin, bbmax) 125 | vtx, faces = np.zeros((0, 3)), np.zeros((0, 3)) 126 | try: 127 | vtx, faces, _, _ = skimage.measure.marching_cubes(sdf_values, level) 128 | except: 129 | pass 130 | if vtx.size == 0 or faces.size == 0: 131 | print('Warning from marching cubes: Empty mesh!') 132 | return 133 | 134 | # normalize vtx 135 | vtx = vtx * ((bbmax - bbmin) / size) + bbmin # [0,sz]->[bbmin,bbmax] 136 | vtx = vtx * mesh_scale # rescale 137 | 138 | # save to ply and npy 139 | mesh = trimesh.Trimesh(vtx, faces) 140 | mesh.export(filename) 141 | if save_sdf: 142 | np.save(filename[:-4] + ".sdf.npy", sdf_values) 143 | 144 | 145 | def calc_sdf_err(filename_gt, filename_pred): 146 | scale = 1.0e2 # scale the result for better display 147 | sdf_gt = np.load(filename_gt) 148 | sdf = np.load(filename_pred) 149 | err = np.abs(sdf - sdf_gt).mean() * scale 150 | return err 151 | 152 | 153 | def calc_chamfer(filename_gt, filename_pred, point_num): 154 | scale = 1.0e5 # scale the result for better display 155 | np.random.seed(101) 156 | 157 | mesh_a = trimesh.load(filename_gt) 158 | points_a, _ = trimesh.sample.sample_surface(mesh_a, point_num) 159 | mesh_b = trimesh.load(filename_pred) 160 | points_b, _ = trimesh.sample.sample_surface(mesh_b, point_num) 161 | 162 | kdtree_a = cKDTree(points_a) 163 | dist_a, _ = kdtree_a.query(points_b) 164 | chamfer_a = np.mean(np.square(dist_a)) * scale 165 | 166 | kdtree_b = cKDTree(points_b) 167 | dist_b, _ = kdtree_b.query(points_a) 168 | chamfer_b = np.mean(np.square(dist_b)) * scale 169 | return chamfer_a, chamfer_b 170 | 171 | 172 | def points2ply(filename, points, scale=1.0): 173 | xyz = ocnn.points_property(points, 'xyz') 174 | normal = ocnn.points_property(points, 'normal') 175 | has_normal = normal is not None 176 | xyz = xyz.numpy() * scale 177 | if has_normal: normal = normal.numpy() 178 | 179 | # data types 180 | data = xyz 181 | py_types = (float, float, float) 182 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 183 | if has_normal: 184 | py_types = py_types + (float, float, float) 185 | npy_types = npy_types + [('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4')] 186 | data = np.concatenate((data, normal), axis=1) 187 | 188 | # format into NumPy structured array 189 | vertices = [] 190 | for idx in range(data.shape[0]): 191 | vertices.append(tuple(dtype(d) for dtype, d in zip(py_types, data[idx]))) 192 | structured_array = np.array(vertices, dtype=npy_types) 193 | el = PlyElement.describe(structured_array, 'vertex') 194 | 195 | # write ply 196 | PlyData([el]).write(filename) 197 | -------------------------------------------------------------------------------- /models/graph_ounet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | import torch.nn 11 | 12 | from . import mpu 13 | from . import modules 14 | from . import dual_octree 15 | 16 | 17 | class GraphOUNet(torch.nn.Module): 18 | 19 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6, 20 | resblk_type='bottleneck', bottleneck=4): 21 | super().__init__() 22 | self.depth = depth 23 | self.channel_in = channel_in 24 | self.nout = nout 25 | self.full_depth = full_depth 26 | self.depth_out = depth_out 27 | self.resblk_type = resblk_type 28 | self.bottleneck = bottleneck 29 | self.neural_mpu = mpu.NeuralMPU(self.full_depth, self.depth_out) 30 | self._setup_channels_and_resblks() 31 | n_edge_type, avg_degree = 7, 7 32 | 33 | # encoder 34 | self.conv1 = modules.GraphConvBnRelu( 35 | channel_in, self.channels[depth], n_edge_type, avg_degree, depth-1) 36 | self.encoder = torch.nn.ModuleList( 37 | [modules.GraphResBlocks(self.channels[d], self.channels[d], 38 | self.resblk_num[d], bottleneck, n_edge_type, avg_degree, d-1, resblk_type) 39 | for d in range(depth, full_depth-1, -1)]) 40 | self.downsample = torch.nn.ModuleList( 41 | [modules.GraphDownsample(self.channels[d], self.channels[d-1]) 42 | for d in range(depth, full_depth, -1)]) 43 | 44 | # decoder 45 | self.upsample = torch.nn.ModuleList( 46 | [modules.GraphUpsample(self.channels[d-1], self.channels[d]) 47 | for d in range(full_depth+1, depth + 1)]) 48 | self.decoder = torch.nn.ModuleList( 49 | [modules.GraphResBlocks(self.channels[d], self.channels[d], 50 | self.resblk_num[d], bottleneck, n_edge_type, avg_degree, d-1, resblk_type) 51 | for d in range(full_depth+1, depth + 1)]) 52 | 53 | # header 54 | self.predict = torch.nn.ModuleList( 55 | [self._make_predict_module(self.channels[d], 2) 56 | for d in range(full_depth, depth + 1)]) 57 | self.regress = torch.nn.ModuleList( 58 | [self._make_predict_module(self.channels[d], 4) 59 | for d in range(full_depth, depth + 1)]) 60 | 61 | def _setup_channels_and_resblks(self): 62 | # self.resblk_num = [3] * 7 + [1] + [1] * 9 63 | self.resblk_num = [3] * 16 64 | self.channels = [4, 512, 512, 256, 128, 64, 32, 32, 24] 65 | 66 | def _make_predict_module(self, channel_in, channel_out=2, num_hidden=32): 67 | return torch.nn.Sequential( 68 | modules.Conv1x1BnRelu(channel_in, num_hidden), 69 | modules.Conv1x1(num_hidden, channel_out, use_bias=True)) 70 | 71 | def _get_input_feature(self, doctree): 72 | return doctree.get_input_feature() 73 | 74 | def octree_encoder(self, octree, doctree): 75 | depth, full_depth = self.depth, self.full_depth 76 | data = self._get_input_feature(doctree) 77 | 78 | convs = dict() 79 | convs[depth] = data 80 | for i, d in enumerate(range(depth, full_depth-1, -1)): 81 | # perform graph conv 82 | convd = convs[d] # get convd 83 | edge_idx = doctree.graph[d]['edge_idx'] 84 | edge_type = doctree.graph[d]['edge_dir'] 85 | node_type = doctree.graph[d]['node_type'] 86 | if d == self.depth: # the first conv 87 | convd = self.conv1(convd, edge_idx, edge_type, node_type) 88 | convd = self.encoder[i](convd, edge_idx, edge_type, node_type) 89 | convs[d] = convd # update convd 90 | 91 | # downsampleing 92 | if d > full_depth: # init convd 93 | nnum = doctree.nnum[d] 94 | lnum = doctree.lnum[d-1] 95 | leaf_mask = doctree.node_child(d-1) < 0 96 | convs[d-1] = self.downsample[i](convd, leaf_mask, nnum, lnum) 97 | 98 | return convs 99 | 100 | def octree_decoder(self, convs, doctree_out, doctree, update_octree=False): 101 | logits = dict() 102 | reg_voxs = dict() 103 | deconvs = dict() 104 | 105 | deconvs[self.full_depth] = convs[self.full_depth] 106 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)): 107 | if d > self.full_depth: 108 | nnum = doctree_out.nnum[d-1] 109 | leaf_mask = doctree_out.node_child(d-1) < 0 110 | deconvd = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum) 111 | skip = modules.doctree_align( 112 | convs[d], doctree.graph[d]['keyd'], doctree_out.graph[d]['keyd']) 113 | deconvd = deconvd + skip # skip connections 114 | 115 | edge_idx = doctree_out.graph[d]['edge_idx'] 116 | edge_type = doctree_out.graph[d]['edge_dir'] 117 | node_type = doctree_out.graph[d]['node_type'] 118 | deconvs[d] = self.decoder[i-1](deconvd, edge_idx, edge_type, node_type) 119 | 120 | # predict the splitting label 121 | logit = self.predict[i](deconvs[d]) 122 | nnum = doctree_out.nnum[d] 123 | logits[d] = logit[-nnum:] 124 | 125 | # update the octree according to predicted labels 126 | if update_octree: 127 | label = logits[d].argmax(1).to(torch.int32) 128 | octree_out = doctree_out.octree 129 | octree_out = ocnn.octree_update(octree_out, label, d, split=1) 130 | if d < self.depth_out: 131 | octree_out = ocnn.octree_grow(octree_out, target_depth=d+1) 132 | doctree_out = dual_octree.DualOctree(octree_out) 133 | doctree_out.post_processing_for_docnn() 134 | 135 | # predict the signal 136 | reg_vox = self.regress[i](deconvs[d]) 137 | 138 | # TODO: improve it 139 | # pad zeros to reg_vox to reuse the original code for ocnn 140 | node_mask = doctree_out.graph[d]['node_mask'] 141 | shape = (node_mask.shape[0], reg_vox.shape[1]) 142 | reg_vox_pad = torch.zeros(shape, device=reg_vox.device) 143 | reg_vox_pad[node_mask] = reg_vox 144 | reg_voxs[d] = reg_vox_pad 145 | 146 | return logits, reg_voxs, doctree_out.octree 147 | 148 | def forward(self, octree_in, octree_out=None, pos=None): 149 | # generate dual octrees 150 | doctree_in = dual_octree.DualOctree(octree_in) 151 | doctree_in.post_processing_for_docnn() 152 | 153 | update_octree = octree_out is None 154 | if update_octree: 155 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout) 156 | doctree_out = dual_octree.DualOctree(octree_out) 157 | doctree_out.post_processing_for_docnn() 158 | 159 | # run encoder and decoder 160 | convs = self.octree_encoder(octree_in, doctree_in) 161 | out = self.octree_decoder(convs, doctree_out, doctree_in, update_octree) 162 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]} 163 | 164 | # compute function value with mpu 165 | if pos is not None: 166 | output['mpus'] = self.neural_mpu(pos, out[1], out[2]) 167 | 168 | # create the mpu wrapper 169 | def _neural_mpu(pos): 170 | pred = self.neural_mpu(pos, out[1], out[2]) 171 | return pred[self.depth_out][0] 172 | output['neural_mpu'] = _neural_mpu 173 | 174 | return output 175 | -------------------------------------------------------------------------------- /solver/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | # autopep8: off 9 | import os 10 | import sys 11 | import shutil 12 | import argparse 13 | from datetime import datetime 14 | from yacs.config import CfgNode as CN 15 | 16 | _C = CN() 17 | 18 | # SOLVER related parameters 19 | _C.SOLVER = CN() 20 | _C.SOLVER.alias = '' # The experiment alias 21 | _C.SOLVER.gpu = (0,) # The gpu ids 22 | _C.SOLVER.run = 'train' # Choose from train or test 23 | 24 | _C.SOLVER.logdir = 'logs' # Directory where to write event logs 25 | _C.SOLVER.ckpt = '' # Restore weights from checkpoint file 26 | _C.SOLVER.ckpt_num = 10 # The number of checkpoint kept 27 | 28 | _C.SOLVER.type = 'sgd' # Choose from sgd or adam 29 | _C.SOLVER.weight_decay = 0.0005 # The weight decay on model weights 30 | _C.SOLVER.max_epoch = 300 # Maximum training epoch 31 | _C.SOLVER.eval_epoch = 1 # Maximum evaluating epoch 32 | _C.SOLVER.eval_step = -1 # Maximum evaluating steps 33 | _C.SOLVER.test_every_epoch = 10 # Test model every n training epochs 34 | _C.SOLVER.log_per_iter = -1 # Output log every k training iteration 35 | 36 | _C.SOLVER.lr_type = 'step' # Learning rate type: step or cos 37 | _C.SOLVER.lr = 0.1 # Initial learning rate 38 | _C.SOLVER.gamma = 0.1 # Learning rate step-wise decay 39 | _C.SOLVER.step_size = (120,60,) # Learning rate step size. 40 | _C.SOLVER.lr_power = 0.9 # Used in poly learning rate 41 | 42 | _C.SOLVER.dist_url = 'tcp://localhost:10001' 43 | _C.SOLVER.progress_bar = True 44 | 45 | 46 | # DATA related parameters 47 | _C.DATA = CN() 48 | _C.DATA.train = CN() 49 | _C.DATA.train.name = '' # The name of the dataset 50 | 51 | # For octree building 52 | # If node_dis = True and there are normals, the octree features 53 | # is 4 channels, i.e., the average normals and the 1 channel displacement. 54 | # If node_dis = True and there are no normals, the feature is also 4 channels, 55 | # i.e., a 3 channel # displacement of average points relative to the center 56 | # points, and the last channel is constant. 57 | _C.DATA.train.depth = 5 # The octree depth 58 | _C.DATA.train.full_depth = 2 # The full depth 59 | _C.DATA.train.node_dis = False # Save the node displacement 60 | _C.DATA.train.split_label = False # Save the split label 61 | _C.DATA.train.adaptive = False # Build the adaptive octree 62 | _C.DATA.train.node_feat = False # Calculate the node feature 63 | 64 | # For normalization 65 | # If radius < 0, then the method will compute a bounding sphere 66 | _C.DATA.train.bsphere = 'sphere' # The method uesd to calc the bounding sphere 67 | _C.DATA.train.radius = -1. # The radius and center of the bounding sphere 68 | _C.DATA.train.center = (-1., -1., -1.) 69 | 70 | # For transformation 71 | _C.DATA.train.offset = 0.016 # Used to displace the points when building octree 72 | _C.DATA.train.normal_axis = '' # Used to re-orient normal directions 73 | 74 | # For data augmentation 75 | _C.DATA.train.disable = False # Disable this dataset or not 76 | _C.DATA.train.distort = False # Whether to apply data augmentation 77 | _C.DATA.train.scale = 0.0 # Scale the points 78 | _C.DATA.train.uniform = False # Generate uniform scales 79 | _C.DATA.train.jitter = 0.0 # Jitter the points 80 | _C.DATA.train.interval = (1, 1, 1) # Use interval&angle to generate random angle 81 | _C.DATA.train.angle = (180, 180, 180) 82 | 83 | # For data loading 84 | _C.DATA.train.location = '' # The data location 85 | _C.DATA.train.filelist = '' # The data filelist 86 | _C.DATA.train.batch_size = 32 # Training data batch size 87 | _C.DATA.train.num_workers = 8 # Number of workers to load the data 88 | _C.DATA.train.shuffle = False # Shuffle the input data 89 | _C.DATA.train.in_memory = False # Load the training data into memory 90 | 91 | 92 | _C.DATA.test = _C.DATA.train.clone() 93 | 94 | 95 | # MODEL related parameters 96 | _C.MODEL = CN() 97 | _C.MODEL.name = '' # The name of the model 98 | _C.MODEL.depth = 5 # The input octree depth 99 | _C.MODEL.full_depth = 2 # The input octree full depth layer 100 | _C.MODEL.depth_out = 5 # The output feature depth 101 | _C.MODEL.channel = 3 # The input feature channel 102 | _C.MODEL.factor = 1 # The factor used to widen the network 103 | _C.MODEL.nout = 40 # The output feature channel 104 | _C.MODEL.resblock_num = 3 # The resblock number 105 | _C.MODEL.resblock_type = 'bottleneck'# Choose from 'bottleneck' and 'basic 106 | _C.MODEL.bottleneck = 4 # The bottleneck factor of one resblock 107 | _C.MODEL.dropout = (0.0,) # The dropout ratio 108 | 109 | _C.MODEL.upsample = 'nearest' # The method used for upsampling 110 | _C.MODEL.interp = 'linear' # The interplation method: linear or nearest 111 | _C.MODEL.nempty = False # Perform Octree Conv on non-empty octree nodes 112 | _C.MODEL.sync_bn = False # Use sync_bn when training the network 113 | _C.MODEL.use_checkpoint = False # Use checkpoint to save memory 114 | _C.MODEL.find_unused_parameters = False # Used in DistributedDataParallel 115 | 116 | # loss related parameters 117 | _C.LOSS = CN() 118 | _C.LOSS.name = '' # The name of the loss 119 | _C.LOSS.num_class = 40 # The class number for the cross-entropy loss 120 | _C.LOSS.weights = (1.0, 1.0) # The weight factors for different losses 121 | _C.LOSS.label_smoothing = 0.0 # The factor of label smoothing 122 | 123 | 124 | # backup the commands 125 | _C.SYS = CN() 126 | _C.SYS.cmds = '' # Used to backup the commands 127 | 128 | FLAGS = _C 129 | 130 | 131 | def _update_config(FLAGS, args): 132 | FLAGS.defrost() 133 | if args.config: 134 | FLAGS.merge_from_file(args.config) 135 | if args.opts: 136 | FLAGS.merge_from_list(args.opts) 137 | FLAGS.SYS.cmds = ' '.join(sys.argv) 138 | 139 | # update logdir 140 | alias = FLAGS.SOLVER.alias.lower() 141 | if 'time' in alias: # 'time' is a special keyword 142 | alias = alias.replace('time', datetime.now().strftime('%m%d%H%M')) #%S 143 | if alias is not '': 144 | FLAGS.SOLVER.logdir += '_' + alias 145 | FLAGS.freeze() 146 | 147 | 148 | def _backup_config(FLAGS, args): 149 | logdir = FLAGS.SOLVER.logdir 150 | if not os.path.exists(logdir): 151 | os.makedirs(logdir) 152 | # copy the file to logdir 153 | if args.config: 154 | shutil.copy2(args.config, logdir) 155 | # dump all configs 156 | filename = os.path.join(logdir, 'all_configs.yaml') 157 | with open(filename, 'w') as fid: 158 | fid.write(FLAGS.dump()) 159 | 160 | 161 | def _set_env_var(FLAGS): 162 | gpus = ','.join([str(a) for a in FLAGS.SOLVER.gpu]) 163 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 164 | 165 | 166 | def get_config(): 167 | return FLAGS 168 | 169 | def parse_args(backup=True): 170 | parser = argparse.ArgumentParser(description='The configs') 171 | parser.add_argument('--config', type=str, 172 | help='experiment configure file name') 173 | parser.add_argument('opts', nargs=argparse.REMAINDER, 174 | help="Modify config options using the command-line") 175 | 176 | args = parser.parse_args() 177 | _update_config(FLAGS, args) 178 | if backup: 179 | _backup_config(FLAGS, args) 180 | _set_env_var(FLAGS) 181 | return FLAGS 182 | 183 | 184 | if __name__ == '__main__': 185 | flags = parse_args(backup=False) 186 | print(flags) 187 | -------------------------------------------------------------------------------- /losses/loss.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import ocnn 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | def compute_gradient(y, x): 14 | grad_outputs = torch.ones_like(y) 15 | grad = torch.autograd.grad(y, [x], grad_outputs, create_graph=True)[0] 16 | return grad 17 | 18 | 19 | def sdf_reg_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 20 | wg, ws = 1.0, 200.0 21 | grad_loss = (grad - grad_gt).pow(2).mean() * wg 22 | sdf_loss = (sdf - sdf_gt).pow(2).mean() * ws 23 | loss_dict = {'grad_loss' + name_suffix: grad_loss, 24 | 'sdf_loss' + name_suffix: sdf_loss} 25 | return loss_dict 26 | 27 | 28 | def sdf_grad_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 29 | on_surf = sdf_gt != -1 30 | off_surf = on_surf.logical_not() 31 | 32 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0 33 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0 34 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1 35 | grad_loss = (grad[off_surf].norm(2, dim=-1) - 1).abs().mean() * 0.1 36 | 37 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss] 38 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss'] 39 | names = [name + name_suffix for name in names] 40 | loss_dict = dict(zip(names, losses)) 41 | return loss_dict 42 | 43 | 44 | def sdf_grad_regularized_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 45 | on_surf = sdf_gt != -1 46 | off_surf = on_surf.logical_not() 47 | 48 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0 49 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0 50 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1 51 | grad_loss = (grad[off_surf].norm(2, dim=-1) - 1).abs().mean() * 0.1 52 | grad_reg_loss = (grad[off_surf] - grad_gt[off_surf]).pow(2).mean() * 0.1 53 | 54 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss, grad_reg_loss] 55 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss', 'grad_reg_loss'] 56 | names = [name + name_suffix for name in names] 57 | loss_dict = dict(zip(names, losses)) 58 | return loss_dict 59 | 60 | 61 | def possion_grad_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''): 62 | on_surf = sdf_gt == 0 63 | out_of_bbox = sdf_gt == 1.0 64 | off_surf = on_surf.logical_not() 65 | 66 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0 67 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0 68 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1 69 | grad_loss = grad[off_surf].pow(2).mean() * 0.1 # poisson loss 70 | bbox_loss = torch.mean(torch.relu(-sdf[out_of_bbox])) * 100.0 71 | 72 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss, bbox_loss] 73 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss', 'bbox_loss'] 74 | names = [name + name_suffix for name in names] 75 | loss_dict = dict(zip(names, losses)) 76 | return loss_dict 77 | 78 | 79 | def compute_mpu_gradients(mpus, pos, fval_transform=None): 80 | grads = dict() 81 | for d in mpus.keys(): 82 | fval, flags = mpus[d] 83 | if fval_transform is not None: 84 | fval = fval_transform(fval) 85 | grads[d] = compute_gradient(fval, pos)[:, :3] 86 | return grads 87 | 88 | 89 | def compute_octree_loss(logits, octree_out): 90 | weights = [1.0] * 16 91 | # weights = [1.0] * 4 + [0.8, 0.6, 0.4] + [0.2] * 16 92 | 93 | output = dict() 94 | for d in logits.keys(): 95 | logitd = logits[d] 96 | label_gt = ocnn.octree_property(octree_out, 'split', d).long() 97 | output['loss_%d' % d] = F.cross_entropy(logitd, label_gt) * weights[d] 98 | output['accu_%d' % d] = logitd.argmax(1).eq(label_gt).float().mean() 99 | return output 100 | 101 | 102 | def compute_sdf_loss(mpus, grads, sdf_gt, grad_gt, reg_loss_func): 103 | output = dict() 104 | for d in mpus.keys(): 105 | sdf, flgs = mpus[d] # TODO: tune the loss weights and `flgs` 106 | reg_loss = reg_loss_func(sdf, grads[d], sdf_gt, grad_gt, '_%d' % d) 107 | # if d < 3: # ignore depth 2 108 | # for key in reg_loss.keys(): 109 | # reg_loss[key] = reg_loss[key] * 0.0 110 | output.update(reg_loss) 111 | return output 112 | 113 | 114 | def compute_occu_loss_v0(mpus, grads, occu_gt, grad_gt, weight): 115 | output = dict() 116 | for d in mpus.keys(): 117 | occu, flgs, grad = mpus[d] 118 | 119 | # pos_weight = torch.ones_like(occu_gt) * 10.0 120 | loss_o = F.binary_cross_entropy_with_logits(occu, occu_gt, weight=weight) 121 | # loss_g = torch.mean((grad - grad_gt) ** 2) 122 | 123 | occu = torch.sigmoid(occu) 124 | non_surface_points = occu_gt != 0.5 125 | accu = (occu > 0.5).eq(occu_gt).float()[non_surface_points].mean() 126 | 127 | output['occu_loss_%d' % d] = loss_o 128 | # output['grad_loss_%d' % d] = loss_g 129 | output['occu_accu_%d' % d] = accu 130 | return output 131 | 132 | 133 | def compute_occu_loss_1214(mpus, occu): 134 | # tried on 2021.12.13 135 | weights = [0.2] * 4 + [0.4, 0.6, 0.8] + [1.0] * 16 # TODO: tune the weights 136 | 137 | inside = occu == 0 138 | outside = occu == 1 139 | output = dict() 140 | for d in mpus.keys(): 141 | sdf, flgs = mpus[d] 142 | 143 | inside_loss = torch.mean(torch.relu(sdf[inside])) * (1000 * weights[d]) 144 | outside_loss = torch.mean(torch.relu(-sdf[outside])) * (1000 * weights[d]) 145 | 146 | output['inside_loss_%d' % d] = inside_loss 147 | output['outside_loss_%d' % d] = outside_loss 148 | output['inside_accu_%d' % d] = (sdf[inside] < 0).float().mean() 149 | output['outside_accu_%d' % d] = (sdf[outside] > 0).float().mean() 150 | return output 151 | 152 | 153 | def compute_occu_loss(mpus, grads, occu, grad_gt): 154 | weights = [1.0] * 16 155 | # weights = [0.2] * 4 + [0.4, 0.6, 0.8] + [1.0] * 16 156 | # weights = [0.0] * 7 + [1.0] * 16 # Single level loss 157 | 158 | inside = occu == 0 159 | outside = occu == 1 160 | on_surf = occu == 0.5 161 | off_surf = on_surf.logical_not() 162 | 163 | output = dict() 164 | for d in mpus.keys(): 165 | sdf, flgs = mpus[d] 166 | grad = grads[d] 167 | grad_diff = grad[on_surf] - grad_gt[on_surf] 168 | 169 | sdf_loss = sdf[on_surf].pow(2).mean() * (200 * weights[d]) 170 | norm_loss = grad_diff.pow(2).mean() * (1.0 * weights[d]) 171 | intr_loss = torch.exp(-40 * sdf[off_surf].abs()).mean() * (0.1 * weights[d]) 172 | grad_loss = grad[off_surf].pow(2).mean() * (0.1 * weights[d]) 173 | 174 | inside_loss = torch.mean(torch.relu(sdf[inside])) * (500 * weights[d]) 175 | outside_loss = torch.mean(torch.relu(-sdf[outside])) * (2000 * weights[d]) 176 | inside_accu = (sdf[inside] < 0).float().mean() 177 | outside_accu = (sdf[outside] > 0).float().mean() 178 | 179 | losses = [sdf_loss, norm_loss, grad_loss, intr_loss, 180 | inside_loss, inside_accu, outside_loss, outside_accu] 181 | names = ['sdf_loss', 'norm_loss', 'grad_loss', 'inter_loss', 182 | 'inside_loss', 'inside_accu', 'outside_loss', 'outside_accu'] 183 | names = [name + ('_%d' % d) for name in names] 184 | loss_dict = dict(zip(names, losses)) 185 | output.update(loss_dict) 186 | 187 | return output 188 | 189 | 190 | def compute_occu_loss_cls(mpus, grads, occu_gt, grad_gt): 191 | weights = [1.0] * 16 192 | # weights = [0.2] * 4 + [0.4, 0.6, 0.8] + [1.0] * 16 193 | # weights = [0.0] * 7 + [1.0] * 16 # Single level loss 194 | 195 | inside = occu_gt == 0 196 | outside = occu_gt == 1 197 | on_surf = occu_gt == 0.5 198 | off_surf = on_surf.logical_not() 199 | 200 | # Use soft-version occupancies 201 | occu_soft = torch.ones_like(occu_gt) * 0.5 202 | occu_soft[inside] = 0.3 203 | occu_soft[outside] = 0.7 204 | 205 | output = dict() 206 | for d in mpus.keys(): 207 | sdf, flgs = mpus[d] 208 | grad = grads[d] 209 | grad_diff = grad[on_surf] - grad_gt[on_surf] 210 | 211 | sdf_loss = sdf[on_surf].pow(2).mean() * (200 * weights[d]) 212 | norm_loss = grad_diff.pow(2).mean() * (1.0 * weights[d]) 213 | intr_loss = torch.exp(-40 * sdf[off_surf].abs()).mean() * (0.1 * weights[d]) 214 | grad_loss = grad[off_surf].pow(2).mean() * (0.1 * weights[d]) 215 | 216 | loss_o = F.binary_cross_entropy_with_logits(sdf, occu_soft) 217 | accu = (sdf.sigmoid() > 0.5).eq(occu_gt).float() 218 | inside_accu = accu[inside].mean() 219 | outside_accu = accu[outside].mean() 220 | 221 | losses = [sdf_loss, norm_loss, grad_loss, intr_loss, 222 | loss_o, inside_accu, outside_accu] 223 | names = ['sdf_loss', 'norm_loss', 'grad_loss', 'inter_loss', 224 | 'occu_loss', 'inside_accu', 'outside_accu'] 225 | names = [name + ('_%d' % d) for name in names] 226 | loss_dict = dict(zip(names, losses)) 227 | output.update(loss_dict) 228 | 229 | return output 230 | 231 | 232 | def get_sdf_loss_function(loss_type=''): 233 | if loss_type == 'sdf_reg_loss': 234 | return sdf_reg_loss 235 | elif loss_type == 'sdf_grad_loss': 236 | return sdf_grad_loss 237 | elif loss_type == 'possion_grad_loss': 238 | return possion_grad_loss 239 | elif loss_type == 'sdf_grad_reg_loss': 240 | return sdf_grad_regularized_loss 241 | else: 242 | return None 243 | 244 | 245 | def shapenet_loss(batch, model_out, reg_loss_type=''): 246 | # octree loss 247 | output = compute_octree_loss(model_out['logits'], model_out['octree_out']) 248 | 249 | # regression loss 250 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos']) 251 | reg_loss_func = get_sdf_loss_function(reg_loss_type) 252 | sdf_loss = compute_sdf_loss( 253 | model_out['mpus'], grads, batch['sdf'], batch['grad'], reg_loss_func) 254 | output.update(sdf_loss) 255 | return output 256 | 257 | 258 | def dfaust_loss(batch, model_out, reg_loss_type=''): 259 | # there is no octree loss 260 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos']) 261 | reg_loss_func = get_sdf_loss_function(reg_loss_type) 262 | output = compute_sdf_loss( 263 | model_out['mpus'], grads, batch['sdf'], batch['grad'], reg_loss_func) 264 | return output 265 | 266 | 267 | def synthetic_room_loss(batch, model_out, reg_loss_type=''): 268 | # octree loss 269 | output = compute_octree_loss(model_out['logits'], model_out['octree_out']) 270 | 271 | # grads 272 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos']) 273 | 274 | # occu loss 275 | occu_loss = compute_occu_loss( 276 | model_out['mpus'], grads, batch['occu'], batch['grad']) 277 | output.update(occu_loss) 278 | 279 | return output 280 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dual Octree Graph Networks 2 | 3 | This repository contains the implementation of our papers *Dual Octree Graph Networks*. The experiments are conducted on Ubuntu 18.04 with 4 V400 GPUs (32GB memory). The code is released under the **MIT license**. 4 | 5 | **[Dual Octree Graph Networks for Learning Adaptive Volumetric Shape Representations](https://arxiv.org/abs/2205.02825)**
6 | [Peng-Shuai Wang](https://wang-ps.github.io/), [Yang Liu](https://xueyuhanlang.github.io/), and [Xin Tong](https://www.microsoft.com/en-us/research/people/xtong/)
7 | ACM Transactions on Graphics (SIGGRAPH), 41(4), 2022 8 | 9 | ![teaser](teaser.png) 10 | 11 | 12 | - [Dual Octree Graph Networks](#dual-octree-graph-networks) 13 | - [1. Installation](#1-installation) 14 | - [2. Shape Reconstruction with ShapeNet](#2-shape-reconstruction-with-shapenet) 15 | - [2.1 Data Preparation](#21-data-preparation) 16 | - [2.2 Experiment](#22-experiment) 17 | - [2.3 Generalization](#23-generalization) 18 | - [3. Synthetic Scene Reconstruction](#3-synthetic-scene-reconstruction) 19 | - [3.1 Data Preparation](#31-data-preparation) 20 | - [3.2 Experiment](#32-experiment) 21 | - [4. Unsupervised Surface Reconstruction with DFaust](#4-unsupervised-surface-reconstruction-with-dfaust) 22 | - [4.1 Data Preparation](#41-data-preparation) 23 | - [4.2 Experiment](#42-experiment) 24 | - [4.3 Generalization](#43-generalization) 25 | - [5. Autoencoder with ShapeNet](#5-autoencoder-with-shapenet) 26 | - [5.1 Data Preparation](#51-data-preparation) 27 | - [5.2 Experiment](#52-experiment) 28 | 29 | 30 | ## 1. Installation 31 | 32 | 1. Install [Conda](https://www.anaconda.com/) and create a `Conda` environment. 33 | ```bash 34 | conda create --name dualocnn python=3.7 35 | conda activate dualocnn 36 | ``` 37 | 38 | 1. Install PyTorch-1.9.1 with conda according to the official documentation. 39 | ```bash 40 | conda install pytorch==1.9.1 torchvision==0.10.1 cudatoolkit=10.2 -c pytorch 41 | ``` 42 | 43 | 2. Install `ocnn-pytorch` from [O-CNN](https://github.com/microsoft/O-CNN). 44 | ```bash 45 | git clone https://github.com/microsoft/O-CNN.git 46 | cd O-CNN/pytorch 47 | pip install -r requirements.txt 48 | python setup.py install --build_octree 49 | ``` 50 | 51 | 3. Clone this repository and install other requirements. 52 | ```bash 53 | git clone https://github.com/microsoft/DualOctreeGNN.git 54 | cd DualOctreeGNN 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | ## 2. Shape Reconstruction with ShapeNet 59 | 60 | ### 2.1 Data Preparation 61 | 62 | 1. Download `ShapeNetCore.v1.zip` (31G) from [ShapeNet](https://shapenet.org/) 63 | and place it into the folder `data/ShapeNet`. 64 | 65 | 2. Convert the meshes in `ShapeNetCore.v1` to signed distance fields (SDFs). 66 | 67 | ```bash 68 | python tools/shapenet.py --run convert_mesh_to_sdf 69 | ``` 70 | 71 | Note that this process is relatively slow, it may take several days to finish 72 | converting all the meshes from ShapeNet. And for simplicity, I did not use 73 | multiprocessing of python to speed up. If the speed is a matter, you can 74 | simultaneously execute multiple python commands manually by specifying the 75 | `start` and `end` index of the mesh to be processed. An example is shown as 76 | follows: 77 | 78 | ```bash 79 | python tools/shapenet.py --run convert_mesh_to_sdf --start 10000 --end 20000 80 | ``` 81 | 82 | The `ShapeNetConv.v1` contains 57k meshes. After unzipping, the total size is 83 | about 100G. And the total sizes of the generated SDFs and the repaired meshes 84 | are 450G and 90G, respectively. Please make sure your hard disk has enough 85 | space. 86 | 87 | 3. Sample points and ground-truth SDFs for the learning process. 88 | 89 | ```bash 90 | python tools/shapenet.py --run generate_dataset 91 | ``` 92 | 93 | 4. If you just want to forward the pretrained network, the test point clouds 94 | (330M) can be downloaded manually from 95 | [here](https://www.dropbox.com/s/us28g6808srcop5/shapenet.test.input.zip?dl=0). 96 | After downloading the zip file, unzip it to the folder 97 | `data/ShapeNet/test.input`. 98 | 99 | 107 | 108 | ### 2.2 Experiment 109 | 110 | 1. **Train**: Run the following command to train the network on 4 GPUs. The 111 | training takes 17 hours on 4 V100 GPUs. The trained weight and log can be 112 | downloaded [here](https://www.dropbox.com/s/v3tnopnqxoqqbvb/shapenet_weights_log.zip?dl=1). 113 | 114 | ```bash 115 | python dualocnn.py --config configs/shapenet.yaml SOLVER.gpu 0,1,2,3 116 | ``` 117 | 118 | 2. **Test**: Run the following command to generate the extracted meshes. It is 119 | also possible to specify other trained weights by replacing the parameter 120 | after `SOLVER.ckpt`. 121 | 122 | ```bash 123 | python dualocnn.py --config configs/shapenet_eval.yaml \ 124 | SOLVER.ckpt logs/shapenet/shapenet/checkpoints/00300.model.pth 125 | ``` 126 | 127 | 3. **Evaluate**: We use the code of 128 | [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks.git) 129 | to compute the evaluation metrics. Following the instructions 130 | [here](https://github.com/wang-ps/ConvONet) to reproduce our results in Table 131 | 1. 132 | 133 | ### 2.3 Generalization 134 | 135 | 1. **Test**: Run the following command to test the trained network on unseen 5 136 | categories of ShapeNet: 137 | 138 | ```bash 139 | python dualocnn.py --config configs/shapenet_unseen5.yaml \ 140 | SOLVER.ckpt logs/shapenet/shapenet/checkpoints/00300.model.pth 141 | ``` 142 | 143 | 2. **Evaluate**: Following the instructions 144 | [here](https://github.com/wang-ps/ConvONet) to reproduce our results on the 145 | unseen dataset in Table 1. 146 | 147 | ## 3. Synthetic Scene Reconstruction 148 | 149 | ### 3.1 Data Preparation 150 | 151 | Download and unzip the synthetic scene dataset (205G in total) and the data 152 | splitting filelists by 153 | [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks) 154 | via the following command. 155 | If needed, the ground truth meshes can be downloaded from 156 | [here](https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/room_watertight_mesh.zip) 157 | (90G). 158 | 159 | ```bash 160 | python tools/room.py --run generate_dataset 161 | ``` 162 | 163 | ### 3.2 Experiment 164 | 165 | 1. **Train**: Run the following command to train the network on 4 GPUs. The 166 | training takes 27 hours on 4 V100 GPUs. The trained weight and log can be 167 | downloaded 168 | [here](https://www.dropbox.com/s/t9p8e8tg9rzeaeq/room_weights_log.zip?dl=1). 169 | 170 | ```bash 171 | python dualocnn.py --config configs/synthetic_room.yaml SOLVER.gpu 0,1,2,3 172 | ``` 173 | 174 | 2. **Test**: Run the following command to generate the extracted meshes. 175 | 176 | ```bash 177 | python dualocnn.py --config configs/synthetic_room_eval.yaml \ 178 | SOLVER.ckpt logs/room/room/checkpoints/00900.model.pth 179 | ``` 180 | 181 | 3. **Evaluate**: Following the instructions 182 | [here](https://github.com/wang-ps/ConvONet) to reproduce our results in Table 183 | 5. 184 | 185 | 186 | ## 4. Unsupervised Surface Reconstruction with DFaust 187 | 188 | 189 | ### 4.1 Data Preparation 190 | 191 | 1. Download the [DFaust](https://dfaust.is.tue.mpg.de/) dataset, unzip the raw 192 | scans into the folder `data/dfaust/scans`, and unzip the ground-truth meshes 193 | into the folder `data/dfaust/mesh_gt`. Note that the ground-truth meshes are 194 | used in computing evaluation metric and NOT used in training. 195 | 196 | 2. Run the following command to prepare the dataset. 197 | 198 | ```bash 199 | python tools/dfaust.py --run genereate_dataset 200 | ``` 201 | 202 | 3. For convenience, we also provide the dataset for downloading. 203 | 204 | ```bash 205 | python tools/dfaust.py --run download_dataset 206 | ``` 207 | 208 | ### 4.2 Experiment 209 | 210 | 1. **Train**: Run the following command to train the network on 4 GPUs. The 211 | training takes 20 hours on 4 V100 GPUs. The trained weight and log can be 212 | downloaded 213 | [here](https://www.dropbox.com/s/lyhr9n3b7uhjul8/dfaust_weights_log.zip?dl=0). 214 | 215 | ```bash 216 | python dualocnn.py --config configs/dfaust.yaml SOLVER.gpu 0,1,2,3 217 | ``` 218 | 219 | 220 | 2. **Test**: Run the following command to generate the meshes with the trained 221 | weights. 222 | 223 | ```bash 224 | python dualocnn.py --config configs/dfaust_eval.yaml \ 225 | SOLVER.ckpt logs/dfaust/dfaust/checkpoints/00600.model.pth 226 | ``` 227 | 228 | 3. **Evaluate**: To calculate the evaluation metric, we need first rescale the 229 | mesh into the original size, since the point clouds are scaled during the 230 | data processing stage. 231 | 232 | ```bash 233 | python tools/dfaust.py \ 234 | --mesh_folder logs/dfaust_eval/dfaust \ 235 | --output_folder logs/dfaust_eval/dfaust_rescale \ 236 | --run rescale_mesh 237 | ``` 238 | 239 | Then our results in Table 6 can be reproduced in the file `metrics.csv`. 240 | 241 | ```bash 242 | python tools/compute_metrics.py \ 243 | --mesh_folder logs/dfaust_eval/dfaust_rescale \ 244 | --filelist data/dfaust/filelist/test.txt \ 245 | --ref_folder data/dfaust/mesh_gt \ 246 | --filename_out logs/dfaust_eval/dfaust_rescale/metrics.csv 247 | ``` 248 | 249 | 250 | ### 4.3 Generalization 251 | 252 | In the Figure 1 and 11 of our paper, we test the generalization ability of our 253 | network on several out-of-distribution point clouds. Please download the point 254 | clouds from [here](https://www.dropbox.com/s/pzo9xajktaml4hh/shapes.zip?dl=1), 255 | and place the unzipped data to the folder `data/shapes`. Then run the following 256 | command to reproduce the results: 257 | 258 | ```bash 259 | python dualocnn.py --config configs/shapes.yaml \ 260 | SOLVER.ckpt logs/dfaust/dfaust/checkpoints/00600.model.pth \ 261 | ``` 262 | 263 | 264 | ## 5. Autoencoder with ShapeNet 265 | 266 | ### 5.1 Data Preparation 267 | 268 | Following the instructions [here](#21-data-preparation) to prepare the dataset. 269 | 270 | 271 | ### 5.2 Experiment 272 | 273 | 1. **Train**: Run the following command to train the network on 4 GPUs. The 274 | training takes 24 hours on 4 V100 GPUs. The trained weight and log can be 275 | downloaded 276 | [here](https://www.dropbox.com/s/3e4bx3zaj0b85kd/shapenet_ae_weights_log.zip?dl=1). 277 | 278 | ```bash 279 | python dualocnn.py --config configs/shapenet_ae.yaml SOLVER.gpu 0,1,2,3 280 | ``` 281 | 282 | 2. **Test**: Run the following command to generate the extracted meshes. 283 | 284 | ```bash 285 | python dualocnn.py --config configs/shapenet_ae_eval.yaml \ 286 | SOLVER.ckpt logs/shapenet/ae/checkpoints/00300.model.pth 287 | ``` 288 | 289 | 3. **Evaluate**: Run the following command to evaluate the predicted meshes. 290 | Then our results in Table 7 can be reproduced in the file `metrics.4096.csv`. 291 | 292 | ```bash 293 | python tools/compute_metrics.py \ 294 | --mesh_folder logs/shapenet_eval/ae \ 295 | --filelist data/ShapeNet/filelist/test_im.txt \ 296 | --ref_folder data/ShapeNet/mesh \ 297 | --num_samples 4096 \ 298 | --filename_out logs/shapenet_eval/ae/metrics.4096.csv 299 | ``` 300 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import torch 10 | import torch.nn 11 | import torch.nn.init 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | # import torch_geometric.nn 15 | 16 | from .utils.scatter import scatter_mean 17 | 18 | bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x 19 | # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch 20 | 21 | 22 | def ckpt_conv_wrapper(conv_op, x, edge_index, edge_type): 23 | def conv_wrapper(x, edge_index, edge_type, dummy_tensor): 24 | return conv_op(x, edge_index, edge_type) 25 | 26 | # The dummy tensor is a workaround when the checkpoint is used for the first conv layer: 27 | # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11 28 | dummy = torch.ones(1, dtype=torch.float32, requires_grad=True) 29 | 30 | return torch.utils.checkpoint.checkpoint( 31 | conv_wrapper, x, edge_index, edge_type, dummy) 32 | 33 | 34 | # class GraphConv_v0(torch_geometric.nn.MessagePassing): 35 | class GraphConv_v0: 36 | ''' This implementation explicitly constructs the self.weights[edge_type], 37 | thus consuming a lot of computation and memory. 38 | ''' 39 | 40 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7): 41 | super().__init__(aggr='add') 42 | self.in_channels = in_channels 43 | self.out_channels = out_channels 44 | self.n_edge_type = n_edge_type 45 | self.avg_degree = avg_degree 46 | 47 | self.weights = torch.nn.Parameter( 48 | torch.Tensor(n_edge_type, out_channels, in_channels)) 49 | self.reset_parameters() 50 | 51 | def reset_parameters(self) -> None: 52 | fan_in = self.avg_degree * self.weights.shape[2] 53 | fan_out = self.avg_degree * self.weights.shape[1] 54 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 55 | a = math.sqrt(3.0) * std 56 | torch.nn.init.uniform_(self.weights, -a, a) 57 | 58 | def forward(self, x, edge_index, edge_type): 59 | # x has shape [N, in_channels] 60 | # edge_index has shape [2, E] 61 | 62 | return self.propagate(edge_index, x=x, edge_type=edge_type) 63 | 64 | def message(self, x_j, edge_type): 65 | weights = self.weights[edge_type] # (N, out_channels, in_channels) 66 | output = weights @ x_j.unsqueeze(-1) 67 | return output.squeeze(-1) 68 | 69 | 70 | class GraphConv(torch.nn.Module): 71 | 72 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7, 73 | n_node_type=0): 74 | super().__init__() 75 | self.in_channels = in_channels 76 | self.out_channels = out_channels 77 | self.n_edge_type = n_edge_type 78 | self.avg_degree = avg_degree 79 | self.n_node_type = n_node_type 80 | 81 | node_channel = n_node_type if n_node_type > 1 else 0 82 | self.weights = torch.nn.Parameter( 83 | torch.Tensor(n_edge_type * (in_channels + node_channel), out_channels)) 84 | # if n_node_type > 0: 85 | # self.node_weights = torch.nn.Parameter( 86 | # torch.tensor([0.5 ** i for i in range(n_node_type)])) 87 | self.reset_parameters() 88 | 89 | def reset_parameters(self) -> None: 90 | fan_in = self.avg_degree * self.in_channels 91 | fan_out = self.avg_degree * self.out_channels 92 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 93 | a = math.sqrt(3.0) * std 94 | torch.nn.init.uniform_(self.weights, -a, a) 95 | 96 | def forward(self, x, edge_index, edge_type, node_type=None): 97 | has_node_type = node_type is not None 98 | if has_node_type and self.n_node_type > 1: 99 | # concatenate the one_hot vector 100 | one_hot = F.one_hot(node_type, num_classes=self.n_node_type) 101 | x = torch.cat([x, one_hot], dim=1) 102 | 103 | # x -> col_data 104 | row, col = edge_index[0], edge_index[1] 105 | # weights = torch.pow(0.5, node_type[col]) if has_node_type else None 106 | weights = None # TODO: ablation the weights 107 | index = row * self.n_edge_type + edge_type 108 | col_data = scatter_mean(x[col], index, dim=0, weights=weights, 109 | dim_size=x.shape[0] * self.n_edge_type) 110 | 111 | # matrix product 112 | output = col_data.view(x.shape[0], -1) @ self.weights 113 | return output 114 | 115 | def extra_repr(self) -> str: 116 | return ('channel_in={}, channel_out={}, n_edge_type={}, avg_degree={}, ' 117 | 'n_node_type={}'.format(self.in_channels, self.out_channels, 118 | self.n_edge_type, self.avg_degree, self.n_node_type)) # noqa 119 | 120 | 121 | class GraphConvBn(torch.nn.Module): 122 | 123 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7, 124 | n_node_type=0): 125 | super().__init__() 126 | self.conv = GraphConv( 127 | in_channels, out_channels, n_edge_type, avg_degree, n_node_type) 128 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 129 | 130 | def forward(self, x, edge_index, edge_type, node_type=None): 131 | out = self.conv(x, edge_index, edge_type, node_type) 132 | out = self.bn(out) 133 | return out 134 | 135 | 136 | class GraphConvBnRelu(torch.nn.Module): 137 | 138 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7, 139 | n_node_type=0): 140 | super().__init__() 141 | self.conv = GraphConv( 142 | in_channels, out_channels, n_edge_type, avg_degree, n_node_type) 143 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum) 144 | self.relu = torch.nn.ReLU(inplace=True) 145 | 146 | def forward(self, x, edge_index, edge_type, node_type=None): 147 | out = self.conv(x, edge_index, edge_type, node_type) 148 | # out = ckpt_conv_wrapper(self.conv, x, edge_index, edge_type) 149 | out = self.bn(out) 150 | out = self.relu(out) 151 | return out 152 | 153 | 154 | class Conv1x1(torch.nn.Module): 155 | 156 | def __init__(self, channel_in, channel_out, use_bias=False): 157 | super().__init__() 158 | self.linear = torch.nn.Linear(channel_in, channel_out, use_bias) 159 | 160 | def forward(self, x): 161 | return self.linear(x) 162 | 163 | 164 | class Conv1x1Bn(torch.nn.Module): 165 | 166 | def __init__(self, channel_in, channel_out): 167 | super().__init__() 168 | self.conv = Conv1x1(channel_in, channel_out, use_bias=False) 169 | self.bn = torch.nn.BatchNorm1d(channel_out, bn_eps, bn_momentum) 170 | 171 | def forward(self, x): 172 | out = self.conv(x) 173 | out = self.bn(out) 174 | return out 175 | 176 | 177 | class Conv1x1BnRelu(torch.nn.Module): 178 | 179 | def __init__(self, channel_in, channel_out): 180 | super().__init__() 181 | self.conv = Conv1x1(channel_in, channel_out, use_bias=False) 182 | self.bn = torch.nn.BatchNorm1d(channel_out, bn_eps, bn_momentum) 183 | self.relu = torch.nn.ReLU(inplace=True) 184 | 185 | def forward(self, x): 186 | out = self.conv(x) 187 | out = self.bn(out) 188 | out = self.relu(out) 189 | return out 190 | 191 | 192 | class Upsample(torch.nn.Module): 193 | 194 | def __init__(self, channels): 195 | super().__init__() 196 | self.channels = channels 197 | 198 | self.weights = torch.nn.Parameter( 199 | torch.Tensor(channels, channels, 8)) 200 | torch.nn.init.xavier_uniform_(self.weights) 201 | 202 | def forward(self, x): 203 | out = x @ self.weights.flatten(1) 204 | out = out.view(-1, self.channels) 205 | return out 206 | 207 | def extra_repr(self): 208 | return 'channels={}'.format(self.channels) 209 | 210 | 211 | class Downsample(torch.nn.Module): 212 | 213 | def __init__(self, channels): 214 | super().__init__() 215 | self.channels = channels 216 | 217 | self.weights = torch.nn.Parameter( 218 | torch.Tensor(channels, channels, 8)) 219 | torch.nn.init.xavier_uniform_(self.weights) 220 | 221 | def forward(self, x): 222 | weights = self.weights.flatten(1).t() 223 | out = x.view(-1, self.channels * 8) @ weights 224 | return out 225 | 226 | def extra_repr(self): 227 | return 'channels={}'.format(self.channels) 228 | 229 | 230 | class GraphDownsample(torch.nn.Module): 231 | 232 | def __init__(self, channels_in, channels_out=None): 233 | super().__init__() 234 | self.channels_in = channels_in 235 | self.channels_out = channels_out or channels_in 236 | self.downsample = Downsample(channels_in) 237 | if self.channels_in != self.channels_out: 238 | self.conv1x1 = Conv1x1BnRelu(self.channels_in, self.channels_out) 239 | 240 | def forward(self, x, leaf_mask, numd, lnumd): 241 | # downsample nodes at layer depth 242 | outd = x[-numd:] 243 | outd = self.downsample(outd) 244 | 245 | # get the nodes at layer (depth-1) 246 | out = torch.zeros(leaf_mask.shape[0], x.shape[1], device=x.device) 247 | out[leaf_mask] = x[-lnumd-numd:-numd] 248 | out[leaf_mask.logical_not()] = outd 249 | 250 | # construct the final output 251 | out = torch.cat([x[:-numd-lnumd], out], dim=0) 252 | 253 | if self.channels_in != self.channels_out: 254 | out = self.conv1x1(out) 255 | return out 256 | 257 | def extra_repr(self): 258 | return 'channels_in={}, channels_out={}'.format( 259 | self.channels_in, self.channels_out) 260 | 261 | 262 | class GraphMaxpool(torch.nn.Module): 263 | 264 | def __init__(self): 265 | super().__init__() 266 | 267 | def forward(self, x, leaf_mask, numd, lnumd): 268 | # downsample nodes at layer depth 269 | channel = x.shape[1] 270 | outd = x[-numd:] 271 | outd, _ = outd.view(-1, 8, channel).max(dim=1) 272 | 273 | # get the nodes at layer (depth-1) 274 | out = torch.zeros(leaf_mask.shape[0], channel, device=x.device) 275 | out[leaf_mask] = x[-lnumd-numd:-numd] 276 | out[leaf_mask.logical_not()] = outd 277 | 278 | # construct the final output 279 | out = torch.cat([x[:-numd-lnumd], out], dim=0) 280 | return out 281 | 282 | 283 | class GraphUpsample(torch.nn.Module): 284 | 285 | def __init__(self, channels_in, channels_out=None): 286 | super().__init__() 287 | self.channels_in = channels_in 288 | self.channels_out = channels_out or channels_in 289 | self.upsample = Upsample(channels_in) 290 | if self.channels_in != self.channels_out: 291 | self.conv1x1 = Conv1x1BnRelu(self.channels_in, self.channels_out) 292 | 293 | def forward(self, x, leaf_mask, numd): 294 | # upsample nodes at layer (depth-1) 295 | outd = x[-numd:] 296 | out1 = outd[leaf_mask.logical_not()] 297 | out1 = self.upsample(out1) 298 | 299 | # construct the final output 300 | out = torch.cat([x[:-numd], outd[leaf_mask], out1], dim=0) 301 | if self.channels_in != self.channels_out: 302 | out = self.conv1x1(out) 303 | return out 304 | 305 | def extra_repr(self): 306 | return 'channels_in={}, channels_out={}'.format( 307 | self.channels_in, self.channels_out) 308 | 309 | 310 | class GraphResBlock2(torch.nn.Module): 311 | 312 | def __init__(self, channel_in, channel_out, bottleneck=1, n_edge_type=7, 313 | avg_degree=7, n_node_type=0): 314 | super().__init__() 315 | self.channel_in = channel_in 316 | self.channel_out = channel_out 317 | self.bottleneck = bottleneck 318 | channel_m = int(channel_out / bottleneck) 319 | 320 | self.conva = GraphConvBnRelu( 321 | channel_in, channel_m, n_edge_type, avg_degree, n_node_type) 322 | self.convb = GraphConvBn( 323 | channel_m, channel_out, n_edge_type, avg_degree, n_node_type) 324 | if self.channel_in != self.channel_out: 325 | self.conv1x1 = Conv1x1Bn(channel_in, channel_out) 326 | self.relu = torch.nn.ReLU(inplace=True) 327 | 328 | def forward(self, x, edge_index, edge_type, node_type): 329 | x1 = self.conva(x, edge_index, edge_type, node_type) 330 | x2 = self.convb(x1, edge_index, edge_type, node_type) 331 | 332 | if self.channel_in != self.channel_out: 333 | x = self.conv1x1(x) 334 | 335 | out = self.relu(x2 + x) 336 | return out 337 | 338 | 339 | class GraphResBlock(torch.nn.Module): 340 | 341 | def __init__(self, channel_in, channel_out, bottleneck=4, n_edge_type=7, 342 | avg_degree=7, n_node_type=0): 343 | super().__init__() 344 | self.channel_in = channel_in 345 | self.channel_out = channel_out 346 | self.bottleneck = bottleneck 347 | channel_m = int(channel_out / bottleneck) 348 | 349 | self.conv1x1a = Conv1x1BnRelu(channel_in, channel_m) 350 | self.conv = GraphConvBnRelu( 351 | channel_m, channel_m, n_edge_type, avg_degree, n_node_type) 352 | self.conv1x1b = Conv1x1Bn(channel_m, channel_out) 353 | if self.channel_in != self.channel_out: 354 | self.conv1x1c = Conv1x1Bn(channel_in, channel_out) 355 | self.relu = torch.nn.ReLU(inplace=True) 356 | 357 | def forward(self, x, edge_index, edge_type, node_type): 358 | x1 = self.conv1x1a(x) 359 | x2 = self.conv(x1, edge_index, edge_type, node_type) 360 | x3 = self.conv1x1b(x2) 361 | 362 | if self.channel_in != self.channel_out: 363 | x = self.conv1x1c(x) 364 | 365 | out = self.relu(x3 + x) 366 | return out 367 | 368 | 369 | class GraphResBlocks(torch.nn.Module): 370 | 371 | def __init__(self, channel_in, channel_out, resblk_num, bottleneck=4, 372 | n_edge_type=7, avg_degree=7, n_node_type=0, 373 | resblk_type='bottleneck'): 374 | super().__init__() 375 | self.resblk_num = resblk_num 376 | channels = [channel_in] + [channel_out] * resblk_num 377 | ResBlk = self._get_resblock(resblk_type) 378 | self.resblks = torch.nn.ModuleList([ 379 | ResBlk(channels[i], channels[i+1], bottleneck, 380 | n_edge_type, avg_degree, n_node_type) 381 | for i in range(self.resblk_num)]) 382 | 383 | def _get_resblock(self, resblk_type): 384 | if resblk_type == 'bottleneck': 385 | return GraphResBlock 386 | elif resblk_type == 'basic': 387 | return GraphResBlock2 388 | else: 389 | raise ValueError 390 | 391 | def forward(self, data, edge_index, edge_type, node_type): 392 | for i in range(self.resblk_num): 393 | data = self.resblks[i](data, edge_index, edge_type, node_type) 394 | return data 395 | 396 | 397 | def doctree_align(value, key, query): 398 | # out-of-bound 399 | out_of_bound = query > key[-1] 400 | query[out_of_bound] = -1 401 | 402 | # search 403 | idx = torch.searchsorted(key, query) 404 | found = key[idx] == query 405 | 406 | # assign the found value to the output 407 | out = torch.zeros(query.shape[0], value.shape[1], device=value.device) 408 | out[found] = value[idx[found]] 409 | return out 410 | -------------------------------------------------------------------------------- /tools/shapenet.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import wget 11 | import shutil 12 | import torch 13 | import ocnn 14 | import trimesh 15 | import logging 16 | import mesh2sdf 17 | import zipfile 18 | import argparse 19 | import numpy as np 20 | from tqdm import tqdm 21 | from plyfile import PlyData, PlyElement 22 | 23 | logger = logging.getLogger("trimesh") 24 | logger.setLevel(logging.ERROR) 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--run', type=str, required=True) 28 | parser.add_argument('--start', type=int, default=0) 29 | parser.add_argument('--end', type=int, default=45572) 30 | args = parser.parse_args() 31 | 32 | size = 128 # resolution of SDF 33 | level = 0.015 # 2/128 = 0.015625 34 | shape_scale = 0.5 # rescale the shape into [-0.5, 0.5] 35 | project_folder = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 36 | root_folder = os.path.join(project_folder, 'data/ShapeNet') 37 | 38 | 39 | def create_flag_file(filename): 40 | r''' Creates a flag file to indicate whether some time-consuming works 41 | have been done. 42 | ''' 43 | 44 | folder = os.path.dirname(filename) 45 | if not os.path.exists(folder): 46 | os.makedirs(folder) 47 | with open(filename, 'w') as fid: 48 | fid.write('succ @ ' + time.ctime()) 49 | 50 | 51 | def check_folder(filenames: list): 52 | r''' Checks whether the folder contains the filename exists. 53 | ''' 54 | 55 | for filename in filenames: 56 | folder = os.path.dirname(filename) 57 | if not os.path.exists(folder): 58 | os.makedirs(folder) 59 | 60 | 61 | def get_filenames(filelist): 62 | r''' Gets filenames from a filelist. 63 | ''' 64 | 65 | filelist = os.path.join(root_folder, 'filelist', filelist) 66 | with open(filelist, 'r') as fid: 67 | lines = fid.readlines() 68 | filenames = [line.split()[0] for line in lines] 69 | return filenames 70 | 71 | 72 | def unzip_shapenet(): 73 | r''' Unzip the ShapeNetCore.v1 74 | ''' 75 | 76 | filename = os.path.join(root_folder, 'ShapeNetCore.v1.zip') 77 | flag_file = os.path.join(root_folder, 'flags/unzip_shapenet_succ') 78 | if not os.path.exists(flag_file): 79 | print('-> Unzip ShapeNetCore.v1.zip.') 80 | with zipfile.ZipFile(filename, 'r') as zip_ref: 81 | zip_ref.extractall(root_folder) 82 | create_flag_file(flag_file) 83 | 84 | folder = os.path.join(root_folder, 'ShapeNetCore.v1') 85 | flag_file = os.path.join(root_folder, 'flags/unzip_shapenet_all_succ') 86 | if not os.path.exists(flag_file): 87 | print('-> Unzip all zip files in ShapeNetCore.v1.') 88 | filenames = os.listdir(folder) 89 | for filename in filenames: 90 | if filename.endswith('.zip'): 91 | print('- Unzip %s' % filename) 92 | zipname = os.path.join(folder, filename) 93 | with zipfile.ZipFile(zipname, 'r') as zip_ref: 94 | zip_ref.extractall(folder) 95 | os.remove(zipname) 96 | create_flag_file(flag_file) 97 | 98 | 99 | def download_filelist(): 100 | r''' Downloads the filelists used for learning. 101 | ''' 102 | 103 | flag_file = os.path.join(root_folder, 'flags/download_filelist_succ') 104 | if not os.path.exists(flag_file): 105 | print('-> Download the filelist.') 106 | url = 'https://www.dropbox.com/s/4jvam486l8961t7/shapenet.filelist.zip?dl=1' 107 | filename = os.path.join(root_folder, 'filelist.zip') 108 | wget.download(url, filename, bar=None) 109 | 110 | folder = os.path.join(root_folder, 'filelist') 111 | with zipfile.ZipFile(filename, 'r') as zip_ref: 112 | zip_ref.extractall(path=folder) 113 | os.remove(filename) 114 | create_flag_file(flag_file) 115 | 116 | 117 | def run_mesh2sdf(): 118 | r''' Converts the meshes from ShapeNet to SDFs and manifold meshes. 119 | ''' 120 | 121 | print('-> Run mesh2sdf.') 122 | mesh_scale = 0.8 123 | filenames = get_filenames('all.txt') 124 | for i in tqdm(range(args.start, args.end), ncols=80): 125 | filename = filenames[i] 126 | filename_raw = os.path.join( 127 | root_folder, 'ShapeNetCore.v1', filename, 'model.obj') 128 | filename_obj = os.path.join(root_folder, 'mesh', filename + '.obj') 129 | filename_box = os.path.join(root_folder, 'bbox', filename + '.npz') 130 | filename_npy = os.path.join(root_folder, 'sdf', filename + '.npy') 131 | check_folder([filename_obj, filename_box, filename_npy]) 132 | if os.path.exists(filename_obj): continue 133 | 134 | # load the raw mesh 135 | mesh = trimesh.load(filename_raw, force='mesh') 136 | 137 | # rescale mesh to [-1, 1] for mesh2sdf, note the factor **mesh_scale** 138 | vertices = mesh.vertices 139 | bbmin, bbmax = vertices.min(0), vertices.max(0) 140 | center = (bbmin + bbmax) * 0.5 141 | scale = 2.0 * mesh_scale / (bbmax - bbmin).max() 142 | vertices = (vertices - center) * scale 143 | 144 | # run mesh2sdf 145 | sdf, mesh_new = mesh2sdf.compute(vertices, mesh.faces, size, fix=True, 146 | level=level, return_mesh=True) 147 | mesh_new.vertices = mesh_new.vertices * shape_scale 148 | 149 | # save 150 | np.savez(filename_box, bbmax=bbmax, bbmin=bbmin, mul=mesh_scale) 151 | np.save(filename_npy, sdf) 152 | mesh_new.export(filename_obj) 153 | 154 | 155 | def sample_pts_from_mesh(): 156 | r''' Samples 10k points with normals from the ground-truth meshes. 157 | ''' 158 | 159 | print('-> Run sample_pts_from_mesh.') 160 | num_samples = 40000 161 | mesh_folder = os.path.join(root_folder, 'mesh') 162 | output_folder = os.path.join(root_folder, 'dataset') 163 | filenames = get_filenames('all.txt') 164 | for i in tqdm(range(args.start, args.end), ncols=80): 165 | filename = filenames[i] 166 | filename_obj = os.path.join(mesh_folder, filename + '.obj') 167 | filename_pts = os.path.join(output_folder, filename, 'pointcloud.npz') 168 | check_folder([filename_pts]) 169 | if os.path.exists(filename_pts): continue 170 | 171 | # sample points 172 | mesh = trimesh.load(filename_obj, force='mesh') 173 | points, idx = trimesh.sample.sample_surface(mesh, num_samples) 174 | normals = mesh.face_normals[idx] 175 | 176 | # save points 177 | np.savez(filename_pts, points=points.astype(np.float16), 178 | normals=normals.astype(np.float16)) 179 | 180 | 181 | def sample_sdf(): 182 | r''' Samples ground-truth SDF values for training. 183 | ''' 184 | 185 | # constants 186 | depth, full_depth = 6, 4 187 | sample_num = 4 # number of samples in each octree node 188 | grid = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], 189 | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]) 190 | 191 | print('-> Sample SDFs from the ground truth.') 192 | filenames = get_filenames('all.txt') 193 | for i in tqdm(range(args.start, args.end), ncols=80): 194 | filename = filenames[i] 195 | dataset_folder = os.path.join(root_folder, 'dataset') 196 | filename_sdf = os.path.join(root_folder, 'sdf', filename + '.npy') 197 | filename_pts = os.path.join(dataset_folder, filename, 'pointcloud.npz') 198 | filename_out = os.path.join(dataset_folder, filename, 'sdf.npz') 199 | if os.path.exists(filename_out): continue 200 | 201 | # load data 202 | pts = np.load(filename_pts) 203 | sdf = np.load(filename_sdf) 204 | sdf = torch.from_numpy(sdf) 205 | points = pts['points'].astype(np.float32) 206 | normals = pts['normals'].astype(np.float32) 207 | points = points / shape_scale # rescale points to [-1, 1] 208 | 209 | # build octree 210 | points = ocnn.points_new( 211 | torch.from_numpy(points), torch.from_numpy(normals), 212 | torch.Tensor(), torch.Tensor()) 213 | octree2points = ocnn.Points2Octree(depth=depth, full_depth=full_depth) 214 | octree = octree2points(points) 215 | 216 | # sample points and grads according to the xyz 217 | xyzs, grads, sdfs = [], [], [] 218 | for d in range(full_depth, depth + 1): 219 | xyz = ocnn.octree_property(octree, 'xyz', d) 220 | xyz = ocnn.octree_decode_key(xyz) 221 | 222 | # sample k points in each octree node 223 | xyz = xyz[:, :3].float() # + 0.5 -> octree node center 224 | xyz = xyz.unsqueeze(1) + torch.rand(xyz.shape[0], sample_num, 3) 225 | xyz = xyz.view(-1, 3) # (N, 3) 226 | xyz = xyz * (size / 2 ** d) # normalize to [0, 2^sdf_depth] 227 | xyz = xyz[(xyz < 127).all(dim=1)] # remove out-of-bound points 228 | xyzs.append(xyz) 229 | 230 | # interpolate the sdf values 231 | xyzi = torch.floor(xyz) # the integer part (N, 3) 232 | corners = xyzi.unsqueeze(1) + grid # (N, 8, 3) 233 | coordsf = xyz.unsqueeze(1) - corners # (N, 8, 3), in [-1.0, 1.0] 234 | weights = (1 - coordsf.abs()).prod(dim=-1) # (N, 8, 1) 235 | corners = corners.long().view(-1, 3) 236 | x, y, z = corners[:, 0], corners[:, 1], corners[:, 2] 237 | s = sdf[x, y, z].view(-1, 8) 238 | sw = torch.sum(s * weights, dim=1) 239 | sdfs.append(sw) 240 | 241 | # calc the gradient 242 | gx = s[:, 4] - s[:, 0] + s[:, 5] - s[:, 1] + \ 243 | s[:, 6] - s[:, 2] + s[:, 7] - s[:, 3] # noqa 244 | gy = s[:, 2] - s[:, 0] + s[:, 3] - s[:, 1] + \ 245 | s[:, 6] - s[:, 4] + s[:, 7] - s[:, 5] # noqa 246 | gz = s[:, 1] - s[:, 0] + s[:, 3] - s[:, 2] + \ 247 | s[:, 5] - s[:, 4] + s[:, 7] - s[:, 6] # noqa 248 | grad = torch.stack([gx, gy, gz], dim=-1) 249 | norm = torch.sqrt(torch.sum(grad ** 2, dim=-1, keepdims=True)) 250 | grad = grad / (norm + 1.0e-8) 251 | grads.append(grad) 252 | 253 | # concat the results 254 | xyzs = torch.cat(xyzs, dim=0).numpy() 255 | points = (xyzs / 64 - 1).astype(np.float16) * shape_scale # !shape_scale 256 | grads = torch.cat(grads, dim=0).numpy().astype(np.float16) 257 | sdfs = torch.cat(sdfs, dim=0).numpy().astype(np.float16) 258 | 259 | # save results 260 | # points = (points * args.scale).astype(np.float16) # in [-scale, scale] 261 | np.savez(filename_out, points=points, grad=grads, sdf=sdfs) 262 | 263 | 264 | def sample_occu(): 265 | r''' Samples occupancy values for evaluating the IoU following ConvONet. 266 | ''' 267 | 268 | num_samples = 100000 269 | grid = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], 270 | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]) 271 | 272 | # filenames = get_filenames('all.txt') 273 | filenames = get_filenames('test.txt') + get_filenames('test_unseen5.txt') 274 | for filename in tqdm(filenames, ncols=80): 275 | filename_sdf = os.path.join(root_folder, 'sdf', filename + '.npy') 276 | filename_occu = os.path.join(root_folder, 'dataset', filename, 'points') 277 | if os.path.exists(filename_occu) or (not os.path.exists(filename_sdf)): 278 | continue 279 | 280 | sdf = np.load(filename_sdf) 281 | factor = 127.0 / 128.0 # make sure the interpolation is well defined 282 | points_uniform = np.random.rand(num_samples, 3) * factor # in [0, 1) 283 | points = (points_uniform - 0.5) * (2 * shape_scale) # !!! rescale 284 | points = points.astype(np.float16) 285 | 286 | # interpolate the sdf values 287 | xyz = points_uniform * 128 # in [0, 127) 288 | xyzi = np.floor(xyz) # the integer part (N, 3) 289 | corners = np.expand_dims(xyzi, 1) + grid # (N, 8, 3) 290 | coordsf = np.expand_dims(xyz, 1) - corners # (N, 8, 3), in [-1.0, 1.0] 291 | weights = np.prod(1 - np.abs(coordsf), axis=-1) # (N, 8) 292 | 293 | corners = np.reshape(corners.astype(np.int64), (-1, 3)) 294 | x, y, z = corners[:, 0], corners[:, 1], corners[:, 2] 295 | values = np.reshape(sdf[x, y, z], (-1, 8)) 296 | value = np.sum(values * weights, axis=1) 297 | occu = value < 0 298 | occu = np.packbits(occu) 299 | 300 | # save 301 | np.savez(filename_occu, points=points, occupancies=occu) 302 | 303 | 304 | def generate_test_points(): 305 | r''' Generates points in `ply` format for testing. 306 | ''' 307 | 308 | noise_std = 0.005 309 | point_sample_num = 3000 310 | # filenames = get_filenames('all.txt') 311 | filenames = get_filenames('test.txt') + get_filenames('test_unseen5.txt') 312 | for filename in tqdm(filenames, ncols=80): 313 | filename_pts = os.path.join( 314 | root_folder, 'dataset', filename, 'pointcloud.npz') 315 | filename_ply = os.path.join( 316 | root_folder, 'test.input', filename + '.ply') 317 | if not os.path.exists(filename_pts): continue 318 | check_folder([filename_ply]) 319 | 320 | # sample points 321 | pts = np.load(filename_pts) 322 | points = pts['points'].astype(np.float32) 323 | noise = noise_std * np.random.randn(point_sample_num, 3) 324 | rand_idx = np.random.choice(points.shape[0], size=point_sample_num) 325 | points_noise = points[rand_idx] + noise 326 | 327 | # save ply 328 | vertices = [] 329 | py_types = (float, float, float) 330 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')] 331 | for idx in range(points_noise.shape[0]): 332 | vertices.append( 333 | tuple(dtype(d) for dtype, d in zip(py_types, points_noise[idx]))) 334 | structured_array = np.array(vertices, dtype=npy_types) 335 | el = PlyElement.describe(structured_array, 'vertex') 336 | PlyData([el]).write(filename_ply) 337 | 338 | 339 | def download_dataset(): 340 | r''' Directly downloads the dataset. 341 | ''' 342 | 343 | flag_file = os.path.join(root_folder, 'flags/download_dataset_succ') 344 | if not os.path.exists(flag_file): 345 | print('-> Download the dataset.') 346 | url = 'https://www.dropbox.com/s/mc3lrwqpmnfq3j8/shapenet.dataset.zip?dl=1' 347 | filename = os.path.join(root_folder, 'shapenet.dataset.zip') 348 | wget.download(url, filename) 349 | 350 | with zipfile.ZipFile(filename, 'r') as zip_ref: 351 | zip_ref.extractall(path=root_folder) 352 | # os.remove(filename) 353 | create_flag_file(flag_file) 354 | 355 | 356 | def generate_dataset_unseen5(): 357 | r'''Creates the unseen5 dataset 358 | ''' 359 | 360 | dataset_folder = os.path.join(root_folder, 'dataset') 361 | unseen5_folder = os.path.join(root_folder, 'dataset.unseen5') 362 | if not os.path.exists(unseen5_folder): 363 | os.makedirs(unseen5_folder) 364 | for folder in ['02808440', '02773838', '02818832', '02876657', '03938244']: 365 | curr_folder = os.path.join(dataset_folder, folder) 366 | if os.path.exists(curr_folder): 367 | shutil.move(os.path.join(dataset_folder, folder), unseen5_folder) 368 | 369 | 370 | def copy_convonet_filelists(): 371 | r''' Copies the filelist of ConvONet to the datasets, which are needed when 372 | calculating the evaluation metrics. 373 | ''' 374 | 375 | with open(os.path.join(root_folder, 'filelist/lists.txt'), 'r') as fid: 376 | lines = fid.readlines() 377 | filenames = [line.split()[0] for line in lines] 378 | filelist_folder = os.path.join(root_folder, 'filelist') 379 | for filename in filenames: 380 | src_name = os.path.join(filelist_folder, filename) 381 | des_name = src_name.replace('filelist/convonet.filelist', 'dataset') \ 382 | .replace('filelist/unseen5.filelist', 'dataset.unseen5') 383 | if not os.path.exists(des_name): 384 | shutil.copy(src_name, des_name) 385 | 386 | 387 | def convert_mesh_to_sdf(): 388 | unzip_shapenet() 389 | download_filelist() 390 | run_mesh2sdf() 391 | 392 | 393 | def generate_dataset(): 394 | sample_pts_from_mesh() 395 | sample_sdf() 396 | sample_occu() 397 | generate_test_points() 398 | generate_dataset_unseen5() 399 | copy_convonet_filelists() 400 | 401 | 402 | if __name__ == '__main__': 403 | eval('%s()' % args.run) 404 | -------------------------------------------------------------------------------- /models/dual_octree.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import ocnn 10 | # import torch_sparse 11 | import numpy as np 12 | 13 | 14 | class DualOctree: 15 | 16 | def __init__(self, octree): 17 | # prime octree 18 | self.octree = octree 19 | self.device = octree.device 20 | self.depth = ocnn.octree_property(octree, 'depth').item() 21 | self.full_depth = ocnn.octree_property(octree, 'full_depth').item() 22 | self.batch_size = ocnn.octree_property(octree, 'batch_size').item() 23 | 24 | # node numbers 25 | self.nnum = ocnn.octree_property(octree, 'node_num') 26 | self.nenum = ocnn.octree_property(octree, 'node_num_ne') 27 | self.ncum = ocnn.octree_property(octree, 'node_num_cum') 28 | self.lnum = self.nnum - self.nenum # leaf node numbers 29 | 30 | # node properties 31 | xyzi = ocnn.octree_property(octree, 'xyz') 32 | self.xyzi = ocnn.octree_decode_key(xyzi) 33 | self.xyz = self.xyzi[:, :3] 34 | self.batch = self.xyzi[:, 3] 35 | self.node_depth = self._node_depth() 36 | self.child = ocnn.octree_property(octree, 'child') 37 | self.key = ocnn.octree_property(octree, 'key') 38 | self.keyd = self.key | (self.node_depth << 58) 39 | 40 | # build lookup tables 41 | self._lookup_table() 42 | 43 | # build dual graph 44 | self._graph = [dict()] * (self.depth + 1) # the internal graph 45 | self.graph = [dict()] * (self.depth + 1) # the output graph 46 | self.build_dual_graph() 47 | 48 | def _lookup_table(self): 49 | self.ngh = torch.tensor( 50 | [[0, 0, 1], [0, 0, -1], # up, down 51 | [0, 1, 0], [0, -1, 0], # right, left 52 | [1, 0, 0], [-1, 0, 0]], # front, back 53 | dtype=torch.int16, device=self.device) 54 | self.dir_table = torch.tensor( 55 | [[1, 3, 5, 7], [0, 2, 4, 6], # up, down 56 | [2, 3, 6, 7], [0, 1, 4, 5], # right, left 57 | [4, 5, 6, 7], [0, 1, 2, 3]], # front, back 58 | dtype=torch.int64, device=self.device) 59 | self.dir_type = torch.tensor( 60 | [0, 1, 2, 3, 4, 5], 61 | dtype=torch.int64, device=self.device) 62 | self.remap = torch.tensor( 63 | [1, 0, 3, 2, 5, 4], 64 | dtype=torch.int64, device=self.device) 65 | self.inter_row = torch.tensor( 66 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 67 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7], 68 | dtype=torch.int64, device=self.device) 69 | self.inter_col = torch.tensor( 70 | [1, 2, 4, 0, 3, 5, 0, 3, 6, 1, 2, 7, 71 | 0, 5, 6, 1, 4, 7, 2, 4, 7, 3, 5, 6], 72 | dtype=torch.int64, device=self.device) 73 | self.inter_dir = torch.tensor( 74 | [0, 2, 4, 1, 2, 4, 3, 0, 4, 3, 1, 4, 75 | 5, 0, 2, 5, 1, 2, 5, 3, 0, 5, 3, 1], 76 | dtype=torch.int64, device=self.device) 77 | 78 | def _node_depth(self): 79 | nd = [torch.tensor([d], dtype=torch.int64, device=self.device).expand( 80 | self.nnum[d]) for d in range(self.depth + 1)] 81 | return torch.cat(nd) 82 | 83 | def build_dual_graph(self): 84 | self._graph[self.full_depth] = self.dense_graph(self.full_depth) 85 | for d in range(self.full_depth + 1, self.depth + 1): 86 | self._graph[d] = self.sparse_graph(d, self._graph[d - 1]) 87 | 88 | def dense_graph(self, depth=3): 89 | K = 6 # each node has at most K neighboring node 90 | bnd = 2 ** depth 91 | num = bnd ** 3 92 | 93 | ki = torch.arange(0, num, dtype=torch.int64, device=self.device) 94 | xi = ocnn.octree_key2xyz(ki, depth) 95 | xi = ocnn.octree_decode_key(xi)[:, :3] 96 | xj = xi.unsqueeze(1) + self.ngh # [N, K, 3] 97 | 98 | row = ki.unsqueeze(1).repeat(1, K).view(-1) 99 | zj = torch.zeros(num, K, 1, dtype=torch.int16, device=self.device) 100 | kj = torch.cat([xj, zj], dim=-1).view(-1, 4) 101 | kj = ocnn.octree_encode_key(kj) 102 | # for full octree, the octree key is the index 103 | col = ocnn.octree_xyz2key(kj, depth) 104 | 105 | valid = torch.logical_and(xj > -1, xj < bnd) # out-of-bound 106 | valid = torch.all(valid, dim=-1).view(-1) 107 | row, col = row[valid], col[valid] 108 | 109 | edge_dir = self.dir_type.repeat(num) 110 | edge_dir = edge_dir[valid] 111 | 112 | # deal with batches 113 | dis = torch.arange(self.batch_size, dtype=torch.int64, device=self.device) 114 | dis = dis.unsqueeze(1) * num + self.ncum[depth] # NOTE:add self.ncum[depth] 115 | row = row.unsqueeze(0) + dis 116 | col = col.unsqueeze(0) + dis 117 | edge_dir = edge_dir.unsqueeze(0) + torch.zeros_like(dis) 118 | # rowptr = torch.ops.torch_sparse.ind2ptr(row, num) 119 | return {'edge_idx': torch.stack([row.view(-1), col.view(-1)]), 120 | 'edge_dir': edge_dir.view(-1)} 121 | 122 | def _internal_edges(self, nnum, dis=0): 123 | assert(nnum % 8 == 0) 124 | d = torch.arange(0, nnum / 8, dtype=torch.int64, device=self.device) 125 | d = torch.unsqueeze(d * 8 + dis, dim=1) 126 | row = self.inter_row.unsqueeze(0) + d 127 | col = self.inter_col.unsqueeze(0) + d 128 | edge_dir = self.inter_dir.unsqueeze(0) + torch.zeros_like(d) 129 | return row.view(-1), col.view(-1), edge_dir.view(-1) 130 | 131 | def relative_dir(self, vi, vj, depth, rescale=True): 132 | xi = self.xyz[vi] 133 | xj = self.xyz[vj] 134 | 135 | # get 6 neighborhoods of xi via `self.ngh` 136 | xn = xi.unsqueeze(1) + self.ngh 137 | 138 | # rescale the coord of xj 139 | scale = torch.ones_like(vj) 140 | if rescale: 141 | dj = self.node_depth[vj] 142 | scale = torch.pow(2.0, depth - dj) 143 | # torch._assert((scale > 1.0).all(), 'vj is larger than vi') 144 | xj = xj * scale.unsqueeze(-1) 145 | 146 | # inbox testing 147 | xj = xj.unsqueeze(1) 148 | scale = scale.view(-1, 1, 1) 149 | inbox = torch.logical_and(xn >= xj, xn < xj + scale.view(-1, 1, 1)) 150 | inbox = torch.all(inbox, dim=-1) 151 | rel_dir = torch.argmax(inbox.byte(), dim=-1) 152 | return rel_dir 153 | 154 | def _node_property(self, prop, depth): 155 | return prop[self.ncum[depth]: self.ncum[depth] + self.nnum[depth]] 156 | 157 | def node_child(self, depth): 158 | return self._node_property(self.child, depth) 159 | 160 | def sparse_graph(self, depth, graph): 161 | # Add internal edges connecting sliding nodes. 162 | ncum_d = self.ncum[depth] # NOTE: add ncum_d, i.e., self.ncum[depth] 163 | row_i, col_i, dir_i = self._internal_edges(self.nnum[depth], ncum_d) 164 | 165 | # mark invalid nodes of layer (depth-1) 166 | edge_idx, edge_dir = graph['edge_idx'], graph['edge_dir'] 167 | row, col = edge_idx[0], edge_idx[1] 168 | valid_row = self.child[row] < 0 169 | valid_col = self.child[col] < 0 170 | invalid_row = torch.logical_not(valid_row) 171 | invalid_col = torch.logical_not(valid_col) 172 | valid_edges = torch.logical_and(valid_row, valid_col) 173 | invalid_row_vtx = torch.logical_and(invalid_row, valid_col) 174 | invalid_both_vtx = torch.logical_and(invalid_row, invalid_col) 175 | 176 | # deal with edges with invalid row vtx only 177 | vi, vj = row[invalid_row_vtx], col[invalid_row_vtx] 178 | rel_dir = self.relative_dir(vi, vj, depth - 1) 179 | row_o1 = self.child[vi].unsqueeze(1) * 8 + self.dir_table[rel_dir, :] 180 | row_o1 = row_o1.view(-1) + ncum_d # NOTE: add ncum_d 181 | col_o1 = vj.unsqueeze(1).repeat(1, 4).view(-1) 182 | dir_o1 = rel_dir.unsqueeze(1).repeat(1, 4).view(-1) 183 | 184 | # deal with edges with 2 invalid nodes 185 | row_o2 = torch.tensor([], dtype=torch.int64, device=self.device) 186 | col_o2 = torch.tensor([], dtype=torch.int64, device=self.device) 187 | dir_o2 = torch.tensor([], dtype=torch.int64, device=self.device) 188 | if invalid_both_vtx.any(): 189 | vi, vj = row[invalid_both_vtx], col[invalid_both_vtx] 190 | rel_dir = self.relative_dir(vi, vj, depth - 1, rescale=False) 191 | row_o2 = self.child[vi].unsqueeze(1) * 8 + self.dir_table[rel_dir, :] 192 | row_o2 = row_o2.view(-1) + ncum_d # NOTE: add ncum_d 193 | dir_o2 = rel_dir.unsqueeze(1).repeat(1, 4).view(-1) 194 | rel_dir_col = self.remap[rel_dir] 195 | col_o2 = self.child[vj].unsqueeze(1) * 8 + self.dir_table[rel_dir_col, :] 196 | col_o2 = col_o2.view(-1) + ncum_d # NOTE: add ncum_d 197 | 198 | # gather the results 199 | edge_idx = torch.stack([ 200 | torch.cat([row[valid_edges], row_i, row_o1, col_o1, row_o2]), 201 | torch.cat([col[valid_edges], col_i, col_o1, row_o1, col_o2])]) 202 | edge_dir = torch.cat([ 203 | edge_dir[valid_edges], dir_i, dir_o1, self.remap[dir_o1], dir_o2]) 204 | return {'edge_idx': edge_idx, 'edge_dir': edge_dir} 205 | 206 | def add_self_loops(self): 207 | for d in range(self.full_depth, self.depth + 1): 208 | edge_idx = self._graph[d]['edge_idx'] 209 | edge_dir = self._graph[d]['edge_dir'] 210 | row, col = edge_idx[0], edge_idx[1] 211 | unique_idx = torch.unique(row, sorted=True) 212 | dir_idx = torch.ones_like(unique_idx) * 6 213 | self.graph[d] = {'edge_idx': torch.stack([torch.cat([row, unique_idx]), 214 | torch.cat([col, unique_idx])]), 215 | 'edge_dir': torch.cat([edge_dir, dir_idx])} 216 | 217 | def calc_edge_type(self): 218 | dir_num = 7 219 | for d in range(self.full_depth, self.depth + 1): 220 | depth_num = d - self.full_depth + 1 221 | edge_idx = self._graph[d]['edge_idx'] 222 | edge_dir = self._graph[d]['edge_dir'] 223 | row, col = edge_idx[0], edge_idx[1] 224 | 225 | dr = self.node_depth[row] - self.full_depth 226 | dc = self.node_depth[col] - self.full_depth 227 | edge_type = (dr * depth_num + dc) * dir_num + edge_dir 228 | 229 | self.graph[d]['edge_type'] = edge_type 230 | 231 | def remap_node_idx(self): 232 | leaf_nodes = self.child < 0 233 | for d in range(self.full_depth, self.depth + 1): 234 | leaf_d = torch.ones(self.nnum[d], dtype=torch.bool, device=self.device) 235 | mask = torch.cat([leaf_nodes[:self.ncum[d]], leaf_d], dim=0) 236 | remap = torch.cumsum(mask.long(), dim=0) - 1 237 | self.graph[d]['edge_idx'] = remap[self.graph[d]['edge_idx']] 238 | 239 | def filter_multiple_level_edges(self): 240 | for d in range(self.full_depth, self.depth + 1): 241 | edge_idx = self.graph[d]['edge_idx'] 242 | edge_dir = self.graph[d]['edge_dir'] 243 | valid_edges = (self.node_depth[edge_idx] == d).all(dim=0) 244 | 245 | # filter edges 246 | edge_idx = edge_idx[:, valid_edges] 247 | edge_dir = edge_dir[valid_edges] 248 | 249 | self.graph[d] = {'edge_idx': edge_idx, 'edge_dir': edge_dir} 250 | 251 | def filter_coarse_to_fine_edges(self): 252 | for d in range(self.full_depth, self.depth + 1): 253 | edge_idx = self.graph[d]['edge_idx'] 254 | edge_dir = self.graph[d]['edge_dir'] 255 | 256 | edge_node_depth = self.node_depth[edge_idx] 257 | # the depth of sender nodes should be larger than receivers 258 | valid_edges = edge_node_depth[0] >= edge_node_depth[1] 259 | 260 | # filter edges 261 | edge_idx = edge_idx[:, valid_edges] 262 | edge_dir = edge_dir[valid_edges] 263 | 264 | self.graph[d] = {'edge_idx': edge_idx, 'edge_dir': edge_dir} 265 | 266 | def filter_crosslevel_edges(self): 267 | for d in range(self.full_depth, self.depth + 1): 268 | edge_idx = self.graph[d]['edge_idx'] 269 | edge_dir = self.graph[d]['edge_dir'] 270 | 271 | edge_node_depth = self.node_depth[edge_idx] 272 | valid_edges = edge_node_depth[0] == edge_node_depth[1] 273 | 274 | # filter edges 275 | edge_idx = edge_idx[:, valid_edges] 276 | edge_dir = edge_dir[valid_edges] 277 | 278 | self.graph[d] = {'edge_idx': edge_idx, 'edge_dir': edge_dir} 279 | 280 | def displace_edge_and_add_node_type(self): 281 | for d in range(self.full_depth, self.depth + 1): 282 | # displace edge index 283 | self.graph[d]['edge_idx'] -= self.ncum[d] 284 | 285 | # only one type of node 286 | zeros = torch.zeros(self.nnum[d], dtype=torch.long, device=self.device) 287 | self.graph[d]['node_type'] = zeros 288 | 289 | # used in skip connections 290 | self.graph[d]['keyd'] = self._node_property(self.keyd, d) 291 | 292 | def post_processing_for_ocnn(self): 293 | self.add_self_loops() 294 | self.filter_multiple_level_edges() 295 | self.displace_edge_and_add_node_type() 296 | self.sort_edges() 297 | 298 | def sort_edges(self): 299 | dir_num = 7 300 | for d in range(self.full_depth, self.depth + 1): 301 | edge_idx = self.graph[d]['edge_idx'] 302 | edge_dir = self.graph[d]['edge_dir'] 303 | 304 | edge_key = edge_idx[0] * dir_num + edge_dir 305 | sidx = torch.argsort(edge_key) 306 | self.graph[d]['edge_idx'] = edge_idx[:, sidx] 307 | self.graph[d]['edge_dir'] = edge_dir[sidx] 308 | 309 | def get_input_feature(self, all_leaf_nodes=True): 310 | # the initial feature of leaf nodes in the layer self.depth 311 | data = ocnn.octree_property(self.octree, 'feature', self.depth) 312 | data = data.squeeze(0).squeeze(-1).t() 313 | 314 | # the initial feature of leaf nodes in other layers 315 | if all_leaf_nodes: 316 | channel = data.shape[1] 317 | leaf_num = self.lnum[self.full_depth:self.depth].sum() 318 | zeros = torch.zeros(leaf_num, channel, device=self.device) 319 | 320 | # concat zero features with the initial features in layer depth 321 | data = torch.cat([zeros, data], dim=0) 322 | 323 | return data 324 | 325 | def add_node_keyd(self): 326 | keyd1, keyd2 = [], [] 327 | for d in range(self.full_depth, self.depth + 1): 328 | keyd = self._node_property(self.keyd, d) 329 | leaf_mask = self._node_property(self.child, d) < 0 330 | keyd1.append(keyd[leaf_mask]) 331 | keyd2.append(keyd) 332 | self.graph[d]['keyd'] = torch.cat(keyd1[:-1] + keyd2[-1:], dim=0) 333 | 334 | def add_node_xyzd(self): 335 | xyz1, xyz2 = [], [] 336 | for d in range(self.full_depth, self.depth + 1): 337 | xyzd = self._node_property(self.xyz, d) 338 | xyzf = xyzd.float() / 2 ** d # normalize to [0, 1] 339 | leaf_mask = self._node_property(self.child, d) < 0 340 | xyz1.append(xyzf[leaf_mask]) 341 | xyz2.append(xyzf) 342 | self.graph[d]['xyz'] = torch.cat(xyz1[:-1] + xyz2[-1:], dim=0) 343 | 344 | def add_node_type(self): 345 | ntype1, ntype2 = [], [] 346 | full_depth, depth = self.full_depth, self.depth 347 | for i, d in enumerate(range(full_depth, depth + 1)): 348 | ntype = d - full_depth 349 | ntype1.append(torch.ones(self.lnum[d], device=self.device) * ntype) 350 | ntype2.append(torch.ones(self.nnum[d], device=self.device) * ntype) 351 | node_type = torch.cat(ntype1[:i] + ntype2[i:i + 1], dim=0).long() 352 | self.graph[d]['node_type'] = node_type 353 | 354 | def add_node_mask(self): 355 | leaf_masks = [] 356 | full_depth, depth = self.full_depth, self.depth 357 | for i, d in enumerate(range(full_depth, depth + 1)): 358 | mask1 = self._node_property(self.child, d) < 0 359 | mask2 = torch.ones(self.nnum[d], dtype=torch.bool, device=self.device) 360 | leaf_masks.append(mask1) 361 | self.graph[d]['node_mask'] = torch.cat(leaf_masks[:i] + [mask2], dim=0) 362 | 363 | def post_processing_for_docnn(self): 364 | self.add_self_loops() 365 | # The following 2 functions are only used in ablation study. 366 | # self.filter_coarse_to_fine_edges() 367 | # self.filter_crosslevel_edges() 368 | self.remap_node_idx() 369 | self.add_node_type() 370 | self.add_node_keyd() # used in skip connects 371 | self.add_node_mask() 372 | self.sort_edges() 373 | 374 | def save(self, filename): 375 | np.save(filename + 'xyz.npy', self.xyz.cpu().numpy()) 376 | np.save(filename + 'batch.npy', self.batch.cpu().numpy()) 377 | np.save(filename + 'node_depth.npy', self.node_depth.cpu().numpy()) 378 | for d in range(self.full_depth, self.depth + 1): 379 | edge_idx = self._graph[d]['edge_idx'] 380 | np.save(filename + "edge_%d.npy" % d, edge_idx.t().cpu().numpy()) 381 | 382 | 383 | if __name__ == '__main__': 384 | octrees = ocnn.octree_samples(['octree_1', 'octree_2']) 385 | octree = ocnn.octree_batch(octrees).cuda() 386 | pdoctree = DualOctree(octree) 387 | pdoctree.save('data/batch_12_') 388 | pdoctree.add_self_loops() 389 | pdoctree.calc_edge_type() 390 | pdoctree.remap_node_idx() 391 | print('succ!') 392 | -------------------------------------------------------------------------------- /solver/solver.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Dual Octree Graph Networks 3 | # Copyright (c) 2022 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Peng-Shuai Wang 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import time 10 | import torch 11 | import torch.nn 12 | import torch.optim 13 | import torch.distributed 14 | import torch.multiprocessing 15 | import torch.utils.data 16 | import warnings 17 | from datetime import datetime 18 | from tqdm import tqdm 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | from .sampler import InfSampler, DistributedInfSampler 22 | from .config import parse_args 23 | 24 | warnings.filterwarnings("ignore", module="torch.optim.lr_scheduler") 25 | 26 | 27 | class AverageTracker: 28 | 29 | def __init__(self): 30 | self.value = None 31 | self.num = 0.0 32 | self.max_len = 76 33 | self.start_time = time.time() 34 | 35 | def update(self, value): 36 | if not value: 37 | return # empty input, return 38 | 39 | value = {key: val.detach() for key, val in value.items()} 40 | if self.value is None: 41 | self.value = value 42 | else: 43 | for key, val in value.items(): 44 | self.value[key] += val 45 | self.num += 1 46 | 47 | def average(self): 48 | return {key: val.item()/self.num for key, val in self.value.items()} 49 | 50 | @torch.no_grad() 51 | def average_all_gather(self): 52 | for key, tensor in self.value.items(): 53 | tensors_gather = [torch.ones_like(tensor) 54 | for _ in range(torch.distributed.get_world_size())] 55 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 56 | tensors = torch.stack(tensors_gather, dim=0) 57 | self.value[key] = torch.mean(tensors) 58 | 59 | def log(self, epoch, summry_writer=None, log_file=None, msg_tag='->', 60 | notes='', print_time=True): 61 | if not self.value: 62 | return # empty, return 63 | 64 | avg = self.average() 65 | msg = 'Epoch: %d' % epoch 66 | for key, val in avg.items(): 67 | msg += ', %s: %.3f' % (key, val) 68 | if summry_writer: 69 | summry_writer.add_scalar(key, val, epoch) 70 | 71 | # if the log_file is provided, save the log 72 | if log_file: 73 | with open(log_file, 'a') as fid: 74 | fid.write(msg + '\n') 75 | 76 | # memory 77 | memory = '' 78 | if torch.cuda.is_available(): 79 | # size = torch.cuda.memory_allocated() 80 | size = torch.cuda.memory_reserved(device=None) 81 | memory = ', memory: {:.3f}GB'.format(size / 2**30) 82 | 83 | # time 84 | time_str = '' 85 | if print_time: 86 | curr_time = ', time: ' + datetime.now().strftime("%Y/%m/%d %H:%M:%S") 87 | duration = ', duration: {:.2f}s'.format(time.time() - self.start_time) 88 | time_str = curr_time + duration 89 | 90 | # other notes 91 | if notes: 92 | notes = ', ' + notes 93 | 94 | msg += memory + time_str + notes 95 | 96 | # split the msg for better display 97 | chunks = [msg[i:i+self.max_len] for i in range(0, len(msg), self.max_len)] 98 | msg = (msg_tag + ' ') + ('\n' + len(msg_tag) * ' ' + ' ').join(chunks) 99 | tqdm.write(msg) 100 | 101 | 102 | class Solver: 103 | 104 | def __init__(self, FLAGS, is_master=True): 105 | self.FLAGS = FLAGS 106 | self.is_master = is_master 107 | self.world_size = len(FLAGS.SOLVER.gpu) 108 | self.device = torch.cuda.current_device() 109 | self.disable_tqdm = not (is_master and FLAGS.SOLVER.progress_bar) 110 | self.start_epoch = 1 111 | 112 | self.model = None # torch.nn.Module 113 | self.optimizer = None # torch.optim.Optimizer 114 | self.scheduler = None # torch.optim.lr_scheduler._LRScheduler 115 | self.summary_writer = None # torch.utils.tensorboard.SummaryWriter 116 | self.log_file = None # str, used to save training logs 117 | 118 | def get_model(self): 119 | raise NotImplementedError 120 | 121 | def get_dataset(self, flags): 122 | raise NotImplementedError 123 | 124 | def train_step(self, batch): 125 | raise NotImplementedError 126 | 127 | def test_step(self, batch): 128 | raise NotImplementedError 129 | 130 | def eval_step(self, batch): 131 | raise NotImplementedError 132 | 133 | def result_callback(self, avg_tracker: AverageTracker, epoch): 134 | pass # additional operations based on the avg_tracker 135 | 136 | def config_dataloader(self, disable_train_data=False): 137 | flags_train, flags_test = self.FLAGS.DATA.train, self.FLAGS.DATA.test 138 | 139 | if not disable_train_data and not flags_train.disable: 140 | self.train_loader = self.get_dataloader(flags_train) 141 | self.train_iter = iter(self.train_loader) 142 | 143 | if not flags_test.disable: 144 | self.test_loader = self.get_dataloader(flags_test) 145 | self.test_iter = iter(self.test_loader) 146 | 147 | def get_dataloader(self, flags): 148 | dataset, collate_fn = self.get_dataset(flags) 149 | 150 | if self.world_size > 1: 151 | sampler = DistributedInfSampler(dataset, shuffle=flags.shuffle) 152 | else: 153 | sampler = InfSampler(dataset, shuffle=flags.shuffle) 154 | 155 | data_loader = torch.utils.data.DataLoader( 156 | dataset, batch_size=flags.batch_size, num_workers=flags.num_workers, 157 | sampler=sampler, collate_fn=collate_fn, pin_memory=True) 158 | return data_loader 159 | 160 | def config_model(self): 161 | flags = self.FLAGS.MODEL 162 | model = self.get_model(flags) 163 | model.cuda(device=self.device) 164 | if self.world_size > 1: 165 | if flags.sync_bn: 166 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 167 | model = torch.nn.parallel.DistributedDataParallel( 168 | module=model, device_ids=[self.device], 169 | output_device=self.device, broadcast_buffers=False, 170 | find_unused_parameters=flags.find_unused_parameters) 171 | if self.is_master: 172 | print(model) 173 | self.model = model 174 | 175 | def configure_optimizer(self): 176 | flags = self.FLAGS.SOLVER 177 | # The learning rate scales with regard to the world_size 178 | lr = flags.lr * self.world_size 179 | if flags.type.lower() == 'sgd': 180 | self.optimizer = torch.optim.SGD( 181 | self.model.parameters(), lr=lr, weight_decay=flags.weight_decay, 182 | momentum=0.9) 183 | elif flags.type.lower() == 'adam': 184 | self.optimizer = torch.optim.Adam( 185 | self.model.parameters(), lr=lr, weight_decay=flags.weight_decay) 186 | elif flags.type.lower() == 'adamw': 187 | self.optimizer = torch.optim.AdamW( 188 | self.model.parameters(), lr=lr, weight_decay=flags.weight_decay) 189 | else: 190 | raise ValueError 191 | 192 | if flags.lr_type == 'step': 193 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR( 194 | self.optimizer, milestones=flags.step_size, gamma=0.1) 195 | elif flags.lr_type == 'cos': 196 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 197 | self.optimizer, flags.max_epoch, eta_min=0) 198 | elif flags.lr_type == 'poly': 199 | def poly(epoch): return (1 - epoch / flags.max_epoch) ** flags.lr_power 200 | self.scheduler = torch.optim.lr_scheduler.LambdaLR( 201 | self.optimizer, poly) 202 | elif flags.lr_type == 'constant': 203 | self.scheduler = torch.optim.lr_scheduler.LambdaLR( 204 | self.optimizer, lambda epoch: 1) 205 | else: 206 | raise ValueError 207 | 208 | def configure_log(self, set_writer=True): 209 | self.logdir = self.FLAGS.SOLVER.logdir 210 | self.ckpt_dir = os.path.join(self.logdir, 'checkpoints') 211 | self.log_file = os.path.join(self.logdir, 'log.csv') 212 | 213 | if self.is_master: 214 | tqdm.write('Logdir: ' + self.logdir) 215 | 216 | if self.is_master and set_writer: 217 | self.summary_writer = SummaryWriter(self.logdir, flush_secs=20) 218 | if not os.path.exists(self.ckpt_dir): 219 | os.makedirs(self.ckpt_dir) 220 | 221 | def train_epoch(self, epoch): 222 | self.model.train() 223 | if self.world_size > 1: 224 | self.train_loader.sampler.set_epoch(epoch) 225 | 226 | train_tracker = AverageTracker() 227 | rng = range(len(self.train_loader)) 228 | log_per_iter = self.FLAGS.SOLVER.log_per_iter 229 | for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm): 230 | self.optimizer.zero_grad() 231 | 232 | # forward 233 | batch = self.train_iter.next() 234 | batch['iter_num'] = it 235 | batch['epoch'] = epoch 236 | output = self.train_step(batch) 237 | 238 | # backward 239 | output['train/loss'].backward() 240 | self.optimizer.step() 241 | 242 | # track the averaged tensors 243 | train_tracker.update(output) 244 | 245 | # output intermediate logs 246 | if self.is_master and log_per_iter > 0 and it % log_per_iter == 0: 247 | notes = 'iter: %d' % it 248 | train_tracker.log(epoch, msg_tag='- ', notes=notes, print_time=False) 249 | 250 | # save logs 251 | if self.world_size > 1: 252 | train_tracker.average_all_gather() 253 | if self.is_master: 254 | train_tracker.log(epoch, self.summary_writer) 255 | 256 | def test_epoch(self, epoch): 257 | self.model.eval() 258 | test_tracker = AverageTracker() 259 | rng = range(len(self.test_loader)) 260 | for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm): 261 | # forward 262 | batch = self.test_iter.next() 263 | batch['iter_num'] = it 264 | batch['epoch'] = epoch 265 | # with torch.no_grad(): 266 | output = self.test_step(batch) 267 | 268 | # track the averaged tensors 269 | test_tracker.update(output) 270 | 271 | if self.world_size > 1: 272 | test_tracker.average_all_gather() 273 | if self.is_master: 274 | test_tracker.log(epoch, self.summary_writer, self.log_file, msg_tag='=>') 275 | self.result_callback(test_tracker, epoch) 276 | 277 | def eval_epoch(self, epoch): 278 | self.model.eval() 279 | eval_step = min(self.FLAGS.SOLVER.eval_step, len(self.test_loader)) 280 | if eval_step < 1: 281 | eval_step = len(self.test_loader) 282 | for it in tqdm(range(eval_step), ncols=80, leave=False): 283 | batch = self.test_iter.next() 284 | batch['iter_num'] = it 285 | batch['epoch'] = epoch 286 | with torch.no_grad(): 287 | self.eval_step(batch) 288 | 289 | def save_checkpoint(self, epoch): 290 | if not self.is_master: 291 | return 292 | 293 | # clean up 294 | ckpts = sorted(os.listdir(self.ckpt_dir)) 295 | ckpts = [ck for ck in ckpts if ck.endswith('.pth') or ck.endswith('.tar')] 296 | if len(ckpts) > self.FLAGS.SOLVER.ckpt_num: 297 | for ckpt in ckpts[:-self.FLAGS.SOLVER.ckpt_num]: 298 | os.remove(os.path.join(self.ckpt_dir, ckpt)) 299 | 300 | # save ckpt 301 | model_dict = self.model.module.state_dict() \ 302 | if self.world_size > 1 else self.model.state_dict() 303 | ckpt_name = os.path.join(self.ckpt_dir, '%05d' % epoch) 304 | torch.save(model_dict, ckpt_name + '.model.pth') 305 | torch.save({'model_dict': model_dict, 'epoch': epoch, 306 | 'optimizer_dict': self.optimizer.state_dict(), 307 | 'scheduler_dict': self.scheduler.state_dict()}, 308 | ckpt_name + '.solver.tar') 309 | 310 | def load_checkpoint(self): 311 | ckpt = self.FLAGS.SOLVER.ckpt 312 | if not ckpt: 313 | # If ckpt is empty, then get the latest checkpoint from ckpt_dir 314 | if not os.path.exists(self.ckpt_dir): 315 | return 316 | ckpts = sorted(os.listdir(self.ckpt_dir)) 317 | ckpts = [ck for ck in ckpts if ck.endswith('solver.tar')] 318 | if len(ckpts) > 0: 319 | ckpt = os.path.join(self.ckpt_dir, ckpts[-1]) 320 | if not ckpt: 321 | return # return if ckpt is still empty 322 | 323 | # load trained model 324 | # check: map_location = {'cuda:0' : 'cuda:%d' % self.rank} 325 | trained_dict = torch.load(ckpt, map_location='cuda') 326 | if ckpt.endswith('.solver.tar'): 327 | model_dict = trained_dict['model_dict'] 328 | self.start_epoch = trained_dict['epoch'] + 1 # !!! add 1 329 | if self.optimizer: 330 | self.optimizer.load_state_dict(trained_dict['optimizer_dict']) 331 | if self.scheduler: 332 | self.scheduler.load_state_dict(trained_dict['scheduler_dict']) 333 | else: 334 | model_dict = trained_dict 335 | model = self.model.module if self.world_size > 1 else self.model 336 | model.load_state_dict(model_dict) 337 | 338 | # print messages 339 | if self.is_master: 340 | tqdm.write('Load the checkpoint: %s' % ckpt) 341 | tqdm.write('The start_epoch is %d' % self.start_epoch) 342 | 343 | def train(self): 344 | self.config_model() 345 | self.config_dataloader() 346 | self.configure_optimizer() 347 | self.configure_log() 348 | self.load_checkpoint() 349 | 350 | rng = range(self.start_epoch, self.FLAGS.SOLVER.max_epoch+1) 351 | for epoch in tqdm(rng, ncols=80, disable=self.disable_tqdm): 352 | # training epoch 353 | self.train_epoch(epoch) 354 | 355 | # update learning rate 356 | self.scheduler.step() 357 | if self.is_master: 358 | lr = self.scheduler.get_last_lr() # lr is a list 359 | self.summary_writer.add_scalar('train/lr', lr[0], epoch) 360 | 361 | # testing or not 362 | if epoch % self.FLAGS.SOLVER.test_every_epoch != 0: 363 | continue 364 | 365 | # testing epoch 366 | self.test_epoch(epoch) 367 | 368 | # checkpoint 369 | self.save_checkpoint(epoch) 370 | 371 | # sync and exit 372 | if self.world_size > 1: 373 | torch.distributed.barrier() 374 | 375 | def test(self): 376 | self.config_model() 377 | self.configure_log(set_writer=False) 378 | self.config_dataloader(disable_train_data=True) 379 | self.load_checkpoint() 380 | self.test_epoch(epoch=0) 381 | 382 | def evaluate(self): 383 | self.config_model() 384 | self.configure_log(set_writer=False) 385 | self.config_dataloader(disable_train_data=True) 386 | self.load_checkpoint() 387 | for epoch in tqdm(range(self.FLAGS.SOLVER.eval_epoch), ncols=80): 388 | self.eval_epoch(epoch) 389 | 390 | def profile(self): 391 | ''' Set `DATA.train.num_workers 0` when using this function''' 392 | self.config_model() 393 | self.config_dataloader() 394 | 395 | # warm up 396 | batch = next(iter(self.train_loader)) 397 | for _ in range(3): 398 | output = self.train_step(batch) 399 | output['train/loss'].backward() 400 | 401 | # profile 402 | with torch.autograd.profiler.profile( 403 | use_cuda=True, profile_memory=True, 404 | with_stack=True, record_shapes=True) as prof: 405 | output = self.train_step(batch) 406 | output['train/loss'].backward() 407 | 408 | json = os.path.join(self.FLAGS.SOLVER.logdir, 'trace.json') 409 | print('Save the profile into: ' + json) 410 | prof.export_chrome_trace(json) 411 | print(prof.key_averages(group_by_stack_n=10) 412 | .table(sort_by="cuda_time_total", row_limit=10)) 413 | print(prof.key_averages(group_by_stack_n=10) 414 | .table(sort_by="cuda_memory_usage", row_limit=10)) 415 | 416 | def run(self): 417 | eval('self.%s()' % self.FLAGS.SOLVER.run) 418 | 419 | @classmethod 420 | def update_configs(cls): 421 | pass 422 | 423 | @classmethod 424 | def worker(cls, gpu, FLAGS): 425 | world_size = len(FLAGS.SOLVER.gpu) 426 | if world_size > 1: 427 | # Set the GPU to use. 428 | torch.cuda.set_device(gpu) 429 | # Initialize the process group. Currently, the code only supports the 430 | # `single node + multiple GPU` mode, so the rank is equal to gpu id. 431 | torch.distributed.init_process_group( 432 | backend='nccl', init_method=FLAGS.SOLVER.dist_url, 433 | world_size=world_size, rank=gpu) 434 | # Master process is responsible for logging, writing and loading 435 | # checkpoints. In the multi GPU setting, we assign the master role to the 436 | # rank 0 process. 437 | is_master = gpu == 0 438 | the_solver = cls(FLAGS, is_master) 439 | else: 440 | the_solver = cls(FLAGS, is_master=True) 441 | the_solver.run() 442 | 443 | @classmethod 444 | def main(cls): 445 | cls.update_configs() 446 | FLAGS = parse_args() 447 | 448 | num_gpus = len(FLAGS.SOLVER.gpu) 449 | if num_gpus > 1: 450 | torch.multiprocessing.spawn(cls.worker, nprocs=num_gpus, args=(FLAGS,)) 451 | else: 452 | cls.worker(0, FLAGS) 453 | --------------------------------------------------------------------------------