├── README.md ├── figs ├── mainfig.jpg └── overview.jpg └── src ├── chamfer_distance ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ └── chamfer_distance.cpython-39.pyc ├── chamfer_distance.cpp ├── chamfer_distance.cu └── chamfer_distance.py ├── configs ├── ULIP.yaml ├── base_model.yaml ├── finetune_modelnet.yaml ├── ltp_model.yaml └── pro_model.yaml ├── dataset_svr ├── augmenter.py ├── dataset_shapenet.py ├── dataset_shapenet_text.py ├── mesh_processor.py ├── pointcloud_processor.py ├── taxonomy.json ├── trainer_dataset.py └── trainer_text_dataset.py ├── loss ├── cdloss.py ├── losses.py └── losses_v2.py ├── models ├── Point_MAE.py ├── ULIP_models.py ├── ULIP_utils.py ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── build.py ├── clip_utils.py ├── encoder.py ├── pointbert │ ├── PointTransformer_8192point.yaml │ ├── ULIP_2_PointBERT_10k_colored_pointclouds.yaml │ ├── __pycache__ │ │ ├── checkpoint.cpython-39.pyc │ │ ├── dvae.cpython-39.pyc │ │ ├── logger.cpython-39.pyc │ │ ├── misc.cpython-39.pyc │ │ └── point_encoder.cpython-39.pyc │ ├── checkpoint.py │ ├── dvae.py │ ├── logger.py │ ├── misc.py │ └── point_encoder.py ├── text_encoder_3d.py └── tokenizer.py ├── network ├── base_model.py ├── ltp_model.py └── pro_model.py ├── tools ├── builder.py └── smt.py ├── train ├── train_base.py ├── train_ltp.py └── train_pro.py ├── utils ├── __init__.py ├── checkpoint.py ├── config.py ├── logger.py ├── misc.py └── registry.py └── val ├── val_base.py └── val_pro.py /README.md: -------------------------------------------------------------------------------- 1 |

2 |

MESC-3D:Mining Effective Semantic Cues for 3D Reconstruction from a Single Image(CVPR 2025)

3 | 17 | 18 |

Paper

19 |
20 |

21 |

22 | 23 |

24 | 25 | We release the code of the paper MESC-3D:Mining Effective Semantic Cues for 3D Reconstruction from a Single Image in this repository. 26 | 27 | 28 | ## Abstract 29 | 30 |

31 | In this work, we propose a novel single-image 3D reconstruction method called Mining Effective Semantic Cues for 3D Reconstruction from a Single Image (MESC-3D), which can actively mine effective semantic cues from entangled features. Specifically, we design an Effective Semantic Mining Module to establish connections between point clouds and image semantic attributes, enabling the point clouds to autonomously select the necessary information. Furthermore, to address the potential insufficiencies in semantic information from a single image, such as occlusions, inspired by the human ability to represent 3D objects using prior knowledge drawn from daily experiences, we introduce a 3DSPL. This module incorporates semantic understanding of spatial structures, enabling the model to interpret and reconstruct 3D objects with greater accuracy and realism, closely mirroring human perception of complex 3D environments. Extensive evaluations show that our method achieves significant improvements in reconstruction quality and robustness compared to prior works. Additionally, further experiments validate the strong generalization capabilities and excels in zero-shot preformance on unseen classes. 32 |

33 | 34 | 35 | ## Method 36 | 37 |

38 | 39 |

40 | 41 |

42 | Overview of MESC-3D. Our network is composed of two main components. (a) The 3DSPL align point cloud modality features with text features, aiming to capture the unique 3D geometric characteristics of each category. (b) The ESM establishes a connection between the semantic feature Fi and the 3D point cloud at ith stage, allowing each point to autonomously select the most valuable semantic information. 43 |

