├── teaser.png
├── ckpts
├── denoise.pth
├── model_u.pth
└── upsmodel.pth
├── utils
├── __pycache__
│ ├── misc.cpython-36.pyc
│ ├── misc.cpython-38.pyc
│ ├── render.cpython-36.pyc
│ ├── render.cpython-38.pyc
│ ├── denoise.cpython-36.pyc
│ ├── denoise.cpython-38.pyc
│ ├── evaluate.cpython-36.pyc
│ ├── evaluate.cpython-38.pyc
│ ├── mitsuba.cpython-36.pyc
│ ├── mitsuba.cpython-38.pyc
│ ├── transforms.cpython-36.pyc
│ └── transforms.cpython-38.pyc
├── datasets
│ ├── __pycache__
│ │ ├── patch.cpython-36.pyc
│ │ ├── patch.cpython-38.pyc
│ │ ├── pcl.cpython-36.pyc
│ │ ├── pcl.cpython-38.pyc
│ │ ├── toy.cpython-36.pyc
│ │ ├── toy.cpython-38.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── __init__.cpython-38.pyc
│ ├── __init__.py
│ ├── toy.py
│ ├── patch.py
│ └── pcl.py
├── mesh_to_pcl.py
├── mitsuba.py
├── render.py
├── misc.py
├── denoise.py
├── transforms.py
└── evaluate.py
├── scripts
├── __pycache__
│ ├── demo.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── evaluate.cpython-38.pyc
│ ├── train_toy.cpython-36.pyc
│ ├── train_toy.cpython-38.pyc
│ ├── upsample.cpython-36.pyc
│ ├── upsample.cpython-38.pyc
│ ├── validate.cpython-36.pyc
│ ├── validate.cpython-38.pyc
│ ├── mesh_to_pcl.cpython-36.pyc
│ ├── mesh_to_pcl.cpython-38.pyc
│ ├── new_upsample.cpython-36.pyc
│ ├── render_noisy.cpython-36.pyc
│ ├── render_noisy.cpython-38.pyc
│ ├── train_main.cpython-36.pyc
│ ├── train_main.cpython-38.pyc
│ ├── validate_p2m.cpython-36.pyc
│ ├── validate_p2m.cpython-37.pyc
│ ├── validate_p2m.cpython-38.pyc
│ ├── save_denoised.cpython-36.pyc
│ ├── save_denoised.cpython-37.pyc
│ ├── save_denoised.cpython-38.pyc
│ ├── save_upsample.cpython-36.pyc
│ ├── save_upsample.cpython-38.pyc
│ ├── train_denoise.cpython-38.pyc
│ ├── train_upsample.cpython-38.pyc
│ ├── evaluate_upsample.cpython-38.pyc
│ ├── validate_largePC.cpython-36.pyc
│ └── validate_upsample.cpython-36.pyc
├── save_denoised.py
├── evaluate_upsample.py
├── save_upsample.py
├── evaluate.py
├── train_denoise.py
└── train_upsample.py
├── models
├── __pycache__
│ ├── common.cpython-36.pyc
│ ├── common.cpython-38.pyc
│ ├── resampler.cpython-36.pyc
│ ├── resampler.cpython-37.pyc
│ └── resampler.cpython-38.pyc
├── encoders
│ ├── __pycache__
│ │ ├── knn.cpython-36.pyc
│ │ ├── knn.cpython-38.pyc
│ │ ├── mccnn.cpython-36.pyc
│ │ ├── mccnn.cpython-38.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ └── __init__.cpython-38.pyc
│ ├── __init__.py
│ └── knn.py
├── vecfields
│ ├── __pycache__
│ │ ├── knn.cpython-36.pyc
│ │ ├── knn.cpython-38.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── radius.cpython-36.pyc
│ │ └── radius.cpython-38.pyc
│ ├── __init__.py
│ ├── knn.py
│ └── radius.py
├── resampler.py
└── common.py
├── configs
├── ups.yml
└── denoise.yml
├── LICENSE
├── README.md
└── environment.yml
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/teaser.png
--------------------------------------------------------------------------------
/ckpts/denoise.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/ckpts/denoise.pth
--------------------------------------------------------------------------------
/ckpts/model_u.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/ckpts/model_u.pth
--------------------------------------------------------------------------------
/ckpts/upsmodel.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/ckpts/upsmodel.pth
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/misc.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/misc.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/demo.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/demo.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/render.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/render.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/render.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/render.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/common.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/common.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/common.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/common.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/denoise.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/denoise.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/denoise.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/denoise.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/evaluate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/evaluate.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/evaluate.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/evaluate.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/mitsuba.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/mitsuba.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/mitsuba.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/mitsuba.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/resampler.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/resampler.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resampler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/resampler.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/evaluate.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/evaluate.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/train_toy.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_toy.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/train_toy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_toy.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/upsample.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/upsample.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/upsample.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/upsample.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/validate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/validate.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/knn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/knn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/knn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/knn.cpython-38.pyc
--------------------------------------------------------------------------------
/models/vecfields/__pycache__/knn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/knn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/vecfields/__pycache__/knn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/knn.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/mesh_to_pcl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/mesh_to_pcl.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/mesh_to_pcl.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/mesh_to_pcl.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/new_upsample.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/new_upsample.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/render_noisy.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/render_noisy.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/render_noisy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/render_noisy.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/train_main.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_main.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/train_main.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_main.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/validate_p2m.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_p2m.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/validate_p2m.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_p2m.cpython-37.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/validate_p2m.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_p2m.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/patch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/patch.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/patch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/patch.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/pcl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/pcl.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/pcl.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/pcl.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/toy.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/toy.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/toy.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/toy.cpython-38.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/mccnn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/mccnn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/mccnn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/mccnn.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/save_denoised.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_denoised.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/save_denoised.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_denoised.cpython-37.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/save_denoised.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_denoised.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/save_upsample.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_upsample.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/save_upsample.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_upsample.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/train_denoise.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_denoise.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/train_upsample.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_upsample.cpython-38.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/encoders/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/vecfields/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/vecfields/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/vecfields/__pycache__/radius.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/radius.cpython-36.pyc
--------------------------------------------------------------------------------
/models/vecfields/__pycache__/radius.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/radius.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/evaluate_upsample.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/evaluate_upsample.cpython-38.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/validate_largePC.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_largePC.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/__pycache__/validate_upsample.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_upsample.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .pcl import PointCloudDataset,UpsampleDataset
2 | from .patch import PairedPatchDataset,PairedUpsDataset
3 | from .toy import ToyPointCloudDataset
4 |
--------------------------------------------------------------------------------
/models/vecfields/__init__.py:
--------------------------------------------------------------------------------
1 | from .knn import KNearestVectorField
2 | from .radius import RadiusVectorField
3 |
4 |
5 | def get_vecfield(cfg, ctx_point_feature_dim):
6 | if cfg.type == 'knn':
7 | return KNearestVectorField(
8 | knn=cfg.knn,
9 | radius = cfg.radius,
10 | ctx_point_feature_dim=ctx_point_feature_dim,
11 | )
12 | elif cfg.type == 'radius' :
13 | return RadiusVectorField(
14 | radius = cfg.radius ,
15 | ctx_point_feature_dim = ctx_point_feature_dim ,
16 | num_points = cfg.num_points ,
17 | style = cfg.style
18 | )
19 | else:
20 | raise NotImplementedError('Vecfield `%s` is not implemented.' % cfg.type)
21 |
--------------------------------------------------------------------------------
/models/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from .knn import KNearestGCNNPointEncoder
2 | from .mccnn import MCCNNEncoder
3 |
4 |
5 | def get_encoder(cfg):
6 | if cfg.type == 'knn':
7 | return KNearestGCNNPointEncoder(
8 | dynamic_graph=cfg.dynamic_graph,
9 | conv_channels=cfg.conv_channels,
10 | num_convs=cfg.num_convs,
11 | conv_num_fc_layers=cfg.conv_num_fc_layers,
12 | conv_growth_rate=cfg.conv_growth_rate,
13 | conv_knn=cfg.conv_knn,
14 | )
15 | elif cfg.type == 'mccnn':
16 | return MCCNNEncoder(
17 | radius = cfg.radius ,
18 | num_points = cfg.num_points,
19 | kde_bandwidth = cfg.kde_bandwidth ,
20 | point_dim= cfg.point_dim,
21 | first_block_dim= cfg.first_block_dim,
22 | #hidden_dims_pointwise= cfg.hidden_dims_pointwise,
23 | )
24 | else:
25 | raise NotImplementedError('Encoder `%s` is not implemented.' % cfg.type)
26 |
--------------------------------------------------------------------------------
/configs/ups.yml:
--------------------------------------------------------------------------------
1 | # demo configuration
2 | model:
3 | encoder:
4 | type: knn
5 | dynamic_graph: false
6 | num_convs: 7
7 | conv_channels: 30
8 | conv_num_fc_layers: 4
9 | conv_growth_rate: 24
10 | conv_knn: 30
11 | vecfield:
12 | type: knn
13 | knn: 48
14 | style : normal
15 | radius: 0.2
16 | num_points : 48
17 | raise_xyz_channels: 48
18 |
19 | train:
20 | seed: 2021
21 | train_batch_size: 4
22 | num_workers: 4
23 | lr: 5.e-4
24 | weight_decay: 0.
25 | vec_avg_knn: 4
26 | max_iters: 1000000
27 | val_freq: 2000
28 | scheduler:
29 | factor: 0.8
30 | patience: 50
31 | threshold : 3.e-6
32 |
33 | dataset:
34 | method : deeprs
35 | base_dir : # directory that you save your denoising results
36 | dataset_root: # sparse and dense point cloud for training and evaluation
37 | dataset : Mixed
38 | resolutions:
39 | - '1024_poisson'
40 | - '2048_poisson'
41 | rate : 4
42 | patch_size: 512
43 | num_pnts: 10000
44 | noise_min: 0.005
45 | noise_max: 0.030
46 | aug_rotate: true
47 | val_noise: 0.010
48 |
--------------------------------------------------------------------------------
/utils/datasets/toy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | from tqdm.auto import tqdm
6 |
7 |
8 | class ToyPointCloudDataset(Dataset):
9 |
10 | def __init__(self, shape='plane', num_pnts=10000, size=24, transform=None):
11 | super().__init__()
12 | assert shape in ('plane', 'sphere')
13 | self.shape = shape
14 | self.size = size
15 | self.num_pnts = num_pnts
16 | self.transform = transform
17 |
18 | def __len__(self):
19 | return self.size
20 |
21 | def __getitem__(self, idx):
22 | if idx >= self.size:
23 | raise IndexError()
24 | pcl = torch.rand([self.num_pnts, 3])
25 | if self.shape == 'plane':
26 | pcl[:, 2] = 0
27 | elif self.shape == 'sphere':
28 | pcl -= 0.5
29 | pcl /= (pcl ** 2).sum(dim=1, keepdim=True).sqrt()
30 | data = {
31 | 'pcl_clean': pcl,
32 | }
33 | if self.transform is not None:
34 | data = self.transform(data)
35 | return data
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 ChenhLiwnl
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/configs/denoise.yml:
--------------------------------------------------------------------------------
1 | # demo configuration
2 | model:
3 |
4 | encoder:
5 | type: knn
6 | dynamic_graph: false
7 | num_convs: 7
8 | conv_channels: 48
9 | conv_num_fc_layers: 5
10 | conv_growth_rate: 24
11 | conv_knn: 30
12 | vecfield:
13 | type: knn
14 | knn: 24
15 | style : normal
16 | radius: 0.2
17 | num_points : 24
18 | raise_xyz_channels: 48
19 |
20 | train:
21 | seed: 2021
22 | train_batch_size: 4
23 | num_workers: 4
24 | lr: 5.e-4
25 | weight_decay: 0.
26 | vec_avg_knn: 4
27 | max_iters: 1000000
28 | val_freq: 2000
29 | scheduler:
30 | factor: 0.6
31 | patience: 100
32 | threshold : 1.e-5
33 |
34 | dataset:
35 | method : deeprs
36 | base_dir : # directory that you save your denoising results
37 | dataset_root: # clean point cloud for training and evaluation
38 | input_root: # noisy data for evaluation
39 | dataset : PUNet
40 | resolutions:
41 | - '10000_poisson'
42 | - '30000_poisson'
43 | - '10000_poisson'
44 | patch_size: 1000
45 | num_pnts: 10000
46 | noise_min: 0.005
47 | noise_max: 0.030
48 | aug_rotate: true
49 | val_noise: 0.02
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Point Set Resampling via Gradient Fields
2 |
3 |
4 |
5 | [Paper](https://arxiv.org/abs/2111.02045) [code](https://github.com/ChenhLiwnl/deep-rs)
6 |
7 | The official code repository for our TPAMI paper 'Deep Point Set Resampling via Gradient Fields'.
8 |
9 |
10 |
11 | ## Installation
12 |
13 | You can install via conda environment .yaml file
14 |
15 | ```bash
16 | conda env create -f environment.yml
17 | conda activate deeprs
18 | ```
19 |
20 | ## Dataset
21 |
22 | to release soon
23 |
24 | ## Training and Testing
25 |
26 | To train or test your model, you should first edit the corresponding config file according to the comments (fill in the file directory). Then, you can simply run
27 |
28 | ```bash
29 | ## Train a network for denoising
30 | python -m scripts.train_denoise --config ./configs/denoise.yml
31 | ## Train a network for upsampling
32 | python -m scripts.train_upsample --config ./configs/ups.yml
33 | ```
34 |
35 | for training, and run
36 |
37 | ```bash
38 | ## Save your denoising results
39 | python -m scripts.save_denoised --config ./configs/denoise.yml [--ckpt]
40 | ## Save your upsampling results
41 | python -m scripts.train_upsample --config ./configs/ups.yml [--ckpt_model] [--ckpt_gen]
42 | ```
43 |
44 | for testing.
45 |
46 | ## Evaluation
47 |
48 | You can run
49 |
50 | ```bash
51 | python -m scripts.evaluate --config ./configs/denoise.yml
52 | python -m scripts.evaluate_upsample --config ./configs/ups.yml
53 | ```
54 |
55 | to evaluate your results
56 |
57 | ## Citation
58 |
59 | If you feel this work helpful, please cite
60 |
61 | ```
62 | @ARTICLE{9775211,
63 | author={Chen, Haolan and Du, Bi'an and Luo, Shitong and Hu, Wei},
64 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
65 | title={Deep Point Set Resampling via Gradient Fields},
66 | year={2022},
67 | volume={},
68 | number={},
69 | pages={1-1},
70 | doi={10.1109/TPAMI.2022.3175183}}
71 | ```
72 |
73 | and contact chenhl99@pku.edu.cn for any question.
--------------------------------------------------------------------------------
/scripts/save_denoised.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | from types import CodeType
5 | import yaml
6 | from easydict import EasyDict
7 | from tqdm.auto import tqdm
8 | import random
9 | import torch
10 | from torch.utils.data import DataLoader
11 | import torch.utils.tensorboard
12 | from models.resampler import PointSetResampler
13 | from models.common import chamfer_distance_unit_sphere
14 | from utils.datasets import PointCloudDataset, PairedPatchDataset
15 | from utils.transforms import *
16 | from utils.misc import *
17 | from utils.denoise import patch_based_denoise , patch_based_denoise_big,denoise_large_pointcloud
18 | from .validate_p2m import rotate
19 | def denormalize ( pc , center , scale):
20 | return pc * scale + center
21 | if __name__ == '__main__':
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--config', type=str)
24 | parser.add_argument('--device', type=str, default='cuda')
25 | parser.add_argument('--log_root', type=str, default='./logs')
26 | parser.add_argument('--ckpt' , type = str ,default = './ckpts/denoise.pth' )
27 | args = parser.parse_args()
28 |
29 | # Load configs
30 | with open(args.config, 'r') as f:
31 | config = EasyDict(yaml.safe_load(f))
32 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
33 | seed_all(config.train.seed)
34 |
35 | # Datasets and loaders
36 |
37 | val_dset = PointCloudDataset(
38 | root=config.dataset.dataset_root,
39 | dataset=config.dataset.dataset,
40 | split='test',
41 | resolution=config.dataset.resolutions[0],
42 | need_mesh=False,
43 | from_saved= True,
44 | noise_level = config.dataset.val_noise,
45 | input_root = config.dataset.input_root,
46 | transform = NormalizeUnitSphere() ,
47 | )
48 | print("resolution is %s" % (config.dataset.resolutions[0]))
49 | # Model
50 | model = PointSetResampler(config.model).to(args.device)
51 | model.load_state_dict(torch.load(args.ckpt))
52 | def validate(it):
53 | model.eval()
54 | filepath = config.dataset.dataset+"_" + config.dataset.method+"_" + config.dataset.resolutions[0] +"_" +str(config.dataset.val_noise)
55 | file = os.path.join( config.dataset.base_dir, filepath )
56 | os.makedirs(file, exist_ok=True)
57 | print(file)
58 | for i, data in enumerate(tqdm(val_dset, desc='Validate')):
59 | pcl_noisy = data['pcl_clean'].to(args.device)
60 | center = data['center'].to(args.device)
61 | scale = data['scale'].to(args.device)
62 | pcl_denoised = patch_based_denoise(model, pcl_noisy)
63 | pcl_denoised = denormalize(pcl_denoised , center , scale)
64 | filename = os.path.join(file, data['name']+".xyz")
65 | np.savetxt( filename, pcl_denoised.cpu().numpy())
66 |
67 | print('[Val] noise %s , %s | Finished Saving ' % (config.dataset.val_noise, config.dataset.dataset))
68 |
69 | # Main loop
70 | try:
71 | cd_loss = validate(0)
72 |
73 | except KeyboardInterrupt:
74 | print('Terminating...')
75 |
--------------------------------------------------------------------------------
/utils/mesh_to_pcl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import multiprocessing as mp
4 | import logging
5 | import numpy as np
6 | import point_cloud_utils as pcu
7 | logging.basicConfig(level=logging.DEBUG)
8 |
9 | '''
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--dataset_root', type=str, default='./data/PUNet')
12 | args = parser.parse_args()
13 |
14 | DATASET_ROOT = args.dataset_root
15 | '''
16 | DATASET_ROOT = '/172.31.222.52_data/chenhaolan/Mixed'
17 | MESH_ROOT = os.path.join(DATASET_ROOT, 'meshes')
18 | SAVE_ROOT = '/172.31.222.52_data/chenhaolan/Mixed/'
19 | POINTCLOUD_ROOT = os.path.join(SAVE_ROOT, 'pointclouds')
20 | RESOLUTIONS = [1024, 2048 , 4096, 8192 , 16384 , 32768 ,2048*12 ,2048*20 , 2048*32 , 2048*24]
21 | SAMPLERS = ['poisson']
22 | #SUBSETS = ['simple', 'medium', 'complex', 'train', 'test']
23 | SUBSETS = ['train','test']
24 | NUM_WORKERS = 32
25 |
26 |
27 | def poisson_sample(v, f, n, num_points):
28 | pc, _ = pcu.sample_mesh_poisson_disk(v, f, n, num_points, use_geodesic_distance=True)
29 | if pc.shape[0] > num_points:
30 | pc = pc[:num_points, :]
31 | else:
32 | compl, _ = pcu.sample_mesh_random(v, f, n, num_points - pc.shape[0])
33 | # Notice: if (num_points - pc.shape[0]) == 1, sample_mesh_random will
34 | # return a tensor of size (3, ) but not (1, 3)
35 | compl = np.reshape(compl, [-1, 3])
36 | pc = np.concatenate([pc, compl], axis=0)
37 | return pc
38 |
39 |
40 | def random_sample(v, f, n, num_points):
41 | pc, _ = pcu.sample_mesh_random(v, f, n, num_points)
42 | return pc
43 |
44 |
45 | def enum_configs():
46 | for subset in SUBSETS:
47 | for resolution in RESOLUTIONS:
48 | for sampler in SAMPLERS:
49 | yield (subset, resolution, sampler)
50 |
51 |
52 | def enum_meshes():
53 | for subset, resolution, sampler in enum_configs():
54 | in_dir = os.path.join(MESH_ROOT, subset)
55 | out_dir = os.path.join(POINTCLOUD_ROOT, subset, '%d_%s' % (resolution, sampler))
56 | if not os.path.exists(in_dir):
57 | continue
58 | for fn in os.listdir(in_dir):
59 | if fn[-3:] == 'off':
60 | basename = fn[:-4]
61 | yield (subset, resolution, sampler,
62 | os.path.join(in_dir, fn),
63 | os.path.join(out_dir, basename+'.xyz'))
64 |
65 |
66 | def process(args):
67 | subset, resolution, sampler, in_file, out_file = args
68 | if os.path.exists(out_file):
69 | logging.info('Already exists: ' + in_file)
70 | return
71 | logging.info('Start processing: [%s,%d,%s] %s' % (subset, resolution, sampler, in_file))
72 | os.makedirs(os.path.dirname(out_file), exist_ok=True)
73 | v, f, n = pcu.read_off(in_file)
74 | if sampler == 'poisson':
75 | pointcloud = poisson_sample(v, f, n, resolution)
76 | elif sampler == 'random':
77 | pointcloud = random_sample(v, f, n, resolution)
78 | else:
79 | raise ValueError('Unknown sampler: ' + sampler)
80 | np.savetxt(out_file, pointcloud, '%.6f')
81 |
82 |
83 | if __name__ == '__main__':
84 | if NUM_WORKERS > 1:
85 | with mp.Pool(processes=NUM_WORKERS) as pool:
86 | pool.map(process, enum_meshes())
87 | else:
88 | for args in enum_meshes():
89 | process(args)
--------------------------------------------------------------------------------
/models/vecfields/knn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module, Linear, Conv2d, BatchNorm2d, Conv1d, BatchNorm1d, Sequential, ReLU
3 | import pytorch3d.ops
4 | import numpy as np
5 |
6 |
7 | class KNearestVectorField(Module):
8 |
9 | def __init__(
10 | self,
11 | knn,
12 | ctx_point_feature_dim,
13 | raise_xyz_channels=64,
14 | point_dim=3,
15 | radius = 0.1,
16 | hidden_dims_pointwise=[128, 128, 256],
17 | hidden_dims_global=[128, 64]
18 | ):
19 | super().__init__()
20 | self.knn = knn
21 | self.radius = radius
22 | self.raise_xyz = Linear(3, raise_xyz_channels)
23 | print(self.radius)
24 | dims = [raise_xyz_channels+ctx_point_feature_dim] + hidden_dims_pointwise
25 | conv_layers = []
26 | for i in range(len(dims)-1):
27 | conv_layers += [
28 | Conv2d(dims[i], dims[i+1], kernel_size=(1, 1)),
29 | BatchNorm2d(dims[i+1]),
30 | ]
31 | if i < len(dims)-2:
32 | conv_layers += [
33 | ReLU(),
34 | ]
35 | self.pointwise_convmlp = Sequential(*conv_layers)
36 |
37 | dims = [hidden_dims_pointwise[-1]] + hidden_dims_global + [point_dim]
38 | conv_layers = []
39 | for i in range(len(dims)-1):
40 | conv_layers += [
41 | Conv1d(dims[i], dims[i+1], kernel_size=1),
42 | ]
43 | if i < len(dims)-2:
44 | conv_layers += [
45 | BatchNorm1d(dims[i+1]),
46 | ReLU(),
47 | ]
48 | self.global_convmlp = Sequential(*conv_layers)
49 |
50 |
51 | def forward(self, p_query, p_context, h_context):
52 | """
53 | Args:
54 | p_query: Query point set, (B, N_query, 3).
55 | p_context: Context point set, (B, N_ctx, 3).
56 | h_context: Point-wise features of the context point set, (B, N_ctx, H_ctx).
57 | Returns:
58 | (B, N_query, 3)
59 | """
60 | b , N_query , _ = p_query.shape
61 | dist, knn_idx, knn_points = pytorch3d.ops.knn_points(
62 | p1=p_query,
63 | p2=p_context,
64 | K=self.knn,
65 | return_nn=True
66 | ) # (B, N_query,K), (B, N_query, K), (B, N_query, K, 3)
67 | # Relative coordinates and their embeddings
68 | p_rel = knn_points - p_query.unsqueeze(-2) # (B, N_query, K, 3)
69 |
70 | h_rel = self.raise_xyz(p_rel) # (B, N_query, K, H_rel)
71 |
72 | # Grouped features of neighbor points
73 | h_group = pytorch3d.ops.knn_gather(
74 | x=h_context,
75 | idx=knn_idx,
76 | ) # (B, N_query, K, H_ctx)
77 |
78 | # Combine
79 | h_combined = torch.cat([h_rel, h_group], dim=-1) # (B, N_query, K, H_rel+H_ctx)
80 |
81 | # Featurize
82 | h_combined = h_combined.permute(0, 3, 1, 2).contiguous() # (B, H_rel+H_ctx, N_query, K)
83 | y = self.pointwise_convmlp(h_combined) # (B, H_out, N_query, K)
84 | y = y.permute(0, 2, 3, 1).contiguous() # (B, N_query, K, H_out)
85 | dist = torch.sqrt(dist)
86 | dist = dist.unsqueeze(-1) #(B, N_query, K, 1)
87 | c = 0.5*(torch.cos(dist * np.pi / self.radius) + 1.0)
88 | c = c * (dist <= self.radius) * (dist > 0.0) # (B, N_query, K, 1)
89 | #c = 1
90 | y = torch.mul(y , c).permute(0, 3, 1, 2).contiguous()
91 | y = y.sum(-1) # (B, H_out, N_query)
92 | #y, _ = torch.max(y, dim=3) # (B, H_out, N_query)
93 |
94 | # Vectorize
95 | y = self.global_convmlp(y) # (B, 3, N_query)
96 | y = y.permute(0, 2, 1).contiguous() # (B, N_query, 3)
97 | return y
98 |
--------------------------------------------------------------------------------
/scripts/evaluate_upsample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | import yaml
5 | from easydict import EasyDict
6 | from tqdm.auto import tqdm
7 | import random
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.utils.tensorboard
11 | from models.resampler import PointSetResampler
12 | from models.common import chamfer_distance_unit_sphere
13 | from utils.datasets import UpsampleDataset, PairedPatchDataset
14 | from utils.transforms import *
15 | from utils.misc import *
16 | from utils.evaluate import *
17 | from utils.denoise import patch_based_upsample , patch_based_upsample_big
18 |
19 | def denormalize ( pc , center , scale):
20 | return pc * scale + center
21 | def validate_cd ( data , file):
22 | pcl_clean = data['gt'].to(args.device)
23 | name = data['name']
24 | pcl_denoised = np.loadtxt(os.path.join(file , name+".xyz"))
25 | pcl_denoised = torch.from_numpy(pcl_denoised).type(torch.FloatTensor).to(args.device)
26 | chamfer = chamfer_distance_unit_sphere(pcl_clean.unsqueeze(0),pcl_denoised.unsqueeze(0), batch_reduction='mean')[0].item()
27 | return chamfer
28 | def validate_p2m(data , file):
29 | verts = data['meshes']['verts'].to(args.device)
30 | faces = data['meshes']['faces'].to(args.device)
31 | name = data['name']
32 | pcl_denoised = np.loadtxt(os.path.join(file , name+".xyz"))
33 | pcl_denoised = torch.from_numpy(pcl_denoised).type(torch.FloatTensor).to(args.device)
34 | p2f = point_mesh_bidir_distance_single_unit_sphere(
35 | pcl=pcl_denoised,
36 | verts=verts,
37 | faces=faces
38 | ).mean()
39 | return p2f
40 | def validate_hd(data , file):
41 | ref = data['gt'].to(args.device)
42 | name = data['name']
43 | gen = np.loadtxt(os.path.join(file , name+".xyz"))
44 | gen = torch.from_numpy(gen).type(torch.FloatTensor).to(args.device)
45 | hd = hausdorff_distance_unit_sphere(gen = gen.unsqueeze(0) , ref = ref.unsqueeze(0))
46 | return hd
47 | if __name__ == '__main__':
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument('--config', type=str)
50 | parser.add_argument('--device', type=str, default='cuda')
51 | parser.add_argument('--log_root', type=str, default='./logs')
52 | args = parser.parse_args()
53 |
54 | # Load configs
55 | with open(args.config, 'r') as f:
56 | config = EasyDict(yaml.safe_load(f))
57 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
58 | seed_all(config.train.seed)
59 |
60 | # Datasets and loaders
61 |
62 | val_dset = UpsampleDataset(
63 | root=config.dataset.dataset_root,
64 | dataset=config.dataset.dataset,
65 | split='test',
66 | resolution=config.dataset.resolutions[1],
67 | rate = 16,
68 | noise_min = 0,
69 | noise_max = 0,
70 | need_mesh=True
71 | )
72 | print("resolution is %s" % (config.dataset.resolutions[1]))
73 | filepath = config.dataset.dataset+"_" + config.dataset.method+"_" + config.dataset.resolutions[0] +"_" + str(val_dset.rate) + "x"
74 | file = os.path.join( config.dataset.base_dir, filepath )
75 | print(file)
76 | def validate(it):
77 | global best
78 | avg_chamfer = 0
79 | total_chamfer = 0
80 | total_p2m = 0
81 | total_hd = 0
82 | for i, data in enumerate(tqdm(val_dset, desc='Validate')):
83 | total_p2m += validate_p2m(data,file)
84 | total_chamfer += validate_cd(data, file)
85 | avg_chamfer = total_chamfer / len(val_dset)
86 | avg_p2m = total_p2m / len(val_dset)
87 | print('[Val] noise %.6f | CD %.8f ' % (config.dataset.val_noise, avg_chamfer))
88 | print('[Val] noise %.6f | P2M%.8f ' % (config.dataset.val_noise, avg_p2m))
89 |
90 | # Main loop
91 | try:
92 | cd_loss = validate(0)
93 |
94 | except KeyboardInterrupt:
95 | print('Terminating...')
96 |
--------------------------------------------------------------------------------
/scripts/save_upsample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | import yaml
5 | from easydict import EasyDict
6 | from tqdm.auto import tqdm
7 | import random
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.utils.tensorboard
11 | from models.resampler import PointSetResampler
12 | from models.common import chamfer_distance_unit_sphere, coarsesGenerator
13 | from utils.datasets import UpsampleDataset, PairedPatchDataset, pcl
14 | from utils.transforms import *
15 | from utils.misc import *
16 | from utils.denoise import patch_based_upsample , patch_based_upsample_big
17 | def addnoise (pcl , noise ):
18 | pcl = pcl + torch.randn_like(pcl) * noise
19 | data['noise_std'] = noise_std
20 | return data
21 | def normalize(pcl, center=None, scale=None):
22 | if center is None:
23 | p_max = pcl.max(dim=0, keepdim=True)[0]
24 | p_min = pcl.min(dim=0, keepdim=True)[0]
25 | center = (p_max + p_min) / 2 # (1, 3)
26 | pcl = pcl - center
27 | if scale is None:
28 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1)
29 | pcl = pcl / scale
30 | return pcl, center, scale
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--config', type=str)
34 | parser.add_argument('--device', type=str, default='cuda')
35 | parser.add_argument('--log_root', type=str, default='./logs')
36 | parser.add_argument('--ckpt_model' ,type = str , default = './ckpts/model_u.pth' )
37 | parser.add_argument('--ckpt_gen' ,type = str , default = './ckpts/upsmodel.pth' )
38 | args = parser.parse_args()
39 |
40 | # Load configs
41 | with open(args.config, 'r') as f:
42 | config = EasyDict(yaml.safe_load(f))
43 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
44 | seed_all(config.train.seed)
45 |
46 | # Datasets and loaders
47 |
48 | val_dset = UpsampleDataset(
49 | root=config.dataset.dataset_root,
50 | dataset=config.dataset.dataset,
51 | split='test',
52 | resolution=config.dataset.resolutions[1],
53 | rate=16,
54 | noise_min = 1.5e-2,
55 | noise_max = 1.5e-2
56 | )
57 |
58 | print("resolution is %s" % (config.dataset.resolutions[1]))
59 | # Model
60 | model = PointSetResampler(config.model).to(args.device)
61 | model.load_state_dict(torch.load(args.ckpt_model))
62 | upsampler = coarsesGenerator(rate = 16 ).to(args.device)
63 | upsampler.load_state_dict(torch.load(args.ckpt_gen))
64 |
65 | def validate(it):
66 | model.eval()
67 | upsampler.eval()
68 | all_clean = []
69 | all_denoised = []
70 | filepath = config.dataset.dataset+"_" + config.dataset.method+"_" + config.dataset.resolutions[0] +"_" + str(val_dset.rate) + "x"
71 | filepath = os.path.join( config.dataset.base_dir, filepath )
72 | os.makedirs(filepath, exist_ok=True)
73 | for i, data in enumerate(tqdm(val_dset, desc='Validate')):
74 | with torch.no_grad():
75 | pcl_noisy = data['ups'].to(args.device)
76 | pcl_low = data['original'].to(args.device)
77 | pcl_clean = data['gt'].to(args.device)
78 | pcl_denoised = upsampler(pcl_low.unsqueeze(0)).squeeze(0)
79 | file = os.path.join(filepath , data['name']+".xyz" )
80 | np.savetxt( file,pcl_denoised.detach().cpu().numpy() )
81 | all_clean.append(pcl_clean.unsqueeze(0))
82 | all_denoised.append(pcl_denoised.unsqueeze(0))
83 | all_clean = torch.cat(all_clean, dim=0)
84 | all_denoised = torch.cat(all_denoised, dim=0)
85 | avg_chamfer = chamfer_distance_unit_sphere(all_denoised, all_clean, batch_reduction='mean')[0].item()
86 | print('[Val] Iter %04d | CD %.8f ' % (it, avg_chamfer))
87 |
88 | # Main loop
89 | try:
90 | cd_loss = validate(0)
91 |
92 | except KeyboardInterrupt:
93 | print('Terminating...')
94 |
--------------------------------------------------------------------------------
/scripts/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | import yaml
5 | from easydict import EasyDict
6 | from tqdm.auto import tqdm
7 |
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.utils.tensorboard
11 | from models.resampler import PointSetResampler
12 | from models.common import chamfer_distance_unit_sphere
13 | from utils.datasets import PointCloudDataset, PairedPatchDataset
14 | from utils.transforms import *
15 | from utils.misc import *
16 | from utils.evaluate import *
17 | from utils.denoise import patch_based_denoise
18 | def rotate( pc, degree):
19 | degree = math.pi * degree / 180.0
20 | sin, cos = math.sin(degree), math.cos(degree)
21 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
22 | #matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
23 | #matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
24 | matrix = torch.tensor(matrix).to("cuda")
25 | pc = torch.matmul(pc, matrix)
26 | return pc
27 | def normalize(pcl, center=None, scale=None):
28 | """
29 | Args:
30 | pcl: The point cloud to be normalized, (N, 3)
31 | """
32 | if center is None:
33 | p_max = pcl.max(dim=0, keepdim=True)[0]
34 | p_min = pcl.min(dim=0, keepdim=True)[0]
35 | center = (p_max + p_min) / 2 # (1, 3)
36 | pcl = pcl - center
37 | if scale is None:
38 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1)
39 | pcl = pcl / scale
40 | return pcl, center, scale
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--config', type=str)
44 | parser.add_argument('--device', type=str, default='cuda')
45 | parser.add_argument('--log_root', type=str, default='./logs')
46 | args = parser.parse_args()
47 |
48 | # Load configs
49 | with open(args.config, 'r') as f:
50 | config = EasyDict(yaml.safe_load(f))
51 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
52 | seed_all(config.train.seed)
53 |
54 | # Datasets and loaders
55 |
56 | val_dset = PointCloudDataset(
57 | root=config.dataset.dataset_root,
58 | dataset=config.dataset.dataset,
59 | split='test',
60 | input_root=config.dataset.input_root,
61 | resolution=config.dataset.resolutions[0],
62 | from_saved=False,
63 | need_mesh=True,
64 | transform = NormalizeUnitSphere() ,
65 | )
66 |
67 | print("resolution is %s" % (config.dataset.resolutions[0]))
68 | method = config.dataset.method
69 | resolution = config.dataset.resolutions[0]
70 | val_noise = str(config.dataset.val_noise)
71 | filepath = config.dataset.dataset+ "_"+method +"_"+ resolution +"_" + val_noise
72 | file = os.path.join(config.dataset.base_dir, filepath )
73 | #filepath = '/172.31.222.52_data/luost/denoisegf_data/results'
74 | file = os.path.join(filepath,file)
75 | print(file)
76 | def validate(it):
77 | global best
78 | avg_chamfer = 0
79 | avg_p2f = 0
80 | for i, data in enumerate(tqdm(val_dset, desc='Validate')):
81 | pcl_clean = data['pcl_clean'].to(args.device)
82 | verts = data['meshes']['verts'].to(args.device)
83 | faces = data['meshes']['faces'].to(args.device)
84 | name = data['name']
85 | pcl_denoised = np.loadtxt(os.path.join(file , name+".xyz"))
86 | pcl_denoised = torch.from_numpy(pcl_denoised).type(torch.FloatTensor).to(args.device)
87 | if val_noise == 'blensor':
88 | pcl_denoised = rotate(pcl_denoised , -90) # blensor requires rotation, -90
89 | avg_p2f += point_mesh_bidir_distance_single_unit_sphere(
90 | pcl=pcl_denoised,
91 | verts=verts,
92 | faces=faces
93 | ).mean()
94 | pcl_denoised, _ , _ = normalize(pcl_denoised,data['center'].cuda(),data['scale'].cuda())
95 | avg_chamfer += chamfer_distance_unit_sphere(pcl_clean.unsqueeze(0),pcl_denoised.unsqueeze(0), batch_reduction='mean')[0].item()
96 |
97 | avg_chamfer /= len(val_dset)
98 | avg_p2f /= len(val_dset)
99 | print('[Val] noise %s | CD %.8f ' % (config.dataset.val_noise, avg_chamfer))
100 | print('[Val] noise %s | P2M %.8f ' % (config.dataset.val_noise, avg_p2f))
101 |
102 | # Main loop
103 | try:
104 | cd_loss = validate(0)
105 |
106 | except KeyboardInterrupt:
107 | print('Terminating...')
108 |
--------------------------------------------------------------------------------
/utils/mitsuba.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def standardize_bbox(pcl):
5 | mins = np.amin(pcl, axis=0)
6 | maxs = np.amax(pcl, axis=0)
7 | center = ( mins + maxs ) / 2.
8 | scale = np.amax(maxs-mins)
9 | # print("Center: {}, Scale: {}".format(center, scale))
10 | result = ((pcl - center)/scale).astype(np.float32) # [-0.5, 0.5]
11 |
12 | # Move down onto the floor
13 |
14 | m = result.min(0)
15 | m[:2] = 0
16 | m[2] = m[2] + 0.3
17 | result = result - m
18 |
19 | return result
20 |
21 |
22 | def downsample(pcl, colors, n_points):
23 | pt_indices = np.random.choice(pcl.shape[0], n_points, replace=False)
24 | np.random.shuffle(pt_indices)
25 | pcl = pcl[pt_indices] # n by 3
26 | colors = colors[pt_indices]
27 | return pcl, colors
28 |
29 |
30 | def select_downsample(pcl, noise, n_points, low_noise=False):
31 | if low_noise:
32 | idx = np.argpartition(noise.flatten(), n_points)
33 | pcl = pcl[idx[:n_points]] # n by 3
34 | else:
35 | np.random.shuffle(pcl)
36 | pcl = pcl[:n_points]
37 | return pcl
38 |
39 |
40 |
41 | def xml_head():
42 | return """
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | """
73 |
74 |
75 | def xml_tail():
76 | return """
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 | """
96 |
97 | #
98 | #
99 | #
100 |
101 | def xml_sphere(x, y, z, r, g, b, radius=0.0075):
102 | tmpl = """
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 | """
113 | return tmpl.format(radius, x, y, z, r, g, b)
114 |
115 |
116 | # def xml_icosphere(x, y, z, r, g, b, radius=0.0075):
117 | # tmpl = """
118 | #
119 | #
120 | #
121 | #
122 | #
123 | #
124 | #
125 | #
126 | #
127 | # """
128 | # return tmpl.format(x, y, z, radius/2, r, g, b)
129 |
130 |
131 | def make_xml(pcl, color, radius, flipX=False, flipY=False, flipZ=False, max_points=None):
132 | xml = ''
133 | xml += xml_head()
134 |
135 | pcl = standardize_bbox(pcl)
136 | print(pcl.shape)
137 | if max_points is not None:
138 | pcl, color = downsample(pcl, color, max_points)
139 | for p, c in zip(pcl, color):
140 | x, y, z = p
141 | r, g, b = c
142 | if flipX: x = -x
143 | if flipY: y = -y
144 | if flipZ: z = -z
145 | xml += xml_sphere(x, y, z, r, g, b, radius)
146 |
147 | xml += xml_tail()
148 |
149 | return xml
150 |
--------------------------------------------------------------------------------
/utils/render.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import point_cloud_utils as pcu
6 | from matplotlib import cm
7 |
8 | from pytorch3d.ops import knn_points
9 | from pytorch3d.structures import Meshes
10 | from pytorch3d.renderer import (
11 | look_at_view_transform,
12 | FoVPerspectiveCameras,
13 | PointLights,
14 | DirectionalLights,
15 | Materials,
16 | RasterizationSettings,
17 | MeshRenderer,
18 | MeshRasterizer,
19 | SoftPhongShader,
20 | TexturesUV,
21 | TexturesVertex
22 | )
23 |
24 | from models.common import *
25 | from .mitsuba import *
26 |
27 |
28 | def get_icosahedron(size):
29 | X = .525731112119133606
30 | Z = .850650808352039932
31 | verts = torch.FloatTensor([
32 | [-X, 0.0, Z], [X, 0.0, Z], [-X, 0.0, -Z], [X, 0.0, -Z],
33 | [0.0, Z, X], [0.0, Z, -X], [0.0, -Z, X], [0.0, -Z, -X],
34 | [Z, X, 0.0], [-Z, X, 0.0], [Z, -X, 0.0], [-Z, -X, 0.0]
35 | ]) * size
36 | faces = torch.LongTensor([
37 | [1, 4, 0], [4, 9, 0], [4, 5, 9], [8, 5, 4], [1, 8, 4],
38 | [1, 10, 8], [10, 3, 8], [8, 3, 5], [3, 2, 5], [3, 7, 2],
39 | [3, 10, 7], [10, 6, 7], [6, 11, 7], [6, 0, 11], [6, 1, 0],
40 | [10, 1, 6], [11, 0, 9], [2, 11, 9], [5, 2, 9], [11, 2, 7],
41 | ])
42 | return verts, faces
43 |
44 |
45 | class NoisyPointCloudXMLMaker(object):
46 |
47 | def __init__(self, point_size=0.01, max_noise=0.03):
48 | super().__init__()
49 | self.point_size = point_size
50 | self.max_noise = max_noise
51 |
52 | def get_color(self, noise, showing_noisy=False):
53 | """
54 | Args:
55 | noise: (N, 1)
56 | """
57 | max_noise = self.max_noise
58 |
59 | N = noise.shape[0]
60 | noise_level = np.clip(noise / max_noise, 0, 1) # (N, 1)
61 |
62 | base_color = np.repeat(np.array([[0, 0, 1]], dtype=np.float), N, axis=0) # Blue, (N, 3)
63 | noise_color = np.repeat(np.array([[1, 1, 0]], dtype=np.float), N, axis=0) # Yellow
64 | if showing_noisy:
65 | noise_level = (np.power(10, noise_level)-1) / (np.power(10, 1)-1)
66 | else:
67 | noise_level = (np.power(25, noise_level)-1) / (np.power(25, 1)-1)
68 |
69 | inv_mix = (1-noise_level)*(1-base_color) + noise_level*(1-noise_color)
70 | return 1 - inv_mix # (N, 3)
71 |
72 | def render(self, pcl, verts, faces, rotation=None, lookat=(2.0, 30, -45), distance='p2m'):
73 | assert distance in ('p2m', 'nn')
74 |
75 | device = pcl.device
76 | # Normalize mesh
77 | verts, center, scale = normalize_sphere(verts.unsqueeze(0))
78 | verts = verts[0]
79 | # Normalize pcl
80 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale)
81 | pcl = pcl[0]
82 | if distance == 'p2m':
83 | # Compute point-to-surface distance
84 | p2m = pointwise_p2m_distance_normalized(pcl.to(device), verts.to(device), faces.to(device)).sqrt().unsqueeze(-1)
85 | elif distance == 'nn':
86 | p2m, _, _ = pytorch3d.ops.knn_points(pcl.unsqueeze(0), verts.unsqueeze(0), K=1)
87 | p2m = p2m[0,:,:].mean(dim=-1,keepdim=True)
88 | p2m = p2m
89 | print(p2m.max() , p2m.mean() , p2m.min())
90 |
91 |
92 |
93 | # Rotate point cloud
94 | if rotation is not None:
95 | pcl = torch.matmul(pcl, rotation.t())
96 | # pcl[:, 1] = -pcl[:, 1]
97 |
98 | xml = make_xml(pcl.cpu().numpy(), self.get_color(p2m.cpu().numpy()), radius=self.point_size , max_points=None)
99 | return xml
100 |
101 |
102 | class SimpleMeshRenderer(object):
103 |
104 | def __init__(self, image_size=1024):
105 | super().__init__()
106 | self.image_size = image_size
107 |
108 | def render(self, verts, faces, lookat=(2.0, 30, -45), color=(1, 1, 1)):
109 | device = verts.device
110 | # Normalize mesh
111 | verts, _, _ = normalize_sphere(verts.unsqueeze(0))
112 | verts = verts[0]
113 |
114 | textures = TexturesVertex([
115 | torch.FloatTensor(color).to(device).view(1, 3).repeat(verts.size(0), 1)
116 | ])
117 | meshes = Meshes([verts], [faces], textures)
118 |
119 | # Render
120 | R, T = look_at_view_transform(*lookat)
121 | cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
122 | raster_settings = RasterizationSettings(
123 | image_size=self.image_size,
124 | blur_radius=0.,
125 | faces_per_pixel=1,
126 | )
127 |
128 | lights = PointLights(
129 | device=device,
130 | ambient_color=torch.ones(1,3) * 0.5,
131 | diffuse_color=torch.ones(1,3) * 0.4,
132 | specular_color=torch.ones(1,3) * 0.1,
133 | location=cameras.get_camera_center()
134 | )
135 |
136 | renderer = MeshRenderer(
137 | rasterizer=MeshRasterizer(
138 | cameras=cameras,
139 | raster_settings=raster_settings
140 | ),
141 | shader=SoftPhongShader(
142 | device=device,
143 | cameras=cameras,
144 | lights=lights
145 | )
146 | )
147 |
148 | images = renderer(meshes)
149 |
150 |
151 |
152 |
153 | return images
154 |
--------------------------------------------------------------------------------
/models/vecfields/radius.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module, Linear, Conv2d, BatchNorm2d, Conv1d, BatchNorm1d, Sequential, ReLU
3 | import pytorch3d.ops
4 | import numpy as np
5 | from ..common import *
6 |
7 |
8 | class RadiusVectorField(Module):
9 |
10 | def __init__(
11 | self,
12 | radius,
13 | num_points ,
14 | ctx_point_feature_dim,
15 | style = "normal" ,
16 | max_points = 60,
17 | raise_xyz_channels=64,
18 | point_dim=3,
19 | hidden_dims_pointwise=[128, 128, 256],
20 | hidden_dims_global=[128, 64]
21 | ):
22 | super().__init__()
23 | self.radius = radius
24 | self.style = style
25 | self.num_points = num_points
26 | self.max_points = max_points
27 | self.raise_xyz = Linear(3, raise_xyz_channels)
28 |
29 | dims = [raise_xyz_channels+ctx_point_feature_dim] + hidden_dims_pointwise
30 | conv_layers = []
31 | for i in range(len(dims)-1):
32 | if self.style == "normal":
33 | conv_layers += [
34 | Conv2d(dims[i], dims[i+1], kernel_size=(1, 1)),
35 | BatchNorm2d(dims[i+1]),
36 | ]
37 | elif self.style == "residual":
38 | conv_layers += [
39 | ResnetBlockConv2d(dims[i], dims[i+1]),
40 | BatchNorm2d(dims[i+1]),
41 | ]
42 | if i < len(dims)-2:
43 | conv_layers += [
44 | ReLU(),
45 | ]
46 | self.pointwise_convmlp = Sequential(*conv_layers)
47 |
48 | dims = [hidden_dims_pointwise[-1]] + hidden_dims_global + [point_dim]
49 | conv_layers = []
50 | for i in range(len(dims)-1):
51 | if self.style == "normal":
52 | conv_layers += [
53 | Conv1d(dims[i], dims[i+1], kernel_size=1),
54 | ]
55 | elif self.style == "residual":
56 | conv_layers += [
57 | ResnetBlockConv1d(dims[i] , dims[i+1]) ,
58 | ]
59 | if i < len(dims)-2:
60 | conv_layers += [
61 | BatchNorm1d(dims[i+1]),
62 | ReLU(),
63 | ]
64 | self.global_convmlp = Sequential(*conv_layers)
65 |
66 |
67 | def forward(self, p_query, p_context, h_context):
68 | """
69 | Args:
70 | p_query: Query point set, (B, N_query, 3).
71 | p_context: Context point set, (B, N_ctx, 3).
72 | h_context: Point-wise features of the context point set, (B, N_ctx, H_ctx).
73 | Returns:
74 | (B, N_query, 3)
75 | """
76 | dist, knn_idx, _ = pytorch3d.ops.knn_points(
77 | p1=p_query,
78 | p2=p_context,
79 | K=self.max_points,
80 | return_nn=True,
81 | return_sorted=True
82 | ) # (B,N_query , K ) , (B, N_query, K), (B, N_query, K, 3) , K should be a large number
83 | # dist : [B, N_query, K] records the distance between the query points and its K nearest neighbours
84 | B , N_query , _ = p_query.shape
85 | _ , N , _ = p_context.shape
86 | #radius_first_point = knn_idx[ : , : , 0].view(B , N_query , 1).repeat([1, 1, self.num_points])
87 | knn_idx[dist > self.radius ** 2] = N + 1 # set index of points out of range as N + 1
88 | radius_graph_idx = knn_idx[ : , : , :self.num_points] # knn_idx is sorted in ascending order before
89 |
90 | # actually copy the value of the first point in radius_graph for subsequent replacement.
91 | mask = (radius_graph_idx == N + 1) # The number of points within their( these points' ) radius is less than num_points
92 | radius_graph_idx = torch.where(mask, torch.zeros_like(radius_graph_idx),radius_graph_idx)
93 | radius_graph_points = pytorch3d.ops.knn_gather( x = p_context , idx = radius_graph_idx) # ( B, N_query, num_points, 3 )
94 | dist = dist [ : , : , :self.num_points]
95 | dist = torch.where(mask , torch.zeros_like(dist) , dist)
96 | # Relative coordinates and their embeddings
97 | p_rel = radius_graph_points - p_query.unsqueeze(-2) # (B, N_query, num_points, 3)
98 | h_rel = self.raise_xyz(p_rel) # (B, N_query, num_points , H_rel)
99 |
100 | # Grouped features of neighbor points
101 | h_group = pytorch3d.ops.knn_gather(
102 | x=h_context,
103 | idx=radius_graph_idx,
104 | ) # (B, N_query, num_points, H_ctx)
105 | # Combine
106 | h_combined = torch.cat([h_rel, h_group], dim=-1) # (B, N_query, num_points, H_rel+H_ctx)
107 |
108 | # Featurize
109 | h_combined = h_combined.permute(0, 3, 1, 2).contiguous() # (B, H_rel+H_ctx, N_query, num_points)
110 | y = self.pointwise_convmlp(h_combined) # (B, H_out, N_query, num_points)
111 | y = y.permute(0, 2, 3, 1).contiguous() # (B, N_query, num_points, H_out)
112 | dist = dist.unsqueeze(-1) #(B, N_query, num_points, 1)
113 | c = 0.5*(torch.cos(dist * np.pi / self.radius) + 1.0)
114 | c = c * (dist <= self.radius) * (dist > 0.0) # (B, N_query, num_points, 1)
115 | y = torch.mul(y , c).permute(0, 3, 1, 2).contiguous()
116 | y = y.sum(-1) # (B, H_out, N_query)
117 | #y, _ = torch.max(y, dim=3) # (B, H_out, N_query)
118 | # Vectorize
119 | y = self.global_convmlp(y) # (B, 3, N_query)
120 | y = y.permute(0, 2, 1).contiguous() # (B, N_query, 3)
121 | return y
122 |
--------------------------------------------------------------------------------
/models/encoders/knn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module, Linear, ModuleList
3 | import pytorch3d.ops
4 | from ..common import *
5 |
6 | class DenseEdgeConv(Module):
7 |
8 | def __init__(self, in_channels, num_fc_layers, growth_rate, knn=16, aggr='max', activation='relu', relative_feat_only=False):
9 | super().__init__()
10 | self.in_channels = in_channels
11 | self.knn = knn
12 | assert num_fc_layers > 2
13 | self.num_fc_layers = num_fc_layers
14 | self.growth_rate = growth_rate
15 | self.relative_feat_only = relative_feat_only
16 |
17 | # Densely Connected Layers
18 | if relative_feat_only:
19 | self.layer_first = FCLayer(in_channels, growth_rate, bias=True, activation=activation)
20 | else:
21 | self.layer_first = FCLayer(3*in_channels, growth_rate, bias=True, activation=activation)
22 | self.layer_last = FCLayer(in_channels + (num_fc_layers - 1) * growth_rate, growth_rate, bias=True, activation=None)
23 | self.layers = ModuleList()
24 | for i in range(1, num_fc_layers-1):
25 | self.layers.append(FCLayer(in_channels + i * growth_rate, growth_rate, bias=True, activation=activation))
26 |
27 | self.aggr = Aggregator(aggr)
28 |
29 | @property
30 | def out_channels(self):
31 | return self.in_channels + self.num_fc_layers * self.growth_rate
32 |
33 | def get_edge_feature(self, x, knn_idx):
34 | """
35 | :param x: (B, N, d)
36 | :param knn_idx: (B, N, K)
37 | :return (B, N, K, 2*d)
38 | """
39 | knn_feat = knn_group(x, knn_idx) # B * N * K * d
40 | x_tiled = x.unsqueeze(-2).expand_as(knn_feat)
41 | if self.relative_feat_only:
42 | edge_feat = knn_feat - x_tiled
43 | else:
44 | edge_feat = torch.cat([x_tiled, knn_feat, knn_feat - x_tiled], dim=3)
45 | return edge_feat
46 |
47 | def forward(self, x, pos):
48 | """
49 | :param x: (B, N, d)
50 | :return (B, N, d+L*c)
51 | """
52 | knn_idx = get_knn_idx(pos, pos, k=self.knn, offset=1)
53 |
54 | # First Layer
55 | edge_feat = self.get_edge_feature(x, knn_idx)
56 | y = torch.cat([
57 | self.layer_first(edge_feat), # (B, N, K, c)
58 | x.unsqueeze(-2).repeat(1, 1, self.knn, 1) # (B, N, K, d)
59 | ], dim=-1) # (B, N, K, d+c)
60 |
61 | # Intermediate Layers
62 | for layer in self.layers:
63 | y = torch.cat([
64 | layer(y), # (B, N, K, c)
65 | y, # (B, N, K, c+d)
66 | ], dim=-1) # (B, N, K, d+c+...)
67 |
68 | # Last Layer
69 | y = torch.cat([
70 | self.layer_last(y), # (B, N, K, c)
71 | y # (B, N, K, d+(L-1)*c)
72 | ], dim=-1) # (B, N, K, d+L*c)
73 |
74 | # Pooling
75 | y = self.aggr(y, dim=-2)
76 |
77 | return y
78 |
79 |
80 | class KNearestGCNNPointEncoder(Module):
81 |
82 | def __init__(self,
83 | in_channels=3,
84 | dynamic_graph=False,
85 | conv_channels=24,
86 | num_convs=4,
87 | conv_num_fc_layers=3,
88 | conv_growth_rate=12,
89 | conv_knn=16,
90 | conv_aggr='max',
91 | activation='relu'
92 | ):
93 | super().__init__()
94 | self.in_channels = in_channels
95 | self.dynamic_graph = dynamic_graph
96 | self.num_convs = num_convs
97 | # Edge Convolution Units
98 | self.transforms = ModuleList()
99 | self.convs = ModuleList()
100 | for i in range(num_convs):
101 | if i == 0:
102 | trans = FCLayer(in_channels, conv_channels, bias=True, activation=None)
103 | conv = DenseEdgeConv(
104 | conv_channels,
105 | num_fc_layers=conv_num_fc_layers,
106 | growth_rate=conv_growth_rate,
107 | knn=conv_knn,
108 | aggr=conv_aggr,
109 | activation=activation,
110 | relative_feat_only=True
111 | )
112 | else:
113 | trans = FCLayer(in_channels, conv_channels, bias=True, activation=activation)
114 | conv = DenseEdgeConv(
115 | conv_channels,
116 | num_fc_layers=conv_num_fc_layers,
117 | growth_rate=conv_growth_rate,
118 | knn=conv_knn,
119 | aggr=conv_aggr,
120 | activation=activation,
121 | relative_feat_only=False
122 | )
123 | self.transforms.append(trans)
124 | self.convs.append(conv)
125 | in_channels = conv.out_channels
126 |
127 | @property
128 | def out_channels(self):
129 | return self.convs[-1].out_channels
130 |
131 | def dynamic_graph_forward(self, x):
132 | for i in range(self.num_convs):
133 | x = self.transforms[i](x)
134 | x = self.convs[i](x, x)
135 | return x
136 |
137 | def static_graph_forward(self, pos):
138 | x = pos
139 | for i in range(self.num_convs):
140 | x = self.transforms[i](x)
141 | x = self.convs[i](x, pos)
142 | return x
143 |
144 | def forward(self, x):
145 | if self.dynamic_graph:
146 | return self.dynamic_graph_forward(x)
147 | else:
148 | return self.static_graph_forward(x)
149 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: deeprs
2 | channels:
3 | - pytorch3d
4 | - pytorch
5 | - bottler
6 | - conda-forge
7 | - iopath
8 | - fvcore
9 | - defaults
10 | dependencies:
11 | - _libgcc_mutex=0.1=main
12 | - _openmp_mutex=4.5=1_gnu
13 | - absl-py=0.15.0=pyhd3eb1b0_0
14 | - aiohttp=3.8.1=py38h7f8727e_1
15 | - aiosignal=1.2.0=pyhd3eb1b0_0
16 | - async-timeout=4.0.1=pyhd3eb1b0_0
17 | - attrs=21.4.0=pyhd3eb1b0_0
18 | - blas=1.0=mkl
19 | - blessings=1.7=py38h06a4308_1002
20 | - blinker=1.4=py38h06a4308_0
21 | - bottleneck=1.3.4=py38hce1f21e_0
22 | - brotli=1.0.9=he6710b0_2
23 | - brotlipy=0.7.0=py38h27cfd23_1003
24 | - c-ares=1.18.1=h7f8727e_0
25 | - ca-certificates=2022.4.26=h06a4308_0
26 | - cachetools=4.2.2=pyhd3eb1b0_0
27 | - certifi=2022.5.18.1=py38h06a4308_0
28 | - cffi=1.15.0=py38hd667e15_1
29 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
30 | - click=8.0.4=py38h06a4308_0
31 | - colorama=0.4.4=pyhd3eb1b0_0
32 | - cryptography=3.4.8=py38hd23ed53_0
33 | - cudatoolkit=10.2.89=hfd86e86_1
34 | - cycler=0.11.0=pyhd3eb1b0_0
35 | - dataclasses=0.8=pyh6d0b6a4_7
36 | - dbus=1.13.18=hb2f20db_0
37 | - expat=2.4.4=h295c915_0
38 | - fontconfig=2.13.1=h6c09931_0
39 | - fonttools=4.25.0=pyhd3eb1b0_0
40 | - freetype=2.11.0=h70c0345_0
41 | - frozenlist=1.2.0=py38h7f8727e_0
42 | - fvcore=0.1.5.post20210915=py38
43 | - giflib=5.2.1=h7b6447c_0
44 | - glib=2.69.1=h4ff587b_1
45 | - google-auth=2.6.0=pyhd3eb1b0_0
46 | - google-auth-oauthlib=0.4.1=py_2
47 | - gpustat=0.6.0=pyhd3eb1b0_1
48 | - grpcio=1.42.0=py38hce63b2e_0
49 | - gst-plugins-base=1.14.0=h8213a91_2
50 | - gstreamer=1.14.0=h28cd5cc_2
51 | - h5py=3.6.0=py38ha0f2276_0
52 | - hdf5=1.10.6=hb1b8bf9_0
53 | - icu=58.2=he6710b0_3
54 | - idna=3.3=pyhd3eb1b0_0
55 | - importlib-metadata=4.11.3=py38h06a4308_0
56 | - intel-openmp=2021.4.0=h06a4308_3561
57 | - iopath=0.1.9=py38
58 | - joblib=1.1.0=pyhd3eb1b0_0
59 | - jpeg=9b=h024ee3a_2
60 | - kiwisolver=1.3.2=py38h295c915_0
61 | - lcms2=2.12=h3be6417_0
62 | - ld_impl_linux-64=2.35.1=h7274673_9
63 | - libffi=3.3=he6710b0_2
64 | - libgcc-ng=9.3.0=h5101ec6_17
65 | - libgfortran-ng=7.5.0=ha8ba4b0_17
66 | - libgfortran4=7.5.0=ha8ba4b0_17
67 | - libgomp=9.3.0=h5101ec6_17
68 | - libpng=1.6.37=hbc83047_0
69 | - libprotobuf=3.19.1=h4ff587b_0
70 | - libstdcxx-ng=9.3.0=hd4cf53a_17
71 | - libtiff=4.2.0=h85742a9_0
72 | - libuuid=1.0.3=h7f8727e_2
73 | - libuv=1.40.0=h7b6447c_0
74 | - libwebp=1.2.0=h89dd481_0
75 | - libwebp-base=1.2.0=h27cfd23_0
76 | - libxcb=1.14=h7b6447c_0
77 | - libxml2=2.9.12=h03d6c58_0
78 | - llvm-openmp=8.0.1=hc9558a2_0
79 | - lz4-c=1.9.3=h295c915_1
80 | - markdown=3.3.4=py38h06a4308_0
81 | - matplotlib=3.5.1=py38h06a4308_1
82 | - matplotlib-base=3.5.1=py38ha18d171_1
83 | - mkl=2021.4.0=h06a4308_640
84 | - mkl-service=2.4.0=py38h7f8727e_0
85 | - mkl_fft=1.3.1=py38hd3c417c_0
86 | - mkl_random=1.2.2=py38h51133e4_0
87 | - multidict=5.2.0=py38h7f8727e_2
88 | - munkres=1.1.4=py_0
89 | - ncurses=6.3=h7f8727e_2
90 | - ninja=1.10.2=py38hd09550d_3
91 | - numexpr=2.8.1=py38h6abb31d_0
92 | - numpy=1.21.2=py38h20f2e39_0
93 | - numpy-base=1.21.2=py38h79a1101_0
94 | - nvidia-ml=7.352.0=pyhd3eb1b0_0
95 | - nvidiacub=1.10.0=0
96 | - oauthlib=3.2.0=pyhd3eb1b0_0
97 | - openmp=8.0.1=0
98 | - openssl=1.1.1o=h7f8727e_0
99 | - packaging=21.3=pyhd3eb1b0_0
100 | - pandas=1.4.1=py38h295c915_1
101 | - pcre=8.45=h295c915_0
102 | - pillow=9.0.1=py38h22f2fdc_0
103 | - pip=21.2.4=py38h06a4308_0
104 | - point_cloud_utils=0.18.0=py38hc10631b_1
105 | - portalocker=2.3.2=py38h578d9bd_1
106 | - protobuf=3.19.1=py38h295c915_0
107 | - psutil=5.8.0=py38h27cfd23_1
108 | - pyasn1=0.4.8=pyhd3eb1b0_0
109 | - pyasn1-modules=0.2.8=py_0
110 | - pycparser=2.21=pyhd3eb1b0_0
111 | - pyjwt=2.1.0=py38h06a4308_0
112 | - pyopenssl=21.0.0=pyhd3eb1b0_1
113 | - pyparsing=3.0.4=pyhd3eb1b0_0
114 | - pyqt=5.9.2=py38h05f1152_4
115 | - pysocks=1.7.1=py38h06a4308_0
116 | - python=3.8.13=h12debd9_0
117 | - python-dateutil=2.8.2=pyhd3eb1b0_0
118 | - python_abi=3.8=2_cp38
119 | - pytorch=1.7.1=py3.8_cuda10.2.89_cudnn7.6.5_0
120 | - pytorch3d=0.5.0=py38_cu102_pyt171
121 | - pytz=2021.3=pyhd3eb1b0_0
122 | - pyyaml=6.0=py38h7f8727e_1
123 | - qt=5.9.7=h5867ecd_1
124 | - readline=8.1.2=h7f8727e_1
125 | - requests=2.27.1=pyhd3eb1b0_0
126 | - requests-oauthlib=1.3.0=py_0
127 | - rsa=4.7.2=pyhd3eb1b0_1
128 | - scikit-learn=1.0.2=py38h51133e4_1
129 | - scipy=1.7.3=py38hc147768_0
130 | - setuptools=58.0.4=py38h06a4308_0
131 | - sip=4.19.13=py38h295c915_0
132 | - six=1.16.0=pyhd3eb1b0_1
133 | - sqlite=3.38.2=hc218d9a_0
134 | - tabulate=0.8.9=py38h06a4308_0
135 | - tensorboard=2.6.0=py_1
136 | - tensorboard-data-server=0.6.0=py38hca6d32c_0
137 | - tensorboard-plugin-wit=1.6.0=py_0
138 | - termcolor=1.1.0=py_2
139 | - threadpoolctl=2.2.0=pyh0d69192_0
140 | - tk=8.6.11=h1ccaba5_0
141 | - torchvision=0.8.2=py38_cu102
142 | - tornado=6.1=py38h27cfd23_0
143 | - tqdm=4.63.0=pyhd3eb1b0_0
144 | - typing-extensions=4.1.1=hd3eb1b0_0
145 | - typing_extensions=4.1.1=pyh06a4308_0
146 | - urllib3=1.26.8=pyhd3eb1b0_0
147 | - werkzeug=2.0.3=pyhd3eb1b0_0
148 | - wheel=0.37.1=pyhd3eb1b0_0
149 | - xz=5.2.5=h7b6447c_0
150 | - yacs=0.1.6=pyhd3eb1b0_1
151 | - yaml=0.2.5=h7b6447c_0
152 | - yarl=1.6.3=py38h27cfd23_0
153 | - zipp=3.7.0=pyhd3eb1b0_0
154 | - zlib=1.2.11=h7f8727e_4
155 | - zstd=1.4.9=haebb681_0
156 | - pip:
157 | - easydict==1.9
158 | - jinja2==3.1.1
159 | - markupsafe==2.1.1
160 | - torch-cluster==1.5.9
161 | - torch-geometric==2.0.4
162 | - torch-scatter==2.0.5
163 | - torch-sparse==0.6.8
164 | - torch-spline-conv==1.2.1
165 | prefix: /home/chenhaolan/.conda/envs/deeprs
166 |
--------------------------------------------------------------------------------
/models/resampler.py:
--------------------------------------------------------------------------------
1 | from pytorch3d.ops.knn import knn_points
2 | import torch
3 | from torch.nn import Module
4 | import pytorch3d.ops
5 | from .common import *
6 |
7 | from .encoders import get_encoder
8 | from .vecfields import get_vecfield
9 |
10 | class PointSetResampler(Module):
11 |
12 | def __init__(self, config):
13 | super().__init__()
14 | self.cfg = config
15 | #print(config)
16 | self.encoder = get_encoder(config.encoder)
17 | self.vecfield = get_vecfield(config.vecfield, self.encoder.out_channels)
18 |
19 | def forward(self, p_query, p_ctx):
20 | """
21 | Args:
22 | p_query: Query point set, (B, N_query, 3).
23 | p_ctx: Context point set, (B, N_ctx, 3).
24 | Returns:
25 | (B, N_query, 3)
26 | """
27 | h_ctx = self.encoder(p_ctx) # (B, N_query, H), Features of each point in `p_ctx`.
28 | vec = self.vecfield(p_query, p_ctx, h_ctx) # (B, N_query, 3)
29 | return vec
30 |
31 | def get_loss_vec(self, p_query, p_ctx, vec_gt):
32 | """
33 | Computes loss according to ground truth vectors.
34 | Args:
35 | vec_gt: Ground truth vectors, (B, N_query, 3).
36 | """
37 | vec_pred = self(p_query, p_ctx) # (B, N_query, 3)
38 | loss = ((vec_pred - vec_gt) ** 2.0).sum(dim=-1).mean()
39 | return loss
40 |
41 | def get_loss_pc(self, p_query, p_ctx, p_gt, avg_knn):
42 | """
43 | Computes loss according to ground truth point clouds.
44 | Args:
45 | p_gt: Ground truth point clouds, (B, N_gt, 3).
46 | avg_knn: For each point in `p_query`, use how many nearest points in `p_gt`
47 | to estimate the ground truth vector.
48 | """
49 | _, _, gt_nbs = pytorch3d.ops.knn_points(
50 | p_query,
51 | p_gt,
52 | K=avg_knn,
53 | return_nn=True,
54 | ) # (B, p_query, K, 3)
55 |
56 | vec_gt = (gt_nbs - p_query.unsqueeze(-2)).mean(-2) # (B, N_query, 3)
57 | return self.get_loss_vec(p_query, p_ctx, vec_gt)
58 | def get_cd_loss(self, ref , gen):
59 | P = batch_pairwise_dist(ref, gen)
60 | mins, _ = torch.min(P, 1)
61 | loss_1 = torch.mean(mins)
62 | mins, _ = torch.min(P, 2)
63 | loss_2 = torch.mean(mins)
64 | return loss_1 + loss_2
65 | @torch.no_grad()
66 | def resample(self, p_init, p_ctx, step_size=0.2, step_decay=0.95, num_steps=40):
67 | traj = [p_init.clone().cpu()]
68 | h_ctx = self.encoder(p_ctx) # (B, N_ctx, 3)
69 |
70 | p_current = p_init
71 | for step in range(num_steps):
72 | vec_pred = self.vecfield(p_current, p_ctx, h_ctx) # (B, N_query, 3)
73 | # vec_pred: f(x_i) -> v_i + dL(x)/dx
74 | s = step_size * (step_decay ** step)
75 | p_next = p_current + s * vec_pred
76 | if step % 5 == 4 :
77 | traj.append(p_next.clone().cpu())
78 | p_current = p_next
79 | return p_current, traj
80 | def get_repulsion_loss(self, p_cur):
81 | h = 7e-3
82 | dist2, _, _ = pytorch3d.ops.knn_points(
83 | p1=p_cur,
84 | p2=p_cur,
85 | K=7,
86 | return_nn=True
87 | ) # (B, N_query,K ) , dist2 is a squared number
88 | dist2 = dist2[: , :, 1:] # dist2[: ,:, 0] = 0
89 | dist = torch.sqrt(dist2)
90 | #print(dist.mean())
91 | weight = torch.exp(-dist2 / h ** 2)
92 | loss = torch.mean((- dist) * weight)
93 | #print('Loss %.6f' % (
94 | # 0.05*loss.item(),
95 | #))
96 | return 0.05*loss + 1e-4
97 | '''
98 | def repulsion(self, p_cur) :
99 | _, _, knn_points = pytorch3d.ops.knn_points(
100 | p1=p_cur,
101 | p2=p_cur,
102 | K=6,
103 | return_nn=True
104 | ) # (B, N_query,K , 3)
105 | p_rel = knn_points - p_cur.unsqueeze(-2) # (B, N_query, K, 3)
106 | p_rel = p_rel[:,:,1:,:]
107 | p_normed = torch.norm(p_rel, p=2, dim = -1,keepdim=True)
108 | p_normed = p_rel / p_normed
109 | dist = torch.sqrt((p_rel ** 2 ).sum(dim = -1)) # (B,N_query,K)
110 | #p_normed = (-dist * torch.exp ( -1 * dist ** 2 / 6e-3)).unsqueeze(-1) * p_normed
111 | p_normed = -(5e-4/dist ** 2).unsqueeze(-1) * p_normed
112 | target = torch.sum(p_normed,dim = -2 , keepdim=False) # target of the repulsion power
113 | return target
114 | '''
115 | def glr (self, p_cur):
116 | _, _, knn_points = pytorch3d.ops.knn_points(
117 | p1=p_cur,
118 | p2=p_cur,
119 | K=6,
120 | return_nn=True
121 | ) # (B, N_query,K , 3)
122 | p_rel = p_cur.unsqueeze(-2) - knn_points # (B, N_query, K, 3)
123 | p_rel = p_rel[:,:,1:,:]
124 | glr_grad = 2 * p_rel
125 | dist = (p_rel ** 2 ).sum(dim = -1) # (B,N_query,K)
126 | target = torch.exp ( -1 * dist ** 2 / 1e-9).unsqueeze(-1) * glr_grad
127 | target = torch.sum(target,dim = -2 , keepdim=False)
128 | return target
129 | def gtv (self, p_cur) :
130 | _, _, knn_points = pytorch3d.ops.knn_points(
131 | p1=p_cur,
132 | p2=p_cur,
133 | K=6,
134 | return_nn=True
135 | ) # (B, N_query,K , 3)
136 | p_rel = (p_cur.unsqueeze(-2) - knn_points)
137 | p_rel = p_rel[:,:,1:,:]
138 | gtv_grad = p_rel / (torch.abs(p_rel)+1e-7)
139 | dist = (p_rel ** 2 ).sum(dim = -1) # (B,N_query,K)
140 | target = torch.exp ( -1 * dist ** 2 / 5e-9).unsqueeze(-1) * gtv_grad
141 | target = torch.sum(target,dim = -2 , keepdim=False)
142 | return target
--------------------------------------------------------------------------------
/scripts/train_denoise.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | import yaml
5 | from easydict import EasyDict
6 | from tqdm.auto import tqdm
7 |
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.utils.tensorboard
11 | from models.resampler import PointSetResampler
12 | from models.common import chamfer_distance_unit_sphere
13 | from utils.datasets import PointCloudDataset, PairedPatchDataset
14 | from utils.transforms import *
15 | from utils.misc import *
16 | from utils.denoise import patch_based_denoise
17 |
18 |
19 | if __name__ == '__main__':
20 | best = 10000
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--config', type=str )
23 | parser.add_argument('--device', type=str, default='cuda')
24 | parser.add_argument('--log_root', type=str, default='./logs')
25 | args = parser.parse_args()
26 |
27 | # Load configs
28 | with open(args.config, 'r') as f:
29 | config = EasyDict(yaml.safe_load(f))
30 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
31 | seed_all(config.train.seed)
32 |
33 | # Logging
34 | log_dir = get_new_log_dir(args.log_root, prefix=config_name +'_'+str(config.dataset.val_noise)+'_')
35 | print(log_dir)
36 | ckpt_dir = os.path.join(log_dir, 'checkpoints')
37 | os.makedirs(ckpt_dir, exist_ok=True)
38 | logger = get_logger('train', log_dir)
39 | writer = torch.utils.tensorboard.SummaryWriter(log_dir)
40 | logger.info(args)
41 | logger.info(config)
42 | shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
43 |
44 | # Datasets and loaders
45 | logger.info('Loading datasets')
46 | train_dset = PairedPatchDataset(
47 | datasets=[PointCloudDataset(
48 | root=config.dataset.dataset_root,
49 | dataset=config.dataset.dataset,
50 | split='train',
51 | resolution=resl,
52 | transform=standard_train_transforms(noise_std_max=config.dataset.noise_max, noise_std_min=config.dataset.noise_min, rotate=config.dataset.aug_rotate)
53 | ) for resl in config.dataset.resolutions
54 | ],
55 | patch_size=config.dataset.patch_size,
56 | patch_ratio=1.2,
57 | on_the_fly=True
58 | )
59 |
60 | val_dset = PointCloudDataset(
61 | root=config.dataset.dataset_root,
62 | dataset=config.dataset.dataset,
63 | split='test',
64 | resolution=config.dataset.resolutions[2],
65 | transform=standard_train_transforms(noise_std_max=config.dataset.val_noise, noise_std_min=config.dataset.val_noise, rotate=False, scale_d=0),
66 | )
67 | train_iter = get_data_iterator(DataLoader(train_dset, batch_size=config.train.train_batch_size, num_workers=config.train.num_workers, shuffle=True))
68 |
69 | # Model
70 | logger.info('Building model...')
71 | model = PointSetResampler(config.model).to(args.device)
72 | logger.info(repr(model))
73 |
74 | # Optimizer and Scheduler
75 | optimizer = torch.optim.Adam(
76 | model.parameters(),
77 | lr=config.train.lr,
78 | weight_decay=config.train.weight_decay,
79 | )
80 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
81 | optimizer,
82 | mode="min",
83 | factor=config.scheduler.factor,
84 | patience=config.scheduler.patience,
85 | threshold=config.scheduler.threshold,
86 | )
87 | def train(it):
88 | if it % 10000 == 0 :
89 | print(it)
90 | batch = next(train_iter)
91 | p_noisy = batch['pcl_noisy'].to(args.device)
92 | p_clean = batch['pcl_clean'].to(args.device)
93 |
94 | # Reset grad and model state
95 | optimizer.zero_grad()
96 | model.train()
97 |
98 | # Forward
99 | loss = model.get_loss_pc(
100 | p_query=p_noisy,
101 | p_ctx=p_noisy,
102 | p_gt=p_clean,
103 | avg_knn=config.train.vec_avg_knn,
104 | )
105 |
106 | # Backward
107 | loss.backward()
108 | optimizer.step()
109 |
110 | # Logging
111 | logger.info('[Train] Iter %04d | Loss %.6f' % (
112 | it, loss.item(),
113 | ))
114 | writer.add_scalar('train/loss', loss, it)
115 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it)
116 | writer.flush()
117 |
118 | def validate(it):
119 | global best
120 | all_clean = []
121 | all_denoised = []
122 | for i, data in enumerate(tqdm(val_dset, desc='Validate')):
123 | pcl_noisy = data['pcl_noisy'].to(args.device)
124 | pcl_clean = data['pcl_clean'].to(args.device)
125 | pcl_denoised = patch_based_denoise(model, pcl_noisy)
126 | all_clean.append(pcl_clean.unsqueeze(0))
127 | all_denoised.append(pcl_denoised.unsqueeze(0))
128 | all_clean = torch.cat(all_clean, dim=0)
129 | all_denoised = torch.cat(all_denoised, dim=0)
130 |
131 | avg_chamfer = chamfer_distance_unit_sphere(all_denoised, all_clean, batch_reduction='mean')[0].item()
132 | logger.info('[Val] Iter %04d | CD %.6f ' % (it, avg_chamfer))
133 | writer.add_scalar('val/chamfer', avg_chamfer, it)
134 | writer.add_mesh('val/pcl', all_denoised[:2], global_step=it)
135 | if avg_chamfer < best :
136 | best = avg_chamfer
137 | torch.save(model.state_dict(),log_dir+'/model.pth')
138 | writer.flush()
139 | scheduler.step(avg_chamfer)
140 |
141 | # Main loop
142 | logger.info('Start training...')
143 | try:
144 | for it in range(1, config.train.max_iters+1):
145 | train(it)
146 | if it % config.train.val_freq == 0 or it == config.train.max_iters:
147 | cd_loss = validate(it)
148 |
149 | except KeyboardInterrupt:
150 | logger.info('Terminating...')
151 |
--------------------------------------------------------------------------------
/utils/datasets/patch.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.utils.data import Dataset
4 | import pytorch3d.ops
5 | from tqdm.auto import tqdm
6 |
7 |
8 | def make_patches_for_pcl_pair(pcl_A, pcl_B, patch_size, num_patches, ratio):
9 | """
10 | Args:
11 | pcl_A: The first point cloud, (N, 3).
12 | pcl_B: The second point cloud, (rN, 3).
13 | patch_size: Patch size M.
14 | num_patches: Number of patches P.
15 | ratio: Ratio r.
16 | Returns:
17 | (P, M, 3), (P, rM, 3)
18 | """
19 | N = pcl_A.size(0)
20 | seed_idx = torch.randperm(N)[:num_patches] # (P, )
21 | seed_pnts = pcl_A[seed_idx].unsqueeze(0) # (1, P, 3)
22 | _, _, pat_A = pytorch3d.ops.knn_points(seed_pnts, pcl_A.unsqueeze(0), K=patch_size, return_nn=True)
23 | pat_A = pat_A[0] # (P, M, 3)
24 | _, _, pat_B = pytorch3d.ops.knn_points(seed_pnts, pcl_B.unsqueeze(0), K=int(ratio*patch_size), return_nn=True)
25 | pat_B = pat_B[0]
26 | return pat_A, pat_B
27 |
28 |
29 | class PairedPatchDataset(Dataset):
30 |
31 | def __init__(self, datasets, patch_ratio, on_the_fly=True, patch_size=1000, num_patches=1000, transform=None):
32 | super().__init__()
33 | self.datasets = datasets
34 | self.len_datasets = sum([len(dset) for dset in datasets])
35 | self.patch_ratio = patch_ratio
36 | self.patch_size = patch_size
37 | self.num_patches = num_patches
38 | self.on_the_fly = on_the_fly
39 | self.transform = transform
40 | self.patches = []
41 | # Initialize
42 | if not on_the_fly:
43 | self.make_patches()
44 |
45 | def make_patches(self):
46 | for dataset in tqdm(self.datasets, desc='MakePatch'):
47 | for data in tqdm(dataset):
48 | pat_noisy, pat_clean = make_patches_for_pcl_pair(
49 | data['pcl_noisy'],
50 | data['pcl_clean'],
51 | patch_size=self.patch_size,
52 | num_patches=self.num_patches,
53 | ratio=self.patch_ratio
54 | ) # (P, M, 3), (P, rM, 3)
55 | for i in range(pat_noisy.size(0)):
56 | self.patches.append((pat_noisy[i], pat_clean[i], ))
57 |
58 | def __len__(self):
59 | if not self.on_the_fly:
60 | return len(self.patches)
61 | else:
62 | return self.len_datasets * self.num_patches
63 |
64 |
65 | def __getitem__(self, idx):
66 | if self.on_the_fly:
67 | pcl_dset = random.choice(self.datasets)
68 | pcl_data = pcl_dset[idx % len(pcl_dset)]
69 | pat_noisy, pat_clean = make_patches_for_pcl_pair(
70 | pcl_data['pcl_noisy'],
71 | pcl_data['pcl_clean'],
72 | patch_size=self.patch_size,
73 | num_patches=1,
74 | ratio=self.patch_ratio
75 | )
76 | data = {
77 | 'pcl_noisy': pat_noisy[0],
78 | 'pcl_clean': pat_clean[0]
79 | }
80 | else:
81 | data = {
82 | 'pcl_noisy': self.patches[idx][0].clone(),
83 | 'pcl_clean': self.patches[idx][1].clone(),
84 | }
85 | if self.transform is not None:
86 | data = self.transform(data)
87 | return data
88 |
89 | class PairedUpsDataset(Dataset):
90 |
91 | def __init__(self, datasets, patch_ratio, on_the_fly=True, patch_size=1000, num_patches=200, transform=None):
92 | super().__init__()
93 | self.datasets = datasets
94 | self.len_datasets = sum([len(dset) for dset in datasets])
95 | self.patch_ratio = patch_ratio
96 | self.patch_size = patch_size
97 | self.num_patches = num_patches
98 | self.on_the_fly = on_the_fly
99 | self.transform = transform
100 | self.patches = []
101 |
102 | def __len__(self):
103 | if not self.on_the_fly:
104 | return len(self.patches)
105 | else:
106 | return self.len_datasets * self.num_patches
107 |
108 | def make_patches_for_ups_pair(self, pcl_A, pcl_B, pcl_C,patch_size, num_patches, ratio):
109 | """
110 | Args:
111 | pcl_A: The first point cloud, (N, 3).
112 | pcl_B: The second point cloud, (rN, 3).
113 | patch_size: Patch size M.
114 | num_patches: Number of patches P.
115 | ratio: Ratio r.
116 | Returns:
117 | (P, M, 3), (P, rM, 3)
118 | """
119 | N = pcl_A.size(0)
120 | seed_idx = torch.randperm(N)[:num_patches] # (P, )
121 | seed_pnts = pcl_A[seed_idx].unsqueeze(0) # (1, P, 3)
122 | _, _, pat_A = pytorch3d.ops.knn_points(seed_pnts, pcl_A.unsqueeze(0), K=patch_size, return_nn=True)
123 | pat_A = pat_A[0] # (P, M, 3)
124 | _, _, pat_B = pytorch3d.ops.knn_points(seed_pnts, pcl_B.unsqueeze(0), K=int(ratio*patch_size), return_nn=True)
125 | pat_B = pat_B[0]
126 | _, _, pat_C = pytorch3d.ops.knn_points(seed_pnts, pcl_C.unsqueeze(0), K=int(ratio*patch_size), return_nn=True)
127 | pat_C = pat_C[0]
128 | return pat_A,pat_B, pat_C
129 |
130 | def __getitem__(self, idx):
131 | pcl_dset = random.choice(self.datasets)
132 | pcl_data = pcl_dset[idx % len(pcl_dset)]
133 |
134 | pat_low, pat_noisy , pat_gt= self.make_patches_for_ups_pair(
135 | pcl_data['original'],
136 | pcl_data['ups'],
137 | pcl_data['gt'],
138 | patch_size=self.patch_size,
139 | num_patches=1,
140 | ratio=self.patch_ratio
141 | )
142 | '''
143 | pat_low, pat_noisy , pat_gt= self.make_patches_for_ups_pair(
144 | pcl_data['pcl_noisy'],
145 | pcl_data['pcl_noisy'],
146 | pcl_data['pcl_clean'],
147 | patch_size=self.patch_size,
148 | num_patches=1,
149 | ratio=self.patch_ratio
150 | )
151 | '''
152 | #pat_low = pcl_data['original']
153 | pat_low = pat_low[0]
154 | pat_noisy = pat_noisy[0]
155 | pat_gt = pat_gt[0]
156 | '''
157 | pat_low = pat_low[0]
158 | #patch?
159 | pat_noisy = pcl_data['down']
160 | pat_gt = pcl_data['gt']
161 | '''
162 | data = {
163 | 'pcl_low': pat_low,
164 | 'pcl_noisy': pat_noisy,
165 | 'pcl_gt' : pat_gt
166 | }
167 | if self.transform is not None:
168 | data = self.transform(data)
169 | return data
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import random
5 | import time
6 | import logging
7 | import logging.handlers
8 |
9 | THOUSAND = 1000
10 | MILLION = 1000000
11 |
12 |
13 | class BlackHole(object):
14 | def __setattr__(self, name, value):
15 | pass
16 | def __call__(self, *args, **kwargs):
17 | return self
18 | def __getattr__(self, name):
19 | return self
20 | class CheckpointManager(object):
21 |
22 | def __init__(self, save_dir, logger=BlackHole()):
23 | super().__init__()
24 | os.makedirs(save_dir, exist_ok=True)
25 | self.save_dir = save_dir
26 | self.ckpts = []
27 | self.logger = logger
28 |
29 | for f in os.listdir(self.save_dir):
30 | if f[:4] != 'ckpt':
31 | continue
32 | _, score, it = f.split('_')
33 | it = it.split('.')[0]
34 | self.ckpts.append({
35 | 'score': float(score),
36 | 'file': f,
37 | 'iteration': int(it),
38 | })
39 |
40 | def get_worst_ckpt_idx(self):
41 | idx = -1
42 | worst = float('-inf')
43 | for i, ckpt in enumerate(self.ckpts):
44 | if ckpt['score'] >= worst:
45 | idx = i
46 | worst = ckpt['score']
47 | return idx if idx >= 0 else None
48 |
49 | def get_best_ckpt_idx(self):
50 | idx = -1
51 | best = float('inf')
52 | for i, ckpt in enumerate(self.ckpts):
53 | if ckpt['score'] <= best:
54 | idx = i
55 | best = ckpt['score']
56 | return idx if idx >= 0 else None
57 |
58 | def get_latest_ckpt_idx(self):
59 | idx = -1
60 | latest_it = -1
61 | for i, ckpt in enumerate(self.ckpts):
62 | if ckpt['iteration'] > latest_it:
63 | idx = i
64 | latest_it = ckpt['iteration']
65 | return idx if idx >= 0 else None
66 |
67 | def save(self, model, args, score, others=None, step=None):
68 |
69 | if step is None:
70 | fname = 'ckpt_%.6f_.pt' % float(score)
71 | else:
72 | fname = 'ckpt_%.6f_%d.pt' % (float(score), int(step))
73 | path = os.path.join(self.save_dir, fname)
74 |
75 | torch.save({
76 | 'args': args,
77 | 'state_dict': model.state_dict(),
78 | 'others': others
79 | }, path)
80 |
81 | self.ckpts.append({
82 | 'score': score,
83 | 'file': fname
84 | })
85 |
86 | return True
87 |
88 | def load_best(self):
89 | idx = self.get_best_ckpt_idx()
90 | if idx is None:
91 | raise IOError('No checkpoints found.')
92 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file']))
93 | return ckpt
94 |
95 | def load_latest(self):
96 | idx = self.get_latest_ckpt_idx()
97 | if idx is None:
98 | raise IOError('No checkpoints found.')
99 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file']))
100 | return ckpt
101 |
102 | def load_selected(self, file):
103 | ckpt = torch.load(os.path.join(self.save_dir, file))
104 | return ckpt
105 |
106 | def seed_all(seed):
107 | torch.manual_seed(seed)
108 | np.random.seed(seed)
109 | random.seed(seed)
110 |
111 |
112 | def get_logger(name, log_dir=None):
113 | logger = logging.getLogger(name)
114 | logger.setLevel(logging.DEBUG)
115 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
116 |
117 | stream_handler = logging.StreamHandler()
118 | stream_handler.setLevel(logging.DEBUG)
119 | stream_handler.setFormatter(formatter)
120 | logger.addHandler(stream_handler)
121 |
122 | if log_dir is not None:
123 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
124 | file_handler.setLevel(logging.INFO)
125 | file_handler.setFormatter(formatter)
126 | logger.addHandler(file_handler)
127 |
128 | return logger
129 |
130 |
131 | def get_new_log_dir(root='./logs', postfix='', prefix=''):
132 | log_dir = os.path.join(root, prefix + time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) + postfix)
133 | os.makedirs(log_dir)
134 | return log_dir
135 |
136 |
137 | def int_tuple(argstr):
138 | return tuple(map(int, argstr.split(',')))
139 |
140 |
141 | def str_tuple(argstr):
142 | return tuple(argstr.split(','))
143 |
144 |
145 | def int_list(argstr):
146 | return list(map(int, argstr.split(',')))
147 |
148 |
149 | def str_list(argstr):
150 | return list(argstr.split(','))
151 |
152 |
153 | def log_hyperparams(writer, log_dir, args):
154 | from torch.utils.tensorboard.summary import hparams
155 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
156 | exp, ssi, sei = hparams(vars_args, {"hp_metric": -1})
157 | fw = writer._get_file_writer()
158 | fw.add_summary(exp)
159 | fw.add_summary(ssi)
160 | fw.add_summary(sei)
161 | with open(os.path.join(log_dir, 'hparams.csv'), 'w') as csvf:
162 | csvf.write('key,value\n')
163 | for k, v in vars_args.items():
164 | csvf.write('%s,%s\n' % (k, v))
165 |
166 |
167 |
168 | def get_data_iterator(iterable):
169 | """Allows training with DataLoaders in a single infinite loop:
170 | for i, data in enumerate(inf_generator(train_loader)):
171 | """
172 | iterator = iterable.__iter__()
173 | while True:
174 | try:
175 | yield iterator.__next__()
176 | except StopIteration:
177 | iterator = iterable.__iter__()
178 |
179 |
180 | def parse_experiment_name(name):
181 | if 'blensor' in name:
182 | if 'Ours' in name:
183 | dataset, method, tag, blensor_, noise = name.split('_')[:5]
184 | else:
185 | dataset, method, blensor_, noise = name.split('_')[:4]
186 | return {
187 | 'dataset': dataset,
188 | 'method': method,
189 | 'resolution': 'blensor',
190 | 'noise': noise,
191 | }
192 |
193 | if 'real' in name:
194 | if 'Ours' in name:
195 | dataset, method, tag, blensor_, noise = name.split('_')[:5]
196 | else:
197 | dataset, method, blensor_, noise = name.split('_')[:4]
198 | return {
199 | 'dataset': dataset,
200 | 'method': method,
201 | 'resolution': 'real',
202 | 'noise': noise,
203 | }
204 |
205 | else:
206 | if 'Ours' in name:
207 | dataset, method, tag, num_pnts, sample_method, noise = name.split('_')[:6]
208 | else:
209 | dataset, method, num_pnts, sample_method, noise = name.split('_')[:5]
210 | return {
211 | 'dataset': dataset,
212 | 'method': method,
213 | 'resolution': num_pnts + '_' + sample_method,
214 | 'noise': noise,
215 | }
216 |
--------------------------------------------------------------------------------
/utils/denoise.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import pytorch3d.ops
5 | from tqdm.auto import tqdm
6 | from sklearn.cluster import KMeans
7 | from sklearn.neighbors import kneighbors_graph, KDTree
8 | from torch import nn
9 | from collections import OrderedDict
10 | import torch.nn.functional as F
11 | import numpy as np
12 | import itertools
13 | from models.common import farthest_point_sampling
14 | from .transforms import NormalizeUnitSphere
15 | def split_tensor_to_segments(x, segsize):
16 | num_segs = math.ceil(x.size(0) / segsize)
17 | segs = []
18 | for i in range(num_segs):
19 | segs.append(x[i*segsize : (i+1)*segsize])
20 | return segs
21 |
22 | def patch_based_denoise(model, pcl_noisy, step_size=0.15, num_steps=50, patch_size=4000, seed_k=5, denoise_knn=4, step_decay=0.98, get_traj=False):
23 | """
24 | Args:
25 | pcl_noisy: Input point cloud, (N, 3)
26 | """
27 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).'
28 | N, d = pcl_noisy.size()
29 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3)
30 | #seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size))
31 | #_, _, patches = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy, K=patch_size, return_nn=True)
32 | #patches = patches[0]
33 | patches = pcl_noisy
34 | with torch.no_grad():
35 | #model.eval()
36 | patches_denoised, traj = model.resample(
37 | p_init=patches,
38 | p_ctx=patches,
39 | step_size=step_size,
40 | step_decay=step_decay,
41 | num_steps=num_steps
42 | )
43 | '''
44 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N)
45 | pcl_denoised = pcl_denoised[0]
46 | fps_idx = fps_idx[0]
47 | '''
48 | if get_traj:
49 | return patches_denoised[0], traj
50 | else:
51 | return patches_denoised[0]
52 | def patch_based_upsample(model, pcl_low, pcl_noisy, step_size=0.15, num_steps=50, patch_size=512, seed_k=3, denoise_knn=4, step_decay=0.98, get_traj=False):
53 | """
54 | Args:
55 | pcl_noisy: Input point cloud, (N, 3)
56 | """
57 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).'
58 | N, d = pcl_noisy.size()
59 | M , d = pcl_low.size()
60 | rate = int(N/M)
61 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3)
62 | pcl_low = pcl_low.unsqueeze(0)
63 | #seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size))
64 | #_, _, patches_noisy = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy,K=int(patch_size) ,return_nn=True)
65 | #patches_noisy = patches_noisy[0]
66 | #_, _, patches_low = pytorch3d.ops.knn_points(seed_pnts, pcl_low, K=patch_size, return_nn=True)
67 | #patches_low = patches_low[0]
68 | patches_noisy = pcl_noisy
69 | patches_low = pcl_low
70 | with torch.no_grad():
71 | model.eval()
72 | patches_denoised, traj = model.resample(
73 | p_init=patches_noisy,
74 | p_ctx=patches_low,
75 | step_size=step_size,
76 | step_decay=step_decay,
77 | num_steps=num_steps
78 | )
79 |
80 | if get_traj :
81 | return patches_denoised[0] , traj
82 | else :
83 | return patches_denoised[0]
84 | return patches_denoised[0]
85 | '''
86 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N)
87 | pcl_denoised = pcl_denoised[0]
88 | fps_idx = fps_idx[0]
89 |
90 | if get_traj:
91 | for i in range(len(traj)):
92 | traj[i] = traj[i].view(-1, d)[fps_idx, :]
93 | return pcl_denoised, traj
94 | else:
95 | return pcl_denoised
96 | '''
97 |
98 | def patch_based_upsample_big(model, pcl_low, pcl_noisy, step_size=0.15, num_steps=50, patch_size=512, seed_k=3, denoise_knn=4, step_decay=0.99, get_traj=False):
99 | """
100 | Args:
101 | pcl_noisy: Input point cloud, (N, 3)
102 | """
103 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).'
104 | N, d = pcl_noisy.size()
105 | M , d = pcl_low.size()
106 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3)
107 | pcl_low = pcl_low.unsqueeze(0)
108 | seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size))
109 | _, _, patches_noisy = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy, K=int(patch_size), return_nn=True)
110 | patches_noisy = patches_noisy[0]
111 | _, _, patches_low = pytorch3d.ops.knn_points(seed_pnts, pcl_low, K=patch_size, return_nn=True)
112 | patches_low = patches_low[0]
113 | patches_low = split_tensor_to_segments( patches_low, 5)
114 | patches_noisy = split_tensor_to_segments(patches_noisy ,5 )
115 | n = len(patches_low)
116 | patches_denoised = []
117 | for i in range(n) :
118 | patch_denoised, traj = model.resample(
119 | p_init=patches_noisy[i],
120 | p_ctx=patches_low[i],
121 | step_size=step_size,
122 | step_decay=step_decay,
123 | num_steps=num_steps
124 | )
125 | patches_denoised.append(patch_denoised)
126 | patches_denoised = torch.cat(patches_denoised , dim=0)
127 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N)
128 | pcl_denoised = pcl_denoised[0]
129 | fps_idx = fps_idx[0]
130 |
131 | if get_traj:
132 | for i in range(len(traj)):
133 | traj[i] = traj[i].view(-1, d)[fps_idx, :]
134 | return pcl_denoised, traj
135 | else:
136 | return pcl_denoised
137 |
138 |
139 | def patch_based_denoise_big(model, pcl_noisy, step_size=0.15, num_steps=50, patch_size=10000, seed_k=3, denoise_knn=4, step_decay=0.95, get_traj=False):
140 | """
141 | Args:
142 | pcl_noisy: Input point cloud, (N, 3)
143 | """
144 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).'
145 | N, d = pcl_noisy.size()
146 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3)
147 | seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size))
148 | _, _, patches = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy, K=patch_size, return_nn=True)
149 | print("here!")
150 | patches = patches[0] # (N, K, 3)
151 | patches = split_tensor_to_segments( patches, 5)
152 | n = len(patches)
153 | patches_denoised = []
154 | for i in range(n) :
155 | patch_denoised, traj = model.resample(
156 | p_init=patches[i],
157 | p_ctx=patches[i],
158 | step_size=step_size,
159 | step_decay=step_decay,
160 | num_steps=num_steps
161 | )
162 | patches_denoised.append(patch_denoised)
163 | patches_denoised = torch.cat(patches_denoised , dim=0)
164 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N)
165 | pcl_denoised = pcl_denoised[0]
166 | fps_idx = fps_idx[0]
167 |
168 | if get_traj:
169 | for i in range(len(traj)):
170 | traj[i] = traj[i].view(-1, d)[fps_idx, :]
171 | return pcl_denoised, traj
172 | else:
173 | return pcl_denoised
174 |
175 | def denoise_large_pointcloud(model, pcl, cluster_size, seed=0):
176 | device = pcl.device
177 | pcl = pcl.cpu().numpy()
178 |
179 | print('Running KMeans to construct clusters...')
180 | n_clusters = math.ceil(pcl.shape[0] / cluster_size)
181 | print(n_clusters)
182 | kmeans = KMeans(n_clusters=n_clusters, random_state=seed).fit(pcl)
183 |
184 | pcl_parts = []
185 | for i in tqdm(range(n_clusters), desc='Denoise Cluster'):
186 | pts_idx = kmeans.labels_ == i
187 |
188 | pcl_part_noisy = torch.FloatTensor(pcl[pts_idx]).to(device)
189 | pcl_part_noisy, center, scale = NormalizeUnitSphere.normalize(pcl_part_noisy)
190 | pcl_part_denoised = patch_based_denoise(
191 | model,
192 | pcl_part_noisy,
193 | seed_k=3
194 | )
195 | pcl_part_denoised = pcl_part_denoised * scale + center
196 | pcl_parts.append(pcl_part_denoised)
197 |
198 | return torch.cat(pcl_parts, dim=0)
199 |
--------------------------------------------------------------------------------
/scripts/train_upsample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import shutil
4 | import yaml
5 | from easydict import EasyDict
6 | from tqdm.auto import tqdm
7 | import itertools
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.utils.tensorboard
11 | from models.resampler import PointSetResampler
12 | from models.common import chamfer_distance_unit_sphere,coarsesGenerator
13 | from utils.datasets import UpsampleDataset, PairedUpsDataset
14 | from utils.transforms import *
15 | from utils.misc import *
16 | from utils.denoise import patch_based_denoise, patch_based_upsample
17 | import itertools
18 | from utils.evaluate import *
19 | def batch_pairwise_dist( x, y):
20 | bs, num_points_x, points_dim = x.size()
21 | _, num_points_y, _ = y.size()
22 | xx = torch.bmm(x, x.transpose(2, 1))
23 | yy = torch.bmm(y, y.transpose(2, 1))
24 | zz = torch.bmm(x, y.transpose(2, 1))
25 | diag_ind_x = torch.arange(0, num_points_x)
26 | diag_ind_y = torch.arange(0, num_points_y)
27 | if x.get_device() != -1:
28 | diag_ind_x = diag_ind_x.cuda(x.get_device())
29 | diag_ind_y = diag_ind_y.cuda(x.get_device())
30 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1))
31 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz)
32 | P = (rx.transpose(2, 1) + ry - 2 * zz)
33 | return P
34 |
35 | def get_cd_loss( ref , gen):
36 | P = batch_pairwise_dist(ref, gen)
37 | mins, _ = torch.min(P, 1)
38 | loss_1 = torch.mean(mins)
39 | mins, _ = torch.min(P, 2)
40 | loss_2 = torch.mean(mins)
41 | return loss_1 + loss_2
42 | if __name__ == '__main__':
43 | r = 5
44 | best = 10000
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--config', type=str)
47 | parser.add_argument('--device', type=str, default='cuda')
48 | parser.add_argument('--log_root', type=str, default='./logs')
49 | args = parser.parse_args()
50 |
51 | # Load configs
52 | with open(args.config, 'r') as f:
53 | config = EasyDict(yaml.safe_load(f))
54 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
55 | seed_all(config.train.seed)
56 |
57 | # Logging
58 | log_dir = get_new_log_dir(args.log_root, prefix=config_name + '_' +'retrain_ups_gen')
59 | ckpt_mgr = CheckpointManager(log_dir)
60 | ckpt_dir = os.path.join(log_dir, 'checkpoints')
61 | os.makedirs(ckpt_dir, exist_ok=True)
62 | logger = get_logger('train', log_dir)
63 | writer = torch.utils.tensorboard.SummaryWriter(log_dir)
64 | logger.info(args)
65 | logger.info(config)
66 | shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
67 |
68 | # Datasets and loaders
69 | logger.info('Loading datasets')
70 | train_dset = PairedUpsDataset(
71 | datasets=[UpsampleDataset(
72 | root=config.dataset.dataset_root,
73 | dataset=config.dataset.dataset,
74 | rate = 4 ,
75 | split='train',
76 | resolution=resl,
77 | noise_min = 1e-2,
78 | noise_max = 2.5e-2,
79 | ) for resl in config.dataset.resolutions
80 | ],
81 | patch_size=config.dataset.patch_size,
82 | patch_ratio=4.2,
83 | on_the_fly=True
84 | )
85 |
86 | val_dset = UpsampleDataset(
87 | root=config.dataset.dataset_root,
88 | dataset=config.dataset.dataset,
89 | split='test',
90 | resolution=config.dataset.resolutions[1],
91 | rate = 4,
92 | noise_min = 1.7e-2,
93 | noise_max = 1.7e-2,
94 | need_mesh=True
95 | )
96 | train_iter = get_data_iterator(DataLoader(train_dset, batch_size=config.train.train_batch_size, num_workers=config.train.num_workers, shuffle=True))
97 |
98 | # Model
99 | logger.info('Building model...')
100 | model = PointSetResampler(config.model).to(args.device)
101 | upsampler = coarsesGenerator(rate = 8 ).to(args.device)
102 | #logger.info(repr(model))
103 | #logger.info(repr(upsampler))
104 |
105 | # Optimizer and Scheduler
106 | optimizer = torch.optim.Adam(
107 | itertools.chain(upsampler.parameters(),model.parameters()),
108 | #model.parameters(),
109 | lr=config.train.lr,
110 | weight_decay=config.train.weight_decay,
111 | )
112 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
113 | optimizer,
114 | mode="min",
115 | factor=config.scheduler.factor,
116 | patience=config.scheduler.patience,
117 | threshold=config.scheduler.threshold,
118 | )
119 | def train(it):
120 | batch = next(train_iter)
121 | pcl_low = batch['pcl_low'].to(args.device)
122 | pcl_noisy = batch['pcl_noisy'].to(args.device)
123 | pcl_gt = batch['pcl_gt'].to(args.device)
124 | # Reset grad and model state
125 | model.train()
126 | upsampler.train()
127 | optimizer.zero_grad()
128 | # Forward
129 |
130 | pcl_noisy = upsampler(pcl_low)
131 | cd_loss = get_cd_loss(
132 | gen = pcl_noisy ,
133 | ref = pcl_gt
134 | )
135 |
136 | vec_loss = model.get_loss_pc(
137 | p_query=pcl_noisy,
138 | p_ctx=pcl_low,
139 | p_gt=pcl_gt,
140 | avg_knn=config.train.vec_avg_knn,
141 | )
142 | loss = vec_loss
143 | loss += cd_loss
144 | # Backward
145 | loss.backward()
146 | optimizer.step()
147 | # Logging
148 | #logger.info('[Train] Iter %04d |cd Loss %.6f | vec Loss %.6f' % (
149 | # it, cd_loss.item(),vec_loss.item()
150 | #))
151 | logger.info('[Train] Iter %04d || vec Loss %.6f' % (
152 | it,vec_loss.item()
153 | ))
154 | writer.add_scalar('train/loss', loss, it)
155 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it)
156 | writer.flush()
157 |
158 | def validate(it):
159 | global best
160 | all_clean = []
161 | all_denoised = []
162 | #upsampler.eval()
163 | avg_p2m = 0
164 | with torch.no_grad():
165 | for i, data in enumerate(tqdm(val_dset, desc='Validate')):
166 |
167 | pcl_noisy = data['ups'].to(args.device)
168 | pcl_low = data['original'].to(args.device)
169 | pcl_clean = data['gt'].to(args.device)
170 | #if it != 1 :
171 | pcl_denoised = patch_based_upsample(model,pcl_low, pcl_noisy)
172 | #pcl_denoised = patch_based_denoise(model,pcl_noisy)
173 | avg_p2m += point_mesh_bidir_distance_single_unit_sphere(
174 | pcl=pcl_denoised,
175 | verts=data['meshes']['verts'].to(args.device),
176 | faces=data['meshes']['faces'].to(args.device)
177 | ).mean()
178 | all_clean.append(pcl_clean.unsqueeze(0))
179 | all_denoised.append(pcl_denoised.unsqueeze(0))
180 |
181 | all_clean = torch.cat(all_clean, dim=0)
182 | all_denoised = torch.cat(all_denoised, dim=0)
183 | avg_chamfer = chamfer_distance_unit_sphere(all_denoised, all_clean, batch_reduction='mean')[0].item()
184 | avg_p2m /= len(val_dset)
185 | logger.info('[Val] Iter %04d | CD %.6f P2M %.6f ' % (it, avg_chamfer , avg_p2m))
186 | writer.add_scalar('val/chamfer', avg_chamfer, it)
187 | writer.add_scalar('val/p2m', avg_p2m, it)
188 | if avg_p2m < best :
189 | best = avg_p2m
190 | torch.save(model.state_dict(),log_dir+'/model.pth')
191 | torch.save(upsampler.state_dict(),log_dir+'/upsmodel.pth')
192 | writer.flush()
193 | scheduler.step(avg_chamfer)
194 | return avg_chamfer , avg_p2m
195 |
196 | # Main loop
197 | logger.info('Start training...')
198 | try:
199 | for it in range(1, config.train.max_iters+1):
200 | train(it)
201 | if it % config.train.val_freq == 0 or it == config.train.max_iters:
202 | cd_loss , _ = validate(it)
203 | '''
204 | opt_states = {
205 | 'optimizer': optimizer.state_dict(),
206 | 'scheduler': scheduler.state_dict(),
207 | }
208 | ckpt_mgr.save(model, args, cd_loss, opt_states, step=it)
209 | '''
210 |
211 | except KeyboardInterrupt:
212 | logger.info('Terminating...')
213 |
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import numbers
4 | from numpy.core.fromnumeric import size
5 | import torch
6 | import numpy as np
7 | from torchvision.transforms import Compose
8 |
9 |
10 | class NormalizeUnitSphere(object):
11 |
12 | def __init__(self):
13 | super().__init__()
14 |
15 | @staticmethod
16 | def normalize(pcl, center=None, scale=None):
17 | """
18 | Args:
19 | pcl: The point cloud to be normalized, (N, 3)
20 | """
21 | if center is None:
22 | p_max = pcl.max(dim=0, keepdim=True)[0]
23 | p_min = pcl.min(dim=0, keepdim=True)[0]
24 | center = (p_max + p_min) / 2 # (1, 3)
25 | pcl = pcl - center
26 | if scale is None:
27 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1)
28 | pcl = pcl / scale
29 | return pcl, center, scale
30 |
31 | def __call__(self, data):
32 | assert 'pcl_noisy' not in data, 'Point clouds must be normalized before applying noise perturbation.'
33 | data['pcl_clean'], center, scale = self.normalize(data['pcl_clean'])
34 | data['center'] = center
35 | data['scale'] = scale
36 | return data
37 |
38 |
39 | class AddNoise(object):
40 |
41 | def __init__(self, noise_std_min, noise_std_max):
42 | super().__init__()
43 | self.noise_std_min = noise_std_min
44 | self.noise_std_max = noise_std_max
45 |
46 | def __call__(self, data):
47 | noise_std = random.uniform(self.noise_std_min, self.noise_std_max)
48 | data['pcl_noisy'] = data['pcl_clean'] + torch.randn_like(data['pcl_clean']) * noise_std
49 | data['noise_std'] = noise_std
50 | return data
51 |
52 |
53 | class AddLaplacianNoise(object):
54 |
55 | def __init__(self, noise_std_min, noise_std_max):
56 | super().__init__()
57 | self.noise_std_min = noise_std_min
58 | self.noise_std_max = noise_std_max
59 |
60 | def __call__(self, data):
61 | noise_std = random.uniform(self.noise_std_min, self.noise_std_max)
62 | noise = torch.FloatTensor(np.random.laplace(0, noise_std, size=data['pcl_clean'].shape)).to(data['pcl_clean'])
63 | data['pcl_noisy'] = data['pcl_clean'] + noise
64 | data['noise_std'] = noise_std
65 | return data
66 |
67 |
68 | class AddUniformBallNoise(object):
69 |
70 | def __init__(self, scale):
71 | super().__init__()
72 | self.scale = scale
73 |
74 | def __call__(self, data):
75 | N = data['pcl_clean'].shape[0]
76 | phi = np.random.uniform(0, 2*np.pi, size=N)
77 | costheta = np.random.uniform(-1, 1, size=N)
78 | u = np.random.uniform(0, 1, size=N)
79 | theta = np.arccos(costheta)
80 | r = self.scale * u ** (1/3)
81 |
82 | noise = np.zeros([N, 3])
83 | noise[:, 0] = r * np.sin(theta) * np.cos(phi)
84 | noise[:, 1] = r * np.sin(theta) * np.sin(phi)
85 | noise[:, 2] = r * np.cos(theta)
86 | noise = torch.FloatTensor(noise).to(data['pcl_clean'])
87 | data['pcl_noisy'] = data['pcl_clean'] + noise
88 | return data
89 |
90 |
91 | class AddCovNoise(object):
92 |
93 | def __init__(self, cov, std_factor=1.0):
94 | super().__init__()
95 | self.cov = torch.FloatTensor(cov)
96 | self.std_factor = std_factor
97 |
98 | def __call__(self, data):
99 | num_points = data['pcl_clean'].shape[0]
100 | noise = np.random.multivariate_normal(np.zeros(3), self.cov.numpy(), num_points) # (N, 3)
101 | noise = torch.FloatTensor(noise).to(data['pcl_clean'])
102 | data['pcl_noisy'] = data['pcl_clean'] + noise * self.std_factor
103 | data['noise_std'] = self.std_factor
104 | return data
105 |
106 |
107 | class AddDiscreteNoise(object):
108 |
109 | def __init__(self, scale, prob=0.1):
110 | super().__init__()
111 | self.scale = scale
112 | self.prob = prob
113 | self.template = np.array([
114 | [1, 0, 0],
115 | [-1, 0, 0],
116 | [0, 1, 0],
117 | [0, -1, 0],
118 | [0, 0, 1],
119 | [0, 0, -1],
120 | ], dtype=np.float32)
121 |
122 | def __call__(self, data):
123 | num_points = data['pcl_clean'].shape[0]
124 | uni_rand = np.random.uniform(size=num_points)
125 | noise = np.zeros([num_points, 3])
126 | for i in range(self.template.shape[0]):
127 | idx = np.logical_and(0.1*i <= uni_rand, uni_rand < 0.1*(i+1))
128 | noise[idx] = self.template[i].reshape(1, 3)
129 | noise = torch.FloatTensor(noise).to(data['pcl_clean'])
130 | # print(data['pcl_clean'])
131 | # print(self.scale)
132 | data['pcl_noisy'] = data['pcl_clean'] + noise * self.scale
133 | data['noise_std'] = self.scale
134 | return data
135 |
136 |
137 | class RandomScale(object):
138 |
139 | def __init__(self, scales):
140 | assert isinstance(scales, (tuple, list)) and len(scales) == 2
141 | self.scales = scales
142 |
143 | def __call__(self, data):
144 | scale = random.uniform(*self.scales)
145 | '''
146 | for k in data:
147 | if k != 'name':
148 | data[k] = data[k]* scale
149 | '''
150 | data['pcl_clean'] = data['pcl_clean'] * scale
151 | if 'pcl_noisy' in data:
152 | data['pcl_noisy'] = data['pcl_noisy'] * scale
153 | return data
154 |
155 |
156 | class RandomRotate(object):
157 |
158 | def __init__(self, degrees=180.0, axis=0):
159 | if isinstance(degrees, numbers.Number):
160 | degrees = (-abs(degrees), abs(degrees))
161 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2
162 | self.degrees = degrees
163 | self.axis = axis
164 |
165 | def __call__(self, data):
166 | degree = math.pi * random.uniform(*self.degrees) / 180.0
167 | sin, cos = math.sin(degree), math.cos(degree)
168 |
169 | if self.axis == 0:
170 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]
171 | elif self.axis == 1:
172 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]
173 | else:
174 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]
175 | matrix = torch.tensor(matrix)
176 | '''
177 | for k in data:
178 | if k != 'name' and k != 'center' and k != 'scale':
179 | data[k] = torch.matmul(data[k],matrix)
180 |
181 | '''
182 | data['pcl_clean'] = torch.matmul(data['pcl_clean'], matrix)
183 | if 'pcl_noisy' in data:
184 | data['pcl_noisy'] = torch.matmul(data['pcl_noisy'], matrix)
185 | return data
186 |
187 |
188 | def standard_train_transforms(noise_std_min, noise_std_max, scale_d=0.2, rotate=True):
189 | transforms = [
190 | #NormalizeUnitSphere(),
191 | AddNoise(noise_std_min=noise_std_min, noise_std_max=noise_std_max),
192 | RandomScale([1.0-scale_d, 1.0+scale_d]),
193 | ]
194 | if rotate:
195 | transforms += [
196 | RandomRotate(axis=0),
197 | RandomRotate(axis=1),
198 | RandomRotate(axis=2),
199 | ]
200 | return Compose(transforms)
201 | class add_varying_noise(object):
202 |
203 | def __init__(self, noise_std_min, noise_std_max):
204 | super().__init__()
205 | self.noise_std_min = noise_std_min
206 | self.noise_std_max = noise_std_max
207 |
208 | def __call__(self, data):
209 | noise_std = random.uniform(self.noise_std_min, self.noise_std_max)
210 | N = data['pcl_clean'].shape[0]
211 | p_max = data['pcl_clean'].max(dim=0, keepdim=False)[0]
212 | p_min = data['pcl_clean'].min(dim=0, keepdim=False)[0]
213 | center = (p_max + p_min) / 2 # (3,)
214 | noise_top = torch.FloatTensor(np.random.laplace(0, noise_std, size=data['pcl_clean'].shape)).to(data['pcl_clean'])
215 | phi = np.random.uniform(0, 2*np.pi, size=N)
216 | costheta = np.random.uniform(-1, 1, size=N)
217 | u = np.random.uniform(0, 1, size=N)
218 | theta = np.arccos(costheta)
219 | r = noise_std * u ** (1/3)
220 |
221 | noise_bottom= np.zeros([N, 3])
222 | noise_bottom[:, 0] = r * np.sin(theta) * np.cos(phi)
223 | noise_bottom[:, 1] = r * np.sin(theta) * np.sin(phi)
224 | noise_bottom[:, 2] = r * np.cos(theta)
225 | noise_bottom = torch.FloatTensor(noise_bottom).to(data['pcl_clean'])
226 |
227 | idx_top = data['pcl_clean'][:,0]>center[0]
228 | idx_bottom = data['pcl_clean'][:,0] <= center[0]
229 | noise_top = torch.mul(noise_top , idx_top.unsqueeze(-1).repeat(1,3))
230 | noise_bottom = torch.mul(noise_bottom , idx_bottom.unsqueeze(-1).repeat(1,3))
231 | data['pcl_noisy'] = data['pcl_clean'] + noise_top + noise_bottom
232 | data['noise_std'] = noise_std
233 | return data
234 |
235 |
--------------------------------------------------------------------------------
/utils/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch3d
4 | import pytorch3d.loss
5 | import numpy as np
6 | from scipy.spatial.transform import Rotation
7 | import pandas as pd
8 | import point_cloud_utils as pcu
9 | from tqdm.auto import tqdm
10 |
11 | from .misc import BlackHole
12 | import glob
13 | import h5py
14 | def load_h5():
15 | DATA_DIR = "./"
16 | all_data = []
17 | all_label = []
18 | h5_name = os.path.join(DATA_DIR, 'PUGAN_poisson_256_poisson_1024.h5')
19 | f = h5py.File(h5_name)
20 | data = f['data'][:].astype('float32')
21 | label = f['label'][:].astype('int64')
22 | f.close()
23 | all_data.append(data)
24 | all_label.append(label)
25 | all_data = np.concatenate(all_data, axis=0)
26 | all_label = np.concatenate(all_label, axis=0)
27 | return all_data, all_label
28 |
29 | def load_xyz(xyz_dir):
30 | all_pcls = {}
31 | for fn in tqdm(os.listdir(xyz_dir), desc='Loading'):
32 | if fn[-3:] != 'xyz':
33 | continue
34 | name = fn[:-4]
35 | path = os.path.join(xyz_dir, fn)
36 | all_pcls[name] = torch.FloatTensor(np.loadtxt(path, dtype=np.float32))
37 | return all_pcls
38 |
39 | def load_off(off_dir):
40 | all_meshes = {}
41 | for fn in tqdm(os.listdir(off_dir), desc='Loading'):
42 | if fn[-3:] != 'off':
43 | continue
44 | name = fn[:-4]
45 | path = os.path.join(off_dir, fn)
46 | verts, faces, _ = pcu.read_off(path)
47 | verts = torch.FloatTensor(verts)
48 | faces = torch.LongTensor(faces)
49 | all_meshes[name] = {'verts': verts, 'faces': faces}
50 | return all_meshes
51 |
52 |
53 | class Evaluator(object):
54 |
55 | def __init__(self, output_pcl_dir, dataset_root, dataset, summary_dir, experiment_name, device='cuda', res_gts='8192_poisson', logger=BlackHole()):
56 | super().__init__()
57 | self.output_pcl_dir = output_pcl_dir
58 | self.dataset_root = dataset_root
59 | self.dataset = dataset
60 | self.summary_dir = summary_dir
61 | self.experiment_name = experiment_name
62 | self.gts_pcl_dir = os.path.join(dataset_root, dataset, 'pointclouds', 'test', res_gts)
63 | self.gts_mesh_dir = os.path.join(dataset_root, dataset, 'meshes', 'test')
64 | self.res_gts = res_gts
65 | self.device = device
66 | self.logger = logger
67 | self.load_data()
68 |
69 | def load_data(self):
70 | self.pcls_up = load_xyz(self.output_pcl_dir)
71 | self.pcls_high = load_xyz(self.gts_pcl_dir)
72 | self.meshes = load_off(self.gts_mesh_dir)
73 | self.pcls_name = list(self.pcls_up.keys())
74 |
75 | def run(self):
76 | pcls_up, pcls_high, pcls_name = self.pcls_up, self.pcls_high, self.pcls_name
77 | results = {}
78 | for name in tqdm(pcls_name, desc='Evaluate'):
79 | pcl_up = pcls_up[name][:,:3].unsqueeze(0).to(self.device)
80 | if name not in pcls_high:
81 | self.logger.warning('Shape `%s` not found, ignored.' % name)
82 | continue
83 | pcl_high = pcls_high[name].unsqueeze(0).to(self.device)
84 | verts = self.meshes[name]['verts'].to(self.device)
85 | faces = self.meshes[name]['faces'].to(self.device)
86 |
87 | cd = pytorch3d.loss.chamfer_distance(pcl_up, pcl_high)[0].item()
88 | cd_sph = chamfer_distance_unit_sphere(pcl_up, pcl_high)[0].item()
89 | hd_sph = hausdorff_distance_unit_sphere(pcl_up, pcl_high)[0].item()
90 |
91 | # p2f = point_to_mesh_distance_single_unit_sphere(
92 | # pcl=pcl_up[0],
93 | # verts=verts,
94 | # faces=faces
95 | # ).sqrt().mean().item()
96 | if 'blensor' in self.experiment_name:
97 | rotmat = torch.FloatTensor(Rotation.from_euler('xyz', [-90, 0, 0], degrees=True).as_matrix()).to(pcl_up[0])
98 | p2f = point_mesh_bidir_distance_single_unit_sphere(
99 | pcl=pcl_up[0].matmul(rotmat.t()),
100 | verts=verts,
101 | faces=faces
102 | ).item()
103 | else:
104 | p2f = point_mesh_bidir_distance_single_unit_sphere(
105 | pcl=pcl_up[0],
106 | verts=verts,
107 | faces=faces
108 | ).item()
109 |
110 | results[name] = {
111 | 'cd': cd,
112 | 'cd_sph': cd_sph,
113 | 'p2f': p2f,
114 | 'hd_sph': hd_sph,
115 | }
116 |
117 | results = pd.DataFrame(results).transpose()
118 | res_mean = results.mean(axis=0)
119 | self.logger.info("\n" + repr(results))
120 | self.logger.info("\nMean\n" + '\n'.join([
121 | '%s\t%.12f' % (k, v) for k, v in res_mean.items()
122 | ]))
123 |
124 | update_summary(
125 | os.path.join(self.summary_dir, 'Summary_%s.csv' % self.dataset),
126 | model=self.experiment_name,
127 | metrics={
128 | # 'cd(mean)': res_mean['cd'],
129 | 'cd_sph(mean)': res_mean['cd_sph'],
130 | 'p2f(mean)': res_mean['p2f'],
131 | 'hd_sph(mean)': res_mean['hd_sph'],
132 | }
133 | )
134 |
135 |
136 | def update_summary(path, model, metrics):
137 | if os.path.exists(path):
138 | df = pd.read_csv(path, index_col=0, sep="\s*,\s*", engine='python')
139 | else:
140 | df = pd.DataFrame()
141 | for metric, value in metrics.items():
142 | setting = metric
143 | if setting not in df.columns:
144 | df[setting] = np.nan
145 | df.loc[model, setting] = value
146 | df.to_csv(path, float_format='%.12f')
147 | return df
148 | def hausdorff_distance_unit_sphere(gen, ref):
149 | """
150 | Args:
151 | gen: (B, N, 3)
152 | ref: (B, N, 3)
153 | Returns:
154 | (B, )
155 | """
156 | ref, center, scale = normalize_sphere(ref)
157 | gen = normalize_pcl(gen, center, scale)
158 |
159 | dists_ab, _, _ = pytorch3d.ops.knn_points(ref, gen, K=1)
160 | dists_ab = dists_ab[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1)
161 | # print(dists_ab)
162 |
163 | dists_ba, _, _ = pytorch3d.ops.knn_points(gen, ref, K=1)
164 | dists_ba = dists_ba[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1)
165 | # print(dists_ba)
166 |
167 | dists_hausdorff = torch.max(torch.cat([dists_ab, dists_ba], dim=1), dim=1)[0]
168 |
169 | return dists_hausdorff
170 | def normalize_sphere(pc, radius=1.0):
171 | """
172 | Args:
173 | pc: A batch of point clouds, (B, N, 3).
174 | """
175 | ## Center
176 | p_max = pc.max(dim=-2, keepdim=True)[0]
177 | p_min = pc.min(dim=-2, keepdim=True)[0]
178 | center = (p_max + p_min) / 2 # (B, 1, 3)
179 | pc = pc - center
180 | ## Scale
181 | scale = (pc ** 2).sum(dim=-1, keepdim=True).sqrt().max(dim=-2, keepdim=True)[0] / radius # (B, N, 1)
182 | pc = pc / scale
183 | return pc, center, scale
184 |
185 | def normalize_pcl(pc, center, scale):
186 | return (pc - center) / scale
187 | '''
188 | def pointwise_p2m_distance_normalized(pcl, verts, faces):
189 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.'
190 | # Normalize mesh
191 | #print(' before :verts %.6f pcl %.6f' % (verts.abs().max().item(), pcl.abs().max().item()))
192 | verts, center, scale = normalize_sphere(verts.unsqueeze(0))
193 | verts = verts[0]
194 | # Normalize pcl
195 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale)
196 | pcl = pcl[0]
197 | #print('after :verts %.6f pcl %.6f' % (verts.abs().max().item(), pcl.abs().max().item()))
198 | # Convert them to pytorch3d structures
199 | pcls = pytorch3d.structures.Pointclouds([pcl])
200 | meshes = pytorch3d.structures.Meshes([verts], [faces])
201 |
202 | # packed representation for pointclouds
203 | points = pcls.points_packed() # (P, 3)
204 | points_first_idx = pcls.cloud_to_packed_first_idx()
205 | max_points = pcls.num_points_per_cloud().max().item()
206 |
207 | # packed representation for faces
208 | verts_packed = meshes.verts_packed()
209 | faces_packed = meshes.faces_packed()
210 | tris = verts_packed[faces_packed] # (T, 3, 3)
211 | tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
212 | max_tris = meshes.num_faces_per_mesh().max().item()
213 | # point to face distance: shape (P,)
214 | point_to_face = point_face_distance(
215 | points, points_first_idx, tris, tris_first_idx, max_points
216 | )
217 | return point_to_face
218 | '''
219 | def point_mesh_bidir_distance_single_unit_sphere(pcl, verts, faces):
220 | """
221 | Args:
222 | pcl: (N, 3).
223 | verts: (M, 3).
224 | faces: LongTensor, (T, 3).
225 | Returns:
226 | Squared pointwise distances, (N, ).
227 | """
228 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.'
229 |
230 | # Normalize mesh
231 | verts, center, scale = normalize_sphere(verts.unsqueeze(0))
232 | verts = verts[0]
233 | # Normalize pcl
234 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale)
235 | pcl = pcl[0]
236 |
237 | # Convert them to pytorch3d structures
238 | pcls = pytorch3d.structures.Pointclouds([pcl])
239 | meshes = pytorch3d.structures.Meshes([verts], [faces])
240 | return pytorch3d.loss.point_mesh_face_distance(meshes, pcls)
--------------------------------------------------------------------------------
/utils/datasets/pcl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | from tqdm.auto import tqdm
6 | import point_cloud_utils as pcu
7 | import random
8 |
9 | class PointCloudDataset(Dataset):
10 | def __init__(self, root, dataset , split, resolution, from_saved = False , noise_level = None , input_root = None , need_mesh = False , transform=None):
11 | super().__init__()
12 | if from_saved == False :
13 | self.pcl_dir = os.path.join(root, dataset, 'pointclouds', split, resolution) #during training, set from_saved = False
14 | else :
15 | self.pcl_dir = os.path.join(input_root , dataset+"_"+resolution+"_"+str(noise_level)) #during evaluating, set from_saved=True
16 | print(self.pcl_dir)
17 | self.resolution = resolution
18 | self.transform = transform
19 | self.need_mesh = need_mesh
20 | if self.need_mesh == False:
21 | self.pointclouds = []
22 | self.pointcloud_names = []
23 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'):
24 | if fn[-3:] != 'xyz':
25 | continue
26 | pcl_path = os.path.join(self.pcl_dir, fn)
27 | if not os.path.exists(pcl_path):
28 | raise FileNotFoundError('File not found: %s' % pcl_path)
29 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32))
30 | self.pointclouds.append(pcl)
31 | self.pointcloud_names.append(fn[:-4])
32 | if self.need_mesh == True:
33 | self.mesh_dir = os.path.join(root, dataset, 'meshes', 'test')
34 | self.meshes = {}
35 | self.meshes_names = []
36 | self.pointclouds = {}
37 | self.pointcloud_names = []
38 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'):
39 | if fn[-3:] != 'xyz':
40 | continue
41 | pcl_path = os.path.join(self.pcl_dir, fn)
42 | if not os.path.exists(pcl_path):
43 | raise FileNotFoundError('File not found: %s' % pcl_path)
44 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32))
45 | name = fn[:-4]
46 | self.pointclouds[name] = pcl
47 | self.pointcloud_names.append(name)
48 | for fn in tqdm(os.listdir(self.mesh_dir), desc='Loading'):
49 | if fn[-3:] != 'off':
50 | continue
51 | mesh_path = os.path.join(self.mesh_dir , fn)
52 | if not os.path.exists(mesh_path):
53 | raise FileNotFoundError('File not found: %s' % pcl_path)
54 | verts, faces, = pcu.load_mesh_vf(mesh_path)
55 | verts = torch.FloatTensor(verts)
56 | faces = torch.LongTensor(faces)
57 | name = fn[:-4]
58 | self.meshes[name] = {'verts': verts, 'faces': faces}
59 | self.meshes_names.append(name)
60 | def __len__(self):
61 | return len(self.pointclouds)
62 |
63 | def __getitem__(self, idx):
64 | if self.need_mesh == False :
65 | data = {
66 | 'pcl_clean': self.pointclouds[idx].clone(),
67 | 'name': self.pointcloud_names[idx]
68 | }
69 | if self.transform is not None:
70 | data = self.transform(data)
71 | return data
72 | if self.need_mesh == True :
73 | name = self.pointcloud_names[idx]
74 | data = {
75 | 'pcl_clean': self.pointclouds[name].clone(),
76 | 'name': name,
77 | }
78 | if self.transform is not None:
79 | data = self.transform(data)
80 | data['meshes'] = self.meshes[name]
81 | return data
82 |
83 | class UpsampleDataset(Dataset):
84 |
85 | def __init__(self, root, dataset , split, resolution, rate , noise_min , noise_max , from_saved = False , input_root = None , need_mesh = False , transform=None):
86 | super().__init__()
87 | self.resolution = resolution
88 | self.rate = rate
89 | self.transform = transform
90 | self.need_mesh = need_mesh
91 | self.noise_min = noise_min
92 | self.noise_max = noise_max
93 | if from_saved == False :
94 | self.pcl_dir = os.path.join(root, dataset, 'pointclouds', split, resolution)
95 | else :
96 | self.pcl_dir = os.path.join(input_root , dataset+"_"+resolution+"_"+"0.03")
97 | self.gt_resolution = str(int( self.resolution[:self.resolution.index('_')]) * self.rate )
98 | #self.gt_resolution = "32768"
99 | self.gt_dir = os.path.join(root, dataset ,'pointclouds' , split , self.gt_resolution+"_poisson" )
100 | print(self.pcl_dir)
101 | print(self.gt_dir)
102 | if self.need_mesh == False:
103 | self.pointclouds = []
104 | self.gt_pointclouds = []
105 | self.pointcloud_names = []
106 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'):
107 | if fn[-3:] != 'xyz':
108 | continue
109 | pcl_path = os.path.join(self.pcl_dir, fn)
110 | gt_path = os.path.join(self.gt_dir , fn)
111 | if not os.path.exists(pcl_path):
112 | raise FileNotFoundError('File not found: %s' % pcl_path)
113 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32))
114 | gt = torch.FloatTensor(np.loadtxt(gt_path, dtype=np.float32))
115 | self.pointclouds.append(pcl)
116 | self.gt_pointclouds.append(gt)
117 | self.pointcloud_names.append(fn[:-4])
118 | if self.need_mesh == True:
119 | self.mesh_dir = os.path.join(root, dataset, 'meshes', 'test')
120 | self.meshes = {}
121 | self.meshes_names = []
122 | self.pointclouds = []
123 | self.gt_pointclouds = []
124 | self.pointcloud_names = []
125 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'):
126 | if fn[-3:] != 'xyz':
127 | continue
128 | pcl_path = os.path.join(self.pcl_dir, fn)
129 | gt_path = os.path.join(self.gt_dir , fn)
130 | if not os.path.exists(pcl_path):
131 | raise FileNotFoundError('File not found: %s' % pcl_path)
132 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32))
133 | gt = torch.FloatTensor(np.loadtxt(gt_path, dtype=np.float32))
134 | self.pointclouds.append(pcl)
135 | self.gt_pointclouds.append(gt)
136 | self.pointcloud_names.append(fn[:-4])
137 | for fn in tqdm(os.listdir(self.mesh_dir), desc='Loading'):
138 | if fn[-3:] != 'off':
139 | continue
140 | mesh_path = os.path.join(self.mesh_dir , fn)
141 | if not os.path.exists(mesh_path):
142 | raise FileNotFoundError('File not found: %s' % pcl_path)
143 | verts, faces, = pcu.load_mesh_vf(mesh_path)
144 | verts = torch.FloatTensor(verts)
145 | faces = torch.LongTensor(faces)
146 | name = fn[:-4]
147 | self.meshes[name] = {'verts': verts, 'faces': faces}
148 | self.meshes_names.append(name)
149 | def __len__(self):
150 | return len(self.pointclouds)
151 | def normalize(self,pcl, center=None, scale=None):
152 | """
153 | Args:
154 | pcl: The point cloud to be normalized, (N, 3)
155 | """
156 | if center is None:
157 | p_max = pcl.max(dim=0, keepdim=True)[0]
158 | p_min = pcl.min(dim=0, keepdim=True)[0]
159 | center = (p_max + p_min) / 2 # (1, 3)
160 | pcl = pcl - center
161 | if scale is None:
162 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1)
163 | pcl = pcl / scale
164 | return pcl, center, scale
165 | def __getitem__(self, idx):
166 | if self.need_mesh == False :
167 | pcl_noisy = []
168 | #gt , center , scale = self.normalize(self.gt_pointclouds[idx])
169 | gt = self.gt_pointclouds[idx].clone()
170 | #print(center,scale)
171 | original = self.pointclouds[idx]
172 | #original, _ , _ = self.normalize(self.pointclouds[idx],center , scale)
173 | for i in range(self.rate-1) :
174 | noise_std = random.uniform(self.noise_min, self.noise_max)
175 | data = original + torch.randn_like(original) * noise_std
176 | #data = original + (2*torch.rand_like(original)-1) * noise_std
177 | pcl_noisy.append(data)
178 | pcl_noisy.append(original)
179 | pcl_noisy = torch.cat(pcl_noisy,dim=0)
180 | #pcl_noisy , _ , _ = self.normalize(pcl_noisy , center , scale)
181 | data = {
182 | 'gt': gt,
183 | 'ups' : pcl_noisy,
184 | 'original' : original,
185 | 'name': self.pointcloud_names[idx]
186 | }
187 | if self.transform is not None:
188 | data = self.transform(data)
189 | return data
190 | if self.need_mesh == True :
191 | pcl_noisy = []
192 | #gt , center , scale = self.normalize(self.gt_pointclouds[idx])
193 | gt = self.gt_pointclouds[idx].clone()
194 | #print(center,scale)
195 | original = self.pointclouds[idx]
196 | #original, _ , _ = self.normalize(self.pointclouds[idx],center , scale)
197 | for i in range(self.rate-1) :
198 | noise_std = random.uniform(self.noise_min, self.noise_max)
199 | data = original + torch.randn_like(original) * noise_std
200 | #data = original + (2*torch.rand_like(original)-1) * noise_std
201 | pcl_noisy.append(data)
202 | pcl_noisy.append(original)
203 | pcl_noisy = torch.cat(pcl_noisy,dim=0)
204 | #pcl_noisy , _ , _ = self.normalize(pcl_noisy , center , scale)
205 | name = self.pointcloud_names[idx]
206 | data = {
207 | 'gt': gt,
208 | 'ups' : pcl_noisy,
209 | 'original' : original,
210 | 'name': name
211 | }
212 | if self.transform is not None:
213 | data = self.transform(data)
214 | data['meshes'] = self.meshes[name]
215 | return data
216 |
--------------------------------------------------------------------------------
/models/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import pytorch3d.loss
5 | import pytorch3d.structures
6 | import pytorch3d.ops
7 | from pytorch3d.loss.point_mesh_distance import point_face_distance
8 | from torch_cluster import fps
9 | from torch.nn import Module, ModuleList, Identity, ReLU, Parameter, Sequential, Conv2d, BatchNorm2d, Conv1d, BatchNorm1d
10 |
11 | def batch_pairwise_dist( x, y):
12 | bs, num_points_x, points_dim = x.size()
13 | _, num_points_y, _ = y.size()
14 | xx = torch.bmm(x, x.transpose(2, 1))
15 | yy = torch.bmm(y, y.transpose(2, 1))
16 | zz = torch.bmm(x, y.transpose(2, 1))
17 | diag_ind_x = torch.arange(0, num_points_x)
18 | diag_ind_y = torch.arange(0, num_points_y)
19 | if x.get_device() != -1:
20 | diag_ind_x = diag_ind_x.cuda(x.get_device())
21 | diag_ind_y = diag_ind_y.cuda(x.get_device())
22 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1))
23 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz)
24 | P = (rx.transpose(2, 1) + ry - 2 * zz)
25 | return P
26 |
27 |
28 | class EdgeConv(Module):
29 |
30 | def __init__(self, in_channels, num_layers, layer_out_dim, knn=16, aggr='max', activation='relu'):
31 | super().__init__()
32 | self.in_channels = in_channels
33 | self.knn = knn
34 | assert num_layers > 2
35 | self.num_layers = num_layers
36 | self.layer_out_dim = layer_out_dim
37 |
38 | # Densely Connected Layers
39 | self.layer_first = FullyConnected(3*in_channels, layer_out_dim, bias=True, activation=activation)
40 | self.layer_last = FullyConnected(in_channels + (num_layers - 1) * layer_out_dim, layer_out_dim, bias=True, activation=None)
41 | self.layers = ModuleList()
42 | for i in range(1, num_layers-1):
43 | self.layers.append(FullyConnected(in_channels + i * layer_out_dim, layer_out_dim, bias=True, activation=activation))
44 |
45 | self.aggr = Aggregator(aggr)
46 |
47 | @property
48 | def out_channels(self):
49 | return self.in_channels + self.num_layers * self.layer_out_dim
50 |
51 | def get_edge_feature(self, x, knn_idx):
52 | """
53 | :param x: (B, N, d)
54 | :param knn_idx: (B, N, K)
55 | :return (B, N, K, 2*d)
56 | """
57 | knn_feat = knn_group(x, knn_idx) # B * N * K * d
58 | x_tiled = x.unsqueeze(-2).expand_as(knn_feat)
59 | edge_feat = torch.cat([x_tiled, knn_feat, knn_feat - x_tiled], dim=3)
60 | return edge_feat
61 |
62 | def forward(self, x, pos):
63 | """
64 | :param x: (B, N, d)
65 | :return (B, N, d+L*c)
66 | """
67 | knn_idx = get_knn_idx(pos, pos, k=self.knn, offset=1)
68 |
69 | # First Layer
70 | edge_feat = self.get_edge_feature(x, knn_idx)
71 | y = torch.cat([
72 | self.layer_first(edge_feat), # (B, N, K, c)
73 | x.unsqueeze(-2).repeat(1, 1, self.knn, 1) # (B, N, K, d)
74 | ], dim=-1) # (B, N, K, d+c)
75 |
76 | # Intermediate Layers
77 | for layer in self.layers:
78 | y = torch.cat([
79 | layer(y), # (B, N, K, c)
80 | y, # (B, N, K, c+d)
81 | ], dim=-1) # (B, N, K, d+c+...)
82 |
83 | # Last Layer
84 | y = torch.cat([
85 | self.layer_last(y), # (B, N, K, c)
86 | y # (B, N, K, d+(L-1)*c)
87 | ], dim=-1) # (B, N, K, d+L*c)
88 |
89 | # Pooling
90 | y = self.aggr(y, dim=-2)
91 |
92 | return y
93 |
94 | def gen_grid(rate):
95 | '''
96 | in : int
97 | out : (rate , 2)
98 | '''
99 |
100 | sqrted = int(math.sqrt(rate))+1
101 | for i in range(1,sqrted+1).__reversed__():
102 | if (rate%i) == 0:
103 | num_x = i
104 | num_y = rate//i
105 | break
106 | grid_x = torch.linspace(-0.2, 0.2, num_x)
107 | grid_y = torch.linspace(-0.2, 0.2, num_y)
108 |
109 | x, y = torch.meshgrid(grid_x, grid_y)
110 | grid = torch.reshape(torch.stack([x,y],dim=-1) ,[-1,2])
111 | return grid
112 |
113 | class GCN_feature_extractor(torch.nn.Module):
114 |
115 | def __init__(self ,conv_growth_rate , knn , block_num) :
116 | super().__init__()
117 |
118 | self.layer_out_dim = conv_growth_rate
119 | self.knn = knn
120 | self.block_num = block_num
121 |
122 | dims = [conv_growth_rate , conv_growth_rate * 5, conv_growth_rate * 10]
123 | print(dims)
124 | self.layers = []
125 | self.edgeconvs = []
126 | for i in range(block_num):
127 | if i == 0 :
128 | self.layers += [nn.Conv1d(in_channels= 3 , out_channels = self.layer_out_dim, kernel_size = 1 )]
129 | self.edgeconvs += [EdgeConv(in_channels= self.layer_out_dim , num_layers= 3 , layer_out_dim= self.layer_out_dim , knn = self.knn)]
130 | else :
131 | self.layers += [nn.Conv1d(in_channels= dims[i] , out_channels = self.layer_out_dim * 2 , kernel_size = 1 )]
132 | self.edgeconvs += [EdgeConv(in_channels= self.layer_out_dim * 2 , num_layers= 3 , layer_out_dim= self.layer_out_dim , knn = self.knn)]
133 | self.layers = Sequential(*self.layers)
134 | self.edgeconvs = Sequential(*self.edgeconvs)
135 | def forward(self, points) :
136 | '''
137 | points : ( B , N , 3)
138 | out_feature : ( B , N , D)
139 | '''
140 | out_feature = points.permute(0,2,1).contiguous()
141 | for i in range(self.block_num):
142 | cur_feat = self.layers[i](out_feature).permute(0,2,1).contiguous() # (B, N ,D1)
143 | if i == 0 :
144 | out_feature = cur_feat.permute(0,2,1).contiguous()
145 | cur_feat = self.edgeconvs[i](cur_feat , points) # (B ,N ,D2)
146 | out_feature = torch.cat([out_feature.permute(0,2,1).contiguous(),cur_feat],dim=-1)
147 | out_feature = out_feature.permute(0,2,1).contiguous()
148 |
149 | return out_feature.permute(0,2,1).contiguous()
150 |
151 |
152 | class duplicate_up(torch.nn.Module):
153 |
154 | def __init__(self , rate) :
155 | super().__init__()
156 |
157 | self.rate = rate
158 | dims = [360+2 , 256, 128]
159 | conv_layers = []
160 |
161 | for i in range(len(dims)-1):
162 | conv_layers += [
163 | Conv1d(dims[i], dims[i+1], kernel_size=1),
164 | BatchNorm1d(dims[i+1]),
165 | ]
166 | if i < len(dims)-2:
167 | conv_layers += [
168 | ReLU(),
169 | ]
170 | self.layers = Sequential(*conv_layers)
171 |
172 | def forward(self, coarse_feature):
173 |
174 | B, N, d = coarse_feature.shape
175 | feat = coarse_feature.repeat(1, self.rate, 1) # ( B , N*rate , d)
176 | grid = gen_grid(self.rate).unsqueeze(0).to("cuda") # (1, rate , 2)
177 | grid = grid.repeat( B , 1 , N) # ( B , rate , 2*N)
178 | grid = grid.reshape(B, N*self.rate, 2) #(B , N*rate , 2)
179 | feat = torch.cat([feat,grid] ,dim = -1).permute(0,2,1).contiguous() # ( B , d+2 , N*R)
180 |
181 | feat = self.layers(feat)
182 | return feat
183 |
184 |
185 | class regressor (torch.nn.Module):
186 |
187 | def __init__(self) :
188 | super().__init__()
189 |
190 | dims = [128,256,64,3]
191 | conv_layers = []
192 |
193 | for i in range(len(dims)-1):
194 | conv_layers += [
195 | Conv1d(dims[i], dims[i+1], kernel_size=1, stride=1),
196 | ]
197 | if i < len(dims)-2:
198 | conv_layers += [
199 | ReLU(),
200 | ]
201 | self.layers = Sequential(*conv_layers)
202 |
203 | def forward(self , coarse_feature):
204 | B, d, N = coarse_feature.shape
205 | #coarse_feature = coarse_feature.permute(0,2,1).contiguous()
206 | coarse = self.layers(coarse_feature)
207 |
208 | coarse = coarse.permute(0,2,1).contiguous()
209 | return coarse
210 |
211 | class coarsesGenerator(torch.nn.Module):
212 |
213 | def __init__(self, rate , block_num = 3 , knn = 24 , conv_growth_rate = 24 ):
214 | super().__init__()
215 | self.rate = rate
216 | self.block_num = block_num
217 | self.conv_growth_rate = conv_growth_rate
218 | self.knn = knn
219 |
220 | self.GCN = GCN_feature_extractor( self.conv_growth_rate , self.knn, self.block_num)
221 | self.duplicate_up = duplicate_up(self.rate)
222 | self.regressor = regressor()
223 |
224 | def forward(self, points) :
225 | coarse_feature = self.GCN(points)
226 | coarse_feature = self.duplicate_up(coarse_feature)
227 | coarse = self.regressor(coarse_feature)
228 | return coarse
229 |
230 | class FullyConnected(torch.nn.Module):
231 |
232 | def __init__(self, in_features, out_features, bias=True, activation=None):
233 | super().__init__()
234 |
235 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
236 |
237 | if activation is None:
238 | self.activation = torch.nn.Identity()
239 | elif activation == 'relu':
240 | self.activation = torch.nn.ReLU()
241 | elif activation == 'elu':
242 | self.activation = torch.nn.ELU(alpha=1.0)
243 | elif activation == 'lrelu':
244 | self.activation = torch.nn.LeakyReLU(0.1)
245 | else:
246 | raise ValueError()
247 |
248 | def forward(self, x):
249 | return self.activation(self.linear(x))
250 |
251 | class FeatureExtraction(Module):
252 |
253 | def __init__(self, in_channels=3, dynamic_graph=True, conv_channels=24, num_convs=3, conv_num_layers=3, conv_layer_out_dim=12, conv_knn=16, conv_aggr='max', activation='relu'):
254 | super().__init__()
255 | self.in_channels = in_channels
256 | self.dynamic_graph = dynamic_graph
257 | self.num_convs = num_convs
258 |
259 | # Edge Convolution Units
260 | self.transforms = ModuleList()
261 | self.convs = ModuleList()
262 | for i in range(num_convs):
263 | if i == 0:
264 | trans = FullyConnected(in_channels, conv_channels, bias=True, activation=None)
265 | else:
266 | trans = FullyConnected(in_channels, conv_channels, bias=True, activation=activation)
267 | conv = EdgeConv(conv_channels, num_layers=conv_num_layers, layer_out_dim=conv_layer_out_dim, knn=conv_knn, aggr=conv_aggr, activation=activation)
268 | self.transforms.append(trans)
269 | self.convs.append(conv)
270 | in_channels = conv.out_channels
271 |
272 | @property
273 | def out_channels(self):
274 | return self.convs[-1].out_channels
275 |
276 | def dynamic_graph_forward(self, x):
277 | for i in range(self.num_convs):
278 | x = self.transforms[i](x)
279 | x = self.convs[i](x, x)
280 | return x
281 |
282 | def static_graph_forward(self, pos):
283 | x = pos
284 | for i in range(self.num_convs):
285 | x = self.transforms[i](x)
286 | x = self.convs[i](x, pos)
287 | return x
288 |
289 | def forward(self, x):
290 | if self.dynamic_graph:
291 | return self.dynamic_graph_forward(x)
292 | else:
293 | return self.static_graph_forward(x)
294 |
295 | class FCLayer(torch.nn.Module):
296 |
297 | def __init__(self, in_features, out_features, bias=True, activation=None):
298 | super().__init__()
299 |
300 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
301 |
302 | if activation is None:
303 | self.activation = torch.nn.Identity()
304 | elif activation == 'relu':
305 | self.activation = torch.nn.ReLU()
306 | elif activation == 'elu':
307 | self.activation = torch.nn.ELU(alpha=1.0)
308 | elif activation == 'lrelu':
309 | self.activation = torch.nn.LeakyReLU(0.1)
310 | else:
311 | raise ValueError()
312 |
313 | def forward(self, x):
314 | return self.activation(self.linear(x))
315 |
316 |
317 | class Aggregator(torch.nn.Module):
318 |
319 | def __init__(self, oper):
320 | super().__init__()
321 | assert oper in ('mean', 'sum', 'max')
322 | self.oper = oper
323 |
324 | def forward(self, x, dim=2):
325 | if self.oper == 'mean':
326 | return x.mean(dim=dim, keepdim=False)
327 | elif self.oper == 'sum':
328 | return x.sum(dim=dim, keepdim=False)
329 | elif self.oper == 'max':
330 | ret, _ = x.max(dim=dim, keepdim=False)
331 | return ret
332 |
333 |
334 | class ResnetBlockConv1d(nn.Module):
335 | """ 1D-Convolutional ResNet block class.
336 | Args:
337 | size_in (int): input dimension
338 | size_out (int): output dimension
339 | size_h (int): hidden dimension
340 | """
341 |
342 | def __init__(self, size_in, size_out ,
343 | norm_method='batch_norm', legacy=False):
344 | super().__init__()
345 | # Attributes
346 |
347 | self.size_in = size_in
348 | self.size_out = size_out
349 | # Submodules
350 | if norm_method == 'batch_norm':
351 | norm = nn.BatchNorm1d
352 | elif norm_method == 'sync_batch_norm':
353 | norm = nn.SyncBatchNorm
354 | else:
355 | raise Exception("Invalid norm method: %s" % norm_method)
356 |
357 | self.fc_0 = nn.Conv1d(size_in, size_out, 1)
358 | self.actvn = nn.ReLU()
359 | if size_in == size_out:
360 | self.shortcut = None
361 | else:
362 | self.shortcut = nn.Conv1d(size_in, size_out, 1)
363 |
364 | def forward(self, x):
365 | dx = self.fc_0(x)
366 | if self.shortcut is not None:
367 | x_s = self.shortcut(x)
368 | else:
369 | x_s = x
370 |
371 | out = x_s + dx
372 |
373 | return out
374 |
375 | class ResnetBlockConv2d(nn.Module):
376 | """ 2D-Convolutional ResNet block class.
377 | Args:
378 | size_in (int): input dimension
379 | size_out (int): output dimension
380 | """
381 |
382 | def __init__(self, size_in, size_out,
383 | norm_method='batch_norm', legacy=False):
384 | super().__init__()
385 | # Attributes
386 |
387 | self.size_in = size_in
388 | self.size_out = size_out
389 | # Submodules
390 | if norm_method == 'batch_norm':
391 | norm = nn.BatchNorm1d
392 | elif norm_method == 'sync_batch_norm':
393 | norm = nn.SyncBatchNorm
394 | else:
395 | raise Exception("Invalid norm method: %s" % norm_method)
396 |
397 | self.fc_0 = nn.Conv2d(size_in, size_out, (1,1))
398 | self.actvn = nn.ReLU()
399 | if size_in == size_out:
400 | self.shortcut = None
401 | else:
402 | self.shortcut = nn.Conv1d(size_in, size_out, (1,1))
403 |
404 | def forward(self, x):
405 | dx = self.fc_0(x)
406 | if self.shortcut is not None:
407 | x_s = self.shortcut(x)
408 | else:
409 | x_s = x
410 |
411 | out = x_s + dx
412 |
413 | return out
414 |
415 |
416 | def normalize_sphere(pc, radius=1.0):
417 | """
418 | Args:
419 | pc: A batch of point clouds, (B, N, 3).
420 | """
421 | ## Center
422 | p_max = pc.max(dim=-2, keepdim=True)[0]
423 | p_min = pc.min(dim=-2, keepdim=True)[0]
424 | center = (p_max + p_min) / 2 # (B, 1, 3)
425 | pc = pc - center
426 | ## Scale
427 | scale = (pc ** 2).sum(dim=-1, keepdim=True).sqrt().max(dim=-2, keepdim=True)[0] / radius # (B, N, 1)
428 | pc = pc / scale
429 | return pc, center, scale
430 |
431 |
432 | def normalize_std(pc, std=1.0):
433 | """
434 | Args:
435 | pc: A batch of point clouds, (B, N, 3).
436 | """
437 | center = pc.mean(dim=-2, keepdim=True) # (B, 1, 3)
438 | pc = pc - center
439 | scale = pc.view(pc.size(0), -1).std(dim=-1).view(pc.size(0), 1, 1) / std
440 | pc = pc / scale
441 | return pc, center, scale
442 |
443 |
444 | def normalize_pcl(pc, center, scale):
445 | return (pc - center) / scale
446 |
447 |
448 | def denormalize_pcl(pc, center, scale):
449 | return pc * scale + center
450 |
451 |
452 | def chamfer_distance_unit_sphere(gen, ref, batch_reduction='mean', point_reduction='mean'):
453 | ref, center, scale = normalize_sphere(ref)
454 | gen = normalize_pcl(gen, center, scale) #ups注释掉这两句之后会有略微的上升,因为已经noramlize过了
455 | return pytorch3d.loss.chamfer_distance(gen, ref, batch_reduction=batch_reduction, point_reduction=point_reduction)
456 |
457 |
458 | def farthest_point_sampling(pcls, num_pnts):
459 | """
460 | Args:
461 | pcls: A batch of point clouds, (B, N, 3).
462 | num_pnts: Target number of points.
463 | """
464 | ratio = 0.01 + num_pnts / pcls.size(1)
465 | sampled = []
466 | indices = []
467 | for i in range(pcls.size(0)):
468 | idx = fps(pcls[i], ratio=ratio, random_start=False)[:num_pnts]
469 | sampled.append(pcls[i:i+1, idx, :])
470 | indices.append(idx)
471 | sampled = torch.cat(sampled, dim=0)
472 | return sampled, indices
473 |
474 |
475 | def point_mesh_bidir_distance_single_unit_sphere(pcl, verts, faces):
476 | """
477 | Args:
478 | pcl: (N, 3).
479 | verts: (M, 3).
480 | faces: LongTensor, (T, 3).
481 | Returns:
482 | Squared pointwise distances, (N, ).
483 | """
484 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.'
485 |
486 | # Normalize mesh
487 | verts, center, scale = normalize_sphere(verts.unsqueeze(0))
488 | verts = verts[0]
489 | # Normalize pcl
490 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale)
491 | pcl = pcl[0]
492 |
493 | # print('%.6f %.6f' % (verts.abs().max().item(), pcl.abs().max().item()))
494 |
495 | # Convert them to pytorch3d structures
496 | pcls = pytorch3d.structures.Pointclouds([pcl])
497 | meshes = pytorch3d.structures.Meshes([verts], [faces])
498 | return pytorch3d.loss.point_mesh_face_distance(meshes, pcls)
499 |
500 |
501 | def pointwise_p2m_distance_normalized(pcl, verts, faces):
502 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.'
503 |
504 | # Normalize mesh
505 | verts, center, scale = normalize_sphere(verts.unsqueeze(0))
506 | verts = verts[0]
507 | # Normalize pcl
508 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale)
509 | pcl = pcl[0]
510 |
511 | # Convert them to pytorch3d structures
512 | pcls = pytorch3d.structures.Pointclouds([pcl])
513 | meshes = pytorch3d.structures.Meshes([verts], [faces])
514 |
515 | # packed representation for pointclouds
516 | points = pcls.points_packed() # (P, 3)
517 | points_first_idx = pcls.cloud_to_packed_first_idx()
518 | max_points = pcls.num_points_per_cloud().max().item()
519 |
520 | # packed representation for faces
521 | verts_packed = meshes.verts_packed()
522 | faces_packed = meshes.faces_packed()
523 | tris = verts_packed[faces_packed] # (T, 3, 3)
524 | tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
525 | max_tris = meshes.num_faces_per_mesh().max().item()
526 |
527 | # point to face distance: shape (P,)
528 | point_to_face = point_face_distance(
529 | points, points_first_idx, tris, tris_first_idx, max_points
530 | )
531 | return point_to_face
532 |
533 |
534 | def hausdorff_distance_unit_sphere(gen, ref):
535 | """
536 | Args:
537 | gen: (B, N, 3)
538 | ref: (B, N, 3)
539 | Returns:
540 | (B, )
541 | """
542 | ref, center, scale = normalize_sphere(ref)
543 | gen = normalize_pcl(gen, center, scale)
544 |
545 | dists_ab, _, _ = pytorch3d.ops.knn_points(ref, gen, K=1)
546 | dists_ab = dists_ab[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1)
547 | # print(dists_ab)
548 |
549 | dists_ba, _, _ = pytorch3d.ops.knn_points(gen, ref, K=1)
550 | dists_ba = dists_ba[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1)
551 | # print(dists_ba)
552 |
553 | dists_hausdorff = torch.max(torch.cat([dists_ab, dists_ba], dim=1), dim=1)[0]
554 |
555 | return dists_hausdorff
556 |
557 |
558 | def get_knn_idx(x, y, k, offset=0):
559 | """
560 | Args:
561 | x: (B, N, d)
562 | y: (B, M, d)
563 | Returns:
564 | (B, N, k)
565 | """
566 | _, knn_idx, _ = pytorch3d.ops.knn_points(x, y, K=k+offset)
567 | return knn_idx[:, :, offset:]
568 |
569 |
570 | def knn_group(x:torch.FloatTensor, idx:torch.LongTensor):
571 | """
572 | :param x: (B, N, F)
573 | :param idx: (B, M, k)
574 | :return (B, M, k, F)
575 | """
576 | B, N, F = tuple(x.size())
577 | _, M, k = tuple(idx.size())
578 |
579 | x = x.unsqueeze(1).expand(B, M, N, F)
580 | idx = idx.unsqueeze(3).expand(B, M, k, F)
581 |
582 | return torch.gather(x, dim=2, index=idx)
583 |
584 |
--------------------------------------------------------------------------------