├── model ├── __init__.py ├── log.txt ├── Diffusion.py ├── Encoder.py ├── Decoder_Component.py ├── DiffusionCop.py ├── EncoderCompress.py ├── DiffusionSR.py ├── DiffusionPretrain.py ├── Encoder_Component.py ├── EncoderCop.py └── EncoderSR.py ├── utils ├── __init__.py ├── misc.py ├── visualization.py └── logger.py ├── metrics ├── __init__.py ├── .gitignore ├── pytorch_structural_losses │ ├── .gitignore │ ├── __init__.py │ ├── src │ │ ├── nndistance.cuh │ │ ├── approxmatch.cuh │ │ ├── utils.hpp │ │ ├── nndistance.cu │ │ ├── structural_loss.cpp │ │ └── approxmatch.cu │ ├── pybind │ │ ├── extern.hpp │ │ └── bind.cpp │ ├── setup.py │ ├── nn_distance.py │ ├── match_cost.py │ └── Makefile └── evaluation_metrics.py ├── assets └── structure.png ├── .gitignore ├── pretrain_model └── .gitignore ├── .idea ├── .gitignore ├── misc.xml ├── vcs.xml ├── other.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── DiffPointMAE.iml ├── dataset └── .gitignore ├── requirements.txt ├── train_encoder.py ├── eval_compression.py ├── eval_cop.py ├── eval_upsampling.py ├── train_decoder.py ├── readme.md └── eval_diffpmae.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/.gitignore: -------------------------------------------------------------------------------- 1 | StructuralLosses 2 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/.gitignore: -------------------------------------------------------------------------------- 1 | PyTorchStructuralLosses.egg-info/ 2 | -------------------------------------------------------------------------------- /assets/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TyraelDLee/DiffPMAE/HEAD/assets/structure.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | eval_diffpmae_ts300.py 2 | train_decoder_kitti.py 3 | train_encoder_kitti.py 4 | train_encoder_kitti_object.py 5 | /results -------------------------------------------------------------------------------- /pretrain_model/.gitignore: -------------------------------------------------------------------------------- 1 | /completion 2 | /compress 3 | /encocder_cop_exp 4 | /new_start_eval_2023_10_21_15_37 5 | /pretrain 6 | /sr -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/__init__.py: -------------------------------------------------------------------------------- 1 | #import torch 2 | 3 | #from MakePytorchBackend import AddGPU, Foo, ApproxMatch 4 | 5 | #from Add import add_gpu, approx_match 6 | 7 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # 默认忽略的文件 2 | /shelf/ 3 | /workspace.xml 4 | # 基于编辑器的 HTTP 客户端请求 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | data_depth_velodyne.zip 2 | data_object_velodyne.zip 3 | data_odometry_calib.zip 4 | data_odometry_labels.zip 5 | /data_odometry_labels 6 | /KITTI 7 | /KITTI_depth 8 | /ModelNet 9 | /PU1K 10 | /SEMANTIC_KITTI_DIR 11 | /ShapeNet55 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.21.5 2 | matplotlib~=3.1.1 3 | h5py~=2.9.0 4 | yaml~=0.1.7 5 | pyyaml~=5.1.2 6 | easydict~=1.10 7 | open3d~=0.17.0 8 | tqdm~=4.64.1 9 | termcolor~=1.1.0 10 | scipy~=1.4.1 11 | scikit-learn~=0.23.2 12 | setuptools~=41.4.0 -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | from pointnet2_ops import pointnet2_utils 2 | 3 | 4 | def fps(data, number): 5 | ''' 6 | data B N 3 7 | number int 8 | ''' 9 | fps_idx = pointnet2_utils.furthest_point_sample(data, number) 10 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() 11 | return fps_data 12 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/nndistance.cuh: -------------------------------------------------------------------------------- 1 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 2 | void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 3 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/pybind/extern.hpp: -------------------------------------------------------------------------------- 1 | std::vector ApproxMatch(at::Tensor in_a, at::Tensor in_b); 2 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match); 3 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match); 4 | 5 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q); 6 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2); 7 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/pybind/bind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "R:\\Documents\\developing\\COMP5702_04\\models\\metrics\\pytorch_structural_losses\\pybind\\extern.hpp" 6 | 7 | namespace py = pybind11; 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 10 | m.def("ApproxMatch", &ApproxMatch); 11 | m.def("MatchCost", &MatchCost); 12 | m.def("MatchCostGrad", &MatchCostGrad); 13 | m.def("NNDistance", &NNDistance); 14 | m.def("NNDistanceGrad", &NNDistanceGrad); 15 | } 16 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/approxmatch.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | template 3 | void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, 4 | cudaStream_t stream); 5 | */ 6 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream); 7 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream); 8 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream); 9 | -------------------------------------------------------------------------------- /.idea/DiffPointMAE.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/utils.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | class Formatter { 6 | public: 7 | Formatter() {} 8 | ~Formatter() {} 9 | 10 | template Formatter &operator<<(const Type &value) { 11 | stream_ << value; 12 | return *this; 13 | } 14 | 15 | std::string str() const { return stream_.str(); } 16 | operator std::string() const { return stream_.str(); } 17 | 18 | enum ConvertToString { to_str }; 19 | 20 | std::string operator>>(ConvertToString) { return stream_.str(); } 21 | 22 | private: 23 | std::stringstream stream_; 24 | Formatter(const Formatter &); 25 | Formatter &operator=(Formatter &); 26 | }; 27 | -------------------------------------------------------------------------------- /model/log.txt: -------------------------------------------------------------------------------- 1 | [2023-10-21 15:37:35,969::INFO] loading dataset 2 | [2023-10-21 15:37:36,432::INFO] Training Stable Diffusion 3 | [2023-10-21 15:37:36,433::INFO] config: 4 | [2023-10-21 15:37:36,433::INFO] Namespace(batch_size=32, beta_1=0.0001, beta_T=0.05, decoder_depth=4, decoder_num_heads=4, decoder_trans_dim=192, depth=12, device='cuda', drop_path_rate=0.1, encoder_dims=384, group_size=32, log=True, loss='cdl2', mask_ratio=0.75, mask_type='rand', num_group=64, num_heads=6, num_output=8192, num_points=2048, num_steps=200, save_dir='./results', sched_mode='linear', trans_dim=384, transformer_dim_forward=128, transformer_drop_out=0.1, val_batch_size=1) 5 | [2023-10-21 15:37:36,433::INFO] dataset loaded 6 | [2023-10-21 15:37:38,342::INFO] [Point_MAE] 7 | [2023-10-21 15:37:38,427::INFO] [Point_MAE] divide point cloud into G64 x S32 points ... 8 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | 4 | # Python interface 5 | setup( 6 | name='PyTorchStructuralLosses', 7 | version='0.1.0', 8 | install_requires=['torch'], 9 | packages=['StructuralLosses'], 10 | package_dir={'StructuralLosses': '.\\'}, 11 | ext_modules=[ 12 | CUDAExtension( 13 | name='StructuralLossesBackend', 14 | include_dirs=['.\\'], 15 | sources=[ 16 | 'pybind/bind.cpp', 17 | ], 18 | libraries=['make_pytorch'], 19 | library_dirs=['objs'], 20 | # extra_compile_args=['-g'] 21 | ) 22 | ], 23 | cmdclass={'build_ext': BuildExtension}, 24 | author='Christopher B. Choy', 25 | author_email='chrischoy@ai.stanford.edu', 26 | description='Tutorial for Pytorch C++ Extension with a Makefile', 27 | keywords='Pytorch C++ Extension', 28 | url='https://github.com/chrischoy/MakePytorchPlusPlus', 29 | zip_safe=False, 30 | ) 31 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/nn_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | # from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 4 | from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad 5 | 6 | # Inherit from Function 7 | class NNDistanceFunction(Function): 8 | # Note that both forward and backward are @staticmethods 9 | @staticmethod 10 | # bias is an optional argument 11 | def forward(ctx, seta, setb): 12 | #print("Match Cost Forward") 13 | ctx.save_for_backward(seta, setb) 14 | ''' 15 | input: 16 | set1 : batch_size * #dataset_points * 3 17 | set2 : batch_size * #query_points * 3 18 | returns: 19 | dist1, idx1, dist2, idx2 20 | ''' 21 | dist1, idx1, dist2, idx2 = NNDistance(seta, setb) 22 | ctx.idx1 = idx1 23 | ctx.idx2 = idx2 24 | return dist1, dist2 25 | 26 | # This function has only a single output, so it gets only one gradient 27 | @staticmethod 28 | def backward(ctx, grad_dist1, grad_dist2): 29 | #print("Match Cost Backward") 30 | # This is a pattern that is very convenient - at the top of backward 31 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 32 | # None. Thanks to the fact that additional trailing Nones are 33 | # ignored, the return statement is simple even when the function has 34 | # optional inputs. 35 | seta, setb = ctx.saved_tensors 36 | idx1 = ctx.idx1 37 | idx2 = ctx.idx2 38 | grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2) 39 | return grada, gradb 40 | 41 | nn_distance = NNDistanceFunction.apply 42 | 43 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/match_cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad 4 | 5 | # Inherit from Function 6 | class MatchCostFunction(Function): 7 | # Note that both forward and backward are @staticmethods 8 | @staticmethod 9 | # bias is an optional argument 10 | def forward(ctx, seta, setb): 11 | #print("Match Cost Forward") 12 | ctx.save_for_backward(seta, setb) 13 | ''' 14 | input: 15 | set1 : batch_size * #dataset_points * 3 16 | set2 : batch_size * #query_points * 3 17 | returns: 18 | match : batch_size * #query_points * #dataset_points 19 | ''' 20 | match, temp = ApproxMatch(seta, setb) 21 | ctx.match = match 22 | cost = MatchCost(seta, setb, match) 23 | return cost 24 | 25 | """ 26 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) 27 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] 28 | """ 29 | # This function has only a single output, so it gets only one gradient 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | #print("Match Cost Backward") 33 | # This is a pattern that is very convenient - at the top of backward 34 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to 35 | # None. Thanks to the fact that additional trailing Nones are 36 | # ignored, the return statement is simple even when the function has 37 | # optional inputs. 38 | seta, setb = ctx.saved_tensors 39 | #grad_input = grad_weight = grad_bias = None 40 | grada, gradb = MatchCostGrad(seta, setb, ctx.match) 41 | grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2) 42 | return grada*grad_output_expand, gradb*grad_output_expand 43 | 44 | match_cost = MatchCostFunction.apply 45 | 46 | -------------------------------------------------------------------------------- /model/Diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | 6 | class VarianceSchedule(nn.Module): 7 | 8 | def __init__(self, num_steps, beta_1, beta_T, mode='linear'): 9 | super().__init__() 10 | assert mode in ('linear') 11 | self.num_steps = num_steps 12 | self.beta_1 = beta_1 13 | self.beta_T = beta_T 14 | self.mode = mode 15 | 16 | if mode == 'linear': 17 | betas = torch.linspace(beta_1, beta_T, steps=num_steps) 18 | # create a 1D tensor of size num_steps, values are evenly spaced from beta1 and betaT. 19 | # beta1, ... , betaT are hyper-parameter that control the diffusion rate of the process. 20 | 21 | betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding a 0 at beginning. 22 | 23 | alphas = 1 - betas 24 | log_alphas = torch.log(alphas) # (7) 25 | for i in range(1, log_alphas.size(0)): # 1 to T 26 | log_alphas[i] += log_alphas[i - 1] 27 | # log alpha add all previous step 28 | alpha_bars = log_alphas.exp() # ? 29 | 30 | sigmas_flex = torch.sqrt(betas) 31 | sigmas_inflex = torch.zeros_like(sigmas_flex) # a 0 filled tensor with sigmas_flex dimension 32 | for i in range(1, sigmas_flex.size(0)): 33 | sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] # (11) 34 | sigmas_inflex = torch.sqrt(sigmas_inflex) 35 | 36 | self.register_buffer('betas', betas) 37 | self.register_buffer('alphas', alphas) 38 | self.register_buffer('alpha_bars', alpha_bars) 39 | self.register_buffer('sigmas_flex', sigmas_flex) 40 | self.register_buffer('sigmas_inflex', sigmas_inflex) 41 | 42 | def uniform_sample_t(self, batch_size): 43 | ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size) 44 | return ts.tolist() 45 | 46 | def get_sigmas(self, t, flexibility): 47 | assert 0 <= flexibility <= 1 48 | sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility) 49 | return sigmas 50 | 51 | 52 | class TimeEmbedding(nn.Module): 53 | """ 54 | Sinusoidal Time Embedding for the diffusion process. 55 | 56 | Input: 57 | Timestep: the current timestep, in range [1, ..., T] 58 | """ 59 | def __init__(self, dim): 60 | super().__init__() 61 | self.emb_dim = dim 62 | 63 | def forward(self, ts): 64 | half_dim = self.emb_dim // 2 65 | emb = math.log(10000) / (half_dim - 1) 66 | emb = torch.exp(torch.arange(half_dim, device=ts.device) * -emb) 67 | emb = ts[:, None] * emb[None, :] 68 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 69 | return emb -------------------------------------------------------------------------------- /model/Encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2 4 | from model.Encoder_Component import * 5 | 6 | class Encoder_Module(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.config = config 10 | self.trans_dim = config.trans_dim 11 | self.AE_encoder = PointTransformer(config) 12 | self.group_size = config.group_size 13 | self.num_group = config.num_group 14 | self.num_output = config.num_output 15 | self.num_channel = 3 16 | self.drop_path_rate = config.drop_path_rate 17 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 18 | 19 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 20 | 21 | # prediction head 22 | self.increase_dim = nn.Sequential( 23 | nn.Conv1d(self.trans_dim, (self.num_channel * self.num_output) // self.num_group, 1) 24 | ) 25 | 26 | trunc_normal_(self.mask_token, std=.02) 27 | self.loss = config.loss 28 | # loss 29 | self.build_loss_func(self.loss) 30 | 31 | def build_loss_func(self, loss_type): 32 | if loss_type == "cdl1": 33 | self.loss_func = chamfer_distance_l1 34 | elif loss_type == 'cdl2': 35 | self.loss_func = chamfer_distance_l2 36 | elif loss_type == 'mse': 37 | self.loss_func = F.mse_loss 38 | else: 39 | raise NotImplementedError 40 | 41 | def forward(self, pts, hr_pt): 42 | x_vis, x_msk, mask, center, vis_pc, msk_pc = self.encode(pts, False) 43 | 44 | B, _, C = x_vis.shape # B VIS C 45 | x_full = torch.cat([x_vis, x_msk], dim=1) 46 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 47 | loss1 = self.loss_func(rebuild_points, hr_pt) 48 | return loss1 49 | 50 | def encode(self, pt, masked=False): 51 | B, _, N = pt.shape 52 | neighborhood, center = self.group_divider(pt) 53 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center) 54 | if masked: 55 | return x_vis, mask, center 56 | else: 57 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis) 58 | return x_vis, x_masked, mask, center, vis_pc, msk_pc 59 | 60 | def evaluate(self, x_vis, x_msk): 61 | B, _, C = x_vis.shape # B VIS C 62 | x_full = torch.cat([x_vis, x_msk], dim=1) 63 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 64 | return rebuild_points.reshape(-1, 3).unsqueeze(0) 65 | 66 | def neighborhood(self, neighborhood, center, mask, x_vis): 67 | B, M, N = x_vis.shape 68 | vis_point = neighborhood[~mask].reshape(B * M, -1, 3) 69 | full_vis = vis_point + center[~mask].unsqueeze(1) 70 | msk_point = neighborhood[mask].reshape(B * (self.num_group - M), -1, 3) 71 | full_msk = msk_point + center[mask].unsqueeze(1) 72 | 73 | full_vis = full_vis.reshape(B, -1, 3) 74 | full_msk = full_msk.reshape(B, -1, 3) 75 | 76 | return full_vis, full_msk 77 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/Makefile: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Uncomment for debugging 3 | # DEBUG := 1 4 | # Pretty build 5 | # Q ?= @ 6 | 7 | CXX := g++ 8 | PYTHON := python 9 | NVCC := /usr/local/cuda/bin/nvcc 10 | 11 | # PYTHON Header path 12 | PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())') 13 | PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]') 14 | PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]') 15 | 16 | # CUDA ROOT DIR that contains bin/ lib64/ and include/ 17 | # CUDA_DIR := /usr/local/cuda 18 | CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())') 19 | 20 | INCLUDE_DIRS := ./ $(CUDA_DIR)/include 21 | 22 | INCLUDE_DIRS += $(PYTHON_HEADER_DIR) 23 | INCLUDE_DIRS += $(PYTORCH_INCLUDES) 24 | 25 | # Custom (MKL/ATLAS/OpenBLAS) include and lib directories. 26 | # Leave commented to accept the defaults for your choice of BLAS 27 | # (which should work)! 28 | # BLAS_INCLUDE := /path/to/your/blas 29 | # BLAS_LIB := /path/to/your/blas 30 | 31 | ############################################################################### 32 | SRC_DIR := ./src 33 | OBJ_DIR := ./objs 34 | CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp) 35 | CU_SRCS := $(wildcard $(SRC_DIR)/*.cu) 36 | OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS)) 37 | CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS)) 38 | STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a 39 | 40 | # CUDA architecture setting: going with all of them. 41 | # For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility. 42 | # For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. 43 | CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \ 44 | -gencode arch=compute_61,code=compute_61 \ 45 | -gencode arch=compute_52,code=sm_52 46 | 47 | # We will also explicitly add stdc++ to the link target. 48 | LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu 49 | 50 | # Debugging 51 | ifeq ($(DEBUG), 1) 52 | COMMON_FLAGS += -DDEBUG -g -O0 53 | # https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/ 54 | NVCCFLAGS += -g -G # -rdc true 55 | else 56 | COMMON_FLAGS += -DNDEBUG -O3 57 | endif 58 | 59 | WARNINGS := -Wall -Wno-sign-compare -Wcomment 60 | 61 | INCLUDE_DIRS += $(BLAS_INCLUDE) 62 | 63 | # Automatic dependency generation (nvcc is handled separately) 64 | CXXFLAGS += -MMD -MP 65 | 66 | # Complete build flags. 67 | COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \ 68 | -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0 69 | CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS) 70 | NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) 71 | 72 | all: $(STATIC_LIB) 73 | $(PYTHON) setup.py build 74 | @ mv build/lib.linux-x86_64-3.6/StructuralLosses .. 75 | @ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/ 76 | @- $(RM) -rf $(OBJ_DIR) build objs 77 | 78 | $(OBJ_DIR): 79 | @ mkdir -p $@ 80 | @ mkdir -p $@/cuda 81 | 82 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR) 83 | @ echo CXX $< 84 | $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ 85 | 86 | $(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR) 87 | @ echo NVCC $< 88 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ 89 | -odir $(@D) 90 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 91 | 92 | $(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR) 93 | $(RM) -f $(STATIC_LIB) 94 | $(RM) -rf build dist 95 | @ echo LD -o $@ 96 | ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS) 97 | 98 | clean: 99 | @- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses 100 | 101 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from utils.dataset import Semantic_KITTI, KITTI_Object 4 | 5 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255] 6 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255] 7 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255] 8 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255] 9 | 10 | def plot_point_cloud(pc): 11 | fig = plt.figure(figsize=(8, 8)) 12 | ax = fig.add_subplot(111, projection='3d') 13 | ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2]) 14 | # ax.set_xlim3d(-1, 1) 15 | # ax.set_ylim3d(-1, 1) 16 | # ax.set_zlim3d(-1, 1) 17 | return fig.show() 18 | 19 | 20 | def show_model(pc, index=0, save_img=False, show_axis=True): 21 | """ 22 | Load the generated result from npy file and show it in plots. 23 | Args: 24 | pc: the list of npy files which contain point cloud result. (Each npy file contains multiple PC model) 25 | for AE can take 2 files for output and reference; for GEN take 1 file for output results. 26 | index: show the specific results the group of point cloud result. 27 | """ 28 | fig = plt.figure(figsize=(8, 8)) 29 | for n_pc in range(0, len(pc)): 30 | ax = fig.add_subplot(1, len(pc), n_pc+1, projection='3d') 31 | models = np.load(pc[n_pc]) 32 | print('Input file contains ' + str(len(models)) + ' models') 33 | print(len(models[0])) 34 | print(len(models)) 35 | # print(len(models[1])) 36 | ax.scatter(models[index][:, 0], models[index][:, 1], models[index][:, 2]) 37 | if n_pc == 0: 38 | ax.set_title('Output', fontsize=14) 39 | else: 40 | ax.set_title('Reference', fontsize=14) 41 | if not show_axis: 42 | ax.axis('off') 43 | if save_img: 44 | fig.savefig('output.png', dpi=300) 45 | fig.show() 46 | 47 | def eval_model_normal(x, t, index=0, show_img=True, save_img=True, show_axis=True, save_location="", second_part=None, color=BLUE, color2=BLUE): 48 | """ 49 | Load the generated result from npy file and show it in plots. 50 | Args: 51 | pc: the list of npy files which contain point cloud result. (Each npy file contains multiple PC model) 52 | for AE can take 2 files for output and reference; for GEN take 1 file for output results. 53 | index: show the specific results the group of point cloud result. 54 | """ 55 | fig = plt.figure(figsize=(8, 8)) 56 | 57 | ax = fig.add_subplot(1, 1, 1, projection='3d') 58 | models = x 59 | # ax.set_title('x_t', fontsize=14) 60 | ax.scatter(models[index][:, 0], models[index][:, 1], models[index][:, 2], color=color, marker='o', alpha=1.0) 61 | if second_part is not None: 62 | ax.scatter(second_part[index][:, 0], second_part[index][:, 1], second_part[index][:, 2], color=color2, marker='o', alpha=1.0) 63 | if not show_axis: 64 | ax.axis('off') 65 | if save_img: 66 | fig.savefig(save_location+'/{i}.png'.format(i=t), dpi=100) 67 | if show_img: 68 | fig.show() 69 | 70 | def show_forward_diff(pc): 71 | fig = plt.figure(figsize=(20,8)) 72 | print(len(pc)) 73 | for i in range(0, len(pc)): 74 | if i % 10 == 0: 75 | ax = fig.add_subplot(1, 55, i+1, projection='3d') 76 | models = np.load(pc) 77 | ax.scatter(models[i][:, 0], models[i][:, 1], models[i][:, 2]) 78 | fig.show() 79 | 80 | if __name__ == '__main__': 81 | # test_KITTI = Semantic_KITTI( 82 | # datapath='../dataset/SEMANTIC_KITTI_DIR', 83 | # subset='test', 84 | # norm=True, 85 | # npoints=2048 86 | # ) 87 | # print(test_KITTI.__len__()) 88 | # plot_point_cloud(test_KITTI.__getitem__(4)['coord']) 89 | # show_model(['../results/pretrain_decoder_2024_02_28_17_43/out.npy'], 4) 90 | 91 | test_KITTI = KITTI_Object( 92 | datapath='../dataset/KITTI', 93 | subset='test', 94 | norm=True, 95 | 96 | ) 97 | print(test_KITTI.__len__()) 98 | plot_point_cloud(test_KITTI.__getitem__(90)['hr']) 99 | # show_model(['./results/new_start_2023_10_09_15_17/out.npy'], 21) 100 | # show_forward_diff('./results/new_stable_diffusion_2023_08_10_21_38/out.npy') 101 | 102 | # show_model(['./results/Maksed_noise_70/out.npy', 103 | # './results/baseline/ref.npy'], 100) -------------------------------------------------------------------------------- /model/Decoder_Component.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from timm.models.layers import DropPath 3 | 4 | # Ref: https://github.com/Pang-Yatian/Point-MAE/blob/main/models/Point_MAE.py#L82 5 | class Mlp(nn.Module): 6 | """ 7 | The feedforward in Transformer blockers. 8 | """ 9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 10 | super().__init__() 11 | out_features = out_features or in_features 12 | hidden_features = hidden_features or in_features 13 | self.fc1 = nn.Linear(in_features, hidden_features) 14 | self.act = act_layer() 15 | self.fc2 = nn.Linear(hidden_features, out_features) 16 | self.drop = nn.Dropout(drop) 17 | 18 | def forward(self, x): 19 | x = self.fc1(x) 20 | x = self.act(x) 21 | x = self.drop(x) 22 | x = self.fc2(x) 23 | x = self.drop(x) 24 | return x 25 | 26 | class Attention(nn.Module): 27 | """ 28 | The multi-head attention in Transformer blockers. 29 | """ 30 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 31 | super().__init__() 32 | self.num_heads = num_heads 33 | head_dim = dim // num_heads 34 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 35 | self.scale = qk_scale or head_dim ** -0.5 36 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 37 | self.attn_drop = nn.Dropout(attn_drop) 38 | self.proj = nn.Linear(dim, dim) 39 | self.proj_drop = nn.Dropout(proj_drop) 40 | 41 | def forward(self, x): 42 | B, N, C = x.shape 43 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 44 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 45 | 46 | attn = (q @ k.transpose(-2, -1)) * self.scale 47 | attn = attn.softmax(dim=-1) 48 | attn = self.attn_drop(attn) 49 | 50 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 51 | x = self.proj(x) 52 | x = self.proj_drop(x) 53 | return x 54 | 55 | 56 | class Block(nn.Module): 57 | """ 58 | The Transformer Block in the Decoder module. 59 | """ 60 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 61 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 62 | super().__init__() 63 | self.norm1 = norm_layer(dim) 64 | 65 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 66 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 67 | self.norm2 = norm_layer(dim) 68 | mlp_hidden_dim = int(dim * mlp_ratio) 69 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 70 | 71 | self.attn = Attention( 72 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 73 | 74 | def forward(self, x): 75 | x = x + self.drop_path(self.attn(self.norm1(x))) 76 | x = x + self.drop_path(self.mlp(self.norm2(x))) 77 | return x 78 | 79 | 80 | class Transformer(nn.Module): 81 | """ 82 | The Transformer for the Decoder module. 83 | 84 | Inputs: 85 | Latent: [B, V+M, L] 86 | Position Embedding: [B, V+M, L] 87 | Time Embedding: [B, V+M, L] 88 | 89 | B: Batch size, 90 | V: Visible patches size 91 | M: Masked patches size 92 | L: Latent size 93 | """ 94 | def __init__(self, embed_dim=384, depth=4, num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, 95 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm): 96 | super().__init__() 97 | self.blocks = nn.ModuleList([ 98 | Block( 99 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 100 | drop=drop_rate, attn_drop=attn_drop_rate, 101 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 102 | ) 103 | for i in range(depth)]) 104 | self.norm = norm_layer(embed_dim) 105 | self.head = nn.Identity() 106 | 107 | self.apply(self._init_weights) 108 | 109 | def _init_weights(self, m): 110 | if isinstance(m, nn.Linear): 111 | nn.init.xavier_uniform_(m.weight) 112 | if isinstance(m, nn.Linear) and m.bias is not None: 113 | nn.init.constant_(m.bias, 0) 114 | elif isinstance(m, nn.LayerNorm): 115 | nn.init.constant_(m.bias, 0) 116 | nn.init.constant_(m.weight, 1.0) 117 | 118 | def forward(self, x, pos, num_of_group, ts): 119 | for _, block in enumerate(self.blocks): 120 | x = block(x + pos + ts) 121 | x = self.head(self.norm(x[:, -num_of_group:])) 122 | return x 123 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/nndistance.cu: -------------------------------------------------------------------------------- 1 | 2 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 3 | const int batch=512; 4 | __shared__ float buf[batch*3]; 5 | for (int i=blockIdx.x;ibest){ 117 | result[(i*n+j)]=best; 118 | result_i[(i*n+j)]=best_i; 119 | } 120 | } 121 | __syncthreads(); 122 | } 123 | } 124 | } 125 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 126 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); 127 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 128 | } 129 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 130 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 153 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 154 | } 155 | 156 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /train_encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | from torch import optim 5 | from torch.utils.data import DataLoader 6 | from utils.dataset import * 7 | from utils.logger import * 8 | from model.Encoder import * 9 | 10 | 11 | def get_data_iterator(iterable): 12 | """Allows training with DataLoaders in a single infinite loop: 13 | for i, data in enumerate(inf_generator(train_loader)): 14 | """ 15 | iterator = iterable.__iter__() 16 | while True: 17 | try: 18 | yield iterator.__next__() 19 | except StopIteration: 20 | iterator = iterable.__iter__() 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | # Experiment setting 25 | parser.add_argument('--batch_size', type=int, default=4) 26 | parser.add_argument('--val_batch_size', type=int, default=1) 27 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 28 | parser.add_argument('--save_dir', type=str, default='./results') 29 | parser.add_argument('--log', type=bool, default=False) 30 | 31 | # Grouping setting 32 | parser.add_argument('--mask_type', type=str, default='rand') 33 | parser.add_argument('--mask_ratio', type=float, default=0.75) 34 | parser.add_argument('--group_size', type=int, default=32) 35 | parser.add_argument('--num_group', type=int, default=64) 36 | parser.add_argument('--num_points', type=int, default=2048) 37 | parser.add_argument('--num_output', type=int, default=8192) 38 | 39 | # Transformer setting 40 | parser.add_argument('--trans_dim', type=int, default=384) 41 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 42 | 43 | # Encoder setting 44 | parser.add_argument('--encoder_depth', type=int, default=12) 45 | parser.add_argument('--encoder_num_heads', type=int, default=6) 46 | parser.add_argument('--encoder_dims', type=int, default=384) 47 | parser.add_argument('--loss', type=str, default='cdl2') 48 | 49 | # sche / optim 50 | parser.add_argument('--learning_rate', type=float, default=0.001) 51 | parser.add_argument('--weight_decay', type=float, default=0.05) 52 | parser.add_argument('--eta_min', type=float, default=0.000001) 53 | parser.add_argument('--t_max', type=float, default=200) 54 | 55 | args = parser.parse_args() 56 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M") 57 | save_dir = os.path.join(args.save_dir, 'encocder_8192_width_{gSize}'.format(gSize=args.encoder_dims)) 58 | 59 | if args.log: 60 | if not os.path.exists(save_dir): 61 | os.makedirs(save_dir) 62 | log = logging.getLogger() 63 | log.setLevel(logging.INFO) 64 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s') 65 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt')) 66 | log_file.setLevel(logging.INFO) 67 | log_file.setFormatter(formatter) 68 | log.addHandler(log_file) 69 | 70 | 71 | 72 | if args.log: 73 | log.info('loading dataset') 74 | print('loading dataset') 75 | 76 | train_dset = ShapeNet( 77 | data_path='dataset/ShapeNet55/ShapeNet-55', 78 | pc_path='dataset/ShapeNet55/shapenet_pc', 79 | subset='train', 80 | n_points=2048, 81 | downsample=True 82 | ) 83 | val_dset = ShapeNet( 84 | data_path='dataset/ShapeNet55/ShapeNet-55', 85 | pc_path='dataset/ShapeNet55/shapenet_pc', 86 | subset='test', 87 | n_points=2048, 88 | downsample=True 89 | ) 90 | 91 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True) 92 | trn_loader = DataLoader(train_dset, batch_size=args.batch_size, pin_memory=True) 93 | if args.log: 94 | log.info('Training decoder for stable diffusion.') 95 | log.info('config:') 96 | log.info(args) 97 | log.info('dataset loaded') 98 | print('dataset loaded') 99 | 100 | 101 | model = Encoder_Module(args).to(args.device) 102 | 103 | 104 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 105 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=args.eta_min, T_max=args.t_max) 106 | 107 | 108 | def train(i, batch, epoch): 109 | if i == 0 and epoch > 1: 110 | print('Saving checkpoint at epoch {epoch}'.format(epoch=epoch)) 111 | save_model = {'args': args, 'model': model.state_dict()} 112 | if args.log: 113 | torch.save(save_model, os.path.join(save_dir, 'autoencoder_diff.pt')) 114 | 115 | x = batch['lr'].to(args.device) 116 | optimizer.zero_grad() 117 | model.train() 118 | loss = model(x, batch['hr'].to(args.device)) 119 | loss.backward() 120 | optimizer.step() 121 | if args.log and i == 0: 122 | log.info('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss)) 123 | print('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss)) 124 | 125 | 126 | def validate(): 127 | all_recons = [] 128 | for i, batch in enumerate(val_loader): 129 | print('sampling model {i}'.format(i=i)) 130 | ref = batch['lr'].to(args.device) 131 | if i > 200: 132 | break 133 | with torch.no_grad(): 134 | model.eval() 135 | x_vis, x_masked, mask, center, vis_pc, msk_pc = model.encode(ref, masked=False) 136 | recons = model.evaluate(x_vis, x_masked) 137 | all_recons.append(recons) 138 | all_recons = torch.cat(all_recons, dim=0) 139 | np.save(os.path.join(save_dir, 'out.npy'), all_recons.cpu().numpy()) 140 | 141 | 142 | try: 143 | n_it = 50 144 | epoch = 1 145 | while epoch <= n_it: 146 | model.train() 147 | for i, pc in enumerate(trn_loader): 148 | train(i, pc, epoch) 149 | scheduler.step() 150 | 151 | if epoch == n_it: 152 | if args.log: 153 | saved_file = {'args': args, 'model': model.state_dict()} 154 | torch.save(saved_file, os.path.join(save_dir, 'autoencoder_diff.pt')) 155 | validate() 156 | epoch += 1 157 | except Exception as e: 158 | log.error(e) 159 | 160 | -------------------------------------------------------------------------------- /eval_compression.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | from torch.utils.data import DataLoader 4 | from utils.dataset import * 5 | from utils.logger import * 6 | from model.EncoderCompress import * 7 | from model.DiffusionCop import * 8 | 9 | def get_data_iterator(iterable): 10 | """Allows training with DataLoaders in a single infinite loop: 11 | for i, data in enumerate(inf_generator(train_loader)): 12 | """ 13 | iterator = iterable.__iter__() 14 | while True: 15 | try: 16 | yield iterator.__next__() 17 | except StopIteration: 18 | iterator = iterable.__iter__() 19 | 20 | parser = argparse.ArgumentParser() 21 | # Experiment setting 22 | parser.add_argument('--val_batch_size', type=int, default=1) 23 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 24 | parser.add_argument('--save_dir', type=str, default='./results') 25 | parser.add_argument('--log', type=bool, default=True) 26 | 27 | # Grouping setting 28 | parser.add_argument('--mask_type', type=str, default='rand') 29 | parser.add_argument('--mask_ratio', type=float, default=0.75) 30 | parser.add_argument('--group_size', type=int, default=32) # points in each group 31 | parser.add_argument('--num_group', type=int, default=64) # number of group 32 | parser.add_argument('--num_points', type=int, default=2048) 33 | parser.add_argument('--num_output', type=int, default=8192) 34 | 35 | # Transformer setting 36 | parser.add_argument('--trans_dim', type=int, default=384) 37 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 38 | 39 | # Encoder setting 40 | parser.add_argument('--encoder_depth', type=int, default=12) 41 | parser.add_argument('--encoder_num_heads', type=int, default=6) 42 | parser.add_argument('--loss', type=str, default='cdl2') 43 | 44 | # Decoder setting 45 | parser.add_argument('--decoder_depth', type=int, default=4) 46 | parser.add_argument('--decoder_num_heads', type=int, default=4) 47 | 48 | # diffusion 49 | parser.add_argument('--num_steps', type=int, default=200) 50 | parser.add_argument('--beta_1', type=float, default=1e-4) 51 | parser.add_argument('--beta_T', type=float, default=0.05) 52 | parser.add_argument('--sched_mode', type=str, default='linear') 53 | 54 | 55 | args = parser.parse_args() 56 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M") 57 | save_dir = os.path.join(args.save_dir, 'eval_new_pe'.format(date=time_now)) 58 | 59 | 60 | check_point_dir = os.path.join('./pretrain_model/compress/encoder.pt') 61 | check_point = torch.load(check_point_dir)['model'] 62 | encoder = Encoder_Module(args).to(args.device) 63 | encoder.load_state_dict(check_point) 64 | 65 | 66 | if args.log: 67 | if not os.path.exists(save_dir): 68 | os.makedirs(save_dir) 69 | log = logging.getLogger() 70 | log.setLevel(logging.INFO) 71 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s') 72 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt')) 73 | log_file.setLevel(logging.INFO) 74 | log_file.setFormatter(formatter) 75 | log.addHandler(log_file) 76 | 77 | if args.log: 78 | log.info('loading dataset') 79 | print('loading dataset') 80 | 81 | 82 | test_dset_MN = ModelNet( 83 | root='dataset/ModelNet', 84 | number_pts=8192, 85 | downsampling=2048, 86 | use_normal=False, 87 | cats=40, 88 | subset='test' 89 | ) 90 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True, shuffle=True) 91 | 92 | val_dset = ShapeNet( 93 | data_path='dataset/ShapeNet55/ShapeNet-55', 94 | pc_path='dataset/ShapeNet55/shapenet_pc', 95 | subset='test', 96 | n_points=2048, 97 | downsample=True 98 | ) 99 | 100 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True) 101 | if args.log: 102 | log.info('Training Stable Diffusion') 103 | log.info('config:') 104 | log.info(args) 105 | log.info('dataset loaded') 106 | print('dataset loaded') 107 | 108 | print('loading model') 109 | 110 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255] 111 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255] 112 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255] 113 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255] 114 | 115 | def calculate_metric_all(size=-1): 116 | # mean_cd = [] 117 | diff_check_point_dir = os.path.join('./pretrain_model/compress/decoder.pt') 118 | diff_check_point = torch.load(diff_check_point_dir)['model'] 119 | model = Diff_Point_MAE(args).to(args.device) 120 | model = torch.nn.DataParallel(model, device_ids=[0]) 121 | model.load_state_dict(diff_check_point) 122 | 123 | all_sample = [] 124 | all_ref = [] 125 | all_hd = [] 126 | all_cd = [] 127 | for i, batch in enumerate(val_loader): 128 | if i == size: 129 | break 130 | with torch.no_grad(): 131 | ref = batch['lr'].to(args.device) 132 | model.eval() 133 | compress = Compress(args).to(args.device) 134 | vis, center, mask = compress.compress(ref) 135 | 136 | x_vis, vis_pc = encoder.encode(vis, center, mask) 137 | recons = model.module.sampling(x_vis, mask, center) 138 | recons = torch.cat([vis_pc, recons], dim=1) 139 | 140 | all_sample.append(recons) 141 | all_ref.append(ref) 142 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze()) 143 | all_hd.append(hd) 144 | 145 | cd = chamfer_distance_l2(recons, ref) 146 | all_cd.append(cd) 147 | print("evaluating model: {i}, CD: {cd}".format(i=i, cd=cd)) 148 | 149 | mean_hd = sum(all_hd) / len(all_hd) 150 | mean_cd = sum(all_cd) / len(all_cd) 151 | sample = torch.cat(all_sample, dim=0) 152 | refpc = torch.cat(all_ref, dim=0) 153 | jsd = jsd_between_point_cloud_sets(sample, refpc) 154 | 155 | print("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd)) 156 | log.info("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd)) 157 | 158 | 159 | calculate_metric_all() 160 | -------------------------------------------------------------------------------- /eval_cop.py: -------------------------------------------------------------------------------- 1 | from utils.dataset import * 2 | from metrics.evaluation_metrics import averaged_hausdorff_distance, jsd_between_point_cloud_sets 3 | from model.Completion import * 4 | from model.EncoderCop import * 5 | import argparse 6 | from datetime import datetime 7 | 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from utils.logger import * 11 | 12 | def get_data_iterator(iterable): 13 | """Allows training with DataLoaders in a single infinite loop: 14 | for i, data in enumerate(inf_generator(train_loader)): 15 | """ 16 | iterator = iterable.__iter__() 17 | while True: 18 | try: 19 | yield iterator.__next__() 20 | except StopIteration: 21 | iterator = iterable.__iter__() 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | # Experiment setting 26 | parser.add_argument('--val_batch_size', type=int, default=1) 27 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 28 | parser.add_argument('--save_dir', type=str, default='./results') 29 | parser.add_argument('--log', type=bool, default=True) 30 | 31 | # Grouping setting 32 | parser.add_argument('--mask_type', type=str, default='rand') 33 | parser.add_argument('--mask_ratio', type=float, default=0.75) 34 | parser.add_argument('--group_size', type=int, default=32) # points in each group 35 | parser.add_argument('--num_group', type=int, default=64) # number of group 36 | parser.add_argument('--num_points', type=int, default=2048) 37 | parser.add_argument('--num_output', type=int, default=8192) 38 | 39 | # Transformer setting 40 | parser.add_argument('--trans_dim', type=int, default=384) 41 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 42 | 43 | # Encoder setting 44 | parser.add_argument('--encoder_depth', type=int, default=12) 45 | parser.add_argument('--encoder_num_heads', type=int, default=6) 46 | parser.add_argument('--loss', type=str, default='cdl2') 47 | 48 | # Decoder setting 49 | parser.add_argument('--decoder_depth', type=int, default=4) 50 | parser.add_argument('--decoder_num_heads', type=int, default=4) 51 | 52 | # diffusion 53 | parser.add_argument('--num_steps', type=int, default=200) 54 | parser.add_argument('--beta_1', type=float, default=1e-4) 55 | parser.add_argument('--beta_T', type=float, default=0.05) 56 | parser.add_argument('--sched_mode', type=str, default='linear') 57 | 58 | args = parser.parse_args() 59 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M") 60 | save_dir = os.path.join(args.save_dir, 'new_start_eval_{date}'.format(date=time_now)) 61 | 62 | if args.log: 63 | if not os.path.exists(save_dir): 64 | os.makedirs(save_dir) 65 | log = logging.getLogger() 66 | log.setLevel(logging.INFO) 67 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s') 68 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt')) 69 | log_file.setLevel(logging.INFO) 70 | log_file.setFormatter(formatter) 71 | log.addHandler(log_file) 72 | 73 | if args.log: 74 | log.info('loading dataset') 75 | print('loading dataset') 76 | 77 | 78 | test_dset_MN = ModelNet( 79 | root='dataset/ModelNet', 80 | number_pts=8192, 81 | downsampling=2048, 82 | use_normal=False, 83 | cats=40, 84 | subset='test' 85 | ) 86 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True, shuffle=True) 87 | 88 | val_dset = ShapeNet( 89 | data_path='dataset/ShapeNet55/ShapeNet-55', 90 | pc_path='dataset/ShapeNet55/shapenet_pc', 91 | subset='test', 92 | n_points=2048, 93 | downsample=True 94 | ) 95 | 96 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True) 97 | if args.log: 98 | log.info('Training Stable Diffusion') 99 | log.info('config:') 100 | log.info(args) 101 | log.info('dataset loaded') 102 | print('dataset loaded') 103 | 104 | print('loading model') 105 | 106 | def calculate_metric_all(size=-1, encoder=None): 107 | # mean_cd = [] 108 | diff_check_point_dir = os.path.join('./pretrain_model/completion/decoder.pt') 109 | model = Diff_Point_MAE(args, encoder).to(args.device) 110 | model = torch.nn.DataParallel(model, device_ids=[0]) 111 | model.module.load_model_from_ckpt(diff_check_point_dir) 112 | 113 | 114 | all_sample = [] 115 | all_ref = [] 116 | all_hd = [] 117 | all_cd = [] 118 | for i, batch in enumerate(val_loader_MN): 119 | if i == size: 120 | break 121 | with torch.no_grad(): 122 | ref = batch['model'].to(args.device) 123 | model.eval() 124 | x_vis, x_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref) 125 | 126 | recons = model.module.sampling(x_vis) 127 | 128 | recons = torch.cat([vis_pc, recons], dim=1) 129 | all_sample.append(recons) 130 | all_ref.append(ref) 131 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze()) 132 | all_hd.append(hd) 133 | 134 | cd = chamfer_distance_l2(recons, ref) 135 | all_cd.append(cd) 136 | print("evaluating model: {i}, CD: {cd}".format(i=i, cd=cd)) 137 | 138 | mean_hd = sum(all_hd) / len(all_hd) 139 | mean_cd = sum(all_cd) / len(all_cd) 140 | sample = torch.cat(all_sample, dim=0) 141 | refpc = torch.cat(all_ref, dim=0) 142 | jsd = jsd_between_point_cloud_sets(sample, refpc) 143 | print("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd)) 144 | log.info("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd)) 145 | 146 | 147 | # calculate_metric_all() 148 | 149 | def auto_run(): 150 | lose_ratio = [0.75] 151 | for lr in lose_ratio: 152 | args.mask_ratio = lr 153 | check_point_dir = os.path.join('./pretrain_model/completion/encoder.pt') 154 | # check_point_dir = os.path.join(args.save_dir, '{saved_src}/encoder.pt' 155 | # .format(saved_src='new_start_eval_2023_11_12_02_21')) 156 | check_point = torch.load(check_point_dir)['model'] 157 | encoder = Point_MAE(args).to(args.device) 158 | encoder.load_state_dict(check_point) 159 | print('model loaded') 160 | 161 | 162 | calculate_metric_all(encoder=encoder) 163 | 164 | auto_run() 165 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/structural_loss.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "src/approxmatch.cuh" 5 | #include "src/nndistance.cuh" 6 | 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | /* 15 | input: 16 | set1 : batch_size * #dataset_points * 3 17 | set2 : batch_size * #query_points * 3 18 | returns: 19 | match : batch_size * #query_points * #dataset_points 20 | */ 21 | // temp: TensorShape{b,(n+m)*2} 22 | std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) { 23 | //std::cout << "[ApproxMatch] Called." << std::endl; 24 | int64_t batch_size = set_d.size(0); 25 | int64_t n_dataset_points = set_d.size(1); // n 26 | int64_t n_query_points = set_q.size(1); // m 27 | //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl; 28 | at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 29 | at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 30 | CHECK_INPUT(set_d); 31 | CHECK_INPUT(set_q); 32 | CHECK_INPUT(match); 33 | CHECK_INPUT(temp); 34 | 35 | approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream()); 36 | return {match, temp}; 37 | } 38 | 39 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 40 | //std::cout << "[MatchCost] Called." << std::endl; 41 | int64_t batch_size = set_d.size(0); 42 | int64_t n_dataset_points = set_d.size(1); // n 43 | int64_t n_query_points = set_q.size(1); // m 44 | //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl; 45 | at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 46 | CHECK_INPUT(set_d); 47 | CHECK_INPUT(set_q); 48 | CHECK_INPUT(match); 49 | CHECK_INPUT(out); 50 | matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream()); 51 | return out; 52 | } 53 | 54 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { 55 | //std::cout << "[MatchCostGrad] Called." << std::endl; 56 | int64_t batch_size = set_d.size(0); 57 | int64_t n_dataset_points = set_d.size(1); // n 58 | int64_t n_query_points = set_q.size(1); // m 59 | //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl; 60 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 61 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 62 | CHECK_INPUT(set_d); 63 | CHECK_INPUT(set_q); 64 | CHECK_INPUT(match); 65 | CHECK_INPUT(grad1); 66 | CHECK_INPUT(grad2); 67 | matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream()); 68 | return {grad1, grad2}; 69 | } 70 | 71 | 72 | /* 73 | input: 74 | set_d : batch_size * #dataset_points * 3 75 | set_q : batch_size * #query_points * 3 76 | returns: 77 | dist1, idx1 : batch_size * #dataset_points 78 | dist2, idx2 : batch_size * #query_points 79 | */ 80 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) { 81 | //std::cout << "[NNDistance] Called." << std::endl; 82 | int64_t batch_size = set_d.size(0); 83 | int64_t n_dataset_points = set_d.size(1); // n 84 | int64_t n_query_points = set_q.size(1); // m 85 | //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl; 86 | at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 87 | at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 88 | at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 89 | at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); 90 | CHECK_INPUT(set_d); 91 | CHECK_INPUT(set_q); 92 | CHECK_INPUT(dist1); 93 | CHECK_INPUT(idx1); 94 | CHECK_INPUT(dist2); 95 | CHECK_INPUT(idx2); 96 | // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); 97 | nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream()); 98 | return {dist1, idx1, dist2, idx2}; 99 | } 100 | 101 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) { 102 | //std::cout << "[NNDistanceGrad] Called." << std::endl; 103 | int64_t batch_size = set_d.size(0); 104 | int64_t n_dataset_points = set_d.size(1); // n 105 | int64_t n_query_points = set_q.size(1); // m 106 | //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl; 107 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 108 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); 109 | CHECK_INPUT(set_d); 110 | CHECK_INPUT(set_q); 111 | CHECK_INPUT(idx1); 112 | CHECK_INPUT(idx2); 113 | CHECK_INPUT(grad_dist1); 114 | CHECK_INPUT(grad_dist2); 115 | CHECK_INPUT(grad1); 116 | CHECK_INPUT(grad2); 117 | //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); 118 | nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(), 119 | grad_dist1.data(),idx1.data(), 120 | grad_dist2.data(),idx2.data(), 121 | grad1.data(),grad2.data(), 122 | at::cuda::getCurrentCUDAStream()); 123 | return {grad1, grad2}; 124 | } 125 | 126 | -------------------------------------------------------------------------------- /model/DiffusionCop.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from metrics.evaluation_metrics import * 3 | from model.Diffusion import VarianceSchedule, TimeEmbedding 4 | from model.Decoder_Component import * 5 | 6 | """ 7 | Prediction on the masked patches only 8 | w/ diffusion process 9 | Use FC layer as the mask token convertor 10 | """ 11 | 12 | 13 | class Diff_Point_MAE(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.config = config 17 | self.trans_dim = config.trans_dim 18 | self.group_size = config.group_size 19 | self.num_group = config.num_group 20 | self.num_output = config.num_output 21 | self.num_channel = 3 22 | self.drop_path_rate = config.drop_path_rate 23 | self.mask_token = nn.Conv1d((self.num_channel * 2048) // self.num_group, self.trans_dim, 1) 24 | self.decoder_pos_embed = nn.Sequential( 25 | nn.Linear(3, 128), 26 | nn.GELU(), 27 | nn.Linear(128, self.trans_dim) 28 | ) 29 | 30 | self.decoder_depth = config.decoder_depth 31 | self.decoder_num_heads = config.decoder_num_heads 32 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)] 33 | self.MAE_decoder = Transformer( 34 | embed_dim=self.trans_dim, 35 | depth=self.decoder_depth, 36 | drop_path_rate=dpr, 37 | num_heads=self.decoder_num_heads, 38 | ) 39 | 40 | self.loss = config.loss 41 | # loss 42 | self.build_loss_func(self.loss) 43 | self.var = VarianceSchedule( 44 | num_steps=config.num_steps, 45 | beta_1=config.beta_1, 46 | beta_T=config.beta_T, 47 | mode=config.sched_mode 48 | ) 49 | 50 | # prediction head 51 | self.increase_dim = nn.Sequential( 52 | nn.Conv1d(self.trans_dim, (self.num_channel * 2048) // self.num_group, 1) 53 | ) 54 | 55 | self.timestep = config.num_steps 56 | self.beta_1 = config.beta_1 57 | self.beta_T = config.beta_T 58 | 59 | self.betas = self.linear_schedule(timesteps=self.timestep) 60 | 61 | self.alphas = 1.0 - self.betas 62 | self.alpha_bar = torch.cumprod(self.alphas, axis=0) 63 | self.alpha_bar_t_minus_one = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0) 64 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) 65 | self.sqrt_alphas_bar = torch.sqrt(self.alpha_bar) 66 | self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alpha_bar) 67 | self.sigma = self.betas * (1.0 - self.alpha_bar_t_minus_one) / (1.0 - self.alpha_bar) 68 | self.sqrt_alphas = torch.sqrt(self.alphas) 69 | self.sqrt_alpha_bar_minus_one = torch.sqrt(self.alpha_bar_t_minus_one) 70 | 71 | self.time_emb = nn.Sequential( 72 | TimeEmbedding(self.trans_dim), 73 | nn.Linear(self.trans_dim, self.trans_dim), 74 | nn.ReLU() 75 | ) 76 | 77 | def build_loss_func(self, loss_type): 78 | if loss_type == "cdl1": 79 | self.loss_func = chamfer_distance_l1 80 | elif loss_type == 'cdl2': 81 | self.loss_func = chamfer_distance_l2 82 | elif loss_type == 'mse': 83 | self.loss_func = F.mse_loss 84 | else: 85 | raise NotImplementedError 86 | 87 | def linear_schedule(self, timesteps): 88 | return torch.linspace(self.beta_1, self.beta_T, timesteps) 89 | 90 | def get_index_from_list(self, vals, t, x_shape): 91 | b = t.shape[0] 92 | out = vals.gather(-1, t.cpu()) 93 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device) 94 | 95 | def forward_diffusion(self, x_0, t): 96 | noise = torch.randn_like(x_0).to(x_0.device) 97 | sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_bar, t, x_0.shape).to(x_0.device) 98 | sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_bar, t, x_0.shape).to( 99 | x_0.device) 100 | 101 | return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise 102 | 103 | def forward(self, x_0, t, x_vis, mask, center, vis_pc, ori, debug=False): 104 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1) 105 | x_t, noise = self.forward_diffusion(x_0, t) 106 | 107 | B, _, C = x_vis.shape # B VIS C 108 | 109 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C) 110 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C) 111 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1) 112 | _, N, _ = pos_emd_msk.shape 113 | mask_token = self.mask_token(x_t.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device) 114 | x_full = torch.cat([x_vis, mask_token], dim=1) 115 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts) 116 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) 117 | pc_full = torch.cat([vis_pc, x_rec], dim=1) 118 | if debug: 119 | return x_t, noise, x_rec 120 | else: 121 | return self.loss_func(pc_full, ori) 122 | 123 | def sampling_t(self, x, t, mask, center, x_vis): 124 | center = center.float() 125 | B, _, C = x_vis.shape # B VIS C 126 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1) 127 | betas_t = self.get_index_from_list(self.betas, t, x.shape).to(x_vis.device) 128 | 129 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C) 130 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C) 131 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1) 132 | _, N, _ = pos_emd_msk.shape 133 | mask_token = self.mask_token(x.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device) 134 | x_full = torch.cat([x_vis, mask_token], dim=1) 135 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts) 136 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) 137 | 138 | alpha_bar_t = self.get_index_from_list(self.alpha_bar, t, x.shape).to(x_vis.device) 139 | alpha_bar_t_minus_one = self.get_index_from_list(self.alpha_bar_t_minus_one, t, x.shape).to(x_vis.device) 140 | sqrt_alpha_t = self.get_index_from_list(self.sqrt_alphas, t, x.shape).to(x_vis.device) 141 | sqrt_alphas_bar_t_minus_one = self.get_index_from_list(self.sqrt_alpha_bar_minus_one, t, x.shape).to( 142 | x_vis.device) 143 | 144 | model_mean = (sqrt_alpha_t * (1 - alpha_bar_t_minus_one)) / (1 - alpha_bar_t) * x + ( 145 | sqrt_alphas_bar_t_minus_one * betas_t) / (1 - alpha_bar_t) * x_rec 146 | 147 | sigma_t = self.get_index_from_list(self.sigma, t, x.shape).to(x_vis.device) 148 | 149 | if t == 0: 150 | return model_mean 151 | else: 152 | return model_mean + torch.sqrt(sigma_t) * x_rec 153 | 154 | def sampling(self, x_vis, mask, center, trace=False, noise_patch=None): 155 | B, M, C = x_vis.shape 156 | if noise_patch is None: 157 | noise_patch = torch.randn((B, (self.num_group - M) * self.group_size, 3)).to(x_vis.device) 158 | diffusion_sequence = [] 159 | 160 | for i in range(0, self.timestep)[::-1]: 161 | t = torch.full((1,), i, device=x_vis.device) 162 | noise_patch = self.sampling_t(noise_patch, t, mask, center, x_vis) 163 | if trace: 164 | diffusion_sequence.append(noise_patch.reshape(B, -1, 3)) 165 | 166 | if trace: 167 | return diffusion_sequence 168 | else: 169 | return noise_patch.reshape(B, -1, 3) 170 | -------------------------------------------------------------------------------- /eval_upsampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | from datetime import datetime 4 | 5 | import point_cloud_utils as pcu 6 | from glob import glob 7 | from torch.utils.data import DataLoader 8 | from torch import optim 9 | 10 | from utils.dataset import * 11 | from utils.logger import * 12 | from model.DiffusionSR import * 13 | from model.EncoderSR import * 14 | from utils.visualization import * 15 | from metrics.evaluation_metrics import compute_all_metrics, hausdorff_distance 16 | 17 | 18 | 19 | def get_data_iterator(iterable): 20 | """Allows training with DataLoaders in a single infinite loop: 21 | for i, data in enumerate(inf_generator(train_loader)): 22 | """ 23 | iterator = iterable.__iter__() 24 | while True: 25 | try: 26 | yield iterator.__next__() 27 | except StopIteration: 28 | iterator = iterable.__iter__() 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | # Experiment setting 33 | parser.add_argument('--val_batch_size', type=int, default=1) 34 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 35 | parser.add_argument('--save_dir', type=str, default='./results') 36 | parser.add_argument('--log', type=bool, default=True) 37 | 38 | # Grouping setting 39 | parser.add_argument('--mask_type', type=str, default='rand') 40 | parser.add_argument('--mask_ratio', type=float, default=0.4) 41 | parser.add_argument('--group_size', type=int, default=32) # points in each group 42 | parser.add_argument('--num_group', type=int, default=64) # number of group 43 | parser.add_argument('--num_points', type=int, default=2048) 44 | parser.add_argument('--num_output', type=int, default=8192) 45 | parser.add_argument('--diffusion_output_size', default=8192) 46 | 47 | # Transformer setting 48 | parser.add_argument('--trans_dim', type=int, default=384) 49 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 50 | 51 | # Encoder setting 52 | parser.add_argument('--encoder_depth', type=int, default=12) 53 | parser.add_argument('--encoder_num_heads', type=int, default=6) 54 | parser.add_argument('--loss', type=str, default='cdl2') 55 | 56 | # Decoder setting 57 | parser.add_argument('--decoder_depth', type=int, default=4) 58 | parser.add_argument('--decoder_num_heads', type=int, default=4) 59 | 60 | # diffusion 61 | parser.add_argument('--num_steps', type=int, default=200) 62 | parser.add_argument('--beta_1', type=float, default=1e-4) 63 | parser.add_argument('--beta_T', type=float, default=0.05) 64 | parser.add_argument('--sched_mode', type=str, default='linear') 65 | 66 | args = parser.parse_args() 67 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M") 68 | save_dir = os.path.join(args.save_dir, 'new_start_eval_{date}'.format(date=time_now)) 69 | 70 | if args.log: 71 | if not os.path.exists(save_dir): 72 | os.makedirs(save_dir) 73 | log = logging.getLogger() 74 | log.setLevel(logging.INFO) 75 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s') 76 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt')) 77 | log_file.setLevel(logging.INFO) 78 | log_file.setFormatter(formatter) 79 | log.addHandler(log_file) 80 | 81 | if args.log: 82 | log.info('loading dataset') 83 | print('loading dataset') 84 | 85 | 86 | test_dset_MN = ModelNet( 87 | root='dataset/ModelNet', 88 | number_pts=8192, 89 | use_normal=False, 90 | cats=40, 91 | subset='test' 92 | ) 93 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True) 94 | 95 | val_dset = ShapeNet( 96 | data_path='dataset/ShapeNet55/ShapeNet-55', 97 | pc_path='dataset/ShapeNet55/shapenet_pc', 98 | subset='test', 99 | n_points=2048, 100 | downsample=True 101 | ) 102 | 103 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True) 104 | if args.log: 105 | log.info('Training Stable Diffusion') 106 | log.info('config:') 107 | log.info(args) 108 | log.info('dataset loaded') 109 | print('dataset loaded') 110 | 111 | print('loading model') 112 | check_point_dir = os.path.join('./pretrain_model/sr/encoder.pt') 113 | check_point = torch.load(check_point_dir)['model'] 114 | encoder = Encoder_Module(args).to(args.device) 115 | encoder.load_state_dict(check_point, strict=False) 116 | print('model loaded') 117 | 118 | diff_check_point_dir = os.path.join('./pretrain_model/sr/decoder.pt') 119 | diff_check_point = torch.load(diff_check_point_dir)['model'] 120 | model = Diff_Point_MAE(args).to(args.device) 121 | model = torch.nn.DataParallel(model, device_ids=[0]) 122 | incompatible = model.load_state_dict(diff_check_point, strict=False) 123 | 124 | optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05) 125 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0.000001, T_max=20) 126 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255] 127 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255] 128 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255] 129 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255] 130 | 131 | 132 | def load(filename, count=None): 133 | points = np.loadtxt(filename).astype(np.float32) 134 | if count is not None: 135 | if count > points.shape[0]: 136 | # fill the point clouds with the random point 137 | tmp = np.zeros((count, points.shape[1]), dtype=points.dtype) 138 | tmp[:points.shape[0], ...] = points 139 | tmp[points.shape[0]:, ...] = points[np.random.choice( 140 | points.shape[0], count - points.shape[0]), :] 141 | points = tmp 142 | elif count < points.shape[0]: 143 | # different to pointnet2, take random x point instead of the first 144 | # idx = np.random.permutation(count) 145 | # points = points[idx, :] 146 | points = uniform_down_sample(points, count) 147 | 148 | return points 149 | 150 | def calculate_metric_all(size=-1): 151 | all_sample = [] 152 | all_ref = [] 153 | all_hd = [] 154 | all_vis = [] 155 | gt_paths = glob(os.path.join('dataset/PU1K/test/input_2048/gt_8192', '*.xyz')) 156 | x_paths = glob(os.path.join('dataset/PU1K/test/input_2048/input_2048', '*.xyz')) 157 | 158 | for i in range(0, len(gt_paths)): 159 | if not os.path.exists(save_dir + '/' + str(i)): 160 | os.makedirs(save_dir + '/' + str(i)) 161 | with torch.no_grad(): 162 | x_hr = torch.from_numpy(load(gt_paths[i])[:, :3]).float().unsqueeze(0).to('cuda') 163 | x = torch.from_numpy(load(x_paths[i])[:, :3]).float().unsqueeze(0).to('cuda') 164 | model.eval() 165 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(x, masked=False) 166 | recons = model.module.sampling(x_vis, mask, center) 167 | all_vis.append(x_vis) 168 | all_sample.append(recons) 169 | all_ref.append(x_hr) 170 | hd = hausdorff_distance(recons, x_hr) 171 | all_hd.append(hd) 172 | print("evaluating model: {i}".format(i=i)) 173 | 174 | mean_hd = sum(all_hd) / len(all_hd) 175 | sample = torch.cat(all_sample, dim=0) 176 | refpc = torch.cat(all_ref, dim=0) 177 | all = compute_all_metrics(sample, refpc, 1) 178 | print("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd)) 179 | log.info("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd)) 180 | 181 | calculate_metric_all() 182 | -------------------------------------------------------------------------------- /train_decoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | from torch import optim 4 | from torch.utils.data import DataLoader 5 | from utils.dataset import * 6 | from utils.logger import * 7 | from model.DiffusionPretrain import * 8 | from model.Encoder import * 9 | 10 | 11 | def get_data_iterator(iterable): 12 | """Allows training with DataLoaders in a single infinite loop: 13 | for i, data in enumerate(inf_generator(train_loader)): 14 | """ 15 | iterator = iterable.__iter__() 16 | while True: 17 | try: 18 | yield iterator.__next__() 19 | except StopIteration: 20 | iterator = iterable.__iter__() 21 | 22 | parser = argparse.ArgumentParser() 23 | # Experiment setting 24 | parser.add_argument('--batch_size', type=int, default=32) 25 | parser.add_argument('--val_batch_size', type=int, default=1) 26 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 27 | parser.add_argument('--log', type=bool, default=True) 28 | parser.add_argument('--save_dir', type=str, default='./results') 29 | 30 | # Grouping setting 31 | parser.add_argument('--mask_type', type=str, default='rand') 32 | parser.add_argument('--mask_ratio', type=float, default=0.75) 33 | parser.add_argument('--group_size', type=int, default=32) # points in each group 34 | parser.add_argument('--num_group', type=int, default=64) # number of group 35 | parser.add_argument('--num_points', type=int, default=2048) 36 | parser.add_argument('--num_output', type=int, default=8192) 37 | parser.add_argument('--diffusion_output_size', default=2048) 38 | 39 | # Transformer setting 40 | parser.add_argument('--trans_dim', type=int, default=384) 41 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 42 | 43 | # Encoder setting 44 | parser.add_argument('--encoder_depth', type=int, default=12) 45 | parser.add_argument('--encoder_num_heads', type=int, default=6) 46 | parser.add_argument('--loss', type=str, default='cdl2') 47 | 48 | # Decoder setting 49 | parser.add_argument('--decoder_depth', type=int, default=4) 50 | parser.add_argument('--decoder_num_heads', type=int, default=4) 51 | 52 | # diffusion 53 | parser.add_argument('--num_steps', type=int, default=200) 54 | parser.add_argument('--beta_1', type=float, default=1e-4) 55 | parser.add_argument('--beta_T', type=float, default=0.05) 56 | parser.add_argument('--sched_mode', type=str, default='linear') 57 | 58 | # sche / optim 59 | parser.add_argument('--learning_rate', type=float, default=0.001) 60 | parser.add_argument('--weight_decay', type=float, default=0.05) 61 | parser.add_argument('--eta_min', type=float, default=0.000001) 62 | parser.add_argument('--t_max', type=float, default=200) 63 | 64 | args = parser.parse_args() 65 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M") 66 | save_dir = os.path.join(args.save_dir, 'pretrain_decoder_{date}'.format(date=time_now)) 67 | 68 | if args.log: 69 | if not os.path.exists(save_dir): 70 | os.makedirs(save_dir) 71 | log = logging.getLogger() 72 | log.setLevel(logging.INFO) 73 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s') 74 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt')) 75 | log_file.setLevel(logging.INFO) 76 | log_file.setFormatter(formatter) 77 | log.addHandler(log_file) 78 | 79 | 80 | 81 | if args.log: 82 | log.info('loading dataset') 83 | print('loading dataset') 84 | 85 | train_dset = ShapeNet( 86 | data_path='dataset/ShapeNet55/ShapeNet-55', 87 | pc_path='dataset/ShapeNet55/shapenet_pc', 88 | subset='train', 89 | n_points=2048, 90 | downsample=True 91 | ) 92 | val_dset = ShapeNet( 93 | data_path='dataset/ShapeNet55/ShapeNet-55', 94 | pc_path='dataset/ShapeNet55/shapenet_pc', 95 | subset='test', 96 | n_points=2048, 97 | downsample=True 98 | ) 99 | 100 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True) 101 | trn_loader = DataLoader(train_dset, batch_size=args.batch_size, pin_memory=True) 102 | if args.log: 103 | log.info('Training Stable Diffusion') 104 | log.info('config:') 105 | log.info(args) 106 | log.info('dataset loaded') 107 | print('dataset loaded') 108 | 109 | print('loading model') 110 | check_point_dir = os.path.join('./pretrain_model/pretrain/encoder.pt') 111 | 112 | check_point = torch.load(check_point_dir)['model'] 113 | encoder = Encoder_Module(args).to(args.device) 114 | encoder.load_state_dict(check_point) 115 | print('model loaded') 116 | 117 | model = Diff_Point_MAE(args).to(args.device) 118 | model = torch.nn.DataParallel(model, device_ids=[0]) 119 | 120 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) 121 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=args.eta_min, T_max=args.t_max) 122 | 123 | def train(i, batch, epoch): 124 | if i == 0 and epoch > 1: 125 | print('Saving checkpoint at epoch {epoch}'.format(epoch=epoch)) 126 | save_model = {'args': args, 'model': model.state_dict()} 127 | if args.log: 128 | torch.save(save_model, os.path.join(save_dir, 'autoencoder_diff.pt')) 129 | 130 | x = batch['lr'].to(args.device) 131 | optimizer.zero_grad() 132 | model.train() 133 | encoder.eval() 134 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(x, masked=False) 135 | t = torch.randint(1, args.num_steps, (x.size(0),)) 136 | 137 | loss = model(msk_pc, t, x_vis, mask, center, vis_pc, x) 138 | loss = loss.mean() 139 | loss.backward() 140 | optimizer.step() 141 | if args.log and i == 0: 142 | log.info('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss)) 143 | print('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss)) 144 | 145 | 146 | def validate(): 147 | # all_refs = [] 148 | all_recons = [] 149 | for i, batch in enumerate(val_loader): 150 | print('sampling model {i}'.format(i=i)) 151 | ref = batch['lr'].to(args.device) 152 | if i > 200: 153 | break 154 | with torch.no_grad(): 155 | model.eval() 156 | encoder.eval() 157 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False) 158 | recons = model.module.sampling(x_vis, mask, center) 159 | all_recons.append(recons) 160 | all_recons = torch.cat(all_recons, dim=0) 161 | np.save(os.path.join(save_dir, 'out.npy'), all_recons.cpu().numpy()) 162 | 163 | def forward_diff(): 164 | all_recons = [] 165 | for i, batch in enumerate(val_loader): 166 | ref = batch['lr'].to(args.device) 167 | if i > 200: 168 | break 169 | with torch.no_grad(): 170 | model.eval() 171 | encoder.eval() 172 | if i == 100: 173 | print('sampling model {i}'.format(i=i)) 174 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False) 175 | 176 | for idx in range(1, args.num_steps, 1): 177 | print('sampling step {idx}'.format(idx=idx)) 178 | t = torch.Tensor([idx]).type(torch.int64) 179 | x_noisze, _ = model.module.forward_diffusion(msk_pc, t) 180 | recons = _.reshape(1, -1, 3) 181 | 182 | all_recons.append(torch.cat([recons, vis_pc.reshape(1, -1, 3)], dim=1)) 183 | 184 | all_recons = torch.cat(all_recons, dim=0) 185 | np.save(os.path.join(save_dir, 'out.npy'), all_recons.cpu().numpy()) 186 | 187 | 188 | try: 189 | n_it = 100 190 | epoch = 1 191 | while epoch <= n_it: 192 | model.train() 193 | for i, pc in enumerate(trn_loader): 194 | train(i, pc, epoch) 195 | scheduler.step() 196 | 197 | if epoch == n_it: 198 | if args.log: 199 | saved_file = {'args': args, 'model': model.state_dict()} 200 | torch.save(saved_file, os.path.join(save_dir, 'autoencoder_diff.pt')) 201 | validate() 202 | epoch += 1 203 | except Exception as e: 204 | log.error(e) 205 | 206 | -------------------------------------------------------------------------------- /model/EncoderCompress.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | import numpy as np 5 | import random 6 | 7 | import torch.nn.functional as F 8 | 9 | from metrics.evaluation_metrics import chamfer_distance_l2, chamfer_distance_l1 10 | from model.Encoder_Component import Encoder, Group, TransformerEncoder 11 | 12 | class PointTransformer(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | self.config = config 16 | # define the transformer argparse 17 | self.trans_dim = config.trans_dim 18 | self.depth = config.encoder_depth 19 | self.drop_path_rate = config.drop_path_rate 20 | self.num_heads = config.encoder_num_heads 21 | # embedding 22 | self.encoder_dims = config.trans_dim 23 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 24 | self.mask_type = config.mask_type 25 | self.mask_ratio = config.mask_ratio 26 | self.group_size = config.group_size 27 | 28 | self.pos_embed = nn.Sequential( 29 | nn.Linear(3, 128), 30 | nn.GELU(), 31 | nn.Linear(128, self.trans_dim), 32 | ) 33 | 34 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 35 | self.blocks = TransformerEncoder( 36 | embed_dim=self.trans_dim, 37 | depth=self.depth, 38 | drop_path_rate=dpr, 39 | num_heads=self.num_heads, 40 | ) 41 | 42 | self.norm = nn.LayerNorm(self.trans_dim) 43 | self.apply(self._init_weights) 44 | 45 | def _init_weights(self, m): 46 | if isinstance(m, nn.Linear): 47 | trunc_normal_(m.weight, std=.02) 48 | if isinstance(m, nn.Linear) and m.bias is not None: 49 | nn.init.constant_(m.bias, 0) 50 | elif isinstance(m, nn.LayerNorm): 51 | nn.init.constant_(m.bias, 0) 52 | nn.init.constant_(m.weight, 1.0) 53 | elif isinstance(m, nn.Conv1d): 54 | trunc_normal_(m.weight, std=.02) 55 | if m.bias is not None: 56 | nn.init.constant_(m.bias, 0) 57 | 58 | 59 | def forward(self, neighborhood, center, mask): 60 | # generate mask 61 | B, _ = mask.size() 62 | _, _, C = neighborhood.shape 63 | vis = neighborhood.reshape(B, -1, self.group_size, C) 64 | group_input_tokens = self.encoder(vis) # B G C 65 | 66 | batch_size, seq_len, L = group_input_tokens.size() 67 | vis_center = center[~mask].reshape(B, -1, C) 68 | p = self.pos_embed(vis_center) 69 | z = self.blocks(group_input_tokens, p) 70 | z = self.norm(z) 71 | return z.reshape(batch_size, -1, L) 72 | 73 | class Encoder_Module(nn.Module): 74 | def __init__(self, config): 75 | super().__init__() 76 | self.config = config 77 | self.trans_dim = config.trans_dim 78 | self.AE_encoder = PointTransformer(config) 79 | self.group_size = config.group_size 80 | self.num_group = config.num_group 81 | self.num_output = config.num_output 82 | self.mask_ratio = config.mask_ratio 83 | self.num_channel = 3 84 | self.drop_path_rate = config.drop_path_rate 85 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 86 | 87 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 88 | 89 | # prediction head 90 | self.increase_dim = nn.Sequential( 91 | nn.Conv1d(self.trans_dim, (self.num_channel * int(self.num_output * (1 - self.mask_ratio))) // self.num_group, 1) 92 | ) 93 | 94 | trunc_normal_(self.mask_token, std=.02) 95 | self.loss = config.loss 96 | # loss 97 | self.build_loss_func(self.loss) 98 | 99 | def build_loss_func(self, loss_type): 100 | if loss_type == "cdl1": 101 | self.loss_func = chamfer_distance_l1 102 | elif loss_type == 'cdl2': 103 | self.loss_func = chamfer_distance_l2 104 | elif loss_type == 'mse': 105 | self.loss_func = F.mse_loss 106 | else: 107 | raise NotImplementedError 108 | 109 | def forward(self, pts, neighborhood, center, mask, msk_pc): 110 | neighborhood = neighborhood.float() 111 | center = center.float() 112 | x_vis, vis_pc = self.encode(neighborhood, center, mask) 113 | 114 | B, _, C = x_vis.shape # B VIS C 115 | 116 | rebuild_points = self.increase_dim(x_vis.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 117 | 118 | full = torch.cat([rebuild_points, msk_pc], dim=1) 119 | loss1 = self.loss_func(full, pts) 120 | return loss1 121 | 122 | def encode(self, neighborhood, center, mask): 123 | neighborhood = neighborhood.float() 124 | center = center.float() 125 | x_vis = self.AE_encoder(neighborhood, center, mask) 126 | 127 | full_vis = self.neighborhood(neighborhood, center, mask, x_vis) 128 | return x_vis, full_vis 129 | 130 | 131 | def evaluate(self, x_vis, x_msk): 132 | B, _, C = x_vis.shape # B VIS C 133 | x_full = torch.cat([x_vis, x_msk], dim=1) 134 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 135 | return rebuild_points.reshape(-1, 3).unsqueeze(0) 136 | 137 | def neighborhood(self, neighborhood, center, mask, x_vis): 138 | B, M, N = x_vis.shape 139 | vis_point = neighborhood.reshape(B * M, -1, 3) 140 | full_vis = vis_point + center[~mask].unsqueeze(1) 141 | full_vis = full_vis.reshape(B, -1, 3) 142 | 143 | return full_vis 144 | 145 | 146 | class Compress(nn.Module): 147 | def __init__(self, config): 148 | super().__init__() 149 | self.group_size = config.group_size 150 | self.num_group = config.num_group 151 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 152 | self.mask_ratio = config.mask_ratio 153 | self.mask_type = config.mask_type 154 | 155 | def _mask_center_block(self, center, noaug=False): 156 | if noaug or self.mask_ratio == 0: 157 | return torch.zeros(center.shape[:2]).bool() 158 | mask_idx = [] 159 | for points in center: 160 | points = points.unsqueeze(0) 161 | index = random.randint(0, points.size(1) - 1) 162 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1) 163 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0] 164 | ratio = self.mask_ratio 165 | mask_num = int(ratio * len(idx)) 166 | mask = torch.zeros(len(idx)) 167 | mask[idx[:mask_num]] = 1 168 | mask_idx.append(mask.bool()) 169 | bool_masked_pos = torch.stack(mask_idx).to(center.device) 170 | return bool_masked_pos 171 | 172 | def _mask_center_rand(self, center, noaug=False): 173 | ''' 174 | center : B G 3 175 | -------------- 176 | mask : B G (bool) 177 | ''' 178 | B, G, _ = center.shape 179 | # skip the mask 180 | if noaug or self.mask_ratio == 0: 181 | return torch.zeros(center.shape[:2]).bool() 182 | 183 | self.num_mask = int(self.mask_ratio * G) 184 | 185 | overall_mask = np.zeros([B, G]) 186 | for i in range(B): 187 | mask = np.hstack([ 188 | np.zeros(G - self.num_mask), 189 | np.ones(self.num_mask), 190 | ]) 191 | np.random.shuffle(mask) 192 | overall_mask[i, :] = mask 193 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool) 194 | 195 | return overall_mask.to(center.device) # B G 196 | 197 | def compress(self, pt): 198 | neighborhood, center = self.group_divider(pt) 199 | if self.mask_type == 'rand': 200 | mask = self._mask_center_rand(center) # B G 201 | else: 202 | mask = self._mask_center_block(center) 203 | vis_point = neighborhood[~mask] 204 | return vis_point.half(), center.half(), mask 205 | -------------------------------------------------------------------------------- /model/DiffusionSR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2 5 | from model.Diffusion import VarianceSchedule, TimeEmbedding 6 | from model.Decoder_Component import * 7 | 8 | 9 | """ 10 | Prediction on the masked patches only 11 | w/ diffusion process 12 | Use FC layer as the mask token convertor 13 | """ 14 | 15 | class Diff_Point_MAE(nn.Module): 16 | def __init__(self, config): 17 | super().__init__() 18 | self.config = config 19 | self.trans_dim = config.trans_dim 20 | self.group_size = config.group_size 21 | self.num_group = config.num_group 22 | self.num_output = config.diffusion_output_size 23 | self.num_channel = 3 24 | self.drop_path_rate = config.drop_path_rate 25 | self.mask_token = nn.Conv1d((self.num_channel * 2048) // self.num_group, self.trans_dim, 1) 26 | self.mask_token_sr = nn.Conv1d((self.num_channel * self.num_output) // self.num_group, self.trans_dim, 1) 27 | self.decoder_pos_embed = nn.Sequential( 28 | nn.Linear(3, 128), 29 | nn.GELU(), 30 | nn.Linear(128, self.trans_dim) 31 | ) 32 | 33 | self.decoder_depth = config.decoder_depth 34 | self.decoder_num_heads = config.decoder_num_heads 35 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)] 36 | self.decoder = Transformer( 37 | embed_dim=self.trans_dim, 38 | depth=self.decoder_depth, 39 | drop_path_rate=dpr, 40 | num_heads=self.decoder_num_heads, 41 | ) 42 | 43 | self.loss = config.loss 44 | # loss 45 | self.build_loss_func(self.loss) 46 | self.var = VarianceSchedule( 47 | num_steps=config.num_steps, 48 | beta_1=config.beta_1, 49 | beta_T=config.beta_T, 50 | mode=config.sched_mode 51 | ) 52 | 53 | # prediction head 54 | self.increase_dim = nn.Sequential( 55 | nn.Conv1d(self.trans_dim, (self.num_channel * 2048) // self.num_group, 1) 56 | ) 57 | 58 | self.increase_dim_sr = nn.Sequential( 59 | nn.Conv1d(self.trans_dim, (self.num_channel * self.num_output) // self.num_group, 1) 60 | ) 61 | 62 | self.timestep = config.num_steps 63 | self.beta_1 = config.beta_1 64 | self.beta_T = config.beta_T 65 | 66 | self.betas = self.linear_schedule(timesteps=self.timestep) 67 | 68 | self.alphas = 1.0 - self.betas 69 | self.alpha_bar = torch.cumprod(self.alphas, axis=0) 70 | self.alpha_bar_t_minus_one = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0) 71 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) 72 | self.sqrt_alphas_bar = torch.sqrt(self.alpha_bar) 73 | self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alpha_bar) 74 | self.sigma = self.betas * (1.0 - self.alpha_bar_t_minus_one) / (1.0 - self.alpha_bar) 75 | self.sqrt_alphas = torch.sqrt(self.alphas) 76 | self.sqrt_alpha_bar_minus_one = torch.sqrt(self.alpha_bar_t_minus_one) 77 | 78 | self.time_emb = nn.Sequential( 79 | TimeEmbedding(self.trans_dim), 80 | nn.Linear(self.trans_dim, self.trans_dim), 81 | nn.ReLU() 82 | ) 83 | 84 | def build_loss_func(self, loss_type): 85 | if loss_type == "cdl1": 86 | self.loss_func = chamfer_distance_l1 87 | elif loss_type == 'cdl2': 88 | self.loss_func = chamfer_distance_l2 89 | elif loss_type == 'mse': 90 | self.loss_func = nn.MSELoss() 91 | else: 92 | raise NotImplementedError 93 | 94 | def linear_schedule(self, timesteps): 95 | return torch.linspace(self.beta_1, self.beta_T, timesteps) 96 | 97 | def get_index_from_list(self, vals, t, x_shape): 98 | b = t.shape[0] 99 | out = vals.gather(-1, t.cpu()) 100 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device) 101 | 102 | def forward_diffusion(self, x_0, t): 103 | noise = torch.randn_like(x_0).to(x_0.device) 104 | sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_bar, t, x_0.shape).to(x_0.device) 105 | sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_bar, t, x_0.shape).to(x_0.device) 106 | 107 | return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise 108 | 109 | 110 | def forward(self, x_0, t, x_vis, mask, center, ori, debug=False): 111 | """ 112 | x_0: lr full (vis, msk) 113 | t: time step 114 | x_vis: lr vis latent 115 | mask: mask 116 | center: center 117 | ori: lr overall model 118 | """ 119 | # print() 120 | # print(t.size()) 121 | 122 | # print(ts.size()) 123 | x_t, noise = self.forward_diffusion(x_0, t) 124 | 125 | B, _, C = x_vis.shape # B VIS C 126 | 127 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C) 128 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C) 129 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1) 130 | _, N, _ = pos_emd_msk.shape 131 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1) 132 | mask_token = self.mask_token_sr(x_t.reshape(B, self.num_group, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device) 133 | x_full = mask_token[:, :N, :] 134 | x_full = torch.cat([x_vis, x_full], dim=1) 135 | # x_full = torch.cat([x_vis, mask_token], dim=1) # x_vis, x_vis, msk 136 | x_rec = self.decoder(x_full, pos_full, self.num_group, ts) 137 | x_rec = self.increase_dim_sr(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) 138 | if debug: 139 | return x_t, noise, x_rec 140 | else: 141 | return self.loss_func(x_rec, ori) 142 | # return F.mse_loss(x_rec, x_0, reduction='mean') 143 | 144 | def sampling_t(self, x, t, mask, center, x_vis): 145 | B, M, C = x_vis.shape # B VIS C 146 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1) 147 | betas_t = self.get_index_from_list(self.betas, t, x.shape).to(x_vis.device) 148 | 149 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C) 150 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C) 151 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1) 152 | _, N, _ = pos_emd_msk.shape 153 | mask_token = self.mask_token_sr(x.reshape(B, self.num_group, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device) 154 | x_full = mask_token[:, :N, :] 155 | x_full = torch.cat([x_vis, x_full], dim=1) 156 | x_rec = self.decoder(x_full, pos_full, self.num_group, ts) 157 | x_rec = self.increase_dim_sr(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) 158 | 159 | 160 | alpha_bar_t = self.get_index_from_list(self.alpha_bar, t, x.shape).to(x_vis.device) 161 | alpha_bar_t_minus_one = self.get_index_from_list(self.alpha_bar_t_minus_one, t, x.shape).to(x_vis.device) 162 | sqrt_alpha_t = self.get_index_from_list(self.sqrt_alphas, t, x.shape).to(x_vis.device) 163 | sqrt_alphas_bar_t_minus_one = self.get_index_from_list(self.sqrt_alpha_bar_minus_one, t, x.shape).to(x_vis.device) 164 | 165 | model_mean = (sqrt_alpha_t * (1 - alpha_bar_t_minus_one)) / (1 - alpha_bar_t) * x + (sqrt_alphas_bar_t_minus_one * betas_t) / (1 - alpha_bar_t) * x_rec 166 | 167 | sigma_t = self.get_index_from_list(self.sigma, t, x.shape).to(x_vis.device) 168 | 169 | if t == 0: 170 | return model_mean 171 | else: 172 | return model_mean + torch.sqrt(sigma_t) * x_rec 173 | 174 | def sampling(self, x_vis, mask, center, ret=False, noise_patch=None): 175 | B, M, C = x_vis.shape 176 | if noise_patch is None: 177 | noise_patch = torch.randn((B, self.num_output, 3)).to(x_vis.device) 178 | traj = [] 179 | 180 | for i in range(0, self.timestep)[::-1]: 181 | t = torch.full((1,), i, device=x_vis.device) 182 | noise_patch = self.sampling_t(noise_patch, t, mask, center, x_vis) 183 | if ret: 184 | traj.append(noise_patch.reshape(B, -1, 3)) 185 | 186 | if ret: 187 | return traj 188 | else: 189 | return noise_patch.reshape(B, -1, 3) 190 | -------------------------------------------------------------------------------- /model/DiffusionPretrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2 5 | from model.Diffusion import VarianceSchedule, TimeEmbedding 6 | from model.Decoder_Component import * 7 | 8 | """ 9 | Prediction on the masked patches only 10 | w/ diffusion process 11 | Use FC layer as the mask token convertor 12 | """ 13 | 14 | 15 | class Diff_Point_MAE(nn.Module): 16 | def __init__(self, config): 17 | super().__init__() 18 | self.config = config 19 | self.trans_dim = config.trans_dim 20 | self.group_size = config.group_size 21 | self.num_group = config.num_group 22 | self.num_output = config.num_output 23 | self.diffusion_output_size = config.diffusion_output_size 24 | self.num_channel = 3 25 | self.drop_path_rate = config.drop_path_rate 26 | self.mask_token = nn.Conv1d((self.num_channel * self.diffusion_output_size) // self.num_group, self.trans_dim, 1) 27 | self.decoder_pos_embed = nn.Sequential( 28 | nn.Linear(3, 128), 29 | nn.GELU(), 30 | nn.Linear(128, self.trans_dim) 31 | ) 32 | 33 | self.decoder_depth = config.decoder_depth 34 | self.decoder_num_heads = config.decoder_num_heads 35 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)] 36 | self.MAE_decoder = Transformer( 37 | embed_dim=self.trans_dim, 38 | depth=self.decoder_depth, 39 | drop_path_rate=dpr, 40 | num_heads=self.decoder_num_heads, 41 | ) 42 | 43 | self.loss = config.loss 44 | # loss 45 | self.build_loss_func(self.loss) 46 | self.var = VarianceSchedule( 47 | num_steps=config.num_steps, 48 | beta_1=config.beta_1, 49 | beta_T=config.beta_T, 50 | mode=config.sched_mode 51 | ) 52 | 53 | # prediction head 54 | self.increase_dim = nn.Sequential( 55 | nn.Conv1d(self.trans_dim, (self.num_channel * 2048) // self.num_group, 1) 56 | ) 57 | 58 | self.timestep = config.num_steps 59 | self.beta_1 = config.beta_1 60 | self.beta_T = config.beta_T 61 | 62 | self.betas = self.linear_schedule(timesteps=self.timestep) 63 | 64 | self.alphas = 1.0 - self.betas 65 | self.alpha_bar = torch.cumprod(self.alphas, axis=0) 66 | self.alpha_bar_t_minus_one = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0) 67 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) 68 | self.sqrt_alphas_bar = torch.sqrt(self.alpha_bar) 69 | self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alpha_bar) 70 | self.sigma = self.betas * (1.0 - self.alpha_bar_t_minus_one) / (1.0 - self.alpha_bar) 71 | self.sqrt_alphas = torch.sqrt(self.alphas) 72 | self.sqrt_alpha_bar_minus_one = torch.sqrt(self.alpha_bar_t_minus_one) 73 | 74 | self.time_emb = nn.Sequential( 75 | TimeEmbedding(self.trans_dim), 76 | nn.Linear(self.trans_dim, self.trans_dim), 77 | nn.ReLU() 78 | ) 79 | 80 | def build_loss_func(self, loss_type): 81 | if loss_type == "cdl1": 82 | self.loss_func = chamfer_distance_l1 83 | elif loss_type == 'cdl2': 84 | self.loss_func = chamfer_distance_l2 85 | elif loss_type == 'mse': 86 | self.loss_func = F.mse_loss 87 | else: 88 | raise NotImplementedError 89 | 90 | def linear_schedule(self, timesteps): 91 | return torch.linspace(self.beta_1, self.beta_T, timesteps) 92 | 93 | def get_index_from_list(self, vals, t, x_shape): 94 | b = t.shape[0] 95 | out = vals.gather(-1, t.cpu()) 96 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device) 97 | 98 | def forward_diffusion(self, x_0, t): 99 | """ 100 | Adding noise to the original input for 101 | the forward diffusion process. The noise level 102 | calculated based on current timestep t. 103 | 104 | :param x_0: The original masked input. 105 | [B, NM, C] where 106 | B = Batch size, 107 | NM = Number of point in masked patches, 108 | C = Data Channels (3 for current task) 109 | :param t: The current timestep 110 | :return: The noisy masked input at timestep t. 111 | [B, NM, C] 112 | """ 113 | noise = torch.randn_like(x_0).to(x_0.device) 114 | sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_bar, t, x_0.shape).to(x_0.device) 115 | sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_bar, t, x_0.shape).to( 116 | x_0.device) 117 | 118 | return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise 119 | 120 | def forward(self, x_0, t, x_vis, mask, center, vis_pc, ori, debug=False): 121 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1) 122 | x_t, noise = self.forward_diffusion(x_0, t) 123 | 124 | B, _, C = x_vis.shape # B VIS C 125 | 126 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C) 127 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C) 128 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1) 129 | _, N, _ = pos_emd_msk.shape 130 | mask_token = self.mask_token(x_t.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device) 131 | x_full = torch.cat([x_vis, mask_token], dim=1) 132 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts) 133 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) 134 | pc_full = torch.cat([vis_pc, x_rec], dim=1) 135 | if debug: 136 | return x_t, noise, x_rec 137 | else: 138 | return self.loss_func(pc_full, ori) 139 | 140 | def sampling_t(self, noisy_t, t, mask, center, x_vis): 141 | """ 142 | Reverse sampling at timestep t. 143 | Input noisy level at timestep t, 144 | return noisy level at timestep t-1. 145 | 146 | :param noisy_t: The noisy masked patches at timestep t. 147 | [B, NM, C] 148 | :param t: Timestep. 149 | :param mask: The mask indicator. [B, G] 150 | :param center: The center points. [B, G, C] 151 | :param x_vis: The latent of visible patches. [B, V, L] 152 | :return: The noisy masked patches at timestep t-1. [B, NM, C] 153 | """ 154 | B, _, C = x_vis.shape # B VIS C 155 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1) 156 | betas_t = self.get_index_from_list(self.betas, t, noisy_t.shape).to(x_vis.device) 157 | 158 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C) 159 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C) 160 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1) 161 | _, N, _ = pos_emd_msk.shape 162 | mask_token = self.mask_token(noisy_t.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device) 163 | x_full = torch.cat([x_vis, mask_token], dim=1) 164 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts) 165 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) 166 | 167 | alpha_bar_t = self.get_index_from_list(self.alpha_bar, t, noisy_t.shape).to(x_vis.device) 168 | alpha_bar_t_minus_one = self.get_index_from_list(self.alpha_bar_t_minus_one, t, noisy_t.shape).to(x_vis.device) 169 | sqrt_alpha_t = self.get_index_from_list(self.sqrt_alphas, t, noisy_t.shape).to(x_vis.device) 170 | sqrt_alphas_bar_t_minus_one = self.get_index_from_list(self.sqrt_alpha_bar_minus_one, t, noisy_t.shape).to( 171 | x_vis.device) 172 | 173 | model_mean = (sqrt_alpha_t * (1 - alpha_bar_t_minus_one)) / (1 - alpha_bar_t) * noisy_t + ( 174 | sqrt_alphas_bar_t_minus_one * betas_t) / (1 - alpha_bar_t) * x_rec 175 | 176 | sigma_t = self.get_index_from_list(self.sigma, t, noisy_t.shape).to(x_vis.device) 177 | 178 | if t == 0: 179 | return model_mean 180 | else: 181 | return model_mean + torch.sqrt(sigma_t) * x_rec 182 | 183 | def sampling(self, x_vis, mask, center, trace=False, noise_patch=None): 184 | """ 185 | Sampling the masked patches from Gaussian noise. 186 | 187 | :param x_vis: The latent of visible patches. 188 | [B, V, L] 189 | B = Batch size, 190 | V = Visible patches size, 191 | L = Latent size. 192 | :param mask: The mask indicator. 193 | [B, G] 194 | B = Batch size, 195 | G = Group (Visible patches + Masked patches) size. 196 | :param center: The center points. 197 | [B, G, C] 198 | B = Batch size, 199 | G = Group size, 200 | C = Data Channels (3 for current task) 201 | :param trace: Boolean, False by default. 202 | if true: return all reverse diffusion steps. 203 | else: return the last step only. 204 | :param noise_patch: The pre-defined noises, None by default. 205 | :return: See param trace. 206 | """ 207 | B, M, C = x_vis.shape 208 | if noise_patch is None: 209 | noise_patch = torch.randn((B, (self.num_group - M) * self.group_size, 3)).to(x_vis.device) 210 | diffusion_sequence = [] 211 | 212 | for i in range(0, self.timestep)[::-1]: 213 | t = torch.full((1,), i, device=x_vis.device) 214 | noise_patch = self.sampling_t(noise_patch, t, mask, center, x_vis) 215 | if trace: 216 | diffusion_sequence.append(noise_patch.reshape(B, -1, 3)) 217 | 218 | if trace: 219 | return diffusion_sequence 220 | else: 221 | return noise_patch.reshape(B, -1, 3) 222 | -------------------------------------------------------------------------------- /model/Encoder_Component.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath, trunc_normal_ 4 | import numpy as np 5 | from utils import misc 6 | import random 7 | from knn_cuda import KNN 8 | 9 | class Encoder(nn.Module): ## Embedding module 10 | def __init__(self, encoder_channel): 11 | super().__init__() 12 | self.encoder_channel = encoder_channel 13 | self.first_conv = nn.Sequential( 14 | nn.Conv1d(3, 128, 1), 15 | nn.BatchNorm1d(128), 16 | nn.ReLU(inplace=True), 17 | nn.Conv1d(128, 256, 1) 18 | ) 19 | self.second_conv = nn.Sequential( 20 | nn.Conv1d(512, 512, 1), 21 | nn.BatchNorm1d(512), 22 | nn.ReLU(inplace=True), 23 | nn.Conv1d(512, self.encoder_channel, 1) 24 | ) 25 | 26 | def forward(self, point_groups): 27 | ''' 28 | point_groups : B G N 3 29 | ----------------- 30 | feature_global : B G C 31 | ''' 32 | bs, g, n, _ = point_groups.shape 33 | point_groups = point_groups.reshape(bs * g, n, 3) 34 | # encoder 35 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n 36 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1 37 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n 38 | feature = self.second_conv(feature) # BG 1024 n 39 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024 40 | return feature_global.reshape(bs, g, self.encoder_channel) 41 | 42 | 43 | class Group(nn.Module): # FPS + KNN 44 | def __init__(self, num_group, group_size): 45 | super().__init__() 46 | self.num_group = num_group 47 | self.group_size = group_size 48 | self.knn = KNN(k=self.group_size, transpose_mode=True) 49 | 50 | def forward(self, xyz): 51 | ''' 52 | input: B N 3 53 | --------------------------- 54 | output: B G M 3 55 | center : B G 3 56 | ''' 57 | batch_size, num_points, _ = xyz.shape 58 | # fps the centers out 59 | center = misc.fps(xyz, self.num_group) # B G 3 60 | # knn to get the neighborhood 61 | _, idx = self.knn(xyz, center) # B G M 62 | assert idx.size(1) == self.num_group 63 | assert idx.size(2) == self.group_size 64 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 65 | idx = idx + idx_base 66 | idx = idx.view(-1) 67 | neighborhood = xyz.view(batch_size * num_points, -1)[idx, :] 68 | neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous() 69 | # normalize 70 | neighborhood = neighborhood - center.unsqueeze(2) 71 | return neighborhood, center 72 | 73 | 74 | ## Transformers 75 | class Mlp(nn.Module): 76 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 77 | super().__init__() 78 | out_features = out_features or in_features 79 | hidden_features = hidden_features or in_features 80 | self.fc1 = nn.Linear(in_features, hidden_features) 81 | self.act = act_layer() 82 | self.fc2 = nn.Linear(hidden_features, out_features) 83 | self.drop = nn.Dropout(drop) 84 | 85 | def forward(self, x): 86 | x = self.fc1(x) 87 | x = self.act(x) 88 | x = self.drop(x) 89 | x = self.fc2(x) 90 | x = self.drop(x) 91 | return x 92 | 93 | 94 | class Attention(nn.Module): 95 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 96 | super().__init__() 97 | self.num_heads = num_heads 98 | head_dim = dim // num_heads 99 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 100 | self.scale = qk_scale or head_dim ** -0.5 101 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 102 | self.attn_drop = nn.Dropout(attn_drop) 103 | self.proj = nn.Linear(dim, dim) 104 | self.proj_drop = nn.Dropout(proj_drop) 105 | 106 | def forward(self, x): 107 | B, N, C = x.shape 108 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 109 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 110 | 111 | attn = (q @ k.transpose(-2, -1)) * self.scale 112 | attn = attn.softmax(dim=-1) 113 | attn = self.attn_drop(attn) 114 | 115 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 116 | x = self.proj(x) 117 | x = self.proj_drop(x) 118 | return x 119 | 120 | 121 | class Block(nn.Module): 122 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 123 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 124 | super().__init__() 125 | self.norm1 = norm_layer(dim) 126 | 127 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 128 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 129 | self.norm2 = norm_layer(dim) 130 | mlp_hidden_dim = int(dim * mlp_ratio) 131 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 132 | 133 | self.attn = Attention( 134 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 135 | 136 | def forward(self, x): 137 | x = x + self.drop_path(self.attn(self.norm1(x))) 138 | x = x + self.drop_path(self.mlp(self.norm2(x))) 139 | return x 140 | 141 | 142 | class TransformerEncoder(nn.Module): 143 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 144 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): 145 | super().__init__() 146 | 147 | self.blocks = nn.ModuleList([ 148 | Block( 149 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 150 | drop=drop_rate, attn_drop=attn_drop_rate, 151 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 152 | ) 153 | for i in range(depth)]) 154 | 155 | def forward(self, x, pos): 156 | for _, block in enumerate(self.blocks): 157 | x = block(x + pos) 158 | return x 159 | 160 | 161 | # Pretrain model 162 | class PointTransformer(nn.Module): 163 | def __init__(self, config, **kwargs): 164 | super().__init__() 165 | self.config = config 166 | # define the transformer argparse 167 | self.trans_dim = config.trans_dim 168 | self.depth = config.encoder_depth 169 | self.drop_path_rate = config.drop_path_rate 170 | self.num_heads = config.encoder_num_heads 171 | # embedding 172 | self.encoder_dims = config.trans_dim 173 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 174 | self.mask_type = config.mask_type 175 | self.mask_ratio = config.mask_ratio 176 | 177 | self.pos_embed = nn.Sequential( 178 | nn.Linear(3, 128), 179 | nn.GELU(), 180 | nn.Linear(128, self.trans_dim), 181 | ) 182 | 183 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 184 | self.blocks = TransformerEncoder( 185 | embed_dim=self.trans_dim, 186 | depth=self.depth, 187 | drop_path_rate=dpr, 188 | num_heads=self.num_heads, 189 | ) 190 | 191 | self.norm = nn.LayerNorm(self.trans_dim) 192 | self.apply(self._init_weights) 193 | 194 | def _init_weights(self, m): 195 | if isinstance(m, nn.Linear): 196 | trunc_normal_(m.weight, std=.02) 197 | if isinstance(m, nn.Linear) and m.bias is not None: 198 | nn.init.constant_(m.bias, 0) 199 | elif isinstance(m, nn.LayerNorm): 200 | nn.init.constant_(m.bias, 0) 201 | nn.init.constant_(m.weight, 1.0) 202 | elif isinstance(m, nn.Conv1d): 203 | trunc_normal_(m.weight, std=.02) 204 | if m.bias is not None: 205 | nn.init.constant_(m.bias, 0) 206 | 207 | def _mask_center_block(self, center, noaug=False): 208 | if noaug or self.mask_ratio == 0: 209 | return torch.zeros(center.shape[:2]).bool() 210 | mask_idx = [] 211 | for points in center: 212 | points = points.unsqueeze(0) 213 | index = random.randint(0, points.size(1) - 1) 214 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1) 215 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0] 216 | ratio = self.mask_ratio 217 | mask_num = int(ratio * len(idx)) 218 | mask = torch.zeros(len(idx)) 219 | mask[idx[:mask_num]] = 1 220 | mask_idx.append(mask.bool()) 221 | bool_masked_pos = torch.stack(mask_idx).to(center.device) 222 | return bool_masked_pos 223 | 224 | def _mask_center_rand(self, center, noaug=False): 225 | ''' 226 | center : B G 3 227 | -------------- 228 | mask : B G (bool) 229 | ''' 230 | B, G, _ = center.shape 231 | # skip the mask 232 | if noaug or self.mask_ratio == 0: 233 | return torch.zeros(center.shape[:2]).bool() 234 | 235 | self.num_mask = int(self.mask_ratio * G) 236 | 237 | overall_mask = np.zeros([B, G]) 238 | for i in range(B): 239 | mask = np.hstack([ 240 | np.zeros(G - self.num_mask), 241 | np.ones(self.num_mask), 242 | ]) 243 | np.random.shuffle(mask) 244 | overall_mask[i, :] = mask 245 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool) 246 | 247 | return overall_mask.to(center.device) # B G 248 | 249 | def forward(self, neighborhood, center, noaug=False): 250 | # generate mask 251 | if self.mask_type == 'rand': 252 | bool_masked_pos = self._mask_center_rand(center, noaug=noaug) # B G 253 | else: 254 | bool_masked_pos = self._mask_center_block(center, noaug=noaug) 255 | 256 | group_input_tokens = self.encoder(neighborhood) # B G C 257 | 258 | batch_size, seq_len, C = group_input_tokens.size() 259 | 260 | p = self.pos_embed(center) 261 | 262 | z = self.blocks(group_input_tokens, p) 263 | z = self.norm(z) 264 | return z[~bool_masked_pos].reshape(batch_size, -1, C), bool_masked_pos, z[bool_masked_pos].reshape(batch_size, -1, C) 265 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # DiffPMAE 2 | 3 | In this work, we propose an effective point cloud reconstruction architecture, DiffPMAE. 4 | Inspired by self-supervised learning concepts, we combine Masked Auto-Encoding and Diffusion 5 | Model mechanism to remotely reconstruct point cloud data. DiffPMAE can be extended to many related 6 | downstream tasks including point cloud compression, upsampling and completion with minimal modifications. 7 | 8 |

