├── model
├── __init__.py
├── log.txt
├── Diffusion.py
├── Encoder.py
├── Decoder_Component.py
├── DiffusionCop.py
├── EncoderCompress.py
├── DiffusionSR.py
├── DiffusionPretrain.py
├── Encoder_Component.py
├── EncoderCop.py
└── EncoderSR.py
├── utils
├── __init__.py
├── misc.py
├── visualization.py
└── logger.py
├── metrics
├── __init__.py
├── .gitignore
├── pytorch_structural_losses
│ ├── .gitignore
│ ├── __init__.py
│ ├── src
│ │ ├── nndistance.cuh
│ │ ├── approxmatch.cuh
│ │ ├── utils.hpp
│ │ ├── nndistance.cu
│ │ ├── structural_loss.cpp
│ │ └── approxmatch.cu
│ ├── pybind
│ │ ├── extern.hpp
│ │ └── bind.cpp
│ ├── setup.py
│ ├── nn_distance.py
│ ├── match_cost.py
│ └── Makefile
└── evaluation_metrics.py
├── assets
└── structure.png
├── .gitignore
├── pretrain_model
└── .gitignore
├── .idea
├── .gitignore
├── misc.xml
├── vcs.xml
├── other.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── DiffPointMAE.iml
├── dataset
└── .gitignore
├── requirements.txt
├── train_encoder.py
├── eval_compression.py
├── eval_cop.py
├── eval_upsampling.py
├── train_decoder.py
├── readme.md
└── eval_diffpmae.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/metrics/.gitignore:
--------------------------------------------------------------------------------
1 | StructuralLosses
2 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/.gitignore:
--------------------------------------------------------------------------------
1 | PyTorchStructuralLosses.egg-info/
2 |
--------------------------------------------------------------------------------
/assets/structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TyraelDLee/DiffPMAE/HEAD/assets/structure.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | eval_diffpmae_ts300.py
2 | train_decoder_kitti.py
3 | train_encoder_kitti.py
4 | train_encoder_kitti_object.py
5 | /results
--------------------------------------------------------------------------------
/pretrain_model/.gitignore:
--------------------------------------------------------------------------------
1 | /completion
2 | /compress
3 | /encocder_cop_exp
4 | /new_start_eval_2023_10_21_15_37
5 | /pretrain
6 | /sr
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/__init__.py:
--------------------------------------------------------------------------------
1 | #import torch
2 |
3 | #from MakePytorchBackend import AddGPU, Foo, ApproxMatch
4 |
5 | #from Add import add_gpu, approx_match
6 |
7 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # 基于编辑器的 HTTP 客户端请求
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/dataset/.gitignore:
--------------------------------------------------------------------------------
1 | data_depth_velodyne.zip
2 | data_object_velodyne.zip
3 | data_odometry_calib.zip
4 | data_odometry_labels.zip
5 | /data_odometry_labels
6 | /KITTI
7 | /KITTI_depth
8 | /ModelNet
9 | /PU1K
10 | /SEMANTIC_KITTI_DIR
11 | /ShapeNet55
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy~=1.21.5
2 | matplotlib~=3.1.1
3 | h5py~=2.9.0
4 | yaml~=0.1.7
5 | pyyaml~=5.1.2
6 | easydict~=1.10
7 | open3d~=0.17.0
8 | tqdm~=4.64.1
9 | termcolor~=1.1.0
10 | scipy~=1.4.1
11 | scikit-learn~=0.23.2
12 | setuptools~=41.4.0
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | from pointnet2_ops import pointnet2_utils
2 |
3 |
4 | def fps(data, number):
5 | '''
6 | data B N 3
7 | number int
8 | '''
9 | fps_idx = pointnet2_utils.furthest_point_sample(data, number)
10 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
11 | return fps_data
12 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/src/nndistance.cuh:
--------------------------------------------------------------------------------
1 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
2 | void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
3 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/pybind/extern.hpp:
--------------------------------------------------------------------------------
1 | std::vector ApproxMatch(at::Tensor in_a, at::Tensor in_b);
2 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
3 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
4 |
5 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q);
6 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);
7 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/pybind/bind.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "R:\\Documents\\developing\\COMP5702_04\\models\\metrics\\pytorch_structural_losses\\pybind\\extern.hpp"
6 |
7 | namespace py = pybind11;
8 |
9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
10 | m.def("ApproxMatch", &ApproxMatch);
11 | m.def("MatchCost", &MatchCost);
12 | m.def("MatchCostGrad", &MatchCostGrad);
13 | m.def("NNDistance", &NNDistance);
14 | m.def("NNDistanceGrad", &NNDistanceGrad);
15 | }
16 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/src/approxmatch.cuh:
--------------------------------------------------------------------------------
1 | /*
2 | template
3 | void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
4 | cudaStream_t stream);
5 | */
6 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream);
7 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream);
8 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream);
9 |
--------------------------------------------------------------------------------
/.idea/DiffPointMAE.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/src/utils.hpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | class Formatter {
6 | public:
7 | Formatter() {}
8 | ~Formatter() {}
9 |
10 | template Formatter &operator<<(const Type &value) {
11 | stream_ << value;
12 | return *this;
13 | }
14 |
15 | std::string str() const { return stream_.str(); }
16 | operator std::string() const { return stream_.str(); }
17 |
18 | enum ConvertToString { to_str };
19 |
20 | std::string operator>>(ConvertToString) { return stream_.str(); }
21 |
22 | private:
23 | std::stringstream stream_;
24 | Formatter(const Formatter &);
25 | Formatter &operator=(Formatter &);
26 | };
27 |
--------------------------------------------------------------------------------
/model/log.txt:
--------------------------------------------------------------------------------
1 | [2023-10-21 15:37:35,969::INFO] loading dataset
2 | [2023-10-21 15:37:36,432::INFO] Training Stable Diffusion
3 | [2023-10-21 15:37:36,433::INFO] config:
4 | [2023-10-21 15:37:36,433::INFO] Namespace(batch_size=32, beta_1=0.0001, beta_T=0.05, decoder_depth=4, decoder_num_heads=4, decoder_trans_dim=192, depth=12, device='cuda', drop_path_rate=0.1, encoder_dims=384, group_size=32, log=True, loss='cdl2', mask_ratio=0.75, mask_type='rand', num_group=64, num_heads=6, num_output=8192, num_points=2048, num_steps=200, save_dir='./results', sched_mode='linear', trans_dim=384, transformer_dim_forward=128, transformer_drop_out=0.1, val_batch_size=1)
5 | [2023-10-21 15:37:36,433::INFO] dataset loaded
6 | [2023-10-21 15:37:38,342::INFO] [Point_MAE]
7 | [2023-10-21 15:37:38,427::INFO] [Point_MAE] divide point cloud into G64 x S32 points ...
8 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
3 |
4 | # Python interface
5 | setup(
6 | name='PyTorchStructuralLosses',
7 | version='0.1.0',
8 | install_requires=['torch'],
9 | packages=['StructuralLosses'],
10 | package_dir={'StructuralLosses': '.\\'},
11 | ext_modules=[
12 | CUDAExtension(
13 | name='StructuralLossesBackend',
14 | include_dirs=['.\\'],
15 | sources=[
16 | 'pybind/bind.cpp',
17 | ],
18 | libraries=['make_pytorch'],
19 | library_dirs=['objs'],
20 | # extra_compile_args=['-g']
21 | )
22 | ],
23 | cmdclass={'build_ext': BuildExtension},
24 | author='Christopher B. Choy',
25 | author_email='chrischoy@ai.stanford.edu',
26 | description='Tutorial for Pytorch C++ Extension with a Makefile',
27 | keywords='Pytorch C++ Extension',
28 | url='https://github.com/chrischoy/MakePytorchPlusPlus',
29 | zip_safe=False,
30 | )
31 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/nn_distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | # from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
4 | from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
5 |
6 | # Inherit from Function
7 | class NNDistanceFunction(Function):
8 | # Note that both forward and backward are @staticmethods
9 | @staticmethod
10 | # bias is an optional argument
11 | def forward(ctx, seta, setb):
12 | #print("Match Cost Forward")
13 | ctx.save_for_backward(seta, setb)
14 | '''
15 | input:
16 | set1 : batch_size * #dataset_points * 3
17 | set2 : batch_size * #query_points * 3
18 | returns:
19 | dist1, idx1, dist2, idx2
20 | '''
21 | dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
22 | ctx.idx1 = idx1
23 | ctx.idx2 = idx2
24 | return dist1, dist2
25 |
26 | # This function has only a single output, so it gets only one gradient
27 | @staticmethod
28 | def backward(ctx, grad_dist1, grad_dist2):
29 | #print("Match Cost Backward")
30 | # This is a pattern that is very convenient - at the top of backward
31 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to
32 | # None. Thanks to the fact that additional trailing Nones are
33 | # ignored, the return statement is simple even when the function has
34 | # optional inputs.
35 | seta, setb = ctx.saved_tensors
36 | idx1 = ctx.idx1
37 | idx2 = ctx.idx2
38 | grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
39 | return grada, gradb
40 |
41 | nn_distance = NNDistanceFunction.apply
42 |
43 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/match_cost.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
4 |
5 | # Inherit from Function
6 | class MatchCostFunction(Function):
7 | # Note that both forward and backward are @staticmethods
8 | @staticmethod
9 | # bias is an optional argument
10 | def forward(ctx, seta, setb):
11 | #print("Match Cost Forward")
12 | ctx.save_for_backward(seta, setb)
13 | '''
14 | input:
15 | set1 : batch_size * #dataset_points * 3
16 | set2 : batch_size * #query_points * 3
17 | returns:
18 | match : batch_size * #query_points * #dataset_points
19 | '''
20 | match, temp = ApproxMatch(seta, setb)
21 | ctx.match = match
22 | cost = MatchCost(seta, setb, match)
23 | return cost
24 |
25 | """
26 | grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
27 | return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
28 | """
29 | # This function has only a single output, so it gets only one gradient
30 | @staticmethod
31 | def backward(ctx, grad_output):
32 | #print("Match Cost Backward")
33 | # This is a pattern that is very convenient - at the top of backward
34 | # unpack saved_tensors and initialize all gradients w.r.t. inputs to
35 | # None. Thanks to the fact that additional trailing Nones are
36 | # ignored, the return statement is simple even when the function has
37 | # optional inputs.
38 | seta, setb = ctx.saved_tensors
39 | #grad_input = grad_weight = grad_bias = None
40 | grada, gradb = MatchCostGrad(seta, setb, ctx.match)
41 | grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
42 | return grada*grad_output_expand, gradb*grad_output_expand
43 |
44 | match_cost = MatchCostFunction.apply
45 |
46 |
--------------------------------------------------------------------------------
/model/Diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 |
6 | class VarianceSchedule(nn.Module):
7 |
8 | def __init__(self, num_steps, beta_1, beta_T, mode='linear'):
9 | super().__init__()
10 | assert mode in ('linear')
11 | self.num_steps = num_steps
12 | self.beta_1 = beta_1
13 | self.beta_T = beta_T
14 | self.mode = mode
15 |
16 | if mode == 'linear':
17 | betas = torch.linspace(beta_1, beta_T, steps=num_steps)
18 | # create a 1D tensor of size num_steps, values are evenly spaced from beta1 and betaT.
19 | # beta1, ... , betaT are hyper-parameter that control the diffusion rate of the process.
20 |
21 | betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding a 0 at beginning.
22 |
23 | alphas = 1 - betas
24 | log_alphas = torch.log(alphas) # (7)
25 | for i in range(1, log_alphas.size(0)): # 1 to T
26 | log_alphas[i] += log_alphas[i - 1]
27 | # log alpha add all previous step
28 | alpha_bars = log_alphas.exp() # ?
29 |
30 | sigmas_flex = torch.sqrt(betas)
31 | sigmas_inflex = torch.zeros_like(sigmas_flex) # a 0 filled tensor with sigmas_flex dimension
32 | for i in range(1, sigmas_flex.size(0)):
33 | sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] # (11)
34 | sigmas_inflex = torch.sqrt(sigmas_inflex)
35 |
36 | self.register_buffer('betas', betas)
37 | self.register_buffer('alphas', alphas)
38 | self.register_buffer('alpha_bars', alpha_bars)
39 | self.register_buffer('sigmas_flex', sigmas_flex)
40 | self.register_buffer('sigmas_inflex', sigmas_inflex)
41 |
42 | def uniform_sample_t(self, batch_size):
43 | ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size)
44 | return ts.tolist()
45 |
46 | def get_sigmas(self, t, flexibility):
47 | assert 0 <= flexibility <= 1
48 | sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
49 | return sigmas
50 |
51 |
52 | class TimeEmbedding(nn.Module):
53 | """
54 | Sinusoidal Time Embedding for the diffusion process.
55 |
56 | Input:
57 | Timestep: the current timestep, in range [1, ..., T]
58 | """
59 | def __init__(self, dim):
60 | super().__init__()
61 | self.emb_dim = dim
62 |
63 | def forward(self, ts):
64 | half_dim = self.emb_dim // 2
65 | emb = math.log(10000) / (half_dim - 1)
66 | emb = torch.exp(torch.arange(half_dim, device=ts.device) * -emb)
67 | emb = ts[:, None] * emb[None, :]
68 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
69 | return emb
--------------------------------------------------------------------------------
/model/Encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2
4 | from model.Encoder_Component import *
5 |
6 | class Encoder_Module(nn.Module):
7 | def __init__(self, config):
8 | super().__init__()
9 | self.config = config
10 | self.trans_dim = config.trans_dim
11 | self.AE_encoder = PointTransformer(config)
12 | self.group_size = config.group_size
13 | self.num_group = config.num_group
14 | self.num_output = config.num_output
15 | self.num_channel = 3
16 | self.drop_path_rate = config.drop_path_rate
17 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
18 |
19 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
20 |
21 | # prediction head
22 | self.increase_dim = nn.Sequential(
23 | nn.Conv1d(self.trans_dim, (self.num_channel * self.num_output) // self.num_group, 1)
24 | )
25 |
26 | trunc_normal_(self.mask_token, std=.02)
27 | self.loss = config.loss
28 | # loss
29 | self.build_loss_func(self.loss)
30 |
31 | def build_loss_func(self, loss_type):
32 | if loss_type == "cdl1":
33 | self.loss_func = chamfer_distance_l1
34 | elif loss_type == 'cdl2':
35 | self.loss_func = chamfer_distance_l2
36 | elif loss_type == 'mse':
37 | self.loss_func = F.mse_loss
38 | else:
39 | raise NotImplementedError
40 |
41 | def forward(self, pts, hr_pt):
42 | x_vis, x_msk, mask, center, vis_pc, msk_pc = self.encode(pts, False)
43 |
44 | B, _, C = x_vis.shape # B VIS C
45 | x_full = torch.cat([x_vis, x_msk], dim=1)
46 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
47 | loss1 = self.loss_func(rebuild_points, hr_pt)
48 | return loss1
49 |
50 | def encode(self, pt, masked=False):
51 | B, _, N = pt.shape
52 | neighborhood, center = self.group_divider(pt)
53 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center)
54 | if masked:
55 | return x_vis, mask, center
56 | else:
57 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis)
58 | return x_vis, x_masked, mask, center, vis_pc, msk_pc
59 |
60 | def evaluate(self, x_vis, x_msk):
61 | B, _, C = x_vis.shape # B VIS C
62 | x_full = torch.cat([x_vis, x_msk], dim=1)
63 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
64 | return rebuild_points.reshape(-1, 3).unsqueeze(0)
65 |
66 | def neighborhood(self, neighborhood, center, mask, x_vis):
67 | B, M, N = x_vis.shape
68 | vis_point = neighborhood[~mask].reshape(B * M, -1, 3)
69 | full_vis = vis_point + center[~mask].unsqueeze(1)
70 | msk_point = neighborhood[mask].reshape(B * (self.num_group - M), -1, 3)
71 | full_msk = msk_point + center[mask].unsqueeze(1)
72 |
73 | full_vis = full_vis.reshape(B, -1, 3)
74 | full_msk = full_msk.reshape(B, -1, 3)
75 |
76 | return full_vis, full_msk
77 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/Makefile:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Uncomment for debugging
3 | # DEBUG := 1
4 | # Pretty build
5 | # Q ?= @
6 |
7 | CXX := g++
8 | PYTHON := python
9 | NVCC := /usr/local/cuda/bin/nvcc
10 |
11 | # PYTHON Header path
12 | PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')
13 | PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]')
14 | PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]')
15 |
16 | # CUDA ROOT DIR that contains bin/ lib64/ and include/
17 | # CUDA_DIR := /usr/local/cuda
18 | CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())')
19 |
20 | INCLUDE_DIRS := ./ $(CUDA_DIR)/include
21 |
22 | INCLUDE_DIRS += $(PYTHON_HEADER_DIR)
23 | INCLUDE_DIRS += $(PYTORCH_INCLUDES)
24 |
25 | # Custom (MKL/ATLAS/OpenBLAS) include and lib directories.
26 | # Leave commented to accept the defaults for your choice of BLAS
27 | # (which should work)!
28 | # BLAS_INCLUDE := /path/to/your/blas
29 | # BLAS_LIB := /path/to/your/blas
30 |
31 | ###############################################################################
32 | SRC_DIR := ./src
33 | OBJ_DIR := ./objs
34 | CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp)
35 | CU_SRCS := $(wildcard $(SRC_DIR)/*.cu)
36 | OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS))
37 | CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS))
38 | STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a
39 |
40 | # CUDA architecture setting: going with all of them.
41 | # For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.
42 | # For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.
43 | CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \
44 | -gencode arch=compute_61,code=compute_61 \
45 | -gencode arch=compute_52,code=sm_52
46 |
47 | # We will also explicitly add stdc++ to the link target.
48 | LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu
49 |
50 | # Debugging
51 | ifeq ($(DEBUG), 1)
52 | COMMON_FLAGS += -DDEBUG -g -O0
53 | # https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/
54 | NVCCFLAGS += -g -G # -rdc true
55 | else
56 | COMMON_FLAGS += -DNDEBUG -O3
57 | endif
58 |
59 | WARNINGS := -Wall -Wno-sign-compare -Wcomment
60 |
61 | INCLUDE_DIRS += $(BLAS_INCLUDE)
62 |
63 | # Automatic dependency generation (nvcc is handled separately)
64 | CXXFLAGS += -MMD -MP
65 |
66 | # Complete build flags.
67 | COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
68 | -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0
69 | CXXFLAGS += -pthread -fPIC -fwrapv -std=c++11 $(COMMON_FLAGS) $(WARNINGS)
70 | NVCCFLAGS += -std=c++11 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
71 |
72 | all: $(STATIC_LIB)
73 | $(PYTHON) setup.py build
74 | @ mv build/lib.linux-x86_64-3.6/StructuralLosses ..
75 | @ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/
76 | @- $(RM) -rf $(OBJ_DIR) build objs
77 |
78 | $(OBJ_DIR):
79 | @ mkdir -p $@
80 | @ mkdir -p $@/cuda
81 |
82 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR)
83 | @ echo CXX $<
84 | $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@
85 |
86 | $(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR)
87 | @ echo NVCC $<
88 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \
89 | -odir $(@D)
90 | $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
91 |
92 | $(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR)
93 | $(RM) -f $(STATIC_LIB)
94 | $(RM) -rf build dist
95 | @ echo LD -o $@
96 | ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)
97 |
98 | clean:
99 | @- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses
100 |
101 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from utils.dataset import Semantic_KITTI, KITTI_Object
4 |
5 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255]
6 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255]
7 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255]
8 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255]
9 |
10 | def plot_point_cloud(pc):
11 | fig = plt.figure(figsize=(8, 8))
12 | ax = fig.add_subplot(111, projection='3d')
13 | ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2])
14 | # ax.set_xlim3d(-1, 1)
15 | # ax.set_ylim3d(-1, 1)
16 | # ax.set_zlim3d(-1, 1)
17 | return fig.show()
18 |
19 |
20 | def show_model(pc, index=0, save_img=False, show_axis=True):
21 | """
22 | Load the generated result from npy file and show it in plots.
23 | Args:
24 | pc: the list of npy files which contain point cloud result. (Each npy file contains multiple PC model)
25 | for AE can take 2 files for output and reference; for GEN take 1 file for output results.
26 | index: show the specific results the group of point cloud result.
27 | """
28 | fig = plt.figure(figsize=(8, 8))
29 | for n_pc in range(0, len(pc)):
30 | ax = fig.add_subplot(1, len(pc), n_pc+1, projection='3d')
31 | models = np.load(pc[n_pc])
32 | print('Input file contains ' + str(len(models)) + ' models')
33 | print(len(models[0]))
34 | print(len(models))
35 | # print(len(models[1]))
36 | ax.scatter(models[index][:, 0], models[index][:, 1], models[index][:, 2])
37 | if n_pc == 0:
38 | ax.set_title('Output', fontsize=14)
39 | else:
40 | ax.set_title('Reference', fontsize=14)
41 | if not show_axis:
42 | ax.axis('off')
43 | if save_img:
44 | fig.savefig('output.png', dpi=300)
45 | fig.show()
46 |
47 | def eval_model_normal(x, t, index=0, show_img=True, save_img=True, show_axis=True, save_location="", second_part=None, color=BLUE, color2=BLUE):
48 | """
49 | Load the generated result from npy file and show it in plots.
50 | Args:
51 | pc: the list of npy files which contain point cloud result. (Each npy file contains multiple PC model)
52 | for AE can take 2 files for output and reference; for GEN take 1 file for output results.
53 | index: show the specific results the group of point cloud result.
54 | """
55 | fig = plt.figure(figsize=(8, 8))
56 |
57 | ax = fig.add_subplot(1, 1, 1, projection='3d')
58 | models = x
59 | # ax.set_title('x_t', fontsize=14)
60 | ax.scatter(models[index][:, 0], models[index][:, 1], models[index][:, 2], color=color, marker='o', alpha=1.0)
61 | if second_part is not None:
62 | ax.scatter(second_part[index][:, 0], second_part[index][:, 1], second_part[index][:, 2], color=color2, marker='o', alpha=1.0)
63 | if not show_axis:
64 | ax.axis('off')
65 | if save_img:
66 | fig.savefig(save_location+'/{i}.png'.format(i=t), dpi=100)
67 | if show_img:
68 | fig.show()
69 |
70 | def show_forward_diff(pc):
71 | fig = plt.figure(figsize=(20,8))
72 | print(len(pc))
73 | for i in range(0, len(pc)):
74 | if i % 10 == 0:
75 | ax = fig.add_subplot(1, 55, i+1, projection='3d')
76 | models = np.load(pc)
77 | ax.scatter(models[i][:, 0], models[i][:, 1], models[i][:, 2])
78 | fig.show()
79 |
80 | if __name__ == '__main__':
81 | # test_KITTI = Semantic_KITTI(
82 | # datapath='../dataset/SEMANTIC_KITTI_DIR',
83 | # subset='test',
84 | # norm=True,
85 | # npoints=2048
86 | # )
87 | # print(test_KITTI.__len__())
88 | # plot_point_cloud(test_KITTI.__getitem__(4)['coord'])
89 | # show_model(['../results/pretrain_decoder_2024_02_28_17_43/out.npy'], 4)
90 |
91 | test_KITTI = KITTI_Object(
92 | datapath='../dataset/KITTI',
93 | subset='test',
94 | norm=True,
95 |
96 | )
97 | print(test_KITTI.__len__())
98 | plot_point_cloud(test_KITTI.__getitem__(90)['hr'])
99 | # show_model(['./results/new_start_2023_10_09_15_17/out.npy'], 21)
100 | # show_forward_diff('./results/new_stable_diffusion_2023_08_10_21_38/out.npy')
101 |
102 | # show_model(['./results/Maksed_noise_70/out.npy',
103 | # './results/baseline/ref.npy'], 100)
--------------------------------------------------------------------------------
/model/Decoder_Component.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from timm.models.layers import DropPath
3 |
4 | # Ref: https://github.com/Pang-Yatian/Point-MAE/blob/main/models/Point_MAE.py#L82
5 | class Mlp(nn.Module):
6 | """
7 | The feedforward in Transformer blockers.
8 | """
9 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
10 | super().__init__()
11 | out_features = out_features or in_features
12 | hidden_features = hidden_features or in_features
13 | self.fc1 = nn.Linear(in_features, hidden_features)
14 | self.act = act_layer()
15 | self.fc2 = nn.Linear(hidden_features, out_features)
16 | self.drop = nn.Dropout(drop)
17 |
18 | def forward(self, x):
19 | x = self.fc1(x)
20 | x = self.act(x)
21 | x = self.drop(x)
22 | x = self.fc2(x)
23 | x = self.drop(x)
24 | return x
25 |
26 | class Attention(nn.Module):
27 | """
28 | The multi-head attention in Transformer blockers.
29 | """
30 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
31 | super().__init__()
32 | self.num_heads = num_heads
33 | head_dim = dim // num_heads
34 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
35 | self.scale = qk_scale or head_dim ** -0.5
36 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
37 | self.attn_drop = nn.Dropout(attn_drop)
38 | self.proj = nn.Linear(dim, dim)
39 | self.proj_drop = nn.Dropout(proj_drop)
40 |
41 | def forward(self, x):
42 | B, N, C = x.shape
43 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
44 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
45 |
46 | attn = (q @ k.transpose(-2, -1)) * self.scale
47 | attn = attn.softmax(dim=-1)
48 | attn = self.attn_drop(attn)
49 |
50 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
51 | x = self.proj(x)
52 | x = self.proj_drop(x)
53 | return x
54 |
55 |
56 | class Block(nn.Module):
57 | """
58 | The Transformer Block in the Decoder module.
59 | """
60 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
61 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
62 | super().__init__()
63 | self.norm1 = norm_layer(dim)
64 |
65 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
66 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
67 | self.norm2 = norm_layer(dim)
68 | mlp_hidden_dim = int(dim * mlp_ratio)
69 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
70 |
71 | self.attn = Attention(
72 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
73 |
74 | def forward(self, x):
75 | x = x + self.drop_path(self.attn(self.norm1(x)))
76 | x = x + self.drop_path(self.mlp(self.norm2(x)))
77 | return x
78 |
79 |
80 | class Transformer(nn.Module):
81 | """
82 | The Transformer for the Decoder module.
83 |
84 | Inputs:
85 | Latent: [B, V+M, L]
86 | Position Embedding: [B, V+M, L]
87 | Time Embedding: [B, V+M, L]
88 |
89 | B: Batch size,
90 | V: Visible patches size
91 | M: Masked patches size
92 | L: Latent size
93 | """
94 | def __init__(self, embed_dim=384, depth=4, num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None,
95 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm):
96 | super().__init__()
97 | self.blocks = nn.ModuleList([
98 | Block(
99 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
100 | drop=drop_rate, attn_drop=attn_drop_rate,
101 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
102 | )
103 | for i in range(depth)])
104 | self.norm = norm_layer(embed_dim)
105 | self.head = nn.Identity()
106 |
107 | self.apply(self._init_weights)
108 |
109 | def _init_weights(self, m):
110 | if isinstance(m, nn.Linear):
111 | nn.init.xavier_uniform_(m.weight)
112 | if isinstance(m, nn.Linear) and m.bias is not None:
113 | nn.init.constant_(m.bias, 0)
114 | elif isinstance(m, nn.LayerNorm):
115 | nn.init.constant_(m.bias, 0)
116 | nn.init.constant_(m.weight, 1.0)
117 |
118 | def forward(self, x, pos, num_of_group, ts):
119 | for _, block in enumerate(self.blocks):
120 | x = block(x + pos + ts)
121 | x = self.head(self.norm(x[:, -num_of_group:]))
122 | return x
123 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/src/nndistance.cu:
--------------------------------------------------------------------------------
1 |
2 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
3 | const int batch=512;
4 | __shared__ float buf[batch*3];
5 | for (int i=blockIdx.x;ibest){
117 | result[(i*n+j)]=best;
118 | result_i[(i*n+j)]=best_i;
119 | }
120 | }
121 | __syncthreads();
122 | }
123 | }
124 | }
125 | void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
126 | NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i);
127 | NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i);
128 | }
129 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
130 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);
153 | NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);
154 | }
155 |
156 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | logger_initialized = {}
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7 | """Get root logger and add a keyword filter to it.
8 | The logger will be initialized if it has not been initialized. By default a
9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
10 | also be added. The name of the root logger is the top-level package name,
11 | e.g., "mmdet3d".
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str, optional): The name of the root logger, also used as a
17 | filter keyword. Defaults to 'mmdet3d'.
18 | Returns:
19 | :obj:`logging.Logger`: The obtained logger
20 | """
21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22 | # add a logging filter
23 | logging_filter = logging.Filter(name)
24 | logging_filter.filter = lambda record: record.find(name) != -1
25 |
26 | return logger
27 |
28 |
29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30 | """Initialize and get a logger by name.
31 | If the logger has not been initialized, this method will initialize the
32 | logger by adding one or two handlers, otherwise the initialized logger will
33 | be directly returned. During initialization, a StreamHandler will always be
34 | added. If `log_file` is specified and the process rank is 0, a FileHandler
35 | will also be added.
36 | Args:
37 | name (str): Logger name.
38 | log_file (str | None): The log filename. If specified, a FileHandler
39 | will be added to the logger.
40 | log_level (int): The logger level. Note that only the process of
41 | rank 0 is affected, and other processes will set the level to
42 | "Error" thus be silent most of the time.
43 | file_mode (str): The file mode used in opening log file.
44 | Defaults to 'w'.
45 | Returns:
46 | logging.Logger: The expected logger.
47 | """
48 | logger = logging.getLogger(name)
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | # handle duplicate logs to the console
59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
60 | # to the root logger. As logger.propagate is True by default, this root
61 | # level handler causes logging messages from rank>0 processes to
62 | # unexpectedly show up on the console, creating much unwanted clutter.
63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
64 | # at the ERROR level.
65 | for handler in logger.root.handlers:
66 | if type(handler) is logging.StreamHandler:
67 | handler.setLevel(logging.ERROR)
68 |
69 | stream_handler = logging.StreamHandler()
70 | handlers = [stream_handler]
71 |
72 | if dist.is_available() and dist.is_initialized():
73 | rank = dist.get_rank()
74 | else:
75 | rank = 0
76 |
77 | # only rank 0 will add a FileHandler
78 | if rank == 0 and log_file is not None:
79 | # Here, the default behaviour of the official logger is 'a'. Thus, we
80 | # provide an interface to change the file mode to the default
81 | # behaviour.
82 | file_handler = logging.FileHandler(log_file, file_mode)
83 | handlers.append(file_handler)
84 |
85 | formatter = logging.Formatter(
86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87 | for handler in handlers:
88 | handler.setFormatter(formatter)
89 | handler.setLevel(log_level)
90 | logger.addHandler(handler)
91 |
92 | if rank == 0:
93 | logger.setLevel(log_level)
94 | else:
95 | logger.setLevel(logging.ERROR)
96 |
97 | logger_initialized[name] = True
98 |
99 |
100 | return logger
101 |
102 |
103 | def print_log(msg, logger=None, level=logging.INFO):
104 | """Print a log message.
105 | Args:
106 | msg (str): The message to be logged.
107 | logger (logging.Logger | str | None): The logger to be used.
108 | Some special loggers are:
109 | - "silent": no message will be printed.
110 | - other str: the logger obtained with `get_root_logger(logger)`.
111 | - None: The `print()` method will be used to print log messages.
112 | level (int): Logging level. Only available when `logger` is a Logger
113 | object or "root".
114 | """
115 | if logger is None:
116 | print(msg)
117 | elif isinstance(logger, logging.Logger):
118 | logger.log(level, msg)
119 | elif logger == 'silent':
120 | pass
121 | elif isinstance(logger, str):
122 | _logger = get_logger(logger)
123 | _logger.log(level, msg)
124 | else:
125 | raise TypeError(
126 | 'logger should be either a logging.Logger object, str, '
127 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/train_encoder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 |
4 | from torch import optim
5 | from torch.utils.data import DataLoader
6 | from utils.dataset import *
7 | from utils.logger import *
8 | from model.Encoder import *
9 |
10 |
11 | def get_data_iterator(iterable):
12 | """Allows training with DataLoaders in a single infinite loop:
13 | for i, data in enumerate(inf_generator(train_loader)):
14 | """
15 | iterator = iterable.__iter__()
16 | while True:
17 | try:
18 | yield iterator.__next__()
19 | except StopIteration:
20 | iterator = iterable.__iter__()
21 |
22 |
23 | parser = argparse.ArgumentParser()
24 | # Experiment setting
25 | parser.add_argument('--batch_size', type=int, default=4)
26 | parser.add_argument('--val_batch_size', type=int, default=1)
27 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
28 | parser.add_argument('--save_dir', type=str, default='./results')
29 | parser.add_argument('--log', type=bool, default=False)
30 |
31 | # Grouping setting
32 | parser.add_argument('--mask_type', type=str, default='rand')
33 | parser.add_argument('--mask_ratio', type=float, default=0.75)
34 | parser.add_argument('--group_size', type=int, default=32)
35 | parser.add_argument('--num_group', type=int, default=64)
36 | parser.add_argument('--num_points', type=int, default=2048)
37 | parser.add_argument('--num_output', type=int, default=8192)
38 |
39 | # Transformer setting
40 | parser.add_argument('--trans_dim', type=int, default=384)
41 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
42 |
43 | # Encoder setting
44 | parser.add_argument('--encoder_depth', type=int, default=12)
45 | parser.add_argument('--encoder_num_heads', type=int, default=6)
46 | parser.add_argument('--encoder_dims', type=int, default=384)
47 | parser.add_argument('--loss', type=str, default='cdl2')
48 |
49 | # sche / optim
50 | parser.add_argument('--learning_rate', type=float, default=0.001)
51 | parser.add_argument('--weight_decay', type=float, default=0.05)
52 | parser.add_argument('--eta_min', type=float, default=0.000001)
53 | parser.add_argument('--t_max', type=float, default=200)
54 |
55 | args = parser.parse_args()
56 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M")
57 | save_dir = os.path.join(args.save_dir, 'encocder_8192_width_{gSize}'.format(gSize=args.encoder_dims))
58 |
59 | if args.log:
60 | if not os.path.exists(save_dir):
61 | os.makedirs(save_dir)
62 | log = logging.getLogger()
63 | log.setLevel(logging.INFO)
64 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s')
65 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
66 | log_file.setLevel(logging.INFO)
67 | log_file.setFormatter(formatter)
68 | log.addHandler(log_file)
69 |
70 |
71 |
72 | if args.log:
73 | log.info('loading dataset')
74 | print('loading dataset')
75 |
76 | train_dset = ShapeNet(
77 | data_path='dataset/ShapeNet55/ShapeNet-55',
78 | pc_path='dataset/ShapeNet55/shapenet_pc',
79 | subset='train',
80 | n_points=2048,
81 | downsample=True
82 | )
83 | val_dset = ShapeNet(
84 | data_path='dataset/ShapeNet55/ShapeNet-55',
85 | pc_path='dataset/ShapeNet55/shapenet_pc',
86 | subset='test',
87 | n_points=2048,
88 | downsample=True
89 | )
90 |
91 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True)
92 | trn_loader = DataLoader(train_dset, batch_size=args.batch_size, pin_memory=True)
93 | if args.log:
94 | log.info('Training decoder for stable diffusion.')
95 | log.info('config:')
96 | log.info(args)
97 | log.info('dataset loaded')
98 | print('dataset loaded')
99 |
100 |
101 | model = Encoder_Module(args).to(args.device)
102 |
103 |
104 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
105 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=args.eta_min, T_max=args.t_max)
106 |
107 |
108 | def train(i, batch, epoch):
109 | if i == 0 and epoch > 1:
110 | print('Saving checkpoint at epoch {epoch}'.format(epoch=epoch))
111 | save_model = {'args': args, 'model': model.state_dict()}
112 | if args.log:
113 | torch.save(save_model, os.path.join(save_dir, 'autoencoder_diff.pt'))
114 |
115 | x = batch['lr'].to(args.device)
116 | optimizer.zero_grad()
117 | model.train()
118 | loss = model(x, batch['hr'].to(args.device))
119 | loss.backward()
120 | optimizer.step()
121 | if args.log and i == 0:
122 | log.info('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss))
123 | print('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss))
124 |
125 |
126 | def validate():
127 | all_recons = []
128 | for i, batch in enumerate(val_loader):
129 | print('sampling model {i}'.format(i=i))
130 | ref = batch['lr'].to(args.device)
131 | if i > 200:
132 | break
133 | with torch.no_grad():
134 | model.eval()
135 | x_vis, x_masked, mask, center, vis_pc, msk_pc = model.encode(ref, masked=False)
136 | recons = model.evaluate(x_vis, x_masked)
137 | all_recons.append(recons)
138 | all_recons = torch.cat(all_recons, dim=0)
139 | np.save(os.path.join(save_dir, 'out.npy'), all_recons.cpu().numpy())
140 |
141 |
142 | try:
143 | n_it = 50
144 | epoch = 1
145 | while epoch <= n_it:
146 | model.train()
147 | for i, pc in enumerate(trn_loader):
148 | train(i, pc, epoch)
149 | scheduler.step()
150 |
151 | if epoch == n_it:
152 | if args.log:
153 | saved_file = {'args': args, 'model': model.state_dict()}
154 | torch.save(saved_file, os.path.join(save_dir, 'autoencoder_diff.pt'))
155 | validate()
156 | epoch += 1
157 | except Exception as e:
158 | log.error(e)
159 |
160 |
--------------------------------------------------------------------------------
/eval_compression.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | from torch.utils.data import DataLoader
4 | from utils.dataset import *
5 | from utils.logger import *
6 | from model.EncoderCompress import *
7 | from model.DiffusionCop import *
8 |
9 | def get_data_iterator(iterable):
10 | """Allows training with DataLoaders in a single infinite loop:
11 | for i, data in enumerate(inf_generator(train_loader)):
12 | """
13 | iterator = iterable.__iter__()
14 | while True:
15 | try:
16 | yield iterator.__next__()
17 | except StopIteration:
18 | iterator = iterable.__iter__()
19 |
20 | parser = argparse.ArgumentParser()
21 | # Experiment setting
22 | parser.add_argument('--val_batch_size', type=int, default=1)
23 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
24 | parser.add_argument('--save_dir', type=str, default='./results')
25 | parser.add_argument('--log', type=bool, default=True)
26 |
27 | # Grouping setting
28 | parser.add_argument('--mask_type', type=str, default='rand')
29 | parser.add_argument('--mask_ratio', type=float, default=0.75)
30 | parser.add_argument('--group_size', type=int, default=32) # points in each group
31 | parser.add_argument('--num_group', type=int, default=64) # number of group
32 | parser.add_argument('--num_points', type=int, default=2048)
33 | parser.add_argument('--num_output', type=int, default=8192)
34 |
35 | # Transformer setting
36 | parser.add_argument('--trans_dim', type=int, default=384)
37 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
38 |
39 | # Encoder setting
40 | parser.add_argument('--encoder_depth', type=int, default=12)
41 | parser.add_argument('--encoder_num_heads', type=int, default=6)
42 | parser.add_argument('--loss', type=str, default='cdl2')
43 |
44 | # Decoder setting
45 | parser.add_argument('--decoder_depth', type=int, default=4)
46 | parser.add_argument('--decoder_num_heads', type=int, default=4)
47 |
48 | # diffusion
49 | parser.add_argument('--num_steps', type=int, default=200)
50 | parser.add_argument('--beta_1', type=float, default=1e-4)
51 | parser.add_argument('--beta_T', type=float, default=0.05)
52 | parser.add_argument('--sched_mode', type=str, default='linear')
53 |
54 |
55 | args = parser.parse_args()
56 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M")
57 | save_dir = os.path.join(args.save_dir, 'eval_new_pe'.format(date=time_now))
58 |
59 |
60 | check_point_dir = os.path.join('./pretrain_model/compress/encoder.pt')
61 | check_point = torch.load(check_point_dir)['model']
62 | encoder = Encoder_Module(args).to(args.device)
63 | encoder.load_state_dict(check_point)
64 |
65 |
66 | if args.log:
67 | if not os.path.exists(save_dir):
68 | os.makedirs(save_dir)
69 | log = logging.getLogger()
70 | log.setLevel(logging.INFO)
71 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s')
72 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
73 | log_file.setLevel(logging.INFO)
74 | log_file.setFormatter(formatter)
75 | log.addHandler(log_file)
76 |
77 | if args.log:
78 | log.info('loading dataset')
79 | print('loading dataset')
80 |
81 |
82 | test_dset_MN = ModelNet(
83 | root='dataset/ModelNet',
84 | number_pts=8192,
85 | downsampling=2048,
86 | use_normal=False,
87 | cats=40,
88 | subset='test'
89 | )
90 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True, shuffle=True)
91 |
92 | val_dset = ShapeNet(
93 | data_path='dataset/ShapeNet55/ShapeNet-55',
94 | pc_path='dataset/ShapeNet55/shapenet_pc',
95 | subset='test',
96 | n_points=2048,
97 | downsample=True
98 | )
99 |
100 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True)
101 | if args.log:
102 | log.info('Training Stable Diffusion')
103 | log.info('config:')
104 | log.info(args)
105 | log.info('dataset loaded')
106 | print('dataset loaded')
107 |
108 | print('loading model')
109 |
110 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255]
111 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255]
112 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255]
113 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255]
114 |
115 | def calculate_metric_all(size=-1):
116 | # mean_cd = []
117 | diff_check_point_dir = os.path.join('./pretrain_model/compress/decoder.pt')
118 | diff_check_point = torch.load(diff_check_point_dir)['model']
119 | model = Diff_Point_MAE(args).to(args.device)
120 | model = torch.nn.DataParallel(model, device_ids=[0])
121 | model.load_state_dict(diff_check_point)
122 |
123 | all_sample = []
124 | all_ref = []
125 | all_hd = []
126 | all_cd = []
127 | for i, batch in enumerate(val_loader):
128 | if i == size:
129 | break
130 | with torch.no_grad():
131 | ref = batch['lr'].to(args.device)
132 | model.eval()
133 | compress = Compress(args).to(args.device)
134 | vis, center, mask = compress.compress(ref)
135 |
136 | x_vis, vis_pc = encoder.encode(vis, center, mask)
137 | recons = model.module.sampling(x_vis, mask, center)
138 | recons = torch.cat([vis_pc, recons], dim=1)
139 |
140 | all_sample.append(recons)
141 | all_ref.append(ref)
142 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze())
143 | all_hd.append(hd)
144 |
145 | cd = chamfer_distance_l2(recons, ref)
146 | all_cd.append(cd)
147 | print("evaluating model: {i}, CD: {cd}".format(i=i, cd=cd))
148 |
149 | mean_hd = sum(all_hd) / len(all_hd)
150 | mean_cd = sum(all_cd) / len(all_cd)
151 | sample = torch.cat(all_sample, dim=0)
152 | refpc = torch.cat(all_ref, dim=0)
153 | jsd = jsd_between_point_cloud_sets(sample, refpc)
154 |
155 | print("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd))
156 | log.info("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd))
157 |
158 |
159 | calculate_metric_all()
160 |
--------------------------------------------------------------------------------
/eval_cop.py:
--------------------------------------------------------------------------------
1 | from utils.dataset import *
2 | from metrics.evaluation_metrics import averaged_hausdorff_distance, jsd_between_point_cloud_sets
3 | from model.Completion import *
4 | from model.EncoderCop import *
5 | import argparse
6 | from datetime import datetime
7 |
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from utils.logger import *
11 |
12 | def get_data_iterator(iterable):
13 | """Allows training with DataLoaders in a single infinite loop:
14 | for i, data in enumerate(inf_generator(train_loader)):
15 | """
16 | iterator = iterable.__iter__()
17 | while True:
18 | try:
19 | yield iterator.__next__()
20 | except StopIteration:
21 | iterator = iterable.__iter__()
22 |
23 |
24 | parser = argparse.ArgumentParser()
25 | # Experiment setting
26 | parser.add_argument('--val_batch_size', type=int, default=1)
27 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
28 | parser.add_argument('--save_dir', type=str, default='./results')
29 | parser.add_argument('--log', type=bool, default=True)
30 |
31 | # Grouping setting
32 | parser.add_argument('--mask_type', type=str, default='rand')
33 | parser.add_argument('--mask_ratio', type=float, default=0.75)
34 | parser.add_argument('--group_size', type=int, default=32) # points in each group
35 | parser.add_argument('--num_group', type=int, default=64) # number of group
36 | parser.add_argument('--num_points', type=int, default=2048)
37 | parser.add_argument('--num_output', type=int, default=8192)
38 |
39 | # Transformer setting
40 | parser.add_argument('--trans_dim', type=int, default=384)
41 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
42 |
43 | # Encoder setting
44 | parser.add_argument('--encoder_depth', type=int, default=12)
45 | parser.add_argument('--encoder_num_heads', type=int, default=6)
46 | parser.add_argument('--loss', type=str, default='cdl2')
47 |
48 | # Decoder setting
49 | parser.add_argument('--decoder_depth', type=int, default=4)
50 | parser.add_argument('--decoder_num_heads', type=int, default=4)
51 |
52 | # diffusion
53 | parser.add_argument('--num_steps', type=int, default=200)
54 | parser.add_argument('--beta_1', type=float, default=1e-4)
55 | parser.add_argument('--beta_T', type=float, default=0.05)
56 | parser.add_argument('--sched_mode', type=str, default='linear')
57 |
58 | args = parser.parse_args()
59 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M")
60 | save_dir = os.path.join(args.save_dir, 'new_start_eval_{date}'.format(date=time_now))
61 |
62 | if args.log:
63 | if not os.path.exists(save_dir):
64 | os.makedirs(save_dir)
65 | log = logging.getLogger()
66 | log.setLevel(logging.INFO)
67 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s')
68 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
69 | log_file.setLevel(logging.INFO)
70 | log_file.setFormatter(formatter)
71 | log.addHandler(log_file)
72 |
73 | if args.log:
74 | log.info('loading dataset')
75 | print('loading dataset')
76 |
77 |
78 | test_dset_MN = ModelNet(
79 | root='dataset/ModelNet',
80 | number_pts=8192,
81 | downsampling=2048,
82 | use_normal=False,
83 | cats=40,
84 | subset='test'
85 | )
86 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True, shuffle=True)
87 |
88 | val_dset = ShapeNet(
89 | data_path='dataset/ShapeNet55/ShapeNet-55',
90 | pc_path='dataset/ShapeNet55/shapenet_pc',
91 | subset='test',
92 | n_points=2048,
93 | downsample=True
94 | )
95 |
96 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True)
97 | if args.log:
98 | log.info('Training Stable Diffusion')
99 | log.info('config:')
100 | log.info(args)
101 | log.info('dataset loaded')
102 | print('dataset loaded')
103 |
104 | print('loading model')
105 |
106 | def calculate_metric_all(size=-1, encoder=None):
107 | # mean_cd = []
108 | diff_check_point_dir = os.path.join('./pretrain_model/completion/decoder.pt')
109 | model = Diff_Point_MAE(args, encoder).to(args.device)
110 | model = torch.nn.DataParallel(model, device_ids=[0])
111 | model.module.load_model_from_ckpt(diff_check_point_dir)
112 |
113 |
114 | all_sample = []
115 | all_ref = []
116 | all_hd = []
117 | all_cd = []
118 | for i, batch in enumerate(val_loader_MN):
119 | if i == size:
120 | break
121 | with torch.no_grad():
122 | ref = batch['model'].to(args.device)
123 | model.eval()
124 | x_vis, x_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref)
125 |
126 | recons = model.module.sampling(x_vis)
127 |
128 | recons = torch.cat([vis_pc, recons], dim=1)
129 | all_sample.append(recons)
130 | all_ref.append(ref)
131 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze())
132 | all_hd.append(hd)
133 |
134 | cd = chamfer_distance_l2(recons, ref)
135 | all_cd.append(cd)
136 | print("evaluating model: {i}, CD: {cd}".format(i=i, cd=cd))
137 |
138 | mean_hd = sum(all_hd) / len(all_hd)
139 | mean_cd = sum(all_cd) / len(all_cd)
140 | sample = torch.cat(all_sample, dim=0)
141 | refpc = torch.cat(all_ref, dim=0)
142 | jsd = jsd_between_point_cloud_sets(sample, refpc)
143 | print("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd))
144 | log.info("MMD CD: {mmd}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=mean_cd, jsd=jsd, hd=mean_hd))
145 |
146 |
147 | # calculate_metric_all()
148 |
149 | def auto_run():
150 | lose_ratio = [0.75]
151 | for lr in lose_ratio:
152 | args.mask_ratio = lr
153 | check_point_dir = os.path.join('./pretrain_model/completion/encoder.pt')
154 | # check_point_dir = os.path.join(args.save_dir, '{saved_src}/encoder.pt'
155 | # .format(saved_src='new_start_eval_2023_11_12_02_21'))
156 | check_point = torch.load(check_point_dir)['model']
157 | encoder = Point_MAE(args).to(args.device)
158 | encoder.load_state_dict(check_point)
159 | print('model loaded')
160 |
161 |
162 | calculate_metric_all(encoder=encoder)
163 |
164 | auto_run()
165 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/src/structural_loss.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "src/approxmatch.cuh"
5 | #include "src/nndistance.cuh"
6 |
7 | #include
8 | #include
9 |
10 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13 |
14 | /*
15 | input:
16 | set1 : batch_size * #dataset_points * 3
17 | set2 : batch_size * #query_points * 3
18 | returns:
19 | match : batch_size * #query_points * #dataset_points
20 | */
21 | // temp: TensorShape{b,(n+m)*2}
22 | std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) {
23 | //std::cout << "[ApproxMatch] Called." << std::endl;
24 | int64_t batch_size = set_d.size(0);
25 | int64_t n_dataset_points = set_d.size(1); // n
26 | int64_t n_query_points = set_q.size(1); // m
27 | //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl;
28 | at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
29 | at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
30 | CHECK_INPUT(set_d);
31 | CHECK_INPUT(set_q);
32 | CHECK_INPUT(match);
33 | CHECK_INPUT(temp);
34 |
35 | approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream());
36 | return {match, temp};
37 | }
38 |
39 | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
40 | //std::cout << "[MatchCost] Called." << std::endl;
41 | int64_t batch_size = set_d.size(0);
42 | int64_t n_dataset_points = set_d.size(1); // n
43 | int64_t n_query_points = set_q.size(1); // m
44 | //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl;
45 | at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
46 | CHECK_INPUT(set_d);
47 | CHECK_INPUT(set_q);
48 | CHECK_INPUT(match);
49 | CHECK_INPUT(out);
50 | matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream());
51 | return out;
52 | }
53 |
54 | std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
55 | //std::cout << "[MatchCostGrad] Called." << std::endl;
56 | int64_t batch_size = set_d.size(0);
57 | int64_t n_dataset_points = set_d.size(1); // n
58 | int64_t n_query_points = set_q.size(1); // m
59 | //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl;
60 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
61 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
62 | CHECK_INPUT(set_d);
63 | CHECK_INPUT(set_q);
64 | CHECK_INPUT(match);
65 | CHECK_INPUT(grad1);
66 | CHECK_INPUT(grad2);
67 | matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream());
68 | return {grad1, grad2};
69 | }
70 |
71 |
72 | /*
73 | input:
74 | set_d : batch_size * #dataset_points * 3
75 | set_q : batch_size * #query_points * 3
76 | returns:
77 | dist1, idx1 : batch_size * #dataset_points
78 | dist2, idx2 : batch_size * #query_points
79 | */
80 | std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) {
81 | //std::cout << "[NNDistance] Called." << std::endl;
82 | int64_t batch_size = set_d.size(0);
83 | int64_t n_dataset_points = set_d.size(1); // n
84 | int64_t n_query_points = set_q.size(1); // m
85 | //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl;
86 | at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
87 | at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
88 | at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
89 | at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
90 | CHECK_INPUT(set_d);
91 | CHECK_INPUT(set_q);
92 | CHECK_INPUT(dist1);
93 | CHECK_INPUT(idx1);
94 | CHECK_INPUT(dist2);
95 | CHECK_INPUT(idx2);
96 | // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
97 | nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream());
98 | return {dist1, idx1, dist2, idx2};
99 | }
100 |
101 | std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) {
102 | //std::cout << "[NNDistanceGrad] Called." << std::endl;
103 | int64_t batch_size = set_d.size(0);
104 | int64_t n_dataset_points = set_d.size(1); // n
105 | int64_t n_query_points = set_q.size(1); // m
106 | //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl;
107 | at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
108 | at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
109 | CHECK_INPUT(set_d);
110 | CHECK_INPUT(set_q);
111 | CHECK_INPUT(idx1);
112 | CHECK_INPUT(idx2);
113 | CHECK_INPUT(grad_dist1);
114 | CHECK_INPUT(grad_dist2);
115 | CHECK_INPUT(grad1);
116 | CHECK_INPUT(grad2);
117 | //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
118 | nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),
119 | grad_dist1.data(),idx1.data(),
120 | grad_dist2.data(),idx2.data(),
121 | grad1.data(),grad2.data(),
122 | at::cuda::getCurrentCUDAStream());
123 | return {grad1, grad2};
124 | }
125 |
126 |
--------------------------------------------------------------------------------
/model/DiffusionCop.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from metrics.evaluation_metrics import *
3 | from model.Diffusion import VarianceSchedule, TimeEmbedding
4 | from model.Decoder_Component import *
5 |
6 | """
7 | Prediction on the masked patches only
8 | w/ diffusion process
9 | Use FC layer as the mask token convertor
10 | """
11 |
12 |
13 | class Diff_Point_MAE(nn.Module):
14 | def __init__(self, config):
15 | super().__init__()
16 | self.config = config
17 | self.trans_dim = config.trans_dim
18 | self.group_size = config.group_size
19 | self.num_group = config.num_group
20 | self.num_output = config.num_output
21 | self.num_channel = 3
22 | self.drop_path_rate = config.drop_path_rate
23 | self.mask_token = nn.Conv1d((self.num_channel * 2048) // self.num_group, self.trans_dim, 1)
24 | self.decoder_pos_embed = nn.Sequential(
25 | nn.Linear(3, 128),
26 | nn.GELU(),
27 | nn.Linear(128, self.trans_dim)
28 | )
29 |
30 | self.decoder_depth = config.decoder_depth
31 | self.decoder_num_heads = config.decoder_num_heads
32 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)]
33 | self.MAE_decoder = Transformer(
34 | embed_dim=self.trans_dim,
35 | depth=self.decoder_depth,
36 | drop_path_rate=dpr,
37 | num_heads=self.decoder_num_heads,
38 | )
39 |
40 | self.loss = config.loss
41 | # loss
42 | self.build_loss_func(self.loss)
43 | self.var = VarianceSchedule(
44 | num_steps=config.num_steps,
45 | beta_1=config.beta_1,
46 | beta_T=config.beta_T,
47 | mode=config.sched_mode
48 | )
49 |
50 | # prediction head
51 | self.increase_dim = nn.Sequential(
52 | nn.Conv1d(self.trans_dim, (self.num_channel * 2048) // self.num_group, 1)
53 | )
54 |
55 | self.timestep = config.num_steps
56 | self.beta_1 = config.beta_1
57 | self.beta_T = config.beta_T
58 |
59 | self.betas = self.linear_schedule(timesteps=self.timestep)
60 |
61 | self.alphas = 1.0 - self.betas
62 | self.alpha_bar = torch.cumprod(self.alphas, axis=0)
63 | self.alpha_bar_t_minus_one = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)
64 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
65 | self.sqrt_alphas_bar = torch.sqrt(self.alpha_bar)
66 | self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alpha_bar)
67 | self.sigma = self.betas * (1.0 - self.alpha_bar_t_minus_one) / (1.0 - self.alpha_bar)
68 | self.sqrt_alphas = torch.sqrt(self.alphas)
69 | self.sqrt_alpha_bar_minus_one = torch.sqrt(self.alpha_bar_t_minus_one)
70 |
71 | self.time_emb = nn.Sequential(
72 | TimeEmbedding(self.trans_dim),
73 | nn.Linear(self.trans_dim, self.trans_dim),
74 | nn.ReLU()
75 | )
76 |
77 | def build_loss_func(self, loss_type):
78 | if loss_type == "cdl1":
79 | self.loss_func = chamfer_distance_l1
80 | elif loss_type == 'cdl2':
81 | self.loss_func = chamfer_distance_l2
82 | elif loss_type == 'mse':
83 | self.loss_func = F.mse_loss
84 | else:
85 | raise NotImplementedError
86 |
87 | def linear_schedule(self, timesteps):
88 | return torch.linspace(self.beta_1, self.beta_T, timesteps)
89 |
90 | def get_index_from_list(self, vals, t, x_shape):
91 | b = t.shape[0]
92 | out = vals.gather(-1, t.cpu())
93 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device)
94 |
95 | def forward_diffusion(self, x_0, t):
96 | noise = torch.randn_like(x_0).to(x_0.device)
97 | sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_bar, t, x_0.shape).to(x_0.device)
98 | sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_bar, t, x_0.shape).to(
99 | x_0.device)
100 |
101 | return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
102 |
103 | def forward(self, x_0, t, x_vis, mask, center, vis_pc, ori, debug=False):
104 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1)
105 | x_t, noise = self.forward_diffusion(x_0, t)
106 |
107 | B, _, C = x_vis.shape # B VIS C
108 |
109 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C)
110 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C)
111 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1)
112 | _, N, _ = pos_emd_msk.shape
113 | mask_token = self.mask_token(x_t.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device)
114 | x_full = torch.cat([x_vis, mask_token], dim=1)
115 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts)
116 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3)
117 | pc_full = torch.cat([vis_pc, x_rec], dim=1)
118 | if debug:
119 | return x_t, noise, x_rec
120 | else:
121 | return self.loss_func(pc_full, ori)
122 |
123 | def sampling_t(self, x, t, mask, center, x_vis):
124 | center = center.float()
125 | B, _, C = x_vis.shape # B VIS C
126 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1)
127 | betas_t = self.get_index_from_list(self.betas, t, x.shape).to(x_vis.device)
128 |
129 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C)
130 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C)
131 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1)
132 | _, N, _ = pos_emd_msk.shape
133 | mask_token = self.mask_token(x.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device)
134 | x_full = torch.cat([x_vis, mask_token], dim=1)
135 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts)
136 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3)
137 |
138 | alpha_bar_t = self.get_index_from_list(self.alpha_bar, t, x.shape).to(x_vis.device)
139 | alpha_bar_t_minus_one = self.get_index_from_list(self.alpha_bar_t_minus_one, t, x.shape).to(x_vis.device)
140 | sqrt_alpha_t = self.get_index_from_list(self.sqrt_alphas, t, x.shape).to(x_vis.device)
141 | sqrt_alphas_bar_t_minus_one = self.get_index_from_list(self.sqrt_alpha_bar_minus_one, t, x.shape).to(
142 | x_vis.device)
143 |
144 | model_mean = (sqrt_alpha_t * (1 - alpha_bar_t_minus_one)) / (1 - alpha_bar_t) * x + (
145 | sqrt_alphas_bar_t_minus_one * betas_t) / (1 - alpha_bar_t) * x_rec
146 |
147 | sigma_t = self.get_index_from_list(self.sigma, t, x.shape).to(x_vis.device)
148 |
149 | if t == 0:
150 | return model_mean
151 | else:
152 | return model_mean + torch.sqrt(sigma_t) * x_rec
153 |
154 | def sampling(self, x_vis, mask, center, trace=False, noise_patch=None):
155 | B, M, C = x_vis.shape
156 | if noise_patch is None:
157 | noise_patch = torch.randn((B, (self.num_group - M) * self.group_size, 3)).to(x_vis.device)
158 | diffusion_sequence = []
159 |
160 | for i in range(0, self.timestep)[::-1]:
161 | t = torch.full((1,), i, device=x_vis.device)
162 | noise_patch = self.sampling_t(noise_patch, t, mask, center, x_vis)
163 | if trace:
164 | diffusion_sequence.append(noise_patch.reshape(B, -1, 3))
165 |
166 | if trace:
167 | return diffusion_sequence
168 | else:
169 | return noise_patch.reshape(B, -1, 3)
170 |
--------------------------------------------------------------------------------
/eval_upsampling.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path
3 | from datetime import datetime
4 |
5 | import point_cloud_utils as pcu
6 | from glob import glob
7 | from torch.utils.data import DataLoader
8 | from torch import optim
9 |
10 | from utils.dataset import *
11 | from utils.logger import *
12 | from model.DiffusionSR import *
13 | from model.EncoderSR import *
14 | from utils.visualization import *
15 | from metrics.evaluation_metrics import compute_all_metrics, hausdorff_distance
16 |
17 |
18 |
19 | def get_data_iterator(iterable):
20 | """Allows training with DataLoaders in a single infinite loop:
21 | for i, data in enumerate(inf_generator(train_loader)):
22 | """
23 | iterator = iterable.__iter__()
24 | while True:
25 | try:
26 | yield iterator.__next__()
27 | except StopIteration:
28 | iterator = iterable.__iter__()
29 |
30 |
31 | parser = argparse.ArgumentParser()
32 | # Experiment setting
33 | parser.add_argument('--val_batch_size', type=int, default=1)
34 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
35 | parser.add_argument('--save_dir', type=str, default='./results')
36 | parser.add_argument('--log', type=bool, default=True)
37 |
38 | # Grouping setting
39 | parser.add_argument('--mask_type', type=str, default='rand')
40 | parser.add_argument('--mask_ratio', type=float, default=0.4)
41 | parser.add_argument('--group_size', type=int, default=32) # points in each group
42 | parser.add_argument('--num_group', type=int, default=64) # number of group
43 | parser.add_argument('--num_points', type=int, default=2048)
44 | parser.add_argument('--num_output', type=int, default=8192)
45 | parser.add_argument('--diffusion_output_size', default=8192)
46 |
47 | # Transformer setting
48 | parser.add_argument('--trans_dim', type=int, default=384)
49 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
50 |
51 | # Encoder setting
52 | parser.add_argument('--encoder_depth', type=int, default=12)
53 | parser.add_argument('--encoder_num_heads', type=int, default=6)
54 | parser.add_argument('--loss', type=str, default='cdl2')
55 |
56 | # Decoder setting
57 | parser.add_argument('--decoder_depth', type=int, default=4)
58 | parser.add_argument('--decoder_num_heads', type=int, default=4)
59 |
60 | # diffusion
61 | parser.add_argument('--num_steps', type=int, default=200)
62 | parser.add_argument('--beta_1', type=float, default=1e-4)
63 | parser.add_argument('--beta_T', type=float, default=0.05)
64 | parser.add_argument('--sched_mode', type=str, default='linear')
65 |
66 | args = parser.parse_args()
67 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M")
68 | save_dir = os.path.join(args.save_dir, 'new_start_eval_{date}'.format(date=time_now))
69 |
70 | if args.log:
71 | if not os.path.exists(save_dir):
72 | os.makedirs(save_dir)
73 | log = logging.getLogger()
74 | log.setLevel(logging.INFO)
75 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s')
76 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
77 | log_file.setLevel(logging.INFO)
78 | log_file.setFormatter(formatter)
79 | log.addHandler(log_file)
80 |
81 | if args.log:
82 | log.info('loading dataset')
83 | print('loading dataset')
84 |
85 |
86 | test_dset_MN = ModelNet(
87 | root='dataset/ModelNet',
88 | number_pts=8192,
89 | use_normal=False,
90 | cats=40,
91 | subset='test'
92 | )
93 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True)
94 |
95 | val_dset = ShapeNet(
96 | data_path='dataset/ShapeNet55/ShapeNet-55',
97 | pc_path='dataset/ShapeNet55/shapenet_pc',
98 | subset='test',
99 | n_points=2048,
100 | downsample=True
101 | )
102 |
103 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True)
104 | if args.log:
105 | log.info('Training Stable Diffusion')
106 | log.info('config:')
107 | log.info(args)
108 | log.info('dataset loaded')
109 | print('dataset loaded')
110 |
111 | print('loading model')
112 | check_point_dir = os.path.join('./pretrain_model/sr/encoder.pt')
113 | check_point = torch.load(check_point_dir)['model']
114 | encoder = Encoder_Module(args).to(args.device)
115 | encoder.load_state_dict(check_point, strict=False)
116 | print('model loaded')
117 |
118 | diff_check_point_dir = os.path.join('./pretrain_model/sr/decoder.pt')
119 | diff_check_point = torch.load(diff_check_point_dir)['model']
120 | model = Diff_Point_MAE(args).to(args.device)
121 | model = torch.nn.DataParallel(model, device_ids=[0])
122 | incompatible = model.load_state_dict(diff_check_point, strict=False)
123 |
124 | optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
125 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0.000001, T_max=20)
126 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255]
127 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255]
128 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255]
129 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255]
130 |
131 |
132 | def load(filename, count=None):
133 | points = np.loadtxt(filename).astype(np.float32)
134 | if count is not None:
135 | if count > points.shape[0]:
136 | # fill the point clouds with the random point
137 | tmp = np.zeros((count, points.shape[1]), dtype=points.dtype)
138 | tmp[:points.shape[0], ...] = points
139 | tmp[points.shape[0]:, ...] = points[np.random.choice(
140 | points.shape[0], count - points.shape[0]), :]
141 | points = tmp
142 | elif count < points.shape[0]:
143 | # different to pointnet2, take random x point instead of the first
144 | # idx = np.random.permutation(count)
145 | # points = points[idx, :]
146 | points = uniform_down_sample(points, count)
147 |
148 | return points
149 |
150 | def calculate_metric_all(size=-1):
151 | all_sample = []
152 | all_ref = []
153 | all_hd = []
154 | all_vis = []
155 | gt_paths = glob(os.path.join('dataset/PU1K/test/input_2048/gt_8192', '*.xyz'))
156 | x_paths = glob(os.path.join('dataset/PU1K/test/input_2048/input_2048', '*.xyz'))
157 |
158 | for i in range(0, len(gt_paths)):
159 | if not os.path.exists(save_dir + '/' + str(i)):
160 | os.makedirs(save_dir + '/' + str(i))
161 | with torch.no_grad():
162 | x_hr = torch.from_numpy(load(gt_paths[i])[:, :3]).float().unsqueeze(0).to('cuda')
163 | x = torch.from_numpy(load(x_paths[i])[:, :3]).float().unsqueeze(0).to('cuda')
164 | model.eval()
165 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(x, masked=False)
166 | recons = model.module.sampling(x_vis, mask, center)
167 | all_vis.append(x_vis)
168 | all_sample.append(recons)
169 | all_ref.append(x_hr)
170 | hd = hausdorff_distance(recons, x_hr)
171 | all_hd.append(hd)
172 | print("evaluating model: {i}".format(i=i))
173 |
174 | mean_hd = sum(all_hd) / len(all_hd)
175 | sample = torch.cat(all_sample, dim=0)
176 | refpc = torch.cat(all_ref, dim=0)
177 | all = compute_all_metrics(sample, refpc, 1)
178 | print("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd))
179 | log.info("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd))
180 |
181 | calculate_metric_all()
182 |
--------------------------------------------------------------------------------
/train_decoder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | from torch import optim
4 | from torch.utils.data import DataLoader
5 | from utils.dataset import *
6 | from utils.logger import *
7 | from model.DiffusionPretrain import *
8 | from model.Encoder import *
9 |
10 |
11 | def get_data_iterator(iterable):
12 | """Allows training with DataLoaders in a single infinite loop:
13 | for i, data in enumerate(inf_generator(train_loader)):
14 | """
15 | iterator = iterable.__iter__()
16 | while True:
17 | try:
18 | yield iterator.__next__()
19 | except StopIteration:
20 | iterator = iterable.__iter__()
21 |
22 | parser = argparse.ArgumentParser()
23 | # Experiment setting
24 | parser.add_argument('--batch_size', type=int, default=32)
25 | parser.add_argument('--val_batch_size', type=int, default=1)
26 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
27 | parser.add_argument('--log', type=bool, default=True)
28 | parser.add_argument('--save_dir', type=str, default='./results')
29 |
30 | # Grouping setting
31 | parser.add_argument('--mask_type', type=str, default='rand')
32 | parser.add_argument('--mask_ratio', type=float, default=0.75)
33 | parser.add_argument('--group_size', type=int, default=32) # points in each group
34 | parser.add_argument('--num_group', type=int, default=64) # number of group
35 | parser.add_argument('--num_points', type=int, default=2048)
36 | parser.add_argument('--num_output', type=int, default=8192)
37 | parser.add_argument('--diffusion_output_size', default=2048)
38 |
39 | # Transformer setting
40 | parser.add_argument('--trans_dim', type=int, default=384)
41 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
42 |
43 | # Encoder setting
44 | parser.add_argument('--encoder_depth', type=int, default=12)
45 | parser.add_argument('--encoder_num_heads', type=int, default=6)
46 | parser.add_argument('--loss', type=str, default='cdl2')
47 |
48 | # Decoder setting
49 | parser.add_argument('--decoder_depth', type=int, default=4)
50 | parser.add_argument('--decoder_num_heads', type=int, default=4)
51 |
52 | # diffusion
53 | parser.add_argument('--num_steps', type=int, default=200)
54 | parser.add_argument('--beta_1', type=float, default=1e-4)
55 | parser.add_argument('--beta_T', type=float, default=0.05)
56 | parser.add_argument('--sched_mode', type=str, default='linear')
57 |
58 | # sche / optim
59 | parser.add_argument('--learning_rate', type=float, default=0.001)
60 | parser.add_argument('--weight_decay', type=float, default=0.05)
61 | parser.add_argument('--eta_min', type=float, default=0.000001)
62 | parser.add_argument('--t_max', type=float, default=200)
63 |
64 | args = parser.parse_args()
65 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M")
66 | save_dir = os.path.join(args.save_dir, 'pretrain_decoder_{date}'.format(date=time_now))
67 |
68 | if args.log:
69 | if not os.path.exists(save_dir):
70 | os.makedirs(save_dir)
71 | log = logging.getLogger()
72 | log.setLevel(logging.INFO)
73 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s')
74 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
75 | log_file.setLevel(logging.INFO)
76 | log_file.setFormatter(formatter)
77 | log.addHandler(log_file)
78 |
79 |
80 |
81 | if args.log:
82 | log.info('loading dataset')
83 | print('loading dataset')
84 |
85 | train_dset = ShapeNet(
86 | data_path='dataset/ShapeNet55/ShapeNet-55',
87 | pc_path='dataset/ShapeNet55/shapenet_pc',
88 | subset='train',
89 | n_points=2048,
90 | downsample=True
91 | )
92 | val_dset = ShapeNet(
93 | data_path='dataset/ShapeNet55/ShapeNet-55',
94 | pc_path='dataset/ShapeNet55/shapenet_pc',
95 | subset='test',
96 | n_points=2048,
97 | downsample=True
98 | )
99 |
100 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True)
101 | trn_loader = DataLoader(train_dset, batch_size=args.batch_size, pin_memory=True)
102 | if args.log:
103 | log.info('Training Stable Diffusion')
104 | log.info('config:')
105 | log.info(args)
106 | log.info('dataset loaded')
107 | print('dataset loaded')
108 |
109 | print('loading model')
110 | check_point_dir = os.path.join('./pretrain_model/pretrain/encoder.pt')
111 |
112 | check_point = torch.load(check_point_dir)['model']
113 | encoder = Encoder_Module(args).to(args.device)
114 | encoder.load_state_dict(check_point)
115 | print('model loaded')
116 |
117 | model = Diff_Point_MAE(args).to(args.device)
118 | model = torch.nn.DataParallel(model, device_ids=[0])
119 |
120 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
121 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=args.eta_min, T_max=args.t_max)
122 |
123 | def train(i, batch, epoch):
124 | if i == 0 and epoch > 1:
125 | print('Saving checkpoint at epoch {epoch}'.format(epoch=epoch))
126 | save_model = {'args': args, 'model': model.state_dict()}
127 | if args.log:
128 | torch.save(save_model, os.path.join(save_dir, 'autoencoder_diff.pt'))
129 |
130 | x = batch['lr'].to(args.device)
131 | optimizer.zero_grad()
132 | model.train()
133 | encoder.eval()
134 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(x, masked=False)
135 | t = torch.randint(1, args.num_steps, (x.size(0),))
136 |
137 | loss = model(msk_pc, t, x_vis, mask, center, vis_pc, x)
138 | loss = loss.mean()
139 | loss.backward()
140 | optimizer.step()
141 | if args.log and i == 0:
142 | log.info('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss))
143 | print('epoch: {epoch}, iteration: {i}, loss: {loss}'.format(i=i, epoch=epoch, loss=loss))
144 |
145 |
146 | def validate():
147 | # all_refs = []
148 | all_recons = []
149 | for i, batch in enumerate(val_loader):
150 | print('sampling model {i}'.format(i=i))
151 | ref = batch['lr'].to(args.device)
152 | if i > 200:
153 | break
154 | with torch.no_grad():
155 | model.eval()
156 | encoder.eval()
157 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False)
158 | recons = model.module.sampling(x_vis, mask, center)
159 | all_recons.append(recons)
160 | all_recons = torch.cat(all_recons, dim=0)
161 | np.save(os.path.join(save_dir, 'out.npy'), all_recons.cpu().numpy())
162 |
163 | def forward_diff():
164 | all_recons = []
165 | for i, batch in enumerate(val_loader):
166 | ref = batch['lr'].to(args.device)
167 | if i > 200:
168 | break
169 | with torch.no_grad():
170 | model.eval()
171 | encoder.eval()
172 | if i == 100:
173 | print('sampling model {i}'.format(i=i))
174 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False)
175 |
176 | for idx in range(1, args.num_steps, 1):
177 | print('sampling step {idx}'.format(idx=idx))
178 | t = torch.Tensor([idx]).type(torch.int64)
179 | x_noisze, _ = model.module.forward_diffusion(msk_pc, t)
180 | recons = _.reshape(1, -1, 3)
181 |
182 | all_recons.append(torch.cat([recons, vis_pc.reshape(1, -1, 3)], dim=1))
183 |
184 | all_recons = torch.cat(all_recons, dim=0)
185 | np.save(os.path.join(save_dir, 'out.npy'), all_recons.cpu().numpy())
186 |
187 |
188 | try:
189 | n_it = 100
190 | epoch = 1
191 | while epoch <= n_it:
192 | model.train()
193 | for i, pc in enumerate(trn_loader):
194 | train(i, pc, epoch)
195 | scheduler.step()
196 |
197 | if epoch == n_it:
198 | if args.log:
199 | saved_file = {'args': args, 'model': model.state_dict()}
200 | torch.save(saved_file, os.path.join(save_dir, 'autoencoder_diff.pt'))
201 | validate()
202 | epoch += 1
203 | except Exception as e:
204 | log.error(e)
205 |
206 |
--------------------------------------------------------------------------------
/model/EncoderCompress.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import trunc_normal_
4 | import numpy as np
5 | import random
6 |
7 | import torch.nn.functional as F
8 |
9 | from metrics.evaluation_metrics import chamfer_distance_l2, chamfer_distance_l1
10 | from model.Encoder_Component import Encoder, Group, TransformerEncoder
11 |
12 | class PointTransformer(nn.Module):
13 | def __init__(self, config):
14 | super().__init__()
15 | self.config = config
16 | # define the transformer argparse
17 | self.trans_dim = config.trans_dim
18 | self.depth = config.encoder_depth
19 | self.drop_path_rate = config.drop_path_rate
20 | self.num_heads = config.encoder_num_heads
21 | # embedding
22 | self.encoder_dims = config.trans_dim
23 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
24 | self.mask_type = config.mask_type
25 | self.mask_ratio = config.mask_ratio
26 | self.group_size = config.group_size
27 |
28 | self.pos_embed = nn.Sequential(
29 | nn.Linear(3, 128),
30 | nn.GELU(),
31 | nn.Linear(128, self.trans_dim),
32 | )
33 |
34 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
35 | self.blocks = TransformerEncoder(
36 | embed_dim=self.trans_dim,
37 | depth=self.depth,
38 | drop_path_rate=dpr,
39 | num_heads=self.num_heads,
40 | )
41 |
42 | self.norm = nn.LayerNorm(self.trans_dim)
43 | self.apply(self._init_weights)
44 |
45 | def _init_weights(self, m):
46 | if isinstance(m, nn.Linear):
47 | trunc_normal_(m.weight, std=.02)
48 | if isinstance(m, nn.Linear) and m.bias is not None:
49 | nn.init.constant_(m.bias, 0)
50 | elif isinstance(m, nn.LayerNorm):
51 | nn.init.constant_(m.bias, 0)
52 | nn.init.constant_(m.weight, 1.0)
53 | elif isinstance(m, nn.Conv1d):
54 | trunc_normal_(m.weight, std=.02)
55 | if m.bias is not None:
56 | nn.init.constant_(m.bias, 0)
57 |
58 |
59 | def forward(self, neighborhood, center, mask):
60 | # generate mask
61 | B, _ = mask.size()
62 | _, _, C = neighborhood.shape
63 | vis = neighborhood.reshape(B, -1, self.group_size, C)
64 | group_input_tokens = self.encoder(vis) # B G C
65 |
66 | batch_size, seq_len, L = group_input_tokens.size()
67 | vis_center = center[~mask].reshape(B, -1, C)
68 | p = self.pos_embed(vis_center)
69 | z = self.blocks(group_input_tokens, p)
70 | z = self.norm(z)
71 | return z.reshape(batch_size, -1, L)
72 |
73 | class Encoder_Module(nn.Module):
74 | def __init__(self, config):
75 | super().__init__()
76 | self.config = config
77 | self.trans_dim = config.trans_dim
78 | self.AE_encoder = PointTransformer(config)
79 | self.group_size = config.group_size
80 | self.num_group = config.num_group
81 | self.num_output = config.num_output
82 | self.mask_ratio = config.mask_ratio
83 | self.num_channel = 3
84 | self.drop_path_rate = config.drop_path_rate
85 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
86 |
87 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
88 |
89 | # prediction head
90 | self.increase_dim = nn.Sequential(
91 | nn.Conv1d(self.trans_dim, (self.num_channel * int(self.num_output * (1 - self.mask_ratio))) // self.num_group, 1)
92 | )
93 |
94 | trunc_normal_(self.mask_token, std=.02)
95 | self.loss = config.loss
96 | # loss
97 | self.build_loss_func(self.loss)
98 |
99 | def build_loss_func(self, loss_type):
100 | if loss_type == "cdl1":
101 | self.loss_func = chamfer_distance_l1
102 | elif loss_type == 'cdl2':
103 | self.loss_func = chamfer_distance_l2
104 | elif loss_type == 'mse':
105 | self.loss_func = F.mse_loss
106 | else:
107 | raise NotImplementedError
108 |
109 | def forward(self, pts, neighborhood, center, mask, msk_pc):
110 | neighborhood = neighborhood.float()
111 | center = center.float()
112 | x_vis, vis_pc = self.encode(neighborhood, center, mask)
113 |
114 | B, _, C = x_vis.shape # B VIS C
115 |
116 | rebuild_points = self.increase_dim(x_vis.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
117 |
118 | full = torch.cat([rebuild_points, msk_pc], dim=1)
119 | loss1 = self.loss_func(full, pts)
120 | return loss1
121 |
122 | def encode(self, neighborhood, center, mask):
123 | neighborhood = neighborhood.float()
124 | center = center.float()
125 | x_vis = self.AE_encoder(neighborhood, center, mask)
126 |
127 | full_vis = self.neighborhood(neighborhood, center, mask, x_vis)
128 | return x_vis, full_vis
129 |
130 |
131 | def evaluate(self, x_vis, x_msk):
132 | B, _, C = x_vis.shape # B VIS C
133 | x_full = torch.cat([x_vis, x_msk], dim=1)
134 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
135 | return rebuild_points.reshape(-1, 3).unsqueeze(0)
136 |
137 | def neighborhood(self, neighborhood, center, mask, x_vis):
138 | B, M, N = x_vis.shape
139 | vis_point = neighborhood.reshape(B * M, -1, 3)
140 | full_vis = vis_point + center[~mask].unsqueeze(1)
141 | full_vis = full_vis.reshape(B, -1, 3)
142 |
143 | return full_vis
144 |
145 |
146 | class Compress(nn.Module):
147 | def __init__(self, config):
148 | super().__init__()
149 | self.group_size = config.group_size
150 | self.num_group = config.num_group
151 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
152 | self.mask_ratio = config.mask_ratio
153 | self.mask_type = config.mask_type
154 |
155 | def _mask_center_block(self, center, noaug=False):
156 | if noaug or self.mask_ratio == 0:
157 | return torch.zeros(center.shape[:2]).bool()
158 | mask_idx = []
159 | for points in center:
160 | points = points.unsqueeze(0)
161 | index = random.randint(0, points.size(1) - 1)
162 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1)
163 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0]
164 | ratio = self.mask_ratio
165 | mask_num = int(ratio * len(idx))
166 | mask = torch.zeros(len(idx))
167 | mask[idx[:mask_num]] = 1
168 | mask_idx.append(mask.bool())
169 | bool_masked_pos = torch.stack(mask_idx).to(center.device)
170 | return bool_masked_pos
171 |
172 | def _mask_center_rand(self, center, noaug=False):
173 | '''
174 | center : B G 3
175 | --------------
176 | mask : B G (bool)
177 | '''
178 | B, G, _ = center.shape
179 | # skip the mask
180 | if noaug or self.mask_ratio == 0:
181 | return torch.zeros(center.shape[:2]).bool()
182 |
183 | self.num_mask = int(self.mask_ratio * G)
184 |
185 | overall_mask = np.zeros([B, G])
186 | for i in range(B):
187 | mask = np.hstack([
188 | np.zeros(G - self.num_mask),
189 | np.ones(self.num_mask),
190 | ])
191 | np.random.shuffle(mask)
192 | overall_mask[i, :] = mask
193 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
194 |
195 | return overall_mask.to(center.device) # B G
196 |
197 | def compress(self, pt):
198 | neighborhood, center = self.group_divider(pt)
199 | if self.mask_type == 'rand':
200 | mask = self._mask_center_rand(center) # B G
201 | else:
202 | mask = self._mask_center_block(center)
203 | vis_point = neighborhood[~mask]
204 | return vis_point.half(), center.half(), mask
205 |
--------------------------------------------------------------------------------
/model/DiffusionSR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2
5 | from model.Diffusion import VarianceSchedule, TimeEmbedding
6 | from model.Decoder_Component import *
7 |
8 |
9 | """
10 | Prediction on the masked patches only
11 | w/ diffusion process
12 | Use FC layer as the mask token convertor
13 | """
14 |
15 | class Diff_Point_MAE(nn.Module):
16 | def __init__(self, config):
17 | super().__init__()
18 | self.config = config
19 | self.trans_dim = config.trans_dim
20 | self.group_size = config.group_size
21 | self.num_group = config.num_group
22 | self.num_output = config.diffusion_output_size
23 | self.num_channel = 3
24 | self.drop_path_rate = config.drop_path_rate
25 | self.mask_token = nn.Conv1d((self.num_channel * 2048) // self.num_group, self.trans_dim, 1)
26 | self.mask_token_sr = nn.Conv1d((self.num_channel * self.num_output) // self.num_group, self.trans_dim, 1)
27 | self.decoder_pos_embed = nn.Sequential(
28 | nn.Linear(3, 128),
29 | nn.GELU(),
30 | nn.Linear(128, self.trans_dim)
31 | )
32 |
33 | self.decoder_depth = config.decoder_depth
34 | self.decoder_num_heads = config.decoder_num_heads
35 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)]
36 | self.decoder = Transformer(
37 | embed_dim=self.trans_dim,
38 | depth=self.decoder_depth,
39 | drop_path_rate=dpr,
40 | num_heads=self.decoder_num_heads,
41 | )
42 |
43 | self.loss = config.loss
44 | # loss
45 | self.build_loss_func(self.loss)
46 | self.var = VarianceSchedule(
47 | num_steps=config.num_steps,
48 | beta_1=config.beta_1,
49 | beta_T=config.beta_T,
50 | mode=config.sched_mode
51 | )
52 |
53 | # prediction head
54 | self.increase_dim = nn.Sequential(
55 | nn.Conv1d(self.trans_dim, (self.num_channel * 2048) // self.num_group, 1)
56 | )
57 |
58 | self.increase_dim_sr = nn.Sequential(
59 | nn.Conv1d(self.trans_dim, (self.num_channel * self.num_output) // self.num_group, 1)
60 | )
61 |
62 | self.timestep = config.num_steps
63 | self.beta_1 = config.beta_1
64 | self.beta_T = config.beta_T
65 |
66 | self.betas = self.linear_schedule(timesteps=self.timestep)
67 |
68 | self.alphas = 1.0 - self.betas
69 | self.alpha_bar = torch.cumprod(self.alphas, axis=0)
70 | self.alpha_bar_t_minus_one = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)
71 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
72 | self.sqrt_alphas_bar = torch.sqrt(self.alpha_bar)
73 | self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alpha_bar)
74 | self.sigma = self.betas * (1.0 - self.alpha_bar_t_minus_one) / (1.0 - self.alpha_bar)
75 | self.sqrt_alphas = torch.sqrt(self.alphas)
76 | self.sqrt_alpha_bar_minus_one = torch.sqrt(self.alpha_bar_t_minus_one)
77 |
78 | self.time_emb = nn.Sequential(
79 | TimeEmbedding(self.trans_dim),
80 | nn.Linear(self.trans_dim, self.trans_dim),
81 | nn.ReLU()
82 | )
83 |
84 | def build_loss_func(self, loss_type):
85 | if loss_type == "cdl1":
86 | self.loss_func = chamfer_distance_l1
87 | elif loss_type == 'cdl2':
88 | self.loss_func = chamfer_distance_l2
89 | elif loss_type == 'mse':
90 | self.loss_func = nn.MSELoss()
91 | else:
92 | raise NotImplementedError
93 |
94 | def linear_schedule(self, timesteps):
95 | return torch.linspace(self.beta_1, self.beta_T, timesteps)
96 |
97 | def get_index_from_list(self, vals, t, x_shape):
98 | b = t.shape[0]
99 | out = vals.gather(-1, t.cpu())
100 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device)
101 |
102 | def forward_diffusion(self, x_0, t):
103 | noise = torch.randn_like(x_0).to(x_0.device)
104 | sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_bar, t, x_0.shape).to(x_0.device)
105 | sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_bar, t, x_0.shape).to(x_0.device)
106 |
107 | return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
108 |
109 |
110 | def forward(self, x_0, t, x_vis, mask, center, ori, debug=False):
111 | """
112 | x_0: lr full (vis, msk)
113 | t: time step
114 | x_vis: lr vis latent
115 | mask: mask
116 | center: center
117 | ori: lr overall model
118 | """
119 | # print()
120 | # print(t.size())
121 |
122 | # print(ts.size())
123 | x_t, noise = self.forward_diffusion(x_0, t)
124 |
125 | B, _, C = x_vis.shape # B VIS C
126 |
127 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C)
128 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C)
129 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1)
130 | _, N, _ = pos_emd_msk.shape
131 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1)
132 | mask_token = self.mask_token_sr(x_t.reshape(B, self.num_group, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device)
133 | x_full = mask_token[:, :N, :]
134 | x_full = torch.cat([x_vis, x_full], dim=1)
135 | # x_full = torch.cat([x_vis, mask_token], dim=1) # x_vis, x_vis, msk
136 | x_rec = self.decoder(x_full, pos_full, self.num_group, ts)
137 | x_rec = self.increase_dim_sr(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3)
138 | if debug:
139 | return x_t, noise, x_rec
140 | else:
141 | return self.loss_func(x_rec, ori)
142 | # return F.mse_loss(x_rec, x_0, reduction='mean')
143 |
144 | def sampling_t(self, x, t, mask, center, x_vis):
145 | B, M, C = x_vis.shape # B VIS C
146 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1)
147 | betas_t = self.get_index_from_list(self.betas, t, x.shape).to(x_vis.device)
148 |
149 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C)
150 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C)
151 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1)
152 | _, N, _ = pos_emd_msk.shape
153 | mask_token = self.mask_token_sr(x.reshape(B, self.num_group, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device)
154 | x_full = mask_token[:, :N, :]
155 | x_full = torch.cat([x_vis, x_full], dim=1)
156 | x_rec = self.decoder(x_full, pos_full, self.num_group, ts)
157 | x_rec = self.increase_dim_sr(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3)
158 |
159 |
160 | alpha_bar_t = self.get_index_from_list(self.alpha_bar, t, x.shape).to(x_vis.device)
161 | alpha_bar_t_minus_one = self.get_index_from_list(self.alpha_bar_t_minus_one, t, x.shape).to(x_vis.device)
162 | sqrt_alpha_t = self.get_index_from_list(self.sqrt_alphas, t, x.shape).to(x_vis.device)
163 | sqrt_alphas_bar_t_minus_one = self.get_index_from_list(self.sqrt_alpha_bar_minus_one, t, x.shape).to(x_vis.device)
164 |
165 | model_mean = (sqrt_alpha_t * (1 - alpha_bar_t_minus_one)) / (1 - alpha_bar_t) * x + (sqrt_alphas_bar_t_minus_one * betas_t) / (1 - alpha_bar_t) * x_rec
166 |
167 | sigma_t = self.get_index_from_list(self.sigma, t, x.shape).to(x_vis.device)
168 |
169 | if t == 0:
170 | return model_mean
171 | else:
172 | return model_mean + torch.sqrt(sigma_t) * x_rec
173 |
174 | def sampling(self, x_vis, mask, center, ret=False, noise_patch=None):
175 | B, M, C = x_vis.shape
176 | if noise_patch is None:
177 | noise_patch = torch.randn((B, self.num_output, 3)).to(x_vis.device)
178 | traj = []
179 |
180 | for i in range(0, self.timestep)[::-1]:
181 | t = torch.full((1,), i, device=x_vis.device)
182 | noise_patch = self.sampling_t(noise_patch, t, mask, center, x_vis)
183 | if ret:
184 | traj.append(noise_patch.reshape(B, -1, 3))
185 |
186 | if ret:
187 | return traj
188 | else:
189 | return noise_patch.reshape(B, -1, 3)
190 |
--------------------------------------------------------------------------------
/model/DiffusionPretrain.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2
5 | from model.Diffusion import VarianceSchedule, TimeEmbedding
6 | from model.Decoder_Component import *
7 |
8 | """
9 | Prediction on the masked patches only
10 | w/ diffusion process
11 | Use FC layer as the mask token convertor
12 | """
13 |
14 |
15 | class Diff_Point_MAE(nn.Module):
16 | def __init__(self, config):
17 | super().__init__()
18 | self.config = config
19 | self.trans_dim = config.trans_dim
20 | self.group_size = config.group_size
21 | self.num_group = config.num_group
22 | self.num_output = config.num_output
23 | self.diffusion_output_size = config.diffusion_output_size
24 | self.num_channel = 3
25 | self.drop_path_rate = config.drop_path_rate
26 | self.mask_token = nn.Conv1d((self.num_channel * self.diffusion_output_size) // self.num_group, self.trans_dim, 1)
27 | self.decoder_pos_embed = nn.Sequential(
28 | nn.Linear(3, 128),
29 | nn.GELU(),
30 | nn.Linear(128, self.trans_dim)
31 | )
32 |
33 | self.decoder_depth = config.decoder_depth
34 | self.decoder_num_heads = config.decoder_num_heads
35 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)]
36 | self.MAE_decoder = Transformer(
37 | embed_dim=self.trans_dim,
38 | depth=self.decoder_depth,
39 | drop_path_rate=dpr,
40 | num_heads=self.decoder_num_heads,
41 | )
42 |
43 | self.loss = config.loss
44 | # loss
45 | self.build_loss_func(self.loss)
46 | self.var = VarianceSchedule(
47 | num_steps=config.num_steps,
48 | beta_1=config.beta_1,
49 | beta_T=config.beta_T,
50 | mode=config.sched_mode
51 | )
52 |
53 | # prediction head
54 | self.increase_dim = nn.Sequential(
55 | nn.Conv1d(self.trans_dim, (self.num_channel * 2048) // self.num_group, 1)
56 | )
57 |
58 | self.timestep = config.num_steps
59 | self.beta_1 = config.beta_1
60 | self.beta_T = config.beta_T
61 |
62 | self.betas = self.linear_schedule(timesteps=self.timestep)
63 |
64 | self.alphas = 1.0 - self.betas
65 | self.alpha_bar = torch.cumprod(self.alphas, axis=0)
66 | self.alpha_bar_t_minus_one = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)
67 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
68 | self.sqrt_alphas_bar = torch.sqrt(self.alpha_bar)
69 | self.sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alpha_bar)
70 | self.sigma = self.betas * (1.0 - self.alpha_bar_t_minus_one) / (1.0 - self.alpha_bar)
71 | self.sqrt_alphas = torch.sqrt(self.alphas)
72 | self.sqrt_alpha_bar_minus_one = torch.sqrt(self.alpha_bar_t_minus_one)
73 |
74 | self.time_emb = nn.Sequential(
75 | TimeEmbedding(self.trans_dim),
76 | nn.Linear(self.trans_dim, self.trans_dim),
77 | nn.ReLU()
78 | )
79 |
80 | def build_loss_func(self, loss_type):
81 | if loss_type == "cdl1":
82 | self.loss_func = chamfer_distance_l1
83 | elif loss_type == 'cdl2':
84 | self.loss_func = chamfer_distance_l2
85 | elif loss_type == 'mse':
86 | self.loss_func = F.mse_loss
87 | else:
88 | raise NotImplementedError
89 |
90 | def linear_schedule(self, timesteps):
91 | return torch.linspace(self.beta_1, self.beta_T, timesteps)
92 |
93 | def get_index_from_list(self, vals, t, x_shape):
94 | b = t.shape[0]
95 | out = vals.gather(-1, t.cpu())
96 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device)
97 |
98 | def forward_diffusion(self, x_0, t):
99 | """
100 | Adding noise to the original input for
101 | the forward diffusion process. The noise level
102 | calculated based on current timestep t.
103 |
104 | :param x_0: The original masked input.
105 | [B, NM, C] where
106 | B = Batch size,
107 | NM = Number of point in masked patches,
108 | C = Data Channels (3 for current task)
109 | :param t: The current timestep
110 | :return: The noisy masked input at timestep t.
111 | [B, NM, C]
112 | """
113 | noise = torch.randn_like(x_0).to(x_0.device)
114 | sqrt_alphas_cumprod_t = self.get_index_from_list(self.sqrt_alphas_bar, t, x_0.shape).to(x_0.device)
115 | sqrt_one_minus_alphas_cumprod_t = self.get_index_from_list(self.sqrt_one_minus_alphas_bar, t, x_0.shape).to(
116 | x_0.device)
117 |
118 | return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise
119 |
120 | def forward(self, x_0, t, x_vis, mask, center, vis_pc, ori, debug=False):
121 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1)
122 | x_t, noise = self.forward_diffusion(x_0, t)
123 |
124 | B, _, C = x_vis.shape # B VIS C
125 |
126 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C)
127 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C)
128 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1)
129 | _, N, _ = pos_emd_msk.shape
130 | mask_token = self.mask_token(x_t.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device)
131 | x_full = torch.cat([x_vis, mask_token], dim=1)
132 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts)
133 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3)
134 | pc_full = torch.cat([vis_pc, x_rec], dim=1)
135 | if debug:
136 | return x_t, noise, x_rec
137 | else:
138 | return self.loss_func(pc_full, ori)
139 |
140 | def sampling_t(self, noisy_t, t, mask, center, x_vis):
141 | """
142 | Reverse sampling at timestep t.
143 | Input noisy level at timestep t,
144 | return noisy level at timestep t-1.
145 |
146 | :param noisy_t: The noisy masked patches at timestep t.
147 | [B, NM, C]
148 | :param t: Timestep.
149 | :param mask: The mask indicator. [B, G]
150 | :param center: The center points. [B, G, C]
151 | :param x_vis: The latent of visible patches. [B, V, L]
152 | :return: The noisy masked patches at timestep t-1. [B, NM, C]
153 | """
154 | B, _, C = x_vis.shape # B VIS C
155 | ts = self.time_emb(t.to(x_vis.device)).unsqueeze(1).expand(-1, self.num_group, -1)
156 | betas_t = self.get_index_from_list(self.betas, t, noisy_t.shape).to(x_vis.device)
157 |
158 | pos_emd_vis = self.decoder_pos_embed(center[~mask]).reshape(B, -1, C)
159 | pos_emd_msk = self.decoder_pos_embed(center[mask]).reshape(B, -1, C)
160 | pos_full = torch.cat([pos_emd_vis, pos_emd_msk], dim=1)
161 | _, N, _ = pos_emd_msk.shape
162 | mask_token = self.mask_token(noisy_t.reshape(B, N, -1).transpose(1, 2)).transpose(1, 2).to(x_vis.device)
163 | x_full = torch.cat([x_vis, mask_token], dim=1)
164 | x_rec = self.MAE_decoder(x_full, pos_full, N, ts)
165 | x_rec = self.increase_dim(x_rec.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3)
166 |
167 | alpha_bar_t = self.get_index_from_list(self.alpha_bar, t, noisy_t.shape).to(x_vis.device)
168 | alpha_bar_t_minus_one = self.get_index_from_list(self.alpha_bar_t_minus_one, t, noisy_t.shape).to(x_vis.device)
169 | sqrt_alpha_t = self.get_index_from_list(self.sqrt_alphas, t, noisy_t.shape).to(x_vis.device)
170 | sqrt_alphas_bar_t_minus_one = self.get_index_from_list(self.sqrt_alpha_bar_minus_one, t, noisy_t.shape).to(
171 | x_vis.device)
172 |
173 | model_mean = (sqrt_alpha_t * (1 - alpha_bar_t_minus_one)) / (1 - alpha_bar_t) * noisy_t + (
174 | sqrt_alphas_bar_t_minus_one * betas_t) / (1 - alpha_bar_t) * x_rec
175 |
176 | sigma_t = self.get_index_from_list(self.sigma, t, noisy_t.shape).to(x_vis.device)
177 |
178 | if t == 0:
179 | return model_mean
180 | else:
181 | return model_mean + torch.sqrt(sigma_t) * x_rec
182 |
183 | def sampling(self, x_vis, mask, center, trace=False, noise_patch=None):
184 | """
185 | Sampling the masked patches from Gaussian noise.
186 |
187 | :param x_vis: The latent of visible patches.
188 | [B, V, L]
189 | B = Batch size,
190 | V = Visible patches size,
191 | L = Latent size.
192 | :param mask: The mask indicator.
193 | [B, G]
194 | B = Batch size,
195 | G = Group (Visible patches + Masked patches) size.
196 | :param center: The center points.
197 | [B, G, C]
198 | B = Batch size,
199 | G = Group size,
200 | C = Data Channels (3 for current task)
201 | :param trace: Boolean, False by default.
202 | if true: return all reverse diffusion steps.
203 | else: return the last step only.
204 | :param noise_patch: The pre-defined noises, None by default.
205 | :return: See param trace.
206 | """
207 | B, M, C = x_vis.shape
208 | if noise_patch is None:
209 | noise_patch = torch.randn((B, (self.num_group - M) * self.group_size, 3)).to(x_vis.device)
210 | diffusion_sequence = []
211 |
212 | for i in range(0, self.timestep)[::-1]:
213 | t = torch.full((1,), i, device=x_vis.device)
214 | noise_patch = self.sampling_t(noise_patch, t, mask, center, x_vis)
215 | if trace:
216 | diffusion_sequence.append(noise_patch.reshape(B, -1, 3))
217 |
218 | if trace:
219 | return diffusion_sequence
220 | else:
221 | return noise_patch.reshape(B, -1, 3)
222 |
--------------------------------------------------------------------------------
/model/Encoder_Component.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from timm.models.layers import DropPath, trunc_normal_
4 | import numpy as np
5 | from utils import misc
6 | import random
7 | from knn_cuda import KNN
8 |
9 | class Encoder(nn.Module): ## Embedding module
10 | def __init__(self, encoder_channel):
11 | super().__init__()
12 | self.encoder_channel = encoder_channel
13 | self.first_conv = nn.Sequential(
14 | nn.Conv1d(3, 128, 1),
15 | nn.BatchNorm1d(128),
16 | nn.ReLU(inplace=True),
17 | nn.Conv1d(128, 256, 1)
18 | )
19 | self.second_conv = nn.Sequential(
20 | nn.Conv1d(512, 512, 1),
21 | nn.BatchNorm1d(512),
22 | nn.ReLU(inplace=True),
23 | nn.Conv1d(512, self.encoder_channel, 1)
24 | )
25 |
26 | def forward(self, point_groups):
27 | '''
28 | point_groups : B G N 3
29 | -----------------
30 | feature_global : B G C
31 | '''
32 | bs, g, n, _ = point_groups.shape
33 | point_groups = point_groups.reshape(bs * g, n, 3)
34 | # encoder
35 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
36 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
37 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
38 | feature = self.second_conv(feature) # BG 1024 n
39 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
40 | return feature_global.reshape(bs, g, self.encoder_channel)
41 |
42 |
43 | class Group(nn.Module): # FPS + KNN
44 | def __init__(self, num_group, group_size):
45 | super().__init__()
46 | self.num_group = num_group
47 | self.group_size = group_size
48 | self.knn = KNN(k=self.group_size, transpose_mode=True)
49 |
50 | def forward(self, xyz):
51 | '''
52 | input: B N 3
53 | ---------------------------
54 | output: B G M 3
55 | center : B G 3
56 | '''
57 | batch_size, num_points, _ = xyz.shape
58 | # fps the centers out
59 | center = misc.fps(xyz, self.num_group) # B G 3
60 | # knn to get the neighborhood
61 | _, idx = self.knn(xyz, center) # B G M
62 | assert idx.size(1) == self.num_group
63 | assert idx.size(2) == self.group_size
64 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
65 | idx = idx + idx_base
66 | idx = idx.view(-1)
67 | neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
68 | neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
69 | # normalize
70 | neighborhood = neighborhood - center.unsqueeze(2)
71 | return neighborhood, center
72 |
73 |
74 | ## Transformers
75 | class Mlp(nn.Module):
76 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
77 | super().__init__()
78 | out_features = out_features or in_features
79 | hidden_features = hidden_features or in_features
80 | self.fc1 = nn.Linear(in_features, hidden_features)
81 | self.act = act_layer()
82 | self.fc2 = nn.Linear(hidden_features, out_features)
83 | self.drop = nn.Dropout(drop)
84 |
85 | def forward(self, x):
86 | x = self.fc1(x)
87 | x = self.act(x)
88 | x = self.drop(x)
89 | x = self.fc2(x)
90 | x = self.drop(x)
91 | return x
92 |
93 |
94 | class Attention(nn.Module):
95 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
96 | super().__init__()
97 | self.num_heads = num_heads
98 | head_dim = dim // num_heads
99 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
100 | self.scale = qk_scale or head_dim ** -0.5
101 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
102 | self.attn_drop = nn.Dropout(attn_drop)
103 | self.proj = nn.Linear(dim, dim)
104 | self.proj_drop = nn.Dropout(proj_drop)
105 |
106 | def forward(self, x):
107 | B, N, C = x.shape
108 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
109 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
110 |
111 | attn = (q @ k.transpose(-2, -1)) * self.scale
112 | attn = attn.softmax(dim=-1)
113 | attn = self.attn_drop(attn)
114 |
115 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
116 | x = self.proj(x)
117 | x = self.proj_drop(x)
118 | return x
119 |
120 |
121 | class Block(nn.Module):
122 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
123 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
124 | super().__init__()
125 | self.norm1 = norm_layer(dim)
126 |
127 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
128 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
129 | self.norm2 = norm_layer(dim)
130 | mlp_hidden_dim = int(dim * mlp_ratio)
131 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
132 |
133 | self.attn = Attention(
134 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
135 |
136 | def forward(self, x):
137 | x = x + self.drop_path(self.attn(self.norm1(x)))
138 | x = x + self.drop_path(self.mlp(self.norm2(x)))
139 | return x
140 |
141 |
142 | class TransformerEncoder(nn.Module):
143 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
144 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
145 | super().__init__()
146 |
147 | self.blocks = nn.ModuleList([
148 | Block(
149 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
150 | drop=drop_rate, attn_drop=attn_drop_rate,
151 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
152 | )
153 | for i in range(depth)])
154 |
155 | def forward(self, x, pos):
156 | for _, block in enumerate(self.blocks):
157 | x = block(x + pos)
158 | return x
159 |
160 |
161 | # Pretrain model
162 | class PointTransformer(nn.Module):
163 | def __init__(self, config, **kwargs):
164 | super().__init__()
165 | self.config = config
166 | # define the transformer argparse
167 | self.trans_dim = config.trans_dim
168 | self.depth = config.encoder_depth
169 | self.drop_path_rate = config.drop_path_rate
170 | self.num_heads = config.encoder_num_heads
171 | # embedding
172 | self.encoder_dims = config.trans_dim
173 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
174 | self.mask_type = config.mask_type
175 | self.mask_ratio = config.mask_ratio
176 |
177 | self.pos_embed = nn.Sequential(
178 | nn.Linear(3, 128),
179 | nn.GELU(),
180 | nn.Linear(128, self.trans_dim),
181 | )
182 |
183 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
184 | self.blocks = TransformerEncoder(
185 | embed_dim=self.trans_dim,
186 | depth=self.depth,
187 | drop_path_rate=dpr,
188 | num_heads=self.num_heads,
189 | )
190 |
191 | self.norm = nn.LayerNorm(self.trans_dim)
192 | self.apply(self._init_weights)
193 |
194 | def _init_weights(self, m):
195 | if isinstance(m, nn.Linear):
196 | trunc_normal_(m.weight, std=.02)
197 | if isinstance(m, nn.Linear) and m.bias is not None:
198 | nn.init.constant_(m.bias, 0)
199 | elif isinstance(m, nn.LayerNorm):
200 | nn.init.constant_(m.bias, 0)
201 | nn.init.constant_(m.weight, 1.0)
202 | elif isinstance(m, nn.Conv1d):
203 | trunc_normal_(m.weight, std=.02)
204 | if m.bias is not None:
205 | nn.init.constant_(m.bias, 0)
206 |
207 | def _mask_center_block(self, center, noaug=False):
208 | if noaug or self.mask_ratio == 0:
209 | return torch.zeros(center.shape[:2]).bool()
210 | mask_idx = []
211 | for points in center:
212 | points = points.unsqueeze(0)
213 | index = random.randint(0, points.size(1) - 1)
214 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1)
215 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0]
216 | ratio = self.mask_ratio
217 | mask_num = int(ratio * len(idx))
218 | mask = torch.zeros(len(idx))
219 | mask[idx[:mask_num]] = 1
220 | mask_idx.append(mask.bool())
221 | bool_masked_pos = torch.stack(mask_idx).to(center.device)
222 | return bool_masked_pos
223 |
224 | def _mask_center_rand(self, center, noaug=False):
225 | '''
226 | center : B G 3
227 | --------------
228 | mask : B G (bool)
229 | '''
230 | B, G, _ = center.shape
231 | # skip the mask
232 | if noaug or self.mask_ratio == 0:
233 | return torch.zeros(center.shape[:2]).bool()
234 |
235 | self.num_mask = int(self.mask_ratio * G)
236 |
237 | overall_mask = np.zeros([B, G])
238 | for i in range(B):
239 | mask = np.hstack([
240 | np.zeros(G - self.num_mask),
241 | np.ones(self.num_mask),
242 | ])
243 | np.random.shuffle(mask)
244 | overall_mask[i, :] = mask
245 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
246 |
247 | return overall_mask.to(center.device) # B G
248 |
249 | def forward(self, neighborhood, center, noaug=False):
250 | # generate mask
251 | if self.mask_type == 'rand':
252 | bool_masked_pos = self._mask_center_rand(center, noaug=noaug) # B G
253 | else:
254 | bool_masked_pos = self._mask_center_block(center, noaug=noaug)
255 |
256 | group_input_tokens = self.encoder(neighborhood) # B G C
257 |
258 | batch_size, seq_len, C = group_input_tokens.size()
259 |
260 | p = self.pos_embed(center)
261 |
262 | z = self.blocks(group_input_tokens, p)
263 | z = self.norm(z)
264 | return z[~bool_masked_pos].reshape(batch_size, -1, C), bool_masked_pos, z[bool_masked_pos].reshape(batch_size, -1, C)
265 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # DiffPMAE
2 |
3 | In this work, we propose an effective point cloud reconstruction architecture, DiffPMAE.
4 | Inspired by self-supervised learning concepts, we combine Masked Auto-Encoding and Diffusion
5 | Model mechanism to remotely reconstruct point cloud data. DiffPMAE can be extended to many related
6 | downstream tasks including point cloud compression, upsampling and completion with minimal modifications.
7 |
8 |
9 |
10 |
11 |
12 | GitHub repo: [https://github.com/DiffPMAE/DiffPMAE](https://github.com/DiffPMAE/DiffPMAE)
13 | [[arxiv]](https://arxiv.org/abs/2312.03298) [[poster]](https://eccv.ecva.net/virtual/2024/poster/1690) [[Project page]](https://tyraeldlee.github.io/DiffPMAE.github.io/)
14 | ## Datasets
15 | We use ShapeNet-55 and ModelNet40 for train and validation of the models and PU1K for upsampling validation.
16 | All dataset should be placed in the folder below and that will be read by scripts automatically.
17 | The overall directory structure should be:
18 | ```
19 | │DiffPMAE/
20 | ├──dataset/
21 | │ ├──ModelNet/
22 | │ ├──PU1K/
23 | │ └──ShapeNet55/
24 | ├──.......
25 | ```
26 |
27 | ### ShapeNet-55:
28 |
29 | ```
30 | │ShapeNet55/
31 | ├──ShapeNet-55/
32 | │ ├── train.txt
33 | │ └── test.txt
34 | ├──shapenet_pc/
35 | │ ├── 02691156-1a04e3eab45ca15dd86060f189eb133.npy
36 | │ ├── 02691156-1a6ad7a24bb89733f412783097373bdc.npy
37 | │ ├── .......
38 | ```
39 | Download: You can download the processed ShapeNet55 dataset from [Point-BERT](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md)
40 |
41 | ### ModelNet40:
42 |
43 | ```
44 | │ModelNet40/
45 | ├──modelnet40_shape_names.txt
46 | ├──modelnet40_test.txt
47 | ├──modelnet40_test_8192pts_fps.dat
48 | ├──modelnet40_train.txt
49 | └──modelnet40_train_8192pts_fps.dat
50 | ```
51 | Download: You can download the processed ModelNet40 dataset from [Point-BERT](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md)
52 |
53 | ### PU1K:
54 |
55 | ```
56 | │PU1K/
57 | ├──test/
58 | │ ├── input_256/
59 | │ ├── input_512/
60 | │ ├── input_1024/
61 | │ ├── input_2048/
62 | │ │ ├── gt_8192/
63 | │ │ │ ├── 11509_Panda_v4.xyz
64 | │ │ │ ├── .......
65 | │ │ ├── input_2048/
66 | │ │ │ ├── 11509_Panda_v4.xyz
67 | │ │ │ ├── .......
68 | │ └── original_meshes/
69 | │ │ ├── 11509_Panda_v4.off
70 | │ │ ├── .......
71 | ├──train/
72 | │ └── pu1k_poisson_256_poisson_1024_pc_2500_patch50_addpugan.h5
73 |
74 | ```
75 | Download: You can download the processed PU1K dataset from [PU-GCN](https://github.com/guochengqian/PU-GCN)
76 |
77 | ## Requirements
78 | python >= 3.7
79 | pytorch >= 1.13.1
80 | CUDA >= 11.6
81 | ```
82 | pip install -r requirements.txt
83 | ```
84 |
85 | ```
86 | # PointNet++
87 | pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
88 | # GPU kNN
89 | pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl
90 | ```
91 | ## Pre-trained Models
92 |
93 | Pre-trained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1lkEursyogjBY6HgjPwP6Gy33Y833nQOG?usp=sharing)
94 |
95 | The overall directory structure should be:
96 | ```
97 | │DiffPMAE/
98 | ├──dataset/
99 | ├──pretrain_model/
100 | │ ├──completion/
101 | │ ├──compress/
102 | │ ├──pretrain/
103 | │ └──sr/
104 | ├──.......
105 | ```
106 | ## Training
107 |
108 | For training, you should train the Encoder first by using the command below. Then use pre-trained
109 | Encoder to train a decoder.
110 |
111 | For encoder:
112 |
113 | ```
114 | CUDA_VISIBLE_DEVICES= python train_encoder.py
115 | ```
116 |
117 | Hyperparameter setting can be adjusted in train_encoder.py:
118 |
119 | ```python
120 | # Experiment setting
121 | parser.add_argument('--batch_size', type=int, default=4)
122 | parser.add_argument('--val_batch_size', type=int, default=1)
123 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
124 | parser.add_argument('--save_dir', type=str, default='./results')
125 | parser.add_argument('--log', type=bool, default=False)
126 |
127 | # Grouping setting
128 | parser.add_argument('--mask_type', type=str, default='rand')
129 | parser.add_argument('--mask_ratio', type=float, default=0.75)
130 | parser.add_argument('--group_size', type=int, default=32)
131 | parser.add_argument('--num_group', type=int, default=64)
132 | parser.add_argument('--num_points', type=int, default=2048)
133 | parser.add_argument('--num_output', type=int, default=8192)
134 |
135 | # Transformer setting
136 | parser.add_argument('--trans_dim', type=int, default=384)
137 | parser.add_argument('--depth', type=int, default=12)
138 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
139 | parser.add_argument('--num_heads', type=int, default=6)
140 |
141 | # Encoder setting
142 | parser.add_argument('--encoder_dims', type=int, default=384)
143 | parser.add_argument('--loss', type=str, default='cdl2')
144 |
145 | # sche / optim
146 | parser.add_argument('--learning_rate', type=float, default=0.001)
147 | parser.add_argument('--weight_decay', type=float, default=0.05)
148 | parser.add_argument('--eta_min', type=float, default=0.000001)
149 | parser.add_argument('--t_max', type=float, default=200)
150 | ```
151 |
152 | For decoder:
153 |
154 | ```
155 | CUDA_VISIBLE_DEVICES= python train_decoder.py
156 | ```
157 |
158 | To load the pre-trained Encoder, you can change the following in train_decoder.py:
159 |
160 | ```python
161 | check_point_dir = os.path.join('./pretrain_model/pretrain/encoder.pt')
162 |
163 | check_point = torch.load(check_point_dir)['model']
164 | encoder = Encoder_Module(args).to(args.device)
165 | encoder.load_state_dict(check_point)
166 | ```
167 |
168 | Hyperparameter setting for Decoder can be adjusted in train_decoder.py:
169 |
170 | ```python
171 | # Experiment setting
172 | parser.add_argument('--batch_size', type=int, default=32)
173 | parser.add_argument('--val_batch_size', type=int, default=1)
174 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
175 | parser.add_argument('--log', type=bool, default=True)
176 | parser.add_argument('--save_dir', type=str, default='./results')
177 |
178 | # Grouping setting
179 | parser.add_argument('--mask_type', type=str, default='rand')
180 | parser.add_argument('--mask_ratio', type=float, default=0.75)
181 | parser.add_argument('--group_size', type=int, default=32) # points in each group
182 | parser.add_argument('--num_group', type=int, default=64) # number of group
183 | parser.add_argument('--num_points', type=int, default=2048)
184 | parser.add_argument('--num_output', type=int, default=8192)
185 | parser.add_argument('--diffusion_output_size', default=2048)
186 |
187 | # Transformer setting
188 | parser.add_argument('--trans_dim', type=int, default=384)
189 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
190 |
191 | # Encoder setting
192 | parser.add_argument('--encoder_depth', type=int, default=12)
193 | parser.add_argument('--encoder_num_heads', type=int, default=6)
194 | parser.add_argument('--loss', type=str, default='cdl2')
195 |
196 | # Decoder setting
197 | parser.add_argument('--decoder_depth', type=int, default=4)
198 | parser.add_argument('--decoder_num_heads', type=int, default=4)
199 |
200 | # diffusion
201 | parser.add_argument('--num_steps', type=int, default=200)
202 | parser.add_argument('--beta_1', type=float, default=1e-4)
203 | parser.add_argument('--beta_T', type=float, default=0.05)
204 | parser.add_argument('--sched_mode', type=str, default='linear')
205 |
206 | # sche / optim
207 | parser.add_argument('--learning_rate', type=float, default=0.001)
208 | parser.add_argument('--weight_decay', type=float, default=0.05)
209 | parser.add_argument('--eta_min', type=float, default=0.000001)
210 | parser.add_argument('--t_max', type=float, default=200)
211 | ```
212 |
213 | ## Evaluation
214 | For pre-train model:
215 |
216 | ```
217 | python eval_diffpmae.py
218 | ```
219 |
220 | For upsampling:
221 | ```
222 | python eval_upsampling.py
223 | ```
224 |
225 | For compression:
226 | ```
227 | python eval_compression.py
228 | ```
229 |
230 | The configuration for each task can be adjusted in corresponding python file.
231 | For example, the model configuration for pre-train evaluation can be adjusted in
232 | eval_diffpmae.py file at L14~45.
233 |
234 | For experiment setup:
235 |
236 | ```python
237 | parser.add_argument('--batch_size', type=int, default=32)
238 | # Batch size
239 | parser.add_argument('--val_batch_size', type=int, default=1)
240 | # Validation size
241 | parser.add_argument('--device', type=str, default='cuda')
242 | parser.add_argument('--log', type=bool, default=True)
243 | # Both trained model and log will not saved when set False.
244 | parser.add_argument('--save_dir', type=str, default='./results')
245 | # The root directory of saved file.
246 | ```
247 |
248 | For Grouping setting
249 |
250 | ```python
251 | parser.add_argument('--mask_type', type=str, default='rand')
252 | # Could be either rand or block
253 | parser.add_argument('--mask_ratio', type=float, default=0.75)
254 | parser.add_argument('--group_size', type=int, default=32)
255 | # Points in each group
256 | parser.add_argument('--num_group', type=int, default=64)
257 | # Number of group
258 | parser.add_argument('--num_points', type=int, default=2048)
259 | # Input size of point cloud
260 | parser.add_argument('--num_output', type=int, default=8192)
261 | # Output size of Encoder module
262 | parser.add_argument('--diffusion_output_size', default=2048)
263 | #Output size of Decoder module
264 | ```
265 |
266 | For Transformer setting
267 |
268 | ```python
269 | parser.add_argument('--trans_dim', type=int, default=384)
270 | # Latent size
271 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
272 | ```
273 |
274 | For Encoder setting
275 |
276 | ```python
277 | parser.add_argument('--encoder_depth', type=int, default=12)
278 | # Number of blocks in Encoder Transformer
279 | parser.add_argument('--encoder_num_heads', type=int, default=6)
280 | # Number of heads in each Transformer block
281 | parser.add_argument('--loss', type=str, default='cdl2')
282 | ```
283 |
284 | For Decoder setting
285 |
286 | ```python
287 | parser.add_argument('--decoder_depth', type=int, default=4)
288 | # Number of blocks in Decoder Transformer
289 | parser.add_argument('--decoder_num_heads', type=int, default=4)
290 | # Number of heads in each Transformer block
291 | ```
292 |
293 | For diffusion process
294 |
295 | ```python
296 | parser.add_argument('--num_steps', type=int, default=200)
297 | parser.add_argument('--beta_1', type=float, default=1e-4)
298 | parser.add_argument('--beta_T', type=float, default=0.05)
299 | parser.add_argument('--sched_mode', type=str, default='linear')
300 | ```
301 |
302 | For optimizer and scheduler
303 |
304 | ```python
305 | parser.add_argument('--learning_rate', type=float, default=0.001)
306 | parser.add_argument('--weight_decay', type=float, default=0.05)
307 | parser.add_argument('--eta_min', type=float, default=0.000001)
308 | parser.add_argument('--t_max', type=float, default=200)
309 | ```
310 | ## Acknowledgements
311 | Our code build based on [PointMAE](https://github.com/Pang-Yatian/Point-MAE)
312 |
313 | ## Citation
314 | ```bibtex
315 | @inproceedings{li2024diffpmae,
316 | author = {Yanlong Li and Chamara Madarasingha and Kanchana Thilakarathna},
317 | title = {DiffPMAE: Diffusion Masked Autoencoders for Point Cloud Reconstruction},
318 | booktitle = {ECCV},
319 | year = {2024},
320 | url = {https://arxiv.org/abs/2312.03298}
321 | }
322 | ```
323 |
--------------------------------------------------------------------------------
/metrics/pytorch_structural_losses/src/approxmatch.cu:
--------------------------------------------------------------------------------
1 | #include "utils.hpp"
2 |
3 | __global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){
4 | float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
5 | float multiL,multiR;
6 | if (n>=m){
7 | multiL=1;
8 | multiR=n/m;
9 | }else{
10 | multiL=m/n;
11 | multiR=1;
12 | }
13 | const int Block=1024;
14 | __shared__ float buf[Block*4];
15 | for (int i=blockIdx.x;i=-2;j--){
24 | for (int j=7;j>-2;j--){
25 | float level=-powf(4.0f,j);
26 | if (j==-2){
27 | level=0;
28 | }
29 | for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,out);
227 | //}
228 |
229 | __global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){
230 | __shared__ float sum_grad[256*3];
231 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2);
294 | //}
295 |
296 | /*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
297 | cudaStream_t stream)*/
298 | // temp: TensorShape{b,(n+m)*2}
299 | void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){
300 | approxmatchkernel
301 | <<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp);
302 |
303 | cudaError_t err = cudaGetLastError();
304 | if (cudaSuccess != err)
305 | throw std::runtime_error(Formatter()
306 | << "CUDA kernel failed : " << std::to_string(err));
307 | }
308 |
309 | void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){
310 | matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out);
311 |
312 | cudaError_t err = cudaGetLastError();
313 | if (cudaSuccess != err)
314 | throw std::runtime_error(Formatter()
315 | << "CUDA kernel failed : " << std::to_string(err));
316 | }
317 |
318 | void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){
319 | matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1);
320 | matchcostgrad2kernel<<>>(b,n,m,xyz1,xyz2,match,grad2);
321 |
322 | cudaError_t err = cudaGetLastError();
323 | if (cudaSuccess != err)
324 | throw std::runtime_error(Formatter()
325 | << "CUDA kernel failed : " << std::to_string(err));
326 | }
327 |
--------------------------------------------------------------------------------
/eval_diffpmae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 |
4 | from torch.utils.data import DataLoader
5 | from utils.logger import *
6 | from utils.dataset import *
7 | from model.DiffusionPretrain import *
8 | from model.Encoder import *
9 | from utils.visualization import *
10 | from metrics.evaluation_metrics import compute_all_metrics, averaged_hausdorff_distance, jsd_between_point_cloud_sets
11 |
12 | parser = argparse.ArgumentParser()
13 | # Experiment setting
14 | parser.add_argument('--val_batch_size', type=int, default=1)
15 | parser.add_argument('--device', type=str, default='cuda') # mps for mac
16 | parser.add_argument('--log', type=bool, default=True)
17 | parser.add_argument('--save_dir', type=str, default='./results')
18 |
19 | # Grouping setting
20 | parser.add_argument('--mask_type', type=str, default='rand')
21 | parser.add_argument('--mask_ratio', type=float, default=0.75)
22 | parser.add_argument('--group_size', type=int, default=32) # points in each group
23 | parser.add_argument('--num_group', type=int, default=64) # number of group
24 | parser.add_argument('--num_points', type=int, default=2048)
25 | parser.add_argument('--num_output', type=int, default=8192)
26 | parser.add_argument('--diffusion_output_size', type=int, default=2048)
27 |
28 | # Transformer setting
29 | parser.add_argument('--trans_dim', type=int, default=384)
30 | parser.add_argument('--drop_path_rate', type=float, default=0.1)
31 |
32 | # Encoder setting
33 | parser.add_argument('--encoder_depth', type=int, default=12)
34 | parser.add_argument('--encoder_num_heads', type=int, default=6)
35 | parser.add_argument('--loss', type=str, default='cdl2')
36 |
37 | # Decoder setting
38 | parser.add_argument('--decoder_depth', type=int, default=4)
39 | parser.add_argument('--decoder_num_heads', type=int, default=4)
40 |
41 | # diffusion
42 | parser.add_argument('--num_steps', type=int, default=200)
43 | parser.add_argument('--beta_1', type=float, default=1e-4)
44 | parser.add_argument('--beta_T', type=float, default=0.05)
45 | parser.add_argument('--sched_mode', type=str, default='linear')
46 |
47 |
48 | args = parser.parse_args()
49 | time_now = datetime.now().strftime("%Y_%m_%d_%H_%M")
50 | save_dir = os.path.join(args.save_dir, 'new_start_eval_{date}'.format(date=time_now))
51 |
52 | if args.log:
53 | if not os.path.exists(save_dir):
54 | os.makedirs(save_dir)
55 | log = logging.getLogger()
56 | log.setLevel(logging.INFO)
57 | formatter = logging.Formatter('[%(asctime)s::%(levelname)s] %(message)s')
58 | log_file = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
59 | log_file.setLevel(logging.INFO)
60 | log_file.setFormatter(formatter)
61 | log.addHandler(log_file)
62 |
63 | if args.log:
64 | log.info('loading dataset')
65 | print('loading dataset')
66 |
67 | test_dset_MN = ModelNet(
68 | root='dataset/ModelNet',
69 | number_pts=8192,
70 | use_normal=False,
71 | cats=40,
72 | subset='test'
73 | )
74 | val_loader_MN = DataLoader(test_dset_MN, batch_size=args.val_batch_size, pin_memory=True)
75 |
76 | val_dset = ShapeNet(
77 | data_path='dataset/ShapeNet55/ShapeNet-55',
78 | pc_path='dataset/ShapeNet55/shapenet_pc',
79 | subset='test',
80 | n_points=2048,
81 | downsample=True
82 | )
83 |
84 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, pin_memory=True)
85 |
86 | if args.log:
87 | log.info('Training Stable Diffusion')
88 | log.info('config:')
89 | log.info(args)
90 | log.info('dataset loaded')
91 | print('dataset loaded')
92 |
93 | print('loading model')
94 | check_point_dir = os.path.join('./pretrain_model/pretrain/encoder.pt')
95 | check_point = torch.load(check_point_dir)['model']
96 | encoder = Encoder_Module(args).to(args.device)
97 | encoder.load_state_dict(check_point)
98 | print('model loaded')
99 |
100 | diff_check_point_dir = os.path.join('./pretrain_model/pretrain/decoder.pt')
101 | diff_check_point = torch.load(diff_check_point_dir)['model']
102 | model = Diff_Point_MAE(args).to(args.device)
103 | model = torch.nn.DataParallel(model, device_ids=[0])
104 | model.load_state_dict(diff_check_point)
105 |
106 | PINK = [0xdd / 255, 0x83 / 255, 0xa2 / 255]
107 | BLUE = [0x81 / 255, 0xb0 / 255, 0xbf / 255]
108 | YELLOW = [0xf3 / 255, 0xdb / 255, 0x74 / 255]
109 | CYAN = [0x2b / 255, 0xda / 255, 0xc0 / 255]
110 | GRAY = [0xaa / 255, 0xaa / 255, 0xaa / 255]
111 |
112 |
113 | def visiualizaion(index=-1):
114 | for i, batch in enumerate(val_loader):
115 | with torch.no_grad():
116 | if index == -1 or index == i:
117 | ref = batch['lr'].to(args.device)
118 | model.eval()
119 | encoder.eval()
120 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False)
121 | print('sampling model {i}'.format(i=i))
122 | t = torch.randint(0, args.num_steps, (ref.size(0),))
123 | recons = model.module.sampling(x_vis, mask, center, True)
124 | eval_model_normal(ref.cpu().numpy(), 'ori', show_axis=False, show_img=False, save_img=True,
125 | save_location=save_dir, color=BLUE)
126 | eval_model_normal(vis_pc.cpu().numpy(), 'msk', show_axis=False, show_img=False, save_img=True,
127 | save_location=save_dir, color=PINK)
128 | for i in range(0, len(recons)):
129 | if i in [0, 10, 20, 30, 40, 50, 90, 150, 180]:
130 | eval_model_normal(recons[i].cpu().numpy(), i, show_axis=False, show_img=True, save_img=True,
131 | save_location=save_dir, color=BLUE)
132 | # recons = torch.cat([vis_pc.reshape(1, -1, 3), recons], dim=1)
133 | eval_model_normal(recons[199].cpu().numpy(), 199, show_axis=False, show_img=False, save_img=True,
134 | save_location=save_dir, color=BLUE)
135 | # vis_results(recons, ref, vis_pc, i, show_axis=True, save_img=True, show_img=False)
136 | eval_model_normal(torch.cat([recons[199], vis_pc], dim=1).cpu().numpy(), -2, show_axis=False,
137 | show_img=False, save_img=True, save_location=save_dir, color=BLUE)
138 | eval_model_normal(recons[199].cpu().numpy(), 'comb', second_part=vis_pc.cpu().numpy(), show_axis=False,
139 | show_img=True, save_img=True, save_location=save_dir, color=BLUE, color2=PINK)
140 | eval_model_normal(msk_pc.cpu().numpy(), 'segment', second_part=vis_pc.cpu().numpy(), show_axis=False,
141 | show_img=False, save_img=True, save_location=save_dir, color=GRAY, color2=PINK)
142 | if index <= i:
143 | break
144 |
145 | def calculate_metric_all_on_prediction(size=-1, enc=encoder, pred=model):
146 | all_sample = []
147 | all_ref = []
148 | all_hd = []
149 | for i, batch in enumerate(val_loader):
150 | if i == size:
151 | break
152 | with torch.no_grad():
153 | ref = batch['lr'].to(args.device)
154 | pred.eval()
155 | x_vis, z_masked, mask, center, vis_pc, msk_pc = enc.encode(ref, masked=False)
156 | recons = pred.module.sampling(x_vis, mask, center, False)
157 | recons = torch.cat([recons, vis_pc], dim=1)
158 | all_sample.append(recons)
159 | all_ref.append(ref)
160 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze())
161 | all_hd.append(hd)
162 |
163 | print("evaluating model: {i}".format(i=i))
164 | log.info("evaluating model: {i}".format(i=i))
165 |
166 | mean_hd = sum(all_hd) / len(all_hd)
167 | sample = torch.cat(all_sample, dim=0)
168 | refpc = torch.cat(all_ref, dim=0)
169 | all = compute_all_metrics(sample, refpc, 32)
170 | print(
171 | "MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(
172 | mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'],
173 | N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd))
174 | log.info(
175 | "MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(
176 | mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'],
177 | N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd))
178 |
179 |
180 | def auto_run():
181 | mask_ratio = [0.75]
182 | for v in mask_ratio:
183 | args.mask_ratio = v
184 | log.info("mask ratio: {mr}".format(mr=v))
185 | print('loading encoder')
186 | check_point_dir_l = os.path.join('./pretrain_model/pretrain/encoder.pt')
187 | check_point_l = torch.load(check_point_dir_l)['model']
188 | encoder_l = Encoder_Module(args).to(args.device)
189 | encoder_l.load_state_dict(check_point_l)
190 | print('encoder loaded\rloading model')
191 |
192 | diff_check_point_dir_l = os.path.join('./pretrain_model/pretrain/decoder.pt')
193 | diff_check_point_l = torch.load(diff_check_point_dir_l)['model']
194 | model_l = Diff_Point_MAE(args).to(args.device)
195 | model_l = torch.nn.DataParallel(model_l, device_ids=[0])
196 | model_l.load_state_dict(diff_check_point_l)
197 | print('model loaded')
198 | calculate_metric_all_on_prediction(2048, encoder_l, model_l)
199 |
200 |
201 | # auto_run()
202 |
203 | def calculate_metric_all_set():
204 | all_cd = []
205 | all_sample = []
206 | all_ref = []
207 | all_hd = []
208 | for i, batch in enumerate(val_loader):
209 | with torch.no_grad():
210 | ref = batch['lr'].to(args.device)
211 | model.eval()
212 | encoder.eval()
213 | x_vis, z_masked, mask, center, vis_pc, msk_pc = encoder.encode(ref, masked=False)
214 | recons = model.module.sampling(x_vis, mask, center, False)
215 | recons = torch.cat([vis_pc, recons], dim=1)
216 | all_sample.append(recons)
217 | all_ref.append(ref)
218 | hd = averaged_hausdorff_distance(recons.squeeze(), ref.squeeze())
219 | all_hd.append(hd)
220 |
221 | cd = model.module.loss_func(recons, ref)
222 | all_cd.append(cd)
223 |
224 | print("evaluating model: {i}".format(i=i))
225 | # log.info("evaluating model: {i}, CD={cd}".format(i=i, cd=cd))
226 |
227 | mean_cd = sum(all_cd) / len(all_cd)
228 | mean_hd = sum(all_hd) / len(all_hd)
229 | sample = torch.cat(all_sample, dim=0)
230 | refpc = torch.cat(all_ref, dim=0)
231 | JSD = jsd_between_point_cloud_sets(sample, refpc)
232 | # all = compute_all_metrics(sample, refpc, 32)
233 | # print("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd))
234 | # log.info("MMD CD: {mmd}, \r\nCOV CD: {cov}, \r\nMMD-SMP CD: {mmd_smp}, \r\n1NN CD-t: {N_t}, \r\n1NN CD-f: {N_f}, \r\n1NN CD: {N}\r\nJSD: {jsd}\r\nHD: {hd}".format(mmd=all['lgan_mmd-CD'], cov=all['lgan_cov-CD'], mmd_smp=all['lgan_mmd_smp-CD'], N_t=all['1-NN-CD-acc_t'], N_f=all['1-NN-CD-acc_f'], N=all['1-NN-CD-acc'], jsd=all['JSD'], hd=mean_hd))
235 | print("Mean CD: {mCD}, Mean HD: {mHD}, JSD: {jsd}".format(mCD=mean_cd, mHD=mean_hd, jsd=JSD))
236 | log.info("Mean CD: {mCD}, Mean HD: {mHD}, JSD: {jsd}".format(mCD=mean_cd, mHD=mean_hd, jsd=JSD))
237 |
238 | calculate_metric_all_set()
--------------------------------------------------------------------------------
/metrics/evaluation_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | From https://github.com/stevenygd/PointFlow/tree/master/metrics
3 | """
4 | import torch
5 | import numpy as np
6 | import warnings
7 | from scipy.stats import entropy
8 | from sklearn.neighbors import NearestNeighbors
9 | from numpy.linalg import norm
10 | from tqdm.auto import tqdm
11 |
12 | _EMD_NOT_IMPL_WARNED = True
13 |
14 |
15 | def emd_approx(sample, ref):
16 | global _EMD_NOT_IMPL_WARNED
17 | emd = torch.zeros([sample.size(0)]).to(sample)
18 | if not _EMD_NOT_IMPL_WARNED:
19 | _EMD_NOT_IMPL_WARNED = True
20 | print('\n\n[WARNING]')
21 | print(' * EMD is not implemented due to GPU compatability issue.')
22 | print(' * We will set all EMD to zero by default.')
23 | print(' * You may implement your own EMD in the function `emd_approx` in ./evaluation/evaluation_metrics.py')
24 | print('\n')
25 | return emd
26 |
27 |
28 | # Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
29 | def distChamfer(a, b):
30 | x, y = a, b
31 | bs, num_points, points_dim = x.size()
32 | xx = torch.bmm(x, x.transpose(2, 1))
33 | yy = torch.bmm(y, y.transpose(2, 1))
34 | zz = torch.bmm(x, y.transpose(2, 1))
35 | diag_ind = torch.arange(0, num_points).to(a).long()
36 | rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
37 | ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
38 | P = (rx.transpose(2, 1) + ry - 2 * zz)
39 | return P.min(1)[0], P.min(2)[0]
40 |
41 |
42 | def chamfer_distance_l1(a, b):
43 | d1, d2 = distChamfer(a, b)
44 | d1 = torch.sqrt(d1)
45 | d2 = torch.sqrt(d2)
46 | return (torch.mean(d1) + torch.mean(d2)) / 2
47 |
48 |
49 | def chamfer_distance_l2(a, b):
50 | d1, d2 = distChamfer(a, b)
51 | return torch.mean(d1) + torch.mean(d2)
52 |
53 |
54 | def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
55 | N_sample = sample_pcs.shape[0]
56 | N_ref = ref_pcs.shape[0]
57 | assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
58 |
59 | cd_lst = []
60 | emd_lst = []
61 | iterator = range(0, N_sample, batch_size)
62 |
63 | for b_start in tqdm(iterator, desc='EMD-CD'):
64 | b_end = min(N_sample, b_start + batch_size)
65 | sample_batch = sample_pcs[b_start:b_end]
66 | ref_batch = ref_pcs[b_start:b_end]
67 |
68 | dl, dr = distChamfer(sample_batch, ref_batch)
69 | cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
70 |
71 | emd_batch = emd_approx(sample_batch, ref_batch)
72 | emd_lst.append(emd_batch)
73 |
74 | if reduced:
75 | cd = torch.cat(cd_lst).mean()
76 | emd = torch.cat(emd_lst).mean()
77 | else:
78 | cd = torch.cat(cd_lst)
79 | emd = torch.cat(emd_lst)
80 |
81 | results = {
82 | 'MMD-CD': cd,
83 | 'MMD-EMD': emd,
84 | }
85 | return results
86 |
87 |
88 | def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True):
89 | N_sample = sample_pcs.shape[0]
90 | N_ref = ref_pcs.shape[0]
91 | all_cd = []
92 | iterator = range(N_sample)
93 | if verbose:
94 | iterator = tqdm(iterator, desc='Pairwise EMD-CD')
95 | for sample_b_start in iterator:
96 | sample_batch = sample_pcs[sample_b_start]
97 |
98 | cd_lst = []
99 |
100 | sub_iterator = range(0, N_ref, batch_size)
101 | # if verbose:
102 | # sub_iterator = tqdm(sub_iterator, leave=False)
103 | for ref_b_start in sub_iterator:
104 | ref_b_end = min(N_ref, ref_b_start + batch_size)
105 | ref_batch = ref_pcs[ref_b_start:ref_b_end]
106 |
107 | batch_size_ref = ref_batch.size(0)
108 | point_dim = ref_batch.size(2)
109 | sample_batch_exp = sample_batch.view(1, -1, point_dim).expand(
110 | batch_size_ref, -1, -1)
111 | sample_batch_exp = sample_batch_exp.contiguous()
112 |
113 | dl, dr = distChamfer(sample_batch_exp, ref_batch)
114 | cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
115 |
116 | cd_lst = torch.cat(cd_lst, dim=1)
117 | all_cd.append(cd_lst)
118 |
119 | all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
120 |
121 | return all_cd
122 |
123 |
124 | # Adapted from https://github.com/xuqiantong/
125 | # GAN-Metrics/blob/master/framework/metric.py
126 | def knn(Mxx, Mxy, Myy, k, sqrt=False):
127 | n0 = Mxx.size(0)
128 | n1 = Myy.size(0)
129 | label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)
130 | M = torch.cat([
131 | torch.cat((Mxx, Mxy), 1),
132 | torch.cat((Mxy.transpose(0, 1), Myy), 1)], 0)
133 | if sqrt:
134 | M = M.abs().sqrt()
135 | INFINITY = float('inf')
136 | val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(
137 | k, 0, False)
138 |
139 | count = torch.zeros(n0 + n1).to(Mxx)
140 | for i in range(0, k):
141 | count = count + label.index_select(0, idx[i])
142 | pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
143 |
144 | s = {
145 | 'tp': (pred * label).sum(),
146 | 'fp': (pred * (1 - label)).sum(),
147 | 'fn': ((1 - pred) * label).sum(),
148 | 'tn': ((1 - pred) * (1 - label)).sum(),
149 | }
150 |
151 | s.update({
152 | 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
153 | 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
154 | 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
155 | 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
156 | 'acc': torch.eq(label, pred).float().mean(),
157 | })
158 | return s
159 |
160 |
161 | def lgan_mmd_cov(all_dist):
162 | N_sample, N_ref = all_dist.size(0), all_dist.size(1)
163 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
164 | min_val, _ = torch.min(all_dist, dim=0)
165 | mmd = min_val.mean()
166 | mmd_smp = min_val_fromsmp.mean()
167 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
168 | cov = torch.tensor(cov).to(all_dist)
169 | return {
170 | 'lgan_mmd': mmd,
171 | 'lgan_cov': cov,
172 | 'lgan_mmd_smp': mmd_smp,
173 | }
174 |
175 |
176 | def lgan_mmd_cov_match(all_dist):
177 | N_sample, N_ref = all_dist.size(0), all_dist.size(1)
178 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
179 | min_val, _ = torch.min(all_dist, dim=0)
180 | mmd = min_val.mean()
181 | mmd_smp = min_val_fromsmp.mean()
182 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
183 | cov = torch.tensor(cov).to(all_dist)
184 | return {
185 | 'lgan_mmd': mmd,
186 | 'lgan_cov': cov,
187 | 'lgan_mmd_smp': mmd_smp,
188 | }, min_idx.view(-1)
189 |
190 |
191 | def compute_all_metrics(sample_pcs, ref_pcs, batch_size):
192 | results = {}
193 |
194 | # print("Pairwise EMD CD")
195 | M_rs_cd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)
196 |
197 | ## CD
198 | res_cd = lgan_mmd_cov(M_rs_cd.t())
199 | results.update({
200 | "%s-CD" % k: v for k, v in res_cd.items()
201 | })
202 |
203 | for k, v in results.items():
204 | print('[%s] %.8f' % (k, v.item()))
205 |
206 | M_rr_cd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
207 | M_ss_cd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)
208 |
209 | # 1-NN results
210 | ## CD
211 | one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
212 | results.update({
213 | "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
214 | })
215 |
216 | # JSD results
217 | jsd = jsd_between_point_cloud_sets(sample_pcs, ref_pcs)
218 | results.update({
219 | "JSD": jsd
220 | })
221 |
222 |
223 |
224 | return results
225 |
226 |
227 | #######################################################
228 | # JSD : from https://github.com/optas/latent_3d_points
229 | #######################################################
230 | def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
231 | """Returns the center coordinates of each cell of a 3D grid with
232 | resolution^3 cells, that is placed in the unit-cube. If clip_sphere it True
233 | it drops the "corner" cells that lie outside the unit-sphere.
234 | """
235 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
236 | spacing = 1.0 / float(resolution - 1)
237 | for i in range(resolution):
238 | for j in range(resolution):
239 | for k in range(resolution):
240 | grid[i, j, k, 0] = i * spacing - 0.5
241 | grid[i, j, k, 1] = j * spacing - 0.5
242 | grid[i, j, k, 2] = k * spacing - 0.5
243 |
244 | if clip_sphere:
245 | grid = grid.reshape(-1, 3)
246 | grid = grid[norm(grid, axis=1) <= 0.5]
247 |
248 | return grid, spacing
249 |
250 |
251 | def jsd_between_point_cloud_sets(
252 | sample_pcs, ref_pcs, resolution=28):
253 | """Computes the JSD between two sets of point-clouds,
254 | as introduced in the paper
255 | ```Learning Representations And Generative Models For 3D Point Clouds```.
256 | Args:
257 | sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
258 | ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
259 | resolution: (int) grid-resolution. Affects granularity of measurements.
260 | """
261 | sample_pcs = sample_pcs.cpu().numpy()
262 | ref_pcs = ref_pcs.cpu().numpy()
263 | in_unit_sphere = True
264 | sample_grid_var = entropy_of_occupancy_grid(
265 | sample_pcs, resolution, in_unit_sphere)[1]
266 | ref_grid_var = entropy_of_occupancy_grid(
267 | ref_pcs, resolution, in_unit_sphere)[1]
268 | return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
269 |
270 |
271 | def entropy_of_occupancy_grid(
272 | pclouds, grid_resolution, in_sphere=False, verbose=False):
273 | """Given a collection of point-clouds, estimate the entropy of
274 | the random variables corresponding to occupancy-grid activation patterns.
275 | Inputs:
276 | pclouds: (numpy array) #point-clouds x points per point-cloud x 3
277 | grid_resolution (int) size of occupancy grid that will be used.
278 | """
279 | epsilon = 10e-4
280 | bound = 0.5 + epsilon
281 | if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
282 | if verbose:
283 | warnings.warn('Point-clouds are not in unit cube.')
284 |
285 | if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
286 | if verbose:
287 | warnings.warn('Point-clouds are not in unit sphere.')
288 |
289 | grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
290 | grid_coordinates = grid_coordinates.reshape(-1, 3)
291 | grid_counters = np.zeros(len(grid_coordinates))
292 | grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
293 | nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
294 |
295 | for pc in tqdm(pclouds, desc='JSD'):
296 | _, indices = nn.kneighbors(pc)
297 | indices = np.squeeze(indices)
298 | for i in indices:
299 | grid_counters[i] += 1
300 | indices = np.unique(indices)
301 | for i in indices:
302 | grid_bernoulli_rvars[i] += 1
303 |
304 | acc_entropy = 0.0
305 | n = float(len(pclouds))
306 | for g in grid_bernoulli_rvars:
307 | if g > 0:
308 | p = float(g) / n
309 | acc_entropy += entropy([p, 1.0 - p])
310 |
311 | return acc_entropy / len(grid_counters), grid_counters
312 |
313 |
314 | def jensen_shannon_divergence(P, Q):
315 | if np.any(P < 0) or np.any(Q < 0):
316 | raise ValueError('Negative values.')
317 | if len(P) != len(Q):
318 | raise ValueError('Non equal size.')
319 |
320 | P_ = P / np.sum(P) # Ensure probabilities.
321 | Q_ = Q / np.sum(Q)
322 |
323 | e1 = entropy(P_, base=2)
324 | e2 = entropy(Q_, base=2)
325 | e_sum = entropy((P_ + Q_) / 2.0, base=2)
326 | res = e_sum - ((e1 + e2) / 2.0)
327 |
328 | res2 = _jsdiv(P_, Q_)
329 |
330 | if not np.allclose(res, res2, atol=10e-5, rtol=0):
331 | warnings.warn('Numerical values of two JSD methods don\'t agree.')
332 |
333 | return res
334 |
335 |
336 | def _jsdiv(P, Q):
337 | """another way of computing JSD"""
338 |
339 | def _kldiv(A, B):
340 | a = A.copy()
341 | b = B.copy()
342 | idx = np.logical_and(a > 0, b > 0)
343 | a = a[idx]
344 | b = b[idx]
345 | return np.sum([v for v in a * np.log2(a / b)])
346 |
347 | P_ = P / np.sum(P)
348 | Q_ = Q / np.sum(Q)
349 |
350 | M = 0.5 * (P_ + Q_)
351 |
352 | return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
353 |
354 | from sklearn.metrics.pairwise import pairwise_distances
355 | def averaged_hausdorff_distance(sample_pcs, ref_pcs, max_ahd=np.inf):
356 | sample_pcs = sample_pcs.cpu().numpy()
357 | ref_pcs = ref_pcs.cpu().numpy()
358 |
359 | if len(sample_pcs) == 0 or len(ref_pcs) == 0:
360 | return max_ahd
361 |
362 | sample_pcs = np.array(sample_pcs)
363 | ref_pcs = np.array(ref_pcs)
364 |
365 | assert sample_pcs.ndim == 2, 'got %s' % sample_pcs.ndim
366 | assert ref_pcs.ndim == 2, 'got %s' % ref_pcs.ndim
367 |
368 | assert sample_pcs.shape[1] == ref_pcs.shape[1], \
369 | 'The points in both sets must have the same number of dimensions, got %s and %s.' \
370 | % (ref_pcs.shape[1], ref_pcs.shape[1])
371 |
372 | d2_matrix = pairwise_distances(sample_pcs, ref_pcs, metric='euclidean')
373 |
374 | res = np.average(np.min(d2_matrix, axis=0)) + \
375 | np.average(np.min(d2_matrix, axis=1))
376 |
377 | return res
378 |
379 |
380 | def hausdorff_distance(pred, ref):
381 | A_ab, A_ba = distChamfer(pred, ref)
382 | h_ab = torch.max(A_ab - A_ba)
383 | B_ab, B_ba = distChamfer(ref, pred)
384 | h_ba = torch.max(B_ab - B_ba)
385 |
386 | return torch.max(h_ab, h_ba) / 2
387 |
388 | if __name__ == '__main__':
389 | a = torch.randn([16, 2048, 3]).cuda()
390 | b = torch.randn([16, 2048, 3]).cuda()
391 | print(EMD_CD(a, b, batch_size=8))
392 |
--------------------------------------------------------------------------------
/model/EncoderCop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from timm.models.layers import DropPath, trunc_normal_
5 | import numpy as np
6 | from utils import misc
7 | import random
8 | from knn_cuda import KNN
9 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2
10 |
11 |
12 |
13 | class Encoder(nn.Module): ## Embedding module
14 | def __init__(self, encoder_channel):
15 | super().__init__()
16 | self.encoder_channel = encoder_channel
17 | self.first_conv = nn.Sequential(
18 | nn.Conv1d(3, 128, 1),
19 | nn.BatchNorm1d(128),
20 | nn.ReLU(inplace=True),
21 | nn.Conv1d(128, 256, 1)
22 | )
23 | self.second_conv = nn.Sequential(
24 | nn.Conv1d(512, 512, 1),
25 | nn.BatchNorm1d(512),
26 | nn.ReLU(inplace=True),
27 | nn.Conv1d(512, self.encoder_channel, 1)
28 | )
29 |
30 | def forward(self, point_groups):
31 | '''
32 | point_groups : B G N 3
33 | -----------------
34 | feature_global : B G C
35 | '''
36 | bs, g, n, _ = point_groups.shape
37 | point_groups = point_groups.reshape(bs * g, n, 3)
38 | # encoder
39 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
40 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
41 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
42 | feature = self.second_conv(feature) # BG 1024 n
43 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
44 | return feature_global.reshape(bs, g, self.encoder_channel)
45 |
46 |
47 | class Group(nn.Module): # FPS + KNN
48 | def __init__(self, num_group, group_size):
49 | super().__init__()
50 | self.num_group = num_group
51 | self.group_size = group_size
52 | self.knn = KNN(k=self.group_size, transpose_mode=True)
53 |
54 | def forward(self, xyz):
55 | '''
56 | input: B N 3
57 | ---------------------------
58 | output: B G M 3
59 | center : B G 3
60 | '''
61 | batch_size, num_points, _ = xyz.shape
62 | # fps the centers out
63 | center = misc.fps(xyz, self.num_group) # B G 3
64 | # knn to get the neighborhood
65 | _, idx = self.knn(xyz, center) # B G M
66 | assert idx.size(1) == self.num_group
67 | assert idx.size(2) == self.group_size
68 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
69 | idx = idx + idx_base
70 | idx = idx.view(-1)
71 | neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
72 | neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
73 | # normalize
74 | neighborhood = neighborhood - center.unsqueeze(2)
75 | return neighborhood, center
76 |
77 |
78 | ## Transformers
79 | class Mlp(nn.Module):
80 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
81 | super().__init__()
82 | out_features = out_features or in_features
83 | hidden_features = hidden_features or in_features
84 | self.fc1 = nn.Linear(in_features, hidden_features)
85 | self.act = act_layer()
86 | self.fc2 = nn.Linear(hidden_features, out_features)
87 | self.drop = nn.Dropout(drop)
88 |
89 | def forward(self, x):
90 | x = self.fc1(x)
91 | x = self.act(x)
92 | x = self.drop(x)
93 | x = self.fc2(x)
94 | x = self.drop(x)
95 | return x
96 |
97 |
98 | class Attention(nn.Module):
99 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
100 | super().__init__()
101 | self.num_heads = num_heads
102 | head_dim = dim // num_heads
103 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
104 | self.scale = qk_scale or head_dim ** -0.5
105 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106 | self.attn_drop = nn.Dropout(attn_drop)
107 | self.proj = nn.Linear(dim, dim)
108 | self.proj_drop = nn.Dropout(proj_drop)
109 |
110 | def forward(self, x):
111 | B, N, C = x.shape
112 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
113 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
114 |
115 | attn = (q @ k.transpose(-2, -1)) * self.scale
116 | attn = attn.softmax(dim=-1)
117 | attn = self.attn_drop(attn)
118 |
119 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
120 | x = self.proj(x)
121 | x = self.proj_drop(x)
122 | return x
123 |
124 |
125 | class Block(nn.Module):
126 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
127 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
128 | super().__init__()
129 | self.norm1 = norm_layer(dim)
130 |
131 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
132 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
133 | self.norm2 = norm_layer(dim)
134 | mlp_hidden_dim = int(dim * mlp_ratio)
135 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
136 |
137 | self.attn = Attention(
138 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
139 |
140 | def forward(self, x):
141 | x = x + self.drop_path(self.attn(self.norm1(x)))
142 | x = x + self.drop_path(self.mlp(self.norm2(x)))
143 | return x
144 |
145 |
146 | class TransformerEncoder(nn.Module):
147 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
148 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
149 | super().__init__()
150 |
151 | self.blocks = nn.ModuleList([
152 | Block(
153 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
154 | drop=drop_rate, attn_drop=attn_drop_rate,
155 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
156 | )
157 | for i in range(depth)])
158 |
159 | def forward(self, x, pos):
160 | for _, block in enumerate(self.blocks):
161 | x = block(x + pos)
162 | return x
163 |
164 |
165 | # Pretrain model
166 | class PointTransformer(nn.Module):
167 | def __init__(self, config, **kwargs):
168 | super().__init__()
169 | self.config = config
170 | # define the transformer argparse
171 | self.trans_dim = config.trans_dim
172 | self.depth = config.depth
173 | self.drop_path_rate = config.drop_path_rate
174 | self.num_heads = config.num_heads
175 | # embedding
176 | self.encoder_dims = config.encoder_dims
177 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
178 | self.mask_type = config.mask_type
179 | self.mask_ratio = config.mask_ratio
180 |
181 | self.pos_embed = nn.Sequential(
182 | nn.Linear(3, 128),
183 | nn.GELU(),
184 | nn.Linear(128, self.trans_dim),
185 | )
186 |
187 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
188 | self.blocks = TransformerEncoder(
189 | embed_dim=self.trans_dim,
190 | depth=self.depth,
191 | drop_path_rate=dpr,
192 | num_heads=self.num_heads,
193 | )
194 |
195 | self.norm = nn.LayerNorm(self.trans_dim)
196 | self.apply(self._init_weights)
197 |
198 | def _init_weights(self, m):
199 | if isinstance(m, nn.Linear):
200 | trunc_normal_(m.weight, std=.02)
201 | if isinstance(m, nn.Linear) and m.bias is not None:
202 | nn.init.constant_(m.bias, 0)
203 | elif isinstance(m, nn.LayerNorm):
204 | nn.init.constant_(m.bias, 0)
205 | nn.init.constant_(m.weight, 1.0)
206 | elif isinstance(m, nn.Conv1d):
207 | trunc_normal_(m.weight, std=.02)
208 | if m.bias is not None:
209 | nn.init.constant_(m.bias, 0)
210 |
211 | def _mask_center_block(self, center, noaug=False):
212 | if noaug or self.mask_ratio == 0:
213 | return torch.zeros(center.shape[:2]).bool()
214 | mask_idx = []
215 | for points in center:
216 | points = points.unsqueeze(0)
217 | index = random.randint(0, points.size(1) - 1)
218 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1)
219 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0]
220 | ratio = self.mask_ratio
221 | mask_num = int(ratio * len(idx))
222 | mask = torch.zeros(len(idx))
223 | mask[idx[:mask_num]] = 1
224 | mask_idx.append(mask.bool())
225 | bool_masked_pos = torch.stack(mask_idx).to(center.device)
226 | return bool_masked_pos
227 | def _mask_center_rand(self, center, noaug=False):
228 | '''
229 | center : B G 3
230 | --------------
231 | mask : B G (bool)
232 | '''
233 | B, G, _ = center.shape
234 | # skip the mask
235 | if noaug or self.mask_ratio == 0:
236 | return torch.zeros(center.shape[:2]).bool()
237 |
238 | self.num_mask = int(self.mask_ratio * G)
239 |
240 | overall_mask = np.zeros([B, G])
241 | for i in range(B):
242 | mask = np.hstack([
243 | np.zeros(G - self.num_mask),
244 | np.ones(self.num_mask),
245 | ])
246 | np.random.shuffle(mask)
247 | overall_mask[i, :] = mask
248 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
249 |
250 | return overall_mask.to(center.device) # B G
251 |
252 | def forward(self, neighborhood, center, noaug=False):
253 | # generate mask
254 | if self.mask_type == 'rand':
255 | bool_masked_pos = self._mask_center_rand(center, noaug=noaug) # B G
256 | else:
257 | bool_masked_pos = self._mask_center_block(center, noaug=noaug)
258 | B, G, N, C = neighborhood.shape
259 | vis = neighborhood[~bool_masked_pos].reshape(B, -1, N, C)
260 | group_input_tokens = self.encoder(vis) # B G C
261 |
262 | batch_size, seq_len, L = group_input_tokens.size()
263 | vis_center = center[~bool_masked_pos].reshape(B, -1, C)
264 | p = self.pos_embed(vis_center)
265 |
266 | z = self.blocks(group_input_tokens, p)
267 | z = self.norm(z)
268 | return z.reshape(batch_size, -1, L), bool_masked_pos
269 |
270 | class Encoder_Module(nn.Module):
271 | def __init__(self, config):
272 | super().__init__()
273 | self.config = config
274 | self.trans_dim = config.trans_dim
275 | self.AE_encoder = PointTransformer(config)
276 | self.group_size = config.group_size
277 | self.num_group = config.num_group
278 | self.num_output = config.num_output
279 | self.mask_ratio = config.mask_ratio
280 | self.num_channel = 3
281 | self.drop_path_rate = config.drop_path_rate
282 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
283 |
284 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
285 |
286 | # prediction head
287 | self.increase_dim = nn.Sequential(
288 | nn.Conv1d(self.trans_dim, (self.num_channel * int(self.num_output * (1 - self.mask_ratio))) // self.num_group, 1)
289 | )
290 |
291 | trunc_normal_(self.mask_token, std=.02)
292 | self.loss = config.loss
293 | # loss
294 | self.build_loss_func(self.loss)
295 |
296 | def build_loss_func(self, loss_type):
297 | if loss_type == "cdl1":
298 | self.loss_func = chamfer_distance_l1
299 | elif loss_type == 'cdl2':
300 | self.loss_func = chamfer_distance_l2
301 | elif loss_type == 'mse':
302 | self.loss_func = F.mse_loss
303 | else:
304 | raise NotImplementedError
305 |
306 | def forward(self, pts):
307 | x_vis, mask, center, vis_pc, msk_pc = self.encode(pts, False)
308 |
309 | B, _, C = x_vis.shape # B VIS C
310 |
311 | rebuild_points = self.increase_dim(x_vis.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
312 | full = torch.cat([rebuild_points, msk_pc], dim=1)
313 | loss1 = self.loss_func(full, pts)
314 | return loss1
315 |
316 | def encode(self, pt, masked=False):
317 | B, _, N = pt.shape
318 | neighborhood, center = self.group_divider(pt)
319 | x_vis, mask = self.AE_encoder(neighborhood, center)
320 | if masked:
321 | return x_vis, mask, center
322 | else:
323 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis)
324 | return x_vis, mask, center, vis_pc, msk_pc
325 |
326 | def evaluate(self, x_vis, x_msk):
327 | B, _, C = x_vis.shape # B VIS C
328 | x_full = torch.cat([x_vis, x_msk], dim=1)
329 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
330 | return rebuild_points.reshape(-1, 3).unsqueeze(0)
331 |
332 | def neighborhood(self, neighborhood, center, mask, x_vis):
333 | B, M, N = x_vis.shape
334 | vis_point = neighborhood[~mask].reshape(B * M, -1, 3)
335 | full_vis = vis_point + center[~mask].unsqueeze(1)
336 | msk_point = neighborhood[mask].reshape(B * (self.num_group - M), -1, 3)
337 | full_msk = msk_point + center[mask].unsqueeze(1)
338 |
339 | full_vis = full_vis.reshape(B, -1, 3)
340 | full_msk = full_msk.reshape(B, -1, 3)
341 |
342 | return full_vis, full_msk
343 |
--------------------------------------------------------------------------------
/model/EncoderSR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from timm.models.layers import DropPath, trunc_normal_
5 | import numpy as np
6 | from utils import misc
7 | import random
8 | from knn_cuda import KNN
9 | from metrics.evaluation_metrics import chamfer_distance_l1, chamfer_distance_l2
10 | # from model.Encoder_Component import *
11 |
12 | class Encoder(nn.Module): ## Embedding module
13 | def __init__(self, encoder_channel):
14 | super().__init__()
15 | self.encoder_channel = encoder_channel
16 | self.first_conv = nn.Sequential(
17 | nn.Conv1d(3, 128, 1),
18 | nn.BatchNorm1d(128),
19 | nn.ReLU(inplace=True),
20 | nn.Conv1d(128, 256, 1)
21 | )
22 | self.second_conv = nn.Sequential(
23 | nn.Conv1d(512, 512, 1),
24 | nn.BatchNorm1d(512),
25 | nn.ReLU(inplace=True),
26 | nn.Conv1d(512, self.encoder_channel, 1)
27 | )
28 |
29 | def forward(self, point_groups):
30 | '''
31 | point_groups : B G N 3
32 | -----------------
33 | feature_global : B G C
34 | '''
35 | bs, g, n, _ = point_groups.shape
36 | point_groups = point_groups.reshape(bs * g, n, 3)
37 | # encoder
38 | feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
39 | feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
40 | feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
41 | feature = self.second_conv(feature) # BG 1024 n
42 | feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
43 | return feature_global.reshape(bs, g, self.encoder_channel)
44 |
45 |
46 | class Group(nn.Module): # FPS + KNN
47 | def __init__(self, num_group, group_size):
48 | super().__init__()
49 | self.num_group = num_group
50 | self.group_size = group_size
51 | self.knn = KNN(k=self.group_size, transpose_mode=True)
52 | self.knnhr = KNN(k=self.group_size*4, transpose_mode=True)
53 |
54 | def forward(self, xyz, hr=None):
55 | '''
56 | input: B N 3
57 | ---------------------------
58 | output: B G M 3
59 | center : B G 3
60 | '''
61 | batch_size, num_points, _ = xyz.shape
62 | # fps the centers out
63 | center = misc.fps(xyz, self.num_group) # B G 3
64 | # knn to get the neighborhood
65 | _, idx = self.knn(xyz, center) # B G M
66 | assert idx.size(1) == self.num_group
67 | assert idx.size(2) == self.group_size
68 | idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
69 | idx = idx + idx_base
70 | idx = idx.view(-1)
71 | neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
72 | neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
73 | # normalize
74 | neighborhood = neighborhood - center.unsqueeze(2)
75 | if hr is not None:
76 | _, hr_points, _ = hr.shape
77 | _, hidx = self.knnhr(hr, center)
78 | hidx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
79 | hidx = hidx + hidx_base
80 | hidx = hidx.view(-1)
81 | h_n = hr.view(batch_size * hr_points, -1)[hidx,:]
82 | h_n = h_n.view(batch_size, self.num_group, self.group_size*4, 3).contiguous()
83 | h_n = h_n - center.unsqueeze(2)
84 | return neighborhood, center, h_n
85 | else:
86 | return neighborhood, center
87 |
88 |
89 | ## Transformers
90 | class Mlp(nn.Module):
91 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
92 | super().__init__()
93 | out_features = out_features or in_features
94 | hidden_features = hidden_features or in_features
95 | self.fc1 = nn.Linear(in_features, hidden_features)
96 | self.act = act_layer()
97 | self.fc2 = nn.Linear(hidden_features, out_features)
98 | self.drop = nn.Dropout(drop)
99 |
100 | def forward(self, x):
101 | x = self.fc1(x)
102 | x = self.act(x)
103 | x = self.drop(x)
104 | x = self.fc2(x)
105 | x = self.drop(x)
106 | return x
107 |
108 |
109 | class Attention(nn.Module):
110 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
111 | super().__init__()
112 | self.num_heads = num_heads
113 | head_dim = dim // num_heads
114 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
115 | self.scale = qk_scale or head_dim ** -0.5
116 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
117 | self.attn_drop = nn.Dropout(attn_drop)
118 | self.proj = nn.Linear(dim, dim)
119 | self.proj_drop = nn.Dropout(proj_drop)
120 |
121 | def forward(self, x):
122 | B, N, C = x.shape
123 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
125 |
126 | attn = (q @ k.transpose(-2, -1)) * self.scale
127 | attn = attn.softmax(dim=-1)
128 | attn = self.attn_drop(attn)
129 |
130 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
131 | x = self.proj(x)
132 | x = self.proj_drop(x)
133 | return x
134 |
135 |
136 | class Block(nn.Module):
137 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
139 | super().__init__()
140 | self.norm1 = norm_layer(dim)
141 |
142 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
143 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
144 | self.norm2 = norm_layer(dim)
145 | mlp_hidden_dim = int(dim * mlp_ratio)
146 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
147 |
148 | self.attn = Attention(
149 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
150 |
151 | def forward(self, x):
152 | x = x + self.drop_path(self.attn(self.norm1(x)))
153 | x = x + self.drop_path(self.mlp(self.norm2(x)))
154 | return x
155 |
156 |
157 | class TransformerEncoder(nn.Module):
158 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
159 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
160 | super().__init__()
161 |
162 | self.blocks = nn.ModuleList([
163 | Block(
164 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
165 | drop=drop_rate, attn_drop=attn_drop_rate,
166 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
167 | )
168 | for i in range(depth)])
169 |
170 | def forward(self, x, pos):
171 | for _, block in enumerate(self.blocks):
172 | x = block(x + pos)
173 | return x
174 |
175 |
176 | class PointTransformerDecoder(nn.Module):
177 | def __init__(self, embed_dim=384, depth=4, num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None,
178 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm):
179 | super().__init__()
180 | self.blocks = nn.ModuleList([
181 | Block(
182 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
183 | drop=drop_rate, attn_drop=attn_drop_rate,
184 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
185 | )
186 | for i in range(depth)])
187 | self.norm = norm_layer(embed_dim)
188 | self.head = nn.Identity()
189 |
190 | self.apply(self._init_weights)
191 |
192 | def _init_weights(self, m):
193 | if isinstance(m, nn.Linear):
194 | nn.init.xavier_uniform_(m.weight)
195 | if isinstance(m, nn.Linear) and m.bias is not None:
196 | nn.init.constant_(m.bias, 0)
197 | elif isinstance(m, nn.LayerNorm):
198 | nn.init.constant_(m.bias, 0)
199 | nn.init.constant_(m.weight, 1.0)
200 |
201 | def forward(self, x, pos, n):
202 | for _, block in enumerate(self.blocks):
203 | x = block(x + pos)
204 | x = self.head(self.norm(x[:, -n:]))
205 | return x
206 |
207 | #
208 | # Pretrain model
209 | class PointTransformer(nn.Module):
210 | def __init__(self, config):
211 | super().__init__()
212 | self.config = config
213 | # define the transformer argparse
214 | self.trans_dim = config.trans_dim
215 | self.depth = config.encoder_depth
216 | self.drop_path_rate = config.drop_path_rate
217 | self.num_heads = config.encoder_num_heads
218 | # embedding
219 | self.encoder_dims = config.trans_dim
220 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
221 | self.mask_type = config.mask_type
222 | self.mask_ratio = config.mask_ratio
223 |
224 | self.pos_embed = nn.Sequential(
225 | nn.Linear(3, 128),
226 | nn.GELU(),
227 | nn.Linear(128, self.trans_dim),
228 | )
229 |
230 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
231 | self.blocks = TransformerEncoder(
232 | embed_dim=self.trans_dim,
233 | depth=self.depth,
234 | drop_path_rate=dpr,
235 | num_heads=self.num_heads,
236 | )
237 |
238 | self.norm = nn.LayerNorm(self.trans_dim)
239 | self.apply(self._init_weights)
240 |
241 | def _init_weights(self, m):
242 | if isinstance(m, nn.Linear):
243 | trunc_normal_(m.weight, std=.02)
244 | if isinstance(m, nn.Linear) and m.bias is not None:
245 | nn.init.constant_(m.bias, 0)
246 | elif isinstance(m, nn.LayerNorm):
247 | nn.init.constant_(m.bias, 0)
248 | nn.init.constant_(m.weight, 1.0)
249 | elif isinstance(m, nn.Conv1d):
250 | trunc_normal_(m.weight, std=.02)
251 | if m.bias is not None:
252 | nn.init.constant_(m.bias, 0)
253 |
254 | def _mask_center_block(self, center, noaug=False):
255 | if noaug or self.mask_ratio == 0:
256 | return torch.zeros(center.shape[:2]).bool()
257 | mask_idx = []
258 | for points in center:
259 | points = points.unsqueeze(0)
260 | index = random.randint(0, points.size(1) - 1)
261 | distance_matrix = torch.norm(points[:, index].reshape(1,1,3) - points, p=2, dim=-1)
262 | idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0]
263 | ratio = self.mask_ratio
264 | mask_num = int(ratio * len(idx))
265 | mask = torch.zeros(len(idx))
266 | mask[idx[:mask_num]] = 1
267 | mask_idx.append(mask.bool())
268 | bool_masked_pos = torch.stack(mask_idx).to(center.device)
269 | return bool_masked_pos
270 |
271 | def _mask_center_rand(self, center, noaug=False):
272 | '''
273 | center : B G 3
274 | --------------
275 | mask : B G (bool)
276 | '''
277 | B, G, _ = center.shape
278 | # skip the mask
279 | if noaug or self.mask_ratio == 0:
280 | return torch.zeros(center.shape[:2]).bool()
281 |
282 | self.num_mask = int(self.mask_ratio * G)
283 |
284 | overall_mask = np.zeros([B, G])
285 | for i in range(B):
286 | mask = np.hstack([
287 | np.zeros(G - self.num_mask),
288 | np.ones(self.num_mask),
289 | ])
290 | np.random.shuffle(mask)
291 | overall_mask[i, :] = mask
292 | overall_mask = torch.from_numpy(overall_mask).to(torch.bool)
293 |
294 | return overall_mask.to(center.device) # B G
295 |
296 | def forward(self, neighborhood, center, noaug=False):
297 | # generate mask
298 | if self.mask_type == 'rand':
299 | bool_masked_pos = self._mask_center_rand(center, noaug=noaug) # B G
300 | else:
301 | bool_masked_pos = self._mask_center_block(center, noaug=noaug)
302 |
303 | group_input_tokens = self.encoder(neighborhood) # B G C
304 |
305 | batch_size, seq_len, C = group_input_tokens.size()
306 |
307 | p = self.pos_embed(center)
308 |
309 | z = self.blocks(group_input_tokens, p)
310 | z = self.norm(z)
311 | return z[~bool_masked_pos].reshape(batch_size, -1, C), bool_masked_pos, z[bool_masked_pos].reshape(batch_size, -1, C)
312 |
313 | class Encoder_Module(nn.Module):
314 | def __init__(self, config):
315 | super().__init__()
316 | self.config = config
317 | self.trans_dim = config.trans_dim
318 | self.AE_encoder = PointTransformer(config)
319 | self.group_size = config.group_size
320 | self.num_group = config.num_group
321 | self.num_output = config.num_output
322 | self.num_channel = 3
323 | self.drop_path_rate = config.drop_path_rate
324 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
325 | self.decoder_pos_embed = nn.Sequential(
326 | nn.Linear(3, 128),
327 | nn.GELU(),
328 | nn.Linear(128, self.trans_dim)
329 | )
330 | self.decoder_depth = config.decoder_depth
331 | self.decoder_num_heads = config.decoder_num_heads
332 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.decoder_depth)]
333 | self.AE_decoder = PointTransformerDecoder(
334 | embed_dim=self.trans_dim,
335 | depth=self.decoder_depth,
336 | drop_path_rate=dpr,
337 | num_heads=self.decoder_num_heads,
338 | )
339 |
340 | self.AE_decoder = None
341 |
342 |
343 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
344 |
345 | # prediction head
346 | self.increase_dim = nn.Sequential(
347 | nn.Conv1d(self.trans_dim, (self.num_channel * self.num_output) // self.num_group, 1)
348 | )
349 |
350 | trunc_normal_(self.mask_token, std=.02)
351 | self.loss = config.loss
352 | # loss
353 | self.build_loss_func(self.loss)
354 |
355 | def build_loss_func(self, loss_type):
356 | if loss_type == "cdl1":
357 | self.loss_func = chamfer_distance_l1
358 | elif loss_type == 'cdl2':
359 | self.loss_func = chamfer_distance_l2
360 | elif loss_type == 'mse':
361 | self.loss_func = F.mse_loss
362 | else:
363 | raise NotImplementedError
364 |
365 | def forward(self, pts, hr_pt):
366 | x_vis, x_msk, mask, center = self.encode(pts, False)
367 |
368 | B, _, C = x_vis.shape # B VIS C
369 | x_full = torch.cat([x_vis, x_msk], dim=1)
370 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
371 | loss1 = self.loss_func(rebuild_points, hr_pt)
372 | return loss1
373 |
374 | def encode(self, pt, hr=None, masked=False):
375 | B, _, N = pt.shape
376 | if masked:
377 | if hr is None:
378 | neighborhood, center = self.group_divider(pt)
379 | else:
380 | neighborhood, center, h_n = self.group_divider(pt)
381 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center)
382 | return x_vis, mask, center
383 | else:
384 | if hr is not None:
385 | neighborhood, center, h_n = self.group_divider(pt, hr)
386 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center)
387 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis)
388 | vis_hr, msk_hr = self.neighborhood(h_n, center, mask, x_vis)
389 | return x_vis, x_masked, mask, center, vis_pc, msk_pc, vis_hr, msk_hr
390 | else:
391 | neighborhood, center = self.group_divider(pt)
392 | x_vis, mask, x_masked = self.AE_encoder(neighborhood, center)
393 | vis_pc, msk_pc = self.neighborhood(neighborhood, center, mask, x_vis)
394 | return x_vis, x_masked, mask, center, vis_pc, msk_pc
395 |
396 | def evaluate(self, x_vis, x_msk):
397 | B, _, C = x_vis.shape # B VIS C
398 | x_full = torch.cat([x_vis, x_msk], dim=1)
399 | rebuild_points = self.increase_dim(x_full.transpose(1, 2)).transpose(1, 2).reshape(B, -1, 3) # 38, 32, 3
400 | return rebuild_points.reshape(-1, 3).unsqueeze(0)
401 |
402 | def neighborhood(self, neighborhood, center, mask, x_vis):
403 | B, M, N = x_vis.shape
404 | vis_point = neighborhood[~mask].reshape(B * M, -1, 3)
405 | full_vis = vis_point + center[~mask].unsqueeze(1)
406 | msk_point = neighborhood[mask].reshape(B * (self.num_group - M), -1, 3)
407 | full_msk = msk_point + center[mask].unsqueeze(1)
408 |
409 | full_vis = full_vis.reshape(B, -1, 3)
410 | full_msk = full_msk.reshape(B, -1, 3)
411 |
412 | return full_vis, full_msk
413 |
--------------------------------------------------------------------------------