├── teaser.png ├── ckpts ├── denoise.pth ├── model_u.pth └── upsmodel.pth ├── utils ├── __pycache__ │ ├── misc.cpython-36.pyc │ ├── misc.cpython-38.pyc │ ├── render.cpython-36.pyc │ ├── render.cpython-38.pyc │ ├── denoise.cpython-36.pyc │ ├── denoise.cpython-38.pyc │ ├── evaluate.cpython-36.pyc │ ├── evaluate.cpython-38.pyc │ ├── mitsuba.cpython-36.pyc │ ├── mitsuba.cpython-38.pyc │ ├── transforms.cpython-36.pyc │ └── transforms.cpython-38.pyc ├── datasets │ ├── __pycache__ │ │ ├── patch.cpython-36.pyc │ │ ├── patch.cpython-38.pyc │ │ ├── pcl.cpython-36.pyc │ │ ├── pcl.cpython-38.pyc │ │ ├── toy.cpython-36.pyc │ │ ├── toy.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-38.pyc │ ├── __init__.py │ ├── toy.py │ ├── patch.py │ └── pcl.py ├── mesh_to_pcl.py ├── mitsuba.py ├── render.py ├── misc.py ├── denoise.py ├── transforms.py └── evaluate.py ├── scripts ├── __pycache__ │ ├── demo.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── evaluate.cpython-38.pyc │ ├── train_toy.cpython-36.pyc │ ├── train_toy.cpython-38.pyc │ ├── upsample.cpython-36.pyc │ ├── upsample.cpython-38.pyc │ ├── validate.cpython-36.pyc │ ├── validate.cpython-38.pyc │ ├── mesh_to_pcl.cpython-36.pyc │ ├── mesh_to_pcl.cpython-38.pyc │ ├── new_upsample.cpython-36.pyc │ ├── render_noisy.cpython-36.pyc │ ├── render_noisy.cpython-38.pyc │ ├── train_main.cpython-36.pyc │ ├── train_main.cpython-38.pyc │ ├── validate_p2m.cpython-36.pyc │ ├── validate_p2m.cpython-37.pyc │ ├── validate_p2m.cpython-38.pyc │ ├── save_denoised.cpython-36.pyc │ ├── save_denoised.cpython-37.pyc │ ├── save_denoised.cpython-38.pyc │ ├── save_upsample.cpython-36.pyc │ ├── save_upsample.cpython-38.pyc │ ├── train_denoise.cpython-38.pyc │ ├── train_upsample.cpython-38.pyc │ ├── evaluate_upsample.cpython-38.pyc │ ├── validate_largePC.cpython-36.pyc │ └── validate_upsample.cpython-36.pyc ├── save_denoised.py ├── evaluate_upsample.py ├── save_upsample.py ├── evaluate.py ├── train_denoise.py └── train_upsample.py ├── models ├── __pycache__ │ ├── common.cpython-36.pyc │ ├── common.cpython-38.pyc │ ├── resampler.cpython-36.pyc │ ├── resampler.cpython-37.pyc │ └── resampler.cpython-38.pyc ├── encoders │ ├── __pycache__ │ │ ├── knn.cpython-36.pyc │ │ ├── knn.cpython-38.pyc │ │ ├── mccnn.cpython-36.pyc │ │ ├── mccnn.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-38.pyc │ ├── __init__.py │ └── knn.py ├── vecfields │ ├── __pycache__ │ │ ├── knn.cpython-36.pyc │ │ ├── knn.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── radius.cpython-36.pyc │ │ └── radius.cpython-38.pyc │ ├── __init__.py │ ├── knn.py │ └── radius.py ├── resampler.py └── common.py ├── configs ├── ups.yml └── denoise.yml ├── LICENSE ├── README.md └── environment.yml /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/teaser.png -------------------------------------------------------------------------------- /ckpts/denoise.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/ckpts/denoise.pth -------------------------------------------------------------------------------- /ckpts/model_u.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/ckpts/model_u.pth -------------------------------------------------------------------------------- /ckpts/upsmodel.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/ckpts/upsmodel.pth -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/demo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/demo.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/render.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/render.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/render.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/render.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/denoise.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/denoise.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/denoise.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/denoise.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/evaluate.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/evaluate.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mitsuba.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/mitsuba.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mitsuba.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/mitsuba.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/resampler.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/resampler.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/resampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/__pycache__/resampler.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/evaluate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/evaluate.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/train_toy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_toy.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/train_toy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_toy.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/upsample.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/upsample.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/upsample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/upsample.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/validate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/validate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /models/encoders/__pycache__/knn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/knn.cpython-36.pyc -------------------------------------------------------------------------------- /models/encoders/__pycache__/knn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/knn.cpython-38.pyc -------------------------------------------------------------------------------- /models/vecfields/__pycache__/knn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/knn.cpython-36.pyc -------------------------------------------------------------------------------- /models/vecfields/__pycache__/knn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/knn.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/mesh_to_pcl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/mesh_to_pcl.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/mesh_to_pcl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/mesh_to_pcl.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/new_upsample.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/new_upsample.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/render_noisy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/render_noisy.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/render_noisy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/render_noisy.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/train_main.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_main.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/train_main.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_main.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/validate_p2m.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_p2m.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/validate_p2m.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_p2m.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/validate_p2m.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_p2m.cpython-38.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/patch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/patch.cpython-36.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/patch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/patch.cpython-38.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/pcl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/pcl.cpython-36.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/pcl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/pcl.cpython-38.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/toy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/toy.cpython-36.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/toy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/toy.cpython-38.pyc -------------------------------------------------------------------------------- /models/encoders/__pycache__/mccnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/mccnn.cpython-36.pyc -------------------------------------------------------------------------------- /models/encoders/__pycache__/mccnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/mccnn.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/save_denoised.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_denoised.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/save_denoised.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_denoised.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/save_denoised.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_denoised.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/save_upsample.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_upsample.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/save_upsample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/save_upsample.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/train_denoise.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_denoise.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/train_upsample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/train_upsample.cpython-38.pyc -------------------------------------------------------------------------------- /models/encoders/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/vecfields/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/vecfields/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/vecfields/__pycache__/radius.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/radius.cpython-36.pyc -------------------------------------------------------------------------------- /models/vecfields/__pycache__/radius.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/models/vecfields/__pycache__/radius.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/evaluate_upsample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/evaluate_upsample.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/validate_largePC.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_largePC.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/validate_upsample.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/scripts/__pycache__/validate_upsample.cpython-36.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenhLiwnl/deep-rs/HEAD/utils/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .pcl import PointCloudDataset,UpsampleDataset 2 | from .patch import PairedPatchDataset,PairedUpsDataset 3 | from .toy import ToyPointCloudDataset 4 | -------------------------------------------------------------------------------- /models/vecfields/__init__.py: -------------------------------------------------------------------------------- 1 | from .knn import KNearestVectorField 2 | from .radius import RadiusVectorField 3 | 4 | 5 | def get_vecfield(cfg, ctx_point_feature_dim): 6 | if cfg.type == 'knn': 7 | return KNearestVectorField( 8 | knn=cfg.knn, 9 | radius = cfg.radius, 10 | ctx_point_feature_dim=ctx_point_feature_dim, 11 | ) 12 | elif cfg.type == 'radius' : 13 | return RadiusVectorField( 14 | radius = cfg.radius , 15 | ctx_point_feature_dim = ctx_point_feature_dim , 16 | num_points = cfg.num_points , 17 | style = cfg.style 18 | ) 19 | else: 20 | raise NotImplementedError('Vecfield `%s` is not implemented.' % cfg.type) 21 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .knn import KNearestGCNNPointEncoder 2 | from .mccnn import MCCNNEncoder 3 | 4 | 5 | def get_encoder(cfg): 6 | if cfg.type == 'knn': 7 | return KNearestGCNNPointEncoder( 8 | dynamic_graph=cfg.dynamic_graph, 9 | conv_channels=cfg.conv_channels, 10 | num_convs=cfg.num_convs, 11 | conv_num_fc_layers=cfg.conv_num_fc_layers, 12 | conv_growth_rate=cfg.conv_growth_rate, 13 | conv_knn=cfg.conv_knn, 14 | ) 15 | elif cfg.type == 'mccnn': 16 | return MCCNNEncoder( 17 | radius = cfg.radius , 18 | num_points = cfg.num_points, 19 | kde_bandwidth = cfg.kde_bandwidth , 20 | point_dim= cfg.point_dim, 21 | first_block_dim= cfg.first_block_dim, 22 | #hidden_dims_pointwise= cfg.hidden_dims_pointwise, 23 | ) 24 | else: 25 | raise NotImplementedError('Encoder `%s` is not implemented.' % cfg.type) 26 | -------------------------------------------------------------------------------- /configs/ups.yml: -------------------------------------------------------------------------------- 1 | # demo configuration 2 | model: 3 | encoder: 4 | type: knn 5 | dynamic_graph: false 6 | num_convs: 7 7 | conv_channels: 30 8 | conv_num_fc_layers: 4 9 | conv_growth_rate: 24 10 | conv_knn: 30 11 | vecfield: 12 | type: knn 13 | knn: 48 14 | style : normal 15 | radius: 0.2 16 | num_points : 48 17 | raise_xyz_channels: 48 18 | 19 | train: 20 | seed: 2021 21 | train_batch_size: 4 22 | num_workers: 4 23 | lr: 5.e-4 24 | weight_decay: 0. 25 | vec_avg_knn: 4 26 | max_iters: 1000000 27 | val_freq: 2000 28 | scheduler: 29 | factor: 0.8 30 | patience: 50 31 | threshold : 3.e-6 32 | 33 | dataset: 34 | method : deeprs 35 | base_dir : # directory that you save your denoising results 36 | dataset_root: # sparse and dense point cloud for training and evaluation 37 | dataset : Mixed 38 | resolutions: 39 | - '1024_poisson' 40 | - '2048_poisson' 41 | rate : 4 42 | patch_size: 512 43 | num_pnts: 10000 44 | noise_min: 0.005 45 | noise_max: 0.030 46 | aug_rotate: true 47 | val_noise: 0.010 48 | -------------------------------------------------------------------------------- /utils/datasets/toy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from tqdm.auto import tqdm 6 | 7 | 8 | class ToyPointCloudDataset(Dataset): 9 | 10 | def __init__(self, shape='plane', num_pnts=10000, size=24, transform=None): 11 | super().__init__() 12 | assert shape in ('plane', 'sphere') 13 | self.shape = shape 14 | self.size = size 15 | self.num_pnts = num_pnts 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return self.size 20 | 21 | def __getitem__(self, idx): 22 | if idx >= self.size: 23 | raise IndexError() 24 | pcl = torch.rand([self.num_pnts, 3]) 25 | if self.shape == 'plane': 26 | pcl[:, 2] = 0 27 | elif self.shape == 'sphere': 28 | pcl -= 0.5 29 | pcl /= (pcl ** 2).sum(dim=1, keepdim=True).sqrt() 30 | data = { 31 | 'pcl_clean': pcl, 32 | } 33 | if self.transform is not None: 34 | data = self.transform(data) 35 | return data 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ChenhLiwnl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/denoise.yml: -------------------------------------------------------------------------------- 1 | # demo configuration 2 | model: 3 | 4 | encoder: 5 | type: knn 6 | dynamic_graph: false 7 | num_convs: 7 8 | conv_channels: 48 9 | conv_num_fc_layers: 5 10 | conv_growth_rate: 24 11 | conv_knn: 30 12 | vecfield: 13 | type: knn 14 | knn: 24 15 | style : normal 16 | radius: 0.2 17 | num_points : 24 18 | raise_xyz_channels: 48 19 | 20 | train: 21 | seed: 2021 22 | train_batch_size: 4 23 | num_workers: 4 24 | lr: 5.e-4 25 | weight_decay: 0. 26 | vec_avg_knn: 4 27 | max_iters: 1000000 28 | val_freq: 2000 29 | scheduler: 30 | factor: 0.6 31 | patience: 100 32 | threshold : 1.e-5 33 | 34 | dataset: 35 | method : deeprs 36 | base_dir : # directory that you save your denoising results 37 | dataset_root: # clean point cloud for training and evaluation 38 | input_root: # noisy data for evaluation 39 | dataset : PUNet 40 | resolutions: 41 | - '10000_poisson' 42 | - '30000_poisson' 43 | - '10000_poisson' 44 | patch_size: 1000 45 | num_pnts: 10000 46 | noise_min: 0.005 47 | noise_max: 0.030 48 | aug_rotate: true 49 | val_noise: 0.02 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Point Set Resampling via Gradient Fields 2 | 3 | teaser 4 | 5 | [Paper](https://arxiv.org/abs/2111.02045) [code](https://github.com/ChenhLiwnl/deep-rs) 6 | 7 | The official code repository for our TPAMI paper 'Deep Point Set Resampling via Gradient Fields'. 8 | 9 | 10 | 11 | ## Installation 12 | 13 | You can install via conda environment .yaml file 14 | 15 | ```bash 16 | conda env create -f environment.yml 17 | conda activate deeprs 18 | ``` 19 | 20 | ## Dataset 21 | 22 | to release soon 23 | 24 | ## Training and Testing 25 | 26 | To train or test your model, you should first edit the corresponding config file according to the comments (fill in the file directory). Then, you can simply run 27 | 28 | ```bash 29 | ## Train a network for denoising 30 | python -m scripts.train_denoise --config ./configs/denoise.yml 31 | ## Train a network for upsampling 32 | python -m scripts.train_upsample --config ./configs/ups.yml 33 | ``` 34 | 35 | for training, and run 36 | 37 | ```bash 38 | ## Save your denoising results 39 | python -m scripts.save_denoised --config ./configs/denoise.yml [--ckpt] 40 | ## Save your upsampling results 41 | python -m scripts.train_upsample --config ./configs/ups.yml [--ckpt_model] [--ckpt_gen] 42 | ``` 43 | 44 | for testing. 45 | 46 | ## Evaluation 47 | 48 | You can run 49 | 50 | ```bash 51 | python -m scripts.evaluate --config ./configs/denoise.yml 52 | python -m scripts.evaluate_upsample --config ./configs/ups.yml 53 | ``` 54 | 55 | to evaluate your results 56 | 57 | ## Citation 58 | 59 | If you feel this work helpful, please cite 60 | 61 | ``` 62 | @ARTICLE{9775211, 63 | author={Chen, Haolan and Du, Bi'an and Luo, Shitong and Hu, Wei}, 64 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 65 | title={Deep Point Set Resampling via Gradient Fields}, 66 | year={2022}, 67 | volume={}, 68 | number={}, 69 | pages={1-1}, 70 | doi={10.1109/TPAMI.2022.3175183}} 71 | ``` 72 | 73 | and contact chenhl99@pku.edu.cn for any question. -------------------------------------------------------------------------------- /scripts/save_denoised.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | from types import CodeType 5 | import yaml 6 | from easydict import EasyDict 7 | from tqdm.auto import tqdm 8 | import random 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.utils.tensorboard 12 | from models.resampler import PointSetResampler 13 | from models.common import chamfer_distance_unit_sphere 14 | from utils.datasets import PointCloudDataset, PairedPatchDataset 15 | from utils.transforms import * 16 | from utils.misc import * 17 | from utils.denoise import patch_based_denoise , patch_based_denoise_big,denoise_large_pointcloud 18 | from .validate_p2m import rotate 19 | def denormalize ( pc , center , scale): 20 | return pc * scale + center 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--config', type=str) 24 | parser.add_argument('--device', type=str, default='cuda') 25 | parser.add_argument('--log_root', type=str, default='./logs') 26 | parser.add_argument('--ckpt' , type = str ,default = './ckpts/denoise.pth' ) 27 | args = parser.parse_args() 28 | 29 | # Load configs 30 | with open(args.config, 'r') as f: 31 | config = EasyDict(yaml.safe_load(f)) 32 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')] 33 | seed_all(config.train.seed) 34 | 35 | # Datasets and loaders 36 | 37 | val_dset = PointCloudDataset( 38 | root=config.dataset.dataset_root, 39 | dataset=config.dataset.dataset, 40 | split='test', 41 | resolution=config.dataset.resolutions[0], 42 | need_mesh=False, 43 | from_saved= True, 44 | noise_level = config.dataset.val_noise, 45 | input_root = config.dataset.input_root, 46 | transform = NormalizeUnitSphere() , 47 | ) 48 | print("resolution is %s" % (config.dataset.resolutions[0])) 49 | # Model 50 | model = PointSetResampler(config.model).to(args.device) 51 | model.load_state_dict(torch.load(args.ckpt)) 52 | def validate(it): 53 | model.eval() 54 | filepath = config.dataset.dataset+"_" + config.dataset.method+"_" + config.dataset.resolutions[0] +"_" +str(config.dataset.val_noise) 55 | file = os.path.join( config.dataset.base_dir, filepath ) 56 | os.makedirs(file, exist_ok=True) 57 | print(file) 58 | for i, data in enumerate(tqdm(val_dset, desc='Validate')): 59 | pcl_noisy = data['pcl_clean'].to(args.device) 60 | center = data['center'].to(args.device) 61 | scale = data['scale'].to(args.device) 62 | pcl_denoised = patch_based_denoise(model, pcl_noisy) 63 | pcl_denoised = denormalize(pcl_denoised , center , scale) 64 | filename = os.path.join(file, data['name']+".xyz") 65 | np.savetxt( filename, pcl_denoised.cpu().numpy()) 66 | 67 | print('[Val] noise %s , %s | Finished Saving ' % (config.dataset.val_noise, config.dataset.dataset)) 68 | 69 | # Main loop 70 | try: 71 | cd_loss = validate(0) 72 | 73 | except KeyboardInterrupt: 74 | print('Terminating...') 75 | -------------------------------------------------------------------------------- /utils/mesh_to_pcl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import multiprocessing as mp 4 | import logging 5 | import numpy as np 6 | import point_cloud_utils as pcu 7 | logging.basicConfig(level=logging.DEBUG) 8 | 9 | ''' 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--dataset_root', type=str, default='./data/PUNet') 12 | args = parser.parse_args() 13 | 14 | DATASET_ROOT = args.dataset_root 15 | ''' 16 | DATASET_ROOT = '/172.31.222.52_data/chenhaolan/Mixed' 17 | MESH_ROOT = os.path.join(DATASET_ROOT, 'meshes') 18 | SAVE_ROOT = '/172.31.222.52_data/chenhaolan/Mixed/' 19 | POINTCLOUD_ROOT = os.path.join(SAVE_ROOT, 'pointclouds') 20 | RESOLUTIONS = [1024, 2048 , 4096, 8192 , 16384 , 32768 ,2048*12 ,2048*20 , 2048*32 , 2048*24] 21 | SAMPLERS = ['poisson'] 22 | #SUBSETS = ['simple', 'medium', 'complex', 'train', 'test'] 23 | SUBSETS = ['train','test'] 24 | NUM_WORKERS = 32 25 | 26 | 27 | def poisson_sample(v, f, n, num_points): 28 | pc, _ = pcu.sample_mesh_poisson_disk(v, f, n, num_points, use_geodesic_distance=True) 29 | if pc.shape[0] > num_points: 30 | pc = pc[:num_points, :] 31 | else: 32 | compl, _ = pcu.sample_mesh_random(v, f, n, num_points - pc.shape[0]) 33 | # Notice: if (num_points - pc.shape[0]) == 1, sample_mesh_random will 34 | # return a tensor of size (3, ) but not (1, 3) 35 | compl = np.reshape(compl, [-1, 3]) 36 | pc = np.concatenate([pc, compl], axis=0) 37 | return pc 38 | 39 | 40 | def random_sample(v, f, n, num_points): 41 | pc, _ = pcu.sample_mesh_random(v, f, n, num_points) 42 | return pc 43 | 44 | 45 | def enum_configs(): 46 | for subset in SUBSETS: 47 | for resolution in RESOLUTIONS: 48 | for sampler in SAMPLERS: 49 | yield (subset, resolution, sampler) 50 | 51 | 52 | def enum_meshes(): 53 | for subset, resolution, sampler in enum_configs(): 54 | in_dir = os.path.join(MESH_ROOT, subset) 55 | out_dir = os.path.join(POINTCLOUD_ROOT, subset, '%d_%s' % (resolution, sampler)) 56 | if not os.path.exists(in_dir): 57 | continue 58 | for fn in os.listdir(in_dir): 59 | if fn[-3:] == 'off': 60 | basename = fn[:-4] 61 | yield (subset, resolution, sampler, 62 | os.path.join(in_dir, fn), 63 | os.path.join(out_dir, basename+'.xyz')) 64 | 65 | 66 | def process(args): 67 | subset, resolution, sampler, in_file, out_file = args 68 | if os.path.exists(out_file): 69 | logging.info('Already exists: ' + in_file) 70 | return 71 | logging.info('Start processing: [%s,%d,%s] %s' % (subset, resolution, sampler, in_file)) 72 | os.makedirs(os.path.dirname(out_file), exist_ok=True) 73 | v, f, n = pcu.read_off(in_file) 74 | if sampler == 'poisson': 75 | pointcloud = poisson_sample(v, f, n, resolution) 76 | elif sampler == 'random': 77 | pointcloud = random_sample(v, f, n, resolution) 78 | else: 79 | raise ValueError('Unknown sampler: ' + sampler) 80 | np.savetxt(out_file, pointcloud, '%.6f') 81 | 82 | 83 | if __name__ == '__main__': 84 | if NUM_WORKERS > 1: 85 | with mp.Pool(processes=NUM_WORKERS) as pool: 86 | pool.map(process, enum_meshes()) 87 | else: 88 | for args in enum_meshes(): 89 | process(args) -------------------------------------------------------------------------------- /models/vecfields/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Linear, Conv2d, BatchNorm2d, Conv1d, BatchNorm1d, Sequential, ReLU 3 | import pytorch3d.ops 4 | import numpy as np 5 | 6 | 7 | class KNearestVectorField(Module): 8 | 9 | def __init__( 10 | self, 11 | knn, 12 | ctx_point_feature_dim, 13 | raise_xyz_channels=64, 14 | point_dim=3, 15 | radius = 0.1, 16 | hidden_dims_pointwise=[128, 128, 256], 17 | hidden_dims_global=[128, 64] 18 | ): 19 | super().__init__() 20 | self.knn = knn 21 | self.radius = radius 22 | self.raise_xyz = Linear(3, raise_xyz_channels) 23 | print(self.radius) 24 | dims = [raise_xyz_channels+ctx_point_feature_dim] + hidden_dims_pointwise 25 | conv_layers = [] 26 | for i in range(len(dims)-1): 27 | conv_layers += [ 28 | Conv2d(dims[i], dims[i+1], kernel_size=(1, 1)), 29 | BatchNorm2d(dims[i+1]), 30 | ] 31 | if i < len(dims)-2: 32 | conv_layers += [ 33 | ReLU(), 34 | ] 35 | self.pointwise_convmlp = Sequential(*conv_layers) 36 | 37 | dims = [hidden_dims_pointwise[-1]] + hidden_dims_global + [point_dim] 38 | conv_layers = [] 39 | for i in range(len(dims)-1): 40 | conv_layers += [ 41 | Conv1d(dims[i], dims[i+1], kernel_size=1), 42 | ] 43 | if i < len(dims)-2: 44 | conv_layers += [ 45 | BatchNorm1d(dims[i+1]), 46 | ReLU(), 47 | ] 48 | self.global_convmlp = Sequential(*conv_layers) 49 | 50 | 51 | def forward(self, p_query, p_context, h_context): 52 | """ 53 | Args: 54 | p_query: Query point set, (B, N_query, 3). 55 | p_context: Context point set, (B, N_ctx, 3). 56 | h_context: Point-wise features of the context point set, (B, N_ctx, H_ctx). 57 | Returns: 58 | (B, N_query, 3) 59 | """ 60 | b , N_query , _ = p_query.shape 61 | dist, knn_idx, knn_points = pytorch3d.ops.knn_points( 62 | p1=p_query, 63 | p2=p_context, 64 | K=self.knn, 65 | return_nn=True 66 | ) # (B, N_query,K), (B, N_query, K), (B, N_query, K, 3) 67 | # Relative coordinates and their embeddings 68 | p_rel = knn_points - p_query.unsqueeze(-2) # (B, N_query, K, 3) 69 | 70 | h_rel = self.raise_xyz(p_rel) # (B, N_query, K, H_rel) 71 | 72 | # Grouped features of neighbor points 73 | h_group = pytorch3d.ops.knn_gather( 74 | x=h_context, 75 | idx=knn_idx, 76 | ) # (B, N_query, K, H_ctx) 77 | 78 | # Combine 79 | h_combined = torch.cat([h_rel, h_group], dim=-1) # (B, N_query, K, H_rel+H_ctx) 80 | 81 | # Featurize 82 | h_combined = h_combined.permute(0, 3, 1, 2).contiguous() # (B, H_rel+H_ctx, N_query, K) 83 | y = self.pointwise_convmlp(h_combined) # (B, H_out, N_query, K) 84 | y = y.permute(0, 2, 3, 1).contiguous() # (B, N_query, K, H_out) 85 | dist = torch.sqrt(dist) 86 | dist = dist.unsqueeze(-1) #(B, N_query, K, 1) 87 | c = 0.5*(torch.cos(dist * np.pi / self.radius) + 1.0) 88 | c = c * (dist <= self.radius) * (dist > 0.0) # (B, N_query, K, 1) 89 | #c = 1 90 | y = torch.mul(y , c).permute(0, 3, 1, 2).contiguous() 91 | y = y.sum(-1) # (B, H_out, N_query) 92 | #y, _ = torch.max(y, dim=3) # (B, H_out, N_query) 93 | 94 | # Vectorize 95 | y = self.global_convmlp(y) # (B, 3, N_query) 96 | y = y.permute(0, 2, 1).contiguous() # (B, N_query, 3) 97 | return y 98 | -------------------------------------------------------------------------------- /scripts/evaluate_upsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import yaml 5 | from easydict import EasyDict 6 | from tqdm.auto import tqdm 7 | import random 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.utils.tensorboard 11 | from models.resampler import PointSetResampler 12 | from models.common import chamfer_distance_unit_sphere 13 | from utils.datasets import UpsampleDataset, PairedPatchDataset 14 | from utils.transforms import * 15 | from utils.misc import * 16 | from utils.evaluate import * 17 | from utils.denoise import patch_based_upsample , patch_based_upsample_big 18 | 19 | def denormalize ( pc , center , scale): 20 | return pc * scale + center 21 | def validate_cd ( data , file): 22 | pcl_clean = data['gt'].to(args.device) 23 | name = data['name'] 24 | pcl_denoised = np.loadtxt(os.path.join(file , name+".xyz")) 25 | pcl_denoised = torch.from_numpy(pcl_denoised).type(torch.FloatTensor).to(args.device) 26 | chamfer = chamfer_distance_unit_sphere(pcl_clean.unsqueeze(0),pcl_denoised.unsqueeze(0), batch_reduction='mean')[0].item() 27 | return chamfer 28 | def validate_p2m(data , file): 29 | verts = data['meshes']['verts'].to(args.device) 30 | faces = data['meshes']['faces'].to(args.device) 31 | name = data['name'] 32 | pcl_denoised = np.loadtxt(os.path.join(file , name+".xyz")) 33 | pcl_denoised = torch.from_numpy(pcl_denoised).type(torch.FloatTensor).to(args.device) 34 | p2f = point_mesh_bidir_distance_single_unit_sphere( 35 | pcl=pcl_denoised, 36 | verts=verts, 37 | faces=faces 38 | ).mean() 39 | return p2f 40 | def validate_hd(data , file): 41 | ref = data['gt'].to(args.device) 42 | name = data['name'] 43 | gen = np.loadtxt(os.path.join(file , name+".xyz")) 44 | gen = torch.from_numpy(gen).type(torch.FloatTensor).to(args.device) 45 | hd = hausdorff_distance_unit_sphere(gen = gen.unsqueeze(0) , ref = ref.unsqueeze(0)) 46 | return hd 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--config', type=str) 50 | parser.add_argument('--device', type=str, default='cuda') 51 | parser.add_argument('--log_root', type=str, default='./logs') 52 | args = parser.parse_args() 53 | 54 | # Load configs 55 | with open(args.config, 'r') as f: 56 | config = EasyDict(yaml.safe_load(f)) 57 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')] 58 | seed_all(config.train.seed) 59 | 60 | # Datasets and loaders 61 | 62 | val_dset = UpsampleDataset( 63 | root=config.dataset.dataset_root, 64 | dataset=config.dataset.dataset, 65 | split='test', 66 | resolution=config.dataset.resolutions[1], 67 | rate = 16, 68 | noise_min = 0, 69 | noise_max = 0, 70 | need_mesh=True 71 | ) 72 | print("resolution is %s" % (config.dataset.resolutions[1])) 73 | filepath = config.dataset.dataset+"_" + config.dataset.method+"_" + config.dataset.resolutions[0] +"_" + str(val_dset.rate) + "x" 74 | file = os.path.join( config.dataset.base_dir, filepath ) 75 | print(file) 76 | def validate(it): 77 | global best 78 | avg_chamfer = 0 79 | total_chamfer = 0 80 | total_p2m = 0 81 | total_hd = 0 82 | for i, data in enumerate(tqdm(val_dset, desc='Validate')): 83 | total_p2m += validate_p2m(data,file) 84 | total_chamfer += validate_cd(data, file) 85 | avg_chamfer = total_chamfer / len(val_dset) 86 | avg_p2m = total_p2m / len(val_dset) 87 | print('[Val] noise %.6f | CD %.8f ' % (config.dataset.val_noise, avg_chamfer)) 88 | print('[Val] noise %.6f | P2M%.8f ' % (config.dataset.val_noise, avg_p2m)) 89 | 90 | # Main loop 91 | try: 92 | cd_loss = validate(0) 93 | 94 | except KeyboardInterrupt: 95 | print('Terminating...') 96 | -------------------------------------------------------------------------------- /scripts/save_upsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import yaml 5 | from easydict import EasyDict 6 | from tqdm.auto import tqdm 7 | import random 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.utils.tensorboard 11 | from models.resampler import PointSetResampler 12 | from models.common import chamfer_distance_unit_sphere, coarsesGenerator 13 | from utils.datasets import UpsampleDataset, PairedPatchDataset, pcl 14 | from utils.transforms import * 15 | from utils.misc import * 16 | from utils.denoise import patch_based_upsample , patch_based_upsample_big 17 | def addnoise (pcl , noise ): 18 | pcl = pcl + torch.randn_like(pcl) * noise 19 | data['noise_std'] = noise_std 20 | return data 21 | def normalize(pcl, center=None, scale=None): 22 | if center is None: 23 | p_max = pcl.max(dim=0, keepdim=True)[0] 24 | p_min = pcl.min(dim=0, keepdim=True)[0] 25 | center = (p_max + p_min) / 2 # (1, 3) 26 | pcl = pcl - center 27 | if scale is None: 28 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1) 29 | pcl = pcl / scale 30 | return pcl, center, scale 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--config', type=str) 34 | parser.add_argument('--device', type=str, default='cuda') 35 | parser.add_argument('--log_root', type=str, default='./logs') 36 | parser.add_argument('--ckpt_model' ,type = str , default = './ckpts/model_u.pth' ) 37 | parser.add_argument('--ckpt_gen' ,type = str , default = './ckpts/upsmodel.pth' ) 38 | args = parser.parse_args() 39 | 40 | # Load configs 41 | with open(args.config, 'r') as f: 42 | config = EasyDict(yaml.safe_load(f)) 43 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')] 44 | seed_all(config.train.seed) 45 | 46 | # Datasets and loaders 47 | 48 | val_dset = UpsampleDataset( 49 | root=config.dataset.dataset_root, 50 | dataset=config.dataset.dataset, 51 | split='test', 52 | resolution=config.dataset.resolutions[1], 53 | rate=16, 54 | noise_min = 1.5e-2, 55 | noise_max = 1.5e-2 56 | ) 57 | 58 | print("resolution is %s" % (config.dataset.resolutions[1])) 59 | # Model 60 | model = PointSetResampler(config.model).to(args.device) 61 | model.load_state_dict(torch.load(args.ckpt_model)) 62 | upsampler = coarsesGenerator(rate = 16 ).to(args.device) 63 | upsampler.load_state_dict(torch.load(args.ckpt_gen)) 64 | 65 | def validate(it): 66 | model.eval() 67 | upsampler.eval() 68 | all_clean = [] 69 | all_denoised = [] 70 | filepath = config.dataset.dataset+"_" + config.dataset.method+"_" + config.dataset.resolutions[0] +"_" + str(val_dset.rate) + "x" 71 | filepath = os.path.join( config.dataset.base_dir, filepath ) 72 | os.makedirs(filepath, exist_ok=True) 73 | for i, data in enumerate(tqdm(val_dset, desc='Validate')): 74 | with torch.no_grad(): 75 | pcl_noisy = data['ups'].to(args.device) 76 | pcl_low = data['original'].to(args.device) 77 | pcl_clean = data['gt'].to(args.device) 78 | pcl_denoised = upsampler(pcl_low.unsqueeze(0)).squeeze(0) 79 | file = os.path.join(filepath , data['name']+".xyz" ) 80 | np.savetxt( file,pcl_denoised.detach().cpu().numpy() ) 81 | all_clean.append(pcl_clean.unsqueeze(0)) 82 | all_denoised.append(pcl_denoised.unsqueeze(0)) 83 | all_clean = torch.cat(all_clean, dim=0) 84 | all_denoised = torch.cat(all_denoised, dim=0) 85 | avg_chamfer = chamfer_distance_unit_sphere(all_denoised, all_clean, batch_reduction='mean')[0].item() 86 | print('[Val] Iter %04d | CD %.8f ' % (it, avg_chamfer)) 87 | 88 | # Main loop 89 | try: 90 | cd_loss = validate(0) 91 | 92 | except KeyboardInterrupt: 93 | print('Terminating...') 94 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import yaml 5 | from easydict import EasyDict 6 | from tqdm.auto import tqdm 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.utils.tensorboard 11 | from models.resampler import PointSetResampler 12 | from models.common import chamfer_distance_unit_sphere 13 | from utils.datasets import PointCloudDataset, PairedPatchDataset 14 | from utils.transforms import * 15 | from utils.misc import * 16 | from utils.evaluate import * 17 | from utils.denoise import patch_based_denoise 18 | def rotate( pc, degree): 19 | degree = math.pi * degree / 180.0 20 | sin, cos = math.sin(degree), math.cos(degree) 21 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 22 | #matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 23 | #matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 24 | matrix = torch.tensor(matrix).to("cuda") 25 | pc = torch.matmul(pc, matrix) 26 | return pc 27 | def normalize(pcl, center=None, scale=None): 28 | """ 29 | Args: 30 | pcl: The point cloud to be normalized, (N, 3) 31 | """ 32 | if center is None: 33 | p_max = pcl.max(dim=0, keepdim=True)[0] 34 | p_min = pcl.min(dim=0, keepdim=True)[0] 35 | center = (p_max + p_min) / 2 # (1, 3) 36 | pcl = pcl - center 37 | if scale is None: 38 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1) 39 | pcl = pcl / scale 40 | return pcl, center, scale 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--config', type=str) 44 | parser.add_argument('--device', type=str, default='cuda') 45 | parser.add_argument('--log_root', type=str, default='./logs') 46 | args = parser.parse_args() 47 | 48 | # Load configs 49 | with open(args.config, 'r') as f: 50 | config = EasyDict(yaml.safe_load(f)) 51 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')] 52 | seed_all(config.train.seed) 53 | 54 | # Datasets and loaders 55 | 56 | val_dset = PointCloudDataset( 57 | root=config.dataset.dataset_root, 58 | dataset=config.dataset.dataset, 59 | split='test', 60 | input_root=config.dataset.input_root, 61 | resolution=config.dataset.resolutions[0], 62 | from_saved=False, 63 | need_mesh=True, 64 | transform = NormalizeUnitSphere() , 65 | ) 66 | 67 | print("resolution is %s" % (config.dataset.resolutions[0])) 68 | method = config.dataset.method 69 | resolution = config.dataset.resolutions[0] 70 | val_noise = str(config.dataset.val_noise) 71 | filepath = config.dataset.dataset+ "_"+method +"_"+ resolution +"_" + val_noise 72 | file = os.path.join(config.dataset.base_dir, filepath ) 73 | #filepath = '/172.31.222.52_data/luost/denoisegf_data/results' 74 | file = os.path.join(filepath,file) 75 | print(file) 76 | def validate(it): 77 | global best 78 | avg_chamfer = 0 79 | avg_p2f = 0 80 | for i, data in enumerate(tqdm(val_dset, desc='Validate')): 81 | pcl_clean = data['pcl_clean'].to(args.device) 82 | verts = data['meshes']['verts'].to(args.device) 83 | faces = data['meshes']['faces'].to(args.device) 84 | name = data['name'] 85 | pcl_denoised = np.loadtxt(os.path.join(file , name+".xyz")) 86 | pcl_denoised = torch.from_numpy(pcl_denoised).type(torch.FloatTensor).to(args.device) 87 | if val_noise == 'blensor': 88 | pcl_denoised = rotate(pcl_denoised , -90) # blensor requires rotation, -90 89 | avg_p2f += point_mesh_bidir_distance_single_unit_sphere( 90 | pcl=pcl_denoised, 91 | verts=verts, 92 | faces=faces 93 | ).mean() 94 | pcl_denoised, _ , _ = normalize(pcl_denoised,data['center'].cuda(),data['scale'].cuda()) 95 | avg_chamfer += chamfer_distance_unit_sphere(pcl_clean.unsqueeze(0),pcl_denoised.unsqueeze(0), batch_reduction='mean')[0].item() 96 | 97 | avg_chamfer /= len(val_dset) 98 | avg_p2f /= len(val_dset) 99 | print('[Val] noise %s | CD %.8f ' % (config.dataset.val_noise, avg_chamfer)) 100 | print('[Val] noise %s | P2M %.8f ' % (config.dataset.val_noise, avg_p2f)) 101 | 102 | # Main loop 103 | try: 104 | cd_loss = validate(0) 105 | 106 | except KeyboardInterrupt: 107 | print('Terminating...') 108 | -------------------------------------------------------------------------------- /utils/mitsuba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def standardize_bbox(pcl): 5 | mins = np.amin(pcl, axis=0) 6 | maxs = np.amax(pcl, axis=0) 7 | center = ( mins + maxs ) / 2. 8 | scale = np.amax(maxs-mins) 9 | # print("Center: {}, Scale: {}".format(center, scale)) 10 | result = ((pcl - center)/scale).astype(np.float32) # [-0.5, 0.5] 11 | 12 | # Move down onto the floor 13 | 14 | m = result.min(0) 15 | m[:2] = 0 16 | m[2] = m[2] + 0.3 17 | result = result - m 18 | 19 | return result 20 | 21 | 22 | def downsample(pcl, colors, n_points): 23 | pt_indices = np.random.choice(pcl.shape[0], n_points, replace=False) 24 | np.random.shuffle(pt_indices) 25 | pcl = pcl[pt_indices] # n by 3 26 | colors = colors[pt_indices] 27 | return pcl, colors 28 | 29 | 30 | def select_downsample(pcl, noise, n_points, low_noise=False): 31 | if low_noise: 32 | idx = np.argpartition(noise.flatten(), n_points) 33 | pcl = pcl[idx[:n_points]] # n by 3 34 | else: 35 | np.random.shuffle(pcl) 36 | pcl = pcl[:n_points] 37 | return pcl 38 | 39 | 40 | 41 | def xml_head(): 42 | return """ 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | """ 73 | 74 | 75 | def xml_tail(): 76 | return """ 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | """ 96 | 97 | # 98 | # 99 | # 100 | 101 | def xml_sphere(x, y, z, r, g, b, radius=0.0075): 102 | tmpl = """ 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | """ 113 | return tmpl.format(radius, x, y, z, r, g, b) 114 | 115 | 116 | # def xml_icosphere(x, y, z, r, g, b, radius=0.0075): 117 | # tmpl = """ 118 | # 119 | # 120 | # 121 | # 122 | # 123 | # 124 | # 125 | # 126 | # 127 | # """ 128 | # return tmpl.format(x, y, z, radius/2, r, g, b) 129 | 130 | 131 | def make_xml(pcl, color, radius, flipX=False, flipY=False, flipZ=False, max_points=None): 132 | xml = '' 133 | xml += xml_head() 134 | 135 | pcl = standardize_bbox(pcl) 136 | print(pcl.shape) 137 | if max_points is not None: 138 | pcl, color = downsample(pcl, color, max_points) 139 | for p, c in zip(pcl, color): 140 | x, y, z = p 141 | r, g, b = c 142 | if flipX: x = -x 143 | if flipY: y = -y 144 | if flipZ: z = -z 145 | xml += xml_sphere(x, y, z, r, g, b, radius) 146 | 147 | xml += xml_tail() 148 | 149 | return xml 150 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import point_cloud_utils as pcu 6 | from matplotlib import cm 7 | 8 | from pytorch3d.ops import knn_points 9 | from pytorch3d.structures import Meshes 10 | from pytorch3d.renderer import ( 11 | look_at_view_transform, 12 | FoVPerspectiveCameras, 13 | PointLights, 14 | DirectionalLights, 15 | Materials, 16 | RasterizationSettings, 17 | MeshRenderer, 18 | MeshRasterizer, 19 | SoftPhongShader, 20 | TexturesUV, 21 | TexturesVertex 22 | ) 23 | 24 | from models.common import * 25 | from .mitsuba import * 26 | 27 | 28 | def get_icosahedron(size): 29 | X = .525731112119133606 30 | Z = .850650808352039932 31 | verts = torch.FloatTensor([ 32 | [-X, 0.0, Z], [X, 0.0, Z], [-X, 0.0, -Z], [X, 0.0, -Z], 33 | [0.0, Z, X], [0.0, Z, -X], [0.0, -Z, X], [0.0, -Z, -X], 34 | [Z, X, 0.0], [-Z, X, 0.0], [Z, -X, 0.0], [-Z, -X, 0.0] 35 | ]) * size 36 | faces = torch.LongTensor([ 37 | [1, 4, 0], [4, 9, 0], [4, 5, 9], [8, 5, 4], [1, 8, 4], 38 | [1, 10, 8], [10, 3, 8], [8, 3, 5], [3, 2, 5], [3, 7, 2], 39 | [3, 10, 7], [10, 6, 7], [6, 11, 7], [6, 0, 11], [6, 1, 0], 40 | [10, 1, 6], [11, 0, 9], [2, 11, 9], [5, 2, 9], [11, 2, 7], 41 | ]) 42 | return verts, faces 43 | 44 | 45 | class NoisyPointCloudXMLMaker(object): 46 | 47 | def __init__(self, point_size=0.01, max_noise=0.03): 48 | super().__init__() 49 | self.point_size = point_size 50 | self.max_noise = max_noise 51 | 52 | def get_color(self, noise, showing_noisy=False): 53 | """ 54 | Args: 55 | noise: (N, 1) 56 | """ 57 | max_noise = self.max_noise 58 | 59 | N = noise.shape[0] 60 | noise_level = np.clip(noise / max_noise, 0, 1) # (N, 1) 61 | 62 | base_color = np.repeat(np.array([[0, 0, 1]], dtype=np.float), N, axis=0) # Blue, (N, 3) 63 | noise_color = np.repeat(np.array([[1, 1, 0]], dtype=np.float), N, axis=0) # Yellow 64 | if showing_noisy: 65 | noise_level = (np.power(10, noise_level)-1) / (np.power(10, 1)-1) 66 | else: 67 | noise_level = (np.power(25, noise_level)-1) / (np.power(25, 1)-1) 68 | 69 | inv_mix = (1-noise_level)*(1-base_color) + noise_level*(1-noise_color) 70 | return 1 - inv_mix # (N, 3) 71 | 72 | def render(self, pcl, verts, faces, rotation=None, lookat=(2.0, 30, -45), distance='p2m'): 73 | assert distance in ('p2m', 'nn') 74 | 75 | device = pcl.device 76 | # Normalize mesh 77 | verts, center, scale = normalize_sphere(verts.unsqueeze(0)) 78 | verts = verts[0] 79 | # Normalize pcl 80 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale) 81 | pcl = pcl[0] 82 | if distance == 'p2m': 83 | # Compute point-to-surface distance 84 | p2m = pointwise_p2m_distance_normalized(pcl.to(device), verts.to(device), faces.to(device)).sqrt().unsqueeze(-1) 85 | elif distance == 'nn': 86 | p2m, _, _ = pytorch3d.ops.knn_points(pcl.unsqueeze(0), verts.unsqueeze(0), K=1) 87 | p2m = p2m[0,:,:].mean(dim=-1,keepdim=True) 88 | p2m = p2m 89 | print(p2m.max() , p2m.mean() , p2m.min()) 90 | 91 | 92 | 93 | # Rotate point cloud 94 | if rotation is not None: 95 | pcl = torch.matmul(pcl, rotation.t()) 96 | # pcl[:, 1] = -pcl[:, 1] 97 | 98 | xml = make_xml(pcl.cpu().numpy(), self.get_color(p2m.cpu().numpy()), radius=self.point_size , max_points=None) 99 | return xml 100 | 101 | 102 | class SimpleMeshRenderer(object): 103 | 104 | def __init__(self, image_size=1024): 105 | super().__init__() 106 | self.image_size = image_size 107 | 108 | def render(self, verts, faces, lookat=(2.0, 30, -45), color=(1, 1, 1)): 109 | device = verts.device 110 | # Normalize mesh 111 | verts, _, _ = normalize_sphere(verts.unsqueeze(0)) 112 | verts = verts[0] 113 | 114 | textures = TexturesVertex([ 115 | torch.FloatTensor(color).to(device).view(1, 3).repeat(verts.size(0), 1) 116 | ]) 117 | meshes = Meshes([verts], [faces], textures) 118 | 119 | # Render 120 | R, T = look_at_view_transform(*lookat) 121 | cameras = FoVPerspectiveCameras(device=device, R=R, T=T) 122 | raster_settings = RasterizationSettings( 123 | image_size=self.image_size, 124 | blur_radius=0., 125 | faces_per_pixel=1, 126 | ) 127 | 128 | lights = PointLights( 129 | device=device, 130 | ambient_color=torch.ones(1,3) * 0.5, 131 | diffuse_color=torch.ones(1,3) * 0.4, 132 | specular_color=torch.ones(1,3) * 0.1, 133 | location=cameras.get_camera_center() 134 | ) 135 | 136 | renderer = MeshRenderer( 137 | rasterizer=MeshRasterizer( 138 | cameras=cameras, 139 | raster_settings=raster_settings 140 | ), 141 | shader=SoftPhongShader( 142 | device=device, 143 | cameras=cameras, 144 | lights=lights 145 | ) 146 | ) 147 | 148 | images = renderer(meshes) 149 | 150 | 151 | 152 | 153 | return images 154 | -------------------------------------------------------------------------------- /models/vecfields/radius.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Linear, Conv2d, BatchNorm2d, Conv1d, BatchNorm1d, Sequential, ReLU 3 | import pytorch3d.ops 4 | import numpy as np 5 | from ..common import * 6 | 7 | 8 | class RadiusVectorField(Module): 9 | 10 | def __init__( 11 | self, 12 | radius, 13 | num_points , 14 | ctx_point_feature_dim, 15 | style = "normal" , 16 | max_points = 60, 17 | raise_xyz_channels=64, 18 | point_dim=3, 19 | hidden_dims_pointwise=[128, 128, 256], 20 | hidden_dims_global=[128, 64] 21 | ): 22 | super().__init__() 23 | self.radius = radius 24 | self.style = style 25 | self.num_points = num_points 26 | self.max_points = max_points 27 | self.raise_xyz = Linear(3, raise_xyz_channels) 28 | 29 | dims = [raise_xyz_channels+ctx_point_feature_dim] + hidden_dims_pointwise 30 | conv_layers = [] 31 | for i in range(len(dims)-1): 32 | if self.style == "normal": 33 | conv_layers += [ 34 | Conv2d(dims[i], dims[i+1], kernel_size=(1, 1)), 35 | BatchNorm2d(dims[i+1]), 36 | ] 37 | elif self.style == "residual": 38 | conv_layers += [ 39 | ResnetBlockConv2d(dims[i], dims[i+1]), 40 | BatchNorm2d(dims[i+1]), 41 | ] 42 | if i < len(dims)-2: 43 | conv_layers += [ 44 | ReLU(), 45 | ] 46 | self.pointwise_convmlp = Sequential(*conv_layers) 47 | 48 | dims = [hidden_dims_pointwise[-1]] + hidden_dims_global + [point_dim] 49 | conv_layers = [] 50 | for i in range(len(dims)-1): 51 | if self.style == "normal": 52 | conv_layers += [ 53 | Conv1d(dims[i], dims[i+1], kernel_size=1), 54 | ] 55 | elif self.style == "residual": 56 | conv_layers += [ 57 | ResnetBlockConv1d(dims[i] , dims[i+1]) , 58 | ] 59 | if i < len(dims)-2: 60 | conv_layers += [ 61 | BatchNorm1d(dims[i+1]), 62 | ReLU(), 63 | ] 64 | self.global_convmlp = Sequential(*conv_layers) 65 | 66 | 67 | def forward(self, p_query, p_context, h_context): 68 | """ 69 | Args: 70 | p_query: Query point set, (B, N_query, 3). 71 | p_context: Context point set, (B, N_ctx, 3). 72 | h_context: Point-wise features of the context point set, (B, N_ctx, H_ctx). 73 | Returns: 74 | (B, N_query, 3) 75 | """ 76 | dist, knn_idx, _ = pytorch3d.ops.knn_points( 77 | p1=p_query, 78 | p2=p_context, 79 | K=self.max_points, 80 | return_nn=True, 81 | return_sorted=True 82 | ) # (B,N_query , K ) , (B, N_query, K), (B, N_query, K, 3) , K should be a large number 83 | # dist : [B, N_query, K] records the distance between the query points and its K nearest neighbours 84 | B , N_query , _ = p_query.shape 85 | _ , N , _ = p_context.shape 86 | #radius_first_point = knn_idx[ : , : , 0].view(B , N_query , 1).repeat([1, 1, self.num_points]) 87 | knn_idx[dist > self.radius ** 2] = N + 1 # set index of points out of range as N + 1 88 | radius_graph_idx = knn_idx[ : , : , :self.num_points] # knn_idx is sorted in ascending order before 89 | 90 | # actually copy the value of the first point in radius_graph for subsequent replacement. 91 | mask = (radius_graph_idx == N + 1) # The number of points within their( these points' ) radius is less than num_points 92 | radius_graph_idx = torch.where(mask, torch.zeros_like(radius_graph_idx),radius_graph_idx) 93 | radius_graph_points = pytorch3d.ops.knn_gather( x = p_context , idx = radius_graph_idx) # ( B, N_query, num_points, 3 ) 94 | dist = dist [ : , : , :self.num_points] 95 | dist = torch.where(mask , torch.zeros_like(dist) , dist) 96 | # Relative coordinates and their embeddings 97 | p_rel = radius_graph_points - p_query.unsqueeze(-2) # (B, N_query, num_points, 3) 98 | h_rel = self.raise_xyz(p_rel) # (B, N_query, num_points , H_rel) 99 | 100 | # Grouped features of neighbor points 101 | h_group = pytorch3d.ops.knn_gather( 102 | x=h_context, 103 | idx=radius_graph_idx, 104 | ) # (B, N_query, num_points, H_ctx) 105 | # Combine 106 | h_combined = torch.cat([h_rel, h_group], dim=-1) # (B, N_query, num_points, H_rel+H_ctx) 107 | 108 | # Featurize 109 | h_combined = h_combined.permute(0, 3, 1, 2).contiguous() # (B, H_rel+H_ctx, N_query, num_points) 110 | y = self.pointwise_convmlp(h_combined) # (B, H_out, N_query, num_points) 111 | y = y.permute(0, 2, 3, 1).contiguous() # (B, N_query, num_points, H_out) 112 | dist = dist.unsqueeze(-1) #(B, N_query, num_points, 1) 113 | c = 0.5*(torch.cos(dist * np.pi / self.radius) + 1.0) 114 | c = c * (dist <= self.radius) * (dist > 0.0) # (B, N_query, num_points, 1) 115 | y = torch.mul(y , c).permute(0, 3, 1, 2).contiguous() 116 | y = y.sum(-1) # (B, H_out, N_query) 117 | #y, _ = torch.max(y, dim=3) # (B, H_out, N_query) 118 | # Vectorize 119 | y = self.global_convmlp(y) # (B, 3, N_query) 120 | y = y.permute(0, 2, 1).contiguous() # (B, N_query, 3) 121 | return y 122 | -------------------------------------------------------------------------------- /models/encoders/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Linear, ModuleList 3 | import pytorch3d.ops 4 | from ..common import * 5 | 6 | class DenseEdgeConv(Module): 7 | 8 | def __init__(self, in_channels, num_fc_layers, growth_rate, knn=16, aggr='max', activation='relu', relative_feat_only=False): 9 | super().__init__() 10 | self.in_channels = in_channels 11 | self.knn = knn 12 | assert num_fc_layers > 2 13 | self.num_fc_layers = num_fc_layers 14 | self.growth_rate = growth_rate 15 | self.relative_feat_only = relative_feat_only 16 | 17 | # Densely Connected Layers 18 | if relative_feat_only: 19 | self.layer_first = FCLayer(in_channels, growth_rate, bias=True, activation=activation) 20 | else: 21 | self.layer_first = FCLayer(3*in_channels, growth_rate, bias=True, activation=activation) 22 | self.layer_last = FCLayer(in_channels + (num_fc_layers - 1) * growth_rate, growth_rate, bias=True, activation=None) 23 | self.layers = ModuleList() 24 | for i in range(1, num_fc_layers-1): 25 | self.layers.append(FCLayer(in_channels + i * growth_rate, growth_rate, bias=True, activation=activation)) 26 | 27 | self.aggr = Aggregator(aggr) 28 | 29 | @property 30 | def out_channels(self): 31 | return self.in_channels + self.num_fc_layers * self.growth_rate 32 | 33 | def get_edge_feature(self, x, knn_idx): 34 | """ 35 | :param x: (B, N, d) 36 | :param knn_idx: (B, N, K) 37 | :return (B, N, K, 2*d) 38 | """ 39 | knn_feat = knn_group(x, knn_idx) # B * N * K * d 40 | x_tiled = x.unsqueeze(-2).expand_as(knn_feat) 41 | if self.relative_feat_only: 42 | edge_feat = knn_feat - x_tiled 43 | else: 44 | edge_feat = torch.cat([x_tiled, knn_feat, knn_feat - x_tiled], dim=3) 45 | return edge_feat 46 | 47 | def forward(self, x, pos): 48 | """ 49 | :param x: (B, N, d) 50 | :return (B, N, d+L*c) 51 | """ 52 | knn_idx = get_knn_idx(pos, pos, k=self.knn, offset=1) 53 | 54 | # First Layer 55 | edge_feat = self.get_edge_feature(x, knn_idx) 56 | y = torch.cat([ 57 | self.layer_first(edge_feat), # (B, N, K, c) 58 | x.unsqueeze(-2).repeat(1, 1, self.knn, 1) # (B, N, K, d) 59 | ], dim=-1) # (B, N, K, d+c) 60 | 61 | # Intermediate Layers 62 | for layer in self.layers: 63 | y = torch.cat([ 64 | layer(y), # (B, N, K, c) 65 | y, # (B, N, K, c+d) 66 | ], dim=-1) # (B, N, K, d+c+...) 67 | 68 | # Last Layer 69 | y = torch.cat([ 70 | self.layer_last(y), # (B, N, K, c) 71 | y # (B, N, K, d+(L-1)*c) 72 | ], dim=-1) # (B, N, K, d+L*c) 73 | 74 | # Pooling 75 | y = self.aggr(y, dim=-2) 76 | 77 | return y 78 | 79 | 80 | class KNearestGCNNPointEncoder(Module): 81 | 82 | def __init__(self, 83 | in_channels=3, 84 | dynamic_graph=False, 85 | conv_channels=24, 86 | num_convs=4, 87 | conv_num_fc_layers=3, 88 | conv_growth_rate=12, 89 | conv_knn=16, 90 | conv_aggr='max', 91 | activation='relu' 92 | ): 93 | super().__init__() 94 | self.in_channels = in_channels 95 | self.dynamic_graph = dynamic_graph 96 | self.num_convs = num_convs 97 | # Edge Convolution Units 98 | self.transforms = ModuleList() 99 | self.convs = ModuleList() 100 | for i in range(num_convs): 101 | if i == 0: 102 | trans = FCLayer(in_channels, conv_channels, bias=True, activation=None) 103 | conv = DenseEdgeConv( 104 | conv_channels, 105 | num_fc_layers=conv_num_fc_layers, 106 | growth_rate=conv_growth_rate, 107 | knn=conv_knn, 108 | aggr=conv_aggr, 109 | activation=activation, 110 | relative_feat_only=True 111 | ) 112 | else: 113 | trans = FCLayer(in_channels, conv_channels, bias=True, activation=activation) 114 | conv = DenseEdgeConv( 115 | conv_channels, 116 | num_fc_layers=conv_num_fc_layers, 117 | growth_rate=conv_growth_rate, 118 | knn=conv_knn, 119 | aggr=conv_aggr, 120 | activation=activation, 121 | relative_feat_only=False 122 | ) 123 | self.transforms.append(trans) 124 | self.convs.append(conv) 125 | in_channels = conv.out_channels 126 | 127 | @property 128 | def out_channels(self): 129 | return self.convs[-1].out_channels 130 | 131 | def dynamic_graph_forward(self, x): 132 | for i in range(self.num_convs): 133 | x = self.transforms[i](x) 134 | x = self.convs[i](x, x) 135 | return x 136 | 137 | def static_graph_forward(self, pos): 138 | x = pos 139 | for i in range(self.num_convs): 140 | x = self.transforms[i](x) 141 | x = self.convs[i](x, pos) 142 | return x 143 | 144 | def forward(self, x): 145 | if self.dynamic_graph: 146 | return self.dynamic_graph_forward(x) 147 | else: 148 | return self.static_graph_forward(x) 149 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deeprs 2 | channels: 3 | - pytorch3d 4 | - pytorch 5 | - bottler 6 | - conda-forge 7 | - iopath 8 | - fvcore 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=4.5=1_gnu 13 | - absl-py=0.15.0=pyhd3eb1b0_0 14 | - aiohttp=3.8.1=py38h7f8727e_1 15 | - aiosignal=1.2.0=pyhd3eb1b0_0 16 | - async-timeout=4.0.1=pyhd3eb1b0_0 17 | - attrs=21.4.0=pyhd3eb1b0_0 18 | - blas=1.0=mkl 19 | - blessings=1.7=py38h06a4308_1002 20 | - blinker=1.4=py38h06a4308_0 21 | - bottleneck=1.3.4=py38hce1f21e_0 22 | - brotli=1.0.9=he6710b0_2 23 | - brotlipy=0.7.0=py38h27cfd23_1003 24 | - c-ares=1.18.1=h7f8727e_0 25 | - ca-certificates=2022.4.26=h06a4308_0 26 | - cachetools=4.2.2=pyhd3eb1b0_0 27 | - certifi=2022.5.18.1=py38h06a4308_0 28 | - cffi=1.15.0=py38hd667e15_1 29 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 30 | - click=8.0.4=py38h06a4308_0 31 | - colorama=0.4.4=pyhd3eb1b0_0 32 | - cryptography=3.4.8=py38hd23ed53_0 33 | - cudatoolkit=10.2.89=hfd86e86_1 34 | - cycler=0.11.0=pyhd3eb1b0_0 35 | - dataclasses=0.8=pyh6d0b6a4_7 36 | - dbus=1.13.18=hb2f20db_0 37 | - expat=2.4.4=h295c915_0 38 | - fontconfig=2.13.1=h6c09931_0 39 | - fonttools=4.25.0=pyhd3eb1b0_0 40 | - freetype=2.11.0=h70c0345_0 41 | - frozenlist=1.2.0=py38h7f8727e_0 42 | - fvcore=0.1.5.post20210915=py38 43 | - giflib=5.2.1=h7b6447c_0 44 | - glib=2.69.1=h4ff587b_1 45 | - google-auth=2.6.0=pyhd3eb1b0_0 46 | - google-auth-oauthlib=0.4.1=py_2 47 | - gpustat=0.6.0=pyhd3eb1b0_1 48 | - grpcio=1.42.0=py38hce63b2e_0 49 | - gst-plugins-base=1.14.0=h8213a91_2 50 | - gstreamer=1.14.0=h28cd5cc_2 51 | - h5py=3.6.0=py38ha0f2276_0 52 | - hdf5=1.10.6=hb1b8bf9_0 53 | - icu=58.2=he6710b0_3 54 | - idna=3.3=pyhd3eb1b0_0 55 | - importlib-metadata=4.11.3=py38h06a4308_0 56 | - intel-openmp=2021.4.0=h06a4308_3561 57 | - iopath=0.1.9=py38 58 | - joblib=1.1.0=pyhd3eb1b0_0 59 | - jpeg=9b=h024ee3a_2 60 | - kiwisolver=1.3.2=py38h295c915_0 61 | - lcms2=2.12=h3be6417_0 62 | - ld_impl_linux-64=2.35.1=h7274673_9 63 | - libffi=3.3=he6710b0_2 64 | - libgcc-ng=9.3.0=h5101ec6_17 65 | - libgfortran-ng=7.5.0=ha8ba4b0_17 66 | - libgfortran4=7.5.0=ha8ba4b0_17 67 | - libgomp=9.3.0=h5101ec6_17 68 | - libpng=1.6.37=hbc83047_0 69 | - libprotobuf=3.19.1=h4ff587b_0 70 | - libstdcxx-ng=9.3.0=hd4cf53a_17 71 | - libtiff=4.2.0=h85742a9_0 72 | - libuuid=1.0.3=h7f8727e_2 73 | - libuv=1.40.0=h7b6447c_0 74 | - libwebp=1.2.0=h89dd481_0 75 | - libwebp-base=1.2.0=h27cfd23_0 76 | - libxcb=1.14=h7b6447c_0 77 | - libxml2=2.9.12=h03d6c58_0 78 | - llvm-openmp=8.0.1=hc9558a2_0 79 | - lz4-c=1.9.3=h295c915_1 80 | - markdown=3.3.4=py38h06a4308_0 81 | - matplotlib=3.5.1=py38h06a4308_1 82 | - matplotlib-base=3.5.1=py38ha18d171_1 83 | - mkl=2021.4.0=h06a4308_640 84 | - mkl-service=2.4.0=py38h7f8727e_0 85 | - mkl_fft=1.3.1=py38hd3c417c_0 86 | - mkl_random=1.2.2=py38h51133e4_0 87 | - multidict=5.2.0=py38h7f8727e_2 88 | - munkres=1.1.4=py_0 89 | - ncurses=6.3=h7f8727e_2 90 | - ninja=1.10.2=py38hd09550d_3 91 | - numexpr=2.8.1=py38h6abb31d_0 92 | - numpy=1.21.2=py38h20f2e39_0 93 | - numpy-base=1.21.2=py38h79a1101_0 94 | - nvidia-ml=7.352.0=pyhd3eb1b0_0 95 | - nvidiacub=1.10.0=0 96 | - oauthlib=3.2.0=pyhd3eb1b0_0 97 | - openmp=8.0.1=0 98 | - openssl=1.1.1o=h7f8727e_0 99 | - packaging=21.3=pyhd3eb1b0_0 100 | - pandas=1.4.1=py38h295c915_1 101 | - pcre=8.45=h295c915_0 102 | - pillow=9.0.1=py38h22f2fdc_0 103 | - pip=21.2.4=py38h06a4308_0 104 | - point_cloud_utils=0.18.0=py38hc10631b_1 105 | - portalocker=2.3.2=py38h578d9bd_1 106 | - protobuf=3.19.1=py38h295c915_0 107 | - psutil=5.8.0=py38h27cfd23_1 108 | - pyasn1=0.4.8=pyhd3eb1b0_0 109 | - pyasn1-modules=0.2.8=py_0 110 | - pycparser=2.21=pyhd3eb1b0_0 111 | - pyjwt=2.1.0=py38h06a4308_0 112 | - pyopenssl=21.0.0=pyhd3eb1b0_1 113 | - pyparsing=3.0.4=pyhd3eb1b0_0 114 | - pyqt=5.9.2=py38h05f1152_4 115 | - pysocks=1.7.1=py38h06a4308_0 116 | - python=3.8.13=h12debd9_0 117 | - python-dateutil=2.8.2=pyhd3eb1b0_0 118 | - python_abi=3.8=2_cp38 119 | - pytorch=1.7.1=py3.8_cuda10.2.89_cudnn7.6.5_0 120 | - pytorch3d=0.5.0=py38_cu102_pyt171 121 | - pytz=2021.3=pyhd3eb1b0_0 122 | - pyyaml=6.0=py38h7f8727e_1 123 | - qt=5.9.7=h5867ecd_1 124 | - readline=8.1.2=h7f8727e_1 125 | - requests=2.27.1=pyhd3eb1b0_0 126 | - requests-oauthlib=1.3.0=py_0 127 | - rsa=4.7.2=pyhd3eb1b0_1 128 | - scikit-learn=1.0.2=py38h51133e4_1 129 | - scipy=1.7.3=py38hc147768_0 130 | - setuptools=58.0.4=py38h06a4308_0 131 | - sip=4.19.13=py38h295c915_0 132 | - six=1.16.0=pyhd3eb1b0_1 133 | - sqlite=3.38.2=hc218d9a_0 134 | - tabulate=0.8.9=py38h06a4308_0 135 | - tensorboard=2.6.0=py_1 136 | - tensorboard-data-server=0.6.0=py38hca6d32c_0 137 | - tensorboard-plugin-wit=1.6.0=py_0 138 | - termcolor=1.1.0=py_2 139 | - threadpoolctl=2.2.0=pyh0d69192_0 140 | - tk=8.6.11=h1ccaba5_0 141 | - torchvision=0.8.2=py38_cu102 142 | - tornado=6.1=py38h27cfd23_0 143 | - tqdm=4.63.0=pyhd3eb1b0_0 144 | - typing-extensions=4.1.1=hd3eb1b0_0 145 | - typing_extensions=4.1.1=pyh06a4308_0 146 | - urllib3=1.26.8=pyhd3eb1b0_0 147 | - werkzeug=2.0.3=pyhd3eb1b0_0 148 | - wheel=0.37.1=pyhd3eb1b0_0 149 | - xz=5.2.5=h7b6447c_0 150 | - yacs=0.1.6=pyhd3eb1b0_1 151 | - yaml=0.2.5=h7b6447c_0 152 | - yarl=1.6.3=py38h27cfd23_0 153 | - zipp=3.7.0=pyhd3eb1b0_0 154 | - zlib=1.2.11=h7f8727e_4 155 | - zstd=1.4.9=haebb681_0 156 | - pip: 157 | - easydict==1.9 158 | - jinja2==3.1.1 159 | - markupsafe==2.1.1 160 | - torch-cluster==1.5.9 161 | - torch-geometric==2.0.4 162 | - torch-scatter==2.0.5 163 | - torch-sparse==0.6.8 164 | - torch-spline-conv==1.2.1 165 | prefix: /home/chenhaolan/.conda/envs/deeprs 166 | -------------------------------------------------------------------------------- /models/resampler.py: -------------------------------------------------------------------------------- 1 | from pytorch3d.ops.knn import knn_points 2 | import torch 3 | from torch.nn import Module 4 | import pytorch3d.ops 5 | from .common import * 6 | 7 | from .encoders import get_encoder 8 | from .vecfields import get_vecfield 9 | 10 | class PointSetResampler(Module): 11 | 12 | def __init__(self, config): 13 | super().__init__() 14 | self.cfg = config 15 | #print(config) 16 | self.encoder = get_encoder(config.encoder) 17 | self.vecfield = get_vecfield(config.vecfield, self.encoder.out_channels) 18 | 19 | def forward(self, p_query, p_ctx): 20 | """ 21 | Args: 22 | p_query: Query point set, (B, N_query, 3). 23 | p_ctx: Context point set, (B, N_ctx, 3). 24 | Returns: 25 | (B, N_query, 3) 26 | """ 27 | h_ctx = self.encoder(p_ctx) # (B, N_query, H), Features of each point in `p_ctx`. 28 | vec = self.vecfield(p_query, p_ctx, h_ctx) # (B, N_query, 3) 29 | return vec 30 | 31 | def get_loss_vec(self, p_query, p_ctx, vec_gt): 32 | """ 33 | Computes loss according to ground truth vectors. 34 | Args: 35 | vec_gt: Ground truth vectors, (B, N_query, 3). 36 | """ 37 | vec_pred = self(p_query, p_ctx) # (B, N_query, 3) 38 | loss = ((vec_pred - vec_gt) ** 2.0).sum(dim=-1).mean() 39 | return loss 40 | 41 | def get_loss_pc(self, p_query, p_ctx, p_gt, avg_knn): 42 | """ 43 | Computes loss according to ground truth point clouds. 44 | Args: 45 | p_gt: Ground truth point clouds, (B, N_gt, 3). 46 | avg_knn: For each point in `p_query`, use how many nearest points in `p_gt` 47 | to estimate the ground truth vector. 48 | """ 49 | _, _, gt_nbs = pytorch3d.ops.knn_points( 50 | p_query, 51 | p_gt, 52 | K=avg_knn, 53 | return_nn=True, 54 | ) # (B, p_query, K, 3) 55 | 56 | vec_gt = (gt_nbs - p_query.unsqueeze(-2)).mean(-2) # (B, N_query, 3) 57 | return self.get_loss_vec(p_query, p_ctx, vec_gt) 58 | def get_cd_loss(self, ref , gen): 59 | P = batch_pairwise_dist(ref, gen) 60 | mins, _ = torch.min(P, 1) 61 | loss_1 = torch.mean(mins) 62 | mins, _ = torch.min(P, 2) 63 | loss_2 = torch.mean(mins) 64 | return loss_1 + loss_2 65 | @torch.no_grad() 66 | def resample(self, p_init, p_ctx, step_size=0.2, step_decay=0.95, num_steps=40): 67 | traj = [p_init.clone().cpu()] 68 | h_ctx = self.encoder(p_ctx) # (B, N_ctx, 3) 69 | 70 | p_current = p_init 71 | for step in range(num_steps): 72 | vec_pred = self.vecfield(p_current, p_ctx, h_ctx) # (B, N_query, 3) 73 | # vec_pred: f(x_i) -> v_i + dL(x)/dx 74 | s = step_size * (step_decay ** step) 75 | p_next = p_current + s * vec_pred 76 | if step % 5 == 4 : 77 | traj.append(p_next.clone().cpu()) 78 | p_current = p_next 79 | return p_current, traj 80 | def get_repulsion_loss(self, p_cur): 81 | h = 7e-3 82 | dist2, _, _ = pytorch3d.ops.knn_points( 83 | p1=p_cur, 84 | p2=p_cur, 85 | K=7, 86 | return_nn=True 87 | ) # (B, N_query,K ) , dist2 is a squared number 88 | dist2 = dist2[: , :, 1:] # dist2[: ,:, 0] = 0 89 | dist = torch.sqrt(dist2) 90 | #print(dist.mean()) 91 | weight = torch.exp(-dist2 / h ** 2) 92 | loss = torch.mean((- dist) * weight) 93 | #print('Loss %.6f' % ( 94 | # 0.05*loss.item(), 95 | #)) 96 | return 0.05*loss + 1e-4 97 | ''' 98 | def repulsion(self, p_cur) : 99 | _, _, knn_points = pytorch3d.ops.knn_points( 100 | p1=p_cur, 101 | p2=p_cur, 102 | K=6, 103 | return_nn=True 104 | ) # (B, N_query,K , 3) 105 | p_rel = knn_points - p_cur.unsqueeze(-2) # (B, N_query, K, 3) 106 | p_rel = p_rel[:,:,1:,:] 107 | p_normed = torch.norm(p_rel, p=2, dim = -1,keepdim=True) 108 | p_normed = p_rel / p_normed 109 | dist = torch.sqrt((p_rel ** 2 ).sum(dim = -1)) # (B,N_query,K) 110 | #p_normed = (-dist * torch.exp ( -1 * dist ** 2 / 6e-3)).unsqueeze(-1) * p_normed 111 | p_normed = -(5e-4/dist ** 2).unsqueeze(-1) * p_normed 112 | target = torch.sum(p_normed,dim = -2 , keepdim=False) # target of the repulsion power 113 | return target 114 | ''' 115 | def glr (self, p_cur): 116 | _, _, knn_points = pytorch3d.ops.knn_points( 117 | p1=p_cur, 118 | p2=p_cur, 119 | K=6, 120 | return_nn=True 121 | ) # (B, N_query,K , 3) 122 | p_rel = p_cur.unsqueeze(-2) - knn_points # (B, N_query, K, 3) 123 | p_rel = p_rel[:,:,1:,:] 124 | glr_grad = 2 * p_rel 125 | dist = (p_rel ** 2 ).sum(dim = -1) # (B,N_query,K) 126 | target = torch.exp ( -1 * dist ** 2 / 1e-9).unsqueeze(-1) * glr_grad 127 | target = torch.sum(target,dim = -2 , keepdim=False) 128 | return target 129 | def gtv (self, p_cur) : 130 | _, _, knn_points = pytorch3d.ops.knn_points( 131 | p1=p_cur, 132 | p2=p_cur, 133 | K=6, 134 | return_nn=True 135 | ) # (B, N_query,K , 3) 136 | p_rel = (p_cur.unsqueeze(-2) - knn_points) 137 | p_rel = p_rel[:,:,1:,:] 138 | gtv_grad = p_rel / (torch.abs(p_rel)+1e-7) 139 | dist = (p_rel ** 2 ).sum(dim = -1) # (B,N_query,K) 140 | target = torch.exp ( -1 * dist ** 2 / 5e-9).unsqueeze(-1) * gtv_grad 141 | target = torch.sum(target,dim = -2 , keepdim=False) 142 | return target -------------------------------------------------------------------------------- /scripts/train_denoise.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import yaml 5 | from easydict import EasyDict 6 | from tqdm.auto import tqdm 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.utils.tensorboard 11 | from models.resampler import PointSetResampler 12 | from models.common import chamfer_distance_unit_sphere 13 | from utils.datasets import PointCloudDataset, PairedPatchDataset 14 | from utils.transforms import * 15 | from utils.misc import * 16 | from utils.denoise import patch_based_denoise 17 | 18 | 19 | if __name__ == '__main__': 20 | best = 10000 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--config', type=str ) 23 | parser.add_argument('--device', type=str, default='cuda') 24 | parser.add_argument('--log_root', type=str, default='./logs') 25 | args = parser.parse_args() 26 | 27 | # Load configs 28 | with open(args.config, 'r') as f: 29 | config = EasyDict(yaml.safe_load(f)) 30 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')] 31 | seed_all(config.train.seed) 32 | 33 | # Logging 34 | log_dir = get_new_log_dir(args.log_root, prefix=config_name +'_'+str(config.dataset.val_noise)+'_') 35 | print(log_dir) 36 | ckpt_dir = os.path.join(log_dir, 'checkpoints') 37 | os.makedirs(ckpt_dir, exist_ok=True) 38 | logger = get_logger('train', log_dir) 39 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 40 | logger.info(args) 41 | logger.info(config) 42 | shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config))) 43 | 44 | # Datasets and loaders 45 | logger.info('Loading datasets') 46 | train_dset = PairedPatchDataset( 47 | datasets=[PointCloudDataset( 48 | root=config.dataset.dataset_root, 49 | dataset=config.dataset.dataset, 50 | split='train', 51 | resolution=resl, 52 | transform=standard_train_transforms(noise_std_max=config.dataset.noise_max, noise_std_min=config.dataset.noise_min, rotate=config.dataset.aug_rotate) 53 | ) for resl in config.dataset.resolutions 54 | ], 55 | patch_size=config.dataset.patch_size, 56 | patch_ratio=1.2, 57 | on_the_fly=True 58 | ) 59 | 60 | val_dset = PointCloudDataset( 61 | root=config.dataset.dataset_root, 62 | dataset=config.dataset.dataset, 63 | split='test', 64 | resolution=config.dataset.resolutions[2], 65 | transform=standard_train_transforms(noise_std_max=config.dataset.val_noise, noise_std_min=config.dataset.val_noise, rotate=False, scale_d=0), 66 | ) 67 | train_iter = get_data_iterator(DataLoader(train_dset, batch_size=config.train.train_batch_size, num_workers=config.train.num_workers, shuffle=True)) 68 | 69 | # Model 70 | logger.info('Building model...') 71 | model = PointSetResampler(config.model).to(args.device) 72 | logger.info(repr(model)) 73 | 74 | # Optimizer and Scheduler 75 | optimizer = torch.optim.Adam( 76 | model.parameters(), 77 | lr=config.train.lr, 78 | weight_decay=config.train.weight_decay, 79 | ) 80 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 81 | optimizer, 82 | mode="min", 83 | factor=config.scheduler.factor, 84 | patience=config.scheduler.patience, 85 | threshold=config.scheduler.threshold, 86 | ) 87 | def train(it): 88 | if it % 10000 == 0 : 89 | print(it) 90 | batch = next(train_iter) 91 | p_noisy = batch['pcl_noisy'].to(args.device) 92 | p_clean = batch['pcl_clean'].to(args.device) 93 | 94 | # Reset grad and model state 95 | optimizer.zero_grad() 96 | model.train() 97 | 98 | # Forward 99 | loss = model.get_loss_pc( 100 | p_query=p_noisy, 101 | p_ctx=p_noisy, 102 | p_gt=p_clean, 103 | avg_knn=config.train.vec_avg_knn, 104 | ) 105 | 106 | # Backward 107 | loss.backward() 108 | optimizer.step() 109 | 110 | # Logging 111 | logger.info('[Train] Iter %04d | Loss %.6f' % ( 112 | it, loss.item(), 113 | )) 114 | writer.add_scalar('train/loss', loss, it) 115 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it) 116 | writer.flush() 117 | 118 | def validate(it): 119 | global best 120 | all_clean = [] 121 | all_denoised = [] 122 | for i, data in enumerate(tqdm(val_dset, desc='Validate')): 123 | pcl_noisy = data['pcl_noisy'].to(args.device) 124 | pcl_clean = data['pcl_clean'].to(args.device) 125 | pcl_denoised = patch_based_denoise(model, pcl_noisy) 126 | all_clean.append(pcl_clean.unsqueeze(0)) 127 | all_denoised.append(pcl_denoised.unsqueeze(0)) 128 | all_clean = torch.cat(all_clean, dim=0) 129 | all_denoised = torch.cat(all_denoised, dim=0) 130 | 131 | avg_chamfer = chamfer_distance_unit_sphere(all_denoised, all_clean, batch_reduction='mean')[0].item() 132 | logger.info('[Val] Iter %04d | CD %.6f ' % (it, avg_chamfer)) 133 | writer.add_scalar('val/chamfer', avg_chamfer, it) 134 | writer.add_mesh('val/pcl', all_denoised[:2], global_step=it) 135 | if avg_chamfer < best : 136 | best = avg_chamfer 137 | torch.save(model.state_dict(),log_dir+'/model.pth') 138 | writer.flush() 139 | scheduler.step(avg_chamfer) 140 | 141 | # Main loop 142 | logger.info('Start training...') 143 | try: 144 | for it in range(1, config.train.max_iters+1): 145 | train(it) 146 | if it % config.train.val_freq == 0 or it == config.train.max_iters: 147 | cd_loss = validate(it) 148 | 149 | except KeyboardInterrupt: 150 | logger.info('Terminating...') 151 | -------------------------------------------------------------------------------- /utils/datasets/patch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.data import Dataset 4 | import pytorch3d.ops 5 | from tqdm.auto import tqdm 6 | 7 | 8 | def make_patches_for_pcl_pair(pcl_A, pcl_B, patch_size, num_patches, ratio): 9 | """ 10 | Args: 11 | pcl_A: The first point cloud, (N, 3). 12 | pcl_B: The second point cloud, (rN, 3). 13 | patch_size: Patch size M. 14 | num_patches: Number of patches P. 15 | ratio: Ratio r. 16 | Returns: 17 | (P, M, 3), (P, rM, 3) 18 | """ 19 | N = pcl_A.size(0) 20 | seed_idx = torch.randperm(N)[:num_patches] # (P, ) 21 | seed_pnts = pcl_A[seed_idx].unsqueeze(0) # (1, P, 3) 22 | _, _, pat_A = pytorch3d.ops.knn_points(seed_pnts, pcl_A.unsqueeze(0), K=patch_size, return_nn=True) 23 | pat_A = pat_A[0] # (P, M, 3) 24 | _, _, pat_B = pytorch3d.ops.knn_points(seed_pnts, pcl_B.unsqueeze(0), K=int(ratio*patch_size), return_nn=True) 25 | pat_B = pat_B[0] 26 | return pat_A, pat_B 27 | 28 | 29 | class PairedPatchDataset(Dataset): 30 | 31 | def __init__(self, datasets, patch_ratio, on_the_fly=True, patch_size=1000, num_patches=1000, transform=None): 32 | super().__init__() 33 | self.datasets = datasets 34 | self.len_datasets = sum([len(dset) for dset in datasets]) 35 | self.patch_ratio = patch_ratio 36 | self.patch_size = patch_size 37 | self.num_patches = num_patches 38 | self.on_the_fly = on_the_fly 39 | self.transform = transform 40 | self.patches = [] 41 | # Initialize 42 | if not on_the_fly: 43 | self.make_patches() 44 | 45 | def make_patches(self): 46 | for dataset in tqdm(self.datasets, desc='MakePatch'): 47 | for data in tqdm(dataset): 48 | pat_noisy, pat_clean = make_patches_for_pcl_pair( 49 | data['pcl_noisy'], 50 | data['pcl_clean'], 51 | patch_size=self.patch_size, 52 | num_patches=self.num_patches, 53 | ratio=self.patch_ratio 54 | ) # (P, M, 3), (P, rM, 3) 55 | for i in range(pat_noisy.size(0)): 56 | self.patches.append((pat_noisy[i], pat_clean[i], )) 57 | 58 | def __len__(self): 59 | if not self.on_the_fly: 60 | return len(self.patches) 61 | else: 62 | return self.len_datasets * self.num_patches 63 | 64 | 65 | def __getitem__(self, idx): 66 | if self.on_the_fly: 67 | pcl_dset = random.choice(self.datasets) 68 | pcl_data = pcl_dset[idx % len(pcl_dset)] 69 | pat_noisy, pat_clean = make_patches_for_pcl_pair( 70 | pcl_data['pcl_noisy'], 71 | pcl_data['pcl_clean'], 72 | patch_size=self.patch_size, 73 | num_patches=1, 74 | ratio=self.patch_ratio 75 | ) 76 | data = { 77 | 'pcl_noisy': pat_noisy[0], 78 | 'pcl_clean': pat_clean[0] 79 | } 80 | else: 81 | data = { 82 | 'pcl_noisy': self.patches[idx][0].clone(), 83 | 'pcl_clean': self.patches[idx][1].clone(), 84 | } 85 | if self.transform is not None: 86 | data = self.transform(data) 87 | return data 88 | 89 | class PairedUpsDataset(Dataset): 90 | 91 | def __init__(self, datasets, patch_ratio, on_the_fly=True, patch_size=1000, num_patches=200, transform=None): 92 | super().__init__() 93 | self.datasets = datasets 94 | self.len_datasets = sum([len(dset) for dset in datasets]) 95 | self.patch_ratio = patch_ratio 96 | self.patch_size = patch_size 97 | self.num_patches = num_patches 98 | self.on_the_fly = on_the_fly 99 | self.transform = transform 100 | self.patches = [] 101 | 102 | def __len__(self): 103 | if not self.on_the_fly: 104 | return len(self.patches) 105 | else: 106 | return self.len_datasets * self.num_patches 107 | 108 | def make_patches_for_ups_pair(self, pcl_A, pcl_B, pcl_C,patch_size, num_patches, ratio): 109 | """ 110 | Args: 111 | pcl_A: The first point cloud, (N, 3). 112 | pcl_B: The second point cloud, (rN, 3). 113 | patch_size: Patch size M. 114 | num_patches: Number of patches P. 115 | ratio: Ratio r. 116 | Returns: 117 | (P, M, 3), (P, rM, 3) 118 | """ 119 | N = pcl_A.size(0) 120 | seed_idx = torch.randperm(N)[:num_patches] # (P, ) 121 | seed_pnts = pcl_A[seed_idx].unsqueeze(0) # (1, P, 3) 122 | _, _, pat_A = pytorch3d.ops.knn_points(seed_pnts, pcl_A.unsqueeze(0), K=patch_size, return_nn=True) 123 | pat_A = pat_A[0] # (P, M, 3) 124 | _, _, pat_B = pytorch3d.ops.knn_points(seed_pnts, pcl_B.unsqueeze(0), K=int(ratio*patch_size), return_nn=True) 125 | pat_B = pat_B[0] 126 | _, _, pat_C = pytorch3d.ops.knn_points(seed_pnts, pcl_C.unsqueeze(0), K=int(ratio*patch_size), return_nn=True) 127 | pat_C = pat_C[0] 128 | return pat_A,pat_B, pat_C 129 | 130 | def __getitem__(self, idx): 131 | pcl_dset = random.choice(self.datasets) 132 | pcl_data = pcl_dset[idx % len(pcl_dset)] 133 | 134 | pat_low, pat_noisy , pat_gt= self.make_patches_for_ups_pair( 135 | pcl_data['original'], 136 | pcl_data['ups'], 137 | pcl_data['gt'], 138 | patch_size=self.patch_size, 139 | num_patches=1, 140 | ratio=self.patch_ratio 141 | ) 142 | ''' 143 | pat_low, pat_noisy , pat_gt= self.make_patches_for_ups_pair( 144 | pcl_data['pcl_noisy'], 145 | pcl_data['pcl_noisy'], 146 | pcl_data['pcl_clean'], 147 | patch_size=self.patch_size, 148 | num_patches=1, 149 | ratio=self.patch_ratio 150 | ) 151 | ''' 152 | #pat_low = pcl_data['original'] 153 | pat_low = pat_low[0] 154 | pat_noisy = pat_noisy[0] 155 | pat_gt = pat_gt[0] 156 | ''' 157 | pat_low = pat_low[0] 158 | #patch? 159 | pat_noisy = pcl_data['down'] 160 | pat_gt = pcl_data['gt'] 161 | ''' 162 | data = { 163 | 'pcl_low': pat_low, 164 | 'pcl_noisy': pat_noisy, 165 | 'pcl_gt' : pat_gt 166 | } 167 | if self.transform is not None: 168 | data = self.transform(data) 169 | return data -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import time 6 | import logging 7 | import logging.handlers 8 | 9 | THOUSAND = 1000 10 | MILLION = 1000000 11 | 12 | 13 | class BlackHole(object): 14 | def __setattr__(self, name, value): 15 | pass 16 | def __call__(self, *args, **kwargs): 17 | return self 18 | def __getattr__(self, name): 19 | return self 20 | class CheckpointManager(object): 21 | 22 | def __init__(self, save_dir, logger=BlackHole()): 23 | super().__init__() 24 | os.makedirs(save_dir, exist_ok=True) 25 | self.save_dir = save_dir 26 | self.ckpts = [] 27 | self.logger = logger 28 | 29 | for f in os.listdir(self.save_dir): 30 | if f[:4] != 'ckpt': 31 | continue 32 | _, score, it = f.split('_') 33 | it = it.split('.')[0] 34 | self.ckpts.append({ 35 | 'score': float(score), 36 | 'file': f, 37 | 'iteration': int(it), 38 | }) 39 | 40 | def get_worst_ckpt_idx(self): 41 | idx = -1 42 | worst = float('-inf') 43 | for i, ckpt in enumerate(self.ckpts): 44 | if ckpt['score'] >= worst: 45 | idx = i 46 | worst = ckpt['score'] 47 | return idx if idx >= 0 else None 48 | 49 | def get_best_ckpt_idx(self): 50 | idx = -1 51 | best = float('inf') 52 | for i, ckpt in enumerate(self.ckpts): 53 | if ckpt['score'] <= best: 54 | idx = i 55 | best = ckpt['score'] 56 | return idx if idx >= 0 else None 57 | 58 | def get_latest_ckpt_idx(self): 59 | idx = -1 60 | latest_it = -1 61 | for i, ckpt in enumerate(self.ckpts): 62 | if ckpt['iteration'] > latest_it: 63 | idx = i 64 | latest_it = ckpt['iteration'] 65 | return idx if idx >= 0 else None 66 | 67 | def save(self, model, args, score, others=None, step=None): 68 | 69 | if step is None: 70 | fname = 'ckpt_%.6f_.pt' % float(score) 71 | else: 72 | fname = 'ckpt_%.6f_%d.pt' % (float(score), int(step)) 73 | path = os.path.join(self.save_dir, fname) 74 | 75 | torch.save({ 76 | 'args': args, 77 | 'state_dict': model.state_dict(), 78 | 'others': others 79 | }, path) 80 | 81 | self.ckpts.append({ 82 | 'score': score, 83 | 'file': fname 84 | }) 85 | 86 | return True 87 | 88 | def load_best(self): 89 | idx = self.get_best_ckpt_idx() 90 | if idx is None: 91 | raise IOError('No checkpoints found.') 92 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) 93 | return ckpt 94 | 95 | def load_latest(self): 96 | idx = self.get_latest_ckpt_idx() 97 | if idx is None: 98 | raise IOError('No checkpoints found.') 99 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) 100 | return ckpt 101 | 102 | def load_selected(self, file): 103 | ckpt = torch.load(os.path.join(self.save_dir, file)) 104 | return ckpt 105 | 106 | def seed_all(seed): 107 | torch.manual_seed(seed) 108 | np.random.seed(seed) 109 | random.seed(seed) 110 | 111 | 112 | def get_logger(name, log_dir=None): 113 | logger = logging.getLogger(name) 114 | logger.setLevel(logging.DEBUG) 115 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') 116 | 117 | stream_handler = logging.StreamHandler() 118 | stream_handler.setLevel(logging.DEBUG) 119 | stream_handler.setFormatter(formatter) 120 | logger.addHandler(stream_handler) 121 | 122 | if log_dir is not None: 123 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) 124 | file_handler.setLevel(logging.INFO) 125 | file_handler.setFormatter(formatter) 126 | logger.addHandler(file_handler) 127 | 128 | return logger 129 | 130 | 131 | def get_new_log_dir(root='./logs', postfix='', prefix=''): 132 | log_dir = os.path.join(root, prefix + time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) + postfix) 133 | os.makedirs(log_dir) 134 | return log_dir 135 | 136 | 137 | def int_tuple(argstr): 138 | return tuple(map(int, argstr.split(','))) 139 | 140 | 141 | def str_tuple(argstr): 142 | return tuple(argstr.split(',')) 143 | 144 | 145 | def int_list(argstr): 146 | return list(map(int, argstr.split(','))) 147 | 148 | 149 | def str_list(argstr): 150 | return list(argstr.split(',')) 151 | 152 | 153 | def log_hyperparams(writer, log_dir, args): 154 | from torch.utils.tensorboard.summary import hparams 155 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} 156 | exp, ssi, sei = hparams(vars_args, {"hp_metric": -1}) 157 | fw = writer._get_file_writer() 158 | fw.add_summary(exp) 159 | fw.add_summary(ssi) 160 | fw.add_summary(sei) 161 | with open(os.path.join(log_dir, 'hparams.csv'), 'w') as csvf: 162 | csvf.write('key,value\n') 163 | for k, v in vars_args.items(): 164 | csvf.write('%s,%s\n' % (k, v)) 165 | 166 | 167 | 168 | def get_data_iterator(iterable): 169 | """Allows training with DataLoaders in a single infinite loop: 170 | for i, data in enumerate(inf_generator(train_loader)): 171 | """ 172 | iterator = iterable.__iter__() 173 | while True: 174 | try: 175 | yield iterator.__next__() 176 | except StopIteration: 177 | iterator = iterable.__iter__() 178 | 179 | 180 | def parse_experiment_name(name): 181 | if 'blensor' in name: 182 | if 'Ours' in name: 183 | dataset, method, tag, blensor_, noise = name.split('_')[:5] 184 | else: 185 | dataset, method, blensor_, noise = name.split('_')[:4] 186 | return { 187 | 'dataset': dataset, 188 | 'method': method, 189 | 'resolution': 'blensor', 190 | 'noise': noise, 191 | } 192 | 193 | if 'real' in name: 194 | if 'Ours' in name: 195 | dataset, method, tag, blensor_, noise = name.split('_')[:5] 196 | else: 197 | dataset, method, blensor_, noise = name.split('_')[:4] 198 | return { 199 | 'dataset': dataset, 200 | 'method': method, 201 | 'resolution': 'real', 202 | 'noise': noise, 203 | } 204 | 205 | else: 206 | if 'Ours' in name: 207 | dataset, method, tag, num_pnts, sample_method, noise = name.split('_')[:6] 208 | else: 209 | dataset, method, num_pnts, sample_method, noise = name.split('_')[:5] 210 | return { 211 | 'dataset': dataset, 212 | 'method': method, 213 | 'resolution': num_pnts + '_' + sample_method, 214 | 'noise': noise, 215 | } 216 | -------------------------------------------------------------------------------- /utils/denoise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import pytorch3d.ops 5 | from tqdm.auto import tqdm 6 | from sklearn.cluster import KMeans 7 | from sklearn.neighbors import kneighbors_graph, KDTree 8 | from torch import nn 9 | from collections import OrderedDict 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import itertools 13 | from models.common import farthest_point_sampling 14 | from .transforms import NormalizeUnitSphere 15 | def split_tensor_to_segments(x, segsize): 16 | num_segs = math.ceil(x.size(0) / segsize) 17 | segs = [] 18 | for i in range(num_segs): 19 | segs.append(x[i*segsize : (i+1)*segsize]) 20 | return segs 21 | 22 | def patch_based_denoise(model, pcl_noisy, step_size=0.15, num_steps=50, patch_size=4000, seed_k=5, denoise_knn=4, step_decay=0.98, get_traj=False): 23 | """ 24 | Args: 25 | pcl_noisy: Input point cloud, (N, 3) 26 | """ 27 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).' 28 | N, d = pcl_noisy.size() 29 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3) 30 | #seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size)) 31 | #_, _, patches = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy, K=patch_size, return_nn=True) 32 | #patches = patches[0] 33 | patches = pcl_noisy 34 | with torch.no_grad(): 35 | #model.eval() 36 | patches_denoised, traj = model.resample( 37 | p_init=patches, 38 | p_ctx=patches, 39 | step_size=step_size, 40 | step_decay=step_decay, 41 | num_steps=num_steps 42 | ) 43 | ''' 44 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N) 45 | pcl_denoised = pcl_denoised[0] 46 | fps_idx = fps_idx[0] 47 | ''' 48 | if get_traj: 49 | return patches_denoised[0], traj 50 | else: 51 | return patches_denoised[0] 52 | def patch_based_upsample(model, pcl_low, pcl_noisy, step_size=0.15, num_steps=50, patch_size=512, seed_k=3, denoise_knn=4, step_decay=0.98, get_traj=False): 53 | """ 54 | Args: 55 | pcl_noisy: Input point cloud, (N, 3) 56 | """ 57 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).' 58 | N, d = pcl_noisy.size() 59 | M , d = pcl_low.size() 60 | rate = int(N/M) 61 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3) 62 | pcl_low = pcl_low.unsqueeze(0) 63 | #seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size)) 64 | #_, _, patches_noisy = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy,K=int(patch_size) ,return_nn=True) 65 | #patches_noisy = patches_noisy[0] 66 | #_, _, patches_low = pytorch3d.ops.knn_points(seed_pnts, pcl_low, K=patch_size, return_nn=True) 67 | #patches_low = patches_low[0] 68 | patches_noisy = pcl_noisy 69 | patches_low = pcl_low 70 | with torch.no_grad(): 71 | model.eval() 72 | patches_denoised, traj = model.resample( 73 | p_init=patches_noisy, 74 | p_ctx=patches_low, 75 | step_size=step_size, 76 | step_decay=step_decay, 77 | num_steps=num_steps 78 | ) 79 | 80 | if get_traj : 81 | return patches_denoised[0] , traj 82 | else : 83 | return patches_denoised[0] 84 | return patches_denoised[0] 85 | ''' 86 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N) 87 | pcl_denoised = pcl_denoised[0] 88 | fps_idx = fps_idx[0] 89 | 90 | if get_traj: 91 | for i in range(len(traj)): 92 | traj[i] = traj[i].view(-1, d)[fps_idx, :] 93 | return pcl_denoised, traj 94 | else: 95 | return pcl_denoised 96 | ''' 97 | 98 | def patch_based_upsample_big(model, pcl_low, pcl_noisy, step_size=0.15, num_steps=50, patch_size=512, seed_k=3, denoise_knn=4, step_decay=0.99, get_traj=False): 99 | """ 100 | Args: 101 | pcl_noisy: Input point cloud, (N, 3) 102 | """ 103 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).' 104 | N, d = pcl_noisy.size() 105 | M , d = pcl_low.size() 106 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3) 107 | pcl_low = pcl_low.unsqueeze(0) 108 | seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size)) 109 | _, _, patches_noisy = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy, K=int(patch_size), return_nn=True) 110 | patches_noisy = patches_noisy[0] 111 | _, _, patches_low = pytorch3d.ops.knn_points(seed_pnts, pcl_low, K=patch_size, return_nn=True) 112 | patches_low = patches_low[0] 113 | patches_low = split_tensor_to_segments( patches_low, 5) 114 | patches_noisy = split_tensor_to_segments(patches_noisy ,5 ) 115 | n = len(patches_low) 116 | patches_denoised = [] 117 | for i in range(n) : 118 | patch_denoised, traj = model.resample( 119 | p_init=patches_noisy[i], 120 | p_ctx=patches_low[i], 121 | step_size=step_size, 122 | step_decay=step_decay, 123 | num_steps=num_steps 124 | ) 125 | patches_denoised.append(patch_denoised) 126 | patches_denoised = torch.cat(patches_denoised , dim=0) 127 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N) 128 | pcl_denoised = pcl_denoised[0] 129 | fps_idx = fps_idx[0] 130 | 131 | if get_traj: 132 | for i in range(len(traj)): 133 | traj[i] = traj[i].view(-1, d)[fps_idx, :] 134 | return pcl_denoised, traj 135 | else: 136 | return pcl_denoised 137 | 138 | 139 | def patch_based_denoise_big(model, pcl_noisy, step_size=0.15, num_steps=50, patch_size=10000, seed_k=3, denoise_knn=4, step_decay=0.95, get_traj=False): 140 | """ 141 | Args: 142 | pcl_noisy: Input point cloud, (N, 3) 143 | """ 144 | assert pcl_noisy.dim() == 2, 'The shape of input point cloud must be (N, 3).' 145 | N, d = pcl_noisy.size() 146 | pcl_noisy = pcl_noisy.unsqueeze(0) # (1, N, 3) 147 | seed_pnts, _ = farthest_point_sampling(pcl_noisy, int(seed_k * N / patch_size)) 148 | _, _, patches = pytorch3d.ops.knn_points(seed_pnts, pcl_noisy, K=patch_size, return_nn=True) 149 | print("here!") 150 | patches = patches[0] # (N, K, 3) 151 | patches = split_tensor_to_segments( patches, 5) 152 | n = len(patches) 153 | patches_denoised = [] 154 | for i in range(n) : 155 | patch_denoised, traj = model.resample( 156 | p_init=patches[i], 157 | p_ctx=patches[i], 158 | step_size=step_size, 159 | step_decay=step_decay, 160 | num_steps=num_steps 161 | ) 162 | patches_denoised.append(patch_denoised) 163 | patches_denoised = torch.cat(patches_denoised , dim=0) 164 | pcl_denoised, fps_idx = farthest_point_sampling(patches_denoised.view(1, -1, d), N) 165 | pcl_denoised = pcl_denoised[0] 166 | fps_idx = fps_idx[0] 167 | 168 | if get_traj: 169 | for i in range(len(traj)): 170 | traj[i] = traj[i].view(-1, d)[fps_idx, :] 171 | return pcl_denoised, traj 172 | else: 173 | return pcl_denoised 174 | 175 | def denoise_large_pointcloud(model, pcl, cluster_size, seed=0): 176 | device = pcl.device 177 | pcl = pcl.cpu().numpy() 178 | 179 | print('Running KMeans to construct clusters...') 180 | n_clusters = math.ceil(pcl.shape[0] / cluster_size) 181 | print(n_clusters) 182 | kmeans = KMeans(n_clusters=n_clusters, random_state=seed).fit(pcl) 183 | 184 | pcl_parts = [] 185 | for i in tqdm(range(n_clusters), desc='Denoise Cluster'): 186 | pts_idx = kmeans.labels_ == i 187 | 188 | pcl_part_noisy = torch.FloatTensor(pcl[pts_idx]).to(device) 189 | pcl_part_noisy, center, scale = NormalizeUnitSphere.normalize(pcl_part_noisy) 190 | pcl_part_denoised = patch_based_denoise( 191 | model, 192 | pcl_part_noisy, 193 | seed_k=3 194 | ) 195 | pcl_part_denoised = pcl_part_denoised * scale + center 196 | pcl_parts.append(pcl_part_denoised) 197 | 198 | return torch.cat(pcl_parts, dim=0) 199 | -------------------------------------------------------------------------------- /scripts/train_upsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import shutil 4 | import yaml 5 | from easydict import EasyDict 6 | from tqdm.auto import tqdm 7 | import itertools 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.utils.tensorboard 11 | from models.resampler import PointSetResampler 12 | from models.common import chamfer_distance_unit_sphere,coarsesGenerator 13 | from utils.datasets import UpsampleDataset, PairedUpsDataset 14 | from utils.transforms import * 15 | from utils.misc import * 16 | from utils.denoise import patch_based_denoise, patch_based_upsample 17 | import itertools 18 | from utils.evaluate import * 19 | def batch_pairwise_dist( x, y): 20 | bs, num_points_x, points_dim = x.size() 21 | _, num_points_y, _ = y.size() 22 | xx = torch.bmm(x, x.transpose(2, 1)) 23 | yy = torch.bmm(y, y.transpose(2, 1)) 24 | zz = torch.bmm(x, y.transpose(2, 1)) 25 | diag_ind_x = torch.arange(0, num_points_x) 26 | diag_ind_y = torch.arange(0, num_points_y) 27 | if x.get_device() != -1: 28 | diag_ind_x = diag_ind_x.cuda(x.get_device()) 29 | diag_ind_y = diag_ind_y.cuda(x.get_device()) 30 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1)) 31 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 32 | P = (rx.transpose(2, 1) + ry - 2 * zz) 33 | return P 34 | 35 | def get_cd_loss( ref , gen): 36 | P = batch_pairwise_dist(ref, gen) 37 | mins, _ = torch.min(P, 1) 38 | loss_1 = torch.mean(mins) 39 | mins, _ = torch.min(P, 2) 40 | loss_2 = torch.mean(mins) 41 | return loss_1 + loss_2 42 | if __name__ == '__main__': 43 | r = 5 44 | best = 10000 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--config', type=str) 47 | parser.add_argument('--device', type=str, default='cuda') 48 | parser.add_argument('--log_root', type=str, default='./logs') 49 | args = parser.parse_args() 50 | 51 | # Load configs 52 | with open(args.config, 'r') as f: 53 | config = EasyDict(yaml.safe_load(f)) 54 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')] 55 | seed_all(config.train.seed) 56 | 57 | # Logging 58 | log_dir = get_new_log_dir(args.log_root, prefix=config_name + '_' +'retrain_ups_gen') 59 | ckpt_mgr = CheckpointManager(log_dir) 60 | ckpt_dir = os.path.join(log_dir, 'checkpoints') 61 | os.makedirs(ckpt_dir, exist_ok=True) 62 | logger = get_logger('train', log_dir) 63 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 64 | logger.info(args) 65 | logger.info(config) 66 | shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config))) 67 | 68 | # Datasets and loaders 69 | logger.info('Loading datasets') 70 | train_dset = PairedUpsDataset( 71 | datasets=[UpsampleDataset( 72 | root=config.dataset.dataset_root, 73 | dataset=config.dataset.dataset, 74 | rate = 4 , 75 | split='train', 76 | resolution=resl, 77 | noise_min = 1e-2, 78 | noise_max = 2.5e-2, 79 | ) for resl in config.dataset.resolutions 80 | ], 81 | patch_size=config.dataset.patch_size, 82 | patch_ratio=4.2, 83 | on_the_fly=True 84 | ) 85 | 86 | val_dset = UpsampleDataset( 87 | root=config.dataset.dataset_root, 88 | dataset=config.dataset.dataset, 89 | split='test', 90 | resolution=config.dataset.resolutions[1], 91 | rate = 4, 92 | noise_min = 1.7e-2, 93 | noise_max = 1.7e-2, 94 | need_mesh=True 95 | ) 96 | train_iter = get_data_iterator(DataLoader(train_dset, batch_size=config.train.train_batch_size, num_workers=config.train.num_workers, shuffle=True)) 97 | 98 | # Model 99 | logger.info('Building model...') 100 | model = PointSetResampler(config.model).to(args.device) 101 | upsampler = coarsesGenerator(rate = 8 ).to(args.device) 102 | #logger.info(repr(model)) 103 | #logger.info(repr(upsampler)) 104 | 105 | # Optimizer and Scheduler 106 | optimizer = torch.optim.Adam( 107 | itertools.chain(upsampler.parameters(),model.parameters()), 108 | #model.parameters(), 109 | lr=config.train.lr, 110 | weight_decay=config.train.weight_decay, 111 | ) 112 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 113 | optimizer, 114 | mode="min", 115 | factor=config.scheduler.factor, 116 | patience=config.scheduler.patience, 117 | threshold=config.scheduler.threshold, 118 | ) 119 | def train(it): 120 | batch = next(train_iter) 121 | pcl_low = batch['pcl_low'].to(args.device) 122 | pcl_noisy = batch['pcl_noisy'].to(args.device) 123 | pcl_gt = batch['pcl_gt'].to(args.device) 124 | # Reset grad and model state 125 | model.train() 126 | upsampler.train() 127 | optimizer.zero_grad() 128 | # Forward 129 | 130 | pcl_noisy = upsampler(pcl_low) 131 | cd_loss = get_cd_loss( 132 | gen = pcl_noisy , 133 | ref = pcl_gt 134 | ) 135 | 136 | vec_loss = model.get_loss_pc( 137 | p_query=pcl_noisy, 138 | p_ctx=pcl_low, 139 | p_gt=pcl_gt, 140 | avg_knn=config.train.vec_avg_knn, 141 | ) 142 | loss = vec_loss 143 | loss += cd_loss 144 | # Backward 145 | loss.backward() 146 | optimizer.step() 147 | # Logging 148 | #logger.info('[Train] Iter %04d |cd Loss %.6f | vec Loss %.6f' % ( 149 | # it, cd_loss.item(),vec_loss.item() 150 | #)) 151 | logger.info('[Train] Iter %04d || vec Loss %.6f' % ( 152 | it,vec_loss.item() 153 | )) 154 | writer.add_scalar('train/loss', loss, it) 155 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it) 156 | writer.flush() 157 | 158 | def validate(it): 159 | global best 160 | all_clean = [] 161 | all_denoised = [] 162 | #upsampler.eval() 163 | avg_p2m = 0 164 | with torch.no_grad(): 165 | for i, data in enumerate(tqdm(val_dset, desc='Validate')): 166 | 167 | pcl_noisy = data['ups'].to(args.device) 168 | pcl_low = data['original'].to(args.device) 169 | pcl_clean = data['gt'].to(args.device) 170 | #if it != 1 : 171 | pcl_denoised = patch_based_upsample(model,pcl_low, pcl_noisy) 172 | #pcl_denoised = patch_based_denoise(model,pcl_noisy) 173 | avg_p2m += point_mesh_bidir_distance_single_unit_sphere( 174 | pcl=pcl_denoised, 175 | verts=data['meshes']['verts'].to(args.device), 176 | faces=data['meshes']['faces'].to(args.device) 177 | ).mean() 178 | all_clean.append(pcl_clean.unsqueeze(0)) 179 | all_denoised.append(pcl_denoised.unsqueeze(0)) 180 | 181 | all_clean = torch.cat(all_clean, dim=0) 182 | all_denoised = torch.cat(all_denoised, dim=0) 183 | avg_chamfer = chamfer_distance_unit_sphere(all_denoised, all_clean, batch_reduction='mean')[0].item() 184 | avg_p2m /= len(val_dset) 185 | logger.info('[Val] Iter %04d | CD %.6f P2M %.6f ' % (it, avg_chamfer , avg_p2m)) 186 | writer.add_scalar('val/chamfer', avg_chamfer, it) 187 | writer.add_scalar('val/p2m', avg_p2m, it) 188 | if avg_p2m < best : 189 | best = avg_p2m 190 | torch.save(model.state_dict(),log_dir+'/model.pth') 191 | torch.save(upsampler.state_dict(),log_dir+'/upsmodel.pth') 192 | writer.flush() 193 | scheduler.step(avg_chamfer) 194 | return avg_chamfer , avg_p2m 195 | 196 | # Main loop 197 | logger.info('Start training...') 198 | try: 199 | for it in range(1, config.train.max_iters+1): 200 | train(it) 201 | if it % config.train.val_freq == 0 or it == config.train.max_iters: 202 | cd_loss , _ = validate(it) 203 | ''' 204 | opt_states = { 205 | 'optimizer': optimizer.state_dict(), 206 | 'scheduler': scheduler.state_dict(), 207 | } 208 | ckpt_mgr.save(model, args, cd_loss, opt_states, step=it) 209 | ''' 210 | 211 | except KeyboardInterrupt: 212 | logger.info('Terminating...') 213 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numbers 4 | from numpy.core.fromnumeric import size 5 | import torch 6 | import numpy as np 7 | from torchvision.transforms import Compose 8 | 9 | 10 | class NormalizeUnitSphere(object): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | @staticmethod 16 | def normalize(pcl, center=None, scale=None): 17 | """ 18 | Args: 19 | pcl: The point cloud to be normalized, (N, 3) 20 | """ 21 | if center is None: 22 | p_max = pcl.max(dim=0, keepdim=True)[0] 23 | p_min = pcl.min(dim=0, keepdim=True)[0] 24 | center = (p_max + p_min) / 2 # (1, 3) 25 | pcl = pcl - center 26 | if scale is None: 27 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1) 28 | pcl = pcl / scale 29 | return pcl, center, scale 30 | 31 | def __call__(self, data): 32 | assert 'pcl_noisy' not in data, 'Point clouds must be normalized before applying noise perturbation.' 33 | data['pcl_clean'], center, scale = self.normalize(data['pcl_clean']) 34 | data['center'] = center 35 | data['scale'] = scale 36 | return data 37 | 38 | 39 | class AddNoise(object): 40 | 41 | def __init__(self, noise_std_min, noise_std_max): 42 | super().__init__() 43 | self.noise_std_min = noise_std_min 44 | self.noise_std_max = noise_std_max 45 | 46 | def __call__(self, data): 47 | noise_std = random.uniform(self.noise_std_min, self.noise_std_max) 48 | data['pcl_noisy'] = data['pcl_clean'] + torch.randn_like(data['pcl_clean']) * noise_std 49 | data['noise_std'] = noise_std 50 | return data 51 | 52 | 53 | class AddLaplacianNoise(object): 54 | 55 | def __init__(self, noise_std_min, noise_std_max): 56 | super().__init__() 57 | self.noise_std_min = noise_std_min 58 | self.noise_std_max = noise_std_max 59 | 60 | def __call__(self, data): 61 | noise_std = random.uniform(self.noise_std_min, self.noise_std_max) 62 | noise = torch.FloatTensor(np.random.laplace(0, noise_std, size=data['pcl_clean'].shape)).to(data['pcl_clean']) 63 | data['pcl_noisy'] = data['pcl_clean'] + noise 64 | data['noise_std'] = noise_std 65 | return data 66 | 67 | 68 | class AddUniformBallNoise(object): 69 | 70 | def __init__(self, scale): 71 | super().__init__() 72 | self.scale = scale 73 | 74 | def __call__(self, data): 75 | N = data['pcl_clean'].shape[0] 76 | phi = np.random.uniform(0, 2*np.pi, size=N) 77 | costheta = np.random.uniform(-1, 1, size=N) 78 | u = np.random.uniform(0, 1, size=N) 79 | theta = np.arccos(costheta) 80 | r = self.scale * u ** (1/3) 81 | 82 | noise = np.zeros([N, 3]) 83 | noise[:, 0] = r * np.sin(theta) * np.cos(phi) 84 | noise[:, 1] = r * np.sin(theta) * np.sin(phi) 85 | noise[:, 2] = r * np.cos(theta) 86 | noise = torch.FloatTensor(noise).to(data['pcl_clean']) 87 | data['pcl_noisy'] = data['pcl_clean'] + noise 88 | return data 89 | 90 | 91 | class AddCovNoise(object): 92 | 93 | def __init__(self, cov, std_factor=1.0): 94 | super().__init__() 95 | self.cov = torch.FloatTensor(cov) 96 | self.std_factor = std_factor 97 | 98 | def __call__(self, data): 99 | num_points = data['pcl_clean'].shape[0] 100 | noise = np.random.multivariate_normal(np.zeros(3), self.cov.numpy(), num_points) # (N, 3) 101 | noise = torch.FloatTensor(noise).to(data['pcl_clean']) 102 | data['pcl_noisy'] = data['pcl_clean'] + noise * self.std_factor 103 | data['noise_std'] = self.std_factor 104 | return data 105 | 106 | 107 | class AddDiscreteNoise(object): 108 | 109 | def __init__(self, scale, prob=0.1): 110 | super().__init__() 111 | self.scale = scale 112 | self.prob = prob 113 | self.template = np.array([ 114 | [1, 0, 0], 115 | [-1, 0, 0], 116 | [0, 1, 0], 117 | [0, -1, 0], 118 | [0, 0, 1], 119 | [0, 0, -1], 120 | ], dtype=np.float32) 121 | 122 | def __call__(self, data): 123 | num_points = data['pcl_clean'].shape[0] 124 | uni_rand = np.random.uniform(size=num_points) 125 | noise = np.zeros([num_points, 3]) 126 | for i in range(self.template.shape[0]): 127 | idx = np.logical_and(0.1*i <= uni_rand, uni_rand < 0.1*(i+1)) 128 | noise[idx] = self.template[i].reshape(1, 3) 129 | noise = torch.FloatTensor(noise).to(data['pcl_clean']) 130 | # print(data['pcl_clean']) 131 | # print(self.scale) 132 | data['pcl_noisy'] = data['pcl_clean'] + noise * self.scale 133 | data['noise_std'] = self.scale 134 | return data 135 | 136 | 137 | class RandomScale(object): 138 | 139 | def __init__(self, scales): 140 | assert isinstance(scales, (tuple, list)) and len(scales) == 2 141 | self.scales = scales 142 | 143 | def __call__(self, data): 144 | scale = random.uniform(*self.scales) 145 | ''' 146 | for k in data: 147 | if k != 'name': 148 | data[k] = data[k]* scale 149 | ''' 150 | data['pcl_clean'] = data['pcl_clean'] * scale 151 | if 'pcl_noisy' in data: 152 | data['pcl_noisy'] = data['pcl_noisy'] * scale 153 | return data 154 | 155 | 156 | class RandomRotate(object): 157 | 158 | def __init__(self, degrees=180.0, axis=0): 159 | if isinstance(degrees, numbers.Number): 160 | degrees = (-abs(degrees), abs(degrees)) 161 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 162 | self.degrees = degrees 163 | self.axis = axis 164 | 165 | def __call__(self, data): 166 | degree = math.pi * random.uniform(*self.degrees) / 180.0 167 | sin, cos = math.sin(degree), math.cos(degree) 168 | 169 | if self.axis == 0: 170 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 171 | elif self.axis == 1: 172 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 173 | else: 174 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 175 | matrix = torch.tensor(matrix) 176 | ''' 177 | for k in data: 178 | if k != 'name' and k != 'center' and k != 'scale': 179 | data[k] = torch.matmul(data[k],matrix) 180 | 181 | ''' 182 | data['pcl_clean'] = torch.matmul(data['pcl_clean'], matrix) 183 | if 'pcl_noisy' in data: 184 | data['pcl_noisy'] = torch.matmul(data['pcl_noisy'], matrix) 185 | return data 186 | 187 | 188 | def standard_train_transforms(noise_std_min, noise_std_max, scale_d=0.2, rotate=True): 189 | transforms = [ 190 | #NormalizeUnitSphere(), 191 | AddNoise(noise_std_min=noise_std_min, noise_std_max=noise_std_max), 192 | RandomScale([1.0-scale_d, 1.0+scale_d]), 193 | ] 194 | if rotate: 195 | transforms += [ 196 | RandomRotate(axis=0), 197 | RandomRotate(axis=1), 198 | RandomRotate(axis=2), 199 | ] 200 | return Compose(transforms) 201 | class add_varying_noise(object): 202 | 203 | def __init__(self, noise_std_min, noise_std_max): 204 | super().__init__() 205 | self.noise_std_min = noise_std_min 206 | self.noise_std_max = noise_std_max 207 | 208 | def __call__(self, data): 209 | noise_std = random.uniform(self.noise_std_min, self.noise_std_max) 210 | N = data['pcl_clean'].shape[0] 211 | p_max = data['pcl_clean'].max(dim=0, keepdim=False)[0] 212 | p_min = data['pcl_clean'].min(dim=0, keepdim=False)[0] 213 | center = (p_max + p_min) / 2 # (3,) 214 | noise_top = torch.FloatTensor(np.random.laplace(0, noise_std, size=data['pcl_clean'].shape)).to(data['pcl_clean']) 215 | phi = np.random.uniform(0, 2*np.pi, size=N) 216 | costheta = np.random.uniform(-1, 1, size=N) 217 | u = np.random.uniform(0, 1, size=N) 218 | theta = np.arccos(costheta) 219 | r = noise_std * u ** (1/3) 220 | 221 | noise_bottom= np.zeros([N, 3]) 222 | noise_bottom[:, 0] = r * np.sin(theta) * np.cos(phi) 223 | noise_bottom[:, 1] = r * np.sin(theta) * np.sin(phi) 224 | noise_bottom[:, 2] = r * np.cos(theta) 225 | noise_bottom = torch.FloatTensor(noise_bottom).to(data['pcl_clean']) 226 | 227 | idx_top = data['pcl_clean'][:,0]>center[0] 228 | idx_bottom = data['pcl_clean'][:,0] <= center[0] 229 | noise_top = torch.mul(noise_top , idx_top.unsqueeze(-1).repeat(1,3)) 230 | noise_bottom = torch.mul(noise_bottom , idx_bottom.unsqueeze(-1).repeat(1,3)) 231 | data['pcl_noisy'] = data['pcl_clean'] + noise_top + noise_bottom 232 | data['noise_std'] = noise_std 233 | return data 234 | 235 | -------------------------------------------------------------------------------- /utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch3d 4 | import pytorch3d.loss 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation 7 | import pandas as pd 8 | import point_cloud_utils as pcu 9 | from tqdm.auto import tqdm 10 | 11 | from .misc import BlackHole 12 | import glob 13 | import h5py 14 | def load_h5(): 15 | DATA_DIR = "./" 16 | all_data = [] 17 | all_label = [] 18 | h5_name = os.path.join(DATA_DIR, 'PUGAN_poisson_256_poisson_1024.h5') 19 | f = h5py.File(h5_name) 20 | data = f['data'][:].astype('float32') 21 | label = f['label'][:].astype('int64') 22 | f.close() 23 | all_data.append(data) 24 | all_label.append(label) 25 | all_data = np.concatenate(all_data, axis=0) 26 | all_label = np.concatenate(all_label, axis=0) 27 | return all_data, all_label 28 | 29 | def load_xyz(xyz_dir): 30 | all_pcls = {} 31 | for fn in tqdm(os.listdir(xyz_dir), desc='Loading'): 32 | if fn[-3:] != 'xyz': 33 | continue 34 | name = fn[:-4] 35 | path = os.path.join(xyz_dir, fn) 36 | all_pcls[name] = torch.FloatTensor(np.loadtxt(path, dtype=np.float32)) 37 | return all_pcls 38 | 39 | def load_off(off_dir): 40 | all_meshes = {} 41 | for fn in tqdm(os.listdir(off_dir), desc='Loading'): 42 | if fn[-3:] != 'off': 43 | continue 44 | name = fn[:-4] 45 | path = os.path.join(off_dir, fn) 46 | verts, faces, _ = pcu.read_off(path) 47 | verts = torch.FloatTensor(verts) 48 | faces = torch.LongTensor(faces) 49 | all_meshes[name] = {'verts': verts, 'faces': faces} 50 | return all_meshes 51 | 52 | 53 | class Evaluator(object): 54 | 55 | def __init__(self, output_pcl_dir, dataset_root, dataset, summary_dir, experiment_name, device='cuda', res_gts='8192_poisson', logger=BlackHole()): 56 | super().__init__() 57 | self.output_pcl_dir = output_pcl_dir 58 | self.dataset_root = dataset_root 59 | self.dataset = dataset 60 | self.summary_dir = summary_dir 61 | self.experiment_name = experiment_name 62 | self.gts_pcl_dir = os.path.join(dataset_root, dataset, 'pointclouds', 'test', res_gts) 63 | self.gts_mesh_dir = os.path.join(dataset_root, dataset, 'meshes', 'test') 64 | self.res_gts = res_gts 65 | self.device = device 66 | self.logger = logger 67 | self.load_data() 68 | 69 | def load_data(self): 70 | self.pcls_up = load_xyz(self.output_pcl_dir) 71 | self.pcls_high = load_xyz(self.gts_pcl_dir) 72 | self.meshes = load_off(self.gts_mesh_dir) 73 | self.pcls_name = list(self.pcls_up.keys()) 74 | 75 | def run(self): 76 | pcls_up, pcls_high, pcls_name = self.pcls_up, self.pcls_high, self.pcls_name 77 | results = {} 78 | for name in tqdm(pcls_name, desc='Evaluate'): 79 | pcl_up = pcls_up[name][:,:3].unsqueeze(0).to(self.device) 80 | if name not in pcls_high: 81 | self.logger.warning('Shape `%s` not found, ignored.' % name) 82 | continue 83 | pcl_high = pcls_high[name].unsqueeze(0).to(self.device) 84 | verts = self.meshes[name]['verts'].to(self.device) 85 | faces = self.meshes[name]['faces'].to(self.device) 86 | 87 | cd = pytorch3d.loss.chamfer_distance(pcl_up, pcl_high)[0].item() 88 | cd_sph = chamfer_distance_unit_sphere(pcl_up, pcl_high)[0].item() 89 | hd_sph = hausdorff_distance_unit_sphere(pcl_up, pcl_high)[0].item() 90 | 91 | # p2f = point_to_mesh_distance_single_unit_sphere( 92 | # pcl=pcl_up[0], 93 | # verts=verts, 94 | # faces=faces 95 | # ).sqrt().mean().item() 96 | if 'blensor' in self.experiment_name: 97 | rotmat = torch.FloatTensor(Rotation.from_euler('xyz', [-90, 0, 0], degrees=True).as_matrix()).to(pcl_up[0]) 98 | p2f = point_mesh_bidir_distance_single_unit_sphere( 99 | pcl=pcl_up[0].matmul(rotmat.t()), 100 | verts=verts, 101 | faces=faces 102 | ).item() 103 | else: 104 | p2f = point_mesh_bidir_distance_single_unit_sphere( 105 | pcl=pcl_up[0], 106 | verts=verts, 107 | faces=faces 108 | ).item() 109 | 110 | results[name] = { 111 | 'cd': cd, 112 | 'cd_sph': cd_sph, 113 | 'p2f': p2f, 114 | 'hd_sph': hd_sph, 115 | } 116 | 117 | results = pd.DataFrame(results).transpose() 118 | res_mean = results.mean(axis=0) 119 | self.logger.info("\n" + repr(results)) 120 | self.logger.info("\nMean\n" + '\n'.join([ 121 | '%s\t%.12f' % (k, v) for k, v in res_mean.items() 122 | ])) 123 | 124 | update_summary( 125 | os.path.join(self.summary_dir, 'Summary_%s.csv' % self.dataset), 126 | model=self.experiment_name, 127 | metrics={ 128 | # 'cd(mean)': res_mean['cd'], 129 | 'cd_sph(mean)': res_mean['cd_sph'], 130 | 'p2f(mean)': res_mean['p2f'], 131 | 'hd_sph(mean)': res_mean['hd_sph'], 132 | } 133 | ) 134 | 135 | 136 | def update_summary(path, model, metrics): 137 | if os.path.exists(path): 138 | df = pd.read_csv(path, index_col=0, sep="\s*,\s*", engine='python') 139 | else: 140 | df = pd.DataFrame() 141 | for metric, value in metrics.items(): 142 | setting = metric 143 | if setting not in df.columns: 144 | df[setting] = np.nan 145 | df.loc[model, setting] = value 146 | df.to_csv(path, float_format='%.12f') 147 | return df 148 | def hausdorff_distance_unit_sphere(gen, ref): 149 | """ 150 | Args: 151 | gen: (B, N, 3) 152 | ref: (B, N, 3) 153 | Returns: 154 | (B, ) 155 | """ 156 | ref, center, scale = normalize_sphere(ref) 157 | gen = normalize_pcl(gen, center, scale) 158 | 159 | dists_ab, _, _ = pytorch3d.ops.knn_points(ref, gen, K=1) 160 | dists_ab = dists_ab[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1) 161 | # print(dists_ab) 162 | 163 | dists_ba, _, _ = pytorch3d.ops.knn_points(gen, ref, K=1) 164 | dists_ba = dists_ba[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1) 165 | # print(dists_ba) 166 | 167 | dists_hausdorff = torch.max(torch.cat([dists_ab, dists_ba], dim=1), dim=1)[0] 168 | 169 | return dists_hausdorff 170 | def normalize_sphere(pc, radius=1.0): 171 | """ 172 | Args: 173 | pc: A batch of point clouds, (B, N, 3). 174 | """ 175 | ## Center 176 | p_max = pc.max(dim=-2, keepdim=True)[0] 177 | p_min = pc.min(dim=-2, keepdim=True)[0] 178 | center = (p_max + p_min) / 2 # (B, 1, 3) 179 | pc = pc - center 180 | ## Scale 181 | scale = (pc ** 2).sum(dim=-1, keepdim=True).sqrt().max(dim=-2, keepdim=True)[0] / radius # (B, N, 1) 182 | pc = pc / scale 183 | return pc, center, scale 184 | 185 | def normalize_pcl(pc, center, scale): 186 | return (pc - center) / scale 187 | ''' 188 | def pointwise_p2m_distance_normalized(pcl, verts, faces): 189 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.' 190 | # Normalize mesh 191 | #print(' before :verts %.6f pcl %.6f' % (verts.abs().max().item(), pcl.abs().max().item())) 192 | verts, center, scale = normalize_sphere(verts.unsqueeze(0)) 193 | verts = verts[0] 194 | # Normalize pcl 195 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale) 196 | pcl = pcl[0] 197 | #print('after :verts %.6f pcl %.6f' % (verts.abs().max().item(), pcl.abs().max().item())) 198 | # Convert them to pytorch3d structures 199 | pcls = pytorch3d.structures.Pointclouds([pcl]) 200 | meshes = pytorch3d.structures.Meshes([verts], [faces]) 201 | 202 | # packed representation for pointclouds 203 | points = pcls.points_packed() # (P, 3) 204 | points_first_idx = pcls.cloud_to_packed_first_idx() 205 | max_points = pcls.num_points_per_cloud().max().item() 206 | 207 | # packed representation for faces 208 | verts_packed = meshes.verts_packed() 209 | faces_packed = meshes.faces_packed() 210 | tris = verts_packed[faces_packed] # (T, 3, 3) 211 | tris_first_idx = meshes.mesh_to_faces_packed_first_idx() 212 | max_tris = meshes.num_faces_per_mesh().max().item() 213 | # point to face distance: shape (P,) 214 | point_to_face = point_face_distance( 215 | points, points_first_idx, tris, tris_first_idx, max_points 216 | ) 217 | return point_to_face 218 | ''' 219 | def point_mesh_bidir_distance_single_unit_sphere(pcl, verts, faces): 220 | """ 221 | Args: 222 | pcl: (N, 3). 223 | verts: (M, 3). 224 | faces: LongTensor, (T, 3). 225 | Returns: 226 | Squared pointwise distances, (N, ). 227 | """ 228 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.' 229 | 230 | # Normalize mesh 231 | verts, center, scale = normalize_sphere(verts.unsqueeze(0)) 232 | verts = verts[0] 233 | # Normalize pcl 234 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale) 235 | pcl = pcl[0] 236 | 237 | # Convert them to pytorch3d structures 238 | pcls = pytorch3d.structures.Pointclouds([pcl]) 239 | meshes = pytorch3d.structures.Meshes([verts], [faces]) 240 | return pytorch3d.loss.point_mesh_face_distance(meshes, pcls) -------------------------------------------------------------------------------- /utils/datasets/pcl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from tqdm.auto import tqdm 6 | import point_cloud_utils as pcu 7 | import random 8 | 9 | class PointCloudDataset(Dataset): 10 | def __init__(self, root, dataset , split, resolution, from_saved = False , noise_level = None , input_root = None , need_mesh = False , transform=None): 11 | super().__init__() 12 | if from_saved == False : 13 | self.pcl_dir = os.path.join(root, dataset, 'pointclouds', split, resolution) #during training, set from_saved = False 14 | else : 15 | self.pcl_dir = os.path.join(input_root , dataset+"_"+resolution+"_"+str(noise_level)) #during evaluating, set from_saved=True 16 | print(self.pcl_dir) 17 | self.resolution = resolution 18 | self.transform = transform 19 | self.need_mesh = need_mesh 20 | if self.need_mesh == False: 21 | self.pointclouds = [] 22 | self.pointcloud_names = [] 23 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'): 24 | if fn[-3:] != 'xyz': 25 | continue 26 | pcl_path = os.path.join(self.pcl_dir, fn) 27 | if not os.path.exists(pcl_path): 28 | raise FileNotFoundError('File not found: %s' % pcl_path) 29 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32)) 30 | self.pointclouds.append(pcl) 31 | self.pointcloud_names.append(fn[:-4]) 32 | if self.need_mesh == True: 33 | self.mesh_dir = os.path.join(root, dataset, 'meshes', 'test') 34 | self.meshes = {} 35 | self.meshes_names = [] 36 | self.pointclouds = {} 37 | self.pointcloud_names = [] 38 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'): 39 | if fn[-3:] != 'xyz': 40 | continue 41 | pcl_path = os.path.join(self.pcl_dir, fn) 42 | if not os.path.exists(pcl_path): 43 | raise FileNotFoundError('File not found: %s' % pcl_path) 44 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32)) 45 | name = fn[:-4] 46 | self.pointclouds[name] = pcl 47 | self.pointcloud_names.append(name) 48 | for fn in tqdm(os.listdir(self.mesh_dir), desc='Loading'): 49 | if fn[-3:] != 'off': 50 | continue 51 | mesh_path = os.path.join(self.mesh_dir , fn) 52 | if not os.path.exists(mesh_path): 53 | raise FileNotFoundError('File not found: %s' % pcl_path) 54 | verts, faces, = pcu.load_mesh_vf(mesh_path) 55 | verts = torch.FloatTensor(verts) 56 | faces = torch.LongTensor(faces) 57 | name = fn[:-4] 58 | self.meshes[name] = {'verts': verts, 'faces': faces} 59 | self.meshes_names.append(name) 60 | def __len__(self): 61 | return len(self.pointclouds) 62 | 63 | def __getitem__(self, idx): 64 | if self.need_mesh == False : 65 | data = { 66 | 'pcl_clean': self.pointclouds[idx].clone(), 67 | 'name': self.pointcloud_names[idx] 68 | } 69 | if self.transform is not None: 70 | data = self.transform(data) 71 | return data 72 | if self.need_mesh == True : 73 | name = self.pointcloud_names[idx] 74 | data = { 75 | 'pcl_clean': self.pointclouds[name].clone(), 76 | 'name': name, 77 | } 78 | if self.transform is not None: 79 | data = self.transform(data) 80 | data['meshes'] = self.meshes[name] 81 | return data 82 | 83 | class UpsampleDataset(Dataset): 84 | 85 | def __init__(self, root, dataset , split, resolution, rate , noise_min , noise_max , from_saved = False , input_root = None , need_mesh = False , transform=None): 86 | super().__init__() 87 | self.resolution = resolution 88 | self.rate = rate 89 | self.transform = transform 90 | self.need_mesh = need_mesh 91 | self.noise_min = noise_min 92 | self.noise_max = noise_max 93 | if from_saved == False : 94 | self.pcl_dir = os.path.join(root, dataset, 'pointclouds', split, resolution) 95 | else : 96 | self.pcl_dir = os.path.join(input_root , dataset+"_"+resolution+"_"+"0.03") 97 | self.gt_resolution = str(int( self.resolution[:self.resolution.index('_')]) * self.rate ) 98 | #self.gt_resolution = "32768" 99 | self.gt_dir = os.path.join(root, dataset ,'pointclouds' , split , self.gt_resolution+"_poisson" ) 100 | print(self.pcl_dir) 101 | print(self.gt_dir) 102 | if self.need_mesh == False: 103 | self.pointclouds = [] 104 | self.gt_pointclouds = [] 105 | self.pointcloud_names = [] 106 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'): 107 | if fn[-3:] != 'xyz': 108 | continue 109 | pcl_path = os.path.join(self.pcl_dir, fn) 110 | gt_path = os.path.join(self.gt_dir , fn) 111 | if not os.path.exists(pcl_path): 112 | raise FileNotFoundError('File not found: %s' % pcl_path) 113 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32)) 114 | gt = torch.FloatTensor(np.loadtxt(gt_path, dtype=np.float32)) 115 | self.pointclouds.append(pcl) 116 | self.gt_pointclouds.append(gt) 117 | self.pointcloud_names.append(fn[:-4]) 118 | if self.need_mesh == True: 119 | self.mesh_dir = os.path.join(root, dataset, 'meshes', 'test') 120 | self.meshes = {} 121 | self.meshes_names = [] 122 | self.pointclouds = [] 123 | self.gt_pointclouds = [] 124 | self.pointcloud_names = [] 125 | for fn in tqdm(os.listdir(self.pcl_dir), desc='Loading'): 126 | if fn[-3:] != 'xyz': 127 | continue 128 | pcl_path = os.path.join(self.pcl_dir, fn) 129 | gt_path = os.path.join(self.gt_dir , fn) 130 | if not os.path.exists(pcl_path): 131 | raise FileNotFoundError('File not found: %s' % pcl_path) 132 | pcl = torch.FloatTensor(np.loadtxt(pcl_path, dtype=np.float32)) 133 | gt = torch.FloatTensor(np.loadtxt(gt_path, dtype=np.float32)) 134 | self.pointclouds.append(pcl) 135 | self.gt_pointclouds.append(gt) 136 | self.pointcloud_names.append(fn[:-4]) 137 | for fn in tqdm(os.listdir(self.mesh_dir), desc='Loading'): 138 | if fn[-3:] != 'off': 139 | continue 140 | mesh_path = os.path.join(self.mesh_dir , fn) 141 | if not os.path.exists(mesh_path): 142 | raise FileNotFoundError('File not found: %s' % pcl_path) 143 | verts, faces, = pcu.load_mesh_vf(mesh_path) 144 | verts = torch.FloatTensor(verts) 145 | faces = torch.LongTensor(faces) 146 | name = fn[:-4] 147 | self.meshes[name] = {'verts': verts, 'faces': faces} 148 | self.meshes_names.append(name) 149 | def __len__(self): 150 | return len(self.pointclouds) 151 | def normalize(self,pcl, center=None, scale=None): 152 | """ 153 | Args: 154 | pcl: The point cloud to be normalized, (N, 3) 155 | """ 156 | if center is None: 157 | p_max = pcl.max(dim=0, keepdim=True)[0] 158 | p_min = pcl.min(dim=0, keepdim=True)[0] 159 | center = (p_max + p_min) / 2 # (1, 3) 160 | pcl = pcl - center 161 | if scale is None: 162 | scale = (pcl ** 2).sum(dim=1, keepdim=True).sqrt().max(dim=0, keepdim=True)[0] # (1, 1) 163 | pcl = pcl / scale 164 | return pcl, center, scale 165 | def __getitem__(self, idx): 166 | if self.need_mesh == False : 167 | pcl_noisy = [] 168 | #gt , center , scale = self.normalize(self.gt_pointclouds[idx]) 169 | gt = self.gt_pointclouds[idx].clone() 170 | #print(center,scale) 171 | original = self.pointclouds[idx] 172 | #original, _ , _ = self.normalize(self.pointclouds[idx],center , scale) 173 | for i in range(self.rate-1) : 174 | noise_std = random.uniform(self.noise_min, self.noise_max) 175 | data = original + torch.randn_like(original) * noise_std 176 | #data = original + (2*torch.rand_like(original)-1) * noise_std 177 | pcl_noisy.append(data) 178 | pcl_noisy.append(original) 179 | pcl_noisy = torch.cat(pcl_noisy,dim=0) 180 | #pcl_noisy , _ , _ = self.normalize(pcl_noisy , center , scale) 181 | data = { 182 | 'gt': gt, 183 | 'ups' : pcl_noisy, 184 | 'original' : original, 185 | 'name': self.pointcloud_names[idx] 186 | } 187 | if self.transform is not None: 188 | data = self.transform(data) 189 | return data 190 | if self.need_mesh == True : 191 | pcl_noisy = [] 192 | #gt , center , scale = self.normalize(self.gt_pointclouds[idx]) 193 | gt = self.gt_pointclouds[idx].clone() 194 | #print(center,scale) 195 | original = self.pointclouds[idx] 196 | #original, _ , _ = self.normalize(self.pointclouds[idx],center , scale) 197 | for i in range(self.rate-1) : 198 | noise_std = random.uniform(self.noise_min, self.noise_max) 199 | data = original + torch.randn_like(original) * noise_std 200 | #data = original + (2*torch.rand_like(original)-1) * noise_std 201 | pcl_noisy.append(data) 202 | pcl_noisy.append(original) 203 | pcl_noisy = torch.cat(pcl_noisy,dim=0) 204 | #pcl_noisy , _ , _ = self.normalize(pcl_noisy , center , scale) 205 | name = self.pointcloud_names[idx] 206 | data = { 207 | 'gt': gt, 208 | 'ups' : pcl_noisy, 209 | 'original' : original, 210 | 'name': name 211 | } 212 | if self.transform is not None: 213 | data = self.transform(data) 214 | data['meshes'] = self.meshes[name] 215 | return data 216 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import pytorch3d.loss 5 | import pytorch3d.structures 6 | import pytorch3d.ops 7 | from pytorch3d.loss.point_mesh_distance import point_face_distance 8 | from torch_cluster import fps 9 | from torch.nn import Module, ModuleList, Identity, ReLU, Parameter, Sequential, Conv2d, BatchNorm2d, Conv1d, BatchNorm1d 10 | 11 | def batch_pairwise_dist( x, y): 12 | bs, num_points_x, points_dim = x.size() 13 | _, num_points_y, _ = y.size() 14 | xx = torch.bmm(x, x.transpose(2, 1)) 15 | yy = torch.bmm(y, y.transpose(2, 1)) 16 | zz = torch.bmm(x, y.transpose(2, 1)) 17 | diag_ind_x = torch.arange(0, num_points_x) 18 | diag_ind_y = torch.arange(0, num_points_y) 19 | if x.get_device() != -1: 20 | diag_ind_x = diag_ind_x.cuda(x.get_device()) 21 | diag_ind_y = diag_ind_y.cuda(x.get_device()) 22 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1)) 23 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 24 | P = (rx.transpose(2, 1) + ry - 2 * zz) 25 | return P 26 | 27 | 28 | class EdgeConv(Module): 29 | 30 | def __init__(self, in_channels, num_layers, layer_out_dim, knn=16, aggr='max', activation='relu'): 31 | super().__init__() 32 | self.in_channels = in_channels 33 | self.knn = knn 34 | assert num_layers > 2 35 | self.num_layers = num_layers 36 | self.layer_out_dim = layer_out_dim 37 | 38 | # Densely Connected Layers 39 | self.layer_first = FullyConnected(3*in_channels, layer_out_dim, bias=True, activation=activation) 40 | self.layer_last = FullyConnected(in_channels + (num_layers - 1) * layer_out_dim, layer_out_dim, bias=True, activation=None) 41 | self.layers = ModuleList() 42 | for i in range(1, num_layers-1): 43 | self.layers.append(FullyConnected(in_channels + i * layer_out_dim, layer_out_dim, bias=True, activation=activation)) 44 | 45 | self.aggr = Aggregator(aggr) 46 | 47 | @property 48 | def out_channels(self): 49 | return self.in_channels + self.num_layers * self.layer_out_dim 50 | 51 | def get_edge_feature(self, x, knn_idx): 52 | """ 53 | :param x: (B, N, d) 54 | :param knn_idx: (B, N, K) 55 | :return (B, N, K, 2*d) 56 | """ 57 | knn_feat = knn_group(x, knn_idx) # B * N * K * d 58 | x_tiled = x.unsqueeze(-2).expand_as(knn_feat) 59 | edge_feat = torch.cat([x_tiled, knn_feat, knn_feat - x_tiled], dim=3) 60 | return edge_feat 61 | 62 | def forward(self, x, pos): 63 | """ 64 | :param x: (B, N, d) 65 | :return (B, N, d+L*c) 66 | """ 67 | knn_idx = get_knn_idx(pos, pos, k=self.knn, offset=1) 68 | 69 | # First Layer 70 | edge_feat = self.get_edge_feature(x, knn_idx) 71 | y = torch.cat([ 72 | self.layer_first(edge_feat), # (B, N, K, c) 73 | x.unsqueeze(-2).repeat(1, 1, self.knn, 1) # (B, N, K, d) 74 | ], dim=-1) # (B, N, K, d+c) 75 | 76 | # Intermediate Layers 77 | for layer in self.layers: 78 | y = torch.cat([ 79 | layer(y), # (B, N, K, c) 80 | y, # (B, N, K, c+d) 81 | ], dim=-1) # (B, N, K, d+c+...) 82 | 83 | # Last Layer 84 | y = torch.cat([ 85 | self.layer_last(y), # (B, N, K, c) 86 | y # (B, N, K, d+(L-1)*c) 87 | ], dim=-1) # (B, N, K, d+L*c) 88 | 89 | # Pooling 90 | y = self.aggr(y, dim=-2) 91 | 92 | return y 93 | 94 | def gen_grid(rate): 95 | ''' 96 | in : int 97 | out : (rate , 2) 98 | ''' 99 | 100 | sqrted = int(math.sqrt(rate))+1 101 | for i in range(1,sqrted+1).__reversed__(): 102 | if (rate%i) == 0: 103 | num_x = i 104 | num_y = rate//i 105 | break 106 | grid_x = torch.linspace(-0.2, 0.2, num_x) 107 | grid_y = torch.linspace(-0.2, 0.2, num_y) 108 | 109 | x, y = torch.meshgrid(grid_x, grid_y) 110 | grid = torch.reshape(torch.stack([x,y],dim=-1) ,[-1,2]) 111 | return grid 112 | 113 | class GCN_feature_extractor(torch.nn.Module): 114 | 115 | def __init__(self ,conv_growth_rate , knn , block_num) : 116 | super().__init__() 117 | 118 | self.layer_out_dim = conv_growth_rate 119 | self.knn = knn 120 | self.block_num = block_num 121 | 122 | dims = [conv_growth_rate , conv_growth_rate * 5, conv_growth_rate * 10] 123 | print(dims) 124 | self.layers = [] 125 | self.edgeconvs = [] 126 | for i in range(block_num): 127 | if i == 0 : 128 | self.layers += [nn.Conv1d(in_channels= 3 , out_channels = self.layer_out_dim, kernel_size = 1 )] 129 | self.edgeconvs += [EdgeConv(in_channels= self.layer_out_dim , num_layers= 3 , layer_out_dim= self.layer_out_dim , knn = self.knn)] 130 | else : 131 | self.layers += [nn.Conv1d(in_channels= dims[i] , out_channels = self.layer_out_dim * 2 , kernel_size = 1 )] 132 | self.edgeconvs += [EdgeConv(in_channels= self.layer_out_dim * 2 , num_layers= 3 , layer_out_dim= self.layer_out_dim , knn = self.knn)] 133 | self.layers = Sequential(*self.layers) 134 | self.edgeconvs = Sequential(*self.edgeconvs) 135 | def forward(self, points) : 136 | ''' 137 | points : ( B , N , 3) 138 | out_feature : ( B , N , D) 139 | ''' 140 | out_feature = points.permute(0,2,1).contiguous() 141 | for i in range(self.block_num): 142 | cur_feat = self.layers[i](out_feature).permute(0,2,1).contiguous() # (B, N ,D1) 143 | if i == 0 : 144 | out_feature = cur_feat.permute(0,2,1).contiguous() 145 | cur_feat = self.edgeconvs[i](cur_feat , points) # (B ,N ,D2) 146 | out_feature = torch.cat([out_feature.permute(0,2,1).contiguous(),cur_feat],dim=-1) 147 | out_feature = out_feature.permute(0,2,1).contiguous() 148 | 149 | return out_feature.permute(0,2,1).contiguous() 150 | 151 | 152 | class duplicate_up(torch.nn.Module): 153 | 154 | def __init__(self , rate) : 155 | super().__init__() 156 | 157 | self.rate = rate 158 | dims = [360+2 , 256, 128] 159 | conv_layers = [] 160 | 161 | for i in range(len(dims)-1): 162 | conv_layers += [ 163 | Conv1d(dims[i], dims[i+1], kernel_size=1), 164 | BatchNorm1d(dims[i+1]), 165 | ] 166 | if i < len(dims)-2: 167 | conv_layers += [ 168 | ReLU(), 169 | ] 170 | self.layers = Sequential(*conv_layers) 171 | 172 | def forward(self, coarse_feature): 173 | 174 | B, N, d = coarse_feature.shape 175 | feat = coarse_feature.repeat(1, self.rate, 1) # ( B , N*rate , d) 176 | grid = gen_grid(self.rate).unsqueeze(0).to("cuda") # (1, rate , 2) 177 | grid = grid.repeat( B , 1 , N) # ( B , rate , 2*N) 178 | grid = grid.reshape(B, N*self.rate, 2) #(B , N*rate , 2) 179 | feat = torch.cat([feat,grid] ,dim = -1).permute(0,2,1).contiguous() # ( B , d+2 , N*R) 180 | 181 | feat = self.layers(feat) 182 | return feat 183 | 184 | 185 | class regressor (torch.nn.Module): 186 | 187 | def __init__(self) : 188 | super().__init__() 189 | 190 | dims = [128,256,64,3] 191 | conv_layers = [] 192 | 193 | for i in range(len(dims)-1): 194 | conv_layers += [ 195 | Conv1d(dims[i], dims[i+1], kernel_size=1, stride=1), 196 | ] 197 | if i < len(dims)-2: 198 | conv_layers += [ 199 | ReLU(), 200 | ] 201 | self.layers = Sequential(*conv_layers) 202 | 203 | def forward(self , coarse_feature): 204 | B, d, N = coarse_feature.shape 205 | #coarse_feature = coarse_feature.permute(0,2,1).contiguous() 206 | coarse = self.layers(coarse_feature) 207 | 208 | coarse = coarse.permute(0,2,1).contiguous() 209 | return coarse 210 | 211 | class coarsesGenerator(torch.nn.Module): 212 | 213 | def __init__(self, rate , block_num = 3 , knn = 24 , conv_growth_rate = 24 ): 214 | super().__init__() 215 | self.rate = rate 216 | self.block_num = block_num 217 | self.conv_growth_rate = conv_growth_rate 218 | self.knn = knn 219 | 220 | self.GCN = GCN_feature_extractor( self.conv_growth_rate , self.knn, self.block_num) 221 | self.duplicate_up = duplicate_up(self.rate) 222 | self.regressor = regressor() 223 | 224 | def forward(self, points) : 225 | coarse_feature = self.GCN(points) 226 | coarse_feature = self.duplicate_up(coarse_feature) 227 | coarse = self.regressor(coarse_feature) 228 | return coarse 229 | 230 | class FullyConnected(torch.nn.Module): 231 | 232 | def __init__(self, in_features, out_features, bias=True, activation=None): 233 | super().__init__() 234 | 235 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 236 | 237 | if activation is None: 238 | self.activation = torch.nn.Identity() 239 | elif activation == 'relu': 240 | self.activation = torch.nn.ReLU() 241 | elif activation == 'elu': 242 | self.activation = torch.nn.ELU(alpha=1.0) 243 | elif activation == 'lrelu': 244 | self.activation = torch.nn.LeakyReLU(0.1) 245 | else: 246 | raise ValueError() 247 | 248 | def forward(self, x): 249 | return self.activation(self.linear(x)) 250 | 251 | class FeatureExtraction(Module): 252 | 253 | def __init__(self, in_channels=3, dynamic_graph=True, conv_channels=24, num_convs=3, conv_num_layers=3, conv_layer_out_dim=12, conv_knn=16, conv_aggr='max', activation='relu'): 254 | super().__init__() 255 | self.in_channels = in_channels 256 | self.dynamic_graph = dynamic_graph 257 | self.num_convs = num_convs 258 | 259 | # Edge Convolution Units 260 | self.transforms = ModuleList() 261 | self.convs = ModuleList() 262 | for i in range(num_convs): 263 | if i == 0: 264 | trans = FullyConnected(in_channels, conv_channels, bias=True, activation=None) 265 | else: 266 | trans = FullyConnected(in_channels, conv_channels, bias=True, activation=activation) 267 | conv = EdgeConv(conv_channels, num_layers=conv_num_layers, layer_out_dim=conv_layer_out_dim, knn=conv_knn, aggr=conv_aggr, activation=activation) 268 | self.transforms.append(trans) 269 | self.convs.append(conv) 270 | in_channels = conv.out_channels 271 | 272 | @property 273 | def out_channels(self): 274 | return self.convs[-1].out_channels 275 | 276 | def dynamic_graph_forward(self, x): 277 | for i in range(self.num_convs): 278 | x = self.transforms[i](x) 279 | x = self.convs[i](x, x) 280 | return x 281 | 282 | def static_graph_forward(self, pos): 283 | x = pos 284 | for i in range(self.num_convs): 285 | x = self.transforms[i](x) 286 | x = self.convs[i](x, pos) 287 | return x 288 | 289 | def forward(self, x): 290 | if self.dynamic_graph: 291 | return self.dynamic_graph_forward(x) 292 | else: 293 | return self.static_graph_forward(x) 294 | 295 | class FCLayer(torch.nn.Module): 296 | 297 | def __init__(self, in_features, out_features, bias=True, activation=None): 298 | super().__init__() 299 | 300 | self.linear = torch.nn.Linear(in_features, out_features, bias=bias) 301 | 302 | if activation is None: 303 | self.activation = torch.nn.Identity() 304 | elif activation == 'relu': 305 | self.activation = torch.nn.ReLU() 306 | elif activation == 'elu': 307 | self.activation = torch.nn.ELU(alpha=1.0) 308 | elif activation == 'lrelu': 309 | self.activation = torch.nn.LeakyReLU(0.1) 310 | else: 311 | raise ValueError() 312 | 313 | def forward(self, x): 314 | return self.activation(self.linear(x)) 315 | 316 | 317 | class Aggregator(torch.nn.Module): 318 | 319 | def __init__(self, oper): 320 | super().__init__() 321 | assert oper in ('mean', 'sum', 'max') 322 | self.oper = oper 323 | 324 | def forward(self, x, dim=2): 325 | if self.oper == 'mean': 326 | return x.mean(dim=dim, keepdim=False) 327 | elif self.oper == 'sum': 328 | return x.sum(dim=dim, keepdim=False) 329 | elif self.oper == 'max': 330 | ret, _ = x.max(dim=dim, keepdim=False) 331 | return ret 332 | 333 | 334 | class ResnetBlockConv1d(nn.Module): 335 | """ 1D-Convolutional ResNet block class. 336 | Args: 337 | size_in (int): input dimension 338 | size_out (int): output dimension 339 | size_h (int): hidden dimension 340 | """ 341 | 342 | def __init__(self, size_in, size_out , 343 | norm_method='batch_norm', legacy=False): 344 | super().__init__() 345 | # Attributes 346 | 347 | self.size_in = size_in 348 | self.size_out = size_out 349 | # Submodules 350 | if norm_method == 'batch_norm': 351 | norm = nn.BatchNorm1d 352 | elif norm_method == 'sync_batch_norm': 353 | norm = nn.SyncBatchNorm 354 | else: 355 | raise Exception("Invalid norm method: %s" % norm_method) 356 | 357 | self.fc_0 = nn.Conv1d(size_in, size_out, 1) 358 | self.actvn = nn.ReLU() 359 | if size_in == size_out: 360 | self.shortcut = None 361 | else: 362 | self.shortcut = nn.Conv1d(size_in, size_out, 1) 363 | 364 | def forward(self, x): 365 | dx = self.fc_0(x) 366 | if self.shortcut is not None: 367 | x_s = self.shortcut(x) 368 | else: 369 | x_s = x 370 | 371 | out = x_s + dx 372 | 373 | return out 374 | 375 | class ResnetBlockConv2d(nn.Module): 376 | """ 2D-Convolutional ResNet block class. 377 | Args: 378 | size_in (int): input dimension 379 | size_out (int): output dimension 380 | """ 381 | 382 | def __init__(self, size_in, size_out, 383 | norm_method='batch_norm', legacy=False): 384 | super().__init__() 385 | # Attributes 386 | 387 | self.size_in = size_in 388 | self.size_out = size_out 389 | # Submodules 390 | if norm_method == 'batch_norm': 391 | norm = nn.BatchNorm1d 392 | elif norm_method == 'sync_batch_norm': 393 | norm = nn.SyncBatchNorm 394 | else: 395 | raise Exception("Invalid norm method: %s" % norm_method) 396 | 397 | self.fc_0 = nn.Conv2d(size_in, size_out, (1,1)) 398 | self.actvn = nn.ReLU() 399 | if size_in == size_out: 400 | self.shortcut = None 401 | else: 402 | self.shortcut = nn.Conv1d(size_in, size_out, (1,1)) 403 | 404 | def forward(self, x): 405 | dx = self.fc_0(x) 406 | if self.shortcut is not None: 407 | x_s = self.shortcut(x) 408 | else: 409 | x_s = x 410 | 411 | out = x_s + dx 412 | 413 | return out 414 | 415 | 416 | def normalize_sphere(pc, radius=1.0): 417 | """ 418 | Args: 419 | pc: A batch of point clouds, (B, N, 3). 420 | """ 421 | ## Center 422 | p_max = pc.max(dim=-2, keepdim=True)[0] 423 | p_min = pc.min(dim=-2, keepdim=True)[0] 424 | center = (p_max + p_min) / 2 # (B, 1, 3) 425 | pc = pc - center 426 | ## Scale 427 | scale = (pc ** 2).sum(dim=-1, keepdim=True).sqrt().max(dim=-2, keepdim=True)[0] / radius # (B, N, 1) 428 | pc = pc / scale 429 | return pc, center, scale 430 | 431 | 432 | def normalize_std(pc, std=1.0): 433 | """ 434 | Args: 435 | pc: A batch of point clouds, (B, N, 3). 436 | """ 437 | center = pc.mean(dim=-2, keepdim=True) # (B, 1, 3) 438 | pc = pc - center 439 | scale = pc.view(pc.size(0), -1).std(dim=-1).view(pc.size(0), 1, 1) / std 440 | pc = pc / scale 441 | return pc, center, scale 442 | 443 | 444 | def normalize_pcl(pc, center, scale): 445 | return (pc - center) / scale 446 | 447 | 448 | def denormalize_pcl(pc, center, scale): 449 | return pc * scale + center 450 | 451 | 452 | def chamfer_distance_unit_sphere(gen, ref, batch_reduction='mean', point_reduction='mean'): 453 | ref, center, scale = normalize_sphere(ref) 454 | gen = normalize_pcl(gen, center, scale) #ups注释掉这两句之后会有略微的上升,因为已经noramlize过了 455 | return pytorch3d.loss.chamfer_distance(gen, ref, batch_reduction=batch_reduction, point_reduction=point_reduction) 456 | 457 | 458 | def farthest_point_sampling(pcls, num_pnts): 459 | """ 460 | Args: 461 | pcls: A batch of point clouds, (B, N, 3). 462 | num_pnts: Target number of points. 463 | """ 464 | ratio = 0.01 + num_pnts / pcls.size(1) 465 | sampled = [] 466 | indices = [] 467 | for i in range(pcls.size(0)): 468 | idx = fps(pcls[i], ratio=ratio, random_start=False)[:num_pnts] 469 | sampled.append(pcls[i:i+1, idx, :]) 470 | indices.append(idx) 471 | sampled = torch.cat(sampled, dim=0) 472 | return sampled, indices 473 | 474 | 475 | def point_mesh_bidir_distance_single_unit_sphere(pcl, verts, faces): 476 | """ 477 | Args: 478 | pcl: (N, 3). 479 | verts: (M, 3). 480 | faces: LongTensor, (T, 3). 481 | Returns: 482 | Squared pointwise distances, (N, ). 483 | """ 484 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.' 485 | 486 | # Normalize mesh 487 | verts, center, scale = normalize_sphere(verts.unsqueeze(0)) 488 | verts = verts[0] 489 | # Normalize pcl 490 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale) 491 | pcl = pcl[0] 492 | 493 | # print('%.6f %.6f' % (verts.abs().max().item(), pcl.abs().max().item())) 494 | 495 | # Convert them to pytorch3d structures 496 | pcls = pytorch3d.structures.Pointclouds([pcl]) 497 | meshes = pytorch3d.structures.Meshes([verts], [faces]) 498 | return pytorch3d.loss.point_mesh_face_distance(meshes, pcls) 499 | 500 | 501 | def pointwise_p2m_distance_normalized(pcl, verts, faces): 502 | assert pcl.dim() == 2 and verts.dim() == 2 and faces.dim() == 2, 'Batch is not supported.' 503 | 504 | # Normalize mesh 505 | verts, center, scale = normalize_sphere(verts.unsqueeze(0)) 506 | verts = verts[0] 507 | # Normalize pcl 508 | pcl = normalize_pcl(pcl.unsqueeze(0), center=center, scale=scale) 509 | pcl = pcl[0] 510 | 511 | # Convert them to pytorch3d structures 512 | pcls = pytorch3d.structures.Pointclouds([pcl]) 513 | meshes = pytorch3d.structures.Meshes([verts], [faces]) 514 | 515 | # packed representation for pointclouds 516 | points = pcls.points_packed() # (P, 3) 517 | points_first_idx = pcls.cloud_to_packed_first_idx() 518 | max_points = pcls.num_points_per_cloud().max().item() 519 | 520 | # packed representation for faces 521 | verts_packed = meshes.verts_packed() 522 | faces_packed = meshes.faces_packed() 523 | tris = verts_packed[faces_packed] # (T, 3, 3) 524 | tris_first_idx = meshes.mesh_to_faces_packed_first_idx() 525 | max_tris = meshes.num_faces_per_mesh().max().item() 526 | 527 | # point to face distance: shape (P,) 528 | point_to_face = point_face_distance( 529 | points, points_first_idx, tris, tris_first_idx, max_points 530 | ) 531 | return point_to_face 532 | 533 | 534 | def hausdorff_distance_unit_sphere(gen, ref): 535 | """ 536 | Args: 537 | gen: (B, N, 3) 538 | ref: (B, N, 3) 539 | Returns: 540 | (B, ) 541 | """ 542 | ref, center, scale = normalize_sphere(ref) 543 | gen = normalize_pcl(gen, center, scale) 544 | 545 | dists_ab, _, _ = pytorch3d.ops.knn_points(ref, gen, K=1) 546 | dists_ab = dists_ab[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1) 547 | # print(dists_ab) 548 | 549 | dists_ba, _, _ = pytorch3d.ops.knn_points(gen, ref, K=1) 550 | dists_ba = dists_ba[:,:,0].max(dim=1, keepdim=True)[0] # (B, 1) 551 | # print(dists_ba) 552 | 553 | dists_hausdorff = torch.max(torch.cat([dists_ab, dists_ba], dim=1), dim=1)[0] 554 | 555 | return dists_hausdorff 556 | 557 | 558 | def get_knn_idx(x, y, k, offset=0): 559 | """ 560 | Args: 561 | x: (B, N, d) 562 | y: (B, M, d) 563 | Returns: 564 | (B, N, k) 565 | """ 566 | _, knn_idx, _ = pytorch3d.ops.knn_points(x, y, K=k+offset) 567 | return knn_idx[:, :, offset:] 568 | 569 | 570 | def knn_group(x:torch.FloatTensor, idx:torch.LongTensor): 571 | """ 572 | :param x: (B, N, F) 573 | :param idx: (B, M, k) 574 | :return (B, M, k, F) 575 | """ 576 | B, N, F = tuple(x.size()) 577 | _, M, k = tuple(idx.size()) 578 | 579 | x = x.unsqueeze(1).expand(B, M, N, F) 580 | idx = idx.unsqueeze(3).expand(B, M, k, F) 581 | 582 | return torch.gather(x, dim=2, index=idx) 583 | 584 | --------------------------------------------------------------------------------