├── models
├── utils
│ ├── __init__.py
│ ├── spmm.py
│ └── scatter.py
├── __init__.py
├── graph_resnet.py
├── graph_unet.py
├── octree_ounet.py
├── graph_lenet.py
├── graph_ae.py
├── graph_slounet.py
├── mpu.py
├── graph_ounet.py
├── modules.py
└── dual_octree.py
├── teaser.png
├── .gitmodules
├── .gitignore
├── requirements.txt
├── solver
├── README.md
├── __init__.py
├── sampler.py
├── dataset.py
├── config.py
└── solver.py
├── losses
├── __init__.py
└── loss.py
├── CODE_OF_CONDUCT.md
├── datasets
├── __init__.py
├── utils.py
├── pointcloud_eval.py
├── synthetic_room.py
├── shapenet.py
└── pointcloud.py
├── configs
├── shapes.yaml
├── dfaust_eval.yaml
├── shapenet_eval.yaml
├── shapenet_ae_eval.yaml
├── shapenet_unseen5.yaml
├── synthetic_room_eval.yaml
├── finetune.yaml
├── cls_m40.yaml
├── dfaust.yaml
├── shapenet_ae.yaml
├── shapenet.yaml
└── synthetic_room.yaml
├── LICENSE
├── SUPPORT.md
├── SECURITY.md
├── tools
├── compute_metrics.py
├── dfaust.py
├── room.py
└── shapenet.py
├── dualocnn.py
├── utils.py
└── README.md
/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/DualOctreeGNN/HEAD/teaser.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "solver"]
2 | path = solver
3 | url = https://github.com/wang-ps/solver-pytorch.git
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | dist
3 | __pycache__
4 | *.egg-info/
5 | .vscode/
6 | *.pyd
7 | *.so
8 | *.octree
9 | *.points
10 | *.xyz
11 | *.json
12 | debug.log
13 | data
14 | logs
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy
3 | tqdm
4 | yacs
5 | scipy
6 | plyfile
7 | tensorboard
8 | scikit-image
9 | trimesh
10 | wget
11 | mesh2sdf
12 | setuptools==59.5.0
13 | matplotlib
14 |
--------------------------------------------------------------------------------
/solver/README.md:
--------------------------------------------------------------------------------
1 | # The Solver for PyTorch
2 |
3 | This repository contains part of the code from the official implementation of
4 | [O-CNN](https://github.com/microsoft/O-CNN.git).
5 | And the code is released under the MIT license.
6 |
7 | The code can be used for training/testing PyTorch models.
8 |
9 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from .loss import shapenet_loss, dfaust_loss, synthetic_room_loss
9 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from . import config
9 | from .config import get_config, parse_args
10 |
11 | from . import solver
12 | from .solver import Solver
13 |
14 | from . import dataset
15 | from .dataset import Dataset
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from . import octree_ounet
9 | from . import graph_lenet
10 | from . import graph_resnet
11 | from . import graph_ounet
12 | from . import graph_slounet
13 | from . import graph_unet
14 | from . import graph_ae
15 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | from .shapenet import get_shapenet_dataset
9 | from .pointcloud import get_pointcloud_dataset, get_singlepointcloud_dataset
10 | from .pointcloud_eval import get_pointcloud_eval_dataset
11 | from .synthetic_room import get_synthetic_room_dataset
12 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import ocnn
9 | import torch
10 |
11 |
12 | def collate_func(batch):
13 | output = ocnn.collate_octrees(batch)
14 |
15 | if 'pos' in output:
16 | batch_idx = torch.cat([torch.ones(pos.size(0), 1) * i
17 | for i, pos in enumerate(output['pos'])], dim=0)
18 | pos = torch.cat(output['pos'], dim=0)
19 | output['pos'] = torch.cat([pos, batch_idx], dim=1)
20 |
21 | for key in ['grad', 'sdf', 'occu', 'weight']:
22 | if key in output:
23 | output[key] = torch.cat(output[key], dim=0)
24 |
25 | return output
26 |
--------------------------------------------------------------------------------
/configs/shapes.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: evaluate
11 | logdir: logs/shapes_eval
12 | ckpt: logs/dfaust/dfaust/checkpoints/00600.model.pth
13 | resolution: 420
14 | sdf_scale: 0.9
15 |
16 |
17 | DATA:
18 | test:
19 | name: pointcloud
20 | point_scale: 0.9
21 |
22 | # octree building
23 | depth: 8
24 | full_depth: 3
25 | node_dis: True
26 | split_label: True
27 | offset: 0.0
28 |
29 | # data loading
30 | location: data/shapes
31 | filelist: data/shapes/filelist.txt
32 | batch_size: 1
33 | shuffle: False
34 | # num_workers: 0
35 |
36 |
37 | MODEL:
38 | name: graph_unet
39 | resblock_type: basic
40 |
41 | depth: 8
42 | full_depth: 3
43 | depth_out: 8
44 | channel: 4
45 | nout: 4
46 |
--------------------------------------------------------------------------------
/configs/dfaust_eval.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: evaluate
11 | logdir: logs/dfaust_eval/dfaust
12 | ckpt: logs/dfaust/dfaust/checkpoints/00600.model.pth
13 | resolution: 300
14 | sdf_scale: 0.9
15 |
16 |
17 | DATA:
18 | test:
19 | name: pointcloud
20 | point_scale: 1.0
21 |
22 | # octree building
23 | depth: 8
24 | full_depth: 3
25 | node_dis: True
26 | split_label: True
27 | offset: 0.0
28 |
29 | # data loading
30 | location: data/dfaust/dataset
31 | filelist: data/dfaust/filelist/test.txt
32 | batch_size: 1
33 | shuffle: False
34 | # num_workers: 0
35 |
36 |
37 | MODEL:
38 | name: graph_unet
39 | resblock_type: basic
40 |
41 | depth: 8
42 | full_depth: 3
43 | depth_out: 8
44 | channel: 4
45 | nout: 4
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/configs/shapenet_eval.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: evaluate
11 | logdir: logs/shapenet_eval/test
12 | ckpt: logs/shapenet/shapenet/checkpoints/00300.model.pth
13 | sdf_scale: 0.9
14 | resolution: 128
15 |
16 |
17 | DATA:
18 | test:
19 | name: pointcloud_eval
20 |
21 | # octree building
22 | depth: 6
23 | offset: 0.0
24 | full_depth: 3
25 | node_dis: True
26 | split_label: True
27 |
28 | # data loading
29 | # location: data/ShapeNet/dataset # the original testing data
30 | location: data/ShapeNet/test.input # the generated testing data
31 | filelist: data/ShapeNet/filelist/test.txt
32 | batch_size: 1
33 | shuffle: False
34 | in_memory: False
35 | # num_workers: 0
36 |
37 |
38 | MODEL:
39 | name: graph_ounet
40 |
41 | channel: 5
42 | depth: 6
43 | nout: 4
44 | depth_out: 6
45 | full_depth: 3
46 | bottleneck: 4
47 |
48 | resblock_type: basic
49 |
--------------------------------------------------------------------------------
/configs/shapenet_ae_eval.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: evaluate
11 | logdir: logs/shapenet_eval/ae
12 | ckpt: logs/shapenet/ae/checkpoints/00300.model.pth
13 | sdf_scale: 0.9
14 | resolution: 160
15 |
16 |
17 | DATA:
18 | test:
19 | name: shapenet
20 |
21 | # octree building
22 | depth: 6
23 | offset: 0.0
24 | full_depth: 2
25 | node_dis: True
26 | split_label: True
27 |
28 | # no data augmentation
29 | distort: False
30 |
31 | # data loading
32 | location: data/ShapeNet/dataset
33 | filelist: data/ShapeNet/filelist/test_im.txt
34 | batch_size: 1
35 | shuffle: False
36 | load_sdf: False
37 | # num_workers: 0
38 |
39 |
40 | MODEL:
41 | name: graph_ae
42 |
43 | channel: 4
44 | depth: 6
45 | nout: 4
46 | depth_out: 6
47 | full_depth: 2
48 | bottleneck: 4
49 | resblock_type: basic
50 |
51 | LOSS:
52 | name: shapenet
53 | loss_type: sdf_reg_loss
54 |
--------------------------------------------------------------------------------
/configs/shapenet_unseen5.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: evaluate
11 | logdir: logs/shapenet_eval/unseen5
12 | ckpt: logs/shapenet/shapenet/checkpoints/00300.model.pth
13 | sdf_scale: 0.9
14 | resolution: 128
15 |
16 |
17 | DATA:
18 | test:
19 | name: shapenet
20 |
21 | # octree building
22 | depth: 6
23 | offset: 0.0
24 | full_depth: 3
25 | node_dis: True
26 | split_label: True
27 |
28 | # data augmentation, add noise only
29 | distort: True
30 |
31 | # data loading
32 | location: data/ShapeNet/dataset.unseen5 # the original testing data
33 | filelist: data/ShapeNet/filelist/test_unseen5.txt
34 | batch_size: 1
35 | load_sdf: False
36 | shuffle: False
37 | in_memory: False
38 | # num_workers: 0
39 |
40 |
41 | MODEL:
42 | name: graph_ounet
43 |
44 | channel: 5
45 | depth: 6
46 | nout: 4
47 | depth_out: 6
48 | full_depth: 3
49 | bottleneck: 4
50 |
51 | resblock_type: basic
52 |
--------------------------------------------------------------------------------
/configs/synthetic_room_eval.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: evaluate
11 | logdir: logs/room_eval/room
12 | ckpt: logs/room/room/checkpoints/00900.model.pth
13 | sdf_scale: 0.9
14 | resolution: 280
15 |
16 | DATA:
17 | test:
18 | name: pointcloud_eval
19 | point_scale: 0.6
20 |
21 | # octree building
22 | depth: 7
23 | offset: 0.0
24 | full_depth: 3
25 | node_dis: True
26 | split_label: True
27 |
28 | # data augmentation, add noise only
29 | # distort: True
30 |
31 | # data loading
32 | # location: data/room/synthetic_room_dataset
33 | location: data/room/test.input # the generated testing data
34 | filelist: data/room/filelist/test.txt
35 | batch_size: 1
36 | shuffle: False
37 | in_memory: False
38 | # num_workers: 0
39 |
40 |
41 | MODEL:
42 | name: graph_ounet
43 |
44 | channel: 5
45 | depth: 7
46 | nout: 4
47 | depth_out: 7
48 | full_depth: 3
49 | bottleneck: 4
50 | resblock_type: basic
51 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/configs/finetune.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: train
11 |
12 | logdir: logs/shapes/finetune
13 | max_epoch: 6000
14 | test_every_epoch: 100
15 | log_per_iter: 50
16 | ckpt_num: 200
17 | ckpt: ''
18 |
19 | # optimizer
20 | type: adamw
21 | weight_decay: 0.01 # default value of adamw
22 | lr: 0.0001
23 |
24 | # learning rate
25 | lr_type: constant
26 |
27 |
28 | DATA:
29 | train: &id001
30 | name: singlepointcloud
31 | point_scale: 0.9
32 | point_sample_num: 100000
33 |
34 | # octree building
35 | depth: 8
36 | full_depth: 3
37 | node_dis: True
38 | split_label: True
39 | offset: 0.0
40 |
41 | # data loading
42 | location: data/Shapes
43 | filelist: data/Shapes/filelist/lucy.txt
44 | batch_size: 1
45 | # num_workers: 0
46 |
47 | test: *id001
48 |
49 |
50 | MODEL:
51 | name: graph_unet
52 | resblock_type: basic
53 | find_unused_parameters: True
54 |
55 | depth: 8
56 | full_depth: 3
57 | depth_out: 8
58 | channel: 4
59 | nout: 4
60 |
61 | LOSS:
62 | name: dfaust
63 | loss_type: possion_grad_loss
64 |
--------------------------------------------------------------------------------
/configs/cls_m40.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: train
11 | type: sgd
12 |
13 | logdir: logs/m40/m40
14 | max_epoch: 300
15 | test_every_epoch: 5
16 |
17 | # lr: 0.001 # default value of adamw
18 | # weight_decay: 0.01 # default value of adamw
19 | step_size: (120,180,240)
20 | ckpt_num: 20
21 |
22 | DATA:
23 | train:
24 | # octree building
25 | depth: 5
26 | offset: 0.016
27 |
28 | # data augmentations
29 | distort: True
30 | angle: (0, 0, 5) # small rotation along z axis
31 | interval: (1, 1, 1)
32 | scale: 0.25
33 | jitter: 0.125
34 |
35 | # data loading
36 | location: data/ModelNet40/ModelNet40.points
37 | filelist: data/ModelNet40/m40_train_points_list.txt
38 | batch_size: 32
39 | shuffle: True
40 | # num_workers: 0
41 |
42 | test:
43 | # octree building
44 | depth: 5
45 | offset: 0.016
46 |
47 | # data augmentations
48 | distort: False
49 | angle: (0, 0, 5) # small rotation along z axis
50 | interval: (1, 1, 1)
51 | scale: 0.25
52 | jitter: 0.125
53 |
54 | # data loading
55 | location: data/ModelNet40/ModelNet40.points
56 | filelist: data/ModelNet40/m40_test_points_list.txt
57 | batch_size: 32
58 | shuffle: False
59 | # num_workers: 0
60 |
61 | MODEL:
62 | name: lenet
63 | channel: 3
64 | nout: 40
65 | depth: 5
66 |
67 | LOSS:
68 | num_class: 40
--------------------------------------------------------------------------------
/configs/dfaust.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: train
11 |
12 | logdir: logs/dfaust/dfaust
13 | max_epoch: 600
14 | test_every_epoch: 10
15 | log_per_iter: 50
16 | ckpt_num: 200
17 |
18 | # optimizer
19 | type: adamw
20 | weight_decay: 0.01 # default value of adamw
21 | lr: 0.001 # default value of adamw
22 |
23 | # learning rate
24 | lr_type: poly
25 | step_size: (200,300)
26 |
27 |
28 | DATA:
29 | train:
30 | name: pointcloud
31 | point_scale: 1.0
32 |
33 | # octree building
34 | depth: 8
35 | full_depth: 3
36 | node_dis: True
37 | split_label: True
38 | offset: 0.0
39 |
40 | # data loading
41 | location: data/dfaust/dataset
42 | filelist: data/dfaust/filelist/train.txt
43 | batch_size: 16
44 | # num_workers: 0
45 |
46 | test:
47 | name: pointcloud
48 | point_scale: 1.0
49 |
50 | # octree building
51 | depth: 8
52 | full_depth: 3
53 | node_dis: True
54 | split_label: True
55 | offset: 0.0
56 |
57 | # data loading
58 | location: data/dfaust/dataset
59 | filelist: data/dfaust/filelist/test.txt
60 | batch_size: 4
61 | # num_workers: 0
62 |
63 |
64 | MODEL:
65 | name: graph_unet
66 | resblock_type: basic
67 | find_unused_parameters: True
68 |
69 | depth: 8
70 | full_depth: 3
71 | depth_out: 8
72 | channel: 4
73 | nout: 4
74 |
75 | LOSS:
76 | name: dfaust
77 | loss_type: possion_grad_loss
78 |
--------------------------------------------------------------------------------
/solver/sampler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from torch.utils.data import Sampler, DistributedSampler, Dataset
10 |
11 |
12 | class InfSampler(Sampler):
13 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None:
14 | self.dataset = dataset
15 | self.shuffle = shuffle
16 | self.reset_sampler()
17 |
18 | def reset_sampler(self):
19 | num = len(self.dataset)
20 | indices = torch.randperm(num) if self.shuffle else torch.arange(num)
21 | self.indices = indices.tolist()
22 | self.iter_num = 0
23 |
24 | def __iter__(self):
25 | return self
26 |
27 | def __next__(self):
28 | value = self.indices[self.iter_num]
29 | self.iter_num = self.iter_num + 1
30 |
31 | if self.iter_num >= len(self.indices):
32 | self.reset_sampler()
33 | return value
34 |
35 | def __len__(self):
36 | return len(self.dataset)
37 |
38 |
39 | class DistributedInfSampler(DistributedSampler):
40 | def __init__(self, dataset: Dataset, shuffle: bool = True) -> None:
41 | super().__init__(dataset, shuffle=shuffle)
42 | self.reset_sampler()
43 |
44 | def reset_sampler(self):
45 | self.indices = list(super().__iter__())
46 | self.iter_num = 0
47 |
48 | def __iter__(self):
49 | return self
50 |
51 | def __next__(self):
52 | value = self.indices[self.iter_num]
53 | self.iter_num = self.iter_num + 1
54 |
55 | if self.iter_num >= len(self.indices):
56 | self.reset_sampler()
57 | return value
58 |
--------------------------------------------------------------------------------
/datasets/pointcloud_eval.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import ocnn
10 | import torch
11 | import numpy as np
12 | from plyfile import PlyData
13 |
14 | from solver import Dataset
15 |
16 |
17 | class Transform:
18 | r"""Load point clouds from ply files, rescale the points and build octree.
19 | Used to evaluate the network trained on ShapeNet."""
20 |
21 | def __init__(self, flags):
22 | self.flags = flags
23 |
24 | self.point_scale = flags.point_scale
25 | self.points2octree = ocnn.Points2Octree(**flags)
26 |
27 | def __call__(self, points, idx):
28 | # After normalization, the points are in [-1, 1]
29 | pts = points[:, :3] / self.point_scale
30 |
31 | # construct the points
32 | ones = torch.ones(pts.shape[0], dtype=torch.float32)
33 | points = ocnn.points_new(pts, torch.Tensor(), ones, torch.Tensor())
34 | points, _ = ocnn.clip_points(points, [-1.0]*3, [1.0]*3)
35 |
36 | # transform points to octree
37 | octree = self.points2octree(points)
38 |
39 | return {'points_in': points, 'octree_in': octree}
40 |
41 |
42 | def read_file(filename: str):
43 | plydata = PlyData.read(filename + '.ply')
44 | vtx = plydata['vertex']
45 | points = np.stack([vtx['x'], vtx['y'], vtx['z']], axis=1).astype(np.float32)
46 | output = torch.from_numpy(points.astype(np.float32))
47 | return output
48 |
49 |
50 | def get_pointcloud_eval_dataset(flags):
51 | transform = Transform(flags)
52 | dataset = Dataset(flags.location, flags.filelist, transform,
53 | read_file=read_file, in_memory=flags.in_memory)
54 | return dataset, ocnn.collate_octrees
55 |
--------------------------------------------------------------------------------
/configs/shapenet_ae.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: train
11 |
12 | logdir: logs/shapenet/ae
13 | max_epoch: 300
14 | test_every_epoch: 20
15 | log_per_iter: 50
16 | ckpt_num: 40
17 |
18 | # optimizer
19 | type: adamw
20 | weight_decay: 0.01 # default value of adamw
21 | lr: 0.001 # default value of adamw
22 |
23 | # learning rate
24 | lr_type: poly
25 | step_size: (160,240)
26 |
27 | DATA:
28 | train:
29 | name: shapenet
30 |
31 | # octree building
32 | depth: 6
33 | offset: 0.0
34 | full_depth: 2
35 | node_dis: True
36 | split_label: True
37 |
38 | # no data augmentation
39 | distort: False
40 |
41 | # data loading
42 | location: data/ShapeNet/dataset
43 | filelist: data/ShapeNet/filelist/train_im.txt
44 | load_sdf: True
45 | batch_size: 16
46 | shuffle: True
47 | # num_workers: 0
48 |
49 | test:
50 | name: shapenet
51 |
52 | # octree building
53 | depth: 6
54 | offset: 0.0
55 | full_depth: 2
56 | node_dis: True
57 | split_label: True
58 |
59 | # no data augmentation
60 | distort: False
61 |
62 | # data loading
63 | location: data/ShapeNet/dataset
64 | filelist: data/ShapeNet/filelist/val_im.txt
65 | batch_size: 4
66 | load_sdf: True
67 | shuffle: False
68 | # num_workers: 0
69 |
70 |
71 | MODEL:
72 | name: graph_ae
73 |
74 | channel: 4
75 | depth: 6
76 | nout: 4
77 | depth_out: 6
78 | full_depth: 2
79 | bottleneck: 4
80 | resblock_type: basic
81 |
82 | LOSS:
83 | name: shapenet
84 | loss_type: sdf_reg_loss
85 |
--------------------------------------------------------------------------------
/configs/shapenet.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: train
11 |
12 | logdir: logs/shapenet/shapenet
13 | max_epoch: 300
14 | test_every_epoch: 20
15 | log_per_iter: 50
16 | ckpt_num: 40
17 |
18 | # optimizer
19 | type: adamw
20 | weight_decay: 0.01 # default value of adamw
21 | lr: 0.001 # default value of adamw
22 |
23 | # learning rate
24 | lr_type: poly
25 | step_size: (160,240)
26 |
27 | DATA:
28 | train:
29 | name: shapenet
30 |
31 | # octree building
32 | depth: 6
33 | offset: 0.0
34 | full_depth: 3
35 | node_dis: True
36 | split_label: True
37 |
38 | # data augmentation, add noise only
39 | distort: True
40 |
41 | # data loading
42 | location: data/ShapeNet/dataset
43 | filelist: data/ShapeNet/filelist/train.txt
44 | load_sdf: True
45 | batch_size: 16
46 | shuffle: True
47 | # num_workers: 0
48 |
49 | test:
50 | name: shapenet
51 |
52 | # octree building
53 | depth: 6
54 | offset: 0.0
55 | full_depth: 3
56 | node_dis: True
57 | split_label: True
58 |
59 | # data augmentation, add noise only
60 | distort: True
61 |
62 | # data loading
63 | location: data/ShapeNet/dataset
64 | filelist: data/ShapeNet/filelist/val.txt
65 | batch_size: 8
66 | load_sdf: True
67 | shuffle: False
68 | # num_workers: 0
69 |
70 |
71 | MODEL:
72 | name: graph_ounet
73 |
74 | channel: 5
75 | depth: 6
76 | nout: 4
77 | depth_out: 6
78 | full_depth: 3
79 | bottleneck: 4
80 | resblock_type: basic
81 |
82 | LOSS:
83 | name: shapenet
84 | loss_type: sdf_reg_loss
85 |
--------------------------------------------------------------------------------
/configs/synthetic_room.yaml:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | SOLVER:
9 | gpu: 0,
10 | run: train
11 |
12 | logdir: logs/room/room
13 | max_epoch: 900
14 | test_every_epoch: 10
15 | log_per_iter: 40
16 | ckpt_num: 20
17 |
18 | type: adamw
19 | lr: 0.001 # default value of adamw
20 | weight_decay: 0.01 # default value of adamw
21 | lr_type: poly
22 | step_size: (80,120)
23 |
24 |
25 | DATA:
26 | train:
27 | name: synthetic_room
28 |
29 | # octree building
30 | depth: 7
31 | offset: 0.0
32 | full_depth: 3
33 | node_dis: True
34 | split_label: True
35 |
36 | # data augmentation, add noise only
37 | distort: True
38 |
39 | # data loading
40 | location: data/room/synthetic_room_dataset
41 | filelist: data/room/filelist/train.txt
42 | load_occu: True
43 | sample_surf_points: True
44 | batch_size: 16
45 | shuffle: True
46 | # num_workers: 0
47 |
48 | test:
49 | name: synthetic_room
50 |
51 |
52 | # octree building
53 | depth: 7
54 | offset: 0.0
55 | full_depth: 3
56 | node_dis: True
57 | split_label: True
58 |
59 | # data augmentation, add noise only
60 | distort: True
61 |
62 | # data loading
63 | location: data/room/synthetic_room_dataset
64 | filelist: data/room/filelist/val.txt
65 | load_occu: True
66 | sample_surf_points: True
67 | batch_size: 8
68 | shuffle: False
69 | # num_workers: 0
70 |
71 |
72 | MODEL:
73 | name: graph_ounet
74 |
75 | channel: 5
76 | depth: 7
77 | nout: 4
78 | depth_out: 7
79 | full_depth: 3
80 | bottleneck: 4
81 | resblock_type: basic
82 |
83 | LOSS:
84 | name: synthetic_room
85 | loss_type: possion_grad_loss
86 |
--------------------------------------------------------------------------------
/models/utils/spmm.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from .scatter import scatter_add
10 |
11 |
12 | def spmm(index, value, m, n, matrix):
13 | """Matrix product of sparse matrix with dense matrix.
14 |
15 | Args:
16 | index (:class:`LongTensor`): The index tensor of sparse matrix.
17 | value (:class:`Tensor`): The value tensor of sparse matrix.
18 | m (int): The first dimension of corresponding dense matrix.
19 | n (int): The second dimension of corresponding dense matrix.
20 | matrix (:class:`Tensor`): The dense matrix.
21 |
22 | :rtype: :class:`Tensor`
23 | """
24 |
25 | assert n == matrix.size(-2)
26 |
27 | row, col = index
28 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
29 |
30 | out = matrix.index_select(-2, col)
31 | out = out * value.unsqueeze(-1)
32 | out = scatter_add(out, row, dim=-2, dim_size=m)
33 |
34 | return out
35 |
36 |
37 | def modulated_spmm(index, value, m, n, matrix, xyzf):
38 | """Matrix product of sparse matrix with dense matrix.
39 |
40 | Args:
41 | index (:class:`LongTensor`): The index tensor of sparse matrix.
42 | value (:class:`Tensor`): The value tensor of sparse matrix.
43 | m (int): The first dimension of corresponding dense matrix.
44 | n (int): The second dimension of corresponding dense matrix.
45 | matrix (:class:`Tensor`): The dense matrix.
46 |
47 | :rtype: :class:`Tensor`
48 | """
49 |
50 | assert n == matrix.size(-2)
51 |
52 | row, col = index
53 | matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
54 |
55 | out = matrix.index_select(-2, col)
56 | ones = torch.ones((xyzf.shape[0], 1), device=xyzf.device)
57 | out = torch.sum(out * torch.cat([xyzf, ones], dim=1), dim=1, keepdim=True)
58 | out = out * value.unsqueeze(-1)
59 | out = scatter_add(out, row, dim=-2, dim_size=m)
60 |
61 | return out
62 |
--------------------------------------------------------------------------------
/solver/dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import torch.utils.data
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 |
15 | def read_file(filename):
16 | points = np.fromfile(filename, dtype=np.uint8)
17 | return torch.from_numpy(points) # convert it to torch.tensor
18 |
19 |
20 | class Dataset(torch.utils.data.Dataset):
21 |
22 | def __init__(self, root, filelist, transform, read_file=read_file,
23 | in_memory=False, take: int = -1):
24 | super(Dataset, self).__init__()
25 | self.root = root
26 | self.filelist = filelist
27 | self.transform = transform
28 | self.in_memory = in_memory
29 | self.read_file = read_file
30 | self.take = take
31 |
32 | self.filenames, self.labels = self.load_filenames()
33 | if self.in_memory:
34 | print('Load files into memory from ' + self.filelist)
35 | self.samples = [self.read_file(os.path.join(self.root, f))
36 | for f in tqdm(self.filenames, ncols=80, leave=False)]
37 |
38 | def __len__(self):
39 | return len(self.filenames)
40 |
41 | def __getitem__(self, idx):
42 | sample = self.samples[idx] if self.in_memory else \
43 | self.read_file(os.path.join(self.root, self.filenames[idx])) # noqa
44 | output = self.transform(sample, idx) # data augmentation + build octree
45 | output['label'] = self.labels[idx]
46 | output['filename'] = self.filenames[idx]
47 | return output
48 |
49 | def load_filenames(self):
50 | filenames, labels = [], []
51 | with open(self.filelist) as fid:
52 | lines = fid.readlines()
53 | for line in lines:
54 | tokens = line.split()
55 | filename = tokens[0]
56 | label = tokens[1] if len(tokens) == 2 else 0
57 | filenames.append(filename)
58 | labels.append(int(label))
59 |
60 | num = len(filenames)
61 | if self.take > num or self.take < 1:
62 | self.take = num
63 |
64 | return filenames[:self.take], labels[:self.take]
65 |
--------------------------------------------------------------------------------
/models/utils/scatter.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | from typing import Optional
10 |
11 |
12 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
13 | if dim < 0:
14 | dim = other.dim() + dim
15 | if src.dim() == 1:
16 | for _ in range(0, dim):
17 | src = src.unsqueeze(0)
18 | for _ in range(src.dim(), other.dim()):
19 | src = src.unsqueeze(-1)
20 | src = src.expand_as(other)
21 | return src
22 |
23 |
24 | def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
25 | out: Optional[torch.Tensor] = None,
26 | dim_size: Optional[int] = None) -> torch.Tensor:
27 | index = broadcast(index, src, dim)
28 | if out is None:
29 | size = list(src.size())
30 | if dim_size is not None:
31 | size[dim] = dim_size
32 | elif index.numel() == 0:
33 | size[dim] = 0
34 | else:
35 | size[dim] = int(index.max()) + 1
36 | out = torch.zeros(size, dtype=src.dtype, device=src.device)
37 | return out.scatter_add_(dim, index, src)
38 | else:
39 | return out.scatter_add_(dim, index, src)
40 |
41 |
42 | def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
43 | weights: Optional[torch.Tensor] = None,
44 | out: Optional[torch.Tensor] = None,
45 | dim_size: Optional[int] = None) -> torch.Tensor:
46 | if weights is not None:
47 | src = src * broadcast(weights, src, dim)
48 | out = scatter_add(src, index, dim, out, dim_size)
49 | dim_size = out.size(dim)
50 |
51 | index_dim = dim
52 | if index_dim < 0:
53 | index_dim = index_dim + src.dim()
54 | if index.dim() <= index_dim:
55 | index_dim = index.dim() - 1
56 |
57 | if weights is None:
58 | weights = torch.ones(index.size(), dtype=src.dtype, device=src.device)
59 | count = scatter_add(weights, index, index_dim, None, dim_size)
60 | count[count < 1] = 1
61 | count = broadcast(count, out, dim)
62 | if out.is_floating_point():
63 | out.true_divide_(count)
64 | else:
65 | out.div_(count, rounding_mode='floor')
66 | return out
67 |
--------------------------------------------------------------------------------
/models/graph_resnet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import ocnn
10 |
11 | from .modules import GraphConvBnRelu, GraphResBlocks
12 | from .dual_octree import DualOctree
13 |
14 |
15 | class GraphResNet(torch.nn.Module):
16 |
17 | def __init__(self, depth, channel_in, nout, resblk_num):
18 | super().__init__()
19 | self.depth, self.channel_in = depth, channel_in
20 | channels = [2 ** max(11 - i, 2) for i in range(depth + 1)]
21 | channels.append(channels[depth])
22 | n_edge_type, avg_degree, bottleneck = 7, 7, 4
23 |
24 | self.conv1 = GraphConvBnRelu(
25 | channel_in, channels[depth], n_edge_type, avg_degree)
26 | self.resblocks = torch.nn.ModuleList(
27 | [GraphResBlocks(channels[d + 1], channels[d], resblk_num, bottleneck,
28 | n_edge_type, avg_degree) for d in range(depth, 2, -1)])
29 | self.pools = torch.nn.ModuleList(
30 | [ocnn.OctreeMaxPool(d) for d in range(depth, 2, -1)])
31 | self.header = torch.nn.Sequential(
32 | ocnn.FullOctreeGlobalPool(depth=2), # global pool
33 | # torch.nn.Dropout(p=0.5), # drop
34 | torch.nn.Linear(channels[3], nout)) # fc
35 |
36 | def forward(self, octree):
37 | # Get the initial feature
38 | data = ocnn.octree_property(octree, 'feature', self.depth)
39 | assert data.size(1) == self.channel_in
40 |
41 | # build the dual octree
42 | doctree = DualOctree(octree)
43 | doctree.post_processing_for_ocnn()
44 |
45 | # forward the network
46 | for i, d in enumerate(range(self.depth, 2, -1)):
47 | # perform graph conv
48 | data = data.squeeze().t()
49 | edge_idx = doctree.graph[d]['edge_idx']
50 | edge_type = doctree.graph[d]['edge_type']
51 | if d == self.depth: # the first conv
52 | data = self.conv1(data, edge_idx, edge_type)
53 | data = self.resblocks[i](data, edge_idx, edge_type)
54 |
55 | # downsampleing
56 | data = data.t().unsqueeze(0).unsqueeze(-1)
57 | data = self.pools[i](data, octree)
58 |
59 | # classification head
60 | data = self.header(data)
61 | return data
62 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/models/graph_unet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn
10 |
11 | from . import dual_octree
12 | from . import graph_ounet
13 |
14 |
15 | class GraphUNet(graph_ounet.GraphOUNet):
16 |
17 | def _setup_channels_and_resblks(self):
18 | # self.resblk_num = [3] * 7 + [1] + [1] * 9
19 | self.resblk_num = [2] * 16
20 | self.channels = [4, 512, 512, 256, 128, 64, 32, 32, 32]
21 |
22 | def recons_decoder(self, convs, doctree_out):
23 | logits = dict()
24 | reg_voxs = dict()
25 | deconvs = dict()
26 |
27 | deconvs[self.full_depth] = convs[self.full_depth]
28 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)):
29 | if d > self.full_depth:
30 | nnum = doctree_out.nnum[d-1]
31 | leaf_mask = doctree_out.node_child(d-1) < 0
32 | deconvd = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum)
33 | deconvd = deconvd + convs[d] # skip connections
34 |
35 | edge_idx = doctree_out.graph[d]['edge_idx']
36 | edge_type = doctree_out.graph[d]['edge_dir']
37 | node_type = doctree_out.graph[d]['node_type']
38 | deconvs[d] = self.decoder[i-1](deconvd, edge_idx, edge_type, node_type)
39 |
40 | # predict the splitting label
41 | logit = self.predict[i](deconvs[d])
42 | nnum = doctree_out.nnum[d]
43 | logits[d] = logit[-nnum:]
44 |
45 | # predict the signal
46 | reg_vox = self.regress[i](deconvs[d])
47 |
48 | # TODO: improve it
49 | # pad zeros to reg_vox to reuse the original code for ocnn
50 | node_mask = doctree_out.graph[d]['node_mask']
51 | shape = (node_mask.shape[0], reg_vox.shape[1])
52 | reg_vox_pad = torch.zeros(shape, device=reg_vox.device)
53 | reg_vox_pad[node_mask] = reg_vox
54 | reg_voxs[d] = reg_vox_pad
55 |
56 | return logits, reg_voxs
57 |
58 | def forward(self, octree_in, octree_out=None, pos=None):
59 | # octree_in and octree_out are the same for UNet
60 | doctree_in = dual_octree.DualOctree(octree_in)
61 | doctree_in.post_processing_for_docnn()
62 |
63 | # run encoder and decoder
64 | convs = self.octree_encoder(octree_in, doctree_in)
65 | out = self.recons_decoder(convs, doctree_in)
66 | output = {'reg_voxs': out[1], 'octree_out': octree_in}
67 |
68 | # compute function value with mpu
69 | if pos is not None:
70 | output['mpus'] = self.neural_mpu(pos, out[1], octree_in)
71 |
72 | # create the mpu wrapper
73 | def _neural_mpu(pos):
74 | pred = self.neural_mpu(pos, out[1], octree_in)
75 | return pred[self.depth_out][0]
76 | output['neural_mpu'] = _neural_mpu
77 |
78 | return output
79 |
--------------------------------------------------------------------------------
/models/octree_ounet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import ocnn
9 | import torch
10 | import torch.nn
11 |
12 | from . import mpu
13 |
14 |
15 | class OctreeOUNet(ocnn.OUNet):
16 |
17 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6):
18 | super().__init__(depth, channel_in, nout, full_depth)
19 |
20 | self.header = None
21 | self.depth_out = depth_out
22 |
23 | self.neural_mpu = mpu.NeuralMPU(self.full_depth, self.depth_out)
24 | self.regress = torch.nn.ModuleList(
25 | [self._make_predict_module(self.channels[d], 4)
26 | for d in range(full_depth, depth + 1)])
27 |
28 | def ocnn_decoder(self, convs, octree_out, octree, update_octree=False):
29 | logits = dict()
30 | reg_voxs = dict()
31 | deconvs = dict()
32 | reg_voxs_list = []
33 |
34 | deconvs[self.full_depth] = convs[self.full_depth]
35 | for i, d in enumerate(range(self.full_depth, self.depth_out + 1)):
36 | if d > self.full_depth:
37 | deconvd = self.upsample[i - 1](deconvs[d - 1], octree_out)
38 | skip, _ = ocnn.octree_align(convs[d], octree, octree_out, d)
39 | deconvd = deconvd + skip
40 | deconvs[d] = self.decoder[i - 1](deconvd, octree_out)
41 |
42 | # predict the splitting label
43 | logit = self.predict[i](deconvs[d])
44 | logit = logit.squeeze().t() # (1, C, H, 1) -> (H, C)
45 | logits[d] = logit
46 |
47 | # update the octree according to predicted labels
48 | if update_octree:
49 | label = logits[d].argmax(1).to(torch.int32)
50 | octree_out = ocnn.octree_update(octree_out, label, d, split=1)
51 | if d < self.depth_out:
52 | octree_out = ocnn.octree_grow(octree_out, target_depth=d + 1)
53 |
54 | # predict the signal
55 | reg_vox = self.regress[i](deconvs[d])
56 | reg_vox = reg_vox.squeeze().t() # (1, C, H, 1) -> (H, C)
57 | reg_voxs_list.append(reg_vox)
58 | reg_voxs[d] = torch.cat(reg_voxs_list, dim=0)
59 |
60 | return logits, reg_voxs, octree_out
61 |
62 | def forward(self, octree_in, octree_out=None, pos=None):
63 | update_octree = octree_out is None
64 | if update_octree:
65 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout)
66 |
67 | # run encoder and decoder
68 | convs = self.ocnn_encoder(octree_in)
69 | out = self.ocnn_decoder(convs, octree_out, octree_in, update_octree)
70 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]}
71 |
72 | # compute function value with mpu
73 | if pos is not None:
74 | output['mpus'] = self.neural_mpu(pos, out[1], out[2])
75 |
76 | # create the mpu wrapper
77 | def _neural_mpu(pos):
78 | pred = self.neural_mpu(pos, out[1], out[2])
79 | return pred[self.depth_out][0]
80 | output['neural_mpu'] = _neural_mpu
81 |
82 | return output
83 |
--------------------------------------------------------------------------------
/tools/compute_metrics.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import argparse
10 | import trimesh.sample
11 | import numpy as np
12 | from tqdm import tqdm
13 | from scipy.spatial import cKDTree
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('--mesh_folder', type=str, required=True)
17 | parser.add_argument('--filename_out', type=str, required=True)
18 | parser.add_argument('--num_samples', type=int, default=30000)
19 | parser.add_argument('--ref_folder', type=str, default='data/dfaust/dfaust/mesh_gt')
20 | parser.add_argument('--filelist', type=str, default='data/dfaust/test_all.txt')
21 | args = parser.parse_args()
22 |
23 |
24 | with open(args.filelist, 'r') as fid:
25 | lines = fid.readlines()
26 | filenames = [line.strip() for line in lines]
27 |
28 |
29 | def compute_metrics(filename_ref, filename_pred, num_samples=30000):
30 | mesh_ref = trimesh.load(filename_ref)
31 | points_ref, idx_ref = trimesh.sample.sample_surface(mesh_ref, num_samples)
32 | normals_ref = mesh_ref.face_normals[idx_ref]
33 | # points_ref, normals_ref = read_ply(filename_ref)
34 |
35 | mesh_pred = trimesh.load(filename_pred)
36 | points_pred, idx_pred = trimesh.sample.sample_surface(mesh_pred, num_samples)
37 | normals_pred = mesh_pred.face_normals[idx_pred]
38 |
39 | kdtree_a = cKDTree(points_ref)
40 | dist_a, idx_a = kdtree_a.query(points_pred)
41 | chamfer_a = np.mean(dist_a)
42 | dot_a = np.sum(normals_pred * normals_ref[idx_a], axis=1)
43 | angle_a = np.mean(np.arccos(dot_a) * (180.0 / np.pi))
44 | consist_a = np.mean(np.abs(dot_a))
45 |
46 | kdtree_b = cKDTree(points_pred)
47 | dist_b, idx_b = kdtree_b.query(points_ref)
48 | chamfer_b = np.mean(dist_b)
49 | dot_b = np.sum(normals_ref * normals_pred[idx_b], axis=1)
50 | angle_b = np.mean(np.arccos(dot_b) * (180 / np.pi))
51 | consist_b = np.mean(np.abs(dot_b))
52 |
53 | return chamfer_a, chamfer_b, angle_a, angle_b, consist_a, consist_b
54 |
55 |
56 | counter = 0
57 | fid = open(args.filename_out, 'w')
58 | fid.write(('name, '
59 | 'chamfer_a, chamfer_b, chamfer, '
60 | 'angle_a, angle_b, angle, '
61 | 'consist_a, consist_b, normal consistency\n'))
62 | for filename in tqdm(filenames, ncols=80):
63 | if filename.endswith('.npy'):
64 | filename = filename[:-4]
65 | filename_ref = os.path.join(args.ref_folder, filename + '.obj')
66 | # filename_ref = os.path.join(args.ref_folder, filename + '.ply')
67 | filename_pred = os.path.join(args.mesh_folder, filename + '.obj')
68 | metrics = compute_metrics(filename_ref, filename_pred, args.num_samples)
69 |
70 | chamfer_a, chamfer_b = metrics[0], metrics[1]
71 | angle_a, angle_b = metrics[2], metrics[3]
72 | consist_a, consist_b = metrics[4], metrics[5]
73 |
74 | msg = '{}, {}, {}, {}, {}, {}, {}, {}, {}, {}\n'.format(
75 | filename,
76 | chamfer_a, chamfer_b, 0.5 * (chamfer_a + chamfer_b),
77 | angle_a, angle_b, 0.5 * (angle_a + angle_b),
78 | consist_a, consist_b, 0.5 * (consist_a + consist_b))
79 | fid.write(msg)
80 | tqdm.write(msg)
81 |
82 | fid.close()
83 |
--------------------------------------------------------------------------------
/datasets/synthetic_room.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import ocnn
10 | import torch
11 | import numpy as np
12 |
13 | from solver import Dataset
14 | from .utils import collate_func
15 | from .shapenet import TransformShape
16 |
17 |
18 | class TransformScene(TransformShape):
19 |
20 | def __init__(self, flags):
21 | self.flags = flags
22 |
23 | self.point_sample_num = 10000
24 | self.occu_sample_num = 4096
25 | self.surface_sample_num = 2048
26 | self.sample_surf_points = flags.sample_surf_points
27 | self.points_scale = 0.6 # the points are actually in [-0.55, 0.55]
28 | self.noise_std = 0.005
29 | self.pos_weight = 10
30 | self.points2octree = ocnn.Points2Octree(**flags)
31 |
32 | def sample_occu(self, sample):
33 | points, occus = sample['points'], sample['occupancies']
34 | points = points / self.points_scale
35 | points = points + 1e-6 * np.random.randn(*points.shape) # ref ConvoNet
36 | occus = np.unpackbits(occus)[:points.shape[0]]
37 |
38 | rand_idx = np.random.choice(points.shape[0], size=self.occu_sample_num)
39 | points = torch.from_numpy(points[rand_idx]).float()
40 | occus = torch.from_numpy(occus[rand_idx]).float()
41 | # 1 - outside shapes; 1 - inside shapes
42 | occus = 1 - occus # to be consistent with ShapeNet
43 |
44 | # The number of points inside shapes is roughly 1.2% of points outside
45 | # shapes, we set weights to 10 for points inside shapes.
46 | weight = torch.ones_like(occus)
47 | weight[occus < 0.5] = self.pos_weight
48 |
49 | # The points are not on the surfaces, the gradients are 0
50 | grad = torch.zeros_like(points)
51 | return {'pos': points, 'occu': occus, 'weight': weight, 'grad': grad}
52 |
53 | def sample_surface(self, sample):
54 | # get the input TODO: use normals
55 | points, normals = sample['points'], sample['normals']
56 |
57 | # sample points
58 | rand_idx = np.random.choice(points.shape[0], size=self.surface_sample_num)
59 | pos = torch.from_numpy(points[rand_idx])
60 | grad = torch.from_numpy(normals[rand_idx])
61 | pos = pos / self.points_scale # scale to [-1.0, 1.0]
62 | occus = torch.ones(self.surface_sample_num) * 0.5
63 | weight = torch.ones(self.surface_sample_num) * 2.0 # TODO: tune this scale
64 |
65 | return {'pos': pos, 'occu': occus, 'weight': weight, 'grad': grad}
66 |
67 | def __call__(self, sample, idx):
68 | output = self.process_points_cloud(sample['point_cloud'])
69 |
70 | # sample ground truth sdfs
71 | if self.flags.load_occu:
72 | occus = self.sample_occu(sample['occus'])
73 |
74 | if self.sample_surf_points:
75 | surface_occus = self.sample_surface(sample['point_cloud'])
76 | for key in occus.keys():
77 | occus[key] = torch.cat([occus[key], surface_occus[key]], dim=0)
78 |
79 | output.update(occus)
80 |
81 | return output
82 |
83 |
84 | class ReadFile:
85 | def __init__(self, load_occu=False):
86 | self.load_occu = load_occu
87 | self.num_files = 10
88 |
89 | def __call__(self, filename):
90 | num = np.random.randint(self.num_files)
91 | filename_pc = os.path.join(
92 | filename, 'pointcloud/pointcloud_%02d.npz' % num)
93 | raw = np.load(filename_pc)
94 | point_cloud = {'points': raw['points'], 'normals': raw['normals']}
95 | output = {'point_cloud': point_cloud}
96 |
97 | if self.load_occu:
98 | num = np.random.randint(self.num_files)
99 | filename_occu = os.path.join(
100 | filename, 'points_iou/points_iou_%02d.npz' % num)
101 | raw = np.load(filename_occu)
102 | occus = {'points': raw['points'], 'occupancies': raw['occupancies']}
103 | output['occus'] = occus
104 |
105 | return output
106 |
107 |
108 | def get_synthetic_room_dataset(flags):
109 | transform = TransformScene(flags)
110 | read_file = ReadFile(flags.load_occu)
111 | dataset = Dataset(flags.location, flags.filelist, transform,
112 | read_file=read_file, in_memory=flags.in_memory)
113 | return dataset, collate_func
114 |
--------------------------------------------------------------------------------
/models/graph_lenet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import ocnn
10 |
11 | from .modules import GraphConvBnRelu, GraphDownsample, GraphMaxpool
12 | from .dual_octree import DualOctree
13 |
14 |
15 | class GraphLeNet(torch.nn.Module):
16 | '''This is to do comparison with the original O-CNN'''
17 |
18 | def __init__(self, depth, channel_in, nout):
19 | super().__init__()
20 | self.depth, self.channel_in = depth, channel_in
21 | channels = [2 ** max(9 - i, 2) for i in range(depth + 1)]
22 | channels.append(channel_in)
23 |
24 | self.convs = torch.nn.ModuleList([
25 | GraphConvBnRelu(channels[d + 1], channels[d], n_edge_type=7,
26 | avg_degree=7) for d in range(depth, 2, -1)])
27 | self.pools = torch.nn.ModuleList([
28 | ocnn.OctreeMaxPool(d) for d in range(depth, 2, -1)])
29 | self.octree2voxel = ocnn.FullOctree2Voxel(2)
30 | self.header = torch.nn.Sequential(
31 | torch.nn.Dropout(p=0.5), # drop1
32 | ocnn.FcBnRelu(channels[3] * 64, channels[2]), # fc1
33 | torch.nn.Dropout(p=0.5), # drop2
34 | torch.nn.Linear(channels[2], nout)) # fc2
35 |
36 | def forward(self, octree):
37 | # Get the initial feature
38 | data = ocnn.octree_property(octree, 'feature', self.depth)
39 | assert data.size(1) == self.channel_in
40 |
41 | # build the dual octree
42 | doctree = DualOctree(octree)
43 | doctree.post_processing_for_ocnn()
44 |
45 | # forward the network
46 | for i, d in enumerate(range(self.depth, 2, -1)):
47 | # perform graph conv
48 | data = data.squeeze().t()
49 | edge_idx = doctree.graph[d]['edge_idx']
50 | edge_type = doctree.graph[d]['edge_type']
51 | data = self.convs[i](data, edge_idx, edge_type)
52 |
53 | # downsampleing
54 | data = data.t().unsqueeze(0).unsqueeze(-1)
55 | data = self.pools[i](data, octree)
56 |
57 | # classification head
58 | data = self.octree2voxel(data)
59 | data = self.header(data)
60 | return data
61 |
62 |
63 | class DualGraphLeNet(torch.nn.Module):
64 |
65 | def __init__(self, depth, channel_in, nout):
66 | super().__init__()
67 | self.depth, self.channel_in = depth, channel_in
68 | channels = [2 ** max(9 - i, 2) for i in range(depth + 1)]
69 | channels.append(channel_in)
70 |
71 | self.convs = torch.nn.ModuleList([
72 | GraphConvBnRelu(channels[d + 1], channels[d], n_edge_type=7,
73 | avg_degree=7, n_node_type=d-1)
74 | for d in range(depth, 2, -1)])
75 | # self.downsample = torch.nn.ModuleList([
76 | # GraphDownsample(channels[d]) for d in range(depth, 2, -1)])
77 | self.downsample = torch.nn.ModuleList([
78 | GraphMaxpool() for d in range(depth, 2, -1)])
79 | self.octree2voxel = ocnn.FullOctree2Voxel(2)
80 | self.header = torch.nn.Sequential(
81 | torch.nn.Dropout(p=0.5), # drop1
82 | ocnn.FcBnRelu(channels[3] * 64, channels[2]), # fc1
83 | torch.nn.Dropout(p=0.5), # drop2
84 | torch.nn.Linear(channels[2], nout)) # fc2
85 |
86 | def forward(self, octree):
87 | # build the dual octree
88 | doctree = DualOctree(octree)
89 | doctree.post_processing_for_docnn()
90 |
91 | # Get the initial feature
92 | data = doctree.get_input_feature()
93 | assert data.size(1) == self.channel_in
94 |
95 | # forward the network
96 | for i, d in enumerate(range(self.depth, 2, -1)):
97 | # perform graph conv
98 | edge_idx = doctree.graph[d]['edge_idx']
99 | edge_type = doctree.graph[d]['edge_dir']
100 | node_type = doctree.graph[d]['node_type']
101 | data = self.convs[i](data, edge_idx, edge_type, node_type)
102 |
103 | # downsampleing
104 | nnum = doctree.nnum[d]
105 | lnum = doctree.lnum[d-1]
106 | leaf_mask = doctree.node_child(d-1) < 0
107 | data = self.downsample[i](data, leaf_mask, nnum, lnum)
108 |
109 | # classification head
110 | data = data.t().unsqueeze(0).unsqueeze(-1)
111 | data = self.octree2voxel(data)
112 | data = self.header(data)
113 | return data
114 |
--------------------------------------------------------------------------------
/models/graph_ae.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import ocnn
9 | import torch
10 | import torch.nn
11 |
12 | from . import modules
13 | from . import dual_octree
14 | from . import graph_ounet
15 |
16 |
17 | class GraphAE(graph_ounet.GraphOUNet):
18 |
19 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6,
20 | resblk_type='bottleneck', bottleneck=4):
21 | super().__init__(depth, channel_in, nout, full_depth, depth_out,
22 | resblk_type, bottleneck)
23 | # this is to make the encoder and decoder symmetric
24 | n_edge_type, avg_degree = 7, 7
25 | self.decoder = torch.nn.ModuleList(
26 | [modules.GraphResBlocks(self.channels[d], self.channels[d],
27 | self.resblk_num[d], bottleneck, n_edge_type, avg_degree, d-1, resblk_type)
28 | for d in range(full_depth, depth + 1)])
29 |
30 | self.code_channel = 8
31 | self.code_dim = self.code_channel * 2 ** (3 * self.full_depth)
32 | channel_in = self.channels[self.full_depth]
33 | self.project1 = modules.Conv1x1Bn(channel_in, self.code_channel)
34 | self.project2 = modules.Conv1x1BnRelu(self.code_channel, channel_in)
35 |
36 | def octree_encoder(self, octree, doctree):
37 | convs = super().octree_encoder(octree, doctree)
38 | # reduce the dimension
39 | code = self.project1(convs[self.full_depth])
40 | # constrain the code in [-1, 1]
41 | code = torch.tanh(code)
42 | return code
43 |
44 | def octree_decoder(self, latent_code, doctree_out, doctree=None, update_octree=False):
45 | logits = dict()
46 | reg_voxs = dict()
47 | deconvs = dict()
48 |
49 | deconvs[self.full_depth] = self.project2(latent_code)
50 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)):
51 | if d > self.full_depth:
52 | nnum = doctree_out.nnum[d-1]
53 | leaf_mask = doctree_out.node_child(d-1) < 0
54 | deconvs[d] = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum)
55 |
56 | edge_idx = doctree_out.graph[d]['edge_idx']
57 | edge_type = doctree_out.graph[d]['edge_dir']
58 | node_type = doctree_out.graph[d]['node_type']
59 | deconvs[d] = self.decoder[i](deconvs[d], edge_idx, edge_type, node_type)
60 |
61 | # predict the splitting label
62 | logit = self.predict[i](deconvs[d])
63 | nnum = doctree_out.nnum[d]
64 | logits[d] = logit[-nnum:]
65 |
66 | # update the octree according to predicted labels
67 | if update_octree:
68 | label = logits[d].argmax(1).to(torch.int32)
69 | octree_out = doctree_out.octree
70 | octree_out = ocnn.octree_update(octree_out, label, d, split=1)
71 | if d < self.depth_out:
72 | octree_out = ocnn.octree_grow(octree_out, target_depth=d+1)
73 | doctree_out = dual_octree.DualOctree(octree_out)
74 | doctree_out.post_processing_for_docnn()
75 |
76 | # predict the signal
77 | reg_vox = self.regress[i](deconvs[d])
78 |
79 | # TODO: improve it
80 | # pad zeros to reg_vox to reuse the original code for ocnn
81 | node_mask = doctree_out.graph[d]['node_mask']
82 | shape = (node_mask.shape[0], reg_vox.shape[1])
83 | reg_vox_pad = torch.zeros(shape, device=reg_vox.device)
84 | reg_vox_pad[node_mask] = reg_vox
85 | reg_voxs[d] = reg_vox_pad
86 |
87 | return logits, reg_voxs, doctree_out.octree
88 |
89 | def extract_code(self, octree_in):
90 | doctree_in = dual_octree.DualOctree(octree_in)
91 | doctree_in.post_processing_for_docnn()
92 |
93 | code = self.octree_encoder(octree_in, doctree_in)
94 | return code
95 |
96 | def decode_code(self, code):
97 | # generate dual octrees
98 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout)
99 | doctree_out = dual_octree.DualOctree(octree_out)
100 | doctree_out.post_processing_for_docnn()
101 |
102 | # run encoder and decoder
103 | out = self.octree_decoder(code, doctree_out, update_octree=True)
104 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]}
105 |
106 | # create the mpu wrapper
107 | def _neural_mpu(pos):
108 | pred = self.neural_mpu(pos, out[1], out[2])
109 | return pred[self.depth_out][0]
110 | output['neural_mpu'] = _neural_mpu
111 |
112 | return output
113 |
--------------------------------------------------------------------------------
/dualocnn.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch
10 | import numpy as np
11 |
12 | import builder
13 | import utils
14 | from solver import Solver, get_config
15 |
16 |
17 | class DualOcnnSolver(Solver):
18 |
19 | def get_model(self, flags):
20 | return builder.get_model(flags)
21 |
22 | def get_dataset(self, flags):
23 | return builder.get_dataset(flags)
24 |
25 | def batch_to_cuda(self, batch):
26 | keys = ['octree_in', 'octree_gt', 'pos', 'sdf', 'grad', 'weight', 'occu']
27 | for key in keys:
28 | if key in batch:
29 | batch[key] = batch[key].cuda()
30 | batch['pos'].requires_grad_()
31 |
32 | def compute_loss(self, batch, model_out):
33 | flags = self.FLAGS.LOSS
34 | loss_func = builder.get_loss_function(flags)
35 | output = loss_func(batch, model_out, flags.loss_type)
36 | return output
37 |
38 | def model_forward(self, batch):
39 | self.batch_to_cuda(batch)
40 | model_out = self.model(batch['octree_in'], batch['octree_gt'], batch['pos'])
41 |
42 | output = self.compute_loss(batch, model_out)
43 | losses = [val for key, val in output.items() if 'loss' in key]
44 | output['loss'] = torch.sum(torch.stack(losses))
45 | return output
46 |
47 | def train_step(self, batch):
48 | output = self.model_forward(batch)
49 | output = {'train/' + key: val for key, val in output.items()}
50 | return output
51 |
52 | def test_step(self, batch):
53 | output = self.model_forward(batch)
54 | output = {'test/' + key: val for key, val in output.items()}
55 | return output
56 |
57 | def extract_mesh(self, neural_mpu, filename, bbox=None):
58 | # bbox used for marching cubes
59 | if bbox is not None:
60 | bbmin, bbmax = bbox[:3], bbox[3:]
61 | else:
62 | sdf_scale = self.FLAGS.SOLVER.sdf_scale
63 | bbmin, bbmax = -sdf_scale, sdf_scale
64 |
65 | # create mesh
66 | utils.create_mesh(neural_mpu, filename,
67 | size=self.FLAGS.SOLVER.resolution,
68 | bbmin=bbmin, bbmax=bbmax,
69 | mesh_scale=self.FLAGS.DATA.test.point_scale,
70 | save_sdf=self.FLAGS.SOLVER.save_sdf)
71 |
72 | def eval_step(self, batch):
73 | # forward the model
74 | output = self.model.forward(batch['octree_in'].cuda())
75 |
76 | # extract the mesh
77 | filename = batch['filename'][0]
78 | pos = filename.rfind('.')
79 | if pos != -1: filename = filename[:pos] # remove the suffix
80 | filename = os.path.join(self.logdir, filename + '.obj')
81 | folder = os.path.dirname(filename)
82 | if not os.path.exists(folder): os.makedirs(folder)
83 | bbox = batch['bbox'][0].numpy() if 'bbox' in batch else None
84 | self.extract_mesh(output['neural_mpu'], filename, bbox)
85 |
86 | # save the input point cloud
87 | filename = filename[:-4] + '.input.ply'
88 | utils.points2ply(filename, batch['points_in'][0].cpu(),
89 | self.FLAGS.DATA.test.point_scale)
90 |
91 | def save_tensors(self, batch, output):
92 | iter_num = batch['iter_num']
93 | filename = os.path.join(self.logdir, '%04d.out.octree' % iter_num)
94 | output['octree_out'].cpu().numpy().tofile(filename)
95 | filename = os.path.join(self.logdir, '%04d.in.octree' % iter_num)
96 | batch['octree_in'].cpu().numpy().tofile(filename)
97 | filename = os.path.join(self.logdir, '%04d.in.points' % iter_num)
98 | batch['points_in'][0].cpu().numpy().tofile(filename)
99 | filename = os.path.join(self.logdir, '%04d.gt.octree' % iter_num)
100 | batch['octree_gt'].cpu().numpy().tofile(filename)
101 | filename = os.path.join(self.logdir, '%04d.gt.points' % iter_num)
102 | batch['points_gt'][0].cpu().numpy().tofile(filename)
103 |
104 | @classmethod
105 | def update_configs(cls):
106 | FLAGS = get_config()
107 | FLAGS.SOLVER.resolution = 128 # the resolution used for marching cubes
108 | FLAGS.SOLVER.save_sdf = False # save the sdfs in evaluation
109 | FLAGS.SOLVER.sdf_scale = 0.9 # the scale of sdfs
110 |
111 | FLAGS.DATA.train.point_scale = 0.5 # the scale of point clouds
112 | FLAGS.DATA.train.load_sdf = True # load sdf samples
113 | FLAGS.DATA.train.load_occu = False # load occupancy samples
114 | FLAGS.DATA.train.point_sample_num = 10000
115 | FLAGS.DATA.train.sample_surf_points = False
116 |
117 | # FLAGS.MODEL.skip_connections = True
118 | FLAGS.DATA.test = FLAGS.DATA.train.clone()
119 | FLAGS.LOSS.loss_type = 'sdf_reg_loss'
120 |
121 |
122 | if __name__ == '__main__':
123 | DualOcnnSolver.main()
124 |
--------------------------------------------------------------------------------
/models/graph_slounet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import ocnn
10 | import torch.nn
11 | import torch.nn.functional as F
12 |
13 | from . import mpu
14 | from . import modules
15 | from . import graph_ounet
16 | from . import dual_octree
17 |
18 |
19 | class GraphSLDownsample(modules.GraphDownsample):
20 |
21 | def forward(self, x, leaf_mask, numd, lnumd):
22 | # downsample nodes at layer depth
23 | outd = x[-numd:]
24 | outd = self.downsample(outd)
25 |
26 | # get the nodes at layer (depth-1)
27 | out = torch.zeros(leaf_mask.shape[0], x.shape[1], device=x.device)
28 | out[leaf_mask.logical_not()] = outd
29 |
30 | if self.channels_in != self.channels_out:
31 | out = self.conv1x1(out)
32 | return out
33 |
34 |
35 | class GraphSLUpsample(modules.GraphUpsample):
36 |
37 | def forward(self, x, leaf_mask, numd):
38 | # upsample nodes at layer (depth-1)
39 | outd = x[-numd:]
40 | out = outd[leaf_mask.logical_not()]
41 | out = self.upsample(out)
42 |
43 | if self.channels_in != self.channels_out:
44 | out = self.conv1x1(out)
45 | return out
46 |
47 |
48 | class GraphSLOUNet(graph_ounet.GraphOUNet):
49 |
50 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6,
51 | resblk_type='bottleneck', bottleneck=4):
52 | super().__init__(depth, channel_in, nout, full_depth, depth_out,
53 | resblk_type, bottleneck)
54 |
55 | self.downsample = torch.nn.ModuleList(
56 | [GraphSLDownsample(self.channels[d], self.channels[d-1])
57 | for d in range(depth, full_depth, -1)])
58 |
59 | self.upsample = torch.nn.ModuleList(
60 | [GraphSLUpsample(self.channels[d-1], self.channels[d])
61 | for d in range(full_depth+1, depth+1)])
62 |
63 | def _get_input_feature(self, doctree):
64 | return doctree.get_input_feature(all_leaf_nodes=False)
65 |
66 | def octree_decoder(self, convs, doctree_out, doctree, update_octree=False):
67 | logits = dict()
68 | reg_voxs = dict()
69 | deconvs = dict()
70 | reg_voxs_list = []
71 |
72 | deconvs[self.full_depth] = convs[self.full_depth]
73 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)):
74 | if d > self.full_depth:
75 | nnum = doctree_out.nnum[d-1]
76 | leaf_mask = doctree_out.node_child(d-1) < 0
77 | deconvd = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum)
78 | skip = modules.doctree_align(
79 | convs[d], doctree.graph[d]['keyd'], doctree_out.graph[d]['keyd'])
80 | deconvd = deconvd + skip # skip connections
81 |
82 | edge_idx = doctree_out.graph[d]['edge_idx']
83 | edge_type = doctree_out.graph[d]['edge_dir']
84 | node_type = doctree_out.graph[d]['node_type']
85 | deconvs[d] = self.decoder[i-1](deconvd, edge_idx, edge_type, node_type)
86 |
87 | # predict the splitting label
88 | logit = self.predict[i](deconvs[d])
89 | nnum = doctree_out.nnum[d]
90 | logits[d] = logit[-nnum:]
91 |
92 | # update the octree according to predicted labels
93 | if update_octree:
94 | label = logits[d].argmax(1).to(torch.int32)
95 | octree_out = doctree_out.octree
96 | octree_out = ocnn.octree_update(octree_out, label, d, split=1)
97 | if d < self.depth_out:
98 | octree_out = ocnn.octree_grow(octree_out, target_depth=d+1)
99 | doctree_out = dual_octree.DualOctree(octree_out)
100 | doctree_out.post_processing_for_ocnn()
101 |
102 | # predict the signal
103 | reg_vox = self.regress[i](deconvs[d])
104 | reg_voxs_list.append(reg_vox)
105 | reg_voxs[d] = torch.cat(reg_voxs_list, dim=0)
106 |
107 | return logits, reg_voxs, doctree_out.octree
108 |
109 | def forward(self, octree_in, octree_out=None, pos=None):
110 | # generate dual octrees
111 | doctree_in = dual_octree.DualOctree(octree_in)
112 | doctree_in.post_processing_for_ocnn()
113 |
114 | update_octree = octree_out is None
115 | if update_octree:
116 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout)
117 | doctree_out = dual_octree.DualOctree(octree_out)
118 | doctree_out.post_processing_for_ocnn()
119 |
120 | # run encoder and decoder
121 | convs = self.octree_encoder(octree_in, doctree_in)
122 | out = self.octree_decoder(convs, doctree_out, doctree_in, update_octree)
123 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]}
124 |
125 | # mpus
126 | if pos is not None:
127 | output['mpus'] = self.neural_mpu(pos, out[1], octree_out)
128 |
129 | return output
130 |
--------------------------------------------------------------------------------
/tools/dfaust.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import argparse
10 | import trimesh
11 | import trimesh.sample
12 | import numpy as np
13 | import time
14 | import zipfile
15 | import wget
16 | from tqdm import tqdm
17 |
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--run', type=str, default='prepare_dataset')
21 | parser.add_argument('--filelist', type=str, default='test.txt')
22 | parser.add_argument('--mesh_folder', type=str, default='logs/dfaust/mesh')
23 | parser.add_argument('--output_folder', type=str, default='logs/dfaust/mesh')
24 | args = parser.parse_args()
25 |
26 |
27 | project_folder = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
28 | root_folder = os.path.join(project_folder, 'data/dfaust')
29 | shape_scale = 0.8
30 |
31 |
32 | def create_flag_file(filename):
33 | r''' Creates a flag file to indicate whether some time-consuming works
34 | have been done.
35 | '''
36 |
37 | folder = os.path.dirname(filename)
38 | if not os.path.exists(folder):
39 | os.makedirs(folder)
40 | with open(filename, 'w') as fid:
41 | fid.write('succ @ ' + time.ctime())
42 |
43 |
44 | def check_folder(filenames: list):
45 | r''' Checks whether the folder contains the filename exists.
46 | '''
47 |
48 | for filename in filenames:
49 | folder = os.path.dirname(filename)
50 | if not os.path.exists(folder):
51 | os.makedirs(folder)
52 |
53 |
54 | def get_filenames(filelist, root_folder):
55 | r''' Gets filenames from a filelist.
56 | '''
57 |
58 | filelist = os.path.join(root_folder, 'filelist', filelist)
59 | with open(filelist, 'r') as fid:
60 | lines = fid.readlines()
61 | filenames = [line.split()[0] for line in lines]
62 | return filenames
63 |
64 |
65 | def download_filelist():
66 | r''' Downloads the filelists used for learning.
67 | '''
68 |
69 | print('-> Download the filelist.')
70 | url = 'https://www.dropbox.com/s/vxkpaz3umzjvi66/dfaust.filelist.zip?dl=1'
71 | filename = os.path.join(root_folder, 'filelist.zip')
72 | wget.download(url, filename, bar=None)
73 |
74 | folder = os.path.join(root_folder, 'filelist')
75 | with zipfile.ZipFile(filename, 'r') as zip_ref:
76 | zip_ref.extractall(path=folder)
77 | os.remove(filename)
78 |
79 |
80 | def sample_points():
81 | r''' Samples points from raw scanns for training.
82 | '''
83 |
84 | num_samples = 100000
85 | print('-> Sample points.')
86 | scans_folder = os.path.join(root_folder, 'scans')
87 | dataset_folder = os.path.join(root_folder, 'dataset')
88 | filenames = get_filenames('all.txt', root_folder)
89 | for filename in tqdm(filenames, ncols=80):
90 | filename_ply = os.path.join(scans_folder, filename + '.ply')
91 | filename_pts = os.path.join(dataset_folder, filename + '.npy')
92 | filename_center = filename_pts[:-3] + 'center.npy'
93 | check_folder([filename_pts])
94 |
95 | # sample points
96 | mesh = trimesh.load(filename_ply)
97 | points, idx = trimesh.sample.sample_surface(mesh, num_samples)
98 | normals = mesh.face_normals[idx]
99 |
100 | # normalize: Centralize + Scale
101 | center = np.mean(points, axis=0, keepdims=True)
102 | points = (points - center) * shape_scale
103 | point_cloud = np.concatenate((points, normals), axis=-1).astype(np.float32)
104 |
105 | # save
106 | np.save(filename_pts, point_cloud)
107 | np.save(filename_center, center)
108 |
109 |
110 | def rescale_mesh():
111 | r''' Rescales and translates the generated mesh to align with the raw scans
112 | to compute evaluation metrics.
113 | '''
114 |
115 | filenames = get_filenames(args.filelist, root_folder)
116 | for filename in tqdm(filenames, ncols=80):
117 | filename = filename[:-4]
118 | filename_output = os.path.join(args.output_folder, filename + '.obj')
119 | filename_mesh = os.path.join(args.mesh_folder, filename + '.obj')
120 | filename_center = os.path.join(
121 | root_folder, 'dataset', filename + '.center.npy')
122 | check_folder([filename_output])
123 |
124 | center = np.load(filename_center)
125 | mesh = trimesh.load(filename_mesh)
126 | vertices = mesh.vertices / shape_scale + center
127 | mesh.vertices = vertices
128 | mesh.export(filename_output)
129 |
130 |
131 | def generate_dataset():
132 | download_filelist()
133 | sample_points()
134 |
135 |
136 | def download_dataset():
137 | download_filelist()
138 |
139 | print('-> Download the dataset.')
140 | flag_file = os.path.join(root_folder, 'flags/download_dataset_succ')
141 | if not os.path.exists(flag_file):
142 | url = 'https://www.dropbox.com/s/eb5uk8f2fqswhs3/dfaust.dataset.zip?dl=1'
143 | filename = os.path.join(root_folder, 'dfaust.dataset.zip')
144 | wget.download(url, filename, bar=None)
145 |
146 | with zipfile.ZipFile(filename, 'r') as zip_ref:
147 | zip_ref.extractall(path=root_folder)
148 | # os.remove(filename)
149 | create_flag_file(flag_file)
150 |
151 |
152 | if __name__ == '__main__':
153 | eval('%s()' % args.run)
154 |
--------------------------------------------------------------------------------
/models/mpu.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import ocnn
9 | import torch
10 | import torch.nn
11 |
12 | from .utils.spmm import spmm, modulated_spmm
13 |
14 | kNN = 8
15 |
16 |
17 | class ABS(torch.autograd.Function):
18 | '''The derivative of torch.abs on `0` is `0`, and in this implementation, we
19 | modified it to `1`
20 | '''
21 | @staticmethod
22 | def forward(ctx, input):
23 | ctx.save_for_backward(input)
24 | return input.abs()
25 |
26 | @staticmethod
27 | def backward(ctx, grad_in):
28 | input, = ctx.saved_tensors
29 | sign = input < 0
30 | grad_out = grad_in * (-2.0 * sign.to(input.dtype) + 1.0)
31 | return grad_out
32 |
33 |
34 | def linear_basis(x):
35 | return 1.0 - ABS.apply(x)
36 |
37 |
38 | def get_linear_mask(dim=3):
39 | mask = torch.tensor([0, 1], dtype=torch.float32)
40 | mask = torch.meshgrid([mask]*dim)
41 | mask = torch.stack(mask, -1).view(-1, dim)
42 | return mask
43 |
44 |
45 | def octree_linear_pts(octree, depth, pts):
46 | # get neigh coordinates
47 | scale = 2 ** depth
48 | mask = get_linear_mask(dim=3).to(pts.device)
49 | xyzf, ids = torch.split(pts, [3, 1], 1)
50 | xyzf = (xyzf + 1.0) * (scale / 2.0) # [-1, 1] -> [0, scale]
51 | xyzf = xyzf - 0.5 # the code is defined on the center
52 | xyzi = torch.floor(xyzf).detach() # the integer part (N, 3), use floor
53 | corners = xyzi.unsqueeze(1) + mask # (N, 8, 3)
54 | coordsf = xyzf.unsqueeze(1) - corners # (N, 8, 3), in [-1.0, 1.0]
55 |
56 | # coorers -> key
57 | ids = ids.detach().repeat(1, kNN).unsqueeze(-1) # (N, 8, 1)
58 | key = torch.cat([corners, ids], dim=-1).view(-1, 4).short() # (N*8, 4)
59 | key = ocnn.octree_encode_key(key).long() # (N*8, )
60 | idx = ocnn.octree_search_key(key, octree, depth, key_is_xyz=True)
61 |
62 | # corners -> flags
63 | valid = torch.logical_and(corners > -1, corners < scale) # out-of-bound
64 | valid = torch.all(valid, dim=-1).view(-1)
65 | flgs = torch.logical_and(idx > -1, valid)
66 |
67 | # remove invalid pts
68 | idx = idx[flgs].long() # (N*8, ) -> (N', )
69 | coordsf = coordsf.view(-1, 3)[flgs] # (N, 8, 3) -> (N', 3)
70 |
71 | # bspline weights
72 | weights = linear_basis(coordsf) # (N', 3)
73 | weights = torch.prod(weights, axis=-1).view(-1) # (N', )
74 | # Here, the scale factor `2**(depth - 6)` is used to emphasize high-resolution
75 | # basis functions. Tune this factor further if needed! !!! NOTE !!!
76 | # weights = weights * 2**(depth - 6) # used for shapenet
77 | weights = weights * (depth**2 / 50) # testing
78 |
79 | # rescale back the original scale
80 | # After recaling, the coordsf is in the same scale as pts
81 | coordsf = coordsf * (2.0 / scale) # [-1.0, 1.0] -> [-2.0/scale, 2.0/scale]
82 | return {'idx': idx, 'xyzf': coordsf, 'weights': weights, 'flgs': flgs}
83 |
84 |
85 | def get_linear_pred(pts, octree, shape_code, neighs, depth_start, depth_end):
86 | npt = pts.size(0)
87 | indices, weights, xyzfs = [], [], []
88 | nnum_cum = ocnn.octree_property(octree, 'node_num_cum')
89 | ids = torch.arange(npt, device=pts.device, dtype=torch.long)
90 | ids = ids.unsqueeze(-1).repeat(1, kNN).view(-1)
91 | for d in range(depth_start, depth_end+1):
92 | neighd = neighs[d]
93 | idxd = neighd['idx']
94 | xyzfd = neighd['xyzf']
95 | weightd = neighd['weights']
96 | valid = neighd['flgs']
97 | idsd = ids[valid]
98 |
99 | if d < depth_end:
100 | child = ocnn.octree_property(octree, 'child', d)
101 | leaf = child[idxd] < 0 # keep only leaf nodes
102 | idsd, idxd, weightd, xyzfd = idsd[leaf], idxd[leaf], weightd[leaf], xyzfd[leaf]
103 |
104 | idxd = idxd + (nnum_cum[d] - nnum_cum[depth_start])
105 | indices.append(torch.stack([idsd, idxd], dim=1))
106 | weights.append(weightd)
107 | xyzfs.append(xyzfd)
108 |
109 | indices = torch.cat(indices, dim=0).t()
110 | weights = torch.cat(weights, dim=0)
111 | xyzfs = torch.cat(xyzfs, dim=0)
112 |
113 | code_num = shape_code.size(0)
114 | output = modulated_spmm(indices, weights, npt, code_num, shape_code, xyzfs)
115 | norm = spmm(indices, weights, npt, code_num, torch.ones(code_num, 1).cuda())
116 | output = torch.div(output, norm + 1e-8).squeeze()
117 |
118 | # whether the point has affected by the octree node in depth layer
119 | mask = neighs[depth_end]['flgs'].view(-1, kNN).any(axis=-1)
120 | return output, mask
121 |
122 |
123 | class NeuralMPU:
124 | def __init__(self, full_depth, depth):
125 | self.full_depth = full_depth
126 | self.depth = depth
127 |
128 | def __call__(self, pos, reg_voxs, octree_out):
129 | mpus = dict()
130 | neighs = dict()
131 | for d in range(self.full_depth, self.depth+1):
132 | neighs[d] = octree_linear_pts(octree_out, d, pos)
133 | fval, flgs = get_linear_pred(
134 | pos, octree_out, reg_voxs[d], neighs, self.full_depth, d)
135 | mpus[d] = (fval, flgs)
136 | return mpus
137 |
--------------------------------------------------------------------------------
/datasets/shapenet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import ocnn
10 | import torch
11 | import numpy as np
12 |
13 | from solver import Dataset
14 | from .utils import collate_func
15 |
16 |
17 | class TransformShape:
18 |
19 | def __init__(self, flags):
20 | self.flags = flags
21 |
22 | self.point_sample_num = 3000
23 | self.sdf_sample_num = 5000
24 | self.points_scale = 0.5 # the points are in [-0.5, 0.5]
25 | self.noise_std = 0.005
26 | self.points2octree = ocnn.Points2Octree(**flags)
27 |
28 | def process_points_cloud(self, sample):
29 | # get the input
30 | points, normals = sample['points'], sample['normals']
31 | points = points / self.points_scale # scale to [-1.0, 1.0]
32 |
33 | # transform points to octree
34 | points_gt = ocnn.points_new(
35 | torch.from_numpy(points).float(), torch.from_numpy(normals).float(),
36 | torch.Tensor(), torch.Tensor())
37 | points_gt, _ = ocnn.clip_points(points_gt, [-1.0]*3, [1.0]*3)
38 | octree_gt = self.points2octree(points_gt)
39 |
40 | if self.flags.distort:
41 | # randomly sample points and add noise
42 | # Since we rescale points to [-1.0, 1.0] in Line 24, we also need to
43 | # rescale the `noise_std` here to make sure the `noise_std` is always
44 | # 0.5% of the bounding box size.
45 | noise_std = self.noise_std / self.points_scale
46 | noise = noise_std * np.random.randn(self.point_sample_num, 3)
47 | rand_idx = np.random.choice(points.shape[0], size=self.point_sample_num)
48 | points_noise = points[rand_idx] + noise
49 |
50 | points_in = ocnn.points_new(
51 | torch.from_numpy(points_noise).float(), torch.Tensor(),
52 | torch.ones(self.point_sample_num).float(), torch.Tensor())
53 | points_in, _ = ocnn.clip_points(points_in, [-1.0]*3, [1.0]*3)
54 | octree_in = self.points2octree(points_in)
55 | else:
56 | points_in = points_gt
57 | octree_in = octree_gt
58 |
59 | # construct the output dict
60 | return {'octree_in': octree_in, 'points_in': points_in,
61 | 'octree_gt': octree_gt, 'points_gt': points_gt}
62 |
63 | def sample_sdf(self, sample):
64 | sdf = sample['sdf']
65 | grad = sample['grad']
66 | points = sample['points'] / self.points_scale # to [-1, 1]
67 |
68 | rand_idx = np.random.choice(points.shape[0], size=self.sdf_sample_num)
69 | points = torch.from_numpy(points[rand_idx]).float()
70 | sdf = torch.from_numpy(sdf[rand_idx]).float()
71 | grad = torch.from_numpy(grad[rand_idx]).float()
72 | return {'pos': points, 'sdf': sdf, 'grad': grad}
73 |
74 | def sample_on_surface(self, points, normals):
75 | rand_idx = np.random.choice(points.shape[0], size=self.sdf_sample_num)
76 | xyz = torch.from_numpy(points[rand_idx]).float()
77 | grad = torch.from_numpy(normals[rand_idx]).float()
78 | sdf = torch.zeros(self.sdf_sample_num)
79 | return {'pos': xyz, 'sdf': sdf, 'grad': grad}
80 |
81 | def sample_off_surface(self, xyz):
82 | xyz = xyz / self.points_scale # to [-1, 1]
83 |
84 | rand_idx = np.random.choice(xyz.shape[0], size=self.sdf_sample_num)
85 | xyz = torch.from_numpy(xyz[rand_idx]).float()
86 | # grad = torch.zeros(self.sample_number, 3) # dummy grads
87 | grad = xyz / (xyz.norm(p=2, dim=1, keepdim=True) + 1.0e-6)
88 | sdf = -1 * torch.ones(self.sdf_sample_num) # dummy sdfs
89 | return {'pos': xyz, 'sdf': sdf, 'grad': grad}
90 |
91 | def __call__(self, sample, idx):
92 | output = self.process_points_cloud(sample['point_cloud'])
93 |
94 | # sample ground truth sdfs
95 | if self.flags.load_sdf:
96 | sdf_samples = self.sample_sdf(sample['sdf'])
97 | output.update(sdf_samples)
98 |
99 | # sample on surface points and off surface points
100 | if self.flags.sample_surf_points:
101 | on_surf = self.sample_on_surface(sample['points'], sample['normals'])
102 | off_surf = self.sample_off_surface(sample['sdf']['points']) # TODO
103 | sdf_samples = {
104 | 'pos': torch.cat([on_surf['pos'], off_surf['pos']], dim=0),
105 | 'grad': torch.cat([on_surf['grad'], off_surf['grad']], dim=0),
106 | 'sdf': torch.cat([on_surf['sdf'], off_surf['sdf']], dim=0)}
107 | output.update(sdf_samples)
108 |
109 | return output
110 |
111 |
112 | class ReadFile:
113 | def __init__(self, load_sdf=False, load_occu=False):
114 | self.load_occu = load_occu
115 | self.load_sdf = load_sdf
116 |
117 | def __call__(self, filename):
118 | filename_pc = os.path.join(filename, 'pointcloud.npz')
119 | raw = np.load(filename_pc)
120 | point_cloud = {'points': raw['points'], 'normals': raw['normals']}
121 | output = {'point_cloud': point_cloud}
122 |
123 | if self.load_occu:
124 | filename_occu = os.path.join(filename, 'points.npz')
125 | raw = np.load(filename_occu)
126 | occu = {'points': raw['points'], 'occupancies': raw['occupancies']}
127 | output['occu'] = occu
128 |
129 | if self.load_sdf:
130 | filename_sdf = os.path.join(filename, 'sdf.npz')
131 | raw = np.load(filename_sdf)
132 | sdf = {'points': raw['points'], 'grad': raw['grad'], 'sdf': raw['sdf']}
133 | output['sdf'] = sdf
134 | return output
135 |
136 |
137 | def get_shapenet_dataset(flags):
138 | transform = TransformShape(flags)
139 | read_file = ReadFile(flags.load_sdf, flags.load_occu)
140 | dataset = Dataset(flags.location, flags.filelist, transform,
141 | read_file=read_file, in_memory=flags.in_memory)
142 | return dataset, collate_func
143 |
--------------------------------------------------------------------------------
/tools/room.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import wget
10 | import time
11 | import zipfile
12 | import argparse
13 | import numpy as np
14 | from tqdm import tqdm
15 | from plyfile import PlyData, PlyElement
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--run', type=str, required=True)
19 | args = parser.parse_args()
20 |
21 | project_folder = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
22 | root_folder = os.path.join(project_folder, 'data/room')
23 |
24 |
25 | def create_flag_file(filename):
26 | r''' Creates a flag file to indicate whether some time-consuming works
27 | have been done.
28 | '''
29 |
30 | folder = os.path.dirname(filename)
31 | if not os.path.exists(folder):
32 | os.makedirs(folder)
33 | with open(filename, 'w') as fid:
34 | fid.write('succ @ ' + time.ctime())
35 |
36 |
37 | def check_folder(filenames: list):
38 | r''' Checks whether the folder contains the filename exists.
39 | '''
40 |
41 | for filename in filenames:
42 | folder = os.path.dirname(filename)
43 | if not os.path.exists(folder):
44 | os.makedirs(folder)
45 |
46 |
47 | def get_filenames(filelist, root_folder):
48 | r''' Gets filenames from a filelist.
49 | '''
50 |
51 | filelist = os.path.join(root_folder, 'filelist', filelist)
52 | with open(filelist, 'r') as fid:
53 | lines = fid.readlines()
54 | filenames = [line.split()[0] for line in lines]
55 | return filenames
56 |
57 |
58 | def download_and_unzip():
59 | r''' Dowanload and unzip the data.
60 | '''
61 |
62 | filename = os.path.join(root_folder, 'synthetic_room_dataset.zip')
63 | flag_file = os.path.join(root_folder, 'flags/download_room_dataset_succ')
64 | if not os.path.exists(flag_file):
65 | check_folder([filename])
66 | print('-> Download synthetic_room_dataset.zip.')
67 | url = 'https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/synthetic_room_dataset.zip'
68 | wget.download(url, filename)
69 | create_flag_file(flag_file)
70 |
71 | flag_file = os.path.join(root_folder, 'flags/unzip_succ')
72 | if not os.path.exists(flag_file):
73 | print('-> Unzip synthetic_room_dataset.zip.')
74 | with zipfile.ZipFile(filename, 'r') as zip_ref:
75 | zip_ref.extractall(path=root_folder)
76 | # os.remove(filename)
77 | create_flag_file(flag_file)
78 |
79 |
80 | def download_filelist():
81 | r''' Downloads the filelists used for learning.
82 | '''
83 |
84 | flag_file = os.path.join(root_folder, 'flags/download_filelist_succ')
85 | if not os.path.exists(flag_file):
86 | print('-> Download the filelist.')
87 | url = 'https://www.dropbox.com/s/30v6pdek6777vkr/room.filelist.zip?dl=1'
88 | filename = os.path.join(root_folder, 'filelist.zip')
89 | wget.download(url, filename, bar=None)
90 |
91 | folder = os.path.join(root_folder, 'filelist')
92 | with zipfile.ZipFile(filename, 'r') as zip_ref:
93 | zip_ref.extractall(path=folder)
94 | # os.remove(filename)
95 | create_flag_file(flag_file)
96 |
97 |
98 | def download_ground_truth_mesh():
99 | r''' Downloads the ground-truth meshes
100 | '''
101 |
102 | flag_file = os.path.join(root_folder, 'flags/download_mesh_succ')
103 | if not os.path.exists(flag_file):
104 | print('-> Download the filelist.')
105 | url = 'https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/room_watertight_mesh.zip'
106 | filename = os.path.join(root_folder, 'filelist.zip')
107 | wget.download(url, filename, bar=None)
108 | create_flag_file(flag_file)
109 |
110 |
111 | def generate_test_points():
112 | r''' Generates points in `ply` format for testing.
113 | '''
114 |
115 | noise_std = 0.005
116 | point_sample_num = 10000
117 | print('-> Generate testing points.')
118 | # filenames = get_filenames('all.txt')
119 | filenames = get_filenames('test.txt', root_folder)
120 | for filename in tqdm(filenames, ncols=80):
121 | filename_pts = os.path.join(
122 | root_folder, 'synthetic_room_dataset', filename, 'pointcloud', 'pointcloud_00.npz')
123 | filename_ply = os.path.join(
124 | root_folder, 'test.input', filename + '.ply')
125 | if not os.path.exists(filename_pts): continue
126 | check_folder([filename_ply])
127 |
128 | # sample points
129 | pts = np.load(filename_pts)
130 | points = pts['points'].astype(np.float32)
131 | noise = noise_std * np.random.randn(point_sample_num, 3)
132 | rand_idx = np.random.choice(points.shape[0], size=point_sample_num)
133 | points_noise = points[rand_idx] + noise
134 |
135 | # save ply
136 | vertices = []
137 | py_types = (float, float, float)
138 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')]
139 | for idx in range(points_noise.shape[0]):
140 | vertices.append(
141 | tuple(dtype(d) for dtype, d in zip(py_types, points_noise[idx])))
142 | structured_array = np.array(vertices, dtype=npy_types)
143 | el = PlyElement.describe(structured_array, 'vertex')
144 | PlyData([el]).write(filename_ply)
145 |
146 |
147 | def download_test_points():
148 | r''' Downloads the test points used in our paper.
149 | '''
150 | print('-> Download testing points.')
151 | flag_file = os.path.join(root_folder, 'flags/download_test_points_succ')
152 | if not os.path.exists(flag_file):
153 | url = 'https://www.dropbox.com/s/q3h47042xh6sua7/scene.test.input.zip?dl=1'
154 | filename = os.path.join(root_folder, 'test.input.zip')
155 | wget.download(url, filename, bar=None)
156 |
157 | folder = os.path.join(root_folder, 'test.input')
158 | with zipfile.ZipFile(filename, 'r') as zip_ref:
159 | zip_ref.extractall(path=folder)
160 | # os.remove(filename)
161 | create_flag_file(flag_file)
162 |
163 |
164 | def generate_dataset():
165 | download_and_unzip()
166 | download_filelist()
167 | # generate_test_points()
168 | download_test_points()
169 |
170 |
171 | if __name__ == '__main__':
172 | eval('%s()' % args.run)
173 |
--------------------------------------------------------------------------------
/datasets/pointcloud.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import ocnn
10 | import torch
11 | import numpy as np
12 |
13 | from solver import Dataset
14 | from .utils import collate_func
15 |
16 |
17 | class Transform:
18 |
19 | def __init__(self, flags):
20 | self.flags = flags
21 |
22 | self.octant_sample_num = 2
23 | self.point_sample_num = flags.point_sample_num
24 | self.point_scale = flags.point_scale
25 | self.points2octree = ocnn.Points2Octree(**flags)
26 |
27 | def build_octree(self, points):
28 | pts, normals = points[:, :3], points[:, 3:]
29 | points_in = ocnn.points_new(pts, normals, torch.Tensor(), torch.Tensor())
30 | # points_in, _ = ocnn.clip_points(points_in, [-1.0]*3, [1.0]*3)
31 | octree = self.points2octree(points_in)
32 | return {'octree_in': octree, 'points_in': points_in, 'octree_gt': octree}
33 |
34 | def sample_on_surface(self, points):
35 | '''Randomly sample points on the surface.'''
36 |
37 | rnd_idx = torch.randint(high=points.shape[0], size=(self.point_sample_num,))
38 | pos = points[rnd_idx, :3]
39 | normal = points[rnd_idx, 3:]
40 | sdf = torch.zeros(self.point_sample_num)
41 | return {'pos': pos, 'grad': normal, 'sdf': sdf}
42 |
43 | def sample_off_surface(self, bbox):
44 | '''Randomly sample points in the 3D space.'''
45 |
46 | # uniformly sampling in the whole 3D sapce
47 | pos = torch.rand(self.point_sample_num, 3) * 2 - 1
48 |
49 | # point gradients
50 | # grad = torch.zeros(self.point_sample_num, 3)
51 | norm = torch.sqrt(torch.sum(pos**2, dim=1, keepdim=True)) + 1e-6
52 | grad = pos / norm # fake off-surface gradients
53 |
54 | # sdf values
55 | esp = 0.04
56 | bbmin, bbmax = bbox[:3] - esp, bbox[3:] + esp
57 | mask = torch.logical_and(pos > bbmin, pos < bbmax).all(1) # inbox
58 | sdf = -1.0 * torch.ones(self.point_sample_num)
59 | sdf[mask.logical_not()] = 1.0 # exactly out-of-bbox
60 | return {'pos': pos, 'grad': grad, 'sdf': sdf}
61 |
62 | def sample_on_octree(self, octree):
63 | '''Adaptively sample points according the octree in the 3D space.'''
64 |
65 | xyzs = []
66 | sample_num = 0
67 | depth = ocnn.octree_property(octree, 'depth')
68 | full_depth = ocnn.octree_property(octree, 'full_depth')
69 | for d in range(full_depth, depth+1):
70 | # get octree key
71 | xyz = ocnn.octree_property(octree, 'xyz', d)
72 | empty_node = ocnn.octree_property(octree, 'child', d) < 0
73 | xyz = ocnn.octree_decode_key(xyz[empty_node])
74 |
75 | # sample k points in each octree node
76 | xyz = xyz[:, :3].float() # + 0.5 -> octree node center
77 | rnd = torch.rand(xyz.shape[0], self.octant_sample_num, 3)
78 | xyz = xyz.unsqueeze(1) + rnd
79 | xyz = xyz.view(-1, 3) # (N, 3)
80 | xyz = xyz * (2 ** (1 - d)) - 1 # normalize to [-1, 1]
81 | xyzs.append(xyz)
82 |
83 | # make sure that the max sample number is self.point_sample_num
84 | sample_num += xyz.shape[0]
85 | if sample_num > self.point_sample_num:
86 | size = self.point_sample_num - (sample_num - xyz.shape[0])
87 | rnd_idx = torch.randint(high=xyz.shape[0], size=(size,))
88 | xyzs[-1] = xyzs[-1][rnd_idx]
89 | break
90 |
91 | pos = torch.cat(xyzs, dim=0)
92 | grad = torch.zeros_like(pos)
93 | sdf = -1.0 * torch.ones(pos.shape[0])
94 | return {'pos': pos, 'grad': grad, 'sdf': sdf}
95 |
96 | def scale_and_clip_points(self, points):
97 | points[:, :3] = points[:, :3] * self.point_scale # rescale points
98 | pts = points[:, :3]
99 | mask = torch.logical_and(pts > -1.0, pts < 1.0).all(1)
100 | return points[mask]
101 |
102 | def compute_bbox(self, points):
103 | pts = points[:, :3]
104 | bbmin = pts.min(0)[0] - 0.06
105 | bbmax = pts.max(0)[0] + 0.06
106 | bbmin = torch.clamp(bbmin, min=-1, max=1)
107 | bbmax = torch.clamp(bbmax, min=-1, max=1)
108 | return torch.cat([bbmin, bbmax])
109 |
110 | def __call__(self, points, idx):
111 | points = self.scale_and_clip_points(points) # clip points to [-1, 1]
112 | output = self.build_octree(points)
113 | bbox = self.compute_bbox(points)
114 | output['bbox'] = bbox # used in marching cubes
115 |
116 | sdf_on_surf = self.sample_on_surface(points)
117 | # TODO: compare self.sample_on_octree & self.sample_off_surface
118 | # sdf_off_surf = self.sample_on_octree(output['octree_in'])
119 | sdf_off_surf = self.sample_off_surface(bbox)
120 | sdfs = {key: torch.cat([sdf_on_surf[key], sdf_off_surf[key]], dim=0)
121 | for key in sdf_on_surf.keys()}
122 | output.update(sdfs)
123 |
124 | return output
125 |
126 |
127 | class SingleTransform(Transform):
128 |
129 | def __init__(self, flags):
130 | super().__init__(flags)
131 | self.output = None
132 | self.points = None
133 | self.bbox = None
134 |
135 | def __call__(self, points, idx):
136 | if self.output is None:
137 | self.points = self.scale_and_clip_points(points) # clip points to [-1, 1]
138 | self.output = self.build_octree(self.points)
139 | self.bbox = self.compute_bbox(self.points)
140 | self.output['bbox'] = self.bbox # used in marching cubes
141 |
142 | output = self.output
143 | sdf_on_surf = self.sample_on_surface(self.points)
144 | # TODO: compare self.sample_on_octree & self.sample_off_surface
145 | # sdf_off_surf = self.sample_on_octree(output['octree_in'])
146 | sdf_off_surf = self.sample_off_surface(self.bbox)
147 | sdfs = {key: torch.cat([sdf_on_surf[key], sdf_off_surf[key]], dim=0)
148 | for key in sdf_on_surf.keys()}
149 | output.update(sdfs)
150 |
151 | return output
152 |
153 |
154 | def read_file(filename: str):
155 | if filename.endswith('.xyz'):
156 | points = np.loadtxt(filename)
157 | elif filename.endswith('.npy'):
158 | points = np.load(filename)
159 | else:
160 | raise NotImplementedError
161 | output = torch.from_numpy(points.astype(np.float32))
162 | return output
163 |
164 |
165 | def get_pointcloud_dataset(flags):
166 | transform = Transform(flags)
167 | dataset = Dataset(flags.location, flags.filelist, transform,
168 | read_file=read_file, in_memory=flags.in_memory)
169 | return dataset, collate_func
170 |
171 |
172 | def get_singlepointcloud_dataset(flags):
173 | transform = SingleTransform(flags)
174 | dataset = Dataset(flags.location, flags.filelist, transform,
175 | read_file=read_file, in_memory=flags.in_memory)
176 | return dataset, collate_func
177 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | # autopep8: off
9 | import ocnn
10 | import torch
11 | import torch.autograd
12 | import numpy as np
13 | import matplotlib
14 | matplotlib.use('Agg')
15 | import matplotlib.pyplot as plt
16 | import skimage.measure
17 | import trimesh
18 | from plyfile import PlyData, PlyElement
19 | from scipy.spatial import cKDTree
20 | # autopep8: on
21 |
22 |
23 | def get_mgrid(size, dim=3):
24 | r'''
25 | Example:
26 | >>> get_mgrid(3, dim=2)
27 | array([[0.0, 0.0],
28 | [0.0, 1.0],
29 | [0.0, 2.0],
30 | [1.0, 0.0],
31 | [1.0, 1.0],
32 | [1.0, 2.0],
33 | [2.0, 0.0],
34 | [2.0, 1.0],
35 | [2.0, 2.0]], dtype=float32)
36 | '''
37 | coord = np.arange(0, size, dtype=np.float32)
38 | coords = [coord] * dim
39 | output = np.meshgrid(*coords, indexing='ij')
40 | output = np.stack(output, -1)
41 | output = output.reshape(size**dim, dim)
42 | return output
43 |
44 |
45 | def lin2img(tensor):
46 | channels = 1
47 | num_samples = tensor.shape
48 | size = int(np.sqrt(num_samples))
49 | return tensor.view(channels, size, size)
50 |
51 |
52 | def make_contour_plot(array_2d, mode='log'):
53 | fig, ax = plt.subplots(figsize=(2.75, 2.75), dpi=300)
54 |
55 | if(mode == 'log'):
56 | nlevels = 6
57 | levels_pos = np.logspace(-2, 0, num=nlevels) # logspace
58 | levels_neg = -1. * levels_pos[::-1]
59 | levels = np.concatenate((levels_neg, np.zeros((0)), levels_pos), axis=0)
60 | colors = plt.get_cmap("Spectral")(np.linspace(0., 1., num=nlevels * 2 + 1))
61 | elif(mode == 'lin'):
62 | nlevels = 10
63 | levels = np.linspace(-.5, .5, num=nlevels)
64 | colors = plt.get_cmap("Spectral")(np.linspace(0., 1., num=nlevels))
65 | else:
66 | raise NotImplementedError
67 |
68 | sample = np.flipud(array_2d)
69 | CS = ax.contourf(sample, levels=levels, colors=colors)
70 | cbar = fig.colorbar(CS)
71 |
72 | ax.contour(sample, levels=levels, colors='k', linewidths=0.1)
73 | ax.contour(sample, levels=[0], colors='k', linewidths=0.3)
74 | ax.axis('off')
75 | return fig
76 |
77 |
78 | def write_sdf_summary(model, writer, global_step, alias=''):
79 | size = 128
80 | coords_2d = get_mgrid(size, dim=2)
81 | coords_2d = coords_2d / size - 1.0 # [0, size] -> [-1, 1]
82 | coords_2d = torch.from_numpy(coords_2d)
83 | with torch.no_grad():
84 | zeros = torch.zeros_like(coords_2d[:, :1])
85 | ones = torch.ones_like(coords_2d[:, :1])
86 | names = ['train_yz_sdf_slice', 'train_xz_sdf_slice', 'train_xy_sdf_slice']
87 | coords = [torch.cat((zeros, coords_2d), dim=-1),
88 | torch.cat((coords_2d[:, :1], zeros, coords_2d[:, -1:]), dim=-1),
89 | torch.cat((coords_2d, -0.75 * ones), dim=-1)]
90 | for name, coord in zip(names, coords):
91 | ids = torch.zeros(coord.shape[0], 1)
92 | coord = torch.cat([coord, ids], dim=1).cuda()
93 | sdf_values = model(coord)
94 | sdf_values = lin2img(sdf_values).squeeze().cpu().numpy()
95 | fig = make_contour_plot(sdf_values)
96 | writer.add_figure(alias + name, fig, global_step=global_step)
97 |
98 |
99 | def calc_sdf(model, size=256, max_batch=64**3, bbmin=-1.0, bbmax=1.0):
100 | # generate samples
101 | num_samples = size ** 3
102 | samples = get_mgrid(size, dim=3)
103 | samples = samples * ((bbmax - bbmin) / size) + bbmin # [0,sz]->[bbmin,bbmax]
104 | samples = torch.from_numpy(samples)
105 | sdfs = torch.zeros(num_samples)
106 |
107 | # forward
108 | head = 0
109 | while head < num_samples:
110 | tail = min(head + max_batch, num_samples)
111 | sample_subset = samples[head:tail, :]
112 | idx = torch.zeros(sample_subset.shape[0], 1)
113 | pts = torch.cat([sample_subset, idx], dim=1).cuda()
114 | pred = model(pts).squeeze().detach().cpu()
115 | sdfs[head:tail] = pred
116 | head += max_batch
117 | sdfs = sdfs.reshape(size, size, size).numpy()
118 | return sdfs
119 |
120 |
121 | def create_mesh(model, filename, size=256, max_batch=64**3, level=0,
122 | bbmin=-0.9, bbmax=0.9, mesh_scale=1.0, save_sdf=False, **kwargs):
123 | # marching cubes
124 | sdf_values = calc_sdf(model, size, max_batch, bbmin, bbmax)
125 | vtx, faces = np.zeros((0, 3)), np.zeros((0, 3))
126 | try:
127 | vtx, faces, _, _ = skimage.measure.marching_cubes(sdf_values, level)
128 | except:
129 | pass
130 | if vtx.size == 0 or faces.size == 0:
131 | print('Warning from marching cubes: Empty mesh!')
132 | return
133 |
134 | # normalize vtx
135 | vtx = vtx * ((bbmax - bbmin) / size) + bbmin # [0,sz]->[bbmin,bbmax]
136 | vtx = vtx * mesh_scale # rescale
137 |
138 | # save to ply and npy
139 | mesh = trimesh.Trimesh(vtx, faces)
140 | mesh.export(filename)
141 | if save_sdf:
142 | np.save(filename[:-4] + ".sdf.npy", sdf_values)
143 |
144 |
145 | def calc_sdf_err(filename_gt, filename_pred):
146 | scale = 1.0e2 # scale the result for better display
147 | sdf_gt = np.load(filename_gt)
148 | sdf = np.load(filename_pred)
149 | err = np.abs(sdf - sdf_gt).mean() * scale
150 | return err
151 |
152 |
153 | def calc_chamfer(filename_gt, filename_pred, point_num):
154 | scale = 1.0e5 # scale the result for better display
155 | np.random.seed(101)
156 |
157 | mesh_a = trimesh.load(filename_gt)
158 | points_a, _ = trimesh.sample.sample_surface(mesh_a, point_num)
159 | mesh_b = trimesh.load(filename_pred)
160 | points_b, _ = trimesh.sample.sample_surface(mesh_b, point_num)
161 |
162 | kdtree_a = cKDTree(points_a)
163 | dist_a, _ = kdtree_a.query(points_b)
164 | chamfer_a = np.mean(np.square(dist_a)) * scale
165 |
166 | kdtree_b = cKDTree(points_b)
167 | dist_b, _ = kdtree_b.query(points_a)
168 | chamfer_b = np.mean(np.square(dist_b)) * scale
169 | return chamfer_a, chamfer_b
170 |
171 |
172 | def points2ply(filename, points, scale=1.0):
173 | xyz = ocnn.points_property(points, 'xyz')
174 | normal = ocnn.points_property(points, 'normal')
175 | has_normal = normal is not None
176 | xyz = xyz.numpy() * scale
177 | if has_normal: normal = normal.numpy()
178 |
179 | # data types
180 | data = xyz
181 | py_types = (float, float, float)
182 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')]
183 | if has_normal:
184 | py_types = py_types + (float, float, float)
185 | npy_types = npy_types + [('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4')]
186 | data = np.concatenate((data, normal), axis=1)
187 |
188 | # format into NumPy structured array
189 | vertices = []
190 | for idx in range(data.shape[0]):
191 | vertices.append(tuple(dtype(d) for dtype, d in zip(py_types, data[idx])))
192 | structured_array = np.array(vertices, dtype=npy_types)
193 | el = PlyElement.describe(structured_array, 'vertex')
194 |
195 | # write ply
196 | PlyData([el]).write(filename)
197 |
--------------------------------------------------------------------------------
/models/graph_ounet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import ocnn
9 | import torch
10 | import torch.nn
11 |
12 | from . import mpu
13 | from . import modules
14 | from . import dual_octree
15 |
16 |
17 | class GraphOUNet(torch.nn.Module):
18 |
19 | def __init__(self, depth, channel_in, nout, full_depth=2, depth_out=6,
20 | resblk_type='bottleneck', bottleneck=4):
21 | super().__init__()
22 | self.depth = depth
23 | self.channel_in = channel_in
24 | self.nout = nout
25 | self.full_depth = full_depth
26 | self.depth_out = depth_out
27 | self.resblk_type = resblk_type
28 | self.bottleneck = bottleneck
29 | self.neural_mpu = mpu.NeuralMPU(self.full_depth, self.depth_out)
30 | self._setup_channels_and_resblks()
31 | n_edge_type, avg_degree = 7, 7
32 |
33 | # encoder
34 | self.conv1 = modules.GraphConvBnRelu(
35 | channel_in, self.channels[depth], n_edge_type, avg_degree, depth-1)
36 | self.encoder = torch.nn.ModuleList(
37 | [modules.GraphResBlocks(self.channels[d], self.channels[d],
38 | self.resblk_num[d], bottleneck, n_edge_type, avg_degree, d-1, resblk_type)
39 | for d in range(depth, full_depth-1, -1)])
40 | self.downsample = torch.nn.ModuleList(
41 | [modules.GraphDownsample(self.channels[d], self.channels[d-1])
42 | for d in range(depth, full_depth, -1)])
43 |
44 | # decoder
45 | self.upsample = torch.nn.ModuleList(
46 | [modules.GraphUpsample(self.channels[d-1], self.channels[d])
47 | for d in range(full_depth+1, depth + 1)])
48 | self.decoder = torch.nn.ModuleList(
49 | [modules.GraphResBlocks(self.channels[d], self.channels[d],
50 | self.resblk_num[d], bottleneck, n_edge_type, avg_degree, d-1, resblk_type)
51 | for d in range(full_depth+1, depth + 1)])
52 |
53 | # header
54 | self.predict = torch.nn.ModuleList(
55 | [self._make_predict_module(self.channels[d], 2)
56 | for d in range(full_depth, depth + 1)])
57 | self.regress = torch.nn.ModuleList(
58 | [self._make_predict_module(self.channels[d], 4)
59 | for d in range(full_depth, depth + 1)])
60 |
61 | def _setup_channels_and_resblks(self):
62 | # self.resblk_num = [3] * 7 + [1] + [1] * 9
63 | self.resblk_num = [3] * 16
64 | self.channels = [4, 512, 512, 256, 128, 64, 32, 32, 24]
65 |
66 | def _make_predict_module(self, channel_in, channel_out=2, num_hidden=32):
67 | return torch.nn.Sequential(
68 | modules.Conv1x1BnRelu(channel_in, num_hidden),
69 | modules.Conv1x1(num_hidden, channel_out, use_bias=True))
70 |
71 | def _get_input_feature(self, doctree):
72 | return doctree.get_input_feature()
73 |
74 | def octree_encoder(self, octree, doctree):
75 | depth, full_depth = self.depth, self.full_depth
76 | data = self._get_input_feature(doctree)
77 |
78 | convs = dict()
79 | convs[depth] = data
80 | for i, d in enumerate(range(depth, full_depth-1, -1)):
81 | # perform graph conv
82 | convd = convs[d] # get convd
83 | edge_idx = doctree.graph[d]['edge_idx']
84 | edge_type = doctree.graph[d]['edge_dir']
85 | node_type = doctree.graph[d]['node_type']
86 | if d == self.depth: # the first conv
87 | convd = self.conv1(convd, edge_idx, edge_type, node_type)
88 | convd = self.encoder[i](convd, edge_idx, edge_type, node_type)
89 | convs[d] = convd # update convd
90 |
91 | # downsampleing
92 | if d > full_depth: # init convd
93 | nnum = doctree.nnum[d]
94 | lnum = doctree.lnum[d-1]
95 | leaf_mask = doctree.node_child(d-1) < 0
96 | convs[d-1] = self.downsample[i](convd, leaf_mask, nnum, lnum)
97 |
98 | return convs
99 |
100 | def octree_decoder(self, convs, doctree_out, doctree, update_octree=False):
101 | logits = dict()
102 | reg_voxs = dict()
103 | deconvs = dict()
104 |
105 | deconvs[self.full_depth] = convs[self.full_depth]
106 | for i, d in enumerate(range(self.full_depth, self.depth_out+1)):
107 | if d > self.full_depth:
108 | nnum = doctree_out.nnum[d-1]
109 | leaf_mask = doctree_out.node_child(d-1) < 0
110 | deconvd = self.upsample[i-1](deconvs[d-1], leaf_mask, nnum)
111 | skip = modules.doctree_align(
112 | convs[d], doctree.graph[d]['keyd'], doctree_out.graph[d]['keyd'])
113 | deconvd = deconvd + skip # skip connections
114 |
115 | edge_idx = doctree_out.graph[d]['edge_idx']
116 | edge_type = doctree_out.graph[d]['edge_dir']
117 | node_type = doctree_out.graph[d]['node_type']
118 | deconvs[d] = self.decoder[i-1](deconvd, edge_idx, edge_type, node_type)
119 |
120 | # predict the splitting label
121 | logit = self.predict[i](deconvs[d])
122 | nnum = doctree_out.nnum[d]
123 | logits[d] = logit[-nnum:]
124 |
125 | # update the octree according to predicted labels
126 | if update_octree:
127 | label = logits[d].argmax(1).to(torch.int32)
128 | octree_out = doctree_out.octree
129 | octree_out = ocnn.octree_update(octree_out, label, d, split=1)
130 | if d < self.depth_out:
131 | octree_out = ocnn.octree_grow(octree_out, target_depth=d+1)
132 | doctree_out = dual_octree.DualOctree(octree_out)
133 | doctree_out.post_processing_for_docnn()
134 |
135 | # predict the signal
136 | reg_vox = self.regress[i](deconvs[d])
137 |
138 | # TODO: improve it
139 | # pad zeros to reg_vox to reuse the original code for ocnn
140 | node_mask = doctree_out.graph[d]['node_mask']
141 | shape = (node_mask.shape[0], reg_vox.shape[1])
142 | reg_vox_pad = torch.zeros(shape, device=reg_vox.device)
143 | reg_vox_pad[node_mask] = reg_vox
144 | reg_voxs[d] = reg_vox_pad
145 |
146 | return logits, reg_voxs, doctree_out.octree
147 |
148 | def forward(self, octree_in, octree_out=None, pos=None):
149 | # generate dual octrees
150 | doctree_in = dual_octree.DualOctree(octree_in)
151 | doctree_in.post_processing_for_docnn()
152 |
153 | update_octree = octree_out is None
154 | if update_octree:
155 | octree_out = ocnn.create_full_octree(self.full_depth, self.nout)
156 | doctree_out = dual_octree.DualOctree(octree_out)
157 | doctree_out.post_processing_for_docnn()
158 |
159 | # run encoder and decoder
160 | convs = self.octree_encoder(octree_in, doctree_in)
161 | out = self.octree_decoder(convs, doctree_out, doctree_in, update_octree)
162 | output = {'logits': out[0], 'reg_voxs': out[1], 'octree_out': out[2]}
163 |
164 | # compute function value with mpu
165 | if pos is not None:
166 | output['mpus'] = self.neural_mpu(pos, out[1], out[2])
167 |
168 | # create the mpu wrapper
169 | def _neural_mpu(pos):
170 | pred = self.neural_mpu(pos, out[1], out[2])
171 | return pred[self.depth_out][0]
172 | output['neural_mpu'] = _neural_mpu
173 |
174 | return output
175 |
--------------------------------------------------------------------------------
/solver/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | # autopep8: off
9 | import os
10 | import sys
11 | import shutil
12 | import argparse
13 | from datetime import datetime
14 | from yacs.config import CfgNode as CN
15 |
16 | _C = CN()
17 |
18 | # SOLVER related parameters
19 | _C.SOLVER = CN()
20 | _C.SOLVER.alias = '' # The experiment alias
21 | _C.SOLVER.gpu = (0,) # The gpu ids
22 | _C.SOLVER.run = 'train' # Choose from train or test
23 |
24 | _C.SOLVER.logdir = 'logs' # Directory where to write event logs
25 | _C.SOLVER.ckpt = '' # Restore weights from checkpoint file
26 | _C.SOLVER.ckpt_num = 10 # The number of checkpoint kept
27 |
28 | _C.SOLVER.type = 'sgd' # Choose from sgd or adam
29 | _C.SOLVER.weight_decay = 0.0005 # The weight decay on model weights
30 | _C.SOLVER.max_epoch = 300 # Maximum training epoch
31 | _C.SOLVER.eval_epoch = 1 # Maximum evaluating epoch
32 | _C.SOLVER.eval_step = -1 # Maximum evaluating steps
33 | _C.SOLVER.test_every_epoch = 10 # Test model every n training epochs
34 | _C.SOLVER.log_per_iter = -1 # Output log every k training iteration
35 |
36 | _C.SOLVER.lr_type = 'step' # Learning rate type: step or cos
37 | _C.SOLVER.lr = 0.1 # Initial learning rate
38 | _C.SOLVER.gamma = 0.1 # Learning rate step-wise decay
39 | _C.SOLVER.step_size = (120,60,) # Learning rate step size.
40 | _C.SOLVER.lr_power = 0.9 # Used in poly learning rate
41 |
42 | _C.SOLVER.dist_url = 'tcp://localhost:10001'
43 | _C.SOLVER.progress_bar = True
44 |
45 |
46 | # DATA related parameters
47 | _C.DATA = CN()
48 | _C.DATA.train = CN()
49 | _C.DATA.train.name = '' # The name of the dataset
50 |
51 | # For octree building
52 | # If node_dis = True and there are normals, the octree features
53 | # is 4 channels, i.e., the average normals and the 1 channel displacement.
54 | # If node_dis = True and there are no normals, the feature is also 4 channels,
55 | # i.e., a 3 channel # displacement of average points relative to the center
56 | # points, and the last channel is constant.
57 | _C.DATA.train.depth = 5 # The octree depth
58 | _C.DATA.train.full_depth = 2 # The full depth
59 | _C.DATA.train.node_dis = False # Save the node displacement
60 | _C.DATA.train.split_label = False # Save the split label
61 | _C.DATA.train.adaptive = False # Build the adaptive octree
62 | _C.DATA.train.node_feat = False # Calculate the node feature
63 |
64 | # For normalization
65 | # If radius < 0, then the method will compute a bounding sphere
66 | _C.DATA.train.bsphere = 'sphere' # The method uesd to calc the bounding sphere
67 | _C.DATA.train.radius = -1. # The radius and center of the bounding sphere
68 | _C.DATA.train.center = (-1., -1., -1.)
69 |
70 | # For transformation
71 | _C.DATA.train.offset = 0.016 # Used to displace the points when building octree
72 | _C.DATA.train.normal_axis = '' # Used to re-orient normal directions
73 |
74 | # For data augmentation
75 | _C.DATA.train.disable = False # Disable this dataset or not
76 | _C.DATA.train.distort = False # Whether to apply data augmentation
77 | _C.DATA.train.scale = 0.0 # Scale the points
78 | _C.DATA.train.uniform = False # Generate uniform scales
79 | _C.DATA.train.jitter = 0.0 # Jitter the points
80 | _C.DATA.train.interval = (1, 1, 1) # Use interval&angle to generate random angle
81 | _C.DATA.train.angle = (180, 180, 180)
82 |
83 | # For data loading
84 | _C.DATA.train.location = '' # The data location
85 | _C.DATA.train.filelist = '' # The data filelist
86 | _C.DATA.train.batch_size = 32 # Training data batch size
87 | _C.DATA.train.num_workers = 8 # Number of workers to load the data
88 | _C.DATA.train.shuffle = False # Shuffle the input data
89 | _C.DATA.train.in_memory = False # Load the training data into memory
90 |
91 |
92 | _C.DATA.test = _C.DATA.train.clone()
93 |
94 |
95 | # MODEL related parameters
96 | _C.MODEL = CN()
97 | _C.MODEL.name = '' # The name of the model
98 | _C.MODEL.depth = 5 # The input octree depth
99 | _C.MODEL.full_depth = 2 # The input octree full depth layer
100 | _C.MODEL.depth_out = 5 # The output feature depth
101 | _C.MODEL.channel = 3 # The input feature channel
102 | _C.MODEL.factor = 1 # The factor used to widen the network
103 | _C.MODEL.nout = 40 # The output feature channel
104 | _C.MODEL.resblock_num = 3 # The resblock number
105 | _C.MODEL.resblock_type = 'bottleneck'# Choose from 'bottleneck' and 'basic
106 | _C.MODEL.bottleneck = 4 # The bottleneck factor of one resblock
107 | _C.MODEL.dropout = (0.0,) # The dropout ratio
108 |
109 | _C.MODEL.upsample = 'nearest' # The method used for upsampling
110 | _C.MODEL.interp = 'linear' # The interplation method: linear or nearest
111 | _C.MODEL.nempty = False # Perform Octree Conv on non-empty octree nodes
112 | _C.MODEL.sync_bn = False # Use sync_bn when training the network
113 | _C.MODEL.use_checkpoint = False # Use checkpoint to save memory
114 | _C.MODEL.find_unused_parameters = False # Used in DistributedDataParallel
115 |
116 | # loss related parameters
117 | _C.LOSS = CN()
118 | _C.LOSS.name = '' # The name of the loss
119 | _C.LOSS.num_class = 40 # The class number for the cross-entropy loss
120 | _C.LOSS.weights = (1.0, 1.0) # The weight factors for different losses
121 | _C.LOSS.label_smoothing = 0.0 # The factor of label smoothing
122 |
123 |
124 | # backup the commands
125 | _C.SYS = CN()
126 | _C.SYS.cmds = '' # Used to backup the commands
127 |
128 | FLAGS = _C
129 |
130 |
131 | def _update_config(FLAGS, args):
132 | FLAGS.defrost()
133 | if args.config:
134 | FLAGS.merge_from_file(args.config)
135 | if args.opts:
136 | FLAGS.merge_from_list(args.opts)
137 | FLAGS.SYS.cmds = ' '.join(sys.argv)
138 |
139 | # update logdir
140 | alias = FLAGS.SOLVER.alias.lower()
141 | if 'time' in alias: # 'time' is a special keyword
142 | alias = alias.replace('time', datetime.now().strftime('%m%d%H%M')) #%S
143 | if alias is not '':
144 | FLAGS.SOLVER.logdir += '_' + alias
145 | FLAGS.freeze()
146 |
147 |
148 | def _backup_config(FLAGS, args):
149 | logdir = FLAGS.SOLVER.logdir
150 | if not os.path.exists(logdir):
151 | os.makedirs(logdir)
152 | # copy the file to logdir
153 | if args.config:
154 | shutil.copy2(args.config, logdir)
155 | # dump all configs
156 | filename = os.path.join(logdir, 'all_configs.yaml')
157 | with open(filename, 'w') as fid:
158 | fid.write(FLAGS.dump())
159 |
160 |
161 | def _set_env_var(FLAGS):
162 | gpus = ','.join([str(a) for a in FLAGS.SOLVER.gpu])
163 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus
164 |
165 |
166 | def get_config():
167 | return FLAGS
168 |
169 | def parse_args(backup=True):
170 | parser = argparse.ArgumentParser(description='The configs')
171 | parser.add_argument('--config', type=str,
172 | help='experiment configure file name')
173 | parser.add_argument('opts', nargs=argparse.REMAINDER,
174 | help="Modify config options using the command-line")
175 |
176 | args = parser.parse_args()
177 | _update_config(FLAGS, args)
178 | if backup:
179 | _backup_config(FLAGS, args)
180 | _set_env_var(FLAGS)
181 | return FLAGS
182 |
183 |
184 | if __name__ == '__main__':
185 | flags = parse_args(backup=False)
186 | print(flags)
187 |
--------------------------------------------------------------------------------
/losses/loss.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import ocnn
9 | import torch
10 | import torch.nn.functional as F
11 |
12 |
13 | def compute_gradient(y, x):
14 | grad_outputs = torch.ones_like(y)
15 | grad = torch.autograd.grad(y, [x], grad_outputs, create_graph=True)[0]
16 | return grad
17 |
18 |
19 | def sdf_reg_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''):
20 | wg, ws = 1.0, 200.0
21 | grad_loss = (grad - grad_gt).pow(2).mean() * wg
22 | sdf_loss = (sdf - sdf_gt).pow(2).mean() * ws
23 | loss_dict = {'grad_loss' + name_suffix: grad_loss,
24 | 'sdf_loss' + name_suffix: sdf_loss}
25 | return loss_dict
26 |
27 |
28 | def sdf_grad_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''):
29 | on_surf = sdf_gt != -1
30 | off_surf = on_surf.logical_not()
31 |
32 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0
33 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0
34 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1
35 | grad_loss = (grad[off_surf].norm(2, dim=-1) - 1).abs().mean() * 0.1
36 |
37 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss]
38 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss']
39 | names = [name + name_suffix for name in names]
40 | loss_dict = dict(zip(names, losses))
41 | return loss_dict
42 |
43 |
44 | def sdf_grad_regularized_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''):
45 | on_surf = sdf_gt != -1
46 | off_surf = on_surf.logical_not()
47 |
48 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0
49 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0
50 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1
51 | grad_loss = (grad[off_surf].norm(2, dim=-1) - 1).abs().mean() * 0.1
52 | grad_reg_loss = (grad[off_surf] - grad_gt[off_surf]).pow(2).mean() * 0.1
53 |
54 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss, grad_reg_loss]
55 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss', 'grad_reg_loss']
56 | names = [name + name_suffix for name in names]
57 | loss_dict = dict(zip(names, losses))
58 | return loss_dict
59 |
60 |
61 | def possion_grad_loss(sdf, grad, sdf_gt, grad_gt, name_suffix=''):
62 | on_surf = sdf_gt == 0
63 | out_of_bbox = sdf_gt == 1.0
64 | off_surf = on_surf.logical_not()
65 |
66 | sdf_loss = sdf[on_surf].pow(2).mean() * 200.0
67 | norm_loss = (grad[on_surf] - grad_gt[on_surf]).pow(2).mean() * 1.0
68 | intr_loss = torch.exp(-40 * torch.abs(sdf[off_surf])).mean() * 0.1
69 | grad_loss = grad[off_surf].pow(2).mean() * 0.1 # poisson loss
70 | bbox_loss = torch.mean(torch.relu(-sdf[out_of_bbox])) * 100.0
71 |
72 | losses = [sdf_loss, intr_loss, norm_loss, grad_loss, bbox_loss]
73 | names = ['sdf_loss', 'inter_loss', 'norm_loss', 'grad_loss', 'bbox_loss']
74 | names = [name + name_suffix for name in names]
75 | loss_dict = dict(zip(names, losses))
76 | return loss_dict
77 |
78 |
79 | def compute_mpu_gradients(mpus, pos, fval_transform=None):
80 | grads = dict()
81 | for d in mpus.keys():
82 | fval, flags = mpus[d]
83 | if fval_transform is not None:
84 | fval = fval_transform(fval)
85 | grads[d] = compute_gradient(fval, pos)[:, :3]
86 | return grads
87 |
88 |
89 | def compute_octree_loss(logits, octree_out):
90 | weights = [1.0] * 16
91 | # weights = [1.0] * 4 + [0.8, 0.6, 0.4] + [0.2] * 16
92 |
93 | output = dict()
94 | for d in logits.keys():
95 | logitd = logits[d]
96 | label_gt = ocnn.octree_property(octree_out, 'split', d).long()
97 | output['loss_%d' % d] = F.cross_entropy(logitd, label_gt) * weights[d]
98 | output['accu_%d' % d] = logitd.argmax(1).eq(label_gt).float().mean()
99 | return output
100 |
101 |
102 | def compute_sdf_loss(mpus, grads, sdf_gt, grad_gt, reg_loss_func):
103 | output = dict()
104 | for d in mpus.keys():
105 | sdf, flgs = mpus[d] # TODO: tune the loss weights and `flgs`
106 | reg_loss = reg_loss_func(sdf, grads[d], sdf_gt, grad_gt, '_%d' % d)
107 | # if d < 3: # ignore depth 2
108 | # for key in reg_loss.keys():
109 | # reg_loss[key] = reg_loss[key] * 0.0
110 | output.update(reg_loss)
111 | return output
112 |
113 |
114 | def compute_occu_loss_v0(mpus, grads, occu_gt, grad_gt, weight):
115 | output = dict()
116 | for d in mpus.keys():
117 | occu, flgs, grad = mpus[d]
118 |
119 | # pos_weight = torch.ones_like(occu_gt) * 10.0
120 | loss_o = F.binary_cross_entropy_with_logits(occu, occu_gt, weight=weight)
121 | # loss_g = torch.mean((grad - grad_gt) ** 2)
122 |
123 | occu = torch.sigmoid(occu)
124 | non_surface_points = occu_gt != 0.5
125 | accu = (occu > 0.5).eq(occu_gt).float()[non_surface_points].mean()
126 |
127 | output['occu_loss_%d' % d] = loss_o
128 | # output['grad_loss_%d' % d] = loss_g
129 | output['occu_accu_%d' % d] = accu
130 | return output
131 |
132 |
133 | def compute_occu_loss_1214(mpus, occu):
134 | # tried on 2021.12.13
135 | weights = [0.2] * 4 + [0.4, 0.6, 0.8] + [1.0] * 16 # TODO: tune the weights
136 |
137 | inside = occu == 0
138 | outside = occu == 1
139 | output = dict()
140 | for d in mpus.keys():
141 | sdf, flgs = mpus[d]
142 |
143 | inside_loss = torch.mean(torch.relu(sdf[inside])) * (1000 * weights[d])
144 | outside_loss = torch.mean(torch.relu(-sdf[outside])) * (1000 * weights[d])
145 |
146 | output['inside_loss_%d' % d] = inside_loss
147 | output['outside_loss_%d' % d] = outside_loss
148 | output['inside_accu_%d' % d] = (sdf[inside] < 0).float().mean()
149 | output['outside_accu_%d' % d] = (sdf[outside] > 0).float().mean()
150 | return output
151 |
152 |
153 | def compute_occu_loss(mpus, grads, occu, grad_gt):
154 | weights = [1.0] * 16
155 | # weights = [0.2] * 4 + [0.4, 0.6, 0.8] + [1.0] * 16
156 | # weights = [0.0] * 7 + [1.0] * 16 # Single level loss
157 |
158 | inside = occu == 0
159 | outside = occu == 1
160 | on_surf = occu == 0.5
161 | off_surf = on_surf.logical_not()
162 |
163 | output = dict()
164 | for d in mpus.keys():
165 | sdf, flgs = mpus[d]
166 | grad = grads[d]
167 | grad_diff = grad[on_surf] - grad_gt[on_surf]
168 |
169 | sdf_loss = sdf[on_surf].pow(2).mean() * (200 * weights[d])
170 | norm_loss = grad_diff.pow(2).mean() * (1.0 * weights[d])
171 | intr_loss = torch.exp(-40 * sdf[off_surf].abs()).mean() * (0.1 * weights[d])
172 | grad_loss = grad[off_surf].pow(2).mean() * (0.1 * weights[d])
173 |
174 | inside_loss = torch.mean(torch.relu(sdf[inside])) * (500 * weights[d])
175 | outside_loss = torch.mean(torch.relu(-sdf[outside])) * (2000 * weights[d])
176 | inside_accu = (sdf[inside] < 0).float().mean()
177 | outside_accu = (sdf[outside] > 0).float().mean()
178 |
179 | losses = [sdf_loss, norm_loss, grad_loss, intr_loss,
180 | inside_loss, inside_accu, outside_loss, outside_accu]
181 | names = ['sdf_loss', 'norm_loss', 'grad_loss', 'inter_loss',
182 | 'inside_loss', 'inside_accu', 'outside_loss', 'outside_accu']
183 | names = [name + ('_%d' % d) for name in names]
184 | loss_dict = dict(zip(names, losses))
185 | output.update(loss_dict)
186 |
187 | return output
188 |
189 |
190 | def compute_occu_loss_cls(mpus, grads, occu_gt, grad_gt):
191 | weights = [1.0] * 16
192 | # weights = [0.2] * 4 + [0.4, 0.6, 0.8] + [1.0] * 16
193 | # weights = [0.0] * 7 + [1.0] * 16 # Single level loss
194 |
195 | inside = occu_gt == 0
196 | outside = occu_gt == 1
197 | on_surf = occu_gt == 0.5
198 | off_surf = on_surf.logical_not()
199 |
200 | # Use soft-version occupancies
201 | occu_soft = torch.ones_like(occu_gt) * 0.5
202 | occu_soft[inside] = 0.3
203 | occu_soft[outside] = 0.7
204 |
205 | output = dict()
206 | for d in mpus.keys():
207 | sdf, flgs = mpus[d]
208 | grad = grads[d]
209 | grad_diff = grad[on_surf] - grad_gt[on_surf]
210 |
211 | sdf_loss = sdf[on_surf].pow(2).mean() * (200 * weights[d])
212 | norm_loss = grad_diff.pow(2).mean() * (1.0 * weights[d])
213 | intr_loss = torch.exp(-40 * sdf[off_surf].abs()).mean() * (0.1 * weights[d])
214 | grad_loss = grad[off_surf].pow(2).mean() * (0.1 * weights[d])
215 |
216 | loss_o = F.binary_cross_entropy_with_logits(sdf, occu_soft)
217 | accu = (sdf.sigmoid() > 0.5).eq(occu_gt).float()
218 | inside_accu = accu[inside].mean()
219 | outside_accu = accu[outside].mean()
220 |
221 | losses = [sdf_loss, norm_loss, grad_loss, intr_loss,
222 | loss_o, inside_accu, outside_accu]
223 | names = ['sdf_loss', 'norm_loss', 'grad_loss', 'inter_loss',
224 | 'occu_loss', 'inside_accu', 'outside_accu']
225 | names = [name + ('_%d' % d) for name in names]
226 | loss_dict = dict(zip(names, losses))
227 | output.update(loss_dict)
228 |
229 | return output
230 |
231 |
232 | def get_sdf_loss_function(loss_type=''):
233 | if loss_type == 'sdf_reg_loss':
234 | return sdf_reg_loss
235 | elif loss_type == 'sdf_grad_loss':
236 | return sdf_grad_loss
237 | elif loss_type == 'possion_grad_loss':
238 | return possion_grad_loss
239 | elif loss_type == 'sdf_grad_reg_loss':
240 | return sdf_grad_regularized_loss
241 | else:
242 | return None
243 |
244 |
245 | def shapenet_loss(batch, model_out, reg_loss_type=''):
246 | # octree loss
247 | output = compute_octree_loss(model_out['logits'], model_out['octree_out'])
248 |
249 | # regression loss
250 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos'])
251 | reg_loss_func = get_sdf_loss_function(reg_loss_type)
252 | sdf_loss = compute_sdf_loss(
253 | model_out['mpus'], grads, batch['sdf'], batch['grad'], reg_loss_func)
254 | output.update(sdf_loss)
255 | return output
256 |
257 |
258 | def dfaust_loss(batch, model_out, reg_loss_type=''):
259 | # there is no octree loss
260 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos'])
261 | reg_loss_func = get_sdf_loss_function(reg_loss_type)
262 | output = compute_sdf_loss(
263 | model_out['mpus'], grads, batch['sdf'], batch['grad'], reg_loss_func)
264 | return output
265 |
266 |
267 | def synthetic_room_loss(batch, model_out, reg_loss_type=''):
268 | # octree loss
269 | output = compute_octree_loss(model_out['logits'], model_out['octree_out'])
270 |
271 | # grads
272 | grads = compute_mpu_gradients(model_out['mpus'], batch['pos'])
273 |
274 | # occu loss
275 | occu_loss = compute_occu_loss(
276 | model_out['mpus'], grads, batch['occu'], batch['grad'])
277 | output.update(occu_loss)
278 |
279 | return output
280 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dual Octree Graph Networks
2 |
3 | This repository contains the implementation of our papers *Dual Octree Graph Networks*. The experiments are conducted on Ubuntu 18.04 with 4 V400 GPUs (32GB memory). The code is released under the **MIT license**.
4 |
5 | **[Dual Octree Graph Networks for Learning Adaptive Volumetric Shape Representations](https://arxiv.org/abs/2205.02825)**
6 | [Peng-Shuai Wang](https://wang-ps.github.io/), [Yang Liu](https://xueyuhanlang.github.io/), and [Xin Tong](https://www.microsoft.com/en-us/research/people/xtong/)
7 | ACM Transactions on Graphics (SIGGRAPH), 41(4), 2022
8 |
9 | 
10 |
11 |
12 | - [Dual Octree Graph Networks](#dual-octree-graph-networks)
13 | - [1. Installation](#1-installation)
14 | - [2. Shape Reconstruction with ShapeNet](#2-shape-reconstruction-with-shapenet)
15 | - [2.1 Data Preparation](#21-data-preparation)
16 | - [2.2 Experiment](#22-experiment)
17 | - [2.3 Generalization](#23-generalization)
18 | - [3. Synthetic Scene Reconstruction](#3-synthetic-scene-reconstruction)
19 | - [3.1 Data Preparation](#31-data-preparation)
20 | - [3.2 Experiment](#32-experiment)
21 | - [4. Unsupervised Surface Reconstruction with DFaust](#4-unsupervised-surface-reconstruction-with-dfaust)
22 | - [4.1 Data Preparation](#41-data-preparation)
23 | - [4.2 Experiment](#42-experiment)
24 | - [4.3 Generalization](#43-generalization)
25 | - [5. Autoencoder with ShapeNet](#5-autoencoder-with-shapenet)
26 | - [5.1 Data Preparation](#51-data-preparation)
27 | - [5.2 Experiment](#52-experiment)
28 |
29 |
30 | ## 1. Installation
31 |
32 | 1. Install [Conda](https://www.anaconda.com/) and create a `Conda` environment.
33 | ```bash
34 | conda create --name dualocnn python=3.7
35 | conda activate dualocnn
36 | ```
37 |
38 | 1. Install PyTorch-1.9.1 with conda according to the official documentation.
39 | ```bash
40 | conda install pytorch==1.9.1 torchvision==0.10.1 cudatoolkit=10.2 -c pytorch
41 | ```
42 |
43 | 2. Install `ocnn-pytorch` from [O-CNN](https://github.com/microsoft/O-CNN).
44 | ```bash
45 | git clone https://github.com/microsoft/O-CNN.git
46 | cd O-CNN/pytorch
47 | pip install -r requirements.txt
48 | python setup.py install --build_octree
49 | ```
50 |
51 | 3. Clone this repository and install other requirements.
52 | ```bash
53 | git clone https://github.com/microsoft/DualOctreeGNN.git
54 | cd DualOctreeGNN
55 | pip install -r requirements.txt
56 | ```
57 |
58 | ## 2. Shape Reconstruction with ShapeNet
59 |
60 | ### 2.1 Data Preparation
61 |
62 | 1. Download `ShapeNetCore.v1.zip` (31G) from [ShapeNet](https://shapenet.org/)
63 | and place it into the folder `data/ShapeNet`.
64 |
65 | 2. Convert the meshes in `ShapeNetCore.v1` to signed distance fields (SDFs).
66 |
67 | ```bash
68 | python tools/shapenet.py --run convert_mesh_to_sdf
69 | ```
70 |
71 | Note that this process is relatively slow, it may take several days to finish
72 | converting all the meshes from ShapeNet. And for simplicity, I did not use
73 | multiprocessing of python to speed up. If the speed is a matter, you can
74 | simultaneously execute multiple python commands manually by specifying the
75 | `start` and `end` index of the mesh to be processed. An example is shown as
76 | follows:
77 |
78 | ```bash
79 | python tools/shapenet.py --run convert_mesh_to_sdf --start 10000 --end 20000
80 | ```
81 |
82 | The `ShapeNetConv.v1` contains 57k meshes. After unzipping, the total size is
83 | about 100G. And the total sizes of the generated SDFs and the repaired meshes
84 | are 450G and 90G, respectively. Please make sure your hard disk has enough
85 | space.
86 |
87 | 3. Sample points and ground-truth SDFs for the learning process.
88 |
89 | ```bash
90 | python tools/shapenet.py --run generate_dataset
91 | ```
92 |
93 | 4. If you just want to forward the pretrained network, the test point clouds
94 | (330M) can be downloaded manually from
95 | [here](https://www.dropbox.com/s/us28g6808srcop5/shapenet.test.input.zip?dl=0).
96 | After downloading the zip file, unzip it to the folder
97 | `data/ShapeNet/test.input`.
98 |
99 |
107 |
108 | ### 2.2 Experiment
109 |
110 | 1. **Train**: Run the following command to train the network on 4 GPUs. The
111 | training takes 17 hours on 4 V100 GPUs. The trained weight and log can be
112 | downloaded [here](https://www.dropbox.com/s/v3tnopnqxoqqbvb/shapenet_weights_log.zip?dl=1).
113 |
114 | ```bash
115 | python dualocnn.py --config configs/shapenet.yaml SOLVER.gpu 0,1,2,3
116 | ```
117 |
118 | 2. **Test**: Run the following command to generate the extracted meshes. It is
119 | also possible to specify other trained weights by replacing the parameter
120 | after `SOLVER.ckpt`.
121 |
122 | ```bash
123 | python dualocnn.py --config configs/shapenet_eval.yaml \
124 | SOLVER.ckpt logs/shapenet/shapenet/checkpoints/00300.model.pth
125 | ```
126 |
127 | 3. **Evaluate**: We use the code of
128 | [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks.git)
129 | to compute the evaluation metrics. Following the instructions
130 | [here](https://github.com/wang-ps/ConvONet) to reproduce our results in Table
131 | 1.
132 |
133 | ### 2.3 Generalization
134 |
135 | 1. **Test**: Run the following command to test the trained network on unseen 5
136 | categories of ShapeNet:
137 |
138 | ```bash
139 | python dualocnn.py --config configs/shapenet_unseen5.yaml \
140 | SOLVER.ckpt logs/shapenet/shapenet/checkpoints/00300.model.pth
141 | ```
142 |
143 | 2. **Evaluate**: Following the instructions
144 | [here](https://github.com/wang-ps/ConvONet) to reproduce our results on the
145 | unseen dataset in Table 1.
146 |
147 | ## 3. Synthetic Scene Reconstruction
148 |
149 | ### 3.1 Data Preparation
150 |
151 | Download and unzip the synthetic scene dataset (205G in total) and the data
152 | splitting filelists by
153 | [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks)
154 | via the following command.
155 | If needed, the ground truth meshes can be downloaded from
156 | [here](https://s3.eu-central-1.amazonaws.com/avg-projects/convolutional_occupancy_networks/data/room_watertight_mesh.zip)
157 | (90G).
158 |
159 | ```bash
160 | python tools/room.py --run generate_dataset
161 | ```
162 |
163 | ### 3.2 Experiment
164 |
165 | 1. **Train**: Run the following command to train the network on 4 GPUs. The
166 | training takes 27 hours on 4 V100 GPUs. The trained weight and log can be
167 | downloaded
168 | [here](https://www.dropbox.com/s/t9p8e8tg9rzeaeq/room_weights_log.zip?dl=1).
169 |
170 | ```bash
171 | python dualocnn.py --config configs/synthetic_room.yaml SOLVER.gpu 0,1,2,3
172 | ```
173 |
174 | 2. **Test**: Run the following command to generate the extracted meshes.
175 |
176 | ```bash
177 | python dualocnn.py --config configs/synthetic_room_eval.yaml \
178 | SOLVER.ckpt logs/room/room/checkpoints/00900.model.pth
179 | ```
180 |
181 | 3. **Evaluate**: Following the instructions
182 | [here](https://github.com/wang-ps/ConvONet) to reproduce our results in Table
183 | 5.
184 |
185 |
186 | ## 4. Unsupervised Surface Reconstruction with DFaust
187 |
188 |
189 | ### 4.1 Data Preparation
190 |
191 | 1. Download the [DFaust](https://dfaust.is.tue.mpg.de/) dataset, unzip the raw
192 | scans into the folder `data/dfaust/scans`, and unzip the ground-truth meshes
193 | into the folder `data/dfaust/mesh_gt`. Note that the ground-truth meshes are
194 | used in computing evaluation metric and NOT used in training.
195 |
196 | 2. Run the following command to prepare the dataset.
197 |
198 | ```bash
199 | python tools/dfaust.py --run genereate_dataset
200 | ```
201 |
202 | 3. For convenience, we also provide the dataset for downloading.
203 |
204 | ```bash
205 | python tools/dfaust.py --run download_dataset
206 | ```
207 |
208 | ### 4.2 Experiment
209 |
210 | 1. **Train**: Run the following command to train the network on 4 GPUs. The
211 | training takes 20 hours on 4 V100 GPUs. The trained weight and log can be
212 | downloaded
213 | [here](https://www.dropbox.com/s/lyhr9n3b7uhjul8/dfaust_weights_log.zip?dl=0).
214 |
215 | ```bash
216 | python dualocnn.py --config configs/dfaust.yaml SOLVER.gpu 0,1,2,3
217 | ```
218 |
219 |
220 | 2. **Test**: Run the following command to generate the meshes with the trained
221 | weights.
222 |
223 | ```bash
224 | python dualocnn.py --config configs/dfaust_eval.yaml \
225 | SOLVER.ckpt logs/dfaust/dfaust/checkpoints/00600.model.pth
226 | ```
227 |
228 | 3. **Evaluate**: To calculate the evaluation metric, we need first rescale the
229 | mesh into the original size, since the point clouds are scaled during the
230 | data processing stage.
231 |
232 | ```bash
233 | python tools/dfaust.py \
234 | --mesh_folder logs/dfaust_eval/dfaust \
235 | --output_folder logs/dfaust_eval/dfaust_rescale \
236 | --run rescale_mesh
237 | ```
238 |
239 | Then our results in Table 6 can be reproduced in the file `metrics.csv`.
240 |
241 | ```bash
242 | python tools/compute_metrics.py \
243 | --mesh_folder logs/dfaust_eval/dfaust_rescale \
244 | --filelist data/dfaust/filelist/test.txt \
245 | --ref_folder data/dfaust/mesh_gt \
246 | --filename_out logs/dfaust_eval/dfaust_rescale/metrics.csv
247 | ```
248 |
249 |
250 | ### 4.3 Generalization
251 |
252 | In the Figure 1 and 11 of our paper, we test the generalization ability of our
253 | network on several out-of-distribution point clouds. Please download the point
254 | clouds from [here](https://www.dropbox.com/s/pzo9xajktaml4hh/shapes.zip?dl=1),
255 | and place the unzipped data to the folder `data/shapes`. Then run the following
256 | command to reproduce the results:
257 |
258 | ```bash
259 | python dualocnn.py --config configs/shapes.yaml \
260 | SOLVER.ckpt logs/dfaust/dfaust/checkpoints/00600.model.pth \
261 | ```
262 |
263 |
264 | ## 5. Autoencoder with ShapeNet
265 |
266 | ### 5.1 Data Preparation
267 |
268 | Following the instructions [here](#21-data-preparation) to prepare the dataset.
269 |
270 |
271 | ### 5.2 Experiment
272 |
273 | 1. **Train**: Run the following command to train the network on 4 GPUs. The
274 | training takes 24 hours on 4 V100 GPUs. The trained weight and log can be
275 | downloaded
276 | [here](https://www.dropbox.com/s/3e4bx3zaj0b85kd/shapenet_ae_weights_log.zip?dl=1).
277 |
278 | ```bash
279 | python dualocnn.py --config configs/shapenet_ae.yaml SOLVER.gpu 0,1,2,3
280 | ```
281 |
282 | 2. **Test**: Run the following command to generate the extracted meshes.
283 |
284 | ```bash
285 | python dualocnn.py --config configs/shapenet_ae_eval.yaml \
286 | SOLVER.ckpt logs/shapenet/ae/checkpoints/00300.model.pth
287 | ```
288 |
289 | 3. **Evaluate**: Run the following command to evaluate the predicted meshes.
290 | Then our results in Table 7 can be reproduced in the file `metrics.4096.csv`.
291 |
292 | ```bash
293 | python tools/compute_metrics.py \
294 | --mesh_folder logs/shapenet_eval/ae \
295 | --filelist data/ShapeNet/filelist/test_im.txt \
296 | --ref_folder data/ShapeNet/mesh \
297 | --num_samples 4096 \
298 | --filename_out logs/shapenet_eval/ae/metrics.4096.csv
299 | ```
300 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import math
9 | import torch
10 | import torch.nn
11 | import torch.nn.init
12 | import torch.nn.functional as F
13 | import torch.utils.checkpoint
14 | # import torch_geometric.nn
15 |
16 | from .utils.scatter import scatter_mean
17 |
18 | bn_momentum, bn_eps = 0.01, 0.001 # the default value of Tensorflow 1.x
19 | # bn_momentum, bn_eps = 0.1, 1e-05 # the default value of pytorch
20 |
21 |
22 | def ckpt_conv_wrapper(conv_op, x, edge_index, edge_type):
23 | def conv_wrapper(x, edge_index, edge_type, dummy_tensor):
24 | return conv_op(x, edge_index, edge_type)
25 |
26 | # The dummy tensor is a workaround when the checkpoint is used for the first conv layer:
27 | # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11
28 | dummy = torch.ones(1, dtype=torch.float32, requires_grad=True)
29 |
30 | return torch.utils.checkpoint.checkpoint(
31 | conv_wrapper, x, edge_index, edge_type, dummy)
32 |
33 |
34 | # class GraphConv_v0(torch_geometric.nn.MessagePassing):
35 | class GraphConv_v0:
36 | ''' This implementation explicitly constructs the self.weights[edge_type],
37 | thus consuming a lot of computation and memory.
38 | '''
39 |
40 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7):
41 | super().__init__(aggr='add')
42 | self.in_channels = in_channels
43 | self.out_channels = out_channels
44 | self.n_edge_type = n_edge_type
45 | self.avg_degree = avg_degree
46 |
47 | self.weights = torch.nn.Parameter(
48 | torch.Tensor(n_edge_type, out_channels, in_channels))
49 | self.reset_parameters()
50 |
51 | def reset_parameters(self) -> None:
52 | fan_in = self.avg_degree * self.weights.shape[2]
53 | fan_out = self.avg_degree * self.weights.shape[1]
54 | std = math.sqrt(2.0 / float(fan_in + fan_out))
55 | a = math.sqrt(3.0) * std
56 | torch.nn.init.uniform_(self.weights, -a, a)
57 |
58 | def forward(self, x, edge_index, edge_type):
59 | # x has shape [N, in_channels]
60 | # edge_index has shape [2, E]
61 |
62 | return self.propagate(edge_index, x=x, edge_type=edge_type)
63 |
64 | def message(self, x_j, edge_type):
65 | weights = self.weights[edge_type] # (N, out_channels, in_channels)
66 | output = weights @ x_j.unsqueeze(-1)
67 | return output.squeeze(-1)
68 |
69 |
70 | class GraphConv(torch.nn.Module):
71 |
72 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7,
73 | n_node_type=0):
74 | super().__init__()
75 | self.in_channels = in_channels
76 | self.out_channels = out_channels
77 | self.n_edge_type = n_edge_type
78 | self.avg_degree = avg_degree
79 | self.n_node_type = n_node_type
80 |
81 | node_channel = n_node_type if n_node_type > 1 else 0
82 | self.weights = torch.nn.Parameter(
83 | torch.Tensor(n_edge_type * (in_channels + node_channel), out_channels))
84 | # if n_node_type > 0:
85 | # self.node_weights = torch.nn.Parameter(
86 | # torch.tensor([0.5 ** i for i in range(n_node_type)]))
87 | self.reset_parameters()
88 |
89 | def reset_parameters(self) -> None:
90 | fan_in = self.avg_degree * self.in_channels
91 | fan_out = self.avg_degree * self.out_channels
92 | std = math.sqrt(2.0 / float(fan_in + fan_out))
93 | a = math.sqrt(3.0) * std
94 | torch.nn.init.uniform_(self.weights, -a, a)
95 |
96 | def forward(self, x, edge_index, edge_type, node_type=None):
97 | has_node_type = node_type is not None
98 | if has_node_type and self.n_node_type > 1:
99 | # concatenate the one_hot vector
100 | one_hot = F.one_hot(node_type, num_classes=self.n_node_type)
101 | x = torch.cat([x, one_hot], dim=1)
102 |
103 | # x -> col_data
104 | row, col = edge_index[0], edge_index[1]
105 | # weights = torch.pow(0.5, node_type[col]) if has_node_type else None
106 | weights = None # TODO: ablation the weights
107 | index = row * self.n_edge_type + edge_type
108 | col_data = scatter_mean(x[col], index, dim=0, weights=weights,
109 | dim_size=x.shape[0] * self.n_edge_type)
110 |
111 | # matrix product
112 | output = col_data.view(x.shape[0], -1) @ self.weights
113 | return output
114 |
115 | def extra_repr(self) -> str:
116 | return ('channel_in={}, channel_out={}, n_edge_type={}, avg_degree={}, '
117 | 'n_node_type={}'.format(self.in_channels, self.out_channels,
118 | self.n_edge_type, self.avg_degree, self.n_node_type)) # noqa
119 |
120 |
121 | class GraphConvBn(torch.nn.Module):
122 |
123 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7,
124 | n_node_type=0):
125 | super().__init__()
126 | self.conv = GraphConv(
127 | in_channels, out_channels, n_edge_type, avg_degree, n_node_type)
128 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
129 |
130 | def forward(self, x, edge_index, edge_type, node_type=None):
131 | out = self.conv(x, edge_index, edge_type, node_type)
132 | out = self.bn(out)
133 | return out
134 |
135 |
136 | class GraphConvBnRelu(torch.nn.Module):
137 |
138 | def __init__(self, in_channels, out_channels, n_edge_type=7, avg_degree=7,
139 | n_node_type=0):
140 | super().__init__()
141 | self.conv = GraphConv(
142 | in_channels, out_channels, n_edge_type, avg_degree, n_node_type)
143 | self.bn = torch.nn.BatchNorm1d(out_channels, bn_eps, bn_momentum)
144 | self.relu = torch.nn.ReLU(inplace=True)
145 |
146 | def forward(self, x, edge_index, edge_type, node_type=None):
147 | out = self.conv(x, edge_index, edge_type, node_type)
148 | # out = ckpt_conv_wrapper(self.conv, x, edge_index, edge_type)
149 | out = self.bn(out)
150 | out = self.relu(out)
151 | return out
152 |
153 |
154 | class Conv1x1(torch.nn.Module):
155 |
156 | def __init__(self, channel_in, channel_out, use_bias=False):
157 | super().__init__()
158 | self.linear = torch.nn.Linear(channel_in, channel_out, use_bias)
159 |
160 | def forward(self, x):
161 | return self.linear(x)
162 |
163 |
164 | class Conv1x1Bn(torch.nn.Module):
165 |
166 | def __init__(self, channel_in, channel_out):
167 | super().__init__()
168 | self.conv = Conv1x1(channel_in, channel_out, use_bias=False)
169 | self.bn = torch.nn.BatchNorm1d(channel_out, bn_eps, bn_momentum)
170 |
171 | def forward(self, x):
172 | out = self.conv(x)
173 | out = self.bn(out)
174 | return out
175 |
176 |
177 | class Conv1x1BnRelu(torch.nn.Module):
178 |
179 | def __init__(self, channel_in, channel_out):
180 | super().__init__()
181 | self.conv = Conv1x1(channel_in, channel_out, use_bias=False)
182 | self.bn = torch.nn.BatchNorm1d(channel_out, bn_eps, bn_momentum)
183 | self.relu = torch.nn.ReLU(inplace=True)
184 |
185 | def forward(self, x):
186 | out = self.conv(x)
187 | out = self.bn(out)
188 | out = self.relu(out)
189 | return out
190 |
191 |
192 | class Upsample(torch.nn.Module):
193 |
194 | def __init__(self, channels):
195 | super().__init__()
196 | self.channels = channels
197 |
198 | self.weights = torch.nn.Parameter(
199 | torch.Tensor(channels, channels, 8))
200 | torch.nn.init.xavier_uniform_(self.weights)
201 |
202 | def forward(self, x):
203 | out = x @ self.weights.flatten(1)
204 | out = out.view(-1, self.channels)
205 | return out
206 |
207 | def extra_repr(self):
208 | return 'channels={}'.format(self.channels)
209 |
210 |
211 | class Downsample(torch.nn.Module):
212 |
213 | def __init__(self, channels):
214 | super().__init__()
215 | self.channels = channels
216 |
217 | self.weights = torch.nn.Parameter(
218 | torch.Tensor(channels, channels, 8))
219 | torch.nn.init.xavier_uniform_(self.weights)
220 |
221 | def forward(self, x):
222 | weights = self.weights.flatten(1).t()
223 | out = x.view(-1, self.channels * 8) @ weights
224 | return out
225 |
226 | def extra_repr(self):
227 | return 'channels={}'.format(self.channels)
228 |
229 |
230 | class GraphDownsample(torch.nn.Module):
231 |
232 | def __init__(self, channels_in, channels_out=None):
233 | super().__init__()
234 | self.channels_in = channels_in
235 | self.channels_out = channels_out or channels_in
236 | self.downsample = Downsample(channels_in)
237 | if self.channels_in != self.channels_out:
238 | self.conv1x1 = Conv1x1BnRelu(self.channels_in, self.channels_out)
239 |
240 | def forward(self, x, leaf_mask, numd, lnumd):
241 | # downsample nodes at layer depth
242 | outd = x[-numd:]
243 | outd = self.downsample(outd)
244 |
245 | # get the nodes at layer (depth-1)
246 | out = torch.zeros(leaf_mask.shape[0], x.shape[1], device=x.device)
247 | out[leaf_mask] = x[-lnumd-numd:-numd]
248 | out[leaf_mask.logical_not()] = outd
249 |
250 | # construct the final output
251 | out = torch.cat([x[:-numd-lnumd], out], dim=0)
252 |
253 | if self.channels_in != self.channels_out:
254 | out = self.conv1x1(out)
255 | return out
256 |
257 | def extra_repr(self):
258 | return 'channels_in={}, channels_out={}'.format(
259 | self.channels_in, self.channels_out)
260 |
261 |
262 | class GraphMaxpool(torch.nn.Module):
263 |
264 | def __init__(self):
265 | super().__init__()
266 |
267 | def forward(self, x, leaf_mask, numd, lnumd):
268 | # downsample nodes at layer depth
269 | channel = x.shape[1]
270 | outd = x[-numd:]
271 | outd, _ = outd.view(-1, 8, channel).max(dim=1)
272 |
273 | # get the nodes at layer (depth-1)
274 | out = torch.zeros(leaf_mask.shape[0], channel, device=x.device)
275 | out[leaf_mask] = x[-lnumd-numd:-numd]
276 | out[leaf_mask.logical_not()] = outd
277 |
278 | # construct the final output
279 | out = torch.cat([x[:-numd-lnumd], out], dim=0)
280 | return out
281 |
282 |
283 | class GraphUpsample(torch.nn.Module):
284 |
285 | def __init__(self, channels_in, channels_out=None):
286 | super().__init__()
287 | self.channels_in = channels_in
288 | self.channels_out = channels_out or channels_in
289 | self.upsample = Upsample(channels_in)
290 | if self.channels_in != self.channels_out:
291 | self.conv1x1 = Conv1x1BnRelu(self.channels_in, self.channels_out)
292 |
293 | def forward(self, x, leaf_mask, numd):
294 | # upsample nodes at layer (depth-1)
295 | outd = x[-numd:]
296 | out1 = outd[leaf_mask.logical_not()]
297 | out1 = self.upsample(out1)
298 |
299 | # construct the final output
300 | out = torch.cat([x[:-numd], outd[leaf_mask], out1], dim=0)
301 | if self.channels_in != self.channels_out:
302 | out = self.conv1x1(out)
303 | return out
304 |
305 | def extra_repr(self):
306 | return 'channels_in={}, channels_out={}'.format(
307 | self.channels_in, self.channels_out)
308 |
309 |
310 | class GraphResBlock2(torch.nn.Module):
311 |
312 | def __init__(self, channel_in, channel_out, bottleneck=1, n_edge_type=7,
313 | avg_degree=7, n_node_type=0):
314 | super().__init__()
315 | self.channel_in = channel_in
316 | self.channel_out = channel_out
317 | self.bottleneck = bottleneck
318 | channel_m = int(channel_out / bottleneck)
319 |
320 | self.conva = GraphConvBnRelu(
321 | channel_in, channel_m, n_edge_type, avg_degree, n_node_type)
322 | self.convb = GraphConvBn(
323 | channel_m, channel_out, n_edge_type, avg_degree, n_node_type)
324 | if self.channel_in != self.channel_out:
325 | self.conv1x1 = Conv1x1Bn(channel_in, channel_out)
326 | self.relu = torch.nn.ReLU(inplace=True)
327 |
328 | def forward(self, x, edge_index, edge_type, node_type):
329 | x1 = self.conva(x, edge_index, edge_type, node_type)
330 | x2 = self.convb(x1, edge_index, edge_type, node_type)
331 |
332 | if self.channel_in != self.channel_out:
333 | x = self.conv1x1(x)
334 |
335 | out = self.relu(x2 + x)
336 | return out
337 |
338 |
339 | class GraphResBlock(torch.nn.Module):
340 |
341 | def __init__(self, channel_in, channel_out, bottleneck=4, n_edge_type=7,
342 | avg_degree=7, n_node_type=0):
343 | super().__init__()
344 | self.channel_in = channel_in
345 | self.channel_out = channel_out
346 | self.bottleneck = bottleneck
347 | channel_m = int(channel_out / bottleneck)
348 |
349 | self.conv1x1a = Conv1x1BnRelu(channel_in, channel_m)
350 | self.conv = GraphConvBnRelu(
351 | channel_m, channel_m, n_edge_type, avg_degree, n_node_type)
352 | self.conv1x1b = Conv1x1Bn(channel_m, channel_out)
353 | if self.channel_in != self.channel_out:
354 | self.conv1x1c = Conv1x1Bn(channel_in, channel_out)
355 | self.relu = torch.nn.ReLU(inplace=True)
356 |
357 | def forward(self, x, edge_index, edge_type, node_type):
358 | x1 = self.conv1x1a(x)
359 | x2 = self.conv(x1, edge_index, edge_type, node_type)
360 | x3 = self.conv1x1b(x2)
361 |
362 | if self.channel_in != self.channel_out:
363 | x = self.conv1x1c(x)
364 |
365 | out = self.relu(x3 + x)
366 | return out
367 |
368 |
369 | class GraphResBlocks(torch.nn.Module):
370 |
371 | def __init__(self, channel_in, channel_out, resblk_num, bottleneck=4,
372 | n_edge_type=7, avg_degree=7, n_node_type=0,
373 | resblk_type='bottleneck'):
374 | super().__init__()
375 | self.resblk_num = resblk_num
376 | channels = [channel_in] + [channel_out] * resblk_num
377 | ResBlk = self._get_resblock(resblk_type)
378 | self.resblks = torch.nn.ModuleList([
379 | ResBlk(channels[i], channels[i+1], bottleneck,
380 | n_edge_type, avg_degree, n_node_type)
381 | for i in range(self.resblk_num)])
382 |
383 | def _get_resblock(self, resblk_type):
384 | if resblk_type == 'bottleneck':
385 | return GraphResBlock
386 | elif resblk_type == 'basic':
387 | return GraphResBlock2
388 | else:
389 | raise ValueError
390 |
391 | def forward(self, data, edge_index, edge_type, node_type):
392 | for i in range(self.resblk_num):
393 | data = self.resblks[i](data, edge_index, edge_type, node_type)
394 | return data
395 |
396 |
397 | def doctree_align(value, key, query):
398 | # out-of-bound
399 | out_of_bound = query > key[-1]
400 | query[out_of_bound] = -1
401 |
402 | # search
403 | idx = torch.searchsorted(key, query)
404 | found = key[idx] == query
405 |
406 | # assign the found value to the output
407 | out = torch.zeros(query.shape[0], value.shape[1], device=value.device)
408 | out[found] = value[idx[found]]
409 | return out
410 |
--------------------------------------------------------------------------------
/tools/shapenet.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import time
10 | import wget
11 | import shutil
12 | import torch
13 | import ocnn
14 | import trimesh
15 | import logging
16 | import mesh2sdf
17 | import zipfile
18 | import argparse
19 | import numpy as np
20 | from tqdm import tqdm
21 | from plyfile import PlyData, PlyElement
22 |
23 | logger = logging.getLogger("trimesh")
24 | logger.setLevel(logging.ERROR)
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--run', type=str, required=True)
28 | parser.add_argument('--start', type=int, default=0)
29 | parser.add_argument('--end', type=int, default=45572)
30 | args = parser.parse_args()
31 |
32 | size = 128 # resolution of SDF
33 | level = 0.015 # 2/128 = 0.015625
34 | shape_scale = 0.5 # rescale the shape into [-0.5, 0.5]
35 | project_folder = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
36 | root_folder = os.path.join(project_folder, 'data/ShapeNet')
37 |
38 |
39 | def create_flag_file(filename):
40 | r''' Creates a flag file to indicate whether some time-consuming works
41 | have been done.
42 | '''
43 |
44 | folder = os.path.dirname(filename)
45 | if not os.path.exists(folder):
46 | os.makedirs(folder)
47 | with open(filename, 'w') as fid:
48 | fid.write('succ @ ' + time.ctime())
49 |
50 |
51 | def check_folder(filenames: list):
52 | r''' Checks whether the folder contains the filename exists.
53 | '''
54 |
55 | for filename in filenames:
56 | folder = os.path.dirname(filename)
57 | if not os.path.exists(folder):
58 | os.makedirs(folder)
59 |
60 |
61 | def get_filenames(filelist):
62 | r''' Gets filenames from a filelist.
63 | '''
64 |
65 | filelist = os.path.join(root_folder, 'filelist', filelist)
66 | with open(filelist, 'r') as fid:
67 | lines = fid.readlines()
68 | filenames = [line.split()[0] for line in lines]
69 | return filenames
70 |
71 |
72 | def unzip_shapenet():
73 | r''' Unzip the ShapeNetCore.v1
74 | '''
75 |
76 | filename = os.path.join(root_folder, 'ShapeNetCore.v1.zip')
77 | flag_file = os.path.join(root_folder, 'flags/unzip_shapenet_succ')
78 | if not os.path.exists(flag_file):
79 | print('-> Unzip ShapeNetCore.v1.zip.')
80 | with zipfile.ZipFile(filename, 'r') as zip_ref:
81 | zip_ref.extractall(root_folder)
82 | create_flag_file(flag_file)
83 |
84 | folder = os.path.join(root_folder, 'ShapeNetCore.v1')
85 | flag_file = os.path.join(root_folder, 'flags/unzip_shapenet_all_succ')
86 | if not os.path.exists(flag_file):
87 | print('-> Unzip all zip files in ShapeNetCore.v1.')
88 | filenames = os.listdir(folder)
89 | for filename in filenames:
90 | if filename.endswith('.zip'):
91 | print('- Unzip %s' % filename)
92 | zipname = os.path.join(folder, filename)
93 | with zipfile.ZipFile(zipname, 'r') as zip_ref:
94 | zip_ref.extractall(folder)
95 | os.remove(zipname)
96 | create_flag_file(flag_file)
97 |
98 |
99 | def download_filelist():
100 | r''' Downloads the filelists used for learning.
101 | '''
102 |
103 | flag_file = os.path.join(root_folder, 'flags/download_filelist_succ')
104 | if not os.path.exists(flag_file):
105 | print('-> Download the filelist.')
106 | url = 'https://www.dropbox.com/s/4jvam486l8961t7/shapenet.filelist.zip?dl=1'
107 | filename = os.path.join(root_folder, 'filelist.zip')
108 | wget.download(url, filename, bar=None)
109 |
110 | folder = os.path.join(root_folder, 'filelist')
111 | with zipfile.ZipFile(filename, 'r') as zip_ref:
112 | zip_ref.extractall(path=folder)
113 | os.remove(filename)
114 | create_flag_file(flag_file)
115 |
116 |
117 | def run_mesh2sdf():
118 | r''' Converts the meshes from ShapeNet to SDFs and manifold meshes.
119 | '''
120 |
121 | print('-> Run mesh2sdf.')
122 | mesh_scale = 0.8
123 | filenames = get_filenames('all.txt')
124 | for i in tqdm(range(args.start, args.end), ncols=80):
125 | filename = filenames[i]
126 | filename_raw = os.path.join(
127 | root_folder, 'ShapeNetCore.v1', filename, 'model.obj')
128 | filename_obj = os.path.join(root_folder, 'mesh', filename + '.obj')
129 | filename_box = os.path.join(root_folder, 'bbox', filename + '.npz')
130 | filename_npy = os.path.join(root_folder, 'sdf', filename + '.npy')
131 | check_folder([filename_obj, filename_box, filename_npy])
132 | if os.path.exists(filename_obj): continue
133 |
134 | # load the raw mesh
135 | mesh = trimesh.load(filename_raw, force='mesh')
136 |
137 | # rescale mesh to [-1, 1] for mesh2sdf, note the factor **mesh_scale**
138 | vertices = mesh.vertices
139 | bbmin, bbmax = vertices.min(0), vertices.max(0)
140 | center = (bbmin + bbmax) * 0.5
141 | scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
142 | vertices = (vertices - center) * scale
143 |
144 | # run mesh2sdf
145 | sdf, mesh_new = mesh2sdf.compute(vertices, mesh.faces, size, fix=True,
146 | level=level, return_mesh=True)
147 | mesh_new.vertices = mesh_new.vertices * shape_scale
148 |
149 | # save
150 | np.savez(filename_box, bbmax=bbmax, bbmin=bbmin, mul=mesh_scale)
151 | np.save(filename_npy, sdf)
152 | mesh_new.export(filename_obj)
153 |
154 |
155 | def sample_pts_from_mesh():
156 | r''' Samples 10k points with normals from the ground-truth meshes.
157 | '''
158 |
159 | print('-> Run sample_pts_from_mesh.')
160 | num_samples = 40000
161 | mesh_folder = os.path.join(root_folder, 'mesh')
162 | output_folder = os.path.join(root_folder, 'dataset')
163 | filenames = get_filenames('all.txt')
164 | for i in tqdm(range(args.start, args.end), ncols=80):
165 | filename = filenames[i]
166 | filename_obj = os.path.join(mesh_folder, filename + '.obj')
167 | filename_pts = os.path.join(output_folder, filename, 'pointcloud.npz')
168 | check_folder([filename_pts])
169 | if os.path.exists(filename_pts): continue
170 |
171 | # sample points
172 | mesh = trimesh.load(filename_obj, force='mesh')
173 | points, idx = trimesh.sample.sample_surface(mesh, num_samples)
174 | normals = mesh.face_normals[idx]
175 |
176 | # save points
177 | np.savez(filename_pts, points=points.astype(np.float16),
178 | normals=normals.astype(np.float16))
179 |
180 |
181 | def sample_sdf():
182 | r''' Samples ground-truth SDF values for training.
183 | '''
184 |
185 | # constants
186 | depth, full_depth = 6, 4
187 | sample_num = 4 # number of samples in each octree node
188 | grid = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
189 | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]])
190 |
191 | print('-> Sample SDFs from the ground truth.')
192 | filenames = get_filenames('all.txt')
193 | for i in tqdm(range(args.start, args.end), ncols=80):
194 | filename = filenames[i]
195 | dataset_folder = os.path.join(root_folder, 'dataset')
196 | filename_sdf = os.path.join(root_folder, 'sdf', filename + '.npy')
197 | filename_pts = os.path.join(dataset_folder, filename, 'pointcloud.npz')
198 | filename_out = os.path.join(dataset_folder, filename, 'sdf.npz')
199 | if os.path.exists(filename_out): continue
200 |
201 | # load data
202 | pts = np.load(filename_pts)
203 | sdf = np.load(filename_sdf)
204 | sdf = torch.from_numpy(sdf)
205 | points = pts['points'].astype(np.float32)
206 | normals = pts['normals'].astype(np.float32)
207 | points = points / shape_scale # rescale points to [-1, 1]
208 |
209 | # build octree
210 | points = ocnn.points_new(
211 | torch.from_numpy(points), torch.from_numpy(normals),
212 | torch.Tensor(), torch.Tensor())
213 | octree2points = ocnn.Points2Octree(depth=depth, full_depth=full_depth)
214 | octree = octree2points(points)
215 |
216 | # sample points and grads according to the xyz
217 | xyzs, grads, sdfs = [], [], []
218 | for d in range(full_depth, depth + 1):
219 | xyz = ocnn.octree_property(octree, 'xyz', d)
220 | xyz = ocnn.octree_decode_key(xyz)
221 |
222 | # sample k points in each octree node
223 | xyz = xyz[:, :3].float() # + 0.5 -> octree node center
224 | xyz = xyz.unsqueeze(1) + torch.rand(xyz.shape[0], sample_num, 3)
225 | xyz = xyz.view(-1, 3) # (N, 3)
226 | xyz = xyz * (size / 2 ** d) # normalize to [0, 2^sdf_depth]
227 | xyz = xyz[(xyz < 127).all(dim=1)] # remove out-of-bound points
228 | xyzs.append(xyz)
229 |
230 | # interpolate the sdf values
231 | xyzi = torch.floor(xyz) # the integer part (N, 3)
232 | corners = xyzi.unsqueeze(1) + grid # (N, 8, 3)
233 | coordsf = xyz.unsqueeze(1) - corners # (N, 8, 3), in [-1.0, 1.0]
234 | weights = (1 - coordsf.abs()).prod(dim=-1) # (N, 8, 1)
235 | corners = corners.long().view(-1, 3)
236 | x, y, z = corners[:, 0], corners[:, 1], corners[:, 2]
237 | s = sdf[x, y, z].view(-1, 8)
238 | sw = torch.sum(s * weights, dim=1)
239 | sdfs.append(sw)
240 |
241 | # calc the gradient
242 | gx = s[:, 4] - s[:, 0] + s[:, 5] - s[:, 1] + \
243 | s[:, 6] - s[:, 2] + s[:, 7] - s[:, 3] # noqa
244 | gy = s[:, 2] - s[:, 0] + s[:, 3] - s[:, 1] + \
245 | s[:, 6] - s[:, 4] + s[:, 7] - s[:, 5] # noqa
246 | gz = s[:, 1] - s[:, 0] + s[:, 3] - s[:, 2] + \
247 | s[:, 5] - s[:, 4] + s[:, 7] - s[:, 6] # noqa
248 | grad = torch.stack([gx, gy, gz], dim=-1)
249 | norm = torch.sqrt(torch.sum(grad ** 2, dim=-1, keepdims=True))
250 | grad = grad / (norm + 1.0e-8)
251 | grads.append(grad)
252 |
253 | # concat the results
254 | xyzs = torch.cat(xyzs, dim=0).numpy()
255 | points = (xyzs / 64 - 1).astype(np.float16) * shape_scale # !shape_scale
256 | grads = torch.cat(grads, dim=0).numpy().astype(np.float16)
257 | sdfs = torch.cat(sdfs, dim=0).numpy().astype(np.float16)
258 |
259 | # save results
260 | # points = (points * args.scale).astype(np.float16) # in [-scale, scale]
261 | np.savez(filename_out, points=points, grad=grads, sdf=sdfs)
262 |
263 |
264 | def sample_occu():
265 | r''' Samples occupancy values for evaluating the IoU following ConvONet.
266 | '''
267 |
268 | num_samples = 100000
269 | grid = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
270 | [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]])
271 |
272 | # filenames = get_filenames('all.txt')
273 | filenames = get_filenames('test.txt') + get_filenames('test_unseen5.txt')
274 | for filename in tqdm(filenames, ncols=80):
275 | filename_sdf = os.path.join(root_folder, 'sdf', filename + '.npy')
276 | filename_occu = os.path.join(root_folder, 'dataset', filename, 'points')
277 | if os.path.exists(filename_occu) or (not os.path.exists(filename_sdf)):
278 | continue
279 |
280 | sdf = np.load(filename_sdf)
281 | factor = 127.0 / 128.0 # make sure the interpolation is well defined
282 | points_uniform = np.random.rand(num_samples, 3) * factor # in [0, 1)
283 | points = (points_uniform - 0.5) * (2 * shape_scale) # !!! rescale
284 | points = points.astype(np.float16)
285 |
286 | # interpolate the sdf values
287 | xyz = points_uniform * 128 # in [0, 127)
288 | xyzi = np.floor(xyz) # the integer part (N, 3)
289 | corners = np.expand_dims(xyzi, 1) + grid # (N, 8, 3)
290 | coordsf = np.expand_dims(xyz, 1) - corners # (N, 8, 3), in [-1.0, 1.0]
291 | weights = np.prod(1 - np.abs(coordsf), axis=-1) # (N, 8)
292 |
293 | corners = np.reshape(corners.astype(np.int64), (-1, 3))
294 | x, y, z = corners[:, 0], corners[:, 1], corners[:, 2]
295 | values = np.reshape(sdf[x, y, z], (-1, 8))
296 | value = np.sum(values * weights, axis=1)
297 | occu = value < 0
298 | occu = np.packbits(occu)
299 |
300 | # save
301 | np.savez(filename_occu, points=points, occupancies=occu)
302 |
303 |
304 | def generate_test_points():
305 | r''' Generates points in `ply` format for testing.
306 | '''
307 |
308 | noise_std = 0.005
309 | point_sample_num = 3000
310 | # filenames = get_filenames('all.txt')
311 | filenames = get_filenames('test.txt') + get_filenames('test_unseen5.txt')
312 | for filename in tqdm(filenames, ncols=80):
313 | filename_pts = os.path.join(
314 | root_folder, 'dataset', filename, 'pointcloud.npz')
315 | filename_ply = os.path.join(
316 | root_folder, 'test.input', filename + '.ply')
317 | if not os.path.exists(filename_pts): continue
318 | check_folder([filename_ply])
319 |
320 | # sample points
321 | pts = np.load(filename_pts)
322 | points = pts['points'].astype(np.float32)
323 | noise = noise_std * np.random.randn(point_sample_num, 3)
324 | rand_idx = np.random.choice(points.shape[0], size=point_sample_num)
325 | points_noise = points[rand_idx] + noise
326 |
327 | # save ply
328 | vertices = []
329 | py_types = (float, float, float)
330 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4')]
331 | for idx in range(points_noise.shape[0]):
332 | vertices.append(
333 | tuple(dtype(d) for dtype, d in zip(py_types, points_noise[idx])))
334 | structured_array = np.array(vertices, dtype=npy_types)
335 | el = PlyElement.describe(structured_array, 'vertex')
336 | PlyData([el]).write(filename_ply)
337 |
338 |
339 | def download_dataset():
340 | r''' Directly downloads the dataset.
341 | '''
342 |
343 | flag_file = os.path.join(root_folder, 'flags/download_dataset_succ')
344 | if not os.path.exists(flag_file):
345 | print('-> Download the dataset.')
346 | url = 'https://www.dropbox.com/s/mc3lrwqpmnfq3j8/shapenet.dataset.zip?dl=1'
347 | filename = os.path.join(root_folder, 'shapenet.dataset.zip')
348 | wget.download(url, filename)
349 |
350 | with zipfile.ZipFile(filename, 'r') as zip_ref:
351 | zip_ref.extractall(path=root_folder)
352 | # os.remove(filename)
353 | create_flag_file(flag_file)
354 |
355 |
356 | def generate_dataset_unseen5():
357 | r'''Creates the unseen5 dataset
358 | '''
359 |
360 | dataset_folder = os.path.join(root_folder, 'dataset')
361 | unseen5_folder = os.path.join(root_folder, 'dataset.unseen5')
362 | if not os.path.exists(unseen5_folder):
363 | os.makedirs(unseen5_folder)
364 | for folder in ['02808440', '02773838', '02818832', '02876657', '03938244']:
365 | curr_folder = os.path.join(dataset_folder, folder)
366 | if os.path.exists(curr_folder):
367 | shutil.move(os.path.join(dataset_folder, folder), unseen5_folder)
368 |
369 |
370 | def copy_convonet_filelists():
371 | r''' Copies the filelist of ConvONet to the datasets, which are needed when
372 | calculating the evaluation metrics.
373 | '''
374 |
375 | with open(os.path.join(root_folder, 'filelist/lists.txt'), 'r') as fid:
376 | lines = fid.readlines()
377 | filenames = [line.split()[0] for line in lines]
378 | filelist_folder = os.path.join(root_folder, 'filelist')
379 | for filename in filenames:
380 | src_name = os.path.join(filelist_folder, filename)
381 | des_name = src_name.replace('filelist/convonet.filelist', 'dataset') \
382 | .replace('filelist/unseen5.filelist', 'dataset.unseen5')
383 | if not os.path.exists(des_name):
384 | shutil.copy(src_name, des_name)
385 |
386 |
387 | def convert_mesh_to_sdf():
388 | unzip_shapenet()
389 | download_filelist()
390 | run_mesh2sdf()
391 |
392 |
393 | def generate_dataset():
394 | sample_pts_from_mesh()
395 | sample_sdf()
396 | sample_occu()
397 | generate_test_points()
398 | generate_dataset_unseen5()
399 | copy_convonet_filelists()
400 |
401 |
402 | if __name__ == '__main__':
403 | eval('%s()' % args.run)
404 |
--------------------------------------------------------------------------------
/models/dual_octree.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import ocnn
10 | # import torch_sparse
11 | import numpy as np
12 |
13 |
14 | class DualOctree:
15 |
16 | def __init__(self, octree):
17 | # prime octree
18 | self.octree = octree
19 | self.device = octree.device
20 | self.depth = ocnn.octree_property(octree, 'depth').item()
21 | self.full_depth = ocnn.octree_property(octree, 'full_depth').item()
22 | self.batch_size = ocnn.octree_property(octree, 'batch_size').item()
23 |
24 | # node numbers
25 | self.nnum = ocnn.octree_property(octree, 'node_num')
26 | self.nenum = ocnn.octree_property(octree, 'node_num_ne')
27 | self.ncum = ocnn.octree_property(octree, 'node_num_cum')
28 | self.lnum = self.nnum - self.nenum # leaf node numbers
29 |
30 | # node properties
31 | xyzi = ocnn.octree_property(octree, 'xyz')
32 | self.xyzi = ocnn.octree_decode_key(xyzi)
33 | self.xyz = self.xyzi[:, :3]
34 | self.batch = self.xyzi[:, 3]
35 | self.node_depth = self._node_depth()
36 | self.child = ocnn.octree_property(octree, 'child')
37 | self.key = ocnn.octree_property(octree, 'key')
38 | self.keyd = self.key | (self.node_depth << 58)
39 |
40 | # build lookup tables
41 | self._lookup_table()
42 |
43 | # build dual graph
44 | self._graph = [dict()] * (self.depth + 1) # the internal graph
45 | self.graph = [dict()] * (self.depth + 1) # the output graph
46 | self.build_dual_graph()
47 |
48 | def _lookup_table(self):
49 | self.ngh = torch.tensor(
50 | [[0, 0, 1], [0, 0, -1], # up, down
51 | [0, 1, 0], [0, -1, 0], # right, left
52 | [1, 0, 0], [-1, 0, 0]], # front, back
53 | dtype=torch.int16, device=self.device)
54 | self.dir_table = torch.tensor(
55 | [[1, 3, 5, 7], [0, 2, 4, 6], # up, down
56 | [2, 3, 6, 7], [0, 1, 4, 5], # right, left
57 | [4, 5, 6, 7], [0, 1, 2, 3]], # front, back
58 | dtype=torch.int64, device=self.device)
59 | self.dir_type = torch.tensor(
60 | [0, 1, 2, 3, 4, 5],
61 | dtype=torch.int64, device=self.device)
62 | self.remap = torch.tensor(
63 | [1, 0, 3, 2, 5, 4],
64 | dtype=torch.int64, device=self.device)
65 | self.inter_row = torch.tensor(
66 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3,
67 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7],
68 | dtype=torch.int64, device=self.device)
69 | self.inter_col = torch.tensor(
70 | [1, 2, 4, 0, 3, 5, 0, 3, 6, 1, 2, 7,
71 | 0, 5, 6, 1, 4, 7, 2, 4, 7, 3, 5, 6],
72 | dtype=torch.int64, device=self.device)
73 | self.inter_dir = torch.tensor(
74 | [0, 2, 4, 1, 2, 4, 3, 0, 4, 3, 1, 4,
75 | 5, 0, 2, 5, 1, 2, 5, 3, 0, 5, 3, 1],
76 | dtype=torch.int64, device=self.device)
77 |
78 | def _node_depth(self):
79 | nd = [torch.tensor([d], dtype=torch.int64, device=self.device).expand(
80 | self.nnum[d]) for d in range(self.depth + 1)]
81 | return torch.cat(nd)
82 |
83 | def build_dual_graph(self):
84 | self._graph[self.full_depth] = self.dense_graph(self.full_depth)
85 | for d in range(self.full_depth + 1, self.depth + 1):
86 | self._graph[d] = self.sparse_graph(d, self._graph[d - 1])
87 |
88 | def dense_graph(self, depth=3):
89 | K = 6 # each node has at most K neighboring node
90 | bnd = 2 ** depth
91 | num = bnd ** 3
92 |
93 | ki = torch.arange(0, num, dtype=torch.int64, device=self.device)
94 | xi = ocnn.octree_key2xyz(ki, depth)
95 | xi = ocnn.octree_decode_key(xi)[:, :3]
96 | xj = xi.unsqueeze(1) + self.ngh # [N, K, 3]
97 |
98 | row = ki.unsqueeze(1).repeat(1, K).view(-1)
99 | zj = torch.zeros(num, K, 1, dtype=torch.int16, device=self.device)
100 | kj = torch.cat([xj, zj], dim=-1).view(-1, 4)
101 | kj = ocnn.octree_encode_key(kj)
102 | # for full octree, the octree key is the index
103 | col = ocnn.octree_xyz2key(kj, depth)
104 |
105 | valid = torch.logical_and(xj > -1, xj < bnd) # out-of-bound
106 | valid = torch.all(valid, dim=-1).view(-1)
107 | row, col = row[valid], col[valid]
108 |
109 | edge_dir = self.dir_type.repeat(num)
110 | edge_dir = edge_dir[valid]
111 |
112 | # deal with batches
113 | dis = torch.arange(self.batch_size, dtype=torch.int64, device=self.device)
114 | dis = dis.unsqueeze(1) * num + self.ncum[depth] # NOTE:add self.ncum[depth]
115 | row = row.unsqueeze(0) + dis
116 | col = col.unsqueeze(0) + dis
117 | edge_dir = edge_dir.unsqueeze(0) + torch.zeros_like(dis)
118 | # rowptr = torch.ops.torch_sparse.ind2ptr(row, num)
119 | return {'edge_idx': torch.stack([row.view(-1), col.view(-1)]),
120 | 'edge_dir': edge_dir.view(-1)}
121 |
122 | def _internal_edges(self, nnum, dis=0):
123 | assert(nnum % 8 == 0)
124 | d = torch.arange(0, nnum / 8, dtype=torch.int64, device=self.device)
125 | d = torch.unsqueeze(d * 8 + dis, dim=1)
126 | row = self.inter_row.unsqueeze(0) + d
127 | col = self.inter_col.unsqueeze(0) + d
128 | edge_dir = self.inter_dir.unsqueeze(0) + torch.zeros_like(d)
129 | return row.view(-1), col.view(-1), edge_dir.view(-1)
130 |
131 | def relative_dir(self, vi, vj, depth, rescale=True):
132 | xi = self.xyz[vi]
133 | xj = self.xyz[vj]
134 |
135 | # get 6 neighborhoods of xi via `self.ngh`
136 | xn = xi.unsqueeze(1) + self.ngh
137 |
138 | # rescale the coord of xj
139 | scale = torch.ones_like(vj)
140 | if rescale:
141 | dj = self.node_depth[vj]
142 | scale = torch.pow(2.0, depth - dj)
143 | # torch._assert((scale > 1.0).all(), 'vj is larger than vi')
144 | xj = xj * scale.unsqueeze(-1)
145 |
146 | # inbox testing
147 | xj = xj.unsqueeze(1)
148 | scale = scale.view(-1, 1, 1)
149 | inbox = torch.logical_and(xn >= xj, xn < xj + scale.view(-1, 1, 1))
150 | inbox = torch.all(inbox, dim=-1)
151 | rel_dir = torch.argmax(inbox.byte(), dim=-1)
152 | return rel_dir
153 |
154 | def _node_property(self, prop, depth):
155 | return prop[self.ncum[depth]: self.ncum[depth] + self.nnum[depth]]
156 |
157 | def node_child(self, depth):
158 | return self._node_property(self.child, depth)
159 |
160 | def sparse_graph(self, depth, graph):
161 | # Add internal edges connecting sliding nodes.
162 | ncum_d = self.ncum[depth] # NOTE: add ncum_d, i.e., self.ncum[depth]
163 | row_i, col_i, dir_i = self._internal_edges(self.nnum[depth], ncum_d)
164 |
165 | # mark invalid nodes of layer (depth-1)
166 | edge_idx, edge_dir = graph['edge_idx'], graph['edge_dir']
167 | row, col = edge_idx[0], edge_idx[1]
168 | valid_row = self.child[row] < 0
169 | valid_col = self.child[col] < 0
170 | invalid_row = torch.logical_not(valid_row)
171 | invalid_col = torch.logical_not(valid_col)
172 | valid_edges = torch.logical_and(valid_row, valid_col)
173 | invalid_row_vtx = torch.logical_and(invalid_row, valid_col)
174 | invalid_both_vtx = torch.logical_and(invalid_row, invalid_col)
175 |
176 | # deal with edges with invalid row vtx only
177 | vi, vj = row[invalid_row_vtx], col[invalid_row_vtx]
178 | rel_dir = self.relative_dir(vi, vj, depth - 1)
179 | row_o1 = self.child[vi].unsqueeze(1) * 8 + self.dir_table[rel_dir, :]
180 | row_o1 = row_o1.view(-1) + ncum_d # NOTE: add ncum_d
181 | col_o1 = vj.unsqueeze(1).repeat(1, 4).view(-1)
182 | dir_o1 = rel_dir.unsqueeze(1).repeat(1, 4).view(-1)
183 |
184 | # deal with edges with 2 invalid nodes
185 | row_o2 = torch.tensor([], dtype=torch.int64, device=self.device)
186 | col_o2 = torch.tensor([], dtype=torch.int64, device=self.device)
187 | dir_o2 = torch.tensor([], dtype=torch.int64, device=self.device)
188 | if invalid_both_vtx.any():
189 | vi, vj = row[invalid_both_vtx], col[invalid_both_vtx]
190 | rel_dir = self.relative_dir(vi, vj, depth - 1, rescale=False)
191 | row_o2 = self.child[vi].unsqueeze(1) * 8 + self.dir_table[rel_dir, :]
192 | row_o2 = row_o2.view(-1) + ncum_d # NOTE: add ncum_d
193 | dir_o2 = rel_dir.unsqueeze(1).repeat(1, 4).view(-1)
194 | rel_dir_col = self.remap[rel_dir]
195 | col_o2 = self.child[vj].unsqueeze(1) * 8 + self.dir_table[rel_dir_col, :]
196 | col_o2 = col_o2.view(-1) + ncum_d # NOTE: add ncum_d
197 |
198 | # gather the results
199 | edge_idx = torch.stack([
200 | torch.cat([row[valid_edges], row_i, row_o1, col_o1, row_o2]),
201 | torch.cat([col[valid_edges], col_i, col_o1, row_o1, col_o2])])
202 | edge_dir = torch.cat([
203 | edge_dir[valid_edges], dir_i, dir_o1, self.remap[dir_o1], dir_o2])
204 | return {'edge_idx': edge_idx, 'edge_dir': edge_dir}
205 |
206 | def add_self_loops(self):
207 | for d in range(self.full_depth, self.depth + 1):
208 | edge_idx = self._graph[d]['edge_idx']
209 | edge_dir = self._graph[d]['edge_dir']
210 | row, col = edge_idx[0], edge_idx[1]
211 | unique_idx = torch.unique(row, sorted=True)
212 | dir_idx = torch.ones_like(unique_idx) * 6
213 | self.graph[d] = {'edge_idx': torch.stack([torch.cat([row, unique_idx]),
214 | torch.cat([col, unique_idx])]),
215 | 'edge_dir': torch.cat([edge_dir, dir_idx])}
216 |
217 | def calc_edge_type(self):
218 | dir_num = 7
219 | for d in range(self.full_depth, self.depth + 1):
220 | depth_num = d - self.full_depth + 1
221 | edge_idx = self._graph[d]['edge_idx']
222 | edge_dir = self._graph[d]['edge_dir']
223 | row, col = edge_idx[0], edge_idx[1]
224 |
225 | dr = self.node_depth[row] - self.full_depth
226 | dc = self.node_depth[col] - self.full_depth
227 | edge_type = (dr * depth_num + dc) * dir_num + edge_dir
228 |
229 | self.graph[d]['edge_type'] = edge_type
230 |
231 | def remap_node_idx(self):
232 | leaf_nodes = self.child < 0
233 | for d in range(self.full_depth, self.depth + 1):
234 | leaf_d = torch.ones(self.nnum[d], dtype=torch.bool, device=self.device)
235 | mask = torch.cat([leaf_nodes[:self.ncum[d]], leaf_d], dim=0)
236 | remap = torch.cumsum(mask.long(), dim=0) - 1
237 | self.graph[d]['edge_idx'] = remap[self.graph[d]['edge_idx']]
238 |
239 | def filter_multiple_level_edges(self):
240 | for d in range(self.full_depth, self.depth + 1):
241 | edge_idx = self.graph[d]['edge_idx']
242 | edge_dir = self.graph[d]['edge_dir']
243 | valid_edges = (self.node_depth[edge_idx] == d).all(dim=0)
244 |
245 | # filter edges
246 | edge_idx = edge_idx[:, valid_edges]
247 | edge_dir = edge_dir[valid_edges]
248 |
249 | self.graph[d] = {'edge_idx': edge_idx, 'edge_dir': edge_dir}
250 |
251 | def filter_coarse_to_fine_edges(self):
252 | for d in range(self.full_depth, self.depth + 1):
253 | edge_idx = self.graph[d]['edge_idx']
254 | edge_dir = self.graph[d]['edge_dir']
255 |
256 | edge_node_depth = self.node_depth[edge_idx]
257 | # the depth of sender nodes should be larger than receivers
258 | valid_edges = edge_node_depth[0] >= edge_node_depth[1]
259 |
260 | # filter edges
261 | edge_idx = edge_idx[:, valid_edges]
262 | edge_dir = edge_dir[valid_edges]
263 |
264 | self.graph[d] = {'edge_idx': edge_idx, 'edge_dir': edge_dir}
265 |
266 | def filter_crosslevel_edges(self):
267 | for d in range(self.full_depth, self.depth + 1):
268 | edge_idx = self.graph[d]['edge_idx']
269 | edge_dir = self.graph[d]['edge_dir']
270 |
271 | edge_node_depth = self.node_depth[edge_idx]
272 | valid_edges = edge_node_depth[0] == edge_node_depth[1]
273 |
274 | # filter edges
275 | edge_idx = edge_idx[:, valid_edges]
276 | edge_dir = edge_dir[valid_edges]
277 |
278 | self.graph[d] = {'edge_idx': edge_idx, 'edge_dir': edge_dir}
279 |
280 | def displace_edge_and_add_node_type(self):
281 | for d in range(self.full_depth, self.depth + 1):
282 | # displace edge index
283 | self.graph[d]['edge_idx'] -= self.ncum[d]
284 |
285 | # only one type of node
286 | zeros = torch.zeros(self.nnum[d], dtype=torch.long, device=self.device)
287 | self.graph[d]['node_type'] = zeros
288 |
289 | # used in skip connections
290 | self.graph[d]['keyd'] = self._node_property(self.keyd, d)
291 |
292 | def post_processing_for_ocnn(self):
293 | self.add_self_loops()
294 | self.filter_multiple_level_edges()
295 | self.displace_edge_and_add_node_type()
296 | self.sort_edges()
297 |
298 | def sort_edges(self):
299 | dir_num = 7
300 | for d in range(self.full_depth, self.depth + 1):
301 | edge_idx = self.graph[d]['edge_idx']
302 | edge_dir = self.graph[d]['edge_dir']
303 |
304 | edge_key = edge_idx[0] * dir_num + edge_dir
305 | sidx = torch.argsort(edge_key)
306 | self.graph[d]['edge_idx'] = edge_idx[:, sidx]
307 | self.graph[d]['edge_dir'] = edge_dir[sidx]
308 |
309 | def get_input_feature(self, all_leaf_nodes=True):
310 | # the initial feature of leaf nodes in the layer self.depth
311 | data = ocnn.octree_property(self.octree, 'feature', self.depth)
312 | data = data.squeeze(0).squeeze(-1).t()
313 |
314 | # the initial feature of leaf nodes in other layers
315 | if all_leaf_nodes:
316 | channel = data.shape[1]
317 | leaf_num = self.lnum[self.full_depth:self.depth].sum()
318 | zeros = torch.zeros(leaf_num, channel, device=self.device)
319 |
320 | # concat zero features with the initial features in layer depth
321 | data = torch.cat([zeros, data], dim=0)
322 |
323 | return data
324 |
325 | def add_node_keyd(self):
326 | keyd1, keyd2 = [], []
327 | for d in range(self.full_depth, self.depth + 1):
328 | keyd = self._node_property(self.keyd, d)
329 | leaf_mask = self._node_property(self.child, d) < 0
330 | keyd1.append(keyd[leaf_mask])
331 | keyd2.append(keyd)
332 | self.graph[d]['keyd'] = torch.cat(keyd1[:-1] + keyd2[-1:], dim=0)
333 |
334 | def add_node_xyzd(self):
335 | xyz1, xyz2 = [], []
336 | for d in range(self.full_depth, self.depth + 1):
337 | xyzd = self._node_property(self.xyz, d)
338 | xyzf = xyzd.float() / 2 ** d # normalize to [0, 1]
339 | leaf_mask = self._node_property(self.child, d) < 0
340 | xyz1.append(xyzf[leaf_mask])
341 | xyz2.append(xyzf)
342 | self.graph[d]['xyz'] = torch.cat(xyz1[:-1] + xyz2[-1:], dim=0)
343 |
344 | def add_node_type(self):
345 | ntype1, ntype2 = [], []
346 | full_depth, depth = self.full_depth, self.depth
347 | for i, d in enumerate(range(full_depth, depth + 1)):
348 | ntype = d - full_depth
349 | ntype1.append(torch.ones(self.lnum[d], device=self.device) * ntype)
350 | ntype2.append(torch.ones(self.nnum[d], device=self.device) * ntype)
351 | node_type = torch.cat(ntype1[:i] + ntype2[i:i + 1], dim=0).long()
352 | self.graph[d]['node_type'] = node_type
353 |
354 | def add_node_mask(self):
355 | leaf_masks = []
356 | full_depth, depth = self.full_depth, self.depth
357 | for i, d in enumerate(range(full_depth, depth + 1)):
358 | mask1 = self._node_property(self.child, d) < 0
359 | mask2 = torch.ones(self.nnum[d], dtype=torch.bool, device=self.device)
360 | leaf_masks.append(mask1)
361 | self.graph[d]['node_mask'] = torch.cat(leaf_masks[:i] + [mask2], dim=0)
362 |
363 | def post_processing_for_docnn(self):
364 | self.add_self_loops()
365 | # The following 2 functions are only used in ablation study.
366 | # self.filter_coarse_to_fine_edges()
367 | # self.filter_crosslevel_edges()
368 | self.remap_node_idx()
369 | self.add_node_type()
370 | self.add_node_keyd() # used in skip connects
371 | self.add_node_mask()
372 | self.sort_edges()
373 |
374 | def save(self, filename):
375 | np.save(filename + 'xyz.npy', self.xyz.cpu().numpy())
376 | np.save(filename + 'batch.npy', self.batch.cpu().numpy())
377 | np.save(filename + 'node_depth.npy', self.node_depth.cpu().numpy())
378 | for d in range(self.full_depth, self.depth + 1):
379 | edge_idx = self._graph[d]['edge_idx']
380 | np.save(filename + "edge_%d.npy" % d, edge_idx.t().cpu().numpy())
381 |
382 |
383 | if __name__ == '__main__':
384 | octrees = ocnn.octree_samples(['octree_1', 'octree_2'])
385 | octree = ocnn.octree_batch(octrees).cuda()
386 | pdoctree = DualOctree(octree)
387 | pdoctree.save('data/batch_12_')
388 | pdoctree.add_self_loops()
389 | pdoctree.calc_edge_type()
390 | pdoctree.remap_node_idx()
391 | print('succ!')
392 |
--------------------------------------------------------------------------------
/solver/solver.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dual Octree Graph Networks
3 | # Copyright (c) 2022 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Peng-Shuai Wang
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import time
10 | import torch
11 | import torch.nn
12 | import torch.optim
13 | import torch.distributed
14 | import torch.multiprocessing
15 | import torch.utils.data
16 | import warnings
17 | from datetime import datetime
18 | from tqdm import tqdm
19 | from torch.utils.tensorboard import SummaryWriter
20 |
21 | from .sampler import InfSampler, DistributedInfSampler
22 | from .config import parse_args
23 |
24 | warnings.filterwarnings("ignore", module="torch.optim.lr_scheduler")
25 |
26 |
27 | class AverageTracker:
28 |
29 | def __init__(self):
30 | self.value = None
31 | self.num = 0.0
32 | self.max_len = 76
33 | self.start_time = time.time()
34 |
35 | def update(self, value):
36 | if not value:
37 | return # empty input, return
38 |
39 | value = {key: val.detach() for key, val in value.items()}
40 | if self.value is None:
41 | self.value = value
42 | else:
43 | for key, val in value.items():
44 | self.value[key] += val
45 | self.num += 1
46 |
47 | def average(self):
48 | return {key: val.item()/self.num for key, val in self.value.items()}
49 |
50 | @torch.no_grad()
51 | def average_all_gather(self):
52 | for key, tensor in self.value.items():
53 | tensors_gather = [torch.ones_like(tensor)
54 | for _ in range(torch.distributed.get_world_size())]
55 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
56 | tensors = torch.stack(tensors_gather, dim=0)
57 | self.value[key] = torch.mean(tensors)
58 |
59 | def log(self, epoch, summry_writer=None, log_file=None, msg_tag='->',
60 | notes='', print_time=True):
61 | if not self.value:
62 | return # empty, return
63 |
64 | avg = self.average()
65 | msg = 'Epoch: %d' % epoch
66 | for key, val in avg.items():
67 | msg += ', %s: %.3f' % (key, val)
68 | if summry_writer:
69 | summry_writer.add_scalar(key, val, epoch)
70 |
71 | # if the log_file is provided, save the log
72 | if log_file:
73 | with open(log_file, 'a') as fid:
74 | fid.write(msg + '\n')
75 |
76 | # memory
77 | memory = ''
78 | if torch.cuda.is_available():
79 | # size = torch.cuda.memory_allocated()
80 | size = torch.cuda.memory_reserved(device=None)
81 | memory = ', memory: {:.3f}GB'.format(size / 2**30)
82 |
83 | # time
84 | time_str = ''
85 | if print_time:
86 | curr_time = ', time: ' + datetime.now().strftime("%Y/%m/%d %H:%M:%S")
87 | duration = ', duration: {:.2f}s'.format(time.time() - self.start_time)
88 | time_str = curr_time + duration
89 |
90 | # other notes
91 | if notes:
92 | notes = ', ' + notes
93 |
94 | msg += memory + time_str + notes
95 |
96 | # split the msg for better display
97 | chunks = [msg[i:i+self.max_len] for i in range(0, len(msg), self.max_len)]
98 | msg = (msg_tag + ' ') + ('\n' + len(msg_tag) * ' ' + ' ').join(chunks)
99 | tqdm.write(msg)
100 |
101 |
102 | class Solver:
103 |
104 | def __init__(self, FLAGS, is_master=True):
105 | self.FLAGS = FLAGS
106 | self.is_master = is_master
107 | self.world_size = len(FLAGS.SOLVER.gpu)
108 | self.device = torch.cuda.current_device()
109 | self.disable_tqdm = not (is_master and FLAGS.SOLVER.progress_bar)
110 | self.start_epoch = 1
111 |
112 | self.model = None # torch.nn.Module
113 | self.optimizer = None # torch.optim.Optimizer
114 | self.scheduler = None # torch.optim.lr_scheduler._LRScheduler
115 | self.summary_writer = None # torch.utils.tensorboard.SummaryWriter
116 | self.log_file = None # str, used to save training logs
117 |
118 | def get_model(self):
119 | raise NotImplementedError
120 |
121 | def get_dataset(self, flags):
122 | raise NotImplementedError
123 |
124 | def train_step(self, batch):
125 | raise NotImplementedError
126 |
127 | def test_step(self, batch):
128 | raise NotImplementedError
129 |
130 | def eval_step(self, batch):
131 | raise NotImplementedError
132 |
133 | def result_callback(self, avg_tracker: AverageTracker, epoch):
134 | pass # additional operations based on the avg_tracker
135 |
136 | def config_dataloader(self, disable_train_data=False):
137 | flags_train, flags_test = self.FLAGS.DATA.train, self.FLAGS.DATA.test
138 |
139 | if not disable_train_data and not flags_train.disable:
140 | self.train_loader = self.get_dataloader(flags_train)
141 | self.train_iter = iter(self.train_loader)
142 |
143 | if not flags_test.disable:
144 | self.test_loader = self.get_dataloader(flags_test)
145 | self.test_iter = iter(self.test_loader)
146 |
147 | def get_dataloader(self, flags):
148 | dataset, collate_fn = self.get_dataset(flags)
149 |
150 | if self.world_size > 1:
151 | sampler = DistributedInfSampler(dataset, shuffle=flags.shuffle)
152 | else:
153 | sampler = InfSampler(dataset, shuffle=flags.shuffle)
154 |
155 | data_loader = torch.utils.data.DataLoader(
156 | dataset, batch_size=flags.batch_size, num_workers=flags.num_workers,
157 | sampler=sampler, collate_fn=collate_fn, pin_memory=True)
158 | return data_loader
159 |
160 | def config_model(self):
161 | flags = self.FLAGS.MODEL
162 | model = self.get_model(flags)
163 | model.cuda(device=self.device)
164 | if self.world_size > 1:
165 | if flags.sync_bn:
166 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
167 | model = torch.nn.parallel.DistributedDataParallel(
168 | module=model, device_ids=[self.device],
169 | output_device=self.device, broadcast_buffers=False,
170 | find_unused_parameters=flags.find_unused_parameters)
171 | if self.is_master:
172 | print(model)
173 | self.model = model
174 |
175 | def configure_optimizer(self):
176 | flags = self.FLAGS.SOLVER
177 | # The learning rate scales with regard to the world_size
178 | lr = flags.lr * self.world_size
179 | if flags.type.lower() == 'sgd':
180 | self.optimizer = torch.optim.SGD(
181 | self.model.parameters(), lr=lr, weight_decay=flags.weight_decay,
182 | momentum=0.9)
183 | elif flags.type.lower() == 'adam':
184 | self.optimizer = torch.optim.Adam(
185 | self.model.parameters(), lr=lr, weight_decay=flags.weight_decay)
186 | elif flags.type.lower() == 'adamw':
187 | self.optimizer = torch.optim.AdamW(
188 | self.model.parameters(), lr=lr, weight_decay=flags.weight_decay)
189 | else:
190 | raise ValueError
191 |
192 | if flags.lr_type == 'step':
193 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
194 | self.optimizer, milestones=flags.step_size, gamma=0.1)
195 | elif flags.lr_type == 'cos':
196 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
197 | self.optimizer, flags.max_epoch, eta_min=0)
198 | elif flags.lr_type == 'poly':
199 | def poly(epoch): return (1 - epoch / flags.max_epoch) ** flags.lr_power
200 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(
201 | self.optimizer, poly)
202 | elif flags.lr_type == 'constant':
203 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(
204 | self.optimizer, lambda epoch: 1)
205 | else:
206 | raise ValueError
207 |
208 | def configure_log(self, set_writer=True):
209 | self.logdir = self.FLAGS.SOLVER.logdir
210 | self.ckpt_dir = os.path.join(self.logdir, 'checkpoints')
211 | self.log_file = os.path.join(self.logdir, 'log.csv')
212 |
213 | if self.is_master:
214 | tqdm.write('Logdir: ' + self.logdir)
215 |
216 | if self.is_master and set_writer:
217 | self.summary_writer = SummaryWriter(self.logdir, flush_secs=20)
218 | if not os.path.exists(self.ckpt_dir):
219 | os.makedirs(self.ckpt_dir)
220 |
221 | def train_epoch(self, epoch):
222 | self.model.train()
223 | if self.world_size > 1:
224 | self.train_loader.sampler.set_epoch(epoch)
225 |
226 | train_tracker = AverageTracker()
227 | rng = range(len(self.train_loader))
228 | log_per_iter = self.FLAGS.SOLVER.log_per_iter
229 | for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm):
230 | self.optimizer.zero_grad()
231 |
232 | # forward
233 | batch = self.train_iter.next()
234 | batch['iter_num'] = it
235 | batch['epoch'] = epoch
236 | output = self.train_step(batch)
237 |
238 | # backward
239 | output['train/loss'].backward()
240 | self.optimizer.step()
241 |
242 | # track the averaged tensors
243 | train_tracker.update(output)
244 |
245 | # output intermediate logs
246 | if self.is_master and log_per_iter > 0 and it % log_per_iter == 0:
247 | notes = 'iter: %d' % it
248 | train_tracker.log(epoch, msg_tag='- ', notes=notes, print_time=False)
249 |
250 | # save logs
251 | if self.world_size > 1:
252 | train_tracker.average_all_gather()
253 | if self.is_master:
254 | train_tracker.log(epoch, self.summary_writer)
255 |
256 | def test_epoch(self, epoch):
257 | self.model.eval()
258 | test_tracker = AverageTracker()
259 | rng = range(len(self.test_loader))
260 | for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm):
261 | # forward
262 | batch = self.test_iter.next()
263 | batch['iter_num'] = it
264 | batch['epoch'] = epoch
265 | # with torch.no_grad():
266 | output = self.test_step(batch)
267 |
268 | # track the averaged tensors
269 | test_tracker.update(output)
270 |
271 | if self.world_size > 1:
272 | test_tracker.average_all_gather()
273 | if self.is_master:
274 | test_tracker.log(epoch, self.summary_writer, self.log_file, msg_tag='=>')
275 | self.result_callback(test_tracker, epoch)
276 |
277 | def eval_epoch(self, epoch):
278 | self.model.eval()
279 | eval_step = min(self.FLAGS.SOLVER.eval_step, len(self.test_loader))
280 | if eval_step < 1:
281 | eval_step = len(self.test_loader)
282 | for it in tqdm(range(eval_step), ncols=80, leave=False):
283 | batch = self.test_iter.next()
284 | batch['iter_num'] = it
285 | batch['epoch'] = epoch
286 | with torch.no_grad():
287 | self.eval_step(batch)
288 |
289 | def save_checkpoint(self, epoch):
290 | if not self.is_master:
291 | return
292 |
293 | # clean up
294 | ckpts = sorted(os.listdir(self.ckpt_dir))
295 | ckpts = [ck for ck in ckpts if ck.endswith('.pth') or ck.endswith('.tar')]
296 | if len(ckpts) > self.FLAGS.SOLVER.ckpt_num:
297 | for ckpt in ckpts[:-self.FLAGS.SOLVER.ckpt_num]:
298 | os.remove(os.path.join(self.ckpt_dir, ckpt))
299 |
300 | # save ckpt
301 | model_dict = self.model.module.state_dict() \
302 | if self.world_size > 1 else self.model.state_dict()
303 | ckpt_name = os.path.join(self.ckpt_dir, '%05d' % epoch)
304 | torch.save(model_dict, ckpt_name + '.model.pth')
305 | torch.save({'model_dict': model_dict, 'epoch': epoch,
306 | 'optimizer_dict': self.optimizer.state_dict(),
307 | 'scheduler_dict': self.scheduler.state_dict()},
308 | ckpt_name + '.solver.tar')
309 |
310 | def load_checkpoint(self):
311 | ckpt = self.FLAGS.SOLVER.ckpt
312 | if not ckpt:
313 | # If ckpt is empty, then get the latest checkpoint from ckpt_dir
314 | if not os.path.exists(self.ckpt_dir):
315 | return
316 | ckpts = sorted(os.listdir(self.ckpt_dir))
317 | ckpts = [ck for ck in ckpts if ck.endswith('solver.tar')]
318 | if len(ckpts) > 0:
319 | ckpt = os.path.join(self.ckpt_dir, ckpts[-1])
320 | if not ckpt:
321 | return # return if ckpt is still empty
322 |
323 | # load trained model
324 | # check: map_location = {'cuda:0' : 'cuda:%d' % self.rank}
325 | trained_dict = torch.load(ckpt, map_location='cuda')
326 | if ckpt.endswith('.solver.tar'):
327 | model_dict = trained_dict['model_dict']
328 | self.start_epoch = trained_dict['epoch'] + 1 # !!! add 1
329 | if self.optimizer:
330 | self.optimizer.load_state_dict(trained_dict['optimizer_dict'])
331 | if self.scheduler:
332 | self.scheduler.load_state_dict(trained_dict['scheduler_dict'])
333 | else:
334 | model_dict = trained_dict
335 | model = self.model.module if self.world_size > 1 else self.model
336 | model.load_state_dict(model_dict)
337 |
338 | # print messages
339 | if self.is_master:
340 | tqdm.write('Load the checkpoint: %s' % ckpt)
341 | tqdm.write('The start_epoch is %d' % self.start_epoch)
342 |
343 | def train(self):
344 | self.config_model()
345 | self.config_dataloader()
346 | self.configure_optimizer()
347 | self.configure_log()
348 | self.load_checkpoint()
349 |
350 | rng = range(self.start_epoch, self.FLAGS.SOLVER.max_epoch+1)
351 | for epoch in tqdm(rng, ncols=80, disable=self.disable_tqdm):
352 | # training epoch
353 | self.train_epoch(epoch)
354 |
355 | # update learning rate
356 | self.scheduler.step()
357 | if self.is_master:
358 | lr = self.scheduler.get_last_lr() # lr is a list
359 | self.summary_writer.add_scalar('train/lr', lr[0], epoch)
360 |
361 | # testing or not
362 | if epoch % self.FLAGS.SOLVER.test_every_epoch != 0:
363 | continue
364 |
365 | # testing epoch
366 | self.test_epoch(epoch)
367 |
368 | # checkpoint
369 | self.save_checkpoint(epoch)
370 |
371 | # sync and exit
372 | if self.world_size > 1:
373 | torch.distributed.barrier()
374 |
375 | def test(self):
376 | self.config_model()
377 | self.configure_log(set_writer=False)
378 | self.config_dataloader(disable_train_data=True)
379 | self.load_checkpoint()
380 | self.test_epoch(epoch=0)
381 |
382 | def evaluate(self):
383 | self.config_model()
384 | self.configure_log(set_writer=False)
385 | self.config_dataloader(disable_train_data=True)
386 | self.load_checkpoint()
387 | for epoch in tqdm(range(self.FLAGS.SOLVER.eval_epoch), ncols=80):
388 | self.eval_epoch(epoch)
389 |
390 | def profile(self):
391 | ''' Set `DATA.train.num_workers 0` when using this function'''
392 | self.config_model()
393 | self.config_dataloader()
394 |
395 | # warm up
396 | batch = next(iter(self.train_loader))
397 | for _ in range(3):
398 | output = self.train_step(batch)
399 | output['train/loss'].backward()
400 |
401 | # profile
402 | with torch.autograd.profiler.profile(
403 | use_cuda=True, profile_memory=True,
404 | with_stack=True, record_shapes=True) as prof:
405 | output = self.train_step(batch)
406 | output['train/loss'].backward()
407 |
408 | json = os.path.join(self.FLAGS.SOLVER.logdir, 'trace.json')
409 | print('Save the profile into: ' + json)
410 | prof.export_chrome_trace(json)
411 | print(prof.key_averages(group_by_stack_n=10)
412 | .table(sort_by="cuda_time_total", row_limit=10))
413 | print(prof.key_averages(group_by_stack_n=10)
414 | .table(sort_by="cuda_memory_usage", row_limit=10))
415 |
416 | def run(self):
417 | eval('self.%s()' % self.FLAGS.SOLVER.run)
418 |
419 | @classmethod
420 | def update_configs(cls):
421 | pass
422 |
423 | @classmethod
424 | def worker(cls, gpu, FLAGS):
425 | world_size = len(FLAGS.SOLVER.gpu)
426 | if world_size > 1:
427 | # Set the GPU to use.
428 | torch.cuda.set_device(gpu)
429 | # Initialize the process group. Currently, the code only supports the
430 | # `single node + multiple GPU` mode, so the rank is equal to gpu id.
431 | torch.distributed.init_process_group(
432 | backend='nccl', init_method=FLAGS.SOLVER.dist_url,
433 | world_size=world_size, rank=gpu)
434 | # Master process is responsible for logging, writing and loading
435 | # checkpoints. In the multi GPU setting, we assign the master role to the
436 | # rank 0 process.
437 | is_master = gpu == 0
438 | the_solver = cls(FLAGS, is_master)
439 | else:
440 | the_solver = cls(FLAGS, is_master=True)
441 | the_solver.run()
442 |
443 | @classmethod
444 | def main(cls):
445 | cls.update_configs()
446 | FLAGS = parse_args()
447 |
448 | num_gpus = len(FLAGS.SOLVER.gpu)
449 | if num_gpus > 1:
450 | torch.multiprocessing.spawn(cls.worker, nprocs=num_gpus, args=(FLAGS,))
451 | else:
452 | cls.worker(0, FLAGS)
453 |
--------------------------------------------------------------------------------