├── assets
├── benchmark.png
├── kitti_qual.png
└── qual_kitti.png
├── utils
├── __init__.py
├── softmax_entropy.py
├── log_util.py
├── metric_util.py
├── cmap_plot.py
├── util.py
├── np_ioueval.py
├── load_save_util.py
└── lovasz_losses.py
├── config
├── __init__.py
├── semantickitti-tta.yaml
├── semantickitti-multiscan.yaml
├── semantickitti-tta-val.yaml
├── semantickitti.yaml
├── config.py
└── label_mapping
│ ├── semantic-poss-multiscan.yaml
│ ├── semantic-kitti.yaml
│ ├── semantic-kitti-multiscan.yaml
│ └── semantic-kitti-all.yaml
├── network
├── __init__.py
├── cylinder_spconv_3d_unlock.py
├── cylinder_fea_generator.py
├── conv_base.py
└── segmentator_3d_asymm_spconv_unlock.py
├── dataloader
├── __init__.py
├── pc_dataset.py
├── pc_dataset_test.py
├── dataset_semantickitti.py
└── dataset_semantickitti_test.py
├── requirements.txt
├── README.md
└── run_tta_test.py
/assets/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blue-531/TALoS/HEAD/assets/benchmark.png
--------------------------------------------------------------------------------
/assets/kitti_qual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blue-531/TALoS/HEAD/assets/kitti_qual.png
--------------------------------------------------------------------------------
/assets/qual_kitti.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blue-531/TALoS/HEAD/assets/qual_kitti.png
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: __init__.py.py
4 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: __init__.py.py
4 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: __init__.py.py
4 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: __init__.py.py
4 |
5 | # from . import dataset_nuscenes
--------------------------------------------------------------------------------
/utils/softmax_entropy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def softmax_entropy(x):
5 | """Entropy of softmax distribution from logits."""
6 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
--------------------------------------------------------------------------------
/utils/log_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: log_util.py
4 |
5 |
6 | def save_to_log(logdir, logfile, message):
7 | f = open(logdir + '/' + logfile, "a")
8 | f.write(message + '\n')
9 | f.close()
10 | return
--------------------------------------------------------------------------------
/utils/metric_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: metric_util.py
4 |
5 | import numpy as np
6 |
7 |
8 | def fast_hist(pred, label, n):
9 | k = (label >= 0) & (label < n)
10 | bin_count = np.bincount(
11 | n * label[k].astype(int) + pred[k], minlength=n ** 2)
12 | return bin_count[:n ** 2].reshape(n, n)
13 |
14 |
15 | def per_class_iu(hist):
16 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
17 |
18 |
19 | def fast_hist_crop(output, target, unique_label):
20 | hist = fast_hist(output.flatten(), target.flatten(), np.max(unique_label) + 2)
21 | hist = hist[unique_label + 1, :]
22 | hist = hist[:, unique_label + 1]
23 | return hist
24 |
--------------------------------------------------------------------------------
/config/semantickitti-tta.yaml:
--------------------------------------------------------------------------------
1 | # Config format schema number
2 | format_version: 4
3 |
4 | ###################
5 | ## Model options
6 | model_params:
7 | model_architecture: "cylinder_asym"
8 |
9 | output_shape:
10 | - 256
11 | - 256
12 | - 32
13 |
14 | fea_dim: 7
15 | out_fea_dim: 256
16 | num_class: 20
17 | num_input_features: 32
18 | use_norm: True
19 | init_size: 32
20 |
21 |
22 | ###################
23 | ## Dataset options
24 | dataset_params:
25 | dataset_type: "voxel_dataset"
26 | pc_dataset_type: "SemKITTI_sk_multiscan"
27 | ignore_label: 255
28 | return_test: True
29 | fixed_volume_space: True
30 | label_mapping: "./config/label_mapping/semantic-kitti-multiscan.yaml"
31 | max_volume_space:
32 | - 51.2
33 | - 25.6
34 | - 4.4
35 | min_volume_space:
36 | - 0
37 | - -25.6
38 | - -2
39 |
40 |
41 | ###################
42 | ## Data_loader options
43 | train_data_loader:
44 | data_path: "./dataset/sequences"
45 | imageset: "train"
46 | return_ref: True
47 | batch_size: 2
48 | shuffle: True
49 | num_workers: 0
50 |
51 | val_data_loader:
52 | data_path: "./dataset/sequences"
53 | # imageset: "val"
54 | imageset: "test"
55 | return_ref: True
56 | batch_size: 1 #2
57 | shuffle: False
58 | num_workers: 0
59 |
60 |
61 |
62 |
63 | ###################
64 | ## Train params
65 | train_params:
66 | model_load_path: "./model_load_dir/"
67 | model_save_path: "./model_load_dir/"
68 | checkpoint_every_n_steps: 4599
69 | max_num_epochs: 40
70 | eval_every_n_steps: 1917
71 | learning_rate: 0.0015
72 |
--------------------------------------------------------------------------------
/config/semantickitti-multiscan.yaml:
--------------------------------------------------------------------------------
1 | # Config format schema number
2 | format_version: 4
3 |
4 | ###################
5 | ## Model options
6 | model_params:
7 | model_architecture: "cylinder_asym"
8 |
9 | output_shape:
10 | - 256
11 | - 256
12 | - 32
13 |
14 | fea_dim: 7
15 | out_fea_dim: 256
16 | num_class: 20
17 | num_input_features: 32
18 | use_norm: True
19 | init_size: 32
20 |
21 |
22 | ###################
23 | ## Dataset options
24 | dataset_params:
25 | dataset_type: "voxel_dataset"
26 | pc_dataset_type: "SemKITTI_sk_multiscan"
27 | ignore_label: 255
28 | return_test: True
29 | fixed_volume_space: True
30 | label_mapping: "./config/label_mapping/semantic-kitti-multiscan.yaml"
31 | max_volume_space:
32 | - 51.2
33 | - 25.6
34 | - 4.4
35 | min_volume_space:
36 | - 0
37 | - -25.6
38 | - -2
39 |
40 |
41 | ###################
42 | ## Data_loader options
43 | train_data_loader:
44 | data_path: "./dataset/sequences"
45 | imageset: "train"
46 | return_ref: True
47 | batch_size: 2
48 | shuffle: True
49 | num_workers: 0
50 |
51 | val_data_loader:
52 | data_path: "./dataset/sequences"
53 | imageset: "val"
54 | # imageset: "test"
55 | return_ref: True
56 | batch_size: 1 #2
57 | shuffle: False
58 | num_workers: 0
59 |
60 |
61 |
62 |
63 | ###################
64 | ## Train params
65 | train_params:
66 | model_load_path: "./model_load_dir/"
67 | model_save_path: "./model_load_dir/"
68 | checkpoint_every_n_steps: 4599
69 | max_num_epochs: 40
70 | eval_every_n_steps: 1917
71 | learning_rate: 0.0015
72 |
--------------------------------------------------------------------------------
/config/semantickitti-tta-val.yaml:
--------------------------------------------------------------------------------
1 | # Config format schema number
2 | format_version: 4
3 |
4 | ###################
5 | ## Model options
6 | model_params:
7 | model_architecture: "cylinder_asym"
8 |
9 | output_shape:
10 | - 256
11 | - 256
12 | - 32
13 |
14 | fea_dim: 7
15 | out_fea_dim: 256
16 | num_class: 20
17 | num_input_features: 32
18 | use_norm: True
19 | init_size: 32
20 |
21 |
22 | ###################
23 | ## Dataset options
24 | dataset_params:
25 | dataset_type: "voxel_dataset"
26 | pc_dataset_type: "SemKITTI_sk_multiscan"
27 | ignore_label: 255
28 | return_test: True
29 | fixed_volume_space: True
30 | label_mapping: "./config/label_mapping/semantic-kitti-multiscan.yaml"
31 | max_volume_space:
32 | - 51.2
33 | - 25.6
34 | - 4.4
35 | min_volume_space:
36 | - 0
37 | - -25.6
38 | - -2
39 |
40 |
41 | ###################
42 | ## Data_loader options
43 | train_data_loader:
44 | data_path: "./dataset/sequences"
45 | imageset: "train"
46 | return_ref: True
47 | batch_size: 2
48 | shuffle: True
49 | num_workers: 0
50 |
51 | val_data_loader:
52 | data_path: "./dataset/sequences"
53 | imageset: "val"
54 | # imageset: "test"
55 | return_ref: True
56 | batch_size: 1 #2
57 | shuffle: False
58 | num_workers: 0
59 |
60 |
61 |
62 |
63 | ###################
64 | ## Train params
65 | train_params:
66 | model_load_path: "./model_load_dir/"
67 | model_save_path: "./model_load_dir/"
68 | checkpoint_every_n_steps: 4599
69 | max_num_epochs: 40
70 | eval_every_n_steps: 1917
71 | learning_rate: 0.0015
72 |
--------------------------------------------------------------------------------
/config/semantickitti.yaml:
--------------------------------------------------------------------------------
1 | # Config format schema number
2 | format_version: 4
3 |
4 | ###################
5 | ## Model options
6 | model_params:
7 | model_architecture: "cylinder_asym"
8 |
9 | output_shape:
10 | - 480
11 | - 360
12 | - 32
13 |
14 | fea_dim: 9
15 | out_fea_dim: 256
16 | num_class: 20
17 | num_input_features: 32 #16
18 | use_norm: True
19 | init_size: 32 #16
20 |
21 |
22 | ###################
23 | ## Dataset options
24 | dataset_params:
25 | dataset_type: "cylinder_dataset"
26 | pc_dataset_type: "SemKITTI_sk"
27 | ignore_label: 0
28 | return_test: False
29 | fixed_volume_space: True
30 | label_mapping: "./config/label_mapping/semantic-kitti.yaml"
31 | max_volume_space:
32 | - 50
33 | - 3.1415926
34 | - 2
35 | min_volume_space:
36 | - 0
37 | - -3.1415926
38 | - -4
39 |
40 |
41 | ###################
42 | ## Data_loader options
43 | train_data_loader:
44 | data_path: "/nvme/yuenan/semantickitti_dataset/sequences/"
45 | imageset: "train"
46 | return_ref: True
47 | batch_size: 2
48 | shuffle: True
49 | num_workers: 12 #4
50 |
51 | val_data_loader:
52 | data_path: "/nvme/yuenan/semantickitti_dataset/sequences/"
53 | imageset: "test" #"val"
54 | return_ref: True
55 | batch_size: 1
56 | shuffle: False
57 | num_workers: 12 #4
58 |
59 |
60 | ###################
61 | ## Train params
62 | train_params:
63 | model_load_path: "./model_load_dir/model_full_ft.pt"
64 | model_save_path: "./model_save_dir/model_tmp.pt"
65 | checkpoint_every_n_steps: 4599
66 | max_num_epochs: 20 #40
67 | eval_every_n_steps: 5000 #4599
68 | learning_rate: 0.002 #1
69 |
--------------------------------------------------------------------------------
/utils/cmap_plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import itertools
3 | import io
4 | import numpy as np
5 |
6 |
7 | def label_2_cmap(target,pred):
8 | pred_counts=np.zeros((20,20))
9 | for i in range(20):
10 | pred_i=np.where(target==i,pred,255)
11 | pred_id,counts=np.unique(pred_i,return_counts=True)
12 |
13 | for idx,pred_category in enumerate(pred_id):
14 | if pred_category==255:
15 | continue
16 | pred_counts[i,pred_category]=counts[idx]
17 | return pred_counts
18 |
19 | def plot_confusion_matrix(cm, class_names =["unlabeled","car","bicycle","motorcycle","truck","other-vehicle","person","bicyclist","motorcyclist","road","parking","sidewalk","other-ground","building","fence","vegetation","trunk","terrain","pole","traffic-sign"]):
20 | """
21 | Returns a matplotlib figure containing the plotted confusion matrix.
22 |
23 | Args:
24 | cm (array, shape = [n, n]): a confusion matrix of integer classes
25 | class_names (array, shape = [n]): String names of the integer classes
26 | """
27 | cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=4)
28 |
29 | figure = plt.figure(figsize=(15, 15))
30 | plt.imshow(cm, cmap=plt.cm.Blues)
31 | plt.title("Confusion matrix")
32 | plt.colorbar()
33 | tick_marks = np.arange(len(class_names))
34 | plt.xticks(tick_marks, class_names, rotation=45)
35 | plt.yticks(tick_marks, class_names)
36 |
37 | # Compute the labels from the normalized confusion matrix.
38 | labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=4)
39 |
40 | # Use white text if squares are dark; otherwise black.
41 | threshold = cm.max() / 2.
42 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
43 | color = "white" if cm[i, j] > threshold else "black"
44 | plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)
45 |
46 | plt.tight_layout()
47 | plt.ylabel('True label')
48 | plt.xlabel('Predicted label')
49 | # plt.savefig(f'{fig_name}.png', dpi=300)
50 | buf = io.BytesIO()
51 | plt.savefig(buf, format='png')
52 | plt.close(figure)
53 | buf.seek(0)
54 |
55 | return buf
56 |
57 |
--------------------------------------------------------------------------------
/network/cylinder_spconv_3d_unlock.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: cylinder_spconv_3d.py
4 |
5 | from torch import nn
6 | import torch
7 |
8 | REGISTERED_MODELS_CLASSES = {}
9 |
10 |
11 | def register_model(cls, name=None):
12 | global REGISTERED_MODELS_CLASSES
13 | if name is None:
14 | name = cls.__name__
15 | assert name not in REGISTERED_MODELS_CLASSES, f"exist class: {REGISTERED_MODELS_CLASSES}"
16 | REGISTERED_MODELS_CLASSES[name] = cls
17 | return cls
18 |
19 |
20 | def get_model_class(name):
21 | global REGISTERED_MODELS_CLASSES
22 | assert name in REGISTERED_MODELS_CLASSES, f"available class: {REGISTERED_MODELS_CLASSES}"
23 | return REGISTERED_MODELS_CLASSES[name]
24 |
25 |
26 | @register_model
27 | class cylinder_asym(nn.Module):
28 | def __init__(self,
29 | cylin_model,
30 | segmentator_spconv,
31 | sparse_shape,
32 | ):
33 | super().__init__()
34 | self.name = "cylinder_asym"
35 |
36 | self.cylinder_3d_generator = cylin_model #cylinder_fea_generator
37 |
38 | self.cylinder_3d_spconv_seg = segmentator_spconv #Asymm_3d_spconv
39 |
40 | self.sparse_shape = sparse_shape
41 |
42 | def forward(self, train_pt_fea_ten, train_vox_ten, batch_size, val_grid=None, voting_num=4, use_tta=False, extraction='all'):
43 | coords, features_3d = self.cylinder_3d_generator(train_pt_fea_ten, train_vox_ten)
44 | # train_pt_fea_ten: [batch_size, N1, 7]
45 | # train_vox_ten: [batch_size, N1, 3]
46 |
47 | if use_tta:
48 | batch_size *= voting_num
49 |
50 | spatial_features = self.cylinder_3d_spconv_seg(features_3d, coords, batch_size, extraction) # [batch_size, 20, 256, 256, 32]
51 |
52 | if use_tta:
53 | fused_predict = spatial_features[0, :]
54 | for idx in range(1, voting_num, 1):
55 | aug_predict = spatial_features[idx, :]
56 | aug_predict = torch.flip(aug_predict, dims=[2])
57 | fused_predict += aug_predict
58 | return torch.unsqueeze(fused_predict, 0)
59 | else:
60 | return spatial_features
61 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def Bresenham3D(start_point, end_points,line_end):
4 | ListOfPoints = []
5 |
6 | for end_point in end_points:
7 | x1, y1, z1=start_point[0]
8 | x2, y2, z2=end_point
9 | ListOfPoints.append((x1, y1, z1))
10 | dx = abs(x2 - x1)
11 | dy = abs(y2 - y1)
12 | dz = abs(z2 - z1)
13 | if (x2 > x1):
14 | xs = 1
15 | else:
16 | xs = -1
17 | if (y2 > y1):
18 | ys = 1
19 | else:
20 | ys = -1
21 | if (z2 > z1):
22 | zs = 1
23 | else:
24 | zs = -1
25 | # Driving axis is X-axis"
26 | if (dx >= dy and dx >= dz):
27 | p1 = 2 * dy - dx
28 | p2 = 2 * dz - dx
29 | while (x1 != x2):
30 | x1 += xs
31 | if (p1 >= 0):
32 | y1 += ys
33 | p1 -= 2 * dx
34 | if (p2 >= 0):
35 | z1 += zs
36 | p2 -= 2 * dx
37 | p1 += 2 * dy
38 | p2 += 2 * dz
39 |
40 | ListOfPoints.append((x1, y1, z1))
41 |
42 | # Driving axis is Y-axis"
43 | elif (dy >= dx and dy >= dz):
44 | p1 = 2 * dx - dy
45 | p2 = 2 * dz - dy
46 | while (y1 != y2):
47 | y1 += ys
48 | if (p1 >= 0):
49 | x1 += xs
50 | p1 -= 2 * dy
51 | if (p2 >= 0):
52 | z1 += zs
53 | p2 -= 2 * dy
54 | p1 += 2 * dx
55 | p2 += 2 * dz
56 |
57 | ListOfPoints.append((x1, y1, z1))
58 |
59 | # Driving axis is Z-axis"
60 | else:
61 | p1 = 2 * dy - dz
62 | p2 = 2 * dx - dz
63 | while (z1 != z2):
64 | z1 += zs
65 | if (p1 >= 0):
66 | y1 += ys
67 | p1 -= 2 * dz
68 | if (p2 >= 0):
69 | x1 += xs
70 | p2 -= 2 * dz
71 | p1 += 2 * dy
72 | p2 += 2 * dx
73 |
74 | ListOfPoints.append((x1, y1, z1))
75 | return ListOfPoints
76 |
77 |
--------------------------------------------------------------------------------
/network/cylinder_fea_generator.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch_scatter
8 |
9 |
10 | class cylinder_fea(nn.Module):
11 |
12 | def __init__(self, grid_size, fea_dim=3,
13 | out_pt_fea_dim=64, max_pt_per_encode=64, fea_compre=None):
14 | super(cylinder_fea, self).__init__()
15 |
16 | self.PPmodel = nn.Sequential(
17 | nn.BatchNorm1d(fea_dim),
18 |
19 | nn.Linear(fea_dim, 64),
20 | nn.BatchNorm1d(64),
21 | nn.ReLU(),
22 |
23 | nn.Linear(64, 128),
24 | nn.BatchNorm1d(128),
25 | nn.ReLU(),
26 |
27 | nn.Linear(128, 256),
28 | nn.BatchNorm1d(256),
29 | nn.ReLU(),
30 |
31 | nn.Linear(256, out_pt_fea_dim)
32 | )
33 |
34 | self.max_pt = max_pt_per_encode
35 | self.fea_compre = fea_compre
36 | self.grid_size = grid_size
37 | kernel_size = 3
38 | self.local_pool_op = torch.nn.MaxPool2d(kernel_size, stride=1,
39 | padding=(kernel_size - 1) // 2,
40 | dilation=1)
41 | self.pool_dim = out_pt_fea_dim
42 |
43 | # point feature compression
44 | if self.fea_compre is not None:
45 | self.fea_compression = nn.Sequential(
46 | nn.Linear(self.pool_dim, self.fea_compre),
47 | nn.ReLU())
48 | self.pt_fea_dim = self.fea_compre
49 | else:
50 | self.pt_fea_dim = self.pool_dim
51 |
52 | def forward(self, pt_fea, xy_ind):
53 | cur_dev = pt_fea[0].get_device()
54 |
55 | ### concate everything
56 | cat_pt_ind = []
57 | for i_batch in range(len(xy_ind)):
58 | cat_pt_ind.append(F.pad(xy_ind[i_batch], (1, 0), 'constant', value=i_batch))
59 |
60 | cat_pt_fea = torch.cat(pt_fea, dim=0)
61 | cat_pt_ind = torch.cat(cat_pt_ind, dim=0)
62 | pt_num = cat_pt_ind.shape[0]
63 |
64 | ### shuffle the data
65 | shuffled_ind = torch.randperm(pt_num, device=cur_dev)
66 | cat_pt_fea = cat_pt_fea[shuffled_ind, :]
67 | cat_pt_ind = cat_pt_ind[shuffled_ind, :]
68 |
69 | ### unique xy grid index
70 | unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind, return_inverse=True, return_counts=True, dim=0)
71 | unq = unq.type(torch.int64)
72 |
73 | ### process feature
74 | processed_cat_pt_fea = self.PPmodel(cat_pt_fea)
75 | pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0]
76 |
77 | if self.fea_compre:
78 | processed_pooled_data = self.fea_compression(pooled_data)
79 | else:
80 | processed_pooled_data = pooled_data
81 |
82 | return unq, processed_pooled_data
83 |
--------------------------------------------------------------------------------
/utils/np_ioueval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # This file is covered by the LICENSE file in the root of this project.
3 |
4 | import sys
5 | import numpy as np
6 |
7 |
8 | class iouEval:
9 | def __init__(self, n_classes, ignore=None):
10 | # classes
11 | self.n_classes = n_classes
12 |
13 | # What to include and ignore from the means
14 | self.ignore = np.array(ignore, dtype=np.int64)
15 | self.include = np.array(
16 | [n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64)
17 | # print("[IOU EVAL] IGNORE: ", self.ignore)
18 | # print("[IOU EVAL] INCLUDE: ", self.include)
19 |
20 | # reset the class counters
21 | self.reset()
22 |
23 | def num_classes(self):
24 | return self.n_classes
25 |
26 | def reset(self):
27 | self.conf_matrix = np.zeros((self.n_classes,
28 | self.n_classes),
29 | dtype=np.int64)
30 |
31 | def addBatch(self, x, y): # x=preds, y=targets
32 | # sizes should be matching
33 | x_row = x.reshape(-1) # de-batchify
34 | y_row = y.reshape(-1) # de-batchify
35 |
36 | # check
37 | assert(x_row.shape == y_row.shape)
38 |
39 | # create indexes
40 | idxs = tuple(np.stack((x_row, y_row), axis=0))
41 |
42 | # make confusion matrix (cols = gt, rows = pred)
43 | np.add.at(self.conf_matrix, idxs, 1)
44 |
45 | def getStats(self):
46 | # remove fp from confusion on the ignore classes cols
47 | conf = self.conf_matrix.copy()
48 | conf[:, self.ignore] = 0
49 |
50 | # get the clean stats
51 | tp = np.diag(conf)
52 | fp = conf.sum(axis=1) - tp
53 | fn = conf.sum(axis=0) - tp
54 | return tp, fp, fn
55 |
56 | def getIoU(self):
57 | tp, fp, fn = self.getStats()
58 | intersection = tp
59 | union = tp + fp + fn + 1e-15
60 | iou = intersection / union
61 | iou_mean = (intersection[self.include] / union[self.include]).mean()
62 | return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES
63 |
64 | def getacc(self):
65 | tp, fp, fn = self.getStats()
66 | total_tp = tp.sum()
67 | total = tp[self.include].sum() + fp[self.include].sum() + 1e-15
68 | acc_mean = total_tp / total
69 | return acc_mean # returns "acc mean"
70 |
71 | def get_confusion(self):
72 | return self.conf_matrix.copy()
73 |
74 |
75 |
76 | if __name__ == "__main__":
77 | # mock problem
78 | nclasses = 2
79 | ignore = []
80 |
81 | # test with 2 squares and a known IOU
82 | lbl = np.zeros((7, 7), dtype=np.int64)
83 | argmax = np.zeros((7, 7), dtype=np.int64)
84 |
85 | # put squares
86 | lbl[2:4, 2:4] = 1
87 | argmax[3:5, 3:5] = 1
88 |
89 | # make evaluator
90 | eval = iouEval(nclasses, ignore)
91 |
92 | # run
93 | eval.addBatch(argmax, lbl)
94 | m_iou, iou = eval.getIoU()
95 | print("IoU: ", m_iou)
96 | print("IoU class: ", iou)
97 | m_acc = eval.getacc()
98 | print("Acc: ", m_acc)
99 |
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 |
4 | from pathlib import Path
5 |
6 | from strictyaml import Bool, Float, Int, Map, Seq, Str, as_document, load
7 |
8 | model_params = Map(
9 | {
10 | "model_architecture": Str(),
11 | "output_shape": Seq(Int()),
12 | "fea_dim": Int(),
13 | "out_fea_dim": Int(),
14 | "num_class": Int(),
15 | "num_input_features": Int(),
16 | "use_norm": Bool(),
17 | "init_size": Int(),
18 | }
19 | )
20 |
21 | dataset_params = Map(
22 | {
23 | "dataset_type": Str(),
24 | "pc_dataset_type": Str(),
25 | "ignore_label": Int(),
26 | "return_test": Bool(),
27 | "fixed_volume_space": Bool(),
28 | "label_mapping": Str(),
29 | "max_volume_space": Seq(Float()),
30 | "min_volume_space": Seq(Float()),
31 | }
32 | )
33 |
34 |
35 | train_data_loader = Map(
36 | {
37 | "data_path": Str(),
38 | "imageset": Str(),
39 | "return_ref": Bool(),
40 | "batch_size": Int(),
41 | "shuffle": Bool(),
42 | "num_workers": Int(),
43 | }
44 | )
45 |
46 | val_data_loader = Map(
47 | {
48 | "data_path": Str(),
49 | "imageset": Str(),
50 | "return_ref": Bool(),
51 | "batch_size": Int(),
52 | "shuffle": Bool(),
53 | "num_workers": Int(),
54 | }
55 | )
56 |
57 |
58 | train_params = Map(
59 | {
60 | "model_load_path": Str(),
61 | "model_save_path": Str(),
62 | "checkpoint_every_n_steps": Int(),
63 | "max_num_epochs": Int(),
64 | "eval_every_n_steps": Int(),
65 | "learning_rate": Float()
66 | }
67 | )
68 |
69 | schema_v4 = Map(
70 | {
71 | "format_version": Int(),
72 | "model_params": model_params,
73 | "dataset_params": dataset_params,
74 | "train_data_loader": train_data_loader,
75 | "val_data_loader": val_data_loader,
76 | "train_params": train_params,
77 | }
78 | )
79 |
80 |
81 | SCHEMA_FORMAT_VERSION_TO_SCHEMA = {4: schema_v4}
82 |
83 |
84 | def load_config_data(path: str) -> dict:
85 | yaml_string = Path(path).read_text()
86 | cfg_without_schema = load(yaml_string, schema=None)
87 | schema_version = int(cfg_without_schema["format_version"])
88 | if schema_version not in SCHEMA_FORMAT_VERSION_TO_SCHEMA:
89 | raise Exception(f"Unsupported schema format version: {schema_version}.")
90 |
91 | strict_cfg = load(yaml_string, schema=SCHEMA_FORMAT_VERSION_TO_SCHEMA[schema_version])
92 | cfg: dict = strict_cfg.data
93 | return cfg
94 |
95 |
96 | def config_data_to_config(data): # type: ignore
97 | return as_document(data, schema_v4)
98 |
99 |
100 | def save_config_data(data: dict, path: str) -> None:
101 | cfg_document = config_data_to_config(data)
102 | with open(Path(path), "w") as f:
103 | f.write(cfg_document.as_yaml())
104 |
--------------------------------------------------------------------------------
/config/label_mapping/semantic-poss-multiscan.yaml:
--------------------------------------------------------------------------------
1 | # This file is covered by the LICENSE file in the root of this project.
2 | labels:
3 | 0: "unlabeled"
4 | 4: "1 person"
5 | 5: "2+ person"
6 | 6: "rider"
7 | 7: "car"
8 | 8: "trunk"
9 | 9: "plants"
10 | 10: "traffic sign 1" # standing sign
11 | 11: "traffic sign 2" # hanging sign
12 | 12: "traffic sign 3" # high/big hanging sign
13 | 13: "pole"
14 | 14: "trashcan"
15 | 15: "building"
16 | 17: "fence"
17 | 16: "cone/stone"
18 | 21: "bike"
19 | 22: "ground" # class definition
20 |
21 | color_map: # bgr
22 |
23 | 0 : [0, 0, 0] # 0: "unlabeled"
24 | 4 : [30, 30, 255] # 4: "1 person"
25 | 5 : [30, 30, 255] # 5: "2+ person"
26 | 6 : [200, 40, 255] # 6: "rider"
27 | 7 : [245, 150, 100] # 7: "car"
28 | 8 : [0,60,135] # 8: "trunk"
29 | 9 : [0, 175, 0] # 9: "plants"
30 | 10: [0, 0, 255] # 10: "traffic sign 1"
31 | 11: [0, 0, 255] # 11: "traffic sign 2"
32 | 12: [0, 0, 255] # 12: "traffic sign 3"
33 | 13: [150, 240, 255] # 13: "pole"
34 | 14: [0, 255, 125] # 14: "trashcan"
35 | 15: [0, 200, 255] # 15: "building"
36 | 16: [255, 255, 50] # 16: "cone/stone"
37 | 17: [50, 120, 255] # 17: "fence"
38 | 21: [245, 230, 100] # 21: "bike"
39 | 22: [128, 128, 128] # 22: "ground"
40 |
41 | content: # as a ratio with the total number of points
42 | 0: 0.018889854628292943
43 | 1: 0.0002937197336781505
44 | 10: 0.040818519255974316
45 | 11: 0.00016609538710764618
46 | 13: 2.7879693665067774e-05
47 | 15: 0.00039838616015114444
48 | 16: 0.0
49 | 18: 0.0020633612104619787
50 | 20: 0.0016218197275284021
51 | 30: 0.00017698551338515307
52 | 31: 1.1065903904919655e-08
53 | 32: 5.532951952459828e-09
54 | 40: 0.1987493871255525
55 | 44: 0.014717169549888214
56 | 48: 0.14392298360372
57 | 49: 0.0039048553037472045
58 | 50: 0.1326861944777486
59 | 51: 0.0723592229456223
60 | 52: 0.002395131480328884
61 | 60: 4.7084144280367186e-05
62 | 70: 0.26681502148037506
63 | 71: 0.006035012012626033
64 | 72: 0.07814222006271769
65 | 80: 0.002855498193863172
66 | 81: 0.0006155958086189918
67 | 99: 0.009923127583046915
68 | 252: 0.001789309418528068
69 | 253: 0.00012709999297008662
70 | 254: 0.00016059776092534436
71 | 255: 3.745553104802113e-05
72 | 256: 0.0
73 | 257: 0.00011351574470342043
74 | 258: 0.00010157861367183268
75 | 259: 4.3840131989471124e-05
76 | # classes that are indistinguishable from single scan or inconsistent in
77 | # ground truth are mapped to their closest equivalent
78 |
79 | # 11 CLASSES
80 | learning_map:
81 | 0: 0 #"unlabeled"
82 | 4: 1 # "1 person" --> "people" ----------------mapped
83 | 5: 1 # "2+ person" --> "people" ---------------mapped
84 | 6: 2 #"rider"
85 | 7: 3 #"car"
86 | 8: 4 #"trunk"
87 | 9: 5 #"plants"
88 | 10: 6 # "traffic sign 1" # standing sign -->traffic sign----------------mapped
89 | 11: 6 #"traffic sign 2" # hanging sign-->traffic sign----------------mapped
90 | 12: 6 #"traffic sign 3" # high/big hanging sign-->traffic sign----------------mapped
91 | 13: 7 #"pole"
92 | 14: 0 #"trashcan" --> "unlabeled" ----------------mapped
93 | 15: 8 #"building"
94 | 16: 0 # "cone/stone" --> "unlabeled" ----------------mapped
95 | 17: 9 # "fence"
96 | 21: 10 #"bike"
97 | 22: 11 #"ground" # class definition
98 |
99 | learning_map_inv: # inverse of previous map
100 | 0: 0 # "unlabeled"
101 | 1: 4 # "people"
102 | 2: 6 # "rider"
103 | 3: 7 # "car"
104 | 4: 8 # "trunk"
105 | 5: 9 # "plants"
106 | 6: 10 # "traffic sign"
107 | 7: 13 # "pole"
108 | 8: 15 # "building"
109 | 9: 17 # "fence"
110 | 10: 21 # "bike"
111 | 11: 22 # "ground"
112 |
113 | learning_ignore: # Ignore classes
114 | 0: True # "unlabeled", and others ignored
115 | 1: False # "car"
116 | 2: False # "bicycle"
117 | 3: False # "motorcycle"
118 | 4: False # "truck"
119 | 5: False # "other-vehicle"
120 | 6: False # "person"
121 | 7: False # "bicyclist"
122 | 8: False # "motorcyclist"
123 | 9: False # "road"
124 | 10: False # "parking"
125 | 11: False # "sidewalk"
126 |
127 |
128 | split: # sequence numbers
129 | train:
130 | - 0
131 | - 1
132 | - 3
133 | - 4
134 | - 5
135 | valid:
136 | - 2
137 | test:
138 | - 2
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | addict==2.4.0
3 | aliyun-python-sdk-core==2.14.0
4 | aliyun-python-sdk-kms==2.16.2
5 | ansi2html==1.9.1
6 | antlr4-python3-runtime==4.9.3
7 | asttokens==2.4.1
8 | attrs==23.2.0
9 | backcall==0.2.0
10 | basicsr==1.4.2
11 | blinker==1.7.0
12 | cache==1.0.3
13 | cachetools==5.3.2
14 | ccimport==0.4.2
15 | certifi==2024.2.2
16 | cffi==1.16.0
17 | charset-normalizer==3.3.2
18 | click==8.1.7
19 | colorama==0.4.6
20 | comm==0.2.1
21 | ConfigArgParse==1.7
22 | contourpy==1.1.1
23 | crcmod==1.7
24 | cryptography==42.0.2
25 | cumm==0.4.11
26 | cumm-cu113==0.4.11
27 | cycler==0.12.1
28 | Cython==3.0.9
29 | dash==2.15.0
30 | dash-core-components==2.0.0
31 | dash-html-components==2.0.0
32 | dash-table==5.0.0
33 | decorator==5.1.1
34 | easydict==1.11
35 | einops==0.7.0
36 | executing==2.0.1
37 | fastjsonschema==2.19.1
38 | filelock==3.13.1
39 | fire==0.6.0
40 | Flask==3.0.2
41 | fonttools==4.47.2
42 | freetype-py==2.4.0
43 | fsspec==2024.2.0
44 | future==0.18.3
45 | gin==0.1.6
46 | gin-config==0.5.0
47 | glfw==1.8.3
48 | google-auth==2.27.0
49 | google-auth-oauthlib==1.0.0
50 | grpcio==1.60.1
51 | h5py==3.10.0
52 | hsluv==5.0.4
53 | huggingface-hub==0.20.3
54 | idna==3.6
55 | imageio==2.33.1
56 | imageio-ffmpeg==0.4.9
57 | importlib-metadata==7.0.1
58 | importlib-resources==6.1.1
59 | ipython==8.12.3
60 | ipywidgets==8.1.1
61 | itsdangerous==2.1.2
62 | jedi==0.19.1
63 | Jinja2==3.1.3
64 | jmespath==0.10.0
65 | joblib==1.3.2
66 | jsonschema==4.21.1
67 | jsonschema-specifications==2023.12.1
68 | jupyter_core==5.7.1
69 | jupyterlab-widgets==3.0.9
70 | kiwisolver==1.4.5
71 | kornia==0.7.1
72 | lark==1.1.9
73 | lazy_loader==0.3
74 | lightning-utilities==0.10.1
75 | llvmlite==0.41.1
76 | lmdb==1.4.1
77 | loralib==0.1.2
78 | Markdown==3.5.2
79 | markdown-it-py==3.0.0
80 | MarkupSafe==2.1.5
81 | matplotlib==3.7.4
82 | matplotlib-inline==0.1.6
83 | mdurl==0.1.2
84 | MinkowskiEngine==0.5.4
85 | mkl-fft==1.3.1
86 | mkl-random==1.2.2
87 | mkl-service==2.4.0
88 | mlxtend==0.23.1
89 | mmcv-full==1.7.2
90 | mmengine==0.10.3
91 | model-index==0.1.11
92 | nbformat==5.9.2
93 | nest-asyncio==1.6.0
94 | networkx==3.1
95 | ninja==1.11.1.1
96 | numba==0.58.1
97 | numpy==1.23.0
98 | numpy-indexed==0.3.7
99 | nvidia-ml-py3==7.352.0
100 | oauthlib==3.2.2
101 | olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1701735466804/work
102 | omegaconf==2.3.0
103 | open3d==0.18.0
104 | opencv-python==4.9.0.80
105 | opendatalab==0.0.10
106 | openmim==0.3.9
107 | openxlab==0.0.34
108 | ordered-set==4.1.0
109 | oss2==2.17.0
110 | packaging==23.2
111 | pandas==2.0.3
112 | parso==0.8.3
113 | pccm==0.4.11
114 | pexpect==4.9.0
115 | pickleshare==0.7.5
116 | Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1630696607296/work
117 | pkgutil_resolve_name==1.3.10
118 | platformdirs==4.2.0
119 | plotly==5.18.0
120 | plyfile==1.0.3
121 | portalocker==2.8.2
122 | prompt-toolkit==3.0.43
123 | protobuf==4.25.2
124 | ptyprocess==0.7.0
125 | pure-eval==0.2.2
126 | purge==1.0
127 | pyasn1==0.5.1
128 | pyasn1-modules==0.3.0
129 | pybind11==2.11.1
130 | pycocotools==2.0.7
131 | pycparser==2.21
132 | pycryptodome==3.20.0
133 | Pygments==2.17.2
134 | pyparsing==3.1.1
135 | pyquaternion==0.9.9
136 | python-dateutil==2.8.2
137 | pytorch-msssim==1.0.0
138 | pytz==2023.4
139 | PyWavelets==1.4.1
140 | PyYAML==6.0.1
141 | referencing==0.33.0
142 | requests==2.28.2
143 | requests-oauthlib==1.3.1
144 | retrying==1.3.4
145 | rich==13.4.2
146 | rpds-py==0.17.1
147 | rsa==4.9
148 | safetensors==0.4.2
149 | scikit-image==0.21.0
150 | scikit-learn==1.3.2
151 | scipy==1.10.1
152 | seaborn==0.13.2
153 | shapely==2.0.2
154 | SharedArray==3.2.3
155 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
156 | spconv @ file:///home/user/spconv1.0/dist/spconv-1.0-cp38-cp38-linux_x86_64.whl#sha256=251459523b0dd39a7a5f29f5b3391a15a023d842c835c169b946bbe37c8e60af
157 | stack-data==0.6.3
158 | strictyaml==1.7.3
159 | tabulate==0.9.0
160 | tenacity==8.2.3
161 | tensorboard==2.14.0
162 | tensorboard-data-server==0.7.2
163 | tensorboardX==2.6.2.2
164 | termcolor==2.4.0
165 | terminaltables==3.1.10
166 | threadpoolctl==3.2.0
167 | tifffile==2023.7.10
168 | timm==0.9.12
169 | tomli==2.0.1
170 | torch==1.10.1
171 | torch-efficient-distloss==0.1.3
172 | torch-scatter==2.1.2
173 | torchaudio==0.10.1
174 | torchmetrics==1.3.0.post0
175 | torchvision==0.11.2
176 | tqdm==4.65.2
177 | traitlets==5.14.1
178 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1702176139754/work
179 | tzdata==2023.4
180 | urllib3==1.26.18
181 | vispy==0.14.2
182 | wcwidth==0.2.13
183 | Werkzeug==3.0.1
184 | widgetsnbextension==4.0.9
185 | yapf==0.40.2
186 | zipp==3.17.0
187 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [NeurIPS 2024] TALoS: Enhancing Semantic Scene Completion via Test-time Adaptation on the Line of Sight
2 |
3 | This repository contains the official PyTorch implementation of the paper "TALoS: Enhancing Semantic Scene Completion via Test-time Adaptation on the Line of Sight" paper (NeurIPS 2024) by [Hyun-Kurl Jang*](https://blue-531.github.io/
4 | ) , [Jihun Kim*](https://jihun1998.github.io/
5 | ) and [Hyeokjun Kweon*](https://sangrockeg.github.io/
6 | ).
7 |
8 | (* denotes equal contribution.)
9 |
10 | [[Paper]](https://arxiv.org/abs/2410.15674)
11 |
12 | ## News
13 |
23 |
24 | Our main idea is simple yet effective:
25 | **an observation made at one moment could serve as supervision for the SSC prediction at another moment.**
26 | While traveling through an environment, an autonomous vehicle can continuously observe the overall scene structures, including objects that were previously occluded (or will be occluded later), which are concrete guidances for the adaptation of scene completion. Given the characteristics of the LiDAR sensor, an observation of a point at a specific spatial location at a specific moment confirms not only the occupation at that location itself but also the absence of obstacles along the line of sight from the sensor to that location.
27 | The proposed method, named
28 | **Test-time Adaptation via Line of Sight (TALoS)**
29 | , is designed to explicitly leverage these characteristics, obtaining self-supervision for geometric completion.
30 | Additionally, we extend the TALoS framework for semantic recognition, another key goal of SSC, by collecting the reliable regions only among the semantic segmentation results predicted at each moment.
31 | Further, to leverage valuable future information that is not accessible at the time of the current update, we devise a novel dual optimization scheme involving the model gradually updating across the temporal dimension.
32 | ## Installation
33 |
34 | - PyTorch >= 1.10
35 | - pyyaml
36 | - Cython
37 | - tqdm
38 | - numba
39 | - Numpy-indexed
40 | - [torch-scatter](https://github.com/rusty1s/pytorch_scatter)
41 | - [spconv](https://github.com/tyjiang1997/spconv1.0) (tested with spconv==1.0 and cuda==11.3)
42 |
43 |
44 |
45 | ## Data Preparation
46 |
47 | ### SemanticKITTI
48 | ```
49 | ./
50 | ├──
51 | ├── ...
52 | ├── model_load_dir
53 | ├──pretrained.pth
54 | └── dataset/
55 | ├──sequences
56 | ├── 00/
57 | │ ├── velodyne/
58 | | | ├── 000000.bin
59 | | | ├── 000001.bin
60 | | | └── ...
61 | │ └── labels/
62 | | ├── 000000.label
63 | | ├── 000001.label
64 | | └── ...
65 | │ └── voxels/
66 | | ├── 000000.bin
67 | | ├── 000000.label
68 | | ├── 000000.invalid
69 | | ├── 000000.occluded
70 | | ├── 000001.bin
71 | | ├── 000001.label
72 | | ├── 000001.invalid
73 | | ├── 000001.occluded
74 | | └── ...
75 | ├── 08/ # for validation
76 | ├── 11/ # 11-21 for testing
77 | └── 21/
78 | └── ...
79 | ```
80 |
81 | ## Test-Time Adaptation
82 | 1. Download the pre-trained models and put them in ```./model_load_dir```. [[link]](https://drive.google.com/file/d/12jYauPbVodnSA-faBjFucUNgxeGU0pmP/view?usp=drive_link)
83 | 2. (Optional) Download pre-trained model results and put them in ```./experiments/baseline``` for comparison. [[link]](https://drive.google.com/file/d/1gt65t7hkdnnax2v7BALgUsunTaGHRVkh/view?usp=drive_link)
84 | 3. Generate predictions on the Dataset.
85 |
86 | ### Validation set
87 | ```
88 | python run_tta_val.py --do_adapt --do_cont --use_los --use_pgt
89 | ```
90 | ### Test set
91 | ```
92 | python run_tta_test.py --do_adapt --do_cont --use_los --use_pgt --sq_num={sequence number}
93 | ```
94 | ## Evaluation
95 | To evaluate test sequences in SemanticKITTI, you should submit the generated predictions to [link](https://codalab.lisn.upsaclay.fr/competitions/7170).
96 | After generate predictions, prepare your submission in the designated format, as described in the competition page.
97 | Use the validation script from the [semantic-kitti-api](https://github.com/PRBonn/semantic-kitti-api) to ensure that the folder structure and number of label files in the zip file is correct.
98 |
99 |
100 |
101 |
102 | ## Acknowledgements
103 | We thanks for the open source project [SCPNet](https://github.com/SCPNet/Codes-for-SCPNet).
104 |
--------------------------------------------------------------------------------
/config/label_mapping/semantic-kitti.yaml:
--------------------------------------------------------------------------------
1 | # This file is covered by the LICENSE file in the root of this project.
2 | labels:
3 | 0 : "unlabeled"
4 | 1 : "outlier"
5 | 10: "car"
6 | 11: "bicycle"
7 | 13: "bus"
8 | 15: "motorcycle"
9 | 16: "on-rails"
10 | 18: "truck"
11 | 20: "other-vehicle"
12 | 30: "person"
13 | 31: "bicyclist"
14 | 32: "motorcyclist"
15 | 40: "road"
16 | 44: "parking"
17 | 48: "sidewalk"
18 | 49: "other-ground"
19 | 50: "building"
20 | 51: "fence"
21 | 52: "other-structure"
22 | 60: "lane-marking"
23 | 70: "vegetation"
24 | 71: "trunk"
25 | 72: "terrain"
26 | 80: "pole"
27 | 81: "traffic-sign"
28 | 99: "other-object"
29 | 252: "moving-car"
30 | 253: "moving-bicyclist"
31 | 254: "moving-person"
32 | 255: "moving-motorcyclist"
33 | 256: "moving-on-rails"
34 | 257: "moving-bus"
35 | 258: "moving-truck"
36 | 259: "moving-other-vehicle"
37 | color_map: # bgr
38 | 0 : [0, 0, 0]
39 | 1 : [0, 0, 255]
40 | 10: [245, 150, 100]
41 | 11: [245, 230, 100]
42 | 13: [250, 80, 100]
43 | 15: [150, 60, 30]
44 | 16: [255, 0, 0]
45 | 18: [180, 30, 80]
46 | 20: [255, 0, 0]
47 | 30: [30, 30, 255]
48 | 31: [200, 40, 255]
49 | 32: [90, 30, 150]
50 | 40: [255, 0, 255]
51 | 44: [255, 150, 255]
52 | 48: [75, 0, 75]
53 | 49: [75, 0, 175]
54 | 50: [0, 200, 255]
55 | 51: [50, 120, 255]
56 | 52: [0, 150, 255]
57 | 60: [170, 255, 150]
58 | 70: [0, 175, 0]
59 | 71: [0, 60, 135]
60 | 72: [80, 240, 150]
61 | 80: [150, 240, 255]
62 | 81: [0, 0, 255]
63 | 99: [255, 255, 50]
64 | 252: [245, 150, 100]
65 | 256: [255, 0, 0]
66 | 253: [200, 40, 255]
67 | 254: [30, 30, 255]
68 | 255: [90, 30, 150]
69 | 257: [250, 80, 100]
70 | 258: [180, 30, 80]
71 | 259: [255, 0, 0]
72 | content: # as a ratio with the total number of points
73 | 0: 0.018889854628292943
74 | 1: 0.0002937197336781505
75 | 10: 0.040818519255974316
76 | 11: 0.00016609538710764618
77 | 13: 2.7879693665067774e-05
78 | 15: 0.00039838616015114444
79 | 16: 0.0
80 | 18: 0.0020633612104619787
81 | 20: 0.0016218197275284021
82 | 30: 0.00017698551338515307
83 | 31: 1.1065903904919655e-08
84 | 32: 5.532951952459828e-09
85 | 40: 0.1987493871255525
86 | 44: 0.014717169549888214
87 | 48: 0.14392298360372
88 | 49: 0.0039048553037472045
89 | 50: 0.1326861944777486
90 | 51: 0.0723592229456223
91 | 52: 0.002395131480328884
92 | 60: 4.7084144280367186e-05
93 | 70: 0.26681502148037506
94 | 71: 0.006035012012626033
95 | 72: 0.07814222006271769
96 | 80: 0.002855498193863172
97 | 81: 0.0006155958086189918
98 | 99: 0.009923127583046915
99 | 252: 0.001789309418528068
100 | 253: 0.00012709999297008662
101 | 254: 0.00016059776092534436
102 | 255: 3.745553104802113e-05
103 | 256: 0.0
104 | 257: 0.00011351574470342043
105 | 258: 0.00010157861367183268
106 | 259: 4.3840131989471124e-05
107 | # classes that are indistinguishable from single scan or inconsistent in
108 | # ground truth are mapped to their closest equivalent
109 | learning_map:
110 | 0 : 0 # "unlabeled"
111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
112 | 10: 1 # "car"
113 | 11: 2 # "bicycle"
114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
115 | 15: 3 # "motorcycle"
116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
117 | 18: 4 # "truck"
118 | 20: 5 # "other-vehicle"
119 | 30: 6 # "person"
120 | 31: 7 # "bicyclist"
121 | 32: 8 # "motorcyclist"
122 | 40: 9 # "road"
123 | 44: 10 # "parking"
124 | 48: 11 # "sidewalk"
125 | 49: 12 # "other-ground"
126 | 50: 13 # "building"
127 | 51: 14 # "fence"
128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped
130 | 70: 15 # "vegetation"
131 | 71: 16 # "trunk"
132 | 72: 17 # "terrain"
133 | 80: 18 # "pole"
134 | 81: 19 # "traffic-sign"
135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped
136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped
137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped
138 | 254: 6 # "moving-person" to "person" ------------------------------mapped
139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped
140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped
141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped
142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped
143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped
144 | learning_map_inv: # inverse of previous map
145 | 0: 0 # "unlabeled", and others ignored
146 | 1: 10 # "car"
147 | 2: 11 # "bicycle"
148 | 3: 15 # "motorcycle"
149 | 4: 18 # "truck"
150 | 5: 20 # "other-vehicle"
151 | 6: 30 # "person"
152 | 7: 31 # "bicyclist"
153 | 8: 32 # "motorcyclist"
154 | 9: 40 # "road"
155 | 10: 44 # "parking"
156 | 11: 48 # "sidewalk"
157 | 12: 49 # "other-ground"
158 | 13: 50 # "building"
159 | 14: 51 # "fence"
160 | 15: 70 # "vegetation"
161 | 16: 71 # "trunk"
162 | 17: 72 # "terrain"
163 | 18: 80 # "pole"
164 | 19: 81 # "traffic-sign"
165 | learning_ignore: # Ignore classes
166 | 0: True # "unlabeled", and others ignored
167 | 1: False # "car"
168 | 2: False # "bicycle"
169 | 3: False # "motorcycle"
170 | 4: False # "truck"
171 | 5: False # "other-vehicle"
172 | 6: False # "person"
173 | 7: False # "bicyclist"
174 | 8: False # "motorcyclist"
175 | 9: False # "road"
176 | 10: False # "parking"
177 | 11: False # "sidewalk"
178 | 12: False # "other-ground"
179 | 13: False # "building"
180 | 14: False # "fence"
181 | 15: False # "vegetation"
182 | 16: False # "trunk"
183 | 17: False # "terrain"
184 | 18: False # "pole"
185 | 19: False # "traffic-sign"
186 | split: # sequence numbers
187 | train:
188 | - 0
189 | - 1
190 | - 2
191 | - 3
192 | - 4
193 | - 5
194 | - 6
195 | - 7
196 | - 9
197 | - 10
198 | valid:
199 | - 8
200 | test:
201 | - 11
202 | - 12
203 | - 13
204 | - 14
205 | - 15
206 | - 16
207 | - 17
208 | - 18
209 | - 19
210 | - 20
211 | - 21
212 |
--------------------------------------------------------------------------------
/config/label_mapping/semantic-kitti-multiscan.yaml:
--------------------------------------------------------------------------------
1 | # This file is covered by the LICENSE file in the root of this project.
2 | labels:
3 | 0 : "unlabeled"
4 | 1 : "outlier"
5 | 10: "car"
6 | 11: "bicycle"
7 | 13: "bus"
8 | 15: "motorcycle"
9 | 16: "on-rails"
10 | 18: "truck"
11 | 20: "other-vehicle"
12 | 30: "person"
13 | 31: "bicyclist"
14 | 32: "motorcyclist"
15 | 40: "road"
16 | 44: "parking"
17 | 48: "sidewalk"
18 | 49: "other-ground"
19 | 50: "building"
20 | 51: "fence"
21 | 52: "other-structure"
22 | 60: "lane-marking"
23 | 70: "vegetation"
24 | 71: "trunk"
25 | 72: "terrain"
26 | 80: "pole"
27 | 81: "traffic-sign"
28 | 99: "other-object"
29 | 252: "moving-car"
30 | 253: "moving-bicyclist"
31 | 254: "moving-person"
32 | 255: "moving-motorcyclist"
33 | 256: "moving-on-rails"
34 | 257: "moving-bus"
35 | 258: "moving-truck"
36 | 259: "moving-other-vehicle"
37 | color_map: # bgr
38 | 0 : [0, 0, 0]
39 | 1 : [0, 0, 255]
40 | 10: [245, 150, 100]
41 | 11: [245, 230, 100]
42 | 13: [250, 80, 100]
43 | 15: [150, 60, 30]
44 | 16: [255, 0, 0]
45 | 18: [180, 30, 80]
46 | 20: [255, 0, 0]
47 | 30: [30, 30, 255]
48 | 31: [200, 40, 255]
49 | 32: [90, 30, 150]
50 | 40: [255, 0, 255]
51 | 44: [255, 150, 255]
52 | 48: [75, 0, 75]
53 | 49: [75, 0, 175]
54 | 50: [0, 200, 255]
55 | 51: [50, 120, 255]
56 | 52: [0, 150, 255]
57 | 60: [170, 255, 150]
58 | 70: [0, 175, 0]
59 | 71: [0, 60, 135]
60 | 72: [80, 240, 150]
61 | 80: [150, 240, 255]
62 | 81: [0, 0, 255]
63 | 99: [255, 255, 50]
64 | 252: [245, 150, 100]
65 | 256: [255, 0, 0]
66 | 253: [200, 40, 255]
67 | 254: [30, 30, 255]
68 | 255: [90, 30, 150]
69 | 257: [250, 80, 100]
70 | 258: [180, 30, 80]
71 | 259: [255, 0, 0]
72 | content: # as a ratio with the total number of points
73 | 0: 0.018889854628292943
74 | 1: 0.0002937197336781505
75 | 10: 0.040818519255974316
76 | 11: 0.00016609538710764618
77 | 13: 2.7879693665067774e-05
78 | 15: 0.00039838616015114444
79 | 16: 0.0
80 | 18: 0.0020633612104619787
81 | 20: 0.0016218197275284021
82 | 30: 0.00017698551338515307
83 | 31: 1.1065903904919655e-08
84 | 32: 5.532951952459828e-09
85 | 40: 0.1987493871255525
86 | 44: 0.014717169549888214
87 | 48: 0.14392298360372
88 | 49: 0.0039048553037472045
89 | 50: 0.1326861944777486
90 | 51: 0.0723592229456223
91 | 52: 0.002395131480328884
92 | 60: 4.7084144280367186e-05
93 | 70: 0.26681502148037506
94 | 71: 0.006035012012626033
95 | 72: 0.07814222006271769
96 | 80: 0.002855498193863172
97 | 81: 0.0006155958086189918
98 | 99: 0.009923127583046915
99 | 252: 0.001789309418528068
100 | 253: 0.00012709999297008662
101 | 254: 0.00016059776092534436
102 | 255: 3.745553104802113e-05
103 | 256: 0.0
104 | 257: 0.00011351574470342043
105 | 258: 0.00010157861367183268
106 | 259: 4.3840131989471124e-05
107 | # classes that are indistinguishable from single scan or inconsistent in
108 | # ground truth are mapped to their closest equivalent
109 | learning_map:
110 | 0 : 0 # "unlabeled"
111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
112 | 10: 1 # "car"
113 | 11: 2 # "bicycle"
114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
115 | 15: 3 # "motorcycle"
116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
117 | 18: 4 # "truck"
118 | 20: 5 # "other-vehicle"
119 | 30: 6 # "person"
120 | 31: 7 # "bicyclist"
121 | 32: 8 # "motorcyclist"
122 | 40: 9 # "road"
123 | 44: 10 # "parking"
124 | 48: 11 # "sidewalk"
125 | 49: 12 # "other-ground"
126 | 50: 13 # "building"
127 | 51: 14 # "fence"
128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped
130 | 70: 15 # "vegetation"
131 | 71: 16 # "trunk"
132 | 72: 17 # "terrain"
133 | 80: 18 # "pole"
134 | 81: 19 # "traffic-sign"
135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped
136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped
137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped
138 | 254: 6 # "moving-person" to "person" ------------------------------mapped
139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped
140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped
141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped
142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped
143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped
144 | learning_map_inv: # inverse of previous map
145 | 0: 0 # "unlabeled", and others ignored
146 | 1: 10 # "car"
147 | 2: 11 # "bicycle"
148 | 3: 15 # "motorcycle"
149 | 4: 18 # "truck"
150 | 5: 20 # "other-vehicle"
151 | 6: 30 # "person"
152 | 7: 31 # "bicyclist"
153 | 8: 32 # "motorcyclist"
154 | 9: 40 # "road"
155 | 10: 44 # "parking"
156 | 11: 48 # "sidewalk"
157 | 12: 49 # "other-ground"
158 | 13: 50 # "building"
159 | 14: 51 # "fence"
160 | 15: 70 # "vegetation"
161 | 16: 71 # "trunk"
162 | 17: 72 # "terrain"
163 | 18: 80 # "pole"
164 | 19: 81 # "traffic-sign"
165 | learning_ignore: # Ignore classes
166 | 0: True # "unlabeled", and others ignored
167 | 1: False # "car"
168 | 2: False # "bicycle"
169 | 3: False # "motorcycle"
170 | 4: False # "truck"
171 | 5: False # "other-vehicle"
172 | 6: False # "person"
173 | 7: False # "bicyclist"
174 | 8: False # "motorcyclist"
175 | 9: False # "road"
176 | 10: False # "parking"
177 | 11: False # "sidewalk"
178 | 12: False # "other-ground"
179 | 13: False # "building"
180 | 14: False # "fence"
181 | 15: False # "vegetation"
182 | 16: False # "trunk"
183 | 17: False # "terrain"
184 | 18: False # "pole"
185 | 19: False # "traffic-sign"
186 | split: # sequence numbers
187 | train:
188 | - 0
189 | - 1
190 | - 2
191 | - 3
192 | - 4
193 | - 5
194 | - 6
195 | - 7
196 | - 9
197 | - 10
198 | valid:
199 | - 8
200 | test:
201 | - 11
202 | - 12
203 | - 13
204 | - 14
205 | - 15
206 | - 16
207 | - 17
208 | - 18
209 | - 19
210 | - 20
211 | - 21
212 |
--------------------------------------------------------------------------------
/config/label_mapping/semantic-kitti-all.yaml:
--------------------------------------------------------------------------------
1 | # This file is covered by the LICENSE file in the root of this project.
2 | labels:
3 | 0 : "unlabeled"
4 | 1 : "outlier"
5 | 10: "car"
6 | 11: "bicycle"
7 | 13: "bus"
8 | 15: "motorcycle"
9 | 16: "on-rails"
10 | 18: "truck"
11 | 20: "other-vehicle"
12 | 30: "person"
13 | 31: "bicyclist"
14 | 32: "motorcyclist"
15 | 40: "road"
16 | 44: "parking"
17 | 48: "sidewalk"
18 | 49: "other-ground"
19 | 50: "building"
20 | 51: "fence"
21 | 52: "other-structure"
22 | 60: "lane-marking"
23 | 70: "vegetation"
24 | 71: "trunk"
25 | 72: "terrain"
26 | 80: "pole"
27 | 81: "traffic-sign"
28 | 99: "other-object"
29 | 252: "moving-car"
30 | 253: "moving-bicyclist"
31 | 254: "moving-person"
32 | 255: "moving-motorcyclist"
33 | 256: "moving-on-rails"
34 | 257: "moving-bus"
35 | 258: "moving-truck"
36 | 259: "moving-other-vehicle"
37 | color_map: # bgr
38 | 0 : [0, 0, 0]
39 | 1 : [0, 0, 255]
40 | 10: [245, 150, 100]
41 | 11: [245, 230, 100]
42 | 13: [250, 80, 100]
43 | 15: [150, 60, 30]
44 | 16: [255, 0, 0]
45 | 18: [180, 30, 80]
46 | 20: [255, 0, 0]
47 | 30: [30, 30, 255]
48 | 31: [200, 40, 255]
49 | 32: [90, 30, 150]
50 | 40: [255, 0, 255]
51 | 44: [255, 150, 255]
52 | 48: [75, 0, 75]
53 | 49: [75, 0, 175]
54 | 50: [0, 200, 255]
55 | 51: [50, 120, 255]
56 | 52: [0, 150, 255]
57 | 60: [170, 255, 150]
58 | 70: [0, 175, 0]
59 | 71: [0, 60, 135]
60 | 72: [80, 240, 150]
61 | 80: [150, 240, 255]
62 | 81: [0, 0, 255]
63 | 99: [255, 255, 50]
64 | 252: [245, 150, 100]
65 | 256: [255, 0, 0]
66 | 253: [200, 40, 255]
67 | 254: [30, 30, 255]
68 | 255: [90, 30, 150]
69 | 257: [250, 80, 100]
70 | 258: [180, 30, 80]
71 | 259: [255, 0, 0]
72 | content: # as a ratio with the total number of points
73 | 0: 0.018889854628292943
74 | 1: 0.0002937197336781505
75 | 10: 0.040818519255974316
76 | 11: 0.00016609538710764618
77 | 13: 2.7879693665067774e-05
78 | 15: 0.00039838616015114444
79 | 16: 0.0
80 | 18: 0.0020633612104619787
81 | 20: 0.0016218197275284021
82 | 30: 0.00017698551338515307
83 | 31: 1.1065903904919655e-08
84 | 32: 5.532951952459828e-09
85 | 40: 0.1987493871255525
86 | 44: 0.014717169549888214
87 | 48: 0.14392298360372
88 | 49: 0.0039048553037472045
89 | 50: 0.1326861944777486
90 | 51: 0.0723592229456223
91 | 52: 0.002395131480328884
92 | 60: 4.7084144280367186e-05
93 | 70: 0.26681502148037506
94 | 71: 0.006035012012626033
95 | 72: 0.07814222006271769
96 | 80: 0.002855498193863172
97 | 81: 0.0006155958086189918
98 | 99: 0.009923127583046915
99 | 252: 0.001789309418528068
100 | 253: 0.00012709999297008662
101 | 254: 0.00016059776092534436
102 | 255: 3.745553104802113e-05
103 | 256: 0.0
104 | 257: 0.00011351574470342043
105 | 258: 0.00010157861367183268
106 | 259: 4.3840131989471124e-05
107 | # classes that are indistinguishable from single scan or inconsistent in
108 | # ground truth are mapped to their closest equivalent
109 | learning_map:
110 | 0 : 0 # "unlabeled"
111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped
112 | 10: 1 # "car"
113 | 11: 2 # "bicycle"
114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped
115 | 15: 3 # "motorcycle"
116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped
117 | 18: 4 # "truck"
118 | 20: 5 # "other-vehicle"
119 | 30: 6 # "person"
120 | 31: 7 # "bicyclist"
121 | 32: 8 # "motorcyclist"
122 | 40: 9 # "road"
123 | 44: 10 # "parking"
124 | 48: 11 # "sidewalk"
125 | 49: 12 # "other-ground"
126 | 50: 13 # "building"
127 | 51: 14 # "fence"
128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped
129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped
130 | 70: 15 # "vegetation"
131 | 71: 16 # "trunk"
132 | 72: 17 # "terrain"
133 | 80: 18 # "pole"
134 | 81: 19 # "traffic-sign"
135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped
136 | 252: 20 # "moving-car"
137 | 253: 21 # "moving-bicyclist"
138 | 254: 22 # "moving-person"
139 | 255: 23 # "moving-motorcyclist"
140 | 256: 24 # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped
141 | 257: 24 # "moving-bus" mapped to "moving-other-vehicle" -----------mapped
142 | 258: 25 # "moving-truck"
143 | 259: 24 # "moving-other-vehicle"
144 | learning_map_inv: # inverse of previous map
145 | 0: 0 # "unlabeled", and others ignored
146 | 1: 10 # "car"
147 | 2: 11 # "bicycle"
148 | 3: 15 # "motorcycle"
149 | 4: 18 # "truck"
150 | 5: 20 # "other-vehicle"
151 | 6: 30 # "person"
152 | 7: 31 # "bicyclist"
153 | 8: 32 # "motorcyclist"
154 | 9: 40 # "road"
155 | 10: 44 # "parking"
156 | 11: 48 # "sidewalk"
157 | 12: 49 # "other-ground"
158 | 13: 50 # "building"
159 | 14: 51 # "fence"
160 | 15: 70 # "vegetation"
161 | 16: 71 # "trunk"
162 | 17: 72 # "terrain"
163 | 18: 80 # "pole"
164 | 19: 81 # "traffic-sign"
165 | 20: 252 # "moving-car"
166 | 21: 253 # "moving-bicyclist"
167 | 22: 254 # "moving-person"
168 | 23: 255 # "moving-motorcyclist"
169 | 24: 259 # "moving-other-vehicle"
170 | 25: 258 # "moving-truck"
171 | learning_ignore: # Ignore classes
172 | 0: True # "unlabeled", and others ignored
173 | 1: False # "car"
174 | 2: False # "bicycle"
175 | 3: False # "motorcycle"
176 | 4: False # "truck"
177 | 5: False # "other-vehicle"
178 | 6: False # "person"
179 | 7: False # "bicyclist"
180 | 8: False # "motorcyclist"
181 | 9: False # "road"
182 | 10: False # "parking"
183 | 11: False # "sidewalk"
184 | 12: False # "other-ground"
185 | 13: False # "building"
186 | 14: False # "fence"
187 | 15: False # "vegetation"
188 | 16: False # "trunk"
189 | 17: False # "terrain"
190 | 18: False # "pole"
191 | 19: False # "traffic-sign"
192 | 20: False # "moving-car"
193 | 21: False # "moving-bicyclist"
194 | 22: False # "moving-person"
195 | 23: False # "moving-motorcyclist"
196 | 24: False # "moving-other-vehicle"
197 | 25: False # "moving-truck"
198 | split: # sequence numbers
199 | train:
200 | - 0
201 | - 1
202 | - 2
203 | - 3
204 | - 4
205 | - 5
206 | - 6
207 | - 7
208 | - 9
209 | - 10
210 | valid:
211 | - 8
212 | test:
213 | - 11
214 | - 12
215 | - 13
216 | - 14
217 | - 15
218 | - 16
219 | - 17
220 | - 18
221 | - 19
222 | - 20
223 | - 21
224 |
--------------------------------------------------------------------------------
/network/conv_base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import List, Tuple
3 |
4 |
5 | class SharedMLP(nn.Sequential):
6 |
7 | def __init__(
8 | self,
9 | args: List[int],
10 | *,
11 | bn: bool = False,
12 | activation=nn.ReLU(inplace=True),
13 | preact: bool = False,
14 | first: bool = False,
15 | name: str = "",
16 | instance_norm: bool = False, ):
17 | super().__init__()
18 |
19 | for i in range(len(args) - 1):
20 | self.add_module(
21 | name + 'layer{}'.format(i),
22 | Conv2d(
23 | args[i],
24 | args[i + 1],
25 | bn=(not first or not preact or (i != 0)) and bn,
26 | activation=activation
27 | if (not first or not preact or (i != 0)) else None,
28 | preact=preact,
29 | instance_norm=instance_norm
30 | )
31 | )
32 |
33 |
34 | class _ConvBase(nn.Sequential):
35 |
36 | def __init__(
37 | self,
38 | in_size,
39 | out_size,
40 | kernel_size,
41 | stride,
42 | padding,
43 | activation,
44 | bn,
45 | init,
46 | conv=None,
47 | batch_norm=None,
48 | bias=True,
49 | preact=False,
50 | name="",
51 | instance_norm=False,
52 | instance_norm_func=None
53 | ):
54 | super().__init__()
55 |
56 | bias = bias and (not bn)
57 | conv_unit = conv(
58 | in_size,
59 | out_size,
60 | kernel_size=kernel_size,
61 | stride=stride,
62 | padding=padding,
63 | bias=bias
64 | )
65 | init(conv_unit.weight)
66 | if bias:
67 | nn.init.constant_(conv_unit.bias, 0)
68 |
69 | if bn:
70 | if not preact:
71 | bn_unit = batch_norm(out_size)
72 | else:
73 | bn_unit = batch_norm(in_size)
74 | if instance_norm:
75 | if not preact:
76 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
77 | else:
78 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
79 |
80 | if preact:
81 | if bn:
82 | self.add_module(name + 'bn', bn_unit)
83 |
84 | if activation is not None:
85 | self.add_module(name + 'activation', activation)
86 |
87 | if not bn and instance_norm:
88 | self.add_module(name + 'in', in_unit)
89 |
90 | self.add_module(name + 'conv', conv_unit)
91 |
92 | if not preact:
93 | if bn:
94 | self.add_module(name + 'bn', bn_unit)
95 |
96 | if activation is not None:
97 | self.add_module(name + 'activation', activation)
98 |
99 | if not bn and instance_norm:
100 | self.add_module(name + 'in', in_unit)
101 |
102 |
103 | class _BNBase(nn.Sequential):
104 |
105 | def __init__(self, in_size, batch_norm=None, name=""):
106 | super().__init__()
107 | self.add_module(name + "bn", batch_norm(in_size))
108 |
109 | nn.init.constant_(self[0].weight, 1.0)
110 | nn.init.constant_(self[0].bias, 0)
111 |
112 |
113 | class BatchNorm1d(_BNBase):
114 |
115 | def __init__(self, in_size: int, *, name: str = ""):
116 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
117 |
118 |
119 | class BatchNorm2d(_BNBase):
120 |
121 | def __init__(self, in_size: int, name: str = ""):
122 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
123 |
124 |
125 | class BatchNorm3d(_BNBase):
126 |
127 | def __init__(self, in_size: int, name: str = ""):
128 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name)
129 |
130 |
131 | class Conv1d(_ConvBase):
132 |
133 | def __init__(
134 | self,
135 | in_size: int,
136 | out_size: int,
137 | *,
138 | kernel_size: int = 1,
139 | stride: int = 1,
140 | padding: int = 0,
141 | activation=nn.ReLU(inplace=True),
142 | bn: bool = False,
143 | init=nn.init.kaiming_normal_,
144 | bias: bool = True,
145 | preact: bool = False,
146 | name: str = "",
147 | instance_norm=False
148 | ):
149 | super().__init__(
150 | in_size,
151 | out_size,
152 | kernel_size,
153 | stride,
154 | padding,
155 | activation,
156 | bn,
157 | init,
158 | conv=nn.Conv1d,
159 | batch_norm=BatchNorm1d,
160 | bias=bias,
161 | preact=preact,
162 | name=name,
163 | instance_norm=instance_norm,
164 | instance_norm_func=nn.InstanceNorm1d
165 | )
166 |
167 |
168 | class Conv2d(_ConvBase):
169 |
170 | def __init__(
171 | self,
172 | in_size: int,
173 | out_size: int,
174 | *,
175 | kernel_size: Tuple[int, int] = (1, 1),
176 | stride: Tuple[int, int] = (1, 1),
177 | padding: Tuple[int, int] = (0, 0),
178 | activation=nn.ReLU(inplace=True),
179 | bn: bool = False,
180 | init=nn.init.kaiming_normal_,
181 | bias: bool = True,
182 | preact: bool = False,
183 | name: str = "",
184 | instance_norm=False
185 | ):
186 | super().__init__(
187 | in_size,
188 | out_size,
189 | kernel_size,
190 | stride,
191 | padding,
192 | activation,
193 | bn,
194 | init,
195 | conv=nn.Conv2d,
196 | batch_norm=BatchNorm2d,
197 | bias=bias,
198 | preact=preact,
199 | name=name,
200 | instance_norm=instance_norm,
201 | instance_norm_func=nn.InstanceNorm2d
202 | )
203 |
204 |
205 | class Conv3d(_ConvBase):
206 |
207 | def __init__(
208 | self,
209 | in_size: int,
210 | out_size: int,
211 | *,
212 | kernel_size: Tuple[int, int, int] = (1, 1, 1),
213 | stride: Tuple[int, int, int] = (1, 1, 1),
214 | padding: Tuple[int, int, int] = (0, 0, 0),
215 | activation=nn.ReLU(inplace=True),
216 | bn: bool = False,
217 | init=nn.init.kaiming_normal_,
218 | bias: bool = True,
219 | preact: bool = False,
220 | name: str = "",
221 | instance_norm=False
222 | ):
223 | super().__init__(
224 | in_size,
225 | out_size,
226 | kernel_size,
227 | stride,
228 | padding,
229 | activation,
230 | bn,
231 | init,
232 | conv=nn.Conv3d,
233 | batch_norm=BatchNorm3d,
234 | bias=bias,
235 | preact=preact,
236 | name=name,
237 | instance_norm=instance_norm,
238 | instance_norm_func=nn.InstanceNorm3d
239 | )
240 |
241 |
242 | class FC(nn.Sequential):
243 |
244 | def __init__(
245 | self,
246 | in_size: int,
247 | out_size: int,
248 | *,
249 | activation=nn.ReLU(inplace=True),
250 | bn: bool = False,
251 | init=None,
252 | preact: bool = False,
253 | name: str = ""
254 | ):
255 | super().__init__()
256 |
257 | fc = nn.Linear(in_size, out_size, bias=not bn)
258 | if init is not None:
259 | init(fc.weight)
260 | if not bn:
261 | nn.init.constant(fc.bias, 0)
262 |
263 | if preact:
264 | if bn:
265 | self.add_module(name + 'bn', BatchNorm1d(in_size))
266 |
267 | if activation is not None:
268 | self.add_module(name + 'activation', activation)
269 |
270 | self.add_module(name + 'fc', fc)
271 |
272 | if not preact:
273 | if bn:
274 | self.add_module(name + 'bn', BatchNorm1d(out_size))
275 |
276 | if activation is not None:
277 | self.add_module(name + 'activation', activation)
278 |
279 |
280 |
--------------------------------------------------------------------------------
/utils/load_save_util.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: load_save_util.py
4 |
5 | import torch
6 |
7 |
8 | # def load_checkpoint_old2(model_load_path, model):
9 | def load_checkpoint(model_load_path, model):
10 | pre_weight = torch.load(model_load_path)
11 | my_model_dict = model.state_dict()
12 | part_load = {}
13 | match_size = 0
14 | nomatch_size = 0
15 | for k in pre_weight.keys():
16 | value = pre_weight[k]
17 | # str3 = 'seg_head.sparseModel.1.weight'
18 | # if k.find(str3) > 0:
19 | # value = value[:, :, 0, :]
20 | if k in my_model_dict and my_model_dict[k].shape == value.shape:
21 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape)))
22 | match_size += 1
23 | part_load[k] = value
24 | else:
25 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape)))
26 | Lvalue = len(value.shape)
27 | assert 1 <= Lvalue <= 5
28 | if len(value.shape) == 1:
29 | c = value.shape[0]
30 | cc = my_model_dict[k].shape[0] - c #int(c*0.5)
31 | if 0 < cc <= c:
32 | value = torch.cat([value, value[:cc]], dim=0)
33 | elif cc > c:
34 | value = torch.cat([value, value, value[:(cc-c)]], dim=0)
35 | elif cc < 0:
36 | value = value[:-cc]
37 | else:
38 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape)))
39 | cs = value.shape
40 | ccs = [0]*Lvalue
41 | j = -1
42 | for ci in cs:
43 | j += 1
44 | ccs[j] = (my_model_dict[k].shape[j] - ci)
45 | if ccs[j] != 0:
46 | for m in range(Lvalue):
47 | if m != j:
48 | ccs[m] = value.shape[m]
49 | # print(ccs)
50 | if ccs[j] > 0:
51 | if ccs[j] > ci:
52 | ccs[j] = ccs[j] - ci
53 | if Lvalue == 5:
54 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2], :ccs[3], :ccs[4]]], dim=j)
55 | elif Lvalue == 4:
56 | # print(value[:ccs[0], :ccs[1], :ccs[2], :ccs[3]].shape)
57 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2], :ccs[3]]], dim=j)
58 | elif Lvalue == 3:
59 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2]]], dim=j)
60 | elif Lvalue == 2:
61 | value = torch.cat([value, value[:ccs[0], :ccs[1]]], dim=j)
62 | ccs[j] = value.shape[j]
63 | elif ccs[j] < 0:
64 | # ccs[j] = -ccs[j]
65 | ccs[j] = my_model_dict[k].shape[j]
66 | if j == 0:
67 | value = value[:ccs[0], :]
68 | elif j == 1:
69 | if Lvalue == 2:
70 | value = value[:, :ccs[1]]
71 | else:
72 | value = value[:, :ccs[1], :]
73 | elif j == 2:
74 | if Lvalue == 3:
75 | value = value[:, :, :ccs[2]]
76 | else:
77 | value = value[:, :, :ccs[2], :]
78 | elif j == 3 and j <= Lvalue-1:
79 | if Lvalue == 4:
80 | value = value[:, :, :, :ccs[3]]
81 | else:
82 | value = value[:, :, :, :ccs[3], :]
83 | elif j == 4 and j <= Lvalue-1:
84 | value = value[:, :, :, :, :ccs[4]]
85 |
86 | nomatch_size += 1
87 | if my_model_dict[k].shape == value.shape:
88 | part_load[k] = value
89 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape)))
90 | # assert my_model_dict[k].shape == value.shape
91 |
92 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size))
93 |
94 | my_model_dict.update(part_load)
95 | model.load_state_dict(my_model_dict)
96 | # model.load_state_dict(my_model_dict, strict=False) # True
97 |
98 | return model
99 |
100 |
101 | def load_checkpoint_old2(model_load_path, model):
102 | my_model_dict = model.state_dict()
103 | pre_weight = torch.load(model_load_path)
104 |
105 | part_load = {}
106 | match_size = 0
107 | nomatch_size = 0
108 | for k in pre_weight.keys():
109 | value = pre_weight[k]
110 | if k in my_model_dict and my_model_dict[k].shape == value.shape:
111 | #print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape)))
112 | match_size += 1
113 | part_load[k] = value
114 | else:
115 | assert len(value.shape) == 1 or len(value.shape) == 5
116 | if len(value.shape) == 1:
117 | c = value.shape[0]
118 | cc = my_model_dict[k].shape[0] - c #int(c*0.5)
119 | if cc <= c:
120 | value = torch.cat([value, value[:cc]], dim=0)
121 | else:
122 | value = torch.cat([value, value, value[:(cc-c)]], dim=0)
123 | else:
124 | _, _, _, c1, c2 = value.shape
125 | cc1 = my_model_dict[k].shape[3] - c1 #int(c1*0.5)
126 | cc2 = my_model_dict[k].shape[4] - c2 #int(c2*0.5)
127 | if cc1 > 0 and cc1 <= c1:
128 | value1 = torch.cat([value, value[:, :, :, :cc1, :]], dim=3)
129 | elif cc1 > c1:
130 | value1 = torch.cat([value, value, value[:, :, :, :(cc1-c1), :]], dim=3)
131 | else:
132 | value1 = value
133 | if cc2 > 0 and cc2 <= c2:
134 | value = torch.cat([value1, value1[:, :, :, :, :cc2]], dim=4)
135 | elif cc2 > c2:
136 | value = torch.cat([value1, value1, value1[:, :, :, :, :(cc2-c2)]], dim=4)
137 | else:
138 | value = value1
139 | nomatch_size += 1
140 | part_load[k] = value
141 | assert my_model_dict[k].shape == value.shape
142 | #print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape)))
143 |
144 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size))
145 |
146 | my_model_dict.update(part_load)
147 | # model.load_state_dict(my_model_dict)
148 | model.load_state_dict(my_model_dict, strict=False) # True
149 |
150 | return model
151 |
152 | def load_checkpoint_old(model_load_path, model):
153 | my_model_dict = model.state_dict()
154 | pre_weight = torch.load(model_load_path)
155 |
156 | part_load = {}
157 | match_size = 0
158 | nomatch_size = 0
159 | for k in pre_weight.keys():
160 | value = pre_weight[k]
161 | if k in my_model_dict and my_model_dict[k].shape == value.shape:
162 | # print("loading ", k)
163 | match_size += 1
164 | part_load[k] = value
165 | else:
166 | nomatch_size += 1
167 |
168 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size))
169 |
170 | my_model_dict.update(part_load)
171 | model.load_state_dict(my_model_dict)
172 |
173 | return model
174 |
175 | def load_checkpoint_1b1(model_load_path, model):
176 | my_model_dict = model.state_dict()
177 | pre_weight = torch.load(model_load_path)
178 |
179 | part_load = {}
180 | match_size = 0
181 | nomatch_size = 0
182 |
183 | pre_weight_list = [*pre_weight]
184 | my_model_dict_list = [*my_model_dict]
185 |
186 | for idx in range(len(pre_weight_list)):
187 | key_ = pre_weight_list[idx]
188 | key_2 = my_model_dict_list[idx]
189 | value_ = pre_weight[key_]
190 | if my_model_dict[key_2].shape == pre_weight[key_].shape:
191 | # print("loading ", k)
192 | match_size += 1
193 | part_load[key_2] = value_
194 | else:
195 | print(key_)
196 | print(key_2)
197 | nomatch_size += 1
198 |
199 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size))
200 |
201 | my_model_dict.update(part_load)
202 | model.load_state_dict(my_model_dict)
203 |
204 | return model
205 |
--------------------------------------------------------------------------------
/utils/lovasz_losses.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 |
4 |
5 | """
6 | Lovasz-Softmax and Jaccard hinge loss in PyTorch
7 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
8 | """
9 |
10 | from __future__ import print_function, division
11 |
12 | import torch
13 | from torch.autograd import Variable
14 | import torch.nn.functional as F
15 | import numpy as np
16 | try:
17 | from itertools import ifilterfalse
18 | except ImportError: # py3k
19 | from itertools import filterfalse as ifilterfalse
20 |
21 | def lovasz_grad(gt_sorted):
22 | """
23 | Computes gradient of the Lovasz extension w.r.t sorted errors
24 | See Alg. 1 in paper
25 | """
26 | p = len(gt_sorted)
27 | gts = gt_sorted.sum()
28 | intersection = gts - gt_sorted.float().cumsum(0)
29 | union = gts + (1 - gt_sorted).float().cumsum(0)
30 | jaccard = 1. - intersection / union
31 | if p > 1: # cover 1-pixel case
32 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
33 | return jaccard
34 |
35 |
36 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
37 | """
38 | IoU for foreground class
39 | binary: 1 foreground, 0 background
40 | """
41 | if not per_image:
42 | preds, labels = (preds,), (labels,)
43 | ious = []
44 | for pred, label in zip(preds, labels):
45 | intersection = ((label == 1) & (pred == 1)).sum()
46 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
47 | if not union:
48 | iou = EMPTY
49 | else:
50 | iou = float(intersection) / float(union)
51 | ious.append(iou)
52 | iou = mean(ious) # mean accross images if per_image
53 | return 100 * iou
54 |
55 |
56 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
57 | """
58 | Array of IoU for each (non ignored) class
59 | """
60 | if not per_image:
61 | preds, labels = (preds,), (labels,)
62 | ious = []
63 | for pred, label in zip(preds, labels):
64 | iou = []
65 | for i in range(C):
66 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
67 | intersection = ((label == i) & (pred == i)).sum()
68 | union = ((label == i) | ((pred == i) & (label != ignore))).sum()
69 | if not union:
70 | iou.append(EMPTY)
71 | else:
72 | iou.append(float(intersection) / float(union))
73 | ious.append(iou)
74 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
75 | return 100 * np.array(ious)
76 |
77 |
78 | # --------------------------- BINARY LOSSES ---------------------------
79 |
80 |
81 | def lovasz_hinge(logits, labels, per_image=True, ignore=None):
82 | """
83 | Binary Lovasz hinge loss
84 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
85 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
86 | per_image: compute the loss per image instead of per batch
87 | ignore: void class id
88 | """
89 | if per_image:
90 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
91 | for log, lab in zip(logits, labels))
92 | else:
93 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
94 | return loss
95 |
96 |
97 | def lovasz_hinge_flat(logits, labels):
98 | """
99 | Binary Lovasz hinge loss
100 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
101 | labels: [P] Tensor, binary ground truth labels (0 or 1)
102 | ignore: label to ignore
103 | """
104 | if len(labels) == 0:
105 | # only void pixels, the gradients should be 0
106 | return logits.sum() * 0.
107 | signs = 2. * labels.float() - 1.
108 | errors = (1. - logits * Variable(signs))
109 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
110 | perm = perm.data
111 | gt_sorted = labels[perm]
112 | grad = lovasz_grad(gt_sorted)
113 | loss = torch.dot(F.relu(errors_sorted), Variable(grad))
114 | return loss
115 |
116 |
117 | def flatten_binary_scores(scores, labels, ignore=None):
118 | """
119 | Flattens predictions in the batch (binary case)
120 | Remove labels equal to 'ignore'
121 | """
122 | scores = scores.view(-1)
123 | labels = labels.view(-1)
124 | if ignore is None:
125 | return scores, labels
126 | valid = (labels != ignore)
127 | vscores = scores[valid]
128 | vlabels = labels[valid]
129 | return vscores, vlabels
130 |
131 |
132 | class StableBCELoss(torch.nn.modules.Module):
133 | def __init__(self):
134 | super(StableBCELoss, self).__init__()
135 | def forward(self, input, target):
136 | neg_abs = - input.abs()
137 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
138 | return loss.mean()
139 |
140 |
141 | def binary_xloss(logits, labels, ignore=None):
142 | """
143 | Binary Cross entropy loss
144 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
145 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
146 | ignore: void class id
147 | """
148 | logits, labels = flatten_binary_scores(logits, labels, ignore)
149 | loss = StableBCELoss()(logits, Variable(labels.float()))
150 | return loss
151 |
152 |
153 | # --------------------------- MULTICLASS LOSSES ---------------------------
154 |
155 |
156 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
157 | """
158 | Multi-class Lovasz-Softmax loss
159 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
160 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
161 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
162 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
163 | per_image: compute the loss per image instead of per batch
164 | ignore: void class labels
165 | """
166 | if per_image:
167 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
168 | for prob, lab in zip(probas, labels))
169 | else:
170 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
171 | return loss
172 |
173 |
174 | def lovasz_softmax_flat(probas, labels, classes='present'):
175 | """
176 | Multi-class Lovasz-Softmax loss
177 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
178 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
179 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
180 | """
181 | if probas.numel() == 0:
182 | # only void pixels, the gradients should be 0
183 | return probas * 0.
184 | C = probas.size(1)
185 | losses = []
186 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
187 | for c in class_to_sum:
188 | fg = (labels == c).float() # foreground for class c
189 | if (classes is 'present' and fg.sum() == 0):
190 | continue
191 | if C == 1:
192 | if len(classes) > 1:
193 | raise ValueError('Sigmoid output possible only with 1 class')
194 | class_pred = probas[:, 0]
195 | else:
196 | class_pred = probas[:, c]
197 | errors = (Variable(fg) - class_pred).abs()
198 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
199 | perm = perm.data
200 | fg_sorted = fg[perm]
201 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
202 | return mean(losses)
203 |
204 |
205 | def flatten_probas(probas, labels, ignore=None):
206 | """
207 | Flattens predictions in the batch
208 | """
209 | if probas.dim() == 3:
210 | # assumes output of a sigmoid layer
211 | B, H, W = probas.size()
212 | probas = probas.view(B, 1, H, W)
213 | elif probas.dim() == 5:
214 | #3D segmentation
215 | B, C, L, H, W = probas.size()
216 | probas = probas.contiguous().view(B, C, L, H*W)
217 | B, C, H, W = probas.size()
218 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
219 | labels = labels.view(-1)
220 | if ignore is None:
221 | return probas, labels
222 | valid = (labels != ignore)
223 | vprobas = probas[valid.nonzero().squeeze()]
224 | vlabels = labels[valid]
225 | return vprobas, vlabels
226 |
227 | def xloss(logits, labels, ignore=None):
228 | """
229 | Cross entropy loss
230 | """
231 | return F.cross_entropy(logits, Variable(labels), ignore_index=255)
232 |
233 | def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None):
234 | """
235 | Something wrong with this loss
236 | Multi-class Lovasz-Softmax loss
237 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
238 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
239 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
240 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
241 | per_image: compute the loss per image instead of per batch
242 | ignore: void class labels
243 | """
244 | vprobas, vlabels = flatten_probas(probas, labels, ignore)
245 |
246 |
247 | true_1_hot = torch.eye(vprobas.shape[1])[vlabels]
248 |
249 | if bk_class:
250 | one_hot_assignment = torch.ones_like(vlabels)
251 | one_hot_assignment[vlabels == bk_class] = 0
252 | one_hot_assignment = one_hot_assignment.float().unsqueeze(1)
253 | true_1_hot = true_1_hot*one_hot_assignment
254 |
255 | true_1_hot = true_1_hot.to(vprobas.device)
256 | intersection = torch.sum(vprobas * true_1_hot)
257 | cardinality = torch.sum(vprobas + true_1_hot)
258 | loss = (intersection + smooth / (cardinality - intersection + smooth)).mean()
259 | return (1-loss)*smooth
260 |
261 | def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100):
262 | """
263 | Multi-class Hinge Jaccard loss
264 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
265 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
266 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
267 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
268 | ignore: void class labels
269 | """
270 | vprobas, vlabels = flatten_probas(probas, labels, ignore)
271 | C = vprobas.size(1)
272 | losses = []
273 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
274 | for c in class_to_sum:
275 | if c in vlabels:
276 | c_sample_ind = vlabels == c
277 | cprobas = vprobas[c_sample_ind,:]
278 | non_c_ind =np.array([a for a in class_to_sum if a != c])
279 | class_pred = cprobas[:,c]
280 | max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0]
281 | TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth
282 | FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge)
283 |
284 | if (~c_sample_ind).sum() == 0:
285 | FP = 0
286 | else:
287 | nonc_probas = vprobas[~c_sample_ind,:]
288 | class_pred = nonc_probas[:,c]
289 | max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0]
290 | FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.)
291 |
292 | losses.append(1 - TP/(TP+FP+FN))
293 |
294 | if len(losses) == 0: return 0
295 | return mean(losses)
296 |
297 | # --------------------------- HELPER FUNCTIONS ---------------------------
298 | def isnan(x):
299 | return x != x
300 |
301 |
302 | def mean(l, ignore_nan=False, empty=0):
303 | """
304 | nanmean compatible with generators.
305 | """
306 | l = iter(l)
307 | if ignore_nan:
308 | l = ifilterfalse(isnan, l)
309 | try:
310 | n = 1
311 | acc = next(l)
312 | except StopIteration:
313 | if empty == 'raise':
314 | raise ValueError('Empty mean')
315 | return empty
316 | for n, v in enumerate(l, 2):
317 | acc += v
318 | if n == 1:
319 | return acc
320 | return acc / n
321 |
--------------------------------------------------------------------------------
/dataloader/pc_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: pc_dataset.py
4 |
5 | import os
6 | import numpy as np
7 | from torch.utils import data
8 | import yaml
9 | import pickle
10 | import pathlib
11 | REGISTERED_PC_DATASET_CLASSES = {}
12 |
13 |
14 | def register_dataset(cls, name=None):
15 | global REGISTERED_PC_DATASET_CLASSES
16 | if name is None:
17 | name = cls.__name__
18 | assert name not in REGISTERED_PC_DATASET_CLASSES, f"exist class: {REGISTERED_PC_DATASET_CLASSES}"
19 | REGISTERED_PC_DATASET_CLASSES[name] = cls
20 | return cls
21 |
22 |
23 | def get_pc_model_class(name):
24 | global REGISTERED_PC_DATASET_CLASSES
25 | # print(REGISTERED_PC_DATASET_CLASSES)
26 | assert name in REGISTERED_PC_DATASET_CLASSES, f"available class: {REGISTERED_PC_DATASET_CLASSES}"
27 | return REGISTERED_PC_DATASET_CLASSES[name]
28 |
29 |
30 | @register_dataset
31 | class SemKITTI_sk(data.Dataset):
32 | def __init__(self, data_path, imageset='train',
33 | return_ref=False, label_mapping="semantic-kitti.yaml", nusc=None):
34 | self.return_ref = return_ref
35 | with open(label_mapping, 'r') as stream:
36 | semkittiyaml = yaml.safe_load(stream)
37 | self.learning_map = semkittiyaml['learning_map']
38 | self.imageset = imageset
39 | if imageset == 'train':
40 | split = semkittiyaml['split']['train']
41 | elif imageset == 'val':
42 | split = semkittiyaml['split']['valid']
43 | elif imageset == 'test':
44 | split = semkittiyaml['split']['test']
45 | else:
46 | raise Exception('Split must be train/val/test')
47 |
48 | self.im_idx = []
49 | for i_folder in split:
50 | self.im_idx += absoluteFilePaths('/'.join([data_path, str(i_folder).zfill(2), 'velodyne']))
51 |
52 | def __len__(self):
53 | 'Denotes the total number of samples'
54 | return len(self.im_idx)
55 |
56 | def __getitem__(self, index):
57 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4))
58 | if self.imageset == 'test':
59 | annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1)
60 | else:
61 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label',
62 | dtype=np.uint32).reshape((-1, 1))
63 | annotated_data = annotated_data & 0xFFFF # delete high 16 digits binary
64 | annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data)
65 |
66 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8))
67 | if self.return_ref:
68 | data_tuple += (raw_data[:, 3],)
69 | return data_tuple
70 |
71 |
72 | def absoluteFilePaths(directory):
73 | for dirpath, _, filenames in os.walk(directory):
74 | filenames.sort()
75 | for f in filenames:
76 | yield os.path.abspath(os.path.join(dirpath, f))
77 |
78 |
79 | def SemKITTI2train(label):
80 | if isinstance(label, list):
81 | return [SemKITTI2train_single(a) for a in label]
82 | else:
83 | return SemKITTI2train_single(label)
84 |
85 |
86 | def SemKITTI2train_single(label):
87 | remove_ind = label == 0
88 | label -= 1
89 | label[remove_ind] = 255
90 | return label
91 |
92 |
93 | def unpack(compressed): # from samantickitti api
94 | ''' given a bit encoded voxel grid, make a normal voxel grid out of it. '''
95 | uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8)
96 | uncompressed[::8] = compressed[:] >> 7 & 1
97 | uncompressed[1::8] = compressed[:] >> 6 & 1
98 | uncompressed[2::8] = compressed[:] >> 5 & 1
99 | uncompressed[3::8] = compressed[:] >> 4 & 1
100 | uncompressed[4::8] = compressed[:] >> 3 & 1
101 | uncompressed[5::8] = compressed[:] >> 2 & 1
102 | uncompressed[6::8] = compressed[:] >> 1 & 1
103 | uncompressed[7::8] = compressed[:] & 1
104 |
105 | return uncompressed
106 |
107 | def get_eval_mask(labels, invalid_voxels): # from samantickitti api
108 | """
109 | Ignore labels set to 255 and invalid voxels (the ones never hit by a laser ray, probed using ray tracing)
110 | :param labels: input ground truth voxels
111 | :param invalid_voxels: voxels ignored during evaluation since the lie beyond the scene that was captured by the laser
112 | :return: boolean mask to subsample the voxels to evaluate
113 | """
114 | masks = np.ones_like(labels, dtype=np.bool)
115 | masks[labels == 255] = False
116 | masks[invalid_voxels == 1] = False
117 |
118 | return masks
119 |
120 |
121 | from os.path import join
122 | @register_dataset
123 | class SemKITTI_sk_multiscan(data.Dataset):
124 | def __init__(self, data_path, imageset='train',stride=5,return_ref=False, label_mapping="semantic-kitti-multiscan.yaml", nusc=None):
125 | self.return_ref = return_ref
126 | with open(label_mapping, 'r') as stream:
127 | semkittiyaml = yaml.safe_load(stream)
128 | ### remap completion label
129 | remapdict = semkittiyaml['learning_map']
130 | # make lookup table for mapping
131 | maxkey = max(remapdict.keys())
132 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32)
133 | remap_lut[list(remapdict.keys())] = list(remapdict.values())
134 | # in completion we have to distinguish empty and invalid voxels.
135 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled".
136 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid'
137 | remap_lut[0] = 0 # only 'empty' stays 'empty'.
138 | self.comletion_remap_lut = remap_lut
139 |
140 | self.learning_map = semkittiyaml['learning_map']
141 | self.imageset = imageset
142 | self.data_path = data_path
143 | if imageset == 'train':
144 | split = semkittiyaml['split']['train']
145 | elif imageset == 'val':
146 | split = semkittiyaml['split']['valid']
147 | elif imageset == 'test':
148 | split = semkittiyaml['split']['test']
149 | else:
150 | raise Exception('Split must be train/val/test')
151 | multiscan = 0 # additional frames are fused with target-frame. Hence, multiscan+1 point clouds in total
152 | print('multiscan: %d' %multiscan)
153 | self.multiscan = multiscan
154 | self.im_idx = []
155 | self.stride = stride
156 | self.accumulation=1
157 | self.calibrations = []
158 | self.times = []
159 | self.poses = []
160 | self.use_test_time_adaptation=True
161 | self.load_calib_poses()
162 | for i_folder in split:
163 | # velodyne path corresponding to voxel path
164 | complete_path = os.path.join(data_path, str(i_folder).zfill(2), "voxels")
165 | files = list(pathlib.Path(complete_path).glob('*.bin'))
166 |
167 | for filename in files:
168 | self.im_idx.append(str(filename).replace('voxels', 'velodyne'))
169 | if i_folder==8:
170 | self.im_idx.sort()
171 |
172 |
173 |
174 | def __len__(self):
175 | 'Denotes the total number of samples'
176 | return len(self.im_idx)
177 |
178 | def load_calib_poses(self):
179 | """
180 | load calib poses and times.
181 | """
182 |
183 | ###########
184 | # Load data
185 | ###########
186 |
187 | self.calibrations = []
188 | self.times = []
189 | self.poses = []
190 |
191 | for seq in range(0, 22):
192 | seq_folder = join(self.data_path, str(seq).zfill(2))
193 |
194 | # Read Calib
195 | self.calibrations.append(self.parse_calibration(join(seq_folder, "calib.txt")))
196 |
197 | # Read times
198 | self.times.append(np.loadtxt(join(seq_folder, 'times.txt'), dtype=np.float32))
199 |
200 | # Read poses
201 | poses_f64 = self.parse_poses(join(seq_folder, 'poses.txt'), self.calibrations[-1])
202 | self.poses.append([pose.astype(np.float32) for pose in poses_f64])
203 |
204 | def parse_calibration(self, filename):
205 | """ read calibration file with given filename
206 |
207 | Returns
208 | -------
209 | dict
210 | Calibration matrices as 4x4 numpy arrays.
211 | """
212 | calib = {}
213 |
214 | calib_file = open(filename)
215 | for line in calib_file:
216 | key, content = line.strip().split(":")
217 | values = [float(v) for v in content.strip().split()]
218 |
219 | pose = np.zeros((4, 4))
220 | pose[0, 0:4] = values[0:4]
221 | pose[1, 0:4] = values[4:8]
222 | pose[2, 0:4] = values[8:12]
223 | pose[3, 3] = 1.0
224 |
225 | calib[key] = pose
226 |
227 | calib_file.close()
228 |
229 | return calib
230 |
231 | def parse_poses(self, filename, calibration):
232 | """ read poses file with per-scan poses from given filename
233 |
234 | Returns
235 | -------
236 | list
237 | list of poses as 4x4 numpy arrays.
238 | """
239 | file = open(filename)
240 |
241 | poses = []
242 | Tr = calibration["Tr"]
243 | Tr_inv = np.linalg.inv(Tr)
244 |
245 | for line in file:
246 | values = [float(v) for v in line.strip().split()]
247 |
248 | pose = np.zeros((4, 4))
249 | pose[0, 0:4] = values[0:4]
250 | pose[1, 0:4] = values[4:8]
251 | pose[2, 0:4] = values[8:12]
252 | pose[3, 3] = 1.0
253 |
254 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))
255 |
256 | return poses
257 |
258 | def fuse_multi_scan(self, points, pose0, pose):
259 |
260 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
261 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)
262 | new_points = new_points[:, :3]
263 | new_coords = new_points - pose0[:3, 3]
264 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1)
265 | new_coords = np.hstack((new_coords, points[:, 3:]))
266 |
267 | return new_coords
268 |
269 | def frame_transform_scan(self, points, pose0, pose):
270 |
271 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
272 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)
273 | new_points = new_points[:, :3]
274 | new_coords = new_points - pose0[:3, 3]
275 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1)
276 | new_coords = np.hstack((new_coords, points[:, 3:]))
277 |
278 | return new_coords
279 |
280 | def __getitem__(self, index):
281 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) # point cloud
282 | origin_len = len(raw_data)
283 | voxel_label = 1
284 | number_idx = int(self.im_idx[index][-10:-4])
285 |
286 |
287 | dir_idx = int(self.im_idx[index][-22:-20])
288 | pose_list=[]
289 | pose0 = self.poses[dir_idx][number_idx]
290 | pose_list.append(pose0)
291 |
292 | prev_scans=[]
293 | prev_annotated_data=[]
294 | stride_list=[]
295 | for stride_ in self.stride:
296 | data_idx = number_idx + stride_
297 |
298 | if data_idx<0:
299 | continue
300 |
301 | path = self.im_idx[index][:-10]
302 | newpath_prev = path + str(data_idx).zfill(6) + self.im_idx[index][-4:]
303 |
304 | voxel_path = path.replace('velodyne', 'labels')
305 | files = list(pathlib.Path(voxel_path).glob('*.label'))
306 |
307 | if data_idx < len(files):
308 | pose = self.poses[dir_idx][data_idx]
309 | pose_list.append(pose)
310 | raw_data_prev = np.fromfile(newpath_prev, dtype=np.float32).reshape((-1, 4))
311 | prev_scans.append(raw_data_prev)
312 | annotated_data_prev = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1))
313 | annotated_data_prev = self.comletion_remap_lut[annotated_data_prev]
314 | annotated_data_prev = annotated_data_prev.reshape((256, 256, 32))
315 | prev_annotated_data.append(annotated_data_prev)
316 | stride_list.append(stride_)
317 |
318 | if len(pose_list)>1:
319 | prev_annotated_data=np.stack(prev_annotated_data,0)
320 | prev_scans_xyz=[prev_scan[:, :3] for prev_scan in prev_scans]
321 | prev_scans_i=[prev_scan[:, 3] for prev_scan in prev_scans]
322 | prev_data_tuple=(prev_scans_xyz, prev_annotated_data.astype(np.uint8))
323 | prev_data_tuple += (prev_scans_i,origin_len)
324 | else:
325 | prev_data_tuple=(np.zeros([1,1,1]),np.zeros([1,1,1]),np.zeros([1,1,1]),1)
326 |
327 | if self.imageset == 'test':
328 | annotated_data = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1))
329 | else:
330 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'voxels')[:-3] + 'label',
331 | dtype=np.uint16).reshape((-1, 1)) # voxel labels
332 |
333 | annotated_data = self.comletion_remap_lut[annotated_data]
334 | annotated_data = annotated_data.reshape((256, 256, 32))
335 |
336 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) # xyz, voxel labels
337 | data_tuple += (raw_data[:, 3], origin_len) # origin_len is used to indicate the length of target-scan
338 |
339 | return data_tuple,prev_data_tuple,pose_list,stride_list
340 |
341 |
342 | # load Semantic KITTI class info
343 | def get_SemKITTI_label_name(label_mapping):
344 | with open(label_mapping, 'r') as stream:
345 | semkittiyaml = yaml.safe_load(stream)
346 | SemKITTI_label_name = dict()
347 | for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]:
348 | SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i]
349 |
350 | return SemKITTI_label_name
351 |
352 |
353 | def get_nuScenes_label_name(label_mapping):
354 | with open(label_mapping, 'r') as stream:
355 | nuScenesyaml = yaml.safe_load(stream)
356 | nuScenes_label_name = dict()
357 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]:
358 | val_ = nuScenesyaml['learning_map'][i]
359 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_]
360 |
361 | return nuScenes_label_name
362 |
--------------------------------------------------------------------------------
/dataloader/pc_dataset_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 | # @file: pc_dataset.py
4 |
5 | import os
6 | import numpy as np
7 | from torch.utils import data
8 | import yaml
9 | import pickle
10 | import pathlib
11 | REGISTERED_PC_DATASET_CLASSES = {}
12 |
13 |
14 | def register_dataset(cls, name=None):
15 | global REGISTERED_PC_DATASET_CLASSES
16 | if name is None:
17 | name = cls.__name__
18 | assert name not in REGISTERED_PC_DATASET_CLASSES, f"exist class: {REGISTERED_PC_DATASET_CLASSES}"
19 | REGISTERED_PC_DATASET_CLASSES[name] = cls
20 | return cls
21 |
22 |
23 | def get_pc_model_class(name):
24 | global REGISTERED_PC_DATASET_CLASSES
25 | # print(REGISTERED_PC_DATASET_CLASSES)
26 | assert name in REGISTERED_PC_DATASET_CLASSES, f"available class: {REGISTERED_PC_DATASET_CLASSES}"
27 | return REGISTERED_PC_DATASET_CLASSES[name]
28 |
29 |
30 | @register_dataset
31 | class SemKITTI_sk(data.Dataset):
32 | def __init__(self, data_path, imageset='train',
33 | return_ref=False, label_mapping="semantic-kitti.yaml", nusc=None):
34 | self.return_ref = return_ref
35 | with open(label_mapping, 'r') as stream:
36 | semkittiyaml = yaml.safe_load(stream)
37 | self.learning_map = semkittiyaml['learning_map']
38 | self.imageset = imageset
39 | if imageset == 'train':
40 | split = semkittiyaml['split']['train']
41 | elif imageset == 'val':
42 | split = semkittiyaml['split']['valid']
43 | elif imageset == 'test':
44 | split = semkittiyaml['split']['test']
45 | else:
46 | raise Exception('Split must be train/val/test')
47 |
48 | self.im_idx = []
49 | for i_folder in split:
50 | self.im_idx += absoluteFilePaths('/'.join([data_path, str(i_folder).zfill(2), 'velodyne']))
51 |
52 | def __len__(self):
53 | 'Denotes the total number of samples'
54 | return len(self.im_idx)
55 |
56 | def __getitem__(self, index):
57 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4))
58 | if self.imageset == 'test':
59 | annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1)
60 | else:
61 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label',
62 | dtype=np.uint32).reshape((-1, 1))
63 | annotated_data = annotated_data & 0xFFFF # delete high 16 digits binary
64 | annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data)
65 |
66 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8))
67 | if self.return_ref:
68 | data_tuple += (raw_data[:, 3],)
69 | return data_tuple
70 |
71 |
72 | def absoluteFilePaths(directory):
73 | for dirpath, _, filenames in os.walk(directory):
74 | filenames.sort()
75 | for f in filenames:
76 | yield os.path.abspath(os.path.join(dirpath, f))
77 |
78 |
79 | def SemKITTI2train(label):
80 | if isinstance(label, list):
81 | return [SemKITTI2train_single(a) for a in label]
82 | else:
83 | return SemKITTI2train_single(label)
84 |
85 |
86 | def SemKITTI2train_single(label):
87 | remove_ind = label == 0
88 | label -= 1
89 | label[remove_ind] = 255
90 | return label
91 |
92 |
93 | def unpack(compressed): # from samantickitti api
94 | ''' given a bit encoded voxel grid, make a normal voxel grid out of it. '''
95 | uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8)
96 | uncompressed[::8] = compressed[:] >> 7 & 1
97 | uncompressed[1::8] = compressed[:] >> 6 & 1
98 | uncompressed[2::8] = compressed[:] >> 5 & 1
99 | uncompressed[3::8] = compressed[:] >> 4 & 1
100 | uncompressed[4::8] = compressed[:] >> 3 & 1
101 | uncompressed[5::8] = compressed[:] >> 2 & 1
102 | uncompressed[6::8] = compressed[:] >> 1 & 1
103 | uncompressed[7::8] = compressed[:] & 1
104 |
105 | return uncompressed
106 |
107 | def get_eval_mask(labels, invalid_voxels): # from samantickitti api
108 | """
109 | Ignore labels set to 255 and invalid voxels (the ones never hit by a laser ray, probed using ray tracing)
110 | :param labels: input ground truth voxels
111 | :param invalid_voxels: voxels ignored during evaluation since the lie beyond the scene that was captured by the laser
112 | :return: boolean mask to subsample the voxels to evaluate
113 | """
114 | masks = np.ones_like(labels, dtype=np.bool)
115 | masks[labels == 255] = False
116 | masks[invalid_voxels == 1] = False
117 |
118 | return masks
119 |
120 |
121 | from os.path import join
122 | @register_dataset
123 | class SemKITTI_sk_multiscan(data.Dataset):
124 | def __init__(self, data_path, imageset='train',stride=5,return_ref=False, label_mapping="semantic-kitti-multiscan.yaml", nusc=None,sq_num=10):
125 | self.return_ref = return_ref
126 | with open(label_mapping, 'r') as stream:
127 | semkittiyaml = yaml.safe_load(stream)
128 | ### remap completion label
129 | remapdict = semkittiyaml['learning_map']
130 | # make lookup table for mapping
131 | maxkey = max(remapdict.keys())
132 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32)
133 | remap_lut[list(remapdict.keys())] = list(remapdict.values())
134 | # in completion we have to distinguish empty and invalid voxels.
135 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled".
136 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid'
137 | remap_lut[0] = 0 # only 'empty' stays 'empty'.
138 | self.comletion_remap_lut = remap_lut
139 |
140 | self.learning_map = semkittiyaml['learning_map']
141 | self.imageset = imageset
142 | self.data_path = data_path
143 | if imageset == 'train':
144 | split = semkittiyaml['split']['train']
145 | elif imageset == 'val':
146 | split = semkittiyaml['split']['valid']
147 | elif imageset == 'test':
148 | # split = semkittiyaml['split']['test']
149 | split = [sq_num]
150 | else:
151 | raise Exception('Split must be train/val/test')
152 | # import pdb;pdb.set_trace()
153 | multiscan = 0 # additional frames are fused with target-frame. Hence, multiscan+1 point clouds in total
154 | print('multiscan: %d' %multiscan)
155 | self.multiscan = multiscan
156 | self.im_idx = []
157 | self.stride = stride
158 | self.accumulation=1
159 | self.calibrations = []
160 | self.times = []
161 | self.poses = []
162 | self.use_test_time_adaptation=True
163 | self.load_calib_poses()
164 | for i_folder in split:
165 | # velodyne path corresponding to voxel path
166 | complete_path = os.path.join(data_path, str(i_folder).zfill(2), "voxels")
167 | files = list(pathlib.Path(complete_path).glob('*.bin'))
168 |
169 | for filename in files:
170 | self.im_idx.append(str(filename).replace('voxels', 'velodyne'))
171 |
172 | self.im_idx.sort()
173 |
174 |
175 |
176 | def __len__(self):
177 | 'Denotes the total number of samples'
178 | return len(self.im_idx)
179 |
180 | def load_calib_poses(self):
181 | """
182 | load calib poses and times.
183 | """
184 |
185 | ###########
186 | # Load data
187 | ###########
188 |
189 | self.calibrations = []
190 | self.times = []
191 | self.poses = []
192 |
193 | for seq in range(0, 22):
194 | seq_folder = join(self.data_path, str(seq).zfill(2))
195 |
196 | # Read Calib
197 | self.calibrations.append(self.parse_calibration(join(seq_folder, "calib.txt")))
198 |
199 | # Read times
200 | self.times.append(np.loadtxt(join(seq_folder, 'times.txt'), dtype=np.float32))
201 |
202 | # Read poses
203 | poses_f64 = self.parse_poses(join(seq_folder, 'poses.txt'), self.calibrations[-1])
204 | self.poses.append([pose.astype(np.float32) for pose in poses_f64])
205 |
206 | def parse_calibration(self, filename):
207 | """ read calibration file with given filename
208 |
209 | Returns
210 | -------
211 | dict
212 | Calibration matrices as 4x4 numpy arrays.
213 | """
214 | calib = {}
215 |
216 | calib_file = open(filename)
217 | for line in calib_file:
218 | key, content = line.strip().split(":")
219 | values = [float(v) for v in content.strip().split()]
220 |
221 | pose = np.zeros((4, 4))
222 | pose[0, 0:4] = values[0:4]
223 | pose[1, 0:4] = values[4:8]
224 | pose[2, 0:4] = values[8:12]
225 | pose[3, 3] = 1.0
226 |
227 | calib[key] = pose
228 |
229 | calib_file.close()
230 |
231 | return calib
232 |
233 | def parse_poses(self, filename, calibration):
234 | """ read poses file with per-scan poses from given filename
235 |
236 | Returns
237 | -------
238 | list
239 | list of poses as 4x4 numpy arrays.
240 | """
241 | file = open(filename)
242 |
243 | poses = []
244 | Tr = calibration["Tr"]
245 | Tr_inv = np.linalg.inv(Tr)
246 |
247 | for line in file:
248 | values = [float(v) for v in line.strip().split()]
249 |
250 | pose = np.zeros((4, 4))
251 | pose[0, 0:4] = values[0:4]
252 | pose[1, 0:4] = values[4:8]
253 | pose[2, 0:4] = values[8:12]
254 | pose[3, 3] = 1.0
255 |
256 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr)))
257 |
258 | return poses
259 |
260 | def fuse_multi_scan(self, points, pose0, pose):
261 |
262 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
263 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)
264 | new_points = new_points[:, :3]
265 | new_coords = new_points - pose0[:3, 3]
266 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1)
267 | new_coords = np.hstack((new_coords, points[:, 3:]))
268 |
269 | return new_coords
270 |
271 | def frame_transform_scan(self, points, pose0, pose):
272 |
273 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
274 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)
275 | new_points = new_points[:, :3]
276 | new_coords = new_points - pose0[:3, 3]
277 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1)
278 | new_coords = np.hstack((new_coords, points[:, 3:]))
279 |
280 | return new_coords
281 |
282 | def __getitem__(self, index):
283 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) # point cloud
284 | origin_len = len(raw_data)
285 | voxel_label = 1
286 | number_idx = int(self.im_idx[index][-10:-4])
287 | dir_idx = int(self.im_idx[index][-22:-20])
288 | pose_list=[]
289 | pose0 = self.poses[dir_idx][number_idx]
290 | pose_list.append(pose0)
291 |
292 | prev_scans=[]
293 | prev_annotated_data=[]
294 | stride_list=[]
295 | for stride_ in self.stride:
296 | data_idx = number_idx + stride_
297 |
298 | if data_idx<0:
299 | continue
300 |
301 | path = self.im_idx[index][:-10]
302 | newpath_prev = path + str(data_idx).zfill(6) + self.im_idx[index][-4:]
303 | voxel_path = path.replace('velodyne', 'labels')
304 | files=list(pathlib.Path(path).glob('*.bin'))
305 |
306 | if data_idx < len(files):
307 | pose = self.poses[dir_idx][data_idx]
308 | pose_list.append(pose)
309 | raw_data_prev = np.fromfile(newpath_prev, dtype=np.float32).reshape((-1, 4))
310 | prev_scans.append(raw_data_prev)
311 | annotated_data_prev = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1))
312 | annotated_data_prev = self.comletion_remap_lut[annotated_data_prev]
313 | annotated_data_prev = annotated_data_prev.reshape((256, 256, 32))
314 | prev_annotated_data.append(annotated_data_prev)
315 | stride_list.append(stride_)
316 |
317 | if len(pose_list)>1:
318 | prev_annotated_data=np.stack(prev_annotated_data,0)
319 | prev_scans_xyz=[prev_scan[:, :3] for prev_scan in prev_scans]
320 | prev_scans_i=[prev_scan[:, 3] for prev_scan in prev_scans]
321 | prev_data_tuple=(prev_scans_xyz, prev_annotated_data.astype(np.uint8))
322 | prev_data_tuple += (prev_scans_i,origin_len)
323 | else:
324 | prev_data_tuple=(np.zeros([1,1,1]),np.zeros([1,1,1]),np.zeros([1,1,1]),1)
325 |
326 | if self.imageset == 'test':
327 | annotated_data = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1))
328 | else:
329 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'voxels')[:-3] + 'label',
330 | dtype=np.uint16).reshape((-1, 1)) # voxel labels
331 |
332 | annotated_data = self.comletion_remap_lut[annotated_data]
333 | annotated_data = annotated_data.reshape((256, 256, 32))
334 |
335 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) # xyz, voxel labels
336 | data_tuple += (raw_data[:, 3], origin_len) # origin_len is used to indicate the length of target-scan
337 |
338 | return data_tuple,prev_data_tuple,pose_list,stride_list
339 |
340 |
341 | # load Semantic KITTI class info
342 | def get_SemKITTI_label_name(label_mapping):
343 | with open(label_mapping, 'r') as stream:
344 | semkittiyaml = yaml.safe_load(stream)
345 | SemKITTI_label_name = dict()
346 | for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]:
347 | SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i]
348 |
349 | return SemKITTI_label_name
350 |
351 |
352 | def get_nuScenes_label_name(label_mapping):
353 | with open(label_mapping, 'r') as stream:
354 | nuScenesyaml = yaml.safe_load(stream)
355 | nuScenes_label_name = dict()
356 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]:
357 | val_ = nuScenesyaml['learning_map'][i]
358 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_]
359 |
360 | return nuScenes_label_name
361 |
--------------------------------------------------------------------------------
/dataloader/dataset_semantickitti.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 |
4 | """
5 | SemKITTI dataloader
6 | """
7 | import numpy as np
8 | import torch
9 | import numba as nb
10 | from torch.utils import data
11 | import time
12 | import random
13 |
14 | from utils.util import Bresenham3D
15 | REGISTERED_DATASET_CLASSES = {}
16 |
17 | def from_voxel_to_voxel(max_volume_space,min_volume_space,intervals,pose0,pose):
18 | x_bias = (max_volume_space[0] - min_volume_space[0])/2
19 | max_bound = np.asarray(max_volume_space)
20 | min_bound = np.asarray(min_volume_space)
21 | min_bound[0] -= x_bias
22 | max_bound[0] -= x_bias
23 | max_bound2 = np.asarray(max_volume_space)
24 | min_bound2 = np.asarray(min_volume_space)
25 |
26 | voxel_grid = np.indices((256, 256, 32)).transpose(1, 2, 3, 0)
27 | voxel_grid = voxel_grid.reshape(-1,3)
28 | full_voxel_center=(voxel_grid.astype(np.float32) + 0.5) * intervals + min_bound
29 | full_voxel_center[:,0]+= x_bias
30 | current_vox_center=frame_transform_scan(full_voxel_center,pose0,pose)
31 | current_vox_center=np.concatenate([current_vox_center,voxel_grid],1)
32 | vox_xyz0 = current_vox_center
33 | for ci in range(3):
34 | vox_xyz0[current_vox_center[:, ci] < min_bound2[ci], :] = 1000
35 | vox_xyz0[current_vox_center[:, ci] > max_bound2[ci], :] = 1000
36 | vox_valid_inds = vox_xyz0[:, 0] != 1000
37 | current_vox_center = current_vox_center[vox_valid_inds, :]
38 | current_vox_center[:, 0]-= x_bias # current_vox_center is
39 | vox_grid_from=current_vox_center[:,-3:]
40 | vox_grid_to=current_vox_center[:,:3]
41 | vox_grid_to = (np.floor((np.clip(vox_grid_to, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
42 | return vox_grid_from,vox_grid_to
43 |
44 | def frame_transform_scan(points, pose0, pose):
45 |
46 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
47 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)
48 | new_points = new_points[:, :3]
49 | new_coords = new_points - pose0[:3, 3]
50 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1)
51 | new_coords = np.hstack((new_coords, points[:, 3:]))
52 |
53 | return new_coords
54 |
55 | def register_dataset(cls, name=None):
56 | global REGISTERED_DATASET_CLASSES
57 | if name is None:
58 | name = cls.__name__
59 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}"
60 | REGISTERED_DATASET_CLASSES[name] = cls
61 | return cls
62 |
63 |
64 | def get_model_class(name):
65 | global REGISTERED_DATASET_CLASSES
66 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}"
67 | return REGISTERED_DATASET_CLASSES[name]
68 |
69 |
70 | @register_dataset
71 | class voxel_dataset(data.Dataset):
72 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False,
73 | fixed_volume_space=False, max_volume_space=[50, 50, 1.5], min_volume_space=[-50, -50, -3]):
74 | 'Initialization'
75 | self.point_cloud_dataset = in_dataset
76 | self.grid_size = np.asarray(grid_size)
77 | self.rotate_aug = rotate_aug
78 | self.ignore_label = ignore_label
79 | self.return_test = return_test
80 | self.flip_aug = flip_aug
81 | self.fixed_volume_space = fixed_volume_space
82 | self.max_volume_space = max_volume_space
83 | self.min_volume_space = min_volume_space
84 |
85 | def __len__(self):
86 | 'Denotes the total number of samples'
87 | return len(self.point_cloud_dataset)
88 |
89 | def __getitem__(self, index):
90 | 'Generates one sample of data'
91 | data,prev_data,pose_list,stride_list = self.point_cloud_dataset[index]
92 |
93 | if len(data) == 4:
94 | xyz, labels, sig, origin_len = data
95 | prev_xyz, prev_labels, prev_sig,_ = prev_data
96 | if len(sig.shape) == 2: sig = np.squeeze(sig)
97 | else:
98 | raise Exception('Return invalid data tuple')
99 |
100 | origin_len = len(xyz)
101 | max_bound = np.asarray(self.max_volume_space)
102 | min_bound = np.asarray(self.min_volume_space)
103 |
104 | ### Cut point cloud and segmentation label for valid range
105 | xyz0 = xyz
106 | for ci in range(3):
107 | xyz0[xyz[:, ci] < min_bound[ci], :] = 1000
108 | xyz0[xyz[:, ci] > max_bound[ci], :] = 1000
109 | valid_inds = xyz0[:, 0] != 1000
110 | xyz = xyz[valid_inds, :]
111 | sig = sig[valid_inds]
112 |
113 | ### post_scan_preprocess ###
114 |
115 | prev_raws=[]
116 | prev_sigs=[]
117 | prev_vox=[]
118 | prev_velodyne=[]
119 | prev_trans_sigs=[]
120 |
121 | if len(pose_list)>1:
122 | prev_frames_num=len(prev_xyz)
123 | for f_idx in range(prev_frames_num):
124 |
125 | ### cut point clound
126 | prev_xyz0 = prev_xyz[f_idx]
127 | prev_xyz_sin=prev_xyz[f_idx]
128 | prev_sig_sin=prev_sig[f_idx]
129 | for ci in range(3):
130 | prev_xyz0[prev_xyz_sin[:, ci] < min_bound[ci], :] = 1000
131 | prev_xyz0[prev_xyz_sin[:, ci] > max_bound[ci], :] = 1000
132 | valid_inds = prev_xyz0[:, 0] != 1000
133 | prev_xyz_single = prev_xyz_sin[valid_inds, :]
134 | prev_sig_single = prev_sig_sin[valid_inds]
135 | prev_raws.append(prev_xyz_single)
136 | prev_sigs.append(prev_sig_single)
137 |
138 | ### Transform prev point cloud
139 | transformed_prev_xyz = frame_transform_scan(prev_xyz[f_idx], pose_list[0], pose_list[f_idx+1])
140 | transformed_prev_velodyne=frame_transform_scan(np.zeros([1,4]), pose_list[0], pose_list[f_idx+1])
141 |
142 | ### Cut prev point cloud
143 | prev_xyz0 = transformed_prev_xyz
144 | for ci in range(3):
145 | prev_xyz0[transformed_prev_xyz[:, ci] < min_bound[ci], :] = 1000
146 | prev_xyz0[transformed_prev_xyz[:, ci] > max_bound[ci], :] = 1000
147 | prev_valid_inds = prev_xyz0[:, 0] != 1000
148 | transformed_prev_xyz = transformed_prev_xyz[prev_valid_inds, :]
149 | transformed_prev_sig_single = prev_sig_sin[prev_valid_inds]
150 | prev_vox.append(transformed_prev_xyz)
151 | prev_trans_sigs.append(transformed_prev_sig_single)
152 | prev_velodyne.append(transformed_prev_velodyne)
153 |
154 | # transpose centre coord for x axis
155 | x_bias = (self.max_volume_space[0] - self.min_volume_space[0])/2
156 | min_bound[0] -= x_bias
157 | max_bound[0] -= x_bias
158 | xyz[:, 0] -= x_bias
159 | if len(pose_list)>1:
160 | for f_idx in range(prev_frames_num):
161 | prev_raws_sim=prev_raws[f_idx]
162 | prev_vox_sin=prev_vox[f_idx]
163 | prev_raws_sim[:, 0]-= x_bias
164 | prev_vox_sin[:, 0]-= x_bias
165 | prev_velodyne_sin=prev_velodyne[f_idx]
166 | prev_velodyne_sin=prev_velodyne_sin[:,:3]
167 | prev_velodyne_sin[:,0]-= x_bias
168 | prev_vox[f_idx]=prev_vox_sin
169 | prev_velodyne[f_idx]=prev_velodyne_sin
170 | prev_raws[f_idx]=prev_raws_sim
171 |
172 | # get grid index
173 | crop_range = max_bound - min_bound
174 | cur_grid_size = self.grid_size
175 | intervals = crop_range / (cur_grid_size - 1)
176 | if (intervals == 0).any(): print("Zero interval!")
177 |
178 | grid_ind = (np.floor((np.clip(xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
179 |
180 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
181 | return_xyz = xyz - voxel_centers
182 | return_xyz = np.concatenate((return_xyz, xyz), axis=1)
183 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity
184 |
185 |
186 | prev_grid_list=[]
187 | prev_fea_list=[]
188 | prev_label_list=[]
189 | prev_transformed_fea_list=[]
190 | prev_transformed_grid_list=[]
191 | if len(pose_list)>1:
192 | for f_idx in range(prev_frames_num):
193 | single_prev_xyz=prev_raws[f_idx]
194 | single_prev_sig=prev_sigs[f_idx]
195 | single_trans_prev_sig=prev_trans_sigs[f_idx]
196 | single_prev_vox=prev_vox[f_idx]
197 |
198 | prev_grid_ind = (np.floor((np.clip(single_prev_xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
199 | prev_voxel_centers = (prev_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
200 | return_prev_xyz = single_prev_xyz - prev_voxel_centers
201 | return_prev_xyz = np.concatenate((return_prev_xyz, single_prev_xyz), axis=1)
202 | return_prev_fea = np.concatenate((return_prev_xyz, single_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity
203 |
204 | prev_transformed_grid_ind = (np.floor((np.clip(single_prev_vox, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
205 | prev_transformed_voxel_centers = (prev_transformed_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
206 | return_transformed_prev_xyz = single_prev_vox - prev_transformed_voxel_centers
207 | return_transformed_prev_xyz = np.concatenate((return_transformed_prev_xyz, single_prev_vox), axis=1)
208 | return_transformed_prev_fea = np.concatenate((return_transformed_prev_xyz, single_trans_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity
209 |
210 | prev_fea_list.append(return_prev_fea)
211 | prev_transformed_fea_list.append(return_transformed_prev_fea)
212 | prev_transformed_grid_list.append(prev_transformed_grid_ind)
213 | prev_grid_list.append(prev_grid_ind)
214 | prev_label_list.append(prev_labels[f_idx])
215 |
216 | dim_array = np.ones(len(self.grid_size) + 1, int)
217 | dim_array[0] = -1
218 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array)
219 | processed_label = labels # voxel labels
220 |
221 | data_tuple = (voxel_position, processed_label)
222 |
223 | vox_from_list=[]
224 | vox_to_list=[]
225 | if len(pose_list)>1:
226 | for f_idx in range(prev_frames_num):
227 | vox_grid_from,vox_grid_to= from_voxel_to_voxel(self.max_volume_space,self.min_volume_space,intervals,pose_list[0],pose_list[f_idx+1])
228 | vox_from_list.append(vox_grid_from)
229 | vox_to_list.append(vox_grid_to)
230 |
231 | else:
232 | vox_grid_to=np.array([[0,0,0]])
233 | vox_grid_from=np.array([[0,0,0]])
234 | vox_from_list.append(vox_grid_from)
235 | vox_to_list.append(vox_grid_to)
236 |
237 |
238 | if self.return_test:
239 | data_tuple += (grid_ind, labels, return_fea, index)
240 | else:
241 | data_tuple += (grid_ind, labels, return_fea)
242 |
243 | data_tuple += (origin_len,prev_velodyne,vox_to_list,vox_from_list,prev_grid_list,prev_fea_list,prev_transformed_grid_list,prev_transformed_fea_list,prev_label_list,min_bound,max_bound,intervals,stride_list)
244 |
245 | return data_tuple
246 |
247 |
248 | # transformation between Cartesian coordinates and polar coordinates
249 | def cart2polar(input_xyz):
250 | rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2)
251 | phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0])
252 | return np.stack((rho, phi, input_xyz[:, 2]), axis=1)
253 |
254 |
255 | def polar2cat(input_xyz_polar):
256 | # print(input_xyz_polar.shape)
257 | x = input_xyz_polar[0] * np.cos(input_xyz_polar[1])
258 | y = input_xyz_polar[0] * np.sin(input_xyz_polar[1])
259 | return np.stack((x, y, input_xyz_polar[2]), axis=0)
260 |
261 | @nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False)
262 | def nb_process_label(processed_label, sorted_label_voxel_pair):
263 | label_size = 256
264 | counter = np.zeros((label_size,), dtype=np.uint16)
265 | counter[sorted_label_voxel_pair[0, 3]] = 1
266 | cur_sear_ind = sorted_label_voxel_pair[0, :3]
267 | for i in range(1, sorted_label_voxel_pair.shape[0]):
268 | cur_ind = sorted_label_voxel_pair[i, :3]
269 | if not np.all(np.equal(cur_ind, cur_sear_ind)):
270 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
271 | counter = np.zeros((label_size,), dtype=np.uint16)
272 | cur_sear_ind = cur_ind
273 | counter[sorted_label_voxel_pair[i, 3]] += 1
274 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
275 | return processed_label
276 |
277 |
278 |
279 | def collate_fn_BEV_ms_tta(data):
280 | label2stack = np.stack([d[1] for d in data]).astype(np.int)
281 | grid_ind_stack = [d[2] for d in data]
282 | xyz = [d[4] for d in data]
283 | index = [d[5] for d in data]
284 | prev_velodyne=[d[7] for d in data]
285 | vox_grid_to=[d[8] for d in data]
286 | vox_grid_from=[d[9] for d in data]
287 | prev_grid_ind=[d[10] for d in data]
288 | prev_feat=[d[11] for d in data]
289 | prev_trans_ind=[d[12] for d in data]
290 | prev_trans_feat=[d[13] for d in data]
291 | min_bound=[d[15] for d in data]
292 | max_bound=[d[16] for d in data]
293 | interval=[d[17] for d in data]
294 | strides=[d[18] for d in data]
295 | current_frame={'grid_ind':grid_ind_stack, 'pt_feat': xyz, 'index': index, 'gt':torch.from_numpy(label2stack),'min_bound':min_bound,'max_bound':max_bound,'interval':interval,'stride':strides[0]}
296 | prev_frame={'vox_grid_to':vox_grid_to[0], 'vox_grid_from': vox_grid_from[0], 'grid_ind': prev_grid_ind[0], 'pt_feat':prev_feat[0],'trans_grid_ind':prev_trans_ind[0],'trans_pt_feat':prev_trans_feat[0],'lidar_pose':prev_velodyne[0]}
297 |
298 | return current_frame,prev_frame
299 |
300 |
301 | def collate_fn_BEV(data):
302 | data2stack = np.stack([d[0] for d in data]).astype(np.float32)
303 | label2stack = np.stack([d[1] for d in data]).astype(np.int)
304 | grid_ind_stack = [d[2] for d in data]
305 | point_label = [d[3] for d in data]
306 | xyz = [d[4] for d in data]
307 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz
308 |
309 | def collate_fn_BEV_tta(data):
310 | voxel_label = []
311 | for da1 in data:
312 | for da2 in da1:
313 | voxel_label.append(da2[1])
314 | grid_ind_stack = []
315 | for da1 in data:
316 | for da2 in da1:
317 | grid_ind_stack.append(da2[2])
318 | point_label = []
319 | for da1 in data:
320 | for da2 in da1:
321 | point_label.append(da2[3])
322 | xyz = []
323 | for da1 in data:
324 | for da2 in da1:
325 | xyz.append(da2[4])
326 | index = []
327 | for da1 in data:
328 | for da2 in da1:
329 | index.append(da2[5])
330 | return xyz, xyz, grid_ind_stack, point_label, xyz, index
331 |
332 | def collate_fn_BEV_ms(data):
333 | data2stack = np.stack([d[0] for d in data]).astype(np.float32)
334 | label2stack = np.stack([d[1] for d in data]).astype(np.int)
335 | grid_ind_stack = [d[2] for d in data]
336 | point_label = [d[3] for d in data]
337 | xyz = [d[4] for d in data]
338 | index = [d[5] for d in data]
339 | origin_len = [d[6] for d in data]
340 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len
341 |
342 |
--------------------------------------------------------------------------------
/dataloader/dataset_semantickitti_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge
3 |
4 | """
5 | SemKITTI dataloader
6 | """
7 | import numpy as np
8 | import torch
9 | import numba as nb
10 | from torch.utils import data
11 | import time
12 | import random
13 |
14 | from utils.util import Bresenham3D
15 | REGISTERED_DATASET_CLASSES = {}
16 |
17 | def from_voxel_to_voxel(max_volume_space,min_volume_space,intervals,pose0,pose):
18 | x_bias = (max_volume_space[0] - min_volume_space[0])/2
19 | max_bound = np.asarray(max_volume_space)
20 | min_bound = np.asarray(min_volume_space)
21 | min_bound[0] -= x_bias
22 | max_bound[0] -= x_bias
23 | max_bound2 = np.asarray(max_volume_space)
24 | min_bound2 = np.asarray(min_volume_space)
25 |
26 | voxel_grid = np.indices((256, 256, 32)).transpose(1, 2, 3, 0)
27 | voxel_grid = voxel_grid.reshape(-1,3)
28 | full_voxel_center=(voxel_grid.astype(np.float32) + 0.5) * intervals + min_bound
29 | full_voxel_center[:,0]+= x_bias
30 | current_vox_center=frame_transform_scan(full_voxel_center,pose0,pose)
31 | current_vox_center=np.concatenate([current_vox_center,voxel_grid],1)
32 | vox_xyz0 = current_vox_center
33 | for ci in range(3):
34 | vox_xyz0[current_vox_center[:, ci] < min_bound2[ci], :] = 1000
35 | vox_xyz0[current_vox_center[:, ci] > max_bound2[ci], :] = 1000
36 | vox_valid_inds = vox_xyz0[:, 0] != 1000
37 | current_vox_center = current_vox_center[vox_valid_inds, :]
38 | current_vox_center[:, 0]-= x_bias # current_vox_center is
39 | vox_grid_from=current_vox_center[:,-3:]
40 | vox_grid_to=current_vox_center[:,:3]
41 | vox_grid_to = (np.floor((np.clip(vox_grid_to, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
42 | return vox_grid_from,vox_grid_to
43 |
44 | def frame_transform_scan(points, pose0, pose):
45 |
46 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1])))
47 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1)
48 | new_points = new_points[:, :3]
49 | new_coords = new_points - pose0[:3, 3]
50 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1)
51 | new_coords = np.hstack((new_coords, points[:, 3:]))
52 |
53 | return new_coords
54 |
55 | def register_dataset(cls, name=None):
56 | global REGISTERED_DATASET_CLASSES
57 | if name is None:
58 | name = cls.__name__
59 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}"
60 | REGISTERED_DATASET_CLASSES[name] = cls
61 | return cls
62 |
63 |
64 | def get_model_class(name):
65 | global REGISTERED_DATASET_CLASSES
66 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}"
67 | return REGISTERED_DATASET_CLASSES[name]
68 |
69 |
70 | @register_dataset
71 | class voxel_dataset(data.Dataset):
72 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False,
73 | fixed_volume_space=False, max_volume_space=[50, 50, 1.5], min_volume_space=[-50, -50, -3]):
74 | 'Initialization'
75 | self.point_cloud_dataset = in_dataset
76 | self.grid_size = np.asarray(grid_size)
77 | self.rotate_aug = rotate_aug
78 | self.ignore_label = ignore_label
79 | self.return_test = return_test
80 | self.flip_aug = flip_aug
81 | self.fixed_volume_space = fixed_volume_space
82 | self.max_volume_space = max_volume_space
83 | self.min_volume_space = min_volume_space
84 |
85 | def __len__(self):
86 | 'Denotes the total number of samples'
87 | return len(self.point_cloud_dataset)
88 |
89 | def __getitem__(self, index):
90 | 'Generates one sample of data'
91 | data,prev_data,pose_list,stride_list = self.point_cloud_dataset[index]
92 |
93 | if len(data) == 4:
94 | xyz, labels, sig, origin_len = data
95 | prev_xyz, prev_labels, prev_sig,_ = prev_data
96 | if len(sig.shape) == 2: sig = np.squeeze(sig)
97 | else:
98 | raise Exception('Return invalid data tuple')
99 |
100 | origin_len = len(xyz)
101 | max_bound = np.asarray(self.max_volume_space)
102 | min_bound = np.asarray(self.min_volume_space)
103 |
104 | ### Cut point cloud and segmentation label for valid range
105 | xyz0 = xyz
106 | for ci in range(3):
107 | xyz0[xyz[:, ci] < min_bound[ci], :] = 1000
108 | xyz0[xyz[:, ci] > max_bound[ci], :] = 1000
109 | valid_inds = xyz0[:, 0] != 1000
110 | xyz = xyz[valid_inds, :]
111 | sig = sig[valid_inds]
112 |
113 |
114 | ### post_scan_preprocess ###
115 |
116 | prev_raws=[]
117 | prev_sigs=[]
118 | prev_vox=[]
119 | prev_velodyne=[]
120 | prev_trans_sigs=[]
121 |
122 | if len(pose_list)>1:
123 | prev_frames_num=len(prev_xyz)
124 | for f_idx in range(prev_frames_num):
125 |
126 | ### cut point clound
127 | prev_xyz0 = prev_xyz[f_idx]
128 | prev_xyz_sin=prev_xyz[f_idx]
129 | prev_sig_sin=prev_sig[f_idx]
130 | for ci in range(3):
131 | prev_xyz0[prev_xyz_sin[:, ci] < min_bound[ci], :] = 1000
132 | prev_xyz0[prev_xyz_sin[:, ci] > max_bound[ci], :] = 1000
133 | valid_inds = prev_xyz0[:, 0] != 1000
134 | prev_xyz_single = prev_xyz_sin[valid_inds, :]
135 | prev_sig_single = prev_sig_sin[valid_inds]
136 | prev_raws.append(prev_xyz_single)
137 | prev_sigs.append(prev_sig_single)
138 |
139 | ### Transform prev point cloud
140 | transformed_prev_xyz = frame_transform_scan(prev_xyz[f_idx], pose_list[0], pose_list[f_idx+1])
141 | transformed_prev_velodyne=frame_transform_scan(np.zeros([1,4]), pose_list[0], pose_list[f_idx+1])
142 |
143 | ### Cut prev point cloud
144 | prev_xyz0 = transformed_prev_xyz
145 | for ci in range(3):
146 | prev_xyz0[transformed_prev_xyz[:, ci] < min_bound[ci], :] = 1000
147 | prev_xyz0[transformed_prev_xyz[:, ci] > max_bound[ci], :] = 1000
148 | prev_valid_inds = prev_xyz0[:, 0] != 1000
149 | transformed_prev_xyz = transformed_prev_xyz[prev_valid_inds, :]
150 | transformed_prev_sig_single = prev_sig_sin[prev_valid_inds]
151 | prev_vox.append(transformed_prev_xyz)
152 | prev_trans_sigs.append(transformed_prev_sig_single)
153 | prev_velodyne.append(transformed_prev_velodyne)
154 |
155 | # transpose centre coord for x axis
156 | x_bias = (self.max_volume_space[0] - self.min_volume_space[0])/2
157 | min_bound[0] -= x_bias
158 | max_bound[0] -= x_bias
159 | xyz[:, 0] -= x_bias
160 | if len(pose_list)>1:
161 | for f_idx in range(prev_frames_num):
162 | prev_raws_sim=prev_raws[f_idx]
163 | prev_vox_sin=prev_vox[f_idx]
164 | prev_raws_sim[:, 0]-= x_bias
165 | prev_vox_sin[:, 0]-= x_bias
166 | prev_velodyne_sin=prev_velodyne[f_idx]
167 | prev_velodyne_sin=prev_velodyne_sin[:,:3]
168 | prev_velodyne_sin[:,0]-= x_bias
169 | prev_vox[f_idx]=prev_vox_sin
170 | prev_velodyne[f_idx]=prev_velodyne_sin
171 | prev_raws[f_idx]=prev_raws_sim
172 |
173 | # get grid index
174 | crop_range = max_bound - min_bound
175 | cur_grid_size = self.grid_size
176 | intervals = crop_range / (cur_grid_size - 1)
177 | if (intervals == 0).any(): print("Zero interval!")
178 |
179 | grid_ind = (np.floor((np.clip(xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
180 |
181 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
182 | return_xyz = xyz - voxel_centers
183 | return_xyz = np.concatenate((return_xyz, xyz), axis=1)
184 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity
185 |
186 |
187 | prev_grid_list=[]
188 | prev_fea_list=[]
189 | prev_label_list=[]
190 | prev_transformed_fea_list=[]
191 | prev_transformed_grid_list=[]
192 | if len(pose_list)>1:
193 | for f_idx in range(prev_frames_num):
194 | single_prev_xyz=prev_raws[f_idx]
195 | single_prev_sig=prev_sigs[f_idx]
196 | single_trans_prev_sig=prev_trans_sigs[f_idx]
197 | single_prev_vox=prev_vox[f_idx]
198 |
199 | prev_grid_ind = (np.floor((np.clip(single_prev_xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
200 | prev_voxel_centers = (prev_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
201 | return_prev_xyz = single_prev_xyz - prev_voxel_centers
202 | return_prev_xyz = np.concatenate((return_prev_xyz, single_prev_xyz), axis=1)
203 | return_prev_fea = np.concatenate((return_prev_xyz, single_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity
204 |
205 | prev_transformed_grid_ind = (np.floor((np.clip(single_prev_vox, min_bound, max_bound) - min_bound) / intervals)).astype(np.int)
206 | prev_transformed_voxel_centers = (prev_transformed_grid_ind.astype(np.float32) + 0.5) * intervals + min_bound
207 | return_transformed_prev_xyz = single_prev_vox - prev_transformed_voxel_centers
208 | return_transformed_prev_xyz = np.concatenate((return_transformed_prev_xyz, single_prev_vox), axis=1)
209 | return_transformed_prev_fea = np.concatenate((return_transformed_prev_xyz, single_trans_prev_sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity
210 |
211 | prev_fea_list.append(return_prev_fea)
212 | prev_transformed_fea_list.append(return_transformed_prev_fea)
213 | prev_transformed_grid_list.append(prev_transformed_grid_ind)
214 | prev_grid_list.append(prev_grid_ind)
215 | prev_label_list.append(prev_labels[f_idx])
216 |
217 | dim_array = np.ones(len(self.grid_size) + 1, int)
218 | dim_array[0] = -1
219 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array)
220 | processed_label = labels # voxel labels
221 |
222 | data_tuple = (voxel_position, processed_label)
223 |
224 | vox_from_list=[]
225 | vox_to_list=[]
226 | if len(pose_list)>1:
227 | for f_idx in range(prev_frames_num):
228 | vox_grid_from,vox_grid_to= from_voxel_to_voxel(self.max_volume_space,self.min_volume_space,intervals,pose_list[0],pose_list[f_idx+1])
229 | vox_from_list.append(vox_grid_from)
230 | vox_to_list.append(vox_grid_to)
231 |
232 | else:
233 | vox_grid_to=np.array([[0,0,0]])
234 | vox_grid_from=np.array([[0,0,0]])
235 | vox_from_list.append(vox_grid_from)
236 | vox_to_list.append(vox_grid_to)
237 |
238 |
239 | if self.return_test:
240 | data_tuple += (grid_ind, labels, return_fea, index)
241 | else:
242 | data_tuple += (grid_ind, labels, return_fea)
243 |
244 | data_tuple += (origin_len,prev_velodyne,vox_to_list,vox_from_list,prev_grid_list,prev_fea_list,prev_transformed_grid_list,prev_transformed_fea_list,prev_label_list,min_bound,max_bound,intervals,stride_list)
245 |
246 | return data_tuple
247 |
248 |
249 | # transformation between Cartesian coordinates and polar coordinates
250 | def cart2polar(input_xyz):
251 | rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2)
252 | phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0])
253 | return np.stack((rho, phi, input_xyz[:, 2]), axis=1)
254 |
255 |
256 | def polar2cat(input_xyz_polar):
257 | # print(input_xyz_polar.shape)
258 | x = input_xyz_polar[0] * np.cos(input_xyz_polar[1])
259 | y = input_xyz_polar[0] * np.sin(input_xyz_polar[1])
260 | return np.stack((x, y, input_xyz_polar[2]), axis=0)
261 |
262 | @nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False)
263 | def nb_process_label(processed_label, sorted_label_voxel_pair):
264 | label_size = 256
265 | counter = np.zeros((label_size,), dtype=np.uint16)
266 | counter[sorted_label_voxel_pair[0, 3]] = 1
267 | cur_sear_ind = sorted_label_voxel_pair[0, :3]
268 | for i in range(1, sorted_label_voxel_pair.shape[0]):
269 | cur_ind = sorted_label_voxel_pair[i, :3]
270 | if not np.all(np.equal(cur_ind, cur_sear_ind)):
271 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
272 | counter = np.zeros((label_size,), dtype=np.uint16)
273 | cur_sear_ind = cur_ind
274 | counter[sorted_label_voxel_pair[i, 3]] += 1
275 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter)
276 | return processed_label
277 |
278 |
279 |
280 | def collate_fn_BEV_ms_tta(data):
281 | label2stack = np.stack([d[1] for d in data]).astype(np.int)
282 | grid_ind_stack = [d[2] for d in data]
283 | xyz = [d[4] for d in data]
284 | index = [d[5] for d in data]
285 | prev_velodyne=[d[7] for d in data]
286 | vox_grid_to=[d[8] for d in data]
287 | vox_grid_from=[d[9] for d in data]
288 | prev_grid_ind=[d[10] for d in data]
289 | prev_feat=[d[11] for d in data]
290 | prev_trans_ind=[d[12] for d in data]
291 | prev_trans_feat=[d[13] for d in data]
292 | min_bound=[d[15] for d in data]
293 | max_bound=[d[16] for d in data]
294 | interval=[d[17] for d in data]
295 | strides=[d[18] for d in data]
296 | current_frame={'grid_ind':grid_ind_stack, 'pt_feat': xyz, 'index': index, 'gt':torch.from_numpy(label2stack),'min_bound':min_bound,'max_bound':max_bound,'interval':interval,'stride':strides[0]}
297 | prev_frame={'vox_grid_to':vox_grid_to[0], 'vox_grid_from': vox_grid_from[0], 'grid_ind': prev_grid_ind[0], 'pt_feat':prev_feat[0],'trans_grid_ind':prev_trans_ind[0],'trans_pt_feat':prev_trans_feat[0],'lidar_pose':prev_velodyne[0]}
298 |
299 | return current_frame,prev_frame
300 |
301 |
302 | def collate_fn_BEV(data):
303 | data2stack = np.stack([d[0] for d in data]).astype(np.float32)
304 | label2stack = np.stack([d[1] for d in data]).astype(np.int)
305 | grid_ind_stack = [d[2] for d in data]
306 | point_label = [d[3] for d in data]
307 | xyz = [d[4] for d in data]
308 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz
309 |
310 | def collate_fn_BEV_tta(data):
311 | voxel_label = []
312 | for da1 in data:
313 | for da2 in da1:
314 | voxel_label.append(da2[1])
315 | grid_ind_stack = []
316 | for da1 in data:
317 | for da2 in da1:
318 | grid_ind_stack.append(da2[2])
319 | point_label = []
320 | for da1 in data:
321 | for da2 in da1:
322 | point_label.append(da2[3])
323 | xyz = []
324 | for da1 in data:
325 | for da2 in da1:
326 | xyz.append(da2[4])
327 | index = []
328 | for da1 in data:
329 | for da2 in da1:
330 | index.append(da2[5])
331 | return xyz, xyz, grid_ind_stack, point_label, xyz, index
332 |
333 | def collate_fn_BEV_ms(data):
334 | data2stack = np.stack([d[0] for d in data]).astype(np.float32)
335 | label2stack = np.stack([d[1] for d in data]).astype(np.int)
336 | grid_ind_stack = [d[2] for d in data]
337 | point_label = [d[3] for d in data]
338 | xyz = [d[4] for d in data]
339 | index = [d[5] for d in data]
340 | origin_len = [d[6] for d in data]
341 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len
342 |
343 |
--------------------------------------------------------------------------------
/network/segmentator_3d_asymm_spconv_unlock.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | # author: Xinge, Xzy
3 | # @file: segmentator_3d_asymm_spconv.py
4 |
5 | import numpy as np
6 | import spconv
7 | import torch
8 | from torch import nn
9 |
10 |
11 | def extract_nonzero_features(x):
12 | device = x.device
13 | nonzero_index = torch.sum(torch.abs(x), dim=1).nonzero()
14 | coords = nonzero_index.type(torch.int32).to(device)
15 | channels = int(x.shape[1])
16 | features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels)
17 | features = features[torch.sum(torch.abs(features), dim=1).nonzero(), :]
18 | features = features.squeeze(1).to(device)
19 | coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0)
20 | return coords, features
21 |
22 | def no_extraction(x):
23 | device = x.device
24 | total_index=torch.ones_like(torch.sum(torch.abs(x), dim=1)).nonzero()
25 | coords = total_index.type(torch.int32).to(device)
26 | channels = int(x.shape[1])
27 | features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels)
28 | total_index2=torch.ones_like(torch.sum(torch.abs(features), dim=1)).nonzero()
29 | features = features[total_index2, :]
30 | features = features.squeeze(1).to(device)
31 | coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0)
32 | return coords, features
33 |
34 | # def manual_extraction(x, idx_for_extract):
35 | # device = x.device
36 | # nonzero_index = torch.sum(torch.abs(x), dim=1).nonzero()
37 | # manually_added_index = (idx_for_extract==1).nonzero()
38 | # total_index = torch.cat((nonzero_index, manually_added_index), dim=0)
39 | # total_index, _, _ = torch.unique(total_index, return_inverse=True, return_counts=True, dim=0)
40 | # coords = total_index.type(torch.int32)
41 |
42 | # channels = int(x.shape[1])
43 | # features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels)
44 | # nonzero_index2 = torch.sum(torch.abs(features), dim=1).nonzero()
45 | # manually_added_index2 = (idx_for_extract.flatten()==1).nonzero()
46 | # total_index2 = torch.cat((nonzero_index2, manually_added_index2), dim=0)
47 | # total_index2, _, _ = torch.unique(total_index2, return_inverse=True, return_counts=True, dim=0)
48 | # features = features[total_index2, :]
49 | # features = features.squeeze(1).to(device)
50 | # # coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0)
51 | # return coords, features
52 |
53 | def manual_extraction(x, proposure):
54 | device = x.device
55 | #proposure.shape -> (n,3)
56 | x_proposure=x.clone().detach()
57 | # import pdb;pdb.set_trace()
58 | if proposure!=None:
59 | x_proposure[0,0,proposure[:,0],proposure[:,1],proposure[:,2]]=1
60 | total_index=torch.sum(torch.abs(x_proposure), dim=1).nonzero()
61 |
62 | else:
63 | total_index=torch.ones_like(torch.sum(torch.abs(x), dim=1)).nonzero()
64 |
65 | coords = total_index.type(torch.int32).to(device)
66 | channels = int(x.shape[1])
67 | features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels)
68 | features_proposure = x_proposure.permute(0, 2, 3, 4, 1).reshape(-1, channels)
69 | if proposure==None:
70 | total_index=torch.ones_like(torch.sum(torch.abs(features), dim=1)).nonzero()
71 | else:
72 | total_index=torch.sum(torch.abs(features_proposure), dim=1).nonzero()
73 | features = features[total_index, :]
74 | features = features.squeeze(1).to(device)
75 | coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0)
76 | return coords, features
77 |
78 | class Asymm_3d_spconv(nn.Module):
79 | def __init__(self,
80 | output_shape,
81 | use_norm=True,
82 | num_input_features=128,
83 | nclasses=20, n_height=32, strict=False, init_size=16):
84 | super(Asymm_3d_spconv, self).__init__()
85 | self.nclasses = nclasses
86 | self.nheight = n_height
87 | self.strict = False
88 |
89 | sparse_shape = np.array(output_shape)
90 | self.sparse_shape = sparse_shape
91 |
92 | ### Completion sub-network
93 | mybias = False # False
94 | chs = [init_size, init_size*1, init_size*1, init_size*1]
95 | self.a_conv1 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU())
96 | self.a_conv2 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU())
97 | self.a_conv3 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 5, 1, padding=2, bias=mybias), nn.ReLU())
98 | self.a_conv4 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 7, 1, padding=3, bias=mybias), nn.ReLU())
99 | self.a_conv5 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU())
100 | self.a_conv6 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 5, 1, padding=2, bias=mybias), nn.ReLU())
101 | self.a_conv7 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 7, 1, padding=3, bias=mybias), nn.ReLU())
102 | self.ch_conv1 = nn.Sequential(nn.Conv3d(chs[1]*7, chs[0], kernel_size=1, stride=1, bias=mybias), nn.ReLU())
103 | self.res_1 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 3, 1, padding=1, bias=mybias), nn.ReLU())
104 | self.res_2 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 5, 1, padding=2, bias=mybias), nn.ReLU())
105 | self.res_3 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 7, 1, padding=3, bias=mybias), nn.ReLU())
106 |
107 | ### Segmentation sub-network
108 | self.downCntx = ResContextBlock(num_input_features, init_size, indice_key="pre")
109 | self.resBlock2 = ResBlock(init_size, 2 * init_size, 0.2, height_pooling=True, indice_key="down2")
110 | self.resBlock3 = ResBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=True, indice_key="down3")
111 | self.resBlock4 = ResBlock(4 * init_size, 8 * init_size, 0.2, pooling=True, height_pooling=False,
112 | indice_key="down4")
113 | self.resBlock5 = ResBlock(8 * init_size, 16 * init_size, 0.2, pooling=True, height_pooling=False,
114 | indice_key="down5")
115 |
116 | self.upBlock0 = UpBlock(16 * init_size, 16 * init_size, indice_key="up0", up_key="down5")
117 | self.upBlock1 = UpBlock(16 * init_size, 8 * init_size, indice_key="up1", up_key="down4")
118 | self.upBlock2 = UpBlock(8 * init_size, 4 * init_size, indice_key="up2", up_key="down3")
119 | self.upBlock3 = UpBlock(4 * init_size, 2 * init_size, indice_key="up3", up_key="down2")
120 |
121 | self.ReconNet = ReconBlock(2 * init_size, 2 * init_size, indice_key="recon")
122 |
123 | self.logits = spconv.SubMConv3d(4 * init_size, nclasses, indice_key="logit", kernel_size=3, stride=1, padding=1,
124 | bias=True)
125 |
126 |
127 | def forward(self, voxel_features, coors, batch_size, extraction='all'):
128 |
129 | coors = coors.int()
130 | x_sparse = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size)
131 | x = x_sparse
132 | # import pdb;pdb.set_trace()
133 | # Spase to dense
134 | x_dense = x_sparse.dense()
135 |
136 | ### Completion sub-network by dense convolution
137 | x1 = self.a_conv1(x_dense)
138 | x2 = self.a_conv2(x1)
139 | x3 = self.a_conv3(x1)
140 | x4 = self.a_conv4(x1)
141 | t1 = torch.cat((x2, x3, x4), 1)
142 | x5 = self.a_conv5(t1)
143 | x6 = self.a_conv6(t1)
144 | x7 = self.a_conv7(t1)
145 | x = torch.cat((x1, x2, x3, x4, x5, x6, x7), 1)
146 | y0 = self.ch_conv1(x)
147 | y1 = self.res_1(x_dense)
148 | y2 = self.res_2(x_dense)
149 | y3 = self.res_3(x_dense)
150 | x = x_dense + y0 + y1 + y2 + y3
151 |
152 | # Dense to sparse
153 | if extraction is 'all':
154 | coord, features = extract_nonzero_features(x)
155 | elif extraction is None:
156 | coord, features = no_extraction(x)
157 | else:
158 | coord, features = manual_extraction(x, extraction)
159 | x = spconv.SparseConvTensor(features, coord.int(), self.sparse_shape, batch_size) # voxel features
160 |
161 | ### Segmentation sub-network by sparse convolution
162 | x = self.downCntx(x)
163 | down1c, down1b = self.resBlock2(x)
164 | down2c, down2b = self.resBlock3(down1c)
165 | down3c, down3b = self.resBlock4(down2c)
166 | down4c, down4b = self.resBlock5(down3c)
167 |
168 | up4e = self.upBlock0(down4c, down4b)
169 | up3e = self.upBlock1(up4e, down3b)
170 | up2e = self.upBlock2(up3e, down2b)
171 | up1e = self.upBlock3(up2e, down1b)
172 |
173 | up0e = self.ReconNet(up1e)
174 |
175 | up0e.features = torch.cat((up0e.features, up1e.features), 1)
176 |
177 | logits = self.logits(up0e)
178 | y = logits.dense()
179 |
180 | return y
181 |
182 | @staticmethod
183 | def _joining(encoder_features, x, concat):
184 | if concat:
185 | return torch.cat((encoder_features, x), dim=1)
186 | else:
187 | return encoder_features + x
188 |
189 |
190 |
191 | def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
192 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=3, stride=stride,
193 | padding=1, bias=False, indice_key=indice_key)
194 |
195 |
196 | def conv1x3(in_planes, out_planes, stride=1, indice_key=None):
197 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,
198 | padding=(0, 1, 1), bias=False, indice_key=indice_key)
199 |
200 |
201 | def conv1x1x3(in_planes, out_planes, stride=1, indice_key=None):
202 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride,
203 | padding=(0, 0, 1), bias=False, indice_key=indice_key)
204 |
205 |
206 | def conv1x3x1(in_planes, out_planes, stride=1, indice_key=None):
207 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride,
208 | padding=(0, 1, 0), bias=False, indice_key=indice_key)
209 |
210 |
211 | def conv3x1x1(in_planes, out_planes, stride=1, indice_key=None):
212 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride,
213 | padding=(1, 0, 0), bias=False, indice_key=indice_key)
214 |
215 |
216 | def conv3x1(in_planes, out_planes, stride=1, indice_key=None):
217 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride,
218 | padding=(1, 0, 1), bias=False, indice_key=indice_key)
219 |
220 |
221 | def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
222 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=1, stride=stride,
223 | padding=1, bias=False, indice_key=indice_key)
224 |
225 |
226 | class ResContextBlock(nn.Module):
227 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None):
228 | super(ResContextBlock, self).__init__()
229 | self.conv1 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef")
230 | self.bn0 = nn.BatchNorm1d(out_filters)
231 | self.act1 = nn.LeakyReLU()
232 |
233 | self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef")
234 | self.bn0_2 = nn.BatchNorm1d(out_filters)
235 | self.act1_2 = nn.LeakyReLU()
236 |
237 | self.conv2 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef")
238 | self.act2 = nn.LeakyReLU()
239 | self.bn1 = nn.BatchNorm1d(out_filters)
240 |
241 | self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef")
242 | self.act3 = nn.LeakyReLU()
243 | self.bn2 = nn.BatchNorm1d(out_filters)
244 |
245 | self.weight_initialization()
246 |
247 | def weight_initialization(self):
248 | for m in self.modules():
249 | if isinstance(m, nn.BatchNorm1d):
250 | nn.init.constant_(m.weight, 1)
251 | nn.init.constant_(m.bias, 0)
252 |
253 | def forward(self, x):
254 | shortcut = self.conv1(x)
255 | shortcut.features = self.act1(shortcut.features)
256 | shortcut.features = self.bn0(shortcut.features)
257 |
258 | shortcut = self.conv1_2(shortcut)
259 | shortcut.features = self.act1_2(shortcut.features)
260 | shortcut.features = self.bn0_2(shortcut.features)
261 |
262 | resA = self.conv2(x)
263 | resA.features = self.act2(resA.features)
264 | resA.features = self.bn1(resA.features)
265 |
266 | resA = self.conv3(resA)
267 | resA.features = self.act3(resA.features)
268 | resA.features = self.bn2(resA.features)
269 | resA.features = resA.features + shortcut.features
270 |
271 | return resA
272 |
273 |
274 | class ResBlock(nn.Module):
275 | def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1,
276 | pooling=True, drop_out=True, height_pooling=False, indice_key=None):
277 | super(ResBlock, self).__init__()
278 | self.pooling = pooling
279 | self.drop_out = drop_out
280 |
281 | self.conv1 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef")
282 | self.act1 = nn.LeakyReLU()
283 | self.bn0 = nn.BatchNorm1d(out_filters)
284 |
285 | self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef")
286 | self.act1_2 = nn.LeakyReLU()
287 | self.bn0_2 = nn.BatchNorm1d(out_filters)
288 |
289 | self.conv2 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef")
290 | self.act2 = nn.LeakyReLU()
291 | self.bn1 = nn.BatchNorm1d(out_filters)
292 |
293 | self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef")
294 | self.act3 = nn.LeakyReLU()
295 | self.bn2 = nn.BatchNorm1d(out_filters)
296 |
297 | if pooling:
298 | if height_pooling:
299 | self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=2,
300 | padding=1, indice_key=indice_key, bias=False)
301 | else:
302 | self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),
303 | padding=1, indice_key=indice_key, bias=False)
304 | self.weight_initialization()
305 |
306 | def weight_initialization(self):
307 | for m in self.modules():
308 | if isinstance(m, nn.BatchNorm1d):
309 | nn.init.constant_(m.weight, 1)
310 | nn.init.constant_(m.bias, 0)
311 |
312 | def forward(self, x):
313 | shortcut = self.conv1(x)
314 | shortcut.features = self.act1(shortcut.features)
315 | shortcut.features = self.bn0(shortcut.features)
316 |
317 | shortcut = self.conv1_2(shortcut)
318 | shortcut.features = self.act1_2(shortcut.features)
319 | shortcut.features = self.bn0_2(shortcut.features)
320 |
321 | resA = self.conv2(x)
322 | resA.features = self.act2(resA.features)
323 | resA.features = self.bn1(resA.features)
324 |
325 | resA = self.conv3(resA)
326 | resA.features = self.act3(resA.features)
327 | resA.features = self.bn2(resA.features)
328 |
329 | resA.features = resA.features + shortcut.features
330 |
331 | if self.pooling:
332 | resB = self.pool(resA)
333 | return resB, resA
334 | else:
335 | return resA
336 |
337 |
338 | class UpBlock(nn.Module):
339 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), indice_key=None, up_key=None):
340 | super(UpBlock, self).__init__()
341 | # self.drop_out = drop_out
342 | self.trans_dilao = conv3x3(in_filters, out_filters, indice_key=indice_key + "new_up")
343 | self.trans_act = nn.LeakyReLU()
344 | self.trans_bn = nn.BatchNorm1d(out_filters)
345 |
346 | self.conv1 = conv1x3(out_filters, out_filters, indice_key=indice_key)
347 | self.act1 = nn.LeakyReLU()
348 | self.bn1 = nn.BatchNorm1d(out_filters)
349 |
350 | self.conv2 = conv3x1(out_filters, out_filters, indice_key=indice_key)
351 | self.act2 = nn.LeakyReLU()
352 | self.bn2 = nn.BatchNorm1d(out_filters)
353 |
354 | self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key)
355 | self.act3 = nn.LeakyReLU()
356 | self.bn3 = nn.BatchNorm1d(out_filters)
357 | # self.dropout3 = nn.Dropout3d(p=dropout_rate)
358 |
359 | self.up_subm = spconv.SparseInverseConv3d(out_filters, out_filters, kernel_size=3, indice_key=up_key,
360 | bias=False)
361 |
362 | self.weight_initialization()
363 |
364 | def weight_initialization(self):
365 | for m in self.modules():
366 | if isinstance(m, nn.BatchNorm1d):
367 | nn.init.constant_(m.weight, 1)
368 | nn.init.constant_(m.bias, 0)
369 |
370 | def forward(self, x, skip):
371 | upA = self.trans_dilao(x)
372 | upA.features = self.trans_act(upA.features)
373 | upA.features = self.trans_bn(upA.features)
374 |
375 | ## upsample
376 | upA = self.up_subm(upA)
377 |
378 | upA.features = upA.features + skip.features
379 |
380 | upE = self.conv1(upA)
381 | upE.features = self.act1(upE.features)
382 | upE.features = self.bn1(upE.features)
383 |
384 | upE = self.conv2(upE)
385 | upE.features = self.act2(upE.features)
386 | upE.features = self.bn2(upE.features)
387 |
388 | upE = self.conv3(upE)
389 | upE.features = self.act3(upE.features)
390 | upE.features = self.bn3(upE.features)
391 |
392 | return upE
393 |
394 |
395 | class ReconBlock(nn.Module):
396 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None):
397 | super(ReconBlock, self).__init__()
398 | self.conv1 = conv3x1x1(in_filters, out_filters, indice_key=indice_key + "bef")
399 | self.bn0 = nn.BatchNorm1d(out_filters)
400 | self.act1 = nn.Sigmoid()
401 |
402 | self.conv1_2 = conv1x3x1(in_filters, out_filters, indice_key=indice_key + "bef")
403 | self.bn0_2 = nn.BatchNorm1d(out_filters)
404 | self.act1_2 = nn.Sigmoid()
405 |
406 | self.conv1_3 = conv1x1x3(in_filters, out_filters, indice_key=indice_key + "bef")
407 | self.bn0_3 = nn.BatchNorm1d(out_filters)
408 | self.act1_3 = nn.Sigmoid()
409 |
410 | def forward(self, x):
411 | shortcut = self.conv1(x)
412 | shortcut.features = self.bn0(shortcut.features)
413 | shortcut.features = self.act1(shortcut.features)
414 |
415 | shortcut2 = self.conv1_2(x)
416 | shortcut2.features = self.bn0_2(shortcut2.features)
417 | shortcut2.features = self.act1_2(shortcut2.features)
418 |
419 | shortcut3 = self.conv1_3(x)
420 | shortcut3.features = self.bn0_3(shortcut3.features)
421 | shortcut3.features = self.act1_3(shortcut3.features)
422 | shortcut.features = shortcut.features + shortcut2.features + shortcut3.features
423 |
424 | shortcut.features = shortcut.features * x.features
425 |
426 | return shortcut
427 |
--------------------------------------------------------------------------------
/run_tta_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import os
3 | import time
4 | import argparse
5 | import sys
6 | import numpy as np
7 |
8 | import torch
9 | import torch.optim as optim
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from tqdm import tqdm
13 |
14 | from dataloader.pc_dataset_test import get_SemKITTI_label_name, get_eval_mask, unpack
15 | from config.config import load_config_data
16 | from builder import loss_builder
17 | from builder import model_builder_unlock as model_builder
18 | from builder import data_builder_test as data_builder
19 | from utils.load_save_util import load_checkpoint
20 | import warnings
21 | warnings.filterwarnings("ignore")
22 | import yaml
23 | from utils.util import Bresenham3D
24 | import random
25 |
26 | import ast
27 | import pdb
28 | import copy
29 | import importlib.util as ilutil
30 | import glob
31 | from torch.utils.tensorboard import SummaryWriter
32 | from matplotlib import pyplot as plt
33 | from utils.np_ioueval import iouEval
34 | from utils.softmax_entropy import softmax_entropy
35 |
36 | import pandas as pd
37 | import seaborn as sn
38 |
39 | CATS = ["empty",
40 | "car", "bicycle", "motorcycle", "truck", "other-vehicle",
41 | "person", "bicyclist", "motorcyclist", "road", "parking",
42 | "sidewalk", "other-ground", "building", "fence", "vegetation",
43 | "trunk", "terrain", "pole", "traffic-sign"]
44 |
45 | mapping_forward = {0:0,
46 | 1:0, 10:1, 11:2, 13:5, 15:3,
47 | 16:5, 18:4, 20:5, 30:6, 31:7,
48 | 32:8, 40:9, 44:10, 48:11, 49:12,
49 | 50:13, 51:14, 52:0, 60:9, 70:15,
50 | 71:16, 72:17, 80:18, 81:19, 99:0,
51 | 252:1, 253:7, 254:6, 255:8, 256:5,
52 | 257:5, 258:4, 259:5}
53 |
54 | max_key = max(mapping_forward.keys())
55 |
56 | MAP_ARRAY = np.zeros(max_key + 1, dtype=int)
57 |
58 | for key, value in mapping_forward.items():
59 | MAP_ARRAY[key] = value
60 |
61 | PALLETE = np.asarray([[0, 0, 0],[245, 150, 100],[245, 230, 100],[150, 60, 30],[180, 30, 80],[255, 0, 0],[30, 30, 255],[200, 40, 255],[90, 30, 150],[255, 0, 255],
62 | [255, 150, 255],[75, 0, 75],[75, 0, 175],[0, 200, 255],[50, 120, 255],[0, 175, 0],[0, 60, 135],[80, 240, 150],[150, 240, 255],[0, 0, 255], [255,255,255]]).astype(np.uint8)
63 | PALLETE[:,[0,2]]=PALLETE[:,[2,0]]
64 | PALLETE_BINARY = np.asarray([[255,0,0], [0,255,0], [0,0,255], [0,0,0]]).astype(np.uint8)
65 |
66 | def train2SemKITTI(input_label):
67 | # delete 0 label (uses uint8 trick : 0 - 1 = 255 )
68 | return input_label + 1
69 |
70 | def get_remap_first(semkittiyaml):
71 | # make lookup table for mapping
72 | learning_map_inv = semkittiyaml["learning_map_inv"]
73 | maxkey = max(learning_map_inv.keys())
74 | # +100 hack making lut bigger just in case there are unknown labels
75 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32)
76 | remap_lut[list(learning_map_inv.keys())] = list(learning_map_inv.values())
77 | return remap_lut,learning_map_inv
78 |
79 | def get_remap_second(semkittiyaml):
80 | class_remap = semkittiyaml["learning_map"]
81 | maxkey2 = max(class_remap.keys())
82 | remap_lut = np.zeros((maxkey2 + 100), dtype=np.int32)
83 | remap_lut[list(class_remap.keys())] = list(class_remap.values())
84 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid'
85 | remap_lut[0] = 0
86 | return remap_lut
87 |
88 |
89 | def remapping(pred, remap=None):
90 | ### save prediction after remapping
91 | upper_half = pred >> 16 # get upper half for instances
92 | lower_half = pred & 0xFFFF # get lower half for semantics
93 | lower_half = remap[lower_half] # do the remapping of semantics
94 | pred = (upper_half << 16) + lower_half # reconstruct full label
95 | pred = pred.astype(np.uint32)
96 | pred = pred.astype(np.uint16)
97 | return pred
98 |
99 | def save_pred(final_preds, save_dir, output_path):
100 | _, dir2 = save_dir.split('/sequences/',1)
101 | new_save_dir = output_path + '/sequences/' + dir2.replace('velodyne', 'predictions')[:-3]+'label'
102 | if not os.path.exists(os.path.dirname(new_save_dir)):
103 | try:
104 | os.makedirs(os.path.dirname(new_save_dir))
105 | except OSError as exc:
106 | if exc.errno != errno.EEXIST:
107 | raise
108 | final_preds.tofile(new_save_dir)
109 |
110 | def extract_bev_for_vis(arr, ignore_idx=0):
111 | for i in arr[::-1]:
112 | if i != ignore_idx:
113 | return i
114 | return ignore_idx
115 |
116 | def extract_mev_for_vis(arr, ignore_idx=0):
117 | for i in arr:
118 | if i != ignore_idx:
119 | return i
120 | return ignore_idx
121 |
122 | def extract_bev_for_vis_dual(arr, ignore_idx=(0,255)):
123 | for i in arr[::-1]:
124 | if i not in ignore_idx:
125 | return i
126 | return 255
127 |
128 |
129 | def main(args):
130 | pytorch_device = torch.device('cuda:0')
131 | epsilon = np.finfo(np.float32).eps
132 |
133 | config_path = args.config_path
134 |
135 | configs = load_config_data(config_path)
136 |
137 | dataset_config = configs['dataset_params']
138 | train_dataloader_config = configs['train_data_loader']
139 | val_dataloader_config = configs['val_data_loader']
140 | val_batch_size = val_dataloader_config['batch_size']
141 |
142 | model_config = configs['model_params']
143 | train_hypers = configs['train_params']
144 |
145 | grid_size = model_config['output_shape']
146 | num_class = model_config['num_class']
147 | ignore_label = dataset_config['ignore_label']
148 | loss_fn_ce, loss_fn_lovasz = loss_builder.build(wce=True, lovasz=True, num_class=num_class, ignore_label=ignore_label)
149 | loss_fn_ce_binary, loss_fn_lovasz_binary = loss_builder.build(wce=True, lovasz=True, num_class=2, ignore_label=ignore_label)
150 |
151 | # Define dataset/loader
152 | with open("config/label_mapping/semantic-kitti.yaml", 'r') as stream:
153 | semkittiyaml = yaml.safe_load(stream)
154 | remap_first,class_inv_remap = get_remap_first(semkittiyaml)
155 | remap_second = get_remap_second(semkittiyaml)
156 | class_strings = semkittiyaml["labels"]
157 | strides=[int(num) for num in ast.literal_eval(args.stride)]
158 |
159 | test_dataset_loader, test_pt_dataset = data_builder.build(dataset_config,
160 | train_dataloader_config,
161 | val_dataloader_config,
162 | grid_size=grid_size,
163 | use_tta=True,
164 | use_multiscan=True,
165 | stride=args.stride,
166 | sq_num=args.sq_num)
167 |
168 | # Define experiment path
169 | exp_name = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) + '_' + __file__.replace('run_tta_','').replace('.py','') + args.name
170 | exp_path = args.talos_root+'experiments/' + exp_name
171 | print("Experiment path is "+exp_path)
172 | os.makedirs(exp_path, exist_ok=True)
173 |
174 | # Save run_code snapshots
175 | os.system('scp -r '+args.talos_root+__file__+' '+exp_path+'/'+__file__)
176 | #
177 |
178 | writer = SummaryWriter(exp_path)
179 | config_file = exp_path + '/config.txt'
180 | with open(config_file, 'w') as log:
181 | log.write(str(args))
182 |
183 | loss_cont_names = ['loss_cont', 'loss_cont_occ_ce', 'loss_cont_occ_lovasz', 'loss_cont_pgt_ce','loss_cont_pgt_lovasz']
184 | loss_adapt_names = ['loss_adapt', 'loss_adapt_occ_ce', 'loss_adapt_occ_lovasz', 'loss_adapt_pgt_ce','loss_adapt_pgt_lovasz']
185 | evaluator_all = iouEval(num_class, [])
186 | baseline_performance = open(args.baseline_perf_txt, 'r')
187 | baseline_prediction_paths = sorted(glob.glob(args.baseline_preds+'/*.label'))
188 |
189 | model_load_path = train_hypers['model_load_path']
190 | model_load_path += 'pretrained.pth'
191 |
192 | model_baseline = model_builder.build(model_config)
193 | print('Load model from: %s' % model_load_path)
194 | model_baseline = load_checkpoint(model_load_path, model_baseline)
195 | model_baseline.to(pytorch_device)
196 |
197 | # For freeze
198 | module_names_mlp = ['cylinder_3d_generator']
199 | module_names_comp = ['a_conv1', 'a_conv2', 'a_conv3', 'a_conv4', 'a_conv5', 'a_conv6', 'a_conv7', 'ch_conv1','res_1','res_2','res_3']
200 | module_names_seg = ['downCntx', 'resBlock2', 'resBlock3', 'resBlock4', 'resBlock5', 'upBlock0', 'upBlock1', 'upBlock2', 'upBlock3', 'ReconNet']
201 | module_names_logit = ['logits']
202 |
203 | assert (args.do_adapt or args.do_cont)
204 |
205 | if args.do_cont:
206 | print("continual")
207 | model_cont = copy.deepcopy(model_baseline)
208 | param_to_update_cont = []
209 |
210 | print("Update segmentation module for continual tta.")
211 | for tn in module_names_seg:
212 | param_to_update_cont += list(getattr(model_cont.cylinder_3d_spconv_seg, tn).parameters())
213 |
214 | optimizer_cont = optim.Adam(param_to_update_cont, lr=args.cont_lr)
215 |
216 | if args.do_adapt:
217 | print("scan-wise adaptation")
218 |
219 | # current_frame={'grid_ind','pt_feat','index','gt'}
220 | # prev_frame={'vox_grid_to','vox_grid_from','grid_ind','pt_feat','trans_grid_ind','trans_pt_feat','lidar_pose'}
221 | for idx_test, (frame_curr, frame_aux) in enumerate(tqdm(test_dataset_loader)):
222 |
223 | print('')
224 |
225 | #########################################################################################################################################
226 | ############################################################# Data process ##############################################################
227 | #########################################################################################################################################
228 |
229 | exist_stride = [stride in frame_curr['stride'] for stride in strides]
230 |
231 | flag_aux_exist= bool(len(frame_aux['trans_grid_ind'])!=0)
232 | flag_adapt_aux_exist= bool(exist_stride[0])
233 | flag_cont_aux_exist= bool(exist_stride[1])
234 | adapt_aux_idx= 0
235 | cont_aux_idx = 1 if flag_adapt_aux_exist else 0
236 |
237 | feat_curr = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in frame_curr['pt_feat']]
238 | grid_curr = [torch.from_numpy(i).to(pytorch_device) for i in frame_curr['grid_ind']]
239 | gt_curr = frame_curr['gt'].to(pytorch_device)
240 |
241 | if flag_aux_exist:
242 | feat_aux = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in frame_aux['pt_feat']]
243 | grid_aux = [torch.from_numpy(i).to(pytorch_device) for i in frame_aux['grid_ind']]
244 | else:
245 | print("The pointer is at the edge of the sequence: " + str(idx_test))
246 |
247 | #########################################################################################################################################
248 | ##################################################### Pseudo GT generation ##############################################################
249 | #########################################################################################################################################
250 |
251 | if args.use_los:
252 | print("Generate occupancy pgt by checking the line of sight")
253 | voxel_los_adapt = 255*torch.ones([1,256,256,32],dtype=torch.long).to(pytorch_device) # 0:empty, 1:occupied, 2:LoS, 255:ignore
254 | voxel_los_cont = 255*torch.ones([1,256,256,32],dtype=torch.long).to(pytorch_device) # 0:empty, 1:occupied, 2:LoS, 255:ignore
255 | if flag_adapt_aux_exist:
256 | idx_curr_occupied = frame_curr['grid_ind'][0]
257 | idx_curr_occupied = np.unique(idx_curr_occupied, axis =0)
258 | idx_aux_occupied = frame_aux['trans_grid_ind'][adapt_aux_idx]
259 | idx_aux_occupied = np.unique(idx_aux_occupied, axis=0)
260 | idx_cat_occupied = np.concatenate([idx_curr_occupied, idx_aux_occupied],0)
261 |
262 | voxel_aux = np.zeros([256,256,32])
263 | voxel_aux[idx_aux_occupied[:,0],idx_aux_occupied[:,1],idx_aux_occupied[:,2]] = 1 # Empty:0, Occupy:1
264 | voxel_curr = np.zeros([256,256,32])
265 | voxel_curr[idx_curr_occupied[:,0],idx_curr_occupied[:,1],idx_curr_occupied[:,2]] = 1 # Empty:0, Occupy:1
266 |
267 | voxel_aux_only = np.where(np.logical_and(voxel_curr!=1,voxel_aux==1), True, False)
268 | los_start = (np.floor((np.clip(frame_aux['lidar_pose'][0] , frame_curr['min_bound'][0], frame_curr['max_bound'][0]) - frame_curr['min_bound'][0]) / frame_curr['interval'][0])).astype(np.int)
269 | los_end = np.argwhere(voxel_aux_only==True)
270 | los_end = los_end[random.sample(range(los_end.shape[0]),los_end.shape[0]//8)]
271 | idx_los_empty = Bresenham3D(los_start,los_end,idx_cat_occupied)
272 | idx_los_empty = np.unique(np.array(idx_los_empty), axis=0)
273 | idx_los_occupied = idx_aux_occupied
274 |
275 | voxel_los_adapt[0, idx_los_empty[:,0],idx_los_empty[:,1],idx_los_empty[:,2]] = 0
276 | voxel_los_adapt[0, idx_los_occupied[:,0],idx_los_occupied[:,1],idx_los_occupied[:,2]] = 1
277 |
278 | if flag_cont_aux_exist:
279 | idx_curr_occupied = frame_curr['grid_ind'][0]
280 | idx_curr_occupied = np.unique(idx_curr_occupied, axis =0)
281 | idx_aux_occupied = frame_aux['trans_grid_ind'][cont_aux_idx]
282 | idx_aux_occupied = np.unique(idx_aux_occupied, axis=0)
283 | idx_cat_occupied = np.concatenate([idx_curr_occupied, idx_aux_occupied],0)
284 |
285 | voxel_aux = np.zeros([256,256,32])
286 | voxel_aux[idx_aux_occupied[:,0],idx_aux_occupied[:,1],idx_aux_occupied[:,2]] = 1 # Empty:0, Occupy:1
287 | voxel_curr = np.zeros([256,256,32])
288 | voxel_curr[idx_curr_occupied[:,0],idx_curr_occupied[:,1],idx_curr_occupied[:,2]] = 1 # Empty:0, Occupy:1
289 |
290 | voxel_aux_only = np.where(np.logical_and(voxel_curr!=1,voxel_aux==1), True, False)
291 | los_start = (np.floor((np.clip(frame_aux['lidar_pose'][cont_aux_idx] , frame_curr['min_bound'][0], frame_curr['max_bound'][0]) - frame_curr['min_bound'][0]) / frame_curr['interval'][0])).astype(np.int)
292 | los_end = np.argwhere(voxel_aux_only==True)
293 | los_end = los_end[random.sample(range(los_end.shape[0]),los_end.shape[0]//30)]
294 | idx_los_empty = Bresenham3D(los_start,los_end,idx_cat_occupied)
295 | idx_los_empty = np.unique(np.array(idx_los_empty), axis=0)
296 | idx_los_occupied = idx_aux_occupied
297 |
298 | voxel_los_cont[0, idx_los_empty[:,0],idx_los_empty[:,1],idx_los_empty[:,2]] = 0
299 | voxel_los_cont[0, idx_los_occupied[:,0],idx_los_occupied[:,1],idx_los_occupied[:,2]] = 1
300 | else:
301 | print("Skip to generate los-based pgt.")
302 |
303 | dim_chn=20
304 | conv_ones = nn.Conv3d(dim_chn, dim_chn, kernel_size=(5,5,3), stride=1, padding=(2,2,1), bias=False)
305 | conv_ones.weight = torch.nn.Parameter(torch.ones((dim_chn,dim_chn,5,5,3)))
306 | conv_ones.weight.requires_grad=False
307 | conv_ones.cuda()
308 | proposure_idx=None
309 |
310 | if args.use_pgt:
311 | print("Generate class pgt according to entropy-based confidence.")
312 | model_baseline.eval()
313 | voxel_pgt_aux_cont = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device)
314 | voxel_pgt_aux_adapt = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device)
315 | voxel_pgt_curr = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device)
316 | voxel_pgt_all_adapt = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device)
317 | voxel_pgt_all_cont = 255*torch.ones([1,256,256,32]).type(torch.LongTensor).to(pytorch_device)
318 |
319 | with torch.no_grad():
320 |
321 | ### Current PGT
322 | print("- Generate class pgt using current scan.")
323 | pred_logit_curr = model_baseline(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False)
324 |
325 | pred_cls_curr = torch.argmax(pred_logit_curr, 1).type(torch.LongTensor).to(pytorch_device)
326 | mask_pred_curr_zeroforced = (torch.sum(pred_logit_curr, dim=1)==0) # Forced to be zero during dense-to-sparse (refer to spconv)
327 |
328 | conf_curr = softmax_entropy(pred_logit_curr)
329 | conf_curr[mask_pred_curr_zeroforced] = 0
330 | conf_curr = 1 - conf_curr/torch.max(conf_curr)
331 | conf_curr[mask_pred_curr_zeroforced] = -1
332 |
333 | cur_middle_rfield=conv_ones(pred_logit_curr)
334 | cur_middle_field = (torch.sum(cur_middle_rfield, dim=1)!=0)
335 | cur_empty_field=(torch.sum(pred_logit_curr, dim=1)==0)
336 | proposure_idx=torch.where(cur_middle_field*cur_empty_field,1,0)
337 | proposure_idx=proposure_idx.squeeze().nonzero()
338 |
339 | mask_reliable_curr_occupied = torch.logical_and(conf_curr>args.th_pgt_occupied, pred_cls_curr!=0)
340 | mask_reliable_curr_empty = torch.logical_and(conf_curr>args.th_pgt_empty, pred_cls_curr==0)
341 | mask_reliable_curr = torch.logical_or(mask_reliable_curr_empty, mask_reliable_curr_occupied)
342 |
343 | voxel_pgt_curr[mask_reliable_curr] = pred_cls_curr[mask_reliable_curr]
344 |
345 | vis_voxel_pgt_curr = voxel_pgt_curr[0].cpu().detach().numpy()
346 | vis_voxel_pgt_curr = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt_curr, ignore_idx=(0,255))
347 | vis_voxel_pgt_curr[vis_voxel_pgt_curr==255] = 20
348 | vis_voxel_pgt_curr = PALLETE[vis_voxel_pgt_curr]
349 |
350 | ### Aux PGT
351 | if flag_adapt_aux_exist:
352 | print("- Generate class pgt for adapt using auxiliary scan.")
353 |
354 | vox_grid_to = frame_aux['vox_grid_to'][adapt_aux_idx].astype(np.int)
355 | vox_grid_from = frame_aux['vox_grid_from'][adapt_aux_idx].astype(np.int)
356 |
357 | pred_logit_aux = model_baseline([feat_aux[adapt_aux_idx]], [grid_aux[adapt_aux_idx]], val_batch_size, [frame_aux['grid_ind'][adapt_aux_idx]], use_tta=False)
358 | pred_cls_aux = torch.argmax(pred_logit_aux, 1).type(torch.LongTensor).to(pytorch_device)
359 | mask_pred_aux_zeroforced = (torch.sum(pred_logit_aux, dim=1)==0) # Forced to be zero during dense-to-sparse (refer to spconv)
360 |
361 | conf_aux = softmax_entropy(pred_logit_aux)
362 | conf_aux[mask_pred_aux_zeroforced] = 0
363 | conf_aux = 1 - conf_aux/torch.max(conf_aux)
364 | conf_aux[mask_pred_aux_zeroforced] = -1
365 |
366 | mask_reliable_aux_occupied = torch.logical_and(conf_aux>args.th_pgt_occupied, pred_cls_aux!=0)
367 | mask_reliable_aux_empty = torch.logical_and(conf_aux>args.th_pgt_empty, pred_cls_aux==0)
368 | mask_reliable_aux = torch.logical_or(mask_reliable_aux_occupied, mask_reliable_aux_empty)
369 |
370 | pred_cls_aux[~mask_reliable_aux] = 255
371 | pred_cls_aux[mask_pred_aux_zeroforced] = 255
372 | voxel_pgt_aux_adapt[0, vox_grid_to[:, 0], vox_grid_to[:, 1], vox_grid_to[:, 2]] = pred_cls_aux[0, vox_grid_from[:, 0], vox_grid_from[:, 1], vox_grid_from[:, 2]]
373 |
374 | else:
375 | vis_voxel_pgt_aux_adapt = None
376 | print("- Skip to make aux pgt for adapt.")
377 |
378 | vis_voxel_pgt_aux_adapt = voxel_pgt_aux_adapt[0].cpu().detach().numpy()
379 | vis_voxel_pgt_aux_adapt = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt_aux_adapt, ignore_idx=(0,255))
380 | vis_voxel_pgt_aux_adapt[vis_voxel_pgt_aux_adapt==255] = 20
381 | vis_voxel_pgt_aux_adapt = PALLETE[vis_voxel_pgt_aux_adapt]
382 |
383 | if flag_cont_aux_exist:
384 |
385 | print("- Generate class pgt for cont using auxiliary scan.")
386 |
387 | vox_grid_to = frame_aux['vox_grid_to'][cont_aux_idx].astype(np.int)
388 | vox_grid_from = frame_aux['vox_grid_from'][cont_aux_idx].astype(np.int)
389 |
390 | pred_logit_aux = model_baseline([feat_aux[cont_aux_idx]], [grid_aux[cont_aux_idx]], val_batch_size, [frame_aux['grid_ind'][cont_aux_idx]], use_tta=False)
391 | pred_cls_aux_cont = torch.argmax(pred_logit_aux, 1).type(torch.LongTensor).to(pytorch_device)
392 | mask_pred_aux_zeroforced = (torch.sum(pred_logit_aux, dim=1)==0) # Forced to be zero during dense-to-sparse (refer to spconv)
393 |
394 | conf_aux = softmax_entropy(pred_logit_aux)
395 | conf_aux[mask_pred_aux_zeroforced] = 0
396 | conf_aux = 1 - conf_aux/torch.max(conf_aux)
397 | conf_aux[mask_pred_aux_zeroforced] = -1
398 |
399 | mask_reliable_aux_occupied = torch.logical_and(conf_aux>args.th_pgt_occupied, pred_cls_aux_cont!=0)
400 | mask_reliable_aux_empty = torch.logical_and(conf_aux>args.th_pgt_empty, pred_cls_aux_cont==0)
401 | mask_reliable_aux = torch.logical_or(mask_reliable_aux_occupied, mask_reliable_aux_empty)
402 |
403 | pred_cls_aux_cont[~mask_reliable_aux] = 255
404 | pred_cls_aux_cont[mask_pred_aux_zeroforced] = 255
405 | voxel_pgt_aux_cont[0, vox_grid_to[:, 0], vox_grid_to[:, 1], vox_grid_to[:, 2]] = pred_cls_aux_cont[0, vox_grid_from[:, 0], vox_grid_from[:, 1], vox_grid_from[:, 2]]
406 |
407 | else:
408 | vis_voxel_pgt_aux_cont = None
409 | print("- Skip to make aux pgt for cont.")
410 |
411 | vis_voxel_pgt_aux_cont = voxel_pgt_aux_cont[0].cpu().detach().numpy()
412 | vis_voxel_pgt_aux_cont = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt_aux_cont, ignore_idx=(0,255))
413 | vis_voxel_pgt_aux_cont[vis_voxel_pgt_aux_cont==255] = 20
414 | vis_voxel_pgt_aux_cont = PALLETE[vis_voxel_pgt_aux_cont]
415 |
416 |
417 | ### Aggregation of current PGT and aux PGT for adapt
418 | voxel_pgt_all_adapt = voxel_pgt_curr.clone()
419 | mask_temp = (voxel_pgt_all_adapt==255)*(voxel_pgt_aux_adapt!=255)
420 | voxel_pgt_all_adapt[mask_temp] = voxel_pgt_aux_adapt[mask_temp]
421 | mask_temp = (voxel_pgt_curr!=255)*(voxel_pgt_aux_adapt!=255)*(voxel_pgt_curr!=voxel_pgt_aux_adapt)
422 | voxel_pgt_all_adapt[mask_temp] = 255
423 | ### Aggregation of current PGT and aux PGT for cont
424 | voxel_pgt_all_cont = voxel_pgt_curr.clone()
425 | mask_temp = (voxel_pgt_all_cont==255)*(voxel_pgt_aux_cont!=255)
426 | voxel_pgt_all_cont[mask_temp] = voxel_pgt_aux_cont[mask_temp]
427 | mask_temp = (voxel_pgt_curr!=255)*(voxel_pgt_aux_cont!=255)*(voxel_pgt_curr!=voxel_pgt_aux_cont)
428 | voxel_pgt_all_cont[mask_temp] = 255
429 |
430 |
431 | vis_voxel_pgt = voxel_pgt_all_adapt[0].cpu().detach().numpy()
432 | vis_voxel_pgt = np.apply_along_axis(extract_bev_for_vis_dual, 2, vis_voxel_pgt, ignore_idx=(0,255))
433 | vis_voxel_pgt[vis_voxel_pgt==255] = 20
434 | vis_voxel_pgt = PALLETE[vis_voxel_pgt]
435 |
436 | else:
437 | print("Skip to generate class pgt.")
438 |
439 | #########################################################################################################################################
440 | ########################################################## Scan-wise Adaptation #########################################################
441 | #########################################################################################################################################
442 |
443 | if args.do_adapt:
444 |
445 | print("Do scan-wise adaptation.")
446 | print("- From baseline model")
447 | model_adapt = copy.deepcopy(model_baseline)
448 |
449 | param_to_update_adapt = []
450 |
451 | print("- Scan-wise adapt segmentation module")
452 | for tn in module_names_seg:
453 | param_to_update_adapt += list(getattr(model_adapt.cylinder_3d_spconv_seg, tn).parameters())
454 |
455 | print("- Scan-wise adapt final logit layer")
456 | for tn in module_names_logit:
457 | param_to_update_adapt += list(getattr(model_adapt.cylinder_3d_spconv_seg, tn).parameters())
458 |
459 | optimizer_adapt = optim.Adam(param_to_update_adapt, lr=args.adapt_lr)
460 |
461 | model_adapt.train()
462 |
463 | # Partial freeze
464 | for name, param in model_adapt.named_parameters():
465 | #freeze_mlp
466 | if name.split('.')[0] in module_names_mlp:
467 | param.requires_grad = False
468 | #freeze_comp:
469 | if any(mona in name.split('.')[1] for mona in module_names_comp):
470 | param.requires_grad = False
471 |
472 | for idx_adapt in range(args.adapt_iter):
473 |
474 | logit = model_adapt(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False) # (B,C,x,y,z)
475 | loss_adapt_occ_ce = 0
476 | loss_adapt_occ_lovasz = 0
477 | loss_adapt_pgt_ce = 0
478 | loss_adapt_pgt_lovasz = 0
479 |
480 | if args.use_los and flag_adapt_aux_exist:
481 | logit_empty = logit[:,:1,:,:,:]
482 | logit_occupied = logit[:,1:,:,:,:].max(dim=1, keepdim=True)[0]
483 | logit_comp = torch.cat((logit_empty, logit_occupied), dim=1) # (B,2,x,y,z)
484 | loss_adapt_occ_ce += loss_fn_ce_binary(logit_comp, voxel_los_adapt)
485 | loss_adapt_occ_lovasz += loss_fn_lovasz_binary(F.softmax(logit_comp), voxel_los_adapt, ignore=255)
486 |
487 | if args.use_pgt:
488 | loss_adapt_pgt_ce += loss_fn_ce(logit, voxel_pgt_all_adapt)
489 | loss_adapt_pgt_lovasz += loss_fn_lovasz(F.softmax(logit), voxel_pgt_all_adapt, ignore=255)
490 |
491 | optimizer_adapt.zero_grad()
492 | loss_adapt = args.weight_adapt_occ_ce*loss_adapt_occ_ce + args.weight_adapt_occ_lovasz*loss_adapt_occ_lovasz \
493 | + args.weight_adapt_pgt_ce*loss_adapt_pgt_ce + args.weight_adapt_pgt_lovasz*loss_adapt_pgt_lovasz
494 | if loss_adapt!=0 and not torch.isnan(loss_adapt):
495 | loss_adapt.backward()
496 | optimizer_adapt.step()
497 | else:
498 | print('Loss is zero or NaN! Skip adapt optimization.')
499 |
500 | # plot adapt losses
501 | for loss_name in loss_adapt_names:
502 | loss_now = locals()[loss_name]
503 | writer.add_scalar("loss_adapt/"+loss_name, loss_now, global_step=args.adapt_iter*idx_test+idx_adapt)
504 |
505 |
506 | else:
507 | print("Skip scan-wise adaptation.")
508 |
509 | #########################################################################################################################################
510 | ############################################################## Eval phase ###############################################################
511 | #########################################################################################################################################
512 |
513 | model_cont.eval()
514 | model_adapt.eval()
515 | with torch.no_grad():
516 | pred_logit = model_cont(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False, extraction=proposure_idx)
517 |
518 | pred = torch.argmax(pred_logit, dim=1)
519 |
520 | pred_logit_bs = model_adapt(feat_curr, grid_curr, val_batch_size, frame_curr['grid_ind'], use_tta=False)
521 | pred_bs = torch.argmax(pred_logit_bs, dim=1)
522 |
523 | mask_pred_=(pred==9)+(pred==10)+(pred==11)+(pred==12)+(pred==13)+(pred==15)+(pred==16)+(pred==17)
524 | pred[~mask_pred_]=pred_bs[~mask_pred_]
525 |
526 | pred = pred.cpu().detach().numpy()
527 | pred = np.squeeze(pred)
528 | pred = pred.astype(np.uint32)
529 | pred = pred.reshape((-1))
530 |
531 | ### save prediction after remapping
532 | pred_remapped = remapping(pred, remap=remap_first)
533 | name_velodyne = test_pt_dataset.im_idx[frame_curr['index'][0]]
534 | save_pred(pred_remapped, name_velodyne, exp_path)
535 |
536 | #########################################################################################################################################
537 | ######################################################## Continual TTA phase ############################################################
538 | #########################################################################################################################################
539 |
540 | if args.do_cont:
541 |
542 | print("Do continual adaptation.")
543 |
544 | model_cont.train()
545 | # Partial freeze
546 | for name, param in model_cont.named_parameters():
547 | #freeze_mlp
548 | if name.split('.')[0] in module_names_mlp:
549 | param.requires_grad = False
550 | #freeze_comp
551 | if any(mona in name.split('.')[1] for mona in module_names_comp):
552 | param.requires_grad = False
553 | #freeze_logit
554 | if any(mona in name.split('.')[1] for mona in module_names_logit):
555 | param.requires_grad = False
556 |
557 | # Continual TTA loop
558 | for idx_cont in range(args.cont_iter):
559 |
560 | # Random masking
561 | mask_size=4
562 | mask_ratio=0.1
563 | upsample = nn.Upsample(scale_factor=mask_size, mode='nearest')
564 | mask_voxel=torch.zeros(int(256/mask_size),int(256/mask_size),int(32/mask_size))
565 | mask_voxel=mask_voxel.reshape(-1)
566 | rand_idx=torch.randperm(mask_voxel.shape[0])
567 | mask_number=int(rand_idx.shape[0]*mask_ratio)
568 | mask_patch=rand_idx[:mask_number]
569 | mask_voxel[mask_patch[:]]=1
570 | mask_voxel=mask_voxel.reshape(int(256/mask_size),int(256/mask_size),int(32/mask_size))
571 | mask_voxel=upsample(mask_voxel.unsqueeze(0).unsqueeze(0)).squeeze()
572 | voxel_curr_grid = torch.zeros([256,256,32])
573 | voxel_curr_grid[grid_curr[0][:,0],grid_curr[0][:,1],grid_curr[0][:,2]] = 1
574 | masked_voxel=voxel_curr_grid*mask_voxel
575 | masked_coords=masked_voxel.nonzero()
576 |
577 | matches = torch.nonzero((grid_curr[0][:, None] == masked_coords.cuda()).all(-1), as_tuple=True)
578 | masked_idx = matches[0]
579 | all_indices = torch.arange(grid_curr[0].shape[0])
580 | retain_idx = all_indices[~torch.isin(all_indices.cuda(), masked_idx)]
581 | feat_curr_retain = [feat_curr[0][retain_idx]]
582 | grid_curr_retain = [grid_curr[0][retain_idx]]
583 | frame_curr_grid_ind_retain = [frame_curr['grid_ind'][0][retain_idx.cpu().detach().numpy()]]
584 | grid_curr_masked = grid_curr[0][masked_idx]
585 | grid_curr_masked=torch.cat([grid_curr_masked,proposure_idx],0)
586 | logit = model_cont(feat_curr_retain, grid_curr_retain, val_batch_size, frame_curr_grid_ind_retain, use_tta=False, extraction=grid_curr_masked) # (B,C,x,y,z)
587 | pred = torch.argmax(logit, dim=1) # (B,x,y,z)
588 |
589 | loss_cont_occ_ce = 0
590 | loss_cont_occ_lovasz = 0
591 | loss_cont_pgt_ce = 0
592 | loss_cont_pgt_lovasz = 0
593 |
594 | if args.use_los and flag_cont_aux_exist:
595 | logit_empty = logit[:,:1,:,:,:]
596 | logit_occupied = logit[:,1:,:,:,:].max(dim=1, keepdim=True)[0]
597 | logit_comp = torch.cat((logit_empty, logit_occupied), dim=1) # (B,2,x,y,z)
598 | loss_cont_occ_ce += loss_fn_ce_binary(logit_comp, voxel_los_cont)
599 | loss_cont_occ_lovasz += loss_fn_lovasz_binary(F.softmax(logit_comp), voxel_los_cont, ignore=255)
600 |
601 | if args.use_pgt:
602 | loss_cont_pgt_ce += loss_fn_ce(logit, voxel_pgt_all_cont)
603 | loss_cont_pgt_lovasz += loss_fn_lovasz(F.softmax(logit), voxel_pgt_all_cont, ignore=255)
604 |
605 | optimizer_cont.zero_grad()
606 | loss_cont = args.weight_cont_occ_ce*loss_cont_occ_ce + args.weight_cont_occ_lovasz*loss_cont_occ_lovasz \
607 | + args.weight_cont_pgt_ce*loss_cont_pgt_ce + args.weight_cont_pgt_lovasz*loss_cont_pgt_lovasz
608 | if loss_cont!=0 and not torch.isnan(loss_cont):
609 | loss_cont.backward()
610 | optimizer_cont.step()
611 | else:
612 | print('Loss is zero or NaN! Skip cont optimization.')
613 |
614 | # plot cont tta losses
615 | for loss_name in loss_cont_names:
616 | loss_now = locals()[loss_name]
617 | writer.add_scalar("loss_cont/"+loss_name, loss_now, global_step=args.cont_iter*idx_test+idx_cont)
618 | else:
619 | print("Skip continual adaptation.")
620 |
621 |
622 | if __name__ == '__main__':
623 | # Training settings
624 | parser = argparse.ArgumentParser(description='')
625 |
626 | # Sources
627 | parser.add_argument('--talos_root', default='./', type=str)
628 | parser.add_argument('--config_path', default='config/semantickitti-tta.yaml')
629 | parser.add_argument('--baseline_perf_txt', default='baseline_performance.txt', type=str)
630 | parser.add_argument('--baseline_preds', default='experiments/baseline/sequences/08/predictions', type=str)
631 | parser.add_argument('--sq_num', default='8', type=str)
632 | # Experiment
633 | parser.add_argument('--name', default='debug')
634 | parser.add_argument('--ang', action='store_true', help='Auto Name Generator')
635 | parser.add_argument('--loader', default='data_builder', type=str)
636 | parser.add_argument('--stride', default='[-5,5]', type=str)
637 | parser.add_argument('--do_cont', action='store_true')
638 | parser.add_argument('--do_adapt', action='store_true')
639 |
640 | # Attributes
641 | parser.add_argument('--use_los', action='store_true')
642 | parser.add_argument('--use_pgt', action='store_true')
643 | parser.add_argument('--th_pgt_occupied', default=0.75, type=float)
644 | parser.add_argument('--th_pgt_empty', default=0.999, type=float)
645 |
646 | # Optimization (cont)
647 | parser.add_argument('--cont_lr', default=3e-05, type=float)
648 | parser.add_argument('--cont_iter', default=1, type=int)
649 | parser.add_argument('--weight_cont_occ_ce', default=1, type=float)
650 | parser.add_argument('--weight_cont_occ_lovasz', default=1, type=float)
651 | parser.add_argument('--weight_cont_pgt_ce', default=1, type=float)
652 | parser.add_argument('--weight_cont_pgt_lovasz', default=1, type=float)
653 |
654 | # Optimization (adapt)
655 | parser.add_argument('--adapt_lr', default=0.0003, type=float)
656 | parser.add_argument('--adapt_iter', default=3, type=int)
657 | parser.add_argument('--weight_adapt_occ_ce', default=1, type=float)
658 | parser.add_argument('--weight_adapt_occ_lovasz', default=1, type=float)
659 | parser.add_argument('--weight_adapt_pgt_ce', default=1, type=float)
660 | parser.add_argument('--weight_adapt_pgt_lovasz', default=1, type=float)
661 |
662 | args = parser.parse_args()
663 |
664 | args.baseline_perf_txt = args.talos_root+args.baseline_perf_txt
665 | args.baseline_preds = args.talos_root+args.baseline_preds
666 |
667 | print(' '.join(sys.argv))
668 | print(args)
669 | print('#####')
670 | print('Stride: '+str(args.stride))
671 | print('#####')
672 |
673 | if args.ang:
674 |
675 | args.name = ''
676 | args.name += '_stride'+str(args.stride)
677 |
678 | if args.use_los:
679 | args.name += '_los'
680 | if args.use_pgt:
681 | args.name += '_pgt'
682 | if args.th_pgt_occupied != 0.75:
683 | args.name += '_thocc'+str(args.th_pgt_occupied)
684 | if args.th_pgt_empty != 0.999:
685 | args.name += '_themp'+str(args.th_pgt_empty)
686 |
687 | if args.do_cont:
688 | args.name += '_cont'
689 | args.name += '_clr'+str(args.cont_lr)
690 | args.name += '_cit'+str(args.cont_iter)
691 |
692 | if args.do_adapt:
693 | args.name += '_adapt'
694 | args.name += '_alr'+str(args.adapt_lr)
695 | args.name += '_ait'+str(args.adapt_iter)
696 |
697 |
698 |
699 | print(args.name)
700 |
701 | main(args)
702 |
--------------------------------------------------------------------------------