├── README.md
├── figs
├── mainfig.jpg
└── overview.jpg
└── src
├── chamfer_distance
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-39.pyc
│ └── chamfer_distance.cpython-39.pyc
├── chamfer_distance.cpp
├── chamfer_distance.cu
└── chamfer_distance.py
├── configs
├── ULIP.yaml
├── base_model.yaml
├── finetune_modelnet.yaml
├── ltp_model.yaml
└── pro_model.yaml
├── dataset_svr
├── augmenter.py
├── dataset_shapenet.py
├── dataset_shapenet_text.py
├── mesh_processor.py
├── pointcloud_processor.py
├── taxonomy.json
├── trainer_dataset.py
└── trainer_text_dataset.py
├── loss
├── cdloss.py
├── losses.py
└── losses_v2.py
├── models
├── Point_MAE.py
├── ULIP_models.py
├── ULIP_utils.py
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── build.py
├── clip_utils.py
├── encoder.py
├── pointbert
│ ├── PointTransformer_8192point.yaml
│ ├── ULIP_2_PointBERT_10k_colored_pointclouds.yaml
│ ├── __pycache__
│ │ ├── checkpoint.cpython-39.pyc
│ │ ├── dvae.cpython-39.pyc
│ │ ├── logger.cpython-39.pyc
│ │ ├── misc.cpython-39.pyc
│ │ └── point_encoder.cpython-39.pyc
│ ├── checkpoint.py
│ ├── dvae.py
│ ├── logger.py
│ ├── misc.py
│ └── point_encoder.py
├── text_encoder_3d.py
└── tokenizer.py
├── network
├── base_model.py
├── ltp_model.py
└── pro_model.py
├── tools
├── builder.py
└── smt.py
├── train
├── train_base.py
├── train_ltp.py
└── train_pro.py
├── utils
├── __init__.py
├── checkpoint.py
├── config.py
├── logger.py
├── misc.py
└── registry.py
└── val
├── val_base.py
└── val_pro.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
MESC-3D:Mining Effective Semantic Cues for 3D Reconstruction from a Single Image(CVPR 2025)
3 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | We release the code of the paper MESC-3D:Mining Effective Semantic Cues for 3D Reconstruction from a Single Image in this repository.
26 |
27 |
28 | ## Abstract
29 |
30 |
31 | In this work, we propose a novel single-image 3D reconstruction method called Mining Effective Semantic Cues for 3D Reconstruction from a Single Image (MESC-3D), which can actively mine effective semantic cues from entangled features. Specifically, we design an Effective Semantic Mining Module to establish connections between point clouds and image semantic attributes, enabling the point clouds to autonomously select the necessary information. Furthermore, to address the potential insufficiencies in semantic information from a single image, such as occlusions, inspired by the human ability to represent 3D objects using prior knowledge drawn from daily experiences, we introduce a 3DSPL. This module incorporates semantic understanding of spatial structures, enabling the model to interpret and reconstruct 3D objects with greater accuracy and realism, closely mirroring human perception of complex 3D environments. Extensive evaluations show that our method achieves significant improvements in reconstruction quality and robustness compared to prior works. Additionally, further experiments validate the strong generalization capabilities and excels in zero-shot preformance on unseen classes.
32 |
33 |
34 |
35 | ## Method
36 |
37 |
38 |
39 |
40 |
41 |
42 | Overview of MESC-3D. Our network is composed of two main components. (a) The 3DSPL align point cloud modality features with text features, aiming to capture the unique 3D geometric characteristics of each category. (b) The ESM establishes a connection between the semantic feature Fi and the 3D point cloud at ith stage, allowing each point to autonomously select the most valuable semantic information.
43 |
44 |
45 | ## Installation
46 | Clone this repository and install the required packages:
47 |
48 | - Install python Dependencies
49 | ```shell
50 |
51 | git clone https://github.com/QINGQINGLE/MESC-3D.git
52 | cd MESC-3D
53 |
54 | conda create -n mesc3d python=3.9
55 | conda activate mesc3d
56 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
57 |
58 | pip install -r requirements.txt
59 |
60 | ```
61 |
62 | - Compile PyTorch 3rd-party modules.
63 |
64 | ```shell
65 |
66 | cd package/Pointnet2_PyTorch-master/
67 | pip install -e .
68 | pip install pointnet2_ops_lib/.
69 |
70 | cd -
71 | cd package/KNN_CUDA-master/
72 | make && make install
73 |
74 | ```
75 |
76 | - CLIP Usage
77 | The following step is the usage and modification of CLIP.
78 | ```shell
79 |
80 | pip install git+https://github.com/openai/CLIP.git
81 | Or
82 | pip install clip
83 |
84 | ```
85 | Inplace
86 | ```shell
87 | def encode_text(self, text):
88 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
89 | x = x + self.positional_embedding.type(self.dtype)
90 | ...
91 | ```
92 | with
93 | ```shell
94 | def encode_token(self, token):
95 | x = self.token_embedding(token)
96 | return x
97 | def encode_text(self, text, token):
98 | #x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
99 | x = text.type(self.dtype) + self.positional_embedding.type(self.dtype)
100 | ...
101 | ```
102 | ## Dataset
103 |
104 | ## Pretrained-model
105 |
106 |
107 | We provide the following pretrained models: BaseModel, ProModel, ULIPModel, etc. Please download them from Google Drive.
108 |
109 |
110 |
111 | ## Training
112 |
113 | ## Testing
114 |
115 | The remaining code is on the way.
116 |
--------------------------------------------------------------------------------
/figs/mainfig.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/figs/mainfig.jpg
--------------------------------------------------------------------------------
/figs/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/figs/overview.jpg
--------------------------------------------------------------------------------
/src/chamfer_distance/__init__.py:
--------------------------------------------------------------------------------
1 | from .chamfer_distance import ChamferDistance
2 |
--------------------------------------------------------------------------------
/src/chamfer_distance/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/chamfer_distance/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/chamfer_distance/__pycache__/chamfer_distance.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/chamfer_distance/__pycache__/chamfer_distance.cpython-39.pyc
--------------------------------------------------------------------------------
/src/chamfer_distance/chamfer_distance.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | // CUDA forward declarations
4 | int ChamferDistanceKernelLauncher(
5 | const int b, const int n,
6 | const float* xyz,
7 | const int m,
8 | const float* xyz2,
9 | float* result,
10 | int* result_i,
11 | float* result2,
12 | int* result2_i);
13 |
14 | int ChamferDistanceGradKernelLauncher(
15 | const int b, const int n,
16 | const float* xyz1,
17 | const int m,
18 | const float* xyz2,
19 | const float* grad_dist1,
20 | const int* idx1,
21 | const float* grad_dist2,
22 | const int* idx2,
23 | float* grad_xyz1,
24 | float* grad_xyz2);
25 |
26 |
27 | void chamfer_distance_forward_cuda(
28 | const at::Tensor xyz1,
29 | const at::Tensor xyz2,
30 | const at::Tensor dist1,
31 | const at::Tensor dist2,
32 | const at::Tensor idx1,
33 | const at::Tensor idx2)
34 | {
35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(),
36 | xyz2.size(1), xyz2.data(),
37 | dist1.data(), idx1.data(),
38 | dist2.data(), idx2.data());
39 | }
40 |
41 | void chamfer_distance_backward_cuda(
42 | const at::Tensor xyz1,
43 | const at::Tensor xyz2,
44 | at::Tensor gradxyz1,
45 | at::Tensor gradxyz2,
46 | at::Tensor graddist1,
47 | at::Tensor graddist2,
48 | at::Tensor idx1,
49 | at::Tensor idx2)
50 | {
51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(),
52 | xyz2.size(1), xyz2.data(),
53 | graddist1.data(), idx1.data(),
54 | graddist2.data(), idx2.data(),
55 | gradxyz1.data(), gradxyz2.data());
56 | }
57 |
58 |
59 | void nnsearch(
60 | const int b, const int n, const int m,
61 | const float* xyz1,
62 | const float* xyz2,
63 | float* dist,
64 | int* idx)
65 | {
66 | for (int i = 0; i < b; i++) {
67 | for (int j = 0; j < n; j++) {
68 | const float x1 = xyz1[(i*n+j)*3+0];
69 | const float y1 = xyz1[(i*n+j)*3+1];
70 | const float z1 = xyz1[(i*n+j)*3+2];
71 | double best = 0;
72 | int besti = 0;
73 | for (int k = 0; k < m; k++) {
74 | const float x2 = xyz2[(i*m+k)*3+0] - x1;
75 | const float y2 = xyz2[(i*m+k)*3+1] - y1;
76 | const float z2 = xyz2[(i*m+k)*3+2] - z1;
77 | const double d=x2*x2+y2*y2+z2*z2;
78 | if (k==0 || d < best){
79 | best = d;
80 | besti = k;
81 | }
82 | }
83 | dist[i*n+j] = best;
84 | idx[i*n+j] = besti;
85 | }
86 | }
87 | }
88 |
89 |
90 | void chamfer_distance_forward(
91 | const at::Tensor xyz1,
92 | const at::Tensor xyz2,
93 | const at::Tensor dist1,
94 | const at::Tensor dist2,
95 | const at::Tensor idx1,
96 | const at::Tensor idx2)
97 | {
98 | const int batchsize = xyz1.size(0);
99 | const int n = xyz1.size(1);
100 | const int m = xyz2.size(1);
101 |
102 | const float* xyz1_data = xyz1.data();
103 | const float* xyz2_data = xyz2.data();
104 | float* dist1_data = dist1.data();
105 | float* dist2_data = dist2.data();
106 | int* idx1_data = idx1.data();
107 | int* idx2_data = idx2.data();
108 |
109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
111 | }
112 |
113 |
114 | void chamfer_distance_backward(
115 | const at::Tensor xyz1,
116 | const at::Tensor xyz2,
117 | at::Tensor gradxyz1,
118 | at::Tensor gradxyz2,
119 | at::Tensor graddist1,
120 | at::Tensor graddist2,
121 | at::Tensor idx1,
122 | at::Tensor idx2)
123 | {
124 | const int b = xyz1.size(0);
125 | const int n = xyz1.size(1);
126 | const int m = xyz2.size(1);
127 |
128 | const float* xyz1_data = xyz1.data();
129 | const float* xyz2_data = xyz2.data();
130 | float* gradxyz1_data = gradxyz1.data();
131 | float* gradxyz2_data = gradxyz2.data();
132 | float* graddist1_data = graddist1.data();
133 | float* graddist2_data = graddist2.data();
134 | const int* idx1_data = idx1.data();
135 | const int* idx2_data = idx2.data();
136 |
137 | for (int i = 0; i < b*n*3; i++)
138 | gradxyz1_data[i] = 0;
139 | for (int i = 0; i < b*m*3; i++)
140 | gradxyz2_data[i] = 0;
141 | for (int i = 0;i < b; i++) {
142 | for (int j = 0; j < n; j++) {
143 | const float x1 = xyz1_data[(i*n+j)*3+0];
144 | const float y1 = xyz1_data[(i*n+j)*3+1];
145 | const float z1 = xyz1_data[(i*n+j)*3+2];
146 | const int j2 = idx1_data[i*n+j];
147 |
148 | const float x2 = xyz2_data[(i*m+j2)*3+0];
149 | const float y2 = xyz2_data[(i*m+j2)*3+1];
150 | const float z2 = xyz2_data[(i*m+j2)*3+2];
151 | const float g = graddist1_data[i*n+j]*2;
152 |
153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
159 | }
160 | for (int j = 0; j < m; j++) {
161 | const float x1 = xyz2_data[(i*m+j)*3+0];
162 | const float y1 = xyz2_data[(i*m+j)*3+1];
163 | const float z1 = xyz2_data[(i*m+j)*3+2];
164 | const int j2 = idx2_data[i*m+j];
165 | const float x2 = xyz1_data[(i*n+j2)*3+0];
166 | const float y2 = xyz1_data[(i*n+j2)*3+1];
167 | const float z2 = xyz1_data[(i*n+j2)*3+2];
168 | const float g = graddist2_data[i*m+j]*2;
169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
175 | }
176 | }
177 | }
178 |
179 |
180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
185 | }
186 |
--------------------------------------------------------------------------------
/src/chamfer_distance/chamfer_distance.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | __global__
7 | void ChamferDistanceKernel(
8 | int b,
9 | int n,
10 | const float* xyz,
11 | int m,
12 | const float* xyz2,
13 | float* result,
14 | int* result_i)
15 | {
16 | const int batch=512;
17 | __shared__ float buf[batch*3];
18 | for (int i=blockIdx.x;ibest){
130 | result[(i*n+j)]=best;
131 | result_i[(i*n+j)]=best_i;
132 | }
133 | }
134 | __syncthreads();
135 | }
136 | }
137 | }
138 |
139 | void ChamferDistanceKernelLauncher(
140 | const int b, const int n,
141 | const float* xyz,
142 | const int m,
143 | const float* xyz2,
144 | float* result,
145 | int* result_i,
146 | float* result2,
147 | int* result2_i)
148 | {
149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i);
150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i);
151 |
152 | cudaError_t err = cudaGetLastError();
153 | if (err != cudaSuccess)
154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
155 | }
156 |
157 |
158 | __global__
159 | void ChamferDistanceGradKernel(
160 | int b, int n,
161 | const float* xyz1,
162 | int m,
163 | const float* xyz2,
164 | const float* grad_dist1,
165 | const int* idx1,
166 | float* grad_xyz1,
167 | float* grad_xyz2)
168 | {
169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
205 |
206 | cudaError_t err = cudaGetLastError();
207 | if (err != cudaSuccess)
208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
209 | }
210 |
--------------------------------------------------------------------------------
/src/chamfer_distance/chamfer_distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import importlib
3 | from torch.utils.cpp_extension import load
4 | import os
5 | # os.environ["CUDA_VISIBLE_DEVICES"] = '5'
6 | script_dir = os.path.dirname(__file__)
7 |
8 | sources = [
9 | os.path.join(script_dir, "chamfer_distance.cpp"),
10 | os.path.join(script_dir, "chamfer_distance.cu"),
11 | ]
12 | chamfer_found = importlib.find_loader("chamfer_3D") is not None
13 | if not chamfer_found:
14 | ## Cool trick from https://github.com/chrdiller
15 | print("Jitting Chamfer 3D")
16 | from torch.utils.cpp_extension import load
17 | chamfer_3D = load(name="chamfer_3D",sources=sources)
18 | print("Loaded JIT 3D CUDA chamfer distance")
19 | else:
20 | import chamfer_3D
21 | print("Loaded compiled 3D CUDA chamfer distance")
22 |
23 |
24 | class ChamferDistanceFunction(torch.autograd.Function):
25 | @staticmethod
26 | def forward(ctx, xyz1, xyz2):
27 | batchsize, n, _ = xyz1.size()
28 | _, m, _ = xyz2.size()
29 | device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
30 | # print("device:", device)
31 | xyz1 = xyz1.contiguous()
32 | xyz2 = xyz2.contiguous()
33 | dist1 = torch.zeros(batchsize, n)
34 | dist2 = torch.zeros(batchsize, m)
35 |
36 | idx1 = torch.zeros(batchsize, n, dtype=torch.int)
37 | idx2 = torch.zeros(batchsize, m, dtype=torch.int)
38 | xyz1 = xyz1.to(device)
39 | xyz2 = xyz2.to(device)
40 | dist1 = dist1.to(device)
41 | dist2 = dist2.to(device)
42 | idx1 = idx1.to(device)
43 | idx2 = idx2.to(device)
44 | torch.cuda.set_device(device)
45 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
46 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
47 | # if not xyz1.is_cuda:
48 | # chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
49 | # else:
50 | # # dist1 = dist1.cuda()
51 | # # dist2 = dist2.cuda()
52 | # # idx1 = idx1.cuda()
53 | # # idx2 = idx2.cuda()
54 | # dist1 = dist1.to(device)
55 | # dist2 = dist2.to(device)
56 | # idx1 = idx1.to(device)
57 | # idx2 = idx2.to(device)
58 | # torch.cuda.set_device(device)
59 | # # chamfer_3D.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, sidx2)
60 | # chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
61 |
62 | # ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
63 | return dist1, dist2
64 |
65 | @staticmethod
66 | def backward(ctx, graddist1, graddist2):
67 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
68 |
69 | graddist1 = graddist1.contiguous()
70 | graddist2 = graddist2.contiguous()
71 | device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')
72 |
73 | gradxyz1 = torch.zeros(xyz1.size())
74 | gradxyz2 = torch.zeros(xyz2.size())
75 |
76 | if not graddist1.is_cuda:
77 | chamfer_3D.backward(
78 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
79 | )
80 | else:
81 | # gradxyz1 = gradxyz1.cuda()
82 | # gradxyz2 = gradxyz2.cuda()
83 | gradxyz1 = gradxyz1.to(device)
84 | gradxyz2 = gradxyz2.to(device)
85 | # chamfer_3D.backward_cuda(
86 | # xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
87 | # )
88 | chamfer_3D.backward(
89 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
90 | )
91 |
92 | return gradxyz1, gradxyz2
93 |
94 |
95 | class ChamferDistance(torch.nn.Module):
96 | def __init__(self):
97 | super(ChamferDistance, self).__init__()
98 | def forward(self, xyz1, xyz2):
99 | return ChamferDistanceFunction.apply(xyz1, xyz2)
100 |
--------------------------------------------------------------------------------
/src/configs/ULIP.yaml:
--------------------------------------------------------------------------------
1 | model: ULIP_PointBERT
2 | npoints: 2048
3 | evaluate_3d: store_true
4 |
--------------------------------------------------------------------------------
/src/configs/base_model.yaml:
--------------------------------------------------------------------------------
1 | # Adam & scheduler
2 | optimizer : {
3 | type: Adam,
4 | kwargs: {
5 | lr : 0.001,
6 | betas : [0.9,0.999],
7 | eps: 0.000001
8 | }}
9 | scheduler: {
10 | type: CosLR,
11 | kwargs: {
12 | epochs: 400,
13 | initial_epochs : 20
14 | }}
15 | # model
16 | model_name: BaseModel
17 | manual_seed: null
18 |
19 | model:
20 | embed_dim: 32
21 | depth: 4
22 | out_dim: 32
23 | num_heads: 8
24 | mlp_ratio: 2.0
25 | qkv_bias: False
26 | qk_scale: null
27 | drop_rate: 0.2
28 | attn_drop_rate: 0.2
29 | drop_path_rate: 0.2
30 | num_points: 2048
31 | dec_dim: [768,512,256,128,64,32]
32 | yaml: Path/configs/finetune_modelnet.yaml
33 |
34 | # batch
35 | batch_size: 24
36 | start_epoch: 0
37 | nepoch: 400
38 | epoch_interval_to_save: 5
39 | epoch_interval_to_val: 1
40 |
41 | # gpus
42 | distributed: True
43 | gpus: [0,1]
44 | # SVR Data
45 | radius: 0.5
46 | normalization: UnitBall
47 | shapenet13: True
48 | SVR: True
49 | class_choice: ["airplane"]
50 | number_points: 2048
51 | number_points_eval: 2048
52 | random_rotation: False
53 | data_augmentation_axis_rotation: False
54 | data_augmentation_random_flips: False
55 | random_translation: False
56 | anisotropic_scaling: False
57 | demo: False
58 | sample: True
59 | workers: 8
60 |
61 | # path
62 | taxonomy_path: Path/dataset_svr/taxonomy.json
63 | dir_outpath: Path/log_dir
64 |
65 | ckpt_path: path/ckpt_path/base_model.pth
66 |
67 | pointcloud_path: Path/ShapeNetV1PointCloud/
68 | image_path: Path/ShapeNetV1Renderings/
69 | cache_path: Path/Cache/
70 | cache_path_test: Path/Cachetest/
71 |
--------------------------------------------------------------------------------
/src/configs/finetune_modelnet.yaml:
--------------------------------------------------------------------------------
1 | optimizer : {
2 | type: AdamW,
3 | kwargs: {
4 | lr : 0.0005,
5 | weight_decay : 0.05
6 | }}
7 |
8 | scheduler: {
9 | type: CosLR,
10 | kwargs: {
11 | epochs: 300,
12 | initial_epochs : 10
13 | }}
14 |
15 | # dataset : {
16 | # train : { _base_: cfgs/dataset_configs/ModelNet40.yaml,
17 | # others: {subset: 'train'}},
18 | # val : { _base_: cfgs/dataset_configs/ModelNet40.yaml,
19 | # others: {subset: 'test'}},
20 | # test : { _base_: cfgs/dataset_configs/ModelNet40.yaml,
21 | # others: {subset: 'test'}}}
22 | model : {
23 | NAME: PointTransformer,
24 | trans_dim: 384,
25 | depth: 12,
26 | drop_path_rate: 0.1,
27 | cls_dim: 40,
28 | num_heads: 6,
29 | group_size: 32,
30 | num_group: 64,
31 | encoder_dims: 384,
32 | }
33 |
34 | ckpts: Path/checkpoints/modelnet_1k.pth
35 |
36 | npoints: 2048
37 | total_bs : 32
38 | step_per_update : 1
39 | max_epoch : 300
40 | grad_norm_clip : 10
--------------------------------------------------------------------------------
/src/configs/ltp_model.yaml:
--------------------------------------------------------------------------------
1 | optimizer : {
2 | type: Adam,
3 | kwargs: {
4 | lr : 0.0001,
5 | betas : [0.9,0.999],
6 | eps: 0.000001
7 | }}
8 |
9 | scheduler: {
10 | type: CosLR,
11 | kwargs: {
12 | epochs: 250,
13 | initial_epochs : 10
14 | }}
15 |
16 |
17 | #trainer
18 | start_epoch: 0
19 | epochs: 250
20 | model: TextAlignPCModel
21 |
22 | batch_size: 32 # 128 64 32
23 | disable_amp: store_true
24 | # update_freq: 2 # 1 2 4
25 | print_freq: 10
26 | warmup_epochs: 1
27 | wandb: store_true
28 | #Shapenet Data
29 | radius: 0.5
30 | normalization: UnitBall
31 | shapenet13: True
32 | SVR: True
33 |
34 | class_choice: ["airplane"]
35 | number_points: 2048
36 | number_points_eval: 2048
37 | random_rotation: False
38 | data_augmentation_axis_rotation: False
39 | data_augmentation_random_flips: False
40 | random_translation: False
41 | anisotropic_scaling: False
42 | demo: False
43 | sample: True
44 | workers: 8
45 |
46 | # path
47 | dir_outpath: Path/dir_log/
48 | dir_checkpoints: None
49 | output_dir: Path/dir_out/
50 | clip_ckpt_path: Path/checkpoints/ViT-B-16.pt
51 | ulip_ckpt_path: Path/checkpoints/pretrained_models_ckpt_zero-sho_classification_checkpoint_pointbert.pt
52 |
53 | ulip: Path/configs/Ulip.yaml
54 |
55 | pointcloud_path: Path/ShapeNetV1PointCloud/
56 | image_path: Path/ShapeNetV1Renderings/
57 | cache_path: Path/Cache/
58 | cache_path_test: Path/Cachetest/
--------------------------------------------------------------------------------
/src/configs/pro_model.yaml:
--------------------------------------------------------------------------------
1 | # Adam & Scheduler
2 | optimizer : {
3 | type: Adam,
4 | kwargs: {
5 | lr : 0.001,
6 | betas : [0.9,0.999],
7 | eps: 0.000001
8 | }}
9 |
10 | scheduler: {
11 | type: CosLR,
12 | kwargs: {
13 | epochs: 400,
14 | initial_epochs : 20
15 | }}
16 | # model
17 | model_name: ProModel
18 | manual_seed: null
19 |
20 | model:
21 | embed_dim: 32
22 | depth: 4
23 | out_dim: 32
24 | num_heads: 8
25 | mlp_ratio: 2.0
26 | qkv_bias: False
27 | qk_scale: null
28 | drop_rate: 0.2
29 | attn_drop_rate: 0.2
30 | drop_path_rate: 0.2
31 | num_points: 2048
32 | dec_dim: [1280,1024,768,512,256,128,64,32]
33 | yaml: Path/configs/finetune_modelnet.yaml
34 |
35 | # batch
36 | batch_size: 24 # 32
37 | start_epoch: 0
38 | nepoch: 400
39 | epoch_interval_to_save: 5
40 | epoch_interval_to_val: 1
41 |
42 | # gpus
43 | distributed: True
44 | gpus: [0,1]
45 | # SVR Data
46 | radius: 0.5
47 | normalization: UnitBall
48 | shapenet13: True
49 | SVR: True
50 |
51 | class_choice: ["airplane"]
52 | number_points: 2048
53 | number_points_eval: 2048
54 | random_rotation: False
55 | data_augmentation_axis_rotation: False
56 | data_augmentation_random_flips: False
57 | random_translation: False
58 | anisotropic_scaling: False
59 | demo: False
60 | sample: True
61 | workers: 8
62 |
63 | # path
64 | taxonomy_path: Path/dataset_svr/taxonomy.json
65 | dir_outpath: Path/log_dir
66 |
67 | ckpt_path: Path/ckpt_path/pro_model.pth
68 | prompt_ckpt_path: Path/prompt.pth
69 | text_encoder_ckpt_path: Path/text_encoder.pth
70 |
71 | pointcloud_path: Path/ShapeNetV1PointCloud/
72 | image_path: Path/ShapeNetV1Renderings/
73 | cache_path: Path/Cache/
74 | cache_path_test: Path/Cachetest/
--------------------------------------------------------------------------------
/src/dataset_svr/augmenter.py:
--------------------------------------------------------------------------------
1 | import dataset_svr.pointcloud_processor as pointcloud_processor
2 |
3 |
4 | class Augmenter(object):
5 | def __init__(self, translation=False, rotation_axis=[], rotation_3D=False, anisotropic_scaling=False, flips=[]):
6 | self.translation = translation
7 | self.rotation_axis = rotation_axis
8 | self.rotation_3D = rotation_3D
9 | self.anisotropic_scaling = anisotropic_scaling
10 | self.flips = flips
11 |
12 | def __call__(self, points):
13 | operation = pointcloud_processor.DataAugmentation(points)
14 | for axis in self.rotation_axis:
15 | operation.random_axial_rotation(axis=axis)
16 | if self.rotation_3D:
17 | operation.random_rotation()
18 | if self.anisotropic_scaling:
19 | operation.random_anisotropic_scaling()
20 | if len(self.flips) > 0:
21 | operation.random_flips(self.flips)
22 | if self.translation:
23 | operation.random_translation()
24 |
--------------------------------------------------------------------------------
/src/dataset_svr/mesh_processor.py:
--------------------------------------------------------------------------------
1 | import pymesh
2 | import numpy as np
3 | from os.path import join, dirname
4 |
5 |
6 | class ColorMap:
7 | def __init__(self):
8 | self.colormap_path = "auxiliary/colormap.npy"
9 | self.colormap = (np.load(self.colormap_path) * 255).astype('int')
10 |
11 | def __call__(self, index):
12 | """
13 | :param value: a float
14 | :return:
15 | """
16 | colors = self.colormap[index]
17 | return colors
18 |
19 |
20 | def save(mesh, path, colormap):
21 | try:
22 | vertex_sources = mesh.get_attribute("vertex_sources") # batch, nb_prim, num_point, 3
23 | if vertex_sources.max() > 0:
24 | vertex_sources = (255 * (vertex_sources / vertex_sources.max())).astype('int')
25 | mesh.add_attribute("vertex_red")
26 | mesh.add_attribute("vertex_green")
27 | mesh.add_attribute("vertex_blue")
28 | mesh.set_attribute("vertex_red", colormap.colormap[vertex_sources][:, 0])
29 | mesh.set_attribute("vertex_green", colormap.colormap[vertex_sources][:, 1])
30 | mesh.set_attribute("vertex_blue", colormap.colormap[vertex_sources][:, 2])
31 | except:
32 | pass
33 | pymesh.save_mesh(path[:-3] + "ply", mesh, *mesh.get_attribute_names(), ascii=True)
34 |
--------------------------------------------------------------------------------
/src/dataset_svr/pointcloud_processor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class FunctionGenerator(object):
6 | def invert(self):
7 | print("This function has to be reimplemented in every inherited class")
8 |
9 |
10 | class ScaleFunctions(FunctionGenerator):
11 | def __init__(self, operator, inplace):
12 | self.operator = operator.clone()
13 | self.inplace = inplace
14 |
15 | def __call__(self, points):
16 | if self.inplace:
17 | points *= self.operator
18 | return points
19 | else:
20 | return points * self.operator
21 |
22 | def invert(self):
23 | self.operator = 1.0 / self.operator
24 |
25 |
26 | class RotationFunctions(FunctionGenerator):
27 | def __init__(self, operator, inplace):
28 | self.operator = operator.clone()
29 | self.inplace = inplace
30 | assert (self.operator.bmm(self.operator.transpose(1, 2).contiguous()).sum().item() - (
31 | operator.size(0) * operator.size(2))) ** 2 < 0.001, "Input matrix is not a rotation matrix"
32 |
33 | def __call__(self, points):
34 | rotated_points = torch.bmm(points, self.operator)
35 | if self.inplace:
36 | points.copy_(rotated_points)
37 | return points
38 | return rotated_points
39 |
40 | def invert(self):
41 | self.operator = self.operator.transpose(1, 2).contiguous()
42 |
43 |
44 | class TranslationFunctions(FunctionGenerator):
45 | def __init__(self, operator, inplace):
46 | self.operator = operator.clone()
47 | self.inplace = inplace
48 |
49 | def __call__(self, points):
50 | if self.inplace:
51 | points += self.operator
52 | return points
53 | else:
54 | return points + self.operator
55 |
56 | def invert(self):
57 | self.operator = -self.operator
58 |
59 |
60 | class Operation(object):
61 | def __init__(self, points, inplace=True, keep_track=False):
62 | """
63 | The keep track boolean is used in case one wants to unroll all the operation that have been performed
64 | :param keep_track: boolean
65 | """
66 | self.keep_track = keep_track
67 | self.transforms = []
68 | self.points = points
69 | self.device = points.device
70 | self.inplace = inplace
71 | self.dim = points.dim()
72 | self.type = self.points.type()
73 |
74 | if not self.inplace:
75 | self.points = self.points.clone()
76 | if self.dim == 2:
77 | self.points = self.points.unsqueeze_(0)
78 | elif self.dim == 3:
79 | pass
80 | else:
81 | print("Input should have dimension 2 or 3")
82 |
83 | def apply(self, points):
84 | for func in self.transforms:
85 | points = func(points)
86 | return points
87 |
88 | def invert(self):
89 | self.transforms.reverse()
90 | for func in self.transforms:
91 | func.invert()
92 |
93 | def scale(self, scale_vector):
94 | scaling_op = ScaleFunctions(scale_vector.to(self.device).type(self.type), inplace=self.inplace)
95 | self.points = scaling_op(self.points)
96 | if self.keep_track:
97 | self.transforms.append(scaling_op)
98 | return
99 |
100 | def translate(self, translation_vector):
101 | translation_op = TranslationFunctions(translation_vector.to(self.device).type(self.type), inplace=self.inplace)
102 | self.points = translation_op(self.points)
103 | if self.keep_track:
104 | self.transforms.append(translation_op)
105 | return
106 |
107 | def rotate(self, rotation_vector):
108 | rotation_op = RotationFunctions(rotation_vector.to(self.device).type(self.type), inplace=self.inplace)
109 | self.points = rotation_op(self.points)
110 | if self.keep_track:
111 | self.transforms.append(rotation_op)
112 | return
113 |
114 | @staticmethod
115 | def get_3D_rot_matrix(axis, rad_angle):
116 | """
117 | Get a 3D rotation matrix around axis with angle in radian
118 | :param axis: int
119 | :param angle: torch.tensor of size Batch.
120 | :return: Rotation Matrix as a tensor
121 | """
122 | cos_angle = torch.cos(rad_angle)
123 | sin_angle = torch.sin(rad_angle)
124 | rotation_matrix = torch.zeros(rad_angle.size(0), 3, 3)
125 | rotation_matrix[:, 1, 1].fill_(1)
126 | rotation_matrix[:, 0, 0].copy_(cos_angle)
127 | rotation_matrix[:, 0, 2].copy_(sin_angle)
128 | rotation_matrix[:, 2, 0].copy_(-sin_angle)
129 | rotation_matrix[:, 2, 2].copy_(cos_angle)
130 | if axis == 0:
131 | rotation_matrix = rotation_matrix[:, [1, 0, 2], :][:, :, [1, 0, 2]]
132 | if axis == 2:
133 | rotation_matrix = rotation_matrix[:, [0, 2, 1], :][:, :, [0, 2, 1]]
134 | return rotation_matrix
135 |
136 | def rotate_axis_angle(self, axis, rad_angle, normals=False):
137 | """
138 |
139 | :param points: Batched points
140 | :param axis: int
141 | :param angle: batched angles
142 | :return:
143 | """
144 | rot_matrix = Operation.get_3D_rot_matrix(axis=axis, rad_angle=rad_angle)
145 | if normals:
146 | rot_matrix = torch.cat([rot_matrix, rot_matrix], dim=2)
147 | self.rotate(rot_matrix)
148 | return
149 |
150 |
151 | class Normalization(Operation):
152 | def __init__(self, *args, **kwargs):
153 | super(Normalization, self).__init__(*args, **kwargs)
154 |
155 | def center_pointcloud(self):
156 | """
157 | In-place centering
158 | :param points: Tensor Batch, N_pts, D_dim
159 | :return: None
160 | """
161 | # input :
162 | # ouput : torch Tensor N_pts, D_dim
163 | centroid = torch.mean(self.points, dim=1, keepdim=True)
164 | self.translate(-centroid)
165 | return self.points
166 |
167 | @staticmethod
168 | def center_pointcloud_functional(points):
169 | operator = Normalization(points, inplace=False)
170 | return operator.center_pointcloud()
171 |
172 | def normalize_unitL2ball(self):
173 | """
174 | In-place normalization of input to unit ball
175 | :param points: torch Tensor Batch, N_pts, D_dim
176 | :return: None
177 | """
178 | # input : torch Tensor N_pts, D_dim
179 | # ouput : torch Tensor N_pts, D_dim
180 | #
181 | self.center_pointcloud()
182 | scaling_factor_square, _ = torch.max(torch.sum(self.points ** 2, dim=2, keepdim=True), dim=1, keepdim=True)
183 | scaling_factor = torch.sqrt(scaling_factor_square)
184 | self.scale(1.0 / scaling_factor)
185 | return self.points
186 |
187 | @staticmethod
188 | def normalize_unitL2ball_functional(points):
189 | operator = Normalization(points, inplace=False)
190 | return operator.normalize_unitL2ball()
191 |
192 | def center_bounding_box(self):
193 | """
194 | in place Centering : return center the bounding box
195 | :param points: torch Tensor Batch, N_pts, D_dim
196 | :return: diameter
197 | """
198 | min_vals, _ = torch.min(self.points, 1, keepdim=True)
199 | max_vals, _ = torch.max(self.points, 1, keepdim=True)
200 | self.translate(-(min_vals + max_vals) / 2)
201 | return self.points, (max_vals - min_vals) / 2
202 |
203 | @staticmethod
204 | def center_bounding_box_functional(points):
205 | operator = Normalization(points, inplace=False)
206 | points, _ = operator.center_bounding_box()
207 | return points
208 |
209 | def normalize_bounding_box(self, isotropic=True):
210 | """
211 | In place : center the bounding box and uniformly scale the bounding box to edge lenght 1 or max edge length 1 if isotropic is True (default).
212 | :param points: torch Tensor Batch, N_pts, D_dim
213 | :return:
214 | """
215 | _, diameter = self.center_bounding_box()
216 | if isotropic:
217 | diameter, _ = torch.max(diameter, 2, keepdim=True)
218 | self.scale(1.0 / diameter)
219 | return self.points
220 |
221 | @staticmethod
222 | def normalize_bounding_box_functional(points):
223 | operator = Normalization(points, inplace=False)
224 | return operator.normalize_bounding_box()
225 |
226 | @staticmethod
227 | def identity_functional(points):
228 | return points
229 |
230 |
231 | class DataAugmentation(Operation):
232 | def __init__(self, *args, **kwargs):
233 | super(DataAugmentation, self).__init__(*args, **kwargs)
234 |
235 | def random_anisotropic_scaling(self, min_val=0.75, max_val=1.25):
236 | """
237 | In place : Random Anisotropic scaling by a factor between min_val and max_val
238 | :param points: torch Tensor Batch, N_pts, D_dim
239 | :return:
240 | """
241 | scale = torch.rand(self.points.size(0), 1, self.points.size(2)) * (max_val - min_val) + min_val
242 | self.scale(scale)
243 | return
244 |
245 | def random_axial_rotation(self, axis=0, normals=False, range_rot=360):
246 | """
247 | Compute a random rotation of the batch around an axis. There is is no in-place version of this function because bmm_ is not possible in pytorch.
248 | :param points: torch Tensor Batch, N_pts, D_dim
249 | :return: torch Tensor Batch, N_pts, D_dim
250 | """
251 | scale_factor = 360.0 / range_rot
252 | scale_factor = np.pi / scale_factor
253 | rad_angle = torch.rand(self.points.size(0)) * 2 * scale_factor - scale_factor
254 | self.rotate_axis_angle(axis=axis, rad_angle=rad_angle, normals=normals)
255 | return
256 |
257 | @staticmethod
258 | def get_random_rotation_matrix(batch_size=1):
259 | """
260 | Get a random 3D rotation matrix
261 | :return: Rotation Matrix as a tensor
262 | from : https://en.wikipedia.org/wiki/Rotation_matrix#Basic_rotations
263 | An easy way to do this : sample a point on the sphere (with normalize(normal(), normal(), normal())
264 | then sample an angle, then just compute the associated rotation matrix
265 | """
266 | # Select a random point on the sphere
267 | x = torch.randn(batch_size, 1, 3).double()
268 | scaling_factor_square, _ = torch.max(torch.sum(x ** 2, dim=2, keepdim=True), dim=1, keepdim=True)
269 | scaling_factor = torch.sqrt(scaling_factor_square)
270 | x /= scaling_factor
271 | x = x.squeeze()
272 | XX = torch.bmm(x.unsqueeze(2), x.unsqueeze(1))
273 |
274 | # get random angle
275 | rad_angle = torch.rand(batch_size).double() * 2 * np.pi + np.pi
276 | cos_angle = torch.cos(rad_angle)
277 | sin_angle = torch.sin(rad_angle)
278 |
279 | # Compute fat matrix
280 | rotation_matrix = torch.zeros(rad_angle.size(0), 3, 3).double()
281 |
282 | rotation_matrix[:, 0, 0].copy_(cos_angle + XX[:, 0, 0] * (1 - cos_angle))
283 | rotation_matrix[:, 1, 1].copy_(cos_angle + XX[:, 1, 1] * (1 - cos_angle))
284 | rotation_matrix[:, 2, 2].copy_(cos_angle + XX[:, 2, 2] * (1 - cos_angle))
285 |
286 | rotation_matrix[:, 0, 1].copy_(XX[:, 0, 1] * (1 - cos_angle) - x[:, 2] * sin_angle)
287 | rotation_matrix[:, 1, 0].copy_(XX[:, 0, 1] * (1 - cos_angle) + x[:, 2] * sin_angle)
288 |
289 | rotation_matrix[:, 0, 2].copy_(XX[:, 0, 2] * (1 - cos_angle) + x[:, 1] * sin_angle)
290 | rotation_matrix[:, 2, 0].copy_(XX[:, 0, 2] * (1 - cos_angle) - x[:, 1] * sin_angle)
291 |
292 | rotation_matrix[:, 1, 2].copy_(XX[:, 1, 2] * (1 - cos_angle) - x[:, 0] * sin_angle)
293 | rotation_matrix[:, 2, 1].copy_(XX[:, 1, 2] * (1 - cos_angle) + x[:, 0] * sin_angle)
294 |
295 | return rotation_matrix
296 |
297 | def random_rotation(self, normals=False):
298 | """
299 | Compute a random rotation of the batch. There is is no in-place version of this function because bmm_ is not possible in pytorch.
300 | :param points: torch Tensor Batch, N_pts, D_dim
301 | :return: torch Tensor Batch, N_pts, D_dim
302 | """
303 | rot_matrix = DataAugmentation.get_random_rotation_matrix(batch_size=self.points.size(0))
304 | if normals:
305 | rot_matrix = torch.cat([rot_matrix, rot_matrix], dim=2)
306 | self.rotate(rot_matrix)
307 | return
308 |
309 | def random_translation(self, scale=0.03):
310 | """
311 | In place Compute a random tranlation of the batch.
312 | :param points: torch Tensor Batch, N_pts, D_dim
313 | :return:
314 | """
315 | translation_vector = torch.rand(self.points.size(0), 1, self.points.size(2)) * 2 * scale - scale
316 | self.translate(translation_vector)
317 | return
318 |
319 | @staticmethod
320 | def diff(first, second):
321 | second = set(second)
322 | return [item for item in first if item not in second]
323 |
324 | def random_flips(self, dims=[]):
325 | """
326 | In place Random flip
327 | :param points: torch Tensor Batch, N_pts, D_dim
328 | :return:
329 | """
330 | exclude_dims = DataAugmentation.diff(range(self.points.size(2)), dims)
331 | scale_factor = torch.randint(2, (self.points.size(0), 1, self.points.size(2))) * 2 - 1
332 | for axis in exclude_dims:
333 | scale_factor[:, :, axis].fill_(1)
334 | self.scale(scale_factor)
335 | return
336 |
337 |
338 | # Done for eurographics 19
339 | def barycentric(p, a, b, c):
340 | """
341 | :param p: numpy arrays of size N_points x 3
342 | :param a: numpy arrays of size N_points x 3
343 | :param b: numpy arrays of size N_points x 3
344 | :param c: numpy arrays of size N_points x 3
345 | :return: barycentric coordinates point p in triangle (a,b,c)
346 | """
347 |
348 | v0 = b - a
349 | v1 = c - a
350 | v2 = p - a
351 |
352 | d00 = np.sum(np.multiply(v0, v0), 1)
353 | d01 = np.sum(np.multiply(v0, v1), 1)
354 | d11 = np.sum(np.multiply(v1, v1), 1)
355 | d20 = np.sum(np.multiply(v2, v0), 1)
356 | d21 = np.sum(np.multiply(v2, v1), 1)
357 |
358 | denom = np.multiply(d00, d11) - np.multiply(d01, d01)
359 |
360 | v = (np.multiply(d11, d20) - np.multiply(d01, d21)) / denom
361 | w = (np.multiply(d00, d21) - np.multiply(d01, d20)) / denom
362 | u = - v - w + 1.0
363 |
364 | return (u, v, w)
365 |
366 |
367 | if __name__ == '__main__':
368 | print("Start unit test")
369 |
--------------------------------------------------------------------------------
/src/dataset_svr/trainer_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | import munch
5 | # import dataset_shapenet
6 | import dataset_svr.dataset_shapenet as dataset_shapenet
7 | import yaml
8 |
9 | def pc_normalize(pc, radius):
10 | l = pc.shape[0]
11 | centroid = np.mean(pc, axis=0)
12 | pc = pc - centroid
13 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
14 | pc = pc / m * radius
15 | return pc
16 |
17 | def get_spherepoints(num_points, radius):
18 | ball_name = 'Path/balls/%d.xyz' % num_points
19 | ball = np.loadtxt(ball_name)
20 | ball = pc_normalize(ball, radius)
21 | return ball
22 | def build_dataset(args):
23 | # Create Datasets
24 | dataset_train = dataset_shapenet.ShapeNet(args, train=True)
25 | dataset_test = dataset_shapenet.ShapeNet(args, train=False)
26 |
27 | # Create dataloaders
28 | dataloader_train = torch.utils.data.DataLoader(dataset_train,
29 | batch_size=args.batch_size,
30 | shuffle=True,
31 | num_workers=int(args.workers))
32 | dataloader_test = torch.utils.data.DataLoader(dataset_test,
33 | batch_size=args.batch_size,
34 | shuffle=False, num_workers=int(args.workers))
35 |
36 | len_dataset = len(dataset_train)
37 | len_dataset_test = len(dataset_test)
38 | print('Length of train dataset:%d', len_dataset)
39 | print('Length of test dataset:%d', len_dataset_test)
40 |
41 | return dataloader_train, dataloader_test
42 |
43 | def build_dataset_val(args):
44 |
45 | # Create Datasets
46 | dataset_test = dataset_shapenet.ShapeNet_val(args, train=False)
47 |
48 | # Create dataloaders
49 | dataloader_test = torch.utils.data.DataLoader(dataset_test,
50 | batch_size=args.batch_size,
51 | shuffle=False, num_workers=int(args.workers))
52 |
53 | len_dataset_test = len(dataset_test)
54 | print('Length of test dataset:%d', len_dataset_test)
55 |
56 | return dataloader_test
57 |
58 | if __name__ == '__main__':
59 | config_path = "Path/MAE/config.yaml"
60 | args = munch.munchify(yaml.safe_load(open(config_path)))
61 | dataloader_test = build_dataset_val(args)
62 | for data in dataloader_test:
63 | img = data['image']
64 | pc = data['points']
65 | print(img.shape)
66 | print(pc.shape)
67 | print('done')
--------------------------------------------------------------------------------
/src/dataset_svr/trainer_text_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import munch
4 | import dataset_svr.dataset_shapenet_text as dataset_shapenet
5 | import yaml
6 | import numpy as np
7 | def pc_normalize(pc, radius):
8 | l = pc.shape[0]
9 | centroid = np.mean(pc, axis=0)
10 | pc = pc - centroid
11 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
12 | pc = pc / m * radius
13 | return pc
14 |
15 | def get_spherepoints(num_points, radius):
16 | ball_name = 'Path/balls/%d.xyz' % num_points
17 | ball = np.loadtxt(ball_name)
18 | ball = pc_normalize(ball, radius)
19 | return ball
20 |
21 | def build_dataset(args):
22 | # Create Datasets
23 | dataset_train = dataset_shapenet.ShapeNet(args, train=True)
24 | dataset_test = dataset_shapenet.ShapeNet(args, train=False)
25 |
26 | # Create dataloaders
27 | dataloader_train = torch.utils.data.DataLoader(dataset_train,
28 | batch_size=args.batch_size,
29 | shuffle=True,
30 | num_workers=int(args.workers))
31 | dataloader_test = torch.utils.data.DataLoader(dataset_test,
32 | batch_size=args.batch_size,
33 | shuffle=False, num_workers=int(args.workers))
34 |
35 | len_dataset = len(dataset_train)
36 | len_dataset_test = len(dataset_test)
37 | print('Length of train dataset:%d', len_dataset)
38 | print('Length of test dataset:%d', len_dataset_test)
39 |
40 | return dataloader_train, dataloader_test
41 |
42 | def build_dataset_val(args):
43 |
44 | # Create Datasets
45 | dataset_test = dataset_shapenet.ShapeNet_val(args, train=False)
46 |
47 | # Create dataloaders
48 | dataloader_test = torch.utils.data.DataLoader(dataset_test,
49 | batch_size=args.batch_size,
50 | shuffle=False, num_workers=int(args.workers))
51 |
52 | len_dataset_test = len(dataset_test)
53 | print('Length of test dataset:%d', len_dataset_test)
54 |
55 | return dataloader_test
56 | def get_batch_label(texts, prompt_text, label_map: dict):
57 | label_vectors = torch.zeros(0)
58 | if len(label_map) != 7:
59 | if len(label_map) == 2:
60 | for text in texts:
61 | label_vector = torch.zeros(2)
62 | if text == 'Normal':
63 | label_vector[0] = 1
64 | else:
65 | label_vector[1] = 1
66 | label_vector = label_vector.unsqueeze(0)
67 | label_vectors = torch.cat([label_vectors, label_vector], dim=0)
68 | else:
69 | for text in texts:
70 | label_vector = torch.zeros(len(prompt_text))
71 | if text in label_map:
72 | label_text = label_map[text]
73 | label_vector[prompt_text.index(label_text)] = 1
74 |
75 | label_vector = label_vector.unsqueeze(0)
76 | label_vectors = torch.cat([label_vectors, label_vector], dim=0)
77 | else:
78 | for text in texts:
79 | label_vector = torch.zeros(len(prompt_text))
80 | labels = text.split('-')
81 | for label in labels:
82 | if label in label_map:
83 | label_text = label_map[label]
84 | label_vector[prompt_text.index(label_text)] = 1
85 |
86 | label_vector = label_vector.unsqueeze(0)
87 | label_vectors = torch.cat([label_vectors, label_vector], dim=0)
88 |
89 | return label_vectors
90 | # def get_prompt_text(category:list,label_map: dict):
91 | # prompt_text = []
92 | # for v in label_map.values():
93 | # prompt_text.append(v)
94 |
95 | # return prompt_text
96 | def get_prompt_text(category: list, label_map: dict) -> list:
97 | prompt_text = []
98 | for cat in category:
99 | if cat in label_map:
100 | prompt_text.append(label_map[cat])
101 | else:
102 |
103 | prompt_text.append('Unknown')
104 | return prompt_text
105 | if __name__ == '__main__':
106 | from tqdm import tqdm
107 | from models.tokenizer import SimpleTokenizer
108 | from models.ULIP_models import ULIP_PointBERT
109 | from collections import OrderedDict
110 | config_path = 'Path/MAE/config.yaml'
111 | args = munch.munchify(yaml.safe_load(open(config_path)))
112 | dataloader_train, dataloader_test = build_dataset(args)
113 | tokenizer = SimpleTokenizer()
114 | label_map = dict({'02691156':'airplane','02828884':'bench','02933112':'cabinet','02958343':'car','03001627':'chair',
115 | '03211117':'display','03636649':'lamp','03691459':'loudspeaker','04090263':'rifle',
116 | '04256520':'sofa','04379243':'table','04401088':'telephone','04530566':'vessel'})
117 | # prompt_text = get_prompt_text(label_map)
118 | # print(f'prompt_text:{prompt_text}')
119 | config_path = "Path/cfgs/ULIP.yaml"
120 | args = munch.munchify(yaml.safe_load(open(config_path)))
121 | ULIP = ULIP_PointBERT(args).to('cuda:7')
122 | ckpt = torch.load('Path/checkpoints/pretrained_models_ckpt_zero-sho_classification_checkpoint_pointbert.pt')
123 | state_dict = OrderedDict()
124 | for k, v in ckpt["state_dict"].items():
125 | # print(f'k {k}')
126 | state_dict[k.replace("module.", "")] = v
127 | ULIP.load_state_dict(state_dict, strict=False)
128 | tokenized_captions = []
129 | with tqdm(dataloader_train) as t:
130 | for batch_idx,data in enumerate(t):
131 | # if batch_idx == 0:
132 | # category
133 | tokenized_captions = []
134 | category = data['category']
135 | text_labels = list(category)
136 | # text_labels = get_batch_label(text_labels, prompt_text, label_map)
137 | prompt_text = get_prompt_text(text_labels,label_map)
138 | path = data['image_path']
139 | # captions
140 | caption = data["caption"]
141 | print(f'caption : {caption}')
142 | caption = list(zip(*caption))
143 | for i in range(len(caption)):
144 | caption[i] = list(caption[i])
145 | for i in range(len(caption)):
146 | texts = tokenizer(caption[i]).cuda(7, non_blocking=True)
147 | tokenized_captions.append(texts)
148 | tokenized_captions = torch.stack(tokenized_captions)
149 | text_features = []
150 | for i in range(tokenized_captions.shape[0]):
151 | text_for_one_sample = tokenized_captions[i]
152 | text_embed = ULIP.encode_text(text_for_one_sample)
153 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
154 | text_embed = text_embed.mean(dim=0)
155 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
156 | text_features.append(text_embed)
157 | text_features = torch.stack(text_features)
158 |
159 | # print(f'tokenized {tokenized_captions.shape}')
160 | # print(f"catefory {category}")
161 | # print(f"text_labels {text_labels}")
162 | # print(f'prompt_text:{prompt_text}')
163 | print(f'text_embed {text_features.shape}')
164 | # print(f'path {path}')
165 |
--------------------------------------------------------------------------------
/src/loss/cdloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from chamfer_distance.chamfer_distance import ChamferDistance
4 | def fscore(dist1, dist2, threshold=0.01):
5 | """
6 | Calculates the F-score between two point clouds with the corresponding threshold value.
7 | :param dist1: Batch, N-Points
8 | :param dist2: Batch, N-Points
9 | :param th: float
10 | :return: fscore, precision, recall
11 | """
12 | # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly.
13 | precision_1 = torch.mean((dist1 < threshold).float(), dim=1)
14 | precision_2 = torch.mean((dist2 < threshold).float(), dim=1)
15 | fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
16 | fscore[torch.isnan(fscore)] = 0
17 | return fscore, precision_1, precision_2
18 | class SimplificationLoss(nn.Module):
19 | def __init__(self):
20 | super(SimplificationLoss, self).__init__()
21 |
22 | def forward(self, ref_pc, samp_pc,calc_f1=False):
23 |
24 | cost_p1_p2, cost_p2_p1 = ChamferDistance()(samp_pc, ref_pc)
25 | cd_p = (torch.sqrt(cost_p1_p2).mean(1) + torch.sqrt(cost_p2_p1).mean(1)) / 2 # l1
26 | cd_t = (cost_p1_p2.mean(1) + cost_p2_p1.mean(1)) #l2
27 | if calc_f1:
28 | f1,_,_ =fscore(cost_p1_p2,cost_p2_p1)
29 | return cd_p,cd_t,f1
30 |
31 | else:
32 | return cd_p,cd_t
--------------------------------------------------------------------------------
/src/loss/losses.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Le Xue
7 | '''
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from ULIP_utils import all_gather_batch,get_rank
13 |
14 | loss = torch.nn.KLDivLoss(reduction='batchmean')
15 |
16 | def KL_loss(student_features,teacher_features):
17 | feature_loss = loss(F.log_softmax(student_features, dim=-1), F.softmax(teacher_features, dim=-1))
18 | return feature_loss
19 |
20 | labels = torch.tensor([1, 0, 1, 0], dtype=torch.float32)
21 | def contrastive_loss(student_features, teacher_features, margin=1.0):
22 | distances = (student_features - teacher_features).pow(2).sum(1) # 欧几里得距离的平方
23 | loss = (1 - labels) * distances + labels * F.relu(margin - distances.sqrt()).pow(2)
24 | return loss.mean()
25 |
26 | class ULIPWithImageLoss(nn.Module):
27 | def __init__(self):
28 | super().__init__()
29 | self.labels = None
30 | self.last_local_batch_size = None
31 |
32 | def forward(self, outputs):
33 | pc_embed = outputs['pc_embed']
34 | text_embed = outputs['text_embed']
35 | image_embed = outputs['image_embed']
36 | logit_scale = outputs['logit_scale']
37 | local_batch_size = pc_embed.size(0)
38 |
39 | if local_batch_size != self.last_local_batch_size:
40 | self.labels = local_batch_size * get_rank() + torch.arange(
41 | local_batch_size, device=pc_embed.device
42 | )
43 | self.last_local_batch_size = local_batch_size
44 |
45 | # normalized features
46 | # pc_embed = F.normalize(pc_embed, dim=-1, p=2)
47 | # text_embed = F.normalize(text_embed, dim=-1, p=2)
48 | # image_embed = F.normalize(image_embed, dim=-1, p=2)
49 | # # gather features from all GPUs
50 | # pc_embed_all, text_embed_all, image_embed_all = \
51 | # all_gather_batch([pc_embed, text_embed, image_embed])
52 | # cosine similarity as logits
53 | # logits_per_pc_text = logit_scale * pc_embed @ text_embed_all.t()
54 | # logits_per_text_pc = logit_scale * text_embed @ pc_embed_all.t()
55 | # logits_per_text_image = logit_scale * text_embed @ image_embed_all.t()
56 | # logits_per_image_text = logit_scale * image_embed @ text_embed_all.t()
57 |
58 | # logits_per_pc_text = F.cosine_similarity(text_embed,pc_embed)
59 | logits_per_text_image = F.cosine_similarity(text_embed,image_embed)
60 | # ulip_pc_text = logit_scale - logits_per_pc_text.mean()
61 | ulip_text_image = logit_scale - logits_per_text_image.mean()
62 |
63 | # loss = 0.9*ulip_pc_text + 0.1*ulip_text_image
64 | loss = ulip_text_image
65 |
66 | return {'loss': loss, 'ulip_loss': loss, 'ulip_text_pc_sim':0, 'ulip_text_image_sim': ulip_text_image}
67 |
--------------------------------------------------------------------------------
/src/loss/losses_v2.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023, salesforce.com, inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: BSD-3-Clause
5 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | * By Le Xue
7 | '''
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from models.ULIP_utils import all_gather_batch,get_rank
13 |
14 | loss = torch.nn.KLDivLoss(reduction='batchmean')
15 |
16 | def KL_loss(student_features,teacher_features):
17 | feature_loss = loss(F.log_softmax(student_features, dim=-1), F.softmax(teacher_features, dim=-1))
18 | return feature_loss
19 | def contrastive_loss(e_s, e_t):
20 | """
21 | Compute the contrastive loss between shape (e_s) and image (e_t) embeddings.
22 |
23 | Args:
24 | e_s: Tensor of shape (N, C)
25 | e_t: Tensor of shape (N, C)
26 |
27 | Returns:
28 | loss: Contrastive loss value
29 | """
30 | N, C = e_s.size()
31 |
32 | # Normalize embeddings
33 | e_s = F.normalize(e_s, p=2, dim=1)
34 | e_t = F.normalize(e_t, p=2, dim=1)
35 |
36 | # Compute the similarity matrix
37 | sim_matrix = torch.matmul(e_s, e_t.T) # Shape: (N, N)
38 |
39 | # Compute the log probabilities
40 | log_prob_s = F.log_softmax(sim_matrix, dim=1) # Shape: (N, N)
41 | log_prob_t = F.log_softmax(sim_matrix.T, dim=1) # Shape: (N, N)
42 |
43 | # Contrastive loss
44 | loss_s = -torch.mean(torch.diag(log_prob_s))
45 | loss_t = -torch.mean(torch.diag(log_prob_t))
46 |
47 | # Total loss
48 | loss = (loss_s + loss_t) / 2
49 |
50 | return loss
51 |
52 | class ULIPWithImageLoss(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 | self.labels = None
56 | self.last_local_batch_size = None
57 |
58 | def forward(self, outputs):
59 | pc_embed = outputs['pc_embed']
60 | text_embed = outputs['text_embed']
61 | image_embed = outputs['image_embed']
62 | logit_scale = outputs['logit_scale']
63 | local_batch_size = pc_embed.size(0)
64 |
65 | if local_batch_size != self.last_local_batch_size:
66 | self.labels = local_batch_size * get_rank() + torch.arange(
67 | local_batch_size, device=pc_embed.device
68 | )
69 | self.last_local_batch_size = local_batch_size
70 |
71 | #normalized features
72 | pc_embed = F.normalize(pc_embed, dim=-1, p=2)
73 | text_embed = F.normalize(text_embed, dim=-1, p=2)
74 | image_embed = F.normalize(image_embed, dim=-1, p=2)
75 | # gather features from all GPUs
76 | pc_embed_all, text_embed_all, image_embed_all = \
77 | all_gather_batch([pc_embed, text_embed, image_embed])
78 |
79 | #cosine similarity as logits #[batch_size,batch_size]
80 |
81 | logits_per_pc_text = logit_scale * pc_embed @ text_embed_all.t()
82 | # logits_per_image_text = logit_scale * image_embed @ text_embed_all.t()
83 |
84 |
85 | ulip_pc_text = logit_scale - torch.diag(logits_per_pc_text).mean()
86 | # ulip_pc_image = logit_scale - torch.diag(logits_per_pc_image).mean()
87 | # ulip_text_image = logit_scale - torch.diag(logits_per_image_text).mean()
88 | # kl = KL_loss(text_embed,pc_embed).mean()
89 | con_loss = contrastive_loss(pc_embed,text_embed_all)
90 | # ctr = contrastive_loss(text_embed,pc_embed)
91 | # loss = ulip_pc_text + 0.2*ulip_text_image
92 | # loss = ulip_pc_text + 0.05 *ulip_text_image
93 | # loss = ulip_pc_text + kl
94 | loss = con_loss
95 | return {'loss': loss, 'ulip_loss': loss, 'ulip_text_pc_sim': ulip_pc_text, 'ulip_text_image_sim': 0}
96 |
--------------------------------------------------------------------------------
/src/models/ULIP_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import numpy as np
4 | import shutil
5 | import os
6 | class AverageMeter(object):
7 | """Computes and stores the average and current value"""
8 | def __init__(self, name, fmt=':f'):
9 | self.name = name
10 | self.fmt = fmt
11 | self.reset()
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
25 | def synchronize(self):
26 | if not is_dist_avail_and_initialized():
27 | return
28 | t = torch.tensor([self.sum, self.count], dtype=torch.float64, device='cuda')
29 | dist.barrier()
30 | dist.all_reduce(t)
31 | t = t.tolist()
32 | self.sum = int(t[0])
33 | self.count = t[1]
34 | self.avg = self.sum / self.count
35 |
36 | def __str__(self):
37 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
38 | return fmtstr.format(**self.__dict__)
39 |
40 | def get_model(model):
41 | if isinstance(model, torch.nn.DataParallel) \
42 | or isinstance(model, torch.nn.parallel.DistributedDataParallel):
43 | return model.module
44 | else:
45 | return model
46 | def is_dist_avail_and_initialized():
47 | if not dist.is_available():
48 | return False
49 | if not dist.is_initialized():
50 | return False
51 | return True
52 |
53 |
54 | def get_world_size():
55 | if not is_dist_avail_and_initialized():
56 | return 1
57 | return dist.get_world_size()
58 |
59 |
60 | def get_rank():
61 | if not is_dist_avail_and_initialized():
62 | return 0
63 | return dist.get_rank()
64 |
65 |
66 | def is_main_process():
67 | return get_rank() == 0
68 |
69 | class ProgressMeter(object):
70 | def __init__(self, num_batches, meters, prefix=""):
71 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
72 | self.meters = meters
73 | self.prefix = prefix
74 |
75 | def display(self, batch):
76 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
77 | entries += [str(meter) for meter in self.meters]
78 | print('\t'.join(entries))
79 |
80 | def synchronize(self):
81 | for meter in self.meters:
82 | meter.synchronize()
83 |
84 | def _get_batch_fmtstr(self, num_batches):
85 | num_digits = len(str(num_batches // 1))
86 | fmt = '{:' + str(num_digits) + 'd}'
87 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
88 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
89 | warmup_schedule = np.array([])
90 | warmup_iters = warmup_epochs * niter_per_ep
91 | if warmup_epochs > 0:
92 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
93 |
94 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
95 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
96 |
97 | schedule = np.concatenate((warmup_schedule, schedule))
98 | assert len(schedule) == epochs * niter_per_ep
99 | return schedule
100 |
101 | def scaled_all_reduce(tensors, is_scale=True):
102 | """Performs the scaled all_reduce operation on the provided tensors.
103 | The input tensors are modified in-place. Currently supports only the sum
104 | reduction operator. The reduced values are scaled by the inverse size of the
105 | world size.
106 | """
107 | world_size = get_world_size()
108 | # There is no need for reduction in the single-proc case
109 | if world_size == 1:
110 | return tensors
111 | # Queue the reductions
112 | reductions = []
113 | for tensor in tensors:
114 | reduction = dist.all_reduce(tensor, async_op=True)
115 | reductions.append(reduction)
116 | # Wait for reductions to finish
117 | for reduction in reductions:
118 | reduction.wait()
119 | # Scale the results
120 | if is_scale:
121 | for tensor in tensors:
122 | tensor.mul_(1.0 / world_size)
123 | return tensors
124 | def save_on_master(state, is_best, output_dir):
125 | if is_main_process():
126 | ckpt_path = '{}/checkpoint_{}.pt'.format(output_dir, state['epoch'])
127 | best_path = f'{output_dir}/checkpoint_best.pt'
128 | torch.save(state, ckpt_path)
129 | if is_best:
130 | shutil.copyfile(ckpt_path, best_path)
131 | if os.path.exists(ckpt_path):
132 | os.remove(ckpt_path)
133 |
134 | def all_gather_batch(tensors):
135 | """
136 | Performs all_gather operation on the provided tensors.
137 | """
138 | # Queue the gathered tensors
139 | world_size = get_world_size()
140 | # There is no need for reduction in the single-proc case
141 | if world_size == 1:
142 | return tensors
143 | tensor_list = []
144 | output_tensor = []
145 | for tensor in tensors:
146 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
147 | dist.all_gather(
148 | tensor_all,
149 | tensor,
150 | async_op=False # performance opt
151 | )
152 |
153 | tensor_list.append(tensor_all)
154 |
155 | for tensor_all in tensor_list:
156 | output_tensor.append(torch.cat(tensor_all, dim=0))
157 | return output_tensor
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build_model_from_cfg
2 | import models.Point_MAE
--------------------------------------------------------------------------------
/src/models/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/src/models/build.py:
--------------------------------------------------------------------------------
1 | from utils import registry
2 |
3 |
4 | MODELS = registry.Registry('models')
5 |
6 |
7 | def build_model_from_cfg(cfg, **kwargs):
8 | """
9 | Build a dataset, defined by `dataset_name`.
10 | Args:
11 | cfg (eDICT):
12 | Returns:
13 | Dataset: a constructed dataset specified by dataset_name.
14 | """
15 | return MODELS.build(cfg, **kwargs)
16 |
17 |
18 |
--------------------------------------------------------------------------------
/src/models/clip_utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import math
3 | from typing import Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | from timm.models.layers import drop_path
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, inplanes, planes, stride=1):
16 | super().__init__()
17 |
18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 |
22 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 |
30 | self.relu = nn.ReLU(inplace=True)
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.relu(self.bn1(self.conv1(x)))
46 | out = self.relu(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 | self.embed_dim = embed_dim
68 | self.spacial_dim = spacial_dim
69 |
70 | def forward(self, x):
71 | B, C, H, W = x.shape
72 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
73 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
74 |
75 | cls_pos = self.positional_embedding[0:1, :]
76 | spatial_pos = F.interpolate(self.positional_embedding[1:,].reshape(1, self.spacial_dim, self.spacial_dim, self.embed_dim).permute(0, 3, 1, 2), size=(H, W), mode='bilinear')
77 | spatial_pos = spatial_pos.reshape(self.embed_dim, H*W).permute(1, 0)
78 | positional_embedding = torch.cat([cls_pos, spatial_pos], dim=0)
79 |
80 | x = x + positional_embedding[:, None, :]
81 | x, _ = F.multi_head_attention_forward(
82 | query=x, key=x, value=x,
83 | embed_dim_to_check=x.shape[-1],
84 | num_heads=self.num_heads,
85 | q_proj_weight=self.q_proj.weight,
86 | k_proj_weight=self.k_proj.weight,
87 | v_proj_weight=self.v_proj.weight,
88 | in_proj_weight=None,
89 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
90 | bias_k=None,
91 | bias_v=None,
92 | add_zero_attn=False,
93 | dropout_p=0,
94 | out_proj_weight=self.c_proj.weight,
95 | out_proj_bias=self.c_proj.bias,
96 | use_separate_proj_weight=True,
97 | training=self.training,
98 | need_weights=False
99 | )
100 |
101 | x = x.permute(1, 2, 0)
102 | global_feat = x[:, :, 0]
103 | feature_map = x[:, :, 1:].reshape(B, -1, H, W)
104 | return global_feat, feature_map
105 |
106 |
107 | class LayerNorm(nn.LayerNorm):
108 | """Subclass torch's LayerNorm to handle fp16."""
109 |
110 | def forward(self, x: torch.Tensor):
111 | orig_type = x.dtype
112 | ret = super().forward(x.type(torch.float32))
113 | return ret.type(orig_type)
114 |
115 |
116 | class QuickGELU(nn.Module):
117 | def forward(self, x: torch.Tensor):
118 | return x * torch.sigmoid(1.702 * x)
119 |
120 |
121 | class DropPath(nn.Module):
122 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
123 | """
124 | def __init__(self, drop_prob=None):
125 | super(DropPath, self).__init__()
126 | self.drop_prob = drop_prob
127 |
128 | def forward(self, x):
129 | return drop_path(x, self.drop_prob, self.training)
130 |
131 | def extra_repr(self) -> str:
132 | return 'p={}'.format(self.drop_prob)
133 |
134 |
135 | class ResidualAttentionBlock(nn.Module):
136 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, drop_path=0.):
137 | super().__init__()
138 |
139 | self.attn = nn.MultiheadAttention(d_model, n_head)
140 | self.ln_1 = LayerNorm(d_model)
141 | self.mlp = nn.Sequential(OrderedDict([
142 | ("c_fc", nn.Linear(d_model, d_model * 4)),
143 | ("gelu", QuickGELU()),
144 | ("c_proj", nn.Linear(d_model * 4, d_model))
145 | ]))
146 | self.ln_2 = LayerNorm(d_model)
147 | self.attn_mask = attn_mask
148 |
149 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
150 |
151 | def attention(self, x: torch.Tensor):
152 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
153 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
154 |
155 | def attention_weight(self, x: torch.Tensor):
156 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
157 | return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)[1]
158 |
159 | def forward(self, x: torch.Tensor, return_attention: bool=False):
160 | x = x + self.attention(self.ln_1(x))
161 | x = x + self.mlp(self.ln_2(x))
162 | return x
163 |
164 | class Transformer(nn.Module):
165 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, drop_path_rate=0.):
166 | super().__init__()
167 |
168 | self.width = width
169 | self.layers = layers
170 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, layers)] # stochastic depth decay rule
171 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) for i in range(layers)])
172 |
173 | def forward(self, x: torch.Tensor):
174 | return self.resblocks(x)
175 |
176 | # ADDED
177 | def forward_attention(self, x: torch.Tensor):
178 | for index, layer in enumerate(self.resblocks):
179 | if index == len(self.resblocks) - 1:
180 | return layer(x, return_attention=True)
181 | x = layer(x)
182 |
183 |
184 | class Attention(nn.Module):
185 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
186 | super().__init__()
187 | self.num_heads = num_heads
188 | head_dim = dim // num_heads
189 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
190 | self.scale = qk_scale or head_dim ** -0.5
191 |
192 | self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
193 | self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
194 | self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
195 |
196 |
197 | self.attn_drop = nn.Dropout(attn_drop)
198 | self.proj = nn.Linear(dim, dim)
199 | self.proj_drop = nn.Dropout(proj_drop)
200 |
201 | def forward(self, q, k, v):
202 | B, N, C = q.shape
203 | assert k.shape == v.shape
204 | B, M, C = k.shape
205 | q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads)
206 | k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads)
207 | v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads)
208 |
209 | attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale
210 |
211 | attn = attn.softmax(dim=-1)
212 |
213 | x = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B, N, C)
214 |
215 | x = self.proj(x)
216 | x = self.proj_drop(x)
217 | return x
218 |
--------------------------------------------------------------------------------
/src/models/encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 |
6 | # From : https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | 'resnet152']
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | }
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1):
21 | "3x3 convolution with padding"
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23 | padding=1, bias=False)
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | expansion = 1
28 |
29 | def __init__(self, inplanes, planes, stride=1, downsample=None):
30 | super(BasicBlock, self).__init__()
31 | self.conv1 = conv3x3(inplanes, planes, stride)
32 | self.bn1 = nn.BatchNorm2d(planes)
33 | self.relu = nn.ReLU(inplace=True)
34 | self.conv2 = conv3x3(planes, planes)
35 | self.bn2 = nn.BatchNorm2d(planes)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 |
49 | if self.downsample is not None:
50 | residual = self.downsample(x)
51 |
52 | out += residual
53 | out = self.relu(out)
54 |
55 | return out
56 |
57 |
58 | class Bottleneck(nn.Module):
59 | expansion = 4
60 |
61 | def __init__(self, inplanes, planes, stride=1, downsample=None):
62 | super(Bottleneck, self).__init__()
63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(planes)
65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
66 | padding=1, bias=False)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
69 | self.bn3 = nn.BatchNorm2d(planes * 4)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | residual = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | residual = self.downsample(x)
90 |
91 | out += residual
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNet(nn.Module):
98 |
99 | def __init__(self, block, layers, num_classes=1000):
100 | self.inplanes = 64
101 | super(ResNet, self).__init__()
102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
103 | bias=False)
104 | self.bn1 = nn.BatchNorm2d(64)
105 | self.relu = nn.ReLU(inplace=True)
106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
107 | self.layer1 = self._make_layer(block, 64, layers[0])
108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
111 | self.avgpool = nn.AvgPool2d(7)
112 | self.fc = nn.Linear(512 * block.expansion, num_classes)
113 |
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
117 | m.weight.data.normal_(0, math.sqrt(2. / n))
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 |
122 | def _make_layer(self, block, planes, blocks, stride=1):
123 | downsample = None
124 | if stride != 1 or self.inplanes != planes * block.expansion:
125 | downsample = nn.Sequential(
126 | nn.Conv2d(self.inplanes, planes * block.expansion,
127 | kernel_size=1, stride=stride, bias=False),
128 | nn.BatchNorm2d(planes * block.expansion),
129 | )
130 |
131 | layers = []
132 | layers.append(block(self.inplanes, planes, stride, downsample))
133 | self.inplanes = planes * block.expansion
134 | for i in range(1, blocks):
135 | layers.append(block(self.inplanes, planes))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def forward(self, x):
140 | x = self.conv1(x)
141 | x = self.bn1(x)
142 | x = self.relu(x)
143 | x = self.maxpool(x)
144 |
145 | x = self.layer1(x)
146 | x = self.layer2(x)
147 | x = self.layer3(x)
148 | x = self.layer4(x)
149 |
150 | x = self.avgpool(x)
151 | x = x.view(x.size(0), -1)
152 | x = self.fc(x)
153 |
154 | return x
155 |
156 | class ResNetReturnMid(nn.Module):
157 |
158 | def __init__(self, block, layers, num_classes=1000):
159 | self.inplanes = 64
160 | super(ResNetReturnMid, self).__init__()
161 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
162 | bias=False)
163 | self.bn1 = nn.BatchNorm2d(64)
164 | self.relu = nn.ReLU(inplace=True)
165 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
166 | self.layer1 = self._make_layer(block, 64, layers[0])
167 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
168 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
169 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
170 | self.avgpool = nn.AvgPool2d(7)
171 | self.fc = nn.Linear(512 * block.expansion, num_classes)
172 |
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
176 | m.weight.data.normal_(0, math.sqrt(2. / n))
177 | elif isinstance(m, nn.BatchNorm2d):
178 | m.weight.data.fill_(1)
179 | m.bias.data.zero_()
180 |
181 | def _make_layer(self, block, planes, blocks, stride=1):
182 | downsample = None
183 | if stride != 1 or self.inplanes != planes * block.expansion:
184 | downsample = nn.Sequential(
185 | nn.Conv2d(self.inplanes, planes * block.expansion,
186 | kernel_size=1, stride=stride, bias=False),
187 | nn.BatchNorm2d(planes * block.expansion),
188 | )
189 |
190 | layers = []
191 | layers.append(block(self.inplanes, planes, stride, downsample))
192 | self.inplanes = planes * block.expansion
193 | for i in range(1, blocks):
194 | layers.append(block(self.inplanes, planes))
195 |
196 | return nn.Sequential(*layers)
197 |
198 | def forward(self, x):
199 | x = self.conv1(x)
200 | x = self.bn1(x)
201 | x = self.relu(x)
202 | x = self.maxpool(x)
203 |
204 | x = self.layer1(x)
205 | x = self.layer2(x)
206 | x1 = x
207 | x = self.layer3(x)
208 | x = self.layer4(x)
209 |
210 | x = self.avgpool(x)
211 | x = x.view(x.size(0), -1)
212 | x = self.fc(x)
213 |
214 | return x, x1
215 |
216 | def resnet18(pretrained=False, **kwargs):
217 | """Constructs a ResNet-18 model.
218 | Args:
219 | pretrained (bool): If True, returns a model pre-trained on ImageNet
220 | """
221 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
222 | if pretrained:
223 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
224 | return model
225 |
226 | def resnet18ReturnMid(pretrained=False, **kwargs):
227 | """Constructs a ResNet-18 model.
228 | Args:
229 | pretrained (bool): If True, returns a model pre-trained on ImageNet
230 | """
231 | model = ResNetReturnMid(BasicBlock, [2, 2, 2, 2], **kwargs)
232 | if pretrained:
233 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
234 | return model
235 |
236 | def resnet34(pretrained=False, **kwargs):
237 | """Constructs a ResNet-34 model.
238 | Args:
239 | pretrained (bool): If True, returns a model pre-trained on ImageNet
240 | """
241 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
242 | if pretrained:
243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
244 | return model
245 |
246 |
247 | def resnet50(pretrained=False, **kwargs):
248 | """Constructs a ResNet-50 model.
249 | Args:
250 | pretrained (bool): If True, returns a model pre-trained on ImageNet
251 | """
252 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
253 | if pretrained:
254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
255 | return model
256 |
257 |
258 | def resnet101(pretrained=False, **kwargs):
259 | """Constructs a ResNet-101 model.
260 | Args:
261 | pretrained (bool): If True, returns a model pre-trained on ImageNet
262 | """
263 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
264 | if pretrained:
265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
266 | return model
267 |
268 |
269 | def resnet152(pretrained=False, **kwargs):
270 | """Constructs a ResNet-152 model.
271 | Args:
272 | pretrained (bool): If True, returns a model pre-trained on ImageNet
273 | """
274 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
275 | if pretrained:
276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
277 | return model
278 |
279 |
280 |
--------------------------------------------------------------------------------
/src/models/pointbert/PointTransformer_8192point.yaml:
--------------------------------------------------------------------------------
1 | optimizer : {
2 | type: AdamW,
3 | kwargs: {
4 | lr : 0.0005,
5 | weight_decay : 0.05
6 | }}
7 |
8 | scheduler: {
9 | type: CosLR,
10 | kwargs: {
11 | epochs: 300,
12 | initial_epochs : 10
13 | }}
14 |
15 | model : {
16 | NAME: PointTransformer,
17 | trans_dim: 384,
18 | depth: 12,
19 | drop_path_rate: 0.1,
20 | cls_dim: 40,
21 | num_heads: 6,
22 | group_size: 32,
23 | num_group: 64,
24 | encoder_dims: 256,
25 | }
26 | npoints: 2048
27 | total_bs : 32
28 | step_per_update : 1
29 | max_epoch : 300
30 | grad_norm_clip : 10
31 |
32 | consider_metric: CDL1
--------------------------------------------------------------------------------
/src/models/pointbert/ULIP_2_PointBERT_10k_colored_pointclouds.yaml:
--------------------------------------------------------------------------------
1 | optimizer : {
2 | type: AdamW,
3 | kwargs: {
4 | lr : 0.0005,
5 | weight_decay : 0.05
6 | }}
7 |
8 | scheduler: {
9 | type: CosLR,
10 | kwargs: {
11 | epochs: 200,
12 | initial_epochs : 10
13 | }}
14 |
15 | model : {
16 | NAME: PointTransformer,
17 | trans_dim: 384,
18 | depth: 18,
19 | drop_path_rate: 0.1,
20 | cls_dim: 40,
21 | num_heads: 6,
22 | group_size: 32,
23 | num_group: 512,
24 | encoder_dims: 256,
25 | }
26 | npoints: 10000
27 | total_bs : 32
28 | step_per_update : 1
29 | max_epoch : 300
30 | grad_norm_clip : 10
31 |
32 | consider_metric: CDL1
--------------------------------------------------------------------------------
/src/models/pointbert/__pycache__/checkpoint.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/checkpoint.cpython-39.pyc
--------------------------------------------------------------------------------
/src/models/pointbert/__pycache__/dvae.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/dvae.cpython-39.pyc
--------------------------------------------------------------------------------
/src/models/pointbert/__pycache__/logger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/logger.cpython-39.pyc
--------------------------------------------------------------------------------
/src/models/pointbert/__pycache__/misc.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/misc.cpython-39.pyc
--------------------------------------------------------------------------------
/src/models/pointbert/__pycache__/point_encoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/models/pointbert/__pycache__/point_encoder.cpython-39.pyc
--------------------------------------------------------------------------------
/src/models/pointbert/checkpoint.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch.nn as nn
3 |
4 | from typing import Any
5 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
6 |
7 | from termcolor import colored
8 |
9 | def get_missing_parameters_message(keys: List[str]) -> str:
10 | """
11 | Get a logging-friendly message to report parameter names (keys) that are in
12 | the model but not found in a checkpoint.
13 | Args:
14 | keys (list[str]): List of keys that were not found in the checkpoint.
15 | Returns:
16 | str: message.
17 | """
18 | groups = _group_checkpoint_keys(keys)
19 | msg = "Some model parameters or buffers are not found in the checkpoint:\n"
20 | msg += "\n".join(
21 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
22 | )
23 | return msg
24 |
25 |
26 | def get_unexpected_parameters_message(keys: List[str]) -> str:
27 | """
28 | Get a logging-friendly message to report parameter names (keys) that are in
29 | the checkpoint but not found in the model.
30 | Args:
31 | keys (list[str]): List of keys that were not found in the model.
32 | Returns:
33 | str: message.
34 | """
35 | groups = _group_checkpoint_keys(keys)
36 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
37 | msg += "\n".join(
38 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
39 | )
40 | return msg
41 |
42 |
43 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
44 | """
45 | Strip the prefix in metadata, if any.
46 | Args:
47 | state_dict (OrderedDict): a state-dict to be loaded to the model.
48 | prefix (str): prefix.
49 | """
50 | keys = sorted(state_dict.keys())
51 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
52 | return
53 |
54 | for key in keys:
55 | newkey = key[len(prefix):]
56 | state_dict[newkey] = state_dict.pop(key)
57 |
58 | # also strip the prefix in metadata, if any..
59 | try:
60 | metadata = state_dict._metadata # pyre-ignore
61 | except AttributeError:
62 | pass
63 | else:
64 | for key in list(metadata.keys()):
65 | # for the metadata dict, the key can be:
66 | # '': for the DDP module, which we want to remove.
67 | # 'module': for the actual model.
68 | # 'module.xx.xx': for the rest.
69 |
70 | if len(key) == 0:
71 | continue
72 | newkey = key[len(prefix):]
73 | metadata[newkey] = metadata.pop(key)
74 |
75 |
76 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
77 | """
78 | Group keys based on common prefixes. A prefix is the string up to the final
79 | "." in each key.
80 | Args:
81 | keys (list[str]): list of parameter names, i.e. keys in the model
82 | checkpoint dict.
83 | Returns:
84 | dict[list]: keys with common prefixes are grouped into lists.
85 | """
86 | groups = defaultdict(list)
87 | for key in keys:
88 | pos = key.rfind(".")
89 | if pos >= 0:
90 | head, tail = key[:pos], [key[pos + 1:]]
91 | else:
92 | head, tail = key, []
93 | groups[head].extend(tail)
94 | return groups
95 |
96 |
97 | def _group_to_str(group: List[str]) -> str:
98 | """
99 | Format a group of parameter name suffixes into a loggable string.
100 | Args:
101 | group (list[str]): list of parameter name suffixes.
102 | Returns:
103 | str: formated string.
104 | """
105 | if len(group) == 0:
106 | return ""
107 |
108 | if len(group) == 1:
109 | return "." + group[0]
110 |
111 | return ".{" + ", ".join(group) + "}"
112 |
113 |
114 | def _named_modules_with_dup(
115 | model: nn.Module, prefix: str = ""
116 | ) -> Iterable[Tuple[str, nn.Module]]:
117 | """
118 | The same as `model.named_modules()`, except that it includes
119 | duplicated modules that have more than one name.
120 | """
121 | yield prefix, model
122 | for name, module in model._modules.items(): # pyre-ignore
123 | if module is None:
124 | continue
125 | submodule_prefix = prefix + ("." if prefix else "") + name
126 | yield from _named_modules_with_dup(module, submodule_prefix)
--------------------------------------------------------------------------------
/src/models/pointbert/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | logger_initialized = {}
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7 | """Get root logger and add a keyword filter to it.
8 | The logger will be initialized if it has not been initialized. By default a
9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
10 | also be added. The name of the root logger is the top-level package name,
11 | e.g., "mmdet3d".
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str, optional): The name of the root logger, also used as a
17 | filter keyword. Defaults to 'mmdet3d'.
18 | Returns:
19 | :obj:`logging.Logger`: The obtained logger
20 | """
21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22 | # add a logging filter
23 | logging_filter = logging.Filter(name)
24 | logging_filter.filter = lambda record: record.find(name) != -1
25 |
26 | return logger
27 |
28 |
29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30 | """Initialize and get a logger by name.
31 | If the logger has not been initialized, this method will initialize the
32 | logger by adding one or two handlers, otherwise the initialized logger will
33 | be directly returned. During initialization, a StreamHandler will always be
34 | added. If `log_file` is specified and the process rank is 0, a FileHandler
35 | will also be added.
36 | Args:
37 | name (str): Logger name.
38 | log_file (str | None): The log filename. If specified, a FileHandler
39 | will be added to the logger.
40 | log_level (int): The logger level. Note that only the process of
41 | rank 0 is affected, and other processes will set the level to
42 | "Error" thus be silent most of the time.
43 | file_mode (str): The file mode used in opening log file.
44 | Defaults to 'w'.
45 | Returns:
46 | logging.Logger: The expected logger.
47 | """
48 | logger = logging.getLogger(name)
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | # handle duplicate logs to the console
59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
60 | # to the root logger. As logger.propagate is True by default, this root
61 | # level handler causes logging messages from rank>0 processes to
62 | # unexpectedly show up on the console, creating much unwanted clutter.
63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
64 | # at the ERROR level.
65 | for handler in logger.root.handlers:
66 | if type(handler) is logging.StreamHandler:
67 | handler.setLevel(logging.ERROR)
68 |
69 | stream_handler = logging.StreamHandler()
70 | handlers = [stream_handler]
71 |
72 | if dist.is_available() and dist.is_initialized():
73 | rank = dist.get_rank()
74 | else:
75 | rank = 0
76 |
77 | # only rank 0 will add a FileHandler
78 | if rank == 0 and log_file is not None:
79 | # Here, the default behaviour of the official logger is 'a'. Thus, we
80 | # provide an interface to change the file mode to the default
81 | # behaviour.
82 | file_handler = logging.FileHandler(log_file, file_mode)
83 | handlers.append(file_handler)
84 |
85 | formatter = logging.Formatter(
86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87 | for handler in handlers:
88 | handler.setFormatter(formatter)
89 | handler.setLevel(log_level)
90 | logger.addHandler(handler)
91 |
92 | if rank == 0:
93 | logger.setLevel(log_level)
94 | else:
95 | logger.setLevel(logging.ERROR)
96 |
97 | logger_initialized[name] = True
98 |
99 |
100 | return logger
101 |
102 |
103 | def print_log(msg, logger=None, level=logging.INFO):
104 | """Print a log message.
105 | Args:
106 | msg (str): The message to be logged.
107 | logger (logging.Logger | str | None): The logger to be used.
108 | Some special loggers are:
109 | - "silent": no message will be printed.
110 | - other str: the logger obtained with `get_root_logger(logger)`.
111 | - None: The `print()` method will be used to print log messages.
112 | level (int): Logging level. Only available when `logger` is a Logger
113 | object or "root".
114 | """
115 | if logger is None:
116 | print(msg)
117 | elif isinstance(logger, logging.Logger):
118 | logger.log(level, msg)
119 | elif logger == 'silent':
120 | pass
121 | elif isinstance(logger, str):
122 | _logger = get_logger(logger)
123 | _logger.log(level, msg)
124 | else:
125 | raise TypeError(
126 | 'logger should be either a logging.Logger object, str, '
127 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/src/models/pointbert/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from mpl_toolkits.mplot3d import Axes3D
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | from collections import abc
10 | from pointnet2_ops import pointnet2_utils
11 |
12 |
13 | # def fps(data, number):
14 | # '''
15 | # data B N 3
16 | # number int
17 | # '''
18 | # fps_idx = pointnet2_utils.furthest_point_sample(data, number)
19 | # fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
20 | # return fps_data
21 |
22 | def index_points(points, idx):
23 | """
24 | Input:
25 | points: input points data, [B, N, C]
26 | idx: sample index data, [B, S]
27 | Return:
28 | new_points:, indexed points data, [B, S, C]
29 | """
30 | device = points.device
31 | B = points.shape[0]
32 | view_shape = list(idx.shape)
33 | view_shape[1:] = [1] * (len(view_shape) - 1)
34 | repeat_shape = list(idx.shape)
35 | repeat_shape[0] = 1
36 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
37 | new_points = points[batch_indices, idx, :]
38 | return new_points
39 |
40 | # def fps(xyz, npoint):
41 | # """
42 | # Input:
43 | # xyz: pointcloud data, [B, N, 3]
44 | # npoint: number of samples
45 | # Return:
46 | # centroids: sampled pointcloud index, [B, npoint]
47 | # """
48 | # device = xyz.device
49 | # B, N, C = xyz.shape
50 | # centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
51 | # distance = torch.ones(B, N).to(device) * 1e10
52 | # farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
53 | # batch_indices = torch.arange(B, dtype=torch.long).to(device)
54 | # for i in range(npoint):
55 | # centroids[:, i] = farthest
56 | # centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
57 | # dist = torch.sum((xyz - centroid) ** 2, -1)
58 | # distance = torch.min(distance, dist)
59 | # farthest = torch.max(distance, -1)[1]
60 | # return index_points(xyz, centroids)
61 |
62 | def fps(data, number):
63 | '''
64 | data B N 3
65 | number int
66 | '''
67 | fps_idx = pointnet2_utils.furthest_point_sample(data, number)
68 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
69 | return fps_data
70 |
71 | def worker_init_fn(worker_id):
72 | np.random.seed(np.random.get_state()[1][0] + worker_id)
73 |
74 | def build_lambda_sche(opti, config):
75 | if config.get('decay_step') is not None:
76 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
77 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
78 | else:
79 | raise NotImplementedError()
80 | return scheduler
81 |
82 | def build_lambda_bnsche(model, config):
83 | if config.get('decay_step') is not None:
84 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
85 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
86 | else:
87 | raise NotImplementedError()
88 | return bnm_scheduler
89 |
90 | def set_random_seed(seed, deterministic=False):
91 | """Set random seed.
92 | Args:
93 | seed (int): Seed to be used.
94 | deterministic (bool): Whether to set the deterministic option for
95 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
96 | to True and `torch.backends.cudnn.benchmark` to False.
97 | Default: False.
98 |
99 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
100 | if cuda_deterministic: # slower, more reproducible
101 | cudnn.deterministic = True
102 | cudnn.benchmark = False
103 | else: # faster, less reproducible
104 | cudnn.deterministic = False
105 | cudnn.benchmark = True
106 |
107 | """
108 | random.seed(seed)
109 | np.random.seed(seed)
110 | torch.manual_seed(seed)
111 | torch.cuda.manual_seed_all(seed)
112 | if deterministic:
113 | torch.backends.cudnn.deterministic = True
114 | torch.backends.cudnn.benchmark = False
115 |
116 |
117 | def is_seq_of(seq, expected_type, seq_type=None):
118 | """Check whether it is a sequence of some type.
119 | Args:
120 | seq (Sequence): The sequence to be checked.
121 | expected_type (type): Expected type of sequence items.
122 | seq_type (type, optional): Expected sequence type.
123 | Returns:
124 | bool: Whether the sequence is valid.
125 | """
126 | if seq_type is None:
127 | exp_seq_type = abc.Sequence
128 | else:
129 | assert isinstance(seq_type, type)
130 | exp_seq_type = seq_type
131 | if not isinstance(seq, exp_seq_type):
132 | return False
133 | for item in seq:
134 | if not isinstance(item, expected_type):
135 | return False
136 | return True
137 |
138 |
139 | def set_bn_momentum_default(bn_momentum):
140 | def fn(m):
141 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
142 | m.momentum = bn_momentum
143 | return fn
144 |
145 | class BNMomentumScheduler(object):
146 |
147 | def __init__(
148 | self, model, bn_lambda, last_epoch=-1,
149 | setter=set_bn_momentum_default
150 | ):
151 | if not isinstance(model, nn.Module):
152 | raise RuntimeError(
153 | "Class '{}' is not a PyTorch nn Module".format(
154 | type(model).__name__
155 | )
156 | )
157 |
158 | self.model = model
159 | self.setter = setter
160 | self.lmbd = bn_lambda
161 |
162 | self.step(last_epoch + 1)
163 | self.last_epoch = last_epoch
164 |
165 | def step(self, epoch=None):
166 | if epoch is None:
167 | epoch = self.last_epoch + 1
168 |
169 | self.last_epoch = epoch
170 | self.model.apply(self.setter(self.lmbd(epoch)))
171 |
172 | def get_momentum(self, epoch=None):
173 | if epoch is None:
174 | epoch = self.last_epoch + 1
175 | return self.lmbd(epoch)
176 |
177 |
178 |
179 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):
180 | '''
181 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
182 | '''
183 | _,n,c = xyz.shape
184 |
185 | assert n == num_points
186 | assert c == 3
187 | if crop == num_points:
188 | return xyz, None
189 |
190 | INPUT = []
191 | CROP = []
192 | for points in xyz:
193 | if isinstance(crop,list):
194 | num_crop = random.randint(crop[0],crop[1])
195 | else:
196 | num_crop = crop
197 |
198 | points = points.unsqueeze(0)
199 |
200 | if fixed_points is None:
201 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()
202 | else:
203 | if isinstance(fixed_points,list):
204 | fixed_point = random.sample(fixed_points,1)[0]
205 | else:
206 | fixed_point = fixed_points
207 | center = fixed_point.reshape(1,1,3).cuda()
208 |
209 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
210 |
211 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
212 |
213 | if padding_zeros:
214 | input_data = points.clone()
215 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
216 |
217 | else:
218 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
219 |
220 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
221 |
222 | if isinstance(crop,list):
223 | INPUT.append(fps(input_data,2048))
224 | CROP.append(fps(crop_data,2048))
225 | else:
226 | INPUT.append(input_data)
227 | CROP.append(crop_data)
228 |
229 | input_data = torch.cat(INPUT,dim=0)# B N 3
230 | crop_data = torch.cat(CROP,dim=0)# B M 3
231 |
232 | return input_data.contiguous(), crop_data.contiguous()
233 |
234 | def get_ptcloud_img(ptcloud):
235 | fig = plt.figure(figsize=(8, 8))
236 |
237 | x, z, y = ptcloud.transpose(1, 0)
238 | ax = fig.gca(projection=Axes3D.name, adjustable='box')
239 | ax.axis('off')
240 | # ax.axis('scaled')
241 | ax.view_init(30, 45)
242 | max, min = np.max(ptcloud), np.min(ptcloud)
243 | ax.set_xbound(min, max)
244 | ax.set_ybound(min, max)
245 | ax.set_zbound(min, max)
246 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet')
247 |
248 | fig.canvas.draw()
249 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
250 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
251 | return img
252 |
253 |
254 |
255 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y',
256 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):
257 | fig = plt.figure(figsize=(6*len(data_list),6))
258 | cmax = data_list[-1][:,0].max()
259 |
260 | for i in range(len(data_list)):
261 | data = data_list[i][:-2048] if i == 1 else data_list[i]
262 | color = data[:,0] /cmax
263 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')
264 | ax.view_init(30, -120)
265 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')
266 | ax.set_title(titles[i])
267 |
268 | ax.set_axis_off()
269 | ax.set_xlim(xlim)
270 | ax.set_ylim(ylim)
271 | ax.set_zlim(zlim)
272 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
273 | if not os.path.exists(path):
274 | os.makedirs(path)
275 |
276 | pic_path = path + '.png'
277 | fig.savefig(pic_path)
278 |
279 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
280 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
281 | plt.close(fig)
282 |
283 |
284 | def random_dropping(pc, e):
285 | up_num = max(64, 768 // (e//50 + 1))
286 | pc = pc
287 | random_num = torch.randint(1, up_num, (1,1))[0,0]
288 | pc = fps(pc, random_num)
289 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
290 | pc = torch.cat([pc, padding], dim = 1)
291 | return pc
292 |
293 |
294 | def random_scale(partial, scale_range=[0.8, 1.2]):
295 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
296 | return partial * scale
297 |
--------------------------------------------------------------------------------
/src/models/pointbert/point_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from timm.models.layers import DropPath
5 | from models.pointbert.dvae import Group
6 | from models.pointbert.dvae import Encoder
7 | from models.pointbert.logger import print_log
8 |
9 | from models.pointbert.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
10 |
11 | def cal_model_parm_nums(model):
12 | total = sum([param.nelement() for param in model.parameters()])
13 |
14 | if total >= 1e9:
15 | return "{:.2f}B".format(total / 1e9)
16 | elif total >= 1e6:
17 | return "{:.2f}M".format(total / 1e6)
18 | elif total >= 1e3:
19 | return "{:.2f}K".format(total / 1e3)
20 | else:
21 | return str(total)
22 |
23 | class Mlp(nn.Module):
24 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
25 | super().__init__()
26 | out_features = out_features or in_features
27 | hidden_features = hidden_features or in_features
28 | self.fc1 = nn.Linear(in_features, hidden_features)
29 | self.act = act_layer()
30 | self.fc2 = nn.Linear(hidden_features, out_features)
31 | self.drop = nn.Dropout(drop)
32 |
33 | def forward(self, x):
34 | x = self.fc1(x)
35 | x = self.act(x)
36 | x = self.drop(x)
37 | x = self.fc2(x)
38 | x = self.drop(x)
39 | return x
40 |
41 |
42 | class Attention(nn.Module):
43 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
44 | super().__init__()
45 | self.num_heads = num_heads
46 | head_dim = dim // num_heads
47 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
48 | self.scale = qk_scale or head_dim ** -0.5
49 |
50 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
51 | self.attn_drop = nn.Dropout(attn_drop)
52 | self.proj = nn.Linear(dim, dim)
53 | self.proj_drop = nn.Dropout(proj_drop)
54 |
55 | def forward(self, x):
56 | B, N, C = x.shape
57 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
58 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
59 |
60 | attn = (q @ k.transpose(-2, -1)) * self.scale
61 | attn = attn.softmax(dim=-1)
62 | attn = self.attn_drop(attn)
63 |
64 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
65 | x = self.proj(x)
66 | x = self.proj_drop(x)
67 | return x
68 |
69 |
70 | class Block(nn.Module):
71 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
72 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
73 | super().__init__()
74 | self.norm1 = norm_layer(dim)
75 |
76 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
77 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
78 | self.norm2 = norm_layer(dim)
79 | mlp_hidden_dim = int(dim * mlp_ratio)
80 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
81 |
82 | self.attn = Attention(
83 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
84 |
85 | def forward(self, x):
86 | x = x + self.drop_path(self.attn(self.norm1(x)))
87 | x = x + self.drop_path(self.mlp(self.norm2(x)))
88 | return x
89 |
90 |
91 | class TransformerEncoder(nn.Module):
92 | """ Transformer Encoder without hierarchical structure
93 | """
94 |
95 | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
96 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
97 | super().__init__()
98 |
99 | self.blocks = nn.ModuleList([
100 | Block(
101 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
102 | drop=drop_rate, attn_drop=attn_drop_rate,
103 | drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
104 | )
105 | for i in range(depth)])
106 |
107 | def forward(self, x, pos):
108 | for _, block in enumerate(self.blocks):
109 | x = block(x + pos)
110 | return x
111 |
112 |
113 | class PointTransformer(nn.Module):
114 | def __init__(self, config, **kwargs):
115 | super().__init__()
116 | self.config = config
117 | self.args = kwargs["args"]
118 |
119 | self.trans_dim = config.trans_dim
120 | self.depth = config.depth
121 | self.drop_path_rate = config.drop_path_rate
122 | self.cls_dim = config.cls_dim
123 | self.num_heads = config.num_heads
124 |
125 | self.group_size = config.group_size
126 | self.num_group = config.num_group
127 | # grouper
128 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
129 | # define the encoder
130 | self.encoder_dims = config.encoder_dims
131 | self.encoder = Encoder(encoder_channel=self.encoder_dims)
132 | # bridge encoder and transformer
133 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim)
134 |
135 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
136 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
137 |
138 | self.pos_embed = nn.Sequential(
139 | nn.Linear(3, 128),
140 | nn.GELU(),
141 | nn.Linear(128, self.trans_dim)
142 | )
143 |
144 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
145 | self.blocks = TransformerEncoder(
146 | embed_dim=self.trans_dim,
147 | depth=self.depth,
148 | drop_path_rate=dpr,
149 | num_heads=self.num_heads
150 | )
151 |
152 | self.norm = nn.LayerNorm(self.trans_dim)
153 | # self.load_model_from_ckpt('/export/home/repos/SLIP/pretrained_models/point_transformer_8192.pt')
154 | if not self.args.evaluate_3d:
155 | self.load_model_from_ckpt('Path/checkpoints/initialize_models_point_bert_pretrained.pt')
156 |
157 | # self.cls_head_finetune = nn.Sequential(
158 | # nn.Linear(self.trans_dim * 2, 256),
159 | # nn.ReLU(inplace=True),
160 | # nn.Dropout(0.5),
161 | # nn.Linear(256, self.cls_dim)
162 | # )
163 |
164 | # self.build_loss_func()
165 |
166 | def build_loss_func(self):
167 | self.loss_ce = nn.CrossEntropyLoss()
168 |
169 | def get_loss_acc(self, pred, gt, smoothing=True):
170 | # import pdb; pdb.set_trace()
171 | gt = gt.contiguous().view(-1).long()
172 |
173 | if smoothing:
174 | eps = 0.2
175 | n_class = pred.size(1)
176 |
177 | one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1)
178 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
179 | log_prb = F.log_softmax(pred, dim=1)
180 |
181 | loss = -(one_hot * log_prb).sum(dim=1).mean()
182 | else:
183 | loss = self.loss_ce(pred, gt.long())
184 |
185 | pred = pred.argmax(-1)
186 | acc = (pred == gt).sum() / float(gt.size(0))
187 |
188 | return loss, acc * 100
189 |
190 | def load_model_from_ckpt(self, bert_ckpt_path):
191 | ckpt = torch.load(bert_ckpt_path)
192 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
193 | for k in list(base_ckpt.keys()):
194 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'):
195 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k]
196 | elif k.startswith('base_model'):
197 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
198 | del base_ckpt[k]
199 |
200 | incompatible = self.load_state_dict(base_ckpt, strict=False)
201 |
202 | if incompatible.missing_keys:
203 | print_log('missing_keys', logger='Transformer')
204 | print_log(
205 | get_missing_parameters_message(incompatible.missing_keys),
206 | logger='Transformer'
207 | )
208 | if incompatible.unexpected_keys:
209 | print_log('unexpected_keys', logger='Transformer')
210 | print_log(
211 | get_unexpected_parameters_message(incompatible.unexpected_keys),
212 | logger='Transformer'
213 | )
214 |
215 | print_log(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}', logger='Transformer')
216 |
217 | def forward(self, pts):
218 | # divide the point cloud in the same form. This is important
219 | neighborhood, center = self.group_divider(pts)
220 | # encoder the input cloud blocks
221 | group_input_tokens = self.encoder(neighborhood) # B G N
222 | group_input_tokens = self.reduce_dim(group_input_tokens)
223 | # prepare cls
224 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
225 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
226 | # add pos embedding
227 | pos = self.pos_embed(center)
228 | # final input
229 | x = torch.cat((cls_tokens, group_input_tokens), dim=1)
230 | pos = torch.cat((cls_pos, pos), dim=1)
231 | # transformer
232 | x = self.blocks(x, pos)
233 | x = self.norm(x)
234 | concat_f = torch.cat([x[:, 0], x[:, 1:].max(1)[0]], dim=-1)
235 | # ret = self.cls_head_finetune(concat_f)
236 | return concat_f
237 |
238 | class PointTransformer_Colored(nn.Module):
239 | def __init__(self, config, **kwargs):
240 | super().__init__()
241 | self.config = config
242 | self.args = kwargs["args"]
243 |
244 | self.trans_dim = config.trans_dim
245 | self.depth = config.depth
246 | self.drop_path_rate = config.drop_path_rate
247 | self.cls_dim = config.cls_dim
248 | self.num_heads = config.num_heads
249 |
250 | self.group_size = config.group_size
251 | self.num_group = config.num_group
252 | # grouper
253 | self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
254 | # define the encoder
255 | self.encoder_dims = config.encoder_dims
256 | self.encoder = Encoder(encoder_channel=self.encoder_dims, input_dim=6)
257 | # bridge encoder and transformer
258 | self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim)
259 |
260 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
261 | self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
262 |
263 | self.pos_embed = nn.Sequential(
264 | nn.Linear(3, 128),
265 | nn.GELU(),
266 | nn.Linear(128, self.trans_dim)
267 | )
268 |
269 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
270 | self.blocks = TransformerEncoder(
271 | embed_dim=self.trans_dim,
272 | depth=self.depth,
273 | drop_path_rate=dpr,
274 | num_heads=self.num_heads
275 | )
276 |
277 | self.norm = nn.LayerNorm(self.trans_dim)
278 |
279 | print("training from scratch for pointbert.")
280 |
281 | model_size = cal_model_parm_nums(self)
282 | print("model size:")
283 | print(model_size)
284 |
285 | def build_loss_func(self):
286 | self.loss_ce = nn.CrossEntropyLoss()
287 |
288 | def get_loss_acc(self, pred, gt, smoothing=True):
289 | # import pdb; pdb.set_trace()
290 | gt = gt.contiguous().view(-1).long()
291 |
292 | if smoothing:
293 | eps = 0.2
294 | n_class = pred.size(1)
295 |
296 | one_hot = torch.zeros_like(pred).scatter(1, gt.view(-1, 1), 1)
297 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
298 | log_prb = F.log_softmax(pred, dim=1)
299 |
300 | loss = -(one_hot * log_prb).sum(dim=1).mean()
301 | else:
302 | loss = self.loss_ce(pred, gt.long())
303 |
304 | pred = pred.argmax(-1)
305 | acc = (pred == gt).sum() / float(gt.size(0))
306 |
307 | return loss, acc * 100
308 |
309 | def load_model_from_ckpt(self, bert_ckpt_path):
310 | ckpt = torch.load(bert_ckpt_path, map_location=torch.device('cpu'))
311 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
312 | for k in list(base_ckpt.keys()):
313 | if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'):
314 | base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k]
315 | elif k.startswith('base_model'):
316 | base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
317 | del base_ckpt[k]
318 |
319 | incompatible = self.load_state_dict(base_ckpt, strict=False)
320 |
321 | if incompatible.missing_keys:
322 | print_log('missing_keys', logger='Transformer')
323 | print_log(
324 | get_missing_parameters_message(incompatible.missing_keys),
325 | logger='Transformer'
326 | )
327 | if incompatible.unexpected_keys:
328 | print_log('unexpected_keys', logger='Transformer')
329 | print_log(
330 | get_unexpected_parameters_message(incompatible.unexpected_keys),
331 | logger='Transformer'
332 | )
333 |
334 | print_log(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}', logger='Transformer')
335 |
336 | def forward(self, pts):
337 | # divide the point cloud in the same form. This is important
338 | neighborhood, center = self.group_divider(pts)
339 | # encoder the input cloud blocks
340 | group_input_tokens = self.encoder(neighborhood) # B G N
341 | group_input_tokens = self.reduce_dim(group_input_tokens)
342 | # prepare cls
343 | cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
344 | cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
345 | # add pos embedding
346 | pos = self.pos_embed(center)
347 | # final input
348 | x = torch.cat((cls_tokens, group_input_tokens), dim=1)
349 | pos = torch.cat((cls_pos, pos), dim=1)
350 | # transformer
351 | x = self.blocks(x, pos)
352 | x = self.norm(x)
353 | concat_f = torch.cat([x[:, 0], x[:, 1:].max(1)[0]], dim=-1)
354 | # ret = self.cls_head_finetune(concat_f)
355 | return concat_f
--------------------------------------------------------------------------------
/src/models/text_encoder_3d.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | # import math
3 | from .clip_utils import *
4 | from torch.nn.parameter import Parameter
5 | import math
6 | from torch.nn import Dropout
7 | from functools import reduce
8 | from operator import mul
9 | class CLIPTextEncoder(nn.Module):
10 | def __init__(self, context_length=77,
11 | vocab_size=49408,
12 | transformer_width=512,
13 | transformer_heads=8,
14 | transformer_layers=12,
15 | embed_dim=512,
16 | out_dim=256,
17 | patch_size = 16,
18 | pretrained=None, **kwargs):
19 | super().__init__()
20 | self.layers = transformer_layers
21 | self.total_d_layer = transformer_layers-1
22 | self.pretrained = pretrained
23 | self.out_indices = [3, 5, 7, 11]
24 | self.num_tokens = 100
25 | self.prompt_dim = embed_dim
26 | self.context_length = context_length
27 |
28 | self.transformer = Transformer(
29 | width=transformer_width,
30 | layers=transformer_layers,
31 | heads=transformer_heads,
32 | attn_mask=self.build_attention_mask()
33 | )
34 |
35 | self.vocab_size = vocab_size
36 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
37 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
38 | self.ln_final = LayerNorm(transformer_width)
39 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
40 |
41 | self._init_prompt(patch_size, self.num_tokens, self.prompt_dim, self.total_d_layer)
42 |
43 | def _init_prompt(self, patch, num_tokens, prompt_dim, total_d_layer):
44 | patch_size = []
45 | patch_size.append(patch)
46 | patch_size.append(patch)
47 | val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim))
48 |
49 | if total_d_layer >= 0:
50 | self.prompt_embeddings = nn.Parameter(torch.zeros(1, num_tokens, prompt_dim))
51 | # xavier_uniform initialization
52 | nn.init.uniform_(self.prompt_embeddings.data, -val, val)
53 |
54 | if total_d_layer > 0: # noqa
55 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros(total_d_layer, num_tokens, prompt_dim))
56 | # xavier_uniform initialization
57 | nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
58 |
59 | self.prompt_proj = nn.Linear(prompt_dim, prompt_dim)
60 | nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')
61 | self.prompt_norm = LayerNorm(prompt_dim, eps=1e-6)
62 | self.prompt_dropout = Dropout(0.1)
63 |
64 | else: # total_d_layer < 0
65 | self.deep_prompt_embeddings = nn.Parameter(torch.zeros(abs(total_d_layer), num_tokens, prompt_dim))
66 | nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
67 | self.prompt_proj = nn.Linear(prompt_dim, prompt_dim)
68 | nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')
69 | self.prompt_norm = LayerNorm(prompt_dim, eps=1e-6)
70 | self.prompt_dropout = Dropout(0.1)
71 |
72 | def init_weights(self, pretrained=None):
73 | pretrained = pretrained or self.pretrained
74 | if isinstance(pretrained, str):
75 | checkpoint = torch.jit.load(pretrained, map_location='cpu').float().state_dict()
76 |
77 | state_dict = {}
78 |
79 | for k in checkpoint.keys():
80 | if k.startswith('transformer.'):
81 | state_dict[k] = checkpoint[k]
82 |
83 | if k == 'positional_embedding' or k == 'text_projection' or k.startswith('token_embedding') or k.startswith('ln_final'):
84 | if k == 'positional_embedding' and checkpoint[k].size(0) > self.context_length:
85 | checkpoint[k] = checkpoint[k][:self.context_length]
86 | print('positional_embedding is tuncated from 77 to', self.context_length)
87 | state_dict[k] = checkpoint[k]
88 |
89 | u, w = self.load_state_dict(state_dict, False)
90 | print(u, w, 'are misaligned params in text encoder')
91 |
92 | def build_attention_mask(self):
93 | # lazily create causal attention mask, with full attention between the vision tokens
94 | # pytorch uses additive attention mask; fill with -inf
95 | mask = torch.empty(self.context_length + self.num_tokens, self.context_length+self.num_tokens)
96 | mask.fill_(float("-inf"))
97 | mask.triu_(1) # zero out the lower diagonal
98 | return mask
99 |
100 | def forward_deep_prompt(self, embedding_output, features, out_last=False):
101 | B = embedding_output.shape[1] # batch_size
102 | for i in range(self.layers):
103 | if i == 0:
104 | hidden_states = self.transformer.resblocks[i](embedding_output)
105 | elif i <= self.deep_prompt_embeddings.shape[0]:
106 | deep_prompt_emb = self.prompt_dropout(self.prompt_proj(self.deep_prompt_embeddings[i-1]).expand(B, -1, -1)).permute(1, 0, 2) # seems like the middle layer's dpt
107 | hidden_states = torch.cat((
108 | deep_prompt_emb,
109 | hidden_states[self.num_tokens:,:,:]
110 | ), dim=0) # 177 B 768
111 |
112 | hidden_states = self.transformer.resblocks[i](hidden_states)
113 | else:
114 | # hidden_states = torch.cat((
115 | # hidden_states[:1, :, :],
116 | # hidden_states[-(H*W):, :, :]
117 | # ), dim=0)
118 | hidden_states = self.transformer.resblocks[i](hidden_states)
119 |
120 | if len(self.out_indices) > 1:
121 | if i in self.out_indices:
122 | xp = hidden_states.permute(1, 0, 2)[:, -77:, :].permute(0, 2, 1) # B,512,77
123 | features.append(xp.contiguous())
124 |
125 | if i == (self.layers-2): #10
126 | before_last_feats = self.prompt_norm(hidden_states) # 1125x4x768
127 |
128 | encoded = self.prompt_norm(hidden_states)
129 | if out_last:
130 | return before_last_feats
131 | else:
132 | return encoded, features
133 | def encode_token(self, token):
134 | x = self.token_embedding(token)
135 | return x
136 |
137 | def forward(self, text, token):
138 | # x = self.token_embedding(text)
139 | x = text + self.positional_embedding
140 | if self.total_d_layer >=0:
141 | # concat prompt
142 | x = torch.cat(( # Deep Prompt Tuning
143 | self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(x.shape[0], -1, -1)),
144 | x
145 | ), dim=1)# B,177,512
146 |
147 | x = x.permute(1, 0, 2)
148 | features = []
149 | outs = []
150 | if self.total_d_layer > 0:
151 | x, features = self.forward_deep_prompt(x, features)
152 | # x = self.transformer(x)
153 | x = x[-77:,:,:]
154 | x = x.permute(1, 0, 2)
155 | x = self.ln_final(x)
156 | x = x[torch.arange(x.shape[0]), token.argmax(dim=-1)] @ self.text_projection
157 | # outs.append(tuple(features))
158 | return x
159 |
160 | if __name__ == '__main__':
161 | from tokenizer import SimpleTokenizer
162 | tokenizer = SimpleTokenizer()
163 | encoder = CLIPTextEncoder(pretrained="Path/ViT-B-16.pt")
164 | encoder.init_weights()
165 | text = 'a airplane'
166 | texts = tokenizer(text).unsqueeze(0)
167 | text_embed = encoder(texts)
168 | print(f'text_mebed: {text_embed.shape}')
169 | exclude_key = 'prompt'
170 | for n,m in encoder.named_parameters():
171 | if exclude_key not in n:
172 | m.requires_grad = False
173 | else:
174 | print(n)
175 | # if exclude_key:
176 | # if isinstance(exclude_key, str):
177 | # if not exclude_key in n:
178 | # m.requires_grad = False
179 | # print(f'False : {n}')
180 | # elif isinstance(exclude_key, list):
181 | # count = 0
182 | # for i in range(len(exclude_key)):
183 | # i_layer = str(exclude_key[i])
184 | # if i_layer in n:
185 | # count += 1
186 | # if count == 0:
187 | # m.requires_grad = False
188 | # elif count>0:
189 | # print('Finetune layer in backbone:', n)
190 | # else:
191 | # assert AttributeError("Dont support the type of exclude_key!")
192 | # else:
193 | # m.requires_grad = False
194 | # print(f'False : {n}')
195 |
--------------------------------------------------------------------------------
/src/models/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Modified from github.com/openai/CLIP
2 | import gzip
3 | import html
4 | import os
5 | from functools import lru_cache
6 |
7 | import ftfy
8 | import regex as re
9 | import torch
10 |
11 |
12 | @lru_cache()
13 | def default_bpe():
14 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
15 |
16 |
17 | @lru_cache()
18 | def bytes_to_unicode():
19 | """
20 | Returns list of utf-8 byte and a corresponding list of unicode strings.
21 | The reversible bpe codes work on unicode strings.
22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
24 | This is a signficant percentage of your normal, say, 32K bpe vocab.
25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
26 | And avoids mapping to whitespace/control characters the bpe code barfs on.
27 | """
28 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
29 | cs = bs[:]
30 | n = 0
31 | for b in range(2**8):
32 | if b not in bs:
33 | bs.append(b)
34 | cs.append(2**8+n)
35 | n += 1
36 | cs = [chr(n) for n in cs]
37 | return dict(zip(bs, cs))
38 |
39 |
40 | def get_pairs(word):
41 | """Return set of symbol pairs in a word.
42 | Word is represented as tuple of symbols (symbols being variable-length strings).
43 | """
44 | pairs = set()
45 | prev_char = word[0]
46 | for char in word[1:]:
47 | pairs.add((prev_char, char))
48 | prev_char = char
49 | return pairs
50 |
51 |
52 | def basic_clean(text):
53 | text = ftfy.fix_text(text)
54 | text = html.unescape(html.unescape(text))
55 | return text.strip()
56 |
57 |
58 | def whitespace_clean(text):
59 | text = re.sub(r'\s+', ' ', text)
60 | text = text.strip()
61 | return text
62 |
63 |
64 | class SimpleTokenizer(object):
65 | def __init__(self, bpe_path: str = default_bpe()):
66 | self.byte_encoder = bytes_to_unicode()
67 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
68 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
69 | merges = merges[1:49152-256-2+1]
70 | merges = [tuple(merge.split()) for merge in merges]
71 | vocab = list(bytes_to_unicode().values())
72 | vocab = vocab + [v+'' for v in vocab]
73 | for merge in merges:
74 | vocab.append(''.join(merge))
75 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
76 | self.encoder = dict(zip(vocab, range(len(vocab))))
77 | self.decoder = {v: k for k, v in self.encoder.items()}
78 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
79 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
80 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
81 |
82 | def bpe(self, token):
83 | if token in self.cache:
84 | return self.cache[token]
85 | word = tuple(token[:-1]) + ( token[-1] + '',)
86 | pairs = get_pairs(word)
87 |
88 | if not pairs:
89 | return token+''
90 |
91 | while True:
92 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
93 | if bigram not in self.bpe_ranks:
94 | break
95 | first, second = bigram
96 | new_word = []
97 | i = 0
98 | while i < len(word):
99 | try:
100 | j = word.index(first, i)
101 | new_word.extend(word[i:j])
102 | i = j
103 | except:
104 | new_word.extend(word[i:])
105 | break
106 |
107 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
108 | new_word.append(first+second)
109 | i += 2
110 | else:
111 | new_word.append(word[i])
112 | i += 1
113 | new_word = tuple(new_word)
114 | word = new_word
115 | if len(word) == 1:
116 | break
117 | else:
118 | pairs = get_pairs(word)
119 | word = ' '.join(word)
120 | self.cache[token] = word
121 | return word
122 |
123 | def encode(self, text):
124 | bpe_tokens = []
125 | text = whitespace_clean(basic_clean(text)).lower()
126 | for token in re.findall(self.pat, text):
127 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
128 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
129 | return bpe_tokens
130 |
131 | def decode(self, tokens):
132 | text = ''.join([self.decoder[token] for token in tokens])
133 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
134 | return text
135 |
136 | def __call__(self, texts, context_length=77):
137 | if isinstance(texts, str):
138 | texts = [texts]
139 |
140 | sot_token = self.encoder["<|startoftext|>"]
141 | eot_token = self.encoder["<|endoftext|>"]
142 | all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
143 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
144 |
145 | for i, tokens in enumerate(all_tokens):
146 | tokens = tokens[:context_length]
147 | result[i, :len(tokens)] = torch.tensor(tokens)
148 |
149 | if len(result) == 1:
150 | return result[0]
151 | return result
--------------------------------------------------------------------------------
/src/tools/builder.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | # online package
3 | import torch
4 | # optimizer
5 | import torch.optim as optim
6 | # dataloader
7 | # from datasets import build_dataset_from_cfg
8 | # from .models.build import build_model_from_cfg
9 | from models.build import build_model_from_cfg
10 | # utils
11 | from utils.logger import *
12 |
13 | from utils.misc import *
14 | from timm.scheduler import CosineLRScheduler
15 |
16 | # def dataset_builder(args, config):
17 | # dataset = build_dataset_from_cfg(config._base_, config.others)
18 | # shuffle = config.others.subset == 'train'
19 | # if args.distributed:
20 | # sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle = shuffle)
21 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size = config.others.bs,
22 | # num_workers = int(args.num_workers),
23 | # drop_last = config.others.subset == 'train',
24 | # worker_init_fn = worker_init_fn,
25 | # sampler = sampler)
26 | # else:
27 | # sampler = None
28 | # dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs,
29 | # shuffle = shuffle,
30 | # drop_last = config.others.subset == 'train',
31 | # num_workers = int(args.num_workers),
32 | # worker_init_fn=worker_init_fn)
33 | # return sampler, dataloader
34 |
35 | def model_builder(config):
36 | model = build_model_from_cfg(config)
37 | return model
38 |
39 | def build_opti_sche(base_model, config):
40 | opti_config = config.optimizer
41 | if opti_config.type == 'AdamW':
42 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
43 | decay = []
44 | no_decay = []
45 | for name,param in model.named_parameters():
46 | # for name, param in model.module.named_parameters():
47 | if not param.requires_grad:
48 | continue # frozen weights
49 | if len(param.shape) == 1 or name.endswith(".bias") or 'token' in name or name in skip_list:
50 | # print(name)
51 | no_decay.append(param)
52 | else:
53 | decay.append(param)
54 | return [
55 | {'params': no_decay, 'weight_decay': 0.},
56 | {'params': decay, 'weight_decay': weight_decay}]
57 | param_groups = add_weight_decay(base_model, weight_decay=opti_config.kwargs.weight_decay)
58 | optimizer = optim.AdamW(param_groups, **opti_config.kwargs)
59 | elif opti_config.type == 'Adam':
60 | optimizer = optim.Adam(base_model.parameters(), **opti_config.kwargs)
61 | elif opti_config.type == 'SGD':
62 | optimizer = optim.SGD(base_model.parameters(), nesterov=True, **opti_config.kwargs)
63 | else:
64 | raise NotImplementedError()
65 |
66 | sche_config = config.scheduler
67 | if sche_config.type == 'LambdaLR':
68 | scheduler = build_lambda_sche(optimizer, sche_config.kwargs) # misc.py
69 | elif sche_config.type == 'CosLR':
70 | scheduler = CosineLRScheduler(optimizer,
71 | t_initial=sche_config.kwargs.epochs,
72 | t_mul=1,
73 | lr_min=1e-6,
74 | decay_rate=0.1,
75 | warmup_lr_init=1e-6,
76 | warmup_t=sche_config.kwargs.initial_epochs,
77 | cycle_limit=1,
78 | t_in_epochs=True)
79 | elif sche_config.type == 'StepLR':
80 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **sche_config.kwargs)
81 | elif sche_config.type == 'function':
82 | scheduler = None
83 | else:
84 | raise NotImplementedError()
85 |
86 | if config.get('bnmscheduler') is not None:
87 | bnsche_config = config.bnmscheduler
88 | if bnsche_config.type == 'Lambda':
89 | bnscheduler = build_lambda_bnsche(base_model, bnsche_config.kwargs) # misc.py
90 | scheduler = [scheduler, bnscheduler]
91 |
92 | return optimizer, scheduler
93 |
94 | # def resume_model(base_model, args, logger = None):
95 | # ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')
96 | # if not os.path.exists(ckpt_path):
97 | # print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger)
98 | # return 0, 0
99 | # print_log(f'[RESUME INFO] Loading model weights from {ckpt_path}...', logger = logger )
100 |
101 | # # load state dict
102 | # map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank}
103 | # state_dict = torch.load(ckpt_path, map_location=map_location)
104 | # # parameter resume of base model
105 | # # if args.local_rank == 0:
106 | # base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()}
107 | # base_model.load_state_dict(base_ckpt, strict = True)
108 |
109 | # # parameter
110 | # start_epoch = state_dict['epoch'] + 1
111 | # best_metrics = state_dict['best_metrics']
112 | # if not isinstance(best_metrics, dict):
113 | # best_metrics = best_metrics.state_dict()
114 | # # print(best_metrics)
115 |
116 | # print_log(f'[RESUME INFO] resume ckpts @ {start_epoch - 1} epoch( best_metrics = {str(best_metrics):s})', logger = logger)
117 | # return start_epoch, best_metrics
118 |
119 | # def resume_optimizer(optimizer, args, logger = None):
120 | # ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')
121 | # if not os.path.exists(ckpt_path):
122 | # print_log(f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger = logger)
123 | # return 0, 0, 0
124 | # print_log(f'[RESUME INFO] Loading optimizer from {ckpt_path}...', logger = logger )
125 | # # load state dict
126 | # state_dict = torch.load(ckpt_path, map_location='cpu')
127 | # # optimizer
128 | # optimizer.load_state_dict(state_dict['optimizer'])
129 |
130 | def save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, prefix, args, logger = None):
131 | if args.local_rank == 0:
132 | torch.save({
133 | 'base_model' : base_model.module.state_dict() if args.distributed else base_model.state_dict(),
134 | 'optimizer' : optimizer.state_dict(),
135 | 'epoch' : epoch,
136 | 'metrics' : metrics.state_dict() if metrics is not None else dict(),
137 | 'best_metrics' : best_metrics.state_dict() if best_metrics is not None else dict(),
138 | }, os.path.join(args.experiment_path, prefix + '.pth'))
139 | print_log(f"Save checkpoint at {os.path.join(args.experiment_path, prefix + '.pth')}", logger = logger)
140 |
141 | def load_model(base_model, ckpt_path, logger = None):
142 | if not os.path.exists(ckpt_path):
143 | raise NotImplementedError('no checkpoint file from path %s...' % ckpt_path)
144 | print_log(f'Loading weights from {ckpt_path}...', logger = logger )
145 |
146 | # load state dict
147 | state_dict = torch.load(ckpt_path, map_location='cpu')
148 | # parameter resume of base model
149 | if state_dict.get('model') is not None:
150 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['model'].items()}
151 | elif state_dict.get('base_model') is not None:
152 | base_ckpt = {k.replace("module.", ""): v for k, v in state_dict['base_model'].items()}
153 | else:
154 | raise RuntimeError('mismatch of ckpt weight')
155 | base_model.load_state_dict(base_ckpt, strict = True)
156 |
157 | epoch = -1
158 | if state_dict.get('epoch') is not None:
159 | epoch = state_dict['epoch']
160 | if state_dict.get('metrics') is not None:
161 | metrics = state_dict['metrics']
162 | if not isinstance(metrics, dict):
163 | metrics = metrics.state_dict()
164 | else:
165 | metrics = 'No Metrics'
166 | print_log(f'ckpts @ {epoch} epoch( performance = {str(metrics):s})', logger = logger)
167 | return
--------------------------------------------------------------------------------
/src/train/train_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import munch
5 | import yaml
6 | import sys
7 | from torch.utils.tensorboard import SummaryWriter
8 | from network.base_model import BaseModel
9 | from network.base_model import AverageValueMeter,build_optimizer,save_model_2
10 | import argparse
11 | from time import time
12 | from tqdm.auto import tqdm
13 | import time as timetmp
14 | import logging
15 | import random
16 | import math
17 | from loss.cdloss import SimplificationLoss
18 | from dataset_svr.trainer_dataset import build_dataset,get_spherepoints
19 |
20 | def setFolders(args):
21 | LOG_DIR = args.dir_outpath
22 | MODEL_NAME = '%s-%s'%(args.model_name, timetmp.strftime("%m%d_%H%M", timetmp.localtime()))
23 |
24 | OUT_DIR = os.path.join(LOG_DIR, MODEL_NAME)
25 | args.dir_checkpoints = os.path.join(OUT_DIR, 'checkpoints')
26 | if not os.path.exists(OUT_DIR): os.mkdir(OUT_DIR)
27 | if not os.path.exists(args.dir_checkpoints):
28 | os.makedirs(args.dir_checkpoints)
29 |
30 | LOG_FOUT = open(os.path.join(OUT_DIR, 'log_%s.csv' %(MODEL_NAME)), 'w')
31 | return MODEL_NAME, OUT_DIR, LOG_FOUT
32 | def log_string(out_str,LOG_FOUT):
33 | LOG_FOUT.write(out_str+'\n')
34 | LOG_FOUT.flush()
35 |
36 | def train():
37 | exp_name,Log_dir,LOG_FOUT = setFolders(args)
38 | writer = SummaryWriter(log_dir=Log_dir)
39 | log_string('EPOCH,avg_cd_l1,avg_cd_l2,Best CDL2[epoch,best_loss],',LOG_FOUT)
40 | logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(Log_dir, 'train.log')),
41 | logging.StreamHandler(sys.stdout)])
42 | logging.info(str(args))
43 | metrics = ['cd_p', 'cd_t', 'f1']
44 | best_epoch_losses = {m: (0, 0) if m == 'f1' else (0, math.inf) for m in metrics}
45 | train_loss_meter = AverageValueMeter()
46 | val_loss_meters = {m: AverageValueMeter() for m in metrics}
47 | #seed
48 | if not args.manual_seed:
49 | seed = random.randint(1,10000)
50 | else:
51 | seed = int(args.manual_seed)
52 | logging.info('Random Seed: %d' % seed)
53 | random.seed(seed)
54 | torch.manual_seed(seed)
55 | if args.distributed:
56 | svrmodel = torch.nn.DataParallel(BaseModel(),device_ids=args.gpus,output_device=args.gpus[0])
57 | svrmodel.to(device)
58 | else:
59 | svrmodel = BaseModel.to(device)
60 |
61 | optimizer,scheduler = build_optimizer(svrmodel,args)
62 | best_cd_l1 = float("inf")
63 | best_cd_l2 = float("inf")
64 | best_f1 = float("inf")
65 | print("Data Uploading...")
66 |
67 | dataloader, dataloader_test = build_dataset(args)
68 |
69 | print("Data Preparation Done...")
70 | loss_function = SimplificationLoss()
71 | print("Loss Function Done...")
72 |
73 |
74 | for epoch in tqdm(range(args.start_epoch,args.nepoch),desc='Training'):
75 | epoch_start_time = time()
76 | total_cd_l1 = 0
77 | total_cd_l2 = 0
78 | train_loss_meter.reset()
79 |
80 | svrmodel.module.train()
81 |
82 | for param_group in optimizer.param_groups:
83 | print(f"Epoch {epoch+1}, Learning Rate: {param_group['lr']}")
84 | n_batches = len(dataloader)
85 | # train
86 | with tqdm(dataloader) as t:
87 | for batch_idx,data in enumerate(t):
88 | optimizer.zero_grad()
89 |
90 | n_itr = epoch * n_batches + batch_idx
91 | # to(cuda)
92 | images = data['image'].to(device)
93 | batch_size = images.shape[0]
94 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device)
95 |
96 | pointclouds = data['points'].to(device)
97 |
98 | pred_points = svrmodel(images,partial_pcs)
99 | # to(cuda)
100 | pred_points = pred_points.to(pointclouds.device)
101 | pred_points = pred_points.transpose(2,1)
102 | net_loss,loss_t=loss_function(pred_points,pointclouds)
103 |
104 | # caculate Chamfer distance loss
105 | net_loss = net_loss.mean()
106 | loss_t = loss_t.mean()
107 |
108 | net_loss_all = net_loss + loss_t
109 |
110 | train_loss_meter.update(net_loss.item())
111 | net_loss_all.backward()
112 | optimizer.step()
113 |
114 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4
115 | total_cd_l2 += cd_l2_item
116 | cd_l1_item = net_loss.item() * 1e4
117 | total_cd_l1 += cd_l1_item
118 |
119 | t.set_description('[Epoch %d/%d][Batch %d/%d]' % (epoch, args.nepoch, batch_idx + 1, n_batches))
120 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [cd_l1_item, cd_l2_item]])
121 | writer.add_scalar('MAE_netloss',net_loss.item(),n_itr)
122 | writer.add_scalar('MAE_losst',loss_t.item(),n_itr)
123 | scheduler.step(epoch)
124 |
125 | avg_cd_l1 = total_cd_l1 / n_batches
126 | avg_cd_l2 = total_cd_l2 / n_batches
127 | epoch_end_time = time()
128 | logging.info('')
129 | logging.info(
130 | exp_name + '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s ' %
131 | (epoch,args.nepoch,epoch_end_time - epoch_start_time,['%.4f' % l for l in [avg_cd_l1,avg_cd_l2]])
132 | )
133 | log_string(f'{epoch}[{batch_idx}],{avg_cd_l1:.4f},{avg_cd_l2:.4f},{total_cd_l2:.4f},{best_epoch_losses},', LOG_FOUT)
134 | writer.add_scalar('MAE_CDL2', avg_cd_l2, epoch)
135 | writer.add_scalar('MAE_CDL1', avg_cd_l1, epoch)
136 |
137 | if epoch % args.epoch_interval_to_val == 0 or epoch == args.nepoch - 1:
138 | best_cd_l1, best_cd_l2 ,best_f1= val(svrmodel,loss_function,epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,Log_dir,best_cd_l1,best_cd_l2,best_f1)
139 |
140 | def val(net,cal_loss,curr_epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,log_dir,best_cd_l1,best_cd_l2,best_f1):
141 | val_start_time = time()
142 | metrics_val = ['cd_t']
143 | val_loss_meters = {m: AverageValueMeter() for m in metrics_val}
144 | logging.info('Testing...')
145 | for v in val_loss_meters.values():
146 | v.reset()
147 | net.module.eval()
148 | total_cd_l1 = 0
149 | total_cd_l2 = 0
150 | total_f1 = 0
151 | n_batches = len(dataloader_test)
152 | with torch.no_grad():
153 | with tqdm(dataloader_test) as tt:
154 | for i,data in enumerate(tt):
155 | images = data['image'].to(device)
156 | batch_size = images.shape[0]
157 | gt = data['points'].to(device)
158 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device)
159 | pred_points = net(images,partial_pcs)
160 | pred_points = pred_points.transpose(2,1)
161 |
162 | loss_p, loss_t,f1 = cal_loss(pred_points, gt,calc_f1=True)
163 |
164 | cd_l1_item = torch.sum(loss_p).item() / batch_size * 1e4
165 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4
166 | f1_item = torch.sum(f1).item() / batch_size * 1e4
167 | total_cd_l1 += cd_l1_item
168 | total_cd_l2 += cd_l2_item
169 | total_f1 += f1_item
170 |
171 | avg_cd_l1 = total_cd_l1 / n_batches
172 | avg_cd_l2 = total_cd_l2 / n_batches
173 | avg_f1 = total_f1 / n_batches
174 |
175 | if avg_cd_l1 < best_cd_l1:
176 | best_cd_l1 = avg_cd_l1
177 | save_model_2(str(log_dir) + '/checkpoints/bestl1_network.pth', net)
178 | logging.info("Saving net...")
179 |
180 | if avg_cd_l2 < best_cd_l2:
181 | best_cd_l2 = avg_cd_l2
182 |
183 | if avg_f1 > best_f1:
184 | best_f1 = avg_f1
185 |
186 | log_string('%d,%.2f,%.2f,%.2f,%.2f,%.2f,%.2f'%(curr_epoch, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2,avg_f1,best_f1), LOG_FOUT)
187 |
188 | val_end_time = time()
189 |
190 | logging.info(
191 | '[Epoch %d/%d] TestTime = %.3f (s) Curr_cdl1 = %s Best_cdl1 = %s Curr_cdl2 = %s Best_cdl2 = %s Curr_f1 = %s Best_f1 = %s' %
192 | (curr_epoch, args.nepoch, val_end_time - val_start_time, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2, avg_f1, best_f1))
193 |
194 | return best_cd_l1, best_cd_l2 , best_f1
195 |
196 | if __name__ == '__main__':
197 | parser = argparse.ArgumentParser(description='Train config file')
198 | parser.add_argument('-c', '--config', help='path to config file', required=True)
199 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True)
200 | arg = parser.parse_args()
201 | config_path = arg.config
202 | args = munch.munchify(yaml.safe_load(open(config_path)))
203 |
204 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
205 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id) #os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7'
206 | print('Using gpu:' + str(arg.gpu_id))
207 | sphere_points = get_spherepoints(args.number_points,0.5)
208 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu')
209 | print('Number of points:' + str(args.number_points))
210 |
211 | train()
--------------------------------------------------------------------------------
/src/train/train_ltp.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from dataset_svr.trainer_text_dataset import *
3 | import torch
4 | import torch.nn.functional as F
5 | from models.ULIP_models import get_loss_v2
6 | import sys
7 | import time
8 | from network.ltp_model import *
9 | from network.ltp_model import TextAlignPCModel
10 | from models.ULIP_utils import *
11 | from dataset_svr.trainer_dataset import build_dataset
12 | import os
13 | import logging
14 | import time as timetmp
15 |
16 |
17 | def setFolders(args):
18 | LOG_DIR = args.dir_outpath
19 | MODEL_NAME = '%s-%s'%(args.model, timetmp.strftime("%m%d_%H%M", timetmp.localtime()))
20 |
21 | OUT_DIR = os.path.join(LOG_DIR, MODEL_NAME)
22 | args.dir_checkpoints = os.path.join(OUT_DIR, 'checkpoints')
23 | if not os.path.exists(OUT_DIR): os.mkdir(OUT_DIR)
24 | if not os.path.exists(args.dir_checkpoints):
25 | os.makedirs(args.dir_checkpoints)
26 |
27 | LOG_FOUT = open(os.path.join(OUT_DIR, 'log_%s.csv' %(MODEL_NAME)), 'w')
28 | return MODEL_NAME, OUT_DIR, args.dir_checkpoints
29 | def log_string(out_str,LOG_FOUT):
30 | LOG_FOUT.write(out_str+'\n')
31 | LOG_FOUT.flush()
32 | def save_model(path, net, net_d=None):
33 | if net_d is not None:
34 | torch.save({'net_state_dict': net.state_dict(),
35 | 'D_state_dict': net_d.state_dict()}, path)
36 | else:
37 | torch.save({'net_state_dict': net.state_dict()}, path)
38 | def main(args):
39 | best_cos = float(0)
40 | exp_name,Log_dir,Check_FOUT = setFolders(args)
41 | logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(Log_dir, 'train.log')),
42 | logging.StreamHandler(sys.stdout)])
43 | logging.info(str(args))
44 | dataloader_train, dataloader_test = build_dataset(args)
45 | # create model
46 | print("=> creating model: {}".format(args.model))
47 | model = TextAlignPCModel(args).to(device)
48 |
49 | criterion = get_loss_v2(args).to(device)
50 | optimizer,scheduler = build_optimizer(model,args)
51 | print("=> beginning training")
52 | best_epoch = -1
53 | for epoch in range(args.start_epoch, args.epochs):
54 | train_stats = train(dataloader_train,model,criterion,optimizer,epoch,scheduler,args)
55 | logging.info('')
56 | logging.info(f'Training result: {train_stats}')
57 | logging.info('Testing...')
58 | val_stats = {"cos@t&p":-1}
59 |
60 | if epoch % 1 == 0:
61 | val_stats = val(dataloader_test,model,args)
62 | cos = val_stats["cos@t&p"]
63 | print(f'val_stats:{val_stats}')
64 |
65 | is_best = cos > best_cos
66 | if is_best:
67 | best_epoch = epoch
68 |
69 | best_cos = max(cos,best_cos)
70 | if is_best :
71 | print("=> saving checkpoint")
72 | save_model(str(Check_FOUT)+'/prompt.pth',model.text_prompt_embeddings)
73 | save_model(str(Check_FOUT)+'/text_encoder.pth',model.text_encoder)
74 | logging.info("Saving net...")
75 | if epoch + 1 == args.epochs:
76 | print("=> saving last checkpoint")
77 | save_model(str(Check_FOUT)+'/last_prompt.pth',model.text_prompt_embeddings)
78 | save_model(str(Check_FOUT)+'/last_text_encoder.pth',model.text_encoder)
79 |
80 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
81 | **{f'test_{k}': v for k, v in val_stats.items()},
82 | 'epoch': epoch,
83 | 'best_cos@t&p': best_cos,
84 | 'best_epoch': best_epoch}
85 | logging.info(f'log_stats: {log_stats}')
86 |
87 | def train(train_loader, model, criterion, optimizer, epoch, scheduler, args):
88 | n_batches = len(train_loader)
89 | # switch to train mode
90 | total_loss = 0
91 | model.train()
92 | end = time.time()
93 | with tqdm(train_loader) as t:
94 | for data_iter,data in enumerate(t):
95 | optimizer.zero_grad()
96 | # images
97 | images = data['image'].to(device)
98 | # text
99 | category = data['category']
100 | category_indices = [int(c) for c in category]
101 | text_labels = torch.tensor(category_indices).to(device)
102 | # pointclouds
103 | pc = data['points'].to(device)
104 |
105 | outputs = model(images,text_labels,pc)
106 | loss_dict = criterion(outputs)
107 | loss = loss_dict['loss']
108 |
109 | loss.backward()
110 | optimizer.step()
111 | total_loss += loss.item()
112 | get_model(model).logit_scale.data.clamp_(0, 4.6052)
113 | logit_scale = get_model(model).logit_scale.exp().item()
114 | avg_loss = total_loss / n_batches
115 | scheduler.step(epoch)
116 | return {**{'loss':avg_loss},
117 | 'lr': optimizer.param_groups[0]['lr'],
118 | 'logit_scale': logit_scale}
119 | def val(test_loader,model,tokenizer,args=None):
120 | batch_time = AverageMeter('Time', ':6.3f')
121 | total = 0
122 | avg = 0
123 | model.eval()
124 | n_batches = len(test_loader)
125 |
126 | print('==>encoding captions')
127 | with torch.no_grad():
128 | end = time.time()
129 | for i,data in enumerate(test_loader):
130 | # image
131 | images = data['image'].to(device)
132 | batch_size = images.shape[0]
133 | # category
134 | category = data['category']
135 | category_indices = [int(c) for c in category]
136 | text_labels = torch.tensor(category_indices).to(device)
137 | # pointclouds
138 | gt = data['points'].to(device)
139 | # model
140 | outputs = model(images,text_labels,gt)
141 | text_embed = outputs['text_embed']
142 | gt_embed = outputs['pc_embed']
143 |
144 | text_embed = F.normalize(text_embed, dim=-1, p=2)
145 | gt_embed = F.normalize(gt_embed, dim=-1, p=2)
146 |
147 | logits_per_pc_text = text_embed @ gt_embed.t()
148 |
149 | cos_sim_1 = torch.diag(logits_per_pc_text).mean()
150 |
151 | total += cos_sim_1
152 |
153 | avg = total / n_batches
154 |
155 | batch_time.update(time.time() - end)
156 | end = time.time()
157 | logging.info(f"batch_time {batch_time}, 'cos@t&p':{avg}")
158 | return {'cos@t&p':avg}
159 |
160 | if __name__ =='__main__':
161 | parser = argparse.ArgumentParser(description='Train config file')
162 | parser.add_argument('-c', '--config', help='path to config file', required=True)
163 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True)
164 | arg = parser.parse_args()
165 | config_path = arg.config
166 | args = munch.munchify(yaml.safe_load(open(config_path)))
167 | device = torch.device(args.gpu_id if torch.cuda.is_available() else 'cpu')
168 | print('Using gpu:' + str(arg.gpu_id))
169 |
170 | train()
--------------------------------------------------------------------------------
/src/train/train_pro.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import munch
4 | import yaml
5 | import sys
6 | from network.pro_model import ProModel
7 | from network.pro_model import AverageValueMeter,build_optimizer,save_model_2
8 | from time import time
9 | from tqdm.auto import tqdm
10 | import time as timetmp
11 | import argparse
12 | import logging
13 | import random
14 | import math
15 | from loss.cdloss import SimplificationLoss
16 | from dataset_svr.trainer_dataset import build_dataset,get_spherepoints
17 |
18 | def setFolders(args):
19 | LOG_DIR = args.dir_outpath
20 | MODEL_NAME = '%s-%s'%(args.model_name, timetmp.strftime("%m%d_%H%M", timetmp.localtime()))
21 |
22 | OUT_DIR = os.path.join(LOG_DIR, MODEL_NAME)
23 | args.dir_checkpoints = os.path.join(OUT_DIR, 'checkpoints')
24 | if not os.path.exists(OUT_DIR): os.mkdir(OUT_DIR)
25 | if not os.path.exists(args.dir_checkpoints):
26 | os.makedirs(args.dir_checkpoints)
27 |
28 | LOG_FOUT = open(os.path.join(OUT_DIR, 'log_%s.csv' %(MODEL_NAME)), 'w')
29 | return MODEL_NAME, OUT_DIR, LOG_FOUT
30 | def log_string(out_str,LOG_FOUT):
31 | LOG_FOUT.write(out_str+'\n')
32 | LOG_FOUT.flush()
33 |
34 |
35 | def train():
36 | exp_name,Log_dir,LOG_FOUT = setFolders(args)
37 | log_string('EPOCH,avg_cd_l1,avg_cd_l2,Best CDL2[epoch,best_loss],',LOG_FOUT)
38 | logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(Log_dir, 'train.log')),
39 | logging.StreamHandler(sys.stdout)])
40 | logging.info(str(args))
41 |
42 | metrics = ['cd_p', 'cd_t', 'f1']
43 | best_epoch_losses = {m: (0, 0) if m == 'f1' else (0, math.inf) for m in metrics}
44 | train_loss_meter = AverageValueMeter()
45 | val_loss_meters = {m: AverageValueMeter() for m in metrics}
46 | #seed
47 | if not args.manual_seed:
48 | seed = random.randint(1,10000)
49 | else:
50 | seed = int(args.manual_seed)
51 | logging.info('Random Seed: %d' % seed)
52 | random.seed(seed)
53 | torch.manual_seed(seed)
54 |
55 | if args.distirbuted:
56 | promodel = torch.nn.DataParallel(ProModel(),device_ids=args.gpus,output_device=args.gpus[0])
57 | else:
58 | promodel = ProModel.to(device)
59 | optimizer,scheduler = build_optimizer(promodel,args)
60 | best_cd_l1 = float("inf")
61 | best_cd_l2 = float("inf")
62 | best_f1 = float("inf")
63 | print("Data Uploading...")
64 |
65 | dataloader, dataloader_test = build_dataset(args)
66 | print("Data Preparation Done...")
67 | loss_function = SimplificationLoss()
68 | print("Loss Function Done...")
69 |
70 | for epoch in tqdm(range(args.start_epoch,args.nepoch),desc='Training'):
71 | # time,loss
72 | epoch_start_time = time()
73 | total_cd_l1 = 0
74 | total_cd_l2 = 0
75 | train_loss_meter.reset()
76 |
77 | promodel.module.train()
78 |
79 | for param_group in optimizer.param_groups:
80 | print(f"Epoch {epoch+1}, Learning Rate: {param_group['lr']}")
81 | n_batches = len(dataloader)
82 | # train
83 | with tqdm(dataloader) as t:
84 | for batch_idx,data in enumerate(t):
85 | optimizer.zero_grad()
86 |
87 | n_itr = epoch * n_batches + batch_idx
88 | # to(cuda)
89 | images = data['image'].to(device)
90 | batch_size = images.shape[0]
91 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device)
92 | # partial_pcs = partial_pcs.to(device)
93 | pointclouds = data['points'].to(device)
94 | # category
95 | category = data['category']
96 | category_indices = [int(c) for c in category]
97 | text_labels = torch.tensor(category_indices).to(device)
98 |
99 | pred_points = promodel(images,partial_pcs,text_labels)
100 | # to(cuda)
101 | pred_points = pred_points.to(pointclouds.device)
102 | pred_points = pred_points.transpose(2,1)
103 | net_loss,loss_t=loss_function(pred_points,pointclouds)
104 |
105 | # caculate Chamfer distance loss
106 | net_loss = net_loss.mean()
107 | loss_t = loss_t.mean()
108 |
109 | # net_loss_all = net_loss + loss_t
110 | net_loss_all = net_loss
111 | train_loss_meter.update(net_loss.item())
112 | net_loss_all.backward()
113 | optimizer.step()
114 |
115 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4
116 | total_cd_l2 += cd_l2_item
117 | cd_l1_item = net_loss.item() * 1e4
118 | total_cd_l1 += cd_l1_item
119 |
120 | t.set_description('[Epoch %d/%d][Batch %d/%d]' % (epoch, args.nepoch, batch_idx + 1, n_batches))
121 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [cd_l1_item, cd_l2_item]])
122 | scheduler.step(epoch) # CosLR,
123 |
124 | avg_cd_l1 = total_cd_l1 / n_batches
125 | avg_cd_l2 = total_cd_l2 / n_batches
126 | epoch_end_time = time()
127 | logging.info('')
128 | logging.info(
129 | exp_name + '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s ' %
130 | (epoch,args.nepoch,epoch_end_time - epoch_start_time,['%.4f' % l for l in [avg_cd_l1,avg_cd_l2]])
131 | )
132 | log_string(f'{epoch}[{batch_idx}],{avg_cd_l1:.4f},{avg_cd_l2:.4f},{total_cd_l2:.4f},{best_epoch_losses},', LOG_FOUT)
133 | if epoch % args.epoch_interval_to_val == 0 or epoch == args.nepoch - 1:
134 | best_cd_l1, best_cd_l2 ,best_f1= val(promodel,loss_function,epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,Log_dir,best_cd_l1,best_cd_l2,best_f1)
135 |
136 | def val(net,cal_loss,curr_epoch,val_loss_meters,dataloader_test,best_epoch_losses,LOG_FOUT,log_dir,best_cd_l1,best_cd_l2,best_f1):
137 | val_start_time = time()
138 | metrics_val = ['cd_t']
139 | val_loss_meters = {m: AverageValueMeter() for m in metrics_val}
140 | logging.info('Testing...')
141 | for v in val_loss_meters.values():
142 | v.reset()
143 | net.module.eval()
144 |
145 | total_cd_l1 = 0
146 | total_cd_l2 = 0
147 | total_f1 = 0
148 | n_batches = len(dataloader_test)
149 | with torch.no_grad():
150 | with tqdm(dataloader_test) as tt:
151 | for i,data in enumerate(tt):
152 | images = data['image'].to(device)
153 | batch_size = images.shape[0]
154 | gt = data['points'].to(device)
155 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device)
156 | category = data['category']
157 | category_indices = [int(c) for c in category]
158 | text_labels = torch.tensor(category_indices).to(device)
159 |
160 | pred_points = net(images,partial_pcs,text_labels)
161 | pred_points = pred_points.transpose(2,1)
162 |
163 | loss_p, loss_t,f1 = cal_loss(pred_points, gt,calc_f1=True)
164 |
165 | cd_l1_item = torch.sum(loss_p).item() / batch_size * 1e4
166 | cd_l2_item = torch.sum(loss_t).item() / batch_size * 1e4
167 | f1_item = torch.sum(f1).item() / batch_size * 1e4
168 | total_cd_l1 += cd_l1_item
169 | total_cd_l2 += cd_l2_item
170 | total_f1 += f1_item
171 |
172 | avg_cd_l1 = total_cd_l1 / n_batches
173 | avg_cd_l2 = total_cd_l2 / n_batches
174 | avg_f1 = total_f1 / n_batches
175 |
176 | if avg_cd_l1 < best_cd_l1:
177 | best_cd_l1 = avg_cd_l1
178 | save_model_2(str(log_dir) + '/checkpoints/bestl1_network.pth', net)
179 | logging.info("Saving net...")
180 | if avg_cd_l2 < best_cd_l2:
181 | best_cd_l2 = avg_cd_l2
182 | if avg_f1 > best_f1:
183 | best_f1 = avg_f1
184 |
185 | log_string('%d,%.2f,%.2f,%.2f,%.2f,%.2f,%.2f'%(curr_epoch, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2,avg_f1,best_f1), LOG_FOUT)
186 |
187 | val_end_time = time()
188 |
189 | logging.info(
190 | '[Epoch %d/%d] TestTime = %.3f (s) Curr_cdl1 = %s Best_cdl1 = %s Curr_cdl2 = %s Best_cdl2 = %s Curr_f1 = %s Best_f1 = %s' %
191 | (curr_epoch, args.nepoch, val_end_time - val_start_time, avg_cd_l1, best_cd_l1, avg_cd_l2, best_cd_l2, avg_f1, best_f1))
192 |
193 | return best_cd_l1, best_cd_l2 , best_f1
194 |
195 | if __name__ == '__main__':
196 | parser = argparse.ArgumentParser(description='Train config file')
197 | parser.add_argument('-c', '--config', help='path to config file', required=True)
198 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True)
199 | arg = parser.parse_args()
200 | config_path = arg.config
201 | args = munch.munchify(yaml.safe_load(open(config_path)))
202 |
203 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
204 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id) #os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3,4,5,6,7'
205 | print('Using gpu:' + str(arg.gpu_id))
206 | sphere_points = get_spherepoints(args.number_points,0.5)
207 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu')
208 | print('Number of points:' + str(args.number_points))
209 |
210 | train()
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QINGQINGLE/MESC-3D/cfef548f06951ecb112ee15513e184ff962d4e7a/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | import copy
5 | import logging
6 | import os
7 | from collections import defaultdict
8 | import torch
9 | import torch.nn as nn
10 |
11 | from typing import Any
12 | from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
13 |
14 | from termcolor import colored
15 |
16 | def get_missing_parameters_message(keys: List[str]) -> str:
17 | """
18 | Get a logging-friendly message to report parameter names (keys) that are in
19 | the model but not found in a checkpoint.
20 | Args:
21 | keys (list[str]): List of keys that were not found in the checkpoint.
22 | Returns:
23 | str: message.
24 | """
25 | groups = _group_checkpoint_keys(keys)
26 | msg = "Some model parameters or buffers are not found in the checkpoint:\n"
27 | msg += "\n".join(
28 | " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
29 | )
30 | return msg
31 |
32 |
33 | def get_unexpected_parameters_message(keys: List[str]) -> str:
34 | """
35 | Get a logging-friendly message to report parameter names (keys) that are in
36 | the checkpoint but not found in the model.
37 | Args:
38 | keys (list[str]): List of keys that were not found in the model.
39 | Returns:
40 | str: message.
41 | """
42 | groups = _group_checkpoint_keys(keys)
43 | msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
44 | msg += "\n".join(
45 | " " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
46 | )
47 | return msg
48 |
49 |
50 | def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
51 | """
52 | Strip the prefix in metadata, if any.
53 | Args:
54 | state_dict (OrderedDict): a state-dict to be loaded to the model.
55 | prefix (str): prefix.
56 | """
57 | keys = sorted(state_dict.keys())
58 | if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
59 | return
60 |
61 | for key in keys:
62 | newkey = key[len(prefix):]
63 | state_dict[newkey] = state_dict.pop(key)
64 |
65 | # also strip the prefix in metadata, if any..
66 | try:
67 | metadata = state_dict._metadata # pyre-ignore
68 | except AttributeError:
69 | pass
70 | else:
71 | for key in list(metadata.keys()):
72 | # for the metadata dict, the key can be:
73 | # '': for the DDP module, which we want to remove.
74 | # 'module': for the actual model.
75 | # 'module.xx.xx': for the rest.
76 |
77 | if len(key) == 0:
78 | continue
79 | newkey = key[len(prefix):]
80 | metadata[newkey] = metadata.pop(key)
81 |
82 |
83 | def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
84 | """
85 | Group keys based on common prefixes. A prefix is the string up to the final
86 | "." in each key.
87 | Args:
88 | keys (list[str]): list of parameter names, i.e. keys in the model
89 | checkpoint dict.
90 | Returns:
91 | dict[list]: keys with common prefixes are grouped into lists.
92 | """
93 | groups = defaultdict(list)
94 | for key in keys:
95 | pos = key.rfind(".")
96 | if pos >= 0:
97 | head, tail = key[:pos], [key[pos + 1:]]
98 | else:
99 | head, tail = key, []
100 | groups[head].extend(tail)
101 | return groups
102 |
103 |
104 | def _group_to_str(group: List[str]) -> str:
105 | """
106 | Format a group of parameter name suffixes into a loggable string.
107 | Args:
108 | group (list[str]): list of parameter name suffixes.
109 | Returns:
110 | str: formated string.
111 | """
112 | if len(group) == 0:
113 | return ""
114 |
115 | if len(group) == 1:
116 | return "." + group[0]
117 |
118 | return ".{" + ", ".join(group) + "}"
119 |
120 |
121 | def _named_modules_with_dup(
122 | model: nn.Module, prefix: str = ""
123 | ) -> Iterable[Tuple[str, nn.Module]]:
124 | """
125 | The same as `model.named_modules()`, except that it includes
126 | duplicated modules that have more than one name.
127 | """
128 | yield prefix, model
129 | for name, module in model._modules.items(): # pyre-ignore
130 | if module is None:
131 | continue
132 | submodule_prefix = prefix + ("." if prefix else "") + name
133 | yield from _named_modules_with_dup(module, submodule_prefix)
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from easydict import EasyDict
3 | import os
4 | from .logger import print_log
5 |
6 | def log_args_to_file(args, pre='args', logger=None):
7 | for key, val in args.__dict__.items():
8 | print_log(f'{pre}.{key} : {val}', logger = logger)
9 |
10 | def log_config_to_file(cfg, pre='cfg', logger=None):
11 | for key, val in cfg.items():
12 | if isinstance(cfg[key], EasyDict):
13 | print_log(f'{pre}.{key} = edict()', logger = logger)
14 | log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)
15 | continue
16 | print_log(f'{pre}.{key} : {val}', logger = logger)
17 |
18 | def merge_new_config(config, new_config):
19 | for key, val in new_config.items():
20 | if not isinstance(val, dict):
21 | if key == '_base_':
22 | with open(new_config['_base_'], 'r') as f:
23 | try:
24 | val = yaml.load(f, Loader=yaml.FullLoader)
25 | except:
26 | val = yaml.load(f)
27 | config[key] = EasyDict()
28 | merge_new_config(config[key], val)
29 | else:
30 | config[key] = val
31 | continue
32 | if key not in config:
33 | config[key] = EasyDict()
34 | merge_new_config(config[key], val)
35 | return config
36 |
37 | def cfg_from_yaml_file(cfg_file):
38 | config = EasyDict()
39 | with open(cfg_file, 'r') as f:
40 | try:
41 | new_config = yaml.load(f, Loader=yaml.FullLoader)
42 | except:
43 | new_config = yaml.load(f)
44 | merge_new_config(config=config, new_config=new_config)
45 | return config
46 |
47 | def get_config(args, logger=None):
48 | if args.resume:
49 | cfg_path = os.path.join(args.experiment_path, 'config.yaml')
50 | if not os.path.exists(cfg_path):
51 | print_log("Failed to resume", logger = logger)
52 | raise FileNotFoundError()
53 | print_log(f'Resume yaml from {cfg_path}', logger = logger)
54 | args.config = cfg_path
55 | config = cfg_from_yaml_file(args.config)
56 | if not args.resume and args.local_rank == 0:
57 | save_experiment_config(args, config, logger)
58 | return config
59 |
60 | def save_experiment_config(args, config, logger = None):
61 | config_path = os.path.join(args.experiment_path, 'config.yaml')
62 | os.system('cp %s %s' % (args.config, config_path))
63 | print_log(f'Copy the Config file from {args.config} to {config_path}',logger = logger )
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | logger_initialized = {}
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
7 | """Get root logger and add a keyword filter to it.
8 | The logger will be initialized if it has not been initialized. By default a
9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
10 | also be added. The name of the root logger is the top-level package name,
11 | e.g., "mmdet3d".
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to logging.INFO.
16 | name (str, optional): The name of the root logger, also used as a
17 | filter keyword. Defaults to 'mmdet3d'.
18 | Returns:
19 | :obj:`logging.Logger`: The obtained logger
20 | """
21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level)
22 | # add a logging filter
23 | logging_filter = logging.Filter(name)
24 | logging_filter.filter = lambda record: record.find(name) != -1
25 |
26 | return logger
27 |
28 |
29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
30 | """Initialize and get a logger by name.
31 | If the logger has not been initialized, this method will initialize the
32 | logger by adding one or two handlers, otherwise the initialized logger will
33 | be directly returned. During initialization, a StreamHandler will always be
34 | added. If `log_file` is specified and the process rank is 0, a FileHandler
35 | will also be added.
36 | Args:
37 | name (str): Logger name.
38 | log_file (str | None): The log filename. If specified, a FileHandler
39 | will be added to the logger.
40 | log_level (int): The logger level. Note that only the process of
41 | rank 0 is affected, and other processes will set the level to
42 | "Error" thus be silent most of the time.
43 | file_mode (str): The file mode used in opening log file.
44 | Defaults to 'w'.
45 | Returns:
46 | logging.Logger: The expected logger.
47 | """
48 | logger = logging.getLogger(name)
49 | if name in logger_initialized:
50 | return logger
51 | # handle hierarchical names
52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
53 | # initialization since it is a child of "a".
54 | for logger_name in logger_initialized:
55 | if name.startswith(logger_name):
56 | return logger
57 |
58 | # handle duplicate logs to the console
59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
60 | # to the root logger. As logger.propagate is True by default, this root
61 | # level handler causes logging messages from rank>0 processes to
62 | # unexpectedly show up on the console, creating much unwanted clutter.
63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
64 | # at the ERROR level.
65 | for handler in logger.root.handlers:
66 | if type(handler) is logging.StreamHandler:
67 | handler.setLevel(logging.ERROR)
68 |
69 | stream_handler = logging.StreamHandler()
70 | handlers = [stream_handler]
71 |
72 | if dist.is_available() and dist.is_initialized():
73 | rank = dist.get_rank()
74 | else:
75 | rank = 0
76 |
77 | # only rank 0 will add a FileHandler
78 | if rank == 0 and log_file is not None:
79 | # Here, the default behaviour of the official logger is 'a'. Thus, we
80 | # provide an interface to change the file mode to the default
81 | # behaviour.
82 | file_handler = logging.FileHandler(log_file, file_mode)
83 | handlers.append(file_handler)
84 |
85 | formatter = logging.Formatter(
86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
87 | for handler in handlers:
88 | handler.setFormatter(formatter)
89 | handler.setLevel(log_level)
90 | logger.addHandler(handler)
91 |
92 | if rank == 0:
93 | logger.setLevel(log_level)
94 | else:
95 | logger.setLevel(logging.ERROR)
96 |
97 | logger_initialized[name] = True
98 |
99 |
100 | return logger
101 |
102 |
103 | def print_log(msg, logger=None, level=logging.INFO):
104 | """Print a log message.
105 | Args:
106 | msg (str): The message to be logged.
107 | logger (logging.Logger | str | None): The logger to be used.
108 | Some special loggers are:
109 | - "silent": no message will be printed.
110 | - other str: the logger obtained with `get_root_logger(logger)`.
111 | - None: The `print()` method will be used to print log messages.
112 | level (int): Logging level. Only available when `logger` is a Logger
113 | object or "root".
114 | """
115 | if logger is None:
116 | print(msg)
117 | elif isinstance(logger, logging.Logger):
118 | logger.log(level, msg)
119 | elif logger == 'silent':
120 | pass
121 | elif isinstance(logger, str):
122 | _logger = get_logger(logger)
123 | _logger.log(level, msg)
124 | else:
125 | raise TypeError(
126 | 'logger should be either a logging.Logger object, str, '
127 | f'"silent" or None, but got {type(logger)}')
--------------------------------------------------------------------------------
/src/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from mpl_toolkits.mplot3d import Axes3D
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | from collections import abc
10 | from pointnet2_ops import pointnet2_utils
11 |
12 |
13 | def fps(data, number):
14 | '''
15 | data B N 3
16 | number int
17 | '''
18 | fps_idx = pointnet2_utils.furthest_point_sample(data, number)
19 | fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
20 | return fps_data
21 |
22 |
23 | def worker_init_fn(worker_id):
24 | np.random.seed(np.random.get_state()[1][0] + worker_id)
25 |
26 | def build_lambda_sche(opti, config):
27 | if config.get('decay_step') is not None:
28 | lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
29 | scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
30 | else:
31 | raise NotImplementedError()
32 | return scheduler
33 |
34 | def build_lambda_bnsche(model, config):
35 | if config.get('decay_step') is not None:
36 | bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
37 | bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
38 | else:
39 | raise NotImplementedError()
40 | return bnm_scheduler
41 |
42 | def set_random_seed(seed, deterministic=False):
43 | """Set random seed.
44 | Args:
45 | seed (int): Seed to be used.
46 | deterministic (bool): Whether to set the deterministic option for
47 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
48 | to True and `torch.backends.cudnn.benchmark` to False.
49 | Default: False.
50 |
51 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
52 | if cuda_deterministic: # slower, more reproducible
53 | cudnn.deterministic = True
54 | cudnn.benchmark = False
55 | else: # faster, less reproducible
56 | cudnn.deterministic = False
57 | cudnn.benchmark = True
58 |
59 | """
60 | random.seed(seed)
61 | np.random.seed(seed)
62 | torch.manual_seed(seed)
63 | torch.cuda.manual_seed_all(seed)
64 | if deterministic:
65 | torch.backends.cudnn.deterministic = True
66 | torch.backends.cudnn.benchmark = False
67 |
68 |
69 | def is_seq_of(seq, expected_type, seq_type=None):
70 | """Check whether it is a sequence of some type.
71 | Args:
72 | seq (Sequence): The sequence to be checked.
73 | expected_type (type): Expected type of sequence items.
74 | seq_type (type, optional): Expected sequence type.
75 | Returns:
76 | bool: Whether the sequence is valid.
77 | """
78 | if seq_type is None:
79 | exp_seq_type = abc.Sequence
80 | else:
81 | assert isinstance(seq_type, type)
82 | exp_seq_type = seq_type
83 | if not isinstance(seq, exp_seq_type):
84 | return False
85 | for item in seq:
86 | if not isinstance(item, expected_type):
87 | return False
88 | return True
89 |
90 |
91 | def set_bn_momentum_default(bn_momentum):
92 | def fn(m):
93 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
94 | m.momentum = bn_momentum
95 | return fn
96 |
97 | class BNMomentumScheduler(object):
98 |
99 | def __init__(
100 | self, model, bn_lambda, last_epoch=-1,
101 | setter=set_bn_momentum_default
102 | ):
103 | if not isinstance(model, nn.Module):
104 | raise RuntimeError(
105 | "Class '{}' is not a PyTorch nn Module".format(
106 | type(model).__name__
107 | )
108 | )
109 |
110 | self.model = model
111 | self.setter = setter
112 | self.lmbd = bn_lambda
113 |
114 | self.step(last_epoch + 1)
115 | self.last_epoch = last_epoch
116 |
117 | def step(self, epoch=None):
118 | if epoch is None:
119 | epoch = self.last_epoch + 1
120 |
121 | self.last_epoch = epoch
122 | self.model.apply(self.setter(self.lmbd(epoch)))
123 |
124 | def get_momentum(self, epoch=None):
125 | if epoch is None:
126 | epoch = self.last_epoch + 1
127 | return self.lmbd(epoch)
128 |
129 |
130 |
131 | def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):
132 | '''
133 | seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
134 | '''
135 | _,n,c = xyz.shape
136 |
137 | assert n == num_points
138 | assert c == 3
139 | if crop == num_points:
140 | return xyz, None
141 |
142 | INPUT = []
143 | CROP = []
144 | for points in xyz:
145 | if isinstance(crop,list):
146 | num_crop = random.randint(crop[0],crop[1])
147 | else:
148 | num_crop = crop
149 |
150 | points = points.unsqueeze(0)
151 |
152 | if fixed_points is None:
153 | center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()
154 | else:
155 | if isinstance(fixed_points,list):
156 | fixed_point = random.sample(fixed_points,1)[0]
157 | else:
158 | fixed_point = fixed_points
159 | center = fixed_point.reshape(1,1,3).cuda()
160 |
161 | distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
162 |
163 | idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
164 |
165 | if padding_zeros:
166 | input_data = points.clone()
167 | input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
168 |
169 | else:
170 | input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
171 |
172 | crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
173 |
174 | if isinstance(crop,list):
175 | INPUT.append(fps(input_data,2048))
176 | CROP.append(fps(crop_data,2048))
177 | else:
178 | INPUT.append(input_data)
179 | CROP.append(crop_data)
180 |
181 | input_data = torch.cat(INPUT,dim=0)# B N 3
182 | crop_data = torch.cat(CROP,dim=0)# B M 3
183 |
184 | return input_data.contiguous(), crop_data.contiguous()
185 |
186 | def get_ptcloud_img(ptcloud,roll,pitch):
187 | fig = plt.figure(figsize=(8, 8))
188 |
189 | x, z, y = ptcloud.transpose(1, 0)
190 | ax = fig.gca(projection=Axes3D.name, adjustable='box')
191 | ax.axis('off')
192 | # ax.axis('scaled')
193 | ax.view_init(roll,pitch)
194 | max, min = np.max(ptcloud), np.min(ptcloud)
195 | ax.set_xbound(min, max)
196 | ax.set_ybound(min, max)
197 | ax.set_zbound(min, max)
198 | ax.scatter(x, y, z, zdir='z', c=y, cmap='jet')
199 |
200 | fig.canvas.draw()
201 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
202 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
203 | return img
204 |
205 |
206 |
207 | def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y',
208 | xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):
209 | fig = plt.figure(figsize=(6*len(data_list),6))
210 | cmax = data_list[-1][:,0].max()
211 |
212 | for i in range(len(data_list)):
213 | data = data_list[i][:-2048] if i == 1 else data_list[i]
214 | color = data[:,0] /cmax
215 | ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')
216 | ax.view_init(30, -120)
217 | b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')
218 | ax.set_title(titles[i])
219 |
220 | ax.set_axis_off()
221 | ax.set_xlim(xlim)
222 | ax.set_ylim(ylim)
223 | ax.set_zlim(zlim)
224 | plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
225 | if not os.path.exists(path):
226 | os.makedirs(path)
227 |
228 | pic_path = path + '.png'
229 | fig.savefig(pic_path)
230 |
231 | np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
232 | np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
233 | plt.close(fig)
234 |
235 |
236 | def random_dropping(pc, e):
237 | up_num = max(64, 768 // (e//50 + 1))
238 | pc = pc
239 | random_num = torch.randint(1, up_num, (1,1))[0,0]
240 | pc = fps(pc, random_num)
241 | padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
242 | pc = torch.cat([pc, padding], dim = 1)
243 | return pc
244 |
245 |
246 | def random_scale(partial, scale_range=[0.8, 1.2]):
247 | scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
248 | return partial * scale
249 |
--------------------------------------------------------------------------------
/src/utils/registry.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import warnings
3 | from functools import partial
4 | from utils import config,misc
5 |
6 | class Registry:
7 | """A registry to map strings to classes.
8 | Registered object could be built from registry.
9 | Example:
10 | >>> MODELS = Registry('models')
11 | >>> @MODELS.register_module()
12 | >>> class ResNet:
13 | >>> pass
14 | >>> resnet = MODELS.build(dict(NAME='ResNet'))
15 | Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
16 | advanced useage.
17 | Args:
18 | name (str): Registry name.
19 | build_func(func, optional): Build function to construct instance from
20 | Registry, func:`build_from_cfg` is used if neither ``parent`` or
21 | ``build_func`` is specified. If ``parent`` is specified and
22 | ``build_func`` is not given, ``build_func`` will be inherited
23 | from ``parent``. Default: None.
24 | parent (Registry, optional): Parent registry. The class registered in
25 | children registry could be built from parent. Default: None.
26 | scope (str, optional): The scope of registry. It is the key to search
27 | for children registry. If not specified, scope will be the name of
28 | the package where class is defined, e.g. mmdet, mmcls, mmseg.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, name, build_func=None, parent=None, scope=None):
33 | self._name = name
34 | self._module_dict = dict()
35 | self._children = dict()
36 | self._scope = self.infer_scope() if scope is None else scope
37 |
38 | # self.build_func will be set with the following priority:
39 | # 1. build_func
40 | # 2. parent.build_func
41 | # 3. build_from_cfg
42 | if build_func is None:
43 | if parent is not None:
44 | self.build_func = parent.build_func
45 | else:
46 | self.build_func = build_from_cfg
47 | else:
48 | self.build_func = build_func
49 | if parent is not None:
50 | assert isinstance(parent, Registry)
51 | parent._add_children(self)
52 | self.parent = parent
53 | else:
54 | self.parent = None
55 |
56 | def __len__(self):
57 | return len(self._module_dict)
58 |
59 | def __contains__(self, key):
60 | return self.get(key) is not None
61 |
62 | def __repr__(self):
63 | format_str = self.__class__.__name__ + \
64 | f'(name={self._name}, ' \
65 | f'items={self._module_dict})'
66 | return format_str
67 |
68 | @staticmethod
69 | def infer_scope():
70 | """Infer the scope of registry.
71 | The name of the package where registry is defined will be returned.
72 | Example:
73 | # in mmdet/models/backbone/resnet.py
74 | >>> MODELS = Registry('models')
75 | >>> @MODELS.register_module()
76 | >>> class ResNet:
77 | >>> pass
78 | The scope of ``ResNet`` will be ``mmdet``.
79 | Returns:
80 | scope (str): The inferred scope name.
81 | """
82 | # inspect.stack() trace where this function is called, the index-2
83 | # indicates the frame where `infer_scope()` is called
84 | filename = inspect.getmodule(inspect.stack()[2][0]).__name__
85 | split_filename = filename.split('.')
86 | return split_filename[0]
87 |
88 | @staticmethod
89 | def split_scope_key(key):
90 | """Split scope and key.
91 | The first scope will be split from key.
92 | Examples:
93 | >>> Registry.split_scope_key('mmdet.ResNet')
94 | 'mmdet', 'ResNet'
95 | >>> Registry.split_scope_key('ResNet')
96 | None, 'ResNet'
97 | Return:
98 | scope (str, None): The first scope.
99 | key (str): The remaining key.
100 | """
101 | split_index = key.find('.')
102 | if split_index != -1:
103 | return key[:split_index], key[split_index + 1:]
104 | else:
105 | return None, key
106 |
107 | @property
108 | def name(self):
109 | return self._name
110 |
111 | @property
112 | def scope(self):
113 | return self._scope
114 |
115 | @property
116 | def module_dict(self):
117 | return self._module_dict
118 |
119 | @property
120 | def children(self):
121 | return self._children
122 |
123 | def get(self, key):
124 | """Get the registry record.
125 | Args:
126 | key (str): The class name in string format.
127 | Returns:
128 | class: The corresponding class.
129 | """
130 | scope, real_key = self.split_scope_key(key)
131 | if scope is None or scope == self._scope:
132 | # get from self
133 | if real_key in self._module_dict:
134 | return self._module_dict[real_key]
135 | else:
136 | # get from self._children
137 | if scope in self._children:
138 | return self._children[scope].get(real_key)
139 | else:
140 | # goto root
141 | parent = self.parent
142 | while parent.parent is not None:
143 | parent = parent.parent
144 | return parent.get(key)
145 |
146 | def build(self, *args, **kwargs):
147 | return self.build_func(*args, **kwargs, registry=self)
148 |
149 | def _add_children(self, registry):
150 | """Add children for a registry.
151 | The ``registry`` will be added as children based on its scope.
152 | The parent registry could build objects from children registry.
153 | Example:
154 | >>> models = Registry('models')
155 | >>> mmdet_models = Registry('models', parent=models)
156 | >>> @mmdet_models.register_module()
157 | >>> class ResNet:
158 | >>> pass
159 | >>> resnet = models.build(dict(NAME='mmdet.ResNet'))
160 | """
161 |
162 | assert isinstance(registry, Registry)
163 | assert registry.scope is not None
164 | assert registry.scope not in self.children, \
165 | f'scope {registry.scope} exists in {self.name} registry'
166 | self.children[registry.scope] = registry
167 |
168 | def _register_module(self, module_class, module_name=None, force=False):
169 | if not inspect.isclass(module_class):
170 | raise TypeError('module must be a class, '
171 | f'but got {type(module_class)}')
172 |
173 | if module_name is None:
174 | module_name = module_class.__name__
175 | if isinstance(module_name, str):
176 | module_name = [module_name]
177 | for name in module_name:
178 | if not force and name in self._module_dict:
179 | raise KeyError(f'{name} is already registered '
180 | f'in {self.name}')
181 | self._module_dict[name] = module_class
182 |
183 | def deprecated_register_module(self, cls=None, force=False):
184 | warnings.warn(
185 | 'The old API of register_module(module, force=False) '
186 | 'is deprecated and will be removed, please use the new API '
187 | 'register_module(name=None, force=False, module=None) instead.')
188 | if cls is None:
189 | return partial(self.deprecated_register_module, force=force)
190 | self._register_module(cls, force=force)
191 | return cls
192 |
193 | def register_module(self, name=None, force=False, module=None):
194 | """Register a module.
195 | A record will be added to `self._module_dict`, whose key is the class
196 | name or the specified name, and value is the class itself.
197 | It can be used as a decorator or a normal function.
198 | Example:
199 | >>> backbones = Registry('backbone')
200 | >>> @backbones.register_module()
201 | >>> class ResNet:
202 | >>> pass
203 | >>> backbones = Registry('backbone')
204 | >>> @backbones.register_module(name='mnet')
205 | >>> class MobileNet:
206 | >>> pass
207 | >>> backbones = Registry('backbone')
208 | >>> class ResNet:
209 | >>> pass
210 | >>> backbones.register_module(ResNet)
211 | Args:
212 | name (str | None): The module name to be registered. If not
213 | specified, the class name will be used.
214 | force (bool, optional): Whether to override an existing class with
215 | the same name. Default: False.
216 | module (type): Module class to be registered.
217 | """
218 | if not isinstance(force, bool):
219 | raise TypeError(f'force must be a boolean, but got {type(force)}')
220 | # NOTE: This is a walkaround to be compatible with the old api,
221 | # while it may introduce unexpected bugs.
222 | if isinstance(name, type):
223 | return self.deprecated_register_module(name, force=force)
224 |
225 | # raise the error ahead of time
226 | if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):
227 | raise TypeError(
228 | 'name must be either of None, an instance of str or a sequence'
229 | f' of str, but got {type(name)}')
230 |
231 | # use it as a normal method: x.register_module(module=SomeClass)
232 | if module is not None:
233 | self._register_module(
234 | module_class=module, module_name=name, force=force)
235 | return module
236 |
237 | # use it as a decorator: @x.register_module()
238 | def _register(cls):
239 | self._register_module(
240 | module_class=cls, module_name=name, force=force)
241 | return cls
242 |
243 | return _register
244 |
245 |
246 | def build_from_cfg(cfg, registry, default_args=None):
247 | """Build a module from config dict.
248 | Args:
249 | cfg (edict): Config dict. It should at least contain the key "NAME".
250 | registry (:obj:`Registry`): The registry to search the type from.
251 | Returns:
252 | object: The constructed object.
253 | """
254 | if not isinstance(cfg, dict):
255 | raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
256 | if 'NAME' not in cfg:
257 | if default_args is None or 'NAME' not in default_args:
258 | raise KeyError(
259 | '`cfg` or `default_args` must contain the key "NAME", '
260 | f'but got {cfg}\n{default_args}')
261 | if not isinstance(registry, Registry):
262 | raise TypeError('registry must be an mmcv.Registry object, '
263 | f'but got {type(registry)}')
264 |
265 | if not (isinstance(default_args, dict) or default_args is None):
266 | raise TypeError('default_args must be a dict or None, '
267 | f'but got {type(default_args)}')
268 |
269 | if default_args is not None:
270 | cfg = config.merge_new_config(cfg, default_args)
271 |
272 | obj_type = cfg.get('NAME')
273 | print(f'obj_type:{obj_type}')
274 | if isinstance(obj_type, str):
275 | obj_cls = registry.get(obj_type)
276 | if obj_cls is None:
277 | raise KeyError(
278 | f'{obj_type} is not in the {registry.name} registry')
279 | elif inspect.isclass(obj_type):
280 | obj_cls = obj_type
281 | else:
282 | raise TypeError(
283 | f'type must be a str or valid type, but got {type(obj_type)}')
284 | try:
285 | return obj_cls(cfg)
286 | except Exception as e:
287 | # Normal TypeError does not print class name.
288 | raise type(e)(f'{obj_cls.__name__}: {e}')
289 |
--------------------------------------------------------------------------------
/src/val/val_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 | import random
5 | import logging
6 | from dataset_svr.trainer_dataset import build_dataset_val
7 | from dataset_svr.trainer_dataset import get_spherepoints
8 | from network.base_model import BaseModel
9 | from network.base_model import AverageValueMeter
10 | from loss.cdloss import SimplificationLoss
11 | from tqdm import tqdm
12 | import munch
13 | import yaml
14 |
15 | def val():
16 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu')
17 | dataloader_test = build_dataset_val(args)
18 | len_test = len(dataloader_test)
19 |
20 | if not args.manual_seed:
21 | seed = random.randint(1,10000)
22 | else:
23 | seed = int(args.manual_seed)
24 | logging.info('Random Seed: %d' % seed)
25 | random.seed(seed)
26 | torch.manual_seed(seed)
27 | ckpt = torch.load(args.ckpt_path)
28 | if args.distributed:
29 | svrmodel = torch.nn.DataParallel(BaseModel(args.model),device_ids=args.gpus,output_device=args.gpus[0])
30 | svrmodel.to(device)
31 | svrmodel.module.load_state_dict(ckpt['net_state_dict'])
32 | logging.info("%s's previous weights loaded." % args.model_name)
33 | svrmodel.module.eval()
34 | else:
35 | svrmodel = BaseModel(args.model).to(device)
36 | svrmodel.load_state_dict(ckpt['net_state_dict'])
37 | logging.info("%s's previous weights loaded." % args.model_name)
38 | svrmodel.eval()
39 |
40 |
41 | logging.info('Testing....')
42 | test_loss_l1 = AverageValueMeter()
43 | test_loss_l2 = AverageValueMeter()
44 | test_f1 = AverageValueMeter()
45 | loss_function = SimplificationLoss()
46 | sphere_points = get_spherepoints(args.number_points,0.5)
47 |
48 | with tqdm(dataloader_test) as t:
49 | for i,data in enumerate(t):
50 | with torch.no_grad():
51 | images = data['image'].to(device)
52 | gt = data['points'].to(device)
53 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device)
54 | batch_size = gt.shape[0]
55 | pred_points = svrmodel(images,partial_pcs)
56 |
57 | pred_points = pred_points.transpose(2,1)
58 | loss_p,loss_t,f1 = loss_function(pred_points,gt,calc_f1=True)
59 |
60 | cd_l1_item = torch.sum(loss_p).item() / batch_size
61 | cd_l2_item = torch.sum(loss_t).item() / batch_size
62 | f1_item = torch.sum(f1).item()/batch_size
63 | test_loss_l1.update(cd_l1_item, images.shape[0])
64 | test_loss_l2.update(cd_l2_item, images.shape[0])
65 | test_f1.update(f1_item,images.shape[0])
66 |
67 | print('cd_l1 %f cd_l2 %f f1 %f' % (test_loss_l1.avg, test_loss_l2.avg,test_f1.avg))
68 |
69 | if __name__ == "__main__":
70 | parser = argparse.ArgumentParser(description='Train config file')
71 | parser.add_argument('-c', '--config', help='path to config file', required=True)
72 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True)
73 | arg = parser.parse_args()
74 | config_path = arg.config
75 | args = munch.munchify(yaml.safe_load(open(config_path)))
76 |
77 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
78 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id)
79 | print('Using gpu:' + str(arg.gpu_id))
80 | val()
--------------------------------------------------------------------------------
/src/val/val_pro.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import os
4 | import random
5 | import logging
6 | from dataset_svr.trainer_dataset import build_dataset_val,get_spherepoints
7 | from network.pro_model import ProModel
8 | from network.pro_model import AverageValueMeter
9 | from loss.cdloss import SimplificationLoss
10 | from tqdm import tqdm
11 | import munch
12 | import yaml
13 | import os
14 |
15 | def val():
16 | device = torch.device(args.gpus[0] if torch.cuda.is_available() else 'cpu')
17 | dataloader_test = build_dataset_val(args)
18 | len_test = len(dataloader_test)
19 |
20 | if not args.manual_seed:
21 | seed = random.randint(1,10000)
22 | else:
23 | seed = int(args.manual_seed)
24 | logging.info('Random Seed: %d' % seed)
25 | random.seed(seed)
26 | torch.manual_seed(seed)
27 | ckpt = torch.load(args.ckpt_path,map_location=device)
28 | if args.distributed:
29 | promodel = torch.nn.DataParallel(ProModel(),device_ids=args.gpus,output_device=args.gpus[0])
30 | promodel.to(device)
31 | promodel.module.load_state_dict(ckpt['net_state_dict'])
32 | logging.info("%s's previous weights loaded." % args.model_name)
33 | promodel.module.eval()
34 | else:
35 | promodel = ProModel().to(device)
36 | promodel.load_state_dict(ckpt['net_state_dict'])
37 | logging.info("%s's previous weights loaded." % args.model_name)
38 | promodel.eval()
39 |
40 | logging.info('Testing....')
41 |
42 | test_loss_l1 = AverageValueMeter()
43 | test_loss_l2 = AverageValueMeter()
44 | test_f1 = AverageValueMeter()
45 | loss_function = SimplificationLoss()
46 | sphere_points = get_spherepoints(args.number_points,0.5)
47 |
48 |
49 | with tqdm(dataloader_test) as t:
50 | for i,data in enumerate(t):
51 | with torch.no_grad():
52 | images = data['image'].to(device)
53 | gt = data['points'].to(device)
54 | partial_pcs = torch.tensor(sphere_points).to(torch.float32).unsqueeze(0).repeat(images.shape[0], 1, 1).to(device)
55 |
56 | category = data['category']
57 | category_indices = [int(c) for c in category]
58 | text_labels = torch.tensor(category_indices).to(device)
59 |
60 | batch_size = gt.shape[0]
61 |
62 | pred_points = promodel(images,partial_pcs,text_labels)
63 | pred_points = pred_points.transpose(2,1)
64 | loss_p,loss_t,f1 = loss_function(pred_points,gt,calc_f1=True)
65 |
66 | cd_l1_item = torch.sum(loss_p).item() / batch_size
67 | cd_l2_item = torch.sum(loss_t).item() / batch_size
68 | f1_item = torch.sum(f1).item() / batch_size
69 | test_loss_l1.update(cd_l1_item, images.shape[0])
70 | test_loss_l2.update(cd_l2_item, images.shape[0])
71 | test_f1.update(f1_item, images.shape[0])
72 |
73 | print('cd_l1 %f cd_l2 %f f1 %f' % (test_loss_l1.avg, test_loss_l2.avg,test_f1.avg))
74 |
75 | if __name__ == "__main__":
76 | parser = argparse.ArgumentParser(description='Train config file')
77 | parser.add_argument('-c', '--config', help='path to config file', required=True)
78 | parser.add_argument('-gpu', '--gpu_id', help='gpu_id', required=True)
79 | arg = parser.parse_args()
80 | config_path = arg.config
81 | args = munch.munchify(yaml.safe_load(open(config_path)))
82 |
83 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
84 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu_id)
85 | print('Using gpu:' + str(arg.gpu_id))
86 | val()
87 | val()
--------------------------------------------------------------------------------