44 | 45 | ## Installation 46 | Clone this repository and install the required packages: 47 | 48 | - Install python Dependencies 49 | ```shell 50 | 51 | git clone https://github.com/QINGQINGLE/MESC-3D.git 52 | cd MESC-3D 53 | 54 | conda create -n mesc3d python=3.9 55 | conda activate mesc3d 56 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia 57 | 58 | pip install -r requirements.txt 59 | 60 | ``` 61 | 62 | - Compile PyTorch 3rd-party modules. 63 | 64 | ```shell 65 | 66 | cd package/Pointnet2_PyTorch-master/ 67 | pip install -e . 68 | pip install pointnet2_ops_lib/. 69 | 70 | cd - 71 | cd package/KNN_CUDA-master/ 72 | make && make install 73 | 74 | ``` 75 | 76 | - CLIP Usage 77 | The following step is the usage and modification of CLIP. 78 | ```shell 79 | 80 | pip install git+https://github.com/openai/CLIP.git 81 | Or 82 | pip install clip 83 | 84 | ``` 85 | Inplace 86 | ```shell 87 | def encode_text(self, text): 88 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 89 | x = x + self.positional_embedding.type(self.dtype) 90 | ... 91 | ``` 92 | with 93 | ```shell 94 | def encode_token(self, token): 95 | x = self.token_embedding(token) 96 | return x 97 | def encode_text(self, text, token): 98 | #x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 99 | x = text.type(self.dtype) + self.positional_embedding.type(self.dtype) 100 | ... 101 | ``` 102 | ## Dataset 103 | 104 | ## Pretrained-model 105 | 106 | 107 | We provide the following pretrained models: BaseModel, ProModel, ULIPModel, etc. Please download them from Google Drive. 108 | 109 | 110 | 111 | ## Training 112 | 113 | ## Testing 114 | 115 | The remaining code is on the way. 116 | -------------------------------------------------------------------------------- /figs/mainfig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/figs/mainfig.jpg -------------------------------------------------------------------------------- /figs/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/figs/overview.jpg -------------------------------------------------------------------------------- /src/chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance 2 | -------------------------------------------------------------------------------- /src/chamfer_distance/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/chamfer_distance/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/chamfer_distance/__pycache__/chamfer_distance.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/chamfer_distance/__pycache__/chamfer_distance.cpython-39.pyc -------------------------------------------------------------------------------- /src/chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /src/chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /src/chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import importlib 3 | from torch.utils.cpp_extension import load 4 | import os 5 | # os.environ["CUDA_VISIBLE_DEVICES"] = '5' 6 | script_dir = os.path.dirname(__file__) 7 | 8 | sources = [ 9 | os.path.join(script_dir, "chamfer_distance.cpp"), 10 | os.path.join(script_dir, "chamfer_distance.cu"), 11 | ] 12 | chamfer_found = importlib.find_loader("chamfer_3D") is not None 13 | if not chamfer_found: 14 | ## Cool trick from https://github.com/chrdiller 15 | print("Jitting Chamfer 3D") 16 | from torch.utils.cpp_extension import load 17 | chamfer_3D = load(name="chamfer_3D",sources=sources) 18 | print("Loaded JIT 3D CUDA chamfer distance") 19 | else: 20 | import chamfer_3D 21 | print("Loaded compiled 3D CUDA chamfer distance") 22 | 23 | 24 | class ChamferDistanceFunction(torch.autograd.Function): 25 | @staticmethod 26 | def forward(ctx, xyz1, xyz2): 27 | batchsize, n, _ = xyz1.size() 28 | _, m, _ = xyz2.size() 29 | device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu') 30 | # print("device:", device) 31 | xyz1 = xyz1.contiguous() 32 | xyz2 = xyz2.contiguous() 33 | dist1 = torch.zeros(batchsize, n) 34 | dist2 = torch.zeros(batchsize, m) 35 | 36 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 37 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 38 | xyz1 = xyz1.to(device) 39 | xyz2 = xyz2.to(device) 40 | dist1 = dist1.to(device) 41 | dist2 = dist2.to(device) 42 | idx1 = idx1.to(device) 43 | idx2 = idx2.to(device) 44 | torch.cuda.set_device(device) 45 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 46 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 47 | # if not xyz1.is_cuda: 48 | # chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 49 | # else: 50 | # # dist1 = dist1.cuda() 51 | # # dist2 = dist2.cuda() 52 | # # idx1 = idx1.cuda() 53 | # # idx2 = idx2.cuda() 54 | # dist1 = dist1.to(device) 55 | # dist2 = dist2.to(device) 56 | # idx1 = idx1.to(device) 57 | # idx2 = idx2.to(device) 58 | # torch.cuda.set_device(device) 59 | # # chamfer_3D.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, sidx2) 60 | # chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 61 | 62 | # ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 63 | return dist1, dist2 64 | 65 | @staticmethod 66 | def backward(ctx, graddist1, graddist2): 67 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 68 | 69 | graddist1 = graddist1.contiguous() 70 | graddist2 = graddist2.contiguous() 71 | device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu') 72 | 73 | gradxyz1 = torch.zeros(xyz1.size()) 74 | gradxyz2 = torch.zeros(xyz2.size()) 75 | 76 | if not graddist1.is_cuda: 77 | chamfer_3D.backward( 78 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 79 | ) 80 | else: 81 | # gradxyz1 = gradxyz1.cuda() 82 | # gradxyz2 = gradxyz2.cuda() 83 | gradxyz1 = gradxyz1.to(device) 84 | gradxyz2 = gradxyz2.to(device) 85 | # chamfer_3D.backward_cuda( 86 | # xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 87 | # ) 88 | chamfer_3D.backward( 89 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 90 | ) 91 | 92 | return gradxyz1, gradxyz2 93 | 94 | 95 | class ChamferDistance(torch.nn.Module): 96 | def __init__(self): 97 | super(ChamferDistance, self).__init__() 98 | def forward(self, xyz1, xyz2): 99 | return ChamferDistanceFunction.apply(xyz1, xyz2) 100 | -------------------------------------------------------------------------------- /src/configs/ULIP.yaml: -------------------------------------------------------------------------------- 1 | model: ULIP_PointBERT 2 | npoints: 2048 3 | evaluate_3d: store_true 4 | -------------------------------------------------------------------------------- /src/configs/base_model.yaml: -------------------------------------------------------------------------------- 1 | # Adam & scheduler 2 | optimizer : { 3 | type: Adam, 4 | kwargs: { 5 | lr : 0.001, 6 | betas : [0.9,0.999], 7 | eps: 0.000001 8 | }} 9 | scheduler: { 10 | type: CosLR, 11 | kwargs: { 12 | epochs: 400, 13 | initial_epochs : 20 14 | }} 15 | # model 16 | model_name: BaseModel 17 | manual_seed: null 18 | 19 | model: 20 | embed_dim: 32 21 | depth: 4 22 | out_dim: 32 23 | num_heads: 8 24 | mlp_ratio: 2.0 25 | qkv_bias: False 26 | qk_scale: null 27 | drop_rate: 0.2 28 | attn_drop_rate: 0.2 29 | drop_path_rate: 0.2 30 | num_points: 2048 31 | dec_dim: [768,512,256,128,64,32] 32 | yaml: Path/configs/finetune_modelnet.yaml 33 | 34 | # batch 35 | batch_size: 24 36 | start_epoch: 0 37 | nepoch: 400 38 | epoch_interval_to_save: 5 39 | epoch_interval_to_val: 1 40 | 41 | # gpus 42 | distributed: True 43 | gpus: [0,1] 44 | # SVR Data 45 | radius: 0.5 46 | normalization: UnitBall 47 | shapenet13: True 48 | SVR: True 49 | class_choice: ["airplane"] 50 | number_points: 2048 51 | number_points_eval: 2048 52 | random_rotation: False 53 | data_augmentation_axis_rotation: False 54 | data_augmentation_random_flips: False 55 | random_translation: False 56 | anisotropic_scaling: False 57 | demo: False 58 | sample: True 59 | workers: 8 60 | 61 | # path 62 | taxonomy_path: Path/dataset_svr/taxonomy.json 63 | dir_outpath: Path/log_dir 64 | 65 | ckpt_path: path/ckpt_path/base_model.pth 66 | 67 | pointcloud_path: Path/ShapeNetV1PointCloud/ 68 | image_path: Path/ShapeNetV1Renderings/ 69 | cache_path: Path/Cache/ 70 | cache_path_test: Path/Cachetest/ 71 | -------------------------------------------------------------------------------- /src/configs/finetune_modelnet.yaml: -------------------------------------------------------------------------------- 1 | optimizer : { 2 | type: AdamW, 3 | kwargs: { 4 | lr : 0.0005, 5 | weight_decay : 0.05 6 | }} 7 | 8 | scheduler: { 9 | type: CosLR, 10 | kwargs: { 11 | epochs: 300, 12 | initial_epochs : 10 13 | }} 14 | 15 | # dataset : { 16 | # train : { _base_: cfgs/dataset_configs/ModelNet40.yaml, 17 | # others: {subset: 'train'}}, 18 | # val : { _base_: cfgs/dataset_configs/ModelNet40.yaml, 19 | # others: {subset: 'test'}}, 20 | # test : { _base_: cfgs/dataset_configs/ModelNet40.yaml, 21 | # others: {subset: 'test'}}} 22 | model : { 23 | NAME: PointTransformer, 24 | trans_dim: 384, 25 | depth: 12, 26 | drop_path_rate: 0.1, 27 | cls_dim: 40, 28 | num_heads: 6, 29 | group_size: 32, 30 | num_group: 64, 31 | encoder_dims: 384, 32 | } 33 | 34 | ckpts: Path/checkpoints/modelnet_1k.pth 35 | 36 | npoints: 2048 37 | total_bs : 32 38 | step_per_update : 1 39 | max_epoch : 300 40 | grad_norm_clip : 10 -------------------------------------------------------------------------------- /src/configs/ltp_model.yaml: -------------------------------------------------------------------------------- 1 | optimizer : { 2 | type: Adam, 3 | kwargs: { 4 | lr : 0.0001, 5 | betas : [0.9,0.999], 6 | eps: 0.000001 7 | }} 8 | 9 | scheduler: { 10 | type: CosLR, 11 | kwargs: { 12 | epochs: 250, 13 | initial_epochs : 10 14 | }} 15 | 16 | 17 | #trainer 18 | start_epoch: 0 19 | epochs: 250 20 | model: TextAlignPCModel 21 | 22 | batch_size: 32 # 128 64 32 23 | disable_amp: store_true 24 | # update_freq: 2 # 1 2 4 25 | print_freq: 10 26 | warmup_epochs: 1 27 | wandb: store_true 28 | #Shapenet Data 29 | radius: 0.5 30 | normalization: UnitBall 31 | shapenet13: True 32 | SVR: True 33 | 34 | class_choice: ["airplane"] 35 | number_points: 2048 36 | number_points_eval: 2048 37 | random_rotation: False 38 | data_augmentation_axis_rotation: False 39 | data_augmentation_random_flips: False 40 | random_translation: False 41 | anisotropic_scaling: False 42 | demo: False 43 | sample: True 44 | workers: 8 45 | 46 | # path 47 | dir_outpath: Path/dir_log/ 48 | dir_checkpoints: None 49 | output_dir: Path/dir_out/ 50 | clip_ckpt_path: Path/checkpoints/ViT-B-16.pt 51 | ulip_ckpt_path: Path/checkpoints/pretrained_models_ckpt_zero-sho_classification_checkpoint_pointbert.pt 52 | 53 | ulip: Path/configs/Ulip.yaml 54 | 55 | pointcloud_path: Path/ShapeNetV1PointCloud/ 56 | image_path: Path/ShapeNetV1Renderings/ 57 | cache_path: Path/Cache/ 58 | cache_path_test: Path/Cachetest/ -------------------------------------------------------------------------------- /src/configs/pro_model.yaml: -------------------------------------------------------------------------------- 1 | # Adam & Scheduler 2 | optimizer : { 3 | type: Adam, 4 | kwargs: { 5 | lr : 0.001, 6 | betas : [0.9,0.999], 7 | eps: 0.000001 8 | }} 9 | 10 | scheduler: { 11 | type: CosLR, 12 | kwargs: { 13 | epochs: 400, 14 | initial_epochs : 20 15 | }} 16 | # model 17 | model_name: ProModel 18 | manual_seed: null 19 | 20 | model: 21 | embed_dim: 32 22 | depth: 4 23 | out_dim: 32 24 | num_heads: 8 25 | mlp_ratio: 2.0 26 | qkv_bias: False 27 | qk_scale: null 28 | drop_rate: 0.2 29 | attn_drop_rate: 0.2 30 | drop_path_rate: 0.2 31 | num_points: 2048 32 | dec_dim: [1280,1024,768,512,256,128,64,32] 33 | yaml: Path/configs/finetune_modelnet.yaml 34 | 35 | # batch 36 | batch_size: 24 # 32 37 | start_epoch: 0 38 | nepoch: 400 39 | epoch_interval_to_save: 5 40 | epoch_interval_to_val: 1 41 | 42 | # gpus 43 | distributed: True 44 | gpus: [0,1] 45 | # SVR Data 46 | radius: 0.5 47 | normalization: UnitBall 48 | shapenet13: True 49 | SVR: True 50 | 51 | class_choice: ["airplane"] 52 | number_points: 2048 53 | number_points_eval: 2048 54 | random_rotation: False 55 | data_augmentation_axis_rotation: False 56 | data_augmentation_random_flips: False 57 | random_translation: False 58 | anisotropic_scaling: False 59 | demo: False 60 | sample: True 61 | workers: 8 62 | 63 | # path 64 | taxonomy_path: Path/dataset_svr/taxonomy.json 65 | dir_outpath: Path/log_dir 66 | 67 | ckpt_path: Path/ckpt_path/pro_model.pth 68 | prompt_ckpt_path: Path/prompt.pth 69 | text_encoder_ckpt_path: Path/text_encoder.pth 70 | 71 | pointcloud_path: Path/ShapeNetV1PointCloud/ 72 | image_path: Path/ShapeNetV1Renderings/ 73 | cache_path: Path/Cache/ 74 | cache_path_test: Path/Cachetest/ -------------------------------------------------------------------------------- /src/dataset_svr/augmenter.py: -------------------------------------------------------------------------------- 1 | import dataset_svr.pointcloud_processor as pointcloud_processor 2 | 3 | 4 | class Augmenter(object): 5 | def __init__(self, translation=False, rotation_axis=[], rotation_3D=False, anisotropic_scaling=False, flips=[]): 6 | self.translation = translation 7 | self.rotation_axis = rotation_axis 8 | self.rotation_3D = rotation_3D 9 | self.anisotropic_scaling = anisotropic_scaling 10 | self.flips = flips 11 | 12 | def __call__(self, points): 13 | operation = pointcloud_processor.DataAugmentation(points) 14 | for axis in self.rotation_axis: 15 | operation.random_axial_rotation(axis=axis) 16 | if self.rotation_3D: 17 | operation.random_rotation() 18 | if self.anisotropic_scaling: 19 | operation.random_anisotropic_scaling() 20 | if len(self.flips) > 0: 21 | operation.random_flips(self.flips) 22 | if self.translation: 23 | operation.random_translation() 24 | -------------------------------------------------------------------------------- /src/dataset_svr/mesh_processor.py: -------------------------------------------------------------------------------- 1 | import pymesh 2 | import numpy as np 3 | from os.path import join, dirname 4 | 5 | 6 | class ColorMap: 7 | def __init__(self): 8 | self.colormap_path = "auxiliary/colormap.npy" 9 | self.colormap = (np.load(self.colormap_path) * 255).astype('int') 10 | 11 | def __call__(self, index): 12 | """ 13 | :param value: a float 14 | :return: 15 | """ 16 | colors = self.colormap[index] 17 | return colors 18 | 19 | 20 | def save(mesh, path, colormap): 21 | try: 22 | vertex_sources = mesh.get_attribute("vertex_sources") # batch, nb_prim, num_point, 3 23 | if vertex_sources.max() > 0: 24 | vertex_sources = (255 * (vertex_sources / vertex_sources.max())).astype('int') 25 | mesh.add_attribute("vertex_red") 26 | mesh.add_attribute("vertex_green") 27 | mesh.add_attribute("vertex_blue") 28 | mesh.set_attribute("vertex_red", colormap.colormap[vertex_sources][:, 0]) 29 | mesh.set_attribute("vertex_green", colormap.colormap[vertex_sources][:, 1]) 30 | mesh.set_attribute("vertex_blue", colormap.colormap[vertex_sources][:, 2]) 31 | except: 32 | pass 33 | pymesh.save_mesh(path[:-3] + "ply", mesh, *mesh.get_attribute_names(), ascii=True) 34 | -------------------------------------------------------------------------------- /src/dataset_svr/pointcloud_processor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class FunctionGenerator(object): 6 | def invert(self): 7 | print("This function has to be reimplemented in every inherited class") 8 | 9 | 10 | class ScaleFunctions(FunctionGenerator): 11 | def __init__(self, operator, inplace): 12 | self.operator = operator.clone() 13 | self.inplace = inplace 14 | 15 | def __call__(self, points): 16 | if self.inplace: 17 | points *= self.operator 18 | return points 19 | else: 20 | return points * self.operator 21 | 22 | def invert(self): 23 | self.operator = 1.0 / self.operator 24 | 25 | 26 | class RotationFunctions(FunctionGenerator): 27 | def __init__(self, operator, inplace): 28 | self.operator = operator.clone() 29 | self.inplace = inplace 30 | assert (self.operator.bmm(self.operator.transpose(1, 2).contiguous()).sum().item() - ( 31 | operator.size(0) * operator.size(2))) ** 2 < 0.001, "Input matrix is not a rotation matrix" 32 | 33 | def __call__(self, points): 34 | rotated_points = torch.bmm(points, self.operator) 35 | if self.inplace: 36 | points.copy_(rotated_points) 37 | return points 38 | return rotated_points 39 | 40 | def invert(self): 41 | self.operator = self.operator.transpose(1, 2).contiguous() 42 | 43 | 44 | class TranslationFunctions(FunctionGenerator): 45 | def __init__(self, operator, inplace): 46 | self.operator = operator.clone() 47 | self.inplace = inplace 48 | 49 | def __call__(self, points): 50 | if self.inplace: 51 | points += self.operator 52 | return points 53 | else: 54 | return points + self.operator 55 | 56 | def invert(self): 57 | self.operator = -self.operator 58 | 59 | 60 | class Operation(object): 61 | def __init__(self, points, inplace=True, keep_track=False): 62 | """ 63 | The keep track boolean is used in case one wants to unroll all the operation that have been performed 64 | :param keep_track: boolean 65 | """ 66 | self.keep_track = keep_track 67 | self.transforms = [] 68 | self.points = points 69 | self.device = points.device 70 | self.inplace = inplace 71 | self.dim = points.dim() 72 | self.type = self.points.type() 73 | 74 | if not self.inplace: 75 | self.points = self.points.clone() 76 | if self.dim == 2: 77 | self.points = self.points.unsqueeze_(0) 78 | elif self.dim == 3: 79 | pass 80 | else: 81 | print("Input should have dimension 2 or 3") 82 | 83 | def apply(self, points): 84 | for func in self.transforms: 85 | points = func(points) 86 | return points 87 | 88 | def invert(self): 89 | self.transforms.reverse() 90 | for func in self.transforms: 91 | func.invert() 92 | 93 | def scale(self, scale_vector): 94 | scaling_op = ScaleFunctions(scale_vector.to(self.device).type(self.type), inplace=self.inplace) 95 | self.points = scaling_op(self.points) 96 | if self.keep_track: 97 | self.transforms.append(scaling_op) 98 | return 99 | 100 | def translate(self, translation_vector): 101 | translation_op = TranslationFunctions(translation_vector.to(self.device).type(self.type), inplace=self.inplace) 102 | self.points = translation_op(self.points) 103 | if self.keep_track: 104 | self.transforms.append(translation_op) 105 | return 106 | 107 | def rotate(self, rotation_vector): 108 | rotation_op = RotationFunctions(rotation_vector.to(self.device).type(self.type), inplace=self.inplace) 109 | self.points = rotation_op(self.points) 110 | if self.keep_track: 111 | self.transforms.append(rotation_op) 112 | return 113 | 114 | @staticmethod 115 | def get_3D_rot_matrix(axis, rad_angle): 116 | """ 117 | Get a 3D rotation matrix around axis with angle in radian 118 | :param axis: int 119 | :param angle: torch.tensor of size Batch. 120 | :return: Rotation Matrix as a tensor 121 | """ 122 | cos_angle = torch.cos(rad_angle) 123 | sin_angle = torch.sin(rad_angle) 124 | rotation_matrix = torch.zeros(rad_angle.size(0), 3, 3) 125 | rotation_matrix[:, 1, 1].fill_(1) 126 | rotation_matrix[:, 0, 0].copy_(cos_angle) 127 | rotation_matrix[:, 0, 2].copy_(sin_angle) 128 | rotation_matrix[:, 2, 0].copy_(-sin_angle) 129 | rotation_matrix[:, 2, 2].copy_(cos_angle) 130 | if axis == 0: 131 | rotation_matrix = rotation_matrix[:, [1, 0, 2], :][:, :, [1, 0, 2]] 132 | if axis == 2: 133 | rotation_matrix = rotation_matrix[:, [0, 2, 1], :][:, :, [0, 2, 1]] 134 | return rotation_matrix 135 | 136 | def rotate_axis_angle(self, axis, rad_angle, normals=False): 137 | """ 138 | 139 | :param points: Batched points 140 | :param axis: int 141 | :param angle: batched angles 142 | :return: 143 | """ 144 | rot_matrix = Operation.get_3D_rot_matrix(axis=axis, rad_angle=rad_angle) 145 | if normals: 146 | rot_matrix = torch.cat([rot_matrix, rot_matrix], dim=2) 147 | self.rotate(rot_matrix) 148 | return 149 | 150 | 151 | class Normalization(Operation): 152 | def __init__(self, *args, **kwargs): 153 | super(Normalization, self).__init__(*args, **kwargs) 154 | 155 | def center_pointcloud(self): 156 | """ 157 | In-place centering 158 | :param points: Tensor Batch, N_pts, D_dim 159 | :return: None 160 | """ 161 | # input : 162 | # ouput : torch Tensor N_pts, D_dim 163 | centroid = torch.mean(self.points, dim=1, keepdim=True) 164 | self.translate(-centroid) 165 | return self.points 166 | 167 | @staticmethod 168 | def center_pointcloud_functional(points): 169 | operator = Normalization(points, inplace=False) 170 | return operator.center_pointcloud() 171 | 172 | def normalize_unitL2ball(self): 173 | """ 174 | In-place normalization of input to unit ball 175 | :param points: torch Tensor Batch, N_pts, D_dim 176 | :return: None 177 | """ 178 | # input : torch Tensor N_pts, D_dim 179 | # ouput : torch Tensor N_pts, D_dim 180 | # 181 | self.center_pointcloud() 182 | scaling_factor_square, _ = torch.max(torch.sum(self.points ** 2, dim=2, keepdim=True), dim=1, keepdim=True) 183 | scaling_factor = torch.sqrt(scaling_factor_square) 184 | self.scale(1.0 / scaling_factor) 185 | return self.points 186 | 187 | @staticmethod 188 | def normalize_unitL2ball_functional(points): 189 | operator = Normalization(points, inplace=False) 190 | return operator.normalize_unitL2ball() 191 | 192 | def center_bounding_box(self): 193 | """ 194 | in place Centering : return center the bounding box 195 | :param points: torch Tensor Batch, N_pts, D_dim 196 | :return: diameter 197 | """ 198 | min_vals, _ = torch.min(self.points, 1, keepdim=True) 199 | max_vals, _ = torch.max(self.points, 1, keepdim=True) 200 | self.translate(-(min_vals + max_vals) / 2) 201 | return self.points, (max_vals - min_vals) / 2 202 | 203 | @staticmethod 204 | def center_bounding_box_functional(points): 205 | operator = Normalization(points, inplace=False) 206 | points, _ = operator.center_bounding_box() 207 | return points 208 | 209 | def normalize_bounding_box(self, isotropic=True): 210 | """ 211 | In place : center the bounding box and uniformly scale the bounding box to edge lenght 1 or max edge length 1 if isotropic is True (default). 212 | :param points: torch Tensor Batch, N_pts, D_dim 213 | :return: 214 | """ 215 | _, diameter = self.center_bounding_box() 216 | if isotropic: 217 | diameter, _ = torch.max(diameter, 2, keepdim=True) 218 | self.scale(1.0 / diameter) 219 | return self.points 220 | 221 | @staticmethod 222 | def normalize_bounding_box_functional(points): 223 | operator = Normalization(points, inplace=False) 224 | return operator.normalize_bounding_box() 225 | 226 | @staticmethod 227 | def identity_functional(points): 228 | return points 229 | 230 | 231 | class DataAugmentation(Operation): 232 | def __init__(self, *args, **kwargs): 233 | super(DataAugmentation, self).__init__(*args, **kwargs) 234 | 235 | def random_anisotropic_scaling(self, min_val=0.75, max_val=1.25): 236 | """ 237 | In place : Random Anisotropic scaling by a factor between min_val and max_val 238 | :param points: torch Tensor Batch, N_pts, D_dim 239 | :return: 240 | """ 241 | scale = torch.rand(self.points.size(0), 1, self.points.size(2)) * (max_val - min_val) + min_val 242 | self.scale(scale) 243 | return 244 | 245 | def random_axial_rotation(self, axis=0, normals=False, range_rot=360): 246 | """ 247 | Compute a random rotation of the batch around an axis. There is is no in-place version of this function because bmm_ is not possible in pytorch. 248 | :param points: torch Tensor Batch, N_pts, D_dim 249 | :return: torch Tensor Batch, N_pts, D_dim 250 | """ 251 | scale_factor = 360.0 / range_rot 252 | scale_factor = np.pi / scale_factor 253 | rad_angle = torch.rand(self.points.size(0)) * 2 * scale_factor - scale_factor 254 | self.rotate_axis_angle(axis=axis, rad_angle=rad_angle, normals=normals) 255 | return 256 | 257 | @staticmethod 258 | def get_random_rotation_matrix(batch_size=1): 259 | """ 260 | Get a random 3D rotation matrix 261 | :return: Rotation Matrix as a tensor 262 | from : https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations 263 | An easy way to do this : sample a point on the sphere (with normalize(normal(), normal(), normal()) 264 | then sample an angle, then just compute the associated rotation matrix 265 | """ 266 | # Select a random point on the sphere 267 | x = torch.randn(batch_size, 1, 3).double() 268 | scaling_factor_square, _ = torch.max(torch.sum(x ** 2, dim=2, keepdim=True), dim=1, keepdim=True) 269 | scaling_factor = torch.sqrt(scaling_factor_square) 270 | x /= scaling_factor 271 | x = x.squeeze() 272 | XX = torch.bmm(x.unsqueeze(2), x.unsqueeze(1)) 273 | 274 | # get random angle 275 | rad_angle = torch.rand(batch_size).double() * 2 * np.pi + np.pi 276 | cos_angle = torch.cos(rad_angle) 277 | sin_angle = torch.sin(rad_angle) 278 | 279 | # Compute fat matrix 280 | rotation_matrix = torch.zeros(rad_angle.size(0), 3, 3).double() 281 | 282 | rotation_matrix[:, 0, 0].copy_(cos_angle + XX[:, 0, 0] * (1 - cos_angle)) 283 | rotation_matrix[:, 1, 1].copy_(cos_angle + XX[:, 1, 1] * (1 - cos_angle)) 284 | rotation_matrix[:, 2, 2].copy_(cos_angle + XX[:, 2, 2] * (1 - cos_angle)) 285 | 286 | rotation_matrix[:, 0, 1].copy_(XX[:, 0, 1] * (1 - cos_angle) - x[:, 2] * sin_angle) 287 | rotation_matrix[:, 1, 0].copy_(XX[:, 0, 1] * (1 - cos_angle) + x[:, 2] * sin_angle) 288 | 289 | rotation_matrix[:, 0, 2].copy_(XX[:, 0, 2] * (1 - cos_angle) + x[:, 1] * sin_angle) 290 | rotation_matrix[:, 2, 0].copy_(XX[:, 0, 2] * (1 - cos_angle) - x[:, 1] * sin_angle) 291 | 292 | rotation_matrix[:, 1, 2].copy_(XX[:, 1, 2] * (1 - cos_angle) - x[:, 0] * sin_angle) 293 | rotation_matrix[:, 2, 1].copy_(XX[:, 1, 2] * (1 - cos_angle) + x[:, 0] * sin_angle) 294 | 295 | return rotation_matrix 296 | 297 | def random_rotation(self, normals=False): 298 | """ 299 | Compute a random rotation of the batch. There is is no in-place version of this function because bmm_ is not possible in pytorch. 300 | :param points: torch Tensor Batch, N_pts, D_dim 301 | :return: torch Tensor Batch, N_pts, D_dim 302 | """ 303 | rot_matrix = DataAugmentation.get_random_rotation_matrix(batch_size=self.points.size(0)) 304 | if normals: 305 | rot_matrix = torch.cat([rot_matrix, rot_matrix], dim=2) 306 | self.rotate(rot_matrix) 307 | return 308 | 309 | def random_translation(self, scale=0.03): 310 | """ 311 | In place Compute a random tranlation of the batch. 312 | :param points: torch Tensor Batch, N_pts, D_dim 313 | :return: 314 | """ 315 | translation_vector = torch.rand(self.points.size(0), 1, self.points.size(2)) * 2 * scale - scale 316 | self.translate(translation_vector) 317 | return 318 | 319 | @staticmethod 320 | def diff(first, second): 321 | second = set(second) 322 | return [item for item in first if item not in second] 323 | 324 | def random_flips(self, dims=[]): 325 | """ 326 | In place Random flip 327 | :param points: torch Tensor Batch, N_pts, D_dim 328 | :return: 329 | """ 330 | exclude_dims = DataAugmentation.diff(range(self.points.size(2)), dims) 331 | scale_factor = torch.randint(2, (self.points.size(0), 1, self.points.size(2))) * 2 - 1 332 | for axis in exclude_dims: 333 | scale_factor[:, :, axis].fill_(1) 334 | self.scale(scale_factor) 335 | return 336 | 337 | 338 | # Done for eurographics 19 339 | def barycentric(p, a, b, c): 340 | """ 341 | :param p: numpy arrays of size N_points x 3 342 | :param a: numpy arrays of size N_points x 3 343 | :param b: numpy arrays of size N_points x 3 344 | :param c: numpy arrays of size N_points x 3 345 | :return: barycentric coordinates point p in triangle (a,b,c) 346 | """ 347 | 348 | v0 = b - a 349 | v1 = c - a 350 | v2 = p - a 351 | 352 | d00 = np.sum(np.multiply(v0, v0), 1) 353 | d01 = np.sum(np.multiply(v0, v1), 1) 354 | d11 = np.sum(np.multiply(v1, v1), 1) 355 | d20 = np.sum(np.multiply(v2, v0), 1) 356 | d21 = np.sum(np.multiply(v2, v1), 1) 357 | 358 | denom = np.multiply(d00, d11) - np.multiply(d01, d01) 359 | 360 | v = (np.multiply(d11, d20) - np.multiply(d01, d21)) / denom 361 | w = (np.multiply(d00, d21) - np.multiply(d01, d20)) / denom 362 | u = - v - w + 1.0 363 | 364 | return (u, v, w) 365 | 366 | 367 | if __name__ == '__main__': 368 | print("Start unit test") 369 | -------------------------------------------------------------------------------- /src/dataset_svr/trainer_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import munch 5 | # import dataset_shapenet 6 | import dataset_svr.dataset_shapenet as dataset_shapenet 7 | import yaml 8 | 9 | def pc_normalize(pc, radius): 10 | l = pc.shape[0] 11 | centroid = np.mean(pc, axis=0) 12 | pc = pc - centroid 13 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 14 | pc = pc / m * radius 15 | return pc 16 | 17 | def get_spherepoints(num_points, radius): 18 | ball_name = 'Path/balls/%d.xyz' % num_points 19 | ball = np.loadtxt(ball_name) 20 | ball = pc_normalize(ball, radius) 21 | return ball 22 | def build_dataset(args): 23 | # Create Datasets 24 | dataset_train = dataset_shapenet.ShapeNet(args, train=True) 25 | dataset_test = dataset_shapenet.ShapeNet(args, train=False) 26 | 27 | # Create dataloaders 28 | dataloader_train = torch.utils.data.DataLoader(dataset_train, 29 | batch_size=args.batch_size, 30 | shuffle=True, 31 | num_workers=int(args.workers)) 32 | dataloader_test = torch.utils.data.DataLoader(dataset_test, 33 | batch_size=args.batch_size, 34 | shuffle=False, num_workers=int(args.workers)) 35 | 36 | len_dataset = len(dataset_train) 37 | len_dataset_test = len(dataset_test) 38 | print('Length of train dataset:%d', len_dataset) 39 | print('Length of test dataset:%d', len_dataset_test) 40 | 41 | return dataloader_train, dataloader_test 42 | 43 | def build_dataset_val(args): 44 | 45 | # Create Datasets 46 | dataset_test = dataset_shapenet.ShapeNet_val(args, train=False) 47 | 48 | # Create dataloaders 49 | dataloader_test = torch.utils.data.DataLoader(dataset_test, 50 | batch_size=args.batch_size, 51 | shuffle=False, num_workers=int(args.workers)) 52 | 53 | len_dataset_test = len(dataset_test) 54 | print('Length of test dataset:%d', len_dataset_test) 55 | 56 | return dataloader_test 57 | 58 | if __name__ == '__main__': 59 | config_path = "Path/MAE/config.yaml" 60 | args = munch.munchify(yaml.safe_load(open(config_path))) 61 | dataloader_test = build_dataset_val(args) 62 | for data in dataloader_test: 63 | img = data['image'] 64 | pc = data['points'] 65 | print(img.shape) 66 | print(pc.shape) 67 | print('done') -------------------------------------------------------------------------------- /src/dataset_svr/trainer_text_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import munch 4 | import dataset_svr.dataset_shapenet_text as dataset_shapenet 5 | import yaml 6 | import numpy as np 7 | def pc_normalize(pc, radius): 8 | l = pc.shape[0] 9 | centroid = np.mean(pc, axis=0) 10 | pc = pc - centroid 11 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 12 | pc = pc / m * radius 13 | return pc 14 | 15 | def get_spherepoints(num_points, radius): 16 | ball_name = 'Path/balls/%d.xyz' % num_points 17 | ball = np.loadtxt(ball_name) 18 | ball = pc_normalize(ball, radius) 19 | return ball 20 | 21 | def build_dataset(args): 22 | # Create Datasets 23 | dataset_train = dataset_shapenet.ShapeNet(args, train=True) 24 | dataset_test = dataset_shapenet.ShapeNet(args, train=False) 25 | 26 | # Create dataloaders 27 | dataloader_train = torch.utils.data.DataLoader(dataset_train, 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | num_workers=int(args.workers)) 31 | dataloader_test = torch.utils.data.DataLoader(dataset_test, 32 | batch_size=args.batch_size, 33 | shuffle=False, num_workers=int(args.workers)) 34 | 35 | len_dataset = len(dataset_train) 36 | len_dataset_test = len(dataset_test) 37 | print('Length of train dataset:%d', len_dataset) 38 | print('Length of test dataset:%d', len_dataset_test) 39 | 40 | return dataloader_train, dataloader_test 41 | 42 | def build_dataset_val(args): 43 | 44 | # Create Datasets 45 | dataset_test = dataset_shapenet.ShapeNet_val(args, train=False) 46 | 47 | # Create dataloaders 48 | dataloader_test = torch.utils.data.DataLoader(dataset_test, 49 | batch_size=args.batch_size, 50 | shuffle=False, num_workers=int(args.workers)) 51 | 52 | len_dataset_test = len(dataset_test) 53 | print('Length of test dataset:%d', len_dataset_test) 54 | 55 | return dataloader_test 56 | def get_batch_label(texts, prompt_text, label_map: dict): 57 | label_vectors = torch.zeros(0) 58 | if len(label_map) != 7: 59 | if len(label_map) == 2: 60 | for text in texts: 61 | label_vector = torch.zeros(2) 62 | if text == 'Normal': 63 | label_vector[0] = 1 64 | else: 65 | label_vector[1] = 1 66 | label_vector = label_vector.unsqueeze(0) 67 | label_vectors = torch.cat([label_vectors, label_vector], dim=0) 68 | else: 69 | for text in texts: 70 | label_vector = torch.zeros(len(prompt_text)) 71 | if text in label_map: 72 | label_text = label_map[text] 73 | label_vector[prompt_text.index(label_text)] = 1 74 | 75 | label_vector = label_vector.unsqueeze(0) 76 | label_vectors = torch.cat([label_vectors, label_vector], dim=0) 77 | else: 78 | for text in texts: 79 | label_vector = torch.zeros(len(prompt_text)) 80 | labels = text.split('-') 81 | for label in labels: 82 | if label in label_map: 83 | label_text = label_map[label] 84 | label_vector[prompt_text.index(label_text)] = 1 85 | 86 | label_vector = label_vector.unsqueeze(0) 87 | label_vectors = torch.cat([label_vectors, label_vector], dim=0) 88 | 89 | return label_vectors 90 | # def get_prompt_text(category:list,label_map: dict): 91 | # prompt_text = [] 92 | # for v in label_map.values(): 93 | # prompt_text.append(v) 94 | 95 | # return prompt_text 96 | def get_prompt_text(category: list, label_map: dict) -> list: 97 | prompt_text = [] 98 | for cat in category: 99 | if cat in label_map: 100 | prompt_text.append(label_map[cat]) 101 | else: 102 | 103 | prompt_text.append('Unknown') 104 | return prompt_text 105 | if __name__ == '__main__': 106 | from tqdm import tqdm 107 | from models.tokenizer import SimpleTokenizer 108 | from models.ULIP_models import ULIP_PointBERT 109 | from collections import OrderedDict 110 | config_path = 'Path/MAE/config.yaml' 111 | args = munch.munchify(yaml.safe_load(open(config_path))) 112 | dataloader_train, dataloader_test = build_dataset(args) 113 | tokenizer = SimpleTokenizer() 114 | label_map = dict({'02691156':'airplane','02828884':'bench','02933112':'cabinet','02958343':'car','03001627':'chair', 115 | '03211117':'display','03636649':'lamp','03691459':'loudspeaker','04090263':'rifle', 116 | '04256520':'sofa','04379243':'table','04401088':'telephone','04530566':'vessel'}) 117 | # prompt_text = get_prompt_text(label_map) 118 | # print(f'prompt_text:{prompt_text}') 119 | config_path = "Path/cfgs/ULIP.yaml" 120 | args = munch.munchify(yaml.safe_load(open(config_path))) 121 | ULIP = ULIP_PointBERT(args).to('cuda:7') 122 | ckpt = torch.load('Path/checkpoints/pretrained_models_ckpt_zero-sho_classification_checkpoint_pointbert.pt') 123 | state_dict = OrderedDict() 124 | for k, v in ckpt["state_dict"].items(): 125 | # print(f'k {k}') 126 | state_dict[k.replace("module.", "")] = v 127 | ULIP.load_state_dict(state_dict, strict=False) 128 | tokenized_captions = [] 129 | with tqdm(dataloader_train) as t: 130 | for batch_idx,data in enumerate(t): 131 | # if batch_idx == 0: 132 | # category 133 | tokenized_captions = [] 134 | category = data['category'] 135 | text_labels = list(category) 136 | # text_labels = get_batch_label(text_labels, prompt_text, label_map) 137 | prompt_text = get_prompt_text(text_labels,label_map) 138 | path = data['image_path'] 139 | # captions 140 | caption = data["caption"] 141 | print(f'caption : {caption}') 142 | caption = list(zip(*caption)) 143 | for i in range(len(caption)): 144 | caption[i] = list(caption[i]) 145 | for i in range(len(caption)): 146 | texts = tokenizer(caption[i]).cuda(7, non_blocking=True) 147 | tokenized_captions.append(texts) 148 | tokenized_captions = torch.stack(tokenized_captions) 149 | text_features = [] 150 | for i in range(tokenized_captions.shape[0]): 151 | text_for_one_sample = tokenized_captions[i] 152 | text_embed = ULIP.encode_text(text_for_one_sample) 153 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 154 | text_embed = text_embed.mean(dim=0) 155 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 156 | text_features.append(text_embed) 157 | text_features = torch.stack(text_features) 158 | 159 | # print(f'tokenized {tokenized_captions.shape}') 160 | # print(f"catefory {category}") 161 | # print(f"text_labels {text_labels}") 162 | # print(f'prompt_text:{prompt_text}') 163 | print(f'text_embed {text_features.shape}') 164 | # print(f'path {path}') 165 | -------------------------------------------------------------------------------- /src/loss/cdloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from chamfer_distance.chamfer_distance import ChamferDistance 4 | def fscore(dist1, dist2, threshold=0.01): 5 | """ 6 | Calculates the F-score between two point clouds with the corresponding threshold value. 7 | :param dist1: Batch, N-Points 8 | :param dist2: Batch, N-Points 9 | :param th: float 10 | :return: fscore, precision, recall 11 | """ 12 | # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly. 13 | precision_1 = torch.mean((dist1 < threshold).float(), dim=1) 14 | precision_2 = torch.mean((dist2 < threshold).float(), dim=1) 15 | fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) 16 | fscore[torch.isnan(fscore)] = 0 17 | return fscore, precision_1, precision_2 18 | class SimplificationLoss(nn.Module): 19 | def __init__(self): 20 | super(SimplificationLoss, self).__init__() 21 | 22 | def forward(self, ref_pc, samp_pc,calc_f1=False): 23 | 24 | cost_p1_p2, cost_p2_p1 = ChamferDistance()(samp_pc, ref_pc) 25 | cd_p = (torch.sqrt(cost_p1_p2).mean(1) + torch.sqrt(cost_p2_p1).mean(1)) / 2 # l1 26 | cd_t = (cost_p1_p2.mean(1) + cost_p2_p1.mean(1)) #l2 27 | if calc_f1: 28 | f1,_,_ =fscore(cost_p1_p2,cost_p2_p1) 29 | return cd_p,cd_t,f1 30 | 31 | else: 32 | return cd_p,cd_t -------------------------------------------------------------------------------- /src/loss/losses.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Le Xue 7 | ''' 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from ULIP_utils import all_gather_batch,get_rank 13 | 14 | loss = torch.nn.KLDivLoss(reduction='batchmean') 15 | 16 | def KL_loss(student_features,teacher_features): 17 | feature_loss = loss(F.log_softmax(student_features, dim=-1), F.softmax(teacher_features, dim=-1)) 18 | return feature_loss 19 | 20 | labels = torch.tensor([1, 0, 1, 0], dtype=torch.float32) 21 | def contrastive_loss(student_features, teacher_features, margin=1.0): 22 | distances = (student_features - teacher_features).pow(2).sum(1) # 欧几里得距离的平方 23 | loss = (1 - labels) * distances + labels * F.relu(margin - distances.sqrt()).pow(2) 24 | return loss.mean() 25 | 26 | class ULIPWithImageLoss(nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | self.labels = None 30 | self.last_local_batch_size = None 31 | 32 | def forward(self, outputs): 33 | pc_embed = outputs['pc_embed'] 34 | text_embed = outputs['text_embed'] 35 | image_embed = outputs['image_embed'] 36 | logit_scale = outputs['logit_scale'] 37 | local_batch_size = pc_embed.size(0) 38 | 39 | if local_batch_size != self.last_local_batch_size: 40 | self.labels = local_batch_size * get_rank() + torch.arange( 41 | local_batch_size, device=pc_embed.device 42 | ) 43 | self.last_local_batch_size = local_batch_size 44 | 45 | # normalized features 46 | # pc_embed = F.normalize(pc_embed, dim=-1, p=2) 47 | # text_embed = F.normalize(text_embed, dim=-1, p=2) 48 | # image_embed = F.normalize(image_embed, dim=-1, p=2) 49 | # # gather features from all GPUs 50 | # pc_embed_all, text_embed_all, image_embed_all = \ 51 | # all_gather_batch([pc_embed, text_embed, image_embed]) 52 | # cosine similarity as logits 53 | # logits_per_pc_text = logit_scale * pc_embed @ text_embed_all.t() 54 | # logits_per_text_pc = logit_scale * text_embed @ pc_embed_all.t() 55 | # logits_per_text_image = logit_scale * text_embed @ image_embed_all.t() 56 | # logits_per_image_text = logit_scale * image_embed @ text_embed_all.t() 57 | 58 | # logits_per_pc_text = F.cosine_similarity(text_embed,pc_embed) 59 | logits_per_text_image = F.cosine_similarity(text_embed,image_embed) 60 | # ulip_pc_text = logit_scale - logits_per_pc_text.mean() 61 | ulip_text_image = logit_scale - logits_per_text_image.mean() 62 | 63 | # loss = 0.9*ulip_pc_text + 0.1*ulip_text_image 64 | loss = ulip_text_image 65 | 66 | return {'loss': loss, 'ulip_loss': loss, 'ulip_text_pc_sim':0, 'ulip_text_image_sim': ulip_text_image} 67 | -------------------------------------------------------------------------------- /src/loss/losses_v2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Le Xue 7 | ''' 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from models.ULIP_utils import all_gather_batch,get_rank 13 | 14 | loss = torch.nn.KLDivLoss(reduction='batchmean') 15 | 16 | def KL_loss(student_features,teacher_features): 17 | feature_loss = loss(F.log_softmax(student_features, dim=-1), F.softmax(teacher_features, dim=-1)) 18 | return feature_loss 19 | def contrastive_loss(e_s, e_t): 20 | """ 21 | Compute the contrastive loss between shape (e_s) and image (e_t) embeddings. 22 | 23 | Args: 24 | e_s: Tensor of shape (N, C) 25 | e_t: Tensor of shape (N, C) 26 | 27 | Returns: 28 | loss: Contrastive loss value 29 | """ 30 | N, C = e_s.size() 31 | 32 | # Normalize embeddings 33 | e_s = F.normalize(e_s, p=2, dim=1) 34 | e_t = F.normalize(e_t, p=2, dim=1) 35 | 36 | # Compute the similarity matrix 37 | sim_matrix = torch.matmul(e_s, e_t.T) # Shape: (N, N) 38 | 39 | # Compute the log probabilities 40 | log_prob_s = F.log_softmax(sim_matrix, dim=1) # Shape: (N, N) 41 | log_prob_t = F.log_softmax(sim_matrix.T, dim=1) # Shape: (N, N) 42 | 43 | # Contrastive loss 44 | loss_s = -torch.mean(torch.diag(log_prob_s)) 45 | loss_t = -torch.mean(torch.diag(log_prob_t)) 46 | 47 | # Total loss 48 | loss = (loss_s + loss_t) / 2 49 | 50 | return loss 51 | 52 | class ULIPWithImageLoss(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | self.labels = None 56 | self.last_local_batch_size = None 57 | 58 | def forward(self, outputs): 59 | pc_embed = outputs['pc_embed'] 60 | text_embed = outputs['text_embed'] 61 | image_embed = outputs['image_embed'] 62 | logit_scale = outputs['logit_scale'] 63 | local_batch_size = pc_embed.size(0) 64 | 65 | if local_batch_size != self.last_local_batch_size: 66 | self.labels = local_batch_size * get_rank() + torch.arange( 67 | local_batch_size, device=pc_embed.device 68 | ) 69 | self.last_local_batch_size = local_batch_size 70 | 71 | #normalized features 72 | pc_embed = F.normalize(pc_embed, dim=-1, p=2) 73 | text_embed = F.normalize(text_embed, dim=-1, p=2) 74 | image_embed = F.normalize(image_embed, dim=-1, p=2) 75 | # gather features from all GPUs 76 | pc_embed_all, text_embed_all, image_embed_all = \ 77 | all_gather_batch([pc_embed, text_embed, image_embed]) 78 | 79 | #cosine similarity as logits #[batch_size,batch_size] 80 | 81 | logits_per_pc_text = logit_scale * pc_embed @ text_embed_all.t() 82 | # logits_per_image_text = logit_scale * image_embed @ text_embed_all.t() 83 | 84 | 85 | ulip_pc_text = logit_scale - torch.diag(logits_per_pc_text).mean() 86 | # ulip_pc_image = logit_scale - torch.diag(logits_per_pc_image).mean() 87 | # ulip_text_image = logit_scale - torch.diag(logits_per_image_text).mean() 88 | # kl = KL_loss(text_embed,pc_embed).mean() 89 | con_loss = contrastive_loss(pc_embed,text_embed_all) 90 | # ctr = contrastive_loss(text_embed,pc_embed) 91 | # loss = ulip_pc_text + 0.2*ulip_text_image 92 | # loss = ulip_pc_text + 0.05 *ulip_text_image 93 | # loss = ulip_pc_text + kl 94 | loss = con_loss 95 | return {'loss': loss, 'ulip_loss': loss, 'ulip_text_pc_sim': ulip_pc_text, 'ulip_text_image_sim': 0} 96 | -------------------------------------------------------------------------------- /src/models/ULIP_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import numpy as np 4 | import shutil 5 | import os 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | def __init__(self, name, fmt=':f'): 9 | self.name = name 10 | self.fmt = fmt 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def synchronize(self): 26 | if not is_dist_avail_and_initialized(): 27 | return 28 | t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda') 29 | dist.barrier() 30 | dist.all_reduce(t) 31 | t = t.tolist() 32 | self.sum = int(t[0]) 33 | self.count = t[1] 34 | self.avg = self.sum / self.count 35 | 36 | def __str__(self): 37 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 38 | return fmtstr.format(**self.__dict__) 39 | 40 | def get_model(model): 41 | if isinstance(model, torch.nn.DataParallel) \ 42 | or isinstance(model, torch.nn.parallel.DistributedDataParallel): 43 | return model.module 44 | else: 45 | return model 46 | def is_dist_avail_and_initialized(): 47 | if not dist.is_available(): 48 | return False 49 | if not dist.is_initialized(): 50 | return False 51 | return True 52 | 53 | 54 | def get_world_size(): 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | return dist.get_world_size() 58 | 59 | 60 | def get_rank(): 61 | if not is_dist_avail_and_initialized(): 62 | return 0 63 | return dist.get_rank() 64 | 65 | 66 | def is_main_process(): 67 | return get_rank() == 0 68 | 69 | class ProgressMeter(object): 70 | def __init__(self, num_batches, meters, prefix=""): 71 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 72 | self.meters = meters 73 | self.prefix = prefix 74 | 75 | def display(self, batch): 76 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 77 | entries += [str(meter) for meter in self.meters] 78 | print('\t'.join(entries)) 79 | 80 | def synchronize(self): 81 | for meter in self.meters: 82 | meter.synchronize() 83 | 84 | def _get_batch_fmtstr(self, num_batches): 85 | num_digits = len(str(num_batches // 1)) 86 | fmt = '{:' + str(num_digits) + 'd}' 87 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 88 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): 89 | warmup_schedule = np.array([]) 90 | warmup_iters = warmup_epochs * niter_per_ep 91 | if warmup_epochs > 0: 92 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 93 | 94 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 95 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 96 | 97 | schedule = np.concatenate((warmup_schedule, schedule)) 98 | assert len(schedule) == epochs * niter_per_ep 99 | return schedule 100 | 101 | def scaled_all_reduce(tensors, is_scale=True): 102 | """Performs the scaled all_reduce operation on the provided tensors. 103 | The input tensors are modified in-place. Currently supports only the sum 104 | reduction operator. The reduced values are scaled by the inverse size of the 105 | world size. 106 | """ 107 | world_size = get_world_size() 108 | # There is no need for reduction in the single-proc case 109 | if world_size == 1: 110 | return tensors 111 | # Queue the reductions 112 | reductions = [] 113 | for tensor in tensors: 114 | reduction = dist.all_reduce(tensor, async_op=True) 115 | reductions.append(reduction) 116 | # Wait for reductions to finish 117 | for reduction in reductions: 118 | reduction.wait() 119 | # Scale the results 120 | if is_scale: 121 | for tensor in tensors: 122 | tensor.mul_(1.0 / world_size) 123 | return tensors 124 | def save_on_master(state, is_best, output_dir): 125 | if is_main_process(): 126 | ckpt_path = '{}/checkpoint_{}.pt'.format(output_dir, state['epoch']) 127 | best_path = f'{output_dir}/checkpoint_best.pt' 128 | torch.save(state, ckpt_path) 129 | if is_best: 130 | shutil.copyfile(ckpt_path, best_path) 131 | if os.path.exists(ckpt_path): 132 | os.remove(ckpt_path) 133 | 134 | def all_gather_batch(tensors): 135 | """ 136 | Performs all_gather operation on the provided tensors. 137 | """ 138 | # Queue the gathered tensors 139 | world_size = get_world_size() 140 | # There is no need for reduction in the single-proc case 141 | if world_size == 1: 142 | return tensors 143 | tensor_list = [] 144 | output_tensor = [] 145 | for tensor in tensors: 146 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 147 | dist.all_gather( 148 | tensor_all, 149 | tensor, 150 | async_op=False # performance opt 151 | ) 152 | 153 | tensor_list.append(tensor_all) 154 | 155 | for tensor_all in tensor_list: 156 | output_tensor.append(torch.cat(tensor_all, dim=0)) 157 | return output_tensor -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model_from_cfg 2 | import models.Point_MAE -------------------------------------------------------------------------------- /src/models/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /src/models/build.py: -------------------------------------------------------------------------------- 1 | from utils import registry 2 | 3 | 4 | MODELS = registry.Registry('models') 5 | 6 | 7 | def build_model_from_cfg(cfg, **kwargs): 8 | """ 9 | Build a dataset, defined by `dataset_name`. 10 | Args: 11 | cfg (eDICT): 12 | Returns: 13 | Dataset: a constructed dataset specified by dataset_name. 14 | """ 15 | return MODELS.build(cfg, **kwargs) 16 | 17 | 18 | -------------------------------------------------------------------------------- /src/models/clip_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import math 3 | from typing import Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from timm.models.layers import drop_path 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1): 16 | super().__init__() 17 | 18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | 22 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu(self.bn1(self.conv1(x))) 46 | out = self.relu(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | self.embed_dim = embed_dim 68 | self.spacial_dim = spacial_dim 69 | 70 | def forward(self, x): 71 | B, C, H, W = x.shape 72 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 73 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 74 | 75 | cls_pos = self.positional_embedding[0:1, :] 76 | spatial_pos = F.interpolate(self.positional_embedding[1:,].reshape(1, self.spacial_dim, self.spacial_dim, self.embed_dim).permute(0, 3, 1, 2), size=(H, W), mode='bilinear') 77 | spatial_pos = spatial_pos.reshape(self.embed_dim, H*W).permute(1, 0) 78 | positional_embedding = torch.cat([cls_pos, spatial_pos], dim=0) 79 | 80 | x = x + positional_embedding[:, None, :] 81 | x, _ = F.multi_head_attention_forward( 82 | query=x, key=x, value=x, 83 | embed_dim_to_check=x.shape[-1], 84 | num_heads=self.num_heads, 85 | q_proj_weight=self.q_proj.weight, 86 | k_proj_weight=self.k_proj.weight, 87 | v_proj_weight=self.v_proj.weight, 88 | in_proj_weight=None, 89 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 90 | bias_k=None, 91 | bias_v=None, 92 | add_zero_attn=False, 93 | dropout_p=0, 94 | out_proj_weight=self.c_proj.weight, 95 | out_proj_bias=self.c_proj.bias, 96 | use_separate_proj_weight=True, 97 | training=self.training, 98 | need_weights=False 99 | ) 100 | 101 | x = x.permute(1, 2, 0) 102 | global_feat = x[:, :, 0] 103 | feature_map = x[:, :, 1:].reshape(B, -1, H, W) 104 | return global_feat, feature_map 105 | 106 | 107 | class LayerNorm(nn.LayerNorm): 108 | """Subclass torch's LayerNorm to handle fp16.""" 109 | 110 | def forward(self, x: torch.Tensor): 111 | orig_type = x.dtype 112 | ret = super().forward(x.type(torch.float32)) 113 | return ret.type(orig_type) 114 | 115 | 116 | class QuickGELU(nn.Module): 117 | def forward(self, x: torch.Tensor): 118 | return x * torch.sigmoid(1.702 * x) 119 | 120 | 121 | class DropPath(nn.Module): 122 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 123 | """ 124 | def __init__(self, drop_prob=None): 125 | super(DropPath, self).__init__() 126 | self.drop_prob = drop_prob 127 | 128 | def forward(self, x): 129 | return drop_path(x, self.drop_prob, self.training) 130 | 131 | def extra_repr(self) -> str: 132 | return 'p={}'.format(self.drop_prob) 133 | 134 | 135 | class ResidualAttentionBlock(nn.Module): 136 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path=0.): 137 | super().__init__() 138 | 139 | self.attn = nn.MultiheadAttention(d_model, n_head) 140 | self.ln_1 = LayerNorm(d_model) 141 | self.mlp = nn.Sequential(OrderedDict([ 142 | ("c_fc", nn.Linear(d_model, d_model * 4)), 143 | ("gelu", QuickGELU()), 144 | ("c_proj", nn.Linear(d_model * 4, d_model)) 145 | ])) 146 | self.ln_2 = LayerNorm(d_model) 147 | self.attn_mask = attn_mask 148 | 149 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 150 | 151 | def attention(self, x: torch.Tensor): 152 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 153 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 154 | 155 | def attention_weight(self, x: torch.Tensor): 156 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 157 | return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)[1] 158 | 159 | def forward(self, x: torch.Tensor, return_attention: bool=False): 160 | x = x + self.attention(self.ln_1(x)) 161 | x = x + self.mlp(self.ln_2(x)) 162 | return x 163 | 164 | class Transformer(nn.Module): 165 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, drop_path_rate=0.): 166 | super().__init__() 167 | 168 | self.width = width 169 | self.layers = layers 170 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers)] # stochastic depth decay rule 171 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers)]) 172 | 173 | def forward(self, x: torch.Tensor): 174 | return self.resblocks(x) 175 | 176 | # ADDED 177 | def forward_attention(self, x: torch.Tensor): 178 | for index, layer in enumerate(self.resblocks): 179 | if index == len(self.resblocks) - 1: 180 | return layer(x, return_attention=True) 181 | x = layer(x) 182 | 183 | 184 | class Attention(nn.Module): 185 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 186 | super().__init__() 187 | self.num_heads = num_heads 188 | head_dim = dim // num_heads 189 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 190 | self.scale = qk_scale or head_dim ** -0.5 191 | 192 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) 193 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias) 194 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias) 195 | 196 | 197 | self.attn_drop = nn.Dropout(attn_drop) 198 | self.proj = nn.Linear(dim, dim) 199 | self.proj_drop = nn.Dropout(proj_drop) 200 | 201 | def forward(self, q, k, v): 202 | B, N, C = q.shape 203 | assert k.shape == v.shape 204 | B, M, C = k.shape 205 | q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads) 206 | k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads) 207 | v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads) 208 | 209 | attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale 210 | 211 | attn = attn.softmax(dim=-1) 212 | 213 | x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C) 214 | 215 | x = self.proj(x) 216 | x = self.proj_drop(x) 217 | return x 218 | -------------------------------------------------------------------------------- /src/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | # From : https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | 99 | def __init__(self, block, layers, num_classes=1000): 100 | self.inplanes = 64 101 | super(ResNet, self).__init__() 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 103 | bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | self.avgpool = nn.AvgPool2d(7) 112 | self.fc = nn.Linear(512 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | class ResNetReturnMid(nn.Module): 157 | 158 | def __init__(self, block, layers, num_classes=1000): 159 | self.inplanes = 64 160 | super(ResNetReturnMid, self).__init__() 161 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 162 | bias=False) 163 | self.bn1 = nn.BatchNorm2d(64) 164 | self.relu = nn.ReLU(inplace=True) 165 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 166 | self.layer1 = self._make_layer(block, 64, layers[0]) 167 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 168 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 169 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 170 | self.avgpool = nn.AvgPool2d(7) 171 | self.fc = nn.Linear(512 * block.expansion, num_classes) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 176 | m.weight.data.normal_(0, math.sqrt(2. / n)) 177 | elif isinstance(m, nn.BatchNorm2d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | 181 | def _make_layer(self, block, planes, blocks, stride=1): 182 | downsample = None 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | nn.Conv2d(self.inplanes, planes * block.expansion, 186 | kernel_size=1, stride=stride, bias=False), 187 | nn.BatchNorm2d(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample)) 192 | self.inplanes = planes * block.expansion 193 | for i in range(1, blocks): 194 | layers.append(block(self.inplanes, planes)) 195 | 196 | return nn.Sequential(*layers) 197 | 198 | def forward(self, x): 199 | x = self.conv1(x) 200 | x = self.bn1(x) 201 | x = self.relu(x) 202 | x = self.maxpool(x) 203 | 204 | x = self.layer1(x) 205 | x = self.layer2(x) 206 | x1 = x 207 | x = self.layer3(x) 208 | x = self.layer4(x) 209 | 210 | x = self.avgpool(x) 211 | x = x.view(x.size(0), -1) 212 | x = self.fc(x) 213 | 214 | return x, x1 215 | 216 | def resnet18(pretrained=False, **kwargs): 217 | """Constructs a ResNet-18 model. 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | """ 221 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 222 | if pretrained: 223 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 224 | return model 225 | 226 | def resnet18ReturnMid(pretrained=False, **kwargs): 227 | """Constructs a ResNet-18 model. 228 | Args: 229 | pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | """ 231 | model = ResNetReturnMid(BasicBlock, [2, 2, 2, 2], **kwargs) 232 | if pretrained: 233 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 234 | return model 235 | 236 | def resnet34(pretrained=False, **kwargs): 237 | """Constructs a ResNet-34 model. 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 242 | if pretrained: 243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 244 | return model 245 | 246 | 247 | def resnet50(pretrained=False, **kwargs): 248 | """Constructs a ResNet-50 model. 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | """ 252 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 253 | if pretrained: 254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 255 | return model 256 | 257 | 258 | def resnet101(pretrained=False, **kwargs): 259 | """Constructs a ResNet-101 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | """ 263 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 264 | if pretrained: 265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 266 | return model 267 | 268 | 269 | def resnet152(pretrained=False, **kwargs): 270 | """Constructs a ResNet-152 model. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | """ 274 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 275 | if pretrained: 276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 277 | return model 278 | 279 | 280 | -------------------------------------------------------------------------------- /src/models/pointbert/PointTransformer_8192point.yaml: -------------------------------------------------------------------------------- 1 | optimizer : { 2 | type: AdamW, 3 | kwargs: { 4 | lr : 0.0005, 5 | weight_decay : 0.05 6 | }} 7 | 8 | scheduler: { 9 | type: CosLR, 10 | kwargs: { 11 | epochs: 300, 12 | initial_epochs : 10 13 | }} 14 | 15 | model : { 16 | NAME: PointTransformer, 17 | trans_dim: 384, 18 | depth: 12, 19 | drop_path_rate: 0.1, 20 | cls_dim: 40, 21 | num_heads: 6, 22 | group_size: 32, 23 | num_group: 64, 24 | encoder_dims: 256, 25 | } 26 | npoints: 2048 27 | total_bs : 32 28 | step_per_update : 1 29 | max_epoch : 300 30 | grad_norm_clip : 10 31 | 32 | consider_metric: CDL1 -------------------------------------------------------------------------------- /src/models/pointbert/ULIP_2_PointBERT_10k_colored_pointclouds.yaml: -------------------------------------------------------------------------------- 1 | optimizer : { 2 | type: AdamW, 3 | kwargs: { 4 | lr : 0.0005, 5 | weight_decay : 0.05 6 | }} 7 | 8 | scheduler: { 9 | type: CosLR, 10 | kwargs: { 11 | epochs: 200, 12 | initial_epochs : 10 13 | }} 14 | 15 | model : { 16 | NAME: PointTransformer, 17 | trans_dim: 384, 18 | depth: 18, 19 | drop_path_rate: 0.1, 20 | cls_dim: 40, 21 | num_heads: 6, 22 | group_size: 32, 23 | num_group: 512, 24 | encoder_dims: 256, 25 | } 26 | npoints: 10000 27 | total_bs : 32 28 | step_per_update : 1 29 | max_epoch : 300 30 | grad_norm_clip : 10 31 | 32 | consider_metric: CDL1 -------------------------------------------------------------------------------- /src/models/pointbert/__pycache__/checkpoint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/checkpoint.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/pointbert/__pycache__/dvae.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/dvae.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/pointbert/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/pointbert/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/pointbert/__pycache__/point_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/point_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/models/pointbert/checkpoint.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch.nn as nn 3 | 4 | from typing import Any 5 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable 6 | 7 | from termcolor import colored 8 | 9 | def get_missing_parameters_message(keys: List[str]) -> str: 10 | """ 11 | Get a logging-friendly message to report parameter names (keys) that are in 12 | the model but not found in a checkpoint. 13 | Args: 14 | keys (list[str]): List of keys that were not found in the checkpoint. 15 | Returns: 16 | str: message. 17 | """ 18 | groups = _group_checkpoint_keys(keys) 19 | msg = "Some model parameters or buffers are not found in the checkpoint:\n" 20 | msg += "\n".join( 21 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items() 22 | ) 23 | return msg 24 | 25 | 26 | def get_unexpected_parameters_message(keys: List[str]) -> str: 27 | """ 28 | Get a logging-friendly message to report parameter names (keys) that are in 29 | the checkpoint but not found in the model. 30 | Args: 31 | keys (list[str]): List of keys that were not found in the model. 32 | Returns: 33 | str: message. 34 | """ 35 | groups = _group_checkpoint_keys(keys) 36 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n" 37 | msg += "\n".join( 38 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() 39 | ) 40 | return msg 41 | 42 | 43 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None: 44 | """ 45 | Strip the prefix in metadata, if any. 46 | Args: 47 | state_dict (OrderedDict): a state-dict to be loaded to the model. 48 | prefix (str): prefix. 49 | """ 50 | keys = sorted(state_dict.keys()) 51 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys): 52 | return 53 | 54 | for key in keys: 55 | newkey = key[len(prefix):] 56 | state_dict[newkey] = state_dict.pop(key) 57 | 58 | # also strip the prefix in metadata, if any.. 59 | try: 60 | metadata = state_dict._metadata # pyre-ignore 61 | except AttributeError: 62 | pass 63 | else: 64 | for key in list(metadata.keys()): 65 | # for the metadata dict, the key can be: 66 | # '': for the DDP module, which we want to remove. 67 | # 'module': for the actual model. 68 | # 'module.xx.xx': for the rest. 69 | 70 | if len(key) == 0: 71 | continue 72 | newkey = key[len(prefix):] 73 | metadata[newkey] = metadata.pop(key) 74 | 75 | 76 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]: 77 | """ 78 | Group keys based on common prefixes. A prefix is the string up to the final 79 | "." in each key. 80 | Args: 81 | keys (list[str]): list of parameter names, i.e. keys in the model 82 | checkpoint dict. 83 | Returns: 84 | dict[list]: keys with common prefixes are grouped into lists. 85 | """ 86 | groups = defaultdict(list) 87 | for key in keys: 88 | pos = key.rfind(".") 89 | if pos >= 0: 90 | head, tail = key[:pos], [key[pos + 1:]] 91 | else: 92 | head, tail = key, [] 93 | groups[head].extend(tail) 94 | return groups 95 | 96 | 97 | def _group_to_str(group: List[str]) -> str: 98 | """ 99 | Format a group of parameter name suffixes into a loggable string. 100 | Args: 101 | group (list[str]): list of parameter name suffixes. 102 | Returns: 103 | str: formated string. 104 | """ 105 | if len(group) == 0: 106 | return "" 107 | 108 | if len(group) == 1: 109 | return "." + group[0] 110 | 111 | return ".{" + ", ".join(group) + "}" 112 | 113 | 114 | def _named_modules_with_dup( 115 | model: nn.Module, prefix: str = "" 116 | ) -> Iterable[Tuple[str, nn.Module]]: 117 | """ 118 | The same as `model.named_modules()`, except that it includes 119 | duplicated modules that have more than one name. 120 | """ 121 | yield prefix, model 122 | for name, module in model._modules.items(): # pyre-ignore 123 | if module is None: 124 | continue 125 | submodule_prefix = prefix + ("." if prefix else "") + name 126 | yield from _named_modules_with_dup(module, submodule_prefix) -------------------------------------------------------------------------------- /src/models/pointbert/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /src/models/pointbert/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | from collections import abc 10 | from pointnet2_ops import pointnet2_utils 11 | 12 | 13 | # def fps(data, number): 14 | # ''' 15 | # data B N 3 16 | # number int 17 | # ''' 18 | # fps_idx = pointnet2_utils.furthest_point_sample(data, number) 19 | # fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() 20 | # return fps_data 21 | 22 | def index_points(points, idx): 23 | """ 24 | Input: 25 | points: input points data, [B, N, C] 26 | idx: sample index data, [B, S] 27 | Return: 28 | new_points:, indexed points data, [B, S, C] 29 | """ 30 | device = points.device 31 | B = points.shape[0] 32 | view_shape = list(idx.shape) 33 | view_shape[1:] = [1] * (len(view_shape) - 1) 34 | repeat_shape = list(idx.shape) 35 | repeat_shape[0] = 1 36 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 37 | new_points = points[batch_indices, idx, :] 38 | return new_points 39 | 40 | # def fps(xyz, npoint): 41 | # """ 42 | # Input: 43 | # xyz: pointcloud data, [B, N, 3] 44 | # npoint: number of samples 45 | # Return: 46 | # centroids: sampled pointcloud index, [B, npoint] 47 | # """ 48 | # device = xyz.device 49 | # B, N, C = xyz.shape 50 | # centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 51 | # distance = torch.ones(B, N).to(device) * 1e10 52 | # farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 53 | # batch_indices = torch.arange(B, dtype=torch.long).to(device) 54 | # for i in range(npoint): 55 | # centroids[:, i] = farthest 56 | # centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 57 | # dist = torch.sum((xyz - centroid) ** 2, -1) 58 | # distance = torch.min(distance, dist) 59 | # farthest = torch.max(distance, -1)[1] 60 | # return index_points(xyz, centroids) 61 | 62 | def fps(data, number): 63 | ''' 64 | data B N 3 65 | number int 66 | ''' 67 | fps_idx = pointnet2_utils.furthest_point_sample(data, number) 68 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() 69 | return fps_data 70 | 71 | def worker_init_fn(worker_id): 72 | np.random.seed(np.random.get_state()[1][0] + worker_id) 73 | 74 | def build_lambda_sche(opti, config): 75 | if config.get('decay_step') is not None: 76 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay) 77 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd) 78 | else: 79 | raise NotImplementedError() 80 | return scheduler 81 | 82 | def build_lambda_bnsche(model, config): 83 | if config.get('decay_step') is not None: 84 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay) 85 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd) 86 | else: 87 | raise NotImplementedError() 88 | return bnm_scheduler 89 | 90 | def set_random_seed(seed, deterministic=False): 91 | """Set random seed. 92 | Args: 93 | seed (int): Seed to be used. 94 | deterministic (bool): Whether to set the deterministic option for 95 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 96 | to True and `torch.backends.cudnn.benchmark` to False. 97 | Default: False. 98 | 99 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 100 | if cuda_deterministic: # slower, more reproducible 101 | cudnn.deterministic = True 102 | cudnn.benchmark = False 103 | else: # faster, less reproducible 104 | cudnn.deterministic = False 105 | cudnn.benchmark = True 106 | 107 | """ 108 | random.seed(seed) 109 | np.random.seed(seed) 110 | torch.manual_seed(seed) 111 | torch.cuda.manual_seed_all(seed) 112 | if deterministic: 113 | torch.backends.cudnn.deterministic = True 114 | torch.backends.cudnn.benchmark = False 115 | 116 | 117 | def is_seq_of(seq, expected_type, seq_type=None): 118 | """Check whether it is a sequence of some type. 119 | Args: 120 | seq (Sequence): The sequence to be checked. 121 | expected_type (type): Expected type of sequence items. 122 | seq_type (type, optional): Expected sequence type. 123 | Returns: 124 | bool: Whether the sequence is valid. 125 | """ 126 | if seq_type is None: 127 | exp_seq_type = abc.Sequence 128 | else: 129 | assert isinstance(seq_type, type) 130 | exp_seq_type = seq_type 131 | if not isinstance(seq, exp_seq_type): 132 | return False 133 | for item in seq: 134 | if not isinstance(item, expected_type): 135 | return False 136 | return True 137 | 138 | 139 | def set_bn_momentum_default(bn_momentum): 140 | def fn(m): 141 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 142 | m.momentum = bn_momentum 143 | return fn 144 | 145 | class BNMomentumScheduler(object): 146 | 147 | def __init__( 148 | self, model, bn_lambda, last_epoch=-1, 149 | setter=set_bn_momentum_default 150 | ): 151 | if not isinstance(model, nn.Module): 152 | raise RuntimeError( 153 | "Class '{}' is not a PyTorch nn Module".format( 154 | type(model).__name__ 155 | ) 156 | ) 157 | 158 | self.model = model 159 | self.setter = setter 160 | self.lmbd = bn_lambda 161 | 162 | self.step(last_epoch + 1) 163 | self.last_epoch = last_epoch 164 | 165 | def step(self, epoch=None): 166 | if epoch is None: 167 | epoch = self.last_epoch + 1 168 | 169 | self.last_epoch = epoch 170 | self.model.apply(self.setter(self.lmbd(epoch))) 171 | 172 | def get_momentum(self, epoch=None): 173 | if epoch is None: 174 | epoch = self.last_epoch + 1 175 | return self.lmbd(epoch) 176 | 177 | 178 | 179 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False): 180 | ''' 181 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number. 182 | ''' 183 | _,n,c = xyz.shape 184 | 185 | assert n == num_points 186 | assert c == 3 187 | if crop == num_points: 188 | return xyz, None 189 | 190 | INPUT = [] 191 | CROP = [] 192 | for points in xyz: 193 | if isinstance(crop,list): 194 | num_crop = random.randint(crop[0],crop[1]) 195 | else: 196 | num_crop = crop 197 | 198 | points = points.unsqueeze(0) 199 | 200 | if fixed_points is None: 201 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda() 202 | else: 203 | if isinstance(fixed_points,list): 204 | fixed_point = random.sample(fixed_points,1)[0] 205 | else: 206 | fixed_point = fixed_points 207 | center = fixed_point.reshape(1,1,3).cuda() 208 | 209 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048 210 | 211 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048 212 | 213 | if padding_zeros: 214 | input_data = points.clone() 215 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0 216 | 217 | else: 218 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3 219 | 220 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0) 221 | 222 | if isinstance(crop,list): 223 | INPUT.append(fps(input_data,2048)) 224 | CROP.append(fps(crop_data,2048)) 225 | else: 226 | INPUT.append(input_data) 227 | CROP.append(crop_data) 228 | 229 | input_data = torch.cat(INPUT,dim=0)# B N 3 230 | crop_data = torch.cat(CROP,dim=0)# B M 3 231 | 232 | return input_data.contiguous(), crop_data.contiguous() 233 | 234 | def get_ptcloud_img(ptcloud): 235 | fig = plt.figure(figsize=(8, 8)) 236 | 237 | x, z, y = ptcloud.transpose(1, 0) 238 | ax = fig.gca(projection=Axes3D.name, adjustable='box') 239 | ax.axis('off') 240 | # ax.axis('scaled') 241 | ax.view_init(30, 45) 242 | max, min = np.max(ptcloud), np.min(ptcloud) 243 | ax.set_xbound(min, max) 244 | ax.set_ybound(min, max) 245 | ax.set_zbound(min, max) 246 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet') 247 | 248 | fig.canvas.draw() 249 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 250 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, )) 251 | return img 252 | 253 | 254 | 255 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y', 256 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ): 257 | fig = plt.figure(figsize=(6*len(data_list),6)) 258 | cmax = data_list[-1][:,0].max() 259 | 260 | for i in range(len(data_list)): 261 | data = data_list[i][:-2048] if i == 1 else data_list[i] 262 | color = data[:,0] /cmax 263 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d') 264 | ax.view_init(30, -120) 265 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black') 266 | ax.set_title(titles[i]) 267 | 268 | ax.set_axis_off() 269 | ax.set_xlim(xlim) 270 | ax.set_ylim(ylim) 271 | ax.set_zlim(zlim) 272 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0) 273 | if not os.path.exists(path): 274 | os.makedirs(path) 275 | 276 | pic_path = path + '.png' 277 | fig.savefig(pic_path) 278 | 279 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy()) 280 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy()) 281 | plt.close(fig) 282 | 283 | 284 | def random_dropping(pc, e): 285 | up_num = max(64, 768 // (e//50 + 1)) 286 | pc = pc 287 | random_num = torch.randint(1, up_num, (1,1))[0,0] 288 | pc = fps(pc, random_num) 289 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device) 290 | pc = torch.cat([pc, padding], dim = 1) 291 | return pc 292 | 293 | 294 | def random_scale(partial, scale_range=[0.8, 1.2]): 295 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0] 296 | return partial * scale 297 | -------------------------------------------------------------------------------- /src/models/pointbert/point_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath 5 | from models.pointbert.dvae import Group 6 | from models.pointbert.dvae import Encoder 7 | from models.pointbert.logger import print_log 8 | 9 | from models.pointbert.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message 10 | 11 | def cal_model_parm_nums(model): 12 | total = sum([param.nelement() for param in model.parameters()]) 13 | 14 | if total >= 1e9: 15 | return "{:.2f}B".format(total / 1e9) 16 | elif total >= 1e6: 17 | return "{:.2f}M".format(total / 1e6) 18 | elif total >= 1e3: 19 | return "{:.2f}K".format(total / 1e3) 20 | else: 21 | return str(total) 22 | 23 | class Mlp(nn.Module): 24 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 25 | super().__init__() 26 | out_features = out_features or in_features 27 | hidden_features = hidden_features or in_features 28 | self.fc1 = nn.Linear(in_features, hidden_features) 29 | self.act = act_layer() 30 | self.fc2 = nn.Linear(hidden_features, out_features) 31 | self.drop = nn.Dropout(drop) 32 | 33 | def forward(self, x): 34 | x = self.fc1(x) 35 | x = self.act(x) 36 | x = self.drop(x) 37 | x = self.fc2(x) 38 | x = self.drop(x) 39 | return x 40 | 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 44 | super().__init__() 45 | self.num_heads = num_heads 46 | head_dim = dim // num_heads 47 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 48 | self.scale = qk_scale or head_dim ** -0.5 49 | 50 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 51 | self.attn_drop = nn.Dropout(attn_drop) 52 | self.proj = nn.Linear(dim, dim) 53 | self.proj_drop = nn.Dropout(proj_drop) 54 | 55 | def forward(self, x): 56 | B, N, C = x.shape 57 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 58 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 59 | 60 | attn = (q @ k.transpose(-2, -1)) * self.scale 61 | attn = attn.softmax(dim=-1) 62 | attn = self.attn_drop(attn) 63 | 64 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 65 | x = self.proj(x) 66 | x = self.proj_drop(x) 67 | return x 68 | 69 | 70 | class Block(nn.Module): 71 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 72 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 73 | super().__init__() 74 | self.norm1 = norm_layer(dim) 75 | 76 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 77 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 78 | self.norm2 = norm_layer(dim) 79 | mlp_hidden_dim = int(dim * mlp_ratio) 80 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 81 | 82 | self.attn = Attention( 83 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 84 | 85 | def forward(self, x): 86 | x = x + self.drop_path(self.attn(self.norm1(x))) 87 | x = x + self.drop_path(self.mlp(self.norm2(x))) 88 | return x 89 | 90 | 91 | class TransformerEncoder(nn.Module): 92 | """ Transformer Encoder without hierarchical structure 93 | """ 94 | 95 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, 96 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): 97 | super().__init__() 98 | 99 | self.blocks = nn.ModuleList([ 100 | Block( 101 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 102 | drop=drop_rate, attn_drop=attn_drop_rate, 103 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate 104 | ) 105 | for i in range(depth)]) 106 | 107 | def forward(self, x, pos): 108 | for _, block in enumerate(self.blocks): 109 | x = block(x + pos) 110 | return x 111 | 112 | 113 | class PointTransformer(nn.Module): 114 | def __init__(self, config, **kwargs): 115 | super().__init__() 116 | self.config = config 117 | self.args = kwargs["args"] 118 | 119 | self.trans_dim = config.trans_dim 120 | self.depth = config.depth 121 | self.drop_path_rate = config.drop_path_rate 122 | self.cls_dim = config.cls_dim 123 | self.num_heads = config.num_heads 124 | 125 | self.group_size = config.group_size 126 | self.num_group = config.num_group 127 | # grouper 128 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 129 | # define the encoder 130 | self.encoder_dims = config.encoder_dims 131 | self.encoder = Encoder(encoder_channel=self.encoder_dims) 132 | # bridge encoder and transformer 133 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim) 134 | 135 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 136 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) 137 | 138 | self.pos_embed = nn.Sequential( 139 | nn.Linear(3, 128), 140 | nn.GELU(), 141 | nn.Linear(128, self.trans_dim) 142 | ) 143 | 144 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 145 | self.blocks = TransformerEncoder( 146 | embed_dim=self.trans_dim, 147 | depth=self.depth, 148 | drop_path_rate=dpr, 149 | num_heads=self.num_heads 150 | ) 151 | 152 | self.norm = nn.LayerNorm(self.trans_dim) 153 | # self.load_model_from_ckpt('/export/home/repos/SLIP/pretrained_models/point_transformer_8192.pt') 154 | if not self.args.evaluate_3d: 155 | self.load_model_from_ckpt('Path/checkpoints/initialize_models_point_bert_pretrained.pt') 156 | 157 | # self.cls_head_finetune = nn.Sequential( 158 | # nn.Linear(self.trans_dim * 2, 256), 159 | # nn.ReLU(inplace=True), 160 | # nn.Dropout(0.5), 161 | # nn.Linear(256, self.cls_dim) 162 | # ) 163 | 164 | # self.build_loss_func() 165 | 166 | def build_loss_func(self): 167 | self.loss_ce = nn.CrossEntropyLoss() 168 | 169 | def get_loss_acc(self, pred, gt, smoothing=True): 170 | # import pdb; pdb.set_trace() 171 | gt = gt.contiguous().view(-1).long() 172 | 173 | if smoothing: 174 | eps = 0.2 175 | n_class = pred.size(1) 176 | 177 | one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1) 178 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 179 | log_prb = F.log_softmax(pred, dim=1) 180 | 181 | loss = -(one_hot * log_prb).sum(dim=1).mean() 182 | else: 183 | loss = self.loss_ce(pred, gt.long()) 184 | 185 | pred = pred.argmax(-1) 186 | acc = (pred == gt).sum() / float(gt.size(0)) 187 | 188 | return loss, acc * 100 189 | 190 | def load_model_from_ckpt(self, bert_ckpt_path): 191 | ckpt = torch.load(bert_ckpt_path) 192 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} 193 | for k in list(base_ckpt.keys()): 194 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'): 195 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k] 196 | elif k.startswith('base_model'): 197 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k] 198 | del base_ckpt[k] 199 | 200 | incompatible = self.load_state_dict(base_ckpt, strict=False) 201 | 202 | if incompatible.missing_keys: 203 | print_log('missing_keys', logger='Transformer') 204 | print_log( 205 | get_missing_parameters_message(incompatible.missing_keys), 206 | logger='Transformer' 207 | ) 208 | if incompatible.unexpected_keys: 209 | print_log('unexpected_keys', logger='Transformer') 210 | print_log( 211 | get_unexpected_parameters_message(incompatible.unexpected_keys), 212 | logger='Transformer' 213 | ) 214 | 215 | print_log(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}', logger='Transformer') 216 | 217 | def forward(self, pts): 218 | # divide the point cloud in the same form. This is important 219 | neighborhood, center = self.group_divider(pts) 220 | # encoder the input cloud blocks 221 | group_input_tokens = self.encoder(neighborhood) # B G N 222 | group_input_tokens = self.reduce_dim(group_input_tokens) 223 | # prepare cls 224 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) 225 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) 226 | # add pos embedding 227 | pos = self.pos_embed(center) 228 | # final input 229 | x = torch.cat((cls_tokens, group_input_tokens), dim=1) 230 | pos = torch.cat((cls_pos, pos), dim=1) 231 | # transformer 232 | x = self.blocks(x, pos) 233 | x = self.norm(x) 234 | concat_f = torch.cat([x[:, 0], x[:, 1:].max(1)[0]], dim=-1) 235 | # ret = self.cls_head_finetune(concat_f) 236 | return concat_f 237 | 238 | class PointTransformer_Colored(nn.Module): 239 | def __init__(self, config, **kwargs): 240 | super().__init__() 241 | self.config = config 242 | self.args = kwargs["args"] 243 | 244 | self.trans_dim = config.trans_dim 245 | self.depth = config.depth 246 | self.drop_path_rate = config.drop_path_rate 247 | self.cls_dim = config.cls_dim 248 | self.num_heads = config.num_heads 249 | 250 | self.group_size = config.group_size 251 | self.num_group = config.num_group 252 | # grouper 253 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) 254 | # define the encoder 255 | self.encoder_dims = config.encoder_dims 256 | self.encoder = Encoder(encoder_channel=self.encoder_dims, input_dim=6) 257 | # bridge encoder and transformer 258 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim) 259 | 260 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) 261 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) 262 | 263 | self.pos_embed = nn.Sequential( 264 | nn.Linear(3, 128), 265 | nn.GELU(), 266 | nn.Linear(128, self.trans_dim) 267 | ) 268 | 269 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 270 | self.blocks = TransformerEncoder( 271 | embed_dim=self.trans_dim, 272 | depth=self.depth, 273 | drop_path_rate=dpr, 274 | num_heads=self.num_heads 275 | ) 276 | 277 | self.norm = nn.LayerNorm(self.trans_dim) 278 | 279 | print("training from scratch for pointbert.") 280 | 281 | model_size = cal_model_parm_nums(self) 282 | print("model size:") 283 | print(model_size) 284 | 285 | def build_loss_func(self): 286 | self.loss_ce = nn.CrossEntropyLoss() 287 | 288 | def get_loss_acc(self, pred, gt, smoothing=True): 289 | # import pdb; pdb.set_trace() 290 | gt = gt.contiguous().view(-1).long() 291 | 292 | if smoothing: 293 | eps = 0.2 294 | n_class = pred.size(1) 295 | 296 | one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1) 297 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 298 | log_prb = F.log_softmax(pred, dim=1) 299 | 300 | loss = -(one_hot * log_prb).sum(dim=1).mean() 301 | else: 302 | loss = self.loss_ce(pred, gt.long()) 303 | 304 | pred = pred.argmax(-1) 305 | acc = (pred == gt).sum() / float(gt.size(0)) 306 | 307 | return loss, acc * 100 308 | 309 | def load_model_from_ckpt(self, bert_ckpt_path): 310 | ckpt = torch.load(bert_ckpt_path, map_location=torch.device('cpu')) 311 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} 312 | for k in list(base_ckpt.keys()): 313 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'): 314 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k] 315 | elif k.startswith('base_model'): 316 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k] 317 | del base_ckpt[k] 318 | 319 | incompatible = self.load_state_dict(base_ckpt, strict=False) 320 | 321 | if incompatible.missing_keys: 322 | print_log('missing_keys', logger='Transformer') 323 | print_log( 324 | get_missing_parameters_message(incompatible.missing_keys), 325 | logger='Transformer' 326 | ) 327 | if incompatible.unexpected_keys: 328 | print_log('unexpected_keys', logger='Transformer') 329 | print_log( 330 | get_unexpected_parameters_message(incompatible.unexpected_keys), 331 | logger='Transformer' 332 | ) 333 | 334 | print_log(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}', logger='Transformer') 335 | 336 | def forward(self, pts): 337 | # divide the point cloud in the same form. This is important 338 | neighborhood, center = self.group_divider(pts) 339 | # encoder the input cloud blocks 340 | group_input_tokens = self.encoder(neighborhood) # B G N 341 | group_input_tokens = self.reduce_dim(group_input_tokens) 342 | # prepare cls 343 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) 344 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) 345 | # add pos embedding 346 | pos = self.pos_embed(center) 347 | # final input 348 | x = torch.cat((cls_tokens, group_input_tokens), dim=1) 349 | pos = torch.cat((cls_pos, pos), dim=1) 350 | # transformer 351 | x = self.blocks(x, pos) 352 | x = self.norm(x) 353 | concat_f = torch.cat([x[:, 0], x[:, 1:].max(1)[0]], dim=-1) 354 | # ret = self.cls_head_finetune(concat_f) 355 | return concat_f -------------------------------------------------------------------------------- /src/models/text_encoder_3d.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | # import math 3 | from .clip_utils import * 4 | from torch.nn.parameter import Parameter 5 | import math 6 | from torch.nn import Dropout 7 | from functools import reduce 8 | from operator import mul 9 | class CLIPTextEncoder(nn.Module): 10 | def __init__(self, context_length=77, 11 | vocab_size=49408, 12 | transformer_width=512, 13 | transformer_heads=8, 14 | transformer_layers=12, 15 | embed_dim=512, 16 | out_dim=256, 17 | patch_size = 16, 18 | pretrained=None, **kwargs): 19 | super().__init__() 20 | self.layers = transformer_layers 21 | self.total_d_layer = transformer_layers-1 22 | self.pretrained = pretrained 23 | self.out_indices = [3, 5, 7, 11] 24 | self.num_tokens = 100 25 | self.prompt_dim = embed_dim 26 | self.context_length = context_length 27 | 28 | self.transformer = Transformer( 29 | width=transformer_width, 30 | layers=transformer_layers, 31 | heads=transformer_heads, 32 | attn_mask=self.build_attention_mask() 33 | ) 34 | 35 | self.vocab_size = vocab_size 36 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 37 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 38 | self.ln_final = LayerNorm(transformer_width) 39 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 40 | 41 | self._init_prompt(patch_size, self.num_tokens, self.prompt_dim, self.total_d_layer) 42 | 43 | def _init_prompt(self, patch, num_tokens, prompt_dim, total_d_layer): 44 | patch_size = [] 45 | patch_size.append(patch) 46 | patch_size.append(patch) 47 | val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim)) 48 | 49 | if total_d_layer >= 0: 50 | self.prompt_embeddings = nn.Parameter(torch.zeros(1, num_tokens, prompt_dim)) 51 | # xavier_uniform initialization 52 | nn.init.uniform_(self.prompt_embeddings.data, -val, val) 53 | 54 | if total_d_layer > 0: # noqa 55 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros(total_d_layer, num_tokens, prompt_dim)) 56 | # xavier_uniform initialization 57 | nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val) 58 | 59 | self.prompt_proj = nn.Linear(prompt_dim, prompt_dim) 60 | nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out') 61 | self.prompt_norm = LayerNorm(prompt_dim, eps=1e-6) 62 | self.prompt_dropout = Dropout(0.1) 63 | 64 | else: # total_d_layer < 0 65 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros(abs(total_d_layer), num_tokens, prompt_dim)) 66 | nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val) 67 | self.prompt_proj = nn.Linear(prompt_dim, prompt_dim) 68 | nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out') 69 | self.prompt_norm = LayerNorm(prompt_dim, eps=1e-6) 70 | self.prompt_dropout = Dropout(0.1) 71 | 72 | def init_weights(self, pretrained=None): 73 | pretrained = pretrained or self.pretrained 74 | if isinstance(pretrained, str): 75 | checkpoint = torch.jit.load(pretrained, map_location='cpu').float().state_dict() 76 | 77 | state_dict = {} 78 | 79 | for k in checkpoint.keys(): 80 | if k.startswith('transformer.'): 81 | state_dict[k] = checkpoint[k] 82 | 83 | if k == 'positional_embedding' or k == 'text_projection' or k.startswith('token_embedding') or k.startswith('ln_final'): 84 | if k == 'positional_embedding' and checkpoint[k].size(0) > self.context_length: 85 | checkpoint[k] = checkpoint[k][:self.context_length] 86 | print('positional_embedding is tuncated from 77 to', self.context_length) 87 | state_dict[k] = checkpoint[k] 88 | 89 | u, w = self.load_state_dict(state_dict, False) 90 | print(u, w, 'are misaligned params in text encoder') 91 | 92 | def build_attention_mask(self): 93 | # lazily create causal attention mask, with full attention between the vision tokens 94 | # pytorch uses additive attention mask; fill with -inf 95 | mask = torch.empty(self.context_length + self.num_tokens, self.context_length+self.num_tokens) 96 | mask.fill_(float("-inf")) 97 | mask.triu_(1) # zero out the lower diagonal 98 | return mask 99 | 100 | def forward_deep_prompt(self, embedding_output, features, out_last=False): 101 | B = embedding_output.shape[1] # batch_size 102 | for i in range(self.layers): 103 | if i == 0: 104 | hidden_states = self.transformer.resblocks[i](embedding_output) 105 | elif i <= self.deep_prompt_embeddings.shape[0]: 106 | deep_prompt_emb = self.prompt_dropout(self.prompt_proj(self.deep_prompt_embeddings[i-1]).expand(B, -1, -1)).permute(1, 0, 2) # seems like the middle layer's dpt 107 | hidden_states = torch.cat(( 108 | deep_prompt_emb, 109 | hidden_states[self.num_tokens:,:,:] 110 | ), dim=0) # 177 B 768 111 | 112 | hidden_states = self.transformer.resblocks[i](hidden_states) 113 | else: 114 | # hidden_states = torch.cat(( 115 | # hidden_states[:1, :, :], 116 | # hidden_states[-(H*W):, :, :] 117 | # ), dim=0) 118 | hidden_states = self.transformer.resblocks[i](hidden_states) 119 | 120 | if len(self.out_indices) > 1: 121 | if i in self.out_indices: 122 | xp = hidden_states.permute(1, 0, 2)[:, -77:, :].permute(0, 2, 1) # B,512,77 123 | features.append(xp.contiguous()) 124 | 125 | if i == (self.layers-2): #10 126 | before_last_feats = self.prompt_norm(hidden_states) # 1125x4x768 127 | 128 | encoded = self.prompt_norm(hidden_states) 129 | if out_last: 130 | return before_last_feats 131 | else: 132 | return encoded, features 133 | def encode_token(self, token): 134 | x = self.token_embedding(token) 135 | return x 136 | 137 | def forward(self, text, token): 138 | # x = self.token_embedding(text) 139 | x = text + self.positional_embedding 140 | if self.total_d_layer >=0: 141 | # concat prompt 142 | x = torch.cat(( # Deep Prompt Tuning 143 | self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(x.shape[0], -1, -1)), 144 | x 145 | ), dim=1)# B,177,512 146 | 147 | x = x.permute(1, 0, 2) 148 | features = [] 149 | outs = [] 150 | if self.total_d_layer > 0: 151 | x, features = self.forward_deep_prompt(x, features) 152 | # x = self.transformer(x) 153 | x = x[-77:,:,:] 154 | x = x.permute(1, 0, 2) 155 | x = self.ln_final(x) 156 | x = x[torch.arange(x.shape[0]), token.argmax(dim=-1)] @ self.text_projection 157 | # outs.append(tuple(features)) 158 | return x 159 | 160 | if __name__ == '__main__': 161 | from tokenizer import SimpleTokenizer 162 | tokenizer = SimpleTokenizer() 163 | encoder = CLIPTextEncoder(pretrained="Path/ViT-B-16.pt") 164 | encoder.init_weights() 165 | text = 'a airplane' 166 | texts = tokenizer(text).unsqueeze(0) 167 | text_embed = encoder(texts) 168 | print(f'text_mebed: {text_embed.shape}') 169 | exclude_key = 'prompt' 170 | for n,m in encoder.named_parameters(): 171 | if exclude_key not in n: 172 | m.requires_grad = False 173 | else: 174 | print(n) 175 | # if exclude_key: 176 | # if isinstance(exclude_key, str): 177 | # if not exclude_key in n: 178 | # m.requires_grad = False 179 | # print(f'False : {n}') 180 | # elif isinstance(exclude_key, list): 181 | # count = 0 182 | # for i in range(len(exclude_key)): 183 | # i_layer = str(exclude_key[i]) 184 | # if i_layer in n: 185 | # count += 1 186 | # if count == 0: 187 | # m.requires_grad = False 188 | # elif count>0: 189 | # print('Finetune layer in backbone:', n) 190 | # else: 191 | # assert AttributeError("Dont support the type of exclude_key!") 192 | # else: 193 | # m.requires_grad = False 194 | # print(f'False : {n}') 195 | -------------------------------------------------------------------------------- /src/models/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Modified from github.com/openai/CLIP 2 | import gzip 3 | import html 4 | import os 5 | from functools import lru_cache 6 | 7 | import ftfy 8 | import regex as re 9 | import torch 10 | 11 | 12 | @lru_cache() 13 | def default_bpe(): 14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 29 | cs = bs[:] 30 | n = 0 31 | for b in range(2**8): 32 | if b not in bs: 33 | bs.append(b) 34 | cs.append(2**8+n) 35 | n += 1 36 | cs = [chr(n) for n in cs] 37 | return dict(zip(bs, cs)) 38 | 39 | 40 | def get_pairs(word): 41 | """Return set of symbol pairs in a word. 42 | Word is represented as tuple of symbols (symbols being variable-length strings). 43 | """ 44 | pairs = set() 45 | prev_char = word[0] 46 | for char in word[1:]: 47 | pairs.add((prev_char, char)) 48 | prev_char = char 49 | return pairs 50 | 51 | 52 | def basic_clean(text): 53 | text = ftfy.fix_text(text) 54 | text = html.unescape(html.unescape(text)) 55 | return text.strip() 56 | 57 | 58 | def whitespace_clean(text): 59 | text = re.sub(r'\s+', ' ', text) 60 | text = text.strip() 61 | return text 62 | 63 | 64 | class SimpleTokenizer(object): 65 | def __init__(self, bpe_path: str = default_bpe()): 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 69 | merges = merges[1:49152-256-2+1] 70 | merges = [tuple(merge.split()) for merge in merges] 71 | vocab = list(bytes_to_unicode().values()) 72 | vocab = vocab + [v+'' for v in vocab] 73 | for merge in merges: 74 | vocab.append(''.join(merge)) 75 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 76 | self.encoder = dict(zip(vocab, range(len(vocab)))) 77 | self.decoder = {v: k for k, v in self.encoder.items()} 78 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 81 | 82 | def bpe(self, token): 83 | if token in self.cache: 84 | return self.cache[token] 85 | word = tuple(token[:-1]) + ( token[-1] + '',) 86 | pairs = get_pairs(word) 87 | 88 | if not pairs: 89 | return token+'' 90 | 91 | while True: 92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 93 | if bigram not in self.bpe_ranks: 94 | break 95 | first, second = bigram 96 | new_word = [] 97 | i = 0 98 | while i < len(word): 99 | try: 100 | j = word.index(first, i) 101 | new_word.extend(word[i:j]) 102 | i = j 103 | except: 104 | new_word.extend(word[i:]) 105 | break 106 | 107 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 108 | new_word.append(first+second) 109 | i += 2 110 | else: 111 | new_word.append(word[i]) 112 | i += 1 113 | new_word = tuple(new_word) 114 | word = new_word 115 | if len(word) == 1: 116 | break 117 | else: 118 | pairs = get_pairs(word) 119 | word = ' '.join(word) 120 | self.cache[token] = word 121 | return word 122 | 123 | def encode(self, text): 124 | bpe_tokens = [] 125 | text = whitespace_clean(basic_clean(text)).lower() 126 | for token in re.findall(self.pat, text): 127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 129 | return bpe_tokens 130 | 131 | def decode(self, tokens): 132 | text = ''.join([self.decoder[token] for token in tokens]) 133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 134 | return text 135 | 136 | def __call__(self, texts, context_length=77): 137 | if isinstance(texts, str): 138 | texts = [texts] 139 | 140 | sot_token = self.encoder["<|startoftext|>"] 141 | eot_token = self.encoder["<|endoftext|>"] 142 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] 143 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 144 | 145 | for i, tokens in enumerate(all_tokens): 146 | tokens = tokens[:context_length] 147 | result[i, :len(tokens)] = torch.tensor(tokens) 148 | 149 | if len(result) == 1: 150 | return result[0] 151 | return result -------------------------------------------------------------------------------- /src/tools/builder.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | # online package 3 | import torch 4 | # optimizer 5 | import torch.optim as optim 6 | # dataloader 7 | # from datasets import build_dataset_from_cfg 8 | # from .models.build import build_model_from_cfg 9 | from models.build import build_model_from_cfg 10 | # utils 11 | from utils.logger import * 12 | 13 | from utils.misc import * 14 | from timm.scheduler import CosineLRScheduler 15 | 16 | # def dataset_builder(args, config): 17 | # dataset = build_dataset_from_cfg(config._base_, config.others) 18 | # shuffle = config.others.subset == 'train' 19 | # if args.distributed: 20 | # sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle = shuffle) 21 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size = config.others.bs, 22 | # num_workers = int(args.num_workers), 23 | # drop_last = config.others.subset == 'train', 24 | # worker_init_fn = worker_init_fn, 25 | # sampler = sampler) 26 | # else: 27 | # sampler = None 28 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs, 29 | # shuffle = shuffle, 30 | # drop_last = config.others.subset == 'train', 31 | # num_workers = int(args.num_workers), 32 | # worker_init_fn=worker_init_fn) 33 | # return sampler, dataloader 34 | 35 | def model_builder(config): 36 | model = build_model_from_cfg(config) 37 | return model 38 | 39 | def build_opti_sche(base_model, config): 40 | opti_config = config.optimizer 41 | if opti_config.type == 'AdamW': 42 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 43 | decay = [] 44 | no_decay = [] 45 | for name,param in model.named_parameters(): 46 | # for name, param in model.module.named_parameters(): 47 | if not param.requires_grad: 48 | continue # frozen weights 49 | if len(param.shape) == 1 or name.endswith(".bias") or 'token' in name or name in skip_list: 50 | # print(name) 51 | no_decay.append(param) 52 | else: 53 | decay.append(param) 54 | return [ 55 | {'params': no_decay, 'weight_decay': 0.}, 56 | {'params': decay, 'weight_decay': weight_decay}] 57 | param_groups = add_weight_decay(base_model, weight_decay=opti_config.kwargs.weight_decay) 58 | optimizer = optim.AdamW(param_groups, **opti_config.kwargs) 59 | elif opti_config.type == 'Adam': 60 | optimizer = optim.Adam(base_model.parameters(), **opti_config.kwargs) 61 | elif opti_config.type == 'SGD': 62 | optimizer = optim.SGD(base_model.parameters(), nesterov=True, **opti_config.kwargs) 63 | else: 64 | raise NotImplementedError() 65 | 66 | sche_config = config.scheduler 67 | if sche_config.type == 'LambdaLR': 68 | scheduler = build_lambda_sche(optimizer, sche_config.kwargs) # misc.py 69 | elif sche_config.type == 'CosLR': 70 | scheduler = CosineLRScheduler(optimizer, 71 | t_initial=sche_config.kwargs.epochs, 72 | t_mul=1, 73 | lr_min=1e-6, 74 | decay_rate=0.1, 75 | warmup_lr_init=1e-6, 76 | warmup_t=sche_config.kwargs.initial_epochs, 77 | cycle_limit=1, 78 | t_in_epochs=True) 79 | elif sche_config.type == 'StepLR': 80 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **sche_config.kwargs) 81 | elif sche_config.type == 'function': 82 | scheduler = None 83 | else: 84 | raise NotImplementedError() 85 | 86 | if config.get('bnmscheduler') is not None: 87 | bnsche_config = config.bnmscheduler 88 | if bnsche_config.type == 'Lambda': 89 | bnscheduler = build_lambda_bnsche(base_model, bnsche_config.kwargs) # misc.py 90 | scheduler = [scheduler, bnscheduler] 91 | 92 | return optimizer, scheduler 93 | 94 | # def resume_model(base_model, args, logger = None): 95 | # ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth') 96 | # if not os.path.exists(ckpt_path): 97 | # print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger) 98 | # return 0, 0 99 | # print_log(f'[RESUME INFO] Loading model weights from {ckpt_path}...', logger = logger ) 100 | 101 | # # load state dict 102 | # map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank} 103 | # state_dict = torch.load(ckpt_path, map_location=map_location) 104 | # # parameter resume of base model 105 | # # if args.local_rank == 0: 106 | # base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()} 107 | # base_model.load_state_dict(base_ckpt, strict = True) 108 | 109 | # # parameter 110 | # start_epoch = state_dict['epoch'] + 1 111 | # best_metrics = state_dict['best_metrics'] 112 | # if not isinstance(best_metrics, dict): 113 | # best_metrics = best_metrics.state_dict() 114 | # # print(best_metrics) 115 | 116 | # print_log(f'[RESUME INFO] resume ckpts @ {start_epoch - 1} epoch( best_metrics = {str(best_metrics):s})', logger = logger) 117 | # return start_epoch, best_metrics 118 | 119 | # def resume_optimizer(optimizer, args, logger = None): 120 | # ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth') 121 | # if not os.path.exists(ckpt_path): 122 | # print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger) 123 | # return 0, 0, 0 124 | # print_log(f'[RESUME INFO] Loading optimizer from {ckpt_path}...', logger = logger ) 125 | # # load state dict 126 | # state_dict = torch.load(ckpt_path, map_location='cpu') 127 | # # optimizer 128 | # optimizer.load_state_dict(state_dict['optimizer']) 129 | 130 | def save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, prefix, args, logger = None): 131 | if args.local_rank == 0: 132 | torch.save({ 133 | 'base_model' : base_model.module.state_dict() if args.distributed else base_model.state_dict(), 134 | 'optimizer' : optimizer.state_dict(), 135 | 'epoch' : epoch, 136 | 'metrics' : metrics.state_dict() if metrics is not None else dict(), 137 | 'best_metrics' : best_metrics.state_dict() if best_metrics is not None else dict(), 138 | }, os.path.join(args.experiment_path, prefix + '.pth')) 139 | print_log(f"Save checkpoint at {os.path.join(args.experiment_path, prefix + '.pth')}", logger = logger) 140 | 141 | def load_model(base_model, ckpt_path, logger = None): 142 | if not os.path.exists(ckpt_path): 143 | raise NotImplementedError('no checkpoint file from path %s...' % ckpt_path) 144 | print_log(f'Loading weights from {ckpt_path}...', logger = logger ) 145 | 146 | # load state dict 147 | state_dict = torch.load(ckpt_path, map_location='cpu') 148 | # parameter resume of base model 149 | if state_dict.get('model') is not None: 150 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['model'].items()} 151 | elif state_dict.get('base_model') is not None: 152 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()} 153 | else: 154 | raise RuntimeError('mismatch of ckpt weight') 155 | base_model.load_state_dict(base_ckpt, strict = True) 156 | 157 | epoch = -1 158 | if state_dict.get('epoch') is not None: 159 | epoch = state_dict['epoch'] 160 | if state_dict.get('metrics') is not None: 161 | metrics = state_dict['metrics'] 162 | if not isinstance(metrics, dict): 163 | metrics = metrics.state_dict() 164 | else: 165 | metrics = 'No Metrics' 166 | print_log(f'ckpts @ {epoch} epoch( performance = {str(metrics):s})', logger = logger) 167 | return -------------------------------------------------------------------------------- /src/train/train_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import munch 5 | import yaml 6 | import sys 7 | from torch.utils.tensorboard import SummaryWriter 8 | from network.base_model import BaseModel 9 | from network.base_model import AverageValueMeter,build_optimizer,save_model_2 10 | import argparse 11 | from time import time 12 | from tqdm.auto import tqdm 13 | import time as timetmp 14 | import logging 15 | import random 16 | import math 17 | from loss.cdloss import SimplificationLoss 18 | from dataset_svr.trainer_dataset import build_dataset,get_spherepoints 19 | 20 | def setFolders(args): 21 | LOG_DIR = args.dir_outpath 22 | MODEL_NAME = '%s-%s'%(args.model_name, timetmp.strftime("%m%d_%H%M", timetmp.localtime())) 23 | 24 | OUT_DIR = os.path.join(LOG_DIR, MODEL_NAME) 25 | args.dir_checkpoints = os.path.join(OUT_DIR, 'checkpoints') 26 | if not os.path.exists(OUT_DIR): os.mkdir(OUT_DIR) 27 | if not os.path.exists(args.dir_checkpoints): 28 | os.makedirs(args.dir_checkpoints) 29 | 30 | LOG_FOUT = open(os.path.join(OUT_DIR, 'log_%s.csv' %(MODEL_NAME)), 'w') 31 | return MODEL_NAME, OUT_DIR, LOG_FOUT 32 | def log_string(out_str,LOG_FOUT): 33 | LOG_FOUT.write(out_str+'\n') 34 | LOG_FOUT.flush() 35 | 36 | def train(): 37 | exp_name,Log_dir,LOG_FOUT = setFolders(args) 38 | writer = SummaryWriter(log_dir=Log_dir) 39 | log_string('EPOCH,avg_cd_l1,avg_cd_l2,Best CDL2[epoch,best_loss],',LOG_FOUT) 40 | logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(Log_dir, 'train.log')), 41 | logging.StreamHandler(sys.stdout)]) 42 | logging.info(str(args)) 43 | metrics = ['cd_p', 'cd_t', 'f1'] 44 | best_epoch_losses = {m: (0, 0) if m == 'f1' else (0, math.inf) for m in metrics} 45 | train_loss_meter = AverageValueMeter() 46 | val_loss_meters = {m: AverageValueMeter() for m in metrics} 47 | #seed 48 | if not args.manual_seed: 49 | seed = random.randint(1,10000) 50 | else: 51 | seed = int(args.manual_seed) 52 | logging.info('Random Seed: %d' % seed) 53 | random.seed(seed) 54 | torch.manual_seed(seed) 55 | if args.distributed: 56 | svrmodel = torch.nn.DataParallel(BaseModel(),device_ids=args.gpus,output_device=args.gpus[0]) 57 | svrmodel.to(device) 58 | else: 59 | svrmodel = BaseModel.to(device) 60 | 61 | optimizer,scheduler = build_optimizer(svrmodel,args) 62 | best_cd_l1 = float("inf") 63 | best_cd_l2 = float("inf") 64 | best_f1 = float("inf") 65 | print("Data Uploading...") 66 | 67 | dataloader, dataloader_test = build_dataset(args) 68 | 69 | print("Data Preparation Done...") 70 | loss_function = SimplificationLoss() 71 | print("Loss Function Done...") 72 | 73 | 74 | for epoch in tqdm(range(args.start_epoch,args.nepoch),desc='Training'): 75 | epoch_start_time = time() 76 | total_cd_l1 = 0 77 | total_cd_l2 = 0 78 | train_loss_meter.reset() 79 | 80 | svrmodel.module.train() 81 | 82 | for param_group in optimizer.param_groups: 83 | print(f"Epoch {epoch+1}, Learning Rate: {param_group['lr']}") 84 | n_batches = len(dataloader) 85 | # train 86 | with tqdm(dataloader) as t: 87 | for batch_idx,data in enumerate(t): 88 | optimizer.zero_grad() 89 | 90 | n_itr = epoch * n_batches + batch_idx 91 | # to(cuda) 92 | images = data['image'].to(device) 93 | batch_size = images.shape[0] 94 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device) 95 | 96 | pointclouds = data['points'].to(device) 97 | 98 | pred_points = svrmodel(images,partial_pcs) 99 | # to(cuda) 100 | pred_points = pred_points.to(pointclouds.device) 101 | pred_points = pred_points.transpose(2,1) 102 | net_loss,loss_t=loss_function(pred_points,pointclouds) 103 | 104 | # caculate Chamfer distance loss 105 | net_loss = net_loss.mean() 106 | loss_t = loss_t.mean() 107 | 108 | net_loss_all = net_loss + loss_t 109 | 110 | train_loss_meter.update(net_loss.item()) 111 | net_loss_all.backward() 112 | optimizer.step() 113 | 114 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4 115 | total_cd_l2 += cd_l2_item 116 | cd_l1_item = net_loss.item() * 1e4 117 | total_cd_l1 += cd_l1_item 118 | 119 | t.set_description('[Epoch %d/%d][Batch %d/%d]' % (epoch, args.nepoch, batch_idx + 1, n_batches)) 120 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [cd_l1_item, cd_l2_item]]) 121 | writer.add_scalar('MAE_netloss',net_loss.item(),n_itr) 122 | writer.add_scalar('MAE_losst',loss_t.item(),n_itr) 123 | scheduler.step(epoch) 124 | 125 | avg_cd_l1 = total_cd_l1 / n_batches 126 | avg_cd_l2 = total_cd_l2 / n_batches 127 | epoch_end_time = time() 128 | logging.info('') 129 | logging.info( 130 | exp_name + '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s ' % 131 | (epoch,args.nepoch,epoch_end_time - epoch_start_time,['%.4f' % l for l in [avg_cd_l1,avg_cd_l2]]) 132 | ) 133 | log_string(f'{epoch}[{batch_idx}],{avg_cd_l1:.4f},{avg_cd_l2:.4f},{total_cd_l2:.4f},{best_epoch_losses},', LOG_FOUT) 134 | writer.add_scalar('MAE_CDL2', avg_cd_l2, epoch) 135 | writer.add_scalar('MAE_CDL1', avg_cd_l1, epoch) 136 | 137 | if epoch % args.epoch_interval_to_val == 0 or epoch == args.nepoch - 1: 138 | best_cd_l1, best_cd_l2 ,best_f1= val(svrmodel,loss_function,epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,Log_dir,best_cd_l1,best_cd_l2,best_f1) 139 | 140 | def val(net,cal_loss,curr_epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,log_dir,best_cd_l1,best_cd_l2,best_f1): 141 | val_start_time = time() 142 | metrics_val = ['cd_t'] 143 | val_loss_meters = {m: AverageValueMeter() for m in metrics_val} 144 | logging.info('Testing...') 145 | for v in val_loss_meters.values(): 146 | v.reset() 147 | net.module.eval() 148 | total_cd_l1 = 0 149 | total_cd_l2 = 0 150 | total_f1 = 0 151 | n_batches = len(dataloader_test) 152 | with torch.no_grad(): 153 | with tqdm(dataloader_test) as tt: 154 | for i,data in enumerate(tt): 155 | images = data['image'].to(device) 156 | batch_size = images.shape[0] 157 | gt = data['points'].to(device) 158 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device) 159 | pred_points = net(images,partial_pcs) 160 | pred_points = pred_points.transpose(2,1) 161 | 162 | loss_p, loss_t,f1 = cal_loss(pred_points, gt,calc_f1=True) 163 | 164 | cd_l1_item = torch.sum(loss_p).item() / batch_size * 1e4 165 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4 166 | f1_item = torch.sum(f1).item() / batch_size * 1e4 167 | total_cd_l1 += cd_l1_item 168 | total_cd_l2 += cd_l2_item 169 | total_f1 += f1_item 170 | 171 | avg_cd_l1 = total_cd_l1 / n_batches 172 | avg_cd_l2 = total_cd_l2 / n_batches 173 | avg_f1 = total_f1 / n_batches 174 | 175 | if avg_cd_l1 < best_cd_l1: 176 | best_cd_l1 = avg_cd_l1 177 | save_model_2(str(log_dir) + '/checkpoints/bestl1_network.pth', net) 178 | logging.info("Saving net...") 179 | 180 | if avg_cd_l2 < best_cd_l2: 181 | best_cd_l2 = avg_cd_l2 182 | 183 | if avg_f1 > best_f1: 184 | best_f1 = avg_f1 185 | 186 | log_string('%d,%.2f,%.2f,%.2f,%.2f,%.2f,%.2f'%(curr_epoch, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2,avg_f1,best_f1), LOG_FOUT) 187 | 188 | val_end_time = time() 189 | 190 | logging.info( 191 | '[Epoch %d/%d] TestTime = %.3f (s) Curr_cdl1 = %s Best_cdl1 = %s Curr_cdl2 = %s Best_cdl2 = %s Curr_f1 = %s Best_f1 = %s' % 192 | (curr_epoch, args.nepoch, val_end_time - val_start_time, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2, avg_f1, best_f1)) 193 | 194 | return best_cd_l1, best_cd_l2 , best_f1 195 | 196 | if __name__ == '__main__': 197 | parser = argparse.ArgumentParser(description='Train config file') 198 | parser.add_argument('-c', '--config', help='path to config file', required=True) 199 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True) 200 | arg = parser.parse_args() 201 | config_path = arg.config 202 | args = munch.munchify(yaml.safe_load(open(config_path))) 203 | 204 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 205 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id) #os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7' 206 | print('Using gpu:' + str(arg.gpu_id)) 207 | sphere_points = get_spherepoints(args.number_points,0.5) 208 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu') 209 | print('Number of points:' + str(args.number_points)) 210 | 211 | train() -------------------------------------------------------------------------------- /src/train/train_ltp.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from dataset_svr.trainer_text_dataset import * 3 | import torch 4 | import torch.nn.functional as F 5 | from models.ULIP_models import get_loss_v2 6 | import sys 7 | import time 8 | from network.ltp_model import * 9 | from network.ltp_model import TextAlignPCModel 10 | from models.ULIP_utils import * 11 | from dataset_svr.trainer_dataset import build_dataset 12 | import os 13 | import logging 14 | import time as timetmp 15 | 16 | 17 | def setFolders(args): 18 | LOG_DIR = args.dir_outpath 19 | MODEL_NAME = '%s-%s'%(args.model, timetmp.strftime("%m%d_%H%M", timetmp.localtime())) 20 | 21 | OUT_DIR = os.path.join(LOG_DIR, MODEL_NAME) 22 | args.dir_checkpoints = os.path.join(OUT_DIR, 'checkpoints') 23 | if not os.path.exists(OUT_DIR): os.mkdir(OUT_DIR) 24 | if not os.path.exists(args.dir_checkpoints): 25 | os.makedirs(args.dir_checkpoints) 26 | 27 | LOG_FOUT = open(os.path.join(OUT_DIR, 'log_%s.csv' %(MODEL_NAME)), 'w') 28 | return MODEL_NAME, OUT_DIR, args.dir_checkpoints 29 | def log_string(out_str,LOG_FOUT): 30 | LOG_FOUT.write(out_str+'\n') 31 | LOG_FOUT.flush() 32 | def save_model(path, net, net_d=None): 33 | if net_d is not None: 34 | torch.save({'net_state_dict': net.state_dict(), 35 | 'D_state_dict': net_d.state_dict()}, path) 36 | else: 37 | torch.save({'net_state_dict': net.state_dict()}, path) 38 | def main(args): 39 | best_cos = float(0) 40 | exp_name,Log_dir,Check_FOUT = setFolders(args) 41 | logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(Log_dir, 'train.log')), 42 | logging.StreamHandler(sys.stdout)]) 43 | logging.info(str(args)) 44 | dataloader_train, dataloader_test = build_dataset(args) 45 | # create model 46 | print("=> creating model: {}".format(args.model)) 47 | model = TextAlignPCModel(args).to(device) 48 | 49 | criterion = get_loss_v2(args).to(device) 50 | optimizer,scheduler = build_optimizer(model,args) 51 | print("=> beginning training") 52 | best_epoch = -1 53 | for epoch in range(args.start_epoch, args.epochs): 54 | train_stats = train(dataloader_train,model,criterion,optimizer,epoch,scheduler,args) 55 | logging.info('') 56 | logging.info(f'Training result: {train_stats}') 57 | logging.info('Testing...') 58 | val_stats = {"cos@t&p":-1} 59 | 60 | if epoch % 1 == 0: 61 | val_stats = val(dataloader_test,model,args) 62 | cos = val_stats["cos@t&p"] 63 | print(f'val_stats:{val_stats}') 64 | 65 | is_best = cos > best_cos 66 | if is_best: 67 | best_epoch = epoch 68 | 69 | best_cos = max(cos,best_cos) 70 | if is_best : 71 | print("=> saving checkpoint") 72 | save_model(str(Check_FOUT)+'/prompt.pth',model.text_prompt_embeddings) 73 | save_model(str(Check_FOUT)+'/text_encoder.pth',model.text_encoder) 74 | logging.info("Saving net...") 75 | if epoch + 1 == args.epochs: 76 | print("=> saving last checkpoint") 77 | save_model(str(Check_FOUT)+'/last_prompt.pth',model.text_prompt_embeddings) 78 | save_model(str(Check_FOUT)+'/last_text_encoder.pth',model.text_encoder) 79 | 80 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 81 | **{f'test_{k}': v for k, v in val_stats.items()}, 82 | 'epoch': epoch, 83 | 'best_cos@t&p': best_cos, 84 | 'best_epoch': best_epoch} 85 | logging.info(f'log_stats: {log_stats}') 86 | 87 | def train(train_loader, model, criterion, optimizer, epoch, scheduler, args): 88 | n_batches = len(train_loader) 89 | # switch to train mode 90 | total_loss = 0 91 | model.train() 92 | end = time.time() 93 | with tqdm(train_loader) as t: 94 | for data_iter,data in enumerate(t): 95 | optimizer.zero_grad() 96 | # images 97 | images = data['image'].to(device) 98 | # text 99 | category = data['category'] 100 | category_indices = [int(c) for c in category] 101 | text_labels = torch.tensor(category_indices).to(device) 102 | # pointclouds 103 | pc = data['points'].to(device) 104 | 105 | outputs = model(images,text_labels,pc) 106 | loss_dict = criterion(outputs) 107 | loss = loss_dict['loss'] 108 | 109 | loss.backward() 110 | optimizer.step() 111 | total_loss += loss.item() 112 | get_model(model).logit_scale.data.clamp_(0, 4.6052) 113 | logit_scale = get_model(model).logit_scale.exp().item() 114 | avg_loss = total_loss / n_batches 115 | scheduler.step(epoch) 116 | return {**{'loss':avg_loss}, 117 | 'lr': optimizer.param_groups[0]['lr'], 118 | 'logit_scale': logit_scale} 119 | def val(test_loader,model,tokenizer,args=None): 120 | batch_time = AverageMeter('Time', ':6.3f') 121 | total = 0 122 | avg = 0 123 | model.eval() 124 | n_batches = len(test_loader) 125 | 126 | print('==>encoding captions') 127 | with torch.no_grad(): 128 | end = time.time() 129 | for i,data in enumerate(test_loader): 130 | # image 131 | images = data['image'].to(device) 132 | batch_size = images.shape[0] 133 | # category 134 | category = data['category'] 135 | category_indices = [int(c) for c in category] 136 | text_labels = torch.tensor(category_indices).to(device) 137 | # pointclouds 138 | gt = data['points'].to(device) 139 | # model 140 | outputs = model(images,text_labels,gt) 141 | text_embed = outputs['text_embed'] 142 | gt_embed = outputs['pc_embed'] 143 | 144 | text_embed = F.normalize(text_embed, dim=-1, p=2) 145 | gt_embed = F.normalize(gt_embed, dim=-1, p=2) 146 | 147 | logits_per_pc_text = text_embed @ gt_embed.t() 148 | 149 | cos_sim_1 = torch.diag(logits_per_pc_text).mean() 150 | 151 | total += cos_sim_1 152 | 153 | avg = total / n_batches 154 | 155 | batch_time.update(time.time() - end) 156 | end = time.time() 157 | logging.info(f"batch_time {batch_time}, 'cos@t&p':{avg}") 158 | return {'cos@t&p':avg} 159 | 160 | if __name__ =='__main__': 161 | parser = argparse.ArgumentParser(description='Train config file') 162 | parser.add_argument('-c', '--config', help='path to config file', required=True) 163 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True) 164 | arg = parser.parse_args() 165 | config_path = arg.config 166 | args = munch.munchify(yaml.safe_load(open(config_path))) 167 | device = torch.device(args.gpu_id if torch.cuda.is_available() else 'cpu') 168 | print('Using gpu:' + str(arg.gpu_id)) 169 | 170 | train() -------------------------------------------------------------------------------- /src/train/train_pro.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import munch 4 | import yaml 5 | import sys 6 | from network.pro_model import ProModel 7 | from network.pro_model import AverageValueMeter,build_optimizer,save_model_2 8 | from time import time 9 | from tqdm.auto import tqdm 10 | import time as timetmp 11 | import argparse 12 | import logging 13 | import random 14 | import math 15 | from loss.cdloss import SimplificationLoss 16 | from dataset_svr.trainer_dataset import build_dataset,get_spherepoints 17 | 18 | def setFolders(args): 19 | LOG_DIR = args.dir_outpath 20 | MODEL_NAME = '%s-%s'%(args.model_name, timetmp.strftime("%m%d_%H%M", timetmp.localtime())) 21 | 22 | OUT_DIR = os.path.join(LOG_DIR, MODEL_NAME) 23 | args.dir_checkpoints = os.path.join(OUT_DIR, 'checkpoints') 24 | if not os.path.exists(OUT_DIR): os.mkdir(OUT_DIR) 25 | if not os.path.exists(args.dir_checkpoints): 26 | os.makedirs(args.dir_checkpoints) 27 | 28 | LOG_FOUT = open(os.path.join(OUT_DIR, 'log_%s.csv' %(MODEL_NAME)), 'w') 29 | return MODEL_NAME, OUT_DIR, LOG_FOUT 30 | def log_string(out_str,LOG_FOUT): 31 | LOG_FOUT.write(out_str+'\n') 32 | LOG_FOUT.flush() 33 | 34 | 35 | def train(): 36 | exp_name,Log_dir,LOG_FOUT = setFolders(args) 37 | log_string('EPOCH,avg_cd_l1,avg_cd_l2,Best CDL2[epoch,best_loss],',LOG_FOUT) 38 | logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(Log_dir, 'train.log')), 39 | logging.StreamHandler(sys.stdout)]) 40 | logging.info(str(args)) 41 | 42 | metrics = ['cd_p', 'cd_t', 'f1'] 43 | best_epoch_losses = {m: (0, 0) if m == 'f1' else (0, math.inf) for m in metrics} 44 | train_loss_meter = AverageValueMeter() 45 | val_loss_meters = {m: AverageValueMeter() for m in metrics} 46 | #seed 47 | if not args.manual_seed: 48 | seed = random.randint(1,10000) 49 | else: 50 | seed = int(args.manual_seed) 51 | logging.info('Random Seed: %d' % seed) 52 | random.seed(seed) 53 | torch.manual_seed(seed) 54 | 55 | if args.distirbuted: 56 | promodel = torch.nn.DataParallel(ProModel(),device_ids=args.gpus,output_device=args.gpus[0]) 57 | else: 58 | promodel = ProModel.to(device) 59 | optimizer,scheduler = build_optimizer(promodel,args) 60 | best_cd_l1 = float("inf") 61 | best_cd_l2 = float("inf") 62 | best_f1 = float("inf") 63 | print("Data Uploading...") 64 | 65 | dataloader, dataloader_test = build_dataset(args) 66 | print("Data Preparation Done...") 67 | loss_function = SimplificationLoss() 68 | print("Loss Function Done...") 69 | 70 | for epoch in tqdm(range(args.start_epoch,args.nepoch),desc='Training'): 71 | # time,loss 72 | epoch_start_time = time() 73 | total_cd_l1 = 0 74 | total_cd_l2 = 0 75 | train_loss_meter.reset() 76 | 77 | promodel.module.train() 78 | 79 | for param_group in optimizer.param_groups: 80 | print(f"Epoch {epoch+1}, Learning Rate: {param_group['lr']}") 81 | n_batches = len(dataloader) 82 | # train 83 | with tqdm(dataloader) as t: 84 | for batch_idx,data in enumerate(t): 85 | optimizer.zero_grad() 86 | 87 | n_itr = epoch * n_batches + batch_idx 88 | # to(cuda) 89 | images = data['image'].to(device) 90 | batch_size = images.shape[0] 91 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device) 92 | # partial_pcs = partial_pcs.to(device) 93 | pointclouds = data['points'].to(device) 94 | # category 95 | category = data['category'] 96 | category_indices = [int(c) for c in category] 97 | text_labels = torch.tensor(category_indices).to(device) 98 | 99 | pred_points = promodel(images,partial_pcs,text_labels) 100 | # to(cuda) 101 | pred_points = pred_points.to(pointclouds.device) 102 | pred_points = pred_points.transpose(2,1) 103 | net_loss,loss_t=loss_function(pred_points,pointclouds) 104 | 105 | # caculate Chamfer distance loss 106 | net_loss = net_loss.mean() 107 | loss_t = loss_t.mean() 108 | 109 | # net_loss_all = net_loss + loss_t 110 | net_loss_all = net_loss 111 | train_loss_meter.update(net_loss.item()) 112 | net_loss_all.backward() 113 | optimizer.step() 114 | 115 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4 116 | total_cd_l2 += cd_l2_item 117 | cd_l1_item = net_loss.item() * 1e4 118 | total_cd_l1 += cd_l1_item 119 | 120 | t.set_description('[Epoch %d/%d][Batch %d/%d]' % (epoch, args.nepoch, batch_idx + 1, n_batches)) 121 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [cd_l1_item, cd_l2_item]]) 122 | scheduler.step(epoch) # CosLR, 123 | 124 | avg_cd_l1 = total_cd_l1 / n_batches 125 | avg_cd_l2 = total_cd_l2 / n_batches 126 | epoch_end_time = time() 127 | logging.info('') 128 | logging.info( 129 | exp_name + '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s ' % 130 | (epoch,args.nepoch,epoch_end_time - epoch_start_time,['%.4f' % l for l in [avg_cd_l1,avg_cd_l2]]) 131 | ) 132 | log_string(f'{epoch}[{batch_idx}],{avg_cd_l1:.4f},{avg_cd_l2:.4f},{total_cd_l2:.4f},{best_epoch_losses},', LOG_FOUT) 133 | if epoch % args.epoch_interval_to_val == 0 or epoch == args.nepoch - 1: 134 | best_cd_l1, best_cd_l2 ,best_f1= val(promodel,loss_function,epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,Log_dir,best_cd_l1,best_cd_l2,best_f1) 135 | 136 | def val(net,cal_loss,curr_epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,log_dir,best_cd_l1,best_cd_l2,best_f1): 137 | val_start_time = time() 138 | metrics_val = ['cd_t'] 139 | val_loss_meters = {m: AverageValueMeter() for m in metrics_val} 140 | logging.info('Testing...') 141 | for v in val_loss_meters.values(): 142 | v.reset() 143 | net.module.eval() 144 | 145 | total_cd_l1 = 0 146 | total_cd_l2 = 0 147 | total_f1 = 0 148 | n_batches = len(dataloader_test) 149 | with torch.no_grad(): 150 | with tqdm(dataloader_test) as tt: 151 | for i,data in enumerate(tt): 152 | images = data['image'].to(device) 153 | batch_size = images.shape[0] 154 | gt = data['points'].to(device) 155 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device) 156 | category = data['category'] 157 | category_indices = [int(c) for c in category] 158 | text_labels = torch.tensor(category_indices).to(device) 159 | 160 | pred_points = net(images,partial_pcs,text_labels) 161 | pred_points = pred_points.transpose(2,1) 162 | 163 | loss_p, loss_t,f1 = cal_loss(pred_points, gt,calc_f1=True) 164 | 165 | cd_l1_item = torch.sum(loss_p).item() / batch_size * 1e4 166 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4 167 | f1_item = torch.sum(f1).item() / batch_size * 1e4 168 | total_cd_l1 += cd_l1_item 169 | total_cd_l2 += cd_l2_item 170 | total_f1 += f1_item 171 | 172 | avg_cd_l1 = total_cd_l1 / n_batches 173 | avg_cd_l2 = total_cd_l2 / n_batches 174 | avg_f1 = total_f1 / n_batches 175 | 176 | if avg_cd_l1 < best_cd_l1: 177 | best_cd_l1 = avg_cd_l1 178 | save_model_2(str(log_dir) + '/checkpoints/bestl1_network.pth', net) 179 | logging.info("Saving net...") 180 | if avg_cd_l2 < best_cd_l2: 181 | best_cd_l2 = avg_cd_l2 182 | if avg_f1 > best_f1: 183 | best_f1 = avg_f1 184 | 185 | log_string('%d,%.2f,%.2f,%.2f,%.2f,%.2f,%.2f'%(curr_epoch, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2,avg_f1,best_f1), LOG_FOUT) 186 | 187 | val_end_time = time() 188 | 189 | logging.info( 190 | '[Epoch %d/%d] TestTime = %.3f (s) Curr_cdl1 = %s Best_cdl1 = %s Curr_cdl2 = %s Best_cdl2 = %s Curr_f1 = %s Best_f1 = %s' % 191 | (curr_epoch, args.nepoch, val_end_time - val_start_time, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2, avg_f1, best_f1)) 192 | 193 | return best_cd_l1, best_cd_l2 , best_f1 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser(description='Train config file') 197 | parser.add_argument('-c', '--config', help='path to config file', required=True) 198 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True) 199 | arg = parser.parse_args() 200 | config_path = arg.config 201 | args = munch.munchify(yaml.safe_load(open(config_path))) 202 | 203 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 204 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id) #os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7' 205 | print('Using gpu:' + str(arg.gpu_id)) 206 | sphere_points = get_spherepoints(args.number_points,0.5) 207 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu') 208 | print('Number of points:' + str(args.number_points)) 209 | 210 | train() -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import copy 5 | import logging 6 | import os 7 | from collections import defaultdict 8 | import torch 9 | import torch.nn as nn 10 | 11 | from typing import Any 12 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable 13 | 14 | from termcolor import colored 15 | 16 | def get_missing_parameters_message(keys: List[str]) -> str: 17 | """ 18 | Get a logging-friendly message to report parameter names (keys) that are in 19 | the model but not found in a checkpoint. 20 | Args: 21 | keys (list[str]): List of keys that were not found in the checkpoint. 22 | Returns: 23 | str: message. 24 | """ 25 | groups = _group_checkpoint_keys(keys) 26 | msg = "Some model parameters or buffers are not found in the checkpoint:\n" 27 | msg += "\n".join( 28 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items() 29 | ) 30 | return msg 31 | 32 | 33 | def get_unexpected_parameters_message(keys: List[str]) -> str: 34 | """ 35 | Get a logging-friendly message to report parameter names (keys) that are in 36 | the checkpoint but not found in the model. 37 | Args: 38 | keys (list[str]): List of keys that were not found in the model. 39 | Returns: 40 | str: message. 41 | """ 42 | groups = _group_checkpoint_keys(keys) 43 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n" 44 | msg += "\n".join( 45 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items() 46 | ) 47 | return msg 48 | 49 | 50 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None: 51 | """ 52 | Strip the prefix in metadata, if any. 53 | Args: 54 | state_dict (OrderedDict): a state-dict to be loaded to the model. 55 | prefix (str): prefix. 56 | """ 57 | keys = sorted(state_dict.keys()) 58 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys): 59 | return 60 | 61 | for key in keys: 62 | newkey = key[len(prefix):] 63 | state_dict[newkey] = state_dict.pop(key) 64 | 65 | # also strip the prefix in metadata, if any.. 66 | try: 67 | metadata = state_dict._metadata # pyre-ignore 68 | except AttributeError: 69 | pass 70 | else: 71 | for key in list(metadata.keys()): 72 | # for the metadata dict, the key can be: 73 | # '': for the DDP module, which we want to remove. 74 | # 'module': for the actual model. 75 | # 'module.xx.xx': for the rest. 76 | 77 | if len(key) == 0: 78 | continue 79 | newkey = key[len(prefix):] 80 | metadata[newkey] = metadata.pop(key) 81 | 82 | 83 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]: 84 | """ 85 | Group keys based on common prefixes. A prefix is the string up to the final 86 | "." in each key. 87 | Args: 88 | keys (list[str]): list of parameter names, i.e. keys in the model 89 | checkpoint dict. 90 | Returns: 91 | dict[list]: keys with common prefixes are grouped into lists. 92 | """ 93 | groups = defaultdict(list) 94 | for key in keys: 95 | pos = key.rfind(".") 96 | if pos >= 0: 97 | head, tail = key[:pos], [key[pos + 1:]] 98 | else: 99 | head, tail = key, [] 100 | groups[head].extend(tail) 101 | return groups 102 | 103 | 104 | def _group_to_str(group: List[str]) -> str: 105 | """ 106 | Format a group of parameter name suffixes into a loggable string. 107 | Args: 108 | group (list[str]): list of parameter name suffixes. 109 | Returns: 110 | str: formated string. 111 | """ 112 | if len(group) == 0: 113 | return "" 114 | 115 | if len(group) == 1: 116 | return "." + group[0] 117 | 118 | return ".{" + ", ".join(group) + "}" 119 | 120 | 121 | def _named_modules_with_dup( 122 | model: nn.Module, prefix: str = "" 123 | ) -> Iterable[Tuple[str, nn.Module]]: 124 | """ 125 | The same as `model.named_modules()`, except that it includes 126 | duplicated modules that have more than one name. 127 | """ 128 | yield prefix, model 129 | for name, module in model._modules.items(): # pyre-ignore 130 | if module is None: 131 | continue 132 | submodule_prefix = prefix + ("." if prefix else "") + name 133 | yield from _named_modules_with_dup(module, submodule_prefix) -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from easydict import EasyDict 3 | import os 4 | from .logger import print_log 5 | 6 | def log_args_to_file(args, pre='args', logger=None): 7 | for key, val in args.__dict__.items(): 8 | print_log(f'{pre}.{key} : {val}', logger = logger) 9 | 10 | def log_config_to_file(cfg, pre='cfg', logger=None): 11 | for key, val in cfg.items(): 12 | if isinstance(cfg[key], EasyDict): 13 | print_log(f'{pre}.{key} = edict()', logger = logger) 14 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) 15 | continue 16 | print_log(f'{pre}.{key} : {val}', logger = logger) 17 | 18 | def merge_new_config(config, new_config): 19 | for key, val in new_config.items(): 20 | if not isinstance(val, dict): 21 | if key == '_base_': 22 | with open(new_config['_base_'], 'r') as f: 23 | try: 24 | val = yaml.load(f, Loader=yaml.FullLoader) 25 | except: 26 | val = yaml.load(f) 27 | config[key] = EasyDict() 28 | merge_new_config(config[key], val) 29 | else: 30 | config[key] = val 31 | continue 32 | if key not in config: 33 | config[key] = EasyDict() 34 | merge_new_config(config[key], val) 35 | return config 36 | 37 | def cfg_from_yaml_file(cfg_file): 38 | config = EasyDict() 39 | with open(cfg_file, 'r') as f: 40 | try: 41 | new_config = yaml.load(f, Loader=yaml.FullLoader) 42 | except: 43 | new_config = yaml.load(f) 44 | merge_new_config(config=config, new_config=new_config) 45 | return config 46 | 47 | def get_config(args, logger=None): 48 | if args.resume: 49 | cfg_path = os.path.join(args.experiment_path, 'config.yaml') 50 | if not os.path.exists(cfg_path): 51 | print_log("Failed to resume", logger = logger) 52 | raise FileNotFoundError() 53 | print_log(f'Resume yaml from {cfg_path}', logger = logger) 54 | args.config = cfg_path 55 | config = cfg_from_yaml_file(args.config) 56 | if not args.resume and args.local_rank == 0: 57 | save_experiment_config(args, config, logger) 58 | return config 59 | 60 | def save_experiment_config(args, config, logger = None): 61 | config_path = os.path.join(args.experiment_path, 'config.yaml') 62 | os.system('cp %s %s' % (args.config, config_path)) 63 | print_log(f'Copy the Config file from {args.config} to {config_path}',logger = logger ) -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | from collections import abc 10 | from pointnet2_ops import pointnet2_utils 11 | 12 | 13 | def fps(data, number): 14 | ''' 15 | data B N 3 16 | number int 17 | ''' 18 | fps_idx = pointnet2_utils.furthest_point_sample(data, number) 19 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() 20 | return fps_data 21 | 22 | 23 | def worker_init_fn(worker_id): 24 | np.random.seed(np.random.get_state()[1][0] + worker_id) 25 | 26 | def build_lambda_sche(opti, config): 27 | if config.get('decay_step') is not None: 28 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay) 29 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd) 30 | else: 31 | raise NotImplementedError() 32 | return scheduler 33 | 34 | def build_lambda_bnsche(model, config): 35 | if config.get('decay_step') is not None: 36 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay) 37 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd) 38 | else: 39 | raise NotImplementedError() 40 | return bnm_scheduler 41 | 42 | def set_random_seed(seed, deterministic=False): 43 | """Set random seed. 44 | Args: 45 | seed (int): Seed to be used. 46 | deterministic (bool): Whether to set the deterministic option for 47 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 48 | to True and `torch.backends.cudnn.benchmark` to False. 49 | Default: False. 50 | 51 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 52 | if cuda_deterministic: # slower, more reproducible 53 | cudnn.deterministic = True 54 | cudnn.benchmark = False 55 | else: # faster, less reproducible 56 | cudnn.deterministic = False 57 | cudnn.benchmark = True 58 | 59 | """ 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed_all(seed) 64 | if deterministic: 65 | torch.backends.cudnn.deterministic = True 66 | torch.backends.cudnn.benchmark = False 67 | 68 | 69 | def is_seq_of(seq, expected_type, seq_type=None): 70 | """Check whether it is a sequence of some type. 71 | Args: 72 | seq (Sequence): The sequence to be checked. 73 | expected_type (type): Expected type of sequence items. 74 | seq_type (type, optional): Expected sequence type. 75 | Returns: 76 | bool: Whether the sequence is valid. 77 | """ 78 | if seq_type is None: 79 | exp_seq_type = abc.Sequence 80 | else: 81 | assert isinstance(seq_type, type) 82 | exp_seq_type = seq_type 83 | if not isinstance(seq, exp_seq_type): 84 | return False 85 | for item in seq: 86 | if not isinstance(item, expected_type): 87 | return False 88 | return True 89 | 90 | 91 | def set_bn_momentum_default(bn_momentum): 92 | def fn(m): 93 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 94 | m.momentum = bn_momentum 95 | return fn 96 | 97 | class BNMomentumScheduler(object): 98 | 99 | def __init__( 100 | self, model, bn_lambda, last_epoch=-1, 101 | setter=set_bn_momentum_default 102 | ): 103 | if not isinstance(model, nn.Module): 104 | raise RuntimeError( 105 | "Class '{}' is not a PyTorch nn Module".format( 106 | type(model).__name__ 107 | ) 108 | ) 109 | 110 | self.model = model 111 | self.setter = setter 112 | self.lmbd = bn_lambda 113 | 114 | self.step(last_epoch + 1) 115 | self.last_epoch = last_epoch 116 | 117 | def step(self, epoch=None): 118 | if epoch is None: 119 | epoch = self.last_epoch + 1 120 | 121 | self.last_epoch = epoch 122 | self.model.apply(self.setter(self.lmbd(epoch))) 123 | 124 | def get_momentum(self, epoch=None): 125 | if epoch is None: 126 | epoch = self.last_epoch + 1 127 | return self.lmbd(epoch) 128 | 129 | 130 | 131 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False): 132 | ''' 133 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number. 134 | ''' 135 | _,n,c = xyz.shape 136 | 137 | assert n == num_points 138 | assert c == 3 139 | if crop == num_points: 140 | return xyz, None 141 | 142 | INPUT = [] 143 | CROP = [] 144 | for points in xyz: 145 | if isinstance(crop,list): 146 | num_crop = random.randint(crop[0],crop[1]) 147 | else: 148 | num_crop = crop 149 | 150 | points = points.unsqueeze(0) 151 | 152 | if fixed_points is None: 153 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda() 154 | else: 155 | if isinstance(fixed_points,list): 156 | fixed_point = random.sample(fixed_points,1)[0] 157 | else: 158 | fixed_point = fixed_points 159 | center = fixed_point.reshape(1,1,3).cuda() 160 | 161 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048 162 | 163 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048 164 | 165 | if padding_zeros: 166 | input_data = points.clone() 167 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0 168 | 169 | else: 170 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3 171 | 172 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0) 173 | 174 | if isinstance(crop,list): 175 | INPUT.append(fps(input_data,2048)) 176 | CROP.append(fps(crop_data,2048)) 177 | else: 178 | INPUT.append(input_data) 179 | CROP.append(crop_data) 180 | 181 | input_data = torch.cat(INPUT,dim=0)# B N 3 182 | crop_data = torch.cat(CROP,dim=0)# B M 3 183 | 184 | return input_data.contiguous(), crop_data.contiguous() 185 | 186 | def get_ptcloud_img(ptcloud,roll,pitch): 187 | fig = plt.figure(figsize=(8, 8)) 188 | 189 | x, z, y = ptcloud.transpose(1, 0) 190 | ax = fig.gca(projection=Axes3D.name, adjustable='box') 191 | ax.axis('off') 192 | # ax.axis('scaled') 193 | ax.view_init(roll,pitch) 194 | max, min = np.max(ptcloud), np.min(ptcloud) 195 | ax.set_xbound(min, max) 196 | ax.set_ybound(min, max) 197 | ax.set_zbound(min, max) 198 | ax.scatter(x, y, z, zdir='z', c=y, cmap='jet') 199 | 200 | fig.canvas.draw() 201 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 202 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, )) 203 | return img 204 | 205 | 206 | 207 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y', 208 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ): 209 | fig = plt.figure(figsize=(6*len(data_list),6)) 210 | cmax = data_list[-1][:,0].max() 211 | 212 | for i in range(len(data_list)): 213 | data = data_list[i][:-2048] if i == 1 else data_list[i] 214 | color = data[:,0] /cmax 215 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d') 216 | ax.view_init(30, -120) 217 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black') 218 | ax.set_title(titles[i]) 219 | 220 | ax.set_axis_off() 221 | ax.set_xlim(xlim) 222 | ax.set_ylim(ylim) 223 | ax.set_zlim(zlim) 224 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0) 225 | if not os.path.exists(path): 226 | os.makedirs(path) 227 | 228 | pic_path = path + '.png' 229 | fig.savefig(pic_path) 230 | 231 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy()) 232 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy()) 233 | plt.close(fig) 234 | 235 | 236 | def random_dropping(pc, e): 237 | up_num = max(64, 768 // (e//50 + 1)) 238 | pc = pc 239 | random_num = torch.randint(1, up_num, (1,1))[0,0] 240 | pc = fps(pc, random_num) 241 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device) 242 | pc = torch.cat([pc, padding], dim = 1) 243 | return pc 244 | 245 | 246 | def random_scale(partial, scale_range=[0.8, 1.2]): 247 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0] 248 | return partial * scale 249 | -------------------------------------------------------------------------------- /src/utils/registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from functools import partial 4 | from utils import config,misc 5 | 6 | class Registry: 7 | """A registry to map strings to classes. 8 | Registered object could be built from registry. 9 | Example: 10 | >>> MODELS = Registry('models') 11 | >>> @MODELS.register_module() 12 | >>> class ResNet: 13 | >>> pass 14 | >>> resnet = MODELS.build(dict(NAME='ResNet')) 15 | Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for 16 | advanced useage. 17 | Args: 18 | name (str): Registry name. 19 | build_func(func, optional): Build function to construct instance from 20 | Registry, func:`build_from_cfg` is used if neither ``parent`` or 21 | ``build_func`` is specified. If ``parent`` is specified and 22 | ``build_func`` is not given, ``build_func`` will be inherited 23 | from ``parent``. Default: None. 24 | parent (Registry, optional): Parent registry. The class registered in 25 | children registry could be built from parent. Default: None. 26 | scope (str, optional): The scope of registry. It is the key to search 27 | for children registry. If not specified, scope will be the name of 28 | the package where class is defined, e.g. mmdet, mmcls, mmseg. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, name, build_func=None, parent=None, scope=None): 33 | self._name = name 34 | self._module_dict = dict() 35 | self._children = dict() 36 | self._scope = self.infer_scope() if scope is None else scope 37 | 38 | # self.build_func will be set with the following priority: 39 | # 1. build_func 40 | # 2. parent.build_func 41 | # 3. build_from_cfg 42 | if build_func is None: 43 | if parent is not None: 44 | self.build_func = parent.build_func 45 | else: 46 | self.build_func = build_from_cfg 47 | else: 48 | self.build_func = build_func 49 | if parent is not None: 50 | assert isinstance(parent, Registry) 51 | parent._add_children(self) 52 | self.parent = parent 53 | else: 54 | self.parent = None 55 | 56 | def __len__(self): 57 | return len(self._module_dict) 58 | 59 | def __contains__(self, key): 60 | return self.get(key) is not None 61 | 62 | def __repr__(self): 63 | format_str = self.__class__.__name__ + \ 64 | f'(name={self._name}, ' \ 65 | f'items={self._module_dict})' 66 | return format_str 67 | 68 | @staticmethod 69 | def infer_scope(): 70 | """Infer the scope of registry. 71 | The name of the package where registry is defined will be returned. 72 | Example: 73 | # in mmdet/models/backbone/resnet.py 74 | >>> MODELS = Registry('models') 75 | >>> @MODELS.register_module() 76 | >>> class ResNet: 77 | >>> pass 78 | The scope of ``ResNet`` will be ``mmdet``. 79 | Returns: 80 | scope (str): The inferred scope name. 81 | """ 82 | # inspect.stack() trace where this function is called, the index-2 83 | # indicates the frame where `infer_scope()` is called 84 | filename = inspect.getmodule(inspect.stack()[2][0]).__name__ 85 | split_filename = filename.split('.') 86 | return split_filename[0] 87 | 88 | @staticmethod 89 | def split_scope_key(key): 90 | """Split scope and key. 91 | The first scope will be split from key. 92 | Examples: 93 | >>> Registry.split_scope_key('mmdet.ResNet') 94 | 'mmdet', 'ResNet' 95 | >>> Registry.split_scope_key('ResNet') 96 | None, 'ResNet' 97 | Return: 98 | scope (str, None): The first scope. 99 | key (str): The remaining key. 100 | """ 101 | split_index = key.find('.') 102 | if split_index != -1: 103 | return key[:split_index], key[split_index + 1:] 104 | else: 105 | return None, key 106 | 107 | @property 108 | def name(self): 109 | return self._name 110 | 111 | @property 112 | def scope(self): 113 | return self._scope 114 | 115 | @property 116 | def module_dict(self): 117 | return self._module_dict 118 | 119 | @property 120 | def children(self): 121 | return self._children 122 | 123 | def get(self, key): 124 | """Get the registry record. 125 | Args: 126 | key (str): The class name in string format. 127 | Returns: 128 | class: The corresponding class. 129 | """ 130 | scope, real_key = self.split_scope_key(key) 131 | if scope is None or scope == self._scope: 132 | # get from self 133 | if real_key in self._module_dict: 134 | return self._module_dict[real_key] 135 | else: 136 | # get from self._children 137 | if scope in self._children: 138 | return self._children[scope].get(real_key) 139 | else: 140 | # goto root 141 | parent = self.parent 142 | while parent.parent is not None: 143 | parent = parent.parent 144 | return parent.get(key) 145 | 146 | def build(self, *args, **kwargs): 147 | return self.build_func(*args, **kwargs, registry=self) 148 | 149 | def _add_children(self, registry): 150 | """Add children for a registry. 151 | The ``registry`` will be added as children based on its scope. 152 | The parent registry could build objects from children registry. 153 | Example: 154 | >>> models = Registry('models') 155 | >>> mmdet_models = Registry('models', parent=models) 156 | >>> @mmdet_models.register_module() 157 | >>> class ResNet: 158 | >>> pass 159 | >>> resnet = models.build(dict(NAME='mmdet.ResNet')) 160 | """ 161 | 162 | assert isinstance(registry, Registry) 163 | assert registry.scope is not None 164 | assert registry.scope not in self.children, \ 165 | f'scope {registry.scope} exists in {self.name} registry' 166 | self.children[registry.scope] = registry 167 | 168 | def _register_module(self, module_class, module_name=None, force=False): 169 | if not inspect.isclass(module_class): 170 | raise TypeError('module must be a class, ' 171 | f'but got {type(module_class)}') 172 | 173 | if module_name is None: 174 | module_name = module_class.__name__ 175 | if isinstance(module_name, str): 176 | module_name = [module_name] 177 | for name in module_name: 178 | if not force and name in self._module_dict: 179 | raise KeyError(f'{name} is already registered ' 180 | f'in {self.name}') 181 | self._module_dict[name] = module_class 182 | 183 | def deprecated_register_module(self, cls=None, force=False): 184 | warnings.warn( 185 | 'The old API of register_module(module, force=False) ' 186 | 'is deprecated and will be removed, please use the new API ' 187 | 'register_module(name=None, force=False, module=None) instead.') 188 | if cls is None: 189 | return partial(self.deprecated_register_module, force=force) 190 | self._register_module(cls, force=force) 191 | return cls 192 | 193 | def register_module(self, name=None, force=False, module=None): 194 | """Register a module. 195 | A record will be added to `self._module_dict`, whose key is the class 196 | name or the specified name, and value is the class itself. 197 | It can be used as a decorator or a normal function. 198 | Example: 199 | >>> backbones = Registry('backbone') 200 | >>> @backbones.register_module() 201 | >>> class ResNet: 202 | >>> pass 203 | >>> backbones = Registry('backbone') 204 | >>> @backbones.register_module(name='mnet') 205 | >>> class MobileNet: 206 | >>> pass 207 | >>> backbones = Registry('backbone') 208 | >>> class ResNet: 209 | >>> pass 210 | >>> backbones.register_module(ResNet) 211 | Args: 212 | name (str | None): The module name to be registered. If not 213 | specified, the class name will be used. 214 | force (bool, optional): Whether to override an existing class with 215 | the same name. Default: False. 216 | module (type): Module class to be registered. 217 | """ 218 | if not isinstance(force, bool): 219 | raise TypeError(f'force must be a boolean, but got {type(force)}') 220 | # NOTE: This is a walkaround to be compatible with the old api, 221 | # while it may introduce unexpected bugs. 222 | if isinstance(name, type): 223 | return self.deprecated_register_module(name, force=force) 224 | 225 | # raise the error ahead of time 226 | if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)): 227 | raise TypeError( 228 | 'name must be either of None, an instance of str or a sequence' 229 | f' of str, but got {type(name)}') 230 | 231 | # use it as a normal method: x.register_module(module=SomeClass) 232 | if module is not None: 233 | self._register_module( 234 | module_class=module, module_name=name, force=force) 235 | return module 236 | 237 | # use it as a decorator: @x.register_module() 238 | def _register(cls): 239 | self._register_module( 240 | module_class=cls, module_name=name, force=force) 241 | return cls 242 | 243 | return _register 244 | 245 | 246 | def build_from_cfg(cfg, registry, default_args=None): 247 | """Build a module from config dict. 248 | Args: 249 | cfg (edict): Config dict. It should at least contain the key "NAME". 250 | registry (:obj:`Registry`): The registry to search the type from. 251 | Returns: 252 | object: The constructed object. 253 | """ 254 | if not isinstance(cfg, dict): 255 | raise TypeError(f'cfg must be a dict, but got {type(cfg)}') 256 | if 'NAME' not in cfg: 257 | if default_args is None or 'NAME' not in default_args: 258 | raise KeyError( 259 | '`cfg` or `default_args` must contain the key "NAME", ' 260 | f'but got {cfg}\n{default_args}') 261 | if not isinstance(registry, Registry): 262 | raise TypeError('registry must be an mmcv.Registry object, ' 263 | f'but got {type(registry)}') 264 | 265 | if not (isinstance(default_args, dict) or default_args is None): 266 | raise TypeError('default_args must be a dict or None, ' 267 | f'but got {type(default_args)}') 268 | 269 | if default_args is not None: 270 | cfg = config.merge_new_config(cfg, default_args) 271 | 272 | obj_type = cfg.get('NAME') 273 | print(f'obj_type:{obj_type}') 274 | if isinstance(obj_type, str): 275 | obj_cls = registry.get(obj_type) 276 | if obj_cls is None: 277 | raise KeyError( 278 | f'{obj_type} is not in the {registry.name} registry') 279 | elif inspect.isclass(obj_type): 280 | obj_cls = obj_type 281 | else: 282 | raise TypeError( 283 | f'type must be a str or valid type, but got {type(obj_type)}') 284 | try: 285 | return obj_cls(cfg) 286 | except Exception as e: 287 | # Normal TypeError does not print class name. 288 | raise type(e)(f'{obj_cls.__name__}: {e}') 289 | -------------------------------------------------------------------------------- /src/val/val_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import random 5 | import logging 6 | from dataset_svr.trainer_dataset import build_dataset_val 7 | from dataset_svr.trainer_dataset import get_spherepoints 8 | from network.base_model import BaseModel 9 | from network.base_model import AverageValueMeter 10 | from loss.cdloss import SimplificationLoss 11 | from tqdm import tqdm 12 | import munch 13 | import yaml 14 | 15 | def val(): 16 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu') 17 | dataloader_test = build_dataset_val(args) 18 | len_test = len(dataloader_test) 19 | 20 | if not args.manual_seed: 21 | seed = random.randint(1,10000) 22 | else: 23 | seed = int(args.manual_seed) 24 | logging.info('Random Seed: %d' % seed) 25 | random.seed(seed) 26 | torch.manual_seed(seed) 27 | ckpt = torch.load(args.ckpt_path) 28 | if args.distributed: 29 | svrmodel = torch.nn.DataParallel(BaseModel(args.model),device_ids=args.gpus,output_device=args.gpus[0]) 30 | svrmodel.to(device) 31 | svrmodel.module.load_state_dict(ckpt['net_state_dict']) 32 | logging.info("%s's previous weights loaded." % args.model_name) 33 | svrmodel.module.eval() 34 | else: 35 | svrmodel = BaseModel(args.model).to(device) 36 | svrmodel.load_state_dict(ckpt['net_state_dict']) 37 | logging.info("%s's previous weights loaded." % args.model_name) 38 | svrmodel.eval() 39 | 40 | 41 | logging.info('Testing....') 42 | test_loss_l1 = AverageValueMeter() 43 | test_loss_l2 = AverageValueMeter() 44 | test_f1 = AverageValueMeter() 45 | loss_function = SimplificationLoss() 46 | sphere_points = get_spherepoints(args.number_points,0.5) 47 | 48 | with tqdm(dataloader_test) as t: 49 | for i,data in enumerate(t): 50 | with torch.no_grad(): 51 | images = data['image'].to(device) 52 | gt = data['points'].to(device) 53 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device) 54 | batch_size = gt.shape[0] 55 | pred_points = svrmodel(images,partial_pcs) 56 | 57 | pred_points = pred_points.transpose(2,1) 58 | loss_p,loss_t,f1 = loss_function(pred_points,gt,calc_f1=True) 59 | 60 | cd_l1_item = torch.sum(loss_p).item() / batch_size 61 | cd_l2_item = torch.sum(loss_t).item() / batch_size 62 | f1_item = torch.sum(f1).item()/batch_size 63 | test_loss_l1.update(cd_l1_item, images.shape[0]) 64 | test_loss_l2.update(cd_l2_item, images.shape[0]) 65 | test_f1.update(f1_item,images.shape[0]) 66 | 67 | print('cd_l1 %f cd_l2 %f f1 %f' % (test_loss_l1.avg, test_loss_l2.avg,test_f1.avg)) 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser(description='Train config file') 71 | parser.add_argument('-c', '--config', help='path to config file', required=True) 72 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True) 73 | arg = parser.parse_args() 74 | config_path = arg.config 75 | args = munch.munchify(yaml.safe_load(open(config_path))) 76 | 77 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 78 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id) 79 | print('Using gpu:' + str(arg.gpu_id)) 80 | val() -------------------------------------------------------------------------------- /src/val/val_pro.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import random 5 | import logging 6 | from dataset_svr.trainer_dataset import build_dataset_val,get_spherepoints 7 | from network.pro_model import ProModel 8 | from network.pro_model import AverageValueMeter 9 | from loss.cdloss import SimplificationLoss 10 | from tqdm import tqdm 11 | import munch 12 | import yaml 13 | import os 14 | 15 | def val(): 16 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu') 17 | dataloader_test = build_dataset_val(args) 18 | len_test = len(dataloader_test) 19 | 20 | if not args.manual_seed: 21 | seed = random.randint(1,10000) 22 | else: 23 | seed = int(args.manual_seed) 24 | logging.info('Random Seed: %d' % seed) 25 | random.seed(seed) 26 | torch.manual_seed(seed) 27 | ckpt = torch.load(args.ckpt_path,map_location=device) 28 | if args.distributed: 29 | promodel = torch.nn.DataParallel(ProModel(),device_ids=args.gpus,output_device=args.gpus[0]) 30 | promodel.to(device) 31 | promodel.module.load_state_dict(ckpt['net_state_dict']) 32 | logging.info("%s's previous weights loaded." % args.model_name) 33 | promodel.module.eval() 34 | else: 35 | promodel = ProModel().to(device) 36 | promodel.load_state_dict(ckpt['net_state_dict']) 37 | logging.info("%s's previous weights loaded." % args.model_name) 38 | promodel.eval() 39 | 40 | logging.info('Testing....') 41 | 42 | test_loss_l1 = AverageValueMeter() 43 | test_loss_l2 = AverageValueMeter() 44 | test_f1 = AverageValueMeter() 45 | loss_function = SimplificationLoss() 46 | sphere_points = get_spherepoints(args.number_points,0.5) 47 | 48 | 49 | with tqdm(dataloader_test) as t: 50 | for i,data in enumerate(t): 51 | with torch.no_grad(): 52 | images = data['image'].to(device) 53 | gt = data['points'].to(device) 54 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device) 55 | 56 | category = data['category'] 57 | category_indices = [int(c) for c in category] 58 | text_labels = torch.tensor(category_indices).to(device) 59 | 60 | batch_size = gt.shape[0] 61 | 62 | pred_points = promodel(images,partial_pcs,text_labels) 63 | pred_points = pred_points.transpose(2,1) 64 | loss_p,loss_t,f1 = loss_function(pred_points,gt,calc_f1=True) 65 | 66 | cd_l1_item = torch.sum(loss_p).item() / batch_size 67 | cd_l2_item = torch.sum(loss_t).item() / batch_size 68 | f1_item = torch.sum(f1).item() / batch_size 69 | test_loss_l1.update(cd_l1_item, images.shape[0]) 70 | test_loss_l2.update(cd_l2_item, images.shape[0]) 71 | test_f1.update(f1_item, images.shape[0]) 72 | 73 | print('cd_l1 %f cd_l2 %f f1 %f' % (test_loss_l1.avg, test_loss_l2.avg,test_f1.avg)) 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser(description='Train config file') 77 | parser.add_argument('-c', '--config', help='path to config file', required=True) 78 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True) 79 | arg = parser.parse_args() 80 | config_path = arg.config 81 | args = munch.munchify(yaml.safe_load(open(config_path))) 82 | 83 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 84 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id) 85 | print('Using gpu:' + str(arg.gpu_id)) 86 | val() 87 | val() --------------------------------------------------------------------------------