9 | 10 |

11 | 12 | GitHub repo: [https://github.com/DiffPMAE/DiffPMAE](https://github.com/DiffPMAE/DiffPMAE)
13 | [[arxiv]](https://arxiv.org/abs/2312.03298) [[poster]](https://eccv.ecva.net/virtual/2024/poster/1690) [[Project page]](https://tyraeldlee.github.io/DiffPMAE.github.io/) 14 | ## Datasets 15 | We use ShapeNet-55 and ModelNet40 for train and validation of the models and PU1K for upsampling validation. 16 | All dataset should be placed in the folder below and that will be read by scripts automatically.
17 | The overall directory structure should be: 18 | ``` 19 | │DiffPMAE/ 20 | ├──dataset/ 21 | │ ├──ModelNet/ 22 | │ ├──PU1K/ 23 | │ └──ShapeNet55/ 24 | ├──....... 25 | ``` 26 | 27 | ### ShapeNet-55: 28 | 29 | ``` 30 | │ShapeNet55/ 31 | ├──ShapeNet-55/ 32 | │ ├── train.txt 33 | │ └── test.txt 34 | ├──shapenet_pc/ 35 | │ ├── 02691156-1a04e3eab45ca15dd86060f189eb133.npy 36 | │ ├── 02691156-1a6ad7a24bb89733f412783097373bdc.npy 37 | │ ├── ....... 38 | ``` 39 | Download: You can download the processed ShapeNet55 dataset from [Point-BERT](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md) 40 | 41 | ### ModelNet40: 42 | 43 | ``` 44 | │ModelNet40/ 45 | ├──modelnet40_shape_names.txt 46 | ├──modelnet40_test.txt 47 | ├──modelnet40_test_8192pts_fps.dat 48 | ├──modelnet40_train.txt 49 | └──modelnet40_train_8192pts_fps.dat 50 | ``` 51 | Download: You can download the processed ModelNet40 dataset from [Point-BERT](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md) 52 | 53 | ### PU1K: 54 | 55 | ``` 56 | │PU1K/ 57 | ├──test/ 58 | │ ├── input_256/ 59 | │ ├── input_512/ 60 | │ ├── input_1024/ 61 | │ ├── input_2048/ 62 | │ │ ├── gt_8192/ 63 | │ │ │ ├── 11509_Panda_v4.xyz 64 | │ │ │ ├── ....... 65 | │ │ ├── input_2048/ 66 | │ │ │ ├── 11509_Panda_v4.xyz 67 | │ │ │ ├── ....... 68 | │ └── original_meshes/ 69 | │ │ ├── 11509_Panda_v4.off 70 | │ │ ├── ....... 71 | ├──train/ 72 | │ └── pu1k_poisson_256_poisson_1024_pc_2500_patch50_addpugan.h5 73 | 74 | ``` 75 | Download: You can download the processed PU1K dataset from [PU-GCN](https://github.com/guochengqian/PU-GCN) 76 | 77 | ## Requirements 78 | python >= 3.7
79 | pytorch >= 1.13.1
80 | CUDA >= 11.6 81 | ``` 82 | pip install -r requirements.txt 83 | ``` 84 | 85 | ``` 86 | # PointNet++ 87 | pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib" 88 | # GPU kNN 89 | pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl 90 | ``` 91 | ## Pre-trained Models 92 | 93 | Pre-trained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1lkEursyogjBY6HgjPwP6Gy33Y833nQOG?usp=sharing) 94 | 95 | The overall directory structure should be: 96 | ``` 97 | │DiffPMAE/ 98 | ├──dataset/ 99 | ├──pretrain_model/ 100 | │ ├──completion/ 101 | │ ├──compress/ 102 | │ ├──pretrain/ 103 | │ └──sr/ 104 | ├──....... 105 | ``` 106 | ## Training 107 | 108 | For training, you should train the Encoder first by using the command below. Then use pre-trained 109 | Encoder to train a decoder. 110 | 111 | For encoder: 112 | 113 | ``` 114 | CUDA_VISIBLE_DEVICES= python train_encoder.py 115 | ``` 116 | 117 | Hyperparameter setting can be adjusted in train_encoder.py: 118 | 119 | ```python 120 | # Experiment setting 121 | parser.add_argument('--batch_size', type=int, default=4) 122 | parser.add_argument('--val_batch_size', type=int, default=1) 123 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 124 | parser.add_argument('--save_dir', type=str, default='./results') 125 | parser.add_argument('--log', type=bool, default=False) 126 | 127 | # Grouping setting 128 | parser.add_argument('--mask_type', type=str, default='rand') 129 | parser.add_argument('--mask_ratio', type=float, default=0.75) 130 | parser.add_argument('--group_size', type=int, default=32) 131 | parser.add_argument('--num_group', type=int, default=64) 132 | parser.add_argument('--num_points', type=int, default=2048) 133 | parser.add_argument('--num_output', type=int, default=8192) 134 | 135 | # Transformer setting 136 | parser.add_argument('--trans_dim', type=int, default=384) 137 | parser.add_argument('--depth', type=int, default=12) 138 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 139 | parser.add_argument('--num_heads', type=int, default=6) 140 | 141 | # Encoder setting 142 | parser.add_argument('--encoder_dims', type=int, default=384) 143 | parser.add_argument('--loss', type=str, default='cdl2') 144 | 145 | # sche / optim 146 | parser.add_argument('--learning_rate', type=float, default=0.001) 147 | parser.add_argument('--weight_decay', type=float, default=0.05) 148 | parser.add_argument('--eta_min', type=float, default=0.000001) 149 | parser.add_argument('--t_max', type=float, default=200) 150 | ``` 151 | 152 | For decoder: 153 | 154 | ``` 155 | CUDA_VISIBLE_DEVICES= python train_decoder.py 156 | ``` 157 | 158 | To load the pre-trained Encoder, you can change the following in train_decoder.py: 159 | 160 | ```python 161 | check_point_dir = os.path.join('./pretrain_model/pretrain/encoder.pt') 162 | 163 | check_point = torch.load(check_point_dir)['model'] 164 | encoder = Encoder_Module(args).to(args.device) 165 | encoder.load_state_dict(check_point) 166 | ``` 167 | 168 | Hyperparameter setting for Decoder can be adjusted in train_decoder.py: 169 | 170 | ```python 171 | # Experiment setting 172 | parser.add_argument('--batch_size', type=int, default=32) 173 | parser.add_argument('--val_batch_size', type=int, default=1) 174 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 175 | parser.add_argument('--log', type=bool, default=True) 176 | parser.add_argument('--save_dir', type=str, default='./results') 177 | 178 | # Grouping setting 179 | parser.add_argument('--mask_type', type=str, default='rand') 180 | parser.add_argument('--mask_ratio', type=float, default=0.75) 181 | parser.add_argument('--group_size', type=int, default=32) # points in each group 182 | parser.add_argument('--num_group', type=int, default=64) # number of group 183 | parser.add_argument('--num_points', type=int, default=2048) 184 | parser.add_argument('--num_output', type=int, default=8192) 185 | parser.add_argument('--diffusion_output_size', default=2048) 186 | 187 | # Transformer setting 188 | parser.add_argument('--trans_dim', type=int, default=384) 189 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 190 | 191 | # Encoder setting 192 | parser.add_argument('--encoder_depth', type=int, default=12) 193 | parser.add_argument('--encoder_num_heads', type=int, default=6) 194 | parser.add_argument('--loss', type=str, default='cdl2') 195 | 196 | # Decoder setting 197 | parser.add_argument('--decoder_depth', type=int, default=4) 198 | parser.add_argument('--decoder_num_heads', type=int, default=4) 199 | 200 | # diffusion 201 | parser.add_argument('--num_steps', type=int, default=200) 202 | parser.add_argument('--beta_1', type=float, default=1e-4) 203 | parser.add_argument('--beta_T', type=float, default=0.05) 204 | parser.add_argument('--sched_mode', type=str, default='linear') 205 | 206 | # sche / optim 207 | parser.add_argument('--learning_rate', type=float, default=0.001) 208 | parser.add_argument('--weight_decay', type=float, default=0.05) 209 | parser.add_argument('--eta_min', type=float, default=0.000001) 210 | parser.add_argument('--t_max', type=float, default=200) 211 | ``` 212 | 213 | ## Evaluation 214 | For pre-train model: 215 | 216 | ``` 217 | python eval_diffpmae.py 218 | ``` 219 | 220 | For upsampling: 221 | ``` 222 | python eval_upsampling.py 223 | ``` 224 | 225 | For compression: 226 | ``` 227 | python eval_compression.py 228 | ``` 229 | 230 | The configuration for each task can be adjusted in corresponding python file.
231 | For example, the model configuration for pre-train evaluation can be adjusted in 232 | eval_diffpmae.py file at L14~45. 233 | 234 | For experiment setup: 235 | 236 | ```python 237 | parser.add_argument('--batch_size', type=int, default=32) 238 | # Batch size 239 | parser.add_argument('--val_batch_size', type=int, default=1) 240 | # Validation size 241 | parser.add_argument('--device', type=str, default='cuda') 242 | parser.add_argument('--log', type=bool, default=True) 243 | # Both trained model and log will not saved when set False. 244 | parser.add_argument('--save_dir', type=str, default='./results') 245 | # The root directory of saved file. 246 | ``` 247 | 248 | For Grouping setting 249 | 250 | ```python 251 | parser.add_argument('--mask_type', type=str, default='rand') 252 | # Could be either rand or block 253 | parser.add_argument('--mask_ratio', type=float, default=0.75) 254 | parser.add_argument('--group_size', type=int, default=32) 255 | # Points in each group 256 | parser.add_argument('--num_group', type=int, default=64) 257 | # Number of group 258 | parser.add_argument('--num_points', type=int, default=2048) 259 | # Input size of point cloud 260 | parser.add_argument('--num_output', type=int, default=8192) 261 | # Output size of Encoder module 262 | parser.add_argument('--diffusion_output_size', default=2048) 263 | #Output size of Decoder module 264 | ``` 265 | 266 | For Transformer setting 267 | 268 | ```python 269 | parser.add_argument('--trans_dim', type=int, default=384) 270 | # Latent size 271 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 272 | ``` 273 | 274 | For Encoder setting 275 | 276 | ```python 277 | parser.add_argument('--encoder_depth', type=int, default=12) 278 | # Number of blocks in Encoder Transformer 279 | parser.add_argument('--encoder_num_heads', type=int, default=6) 280 | # Number of heads in each Transformer block 281 | parser.add_argument('--loss', type=str, default='cdl2') 282 | ``` 283 | 284 | For Decoder setting 285 | 286 | ```python 287 | parser.add_argument('--decoder_depth', type=int, default=4) 288 | # Number of blocks in Decoder Transformer 289 | parser.add_argument('--decoder_num_heads', type=int, default=4) 290 | # Number of heads in each Transformer block 291 | ``` 292 | 293 | For diffusion process 294 | 295 | ```python 296 | parser.add_argument('--num_steps', type=int, default=200) 297 | parser.add_argument('--beta_1', type=float, default=1e-4) 298 | parser.add_argument('--beta_T', type=float, default=0.05) 299 | parser.add_argument('--sched_mode', type=str, default='linear') 300 | ``` 301 | 302 | For optimizer and scheduler 303 | 304 | ```python 305 | parser.add_argument('--learning_rate', type=float, default=0.001) 306 | parser.add_argument('--weight_decay', type=float, default=0.05) 307 | parser.add_argument('--eta_min', type=float, default=0.000001) 308 | parser.add_argument('--t_max', type=float, default=200) 309 | ``` 310 | ## Acknowledgements 311 | Our code build based on [PointMAE](https://github.com/Pang-Yatian/Point-MAE) 312 | 313 | ## Citation 314 | ```bibtex 315 | @inproceedings{li2024diffpmae, 316 | author = {Yanlong Li and Chamara Madarasingha and Kanchana Thilakarathna}, 317 | title = {DiffPMAE: Diffusion Masked Autoencoders for Point Cloud Reconstruction}, 318 | booktitle = {ECCV}, 319 | year = {2024}, 320 | url = {https://arxiv.org/abs/2312.03298} 321 | } 322 | ``` 323 | -------------------------------------------------------------------------------- /metrics/pytorch_structural_losses/src/approxmatch.cu: -------------------------------------------------------------------------------- 1 | #include "utils.hpp" 2 | 3 | __global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ 4 | float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; 5 | float multiL,multiR; 6 | if (n>=m){ 7 | multiL=1; 8 | multiR=n/m; 9 | }else{ 10 | multiL=m/n; 11 | multiR=1; 12 | } 13 | const int Block=1024; 14 | __shared__ float buf[Block*4]; 15 | for (int i=blockIdx.x;i=-2;j--){ 24 | for (int j=7;j>-2;j--){ 25 | float level=-powf(4.0f,j); 26 | if (j==-2){ 27 | level=0; 28 | } 29 | for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,out); 227 | //} 228 | 229 | __global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ 230 | __shared__ float sum_grad[256*3]; 231 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2); 294 | //} 295 | 296 | /*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, 297 | cudaStream_t stream)*/ 298 | // temp: TensorShape{b,(n+m)*2} 299 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){ 300 | approxmatchkernel 301 | <<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp); 302 | 303 | cudaError_t err = cudaGetLastError(); 304 | if (cudaSuccess != err) 305 | throw std::runtime_error(Formatter() 306 | << "CUDA kernel failed : " << std::to_string(err)); 307 | } 308 | 309 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){ 310 | matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out); 311 | 312 | cudaError_t err = cudaGetLastError(); 313 | if (cudaSuccess != err) 314 | throw std::runtime_error(Formatter() 315 | << "CUDA kernel failed : " << std::to_string(err)); 316 | } 317 | 318 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){ 319 | matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1); 320 | matchcostgrad2kernel<<>>(b,n,m,xyz1,xyz2,match,grad2); 321 | 322 | cudaError_t err = cudaGetLastError(); 323 | if (cudaSuccess != err) 324 | throw std::runtime_error(Formatter() 325 | << "CUDA kernel failed : " << std::to_string(err)); 326 | } 327 | -------------------------------------------------------------------------------- /eval_diffpmae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | 4 | from torch.utils.data import DataLoader 5 | from utils.logger import * 6 | from utils.dataset import * 7 | from model.DiffusionPretrain import * 8 | from model.Encoder import * 9 | from utils.visualization import * 10 | from metrics.evaluation_metrics import compute_all_metrics, averaged_hausdorff_distance, jsd_between_point_cloud_sets 11 | 12 | parser = argparse.ArgumentParser() 13 | # Experiment setting 14 | parser.add_argument('--val_batch_size', type=int, default=1) 15 | parser.add_argument('--device', type=str, default='cuda') # mps for mac 16 | parser.add_argument('--log', type=bool, default=True) 17 | parser.add_argument('--save_dir', type=str, default='./results') 18 | 19 | # Grouping setting 20 | parser.add_argument('--mask_type', type=str, default='rand') 21 | parser.add_argument('--mask_ratio', type=float, default=0.75) 22 | parser.add_argument('--group_size', type=int, default=32) # points in each group 23 | parser.add_argument('--num_group', type=int, default=64) # number of group 24 | parser.add_argument('--num_points', type=int, default=2048) 25 | parser.add_argument('--num_output', type=int, default=8192) 26 | parser.add_argument('--diffusion_output_size', type=int, default=2048) 27 | 28 | # Transformer setting 29 | parser.add_argument('--trans_dim', type=int, default=384) 30 | parser.add_argument('--drop_path_rate', type=float, default=0.1) 31 | 32 | # Encoder setting 33 | parser.add_argument('--encoder_depth', type=int, default=12) 34 | parser.add_argument('--encoder_num_heads', type=int, default=6) 35 | parser.add_argument('--loss', type=str, default='cdl2') 36 | 37 | # Decoder setting 38 | parser.add_argument('--decoder_depth', type=int, default=4) 39 | parser.add_argument('--decoder_num_heads', type=int, default=4) 40 | 41 | # diffusion 42 | parser.add_argument('--num_steps', type=int, default=200) 43 | parser.add_argument('--beta_1', type=float, default=1e-4) 44 | parser.add_argument('--beta_T', type=float, default=0.05) 45 | parser.add_argument('--sched_mode', type=str, default='linear') 46 | 47 | 48 | args = parser.parse_args() 49 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M") 50 | save_dir = os.path.join(args.save_dir, 'new_start_eval_{date}'.format(date=time_now)) 51 | 52 | if args.log: 53 | if not os.path.exists(save_dir): 54 | os.makedirs(save_dir) 55 | log = logging.getLogger() 56 | log.setLevel(logging.INFO) 57 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s') 58 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt')) 59 | log_file.setLevel(logging.INFO) 60 | log_file.setFormatter(formatter) 61 | log.addHandler(log_file) 62 | 63 | if args.log: 64 | log.info('loading dataset') 65 | print('loading dataset') 66 | 67 | test_dset_MN = ModelNet( 68 | root='dataset/ModelNet', 69 | number_pts=8192, 70 | use_normal=False, 71 | cats=40, 72 | subset='test' 73 | ) 74 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True) 75 | 76 | val_dset = ShapeNet( 77 | data_path='dataset/ShapeNet55/ShapeNet-55', 78 | pc_path='dataset/ShapeNet55/shapenet_pc', 79 | subset='test', 80 | n_points=2048, 81 | downsample=True 82 | ) 83 | 84 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True) 85 | 86 | if args.log: 87 | log.info('Training Stable Diffusion') 88 | log.info('config:') 89 | log.info(args) 90 | log.info('dataset loaded') 91 | print('dataset loaded') 92 | 93 | print('loading model') 94 | check_point_dir = os.path.join('./pretrain_model/pretrain/encoder.pt') 95 | check_point = torch.load(check_point_dir)['model'] 96 | encoder = Encoder_Module(args).to(args.device) 97 | encoder.load_state_dict(check_point) 98 | print('model loaded') 99 | 100 | diff_check_point_dir = os.path.join('./pretrain_model/pretrain/decoder.pt') 101 | diff_check_point = torch.load(diff_check_point_dir)['model'] 102 | model = Diff_Point_MAE(args).to(args.device) 103 | model = torch.nn.DataParallel(model, device_ids=[0]) 104 | model.load_state_dict(diff_check_point) 105 | 106 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255] 107 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255] 108 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255] 109 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255] 110 | GRAY = [0xaa / 255, 0xaa / 255, 0xaa / 255] 111 | 112 | 113 | def visiualizaion(index=-1): 114 | for i, batch in enumerate(val_loader): 115 | with torch.no_grad(): 116 | if index == -1 or index == i: 117 | ref = batch['lr'].to(args.device) 118 | model.eval() 119 | encoder.eval() 120 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False) 121 | print('sampling model {i}'.format(i=i)) 122 | t = torch.randint(0, args.num_steps, (ref.size(0),)) 123 | recons = model.module.sampling(x_vis, mask, center, True) 124 | eval_model_normal(ref.cpu().numpy(), 'ori', show_axis=False, show_img=False, save_img=True, 125 | save_location=save_dir, color=BLUE) 126 | eval_model_normal(vis_pc.cpu().numpy(), 'msk', show_axis=False, show_img=False, save_img=True, 127 | save_location=save_dir, color=PINK) 128 | for i in range(0, len(recons)): 129 | if i in [0, 10, 20, 30, 40, 50, 90, 150, 180]: 130 | eval_model_normal(recons[i].cpu().numpy(), i, show_axis=False, show_img=True, save_img=True, 131 | save_location=save_dir, color=BLUE) 132 | # recons = torch.cat([vis_pc.reshape(1, -1, 3), recons], dim=1) 133 | eval_model_normal(recons[199].cpu().numpy(), 199, show_axis=False, show_img=False, save_img=True, 134 | save_location=save_dir, color=BLUE) 135 | # vis_results(recons, ref, vis_pc, i, show_axis=True, save_img=True, show_img=False) 136 | eval_model_normal(torch.cat([recons[199], vis_pc], dim=1).cpu().numpy(), -2, show_axis=False, 137 | show_img=False, save_img=True, save_location=save_dir, color=BLUE) 138 | eval_model_normal(recons[199].cpu().numpy(), 'comb', second_part=vis_pc.cpu().numpy(), show_axis=False, 139 | show_img=True, save_img=True, save_location=save_dir, color=BLUE, color2=PINK) 140 | eval_model_normal(msk_pc.cpu().numpy(), 'segment', second_part=vis_pc.cpu().numpy(), show_axis=False, 141 | show_img=False, save_img=True, save_location=save_dir, color=GRAY, color2=PINK) 142 | if index <= i: 143 | break 144 | 145 | def calculate_metric_all_on_prediction(size=-1, enc=encoder, pred=model): 146 | all_sample = [] 147 | all_ref = [] 148 | all_hd = [] 149 | for i, batch in enumerate(val_loader): 150 | if i == size: 151 | break 152 | with torch.no_grad(): 153 | ref = batch['lr'].to(args.device) 154 | pred.eval() 155 | x_vis, z_masked, mask, center, vis_pc, msk_pc = enc.encode(ref, masked=False) 156 | recons = pred.module.sampling(x_vis, mask, center, False) 157 | recons = torch.cat([recons, vis_pc], dim=1) 158 | all_sample.append(recons) 159 | all_ref.append(ref) 160 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze()) 161 | all_hd.append(hd) 162 | 163 | print("evaluating model: {i}".format(i=i)) 164 | log.info("evaluating model: {i}".format(i=i)) 165 | 166 | mean_hd = sum(all_hd) / len(all_hd) 167 | sample = torch.cat(all_sample, dim=0) 168 | refpc = torch.cat(all_ref, dim=0) 169 | all = compute_all_metrics(sample, refpc, 32) 170 | print( 171 | "MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format( 172 | mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], 173 | N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd)) 174 | log.info( 175 | "MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format( 176 | mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], 177 | N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd)) 178 | 179 | 180 | def auto_run(): 181 | mask_ratio = [0.75] 182 | for v in mask_ratio: 183 | args.mask_ratio = v 184 | log.info("mask ratio: {mr}".format(mr=v)) 185 | print('loading encoder') 186 | check_point_dir_l = os.path.join('./pretrain_model/pretrain/encoder.pt') 187 | check_point_l = torch.load(check_point_dir_l)['model'] 188 | encoder_l = Encoder_Module(args).to(args.device) 189 | encoder_l.load_state_dict(check_point_l) 190 | print('encoder loaded\rloading model') 191 | 192 | diff_check_point_dir_l = os.path.join('./pretrain_model/pretrain/decoder.pt') 193 | diff_check_point_l = torch.load(diff_check_point_dir_l)['model'] 194 | model_l = Diff_Point_MAE(args).to(args.device) 195 | model_l = torch.nn.DataParallel(model_l, device_ids=[0]) 196 | model_l.load_state_dict(diff_check_point_l) 197 | print('model loaded') 198 | calculate_metric_all_on_prediction(2048, encoder_l, model_l) 199 | 200 | 201 | # auto_run() 202 | 203 | def calculate_metric_all_set(): 204 | all_cd = [] 205 | all_sample = [] 206 | all_ref = [] 207 | all_hd = [] 208 | for i, batch in enumerate(val_loader): 209 | with torch.no_grad(): 210 | ref = batch['lr'].to(args.device) 211 | model.eval() 212 | encoder.eval() 213 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False) 214 | recons = model.module.sampling(x_vis, mask, center, False) 215 | recons = torch.cat([vis_pc, recons], dim=1) 216 | all_sample.append(recons) 217 | all_ref.append(ref) 218 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze()) 219 | all_hd.append(hd) 220 | 221 | cd = model.module.loss_func(recons, ref) 222 | all_cd.append(cd) 223 | 224 | print("evaluating model: {i}".format(i=i)) 225 | # log.info("evaluating model: {i}, CD={cd}".format(i=i, cd=cd)) 226 | 227 | mean_cd = sum(all_cd) / len(all_cd) 228 | mean_hd = sum(all_hd) / len(all_hd) 229 | sample = torch.cat(all_sample, dim=0) 230 | refpc = torch.cat(all_ref, dim=0) 231 | JSD = jsd_between_point_cloud_sets(sample, refpc) 232 | # all = compute_all_metrics(sample, refpc, 32) 233 | # print("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd)) 234 | # log.info("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd)) 235 | print("Mean CD: {mCD}, Mean HD: {mHD}, JSD: {jsd}".format(mCD=mean_cd, mHD=mean_hd, jsd=JSD)) 236 | log.info("Mean CD: {mCD}, Mean HD: {mHD}, JSD: {jsd}".format(mCD=mean_cd, mHD=mean_hd, jsd=JSD)) 237 | 238 | calculate_metric_all_set() -------------------------------------------------------------------------------- /metrics/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/stevenygd/PointFlow/tree/master/metrics 3 | """ 4 | import torch 5 | import numpy as np 6 | import warnings 7 | from scipy.stats import entropy 8 | from sklearn.neighbors import NearestNeighbors 9 | from numpy.linalg import norm 10 | from tqdm.auto import tqdm 11 | 12 | _EMD_NOT_IMPL_WARNED = True 13 | 14 | 15 | def emd_approx(sample, ref): 16 | global _EMD_NOT_IMPL_WARNED 17 | emd = torch.zeros([sample.size(0)]).to(sample) 18 | if not _EMD_NOT_IMPL_WARNED: 19 | _EMD_NOT_IMPL_WARNED = True 20 | print('\n\n[WARNING]') 21 | print(' * EMD is not implemented due to GPU compatability issue.') 22 | print(' * We will set all EMD to zero by default.') 23 | print(' * You may implement your own EMD in the function `emd_approx` in ./evaluation/evaluation_metrics.py') 24 | print('\n') 25 | return emd 26 | 27 | 28 | # Borrow from https://github.com/ThibaultGROUEIX/AtlasNet 29 | def distChamfer(a, b): 30 | x, y = a, b 31 | bs, num_points, points_dim = x.size() 32 | xx = torch.bmm(x, x.transpose(2, 1)) 33 | yy = torch.bmm(y, y.transpose(2, 1)) 34 | zz = torch.bmm(x, y.transpose(2, 1)) 35 | diag_ind = torch.arange(0, num_points).to(a).long() 36 | rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) 37 | ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) 38 | P = (rx.transpose(2, 1) + ry - 2 * zz) 39 | return P.min(1)[0], P.min(2)[0] 40 | 41 | 42 | def chamfer_distance_l1(a, b): 43 | d1, d2 = distChamfer(a, b) 44 | d1 = torch.sqrt(d1) 45 | d2 = torch.sqrt(d2) 46 | return (torch.mean(d1) + torch.mean(d2)) / 2 47 | 48 | 49 | def chamfer_distance_l2(a, b): 50 | d1, d2 = distChamfer(a, b) 51 | return torch.mean(d1) + torch.mean(d2) 52 | 53 | 54 | def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): 55 | N_sample = sample_pcs.shape[0] 56 | N_ref = ref_pcs.shape[0] 57 | assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) 58 | 59 | cd_lst = [] 60 | emd_lst = [] 61 | iterator = range(0, N_sample, batch_size) 62 | 63 | for b_start in tqdm(iterator, desc='EMD-CD'): 64 | b_end = min(N_sample, b_start + batch_size) 65 | sample_batch = sample_pcs[b_start:b_end] 66 | ref_batch = ref_pcs[b_start:b_end] 67 | 68 | dl, dr = distChamfer(sample_batch, ref_batch) 69 | cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) 70 | 71 | emd_batch = emd_approx(sample_batch, ref_batch) 72 | emd_lst.append(emd_batch) 73 | 74 | if reduced: 75 | cd = torch.cat(cd_lst).mean() 76 | emd = torch.cat(emd_lst).mean() 77 | else: 78 | cd = torch.cat(cd_lst) 79 | emd = torch.cat(emd_lst) 80 | 81 | results = { 82 | 'MMD-CD': cd, 83 | 'MMD-EMD': emd, 84 | } 85 | return results 86 | 87 | 88 | def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True): 89 | N_sample = sample_pcs.shape[0] 90 | N_ref = ref_pcs.shape[0] 91 | all_cd = [] 92 | iterator = range(N_sample) 93 | if verbose: 94 | iterator = tqdm(iterator, desc='Pairwise EMD-CD') 95 | for sample_b_start in iterator: 96 | sample_batch = sample_pcs[sample_b_start] 97 | 98 | cd_lst = [] 99 | 100 | sub_iterator = range(0, N_ref, batch_size) 101 | # if verbose: 102 | # sub_iterator = tqdm(sub_iterator, leave=False) 103 | for ref_b_start in sub_iterator: 104 | ref_b_end = min(N_ref, ref_b_start + batch_size) 105 | ref_batch = ref_pcs[ref_b_start:ref_b_end] 106 | 107 | batch_size_ref = ref_batch.size(0) 108 | point_dim = ref_batch.size(2) 109 | sample_batch_exp = sample_batch.view(1, -1, point_dim).expand( 110 | batch_size_ref, -1, -1) 111 | sample_batch_exp = sample_batch_exp.contiguous() 112 | 113 | dl, dr = distChamfer(sample_batch_exp, ref_batch) 114 | cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) 115 | 116 | cd_lst = torch.cat(cd_lst, dim=1) 117 | all_cd.append(cd_lst) 118 | 119 | all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref 120 | 121 | return all_cd 122 | 123 | 124 | # Adapted from https://github.com/xuqiantong/ 125 | # GAN-Metrics/blob/master/framework/metric.py 126 | def knn(Mxx, Mxy, Myy, k, sqrt=False): 127 | n0 = Mxx.size(0) 128 | n1 = Myy.size(0) 129 | label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) 130 | M = torch.cat([ 131 | torch.cat((Mxx, Mxy), 1), 132 | torch.cat((Mxy.transpose(0, 1), Myy), 1)], 0) 133 | if sqrt: 134 | M = M.abs().sqrt() 135 | INFINITY = float('inf') 136 | val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk( 137 | k, 0, False) 138 | 139 | count = torch.zeros(n0 + n1).to(Mxx) 140 | for i in range(0, k): 141 | count = count + label.index_select(0, idx[i]) 142 | pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() 143 | 144 | s = { 145 | 'tp': (pred * label).sum(), 146 | 'fp': (pred * (1 - label)).sum(), 147 | 'fn': ((1 - pred) * label).sum(), 148 | 'tn': ((1 - pred) * (1 - label)).sum(), 149 | } 150 | 151 | s.update({ 152 | 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), 153 | 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 154 | 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 155 | 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), 156 | 'acc': torch.eq(label, pred).float().mean(), 157 | }) 158 | return s 159 | 160 | 161 | def lgan_mmd_cov(all_dist): 162 | N_sample, N_ref = all_dist.size(0), all_dist.size(1) 163 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) 164 | min_val, _ = torch.min(all_dist, dim=0) 165 | mmd = min_val.mean() 166 | mmd_smp = min_val_fromsmp.mean() 167 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) 168 | cov = torch.tensor(cov).to(all_dist) 169 | return { 170 | 'lgan_mmd': mmd, 171 | 'lgan_cov': cov, 172 | 'lgan_mmd_smp': mmd_smp, 173 | } 174 | 175 | 176 | def lgan_mmd_cov_match(all_dist): 177 | N_sample, N_ref = all_dist.size(0), all_dist.size(1) 178 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) 179 | min_val, _ = torch.min(all_dist, dim=0) 180 | mmd = min_val.mean() 181 | mmd_smp = min_val_fromsmp.mean() 182 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) 183 | cov = torch.tensor(cov).to(all_dist) 184 | return { 185 | 'lgan_mmd': mmd, 186 | 'lgan_cov': cov, 187 | 'lgan_mmd_smp': mmd_smp, 188 | }, min_idx.view(-1) 189 | 190 | 191 | def compute_all_metrics(sample_pcs, ref_pcs, batch_size): 192 | results = {} 193 | 194 | # print("Pairwise EMD CD") 195 | M_rs_cd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) 196 | 197 | ## CD 198 | res_cd = lgan_mmd_cov(M_rs_cd.t()) 199 | results.update({ 200 | "%s-CD" % k: v for k, v in res_cd.items() 201 | }) 202 | 203 | for k, v in results.items(): 204 | print('[%s] %.8f' % (k, v.item())) 205 | 206 | M_rr_cd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size) 207 | M_ss_cd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size) 208 | 209 | # 1-NN results 210 | ## CD 211 | one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) 212 | results.update({ 213 | "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k 214 | }) 215 | 216 | # JSD results 217 | jsd = jsd_between_point_cloud_sets(sample_pcs, ref_pcs) 218 | results.update({ 219 | "JSD": jsd 220 | }) 221 | 222 | 223 | 224 | return results 225 | 226 | 227 | ####################################################### 228 | # JSD : from https://github.com/optas/latent_3d_points 229 | ####################################################### 230 | def unit_cube_grid_point_cloud(resolution, clip_sphere=False): 231 | """Returns the center coordinates of each cell of a 3D grid with 232 | resolution^3 cells, that is placed in the unit-cube. If clip_sphere it True 233 | it drops the "corner" cells that lie outside the unit-sphere. 234 | """ 235 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) 236 | spacing = 1.0 / float(resolution - 1) 237 | for i in range(resolution): 238 | for j in range(resolution): 239 | for k in range(resolution): 240 | grid[i, j, k, 0] = i * spacing - 0.5 241 | grid[i, j, k, 1] = j * spacing - 0.5 242 | grid[i, j, k, 2] = k * spacing - 0.5 243 | 244 | if clip_sphere: 245 | grid = grid.reshape(-1, 3) 246 | grid = grid[norm(grid, axis=1) <= 0.5] 247 | 248 | return grid, spacing 249 | 250 | 251 | def jsd_between_point_cloud_sets( 252 | sample_pcs, ref_pcs, resolution=28): 253 | """Computes the JSD between two sets of point-clouds, 254 | as introduced in the paper 255 | ```Learning Representations And Generative Models For 3D Point Clouds```. 256 | Args: 257 | sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. 258 | ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. 259 | resolution: (int) grid-resolution. Affects granularity of measurements. 260 | """ 261 | sample_pcs = sample_pcs.cpu().numpy() 262 | ref_pcs = ref_pcs.cpu().numpy() 263 | in_unit_sphere = True 264 | sample_grid_var = entropy_of_occupancy_grid( 265 | sample_pcs, resolution, in_unit_sphere)[1] 266 | ref_grid_var = entropy_of_occupancy_grid( 267 | ref_pcs, resolution, in_unit_sphere)[1] 268 | return jensen_shannon_divergence(sample_grid_var, ref_grid_var) 269 | 270 | 271 | def entropy_of_occupancy_grid( 272 | pclouds, grid_resolution, in_sphere=False, verbose=False): 273 | """Given a collection of point-clouds, estimate the entropy of 274 | the random variables corresponding to occupancy-grid activation patterns. 275 | Inputs: 276 | pclouds: (numpy array) #point-clouds x points per point-cloud x 3 277 | grid_resolution (int) size of occupancy grid that will be used. 278 | """ 279 | epsilon = 10e-4 280 | bound = 0.5 + epsilon 281 | if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: 282 | if verbose: 283 | warnings.warn('Point-clouds are not in unit cube.') 284 | 285 | if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: 286 | if verbose: 287 | warnings.warn('Point-clouds are not in unit sphere.') 288 | 289 | grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) 290 | grid_coordinates = grid_coordinates.reshape(-1, 3) 291 | grid_counters = np.zeros(len(grid_coordinates)) 292 | grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) 293 | nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) 294 | 295 | for pc in tqdm(pclouds, desc='JSD'): 296 | _, indices = nn.kneighbors(pc) 297 | indices = np.squeeze(indices) 298 | for i in indices: 299 | grid_counters[i] += 1 300 | indices = np.unique(indices) 301 | for i in indices: 302 | grid_bernoulli_rvars[i] += 1 303 | 304 | acc_entropy = 0.0 305 | n = float(len(pclouds)) 306 | for g in grid_bernoulli_rvars: 307 | if g > 0: 308 | p = float(g) / n 309 | acc_entropy += entropy([p, 1.0 - p]) 310 | 311 | return acc_entropy / len(grid_counters), grid_counters 312 | 313 | 314 | def jensen_shannon_divergence(P, Q): 315 | if np.any(P < 0) or np.any(Q < 0): 316 | raise ValueError('Negative values.') 317 | if len(P) != len(Q): 318 | raise ValueError('Non equal size.') 319 | 320 | P_ = P / np.sum(P) # Ensure probabilities. 321 | Q_ = Q / np.sum(Q) 322 | 323 | e1 = entropy(P_, base=2) 324 | e2 = entropy(Q_, base=2) 325 | e_sum = entropy((P_ + Q_) / 2.0, base=2) 326 | res = e_sum - ((e1 + e2) / 2.0) 327 | 328 | res2 = _jsdiv(P_, Q_) 329 | 330 | if not np.allclose(res, res2, atol=10e-5, rtol=0): 331 | warnings.warn('Numerical values of two JSD methods don\'t agree.') 332 | 333 | return res 334 | 335 | 336 | def _jsdiv(P, Q): 337 | """another way of computing JSD""" 338 | 339 | def _kldiv(A, B): 340 | a = A.copy() 341 | b = B.copy() 342 | idx = np.logical_and(a > 0, b > 0) 343 | a = a[idx] 344 | b = b[idx] 345 | return np.sum([v for v in a * np.log2(a / b)]) 346 | 347 | P_ = P / np.sum(P) 348 | Q_ = Q / np.sum(Q) 349 | 350 | M = 0.5 * (P_ + Q_) 351 | 352 | return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) 353 | 354 | from sklearn.metrics.pairwise import pairwise_distances 355 | def averaged_hausdorff_distance(sample_pcs, ref_pcs, max_ahd=np.inf): 356 | sample_pcs = sample_pcs.cpu().numpy() 357 | ref_pcs = ref_pcs.cpu().numpy() 358 | 359 | if len(sample_pcs) == 0 or len(ref_pcs) == 0: 360 | return max_ahd 361 | 362 | sample_pcs = np.array(sample_pcs) 363 | ref_pcs = np.array(ref_pcs) 364 | 365 | assert sample_pcs.ndim == 2, 'got %s' % sample_pcs.ndim 366 | assert ref_pcs.ndim == 2, 'got %s' % ref_pcs.ndim 367 | 368 | assert sample_pcs.shape[1] == ref_pcs.shape[1], \ 369 | 'The points in both sets must have the same number of dimensions, got %s and %s.' \ 370 | % (ref_pcs.shape[1], ref_pcs.shape[1]) 371 | 372 | d2_matrix = pairwise_distances(sample_pcs, ref_pcs, metric='euclidean') 373 | 374 | res = np.average(np.min(d2_matrix, axis=0)) + \ 375 | np.average(np.min(d2_matrix, axis=1)) 376 | 377 | return res 378 | 379 | 380 | def hausdorff_distance(pred, ref): 381 | A_ab, A_ba = distChamfer(pred, ref) 382 | h_ab = torch.max(A_ab - A_ba) 383 | B_ab, B_ba = distChamfer(ref, pred) 384 | h_ba = torch.max(B_ab - B_ba) 385 | 386 | return torch.max(h_ab, h_ba) / 2 387 | 388 | if __name__ == '__main__': 389 | a = torch.randn([16, 2048, 3]).cuda() 390 | b = torch.randn([16, 2048, 3]).cuda() 391 | print(EMD_CD(a, b, batch_size=8)) 392 | -------------------------------------------------------------------------------- /model/EncoderCop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath, trunc_normal_ 5 | import numpy as np 6 | from utils import misc 7 | import random 8 | from knn_cuda import KNN 9 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2 10 | 11 | 12 | 13 | class Encoder(nn.Module): ## Embedding module 14 | def __init__(self, encoder_channel): 15 | super().__init__() 16 | self.encoder_channel = encoder_channel 17 | self.first_conv = nn.Sequential( 18 | nn.Conv1d(3, 128, 1), 19 | nn.BatchNorm1d(128), 20 | nn.ReLU(inplace=True), 21 | nn.Conv1d(128, 256, 1) 22 | ) 23 | self.second_conv = nn.Sequential( 24 | nn.Conv1d(512, 512, 1), 25 | nn.BatchNorm1d(512), 26 | nn.ReLU(inplace=True), 27 | nn.Conv1d(512, self.encoder_channel, 1) 28 | ) 29 | 30 | def forward(self, point_groups): 31 | ''' 32 | point_groups : B G N 3 33 | ----------------- 34 | feature_global : B G C 35 | ''' 36 | bs, g, n, _ = point_groups.shape 37 | point_groups = point_groups.reshape(bs * g, n, 3) 38 | # encoder 39 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n 40 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1 41 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n 42 | feature = self.second_conv(feature) # BG 1024 n 43 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024 44 | return feature_global.reshape(bs, g, self.encoder_channel) 45 | 46 | 47 | class Group(nn.Module): # FPS + KNN 48 | def __init__(self, num_group, group_size): 49 | super().__init__() 50 | self.num_group = num_group 51 | self.group_size = group_size 52 | self.knn = KNN(k=self.group_size, transpose_mode=True) 53 | 54 | def forward(self, xyz): 55 | ''' 56 | input: B N 3 57 | --------------------------- 58 | output: B G M 3 59 | center : B G 3 60 | ''' 61 | batch_size, num_points, _ = xyz.shape 62 | # fps the centers out 63 | center = misc.fps(xyz, self.num_group) # B G 3 64 | # knn to get the neighborhood 65 | _, idx = self.knn(xyz, center) # B G M 66 | assert idx.size(1) == self.num_group 67 | assert idx.size(2) == self.group_size 68 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 69 | idx = idx + idx_base 70 | idx = idx.view(-1) 71 | neighborhood = xyz.view(batch_size * num_points, -1)[idx, :] 72 | neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous() 73 | # normalize 74 | neighborhood = neighborhood - center.unsqueeze(2) 75 | return neighborhood, center 76 | 77 | 78 | ## Transformers 79 | class Mlp(nn.Module): 80 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 81 | super().__init__() 82 | out_features = out_features or in_features 83 | hidden_features = hidden_features or in_features 84 | self.fc1 = nn.Linear(in_features, hidden_features) 85 | self.act = act_layer() 86 | self.fc2 = nn.Linear(hidden_features, out_features) 87 | self.drop = nn.Dropout(drop) 88 | 89 | def forward(self, x): 90 | x = self.fc1(x) 91 | x = self.act(x) 92 | x = self.drop(x) 93 | x = self.fc2(x) 94 | x = self.drop(x) 95 | return x 96 | 97 | 98 | class Attention(nn.Module): 99 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 100 | super().__init__() 101 | self.num_heads = num_heads 102 | head_dim = dim // num_heads 103 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 104 | self.scale = qk_scale or head_dim ** -0.5 105 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 106 | self.attn_drop = nn.Dropout(attn_drop) 107 | self.proj = nn.Linear(dim, dim) 108 | self.proj_drop = nn.Dropout(proj_drop) 109 | 110 | def forward(self, x): 111 | B, N, C = x.shape 112 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 113 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 114 | 115 | attn = (q @ k.transpose(-2, -1)) * self.scale 116 | attn = attn.softmax(dim=-1) 117 | attn = self.attn_drop(attn) 118 | 119 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 120 | x = self.proj(x) 121 | x = self.proj_drop(x) 122 | return x 123 | 124 | 125 | class Block(nn.Module): 126 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 127 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 128 | super().__init__() 129 | self.norm1 = norm_layer(dim) 130 | 131 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 132 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 133 | self.norm2 = norm_layer(dim) 134 | mlp_hidden_dim = int(dim * mlp_ratio) 135 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 136 | 137 | self.attn = Attention( 138 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 139 | 140 | def forward(self, x): 141 | x = x + self.drop_path(self.attn(self.norm1(x))) 142 | x = x + self.drop_path(self.mlp(self.norm2(x))) 143 | return x 144 | 145 | 146 | class TransformerEncoder(nn.Module): 147 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 148 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): 149 | super().__init__() 150 | 151 | self.blocks = nn.ModuleList([ 152 | Block( 153 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 154 | drop=drop_rate, attn_drop=attn_drop_rate, 155 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 156 | ) 157 | for i in range(depth)]) 158 | 159 | def forward(self, x, pos): 160 | for _, block in enumerate(self.blocks): 161 | x = block(x + pos) 162 | return x 163 | 164 | 165 | # Pretrain model 166 | class PointTransformer(nn.Module): 167 | def __init__(self, config, **kwargs): 168 | super().__init__() 169 | self.config = config 170 | # define the transformer argparse 171 | self.trans_dim = config.trans_dim 172 | self.depth = config.depth 173 | self.drop_path_rate = config.drop_path_rate 174 | self.num_heads = config.num_heads 175 | # embedding 176 | self.encoder_dims = config.encoder_dims 177 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 178 | self.mask_type = config.mask_type 179 | self.mask_ratio = config.mask_ratio 180 | 181 | self.pos_embed = nn.Sequential( 182 | nn.Linear(3, 128), 183 | nn.GELU(), 184 | nn.Linear(128, self.trans_dim), 185 | ) 186 | 187 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 188 | self.blocks = TransformerEncoder( 189 | embed_dim=self.trans_dim, 190 | depth=self.depth, 191 | drop_path_rate=dpr, 192 | num_heads=self.num_heads, 193 | ) 194 | 195 | self.norm = nn.LayerNorm(self.trans_dim) 196 | self.apply(self._init_weights) 197 | 198 | def _init_weights(self, m): 199 | if isinstance(m, nn.Linear): 200 | trunc_normal_(m.weight, std=.02) 201 | if isinstance(m, nn.Linear) and m.bias is not None: 202 | nn.init.constant_(m.bias, 0) 203 | elif isinstance(m, nn.LayerNorm): 204 | nn.init.constant_(m.bias, 0) 205 | nn.init.constant_(m.weight, 1.0) 206 | elif isinstance(m, nn.Conv1d): 207 | trunc_normal_(m.weight, std=.02) 208 | if m.bias is not None: 209 | nn.init.constant_(m.bias, 0) 210 | 211 | def _mask_center_block(self, center, noaug=False): 212 | if noaug or self.mask_ratio == 0: 213 | return torch.zeros(center.shape[:2]).bool() 214 | mask_idx = [] 215 | for points in center: 216 | points = points.unsqueeze(0) 217 | index = random.randint(0, points.size(1) - 1) 218 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1) 219 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0] 220 | ratio = self.mask_ratio 221 | mask_num = int(ratio * len(idx)) 222 | mask = torch.zeros(len(idx)) 223 | mask[idx[:mask_num]] = 1 224 | mask_idx.append(mask.bool()) 225 | bool_masked_pos = torch.stack(mask_idx).to(center.device) 226 | return bool_masked_pos 227 | def _mask_center_rand(self, center, noaug=False): 228 | ''' 229 | center : B G 3 230 | -------------- 231 | mask : B G (bool) 232 | ''' 233 | B, G, _ = center.shape 234 | # skip the mask 235 | if noaug or self.mask_ratio == 0: 236 | return torch.zeros(center.shape[:2]).bool() 237 | 238 | self.num_mask = int(self.mask_ratio * G) 239 | 240 | overall_mask = np.zeros([B, G]) 241 | for i in range(B): 242 | mask = np.hstack([ 243 | np.zeros(G - self.num_mask), 244 | np.ones(self.num_mask), 245 | ]) 246 | np.random.shuffle(mask) 247 | overall_mask[i, :] = mask 248 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool) 249 | 250 | return overall_mask.to(center.device) # B G 251 | 252 | def forward(self, neighborhood, center, noaug=False): 253 | # generate mask 254 | if self.mask_type == 'rand': 255 | bool_masked_pos = self._mask_center_rand(center, noaug=noaug) # B G 256 | else: 257 | bool_masked_pos = self._mask_center_block(center, noaug=noaug) 258 | B, G, N, C = neighborhood.shape 259 | vis = neighborhood[~bool_masked_pos].reshape(B, -1, N, C) 260 | group_input_tokens = self.encoder(vis) # B G C 261 | 262 | batch_size, seq_len, L = group_input_tokens.size() 263 | vis_center = center[~bool_masked_pos].reshape(B, -1, C) 264 | p = self.pos_embed(vis_center) 265 | 266 | z = self.blocks(group_input_tokens, p) 267 | z = self.norm(z) 268 | return z.reshape(batch_size, -1, L), bool_masked_pos 269 | 270 | class Encoder_Module(nn.Module): 271 | def __init__(self, config): 272 | super().__init__() 273 | self.config = config 274 | self.trans_dim = config.trans_dim 275 | self.AE_encoder = PointTransformer(config) 276 | self.group_size = config.group_size 277 | self.num_group = config.num_group 278 | self.num_output = config.num_output 279 | self.mask_ratio = config.mask_ratio 280 | self.num_channel = 3 281 | self.drop_path_rate = config.drop_path_rate 282 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 283 | 284 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 285 | 286 | # prediction head 287 | self.increase_dim = nn.Sequential( 288 | nn.Conv1d(self.trans_dim, (self.num_channel * int(self.num_output * (1 - self.mask_ratio))) // self.num_group, 1) 289 | ) 290 | 291 | trunc_normal_(self.mask_token, std=.02) 292 | self.loss = config.loss 293 | # loss 294 | self.build_loss_func(self.loss) 295 | 296 | def build_loss_func(self, loss_type): 297 | if loss_type == "cdl1": 298 | self.loss_func = chamfer_distance_l1 299 | elif loss_type == 'cdl2': 300 | self.loss_func = chamfer_distance_l2 301 | elif loss_type == 'mse': 302 | self.loss_func = F.mse_loss 303 | else: 304 | raise NotImplementedError 305 | 306 | def forward(self, pts): 307 | x_vis, mask, center, vis_pc, msk_pc = self.encode(pts, False) 308 | 309 | B, _, C = x_vis.shape # B VIS C 310 | 311 | rebuild_points = self.increase_dim(x_vis.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 312 | full = torch.cat([rebuild_points, msk_pc], dim=1) 313 | loss1 = self.loss_func(full, pts) 314 | return loss1 315 | 316 | def encode(self, pt, masked=False): 317 | B, _, N = pt.shape 318 | neighborhood, center = self.group_divider(pt) 319 | x_vis, mask = self.AE_encoder(neighborhood, center) 320 | if masked: 321 | return x_vis, mask, center 322 | else: 323 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis) 324 | return x_vis, mask, center, vis_pc, msk_pc 325 | 326 | def evaluate(self, x_vis, x_msk): 327 | B, _, C = x_vis.shape # B VIS C 328 | x_full = torch.cat([x_vis, x_msk], dim=1) 329 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 330 | return rebuild_points.reshape(-1, 3).unsqueeze(0) 331 | 332 | def neighborhood(self, neighborhood, center, mask, x_vis): 333 | B, M, N = x_vis.shape 334 | vis_point = neighborhood[~mask].reshape(B * M, -1, 3) 335 | full_vis = vis_point + center[~mask].unsqueeze(1) 336 | msk_point = neighborhood[mask].reshape(B * (self.num_group - M), -1, 3) 337 | full_msk = msk_point + center[mask].unsqueeze(1) 338 | 339 | full_vis = full_vis.reshape(B, -1, 3) 340 | full_msk = full_msk.reshape(B, -1, 3) 341 | 342 | return full_vis, full_msk 343 | -------------------------------------------------------------------------------- /model/EncoderSR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath, trunc_normal_ 5 | import numpy as np 6 | from utils import misc 7 | import random 8 | from knn_cuda import KNN 9 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2 10 | # from model.Encoder_Component import * 11 | 12 | class Encoder(nn.Module): ## Embedding module 13 | def __init__(self, encoder_channel): 14 | super().__init__() 15 | self.encoder_channel = encoder_channel 16 | self.first_conv = nn.Sequential( 17 | nn.Conv1d(3, 128, 1), 18 | nn.BatchNorm1d(128), 19 | nn.ReLU(inplace=True), 20 | nn.Conv1d(128, 256, 1) 21 | ) 22 | self.second_conv = nn.Sequential( 23 | nn.Conv1d(512, 512, 1), 24 | nn.BatchNorm1d(512), 25 | nn.ReLU(inplace=True), 26 | nn.Conv1d(512, self.encoder_channel, 1) 27 | ) 28 | 29 | def forward(self, point_groups): 30 | ''' 31 | point_groups : B G N 3 32 | ----------------- 33 | feature_global : B G C 34 | ''' 35 | bs, g, n, _ = point_groups.shape 36 | point_groups = point_groups.reshape(bs * g, n, 3) 37 | # encoder 38 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n 39 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1 40 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n 41 | feature = self.second_conv(feature) # BG 1024 n 42 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024 43 | return feature_global.reshape(bs, g, self.encoder_channel) 44 | 45 | 46 | class Group(nn.Module): # FPS + KNN 47 | def __init__(self, num_group, group_size): 48 | super().__init__() 49 | self.num_group = num_group 50 | self.group_size = group_size 51 | self.knn = KNN(k=self.group_size, transpose_mode=True) 52 | self.knnhr = KNN(k=self.group_size*4, transpose_mode=True) 53 | 54 | def forward(self, xyz, hr=None): 55 | ''' 56 | input: B N 3 57 | --------------------------- 58 | output: B G M 3 59 | center : B G 3 60 | ''' 61 | batch_size, num_points, _ = xyz.shape 62 | # fps the centers out 63 | center = misc.fps(xyz, self.num_group) # B G 3 64 | # knn to get the neighborhood 65 | _, idx = self.knn(xyz, center) # B G M 66 | assert idx.size(1) == self.num_group 67 | assert idx.size(2) == self.group_size 68 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 69 | idx = idx + idx_base 70 | idx = idx.view(-1) 71 | neighborhood = xyz.view(batch_size * num_points, -1)[idx, :] 72 | neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous() 73 | # normalize 74 | neighborhood = neighborhood - center.unsqueeze(2) 75 | if hr is not None: 76 | _, hr_points, _ = hr.shape 77 | _, hidx = self.knnhr(hr, center) 78 | hidx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points 79 | hidx = hidx + hidx_base 80 | hidx = hidx.view(-1) 81 | h_n = hr.view(batch_size * hr_points, -1)[hidx,:] 82 | h_n = h_n.view(batch_size, self.num_group, self.group_size*4, 3).contiguous() 83 | h_n = h_n - center.unsqueeze(2) 84 | return neighborhood, center, h_n 85 | else: 86 | return neighborhood, center 87 | 88 | 89 | ## Transformers 90 | class Mlp(nn.Module): 91 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 92 | super().__init__() 93 | out_features = out_features or in_features 94 | hidden_features = hidden_features or in_features 95 | self.fc1 = nn.Linear(in_features, hidden_features) 96 | self.act = act_layer() 97 | self.fc2 = nn.Linear(hidden_features, out_features) 98 | self.drop = nn.Dropout(drop) 99 | 100 | def forward(self, x): 101 | x = self.fc1(x) 102 | x = self.act(x) 103 | x = self.drop(x) 104 | x = self.fc2(x) 105 | x = self.drop(x) 106 | return x 107 | 108 | 109 | class Attention(nn.Module): 110 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 111 | super().__init__() 112 | self.num_heads = num_heads 113 | head_dim = dim // num_heads 114 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 115 | self.scale = qk_scale or head_dim ** -0.5 116 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 117 | self.attn_drop = nn.Dropout(attn_drop) 118 | self.proj = nn.Linear(dim, dim) 119 | self.proj_drop = nn.Dropout(proj_drop) 120 | 121 | def forward(self, x): 122 | B, N, C = x.shape 123 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 124 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 125 | 126 | attn = (q @ k.transpose(-2, -1)) * self.scale 127 | attn = attn.softmax(dim=-1) 128 | attn = self.attn_drop(attn) 129 | 130 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 131 | x = self.proj(x) 132 | x = self.proj_drop(x) 133 | return x 134 | 135 | 136 | class Block(nn.Module): 137 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 139 | super().__init__() 140 | self.norm1 = norm_layer(dim) 141 | 142 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 143 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 144 | self.norm2 = norm_layer(dim) 145 | mlp_hidden_dim = int(dim * mlp_ratio) 146 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 147 | 148 | self.attn = Attention( 149 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 150 | 151 | def forward(self, x): 152 | x = x + self.drop_path(self.attn(self.norm1(x))) 153 | x = x + self.drop_path(self.mlp(self.norm2(x))) 154 | return x 155 | 156 | 157 | class TransformerEncoder(nn.Module): 158 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 159 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): 160 | super().__init__() 161 | 162 | self.blocks = nn.ModuleList([ 163 | Block( 164 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 165 | drop=drop_rate, attn_drop=attn_drop_rate, 166 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 167 | ) 168 | for i in range(depth)]) 169 | 170 | def forward(self, x, pos): 171 | for _, block in enumerate(self.blocks): 172 | x = block(x + pos) 173 | return x 174 | 175 | 176 | class PointTransformerDecoder(nn.Module): 177 | def __init__(self, embed_dim=384, depth=4, num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, 178 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm): 179 | super().__init__() 180 | self.blocks = nn.ModuleList([ 181 | Block( 182 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 183 | drop=drop_rate, attn_drop=attn_drop_rate, 184 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 185 | ) 186 | for i in range(depth)]) 187 | self.norm = norm_layer(embed_dim) 188 | self.head = nn.Identity() 189 | 190 | self.apply(self._init_weights) 191 | 192 | def _init_weights(self, m): 193 | if isinstance(m, nn.Linear): 194 | nn.init.xavier_uniform_(m.weight) 195 | if isinstance(m, nn.Linear) and m.bias is not None: 196 | nn.init.constant_(m.bias, 0) 197 | elif isinstance(m, nn.LayerNorm): 198 | nn.init.constant_(m.bias, 0) 199 | nn.init.constant_(m.weight, 1.0) 200 | 201 | def forward(self, x, pos, n): 202 | for _, block in enumerate(self.blocks): 203 | x = block(x + pos) 204 | x = self.head(self.norm(x[:, -n:])) 205 | return x 206 | 207 | # 208 | # Pretrain model 209 | class PointTransformer(nn.Module): 210 | def __init__(self, config): 211 | super().__init__() 212 | self.config = config 213 | # define the transformer argparse 214 | self.trans_dim = config.trans_dim 215 | self.depth = config.encoder_depth 216 | self.drop_path_rate = config.drop_path_rate 217 | self.num_heads = config.encoder_num_heads 218 | # embedding 219 | self.encoder_dims = config.trans_dim 220 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 221 | self.mask_type = config.mask_type 222 | self.mask_ratio = config.mask_ratio 223 | 224 | self.pos_embed = nn.Sequential( 225 | nn.Linear(3, 128), 226 | nn.GELU(), 227 | nn.Linear(128, self.trans_dim), 228 | ) 229 | 230 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 231 | self.blocks = TransformerEncoder( 232 | embed_dim=self.trans_dim, 233 | depth=self.depth, 234 | drop_path_rate=dpr, 235 | num_heads=self.num_heads, 236 | ) 237 | 238 | self.norm = nn.LayerNorm(self.trans_dim) 239 | self.apply(self._init_weights) 240 | 241 | def _init_weights(self, m): 242 | if isinstance(m, nn.Linear): 243 | trunc_normal_(m.weight, std=.02) 244 | if isinstance(m, nn.Linear) and m.bias is not None: 245 | nn.init.constant_(m.bias, 0) 246 | elif isinstance(m, nn.LayerNorm): 247 | nn.init.constant_(m.bias, 0) 248 | nn.init.constant_(m.weight, 1.0) 249 | elif isinstance(m, nn.Conv1d): 250 | trunc_normal_(m.weight, std=.02) 251 | if m.bias is not None: 252 | nn.init.constant_(m.bias, 0) 253 | 254 | def _mask_center_block(self, center, noaug=False): 255 | if noaug or self.mask_ratio == 0: 256 | return torch.zeros(center.shape[:2]).bool() 257 | mask_idx = [] 258 | for points in center: 259 | points = points.unsqueeze(0) 260 | index = random.randint(0, points.size(1) - 1) 261 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1) 262 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0] 263 | ratio = self.mask_ratio 264 | mask_num = int(ratio * len(idx)) 265 | mask = torch.zeros(len(idx)) 266 | mask[idx[:mask_num]] = 1 267 | mask_idx.append(mask.bool()) 268 | bool_masked_pos = torch.stack(mask_idx).to(center.device) 269 | return bool_masked_pos 270 | 271 | def _mask_center_rand(self, center, noaug=False): 272 | ''' 273 | center : B G 3 274 | -------------- 275 | mask : B G (bool) 276 | ''' 277 | B, G, _ = center.shape 278 | # skip the mask 279 | if noaug or self.mask_ratio == 0: 280 | return torch.zeros(center.shape[:2]).bool() 281 | 282 | self.num_mask = int(self.mask_ratio * G) 283 | 284 | overall_mask = np.zeros([B, G]) 285 | for i in range(B): 286 | mask = np.hstack([ 287 | np.zeros(G - self.num_mask), 288 | np.ones(self.num_mask), 289 | ]) 290 | np.random.shuffle(mask) 291 | overall_mask[i, :] = mask 292 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool) 293 | 294 | return overall_mask.to(center.device) # B G 295 | 296 | def forward(self, neighborhood, center, noaug=False): 297 | # generate mask 298 | if self.mask_type == 'rand': 299 | bool_masked_pos = self._mask_center_rand(center, noaug=noaug) # B G 300 | else: 301 | bool_masked_pos = self._mask_center_block(center, noaug=noaug) 302 | 303 | group_input_tokens = self.encoder(neighborhood) # B G C 304 | 305 | batch_size, seq_len, C = group_input_tokens.size() 306 | 307 | p = self.pos_embed(center) 308 | 309 | z = self.blocks(group_input_tokens, p) 310 | z = self.norm(z) 311 | return z[~bool_masked_pos].reshape(batch_size, -1, C), bool_masked_pos, z[bool_masked_pos].reshape(batch_size, -1, C) 312 | 313 | class Encoder_Module(nn.Module): 314 | def __init__(self, config): 315 | super().__init__() 316 | self.config = config 317 | self.trans_dim = config.trans_dim 318 | self.AE_encoder = PointTransformer(config) 319 | self.group_size = config.group_size 320 | self.num_group = config.num_group 321 | self.num_output = config.num_output 322 | self.num_channel = 3 323 | self.drop_path_rate = config.drop_path_rate 324 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 325 | self.decoder_pos_embed = nn.Sequential( 326 | nn.Linear(3, 128), 327 | nn.GELU(), 328 | nn.Linear(128, self.trans_dim) 329 | ) 330 | self.decoder_depth = config.decoder_depth 331 | self.decoder_num_heads = config.decoder_num_heads 332 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)] 333 | self.AE_decoder = PointTransformerDecoder( 334 | embed_dim=self.trans_dim, 335 | depth=self.decoder_depth, 336 | drop_path_rate=dpr, 337 | num_heads=self.decoder_num_heads, 338 | ) 339 | 340 | self.AE_decoder = None 341 | 342 | 343 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 344 | 345 | # prediction head 346 | self.increase_dim = nn.Sequential( 347 | nn.Conv1d(self.trans_dim, (self.num_channel * self.num_output) // self.num_group, 1) 348 | ) 349 | 350 | trunc_normal_(self.mask_token, std=.02) 351 | self.loss = config.loss 352 | # loss 353 | self.build_loss_func(self.loss) 354 | 355 | def build_loss_func(self, loss_type): 356 | if loss_type == "cdl1": 357 | self.loss_func = chamfer_distance_l1 358 | elif loss_type == 'cdl2': 359 | self.loss_func = chamfer_distance_l2 360 | elif loss_type == 'mse': 361 | self.loss_func = F.mse_loss 362 | else: 363 | raise NotImplementedError 364 | 365 | def forward(self, pts, hr_pt): 366 | x_vis, x_msk, mask, center = self.encode(pts, False) 367 | 368 | B, _, C = x_vis.shape # B VIS C 369 | x_full = torch.cat([x_vis, x_msk], dim=1) 370 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 371 | loss1 = self.loss_func(rebuild_points, hr_pt) 372 | return loss1 373 | 374 | def encode(self, pt, hr=None, masked=False): 375 | B, _, N = pt.shape 376 | if masked: 377 | if hr is None: 378 | neighborhood, center = self.group_divider(pt) 379 | else: 380 | neighborhood, center, h_n = self.group_divider(pt) 381 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center) 382 | return x_vis, mask, center 383 | else: 384 | if hr is not None: 385 | neighborhood, center, h_n = self.group_divider(pt, hr) 386 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center) 387 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis) 388 | vis_hr, msk_hr = self.neighborhood(h_n, center, mask, x_vis) 389 | return x_vis, x_masked, mask, center, vis_pc, msk_pc, vis_hr, msk_hr 390 | else: 391 | neighborhood, center = self.group_divider(pt) 392 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center) 393 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis) 394 | return x_vis, x_masked, mask, center, vis_pc, msk_pc 395 | 396 | def evaluate(self, x_vis, x_msk): 397 | B, _, C = x_vis.shape # B VIS C 398 | x_full = torch.cat([x_vis, x_msk], dim=1) 399 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3 400 | return rebuild_points.reshape(-1, 3).unsqueeze(0) 401 | 402 | def neighborhood(self, neighborhood, center, mask, x_vis): 403 | B, M, N = x_vis.shape 404 | vis_point = neighborhood[~mask].reshape(B * M, -1, 3) 405 | full_vis = vis_point + center[~mask].unsqueeze(1) 406 | msk_point = neighborhood[mask].reshape(B * (self.num_group - M), -1, 3) 407 | full_msk = msk_point + center[mask].unsqueeze(1) 408 | 409 | full_vis = full_vis.reshape(B, -1, 3) 410 | full_msk = full_msk.reshape(B, -1, 3) 411 | 412 | return full_vis, full_msk 413 | --------------------------------------------------------------------------------