├── OverlapDetect ├── main.py ├── model.py └── util.py ├── OverlapReg ├── main.py ├── model.py └── util.py ├── README.md ├── data_utils.py ├── evaluate_funcs.py └── feature_extract.py /OverlapDetect/main.py: -------------------------------------------------------------------------------- 1 | import torch.multiprocessing 2 | torch.multiprocessing.set_sharing_strategy('file_system') 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import optim 9 | from torch.utils.data import Dataset, TensorDataset, DataLoader 10 | from torch.nn import DataParallel 11 | from tensorboardX import SummaryWriter 12 | from tqdm import tqdm 13 | import sys, os 14 | sys.path.append("..") 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 16 | from data_utils import ModelNet40_Reg 17 | from model import OverlapNet 18 | from evaluate_funcs import evaluate_mask 19 | 20 | 21 | use_cuda = torch.cuda.is_available() 22 | gpu_id = 0 23 | torch.cuda.set_device(gpu_id) 24 | if not os.path.isdir("./logs"): 25 | os.mkdir("./logs") 26 | writer = SummaryWriter('./logs') 27 | batch_size = 32 28 | epochs = 500 29 | lr = 1e-3 30 | partial_overlap = 2 31 | subsampled_rate_src = 0.8 32 | subsampled_rate_tgt = 0.8 33 | unseen = False 34 | noise = False 35 | file_type = 'modelnet40' 36 | # file_type = 'Kitti_odo' 37 | # file_type = 'bunny' 38 | 39 | # set seed 40 | def setup_seed(seed): 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | np.random.seed(seed) 45 | random.seed(seed) 46 | torch.backends.cudnn.benchmark = False 47 | torch.backends.cudnn.deterministic = True 48 | 49 | # 设置随机数种子 50 | setup_seed(1234) 51 | 52 | 53 | def test_one_epoch(net, test_loader): 54 | net.eval() 55 | total_loss = 0 56 | loss_fn = nn.CrossEntropyLoss() 57 | 58 | with torch.no_grad(): 59 | accs_src = [] 60 | preciss_src = [] 61 | recalls_src = [] 62 | f1s_src = [] 63 | 64 | accs_tgt = [] 65 | preciss_tgt = [] 66 | recalls_tgt = [] 67 | f1s_tgt = [] 68 | for src, target, rotation, translation, euler, gt_mask_src, gt_mask_tgt in tqdm(test_loader): 69 | 70 | if use_cuda: 71 | src = src.cuda() 72 | target = target.cuda() 73 | 74 | mask_src, mask_tgt, mask_src_idx, mask_tgt_idx = net(src, target) 75 | 76 | loss = loss_fn(mask_src, gt_mask_src.long().cuda()) + loss_fn(mask_tgt, gt_mask_tgt.long().cuda()) 77 | 78 | total_loss += loss.item() 79 | 80 | # 评估 81 | acc_src, precis_src, recall_src, _ = evaluate_mask(torch.max(mask_src, dim=1)[1], gt_mask_src) 82 | _, _, _, f1_src = evaluate_mask(mask_src_idx, gt_mask_src) 83 | accs_src.append(acc_src) 84 | preciss_src.append(precis_src) 85 | recalls_src.append(recall_src) 86 | f1s_src.append(f1_src) 87 | 88 | acc_tgt, precis_tgt, recall_tgt, _ = evaluate_mask(torch.max(mask_tgt, dim=1)[1], gt_mask_tgt) 89 | _, _, _, f1_tgt = evaluate_mask(mask_tgt_idx, gt_mask_tgt) 90 | accs_tgt.append(acc_tgt) 91 | preciss_tgt.append(precis_tgt) 92 | recalls_tgt.append(recall_tgt) 93 | f1s_tgt.append(f1_tgt) 94 | 95 | acc_src = np.mean(accs_src) 96 | precis_src = np.mean(preciss_src) 97 | recall_src = np.mean(recalls_src) 98 | f1_src = np.mean(f1s_src) 99 | 100 | acc_tgt = np.mean(accs_tgt) 101 | precis_tgt = np.mean(preciss_tgt) 102 | recall_tgt = np.mean(recalls_tgt) 103 | f1_tgt = np.mean(f1s_tgt) 104 | 105 | f1 = (f1_src + f1_tgt) / 2 106 | acc = (acc_src + acc_tgt) / 2 107 | precis = (precis_src + precis_tgt) / 2 108 | recall = (recall_src + recall_tgt) / 2 109 | 110 | return total_loss, f1, acc, precis, recall 111 | 112 | 113 | def train_one_epoch(net, opt, train_loader): 114 | net.train() 115 | total_loss = 0 116 | accs_src = [] 117 | preciss_src = [] 118 | recalls_src = [] 119 | f1s_src = [] 120 | 121 | accs_tgt = [] 122 | preciss_tgt = [] 123 | recalls_tgt = [] 124 | f1s_tgt = [] 125 | loss_fn = nn.CrossEntropyLoss() 126 | 127 | for src, target, rotation, translation, euler, gt_mask_src, gt_mask_tgt in tqdm(train_loader): 128 | # print(src.shape, target.shape) 129 | if use_cuda: 130 | src = src.cuda() 131 | target = target.cuda() 132 | 133 | mask_src, mask_tgt, mask_src_idx, mask_tgt_idx = net(src, target) 134 | 135 | opt.zero_grad() 136 | loss1 = loss_fn(mask_src, gt_mask_src.long().cuda()) 137 | loss2 = loss_fn(mask_tgt, gt_mask_tgt.long().cuda()) 138 | a = 0.5 139 | loss = (1-a)*loss1 + a*loss2 140 | total_loss += loss.item() 141 | loss.backward() 142 | # nn.utils.clip_grad_norm_(net.parameters(), 5, norm_type=2) 143 | opt.step() 144 | 145 | # 评估 146 | acc_src, precis_src, recall_src, _ = evaluate_mask(torch.max(mask_src, dim=1)[1], gt_mask_src) 147 | _, _, _, f1_src = evaluate_mask(mask_src_idx, gt_mask_src) 148 | accs_src.append(acc_src) 149 | preciss_src.append(precis_src) 150 | recalls_src.append(recall_src) 151 | f1s_src.append(f1_src) 152 | 153 | acc_tgt, precis_tgt, recall_tgt, _ = evaluate_mask(torch.max(mask_tgt, dim=1)[1], gt_mask_tgt) 154 | _, _, _, f1_tgt = evaluate_mask(mask_tgt_idx, gt_mask_tgt) 155 | accs_tgt.append(acc_tgt) 156 | preciss_tgt.append(precis_tgt) 157 | recalls_tgt.append(recall_tgt) 158 | f1s_tgt.append(f1_tgt) 159 | # print(acc_tgt, precis_tgt, recall_tgt, f1_tgt) 160 | 161 | acc_src = np.mean(accs_src) 162 | precis_src = np.mean(preciss_src) 163 | recall_src = np.mean(recalls_src) 164 | f1_src = np.mean(f1s_src) 165 | 166 | acc_tgt = np.mean(accs_tgt) 167 | precis_tgt = np.mean(preciss_tgt) 168 | recall_tgt = np.mean(recalls_tgt) 169 | f1_tgt = np.mean(f1s_tgt) 170 | 171 | f1 = (f1_src + f1_tgt) / 2 172 | acc = (acc_src + acc_tgt) / 2 173 | precis = (precis_src + precis_tgt) / 2 174 | recall = (recall_src + recall_tgt) / 2 175 | 176 | return total_loss, f1, acc, precis, recall 177 | 178 | 179 | if __name__ == '__main__': 180 | 181 | best_loss = np.inf 182 | best_f1 = 0 183 | best_precis = 0 184 | best_recall = 0 185 | best_acc = 0 186 | 187 | if file_type == 'modelnet40': 188 | all_points = 1024 189 | src_subsampled_points = int(subsampled_rate_src * all_points) 190 | tgt_subsampled_points = int(subsampled_rate_tgt * all_points) 191 | elif file_type in ['Kitti_odo', 'bunny']: 192 | all_points = 2048 193 | src_subsampled_points = int(subsampled_rate_src * all_points) 194 | tgt_subsampled_points = int(subsampled_rate_tgt * all_points) 195 | 196 | train_loader = DataLoader( 197 | dataset=ModelNet40_Reg(all_points, partition='train', max_angle=45, max_t=0.5, unseen=unseen, file_type=file_type, 198 | subsampled_rate_src=subsampled_rate_src, subsampled_rate_tgt=subsampled_rate_tgt, 199 | partial_overlap=partial_overlap, noise=noise), 200 | batch_size=batch_size, 201 | shuffle=True, 202 | num_workers=4, 203 | pin_memory=False, 204 | prefetch_factor=2 205 | ) 206 | test_loader = DataLoader( 207 | dataset=ModelNet40_Reg(all_points, partition='test', max_angle=45, max_t=0.5, unseen=unseen, file_type=file_type, 208 | subsampled_rate_src=subsampled_rate_src, subsampled_rate_tgt=subsampled_rate_tgt, 209 | partial_overlap=partial_overlap, noise=noise), 210 | batch_size=batch_size, 211 | shuffle=False, 212 | num_workers=4, 213 | pin_memory=False, 214 | prefetch_factor=2 215 | ) 216 | 217 | net = OverlapNet(all_points=all_points, src_subsampled_points=src_subsampled_points, 218 | tgt_subsampled_points=tgt_subsampled_points) 219 | opt = optim.RAdam(params=net.parameters(), lr=lr) 220 | 221 | if use_cuda: 222 | net = net.cuda() 223 | # net = DataParallel(net, device_ids=[0, 1]) 224 | 225 | start_epoch = -1 226 | RESUME = False # 是否加载模型继续上次训练 227 | if RESUME: 228 | path_checkpoint = "./checkpoint/ckpt%s.pth"%(str(file_type)+str(subsampled_rate_src)+str(subsampled_rate_tgt)) # 断点路径 229 | checkpoint = torch.load(path_checkpoint, map_location=lambda storage, loc: storage.cuda(gpu_id)) # 加载断点 230 | net.load_state_dict(checkpoint['net']) # 加载模型可学习参数 231 | # scheduler.load_state_dict(checkpoint["lr_step"]) 232 | opt.load_state_dict(checkpoint['optimizer']) # 加载优化器参数 233 | start_epoch = checkpoint['epoch'] # 设置开始的epoch 234 | # 加载上次best结果 235 | best_loss = checkpoint['best_loss'] 236 | best_precis = checkpoint['best_Precis'] 237 | best_recall = checkpoint['best_Recall'] 238 | best_acc = checkpoint['best_Acc'] 239 | best_f1 = checkpoint['best_f1'] 240 | 241 | for epoch in range(start_epoch + 1, epochs): 242 | 243 | train_total_loss, train_f1, train_acc, train_precis, train_recall = train_one_epoch(net, opt, train_loader) 244 | 245 | test_total_loss, test_f1, test_acc, test_precis, test_recall = test_one_epoch(net, test_loader) 246 | 247 | if test_f1 >= best_f1: 248 | best_loss = test_total_loss 249 | best_precis = test_precis 250 | best_recall = test_recall 251 | best_f1 = test_f1 252 | best_acc = test_acc 253 | # 保存最好的checkpoint 254 | checkpoint_best = { 255 | "net": net.state_dict(), 256 | } 257 | if not os.path.isdir("./checkpoint"): 258 | os.mkdir("./checkpoint") 259 | torch.save(checkpoint_best, './checkpoint/ckpt_best%s.pth'%(str(file_type)+str(subsampled_rate_src)+str(subsampled_rate_tgt))) 260 | 261 | print('---------Epoch: %d---------' % (epoch+1)) 262 | print('Train: Loss: %f, F1: %f, Acc: %f, Precis: %f, Recall: %f' 263 | % (train_total_loss, train_f1, train_acc, train_precis, train_recall)) 264 | 265 | print('Test: Loss: %f, F1: %f, Acc: %f, Precis: %f, Recall: %f' 266 | % (test_total_loss, test_f1, test_acc, test_precis, test_recall)) 267 | 268 | print('Best: Loss: %f, F1: %f, Acc: %f, Precis: %f, Recall: %f' 269 | % (best_loss, best_f1, best_acc, best_precis, best_recall)) 270 | writer.add_scalar('Train/train_loss', train_total_loss, global_step=epoch) 271 | writer.add_scalar('Train/train_Precis', train_precis, global_step=epoch) 272 | writer.add_scalar('Train/train_Recall', train_recall, global_step=epoch) 273 | writer.add_scalar('Train/train_Acc', train_acc, global_step=epoch) 274 | writer.add_scalar('Train/train_F1', train_f1, global_step=epoch) 275 | 276 | writer.add_scalar('Test/test_loss', test_total_loss, global_step=epoch) 277 | writer.add_scalar('Test/test_Precis', test_precis, global_step=epoch) 278 | writer.add_scalar('Test/test_Recall', test_recall, global_step=epoch) 279 | writer.add_scalar('Test/test_Acc', test_acc, global_step=epoch) 280 | writer.add_scalar('Test/test_F1', test_f1, global_step=epoch) 281 | 282 | writer.add_scalar('Best/best_loss', best_loss, global_step=epoch) 283 | writer.add_scalar('Best/best_Precis', best_precis, global_step=epoch) 284 | writer.add_scalar('Best/best_Recall', best_recall, global_step=epoch) 285 | writer.add_scalar('Best/best_Acc', best_acc, global_step=epoch) 286 | writer.add_scalar('Best/best_F1', best_f1, global_step=epoch) 287 | 288 | # 保存checkpoint 289 | checkpoint = { 290 | "net": net.state_dict(), 291 | 'optimizer': opt.state_dict(), 292 | "epoch": epoch, 293 | # "lr_step": scheduler.state_dict(), 294 | "best_loss": best_loss, 295 | 'best_Precis': best_precis, 296 | 'best_Recall': best_recall, 297 | 'best_Acc': best_acc, 298 | 'best_f1': best_f1, 299 | } 300 | if not os.path.isdir("./checkpoint"): 301 | os.mkdir("./checkpoint") 302 | torch.save(checkpoint, './checkpoint/ckpt%s.pth'%(str(file_type)+str(subsampled_rate_src)+str(subsampled_rate_tgt))) 303 | writer.close() 304 | 305 | 306 | 307 | 308 | -------------------------------------------------------------------------------- /OverlapDetect/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | import math, sys 7 | sys.path.append("..") 8 | from feature_extract import PointNet, DGCNN 9 | 10 | 11 | def mask_point(mask_idx, points): 12 | # masks: [b, n] : Tensor, 包含0和1 13 | # points: [b, 3, n] : Tensor 14 | # return: [b, 3, n2] : Tensor 15 | batch_size = points.shape[0] 16 | points = points.permute(0, 2, 1) 17 | mask_idx = mask_idx.reshape(batch_size, -1, 1) 18 | new_pcs = points * mask_idx 19 | new_points = [] 20 | 21 | for new_pc in new_pcs: 22 | # 删除被屏蔽的0点 23 | temp = new_pc[:, ...] == 0 24 | temp = temp.cpu() 25 | idx = np.argwhere(temp.all(axis=1)) 26 | new_point = np.delete(new_pc.cpu().detach().numpy(), idx, axis=0) 27 | 28 | new_points.append(new_point) 29 | 30 | new_points = np.array(new_points) 31 | new_points = torch.from_numpy(new_points) 32 | if torch.cuda.is_available(): 33 | new_points = new_points.cuda() 34 | return new_points.permute(0, 2, 1) 35 | 36 | 37 | def gather_points(points, inds): 38 | ''' 39 | :param points: shape=(B, N, C) 40 | :param inds: shape=(B, M) or shape=(B, M, K) 41 | :return: sampling points: shape=(B, M, C) or shape=(B, M, K, C) 42 | ''' 43 | device = points.device 44 | B, N, C = points.shape 45 | inds_shape = list(inds.shape) 46 | inds_shape[1:] = [1] * len(inds_shape[1:]) 47 | repeat_shape = list(inds.shape) 48 | repeat_shape[0] = 1 49 | batchlists = torch.arange(0, B, dtype=torch.long).to(device).reshape(inds_shape).repeat(repeat_shape) 50 | return points[batchlists, inds, :] 51 | 52 | 53 | def feature_interaction(src_embedding, tar_embedding): 54 | # embedding: (batch, emb_dims, num_points) 55 | num_points1 = src_embedding.shape[2] 56 | 57 | simi1 = cos_simi(src_embedding, tar_embedding) # (num_points1, num_points2) 58 | 59 | src_embedding = src_embedding.permute(0, 2, 1) 60 | tar_embedding = tar_embedding.permute(0, 2, 1) 61 | 62 | simi_src = nn.Softmax(dim=2)(simi1) # 转化为概率 63 | glob_tar = torch.matmul(simi_src, tar_embedding) # 加权平均tar的全局特征 64 | glob_src = torch.max(src_embedding, dim=1, keepdim=True)[0] 65 | glob_src = glob_src.repeat(1, num_points1, 1) 66 | # print(glob_src.shape, glob_tar.shape,src_embedding.shape) 67 | inter_src_feature = torch.cat((src_embedding, glob_tar, glob_src, glob_tar-glob_src), dim=2) # 交互特征 68 | inter_src_feature = inter_src_feature.permute(0, 2, 1) 69 | 70 | return inter_src_feature 71 | 72 | 73 | def cos_simi(src_embedding, tgt_embedding): 74 | # (batch, emb_dims, num_points) 75 | # src_norm = src_embedding / (src_embedding.norm(dim=1).reshape(batch_size, 1, num_points1)) 76 | # tar_norm = tgt_embedding / (tgt_embedding.norm(dim=1).reshape(batch_size, 1, num_points2)) 77 | src_norm = F.normalize(src_embedding, p=2, dim=1) 78 | tar_norm = F.normalize(tgt_embedding, p=2, dim=1) 79 | simi = torch.matmul(src_norm.transpose(2, 1).contiguous(), tar_norm) # (batch, num_points1, num_points2) 80 | return simi 81 | 82 | 83 | class OverlapNet(nn.Module): 84 | def __init__(self, n_emb_dims=1024, all_points=1024, src_subsampled_points=768, tgt_subsampled_points=768): 85 | super(OverlapNet, self).__init__() 86 | self.emb_dims = n_emb_dims 87 | self.all_points = all_points 88 | self.emb_dims1 = int(self.emb_dims / 2) 89 | self.src_subsampled_points = src_subsampled_points 90 | self.tgt_subsampled_points = tgt_subsampled_points 91 | self.emb_nn1 = DGCNN(self.emb_dims1, k=32) 92 | self.emb_nn2_src = nn.Sequential( 93 | nn.Conv1d(self.emb_dims1 * 4, self.emb_dims1 * 2, 1), nn.BatchNorm1d(self.emb_dims1 * 2),nn.LeakyReLU(negative_slope=0.01), 94 | nn.Conv1d(self.emb_dims1 * 2, self.emb_dims, 1), nn.BatchNorm1d(self.emb_dims),nn.LeakyReLU(negative_slope=0.01), 95 | ) 96 | self.emb_nn2_tgt = nn.Sequential( 97 | nn.Conv1d(self.emb_dims1 * 4, self.emb_dims1 * 2, 1), nn.BatchNorm1d(self.emb_dims1 * 2),nn.LeakyReLU(negative_slope=0.01), 98 | nn.Conv1d(self.emb_dims1 * 2, self.emb_dims, 1), nn.BatchNorm1d(self.emb_dims),nn.LeakyReLU(negative_slope=0.01), 99 | ) 100 | self.score_nn_src = nn.Sequential( 101 | nn.Conv1d(self.emb_dims, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 102 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 103 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 104 | nn.Conv1d(128, 1, 1), nn.Sigmoid() 105 | ) 106 | self.score_nn_tgt = nn.Sequential( 107 | nn.Conv1d(self.emb_dims, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 108 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 109 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 110 | nn.Conv1d(128, 1, 1), nn.Sigmoid() 111 | ) 112 | self.mask_src_nn = nn.Sequential( 113 | nn.Conv1d(self.tgt_subsampled_points, 1024, 1), nn.BatchNorm1d(1024), nn.LeakyReLU(negative_slope=0.01), 114 | nn.Conv1d(1024, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 115 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 116 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 117 | nn.Conv1d(128, 2, 1) 118 | ) 119 | self.mask_tgt_nn = nn.Sequential( 120 | nn.Conv1d(self.src_subsampled_points, 1024, 1), nn.BatchNorm1d(1024), nn.LeakyReLU(negative_slope=0.01), 121 | nn.Conv1d(1024, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 122 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 123 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 124 | nn.Conv1d(128, 2, 1) 125 | ) 126 | 127 | def forward(self, *input): 128 | src = input[0] # 1024 129 | tgt = input[1] # 768 130 | batch_size = src.shape[0] 131 | 132 | src_embedding = self.emb_nn1(src) 133 | tgt_embedding = self.emb_nn1(tgt) 134 | # 特征融合 135 | inter_src_feature = feature_interaction(src_embedding, tgt_embedding) 136 | inter_tar_feature = feature_interaction(tgt_embedding, src_embedding) 137 | # 进一步提取特征 138 | src_embedding = self.emb_nn2_src(inter_src_feature) 139 | tgt_embedding = self.emb_nn2_tgt(inter_tar_feature) 140 | # 计算打分 141 | src_score = self.score_nn_src(src_embedding).reshape(batch_size, 1, -1) 142 | tgt_score = self.score_nn_tgt(tgt_embedding).reshape(batch_size, 1, -1) 143 | 144 | src_score = nn.Softmax(dim=2)(src_score) 145 | tgt_score = nn.Softmax(dim=2)(tgt_score) 146 | 147 | simi1 = cos_simi(src_embedding, tgt_embedding) 148 | simi2 = cos_simi(tgt_embedding, src_embedding) 149 | 150 | # 结合打分计算相似度 151 | simi_src = simi1 * tgt_score 152 | simi_tgt = simi2 * src_score 153 | 154 | mask_src = self.mask_src_nn(simi_src.permute(0, 2, 1)) 155 | mask_tgt = self.mask_tgt_nn(simi_tgt.permute(0, 2, 1)) 156 | overlap_points = self.all_points - (self.all_points - self.src_subsampled_points)\ 157 | - (self.all_points - self.tgt_subsampled_points) 158 | 159 | mask_src_score = torch.softmax(mask_src, dim=1)[:, 1, :].detach() # (B, N) 160 | mask_tgt_score = torch.softmax(mask_tgt, dim=1)[:, 1, :].detach() 161 | # 取前overlap_points个点作为重叠点 162 | mask_src_idx = torch.zeros(mask_src_score.shape).cuda() 163 | values, indices = torch.topk(mask_src_score, k=overlap_points, dim=1) 164 | mask_src_idx.scatter_(1, indices, 1) # (dim, 索引, 根据索引赋的值) 165 | # mask_src_idx = torch.where(mask_src > values[:, -1].reshape(batch_size, -1), 1, 0) 166 | 167 | mask_tgt_idx = torch.zeros(mask_tgt_score.shape).cuda() 168 | values, indices = torch.topk(mask_tgt_score, k=overlap_points, dim=1) 169 | mask_tgt_idx.scatter_(1, indices, 1) # (dim, 索引, 根据索引赋的值) 170 | # mask_tgt_idx = torch.where(mask_tgt > values[:, -1].reshape(batch_size, -1), 1, 0) 171 | 172 | return mask_src, mask_tgt, mask_src_idx, mask_tgt_idx 173 | 174 | 175 | # # src,tar:[batchsize, 3, num_points] 176 | # src = torch.rand([4, 3, 800]) 177 | # tar = torch.rand([4, 3, 768]) 178 | # model = OverlapNet() 179 | # mask_src, mask_tgt, mask_src_idx, mask_tgt_idx = model(src, tar) 180 | 181 | 182 | -------------------------------------------------------------------------------- /OverlapDetect/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_transform(batch_pc, batch_R, batch_t=None): 5 | ''' 6 | :param batch_pc: shape=(B, N, 3) 7 | :param batch_R: shape=(B, 3, 3) 8 | :param batch_t: shape=(B, 3) 9 | :return: shape(B, N, 3) 10 | ''' 11 | transformed_pc = torch.matmul(batch_pc, batch_R.permute(0, 2, 1).contiguous()) 12 | if batch_t is not None: 13 | transformed_pc = transformed_pc + torch.unsqueeze(batch_t, 1) 14 | return transformed_pc -------------------------------------------------------------------------------- /OverlapReg/main.py: -------------------------------------------------------------------------------- 1 | import torch.multiprocessing 2 | torch.multiprocessing.set_sharing_strategy('file_system') 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import optim 9 | from torch.utils.data import Dataset, TensorDataset, DataLoader 10 | from torch.nn import DataParallel 11 | from tensorboardX import SummaryWriter 12 | from tqdm import tqdm 13 | import sys, os 14 | sys.path.append("..") 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 16 | from data_utils import ModelNet40_Reg 17 | from model import OverlapNet 18 | from evaluate_funcs import calculate_R_msemae, calculate_t_msemae, evaluate_mask 19 | 20 | 21 | use_cuda = torch.cuda.is_available() 22 | gpu_id = 0 23 | torch.cuda.set_device(gpu_id) 24 | if not os.path.isdir("./logs"): 25 | os.mkdir("./logs") 26 | writer = SummaryWriter('./logs') 27 | batch_size = 16 # 16 28 | epochs = 500 29 | lr = 1e-3 30 | partial_overlap = 2 31 | subsampled_rate_src = 0.8 32 | subsampled_rate_tgt = 0.8 33 | unseen = False 34 | noise = False 35 | file_type = 'modelnet40' 36 | # file_type = 'Kitti_odo' 37 | # file_type = 'bunny' 38 | 39 | 40 | # set seed 41 | def setup_seed(seed): 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed(seed) 44 | torch.cuda.manual_seed_all(seed) 45 | np.random.seed(seed) 46 | random.seed(seed) 47 | torch.backends.cudnn.benchmark = False 48 | torch.backends.cudnn.deterministic = True 49 | 50 | # 设置随机数种子 51 | setup_seed(1234) 52 | 53 | 54 | def test_one_epoch(net, test_loader): 55 | net.eval() 56 | total_loss = 0 57 | reg_total_loss = 0 58 | 59 | loss_fn = nn.CrossEntropyLoss() 60 | 61 | with torch.no_grad(): 62 | 63 | accs_src = [] 64 | preciss_src = [] 65 | recalls_src = [] 66 | f1s_src = [] 67 | 68 | accs_tgt = [] 69 | preciss_tgt = [] 70 | recalls_tgt = [] 71 | f1s_tgt = [] 72 | 73 | Rs_gt = [] 74 | ts_gt = [] 75 | Rs1_pred = [] 76 | ts1_pred = [] 77 | for src, target, R_gt, t_gt, euler, gt_mask_src, gt_mask_tgt in tqdm(test_loader): 78 | batch_size = src.shape[0] 79 | if use_cuda: 80 | src = src.cuda() 81 | target = target.cuda() 82 | R_gt = R_gt.cuda() 83 | t_gt = t_gt.cuda() 84 | 85 | init_R, init_t, mask_src, mask_tgt, mask_src_idx, mask_tgt_idx, \ 86 | R1_pred, t1_pred = net(src, target) 87 | 88 | # compute mask loss and registration loss 89 | loss_mask = loss_fn(mask_src, gt_mask_src.long().cuda()) + loss_fn(mask_tgt, gt_mask_tgt.long().cuda()) 90 | E = torch.eye(3).unsqueeze(0).repeat(batch_size, 1, 1).cuda() 91 | loss_init = F.mse_loss(torch.matmul(init_R.transpose(2, 1), R_gt), E) \ 92 | + F.mse_loss(init_t, t_gt) 93 | loss_reg1 = F.mse_loss(torch.matmul(R1_pred.transpose(2, 1), R_gt), E) \ 94 | + F.mse_loss(t1_pred, t_gt) 95 | loss_reg = loss_reg1 96 | loss = loss_mask + loss_reg + loss_init 97 | total_loss += loss.item() 98 | reg_total_loss += loss_reg.item() 99 | 100 | Rs_gt.append(R_gt.detach().cpu().numpy()) 101 | ts_gt.append(t_gt.detach().cpu().numpy()) 102 | Rs1_pred.append(R1_pred.detach().cpu().numpy()) 103 | ts1_pred.append(t1_pred.detach().cpu().numpy()) 104 | # 评估 105 | acc_src, precis_src, recall_src, _ = evaluate_mask(torch.max(mask_src, dim=1)[1], gt_mask_src) 106 | _, _, _, f1_src = evaluate_mask(mask_src_idx, gt_mask_src) 107 | accs_src.append(acc_src) 108 | preciss_src.append(precis_src) 109 | recalls_src.append(recall_src) 110 | f1s_src.append(f1_src) 111 | 112 | acc_tgt, precis_tgt, recall_tgt, _ = evaluate_mask(torch.max(mask_tgt, dim=1)[1], gt_mask_tgt) 113 | _, _, _, f1_tgt = evaluate_mask(mask_tgt_idx, gt_mask_tgt) 114 | accs_tgt.append(acc_tgt) 115 | preciss_tgt.append(precis_tgt) 116 | recalls_tgt.append(recall_tgt) 117 | f1s_tgt.append(f1_tgt) 118 | 119 | Rs_gt = np.concatenate(Rs_gt, axis=0) 120 | ts_gt = np.concatenate(ts_gt, axis=0) 121 | Rs1_pred = np.concatenate(Rs1_pred, axis=0) 122 | ts1_pred = np.concatenate(ts1_pred, axis=0) 123 | 124 | acc_src = np.mean(accs_src) 125 | precis_src = np.mean(preciss_src) 126 | recall_src = np.mean(recalls_src) 127 | f1_src = np.mean(f1s_src) 128 | 129 | acc_tgt = np.mean(accs_tgt) 130 | precis_tgt = np.mean(preciss_tgt) 131 | recall_tgt = np.mean(recalls_tgt) 132 | f1_tgt = np.mean(f1s_tgt) 133 | 134 | f1 = (f1_src + f1_tgt) / 2 135 | acc = (acc_src + acc_tgt) / 2 136 | precis = (precis_src + precis_tgt) / 2 137 | recall = (recall_src + recall_tgt) / 2 138 | 139 | return total_loss, f1, acc, precis, recall, Rs_gt, ts_gt, Rs1_pred, ts1_pred, reg_total_loss 140 | 141 | 142 | def train_one_epoch(net, opt, train_loader): 143 | net.train() 144 | total_loss = 0 145 | reg_total_loss = 0 146 | accs_src = [] 147 | preciss_src = [] 148 | recalls_src = [] 149 | f1s_src = [] 150 | 151 | accs_tgt = [] 152 | preciss_tgt = [] 153 | recalls_tgt = [] 154 | f1s_tgt = [] 155 | 156 | Rs_gt = [] 157 | ts_gt = [] 158 | Rs1_pred = [] 159 | ts1_pred = [] 160 | 161 | loss_fn = nn.CrossEntropyLoss() 162 | 163 | for src, target, R_gt, t_gt, euler, gt_mask_src, gt_mask_tgt in tqdm(train_loader): 164 | batch_size = src.shape[0] 165 | # print(src.shape, target.shape) 166 | if use_cuda: 167 | src = src.cuda() 168 | target = target.cuda() 169 | R_gt = R_gt.cuda() 170 | t_gt = t_gt.cuda() 171 | 172 | init_R, init_t, mask_src, mask_tgt, mask_src_idx, mask_tgt_idx, \ 173 | R1_pred, t1_pred = net(src, target) 174 | 175 | opt.zero_grad() 176 | # compute mask loss and registration loss 177 | loss_mask = loss_fn(mask_src, gt_mask_src.long().cuda()) + loss_fn(mask_tgt, gt_mask_tgt.long().cuda()) 178 | E = torch.eye(3).unsqueeze(0).repeat(batch_size, 1, 1).cuda() 179 | loss_init = F.mse_loss(torch.matmul(init_R.transpose(2, 1), R_gt), E) \ 180 | + F.mse_loss(init_t, t_gt) 181 | loss_reg1 = F.mse_loss(torch.matmul(R1_pred.transpose(2, 1), R_gt), E) \ 182 | + F.mse_loss(t1_pred, t_gt) 183 | loss_reg = loss_reg1 184 | # loss = loss_mask/loss_mask.detach() + loss_reg/loss_reg.detach() 185 | loss = loss_mask + loss_reg + loss_init 186 | # print((loss_mask/loss_mask.detach()).item(), loss_reg.item()) 187 | total_loss += loss_mask.item() + loss_reg.item() 188 | reg_total_loss += loss_reg.item() 189 | loss.backward() 190 | # nn.utils.clip_grad_norm_(net.parameters(), 5, norm_type=2) 191 | opt.step() 192 | 193 | Rs_gt.append(R_gt.detach().cpu().numpy()) 194 | ts_gt.append(t_gt.detach().cpu().numpy()) 195 | Rs1_pred.append(R1_pred.detach().cpu().numpy()) 196 | ts1_pred.append(t1_pred.detach().cpu().numpy()) 197 | # 评估 198 | acc_src, precis_src, recall_src, _ = evaluate_mask(torch.max(mask_src, dim=1)[1], gt_mask_src) 199 | _, _, _, f1_src = evaluate_mask(mask_src_idx, gt_mask_src) 200 | accs_src.append(acc_src) 201 | preciss_src.append(precis_src) 202 | recalls_src.append(recall_src) 203 | f1s_src.append(f1_src) 204 | 205 | acc_tgt, precis_tgt, recall_tgt, _ = evaluate_mask(torch.max(mask_tgt, dim=1)[1], gt_mask_tgt) 206 | _, _, _, f1_tgt = evaluate_mask(mask_tgt_idx, gt_mask_tgt) 207 | accs_tgt.append(acc_tgt) 208 | preciss_tgt.append(precis_tgt) 209 | recalls_tgt.append(recall_tgt) 210 | f1s_tgt.append(f1_tgt) 211 | 212 | Rs_gt = np.concatenate(Rs_gt, axis=0) 213 | ts_gt = np.concatenate(ts_gt, axis=0) 214 | Rs1_pred = np.concatenate(Rs1_pred, axis=0) 215 | ts1_pred = np.concatenate(ts1_pred, axis=0) 216 | acc_src = np.mean(accs_src) 217 | precis_src = np.mean(preciss_src) 218 | recall_src = np.mean(recalls_src) 219 | f1_src = np.mean(f1s_src) 220 | 221 | acc_tgt = np.mean(accs_tgt) 222 | precis_tgt = np.mean(preciss_tgt) 223 | recall_tgt = np.mean(recalls_tgt) 224 | f1_tgt = np.mean(f1s_tgt) 225 | 226 | f1 = (f1_src + f1_tgt) / 2 227 | acc = (acc_src + acc_tgt) / 2 228 | precis = (precis_src + precis_tgt) / 2 229 | recall = (recall_src + recall_tgt) / 2 230 | 231 | return total_loss, f1, acc, precis, recall, Rs_gt, ts_gt, Rs1_pred, ts1_pred, reg_total_loss 232 | 233 | 234 | if __name__ == '__main__': 235 | 236 | best_loss = np.inf 237 | best_f1 = 0 238 | best_precis = 0 239 | best_recall = 0 240 | best_acc = 0 241 | 242 | best_R_mse = np.inf 243 | best_R_rmse = np.inf 244 | best_R_mae = np.inf 245 | best_t_mse = np.inf 246 | best_t_rmse = np.inf 247 | best_t_mae = np.inf 248 | 249 | best_t_error = np.inf 250 | best_angle_error = np.inf 251 | best_angle_mat_error = np.inf 252 | 253 | if file_type == 'modelnet40': 254 | all_points = 1024 255 | src_subsampled_points = int(subsampled_rate_src * all_points) 256 | tgt_subsampled_points = int(subsampled_rate_tgt * all_points) 257 | elif file_type in ['bunny', 'Kitti_odo']: 258 | all_points = 2048 259 | src_subsampled_points = int(subsampled_rate_src * all_points) 260 | tgt_subsampled_points = int(subsampled_rate_tgt * all_points) 261 | train_loader = DataLoader( 262 | dataset=ModelNet40_Reg(all_points, partition='train', max_angle=45, max_t=0.5, unseen=unseen, 263 | file_type=file_type, 264 | subsampled_rate_src=subsampled_rate_src, subsampled_rate_tgt=subsampled_rate_tgt, 265 | partial_overlap=partial_overlap, noise=noise), 266 | batch_size=batch_size, 267 | shuffle=True, 268 | num_workers=4, 269 | pin_memory=False, 270 | prefetch_factor=2 271 | ) 272 | test_loader = DataLoader( 273 | dataset=ModelNet40_Reg(all_points, partition='test', max_angle=45, max_t=0.5, unseen=unseen, 274 | file_type=file_type, 275 | subsampled_rate_src=subsampled_rate_src, subsampled_rate_tgt=subsampled_rate_tgt, 276 | partial_overlap=partial_overlap, noise=noise), 277 | batch_size=batch_size, 278 | shuffle=False, 279 | num_workers=4, 280 | pin_memory=False, 281 | prefetch_factor=2 282 | ) 283 | 284 | net = OverlapNet(all_points=all_points, src_subsampled_points=src_subsampled_points, tgt_subsampled_points=tgt_subsampled_points) 285 | opt = optim.RAdam(params=net.parameters(), lr=lr) 286 | 287 | if use_cuda: 288 | net = net.cuda() 289 | # net = DataParallel(net, device_ids=[0, 1]) 290 | 291 | start_epoch = -1 292 | RESUME = False # 是否加载模型继续上次训练 293 | if RESUME: 294 | path_checkpoint = "./checkpoint/ckpt%s.pth"%(str(file_type)+str(subsampled_rate_src)+str(subsampled_rate_tgt)) # 断点路径 295 | checkpoint = torch.load(path_checkpoint, map_location=lambda storage, loc: storage.cuda(gpu_id)) # 加载断点 296 | net.load_state_dict(checkpoint['net']) # 加载模型可学习参数 297 | # scheduler.load_state_dict(checkpoint["lr_step"]) 298 | opt.load_state_dict(checkpoint['optimizer']) # 加载优化器参数 299 | start_epoch = checkpoint['epoch'] # 设置开始的epoch 300 | # 加载上次best结果 301 | best_loss = checkpoint['best_loss'] 302 | best_precis = checkpoint['best_Precis'] 303 | best_recall = checkpoint['best_Recall'] 304 | best_acc = checkpoint['best_Acc'] 305 | best_f1 = checkpoint['best_f1'] 306 | 307 | best_R_mse = checkpoint['best_MSE(R)'] 308 | best_R_rmse = checkpoint['best_RMSE(R)'] 309 | best_R_mae = checkpoint['best_MAE(R)'] 310 | best_t_mse = checkpoint['best_MSE(t)'] 311 | best_t_rmse = checkpoint['best_RMSE(t)'] 312 | best_t_mae = checkpoint['best_MAE(t)'] 313 | 314 | for epoch in range(start_epoch + 1, epochs): 315 | train_total_loss, train_f1, train_acc, train_precis, train_recall, train_Rs_gt, train_ts_gt, \ 316 | train_Rs1_pred, train_ts1_pred, train_reg_loss = train_one_epoch(net, opt, train_loader) 317 | 318 | test_total_loss, test_f1, test_acc, test_precis, test_recall, test_Rs_gt, test_ts_gt, \ 319 | test_Rs1_pred, test_ts1_pred, test_reg_loss = test_one_epoch(net, test_loader) 320 | 321 | train_R_mse, train_R_mae = calculate_R_msemae(train_Rs_gt, train_Rs1_pred) 322 | train_R_rmse = np.sqrt(train_R_mse) 323 | train_t_mse, train_t_mae = calculate_t_msemae(train_ts_gt, train_ts1_pred) 324 | train_t_rmse = np.sqrt(train_t_mse) 325 | 326 | test_R_mse, test_R_mae = calculate_R_msemae(test_Rs_gt, test_Rs1_pred) 327 | test_R_rmse = np.sqrt(test_R_mse) 328 | test_t_mse, test_t_mae = calculate_t_msemae(test_ts_gt, test_ts1_pred) 329 | test_t_rmse = np.sqrt(test_t_mse) 330 | 331 | if test_reg_loss <= best_loss: 332 | best_loss = test_reg_loss 333 | best_precis = test_precis 334 | best_recall = test_recall 335 | best_f1 = test_f1 336 | best_acc = test_acc 337 | 338 | best_R_mse = test_R_mse 339 | best_R_rmse = test_R_rmse 340 | best_R_mae = test_R_mae 341 | 342 | best_t_mse = test_t_mse 343 | best_t_rmse = test_t_rmse 344 | best_t_mae = test_t_mae 345 | 346 | # 保存最好的checkpoint 347 | checkpoint_best = { 348 | "net": net.state_dict(), 349 | } 350 | if not os.path.isdir("./checkpoint"): 351 | os.mkdir("./checkpoint") 352 | torch.save(checkpoint_best, './checkpoint/ckpt_best%s.pth'%(str(file_type)+str(subsampled_rate_src)+str(subsampled_rate_tgt))) 353 | 354 | print('---------Epoch: %d---------' % (epoch+1)) 355 | print('Train: Loss: %f, F1: %f, Acc: %f, Precis: %f, Recall: %f, MSE(R): %f, RMSE(R): %f, ' 356 | 'MAE(R): %f, MSE(t): %f, RMSE(t): %f, MAE(t): %f' 357 | % (train_total_loss, train_f1, train_acc, train_precis, train_recall, train_R_mse, train_R_rmse, 358 | train_R_mae, train_t_mse, train_t_rmse, train_t_mae)) 359 | 360 | print('Test: Loss: %f, F1: %f, Acc: %f, Precis: %f, Recall: %f, MSE(R): %f, RMSE(R): %f, ' 361 | 'MAE(R): %f, MSE(t): %f, RMSE(t): %f, MAE(t): %f,' 362 | % (test_total_loss, test_f1, test_acc, test_precis, test_recall, test_R_mse, test_R_rmse, 363 | test_R_mae, test_t_mse, test_t_rmse, test_t_mae)) 364 | 365 | print('Best: Loss: %f, F1: %f, Acc: %f, Precis: %f, Recall: %f, MSE(R): %f, RMSE(R): %f, ' 366 | 'MAE(R): %f, MSE(t): %f, RMSE(t): %f, MAE(t): %f' 367 | % (best_loss, best_f1, best_acc, best_precis, best_recall, best_R_mse, best_R_rmse, 368 | best_R_mae, best_t_mse, best_t_rmse, best_t_mae)) 369 | writer.add_scalar('Train/train_loss', train_total_loss, global_step=epoch) 370 | writer.add_scalar('Train/train_Precis', train_precis, global_step=epoch) 371 | writer.add_scalar('Train/train_Recall', train_recall, global_step=epoch) 372 | writer.add_scalar('Train/train_Acc', train_acc, global_step=epoch) 373 | writer.add_scalar('Train/train_F1', train_f1, global_step=epoch) 374 | writer.add_scalar('Train/train_MSER', train_R_mse, global_step=epoch) 375 | writer.add_scalar('Train/train_RMSER', train_R_rmse, global_step=epoch) 376 | writer.add_scalar('Train/train_MAER', train_R_mae, global_step=epoch) 377 | writer.add_scalar('Train/train_MSEt', train_t_mse, global_step=epoch) 378 | writer.add_scalar('Train/train_RMSEt', train_t_rmse, global_step=epoch) 379 | writer.add_scalar('Train/train_MAEt', train_t_mae, global_step=epoch) 380 | 381 | writer.add_scalar('Test/test_loss', test_total_loss, global_step=epoch) 382 | writer.add_scalar('Test/test_Precis', test_precis, global_step=epoch) 383 | writer.add_scalar('Test/test_Recall', test_recall, global_step=epoch) 384 | writer.add_scalar('Test/test_Acc', test_acc, global_step=epoch) 385 | writer.add_scalar('Test/test_F1', test_f1, global_step=epoch) 386 | writer.add_scalar('Test/test_MSER', test_R_mse, global_step=epoch) 387 | writer.add_scalar('Test/test_RMSER', test_R_rmse, global_step=epoch) 388 | writer.add_scalar('Test/test_MAER', test_R_mae, global_step=epoch) 389 | writer.add_scalar('Test/test_MSEt', test_t_mse, global_step=epoch) 390 | writer.add_scalar('Test/test_RMSEt', test_t_rmse, global_step=epoch) 391 | writer.add_scalar('Test/test_MAEt', test_t_mae, global_step=epoch) 392 | 393 | writer.add_scalar('Best/best_loss', best_loss, global_step=epoch) 394 | writer.add_scalar('Best/best_Precis', best_precis, global_step=epoch) 395 | writer.add_scalar('Best/best_Recall', best_recall, global_step=epoch) 396 | writer.add_scalar('Best/best_Acc', best_acc, global_step=epoch) 397 | writer.add_scalar('Best/best_F1', best_f1, global_step=epoch) 398 | writer.add_scalar('Best/best_MSER', best_R_mse, global_step=epoch) 399 | writer.add_scalar('Best/best_RMSER', best_R_rmse, global_step=epoch) 400 | writer.add_scalar('Best/best_MAER', best_R_mae, global_step=epoch) 401 | writer.add_scalar('Best/best_MSEt', best_t_mse, global_step=epoch) 402 | writer.add_scalar('Best/best_RMSEt', best_t_rmse, global_step=epoch) 403 | writer.add_scalar('Best/best_MAEt', best_t_mae, global_step=epoch) 404 | 405 | # 保存checkpoint 406 | checkpoint = { 407 | "net": net.state_dict(), 408 | 'optimizer': opt.state_dict(), 409 | "epoch": epoch, 410 | # "lr_step": scheduler.state_dict(), 411 | "best_loss":best_loss, 412 | 'best_Precis': best_precis, 413 | 'best_Recall': best_recall, 414 | 'best_Acc': best_acc, 415 | 'best_f1': best_f1, 416 | 417 | 'best_MSE(R)': best_R_mse, 418 | 'best_RMSE(R)': best_R_rmse, 419 | 'best_MAE(R)': best_R_mae, 420 | 'best_MSE(t)': best_t_mse, 421 | 'best_RMSE(t)': best_t_rmse, 422 | 'best_MAE(t)': best_t_mae, 423 | } 424 | if not os.path.isdir("./checkpoint"): 425 | os.mkdir("./checkpoint") 426 | torch.save(checkpoint, './checkpoint/ckpt%s.pth'%(str(file_type)+str(subsampled_rate_src)+str(subsampled_rate_tgt))) 427 | writer.close() 428 | 429 | 430 | 431 | 432 | -------------------------------------------------------------------------------- /OverlapReg/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | import math, sys, random 7 | sys.path.append("..") 8 | from feature_extract import PointNet, DGCNN 9 | from util import batch_transform, batch_quat2mat 10 | 11 | 12 | def mask_point(mask_idx, points): 13 | # masks: [b, n] : Tensor, 包含0和1 14 | # points: [b, 3, n] : Tensor 15 | # return: [b, 3, n2] : Tensor 16 | batch_size = points.shape[0] 17 | points = points.permute(0, 2, 1) 18 | mask_idx = mask_idx.reshape(batch_size, -1, 1) 19 | new_pcs = points * mask_idx 20 | new_points = [] 21 | 22 | for new_pc in new_pcs: 23 | # 删除被屏蔽的0点 24 | temp = new_pc[:, ...] == 0 25 | temp = temp.cpu() 26 | idx = np.argwhere(temp.all(axis=1)) 27 | new_point = np.delete(new_pc.cpu().detach().numpy(), idx, axis=0) 28 | new_points.append(new_point) 29 | 30 | new_points = np.array(new_points) 31 | new_points = torch.from_numpy(new_points) 32 | if torch.cuda.is_available(): 33 | new_points = new_points.cuda() 34 | return new_points.permute(0, 2, 1) 35 | 36 | 37 | def gather_points(points, inds): 38 | ''' 39 | :param points: shape=(B, N, C) 40 | :param inds: shape=(B, M) or shape=(B, M, K) 41 | :return: sampling points: shape=(B, M, C) or shape=(B, M, K, C) 42 | ''' 43 | device = points.device 44 | B, N, C = points.shape 45 | inds_shape = list(inds.shape) 46 | inds_shape[1:] = [1] * len(inds_shape[1:]) 47 | repeat_shape = list(inds.shape) 48 | repeat_shape[0] = 1 49 | batchlists = torch.arange(0, B, dtype=torch.long).to(device).reshape(inds_shape).repeat(repeat_shape) 50 | return points[batchlists, inds, :] 51 | 52 | 53 | def feature_interaction(src_embedding, tar_embedding): 54 | # embedding: (batch, emb_dims, num_points) 55 | num_points1 = src_embedding.shape[2] 56 | 57 | simi1 = cos_simi(src_embedding, tar_embedding) # (num_points1, num_points2) 58 | 59 | src_embedding = src_embedding.permute(0, 2, 1) 60 | tar_embedding = tar_embedding.permute(0, 2, 1) 61 | 62 | simi_src = nn.Softmax(dim=2)(simi1) # 转化为概率 63 | glob_tar = torch.matmul(simi_src, tar_embedding) # 加权平均tar的全局特征 64 | glob_src = torch.max(src_embedding, dim=1, keepdim=True)[0] 65 | glob_src = glob_src.repeat(1, num_points1, 1) 66 | # print(glob_src.shape, glob_tar.shape,src_embedding.shape) 67 | inter_src_feature = torch.cat((src_embedding, glob_tar, glob_src, glob_tar-glob_src), dim=2) # 交互特征 68 | inter_src_feature = inter_src_feature.permute(0, 2, 1) 69 | 70 | return inter_src_feature 71 | 72 | 73 | def cos_simi(src_embedding, tgt_embedding): 74 | # (batch, emb_dims, num_points) 75 | batch_size, num_dims, num_points1 = src_embedding.size() 76 | batch_size, num_dims, num_points2 = tgt_embedding.size() 77 | 78 | # src_norm = src_embedding / (src_embedding.norm(dim=1).reshape(batch_size, 1, num_points1)) 79 | # tar_norm = tgt_embedding / (tgt_embedding.norm(dim=1).reshape(batch_size, 1, num_points2)) 80 | src_norm = F.normalize(src_embedding, p=2, dim=1) 81 | tar_norm = F.normalize(tgt_embedding, p=2, dim=1) 82 | simi = torch.matmul(src_norm.transpose(2, 1).contiguous(), tar_norm) # (batch, num_points1, num_points2) 83 | return simi 84 | 85 | 86 | class MLPs(nn.Module): 87 | def __init__(self, in_dim, mlps): 88 | super(MLPs, self).__init__() 89 | self.mlps = nn.Sequential() 90 | l = len(mlps) 91 | for i, out_dim in enumerate(mlps): 92 | self.mlps.add_module(f'fc_{i}', nn.Linear(in_dim, out_dim)) 93 | if i != l - 1: 94 | self.mlps.add_module(f'relu_{i}', nn.ReLU(inplace=True)) 95 | in_dim = out_dim 96 | 97 | def forward(self, x): 98 | x = self.mlps(x) 99 | return x 100 | 101 | 102 | class InitReg(nn.Module): 103 | def __init__(self): 104 | super(InitReg, self).__init__() 105 | self.num_dims = 512 106 | self.encoder = PointNet(n_emb_dims=self.num_dims) 107 | self.decoder = MLPs(in_dim=self.num_dims*2, mlps=[512, 512, 256, 7]) 108 | 109 | def forward(self, src, tgt): 110 | # (batch, 3, n) 111 | src_emb = self.encoder(src) 112 | tgt_emb = self.encoder(tgt) 113 | src_glob, _ = torch.max(src_emb, dim=2) 114 | tgt_glob, _ = torch.max(tgt_emb, dim=2) 115 | pose7d = self.decoder(torch.cat((src_glob, tgt_glob), dim=1)) 116 | batch_t, batch_quat = pose7d[:, :3], pose7d[:, 3:] / ( 117 | torch.norm(pose7d[:, 3:], dim=1, keepdim=True) + 1e-8) 118 | batch_R = batch_quat2mat(batch_quat) 119 | 120 | return batch_R, batch_t 121 | 122 | 123 | class RegNet(nn.Module): 124 | def __init__(self, n_emb_dims=1024): 125 | super(RegNet, self).__init__() 126 | self.emb_dims = n_emb_dims 127 | self.emb_dims1 = int(self.emb_dims / 2) 128 | self.emb_nn1 = DGCNN(self.emb_dims1, k=32) 129 | self.init_reg = InitReg() 130 | self.emb_nn2_src = nn.Sequential( 131 | nn.Conv1d(self.emb_dims1 * 4, self.emb_dims1 * 2, 1), nn.BatchNorm1d(self.emb_dims1 * 2), 132 | nn.LeakyReLU(negative_slope=0.01), 133 | nn.Conv1d(self.emb_dims1 * 2, self.emb_dims, 1), nn.BatchNorm1d(self.emb_dims), 134 | nn.LeakyReLU(negative_slope=0.01), 135 | ) 136 | self.emb_nn2_tgt = nn.Sequential( 137 | nn.Conv1d(self.emb_dims1 * 4, self.emb_dims1 * 2, 1), nn.BatchNorm1d(self.emb_dims1 * 2), 138 | nn.LeakyReLU(negative_slope=0.01), 139 | nn.Conv1d(self.emb_dims1 * 2, self.emb_dims, 1), nn.BatchNorm1d(self.emb_dims), 140 | nn.LeakyReLU(negative_slope=0.01), 141 | ) 142 | self.my_iter = torch.ones(1) 143 | self.reflect = nn.Parameter(torch.eye(3), requires_grad=False) 144 | self.reflect[2, 2] = -1 145 | 146 | def generate_keypoints(self, src, tgt, src_embedding, tgt_embedding): 147 | # src, tgt: (batch, n, 3) 148 | # embedding: (batch, emb_dims, num_points) 149 | # return: (batch, 3, n*3/4) 150 | num_points1 = src.shape[1] 151 | num_points2 = tgt.shape[1] 152 | simi1 = cos_simi(src_embedding, tgt_embedding) # (batch, num_points1, num_points2) 153 | simi1 = torch.max(simi1, dim=2)[0] # (batch, num_points1) 154 | values, indices = torch.topk(simi1, k=int(num_points1 * 0.9), dim=1, sorted=False) 155 | src_keypoints = gather_points(src, indices) 156 | src_embedding_key = gather_points(src_embedding.permute(0, 2, 1), indices) 157 | 158 | simi2 = cos_simi(tgt_embedding, src_embedding) # (batch, num_points2, num_points1) 159 | simi2 = torch.max(simi2, dim=2)[0] # (batch, num_points1) 160 | values, indices = torch.topk(simi2, k=int(num_points2 * 0.9), dim=1, sorted=False) 161 | tgt_keypoints = gather_points(tgt, indices) 162 | tgt_embedding_key = gather_points(tgt_embedding.permute(0, 2, 1), indices) 163 | 164 | return src_keypoints.permute(0, 2, 1), tgt_keypoints.permute(0, 2, 1), src_embedding_key.permute(0, 2, 1), tgt_embedding_key.permute(0, 2, 1) 165 | 166 | def generate_corr(self,src, tgt, src_embedding, tar_embedding): 167 | # src, tgt: (batch, n, 3) 168 | # embedding: (batch, emb_dims, num_points) 169 | simi1 = cos_simi(src_embedding, tar_embedding) # (num_points1, num_points2) 170 | simi2 = cos_simi(tar_embedding, src_embedding) 171 | 172 | simi_src = nn.Softmax(dim=2)(simi1) # 转化为概率 173 | src_corr = torch.matmul(simi_src, tgt) # 加权平均tar的全局特征作为对应点(n1, 3) 174 | 175 | simi_tar = nn.Softmax(dim=2)(simi2) # 转化为概率 176 | tgt_corr = torch.matmul(simi_tar, src) # 加权平均src的全局特征作为对应点(n2, 3) 177 | 178 | return src_corr.permute(0, 2, 1), tgt_corr.permute(0, 2, 1) 179 | 180 | def SVD(self, src, src_corr): 181 | # (batch, 3, n) 182 | batch_size = src.shape[0] 183 | src_centered = src - src.mean(dim=2, keepdim=True) 184 | src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True) 185 | 186 | H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()).cpu() 187 | R = [] 188 | for i in range(src.size(0)): 189 | u, s, v = torch.svd(H[i]) 190 | r = torch.matmul(v, u.transpose(1, 0)).contiguous() 191 | r_det = torch.det(r).item() 192 | diag = torch.from_numpy(np.array([[1.0, 0, 0], 193 | [0, 1.0, 0], 194 | [0, 0, r_det]]).astype('float32')).to(v.device) 195 | r = torch.matmul(torch.matmul(v, diag), u.transpose(1, 0)).contiguous() 196 | R.append(r) 197 | 198 | R = torch.stack(R, dim=0).cuda() 199 | t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True) 200 | if self.training: 201 | self.my_iter += 1 202 | return R, t.view(batch_size, 3) 203 | 204 | def forward(self, src, tgt): 205 | # src, tgt: (batch, 3, n) 206 | # embedding: (batch, emb_dims, num_points) 207 | # trans_src = self.stn(src) 208 | # trans_tgt = self.stn(tgt) 209 | # src = torch.bmm(src.transpose(2, 1), trans_src).transpose(2, 1) 210 | # tgt = torch.bmm(tgt.transpose(2, 1), trans_tgt).transpose(2, 1) 211 | src_embedding = self.emb_nn1(src) 212 | tgt_embedding = self.emb_nn1(tgt) 213 | # 特征融合 214 | inter_src_feature = feature_interaction(src_embedding, tgt_embedding) 215 | inter_tar_feature = feature_interaction(tgt_embedding, src_embedding) 216 | # 进一步提取特征 217 | src_embedding = self.emb_nn2_src(inter_src_feature) 218 | tgt_embedding = self.emb_nn2_tgt(inter_tar_feature) 219 | src, tgt, src_embedding, tgt_embedding = self.generate_keypoints(src.permute(0, 2, 1), tgt.permute(0, 2, 1), src_embedding, tgt_embedding) 220 | src_embedding = torch.where(torch.isnan(src_embedding), torch.full_like(src_embedding, random.random()), 221 | src_embedding) 222 | tgt_embedding = torch.where(torch.isnan(tgt_embedding), torch.full_like(tgt_embedding, random.random()), 223 | tgt_embedding) 224 | src_corr, tgt_corr = self.generate_corr(src.permute(0, 2, 1), tgt.permute(0, 2, 1), src_embedding, tgt_embedding) 225 | R1, t1 = self.SVD(src, src_corr) 226 | 227 | return R1, t1 228 | 229 | 230 | class OverlapNet(nn.Module): 231 | def __init__(self, n_emb_dims=1024, all_points=1024, src_subsampled_points=768, tgt_subsampled_points=768): 232 | super(OverlapNet, self).__init__() 233 | self.emb_dims = n_emb_dims 234 | self.all_points = all_points 235 | self.emb_dims1 = int(self.emb_dims / 2) 236 | self.src_subsampled_points = src_subsampled_points 237 | self.tgt_subsampled_points = tgt_subsampled_points 238 | self.reg_net = RegNet(self.emb_dims) 239 | self.emb_nn1 = DGCNN(self.emb_dims1, k=32) 240 | self.init_reg = InitReg() 241 | self.emb_nn2_src = nn.Sequential( 242 | nn.Conv1d(self.emb_dims1 * 4, self.emb_dims1 * 2, 1), nn.BatchNorm1d(self.emb_dims1 * 2),nn.LeakyReLU(negative_slope=0.01), 243 | nn.Conv1d(self.emb_dims1 * 2, self.emb_dims, 1), nn.BatchNorm1d(self.emb_dims),nn.LeakyReLU(negative_slope=0.01), 244 | ) 245 | self.emb_nn2_tgt = nn.Sequential( 246 | nn.Conv1d(self.emb_dims1 * 4, self.emb_dims1 * 2, 1), nn.BatchNorm1d(self.emb_dims1 * 2), 247 | nn.LeakyReLU(negative_slope=0.01), 248 | nn.Conv1d(self.emb_dims1 * 2, self.emb_dims, 1), nn.BatchNorm1d(self.emb_dims), 249 | nn.LeakyReLU(negative_slope=0.01), 250 | ) 251 | self.score_nn_src = nn.Sequential( 252 | nn.Conv1d(self.emb_dims, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 253 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 254 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 255 | nn.Conv1d(128, 1, 1), nn.Sigmoid() 256 | ) 257 | self.score_nn_tgt = nn.Sequential( 258 | nn.Conv1d(self.emb_dims, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 259 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 260 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 261 | nn.Conv1d(128, 1, 1), nn.Sigmoid() 262 | ) 263 | self.mask_src_nn = nn.Sequential( 264 | nn.Conv1d(self.tgt_subsampled_points, 1024, 1), nn.BatchNorm1d(1024), nn.LeakyReLU(negative_slope=0.01), 265 | nn.Conv1d(1024, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 266 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 267 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 268 | nn.Conv1d(128, 2, 1) 269 | ) 270 | self.mask_tgt_nn = nn.Sequential( 271 | nn.Conv1d(self.src_subsampled_points, 1024, 1), nn.BatchNorm1d(1024), nn.LeakyReLU(negative_slope=0.01), 272 | nn.Conv1d(1024, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.01), 273 | nn.Conv1d(512, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 274 | nn.Conv1d(128, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.01), 275 | nn.Conv1d(128, 2, 1) 276 | ) 277 | 278 | def forward(self, *input): 279 | src = input[0] # 1024 280 | tgt = input[1] # 768 281 | batch_size = src.shape[0] 282 | init_R, init_t = self.init_reg(src, tgt) 283 | src = torch.matmul(init_R, src) + init_t.reshape(batch_size, 3, 1) 284 | 285 | src_embedding = self.emb_nn1(src) 286 | tgt_embedding = self.emb_nn1(tgt) 287 | # 特征融合 288 | inter_src_feature = feature_interaction(src_embedding, tgt_embedding) 289 | inter_tar_feature = feature_interaction(tgt_embedding, src_embedding) 290 | # 进一步提取特征 291 | src_embedding = self.emb_nn2_src(inter_src_feature) 292 | tgt_embedding = self.emb_nn2_tgt(inter_tar_feature) 293 | # 计算打分 294 | src_score = self.score_nn_src(src_embedding).reshape(batch_size, 1, -1) 295 | tgt_score = self.score_nn_tgt(tgt_embedding).reshape(batch_size, 1, -1) 296 | 297 | src_score = nn.Softmax(dim=2)(src_score) 298 | tgt_score = nn.Softmax(dim=2)(tgt_score) 299 | 300 | simi1 = cos_simi(src_embedding, tgt_embedding) 301 | simi2 = cos_simi(tgt_embedding, src_embedding) 302 | 303 | # 结合打分计算相似度 304 | simi_src = simi1 * tgt_score 305 | simi_tgt = simi2 * src_score 306 | 307 | mask_src = self.mask_src_nn(simi_src.permute(0, 2, 1)) 308 | mask_tgt = self.mask_tgt_nn(simi_tgt.permute(0, 2, 1)) 309 | overlap_points = self.all_points - (self.all_points - self.src_subsampled_points) \ 310 | - (self.all_points - self.tgt_subsampled_points) 311 | 312 | mask_src_score = torch.softmax(mask_src, dim=1)[:, 1, :].detach() # (B, N) 313 | mask_tgt_score = torch.softmax(mask_tgt, dim=1)[:, 1, :].detach() 314 | # 取前overlap_points个点作为重叠点 315 | mask_src_idx = torch.zeros(mask_src_score.shape).cuda() 316 | values, indices = torch.topk(mask_src_score, k=overlap_points, dim=1) 317 | mask_src_idx.scatter_(1, indices, 1) # (dim, 索引, 根据索引赋的值) 318 | # mask_src_idx = torch.where(mask_src > values[:, -1].reshape(batch_size, -1), 1, 0) 319 | 320 | mask_tgt_idx = torch.zeros(mask_tgt_score.shape).cuda() 321 | values, indices = torch.topk(mask_tgt_score, k=overlap_points, dim=1) 322 | mask_tgt_idx.scatter_(1, indices, 1) # (dim, 索引, 根据索引赋的值) 323 | # mask_tgt_idx = torch.where(mask_tgt > values[:, -1].reshape(batch_size, -1), 1, 0) 324 | 325 | src = mask_point(mask_src_idx, src) 326 | tgt = mask_point(mask_tgt_idx, tgt) 327 | 328 | R1, t1 = self.reg_net(src, tgt) 329 | 330 | return init_R, init_t, mask_src, mask_tgt, mask_src_idx, mask_tgt_idx, R1, t1 331 | 332 | 333 | # # src,tar:[batchsize, 3, num_points] 334 | # src = torch.rand([4, 3, 800]) 335 | # tar = torch.rand([4, 3, 768]) 336 | # model = OverlapNet() 337 | # mask_src, mask_tgt, mask_src_idx, mask_tgt_idx = model(src, tar) 338 | 339 | 340 | -------------------------------------------------------------------------------- /OverlapReg/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_transform(batch_pc, batch_R, batch_t=None): 5 | ''' 6 | :param batch_pc: shape=(B, N, 3) 7 | :param batch_R: shape=(B, 3, 3) 8 | :param batch_t: shape=(B, 3) 9 | :return: shape(B, N, 3) 10 | ''' 11 | transformed_pc = torch.matmul(batch_pc, batch_R.permute(0, 2, 1).contiguous()) 12 | if batch_t is not None: 13 | transformed_pc = transformed_pc + torch.unsqueeze(batch_t, 1) 14 | return transformed_pc 15 | 16 | 17 | def batch_quat2mat(batch_quat): 18 | ''' 19 | :param batch_quat: shape=(B, 4) 20 | :return: 21 | ''' 22 | batch_quat = batch_quat.squeeze() 23 | w, x, y, z = batch_quat[:, 0], batch_quat[:, 1], batch_quat[:, 2], \ 24 | batch_quat[:, 3] 25 | device = batch_quat.device 26 | B = batch_quat.size()[0] 27 | R = torch.zeros(dtype=torch.float, size=(B, 3, 3)).to(device) 28 | R[:, 0, 0] = 1 - 2 * y * y - 2 * z * z 29 | R[:, 0, 1] = 2 * x * y - 2 * z * w 30 | R[:, 0, 2] = 2 * x * z + 2 * y * w 31 | R[:, 1, 0] = 2 * x * y + 2 * z * w 32 | R[:, 1, 1] = 1 - 2 * x * x - 2 * z * z 33 | R[:, 1, 2] = 2 * y * z - 2 * x * w 34 | R[:, 2, 0] = 2 * x * z - 2 * y * w 35 | R[:, 2, 1] = 2 * y * z + 2 * x * w 36 | R[:, 2, 2] = 1 - 2 * x * x - 2 * y * y 37 | return R -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RORNet: Partial-to-Partial Registration Network With Reliable Overlapping Representations 2 | 3 | by Yue Wu , Yue Zhang , Wenping Ma , Maoguo Gong , Xiaolong Fan , Mingyang Zhang , A. K. Qin , and Qiguang Miao, and details are in [paper](https://ieeexplore.ieee.org/document/10168979). 4 | 5 | ## Usage 6 | 7 | 1. Clone the repository. 8 | 9 | 2. Change the "DATA_DIR" parameter in the "data_utils.py" file to its own data set folder path. 10 | 11 | 3. Run the "main.py" in OverlapDetect file and save the pkl file; load pkl file trained by OverlapDetect file and run the OverlapReg file. 12 | **Note**: you need to make the "OverlapNet" model consistent for the OverlapDetect file and the OverlapReg file. 13 | 14 | 4. For convenience, We provide end-to-end training "running OverlapReg/main.py directly", but there may be a loss of accuracy. 15 | 16 | ## Requirement 17 | 18 | ​ h5py=3.7.0 19 | 20 | ​ open3d=0.15.2 21 | 22 | ​ pytorch=1.11.0 23 | 24 | ​ scikit-learn=1.1.1 25 | 26 | ​ transforms3d=0.4.1 27 | 28 | ​ tensorboardX=1.15.0 29 | 30 | ​ tqdm 31 | 32 | ​ numpy 33 | 34 | ## Dataset 35 | 36 | ​ (1) [ModelNet40](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip) 37 | 38 | ​ (2) [KITTI_odo](https://www.cvlibs.net/datasets/kitti/eval_odometry.php) 39 | 40 | ​ (3) [Stanford Bunny](http://graphics.stanford.edu/data/3Dscanrep/) 41 | 42 | ## Citation 43 | 44 | If you find the code or trained models useful, please consider citing: 45 | 46 | ``` 47 | @article{2023rornet, 48 | title={RORNet: Partial-to-Partial Registration Network With Reliable Overlapping Representations}, 49 | author={Wu, Yue and Zhang, Yue and Ma, Wenping and Gong, Maoguo and Fan, Xiaolong and Zhang, Mingyang and Qin, AK and Miao, Qiguang}, 50 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 51 | year={2023}, 52 | publisher={IEEE} 53 | } 54 | ``` 55 | 56 | ## Acknowledgement 57 | 58 | Our code refers to [PointNet](https://github.com/fxia22/pointnet.pytorch), [DCP](https://github.com/WangYueFt/dcp) and [MaskNet](https://github.com/vinits5/masknet). We want to thank the above open-source projects. 59 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob, h5py 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from sklearn.neighbors import NearestNeighbors 8 | from scipy.spatial.transform import Rotation 9 | from scipy.spatial.distance import minkowski 10 | import open3d as o3d 11 | 12 | # (9840, 2048, 3), (9840, 1) 13 | # download in:https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip 14 | DATA_DIR = '/home/zy/dataset' 15 | def load_data(partition, file_type='modelnet40'): 16 | # 读取训练集or测试集 17 | if file_type == 'Kitti_odo': 18 | file_name = 'KITTI_odometry/sequences/' 19 | DATA_FILES = { 20 | 'train': ['00', '01', '02', '03', '04', '05'], # 0,1,2,3,4,5 21 | 'test': ['08', '09', '10'], # 8,9,10 22 | } 23 | all_data = [] 24 | all_label = [] 25 | for idx in DATA_FILES[partition]: 26 | for fname in glob.glob(os.path.join(DATA_DIR, file_name, idx, "velodyne", "*.bin")): 27 | template_data = np.fromfile(fname, dtype=np.float32).reshape(-1, 4) 28 | # print(template_data.shape) 29 | points = template_data[:4096, :3] 30 | # points_idx = np.arange(points.shape[0]) 31 | # np.random.shuffle(points_idx) 32 | # points = points[points_idx[:2048], :] 33 | all_data.append(points) 34 | all_label.append(0) 35 | return np.array(all_data), np.array(all_label) 36 | elif file_type == 'modelnet40': 37 | file_name = 'modelnet40_ply_hdf5_2048' 38 | elif file_type == 'bunny': 39 | file_name = 'bunny/data/' 40 | all_data = [] 41 | all_label = [] 42 | for h5_name in glob.glob(os.path.join(DATA_DIR, file_name, '*.ply')): 43 | pc = o3d.io.read_point_cloud(h5_name) 44 | points = normalize_pc(np.array(pc.points)) 45 | # 采样10000个点 46 | points_idx = np.arange(points.shape[0]) 47 | np.random.shuffle(points_idx) 48 | points = points[points_idx[:4096], :] 49 | all_data.append(points) 50 | 51 | return np.array(all_data), np.array(all_label) 52 | else: 53 | print('Error file name!') 54 | all_data = [] 55 | all_label = [] 56 | for h5_name in glob.glob(os.path.join(DATA_DIR, file_name, 'ply_data_%s*.h5' % partition)): 57 | f = h5py.File(h5_name) 58 | data = f['data'][:].astype('float32') 59 | if file_name == 'S3DIS_hdf5': 60 | data = data[:, :, 0:3] 61 | label = f['label'][:].astype('int64') 62 | f.close() 63 | # 取1024个点 64 | # points_idx = np.arange(data.shape[1]) 65 | # np.random.shuffle(points_idx) 66 | # data = data[:, points_idx[:1024], :] 67 | all_data.append(data) 68 | all_label.append(label) 69 | all_data = np.concatenate(all_data, axis=0) 70 | all_label = np.concatenate(all_label, axis=0) 71 | return all_data, all_label # (9840, 2048, 3), (9840, 1) 72 | 73 | 74 | def normalize_pc(point_cloud): 75 | centroid = np.mean(point_cloud, axis=0) 76 | point_cloud -= centroid 77 | furthest_distance = np.max(np.sqrt(np.sum(abs(point_cloud) ** 2, axis=-1))) 78 | point_cloud /= furthest_distance 79 | return point_cloud 80 | 81 | 82 | def add_outliers(pointcloud, gt_mask): 83 | # pointcloud: Point Cloud (ndarray) [NxC] 84 | # output: Corrupted Point Cloud (ndarray) [(N+300)xC] 85 | if isinstance(pointcloud, np.ndarray): 86 | pointcloud = torch.from_numpy(pointcloud) 87 | 88 | num_outliers = 20 89 | N, C = pointcloud.shape 90 | outliers = 2*torch.rand(num_outliers, C)-1 # Sample points in a cube [-0.5, 0.5] 91 | pointcloud = torch.cat([pointcloud, outliers], dim=0) 92 | gt_mask = torch.cat([gt_mask, torch.zeros(num_outliers)]) 93 | 94 | idx = torch.randperm(pointcloud.shape[0]) 95 | pointcloud, gt_mask = pointcloud[idx], gt_mask[idx] 96 | return pointcloud.numpy(), gt_mask 97 | 98 | 99 | # 加入高斯噪声 100 | def jitter_pointcloud(pointcloud, sigma=0.2, clip=0.05): 101 | N, C = pointcloud.shape 102 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1 * clip, clip) 103 | # pointcloud += sigma * np.random.randn(N, C) 104 | return pointcloud 105 | 106 | 107 | def Farthest_Point_Sampling(pointcloud1, src_subsampled_points, tgt_subsampled_points=None): 108 | # (num_points, 3) 109 | pointcloud1 = pointcloud1 110 | num_points = pointcloud1.shape[0] 111 | 112 | if tgt_subsampled_points is None: 113 | nbrs1 = NearestNeighbors(n_neighbors=src_subsampled_points, algorithm='auto', 114 | metric=lambda x, y: minkowski(x, y)).fit(pointcloud1[:, :3]) 115 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) 116 | idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((src_subsampled_points,)) 117 | gt_mask_src = torch.zeros(num_points).scatter_(0, torch.tensor(idx1), 1) 118 | 119 | return pointcloud1[idx1, :], gt_mask_src 120 | 121 | else: 122 | nbrs_src = NearestNeighbors(n_neighbors=src_subsampled_points, algorithm='auto', 123 | metric=lambda x, y: minkowski(x, y)).fit(pointcloud1[:, :3]) 124 | 125 | # 打乱点的顺序 126 | nbrs_tgt = NearestNeighbors(n_neighbors=tgt_subsampled_points, algorithm='auto', 127 | metric=lambda x, y: minkowski(x, y)).fit(pointcloud1[:, :3]) 128 | random = np.random.random(size=(1, 3)) 129 | 130 | random_p1 = random + np.array([[500, 500, 500]]) 131 | src = nbrs_src.kneighbors(random_p1, return_distance=False).reshape((src_subsampled_points,)) 132 | mask_src = torch.zeros(num_points).scatter_(0, torch.tensor(src), 1) # (src_subsampled_points) 133 | src = torch.sort(torch.tensor(src))[0] 134 | 135 | random_p2 = random - np.array([[500, 500, 500]]) 136 | tgt = nbrs_tgt.kneighbors(random_p2, return_distance=False).reshape((tgt_subsampled_points,)) 137 | mask_tgt = torch.zeros(num_points).scatter_(0, torch.tensor(tgt), 1) # (tgt_subsampled_points) 138 | tgt = torch.sort(torch.tensor(tgt))[0] 139 | 140 | return pointcloud1[src, :], mask_src, pointcloud1[tgt, :], mask_tgt 141 | 142 | 143 | class ModelNet40_Reg(Dataset): 144 | def __init__(self, num_points, subsampled_rate_src, subsampled_rate_tgt, partition='train', max_angle=45, max_t=0.5, 145 | noise=False, partial_overlap=2, unseen=False, file_type='modelnet40'): 146 | self.partial_overlap = partial_overlap # 部分重叠的点云个数:0,1,2 147 | self.data, self.label = load_data(partition, file_type=file_type) 148 | self.num_points = num_points 149 | self.file_type = file_type 150 | self.partition = partition 151 | self.label = self.label.squeeze() # 去掉维度为1的条目 152 | self.max_angle = np.pi / 180 * max_angle 153 | self.max_t = max_t 154 | self.noise = noise 155 | self.unseen =unseen 156 | self.subsampled_rate_src = subsampled_rate_src 157 | self.subsampled_rate_tgt = subsampled_rate_tgt 158 | 159 | if file_type == 'modelnet40' and self.unseen: 160 | # simulate testing on first 20 categories while training on last 20 categories 161 | if self.partition == 'test': 162 | self.data = self.data[self.label >= 20] 163 | self.label = self.label[self.label >= 20] 164 | elif self.partition == 'train': 165 | self.data = self.data[self.label < 20] 166 | self.label = self.label[self.label < 20] 167 | 168 | def __getitem__(self, item): 169 | pointcloud = self.data[item][:self.num_points] 170 | 171 | # pointcloud = self.data[item] 172 | anglex = np.random.uniform(-self.max_angle, self.max_angle) 173 | angley = np.random.uniform(-self.max_angle, self.max_angle) 174 | anglez = np.random.uniform(-self.max_angle, self.max_angle) 175 | cosx = np.cos(anglex) 176 | cosy = np.cos(angley) 177 | cosz = np.cos(anglez) 178 | sinx = np.sin(anglex) 179 | siny = np.sin(angley) 180 | sinz = np.sin(anglez) 181 | Rx = np.array([[1, 0, 0], 182 | [0, cosx, -sinx], 183 | [0, sinx, cosx]]) 184 | Ry = np.array([[cosy, 0, siny], 185 | [0, 1, 0], 186 | [-siny, 0, cosy]]) 187 | Rz = np.array([[cosz, -sinz, 0], 188 | [sinz, cosz, 0], 189 | [0, 0, 1]]) 190 | R_ab = Rx.dot(Ry).dot(Rz) 191 | rotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex]) 192 | euler_ab = np.asarray([anglez, angley, anglex]) 193 | # 平移矩阵t 194 | translation_ab = np.array([np.random.uniform(-self.max_t, self.max_t), np.random.uniform(-self.max_t, self.max_t), 195 | np.random.uniform(-self.max_t, self.max_t)]) 196 | # 第item个物体 点云1 [Nx3] 197 | pointcloud1 = pointcloud 198 | 199 | # 部分重叠 200 | if self.partial_overlap == 2: 201 | # (num_points, 3) 202 | src_subsampled_points = int(self.subsampled_rate_src * pointcloud1.shape[0]) 203 | tgt_subsampled_points = int(self.subsampled_rate_tgt * pointcloud1.shape[0]) 204 | # (num_points, 3) 205 | pointcloud1, mask_src, pointcloud2, mask_tgt = Farthest_Point_Sampling( 206 | pointcloud1, src_subsampled_points, tgt_subsampled_points) 207 | # print("src",torch.unique(mask_src, return_counts=True), pointcloud1.shape) 208 | # print("tgt",torch.unique(mask_tgt, return_counts=True), pointcloud2.shape) 209 | 210 | gt_mask_src = [] 211 | gt_mask_tgt = [] 212 | for i in range(pointcloud.shape[0]): 213 | if mask_src[i] == 1 and mask_tgt[i] == 1: 214 | gt_mask_src.append(1) 215 | gt_mask_tgt.append(1) 216 | elif mask_src[i] == 1 and mask_tgt[i] == 0: 217 | gt_mask_src.append(0) 218 | elif mask_src[i] == 0 and mask_tgt[i] == 1: 219 | gt_mask_tgt.append(0) 220 | 221 | gt_mask_src = torch.Tensor(gt_mask_src) 222 | gt_mask_tgt = torch.Tensor(gt_mask_tgt) 223 | 224 | pointcloud2 = rotation_ab.apply(pointcloud2).T + np.expand_dims(translation_ab, axis=1) 225 | # 打乱点的顺序 226 | state = np.random.get_state() 227 | pointcloud1 = np.random.permutation(pointcloud1).T 228 | np.random.set_state(state) 229 | gt_mask_src = np.random.permutation(gt_mask_src).T 230 | 231 | if self.noise: 232 | # ---加入噪声--- 233 | # (num_points, 3) 234 | pointcloud2 = jitter_pointcloud(pointcloud2.T) 235 | pointcloud2 = pointcloud2.T 236 | 237 | return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \ 238 | translation_ab.astype('float32'), euler_ab.astype('float32'), gt_mask_src, gt_mask_tgt 239 | 240 | else: 241 | raise ValueError('partial_overlap must be 2!') 242 | 243 | def __len__(self): 244 | return self.data.shape[0] 245 | 246 | 247 | # data, label = load_data('train', file_type='Kitti_odo') 248 | # print(data.shape) 249 | 250 | -------------------------------------------------------------------------------- /evaluate_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import transforms3d 5 | from scipy.spatial.transform import Rotation 6 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 7 | 8 | 9 | def calculate_R_msemae(r1, r2, seq='zyx', degrees=True): 10 | ''' 11 | Calculate mse, mae euler angle error. 12 | :param r1: shape=(B, 3, 3), pred 13 | :param r2: shape=(B, 3, 3), gt 14 | :return: 15 | ''' 16 | if isinstance(r1, torch.Tensor): 17 | r1 = r1.cpu().detach().numpy() 18 | if isinstance(r2, torch.Tensor): 19 | r2 = r2.cpu().detach().numpy() 20 | assert r1.shape == r2.shape 21 | eulers1, eulers2 = [], [] 22 | for i in range(r1.shape[0]): 23 | euler1 = Rotation.from_matrix(r1[i]).as_euler(seq=seq, degrees=degrees) 24 | euler2 = Rotation.from_matrix(r2[i]).as_euler(seq=seq, degrees=degrees) 25 | eulers1.append(euler1) 26 | eulers2.append(euler2) 27 | eulers1 = np.stack(eulers1, axis=0) 28 | eulers2 = np.stack(eulers2, axis=0) 29 | r_mse = np.mean((eulers1 - eulers2)**2, axis=-1) 30 | r_mae = np.mean(np.abs(eulers1 - eulers2), axis=-1) 31 | 32 | return np.mean(r_mse), np.mean(r_mae) 33 | 34 | 35 | def calculate_t_msemae(t1, t2): 36 | ''' 37 | calculate translation mse and mae error. 38 | :param t1: shape=(B, 3) 39 | :param t2: shape=(B, 3) 40 | :return: 41 | ''' 42 | if isinstance(t1, torch.Tensor): 43 | t1 = t1.cpu().detach().numpy() 44 | if isinstance(t2, torch.Tensor): 45 | t2 = t2.cpu().detach().numpy() 46 | assert t1.shape == t2.shape 47 | t_mse = np.mean((t1 - t2) ** 2, axis=1) 48 | t_mae = np.mean(np.abs(t1 - t2), axis=1) 49 | return np.mean(t_mse), np.mean(t_mae) 50 | 51 | 52 | def find_errors(gt_R, pred_R, gt_t, pred_t): 53 | # gt_R: ground truth Rotation matrix [3, 3] 54 | # pred_R: predicted rotation matrix [3, 3] 55 | # gt_t: ground truth translation vector [1, 3] 56 | # pred_t: predicted translation matrix [1, 3] 57 | 58 | translation_error = np.sqrt(np.sum(np.square(gt_t - pred_t))) 59 | # Convert matrix remains to axis angle representation and report the angle as rotation error. 60 | error_mat = np.dot(gt_R.T, pred_R) # matrix remains [3, 3] 61 | rad = transforms3d.axangles.mat2axangle(error_mat)[1] # 返回弧度 62 | angle = abs(rad*(180/np.pi)) 63 | 64 | # 另一种方法计算角度误差 65 | rad1 = transforms3d.axangles.mat2axangle(pred_R)[1] 66 | rad2 = transforms3d.axangles.mat2axangle(gt_R)[1] 67 | angle_our = abs(rad1*(180/np.pi) - rad2*(180/np.pi)) % 360 68 | return translation_error, angle, angle_our 69 | 70 | 71 | def compute_error(rotation, rotation_pred, translation, translation_pred): 72 | # 输入batch个数据 73 | errors = [] 74 | # 计算旋转角度误差和平移误差 75 | if isinstance(rotation, torch.Tensor): 76 | rotation = rotation.cpu().detach().numpy() 77 | if isinstance(rotation_pred, torch.Tensor): 78 | rotation_pred = rotation_pred.cpu().detach().numpy() 79 | if isinstance(translation, torch.Tensor): 80 | translation = translation.cpu().detach().numpy() 81 | if isinstance(translation_pred, torch.Tensor): 82 | translation_pred = translation_pred.cpu().detach().numpy() 83 | 84 | for gt_R_i, pred_R_i, gt_t_i, pred_t_i in \ 85 | zip(rotation, rotation_pred, translation, translation_pred): 86 | errors.append(find_errors(gt_R_i, pred_R_i, gt_t_i, pred_t_i)) 87 | # (t_error, angle_error, angle2_error)仅三个数 88 | return np.mean(errors, axis=0) 89 | 90 | 91 | def quat2mat(quat): 92 | x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] 93 | 94 | B = quat.size(0) 95 | 96 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 97 | wx, wy, wz = w*x, w*y, w*z 98 | xy, xz, yz = x*y, x*z, y*z 99 | 100 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 101 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 102 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) 103 | return rotMat 104 | 105 | def transform_point_cloud(point_cloud, rotation, translation): 106 | if len(rotation.size()) == 2: 107 | rot_mat = quat2mat(rotation) 108 | else: 109 | rot_mat = rotation 110 | return torch.matmul(rot_mat, point_cloud) + translation.unsqueeze(2) 111 | 112 | 113 | def evaluate_mask(mask, mask_gt): 114 | accs = [] 115 | preciss = [] 116 | recalls = [] 117 | f1s = [] 118 | for m, m_gt in zip(mask, mask_gt): 119 | m = m.cpu() 120 | m_gt = m_gt.cpu() 121 | # mask, mask_gt: n维 122 | acc = accuracy_score(m_gt, m) 123 | precis = precision_score(m_gt, m, zero_division=0) 124 | recall = recall_score(m_gt, m, zero_division=0) 125 | f1 = f1_score(m_gt, m) 126 | 127 | accs.append(acc) 128 | preciss.append(precis) 129 | recalls.append(recall) 130 | f1s.append(f1) 131 | acc = np.mean(accs) 132 | precis = np.mean(preciss) 133 | recall = np.mean(recalls) 134 | f1 = np.mean(f1s) 135 | 136 | return acc, precis, recall, f1 137 | 138 | 139 | -------------------------------------------------------------------------------- /feature_extract.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import cuda, FloatTensor, LongTensor 6 | import copy, math 7 | 8 | 9 | # ----- DGCNN ----- 10 | def knn(x, k): 11 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) 12 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 13 | distance = -xx - inner - xx.transpose(2, 1).contiguous() 14 | 15 | idx = distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 16 | return idx 17 | 18 | def get_graph_feature(x, k=20): 19 | # x = x.squeeze() 20 | x = x.view(*x.size()[:3]) 21 | idx = knn(x, k=k) # (batch_size, num_points, k) 22 | batch_size, num_points, _ = idx.size() 23 | 24 | idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points 25 | idx = idx + idx_base 26 | idx = idx.view(-1) 27 | 28 | _, num_dims, _ = x.size() 29 | 30 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 31 | feature = x.view(batch_size * num_points, -1)[idx, :] 32 | feature = feature.view(batch_size, num_points, k, num_dims) 33 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 34 | 35 | feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2) 36 | # [batchsize, 输入特征dim*2, num_points, k] 37 | return feature 38 | 39 | 40 | class DGCNN(nn.Module): 41 | def __init__(self, n_emb_dims=512, k=20): 42 | super(DGCNN, self).__init__() 43 | self.k = k 44 | self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False) 45 | self.conv2 = nn.Conv2d(64 * 2, 64, kernel_size=1, bias=False) 46 | self.conv3 = nn.Conv2d(64 * 2, 128, kernel_size=1, bias=False) 47 | self.conv4 = nn.Conv2d(128 * 2, 256, kernel_size=1, bias=False) 48 | self.conv5 = nn.Conv2d(512, n_emb_dims, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(64) 50 | self.bn2 = nn.BatchNorm2d(64) 51 | self.bn3 = nn.BatchNorm2d(128) 52 | self.bn4 = nn.BatchNorm2d(256) 53 | self.bn5 = nn.BatchNorm2d(n_emb_dims) 54 | 55 | def forward(self, x): 56 | batch_size, num_dims, num_points = x.size() 57 | x = get_graph_feature(x, self.k) 58 | x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2) 59 | x1 = x.max(dim=-1, keepdim=True)[0] 60 | 61 | x = get_graph_feature(x1, self.k) 62 | x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2) 63 | x2 = x.max(dim=-1, keepdim=True)[0] 64 | 65 | x = get_graph_feature(x2, self.k) 66 | x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2) 67 | x3 = x.max(dim=-1, keepdim=True)[0] 68 | 69 | x = get_graph_feature(x3, self.k) 70 | x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2) 71 | x4 = x.max(dim=-1, keepdim=True)[0] 72 | 73 | x = torch.cat((x1, x2, x3, x4), dim=1) 74 | 75 | x = F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.2).view(batch_size, -1, num_points) 76 | return x 77 | 78 | # ----- END-DGCNN ----- 79 | 80 | class STNkd(nn.Module): 81 | def __init__(self, k=3): 82 | super(STNkd, self).__init__() 83 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 84 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 85 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 86 | self.fc1 = nn.Linear(1024, 512) 87 | self.fc2 = nn.Linear(512, 256) 88 | self.fc3 = nn.Linear(256, k*k) 89 | self.relu = nn.ReLU() 90 | 91 | self.bn1 = nn.BatchNorm1d(64) 92 | self.bn2 = nn.BatchNorm1d(128) 93 | self.bn3 = nn.BatchNorm1d(1024) 94 | self.bn4 = nn.BatchNorm1d(512) 95 | self.bn5 = nn.BatchNorm1d(256) 96 | 97 | self.k = k 98 | 99 | def forward(self, x): 100 | batchsize = x.size()[0] 101 | x = F.relu(self.bn1(self.conv1(x))) 102 | x = F.relu(self.bn2(self.conv2(x))) 103 | x = F.relu(self.bn3(self.conv3(x))) 104 | x = torch.max(x, 2, keepdim=True)[0] 105 | x = x.view(-1, 1024) 106 | 107 | x = F.relu(self.bn4(self.fc1(x))) 108 | x = F.relu(self.bn5(self.fc2(x))) 109 | x = self.fc3(x) 110 | 111 | iden = torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)).view(1,self.k*self.k).repeat(batchsize,1) 112 | if x.is_cuda: 113 | iden = iden.cuda() 114 | x = x + iden 115 | x = x.view(-1, self.k, self.k) 116 | return x 117 | 118 | 119 | class PointNet(nn.Module): 120 | def __init__(self, n_emb_dims): 121 | super(PointNet, self).__init__() 122 | self.stn = STNkd(k=64) 123 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) 124 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 125 | self.conv3 = nn.Conv1d(64, 128, kernel_size=1, bias=False) 126 | self.conv4 = nn.Conv1d(128, 256, kernel_size=1, bias=False) 127 | self.conv5 = nn.Conv1d(512, n_emb_dims, kernel_size=1, bias=False) 128 | self.bn1 = nn.BatchNorm1d(64) 129 | self.bn2 = nn.BatchNorm1d(64) 130 | self.bn3 = nn.BatchNorm1d(128) 131 | self.bn4 = nn.BatchNorm1d(256) 132 | self.bn5 = nn.BatchNorm1d(n_emb_dims) 133 | 134 | def forward(self, x): 135 | x1 = F.relu(self.bn1(self.conv1(x))) 136 | x2 = F.relu(self.bn2(self.conv2(x1))) 137 | trans_feat = self.stn(x2) 138 | x2 = x2.transpose(2, 1) 139 | x2 = torch.bmm(x2, trans_feat) 140 | x2 = x2.transpose(2, 1) 141 | x3 = F.relu(self.bn3(self.conv3(x2))) 142 | x4 = F.relu(self.bn4(self.conv4(x3))) 143 | x5 = torch.cat((x1, x2, x3, x4), dim=1) 144 | x = F.relu(self.bn5(self.conv5(x5))) 145 | return x 146 | 147 | 148 | --------------------------------------------------------------------------------