├── assets ├── benchmark.png ├── kitti_qual.png └── qual_kitti.png ├── utils ├── __init__.py ├── softmax_entropy.py ├── log_util.py ├── metric_util.py ├── cmap_plot.py ├── util.py ├── np_ioueval.py ├── load_save_util.py └── lovasz_losses.py ├── config ├── __init__.py ├── semantickitti-tta.yaml ├── semantickitti-multiscan.yaml ├── semantickitti-tta-val.yaml ├── semantickitti.yaml ├── config.py └── label_mapping │ ├── semantic-poss-multiscan.yaml │ ├── semantic-kitti.yaml │ ├── semantic-kitti-multiscan.yaml │ └── semantic-kitti-all.yaml ├── network ├── __init__.py ├── cylinder_spconv_3d_unlock.py ├── cylinder_fea_generator.py ├── conv_base.py └── segmentator_3d_asymm_spconv_unlock.py ├── dataloader ├── __init__.py ├── pc_dataset.py ├── pc_dataset_test.py ├── dataset_semantickitti.py └── dataset_semantickitti_test.py ├── requirements.txt ├── README.md └── run_tta_test.py /assets/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blue-531/TALoS/HEAD/assets/benchmark.png -------------------------------------------------------------------------------- /assets/kitti_qual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blue-531/TALoS/HEAD/assets/kitti_qual.png -------------------------------------------------------------------------------- /assets/qual_kitti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blue-531/TALoS/HEAD/assets/qual_kitti.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | 5 | # from . import dataset_nuscenes -------------------------------------------------------------------------------- /utils/softmax_entropy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def softmax_entropy(x): 5 | """Entropy of softmax distribution from logits.""" 6 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) -------------------------------------------------------------------------------- /utils/log_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: log_util.py 4 | 5 | 6 | def save_to_log(logdir, logfile, message): 7 | f = open(logdir + '/' + logfile, "a") 8 | f.write(message + '\n') 9 | f.close() 10 | return -------------------------------------------------------------------------------- /utils/metric_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: metric_util.py 4 | 5 | import numpy as np 6 | 7 | 8 | def fast_hist(pred, label, n): 9 | k = (label >= 0) & (label < n) 10 | bin_count = np.bincount( 11 | n * label[k].astype(int) + pred[k], minlength=n ** 2) 12 | return bin_count[:n ** 2].reshape(n, n) 13 | 14 | 15 | def per_class_iu(hist): 16 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 17 | 18 | 19 | def fast_hist_crop(output, target, unique_label): 20 | hist = fast_hist(output.flatten(), target.flatten(), np.max(unique_label) + 2) 21 | hist = hist[unique_label + 1, :] 22 | hist = hist[:, unique_label + 1] 23 | return hist 24 | -------------------------------------------------------------------------------- /config/semantickitti-tta.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 256 11 | - 256 12 | - 32 13 | 14 | fea_dim: 7 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 32 18 | use_norm: True 19 | init_size: 32 20 | 21 | 22 | ################### 23 | ## Dataset options 24 | dataset_params: 25 | dataset_type: "voxel_dataset" 26 | pc_dataset_type: "SemKITTI_sk_multiscan" 27 | ignore_label: 255 28 | return_test: True 29 | fixed_volume_space: True 30 | label_mapping: "./config/label_mapping/semantic-kitti-multiscan.yaml" 31 | max_volume_space: 32 | - 51.2 33 | - 25.6 34 | - 4.4 35 | min_volume_space: 36 | - 0 37 | - -25.6 38 | - -2 39 | 40 | 41 | ################### 42 | ## Data_loader options 43 | train_data_loader: 44 | data_path: "./dataset/sequences" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 2 48 | shuffle: True 49 | num_workers: 0 50 | 51 | val_data_loader: 52 | data_path: "./dataset/sequences" 53 | # imageset: "val" 54 | imageset: "test" 55 | return_ref: True 56 | batch_size: 1 #2 57 | shuffle: False 58 | num_workers: 0 59 | 60 | 61 | 62 | 63 | ################### 64 | ## Train params 65 | train_params: 66 | model_load_path: "./model_load_dir/" 67 | model_save_path: "./model_load_dir/" 68 | checkpoint_every_n_steps: 4599 69 | max_num_epochs: 40 70 | eval_every_n_steps: 1917 71 | learning_rate: 0.0015 72 | -------------------------------------------------------------------------------- /config/semantickitti-multiscan.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 256 11 | - 256 12 | - 32 13 | 14 | fea_dim: 7 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 32 18 | use_norm: True 19 | init_size: 32 20 | 21 | 22 | ################### 23 | ## Dataset options 24 | dataset_params: 25 | dataset_type: "voxel_dataset" 26 | pc_dataset_type: "SemKITTI_sk_multiscan" 27 | ignore_label: 255 28 | return_test: True 29 | fixed_volume_space: True 30 | label_mapping: "./config/label_mapping/semantic-kitti-multiscan.yaml" 31 | max_volume_space: 32 | - 51.2 33 | - 25.6 34 | - 4.4 35 | min_volume_space: 36 | - 0 37 | - -25.6 38 | - -2 39 | 40 | 41 | ################### 42 | ## Data_loader options 43 | train_data_loader: 44 | data_path: "./dataset/sequences" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 2 48 | shuffle: True 49 | num_workers: 0 50 | 51 | val_data_loader: 52 | data_path: "./dataset/sequences" 53 | imageset: "val" 54 | # imageset: "test" 55 | return_ref: True 56 | batch_size: 1 #2 57 | shuffle: False 58 | num_workers: 0 59 | 60 | 61 | 62 | 63 | ################### 64 | ## Train params 65 | train_params: 66 | model_load_path: "./model_load_dir/" 67 | model_save_path: "./model_load_dir/" 68 | checkpoint_every_n_steps: 4599 69 | max_num_epochs: 40 70 | eval_every_n_steps: 1917 71 | learning_rate: 0.0015 72 | -------------------------------------------------------------------------------- /config/semantickitti-tta-val.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 256 11 | - 256 12 | - 32 13 | 14 | fea_dim: 7 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 32 18 | use_norm: True 19 | init_size: 32 20 | 21 | 22 | ################### 23 | ## Dataset options 24 | dataset_params: 25 | dataset_type: "voxel_dataset" 26 | pc_dataset_type: "SemKITTI_sk_multiscan" 27 | ignore_label: 255 28 | return_test: True 29 | fixed_volume_space: True 30 | label_mapping: "./config/label_mapping/semantic-kitti-multiscan.yaml" 31 | max_volume_space: 32 | - 51.2 33 | - 25.6 34 | - 4.4 35 | min_volume_space: 36 | - 0 37 | - -25.6 38 | - -2 39 | 40 | 41 | ################### 42 | ## Data_loader options 43 | train_data_loader: 44 | data_path: "./dataset/sequences" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 2 48 | shuffle: True 49 | num_workers: 0 50 | 51 | val_data_loader: 52 | data_path: "./dataset/sequences" 53 | imageset: "val" 54 | # imageset: "test" 55 | return_ref: True 56 | batch_size: 1 #2 57 | shuffle: False 58 | num_workers: 0 59 | 60 | 61 | 62 | 63 | ################### 64 | ## Train params 65 | train_params: 66 | model_load_path: "./model_load_dir/" 67 | model_save_path: "./model_load_dir/" 68 | checkpoint_every_n_steps: 4599 69 | max_num_epochs: 40 70 | eval_every_n_steps: 1917 71 | learning_rate: 0.0015 72 | -------------------------------------------------------------------------------- /config/semantickitti.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 32 #16 18 | use_norm: True 19 | init_size: 32 #16 20 | 21 | 22 | ################### 23 | ## Dataset options 24 | dataset_params: 25 | dataset_type: "cylinder_dataset" 26 | pc_dataset_type: "SemKITTI_sk" 27 | ignore_label: 0 28 | return_test: False 29 | fixed_volume_space: True 30 | label_mapping: "./config/label_mapping/semantic-kitti.yaml" 31 | max_volume_space: 32 | - 50 33 | - 3.1415926 34 | - 2 35 | min_volume_space: 36 | - 0 37 | - -3.1415926 38 | - -4 39 | 40 | 41 | ################### 42 | ## Data_loader options 43 | train_data_loader: 44 | data_path: "/nvme/yuenan/semantickitti_dataset/sequences/" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 2 48 | shuffle: True 49 | num_workers: 12 #4 50 | 51 | val_data_loader: 52 | data_path: "/nvme/yuenan/semantickitti_dataset/sequences/" 53 | imageset: "test" #"val" 54 | return_ref: True 55 | batch_size: 1 56 | shuffle: False 57 | num_workers: 12 #4 58 | 59 | 60 | ################### 61 | ## Train params 62 | train_params: 63 | model_load_path: "./model_load_dir/model_full_ft.pt" 64 | model_save_path: "./model_save_dir/model_tmp.pt" 65 | checkpoint_every_n_steps: 4599 66 | max_num_epochs: 20 #40 67 | eval_every_n_steps: 5000 #4599 68 | learning_rate: 0.002 #1 69 | -------------------------------------------------------------------------------- /utils/cmap_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import itertools 3 | import io 4 | import numpy as np 5 | 6 | 7 | def label_2_cmap(target,pred): 8 | pred_counts=np.zeros((20,20)) 9 | for i in range(20): 10 | pred_i=np.where(target==i,pred,255) 11 | pred_id,counts=np.unique(pred_i,return_counts=True) 12 | 13 | for idx,pred_category in enumerate(pred_id): 14 | if pred_category==255: 15 | continue 16 | pred_counts[i,pred_category]=counts[idx] 17 | return pred_counts 18 | 19 | def plot_confusion_matrix(cm, class_names =["unlabeled","car","bicycle","motorcycle","truck","other-vehicle","person","bicyclist","motorcyclist","road","parking","sidewalk","other-ground","building","fence","vegetation","trunk","terrain","pole","traffic-sign"]): 20 | """ 21 | Returns a matplotlib figure containing the plotted confusion matrix. 22 | 23 | Args: 24 | cm (array, shape = [n, n]): a confusion matrix of integer classes 25 | class_names (array, shape = [n]): String names of the integer classes 26 | """ 27 | cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=4) 28 | 29 | figure = plt.figure(figsize=(15, 15)) 30 | plt.imshow(cm, cmap=plt.cm.Blues) 31 | plt.title("Confusion matrix") 32 | plt.colorbar() 33 | tick_marks = np.arange(len(class_names)) 34 | plt.xticks(tick_marks, class_names, rotation=45) 35 | plt.yticks(tick_marks, class_names) 36 | 37 | # Compute the labels from the normalized confusion matrix. 38 | labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=4) 39 | 40 | # Use white text if squares are dark; otherwise black. 41 | threshold = cm.max() / 2. 42 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 43 | color = "white" if cm[i, j] > threshold else "black" 44 | plt.text(j, i, labels[i, j], horizontalalignment="center", color=color) 45 | 46 | plt.tight_layout() 47 | plt.ylabel('True label') 48 | plt.xlabel('Predicted label') 49 | # plt.savefig(f'{fig_name}.png', dpi=300) 50 | buf = io.BytesIO() 51 | plt.savefig(buf, format='png') 52 | plt.close(figure) 53 | buf.seek(0) 54 | 55 | return buf 56 | 57 | -------------------------------------------------------------------------------- /network/cylinder_spconv_3d_unlock.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: cylinder_spconv_3d.py 4 | 5 | from torch import nn 6 | import torch 7 | 8 | REGISTERED_MODELS_CLASSES = {} 9 | 10 | 11 | def register_model(cls, name=None): 12 | global REGISTERED_MODELS_CLASSES 13 | if name is None: 14 | name = cls.__name__ 15 | assert name not in REGISTERED_MODELS_CLASSES, f"exist class: {REGISTERED_MODELS_CLASSES}" 16 | REGISTERED_MODELS_CLASSES[name] = cls 17 | return cls 18 | 19 | 20 | def get_model_class(name): 21 | global REGISTERED_MODELS_CLASSES 22 | assert name in REGISTERED_MODELS_CLASSES, f"available class: {REGISTERED_MODELS_CLASSES}" 23 | return REGISTERED_MODELS_CLASSES[name] 24 | 25 | 26 | @register_model 27 | class cylinder_asym(nn.Module): 28 | def __init__(self, 29 | cylin_model, 30 | segmentator_spconv, 31 | sparse_shape, 32 | ): 33 | super().__init__() 34 | self.name = "cylinder_asym" 35 | 36 | self.cylinder_3d_generator = cylin_model #cylinder_fea_generator 37 | 38 | self.cylinder_3d_spconv_seg = segmentator_spconv #Asymm_3d_spconv 39 | 40 | self.sparse_shape = sparse_shape 41 | 42 | def forward(self, train_pt_fea_ten, train_vox_ten, batch_size, val_grid=None, voting_num=4, use_tta=False, extraction='all'): 43 | coords, features_3d = self.cylinder_3d_generator(train_pt_fea_ten, train_vox_ten) 44 | # train_pt_fea_ten: [batch_size, N1, 7] 45 | # train_vox_ten: [batch_size, N1, 3] 46 | 47 | if use_tta: 48 | batch_size *= voting_num 49 | 50 | spatial_features = self.cylinder_3d_spconv_seg(features_3d, coords, batch_size, extraction) # [batch_size, 20, 256, 256, 32] 51 | 52 | if use_tta: 53 | fused_predict = spatial_features[0, :] 54 | for idx in range(1, voting_num, 1): 55 | aug_predict = spatial_features[idx, :] 56 | aug_predict = torch.flip(aug_predict, dims=[2]) 57 | fused_predict += aug_predict 58 | return torch.unsqueeze(fused_predict, 0) 59 | else: 60 | return spatial_features 61 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def Bresenham3D(start_point, end_points,line_end): 4 | ListOfPoints = [] 5 | 6 | for end_point in end_points: 7 | x1, y1, z1=start_point[0] 8 | x2, y2, z2=end_point 9 | ListOfPoints.append((x1, y1, z1)) 10 | dx = abs(x2 - x1) 11 | dy = abs(y2 - y1) 12 | dz = abs(z2 - z1) 13 | if (x2 > x1): 14 | xs = 1 15 | else: 16 | xs = -1 17 | if (y2 > y1): 18 | ys = 1 19 | else: 20 | ys = -1 21 | if (z2 > z1): 22 | zs = 1 23 | else: 24 | zs = -1 25 | # Driving axis is X-axis" 26 | if (dx >= dy and dx >= dz): 27 | p1 = 2 * dy - dx 28 | p2 = 2 * dz - dx 29 | while (x1 != x2): 30 | x1 += xs 31 | if (p1 >= 0): 32 | y1 += ys 33 | p1 -= 2 * dx 34 | if (p2 >= 0): 35 | z1 += zs 36 | p2 -= 2 * dx 37 | p1 += 2 * dy 38 | p2 += 2 * dz 39 | 40 | ListOfPoints.append((x1, y1, z1)) 41 | 42 | # Driving axis is Y-axis" 43 | elif (dy >= dx and dy >= dz): 44 | p1 = 2 * dx - dy 45 | p2 = 2 * dz - dy 46 | while (y1 != y2): 47 | y1 += ys 48 | if (p1 >= 0): 49 | x1 += xs 50 | p1 -= 2 * dy 51 | if (p2 >= 0): 52 | z1 += zs 53 | p2 -= 2 * dy 54 | p1 += 2 * dx 55 | p2 += 2 * dz 56 | 57 | ListOfPoints.append((x1, y1, z1)) 58 | 59 | # Driving axis is Z-axis" 60 | else: 61 | p1 = 2 * dy - dz 62 | p2 = 2 * dx - dz 63 | while (z1 != z2): 64 | z1 += zs 65 | if (p1 >= 0): 66 | y1 += ys 67 | p1 -= 2 * dz 68 | if (p2 >= 0): 69 | x1 += xs 70 | p2 -= 2 * dz 71 | p1 += 2 * dy 72 | p2 += 2 * dx 73 | 74 | ListOfPoints.append((x1, y1, z1)) 75 | return ListOfPoints 76 | 77 | -------------------------------------------------------------------------------- /network/cylinder_fea_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch_scatter 8 | 9 | 10 | class cylinder_fea(nn.Module): 11 | 12 | def __init__(self, grid_size, fea_dim=3, 13 | out_pt_fea_dim=64, max_pt_per_encode=64, fea_compre=None): 14 | super(cylinder_fea, self).__init__() 15 | 16 | self.PPmodel = nn.Sequential( 17 | nn.BatchNorm1d(fea_dim), 18 | 19 | nn.Linear(fea_dim, 64), 20 | nn.BatchNorm1d(64), 21 | nn.ReLU(), 22 | 23 | nn.Linear(64, 128), 24 | nn.BatchNorm1d(128), 25 | nn.ReLU(), 26 | 27 | nn.Linear(128, 256), 28 | nn.BatchNorm1d(256), 29 | nn.ReLU(), 30 | 31 | nn.Linear(256, out_pt_fea_dim) 32 | ) 33 | 34 | self.max_pt = max_pt_per_encode 35 | self.fea_compre = fea_compre 36 | self.grid_size = grid_size 37 | kernel_size = 3 38 | self.local_pool_op = torch.nn.MaxPool2d(kernel_size, stride=1, 39 | padding=(kernel_size - 1) // 2, 40 | dilation=1) 41 | self.pool_dim = out_pt_fea_dim 42 | 43 | # point feature compression 44 | if self.fea_compre is not None: 45 | self.fea_compression = nn.Sequential( 46 | nn.Linear(self.pool_dim, self.fea_compre), 47 | nn.ReLU()) 48 | self.pt_fea_dim = self.fea_compre 49 | else: 50 | self.pt_fea_dim = self.pool_dim 51 | 52 | def forward(self, pt_fea, xy_ind): 53 | cur_dev = pt_fea[0].get_device() 54 | 55 | ### concate everything 56 | cat_pt_ind = [] 57 | for i_batch in range(len(xy_ind)): 58 | cat_pt_ind.append(F.pad(xy_ind[i_batch], (1, 0), 'constant', value=i_batch)) 59 | 60 | cat_pt_fea = torch.cat(pt_fea, dim=0) 61 | cat_pt_ind = torch.cat(cat_pt_ind, dim=0) 62 | pt_num = cat_pt_ind.shape[0] 63 | 64 | ### shuffle the data 65 | shuffled_ind = torch.randperm(pt_num, device=cur_dev) 66 | cat_pt_fea = cat_pt_fea[shuffled_ind, :] 67 | cat_pt_ind = cat_pt_ind[shuffled_ind, :] 68 | 69 | ### unique xy grid index 70 | unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind, return_inverse=True, return_counts=True, dim=0) 71 | unq = unq.type(torch.int64) 72 | 73 | ### process feature 74 | processed_cat_pt_fea = self.PPmodel(cat_pt_fea) 75 | pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0] 76 | 77 | if self.fea_compre: 78 | processed_pooled_data = self.fea_compression(pooled_data) 79 | else: 80 | processed_pooled_data = pooled_data 81 | 82 | return unq, processed_pooled_data 83 | -------------------------------------------------------------------------------- /utils/np_ioueval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import sys 5 | import numpy as np 6 | 7 | 8 | class iouEval: 9 | def __init__(self, n_classes, ignore=None): 10 | # classes 11 | self.n_classes = n_classes 12 | 13 | # What to include and ignore from the means 14 | self.ignore = np.array(ignore, dtype=np.int64) 15 | self.include = np.array( 16 | [n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64) 17 | # print("[IOU EVAL] IGNORE: ", self.ignore) 18 | # print("[IOU EVAL] INCLUDE: ", self.include) 19 | 20 | # reset the class counters 21 | self.reset() 22 | 23 | def num_classes(self): 24 | return self.n_classes 25 | 26 | def reset(self): 27 | self.conf_matrix = np.zeros((self.n_classes, 28 | self.n_classes), 29 | dtype=np.int64) 30 | 31 | def addBatch(self, x, y): # x=preds, y=targets 32 | # sizes should be matching 33 | x_row = x.reshape(-1) # de-batchify 34 | y_row = y.reshape(-1) # de-batchify 35 | 36 | # check 37 | assert(x_row.shape == y_row.shape) 38 | 39 | # create indexes 40 | idxs = tuple(np.stack((x_row, y_row), axis=0)) 41 | 42 | # make confusion matrix (cols = gt, rows = pred) 43 | np.add.at(self.conf_matrix, idxs, 1) 44 | 45 | def getStats(self): 46 | # remove fp from confusion on the ignore classes cols 47 | conf = self.conf_matrix.copy() 48 | conf[:, self.ignore] = 0 49 | 50 | # get the clean stats 51 | tp = np.diag(conf) 52 | fp = conf.sum(axis=1) - tp 53 | fn = conf.sum(axis=0) - tp 54 | return tp, fp, fn 55 | 56 | def getIoU(self): 57 | tp, fp, fn = self.getStats() 58 | intersection = tp 59 | union = tp + fp + fn + 1e-15 60 | iou = intersection / union 61 | iou_mean = (intersection[self.include] / union[self.include]).mean() 62 | return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES 63 | 64 | def getacc(self): 65 | tp, fp, fn = self.getStats() 66 | total_tp = tp.sum() 67 | total = tp[self.include].sum() + fp[self.include].sum() + 1e-15 68 | acc_mean = total_tp / total 69 | return acc_mean # returns "acc mean" 70 | 71 | def get_confusion(self): 72 | return self.conf_matrix.copy() 73 | 74 | 75 | 76 | if __name__ == "__main__": 77 | # mock problem 78 | nclasses = 2 79 | ignore = [] 80 | 81 | # test with 2 squares and a known IOU 82 | lbl = np.zeros((7, 7), dtype=np.int64) 83 | argmax = np.zeros((7, 7), dtype=np.int64) 84 | 85 | # put squares 86 | lbl[2:4, 2:4] = 1 87 | argmax[3:5, 3:5] = 1 88 | 89 | # make evaluator 90 | eval = iouEval(nclasses, ignore) 91 | 92 | # run 93 | eval.addBatch(argmax, lbl) 94 | m_iou, iou = eval.getIoU() 95 | print("IoU: ", m_iou) 96 | print("IoU class: ", iou) 97 | m_acc = eval.getacc() 98 | print("Acc: ", m_acc) 99 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | from pathlib import Path 5 | 6 | from strictyaml import Bool, Float, Int, Map, Seq, Str, as_document, load 7 | 8 | model_params = Map( 9 | { 10 | "model_architecture": Str(), 11 | "output_shape": Seq(Int()), 12 | "fea_dim": Int(), 13 | "out_fea_dim": Int(), 14 | "num_class": Int(), 15 | "num_input_features": Int(), 16 | "use_norm": Bool(), 17 | "init_size": Int(), 18 | } 19 | ) 20 | 21 | dataset_params = Map( 22 | { 23 | "dataset_type": Str(), 24 | "pc_dataset_type": Str(), 25 | "ignore_label": Int(), 26 | "return_test": Bool(), 27 | "fixed_volume_space": Bool(), 28 | "label_mapping": Str(), 29 | "max_volume_space": Seq(Float()), 30 | "min_volume_space": Seq(Float()), 31 | } 32 | ) 33 | 34 | 35 | train_data_loader = Map( 36 | { 37 | "data_path": Str(), 38 | "imageset": Str(), 39 | "return_ref": Bool(), 40 | "batch_size": Int(), 41 | "shuffle": Bool(), 42 | "num_workers": Int(), 43 | } 44 | ) 45 | 46 | val_data_loader = Map( 47 | { 48 | "data_path": Str(), 49 | "imageset": Str(), 50 | "return_ref": Bool(), 51 | "batch_size": Int(), 52 | "shuffle": Bool(), 53 | "num_workers": Int(), 54 | } 55 | ) 56 | 57 | 58 | train_params = Map( 59 | { 60 | "model_load_path": Str(), 61 | "model_save_path": Str(), 62 | "checkpoint_every_n_steps": Int(), 63 | "max_num_epochs": Int(), 64 | "eval_every_n_steps": Int(), 65 | "learning_rate": Float() 66 | } 67 | ) 68 | 69 | schema_v4 = Map( 70 | { 71 | "format_version": Int(), 72 | "model_params": model_params, 73 | "dataset_params": dataset_params, 74 | "train_data_loader": train_data_loader, 75 | "val_data_loader": val_data_loader, 76 | "train_params": train_params, 77 | } 78 | ) 79 | 80 | 81 | SCHEMA_FORMAT_VERSION_TO_SCHEMA = {4: schema_v4} 82 | 83 | 84 | def load_config_data(path: str) -> dict: 85 | yaml_string = Path(path).read_text() 86 | cfg_without_schema = load(yaml_string, schema=None) 87 | schema_version = int(cfg_without_schema["format_version"]) 88 | if schema_version not in SCHEMA_FORMAT_VERSION_TO_SCHEMA: 89 | raise Exception(f"Unsupported schema format version: {schema_version}.") 90 | 91 | strict_cfg = load(yaml_string, schema=SCHEMA_FORMAT_VERSION_TO_SCHEMA[schema_version]) 92 | cfg: dict = strict_cfg.data 93 | return cfg 94 | 95 | 96 | def config_data_to_config(data): # type: ignore 97 | return as_document(data, schema_v4) 98 | 99 | 100 | def save_config_data(data: dict, path: str) -> None: 101 | cfg_document = config_data_to_config(data) 102 | with open(Path(path), "w") as f: 103 | f.write(cfg_document.as_yaml()) 104 | -------------------------------------------------------------------------------- /config/label_mapping/semantic-poss-multiscan.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0: "unlabeled" 4 | 4: "1 person" 5 | 5: "2+ person" 6 | 6: "rider" 7 | 7: "car" 8 | 8: "trunk" 9 | 9: "plants" 10 | 10: "traffic sign 1" # standing sign 11 | 11: "traffic sign 2" # hanging sign 12 | 12: "traffic sign 3" # high/big hanging sign 13 | 13: "pole" 14 | 14: "trashcan" 15 | 15: "building" 16 | 17: "fence" 17 | 16: "cone/stone" 18 | 21: "bike" 19 | 22: "ground" # class definition 20 | 21 | color_map: # bgr 22 | 23 | 0 : [0, 0, 0] # 0: "unlabeled" 24 | 4 : [30, 30, 255] # 4: "1 person" 25 | 5 : [30, 30, 255] # 5: "2+ person" 26 | 6 : [200, 40, 255] # 6: "rider" 27 | 7 : [245, 150, 100] # 7: "car" 28 | 8 : [0,60,135] # 8: "trunk" 29 | 9 : [0, 175, 0] # 9: "plants" 30 | 10: [0, 0, 255] # 10: "traffic sign 1" 31 | 11: [0, 0, 255] # 11: "traffic sign 2" 32 | 12: [0, 0, 255] # 12: "traffic sign 3" 33 | 13: [150, 240, 255] # 13: "pole" 34 | 14: [0, 255, 125] # 14: "trashcan" 35 | 15: [0, 200, 255] # 15: "building" 36 | 16: [255, 255, 50] # 16: "cone/stone" 37 | 17: [50, 120, 255] # 17: "fence" 38 | 21: [245, 230, 100] # 21: "bike" 39 | 22: [128, 128, 128] # 22: "ground" 40 | 41 | content: # as a ratio with the total number of points 42 | 0: 0.018889854628292943 43 | 1: 0.0002937197336781505 44 | 10: 0.040818519255974316 45 | 11: 0.00016609538710764618 46 | 13: 2.7879693665067774e-05 47 | 15: 0.00039838616015114444 48 | 16: 0.0 49 | 18: 0.0020633612104619787 50 | 20: 0.0016218197275284021 51 | 30: 0.00017698551338515307 52 | 31: 1.1065903904919655e-08 53 | 32: 5.532951952459828e-09 54 | 40: 0.1987493871255525 55 | 44: 0.014717169549888214 56 | 48: 0.14392298360372 57 | 49: 0.0039048553037472045 58 | 50: 0.1326861944777486 59 | 51: 0.0723592229456223 60 | 52: 0.002395131480328884 61 | 60: 4.7084144280367186e-05 62 | 70: 0.26681502148037506 63 | 71: 0.006035012012626033 64 | 72: 0.07814222006271769 65 | 80: 0.002855498193863172 66 | 81: 0.0006155958086189918 67 | 99: 0.009923127583046915 68 | 252: 0.001789309418528068 69 | 253: 0.00012709999297008662 70 | 254: 0.00016059776092534436 71 | 255: 3.745553104802113e-05 72 | 256: 0.0 73 | 257: 0.00011351574470342043 74 | 258: 0.00010157861367183268 75 | 259: 4.3840131989471124e-05 76 | # classes that are indistinguishable from single scan or inconsistent in 77 | # ground truth are mapped to their closest equivalent 78 | 79 | # 11 CLASSES 80 | learning_map: 81 | 0: 0 #"unlabeled" 82 | 4: 1 # "1 person" --> "people" ----------------mapped 83 | 5: 1 # "2+ person" --> "people" ---------------mapped 84 | 6: 2 #"rider" 85 | 7: 3 #"car" 86 | 8: 4 #"trunk" 87 | 9: 5 #"plants" 88 | 10: 6 # "traffic sign 1" # standing sign -->traffic sign----------------mapped 89 | 11: 6 #"traffic sign 2" # hanging sign-->traffic sign----------------mapped 90 | 12: 6 #"traffic sign 3" # high/big hanging sign-->traffic sign----------------mapped 91 | 13: 7 #"pole" 92 | 14: 0 #"trashcan" --> "unlabeled" ----------------mapped 93 | 15: 8 #"building" 94 | 16: 0 # "cone/stone" --> "unlabeled" ----------------mapped 95 | 17: 9 # "fence" 96 | 21: 10 #"bike" 97 | 22: 11 #"ground" # class definition 98 | 99 | learning_map_inv: # inverse of previous map 100 | 0: 0 # "unlabeled" 101 | 1: 4 # "people" 102 | 2: 6 # "rider" 103 | 3: 7 # "car" 104 | 4: 8 # "trunk" 105 | 5: 9 # "plants" 106 | 6: 10 # "traffic sign" 107 | 7: 13 # "pole" 108 | 8: 15 # "building" 109 | 9: 17 # "fence" 110 | 10: 21 # "bike" 111 | 11: 22 # "ground" 112 | 113 | learning_ignore: # Ignore classes 114 | 0: True # "unlabeled", and others ignored 115 | 1: False # "car" 116 | 2: False # "bicycle" 117 | 3: False # "motorcycle" 118 | 4: False # "truck" 119 | 5: False # "other-vehicle" 120 | 6: False # "person" 121 | 7: False # "bicyclist" 122 | 8: False # "motorcyclist" 123 | 9: False # "road" 124 | 10: False # "parking" 125 | 11: False # "sidewalk" 126 | 127 | 128 | split: # sequence numbers 129 | train: 130 | - 0 131 | - 1 132 | - 3 133 | - 4 134 | - 5 135 | valid: 136 | - 2 137 | test: 138 | - 2 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | addict==2.4.0 3 | aliyun-python-sdk-core==2.14.0 4 | aliyun-python-sdk-kms==2.16.2 5 | ansi2html==1.9.1 6 | antlr4-python3-runtime==4.9.3 7 | asttokens==2.4.1 8 | attrs==23.2.0 9 | backcall==0.2.0 10 | basicsr==1.4.2 11 | blinker==1.7.0 12 | cache==1.0.3 13 | cachetools==5.3.2 14 | ccimport==0.4.2 15 | certifi==2024.2.2 16 | cffi==1.16.0 17 | charset-normalizer==3.3.2 18 | click==8.1.7 19 | colorama==0.4.6 20 | comm==0.2.1 21 | ConfigArgParse==1.7 22 | contourpy==1.1.1 23 | crcmod==1.7 24 | cryptography==42.0.2 25 | cumm==0.4.11 26 | cumm-cu113==0.4.11 27 | cycler==0.12.1 28 | Cython==3.0.9 29 | dash==2.15.0 30 | dash-core-components==2.0.0 31 | dash-html-components==2.0.0 32 | dash-table==5.0.0 33 | decorator==5.1.1 34 | easydict==1.11 35 | einops==0.7.0 36 | executing==2.0.1 37 | fastjsonschema==2.19.1 38 | filelock==3.13.1 39 | fire==0.6.0 40 | Flask==3.0.2 41 | fonttools==4.47.2 42 | freetype-py==2.4.0 43 | fsspec==2024.2.0 44 | future==0.18.3 45 | gin==0.1.6 46 | gin-config==0.5.0 47 | glfw==1.8.3 48 | google-auth==2.27.0 49 | google-auth-oauthlib==1.0.0 50 | grpcio==1.60.1 51 | h5py==3.10.0 52 | hsluv==5.0.4 53 | huggingface-hub==0.20.3 54 | idna==3.6 55 | imageio==2.33.1 56 | imageio-ffmpeg==0.4.9 57 | importlib-metadata==7.0.1 58 | importlib-resources==6.1.1 59 | ipython==8.12.3 60 | ipywidgets==8.1.1 61 | itsdangerous==2.1.2 62 | jedi==0.19.1 63 | Jinja2==3.1.3 64 | jmespath==0.10.0 65 | joblib==1.3.2 66 | jsonschema==4.21.1 67 | jsonschema-specifications==2023.12.1 68 | jupyter_core==5.7.1 69 | jupyterlab-widgets==3.0.9 70 | kiwisolver==1.4.5 71 | kornia==0.7.1 72 | lark==1.1.9 73 | lazy_loader==0.3 74 | lightning-utilities==0.10.1 75 | llvmlite==0.41.1 76 | lmdb==1.4.1 77 | loralib==0.1.2 78 | Markdown==3.5.2 79 | markdown-it-py==3.0.0 80 | MarkupSafe==2.1.5 81 | matplotlib==3.7.4 82 | matplotlib-inline==0.1.6 83 | mdurl==0.1.2 84 | MinkowskiEngine==0.5.4 85 | mkl-fft==1.3.1 86 | mkl-random==1.2.2 87 | mkl-service==2.4.0 88 | mlxtend==0.23.1 89 | mmcv-full==1.7.2 90 | mmengine==0.10.3 91 | model-index==0.1.11 92 | nbformat==5.9.2 93 | nest-asyncio==1.6.0 94 | networkx==3.1 95 | ninja==1.11.1.1 96 | numba==0.58.1 97 | numpy==1.23.0 98 | numpy-indexed==0.3.7 99 | nvidia-ml-py3==7.352.0 100 | oauthlib==3.2.2 101 | olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1701735466804/work 102 | omegaconf==2.3.0 103 | open3d==0.18.0 104 | opencv-python==4.9.0.80 105 | opendatalab==0.0.10 106 | openmim==0.3.9 107 | openxlab==0.0.34 108 | ordered-set==4.1.0 109 | oss2==2.17.0 110 | packaging==23.2 111 | pandas==2.0.3 112 | parso==0.8.3 113 | pccm==0.4.11 114 | pexpect==4.9.0 115 | pickleshare==0.7.5 116 | Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1630696607296/work 117 | pkgutil_resolve_name==1.3.10 118 | platformdirs==4.2.0 119 | plotly==5.18.0 120 | plyfile==1.0.3 121 | portalocker==2.8.2 122 | prompt-toolkit==3.0.43 123 | protobuf==4.25.2 124 | ptyprocess==0.7.0 125 | pure-eval==0.2.2 126 | purge==1.0 127 | pyasn1==0.5.1 128 | pyasn1-modules==0.3.0 129 | pybind11==2.11.1 130 | pycocotools==2.0.7 131 | pycparser==2.21 132 | pycryptodome==3.20.0 133 | Pygments==2.17.2 134 | pyparsing==3.1.1 135 | pyquaternion==0.9.9 136 | python-dateutil==2.8.2 137 | pytorch-msssim==1.0.0 138 | pytz==2023.4 139 | PyWavelets==1.4.1 140 | PyYAML==6.0.1 141 | referencing==0.33.0 142 | requests==2.28.2 143 | requests-oauthlib==1.3.1 144 | retrying==1.3.4 145 | rich==13.4.2 146 | rpds-py==0.17.1 147 | rsa==4.9 148 | safetensors==0.4.2 149 | scikit-image==0.21.0 150 | scikit-learn==1.3.2 151 | scipy==1.10.1 152 | seaborn==0.13.2 153 | shapely==2.0.2 154 | SharedArray==3.2.3 155 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 156 | spconv @ file:///home/user/spconv1.0/dist/spconv-1.0-cp38-cp38-linux_x86_64.whl#sha256=251459523b0dd39a7a5f29f5b3391a15a023d842c835c169b946bbe37c8e60af 157 | stack-data==0.6.3 158 | strictyaml==1.7.3 159 | tabulate==0.9.0 160 | tenacity==8.2.3 161 | tensorboard==2.14.0 162 | tensorboard-data-server==0.7.2 163 | tensorboardX==2.6.2.2 164 | termcolor==2.4.0 165 | terminaltables==3.1.10 166 | threadpoolctl==3.2.0 167 | tifffile==2023.7.10 168 | timm==0.9.12 169 | tomli==2.0.1 170 | torch==1.10.1 171 | torch-efficient-distloss==0.1.3 172 | torch-scatter==2.1.2 173 | torchaudio==0.10.1 174 | torchmetrics==1.3.0.post0 175 | torchvision==0.11.2 176 | tqdm==4.65.2 177 | traitlets==5.14.1 178 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1702176139754/work 179 | tzdata==2023.4 180 | urllib3==1.26.18 181 | vispy==0.14.2 182 | wcwidth==0.2.13 183 | Werkzeug==3.0.1 184 | widgetsnbextension==4.0.9 185 | yapf==0.40.2 186 | zipp==3.17.0 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [NeurIPS 2024] TALoS: Enhancing Semantic Scene Completion via Test-time Adaptation on the Line of Sight 2 | 3 | This repository contains the official PyTorch implementation of the paper "TALoS: Enhancing Semantic Scene Completion via Test-time Adaptation on the Line of Sight" paper (NeurIPS 2024) by [Hyun-Kurl Jang*](https://blue-531.github.io/ 4 | ) , [Jihun Kim*](https://jihun1998.github.io/ 5 | ) and [Hyeokjun Kweon*](https://sangrockeg.github.io/ 6 | ). 7 | 8 | (* denotes equal contribution.) 9 | 10 | [[Paper]](https://arxiv.org/abs/2410.15674) 11 | 12 | ## News 13 | 20 | 21 | ## Introduction 22 | 23 | 24 | Our main idea is simple yet effective: 25 | **an observation made at one moment could serve as supervision for the SSC prediction at another moment.** 26 | While traveling through an environment, an autonomous vehicle can continuously observe the overall scene structures, including objects that were previously occluded (or will be occluded later), which are concrete guidances for the adaptation of scene completion. Given the characteristics of the LiDAR sensor, an observation of a point at a specific spatial location at a specific moment confirms not only the occupation at that location itself but also the absence of obstacles along the line of sight from the sensor to that location. 27 | The proposed method, named 28 | **Test-time Adaptation via Line of Sight (TALoS)** 29 | , is designed to explicitly leverage these characteristics, obtaining self-supervision for geometric completion. 30 | Additionally, we extend the TALoS framework for semantic recognition, another key goal of SSC, by collecting the reliable regions only among the semantic segmentation results predicted at each moment. 31 | Further, to leverage valuable future information that is not accessible at the time of the current update, we devise a novel dual optimization scheme involving the model gradually updating across the temporal dimension. 32 | ## Installation 33 | 34 | - PyTorch >= 1.10 35 | - pyyaml 36 | - Cython 37 | - tqdm 38 | - numba 39 | - Numpy-indexed 40 | - [torch-scatter](https://github.com/rusty1s/pytorch_scatter) 41 | - [spconv](https://github.com/tyjiang1997/spconv1.0) (tested with spconv==1.0 and cuda==11.3) 42 | 43 | 44 | 45 | ## Data Preparation 46 | 47 | ### SemanticKITTI 48 | ``` 49 | ./ 50 | ├── 51 | ├── ... 52 | ├── model_load_dir 53 | ├──pretrained.pth 54 | └── dataset/ 55 | ├──sequences 56 | ├── 00/ 57 | │ ├── velodyne/ 58 | | | ├── 000000.bin 59 | | | ├── 000001.bin 60 | | | └── ... 61 | │ └── labels/ 62 | | ├── 000000.label 63 | | ├── 000001.label 64 | | └── ... 65 | │ └── voxels/ 66 | | ├── 000000.bin 67 | | ├── 000000.label 68 | | ├── 000000.invalid 69 | | ├── 000000.occluded 70 | | ├── 000001.bin 71 | | ├── 000001.label 72 | | ├── 000001.invalid 73 | | ├── 000001.occluded 74 | | └── ... 75 | ├── 08/ # for validation 76 | ├── 11/ # 11-21 for testing 77 | └── 21/ 78 | └── ... 79 | ``` 80 | 81 | ## Test-Time Adaptation 82 | 1. Download the pre-trained models and put them in ```./model_load_dir```. [[link]](https://drive.google.com/file/d/12jYauPbVodnSA-faBjFucUNgxeGU0pmP/view?usp=drive_link) 83 | 2. (Optional) Download pre-trained model results and put them in ```./experiments/baseline``` for comparison. [[link]](https://drive.google.com/file/d/1gt65t7hkdnnax2v7BALgUsunTaGHRVkh/view?usp=drive_link) 84 | 3. Generate predictions on the Dataset. 85 | 86 | ### Validation set 87 | ``` 88 | python run_tta_val.py --do_adapt --do_cont --use_los --use_pgt 89 | ``` 90 | ### Test set 91 | ``` 92 | python run_tta_test.py --do_adapt --do_cont --use_los --use_pgt --sq_num={sequence number} 93 | ``` 94 | ## Evaluation 95 | To evaluate test sequences in SemanticKITTI, you should submit the generated predictions to [link](https://codalab.lisn.upsaclay.fr/competitions/7170). 96 | After generate predictions, prepare your submission in the designated format, as described in the competition page. 97 | Use the validation script from the [semantic-kitti-api](https://github.com/PRBonn/semantic-kitti-api) to ensure that the folder structure and number of label files in the zip file is correct. 98 | 99 | 100 | 101 | 102 | ## Acknowledgements 103 | We thanks for the open source project [SCPNet](https://github.com/SCPNet/Codes-for-SCPNet). 104 | -------------------------------------------------------------------------------- /config/label_mapping/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /config/label_mapping/semantic-kitti-multiscan.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /config/label_mapping/semantic-kitti-all.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 20 # "moving-car" 137 | 253: 21 # "moving-bicyclist" 138 | 254: 22 # "moving-person" 139 | 255: 23 # "moving-motorcyclist" 140 | 256: 24 # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped 141 | 257: 24 # "moving-bus" mapped to "moving-other-vehicle" -----------mapped 142 | 258: 25 # "moving-truck" 143 | 259: 24 # "moving-other-vehicle" 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | 20: 252 # "moving-car" 166 | 21: 253 # "moving-bicyclist" 167 | 22: 254 # "moving-person" 168 | 23: 255 # "moving-motorcyclist" 169 | 24: 259 # "moving-other-vehicle" 170 | 25: 258 # "moving-truck" 171 | learning_ignore: # Ignore classes 172 | 0: True # "unlabeled", and others ignored 173 | 1: False # "car" 174 | 2: False # "bicycle" 175 | 3: False # "motorcycle" 176 | 4: False # "truck" 177 | 5: False # "other-vehicle" 178 | 6: False # "person" 179 | 7: False # "bicyclist" 180 | 8: False # "motorcyclist" 181 | 9: False # "road" 182 | 10: False # "parking" 183 | 11: False # "sidewalk" 184 | 12: False # "other-ground" 185 | 13: False # "building" 186 | 14: False # "fence" 187 | 15: False # "vegetation" 188 | 16: False # "trunk" 189 | 17: False # "terrain" 190 | 18: False # "pole" 191 | 19: False # "traffic-sign" 192 | 20: False # "moving-car" 193 | 21: False # "moving-bicyclist" 194 | 22: False # "moving-person" 195 | 23: False # "moving-motorcyclist" 196 | 24: False # "moving-other-vehicle" 197 | 25: False # "moving-truck" 198 | split: # sequence numbers 199 | train: 200 | - 0 201 | - 1 202 | - 2 203 | - 3 204 | - 4 205 | - 5 206 | - 6 207 | - 7 208 | - 9 209 | - 10 210 | valid: 211 | - 8 212 | test: 213 | - 11 214 | - 12 215 | - 13 216 | - 14 217 | - 15 218 | - 16 219 | - 17 220 | - 18 221 | - 19 222 | - 20 223 | - 21 224 | -------------------------------------------------------------------------------- /network/conv_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.ReLU(inplace=True), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, ): 17 | super().__init__() 18 | 19 | for i in range(len(args) - 1): 20 | self.add_module( 21 | name + 'layer{}'.format(i), 22 | Conv2d( 23 | args[i], 24 | args[i + 1], 25 | bn=(not first or not preact or (i != 0)) and bn, 26 | activation=activation 27 | if (not first or not preact or (i != 0)) else None, 28 | preact=preact, 29 | instance_norm=instance_norm 30 | ) 31 | ) 32 | 33 | 34 | class _ConvBase(nn.Sequential): 35 | 36 | def __init__( 37 | self, 38 | in_size, 39 | out_size, 40 | kernel_size, 41 | stride, 42 | padding, 43 | activation, 44 | bn, 45 | init, 46 | conv=None, 47 | batch_norm=None, 48 | bias=True, 49 | preact=False, 50 | name="", 51 | instance_norm=False, 52 | instance_norm_func=None 53 | ): 54 | super().__init__() 55 | 56 | bias = bias and (not bn) 57 | conv_unit = conv( 58 | in_size, 59 | out_size, 60 | kernel_size=kernel_size, 61 | stride=stride, 62 | padding=padding, 63 | bias=bias 64 | ) 65 | init(conv_unit.weight) 66 | if bias: 67 | nn.init.constant_(conv_unit.bias, 0) 68 | 69 | if bn: 70 | if not preact: 71 | bn_unit = batch_norm(out_size) 72 | else: 73 | bn_unit = batch_norm(in_size) 74 | if instance_norm: 75 | if not preact: 76 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 77 | else: 78 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 79 | 80 | if preact: 81 | if bn: 82 | self.add_module(name + 'bn', bn_unit) 83 | 84 | if activation is not None: 85 | self.add_module(name + 'activation', activation) 86 | 87 | if not bn and instance_norm: 88 | self.add_module(name + 'in', in_unit) 89 | 90 | self.add_module(name + 'conv', conv_unit) 91 | 92 | if not preact: 93 | if bn: 94 | self.add_module(name + 'bn', bn_unit) 95 | 96 | if activation is not None: 97 | self.add_module(name + 'activation', activation) 98 | 99 | if not bn and instance_norm: 100 | self.add_module(name + 'in', in_unit) 101 | 102 | 103 | class _BNBase(nn.Sequential): 104 | 105 | def __init__(self, in_size, batch_norm=None, name=""): 106 | super().__init__() 107 | self.add_module(name + "bn", batch_norm(in_size)) 108 | 109 | nn.init.constant_(self[0].weight, 1.0) 110 | nn.init.constant_(self[0].bias, 0) 111 | 112 | 113 | class BatchNorm1d(_BNBase): 114 | 115 | def __init__(self, in_size: int, *, name: str = ""): 116 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 117 | 118 | 119 | class BatchNorm2d(_BNBase): 120 | 121 | def __init__(self, in_size: int, name: str = ""): 122 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 123 | 124 | 125 | class BatchNorm3d(_BNBase): 126 | 127 | def __init__(self, in_size: int, name: str = ""): 128 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 129 | 130 | 131 | class Conv1d(_ConvBase): 132 | 133 | def __init__( 134 | self, 135 | in_size: int, 136 | out_size: int, 137 | *, 138 | kernel_size: int = 1, 139 | stride: int = 1, 140 | padding: int = 0, 141 | activation=nn.ReLU(inplace=True), 142 | bn: bool = False, 143 | init=nn.init.kaiming_normal_, 144 | bias: bool = True, 145 | preact: bool = False, 146 | name: str = "", 147 | instance_norm=False 148 | ): 149 | super().__init__( 150 | in_size, 151 | out_size, 152 | kernel_size, 153 | stride, 154 | padding, 155 | activation, 156 | bn, 157 | init, 158 | conv=nn.Conv1d, 159 | batch_norm=BatchNorm1d, 160 | bias=bias, 161 | preact=preact, 162 | name=name, 163 | instance_norm=instance_norm, 164 | instance_norm_func=nn.InstanceNorm1d 165 | ) 166 | 167 | 168 | class Conv2d(_ConvBase): 169 | 170 | def __init__( 171 | self, 172 | in_size: int, 173 | out_size: int, 174 | *, 175 | kernel_size: Tuple[int, int] = (1, 1), 176 | stride: Tuple[int, int] = (1, 1), 177 | padding: Tuple[int, int] = (0, 0), 178 | activation=nn.ReLU(inplace=True), 179 | bn: bool = False, 180 | init=nn.init.kaiming_normal_, 181 | bias: bool = True, 182 | preact: bool = False, 183 | name: str = "", 184 | instance_norm=False 185 | ): 186 | super().__init__( 187 | in_size, 188 | out_size, 189 | kernel_size, 190 | stride, 191 | padding, 192 | activation, 193 | bn, 194 | init, 195 | conv=nn.Conv2d, 196 | batch_norm=BatchNorm2d, 197 | bias=bias, 198 | preact=preact, 199 | name=name, 200 | instance_norm=instance_norm, 201 | instance_norm_func=nn.InstanceNorm2d 202 | ) 203 | 204 | 205 | class Conv3d(_ConvBase): 206 | 207 | def __init__( 208 | self, 209 | in_size: int, 210 | out_size: int, 211 | *, 212 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 213 | stride: Tuple[int, int, int] = (1, 1, 1), 214 | padding: Tuple[int, int, int] = (0, 0, 0), 215 | activation=nn.ReLU(inplace=True), 216 | bn: bool = False, 217 | init=nn.init.kaiming_normal_, 218 | bias: bool = True, 219 | preact: bool = False, 220 | name: str = "", 221 | instance_norm=False 222 | ): 223 | super().__init__( 224 | in_size, 225 | out_size, 226 | kernel_size, 227 | stride, 228 | padding, 229 | activation, 230 | bn, 231 | init, 232 | conv=nn.Conv3d, 233 | batch_norm=BatchNorm3d, 234 | bias=bias, 235 | preact=preact, 236 | name=name, 237 | instance_norm=instance_norm, 238 | instance_norm_func=nn.InstanceNorm3d 239 | ) 240 | 241 | 242 | class FC(nn.Sequential): 243 | 244 | def __init__( 245 | self, 246 | in_size: int, 247 | out_size: int, 248 | *, 249 | activation=nn.ReLU(inplace=True), 250 | bn: bool = False, 251 | init=None, 252 | preact: bool = False, 253 | name: str = "" 254 | ): 255 | super().__init__() 256 | 257 | fc = nn.Linear(in_size, out_size, bias=not bn) 258 | if init is not None: 259 | init(fc.weight) 260 | if not bn: 261 | nn.init.constant(fc.bias, 0) 262 | 263 | if preact: 264 | if bn: 265 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 266 | 267 | if activation is not None: 268 | self.add_module(name + 'activation', activation) 269 | 270 | self.add_module(name + 'fc', fc) 271 | 272 | if not preact: 273 | if bn: 274 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 275 | 276 | if activation is not None: 277 | self.add_module(name + 'activation', activation) 278 | 279 | 280 | -------------------------------------------------------------------------------- /utils/load_save_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: load_save_util.py 4 | 5 | import torch 6 | 7 | 8 | # def load_checkpoint_old2(model_load_path, model): 9 | def load_checkpoint(model_load_path, model): 10 | pre_weight = torch.load(model_load_path) 11 | my_model_dict = model.state_dict() 12 | part_load = {} 13 | match_size = 0 14 | nomatch_size = 0 15 | for k in pre_weight.keys(): 16 | value = pre_weight[k] 17 | # str3 = 'seg_head.sparseModel.1.weight' 18 | # if k.find(str3) > 0: 19 | # value = value[:, :, 0, :] 20 | if k in my_model_dict and my_model_dict[k].shape == value.shape: 21 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 22 | match_size += 1 23 | part_load[k] = value 24 | else: 25 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 26 | Lvalue = len(value.shape) 27 | assert 1 <= Lvalue <= 5 28 | if len(value.shape) == 1: 29 | c = value.shape[0] 30 | cc = my_model_dict[k].shape[0] - c #int(c*0.5) 31 | if 0 < cc <= c: 32 | value = torch.cat([value, value[:cc]], dim=0) 33 | elif cc > c: 34 | value = torch.cat([value, value, value[:(cc-c)]], dim=0) 35 | elif cc < 0: 36 | value = value[:-cc] 37 | else: 38 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 39 | cs = value.shape 40 | ccs = [0]*Lvalue 41 | j = -1 42 | for ci in cs: 43 | j += 1 44 | ccs[j] = (my_model_dict[k].shape[j] - ci) 45 | if ccs[j] != 0: 46 | for m in range(Lvalue): 47 | if m != j: 48 | ccs[m] = value.shape[m] 49 | # print(ccs) 50 | if ccs[j] > 0: 51 | if ccs[j] > ci: 52 | ccs[j] = ccs[j] - ci 53 | if Lvalue == 5: 54 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2], :ccs[3], :ccs[4]]], dim=j) 55 | elif Lvalue == 4: 56 | # print(value[:ccs[0], :ccs[1], :ccs[2], :ccs[3]].shape) 57 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2], :ccs[3]]], dim=j) 58 | elif Lvalue == 3: 59 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2]]], dim=j) 60 | elif Lvalue == 2: 61 | value = torch.cat([value, value[:ccs[0], :ccs[1]]], dim=j) 62 | ccs[j] = value.shape[j] 63 | elif ccs[j] < 0: 64 | # ccs[j] = -ccs[j] 65 | ccs[j] = my_model_dict[k].shape[j] 66 | if j == 0: 67 | value = value[:ccs[0], :] 68 | elif j == 1: 69 | if Lvalue == 2: 70 | value = value[:, :ccs[1]] 71 | else: 72 | value = value[:, :ccs[1], :] 73 | elif j == 2: 74 | if Lvalue == 3: 75 | value = value[:, :, :ccs[2]] 76 | else: 77 | value = value[:, :, :ccs[2], :] 78 | elif j == 3 and j <= Lvalue-1: 79 | if Lvalue == 4: 80 | value = value[:, :, :, :ccs[3]] 81 | else: 82 | value = value[:, :, :, :ccs[3], :] 83 | elif j == 4 and j <= Lvalue-1: 84 | value = value[:, :, :, :, :ccs[4]] 85 | 86 | nomatch_size += 1 87 | if my_model_dict[k].shape == value.shape: 88 | part_load[k] = value 89 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 90 | # assert my_model_dict[k].shape == value.shape 91 | 92 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 93 | 94 | my_model_dict.update(part_load) 95 | model.load_state_dict(my_model_dict) 96 | # model.load_state_dict(my_model_dict, strict=False) # True 97 | 98 | return model 99 | 100 | 101 | def load_checkpoint_old2(model_load_path, model): 102 | my_model_dict = model.state_dict() 103 | pre_weight = torch.load(model_load_path) 104 | 105 | part_load = {} 106 | match_size = 0 107 | nomatch_size = 0 108 | for k in pre_weight.keys(): 109 | value = pre_weight[k] 110 | if k in my_model_dict and my_model_dict[k].shape == value.shape: 111 | #print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 112 | match_size += 1 113 | part_load[k] = value 114 | else: 115 | assert len(value.shape) == 1 or len(value.shape) == 5 116 | if len(value.shape) == 1: 117 | c = value.shape[0] 118 | cc = my_model_dict[k].shape[0] - c #int(c*0.5) 119 | if cc <= c: 120 | value = torch.cat([value, value[:cc]], dim=0) 121 | else: 122 | value = torch.cat([value, value, value[:(cc-c)]], dim=0) 123 | else: 124 | _, _, _, c1, c2 = value.shape 125 | cc1 = my_model_dict[k].shape[3] - c1 #int(c1*0.5) 126 | cc2 = my_model_dict[k].shape[4] - c2 #int(c2*0.5) 127 | if cc1 > 0 and cc1 <= c1: 128 | value1 = torch.cat([value, value[:, :, :, :cc1, :]], dim=3) 129 | elif cc1 > c1: 130 | value1 = torch.cat([value, value, value[:, :, :, :(cc1-c1), :]], dim=3) 131 | else: 132 | value1 = value 133 | if cc2 > 0 and cc2 <= c2: 134 | value = torch.cat([value1, value1[:, :, :, :, :cc2]], dim=4) 135 | elif cc2 > c2: 136 | value = torch.cat([value1, value1, value1[:, :, :, :, :(cc2-c2)]], dim=4) 137 | else: 138 | value = value1 139 | nomatch_size += 1 140 | part_load[k] = value 141 | assert my_model_dict[k].shape == value.shape 142 | #print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 143 | 144 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 145 | 146 | my_model_dict.update(part_load) 147 | # model.load_state_dict(my_model_dict) 148 | model.load_state_dict(my_model_dict, strict=False) # True 149 | 150 | return model 151 | 152 | def load_checkpoint_old(model_load_path, model): 153 | my_model_dict = model.state_dict() 154 | pre_weight = torch.load(model_load_path) 155 | 156 | part_load = {} 157 | match_size = 0 158 | nomatch_size = 0 159 | for k in pre_weight.keys(): 160 | value = pre_weight[k] 161 | if k in my_model_dict and my_model_dict[k].shape == value.shape: 162 | # print("loading ", k) 163 | match_size += 1 164 | part_load[k] = value 165 | else: 166 | nomatch_size += 1 167 | 168 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 169 | 170 | my_model_dict.update(part_load) 171 | model.load_state_dict(my_model_dict) 172 | 173 | return model 174 | 175 | def load_checkpoint_1b1(model_load_path, model): 176 | my_model_dict = model.state_dict() 177 | pre_weight = torch.load(model_load_path) 178 | 179 | part_load = {} 180 | match_size = 0 181 | nomatch_size = 0 182 | 183 | pre_weight_list = [*pre_weight] 184 | my_model_dict_list = [*my_model_dict] 185 | 186 | for idx in range(len(pre_weight_list)): 187 | key_ = pre_weight_list[idx] 188 | key_2 = my_model_dict_list[idx] 189 | value_ = pre_weight[key_] 190 | if my_model_dict[key_2].shape == pre_weight[key_].shape: 191 | # print("loading ", k) 192 | match_size += 1 193 | part_load[key_2] = value_ 194 | else: 195 | print(key_) 196 | print(key_2) 197 | nomatch_size += 1 198 | 199 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 200 | 201 | my_model_dict.update(part_load) 202 | model.load_state_dict(my_model_dict) 203 | 204 | return model 205 | -------------------------------------------------------------------------------- /utils/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | 5 | """ 6 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 7 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 8 | """ 9 | 10 | from __future__ import print_function, division 11 | 12 | import torch 13 | from torch.autograd import Variable 14 | import torch.nn.functional as F 15 | import numpy as np 16 | try: 17 | from itertools import ifilterfalse 18 | except ImportError: # py3k 19 | from itertools import filterfalse as ifilterfalse 20 | 21 | def lovasz_grad(gt_sorted): 22 | """ 23 | Computes gradient of the Lovasz extension w.r.t sorted errors 24 | See Alg. 1 in paper 25 | """ 26 | p = len(gt_sorted) 27 | gts = gt_sorted.sum() 28 | intersection = gts - gt_sorted.float().cumsum(0) 29 | union = gts + (1 - gt_sorted).float().cumsum(0) 30 | jaccard = 1. - intersection / union 31 | if p > 1: # cover 1-pixel case 32 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 33 | return jaccard 34 | 35 | 36 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 37 | """ 38 | IoU for foreground class 39 | binary: 1 foreground, 0 background 40 | """ 41 | if not per_image: 42 | preds, labels = (preds,), (labels,) 43 | ious = [] 44 | for pred, label in zip(preds, labels): 45 | intersection = ((label == 1) & (pred == 1)).sum() 46 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 47 | if not union: 48 | iou = EMPTY 49 | else: 50 | iou = float(intersection) / float(union) 51 | ious.append(iou) 52 | iou = mean(ious) # mean accross images if per_image 53 | return 100 * iou 54 | 55 | 56 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 57 | """ 58 | Array of IoU for each (non ignored) class 59 | """ 60 | if not per_image: 61 | preds, labels = (preds,), (labels,) 62 | ious = [] 63 | for pred, label in zip(preds, labels): 64 | iou = [] 65 | for i in range(C): 66 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 67 | intersection = ((label == i) & (pred == i)).sum() 68 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 69 | if not union: 70 | iou.append(EMPTY) 71 | else: 72 | iou.append(float(intersection) / float(union)) 73 | ious.append(iou) 74 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 75 | return 100 * np.array(ious) 76 | 77 | 78 | # --------------------------- BINARY LOSSES --------------------------- 79 | 80 | 81 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 82 | """ 83 | Binary Lovasz hinge loss 84 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 85 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 86 | per_image: compute the loss per image instead of per batch 87 | ignore: void class id 88 | """ 89 | if per_image: 90 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 91 | for log, lab in zip(logits, labels)) 92 | else: 93 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 94 | return loss 95 | 96 | 97 | def lovasz_hinge_flat(logits, labels): 98 | """ 99 | Binary Lovasz hinge loss 100 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 101 | labels: [P] Tensor, binary ground truth labels (0 or 1) 102 | ignore: label to ignore 103 | """ 104 | if len(labels) == 0: 105 | # only void pixels, the gradients should be 0 106 | return logits.sum() * 0. 107 | signs = 2. * labels.float() - 1. 108 | errors = (1. - logits * Variable(signs)) 109 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 110 | perm = perm.data 111 | gt_sorted = labels[perm] 112 | grad = lovasz_grad(gt_sorted) 113 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 114 | return loss 115 | 116 | 117 | def flatten_binary_scores(scores, labels, ignore=None): 118 | """ 119 | Flattens predictions in the batch (binary case) 120 | Remove labels equal to 'ignore' 121 | """ 122 | scores = scores.view(-1) 123 | labels = labels.view(-1) 124 | if ignore is None: 125 | return scores, labels 126 | valid = (labels != ignore) 127 | vscores = scores[valid] 128 | vlabels = labels[valid] 129 | return vscores, vlabels 130 | 131 | 132 | class StableBCELoss(torch.nn.modules.Module): 133 | def __init__(self): 134 | super(StableBCELoss, self).__init__() 135 | def forward(self, input, target): 136 | neg_abs = - input.abs() 137 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 138 | return loss.mean() 139 | 140 | 141 | def binary_xloss(logits, labels, ignore=None): 142 | """ 143 | Binary Cross entropy loss 144 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 145 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 146 | ignore: void class id 147 | """ 148 | logits, labels = flatten_binary_scores(logits, labels, ignore) 149 | loss = StableBCELoss()(logits, Variable(labels.float())) 150 | return loss 151 | 152 | 153 | # --------------------------- MULTICLASS LOSSES --------------------------- 154 | 155 | 156 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 157 | """ 158 | Multi-class Lovasz-Softmax loss 159 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 160 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 161 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 162 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 163 | per_image: compute the loss per image instead of per batch 164 | ignore: void class labels 165 | """ 166 | if per_image: 167 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 168 | for prob, lab in zip(probas, labels)) 169 | else: 170 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 171 | return loss 172 | 173 | 174 | def lovasz_softmax_flat(probas, labels, classes='present'): 175 | """ 176 | Multi-class Lovasz-Softmax loss 177 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 178 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 179 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 180 | """ 181 | if probas.numel() == 0: 182 | # only void pixels, the gradients should be 0 183 | return probas * 0. 184 | C = probas.size(1) 185 | losses = [] 186 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 187 | for c in class_to_sum: 188 | fg = (labels == c).float() # foreground for class c 189 | if (classes is 'present' and fg.sum() == 0): 190 | continue 191 | if C == 1: 192 | if len(classes) > 1: 193 | raise ValueError('Sigmoid output possible only with 1 class') 194 | class_pred = probas[:, 0] 195 | else: 196 | class_pred = probas[:, c] 197 | errors = (Variable(fg) - class_pred).abs() 198 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 199 | perm = perm.data 200 | fg_sorted = fg[perm] 201 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 202 | return mean(losses) 203 | 204 | 205 | def flatten_probas(probas, labels, ignore=None): 206 | """ 207 | Flattens predictions in the batch 208 | """ 209 | if probas.dim() == 3: 210 | # assumes output of a sigmoid layer 211 | B, H, W = probas.size() 212 | probas = probas.view(B, 1, H, W) 213 | elif probas.dim() == 5: 214 | #3D segmentation 215 | B, C, L, H, W = probas.size() 216 | probas = probas.contiguous().view(B, C, L, H*W) 217 | B, C, H, W = probas.size() 218 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 219 | labels = labels.view(-1) 220 | if ignore is None: 221 | return probas, labels 222 | valid = (labels != ignore) 223 | vprobas = probas[valid.nonzero().squeeze()] 224 | vlabels = labels[valid] 225 | return vprobas, vlabels 226 | 227 | def xloss(logits, labels, ignore=None): 228 | """ 229 | Cross entropy loss 230 | """ 231 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 232 | 233 | def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None): 234 | """ 235 | Something wrong with this loss 236 | Multi-class Lovasz-Softmax loss 237 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 238 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 239 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 240 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 241 | per_image: compute the loss per image instead of per batch 242 | ignore: void class labels 243 | """ 244 | vprobas, vlabels = flatten_probas(probas, labels, ignore) 245 | 246 | 247 | true_1_hot = torch.eye(vprobas.shape[1])[vlabels] 248 | 249 | if bk_class: 250 | one_hot_assignment = torch.ones_like(vlabels) 251 | one_hot_assignment[vlabels == bk_class] = 0 252 | one_hot_assignment = one_hot_assignment.float().unsqueeze(1) 253 | true_1_hot = true_1_hot*one_hot_assignment 254 | 255 | true_1_hot = true_1_hot.to(vprobas.device) 256 | intersection = torch.sum(vprobas * true_1_hot) 257 | cardinality = torch.sum(vprobas + true_1_hot) 258 | loss = (intersection + smooth / (cardinality - intersection + smooth)).mean() 259 | return (1-loss)*smooth 260 | 261 | def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100): 262 | """ 263 | Multi-class Hinge Jaccard loss 264 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 265 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 266 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 267 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 268 | ignore: void class labels 269 | """ 270 | vprobas, vlabels = flatten_probas(probas, labels, ignore) 271 | C = vprobas.size(1) 272 | losses = [] 273 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 274 | for c in class_to_sum: 275 | if c in vlabels: 276 | c_sample_ind = vlabels == c 277 | cprobas = vprobas[c_sample_ind,:] 278 | non_c_ind =np.array([a for a in class_to_sum if a != c]) 279 | class_pred = cprobas[:,c] 280 | max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0] 281 | TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth 282 | FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge) 283 | 284 | if (~c_sample_ind).sum() == 0: 285 | FP = 0 286 | else: 287 | nonc_probas = vprobas[~c_sample_ind,:] 288 | class_pred = nonc_probas[:,c] 289 | max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0] 290 | FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) 291 | 292 | losses.append(1 - TP/(TP+FP+FN)) 293 | 294 | if len(losses) == 0: return 0 295 | return mean(losses) 296 | 297 | # --------------------------- HELPER FUNCTIONS --------------------------- 298 | def isnan(x): 299 | return x != x 300 | 301 | 302 | def mean(l, ignore_nan=False, empty=0): 303 | """ 304 | nanmean compatible with generators. 305 | """ 306 | l = iter(l) 307 | if ignore_nan: 308 | l = ifilterfalse(isnan, l) 309 | try: 310 | n = 1 311 | acc = next(l) 312 | except StopIteration: 313 | if empty == 'raise': 314 | raise ValueError('Empty mean') 315 | return empty 316 | for n, v in enumerate(l, 2): 317 | acc += v 318 | if n == 1: 319 | return acc 320 | return acc / n 321 | -------------------------------------------------------------------------------- /dataloader/pc_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: pc_dataset.py 4 | 5 | import os 6 | import numpy as np 7 | from torch.utils import data 8 | import yaml 9 | import pickle 10 | import pathlib 11 | REGISTERED_PC_DATASET_CLASSES = {} 12 | 13 | 14 | def register_dataset(cls, name=None): 15 | global REGISTERED_PC_DATASET_CLASSES 16 | if name is None: 17 | name = cls.__name__ 18 | assert name not in REGISTERED_PC_DATASET_CLASSES, f"exist class: {REGISTERED_PC_DATASET_CLASSES}" 19 | REGISTERED_PC_DATASET_CLASSES[name] = cls 20 | return cls 21 | 22 | 23 | def get_pc_model_class(name): 24 | global REGISTERED_PC_DATASET_CLASSES 25 | # print(REGISTERED_PC_DATASET_CLASSES) 26 | assert name in REGISTERED_PC_DATASET_CLASSES, f"available class: {REGISTERED_PC_DATASET_CLASSES}" 27 | return REGISTERED_PC_DATASET_CLASSES[name] 28 | 29 | 30 | @register_dataset 31 | class SemKITTI_sk(data.Dataset): 32 | def __init__(self, data_path, imageset='train', 33 | return_ref=False, label_mapping="semantic-kitti.yaml", nusc=None): 34 | self.return_ref = return_ref 35 | with open(label_mapping, 'r') as stream: 36 | semkittiyaml = yaml.safe_load(stream) 37 | self.learning_map = semkittiyaml['learning_map'] 38 | self.imageset = imageset 39 | if imageset == 'train': 40 | split = semkittiyaml['split']['train'] 41 | elif imageset == 'val': 42 | split = semkittiyaml['split']['valid'] 43 | elif imageset == 'test': 44 | split = semkittiyaml['split']['test'] 45 | else: 46 | raise Exception('Split must be train/val/test') 47 | 48 | self.im_idx = [] 49 | for i_folder in split: 50 | self.im_idx += absoluteFilePaths('/'.join([data_path, str(i_folder).zfill(2), 'velodyne'])) 51 | 52 | def __len__(self): 53 | 'Denotes the total number of samples' 54 | return len(self.im_idx) 55 | 56 | def __getitem__(self, index): 57 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) 58 | if self.imageset == 'test': 59 | annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1) 60 | else: 61 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label', 62 | dtype=np.uint32).reshape((-1, 1)) 63 | annotated_data = annotated_data & 0xFFFF # delete high 16 digits binary 64 | annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data) 65 | 66 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) 67 | if self.return_ref: 68 | data_tuple += (raw_data[:, 3],) 69 | return data_tuple 70 | 71 | 72 | def absoluteFilePaths(directory): 73 | for dirpath, _, filenames in os.walk(directory): 74 | filenames.sort() 75 | for f in filenames: 76 | yield os.path.abspath(os.path.join(dirpath, f)) 77 | 78 | 79 | def SemKITTI2train(label): 80 | if isinstance(label, list): 81 | return [SemKITTI2train_single(a) for a in label] 82 | else: 83 | return SemKITTI2train_single(label) 84 | 85 | 86 | def SemKITTI2train_single(label): 87 | remove_ind = label == 0 88 | label -= 1 89 | label[remove_ind] = 255 90 | return label 91 | 92 | 93 | def unpack(compressed): # from samantickitti api 94 | ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' 95 | uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) 96 | uncompressed[::8] = compressed[:] >> 7 & 1 97 | uncompressed[1::8] = compressed[:] >> 6 & 1 98 | uncompressed[2::8] = compressed[:] >> 5 & 1 99 | uncompressed[3::8] = compressed[:] >> 4 & 1 100 | uncompressed[4::8] = compressed[:] >> 3 & 1 101 | uncompressed[5::8] = compressed[:] >> 2 & 1 102 | uncompressed[6::8] = compressed[:] >> 1 & 1 103 | uncompressed[7::8] = compressed[:] & 1 104 | 105 | return uncompressed 106 | 107 | def get_eval_mask(labels, invalid_voxels): # from samantickitti api 108 | """ 109 | Ignore labels set to 255 and invalid voxels (the ones never hit by a laser ray, probed using ray tracing) 110 | :param labels: input ground truth voxels 111 | :param invalid_voxels: voxels ignored during evaluation since the lie beyond the scene that was captured by the laser 112 | :return: boolean mask to subsample the voxels to evaluate 113 | """ 114 | masks = np.ones_like(labels, dtype=np.bool) 115 | masks[labels == 255] = False 116 | masks[invalid_voxels == 1] = False 117 | 118 | return masks 119 | 120 | 121 | from os.path import join 122 | @register_dataset 123 | class SemKITTI_sk_multiscan(data.Dataset): 124 | def __init__(self, data_path, imageset='train',stride=5,return_ref=False, label_mapping="semantic-kitti-multiscan.yaml", nusc=None): 125 | self.return_ref = return_ref 126 | with open(label_mapping, 'r') as stream: 127 | semkittiyaml = yaml.safe_load(stream) 128 | ### remap completion label 129 | remapdict = semkittiyaml['learning_map'] 130 | # make lookup table for mapping 131 | maxkey = max(remapdict.keys()) 132 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 133 | remap_lut[list(remapdict.keys())] = list(remapdict.values()) 134 | # in completion we have to distinguish empty and invalid voxels. 135 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled". 136 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 137 | remap_lut[0] = 0 # only 'empty' stays 'empty'. 138 | self.comletion_remap_lut = remap_lut 139 | 140 | self.learning_map = semkittiyaml['learning_map'] 141 | self.imageset = imageset 142 | self.data_path = data_path 143 | if imageset == 'train': 144 | split = semkittiyaml['split']['train'] 145 | elif imageset == 'val': 146 | split = semkittiyaml['split']['valid'] 147 | elif imageset == 'test': 148 | split = semkittiyaml['split']['test'] 149 | else: 150 | raise Exception('Split must be train/val/test') 151 | multiscan = 0 # additional frames are fused with target-frame. Hence, multiscan+1 point clouds in total 152 | print('multiscan: %d' %multiscan) 153 | self.multiscan = multiscan 154 | self.im_idx = [] 155 | self.stride = stride 156 | self.accumulation=1 157 | self.calibrations = [] 158 | self.times = [] 159 | self.poses = [] 160 | self.use_test_time_adaptation=True 161 | self.load_calib_poses() 162 | for i_folder in split: 163 | # velodyne path corresponding to voxel path 164 | complete_path = os.path.join(data_path, str(i_folder).zfill(2), "voxels") 165 | files = list(pathlib.Path(complete_path).glob('*.bin')) 166 | 167 | for filename in files: 168 | self.im_idx.append(str(filename).replace('voxels', 'velodyne')) 169 | if i_folder==8: 170 | self.im_idx.sort() 171 | 172 | 173 | 174 | def __len__(self): 175 | 'Denotes the total number of samples' 176 | return len(self.im_idx) 177 | 178 | def load_calib_poses(self): 179 | """ 180 | load calib poses and times. 181 | """ 182 | 183 | ########### 184 | # Load data 185 | ########### 186 | 187 | self.calibrations = [] 188 | self.times = [] 189 | self.poses = [] 190 | 191 | for seq in range(0, 22): 192 | seq_folder = join(self.data_path, str(seq).zfill(2)) 193 | 194 | # Read Calib 195 | self.calibrations.append(self.parse_calibration(join(seq_folder, "calib.txt"))) 196 | 197 | # Read times 198 | self.times.append(np.loadtxt(join(seq_folder, 'times.txt'), dtype=np.float32)) 199 | 200 | # Read poses 201 | poses_f64 = self.parse_poses(join(seq_folder, 'poses.txt'), self.calibrations[-1]) 202 | self.poses.append([pose.astype(np.float32) for pose in poses_f64]) 203 | 204 | def parse_calibration(self, filename): 205 | """ read calibration file with given filename 206 | 207 | Returns 208 | ------- 209 | dict 210 | Calibration matrices as 4x4 numpy arrays. 211 | """ 212 | calib = {} 213 | 214 | calib_file = open(filename) 215 | for line in calib_file: 216 | key, content = line.strip().split(":") 217 | values = [float(v) for v in content.strip().split()] 218 | 219 | pose = np.zeros((4, 4)) 220 | pose[0, 0:4] = values[0:4] 221 | pose[1, 0:4] = values[4:8] 222 | pose[2, 0:4] = values[8:12] 223 | pose[3, 3] = 1.0 224 | 225 | calib[key] = pose 226 | 227 | calib_file.close() 228 | 229 | return calib 230 | 231 | def parse_poses(self, filename, calibration): 232 | """ read poses file with per-scan poses from given filename 233 | 234 | Returns 235 | ------- 236 | list 237 | list of poses as 4x4 numpy arrays. 238 | """ 239 | file = open(filename) 240 | 241 | poses = [] 242 | Tr = calibration["Tr"] 243 | Tr_inv = np.linalg.inv(Tr) 244 | 245 | for line in file: 246 | values = [float(v) for v in line.strip().split()] 247 | 248 | pose = np.zeros((4, 4)) 249 | pose[0, 0:4] = values[0:4] 250 | pose[1, 0:4] = values[4:8] 251 | pose[2, 0:4] = values[8:12] 252 | pose[3, 3] = 1.0 253 | 254 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 255 | 256 | return poses 257 | 258 | def fuse_multi_scan(self, points, pose0, pose): 259 | 260 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 261 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 262 | new_points = new_points[:, :3] 263 | new_coords = new_points - pose0[:3, 3] 264 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 265 | new_coords = np.hstack((new_coords, points[:, 3:])) 266 | 267 | return new_coords 268 | 269 | def frame_transform_scan(self, points, pose0, pose): 270 | 271 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 272 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 273 | new_points = new_points[:, :3] 274 | new_coords = new_points - pose0[:3, 3] 275 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 276 | new_coords = np.hstack((new_coords, points[:, 3:])) 277 | 278 | return new_coords 279 | 280 | def __getitem__(self, index): 281 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) # point cloud 282 | origin_len = len(raw_data) 283 | voxel_label = 1 284 | number_idx = int(self.im_idx[index][-10:-4]) 285 | 286 | 287 | dir_idx = int(self.im_idx[index][-22:-20]) 288 | pose_list=[] 289 | pose0 = self.poses[dir_idx][number_idx] 290 | pose_list.append(pose0) 291 | 292 | prev_scans=[] 293 | prev_annotated_data=[] 294 | stride_list=[] 295 | for stride_ in self.stride: 296 | data_idx = number_idx + stride_ 297 | 298 | if data_idx<0: 299 | continue 300 | 301 | path = self.im_idx[index][:-10] 302 | newpath_prev = path + str(data_idx).zfill(6) + self.im_idx[index][-4:] 303 | 304 | voxel_path = path.replace('velodyne', 'labels') 305 | files = list(pathlib.Path(voxel_path).glob('*.label')) 306 | 307 | if data_idx < len(files): 308 | pose = self.poses[dir_idx][data_idx] 309 | pose_list.append(pose) 310 | raw_data_prev = np.fromfile(newpath_prev, dtype=np.float32).reshape((-1, 4)) 311 | prev_scans.append(raw_data_prev) 312 | annotated_data_prev = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1)) 313 | annotated_data_prev = self.comletion_remap_lut[annotated_data_prev] 314 | annotated_data_prev = annotated_data_prev.reshape((256, 256, 32)) 315 | prev_annotated_data.append(annotated_data_prev) 316 | stride_list.append(stride_) 317 | 318 | if len(pose_list)>1: 319 | prev_annotated_data=np.stack(prev_annotated_data,0) 320 | prev_scans_xyz=[prev_scan[:, :3] for prev_scan in prev_scans] 321 | prev_scans_i=[prev_scan[:, 3] for prev_scan in prev_scans] 322 | prev_data_tuple=(prev_scans_xyz, prev_annotated_data.astype(np.uint8)) 323 | prev_data_tuple += (prev_scans_i,origin_len) 324 | else: 325 | prev_data_tuple=(np.zeros([1,1,1]),np.zeros([1,1,1]),np.zeros([1,1,1]),1) 326 | 327 | if self.imageset == 'test': 328 | annotated_data = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1)) 329 | else: 330 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'voxels')[:-3] + 'label', 331 | dtype=np.uint16).reshape((-1, 1)) # voxel labels 332 | 333 | annotated_data = self.comletion_remap_lut[annotated_data] 334 | annotated_data = annotated_data.reshape((256, 256, 32)) 335 | 336 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) # xyz, voxel labels 337 | data_tuple += (raw_data[:, 3], origin_len) # origin_len is used to indicate the length of target-scan 338 | 339 | return data_tuple,prev_data_tuple,pose_list,stride_list 340 | 341 | 342 | # load Semantic KITTI class info 343 | def get_SemKITTI_label_name(label_mapping): 344 | with open(label_mapping, 'r') as stream: 345 | semkittiyaml = yaml.safe_load(stream) 346 | SemKITTI_label_name = dict() 347 | for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]: 348 | SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i] 349 | 350 | return SemKITTI_label_name 351 | 352 | 353 | def get_nuScenes_label_name(label_mapping): 354 | with open(label_mapping, 'r') as stream: 355 | nuScenesyaml = yaml.safe_load(stream) 356 | nuScenes_label_name = dict() 357 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]: 358 | val_ = nuScenesyaml['learning_map'][i] 359 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_] 360 | 361 | return nuScenes_label_name 362 | -------------------------------------------------------------------------------- /dataloader/pc_dataset_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: pc_dataset.py 4 | 5 | import os 6 | import numpy as np 7 | from torch.utils import data 8 | import yaml 9 | import pickle 10 | import pathlib 11 | REGISTERED_PC_DATASET_CLASSES = {} 12 | 13 | 14 | def register_dataset(cls, name=None): 15 | global REGISTERED_PC_DATASET_CLASSES 16 | if name is None: 17 | name = cls.__name__ 18 | assert name not in REGISTERED_PC_DATASET_CLASSES, f"exist class: {REGISTERED_PC_DATASET_CLASSES}" 19 | REGISTERED_PC_DATASET_CLASSES[name] = cls 20 | return cls 21 | 22 | 23 | def get_pc_model_class(name): 24 | global REGISTERED_PC_DATASET_CLASSES 25 | # print(REGISTERED_PC_DATASET_CLASSES) 26 | assert name in REGISTERED_PC_DATASET_CLASSES, f"available class: {REGISTERED_PC_DATASET_CLASSES}" 27 | return REGISTERED_PC_DATASET_CLASSES[name] 28 | 29 | 30 | @register_dataset 31 | class SemKITTI_sk(data.Dataset): 32 | def __init__(self, data_path, imageset='train', 33 | return_ref=False, label_mapping="semantic-kitti.yaml", nusc=None): 34 | self.return_ref = return_ref 35 | with open(label_mapping, 'r') as stream: 36 | semkittiyaml = yaml.safe_load(stream) 37 | self.learning_map = semkittiyaml['learning_map'] 38 | self.imageset = imageset 39 | if imageset == 'train': 40 | split = semkittiyaml['split']['train'] 41 | elif imageset == 'val': 42 | split = semkittiyaml['split']['valid'] 43 | elif imageset == 'test': 44 | split = semkittiyaml['split']['test'] 45 | else: 46 | raise Exception('Split must be train/val/test') 47 | 48 | self.im_idx = [] 49 | for i_folder in split: 50 | self.im_idx += absoluteFilePaths('/'.join([data_path, str(i_folder).zfill(2), 'velodyne'])) 51 | 52 | def __len__(self): 53 | 'Denotes the total number of samples' 54 | return len(self.im_idx) 55 | 56 | def __getitem__(self, index): 57 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) 58 | if self.imageset == 'test': 59 | annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1) 60 | else: 61 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label', 62 | dtype=np.uint32).reshape((-1, 1)) 63 | annotated_data = annotated_data & 0xFFFF # delete high 16 digits binary 64 | annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data) 65 | 66 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) 67 | if self.return_ref: 68 | data_tuple += (raw_data[:, 3],) 69 | return data_tuple 70 | 71 | 72 | def absoluteFilePaths(directory): 73 | for dirpath, _, filenames in os.walk(directory): 74 | filenames.sort() 75 | for f in filenames: 76 | yield os.path.abspath(os.path.join(dirpath, f)) 77 | 78 | 79 | def SemKITTI2train(label): 80 | if isinstance(label, list): 81 | return [SemKITTI2train_single(a) for a in label] 82 | else: 83 | return SemKITTI2train_single(label) 84 | 85 | 86 | def SemKITTI2train_single(label): 87 | remove_ind = label == 0 88 | label -= 1 89 | label[remove_ind] = 255 90 | return label 91 | 92 | 93 | def unpack(compressed): # from samantickitti api 94 | ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' 95 | uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) 96 | uncompressed[::8] = compressed[:] >> 7 & 1 97 | uncompressed[1::8] = compressed[:] >> 6 & 1 98 | uncompressed[2::8] = compressed[:] >> 5 & 1 99 | uncompressed[3::8] = compressed[:] >> 4 & 1 100 | uncompressed[4::8] = compressed[:] >> 3 & 1 101 | uncompressed[5::8] = compressed[:] >> 2 & 1 102 | uncompressed[6::8] = compressed[:] >> 1 & 1 103 | uncompressed[7::8] = compressed[:] & 1 104 | 105 | return uncompressed 106 | 107 | def get_eval_mask(labels, invalid_voxels): # from samantickitti api 108 | """ 109 | Ignore labels set to 255 and invalid voxels (the ones never hit by a laser ray, probed using ray tracing) 110 | :param labels: input ground truth voxels 111 | :param invalid_voxels: voxels ignored during evaluation since the lie beyond the scene that was captured by the laser 112 | :return: boolean mask to subsample the voxels to evaluate 113 | """ 114 | masks = np.ones_like(labels, dtype=np.bool) 115 | masks[labels == 255] = False 116 | masks[invalid_voxels == 1] = False 117 | 118 | return masks 119 | 120 | 121 | from os.path import join 122 | @register_dataset 123 | class SemKITTI_sk_multiscan(data.Dataset): 124 | def __init__(self, data_path, imageset='train',stride=5,return_ref=False, label_mapping="semantic-kitti-multiscan.yaml", nusc=None,sq_num=10): 125 | self.return_ref = return_ref 126 | with open(label_mapping, 'r') as stream: 127 | semkittiyaml = yaml.safe_load(stream) 128 | ### remap completion label 129 | remapdict = semkittiyaml['learning_map'] 130 | # make lookup table for mapping 131 | maxkey = max(remapdict.keys()) 132 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 133 | remap_lut[list(remapdict.keys())] = list(remapdict.values()) 134 | # in completion we have to distinguish empty and invalid voxels. 135 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled". 136 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 137 | remap_lut[0] = 0 # only 'empty' stays 'empty'. 138 | self.comletion_remap_lut = remap_lut 139 | 140 | self.learning_map = semkittiyaml['learning_map'] 141 | self.imageset = imageset 142 | self.data_path = data_path 143 | if imageset == 'train': 144 | split = semkittiyaml['split']['train'] 145 | elif imageset == 'val': 146 | split = semkittiyaml['split']['valid'] 147 | elif imageset == 'test': 148 | # split = semkittiyaml['split']['test'] 149 | split = [sq_num] 150 | else: 151 | raise Exception('Split must be train/val/test') 152 | # import pdb;pdb.set_trace() 153 | multiscan = 0 # additional frames are fused with target-frame. Hence, multiscan+1 point clouds in total 154 | print('multiscan: %d' %multiscan) 155 | self.multiscan = multiscan 156 | self.im_idx = [] 157 | self.stride = stride 158 | self.accumulation=1 159 | self.calibrations = [] 160 | self.times = [] 161 | self.poses = [] 162 | self.use_test_time_adaptation=True 163 | self.load_calib_poses() 164 | for i_folder in split: 165 | # velodyne path corresponding to voxel path 166 | complete_path = os.path.join(data_path, str(i_folder).zfill(2), "voxels") 167 | files = list(pathlib.Path(complete_path).glob('*.bin')) 168 | 169 | for filename in files: 170 | self.im_idx.append(str(filename).replace('voxels', 'velodyne')) 171 | 172 | self.im_idx.sort() 173 | 174 | 175 | 176 | def __len__(self): 177 | 'Denotes the total number of samples' 178 | return len(self.im_idx) 179 | 180 | def load_calib_poses(self): 181 | """ 182 | load calib poses and times. 183 | """ 184 | 185 | ########### 186 | # Load data 187 | ########### 188 | 189 | self.calibrations = [] 190 | self.times = [] 191 | self.poses = [] 192 | 193 | for seq in range(0, 22): 194 | seq_folder = join(self.data_path, str(seq).zfill(2)) 195 | 196 | # Read Calib 197 | self.calibrations.append(self.parse_calibration(join(seq_folder, "calib.txt"))) 198 | 199 | # Read times 200 | self.times.append(np.loadtxt(join(seq_folder, 'times.txt'), dtype=np.float32)) 201 | 202 | # Read poses 203 | poses_f64 = self.parse_poses(join(seq_folder, 'poses.txt'), self.calibrations[-1]) 204 | self.poses.append([pose.astype(np.float32) for pose in poses_f64]) 205 | 206 | def parse_calibration(self, filename): 207 | """ read calibration file with given filename 208 | 209 | Returns 210 | ------- 211 | dict 212 | Calibration matrices as 4x4 numpy arrays. 213 | """ 214 | calib = {} 215 | 216 | calib_file = open(filename) 217 | for line in calib_file: 218 | key, content = line.strip().split(":") 219 | values = [float(v) for v in content.strip().split()] 220 | 221 | pose = np.zeros((4, 4)) 222 | pose[0, 0:4] = values[0:4] 223 | pose[1, 0:4] = values[4:8] 224 | pose[2, 0:4] = values[8:12] 225 | pose[3, 3] = 1.0 226 | 227 | calib[key] = pose 228 | 229 | calib_file.close() 230 | 231 | return calib 232 | 233 | def parse_poses(self, filename, calibration): 234 | """ read poses file with per-scan poses from given filename 235 | 236 | Returns 237 | ------- 238 | list 239 | list of poses as 4x4 numpy arrays. 240 | """ 241 | file = open(filename) 242 | 243 | poses = [] 244 | Tr = calibration["Tr"] 245 | Tr_inv = np.linalg.inv(Tr) 246 | 247 | for line in file: 248 | values = [float(v) for v in line.strip().split()] 249 | 250 | pose = np.zeros((4, 4)) 251 | pose[0, 0:4] = values[0:4] 252 | pose[1, 0:4] = values[4:8] 253 | pose[2, 0:4] = values[8:12] 254 | pose[3, 3] = 1.0 255 | 256 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 257 | 258 | return poses 259 | 260 | def fuse_multi_scan(self, points, pose0, pose): 261 | 262 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 263 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 264 | new_points = new_points[:, :3] 265 | new_coords = new_points - pose0[:3, 3] 266 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 267 | new_coords = np.hstack((new_coords, points[:, 3:])) 268 | 269 | return new_coords 270 | 271 | def frame_transform_scan(self, points, pose0, pose): 272 | 273 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 274 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 275 | new_points = new_points[:, :3] 276 | new_coords = new_points - pose0[:3, 3] 277 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 278 | new_coords = np.hstack((new_coords, points[:, 3:])) 279 | 280 | return new_coords 281 | 282 | def __getitem__(self, index): 283 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) # point cloud 284 | origin_len = len(raw_data) 285 | voxel_label = 1 286 | number_idx = int(self.im_idx[index][-10:-4]) 287 | dir_idx = int(self.im_idx[index][-22:-20]) 288 | pose_list=[] 289 | pose0 = self.poses[dir_idx][number_idx] 290 | pose_list.append(pose0) 291 | 292 | prev_scans=[] 293 | prev_annotated_data=[] 294 | stride_list=[] 295 | for stride_ in self.stride: 296 | data_idx = number_idx + stride_ 297 | 298 | if data_idx<0: 299 | continue 300 | 301 | path = self.im_idx[index][:-10] 302 | newpath_prev = path + str(data_idx).zfill(6) + self.im_idx[index][-4:] 303 | voxel_path = path.replace('velodyne', 'labels') 304 | files=list(pathlib.Path(path).glob('*.bin')) 305 | 306 | if data_idx < len(files): 307 | pose = self.poses[dir_idx][data_idx] 308 | pose_list.append(pose) 309 | raw_data_prev = np.fromfile(newpath_prev, dtype=np.float32).reshape((-1, 4)) 310 | prev_scans.append(raw_data_prev) 311 | annotated_data_prev = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1)) 312 | annotated_data_prev = self.comletion_remap_lut[annotated_data_prev] 313 | annotated_data_prev = annotated_data_prev.reshape((256, 256, 32)) 314 | prev_annotated_data.append(annotated_data_prev) 315 | stride_list.append(stride_) 316 | 317 | if len(pose_list)>1: 318 | prev_annotated_data=np.stack(prev_annotated_data,0) 319 | prev_scans_xyz=[prev_scan[:, :3] for prev_scan in prev_scans] 320 | prev_scans_i=[prev_scan[:, 3] for prev_scan in prev_scans] 321 | prev_data_tuple=(prev_scans_xyz, prev_annotated_data.astype(np.uint8)) 322 | prev_data_tuple += (prev_scans_i,origin_len) 323 | else: 324 | prev_data_tuple=(np.zeros([1,1,1]),np.zeros([1,1,1]),np.zeros([1,1,1]),1) 325 | 326 | if self.imageset == 'test': 327 | annotated_data = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1)) 328 | else: 329 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'voxels')[:-3] + 'label', 330 | dtype=np.uint16).reshape((-1, 1)) # voxel labels 331 | 332 | annotated_data = self.comletion_remap_lut[annotated_data] 333 | annotated_data = annotated_data.reshape((256, 256, 32)) 334 | 335 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) # xyz, voxel labels 336 | data_tuple += (raw_data[:, 3], origin_len) # origin_len is used to indicate the length of target-scan 337 | 338 | return data_tuple,prev_data_tuple,pose_list,stride_list 339 | 340 | 341 | # load Semantic KITTI class info 342 | def get_SemKITTI_label_name(label_mapping): 343 | with open(label_mapping, 'r') as stream: 344 | semkittiyaml = yaml.safe_load(stream) 345 | SemKITTI_label_name = dict() 346 | for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]: 347 | SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i] 348 | 349 | return SemKITTI_label_name 350 | 351 | 352 | def get_nuScenes_label_name(label_mapping): 353 | with open(label_mapping, 'r') as stream: 354 | nuScenesyaml = yaml.safe_load(stream) 355 | nuScenes_label_name = dict() 356 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]: 357 | val_ = nuScenesyaml['learning_map'][i] 358 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_] 359 | 360 | return nuScenes_label_name 361 | -------------------------------------------------------------------------------- /dataloader/dataset_semantickitti.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | """ 5 | SemKITTI dataloader 6 | """ 7 | import numpy as np 8 | import torch 9 | import numba as nb 10 | from torch.utils import data 11 | import time 12 | import random 13 | 14 | from utils.util import Bresenham3D 15 | REGISTERED_DATASET_CLASSES = {} 16 | 17 | def from_voxel_to_voxel(max_volume_space,min_volume_space,intervals,pose0,pose): 18 | x_bias = (max_volume_space[0] - min_volume_space[0])/2 19 | max_bound = np.asarray(max_volume_space) 20 | min_bound = np.asarray(min_volume_space) 21 | min_bound[0] -= x_bias 22 | max_bound[0] -= x_bias 23 | max_bound2 = np.asarray(max_volume_space) 24 | min_bound2 = np.asarray(min_volume_space) 25 | 26 | voxel_grid = np.indices((256, 256, 32)).transpose(1, 2, 3, 0) 27 | voxel_grid = voxel_grid.reshape(-1,3) 28 | full_voxel_center=(voxel_grid.astype(np.float32) + 0.5) * intervals + min_bound 29 | full_voxel_center[:,0]+= x_bias 30 | current_vox_center=frame_transform_scan(full_voxel_center,pose0,pose) 31 | current_vox_center=np.concatenate([current_vox_center,voxel_grid],1) 32 | vox_xyz0 = current_vox_center 33 | for ci in range(3): 34 | vox_xyz0[current_vox_center[:, ci] < min_bound2[ci], :] = 1000 35 | vox_xyz0[current_vox_center[:, ci] > max_bound2[ci], :] = 1000 36 | vox_valid_inds = vox_xyz0[:, 0] != 1000 37 | current_vox_center = current_vox_center[vox_valid_inds, :] 38 | current_vox_center[:, 0]-= x_bias # current_vox_center is 39 | vox_grid_from=current_vox_center[:,-3:] 40 | vox_grid_to=current_vox_center[:,:3] 41 | vox_grid_to = (np.floor((np.clip(vox_grid_to, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 42 | return vox_grid_from,vox_grid_to 43 | 44 | def frame_transform_scan(points, pose0, pose): 45 | 46 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 47 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 48 | new_points = new_points[:, :3] 49 | new_coords = new_points - pose0[:3, 3] 50 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 51 | new_coords = np.hstack((new_coords, points[:, 3:])) 52 | 53 | return new_coords 54 | 55 | def register_dataset(cls, name=None): 56 | global REGISTERED_DATASET_CLASSES 57 | if name is None: 58 | name = cls.__name__ 59 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}" 60 | REGISTERED_DATASET_CLASSES[name] = cls 61 | return cls 62 | 63 | 64 | def get_model_class(name): 65 | global REGISTERED_DATASET_CLASSES 66 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}" 67 | return REGISTERED_DATASET_CLASSES[name] 68 | 69 | 70 | @register_dataset 71 | class voxel_dataset(data.Dataset): 72 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 73 | fixed_volume_space=False, max_volume_space=[50, 50, 1.5], min_volume_space=[-50, -50, -3]): 74 | 'Initialization' 75 | self.point_cloud_dataset = in_dataset 76 | self.grid_size = np.asarray(grid_size) 77 | self.rotate_aug = rotate_aug 78 | self.ignore_label = ignore_label 79 | self.return_test = return_test 80 | self.flip_aug = flip_aug 81 | self.fixed_volume_space = fixed_volume_space 82 | self.max_volume_space = max_volume_space 83 | self.min_volume_space = min_volume_space 84 | 85 | def __len__(self): 86 | 'Denotes the total number of samples' 87 | return len(self.point_cloud_dataset) 88 | 89 | def __getitem__(self, index): 90 | 'Generates one sample of data' 91 | data,prev_data,pose_list,stride_list = self.point_cloud_dataset[index] 92 | 93 | if len(data) == 4: 94 | xyz, labels, sig, origin_len = data 95 | prev_xyz, prev_labels, prev_sig,_ = prev_data 96 | if len(sig.shape) == 2: sig = np.squeeze(sig) 97 | else: 98 | raise Exception('Return invalid data tuple') 99 | 100 | origin_len = len(xyz) 101 | max_bound = np.asarray(self.max_volume_space) 102 | min_bound = np.asarray(self.min_volume_space) 103 | 104 | ### Cut point cloud and segmentation label for valid range 105 | xyz0 = xyz 106 | for ci in range(3): 107 | xyz0[xyz[:, ci] < min_bound[ci], :] = 1000 108 | xyz0[xyz[:, ci] > max_bound[ci], :] = 1000 109 | valid_inds = xyz0[:, 0] != 1000 110 | xyz = xyz[valid_inds, :] 111 | sig = sig[valid_inds] 112 | 113 | ### post_scan_preprocess ### 114 | 115 | prev_raws=[] 116 | prev_sigs=[] 117 | prev_vox=[] 118 | prev_velodyne=[] 119 | prev_trans_sigs=[] 120 | 121 | if len(pose_list)>1: 122 | prev_frames_num=len(prev_xyz) 123 | for f_idx in range(prev_frames_num): 124 | 125 | ### cut point clound 126 | prev_xyz0 = prev_xyz[f_idx] 127 | prev_xyz_sin=prev_xyz[f_idx] 128 | prev_sig_sin=prev_sig[f_idx] 129 | for ci in range(3): 130 | prev_xyz0[prev_xyz_sin[:, ci] < min_bound[ci], :] = 1000 131 | prev_xyz0[prev_xyz_sin[:, ci] > max_bound[ci], :] = 1000 132 | valid_inds = prev_xyz0[:, 0] != 1000 133 | prev_xyz_single = prev_xyz_sin[valid_inds, :] 134 | prev_sig_single = prev_sig_sin[valid_inds] 135 | prev_raws.append(prev_xyz_single) 136 | prev_sigs.append(prev_sig_single) 137 | 138 | ### Transform prev point cloud 139 | transformed_prev_xyz = frame_transform_scan(prev_xyz[f_idx], pose_list[0], pose_list[f_idx+1]) 140 | transformed_prev_velodyne=frame_transform_scan(np.zeros([1,4]), pose_list[0], pose_list[f_idx+1]) 141 | 142 | ### Cut prev point cloud 143 | prev_xyz0 = transformed_prev_xyz 144 | for ci in range(3): 145 | prev_xyz0[transformed_prev_xyz[:, ci] < min_bound[ci], :] = 1000 146 | prev_xyz0[transformed_prev_xyz[:, ci] > max_bound[ci], :] = 1000 147 | prev_valid_inds = prev_xyz0[:, 0] != 1000 148 | transformed_prev_xyz = transformed_prev_xyz[prev_valid_inds, :] 149 | transformed_prev_sig_single = prev_sig_sin[prev_valid_inds] 150 | prev_vox.append(transformed_prev_xyz) 151 | prev_trans_sigs.append(transformed_prev_sig_single) 152 | prev_velodyne.append(transformed_prev_velodyne) 153 | 154 | # transpose centre coord for x axis 155 | x_bias = (self.max_volume_space[0] - self.min_volume_space[0])/2 156 | min_bound[0] -= x_bias 157 | max_bound[0] -= x_bias 158 | xyz[:, 0] -= x_bias 159 | if len(pose_list)>1: 160 | for f_idx in range(prev_frames_num): 161 | prev_raws_sim=prev_raws[f_idx] 162 | prev_vox_sin=prev_vox[f_idx] 163 | prev_raws_sim[:, 0]-= x_bias 164 | prev_vox_sin[:, 0]-= x_bias 165 | prev_velodyne_sin=prev_velodyne[f_idx] 166 | prev_velodyne_sin=prev_velodyne_sin[:,:3] 167 | prev_velodyne_sin[:,0]-= x_bias 168 | prev_vox[f_idx]=prev_vox_sin 169 | prev_velodyne[f_idx]=prev_velodyne_sin 170 | prev_raws[f_idx]=prev_raws_sim 171 | 172 | # get grid index 173 | crop_range = max_bound - min_bound 174 | cur_grid_size = self.grid_size 175 | intervals = crop_range / (cur_grid_size - 1) 176 | if (intervals == 0).any(): print("Zero interval!") 177 | 178 | grid_ind = (np.floor((np.clip(xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 179 | 180 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 181 | return_xyz = xyz - voxel_centers 182 | return_xyz = np.concatenate((return_xyz, xyz), axis=1) 183 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 184 | 185 | 186 | prev_grid_list=[] 187 | prev_fea_list=[] 188 | prev_label_list=[] 189 | prev_transformed_fea_list=[] 190 | prev_transformed_grid_list=[] 191 | if len(pose_list)>1: 192 | for f_idx in range(prev_frames_num): 193 | single_prev_xyz=prev_raws[f_idx] 194 | single_prev_sig=prev_sigs[f_idx] 195 | single_trans_prev_sig=prev_trans_sigs[f_idx] 196 | single_prev_vox=prev_vox[f_idx] 197 | 198 | prev_grid_ind = (np.floor((np.clip(single_prev_xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 199 | prev_voxel_centers = (prev_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 200 | return_prev_xyz = single_prev_xyz - prev_voxel_centers 201 | return_prev_xyz = np.concatenate((return_prev_xyz, single_prev_xyz), axis=1) 202 | return_prev_fea = np.concatenate((return_prev_xyz, single_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 203 | 204 | prev_transformed_grid_ind = (np.floor((np.clip(single_prev_vox, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 205 | prev_transformed_voxel_centers = (prev_transformed_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 206 | return_transformed_prev_xyz = single_prev_vox - prev_transformed_voxel_centers 207 | return_transformed_prev_xyz = np.concatenate((return_transformed_prev_xyz, single_prev_vox), axis=1) 208 | return_transformed_prev_fea = np.concatenate((return_transformed_prev_xyz, single_trans_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 209 | 210 | prev_fea_list.append(return_prev_fea) 211 | prev_transformed_fea_list.append(return_transformed_prev_fea) 212 | prev_transformed_grid_list.append(prev_transformed_grid_ind) 213 | prev_grid_list.append(prev_grid_ind) 214 | prev_label_list.append(prev_labels[f_idx]) 215 | 216 | dim_array = np.ones(len(self.grid_size) + 1, int) 217 | dim_array[0] = -1 218 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 219 | processed_label = labels # voxel labels 220 | 221 | data_tuple = (voxel_position, processed_label) 222 | 223 | vox_from_list=[] 224 | vox_to_list=[] 225 | if len(pose_list)>1: 226 | for f_idx in range(prev_frames_num): 227 | vox_grid_from,vox_grid_to= from_voxel_to_voxel(self.max_volume_space,self.min_volume_space,intervals,pose_list[0],pose_list[f_idx+1]) 228 | vox_from_list.append(vox_grid_from) 229 | vox_to_list.append(vox_grid_to) 230 | 231 | else: 232 | vox_grid_to=np.array([[0,0,0]]) 233 | vox_grid_from=np.array([[0,0,0]]) 234 | vox_from_list.append(vox_grid_from) 235 | vox_to_list.append(vox_grid_to) 236 | 237 | 238 | if self.return_test: 239 | data_tuple += (grid_ind, labels, return_fea, index) 240 | else: 241 | data_tuple += (grid_ind, labels, return_fea) 242 | 243 | data_tuple += (origin_len,prev_velodyne,vox_to_list,vox_from_list,prev_grid_list,prev_fea_list,prev_transformed_grid_list,prev_transformed_fea_list,prev_label_list,min_bound,max_bound,intervals,stride_list) 244 | 245 | return data_tuple 246 | 247 | 248 | # transformation between Cartesian coordinates and polar coordinates 249 | def cart2polar(input_xyz): 250 | rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2) 251 | phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0]) 252 | return np.stack((rho, phi, input_xyz[:, 2]), axis=1) 253 | 254 | 255 | def polar2cat(input_xyz_polar): 256 | # print(input_xyz_polar.shape) 257 | x = input_xyz_polar[0] * np.cos(input_xyz_polar[1]) 258 | y = input_xyz_polar[0] * np.sin(input_xyz_polar[1]) 259 | return np.stack((x, y, input_xyz_polar[2]), axis=0) 260 | 261 | @nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False) 262 | def nb_process_label(processed_label, sorted_label_voxel_pair): 263 | label_size = 256 264 | counter = np.zeros((label_size,), dtype=np.uint16) 265 | counter[sorted_label_voxel_pair[0, 3]] = 1 266 | cur_sear_ind = sorted_label_voxel_pair[0, :3] 267 | for i in range(1, sorted_label_voxel_pair.shape[0]): 268 | cur_ind = sorted_label_voxel_pair[i, :3] 269 | if not np.all(np.equal(cur_ind, cur_sear_ind)): 270 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 271 | counter = np.zeros((label_size,), dtype=np.uint16) 272 | cur_sear_ind = cur_ind 273 | counter[sorted_label_voxel_pair[i, 3]] += 1 274 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 275 | return processed_label 276 | 277 | 278 | 279 | def collate_fn_BEV_ms_tta(data): 280 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 281 | grid_ind_stack = [d[2] for d in data] 282 | xyz = [d[4] for d in data] 283 | index = [d[5] for d in data] 284 | prev_velodyne=[d[7] for d in data] 285 | vox_grid_to=[d[8] for d in data] 286 | vox_grid_from=[d[9] for d in data] 287 | prev_grid_ind=[d[10] for d in data] 288 | prev_feat=[d[11] for d in data] 289 | prev_trans_ind=[d[12] for d in data] 290 | prev_trans_feat=[d[13] for d in data] 291 | min_bound=[d[15] for d in data] 292 | max_bound=[d[16] for d in data] 293 | interval=[d[17] for d in data] 294 | strides=[d[18] for d in data] 295 | current_frame={'grid_ind':grid_ind_stack, 'pt_feat': xyz, 'index': index, 'gt':torch.from_numpy(label2stack),'min_bound':min_bound,'max_bound':max_bound,'interval':interval,'stride':strides[0]} 296 | prev_frame={'vox_grid_to':vox_grid_to[0], 'vox_grid_from': vox_grid_from[0], 'grid_ind': prev_grid_ind[0], 'pt_feat':prev_feat[0],'trans_grid_ind':prev_trans_ind[0],'trans_pt_feat':prev_trans_feat[0],'lidar_pose':prev_velodyne[0]} 297 | 298 | return current_frame,prev_frame 299 | 300 | 301 | def collate_fn_BEV(data): 302 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 303 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 304 | grid_ind_stack = [d[2] for d in data] 305 | point_label = [d[3] for d in data] 306 | xyz = [d[4] for d in data] 307 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz 308 | 309 | def collate_fn_BEV_tta(data): 310 | voxel_label = [] 311 | for da1 in data: 312 | for da2 in da1: 313 | voxel_label.append(da2[1]) 314 | grid_ind_stack = [] 315 | for da1 in data: 316 | for da2 in da1: 317 | grid_ind_stack.append(da2[2]) 318 | point_label = [] 319 | for da1 in data: 320 | for da2 in da1: 321 | point_label.append(da2[3]) 322 | xyz = [] 323 | for da1 in data: 324 | for da2 in da1: 325 | xyz.append(da2[4]) 326 | index = [] 327 | for da1 in data: 328 | for da2 in da1: 329 | index.append(da2[5]) 330 | return xyz, xyz, grid_ind_stack, point_label, xyz, index 331 | 332 | def collate_fn_BEV_ms(data): 333 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 334 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 335 | grid_ind_stack = [d[2] for d in data] 336 | point_label = [d[3] for d in data] 337 | xyz = [d[4] for d in data] 338 | index = [d[5] for d in data] 339 | origin_len = [d[6] for d in data] 340 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len 341 | 342 | -------------------------------------------------------------------------------- /dataloader/dataset_semantickitti_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | """ 5 | SemKITTI dataloader 6 | """ 7 | import numpy as np 8 | import torch 9 | import numba as nb 10 | from torch.utils import data 11 | import time 12 | import random 13 | 14 | from utils.util import Bresenham3D 15 | REGISTERED_DATASET_CLASSES = {} 16 | 17 | def from_voxel_to_voxel(max_volume_space,min_volume_space,intervals,pose0,pose): 18 | x_bias = (max_volume_space[0] - min_volume_space[0])/2 19 | max_bound = np.asarray(max_volume_space) 20 | min_bound = np.asarray(min_volume_space) 21 | min_bound[0] -= x_bias 22 | max_bound[0] -= x_bias 23 | max_bound2 = np.asarray(max_volume_space) 24 | min_bound2 = np.asarray(min_volume_space) 25 | 26 | voxel_grid = np.indices((256, 256, 32)).transpose(1, 2, 3, 0) 27 | voxel_grid = voxel_grid.reshape(-1,3) 28 | full_voxel_center=(voxel_grid.astype(np.float32) + 0.5) * intervals + min_bound 29 | full_voxel_center[:,0]+= x_bias 30 | current_vox_center=frame_transform_scan(full_voxel_center,pose0,pose) 31 | current_vox_center=np.concatenate([current_vox_center,voxel_grid],1) 32 | vox_xyz0 = current_vox_center 33 | for ci in range(3): 34 | vox_xyz0[current_vox_center[:, ci] < min_bound2[ci], :] = 1000 35 | vox_xyz0[current_vox_center[:, ci] > max_bound2[ci], :] = 1000 36 | vox_valid_inds = vox_xyz0[:, 0] != 1000 37 | current_vox_center = current_vox_center[vox_valid_inds, :] 38 | current_vox_center[:, 0]-= x_bias # current_vox_center is 39 | vox_grid_from=current_vox_center[:,-3:] 40 | vox_grid_to=current_vox_center[:,:3] 41 | vox_grid_to = (np.floor((np.clip(vox_grid_to, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 42 | return vox_grid_from,vox_grid_to 43 | 44 | def frame_transform_scan(points, pose0, pose): 45 | 46 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 47 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 48 | new_points = new_points[:, :3] 49 | new_coords = new_points - pose0[:3, 3] 50 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 51 | new_coords = np.hstack((new_coords, points[:, 3:])) 52 | 53 | return new_coords 54 | 55 | def register_dataset(cls, name=None): 56 | global REGISTERED_DATASET_CLASSES 57 | if name is None: 58 | name = cls.__name__ 59 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}" 60 | REGISTERED_DATASET_CLASSES[name] = cls 61 | return cls 62 | 63 | 64 | def get_model_class(name): 65 | global REGISTERED_DATASET_CLASSES 66 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}" 67 | return REGISTERED_DATASET_CLASSES[name] 68 | 69 | 70 | @register_dataset 71 | class voxel_dataset(data.Dataset): 72 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 73 | fixed_volume_space=False, max_volume_space=[50, 50, 1.5], min_volume_space=[-50, -50, -3]): 74 | 'Initialization' 75 | self.point_cloud_dataset = in_dataset 76 | self.grid_size = np.asarray(grid_size) 77 | self.rotate_aug = rotate_aug 78 | self.ignore_label = ignore_label 79 | self.return_test = return_test 80 | self.flip_aug = flip_aug 81 | self.fixed_volume_space = fixed_volume_space 82 | self.max_volume_space = max_volume_space 83 | self.min_volume_space = min_volume_space 84 | 85 | def __len__(self): 86 | 'Denotes the total number of samples' 87 | return len(self.point_cloud_dataset) 88 | 89 | def __getitem__(self, index): 90 | 'Generates one sample of data' 91 | data,prev_data,pose_list,stride_list = self.point_cloud_dataset[index] 92 | 93 | if len(data) == 4: 94 | xyz, labels, sig, origin_len = data 95 | prev_xyz, prev_labels, prev_sig,_ = prev_data 96 | if len(sig.shape) == 2: sig = np.squeeze(sig) 97 | else: 98 | raise Exception('Return invalid data tuple') 99 | 100 | origin_len = len(xyz) 101 | max_bound = np.asarray(self.max_volume_space) 102 | min_bound = np.asarray(self.min_volume_space) 103 | 104 | ### Cut point cloud and segmentation label for valid range 105 | xyz0 = xyz 106 | for ci in range(3): 107 | xyz0[xyz[:, ci] < min_bound[ci], :] = 1000 108 | xyz0[xyz[:, ci] > max_bound[ci], :] = 1000 109 | valid_inds = xyz0[:, 0] != 1000 110 | xyz = xyz[valid_inds, :] 111 | sig = sig[valid_inds] 112 | 113 | 114 | ### post_scan_preprocess ### 115 | 116 | prev_raws=[] 117 | prev_sigs=[] 118 | prev_vox=[] 119 | prev_velodyne=[] 120 | prev_trans_sigs=[] 121 | 122 | if len(pose_list)>1: 123 | prev_frames_num=len(prev_xyz) 124 | for f_idx in range(prev_frames_num): 125 | 126 | ### cut point clound 127 | prev_xyz0 = prev_xyz[f_idx] 128 | prev_xyz_sin=prev_xyz[f_idx] 129 | prev_sig_sin=prev_sig[f_idx] 130 | for ci in range(3): 131 | prev_xyz0[prev_xyz_sin[:, ci] < min_bound[ci], :] = 1000 132 | prev_xyz0[prev_xyz_sin[:, ci] > max_bound[ci], :] = 1000 133 | valid_inds = prev_xyz0[:, 0] != 1000 134 | prev_xyz_single = prev_xyz_sin[valid_inds, :] 135 | prev_sig_single = prev_sig_sin[valid_inds] 136 | prev_raws.append(prev_xyz_single) 137 | prev_sigs.append(prev_sig_single) 138 | 139 | ### Transform prev point cloud 140 | transformed_prev_xyz = frame_transform_scan(prev_xyz[f_idx], pose_list[0], pose_list[f_idx+1]) 141 | transformed_prev_velodyne=frame_transform_scan(np.zeros([1,4]), pose_list[0], pose_list[f_idx+1]) 142 | 143 | ### Cut prev point cloud 144 | prev_xyz0 = transformed_prev_xyz 145 | for ci in range(3): 146 | prev_xyz0[transformed_prev_xyz[:, ci] < min_bound[ci], :] = 1000 147 | prev_xyz0[transformed_prev_xyz[:, ci] > max_bound[ci], :] = 1000 148 | prev_valid_inds = prev_xyz0[:, 0] != 1000 149 | transformed_prev_xyz = transformed_prev_xyz[prev_valid_inds, :] 150 | transformed_prev_sig_single = prev_sig_sin[prev_valid_inds] 151 | prev_vox.append(transformed_prev_xyz) 152 | prev_trans_sigs.append(transformed_prev_sig_single) 153 | prev_velodyne.append(transformed_prev_velodyne) 154 | 155 | # transpose centre coord for x axis 156 | x_bias = (self.max_volume_space[0] - self.min_volume_space[0])/2 157 | min_bound[0] -= x_bias 158 | max_bound[0] -= x_bias 159 | xyz[:, 0] -= x_bias 160 | if len(pose_list)>1: 161 | for f_idx in range(prev_frames_num): 162 | prev_raws_sim=prev_raws[f_idx] 163 | prev_vox_sin=prev_vox[f_idx] 164 | prev_raws_sim[:, 0]-= x_bias 165 | prev_vox_sin[:, 0]-= x_bias 166 | prev_velodyne_sin=prev_velodyne[f_idx] 167 | prev_velodyne_sin=prev_velodyne_sin[:,:3] 168 | prev_velodyne_sin[:,0]-= x_bias 169 | prev_vox[f_idx]=prev_vox_sin 170 | prev_velodyne[f_idx]=prev_velodyne_sin 171 | prev_raws[f_idx]=prev_raws_sim 172 | 173 | # get grid index 174 | crop_range = max_bound - min_bound 175 | cur_grid_size = self.grid_size 176 | intervals = crop_range / (cur_grid_size - 1) 177 | if (intervals == 0).any(): print("Zero interval!") 178 | 179 | grid_ind = (np.floor((np.clip(xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 180 | 181 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 182 | return_xyz = xyz - voxel_centers 183 | return_xyz = np.concatenate((return_xyz, xyz), axis=1) 184 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 185 | 186 | 187 | prev_grid_list=[] 188 | prev_fea_list=[] 189 | prev_label_list=[] 190 | prev_transformed_fea_list=[] 191 | prev_transformed_grid_list=[] 192 | if len(pose_list)>1: 193 | for f_idx in range(prev_frames_num): 194 | single_prev_xyz=prev_raws[f_idx] 195 | single_prev_sig=prev_sigs[f_idx] 196 | single_trans_prev_sig=prev_trans_sigs[f_idx] 197 | single_prev_vox=prev_vox[f_idx] 198 | 199 | prev_grid_ind = (np.floor((np.clip(single_prev_xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 200 | prev_voxel_centers = (prev_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 201 | return_prev_xyz = single_prev_xyz - prev_voxel_centers 202 | return_prev_xyz = np.concatenate((return_prev_xyz, single_prev_xyz), axis=1) 203 | return_prev_fea = np.concatenate((return_prev_xyz, single_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 204 | 205 | prev_transformed_grid_ind = (np.floor((np.clip(single_prev_vox, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 206 | prev_transformed_voxel_centers = (prev_transformed_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 207 | return_transformed_prev_xyz = single_prev_vox - prev_transformed_voxel_centers 208 | return_transformed_prev_xyz = np.concatenate((return_transformed_prev_xyz, single_prev_vox), axis=1) 209 | return_transformed_prev_fea = np.concatenate((return_transformed_prev_xyz, single_trans_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 210 | 211 | prev_fea_list.append(return_prev_fea) 212 | prev_transformed_fea_list.append(return_transformed_prev_fea) 213 | prev_transformed_grid_list.append(prev_transformed_grid_ind) 214 | prev_grid_list.append(prev_grid_ind) 215 | prev_label_list.append(prev_labels[f_idx]) 216 | 217 | dim_array = np.ones(len(self.grid_size) + 1, int) 218 | dim_array[0] = -1 219 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 220 | processed_label = labels # voxel labels 221 | 222 | data_tuple = (voxel_position, processed_label) 223 | 224 | vox_from_list=[] 225 | vox_to_list=[] 226 | if len(pose_list)>1: 227 | for f_idx in range(prev_frames_num): 228 | vox_grid_from,vox_grid_to= from_voxel_to_voxel(self.max_volume_space,self.min_volume_space,intervals,pose_list[0],pose_list[f_idx+1]) 229 | vox_from_list.append(vox_grid_from) 230 | vox_to_list.append(vox_grid_to) 231 | 232 | else: 233 | vox_grid_to=np.array([[0,0,0]]) 234 | vox_grid_from=np.array([[0,0,0]]) 235 | vox_from_list.append(vox_grid_from) 236 | vox_to_list.append(vox_grid_to) 237 | 238 | 239 | if self.return_test: 240 | data_tuple += (grid_ind, labels, return_fea, index) 241 | else: 242 | data_tuple += (grid_ind, labels, return_fea) 243 | 244 | data_tuple += (origin_len,prev_velodyne,vox_to_list,vox_from_list,prev_grid_list,prev_fea_list,prev_transformed_grid_list,prev_transformed_fea_list,prev_label_list,min_bound,max_bound,intervals,stride_list) 245 | 246 | return data_tuple 247 | 248 | 249 | # transformation between Cartesian coordinates and polar coordinates 250 | def cart2polar(input_xyz): 251 | rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2) 252 | phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0]) 253 | return np.stack((rho, phi, input_xyz[:, 2]), axis=1) 254 | 255 | 256 | def polar2cat(input_xyz_polar): 257 | # print(input_xyz_polar.shape) 258 | x = input_xyz_polar[0] * np.cos(input_xyz_polar[1]) 259 | y = input_xyz_polar[0] * np.sin(input_xyz_polar[1]) 260 | return np.stack((x, y, input_xyz_polar[2]), axis=0) 261 | 262 | @nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False) 263 | def nb_process_label(processed_label, sorted_label_voxel_pair): 264 | label_size = 256 265 | counter = np.zeros((label_size,), dtype=np.uint16) 266 | counter[sorted_label_voxel_pair[0, 3]] = 1 267 | cur_sear_ind = sorted_label_voxel_pair[0, :3] 268 | for i in range(1, sorted_label_voxel_pair.shape[0]): 269 | cur_ind = sorted_label_voxel_pair[i, :3] 270 | if not np.all(np.equal(cur_ind, cur_sear_ind)): 271 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 272 | counter = np.zeros((label_size,), dtype=np.uint16) 273 | cur_sear_ind = cur_ind 274 | counter[sorted_label_voxel_pair[i, 3]] += 1 275 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 276 | return processed_label 277 | 278 | 279 | 280 | def collate_fn_BEV_ms_tta(data): 281 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 282 | grid_ind_stack = [d[2] for d in data] 283 | xyz = [d[4] for d in data] 284 | index = [d[5] for d in data] 285 | prev_velodyne=[d[7] for d in data] 286 | vox_grid_to=[d[8] for d in data] 287 | vox_grid_from=[d[9] for d in data] 288 | prev_grid_ind=[d[10] for d in data] 289 | prev_feat=[d[11] for d in data] 290 | prev_trans_ind=[d[12] for d in data] 291 | prev_trans_feat=[d[13] for d in data] 292 | min_bound=[d[15] for d in data] 293 | max_bound=[d[16] for d in data] 294 | interval=[d[17] for d in data] 295 | strides=[d[18] for d in data] 296 | current_frame={'grid_ind':grid_ind_stack, 'pt_feat': xyz, 'index': index, 'gt':torch.from_numpy(label2stack),'min_bound':min_bound,'max_bound':max_bound,'interval':interval,'stride':strides[0]} 297 | prev_frame={'vox_grid_to':vox_grid_to[0], 'vox_grid_from': vox_grid_from[0], 'grid_ind': prev_grid_ind[0], 'pt_feat':prev_feat[0],'trans_grid_ind':prev_trans_ind[0],'trans_pt_feat':prev_trans_feat[0],'lidar_pose':prev_velodyne[0]} 298 | 299 | return current_frame,prev_frame 300 | 301 | 302 | def collate_fn_BEV(data): 303 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 304 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 305 | grid_ind_stack = [d[2] for d in data] 306 | point_label = [d[3] for d in data] 307 | xyz = [d[4] for d in data] 308 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz 309 | 310 | def collate_fn_BEV_tta(data): 311 | voxel_label = [] 312 | for da1 in data: 313 | for da2 in da1: 314 | voxel_label.append(da2[1]) 315 | grid_ind_stack = [] 316 | for da1 in data: 317 | for da2 in da1: 318 | grid_ind_stack.append(da2[2]) 319 | point_label = [] 320 | for da1 in data: 321 | for da2 in da1: 322 | point_label.append(da2[3]) 323 | xyz = [] 324 | for da1 in data: 325 | for da2 in da1: 326 | xyz.append(da2[4]) 327 | index = [] 328 | for da1 in data: 329 | for da2 in da1: 330 | index.append(da2[5]) 331 | return xyz, xyz, grid_ind_stack, point_label, xyz, index 332 | 333 | def collate_fn_BEV_ms(data): 334 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 335 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 336 | grid_ind_stack = [d[2] for d in data] 337 | point_label = [d[3] for d in data] 338 | xyz = [d[4] for d in data] 339 | index = [d[5] for d in data] 340 | origin_len = [d[6] for d in data] 341 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len 342 | 343 | -------------------------------------------------------------------------------- /network/segmentator_3d_asymm_spconv_unlock.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge, Xzy 3 | # @file: segmentator_3d_asymm_spconv.py 4 | 5 | import numpy as np 6 | import spconv 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def extract_nonzero_features(x): 12 | device = x.device 13 | nonzero_index = torch.sum(torch.abs(x), dim=1).nonzero() 14 | coords = nonzero_index.type(torch.int32).to(device) 15 | channels = int(x.shape[1]) 16 | features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels) 17 | features = features[torch.sum(torch.abs(features), dim=1).nonzero(), :] 18 | features = features.squeeze(1).to(device) 19 | coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0) 20 | return coords, features 21 | 22 | def no_extraction(x): 23 | device = x.device 24 | total_index=torch.ones_like(torch.sum(torch.abs(x), dim=1)).nonzero() 25 | coords = total_index.type(torch.int32).to(device) 26 | channels = int(x.shape[1]) 27 | features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels) 28 | total_index2=torch.ones_like(torch.sum(torch.abs(features), dim=1)).nonzero() 29 | features = features[total_index2, :] 30 | features = features.squeeze(1).to(device) 31 | coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0) 32 | return coords, features 33 | 34 | # def manual_extraction(x, idx_for_extract): 35 | # device = x.device 36 | # nonzero_index = torch.sum(torch.abs(x), dim=1).nonzero() 37 | # manually_added_index = (idx_for_extract==1).nonzero() 38 | # total_index = torch.cat((nonzero_index, manually_added_index), dim=0) 39 | # total_index, _, _ = torch.unique(total_index, return_inverse=True, return_counts=True, dim=0) 40 | # coords = total_index.type(torch.int32) 41 | 42 | # channels = int(x.shape[1]) 43 | # features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels) 44 | # nonzero_index2 = torch.sum(torch.abs(features), dim=1).nonzero() 45 | # manually_added_index2 = (idx_for_extract.flatten()==1).nonzero() 46 | # total_index2 = torch.cat((nonzero_index2, manually_added_index2), dim=0) 47 | # total_index2, _, _ = torch.unique(total_index2, return_inverse=True, return_counts=True, dim=0) 48 | # features = features[total_index2, :] 49 | # features = features.squeeze(1).to(device) 50 | # # coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0) 51 | # return coords, features 52 | 53 | def manual_extraction(x, proposure): 54 | device = x.device 55 | #proposure.shape -> (n,3) 56 | x_proposure=x.clone().detach() 57 | # import pdb;pdb.set_trace() 58 | if proposure!=None: 59 | x_proposure[0,0,proposure[:,0],proposure[:,1],proposure[:,2]]=1 60 | total_index=torch.sum(torch.abs(x_proposure), dim=1).nonzero() 61 | 62 | else: 63 | total_index=torch.ones_like(torch.sum(torch.abs(x), dim=1)).nonzero() 64 | 65 | coords = total_index.type(torch.int32).to(device) 66 | channels = int(x.shape[1]) 67 | features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels) 68 | features_proposure = x_proposure.permute(0, 2, 3, 4, 1).reshape(-1, channels) 69 | if proposure==None: 70 | total_index=torch.ones_like(torch.sum(torch.abs(features), dim=1)).nonzero() 71 | else: 72 | total_index=torch.sum(torch.abs(features_proposure), dim=1).nonzero() 73 | features = features[total_index, :] 74 | features = features.squeeze(1).to(device) 75 | coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0) 76 | return coords, features 77 | 78 | class Asymm_3d_spconv(nn.Module): 79 | def __init__(self, 80 | output_shape, 81 | use_norm=True, 82 | num_input_features=128, 83 | nclasses=20, n_height=32, strict=False, init_size=16): 84 | super(Asymm_3d_spconv, self).__init__() 85 | self.nclasses = nclasses 86 | self.nheight = n_height 87 | self.strict = False 88 | 89 | sparse_shape = np.array(output_shape) 90 | self.sparse_shape = sparse_shape 91 | 92 | ### Completion sub-network 93 | mybias = False # False 94 | chs = [init_size, init_size*1, init_size*1, init_size*1] 95 | self.a_conv1 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU()) 96 | self.a_conv2 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU()) 97 | self.a_conv3 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 5, 1, padding=2, bias=mybias), nn.ReLU()) 98 | self.a_conv4 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 7, 1, padding=3, bias=mybias), nn.ReLU()) 99 | self.a_conv5 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU()) 100 | self.a_conv6 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 5, 1, padding=2, bias=mybias), nn.ReLU()) 101 | self.a_conv7 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 7, 1, padding=3, bias=mybias), nn.ReLU()) 102 | self.ch_conv1 = nn.Sequential(nn.Conv3d(chs[1]*7, chs[0], kernel_size=1, stride=1, bias=mybias), nn.ReLU()) 103 | self.res_1 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 3, 1, padding=1, bias=mybias), nn.ReLU()) 104 | self.res_2 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 5, 1, padding=2, bias=mybias), nn.ReLU()) 105 | self.res_3 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 7, 1, padding=3, bias=mybias), nn.ReLU()) 106 | 107 | ### Segmentation sub-network 108 | self.downCntx = ResContextBlock(num_input_features, init_size, indice_key="pre") 109 | self.resBlock2 = ResBlock(init_size, 2 * init_size, 0.2, height_pooling=True, indice_key="down2") 110 | self.resBlock3 = ResBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=True, indice_key="down3") 111 | self.resBlock4 = ResBlock(4 * init_size, 8 * init_size, 0.2, pooling=True, height_pooling=False, 112 | indice_key="down4") 113 | self.resBlock5 = ResBlock(8 * init_size, 16 * init_size, 0.2, pooling=True, height_pooling=False, 114 | indice_key="down5") 115 | 116 | self.upBlock0 = UpBlock(16 * init_size, 16 * init_size, indice_key="up0", up_key="down5") 117 | self.upBlock1 = UpBlock(16 * init_size, 8 * init_size, indice_key="up1", up_key="down4") 118 | self.upBlock2 = UpBlock(8 * init_size, 4 * init_size, indice_key="up2", up_key="down3") 119 | self.upBlock3 = UpBlock(4 * init_size, 2 * init_size, indice_key="up3", up_key="down2") 120 | 121 | self.ReconNet = ReconBlock(2 * init_size, 2 * init_size, indice_key="recon") 122 | 123 | self.logits = spconv.SubMConv3d(4 * init_size, nclasses, indice_key="logit", kernel_size=3, stride=1, padding=1, 124 | bias=True) 125 | 126 | 127 | def forward(self, voxel_features, coors, batch_size, extraction='all'): 128 | 129 | coors = coors.int() 130 | x_sparse = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) 131 | x = x_sparse 132 | # import pdb;pdb.set_trace() 133 | # Spase to dense 134 | x_dense = x_sparse.dense() 135 | 136 | ### Completion sub-network by dense convolution 137 | x1 = self.a_conv1(x_dense) 138 | x2 = self.a_conv2(x1) 139 | x3 = self.a_conv3(x1) 140 | x4 = self.a_conv4(x1) 141 | t1 = torch.cat((x2, x3, x4), 1) 142 | x5 = self.a_conv5(t1) 143 | x6 = self.a_conv6(t1) 144 | x7 = self.a_conv7(t1) 145 | x = torch.cat((x1, x2, x3, x4, x5, x6, x7), 1) 146 | y0 = self.ch_conv1(x) 147 | y1 = self.res_1(x_dense) 148 | y2 = self.res_2(x_dense) 149 | y3 = self.res_3(x_dense) 150 | x = x_dense + y0 + y1 + y2 + y3 151 | 152 | # Dense to sparse 153 | if extraction is 'all': 154 | coord, features = extract_nonzero_features(x) 155 | elif extraction is None: 156 | coord, features = no_extraction(x) 157 | else: 158 | coord, features = manual_extraction(x, extraction) 159 | x = spconv.SparseConvTensor(features, coord.int(), self.sparse_shape, batch_size) # voxel features 160 | 161 | ### Segmentation sub-network by sparse convolution 162 | x = self.downCntx(x) 163 | down1c, down1b = self.resBlock2(x) 164 | down2c, down2b = self.resBlock3(down1c) 165 | down3c, down3b = self.resBlock4(down2c) 166 | down4c, down4b = self.resBlock5(down3c) 167 | 168 | up4e = self.upBlock0(down4c, down4b) 169 | up3e = self.upBlock1(up4e, down3b) 170 | up2e = self.upBlock2(up3e, down2b) 171 | up1e = self.upBlock3(up2e, down1b) 172 | 173 | up0e = self.ReconNet(up1e) 174 | 175 | up0e.features = torch.cat((up0e.features, up1e.features), 1) 176 | 177 | logits = self.logits(up0e) 178 | y = logits.dense() 179 | 180 | return y 181 | 182 | @staticmethod 183 | def _joining(encoder_features, x, concat): 184 | if concat: 185 | return torch.cat((encoder_features, x), dim=1) 186 | else: 187 | return encoder_features + x 188 | 189 | 190 | 191 | def conv3x3(in_planes, out_planes, stride=1, indice_key=None): 192 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=3, stride=stride, 193 | padding=1, bias=False, indice_key=indice_key) 194 | 195 | 196 | def conv1x3(in_planes, out_planes, stride=1, indice_key=None): 197 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride, 198 | padding=(0, 1, 1), bias=False, indice_key=indice_key) 199 | 200 | 201 | def conv1x1x3(in_planes, out_planes, stride=1, indice_key=None): 202 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, 203 | padding=(0, 0, 1), bias=False, indice_key=indice_key) 204 | 205 | 206 | def conv1x3x1(in_planes, out_planes, stride=1, indice_key=None): 207 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, 208 | padding=(0, 1, 0), bias=False, indice_key=indice_key) 209 | 210 | 211 | def conv3x1x1(in_planes, out_planes, stride=1, indice_key=None): 212 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, 213 | padding=(1, 0, 0), bias=False, indice_key=indice_key) 214 | 215 | 216 | def conv3x1(in_planes, out_planes, stride=1, indice_key=None): 217 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, 218 | padding=(1, 0, 1), bias=False, indice_key=indice_key) 219 | 220 | 221 | def conv1x1(in_planes, out_planes, stride=1, indice_key=None): 222 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=1, stride=stride, 223 | padding=1, bias=False, indice_key=indice_key) 224 | 225 | 226 | class ResContextBlock(nn.Module): 227 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None): 228 | super(ResContextBlock, self).__init__() 229 | self.conv1 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef") 230 | self.bn0 = nn.BatchNorm1d(out_filters) 231 | self.act1 = nn.LeakyReLU() 232 | 233 | self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") 234 | self.bn0_2 = nn.BatchNorm1d(out_filters) 235 | self.act1_2 = nn.LeakyReLU() 236 | 237 | self.conv2 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef") 238 | self.act2 = nn.LeakyReLU() 239 | self.bn1 = nn.BatchNorm1d(out_filters) 240 | 241 | self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") 242 | self.act3 = nn.LeakyReLU() 243 | self.bn2 = nn.BatchNorm1d(out_filters) 244 | 245 | self.weight_initialization() 246 | 247 | def weight_initialization(self): 248 | for m in self.modules(): 249 | if isinstance(m, nn.BatchNorm1d): 250 | nn.init.constant_(m.weight, 1) 251 | nn.init.constant_(m.bias, 0) 252 | 253 | def forward(self, x): 254 | shortcut = self.conv1(x) 255 | shortcut.features = self.act1(shortcut.features) 256 | shortcut.features = self.bn0(shortcut.features) 257 | 258 | shortcut = self.conv1_2(shortcut) 259 | shortcut.features = self.act1_2(shortcut.features) 260 | shortcut.features = self.bn0_2(shortcut.features) 261 | 262 | resA = self.conv2(x) 263 | resA.features = self.act2(resA.features) 264 | resA.features = self.bn1(resA.features) 265 | 266 | resA = self.conv3(resA) 267 | resA.features = self.act3(resA.features) 268 | resA.features = self.bn2(resA.features) 269 | resA.features = resA.features + shortcut.features 270 | 271 | return resA 272 | 273 | 274 | class ResBlock(nn.Module): 275 | def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1, 276 | pooling=True, drop_out=True, height_pooling=False, indice_key=None): 277 | super(ResBlock, self).__init__() 278 | self.pooling = pooling 279 | self.drop_out = drop_out 280 | 281 | self.conv1 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef") 282 | self.act1 = nn.LeakyReLU() 283 | self.bn0 = nn.BatchNorm1d(out_filters) 284 | 285 | self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") 286 | self.act1_2 = nn.LeakyReLU() 287 | self.bn0_2 = nn.BatchNorm1d(out_filters) 288 | 289 | self.conv2 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef") 290 | self.act2 = nn.LeakyReLU() 291 | self.bn1 = nn.BatchNorm1d(out_filters) 292 | 293 | self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") 294 | self.act3 = nn.LeakyReLU() 295 | self.bn2 = nn.BatchNorm1d(out_filters) 296 | 297 | if pooling: 298 | if height_pooling: 299 | self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=2, 300 | padding=1, indice_key=indice_key, bias=False) 301 | else: 302 | self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1), 303 | padding=1, indice_key=indice_key, bias=False) 304 | self.weight_initialization() 305 | 306 | def weight_initialization(self): 307 | for m in self.modules(): 308 | if isinstance(m, nn.BatchNorm1d): 309 | nn.init.constant_(m.weight, 1) 310 | nn.init.constant_(m.bias, 0) 311 | 312 | def forward(self, x): 313 | shortcut = self.conv1(x) 314 | shortcut.features = self.act1(shortcut.features) 315 | shortcut.features = self.bn0(shortcut.features) 316 | 317 | shortcut = self.conv1_2(shortcut) 318 | shortcut.features = self.act1_2(shortcut.features) 319 | shortcut.features = self.bn0_2(shortcut.features) 320 | 321 | resA = self.conv2(x) 322 | resA.features = self.act2(resA.features) 323 | resA.features = self.bn1(resA.features) 324 | 325 | resA = self.conv3(resA) 326 | resA.features = self.act3(resA.features) 327 | resA.features = self.bn2(resA.features) 328 | 329 | resA.features = resA.features + shortcut.features 330 | 331 | if self.pooling: 332 | resB = self.pool(resA) 333 | return resB, resA 334 | else: 335 | return resA 336 | 337 | 338 | class UpBlock(nn.Module): 339 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), indice_key=None, up_key=None): 340 | super(UpBlock, self).__init__() 341 | # self.drop_out = drop_out 342 | self.trans_dilao = conv3x3(in_filters, out_filters, indice_key=indice_key + "new_up") 343 | self.trans_act = nn.LeakyReLU() 344 | self.trans_bn = nn.BatchNorm1d(out_filters) 345 | 346 | self.conv1 = conv1x3(out_filters, out_filters, indice_key=indice_key) 347 | self.act1 = nn.LeakyReLU() 348 | self.bn1 = nn.BatchNorm1d(out_filters) 349 | 350 | self.conv2 = conv3x1(out_filters, out_filters, indice_key=indice_key) 351 | self.act2 = nn.LeakyReLU() 352 | self.bn2 = nn.BatchNorm1d(out_filters) 353 | 354 | self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key) 355 | self.act3 = nn.LeakyReLU() 356 | self.bn3 = nn.BatchNorm1d(out_filters) 357 | # self.dropout3 = nn.Dropout3d(p=dropout_rate) 358 | 359 | self.up_subm = spconv.SparseInverseConv3d(out_filters, out_filters, kernel_size=3, indice_key=up_key, 360 | bias=False) 361 | 362 | self.weight_initialization() 363 | 364 | def weight_initialization(self): 365 | for m in self.modules(): 366 | if isinstance(m, nn.BatchNorm1d): 367 | nn.init.constant_(m.weight, 1) 368 | nn.init.constant_(m.bias, 0) 369 | 370 | def forward(self, x, skip): 371 | upA = self.trans_dilao(x) 372 | upA.features = self.trans_act(upA.features) 373 | upA.features = self.trans_bn(upA.features) 374 | 375 | ## upsample 376 | upA = self.up_subm(upA) 377 | 378 | upA.features = upA.features + skip.features 379 | 380 | upE = self.conv1(upA) 381 | upE.features = self.act1(upE.features) 382 | upE.features = self.bn1(upE.features) 383 | 384 | upE = self.conv2(upE) 385 | upE.features = self.act2(upE.features) 386 | upE.features = self.bn2(upE.features) 387 | 388 | upE = self.conv3(upE) 389 | upE.features = self.act3(upE.features) 390 | upE.features = self.bn3(upE.features) 391 | 392 | return upE 393 | 394 | 395 | class ReconBlock(nn.Module): 396 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None): 397 | super(ReconBlock, self).__init__() 398 | self.conv1 = conv3x1x1(in_filters, out_filters, indice_key=indice_key + "bef") 399 | self.bn0 = nn.BatchNorm1d(out_filters) 400 | self.act1 = nn.Sigmoid() 401 | 402 | self.conv1_2 = conv1x3x1(in_filters, out_filters, indice_key=indice_key + "bef") 403 | self.bn0_2 = nn.BatchNorm1d(out_filters) 404 | self.act1_2 = nn.Sigmoid() 405 | 406 | self.conv1_3 = conv1x1x3(in_filters, out_filters, indice_key=indice_key + "bef") 407 | self.bn0_3 = nn.BatchNorm1d(out_filters) 408 | self.act1_3 = nn.Sigmoid() 409 | 410 | def forward(self, x): 411 | shortcut = self.conv1(x) 412 | shortcut.features = self.bn0(shortcut.features) 413 | shortcut.features = self.act1(shortcut.features) 414 | 415 | shortcut2 = self.conv1_2(x) 416 | shortcut2.features = self.bn0_2(shortcut2.features) 417 | shortcut2.features = self.act1_2(shortcut2.features) 418 | 419 | shortcut3 = self.conv1_3(x) 420 | shortcut3.features = self.bn0_3(shortcut3.features) 421 | shortcut3.features = self.act1_3(shortcut3.features) 422 | shortcut.features = shortcut.features + shortcut2.features + shortcut3.features 423 | 424 | shortcut.features = shortcut.features * x.features 425 | 426 | return shortcut 427 | -------------------------------------------------------------------------------- /run_tta_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import time 4 | import argparse 5 | import sys 6 | import numpy as np 7 | 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | 14 | from dataloader.pc_dataset_test import get_SemKITTI_label_name, get_eval_mask, unpack 15 | from config.config import load_config_data 16 | from builder import loss_builder 17 | from builder import model_builder_unlock as model_builder 18 | from builder import data_builder_test as data_builder 19 | from utils.load_save_util import load_checkpoint 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | import yaml 23 | from utils.util import Bresenham3D 24 | import random 25 | 26 | import ast 27 | import pdb 28 | import copy 29 | import importlib.util as ilutil 30 | import glob 31 | from torch.utils.tensorboard import SummaryWriter 32 | from matplotlib import pyplot as plt 33 | from utils.np_ioueval import iouEval 34 | from utils.softmax_entropy import softmax_entropy 35 | 36 | import pandas as pd 37 | import seaborn as sn 38 | 39 | CATS = ["empty", 40 | "car", "bicycle", "motorcycle", "truck", "other-vehicle", 41 | "person", "bicyclist", "motorcyclist", "road", "parking", 42 | "sidewalk", "other-ground", "building", "fence", "vegetation", 43 | "trunk", "terrain", "pole", "traffic-sign"] 44 | 45 | mapping_forward = {0:0, 46 | 1:0, 10:1, 11:2, 13:5, 15:3, 47 | 16:5, 18:4, 20:5, 30:6, 31:7, 48 | 32:8, 40:9, 44:10, 48:11, 49:12, 49 | 50:13, 51:14, 52:0, 60:9, 70:15, 50 | 71:16, 72:17, 80:18, 81:19, 99:0, 51 | 252:1, 253:7, 254:6, 255:8, 256:5, 52 | 257:5, 258:4, 259:5} 53 | 54 | max_key = max(mapping_forward.keys()) 55 | 56 | MAP_ARRAY = np.zeros(max_key + 1, dtype=int) 57 | 58 | for key, value in mapping_forward.items(): 59 | MAP_ARRAY[key] = value 60 | 61 | PALLETE = np.asarray([[0, 0, 0],[245, 150, 100],[245, 230, 100],[150, 60, 30],[180, 30, 80],[255, 0, 0],[30, 30, 255],[200, 40, 255],[90, 30, 150],[255, 0, 255], 62 | [255, 150, 255],[75, 0, 75],[75, 0, 175],[0, 200, 255],[50, 120, 255],[0, 175, 0],[0, 60, 135],[80, 240, 150],[150, 240, 255],[0, 0, 255], [255,255,255]]).astype(np.uint8) 63 | PALLETE[:,[0,2]]=PALLETE[:,[2,0]] 64 | PALLETE_BINARY = np.asarray([[255,0,0], [0,255,0], [0,0,255], [0,0,0]]).astype(np.uint8) 65 | 66 | def train2SemKITTI(input_label): 67 | # delete 0 label (uses uint8 trick : 0 - 1 = 255 ) 68 | return input_label + 1 69 | 70 | def get_remap_first(semkittiyaml): 71 | # make lookup table for mapping 72 | learning_map_inv = semkittiyaml["learning_map_inv"] 73 | maxkey = max(learning_map_inv.keys()) 74 | # +100 hack making lut bigger just in case there are unknown labels 75 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 76 | remap_lut[list(learning_map_inv.keys())] = list(learning_map_inv.values()) 77 | return remap_lut,learning_map_inv 78 | 79 | def get_remap_second(semkittiyaml): 80 | class_remap = semkittiyaml["learning_map"] 81 | maxkey2 = max(class_remap.keys()) 82 | remap_lut = np.zeros((maxkey2 + 100), dtype=np.int32) 83 | remap_lut[list(class_remap.keys())] = list(class_remap.values()) 84 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 85 | remap_lut[0] = 0 86 | return remap_lut 87 | 88 | 89 | def remapping(pred, remap=None): 90 | ### save prediction after remapping 91 | upper_half = pred >> 16 # get upper half for instances 92 | lower_half = pred & 0xFFFF # get lower half for semantics 93 | lower_half = remap[lower_half] # do the remapping of semantics 94 | pred = (upper_half << 16) + lower_half # reconstruct full label 95 | pred = pred.astype(np.uint32) 96 | pred = pred.astype(np.uint16) 97 | return pred 98 | 99 | def save_pred(final_preds, save_dir, output_path): 100 | _, dir2 = save_dir.split('/sequences/',1) 101 | new_save_dir = output_path + '/sequences/' + dir2.replace('velodyne', 'predictions')[:-3]+'label' 102 | if not os.path.exists(os.path.dirname(new_save_dir)): 103 | try: 104 | os.makedirs(os.path.dirname(new_save_dir)) 105 | except OSError as exc: 106 | if exc.errno != errno.EEXIST: 107 | raise 108 | final_preds.tofile(new_save_dir) 109 | 110 | def extract_bev_for_vis(arr, ignore_idx=0): 111 | for i in arr[::-1]: 112 | if i != ignore_idx: 113 | return i 114 | return ignore_idx 115 | 116 | def extract_mev_for_vis(arr, ignore_idx=0): 117 | for i in arr: 118 | if i != ignore_idx: 119 | return i 120 | return ignore_idx 121 | 122 | def extract_bev_for_vis_dual(arr, ignore_idx=(0,255)): 123 | for i in arr[::-1]: 124 | if i not in ignore_idx: 125 | return i 126 | return 255 127 | 128 | 129 | def main(args): 130 | pytorch_device = torch.device('cuda:0') 131 | epsilon = np.finfo(np.float32).eps 132 | 133 | config_path = args.config_path 134 | 135 | configs = load_config_data(config_path) 136 | 137 | dataset_config = configs['dataset_params'] 138 | train_dataloader_config = configs['train_data_loader'] 139 | val_dataloader_config = configs['val_data_loader'] 140 | val_batch_size = val_dataloader_config['batch_size'] 141 | 142 | model_config = configs['model_params'] 143 | train_hypers = configs['train_params'] 144 | 145 | grid_size = model_config['output_shape'] 146 | num_class = model_config['num_class'] 147 | ignore_label = dataset_config['ignore_label'] 148 | loss_fn_ce, loss_fn_lovasz = loss_builder.build(wce=True, lovasz=True, num_class=num_class, ignore_label=ignore_label) 149 | loss_fn_ce_binary, loss_fn_lovasz_binary = loss_builder.build(wce=True, lovasz=True, num_class=2, ignore_label=ignore_label) 150 | 151 | # Define dataset/loader 152 | with open("config/label_mapping/semantic-kitti.yaml", 'r') as stream: 153 | semkittiyaml = yaml.safe_load(stream) 154 | remap_first,class_inv_remap = get_remap_first(semkittiyaml) 155 | remap_second = get_remap_second(semkittiyaml) 156 | class_strings = semkittiyaml["labels"] 157 | strides=[int(num) for num in ast.literal_eval(args.stride)] 158 | 159 | test_dataset_loader, test_pt_dataset = data_builder.build(dataset_config, 160 | train_dataloader_config, 161 | val_dataloader_config, 162 | grid_size=grid_size, 163 | use_tta=True, 164 | use_multiscan=True, 165 | stride=args.stride, 166 | sq_num=args.sq_num) 167 | 168 | # Define experiment path 169 | exp_name = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + '_' + __file__.replace('run_tta_','').replace('.py','') + args.name 170 | exp_path = args.talos_root+'experiments/' + exp_name 171 | print("Experiment path is "+exp_path) 172 | os.makedirs(exp_path, exist_ok=True) 173 | 174 | # Save run_code snapshots 175 | os.system('scp -r '+args.talos_root+__file__+' '+exp_path+'/'+__file__) 176 | # 177 | 178 | writer = SummaryWriter(exp_path) 179 | config_file = exp_path + '/config.txt' 180 | with open(config_file, 'w') as log: 181 | log.write(str(args)) 182 | 183 | loss_cont_names = ['loss_cont', 'loss_cont_occ_ce', 'loss_cont_occ_lovasz', 'loss_cont_pgt_ce','loss_cont_pgt_lovasz'] 184 | loss_adapt_names = ['loss_adapt', 'loss_adapt_occ_ce', 'loss_adapt_occ_lovasz', 'loss_adapt_pgt_ce','loss_adapt_pgt_lovasz'] 185 | evaluator_all = iouEval(num_class, []) 186 | baseline_performance = open(args.baseline_perf_txt, 'r') 187 | baseline_prediction_paths = sorted(glob.glob(args.baseline_preds+'/*.label')) 188 | 189 | model_load_path = train_hypers['model_load_path'] 190 | model_load_path += 'pretrained.pth' 191 | 192 | model_baseline = model_builder.build(model_config) 193 | print('Load model from: %s' % model_load_path) 194 | model_baseline = load_checkpoint(model_load_path, model_baseline) 195 | model_baseline.to(pytorch_device) 196 | 197 | # For freeze 198 | module_names_mlp = ['cylinder_3d_generator'] 199 | module_names_comp = ['a_conv1', 'a_conv2', 'a_conv3', 'a_conv4', 'a_conv5', 'a_conv6', 'a_conv7', 'ch_conv1','res_1','res_2','res_3'] 200 | module_names_seg = ['downCntx', 'resBlock2', 'resBlock3', 'resBlock4', 'resBlock5', 'upBlock0', 'upBlock1', 'upBlock2', 'upBlock3', 'ReconNet'] 201 | module_names_logit = ['logits'] 202 | 203 | assert (args.do_adapt or args.do_cont) 204 | 205 | if args.do_cont: 206 | print("continual") 207 | model_cont = copy.deepcopy(model_baseline) 208 | param_to_update_cont = [] 209 | 210 | print("Update segmentation module for continual tta.") 211 | for tn in module_names_seg: 212 | param_to_update_cont += list(getattr(model_cont.cylinder_3d_spconv_seg, tn).parameters()) 213 | 214 | optimizer_cont = optim.Adam(param_to_update_cont, lr=args.cont_lr) 215 | 216 | if args.do_adapt: 217 | print("scan-wise adaptation") 218 | 219 | # current_frame={'grid_ind','pt_feat','index','gt'} 220 | # prev_frame={'vox_grid_to','vox_grid_from','grid_ind','pt_feat','trans_grid_ind','trans_pt_feat','lidar_pose'} 221 | for idx_test, (frame_curr, frame_aux) in enumerate(tqdm(test_dataset_loader)): 222 | 223 | print('') 224 | 225 | ######################################################################################################################################### 226 | ############################################################# Data process ############################################################## 227 | ######################################################################################################################################### 228 | 229 | exist_stride = [stride in frame_curr['stride'] for stride in strides] 230 | 231 | flag_aux_exist= bool(len(frame_aux['trans_grid_ind'])!=0) 232 | flag_adapt_aux_exist= bool(exist_stride[0]) 233 | flag_cont_aux_exist= bool(exist_stride[1]) 234 | adapt_aux_idx= 0 235 | cont_aux_idx = 1 if flag_adapt_aux_exist else 0 236 | 237 | feat_curr = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in frame_curr['pt_feat']] 238 | grid_curr = [torch.from_numpy(i).to(pytorch_device) for i in frame_curr['grid_ind']] 239 | gt_curr = frame_curr['gt'].to(pytorch_device) 240 | 241 | if flag_aux_exist: 242 | feat_aux = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in frame_aux['pt_feat']] 243 | grid_aux = [torch.from_numpy(i).to(pytorch_device) for i in frame_aux['grid_ind']] 244 | else: 245 | print("The pointer is at the edge of the sequence: " + str(idx_test)) 246 | 247 | ######################################################################################################################################### 248 | ##################################################### Pseudo GT generation ############################################################## 249 | ######################################################################################################################################### 250 | 251 | if args.use_los: 252 | print("Generate occupancy pgt by checking the line of sight") 253 | voxel_los_adapt = 255*torch.ones([1,256,256,32],dtype=torch.long).to(pytorch_device) # 0:empty, 1:occupied, 2:LoS, 255:ignore 254 | voxel_los_cont = 255*torch.ones([1,256,256,32],dtype=torch.long).to(pytorch_device) # 0:empty, 1:occupied, 2:LoS, 255:ignore 255 | if flag_adapt_aux_exist: 256 | idx_curr_occupied = frame_curr['grid_ind'][0] 257 | idx_curr_occupied = np.unique(idx_curr_occupied, axis =0) 258 | idx_aux_occupied = frame_aux['trans_grid_ind'][adapt_aux_idx] 259 | idx_aux_occupied = np.unique(idx_aux_occupied, axis=0) 260 | idx_cat_occupied = np.concatenate([idx_curr_occupied, idx_aux_occupied],0) 261 | 262 | voxel_aux = np.zeros([256,256,32]) 263 | voxel_aux[idx_aux_occupied[:,0],idx_aux_occupied[:,1],idx_aux_occupied[:,2]] = 1 # Empty:0, Occupy:1 264 | voxel_curr = np.zeros([256,256,32]) 265 | voxel_curr[idx_curr_occupied[:,0],idx_curr_occupied[:,1],idx_curr_occupied[:,2]] = 1 # Empty:0, Occupy:1 266 | 267 | voxel_aux_only = np.where(np.logical_and(voxel_curr!=1,voxel_aux==1), True, False) 268 | los_start = (np.floor((np.clip(frame_aux['lidar_pose'][0] , frame_curr['min_bound'][0], frame_curr['max_bound'][0]) - frame_curr['min_bound'][0]) / frame_curr['interval'][0])).astype(np.int) 269 | los_end = np.argwhere(voxel_aux_only==True) 270 | los_end = los_end[random.sample(range(los_end.shape[0]),los_end.shape[0]//8)] 271 | idx_los_empty = Bresenham3D(los_start,los_end,idx_cat_occupied) 272 | idx_los_empty = np.unique(np.array(idx_los_empty), axis=0) 273 | idx_los_occupied = idx_aux_occupied 274 | 275 | voxel_los_adapt[0, idx_los_empty[:,0],idx_los_empty[:,1],idx_los_empty[:,2]] = 0 276 | voxel_los_adapt[0, idx_los_occupied[:,0],idx_los_occupied[:,1],idx_los_occupied[:,2]] = 1 277 | 278 | if flag_cont_aux_exist: 279 | idx_curr_occupied = frame_curr['grid_ind'][0] 280 | idx_curr_occupied = np.unique(idx_curr_occupied, axis =0) 281 | idx_aux_occupied = frame_aux['trans_grid_ind'][cont_aux_idx] 282 | idx_aux_occupied = np.unique(idx_aux_occupied, axis=0) 283 | idx_cat_occupied = np.concatenate([idx_curr_occupied, idx_aux_occupied],0) 284 | 285 | voxel_aux = np.zeros([256,256,32]) 286 | voxel_aux[idx_aux_occupied[:,0],idx_aux_occupied[:,1],idx_aux_occupied[:,2]] = 1 # Empty:0, Occupy:1 287 | voxel_curr = np.zeros([256,256,32]) 288 | voxel_curr[idx_curr_occupied[:,0],idx_curr_occupied[:,1],idx_curr_occupied[:,2]] = 1 # Empty:0, Occupy:1 289 | 290 | voxel_aux_only = np.where(np.logical_and(voxel_curr!=1,voxel_aux==1), True, False) 291 | los_start = (np.floor((np.clip(frame_aux['lidar_pose'][cont_aux_idx] , frame_curr['min_bound'][0], frame_curr['max_bound'][0]) - frame_curr['min_bound'][0]) / frame_curr['interval'][0])).astype(np.int) 292 | los_end = np.argwhere(voxel_aux_only==True) 293 | los_end = los_end[random.sample(range(los_end.shape[0]),los_end.shape[0]//30)] 294 | idx_los_empty = Bresenham3D(los_start,los_end,idx_cat_occupied) 295 | idx_los_empty = np.unique(np.array(idx_los_empty), axis=0) 296 | idx_los_occupied = idx_aux_occupied 297 | 298 | voxel_los_cont[0, idx_los_empty[:,0],idx_los_empty[:,1],idx_los_empty[:,2]] = 0 299 | voxel_los_cont[0, idx_los_occupied[:,0],idx_los_occupied[:,1],idx_los_occupied[:,2]] = 1 300 | else: 301 | print("Skip to generate los-based pgt.") 302 | 303 | dim_chn=20 304 | conv_ones = nn.Conv3d(dim_chn, dim_chn, kernel_size=(5,5,3), stride=1, padding=(2,2,1), bias=False) 305 | conv_ones.weight = torch.nn.Parameter(torch.ones((dim_chn,dim_chn,5,5,3))) 306 | conv_ones.weight.requires_grad=False 307 | conv_ones.cuda() 308 | proposure_idx=None 309 | 310 | if args.use_pgt: 311 | print("Generate class pgt according to entropy-based confidence.") 312 | model_baseline.eval() 313 | voxel_pgt_aux_cont = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device) 314 | voxel_pgt_aux_adapt = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device) 315 | voxel_pgt_curr = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device) 316 | voxel_pgt_all_adapt = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device) 317 | voxel_pgt_all_cont = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device) 318 | 319 | with torch.no_grad(): 320 | 321 | ### Current PGT 322 | print("- Generate class pgt using current scan.") 323 | pred_logit_curr = model_baseline(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False) 324 | 325 | pred_cls_curr = torch.argmax(pred_logit_curr, 1).type(torch.LongTensor).to(pytorch_device) 326 | mask_pred_curr_zeroforced = (torch.sum(pred_logit_curr, dim=1)==0) # Forced to be zero during dense-to-sparse (refer to spconv) 327 | 328 | conf_curr = softmax_entropy(pred_logit_curr) 329 | conf_curr[mask_pred_curr_zeroforced] = 0 330 | conf_curr = 1 - conf_curr/torch.max(conf_curr) 331 | conf_curr[mask_pred_curr_zeroforced] = -1 332 | 333 | cur_middle_rfield=conv_ones(pred_logit_curr) 334 | cur_middle_field = (torch.sum(cur_middle_rfield, dim=1)!=0) 335 | cur_empty_field=(torch.sum(pred_logit_curr, dim=1)==0) 336 | proposure_idx=torch.where(cur_middle_field*cur_empty_field,1,0) 337 | proposure_idx=proposure_idx.squeeze().nonzero() 338 | 339 | mask_reliable_curr_occupied = torch.logical_and(conf_curr>args.th_pgt_occupied, pred_cls_curr!=0) 340 | mask_reliable_curr_empty = torch.logical_and(conf_curr>args.th_pgt_empty, pred_cls_curr==0) 341 | mask_reliable_curr = torch.logical_or(mask_reliable_curr_empty, mask_reliable_curr_occupied) 342 | 343 | voxel_pgt_curr[mask_reliable_curr] = pred_cls_curr[mask_reliable_curr] 344 | 345 | vis_voxel_pgt_curr = voxel_pgt_curr[0].cpu().detach().numpy() 346 | vis_voxel_pgt_curr = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt_curr, ignore_idx=(0,255)) 347 | vis_voxel_pgt_curr[vis_voxel_pgt_curr==255] = 20 348 | vis_voxel_pgt_curr = PALLETE[vis_voxel_pgt_curr] 349 | 350 | ### Aux PGT 351 | if flag_adapt_aux_exist: 352 | print("- Generate class pgt for adapt using auxiliary scan.") 353 | 354 | vox_grid_to = frame_aux['vox_grid_to'][adapt_aux_idx].astype(np.int) 355 | vox_grid_from = frame_aux['vox_grid_from'][adapt_aux_idx].astype(np.int) 356 | 357 | pred_logit_aux = model_baseline([feat_aux[adapt_aux_idx]], [grid_aux[adapt_aux_idx]], val_batch_size, [frame_aux['grid_ind'][adapt_aux_idx]], use_tta=False) 358 | pred_cls_aux = torch.argmax(pred_logit_aux, 1).type(torch.LongTensor).to(pytorch_device) 359 | mask_pred_aux_zeroforced = (torch.sum(pred_logit_aux, dim=1)==0) # Forced to be zero during dense-to-sparse (refer to spconv) 360 | 361 | conf_aux = softmax_entropy(pred_logit_aux) 362 | conf_aux[mask_pred_aux_zeroforced] = 0 363 | conf_aux = 1 - conf_aux/torch.max(conf_aux) 364 | conf_aux[mask_pred_aux_zeroforced] = -1 365 | 366 | mask_reliable_aux_occupied = torch.logical_and(conf_aux>args.th_pgt_occupied, pred_cls_aux!=0) 367 | mask_reliable_aux_empty = torch.logical_and(conf_aux>args.th_pgt_empty, pred_cls_aux==0) 368 | mask_reliable_aux = torch.logical_or(mask_reliable_aux_occupied, mask_reliable_aux_empty) 369 | 370 | pred_cls_aux[~mask_reliable_aux] = 255 371 | pred_cls_aux[mask_pred_aux_zeroforced] = 255 372 | voxel_pgt_aux_adapt[0, vox_grid_to[:, 0], vox_grid_to[:, 1], vox_grid_to[:, 2]] = pred_cls_aux[0, vox_grid_from[:, 0], vox_grid_from[:, 1], vox_grid_from[:, 2]] 373 | 374 | else: 375 | vis_voxel_pgt_aux_adapt = None 376 | print("- Skip to make aux pgt for adapt.") 377 | 378 | vis_voxel_pgt_aux_adapt = voxel_pgt_aux_adapt[0].cpu().detach().numpy() 379 | vis_voxel_pgt_aux_adapt = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt_aux_adapt, ignore_idx=(0,255)) 380 | vis_voxel_pgt_aux_adapt[vis_voxel_pgt_aux_adapt==255] = 20 381 | vis_voxel_pgt_aux_adapt = PALLETE[vis_voxel_pgt_aux_adapt] 382 | 383 | if flag_cont_aux_exist: 384 | 385 | print("- Generate class pgt for cont using auxiliary scan.") 386 | 387 | vox_grid_to = frame_aux['vox_grid_to'][cont_aux_idx].astype(np.int) 388 | vox_grid_from = frame_aux['vox_grid_from'][cont_aux_idx].astype(np.int) 389 | 390 | pred_logit_aux = model_baseline([feat_aux[cont_aux_idx]], [grid_aux[cont_aux_idx]], val_batch_size, [frame_aux['grid_ind'][cont_aux_idx]], use_tta=False) 391 | pred_cls_aux_cont = torch.argmax(pred_logit_aux, 1).type(torch.LongTensor).to(pytorch_device) 392 | mask_pred_aux_zeroforced = (torch.sum(pred_logit_aux, dim=1)==0) # Forced to be zero during dense-to-sparse (refer to spconv) 393 | 394 | conf_aux = softmax_entropy(pred_logit_aux) 395 | conf_aux[mask_pred_aux_zeroforced] = 0 396 | conf_aux = 1 - conf_aux/torch.max(conf_aux) 397 | conf_aux[mask_pred_aux_zeroforced] = -1 398 | 399 | mask_reliable_aux_occupied = torch.logical_and(conf_aux>args.th_pgt_occupied, pred_cls_aux_cont!=0) 400 | mask_reliable_aux_empty = torch.logical_and(conf_aux>args.th_pgt_empty, pred_cls_aux_cont==0) 401 | mask_reliable_aux = torch.logical_or(mask_reliable_aux_occupied, mask_reliable_aux_empty) 402 | 403 | pred_cls_aux_cont[~mask_reliable_aux] = 255 404 | pred_cls_aux_cont[mask_pred_aux_zeroforced] = 255 405 | voxel_pgt_aux_cont[0, vox_grid_to[:, 0], vox_grid_to[:, 1], vox_grid_to[:, 2]] = pred_cls_aux_cont[0, vox_grid_from[:, 0], vox_grid_from[:, 1], vox_grid_from[:, 2]] 406 | 407 | else: 408 | vis_voxel_pgt_aux_cont = None 409 | print("- Skip to make aux pgt for cont.") 410 | 411 | vis_voxel_pgt_aux_cont = voxel_pgt_aux_cont[0].cpu().detach().numpy() 412 | vis_voxel_pgt_aux_cont = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt_aux_cont, ignore_idx=(0,255)) 413 | vis_voxel_pgt_aux_cont[vis_voxel_pgt_aux_cont==255] = 20 414 | vis_voxel_pgt_aux_cont = PALLETE[vis_voxel_pgt_aux_cont] 415 | 416 | 417 | ### Aggregation of current PGT and aux PGT for adapt 418 | voxel_pgt_all_adapt = voxel_pgt_curr.clone() 419 | mask_temp = (voxel_pgt_all_adapt==255)*(voxel_pgt_aux_adapt!=255) 420 | voxel_pgt_all_adapt[mask_temp] = voxel_pgt_aux_adapt[mask_temp] 421 | mask_temp = (voxel_pgt_curr!=255)*(voxel_pgt_aux_adapt!=255)*(voxel_pgt_curr!=voxel_pgt_aux_adapt) 422 | voxel_pgt_all_adapt[mask_temp] = 255 423 | ### Aggregation of current PGT and aux PGT for cont 424 | voxel_pgt_all_cont = voxel_pgt_curr.clone() 425 | mask_temp = (voxel_pgt_all_cont==255)*(voxel_pgt_aux_cont!=255) 426 | voxel_pgt_all_cont[mask_temp] = voxel_pgt_aux_cont[mask_temp] 427 | mask_temp = (voxel_pgt_curr!=255)*(voxel_pgt_aux_cont!=255)*(voxel_pgt_curr!=voxel_pgt_aux_cont) 428 | voxel_pgt_all_cont[mask_temp] = 255 429 | 430 | 431 | vis_voxel_pgt = voxel_pgt_all_adapt[0].cpu().detach().numpy() 432 | vis_voxel_pgt = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt, ignore_idx=(0,255)) 433 | vis_voxel_pgt[vis_voxel_pgt==255] = 20 434 | vis_voxel_pgt = PALLETE[vis_voxel_pgt] 435 | 436 | else: 437 | print("Skip to generate class pgt.") 438 | 439 | ######################################################################################################################################### 440 | ########################################################## Scan-wise Adaptation ######################################################### 441 | ######################################################################################################################################### 442 | 443 | if args.do_adapt: 444 | 445 | print("Do scan-wise adaptation.") 446 | print("- From baseline model") 447 | model_adapt = copy.deepcopy(model_baseline) 448 | 449 | param_to_update_adapt = [] 450 | 451 | print("- Scan-wise adapt segmentation module") 452 | for tn in module_names_seg: 453 | param_to_update_adapt += list(getattr(model_adapt.cylinder_3d_spconv_seg, tn).parameters()) 454 | 455 | print("- Scan-wise adapt final logit layer") 456 | for tn in module_names_logit: 457 | param_to_update_adapt += list(getattr(model_adapt.cylinder_3d_spconv_seg, tn).parameters()) 458 | 459 | optimizer_adapt = optim.Adam(param_to_update_adapt, lr=args.adapt_lr) 460 | 461 | model_adapt.train() 462 | 463 | # Partial freeze 464 | for name, param in model_adapt.named_parameters(): 465 | #freeze_mlp 466 | if name.split('.')[0] in module_names_mlp: 467 | param.requires_grad = False 468 | #freeze_comp: 469 | if any(mona in name.split('.')[1] for mona in module_names_comp): 470 | param.requires_grad = False 471 | 472 | for idx_adapt in range(args.adapt_iter): 473 | 474 | logit = model_adapt(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False) # (B,C,x,y,z) 475 | loss_adapt_occ_ce = 0 476 | loss_adapt_occ_lovasz = 0 477 | loss_adapt_pgt_ce = 0 478 | loss_adapt_pgt_lovasz = 0 479 | 480 | if args.use_los and flag_adapt_aux_exist: 481 | logit_empty = logit[:,:1,:,:,:] 482 | logit_occupied = logit[:,1:,:,:,:].max(dim=1, keepdim=True)[0] 483 | logit_comp = torch.cat((logit_empty, logit_occupied), dim=1) # (B,2,x,y,z) 484 | loss_adapt_occ_ce += loss_fn_ce_binary(logit_comp, voxel_los_adapt) 485 | loss_adapt_occ_lovasz += loss_fn_lovasz_binary(F.softmax(logit_comp), voxel_los_adapt, ignore=255) 486 | 487 | if args.use_pgt: 488 | loss_adapt_pgt_ce += loss_fn_ce(logit, voxel_pgt_all_adapt) 489 | loss_adapt_pgt_lovasz += loss_fn_lovasz(F.softmax(logit), voxel_pgt_all_adapt, ignore=255) 490 | 491 | optimizer_adapt.zero_grad() 492 | loss_adapt = args.weight_adapt_occ_ce*loss_adapt_occ_ce + args.weight_adapt_occ_lovasz*loss_adapt_occ_lovasz \ 493 | + args.weight_adapt_pgt_ce*loss_adapt_pgt_ce + args.weight_adapt_pgt_lovasz*loss_adapt_pgt_lovasz 494 | if loss_adapt!=0 and not torch.isnan(loss_adapt): 495 | loss_adapt.backward() 496 | optimizer_adapt.step() 497 | else: 498 | print('Loss is zero or NaN! Skip adapt optimization.') 499 | 500 | # plot adapt losses 501 | for loss_name in loss_adapt_names: 502 | loss_now = locals()[loss_name] 503 | writer.add_scalar("loss_adapt/"+loss_name, loss_now, global_step=args.adapt_iter*idx_test+idx_adapt) 504 | 505 | 506 | else: 507 | print("Skip scan-wise adaptation.") 508 | 509 | ######################################################################################################################################### 510 | ############################################################## Eval phase ############################################################### 511 | ######################################################################################################################################### 512 | 513 | model_cont.eval() 514 | model_adapt.eval() 515 | with torch.no_grad(): 516 | pred_logit = model_cont(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False, extraction=proposure_idx) 517 | 518 | pred = torch.argmax(pred_logit, dim=1) 519 | 520 | pred_logit_bs = model_adapt(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False) 521 | pred_bs = torch.argmax(pred_logit_bs, dim=1) 522 | 523 | mask_pred_=(pred==9)+(pred==10)+(pred==11)+(pred==12)+(pred==13)+(pred==15)+(pred==16)+(pred==17) 524 | pred[~mask_pred_]=pred_bs[~mask_pred_] 525 | 526 | pred = pred.cpu().detach().numpy() 527 | pred = np.squeeze(pred) 528 | pred = pred.astype(np.uint32) 529 | pred = pred.reshape((-1)) 530 | 531 | ### save prediction after remapping 532 | pred_remapped = remapping(pred, remap=remap_first) 533 | name_velodyne = test_pt_dataset.im_idx[frame_curr['index'][0]] 534 | save_pred(pred_remapped, name_velodyne, exp_path) 535 | 536 | ######################################################################################################################################### 537 | ######################################################## Continual TTA phase ############################################################ 538 | ######################################################################################################################################### 539 | 540 | if args.do_cont: 541 | 542 | print("Do continual adaptation.") 543 | 544 | model_cont.train() 545 | # Partial freeze 546 | for name, param in model_cont.named_parameters(): 547 | #freeze_mlp 548 | if name.split('.')[0] in module_names_mlp: 549 | param.requires_grad = False 550 | #freeze_comp 551 | if any(mona in name.split('.')[1] for mona in module_names_comp): 552 | param.requires_grad = False 553 | #freeze_logit 554 | if any(mona in name.split('.')[1] for mona in module_names_logit): 555 | param.requires_grad = False 556 | 557 | # Continual TTA loop 558 | for idx_cont in range(args.cont_iter): 559 | 560 | # Random masking 561 | mask_size=4 562 | mask_ratio=0.1 563 | upsample = nn.Upsample(scale_factor=mask_size, mode='nearest') 564 | mask_voxel=torch.zeros(int(256/mask_size),int(256/mask_size),int(32/mask_size)) 565 | mask_voxel=mask_voxel.reshape(-1) 566 | rand_idx=torch.randperm(mask_voxel.shape[0]) 567 | mask_number=int(rand_idx.shape[0]*mask_ratio) 568 | mask_patch=rand_idx[:mask_number] 569 | mask_voxel[mask_patch[:]]=1 570 | mask_voxel=mask_voxel.reshape(int(256/mask_size),int(256/mask_size),int(32/mask_size)) 571 | mask_voxel=upsample(mask_voxel.unsqueeze(0).unsqueeze(0)).squeeze() 572 | voxel_curr_grid = torch.zeros([256,256,32]) 573 | voxel_curr_grid[grid_curr[0][:,0],grid_curr[0][:,1],grid_curr[0][:,2]] = 1 574 | masked_voxel=voxel_curr_grid*mask_voxel 575 | masked_coords=masked_voxel.nonzero() 576 | 577 | matches = torch.nonzero((grid_curr[0][:, None] == masked_coords.cuda()).all(-1), as_tuple=True) 578 | masked_idx = matches[0] 579 | all_indices = torch.arange(grid_curr[0].shape[0]) 580 | retain_idx = all_indices[~torch.isin(all_indices.cuda(), masked_idx)] 581 | feat_curr_retain = [feat_curr[0][retain_idx]] 582 | grid_curr_retain = [grid_curr[0][retain_idx]] 583 | frame_curr_grid_ind_retain = [frame_curr['grid_ind'][0][retain_idx.cpu().detach().numpy()]] 584 | grid_curr_masked = grid_curr[0][masked_idx] 585 | grid_curr_masked=torch.cat([grid_curr_masked,proposure_idx],0) 586 | logit = model_cont(feat_curr_retain, grid_curr_retain, val_batch_size, frame_curr_grid_ind_retain, use_tta=False, extraction=grid_curr_masked) # (B,C,x,y,z) 587 | pred = torch.argmax(logit, dim=1) # (B,x,y,z) 588 | 589 | loss_cont_occ_ce = 0 590 | loss_cont_occ_lovasz = 0 591 | loss_cont_pgt_ce = 0 592 | loss_cont_pgt_lovasz = 0 593 | 594 | if args.use_los and flag_cont_aux_exist: 595 | logit_empty = logit[:,:1,:,:,:] 596 | logit_occupied = logit[:,1:,:,:,:].max(dim=1, keepdim=True)[0] 597 | logit_comp = torch.cat((logit_empty, logit_occupied), dim=1) # (B,2,x,y,z) 598 | loss_cont_occ_ce += loss_fn_ce_binary(logit_comp, voxel_los_cont) 599 | loss_cont_occ_lovasz += loss_fn_lovasz_binary(F.softmax(logit_comp), voxel_los_cont, ignore=255) 600 | 601 | if args.use_pgt: 602 | loss_cont_pgt_ce += loss_fn_ce(logit, voxel_pgt_all_cont) 603 | loss_cont_pgt_lovasz += loss_fn_lovasz(F.softmax(logit), voxel_pgt_all_cont, ignore=255) 604 | 605 | optimizer_cont.zero_grad() 606 | loss_cont = args.weight_cont_occ_ce*loss_cont_occ_ce + args.weight_cont_occ_lovasz*loss_cont_occ_lovasz \ 607 | + args.weight_cont_pgt_ce*loss_cont_pgt_ce + args.weight_cont_pgt_lovasz*loss_cont_pgt_lovasz 608 | if loss_cont!=0 and not torch.isnan(loss_cont): 609 | loss_cont.backward() 610 | optimizer_cont.step() 611 | else: 612 | print('Loss is zero or NaN! Skip cont optimization.') 613 | 614 | # plot cont tta losses 615 | for loss_name in loss_cont_names: 616 | loss_now = locals()[loss_name] 617 | writer.add_scalar("loss_cont/"+loss_name, loss_now, global_step=args.cont_iter*idx_test+idx_cont) 618 | else: 619 | print("Skip continual adaptation.") 620 | 621 | 622 | if __name__ == '__main__': 623 | # Training settings 624 | parser = argparse.ArgumentParser(description='') 625 | 626 | # Sources 627 | parser.add_argument('--talos_root', default='./', type=str) 628 | parser.add_argument('--config_path', default='config/semantickitti-tta.yaml') 629 | parser.add_argument('--baseline_perf_txt', default='baseline_performance.txt', type=str) 630 | parser.add_argument('--baseline_preds', default='experiments/baseline/sequences/08/predictions', type=str) 631 | parser.add_argument('--sq_num', default='8', type=str) 632 | # Experiment 633 | parser.add_argument('--name', default='debug') 634 | parser.add_argument('--ang', action='store_true', help='Auto Name Generator') 635 | parser.add_argument('--loader', default='data_builder', type=str) 636 | parser.add_argument('--stride', default='[-5,5]', type=str) 637 | parser.add_argument('--do_cont', action='store_true') 638 | parser.add_argument('--do_adapt', action='store_true') 639 | 640 | # Attributes 641 | parser.add_argument('--use_los', action='store_true') 642 | parser.add_argument('--use_pgt', action='store_true') 643 | parser.add_argument('--th_pgt_occupied', default=0.75, type=float) 644 | parser.add_argument('--th_pgt_empty', default=0.999, type=float) 645 | 646 | # Optimization (cont) 647 | parser.add_argument('--cont_lr', default=3e-05, type=float) 648 | parser.add_argument('--cont_iter', default=1, type=int) 649 | parser.add_argument('--weight_cont_occ_ce', default=1, type=float) 650 | parser.add_argument('--weight_cont_occ_lovasz', default=1, type=float) 651 | parser.add_argument('--weight_cont_pgt_ce', default=1, type=float) 652 | parser.add_argument('--weight_cont_pgt_lovasz', default=1, type=float) 653 | 654 | # Optimization (adapt) 655 | parser.add_argument('--adapt_lr', default=0.0003, type=float) 656 | parser.add_argument('--adapt_iter', default=3, type=int) 657 | parser.add_argument('--weight_adapt_occ_ce', default=1, type=float) 658 | parser.add_argument('--weight_adapt_occ_lovasz', default=1, type=float) 659 | parser.add_argument('--weight_adapt_pgt_ce', default=1, type=float) 660 | parser.add_argument('--weight_adapt_pgt_lovasz', default=1, type=float) 661 | 662 | args = parser.parse_args() 663 | 664 | args.baseline_perf_txt = args.talos_root+args.baseline_perf_txt 665 | args.baseline_preds = args.talos_root+args.baseline_preds 666 | 667 | print(' '.join(sys.argv)) 668 | print(args) 669 | print('#####') 670 | print('Stride: '+str(args.stride)) 671 | print('#####') 672 | 673 | if args.ang: 674 | 675 | args.name = '' 676 | args.name += '_stride'+str(args.stride) 677 | 678 | if args.use_los: 679 | args.name += '_los' 680 | if args.use_pgt: 681 | args.name += '_pgt' 682 | if args.th_pgt_occupied != 0.75: 683 | args.name += '_thocc'+str(args.th_pgt_occupied) 684 | if args.th_pgt_empty != 0.999: 685 | args.name += '_themp'+str(args.th_pgt_empty) 686 | 687 | if args.do_cont: 688 | args.name += '_cont' 689 | args.name += '_clr'+str(args.cont_lr) 690 | args.name += '_cit'+str(args.cont_iter) 691 | 692 | if args.do_adapt: 693 | args.name += '_adapt' 694 | args.name += '_alr'+str(args.adapt_lr) 695 | args.name += '_ait'+str(args.adapt_iter) 696 | 697 | 698 | 699 | print(args.name) 700 | 701 | main(args) 702 | --------------------------------------------------------------------------------